Coverage for postrfp/model/helpers.py: 100%

24 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-10-22 21:34 +0000

1import random 

2import string 

3from typing import Set, TYPE_CHECKING, Iterable 

4 

5from pydantic import BaseModel 

6 

7 

8if TYPE_CHECKING: 

9 from postrfp.model import AuditEvent 

10 from postrfp.model.meta import Base 

11 

12 

13def validate_section_children( 

14 session, 

15 node_type: type, 

16 provided_id_set: Set[int], 

17 found_node_ids: Set[int], 

18 current_child_ids: Set[int], 

19 delete_orphans=False, 

20): 

21 """ 

22 Validate the set of provided ids. Ensure: 

23 - that the set of provided_ids matches that of the found_ids 

24 - that any missing ids (ie present in current_child_ids but missing in provided_id_set) are 

25 only provided if delete_orphans is True 

26 """ 

27 if len(found_node_ids) != len(provided_id_set): 

28 alien = provided_id_set - found_node_ids 

29 m = f"Some provided {node_type.__name__} ids not found in this project: {alien}" 

30 raise ValueError(m) 

31 

32 potential_orphans = current_child_ids - provided_id_set 

33 if potential_orphans and not delete_orphans: 

34 m = f"{node_type.__name__} ids {potential_orphans} exist in the current section but were " 

35 m += "not provided in the provided ID list. Will not delete because delete_orphans is false" 

36 raise ValueError(m) 

37 

38 

39def audited_patch( 

40 db_model: "Base", 

41 serial_doc: "BaseModel", 

42 audit_event: "AuditEvent", 

43 patch_keys: Iterable[str], 

44 prefix=None, 

45): 

46 """ 

47 Patch the db_model with the serial_doc for the given keys and record the changes to the audit_event 

48 """ 

49 for k in patch_keys: 

50 old = getattr(db_model, k) 

51 new = getattr(serial_doc, k) 

52 if old != new: 

53 prop_name = f"{prefix}.{k}" if prefix else k 

54 audit_event.add_change(prop_name, old, new) 

55 setattr(db_model, k, new) 

56 

57 

58def random_string(length=5) -> str: 

59 return "".join(random.choices(string.ascii_lowercase, k=length))