Coverage for postrfp/web/suxint/extractors.py: 100%

207 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-10-22 21:34 +0000

1from typing import Callable, Any 

2import re 

3 

4from pydantic import ValidationError 

5from webob import Request 

6 

7from postrfp.model.exc import ValidationFailure 

8from postrfp.web.ext.openapi_types import openapi_type_mapping 

9 

10 

11truthy_strings = {"true", "1", "yes"} 

12 

13 

14class ArgExtractor(object): 

15 """ 

16 Base class for Argument Extractor classes 

17 

18 Instances of ArgExtractor are callables that extract a value 

19 from an HTTP request 

20 

21 e.g. GetArg('name') produces a callable that knows how to extract the 

22 GET parameter 'name' from an http request 

23 """ 

24 

25 swagger_in: str | None = None 

26 swagger_required = True 

27 

28 def __init__( 

29 self, 

30 arg_name: str, 

31 validator: Callable | None = None, 

32 doc: str | None = None, 

33 default: Any = None, 

34 arg_type: str = "int", 

35 converter: Callable | None = None, 

36 enum_values: set | list | tuple | None = None, 

37 required: bool | None = None, 

38 ) -> None: 

39 self.arg_name = arg_name 

40 self.doc = doc 

41 self.default = default 

42 self.arg_type = arg_type 

43 self.converter = converter 

44 self.required = required 

45 if enum_values is not None and not isinstance(enum_values, (set, list, tuple)): 

46 raise ValueError("enum_values must be a set, list or tuple") 

47 self.enum_values = enum_values 

48 if callable(validator): 

49 self.validator: Callable | None = validator 

50 elif validator is not None: 

51 raise ValueError("validator, if given, must be a callable") 

52 # subclasses might define a validator, don't overwrite it 

53 if not hasattr(self, "validator"): 

54 self.validator = None 

55 

56 def validate(self, value: Any) -> None: 

57 if self.validator: 

58 if not self.validator(value): 

59 raise ValueError("%s is not a valid value" % value) 

60 

61 def print_enum(self) -> str | None: 

62 """Return a string representation of the enum values for documentation""" 

63 if self.enum_values is None: 

64 return None 

65 return ", ".join(str(v) for v in self.enum_values) 

66 

67 def validate_enums(self, value: Any) -> None: 

68 if self.enum_values is not None: 

69 if isinstance(value, (list, set, tuple)): 

70 for v in value: 

71 if v not in self.enum_values: 

72 enum_str = self.print_enum() or "[]" 

73 raise ValueError(f"Value -{v}- is not one of {enum_str}") 

74 elif value not in self.enum_values: 

75 raise ValueError(f"Value -{value}- is not one of {self.print_enum()}") 

76 

77 def typed(self, value: Any) -> Any: 

78 if value is None: 

79 return None 

80 if not isinstance(value, str): 

81 value = str(value) 

82 value = value.strip() 

83 if self.arg_type == "int": 

84 return int(value) 

85 if self.arg_type == "float": 

86 return float(value) 

87 if self.arg_type == "boolean": 

88 return value.lower() in truthy_strings 

89 

90 # lstrip is a workaround to enable strings as path arguments 

91 # to be recognised, e.g. /user/:bob/delete 

92 return value.lstrip(":") 

93 

94 def __call__(self, request: Request) -> Any: 

95 val = self.extract(request) 

96 self.validate(val) 

97 if self.converter is not None: 

98 return self.converter(val) 

99 return val 

100 

101 def extract(self, request: Request) -> Any: 

102 raise NotImplementedError 

103 

104 def update_openapi_path_object(self, path_object: dict) -> None: 

105 spec: dict[str, Any] = { 

106 "in": self.swagger_in, 

107 "required": self.swagger_required, 

108 "name": self.doc_name, 

109 } 

110 if self.required is not None: 

111 spec["required"] = self.required 

112 

113 if self.doc is not None: 

114 spec["description"] = self.doc 

115 spec["schema"] = dict(type=self.swagger_type) 

116 if self.default is not None: 

117 val = self.default 

118 spec["schema"]["default"] = val 

119 

120 if getattr(self, "enum_values", None) is not None: 

121 type_name = f"_{path_object['operationId']}{self.arg_name.capitalize()}" 

122 spec["schema"]["$ref"] = f"#/components/schemas/{type_name}" 

123 

124 self.augment_openapi_path_object(spec) 

125 

126 path_object["parameters"].append(spec) 

127 

128 def augment_openapi_path_object(self, spec: dict) -> None: 

129 """ 

130 Subclasses can override this method to add additional information to the 

131 OpenAPI path object. 

132 """ 

133 pass 

134 

135 @property 

136 def swagger_type(self) -> str: 

137 return openapi_type_mapping[self.arg_type] 

138 

139 @property 

140 def doc_name(self) -> str: 

141 """The name of this Parameter for documentation purposes""" 

142 return self.arg_name 

143 

144 

145class PathArg(ArgExtractor): 

146 swagger_in = "path" 

147 

148 def __init__(self, *args: Any, **kwargs: Any) -> None: 

149 super(PathArg, self).__init__(*args, **kwargs) 

150 regexp = r".*/%s/([^/]+)/*" % self.arg_name 

151 self.path_regex = re.compile(regexp) 

152 

153 def extract(self, request: Request) -> Any: 

154 try: 

155 match = self.path_regex.match(request.path_info) 

156 if match is None: 

157 raise ValueError("No match found") 

158 val = match.groups()[0] 

159 return self.typed(val) 

160 except Exception: 

161 mess = ( 

162 f"{self.__class__.__name__} Adaptor could not extract " 

163 f'value "{self.arg_name}" from path {request.path_info} ' 

164 f"using regex {self.path_regex.pattern}" 

165 ) 

166 raise ValueError(mess) 

167 

168 @property 

169 def doc_name(self) -> str: 

170 """The name of this Parameter for documentation purposes""" 

171 return self.arg_name + "_id" 

172 

173 

174class GetArg(ArgExtractor): 

175 swagger_in = "query" 

176 swagger_required = False 

177 

178 def extract(self, request: Request) -> Any: 

179 if self.arg_name not in request.GET: 

180 if self.required: 

181 raise ValueError(f"Query param '{self.arg_name}'' must be provided") 

182 return self.typed(self.default) 

183 value = request.GET[self.arg_name] 

184 if value.strip() == "": 

185 return self.typed(self.default) 

186 else: 

187 self.validate_enums(value) 

188 return self.typed(value) 

189 

190 

191class GetArgSet(ArgExtractor): 

192 """Extracts a Set (sequence) of argument values for the given GET arg""" 

193 

194 swagger_in = "query" 

195 swagger_type = "array" 

196 swagger_required = False 

197 

198 def __init__(self, *args: Any, **kwargs: Any) -> None: 

199 self.array_items_type = kwargs.pop("array_items_type", "str") 

200 self.min_items = kwargs.pop("min_items", None) 

201 self.max_items = kwargs.pop("max_items", None) 

202 super().__init__(*args, **kwargs) 

203 

204 def extract(self, request: Request) -> set[Any]: 

205 arg_array_name = f"{self.arg_name}[]" 

206 if self.arg_name in request.GET: 

207 arg_array_name = self.arg_name 

208 values = request.GET.getall(arg_array_name) 

209 if self.min_items is not None and len(values) < self.min_items: 

210 raise ValueError( 

211 f"At least {self.min_items} {self.arg_name} parameters required" 

212 ) 

213 if self.max_items is not None and len(values) > self.max_items: 

214 raise ValueError( 

215 f"No more than {self.max_items} {self.arg_name} parameters permitted" 

216 ) 

217 self.validate_enums(values) 

218 return {self.typed(v) for v in values} 

219 

220 def augment_openapi_path_object(self, spec: dict) -> None: 

221 if self.min_items is not None: 

222 spec["schema"]["minItems"] = self.min_items 

223 

224 if self.max_items is not None: 

225 spec["schema"]["maxItems"] = self.max_items 

226 

227 if self.array_items_type is not None: 

228 spec["schema"] = { 

229 "type": "array", 

230 "items": {"type": openapi_type_mapping[self.array_items_type]}, 

231 } 

232 

233 

234def _get_or_create_schema( 

235 path_object: dict, mime_type: str = "application/x-www-form-urlencoded" 

236) -> dict: 

237 return ( 

238 path_object.setdefault("requestBody", {}) 

239 .setdefault("content", {}) 

240 .setdefault(mime_type, {}) 

241 .setdefault("schema", {}) 

242 ) 

243 

244 

245class PostArg(ArgExtractor): 

246 swagger_in = "formData" 

247 swagger_required = False 

248 swagger_type = "string" 

249 doc_name = "body" 

250 

251 def extract(self, request: Request) -> Any: 

252 value = request.POST.get(self.arg_name, self.default) 

253 if value is not None: 

254 tv = self.typed(value) 

255 if self.enum_values and tv not in self.enum_values: 

256 raise ValueError(f"{tv} is not in {self.enum_values}") 

257 return tv 

258 return None 

259 

260 def update_openapi_path_object(self, path_object: dict) -> None: 

261 schema = _get_or_create_schema(path_object, mime_type="multipart/form-data") 

262 schema.setdefault("type", "object") 

263 properties = schema.setdefault("properties", {}) 

264 properties[self.arg_name] = {"type": self.swagger_type} 

265 if self.enum_values is not None: 

266 properties[self.arg_name]["enum"] = list(self.enum_values) 

267 

268 

269class PostFileArg(ArgExtractor): 

270 swagger_in = "formData" 

271 swagger_type = "string" 

272 doc_name = "body" 

273 

274 def __init__(self, *args: Any, **kwargs: Any) -> None: 

275 kwargs["arg_type"] = "file" 

276 super().__init__(*args, **kwargs) 

277 

278 def extract(self, request: Request) -> Any: 

279 val = request.POST.get(self.arg_name, None) 

280 if val is None and self.required: 

281 raise ValueError(f"Post Parameter {self.arg_name} must be provided") 

282 return val 

283 

284 def update_openapi_path_object(self, path_object: dict) -> None: 

285 schema = _get_or_create_schema(path_object, mime_type="multipart/form-data") 

286 schema.setdefault("type", "object") 

287 properties = schema.setdefault("properties", {}) 

288 properties[self.arg_name] = {"type": "string", "format": "binary"} 

289 

290 

291""" 

292 Pydantic ALIAS 

293 

294 - alias should be the public name if different from the db column name 

295 - when dumping/serialising when outputting data from the data to the outside world 

296 use model.dict(by_alias) 

297 - when parsing incoming data, use model.model_dump() 

298 

299 by_alias == False (the default) means 'dump data for consumption by database' 

300 by_alias == True means 'dump data for use by API consumers' (default behaviour) 

301 

302""" 

303 

304 

305class SchemaDocArg(ArgExtractor): 

306 swagger_in = "body" 

307 

308 def __init__(self, schema_cls, as_dict=True, exclude_unset=False): 

309 self.as_dict = as_dict 

310 self.exclude_unset = exclude_unset 

311 super().__init__(None) 

312 self.schema_cls = schema_cls 

313 

314 def extract(self, request): 

315 try: 

316 # NB exclude_unset=True is important. If not set, or False, 

317 # then all the pydantic Model attributes not set will 

318 # appear in the data passed to API methods - i.e. lots of None 

319 # values which can overwrite real data on update 

320 pydantic_model = self.schema_cls.model_validate(request.json) 

321 if self.as_dict: 

322 return pydantic_model.model_dump(exclude_unset=self.exclude_unset) 

323 else: 

324 return pydantic_model 

325 

326 except ValidationError as e: 

327 err_list = [] 

328 for err in e.errors(): 

329 location = " > ".join(f"'{e}'" for e in err["loc"]) 

330 msg = err["msg"].title() 

331 err_msg = f"Validation failed. {msg}: {location}" 

332 err_list.append(err_msg) 

333 raise ValidationFailure(str(e), err_list) 

334 

335 def update_openapi_path_object(self, path_obj): 

336 path_obj["requestBody"] = { 

337 "content": { 

338 "application/json": { 

339 "schema": { 

340 "$ref": "#/components/schemas/%s" % self.schema_cls.__name__ 

341 } 

342 } 

343 } 

344 }