Coverage for amqtt/plugins/base.py: 86%
59 statements
« prev ^ index » next coverage.py v7.8.2, created at 2025-08-12 14:35 +0000
« prev ^ index » next coverage.py v7.8.2, created at 2025-08-12 14:35 +0000
1from dataclasses import dataclass, is_dataclass
2from typing import Any, Generic, TypeVar, cast
4from amqtt.contexts import Action, BaseContext, BrokerConfig
5from amqtt.session import Session
7C = TypeVar("C", bound=BaseContext)
10class BasePlugin(Generic[C]):
11 """The base from which all plugins should inherit.
13 Type Parameters
14 ---------------
15 C:
16 A BaseContext: either BrokerContext or ClientContext, depending on plugin usage
18 Attributes
19 ----------
20 context (C):
21 Information about the environment in which this plugin is executed. Modifying
22 the broker or client state should happen through methods available here.
24 config (self.Config):
25 An instance of the Config dataclass defined by the plugin (or an empty dataclass, if not
26 defined). If using entrypoint- or mixed-style configuration, use `_get_config_option()`
27 to access the variable.
29 """
31 def __init__(self, context: C) -> None:
32 self.context: C = context
33 # since the PluginManager will hydrate the config from a plugin's `Config` class, this is a safe cast
34 self.config = cast("self.Config", context.config) # type: ignore[name-defined]
36 # Deprecated: included to support entrypoint-style configs. Replaced by dataclass Config class.
37 def _get_config_section(self, name: str) -> dict[str, Any] | None:
39 if not self.context.config or not hasattr(self.context.config, "get") or not self.context.config.get(name, None):
40 return None
42 section_config: int | dict[str, Any] | None = self.context.config.get(name, None)
43 # mypy has difficulty excluding int from `config`'s type, unless there's an explicit check
44 if isinstance(section_config, int): 44 ↛ 45line 44 didn't jump to line 45 because the condition on line 44 was never true
45 return None
46 return section_config
48 # Deprecated : supports entrypoint-style configs as well as dataclass configuration.
49 def _get_config_option(self, option_name: str, default: Any = None) -> Any:
50 if not self.context.config: 50 ↛ 51line 50 didn't jump to line 51 because the condition on line 50 was never true
51 return default
53 if is_dataclass(self.context.config): 53 ↛ 56line 53 didn't jump to line 56 because the condition on line 53 was always true
54 # overloaded context.config for BasePlugin `Config` class, so ignoring static type check
55 return getattr(self.context.config, option_name.replace("-", "_"), default)
56 if option_name in self.context.config:
57 return self.context.config[option_name]
58 return default
60 @dataclass
61 class Config:
62 """Override to define the configuration and defaults for plugin."""
64 async def close(self) -> None:
65 """Override if plugin needs to clean up resources upon shutdown."""
68class BaseTopicPlugin(BasePlugin[BaseContext]):
69 """Base class for topic plugins."""
71 def __init__(self, context: BaseContext) -> None:
72 super().__init__(context)
74 self.topic_config: dict[str, Any] | None = self._get_config_section("topic-check")
75 if not bool(self.topic_config) and not is_dataclass(self.context.config):
76 self.context.logger.warning("'topic-check' section not found in context configuration")
78 def _get_config_option(self, option_name: str, default: Any = None) -> Any:
79 if not self.context.config:
80 return default
82 # overloaded context.config with either BrokerConfig or plugin's Config
83 if is_dataclass(self.context.config) and not isinstance(self.context.config, BrokerConfig):
84 # overloaded context.config for BasePlugin `Config` class, so ignoring static type check
85 return getattr(self.context.config, option_name.replace("-", "_"), default)
86 if self.topic_config and option_name in self.topic_config:
87 return self.topic_config[option_name]
88 return default
90 async def topic_filtering(
91 self, *, session: Session | None = None, topic: str | None = None, action: Action | None = None
92 ) -> bool | None:
93 """Logic for filtering out topics.
95 Args:
96 session: amqtt.session.Session
97 topic: str
98 action: amqtt.broker.Action
100 Returns:
101 bool: `True` if topic is allowed, `False` otherwise. `None` if it can't be determined
103 """
104 return bool(self.topic_config) or is_dataclass(self.context.config)
107class BaseAuthPlugin(BasePlugin[BaseContext]):
108 """Base class for authentication plugins."""
110 def _get_config_option(self, option_name: str, default: Any = None) -> Any:
111 if not self.context.config: 111 ↛ 112line 111 didn't jump to line 112 because the condition on line 111 was never true
112 return default
114 if is_dataclass(self.context.config) and not isinstance(self.context.config, BrokerConfig):
115 # overloaded context.config for BasePlugin `Config` class, so ignoring static type check
116 return getattr(self.context.config, option_name.replace("-", "_"), default)
117 if self.auth_config and option_name in self.auth_config:
118 return self.auth_config[option_name]
119 return default
121 def __init__(self, context: BaseContext) -> None:
122 super().__init__(context)
124 self.auth_config: dict[str, Any] | None = self._get_config_section("auth")
125 if not bool(self.auth_config) and not is_dataclass(self.context.config):
126 # auth config section not found and Config dataclass not provided
127 self.context.logger.warning("'auth' section not found in context configuration")
129 async def authenticate(self, *, session: Session) -> bool | None:
130 """Logic for session authentication.
132 Args:
133 session: amqtt.session.Session
135 Returns:
136 - `True` if user is authentication succeed, `False` if user authentication fails
137 - `None` if authentication can't be achieved (then plugin result is then ignored)
139 """
140 return bool(self.auth_config) or is_dataclass(self.context.config)