Coverage for postrfp / web / suxint / extractors.py: 99%
225 statements
« prev ^ index » next coverage.py v7.12.0, created at 2025-12-03 01:35 +0000
« prev ^ index » next coverage.py v7.12.0, created at 2025-12-03 01:35 +0000
1from typing import Callable, Any
2import re
4from pydantic import ValidationError
5from webob import Request
7from postrfp.model.exc import ValidationFailure
8from postrfp.web.ext.openapi_types import openapi_type_mapping
11truthy_strings = {"true", "1", "yes"}
14class ArgExtractor(object):
15 """
16 Base class for Argument Extractor classes
18 Instances of ArgExtractor are callables that extract a value
19 from an HTTP request
21 e.g. GetArg('name') produces a callable that knows how to extract the
22 GET parameter 'name' from an http request
23 """
25 swagger_in: str | None = None
26 swagger_required = True
28 def __init__(
29 self,
30 arg_name: str,
31 validator: Callable | None = None,
32 doc: str | None = None,
33 default: Any = None,
34 arg_type: str = "int",
35 converter: Callable | None = None,
36 enum_values: set | list | tuple | None = None,
37 required: bool | None = None,
38 ) -> None:
39 self.arg_name = arg_name
40 self.doc = doc
41 self.default = default
42 self.arg_type = arg_type
43 self.converter = converter
44 self.required = required
45 if enum_values is not None and not isinstance(enum_values, (set, list, tuple)):
46 raise ValueError("enum_values must be a set, list or tuple")
47 self.enum_values = enum_values
48 if callable(validator):
49 self.validator: Callable | None = validator
50 elif validator is not None:
51 raise ValueError("validator, if given, must be a callable")
52 # subclasses might define a validator, don't overwrite it
53 if not hasattr(self, "validator"):
54 self.validator = None
56 def validate(self, value: Any) -> None:
57 if self.validator:
58 if not self.validator(value):
59 raise ValueError("%s is not a valid value" % value)
61 def print_enum(self) -> str | None:
62 """Return a string representation of the enum values for documentation"""
63 if self.enum_values is None:
64 return None
65 return ", ".join(str(v) for v in self.enum_values)
67 def validate_enums(self, value: Any) -> None:
68 if self.enum_values is not None:
69 if isinstance(value, (list, set, tuple)):
70 for v in value:
71 if v not in self.enum_values:
72 enum_str = self.print_enum() or "[]"
73 raise ValueError(f"Value -{v}- is not one of {enum_str}")
74 elif value not in self.enum_values:
75 raise ValueError(f"Value -{value}- is not one of {self.print_enum()}")
77 def typed(self, value: Any) -> Any:
78 if value is None:
79 return None
80 if not isinstance(value, str):
81 value = str(value)
82 value = value.strip()
83 if self.arg_type == "int":
84 return int(value)
85 if self.arg_type == "float":
86 return float(value)
87 if self.arg_type == "boolean":
88 return value.lower() in truthy_strings
90 # lstrip is a workaround to enable strings as path arguments
91 # to be recognised, e.g. /user/:bob/delete
92 return value.lstrip(":")
94 def __call__(self, request: Request) -> Any:
95 val = self.extract(request)
96 self.validate(val)
97 if self.converter is not None:
98 return self.converter(val)
99 return val
101 def extract(self, request: Request) -> Any:
102 raise NotImplementedError
104 def update_openapi_path_object(self, path_object: dict) -> None:
105 params_object: dict[str, Any] = {
106 "in": self.swagger_in,
107 "required": self.swagger_required,
108 "name": self.doc_name,
109 }
110 if self.required is not None:
111 params_object["required"] = self.required
113 if self.doc is not None:
114 params_object["description"] = self.doc
115 params_object["schema"] = dict(type=self.swagger_type)
116 if self.default is not None:
117 val = self.default
118 params_object["schema"]["default"] = val
120 if getattr(self, "enum_values", None) is not None:
121 type_name = f"_{path_object['operationId']}{self.arg_name.capitalize()}"
122 params_object["schema"]["$ref"] = f"#/components/schemas/{type_name}"
124 self.augment_openapi_params_object(params_object)
125 self.augment_openapi_responses(path_object["responses"])
127 path_object["parameters"].append(params_object)
129 def augment_openapi_params_object(self, params_object: dict) -> None:
130 """
131 Subclasses can override this method to add additional information to the
132 OpenAPI parameters object.
133 """
134 pass
136 def augment_openapi_responses(self, responses_object: dict) -> None:
137 """
138 Subclasses can override this method to add additional response codes
139 to the OpenAPI handler responses object.
140 """
141 pass
143 @property
144 def swagger_type(self) -> str:
145 return openapi_type_mapping[self.arg_type]
147 @property
148 def doc_name(self) -> str:
149 """The name of this Parameter for documentation purposes"""
150 return self.arg_name
153class PathArg(ArgExtractor):
154 swagger_in = "path"
156 def __init__(self, *args: Any, **kwargs: Any) -> None:
157 super(PathArg, self).__init__(*args, **kwargs)
158 regexp = r".*/%s/([^/]+)/*" % self.arg_name
159 self.path_regex = re.compile(regexp)
161 def extract(self, request: Request) -> Any:
162 try:
163 match = self.path_regex.match(request.path_info)
164 if match is None:
165 raise ValueError("No match found")
166 val = match.groups()[0]
167 return self.typed(val)
168 except Exception:
169 mess = (
170 f"{self.__class__.__name__} Adaptor could not extract "
171 f'value "{self.arg_name}" from path {request.path_info} '
172 f"using regex {self.path_regex.pattern}"
173 )
174 raise ValueError(mess)
176 @property
177 def doc_name(self) -> str:
178 """The name of this Parameter for documentation purposes"""
179 return self.arg_name + "_id"
182class GetArg(ArgExtractor):
183 swagger_in = "query"
184 swagger_required = False
186 def extract(self, request: Request) -> Any:
187 if self.arg_name not in request.GET:
188 if self.required:
189 raise ValueError(f"Query param '{self.arg_name}'' must be provided")
190 return self.typed(self.default)
191 value = request.GET[self.arg_name]
192 if value.strip() == "":
193 return self.typed(self.default)
194 else:
195 self.validate_enums(value)
196 return self.typed(value)
199class GetArgSet(ArgExtractor):
200 """Extracts a Set (sequence) of argument values for the given GET arg"""
202 swagger_in = "query"
203 swagger_type = "array"
204 swagger_required = False
206 def __init__(self, *args: Any, **kwargs: Any) -> None:
207 self.array_items_type = kwargs.pop("array_items_type", "str")
208 self.min_items = kwargs.pop("min_items", None)
209 self.max_items = kwargs.pop("max_items", None)
210 super().__init__(*args, **kwargs)
212 def extract(self, request: Request) -> set[Any]:
213 arg_array_name = f"{self.arg_name}[]"
214 if self.arg_name in request.GET:
215 arg_array_name = self.arg_name
216 values = request.GET.getall(arg_array_name)
217 if self.min_items is not None and len(values) < self.min_items:
218 raise ValueError(
219 f"At least {self.min_items} {self.arg_name} parameters required"
220 )
221 if self.max_items is not None and len(values) > self.max_items:
222 raise ValueError(
223 f"No more than {self.max_items} {self.arg_name} parameters permitted"
224 )
225 self.validate_enums(values)
226 return {self.typed(v) for v in values}
228 def augment_openapi_params_object(self, path_object: dict) -> None:
229 if self.min_items is not None:
230 path_object["schema"]["minItems"] = self.min_items
232 if self.max_items is not None:
233 path_object["schema"]["maxItems"] = self.max_items
235 if self.array_items_type is not None:
236 path_object["schema"] = {
237 "type": "array",
238 "items": {"type": openapi_type_mapping[self.array_items_type]},
239 }
242def _get_or_create_schema(
243 path_object: dict, mime_type: str = "application/x-www-form-urlencoded"
244) -> dict:
245 return (
246 path_object.setdefault("requestBody", {})
247 .setdefault("content", {})
248 .setdefault(mime_type, {})
249 .setdefault("schema", {})
250 )
253class PostArg(ArgExtractor):
254 swagger_in = "formData"
255 swagger_required = False
256 swagger_type = "string"
257 doc_name = "body"
259 def extract(self, request: Request) -> Any:
260 value = request.POST.get(self.arg_name, self.default)
261 if value is not None:
262 tv = self.typed(value)
263 if self.enum_values and tv not in self.enum_values:
264 raise ValueError(f"{tv} is not in {self.enum_values}")
265 return tv
266 return None
268 def update_openapi_path_object(self, path_object: dict) -> None:
269 schema = _get_or_create_schema(path_object, mime_type="multipart/form-data")
270 schema.setdefault("type", "object")
271 properties = schema.setdefault("properties", {})
272 properties[self.arg_name] = {"type": self.swagger_type}
273 if self.enum_values is not None:
274 properties[self.arg_name]["enum"] = list(self.enum_values)
277class PostFileArg(ArgExtractor):
278 swagger_in = "formData"
279 swagger_type = "string"
280 doc_name = "body"
282 def __init__(self, *args: Any, **kwargs: Any) -> None:
283 kwargs["arg_type"] = "file"
284 super().__init__(*args, **kwargs)
286 def extract(self, request: Request) -> Any:
287 val = request.POST.get(self.arg_name, None)
288 if val is None and self.required:
289 raise ValueError(f"Post Parameter {self.arg_name} must be provided")
290 return val
292 def update_openapi_path_object(self, path_object: dict) -> None:
293 schema = _get_or_create_schema(path_object, mime_type="multipart/form-data")
294 schema.setdefault("type", "object")
295 properties = schema.setdefault("properties", {})
296 properties[self.arg_name] = {"type": "string", "format": "binary"}
299"""
300 Pydantic ALIAS
302 - alias should be the public name if different from the db column name
303 - when dumping/serialising when outputting data from the data to the outside world
304 use model.dict(by_alias)
305 - when parsing incoming data, use model.model_dump()
307 by_alias == False (the default) means 'dump data for consumption by database'
308 by_alias == True means 'dump data for use by API consumers' (default behaviour)
310"""
313class SchemaDocArg(ArgExtractor):
314 swagger_in = "body"
316 def __init__(self, schema_cls, as_dict=True, exclude_unset=False):
317 self.as_dict = as_dict
318 self.exclude_unset = exclude_unset
319 super().__init__(None)
320 self.schema_cls = schema_cls
322 def extract(self, request):
323 try:
324 # NB exclude_unset=True is important. If not set, or False,
325 # then all the pydantic Model attributes not set will
326 # appear in the data passed to API methods - i.e. lots of None
327 # values which can overwrite real data on update
328 pydantic_model = self.schema_cls.model_validate(request.json)
329 if self.as_dict:
330 return pydantic_model.model_dump(exclude_unset=self.exclude_unset)
331 else:
332 return pydantic_model
334 except ValidationError as e:
335 err_list = []
336 for err in e.errors():
337 location = " > ".join(f"'{e}'" for e in err["loc"])
338 msg = err["msg"].title()
339 err_msg = f"Validation failed. {msg}: {location}"
340 err_list.append(err_msg)
341 raise ValidationFailure(str(e), err_list)
343 def update_openapi_path_object(self, path_obj):
344 path_obj["requestBody"] = {
345 "content": {
346 "application/json": {
347 "schema": {
348 "$ref": "#/components/schemas/%s" % self.schema_cls.__name__
349 }
350 }
351 }
352 }
355class HeaderArg(ArgExtractor):
356 swagger_in = "header"
357 swagger_type = "string"
358 swagger_required = False
360 def extract(self, request: Request) -> Any:
361 if self.arg_name not in request.headers:
362 if self.required:
363 raise ValueError(f"Header '{self.arg_name}'' must be provided")
364 return self.typed(self.default)
365 value = request.headers[self.arg_name]
366 if value.strip() == "":
367 return self.typed(self.default)
368 else:
369 self.validate_enums(value)
370 return self.typed(value)
372 def augment_openapi_responses(self, responses_object: dict) -> None:
373 responses_object["412"] = {
374 "description": "Precondition Failed - ETag did not match",
375 }