Fix/fix: resolve MCP tools race condition causing 'completion 无法解析' error (#5534)
* fix: resolve MCP tools race condition causing 'completion 无法解析' error - Wait for MCP client initialization to complete before accepting requests - Add Future-based synchronization in init_mcp_clients() - Prevent tool_calls from being rejected due to empty func_list - Improve error logging for MCP initialization failures Fixes race condition where AI attempts to call MCP tools before they are registered, resulting in 'API 返回的 completion 无法解析' exceptions. The issue occurred because: 1. MCP clients were initialized asynchronously without waiting 2. System accepted user requests immediately after startup 3. AI received empty tool list and attempted to call non-existent tools 4. Tool matching failed, causing parsing errors This fix ensures all MCP tools are loaded before the system processes any requests that might use them. * perf: add timeout and better error handling for MCP initialization - Add 20-second total timeout to prevent slow MCP servers from blocking startup - Show detailed configuration info when MCP initialization fails - List all failed services in a summary warning - Gracefully handle timeout by using already-completed services This ensures that even if some MCP servers are slow or unreachable, the system will start within a reasonable time and provide clear feedback about which services failed and why. * refactor: simplify MCP init orchestration and improve log security - Replace Future-based sync with asyncio.wait + name→task mapping - Explicitly cancel timed-out tasks after 20s timeout - Downgrade sensitive config details (command/args/URL) to debug level - Move urllib.parse import to top-level * fix: prevent initialized MCP clients from being cleaned up on timeout - Do not cancel pending tasks on timeout; let them continue running in the background waiting for the termination signal (event.set()), so successfully initialized services remain available - Track initialization state with a flag to distinguish init failures from post-init cancellations in _init_mcp_client_task_wrapper * fix: restore task cancellation on timeout per review feedback Pending tasks in asyncio.wait are tasks that have NOT completed initialization within 20s, so cancelling them is safe and correct. * fix: separate init signal from client lifetime in MCP task wrapper The previous design awaited task completion, but tasks only finish on shutdown (after event.wait()), causing asyncio.wait to always hit the 20s timeout and cancel all clients. Fix: introduce a dedicated ready_event that is set immediately after _init_mcp_client completes. init_mcp_clients now waits only for ready_event (with 20s timeout), while the long-lived client task continues running in the background until shutdown_event is set. This ensures startup returns promptly once clients are ready. * security: redact sensitive MCP config from debug logs Only log executable name and argument count instead of full command/args to avoid leaking tokens or credentials even at debug level. * refactor: use McpClientInfo dataclass and MCP_INIT_TIMEOUT constant - Extract MCP_INIT_TIMEOUT = 20.0 as a named module-level constant - Replace tuple-based client_info with _McpClientInfo dataclass to eliminate index-based access and improve readability - Remove _wait_ready helper; use asyncio.create_task(event.wait()) directly - Await cancelled tasks after timeout to prevent lingering background tasks and unobserved exceptions * fix: handle CancelledError and clean up wait_tasks on timeout - Catch asyncio.CancelledError separately in _init_mcp_client_task_wrapper so ready_event.set() is always called (Python 3.8+ CancelledError inherits BaseException, not Exception) - Cancel and await lingering wait_tasks after timeout to prevent them from hanging indefinitely when ready_event is never set * fix: align enable_mcp_server with new wrapper API and fix security/config issues - Fix enable_mcp_server to pass shutdown_event + ready_event instead of ready_future, matching _init_mcp_client_task_wrapper's current signature - Cancel and await init_task on timeout; clean up mcp_client_event on failure - Read MCP_INIT_TIMEOUT from env var ASTRBOT_MCP_INIT_TIMEOUT (default 20s) so operators can tune it without code changes - Strip userinfo from URL in debug log (use hostname+port only, not netloc) to avoid leaking credentials embedded in URLs * refactor: register mcp_client_event only after successful init in enable_mcp_server Move self.mcp_client_event[name] assignment to after initialization succeeds, so callers never observe a stale event for a failed client. * fix: harden MCP init state handling and timeout parsing * fix: improve MCP timeout and post-init error observability * refactor: simplify MCP init lifecycle orchestration * refactor: simplify MCP init flow and cap timeout values * fix: refine mcp timeout handling and lifecycle task tracking * fix: harden mcp shutdown and timeout source logging * refactor: simplify mcp runtime registry and timeout flow * fix: keep mcp init summary return contract * refactor: streamline mcp lifecycle and init errors * refactor: unify mcp lifecycle wait handling * refactor: simplify mcp runtime ownership and timeout resolution * fix: harden mcp shutdown waiting and startup signaling * refactor: streamline mcp lifecycle and shutdown errors * refactor: harden mcp runtime access and shutdown * fix: ensure mcp client cleanup and clarify views * refactor: cache mcp client view and guard startup * refactor: simplify mcp init cleanup and runtime lock * refactor: reduce mcp runtime duplication * refactor: reuse mcp cleanup and client view --------- Co-authored-by: idiotsj <idiotsj@users.noreply.github.com> Co-authored-by: 邹永赫 <1259085392@qq.com>
This commit is contained in:
@@ -4,7 +4,11 @@ import asyncio
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
from collections.abc import AsyncGenerator, Awaitable, Callable
|
||||
import threading
|
||||
import urllib.parse
|
||||
from collections.abc import AsyncGenerator, Awaitable, Callable, Mapping
|
||||
from dataclasses import dataclass
|
||||
from types import MappingProxyType
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
@@ -17,6 +21,103 @@ from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
DEFAULT_MCP_CONFIG = {"mcpServers": {}}
|
||||
|
||||
DEFAULT_MCP_INIT_TIMEOUT_SECONDS = 20.0
|
||||
DEFAULT_ENABLE_MCP_TIMEOUT_SECONDS = 30.0
|
||||
MCP_INIT_TIMEOUT_ENV = "ASTRBOT_MCP_INIT_TIMEOUT"
|
||||
ENABLE_MCP_TIMEOUT_ENV = "ASTRBOT_MCP_ENABLE_TIMEOUT"
|
||||
MAX_MCP_TIMEOUT_SECONDS = 300.0
|
||||
|
||||
|
||||
class MCPInitError(Exception):
|
||||
"""Base exception for MCP initialization failures."""
|
||||
|
||||
|
||||
class MCPInitTimeoutError(asyncio.TimeoutError, MCPInitError):
|
||||
"""Raised when MCP client initialization exceeds the configured timeout."""
|
||||
|
||||
|
||||
class MCPAllServicesFailedError(MCPInitError):
|
||||
"""Raised when all configured MCP services fail to initialize."""
|
||||
|
||||
|
||||
class MCPShutdownTimeoutError(asyncio.TimeoutError):
|
||||
"""Raised when MCP shutdown exceeds the configured timeout."""
|
||||
|
||||
def __init__(self, names: list[str], timeout: float) -> None:
|
||||
self.names = names
|
||||
self.timeout = timeout
|
||||
message = f"MCP 服务关闭超时({timeout:g} 秒):{', '.join(names)}"
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MCPInitSummary:
|
||||
total: int
|
||||
success: int
|
||||
failed: list[str]
|
||||
|
||||
|
||||
@dataclass
|
||||
class _MCPServerRuntime:
|
||||
name: str
|
||||
client: MCPClient
|
||||
shutdown_event: asyncio.Event
|
||||
lifecycle_task: asyncio.Task[None]
|
||||
|
||||
|
||||
class _MCPClientDictView(Mapping[str, MCPClient]):
|
||||
"""Read-only view of MCP clients derived from runtime state."""
|
||||
|
||||
def __init__(self, runtime: dict[str, _MCPServerRuntime]) -> None:
|
||||
self._runtime = runtime
|
||||
|
||||
def __getitem__(self, key: str) -> MCPClient:
|
||||
return self._runtime[key].client
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self._runtime)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._runtime)
|
||||
|
||||
|
||||
def _resolve_timeout(
|
||||
timeout: float | int | str | None = None,
|
||||
*,
|
||||
env_name: str = MCP_INIT_TIMEOUT_ENV,
|
||||
default: float = DEFAULT_MCP_INIT_TIMEOUT_SECONDS,
|
||||
) -> float:
|
||||
"""Resolve timeout with precedence: explicit argument > env value > default."""
|
||||
source = f"环境变量 {env_name}"
|
||||
if timeout is None:
|
||||
timeout = os.getenv(env_name, str(default))
|
||||
else:
|
||||
source = "显式参数 timeout"
|
||||
|
||||
try:
|
||||
timeout_value = float(timeout)
|
||||
except (TypeError, ValueError):
|
||||
logger.warning(
|
||||
f"超时配置({source})={timeout!r} 无效,使用默认值 {default:g} 秒。"
|
||||
)
|
||||
return default
|
||||
|
||||
if timeout_value <= 0:
|
||||
logger.warning(
|
||||
f"超时配置({source})={timeout_value:g} 必须大于 0,使用默认值 {default:g} 秒。"
|
||||
)
|
||||
return default
|
||||
|
||||
if timeout_value > MAX_MCP_TIMEOUT_SECONDS:
|
||||
logger.warning(
|
||||
f"超时配置({source})={timeout_value:g} 过大,已限制为最大值 "
|
||||
f"{MAX_MCP_TIMEOUT_SECONDS:g} 秒,以避免长时间等待。"
|
||||
)
|
||||
return MAX_MCP_TIMEOUT_SECONDS
|
||||
|
||||
return timeout_value
|
||||
|
||||
|
||||
SUPPORTED_TYPES = [
|
||||
"string",
|
||||
"number",
|
||||
@@ -106,9 +207,49 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
|
||||
class FunctionToolManager:
|
||||
def __init__(self) -> None:
|
||||
self.func_list: list[FuncTool] = []
|
||||
self.mcp_client_dict: dict[str, MCPClient] = {}
|
||||
"""MCP 服务列表"""
|
||||
self.mcp_client_event: dict[str, asyncio.Event] = {}
|
||||
self._mcp_server_runtime: dict[str, _MCPServerRuntime] = {}
|
||||
"""MCP 服务运行时状态(唯一事实来源)"""
|
||||
self._mcp_server_runtime_view = MappingProxyType(self._mcp_server_runtime)
|
||||
self._mcp_client_dict_view = _MCPClientDictView(self._mcp_server_runtime)
|
||||
self._timeout_mismatch_warned = False
|
||||
self._timeout_warn_lock = threading.Lock()
|
||||
self._runtime_lock = asyncio.Lock()
|
||||
self._mcp_starting: set[str] = set()
|
||||
self._init_timeout_default = _resolve_timeout(
|
||||
timeout=None,
|
||||
env_name=MCP_INIT_TIMEOUT_ENV,
|
||||
default=DEFAULT_MCP_INIT_TIMEOUT_SECONDS,
|
||||
)
|
||||
self._enable_timeout_default = _resolve_timeout(
|
||||
timeout=None,
|
||||
env_name=ENABLE_MCP_TIMEOUT_ENV,
|
||||
default=DEFAULT_ENABLE_MCP_TIMEOUT_SECONDS,
|
||||
)
|
||||
self._warn_on_timeout_mismatch(
|
||||
self._init_timeout_default,
|
||||
self._enable_timeout_default,
|
||||
)
|
||||
|
||||
@property
|
||||
def mcp_client_dict(self) -> Mapping[str, MCPClient]:
|
||||
"""Read-only compatibility view for external callers that still read mcp_client_dict.
|
||||
|
||||
Note: Mutating this mapping is unsupported and will raise TypeError.
|
||||
"""
|
||||
return self._mcp_client_dict_view
|
||||
|
||||
@property
|
||||
def mcp_server_runtime_view(self) -> Mapping[str, _MCPServerRuntime]:
|
||||
"""Read-only view of MCP runtime metadata for external callers."""
|
||||
return self._mcp_server_runtime_view
|
||||
|
||||
@property
|
||||
def mcp_server_runtime(self) -> Mapping[str, _MCPServerRuntime]:
|
||||
"""Backward-compatible read-only view (deprecated). Do not mutate.
|
||||
|
||||
Note: Mutations are not supported and will raise TypeError.
|
||||
"""
|
||||
return self._mcp_server_runtime_view
|
||||
|
||||
def empty(self) -> bool:
|
||||
return len(self.func_list) == 0
|
||||
@@ -179,7 +320,34 @@ class FunctionToolManager:
|
||||
tool_set = ToolSet(self.func_list.copy())
|
||||
return tool_set
|
||||
|
||||
async def init_mcp_clients(self) -> None:
|
||||
@staticmethod
|
||||
def _log_safe_mcp_debug_config(cfg: dict) -> None:
|
||||
# 仅记录脱敏后的摘要,避免泄露 command/args/url 中的敏感信息
|
||||
if "command" in cfg:
|
||||
cmd = cfg["command"]
|
||||
executable = str(cmd[0] if isinstance(cmd, (list, tuple)) and cmd else cmd)
|
||||
args_val = cfg.get("args", [])
|
||||
args_count = (
|
||||
len(args_val)
|
||||
if isinstance(args_val, (list, tuple))
|
||||
else (0 if args_val is None else 1)
|
||||
)
|
||||
logger.debug(f" 命令可执行文件: {executable}, 参数数量: {args_count}")
|
||||
return
|
||||
|
||||
if "url" in cfg:
|
||||
parsed = urllib.parse.urlparse(str(cfg["url"]))
|
||||
host = parsed.hostname or ""
|
||||
scheme = parsed.scheme or "unknown"
|
||||
try:
|
||||
port = f":{parsed.port}" if parsed.port else ""
|
||||
except ValueError:
|
||||
port = ""
|
||||
logger.debug(f" 主机: {scheme}://{host}{port}")
|
||||
|
||||
async def init_mcp_clients(
|
||||
self, raise_on_all_failed: bool = False
|
||||
) -> MCPInitSummary:
|
||||
"""从项目根目录读取 mcp_server.json 文件,初始化 MCP 服务列表。文件格式如下:
|
||||
```
|
||||
{
|
||||
@@ -197,6 +365,10 @@ class FunctionToolManager:
|
||||
...
|
||||
}
|
||||
```
|
||||
|
||||
Timeout behavior:
|
||||
- 初始化超时使用环境变量 ASTRBOT_MCP_INIT_TIMEOUT 或默认值。
|
||||
- 动态启用超时使用 ASTRBOT_MCP_ENABLE_TIMEOUT(独立于初始化超时)。
|
||||
"""
|
||||
data_dir = get_astrbot_data_path()
|
||||
|
||||
@@ -206,56 +378,211 @@ class FunctionToolManager:
|
||||
with open(mcp_json_file, "w", encoding="utf-8") as f:
|
||||
json.dump(DEFAULT_MCP_CONFIG, f, ensure_ascii=False, indent=4)
|
||||
logger.info(f"未找到 MCP 服务配置文件,已创建默认配置文件 {mcp_json_file}")
|
||||
return
|
||||
return MCPInitSummary(total=0, success=0, failed=[])
|
||||
|
||||
mcp_server_json_obj: dict[str, dict] = json.load(
|
||||
open(mcp_json_file, encoding="utf-8"),
|
||||
)["mcpServers"]
|
||||
with open(mcp_json_file, encoding="utf-8") as f:
|
||||
mcp_server_json_obj: dict[str, dict] = json.load(f)["mcpServers"]
|
||||
|
||||
for name in mcp_server_json_obj:
|
||||
cfg = mcp_server_json_obj[name]
|
||||
init_timeout = self._init_timeout_default
|
||||
timeout_display = f"{init_timeout:g}"
|
||||
|
||||
active_configs: list[tuple[str, dict, asyncio.Event]] = []
|
||||
for name, cfg in mcp_server_json_obj.items():
|
||||
if cfg.get("active", True):
|
||||
event = asyncio.Event()
|
||||
asyncio.create_task(
|
||||
self._init_mcp_client_task_wrapper(name, cfg, event),
|
||||
)
|
||||
self.mcp_client_event[name] = event
|
||||
shutdown_event = asyncio.Event()
|
||||
active_configs.append((name, cfg, shutdown_event))
|
||||
|
||||
async def _init_mcp_client_task_wrapper(
|
||||
if not active_configs:
|
||||
return MCPInitSummary(total=0, success=0, failed=[])
|
||||
|
||||
logger.info(f"等待 {len(active_configs)} 个 MCP 服务初始化...")
|
||||
|
||||
init_tasks = [
|
||||
asyncio.create_task(
|
||||
self._start_mcp_server(
|
||||
name=name,
|
||||
cfg=cfg,
|
||||
shutdown_event=shutdown_event,
|
||||
timeout=init_timeout,
|
||||
),
|
||||
name=f"mcp-init:{name}",
|
||||
)
|
||||
for (name, cfg, shutdown_event) in active_configs
|
||||
]
|
||||
results = await asyncio.gather(*init_tasks, return_exceptions=True)
|
||||
|
||||
success_count = 0
|
||||
failed_services: list[str] = []
|
||||
|
||||
for (name, cfg, _), result in zip(active_configs, results, strict=False):
|
||||
if isinstance(result, Exception):
|
||||
if isinstance(result, MCPInitTimeoutError):
|
||||
logger.error(f"MCP 服务 {name} 初始化超时({timeout_display}秒)")
|
||||
else:
|
||||
logger.error(f"MCP 服务 {name} 初始化失败: {result}")
|
||||
self._log_safe_mcp_debug_config(cfg)
|
||||
failed_services.append(name)
|
||||
async with self._runtime_lock:
|
||||
self._mcp_server_runtime.pop(name, None)
|
||||
continue
|
||||
|
||||
success_count += 1
|
||||
|
||||
if failed_services:
|
||||
logger.warning(
|
||||
f"以下 MCP 服务初始化失败: {', '.join(failed_services)}。"
|
||||
f"请检查配置文件 mcp_server.json 和服务器可用性。"
|
||||
)
|
||||
|
||||
summary = MCPInitSummary(
|
||||
total=len(active_configs), success=success_count, failed=failed_services
|
||||
)
|
||||
logger.info(f"MCP 服务初始化完成: {summary.success}/{summary.total} 成功")
|
||||
if summary.total > 0 and summary.success == 0:
|
||||
msg = "全部 MCP 服务初始化失败,请检查 mcp_server.json 配置和服务器可用性。"
|
||||
if raise_on_all_failed:
|
||||
raise MCPAllServicesFailedError(msg)
|
||||
logger.error(msg)
|
||||
return summary
|
||||
|
||||
async def _start_mcp_server(
|
||||
self,
|
||||
name: str,
|
||||
cfg: dict,
|
||||
event: asyncio.Event,
|
||||
ready_future: asyncio.Future | None = None,
|
||||
*,
|
||||
shutdown_event: asyncio.Event | None = None,
|
||||
timeout: float,
|
||||
) -> None:
|
||||
"""初始化 MCP 客户端的包装函数,用于捕获异常"""
|
||||
"""Initialize MCP server with timeout and register task/event together.
|
||||
|
||||
This method is idempotent. If the server is already running, the existing
|
||||
runtime is kept and the new config is ignored.
|
||||
"""
|
||||
async with self._runtime_lock:
|
||||
if name in self._mcp_server_runtime or name in self._mcp_starting:
|
||||
logger.warning(
|
||||
f"MCP 服务 {name} 已在运行,忽略本次启用请求(timeout={timeout:g})。"
|
||||
)
|
||||
self._log_safe_mcp_debug_config(cfg)
|
||||
return
|
||||
self._mcp_starting.add(name)
|
||||
|
||||
if shutdown_event is None:
|
||||
shutdown_event = asyncio.Event()
|
||||
|
||||
mcp_client: MCPClient | None = None
|
||||
try:
|
||||
await self._init_mcp_client(name, cfg)
|
||||
tools = await self.mcp_client_dict[name].list_tools_and_save()
|
||||
if ready_future and not ready_future.done():
|
||||
# tell the caller we are ready
|
||||
ready_future.set_result(tools)
|
||||
await event.wait()
|
||||
logger.info(f"收到 MCP 客户端 {name} 终止信号")
|
||||
except Exception as e:
|
||||
mcp_client = await asyncio.wait_for(
|
||||
self._init_mcp_client(name, cfg),
|
||||
timeout=timeout,
|
||||
)
|
||||
except asyncio.TimeoutError as exc:
|
||||
raise MCPInitTimeoutError(
|
||||
f"MCP 服务 {name} 初始化超时({timeout:g} 秒)"
|
||||
) from exc
|
||||
except Exception:
|
||||
logger.error(f"初始化 MCP 客户端 {name} 失败", exc_info=True)
|
||||
if ready_future and not ready_future.done():
|
||||
ready_future.set_exception(e)
|
||||
raise
|
||||
finally:
|
||||
# 无论如何都能清理
|
||||
await self._terminate_mcp_client(name)
|
||||
if mcp_client is None:
|
||||
async with self._runtime_lock:
|
||||
self._mcp_starting.discard(name)
|
||||
|
||||
async def _init_mcp_client(self, name: str, config: dict) -> None:
|
||||
async def lifecycle() -> None:
|
||||
try:
|
||||
await shutdown_event.wait()
|
||||
logger.info(f"收到 MCP 客户端 {name} 终止信号")
|
||||
except asyncio.CancelledError:
|
||||
logger.debug(f"MCP 客户端 {name} 任务被取消")
|
||||
raise
|
||||
finally:
|
||||
await self._terminate_mcp_client(name)
|
||||
|
||||
lifecycle_task = asyncio.create_task(lifecycle(), name=f"mcp-client:{name}")
|
||||
async with self._runtime_lock:
|
||||
self._mcp_server_runtime[name] = _MCPServerRuntime(
|
||||
name=name,
|
||||
client=mcp_client,
|
||||
shutdown_event=shutdown_event,
|
||||
lifecycle_task=lifecycle_task,
|
||||
)
|
||||
self._mcp_starting.discard(name)
|
||||
|
||||
async def _shutdown_runtimes(
|
||||
self,
|
||||
runtimes: list[_MCPServerRuntime],
|
||||
timeout: float,
|
||||
*,
|
||||
strict: bool = True,
|
||||
) -> list[str]:
|
||||
"""Shutdown runtimes and wait for lifecycle tasks to complete."""
|
||||
lifecycle_tasks = [
|
||||
runtime.lifecycle_task
|
||||
for runtime in runtimes
|
||||
if not runtime.lifecycle_task.done()
|
||||
]
|
||||
if not lifecycle_tasks:
|
||||
return []
|
||||
|
||||
for runtime in runtimes:
|
||||
runtime.shutdown_event.set()
|
||||
|
||||
try:
|
||||
results = await asyncio.wait_for(
|
||||
asyncio.gather(*lifecycle_tasks, return_exceptions=True),
|
||||
timeout=timeout,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
pending_names = [
|
||||
runtime.name
|
||||
for runtime in runtimes
|
||||
if not runtime.lifecycle_task.done()
|
||||
]
|
||||
for task in lifecycle_tasks:
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
await asyncio.gather(*lifecycle_tasks, return_exceptions=True)
|
||||
if strict:
|
||||
raise MCPShutdownTimeoutError(pending_names, timeout)
|
||||
logger.warning(
|
||||
"MCP 服务关闭超时(%s 秒),以下服务未完全关闭:%s",
|
||||
f"{timeout:g}",
|
||||
", ".join(pending_names),
|
||||
)
|
||||
return pending_names
|
||||
else:
|
||||
for result in results:
|
||||
if isinstance(result, asyncio.CancelledError):
|
||||
logger.debug("MCP lifecycle task was cancelled during shutdown.")
|
||||
elif isinstance(result, Exception):
|
||||
logger.error(
|
||||
"MCP lifecycle task failed during shutdown.",
|
||||
exc_info=(type(result), result, result.__traceback__),
|
||||
)
|
||||
return []
|
||||
|
||||
async def _cleanup_mcp_client_safely(
|
||||
self, mcp_client: MCPClient, name: str
|
||||
) -> None:
|
||||
"""安全清理单个 MCP 客户端,避免清理异常中断主流程。"""
|
||||
try:
|
||||
await mcp_client.cleanup()
|
||||
except Exception as cleanup_exc: # noqa: BLE001 - only log here
|
||||
logger.error(f"清理 MCP 客户端资源 {name} 失败: {cleanup_exc}")
|
||||
|
||||
async def _init_mcp_client(self, name: str, config: dict) -> MCPClient:
|
||||
"""初始化单个MCP客户端"""
|
||||
# 先清理之前的客户端,如果存在
|
||||
if name in self.mcp_client_dict:
|
||||
await self._terminate_mcp_client(name)
|
||||
|
||||
mcp_client = MCPClient()
|
||||
mcp_client.name = name
|
||||
self.mcp_client_dict[name] = mcp_client
|
||||
await mcp_client.connect_to_server(config, name)
|
||||
tools_res = await mcp_client.list_tools_and_save()
|
||||
try:
|
||||
await mcp_client.connect_to_server(config, name)
|
||||
tools_res = await mcp_client.list_tools_and_save()
|
||||
except asyncio.CancelledError:
|
||||
await self._cleanup_mcp_client_safely(mcp_client, name)
|
||||
raise
|
||||
except Exception:
|
||||
await self._cleanup_mcp_client_safely(mcp_client, name)
|
||||
raise
|
||||
logger.debug(f"MCP server {name} list tools response: {tools_res}")
|
||||
tool_names = [tool.name for tool in tools_res.tools]
|
||||
|
||||
@@ -276,26 +603,36 @@ class FunctionToolManager:
|
||||
self.func_list.append(func_tool)
|
||||
|
||||
logger.info(f"已连接 MCP 服务 {name}, Tools: {tool_names}")
|
||||
return mcp_client
|
||||
|
||||
async def _terminate_mcp_client(self, name: str) -> None:
|
||||
"""关闭并清理MCP客户端"""
|
||||
if name in self.mcp_client_dict:
|
||||
client = self.mcp_client_dict[name]
|
||||
try:
|
||||
# 关闭MCP连接
|
||||
await client.cleanup()
|
||||
except Exception as e:
|
||||
logger.error(f"清空 MCP 客户端资源 {name}: {e}。")
|
||||
finally:
|
||||
# Remove client from dict after cleanup attempt (successful or not)
|
||||
self.mcp_client_dict.pop(name, None)
|
||||
# 移除关联的FuncTool
|
||||
self.func_list = [
|
||||
f
|
||||
for f in self.func_list
|
||||
if not (isinstance(f, MCPTool) and f.mcp_server_name == name)
|
||||
]
|
||||
logger.info(f"已关闭 MCP 服务 {name}")
|
||||
async with self._runtime_lock:
|
||||
runtime = self._mcp_server_runtime.get(name)
|
||||
if runtime:
|
||||
client = runtime.client
|
||||
# 关闭MCP连接
|
||||
await self._cleanup_mcp_client_safely(client, name)
|
||||
# 移除关联的FuncTool
|
||||
self.func_list = [
|
||||
f
|
||||
for f in self.func_list
|
||||
if not (isinstance(f, MCPTool) and f.mcp_server_name == name)
|
||||
]
|
||||
async with self._runtime_lock:
|
||||
self._mcp_server_runtime.pop(name, None)
|
||||
self._mcp_starting.discard(name)
|
||||
logger.info(f"已关闭 MCP 服务 {name}")
|
||||
return
|
||||
|
||||
# Runtime missing but stale tools may still exist after failed flows.
|
||||
self.func_list = [
|
||||
f
|
||||
for f in self.func_list
|
||||
if not (isinstance(f, MCPTool) and f.mcp_server_name == name)
|
||||
]
|
||||
async with self._runtime_lock:
|
||||
self._mcp_starting.discard(name)
|
||||
|
||||
@staticmethod
|
||||
async def test_mcp_server_connection(config: dict) -> list[str]:
|
||||
@@ -319,42 +656,36 @@ class FunctionToolManager:
|
||||
self,
|
||||
name: str,
|
||||
config: dict,
|
||||
event: asyncio.Event | None = None,
|
||||
ready_future: asyncio.Future | None = None,
|
||||
timeout: int = 30,
|
||||
shutdown_event: asyncio.Event | None = None,
|
||||
timeout: float | int | str | None = None,
|
||||
) -> None:
|
||||
"""Enable_mcp_server a new MCP server to the manager and initialize it.
|
||||
"""Enable a new MCP server and initialize it.
|
||||
|
||||
Args:
|
||||
name (str): The name of the MCP server.
|
||||
config (dict): Configuration for the MCP server.
|
||||
event (asyncio.Event): Event to signal when the MCP client is ready.
|
||||
ready_future (asyncio.Future): Future to signal when the MCP client is ready.
|
||||
timeout (int): Timeout for the initialization.
|
||||
name: The name of the MCP server.
|
||||
config: Configuration for the MCP server.
|
||||
shutdown_event: Event to signal when the MCP client should shut down.
|
||||
timeout: Timeout in seconds for initialization.
|
||||
Uses ASTRBOT_MCP_ENABLE_TIMEOUT by default (separate from init timeout).
|
||||
|
||||
Raises:
|
||||
TimeoutError: If the initialization does not complete within the specified timeout.
|
||||
MCPInitTimeoutError: If initialization does not complete within timeout.
|
||||
Exception: If there is an error during initialization.
|
||||
|
||||
"""
|
||||
if not event:
|
||||
event = asyncio.Event()
|
||||
if not ready_future:
|
||||
ready_future = asyncio.Future()
|
||||
if name in self.mcp_client_dict:
|
||||
return
|
||||
asyncio.create_task(
|
||||
self._init_mcp_client_task_wrapper(name, config, event, ready_future),
|
||||
if timeout is None:
|
||||
timeout_value = self._enable_timeout_default
|
||||
else:
|
||||
timeout_value = _resolve_timeout(
|
||||
timeout=timeout,
|
||||
env_name=ENABLE_MCP_TIMEOUT_ENV,
|
||||
default=self._enable_timeout_default,
|
||||
)
|
||||
await self._start_mcp_server(
|
||||
name=name,
|
||||
cfg=config,
|
||||
shutdown_event=shutdown_event,
|
||||
timeout=timeout_value,
|
||||
)
|
||||
try:
|
||||
await asyncio.wait_for(ready_future, timeout=timeout)
|
||||
finally:
|
||||
self.mcp_client_event[name] = event
|
||||
|
||||
if ready_future.done() and ready_future.exception():
|
||||
exc = ready_future.exception()
|
||||
if exc is not None:
|
||||
raise exc
|
||||
|
||||
async def disable_mcp_server(
|
||||
self,
|
||||
@@ -367,39 +698,40 @@ class FunctionToolManager:
|
||||
name (str): The name of the MCP server to disable. If None, ALL MCP servers will be disabled.
|
||||
timeout (int): Timeout.
|
||||
|
||||
Raises:
|
||||
MCPShutdownTimeoutError: If shutdown does not complete within timeout.
|
||||
Only raised when disabling a specific server (name is not None).
|
||||
|
||||
"""
|
||||
if name:
|
||||
if name not in self.mcp_client_event:
|
||||
async with self._runtime_lock:
|
||||
runtime = self._mcp_server_runtime.get(name)
|
||||
if runtime is None:
|
||||
return
|
||||
client = self.mcp_client_dict.get(name)
|
||||
self.mcp_client_event[name].set()
|
||||
if not client:
|
||||
return
|
||||
client_running_event = client.running_event
|
||||
try:
|
||||
await asyncio.wait_for(client_running_event.wait(), timeout=timeout)
|
||||
finally:
|
||||
self.mcp_client_event.pop(name, None)
|
||||
self.func_list = [
|
||||
f
|
||||
for f in self.func_list
|
||||
if not (isinstance(f, MCPTool) and f.mcp_server_name == name)
|
||||
]
|
||||
|
||||
await self._shutdown_runtimes([runtime], timeout, strict=True)
|
||||
else:
|
||||
running_events = [
|
||||
client.running_event.wait() for client in self.mcp_client_dict.values()
|
||||
]
|
||||
for key, event in self.mcp_client_event.items():
|
||||
event.set()
|
||||
# waiting for all clients to finish
|
||||
try:
|
||||
await asyncio.wait_for(asyncio.gather(*running_events), timeout=timeout)
|
||||
finally:
|
||||
self.mcp_client_event.clear()
|
||||
self.mcp_client_dict.clear()
|
||||
self.func_list = [
|
||||
f for f in self.func_list if not isinstance(f, MCPTool)
|
||||
]
|
||||
async with self._runtime_lock:
|
||||
runtimes = list(self._mcp_server_runtime.values())
|
||||
await self._shutdown_runtimes(runtimes, timeout, strict=False)
|
||||
|
||||
def _warn_on_timeout_mismatch(
|
||||
self,
|
||||
init_timeout: float,
|
||||
enable_timeout: float,
|
||||
) -> None:
|
||||
if init_timeout == enable_timeout:
|
||||
return
|
||||
with self._timeout_warn_lock:
|
||||
if self._timeout_mismatch_warned:
|
||||
return
|
||||
logger.info(
|
||||
"检测到 MCP 初始化超时与动态启用超时配置不同:"
|
||||
"初始化使用 %s 秒,动态启用使用 %s 秒。如需一致,请设置相同值。",
|
||||
f"{init_timeout:g}",
|
||||
f"{enable_timeout:g}",
|
||||
)
|
||||
self._timeout_mismatch_warned = True
|
||||
|
||||
def get_func_desc_openai_style(self, omit_empty_parameter_field=False) -> list:
|
||||
"""获得 OpenAI API 风格的**已经激活**的工具描述"""
|
||||
|
||||
@@ -330,8 +330,25 @@ class ProviderManager:
|
||||
if not self.curr_tts_provider_inst and self.tts_provider_insts:
|
||||
self.curr_tts_provider_inst = self.tts_provider_insts[0]
|
||||
|
||||
# 初始化 MCP Client 连接
|
||||
asyncio.create_task(self.llm_tools.init_mcp_clients(), name="init_mcp_clients")
|
||||
# 初始化 MCP Client 连接(等待完成以确保工具可用)
|
||||
strict_mcp_init = os.getenv("ASTRBOT_MCP_INIT_STRICT", "").strip().lower() in {
|
||||
"1",
|
||||
"true",
|
||||
"yes",
|
||||
"on",
|
||||
}
|
||||
mcp_init_summary = await self.llm_tools.init_mcp_clients(
|
||||
raise_on_all_failed=strict_mcp_init
|
||||
)
|
||||
if (
|
||||
mcp_init_summary.total > 0
|
||||
and mcp_init_summary.success == 0
|
||||
and not strict_mcp_init
|
||||
):
|
||||
logger.warning(
|
||||
"MCP 服务全部初始化失败,系统将继续启动(可设置 "
|
||||
"ASTRBOT_MCP_INIT_STRICT=1 以在此场景下中止启动)。"
|
||||
)
|
||||
|
||||
def dynamic_import_provider(self, type: str) -> None:
|
||||
"""动态导入提供商适配器模块
|
||||
|
||||
@@ -51,11 +51,9 @@ class ToolsRoute(Route):
|
||||
server_info[key] = value
|
||||
|
||||
# 如果MCP客户端已初始化,从客户端获取工具名称
|
||||
for (
|
||||
name_key,
|
||||
mcp_client,
|
||||
) in self.tool_mgr.mcp_client_dict.items():
|
||||
for name_key, runtime in self.tool_mgr.mcp_server_runtime_view.items():
|
||||
if name_key == name:
|
||||
mcp_client = runtime.client
|
||||
server_info["tools"] = [tool.name for tool in mcp_client.tools]
|
||||
server_info["errlogs"] = mcp_client.server_errlogs
|
||||
break
|
||||
@@ -192,7 +190,7 @@ class ToolsRoute(Route):
|
||||
# 处理MCP客户端状态变化
|
||||
if active:
|
||||
if (
|
||||
old_name in self.tool_mgr.mcp_client_dict
|
||||
old_name in self.tool_mgr.mcp_server_runtime_view
|
||||
or not only_update_active
|
||||
or is_rename
|
||||
):
|
||||
@@ -233,7 +231,7 @@ class ToolsRoute(Route):
|
||||
.__dict__
|
||||
)
|
||||
# 如果要停用服务器
|
||||
elif old_name in self.tool_mgr.mcp_client_dict:
|
||||
elif old_name in self.tool_mgr.mcp_server_runtime_view:
|
||||
try:
|
||||
await self.tool_mgr.disable_mcp_server(old_name, timeout=10)
|
||||
except TimeoutError:
|
||||
@@ -272,7 +270,7 @@ class ToolsRoute(Route):
|
||||
del config["mcpServers"][name]
|
||||
|
||||
if self.tool_mgr.save_mcp_config(config):
|
||||
if name in self.tool_mgr.mcp_client_dict:
|
||||
if name in self.tool_mgr.mcp_server_runtime_view:
|
||||
try:
|
||||
await self.tool_mgr.disable_mcp_server(name, timeout=10)
|
||||
except TimeoutError:
|
||||
|
||||
Reference in New Issue
Block a user