Coverage for postrfp / web / middleware.py: 100%

31 statements  

« prev     ^ index     » next       coverage.py v7.12.0, created at 2025-12-03 01:35 +0000

1import logging 

2import json 

3import os 

4 

5from webob import Response 

6from webob.static import DirectoryApp 

7 

8from postrfp.shared.utils import json_default 

9from postrfp import conf 

10from postrfp.templates import get_template 

11from .request import HttpRequest 

12from ..shared.response import X_ACCEL_HEADER 

13from postrfp.mail.stub import MAILBOX, clear_mailbox 

14from postrfp.model.audit import AuditEvent, Status as EvtStatus 

15 

16log = logging.getLogger(__name__) 

17 

18 

19class DevMiddleware(object): # pragma: no cover 

20 """ 

21 Development middleware for PostRFP. 

22 

23 

24 This middleware is intended for use in a development environment only. 

25 """ 

26 

27 def __init__(self, app, session_factory): 

28 self.app = app 

29 self.session_factory = session_factory 

30 self.json_app = self.static_app(conf.CONF.cache_dir) 

31 self.attachment_app = self.static_app(conf.CONF.attachments_dir) 

32 msg = ( 

33 f"--Initialised DevMiddleware for {self.app.__class__.__name__} - DEV ONLY!" 

34 ) 

35 log.warning(msg) 

36 print(msg) 

37 log.info("cache_dir : %s " % conf.CONF.cache_dir) 

38 log.info("attachments_dir: %s" % conf.CONF.attachments_dir) 

39 

40 def static_app(self, dir_path): 

41 # the last bit of the path for attachments and cache directories is the databasename 

42 # nginx is configured to serve the next directory up (attachments/ or cache/) 

43 # so use the parent directory 

44 return DirectoryApp(dir_path.parent) 

45 

46 def __call__(self, environ, start_response): 

47 request = HttpRequest(environ) 

48 preemptive_response = self.process_request(request) 

49 

50 if preemptive_response is not None: 

51 return preemptive_response(environ, start_response) 

52 

53 app_response = request.get_response(self.app) 

54 

55 updated_response = self.process_response(request, app_response) 

56 if updated_response is not None: 

57 return updated_response(environ, start_response) 

58 

59 return app_response(environ, start_response) 

60 

61 def process_request(self, request): 

62 under_path = request.path_info.strip("/").replace("/", "_") 

63 

64 if hasattr(self, under_path): 

65 # route 'GET blah/bloo' to self.blah_bloo(request) 

66 return getattr(self, under_path)(request) 

67 

68 def process_response(self, request, response): 

69 if X_ACCEL_HEADER in response.headers: 

70 return self.xaccel(request, response) 

71 

72 # Anything-goes CORS configuration 

73 response.headers.update( 

74 { 

75 "Access-Control-Allow-Origin": "*", 

76 "Access-Control-Allow-Methods": "GET, POST, DELETE, PUT, PATCH, OPTIONS", 

77 "Access-Control-Allow-Headers": "Content-Type, api_key, Authorization", 

78 } 

79 ) 

80 

81 def test_rollback(self, _request): 

82 """ 

83 Rollback SAVEPOINT session when running multiple requests 

84 within a nested transaction 

85 """ 

86 from postrfp.shared import constants 

87 

88 if constants.TEST_SESSION is not None: 

89 if constants.TEST_SESSION.get_transaction() is not None: 

90 log.info( 

91 "Rolling back transaction. Nested: %s", 

92 constants.TEST_SESSION.get_transaction().nested, 

93 ) 

94 constants.TEST_SESSION.rollback() 

95 constants.TEST_SESSION.close() 

96 constants.TEST_SESSION = None 

97 

98 # Clear the mailbox of outbound emails 

99 clear_mailbox() 

100 return Response("ok") 

101 

102 def test_mailbox(self, request): 

103 """ 

104 Inspect contents of debug mailbox 

105 """ 

106 if request.prefers_json: 

107 emails = json.dumps(list(reversed(MAILBOX)), default=json_default) 

108 return Response(emails, charset="UTF-8", content_type="application/json") 

109 else: 

110 tmpl = get_template("tools/debug_emails.html") 

111 html = tmpl.render(emails=reversed(MAILBOX), request=request, user=None) 

112 return Response(html) 

113 

114 def test_process_events(self, _request): 

115 """Run background events processor - for testing without 

116 background process or thread 

117 """ 

118 from postrfp.jobs.events import handle_event 

119 

120 session = self.session_factory() 

121 evt_count = 0 

122 for evt in session.query(AuditEvent).filter( 

123 AuditEvent.status == EvtStatus.pending 

124 ): 

125 handle_event(evt, session) 

126 evt_count += 1 

127 

128 ev_data = json.dumps({"evt_count": evt_count}) 

129 return Response(ev_data, charset="UTF-8", content_type="application/json") 

130 

131 def xaccel(self, request, response): 

132 """Mimic NGINX's X-Accel-Redirect functionality""" 

133 

134 fpath = response.headers["X-Accel-Redirect"] 

135 request.path_info = fpath 

136 # virtual_path is the path used address the relevant 

137 # 'location{' block in nginx config 

138 virtual_path = request.path_info_pop() 

139 

140 if virtual_path == "cache": 

141 app = self.json_app 

142 else: 

143 app = self.attachment_app 

144 

145 new_response = request.get_response(app) 

146 

147 if new_response.status_int == 404: 

148 log.error( 

149 "File %s not found for request %s" % os.path.join(app.path, fpath), 

150 request, 

151 ) 

152 return new_response 

153 

154 new_response.content_type = response.content_type 

155 new_response.content_disposition = response.content_disposition 

156 

157 new_response.cache_control = "no-store" 

158 

159 return new_response 

160 

161 def config(self, request): 

162 """Show config information for server - db name etc""" 

163 tmpl = get_template("tools/conf_info.html") 

164 html = tmpl.render(conf=conf.CONF, request=request, user=None) 

165 return Response(html) 

166 

167 

168class DispatchingMiddleware: 

169 """Combine multiple applications as a single WSGI application. 

170 Requests are dispatched to an application based on the path it is 

171 mounted under. 

172 

173 :param app: The WSGI application to dispatch to if the request 

174 doesn't match a mounted path. 

175 :param mounts: Maps path prefixes to applications for dispatching. 

176 """ 

177 

178 def __init__(self, app, mounts=None): 

179 self.app = app 

180 self.mounts = mounts or {} 

181 

182 def __call__(self, environ, start_response): 

183 script = environ.get("PATH_INFO", "") 

184 path_info = "" 

185 

186 while "/" in script: 

187 if script in self.mounts: 

188 app = self.mounts[script] 

189 break 

190 

191 script, last_item = script.rsplit("/", 1) 

192 path_info = f"/{last_item}{path_info}" 

193 else: 

194 app = self.mounts.get(script, self.app) 

195 

196 original_script_name = environ.get("SCRIPT_NAME", "") 

197 environ["SCRIPT_NAME"] = original_script_name + script 

198 environ["PATH_INFO"] = path_info 

199 return app(environ, start_response)