From bf4c2ecd330832e46381fd4140bf96cbec4a0d1a Mon Sep 17 00:00:00 2001
From: Soulter <905617992@qq.com>
Date: Sun, 20 Apr 2025 11:02:28 +0800
Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20feat:=20MCP=20=E6=94=AF=E6=8C=81=20?=
=?UTF-8?q?SSE=20=E4=BC=A0=E8=BE=93=E5=8D=8F=E8=AE=AE=E8=BF=9E=E6=8E=A5?=
=?UTF-8?q?=E5=88=B0=E6=9C=8D=E5=8A=A1=E5=99=A8?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
astrbot/core/provider/func_tool_manager.py | 60 ++++++++++++++--------
dashboard/src/views/ToolUsePage.vue | 4 +-
2 files changed, 42 insertions(+), 22 deletions(-)
diff --git a/astrbot/core/provider/func_tool_manager.py b/astrbot/core/provider/func_tool_manager.py
index 793d21a12..918542f17 100644
--- a/astrbot/core/provider/func_tool_manager.py
+++ b/astrbot/core/provider/func_tool_manager.py
@@ -15,6 +15,7 @@ from astrbot.core.utils.log_pipe import LogPipe
try:
import mcp
+ from mcp.client.sse import sse_client
except (ModuleNotFoundError, ImportError):
logger.warning("警告: 缺少依赖库 'mcp',将无法使用 MCP 服务。")
@@ -92,7 +93,9 @@ class MCPClient:
self.server_errlogs: List[str] = []
async def connect_to_server(self, mcp_server_config: dict, name: str):
- """Connect to an MCP server
+ """连接到 MCP 服务器
+
+ 如果 `url` 参数存在,则使用 SSE 的方式连接到 MCP 服务。
Args:
mcp_server_config (dict): Configuration for the MCP server. See https://modelcontextprotocol.io/quickstart/server
@@ -102,29 +105,44 @@ class MCPClient:
key_0 = list(cfg["mcpServers"].keys())[0]
cfg = cfg["mcpServers"][key_0]
cfg.pop("active", None) # Remove active flag from config
- server_params = mcp.StdioServerParameters(
- **cfg,
- )
- def callback(msg: str):
- # 处理 MCP 服务的错误日志
- self.server_errlogs.append(msg)
+ if "url" in cfg:
+ # SSE transport method
+ self._streams_context = sse_client(url=cfg["url"])
+ streams = await self._streams_context.__aenter__()
- stdio_transport = await self.exit_stack.enter_async_context(
- mcp.stdio_client(
- server_params,
- errlog=LogPipe(
- level=logging.ERROR,
- logger=logger,
- identifier=f"MCPServer-{name}",
- callback=callback,
+ # Create a new client session
+ # self.session = await self._session_context.__aenter__()
+ self.session = await self.exit_stack.enter_async_context(
+ mcp.ClientSession(*streams)
+ )
+
+ else:
+ server_params = mcp.StdioServerParameters(
+ **cfg,
+ )
+
+ def callback(msg: str):
+ # 处理 MCP 服务的错误日志
+ self.server_errlogs.append(msg)
+
+ stdio_transport = await self.exit_stack.enter_async_context(
+ mcp.stdio_client(
+ server_params,
+ errlog=LogPipe(
+ level=logging.ERROR,
+ logger=logger,
+ identifier=f"MCPServer-{name}",
+ callback=callback,
+ ),
),
- ),
- )
- self.stdio, self.write = stdio_transport
- self.session = await self.exit_stack.enter_async_context(
- mcp.ClientSession(self.stdio, self.write)
- )
+ )
+
+ # Create a new client session
+ self.session = await self.exit_stack.enter_async_context(
+ mcp.ClientSession(*stdio_transport)
+ )
+
await self.session.initialize()
async def list_tools_and_save(self) -> mcp.ListToolsResult:
diff --git a/dashboard/src/views/ToolUsePage.vue b/dashboard/src/views/ToolUsePage.vue
index 954e2e3aa..a53be79a0 100644
--- a/dashboard/src/views/ToolUsePage.vue
+++ b/dashboard/src/views/ToolUsePage.vue
@@ -427,7 +427,9 @@
使用模板
- ⚠ 某些 MCP 服务器可能需要按照其要求在 `env` 中填充 `API_KEY` 或 `TOKEN` 等信息,请注意检查是否填写。
+ 1. 某些 MCP 服务器可能需要按照其要求在 env 中填充 `API_KEY` 或 `TOKEN` 等信息,请注意检查是否填写。
+
+ 2. 当配置中带有 url 参数时,将使用 SSE 的方式连接到服务器。