import asyncio import traceback from asyncio import Queue from astrbot.core import logger from astrbot.core.config.astrbot_config import AstrBotConfig from astrbot.core.star.star_handler import EventType, star_handlers_registry, star_map from .platform import Platform from .register import platform_cls_map from .sources.webchat.webchat_adapter import WebChatAdapter class PlatformManager: def __init__(self, config: AstrBotConfig, event_queue: Queue): self.platform_insts: list[Platform] = [] """加载的 Platform 的实例""" self._inst_map = {} self.platforms_config = config["platform"] self.settings = config["platform_settings"] """NOTE: 这里是 default 的配置文件,以保证最大的兼容性; 这个配置中的 unique_session 需要特殊处理, 约定整个项目中对 unique_session 的引用都从 default 的配置中获取""" self.event_queue = event_queue async def initialize(self): """初始化所有平台适配器""" for platform in self.platforms_config: try: await self.load_platform(platform) except Exception as e: logger.error(f"初始化 {platform} 平台适配器失败: {e}") # 网页聊天 webchat_inst = WebChatAdapter({}, self.settings, self.event_queue) self.platform_insts.append(webchat_inst) asyncio.create_task( self._task_wrapper(asyncio.create_task(webchat_inst.run(), name="webchat")), ) async def load_platform(self, platform_config: dict): """实例化一个平台""" # 动态导入 try: if not platform_config["enable"]: return logger.info( f"载入 {platform_config['type']}({platform_config['id']}) 平台适配器 ...", ) match platform_config["type"]: case "aiocqhttp": from .sources.aiocqhttp.aiocqhttp_platform_adapter import ( AiocqhttpAdapter, # noqa: F401 ) case "qq_official": from .sources.qqofficial.qqofficial_platform_adapter import ( QQOfficialPlatformAdapter, # noqa: F401 ) case "qq_official_webhook": from .sources.qqofficial_webhook.qo_webhook_adapter import ( QQOfficialWebhookPlatformAdapter, # noqa: F401 ) case "wechatpadpro": from .sources.wechatpadpro.wechatpadpro_adapter import ( WeChatPadProAdapter, # noqa: F401 ) case "lark": from .sources.lark.lark_adapter import ( LarkPlatformAdapter, # noqa: F401 ) case "dingtalk": from .sources.dingtalk.dingtalk_adapter import ( DingtalkPlatformAdapter, # noqa: F401 ) case "telegram": from .sources.telegram.tg_adapter import ( TelegramPlatformAdapter, # noqa: F401 ) case "wecom": from .sources.wecom.wecom_adapter import ( WecomPlatformAdapter, # noqa: F401 ) case "wecom_ai_bot": from .sources.wecom_ai_bot.wecomai_adapter import ( WecomAIBotAdapter, # noqa: F401 ) case "weixin_official_account": from .sources.weixin_official_account.weixin_offacc_adapter import ( WeixinOfficialAccountPlatformAdapter, # noqa: F401 ) case "discord": from .sources.discord.discord_platform_adapter import ( DiscordPlatformAdapter, # noqa: F401 ) case "misskey": from .sources.misskey.misskey_adapter import ( MisskeyPlatformAdapter, # noqa: F401 ) case "slack": from .sources.slack.slack_adapter import SlackAdapter # noqa: F401 case "satori": from .sources.satori.satori_adapter import ( SatoriPlatformAdapter, # noqa: F401 ) except (ImportError, ModuleNotFoundError) as e: logger.error( f"加载平台适配器 {platform_config['type']} 失败,原因:{e}。请检查依赖库是否安装。提示:可以在 管理面板->控制台->安装Pip库 中安装依赖库。", ) except Exception as e: logger.error(f"加载平台适配器 {platform_config['type']} 失败,原因:{e}。") if platform_config["type"] not in platform_cls_map: logger.error( f"未找到适用于 {platform_config['type']}({platform_config['id']}) 平台适配器,请检查是否已经安装或者名称填写错误", ) return cls_type = platform_cls_map[platform_config["type"]] inst: Platform = cls_type(platform_config, self.settings, self.event_queue) self._inst_map[platform_config["id"]] = { "inst": inst, "client_id": inst.client_self_id, } self.platform_insts.append(inst) asyncio.create_task( self._task_wrapper( asyncio.create_task( inst.run(), name=f"platform_{platform_config['type']}_{platform_config['id']}", ), ), ) handlers = star_handlers_registry.get_handlers_by_event_type( EventType.OnPlatformLoadedEvent, ) for handler in handlers: try: logger.info( f"hook(on_platform_loaded) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}", ) await handler.handler() except Exception: logger.error(traceback.format_exc()) async def _task_wrapper(self, task: asyncio.Task): try: await task except asyncio.CancelledError: pass except Exception as e: logger.error(f"------- 任务 {task.get_name()} 发生错误: {e}") for line in traceback.format_exc().split("\n"): logger.error(f"| {line}") logger.error("-------") async def reload(self, platform_config: dict): await self.terminate_platform(platform_config["id"]) if platform_config["enable"]: await self.load_platform(platform_config) # 和配置文件保持同步 config_ids = [provider["id"] for provider in self.platforms_config] for key in list(self._inst_map.keys()): if key not in config_ids: await self.terminate_platform(key) async def terminate_platform(self, platform_id: str): if platform_id in self._inst_map: logger.info(f"正在尝试终止 {platform_id} 平台适配器 ...") # client_id = self._inst_map.pop(platform_id, None) info = self._inst_map.pop(platform_id, None) client_id = info["client_id"] inst = info["inst"] try: self.platform_insts.remove( next( inst for inst in self.platform_insts if inst.client_self_id == client_id ), ) except Exception: logger.warning(f"可能未完全移除 {platform_id} 平台适配器") if getattr(inst, "terminate", None): await inst.terminate() async def terminate(self): for inst in self.platform_insts: if getattr(inst, "terminate", None): await inst.terminate() def get_insts(self): return self.platform_insts