From da14a894902f25f62b493bc9ee4ef91e82d7eb17 Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Sun, 23 Feb 2025 12:54:25 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=8D=BA=20refactor:=20=E6=94=AF=E6=8C=81?= =?UTF-8?q?=E6=9B=B4=E5=A4=A7=E8=8C=83=E5=9B=B4=E7=9A=84=E7=83=AD=E9=87=8D?= =?UTF-8?q?=E8=BD=BD=E4=BB=A5=E5=8F=8A=E7=AE=A1=E7=90=86=E9=9D=A2=E6=9D=BF?= =?UTF-8?q?=E5=B0=86=E5=B9=B3=E5=8F=B0=E5=92=8C=E6=8F=90=E4=BE=9B=E5=95=86?= =?UTF-8?q?=E9=85=8D=E7=BD=AE=E7=8B=AC=E7=AB=8B=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/core/config/default.py | 6 +- astrbot/core/core_lifecycle.py | 14 +- astrbot/core/platform/manager.py | 121 ++++++-- astrbot/core/platform/platform.py | 6 + .../aiocqhttp/aiocqhttp_platform_adapter.py | 8 +- .../core/platform/sources/gewechat/client.py | 4 +- .../gewechat/gewechat_platform_adapter.py | 4 + astrbot/core/provider/manager.py | 276 +++++++++--------- astrbot/dashboard/routes/config.py | 104 ++++++- .../components/shared/WaitingForRestart.vue | 19 +- .../full/vertical-sidebar/sidebarItem.ts | 12 +- dashboard/src/router/MainRoutes.ts | 11 +- dashboard/src/views/ConfigPage.vue | 5 + dashboard/src/views/PlatformPage.vue | 225 ++++++++++++++ dashboard/src/views/ProviderPage.vue | 221 ++++++++++++++ 15 files changed, 844 insertions(+), 192 deletions(-) create mode 100644 dashboard/src/views/PlatformPage.vue create mode 100644 dashboard/src/views/ProviderPage.vue diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 1f625bd1e..95e2a25c9 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -154,7 +154,8 @@ CONFIG_METADATA_2 = { "id": { "description": "ID", "type": "string", - "hint": "用于在多实例下方便管理和识别。自定义,ID 不能重复。", + "obvious_hint": True, + "hint": "ID 不能和其它的平台适配器重复,否则将发生严重冲突。", }, "type": { "description": "适配器类型", @@ -630,7 +631,8 @@ CONFIG_METADATA_2 = { "id": { "description": "ID", "type": "string", - "hint": "提供商 ID 名,用于在多实例下方便管理和识别。自定义,ID 不能重复。", + "obvious_hint": True, + "hint": "ID 不能和其它的服务提供商重复,否则将发生严重冲突。", }, "type": { "description": "模型提供商类型", diff --git a/astrbot/core/core_lifecycle.py b/astrbot/core/core_lifecycle.py index 0289c4d1a..7d4ec008c 100644 --- a/astrbot/core/core_lifecycle.py +++ b/astrbot/core/core_lifecycle.py @@ -63,9 +63,6 @@ class AstrBotCoreLifecycle: await self.provider_manager.initialize() '''根据配置实例化各个 Provider''' - await self.platform_manager.initialize() - '''根据配置实例化各个平台适配器''' - self.pipeline_scheduler = PipelineScheduler(PipelineContext(self.astrbot_config, self.plugin_manager)) await self.pipeline_scheduler.initialize() '''初始化消息事件流水线调度器''' @@ -74,19 +71,18 @@ class AstrBotCoreLifecycle: self.event_bus = EventBus(self.event_queue, self.pipeline_scheduler) self.start_time = int(time.time()) self.curr_tasks: List[asyncio.Task] = [] - + + await self.platform_manager.initialize() + '''根据配置实例化各个平台适配器''' + def _load(self): - - platform_tasks = self.load_platform() event_bus_task = asyncio.create_task(self.event_bus.dispatch(), name="event_bus") extra_tasks = [] for task in self.star_context._register_tasks: extra_tasks.append(asyncio.create_task(task, name=task.__name__)) - # self.curr_tasks = [event_bus_task, *platform_tasks, *extra_tasks] - - tasks_ = [event_bus_task, *platform_tasks, *extra_tasks] + tasks_ = [event_bus_task, *extra_tasks] for task in tasks_: self.curr_tasks.append(asyncio.create_task(self._task_wrapper(task), name=task.get_name())) diff --git a/astrbot/core/platform/manager.py b/astrbot/core/platform/manager.py index 46118b554..6af0eda7d 100644 --- a/astrbot/core/platform/manager.py +++ b/astrbot/core/platform/manager.py @@ -1,3 +1,5 @@ +import traceback +import asyncio from astrbot.core.config.astrbot_config import AstrBotConfig from .platform import Platform from typing import List @@ -11,43 +13,102 @@ class PlatformManager(): self.platform_insts: List[Platform] = [] '''加载的 Platform 的实例''' + self._inst_map = {} + self.platforms_config = config['platform'] self.settings = config['platform_settings'] self.event_queue = event_queue - - try: - for platform in self.platforms_config: - if not platform['enable']: - continue - match platform['type']: - case "aiocqhttp": - from .sources.aiocqhttp.aiocqhttp_platform_adapter import AiocqhttpAdapter # noqa: F401 - case "qq_official": - from .sources.qqofficial.qqofficial_platform_adapter import QQOfficialPlatformAdapter # noqa: F401 - case "qq_official_webhook": - from .sources.qqofficial_webhook.qo_webhook_adapter import QQOfficialWebhookPlatformAdapter # noqa: F401 - case "gewechat": - from .sources.gewechat.gewechat_platform_adapter import GewechatPlatformAdapter # noqa: F401 - case "lark": - from .sources.lark.lark_adapter import LarkPlatformAdapter # noqa: F401 - except (ImportError, ModuleNotFoundError) as e: - logger.error(f"加载平台适配器 {platform['type']} 失败,原因:{e}。请检查依赖库是否安装。提示:可以在 管理面板->控制台->安装Pip库 中安装依赖库。") - except Exception as e: - logger.error(f"加载平台适配器 {platform['type']} 失败,原因:{e}。") async def initialize(self): + '''初始化所有平台适配器''' for platform in self.platforms_config: - if not platform['enable']: - continue - if platform['type'] not in platform_cls_map: - logger.error(f"未找到适用于 {platform['type']}({platform['id']}) 平台适配器,请检查是否已经安装或者名称填写错误。已跳过。") - continue - cls_type = platform_cls_map[platform['type']] - logger.debug(f"尝试实例化 {platform['type']}({platform['id']}) 平台适配器 ...") - inst = cls_type(platform, self.settings, self.event_queue) - self.platform_insts.append(inst) + await self.load_platform(platform) - self.platform_insts.append(WebChatAdapter({}, self.settings, self.event_queue)) + # 网页聊天 + webchat_inst = WebChatAdapter({}, self.settings, self.event_queue) + self.platform_insts.append(webchat_inst) + asyncio.create_task(self._task_wrapper(asyncio.create_task(webchat_inst.run(), name="webchat"))) + + async def load_platform(self, platform_config: dict): + '''实例化一个平台''' + if not platform_config['enable']: + return + + # 动态导入 + try: + match platform_config['type']: + case "aiocqhttp": + from .sources.aiocqhttp.aiocqhttp_platform_adapter import AiocqhttpAdapter # noqa: F401 + case "qq_official": + from .sources.qqofficial.qqofficial_platform_adapter import QQOfficialPlatformAdapter # noqa: F401 + case "qq_official_webhook": + from .sources.qqofficial_webhook.qo_webhook_adapter import QQOfficialWebhookPlatformAdapter # noqa: F401 + case "gewechat": + from .sources.gewechat.gewechat_platform_adapter import GewechatPlatformAdapter # noqa: F401 + case "lark": + from .sources.lark.lark_adapter import LarkPlatformAdapter # noqa: F401 + except (ImportError, ModuleNotFoundError) as e: + logger.error(f"加载平台适配器 {platform_config['type']} 失败,原因:{e}。请检查依赖库是否安装。提示:可以在 管理面板->控制台->安装Pip库 中安装依赖库。") + except Exception as e: + logger.error(f"加载平台适配器 {platform_config['type']} 失败,原因:{e}。") + + + if platform_config['type'] not in platform_cls_map: + logger.error(f"未找到适用于 {platform_config['type']}({platform_config['id']}) 平台适配器,请检查是否已经安装或者名称填写错误。") + return + cls_type = platform_cls_map[platform_config['type']] + logger.debug(f"尝试实例化 {platform_config['type']}({platform_config['id']}) 平台适配器 ...") + inst = cls_type(platform_config, self.settings, self.event_queue) + self._inst_map[platform_config['id']] = inst + self.platform_insts.append(inst) + + asyncio.create_task(self._task_wrapper(asyncio.create_task(inst.run(), name=platform_config['id'] + "_platform"))) + + async def _task_wrapper(self, task: asyncio.Task): + try: + await task + except asyncio.CancelledError: + pass + except Exception as e: + + logger.error(f"------- 任务 {task.get_name()} 发生错误: {e}") + for line in traceback.format_exc().split("\n"): + logger.error(f"| {line}") + logger.error("-------") + + async def reload(self, platform_config: dict): + # 还未实现完成,不要调用此方法 + + if platform_config['id'] in self._inst_map: + # 正在运行 + if getattr(self._inst_map[platform_config['id']], 'terminate', None): + logger.info(f"正在尝试终止 {platform_config['id']} 平台适配器 ...") + await self._inst_map[platform_config['id']].terminate() + logger.info(f"{platform_config['id']} 平台适配器已终止。") + del self._inst_map[platform_config['id']] + self.platform_insts.remove(self._inst_map[platform_config['id']]) + else: + logger.warning(f"可能无法正常终止 {platform_config['id']} 平台适配器。") + + # 再启动新的实例 + await self.load_platform(platform_config) + + else: + # 先将 _inst_map 中在 platform_config 中不存在的实例删除 + config_ids = [platform['id'] for platform in self.platforms_config] + for key in list(self._inst_map.keys()): + if key not in config_ids: + if getattr(self._inst_map[key], 'terminate', None): + logger.info(f"正在尝试终止 {key} 平台适配器 ...") + await self._inst_map[key].terminate() + logger.info(f"{key} 平台适配器已终止。") + del self._inst_map[key] + self.platform_insts.remove(self._inst_map[key]) + else: + logger.warning(f"可能无法正常终止 {key} 平台适配器。") + + # 再启动新的实例 + await self.load_platform(platform_config) def get_insts(self): return self.platform_insts \ No newline at end of file diff --git a/astrbot/core/platform/platform.py b/astrbot/core/platform/platform.py index 1dd356e89..3526d2802 100644 --- a/astrbot/core/platform/platform.py +++ b/astrbot/core/platform/platform.py @@ -20,6 +20,12 @@ class Platform(abc.ABC): ''' raise NotImplementedError + async def terminate(self): + ''' + 终止一个平台的运行实例。 + ''' + pass + @abc.abstractmethod def meta(self) -> PlatformMetadata: ''' diff --git a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py index 904c5d4a6..248b799e0 100644 --- a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py +++ b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py @@ -32,6 +32,8 @@ class AiocqhttpAdapter(Platform): "适用于 OneBot 标准的消息平台适配器,支持反向 WebSockets。", ) + self.stop = False + async def send_by_session(self, session: MessageSesion, message_chain: MessageChain): ret = await AiocqhttpMessageEvent._parse_onebot_json(message_chain) match session.message_type.value: @@ -230,11 +232,15 @@ class AiocqhttpAdapter(Platform): return bot + async def terminate(self): + self.stop = True + await asyncio.sleep(1) + def meta(self) -> PlatformMetadata: return self.metadata async def shutdown_trigger_placeholder(self): - while not self._event_queue.closed: + while not self._event_queue.closed and not self.stop: await asyncio.sleep(1) logger.info("aiocqhttp 适配器已关闭。") diff --git a/astrbot/core/platform/sources/gewechat/client.py b/astrbot/core/platform/sources/gewechat/client.py index 23021f1c4..34aaa7a7d 100644 --- a/astrbot/core/platform/sources/gewechat/client.py +++ b/astrbot/core/platform/sources/gewechat/client.py @@ -54,6 +54,8 @@ class SimpleGewechatClient(): self.multimedia_downloader = None self.userrealnames = {} + + self.stop = False async def get_token_id(self): async with aiohttp.ClientSession() as session: @@ -231,7 +233,7 @@ class SimpleGewechatClient(): ) async def shutdown_trigger_placeholder(self): - while not self.event_queue.closed: + while not self.event_queue.closed and not self.stop: await asyncio.sleep(1) logger.info("gewechat 适配器已关闭。") diff --git a/astrbot/core/platform/sources/gewechat/gewechat_platform_adapter.py b/astrbot/core/platform/sources/gewechat/gewechat_platform_adapter.py index 1ca47391e..7e325ce8b 100644 --- a/astrbot/core/platform/sources/gewechat/gewechat_platform_adapter.py +++ b/astrbot/core/platform/sources/gewechat/gewechat_platform_adapter.py @@ -47,6 +47,10 @@ class GewechatPlatformAdapter(Platform): "基于 gewechat 的 Wechat 适配器", ) + async def terminate(self): + self.client.stop = True + await asyncio.sleep(1) + @override def run(self): self.client = SimpleGewechatClient( diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py index 70a4af389..3546f249f 100644 --- a/astrbot/core/provider/manager.py +++ b/astrbot/core/provider/manager.py @@ -1,11 +1,9 @@ import traceback -import uuid from astrbot.core.config.astrbot_config import AstrBotConfig from .provider import Provider, STTProvider, TTSProvider, Personality from .entites import ProviderType from typing import List from astrbot.core.db import BaseDatabase -from collections import defaultdict from .register import provider_cls_map, llm_tools from astrbot.core import logger, sp @@ -16,6 +14,14 @@ class ProviderManager(): self.provider_stt_settings: dict = config.get('provider_stt_settings', {}) self.provider_tts_settings: dict = config.get('provider_tts_settings', {}) self.persona_configs: list = config.get('persona', []) + self.astrbot_config = config + + self.selected_provider_id = sp.get("curr_provider") + self.selected_stt_provider_id = self.provider_stt_settings.get("provider_id") + self.selected_tts_provider_id = self.provider_settings.get("provider_id") + self.provider_enabled = self.provider_settings.get("enable", False) + self.stt_enabled = self.provider_stt_settings.get("enable", False) + self.tts_enabled = self.provider_tts_settings.get("enable", False) # 人格情景管理 # 目前没有拆成独立的模块 @@ -75,14 +81,15 @@ class ProviderManager(): _mood_imitation_dialogs_processed="" ) self.personas.append(self.selected_default_persona) - - + self.provider_insts: List[Provider] = [] '''加载的 Provider 的实例''' self.stt_provider_insts: List[STTProvider] = [] '''加载的 Speech To Text Provider 的实例''' self.tts_provider_insts: List[TTSProvider] = [] '''加载的 Text To Speech Provider 的实例''' + self.inst_map = {} + '''Provider 实例映射. key: provider_id, value: Provider 实例''' self.llm_tools = llm_tools self.curr_provider_inst: Provider = None '''当前使用的 Provider 实例''' @@ -90,7 +97,6 @@ class ProviderManager(): '''当前使用的 Speech To Text Provider 实例''' self.curr_tts_provider_inst: TTSProvider = None '''当前使用的 Text To Speech Provider 实例''' - self.loaded_ids = defaultdict(bool) self.db_helper = db_helper # kdb(experimental) @@ -99,145 +105,155 @@ class ProviderManager(): if kdb_cfg and len(kdb_cfg): self.curr_kdb_name = list(kdb_cfg.keys())[0] - changed = False - for provider_cfg in self.providers_config: - if not provider_cfg['enable']: - continue - - if provider_cfg['id'] in self.loaded_ids: - new_id = f"{provider_cfg['id']}_{str(uuid.uuid4())[:8]}" - logger.info(f"Provider ID 重复:{provider_cfg['id']}。已自动更改为 {new_id}。") - provider_cfg['id'] = new_id - changed = True - self.loaded_ids[provider_cfg['id']] = True - - try: - match provider_cfg['type']: - case "openai_chat_completion": - from .sources.openai_source import ProviderOpenAIOfficial as ProviderOpenAIOfficial - case "zhipu_chat_completion": - from .sources.zhipu_source import ProviderZhipu as ProviderZhipu - case "anthropic_chat_completion": - from .sources.anthropic_source import ProviderAnthropic as ProviderAnthropic - case "llm_tuner": - logger.info("加载 LLM Tuner 工具 ...") - from .sources.llmtuner_source import LLMTunerModelLoader as LLMTunerModelLoader - case "dify": - from .sources.dify_source import ProviderDify as ProviderDify - case "dashscope": - from .sources.dashscope_source import ProviderDashscope as ProviderDashscope - case "googlegenai_chat_completion": - from .sources.gemini_source import ProviderGoogleGenAI as ProviderGoogleGenAI - case "openai_whisper_api": - from .sources.whisper_api_source import ProviderOpenAIWhisperAPI as ProviderOpenAIWhisperAPI - case "openai_whisper_selfhost": - from .sources.whisper_selfhosted_source import ProviderOpenAIWhisperSelfHost as ProviderOpenAIWhisperSelfHost - case "openai_tts_api": - from .sources.openai_tts_api_source import ProviderOpenAITTSAPI as ProviderOpenAITTSAPI - case "fishaudio_tts_api": - from .sources.fishaudio_tts_api_source import ProviderFishAudioTTSAPI as ProviderFishAudioTTSAPI - except (ImportError, ModuleNotFoundError) as e: - logger.critical(f"加载 {provider_cfg['type']}({provider_cfg['id']}) 提供商适配器失败:{e}。可能是因为有未安装的依赖。") - continue - except Exception as e: - logger.critical(f"加载 {provider_cfg['type']}({provider_cfg['id']}) 提供商适配器失败:{e}。未知原因") - continue - - if changed: - try: - config.save_config() - except Exception as e: - logger.warning(f"保存配置文件失败:{e}") - async def initialize(self): - - selected_provider_id = sp.get("curr_provider") - selected_stt_provider_id = self.provider_stt_settings.get("provider_id") - selected_tts_provider_id = self.provider_settings.get("provider_id") - provider_enabled = self.provider_settings.get("enable", False) - stt_enabled = self.provider_stt_settings.get("enable", False) - tts_enabled = self.provider_tts_settings.get("enable", False) - for provider_config in self.providers_config: - if not provider_config['enable']: - continue - if provider_config['type'] not in provider_cls_map: - logger.error(f"未找到适用于 {provider_config['type']}({provider_config['id']}) 的提供商适配器,请检查是否已经安装或者名称填写错误。已跳过。") - continue + await self.load_provider(provider_config) - provider_metadata = provider_cls_map[provider_config['type']] - logger.debug(f"尝试实例化 {provider_config['type']}({provider_config['id']}) 提供商适配器 ...") - try: - # 按任务实例化提供商 - - if provider_metadata.provider_type == ProviderType.SPEECH_TO_TEXT: - # STT 任务 - inst = provider_metadata.cls_type(provider_config, self.provider_settings) - - if getattr(inst, "initialize", None): - await inst.initialize() - - self.stt_provider_insts.append(inst) - if selected_stt_provider_id == provider_config['id'] and stt_enabled: - self.curr_stt_provider_inst = inst - logger.info(f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前语音转文本提供商适配器。") - - elif provider_metadata.provider_type == ProviderType.TEXT_TO_SPEECH: - # TTS 任务 - inst = provider_metadata.cls_type(provider_config, self.provider_settings) - - if getattr(inst, "initialize", None): - await inst.initialize() - - self.tts_provider_insts.append(inst) - if selected_tts_provider_id == provider_config['id'] and tts_enabled: - self.curr_tts_provider_inst = inst - logger.info(f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前文本转语音提供商适配器。") - - elif provider_metadata.provider_type == ProviderType.CHAT_COMPLETION: - # 文本生成任务 - inst = provider_metadata.cls_type( - provider_config, - self.provider_settings, - self.db_helper, - self.provider_settings.get('persistant_history', True), - self.selected_default_persona - ) - - if getattr(inst, "initialize", None): - await inst.initialize() - - self.provider_insts.append(inst) - if selected_provider_id == provider_config['id'] and provider_enabled: - self.curr_provider_inst = inst - logger.info(f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前提供商适配器。") - - except Exception as e: - traceback.print_exc() - logger.error(f"实例化 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}") - - if len(self.provider_insts) > 0 and not self.curr_provider_inst and provider_enabled: - self.curr_provider_inst = self.provider_insts[0] - - if len(self.stt_provider_insts) > 0 and not self.curr_stt_provider_inst and stt_enabled: - self.curr_stt_provider_inst = self.stt_provider_insts[0] - - if len(self.tts_provider_insts) > 0 and not self.curr_tts_provider_inst and tts_enabled: - self.curr_tts_provider_inst = self.tts_provider_insts[0] - if not self.curr_provider_inst: logger.warning("未启用任何用于 文本生成 的提供商适配器。") - if stt_enabled and not self.curr_stt_provider_inst: + if self.stt_enabled and not self.curr_stt_provider_inst: logger.warning("未启用任何用于 语音转文本 的提供商适配器。") - if tts_enabled and not self.curr_tts_provider_inst: + if self.tts_enabled and not self.curr_tts_provider_inst: logger.warning("未启用任何用于 文本转语音 的提供商适配器。") + + async def load_provider(self, provider_config: dict): + if not provider_config['enable']: + return + # 动态导入 + try: + match provider_config['type']: + case "openai_chat_completion": + from .sources.openai_source import ProviderOpenAIOfficial as ProviderOpenAIOfficial + case "zhipu_chat_completion": + from .sources.zhipu_source import ProviderZhipu as ProviderZhipu + case "anthropic_chat_completion": + from .sources.anthropic_source import ProviderAnthropic as ProviderAnthropic + case "llm_tuner": + logger.info("加载 LLM Tuner 工具 ...") + from .sources.llmtuner_source import LLMTunerModelLoader as LLMTunerModelLoader + case "dify": + from .sources.dify_source import ProviderDify as ProviderDify + case "dashscope": + from .sources.dashscope_source import ProviderDashscope as ProviderDashscope + case "googlegenai_chat_completion": + from .sources.gemini_source import ProviderGoogleGenAI as ProviderGoogleGenAI + case "openai_whisper_api": + from .sources.whisper_api_source import ProviderOpenAIWhisperAPI as ProviderOpenAIWhisperAPI + case "openai_whisper_selfhost": + from .sources.whisper_selfhosted_source import ProviderOpenAIWhisperSelfHost as ProviderOpenAIWhisperSelfHost + case "openai_tts_api": + from .sources.openai_tts_api_source import ProviderOpenAITTSAPI as ProviderOpenAITTSAPI + case "fishaudio_tts_api": + from .sources.fishaudio_tts_api_source import ProviderFishAudioTTSAPI as ProviderFishAudioTTSAPI + except (ImportError, ModuleNotFoundError) as e: + logger.critical(f"加载 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}。可能是因为有未安装的依赖。") + return + except Exception as e: + logger.critical(f"加载 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}。未知原因") + return + + if provider_config['type'] not in provider_cls_map: + logger.error(f"未找到适用于 {provider_config['type']}({provider_config['id']}) 的提供商适配器,请检查是否已经安装或者名称填写错误。已跳过。") + return + + provider_metadata = provider_cls_map[provider_config['type']] + logger.debug(f"尝试实例化 {provider_config['type']}({provider_config['id']}) 提供商适配器 ...") + try: + # 按任务实例化提供商 + + if provider_metadata.provider_type == ProviderType.SPEECH_TO_TEXT: + # STT 任务 + inst = provider_metadata.cls_type(provider_config, self.provider_settings) + + if getattr(inst, "initialize", None): + await inst.initialize() + + self.stt_provider_insts.append(inst) + if self.selected_stt_provider_id == provider_config['id'] and self.stt_enabled: + self.curr_stt_provider_inst = inst + logger.info(f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前语音转文本提供商适配器。") + if not self.curr_stt_provider_inst and self.stt_enabled: + self.curr_stt_provider_inst = inst + + elif provider_metadata.provider_type == ProviderType.TEXT_TO_SPEECH: + # TTS 任务 + inst = provider_metadata.cls_type(provider_config, self.provider_settings) + + if getattr(inst, "initialize", None): + await inst.initialize() + + self.tts_provider_insts.append(inst) + if self.selected_tts_provider_id == provider_config['id'] and self.tts_enabled: + self.curr_tts_provider_inst = inst + logger.info(f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前文本转语音提供商适配器。") + if not self.curr_tts_provider_inst and self.tts_enabled: + self.curr_tts_provider_inst = inst + + elif provider_metadata.provider_type == ProviderType.CHAT_COMPLETION: + # 文本生成任务 + inst = provider_metadata.cls_type( + provider_config, + self.provider_settings, + self.db_helper, + self.provider_settings.get('persistant_history', True), + self.selected_default_persona + ) + + if getattr(inst, "initialize", None): + await inst.initialize() + + self.provider_insts.append(inst) + if self.selected_provider_id == provider_config['id'] and self.provider_enabled: + self.curr_provider_inst = inst + logger.info(f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前提供商适配器。") + if not self.curr_provider_inst and self.provider_enabled: + self.curr_provider_inst = inst + + self.inst_map[provider_config['id']] = inst + except Exception as e: + traceback.print_exc() + logger.error(f"实例化 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}") + + async def reload(self, provider_config: dict): + await self.terminate_provider(provider_config['id']) + if provider_config['enable']: + await self.load_provider(provider_config) + + # 和配置文件保持同步 + config_ids = [provider['id'] for provider in self.providers_config] + for key in list(self.inst_map.keys()): + if key not in config_ids: + await self.terminate_provider(key) + + if len(self.provider_insts) == 0: + self.curr_provider_inst = None + if len(self.stt_provider_insts) == 0: + self.curr_stt_provider_inst = None + if len(self.tts_provider_insts) == 0: + self.curr_tts_provider_inst = None def get_insts(self): return self.provider_insts + async def terminate_provider(self, provider_id: str): + if provider_id in self.inst_map: + + if self.inst_map[provider_id] in self.provider_insts: + self.provider_insts.remove(self.inst_map[provider_id]) + if self.inst_map[provider_id] in self.stt_provider_insts: + self.stt_provider_insts.remove(self.inst_map[provider_id]) + if self.inst_map[provider_id] in self.tts_provider_insts: + self.tts_provider_insts.remove(self.inst_map[provider_id]) + + if getattr(self.inst_map[provider_id], 'terminate', None): + logger.info(f"正在尝试终止 {provider_id} 提供商适配器 ...") + await self.inst_map[provider_id].terminate() + logger.info(f"{provider_id} 提供商适配器已终止。") + del self.inst_map[provider_id] + async def terminate(self): for provider_inst in self.provider_insts: if hasattr(provider_inst, "terminate"): diff --git a/astrbot/dashboard/routes/config.py b/astrbot/dashboard/routes/config.py index 61810705c..9208ffd12 100644 --- a/astrbot/dashboard/routes/config.py +++ b/astrbot/dashboard/routes/config.py @@ -1,4 +1,5 @@ import typing +import traceback from .route import Route, Response, RouteContext from quart import request from astrbot.core.config.default import CONFIG_METADATA_2, DEFAULT_VALUE_MAP @@ -77,6 +78,7 @@ def save_config(post_config: dict, config: AstrBotConfig, is_core: bool = False) else: errors, post_config = validate_config(post_config, config.schema, is_core) except BaseException as e: + logger.error(traceback.format_exc()) logger.warning(f"验证配置时出现异常: {e}") if errors: raise ValueError(f"格式校验未通过: {errors}") @@ -90,6 +92,14 @@ class ConfigRoute(Route): '/config/get': ('GET', self.get_configs), '/config/astrbot/update': ('POST', self.post_astrbot_configs), '/config/plugin/update': ('POST', self.post_plugin_configs), + + '/config/platform/new': ('POST', self.post_new_platform), + '/config/platform/update': ('POST', self.post_update_platform), + '/config/platform/delete': ('POST', self.post_delete_platform), + + '/config/provider/new': ('POST', self.post_new_provider), + '/config/provider/update': ('POST', self.post_update_provider), + '/config/provider/delete': ('POST', self.post_delete_provider) } self.register_routes() @@ -118,7 +128,99 @@ class ConfigRoute(Route): return Response().ok(None, f"保存插件 {plugin_name} 成功~ 机器人正在重载配置。").__dict__ except Exception as e: return Response().error(str(e)).__dict__ - + + async def post_new_platform(self): + new_platform_config = await request.json + self.config['platform'].append(new_platform_config) + try: + save_config(self.config, self.config, is_core=True) + await self.core_lifecycle.platform_manager.load_platform(new_platform_config) + except Exception as e: + return Response().error(str(e)).__dict__ + return Response().ok(None, "新增平台配置成功~").__dict__ + + async def post_new_provider(self): + new_provider_config = await request.json + self.config['provider'].append(new_provider_config) + try: + save_config(self.config, self.config, is_core=True) + await self.core_lifecycle.provider_manager.load_provider(new_provider_config) + except Exception as e: + return Response().error(str(e)).__dict__ + return Response().ok(None, "新增服务提供商配置成功~").__dict__ + + async def post_update_platform(self): + update_platform_config = await request.json + platform_id = update_platform_config.get("id", None) + new_config = update_platform_config.get("config", None) + if not platform_id or not new_config: + return Response().error("参数错误").__dict__ + + for i, platform in enumerate(self.config['platform']): + if platform['id'] == platform_id: + self.config['platform'][i] = new_config + break + else: + return Response().error("未找到对应平台").__dict__ + + try: + await self._save_astrbot_configs(self.config) + except Exception as e: + return Response().error(str(e)).__dict__ + return Response().ok(None, "更新平台配置成功~").__dict__ + + async def post_update_provider(self): + update_provider_config = await request.json + provider_id = update_provider_config.get("id", None) + new_config = update_provider_config.get("config", None) + if not provider_id or not new_config: + return Response().error("参数错误").__dict__ + + for i, provider in enumerate(self.config['provider']): + if provider['id'] == provider_id: + self.config['provider'][i] = new_config + break + else: + return Response().error("未找到对应服务提供商").__dict__ + + try: + save_config(self.config, self.config, is_core=True) + await self.core_lifecycle.provider_manager.reload(new_config) + except Exception as e: + return Response().error(str(e)).__dict__ + return Response().ok(None, "更新成功,已经实时生效~").__dict__ + + async def post_delete_platform(self): + platform_id = await request.json + platform_id = platform_id.get("id") + for i, platform in enumerate(self.config['platform']): + if platform['id'] == platform_id: + del self.config['platform'][i] + break + else: + return Response().error("未找到对应平台").__dict__ + try: + await self._save_astrbot_configs(self.config) + except Exception as e: + return Response().error(str(e)).__dict__ + return Response().ok(None, "删除平台配置成功~").__dict__ + + async def post_delete_provider(self): + provider_id = await request.json + provider_id = provider_id.get("id") + for i, provider in enumerate(self.config['provider']): + if provider['id'] == provider_id: + del self.config['provider'][i] + break + else: + return Response().error("未找到对应服务提供商").__dict__ + try: + save_config(self.config, self.config, is_core=True) + await self.core_lifecycle.provider_manager.terminate_provider(provider_id) + except Exception as e: + return Response().error(str(e)).__dict__ + return Response().ok(None, "删除成功,已经实时生效~").__dict__ + async def _get_astrbot_config(self): config = self.config diff --git a/dashboard/src/components/shared/WaitingForRestart.vue b/dashboard/src/components/shared/WaitingForRestart.vue index 3fad97832..088b1ee6e 100644 --- a/dashboard/src/components/shared/WaitingForRestart.vue +++ b/dashboard/src/components/shared/WaitingForRestart.vue @@ -4,17 +4,6 @@ 正在等待 AstrBot 重启... -
-
- -

重启成功!

-
- 当前实例标识:{{ startTime }} - 检查到新实例:{{ newStartTime }},即将自动刷新页面 - {{ status }} - 尝试次数:{{ cnt }} / 60 -
-
@@ -73,11 +62,9 @@ export default { if (this.newStartTime !== this.startTime) { this.newStartTime = newStartTime console.log('wfr: restarted') - setTimeout(() => { - this.visible = false - // reload - window.location.reload() - }, 2000) + this.visible = false + // reload + window.location.reload() } return this.newStartTime } diff --git a/dashboard/src/layouts/full/vertical-sidebar/sidebarItem.ts b/dashboard/src/layouts/full/vertical-sidebar/sidebarItem.ts index 0f6f69111..1addccca6 100644 --- a/dashboard/src/layouts/full/vertical-sidebar/sidebarItem.ts +++ b/dashboard/src/layouts/full/vertical-sidebar/sidebarItem.ts @@ -21,7 +21,17 @@ const sidebarItem: menu[] = [ to: '/dashboard/default' }, { - title: '配置文件', + title: '消息平台', + icon: 'mdi-message-processing', + to: '/platforms', + }, + { + title: '服务提供商', + icon: 'mdi-creation', + to: '/providers', + }, + { + title: '配置', icon: 'mdi-cog', to: '/config', }, diff --git a/dashboard/src/router/MainRoutes.ts b/dashboard/src/router/MainRoutes.ts index bd407ad26..006d67687 100644 --- a/dashboard/src/router/MainRoutes.ts +++ b/dashboard/src/router/MainRoutes.ts @@ -16,12 +16,21 @@ const MainRoutes = { path: '/extension', component: () => import('@/views/ExtensionPage.vue') }, + { + name: 'Platforms', + path: '/platforms', + component: () => import('@/views/PlatformPage.vue') + }, + { + name: 'Providers', + path: '/providers', + component: () => import('@/views/ProviderPage.vue') + }, { name: 'Configs', path: '/config', component: () => import('@/views/ConfigPage.vue') }, - { name: 'Default', path: '/dashboard/default', diff --git a/dashboard/src/views/ConfigPage.vue b/dashboard/src/views/ConfigPage.vue index 7405063b5..3a6f9fa45 100644 --- a/dashboard/src/views/ConfigPage.vue +++ b/dashboard/src/views/ConfigPage.vue @@ -44,6 +44,11 @@ import config from '@/config'; + + + 消息平台适配器和服务提供商的配置已经迁移至更方便的独立页面!推荐前往左栏配置哦~ + + {{ item[metadata[key]['metadata'][key2]?.tmpl_display_title] }} diff --git a/dashboard/src/views/PlatformPage.vue b/dashboard/src/views/PlatformPage.vue new file mode 100644 index 000000000..4732fdf2d --- /dev/null +++ b/dashboard/src/views/PlatformPage.vue @@ -0,0 +1,225 @@ + + + + \ No newline at end of file diff --git a/dashboard/src/views/ProviderPage.vue b/dashboard/src/views/ProviderPage.vue new file mode 100644 index 000000000..8d30a06b3 --- /dev/null +++ b/dashboard/src/views/ProviderPage.vue @@ -0,0 +1,221 @@ + + + + \ No newline at end of file