Coverage for amqtt/plugins/topic_checking.py: 94%

78 statements  

« prev     ^ index     » next       coverage.py v7.8.2, created at 2025-08-12 14:35 +0000

1from dataclasses import dataclass, field 

2from typing import Any 

3import warnings 

4 

5from amqtt.contexts import Action, BaseContext 

6from amqtt.errors import PluginInitError 

7from amqtt.plugins.base import BaseTopicPlugin 

8from amqtt.session import Session 

9 

10 

11class TopicTabooPlugin(BaseTopicPlugin): 

12 def __init__(self, context: BaseContext) -> None: 

13 super().__init__(context) 

14 self._taboo: list[str] = ["prohibited", "top-secret", "data/classified"] 

15 

16 async def topic_filtering( 

17 self, *, session: Session | None = None, topic: str | None = None, action: Action | None = None 

18 ) -> bool | None: 

19 filter_result = await super().topic_filtering(session=session, topic=topic, action=action) 

20 if filter_result: 

21 if session and session.username == "admin": 

22 return True 

23 return not (topic and topic in self._taboo) 

24 return bool(filter_result) 

25 

26 

27class TopicAccessControlListPlugin(BaseTopicPlugin): 

28 

29 def __init__(self, context: BaseContext) -> None: 

30 super().__init__(context) 

31 

32 if self._get_config_option("acl", None): 

33 warnings.warn("The 'acl' option is deprecated, please use 'subscribe-acl' instead.", stacklevel=1) 

34 

35 if self._get_config_option("acl", None) and self._get_config_option("subscribe-acl", None): 35 ↛ 36line 35 didn't jump to line 36 because the condition on line 35 was never true

36 msg = "'acl' has been replaced with 'subscribe-acl'; only one may be included" 

37 raise PluginInitError(msg) 

38 

39 @staticmethod 

40 def topic_ac(topic_requested: str, topic_allowed: str) -> bool: 

41 req_split = topic_requested.split("/") 

42 allowed_split = topic_allowed.split("/") 

43 ret = True 

44 for i in range(max(len(req_split), len(allowed_split))): 

45 try: 

46 a_aux = req_split[i] 

47 b_aux = allowed_split[i] 

48 except IndexError: 

49 ret = False 

50 break 

51 if b_aux == "#": 

52 break 

53 if b_aux in ("+", a_aux): 

54 continue 

55 ret = False 

56 break 

57 return ret 

58 

59 async def topic_filtering( 

60 self, *, session: Session | None = None, topic: str | None = None, action: Action | None = None 

61 ) -> bool | None: 

62 filter_result = await super().topic_filtering(session=session, topic=topic, action=action) 

63 if not filter_result: 

64 return False 

65 

66 # hbmqtt and older amqtt do not support publish filtering 

67 if action == Action.PUBLISH and not self._get_config_option("publish-acl", {}): 

68 # maintain backward compatibility, assume permitted 

69 return True 

70 

71 req_topic = topic 

72 if not req_topic: 

73 return False 

74 

75 username = session.username if session else None 

76 if username is None: 

77 username = "anonymous" 

78 

79 acl: dict[str, Any] | None = None 

80 match action: 

81 case Action.PUBLISH: 

82 acl = self._get_config_option("publish-acl", None) 

83 case Action.SUBSCRIBE: 

84 acl = self._get_config_option("subscribe-acl", self._get_config_option("acl", None)) 

85 case Action.RECEIVE: 85 ↛ 87line 85 didn't jump to line 87 because the pattern on line 85 always matched

86 acl = self._get_config_option("receive-acl", None) 

87 case _: 

88 msg = "Received an invalid action type." 

89 raise ValueError(msg) 

90 

91 if acl is None: 

92 return True 

93 

94 allowed_topics = acl.get(username, []) 

95 if not allowed_topics: 

96 return False 

97 

98 return any(self.topic_ac(req_topic, allowed_topic) for allowed_topic in allowed_topics) 

99 

100 @dataclass 

101 class Config: 

102 """Mappings of username and list of approved topics.""" 

103 

104 publish_acl: dict[str, list[str]] = field(default_factory=dict) 

105 acl: dict[str, list[str]] = field(default_factory=dict)