Coverage for postrfp/shared/fetch/flattening.py: 95%
169 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-22 21:34 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-22 21:34 +0000
1from typing import Optional, Any, Generator
2import logging
4from postrfp.model import QElement, Issue
6logger = logging.getLogger(__name__)
9class ElementGrid:
10 """
11 A 2D grid representation of question elements for efficient spatial lookups.
12 Expands all merged cells to create a normalized representation.
13 """
15 def __init__(self, elements: list[QElement]):
16 # Initialize all attributes to handle empty case properly
17 self.height = 0
18 self.width = 0
19 self.grid: list[list[Optional[QElement]]] = []
20 self.row_headers_by_row: dict[int, list[QElement]] = {}
21 self.header_cache: dict[
22 tuple[int, int], tuple[list[QElement], list[QElement]]
23 ] = {}
25 if not elements:
26 return
28 # Calculate dimensions in a single pass
29 max_row = max_col = 0
30 for e in elements:
31 max_row = max(max_row, e.row + e.rowspan - 1)
32 max_col = max(max_col, e.col + e.colspan - 1)
34 self.height = max_row
35 self.width = max_col
37 # Initialize with None
38 self.grid = [[None for _ in range(self.width)] for _ in range(self.height)]
40 # Process all elements in a single pass
41 for el in elements:
42 # Populate grid with elements
43 for r in range(el.row - 1, el.row + el.rowspan - 1):
44 for c in range(el.col - 1, el.col + el.colspan - 1):
45 if r < self.height and c < self.width:
46 self.grid[r][c] = el
48 # Process row headers
49 if el.el_type == "LB" and el.col == 1:
50 for r in range(el.row, el.row + el.rowspan):
51 if r not in self.row_headers_by_row:
52 self.row_headers_by_row[r] = []
53 self.row_headers_by_row[r].append(el)
55 def get_element(self, row: int, col: int) -> Optional[QElement]:
56 """Gets the element at a specific grid coordinate (1-based)."""
57 if 1 <= row <= self.height and 1 <= col <= self.width:
58 return self.grid[row - 1][col - 1]
59 return None
61 def find_headers_for(
62 self, element: QElement
63 ) -> tuple[list[QElement], list[QElement]]:
64 """
65 Finds the row and column headers for a given element.
67 For simple grids: Returns only direct row headers (same row)
68 For hierarchical tables: Includes parent headers from column 1
69 """
70 # Check cache first
71 cache_key = (element.row, element.col)
72 if cache_key in self.header_cache:
73 return self.header_cache[cache_key]
75 row_headers: list[QElement] = []
76 col_headers: list[QElement] = []
78 # For empty grid, return empty headers
79 if not self.grid:
80 return row_headers, col_headers
82 # 1. Find direct row header (same row, to the left)
83 row_header_ids = set() # Use set for faster lookups
84 for c in range(1, element.col):
85 header = self.get_element(element.row, c)
86 if header and header.el_type == "LB" and header.id not in row_header_ids:
87 row_headers.append(header)
88 row_header_ids.add(header.id)
90 # 2. Only add parent headers if this is a hierarchical table
91 if len(self.row_headers_by_row) >= 3:
92 # Add hierarchical parent headers from column 1
93 for r in range(1, element.row):
94 header = self.get_element(r, 1)
95 if (
96 header
97 and header.el_type == "LB"
98 and header.id not in row_header_ids
99 ):
100 row_headers.insert(
101 0, header
102 ) # Add at beginning to maintain hierarchy
103 row_header_ids.add(header.id)
105 # 3. Find column headers (direct and spanning)
106 col_header_ids = set()
107 for r in range(1, element.row):
108 # Direct headers (same column)
109 header = self.get_element(r, element.col)
110 if header and header.el_type == "LB" and header.id not in col_header_ids:
111 col_headers.append(header)
112 col_header_ids.add(header.id)
114 # Spanning headers (headers that span to this column)
115 for c in range(1, element.col):
116 header = self.get_element(r, c)
117 if (
118 header
119 and header.el_type == "LB"
120 and header.id not in col_header_ids
121 and c + header.colspan - 1 >= element.col
122 ):
123 col_headers.append(header)
124 col_header_ids.add(header.id)
126 # Cache results
127 result = (row_headers, col_headers)
128 self.header_cache[cache_key] = result
129 return result
132def embeddings_for_issue(issue: Issue) -> Generator[dict, None, None]:
133 answer_lookup = {a.element_id: a.answer for a in issue.answers}
134 for question_instance in issue.project.questions:
135 qdef = question_instance.question_def
136 for embedding in build_label_answer_mapping(
137 qdef.elements, answer_lookup, qdef.title, question_instance.id
138 ):
139 yield embedding
142def build_label_answer_mapping(
143 elements: list[QElement],
144 answers_by_element_id: dict[int, Any],
145 question_label: str,
146 question_id: int,
147 include_unanswered: bool = False,
148) -> list[dict]:
149 """
150 Creates a flattened list of {label, value} pairs for a question's elements.
151 """
152 if not elements:
153 return []
155 # Sort elements by position for deterministic processing
156 elements = sorted(elements, key=lambda e: (e.row, e.col))
158 # Filter answerable elements
159 answerable_elements = [e for e in elements if e.is_answerable]
160 if not answerable_elements:
161 return []
163 # Detect form structure and process accordingly
164 form_type = _detect_form_structure(elements, answerable_elements)
166 if form_type == "single_column":
167 return _process_single_column_form(
168 elements,
169 answerable_elements,
170 answers_by_element_id,
171 question_label,
172 question_id,
173 include_unanswered,
174 )
175 elif form_type == "two_column":
176 return _process_two_column_form(
177 elements,
178 answerable_elements,
179 answers_by_element_id,
180 question_label,
181 question_id,
182 include_unanswered,
183 )
184 else: # Complex table
185 return _process_complex_table(
186 elements,
187 answerable_elements,
188 answers_by_element_id,
189 question_label,
190 question_id,
191 include_unanswered,
192 )
195def _detect_form_structure(
196 elements: list[QElement], answerable_elements: list[QElement]
197) -> str:
198 """
199 Determines the structure of the form based on element layout.
200 Returns: "single_column", "two_column", or "complex_table"
201 """
202 if not elements:
203 return "single_column"
205 max_col = max(e.col + e.colspan - 1 for e in elements)
207 if max_col == 1:
208 return "single_column"
210 # Fast check for two-column form
211 answerable_ids = {e.id for e in answerable_elements}
212 for e in elements:
213 if e.el_type == "LB" and e.col != 1:
214 return "complex_table"
215 if (
216 e.el_type != "LB"
217 and e.id not in answerable_ids
218 and e.el_type not in ("PD", "FD")
219 ):
220 return "complex_table"
221 if e.id in answerable_ids and e.col != 2:
222 return "complex_table"
224 return "two_column"
227def _process_complex_table(
228 elements: list[QElement],
229 answerable_elements: list[QElement],
230 answers: dict[int, Any],
231 q_label: str,
232 q_id: int,
233 include_unanswered: bool,
234) -> list[dict]:
235 """Process a complex table structure using ElementGrid."""
236 grid = ElementGrid(elements)
237 output = []
239 for el in answerable_elements:
240 if el.id not in answers and not include_unanswered:
241 continue
243 row_headers, col_headers = grid.find_headers_for(el)
244 label_path = _build_label_path(row_headers, col_headers)
246 # Create the final label string
247 final_label = " / ".join(label_path) or q_label
249 output.append(
250 _create_output_item(
251 el, label_path, final_label, answers.get(el.id), q_id, q_label
252 )
253 )
255 return output
258def _build_label_path(
259 row_headers: list[QElement], col_headers: list[QElement]
260) -> list[str]:
261 """Builds a label path from row and column headers."""
262 label_path = []
263 seen = set() # Faster lookup for duplicates
265 # Add row headers first (they typically represent hierarchy)
266 for h in row_headers:
267 if h.label:
268 norm_label = h.label.strip() if h.label else ""
269 if norm_label and norm_label not in seen:
270 label_path.append(norm_label)
271 seen.add(norm_label)
273 # Then add column headers
274 for h in col_headers:
275 if h.label:
276 norm_label = h.label.strip() if h.label else ""
277 if norm_label and norm_label not in seen:
278 label_path.append(norm_label)
279 seen.add(norm_label)
281 return label_path
284def _create_output_item(
285 element: QElement,
286 label_path: list[str],
287 label: str,
288 value: Any,
289 question_id: int,
290 question_label: str,
291) -> dict:
292 """Creates a standardized output item for a single element."""
293 return {
294 "element_id": element.id,
295 "label_path": label_path,
296 "label": label,
297 "value": _normalize_value(element, value),
298 "meta": {"question_id": question_id, "question_title": question_label},
299 }
302def _normalize_value(element: QElement, value: Any) -> Any:
303 """Ensures consistent value representation for different element types."""
304 if value is None:
305 return None
307 if element.el_type == "CB": # Checkbox
308 return bool(value) if value != "false" else False
310 if element.el_type in ("CR", "CC"):
311 if isinstance(value, dict) and "label" in value:
312 return value["label"] # Extract just the label
314 return value
317# --- Fast Path Implementations ---
320def _process_single_column_form(
321 elements: list[QElement],
322 answerable_elements: list[QElement],
323 answers: dict[int, Any],
324 q_label: str,
325 q_id: int,
326 include_unanswered: bool,
327) -> list[dict]:
328 """Process a simple one-column form."""
329 output = []
330 last_label = ""
331 answerable_ids = {e.id for e in answerable_elements} # Faster lookup
333 for el in elements:
334 if el.el_type == "LB":
335 last_label = el.label.strip() if el.label else ""
336 elif el.id in answerable_ids: # Faster check with set
337 if el.id in answers or include_unanswered:
338 # Determine label with fallback logic
339 label = _determine_label_with_fallback(last_label, el.label, q_label)
340 label_path = [last_label] if last_label else []
342 output.append(
343 _create_output_item(
344 el, label_path, label, answers.get(el.id), q_id, q_label
345 )
346 )
348 return output
351def _process_two_column_form(
352 elements: list[QElement],
353 answerable_elements: list[QElement],
354 answers: dict[int, Any],
355 q_label: str,
356 q_id: int,
357 include_unanswered: bool,
358) -> list[dict]:
359 """Process a simple two-column form."""
360 output = []
361 labels_by_row = {
362 el.row: el.label.strip() if el.label else ""
363 for el in elements
364 if el.el_type == "LB"
365 }
367 for el in answerable_elements: # Only iterate through pre-filtered list
368 if el.id in answers or include_unanswered:
369 # Determine label with fallback logic
370 row_label = labels_by_row.get(el.row, "")
371 label = _determine_label_with_fallback(row_label, el.label, q_label)
372 label_path = [label] if label else []
374 output.append(
375 _create_output_item(
376 el, label_path, label, answers.get(el.id), q_id, q_label
377 )
378 )
380 return output
383def _determine_label_with_fallback(
384 primary_label: str, element_label: Optional[str], question_label: str
385) -> str:
386 """Determines the label using fallback logic."""
387 if primary_label:
388 return primary_label
389 if element_label:
390 return element_label.strip()
391 return question_label