3379587223
* feat(mcp): enhance logging and initialize MCP clients in background fixes: #5777 * rf * fix(mcp): simplify MCP client initialization in background * fix(mcp): update error message for MCP background initialization failure
827 lines
34 KiB
Python
827 lines
34 KiB
Python
import asyncio
|
||
import copy
|
||
import os
|
||
import traceback
|
||
from collections.abc import Callable
|
||
from typing import Protocol, runtime_checkable
|
||
|
||
from astrbot.core import astrbot_config, logger, sp
|
||
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
|
||
from astrbot.core.db import BaseDatabase
|
||
from astrbot.core.utils.error_redaction import safe_error
|
||
|
||
from ..persona_mgr import PersonaManager
|
||
from .entities import ProviderType
|
||
from .provider import (
|
||
EmbeddingProvider,
|
||
Provider,
|
||
Providers,
|
||
RerankProvider,
|
||
STTProvider,
|
||
TTSProvider,
|
||
)
|
||
from .register import llm_tools, provider_cls_map
|
||
|
||
|
||
@runtime_checkable
|
||
class HasInitialize(Protocol):
|
||
async def initialize(self) -> None: ...
|
||
|
||
|
||
class ProviderManager:
|
||
def __init__(
|
||
self,
|
||
acm: AstrBotConfigManager,
|
||
db_helper: BaseDatabase,
|
||
persona_mgr: PersonaManager,
|
||
) -> None:
|
||
self.reload_lock = asyncio.Lock()
|
||
self.resource_lock = asyncio.Lock()
|
||
self.persona_mgr = persona_mgr
|
||
self.acm = acm
|
||
config = acm.confs["default"]
|
||
self.providers_config: list = config["provider"]
|
||
self.provider_sources_config: list = config.get("provider_sources", [])
|
||
self.provider_settings: dict = config["provider_settings"]
|
||
self.provider_stt_settings: dict = config.get("provider_stt_settings", {})
|
||
self.provider_tts_settings: dict = config.get("provider_tts_settings", {})
|
||
|
||
# 人格相关属性,v4.0.0 版本后被废弃,推荐使用 PersonaManager
|
||
self.default_persona_name = persona_mgr.default_persona
|
||
|
||
self.provider_insts: list[Provider] = []
|
||
"""加载的 Provider 的实例"""
|
||
self.stt_provider_insts: list[STTProvider] = []
|
||
"""加载的 Speech To Text Provider 的实例"""
|
||
self.tts_provider_insts: list[TTSProvider] = []
|
||
"""加载的 Text To Speech Provider 的实例"""
|
||
self.embedding_provider_insts: list[EmbeddingProvider] = []
|
||
"""加载的 Embedding Provider 的实例"""
|
||
self.rerank_provider_insts: list[RerankProvider] = []
|
||
"""加载的 Rerank Provider 的实例"""
|
||
self.inst_map: dict[
|
||
str,
|
||
Providers,
|
||
] = {}
|
||
"""Provider 实例映射. key: provider_id, value: Provider 实例"""
|
||
self.llm_tools = llm_tools
|
||
|
||
self.curr_provider_inst: Provider | None = None
|
||
"""默认的 Provider 实例。已弃用,请使用 get_using_provider() 方法获取当前使用的 Provider 实例。"""
|
||
self.curr_stt_provider_inst: STTProvider | None = None
|
||
"""默认的 Speech To Text Provider 实例。已弃用,请使用 get_using_provider() 方法获取当前使用的 Provider 实例。"""
|
||
self.curr_tts_provider_inst: TTSProvider | None = None
|
||
"""默认的 Text To Speech Provider 实例。已弃用,请使用 get_using_provider() 方法获取当前使用的 Provider 实例。"""
|
||
self.db_helper = db_helper
|
||
self._provider_change_callback: (
|
||
Callable[[str, ProviderType, str | None], None] | None
|
||
) = None
|
||
self._provider_change_hooks: list[
|
||
Callable[[str, ProviderType, str | None], None]
|
||
] = []
|
||
self._mcp_init_task: asyncio.Task | None = None
|
||
|
||
def set_provider_change_callback(
|
||
self,
|
||
cb: Callable[[str, ProviderType, str | None], None] | None,
|
||
) -> None:
|
||
# Backward-compatible single-callback setter.
|
||
# This callback coexists with register_provider_change_hook subscriptions.
|
||
self._provider_change_callback = cb
|
||
|
||
def register_provider_change_hook(
|
||
self,
|
||
hook: Callable[[str, ProviderType, str | None], None],
|
||
) -> None:
|
||
if hook not in self._provider_change_hooks:
|
||
self._provider_change_hooks.append(hook)
|
||
|
||
def _notify_provider_changed(
|
||
self,
|
||
provider_id: str,
|
||
provider_type: ProviderType,
|
||
umo: str | None,
|
||
) -> None:
|
||
if self._provider_change_callback is not None:
|
||
try:
|
||
self._provider_change_callback(provider_id, provider_type, umo)
|
||
except Exception as e:
|
||
logger.warning(
|
||
"调用 provider 变更回调失败: provider_id=%s, type=%s, err=%s",
|
||
provider_id,
|
||
provider_type,
|
||
safe_error("", e),
|
||
)
|
||
for hook in list(self._provider_change_hooks):
|
||
if hook is self._provider_change_callback:
|
||
continue
|
||
try:
|
||
hook(provider_id, provider_type, umo)
|
||
except Exception as e:
|
||
logger.warning(
|
||
"调用 provider 变更钩子失败: provider_id=%s, type=%s, err=%s",
|
||
provider_id,
|
||
provider_type,
|
||
safe_error("", e),
|
||
)
|
||
|
||
@property
|
||
def persona_configs(self) -> list:
|
||
"""动态获取最新的 persona 配置"""
|
||
return self.persona_mgr.persona_v3_config
|
||
|
||
@property
|
||
def personas(self) -> list:
|
||
"""动态获取最新的 personas 列表"""
|
||
return self.persona_mgr.personas_v3
|
||
|
||
@property
|
||
def selected_default_persona(self):
|
||
"""动态获取最新的默认选中 persona。已弃用,请使用 context.persona_mgr.get_default_persona_v3()"""
|
||
return self.persona_mgr.selected_default_persona_v3
|
||
|
||
async def set_provider(
|
||
self,
|
||
provider_id: str,
|
||
provider_type: ProviderType,
|
||
umo: str | None = None,
|
||
) -> None:
|
||
"""设置提供商。
|
||
|
||
Args:
|
||
provider_id (str): 提供商 ID。
|
||
provider_type (ProviderType): 提供商类型。
|
||
umo (str, optional): 用户会话 ID,用于提供商会话隔离。
|
||
|
||
Version 4.0.0: 这个版本下已经默认隔离提供商
|
||
|
||
"""
|
||
if provider_id not in self.inst_map:
|
||
raise ValueError(f"提供商 {provider_id} 不存在,无法设置。")
|
||
if umo:
|
||
await sp.session_put(
|
||
umo,
|
||
f"provider_perf_{provider_type.value}",
|
||
provider_id,
|
||
)
|
||
self._notify_provider_changed(provider_id, provider_type, umo)
|
||
return
|
||
# 不启用提供商会话隔离模式的情况
|
||
|
||
prov = self.inst_map[provider_id]
|
||
if provider_type == ProviderType.TEXT_TO_SPEECH and isinstance(
|
||
prov,
|
||
TTSProvider,
|
||
):
|
||
self.curr_tts_provider_inst = prov
|
||
await sp.put_async(
|
||
key="curr_provider_tts",
|
||
value=provider_id,
|
||
scope="global",
|
||
scope_id="global",
|
||
)
|
||
self._notify_provider_changed(provider_id, provider_type, umo)
|
||
elif provider_type == ProviderType.SPEECH_TO_TEXT and isinstance(
|
||
prov,
|
||
STTProvider,
|
||
):
|
||
self.curr_stt_provider_inst = prov
|
||
await sp.put_async(
|
||
key="curr_provider_stt",
|
||
value=provider_id,
|
||
scope="global",
|
||
scope_id="global",
|
||
)
|
||
self._notify_provider_changed(provider_id, provider_type, umo)
|
||
elif provider_type == ProviderType.CHAT_COMPLETION and isinstance(
|
||
prov,
|
||
Provider,
|
||
):
|
||
self.curr_provider_inst = prov
|
||
await sp.put_async(
|
||
key="curr_provider",
|
||
value=provider_id,
|
||
scope="global",
|
||
scope_id="global",
|
||
)
|
||
self._notify_provider_changed(provider_id, provider_type, umo)
|
||
|
||
async def get_provider_by_id(self, provider_id: str) -> Providers | None:
|
||
"""根据提供商 ID 获取提供商实例"""
|
||
return self.inst_map.get(provider_id)
|
||
|
||
def get_using_provider(
|
||
self, provider_type: ProviderType, umo=None
|
||
) -> Providers | None:
|
||
"""获取正在使用的提供商实例。
|
||
|
||
Args:
|
||
provider_type (ProviderType): 提供商类型。
|
||
umo (str, optional): 用户会话 ID,用于提供商会话隔离。
|
||
|
||
Returns:
|
||
Provider: 正在使用的提供商实例。
|
||
|
||
"""
|
||
provider = None
|
||
provider_id = None
|
||
if umo:
|
||
provider_id = sp.get(
|
||
f"provider_perf_{provider_type.value}",
|
||
None,
|
||
scope="umo",
|
||
scope_id=umo,
|
||
)
|
||
if provider_id:
|
||
provider = self.inst_map.get(provider_id)
|
||
if not provider:
|
||
# default setting
|
||
config = self.acm.get_conf(umo)
|
||
if provider_type == ProviderType.CHAT_COMPLETION:
|
||
provider_id = config["provider_settings"].get("default_provider_id")
|
||
provider = self.inst_map.get(provider_id)
|
||
if not provider:
|
||
provider = self.provider_insts[0] if self.provider_insts else None
|
||
elif provider_type == ProviderType.SPEECH_TO_TEXT:
|
||
provider_id = config["provider_stt_settings"].get("provider_id")
|
||
if not provider_id:
|
||
return None
|
||
provider = self.inst_map.get(provider_id)
|
||
if not provider:
|
||
provider = (
|
||
self.stt_provider_insts[0] if self.stt_provider_insts else None
|
||
)
|
||
elif provider_type == ProviderType.TEXT_TO_SPEECH:
|
||
provider_id = config["provider_tts_settings"].get("provider_id")
|
||
if not provider_id:
|
||
return None
|
||
provider = self.inst_map.get(provider_id)
|
||
if not provider:
|
||
provider = (
|
||
self.tts_provider_insts[0] if self.tts_provider_insts else None
|
||
)
|
||
else:
|
||
raise ValueError(f"Unknown provider type: {provider_type}")
|
||
|
||
if not provider and provider_id:
|
||
logger.warning(
|
||
f"没有找到 ID 为 {provider_id} 的提供商,这可能是由于您修改了提供商(模型)ID 导致的。"
|
||
)
|
||
|
||
return provider
|
||
|
||
async def initialize(self) -> None:
|
||
# 逐个初始化提供商
|
||
for provider_config in self.providers_config:
|
||
try:
|
||
await self.load_provider(provider_config)
|
||
except Exception as e:
|
||
logger.error(traceback.format_exc())
|
||
logger.error(e)
|
||
|
||
selected_provider_id = await sp.get_async(
|
||
key="curr_provider",
|
||
default=self.provider_settings.get("default_provider_id"),
|
||
scope="global",
|
||
scope_id="global",
|
||
)
|
||
selected_stt_provider_id = await sp.get_async(
|
||
key="curr_provider_stt",
|
||
default=self.provider_stt_settings.get("provider_id"),
|
||
scope="global",
|
||
scope_id="global",
|
||
)
|
||
selected_tts_provider_id = await sp.get_async(
|
||
key="curr_provider_tts",
|
||
default=self.provider_tts_settings.get("provider_id"),
|
||
scope="global",
|
||
scope_id="global",
|
||
)
|
||
|
||
temp_provider = (
|
||
self.inst_map.get(selected_provider_id)
|
||
if isinstance(selected_provider_id, str)
|
||
else None
|
||
)
|
||
self.curr_provider_inst = (
|
||
temp_provider if isinstance(temp_provider, Provider) else None
|
||
)
|
||
if not self.curr_provider_inst and self.provider_insts:
|
||
self.curr_provider_inst = self.provider_insts[0]
|
||
|
||
temp_stt = (
|
||
self.inst_map.get(selected_stt_provider_id)
|
||
if isinstance(selected_stt_provider_id, str)
|
||
else None
|
||
)
|
||
self.curr_stt_provider_inst = (
|
||
temp_stt if isinstance(temp_stt, STTProvider) else None
|
||
)
|
||
if not self.curr_stt_provider_inst and self.stt_provider_insts:
|
||
self.curr_stt_provider_inst = self.stt_provider_insts[0]
|
||
|
||
temp_tts = (
|
||
self.inst_map.get(selected_tts_provider_id)
|
||
if isinstance(selected_tts_provider_id, str)
|
||
else None
|
||
)
|
||
self.curr_tts_provider_inst = (
|
||
temp_tts if isinstance(temp_tts, TTSProvider) else None
|
||
)
|
||
if not self.curr_tts_provider_inst and self.tts_provider_insts:
|
||
self.curr_tts_provider_inst = self.tts_provider_insts[0]
|
||
|
||
async def _init_mcp_clients_bg() -> None:
|
||
try:
|
||
await self.llm_tools.init_mcp_clients()
|
||
except Exception:
|
||
logger.error("MCP init background task failed", exc_info=True)
|
||
|
||
if self._mcp_init_task is None or self._mcp_init_task.done():
|
||
self._mcp_init_task = asyncio.create_task(
|
||
_init_mcp_clients_bg(),
|
||
name="provider-manager:mcp-init",
|
||
)
|
||
|
||
def dynamic_import_provider(self, type: str) -> None:
|
||
"""动态导入提供商适配器模块
|
||
|
||
Args:
|
||
type (str): 提供商请求类型。
|
||
|
||
Raises:
|
||
ImportError: 如果提供商类型未知或无法导入对应模块,则抛出异常。
|
||
"""
|
||
match type:
|
||
case "openai_chat_completion":
|
||
from .sources.openai_source import (
|
||
ProviderOpenAIOfficial as ProviderOpenAIOfficial,
|
||
)
|
||
case "zhipu_chat_completion":
|
||
from .sources.zhipu_source import ProviderZhipu as ProviderZhipu
|
||
case "groq_chat_completion":
|
||
from .sources.groq_source import ProviderGroq as ProviderGroq
|
||
case "xai_chat_completion":
|
||
from .sources.xai_source import ProviderXAI as ProviderXAI
|
||
case "aihubmix_chat_completion":
|
||
from .sources.oai_aihubmix_source import (
|
||
ProviderAIHubMix as ProviderAIHubMix,
|
||
)
|
||
case "openrouter_chat_completion":
|
||
from .sources.openrouter_source import (
|
||
ProviderOpenRouter as ProviderOpenRouter,
|
||
)
|
||
case "anthropic_chat_completion":
|
||
from .sources.anthropic_source import (
|
||
ProviderAnthropic as ProviderAnthropic,
|
||
)
|
||
case "googlegenai_chat_completion":
|
||
from .sources.gemini_source import (
|
||
ProviderGoogleGenAI as ProviderGoogleGenAI,
|
||
)
|
||
case "sensevoice_stt_selfhost":
|
||
from .sources.sensevoice_selfhosted_source import (
|
||
ProviderSenseVoiceSTTSelfHost as ProviderSenseVoiceSTTSelfHost,
|
||
)
|
||
case "openai_whisper_api":
|
||
from .sources.whisper_api_source import (
|
||
ProviderOpenAIWhisperAPI as ProviderOpenAIWhisperAPI,
|
||
)
|
||
case "openai_whisper_selfhost":
|
||
from .sources.whisper_selfhosted_source import (
|
||
ProviderOpenAIWhisperSelfHost as ProviderOpenAIWhisperSelfHost,
|
||
)
|
||
case "xinference_stt":
|
||
from .sources.xinference_stt_provider import (
|
||
ProviderXinferenceSTT as ProviderXinferenceSTT,
|
||
)
|
||
case "openai_tts_api":
|
||
from .sources.openai_tts_api_source import (
|
||
ProviderOpenAITTSAPI as ProviderOpenAITTSAPI,
|
||
)
|
||
case "genie_tts":
|
||
from .sources.genie_tts import (
|
||
GenieTTSProvider as GenieTTSProvider,
|
||
)
|
||
case "edge_tts":
|
||
from .sources.edge_tts_source import (
|
||
ProviderEdgeTTS as ProviderEdgeTTS,
|
||
)
|
||
case "gsv_tts_selfhost":
|
||
from .sources.gsv_selfhosted_source import (
|
||
ProviderGSVTTS as ProviderGSVTTS,
|
||
)
|
||
case "gsvi_tts_api":
|
||
from .sources.gsvi_tts_source import (
|
||
ProviderGSVITTS as ProviderGSVITTS,
|
||
)
|
||
case "fishaudio_tts_api":
|
||
from .sources.fishaudio_tts_api_source import (
|
||
ProviderFishAudioTTSAPI as ProviderFishAudioTTSAPI,
|
||
)
|
||
case "dashscope_tts":
|
||
from .sources.dashscope_tts import (
|
||
ProviderDashscopeTTSAPI as ProviderDashscopeTTSAPI,
|
||
)
|
||
case "azure_tts":
|
||
from .sources.azure_tts_source import (
|
||
AzureTTSProvider as AzureTTSProvider,
|
||
)
|
||
case "minimax_tts_api":
|
||
from .sources.minimax_tts_api_source import (
|
||
ProviderMiniMaxTTSAPI as ProviderMiniMaxTTSAPI,
|
||
)
|
||
case "volcengine_tts":
|
||
from .sources.volcengine_tts import (
|
||
ProviderVolcengineTTS as ProviderVolcengineTTS,
|
||
)
|
||
case "gemini_tts":
|
||
from .sources.gemini_tts_source import (
|
||
ProviderGeminiTTSAPI as ProviderGeminiTTSAPI,
|
||
)
|
||
case "openai_embedding":
|
||
from .sources.openai_embedding_source import (
|
||
OpenAIEmbeddingProvider as OpenAIEmbeddingProvider,
|
||
)
|
||
case "gemini_embedding":
|
||
from .sources.gemini_embedding_source import (
|
||
GeminiEmbeddingProvider as GeminiEmbeddingProvider,
|
||
)
|
||
case "vllm_rerank":
|
||
from .sources.vllm_rerank_source import (
|
||
VLLMRerankProvider as VLLMRerankProvider,
|
||
)
|
||
case "xinference_rerank":
|
||
from .sources.xinference_rerank_source import (
|
||
XinferenceRerankProvider as XinferenceRerankProvider,
|
||
)
|
||
case "bailian_rerank":
|
||
from .sources.bailian_rerank_source import (
|
||
BailianRerankProvider as BailianRerankProvider,
|
||
)
|
||
|
||
def get_merged_provider_config(self, provider_config: dict) -> dict:
|
||
"""获取 provider 配置和 provider_source 配置合并后的结果
|
||
|
||
Returns:
|
||
dict: 合并后的 provider 配置,key 为 provider id,value 为合并后的配置字典
|
||
"""
|
||
pc = copy.deepcopy(provider_config)
|
||
provider_source_id = pc.get("provider_source_id", "")
|
||
if provider_source_id:
|
||
provider_source = None
|
||
for ps in self.provider_sources_config:
|
||
if ps.get("id") == provider_source_id:
|
||
provider_source = ps
|
||
break
|
||
|
||
if provider_source:
|
||
# 合并配置,provider 的配置优先级更高
|
||
merged_config = {**provider_source, **pc}
|
||
# 保持 id 为 provider 的 id,而不是 source 的 id
|
||
merged_config["id"] = pc["id"]
|
||
pc = merged_config
|
||
return pc
|
||
|
||
def _resolve_env_key_list(self, provider_config: dict) -> dict:
|
||
keys = provider_config.get("key", [])
|
||
if not isinstance(keys, list):
|
||
return provider_config
|
||
resolved_keys = []
|
||
for idx, key in enumerate(keys):
|
||
if isinstance(key, str) and key.startswith("$"):
|
||
env_key = key[1:]
|
||
if env_key.startswith("{") and env_key.endswith("}"):
|
||
env_key = env_key[1:-1]
|
||
if env_key:
|
||
env_val = os.getenv(env_key)
|
||
if env_val is None:
|
||
provider_id = provider_config.get("id")
|
||
logger.warning(
|
||
f"Provider {provider_id} 配置项 key[{idx}] 使用环境变量 {env_key} 但未设置。",
|
||
)
|
||
resolved_keys.append("")
|
||
else:
|
||
resolved_keys.append(env_val)
|
||
else:
|
||
resolved_keys.append(key)
|
||
else:
|
||
resolved_keys.append(key)
|
||
provider_config["key"] = resolved_keys
|
||
return provider_config
|
||
|
||
async def load_provider(self, provider_config: dict) -> None:
|
||
# 如果 provider_source_id 存在且不为空,则从 provider_sources 中找到对应的配置并合并
|
||
provider_config = self.get_merged_provider_config(provider_config)
|
||
|
||
if provider_config.get("provider_type", "") == "chat_completion":
|
||
provider_config = self._resolve_env_key_list(provider_config)
|
||
|
||
if not provider_config["enable"]:
|
||
logger.info(f"Provider {provider_config['id']} is disabled, skipping")
|
||
return
|
||
if provider_config.get("provider_type", "") == "agent_runner":
|
||
return
|
||
|
||
logger.info(
|
||
f"载入 {provider_config['type']}({provider_config['id']}) 服务提供商 ...",
|
||
)
|
||
|
||
# 动态导入
|
||
try:
|
||
self.dynamic_import_provider(provider_config["type"])
|
||
except (ImportError, ModuleNotFoundError) as e:
|
||
logger.critical(
|
||
f"加载 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}。可能是因为有未安装的依赖。",
|
||
exc_info=True,
|
||
)
|
||
return
|
||
except Exception as e:
|
||
logger.critical(
|
||
f"加载 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}。未知原因",
|
||
exc_info=True,
|
||
)
|
||
return
|
||
|
||
if provider_config["type"] not in provider_cls_map:
|
||
logger.error(
|
||
f"未找到适用于 {provider_config['type']}({provider_config['id']}) 的提供商适配器,请检查是否已经安装或者名称填写错误。已跳过。",
|
||
exc_info=True,
|
||
)
|
||
return
|
||
|
||
provider_metadata = provider_cls_map[provider_config["type"]]
|
||
try:
|
||
# 按任务实例化提供商
|
||
cls_type = provider_metadata.cls_type
|
||
if not cls_type:
|
||
logger.error(f"无法找到 {provider_metadata.type} 的类")
|
||
return
|
||
|
||
provider_metadata.id = provider_config["id"]
|
||
|
||
match provider_metadata.provider_type:
|
||
case ProviderType.SPEECH_TO_TEXT:
|
||
# STT 任务
|
||
if not issubclass(cls_type, STTProvider):
|
||
raise TypeError(
|
||
f"Provider class {cls_type} is not a subclass of STTProvider"
|
||
)
|
||
inst = cls_type(provider_config, self.provider_settings)
|
||
|
||
if isinstance(inst, HasInitialize):
|
||
await inst.initialize()
|
||
|
||
self.stt_provider_insts.append(inst)
|
||
if (
|
||
self.provider_stt_settings.get("provider_id")
|
||
== provider_config["id"]
|
||
):
|
||
self.curr_stt_provider_inst = inst
|
||
logger.info(
|
||
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前语音转文本提供商适配器。",
|
||
)
|
||
if not self.curr_stt_provider_inst:
|
||
self.curr_stt_provider_inst = inst
|
||
|
||
case ProviderType.TEXT_TO_SPEECH:
|
||
# TTS 任务
|
||
if not issubclass(cls_type, TTSProvider):
|
||
raise TypeError(
|
||
f"Provider class {cls_type} is not a subclass of TTSProvider"
|
||
)
|
||
inst = cls_type(provider_config, self.provider_settings)
|
||
|
||
if isinstance(inst, HasInitialize):
|
||
await inst.initialize()
|
||
|
||
self.tts_provider_insts.append(inst)
|
||
if (
|
||
self.provider_settings.get("provider_id")
|
||
== provider_config["id"]
|
||
):
|
||
self.curr_tts_provider_inst = inst
|
||
logger.info(
|
||
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前文本转语音提供商适配器。",
|
||
)
|
||
if not self.curr_tts_provider_inst:
|
||
self.curr_tts_provider_inst = inst
|
||
|
||
case ProviderType.CHAT_COMPLETION:
|
||
# 文本生成任务
|
||
if not issubclass(cls_type, Provider):
|
||
raise TypeError(
|
||
f"Provider class {cls_type} is not a subclass of Provider"
|
||
)
|
||
inst = cls_type(
|
||
provider_config,
|
||
self.provider_settings,
|
||
)
|
||
|
||
if isinstance(inst, HasInitialize):
|
||
await inst.initialize()
|
||
|
||
self.provider_insts.append(inst)
|
||
if (
|
||
self.provider_settings.get("default_provider_id")
|
||
== provider_config["id"]
|
||
):
|
||
self.curr_provider_inst = inst
|
||
logger.info(
|
||
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前提供商适配器。",
|
||
)
|
||
if not self.curr_provider_inst:
|
||
self.curr_provider_inst = inst
|
||
|
||
case ProviderType.EMBEDDING:
|
||
if not issubclass(cls_type, EmbeddingProvider):
|
||
raise TypeError(
|
||
f"Provider class {cls_type} is not a subclass of EmbeddingProvider"
|
||
)
|
||
inst = cls_type(provider_config, self.provider_settings)
|
||
if isinstance(inst, HasInitialize):
|
||
await inst.initialize()
|
||
self.embedding_provider_insts.append(inst)
|
||
case ProviderType.RERANK:
|
||
if not issubclass(cls_type, RerankProvider):
|
||
raise TypeError(
|
||
f"Provider class {cls_type} is not a subclass of RerankProvider"
|
||
)
|
||
inst = cls_type(provider_config, self.provider_settings)
|
||
if isinstance(inst, HasInitialize):
|
||
await inst.initialize()
|
||
self.rerank_provider_insts.append(inst)
|
||
case _:
|
||
# 未知供应商抛出异常,确保inst初始化
|
||
# Should be unreachable
|
||
raise Exception(
|
||
f"未知的提供商类型:{provider_metadata.provider_type}"
|
||
)
|
||
|
||
self.inst_map[provider_config["id"]] = inst
|
||
except Exception as e:
|
||
logger.error(
|
||
f"实例化 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}",
|
||
)
|
||
raise Exception(
|
||
f"实例化 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}",
|
||
)
|
||
|
||
async def reload(self, provider_config: dict) -> None:
|
||
async with self.reload_lock:
|
||
await self.terminate_provider(provider_config["id"])
|
||
if provider_config["enable"]:
|
||
await self.load_provider(provider_config)
|
||
|
||
# 和配置文件保持同步
|
||
self.providers_config = astrbot_config["provider"]
|
||
self.provider_sources_config = astrbot_config.get("provider_sources", [])
|
||
config_ids = [provider["id"] for provider in self.providers_config]
|
||
logger.info(f"providers in user's config: {config_ids}")
|
||
for key in list(self.inst_map.keys()):
|
||
if key not in config_ids:
|
||
await self.terminate_provider(key)
|
||
|
||
if len(self.provider_insts) == 0:
|
||
self.curr_provider_inst = None
|
||
elif self.curr_provider_inst is None and len(self.provider_insts) > 0:
|
||
self.curr_provider_inst = self.provider_insts[0]
|
||
logger.info(
|
||
f"自动选择 {self.curr_provider_inst.meta().id} 作为当前提供商适配器。",
|
||
)
|
||
|
||
if len(self.stt_provider_insts) == 0:
|
||
self.curr_stt_provider_inst = None
|
||
elif (
|
||
self.curr_stt_provider_inst is None and len(self.stt_provider_insts) > 0
|
||
):
|
||
self.curr_stt_provider_inst = self.stt_provider_insts[0]
|
||
logger.info(
|
||
f"自动选择 {self.curr_stt_provider_inst.meta().id} 作为当前语音转文本提供商适配器。",
|
||
)
|
||
|
||
if len(self.tts_provider_insts) == 0:
|
||
self.curr_tts_provider_inst = None
|
||
elif (
|
||
self.curr_tts_provider_inst is None and len(self.tts_provider_insts) > 0
|
||
):
|
||
self.curr_tts_provider_inst = self.tts_provider_insts[0]
|
||
logger.info(
|
||
f"自动选择 {self.curr_tts_provider_inst.meta().id} 作为当前文本转语音提供商适配器。",
|
||
)
|
||
|
||
def get_insts(self):
|
||
return self.provider_insts
|
||
|
||
async def terminate_provider(self, provider_id: str) -> None:
|
||
if provider_id in self.inst_map:
|
||
logger.info(
|
||
f"终止 {provider_id} 提供商适配器({len(self.provider_insts)}, {len(self.stt_provider_insts)}, {len(self.tts_provider_insts)}) ...",
|
||
)
|
||
|
||
if self.inst_map[provider_id] in self.provider_insts:
|
||
prov_inst = self.inst_map[provider_id]
|
||
if isinstance(prov_inst, Provider):
|
||
self.provider_insts.remove(prov_inst)
|
||
if self.inst_map[provider_id] in self.stt_provider_insts:
|
||
prov_inst = self.inst_map[provider_id]
|
||
if isinstance(prov_inst, STTProvider):
|
||
self.stt_provider_insts.remove(prov_inst)
|
||
if self.inst_map[provider_id] in self.tts_provider_insts:
|
||
prov_inst = self.inst_map[provider_id]
|
||
if isinstance(prov_inst, TTSProvider):
|
||
self.tts_provider_insts.remove(prov_inst)
|
||
|
||
if self.inst_map[provider_id] == self.curr_provider_inst:
|
||
self.curr_provider_inst = None
|
||
if self.inst_map[provider_id] == self.curr_stt_provider_inst:
|
||
self.curr_stt_provider_inst = None
|
||
if self.inst_map[provider_id] == self.curr_tts_provider_inst:
|
||
self.curr_tts_provider_inst = None
|
||
|
||
if getattr(self.inst_map[provider_id], "terminate", None):
|
||
await self.inst_map[provider_id].terminate() # type: ignore
|
||
|
||
logger.info(
|
||
f"{provider_id} 提供商适配器已终止({len(self.provider_insts)}, {len(self.stt_provider_insts)}, {len(self.tts_provider_insts)})",
|
||
)
|
||
del self.inst_map[provider_id]
|
||
|
||
async def delete_provider(
|
||
self, provider_id: str | None = None, provider_source_id: str | None = None
|
||
) -> None:
|
||
"""Delete provider and/or provider source from config and terminate the instances. Config will be saved after deletion."""
|
||
async with self.resource_lock:
|
||
# delete from config
|
||
target_prov_ids = []
|
||
if provider_id:
|
||
target_prov_ids.append(provider_id)
|
||
else:
|
||
for prov in self.providers_config:
|
||
if prov.get("provider_source_id") == provider_source_id:
|
||
target_prov_ids.append(prov.get("id"))
|
||
config = self.acm.default_conf
|
||
for tpid in target_prov_ids:
|
||
await self.terminate_provider(tpid)
|
||
config["provider"] = [
|
||
prov for prov in config["provider"] if prov.get("id") != tpid
|
||
]
|
||
config.save_config()
|
||
logger.info(f"Provider {target_prov_ids} 已从配置中删除。")
|
||
|
||
async def update_provider(self, origin_provider_id: str, new_config: dict) -> None:
|
||
"""Update provider config and reload the instance. Config will be saved after update."""
|
||
async with self.resource_lock:
|
||
npid = new_config.get("id", None)
|
||
if not npid:
|
||
raise ValueError("New provider config must have an 'id' field")
|
||
config = self.acm.default_conf
|
||
for provider in config["provider"]:
|
||
if (
|
||
provider.get("id", None) == npid
|
||
and provider.get("id", None) != origin_provider_id
|
||
):
|
||
raise ValueError(f"Provider ID {npid} already exists")
|
||
# update config
|
||
for idx, provider in enumerate(config["provider"]):
|
||
if provider.get("id", None) == origin_provider_id:
|
||
config["provider"][idx] = new_config
|
||
break
|
||
else:
|
||
raise ValueError(f"Provider ID {origin_provider_id} not found")
|
||
config.save_config()
|
||
# reload instance
|
||
await self.reload(new_config)
|
||
|
||
async def create_provider(self, new_config: dict) -> None:
|
||
"""Add new provider config and load the instance. Config will be saved after addition."""
|
||
async with self.resource_lock:
|
||
npid = new_config.get("id", None)
|
||
if not npid:
|
||
raise ValueError("New provider config must have an 'id' field")
|
||
config = self.acm.default_conf
|
||
for provider in config["provider"]:
|
||
if provider.get("id", None) == npid:
|
||
raise ValueError(f"Provider ID {npid} already exists")
|
||
# add to config
|
||
config["provider"].append(new_config)
|
||
config.save_config()
|
||
# load instance
|
||
await self.load_provider(new_config)
|
||
|
||
async def terminate(self) -> None:
|
||
if self._mcp_init_task and not self._mcp_init_task.done():
|
||
self._mcp_init_task.cancel()
|
||
try:
|
||
await self._mcp_init_task
|
||
except asyncio.CancelledError:
|
||
pass
|
||
|
||
for provider_inst in self.provider_insts:
|
||
if hasattr(provider_inst, "terminate"):
|
||
await provider_inst.terminate() # type: ignore
|
||
try:
|
||
await self.llm_tools.disable_mcp_server()
|
||
except Exception:
|
||
logger.error("Error while disabling MCP servers", exc_info=True)
|