From d4d9a1df4cb9065ef818062478569e47158778c4 Mon Sep 17 00:00:00 2001 From: Alero Date: Sat, 15 Mar 2025 19:47:06 +0800 Subject: [PATCH 01/16] =?UTF-8?q?feat:=E6=96=B0=E5=A2=9EMCP=E6=9C=8D?= =?UTF-8?q?=E5=8A=A1=E6=94=AF=E6=8C=81=E5=B9=B6=E4=BC=98=E5=8C=96=E5=B7=A5?= =?UTF-8?q?=E5=85=B7=E8=B0=83=E7=94=A8=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 引入MCP客户端支持,增加mcp_server.json配置样例,完善工具描述生成及调用逻辑以支持MCP服务工具功能。同时调整相关逻辑以区分本地工具与MCP工具的调用方式,提升扩展性和灵活性。 --- .../process_stage/method/llm_request.py | 42 +++++--- astrbot/core/provider/func_tool_manager.py | 98 ++++++++++++++++++- astrbot/core/provider/manager.py | 3 + .../core/provider/sources/llmtuner_source.py | 2 +- .../core/provider/sources/openai_source.py | 8 +- requirements.txt | 3 + 6 files changed, 136 insertions(+), 20 deletions(-) diff --git a/astrbot/core/pipeline/process_stage/method/llm_request.py b/astrbot/core/pipeline/process_stage/method/llm_request.py index e8246805e..c243583a9 100644 --- a/astrbot/core/pipeline/process_stage/method/llm_request.py +++ b/astrbot/core/pipeline/process_stage/method/llm_request.py @@ -175,21 +175,35 @@ class LLMRequestSubStage(Stage): for func_tool_name, func_tool_args in zip( llm_response.tools_call_name, llm_response.tools_call_args ): - func_tool = req.func_tool.get_func(func_tool_name) - logger.info( - f"调用工具函数:{func_tool_name},参数:{func_tool_args}" - ) try: - # 尝试调用工具函数 - wrapper = self._call_handler( - self.ctx, event, func_tool.handler, **func_tool_args - ) - async for resp in wrapper: - if resp is not None: # 有 return 返回 - function_calling_result[func_tool_name] = resp - else: - yield # 有生成器返回 - event.clear_result() # 清除上一个 handler 的结果 + if func_tool_name.startswith('mcp:'): + _, mcp_server_name, mcp_func_name = func_tool_name.split(':') + logger.info( + f"从mcp服务 {mcp_server_name} 调用工具函数:{mcp_func_name},参数:{func_tool_args}") + + client = req.func_tool.mcp_client_dict[mcp_server_name] + res = await client.session.call_tool(mcp_func_name, func_tool_args) + if res: + # TODO content的类型可能包括list[TextContent | ImageContent | EmbeddedResource],这里只处理了TextContent。 + res_event = event.plain_result(res.content[0].text) + event.set_result(res_event) + yield + else: + func_tool = req.func_tool.get_func(func_tool_name) + logger.info( + f"调用工具函数:{func_tool_name},参数:{func_tool_args}" + ) + + # 尝试调用工具函数 + wrapper = self._call_handler( + self.ctx, event, func_tool.handler, **func_tool_args + ) + async for resp in wrapper: + if resp is not None: # 有 return 返回 + function_calling_result[func_tool_name] = resp + else: + yield # 有生成器返回 + event.clear_result() # 清除上一个 handler 的结果 except BaseException as e: logger.warning(traceback.format_exc()) function_calling_result[func_tool_name] = ( diff --git a/astrbot/core/provider/func_tool_manager.py b/astrbot/core/provider/func_tool_manager.py index 0f04628f7..2c7381a16 100644 --- a/astrbot/core/provider/func_tool_manager.py +++ b/astrbot/core/provider/func_tool_manager.py @@ -1,9 +1,19 @@ import json import textwrap +import os +import asyncio from typing import Dict, List, Awaitable from dataclasses import dataclass +from typing import Optional +from contextlib import AsyncExitStack + +from mcp import ClientSession, StdioServerParameters +from mcp.client.stdio import stdio_client from astrbot import logger +from anthropic import Anthropic + +from ... import logger @dataclass class FuncTool: @@ -33,9 +43,44 @@ SUPPORTED_TYPES = [ ] # json schema 支持的数据类型 +class MCPClient: + def __init__(self): + # Initialize session and client objects + self.session: Optional[ClientSession] = None + self.exit_stack = AsyncExitStack() + self.anthropic = Anthropic() + + self.name = None + self.active: bool = True + + async def connect_to_server(self, server_script_path: str): + """Connect to an MCP server + + Args: + server_script_path: Path to the server script (.py or .js) + """ + is_python = server_script_path.endswith('.py') + is_js = server_script_path.endswith('.js') + if not (is_python or is_js): + raise ValueError("Server script must be a .py or .js file") + + command = "python" if is_python else "node" + server_params = StdioServerParameters( + command=command, + args=[server_script_path], + env=None + ) + + stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params)) + self.stdio, self.write = stdio_transport + self.session = await self.exit_stack.enter_async_context(ClientSession(self.stdio, self.write)) + + await self.session.initialize() + class FuncCall: def __init__(self) -> None: self.func_list: List[FuncTool] = [] + self.mcp_client_dict: Dict[str: MCPClient] = dict() def empty(self) -> bool: return len(self.func_list) == 0 @@ -90,7 +135,43 @@ class FuncCall: return f return None - def get_func_desc_openai_style(self) -> list: + async def init_mcp_client_list(self) -> None: + """ + 从项目根目录读取mcp_server.json。提供mcp_server.json.example作为参考。 + 内容格式为 + { + "mcpServers": { + "example_cmp_server": { + "script_path": "path/to/cmp/server/script.py" + } + }, + ... + } + """ + current_dir = os.path.dirname(os.path.abspath(__file__)) + project_root = os.path.abspath(os.path.join(current_dir, "../../..")) + + mcp_json_file = os.path.join(project_root, "mcp_server.json") + if not os.path.exists(mcp_json_file): + # 配置文件不存在错误处理 + logger.warning(f"mcp server config file {mcp_json_file} not found. skip init mcp client list.") + return + + mcp_server_json_obj = json.load(open(mcp_json_file, "r", encoding="utf-8")) + + for mcp_server_name, mcp_server_script_path in mcp_server_json_obj["mcpServers"].items(): + if not os.path.exists(mcp_server_script_path["script_path"]): + logger.error(f"mcp server import err: Server script {mcp_server_script_path["script_path"]} not found.") + continue + mcp_client = MCPClient() + mcp_client.name = mcp_server_name + await mcp_client.connect_to_server(mcp_server_script_path["script_path"]) + self.mcp_client_dict[mcp_server_name] = mcp_client + logger.info(f"添加mcp服务 {mcp_server_name}.") + if len(self.mcp_client_dict) == 0: + logger.info("未启用任何mcp服务.") + + async def get_func_desc_openai_style(self) -> list: """ 获得 OpenAI API 风格的**已经激活**的工具描述 """ @@ -105,9 +186,22 @@ class FuncCall: "name": f.name, "parameters": f.parameters, "description": f.description, - }, + } } ) + + loop = asyncio.get_event_loop() + for name, client in self.mcp_client_dict.items(): + responses = await client.session.list_tools() + for tool in responses.tools: + _l.append({ + "type": "function", + "function": { + "name": f"mcp:{name}:{tool.name}", + "parameters": tool.inputSchema, + "description": tool.description, + } + }) return _l def get_func_desc_anthropic_style(self) -> list: diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py index 28146333f..c1acf74c0 100644 --- a/astrbot/core/provider/manager.py +++ b/astrbot/core/provider/manager.py @@ -127,6 +127,9 @@ class ProviderManager: if self.tts_enabled and not self.curr_tts_provider_inst: logger.warning("未启用任何用于 文本转语音 的提供商适配器。") + # 初始化mcpclient连接 + await self.llm_tools.init_mcp_client_list() + async def load_provider(self, provider_config: dict): if not provider_config["enable"]: return diff --git a/astrbot/core/provider/sources/llmtuner_source.py b/astrbot/core/provider/sources/llmtuner_source.py index bfd9e03a5..adb5bf428 100644 --- a/astrbot/core/provider/sources/llmtuner_source.py +++ b/astrbot/core/provider/sources/llmtuner_source.py @@ -85,7 +85,7 @@ class LLMTunerModelLoader(Provider): "system": system_prompt, } if func_tool: - tool_list = func_tool.get_func_desc_openai_style() + tool_list = await func_tool.get_func_desc_openai_style() if tool_list: conf["tools"] = tool_list diff --git a/astrbot/core/provider/sources/openai_source.py b/astrbot/core/provider/sources/openai_source.py index 4b66c50ef..5da2f9263 100644 --- a/astrbot/core/provider/sources/openai_source.py +++ b/astrbot/core/provider/sources/openai_source.py @@ -80,7 +80,7 @@ class ProviderOpenAIOfficial(Provider): async def _query(self, payloads: dict, tools: FuncCall) -> LLMResponse: if tools: - tool_list = tools.get_func_desc_openai_style() + tool_list = await tools.get_func_desc_openai_style() if tool_list: payloads["tools"] = tool_list @@ -124,8 +124,10 @@ class ProviderOpenAIOfficial(Provider): for tool in tools.func_list: if tool.name == tool_call.function.name: args = json.loads(tool_call.function.arguments) - args_ls.append(args) - func_name_ls.append(tool_call.function.name) + if tool_call.function.name.startswith("mcp:") and tool_call.function.name.split(':')[1] in tools.mcp_client_dict: + args = json.loads(tool_call.function.arguments) + args_ls.append(args) + func_name_ls.append(tool_call.function.name) llm_response.role = "tool" llm_response.tools_call_args = args_ls llm_response.tools_call_name = func_name_ls diff --git a/requirements.txt b/requirements.txt index 07b168cc1..23f86be3d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -21,6 +21,9 @@ psutil>=5.8.0 lark-oapi ormsgpack cryptography + +mcp~=1.4.1 +anthropic~=0.49.0 dashscope python-telegram-bot wechatpy From 3dea60366aee59d1a4f980ce7316936cac83e891 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 15 Mar 2025 11:54:09 +0000 Subject: [PATCH 02/16] :balloon: auto fixes by pre-commit hooks --- .../process_stage/method/llm_request.py | 13 +++-- astrbot/core/provider/func_tool_manager.py | 50 ++++++++++++------- .../core/provider/sources/openai_source.py | 5 +- 3 files changed, 44 insertions(+), 24 deletions(-) diff --git a/astrbot/core/pipeline/process_stage/method/llm_request.py b/astrbot/core/pipeline/process_stage/method/llm_request.py index c243583a9..2050f37d6 100644 --- a/astrbot/core/pipeline/process_stage/method/llm_request.py +++ b/astrbot/core/pipeline/process_stage/method/llm_request.py @@ -176,13 +176,18 @@ class LLMRequestSubStage(Stage): llm_response.tools_call_name, llm_response.tools_call_args ): try: - if func_tool_name.startswith('mcp:'): - _, mcp_server_name, mcp_func_name = func_tool_name.split(':') + if func_tool_name.startswith("mcp:"): + _, mcp_server_name, mcp_func_name = func_tool_name.split( + ":" + ) logger.info( - f"从mcp服务 {mcp_server_name} 调用工具函数:{mcp_func_name},参数:{func_tool_args}") + f"从mcp服务 {mcp_server_name} 调用工具函数:{mcp_func_name},参数:{func_tool_args}" + ) client = req.func_tool.mcp_client_dict[mcp_server_name] - res = await client.session.call_tool(mcp_func_name, func_tool_args) + res = await client.session.call_tool( + mcp_func_name, func_tool_args + ) if res: # TODO content的类型可能包括list[TextContent | ImageContent | EmbeddedResource],这里只处理了TextContent。 res_event = event.plain_result(res.content[0].text) diff --git a/astrbot/core/provider/func_tool_manager.py b/astrbot/core/provider/func_tool_manager.py index 2c7381a16..5bfabf0b5 100644 --- a/astrbot/core/provider/func_tool_manager.py +++ b/astrbot/core/provider/func_tool_manager.py @@ -15,6 +15,7 @@ from anthropic import Anthropic from ... import logger + @dataclass class FuncTool: """ @@ -59,28 +60,31 @@ class MCPClient: Args: server_script_path: Path to the server script (.py or .js) """ - is_python = server_script_path.endswith('.py') - is_js = server_script_path.endswith('.js') + is_python = server_script_path.endswith(".py") + is_js = server_script_path.endswith(".js") if not (is_python or is_js): raise ValueError("Server script must be a .py or .js file") command = "python" if is_python else "node" server_params = StdioServerParameters( - command=command, - args=[server_script_path], - env=None + command=command, args=[server_script_path], env=None ) - stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params)) + stdio_transport = await self.exit_stack.enter_async_context( + stdio_client(server_params) + ) self.stdio, self.write = stdio_transport - self.session = await self.exit_stack.enter_async_context(ClientSession(self.stdio, self.write)) + self.session = await self.exit_stack.enter_async_context( + ClientSession(self.stdio, self.write) + ) await self.session.initialize() + class FuncCall: def __init__(self) -> None: self.func_list: List[FuncTool] = [] - self.mcp_client_dict: Dict[str: MCPClient] = dict() + self.mcp_client_dict: Dict[str:MCPClient] = dict() def empty(self) -> bool: return len(self.func_list) == 0 @@ -154,14 +158,20 @@ class FuncCall: mcp_json_file = os.path.join(project_root, "mcp_server.json") if not os.path.exists(mcp_json_file): # 配置文件不存在错误处理 - logger.warning(f"mcp server config file {mcp_json_file} not found. skip init mcp client list.") + logger.warning( + f"mcp server config file {mcp_json_file} not found. skip init mcp client list." + ) return mcp_server_json_obj = json.load(open(mcp_json_file, "r", encoding="utf-8")) - for mcp_server_name, mcp_server_script_path in mcp_server_json_obj["mcpServers"].items(): + for mcp_server_name, mcp_server_script_path in mcp_server_json_obj[ + "mcpServers" + ].items(): if not os.path.exists(mcp_server_script_path["script_path"]): - logger.error(f"mcp server import err: Server script {mcp_server_script_path["script_path"]} not found.") + logger.error( + f"mcp server import err: Server script {mcp_server_script_path['script_path']} not found." + ) continue mcp_client = MCPClient() mcp_client.name = mcp_server_name @@ -186,7 +196,7 @@ class FuncCall: "name": f.name, "parameters": f.parameters, "description": f.description, - } + }, } ) @@ -194,14 +204,16 @@ class FuncCall: for name, client in self.mcp_client_dict.items(): responses = await client.session.list_tools() for tool in responses.tools: - _l.append({ - "type": "function", - "function": { - "name": f"mcp:{name}:{tool.name}", - "parameters": tool.inputSchema, - "description": tool.description, + _l.append( + { + "type": "function", + "function": { + "name": f"mcp:{name}:{tool.name}", + "parameters": tool.inputSchema, + "description": tool.description, + }, } - }) + ) return _l def get_func_desc_anthropic_style(self) -> list: diff --git a/astrbot/core/provider/sources/openai_source.py b/astrbot/core/provider/sources/openai_source.py index 5da2f9263..342c2febb 100644 --- a/astrbot/core/provider/sources/openai_source.py +++ b/astrbot/core/provider/sources/openai_source.py @@ -124,7 +124,10 @@ class ProviderOpenAIOfficial(Provider): for tool in tools.func_list: if tool.name == tool_call.function.name: args = json.loads(tool_call.function.arguments) - if tool_call.function.name.startswith("mcp:") and tool_call.function.name.split(':')[1] in tools.mcp_client_dict: + if ( + tool_call.function.name.startswith("mcp:") + and tool_call.function.name.split(":")[1] in tools.mcp_client_dict + ): args = json.loads(tool_call.function.arguments) args_ls.append(args) func_name_ls.append(tool_call.function.name) From 3d6f7aa0e14a7f6f21c8ccc329c71a10b1d0f7f7 Mon Sep 17 00:00:00 2001 From: Alero Date: Sat, 15 Mar 2025 20:09:49 +0800 Subject: [PATCH 03/16] =?UTF-8?q?=E4=BF=AE=E5=A4=8Dcodecheck?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/core/provider/func_tool_manager.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/astrbot/core/provider/func_tool_manager.py b/astrbot/core/provider/func_tool_manager.py index 5bfabf0b5..bff0b53c3 100644 --- a/astrbot/core/provider/func_tool_manager.py +++ b/astrbot/core/provider/func_tool_manager.py @@ -13,9 +13,6 @@ from astrbot import logger from anthropic import Anthropic -from ... import logger - - @dataclass class FuncTool: """ From 2d1f74228de0432060cb93a9e44c0f178e8d6870 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 15 Mar 2025 12:10:17 +0000 Subject: [PATCH 04/16] :balloon: auto fixes by pre-commit hooks --- astrbot/core/provider/func_tool_manager.py | 1 + 1 file changed, 1 insertion(+) diff --git a/astrbot/core/provider/func_tool_manager.py b/astrbot/core/provider/func_tool_manager.py index bff0b53c3..7b1265b20 100644 --- a/astrbot/core/provider/func_tool_manager.py +++ b/astrbot/core/provider/func_tool_manager.py @@ -13,6 +13,7 @@ from astrbot import logger from anthropic import Anthropic + @dataclass class FuncTool: """ From 9fa2a7eeea98ab691edec5e752a567bb37718cbe Mon Sep 17 00:00:00 2001 From: Alero Date: Sat, 15 Mar 2025 20:24:36 +0800 Subject: [PATCH 05/16] =?UTF-8?q?=E4=BF=AE=E5=A4=8Dcodecheck?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/core/provider/func_tool_manager.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/astrbot/core/provider/func_tool_manager.py b/astrbot/core/provider/func_tool_manager.py index 7b1265b20..adba2c40b 100644 --- a/astrbot/core/provider/func_tool_manager.py +++ b/astrbot/core/provider/func_tool_manager.py @@ -82,7 +82,7 @@ class MCPClient: class FuncCall: def __init__(self) -> None: self.func_list: List[FuncTool] = [] - self.mcp_client_dict: Dict[str:MCPClient] = dict() + self.mcp_client_dict: Dict[str:MCPClient] = {} def empty(self) -> bool: return len(self.func_list) == 0 @@ -198,7 +198,6 @@ class FuncCall: } ) - loop = asyncio.get_event_loop() for name, client in self.mcp_client_dict.items(): responses = await client.session.list_tools() for tool in responses.tools: From 8585cd8e2111908cb8e4c907caf91d910f0d0e22 Mon Sep 17 00:00:00 2001 From: Alero Date: Sat, 15 Mar 2025 20:26:17 +0800 Subject: [PATCH 06/16] =?UTF-8?q?=E4=BF=AE=E5=A4=8Dcodecheck?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/core/provider/func_tool_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/astrbot/core/provider/func_tool_manager.py b/astrbot/core/provider/func_tool_manager.py index adba2c40b..2c96dc812 100644 --- a/astrbot/core/provider/func_tool_manager.py +++ b/astrbot/core/provider/func_tool_manager.py @@ -1,7 +1,7 @@ import json import textwrap import os -import asyncio + from typing import Dict, List, Awaitable from dataclasses import dataclass from typing import Optional From 4179b0be0a666cf4c93655f21a1ccb5714f0f51d Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Sun, 23 Mar 2025 11:31:10 +0800 Subject: [PATCH 07/16] =?UTF-8?q?chore:=20=E4=BC=98=E5=8C=96=E6=B3=A8?= =?UTF-8?q?=E8=A7=A3=E6=A0=BC=E5=BC=8F=E5=92=8C=20requirements.txt?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/core/provider/func_tool_manager.py | 16 ++++++++-------- requirements.txt | 6 ++---- 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/astrbot/core/provider/func_tool_manager.py b/astrbot/core/provider/func_tool_manager.py index 2c96dc812..e38bd3a77 100644 --- a/astrbot/core/provider/func_tool_manager.py +++ b/astrbot/core/provider/func_tool_manager.py @@ -82,7 +82,7 @@ class MCPClient: class FuncCall: def __init__(self) -> None: self.func_list: List[FuncTool] = [] - self.mcp_client_dict: Dict[str:MCPClient] = {} + self.mcp_client_dict: Dict[str, MCPClient] = {} def empty(self) -> bool: return len(self.func_list) == 0 @@ -138,9 +138,8 @@ class FuncCall: return None async def init_mcp_client_list(self) -> None: - """ - 从项目根目录读取mcp_server.json。提供mcp_server.json.example作为参考。 - 内容格式为 + """从项目根目录读取 mcp_server.json 文件,初始化 MCP 服务列表。文件格式如下: + ``` { "mcpServers": { "example_cmp_server": { @@ -149,6 +148,7 @@ class FuncCall: }, ... } + ``` """ current_dir = os.path.dirname(os.path.abspath(__file__)) project_root = os.path.abspath(os.path.join(current_dir, "../../..")) @@ -161,23 +161,23 @@ class FuncCall: ) return - mcp_server_json_obj = json.load(open(mcp_json_file, "r", encoding="utf-8")) + mcp_server_json_obj: Dict[str, Dict] = json.load(open(mcp_json_file, "r", encoding="utf-8")) for mcp_server_name, mcp_server_script_path in mcp_server_json_obj[ "mcpServers" ].items(): if not os.path.exists(mcp_server_script_path["script_path"]): logger.error( - f"mcp server import err: Server script {mcp_server_script_path['script_path']} not found." + f"MCP server import err: Server script {mcp_server_script_path['script_path']} not found." ) continue mcp_client = MCPClient() mcp_client.name = mcp_server_name await mcp_client.connect_to_server(mcp_server_script_path["script_path"]) self.mcp_client_dict[mcp_server_name] = mcp_client - logger.info(f"添加mcp服务 {mcp_server_name}.") + logger.info(f"添加 MCP 服务 {mcp_server_name}") if len(self.mcp_client_dict) == 0: - logger.info("未启用任何mcp服务.") + logger.info("未启用任何 MCP 服务") async def get_func_desc_openai_style(self) -> list: """ diff --git a/requirements.txt b/requirements.txt index 23f86be3d..313dba0c8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -21,10 +21,8 @@ psutil>=5.8.0 lark-oapi ormsgpack cryptography - -mcp~=1.4.1 -anthropic~=0.49.0 dashscope python-telegram-bot wechatpy -dingtalk-stream \ No newline at end of file +dingtalk-stream +mcp \ No newline at end of file From 9f8e960ebe3b8569f2f4d8f920ad364d39840a8b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 23 Mar 2025 03:31:19 +0000 Subject: [PATCH 08/16] :balloon: auto fixes by pre-commit hooks --- astrbot/core/provider/func_tool_manager.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/astrbot/core/provider/func_tool_manager.py b/astrbot/core/provider/func_tool_manager.py index e38bd3a77..9379096e4 100644 --- a/astrbot/core/provider/func_tool_manager.py +++ b/astrbot/core/provider/func_tool_manager.py @@ -161,7 +161,9 @@ class FuncCall: ) return - mcp_server_json_obj: Dict[str, Dict] = json.load(open(mcp_json_file, "r", encoding="utf-8")) + mcp_server_json_obj: Dict[str, Dict] = json.load( + open(mcp_json_file, "r", encoding="utf-8") + ) for mcp_server_name, mcp_server_script_path in mcp_server_json_obj[ "mcpServers" From 046f5e645ed81c3d23099b11da38a67bfbe63704 Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Sun, 23 Mar 2025 16:31:26 +0800 Subject: [PATCH 09/16] =?UTF-8?q?=E2=9C=A8=20feat:=20=E5=AE=8C=E5=96=84=20?= =?UTF-8?q?MCP=20=E7=AE=A1=E7=90=86=E5=92=8C=E5=AE=9E=E7=8E=B0=20WebUI=20M?= =?UTF-8?q?CP=20=E7=9B=B8=E5=85=B3=E7=9A=84=E9=A1=B5=E9=9D=A2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../process_stage/method/llm_request.py | 17 +- astrbot/core/provider/func_tool_manager.py | 268 +++++--- astrbot/core/provider/manager.py | 8 +- .../core/provider/sources/llmtuner_source.py | 2 +- .../core/provider/sources/openai_source.py | 7 +- astrbot/dashboard/routes/__init__.py | 2 + astrbot/dashboard/routes/config.py | 7 + astrbot/dashboard/routes/tools.py | 250 +++++++ astrbot/dashboard/server.py | 1 + .../src/components/shared/ItemCardGrid.vue | 134 ++++ .../full/vertical-sidebar/sidebarItem.ts | 5 + dashboard/src/router/MainRoutes.ts | 5 + dashboard/src/views/PlatformPage.vue | 566 ++++++++-------- dashboard/src/views/ProviderPage.vue | 567 ++++++++-------- dashboard/src/views/ToolUsePage.vue | 631 ++++++++++++++++++ 15 files changed, 1820 insertions(+), 650 deletions(-) create mode 100644 astrbot/dashboard/routes/tools.py create mode 100644 dashboard/src/components/shared/ItemCardGrid.vue create mode 100644 dashboard/src/views/ToolUsePage.vue diff --git a/astrbot/core/pipeline/process_stage/method/llm_request.py b/astrbot/core/pipeline/process_stage/method/llm_request.py index 7d7c4516f..3767cc59d 100644 --- a/astrbot/core/pipeline/process_stage/method/llm_request.py +++ b/astrbot/core/pipeline/process_stage/method/llm_request.py @@ -176,17 +176,14 @@ class LLMRequestSubStage(Stage): llm_response.tools_call_name, llm_response.tools_call_args ): try: - if func_tool_name.startswith("mcp:"): - _, mcp_server_name, mcp_func_name = func_tool_name.split( - ":" - ) + func_tool = req.func_tool.get_func(func_tool_name) + if func_tool.origin == "mcp": logger.info( - f"从mcp服务 {mcp_server_name} 调用工具函数:{mcp_func_name},参数:{func_tool_args}" + f"从 MCP 服务 {func_tool.mcp_server_name} 调用工具函数:{func_tool.name},参数:{func_tool_args}" ) - - client = req.func_tool.mcp_client_dict[mcp_server_name] + client = req.func_tool.mcp_client_dict[func_tool.mcp_server_name] res = await client.session.call_tool( - mcp_func_name, func_tool_args + func_tool.name, func_tool_args ) if res: # TODO content的类型可能包括list[TextContent | ImageContent | EmbeddedResource],这里只处理了TextContent。 @@ -194,11 +191,9 @@ class LLMRequestSubStage(Stage): event.set_result(res_event) yield else: - func_tool = req.func_tool.get_func(func_tool_name) logger.info( f"调用工具函数:{func_tool_name},参数:{func_tool_args}" ) - # 尝试调用工具函数 wrapper = self._call_handler( self.ctx, event, func_tool.handler, **func_tool_args @@ -208,7 +203,7 @@ class LLMRequestSubStage(Stage): function_calling_result[func_tool_name] = resp else: yield # 有生成器返回 - event.clear_result() # 清除上一个 handler 的结果 + event.clear_result() # 清除上一个 handler 的结果 except BaseException as e: logger.warning(traceback.format_exc()) function_calling_result[func_tool_name] = ( diff --git a/astrbot/core/provider/func_tool_manager.py b/astrbot/core/provider/func_tool_manager.py index 9379096e4..93d99de65 100644 --- a/astrbot/core/provider/func_tool_manager.py +++ b/astrbot/core/provider/func_tool_manager.py @@ -1,17 +1,27 @@ +from __future__ import annotations import json import textwrap import os +import asyncio +import mcp -from typing import Dict, List, Awaitable +from typing import Dict, List, Awaitable, Literal, Any from dataclasses import dataclass from typing import Optional from contextlib import AsyncExitStack -from mcp import ClientSession, StdioServerParameters from mcp.client.stdio import stdio_client from astrbot import logger -from anthropic import Anthropic +DEFAULT_MCP_CONFIG = {"mcpServers": {}} + +SUPPORTED_TYPES = [ + "string", + "number", + "object", + "array", + "boolean", +] # json schema 支持的数据类型 @dataclass @@ -23,49 +33,68 @@ class FuncTool: name: str parameters: Dict description: str - handler: Awaitable - handler_module_path: str = None # 必须要保留这个,handler 在初始化会被 functools.partial 包装,导致 handler 的 __module__ 为 functools + handler: Awaitable = None + """处理函数, 当 origin 为 mcp 时,这个为空""" + handler_module_path: str = None + """处理函数的模块路径,当 origin 为 mcp 时,这个为空 + 必须要保留这个字段, handler 在初始化会被 functools.partial 包装,导致 handler 的 __module__ 为 functools + """ active: bool = True """是否激活""" + origin: Literal["local", "mcp"] = "local" + """函数工具的来源, local 为本地函数工具, mcp 为 MCP 服务""" + + # MCP 相关字段 + mcp_server_name: str = None + """MCP 服务名称,当 origin 为 mcp 时有效""" + mcp_client: MCPClient = None + """MCP 客户端,当 origin 为 mcp 时有效""" + def __repr__(self): - return f"FuncTool(name={self.name}, parameters={self.parameters}, description={self.description}), active={self.active})" + return f"FuncTool(name={self.name}, parameters={self.parameters}, description={self.description}, active={self.active}, origin={self.origin})" - -SUPPORTED_TYPES = [ - "string", - "number", - "object", - "array", - "boolean", -] # json schema 支持的数据类型 + async def execute(self, **args) -> Any: + """执行函数调用""" + if self.origin == "local": + if not self.handler: + raise Exception(f"Local function {self.name} has no handler") + return await self.handler(**args) + elif self.origin == "mcp": + if not self.mcp_client or not self.mcp_client.session: + raise Exception(f"MCP client for {self.name} is not available") + # 使用name属性而不是额外的mcp_tool_name + if ":" in self.name: + # 如果名字是格式为 mcp:server:tool_name,提取实际的工具名 + actual_tool_name = self.name.split(":")[-1] + return await self.mcp_client.session.call_tool(actual_tool_name, args) + else: + return await self.mcp_client.session.call_tool(self.name, args) + else: + raise Exception(f"Unknown function origin: {self.origin}") class MCPClient: def __init__(self): # Initialize session and client objects - self.session: Optional[ClientSession] = None + self.session: Optional[mcp.ClientSession] = None self.exit_stack = AsyncExitStack() - self.anthropic = Anthropic() self.name = None self.active: bool = True + self.tools: List[mcp.Tool] = [] - async def connect_to_server(self, server_script_path: str): + async def connect_to_server(self, mcp_server_config: dict): """Connect to an MCP server Args: - server_script_path: Path to the server script (.py or .js) + mcp_server_config (dict): Configuration for the MCP server. See https://modelcontextprotocol.io/quickstart/server """ - is_python = server_script_path.endswith(".py") - is_js = server_script_path.endswith(".js") - if not (is_python or is_js): - raise ValueError("Server script must be a .py or .js file") - - command = "python" if is_python else "node" - server_params = StdioServerParameters( - command=command, args=[server_script_path], env=None + cfg = mcp_server_config.copy() + cfg.pop("active", None) + server_params = mcp.StdioServerParameters( + **cfg, ) stdio_transport = await self.exit_stack.enter_async_context( @@ -73,16 +102,31 @@ class MCPClient: ) self.stdio, self.write = stdio_transport self.session = await self.exit_stack.enter_async_context( - ClientSession(self.stdio, self.write) + mcp.ClientSession(self.stdio, self.write) ) await self.session.initialize() + async def list_tools_and_save(self) -> mcp.ListToolsResult: + """List all tools from the server and save them to self.tools""" + response = await self.session.list_tools() + logger.debug(f"MCP server {self.name} list tools response: {response}") + self.tools = response.tools + return response + + async def cleanup(self): + """Clean up resources""" + await self.exit_stack.aclose() + class FuncCall: def __init__(self) -> None: self.func_list: List[FuncTool] = [] + """内部加载的 func tools""" self.mcp_client_dict: Dict[str, MCPClient] = {} + """MCP 服务列表""" + self.mcp_service_queue = asyncio.Queue() + """用于外部控制 MCP 服务的启停""" def empty(self) -> bool: return len(self.func_list) == 0 @@ -137,55 +181,139 @@ class FuncCall: return f return None - async def init_mcp_client_list(self) -> None: + async def _init_mcp_clients(self) -> None: """从项目根目录读取 mcp_server.json 文件,初始化 MCP 服务列表。文件格式如下: ``` { "mcpServers": { - "example_cmp_server": { - "script_path": "path/to/cmp/server/script.py" + "weather": { + "command": "uv", + "args": [ + "--directory", + "/ABSOLUTE/PATH/TO/PARENT/FOLDER/weather", + "run", + "weather.py" + ] } - }, + } ... } ``` """ current_dir = os.path.dirname(os.path.abspath(__file__)) - project_root = os.path.abspath(os.path.join(current_dir, "../../..")) + data_dir = os.path.abspath(os.path.join(current_dir, "../../../data")) - mcp_json_file = os.path.join(project_root, "mcp_server.json") + mcp_json_file = os.path.join(data_dir, "mcp_server.json") if not os.path.exists(mcp_json_file): # 配置文件不存在错误处理 - logger.warning( - f"mcp server config file {mcp_json_file} not found. skip init mcp client list." - ) + with open(mcp_json_file, "w", encoding="utf-8") as f: + json.dump(DEFAULT_MCP_CONFIG, f, ensure_ascii=False, indent=4) + logger.info(f"未找到 MCP 服务配置文件,已创建默认配置文件 {mcp_json_file}") return mcp_server_json_obj: Dict[str, Dict] = json.load( open(mcp_json_file, "r", encoding="utf-8") - ) + )["mcpServers"] + + for name in mcp_server_json_obj.keys(): + cfg = mcp_server_json_obj[name] + if cfg.get("active", True): + asyncio.create_task(self._init_mcp_client(name, cfg)) + + async def mcp_service_selector(self): + """为了避免在不同异步任务中控制 MCP 服务导致的报错,整个项目统一通过这个 Task 来控制 + + 使用 self.mcp_service_queue.put_nowait() 来控制 MCP 服务的启停,数据格式如下: + + {"type": "init"} 初始化所有MCP客户端 + + {"type": "init", "name": "mcp_server_name", "cfg": {...}} 初始化指定的MCP客户端 + + {"type": "terminate"} 终止所有MCP客户端 + + {"type": "terminate", "name": "mcp_server_name"} 终止指定的MCP客户端 + """ + while True: + data = await self.mcp_service_queue.get() + if data["type"] == "init": + if "name" in data: + asyncio.create_task(self._init_mcp_client(data["name"], data["cfg"])) + else: + await self._init_mcp_clients() + elif data["type"] == "terminate": + if "name" in data: + await self._terminate_mcp_client(data["name"]) + else: + for name in self.mcp_client_dict.keys(): + await self._terminate_mcp_client(name) + + async def _init_mcp_client(self, name: str, config: dict) -> None: + """初始化单个MCP客户端""" + try: + # 先清理之前的客户端,如果存在 + if name in self.mcp_client_dict: + await self._terminate_mcp_client(name) - for mcp_server_name, mcp_server_script_path in mcp_server_json_obj[ - "mcpServers" - ].items(): - if not os.path.exists(mcp_server_script_path["script_path"]): - logger.error( - f"MCP server import err: Server script {mcp_server_script_path['script_path']} not found." - ) - continue mcp_client = MCPClient() - mcp_client.name = mcp_server_name - await mcp_client.connect_to_server(mcp_server_script_path["script_path"]) - self.mcp_client_dict[mcp_server_name] = mcp_client - logger.info(f"添加 MCP 服务 {mcp_server_name}") - if len(self.mcp_client_dict) == 0: - logger.info("未启用任何 MCP 服务") + mcp_client.name = name + await mcp_client.connect_to_server(config) + tools_res = await mcp_client.list_tools_and_save() + tool_names = [tool.name for tool in tools_res.tools] + self.mcp_client_dict[name] = mcp_client - async def get_func_desc_openai_style(self) -> list: + # 移除该MCP服务之前的工具(如有) + self.func_list = [ + f + for f in self.func_list + if not (f.origin == "mcp" and f.mcp_server_name == name) + ] + + # 将 MCP 工具转换为 FuncTool 并添加到 func_list + for tool in mcp_client.tools: + func_tool = FuncTool( + name=tool.name, + parameters=tool.inputSchema, + description=tool.description, + origin="mcp", + mcp_server_name=name, + mcp_client=mcp_client, + ) + self.func_list.append(func_tool) + + logger.info(f"已连接 MCP 服务 {name}, Tools: {tool_names}") + return True + except Exception as e: + logger.error(f"初始化 MCP 客户端 {name} 失败: {e}") + # 发生错误时确保客户端被清理 + if name in self.mcp_client_dict: + await self._terminate_mcp_client(name) + return False + + async def _terminate_mcp_client(self, name: str) -> None: + """关闭并清理MCP客户端""" + if name in self.mcp_client_dict: + try: + # 关闭MCP连接 + await self.mcp_client_dict[name].cleanup() + del self.mcp_client_dict[name] + except Exception as e: + logger.info( + f"清空 MCP 客户端资源 {name}: {e}。" + ) + # 移除关联的FuncTool + self.func_list = [ + f + for f in self.func_list + if not (f.origin == "mcp" and f.mcp_server_name == name) + ] + logger.info(f"已关闭 MCP 服务 {name}") + + def get_func_desc_openai_style(self) -> list: """ 获得 OpenAI API 风格的**已经激活**的工具描述 """ _l = [] + # 处理所有工具(包括本地和MCP工具) for f in self.func_list: if not f.active: continue @@ -199,20 +327,6 @@ class FuncCall: }, } ) - - for name, client in self.mcp_client_dict.items(): - responses = await client.session.list_tools() - for tool in responses.tools: - _l.append( - { - "type": "function", - "function": { - "name": f"mcp:{name}:{tool.name}", - "parameters": tool.inputSchema, - "description": tool.description, - }, - } - ) return _l def get_func_desc_anthropic_style(self) -> list: @@ -265,9 +379,9 @@ class FuncCall: continue _l.append( { - "name": f["name"], - "parameters": f["parameters"], - "description": f["description"], + "name": f.name, + "parameters": f.parameters, + "description": f.description, } ) func_definition = json.dumps(_l, ensure_ascii=False) @@ -317,14 +431,11 @@ class FuncCall: func_name = tool["name"] args = tool["args"] # 调用函数 - tool_callable = None - for func in self.func_list: - if func.name == func_name: - tool_callable = func.star_handler_metadata.handler - break - if not tool_callable: + func_tool = self.get_func(func_name) + if not func_tool: raise Exception(f"Request function {func_name} not found.") - ret = await tool_callable(**args) + + ret = await func_tool.execute(**args) if ret: tool_call_result.append(str(ret)) return tool_call_result, True @@ -334,3 +445,8 @@ class FuncCall: def __repr__(self): return str(self.func_list) + + async def terminate(self): + for name in self.mcp_client_dict.keys(): + await self._terminate_mcp_client(name) + logger.debug(f"清理 MCP 客户端 {name} 资源") diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py index c1acf74c0..71b38682f 100644 --- a/astrbot/core/provider/manager.py +++ b/astrbot/core/provider/manager.py @@ -1,4 +1,5 @@ import traceback +import asyncio from astrbot.core.config.astrbot_config import AstrBotConfig from .provider import Provider, STTProvider, TTSProvider, Personality from .entites import ProviderType @@ -127,8 +128,9 @@ class ProviderManager: if self.tts_enabled and not self.curr_tts_provider_inst: logger.warning("未启用任何用于 文本转语音 的提供商适配器。") - # 初始化mcpclient连接 - await self.llm_tools.init_mcp_client_list() + # 初始化 MCP Client 连接 + asyncio.create_task(self.llm_tools.mcp_service_selector(), name="mcp-service-handler") + self.llm_tools.mcp_service_queue.put_nowait({"type": "init"}) async def load_provider(self, provider_config: dict): if not provider_config["enable"]: @@ -342,3 +344,5 @@ class ProviderManager: for provider_inst in self.provider_insts: if hasattr(provider_inst, "terminate"): await provider_inst.terminate() + # 清理 MCP Client 连接 + await self.llm_tools.mcp_service_queue.put({"type": "terminate"}) diff --git a/astrbot/core/provider/sources/llmtuner_source.py b/astrbot/core/provider/sources/llmtuner_source.py index adb5bf428..bfd9e03a5 100644 --- a/astrbot/core/provider/sources/llmtuner_source.py +++ b/astrbot/core/provider/sources/llmtuner_source.py @@ -85,7 +85,7 @@ class LLMTunerModelLoader(Provider): "system": system_prompt, } if func_tool: - tool_list = await func_tool.get_func_desc_openai_style() + tool_list = func_tool.get_func_desc_openai_style() if tool_list: conf["tools"] = tool_list diff --git a/astrbot/core/provider/sources/openai_source.py b/astrbot/core/provider/sources/openai_source.py index 342c2febb..897fd4e7e 100644 --- a/astrbot/core/provider/sources/openai_source.py +++ b/astrbot/core/provider/sources/openai_source.py @@ -80,7 +80,7 @@ class ProviderOpenAIOfficial(Provider): async def _query(self, payloads: dict, tools: FuncCall) -> LLMResponse: if tools: - tool_list = await tools.get_func_desc_openai_style() + tool_list = tools.get_func_desc_openai_style() if tool_list: payloads["tools"] = tool_list @@ -124,11 +124,6 @@ class ProviderOpenAIOfficial(Provider): for tool in tools.func_list: if tool.name == tool_call.function.name: args = json.loads(tool_call.function.arguments) - if ( - tool_call.function.name.startswith("mcp:") - and tool_call.function.name.split(":")[1] in tools.mcp_client_dict - ): - args = json.loads(tool_call.function.arguments) args_ls.append(args) func_name_ls.append(tool_call.function.name) llm_response.role = "tool" diff --git a/astrbot/dashboard/routes/__init__.py b/astrbot/dashboard/routes/__init__.py index f4107bdc5..2e5461981 100644 --- a/astrbot/dashboard/routes/__init__.py +++ b/astrbot/dashboard/routes/__init__.py @@ -6,6 +6,7 @@ from .stat import StatRoute from .log import LogRoute from .static_file import StaticFileRoute from .chat import ChatRoute +from .tools import ToolsRoute # 导入新的ToolsRoute __all__ = [ @@ -17,4 +18,5 @@ __all__ = [ "LogRoute", "StaticFileRoute", "ChatRoute", + "ToolsRoute", # 添加新的ToolsRoute ] diff --git a/astrbot/dashboard/routes/config.py b/astrbot/dashboard/routes/config.py index 14af21bbc..dcfe50d38 100644 --- a/astrbot/dashboard/routes/config.py +++ b/astrbot/dashboard/routes/config.py @@ -146,6 +146,7 @@ class ConfigRoute(Route): "/config/provider/new": ("POST", self.post_new_provider), "/config/provider/update": ("POST", self.post_update_provider), "/config/provider/delete": ("POST", self.post_delete_provider), + "/config/llmtools": ("GET", self.get_llm_tools), } self.register_routes() @@ -278,6 +279,12 @@ class ConfigRoute(Route): return Response().error(str(e)).__dict__ return Response().ok(None, "删除成功,已经实时生效~").__dict__ + async def get_llm_tools(self): + """获取函数调用工具。包含了本地加载的以及 MCP 服务的工具""" + tool_mgr = self.core_lifecycle.provider_manager.llm_tools + tools = tool_mgr.get_func_desc_openai_style() + return Response().ok(tools).__dict__ + async def _get_astrbot_config(self): config = self.config diff --git a/astrbot/dashboard/routes/tools.py b/astrbot/dashboard/routes/tools.py new file mode 100644 index 000000000..8ba166784 --- /dev/null +++ b/astrbot/dashboard/routes/tools.py @@ -0,0 +1,250 @@ +import os +import json +import traceback +from .route import Route, Response, RouteContext +from quart import request +from astrbot.core.core_lifecycle import AstrBotCoreLifecycle +from astrbot.core import logger + +DEFAULT_MCP_CONFIG = {"mcpServers": {}} + + +class ToolsRoute(Route): + def __init__( + self, context: RouteContext, core_lifecycle: AstrBotCoreLifecycle + ) -> None: + super().__init__(context) + self.core_lifecycle = core_lifecycle + self.routes = { + "/tools/mcp/servers": ("GET", self.get_mcp_servers), + "/tools/mcp/add": ("POST", self.add_mcp_server), + "/tools/mcp/update": ("POST", self.update_mcp_server), + "/tools/mcp/delete": ("POST", self.delete_mcp_server), + } + self.register_routes() + self.tool_mgr = self.core_lifecycle.provider_manager.llm_tools + + @property + def mcp_config_path(self): + current_dir = os.path.dirname(os.path.abspath(__file__)) + data_dir = os.path.abspath(os.path.join(current_dir, "../../../data")) + return os.path.join(data_dir, "mcp_server.json") + + def load_mcp_config(self): + if not os.path.exists(self.mcp_config_path): + # 配置文件不存在,创建默认配置 + os.makedirs(os.path.dirname(self.mcp_config_path), exist_ok=True) + with open(self.mcp_config_path, "w", encoding="utf-8") as f: + json.dump(DEFAULT_MCP_CONFIG, f, ensure_ascii=False, indent=4) + return DEFAULT_MCP_CONFIG + + try: + with open(self.mcp_config_path, "r", encoding="utf-8") as f: + return json.load(f) + except Exception as e: + logger.error(f"加载 MCP 配置失败: {e}") + return DEFAULT_MCP_CONFIG + + def save_mcp_config(self, config): + try: + with open(self.mcp_config_path, "w", encoding="utf-8") as f: + json.dump(config, f, ensure_ascii=False, indent=4) + return True + except Exception as e: + logger.error(f"保存 MCP 配置失败: {e}") + return False + + async def get_mcp_servers(self): + try: + config = self.load_mcp_config() + servers = [] + + # 获取所有服务器并添加它们的工具列表 + for name, server_config in config["mcpServers"].items(): + server_info = { + "name": name, + "active": server_config.get("active", True), + } + + # 复制所有配置字段 + for key, value in server_config.items(): + if key != "active": # active 已经处理 + server_info[key] = value + + # 如果MCP客户端已初始化,从客户端获取工具名称 + for ( + name_key, + mcp_client, + ) in self.tool_mgr.mcp_client_dict.items(): + if name_key == name: + server_info["tools"] = [tool.name for tool in mcp_client.tools] + break + else: + server_info["tools"] = [] + + servers.append(server_info) + + return Response().ok(servers).__dict__ + except Exception as e: + logger.error(traceback.format_exc()) + return Response().error(f"获取 MCP 服务器列表失败: {str(e)}").__dict__ + + async def add_mcp_server(self): + try: + server_data = await request.json + + name = server_data.get("name", "") + + # 检查必填字段 + if not name: + return Response().error("服务器名称不能为空").__dict__ + + # 移除特殊字段并检查配置是否有效 + has_valid_config = False + server_config = {"active": server_data.get("active", True)} + + # 复制所有配置字段 + for key, value in server_data.items(): + if key not in ["name", "active", "tools"]: # 排除特殊字段 + server_config[key] = value + has_valid_config = True + + if not has_valid_config: + return Response().error("必须提供有效的服务器配置").__dict__ + + config = self.load_mcp_config() + + if name in config["mcpServers"]: + return Response().error(f"服务器 {name} 已存在").__dict__ + + config["mcpServers"][name] = server_config + + if self.save_mcp_config(config): + # 动态初始化新MCP客户端 + self.tool_mgr.mcp_service_queue.put_nowait( + { + "type": "init", + "name": name, + "cfg": config["mcpServers"][name], + } + ) + return Response().ok(None, f"成功添加 MCP 服务器 {name}").__dict__ + else: + return Response().error("保存配置失败").__dict__ + except Exception as e: + logger.error(traceback.format_exc()) + return Response().error(f"添加 MCP 服务器失败: {str(e)}").__dict__ + + async def update_mcp_server(self): + try: + server_data = await request.json + + name = server_data.get("name", "") + + if not name: + return Response().error("服务器名称不能为空").__dict__ + + config = self.load_mcp_config() + + if name not in config["mcpServers"]: + return Response().error(f"服务器 {name} 不存在").__dict__ + + # 获取活动状态 + active = server_data.get("active", config["mcpServers"][name].get("active", True)) + + # 创建新的配置对象 + server_config = {"active": active} + + # 仅更新活动状态的特殊处理 + only_update_active = True + + # 复制所有配置字段 + for key, value in server_data.items(): + if key not in ["name", "active", "tools"]: # 排除特殊字段 + server_config[key] = value + only_update_active = False + + # 如果只更新活动状态,保留原始配置 + if only_update_active: + for key, value in config["mcpServers"][name].items(): + if key != "active": # 除了active之外的所有字段都保留 + server_config[key] = value + + config["mcpServers"][name] = server_config + + if self.save_mcp_config(config): + # 处理MCP客户端状态变化 + if active: + # 如果要激活服务器或者配置已更改 + if name in self.tool_mgr.mcp_client_dict or not only_update_active: + await self.tool_mgr.mcp_service_queue.put( + { + "type": "terminate", + "name": name, + } + ) + await self.tool_mgr.mcp_service_queue.put( + { + "type": "init", + "name": name, + "cfg": config["mcpServers"][name], + } + ) + else: + # 客户端不存在,初始化 + self.tool_mgr.mcp_service_queue.put_nowait( + { + "type": "init", + "name": name, + "cfg": config["mcpServers"][name], + } + ) + else: + # 如果要停用服务器 + if name in self.tool_mgr.mcp_client_dict: + self.tool_mgr.mcp_service_queue.put_nowait( + { + "type": "terminate", + "name": name, + } + ) + + return Response().ok(None, f"成功更新 MCP 服务器 {name}").__dict__ + else: + return Response().error("保存配置失败").__dict__ + except Exception as e: + logger.error(traceback.format_exc()) + return Response().error(f"更新 MCP 服务器失败: {str(e)}").__dict__ + + async def delete_mcp_server(self): + try: + server_data = await request.json + name = server_data.get("name", "") + + if not name: + return Response().error("服务器名称不能为空").__dict__ + + config = self.load_mcp_config() + + if name not in config["mcpServers"]: + return Response().error(f"服务器 {name} 不存在").__dict__ + + # 删除服务器配置 + del config["mcpServers"][name] + + if self.save_mcp_config(config): + # 关闭并删除MCP客户端 + if name in self.tool_mgr.mcp_client_dict: + self.tool_mgr.mcp_service_queue.put_nowait( + { + "type": "terminate", + "name": name, + } + ) + + return Response().ok(None, f"成功删除 MCP 服务器 {name}").__dict__ + else: + return Response().error("保存配置失败").__dict__ + except Exception as e: + logger.error(traceback.format_exc()) + return Response().error(f"删除 MCP 服务器失败: {str(e)}").__dict__ diff --git a/astrbot/dashboard/server.py b/astrbot/dashboard/server.py index 45aac3cd6..9af11dd53 100644 --- a/astrbot/dashboard/server.py +++ b/astrbot/dashboard/server.py @@ -50,6 +50,7 @@ class AstrBotDashboard: self.sfr = StaticFileRoute(self.context) self.ar = AuthRoute(self.context) self.chat_route = ChatRoute(self.context, db, core_lifecycle) + self.tools_root = ToolsRoute(self.context, core_lifecycle) self.shutdown_event = shutdown_event diff --git a/dashboard/src/components/shared/ItemCardGrid.vue b/dashboard/src/components/shared/ItemCardGrid.vue new file mode 100644 index 000000000..a1ed1609e --- /dev/null +++ b/dashboard/src/components/shared/ItemCardGrid.vue @@ -0,0 +1,134 @@ + + + + + diff --git a/dashboard/src/layouts/full/vertical-sidebar/sidebarItem.ts b/dashboard/src/layouts/full/vertical-sidebar/sidebarItem.ts index d2fbb556e..5325e0cfd 100644 --- a/dashboard/src/layouts/full/vertical-sidebar/sidebarItem.ts +++ b/dashboard/src/layouts/full/vertical-sidebar/sidebarItem.ts @@ -45,6 +45,11 @@ const sidebarItem: menu[] = [ icon: 'mdi-storefront', to: '/extension-marketplace' }, + { + title: '函数调用', + icon: 'mdi-function-variant', + to: '/tool-use' + }, { title: '聊天', icon: 'mdi-chat', diff --git a/dashboard/src/router/MainRoutes.ts b/dashboard/src/router/MainRoutes.ts index 5fb0abeea..9ce2402ca 100644 --- a/dashboard/src/router/MainRoutes.ts +++ b/dashboard/src/router/MainRoutes.ts @@ -31,6 +31,11 @@ const MainRoutes = { path: '/providers', component: () => import('@/views/ProviderPage.vue') }, + { + name: 'ToolUsePage', + path: '/tool-use', + component: () => import('@/views/ToolUsePage.vue') + }, { name: 'Configs', path: '/config', diff --git a/dashboard/src/views/PlatformPage.vue b/dashboard/src/views/PlatformPage.vue index ac3d24ad5..38fa08e68 100644 --- a/dashboard/src/views/PlatformPage.vue +++ b/dashboard/src/views/PlatformPage.vue @@ -1,316 +1,304 @@ - - \ No newline at end of file diff --git a/dashboard/src/views/ProviderPage.vue b/dashboard/src/views/ProviderPage.vue index c29247fb3..e2e1b796b 100644 --- a/dashboard/src/views/ProviderPage.vue +++ b/dashboard/src/views/ProviderPage.vue @@ -1,291 +1,328 @@ - - \ No newline at end of file diff --git a/dashboard/src/views/ToolUsePage.vue b/dashboard/src/views/ToolUsePage.vue new file mode 100644 index 000000000..f67496c43 --- /dev/null +++ b/dashboard/src/views/ToolUsePage.vue @@ -0,0 +1,631 @@ + + + + + \ No newline at end of file From 98e7ed6920caa1847d8cd6278779bee74406e0ae Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 23 Mar 2025 08:34:05 +0000 Subject: [PATCH 10/16] :balloon: auto fixes by pre-commit hooks --- .../process_stage/method/llm_request.py | 4 +++- astrbot/core/provider/func_tool_manager.py | 8 +++---- astrbot/core/provider/manager.py | 4 +++- astrbot/dashboard/routes/tools.py | 24 ++++++++++--------- 4 files changed, 23 insertions(+), 17 deletions(-) diff --git a/astrbot/core/pipeline/process_stage/method/llm_request.py b/astrbot/core/pipeline/process_stage/method/llm_request.py index 3767cc59d..9689ef34e 100644 --- a/astrbot/core/pipeline/process_stage/method/llm_request.py +++ b/astrbot/core/pipeline/process_stage/method/llm_request.py @@ -181,7 +181,9 @@ class LLMRequestSubStage(Stage): logger.info( f"从 MCP 服务 {func_tool.mcp_server_name} 调用工具函数:{func_tool.name},参数:{func_tool_args}" ) - client = req.func_tool.mcp_client_dict[func_tool.mcp_server_name] + client = req.func_tool.mcp_client_dict[ + func_tool.mcp_server_name + ] res = await client.session.call_tool( func_tool.name, func_tool_args ) diff --git a/astrbot/core/provider/func_tool_manager.py b/astrbot/core/provider/func_tool_manager.py index 93d99de65..b795027da 100644 --- a/astrbot/core/provider/func_tool_manager.py +++ b/astrbot/core/provider/func_tool_manager.py @@ -237,7 +237,9 @@ class FuncCall: data = await self.mcp_service_queue.get() if data["type"] == "init": if "name" in data: - asyncio.create_task(self._init_mcp_client(data["name"], data["cfg"])) + asyncio.create_task( + self._init_mcp_client(data["name"], data["cfg"]) + ) else: await self._init_mcp_clients() elif data["type"] == "terminate": @@ -297,9 +299,7 @@ class FuncCall: await self.mcp_client_dict[name].cleanup() del self.mcp_client_dict[name] except Exception as e: - logger.info( - f"清空 MCP 客户端资源 {name}: {e}。" - ) + logger.info(f"清空 MCP 客户端资源 {name}: {e}。") # 移除关联的FuncTool self.func_list = [ f diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py index 71b38682f..ef9040445 100644 --- a/astrbot/core/provider/manager.py +++ b/astrbot/core/provider/manager.py @@ -129,7 +129,9 @@ class ProviderManager: logger.warning("未启用任何用于 文本转语音 的提供商适配器。") # 初始化 MCP Client 连接 - asyncio.create_task(self.llm_tools.mcp_service_selector(), name="mcp-service-handler") + asyncio.create_task( + self.llm_tools.mcp_service_selector(), name="mcp-service-handler" + ) self.llm_tools.mcp_service_queue.put_nowait({"type": "init"}) async def load_provider(self, provider_config: dict): diff --git a/astrbot/dashboard/routes/tools.py b/astrbot/dashboard/routes/tools.py index 8ba166784..36da48ce4 100644 --- a/astrbot/dashboard/routes/tools.py +++ b/astrbot/dashboard/routes/tools.py @@ -65,7 +65,7 @@ class ToolsRoute(Route): "name": name, "active": server_config.get("active", True), } - + # 复制所有配置字段 for key, value in server_config.items(): if key != "active": # active 已经处理 @@ -94,7 +94,7 @@ class ToolsRoute(Route): server_data = await request.json name = server_data.get("name", "") - + # 检查必填字段 if not name: return Response().error("服务器名称不能为空").__dict__ @@ -102,13 +102,13 @@ class ToolsRoute(Route): # 移除特殊字段并检查配置是否有效 has_valid_config = False server_config = {"active": server_data.get("active", True)} - + # 复制所有配置字段 for key, value in server_data.items(): if key not in ["name", "active", "tools"]: # 排除特殊字段 server_config[key] = value has_valid_config = True - + if not has_valid_config: return Response().error("必须提供有效的服务器配置").__dict__ @@ -140,7 +140,7 @@ class ToolsRoute(Route): server_data = await request.json name = server_data.get("name", "") - + if not name: return Response().error("服务器名称不能为空").__dict__ @@ -150,26 +150,28 @@ class ToolsRoute(Route): return Response().error(f"服务器 {name} 不存在").__dict__ # 获取活动状态 - active = server_data.get("active", config["mcpServers"][name].get("active", True)) - + active = server_data.get( + "active", config["mcpServers"][name].get("active", True) + ) + # 创建新的配置对象 server_config = {"active": active} - + # 仅更新活动状态的特殊处理 only_update_active = True - + # 复制所有配置字段 for key, value in server_data.items(): if key not in ["name", "active", "tools"]: # 排除特殊字段 server_config[key] = value only_update_active = False - + # 如果只更新活动状态,保留原始配置 if only_update_active: for key, value in config["mcpServers"][name].items(): if key != "active": # 除了active之外的所有字段都保留 server_config[key] = value - + config["mcpServers"][name] = server_config if self.save_mcp_config(config): From 873b7715f40b703bd50558cca39f826aaebf2c3a Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Sun, 23 Mar 2025 16:50:23 +0800 Subject: [PATCH 11/16] =?UTF-8?q?=F0=9F=8E=88=20perf:=20=E4=BC=98=E5=8C=96?= =?UTF-8?q?=20MCP=20Client=20=E5=BC=82=E6=AD=A5=20Event=20=E7=AE=A1?= =?UTF-8?q?=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/core/provider/func_tool_manager.py | 36 +++++++++++++++++++--- 1 file changed, 32 insertions(+), 4 deletions(-) diff --git a/astrbot/core/provider/func_tool_manager.py b/astrbot/core/provider/func_tool_manager.py index b795027da..fd4a14f7f 100644 --- a/astrbot/core/provider/func_tool_manager.py +++ b/astrbot/core/provider/func_tool_manager.py @@ -127,6 +127,7 @@ class FuncCall: """MCP 服务列表""" self.mcp_service_queue = asyncio.Queue() """用于外部控制 MCP 服务的启停""" + self.mcp_client_event: Dict[str, asyncio.Event] = {} def empty(self) -> bool: return len(self.func_list) == 0 @@ -218,7 +219,11 @@ class FuncCall: for name in mcp_server_json_obj.keys(): cfg = mcp_server_json_obj[name] if cfg.get("active", True): - asyncio.create_task(self._init_mcp_client(name, cfg)) + event = asyncio.Event() + asyncio.create_task( + self._init_mcp_client_task_wrapper(name, cfg, event) + ) + self.mcp_client_event[name] = event async def mcp_service_selector(self): """为了避免在不同异步任务中控制 MCP 服务导致的报错,整个项目统一通过这个 Task 来控制 @@ -237,17 +242,40 @@ class FuncCall: data = await self.mcp_service_queue.get() if data["type"] == "init": if "name" in data: + event = asyncio.Event() asyncio.create_task( - self._init_mcp_client(data["name"], data["cfg"]) + self._init_mcp_client_task_wrapper( + data["name"], data["cfg"], event + ) ) + self.mcp_client_event[data["name"]] = event else: await self._init_mcp_clients() elif data["type"] == "terminate": if "name" in data: - await self._terminate_mcp_client(data["name"]) + # await self._terminate_mcp_client(data["name"]) + if data["name"] in self.mcp_client_event: + self.mcp_client_event[data["name"]].set() + self.mcp_client_event.pop(data["name"], None) else: for name in self.mcp_client_dict.keys(): - await self._terminate_mcp_client(name) + # await self._terminate_mcp_client(name) + # self.mcp_client_event[name].set() + if name in self.mcp_client_event: + self.mcp_client_event[name].set() + self.mcp_client_event.pop(name, None) + + async def _init_mcp_client_task_wrapper( + self, name: str, cfg: dict, event: asyncio.Event + ) -> None: + """初始化 MCP 客户端的包装函数,用于捕获异常""" + try: + await self._init_mcp_client(name, cfg) + await event.wait() + logger.info(f"收到 MCP 客户端 {name} 终止信号") + await self._terminate_mcp_client(name) + except Exception as e: + logger.error(f"初始化 MCP 客户端 {name} 失败: {e}") async def _init_mcp_client(self, name: str, config: dict) -> None: """初始化单个MCP客户端""" From 4942d0a6292dc346c3cf8938c00ef223826c03eb Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Sun, 23 Mar 2025 17:00:38 +0800 Subject: [PATCH 12/16] =?UTF-8?q?=E2=9C=A8=20feat:=20=E5=9C=A8=E5=B7=A5?= =?UTF-8?q?=E5=85=B7=E4=BD=BF=E7=94=A8=E9=A1=B5=E9=9D=A2=E6=B7=BB=E5=8A=A0?= =?UTF-8?q?=E5=87=BD=E6=95=B0=E8=B0=83=E7=94=A8=E4=BF=A1=E6=81=AF=E6=8F=90?= =?UTF-8?q?=E7=A4=BA=E5=92=8C=E9=93=BE=E6=8E=A5=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- dashboard/src/views/ToolUsePage.vue | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/dashboard/src/views/ToolUsePage.vue b/dashboard/src/views/ToolUsePage.vue index f67496c43..39164ad36 100644 --- a/dashboard/src/views/ToolUsePage.vue +++ b/dashboard/src/views/ToolUsePage.vue @@ -7,8 +7,22 @@

mdi-function-variant函数工具管理

-

- 管理 MCP 服务器和查看可用的函数工具 +

+ 管理 MCP 服务器和查看可用的函数工具 + + + 函数调用和 MCP 是什么? +

@@ -393,6 +407,9 @@ export default { }, methods: { + openurl(url) { + window.open(url, '_blank'); + }, formatToolName(name) { if (name.includes(':')) { // MCP 工具通常命名为 mcp:server:tool From c59c8e05f76a1aa3c88a4bd434e39f680a5205ab Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Sun, 23 Mar 2025 17:03:18 +0800 Subject: [PATCH 13/16] =?UTF-8?q?=F0=9F=90=9B=20fix:=20tools=20result?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/core/provider/sources/openai_source.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/astrbot/core/provider/sources/openai_source.py b/astrbot/core/provider/sources/openai_source.py index 897fd4e7e..4b66c50ef 100644 --- a/astrbot/core/provider/sources/openai_source.py +++ b/astrbot/core/provider/sources/openai_source.py @@ -124,8 +124,8 @@ class ProviderOpenAIOfficial(Provider): for tool in tools.func_list: if tool.name == tool_call.function.name: args = json.loads(tool_call.function.arguments) - args_ls.append(args) - func_name_ls.append(tool_call.function.name) + args_ls.append(args) + func_name_ls.append(tool_call.function.name) llm_response.role = "tool" llm_response.tools_call_args = args_ls llm_response.tools_call_name = func_name_ls From 1aac8d8041f751db4a37c638086d3faa5d735fb2 Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Sun, 23 Mar 2025 22:21:47 +0800 Subject: [PATCH 14/16] =?UTF-8?q?=E2=9C=A8=20feat:=20=E9=80=82=E9=85=8D?= =?UTF-8?q?=E5=AE=8C=E6=95=B4=E7=9A=84=20function-calling=20=E6=B5=81?= =?UTF-8?q?=E7=A8=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../process_stage/method/llm_request.py | 228 +++++++++++------- astrbot/core/provider/entites.py | 87 ++++++- astrbot/core/provider/func_tool_manager.py | 7 + astrbot/core/provider/provider.py | 4 +- .../core/provider/sources/anthropic_source.py | 10 +- .../core/provider/sources/gemini_source.py | 44 +++- .../core/provider/sources/openai_source.py | 15 +- 7 files changed, 291 insertions(+), 104 deletions(-) diff --git a/astrbot/core/pipeline/process_stage/method/llm_request.py b/astrbot/core/pipeline/process_stage/method/llm_request.py index 9689ef34e..acde70c27 100644 --- a/astrbot/core/pipeline/process_stage/method/llm_request.py +++ b/astrbot/core/pipeline/process_stage/method/llm_request.py @@ -16,7 +16,13 @@ from astrbot.core.message.message_event_result import ( from astrbot.core.message.components import Image from astrbot.core import logger from astrbot.core.utils.metrics import Metric -from astrbot.core.provider.entites import ProviderRequest, LLMResponse +from astrbot.core.provider.entites import ( + ProviderRequest, + LLMResponse, + ToolCallMessageSegment, + AssistantMessageSegment, + ToolCallsResult, +) from astrbot.core.star.star_handler import star_handlers_registry, EventType from astrbot.core.star.star import star_map @@ -111,10 +117,18 @@ class LLMRequestSubStage(Stage): req.contexts = json.loads(req.contexts) try: - logger.debug(f"提供商请求 Payload: {req}") - if _nested: - req.func_tool = None # 暂时不支持递归工具调用 - llm_response = await provider.text_chat(**req.__dict__) # 请求 LLM + need_loop = True + while need_loop: + need_loop = False + logger.debug(f"提供商请求 Payload: {req}") + llm_response = await provider.text_chat(**req.__dict__) # 请求 LLM + async for result in self._handle_llm_response(event, req, llm_response): + if isinstance(result, ProviderRequest): + # 有函数工具调用并且返回了结果,我们需要再次请求 LLM + req = result + need_loop = True + else: + yield # 执行 LLM 响应后的事件钩子。 handlers = star_handlers_registry.get_handlers_by_event_type( @@ -135,9 +149,6 @@ class LLMRequestSubStage(Stage): ) return - # 保存到历史记录 - await self._save_to_history(event, req, llm_response) - asyncio.create_task( Metric.upload( llm_tick=1, @@ -146,88 +157,8 @@ class LLMRequestSubStage(Stage): ) ) - if llm_response.role == "assistant": - # text completion - if llm_response.result_chain: - event.set_result( - MessageEventResult( - chain=llm_response.result_chain.chain - ).set_result_content_type(ResultContentType.LLM_RESULT) - ) - else: - event.set_result( - MessageEventResult() - .message(llm_response.completion_text) - .set_result_content_type(ResultContentType.LLM_RESULT) - ) - elif llm_response.role == "err": - event.set_result( - MessageEventResult().message( - f"AstrBot 请求失败。\n错误信息: {llm_response.completion_text}" - ) - ) - elif llm_response.role == "tool": - # function calling - function_calling_result = {} - logger.info( - f"触发 {len(llm_response.tools_call_name)} 个函数调用: {llm_response.tools_call_name}" - ) - for func_tool_name, func_tool_args in zip( - llm_response.tools_call_name, llm_response.tools_call_args - ): - try: - func_tool = req.func_tool.get_func(func_tool_name) - if func_tool.origin == "mcp": - logger.info( - f"从 MCP 服务 {func_tool.mcp_server_name} 调用工具函数:{func_tool.name},参数:{func_tool_args}" - ) - client = req.func_tool.mcp_client_dict[ - func_tool.mcp_server_name - ] - res = await client.session.call_tool( - func_tool.name, func_tool_args - ) - if res: - # TODO content的类型可能包括list[TextContent | ImageContent | EmbeddedResource],这里只处理了TextContent。 - res_event = event.plain_result(res.content[0].text) - event.set_result(res_event) - yield - else: - logger.info( - f"调用工具函数:{func_tool_name},参数:{func_tool_args}" - ) - # 尝试调用工具函数 - wrapper = self._call_handler( - self.ctx, event, func_tool.handler, **func_tool_args - ) - async for resp in wrapper: - if resp is not None: # 有 return 返回 - function_calling_result[func_tool_name] = resp - else: - yield # 有生成器返回 - event.clear_result() # 清除上一个 handler 的结果 - except BaseException as e: - logger.warning(traceback.format_exc()) - function_calling_result[func_tool_name] = ( - "When calling the function, an error occurred: " + str(e) - ) - if function_calling_result: - # 工具返回 LLM 资源。比如 RAG、网页 得到的相关结果等。 - # 我们重新执行一遍这个 stage - req.func_tool = None # 暂时不支持递归工具调用 - extra_prompt = "\n\nSystem executed some external tools for this task and here are the results:\n" - for tool_name, tool_result in function_calling_result.items(): - extra_prompt += ( - f"Tool: {tool_name}\nTool Result: {tool_result}\n" - ) - req.prompt += extra_prompt - async for _ in self.process(event, _nested=True): - yield - else: - if llm_response.completion_text: - event.set_result( - MessageEventResult().message(llm_response.completion_text) - ) + # 保存到历史记录 + await self._save_to_history(event, req, llm_response) except BaseException as e: logger.error(traceback.format_exc()) @@ -238,6 +169,116 @@ class LLMRequestSubStage(Stage): ) return + async def _handle_llm_response( + self, event: AstrMessageEvent, req: ProviderRequest, llm_response: LLMResponse + ) -> AsyncGenerator[None, None]: + """处理 LLM 响应。 + + Returns: + bool: 是否需要继续调用 LLM + + Yields: + Iterator[bool]: 将 event 交付给下一个 stage + """ + if llm_response.role == "assistant": + # text completion + if llm_response.result_chain: + event.set_result( + MessageEventResult( + chain=llm_response.result_chain.chain + ).set_result_content_type(ResultContentType.LLM_RESULT) + ) + else: + event.set_result( + MessageEventResult() + .message(llm_response.completion_text) + .set_result_content_type(ResultContentType.LLM_RESULT) + ) + elif llm_response.role == "err": + event.set_result( + MessageEventResult().message( + f"AstrBot 请求失败。\n错误信息: {llm_response.completion_text}" + ) + ) + elif llm_response.role == "tool": + # function calling + tool_call_result: list[ToolCallMessageSegment] = [] + logger.info( + f"触发 {len(llm_response.tools_call_name)} 个函数调用: {llm_response.tools_call_name}" + ) + for func_tool_name, func_tool_args, func_tool_id in zip( + llm_response.tools_call_name, + llm_response.tools_call_args, + llm_response.tools_call_ids, + ): + try: + func_tool = req.func_tool.get_func(func_tool_name) + if func_tool.origin == "mcp": + logger.info( + f"从 MCP 服务 {func_tool.mcp_server_name} 调用工具函数:{func_tool.name},参数:{func_tool_args}" + ) + client = req.func_tool.mcp_client_dict[ + func_tool.mcp_server_name + ] + res = await client.session.call_tool( + func_tool.name, func_tool_args + ) + if res: + # TODO content的类型可能包括list[TextContent | ImageContent | EmbeddedResource],这里只处理了TextContent。 + tool_call_result.append( + ToolCallMessageSegment( + role="tool", + tool_call_id=func_tool_id, + content=res.content[0].text, + ) + ) + else: + logger.info( + f"调用工具函数:{func_tool_name},参数:{func_tool_args}" + ) + # 尝试调用工具函数 + wrapper = self._call_handler( + self.ctx, event, func_tool.handler, **func_tool_args + ) + async for resp in wrapper: + if resp is not None: # 有 return 返回 + tool_call_result.append( + ToolCallMessageSegment( + role="tool", + tool_call_id=func_tool_id, + content=resp, + ) + ) + else: + yield # 有生成器返回 + event.clear_result() # 清除上一个 handler 的结果 + except BaseException as e: + logger.warning(traceback.format_exc()) + tool_call_result.append( + ToolCallMessageSegment( + role="tool", + tool_call_id=func_tool_id, + content=f"error: {str(e)}", + ) + ) + if tool_call_result: + # 函数调用结果 + req.func_tool = None # 暂时不支持递归工具调用 + assistant_msg_seg = AssistantMessageSegment( + role="assistant", tool_calls=llm_response.to_openai_tool_calls() + ) + # 在多轮 Tool 调用的情况下,这里始终保持最新的 Tool 调用结果,减少上下文长度。 + req.tool_calls_result = ToolCallsResult( + tool_calls_info=assistant_msg_seg, + tool_calls_result=tool_call_result, + ) + yield req # 再次执行 LLM 请求 + else: + if llm_response.completion_text: + event.set_result( + MessageEventResult().message(llm_response.completion_text) + ) + async def _save_to_history( self, event: AstrMessageEvent, req: ProviderRequest, llm_response: LLMResponse ): @@ -248,6 +289,13 @@ class LLMRequestSubStage(Stage): # 文本回复 contexts = req.contexts contexts.append(await req.assemble_context()) + + # tool calls result + if req.tool_calls_result: + contexts.extend( + req.tool_calls_result.to_openai_messages() + ) + contexts.append( {"role": "assistant", "content": llm_response.completion_text} ) diff --git a/astrbot/core/provider/entites.py b/astrbot/core/provider/entites.py index 4236214d4..a8ffcdf64 100644 --- a/astrbot/core/provider/entites.py +++ b/astrbot/core/provider/entites.py @@ -1,11 +1,15 @@ import enum import base64 +import json from astrbot.core.utils.io import download_image_by_url from astrbot import logger from dataclasses import dataclass, field from typing import List, Dict, Type from .func_tool_manager import FuncCall from openai.types.chat.chat_completion import ChatCompletion +from openai.types.chat.chat_completion_message_tool_call import ( + ChatCompletionMessageToolCall, +) from astrbot.core.db.po import Conversation from astrbot.core.message.message_event_result import MessageChain import astrbot.core.message.components as Comp @@ -32,6 +36,58 @@ class ProviderMetaData: """显示在 WebUI 配置页中的提供商名称,如空则是 type""" +@dataclass +class ToolCallMessageSegment: + """OpenAI 格式的上下文中 role 为 tool 的消息段。参考: https://platform.openai.com/docs/guides/function-calling""" + + tool_call_id: str + content: str + role: str = "tool" + + def to_dict(self): + return { + "tool_call_id": self.tool_call_id, + "content": self.content, + "role": self.role, + } + + +@dataclass +class AssistantMessageSegment: + """OpenAI 格式的上下文中 role 为 assistant 的消息段。参考: https://platform.openai.com/docs/guides/function-calling""" + + content: str = None + tool_calls: List[ChatCompletionMessageToolCall | Dict] = None + role: str = "assistant" + + def to_dict(self): + ret = { + "role": self.role, + } + if self.content: + ret["content"] = self.content + elif self.tool_calls: + ret["tool_calls"] = self.tool_calls + return ret + + +@dataclass +class ToolCallsResult: + """工具调用结果""" + + tool_calls_info: AssistantMessageSegment + """函数调用的信息""" + tool_calls_result: List[ToolCallMessageSegment] + """函数调用的结果""" + + def to_openai_messages(self) -> List[Dict]: + ret = [ + self.tool_calls_info.to_dict(), + *[item.to_dict() for item in self.tool_calls_result], + ] + return ret + + @dataclass class ProviderRequest: prompt: str @@ -41,7 +97,7 @@ class ProviderRequest: image_urls: List[str] = None """图片 URL 列表""" func_tool: FuncCall = None - """工具""" + """可用的函数工具""" contexts: List = None """上下文。格式与 openai 的上下文格式一致: 参考 https://platform.openai.com/docs/api-reference/chat/create#chat-create-messages @@ -50,8 +106,11 @@ class ProviderRequest: """系统提示词""" conversation: Conversation = None + tool_calls_result: ToolCallsResult = None + """附加的上次请求后工具调用的结果。参考: https://platform.openai.com/docs/guides/function-calling#handling-function-calls""" + def __repr__(self): - return f"ProviderRequest(prompt={self.prompt}, session_id={self.session_id}, image_urls={self.image_urls}, func_tool={self.func_tool}, contexts={self._print_friendly_context()}, system_prompt={self.system_prompt.strip()})" + return f"ProviderRequest(prompt={self.prompt}, session_id={self.session_id}, image_urls={self.image_urls}, func_tool={self.func_tool}, contexts={self._print_friendly_context()}, system_prompt={self.system_prompt.strip()}, tool_calls_result={self.tool_calls_result})" def __str__(self): return self.__repr__() @@ -137,6 +196,8 @@ class LLMResponse: """工具调用参数""" tools_call_name: List[str] = field(default_factory=list) """工具调用名称""" + tools_call_ids: List[str] = field(default_factory=list) + """工具调用 ID""" raw_completion: ChatCompletion = None _new_record: Dict[str, any] = None @@ -148,8 +209,9 @@ class LLMResponse: role: str, completion_text: str = "", result_chain: MessageChain = None, - tools_call_args: List[Dict[str, any]] = None, - tools_call_name: List[str] = None, + tools_call_args: List[Dict[str, any]] = [], + tools_call_name: List[str] = [], + tools_call_ids: List[str] = [], raw_completion: ChatCompletion = None, _new_record: Dict[str, any] = None, ): @@ -168,6 +230,7 @@ class LLMResponse: self.result_chain = result_chain self.tools_call_args = tools_call_args self.tools_call_name = tools_call_name + self.tools_call_ids = tools_call_ids self.raw_completion = raw_completion self._new_record = _new_record @@ -188,3 +251,19 @@ class LLMResponse: self.result_chain.chain.insert(0, Comp.Plain(value)) else: self._completion_text = value + + def to_openai_tool_calls(self) -> List[Dict]: + """将工具调用信息转换为 OpenAI 格式""" + ret = [] + for idx, tool_call_arg in enumerate(self.tools_call_args): + ret.append( + { + "id": self.tools_call_ids[idx], + "function": { + "name": self.tools_call_name[idx], + "arguments": json.dumps(tool_call_arg), + }, + "type": "function", + } + ) + return ret diff --git a/astrbot/core/provider/func_tool_manager.py b/astrbot/core/provider/func_tool_manager.py index fd4a14f7f..52f31c363 100644 --- a/astrbot/core/provider/func_tool_manager.py +++ b/astrbot/core/provider/func_tool_manager.py @@ -4,6 +4,7 @@ import textwrap import os import asyncio import mcp +import copy from typing import Dict, List, Awaitable, Literal, Any from dataclasses import dataclass @@ -391,7 +392,13 @@ class FuncCall: # 检查并添加非空的properties参数 params = f.parameters if isinstance(f.parameters, dict) else {} + params = copy.deepcopy(params) if params.get("properties", {}): + properties = params["properties"] + for key, value in properties.items(): + if "default" in value: + del value["default"] + params["properties"] = properties func_declaration["parameters"] = params tools.append(func_declaration) diff --git a/astrbot/core/provider/provider.py b/astrbot/core/provider/provider.py index 57dc57f90..8dcff9a52 100644 --- a/astrbot/core/provider/provider.py +++ b/astrbot/core/provider/provider.py @@ -3,7 +3,7 @@ from typing import List from astrbot.core.db import BaseDatabase from typing import TypedDict from astrbot.core.provider.func_tool_manager import FuncCall -from astrbot.core.provider.entites import LLMResponse +from astrbot.core.provider.entites import LLMResponse, ToolCallsResult from dataclasses import dataclass @@ -90,6 +90,7 @@ class Provider(AbstractProvider): func_tool: FuncCall = None, contexts: List = None, system_prompt: str = None, + tool_calls_result: ToolCallsResult = None, **kwargs, ) -> LLMResponse: """获得 LLM 的文本对话结果。会使用当前的模型进行对话。 @@ -100,6 +101,7 @@ class Provider(AbstractProvider): image_urls: 图片 URL 列表 tools: Function-calling 工具 contexts: 上下文 + tool_calls_result: 回传给 LLM 的工具调用结果。参考: https://platform.openai.com/docs/guides/function-calling kwargs: 其他参数 Notes: diff --git a/astrbot/core/provider/sources/anthropic_source.py b/astrbot/core/provider/sources/anthropic_source.py index 13f7482d4..90efdee91 100644 --- a/astrbot/core/provider/sources/anthropic_source.py +++ b/astrbot/core/provider/sources/anthropic_source.py @@ -10,7 +10,7 @@ from astrbot.api.provider import Provider, Personality from astrbot import logger from astrbot.core.provider.func_tool_manager import FuncCall from ..register import register_provider_adapter -from astrbot.core.provider.entites import LLMResponse +from astrbot.core.provider.entites import LLMResponse, ToolCallsResult from .openai_source import ProviderOpenAIOfficial @@ -79,11 +79,14 @@ class ProviderAnthropic(ProviderOpenAIOfficial): # tools call (function calling) args_ls = [] func_name_ls = [] + tool_use_ids = [] func_name_ls.append(content.name) args_ls.append(content.input) + tool_use_ids.append(content.id) llm_response.role = "tool" llm_response.tools_call_args = args_ls llm_response.tools_call_name = func_name_ls + llm_response.tools_call_ids = tool_use_ids if not llm_response.completion_text and not llm_response.tools_call_args: logger.error(f"API 返回的 completion 无法解析:{completion}。") @@ -101,6 +104,7 @@ class ProviderAnthropic(ProviderOpenAIOfficial): func_tool: FuncCall = None, contexts=[], system_prompt=None, + tool_calls_result: ToolCallsResult=None, **kwargs, ) -> LLMResponse: if not prompt: @@ -113,6 +117,10 @@ class ProviderAnthropic(ProviderOpenAIOfficial): if "_no_save" in part: del part["_no_save"] + if tool_calls_result: + # 暂时这样写。 + prompt += f"Here are the related results via using tools: {str(tool_calls_result.tool_calls_result)}" + model_config = self.provider_config.get("model_config", {}) payloads = {"messages": context_query, **model_config} diff --git a/astrbot/core/provider/sources/gemini_source.py b/astrbot/core/provider/sources/gemini_source.py index 233cc8932..1caccf2b1 100644 --- a/astrbot/core/provider/sources/gemini_source.py +++ b/astrbot/core/provider/sources/gemini_source.py @@ -1,5 +1,6 @@ import base64 import aiohttp +import json import random from astrbot.core.utils.io import download_image_by_url from astrbot.core.db import BaseDatabase @@ -115,6 +116,7 @@ class ProviderGoogleGenAI(Provider): break google_genai_conversation = [] + print(payloads) for message in payloads["messages"]: if message["role"] == "user": if isinstance(message["content"], str): @@ -146,11 +148,39 @@ class ProviderGoogleGenAI(Provider): google_genai_conversation.append({"role": "user", "parts": parts}) elif message["role"] == "assistant": - if not message["content"]: - message["content"] = "" - google_genai_conversation.append( - {"role": "model", "parts": [{"text": message["content"]}]} + if "content" in message: + if not message["content"]: + message["content"] = "" + google_genai_conversation.append( + {"role": "model", "parts": [{"text": message["content"]}]} + ) + elif "tool_calls" in message: + # tool calls in the last turn + parts = [] + for tool_call in message["tool_calls"]: + parts.append( + { + "functionCall": { + "name": tool_call["function"]["name"], + "args": json.loads(tool_call["function"]["arguments"]), + } + } + ) + google_genai_conversation.append({"role": "model", "parts": parts}) + elif message["role"] == "tool": + parts = [] + parts.append( + { + "functionResponse": { + "name": message["tool_call_id"], + "response": { + "name": message["tool_call_id"], + "content": message["content"], + }, + } + } ) + google_genai_conversation.append({"role": "user", "parts": parts}) logger.debug(f"google_genai_conversation: {google_genai_conversation}") @@ -174,6 +204,7 @@ class ProviderGoogleGenAI(Provider): llm_response.role = "tool" llm_response.tools_call_args.append(candidate["functionCall"]["args"]) llm_response.tools_call_name.append(candidate["functionCall"]["name"]) + llm_response.tools_call_ids.append(candidate["functionCall"]["name"]) # 没有 tool id llm_response.completion_text = llm_response.completion_text.strip() return llm_response @@ -186,6 +217,7 @@ class ProviderGoogleGenAI(Provider): func_tool: FuncCall = None, contexts=[], system_prompt=None, + tool_calls_result=None, **kwargs, ) -> LLMResponse: new_record = await self.assemble_context(prompt, image_urls) @@ -198,6 +230,10 @@ class ProviderGoogleGenAI(Provider): if "_no_save" in part: del part["_no_save"] + # tool calls result + if tool_calls_result: + context_query.extend(tool_calls_result.to_openai_messages()) + model_config = self.provider_config.get("model_config", {}) model_config["model"] = self.get_model() diff --git a/astrbot/core/provider/sources/openai_source.py b/astrbot/core/provider/sources/openai_source.py index 4b66c50ef..628b1849d 100644 --- a/astrbot/core/provider/sources/openai_source.py +++ b/astrbot/core/provider/sources/openai_source.py @@ -120,15 +120,18 @@ class ProviderOpenAIOfficial(Provider): # tools call (function calling) args_ls = [] func_name_ls = [] + tool_call_ids = [] for tool_call in choice.message.tool_calls: for tool in tools.func_list: if tool.name == tool_call.function.name: args = json.loads(tool_call.function.arguments) args_ls.append(args) func_name_ls.append(tool_call.function.name) + tool_call_ids.append(tool_call.id) llm_response.role = "tool" llm_response.tools_call_args = args_ls llm_response.tools_call_name = func_name_ls + llm_response.tools_call_ids = tool_call_ids if choice.finish_reason == "content_filter": raise Exception( @@ -151,6 +154,7 @@ class ProviderOpenAIOfficial(Provider): func_tool: FuncCall = None, contexts=[], system_prompt=None, + tool_calls_result=None, **kwargs, ) -> LLMResponse: new_record = await self.assemble_context(prompt, image_urls) @@ -162,10 +166,15 @@ class ProviderOpenAIOfficial(Provider): if "_no_save" in part: del part["_no_save"] + # tool calls result + if tool_calls_result: + context_query.extend(tool_calls_result.to_openai_messages()) + model_config = self.provider_config.get("model_config", {}) model_config["model"] = self.get_model() payloads = {"messages": context_query, **model_config} + llm_response = None try: llm_response = await self._query(payloads, func_tool) @@ -275,10 +284,8 @@ class ProviderOpenAIOfficial(Provider): def set_key(self, key): self.client.api_key = key - async def assemble_context(self, text: str, image_urls: List[str] = None): - """ - 组装上下文。 - """ + async def assemble_context(self, text: str, image_urls: List[str] = None) -> dict: + """组装成符合 OpenAI 格式的 role 为 user 的消息段""" if image_urls: user_content = {"role": "user", "content": [{"type": "text", "text": text}]} for image_url in image_urls: From db46000337c2fbfa835adb77e6908fcaa7a89fd2 Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Sun, 23 Mar 2025 22:22:11 +0800 Subject: [PATCH 15/16] =?UTF-8?q?=F0=9F=8E=A8=20style:=20format=20codes?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/core/pipeline/process_stage/method/llm_request.py | 4 +--- astrbot/core/provider/sources/anthropic_source.py | 2 +- astrbot/core/provider/sources/gemini_source.py | 8 ++++++-- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/astrbot/core/pipeline/process_stage/method/llm_request.py b/astrbot/core/pipeline/process_stage/method/llm_request.py index acde70c27..d13a101ef 100644 --- a/astrbot/core/pipeline/process_stage/method/llm_request.py +++ b/astrbot/core/pipeline/process_stage/method/llm_request.py @@ -292,9 +292,7 @@ class LLMRequestSubStage(Stage): # tool calls result if req.tool_calls_result: - contexts.extend( - req.tool_calls_result.to_openai_messages() - ) + contexts.extend(req.tool_calls_result.to_openai_messages()) contexts.append( {"role": "assistant", "content": llm_response.completion_text} diff --git a/astrbot/core/provider/sources/anthropic_source.py b/astrbot/core/provider/sources/anthropic_source.py index 90efdee91..fd19c40ca 100644 --- a/astrbot/core/provider/sources/anthropic_source.py +++ b/astrbot/core/provider/sources/anthropic_source.py @@ -104,7 +104,7 @@ class ProviderAnthropic(ProviderOpenAIOfficial): func_tool: FuncCall = None, contexts=[], system_prompt=None, - tool_calls_result: ToolCallsResult=None, + tool_calls_result: ToolCallsResult = None, **kwargs, ) -> LLMResponse: if not prompt: diff --git a/astrbot/core/provider/sources/gemini_source.py b/astrbot/core/provider/sources/gemini_source.py index 1caccf2b1..ec0337389 100644 --- a/astrbot/core/provider/sources/gemini_source.py +++ b/astrbot/core/provider/sources/gemini_source.py @@ -162,7 +162,9 @@ class ProviderGoogleGenAI(Provider): { "functionCall": { "name": tool_call["function"]["name"], - "args": json.loads(tool_call["function"]["arguments"]), + "args": json.loads( + tool_call["function"]["arguments"] + ), } } ) @@ -204,7 +206,9 @@ class ProviderGoogleGenAI(Provider): llm_response.role = "tool" llm_response.tools_call_args.append(candidate["functionCall"]["args"]) llm_response.tools_call_name.append(candidate["functionCall"]["name"]) - llm_response.tools_call_ids.append(candidate["functionCall"]["name"]) # 没有 tool id + llm_response.tools_call_ids.append( + candidate["functionCall"]["name"] + ) # 没有 tool id llm_response.completion_text = llm_response.completion_text.strip() return llm_response From 7f998c76118209a91c5644b3eed04f80f8f26e3a Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Sun, 23 Mar 2025 22:28:00 +0800 Subject: [PATCH 16/16] chore: remove useless print output --- astrbot/core/provider/sources/gemini_source.py | 1 - 1 file changed, 1 deletion(-) diff --git a/astrbot/core/provider/sources/gemini_source.py b/astrbot/core/provider/sources/gemini_source.py index ec0337389..90a584235 100644 --- a/astrbot/core/provider/sources/gemini_source.py +++ b/astrbot/core/provider/sources/gemini_source.py @@ -116,7 +116,6 @@ class ProviderGoogleGenAI(Provider): break google_genai_conversation = [] - print(payloads) for message in payloads["messages"]: if message["role"] == "user": if isinstance(message["content"], str):