Coverage for amqtt/plugins/topic_checking.py: 94%
78 statements
« prev ^ index » next coverage.py v7.8.2, created at 2025-08-12 14:35 +0000
« prev ^ index » next coverage.py v7.8.2, created at 2025-08-12 14:35 +0000
1from dataclasses import dataclass, field
2from typing import Any
3import warnings
5from amqtt.contexts import Action, BaseContext
6from amqtt.errors import PluginInitError
7from amqtt.plugins.base import BaseTopicPlugin
8from amqtt.session import Session
11class TopicTabooPlugin(BaseTopicPlugin):
12 def __init__(self, context: BaseContext) -> None:
13 super().__init__(context)
14 self._taboo: list[str] = ["prohibited", "top-secret", "data/classified"]
16 async def topic_filtering(
17 self, *, session: Session | None = None, topic: str | None = None, action: Action | None = None
18 ) -> bool | None:
19 filter_result = await super().topic_filtering(session=session, topic=topic, action=action)
20 if filter_result:
21 if session and session.username == "admin":
22 return True
23 return not (topic and topic in self._taboo)
24 return bool(filter_result)
27class TopicAccessControlListPlugin(BaseTopicPlugin):
29 def __init__(self, context: BaseContext) -> None:
30 super().__init__(context)
32 if self._get_config_option("acl", None):
33 warnings.warn("The 'acl' option is deprecated, please use 'subscribe-acl' instead.", stacklevel=1)
35 if self._get_config_option("acl", None) and self._get_config_option("subscribe-acl", None): 35 ↛ 36line 35 didn't jump to line 36 because the condition on line 35 was never true
36 msg = "'acl' has been replaced with 'subscribe-acl'; only one may be included"
37 raise PluginInitError(msg)
39 @staticmethod
40 def topic_ac(topic_requested: str, topic_allowed: str) -> bool:
41 req_split = topic_requested.split("/")
42 allowed_split = topic_allowed.split("/")
43 ret = True
44 for i in range(max(len(req_split), len(allowed_split))):
45 try:
46 a_aux = req_split[i]
47 b_aux = allowed_split[i]
48 except IndexError:
49 ret = False
50 break
51 if b_aux == "#":
52 break
53 if b_aux in ("+", a_aux):
54 continue
55 ret = False
56 break
57 return ret
59 async def topic_filtering(
60 self, *, session: Session | None = None, topic: str | None = None, action: Action | None = None
61 ) -> bool | None:
62 filter_result = await super().topic_filtering(session=session, topic=topic, action=action)
63 if not filter_result:
64 return False
66 # hbmqtt and older amqtt do not support publish filtering
67 if action == Action.PUBLISH and not self._get_config_option("publish-acl", {}):
68 # maintain backward compatibility, assume permitted
69 return True
71 req_topic = topic
72 if not req_topic:
73 return False
75 username = session.username if session else None
76 if username is None:
77 username = "anonymous"
79 acl: dict[str, Any] | None = None
80 match action:
81 case Action.PUBLISH:
82 acl = self._get_config_option("publish-acl", None)
83 case Action.SUBSCRIBE:
84 acl = self._get_config_option("subscribe-acl", self._get_config_option("acl", None))
85 case Action.RECEIVE: 85 ↛ 87line 85 didn't jump to line 87 because the pattern on line 85 always matched
86 acl = self._get_config_option("receive-acl", None)
87 case _:
88 msg = "Received an invalid action type."
89 raise ValueError(msg)
91 if acl is None:
92 return True
94 allowed_topics = acl.get(username, [])
95 if not allowed_topics:
96 return False
98 return any(self.topic_ac(req_topic, allowed_topic) for allowed_topic in allowed_topics)
100 @dataclass
101 class Config:
102 """Mappings of username and list of approved topics."""
104 publish_acl: dict[str, list[str]] = field(default_factory=dict)
105 acl: dict[str, list[str]] = field(default_factory=dict)