Coverage for postrfp / shared / fetch / flattening.py: 91%

103 statements  

« prev     ^ index     » next       coverage.py v7.12.0, created at 2025-12-03 01:35 +0000

1from typing import Optional, Any, Generator 

2 

3from untabulate import ProjectionGrid as UntabulateProjectionGrid, GridElement 

4 

5from postrfp.model import QElement, Issue 

6 

7 

8def _qelement_to_grid_element(element: QElement) -> GridElement: 

9 """ 

10 Convert a QElement to untabulate's GridElement format. 

11 

12 QElement has: 

13 - el_type: "LB" means it's a header, anything else is data 

14 - label: the text content 

15 - row, col, rowspan, colspan: position and span 

16 

17 GridElement expects: 

18 - is_header: bool (True if el_type == "LB") 

19 - row, col, rowspan, colspan: position info 

20 - value: the text content (from label) 

21 """ 

22 return GridElement( 

23 is_header=element.el_type == "LB", 

24 row=element.row, 

25 col=element.col, 

26 rowspan=element.rowspan or 1, 

27 colspan=element.colspan or 1, 

28 value=element.label or "", 

29 ) 

30 

31 

32class ProjectionGrid: 

33 """ 

34 Wrapper around untabulate's ProjectionGrid for compatibility. 

35 

36 Converts QElement instances to untabulate's GridElement format 

37 and provides the same API as before. 

38 """ 

39 

40 def __init__(self, elements: list[QElement]): 

41 """Initialize the grid with a list of QElement instances.""" 

42 # Convert QElements to untabulate's GridElement format 

43 grid_elements = [_qelement_to_grid_element(el) for el in elements] 

44 

45 # Create the untabulate ProjectionGrid 

46 self._grid = UntabulateProjectionGrid(grid_elements) 

47 

48 def get_path(self, data_row: int, data_col: int) -> list[str]: 

49 """ 

50 Get all headers that apply to a data cell at the given coordinates. 

51 

52 Args: 

53 data_row: 1-based row index 

54 data_col: 1-based column index 

55 

56 Returns: 

57 List of header strings that govern this cell 

58 """ 

59 return self._grid.get_path(data_row, data_col) 

60 

61 @property 

62 def row_headers(self): 

63 """Expose untabulate's row_headers for test compatibility.""" 

64 return self._grid.row_headers 

65 

66 @property 

67 def col_headers(self): 

68 """Expose untabulate's col_headers for test compatibility.""" 

69 return self._grid.col_headers 

70 

71 

72def _process_complex_table( 

73 elements: list[QElement], 

74 answerable_elements: list[QElement], 

75 answers: dict[int, Any], 

76 q_label: str, 

77 q_id: int, 

78 include_unanswered: bool, 

79) -> list[dict]: 

80 if not elements: 

81 return [] 

82 

83 # 1. Build Projections (Fast, One Pass) 

84 grid = ProjectionGrid(elements) 

85 output = [] 

86 

87 # 2. Process Data Cells (Fast Slicing) 

88 for el in answerable_elements: 

89 if el.id not in answers and not include_unanswered: 

90 continue 

91 

92 # Get the path directly from coordinates 

93 label_path = grid.get_path(el.row, el.col) 

94 

95 final_label = " / ".join(label_path) or q_label 

96 

97 output.append( 

98 _create_output_item( 

99 el, label_path, final_label, answers.get(el.id), q_id, q_label 

100 ) 

101 ) 

102 

103 return output 

104 

105 

106def embeddings_for_issue(issue: Issue) -> Generator[dict, None, None]: 

107 answer_lookup = {a.element_id: a.answer for a in issue.answers} 

108 for question_instance in issue.project.questions: 

109 qdef = question_instance.question_def 

110 for embedding in build_label_answer_mapping( 

111 qdef.elements, answer_lookup, qdef.title, question_instance.id 

112 ): 

113 yield embedding 

114 

115 

116def build_label_answer_mapping( 

117 elements: list[QElement], 

118 answers_by_element_id: dict[int, Any], 

119 question_label: str, 

120 question_id: int, 

121 include_unanswered: bool = False, 

122) -> list[dict]: 

123 """ 

124 Creates a flattened list of {label, value} pairs for a question's elements. 

125 """ 

126 if not elements: 

127 return [] 

128 

129 # Sort elements by position for deterministic processing 

130 elements = sorted(elements, key=lambda e: (e.row, e.col)) 

131 

132 # Filter answerable elements 

133 answerable_elements = [e for e in elements if e.is_answerable] 

134 if not answerable_elements: 

135 return [] 

136 

137 # Detect form structure and process accordingly 

138 form_type = _detect_form_structure(elements, answerable_elements) 

139 

140 if form_type == "single_column": 

141 return _process_single_column_form( 

142 elements, 

143 answerable_elements, 

144 answers_by_element_id, 

145 question_label, 

146 question_id, 

147 include_unanswered, 

148 ) 

149 elif form_type == "two_column": 

150 return _process_two_column_form( 

151 elements, 

152 answerable_elements, 

153 answers_by_element_id, 

154 question_label, 

155 question_id, 

156 include_unanswered, 

157 ) 

158 else: # Complex table 

159 return _process_complex_table( 

160 elements, 

161 answerable_elements, 

162 answers_by_element_id, 

163 question_label, 

164 question_id, 

165 include_unanswered, 

166 ) 

167 

168 

169def _detect_form_structure( 

170 elements: list[QElement], answerable_elements: list[QElement] 

171) -> str: 

172 """ 

173 Determines the structure of the form based on element layout. 

174 Returns: "single_column", "two_column", or "complex_table" 

175 """ 

176 if not elements: 

177 return "single_column" 

178 

179 max_col = max(e.col + e.colspan - 1 for e in elements) 

180 

181 if max_col == 1: 

182 return "single_column" 

183 

184 # Fast check for two-column form 

185 answerable_ids = {e.id for e in answerable_elements} 

186 for e in elements: 

187 if e.el_type == "LB" and e.col != 1: 

188 return "complex_table" 

189 if ( 

190 e.el_type != "LB" 

191 and e.id not in answerable_ids 

192 and e.el_type not in ("PD", "FD") 

193 ): 

194 return "complex_table" 

195 if e.id in answerable_ids and e.col != 2: 

196 return "complex_table" 

197 

198 return "two_column" 

199 

200 

201def _create_output_item( 

202 element: QElement, 

203 label_path: list[str], 

204 label: str, 

205 value: Any, 

206 question_id: int, 

207 question_label: str, 

208) -> dict: 

209 """Creates a standardized output item for a single element.""" 

210 return { 

211 "element_id": element.id, 

212 "label_path": label_path, 

213 "label": label, 

214 "value": _normalize_value(element, value), 

215 "meta": {"question_id": question_id, "question_title": question_label}, 

216 } 

217 

218 

219def _normalize_value(element: QElement, value: Any) -> Any: 

220 """Ensures consistent value representation for different element types.""" 

221 if value is None: 

222 return None 

223 

224 if element.el_type == "CB": # Checkbox 

225 return bool(value) if value != "false" else False 

226 

227 if element.el_type in ("CR", "CC"): 

228 if isinstance(value, dict) and "label" in value: 

229 return value["label"] # Extract just the label 

230 

231 return value 

232 

233 

234# --- Simple Table Implementations --- 

235 

236 

237def _process_single_column_form( 

238 elements: list[QElement], 

239 answerable_elements: list[QElement], 

240 answers: dict[int, Any], 

241 q_label: str, 

242 q_id: int, 

243 include_unanswered: bool, 

244) -> list[dict]: 

245 """Process a simple one-column form.""" 

246 output = [] 

247 last_label = "" 

248 answerable_ids = {e.id for e in answerable_elements} # Faster lookup 

249 

250 for el in elements: 

251 if el.el_type == "LB": 

252 last_label = el.label.strip() if el.label else "" 

253 elif el.id in answerable_ids: # Faster check with set 

254 if el.id in answers or include_unanswered: 

255 # Determine label with fallback logic 

256 label = _determine_label_with_fallback(last_label, el.label, q_label) 

257 label_path = [last_label] if last_label else [] 

258 

259 output.append( 

260 _create_output_item( 

261 el, label_path, label, answers.get(el.id), q_id, q_label 

262 ) 

263 ) 

264 

265 return output 

266 

267 

268def _process_two_column_form( 

269 elements: list[QElement], 

270 answerable_elements: list[QElement], 

271 answers: dict[int, Any], 

272 q_label: str, 

273 q_id: int, 

274 include_unanswered: bool, 

275) -> list[dict]: 

276 """Process a simple two-column form.""" 

277 output = [] 

278 labels_by_row = { 

279 el.row: el.label.strip() if el.label else "" 

280 for el in elements 

281 if el.el_type == "LB" 

282 } 

283 

284 for el in answerable_elements: # Only iterate through pre-filtered list 

285 if el.id in answers or include_unanswered: 

286 # Determine label with fallback logic 

287 row_label = labels_by_row.get(el.row, "") 

288 label = _determine_label_with_fallback(row_label, el.label, q_label) 

289 label_path = [label] if label else [] 

290 

291 output.append( 

292 _create_output_item( 

293 el, label_path, label, answers.get(el.id), q_id, q_label 

294 ) 

295 ) 

296 

297 return output 

298 

299 

300def _determine_label_with_fallback( 

301 primary_label: str, element_label: Optional[str], question_label: str 

302) -> str: 

303 """Determines the label using fallback logic.""" 

304 if primary_label: 

305 return primary_label 

306 if element_label: 

307 return element_label.strip() 

308 return question_label