Coverage for postrfp/model/questionnaire/cte_weights.py: 100%
67 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-22 21:34 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-22 21:34 +0000
1"""
2CTE-based weight calculation for PostRFP questionnaires.
4This module implements weight calculations using Common Table Expressions (CTEs)
5instead of the traditional visitor pattern approach, providing significant
6performance improvements for large questionnaires.
7"""
9import logging
10from decimal import Decimal
11from typing import Optional, Dict, Any, List, TYPE_CHECKING
13from sqlalchemy import text, insert
14from sqlalchemy.orm import Session
16from postrfp.model.questionnaire.weightings import TotalWeighting
18if TYPE_CHECKING:
19 from postrfp.model.project import Project
21log = logging.getLogger(__name__)
24# Recursive CTE query for calculating section weights
25# MariaDB-compatible recursive CTE that builds hierarchy and calculates normalized weights
26SECTION_WEIGHTS_CTE_QUERY = """
27WITH RECURSIVE section_hierarchy AS (
28 -- Anchor: Start with root sections (parent_id IS NULL)
29 SELECT
30 s.id,
31 s.parent_id,
32 COALESCE(w.value, 1.0) as section_weight,
33 CAST(s.id AS CHAR(1000)) as path,
34 0 as level_depth,
35 CAST(s.id AS CHAR(1000)) as sort_path
36 FROM sections s
37 LEFT JOIN weightings w ON (
38 w.section_id = s.id
39 AND w.weighting_set_id = :weighting_set_id
40 )
41 WHERE s.parent_id IS NULL
42 AND s.project_id = :project_id
44 UNION ALL
46 -- Recursive: Join with child sections
47 SELECT
48 c.id,
49 c.parent_id,
50 COALESCE(w.value, 1.0) as section_weight,
51 CONCAT(sh.path, ',', c.id) as path,
52 sh.level_depth + 1 as level_depth,
53 CONCAT(sh.sort_path, ',', LPAD(c.id, 10, '0')) as sort_path
54 FROM sections c
55 JOIN section_hierarchy sh ON c.parent_id = sh.id
56 LEFT JOIN weightings w ON (
57 w.section_id = c.id
58 AND w.weighting_set_id = :weighting_set_id
59 )
60 WHERE c.project_id = :project_id
61),
63-- Calculate sibling totals for normalization
64sibling_totals AS (
65 SELECT
66 sh.id,
67 sh.parent_id,
68 sh.section_weight,
69 sh.path,
70 sh.level_depth,
71 sh.sort_path,
72 SUM(sh.section_weight) OVER (PARTITION BY sh.parent_id) as sibling_total
73 FROM section_hierarchy sh
74),
76-- Calculate normalized weights (weight relative to siblings)
77normalized_sections AS (
78 SELECT
79 st.id,
80 st.parent_id,
81 st.section_weight, -- Keep the raw weight
82 st.path,
83 st.level_depth,
84 st.sort_path,
85 st.sibling_total,
86 CASE
87 WHEN st.sibling_total = 0 THEN 0.0
88 ELSE st.section_weight / st.sibling_total
89 END as normalised_weight
90 FROM sibling_totals st
91)
93SELECT
94 ns.id as section_id,
95 ns.section_weight as raw_weight, -- Add raw weight to output
96 ns.normalised_weight,
97 ns.path,
98 ns.level_depth,
99 ns.sort_path
100FROM normalized_sections ns
101ORDER BY ns.sort_path
102"""
105def calculate_absolute_weights_from_hierarchy(
106 session: Session, project: "Project", weighting_set_id: Optional[int] = None
107) -> List[Dict[str, Any]]:
108 """
109 Calculate absolute weights by walking the hierarchy created by the CTE.
110 This combines the CTE approach with Python calculation for absolute weights.
111 """
113 # Resolve the actual weighting set ID to use
114 actual_weighting_set_id = weighting_set_id
115 if actual_weighting_set_id is None and project.default_weighting_set:
116 actual_weighting_set_id = project.default_weighting_set.id
118 # Execute the CTE query using the constant
119 result = session.execute(
120 text(SECTION_WEIGHTS_CTE_QUERY),
121 {"project_id": project.id, "weighting_set_id": actual_weighting_set_id},
122 )
124 # Convert to list of dictionaries for easier processing
125 sections_data = []
126 for row in result:
127 sections_data.append(
128 {
129 "section_id": row.section_id,
130 "raw_weight": Decimal(str(row.raw_weight)), # Add raw weight
131 "normalised_weight": Decimal(str(row.normalised_weight)),
132 "path": row.path,
133 "level_depth": row.level_depth,
134 "sort_path": row.sort_path,
135 }
136 )
138 # Sort by sort_path to ensure we process parents before children
139 sections_data.sort(key=lambda x: x["sort_path"])
141 # Build a lookup for quick access
142 section_lookup = {s["section_id"]: s for s in sections_data}
144 # Calculate absolute weights by walking the hierarchy
145 # For each section, absolute_weight = normalised_weight * parent_absolute_weight
146 for section_data in sections_data:
147 path_parts = [int(x) for x in section_data["path"].split(",")]
149 if len(path_parts) == 1:
150 # Root section - absolute weight is 1.0
151 section_data["absolute_weight"] = Decimal("1.0")
152 else:
153 # Child section - multiply normalized weight by parent's absolute weight
154 parent_id = path_parts[-2] # Second to last in path is parent
155 parent_data = section_lookup[parent_id]
156 section_data["absolute_weight"] = (
157 section_data["normalised_weight"] * parent_data["absolute_weight"]
158 )
160 return sections_data
163def calculate_question_weights(
164 session: Session,
165 project: "Project",
166 section_weights: List[Dict[str, Any]],
167 weighting_set_id: Optional[int] = None,
168) -> List[Dict[str, Any]]:
169 """
170 Calculate question weights based on section weights and question-specific weightings.
171 """
173 # Resolve the actual weighting set ID to use
174 actual_weighting_set_id = weighting_set_id
175 if actual_weighting_set_id is None and project.default_weighting_set:
176 actual_weighting_set_id = project.default_weighting_set.id
178 # Build section absolute weight lookup
179 section_abs_weights = {
180 s["section_id"]: s["absolute_weight"] for s in section_weights
181 }
183 # Query for all questions in the project with their weights
184 query = text("""
185 SELECT
186 qi.id as question_id,
187 qi.section_id,
188 COALESCE(w.value, 1.0) as question_weight
189 FROM question_instances qi
190 LEFT JOIN weightings w ON (
191 w.question_instance_id = qi.id
192 AND w.weighting_set_id = :weighting_set_id
193 )
194 WHERE qi.project_id = :project_id
195 ORDER BY qi.section_id, qi.id
196 """)
198 result = session.execute(
199 query, {"project_id": project.id, "weighting_set_id": actual_weighting_set_id}
200 )
202 # Group questions by section for normalization
203 questions_by_section: Dict[int, List[Dict[str, Any]]] = {}
204 for row in result:
205 section_id = row.section_id
206 if section_id not in questions_by_section:
207 questions_by_section[section_id] = []
209 questions_by_section[section_id].append(
210 {
211 "question_id": row.question_id,
212 "section_id": section_id,
213 "question_weight": Decimal(str(row.question_weight)),
214 }
215 )
217 # Calculate normalized and absolute weights for questions
218 question_weights = []
220 for section_id, questions in questions_by_section.items():
221 # Calculate total weight for questions in this section
222 total_question_weight = sum(q["question_weight"] for q in questions)
224 # Get section's absolute weight
225 section_absolute_weight = section_abs_weights.get(section_id, Decimal("1.0"))
227 for question in questions:
228 # Calculate normalized weight (relative to other questions in same section)
229 if total_question_weight == 0:
230 normalised_weight = Decimal("0.0")
231 else:
232 normalised_weight = question["question_weight"] / total_question_weight
234 # Calculate absolute weight (normalized * section absolute weight)
235 absolute_weight = normalised_weight * section_absolute_weight
237 question_weights.append(
238 {
239 "question_id": question["question_id"],
240 "section_id": section_id,
241 "normalised_weight": normalised_weight,
242 "absolute_weight": absolute_weight,
243 "raw_weight": question["question_weight"],
244 }
245 )
247 return question_weights
250def save_total_weights_cte(
251 session: Session, project: "Project", weighting_set_id: Optional[int] = None
252) -> None:
253 """
254 Calculate and save total weights using CTE approach.
256 This is the main entry point that replaces the visitor-based approach.
257 """
259 log.info(
260 f"Calculating total weights using CTE for project {project.id}, "
261 f"weighting_set {weighting_set_id}"
262 )
264 # Step 0: Clear existing total weights for this project and weighting set
265 delete_query = session.query(TotalWeighting).filter_by(
266 project_id=project.id, weighting_set_id=weighting_set_id
267 )
268 delete_query.delete(synchronize_session=False)
269 session.flush()
271 # Step 1: Calculate section weights using CTE
272 section_weights = calculate_absolute_weights_from_hierarchy(
273 session, project, weighting_set_id
274 )
276 # Step 2: Calculate question weights
277 question_weights = calculate_question_weights(
278 session, project, section_weights, weighting_set_id
279 )
281 # Step 3: Prepare data for bulk insert
282 total_weightings_data = []
284 # Add section weight records
285 for section_data in section_weights:
286 total_weightings_data.append(
287 {
288 "project_id": project.id,
289 "weighting_set_id": weighting_set_id,
290 "section_id": section_data["section_id"],
291 "question_instance_id": 0, # Default for sections
292 "weight": section_data["raw_weight"], # Now this field exists
293 "normalised_weight": section_data["normalised_weight"],
294 "absolute_weight": section_data["absolute_weight"],
295 }
296 )
298 # Add question weight records
299 for question_data in question_weights:
300 total_weightings_data.append(
301 {
302 "project_id": project.id,
303 "weighting_set_id": weighting_set_id,
304 "section_id": 0, # Default for questions
305 "question_instance_id": question_data["question_id"],
306 "weight": question_data[
307 "raw_weight"
308 ], # Fixed: use raw_weight instead of absolute_weight
309 "normalised_weight": question_data["normalised_weight"],
310 "absolute_weight": question_data["absolute_weight"],
311 }
312 )
314 # Step 4: Bulk insert the results
315 if total_weightings_data:
316 # Use SQLAlchemy's insert for bulk operations
317 stmt = insert(TotalWeighting).values(total_weightings_data)
318 session.execute(stmt)
319 session.flush()
321 log.info(
322 f"Successfully saved {len(total_weightings_data)} total weight records "
323 f"({len(section_weights)} sections, {len(question_weights)} questions)"
324 )