Coverage for postrfp/web/base.py: 100%

123 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-10-22 21:34 +0000

1import logging 

2from zlib import crc32 

3from contextlib import contextmanager 

4from typing import Dict, List, Union, Callable, Optional, TYPE_CHECKING 

5 

6import orjson 

7import webob.exc 

8from webob import Response 

9from webob.dec import wsgify 

10from pydantic.main import BaseModel 

11from semantic_version import Version # type: ignore[import] 

12 

13from postrfp.shared.tools import read_api_version 

14from postrfp.shared.utils import json_default, benchmark 

15from postrfp import conf 

16from postrfp.conf.settings import RunMode 

17from postrfp.auth.policy import DevHeaderPolicy, AbstractIdentityPolicy, JwtBearerPolicy 

18from postrfp.templates import init_jinja, get_template 

19from postrfp.web.request import HttpRequest 

20from postrfp.web.exception import resolve_exception 

21 

22if TYPE_CHECKING: 

23 from postrfp.web.suxint import Sux 

24 

25 

26log = logging.getLogger(__name__) 

27 

28JS_ERR_MSG = "Server Error" 

29 

30API_VERSION_HTTP_HEADER = "X-POSTRFP-API-VERSION" 

31 

32 

33v = read_api_version() 

34API_VERSION = Version(v) 

35 

36 

37def jsonify_models(api_output) -> Union[List, Dict]: 

38 if isinstance(api_output, (list, set)) and len(api_output) > 0: 

39 item = api_output[0] if isinstance(api_output, list) else api_output.pop() 

40 if isinstance(item, BaseModel): 

41 return [r.model_dump(by_alias=True) for r in api_output] 

42 return api_output 

43 

44 

45def render(request: HttpRequest, api_output): 

46 if api_output is None: 

47 json_bytes = b'{"result": "ok"}' 

48 else: 

49 if isinstance(api_output, BaseModel): 

50 json_bytes = api_output.model_dump_json(by_alias=True).encode("utf-8") 

51 else: 

52 json_data = jsonify_models(api_output) 

53 json_bytes = orjson.dumps(json_data, default=json_default) 

54 

55 if request.prefers_json: 

56 res = Response(json_bytes, charset="utf-8", content_type="application/json") 

57 else: 

58 template = get_template("api.html") 

59 html_output = template.render( 

60 js_doc=json_bytes.decode("utf-8"), url=request.path, request=request 

61 ) 

62 res = Response(html_output) 

63 

64 if getattr(request, "generate_etag", False): 

65 res.headers.add("Cache-Control", "must-revalidate") 

66 res.etag = str(crc32(json_bytes)) 

67 if res.etag in request.if_none_match: 

68 return webob.exc.HTTPNotModified(etag=res.etag) 

69 else: 

70 res.headers.add("Cache-Control", "no-cache") 

71 

72 return res 

73 

74 

75@contextmanager 

76def commit_or_rollback(session): 

77 try: 

78 yield 

79 except Exception: # nopep8 

80 session.rollback() 

81 raise 

82 else: 

83 session.commit() 

84 finally: 

85 session.close() 

86 

87 

88class WSGIApp(object): 

89 """ 

90 Entry point for WSGI commerce 

91 """ 

92 

93 routes: dict[str, Callable] = {} 

94 

95 def __init__( 

96 self, 

97 session_factory=None, 

98 auth_policy: AbstractIdentityPolicy | None = None, 

99 api_path="api", 

100 ): 

101 self.session_factory = session_factory 

102 self.sux_instance: Optional[Sux] = None 

103 self.api_path = api_path 

104 self.auth_policy = auth_policy 

105 

106 if auth_policy is None: 

107 if ( 

108 conf.CONF.run_mode is RunMode.test 

109 or conf.CONF.run_mode is RunMode.development 

110 ): 

111 self.auth_policy = DevHeaderPolicy() 

112 else: 

113 self.auth_policy = JwtBearerPolicy() 

114 else: 

115 if isinstance(auth_policy, type): 

116 self.auth_policy = auth_policy() 

117 elif not isinstance(auth_policy, AbstractIdentityPolicy): 

118 raise TypeError("auth_policy must inherit from AbstractIdenityPolicy") 

119 

120 init_jinja() 

121 self.build_sux() 

122 

123 log.info( 

124 "%s App initialised. Auth: %s. API Version %s", 

125 self.__class__.__name__, 

126 self.auth_policy.__class__.__name__, 

127 API_VERSION, 

128 ) 

129 

130 def build_sux(self): # pragma: no cover 

131 """ 

132 If a subclass wants to serve a suxint.Sux API then it must 

133 implement this method to assign a value to self.sux_instance 

134 """ 

135 raise NotImplementedError 

136 

137 @wsgify(RequestClass=HttpRequest) 

138 def __call__(self, request): 

139 try: 

140 handler = self.resolve_route(request) 

141 

142 request.session = session = self.session_factory() 

143 self.authenticate(request) 

144 

145 with commit_or_rollback(session): 

146 response = handler(request) 

147 

148 self.auth_policy.remember(request, response) 

149 

150 return response 

151 

152 except Exception as e: 

153 if conf.CONF.run_mode is RunMode.development: 

154 log.exception( 

155 "Exception in Base webapp, RunMode.development so raising.." 

156 ) 

157 raise 

158 else: 

159 # Set request user to None to avoid detached sqla session 

160 # errors caused by User object lurking in environ dict 

161 request.user = None 

162 return resolve_exception(request, e) 

163 

164 def authenticate(self, request): 

165 self.auth_policy.identify(request) 

166 self.validate_user(request) 

167 

168 def resolve_route(self, request): 

169 sub_app = request.path_info_peek() or "" 

170 

171 if sub_app == self.api_path: 

172 return self.rest_api 

173 

174 elif sub_app in self.routes: 

175 return self.routes[sub_app] 

176 else: 

177 log.warning("No handler found for sub_app: %s", sub_app) 

178 raise webob.exc.HTTPNotFound 

179 

180 def rest_api(self, request): 

181 path_info = request.path_info 

182 with benchmark("API call to %s %s" % (request.method, path_info)): 

183 api_output = self.sux_instance(request) 

184 

185 if isinstance(api_output, Response): 

186 response = api_output 

187 else: 

188 response = render(request, api_output) 

189 

190 response.headers.add(API_VERSION_HTTP_HEADER, str(API_VERSION)) 

191 

192 return response 

193 

194 @classmethod 

195 def route(cls, url_path): 

196 """Provides a decorator method for handler functions to register 

197 at the given URL path 

198 """ 

199 

200 def wrapper(handler_function): 

201 base_path = url_path.lstrip("/") 

202 if base_path in cls.routes: 

203 existing_handler = cls.routes[base_path] 

204 args = (base_path, existing_handler, cls) 

205 raise ValueError("%s path already taken by %s in %s" % args) 

206 cls.routes[url_path.lstrip("/")] = handler_function 

207 return handler_function 

208 

209 return wrapper 

210 

211 def __repr__(self): 

212 return "App - Base WSGI application" 

213 

214 def validate_user(self, request): # pragma: no-cover 

215 raise NotImplementedError("Subclasses to implement")