perf: 优化 MCP 服务器的日志回显

This commit is contained in:
Soulter
2025-04-17 13:59:10 +08:00
parent 97cbccc2ba
commit 33fd6a5016
5 changed files with 107 additions and 26 deletions
+31 -13
View File
@@ -4,12 +4,14 @@ import textwrap
import os
import asyncio
import copy
import logging
from typing import Dict, List, Awaitable, Literal, Any
from dataclasses import dataclass
from typing import Optional
from contextlib import AsyncExitStack
from astrbot import logger
from astrbot.core.utils.log_pipe import LogPipe
try:
import mcp
@@ -87,8 +89,9 @@ class MCPClient:
self.name = None
self.active: bool = True
self.tools: List[mcp.Tool] = []
self.server_errlogs: List[str] = []
async def connect_to_server(self, mcp_server_config: dict):
async def connect_to_server(self, mcp_server_config: dict, name: str):
"""Connect to an MCP server
Args:
@@ -98,19 +101,30 @@ class MCPClient:
if "mcpServers" in cfg and len(cfg["mcpServers"]) > 0:
key_0 = list(cfg["mcpServers"].keys())[0]
cfg = cfg["mcpServers"][key_0]
cfg.pop("active", None) # Remove active flag from config
cfg.pop("active", None) # Remove active flag from config
server_params = mcp.StdioServerParameters(
**cfg,
)
def callback(msg: str):
# 处理 MCP 服务的错误日志
self.server_errlogs.append(msg)
stdio_transport = await self.exit_stack.enter_async_context(
mcp.stdio_client(server_params)
mcp.stdio_client(
server_params,
errlog=LogPipe(
level=logging.ERROR,
logger=logger,
identifier=f"MCPServer-{name}",
callback=callback,
),
),
)
self.stdio, self.write = stdio_transport
self.session = await self.exit_stack.enter_async_context(
mcp.ClientSession(self.stdio, self.write)
)
await self.session.initialize()
async def list_tools_and_save(self) -> mcp.ListToolsResult:
@@ -266,7 +280,9 @@ class FuncCall:
self.func_list = [
f
for f in self.func_list
if not (f.origin == "mcp" and f.mcp_server_name == data["name"])
if not (
f.origin == "mcp" and f.mcp_server_name == data["name"]
)
]
else:
for name in self.mcp_client_dict.keys():
@@ -275,11 +291,7 @@ class FuncCall:
if name in self.mcp_client_event:
self.mcp_client_event[name].set()
self.mcp_client_event.pop(name, None)
self.func_list = [
f
for f in self.func_list
if f.origin != "mcp"
]
self.func_list = [f for f in self.func_list if f.origin != "mcp"]
async def _init_mcp_client_task_wrapper(
self, name: str, cfg: dict, event: asyncio.Event
@@ -291,6 +303,9 @@ class FuncCall:
logger.info(f"收到 MCP 客户端 {name} 终止信号")
await self._terminate_mcp_client(name)
except Exception as e:
import traceback
traceback.print_exc()
logger.error(f"初始化 MCP 客户端 {name} 失败: {e}")
async def _init_mcp_client(self, name: str, config: dict) -> None:
@@ -302,10 +317,10 @@ class FuncCall:
mcp_client = MCPClient()
mcp_client.name = name
await mcp_client.connect_to_server(config)
self.mcp_client_dict[name] = mcp_client
await mcp_client.connect_to_server(config, name)
tools_res = await mcp_client.list_tools_and_save()
tool_names = [tool.name for tool in tools_res.tools]
self.mcp_client_dict[name] = mcp_client
# 移除该MCP服务之前的工具(如有)
self.func_list = [
@@ -329,6 +344,9 @@ class FuncCall:
logger.info(f"已连接 MCP 服务 {name}, Tools: {tool_names}")
return True
except Exception as e:
import traceback
logger.error(traceback.format_exc())
logger.error(f"初始化 MCP 客户端 {name} 失败: {e}")
# 发生错误时确保客户端被清理
if name in self.mcp_client_dict:
@@ -352,7 +370,7 @@ class FuncCall:
]
logger.info(f"已关闭 MCP 服务 {name}")
def get_func_desc_openai_style(self, omit_empty_parameter_field = False) -> list:
def get_func_desc_openai_style(self, omit_empty_parameter_field=False) -> list:
"""
获得 OpenAI API 风格的**已经激活**的工具描述
"""