Coverage for postrfp / web / base.py: 100%
122 statements
« prev ^ index » next coverage.py v7.12.0, created at 2025-12-03 01:35 +0000
« prev ^ index » next coverage.py v7.12.0, created at 2025-12-03 01:35 +0000
1import logging
2from zlib import crc32
3from contextlib import contextmanager
4from typing import Dict, List, Union, Callable, Optional, TYPE_CHECKING
6import orjson
7import webob.exc
8from webob import Response
9from webob.dec import wsgify
10from pydantic.main import BaseModel
11from semantic_version import Version # type: ignore[import]
13from postrfp.shared.tools import read_api_version
14from postrfp.shared.utils import json_default, benchmark
15from postrfp import conf
16from postrfp.shared.constants import RunMode
17from postrfp.auth.policy import DevHeaderPolicy, AbstractIdentityPolicy, JwtBearerPolicy
18from postrfp.templates import init_jinja, get_template
19from postrfp.web.request import HttpRequest
20from postrfp.web.exception import resolve_exception
22if TYPE_CHECKING:
23 from postrfp.web.suxint import Sux
26log = logging.getLogger(__name__)
28API_VERSION_HTTP_HEADER = "X-POSTRFP-API-VERSION"
31v = read_api_version()
32API_VERSION = Version(v)
35def jsonify_models(api_output) -> Union[List, Dict]:
36 if isinstance(api_output, (list, set)) and len(api_output) > 0:
37 item = api_output[0] if isinstance(api_output, list) else api_output.pop()
38 if isinstance(item, BaseModel):
39 return [r.model_dump(by_alias=True) for r in api_output]
40 return api_output
43def render(request: HttpRequest, api_output):
44 if api_output is None:
45 json_bytes = b'{"result": "ok"}'
46 else:
47 if isinstance(api_output, BaseModel):
48 json_bytes = api_output.model_dump_json(by_alias=True).encode("utf-8")
49 else:
50 json_data = jsonify_models(api_output)
51 json_bytes = orjson.dumps(json_data, default=json_default)
53 if request.prefers_json:
54 res = Response(json_bytes, charset="utf-8", content_type="application/json")
55 else:
56 template = get_template("api.html")
57 html_output = template.render(
58 js_doc=json_bytes.decode("utf-8"), url=request.path, request=request
59 )
60 res = Response(html_output)
62 if getattr(request, "generate_etag", False):
63 res.headers.add("Cache-Control", "must-revalidate")
64 res.etag = str(crc32(json_bytes))
65 if res.etag in request.if_none_match:
66 return webob.exc.HTTPNotModified(etag=res.etag)
67 else:
68 res.headers.add("Cache-Control", "no-cache")
70 return res
73@contextmanager
74def commit_or_rollback(session):
75 try:
76 yield
77 except Exception: # nopep8
78 session.rollback()
79 raise
80 else:
81 session.commit()
82 finally:
83 session.close()
86class WSGIApp(object):
87 """
88 Entry point for WSGI commerce
89 """
91 routes: dict[str, Callable] = {}
93 def __init__(
94 self,
95 session_factory=None,
96 auth_policy: AbstractIdentityPolicy | None = None,
97 api_path="api",
98 ):
99 self.session_factory = session_factory
100 self.sux_instance: Optional[Sux] = None
101 self.api_path = api_path
102 self.auth_policy = auth_policy
104 if auth_policy is None:
105 if (
106 conf.CONF.run_mode is RunMode.test
107 or conf.CONF.run_mode is RunMode.development
108 ):
109 self.auth_policy = DevHeaderPolicy()
110 else:
111 self.auth_policy = JwtBearerPolicy()
112 else:
113 if isinstance(auth_policy, type):
114 self.auth_policy = auth_policy()
115 elif not isinstance(auth_policy, AbstractIdentityPolicy):
116 raise TypeError("auth_policy must inherit from AbstractIdenityPolicy")
118 init_jinja()
119 self.build_sux()
121 log.info(
122 "%s App initialised. Auth: %s. API Version %s",
123 self.__class__.__name__,
124 self.auth_policy.__class__.__name__,
125 API_VERSION,
126 )
128 def build_sux(self): # pragma: no cover
129 """
130 If a subclass wants to serve a suxint.Sux API then it must
131 implement this method to assign a value to self.sux_instance
132 """
133 raise NotImplementedError
135 @wsgify(RequestClass=HttpRequest)
136 def __call__(self, request):
137 try:
138 handler = self.resolve_route(request)
140 request.session = session = self.session_factory()
141 self.authenticate(request)
143 with commit_or_rollback(session):
144 response = handler(request)
146 self.auth_policy.remember(request, response)
148 return response
150 except Exception as e:
151 if conf.CONF.run_mode is RunMode.development:
152 log.exception(
153 "Exception in Base webapp, RunMode.development so raising.."
154 )
155 raise
156 else:
157 # Set request user to None to avoid detached sqla session
158 # errors caused by User object lurking in environ dict
159 request.user = None
160 return resolve_exception(request, e)
162 def authenticate(self, request):
163 self.auth_policy.identify(request)
164 self.validate_user(request)
166 def resolve_route(self, request):
167 sub_app = request.path_info_peek() or ""
169 if sub_app == self.api_path:
170 return self.rest_api
172 elif sub_app in self.routes:
173 return self.routes[sub_app]
174 else:
175 log.warning("No handler found for sub_app: %s", sub_app)
176 raise webob.exc.HTTPNotFound
178 def rest_api(self, request):
179 path_info = request.path_info
180 with benchmark("API call to %s %s" % (request.method, path_info)):
181 api_output = self.sux_instance(request)
183 if isinstance(api_output, Response):
184 response = api_output
185 else:
186 response = render(request, api_output)
188 response.headers.add(API_VERSION_HTTP_HEADER, str(API_VERSION))
190 return response
192 @classmethod
193 def route(cls, url_path):
194 """Provides a decorator method for handler functions to register
195 at the given URL path
196 """
198 def wrapper(handler_function):
199 base_path = url_path.lstrip("/")
200 if base_path in cls.routes:
201 existing_handler = cls.routes[base_path]
202 args = (base_path, existing_handler, cls)
203 raise ValueError("%s path already taken by %s in %s" % args)
204 cls.routes[url_path.lstrip("/")] = handler_function
205 return handler_function
207 return wrapper
209 def __repr__(self):
210 return "App - Base WSGI application"
212 def validate_user(self, request): # pragma: no-cover
213 raise NotImplementedError("Subclasses to implement")