refactor: simplify MCP init orchestration and improve log security

- Replace Future-based sync with asyncio.wait + name→task mapping
- Explicitly cancel timed-out tasks after 20s timeout
- Downgrade sensitive config details (command/args/URL) to debug level
- Move urllib.parse import to top-level
This commit is contained in:
idiotsj
2026-02-27 21:59:32 +08:00
parent ec9f7403d5
commit 4924739423
+33 -44
View File
@@ -4,6 +4,7 @@ import asyncio
import copy
import json
import os
import urllib.parse
from collections.abc import AsyncGenerator, Awaitable, Callable
from typing import Any
@@ -212,59 +213,51 @@ class FunctionToolManager:
open(mcp_json_file, encoding="utf-8"),
)["mcpServers"]
# 收集所有初始化任务的 Future
init_futures: dict[str, asyncio.Future] = {}
tasks: dict[str, asyncio.Task] = {}
for name in mcp_server_json_obj:
cfg = mcp_server_json_obj[name]
for name, cfg in mcp_server_json_obj.items():
if cfg.get("active", True):
event = asyncio.Event()
ready_future = asyncio.Future()
init_futures[name] = ready_future
asyncio.create_task(
self._init_mcp_client_task_wrapper(name, cfg, event, ready_future),
task = asyncio.create_task(
self._init_mcp_client_task_wrapper(name, cfg, event),
)
tasks[name] = task
self.mcp_client_event[name] = event
# 等待所有 MCP 客户端初始化完成(或失败)
if init_futures:
logger.info(f"等待 {len(init_futures)} 个 MCP 服务初始化...")
if tasks:
logger.info(f"等待 {len(tasks)} 个 MCP 服务初始化...")
try:
# 设置总超时时间为 20 秒,避免慢速 MCP 服务器阻塞启动过久
results = await asyncio.wait_for(
asyncio.gather(*init_futures.values(), return_exceptions=True),
timeout=20.0,
)
except asyncio.TimeoutError:
done, pending = await asyncio.wait(tasks.values(), timeout=20.0)
if pending:
logger.warning(
"MCP 服务初始化超时(20秒),部分服务可能未完全加载。"
"建议检查 MCP 服务器配置和网络连接。"
)
# 即使超时也继续,已完成的服务仍然可用
results = []
for name, future in zip(init_futures.keys(), init_futures.values()):
if future.done():
try:
results.append(future.result())
except Exception as e:
results.append(e)
else:
results.append(TimeoutError(f"MCP 服务 {name} 初始化超时"))
for task in pending:
task.cancel()
success_count = 0
failed_services = []
for name, result in zip(init_futures.keys(), results):
if isinstance(result, Exception):
logger.error(f"MCP 服务 {name} 初始化失败: {result}")
# 显示配置信息以便调试
failed_services: list[str] = []
for name, task in tasks.items():
if task in pending:
logger.error(f"MCP 服务 {name} 初始化超时")
failed_services.append(name)
continue
exc = task.exception()
if exc is not None:
logger.error(f"MCP 服务 {name} 初始化失败: {exc}")
# 仅在 debug 级别输出完整配置,避免在生产日志中泄露敏感信息
cfg = mcp_server_json_obj.get(name, {})
if "command" in cfg:
logger.error(f" 命令: {cfg['command']}")
logger.debug(f" 命令: {cfg['command']}")
if "args" in cfg:
logger.error(f" 参数: {cfg['args']}")
logger.debug(f" 参数: {cfg['args']}")
elif "url" in cfg:
logger.error(f" URL: {cfg['url']}")
parsed = urllib.parse.urlparse(cfg["url"])
logger.debug(f" 主机: {parsed.scheme}://{parsed.netloc}")
failed_services.append(name)
else:
success_count += 1
@@ -275,28 +268,24 @@ class FunctionToolManager:
f"请检查配置文件 mcp_server.json 和服务器可用性。"
)
logger.info(f"MCP 服务初始化完成: {success_count}/{len(init_futures)} 成功")
logger.info(f"MCP 服务初始化完成: {success_count}/{len(tasks)} 成功")
async def _init_mcp_client_task_wrapper(
self,
name: str,
cfg: dict,
event: asyncio.Event,
ready_future: asyncio.Future | None = None,
) -> None:
"""初始化 MCP 客户端的包装函数,用于捕获异常"""
try:
await self._init_mcp_client(name, cfg)
tools = await self.mcp_client_dict[name].list_tools_and_save()
if ready_future and not ready_future.done():
# tell the caller we are ready
ready_future.set_result(tools)
logger.debug(f"MCP 服务 {name} 初始化完成,工具: {tools}")
await event.wait()
logger.info(f"收到 MCP 客户端 {name} 终止信号")
except Exception as e:
except Exception:
logger.error(f"初始化 MCP 客户端 {name} 失败", exc_info=True)
if ready_future and not ready_future.done():
ready_future.set_exception(e)
raise
finally:
# 无论如何都能清理
await self._terminate_mcp_client(name)