Coverage for postrfp/shared/fsm_entity.py: 95%

87 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-10-22 21:34 +0000

1from typing import Optional, Dict, Type, Set 

2 

3from sqlalchemy import ForeignKey, and_, ForeignKeyConstraint 

4from sqlalchemy.orm import Mapped, mapped_column, relationship, foreign, object_session 

5from sqlalchemy.ext.declarative import declared_attr 

6 

7from postrfp.model.fsm import Workflow, Status 

8from postrfp.shared.types import PermString 

9 

10 

11class FSMEntity: 

12 """Mixin to add FSM properties to entity models and declare abstract methods to implement""" 

13 

14 @declared_attr 

15 def __table_args__(cls): 

16 """Define table constraints for FSM entities""" 

17 return ( 

18 ForeignKeyConstraint( 

19 ["current_status_id", "workflow_id"], 

20 ["statuses.id", "statuses.workflow_id"], 

21 name=f"fk_{cls.__tablename__}_status_workflow", 

22 ), 

23 ) 

24 

25 # Class registry to track all FSM entities 

26 _fsm_entity_registry: Dict[str, Type["FSMEntity"]] = {} 

27 

28 def __init_subclass__(cls, **kwargs): 

29 """Automatically register FSM entity subclasses""" 

30 super().__init_subclass__(**kwargs) 

31 # Only register if this class directly inherits from FSMEntity 

32 # and has SQLAlchemy table mapping 

33 if FSMEntity in cls.__bases__ and hasattr(cls, "__tablename__"): 

34 FSMEntity._fsm_entity_registry[cls.__name__] = cls 

35 

36 @classmethod 

37 def get_registered_entities(cls) -> Dict[str, Type["FSMEntity"]]: 

38 """Get all registered FSM entity classes""" 

39 return cls._fsm_entity_registry.copy() 

40 

41 @classmethod 

42 def get_entity_by_name(cls, name: str) -> Optional[Type["FSMEntity"]]: 

43 """Get a specific FSM entity class by name""" 

44 return cls._fsm_entity_registry.get(name) 

45 

46 @classmethod 

47 def get_entity_names(cls) -> Set[str]: 

48 """Get names of all registered FSM entities""" 

49 return set(cls._fsm_entity_registry.keys()) 

50 

51 @declared_attr 

52 def current_status_id(cls) -> Mapped[Optional[int]]: 

53 return mapped_column(ForeignKey("statuses.id"), nullable=True) 

54 

55 @declared_attr 

56 def workflow_id(cls) -> Mapped[Optional[int]]: 

57 return mapped_column(ForeignKey("workflows.id"), nullable=True) 

58 

59 @declared_attr 

60 def current_status(cls) -> Mapped[Optional[Status]]: 

61 return relationship( 

62 "Status", 

63 foreign_keys=[cls.current_status_id], # type: ignore 

64 lazy="joined", # type: ignore 

65 ) 

66 

67 @declared_attr 

68 def workflow(cls) -> Mapped[Optional[Workflow]]: 

69 return relationship("Workflow", foreign_keys=[cls.workflow_id]) # type: ignore 

70 

71 def clear_workflow(self): 

72 """Clear the workflow and current status of the entity""" 

73 self.workflow = None 

74 self.workflow_id = None 

75 self.current_status = None 

76 self.current_status_id = None 

77 

78 def actions_for_current_status(self) -> set[PermString]: 

79 """Returns the set of actions which are permitted at the current status""" 

80 if self.current_status: 

81 return self.current_status.status_actions 

82 return set() # Return empty set if no current status 

83 

84 @declared_attr 

85 def status_by_name(cls): 

86 """Relationship to find a Status by name within the current workflow""" 

87 return relationship( 

88 "Status", 

89 primaryjoin=lambda: and_( 

90 foreign(cls.workflow_id) == Status.workflow_id, 

91 Status.name != None, # noqa 

92 ), 

93 viewonly=True, 

94 uselist=True, 

95 ) 

96 

97 @property 

98 def status_name(self) -> Optional[str]: 

99 """Get the current status name""" 

100 if self.current_status is None: 

101 return None 

102 return self.current_status.name 

103 

104 @status_name.setter 

105 def status_name(self, status_name: str): 

106 """Set the current status by name""" 

107 if not status_name: 

108 self.current_status = None 

109 return 

110 

111 if not self.workflow_id: 

112 raise ValueError( 

113 "Cannot set status by name: entity has no workflow assigned" 

114 ) 

115 

116 # Force refresh of status_by_name relationship to ensure we get 

117 # statuses from the current workflow, not cached ones 

118 session = object_session(self) 

119 if session: 

120 session.expire(self, ["status_by_name"]) 

121 

122 # Find the matching status in our workflow's statuses 

123 matching_status = None 

124 for status in self.status_by_name: 

125 if status.name == status_name: 

126 matching_status = status 

127 break 

128 

129 if not matching_status: 

130 raise ValueError( 

131 f"Status '{status_name}' not found in the current workflow" 

132 ) 

133 

134 self.current_status = matching_status 

135 

136 """FSMEntity methods to be implemented by subclasses""" 

137 

138 @classmethod 

139 def get_context_schema(cls) -> dict: 

140 """Return the JSON schema for the context data.""" 

141 raise NotImplementedError 

142 

143 def get_context_data(self) -> dict: 

144 """Return the context data for the entity instance.""" 

145 raise NotImplementedError 

146 

147 

148def require_fsm_methods(cls): 

149 """ 

150 Class decorator that ensures subclasses of FSMEntity 

151 implement all required abstract methods. 

152 """ 

153 required_methods = [ 

154 "get_context_schema", 

155 "get_context_data", 

156 ] 

157 

158 if FSMEntity not in cls.__bases__: 

159 raise TypeError( 

160 "require_fsm_methods class decorator is only for classes using the FSMEntity as a mixin" 

161 ) 

162 

163 missing = [] 

164 for method_name in required_methods: 

165 method = getattr(cls, method_name, None) 

166 # Check if the method exists and is not the original implementation from FSMEntity 

167 if method is None or method.__qualname__.startswith("FSMEntity"): 

168 missing.append(method_name) 

169 

170 if missing: 

171 raise TypeError( 

172 f"Class {cls.__name__} must implement these abstract FSM methods: {', '.join(missing)}" 

173 ) 

174 

175 return cls