Coverage for postrfp / web / base.py: 100%

122 statements  

« prev     ^ index     » next       coverage.py v7.12.0, created at 2025-12-03 01:35 +0000

1import logging 

2from zlib import crc32 

3from contextlib import contextmanager 

4from typing import Dict, List, Union, Callable, Optional, TYPE_CHECKING 

5 

6import orjson 

7import webob.exc 

8from webob import Response 

9from webob.dec import wsgify 

10from pydantic.main import BaseModel 

11from semantic_version import Version # type: ignore[import] 

12 

13from postrfp.shared.tools import read_api_version 

14from postrfp.shared.utils import json_default, benchmark 

15from postrfp import conf 

16from postrfp.shared.constants import RunMode 

17from postrfp.auth.policy import DevHeaderPolicy, AbstractIdentityPolicy, JwtBearerPolicy 

18from postrfp.templates import init_jinja, get_template 

19from postrfp.web.request import HttpRequest 

20from postrfp.web.exception import resolve_exception 

21 

22if TYPE_CHECKING: 

23 from postrfp.web.suxint import Sux 

24 

25 

26log = logging.getLogger(__name__) 

27 

28API_VERSION_HTTP_HEADER = "X-POSTRFP-API-VERSION" 

29 

30 

31v = read_api_version() 

32API_VERSION = Version(v) 

33 

34 

35def jsonify_models(api_output) -> Union[List, Dict]: 

36 if isinstance(api_output, (list, set)) and len(api_output) > 0: 

37 item = api_output[0] if isinstance(api_output, list) else api_output.pop() 

38 if isinstance(item, BaseModel): 

39 return [r.model_dump(by_alias=True) for r in api_output] 

40 return api_output 

41 

42 

43def render(request: HttpRequest, api_output): 

44 if api_output is None: 

45 json_bytes = b'{"result": "ok"}' 

46 else: 

47 if isinstance(api_output, BaseModel): 

48 json_bytes = api_output.model_dump_json(by_alias=True).encode("utf-8") 

49 else: 

50 json_data = jsonify_models(api_output) 

51 json_bytes = orjson.dumps(json_data, default=json_default) 

52 

53 if request.prefers_json: 

54 res = Response(json_bytes, charset="utf-8", content_type="application/json") 

55 else: 

56 template = get_template("api.html") 

57 html_output = template.render( 

58 js_doc=json_bytes.decode("utf-8"), url=request.path, request=request 

59 ) 

60 res = Response(html_output) 

61 

62 if getattr(request, "generate_etag", False): 

63 res.headers.add("Cache-Control", "must-revalidate") 

64 res.etag = str(crc32(json_bytes)) 

65 if res.etag in request.if_none_match: 

66 return webob.exc.HTTPNotModified(etag=res.etag) 

67 else: 

68 res.headers.add("Cache-Control", "no-cache") 

69 

70 return res 

71 

72 

73@contextmanager 

74def commit_or_rollback(session): 

75 try: 

76 yield 

77 except Exception: # nopep8 

78 session.rollback() 

79 raise 

80 else: 

81 session.commit() 

82 finally: 

83 session.close() 

84 

85 

86class WSGIApp(object): 

87 """ 

88 Entry point for WSGI commerce 

89 """ 

90 

91 routes: dict[str, Callable] = {} 

92 

93 def __init__( 

94 self, 

95 session_factory=None, 

96 auth_policy: AbstractIdentityPolicy | None = None, 

97 api_path="api", 

98 ): 

99 self.session_factory = session_factory 

100 self.sux_instance: Optional[Sux] = None 

101 self.api_path = api_path 

102 self.auth_policy = auth_policy 

103 

104 if auth_policy is None: 

105 if ( 

106 conf.CONF.run_mode is RunMode.test 

107 or conf.CONF.run_mode is RunMode.development 

108 ): 

109 self.auth_policy = DevHeaderPolicy() 

110 else: 

111 self.auth_policy = JwtBearerPolicy() 

112 else: 

113 if isinstance(auth_policy, type): 

114 self.auth_policy = auth_policy() 

115 elif not isinstance(auth_policy, AbstractIdentityPolicy): 

116 raise TypeError("auth_policy must inherit from AbstractIdenityPolicy") 

117 

118 init_jinja() 

119 self.build_sux() 

120 

121 log.info( 

122 "%s App initialised. Auth: %s. API Version %s", 

123 self.__class__.__name__, 

124 self.auth_policy.__class__.__name__, 

125 API_VERSION, 

126 ) 

127 

128 def build_sux(self): # pragma: no cover 

129 """ 

130 If a subclass wants to serve a suxint.Sux API then it must 

131 implement this method to assign a value to self.sux_instance 

132 """ 

133 raise NotImplementedError 

134 

135 @wsgify(RequestClass=HttpRequest) 

136 def __call__(self, request): 

137 try: 

138 handler = self.resolve_route(request) 

139 

140 request.session = session = self.session_factory() 

141 self.authenticate(request) 

142 

143 with commit_or_rollback(session): 

144 response = handler(request) 

145 

146 self.auth_policy.remember(request, response) 

147 

148 return response 

149 

150 except Exception as e: 

151 if conf.CONF.run_mode is RunMode.development: 

152 log.exception( 

153 "Exception in Base webapp, RunMode.development so raising.." 

154 ) 

155 raise 

156 else: 

157 # Set request user to None to avoid detached sqla session 

158 # errors caused by User object lurking in environ dict 

159 request.user = None 

160 return resolve_exception(request, e) 

161 

162 def authenticate(self, request): 

163 self.auth_policy.identify(request) 

164 self.validate_user(request) 

165 

166 def resolve_route(self, request): 

167 sub_app = request.path_info_peek() or "" 

168 

169 if sub_app == self.api_path: 

170 return self.rest_api 

171 

172 elif sub_app in self.routes: 

173 return self.routes[sub_app] 

174 else: 

175 log.warning("No handler found for sub_app: %s", sub_app) 

176 raise webob.exc.HTTPNotFound 

177 

178 def rest_api(self, request): 

179 path_info = request.path_info 

180 with benchmark("API call to %s %s" % (request.method, path_info)): 

181 api_output = self.sux_instance(request) 

182 

183 if isinstance(api_output, Response): 

184 response = api_output 

185 else: 

186 response = render(request, api_output) 

187 

188 response.headers.add(API_VERSION_HTTP_HEADER, str(API_VERSION)) 

189 

190 return response 

191 

192 @classmethod 

193 def route(cls, url_path): 

194 """Provides a decorator method for handler functions to register 

195 at the given URL path 

196 """ 

197 

198 def wrapper(handler_function): 

199 base_path = url_path.lstrip("/") 

200 if base_path in cls.routes: 

201 existing_handler = cls.routes[base_path] 

202 args = (base_path, existing_handler, cls) 

203 raise ValueError("%s path already taken by %s in %s" % args) 

204 cls.routes[url_path.lstrip("/")] = handler_function 

205 return handler_function 

206 

207 return wrapper 

208 

209 def __repr__(self): 

210 return "App - Base WSGI application" 

211 

212 def validate_user(self, request): # pragma: no-cover 

213 raise NotImplementedError("Subclasses to implement")