Coverage for amqtt/contrib/shadows/models.py: 94%
69 statements
« prev ^ index » next coverage.py v7.8.2, created at 2025-08-12 14:35 +0000
« prev ^ index » next coverage.py v7.8.2, created at 2025-08-12 14:35 +0000
1from collections.abc import Sequence
2from dataclasses import asdict
3import logging
4import time
5from typing import Any, Optional
6import uuid
8from sqlalchemy import JSON, CheckConstraint, Integer, String, UniqueConstraint, desc, event, func, select
9from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession
10from sqlalchemy.orm import DeclarativeBase, Mapped, Mapper, Session, make_transient, mapped_column
12from amqtt.contrib.shadows.states import StateDocument
14logger = logging.getLogger(__name__)
17class ShadowUpdateError(Exception):
18 def __init__(self, message: str = "updating an existing Shadow is not allowed") -> None:
19 super().__init__(message)
22class ShadowBase(DeclarativeBase):
23 pass
26async def sync_shadow_base(connection: AsyncConnection) -> None:
27 """Create tables and table schemas."""
28 await connection.run_sync(ShadowBase.metadata.create_all)
31def default_state_document() -> dict[str, Any]:
32 """Create a default (empty) state document, factory for model field."""
33 return asdict(StateDocument())
36class Shadow(ShadowBase):
37 __tablename__ = "shadows_shadow"
39 id: Mapped[str | None] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
41 device_id: Mapped[str] = mapped_column(String(128), nullable=False)
42 name: Mapped[str] = mapped_column(String(128), nullable=False)
43 version: Mapped[int] = mapped_column(Integer, nullable=False)
45 _state: Mapped[dict[str, Any]] = mapped_column("state", JSON, nullable=False, default=dict)
47 created_at: Mapped[int] = mapped_column(Integer, default=lambda: int(time.time()), nullable=False)
49 __table_args__ = (
50 CheckConstraint("version > 0", name="check_quantity_positive"),
51 UniqueConstraint("device_id", "name", "version", name="uq_device_id_name_version"),
52 )
54 @property
55 def state(self) -> StateDocument:
56 if not self._state:
57 return StateDocument()
58 return StateDocument.from_dict(self._state)
60 @state.setter
61 def state(self, value: StateDocument) -> None:
62 self._state = asdict(value)
64 @classmethod
65 async def latest_version(cls, session: AsyncSession, device_id: str, name: str) -> Optional["Shadow"]:
66 """Get the latest version of the shadow associated with the device and name."""
67 stmt = (
68 select(cls).where(
69 cls.device_id == device_id,
70 cls.name == name
71 ).order_by(desc(cls.version)).limit(1)
72 )
73 result = await session.execute(stmt)
74 return result.scalar_one_or_none()
76 @classmethod
77 async def all(cls, session: AsyncSession, device_id: str, name: str) -> Sequence["Shadow"]:
78 """Return a list of all shadows associated with the device and name."""
79 stmt = (
80 select(cls).where(
81 cls.device_id == device_id,
82 cls.name == name
83 ).order_by(desc(cls.version)))
84 result = await session.execute(stmt)
85 return result.scalars().all()
88@event.listens_for(Shadow, "before_insert")
89def assign_incremental_version(_: Mapper[Any], connection: Session, target: "Shadow") -> None:
90 """Get the latest version of the state document."""
91 stmt = (
92 select(func.max(Shadow.version))
93 .where(
94 Shadow.device_id == target.device_id,
95 Shadow.name == target.name
96 )
97 )
98 result = connection.execute(stmt).scalar_one_or_none()
99 target.version = (result or 0) + 1
102@event.listens_for(Shadow, "before_update")
103def prevent_update(_mapper: Mapper[Any], _session: Session, _instance: "Shadow") -> None:
104 """Prevent existing shadow from being updated."""
105 raise ShadowUpdateError
108@event.listens_for(Session, "before_flush")
109def convert_update_to_insert(session: Session, _flush_context: object, _instances: object | None) -> None:
110 """Force a shadow to insert a new version, instead of updating an existing."""
111 # Make a copy of the dirty set so we can safely mutate the session
112 dirty = list(session.dirty)
114 for obj in dirty:
115 if not session.is_modified(obj, include_collections=False): 115 ↛ 116line 115 didn't jump to line 116 because the condition on line 115 was never true
116 continue # skip unchanged
118 # You can scope this to a particular class
119 if not isinstance(obj, Shadow):
120 continue
122 # Clone logic: convert update into insert
123 session.expunge(obj) # remove from session
124 make_transient(obj) # remove identity and history
125 obj.id = "" # clear primary key
126 obj.version += 1 # bump version or modify fields
128 session.add(obj) # re-add as new object
131_listener_example = '''#
132# @event.listens_for(Shadow, "before_insert")
133# def convert_state_document_to_json(_1: Mapper[Any], _2: Session, target: "Shadow") -> None:
134# """Listen for insertion and convert state document to json."""
135# if not isinstance(target.state, StateDocument):
136# msg = "'state' field needs to be a StateDocument"
137# raise TypeError(msg)
138#
139# target.state = target.state.to_dict()
140'''