Coverage for postrfp/shared/fetch/secq.py: 100%

50 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-10-22 21:34 +0000

1""" 

2Functions which fetch Section objects from the database 

3""" 

4 

5import logging 

6from typing import Optional 

7 

8from sqlalchemy import literal, func, case, cast, INTEGER, null 

9from sqlalchemy.orm import Session, Query 

10from sqlalchemy.orm.session import object_session 

11 

12from postrfp.model import ( 

13 User, 

14 Project, 

15 Section, 

16 SectionPermission, 

17 QuestionInstance, 

18 QuestionDefinition, 

19) 

20from postrfp.model.questionnaire.weightings import TotalWeighting 

21from postrfp.model.questionnaire.b36 import visible_relatives_regex 

22 

23log = logging.getLogger(__name__) 

24 

25 

26def section(session: Session, section_id: int) -> Section: 

27 """ 

28 Fetch a Section from the database 

29 

30 Parameters 

31 ---------- 

32 session : Session 

33 

34 section_id : int 

35 

36 Raises 

37 ------ 

38 NoResultFound if no Section is found for the given section_id 

39 """ 

40 return session.query(Section).filter(Section.id == section_id).one() 

41 

42 

43def section_of_project(project: Project, section_id: int) -> Section: 

44 """ 

45 Fetch a section if it belongs to the given project. 

46 It is assumed the user has access to the project. 

47 

48 Raises 

49 ------ 

50 NoResultFound 

51 """ 

52 return project.sections.filter(Section.id == section_id).one() 

53 

54 

55def section_by_id(session: Session, section_id: int) -> Section: 

56 return session.query(Section).filter(Section.id == section_id).one() 

57 

58 

59def sections(project: Project, user: User) -> Query: 

60 """ 

61 Returns a Section query object, filtered by permission 

62 if the user is restricted 

63 """ 

64 if user.is_restricted: 

65 return ( 

66 project.sections.join(SectionPermission) 

67 .filter(SectionPermission.user == user) 

68 .order_by(Section.b36_number) 

69 ) 

70 else: 

71 return project.sections.order_by(Section.b36_number) 

72 

73 

74def visible_subsections_query(parent: Section, user: User) -> Query: 

75 session = object_session(parent) 

76 assert session is not None 

77 sq = ( 

78 session.query(Section) 

79 .filter(Section.parent_id == parent.id) 

80 .order_by(Section.b36_number) 

81 ) 

82 if user.is_restricted: 

83 sq = sq.join(SectionPermission).filter(SectionPermission.user == user) 

84 return sq 

85 

86 

87def get_subsections_recursive(session: Session, section_id: int) -> Query: 

88 """ 

89 Fetch list of question instance by section ID 

90 """ 

91 beginning_getter = ( 

92 session.query( 

93 Section.id, 

94 Section.parent_id, 

95 literal(0).label("recursive_depth"), 

96 cast(null(), INTEGER).label("parent_lvl_1"), 

97 ) 

98 .filter(Section.id == section_id) 

99 .cte(name="children_for", recursive=True) 

100 ) 

101 with_recursive = beginning_getter.union_all( 

102 session.query( 

103 Section.id, 

104 Section.parent_id, 

105 (beginning_getter.c.recursive_depth + 1).label("recursive_depth"), 

106 case( 

107 (beginning_getter.c.recursive_depth == 0, Section.id), 

108 else_=beginning_getter.c.parent_lvl_1, 

109 ).label("parent_lvl_1"), 

110 ).filter(Section.parent_id == beginning_getter.c.id) 

111 ) 

112 return session.query(with_recursive) 

113 

114 

115def sec_total_weighting(section: Section, weightset_id: Optional[int] = None) -> float: 

116 "Lookup total weighting as a float for the given section and weightset" 

117 session = object_session(section) 

118 assert session is not None 

119 query = ( 

120 session.query(TotalWeighting.absolute_weight) 

121 .filter(TotalWeighting.section_id == section.id) 

122 .filter(TotalWeighting.weighting_set_id == weightset_id) 

123 ) 

124 tw_value = query.scalar() 

125 if tw_value is None: 

126 log.warning(f"Replacing totals of weight set #{weightset_id} for {section}") 

127 section.project.delete_total_weights(weighting_set_id=weightset_id) 

128 section.project.save_total_weights(weighting_set_id=weightset_id) 

129 tw_value = query.scalar() 

130 return tw_value 

131 

132 

133def visible_nodes( 

134 session: Session, 

135 section: Section, 

136 with_questions: bool = True, 

137 with_ancestors: bool = False, 

138) -> Query: 

139 if with_ancestors: 

140 regex = visible_relatives_regex(section.b36_number) 

141 else: 

142 regex = f"^{section.b36_number}.{{0,2}}$" 

143 

144 q = ( 

145 session.query( 

146 Section.id, 

147 Section.title, 

148 Section.description, 

149 literal("section").label("type"), 

150 Section.parent_id, 

151 Section.b36_number.label("b36_number"), 

152 Section.position, 

153 (func.length(Section.b36_number) / 2).label("depth"), 

154 ) 

155 .filter(Section.project_id == section.project_id) 

156 .filter(Section.b36_number.op("REGEXP")(regex)) 

157 ) 

158 

159 if with_questions: 

160 aq = ( 

161 session.query( 

162 QuestionInstance.id, 

163 QuestionDefinition.title, 

164 literal("").label("description"), 

165 literal("question").label("type"), 

166 QuestionInstance.section_id.label("parent_id"), 

167 QuestionInstance.b36_number.label("b36_number"), 

168 QuestionInstance.position, 

169 (func.length(QuestionInstance.b36_number) / 2).label("depth"), 

170 ) 

171 .join(QuestionDefinition) 

172 .filter(QuestionInstance.project_id == section.project_id) 

173 .filter(QuestionInstance.section_id == section.id) 

174 ) 

175 q = q.union(aq) 

176 return q.order_by("b36_number")