From d4d9a1df4cb9065ef818062478569e47158778c4 Mon Sep 17 00:00:00 2001 From: Alero Date: Sat, 15 Mar 2025 19:47:06 +0800 Subject: [PATCH] =?UTF-8?q?feat:=E6=96=B0=E5=A2=9EMCP=E6=9C=8D=E5=8A=A1?= =?UTF-8?q?=E6=94=AF=E6=8C=81=E5=B9=B6=E4=BC=98=E5=8C=96=E5=B7=A5=E5=85=B7?= =?UTF-8?q?=E8=B0=83=E7=94=A8=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 引入MCP客户端支持,增加mcp_server.json配置样例,完善工具描述生成及调用逻辑以支持MCP服务工具功能。同时调整相关逻辑以区分本地工具与MCP工具的调用方式,提升扩展性和灵活性。 --- .../process_stage/method/llm_request.py | 42 +++++--- astrbot/core/provider/func_tool_manager.py | 98 ++++++++++++++++++- astrbot/core/provider/manager.py | 3 + .../core/provider/sources/llmtuner_source.py | 2 +- .../core/provider/sources/openai_source.py | 8 +- requirements.txt | 3 + 6 files changed, 136 insertions(+), 20 deletions(-) diff --git a/astrbot/core/pipeline/process_stage/method/llm_request.py b/astrbot/core/pipeline/process_stage/method/llm_request.py index e8246805e..c243583a9 100644 --- a/astrbot/core/pipeline/process_stage/method/llm_request.py +++ b/astrbot/core/pipeline/process_stage/method/llm_request.py @@ -175,21 +175,35 @@ class LLMRequestSubStage(Stage): for func_tool_name, func_tool_args in zip( llm_response.tools_call_name, llm_response.tools_call_args ): - func_tool = req.func_tool.get_func(func_tool_name) - logger.info( - f"调用工具函数:{func_tool_name},参数:{func_tool_args}" - ) try: - # 尝试调用工具函数 - wrapper = self._call_handler( - self.ctx, event, func_tool.handler, **func_tool_args - ) - async for resp in wrapper: - if resp is not None: # 有 return 返回 - function_calling_result[func_tool_name] = resp - else: - yield # 有生成器返回 - event.clear_result() # 清除上一个 handler 的结果 + if func_tool_name.startswith('mcp:'): + _, mcp_server_name, mcp_func_name = func_tool_name.split(':') + logger.info( + f"从mcp服务 {mcp_server_name} 调用工具函数:{mcp_func_name},参数:{func_tool_args}") + + client = req.func_tool.mcp_client_dict[mcp_server_name] + res = await client.session.call_tool(mcp_func_name, func_tool_args) + if res: + # TODO content的类型可能包括list[TextContent | ImageContent | EmbeddedResource],这里只处理了TextContent。 + res_event = event.plain_result(res.content[0].text) + event.set_result(res_event) + yield + else: + func_tool = req.func_tool.get_func(func_tool_name) + logger.info( + f"调用工具函数:{func_tool_name},参数:{func_tool_args}" + ) + + # 尝试调用工具函数 + wrapper = self._call_handler( + self.ctx, event, func_tool.handler, **func_tool_args + ) + async for resp in wrapper: + if resp is not None: # 有 return 返回 + function_calling_result[func_tool_name] = resp + else: + yield # 有生成器返回 + event.clear_result() # 清除上一个 handler 的结果 except BaseException as e: logger.warning(traceback.format_exc()) function_calling_result[func_tool_name] = ( diff --git a/astrbot/core/provider/func_tool_manager.py b/astrbot/core/provider/func_tool_manager.py index 0f04628f7..2c7381a16 100644 --- a/astrbot/core/provider/func_tool_manager.py +++ b/astrbot/core/provider/func_tool_manager.py @@ -1,9 +1,19 @@ import json import textwrap +import os +import asyncio from typing import Dict, List, Awaitable from dataclasses import dataclass +from typing import Optional +from contextlib import AsyncExitStack + +from mcp import ClientSession, StdioServerParameters +from mcp.client.stdio import stdio_client from astrbot import logger +from anthropic import Anthropic + +from ... import logger @dataclass class FuncTool: @@ -33,9 +43,44 @@ SUPPORTED_TYPES = [ ] # json schema 支持的数据类型 +class MCPClient: + def __init__(self): + # Initialize session and client objects + self.session: Optional[ClientSession] = None + self.exit_stack = AsyncExitStack() + self.anthropic = Anthropic() + + self.name = None + self.active: bool = True + + async def connect_to_server(self, server_script_path: str): + """Connect to an MCP server + + Args: + server_script_path: Path to the server script (.py or .js) + """ + is_python = server_script_path.endswith('.py') + is_js = server_script_path.endswith('.js') + if not (is_python or is_js): + raise ValueError("Server script must be a .py or .js file") + + command = "python" if is_python else "node" + server_params = StdioServerParameters( + command=command, + args=[server_script_path], + env=None + ) + + stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params)) + self.stdio, self.write = stdio_transport + self.session = await self.exit_stack.enter_async_context(ClientSession(self.stdio, self.write)) + + await self.session.initialize() + class FuncCall: def __init__(self) -> None: self.func_list: List[FuncTool] = [] + self.mcp_client_dict: Dict[str: MCPClient] = dict() def empty(self) -> bool: return len(self.func_list) == 0 @@ -90,7 +135,43 @@ class FuncCall: return f return None - def get_func_desc_openai_style(self) -> list: + async def init_mcp_client_list(self) -> None: + """ + 从项目根目录读取mcp_server.json。提供mcp_server.json.example作为参考。 + 内容格式为 + { + "mcpServers": { + "example_cmp_server": { + "script_path": "path/to/cmp/server/script.py" + } + }, + ... + } + """ + current_dir = os.path.dirname(os.path.abspath(__file__)) + project_root = os.path.abspath(os.path.join(current_dir, "../../..")) + + mcp_json_file = os.path.join(project_root, "mcp_server.json") + if not os.path.exists(mcp_json_file): + # 配置文件不存在错误处理 + logger.warning(f"mcp server config file {mcp_json_file} not found. skip init mcp client list.") + return + + mcp_server_json_obj = json.load(open(mcp_json_file, "r", encoding="utf-8")) + + for mcp_server_name, mcp_server_script_path in mcp_server_json_obj["mcpServers"].items(): + if not os.path.exists(mcp_server_script_path["script_path"]): + logger.error(f"mcp server import err: Server script {mcp_server_script_path["script_path"]} not found.") + continue + mcp_client = MCPClient() + mcp_client.name = mcp_server_name + await mcp_client.connect_to_server(mcp_server_script_path["script_path"]) + self.mcp_client_dict[mcp_server_name] = mcp_client + logger.info(f"添加mcp服务 {mcp_server_name}.") + if len(self.mcp_client_dict) == 0: + logger.info("未启用任何mcp服务.") + + async def get_func_desc_openai_style(self) -> list: """ 获得 OpenAI API 风格的**已经激活**的工具描述 """ @@ -105,9 +186,22 @@ class FuncCall: "name": f.name, "parameters": f.parameters, "description": f.description, - }, + } } ) + + loop = asyncio.get_event_loop() + for name, client in self.mcp_client_dict.items(): + responses = await client.session.list_tools() + for tool in responses.tools: + _l.append({ + "type": "function", + "function": { + "name": f"mcp:{name}:{tool.name}", + "parameters": tool.inputSchema, + "description": tool.description, + } + }) return _l def get_func_desc_anthropic_style(self) -> list: diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py index 28146333f..c1acf74c0 100644 --- a/astrbot/core/provider/manager.py +++ b/astrbot/core/provider/manager.py @@ -127,6 +127,9 @@ class ProviderManager: if self.tts_enabled and not self.curr_tts_provider_inst: logger.warning("未启用任何用于 文本转语音 的提供商适配器。") + # 初始化mcpclient连接 + await self.llm_tools.init_mcp_client_list() + async def load_provider(self, provider_config: dict): if not provider_config["enable"]: return diff --git a/astrbot/core/provider/sources/llmtuner_source.py b/astrbot/core/provider/sources/llmtuner_source.py index bfd9e03a5..adb5bf428 100644 --- a/astrbot/core/provider/sources/llmtuner_source.py +++ b/astrbot/core/provider/sources/llmtuner_source.py @@ -85,7 +85,7 @@ class LLMTunerModelLoader(Provider): "system": system_prompt, } if func_tool: - tool_list = func_tool.get_func_desc_openai_style() + tool_list = await func_tool.get_func_desc_openai_style() if tool_list: conf["tools"] = tool_list diff --git a/astrbot/core/provider/sources/openai_source.py b/astrbot/core/provider/sources/openai_source.py index 4b66c50ef..5da2f9263 100644 --- a/astrbot/core/provider/sources/openai_source.py +++ b/astrbot/core/provider/sources/openai_source.py @@ -80,7 +80,7 @@ class ProviderOpenAIOfficial(Provider): async def _query(self, payloads: dict, tools: FuncCall) -> LLMResponse: if tools: - tool_list = tools.get_func_desc_openai_style() + tool_list = await tools.get_func_desc_openai_style() if tool_list: payloads["tools"] = tool_list @@ -124,8 +124,10 @@ class ProviderOpenAIOfficial(Provider): for tool in tools.func_list: if tool.name == tool_call.function.name: args = json.loads(tool_call.function.arguments) - args_ls.append(args) - func_name_ls.append(tool_call.function.name) + if tool_call.function.name.startswith("mcp:") and tool_call.function.name.split(':')[1] in tools.mcp_client_dict: + args = json.loads(tool_call.function.arguments) + args_ls.append(args) + func_name_ls.append(tool_call.function.name) llm_response.role = "tool" llm_response.tools_call_args = args_ls llm_response.tools_call_name = func_name_ls diff --git a/requirements.txt b/requirements.txt index 07b168cc1..23f86be3d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -21,6 +21,9 @@ psutil>=5.8.0 lark-oapi ormsgpack cryptography + +mcp~=1.4.1 +anthropic~=0.49.0 dashscope python-telegram-bot wechatpy