diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index d1089dd02..a1efc6790 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -46,6 +46,7 @@ DEFAULT_CONFIG = { "provider": [], "provider_settings": { "enable": True, + "default_provider_id": "", "wake_prefix": "", "web_search": False, "web_search_link": False, @@ -57,6 +58,7 @@ DEFAULT_CONFIG = { "dequeue_context_length": 1, "streaming_response": False, "streaming_segmented": False, + "separate_provider": False, }, "provider_stt_settings": { "enable": False, @@ -1386,9 +1388,19 @@ CONFIG_METADATA_2 = { "enable": { "description": "启用大语言模型聊天", "type": "bool", - "hint": "如需切换大语言模型提供商,请使用 `/provider` 命令。", + "hint": "如需切换大语言模型提供商,请使用 /provider 命令。", "obvious_hint": True, }, + "separate_provider": { + "description": "提供商会话隔离", + "type": "bool", + "hint": "启用后,每个会话支持独立选择文本生成、STT、TTS 等提供商。如果会话在使用 /provider 指令时提示无权限,可以将会话加入管理员名单或者使用 /alter_cmd provider member 将指令设为非管理员指令。", + }, + "default_provider_id": { + "description": "默认模型提供商 ID", + "type": "string", + "hint": "可选。每个聊天会话的默认提供商 ID。", + }, "wake_prefix": { "description": "LLM 聊天额外唤醒前缀", "type": "string", @@ -1501,7 +1513,7 @@ CONFIG_METADATA_2 = { "obvious_hint": True, }, "provider_id": { - "description": "提供商 ID,不填则默认第一个STT提供商", + "description": "提供商 ID", "type": "string", "hint": "语音转文本提供商 ID。如果不填写将使用载入的第一个提供商。", }, @@ -1518,7 +1530,7 @@ CONFIG_METADATA_2 = { "obvious_hint": True, }, "provider_id": { - "description": "提供商 ID,不填则默认第一个TTS提供商", + "description": "提供商 ID", "type": "string", "hint": "文本转语音提供商 ID。如果不填写将使用载入的第一个提供商。", }, diff --git a/astrbot/core/pipeline/preprocess_stage/stage.py b/astrbot/core/pipeline/preprocess_stage/stage.py index 3e89e1c3e..3e0d4e50f 100644 --- a/astrbot/core/pipeline/preprocess_stage/stage.py +++ b/astrbot/core/pipeline/preprocess_stage/stage.py @@ -43,9 +43,8 @@ class PreProcessStage(Stage): # STT if self.stt_settings.get("enable", False): # TODO: 独立 - stt_provider = ( - self.plugin_manager.context.provider_manager.curr_stt_provider_inst - ) + ctx = self.plugin_manager.context + stt_provider = ctx.get_using_stt_provider(event.unified_msg_origin) if not stt_provider: return message_chain = event.get_messages() diff --git a/astrbot/core/pipeline/process_stage/method/llm_request.py b/astrbot/core/pipeline/process_stage/method/llm_request.py index d22a1f453..20b7651db 100644 --- a/astrbot/core/pipeline/process_stage/method/llm_request.py +++ b/astrbot/core/pipeline/process_stage/method/llm_request.py @@ -70,8 +70,8 @@ class LLMRequestSubStage(Stage): if not self.ctx.astrbot_config["provider_settings"]["enable"]: logger.debug("未启用 LLM 能力,跳过处理。") return - - provider = self.ctx.plugin_manager.context.get_using_provider() + umo = event.unified_msg_origin + provider = self.ctx.plugin_manager.context.get_using_provider(umo=umo) if provider is None: return diff --git a/astrbot/core/pipeline/result_decorate/stage.py b/astrbot/core/pipeline/result_decorate/stage.py index 7a12788a0..2efd98ca3 100644 --- a/astrbot/core/pipeline/result_decorate/stage.py +++ b/astrbot/core/pipeline/result_decorate/stage.py @@ -169,8 +169,8 @@ class ResultDecorateStage(Stage): result.chain = new_chain # TTS - tts_provider = ( - self.ctx.plugin_manager.context.provider_manager.curr_tts_provider_inst + tts_provider = self.ctx.plugin_manager.context.get_using_tts_provider( + event.unified_msg_origin ) if ( self.ctx.astrbot_config["provider_tts_settings"]["enable"] diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py index 160c624b0..b11a3361a 100644 --- a/astrbot/core/provider/manager.py +++ b/astrbot/core/provider/manager.py @@ -18,13 +18,6 @@ class ProviderManager: self.persona_configs: list = config.get("persona", []) self.astrbot_config = config - self.selected_provider_id = sp.get("curr_provider") - self.selected_stt_provider_id = self.provider_stt_settings.get("provider_id") - self.selected_tts_provider_id = self.provider_settings.get("provider_id") - # self.provider_enabled = self.provider_settings.get("enable", False) - # self.stt_enabled = self.provider_stt_settings.get("enable", False) - # self.tts_enabled = self.provider_tts_settings.get("enable", False) - # 人格情景管理 # 目前没有拆成独立的模块 self.default_persona_name = self.provider_settings.get( @@ -103,14 +96,13 @@ class ProviderManager: self.inst_map = {} """Provider 实例映射. key: provider_id, value: Provider 实例""" self.llm_tools = llm_tools - self.default_provider_inst: Provider = None - """默认的 Provider 实例。第 0 个或者用户以前指定的 Provider 实例""" + self.curr_provider_inst: Provider = None - """当前使用的 Provider 实例""" + """默认的 Provider 实例""" self.curr_stt_provider_inst: STTProvider = None - """当前使用的 Speech To Text Provider 实例""" + """默认的 Speech To Text Provider 实例""" self.curr_tts_provider_inst: TTSProvider = None - """当前使用的 Text To Speech Provider 实例""" + """默认的 Text To Speech Provider 实例""" self.db_helper = db_helper # kdb(experimental) @@ -119,13 +111,57 @@ class ProviderManager: if kdb_cfg and len(kdb_cfg): self.curr_kdb_name = list(kdb_cfg.keys())[0] + async def set_provider( + self, provider_id: str, provider_type: ProviderType, umo: str = None + ): + """设置提供商。 + + Args: + provider_id (str): 提供商 ID。 + provider_type (ProviderType): 提供商类型。 + umo (str, optional): 用户会话 ID,用于提供商会话隔离。当用户启用了提供商会话隔离时此参数才生效。 + """ + if provider_id not in self.inst_map: + raise ValueError(f"提供商 {provider_id} 不存在,无法设置。") + if umo and self.provider_settings["separate_provider"]: + perf = sp.get("session_provider_perf", {}) + session_perf = perf.get(umo, {}) + session_perf[provider_type.value] = provider_id + perf[umo] = session_perf + sp.put("session_provider_perf", perf) + return + # 不启用提供商会话隔离模式的情况 + self.curr_provider_inst = self.inst_map[provider_id] + if provider_type == ProviderType.TEXT_TO_SPEECH: + sp.put("curr_provider_tts", provider_id) + elif provider_type == ProviderType.SPEECH_TO_TEXT: + sp.put("curr_provider_stt", provider_id) + elif provider_type == ProviderType.CHAT_COMPLETION: + sp.put("curr_provider", provider_id) + async def initialize(self): + # 逐个初始化提供商 for provider_config in self.providers_config: await self.load_provider(provider_config) - self.default_provider_inst = self.inst_map.get(self.selected_provider_id) - if not self.default_provider_inst and self.provider_insts: - self.default_provider_inst = self.provider_insts[0] + # 设置默认提供商 + self.curr_provider_inst = self.inst_map.get( + self.provider_settings.get("default_provider_id") + ) + if not self.curr_provider_inst and self.provider_insts: + self.curr_provider_inst = self.provider_insts[0] + + self.curr_stt_provider_inst = self.inst_map.get( + self.provider_stt_settings.get("provider_id") + ) + if not self.curr_stt_provider_inst and self.stt_provider_insts: + self.curr_stt_provider_inst = self.stt_provider_insts[0] + + self.curr_tts_provider_inst = self.inst_map.get( + self.provider_tts_settings.get("provider_id") + ) + if not self.curr_tts_provider_inst and self.tts_provider_insts: + self.curr_tts_provider_inst = self.tts_provider_insts[0] # 初始化 MCP Client 连接 asyncio.create_task( @@ -252,7 +288,10 @@ class ProviderManager: await inst.initialize() self.stt_provider_insts.append(inst) - if self.selected_stt_provider_id == provider_config["id"]: + if ( + self.provider_stt_settings.get("provider_id") + == provider_config["id"] + ): self.curr_stt_provider_inst = inst logger.info( f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前语音转文本提供商适配器。" @@ -270,7 +309,7 @@ class ProviderManager: await inst.initialize() self.tts_provider_insts.append(inst) - if self.selected_tts_provider_id == provider_config["id"]: + if self.provider_settings.get("provider_id") == provider_config["id"]: self.curr_tts_provider_inst = inst logger.info( f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前文本转语音提供商适配器。" @@ -292,7 +331,10 @@ class ProviderManager: await inst.initialize() self.provider_insts.append(inst) - if self.selected_provider_id == provider_config["id"]: + if ( + self.provider_settings.get("default_provider_id") + == provider_config["id"] + ): self.curr_provider_inst = inst logger.info( f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前提供商适配器。" @@ -330,7 +372,6 @@ class ProviderManager: self.curr_provider_inst = None elif self.curr_provider_inst is None and len(self.provider_insts) > 0: self.curr_provider_inst = self.provider_insts[0] - self.selected_provider_id = self.curr_provider_inst.meta().id logger.info( f"自动选择 {self.curr_provider_inst.meta().id} 作为当前提供商适配器。" ) @@ -339,7 +380,6 @@ class ProviderManager: self.curr_stt_provider_inst = None elif self.curr_stt_provider_inst is None and len(self.stt_provider_insts) > 0: self.curr_stt_provider_inst = self.stt_provider_insts[0] - self.selected_stt_provider_id = self.curr_stt_provider_inst.meta().id logger.info( f"自动选择 {self.curr_stt_provider_inst.meta().id} 作为当前语音转文本提供商适配器。" ) @@ -348,7 +388,6 @@ class ProviderManager: self.curr_tts_provider_inst = None elif self.curr_tts_provider_inst is None and len(self.tts_provider_insts) > 0: self.curr_tts_provider_inst = self.tts_provider_insts[0] - self.selected_tts_provider_id = self.curr_tts_provider_inst.meta().id logger.info( f"自动选择 {self.curr_tts_provider_inst.meta().id} 作为当前文本转语音提供商适配器。" ) diff --git a/astrbot/core/star/context.py b/astrbot/core/star/context.py index 880b0c72c..dda664b1b 100644 --- a/astrbot/core/star/context.py +++ b/astrbot/core/star/context.py @@ -3,6 +3,7 @@ from typing import List, Union from astrbot.core import sp from astrbot.core.provider.provider import Provider, TTSProvider, STTProvider +from astrbot.core.provider.entities import ProviderType from astrbot.core.db import BaseDatabase from astrbot.core.config.astrbot_config import AstrBotConfig from astrbot.core.provider.func_tool_manager import FuncCall @@ -140,24 +141,46 @@ class Context: """获取所有用于 STT 任务的 Provider。""" return self.provider_manager.stt_provider_insts - def get_using_provider(self) -> Provider: + def get_using_provider(self, umo: str = None) -> Provider: """ - 获取当前使用的用于文本生成任务的 LLM Provider(Chat_Completion 类型)。 + 获取当前使用的用于文本生成任务的 LLM Provider(Chat_Completion 类型)。通过 /provider 指令切换。 - 通过 /provider 指令切换。 + Args: + umo(str): unified_message_origin 值,如果传入并且用户启用了提供商会话隔离,则使用该会话偏好的提供商。 """ + if umo and self._config["provider_settings"]["separate_provider"]: + perf = sp.get("session_provider_perf", {}) + prov_id = perf.get(umo, {}).get(ProviderType.CHAT_COMPLETION.value, None) + if inst := self.provider_manager.inst_map.get(prov_id, None): + return inst return self.provider_manager.curr_provider_inst - def get_using_tts_provider(self) -> TTSProvider: + def get_using_tts_provider(self, umo: str = None) -> TTSProvider: """ 获取当前使用的用于 TTS 任务的 Provider。 + + Args: + umo(str): unified_message_origin 值,如果传入,则使用该会话偏好的提供商。 """ + if umo and self._config["provider_settings"]["separate_provider"]: + perf = sp.get("session_provider_perf", {}) + prov_id = perf.get(umo, {}).get(ProviderType.TEXT_TO_SPEECH.value, None) + if inst := self.provider_manager.inst_map.get(prov_id, None): + return inst return self.provider_manager.curr_tts_provider_inst - def get_using_stt_provider(self) -> STTProvider: + def get_using_stt_provider(self, umo: str = None) -> STTProvider: """ 获取当前使用的用于 STT 任务的 Provider。 + + Args: + umo(str): unified_message_origin 值,如果传入,则使用该会话偏好的提供商。 """ + if umo and self._config["provider_settings"]["separate_provider"]: + perf = sp.get("session_provider_perf", {}) + prov_id = perf.get(umo, {}).get(ProviderType.SPEECH_TO_TEXT.value, None) + if inst := self.provider_manager.inst_map.get(prov_id, None): + return inst return self.provider_manager.curr_stt_provider_inst def get_config(self) -> AstrBotConfig: diff --git a/packages/astrbot/main.py b/packages/astrbot/main.py index 4fd783a67..739f12b5c 100644 --- a/packages/astrbot/main.py +++ b/packages/astrbot/main.py @@ -12,6 +12,7 @@ from astrbot.api import sp from astrbot.api.provider import ProviderRequest from astrbot.core.platform.astr_message_event import MessageSesion from astrbot.core.platform.message_type import MessageType +from astrbot.core.provider.entities import ProviderType from astrbot.core.provider.sources.dify_source import ProviderDify from astrbot.core.utils.io import download_dashboard, get_dashboard_version from astrbot.core.utils.astrbot_path import get_astrbot_data_path @@ -139,6 +140,7 @@ class Main(star.Star): {notice}""" event.set_result(MessageEventResult().message(msg).use_t2i(False)) + @filter.permission_type(filter.PermissionType.ADMIN) @filter.command("llm") async def llm(self, event: AstrMessageEvent): @@ -413,20 +415,21 @@ UID: {user_id} 此 ID 可用于设置管理员。 event.set_result(MessageEventResult().message("删除白名单成功。")) except ValueError: event.set_result(MessageEventResult().message("此 SID 不在白名单内。")) - + @filter.permission_type(filter.PermissionType.ADMIN) @filter.command("provider") async def provider( self, event: AstrMessageEvent, idx: Union[str, int] = None, idx2: int = None ): """查看或者切换 LLM Provider""" + umo = event.unified_msg_origin if idx is None: ret = "## 载入的 LLM 提供商\n" for idx, llm in enumerate(self.context.get_all_providers()): id_ = llm.meta().id ret += f"{idx + 1}. {id_} ({llm.meta().model})" - provider_using = self.context.get_using_provider() + provider_using = self.context.get_using_provider(umo=umo) if provider_using and provider_using.meta().id == id_: ret += " (当前使用)" ret += "\n" @@ -437,7 +440,7 @@ UID: {user_id} 此 ID 可用于设置管理员。 for idx, tts in enumerate(tts_providers): id_ = tts.meta().id ret += f"{idx + 1}. {id_}" - tts_using = self.context.get_using_tts_provider() + tts_using = self.context.get_using_tts_provider(umo=umo) if tts_using and tts_using.meta().id == id_: ret += " (当前使用)" ret += "\n" @@ -448,7 +451,7 @@ UID: {user_id} 此 ID 可用于设置管理员。 for idx, stt in enumerate(stt_providers): id_ = stt.meta().id ret += f"{idx + 1}. {id_}" - stt_using = self.context.get_using_stt_provider() + stt_using = self.context.get_using_stt_provider(umo=umo) if stt_using and stt_using.meta().id == id_: ret += " (当前使用)" ret += "\n" @@ -461,46 +464,54 @@ UID: {user_id} 此 ID 可用于设置管理员。 ret += "\n使用 /provider stt <切换> STT 提供商。" event.set_result(MessageEventResult().message(ret)) - else: - if idx == "tts": - if idx2 is None: - event.set_result(MessageEventResult().message("请输入序号。")) - return - else: - if idx2 > len(self.context.get_all_tts_providers()) or idx2 < 1: - event.set_result(MessageEventResult().message("无效的序号。")) - provider = self.context.get_all_tts_providers()[idx2 - 1] - id_ = provider.meta().id - self.context.provider_manager.curr_tts_provider_inst = provider - sp.put("curr_provider_tts", id_) - event.set_result( - MessageEventResult().message(f"成功切换到 {id_}。") - ) - elif idx == "stt": - if idx2 is None: - event.set_result(MessageEventResult().message("请输入序号。")) - return - else: - if idx2 > len(self.context.get_all_stt_providers()) or idx2 < 1: - event.set_result(MessageEventResult().message("无效的序号。")) - provider = self.context.get_all_stt_providers()[idx2 - 1] - id_ = provider.meta().id - self.context.provider_manager.curr_stt_provider_inst = provider - sp.put("curr_provider_stt", id_) - event.set_result( - MessageEventResult().message(f"成功切换到 {id_}。") - ) - elif isinstance(idx, int): - if idx > len(self.context.get_all_providers()) or idx < 1: - event.set_result(MessageEventResult().message("无效的序号。")) - - provider = self.context.get_all_providers()[idx - 1] - id_ = provider.meta().id - self.context.provider_manager.curr_provider_inst = provider - sp.put("curr_provider", id_) - event.set_result(MessageEventResult().message(f"成功切换到 {id_}。")) + elif idx == "tts": + if idx2 is None: + event.set_result(MessageEventResult().message("请输入序号。")) + return else: - event.set_result(MessageEventResult().message("无效的参数。")) + if idx2 > len(self.context.get_all_tts_providers()) or idx2 < 1: + event.set_result(MessageEventResult().message("无效的序号。")) + provider = self.context.get_all_tts_providers()[idx2 - 1] + id_ = provider.meta().id + await self.context.provider_manager.set_provider( + provider_id=id_, + provider_type=ProviderType.TEXT_TO_SPEECH, + umo=umo, + ) + event.set_result( + MessageEventResult().message(f"成功切换到 {id_}。") + ) + elif idx == "stt": + if idx2 is None: + event.set_result(MessageEventResult().message("请输入序号。")) + return + else: + if idx2 > len(self.context.get_all_stt_providers()) or idx2 < 1: + event.set_result(MessageEventResult().message("无效的序号。")) + provider = self.context.get_all_stt_providers()[idx2 - 1] + id_ = provider.meta().id + await self.context.provider_manager.set_provider( + provider_id=id_, + provider_type=ProviderType.SPEECH_TO_TEXT, + umo=umo, + ) + event.set_result( + MessageEventResult().message(f"成功切换到 {id_}。") + ) + elif isinstance(idx, int): + if idx > len(self.context.get_all_providers()) or idx < 1: + event.set_result(MessageEventResult().message("无效的序号。")) + + provider = self.context.get_all_providers()[idx - 1] + id_ = provider.meta().id + await self.context.provider_manager.set_provider( + provider_id=id_, + provider_type=ProviderType.CHAT_COMPLETION, + umo=umo, + ) + event.set_result(MessageEventResult().message(f"成功切换到 {id_}。")) + else: + event.set_result(MessageEventResult().message("无效的参数。")) @filter.command("reset") async def reset(self, message: AstrMessageEvent): @@ -572,7 +583,7 @@ UID: {user_id} 此 ID 可用于设置管理员。 ret += f"\n聊天增强: 已清除 {cnt} 条聊天记录。" message.set_result(MessageEventResult().message(ret)) - + @filter.permission_type(filter.PermissionType.ADMIN) @filter.command("model") async def model_ls(