Coverage for postrfp/web/suxint/extractors.py: 100%
207 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-22 21:34 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-22 21:34 +0000
1from typing import Callable, Any
2import re
4from pydantic import ValidationError
5from webob import Request
7from postrfp.model.exc import ValidationFailure
8from postrfp.web.ext.openapi_types import openapi_type_mapping
11truthy_strings = {"true", "1", "yes"}
14class ArgExtractor(object):
15 """
16 Base class for Argument Extractor classes
18 Instances of ArgExtractor are callables that extract a value
19 from an HTTP request
21 e.g. GetArg('name') produces a callable that knows how to extract the
22 GET parameter 'name' from an http request
23 """
25 swagger_in: str | None = None
26 swagger_required = True
28 def __init__(
29 self,
30 arg_name: str,
31 validator: Callable | None = None,
32 doc: str | None = None,
33 default: Any = None,
34 arg_type: str = "int",
35 converter: Callable | None = None,
36 enum_values: set | list | tuple | None = None,
37 required: bool | None = None,
38 ) -> None:
39 self.arg_name = arg_name
40 self.doc = doc
41 self.default = default
42 self.arg_type = arg_type
43 self.converter = converter
44 self.required = required
45 if enum_values is not None and not isinstance(enum_values, (set, list, tuple)):
46 raise ValueError("enum_values must be a set, list or tuple")
47 self.enum_values = enum_values
48 if callable(validator):
49 self.validator: Callable | None = validator
50 elif validator is not None:
51 raise ValueError("validator, if given, must be a callable")
52 # subclasses might define a validator, don't overwrite it
53 if not hasattr(self, "validator"):
54 self.validator = None
56 def validate(self, value: Any) -> None:
57 if self.validator:
58 if not self.validator(value):
59 raise ValueError("%s is not a valid value" % value)
61 def print_enum(self) -> str | None:
62 """Return a string representation of the enum values for documentation"""
63 if self.enum_values is None:
64 return None
65 return ", ".join(str(v) for v in self.enum_values)
67 def validate_enums(self, value: Any) -> None:
68 if self.enum_values is not None:
69 if isinstance(value, (list, set, tuple)):
70 for v in value:
71 if v not in self.enum_values:
72 enum_str = self.print_enum() or "[]"
73 raise ValueError(f"Value -{v}- is not one of {enum_str}")
74 elif value not in self.enum_values:
75 raise ValueError(f"Value -{value}- is not one of {self.print_enum()}")
77 def typed(self, value: Any) -> Any:
78 if value is None:
79 return None
80 if not isinstance(value, str):
81 value = str(value)
82 value = value.strip()
83 if self.arg_type == "int":
84 return int(value)
85 if self.arg_type == "float":
86 return float(value)
87 if self.arg_type == "boolean":
88 return value.lower() in truthy_strings
90 # lstrip is a workaround to enable strings as path arguments
91 # to be recognised, e.g. /user/:bob/delete
92 return value.lstrip(":")
94 def __call__(self, request: Request) -> Any:
95 val = self.extract(request)
96 self.validate(val)
97 if self.converter is not None:
98 return self.converter(val)
99 return val
101 def extract(self, request: Request) -> Any:
102 raise NotImplementedError
104 def update_openapi_path_object(self, path_object: dict) -> None:
105 spec: dict[str, Any] = {
106 "in": self.swagger_in,
107 "required": self.swagger_required,
108 "name": self.doc_name,
109 }
110 if self.required is not None:
111 spec["required"] = self.required
113 if self.doc is not None:
114 spec["description"] = self.doc
115 spec["schema"] = dict(type=self.swagger_type)
116 if self.default is not None:
117 val = self.default
118 spec["schema"]["default"] = val
120 if getattr(self, "enum_values", None) is not None:
121 type_name = f"_{path_object['operationId']}{self.arg_name.capitalize()}"
122 spec["schema"]["$ref"] = f"#/components/schemas/{type_name}"
124 self.augment_openapi_path_object(spec)
126 path_object["parameters"].append(spec)
128 def augment_openapi_path_object(self, spec: dict) -> None:
129 """
130 Subclasses can override this method to add additional information to the
131 OpenAPI path object.
132 """
133 pass
135 @property
136 def swagger_type(self) -> str:
137 return openapi_type_mapping[self.arg_type]
139 @property
140 def doc_name(self) -> str:
141 """The name of this Parameter for documentation purposes"""
142 return self.arg_name
145class PathArg(ArgExtractor):
146 swagger_in = "path"
148 def __init__(self, *args: Any, **kwargs: Any) -> None:
149 super(PathArg, self).__init__(*args, **kwargs)
150 regexp = r".*/%s/([^/]+)/*" % self.arg_name
151 self.path_regex = re.compile(regexp)
153 def extract(self, request: Request) -> Any:
154 try:
155 match = self.path_regex.match(request.path_info)
156 if match is None:
157 raise ValueError("No match found")
158 val = match.groups()[0]
159 return self.typed(val)
160 except Exception:
161 mess = (
162 f"{self.__class__.__name__} Adaptor could not extract "
163 f'value "{self.arg_name}" from path {request.path_info} '
164 f"using regex {self.path_regex.pattern}"
165 )
166 raise ValueError(mess)
168 @property
169 def doc_name(self) -> str:
170 """The name of this Parameter for documentation purposes"""
171 return self.arg_name + "_id"
174class GetArg(ArgExtractor):
175 swagger_in = "query"
176 swagger_required = False
178 def extract(self, request: Request) -> Any:
179 if self.arg_name not in request.GET:
180 if self.required:
181 raise ValueError(f"Query param '{self.arg_name}'' must be provided")
182 return self.typed(self.default)
183 value = request.GET[self.arg_name]
184 if value.strip() == "":
185 return self.typed(self.default)
186 else:
187 self.validate_enums(value)
188 return self.typed(value)
191class GetArgSet(ArgExtractor):
192 """Extracts a Set (sequence) of argument values for the given GET arg"""
194 swagger_in = "query"
195 swagger_type = "array"
196 swagger_required = False
198 def __init__(self, *args: Any, **kwargs: Any) -> None:
199 self.array_items_type = kwargs.pop("array_items_type", "str")
200 self.min_items = kwargs.pop("min_items", None)
201 self.max_items = kwargs.pop("max_items", None)
202 super().__init__(*args, **kwargs)
204 def extract(self, request: Request) -> set[Any]:
205 arg_array_name = f"{self.arg_name}[]"
206 if self.arg_name in request.GET:
207 arg_array_name = self.arg_name
208 values = request.GET.getall(arg_array_name)
209 if self.min_items is not None and len(values) < self.min_items:
210 raise ValueError(
211 f"At least {self.min_items} {self.arg_name} parameters required"
212 )
213 if self.max_items is not None and len(values) > self.max_items:
214 raise ValueError(
215 f"No more than {self.max_items} {self.arg_name} parameters permitted"
216 )
217 self.validate_enums(values)
218 return {self.typed(v) for v in values}
220 def augment_openapi_path_object(self, spec: dict) -> None:
221 if self.min_items is not None:
222 spec["schema"]["minItems"] = self.min_items
224 if self.max_items is not None:
225 spec["schema"]["maxItems"] = self.max_items
227 if self.array_items_type is not None:
228 spec["schema"] = {
229 "type": "array",
230 "items": {"type": openapi_type_mapping[self.array_items_type]},
231 }
234def _get_or_create_schema(
235 path_object: dict, mime_type: str = "application/x-www-form-urlencoded"
236) -> dict:
237 return (
238 path_object.setdefault("requestBody", {})
239 .setdefault("content", {})
240 .setdefault(mime_type, {})
241 .setdefault("schema", {})
242 )
245class PostArg(ArgExtractor):
246 swagger_in = "formData"
247 swagger_required = False
248 swagger_type = "string"
249 doc_name = "body"
251 def extract(self, request: Request) -> Any:
252 value = request.POST.get(self.arg_name, self.default)
253 if value is not None:
254 tv = self.typed(value)
255 if self.enum_values and tv not in self.enum_values:
256 raise ValueError(f"{tv} is not in {self.enum_values}")
257 return tv
258 return None
260 def update_openapi_path_object(self, path_object: dict) -> None:
261 schema = _get_or_create_schema(path_object, mime_type="multipart/form-data")
262 schema.setdefault("type", "object")
263 properties = schema.setdefault("properties", {})
264 properties[self.arg_name] = {"type": self.swagger_type}
265 if self.enum_values is not None:
266 properties[self.arg_name]["enum"] = list(self.enum_values)
269class PostFileArg(ArgExtractor):
270 swagger_in = "formData"
271 swagger_type = "string"
272 doc_name = "body"
274 def __init__(self, *args: Any, **kwargs: Any) -> None:
275 kwargs["arg_type"] = "file"
276 super().__init__(*args, **kwargs)
278 def extract(self, request: Request) -> Any:
279 val = request.POST.get(self.arg_name, None)
280 if val is None and self.required:
281 raise ValueError(f"Post Parameter {self.arg_name} must be provided")
282 return val
284 def update_openapi_path_object(self, path_object: dict) -> None:
285 schema = _get_or_create_schema(path_object, mime_type="multipart/form-data")
286 schema.setdefault("type", "object")
287 properties = schema.setdefault("properties", {})
288 properties[self.arg_name] = {"type": "string", "format": "binary"}
291"""
292 Pydantic ALIAS
294 - alias should be the public name if different from the db column name
295 - when dumping/serialising when outputting data from the data to the outside world
296 use model.dict(by_alias)
297 - when parsing incoming data, use model.model_dump()
299 by_alias == False (the default) means 'dump data for consumption by database'
300 by_alias == True means 'dump data for use by API consumers' (default behaviour)
302"""
305class SchemaDocArg(ArgExtractor):
306 swagger_in = "body"
308 def __init__(self, schema_cls, as_dict=True, exclude_unset=False):
309 self.as_dict = as_dict
310 self.exclude_unset = exclude_unset
311 super().__init__(None)
312 self.schema_cls = schema_cls
314 def extract(self, request):
315 try:
316 # NB exclude_unset=True is important. If not set, or False,
317 # then all the pydantic Model attributes not set will
318 # appear in the data passed to API methods - i.e. lots of None
319 # values which can overwrite real data on update
320 pydantic_model = self.schema_cls.model_validate(request.json)
321 if self.as_dict:
322 return pydantic_model.model_dump(exclude_unset=self.exclude_unset)
323 else:
324 return pydantic_model
326 except ValidationError as e:
327 err_list = []
328 for err in e.errors():
329 location = " > ".join(f"'{e}'" for e in err["loc"])
330 msg = err["msg"].title()
331 err_msg = f"Validation failed. {msg}: {location}"
332 err_list.append(err_msg)
333 raise ValidationFailure(str(e), err_list)
335 def update_openapi_path_object(self, path_obj):
336 path_obj["requestBody"] = {
337 "content": {
338 "application/json": {
339 "schema": {
340 "$ref": "#/components/schemas/%s" % self.schema_cls.__name__
341 }
342 }
343 }
344 }