From 49247394236b1837a14fcd3129d0517cd0d94c4d Mon Sep 17 00:00:00 2001 From: idiotsj Date: Fri, 27 Feb 2026 21:59:32 +0800 Subject: [PATCH] refactor: simplify MCP init orchestration and improve log security MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Replace Future-based sync with asyncio.wait + name→task mapping - Explicitly cancel timed-out tasks after 20s timeout - Downgrade sensitive config details (command/args/URL) to debug level - Move urllib.parse import to top-level --- astrbot/core/provider/func_tool_manager.py | 77 ++++++++++------------ 1 file changed, 33 insertions(+), 44 deletions(-) diff --git a/astrbot/core/provider/func_tool_manager.py b/astrbot/core/provider/func_tool_manager.py index 8242d0614..8d59c4075 100644 --- a/astrbot/core/provider/func_tool_manager.py +++ b/astrbot/core/provider/func_tool_manager.py @@ -4,6 +4,7 @@ import asyncio import copy import json import os +import urllib.parse from collections.abc import AsyncGenerator, Awaitable, Callable from typing import Any @@ -212,59 +213,51 @@ class FunctionToolManager: open(mcp_json_file, encoding="utf-8"), )["mcpServers"] - # 收集所有初始化任务的 Future - init_futures: dict[str, asyncio.Future] = {} + tasks: dict[str, asyncio.Task] = {} - for name in mcp_server_json_obj: - cfg = mcp_server_json_obj[name] + for name, cfg in mcp_server_json_obj.items(): if cfg.get("active", True): event = asyncio.Event() - ready_future = asyncio.Future() - init_futures[name] = ready_future - asyncio.create_task( - self._init_mcp_client_task_wrapper(name, cfg, event, ready_future), + task = asyncio.create_task( + self._init_mcp_client_task_wrapper(name, cfg, event), ) + tasks[name] = task self.mcp_client_event[name] = event - # 等待所有 MCP 客户端初始化完成(或失败) - if init_futures: - logger.info(f"等待 {len(init_futures)} 个 MCP 服务初始化...") + if tasks: + logger.info(f"等待 {len(tasks)} 个 MCP 服务初始化...") - try: - # 设置总超时时间为 20 秒,避免慢速 MCP 服务器阻塞启动过久 - results = await asyncio.wait_for( - asyncio.gather(*init_futures.values(), return_exceptions=True), - timeout=20.0, - ) - except asyncio.TimeoutError: + done, pending = await asyncio.wait(tasks.values(), timeout=20.0) + + if pending: logger.warning( "MCP 服务初始化超时(20秒),部分服务可能未完全加载。" "建议检查 MCP 服务器配置和网络连接。" ) - # 即使超时也继续,已完成的服务仍然可用 - results = [] - for name, future in zip(init_futures.keys(), init_futures.values()): - if future.done(): - try: - results.append(future.result()) - except Exception as e: - results.append(e) - else: - results.append(TimeoutError(f"MCP 服务 {name} 初始化超时")) + for task in pending: + task.cancel() success_count = 0 - failed_services = [] - for name, result in zip(init_futures.keys(), results): - if isinstance(result, Exception): - logger.error(f"MCP 服务 {name} 初始化失败: {result}") - # 显示配置信息以便调试 + failed_services: list[str] = [] + + for name, task in tasks.items(): + if task in pending: + logger.error(f"MCP 服务 {name} 初始化超时") + failed_services.append(name) + continue + + exc = task.exception() + if exc is not None: + logger.error(f"MCP 服务 {name} 初始化失败: {exc}") + # 仅在 debug 级别输出完整配置,避免在生产日志中泄露敏感信息 cfg = mcp_server_json_obj.get(name, {}) if "command" in cfg: - logger.error(f" 命令: {cfg['command']}") + logger.debug(f" 命令: {cfg['command']}") if "args" in cfg: - logger.error(f" 参数: {cfg['args']}") + logger.debug(f" 参数: {cfg['args']}") elif "url" in cfg: - logger.error(f" URL: {cfg['url']}") + parsed = urllib.parse.urlparse(cfg["url"]) + logger.debug(f" 主机: {parsed.scheme}://{parsed.netloc}") failed_services.append(name) else: success_count += 1 @@ -275,28 +268,24 @@ class FunctionToolManager: f"请检查配置文件 mcp_server.json 和服务器可用性。" ) - logger.info(f"MCP 服务初始化完成: {success_count}/{len(init_futures)} 成功") + logger.info(f"MCP 服务初始化完成: {success_count}/{len(tasks)} 成功") async def _init_mcp_client_task_wrapper( self, name: str, cfg: dict, event: asyncio.Event, - ready_future: asyncio.Future | None = None, ) -> None: """初始化 MCP 客户端的包装函数,用于捕获异常""" try: await self._init_mcp_client(name, cfg) tools = await self.mcp_client_dict[name].list_tools_and_save() - if ready_future and not ready_future.done(): - # tell the caller we are ready - ready_future.set_result(tools) + logger.debug(f"MCP 服务 {name} 初始化完成,工具: {tools}") await event.wait() logger.info(f"收到 MCP 客户端 {name} 终止信号") - except Exception as e: + except Exception: logger.error(f"初始化 MCP 客户端 {name} 失败", exc_info=True) - if ready_future and not ready_future.done(): - ready_future.set_exception(e) + raise finally: # 无论如何都能清理 await self._terminate_mcp_client(name)