Coverage for postrfp / web / suxint / extractors.py: 99%

225 statements  

« prev     ^ index     » next       coverage.py v7.12.0, created at 2025-12-03 01:35 +0000

1from typing import Callable, Any 

2import re 

3 

4from pydantic import ValidationError 

5from webob import Request 

6 

7from postrfp.model.exc import ValidationFailure 

8from postrfp.web.ext.openapi_types import openapi_type_mapping 

9 

10 

11truthy_strings = {"true", "1", "yes"} 

12 

13 

14class ArgExtractor(object): 

15 """ 

16 Base class for Argument Extractor classes 

17 

18 Instances of ArgExtractor are callables that extract a value 

19 from an HTTP request 

20 

21 e.g. GetArg('name') produces a callable that knows how to extract the 

22 GET parameter 'name' from an http request 

23 """ 

24 

25 swagger_in: str | None = None 

26 swagger_required = True 

27 

28 def __init__( 

29 self, 

30 arg_name: str, 

31 validator: Callable | None = None, 

32 doc: str | None = None, 

33 default: Any = None, 

34 arg_type: str = "int", 

35 converter: Callable | None = None, 

36 enum_values: set | list | tuple | None = None, 

37 required: bool | None = None, 

38 ) -> None: 

39 self.arg_name = arg_name 

40 self.doc = doc 

41 self.default = default 

42 self.arg_type = arg_type 

43 self.converter = converter 

44 self.required = required 

45 if enum_values is not None and not isinstance(enum_values, (set, list, tuple)): 

46 raise ValueError("enum_values must be a set, list or tuple") 

47 self.enum_values = enum_values 

48 if callable(validator): 

49 self.validator: Callable | None = validator 

50 elif validator is not None: 

51 raise ValueError("validator, if given, must be a callable") 

52 # subclasses might define a validator, don't overwrite it 

53 if not hasattr(self, "validator"): 

54 self.validator = None 

55 

56 def validate(self, value: Any) -> None: 

57 if self.validator: 

58 if not self.validator(value): 

59 raise ValueError("%s is not a valid value" % value) 

60 

61 def print_enum(self) -> str | None: 

62 """Return a string representation of the enum values for documentation""" 

63 if self.enum_values is None: 

64 return None 

65 return ", ".join(str(v) for v in self.enum_values) 

66 

67 def validate_enums(self, value: Any) -> None: 

68 if self.enum_values is not None: 

69 if isinstance(value, (list, set, tuple)): 

70 for v in value: 

71 if v not in self.enum_values: 

72 enum_str = self.print_enum() or "[]" 

73 raise ValueError(f"Value -{v}- is not one of {enum_str}") 

74 elif value not in self.enum_values: 

75 raise ValueError(f"Value -{value}- is not one of {self.print_enum()}") 

76 

77 def typed(self, value: Any) -> Any: 

78 if value is None: 

79 return None 

80 if not isinstance(value, str): 

81 value = str(value) 

82 value = value.strip() 

83 if self.arg_type == "int": 

84 return int(value) 

85 if self.arg_type == "float": 

86 return float(value) 

87 if self.arg_type == "boolean": 

88 return value.lower() in truthy_strings 

89 

90 # lstrip is a workaround to enable strings as path arguments 

91 # to be recognised, e.g. /user/:bob/delete 

92 return value.lstrip(":") 

93 

94 def __call__(self, request: Request) -> Any: 

95 val = self.extract(request) 

96 self.validate(val) 

97 if self.converter is not None: 

98 return self.converter(val) 

99 return val 

100 

101 def extract(self, request: Request) -> Any: 

102 raise NotImplementedError 

103 

104 def update_openapi_path_object(self, path_object: dict) -> None: 

105 params_object: dict[str, Any] = { 

106 "in": self.swagger_in, 

107 "required": self.swagger_required, 

108 "name": self.doc_name, 

109 } 

110 if self.required is not None: 

111 params_object["required"] = self.required 

112 

113 if self.doc is not None: 

114 params_object["description"] = self.doc 

115 params_object["schema"] = dict(type=self.swagger_type) 

116 if self.default is not None: 

117 val = self.default 

118 params_object["schema"]["default"] = val 

119 

120 if getattr(self, "enum_values", None) is not None: 

121 type_name = f"_{path_object['operationId']}{self.arg_name.capitalize()}" 

122 params_object["schema"]["$ref"] = f"#/components/schemas/{type_name}" 

123 

124 self.augment_openapi_params_object(params_object) 

125 self.augment_openapi_responses(path_object["responses"]) 

126 

127 path_object["parameters"].append(params_object) 

128 

129 def augment_openapi_params_object(self, params_object: dict) -> None: 

130 """ 

131 Subclasses can override this method to add additional information to the 

132 OpenAPI parameters object. 

133 """ 

134 pass 

135 

136 def augment_openapi_responses(self, responses_object: dict) -> None: 

137 """ 

138 Subclasses can override this method to add additional response codes 

139 to the OpenAPI handler responses object. 

140 """ 

141 pass 

142 

143 @property 

144 def swagger_type(self) -> str: 

145 return openapi_type_mapping[self.arg_type] 

146 

147 @property 

148 def doc_name(self) -> str: 

149 """The name of this Parameter for documentation purposes""" 

150 return self.arg_name 

151 

152 

153class PathArg(ArgExtractor): 

154 swagger_in = "path" 

155 

156 def __init__(self, *args: Any, **kwargs: Any) -> None: 

157 super(PathArg, self).__init__(*args, **kwargs) 

158 regexp = r".*/%s/([^/]+)/*" % self.arg_name 

159 self.path_regex = re.compile(regexp) 

160 

161 def extract(self, request: Request) -> Any: 

162 try: 

163 match = self.path_regex.match(request.path_info) 

164 if match is None: 

165 raise ValueError("No match found") 

166 val = match.groups()[0] 

167 return self.typed(val) 

168 except Exception: 

169 mess = ( 

170 f"{self.__class__.__name__} Adaptor could not extract " 

171 f'value "{self.arg_name}" from path {request.path_info} ' 

172 f"using regex {self.path_regex.pattern}" 

173 ) 

174 raise ValueError(mess) 

175 

176 @property 

177 def doc_name(self) -> str: 

178 """The name of this Parameter for documentation purposes""" 

179 return self.arg_name + "_id" 

180 

181 

182class GetArg(ArgExtractor): 

183 swagger_in = "query" 

184 swagger_required = False 

185 

186 def extract(self, request: Request) -> Any: 

187 if self.arg_name not in request.GET: 

188 if self.required: 

189 raise ValueError(f"Query param '{self.arg_name}'' must be provided") 

190 return self.typed(self.default) 

191 value = request.GET[self.arg_name] 

192 if value.strip() == "": 

193 return self.typed(self.default) 

194 else: 

195 self.validate_enums(value) 

196 return self.typed(value) 

197 

198 

199class GetArgSet(ArgExtractor): 

200 """Extracts a Set (sequence) of argument values for the given GET arg""" 

201 

202 swagger_in = "query" 

203 swagger_type = "array" 

204 swagger_required = False 

205 

206 def __init__(self, *args: Any, **kwargs: Any) -> None: 

207 self.array_items_type = kwargs.pop("array_items_type", "str") 

208 self.min_items = kwargs.pop("min_items", None) 

209 self.max_items = kwargs.pop("max_items", None) 

210 super().__init__(*args, **kwargs) 

211 

212 def extract(self, request: Request) -> set[Any]: 

213 arg_array_name = f"{self.arg_name}[]" 

214 if self.arg_name in request.GET: 

215 arg_array_name = self.arg_name 

216 values = request.GET.getall(arg_array_name) 

217 if self.min_items is not None and len(values) < self.min_items: 

218 raise ValueError( 

219 f"At least {self.min_items} {self.arg_name} parameters required" 

220 ) 

221 if self.max_items is not None and len(values) > self.max_items: 

222 raise ValueError( 

223 f"No more than {self.max_items} {self.arg_name} parameters permitted" 

224 ) 

225 self.validate_enums(values) 

226 return {self.typed(v) for v in values} 

227 

228 def augment_openapi_params_object(self, path_object: dict) -> None: 

229 if self.min_items is not None: 

230 path_object["schema"]["minItems"] = self.min_items 

231 

232 if self.max_items is not None: 

233 path_object["schema"]["maxItems"] = self.max_items 

234 

235 if self.array_items_type is not None: 

236 path_object["schema"] = { 

237 "type": "array", 

238 "items": {"type": openapi_type_mapping[self.array_items_type]}, 

239 } 

240 

241 

242def _get_or_create_schema( 

243 path_object: dict, mime_type: str = "application/x-www-form-urlencoded" 

244) -> dict: 

245 return ( 

246 path_object.setdefault("requestBody", {}) 

247 .setdefault("content", {}) 

248 .setdefault(mime_type, {}) 

249 .setdefault("schema", {}) 

250 ) 

251 

252 

253class PostArg(ArgExtractor): 

254 swagger_in = "formData" 

255 swagger_required = False 

256 swagger_type = "string" 

257 doc_name = "body" 

258 

259 def extract(self, request: Request) -> Any: 

260 value = request.POST.get(self.arg_name, self.default) 

261 if value is not None: 

262 tv = self.typed(value) 

263 if self.enum_values and tv not in self.enum_values: 

264 raise ValueError(f"{tv} is not in {self.enum_values}") 

265 return tv 

266 return None 

267 

268 def update_openapi_path_object(self, path_object: dict) -> None: 

269 schema = _get_or_create_schema(path_object, mime_type="multipart/form-data") 

270 schema.setdefault("type", "object") 

271 properties = schema.setdefault("properties", {}) 

272 properties[self.arg_name] = {"type": self.swagger_type} 

273 if self.enum_values is not None: 

274 properties[self.arg_name]["enum"] = list(self.enum_values) 

275 

276 

277class PostFileArg(ArgExtractor): 

278 swagger_in = "formData" 

279 swagger_type = "string" 

280 doc_name = "body" 

281 

282 def __init__(self, *args: Any, **kwargs: Any) -> None: 

283 kwargs["arg_type"] = "file" 

284 super().__init__(*args, **kwargs) 

285 

286 def extract(self, request: Request) -> Any: 

287 val = request.POST.get(self.arg_name, None) 

288 if val is None and self.required: 

289 raise ValueError(f"Post Parameter {self.arg_name} must be provided") 

290 return val 

291 

292 def update_openapi_path_object(self, path_object: dict) -> None: 

293 schema = _get_or_create_schema(path_object, mime_type="multipart/form-data") 

294 schema.setdefault("type", "object") 

295 properties = schema.setdefault("properties", {}) 

296 properties[self.arg_name] = {"type": "string", "format": "binary"} 

297 

298 

299""" 

300 Pydantic ALIAS 

301 

302 - alias should be the public name if different from the db column name 

303 - when dumping/serialising when outputting data from the data to the outside world 

304 use model.dict(by_alias) 

305 - when parsing incoming data, use model.model_dump() 

306 

307 by_alias == False (the default) means 'dump data for consumption by database' 

308 by_alias == True means 'dump data for use by API consumers' (default behaviour) 

309 

310""" 

311 

312 

313class SchemaDocArg(ArgExtractor): 

314 swagger_in = "body" 

315 

316 def __init__(self, schema_cls, as_dict=True, exclude_unset=False): 

317 self.as_dict = as_dict 

318 self.exclude_unset = exclude_unset 

319 super().__init__(None) 

320 self.schema_cls = schema_cls 

321 

322 def extract(self, request): 

323 try: 

324 # NB exclude_unset=True is important. If not set, or False, 

325 # then all the pydantic Model attributes not set will 

326 # appear in the data passed to API methods - i.e. lots of None 

327 # values which can overwrite real data on update 

328 pydantic_model = self.schema_cls.model_validate(request.json) 

329 if self.as_dict: 

330 return pydantic_model.model_dump(exclude_unset=self.exclude_unset) 

331 else: 

332 return pydantic_model 

333 

334 except ValidationError as e: 

335 err_list = [] 

336 for err in e.errors(): 

337 location = " > ".join(f"'{e}'" for e in err["loc"]) 

338 msg = err["msg"].title() 

339 err_msg = f"Validation failed. {msg}: {location}" 

340 err_list.append(err_msg) 

341 raise ValidationFailure(str(e), err_list) 

342 

343 def update_openapi_path_object(self, path_obj): 

344 path_obj["requestBody"] = { 

345 "content": { 

346 "application/json": { 

347 "schema": { 

348 "$ref": "#/components/schemas/%s" % self.schema_cls.__name__ 

349 } 

350 } 

351 } 

352 } 

353 

354 

355class HeaderArg(ArgExtractor): 

356 swagger_in = "header" 

357 swagger_type = "string" 

358 swagger_required = False 

359 

360 def extract(self, request: Request) -> Any: 

361 if self.arg_name not in request.headers: 

362 if self.required: 

363 raise ValueError(f"Header '{self.arg_name}'' must be provided") 

364 return self.typed(self.default) 

365 value = request.headers[self.arg_name] 

366 if value.strip() == "": 

367 return self.typed(self.default) 

368 else: 

369 self.validate_enums(value) 

370 return self.typed(value) 

371 

372 def augment_openapi_responses(self, responses_object: dict) -> None: 

373 responses_object["412"] = { 

374 "description": "Precondition Failed - ETag did not match", 

375 }