Coverage for postrfp/model/questionnaire/cte_weights.py: 100%

67 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-10-22 21:34 +0000

1""" 

2CTE-based weight calculation for PostRFP questionnaires. 

3 

4This module implements weight calculations using Common Table Expressions (CTEs) 

5instead of the traditional visitor pattern approach, providing significant 

6performance improvements for large questionnaires. 

7""" 

8 

9import logging 

10from decimal import Decimal 

11from typing import Optional, Dict, Any, List, TYPE_CHECKING 

12 

13from sqlalchemy import text, insert 

14from sqlalchemy.orm import Session 

15 

16from postrfp.model.questionnaire.weightings import TotalWeighting 

17 

18if TYPE_CHECKING: 

19 from postrfp.model.project import Project 

20 

21log = logging.getLogger(__name__) 

22 

23 

24# Recursive CTE query for calculating section weights 

25# MariaDB-compatible recursive CTE that builds hierarchy and calculates normalized weights 

26SECTION_WEIGHTS_CTE_QUERY = """ 

27WITH RECURSIVE section_hierarchy AS ( 

28 -- Anchor: Start with root sections (parent_id IS NULL) 

29 SELECT 

30 s.id, 

31 s.parent_id, 

32 COALESCE(w.value, 1.0) as section_weight, 

33 CAST(s.id AS CHAR(1000)) as path, 

34 0 as level_depth, 

35 CAST(s.id AS CHAR(1000)) as sort_path 

36 FROM sections s 

37 LEFT JOIN weightings w ON ( 

38 w.section_id = s.id 

39 AND w.weighting_set_id = :weighting_set_id 

40 ) 

41 WHERE s.parent_id IS NULL 

42 AND s.project_id = :project_id 

43 

44 UNION ALL 

45 

46 -- Recursive: Join with child sections 

47 SELECT 

48 c.id, 

49 c.parent_id, 

50 COALESCE(w.value, 1.0) as section_weight, 

51 CONCAT(sh.path, ',', c.id) as path, 

52 sh.level_depth + 1 as level_depth, 

53 CONCAT(sh.sort_path, ',', LPAD(c.id, 10, '0')) as sort_path 

54 FROM sections c 

55 JOIN section_hierarchy sh ON c.parent_id = sh.id 

56 LEFT JOIN weightings w ON ( 

57 w.section_id = c.id 

58 AND w.weighting_set_id = :weighting_set_id 

59 ) 

60 WHERE c.project_id = :project_id 

61), 

62 

63-- Calculate sibling totals for normalization 

64sibling_totals AS ( 

65 SELECT 

66 sh.id, 

67 sh.parent_id, 

68 sh.section_weight, 

69 sh.path, 

70 sh.level_depth, 

71 sh.sort_path, 

72 SUM(sh.section_weight) OVER (PARTITION BY sh.parent_id) as sibling_total 

73 FROM section_hierarchy sh 

74), 

75 

76-- Calculate normalized weights (weight relative to siblings) 

77normalized_sections AS ( 

78 SELECT 

79 st.id, 

80 st.parent_id, 

81 st.section_weight, -- Keep the raw weight 

82 st.path, 

83 st.level_depth, 

84 st.sort_path, 

85 st.sibling_total, 

86 CASE 

87 WHEN st.sibling_total = 0 THEN 0.0 

88 ELSE st.section_weight / st.sibling_total 

89 END as normalised_weight 

90 FROM sibling_totals st 

91) 

92 

93SELECT 

94 ns.id as section_id, 

95 ns.section_weight as raw_weight, -- Add raw weight to output 

96 ns.normalised_weight, 

97 ns.path, 

98 ns.level_depth, 

99 ns.sort_path 

100FROM normalized_sections ns 

101ORDER BY ns.sort_path 

102""" 

103 

104 

105def calculate_absolute_weights_from_hierarchy( 

106 session: Session, project: "Project", weighting_set_id: Optional[int] = None 

107) -> List[Dict[str, Any]]: 

108 """ 

109 Calculate absolute weights by walking the hierarchy created by the CTE. 

110 This combines the CTE approach with Python calculation for absolute weights. 

111 """ 

112 

113 # Resolve the actual weighting set ID to use 

114 actual_weighting_set_id = weighting_set_id 

115 if actual_weighting_set_id is None and project.default_weighting_set: 

116 actual_weighting_set_id = project.default_weighting_set.id 

117 

118 # Execute the CTE query using the constant 

119 result = session.execute( 

120 text(SECTION_WEIGHTS_CTE_QUERY), 

121 {"project_id": project.id, "weighting_set_id": actual_weighting_set_id}, 

122 ) 

123 

124 # Convert to list of dictionaries for easier processing 

125 sections_data = [] 

126 for row in result: 

127 sections_data.append( 

128 { 

129 "section_id": row.section_id, 

130 "raw_weight": Decimal(str(row.raw_weight)), # Add raw weight 

131 "normalised_weight": Decimal(str(row.normalised_weight)), 

132 "path": row.path, 

133 "level_depth": row.level_depth, 

134 "sort_path": row.sort_path, 

135 } 

136 ) 

137 

138 # Sort by sort_path to ensure we process parents before children 

139 sections_data.sort(key=lambda x: x["sort_path"]) 

140 

141 # Build a lookup for quick access 

142 section_lookup = {s["section_id"]: s for s in sections_data} 

143 

144 # Calculate absolute weights by walking the hierarchy 

145 # For each section, absolute_weight = normalised_weight * parent_absolute_weight 

146 for section_data in sections_data: 

147 path_parts = [int(x) for x in section_data["path"].split(",")] 

148 

149 if len(path_parts) == 1: 

150 # Root section - absolute weight is 1.0 

151 section_data["absolute_weight"] = Decimal("1.0") 

152 else: 

153 # Child section - multiply normalized weight by parent's absolute weight 

154 parent_id = path_parts[-2] # Second to last in path is parent 

155 parent_data = section_lookup[parent_id] 

156 section_data["absolute_weight"] = ( 

157 section_data["normalised_weight"] * parent_data["absolute_weight"] 

158 ) 

159 

160 return sections_data 

161 

162 

163def calculate_question_weights( 

164 session: Session, 

165 project: "Project", 

166 section_weights: List[Dict[str, Any]], 

167 weighting_set_id: Optional[int] = None, 

168) -> List[Dict[str, Any]]: 

169 """ 

170 Calculate question weights based on section weights and question-specific weightings. 

171 """ 

172 

173 # Resolve the actual weighting set ID to use 

174 actual_weighting_set_id = weighting_set_id 

175 if actual_weighting_set_id is None and project.default_weighting_set: 

176 actual_weighting_set_id = project.default_weighting_set.id 

177 

178 # Build section absolute weight lookup 

179 section_abs_weights = { 

180 s["section_id"]: s["absolute_weight"] for s in section_weights 

181 } 

182 

183 # Query for all questions in the project with their weights 

184 query = text(""" 

185 SELECT  

186 qi.id as question_id, 

187 qi.section_id, 

188 COALESCE(w.value, 1.0) as question_weight 

189 FROM question_instances qi 

190 LEFT JOIN weightings w ON ( 

191 w.question_instance_id = qi.id  

192 AND w.weighting_set_id = :weighting_set_id 

193 ) 

194 WHERE qi.project_id = :project_id 

195 ORDER BY qi.section_id, qi.id 

196 """) 

197 

198 result = session.execute( 

199 query, {"project_id": project.id, "weighting_set_id": actual_weighting_set_id} 

200 ) 

201 

202 # Group questions by section for normalization 

203 questions_by_section: Dict[int, List[Dict[str, Any]]] = {} 

204 for row in result: 

205 section_id = row.section_id 

206 if section_id not in questions_by_section: 

207 questions_by_section[section_id] = [] 

208 

209 questions_by_section[section_id].append( 

210 { 

211 "question_id": row.question_id, 

212 "section_id": section_id, 

213 "question_weight": Decimal(str(row.question_weight)), 

214 } 

215 ) 

216 

217 # Calculate normalized and absolute weights for questions 

218 question_weights = [] 

219 

220 for section_id, questions in questions_by_section.items(): 

221 # Calculate total weight for questions in this section 

222 total_question_weight = sum(q["question_weight"] for q in questions) 

223 

224 # Get section's absolute weight 

225 section_absolute_weight = section_abs_weights.get(section_id, Decimal("1.0")) 

226 

227 for question in questions: 

228 # Calculate normalized weight (relative to other questions in same section) 

229 if total_question_weight == 0: 

230 normalised_weight = Decimal("0.0") 

231 else: 

232 normalised_weight = question["question_weight"] / total_question_weight 

233 

234 # Calculate absolute weight (normalized * section absolute weight) 

235 absolute_weight = normalised_weight * section_absolute_weight 

236 

237 question_weights.append( 

238 { 

239 "question_id": question["question_id"], 

240 "section_id": section_id, 

241 "normalised_weight": normalised_weight, 

242 "absolute_weight": absolute_weight, 

243 "raw_weight": question["question_weight"], 

244 } 

245 ) 

246 

247 return question_weights 

248 

249 

250def save_total_weights_cte( 

251 session: Session, project: "Project", weighting_set_id: Optional[int] = None 

252) -> None: 

253 """ 

254 Calculate and save total weights using CTE approach. 

255 

256 This is the main entry point that replaces the visitor-based approach. 

257 """ 

258 

259 log.info( 

260 f"Calculating total weights using CTE for project {project.id}, " 

261 f"weighting_set {weighting_set_id}" 

262 ) 

263 

264 # Step 0: Clear existing total weights for this project and weighting set 

265 delete_query = session.query(TotalWeighting).filter_by( 

266 project_id=project.id, weighting_set_id=weighting_set_id 

267 ) 

268 delete_query.delete(synchronize_session=False) 

269 session.flush() 

270 

271 # Step 1: Calculate section weights using CTE 

272 section_weights = calculate_absolute_weights_from_hierarchy( 

273 session, project, weighting_set_id 

274 ) 

275 

276 # Step 2: Calculate question weights 

277 question_weights = calculate_question_weights( 

278 session, project, section_weights, weighting_set_id 

279 ) 

280 

281 # Step 3: Prepare data for bulk insert 

282 total_weightings_data = [] 

283 

284 # Add section weight records 

285 for section_data in section_weights: 

286 total_weightings_data.append( 

287 { 

288 "project_id": project.id, 

289 "weighting_set_id": weighting_set_id, 

290 "section_id": section_data["section_id"], 

291 "question_instance_id": 0, # Default for sections 

292 "weight": section_data["raw_weight"], # Now this field exists 

293 "normalised_weight": section_data["normalised_weight"], 

294 "absolute_weight": section_data["absolute_weight"], 

295 } 

296 ) 

297 

298 # Add question weight records 

299 for question_data in question_weights: 

300 total_weightings_data.append( 

301 { 

302 "project_id": project.id, 

303 "weighting_set_id": weighting_set_id, 

304 "section_id": 0, # Default for questions 

305 "question_instance_id": question_data["question_id"], 

306 "weight": question_data[ 

307 "raw_weight" 

308 ], # Fixed: use raw_weight instead of absolute_weight 

309 "normalised_weight": question_data["normalised_weight"], 

310 "absolute_weight": question_data["absolute_weight"], 

311 } 

312 ) 

313 

314 # Step 4: Bulk insert the results 

315 if total_weightings_data: 

316 # Use SQLAlchemy's insert for bulk operations 

317 stmt = insert(TotalWeighting).values(total_weightings_data) 

318 session.execute(stmt) 

319 session.flush() 

320 

321 log.info( 

322 f"Successfully saved {len(total_weightings_data)} total weight records " 

323 f"({len(section_weights)} sections, {len(question_weights)} questions)" 

324 )