From 98d2e9bd270210f91ebaf6be8c023d41d2875100 Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Thu, 5 Jun 2025 23:30:18 +0800 Subject: [PATCH 1/9] chore: stage --- astrbot/core/config/default.py | 9 ++++- .../process_stage/method/llm_request.py | 9 +++-- astrbot/core/pipeline/process_stage/stage.py | 2 +- astrbot/core/provider/manager.py | 18 ++++------ astrbot/core/star/context.py | 35 ++++++++++++++++--- 5 files changed, 53 insertions(+), 20 deletions(-) diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 6af27337b..6523a31ee 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -42,10 +42,12 @@ DEFAULT_CONFIG = { "empty_mention_waiting": True, "friend_message_needs_wake_prefix": False, "ignore_bot_self_message": False, + "seperate_provider": False, }, "provider": [], "provider_settings": { "enable": True, + "default_provider_id": "", "wake_prefix": "", "web_search": False, "web_search_link": False, @@ -1379,9 +1381,14 @@ CONFIG_METADATA_2 = { "enable": { "description": "启用大语言模型聊天", "type": "bool", - "hint": "如需切换大语言模型提供商,请使用 `/provider` 命令。", + "hint": "如需切换大语言模型提供商,请使用 /provider 命令。", "obvious_hint": True, }, + "default_provider_id": { + "description": "默认模型提供商 ID", + "type": "str", + "hint": "可选。每个聊天会话的默认提供商 ID。" + }, "wake_prefix": { "description": "LLM 聊天额外唤醒前缀", "type": "string", diff --git a/astrbot/core/pipeline/process_stage/method/llm_request.py b/astrbot/core/pipeline/process_stage/method/llm_request.py index d22a1f453..fc70689e3 100644 --- a/astrbot/core/pipeline/process_stage/method/llm_request.py +++ b/astrbot/core/pipeline/process_stage/method/llm_request.py @@ -52,6 +52,9 @@ class LLMRequestSubStage(Stage): self.streaming_response = ctx.astrbot_config["provider_settings"][ "streaming_response" ] # bool + self.seperate_provider = ctx.astrbot_config["provider_settings"][ + "seperate_provider" + ] # bool for bwp in self.bot_wake_prefixs: if self.provider_wake_prefix.startswith(bwp): @@ -70,8 +73,10 @@ class LLMRequestSubStage(Stage): if not self.ctx.astrbot_config["provider_settings"]["enable"]: logger.debug("未启用 LLM 能力,跳过处理。") return - - provider = self.ctx.plugin_manager.context.get_using_provider() + umo = None + if self.seperate_provider: + umo = event.unified_msg_origin + provider = self.ctx.plugin_manager.context.get_using_provider(umo=umo) if provider is None: return diff --git a/astrbot/core/pipeline/process_stage/stage.py b/astrbot/core/pipeline/process_stage/stage.py index f653a9fb9..90fac5734 100644 --- a/astrbot/core/pipeline/process_stage/stage.py +++ b/astrbot/core/pipeline/process_stage/stage.py @@ -6,7 +6,7 @@ from .method.star_request import StarRequestSubStage from astrbot.core.platform.astr_message_event import AstrMessageEvent from astrbot.core.star.star_handler import StarHandlerMetadata from astrbot.core.provider.entities import ProviderRequest -from astrbot.core import logger +from astrbot.core import logger, sp @register_stage diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py index edfd9f581..7180cf454 100644 --- a/astrbot/core/provider/manager.py +++ b/astrbot/core/provider/manager.py @@ -18,12 +18,9 @@ class ProviderManager: self.persona_configs: list = config.get("persona", []) self.astrbot_config = config - self.selected_provider_id = sp.get("curr_provider") + self.selected_provider_id = self.provider_settings.get("default_provider_id") self.selected_stt_provider_id = self.provider_stt_settings.get("provider_id") self.selected_tts_provider_id = self.provider_settings.get("provider_id") - # self.provider_enabled = self.provider_settings.get("enable", False) - # self.stt_enabled = self.provider_stt_settings.get("enable", False) - # self.tts_enabled = self.provider_tts_settings.get("enable", False) # 人格情景管理 # 目前没有拆成独立的模块 @@ -103,12 +100,10 @@ class ProviderManager: self.inst_map = {} """Provider 实例映射. key: provider_id, value: Provider 实例""" self.llm_tools = llm_tools - self.default_provider_inst: Provider = None - """默认的 Provider 实例。第 0 个或者用户以前指定的 Provider 实例""" self.curr_provider_inst: Provider = None - """当前使用的 Provider 实例""" + """默认设置的 Provider 实例""" self.curr_stt_provider_inst: STTProvider = None - """当前使用的 Speech To Text Provider 实例""" + """默认设置的 Speech To Text Provider 实例""" self.curr_tts_provider_inst: TTSProvider = None """当前使用的 Text To Speech Provider 实例""" self.db_helper = db_helper @@ -123,9 +118,10 @@ class ProviderManager: for provider_config in self.providers_config: await self.load_provider(provider_config) - self.default_provider_inst = self.inst_map.get(self.selected_provider_id) - if not self.default_provider_inst and self.provider_insts: - self.default_provider_inst = self.provider_insts[0] + self.curr_provider_inst = self.inst_map.get(self.selected_provider_id) + if not self.curr_provider_inst and self.provider_insts: + self.curr_provider_inst = self.provider_insts[0] + # 初始化 MCP Client 连接 asyncio.create_task( diff --git a/astrbot/core/star/context.py b/astrbot/core/star/context.py index 880b0c72c..d0ec99173 100644 --- a/astrbot/core/star/context.py +++ b/astrbot/core/star/context.py @@ -140,24 +140,49 @@ class Context: """获取所有用于 STT 任务的 Provider。""" return self.provider_manager.stt_provider_insts - def get_using_provider(self) -> Provider: + def get_using_provider(self, umo: str = None) -> Provider: """ - 获取当前使用的用于文本生成任务的 LLM Provider(Chat_Completion 类型)。 + 获取当前使用的用于文本生成任务的 LLM Provider(Chat_Completion 类型)。通过 /provider 指令切换。 - 通过 /provider 指令切换。 + Args: + umo(str): unified_message_origin 值,如果传入,则使用该会话偏好的提供商。 """ + if umo: + perf = sp.get("session_provider_perf", {}) + provider_id = perf.get(umo, None) + inst = self.provider_manager.inst_map.get(provider_id, None) + if inst: + return inst return self.provider_manager.curr_provider_inst - def get_using_tts_provider(self) -> TTSProvider: + def get_using_tts_provider(self, umo: str = None) -> TTSProvider: """ 获取当前使用的用于 TTS 任务的 Provider。 + + Args: + umo(str): unified_message_origin 值,如果传入,则使用该会话偏好的提供商。 """ + if umo: + perf = sp.get("session_tts_provider_perf", {}) + provider_id = perf.get(umo, None) + inst = self.provider_manager.inst_map.get(provider_id, None) + if inst: + return inst return self.provider_manager.curr_tts_provider_inst - def get_using_stt_provider(self) -> STTProvider: + def get_using_stt_provider(self, umo: str = None) -> STTProvider: """ 获取当前使用的用于 STT 任务的 Provider。 + + Args: + umo(str): unified_message_origin 值,如果传入,则使用该会话偏好的提供商。 """ + if umo: + perf = sp.get("session_stt_provider_perf", {}) + provider_id = perf.get(umo, None) + inst = self.provider_manager.inst_map.get(provider_id, None) + if inst: + return inst return self.provider_manager.curr_stt_provider_inst def get_config(self) -> AstrBotConfig: From 621b556856dab964da0c460921eeb9739dbe3728 Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Mon, 9 Jun 2025 23:33:00 +0800 Subject: [PATCH 2/9] =?UTF-8?q?=E2=9C=A8=20feat:=20=E6=94=AF=E6=8C=81?= =?UTF-8?q?=E5=AF=B9=E6=8F=90=E4=BE=9B=E5=95=86=E4=BC=9A=E8=AF=9D=E9=9A=94?= =?UTF-8?q?=E7=A6=BB?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit fixes: #1762 #602 #479 --- astrbot/core/config/default.py | 15 +++-- .../core/pipeline/preprocess_stage/stage.py | 5 +- .../process_stage/method/llm_request.py | 7 +- .../core/pipeline/result_decorate/stage.py | 4 +- astrbot/core/provider/manager.py | 65 +++++++++++++++---- astrbot/core/star/context.py | 25 +++---- packages/astrbot/main.py | 34 ++++++---- 7 files changed, 102 insertions(+), 53 deletions(-) diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 6523a31ee..092971460 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -42,7 +42,6 @@ DEFAULT_CONFIG = { "empty_mention_waiting": True, "friend_message_needs_wake_prefix": False, "ignore_bot_self_message": False, - "seperate_provider": False, }, "provider": [], "provider_settings": { @@ -59,6 +58,7 @@ DEFAULT_CONFIG = { "dequeue_context_length": 1, "streaming_response": False, "streaming_segmented": False, + "seperate_provider": False, }, "provider_stt_settings": { "enable": False, @@ -1384,10 +1384,15 @@ CONFIG_METADATA_2 = { "hint": "如需切换大语言模型提供商,请使用 /provider 命令。", "obvious_hint": True, }, + "seperate_provider": { + "description": "提供商会话隔离", + "type": "bool", + "hint": "启用后,每个会话支持独立选择文本生成、STT、TTS 等提供商。如果会话在使用 /provider 指令时提示无权限,可以将会话加入管理员名单或者使用 /alter_cmd provider member 将指令设为非管理员指令。", + }, "default_provider_id": { "description": "默认模型提供商 ID", - "type": "str", - "hint": "可选。每个聊天会话的默认提供商 ID。" + "type": "string", + "hint": "可选。每个聊天会话的默认提供商 ID。", }, "wake_prefix": { "description": "LLM 聊天额外唤醒前缀", @@ -1501,7 +1506,7 @@ CONFIG_METADATA_2 = { "obvious_hint": True, }, "provider_id": { - "description": "提供商 ID,不填则默认第一个STT提供商", + "description": "提供商 ID", "type": "string", "hint": "语音转文本提供商 ID。如果不填写将使用载入的第一个提供商。", }, @@ -1518,7 +1523,7 @@ CONFIG_METADATA_2 = { "obvious_hint": True, }, "provider_id": { - "description": "提供商 ID,不填则默认第一个TTS提供商", + "description": "提供商 ID", "type": "string", "hint": "文本转语音提供商 ID。如果不填写将使用载入的第一个提供商。", }, diff --git a/astrbot/core/pipeline/preprocess_stage/stage.py b/astrbot/core/pipeline/preprocess_stage/stage.py index 3e89e1c3e..3e0d4e50f 100644 --- a/astrbot/core/pipeline/preprocess_stage/stage.py +++ b/astrbot/core/pipeline/preprocess_stage/stage.py @@ -43,9 +43,8 @@ class PreProcessStage(Stage): # STT if self.stt_settings.get("enable", False): # TODO: 独立 - stt_provider = ( - self.plugin_manager.context.provider_manager.curr_stt_provider_inst - ) + ctx = self.plugin_manager.context + stt_provider = ctx.get_using_stt_provider(event.unified_msg_origin) if not stt_provider: return message_chain = event.get_messages() diff --git a/astrbot/core/pipeline/process_stage/method/llm_request.py b/astrbot/core/pipeline/process_stage/method/llm_request.py index fc70689e3..20b7651db 100644 --- a/astrbot/core/pipeline/process_stage/method/llm_request.py +++ b/astrbot/core/pipeline/process_stage/method/llm_request.py @@ -52,9 +52,6 @@ class LLMRequestSubStage(Stage): self.streaming_response = ctx.astrbot_config["provider_settings"][ "streaming_response" ] # bool - self.seperate_provider = ctx.astrbot_config["provider_settings"][ - "seperate_provider" - ] # bool for bwp in self.bot_wake_prefixs: if self.provider_wake_prefix.startswith(bwp): @@ -73,9 +70,7 @@ class LLMRequestSubStage(Stage): if not self.ctx.astrbot_config["provider_settings"]["enable"]: logger.debug("未启用 LLM 能力,跳过处理。") return - umo = None - if self.seperate_provider: - umo = event.unified_msg_origin + umo = event.unified_msg_origin provider = self.ctx.plugin_manager.context.get_using_provider(umo=umo) if provider is None: return diff --git a/astrbot/core/pipeline/result_decorate/stage.py b/astrbot/core/pipeline/result_decorate/stage.py index 7a12788a0..2efd98ca3 100644 --- a/astrbot/core/pipeline/result_decorate/stage.py +++ b/astrbot/core/pipeline/result_decorate/stage.py @@ -169,8 +169,8 @@ class ResultDecorateStage(Stage): result.chain = new_chain # TTS - tts_provider = ( - self.ctx.plugin_manager.context.provider_manager.curr_tts_provider_inst + tts_provider = self.ctx.plugin_manager.context.get_using_tts_provider( + event.unified_msg_origin ) if ( self.ctx.astrbot_config["provider_tts_settings"]["enable"] diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py index 7180cf454..d8bc302d1 100644 --- a/astrbot/core/provider/manager.py +++ b/astrbot/core/provider/manager.py @@ -18,10 +18,6 @@ class ProviderManager: self.persona_configs: list = config.get("persona", []) self.astrbot_config = config - self.selected_provider_id = self.provider_settings.get("default_provider_id") - self.selected_stt_provider_id = self.provider_stt_settings.get("provider_id") - self.selected_tts_provider_id = self.provider_settings.get("provider_id") - # 人格情景管理 # 目前没有拆成独立的模块 self.default_persona_name = self.provider_settings.get( @@ -100,12 +96,13 @@ class ProviderManager: self.inst_map = {} """Provider 实例映射. key: provider_id, value: Provider 实例""" self.llm_tools = llm_tools + self.curr_provider_inst: Provider = None - """默认设置的 Provider 实例""" + """默认的 Provider 实例""" self.curr_stt_provider_inst: STTProvider = None - """默认设置的 Speech To Text Provider 实例""" + """默认的 Speech To Text Provider 实例""" self.curr_tts_provider_inst: TTSProvider = None - """当前使用的 Text To Speech Provider 实例""" + """默认的 Text To Speech Provider 实例""" self.db_helper = db_helper # kdb(experimental) @@ -114,14 +111,51 @@ class ProviderManager: if kdb_cfg and len(kdb_cfg): self.curr_kdb_name = list(kdb_cfg.keys())[0] + async def set_provider( + self, provider_id: str, provider_type: ProviderType, umo: str = None + ): + """设置提供商""" + if provider_id not in self.inst_map: + raise ValueError(f"提供商 {provider_id} 不存在,无法设置。") + if umo: + perf = sp.get("session_provider_perf", {}) + session_perf = perf.get(umo, {}) + session_perf[provider_type.value] = provider_id + perf[umo] = session_perf + sp.put("session_provider_perf", perf) + return + # 不启用提供商会话隔离模式的情况 + self.curr_provider_inst = self.inst_map[provider_id] + if provider_type == ProviderType.TEXT_TO_SPEECH: + sp.put("curr_provider_tts", provider_id) + elif provider_type == ProviderType.SPEECH_TO_TEXT: + sp.put("curr_provider_stt", provider_id) + elif provider_type == ProviderType.CHAT_COMPLETION: + sp.put("curr_provider", provider_id) + async def initialize(self): + # 逐个初始化提供商 for provider_config in self.providers_config: await self.load_provider(provider_config) - self.curr_provider_inst = self.inst_map.get(self.selected_provider_id) + # 设置默认提供商 + self.curr_provider_inst = self.inst_map.get( + self.provider_settings.get("default_provider_id") + ) if not self.curr_provider_inst and self.provider_insts: self.curr_provider_inst = self.provider_insts[0] + self.curr_stt_provider_inst = self.inst_map.get( + self.provider_stt_settings.get("provider_id") + ) + if not self.curr_stt_provider_inst and self.stt_provider_insts: + self.curr_stt_provider_inst = self.stt_provider_insts[0] + + self.curr_tts_provider_inst = self.inst_map.get( + self.provider_settings.get("provider_id") + ) + if not self.curr_tts_provider_inst and self.tts_provider_insts: + self.curr_tts_provider_inst = self.tts_provider_insts[0] # 初始化 MCP Client 连接 asyncio.create_task( @@ -244,7 +278,10 @@ class ProviderManager: await inst.initialize() self.stt_provider_insts.append(inst) - if self.selected_stt_provider_id == provider_config["id"]: + if ( + self.provider_stt_settings.get("provider_id") + == provider_config["id"] + ): self.curr_stt_provider_inst = inst logger.info( f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前语音转文本提供商适配器。" @@ -262,7 +299,7 @@ class ProviderManager: await inst.initialize() self.tts_provider_insts.append(inst) - if self.selected_tts_provider_id == provider_config["id"]: + if self.provider_settings.get("provider_id") == provider_config["id"]: self.curr_tts_provider_inst = inst logger.info( f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前文本转语音提供商适配器。" @@ -284,7 +321,10 @@ class ProviderManager: await inst.initialize() self.provider_insts.append(inst) - if self.selected_provider_id == provider_config["id"]: + if ( + self.provider_settings.get("default_provider_id") + == provider_config["id"] + ): self.curr_provider_inst = inst logger.info( f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前提供商适配器。" @@ -322,7 +362,6 @@ class ProviderManager: self.curr_provider_inst = None elif self.curr_provider_inst is None and len(self.provider_insts) > 0: self.curr_provider_inst = self.provider_insts[0] - self.selected_provider_id = self.curr_provider_inst.meta().id logger.info( f"自动选择 {self.curr_provider_inst.meta().id} 作为当前提供商适配器。" ) @@ -331,7 +370,6 @@ class ProviderManager: self.curr_stt_provider_inst = None elif self.curr_stt_provider_inst is None and len(self.stt_provider_insts) > 0: self.curr_stt_provider_inst = self.stt_provider_insts[0] - self.selected_stt_provider_id = self.curr_stt_provider_inst.meta().id logger.info( f"自动选择 {self.curr_stt_provider_inst.meta().id} 作为当前语音转文本提供商适配器。" ) @@ -340,7 +378,6 @@ class ProviderManager: self.curr_tts_provider_inst = None elif self.curr_tts_provider_inst is None and len(self.tts_provider_insts) > 0: self.curr_tts_provider_inst = self.tts_provider_insts[0] - self.selected_tts_provider_id = self.curr_tts_provider_inst.meta().id logger.info( f"自动选择 {self.curr_tts_provider_inst.meta().id} 作为当前文本转语音提供商适配器。" ) diff --git a/astrbot/core/star/context.py b/astrbot/core/star/context.py index d0ec99173..d985a98d5 100644 --- a/astrbot/core/star/context.py +++ b/astrbot/core/star/context.py @@ -3,6 +3,7 @@ from typing import List, Union from astrbot.core import sp from astrbot.core.provider.provider import Provider, TTSProvider, STTProvider +from astrbot.core.provider.entities import ProviderType from astrbot.core.db import BaseDatabase from astrbot.core.config.astrbot_config import AstrBotConfig from astrbot.core.provider.func_tool_manager import FuncCall @@ -145,12 +146,12 @@ class Context: 获取当前使用的用于文本生成任务的 LLM Provider(Chat_Completion 类型)。通过 /provider 指令切换。 Args: - umo(str): unified_message_origin 值,如果传入,则使用该会话偏好的提供商。 + umo(str): unified_message_origin 值,如果传入并且用户启用了提供商会话隔离,则使用该会话偏好的提供商。 """ - if umo: + if umo and self._config["provider_settings"]["seperate_provider"]: perf = sp.get("session_provider_perf", {}) - provider_id = perf.get(umo, None) - inst = self.provider_manager.inst_map.get(provider_id, None) + prov_id = perf.get(umo, {}).get(ProviderType.CHAT_COMPLETION.value, None) + inst = self.provider_manager.inst_map.get(prov_id, None) if inst: return inst return self.provider_manager.curr_provider_inst @@ -162,10 +163,10 @@ class Context: Args: umo(str): unified_message_origin 值,如果传入,则使用该会话偏好的提供商。 """ - if umo: - perf = sp.get("session_tts_provider_perf", {}) - provider_id = perf.get(umo, None) - inst = self.provider_manager.inst_map.get(provider_id, None) + if umo and self._config["provider_settings"]["seperate_provider"]: + perf = sp.get("session_provider_perf", {}) + prov_id = perf.get(umo, {}).get(ProviderType.TEXT_TO_SPEECH.value, None) + inst = self.provider_manager.inst_map.get(prov_id, None) if inst: return inst return self.provider_manager.curr_tts_provider_inst @@ -177,10 +178,10 @@ class Context: Args: umo(str): unified_message_origin 值,如果传入,则使用该会话偏好的提供商。 """ - if umo: - perf = sp.get("session_stt_provider_perf", {}) - provider_id = perf.get(umo, None) - inst = self.provider_manager.inst_map.get(provider_id, None) + if umo and self._config["provider_settings"]["seperate_provider"]: + perf = sp.get("session_provider_perf", {}) + prov_id = perf.get(umo, {}).get(ProviderType.TEXT_TO_SPEECH.value, None) + inst = self.provider_manager.inst_map.get(prov_id, None) if inst: return inst return self.provider_manager.curr_stt_provider_inst diff --git a/packages/astrbot/main.py b/packages/astrbot/main.py index 4fd783a67..d4557e6b8 100644 --- a/packages/astrbot/main.py +++ b/packages/astrbot/main.py @@ -12,6 +12,7 @@ from astrbot.api import sp from astrbot.api.provider import ProviderRequest from astrbot.core.platform.astr_message_event import MessageSesion from astrbot.core.platform.message_type import MessageType +from astrbot.core.provider.entities import ProviderType from astrbot.core.provider.sources.dify_source import ProviderDify from astrbot.core.utils.io import download_dashboard, get_dashboard_version from astrbot.core.utils.astrbot_path import get_astrbot_data_path @@ -139,6 +140,7 @@ class Main(star.Star): {notice}""" event.set_result(MessageEventResult().message(msg).use_t2i(False)) + @filter.permission_type(filter.PermissionType.ADMIN) @filter.command("llm") async def llm(self, event: AstrMessageEvent): @@ -413,20 +415,21 @@ UID: {user_id} 此 ID 可用于设置管理员。 event.set_result(MessageEventResult().message("删除白名单成功。")) except ValueError: event.set_result(MessageEventResult().message("此 SID 不在白名单内。")) - + @filter.permission_type(filter.PermissionType.ADMIN) @filter.command("provider") async def provider( self, event: AstrMessageEvent, idx: Union[str, int] = None, idx2: int = None ): """查看或者切换 LLM Provider""" + umo = event.unified_msg_origin if idx is None: ret = "## 载入的 LLM 提供商\n" for idx, llm in enumerate(self.context.get_all_providers()): id_ = llm.meta().id ret += f"{idx + 1}. {id_} ({llm.meta().model})" - provider_using = self.context.get_using_provider() + provider_using = self.context.get_using_provider(umo=umo) if provider_using and provider_using.meta().id == id_: ret += " (当前使用)" ret += "\n" @@ -437,7 +440,7 @@ UID: {user_id} 此 ID 可用于设置管理员。 for idx, tts in enumerate(tts_providers): id_ = tts.meta().id ret += f"{idx + 1}. {id_}" - tts_using = self.context.get_using_tts_provider() + tts_using = self.context.get_using_tts_provider(umo=umo) if tts_using and tts_using.meta().id == id_: ret += " (当前使用)" ret += "\n" @@ -448,7 +451,7 @@ UID: {user_id} 此 ID 可用于设置管理员。 for idx, stt in enumerate(stt_providers): id_ = stt.meta().id ret += f"{idx + 1}. {id_}" - stt_using = self.context.get_using_stt_provider() + stt_using = self.context.get_using_stt_provider(umo=umo) if stt_using and stt_using.meta().id == id_: ret += " (当前使用)" ret += "\n" @@ -471,8 +474,11 @@ UID: {user_id} 此 ID 可用于设置管理员。 event.set_result(MessageEventResult().message("无效的序号。")) provider = self.context.get_all_tts_providers()[idx2 - 1] id_ = provider.meta().id - self.context.provider_manager.curr_tts_provider_inst = provider - sp.put("curr_provider_tts", id_) + await self.context.provider_manager.set_provider( + provider_id=id_, + provider_type=ProviderType.TEXT_TO_SPEECH, + umo=umo, + ) event.set_result( MessageEventResult().message(f"成功切换到 {id_}。") ) @@ -485,8 +491,11 @@ UID: {user_id} 此 ID 可用于设置管理员。 event.set_result(MessageEventResult().message("无效的序号。")) provider = self.context.get_all_stt_providers()[idx2 - 1] id_ = provider.meta().id - self.context.provider_manager.curr_stt_provider_inst = provider - sp.put("curr_provider_stt", id_) + await self.context.provider_manager.set_provider( + provider_id=id_, + provider_type=ProviderType.SPEECH_TO_TEXT, + umo=umo, + ) event.set_result( MessageEventResult().message(f"成功切换到 {id_}。") ) @@ -496,8 +505,11 @@ UID: {user_id} 此 ID 可用于设置管理员。 provider = self.context.get_all_providers()[idx - 1] id_ = provider.meta().id - self.context.provider_manager.curr_provider_inst = provider - sp.put("curr_provider", id_) + await self.context.provider_manager.set_provider( + provider_id=id_, + provider_type=ProviderType.CHAT_COMPLETION, + umo=umo, + ) event.set_result(MessageEventResult().message(f"成功切换到 {id_}。")) else: event.set_result(MessageEventResult().message("无效的参数。")) @@ -572,7 +584,7 @@ UID: {user_id} 此 ID 可用于设置管理员。 ret += f"\n聊天增强: 已清除 {cnt} 条聊天记录。" message.set_result(MessageEventResult().message(ret)) - + @filter.permission_type(filter.PermissionType.ADMIN) @filter.command("model") async def model_ls( From ffb5605c999224575d674f9011db4c2f41422a86 Mon Sep 17 00:00:00 2001 From: Soulter <37870767+Soulter@users.noreply.github.com> Date: Mon, 9 Jun 2025 23:38:15 +0800 Subject: [PATCH 3/9] fix: default tts provider selection Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com> --- astrbot/core/provider/manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py index d8bc302d1..8baec0cb5 100644 --- a/astrbot/core/provider/manager.py +++ b/astrbot/core/provider/manager.py @@ -152,7 +152,7 @@ class ProviderManager: self.curr_stt_provider_inst = self.stt_provider_insts[0] self.curr_tts_provider_inst = self.inst_map.get( - self.provider_settings.get("provider_id") + self.provider_tts_settings.get("provider_id") ) if not self.curr_tts_provider_inst and self.tts_provider_insts: self.curr_tts_provider_inst = self.tts_provider_insts[0] From a616adaac482ec5284915e016d72e8ba67f93e69 Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Mon, 9 Jun 2025 23:46:44 +0800 Subject: [PATCH 4/9] fix: update provider manager set_provider() --- astrbot/core/provider/manager.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py index 8baec0cb5..d9f1ac046 100644 --- a/astrbot/core/provider/manager.py +++ b/astrbot/core/provider/manager.py @@ -114,10 +114,16 @@ class ProviderManager: async def set_provider( self, provider_id: str, provider_type: ProviderType, umo: str = None ): - """设置提供商""" + """设置提供商。 + + Args: + provider_id (str): 提供商 ID。 + provider_type (ProviderType): 提供商类型。 + umo (str, optional): 用户会话 ID,用于提供商会话隔离。当用户启用了提供商会话隔离时此参数才生效。 + """ if provider_id not in self.inst_map: raise ValueError(f"提供商 {provider_id} 不存在,无法设置。") - if umo: + if umo and self.provider_settings["seperate_provider"]: perf = sp.get("session_provider_perf", {}) session_perf = perf.get(umo, {}) session_perf[provider_type.value] = provider_id From 98800d3426246207c635fcd9f5db17491eaa8cee Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Mon, 9 Jun 2025 23:50:31 +0800 Subject: [PATCH 5/9] fix(typo): "seperate_provider" -> "separate_provider" --- astrbot/core/config/default.py | 4 ++-- astrbot/core/provider/manager.py | 2 +- astrbot/core/star/context.py | 6 +++--- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 092971460..5d8f0219a 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -58,7 +58,7 @@ DEFAULT_CONFIG = { "dequeue_context_length": 1, "streaming_response": False, "streaming_segmented": False, - "seperate_provider": False, + "separate_provider": False, }, "provider_stt_settings": { "enable": False, @@ -1384,7 +1384,7 @@ CONFIG_METADATA_2 = { "hint": "如需切换大语言模型提供商,请使用 /provider 命令。", "obvious_hint": True, }, - "seperate_provider": { + "separate_provider": { "description": "提供商会话隔离", "type": "bool", "hint": "启用后,每个会话支持独立选择文本生成、STT、TTS 等提供商。如果会话在使用 /provider 指令时提示无权限,可以将会话加入管理员名单或者使用 /alter_cmd provider member 将指令设为非管理员指令。", diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py index d9f1ac046..ffeadd8d3 100644 --- a/astrbot/core/provider/manager.py +++ b/astrbot/core/provider/manager.py @@ -123,7 +123,7 @@ class ProviderManager: """ if provider_id not in self.inst_map: raise ValueError(f"提供商 {provider_id} 不存在,无法设置。") - if umo and self.provider_settings["seperate_provider"]: + if umo and self.provider_settings["separate_provider"]: perf = sp.get("session_provider_perf", {}) session_perf = perf.get(umo, {}) session_perf[provider_type.value] = provider_id diff --git a/astrbot/core/star/context.py b/astrbot/core/star/context.py index d985a98d5..0beac12e7 100644 --- a/astrbot/core/star/context.py +++ b/astrbot/core/star/context.py @@ -148,7 +148,7 @@ class Context: Args: umo(str): unified_message_origin 值,如果传入并且用户启用了提供商会话隔离,则使用该会话偏好的提供商。 """ - if umo and self._config["provider_settings"]["seperate_provider"]: + if umo and self._config["provider_settings"]["separate_provider"]: perf = sp.get("session_provider_perf", {}) prov_id = perf.get(umo, {}).get(ProviderType.CHAT_COMPLETION.value, None) inst = self.provider_manager.inst_map.get(prov_id, None) @@ -163,7 +163,7 @@ class Context: Args: umo(str): unified_message_origin 值,如果传入,则使用该会话偏好的提供商。 """ - if umo and self._config["provider_settings"]["seperate_provider"]: + if umo and self._config["provider_settings"]["separate_provider"]: perf = sp.get("session_provider_perf", {}) prov_id = perf.get(umo, {}).get(ProviderType.TEXT_TO_SPEECH.value, None) inst = self.provider_manager.inst_map.get(prov_id, None) @@ -178,7 +178,7 @@ class Context: Args: umo(str): unified_message_origin 值,如果传入,则使用该会话偏好的提供商。 """ - if umo and self._config["provider_settings"]["seperate_provider"]: + if umo and self._config["provider_settings"]["separate_provider"]: perf = sp.get("session_provider_perf", {}) prov_id = perf.get(umo, {}).get(ProviderType.TEXT_TO_SPEECH.value, None) inst = self.provider_manager.inst_map.get(prov_id, None) From a03c79b89d588ed7482aa323c1534109c9ace307 Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Mon, 9 Jun 2025 23:51:54 +0800 Subject: [PATCH 6/9] style: use named expression --- astrbot/core/star/context.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/astrbot/core/star/context.py b/astrbot/core/star/context.py index 0beac12e7..9db1da124 100644 --- a/astrbot/core/star/context.py +++ b/astrbot/core/star/context.py @@ -151,8 +151,7 @@ class Context: if umo and self._config["provider_settings"]["separate_provider"]: perf = sp.get("session_provider_perf", {}) prov_id = perf.get(umo, {}).get(ProviderType.CHAT_COMPLETION.value, None) - inst = self.provider_manager.inst_map.get(prov_id, None) - if inst: + if inst := self.provider_manager.inst_map.get(prov_id, None): return inst return self.provider_manager.curr_provider_inst @@ -166,8 +165,7 @@ class Context: if umo and self._config["provider_settings"]["separate_provider"]: perf = sp.get("session_provider_perf", {}) prov_id = perf.get(umo, {}).get(ProviderType.TEXT_TO_SPEECH.value, None) - inst = self.provider_manager.inst_map.get(prov_id, None) - if inst: + if inst := self.provider_manager.inst_map.get(prov_id, None): return inst return self.provider_manager.curr_tts_provider_inst @@ -181,8 +179,7 @@ class Context: if umo and self._config["provider_settings"]["separate_provider"]: perf = sp.get("session_provider_perf", {}) prov_id = perf.get(umo, {}).get(ProviderType.TEXT_TO_SPEECH.value, None) - inst = self.provider_manager.inst_map.get(prov_id, None) - if inst: + if inst := self.provider_manager.inst_map.get(prov_id, None): return inst return self.provider_manager.curr_stt_provider_inst From cdb7a1b3fa06f647b5cd0be76fd7342142c81303 Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Mon, 9 Jun 2025 23:54:51 +0800 Subject: [PATCH 7/9] style: merge else if into elif --- packages/astrbot/main.py | 83 ++++++++++++++++++++-------------------- 1 file changed, 41 insertions(+), 42 deletions(-) diff --git a/packages/astrbot/main.py b/packages/astrbot/main.py index d4557e6b8..739f12b5c 100644 --- a/packages/astrbot/main.py +++ b/packages/astrbot/main.py @@ -464,55 +464,54 @@ UID: {user_id} 此 ID 可用于设置管理员。 ret += "\n使用 /provider stt <切换> STT 提供商。" event.set_result(MessageEventResult().message(ret)) - else: - if idx == "tts": - if idx2 is None: - event.set_result(MessageEventResult().message("请输入序号。")) - return - else: - if idx2 > len(self.context.get_all_tts_providers()) or idx2 < 1: - event.set_result(MessageEventResult().message("无效的序号。")) - provider = self.context.get_all_tts_providers()[idx2 - 1] - id_ = provider.meta().id - await self.context.provider_manager.set_provider( - provider_id=id_, - provider_type=ProviderType.TEXT_TO_SPEECH, - umo=umo, - ) - event.set_result( - MessageEventResult().message(f"成功切换到 {id_}。") - ) - elif idx == "stt": - if idx2 is None: - event.set_result(MessageEventResult().message("请输入序号。")) - return - else: - if idx2 > len(self.context.get_all_stt_providers()) or idx2 < 1: - event.set_result(MessageEventResult().message("无效的序号。")) - provider = self.context.get_all_stt_providers()[idx2 - 1] - id_ = provider.meta().id - await self.context.provider_manager.set_provider( - provider_id=id_, - provider_type=ProviderType.SPEECH_TO_TEXT, - umo=umo, - ) - event.set_result( - MessageEventResult().message(f"成功切换到 {id_}。") - ) - elif isinstance(idx, int): - if idx > len(self.context.get_all_providers()) or idx < 1: + elif idx == "tts": + if idx2 is None: + event.set_result(MessageEventResult().message("请输入序号。")) + return + else: + if idx2 > len(self.context.get_all_tts_providers()) or idx2 < 1: event.set_result(MessageEventResult().message("无效的序号。")) - - provider = self.context.get_all_providers()[idx - 1] + provider = self.context.get_all_tts_providers()[idx2 - 1] id_ = provider.meta().id await self.context.provider_manager.set_provider( provider_id=id_, - provider_type=ProviderType.CHAT_COMPLETION, + provider_type=ProviderType.TEXT_TO_SPEECH, umo=umo, ) - event.set_result(MessageEventResult().message(f"成功切换到 {id_}。")) + event.set_result( + MessageEventResult().message(f"成功切换到 {id_}。") + ) + elif idx == "stt": + if idx2 is None: + event.set_result(MessageEventResult().message("请输入序号。")) + return else: - event.set_result(MessageEventResult().message("无效的参数。")) + if idx2 > len(self.context.get_all_stt_providers()) or idx2 < 1: + event.set_result(MessageEventResult().message("无效的序号。")) + provider = self.context.get_all_stt_providers()[idx2 - 1] + id_ = provider.meta().id + await self.context.provider_manager.set_provider( + provider_id=id_, + provider_type=ProviderType.SPEECH_TO_TEXT, + umo=umo, + ) + event.set_result( + MessageEventResult().message(f"成功切换到 {id_}。") + ) + elif isinstance(idx, int): + if idx > len(self.context.get_all_providers()) or idx < 1: + event.set_result(MessageEventResult().message("无效的序号。")) + + provider = self.context.get_all_providers()[idx - 1] + id_ = provider.meta().id + await self.context.provider_manager.set_provider( + provider_id=id_, + provider_type=ProviderType.CHAT_COMPLETION, + umo=umo, + ) + event.set_result(MessageEventResult().message(f"成功切换到 {id_}。")) + else: + event.set_result(MessageEventResult().message("无效的参数。")) @filter.command("reset") async def reset(self, message: AstrMessageEvent): From 16e3cd07844ac2ac11f8e5e079cb37fc8ebc8a95 Mon Sep 17 00:00:00 2001 From: Soulter <37870767+Soulter@users.noreply.github.com> Date: Tue, 10 Jun 2025 12:58:39 +0800 Subject: [PATCH 8/9] fix: get_using_stt_provider is fetching using ProviderType.TEXT_TO_SPEECH but should use ProviderType.SPEECH_TO_TEXT for STT isolation. Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- astrbot/core/star/context.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/astrbot/core/star/context.py b/astrbot/core/star/context.py index 9db1da124..dda664b1b 100644 --- a/astrbot/core/star/context.py +++ b/astrbot/core/star/context.py @@ -178,7 +178,7 @@ class Context: """ if umo and self._config["provider_settings"]["separate_provider"]: perf = sp.get("session_provider_perf", {}) - prov_id = perf.get(umo, {}).get(ProviderType.TEXT_TO_SPEECH.value, None) + prov_id = perf.get(umo, {}).get(ProviderType.SPEECH_TO_TEXT.value, None) if inst := self.provider_manager.inst_map.get(prov_id, None): return inst return self.provider_manager.curr_stt_provider_inst From 1d561da7fb137dd2c49a7af552456d14902f33db Mon Sep 17 00:00:00 2001 From: Soulter <37870767+Soulter@users.noreply.github.com> Date: Tue, 10 Jun 2025 12:59:20 +0800 Subject: [PATCH 9/9] style: clean code Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- astrbot/core/pipeline/process_stage/stage.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/astrbot/core/pipeline/process_stage/stage.py b/astrbot/core/pipeline/process_stage/stage.py index 90fac5734..f653a9fb9 100644 --- a/astrbot/core/pipeline/process_stage/stage.py +++ b/astrbot/core/pipeline/process_stage/stage.py @@ -6,7 +6,7 @@ from .method.star_request import StarRequestSubStage from astrbot.core.platform.astr_message_event import AstrMessageEvent from astrbot.core.star.star_handler import StarHandlerMetadata from astrbot.core.provider.entities import ProviderRequest -from astrbot.core import logger, sp +from astrbot.core import logger @register_stage