diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py
index 1f625bd1e..95e2a25c9 100644
--- a/astrbot/core/config/default.py
+++ b/astrbot/core/config/default.py
@@ -154,7 +154,8 @@ CONFIG_METADATA_2 = {
"id": {
"description": "ID",
"type": "string",
- "hint": "用于在多实例下方便管理和识别。自定义,ID 不能重复。",
+ "obvious_hint": True,
+ "hint": "ID 不能和其它的平台适配器重复,否则将发生严重冲突。",
},
"type": {
"description": "适配器类型",
@@ -630,7 +631,8 @@ CONFIG_METADATA_2 = {
"id": {
"description": "ID",
"type": "string",
- "hint": "提供商 ID 名,用于在多实例下方便管理和识别。自定义,ID 不能重复。",
+ "obvious_hint": True,
+ "hint": "ID 不能和其它的服务提供商重复,否则将发生严重冲突。",
},
"type": {
"description": "模型提供商类型",
diff --git a/astrbot/core/core_lifecycle.py b/astrbot/core/core_lifecycle.py
index 0289c4d1a..7d4ec008c 100644
--- a/astrbot/core/core_lifecycle.py
+++ b/astrbot/core/core_lifecycle.py
@@ -63,9 +63,6 @@ class AstrBotCoreLifecycle:
await self.provider_manager.initialize()
'''根据配置实例化各个 Provider'''
- await self.platform_manager.initialize()
- '''根据配置实例化各个平台适配器'''
-
self.pipeline_scheduler = PipelineScheduler(PipelineContext(self.astrbot_config, self.plugin_manager))
await self.pipeline_scheduler.initialize()
'''初始化消息事件流水线调度器'''
@@ -74,19 +71,18 @@ class AstrBotCoreLifecycle:
self.event_bus = EventBus(self.event_queue, self.pipeline_scheduler)
self.start_time = int(time.time())
self.curr_tasks: List[asyncio.Task] = []
-
+
+ await self.platform_manager.initialize()
+ '''根据配置实例化各个平台适配器'''
+
def _load(self):
-
- platform_tasks = self.load_platform()
event_bus_task = asyncio.create_task(self.event_bus.dispatch(), name="event_bus")
extra_tasks = []
for task in self.star_context._register_tasks:
extra_tasks.append(asyncio.create_task(task, name=task.__name__))
- # self.curr_tasks = [event_bus_task, *platform_tasks, *extra_tasks]
-
- tasks_ = [event_bus_task, *platform_tasks, *extra_tasks]
+ tasks_ = [event_bus_task, *extra_tasks]
for task in tasks_:
self.curr_tasks.append(asyncio.create_task(self._task_wrapper(task), name=task.get_name()))
diff --git a/astrbot/core/platform/manager.py b/astrbot/core/platform/manager.py
index 46118b554..a99fb5532 100644
--- a/astrbot/core/platform/manager.py
+++ b/astrbot/core/platform/manager.py
@@ -1,3 +1,5 @@
+import traceback
+import asyncio
from astrbot.core.config.astrbot_config import AstrBotConfig
from .platform import Platform
from typing import List
@@ -11,43 +13,103 @@ class PlatformManager():
self.platform_insts: List[Platform] = []
'''加载的 Platform 的实例'''
+ self._inst_map = {}
+
self.platforms_config = config['platform']
self.settings = config['platform_settings']
self.event_queue = event_queue
-
- try:
- for platform in self.platforms_config:
- if not platform['enable']:
- continue
- match platform['type']:
- case "aiocqhttp":
- from .sources.aiocqhttp.aiocqhttp_platform_adapter import AiocqhttpAdapter # noqa: F401
- case "qq_official":
- from .sources.qqofficial.qqofficial_platform_adapter import QQOfficialPlatformAdapter # noqa: F401
- case "qq_official_webhook":
- from .sources.qqofficial_webhook.qo_webhook_adapter import QQOfficialWebhookPlatformAdapter # noqa: F401
- case "gewechat":
- from .sources.gewechat.gewechat_platform_adapter import GewechatPlatformAdapter # noqa: F401
- case "lark":
- from .sources.lark.lark_adapter import LarkPlatformAdapter # noqa: F401
- except (ImportError, ModuleNotFoundError) as e:
- logger.error(f"加载平台适配器 {platform['type']} 失败,原因:{e}。请检查依赖库是否安装。提示:可以在 管理面板->控制台->安装Pip库 中安装依赖库。")
- except Exception as e:
- logger.error(f"加载平台适配器 {platform['type']} 失败,原因:{e}。")
async def initialize(self):
+ '''初始化所有平台适配器'''
for platform in self.platforms_config:
- if not platform['enable']:
- continue
- if platform['type'] not in platform_cls_map:
- logger.error(f"未找到适用于 {platform['type']}({platform['id']}) 平台适配器,请检查是否已经安装或者名称填写错误。已跳过。")
- continue
- cls_type = platform_cls_map[platform['type']]
- logger.debug(f"尝试实例化 {platform['type']}({platform['id']}) 平台适配器 ...")
- inst = cls_type(platform, self.settings, self.event_queue)
- self.platform_insts.append(inst)
+ await self.load_platform(platform)
- self.platform_insts.append(WebChatAdapter({}, self.settings, self.event_queue))
+ # 网页聊天
+ webchat_inst = WebChatAdapter({}, self.settings, self.event_queue)
+ self.platform_insts.append(webchat_inst)
+ asyncio.create_task(self._task_wrapper(asyncio.create_task(webchat_inst.run(), name="webchat")))
+
+ async def load_platform(self, platform_config: dict):
+ '''实例化一个平台'''
+ if not platform_config['enable']:
+ return
+
+ logger.info(f"载入 {platform_config['type']}({platform_config['id']}) 平台适配器 ...")
+
+ # 动态导入
+ try:
+ match platform_config['type']:
+ case "aiocqhttp":
+ from .sources.aiocqhttp.aiocqhttp_platform_adapter import AiocqhttpAdapter # noqa: F401
+ case "qq_official":
+ from .sources.qqofficial.qqofficial_platform_adapter import QQOfficialPlatformAdapter # noqa: F401
+ case "qq_official_webhook":
+ from .sources.qqofficial_webhook.qo_webhook_adapter import QQOfficialWebhookPlatformAdapter # noqa: F401
+ case "gewechat":
+ from .sources.gewechat.gewechat_platform_adapter import GewechatPlatformAdapter # noqa: F401
+ case "lark":
+ from .sources.lark.lark_adapter import LarkPlatformAdapter # noqa: F401
+ except (ImportError, ModuleNotFoundError) as e:
+ logger.error(f"加载平台适配器 {platform_config['type']} 失败,原因:{e}。请检查依赖库是否安装。提示:可以在 管理面板->控制台->安装Pip库 中安装依赖库。")
+ except Exception as e:
+ logger.error(f"加载平台适配器 {platform_config['type']} 失败,原因:{e}。")
+
+
+ if platform_config['type'] not in platform_cls_map:
+ logger.error(f"未找到适用于 {platform_config['type']}({platform_config['id']}) 平台适配器,请检查是否已经安装或者名称填写错误。")
+ return
+ cls_type = platform_cls_map[platform_config['type']]
+ inst = cls_type(platform_config, self.settings, self.event_queue)
+ self._inst_map[platform_config['id']] = inst
+ self.platform_insts.append(inst)
+
+ asyncio.create_task(self._task_wrapper(asyncio.create_task(inst.run(), name=platform_config['id'] + "_platform")))
+
+ async def _task_wrapper(self, task: asyncio.Task):
+ try:
+ await task
+ except asyncio.CancelledError:
+ pass
+ except Exception as e:
+
+ logger.error(f"------- 任务 {task.get_name()} 发生错误: {e}")
+ for line in traceback.format_exc().split("\n"):
+ logger.error(f"| {line}")
+ logger.error("-------")
+
+ async def reload(self, platform_config: dict):
+ # 还未实现完成,不要调用此方法
+
+ if platform_config['id'] in self._inst_map:
+ # 正在运行
+ if getattr(self._inst_map[platform_config['id']], 'terminate', None):
+ logger.info(f"正在尝试终止 {platform_config['id']} 平台适配器 ...")
+ await self._inst_map[platform_config['id']].terminate()
+ logger.info(f"{platform_config['id']} 平台适配器已终止。")
+ del self._inst_map[platform_config['id']]
+ self.platform_insts.remove(self._inst_map[platform_config['id']])
+ else:
+ logger.warning(f"可能无法正常终止 {platform_config['id']} 平台适配器。")
+
+ # 再启动新的实例
+ await self.load_platform(platform_config)
+
+ else:
+ # 先将 _inst_map 中在 platform_config 中不存在的实例删除
+ config_ids = [platform['id'] for platform in self.platforms_config]
+ for key in list(self._inst_map.keys()):
+ if key not in config_ids:
+ if getattr(self._inst_map[key], 'terminate', None):
+ logger.info(f"正在尝试终止 {key} 平台适配器 ...")
+ await self._inst_map[key].terminate()
+ logger.info(f"{key} 平台适配器已终止。")
+ del self._inst_map[key]
+ self.platform_insts.remove(self._inst_map[key])
+ else:
+ logger.warning(f"可能无法正常终止 {key} 平台适配器。")
+
+ # 再启动新的实例
+ await self.load_platform(platform_config)
def get_insts(self):
return self.platform_insts
\ No newline at end of file
diff --git a/astrbot/core/platform/platform.py b/astrbot/core/platform/platform.py
index 1dd356e89..3526d2802 100644
--- a/astrbot/core/platform/platform.py
+++ b/astrbot/core/platform/platform.py
@@ -20,6 +20,12 @@ class Platform(abc.ABC):
'''
raise NotImplementedError
+ async def terminate(self):
+ '''
+ 终止一个平台的运行实例。
+ '''
+ pass
+
@abc.abstractmethod
def meta(self) -> PlatformMetadata:
'''
diff --git a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py
index 904c5d4a6..248b799e0 100644
--- a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py
+++ b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py
@@ -32,6 +32,8 @@ class AiocqhttpAdapter(Platform):
"适用于 OneBot 标准的消息平台适配器,支持反向 WebSockets。",
)
+ self.stop = False
+
async def send_by_session(self, session: MessageSesion, message_chain: MessageChain):
ret = await AiocqhttpMessageEvent._parse_onebot_json(message_chain)
match session.message_type.value:
@@ -230,11 +232,15 @@ class AiocqhttpAdapter(Platform):
return bot
+ async def terminate(self):
+ self.stop = True
+ await asyncio.sleep(1)
+
def meta(self) -> PlatformMetadata:
return self.metadata
async def shutdown_trigger_placeholder(self):
- while not self._event_queue.closed:
+ while not self._event_queue.closed and not self.stop:
await asyncio.sleep(1)
logger.info("aiocqhttp 适配器已关闭。")
diff --git a/astrbot/core/platform/sources/gewechat/client.py b/astrbot/core/platform/sources/gewechat/client.py
index 23021f1c4..34aaa7a7d 100644
--- a/astrbot/core/platform/sources/gewechat/client.py
+++ b/astrbot/core/platform/sources/gewechat/client.py
@@ -54,6 +54,8 @@ class SimpleGewechatClient():
self.multimedia_downloader = None
self.userrealnames = {}
+
+ self.stop = False
async def get_token_id(self):
async with aiohttp.ClientSession() as session:
@@ -231,7 +233,7 @@ class SimpleGewechatClient():
)
async def shutdown_trigger_placeholder(self):
- while not self.event_queue.closed:
+ while not self.event_queue.closed and not self.stop:
await asyncio.sleep(1)
logger.info("gewechat 适配器已关闭。")
diff --git a/astrbot/core/platform/sources/gewechat/gewechat_platform_adapter.py b/astrbot/core/platform/sources/gewechat/gewechat_platform_adapter.py
index 1ca47391e..7e325ce8b 100644
--- a/astrbot/core/platform/sources/gewechat/gewechat_platform_adapter.py
+++ b/astrbot/core/platform/sources/gewechat/gewechat_platform_adapter.py
@@ -47,6 +47,10 @@ class GewechatPlatformAdapter(Platform):
"基于 gewechat 的 Wechat 适配器",
)
+ async def terminate(self):
+ self.client.stop = True
+ await asyncio.sleep(1)
+
@override
def run(self):
self.client = SimpleGewechatClient(
diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py
index 70a4af389..62f79b2f8 100644
--- a/astrbot/core/provider/manager.py
+++ b/astrbot/core/provider/manager.py
@@ -1,11 +1,9 @@
import traceback
-import uuid
from astrbot.core.config.astrbot_config import AstrBotConfig
from .provider import Provider, STTProvider, TTSProvider, Personality
from .entites import ProviderType
from typing import List
from astrbot.core.db import BaseDatabase
-from collections import defaultdict
from .register import provider_cls_map, llm_tools
from astrbot.core import logger, sp
@@ -16,6 +14,14 @@ class ProviderManager():
self.provider_stt_settings: dict = config.get('provider_stt_settings', {})
self.provider_tts_settings: dict = config.get('provider_tts_settings', {})
self.persona_configs: list = config.get('persona', [])
+ self.astrbot_config = config
+
+ self.selected_provider_id = sp.get("curr_provider")
+ self.selected_stt_provider_id = self.provider_stt_settings.get("provider_id")
+ self.selected_tts_provider_id = self.provider_settings.get("provider_id")
+ self.provider_enabled = self.provider_settings.get("enable", False)
+ self.stt_enabled = self.provider_stt_settings.get("enable", False)
+ self.tts_enabled = self.provider_tts_settings.get("enable", False)
# 人格情景管理
# 目前没有拆成独立的模块
@@ -75,14 +81,15 @@ class ProviderManager():
_mood_imitation_dialogs_processed=""
)
self.personas.append(self.selected_default_persona)
-
-
+
self.provider_insts: List[Provider] = []
'''加载的 Provider 的实例'''
self.stt_provider_insts: List[STTProvider] = []
'''加载的 Speech To Text Provider 的实例'''
self.tts_provider_insts: List[TTSProvider] = []
'''加载的 Text To Speech Provider 的实例'''
+ self.inst_map = {}
+ '''Provider 实例映射. key: provider_id, value: Provider 实例'''
self.llm_tools = llm_tools
self.curr_provider_inst: Provider = None
'''当前使用的 Provider 实例'''
@@ -90,7 +97,6 @@ class ProviderManager():
'''当前使用的 Speech To Text Provider 实例'''
self.curr_tts_provider_inst: TTSProvider = None
'''当前使用的 Text To Speech Provider 实例'''
- self.loaded_ids = defaultdict(bool)
self.db_helper = db_helper
# kdb(experimental)
@@ -99,145 +105,157 @@ class ProviderManager():
if kdb_cfg and len(kdb_cfg):
self.curr_kdb_name = list(kdb_cfg.keys())[0]
- changed = False
- for provider_cfg in self.providers_config:
- if not provider_cfg['enable']:
- continue
-
- if provider_cfg['id'] in self.loaded_ids:
- new_id = f"{provider_cfg['id']}_{str(uuid.uuid4())[:8]}"
- logger.info(f"Provider ID 重复:{provider_cfg['id']}。已自动更改为 {new_id}。")
- provider_cfg['id'] = new_id
- changed = True
- self.loaded_ids[provider_cfg['id']] = True
-
- try:
- match provider_cfg['type']:
- case "openai_chat_completion":
- from .sources.openai_source import ProviderOpenAIOfficial as ProviderOpenAIOfficial
- case "zhipu_chat_completion":
- from .sources.zhipu_source import ProviderZhipu as ProviderZhipu
- case "anthropic_chat_completion":
- from .sources.anthropic_source import ProviderAnthropic as ProviderAnthropic
- case "llm_tuner":
- logger.info("加载 LLM Tuner 工具 ...")
- from .sources.llmtuner_source import LLMTunerModelLoader as LLMTunerModelLoader
- case "dify":
- from .sources.dify_source import ProviderDify as ProviderDify
- case "dashscope":
- from .sources.dashscope_source import ProviderDashscope as ProviderDashscope
- case "googlegenai_chat_completion":
- from .sources.gemini_source import ProviderGoogleGenAI as ProviderGoogleGenAI
- case "openai_whisper_api":
- from .sources.whisper_api_source import ProviderOpenAIWhisperAPI as ProviderOpenAIWhisperAPI
- case "openai_whisper_selfhost":
- from .sources.whisper_selfhosted_source import ProviderOpenAIWhisperSelfHost as ProviderOpenAIWhisperSelfHost
- case "openai_tts_api":
- from .sources.openai_tts_api_source import ProviderOpenAITTSAPI as ProviderOpenAITTSAPI
- case "fishaudio_tts_api":
- from .sources.fishaudio_tts_api_source import ProviderFishAudioTTSAPI as ProviderFishAudioTTSAPI
- except (ImportError, ModuleNotFoundError) as e:
- logger.critical(f"加载 {provider_cfg['type']}({provider_cfg['id']}) 提供商适配器失败:{e}。可能是因为有未安装的依赖。")
- continue
- except Exception as e:
- logger.critical(f"加载 {provider_cfg['type']}({provider_cfg['id']}) 提供商适配器失败:{e}。未知原因")
- continue
-
- if changed:
- try:
- config.save_config()
- except Exception as e:
- logger.warning(f"保存配置文件失败:{e}")
-
async def initialize(self):
-
- selected_provider_id = sp.get("curr_provider")
- selected_stt_provider_id = self.provider_stt_settings.get("provider_id")
- selected_tts_provider_id = self.provider_settings.get("provider_id")
- provider_enabled = self.provider_settings.get("enable", False)
- stt_enabled = self.provider_stt_settings.get("enable", False)
- tts_enabled = self.provider_tts_settings.get("enable", False)
-
for provider_config in self.providers_config:
- if not provider_config['enable']:
- continue
- if provider_config['type'] not in provider_cls_map:
- logger.error(f"未找到适用于 {provider_config['type']}({provider_config['id']}) 的提供商适配器,请检查是否已经安装或者名称填写错误。已跳过。")
- continue
+ await self.load_provider(provider_config)
- provider_metadata = provider_cls_map[provider_config['type']]
- logger.debug(f"尝试实例化 {provider_config['type']}({provider_config['id']}) 提供商适配器 ...")
- try:
- # 按任务实例化提供商
-
- if provider_metadata.provider_type == ProviderType.SPEECH_TO_TEXT:
- # STT 任务
- inst = provider_metadata.cls_type(provider_config, self.provider_settings)
-
- if getattr(inst, "initialize", None):
- await inst.initialize()
-
- self.stt_provider_insts.append(inst)
- if selected_stt_provider_id == provider_config['id'] and stt_enabled:
- self.curr_stt_provider_inst = inst
- logger.info(f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前语音转文本提供商适配器。")
-
- elif provider_metadata.provider_type == ProviderType.TEXT_TO_SPEECH:
- # TTS 任务
- inst = provider_metadata.cls_type(provider_config, self.provider_settings)
-
- if getattr(inst, "initialize", None):
- await inst.initialize()
-
- self.tts_provider_insts.append(inst)
- if selected_tts_provider_id == provider_config['id'] and tts_enabled:
- self.curr_tts_provider_inst = inst
- logger.info(f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前文本转语音提供商适配器。")
-
- elif provider_metadata.provider_type == ProviderType.CHAT_COMPLETION:
- # 文本生成任务
- inst = provider_metadata.cls_type(
- provider_config,
- self.provider_settings,
- self.db_helper,
- self.provider_settings.get('persistant_history', True),
- self.selected_default_persona
- )
-
- if getattr(inst, "initialize", None):
- await inst.initialize()
-
- self.provider_insts.append(inst)
- if selected_provider_id == provider_config['id'] and provider_enabled:
- self.curr_provider_inst = inst
- logger.info(f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前提供商适配器。")
-
- except Exception as e:
- traceback.print_exc()
- logger.error(f"实例化 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}")
-
- if len(self.provider_insts) > 0 and not self.curr_provider_inst and provider_enabled:
- self.curr_provider_inst = self.provider_insts[0]
-
- if len(self.stt_provider_insts) > 0 and not self.curr_stt_provider_inst and stt_enabled:
- self.curr_stt_provider_inst = self.stt_provider_insts[0]
-
- if len(self.tts_provider_insts) > 0 and not self.curr_tts_provider_inst and tts_enabled:
- self.curr_tts_provider_inst = self.tts_provider_insts[0]
-
if not self.curr_provider_inst:
logger.warning("未启用任何用于 文本生成 的提供商适配器。")
- if stt_enabled and not self.curr_stt_provider_inst:
+ if self.stt_enabled and not self.curr_stt_provider_inst:
logger.warning("未启用任何用于 语音转文本 的提供商适配器。")
- if tts_enabled and not self.curr_tts_provider_inst:
+ if self.tts_enabled and not self.curr_tts_provider_inst:
logger.warning("未启用任何用于 文本转语音 的提供商适配器。")
+
+ async def load_provider(self, provider_config: dict):
+ if not provider_config['enable']:
+ return
+ logger.info(f"载入 {provider_config['type']}({provider_config['id']}) 服务提供商适配器 ...")
+
+ # 动态导入
+ try:
+ match provider_config['type']:
+ case "openai_chat_completion":
+ from .sources.openai_source import ProviderOpenAIOfficial as ProviderOpenAIOfficial
+ case "zhipu_chat_completion":
+ from .sources.zhipu_source import ProviderZhipu as ProviderZhipu
+ case "anthropic_chat_completion":
+ from .sources.anthropic_source import ProviderAnthropic as ProviderAnthropic
+ case "llm_tuner":
+ logger.info("加载 LLM Tuner 工具 ...")
+ from .sources.llmtuner_source import LLMTunerModelLoader as LLMTunerModelLoader
+ case "dify":
+ from .sources.dify_source import ProviderDify as ProviderDify
+ case "dashscope":
+ from .sources.dashscope_source import ProviderDashscope as ProviderDashscope
+ case "googlegenai_chat_completion":
+ from .sources.gemini_source import ProviderGoogleGenAI as ProviderGoogleGenAI
+ case "openai_whisper_api":
+ from .sources.whisper_api_source import ProviderOpenAIWhisperAPI as ProviderOpenAIWhisperAPI
+ case "openai_whisper_selfhost":
+ from .sources.whisper_selfhosted_source import ProviderOpenAIWhisperSelfHost as ProviderOpenAIWhisperSelfHost
+ case "openai_tts_api":
+ from .sources.openai_tts_api_source import ProviderOpenAITTSAPI as ProviderOpenAITTSAPI
+ case "fishaudio_tts_api":
+ from .sources.fishaudio_tts_api_source import ProviderFishAudioTTSAPI as ProviderFishAudioTTSAPI
+ except (ImportError, ModuleNotFoundError) as e:
+ logger.critical(f"加载 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}。可能是因为有未安装的依赖。")
+ return
+ except Exception as e:
+ logger.critical(f"加载 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}。未知原因")
+ return
+
+ if provider_config['type'] not in provider_cls_map:
+ logger.error(f"未找到适用于 {provider_config['type']}({provider_config['id']}) 的提供商适配器,请检查是否已经安装或者名称填写错误。已跳过。")
+ return
+
+ provider_metadata = provider_cls_map[provider_config['type']]
+ try:
+ # 按任务实例化提供商
+
+ if provider_metadata.provider_type == ProviderType.SPEECH_TO_TEXT:
+ # STT 任务
+ inst = provider_metadata.cls_type(provider_config, self.provider_settings)
+
+ if getattr(inst, "initialize", None):
+ await inst.initialize()
+
+ self.stt_provider_insts.append(inst)
+ if self.selected_stt_provider_id == provider_config['id'] and self.stt_enabled:
+ self.curr_stt_provider_inst = inst
+ logger.info(f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前语音转文本提供商适配器。")
+ if not self.curr_stt_provider_inst and self.stt_enabled:
+ self.curr_stt_provider_inst = inst
+
+ elif provider_metadata.provider_type == ProviderType.TEXT_TO_SPEECH:
+ # TTS 任务
+ inst = provider_metadata.cls_type(provider_config, self.provider_settings)
+
+ if getattr(inst, "initialize", None):
+ await inst.initialize()
+
+ self.tts_provider_insts.append(inst)
+ if self.selected_tts_provider_id == provider_config['id'] and self.tts_enabled:
+ self.curr_tts_provider_inst = inst
+ logger.info(f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前文本转语音提供商适配器。")
+ if not self.curr_tts_provider_inst and self.tts_enabled:
+ self.curr_tts_provider_inst = inst
+
+ elif provider_metadata.provider_type == ProviderType.CHAT_COMPLETION:
+ # 文本生成任务
+ inst = provider_metadata.cls_type(
+ provider_config,
+ self.provider_settings,
+ self.db_helper,
+ self.provider_settings.get('persistant_history', True),
+ self.selected_default_persona
+ )
+
+ if getattr(inst, "initialize", None):
+ await inst.initialize()
+
+ self.provider_insts.append(inst)
+ if self.selected_provider_id == provider_config['id'] and self.provider_enabled:
+ self.curr_provider_inst = inst
+ logger.info(f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前提供商适配器。")
+ if not self.curr_provider_inst and self.provider_enabled:
+ self.curr_provider_inst = inst
+
+ self.inst_map[provider_config['id']] = inst
+ except Exception as e:
+ traceback.print_exc()
+ logger.error(f"实例化 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}")
+
+ async def reload(self, provider_config: dict):
+ await self.terminate_provider(provider_config['id'])
+ if provider_config['enable']:
+ await self.load_provider(provider_config)
+
+ # 和配置文件保持同步
+ config_ids = [provider['id'] for provider in self.providers_config]
+ for key in list(self.inst_map.keys()):
+ if key not in config_ids:
+ await self.terminate_provider(key)
+
+ if len(self.provider_insts) == 0:
+ self.curr_provider_inst = None
+ if len(self.stt_provider_insts) == 0:
+ self.curr_stt_provider_inst = None
+ if len(self.tts_provider_insts) == 0:
+ self.curr_tts_provider_inst = None
def get_insts(self):
return self.provider_insts
+ async def terminate_provider(self, provider_id: str):
+ if provider_id in self.inst_map:
+
+ logger.info(f"终止 {provider_id} 提供商适配器 ...")
+
+ if self.inst_map[provider_id] in self.provider_insts:
+ self.provider_insts.remove(self.inst_map[provider_id])
+ if self.inst_map[provider_id] in self.stt_provider_insts:
+ self.stt_provider_insts.remove(self.inst_map[provider_id])
+ if self.inst_map[provider_id] in self.tts_provider_insts:
+ self.tts_provider_insts.remove(self.inst_map[provider_id])
+
+ if getattr(self.inst_map[provider_id], 'terminate', None):
+ await self.inst_map[provider_id].terminate()
+ logger.info(f"{provider_id} 提供商适配器已终止。")
+ del self.inst_map[provider_id]
+
async def terminate(self):
for provider_inst in self.provider_insts:
if hasattr(provider_inst, "terminate"):
diff --git a/astrbot/dashboard/routes/config.py b/astrbot/dashboard/routes/config.py
index 61810705c..0e2cdc1c4 100644
--- a/astrbot/dashboard/routes/config.py
+++ b/astrbot/dashboard/routes/config.py
@@ -1,4 +1,5 @@
import typing
+import traceback
from .route import Route, Response, RouteContext
from quart import request
from astrbot.core.config.default import CONFIG_METADATA_2, DEFAULT_VALUE_MAP
@@ -61,7 +62,7 @@ def validate_config(data, schema: dict, is_core: bool) -> typing.Tuple[typing.Li
group_meta = group.get("metadata")
if not group_meta:
continue
- logger.info(f"验证配置: 组 {key} ...")
+ # logger.info(f"验证配置: 组 {key} ...")
validate(data, group_meta, path=f"{key}.")
else:
validate(data, schema)
@@ -77,6 +78,7 @@ def save_config(post_config: dict, config: AstrBotConfig, is_core: bool = False)
else:
errors, post_config = validate_config(post_config, config.schema, is_core)
except BaseException as e:
+ logger.error(traceback.format_exc())
logger.warning(f"验证配置时出现异常: {e}")
if errors:
raise ValueError(f"格式校验未通过: {errors}")
@@ -90,6 +92,14 @@ class ConfigRoute(Route):
'/config/get': ('GET', self.get_configs),
'/config/astrbot/update': ('POST', self.post_astrbot_configs),
'/config/plugin/update': ('POST', self.post_plugin_configs),
+
+ '/config/platform/new': ('POST', self.post_new_platform),
+ '/config/platform/update': ('POST', self.post_update_platform),
+ '/config/platform/delete': ('POST', self.post_delete_platform),
+
+ '/config/provider/new': ('POST', self.post_new_provider),
+ '/config/provider/update': ('POST', self.post_update_provider),
+ '/config/provider/delete': ('POST', self.post_delete_provider)
}
self.register_routes()
@@ -118,7 +128,99 @@ class ConfigRoute(Route):
return Response().ok(None, f"保存插件 {plugin_name} 成功~ 机器人正在重载配置。").__dict__
except Exception as e:
return Response().error(str(e)).__dict__
-
+
+ async def post_new_platform(self):
+ new_platform_config = await request.json
+ self.config['platform'].append(new_platform_config)
+ try:
+ save_config(self.config, self.config, is_core=True)
+ await self.core_lifecycle.platform_manager.load_platform(new_platform_config)
+ except Exception as e:
+ return Response().error(str(e)).__dict__
+ return Response().ok(None, "新增平台配置成功~").__dict__
+
+ async def post_new_provider(self):
+ new_provider_config = await request.json
+ self.config['provider'].append(new_provider_config)
+ try:
+ save_config(self.config, self.config, is_core=True)
+ await self.core_lifecycle.provider_manager.load_provider(new_provider_config)
+ except Exception as e:
+ return Response().error(str(e)).__dict__
+ return Response().ok(None, "新增服务提供商配置成功~").__dict__
+
+ async def post_update_platform(self):
+ update_platform_config = await request.json
+ platform_id = update_platform_config.get("id", None)
+ new_config = update_platform_config.get("config", None)
+ if not platform_id or not new_config:
+ return Response().error("参数错误").__dict__
+
+ for i, platform in enumerate(self.config['platform']):
+ if platform['id'] == platform_id:
+ self.config['platform'][i] = new_config
+ break
+ else:
+ return Response().error("未找到对应平台").__dict__
+
+ try:
+ await self._save_astrbot_configs(self.config)
+ except Exception as e:
+ return Response().error(str(e)).__dict__
+ return Response().ok(None, "更新平台配置成功~").__dict__
+
+ async def post_update_provider(self):
+ update_provider_config = await request.json
+ provider_id = update_provider_config.get("id", None)
+ new_config = update_provider_config.get("config", None)
+ if not provider_id or not new_config:
+ return Response().error("参数错误").__dict__
+
+ for i, provider in enumerate(self.config['provider']):
+ if provider['id'] == provider_id:
+ self.config['provider'][i] = new_config
+ break
+ else:
+ return Response().error("未找到对应服务提供商").__dict__
+
+ try:
+ save_config(self.config, self.config, is_core=True)
+ await self.core_lifecycle.provider_manager.reload(new_config)
+ except Exception as e:
+ return Response().error(str(e)).__dict__
+ return Response().ok(None, "更新成功,已经实时生效~").__dict__
+
+ async def post_delete_platform(self):
+ platform_id = await request.json
+ platform_id = platform_id.get("id")
+ for i, platform in enumerate(self.config['platform']):
+ if platform['id'] == platform_id:
+ del self.config['platform'][i]
+ break
+ else:
+ return Response().error("未找到对应平台").__dict__
+ try:
+ await self._save_astrbot_configs(self.config)
+ except Exception as e:
+ return Response().error(str(e)).__dict__
+ return Response().ok(None, "删除平台配置成功~").__dict__
+
+ async def post_delete_provider(self):
+ provider_id = await request.json
+ provider_id = provider_id.get("id")
+ for i, provider in enumerate(self.config['provider']):
+ if provider['id'] == provider_id:
+ del self.config['provider'][i]
+ break
+ else:
+ return Response().error("未找到对应服务提供商").__dict__
+ try:
+ save_config(self.config, self.config, is_core=True)
+ await self.core_lifecycle.provider_manager.terminate_provider(provider_id)
+ except Exception as e:
+ return Response().error(str(e)).__dict__
+ return Response().ok(None, "删除成功,已经实时生效~").__dict__
+
async def _get_astrbot_config(self):
config = self.config
diff --git a/dashboard/src/components/shared/ConsoleDisplayer.vue b/dashboard/src/components/shared/ConsoleDisplayer.vue
index 620575102..a24460e37 100644
--- a/dashboard/src/components/shared/ConsoleDisplayer.vue
+++ b/dashboard/src/components/shared/ConsoleDisplayer.vue
@@ -4,7 +4,7 @@ import { useCommonStore } from '@/stores/common';
重启成功!