Coverage for postrfp / web / middleware.py: 100%
31 statements
« prev ^ index » next coverage.py v7.12.0, created at 2025-12-03 01:35 +0000
« prev ^ index » next coverage.py v7.12.0, created at 2025-12-03 01:35 +0000
1import logging
2import json
3import os
5from webob import Response
6from webob.static import DirectoryApp
8from postrfp.shared.utils import json_default
9from postrfp import conf
10from postrfp.templates import get_template
11from .request import HttpRequest
12from ..shared.response import X_ACCEL_HEADER
13from postrfp.mail.stub import MAILBOX, clear_mailbox
14from postrfp.model.audit import AuditEvent, Status as EvtStatus
16log = logging.getLogger(__name__)
19class DevMiddleware(object): # pragma: no cover
20 """
21 Development middleware for PostRFP.
24 This middleware is intended for use in a development environment only.
25 """
27 def __init__(self, app, session_factory):
28 self.app = app
29 self.session_factory = session_factory
30 self.json_app = self.static_app(conf.CONF.cache_dir)
31 self.attachment_app = self.static_app(conf.CONF.attachments_dir)
32 msg = (
33 f"--Initialised DevMiddleware for {self.app.__class__.__name__} - DEV ONLY!"
34 )
35 log.warning(msg)
36 print(msg)
37 log.info("cache_dir : %s " % conf.CONF.cache_dir)
38 log.info("attachments_dir: %s" % conf.CONF.attachments_dir)
40 def static_app(self, dir_path):
41 # the last bit of the path for attachments and cache directories is the databasename
42 # nginx is configured to serve the next directory up (attachments/ or cache/)
43 # so use the parent directory
44 return DirectoryApp(dir_path.parent)
46 def __call__(self, environ, start_response):
47 request = HttpRequest(environ)
48 preemptive_response = self.process_request(request)
50 if preemptive_response is not None:
51 return preemptive_response(environ, start_response)
53 app_response = request.get_response(self.app)
55 updated_response = self.process_response(request, app_response)
56 if updated_response is not None:
57 return updated_response(environ, start_response)
59 return app_response(environ, start_response)
61 def process_request(self, request):
62 under_path = request.path_info.strip("/").replace("/", "_")
64 if hasattr(self, under_path):
65 # route 'GET blah/bloo' to self.blah_bloo(request)
66 return getattr(self, under_path)(request)
68 def process_response(self, request, response):
69 if X_ACCEL_HEADER in response.headers:
70 return self.xaccel(request, response)
72 # Anything-goes CORS configuration
73 response.headers.update(
74 {
75 "Access-Control-Allow-Origin": "*",
76 "Access-Control-Allow-Methods": "GET, POST, DELETE, PUT, PATCH, OPTIONS",
77 "Access-Control-Allow-Headers": "Content-Type, api_key, Authorization",
78 }
79 )
81 def test_rollback(self, _request):
82 """
83 Rollback SAVEPOINT session when running multiple requests
84 within a nested transaction
85 """
86 from postrfp.shared import constants
88 if constants.TEST_SESSION is not None:
89 if constants.TEST_SESSION.get_transaction() is not None:
90 log.info(
91 "Rolling back transaction. Nested: %s",
92 constants.TEST_SESSION.get_transaction().nested,
93 )
94 constants.TEST_SESSION.rollback()
95 constants.TEST_SESSION.close()
96 constants.TEST_SESSION = None
98 # Clear the mailbox of outbound emails
99 clear_mailbox()
100 return Response("ok")
102 def test_mailbox(self, request):
103 """
104 Inspect contents of debug mailbox
105 """
106 if request.prefers_json:
107 emails = json.dumps(list(reversed(MAILBOX)), default=json_default)
108 return Response(emails, charset="UTF-8", content_type="application/json")
109 else:
110 tmpl = get_template("tools/debug_emails.html")
111 html = tmpl.render(emails=reversed(MAILBOX), request=request, user=None)
112 return Response(html)
114 def test_process_events(self, _request):
115 """Run background events processor - for testing without
116 background process or thread
117 """
118 from postrfp.jobs.events import handle_event
120 session = self.session_factory()
121 evt_count = 0
122 for evt in session.query(AuditEvent).filter(
123 AuditEvent.status == EvtStatus.pending
124 ):
125 handle_event(evt, session)
126 evt_count += 1
128 ev_data = json.dumps({"evt_count": evt_count})
129 return Response(ev_data, charset="UTF-8", content_type="application/json")
131 def xaccel(self, request, response):
132 """Mimic NGINX's X-Accel-Redirect functionality"""
134 fpath = response.headers["X-Accel-Redirect"]
135 request.path_info = fpath
136 # virtual_path is the path used address the relevant
137 # 'location{' block in nginx config
138 virtual_path = request.path_info_pop()
140 if virtual_path == "cache":
141 app = self.json_app
142 else:
143 app = self.attachment_app
145 new_response = request.get_response(app)
147 if new_response.status_int == 404:
148 log.error(
149 "File %s not found for request %s" % os.path.join(app.path, fpath),
150 request,
151 )
152 return new_response
154 new_response.content_type = response.content_type
155 new_response.content_disposition = response.content_disposition
157 new_response.cache_control = "no-store"
159 return new_response
161 def config(self, request):
162 """Show config information for server - db name etc"""
163 tmpl = get_template("tools/conf_info.html")
164 html = tmpl.render(conf=conf.CONF, request=request, user=None)
165 return Response(html)
168class DispatchingMiddleware:
169 """Combine multiple applications as a single WSGI application.
170 Requests are dispatched to an application based on the path it is
171 mounted under.
173 :param app: The WSGI application to dispatch to if the request
174 doesn't match a mounted path.
175 :param mounts: Maps path prefixes to applications for dispatching.
176 """
178 def __init__(self, app, mounts=None):
179 self.app = app
180 self.mounts = mounts or {}
182 def __call__(self, environ, start_response):
183 script = environ.get("PATH_INFO", "")
184 path_info = ""
186 while "/" in script:
187 if script in self.mounts:
188 app = self.mounts[script]
189 break
191 script, last_item = script.rsplit("/", 1)
192 path_info = f"/{last_item}{path_info}"
193 else:
194 app = self.mounts.get(script, self.app)
196 original_script_name = environ.get("SCRIPT_NAME", "")
197 environ["SCRIPT_NAME"] = original_script_name + script
198 environ["PATH_INFO"] = path_info
199 return app(environ, start_response)