Coverage for amqtt/contrib/shadows/plugin.py: 85%

99 statements  

« prev     ^ index     » next       coverage.py v7.8.2, created at 2025-08-12 14:35 +0000

1from collections import defaultdict 

2from dataclasses import dataclass, field 

3import json 

4import re 

5from typing import Any 

6 

7from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine 

8 

9from amqtt.broker import BrokerContext 

10from amqtt.contexts import Action 

11from amqtt.contrib.shadows.messages import ( 

12 GetAcceptedMessage, 

13 GetRejectedMessage, 

14 UpdateAcceptedMessage, 

15 UpdateDeltaMessage, 

16 UpdateDocumentMessage, 

17 UpdateIotaMessage, 

18) 

19from amqtt.contrib.shadows.models import Shadow, sync_shadow_base 

20from amqtt.contrib.shadows.states import ( 

21 ShadowOperation, 

22 StateDocument, 

23 calculate_delta_update, 

24 calculate_iota_update, 

25) 

26from amqtt.plugins.base import BasePlugin, BaseTopicPlugin 

27from amqtt.session import ApplicationMessage, Session 

28 

29shadow_topic_re = re.compile(r"^\$shadow/(?P<client_id>[a-zA-Z0-9_-]+?)/(?P<shadow_name>[a-zA-Z0-9_-]+?)/(?P<request>get|update)") 

30 

31DeviceID = str 

32ShadowName = str 

33 

34 

35@dataclass 

36class ShadowTopic: 

37 device_id: DeviceID 

38 name: ShadowName 

39 message_op: ShadowOperation 

40 

41 

42def shadow_dict() -> dict[DeviceID, dict[ShadowName, StateDocument]]: 

43 """Nested defaultdict for shadow cache.""" 

44 return defaultdict(shadow_dict) # type: ignore[arg-type] 

45 

46 

47class ShadowPlugin(BasePlugin[BrokerContext]): 

48 

49 def __init__(self, context: BrokerContext) -> None: 

50 super().__init__(context) 

51 self._shadows: dict[DeviceID, dict[ShadowName, StateDocument]] = defaultdict(dict) 

52 

53 self._engine = create_async_engine(self.config.connection) 

54 self._db_session_maker = async_sessionmaker(self._engine, expire_on_commit=False) 

55 

56 async def on_broker_pre_start(self) -> None: 

57 """Sync the schema.""" 

58 async with self._engine.begin() as conn: 

59 await sync_shadow_base(conn) 

60 

61 @staticmethod 

62 def shadow_topic_match(topic: str) -> ShadowTopic | None: 

63 """Check if topic matches the shadow topic format.""" 

64 # pattern is "$shadow/<username>/<shadow_name>/get, update, etc 

65 match = shadow_topic_re.search(topic) 

66 if match: 

67 groups = match.groupdict() 

68 return ShadowTopic(groups["client_id"], groups["shadow_name"], ShadowOperation(groups["request"])) 

69 return None 

70 

71 async def _handle_get(self, st: ShadowTopic) -> None: 

72 """Send 'accepted.""" 

73 async with self._db_session_maker() as db_session, db_session.begin(): 

74 shadow = await Shadow.latest_version(db_session, st.device_id, st.name) 

75 if not shadow: 

76 reject_msg = GetRejectedMessage( 

77 code=404, 

78 message="shadow not found", 

79 ) 

80 await self.context.broadcast_message(reject_msg.topic(st.device_id, st.name), reject_msg.to_message()) 

81 return 

82 

83 accept_msg = GetAcceptedMessage( 

84 state=shadow.state.state, 

85 metadata=shadow.state.metadata, 

86 timestamp=shadow.created_at, 

87 version=shadow.version 

88 ) 

89 await self.context.broadcast_message(accept_msg.topic(st.device_id, st.name), accept_msg.to_message()) 

90 

91 async def _handle_update(self, st: ShadowTopic, update: dict[str, Any]) -> None: 

92 async with self._db_session_maker() as db_session, db_session.begin(): 

93 shadow = await Shadow.latest_version(db_session, st.device_id, st.name) 

94 if not shadow: 94 ↛ 97line 94 didn't jump to line 97 because the condition on line 94 was always true

95 shadow = Shadow(device_id=st.device_id, name=st.name) 

96 

97 state_update = StateDocument.from_dict(update) 

98 

99 prev_state = shadow.state or StateDocument() 

100 prev_state.version = shadow.version or 0 # only required when generating shadow messages 

101 prev_state.timestamp = shadow.created_at or 0 # only required when generating shadow messages 

102 

103 next_state = prev_state + state_update 

104 

105 shadow.state = next_state 

106 db_session.add(shadow) 

107 await db_session.commit() 

108 

109 next_state.version = shadow.version 

110 next_state.timestamp = shadow.created_at 

111 

112 accept_msg = UpdateAcceptedMessage( 

113 state=next_state.state, 

114 metadata=next_state.metadata, 

115 timestamp=123, 

116 version=1 

117 ) 

118 

119 await self.context.broadcast_message(accept_msg.topic(st.device_id, st.name), accept_msg.to_message()) 

120 

121 delta_msg = UpdateDeltaMessage( 

122 state=calculate_delta_update(next_state.state.desired, next_state.state.reported), 

123 metadata=calculate_delta_update(next_state.metadata.desired, next_state.metadata.reported), 

124 version=shadow.version, 

125 timestamp=shadow.created_at 

126 ) 

127 await self.context.broadcast_message(delta_msg.topic(st.device_id, st.name), delta_msg.to_message()) 

128 

129 iota_msg = UpdateIotaMessage( 

130 state=calculate_iota_update(next_state.state.desired, next_state.state.reported), 

131 metadata=calculate_delta_update(next_state.metadata.desired, next_state.metadata.reported), 

132 version=shadow.version, 

133 timestamp=shadow.created_at 

134 ) 

135 await self.context.broadcast_message(iota_msg.topic(st.device_id, st.name), iota_msg.to_message()) 

136 

137 doc_msg = UpdateDocumentMessage( 

138 previous=prev_state, 

139 current=next_state, 

140 timestamp=shadow.created_at 

141 ) 

142 

143 await self.context.broadcast_message(doc_msg.topic(st.device_id, st.name), doc_msg.to_message()) 

144 

145 async def on_broker_message_received(self, *, client_id: str, message: ApplicationMessage) -> None: 

146 """Process a message that was received from a client.""" 

147 topic = message.topic 

148 if not topic.startswith("$shadow"): # this is less overhead than do the full regular expression match 148 ↛ 149line 148 didn't jump to line 149 because the condition on line 148 was never true

149 return 

150 

151 if not (shadow_topic := self.shadow_topic_match(topic)): 151 ↛ 152line 151 didn't jump to line 152 because the condition on line 151 was never true

152 return 

153 

154 match shadow_topic.message_op: 

155 

156 case ShadowOperation.GET: 

157 await self._handle_get(shadow_topic) 

158 case ShadowOperation.UPDATE: 158 ↛ exitline 158 didn't return from function 'on_broker_message_received' because the pattern on line 158 always matched

159 await self._handle_update(shadow_topic, json.loads(message.data.decode("utf-8"))) 

160 

161 @dataclass 

162 class Config: 

163 """Configuration for shadow plugin.""" 

164 

165 connection: str 

166 """SQLAlchemy connection string for the asyncio version of the database connector: 

167 

168 - `mysql+aiomysql://user:password@host:port/dbname` 

169 - `postgresql+asyncpg://user:password@host:port/dbname` 

170 - `sqlite+aiosqlite:///dbfilename.db` 

171 """ 

172 

173 

174class ShadowTopicAuthPlugin(BaseTopicPlugin): 

175 

176 async def topic_filtering(self, *, 

177 session: Session | None = None, 

178 topic: str | None = None, 

179 action: Action | None = None) -> bool | None: 

180 

181 session = session or Session() 

182 if not topic: 

183 return False 

184 

185 shadow_topic = ShadowPlugin.shadow_topic_match(topic) 

186 

187 if not shadow_topic: 

188 return False 

189 

190 return shadow_topic.device_id == session.username or session.username in self.config.superusers 

191 

192 @dataclass 

193 class Config: 

194 """Configuration for only allowing devices access to their own shadow topics.""" 

195 

196 superusers: list[str] = field(default_factory=list) 

197 """A list of one or more usernames that can write to any device topic, 

198 primarily for the central app sending updates to devices."""