Coverage for postrfp/shared/fsm_entity.py: 95%
87 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-22 21:34 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-22 21:34 +0000
1from typing import Optional, Dict, Type, Set
3from sqlalchemy import ForeignKey, and_, ForeignKeyConstraint
4from sqlalchemy.orm import Mapped, mapped_column, relationship, foreign, object_session
5from sqlalchemy.ext.declarative import declared_attr
7from postrfp.model.fsm import Workflow, Status
8from postrfp.shared.types import PermString
11class FSMEntity:
12 """Mixin to add FSM properties to entity models and declare abstract methods to implement"""
14 @declared_attr
15 def __table_args__(cls):
16 """Define table constraints for FSM entities"""
17 return (
18 ForeignKeyConstraint(
19 ["current_status_id", "workflow_id"],
20 ["statuses.id", "statuses.workflow_id"],
21 name=f"fk_{cls.__tablename__}_status_workflow",
22 ),
23 )
25 # Class registry to track all FSM entities
26 _fsm_entity_registry: Dict[str, Type["FSMEntity"]] = {}
28 def __init_subclass__(cls, **kwargs):
29 """Automatically register FSM entity subclasses"""
30 super().__init_subclass__(**kwargs)
31 # Only register if this class directly inherits from FSMEntity
32 # and has SQLAlchemy table mapping
33 if FSMEntity in cls.__bases__ and hasattr(cls, "__tablename__"):
34 FSMEntity._fsm_entity_registry[cls.__name__] = cls
36 @classmethod
37 def get_registered_entities(cls) -> Dict[str, Type["FSMEntity"]]:
38 """Get all registered FSM entity classes"""
39 return cls._fsm_entity_registry.copy()
41 @classmethod
42 def get_entity_by_name(cls, name: str) -> Optional[Type["FSMEntity"]]:
43 """Get a specific FSM entity class by name"""
44 return cls._fsm_entity_registry.get(name)
46 @classmethod
47 def get_entity_names(cls) -> Set[str]:
48 """Get names of all registered FSM entities"""
49 return set(cls._fsm_entity_registry.keys())
51 @declared_attr
52 def current_status_id(cls) -> Mapped[Optional[int]]:
53 return mapped_column(ForeignKey("statuses.id"), nullable=True)
55 @declared_attr
56 def workflow_id(cls) -> Mapped[Optional[int]]:
57 return mapped_column(ForeignKey("workflows.id"), nullable=True)
59 @declared_attr
60 def current_status(cls) -> Mapped[Optional[Status]]:
61 return relationship(
62 "Status",
63 foreign_keys=[cls.current_status_id], # type: ignore
64 lazy="joined", # type: ignore
65 )
67 @declared_attr
68 def workflow(cls) -> Mapped[Optional[Workflow]]:
69 return relationship("Workflow", foreign_keys=[cls.workflow_id]) # type: ignore
71 def clear_workflow(self):
72 """Clear the workflow and current status of the entity"""
73 self.workflow = None
74 self.workflow_id = None
75 self.current_status = None
76 self.current_status_id = None
78 def actions_for_current_status(self) -> set[PermString]:
79 """Returns the set of actions which are permitted at the current status"""
80 if self.current_status:
81 return self.current_status.status_actions
82 return set() # Return empty set if no current status
84 @declared_attr
85 def status_by_name(cls):
86 """Relationship to find a Status by name within the current workflow"""
87 return relationship(
88 "Status",
89 primaryjoin=lambda: and_(
90 foreign(cls.workflow_id) == Status.workflow_id,
91 Status.name != None, # noqa
92 ),
93 viewonly=True,
94 uselist=True,
95 )
97 @property
98 def status_name(self) -> Optional[str]:
99 """Get the current status name"""
100 if self.current_status is None:
101 return None
102 return self.current_status.name
104 @status_name.setter
105 def status_name(self, status_name: str):
106 """Set the current status by name"""
107 if not status_name:
108 self.current_status = None
109 return
111 if not self.workflow_id:
112 raise ValueError(
113 "Cannot set status by name: entity has no workflow assigned"
114 )
116 # Force refresh of status_by_name relationship to ensure we get
117 # statuses from the current workflow, not cached ones
118 session = object_session(self)
119 if session:
120 session.expire(self, ["status_by_name"])
122 # Find the matching status in our workflow's statuses
123 matching_status = None
124 for status in self.status_by_name:
125 if status.name == status_name:
126 matching_status = status
127 break
129 if not matching_status:
130 raise ValueError(
131 f"Status '{status_name}' not found in the current workflow"
132 )
134 self.current_status = matching_status
136 """FSMEntity methods to be implemented by subclasses"""
138 @classmethod
139 def get_context_schema(cls) -> dict:
140 """Return the JSON schema for the context data."""
141 raise NotImplementedError
143 def get_context_data(self) -> dict:
144 """Return the context data for the entity instance."""
145 raise NotImplementedError
148def require_fsm_methods(cls):
149 """
150 Class decorator that ensures subclasses of FSMEntity
151 implement all required abstract methods.
152 """
153 required_methods = [
154 "get_context_schema",
155 "get_context_data",
156 ]
158 if FSMEntity not in cls.__bases__:
159 raise TypeError(
160 "require_fsm_methods class decorator is only for classes using the FSMEntity as a mixin"
161 )
163 missing = []
164 for method_name in required_methods:
165 method = getattr(cls, method_name, None)
166 # Check if the method exists and is not the original implementation from FSMEntity
167 if method is None or method.__qualname__.startswith("FSMEntity"):
168 missing.append(method_name)
170 if missing:
171 raise TypeError(
172 f"Class {cls.__name__} must implement these abstract FSM methods: {', '.join(missing)}"
173 )
175 return cls