diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 6af27337b..6523a31ee 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -42,10 +42,12 @@ DEFAULT_CONFIG = { "empty_mention_waiting": True, "friend_message_needs_wake_prefix": False, "ignore_bot_self_message": False, + "seperate_provider": False, }, "provider": [], "provider_settings": { "enable": True, + "default_provider_id": "", "wake_prefix": "", "web_search": False, "web_search_link": False, @@ -1379,9 +1381,14 @@ CONFIG_METADATA_2 = { "enable": { "description": "启用大语言模型聊天", "type": "bool", - "hint": "如需切换大语言模型提供商,请使用 `/provider` 命令。", + "hint": "如需切换大语言模型提供商,请使用 /provider 命令。", "obvious_hint": True, }, + "default_provider_id": { + "description": "默认模型提供商 ID", + "type": "str", + "hint": "可选。每个聊天会话的默认提供商 ID。" + }, "wake_prefix": { "description": "LLM 聊天额外唤醒前缀", "type": "string", diff --git a/astrbot/core/pipeline/process_stage/method/llm_request.py b/astrbot/core/pipeline/process_stage/method/llm_request.py index d22a1f453..fc70689e3 100644 --- a/astrbot/core/pipeline/process_stage/method/llm_request.py +++ b/astrbot/core/pipeline/process_stage/method/llm_request.py @@ -52,6 +52,9 @@ class LLMRequestSubStage(Stage): self.streaming_response = ctx.astrbot_config["provider_settings"][ "streaming_response" ] # bool + self.seperate_provider = ctx.astrbot_config["provider_settings"][ + "seperate_provider" + ] # bool for bwp in self.bot_wake_prefixs: if self.provider_wake_prefix.startswith(bwp): @@ -70,8 +73,10 @@ class LLMRequestSubStage(Stage): if not self.ctx.astrbot_config["provider_settings"]["enable"]: logger.debug("未启用 LLM 能力,跳过处理。") return - - provider = self.ctx.plugin_manager.context.get_using_provider() + umo = None + if self.seperate_provider: + umo = event.unified_msg_origin + provider = self.ctx.plugin_manager.context.get_using_provider(umo=umo) if provider is None: return diff --git a/astrbot/core/pipeline/process_stage/stage.py b/astrbot/core/pipeline/process_stage/stage.py index f653a9fb9..90fac5734 100644 --- a/astrbot/core/pipeline/process_stage/stage.py +++ b/astrbot/core/pipeline/process_stage/stage.py @@ -6,7 +6,7 @@ from .method.star_request import StarRequestSubStage from astrbot.core.platform.astr_message_event import AstrMessageEvent from astrbot.core.star.star_handler import StarHandlerMetadata from astrbot.core.provider.entities import ProviderRequest -from astrbot.core import logger +from astrbot.core import logger, sp @register_stage diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py index edfd9f581..7180cf454 100644 --- a/astrbot/core/provider/manager.py +++ b/astrbot/core/provider/manager.py @@ -18,12 +18,9 @@ class ProviderManager: self.persona_configs: list = config.get("persona", []) self.astrbot_config = config - self.selected_provider_id = sp.get("curr_provider") + self.selected_provider_id = self.provider_settings.get("default_provider_id") self.selected_stt_provider_id = self.provider_stt_settings.get("provider_id") self.selected_tts_provider_id = self.provider_settings.get("provider_id") - # self.provider_enabled = self.provider_settings.get("enable", False) - # self.stt_enabled = self.provider_stt_settings.get("enable", False) - # self.tts_enabled = self.provider_tts_settings.get("enable", False) # 人格情景管理 # 目前没有拆成独立的模块 @@ -103,12 +100,10 @@ class ProviderManager: self.inst_map = {} """Provider 实例映射. key: provider_id, value: Provider 实例""" self.llm_tools = llm_tools - self.default_provider_inst: Provider = None - """默认的 Provider 实例。第 0 个或者用户以前指定的 Provider 实例""" self.curr_provider_inst: Provider = None - """当前使用的 Provider 实例""" + """默认设置的 Provider 实例""" self.curr_stt_provider_inst: STTProvider = None - """当前使用的 Speech To Text Provider 实例""" + """默认设置的 Speech To Text Provider 实例""" self.curr_tts_provider_inst: TTSProvider = None """当前使用的 Text To Speech Provider 实例""" self.db_helper = db_helper @@ -123,9 +118,10 @@ class ProviderManager: for provider_config in self.providers_config: await self.load_provider(provider_config) - self.default_provider_inst = self.inst_map.get(self.selected_provider_id) - if not self.default_provider_inst and self.provider_insts: - self.default_provider_inst = self.provider_insts[0] + self.curr_provider_inst = self.inst_map.get(self.selected_provider_id) + if not self.curr_provider_inst and self.provider_insts: + self.curr_provider_inst = self.provider_insts[0] + # 初始化 MCP Client 连接 asyncio.create_task( diff --git a/astrbot/core/star/context.py b/astrbot/core/star/context.py index 880b0c72c..d0ec99173 100644 --- a/astrbot/core/star/context.py +++ b/astrbot/core/star/context.py @@ -140,24 +140,49 @@ class Context: """获取所有用于 STT 任务的 Provider。""" return self.provider_manager.stt_provider_insts - def get_using_provider(self) -> Provider: + def get_using_provider(self, umo: str = None) -> Provider: """ - 获取当前使用的用于文本生成任务的 LLM Provider(Chat_Completion 类型)。 + 获取当前使用的用于文本生成任务的 LLM Provider(Chat_Completion 类型)。通过 /provider 指令切换。 - 通过 /provider 指令切换。 + Args: + umo(str): unified_message_origin 值,如果传入,则使用该会话偏好的提供商。 """ + if umo: + perf = sp.get("session_provider_perf", {}) + provider_id = perf.get(umo, None) + inst = self.provider_manager.inst_map.get(provider_id, None) + if inst: + return inst return self.provider_manager.curr_provider_inst - def get_using_tts_provider(self) -> TTSProvider: + def get_using_tts_provider(self, umo: str = None) -> TTSProvider: """ 获取当前使用的用于 TTS 任务的 Provider。 + + Args: + umo(str): unified_message_origin 值,如果传入,则使用该会话偏好的提供商。 """ + if umo: + perf = sp.get("session_tts_provider_perf", {}) + provider_id = perf.get(umo, None) + inst = self.provider_manager.inst_map.get(provider_id, None) + if inst: + return inst return self.provider_manager.curr_tts_provider_inst - def get_using_stt_provider(self) -> STTProvider: + def get_using_stt_provider(self, umo: str = None) -> STTProvider: """ 获取当前使用的用于 STT 任务的 Provider。 + + Args: + umo(str): unified_message_origin 值,如果传入,则使用该会话偏好的提供商。 """ + if umo: + perf = sp.get("session_stt_provider_perf", {}) + provider_id = perf.get(umo, None) + inst = self.provider_manager.inst_map.get(provider_id, None) + if inst: + return inst return self.provider_manager.curr_stt_provider_inst def get_config(self) -> AstrBotConfig: