Coverage for amqtt/contrib/shadows/models.py: 94%

69 statements  

« prev     ^ index     » next       coverage.py v7.8.2, created at 2025-08-12 14:35 +0000

1from collections.abc import Sequence 

2from dataclasses import asdict 

3import logging 

4import time 

5from typing import Any, Optional 

6import uuid 

7 

8from sqlalchemy import JSON, CheckConstraint, Integer, String, UniqueConstraint, desc, event, func, select 

9from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession 

10from sqlalchemy.orm import DeclarativeBase, Mapped, Mapper, Session, make_transient, mapped_column 

11 

12from amqtt.contrib.shadows.states import StateDocument 

13 

14logger = logging.getLogger(__name__) 

15 

16 

17class ShadowUpdateError(Exception): 

18 def __init__(self, message: str = "updating an existing Shadow is not allowed") -> None: 

19 super().__init__(message) 

20 

21 

22class ShadowBase(DeclarativeBase): 

23 pass 

24 

25 

26async def sync_shadow_base(connection: AsyncConnection) -> None: 

27 """Create tables and table schemas.""" 

28 await connection.run_sync(ShadowBase.metadata.create_all) 

29 

30 

31def default_state_document() -> dict[str, Any]: 

32 """Create a default (empty) state document, factory for model field.""" 

33 return asdict(StateDocument()) 

34 

35 

36class Shadow(ShadowBase): 

37 __tablename__ = "shadows_shadow" 

38 

39 id: Mapped[str | None] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4())) 

40 

41 device_id: Mapped[str] = mapped_column(String(128), nullable=False) 

42 name: Mapped[str] = mapped_column(String(128), nullable=False) 

43 version: Mapped[int] = mapped_column(Integer, nullable=False) 

44 

45 _state: Mapped[dict[str, Any]] = mapped_column("state", JSON, nullable=False, default=dict) 

46 

47 created_at: Mapped[int] = mapped_column(Integer, default=lambda: int(time.time()), nullable=False) 

48 

49 __table_args__ = ( 

50 CheckConstraint("version > 0", name="check_quantity_positive"), 

51 UniqueConstraint("device_id", "name", "version", name="uq_device_id_name_version"), 

52 ) 

53 

54 @property 

55 def state(self) -> StateDocument: 

56 if not self._state: 

57 return StateDocument() 

58 return StateDocument.from_dict(self._state) 

59 

60 @state.setter 

61 def state(self, value: StateDocument) -> None: 

62 self._state = asdict(value) 

63 

64 @classmethod 

65 async def latest_version(cls, session: AsyncSession, device_id: str, name: str) -> Optional["Shadow"]: 

66 """Get the latest version of the shadow associated with the device and name.""" 

67 stmt = ( 

68 select(cls).where( 

69 cls.device_id == device_id, 

70 cls.name == name 

71 ).order_by(desc(cls.version)).limit(1) 

72 ) 

73 result = await session.execute(stmt) 

74 return result.scalar_one_or_none() 

75 

76 @classmethod 

77 async def all(cls, session: AsyncSession, device_id: str, name: str) -> Sequence["Shadow"]: 

78 """Return a list of all shadows associated with the device and name.""" 

79 stmt = ( 

80 select(cls).where( 

81 cls.device_id == device_id, 

82 cls.name == name 

83 ).order_by(desc(cls.version))) 

84 result = await session.execute(stmt) 

85 return result.scalars().all() 

86 

87 

88@event.listens_for(Shadow, "before_insert") 

89def assign_incremental_version(_: Mapper[Any], connection: Session, target: "Shadow") -> None: 

90 """Get the latest version of the state document.""" 

91 stmt = ( 

92 select(func.max(Shadow.version)) 

93 .where( 

94 Shadow.device_id == target.device_id, 

95 Shadow.name == target.name 

96 ) 

97 ) 

98 result = connection.execute(stmt).scalar_one_or_none() 

99 target.version = (result or 0) + 1 

100 

101 

102@event.listens_for(Shadow, "before_update") 

103def prevent_update(_mapper: Mapper[Any], _session: Session, _instance: "Shadow") -> None: 

104 """Prevent existing shadow from being updated.""" 

105 raise ShadowUpdateError 

106 

107 

108@event.listens_for(Session, "before_flush") 

109def convert_update_to_insert(session: Session, _flush_context: object, _instances: object | None) -> None: 

110 """Force a shadow to insert a new version, instead of updating an existing.""" 

111 # Make a copy of the dirty set so we can safely mutate the session 

112 dirty = list(session.dirty) 

113 

114 for obj in dirty: 

115 if not session.is_modified(obj, include_collections=False): 115 ↛ 116line 115 didn't jump to line 116 because the condition on line 115 was never true

116 continue # skip unchanged 

117 

118 # You can scope this to a particular class 

119 if not isinstance(obj, Shadow): 

120 continue 

121 

122 # Clone logic: convert update into insert 

123 session.expunge(obj) # remove from session 

124 make_transient(obj) # remove identity and history 

125 obj.id = "" # clear primary key 

126 obj.version += 1 # bump version or modify fields 

127 

128 session.add(obj) # re-add as new object 

129 

130 

131_listener_example = '''# 

132# @event.listens_for(Shadow, "before_insert") 

133# def convert_state_document_to_json(_1: Mapper[Any], _2: Session, target: "Shadow") -> None: 

134# """Listen for insertion and convert state document to json.""" 

135# if not isinstance(target.state, StateDocument): 

136# msg = "'state' field needs to be a StateDocument" 

137# raise TypeError(msg) 

138# 

139# target.state = target.state.to_dict() 

140'''