From 33fd6a5016fd7efcce2e8112e2ca83a31a944357 Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Thu, 17 Apr 2025 13:59:10 +0800 Subject: [PATCH] =?UTF-8?q?perf:=20=E4=BC=98=E5=8C=96=20MCP=20=E6=9C=8D?= =?UTF-8?q?=E5=8A=A1=E5=99=A8=E7=9A=84=E6=97=A5=E5=BF=97=E5=9B=9E=E6=98=BE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/core/log.py | 5 ++- astrbot/core/provider/func_tool_manager.py | 44 +++++++++++++++------- astrbot/core/utils/log_pipe.py | 36 ++++++++++++++++++ astrbot/dashboard/routes/tools.py | 9 +++-- dashboard/src/views/ToolUsePage.vue | 39 +++++++++++++++---- 5 files changed, 107 insertions(+), 26 deletions(-) create mode 100644 astrbot/core/utils/log_pipe.py diff --git a/astrbot/core/log.py b/astrbot/core/log.py index 6609b8246..9b78eaec6 100644 --- a/astrbot/core/log.py +++ b/astrbot/core/log.py @@ -25,6 +25,7 @@ import logging import colorlog import asyncio import os +import sys from collections import deque from asyncio import Queue from typing import List @@ -171,7 +172,9 @@ class LogManager: if logger.hasHandlers(): return logger # 如果logger没有处理器 - console_handler = logging.StreamHandler() # 创建一个StreamHandler用于控制台输出 + console_handler = logging.StreamHandler( + sys.stdout + ) # 创建一个StreamHandler用于控制台输出 console_handler.setLevel( logging.DEBUG ) # 将日志级别设置为DEBUG(最低级别, 显示所有日志), *如果插件没有设置级别, 默认为DEBUG diff --git a/astrbot/core/provider/func_tool_manager.py b/astrbot/core/provider/func_tool_manager.py index 53f5048fa..793d21a12 100644 --- a/astrbot/core/provider/func_tool_manager.py +++ b/astrbot/core/provider/func_tool_manager.py @@ -4,12 +4,14 @@ import textwrap import os import asyncio import copy +import logging from typing import Dict, List, Awaitable, Literal, Any from dataclasses import dataclass from typing import Optional from contextlib import AsyncExitStack from astrbot import logger +from astrbot.core.utils.log_pipe import LogPipe try: import mcp @@ -87,8 +89,9 @@ class MCPClient: self.name = None self.active: bool = True self.tools: List[mcp.Tool] = [] + self.server_errlogs: List[str] = [] - async def connect_to_server(self, mcp_server_config: dict): + async def connect_to_server(self, mcp_server_config: dict, name: str): """Connect to an MCP server Args: @@ -98,19 +101,30 @@ class MCPClient: if "mcpServers" in cfg and len(cfg["mcpServers"]) > 0: key_0 = list(cfg["mcpServers"].keys())[0] cfg = cfg["mcpServers"][key_0] - cfg.pop("active", None) # Remove active flag from config + cfg.pop("active", None) # Remove active flag from config server_params = mcp.StdioServerParameters( **cfg, ) + def callback(msg: str): + # 处理 MCP 服务的错误日志 + self.server_errlogs.append(msg) + stdio_transport = await self.exit_stack.enter_async_context( - mcp.stdio_client(server_params) + mcp.stdio_client( + server_params, + errlog=LogPipe( + level=logging.ERROR, + logger=logger, + identifier=f"MCPServer-{name}", + callback=callback, + ), + ), ) self.stdio, self.write = stdio_transport self.session = await self.exit_stack.enter_async_context( mcp.ClientSession(self.stdio, self.write) ) - await self.session.initialize() async def list_tools_and_save(self) -> mcp.ListToolsResult: @@ -266,7 +280,9 @@ class FuncCall: self.func_list = [ f for f in self.func_list - if not (f.origin == "mcp" and f.mcp_server_name == data["name"]) + if not ( + f.origin == "mcp" and f.mcp_server_name == data["name"] + ) ] else: for name in self.mcp_client_dict.keys(): @@ -275,11 +291,7 @@ class FuncCall: if name in self.mcp_client_event: self.mcp_client_event[name].set() self.mcp_client_event.pop(name, None) - self.func_list = [ - f - for f in self.func_list - if f.origin != "mcp" - ] + self.func_list = [f for f in self.func_list if f.origin != "mcp"] async def _init_mcp_client_task_wrapper( self, name: str, cfg: dict, event: asyncio.Event @@ -291,6 +303,9 @@ class FuncCall: logger.info(f"收到 MCP 客户端 {name} 终止信号") await self._terminate_mcp_client(name) except Exception as e: + import traceback + + traceback.print_exc() logger.error(f"初始化 MCP 客户端 {name} 失败: {e}") async def _init_mcp_client(self, name: str, config: dict) -> None: @@ -302,10 +317,10 @@ class FuncCall: mcp_client = MCPClient() mcp_client.name = name - await mcp_client.connect_to_server(config) + self.mcp_client_dict[name] = mcp_client + await mcp_client.connect_to_server(config, name) tools_res = await mcp_client.list_tools_and_save() tool_names = [tool.name for tool in tools_res.tools] - self.mcp_client_dict[name] = mcp_client # 移除该MCP服务之前的工具(如有) self.func_list = [ @@ -329,6 +344,9 @@ class FuncCall: logger.info(f"已连接 MCP 服务 {name}, Tools: {tool_names}") return True except Exception as e: + import traceback + + logger.error(traceback.format_exc()) logger.error(f"初始化 MCP 客户端 {name} 失败: {e}") # 发生错误时确保客户端被清理 if name in self.mcp_client_dict: @@ -352,7 +370,7 @@ class FuncCall: ] logger.info(f"已关闭 MCP 服务 {name}") - def get_func_desc_openai_style(self, omit_empty_parameter_field = False) -> list: + def get_func_desc_openai_style(self, omit_empty_parameter_field=False) -> list: """ 获得 OpenAI API 风格的**已经激活**的工具描述 """ diff --git a/astrbot/core/utils/log_pipe.py b/astrbot/core/utils/log_pipe.py new file mode 100644 index 000000000..bf5402f17 --- /dev/null +++ b/astrbot/core/utils/log_pipe.py @@ -0,0 +1,36 @@ +import threading +import os +from logging import Logger + + +class LogPipe(threading.Thread): + def __init__( + self, + level, + logger: Logger, + identifier=None, + callback=None, + ): + threading.Thread.__init__(self) + self.daemon = True + self.level = level + self.fd_read, self.fd_write = os.pipe() + self.identifier = identifier + self.logger = logger + self.callback = callback + self.reader = os.fdopen(self.fd_read) + self.start() + + def fileno(self): + return self.fd_write + + def run(self): + for line in iter(self.reader.readline, ""): + if self.callback: + self.callback(line.strip()) + self.logger.log(self.level, f"[{self.identifier}] {line.strip()}") + + self.reader.close() + + def close(self): + os.close(self.fd_write) diff --git a/astrbot/dashboard/routes/tools.py b/astrbot/dashboard/routes/tools.py index af81fe0d2..9fda62cea 100644 --- a/astrbot/dashboard/routes/tools.py +++ b/astrbot/dashboard/routes/tools.py @@ -80,6 +80,7 @@ class ToolsRoute(Route): ) in self.tool_mgr.mcp_client_dict.items(): if name_key == name: server_info["tools"] = [tool.name for tool in mcp_client.tools] + server_info["errlogs"] = mcp_client.server_errlogs break else: server_info["tools"] = [] @@ -107,7 +108,7 @@ class ToolsRoute(Route): # 复制所有配置字段 for key, value in server_data.items(): - if key not in ["name", "active", "tools"]: # 排除特殊字段 + if key not in ["name", "active", "tools", "errlogs"]: # 排除特殊字段 if key == "mcpServers": key_0 = list(server_data["mcpServers"].keys())[ 0 @@ -129,7 +130,7 @@ class ToolsRoute(Route): if self.save_mcp_config(config): # 动态初始化新MCP客户端 - self.tool_mgr.mcp_service_queue.put_nowait( + await self.tool_mgr.mcp_service_queue.put( { "type": "init", "name": name, @@ -170,7 +171,7 @@ class ToolsRoute(Route): # 复制所有配置字段 for key, value in server_data.items(): - if key not in ["name", "active", "tools"]: # 排除特殊字段 + if key not in ["name", "active", "tools", "errlogs"]: # 排除特殊字段 if key == "mcpServers": key_0 = list(server_data["mcpServers"].keys())[ 0 @@ -208,7 +209,7 @@ class ToolsRoute(Route): ) else: # 客户端不存在,初始化 - self.tool_mgr.mcp_service_queue.put_nowait( + await self.tool_mgr.mcp_service_queue.put( { "type": "init", "name": name, diff --git a/dashboard/src/views/ToolUsePage.vue b/dashboard/src/views/ToolUsePage.vue index c5499ca8b..c977613e5 100644 --- a/dashboard/src/views/ToolUsePage.vue +++ b/dashboard/src/views/ToolUsePage.vue @@ -55,9 +55,11 @@ mdi-server MCP 服务器 - - + + 刷新 + + 新增服务器 @@ -77,7 +79,21 @@
- {{ server.name }} +
+ {{ server.name }} + + + +
{{ server.errlogs }}
+
+ +
+