Coverage for amqtt/plugins/manager.py: 81%
234 statements
« prev ^ index » next coverage.py v7.8.2, created at 2025-08-12 14:35 +0000
« prev ^ index » next coverage.py v7.8.2, created at 2025-08-12 14:35 +0000
1__all__ = ["PluginManager", "get_plugin_manager"]
3import asyncio
4from collections import defaultdict
5from collections.abc import Awaitable, Callable, Coroutine
6import contextlib
7import copy
8from importlib.metadata import EntryPoint, EntryPoints, entry_points
9from inspect import iscoroutinefunction
10import logging
11import sys
12import traceback
13from typing import Any, Generic, NamedTuple, Optional, TypeAlias, TypeVar, cast
14import warnings
16from dacite import Config as DaciteConfig, DaciteError, from_dict
18from amqtt.contexts import Action, BaseContext
19from amqtt.errors import PluginCoroError, PluginImportError, PluginInitError, PluginLoadError
20from amqtt.events import BrokerEvents, Events, MQTTEvents
21from amqtt.plugins.base import BaseAuthPlugin, BasePlugin, BaseTopicPlugin
22from amqtt.session import Session
23from amqtt.utils import import_string
26class Plugin(NamedTuple):
27 name: str
28 ep: EntryPoint
29 object: Any
32plugins_manager: dict[str, "PluginManager[Any]"] = {}
35def get_plugin_manager(namespace: str) -> "PluginManager[Any] | None":
36 """Get the plugin manager for a given namespace.
38 :param namespace: The namespace of the plugin manager to retrieve.
39 :return: The plugin manager for the given namespace, or None if it doesn't exist.
40 """
41 return plugins_manager.get(namespace)
44def safe_issubclass(sub_class: Any, super_class: Any) -> bool:
45 try:
46 return issubclass(sub_class, super_class)
47 except TypeError:
48 return False
51AsyncFunc: TypeAlias = Callable[..., Coroutine[Any, Any, None]]
52C = TypeVar("C", bound=BaseContext)
55class PluginManager(Generic[C]):
56 """Wraps contextlib Entry point mechanism to provide a basic plugin system.
58 Plugins are loaded for a given namespace (group). This plugin manager uses coroutines to
59 run plugin calls asynchronously in an event queue.
60 """
62 def __init__(self, namespace: str, context: C | None, loop: asyncio.AbstractEventLoop | None = None) -> None:
63 try:
64 self._loop = loop if loop is not None else asyncio.get_running_loop()
65 except RuntimeError:
66 self._loop = asyncio.new_event_loop()
67 asyncio.set_event_loop(self._loop)
69 self.logger = logging.getLogger(namespace)
70 self.context = context if context is not None else BaseContext()
71 self.context.loop = self._loop
72 self._plugins: list[BasePlugin[C]] = []
73 self._auth_plugins: list[BaseAuthPlugin] = []
74 self._topic_plugins: list[BaseTopicPlugin] = []
75 self._event_plugin_callbacks: dict[str, list[AsyncFunc]] = defaultdict(list)
76 self._is_topic_filtering_enabled = False
77 self._is_auth_filtering_enabled = False
79 self._load_plugins(namespace)
80 self._fired_events: list[asyncio.Future[Any]] = []
81 plugins_manager[namespace] = self
83 @property
84 def app_context(self) -> BaseContext:
85 return self.context
87 def _load_plugins(self, namespace: str | None = None) -> None:
88 """Load plugins from entrypoint or config dictionary.
90 config style is now recommended; entrypoint has been deprecated
91 Example:
92 config = {
93 'listeners':...,
94 'plugins': {
95 'myproject.myfile.MyPlugin': {}
96 }
97 """
98 if self.app_context.config and self.app_context.config.get("plugins", None) is not None:
99 # plugins loaded directly from config dictionary
101 if "auth" in self.app_context.config and self.app_context.config["auth"] is not None: 101 ↛ 102line 101 didn't jump to line 102 because the condition on line 101 was never true
102 self.logger.warning("Loading plugins from config will ignore 'auth' section of config")
103 if "topic-check" in self.app_context.config and self.app_context.config["topic-check"] is not None: 103 ↛ 104line 103 didn't jump to line 104 because the condition on line 103 was never true
104 self.logger.warning("Loading plugins from config will ignore 'topic-check' section of config")
106 plugins_config: list[Any] | dict[str, Any] = self.app_context.config.get("plugins", [])
108 # if the config was generated from yaml, the plugins maybe a list instead of a dictionary; transform before loading
109 #
110 # plugins:
111 # - myproject.myfile.MyPlugin:
113 if isinstance(plugins_config, list): 113 ↛ 114line 113 didn't jump to line 114 because the condition on line 113 was never true
114 plugins_info: dict[str, Any] = {}
115 for plugin_config in plugins_config:
116 if isinstance(plugin_config, str):
117 plugins_info.update({plugin_config: {}})
118 elif not isinstance(plugin_config, dict):
119 msg = "malformed 'plugins' configuration"
120 raise PluginLoadError(msg)
121 else:
122 plugins_info.update(plugin_config)
123 self._load_str_plugins(plugins_info)
124 elif isinstance(plugins_config, dict): 124 ↛ 141line 124 didn't jump to line 141 because the condition on line 124 was always true
125 self._load_str_plugins(plugins_config)
126 else:
127 if not namespace: 127 ↛ 128line 127 didn't jump to line 128 because the condition on line 127 was never true
128 msg = "Namespace needs to be provided for EntryPoint plugin definitions"
129 raise PluginLoadError(msg)
131 warnings.warn(
132 "Loading plugins from EntryPoints is deprecated and will be removed in a future version."
133 " Use `plugins` section of config instead.",
134 DeprecationWarning,
135 stacklevel=4
136 )
138 self._load_ep_plugins(namespace)
140 # for all the loaded plugins, find all event callbacks
141 for plugin in self._plugins:
142 for event in list(BrokerEvents) + list(MQTTEvents):
143 if awaitable := getattr(plugin, f"on_{event}", None):
144 if not iscoroutinefunction(awaitable): 144 ↛ 145line 144 didn't jump to line 145 because the condition on line 144 was never true
145 msg = f"'on_{event}' for '{plugin.__class__.__name__}' is not a coroutine'"
146 raise PluginImportError(msg)
147 self.logger.debug(f"'{event}' handler found for '{plugin.__class__.__name__}'")
148 self._event_plugin_callbacks[event].append(awaitable)
150 def _load_ep_plugins(self, namespace: str) -> None:
151 """Load plugins from `pyproject.toml` entrypoints. Deprecated."""
152 self.logger.debug(f"Loading plugins for namespace {namespace}")
153 auth_filter_list = []
154 topic_filter_list = []
155 if self.app_context.config and "auth" in self.app_context.config:
156 auth_filter_list = self.app_context.config["auth"].get("plugins", None)
157 if self.app_context.config and "topic-check" in self.app_context.config:
158 topic_filter_list = self.app_context.config["topic-check"].get("plugins", None)
160 ep: EntryPoints | list[EntryPoint] = []
161 if hasattr(entry_points(), "select"): 161 ↛ 163line 161 didn't jump to line 163 because the condition on line 161 was always true
162 ep = entry_points().select(group=namespace)
163 elif namespace in entry_points():
164 ep = [entry_points()[namespace]]
166 for item in ep:
167 ep_plugin = self._load_ep_plugin(item)
168 if ep_plugin is not None: 168 ↛ 166line 168 didn't jump to line 166 because the condition on line 168 was always true
169 self._plugins.append(ep_plugin.object)
170 # maintain legacy behavior that if there is no list, use all auth plugins
171 if ((auth_filter_list is None or ep_plugin.name in auth_filter_list)
172 and hasattr(ep_plugin.object, "authenticate")):
173 self._auth_plugins.append(ep_plugin.object)
174 # maintain legacy behavior that if there is no list, use all topic plugins
175 if ((topic_filter_list is None or ep_plugin.name in topic_filter_list)
176 and hasattr(ep_plugin.object, "topic_filtering")):
177 self._topic_plugins.append(ep_plugin.object)
178 self.logger.debug(f" Plugin {item.name} ready")
180 def _load_ep_plugin(self, ep: EntryPoint) -> Plugin | None:
181 """Load plugins from `pyproject.toml` entrypoints. Deprecated."""
182 try:
183 self.logger.debug(f" Loading plugin {ep!s}")
184 plugin = ep.load()
186 except ImportError as e:
187 self.logger.debug(f"Plugin import failed: {ep!r}", exc_info=True)
188 raise PluginImportError(ep) from e
190 self.logger.debug(f" Initializing plugin {ep!s}")
192 plugin_context = copy.copy(self.app_context)
193 plugin_context.logger = self.logger.getChild(ep.name)
194 try:
195 obj = plugin(plugin_context)
196 return Plugin(ep.name, ep, obj)
197 except Exception as e:
198 self.logger.debug(f"Plugin init failed: {ep!r}", exc_info=True)
199 raise PluginInitError(ep) from e
201 def _load_str_plugins(self, plugins_info: dict[str, Any]) -> None:
203 self.logger.info("Loading plugins from config")
204 # legacy had a filtering 'enabled' flag, even if plugins were loaded/listed
205 self._is_topic_filtering_enabled = True
206 self._is_auth_filtering_enabled = True
207 for plugin_path, plugin_config in plugins_info.items():
209 plugin = self._load_str_plugin(plugin_path, plugin_config)
210 self._plugins.append(plugin)
212 # make sure that authenticate and topic filtering plugins have the appropriate async signature
213 if isinstance(plugin, BaseAuthPlugin):
214 if not iscoroutinefunction(plugin.authenticate):
215 msg = f"Auth plugin {plugin_path} has non-async authenticate method."
216 raise PluginCoroError(msg)
217 self._auth_plugins.append(plugin)
218 if isinstance(plugin, BaseTopicPlugin):
219 if not iscoroutinefunction(plugin.topic_filtering): 219 ↛ 220line 219 didn't jump to line 220 because the condition on line 219 was never true
220 msg = f"Topic plugin {plugin_path} has non-async topic_filtering method."
221 raise PluginCoroError(msg)
222 self._topic_plugins.append(plugin)
224 def _load_str_plugin(self, plugin_path: str, plugin_cfg: dict[str, Any] | None = None) -> "BasePlugin[C]":
225 """Load plugin from string dotted path: mymodule.myfile.MyPlugin."""
226 try:
227 plugin_class: Any = import_string(plugin_path)
228 except ImportError as ep:
229 msg = f"Plugin import failed: {plugin_path}"
230 raise PluginImportError(msg) from ep
232 if not safe_issubclass(plugin_class, BasePlugin): 232 ↛ 233line 232 didn't jump to line 233 because the condition on line 232 was never true
233 msg = f"Plugin {plugin_path} is not a subclass of 'BasePlugin'"
234 raise PluginLoadError(msg)
236 plugin_context = copy.copy(self.app_context)
237 plugin_context.logger = self.logger.getChild(plugin_class.__name__)
238 try:
239 # populate the config based on the inner dataclass called `Config`
240 # use `dacite` package to type check
241 plugin_context.config = from_dict(data_class=plugin_class.Config,
242 data=plugin_cfg or {},
243 config=DaciteConfig(strict=True))
244 except DaciteError as e:
245 raise PluginLoadError from e
246 except TypeError as e:
247 msg = f"Could not marshall 'Config' of {plugin_path}; should be a dataclass."
248 raise PluginLoadError(msg) from e
250 try:
251 pc = plugin_class(plugin_context)
252 self.logger.debug(f"Loading plugin {plugin_path}")
253 return cast("BasePlugin[C]", pc)
254 except Exception as e:
255 self.logger.debug(f"Plugin init failed: {plugin_class.__name__}", exc_info=True)
256 raise PluginInitError(plugin_class) from e
258 def get_plugin(self, name: str) -> Optional["BasePlugin[C]"]:
259 """Get a plugin by its name from the plugins loaded for the current namespace.
261 Only used for testing purposes to verify plugin loading correctly.
263 :param name:
264 :return:
265 """
266 for p in self._plugins: 266 ↛ 269line 266 didn't jump to line 269 because the loop on line 266 didn't complete
267 if p.__class__.__name__ == name:
268 return p
269 return None
271 def is_topic_filtering_enabled(self) -> bool:
272 topic_config = self.app_context.config.get("topic-check", {}) if self.app_context.config else {}
273 if isinstance(topic_config, dict):
274 return topic_config.get("enabled", False) or self._is_topic_filtering_enabled
275 return False or self._is_topic_filtering_enabled
277 async def close(self) -> None:
278 """Free PluginManager resources and cancel pending event methods."""
279 await self.map_plugin_close()
280 for task in self._fired_events: 280 ↛ 281line 280 didn't jump to line 281 because the loop on line 280 never started
281 task.cancel()
282 self._fired_events.clear()
284 @property
285 def plugins(self) -> list["BasePlugin[C]"]:
286 """Get the loaded plugins list.
288 :return:
289 """
290 return self._plugins
292 def _schedule_coro(self, coro: Awaitable[str | bool | None]) -> asyncio.Future[str | bool | None]:
293 return asyncio.ensure_future(coro)
295 def _clean_fired_events(self, future: asyncio.Future[Any]) -> None:
296 if self.logger.getEffectiveLevel() <= logging.DEBUG:
297 try:
298 future.result()
299 except asyncio.CancelledError:
300 self.logger.warning("fired event was cancelled")
301 # display plugin fault; don't allow it to cause a broker failure
302 except Exception as exc: # noqa: BLE001, pylint: disable=W0718
303 traceback.print_exception(type(exc), exc, exc.__traceback__, file=sys.stderr)
305 with contextlib.suppress(KeyError, ValueError):
306 self._fired_events.remove(future)
308 async def fire_event(self, event_name: Events, *, wait: bool = False, **method_kwargs: Any) -> None:
309 """Fire an event to plugins.
311 PluginManager schedules async calls for each plugin on method called "on_" + event_name.
312 For example, on_connect will be called on event 'connect'.
313 Method calls are scheduled in the async loop. wait parameter must be set to true
314 to wait until all methods are completed.
315 :param event_name:
316 :param method_kwargs:
317 :param wait: indicates if fire_event should wait for plugin calls completion (True), or not
318 :return:
319 """
320 tasks: list[asyncio.Future[Any]] = []
322 # check if any plugin has defined a callback for this event, skip if none
323 if event_name not in self._event_plugin_callbacks:
324 return
326 for event_awaitable in self._event_plugin_callbacks[event_name]:
328 async def call_method(method: AsyncFunc, kwargs: dict[str, Any]) -> Any:
329 return await method(**kwargs)
331 coro_instance: Awaitable[Any] = call_method(event_awaitable, method_kwargs)
332 tasks.append(asyncio.ensure_future(coro_instance))
333 tasks[-1].add_done_callback(self._clean_fired_events)
335 self._fired_events.extend(tasks)
336 if wait and tasks:
337 await asyncio.wait(tasks)
338 self.logger.debug(f"Plugins len(_fired_events)={len(self._fired_events)}")
340 @staticmethod
341 async def _map_plugin_method(
342 plugins: list["BasePlugin[C]"],
343 method_name: str,
344 method_kwargs: dict[str, Any],
345 ) -> dict["BasePlugin[C]", str | bool | None]:
346 """Call plugin coroutines.
348 :param plugins: List of plugins to execute the method on
349 :param method_name: Name of the method to call on each plugin
350 :param method_kwargs: Keyword arguments to pass to the method
351 :return: dict containing return from coro call for each plugin.
352 """
353 tasks: list[asyncio.Future[Any]] = []
355 for plugin in plugins:
356 if not hasattr(plugin, method_name):
357 continue
359 async def call_method(p: "BasePlugin[C]", kwargs: dict[str, Any]) -> Any:
360 method = getattr(p, method_name)
361 return await method(**kwargs)
363 coro_instance: Awaitable[Any] = call_method(plugin, method_kwargs)
364 tasks.append(asyncio.ensure_future(coro_instance))
366 ret_dict: dict[BasePlugin[C], str | bool | None] = {}
367 if tasks:
368 ret_list = await asyncio.gather(*tasks)
369 ret_dict = dict(zip(plugins, ret_list, strict=False))
371 return ret_dict
373 async def map_plugin_auth(self, *, session: Session) -> dict["BasePlugin[C]", str | bool | None]:
374 """Schedule a coroutine for plugin 'authenticate' calls.
376 :param session: the client session associated with the authentication check
377 :return: dict containing return from coro call for each plugin.
378 """
379 return await self._map_plugin_method(
380 self._auth_plugins, "authenticate", {"session": session}) # type: ignore[arg-type]
382 async def map_plugin_topic(
383 self, *, session: Session, topic: str, action: "Action"
384 ) -> dict["BasePlugin[C]", str | bool | None]:
385 """Schedule a coroutine for plugin 'topic_filtering' calls.
387 :param session: the client session associated with the topic_filtering check
388 :param topic: the topic that needs to be filtered
389 :param action: the action being executed
390 :return: dict containing return from coro call for each plugin.
391 """
392 return await self._map_plugin_method(
393 self._topic_plugins, "topic_filtering", # type: ignore[arg-type]
394 {"session": session, "topic": topic, "action": action}
395 )
397 async def map_plugin_close(self) -> None:
398 """Schedule a coroutine for plugin 'close' calls.
400 :return: dict containing return from coro call for each plugin.
401 """
402 await self._map_plugin_method(self._plugins, "close", {})