Coverage for amqtt/contrib/shadows/plugin.py: 85%
99 statements
« prev ^ index » next coverage.py v7.8.2, created at 2025-08-12 14:35 +0000
« prev ^ index » next coverage.py v7.8.2, created at 2025-08-12 14:35 +0000
1from collections import defaultdict
2from dataclasses import dataclass, field
3import json
4import re
5from typing import Any
7from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
9from amqtt.broker import BrokerContext
10from amqtt.contexts import Action
11from amqtt.contrib.shadows.messages import (
12 GetAcceptedMessage,
13 GetRejectedMessage,
14 UpdateAcceptedMessage,
15 UpdateDeltaMessage,
16 UpdateDocumentMessage,
17 UpdateIotaMessage,
18)
19from amqtt.contrib.shadows.models import Shadow, sync_shadow_base
20from amqtt.contrib.shadows.states import (
21 ShadowOperation,
22 StateDocument,
23 calculate_delta_update,
24 calculate_iota_update,
25)
26from amqtt.plugins.base import BasePlugin, BaseTopicPlugin
27from amqtt.session import ApplicationMessage, Session
29shadow_topic_re = re.compile(r"^\$shadow/(?P<client_id>[a-zA-Z0-9_-]+?)/(?P<shadow_name>[a-zA-Z0-9_-]+?)/(?P<request>get|update)")
31DeviceID = str
32ShadowName = str
35@dataclass
36class ShadowTopic:
37 device_id: DeviceID
38 name: ShadowName
39 message_op: ShadowOperation
42def shadow_dict() -> dict[DeviceID, dict[ShadowName, StateDocument]]:
43 """Nested defaultdict for shadow cache."""
44 return defaultdict(shadow_dict) # type: ignore[arg-type]
47class ShadowPlugin(BasePlugin[BrokerContext]):
49 def __init__(self, context: BrokerContext) -> None:
50 super().__init__(context)
51 self._shadows: dict[DeviceID, dict[ShadowName, StateDocument]] = defaultdict(dict)
53 self._engine = create_async_engine(self.config.connection)
54 self._db_session_maker = async_sessionmaker(self._engine, expire_on_commit=False)
56 async def on_broker_pre_start(self) -> None:
57 """Sync the schema."""
58 async with self._engine.begin() as conn:
59 await sync_shadow_base(conn)
61 @staticmethod
62 def shadow_topic_match(topic: str) -> ShadowTopic | None:
63 """Check if topic matches the shadow topic format."""
64 # pattern is "$shadow/<username>/<shadow_name>/get, update, etc
65 match = shadow_topic_re.search(topic)
66 if match:
67 groups = match.groupdict()
68 return ShadowTopic(groups["client_id"], groups["shadow_name"], ShadowOperation(groups["request"]))
69 return None
71 async def _handle_get(self, st: ShadowTopic) -> None:
72 """Send 'accepted."""
73 async with self._db_session_maker() as db_session, db_session.begin():
74 shadow = await Shadow.latest_version(db_session, st.device_id, st.name)
75 if not shadow:
76 reject_msg = GetRejectedMessage(
77 code=404,
78 message="shadow not found",
79 )
80 await self.context.broadcast_message(reject_msg.topic(st.device_id, st.name), reject_msg.to_message())
81 return
83 accept_msg = GetAcceptedMessage(
84 state=shadow.state.state,
85 metadata=shadow.state.metadata,
86 timestamp=shadow.created_at,
87 version=shadow.version
88 )
89 await self.context.broadcast_message(accept_msg.topic(st.device_id, st.name), accept_msg.to_message())
91 async def _handle_update(self, st: ShadowTopic, update: dict[str, Any]) -> None:
92 async with self._db_session_maker() as db_session, db_session.begin():
93 shadow = await Shadow.latest_version(db_session, st.device_id, st.name)
94 if not shadow: 94 ↛ 97line 94 didn't jump to line 97 because the condition on line 94 was always true
95 shadow = Shadow(device_id=st.device_id, name=st.name)
97 state_update = StateDocument.from_dict(update)
99 prev_state = shadow.state or StateDocument()
100 prev_state.version = shadow.version or 0 # only required when generating shadow messages
101 prev_state.timestamp = shadow.created_at or 0 # only required when generating shadow messages
103 next_state = prev_state + state_update
105 shadow.state = next_state
106 db_session.add(shadow)
107 await db_session.commit()
109 next_state.version = shadow.version
110 next_state.timestamp = shadow.created_at
112 accept_msg = UpdateAcceptedMessage(
113 state=next_state.state,
114 metadata=next_state.metadata,
115 timestamp=123,
116 version=1
117 )
119 await self.context.broadcast_message(accept_msg.topic(st.device_id, st.name), accept_msg.to_message())
121 delta_msg = UpdateDeltaMessage(
122 state=calculate_delta_update(next_state.state.desired, next_state.state.reported),
123 metadata=calculate_delta_update(next_state.metadata.desired, next_state.metadata.reported),
124 version=shadow.version,
125 timestamp=shadow.created_at
126 )
127 await self.context.broadcast_message(delta_msg.topic(st.device_id, st.name), delta_msg.to_message())
129 iota_msg = UpdateIotaMessage(
130 state=calculate_iota_update(next_state.state.desired, next_state.state.reported),
131 metadata=calculate_delta_update(next_state.metadata.desired, next_state.metadata.reported),
132 version=shadow.version,
133 timestamp=shadow.created_at
134 )
135 await self.context.broadcast_message(iota_msg.topic(st.device_id, st.name), iota_msg.to_message())
137 doc_msg = UpdateDocumentMessage(
138 previous=prev_state,
139 current=next_state,
140 timestamp=shadow.created_at
141 )
143 await self.context.broadcast_message(doc_msg.topic(st.device_id, st.name), doc_msg.to_message())
145 async def on_broker_message_received(self, *, client_id: str, message: ApplicationMessage) -> None:
146 """Process a message that was received from a client."""
147 topic = message.topic
148 if not topic.startswith("$shadow"): # this is less overhead than do the full regular expression match 148 ↛ 149line 148 didn't jump to line 149 because the condition on line 148 was never true
149 return
151 if not (shadow_topic := self.shadow_topic_match(topic)): 151 ↛ 152line 151 didn't jump to line 152 because the condition on line 151 was never true
152 return
154 match shadow_topic.message_op:
156 case ShadowOperation.GET:
157 await self._handle_get(shadow_topic)
158 case ShadowOperation.UPDATE: 158 ↛ exitline 158 didn't return from function 'on_broker_message_received' because the pattern on line 158 always matched
159 await self._handle_update(shadow_topic, json.loads(message.data.decode("utf-8")))
161 @dataclass
162 class Config:
163 """Configuration for shadow plugin."""
165 connection: str
166 """SQLAlchemy connection string for the asyncio version of the database connector:
168 - `mysql+aiomysql://user:password@host:port/dbname`
169 - `postgresql+asyncpg://user:password@host:port/dbname`
170 - `sqlite+aiosqlite:///dbfilename.db`
171 """
174class ShadowTopicAuthPlugin(BaseTopicPlugin):
176 async def topic_filtering(self, *,
177 session: Session | None = None,
178 topic: str | None = None,
179 action: Action | None = None) -> bool | None:
181 session = session or Session()
182 if not topic:
183 return False
185 shadow_topic = ShadowPlugin.shadow_topic_match(topic)
187 if not shadow_topic:
188 return False
190 return shadow_topic.device_id == session.username or session.username in self.config.superusers
192 @dataclass
193 class Config:
194 """Configuration for only allowing devices access to their own shadow topics."""
196 superusers: list[str] = field(default_factory=list)
197 """A list of one or more usernames that can write to any device topic,
198 primarily for the central app sending updates to devices."""