Coverage for amqtt/plugins/base.py: 86%

59 statements  

« prev     ^ index     » next       coverage.py v7.8.2, created at 2025-08-12 14:35 +0000

1from dataclasses import dataclass, is_dataclass 

2from typing import Any, Generic, TypeVar, cast 

3 

4from amqtt.contexts import Action, BaseContext, BrokerConfig 

5from amqtt.session import Session 

6 

7C = TypeVar("C", bound=BaseContext) 

8 

9 

10class BasePlugin(Generic[C]): 

11 """The base from which all plugins should inherit. 

12 

13 Type Parameters 

14 --------------- 

15 C: 

16 A BaseContext: either BrokerContext or ClientContext, depending on plugin usage 

17 

18 Attributes 

19 ---------- 

20 context (C): 

21 Information about the environment in which this plugin is executed. Modifying 

22 the broker or client state should happen through methods available here. 

23 

24 config (self.Config): 

25 An instance of the Config dataclass defined by the plugin (or an empty dataclass, if not 

26 defined). If using entrypoint- or mixed-style configuration, use `_get_config_option()` 

27 to access the variable. 

28 

29 """ 

30 

31 def __init__(self, context: C) -> None: 

32 self.context: C = context 

33 # since the PluginManager will hydrate the config from a plugin's `Config` class, this is a safe cast 

34 self.config = cast("self.Config", context.config) # type: ignore[name-defined] 

35 

36 # Deprecated: included to support entrypoint-style configs. Replaced by dataclass Config class. 

37 def _get_config_section(self, name: str) -> dict[str, Any] | None: 

38 

39 if not self.context.config or not hasattr(self.context.config, "get") or not self.context.config.get(name, None): 

40 return None 

41 

42 section_config: int | dict[str, Any] | None = self.context.config.get(name, None) 

43 # mypy has difficulty excluding int from `config`'s type, unless there's an explicit check 

44 if isinstance(section_config, int): 44 ↛ 45line 44 didn't jump to line 45 because the condition on line 44 was never true

45 return None 

46 return section_config 

47 

48 # Deprecated : supports entrypoint-style configs as well as dataclass configuration. 

49 def _get_config_option(self, option_name: str, default: Any = None) -> Any: 

50 if not self.context.config: 50 ↛ 51line 50 didn't jump to line 51 because the condition on line 50 was never true

51 return default 

52 

53 if is_dataclass(self.context.config): 53 ↛ 56line 53 didn't jump to line 56 because the condition on line 53 was always true

54 # overloaded context.config for BasePlugin `Config` class, so ignoring static type check 

55 return getattr(self.context.config, option_name.replace("-", "_"), default) 

56 if option_name in self.context.config: 

57 return self.context.config[option_name] 

58 return default 

59 

60 @dataclass 

61 class Config: 

62 """Override to define the configuration and defaults for plugin.""" 

63 

64 async def close(self) -> None: 

65 """Override if plugin needs to clean up resources upon shutdown.""" 

66 

67 

68class BaseTopicPlugin(BasePlugin[BaseContext]): 

69 """Base class for topic plugins.""" 

70 

71 def __init__(self, context: BaseContext) -> None: 

72 super().__init__(context) 

73 

74 self.topic_config: dict[str, Any] | None = self._get_config_section("topic-check") 

75 if not bool(self.topic_config) and not is_dataclass(self.context.config): 

76 self.context.logger.warning("'topic-check' section not found in context configuration") 

77 

78 def _get_config_option(self, option_name: str, default: Any = None) -> Any: 

79 if not self.context.config: 

80 return default 

81 

82 # overloaded context.config with either BrokerConfig or plugin's Config 

83 if is_dataclass(self.context.config) and not isinstance(self.context.config, BrokerConfig): 

84 # overloaded context.config for BasePlugin `Config` class, so ignoring static type check 

85 return getattr(self.context.config, option_name.replace("-", "_"), default) 

86 if self.topic_config and option_name in self.topic_config: 

87 return self.topic_config[option_name] 

88 return default 

89 

90 async def topic_filtering( 

91 self, *, session: Session | None = None, topic: str | None = None, action: Action | None = None 

92 ) -> bool | None: 

93 """Logic for filtering out topics. 

94 

95 Args: 

96 session: amqtt.session.Session 

97 topic: str 

98 action: amqtt.broker.Action 

99 

100 Returns: 

101 bool: `True` if topic is allowed, `False` otherwise. `None` if it can't be determined 

102 

103 """ 

104 return bool(self.topic_config) or is_dataclass(self.context.config) 

105 

106 

107class BaseAuthPlugin(BasePlugin[BaseContext]): 

108 """Base class for authentication plugins.""" 

109 

110 def _get_config_option(self, option_name: str, default: Any = None) -> Any: 

111 if not self.context.config: 111 ↛ 112line 111 didn't jump to line 112 because the condition on line 111 was never true

112 return default 

113 

114 if is_dataclass(self.context.config) and not isinstance(self.context.config, BrokerConfig): 

115 # overloaded context.config for BasePlugin `Config` class, so ignoring static type check 

116 return getattr(self.context.config, option_name.replace("-", "_"), default) 

117 if self.auth_config and option_name in self.auth_config: 

118 return self.auth_config[option_name] 

119 return default 

120 

121 def __init__(self, context: BaseContext) -> None: 

122 super().__init__(context) 

123 

124 self.auth_config: dict[str, Any] | None = self._get_config_section("auth") 

125 if not bool(self.auth_config) and not is_dataclass(self.context.config): 

126 # auth config section not found and Config dataclass not provided 

127 self.context.logger.warning("'auth' section not found in context configuration") 

128 

129 async def authenticate(self, *, session: Session) -> bool | None: 

130 """Logic for session authentication. 

131 

132 Args: 

133 session: amqtt.session.Session 

134 

135 Returns: 

136 - `True` if user is authentication succeed, `False` if user authentication fails 

137 - `None` if authentication can't be achieved (then plugin result is then ignored) 

138 

139 """ 

140 return bool(self.auth_config) or is_dataclass(self.context.config)