Merge pull request #1773 from AstrBotDevs/feat-seperate-provider
Feature: 支持对提供商会话隔离
This commit is contained in:
@@ -46,6 +46,7 @@ DEFAULT_CONFIG = {
|
||||
"provider": [],
|
||||
"provider_settings": {
|
||||
"enable": True,
|
||||
"default_provider_id": "",
|
||||
"wake_prefix": "",
|
||||
"web_search": False,
|
||||
"web_search_link": False,
|
||||
@@ -57,6 +58,7 @@ DEFAULT_CONFIG = {
|
||||
"dequeue_context_length": 1,
|
||||
"streaming_response": False,
|
||||
"streaming_segmented": False,
|
||||
"separate_provider": False,
|
||||
},
|
||||
"provider_stt_settings": {
|
||||
"enable": False,
|
||||
@@ -1386,9 +1388,19 @@ CONFIG_METADATA_2 = {
|
||||
"enable": {
|
||||
"description": "启用大语言模型聊天",
|
||||
"type": "bool",
|
||||
"hint": "如需切换大语言模型提供商,请使用 `/provider` 命令。",
|
||||
"hint": "如需切换大语言模型提供商,请使用 /provider 命令。",
|
||||
"obvious_hint": True,
|
||||
},
|
||||
"separate_provider": {
|
||||
"description": "提供商会话隔离",
|
||||
"type": "bool",
|
||||
"hint": "启用后,每个会话支持独立选择文本生成、STT、TTS 等提供商。如果会话在使用 /provider 指令时提示无权限,可以将会话加入管理员名单或者使用 /alter_cmd provider member 将指令设为非管理员指令。",
|
||||
},
|
||||
"default_provider_id": {
|
||||
"description": "默认模型提供商 ID",
|
||||
"type": "string",
|
||||
"hint": "可选。每个聊天会话的默认提供商 ID。",
|
||||
},
|
||||
"wake_prefix": {
|
||||
"description": "LLM 聊天额外唤醒前缀",
|
||||
"type": "string",
|
||||
@@ -1501,7 +1513,7 @@ CONFIG_METADATA_2 = {
|
||||
"obvious_hint": True,
|
||||
},
|
||||
"provider_id": {
|
||||
"description": "提供商 ID,不填则默认第一个STT提供商",
|
||||
"description": "提供商 ID",
|
||||
"type": "string",
|
||||
"hint": "语音转文本提供商 ID。如果不填写将使用载入的第一个提供商。",
|
||||
},
|
||||
@@ -1518,7 +1530,7 @@ CONFIG_METADATA_2 = {
|
||||
"obvious_hint": True,
|
||||
},
|
||||
"provider_id": {
|
||||
"description": "提供商 ID,不填则默认第一个TTS提供商",
|
||||
"description": "提供商 ID",
|
||||
"type": "string",
|
||||
"hint": "文本转语音提供商 ID。如果不填写将使用载入的第一个提供商。",
|
||||
},
|
||||
|
||||
@@ -43,9 +43,8 @@ class PreProcessStage(Stage):
|
||||
# STT
|
||||
if self.stt_settings.get("enable", False):
|
||||
# TODO: 独立
|
||||
stt_provider = (
|
||||
self.plugin_manager.context.provider_manager.curr_stt_provider_inst
|
||||
)
|
||||
ctx = self.plugin_manager.context
|
||||
stt_provider = ctx.get_using_stt_provider(event.unified_msg_origin)
|
||||
if not stt_provider:
|
||||
return
|
||||
message_chain = event.get_messages()
|
||||
|
||||
@@ -70,8 +70,8 @@ class LLMRequestSubStage(Stage):
|
||||
if not self.ctx.astrbot_config["provider_settings"]["enable"]:
|
||||
logger.debug("未启用 LLM 能力,跳过处理。")
|
||||
return
|
||||
|
||||
provider = self.ctx.plugin_manager.context.get_using_provider()
|
||||
umo = event.unified_msg_origin
|
||||
provider = self.ctx.plugin_manager.context.get_using_provider(umo=umo)
|
||||
if provider is None:
|
||||
return
|
||||
|
||||
|
||||
@@ -169,8 +169,8 @@ class ResultDecorateStage(Stage):
|
||||
result.chain = new_chain
|
||||
|
||||
# TTS
|
||||
tts_provider = (
|
||||
self.ctx.plugin_manager.context.provider_manager.curr_tts_provider_inst
|
||||
tts_provider = self.ctx.plugin_manager.context.get_using_tts_provider(
|
||||
event.unified_msg_origin
|
||||
)
|
||||
if (
|
||||
self.ctx.astrbot_config["provider_tts_settings"]["enable"]
|
||||
|
||||
@@ -18,13 +18,6 @@ class ProviderManager:
|
||||
self.persona_configs: list = config.get("persona", [])
|
||||
self.astrbot_config = config
|
||||
|
||||
self.selected_provider_id = sp.get("curr_provider")
|
||||
self.selected_stt_provider_id = self.provider_stt_settings.get("provider_id")
|
||||
self.selected_tts_provider_id = self.provider_settings.get("provider_id")
|
||||
# self.provider_enabled = self.provider_settings.get("enable", False)
|
||||
# self.stt_enabled = self.provider_stt_settings.get("enable", False)
|
||||
# self.tts_enabled = self.provider_tts_settings.get("enable", False)
|
||||
|
||||
# 人格情景管理
|
||||
# 目前没有拆成独立的模块
|
||||
self.default_persona_name = self.provider_settings.get(
|
||||
@@ -103,14 +96,13 @@ class ProviderManager:
|
||||
self.inst_map = {}
|
||||
"""Provider 实例映射. key: provider_id, value: Provider 实例"""
|
||||
self.llm_tools = llm_tools
|
||||
self.default_provider_inst: Provider = None
|
||||
"""默认的 Provider 实例。第 0 个或者用户以前指定的 Provider 实例"""
|
||||
|
||||
self.curr_provider_inst: Provider = None
|
||||
"""当前使用的 Provider 实例"""
|
||||
"""默认的 Provider 实例"""
|
||||
self.curr_stt_provider_inst: STTProvider = None
|
||||
"""当前使用的 Speech To Text Provider 实例"""
|
||||
"""默认的 Speech To Text Provider 实例"""
|
||||
self.curr_tts_provider_inst: TTSProvider = None
|
||||
"""当前使用的 Text To Speech Provider 实例"""
|
||||
"""默认的 Text To Speech Provider 实例"""
|
||||
self.db_helper = db_helper
|
||||
|
||||
# kdb(experimental)
|
||||
@@ -119,13 +111,57 @@ class ProviderManager:
|
||||
if kdb_cfg and len(kdb_cfg):
|
||||
self.curr_kdb_name = list(kdb_cfg.keys())[0]
|
||||
|
||||
async def set_provider(
|
||||
self, provider_id: str, provider_type: ProviderType, umo: str = None
|
||||
):
|
||||
"""设置提供商。
|
||||
|
||||
Args:
|
||||
provider_id (str): 提供商 ID。
|
||||
provider_type (ProviderType): 提供商类型。
|
||||
umo (str, optional): 用户会话 ID,用于提供商会话隔离。当用户启用了提供商会话隔离时此参数才生效。
|
||||
"""
|
||||
if provider_id not in self.inst_map:
|
||||
raise ValueError(f"提供商 {provider_id} 不存在,无法设置。")
|
||||
if umo and self.provider_settings["separate_provider"]:
|
||||
perf = sp.get("session_provider_perf", {})
|
||||
session_perf = perf.get(umo, {})
|
||||
session_perf[provider_type.value] = provider_id
|
||||
perf[umo] = session_perf
|
||||
sp.put("session_provider_perf", perf)
|
||||
return
|
||||
# 不启用提供商会话隔离模式的情况
|
||||
self.curr_provider_inst = self.inst_map[provider_id]
|
||||
if provider_type == ProviderType.TEXT_TO_SPEECH:
|
||||
sp.put("curr_provider_tts", provider_id)
|
||||
elif provider_type == ProviderType.SPEECH_TO_TEXT:
|
||||
sp.put("curr_provider_stt", provider_id)
|
||||
elif provider_type == ProviderType.CHAT_COMPLETION:
|
||||
sp.put("curr_provider", provider_id)
|
||||
|
||||
async def initialize(self):
|
||||
# 逐个初始化提供商
|
||||
for provider_config in self.providers_config:
|
||||
await self.load_provider(provider_config)
|
||||
|
||||
self.default_provider_inst = self.inst_map.get(self.selected_provider_id)
|
||||
if not self.default_provider_inst and self.provider_insts:
|
||||
self.default_provider_inst = self.provider_insts[0]
|
||||
# 设置默认提供商
|
||||
self.curr_provider_inst = self.inst_map.get(
|
||||
self.provider_settings.get("default_provider_id")
|
||||
)
|
||||
if not self.curr_provider_inst and self.provider_insts:
|
||||
self.curr_provider_inst = self.provider_insts[0]
|
||||
|
||||
self.curr_stt_provider_inst = self.inst_map.get(
|
||||
self.provider_stt_settings.get("provider_id")
|
||||
)
|
||||
if not self.curr_stt_provider_inst and self.stt_provider_insts:
|
||||
self.curr_stt_provider_inst = self.stt_provider_insts[0]
|
||||
|
||||
self.curr_tts_provider_inst = self.inst_map.get(
|
||||
self.provider_tts_settings.get("provider_id")
|
||||
)
|
||||
if not self.curr_tts_provider_inst and self.tts_provider_insts:
|
||||
self.curr_tts_provider_inst = self.tts_provider_insts[0]
|
||||
|
||||
# 初始化 MCP Client 连接
|
||||
asyncio.create_task(
|
||||
@@ -252,7 +288,10 @@ class ProviderManager:
|
||||
await inst.initialize()
|
||||
|
||||
self.stt_provider_insts.append(inst)
|
||||
if self.selected_stt_provider_id == provider_config["id"]:
|
||||
if (
|
||||
self.provider_stt_settings.get("provider_id")
|
||||
== provider_config["id"]
|
||||
):
|
||||
self.curr_stt_provider_inst = inst
|
||||
logger.info(
|
||||
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前语音转文本提供商适配器。"
|
||||
@@ -270,7 +309,7 @@ class ProviderManager:
|
||||
await inst.initialize()
|
||||
|
||||
self.tts_provider_insts.append(inst)
|
||||
if self.selected_tts_provider_id == provider_config["id"]:
|
||||
if self.provider_settings.get("provider_id") == provider_config["id"]:
|
||||
self.curr_tts_provider_inst = inst
|
||||
logger.info(
|
||||
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前文本转语音提供商适配器。"
|
||||
@@ -292,7 +331,10 @@ class ProviderManager:
|
||||
await inst.initialize()
|
||||
|
||||
self.provider_insts.append(inst)
|
||||
if self.selected_provider_id == provider_config["id"]:
|
||||
if (
|
||||
self.provider_settings.get("default_provider_id")
|
||||
== provider_config["id"]
|
||||
):
|
||||
self.curr_provider_inst = inst
|
||||
logger.info(
|
||||
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前提供商适配器。"
|
||||
@@ -330,7 +372,6 @@ class ProviderManager:
|
||||
self.curr_provider_inst = None
|
||||
elif self.curr_provider_inst is None and len(self.provider_insts) > 0:
|
||||
self.curr_provider_inst = self.provider_insts[0]
|
||||
self.selected_provider_id = self.curr_provider_inst.meta().id
|
||||
logger.info(
|
||||
f"自动选择 {self.curr_provider_inst.meta().id} 作为当前提供商适配器。"
|
||||
)
|
||||
@@ -339,7 +380,6 @@ class ProviderManager:
|
||||
self.curr_stt_provider_inst = None
|
||||
elif self.curr_stt_provider_inst is None and len(self.stt_provider_insts) > 0:
|
||||
self.curr_stt_provider_inst = self.stt_provider_insts[0]
|
||||
self.selected_stt_provider_id = self.curr_stt_provider_inst.meta().id
|
||||
logger.info(
|
||||
f"自动选择 {self.curr_stt_provider_inst.meta().id} 作为当前语音转文本提供商适配器。"
|
||||
)
|
||||
@@ -348,7 +388,6 @@ class ProviderManager:
|
||||
self.curr_tts_provider_inst = None
|
||||
elif self.curr_tts_provider_inst is None and len(self.tts_provider_insts) > 0:
|
||||
self.curr_tts_provider_inst = self.tts_provider_insts[0]
|
||||
self.selected_tts_provider_id = self.curr_tts_provider_inst.meta().id
|
||||
logger.info(
|
||||
f"自动选择 {self.curr_tts_provider_inst.meta().id} 作为当前文本转语音提供商适配器。"
|
||||
)
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import List, Union
|
||||
|
||||
from astrbot.core import sp
|
||||
from astrbot.core.provider.provider import Provider, TTSProvider, STTProvider
|
||||
from astrbot.core.provider.entities import ProviderType
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core.provider.func_tool_manager import FuncCall
|
||||
@@ -140,24 +141,46 @@ class Context:
|
||||
"""获取所有用于 STT 任务的 Provider。"""
|
||||
return self.provider_manager.stt_provider_insts
|
||||
|
||||
def get_using_provider(self) -> Provider:
|
||||
def get_using_provider(self, umo: str = None) -> Provider:
|
||||
"""
|
||||
获取当前使用的用于文本生成任务的 LLM Provider(Chat_Completion 类型)。
|
||||
获取当前使用的用于文本生成任务的 LLM Provider(Chat_Completion 类型)。通过 /provider 指令切换。
|
||||
|
||||
通过 /provider 指令切换。
|
||||
Args:
|
||||
umo(str): unified_message_origin 值,如果传入并且用户启用了提供商会话隔离,则使用该会话偏好的提供商。
|
||||
"""
|
||||
if umo and self._config["provider_settings"]["separate_provider"]:
|
||||
perf = sp.get("session_provider_perf", {})
|
||||
prov_id = perf.get(umo, {}).get(ProviderType.CHAT_COMPLETION.value, None)
|
||||
if inst := self.provider_manager.inst_map.get(prov_id, None):
|
||||
return inst
|
||||
return self.provider_manager.curr_provider_inst
|
||||
|
||||
def get_using_tts_provider(self) -> TTSProvider:
|
||||
def get_using_tts_provider(self, umo: str = None) -> TTSProvider:
|
||||
"""
|
||||
获取当前使用的用于 TTS 任务的 Provider。
|
||||
|
||||
Args:
|
||||
umo(str): unified_message_origin 值,如果传入,则使用该会话偏好的提供商。
|
||||
"""
|
||||
if umo and self._config["provider_settings"]["separate_provider"]:
|
||||
perf = sp.get("session_provider_perf", {})
|
||||
prov_id = perf.get(umo, {}).get(ProviderType.TEXT_TO_SPEECH.value, None)
|
||||
if inst := self.provider_manager.inst_map.get(prov_id, None):
|
||||
return inst
|
||||
return self.provider_manager.curr_tts_provider_inst
|
||||
|
||||
def get_using_stt_provider(self) -> STTProvider:
|
||||
def get_using_stt_provider(self, umo: str = None) -> STTProvider:
|
||||
"""
|
||||
获取当前使用的用于 STT 任务的 Provider。
|
||||
|
||||
Args:
|
||||
umo(str): unified_message_origin 值,如果传入,则使用该会话偏好的提供商。
|
||||
"""
|
||||
if umo and self._config["provider_settings"]["separate_provider"]:
|
||||
perf = sp.get("session_provider_perf", {})
|
||||
prov_id = perf.get(umo, {}).get(ProviderType.SPEECH_TO_TEXT.value, None)
|
||||
if inst := self.provider_manager.inst_map.get(prov_id, None):
|
||||
return inst
|
||||
return self.provider_manager.curr_stt_provider_inst
|
||||
|
||||
def get_config(self) -> AstrBotConfig:
|
||||
|
||||
+55
-44
@@ -12,6 +12,7 @@ from astrbot.api import sp
|
||||
from astrbot.api.provider import ProviderRequest
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from astrbot.core.platform.message_type import MessageType
|
||||
from astrbot.core.provider.entities import ProviderType
|
||||
from astrbot.core.provider.sources.dify_source import ProviderDify
|
||||
from astrbot.core.utils.io import download_dashboard, get_dashboard_version
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
@@ -139,6 +140,7 @@ class Main(star.Star):
|
||||
{notice}"""
|
||||
|
||||
event.set_result(MessageEventResult().message(msg).use_t2i(False))
|
||||
|
||||
@filter.permission_type(filter.PermissionType.ADMIN)
|
||||
@filter.command("llm")
|
||||
async def llm(self, event: AstrMessageEvent):
|
||||
@@ -413,20 +415,21 @@ UID: {user_id} 此 ID 可用于设置管理员。
|
||||
event.set_result(MessageEventResult().message("删除白名单成功。"))
|
||||
except ValueError:
|
||||
event.set_result(MessageEventResult().message("此 SID 不在白名单内。"))
|
||||
|
||||
|
||||
@filter.permission_type(filter.PermissionType.ADMIN)
|
||||
@filter.command("provider")
|
||||
async def provider(
|
||||
self, event: AstrMessageEvent, idx: Union[str, int] = None, idx2: int = None
|
||||
):
|
||||
"""查看或者切换 LLM Provider"""
|
||||
umo = event.unified_msg_origin
|
||||
|
||||
if idx is None:
|
||||
ret = "## 载入的 LLM 提供商\n"
|
||||
for idx, llm in enumerate(self.context.get_all_providers()):
|
||||
id_ = llm.meta().id
|
||||
ret += f"{idx + 1}. {id_} ({llm.meta().model})"
|
||||
provider_using = self.context.get_using_provider()
|
||||
provider_using = self.context.get_using_provider(umo=umo)
|
||||
if provider_using and provider_using.meta().id == id_:
|
||||
ret += " (当前使用)"
|
||||
ret += "\n"
|
||||
@@ -437,7 +440,7 @@ UID: {user_id} 此 ID 可用于设置管理员。
|
||||
for idx, tts in enumerate(tts_providers):
|
||||
id_ = tts.meta().id
|
||||
ret += f"{idx + 1}. {id_}"
|
||||
tts_using = self.context.get_using_tts_provider()
|
||||
tts_using = self.context.get_using_tts_provider(umo=umo)
|
||||
if tts_using and tts_using.meta().id == id_:
|
||||
ret += " (当前使用)"
|
||||
ret += "\n"
|
||||
@@ -448,7 +451,7 @@ UID: {user_id} 此 ID 可用于设置管理员。
|
||||
for idx, stt in enumerate(stt_providers):
|
||||
id_ = stt.meta().id
|
||||
ret += f"{idx + 1}. {id_}"
|
||||
stt_using = self.context.get_using_stt_provider()
|
||||
stt_using = self.context.get_using_stt_provider(umo=umo)
|
||||
if stt_using and stt_using.meta().id == id_:
|
||||
ret += " (当前使用)"
|
||||
ret += "\n"
|
||||
@@ -461,46 +464,54 @@ UID: {user_id} 此 ID 可用于设置管理员。
|
||||
ret += "\n使用 /provider stt <切换> STT 提供商。"
|
||||
|
||||
event.set_result(MessageEventResult().message(ret))
|
||||
else:
|
||||
if idx == "tts":
|
||||
if idx2 is None:
|
||||
event.set_result(MessageEventResult().message("请输入序号。"))
|
||||
return
|
||||
else:
|
||||
if idx2 > len(self.context.get_all_tts_providers()) or idx2 < 1:
|
||||
event.set_result(MessageEventResult().message("无效的序号。"))
|
||||
provider = self.context.get_all_tts_providers()[idx2 - 1]
|
||||
id_ = provider.meta().id
|
||||
self.context.provider_manager.curr_tts_provider_inst = provider
|
||||
sp.put("curr_provider_tts", id_)
|
||||
event.set_result(
|
||||
MessageEventResult().message(f"成功切换到 {id_}。")
|
||||
)
|
||||
elif idx == "stt":
|
||||
if idx2 is None:
|
||||
event.set_result(MessageEventResult().message("请输入序号。"))
|
||||
return
|
||||
else:
|
||||
if idx2 > len(self.context.get_all_stt_providers()) or idx2 < 1:
|
||||
event.set_result(MessageEventResult().message("无效的序号。"))
|
||||
provider = self.context.get_all_stt_providers()[idx2 - 1]
|
||||
id_ = provider.meta().id
|
||||
self.context.provider_manager.curr_stt_provider_inst = provider
|
||||
sp.put("curr_provider_stt", id_)
|
||||
event.set_result(
|
||||
MessageEventResult().message(f"成功切换到 {id_}。")
|
||||
)
|
||||
elif isinstance(idx, int):
|
||||
if idx > len(self.context.get_all_providers()) or idx < 1:
|
||||
event.set_result(MessageEventResult().message("无效的序号。"))
|
||||
|
||||
provider = self.context.get_all_providers()[idx - 1]
|
||||
id_ = provider.meta().id
|
||||
self.context.provider_manager.curr_provider_inst = provider
|
||||
sp.put("curr_provider", id_)
|
||||
event.set_result(MessageEventResult().message(f"成功切换到 {id_}。"))
|
||||
elif idx == "tts":
|
||||
if idx2 is None:
|
||||
event.set_result(MessageEventResult().message("请输入序号。"))
|
||||
return
|
||||
else:
|
||||
event.set_result(MessageEventResult().message("无效的参数。"))
|
||||
if idx2 > len(self.context.get_all_tts_providers()) or idx2 < 1:
|
||||
event.set_result(MessageEventResult().message("无效的序号。"))
|
||||
provider = self.context.get_all_tts_providers()[idx2 - 1]
|
||||
id_ = provider.meta().id
|
||||
await self.context.provider_manager.set_provider(
|
||||
provider_id=id_,
|
||||
provider_type=ProviderType.TEXT_TO_SPEECH,
|
||||
umo=umo,
|
||||
)
|
||||
event.set_result(
|
||||
MessageEventResult().message(f"成功切换到 {id_}。")
|
||||
)
|
||||
elif idx == "stt":
|
||||
if idx2 is None:
|
||||
event.set_result(MessageEventResult().message("请输入序号。"))
|
||||
return
|
||||
else:
|
||||
if idx2 > len(self.context.get_all_stt_providers()) or idx2 < 1:
|
||||
event.set_result(MessageEventResult().message("无效的序号。"))
|
||||
provider = self.context.get_all_stt_providers()[idx2 - 1]
|
||||
id_ = provider.meta().id
|
||||
await self.context.provider_manager.set_provider(
|
||||
provider_id=id_,
|
||||
provider_type=ProviderType.SPEECH_TO_TEXT,
|
||||
umo=umo,
|
||||
)
|
||||
event.set_result(
|
||||
MessageEventResult().message(f"成功切换到 {id_}。")
|
||||
)
|
||||
elif isinstance(idx, int):
|
||||
if idx > len(self.context.get_all_providers()) or idx < 1:
|
||||
event.set_result(MessageEventResult().message("无效的序号。"))
|
||||
|
||||
provider = self.context.get_all_providers()[idx - 1]
|
||||
id_ = provider.meta().id
|
||||
await self.context.provider_manager.set_provider(
|
||||
provider_id=id_,
|
||||
provider_type=ProviderType.CHAT_COMPLETION,
|
||||
umo=umo,
|
||||
)
|
||||
event.set_result(MessageEventResult().message(f"成功切换到 {id_}。"))
|
||||
else:
|
||||
event.set_result(MessageEventResult().message("无效的参数。"))
|
||||
|
||||
@filter.command("reset")
|
||||
async def reset(self, message: AstrMessageEvent):
|
||||
@@ -572,7 +583,7 @@ UID: {user_id} 此 ID 可用于设置管理员。
|
||||
ret += f"\n聊天增强: 已清除 {cnt} 条聊天记录。"
|
||||
|
||||
message.set_result(MessageEventResult().message(ret))
|
||||
|
||||
|
||||
@filter.permission_type(filter.PermissionType.ADMIN)
|
||||
@filter.command("model")
|
||||
async def model_ls(
|
||||
|
||||
Reference in New Issue
Block a user