Coverage for postrfp/web/middleware.py: 100%

31 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-10-22 21:34 +0000

1import logging 

2import json 

3import os 

4 

5from webob import Response 

6from webob.static import DirectoryApp 

7 

8from postrfp.shared.utils import json_default 

9from postrfp import conf 

10from postrfp.templates import get_template 

11from .request import HttpRequest 

12from ..shared.response import X_ACCEL_HEADER 

13from postrfp.mail.stub import MAILBOX, clear_mailbox 

14from postrfp.model.audit import AuditEvent, Status as EvtStatus 

15 

16log = logging.getLogger(__name__) 

17 

18 

19class DevMiddleware(object): # pragma: no cover 

20 def __init__(self, app, session_factory): 

21 self.app = app 

22 self.session_factory = session_factory 

23 self.json_app = self.static_app(conf.CONF.cache_dir) 

24 self.attachment_app = self.static_app(conf.CONF.attachments_dir) 

25 

26 log.warning(" Initialised DevMiddleware - for DEVELOPMENT ONLY! \n") 

27 log.info("cache_dir : %s " % conf.CONF.cache_dir) 

28 log.info("attachments_dir: %s" % conf.CONF.attachments_dir) 

29 

30 def static_app(self, dir_path): 

31 # the last bit of the path for attachments and cache directories is the databasename 

32 # nginx is configured to serve the next directory up (attachments/ or cache/) 

33 # so use the parent directory 

34 return DirectoryApp(dir_path.parent) 

35 

36 def __call__(self, environ, start_response): 

37 request = HttpRequest(environ) 

38 preemptive_response = self.process_request(request) 

39 

40 if preemptive_response is not None: 

41 return preemptive_response(environ, start_response) 

42 

43 app_response = request.get_response(self.app) 

44 

45 updated_response = self.process_response(request, app_response) 

46 if updated_response is not None: 

47 return updated_response(environ, start_response) 

48 

49 return app_response(environ, start_response) 

50 

51 def process_request(self, request): 

52 under_path = request.path_info.strip("/").replace("/", "_") 

53 

54 if hasattr(self, under_path): 

55 # route 'GET blah/bloo' to self.blah_bloo(request) 

56 return getattr(self, under_path)(request) 

57 

58 def process_response(self, request, response): 

59 if X_ACCEL_HEADER in response.headers: 

60 return self.xaccel(request, response) 

61 

62 # Anything-goes CORS configuration 

63 response.headers.update( 

64 { 

65 "Access-Control-Allow-Origin": "*", 

66 "Access-Control-Allow-Methods": "GET, POST, DELETE, PUT, PATCH, OPTIONS", 

67 "Access-Control-Allow-Headers": "Content-Type, api_key, Authorization", 

68 } 

69 ) 

70 

71 def test_rollback(self, _request): 

72 """ 

73 Rollback SAVEPOINT session when running multiple requests 

74 within a nested transaction 

75 """ 

76 from postrfp.shared import constants 

77 

78 if constants.TEST_SESSION is not None: 

79 if constants.TEST_SESSION.get_transaction() is not None: 

80 log.info( 

81 "Rolling back transaction. Nested: %s", 

82 constants.TEST_SESSION.get_transaction().nested, 

83 ) 

84 constants.TEST_SESSION.rollback() 

85 constants.TEST_SESSION.close() 

86 constants.TEST_SESSION = None 

87 

88 # Clear the mailbox of outbound emails 

89 clear_mailbox() 

90 return Response("ok") 

91 

92 def test_mailbox(self, request): 

93 """ 

94 Inspect contents of debug mailbox 

95 """ 

96 if request.prefers_json: 

97 emails = json.dumps(list(reversed(MAILBOX)), default=json_default) 

98 return Response(emails, charset="UTF-8", content_type="application/json") 

99 else: 

100 tmpl = get_template("tools/debug_emails.html") 

101 html = tmpl.render(emails=reversed(MAILBOX), request=request, user=None) 

102 return Response(html) 

103 

104 def test_process_events(self, _request): 

105 """Run background events processor - for testing without 

106 background process or thread 

107 """ 

108 from postrfp.jobs.events import handle_event 

109 

110 session = self.session_factory() 

111 evt_count = 0 

112 for evt in session.query(AuditEvent).filter( 

113 AuditEvent.status == EvtStatus.pending 

114 ): 

115 handle_event(evt, session) 

116 evt_count += 1 

117 

118 ev_data = json.dumps({"evt_count": evt_count}) 

119 return Response(ev_data, charset="UTF-8", content_type="application/json") 

120 

121 def xaccel(self, request, response): 

122 """Mimic NGINX's X-Accel-Redirect functionality""" 

123 

124 fpath = response.headers["X-Accel-Redirect"] 

125 request.path_info = fpath 

126 # virtual_path is the path used address the relevant 

127 # 'location{' block in nginx config 

128 virtual_path = request.path_info_pop() 

129 

130 if virtual_path == "cache": 

131 app = self.json_app 

132 else: 

133 app = self.attachment_app 

134 

135 new_response = request.get_response(app) 

136 

137 if new_response.status_int == 404: 

138 log.error( 

139 "File %s not found for request %s" % os.path.join(app.path, fpath), 

140 request, 

141 ) 

142 return new_response 

143 

144 new_response.content_type = response.content_type 

145 new_response.content_disposition = response.content_disposition 

146 

147 new_response.cache_control = "no-store" 

148 

149 return new_response 

150 

151 def config(self, request): 

152 """Show config information for server - db name etc""" 

153 tmpl = get_template("tools/conf_info.html") 

154 html = tmpl.render(conf=conf.CONF, request=request, user=None) 

155 return Response(html) 

156 

157 

158class DispatchingMiddleware: 

159 """Combine multiple applications as a single WSGI application. 

160 Requests are dispatched to an application based on the path it is 

161 mounted under. 

162 

163 :param app: The WSGI application to dispatch to if the request 

164 doesn't match a mounted path. 

165 :param mounts: Maps path prefixes to applications for dispatching. 

166 """ 

167 

168 def __init__(self, app, mounts=None): 

169 self.app = app 

170 self.mounts = mounts or {} 

171 

172 def __call__(self, environ, start_response): 

173 script = environ.get("PATH_INFO", "") 

174 path_info = "" 

175 

176 while "/" in script: 

177 if script in self.mounts: 

178 app = self.mounts[script] 

179 break 

180 

181 script, last_item = script.rsplit("/", 1) 

182 path_info = f"/{last_item}{path_info}" 

183 else: 

184 app = self.mounts.get(script, self.app) 

185 

186 original_script_name = environ.get("SCRIPT_NAME", "") 

187 environ["SCRIPT_NAME"] = original_script_name + script 

188 environ["PATH_INFO"] = path_info 

189 return app(environ, start_response)