From 8132ce24ebb059d5037faa9cc939727d52e4b0f3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=82=B9=E6=B0=B8=E8=B5=AB?= <1259085392@qq.com> Date: Sun, 1 Mar 2026 21:22:41 +0900 Subject: [PATCH] fix: correctly synchronize MCP client initialization --- astrbot/core/provider/func_tool_manager.py | 77 +++++++++++++++++----- 1 file changed, 59 insertions(+), 18 deletions(-) diff --git a/astrbot/core/provider/func_tool_manager.py b/astrbot/core/provider/func_tool_manager.py index be663f0aa..347a564ce 100644 --- a/astrbot/core/provider/func_tool_manager.py +++ b/astrbot/core/provider/func_tool_manager.py @@ -214,39 +214,68 @@ class FunctionToolManager: )["mcpServers"] tasks: dict[str, asyncio.Task] = {} + ready_futures: dict[str, asyncio.Future] = {} for name, cfg in mcp_server_json_obj.items(): if cfg.get("active", True): event = asyncio.Event() + ready_future = asyncio.get_running_loop().create_future() task = asyncio.create_task( - self._init_mcp_client_task_wrapper(name, cfg, event), + self._init_mcp_client_task_wrapper( + name, + cfg, + event, + ready_future, + ), ) tasks[name] = task + ready_futures[name] = ready_future self.mcp_client_event[name] = event - if tasks: - logger.info(f"等待 {len(tasks)} 个 MCP 服务初始化...") + if ready_futures: + logger.info(f"等待 {len(ready_futures)} 个 MCP 服务初始化...") - done, pending = await asyncio.wait(tasks.values(), timeout=20.0) + _, pending_futures = await asyncio.wait( + ready_futures.values(), + timeout=20.0, + ) - if pending: + pending_services = { + name + for name, ready_future in ready_futures.items() + if ready_future in pending_futures + } + + if pending_services: logger.warning( "MCP 服务初始化超时(20秒),部分服务可能未完全加载。" "建议检查 MCP 服务器配置和网络连接。" ) - for task in pending: + for name in pending_services: + task = tasks[name] task.cancel() + await asyncio.gather( + *(tasks[name] for name in pending_services), + return_exceptions=True, + ) success_count = 0 failed_services: list[str] = [] - for name, task in tasks.items(): - if task in pending: + for name, ready_future in ready_futures.items(): + if name in pending_services: logger.error(f"MCP 服务 {name} 初始化超时") failed_services.append(name) + self.mcp_client_event.pop(name, None) continue - exc = task.exception() + if ready_future.cancelled(): + logger.error(f"MCP 服务 {name} 初始化已取消") + failed_services.append(name) + self.mcp_client_event.pop(name, None) + continue + + exc = ready_future.exception() if exc is not None: logger.error(f"MCP 服务 {name} 初始化失败: {exc}") # 仅在 debug 级别输出完整配置,避免在生产日志中泄露敏感信息 @@ -259,6 +288,7 @@ class FunctionToolManager: parsed = urllib.parse.urlparse(cfg["url"]) logger.debug(f" 主机: {parsed.scheme}://{parsed.netloc}") failed_services.append(name) + self.mcp_client_event.pop(name, None) else: success_count += 1 @@ -275,15 +305,26 @@ class FunctionToolManager: name: str, cfg: dict, event: asyncio.Event, + ready_future: asyncio.Future | None = None, ) -> None: """初始化 MCP 客户端的包装函数,用于捕获异常""" initialized = False try: await self._init_mcp_client(name, cfg) initialized = True + if ready_future and not ready_future.done(): + ready_future.set_result(True) await event.wait() logger.info(f"收到 MCP 客户端 {name} 终止信号") - except Exception: + except asyncio.CancelledError: + if ready_future and not ready_future.done(): + ready_future.set_exception( + asyncio.TimeoutError("MCP 客户端初始化超时"), + ) + raise + except Exception as e: + if ready_future and not ready_future.done(): + ready_future.set_exception(e) if not initialized: # 初始化阶段失败,记录错误并向上抛出让 task.exception() 捕获 logger.error(f"初始化 MCP 客户端 {name} 失败", exc_info=True) @@ -387,22 +428,22 @@ class FunctionToolManager: if not event: event = asyncio.Event() if not ready_future: - ready_future = asyncio.Future() + ready_future = asyncio.get_running_loop().create_future() if name in self.mcp_client_dict: return - asyncio.create_task( + init_task = asyncio.create_task( self._init_mcp_client_task_wrapper(name, config, event, ready_future), ) try: await asyncio.wait_for(ready_future, timeout=timeout) - finally: + except asyncio.TimeoutError: + init_task.cancel() + await asyncio.gather(init_task, return_exceptions=True) + self.mcp_client_event.pop(name, None) + raise + else: self.mcp_client_event[name] = event - if ready_future.done() and ready_future.exception(): - exc = ready_future.exception() - if exc is not None: - raise exc - async def disable_mcp_server( self, name: str | None = None,