Coverage for amqtt/plugins/manager.py: 81%

234 statements  

« prev     ^ index     » next       coverage.py v7.8.2, created at 2025-08-12 14:35 +0000

1__all__ = ["PluginManager", "get_plugin_manager"] 

2 

3import asyncio 

4from collections import defaultdict 

5from collections.abc import Awaitable, Callable, Coroutine 

6import contextlib 

7import copy 

8from importlib.metadata import EntryPoint, EntryPoints, entry_points 

9from inspect import iscoroutinefunction 

10import logging 

11import sys 

12import traceback 

13from typing import Any, Generic, NamedTuple, Optional, TypeAlias, TypeVar, cast 

14import warnings 

15 

16from dacite import Config as DaciteConfig, DaciteError, from_dict 

17 

18from amqtt.contexts import Action, BaseContext 

19from amqtt.errors import PluginCoroError, PluginImportError, PluginInitError, PluginLoadError 

20from amqtt.events import BrokerEvents, Events, MQTTEvents 

21from amqtt.plugins.base import BaseAuthPlugin, BasePlugin, BaseTopicPlugin 

22from amqtt.session import Session 

23from amqtt.utils import import_string 

24 

25 

26class Plugin(NamedTuple): 

27 name: str 

28 ep: EntryPoint 

29 object: Any 

30 

31 

32plugins_manager: dict[str, "PluginManager[Any]"] = {} 

33 

34 

35def get_plugin_manager(namespace: str) -> "PluginManager[Any] | None": 

36 """Get the plugin manager for a given namespace. 

37 

38 :param namespace: The namespace of the plugin manager to retrieve. 

39 :return: The plugin manager for the given namespace, or None if it doesn't exist. 

40 """ 

41 return plugins_manager.get(namespace) 

42 

43 

44def safe_issubclass(sub_class: Any, super_class: Any) -> bool: 

45 try: 

46 return issubclass(sub_class, super_class) 

47 except TypeError: 

48 return False 

49 

50 

51AsyncFunc: TypeAlias = Callable[..., Coroutine[Any, Any, None]] 

52C = TypeVar("C", bound=BaseContext) 

53 

54 

55class PluginManager(Generic[C]): 

56 """Wraps contextlib Entry point mechanism to provide a basic plugin system. 

57 

58 Plugins are loaded for a given namespace (group). This plugin manager uses coroutines to 

59 run plugin calls asynchronously in an event queue. 

60 """ 

61 

62 def __init__(self, namespace: str, context: C | None, loop: asyncio.AbstractEventLoop | None = None) -> None: 

63 try: 

64 self._loop = loop if loop is not None else asyncio.get_running_loop() 

65 except RuntimeError: 

66 self._loop = asyncio.new_event_loop() 

67 asyncio.set_event_loop(self._loop) 

68 

69 self.logger = logging.getLogger(namespace) 

70 self.context = context if context is not None else BaseContext() 

71 self.context.loop = self._loop 

72 self._plugins: list[BasePlugin[C]] = [] 

73 self._auth_plugins: list[BaseAuthPlugin] = [] 

74 self._topic_plugins: list[BaseTopicPlugin] = [] 

75 self._event_plugin_callbacks: dict[str, list[AsyncFunc]] = defaultdict(list) 

76 self._is_topic_filtering_enabled = False 

77 self._is_auth_filtering_enabled = False 

78 

79 self._load_plugins(namespace) 

80 self._fired_events: list[asyncio.Future[Any]] = [] 

81 plugins_manager[namespace] = self 

82 

83 @property 

84 def app_context(self) -> BaseContext: 

85 return self.context 

86 

87 def _load_plugins(self, namespace: str | None = None) -> None: 

88 """Load plugins from entrypoint or config dictionary. 

89 

90 config style is now recommended; entrypoint has been deprecated 

91 Example: 

92 config = { 

93 'listeners':..., 

94 'plugins': { 

95 'myproject.myfile.MyPlugin': {} 

96 } 

97 """ 

98 if self.app_context.config and self.app_context.config.get("plugins", None) is not None: 

99 # plugins loaded directly from config dictionary 

100 

101 if "auth" in self.app_context.config and self.app_context.config["auth"] is not None: 101 ↛ 102line 101 didn't jump to line 102 because the condition on line 101 was never true

102 self.logger.warning("Loading plugins from config will ignore 'auth' section of config") 

103 if "topic-check" in self.app_context.config and self.app_context.config["topic-check"] is not None: 103 ↛ 104line 103 didn't jump to line 104 because the condition on line 103 was never true

104 self.logger.warning("Loading plugins from config will ignore 'topic-check' section of config") 

105 

106 plugins_config: list[Any] | dict[str, Any] = self.app_context.config.get("plugins", []) 

107 

108 # if the config was generated from yaml, the plugins maybe a list instead of a dictionary; transform before loading 

109 # 

110 # plugins: 

111 # - myproject.myfile.MyPlugin: 

112 

113 if isinstance(plugins_config, list): 113 ↛ 114line 113 didn't jump to line 114 because the condition on line 113 was never true

114 plugins_info: dict[str, Any] = {} 

115 for plugin_config in plugins_config: 

116 if isinstance(plugin_config, str): 

117 plugins_info.update({plugin_config: {}}) 

118 elif not isinstance(plugin_config, dict): 

119 msg = "malformed 'plugins' configuration" 

120 raise PluginLoadError(msg) 

121 else: 

122 plugins_info.update(plugin_config) 

123 self._load_str_plugins(plugins_info) 

124 elif isinstance(plugins_config, dict): 124 ↛ 141line 124 didn't jump to line 141 because the condition on line 124 was always true

125 self._load_str_plugins(plugins_config) 

126 else: 

127 if not namespace: 127 ↛ 128line 127 didn't jump to line 128 because the condition on line 127 was never true

128 msg = "Namespace needs to be provided for EntryPoint plugin definitions" 

129 raise PluginLoadError(msg) 

130 

131 warnings.warn( 

132 "Loading plugins from EntryPoints is deprecated and will be removed in a future version." 

133 " Use `plugins` section of config instead.", 

134 DeprecationWarning, 

135 stacklevel=4 

136 ) 

137 

138 self._load_ep_plugins(namespace) 

139 

140 # for all the loaded plugins, find all event callbacks 

141 for plugin in self._plugins: 

142 for event in list(BrokerEvents) + list(MQTTEvents): 

143 if awaitable := getattr(plugin, f"on_{event}", None): 

144 if not iscoroutinefunction(awaitable): 144 ↛ 145line 144 didn't jump to line 145 because the condition on line 144 was never true

145 msg = f"'on_{event}' for '{plugin.__class__.__name__}' is not a coroutine'" 

146 raise PluginImportError(msg) 

147 self.logger.debug(f"'{event}' handler found for '{plugin.__class__.__name__}'") 

148 self._event_plugin_callbacks[event].append(awaitable) 

149 

150 def _load_ep_plugins(self, namespace: str) -> None: 

151 """Load plugins from `pyproject.toml` entrypoints. Deprecated.""" 

152 self.logger.debug(f"Loading plugins for namespace {namespace}") 

153 auth_filter_list = [] 

154 topic_filter_list = [] 

155 if self.app_context.config and "auth" in self.app_context.config: 

156 auth_filter_list = self.app_context.config["auth"].get("plugins", None) 

157 if self.app_context.config and "topic-check" in self.app_context.config: 

158 topic_filter_list = self.app_context.config["topic-check"].get("plugins", None) 

159 

160 ep: EntryPoints | list[EntryPoint] = [] 

161 if hasattr(entry_points(), "select"): 161 ↛ 163line 161 didn't jump to line 163 because the condition on line 161 was always true

162 ep = entry_points().select(group=namespace) 

163 elif namespace in entry_points(): 

164 ep = [entry_points()[namespace]] 

165 

166 for item in ep: 

167 ep_plugin = self._load_ep_plugin(item) 

168 if ep_plugin is not None: 168 ↛ 166line 168 didn't jump to line 166 because the condition on line 168 was always true

169 self._plugins.append(ep_plugin.object) 

170 # maintain legacy behavior that if there is no list, use all auth plugins 

171 if ((auth_filter_list is None or ep_plugin.name in auth_filter_list) 

172 and hasattr(ep_plugin.object, "authenticate")): 

173 self._auth_plugins.append(ep_plugin.object) 

174 # maintain legacy behavior that if there is no list, use all topic plugins 

175 if ((topic_filter_list is None or ep_plugin.name in topic_filter_list) 

176 and hasattr(ep_plugin.object, "topic_filtering")): 

177 self._topic_plugins.append(ep_plugin.object) 

178 self.logger.debug(f" Plugin {item.name} ready") 

179 

180 def _load_ep_plugin(self, ep: EntryPoint) -> Plugin | None: 

181 """Load plugins from `pyproject.toml` entrypoints. Deprecated.""" 

182 try: 

183 self.logger.debug(f" Loading plugin {ep!s}") 

184 plugin = ep.load() 

185 

186 except ImportError as e: 

187 self.logger.debug(f"Plugin import failed: {ep!r}", exc_info=True) 

188 raise PluginImportError(ep) from e 

189 

190 self.logger.debug(f" Initializing plugin {ep!s}") 

191 

192 plugin_context = copy.copy(self.app_context) 

193 plugin_context.logger = self.logger.getChild(ep.name) 

194 try: 

195 obj = plugin(plugin_context) 

196 return Plugin(ep.name, ep, obj) 

197 except Exception as e: 

198 self.logger.debug(f"Plugin init failed: {ep!r}", exc_info=True) 

199 raise PluginInitError(ep) from e 

200 

201 def _load_str_plugins(self, plugins_info: dict[str, Any]) -> None: 

202 

203 self.logger.info("Loading plugins from config") 

204 # legacy had a filtering 'enabled' flag, even if plugins were loaded/listed 

205 self._is_topic_filtering_enabled = True 

206 self._is_auth_filtering_enabled = True 

207 for plugin_path, plugin_config in plugins_info.items(): 

208 

209 plugin = self._load_str_plugin(plugin_path, plugin_config) 

210 self._plugins.append(plugin) 

211 

212 # make sure that authenticate and topic filtering plugins have the appropriate async signature 

213 if isinstance(plugin, BaseAuthPlugin): 

214 if not iscoroutinefunction(plugin.authenticate): 

215 msg = f"Auth plugin {plugin_path} has non-async authenticate method." 

216 raise PluginCoroError(msg) 

217 self._auth_plugins.append(plugin) 

218 if isinstance(plugin, BaseTopicPlugin): 

219 if not iscoroutinefunction(plugin.topic_filtering): 219 ↛ 220line 219 didn't jump to line 220 because the condition on line 219 was never true

220 msg = f"Topic plugin {plugin_path} has non-async topic_filtering method." 

221 raise PluginCoroError(msg) 

222 self._topic_plugins.append(plugin) 

223 

224 def _load_str_plugin(self, plugin_path: str, plugin_cfg: dict[str, Any] | None = None) -> "BasePlugin[C]": 

225 """Load plugin from string dotted path: mymodule.myfile.MyPlugin.""" 

226 try: 

227 plugin_class: Any = import_string(plugin_path) 

228 except ImportError as ep: 

229 msg = f"Plugin import failed: {plugin_path}" 

230 raise PluginImportError(msg) from ep 

231 

232 if not safe_issubclass(plugin_class, BasePlugin): 232 ↛ 233line 232 didn't jump to line 233 because the condition on line 232 was never true

233 msg = f"Plugin {plugin_path} is not a subclass of 'BasePlugin'" 

234 raise PluginLoadError(msg) 

235 

236 plugin_context = copy.copy(self.app_context) 

237 plugin_context.logger = self.logger.getChild(plugin_class.__name__) 

238 try: 

239 # populate the config based on the inner dataclass called `Config` 

240 # use `dacite` package to type check 

241 plugin_context.config = from_dict(data_class=plugin_class.Config, 

242 data=plugin_cfg or {}, 

243 config=DaciteConfig(strict=True)) 

244 except DaciteError as e: 

245 raise PluginLoadError from e 

246 except TypeError as e: 

247 msg = f"Could not marshall 'Config' of {plugin_path}; should be a dataclass." 

248 raise PluginLoadError(msg) from e 

249 

250 try: 

251 pc = plugin_class(plugin_context) 

252 self.logger.debug(f"Loading plugin {plugin_path}") 

253 return cast("BasePlugin[C]", pc) 

254 except Exception as e: 

255 self.logger.debug(f"Plugin init failed: {plugin_class.__name__}", exc_info=True) 

256 raise PluginInitError(plugin_class) from e 

257 

258 def get_plugin(self, name: str) -> Optional["BasePlugin[C]"]: 

259 """Get a plugin by its name from the plugins loaded for the current namespace. 

260 

261 Only used for testing purposes to verify plugin loading correctly. 

262 

263 :param name: 

264 :return: 

265 """ 

266 for p in self._plugins: 266 ↛ 269line 266 didn't jump to line 269 because the loop on line 266 didn't complete

267 if p.__class__.__name__ == name: 

268 return p 

269 return None 

270 

271 def is_topic_filtering_enabled(self) -> bool: 

272 topic_config = self.app_context.config.get("topic-check", {}) if self.app_context.config else {} 

273 if isinstance(topic_config, dict): 

274 return topic_config.get("enabled", False) or self._is_topic_filtering_enabled 

275 return False or self._is_topic_filtering_enabled 

276 

277 async def close(self) -> None: 

278 """Free PluginManager resources and cancel pending event methods.""" 

279 await self.map_plugin_close() 

280 for task in self._fired_events: 280 ↛ 281line 280 didn't jump to line 281 because the loop on line 280 never started

281 task.cancel() 

282 self._fired_events.clear() 

283 

284 @property 

285 def plugins(self) -> list["BasePlugin[C]"]: 

286 """Get the loaded plugins list. 

287 

288 :return: 

289 """ 

290 return self._plugins 

291 

292 def _schedule_coro(self, coro: Awaitable[str | bool | None]) -> asyncio.Future[str | bool | None]: 

293 return asyncio.ensure_future(coro) 

294 

295 def _clean_fired_events(self, future: asyncio.Future[Any]) -> None: 

296 if self.logger.getEffectiveLevel() <= logging.DEBUG: 

297 try: 

298 future.result() 

299 except asyncio.CancelledError: 

300 self.logger.warning("fired event was cancelled") 

301 # display plugin fault; don't allow it to cause a broker failure 

302 except Exception as exc: # noqa: BLE001, pylint: disable=W0718 

303 traceback.print_exception(type(exc), exc, exc.__traceback__, file=sys.stderr) 

304 

305 with contextlib.suppress(KeyError, ValueError): 

306 self._fired_events.remove(future) 

307 

308 async def fire_event(self, event_name: Events, *, wait: bool = False, **method_kwargs: Any) -> None: 

309 """Fire an event to plugins. 

310 

311 PluginManager schedules async calls for each plugin on method called "on_" + event_name. 

312 For example, on_connect will be called on event 'connect'. 

313 Method calls are scheduled in the async loop. wait parameter must be set to true 

314 to wait until all methods are completed. 

315 :param event_name: 

316 :param method_kwargs: 

317 :param wait: indicates if fire_event should wait for plugin calls completion (True), or not 

318 :return: 

319 """ 

320 tasks: list[asyncio.Future[Any]] = [] 

321 

322 # check if any plugin has defined a callback for this event, skip if none 

323 if event_name not in self._event_plugin_callbacks: 

324 return 

325 

326 for event_awaitable in self._event_plugin_callbacks[event_name]: 

327 

328 async def call_method(method: AsyncFunc, kwargs: dict[str, Any]) -> Any: 

329 return await method(**kwargs) 

330 

331 coro_instance: Awaitable[Any] = call_method(event_awaitable, method_kwargs) 

332 tasks.append(asyncio.ensure_future(coro_instance)) 

333 tasks[-1].add_done_callback(self._clean_fired_events) 

334 

335 self._fired_events.extend(tasks) 

336 if wait and tasks: 

337 await asyncio.wait(tasks) 

338 self.logger.debug(f"Plugins len(_fired_events)={len(self._fired_events)}") 

339 

340 @staticmethod 

341 async def _map_plugin_method( 

342 plugins: list["BasePlugin[C]"], 

343 method_name: str, 

344 method_kwargs: dict[str, Any], 

345 ) -> dict["BasePlugin[C]", str | bool | None]: 

346 """Call plugin coroutines. 

347 

348 :param plugins: List of plugins to execute the method on 

349 :param method_name: Name of the method to call on each plugin 

350 :param method_kwargs: Keyword arguments to pass to the method 

351 :return: dict containing return from coro call for each plugin. 

352 """ 

353 tasks: list[asyncio.Future[Any]] = [] 

354 

355 for plugin in plugins: 

356 if not hasattr(plugin, method_name): 

357 continue 

358 

359 async def call_method(p: "BasePlugin[C]", kwargs: dict[str, Any]) -> Any: 

360 method = getattr(p, method_name) 

361 return await method(**kwargs) 

362 

363 coro_instance: Awaitable[Any] = call_method(plugin, method_kwargs) 

364 tasks.append(asyncio.ensure_future(coro_instance)) 

365 

366 ret_dict: dict[BasePlugin[C], str | bool | None] = {} 

367 if tasks: 

368 ret_list = await asyncio.gather(*tasks) 

369 ret_dict = dict(zip(plugins, ret_list, strict=False)) 

370 

371 return ret_dict 

372 

373 async def map_plugin_auth(self, *, session: Session) -> dict["BasePlugin[C]", str | bool | None]: 

374 """Schedule a coroutine for plugin 'authenticate' calls. 

375 

376 :param session: the client session associated with the authentication check 

377 :return: dict containing return from coro call for each plugin. 

378 """ 

379 return await self._map_plugin_method( 

380 self._auth_plugins, "authenticate", {"session": session}) # type: ignore[arg-type] 

381 

382 async def map_plugin_topic( 

383 self, *, session: Session, topic: str, action: "Action" 

384 ) -> dict["BasePlugin[C]", str | bool | None]: 

385 """Schedule a coroutine for plugin 'topic_filtering' calls. 

386 

387 :param session: the client session associated with the topic_filtering check 

388 :param topic: the topic that needs to be filtered 

389 :param action: the action being executed 

390 :return: dict containing return from coro call for each plugin. 

391 """ 

392 return await self._map_plugin_method( 

393 self._topic_plugins, "topic_filtering", # type: ignore[arg-type] 

394 {"session": session, "topic": topic, "action": action} 

395 ) 

396 

397 async def map_plugin_close(self) -> None: 

398 """Schedule a coroutine for plugin 'close' calls. 

399 

400 :return: dict containing return from coro call for each plugin. 

401 """ 

402 await self._map_plugin_method(self._plugins, "close", {})