0711ec346f
* fix: resolve MCP tools race condition causing 'completion 无法解析' error - Wait for MCP client initialization to complete before accepting requests - Add Future-based synchronization in init_mcp_clients() - Prevent tool_calls from being rejected due to empty func_list - Improve error logging for MCP initialization failures Fixes race condition where AI attempts to call MCP tools before they are registered, resulting in 'API 返回的 completion 无法解析' exceptions. The issue occurred because: 1. MCP clients were initialized asynchronously without waiting 2. System accepted user requests immediately after startup 3. AI received empty tool list and attempted to call non-existent tools 4. Tool matching failed, causing parsing errors This fix ensures all MCP tools are loaded before the system processes any requests that might use them. * perf: add timeout and better error handling for MCP initialization - Add 20-second total timeout to prevent slow MCP servers from blocking startup - Show detailed configuration info when MCP initialization fails - List all failed services in a summary warning - Gracefully handle timeout by using already-completed services This ensures that even if some MCP servers are slow or unreachable, the system will start within a reasonable time and provide clear feedback about which services failed and why. * refactor: simplify MCP init orchestration and improve log security - Replace Future-based sync with asyncio.wait + name→task mapping - Explicitly cancel timed-out tasks after 20s timeout - Downgrade sensitive config details (command/args/URL) to debug level - Move urllib.parse import to top-level * fix: prevent initialized MCP clients from being cleaned up on timeout - Do not cancel pending tasks on timeout; let them continue running in the background waiting for the termination signal (event.set()), so successfully initialized services remain available - Track initialization state with a flag to distinguish init failures from post-init cancellations in _init_mcp_client_task_wrapper * fix: restore task cancellation on timeout per review feedback Pending tasks in asyncio.wait are tasks that have NOT completed initialization within 20s, so cancelling them is safe and correct. * fix: separate init signal from client lifetime in MCP task wrapper The previous design awaited task completion, but tasks only finish on shutdown (after event.wait()), causing asyncio.wait to always hit the 20s timeout and cancel all clients. Fix: introduce a dedicated ready_event that is set immediately after _init_mcp_client completes. init_mcp_clients now waits only for ready_event (with 20s timeout), while the long-lived client task continues running in the background until shutdown_event is set. This ensures startup returns promptly once clients are ready. * security: redact sensitive MCP config from debug logs Only log executable name and argument count instead of full command/args to avoid leaking tokens or credentials even at debug level. * refactor: use McpClientInfo dataclass and MCP_INIT_TIMEOUT constant - Extract MCP_INIT_TIMEOUT = 20.0 as a named module-level constant - Replace tuple-based client_info with _McpClientInfo dataclass to eliminate index-based access and improve readability - Remove _wait_ready helper; use asyncio.create_task(event.wait()) directly - Await cancelled tasks after timeout to prevent lingering background tasks and unobserved exceptions * fix: handle CancelledError and clean up wait_tasks on timeout - Catch asyncio.CancelledError separately in _init_mcp_client_task_wrapper so ready_event.set() is always called (Python 3.8+ CancelledError inherits BaseException, not Exception) - Cancel and await lingering wait_tasks after timeout to prevent them from hanging indefinitely when ready_event is never set * fix: align enable_mcp_server with new wrapper API and fix security/config issues - Fix enable_mcp_server to pass shutdown_event + ready_event instead of ready_future, matching _init_mcp_client_task_wrapper's current signature - Cancel and await init_task on timeout; clean up mcp_client_event on failure - Read MCP_INIT_TIMEOUT from env var ASTRBOT_MCP_INIT_TIMEOUT (default 20s) so operators can tune it without code changes - Strip userinfo from URL in debug log (use hostname+port only, not netloc) to avoid leaking credentials embedded in URLs * refactor: register mcp_client_event only after successful init in enable_mcp_server Move self.mcp_client_event[name] assignment to after initialization succeeds, so callers never observe a stale event for a failed client. * fix: harden MCP init state handling and timeout parsing * fix: improve MCP timeout and post-init error observability * refactor: simplify MCP init lifecycle orchestration * refactor: simplify MCP init flow and cap timeout values * fix: refine mcp timeout handling and lifecycle task tracking * fix: harden mcp shutdown and timeout source logging * refactor: simplify mcp runtime registry and timeout flow * fix: keep mcp init summary return contract * refactor: streamline mcp lifecycle and init errors * refactor: unify mcp lifecycle wait handling * refactor: simplify mcp runtime ownership and timeout resolution * fix: harden mcp shutdown waiting and startup signaling * refactor: streamline mcp lifecycle and shutdown errors * refactor: harden mcp runtime access and shutdown * fix: ensure mcp client cleanup and clarify views * refactor: cache mcp client view and guard startup * refactor: simplify mcp init cleanup and runtime lock * refactor: reduce mcp runtime duplication * refactor: reuse mcp cleanup and client view --------- Co-authored-by: idiotsj <idiotsj@users.noreply.github.com> Co-authored-by: 邹永赫 <1259085392@qq.com>
827 lines
34 KiB
Python
827 lines
34 KiB
Python
import asyncio
|
||
import copy
|
||
import os
|
||
import traceback
|
||
from collections.abc import Callable
|
||
from typing import Protocol, runtime_checkable
|
||
|
||
from astrbot.core import astrbot_config, logger, sp
|
||
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
|
||
from astrbot.core.db import BaseDatabase
|
||
from astrbot.core.utils.error_redaction import safe_error
|
||
|
||
from ..persona_mgr import PersonaManager
|
||
from .entities import ProviderType
|
||
from .provider import (
|
||
EmbeddingProvider,
|
||
Provider,
|
||
Providers,
|
||
RerankProvider,
|
||
STTProvider,
|
||
TTSProvider,
|
||
)
|
||
from .register import llm_tools, provider_cls_map
|
||
|
||
|
||
@runtime_checkable
|
||
class HasInitialize(Protocol):
|
||
async def initialize(self) -> None: ...
|
||
|
||
|
||
class ProviderManager:
|
||
def __init__(
|
||
self,
|
||
acm: AstrBotConfigManager,
|
||
db_helper: BaseDatabase,
|
||
persona_mgr: PersonaManager,
|
||
) -> None:
|
||
self.reload_lock = asyncio.Lock()
|
||
self.resource_lock = asyncio.Lock()
|
||
self.persona_mgr = persona_mgr
|
||
self.acm = acm
|
||
config = acm.confs["default"]
|
||
self.providers_config: list = config["provider"]
|
||
self.provider_sources_config: list = config.get("provider_sources", [])
|
||
self.provider_settings: dict = config["provider_settings"]
|
||
self.provider_stt_settings: dict = config.get("provider_stt_settings", {})
|
||
self.provider_tts_settings: dict = config.get("provider_tts_settings", {})
|
||
|
||
# 人格相关属性,v4.0.0 版本后被废弃,推荐使用 PersonaManager
|
||
self.default_persona_name = persona_mgr.default_persona
|
||
|
||
self.provider_insts: list[Provider] = []
|
||
"""加载的 Provider 的实例"""
|
||
self.stt_provider_insts: list[STTProvider] = []
|
||
"""加载的 Speech To Text Provider 的实例"""
|
||
self.tts_provider_insts: list[TTSProvider] = []
|
||
"""加载的 Text To Speech Provider 的实例"""
|
||
self.embedding_provider_insts: list[EmbeddingProvider] = []
|
||
"""加载的 Embedding Provider 的实例"""
|
||
self.rerank_provider_insts: list[RerankProvider] = []
|
||
"""加载的 Rerank Provider 的实例"""
|
||
self.inst_map: dict[
|
||
str,
|
||
Providers,
|
||
] = {}
|
||
"""Provider 实例映射. key: provider_id, value: Provider 实例"""
|
||
self.llm_tools = llm_tools
|
||
|
||
self.curr_provider_inst: Provider | None = None
|
||
"""默认的 Provider 实例。已弃用,请使用 get_using_provider() 方法获取当前使用的 Provider 实例。"""
|
||
self.curr_stt_provider_inst: STTProvider | None = None
|
||
"""默认的 Speech To Text Provider 实例。已弃用,请使用 get_using_provider() 方法获取当前使用的 Provider 实例。"""
|
||
self.curr_tts_provider_inst: TTSProvider | None = None
|
||
"""默认的 Text To Speech Provider 实例。已弃用,请使用 get_using_provider() 方法获取当前使用的 Provider 实例。"""
|
||
self.db_helper = db_helper
|
||
self._provider_change_callback: (
|
||
Callable[[str, ProviderType, str | None], None] | None
|
||
) = None
|
||
self._provider_change_hooks: list[
|
||
Callable[[str, ProviderType, str | None], None]
|
||
] = []
|
||
|
||
def set_provider_change_callback(
|
||
self,
|
||
cb: Callable[[str, ProviderType, str | None], None] | None,
|
||
) -> None:
|
||
# Backward-compatible single-callback setter.
|
||
# This callback coexists with register_provider_change_hook subscriptions.
|
||
self._provider_change_callback = cb
|
||
|
||
def register_provider_change_hook(
|
||
self,
|
||
hook: Callable[[str, ProviderType, str | None], None],
|
||
) -> None:
|
||
if hook not in self._provider_change_hooks:
|
||
self._provider_change_hooks.append(hook)
|
||
|
||
def _notify_provider_changed(
|
||
self,
|
||
provider_id: str,
|
||
provider_type: ProviderType,
|
||
umo: str | None,
|
||
) -> None:
|
||
if self._provider_change_callback is not None:
|
||
try:
|
||
self._provider_change_callback(provider_id, provider_type, umo)
|
||
except Exception as e:
|
||
logger.warning(
|
||
"调用 provider 变更回调失败: provider_id=%s, type=%s, err=%s",
|
||
provider_id,
|
||
provider_type,
|
||
safe_error("", e),
|
||
)
|
||
for hook in list(self._provider_change_hooks):
|
||
if hook is self._provider_change_callback:
|
||
continue
|
||
try:
|
||
hook(provider_id, provider_type, umo)
|
||
except Exception as e:
|
||
logger.warning(
|
||
"调用 provider 变更钩子失败: provider_id=%s, type=%s, err=%s",
|
||
provider_id,
|
||
provider_type,
|
||
safe_error("", e),
|
||
)
|
||
|
||
@property
|
||
def persona_configs(self) -> list:
|
||
"""动态获取最新的 persona 配置"""
|
||
return self.persona_mgr.persona_v3_config
|
||
|
||
@property
|
||
def personas(self) -> list:
|
||
"""动态获取最新的 personas 列表"""
|
||
return self.persona_mgr.personas_v3
|
||
|
||
@property
|
||
def selected_default_persona(self):
|
||
"""动态获取最新的默认选中 persona。已弃用,请使用 context.persona_mgr.get_default_persona_v3()"""
|
||
return self.persona_mgr.selected_default_persona_v3
|
||
|
||
async def set_provider(
|
||
self,
|
||
provider_id: str,
|
||
provider_type: ProviderType,
|
||
umo: str | None = None,
|
||
) -> None:
|
||
"""设置提供商。
|
||
|
||
Args:
|
||
provider_id (str): 提供商 ID。
|
||
provider_type (ProviderType): 提供商类型。
|
||
umo (str, optional): 用户会话 ID,用于提供商会话隔离。
|
||
|
||
Version 4.0.0: 这个版本下已经默认隔离提供商
|
||
|
||
"""
|
||
if provider_id not in self.inst_map:
|
||
raise ValueError(f"提供商 {provider_id} 不存在,无法设置。")
|
||
if umo:
|
||
await sp.session_put(
|
||
umo,
|
||
f"provider_perf_{provider_type.value}",
|
||
provider_id,
|
||
)
|
||
self._notify_provider_changed(provider_id, provider_type, umo)
|
||
return
|
||
# 不启用提供商会话隔离模式的情况
|
||
|
||
prov = self.inst_map[provider_id]
|
||
if provider_type == ProviderType.TEXT_TO_SPEECH and isinstance(
|
||
prov,
|
||
TTSProvider,
|
||
):
|
||
self.curr_tts_provider_inst = prov
|
||
await sp.put_async(
|
||
key="curr_provider_tts",
|
||
value=provider_id,
|
||
scope="global",
|
||
scope_id="global",
|
||
)
|
||
self._notify_provider_changed(provider_id, provider_type, umo)
|
||
elif provider_type == ProviderType.SPEECH_TO_TEXT and isinstance(
|
||
prov,
|
||
STTProvider,
|
||
):
|
||
self.curr_stt_provider_inst = prov
|
||
await sp.put_async(
|
||
key="curr_provider_stt",
|
||
value=provider_id,
|
||
scope="global",
|
||
scope_id="global",
|
||
)
|
||
self._notify_provider_changed(provider_id, provider_type, umo)
|
||
elif provider_type == ProviderType.CHAT_COMPLETION and isinstance(
|
||
prov,
|
||
Provider,
|
||
):
|
||
self.curr_provider_inst = prov
|
||
await sp.put_async(
|
||
key="curr_provider",
|
||
value=provider_id,
|
||
scope="global",
|
||
scope_id="global",
|
||
)
|
||
self._notify_provider_changed(provider_id, provider_type, umo)
|
||
|
||
async def get_provider_by_id(self, provider_id: str) -> Providers | None:
|
||
"""根据提供商 ID 获取提供商实例"""
|
||
return self.inst_map.get(provider_id)
|
||
|
||
def get_using_provider(
|
||
self, provider_type: ProviderType, umo=None
|
||
) -> Providers | None:
|
||
"""获取正在使用的提供商实例。
|
||
|
||
Args:
|
||
provider_type (ProviderType): 提供商类型。
|
||
umo (str, optional): 用户会话 ID,用于提供商会话隔离。
|
||
|
||
Returns:
|
||
Provider: 正在使用的提供商实例。
|
||
|
||
"""
|
||
provider = None
|
||
provider_id = None
|
||
if umo:
|
||
provider_id = sp.get(
|
||
f"provider_perf_{provider_type.value}",
|
||
None,
|
||
scope="umo",
|
||
scope_id=umo,
|
||
)
|
||
if provider_id:
|
||
provider = self.inst_map.get(provider_id)
|
||
if not provider:
|
||
# default setting
|
||
config = self.acm.get_conf(umo)
|
||
if provider_type == ProviderType.CHAT_COMPLETION:
|
||
provider_id = config["provider_settings"].get("default_provider_id")
|
||
provider = self.inst_map.get(provider_id)
|
||
if not provider:
|
||
provider = self.provider_insts[0] if self.provider_insts else None
|
||
elif provider_type == ProviderType.SPEECH_TO_TEXT:
|
||
provider_id = config["provider_stt_settings"].get("provider_id")
|
||
if not provider_id:
|
||
return None
|
||
provider = self.inst_map.get(provider_id)
|
||
if not provider:
|
||
provider = (
|
||
self.stt_provider_insts[0] if self.stt_provider_insts else None
|
||
)
|
||
elif provider_type == ProviderType.TEXT_TO_SPEECH:
|
||
provider_id = config["provider_tts_settings"].get("provider_id")
|
||
if not provider_id:
|
||
return None
|
||
provider = self.inst_map.get(provider_id)
|
||
if not provider:
|
||
provider = (
|
||
self.tts_provider_insts[0] if self.tts_provider_insts else None
|
||
)
|
||
else:
|
||
raise ValueError(f"Unknown provider type: {provider_type}")
|
||
|
||
if not provider and provider_id:
|
||
logger.warning(
|
||
f"没有找到 ID 为 {provider_id} 的提供商,这可能是由于您修改了提供商(模型)ID 导致的。"
|
||
)
|
||
|
||
return provider
|
||
|
||
async def initialize(self) -> None:
|
||
# 逐个初始化提供商
|
||
for provider_config in self.providers_config:
|
||
try:
|
||
await self.load_provider(provider_config)
|
||
except Exception as e:
|
||
logger.error(traceback.format_exc())
|
||
logger.error(e)
|
||
|
||
selected_provider_id = await sp.get_async(
|
||
key="curr_provider",
|
||
default=self.provider_settings.get("default_provider_id"),
|
||
scope="global",
|
||
scope_id="global",
|
||
)
|
||
selected_stt_provider_id = await sp.get_async(
|
||
key="curr_provider_stt",
|
||
default=self.provider_stt_settings.get("provider_id"),
|
||
scope="global",
|
||
scope_id="global",
|
||
)
|
||
selected_tts_provider_id = await sp.get_async(
|
||
key="curr_provider_tts",
|
||
default=self.provider_tts_settings.get("provider_id"),
|
||
scope="global",
|
||
scope_id="global",
|
||
)
|
||
|
||
temp_provider = (
|
||
self.inst_map.get(selected_provider_id)
|
||
if isinstance(selected_provider_id, str)
|
||
else None
|
||
)
|
||
self.curr_provider_inst = (
|
||
temp_provider if isinstance(temp_provider, Provider) else None
|
||
)
|
||
if not self.curr_provider_inst and self.provider_insts:
|
||
self.curr_provider_inst = self.provider_insts[0]
|
||
|
||
temp_stt = (
|
||
self.inst_map.get(selected_stt_provider_id)
|
||
if isinstance(selected_stt_provider_id, str)
|
||
else None
|
||
)
|
||
self.curr_stt_provider_inst = (
|
||
temp_stt if isinstance(temp_stt, STTProvider) else None
|
||
)
|
||
if not self.curr_stt_provider_inst and self.stt_provider_insts:
|
||
self.curr_stt_provider_inst = self.stt_provider_insts[0]
|
||
|
||
temp_tts = (
|
||
self.inst_map.get(selected_tts_provider_id)
|
||
if isinstance(selected_tts_provider_id, str)
|
||
else None
|
||
)
|
||
self.curr_tts_provider_inst = (
|
||
temp_tts if isinstance(temp_tts, TTSProvider) else None
|
||
)
|
||
if not self.curr_tts_provider_inst and self.tts_provider_insts:
|
||
self.curr_tts_provider_inst = self.tts_provider_insts[0]
|
||
|
||
# 初始化 MCP Client 连接(等待完成以确保工具可用)
|
||
strict_mcp_init = os.getenv("ASTRBOT_MCP_INIT_STRICT", "").strip().lower() in {
|
||
"1",
|
||
"true",
|
||
"yes",
|
||
"on",
|
||
}
|
||
mcp_init_summary = await self.llm_tools.init_mcp_clients(
|
||
raise_on_all_failed=strict_mcp_init
|
||
)
|
||
if (
|
||
mcp_init_summary.total > 0
|
||
and mcp_init_summary.success == 0
|
||
and not strict_mcp_init
|
||
):
|
||
logger.warning(
|
||
"MCP 服务全部初始化失败,系统将继续启动(可设置 "
|
||
"ASTRBOT_MCP_INIT_STRICT=1 以在此场景下中止启动)。"
|
||
)
|
||
|
||
def dynamic_import_provider(self, type: str) -> None:
|
||
"""动态导入提供商适配器模块
|
||
|
||
Args:
|
||
type (str): 提供商请求类型。
|
||
|
||
Raises:
|
||
ImportError: 如果提供商类型未知或无法导入对应模块,则抛出异常。
|
||
"""
|
||
match type:
|
||
case "openai_chat_completion":
|
||
from .sources.openai_source import (
|
||
ProviderOpenAIOfficial as ProviderOpenAIOfficial,
|
||
)
|
||
case "zhipu_chat_completion":
|
||
from .sources.zhipu_source import ProviderZhipu as ProviderZhipu
|
||
case "groq_chat_completion":
|
||
from .sources.groq_source import ProviderGroq as ProviderGroq
|
||
case "xai_chat_completion":
|
||
from .sources.xai_source import ProviderXAI as ProviderXAI
|
||
case "aihubmix_chat_completion":
|
||
from .sources.oai_aihubmix_source import (
|
||
ProviderAIHubMix as ProviderAIHubMix,
|
||
)
|
||
case "openrouter_chat_completion":
|
||
from .sources.openrouter_source import (
|
||
ProviderOpenRouter as ProviderOpenRouter,
|
||
)
|
||
case "anthropic_chat_completion":
|
||
from .sources.anthropic_source import (
|
||
ProviderAnthropic as ProviderAnthropic,
|
||
)
|
||
case "googlegenai_chat_completion":
|
||
from .sources.gemini_source import (
|
||
ProviderGoogleGenAI as ProviderGoogleGenAI,
|
||
)
|
||
case "sensevoice_stt_selfhost":
|
||
from .sources.sensevoice_selfhosted_source import (
|
||
ProviderSenseVoiceSTTSelfHost as ProviderSenseVoiceSTTSelfHost,
|
||
)
|
||
case "openai_whisper_api":
|
||
from .sources.whisper_api_source import (
|
||
ProviderOpenAIWhisperAPI as ProviderOpenAIWhisperAPI,
|
||
)
|
||
case "openai_whisper_selfhost":
|
||
from .sources.whisper_selfhosted_source import (
|
||
ProviderOpenAIWhisperSelfHost as ProviderOpenAIWhisperSelfHost,
|
||
)
|
||
case "xinference_stt":
|
||
from .sources.xinference_stt_provider import (
|
||
ProviderXinferenceSTT as ProviderXinferenceSTT,
|
||
)
|
||
case "openai_tts_api":
|
||
from .sources.openai_tts_api_source import (
|
||
ProviderOpenAITTSAPI as ProviderOpenAITTSAPI,
|
||
)
|
||
case "genie_tts":
|
||
from .sources.genie_tts import (
|
||
GenieTTSProvider as GenieTTSProvider,
|
||
)
|
||
case "edge_tts":
|
||
from .sources.edge_tts_source import (
|
||
ProviderEdgeTTS as ProviderEdgeTTS,
|
||
)
|
||
case "gsv_tts_selfhost":
|
||
from .sources.gsv_selfhosted_source import (
|
||
ProviderGSVTTS as ProviderGSVTTS,
|
||
)
|
||
case "gsvi_tts_api":
|
||
from .sources.gsvi_tts_source import (
|
||
ProviderGSVITTS as ProviderGSVITTS,
|
||
)
|
||
case "fishaudio_tts_api":
|
||
from .sources.fishaudio_tts_api_source import (
|
||
ProviderFishAudioTTSAPI as ProviderFishAudioTTSAPI,
|
||
)
|
||
case "dashscope_tts":
|
||
from .sources.dashscope_tts import (
|
||
ProviderDashscopeTTSAPI as ProviderDashscopeTTSAPI,
|
||
)
|
||
case "azure_tts":
|
||
from .sources.azure_tts_source import (
|
||
AzureTTSProvider as AzureTTSProvider,
|
||
)
|
||
case "minimax_tts_api":
|
||
from .sources.minimax_tts_api_source import (
|
||
ProviderMiniMaxTTSAPI as ProviderMiniMaxTTSAPI,
|
||
)
|
||
case "volcengine_tts":
|
||
from .sources.volcengine_tts import (
|
||
ProviderVolcengineTTS as ProviderVolcengineTTS,
|
||
)
|
||
case "gemini_tts":
|
||
from .sources.gemini_tts_source import (
|
||
ProviderGeminiTTSAPI as ProviderGeminiTTSAPI,
|
||
)
|
||
case "openai_embedding":
|
||
from .sources.openai_embedding_source import (
|
||
OpenAIEmbeddingProvider as OpenAIEmbeddingProvider,
|
||
)
|
||
case "gemini_embedding":
|
||
from .sources.gemini_embedding_source import (
|
||
GeminiEmbeddingProvider as GeminiEmbeddingProvider,
|
||
)
|
||
case "vllm_rerank":
|
||
from .sources.vllm_rerank_source import (
|
||
VLLMRerankProvider as VLLMRerankProvider,
|
||
)
|
||
case "xinference_rerank":
|
||
from .sources.xinference_rerank_source import (
|
||
XinferenceRerankProvider as XinferenceRerankProvider,
|
||
)
|
||
case "bailian_rerank":
|
||
from .sources.bailian_rerank_source import (
|
||
BailianRerankProvider as BailianRerankProvider,
|
||
)
|
||
|
||
def get_merged_provider_config(self, provider_config: dict) -> dict:
|
||
"""获取 provider 配置和 provider_source 配置合并后的结果
|
||
|
||
Returns:
|
||
dict: 合并后的 provider 配置,key 为 provider id,value 为合并后的配置字典
|
||
"""
|
||
pc = copy.deepcopy(provider_config)
|
||
provider_source_id = pc.get("provider_source_id", "")
|
||
if provider_source_id:
|
||
provider_source = None
|
||
for ps in self.provider_sources_config:
|
||
if ps.get("id") == provider_source_id:
|
||
provider_source = ps
|
||
break
|
||
|
||
if provider_source:
|
||
# 合并配置,provider 的配置优先级更高
|
||
merged_config = {**provider_source, **pc}
|
||
# 保持 id 为 provider 的 id,而不是 source 的 id
|
||
merged_config["id"] = pc["id"]
|
||
pc = merged_config
|
||
return pc
|
||
|
||
def _resolve_env_key_list(self, provider_config: dict) -> dict:
|
||
keys = provider_config.get("key", [])
|
||
if not isinstance(keys, list):
|
||
return provider_config
|
||
resolved_keys = []
|
||
for idx, key in enumerate(keys):
|
||
if isinstance(key, str) and key.startswith("$"):
|
||
env_key = key[1:]
|
||
if env_key.startswith("{") and env_key.endswith("}"):
|
||
env_key = env_key[1:-1]
|
||
if env_key:
|
||
env_val = os.getenv(env_key)
|
||
if env_val is None:
|
||
provider_id = provider_config.get("id")
|
||
logger.warning(
|
||
f"Provider {provider_id} 配置项 key[{idx}] 使用环境变量 {env_key} 但未设置。",
|
||
)
|
||
resolved_keys.append("")
|
||
else:
|
||
resolved_keys.append(env_val)
|
||
else:
|
||
resolved_keys.append(key)
|
||
else:
|
||
resolved_keys.append(key)
|
||
provider_config["key"] = resolved_keys
|
||
return provider_config
|
||
|
||
async def load_provider(self, provider_config: dict) -> None:
|
||
# 如果 provider_source_id 存在且不为空,则从 provider_sources 中找到对应的配置并合并
|
||
provider_config = self.get_merged_provider_config(provider_config)
|
||
|
||
if provider_config.get("provider_type", "") == "chat_completion":
|
||
provider_config = self._resolve_env_key_list(provider_config)
|
||
|
||
if not provider_config["enable"]:
|
||
logger.info(f"Provider {provider_config['id']} is disabled, skipping")
|
||
return
|
||
if provider_config.get("provider_type", "") == "agent_runner":
|
||
return
|
||
|
||
logger.info(
|
||
f"载入 {provider_config['type']}({provider_config['id']}) 服务提供商 ...",
|
||
)
|
||
|
||
# 动态导入
|
||
try:
|
||
self.dynamic_import_provider(provider_config["type"])
|
||
except (ImportError, ModuleNotFoundError) as e:
|
||
logger.critical(
|
||
f"加载 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}。可能是因为有未安装的依赖。",
|
||
exc_info=True,
|
||
)
|
||
return
|
||
except Exception as e:
|
||
logger.critical(
|
||
f"加载 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}。未知原因",
|
||
exc_info=True,
|
||
)
|
||
return
|
||
|
||
if provider_config["type"] not in provider_cls_map:
|
||
logger.error(
|
||
f"未找到适用于 {provider_config['type']}({provider_config['id']}) 的提供商适配器,请检查是否已经安装或者名称填写错误。已跳过。",
|
||
exc_info=True,
|
||
)
|
||
return
|
||
|
||
provider_metadata = provider_cls_map[provider_config["type"]]
|
||
try:
|
||
# 按任务实例化提供商
|
||
cls_type = provider_metadata.cls_type
|
||
if not cls_type:
|
||
logger.error(f"无法找到 {provider_metadata.type} 的类")
|
||
return
|
||
|
||
provider_metadata.id = provider_config["id"]
|
||
|
||
match provider_metadata.provider_type:
|
||
case ProviderType.SPEECH_TO_TEXT:
|
||
# STT 任务
|
||
if not issubclass(cls_type, STTProvider):
|
||
raise TypeError(
|
||
f"Provider class {cls_type} is not a subclass of STTProvider"
|
||
)
|
||
inst = cls_type(provider_config, self.provider_settings)
|
||
|
||
if isinstance(inst, HasInitialize):
|
||
await inst.initialize()
|
||
|
||
self.stt_provider_insts.append(inst)
|
||
if (
|
||
self.provider_stt_settings.get("provider_id")
|
||
== provider_config["id"]
|
||
):
|
||
self.curr_stt_provider_inst = inst
|
||
logger.info(
|
||
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前语音转文本提供商适配器。",
|
||
)
|
||
if not self.curr_stt_provider_inst:
|
||
self.curr_stt_provider_inst = inst
|
||
|
||
case ProviderType.TEXT_TO_SPEECH:
|
||
# TTS 任务
|
||
if not issubclass(cls_type, TTSProvider):
|
||
raise TypeError(
|
||
f"Provider class {cls_type} is not a subclass of TTSProvider"
|
||
)
|
||
inst = cls_type(provider_config, self.provider_settings)
|
||
|
||
if isinstance(inst, HasInitialize):
|
||
await inst.initialize()
|
||
|
||
self.tts_provider_insts.append(inst)
|
||
if (
|
||
self.provider_settings.get("provider_id")
|
||
== provider_config["id"]
|
||
):
|
||
self.curr_tts_provider_inst = inst
|
||
logger.info(
|
||
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前文本转语音提供商适配器。",
|
||
)
|
||
if not self.curr_tts_provider_inst:
|
||
self.curr_tts_provider_inst = inst
|
||
|
||
case ProviderType.CHAT_COMPLETION:
|
||
# 文本生成任务
|
||
if not issubclass(cls_type, Provider):
|
||
raise TypeError(
|
||
f"Provider class {cls_type} is not a subclass of Provider"
|
||
)
|
||
inst = cls_type(
|
||
provider_config,
|
||
self.provider_settings,
|
||
)
|
||
|
||
if isinstance(inst, HasInitialize):
|
||
await inst.initialize()
|
||
|
||
self.provider_insts.append(inst)
|
||
if (
|
||
self.provider_settings.get("default_provider_id")
|
||
== provider_config["id"]
|
||
):
|
||
self.curr_provider_inst = inst
|
||
logger.info(
|
||
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前提供商适配器。",
|
||
)
|
||
if not self.curr_provider_inst:
|
||
self.curr_provider_inst = inst
|
||
|
||
case ProviderType.EMBEDDING:
|
||
if not issubclass(cls_type, EmbeddingProvider):
|
||
raise TypeError(
|
||
f"Provider class {cls_type} is not a subclass of EmbeddingProvider"
|
||
)
|
||
inst = cls_type(provider_config, self.provider_settings)
|
||
if isinstance(inst, HasInitialize):
|
||
await inst.initialize()
|
||
self.embedding_provider_insts.append(inst)
|
||
case ProviderType.RERANK:
|
||
if not issubclass(cls_type, RerankProvider):
|
||
raise TypeError(
|
||
f"Provider class {cls_type} is not a subclass of RerankProvider"
|
||
)
|
||
inst = cls_type(provider_config, self.provider_settings)
|
||
if isinstance(inst, HasInitialize):
|
||
await inst.initialize()
|
||
self.rerank_provider_insts.append(inst)
|
||
case _:
|
||
# 未知供应商抛出异常,确保inst初始化
|
||
# Should be unreachable
|
||
raise Exception(
|
||
f"未知的提供商类型:{provider_metadata.provider_type}"
|
||
)
|
||
|
||
self.inst_map[provider_config["id"]] = inst
|
||
except Exception as e:
|
||
logger.error(
|
||
f"实例化 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}",
|
||
)
|
||
raise Exception(
|
||
f"实例化 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}",
|
||
)
|
||
|
||
async def reload(self, provider_config: dict) -> None:
|
||
async with self.reload_lock:
|
||
await self.terminate_provider(provider_config["id"])
|
||
if provider_config["enable"]:
|
||
await self.load_provider(provider_config)
|
||
|
||
# 和配置文件保持同步
|
||
self.providers_config = astrbot_config["provider"]
|
||
self.provider_sources_config = astrbot_config.get("provider_sources", [])
|
||
config_ids = [provider["id"] for provider in self.providers_config]
|
||
logger.info(f"providers in user's config: {config_ids}")
|
||
for key in list(self.inst_map.keys()):
|
||
if key not in config_ids:
|
||
await self.terminate_provider(key)
|
||
|
||
if len(self.provider_insts) == 0:
|
||
self.curr_provider_inst = None
|
||
elif self.curr_provider_inst is None and len(self.provider_insts) > 0:
|
||
self.curr_provider_inst = self.provider_insts[0]
|
||
logger.info(
|
||
f"自动选择 {self.curr_provider_inst.meta().id} 作为当前提供商适配器。",
|
||
)
|
||
|
||
if len(self.stt_provider_insts) == 0:
|
||
self.curr_stt_provider_inst = None
|
||
elif (
|
||
self.curr_stt_provider_inst is None and len(self.stt_provider_insts) > 0
|
||
):
|
||
self.curr_stt_provider_inst = self.stt_provider_insts[0]
|
||
logger.info(
|
||
f"自动选择 {self.curr_stt_provider_inst.meta().id} 作为当前语音转文本提供商适配器。",
|
||
)
|
||
|
||
if len(self.tts_provider_insts) == 0:
|
||
self.curr_tts_provider_inst = None
|
||
elif (
|
||
self.curr_tts_provider_inst is None and len(self.tts_provider_insts) > 0
|
||
):
|
||
self.curr_tts_provider_inst = self.tts_provider_insts[0]
|
||
logger.info(
|
||
f"自动选择 {self.curr_tts_provider_inst.meta().id} 作为当前文本转语音提供商适配器。",
|
||
)
|
||
|
||
def get_insts(self):
|
||
return self.provider_insts
|
||
|
||
async def terminate_provider(self, provider_id: str) -> None:
|
||
if provider_id in self.inst_map:
|
||
logger.info(
|
||
f"终止 {provider_id} 提供商适配器({len(self.provider_insts)}, {len(self.stt_provider_insts)}, {len(self.tts_provider_insts)}) ...",
|
||
)
|
||
|
||
if self.inst_map[provider_id] in self.provider_insts:
|
||
prov_inst = self.inst_map[provider_id]
|
||
if isinstance(prov_inst, Provider):
|
||
self.provider_insts.remove(prov_inst)
|
||
if self.inst_map[provider_id] in self.stt_provider_insts:
|
||
prov_inst = self.inst_map[provider_id]
|
||
if isinstance(prov_inst, STTProvider):
|
||
self.stt_provider_insts.remove(prov_inst)
|
||
if self.inst_map[provider_id] in self.tts_provider_insts:
|
||
prov_inst = self.inst_map[provider_id]
|
||
if isinstance(prov_inst, TTSProvider):
|
||
self.tts_provider_insts.remove(prov_inst)
|
||
|
||
if self.inst_map[provider_id] == self.curr_provider_inst:
|
||
self.curr_provider_inst = None
|
||
if self.inst_map[provider_id] == self.curr_stt_provider_inst:
|
||
self.curr_stt_provider_inst = None
|
||
if self.inst_map[provider_id] == self.curr_tts_provider_inst:
|
||
self.curr_tts_provider_inst = None
|
||
|
||
if getattr(self.inst_map[provider_id], "terminate", None):
|
||
await self.inst_map[provider_id].terminate() # type: ignore
|
||
|
||
logger.info(
|
||
f"{provider_id} 提供商适配器已终止({len(self.provider_insts)}, {len(self.stt_provider_insts)}, {len(self.tts_provider_insts)})",
|
||
)
|
||
del self.inst_map[provider_id]
|
||
|
||
async def delete_provider(
|
||
self, provider_id: str | None = None, provider_source_id: str | None = None
|
||
) -> None:
|
||
"""Delete provider and/or provider source from config and terminate the instances. Config will be saved after deletion."""
|
||
async with self.resource_lock:
|
||
# delete from config
|
||
target_prov_ids = []
|
||
if provider_id:
|
||
target_prov_ids.append(provider_id)
|
||
else:
|
||
for prov in self.providers_config:
|
||
if prov.get("provider_source_id") == provider_source_id:
|
||
target_prov_ids.append(prov.get("id"))
|
||
config = self.acm.default_conf
|
||
for tpid in target_prov_ids:
|
||
await self.terminate_provider(tpid)
|
||
config["provider"] = [
|
||
prov for prov in config["provider"] if prov.get("id") != tpid
|
||
]
|
||
config.save_config()
|
||
logger.info(f"Provider {target_prov_ids} 已从配置中删除。")
|
||
|
||
async def update_provider(self, origin_provider_id: str, new_config: dict) -> None:
|
||
"""Update provider config and reload the instance. Config will be saved after update."""
|
||
async with self.resource_lock:
|
||
npid = new_config.get("id", None)
|
||
if not npid:
|
||
raise ValueError("New provider config must have an 'id' field")
|
||
config = self.acm.default_conf
|
||
for provider in config["provider"]:
|
||
if (
|
||
provider.get("id", None) == npid
|
||
and provider.get("id", None) != origin_provider_id
|
||
):
|
||
raise ValueError(f"Provider ID {npid} already exists")
|
||
# update config
|
||
for idx, provider in enumerate(config["provider"]):
|
||
if provider.get("id", None) == origin_provider_id:
|
||
config["provider"][idx] = new_config
|
||
break
|
||
else:
|
||
raise ValueError(f"Provider ID {origin_provider_id} not found")
|
||
config.save_config()
|
||
# reload instance
|
||
await self.reload(new_config)
|
||
|
||
async def create_provider(self, new_config: dict) -> None:
|
||
"""Add new provider config and load the instance. Config will be saved after addition."""
|
||
async with self.resource_lock:
|
||
npid = new_config.get("id", None)
|
||
if not npid:
|
||
raise ValueError("New provider config must have an 'id' field")
|
||
config = self.acm.default_conf
|
||
for provider in config["provider"]:
|
||
if provider.get("id", None) == npid:
|
||
raise ValueError(f"Provider ID {npid} already exists")
|
||
# add to config
|
||
config["provider"].append(new_config)
|
||
config.save_config()
|
||
# load instance
|
||
await self.load_provider(new_config)
|
||
|
||
async def terminate(self) -> None:
|
||
for provider_inst in self.provider_insts:
|
||
if hasattr(provider_inst, "terminate"):
|
||
await provider_inst.terminate() # type: ignore
|
||
try:
|
||
await self.llm_tools.disable_mcp_server()
|
||
except Exception:
|
||
logger.error("Error while disabling MCP servers", exc_info=True)
|