fix: correctly synchronize MCP client initialization

This commit is contained in:
邹永赫
2026-03-01 21:22:41 +09:00
parent 38e99cf65c
commit 8132ce24eb
+59 -18
View File
@@ -214,39 +214,68 @@ class FunctionToolManager:
)["mcpServers"]
tasks: dict[str, asyncio.Task] = {}
ready_futures: dict[str, asyncio.Future] = {}
for name, cfg in mcp_server_json_obj.items():
if cfg.get("active", True):
event = asyncio.Event()
ready_future = asyncio.get_running_loop().create_future()
task = asyncio.create_task(
self._init_mcp_client_task_wrapper(name, cfg, event),
self._init_mcp_client_task_wrapper(
name,
cfg,
event,
ready_future,
),
)
tasks[name] = task
ready_futures[name] = ready_future
self.mcp_client_event[name] = event
if tasks:
logger.info(f"等待 {len(tasks)} 个 MCP 服务初始化...")
if ready_futures:
logger.info(f"等待 {len(ready_futures)} 个 MCP 服务初始化...")
done, pending = await asyncio.wait(tasks.values(), timeout=20.0)
_, pending_futures = await asyncio.wait(
ready_futures.values(),
timeout=20.0,
)
if pending:
pending_services = {
name
for name, ready_future in ready_futures.items()
if ready_future in pending_futures
}
if pending_services:
logger.warning(
"MCP 服务初始化超时(20秒),部分服务可能未完全加载。"
"建议检查 MCP 服务器配置和网络连接。"
)
for task in pending:
for name in pending_services:
task = tasks[name]
task.cancel()
await asyncio.gather(
*(tasks[name] for name in pending_services),
return_exceptions=True,
)
success_count = 0
failed_services: list[str] = []
for name, task in tasks.items():
if task in pending:
for name, ready_future in ready_futures.items():
if name in pending_services:
logger.error(f"MCP 服务 {name} 初始化超时")
failed_services.append(name)
self.mcp_client_event.pop(name, None)
continue
exc = task.exception()
if ready_future.cancelled():
logger.error(f"MCP 服务 {name} 初始化已取消")
failed_services.append(name)
self.mcp_client_event.pop(name, None)
continue
exc = ready_future.exception()
if exc is not None:
logger.error(f"MCP 服务 {name} 初始化失败: {exc}")
# 仅在 debug 级别输出完整配置,避免在生产日志中泄露敏感信息
@@ -259,6 +288,7 @@ class FunctionToolManager:
parsed = urllib.parse.urlparse(cfg["url"])
logger.debug(f" 主机: {parsed.scheme}://{parsed.netloc}")
failed_services.append(name)
self.mcp_client_event.pop(name, None)
else:
success_count += 1
@@ -275,15 +305,26 @@ class FunctionToolManager:
name: str,
cfg: dict,
event: asyncio.Event,
ready_future: asyncio.Future | None = None,
) -> None:
"""初始化 MCP 客户端的包装函数,用于捕获异常"""
initialized = False
try:
await self._init_mcp_client(name, cfg)
initialized = True
if ready_future and not ready_future.done():
ready_future.set_result(True)
await event.wait()
logger.info(f"收到 MCP 客户端 {name} 终止信号")
except Exception:
except asyncio.CancelledError:
if ready_future and not ready_future.done():
ready_future.set_exception(
asyncio.TimeoutError("MCP 客户端初始化超时"),
)
raise
except Exception as e:
if ready_future and not ready_future.done():
ready_future.set_exception(e)
if not initialized:
# 初始化阶段失败,记录错误并向上抛出让 task.exception() 捕获
logger.error(f"初始化 MCP 客户端 {name} 失败", exc_info=True)
@@ -387,22 +428,22 @@ class FunctionToolManager:
if not event:
event = asyncio.Event()
if not ready_future:
ready_future = asyncio.Future()
ready_future = asyncio.get_running_loop().create_future()
if name in self.mcp_client_dict:
return
asyncio.create_task(
init_task = asyncio.create_task(
self._init_mcp_client_task_wrapper(name, config, event, ready_future),
)
try:
await asyncio.wait_for(ready_future, timeout=timeout)
finally:
except asyncio.TimeoutError:
init_task.cancel()
await asyncio.gather(init_task, return_exceptions=True)
self.mcp_client_event.pop(name, None)
raise
else:
self.mcp_client_event[name] = event
if ready_future.done() and ready_future.exception():
exc = ready_future.exception()
if exc is not None:
raise exc
async def disable_mcp_server(
self,
name: str | None = None,