Coverage for postrfp/shared/fetch/flattening.py: 95%

169 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-10-22 21:34 +0000

1from typing import Optional, Any, Generator 

2import logging 

3 

4from postrfp.model import QElement, Issue 

5 

6logger = logging.getLogger(__name__) 

7 

8 

9class ElementGrid: 

10 """ 

11 A 2D grid representation of question elements for efficient spatial lookups. 

12 Expands all merged cells to create a normalized representation. 

13 """ 

14 

15 def __init__(self, elements: list[QElement]): 

16 # Initialize all attributes to handle empty case properly 

17 self.height = 0 

18 self.width = 0 

19 self.grid: list[list[Optional[QElement]]] = [] 

20 self.row_headers_by_row: dict[int, list[QElement]] = {} 

21 self.header_cache: dict[ 

22 tuple[int, int], tuple[list[QElement], list[QElement]] 

23 ] = {} 

24 

25 if not elements: 

26 return 

27 

28 # Calculate dimensions in a single pass 

29 max_row = max_col = 0 

30 for e in elements: 

31 max_row = max(max_row, e.row + e.rowspan - 1) 

32 max_col = max(max_col, e.col + e.colspan - 1) 

33 

34 self.height = max_row 

35 self.width = max_col 

36 

37 # Initialize with None 

38 self.grid = [[None for _ in range(self.width)] for _ in range(self.height)] 

39 

40 # Process all elements in a single pass 

41 for el in elements: 

42 # Populate grid with elements 

43 for r in range(el.row - 1, el.row + el.rowspan - 1): 

44 for c in range(el.col - 1, el.col + el.colspan - 1): 

45 if r < self.height and c < self.width: 

46 self.grid[r][c] = el 

47 

48 # Process row headers 

49 if el.el_type == "LB" and el.col == 1: 

50 for r in range(el.row, el.row + el.rowspan): 

51 if r not in self.row_headers_by_row: 

52 self.row_headers_by_row[r] = [] 

53 self.row_headers_by_row[r].append(el) 

54 

55 def get_element(self, row: int, col: int) -> Optional[QElement]: 

56 """Gets the element at a specific grid coordinate (1-based).""" 

57 if 1 <= row <= self.height and 1 <= col <= self.width: 

58 return self.grid[row - 1][col - 1] 

59 return None 

60 

61 def find_headers_for( 

62 self, element: QElement 

63 ) -> tuple[list[QElement], list[QElement]]: 

64 """ 

65 Finds the row and column headers for a given element. 

66 

67 For simple grids: Returns only direct row headers (same row) 

68 For hierarchical tables: Includes parent headers from column 1 

69 """ 

70 # Check cache first 

71 cache_key = (element.row, element.col) 

72 if cache_key in self.header_cache: 

73 return self.header_cache[cache_key] 

74 

75 row_headers: list[QElement] = [] 

76 col_headers: list[QElement] = [] 

77 

78 # For empty grid, return empty headers 

79 if not self.grid: 

80 return row_headers, col_headers 

81 

82 # 1. Find direct row header (same row, to the left) 

83 row_header_ids = set() # Use set for faster lookups 

84 for c in range(1, element.col): 

85 header = self.get_element(element.row, c) 

86 if header and header.el_type == "LB" and header.id not in row_header_ids: 

87 row_headers.append(header) 

88 row_header_ids.add(header.id) 

89 

90 # 2. Only add parent headers if this is a hierarchical table 

91 if len(self.row_headers_by_row) >= 3: 

92 # Add hierarchical parent headers from column 1 

93 for r in range(1, element.row): 

94 header = self.get_element(r, 1) 

95 if ( 

96 header 

97 and header.el_type == "LB" 

98 and header.id not in row_header_ids 

99 ): 

100 row_headers.insert( 

101 0, header 

102 ) # Add at beginning to maintain hierarchy 

103 row_header_ids.add(header.id) 

104 

105 # 3. Find column headers (direct and spanning) 

106 col_header_ids = set() 

107 for r in range(1, element.row): 

108 # Direct headers (same column) 

109 header = self.get_element(r, element.col) 

110 if header and header.el_type == "LB" and header.id not in col_header_ids: 

111 col_headers.append(header) 

112 col_header_ids.add(header.id) 

113 

114 # Spanning headers (headers that span to this column) 

115 for c in range(1, element.col): 

116 header = self.get_element(r, c) 

117 if ( 

118 header 

119 and header.el_type == "LB" 

120 and header.id not in col_header_ids 

121 and c + header.colspan - 1 >= element.col 

122 ): 

123 col_headers.append(header) 

124 col_header_ids.add(header.id) 

125 

126 # Cache results 

127 result = (row_headers, col_headers) 

128 self.header_cache[cache_key] = result 

129 return result 

130 

131 

132def embeddings_for_issue(issue: Issue) -> Generator[dict, None, None]: 

133 answer_lookup = {a.element_id: a.answer for a in issue.answers} 

134 for question_instance in issue.project.questions: 

135 qdef = question_instance.question_def 

136 for embedding in build_label_answer_mapping( 

137 qdef.elements, answer_lookup, qdef.title, question_instance.id 

138 ): 

139 yield embedding 

140 

141 

142def build_label_answer_mapping( 

143 elements: list[QElement], 

144 answers_by_element_id: dict[int, Any], 

145 question_label: str, 

146 question_id: int, 

147 include_unanswered: bool = False, 

148) -> list[dict]: 

149 """ 

150 Creates a flattened list of {label, value} pairs for a question's elements. 

151 """ 

152 if not elements: 

153 return [] 

154 

155 # Sort elements by position for deterministic processing 

156 elements = sorted(elements, key=lambda e: (e.row, e.col)) 

157 

158 # Filter answerable elements 

159 answerable_elements = [e for e in elements if e.is_answerable] 

160 if not answerable_elements: 

161 return [] 

162 

163 # Detect form structure and process accordingly 

164 form_type = _detect_form_structure(elements, answerable_elements) 

165 

166 if form_type == "single_column": 

167 return _process_single_column_form( 

168 elements, 

169 answerable_elements, 

170 answers_by_element_id, 

171 question_label, 

172 question_id, 

173 include_unanswered, 

174 ) 

175 elif form_type == "two_column": 

176 return _process_two_column_form( 

177 elements, 

178 answerable_elements, 

179 answers_by_element_id, 

180 question_label, 

181 question_id, 

182 include_unanswered, 

183 ) 

184 else: # Complex table 

185 return _process_complex_table( 

186 elements, 

187 answerable_elements, 

188 answers_by_element_id, 

189 question_label, 

190 question_id, 

191 include_unanswered, 

192 ) 

193 

194 

195def _detect_form_structure( 

196 elements: list[QElement], answerable_elements: list[QElement] 

197) -> str: 

198 """ 

199 Determines the structure of the form based on element layout. 

200 Returns: "single_column", "two_column", or "complex_table" 

201 """ 

202 if not elements: 

203 return "single_column" 

204 

205 max_col = max(e.col + e.colspan - 1 for e in elements) 

206 

207 if max_col == 1: 

208 return "single_column" 

209 

210 # Fast check for two-column form 

211 answerable_ids = {e.id for e in answerable_elements} 

212 for e in elements: 

213 if e.el_type == "LB" and e.col != 1: 

214 return "complex_table" 

215 if ( 

216 e.el_type != "LB" 

217 and e.id not in answerable_ids 

218 and e.el_type not in ("PD", "FD") 

219 ): 

220 return "complex_table" 

221 if e.id in answerable_ids and e.col != 2: 

222 return "complex_table" 

223 

224 return "two_column" 

225 

226 

227def _process_complex_table( 

228 elements: list[QElement], 

229 answerable_elements: list[QElement], 

230 answers: dict[int, Any], 

231 q_label: str, 

232 q_id: int, 

233 include_unanswered: bool, 

234) -> list[dict]: 

235 """Process a complex table structure using ElementGrid.""" 

236 grid = ElementGrid(elements) 

237 output = [] 

238 

239 for el in answerable_elements: 

240 if el.id not in answers and not include_unanswered: 

241 continue 

242 

243 row_headers, col_headers = grid.find_headers_for(el) 

244 label_path = _build_label_path(row_headers, col_headers) 

245 

246 # Create the final label string 

247 final_label = " / ".join(label_path) or q_label 

248 

249 output.append( 

250 _create_output_item( 

251 el, label_path, final_label, answers.get(el.id), q_id, q_label 

252 ) 

253 ) 

254 

255 return output 

256 

257 

258def _build_label_path( 

259 row_headers: list[QElement], col_headers: list[QElement] 

260) -> list[str]: 

261 """Builds a label path from row and column headers.""" 

262 label_path = [] 

263 seen = set() # Faster lookup for duplicates 

264 

265 # Add row headers first (they typically represent hierarchy) 

266 for h in row_headers: 

267 if h.label: 

268 norm_label = h.label.strip() if h.label else "" 

269 if norm_label and norm_label not in seen: 

270 label_path.append(norm_label) 

271 seen.add(norm_label) 

272 

273 # Then add column headers 

274 for h in col_headers: 

275 if h.label: 

276 norm_label = h.label.strip() if h.label else "" 

277 if norm_label and norm_label not in seen: 

278 label_path.append(norm_label) 

279 seen.add(norm_label) 

280 

281 return label_path 

282 

283 

284def _create_output_item( 

285 element: QElement, 

286 label_path: list[str], 

287 label: str, 

288 value: Any, 

289 question_id: int, 

290 question_label: str, 

291) -> dict: 

292 """Creates a standardized output item for a single element.""" 

293 return { 

294 "element_id": element.id, 

295 "label_path": label_path, 

296 "label": label, 

297 "value": _normalize_value(element, value), 

298 "meta": {"question_id": question_id, "question_title": question_label}, 

299 } 

300 

301 

302def _normalize_value(element: QElement, value: Any) -> Any: 

303 """Ensures consistent value representation for different element types.""" 

304 if value is None: 

305 return None 

306 

307 if element.el_type == "CB": # Checkbox 

308 return bool(value) if value != "false" else False 

309 

310 if element.el_type in ("CR", "CC"): 

311 if isinstance(value, dict) and "label" in value: 

312 return value["label"] # Extract just the label 

313 

314 return value 

315 

316 

317# --- Fast Path Implementations --- 

318 

319 

320def _process_single_column_form( 

321 elements: list[QElement], 

322 answerable_elements: list[QElement], 

323 answers: dict[int, Any], 

324 q_label: str, 

325 q_id: int, 

326 include_unanswered: bool, 

327) -> list[dict]: 

328 """Process a simple one-column form.""" 

329 output = [] 

330 last_label = "" 

331 answerable_ids = {e.id for e in answerable_elements} # Faster lookup 

332 

333 for el in elements: 

334 if el.el_type == "LB": 

335 last_label = el.label.strip() if el.label else "" 

336 elif el.id in answerable_ids: # Faster check with set 

337 if el.id in answers or include_unanswered: 

338 # Determine label with fallback logic 

339 label = _determine_label_with_fallback(last_label, el.label, q_label) 

340 label_path = [last_label] if last_label else [] 

341 

342 output.append( 

343 _create_output_item( 

344 el, label_path, label, answers.get(el.id), q_id, q_label 

345 ) 

346 ) 

347 

348 return output 

349 

350 

351def _process_two_column_form( 

352 elements: list[QElement], 

353 answerable_elements: list[QElement], 

354 answers: dict[int, Any], 

355 q_label: str, 

356 q_id: int, 

357 include_unanswered: bool, 

358) -> list[dict]: 

359 """Process a simple two-column form.""" 

360 output = [] 

361 labels_by_row = { 

362 el.row: el.label.strip() if el.label else "" 

363 for el in elements 

364 if el.el_type == "LB" 

365 } 

366 

367 for el in answerable_elements: # Only iterate through pre-filtered list 

368 if el.id in answers or include_unanswered: 

369 # Determine label with fallback logic 

370 row_label = labels_by_row.get(el.row, "") 

371 label = _determine_label_with_fallback(row_label, el.label, q_label) 

372 label_path = [label] if label else [] 

373 

374 output.append( 

375 _create_output_item( 

376 el, label_path, label, answers.get(el.id), q_id, q_label 

377 ) 

378 ) 

379 

380 return output 

381 

382 

383def _determine_label_with_fallback( 

384 primary_label: str, element_label: Optional[str], question_label: str 

385) -> str: 

386 """Determines the label using fallback logic.""" 

387 if primary_label: 

388 return primary_label 

389 if element_label: 

390 return element_label.strip() 

391 return question_label