Coverage for postrfp / model / meta.py: 83%
88 statements
« prev ^ index » next coverage.py v7.12.0, created at 2025-12-03 01:35 +0000
« prev ^ index » next coverage.py v7.12.0, created at 2025-12-03 01:35 +0000
1from unicodedata import normalize
2from typing import Sequence, Optional
3import mimetypes
4import sqids
5import pathlib
6import re
7from sqlalchemy import MetaData
8from sqlalchemy import Integer, Unicode, JSON
9from sqlalchemy.orm import (
10 object_session,
11 validates,
12 DeclarativeBase,
13 Mapped,
14 mapped_column,
15)
18def human_friendly_bytes(size):
19 if not size or size == 0:
20 return "0 KB"
21 elif size < 1024:
22 return "1 KB"
23 elif size < 1024 * 1024:
24 return "%s KB" % int(size / 1024)
25 else:
26 return "%s MB" % int(size / 1024 / 1024)
29class Visitor: # pragma: no cover
30 """
31 Base visitor classes enabling subclasses to implement
32 just the methods they need.
33 """
35 def hello_section(self, sec):
36 pass
38 def goodbye_section(self, sec):
39 pass
41 def visit_question(self, question):
42 """
43 Using NotImplemented because presence of this method is
44 used to determine whether or not to load questions
46 Therefore better to make it clear that this method defined in this
47 class is never called
48 """
49 raise NotImplementedError()
52naming_dict = {
53 "ix": "ix__%(table_name)s__%(column_0_N_name)s",
54 "uq": "uq__%(table_name)s__%(column_0_N_name)s",
55 "ck": "ck__%(table_name)s__%(constraint_name)s",
56 "fk": "fk__%(table_name)s__%(column_0_N_name)s__%(referred_table_name)s",
57 "pk": "pk__%(table_name)s",
58}
61class Base(DeclarativeBase):
62 metadata = MetaData(naming_convention=naming_dict)
64 # by default only show id
65 public_attrs: Sequence = ["id"]
67 id: Mapped[int] = mapped_column(Integer, primary_key=True)
69 def as_dict(self, *args, **kwargs):
70 attrs = self.public_attrs
71 return {k: getattr(self, k, "Not Provided") for k in attrs}
73 def __repr__(self):
74 return f"<{self.__class__.__name__} {self.id}>"
76 @property
77 def _instance_session(self):
78 return object_session(self)
81class AttachmentMixin:
82 size_bytes: Mapped[int] = mapped_column("size", Integer, default=0, nullable=False)
83 filename: Mapped[str] = mapped_column(Unicode(255), nullable=False)
84 mimetype: Mapped[str] = mapped_column(Unicode(100), nullable=False)
86 @validates("filename")
87 def _make_safe_filename(self, _attr_name, filename):
88 """
89 Uploaded files can use trick filenames like '../../../rc.local'
90 to try to hack an operating system.
91 Filenames aren't used directly in this app, so shouldn't be a danger
92 but we clean up anyway.
93 """
94 ascii_filename = (
95 normalize("NFKD", filename).encode("ascii", "ignore").decode("ascii")
96 )
97 ascii_filename = pathlib.Path(ascii_filename).name
98 return re.sub(r"\s|/|\\", "_", ascii_filename)
100 def guess_set_mimetype(self, filename):
101 """
102 Set the mimetype attribute for this file based on
103 the filename extension
104 """
105 mtype, _enc = mimetypes.guess_type(filename)
106 self.mimetype = mtype
107 return mtype
109 @property
110 def size(self):
111 return human_friendly_bytes(self.size_bytes)
113 def __repr__(self) -> str:
114 return (
115 f"{self.__class__.__name__} - filename: {self.filename} size: {self.size}"
116 )
119class HTTPHeadersMixin:
120 http_headers: Mapped[Optional[list[dict[str, str]]]] = mapped_column(
121 JSON(),
122 nullable=True,
123 comment="Optional HTTP headers for the request",
124 )
126 @validates("http_headers")
127 def validate_headers(self, key, value):
128 if value is not None:
129 if not isinstance(value, list):
130 raise ValueError(
131 "DB Validation: http_headers must be a list of JSON objects"
132 )
133 if len(value) > 5:
134 raise ValueError(
135 "DB Validation: http_headers cannot have more than 5 entries"
136 )
137 for item in value:
138 if not isinstance(item, dict):
139 raise ValueError(
140 "DB Validation: http_headers must be a list of JSON objects"
141 )
142 if "header" not in item or "value" not in item:
143 raise ValueError(
144 "DB Validation: http_headers must contain 'header' and 'value' keys"
145 )
146 if not isinstance(item["header"], str) or not isinstance(
147 item["value"], str
148 ):
149 raise ValueError(
150 "DB Validation: http_headers keys and values must be strings"
151 )
152 return value
155class SqidsMixin:
156 """
157 Mixin to add UID encoding/decoding capabilities using Sqids.
158 Subclasses must define `_sqids_alphabet`.
159 """
161 _sqids_alphabet: str
162 _sqids_min_length: int = 5
163 _sqids_instance: Optional[sqids.Sqids] = None
165 def __init_subclass__(cls, **kwargs):
166 super().__init_subclass__(**kwargs)
167 if not hasattr(cls, "_sqids_alphabet"):
168 raise TypeError(f"Class {cls.__name__} must define _sqids_alphabet")
170 @classmethod
171 def _get_sqids(cls) -> sqids.Sqids:
172 if cls._sqids_instance is None:
173 if not hasattr(cls, "_sqids_alphabet"):
174 raise ValueError(f"{cls.__name__} must define _sqids_alphabet")
175 cls._sqids_instance = sqids.Sqids(
176 alphabet=cls._sqids_alphabet,
177 min_length=cls._sqids_min_length,
178 )
179 return cls._sqids_instance
181 @property
182 def uid(self) -> str:
183 self_id = getattr(self, "id", None)
184 if self_id is None:
185 return ""
186 return self._get_sqids().encode([self_id])
188 @classmethod
189 def decode_uid(cls, value: str) -> int:
190 decoded_parts = cls._get_sqids().decode(value)
191 if len(decoded_parts) != 1:
192 raise ValueError(f"Invalid {cls.__name__} UID")
193 return decoded_parts[0]