Coverage for postrfp/fsm/service.py: 94%

108 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-10-22 21:34 +0000

1from typing import Any, Optional, Type, TypeVar 

2from sqlalchemy import select 

3from sqlalchemy.orm import Session 

4 

5from ..shared.fsm_entity import FSMEntity 

6from postrfp.shared.types import PermString 

7from postrfp.shared.serial import TransitionResult 

8from postrfp.shared.expression import evaluate_expression 

9 

10from ..model.fsm import Workflow, Status, Transition, StatusAction 

11 

12EntityT = TypeVar("EntityT", bound=FSMEntity) 

13 

14 

15def get_entity_type_name(entity_class: Type[EntityT]) -> str: 

16 """Get the entity type name used in FSM definitions.""" 

17 return entity_class.__name__ 

18 

19 

20def get_definitions_for_entity( 

21 session: Session, entity_class: Type[EntityT], organisation_id: Optional[int] = None 

22) -> list[Workflow]: 

23 """Get all FSM definitions for a specific entity type.""" 

24 

25 entity_type = get_entity_type_name(entity_class) 

26 query = select(Workflow).where(Workflow.entity_type == entity_type) 

27 

28 if organisation_id is not None: 

29 query = query.where(Workflow.organisation_id == organisation_id) 

30 

31 return list(session.execute(query).scalars().all()) 

32 

33 

34def get_all_states(session: Session, fsm_definition: Workflow) -> list[Status]: 

35 """Get all possible states for a given FSM definition.""" 

36 query = select(Status).where(Status.workflow_id == fsm_definition.id) 

37 return list(session.execute(query).scalars().all()) 

38 

39 

40def get_all_transitions(session: Session, fsm_definition: Workflow) -> list[Transition]: 

41 """Get all transitions defined for an FSM definition.""" 

42 query = select(Transition).where(Transition.workflow_id == fsm_definition.id) 

43 return list(session.execute(query).scalars().all()) 

44 

45 

46def get_available_transitions(session: Session, entity: EntityT) -> list[Transition]: 

47 """Get transitions available from the entity's current state.""" 

48 if not entity.current_status_id or not entity.workflow_id: 

49 return [] 

50 

51 query = select(Transition).where( 

52 Transition.workflow_id == entity.workflow_id, 

53 Transition.source_status_id == entity.current_status_id, 

54 ) 

55 return list(session.execute(query).scalars().all()) 

56 

57 

58def get_permitted_actions(session: Session, entity: EntityT) -> set[str]: 

59 """Get permitted actions for the entity's current state.""" 

60 if not entity.current_status_id: 

61 return set() 

62 

63 query = select(StatusAction.action).where( 

64 StatusAction.status_id == entity.current_status_id 

65 ) 

66 permissions = session.execute(query).scalars().all() 

67 return set(permissions) 

68 

69 

70def get_next_state_for_transition( 

71 session: Session, entity: EntityT, transition_name: str 

72) -> Optional[Status]: 

73 """Get the next state if the given transition name is applied to the entity.""" 

74 if not entity.current_status_id or not entity.workflow_id: 

75 return None 

76 

77 query = select(Transition).where( 

78 Transition.workflow_id == entity.workflow_id, 

79 Transition.source_status_id == entity.current_status_id, 

80 Transition.name == transition_name, 

81 ) 

82 transition = session.execute(query).scalar_one_or_none() 

83 

84 if not transition: 

85 return None 

86 

87 return transition.target_status 

88 

89 

90def create_fsm_from_status_actions( 

91 session: Session, 

92 entity_type: str, 

93 organisation_id: str, 

94 status_actions: dict[str, set[PermString]], 

95) -> Workflow: 

96 """ 

97 Create FSM definition from legacy status_actions dictionary. 

98 Useful for migrating from the old system. Generates a linear sequence 

99 of transitions based on the order of keys in the status_actions dict. 

100 """ 

101 workflow = Workflow( 

102 title=f"{entity_type} Default Workflow", 

103 entity_type=entity_type, 

104 organisation_id=organisation_id, 

105 version=1, 

106 is_active=True, 

107 initial_status_code="TEMP", # Temporary, will set after creating first status 

108 ) 

109 session.add(workflow) 

110 session.flush() # Let the db set the ID 

111 

112 ordered_status_names = [name for name in status_actions.keys() if name != "__new__"] 

113 ordered_states = [] 

114 state_map = {} 

115 

116 for status_name in ordered_status_names: 

117 if status_name == "__new__": 

118 continue 

119 code = status_name.lower().replace(" ", "_") 

120 if workflow.initial_status_code is None: 

121 workflow.initial_status_code = code # Set initial status to first one 

122 state = Status( 

123 workflow_id=workflow.id, 

124 name=status_name, 

125 code=code, 

126 ) 

127 session.add(state) 

128 session.flush() # 

129 state_map[status_name] = state 

130 ordered_states.append(state) 

131 

132 for action_name in status_actions[status_name]: 

133 perm = StatusAction(status_id=state.id, action=action_name) 

134 session.add(perm) 

135 

136 for i in range(len(ordered_states) - 1): 

137 source_state = ordered_states[i] 

138 target_state = ordered_states[i + 1] 

139 

140 # transition name, e.g., "draft_to_live" 

141 transition_name = f"{source_state.name.lower()}_to_{target_state.name.lower()}" 

142 transition_name = transition_name.replace(" ", "_") 

143 

144 transition = Transition( 

145 workflow_id=workflow.id, 

146 name=transition_name, 

147 source_status_id=source_state.id, 

148 target_status_id=target_state.id, 

149 ) 

150 session.add(transition) 

151 

152 return workflow 

153 

154 

155def get_all_states_for_entity( 

156 session: Session, entity_class: Type[EntityT], org_id: Optional[int] = None 

157) -> dict[int, list[Status]]: 

158 """ 

159 Get all possible states for a given entity type, organized by FSM definition ID. 

160 This is a convenience function for API endpoints. 

161 """ 

162 definitions = get_definitions_for_entity(session, entity_class, org_id) 

163 

164 result = {} 

165 for definition in definitions: 

166 result[definition.id] = get_all_states(session, definition) 

167 

168 return result 

169 

170 

171def get_all_transitions_for_entity( 

172 session: Session, entity_class: Type[EntityT], org_id: Optional[int] = None 

173) -> dict[int, list[Transition]]: 

174 """ 

175 Get all transitions for a given entity type, organized by FSM definition ID. 

176 This is a convenience function for API endpoints. 

177 """ 

178 definitions = get_definitions_for_entity(session, entity_class, org_id) 

179 

180 result = {} 

181 for definition in definitions: 

182 result[definition.id] = get_all_transitions(session, definition) 

183 

184 return result 

185 

186 

187def migrate_from_status_actions( 

188 session: Session, 

189 entity_type: str, 

190 organisation_id: str, 

191 status_actions: dict[str, set[PermString]], 

192) -> Workflow: 

193 """ 

194 Create an FSM definition from legacy status actions dictionary. 

195 This is a utility function for the migration process. 

196 """ 

197 return create_fsm_from_status_actions( 

198 session, entity_type, organisation_id, status_actions 

199 ) 

200 

201 

202def evaluate_transition( 

203 transition: Transition, context_data: dict[str, Any] 

204) -> TransitionResult: 

205 """ 

206 This is a stub for evaluating transition conditions. In a real implementation, 

207 this would execute the functions associated with the transition to determine 

208 if the transition is permitted and possibly return a job reference for async tasks. 

209 """ 

210 # Execute guard function if present, assuming it's a python expresssion for now 

211 if transition.guard_policy: 

212 try: 

213 print( 

214 f"\n ************ Evaluating guard function: {transition.guard_policy} with context {context_data} \n\n" 

215 ) 

216 permitted = evaluate_expression(transition.guard_policy, context_data) 

217 

218 if not permitted: 

219 return TransitionResult( 

220 transition_permitted=False, 

221 message="Guard condition not met", 

222 job_ref=None, 

223 ) 

224 except Exception as e: 

225 return TransitionResult( 

226 transition_permitted=False, 

227 message=f"Error evaluating guard function: {str(e)}", 

228 job_ref=None, 

229 ) 

230 return TransitionResult(transition_permitted=True, message="OK", job_ref=None) 

231 

232 

233def execute_transition(entity: FSMEntity, transition_name: str): 

234 """ 

235 Execute a transition on the given entity, updating its status if permitted. 

236 This function checks if the transition exists, evaluates any associated functions, 

237 and updates the entity's status if the transition is allowed. 

238 """ 

239 if entity.workflow is None: 

240 raise ValueError("No Workflow associated with this entity") 

241 

242 try: 

243 transition = next( 

244 t for t in entity.workflow.transitions if t.name == transition_name 

245 ) 

246 except StopIteration: 

247 raise ValueError( 

248 f"No transition named {transition_name} found in Workflow {entity.workflow.id}" 

249 ) 

250 

251 if transition.has_functions(): 

252 context_data = entity.get_context_data() 

253 trans_result: TransitionResult = evaluate_transition(transition, context_data) 

254 

255 if not trans_result.transition_permitted: 

256 raise ValueError(f"Status change not permitted: {trans_result.message}") 

257 

258 entity.status_name = transition.target_status.name