diff --git a/astrbot/core/pipeline/process_stage/method/llm_request.py b/astrbot/core/pipeline/process_stage/method/llm_request.py index a1d7ac882..d13a101ef 100644 --- a/astrbot/core/pipeline/process_stage/method/llm_request.py +++ b/astrbot/core/pipeline/process_stage/method/llm_request.py @@ -16,7 +16,13 @@ from astrbot.core.message.message_event_result import ( from astrbot.core.message.components import Image from astrbot.core import logger from astrbot.core.utils.metrics import Metric -from astrbot.core.provider.entites import ProviderRequest, LLMResponse +from astrbot.core.provider.entites import ( + ProviderRequest, + LLMResponse, + ToolCallMessageSegment, + AssistantMessageSegment, + ToolCallsResult, +) from astrbot.core.star.star_handler import star_handlers_registry, EventType from astrbot.core.star.star import star_map @@ -111,10 +117,18 @@ class LLMRequestSubStage(Stage): req.contexts = json.loads(req.contexts) try: - logger.debug(f"提供商请求 Payload: {req}") - if _nested: - req.func_tool = None # 暂时不支持递归工具调用 - llm_response = await provider.text_chat(**req.__dict__) # 请求 LLM + need_loop = True + while need_loop: + need_loop = False + logger.debug(f"提供商请求 Payload: {req}") + llm_response = await provider.text_chat(**req.__dict__) # 请求 LLM + async for result in self._handle_llm_response(event, req, llm_response): + if isinstance(result, ProviderRequest): + # 有函数工具调用并且返回了结果,我们需要再次请求 LLM + req = result + need_loop = True + else: + yield # 执行 LLM 响应后的事件钩子。 handlers = star_handlers_registry.get_handlers_by_event_type( @@ -135,9 +149,6 @@ class LLMRequestSubStage(Stage): ) return - # 保存到历史记录 - await self._save_to_history(event, req, llm_response) - asyncio.create_task( Metric.upload( llm_tick=1, @@ -146,72 +157,8 @@ class LLMRequestSubStage(Stage): ) ) - if llm_response.role == "assistant": - # text completion - if llm_response.result_chain: - event.set_result( - MessageEventResult( - chain=llm_response.result_chain.chain - ).set_result_content_type(ResultContentType.LLM_RESULT) - ) - else: - event.set_result( - MessageEventResult() - .message(llm_response.completion_text) - .set_result_content_type(ResultContentType.LLM_RESULT) - ) - elif llm_response.role == "err": - event.set_result( - MessageEventResult().message( - f"AstrBot 请求失败。\n错误信息: {llm_response.completion_text}" - ) - ) - elif llm_response.role == "tool": - # function calling - function_calling_result = {} - logger.info( - f"触发 {len(llm_response.tools_call_name)} 个函数调用: {llm_response.tools_call_name}" - ) - for func_tool_name, func_tool_args in zip( - llm_response.tools_call_name, llm_response.tools_call_args - ): - func_tool = req.func_tool.get_func(func_tool_name) - logger.info( - f"调用工具函数:{func_tool_name},参数:{func_tool_args}" - ) - try: - # 尝试调用工具函数 - wrapper = self._call_handler( - self.ctx, event, func_tool.handler, **func_tool_args - ) - async for resp in wrapper: - if resp is not None: # 有 return 返回 - function_calling_result[func_tool_name] = resp - else: - yield # 有生成器返回 - event.clear_result() # 清除上一个 handler 的结果 - except BaseException as e: - logger.warning(traceback.format_exc()) - function_calling_result[func_tool_name] = ( - "When calling the function, an error occurred: " + str(e) - ) - if function_calling_result: - # 工具返回 LLM 资源。比如 RAG、网页 得到的相关结果等。 - # 我们重新执行一遍这个 stage - req.func_tool = None # 暂时不支持递归工具调用 - extra_prompt = "\n\nSystem executed some external tools for this task and here are the results:\n" - for tool_name, tool_result in function_calling_result.items(): - extra_prompt += ( - f"Tool: {tool_name}\nTool Result: {tool_result}\n" - ) - req.prompt += extra_prompt - async for _ in self.process(event, _nested=True): - yield - else: - if llm_response.completion_text: - event.set_result( - MessageEventResult().message(llm_response.completion_text) - ) + # 保存到历史记录 + await self._save_to_history(event, req, llm_response) except BaseException as e: logger.error(traceback.format_exc()) @@ -222,6 +169,116 @@ class LLMRequestSubStage(Stage): ) return + async def _handle_llm_response( + self, event: AstrMessageEvent, req: ProviderRequest, llm_response: LLMResponse + ) -> AsyncGenerator[None, None]: + """处理 LLM 响应。 + + Returns: + bool: 是否需要继续调用 LLM + + Yields: + Iterator[bool]: 将 event 交付给下一个 stage + """ + if llm_response.role == "assistant": + # text completion + if llm_response.result_chain: + event.set_result( + MessageEventResult( + chain=llm_response.result_chain.chain + ).set_result_content_type(ResultContentType.LLM_RESULT) + ) + else: + event.set_result( + MessageEventResult() + .message(llm_response.completion_text) + .set_result_content_type(ResultContentType.LLM_RESULT) + ) + elif llm_response.role == "err": + event.set_result( + MessageEventResult().message( + f"AstrBot 请求失败。\n错误信息: {llm_response.completion_text}" + ) + ) + elif llm_response.role == "tool": + # function calling + tool_call_result: list[ToolCallMessageSegment] = [] + logger.info( + f"触发 {len(llm_response.tools_call_name)} 个函数调用: {llm_response.tools_call_name}" + ) + for func_tool_name, func_tool_args, func_tool_id in zip( + llm_response.tools_call_name, + llm_response.tools_call_args, + llm_response.tools_call_ids, + ): + try: + func_tool = req.func_tool.get_func(func_tool_name) + if func_tool.origin == "mcp": + logger.info( + f"从 MCP 服务 {func_tool.mcp_server_name} 调用工具函数:{func_tool.name},参数:{func_tool_args}" + ) + client = req.func_tool.mcp_client_dict[ + func_tool.mcp_server_name + ] + res = await client.session.call_tool( + func_tool.name, func_tool_args + ) + if res: + # TODO content的类型可能包括list[TextContent | ImageContent | EmbeddedResource],这里只处理了TextContent。 + tool_call_result.append( + ToolCallMessageSegment( + role="tool", + tool_call_id=func_tool_id, + content=res.content[0].text, + ) + ) + else: + logger.info( + f"调用工具函数:{func_tool_name},参数:{func_tool_args}" + ) + # 尝试调用工具函数 + wrapper = self._call_handler( + self.ctx, event, func_tool.handler, **func_tool_args + ) + async for resp in wrapper: + if resp is not None: # 有 return 返回 + tool_call_result.append( + ToolCallMessageSegment( + role="tool", + tool_call_id=func_tool_id, + content=resp, + ) + ) + else: + yield # 有生成器返回 + event.clear_result() # 清除上一个 handler 的结果 + except BaseException as e: + logger.warning(traceback.format_exc()) + tool_call_result.append( + ToolCallMessageSegment( + role="tool", + tool_call_id=func_tool_id, + content=f"error: {str(e)}", + ) + ) + if tool_call_result: + # 函数调用结果 + req.func_tool = None # 暂时不支持递归工具调用 + assistant_msg_seg = AssistantMessageSegment( + role="assistant", tool_calls=llm_response.to_openai_tool_calls() + ) + # 在多轮 Tool 调用的情况下,这里始终保持最新的 Tool 调用结果,减少上下文长度。 + req.tool_calls_result = ToolCallsResult( + tool_calls_info=assistant_msg_seg, + tool_calls_result=tool_call_result, + ) + yield req # 再次执行 LLM 请求 + else: + if llm_response.completion_text: + event.set_result( + MessageEventResult().message(llm_response.completion_text) + ) + async def _save_to_history( self, event: AstrMessageEvent, req: ProviderRequest, llm_response: LLMResponse ): @@ -232,6 +289,11 @@ class LLMRequestSubStage(Stage): # 文本回复 contexts = req.contexts contexts.append(await req.assemble_context()) + + # tool calls result + if req.tool_calls_result: + contexts.extend(req.tool_calls_result.to_openai_messages()) + contexts.append( {"role": "assistant", "content": llm_response.completion_text} ) diff --git a/astrbot/core/provider/entites.py b/astrbot/core/provider/entites.py index 4236214d4..a8ffcdf64 100644 --- a/astrbot/core/provider/entites.py +++ b/astrbot/core/provider/entites.py @@ -1,11 +1,15 @@ import enum import base64 +import json from astrbot.core.utils.io import download_image_by_url from astrbot import logger from dataclasses import dataclass, field from typing import List, Dict, Type from .func_tool_manager import FuncCall from openai.types.chat.chat_completion import ChatCompletion +from openai.types.chat.chat_completion_message_tool_call import ( + ChatCompletionMessageToolCall, +) from astrbot.core.db.po import Conversation from astrbot.core.message.message_event_result import MessageChain import astrbot.core.message.components as Comp @@ -32,6 +36,58 @@ class ProviderMetaData: """显示在 WebUI 配置页中的提供商名称,如空则是 type""" +@dataclass +class ToolCallMessageSegment: + """OpenAI 格式的上下文中 role 为 tool 的消息段。参考: https://platform.openai.com/docs/guides/function-calling""" + + tool_call_id: str + content: str + role: str = "tool" + + def to_dict(self): + return { + "tool_call_id": self.tool_call_id, + "content": self.content, + "role": self.role, + } + + +@dataclass +class AssistantMessageSegment: + """OpenAI 格式的上下文中 role 为 assistant 的消息段。参考: https://platform.openai.com/docs/guides/function-calling""" + + content: str = None + tool_calls: List[ChatCompletionMessageToolCall | Dict] = None + role: str = "assistant" + + def to_dict(self): + ret = { + "role": self.role, + } + if self.content: + ret["content"] = self.content + elif self.tool_calls: + ret["tool_calls"] = self.tool_calls + return ret + + +@dataclass +class ToolCallsResult: + """工具调用结果""" + + tool_calls_info: AssistantMessageSegment + """函数调用的信息""" + tool_calls_result: List[ToolCallMessageSegment] + """函数调用的结果""" + + def to_openai_messages(self) -> List[Dict]: + ret = [ + self.tool_calls_info.to_dict(), + *[item.to_dict() for item in self.tool_calls_result], + ] + return ret + + @dataclass class ProviderRequest: prompt: str @@ -41,7 +97,7 @@ class ProviderRequest: image_urls: List[str] = None """图片 URL 列表""" func_tool: FuncCall = None - """工具""" + """可用的函数工具""" contexts: List = None """上下文。格式与 openai 的上下文格式一致: 参考 https://platform.openai.com/docs/api-reference/chat/create#chat-create-messages @@ -50,8 +106,11 @@ class ProviderRequest: """系统提示词""" conversation: Conversation = None + tool_calls_result: ToolCallsResult = None + """附加的上次请求后工具调用的结果。参考: https://platform.openai.com/docs/guides/function-calling#handling-function-calls""" + def __repr__(self): - return f"ProviderRequest(prompt={self.prompt}, session_id={self.session_id}, image_urls={self.image_urls}, func_tool={self.func_tool}, contexts={self._print_friendly_context()}, system_prompt={self.system_prompt.strip()})" + return f"ProviderRequest(prompt={self.prompt}, session_id={self.session_id}, image_urls={self.image_urls}, func_tool={self.func_tool}, contexts={self._print_friendly_context()}, system_prompt={self.system_prompt.strip()}, tool_calls_result={self.tool_calls_result})" def __str__(self): return self.__repr__() @@ -137,6 +196,8 @@ class LLMResponse: """工具调用参数""" tools_call_name: List[str] = field(default_factory=list) """工具调用名称""" + tools_call_ids: List[str] = field(default_factory=list) + """工具调用 ID""" raw_completion: ChatCompletion = None _new_record: Dict[str, any] = None @@ -148,8 +209,9 @@ class LLMResponse: role: str, completion_text: str = "", result_chain: MessageChain = None, - tools_call_args: List[Dict[str, any]] = None, - tools_call_name: List[str] = None, + tools_call_args: List[Dict[str, any]] = [], + tools_call_name: List[str] = [], + tools_call_ids: List[str] = [], raw_completion: ChatCompletion = None, _new_record: Dict[str, any] = None, ): @@ -168,6 +230,7 @@ class LLMResponse: self.result_chain = result_chain self.tools_call_args = tools_call_args self.tools_call_name = tools_call_name + self.tools_call_ids = tools_call_ids self.raw_completion = raw_completion self._new_record = _new_record @@ -188,3 +251,19 @@ class LLMResponse: self.result_chain.chain.insert(0, Comp.Plain(value)) else: self._completion_text = value + + def to_openai_tool_calls(self) -> List[Dict]: + """将工具调用信息转换为 OpenAI 格式""" + ret = [] + for idx, tool_call_arg in enumerate(self.tools_call_args): + ret.append( + { + "id": self.tools_call_ids[idx], + "function": { + "name": self.tools_call_name[idx], + "arguments": json.dumps(tool_call_arg), + }, + "type": "function", + } + ) + return ret diff --git a/astrbot/core/provider/func_tool_manager.py b/astrbot/core/provider/func_tool_manager.py index 0f04628f7..52f31c363 100644 --- a/astrbot/core/provider/func_tool_manager.py +++ b/astrbot/core/provider/func_tool_manager.py @@ -1,9 +1,29 @@ +from __future__ import annotations import json import textwrap -from typing import Dict, List, Awaitable +import os +import asyncio +import mcp +import copy + +from typing import Dict, List, Awaitable, Literal, Any from dataclasses import dataclass +from typing import Optional +from contextlib import AsyncExitStack + +from mcp.client.stdio import stdio_client from astrbot import logger +DEFAULT_MCP_CONFIG = {"mcpServers": {}} + +SUPPORTED_TYPES = [ + "string", + "number", + "object", + "array", + "boolean", +] # json schema 支持的数据类型 + @dataclass class FuncTool: @@ -14,28 +34,101 @@ class FuncTool: name: str parameters: Dict description: str - handler: Awaitable - handler_module_path: str = None # 必须要保留这个,handler 在初始化会被 functools.partial 包装,导致 handler 的 __module__ 为 functools + handler: Awaitable = None + """处理函数, 当 origin 为 mcp 时,这个为空""" + handler_module_path: str = None + """处理函数的模块路径,当 origin 为 mcp 时,这个为空 + 必须要保留这个字段, handler 在初始化会被 functools.partial 包装,导致 handler 的 __module__ 为 functools + """ active: bool = True """是否激活""" + origin: Literal["local", "mcp"] = "local" + """函数工具的来源, local 为本地函数工具, mcp 为 MCP 服务""" + + # MCP 相关字段 + mcp_server_name: str = None + """MCP 服务名称,当 origin 为 mcp 时有效""" + mcp_client: MCPClient = None + """MCP 客户端,当 origin 为 mcp 时有效""" + def __repr__(self): - return f"FuncTool(name={self.name}, parameters={self.parameters}, description={self.description}), active={self.active})" + return f"FuncTool(name={self.name}, parameters={self.parameters}, description={self.description}, active={self.active}, origin={self.origin})" + + async def execute(self, **args) -> Any: + """执行函数调用""" + if self.origin == "local": + if not self.handler: + raise Exception(f"Local function {self.name} has no handler") + return await self.handler(**args) + elif self.origin == "mcp": + if not self.mcp_client or not self.mcp_client.session: + raise Exception(f"MCP client for {self.name} is not available") + # 使用name属性而不是额外的mcp_tool_name + if ":" in self.name: + # 如果名字是格式为 mcp:server:tool_name,提取实际的工具名 + actual_tool_name = self.name.split(":")[-1] + return await self.mcp_client.session.call_tool(actual_tool_name, args) + else: + return await self.mcp_client.session.call_tool(self.name, args) + else: + raise Exception(f"Unknown function origin: {self.origin}") -SUPPORTED_TYPES = [ - "string", - "number", - "object", - "array", - "boolean", -] # json schema 支持的数据类型 +class MCPClient: + def __init__(self): + # Initialize session and client objects + self.session: Optional[mcp.ClientSession] = None + self.exit_stack = AsyncExitStack() + + self.name = None + self.active: bool = True + self.tools: List[mcp.Tool] = [] + + async def connect_to_server(self, mcp_server_config: dict): + """Connect to an MCP server + + Args: + mcp_server_config (dict): Configuration for the MCP server. See https://modelcontextprotocol.io/quickstart/server + """ + cfg = mcp_server_config.copy() + cfg.pop("active", None) + server_params = mcp.StdioServerParameters( + **cfg, + ) + + stdio_transport = await self.exit_stack.enter_async_context( + stdio_client(server_params) + ) + self.stdio, self.write = stdio_transport + self.session = await self.exit_stack.enter_async_context( + mcp.ClientSession(self.stdio, self.write) + ) + + await self.session.initialize() + + async def list_tools_and_save(self) -> mcp.ListToolsResult: + """List all tools from the server and save them to self.tools""" + response = await self.session.list_tools() + logger.debug(f"MCP server {self.name} list tools response: {response}") + self.tools = response.tools + return response + + async def cleanup(self): + """Clean up resources""" + await self.exit_stack.aclose() class FuncCall: def __init__(self) -> None: self.func_list: List[FuncTool] = [] + """内部加载的 func tools""" + self.mcp_client_dict: Dict[str, MCPClient] = {} + """MCP 服务列表""" + self.mcp_service_queue = asyncio.Queue() + """用于外部控制 MCP 服务的启停""" + self.mcp_client_event: Dict[str, asyncio.Event] = {} def empty(self) -> bool: return len(self.func_list) == 0 @@ -90,11 +183,166 @@ class FuncCall: return f return None + async def _init_mcp_clients(self) -> None: + """从项目根目录读取 mcp_server.json 文件,初始化 MCP 服务列表。文件格式如下: + ``` + { + "mcpServers": { + "weather": { + "command": "uv", + "args": [ + "--directory", + "/ABSOLUTE/PATH/TO/PARENT/FOLDER/weather", + "run", + "weather.py" + ] + } + } + ... + } + ``` + """ + current_dir = os.path.dirname(os.path.abspath(__file__)) + data_dir = os.path.abspath(os.path.join(current_dir, "../../../data")) + + mcp_json_file = os.path.join(data_dir, "mcp_server.json") + if not os.path.exists(mcp_json_file): + # 配置文件不存在错误处理 + with open(mcp_json_file, "w", encoding="utf-8") as f: + json.dump(DEFAULT_MCP_CONFIG, f, ensure_ascii=False, indent=4) + logger.info(f"未找到 MCP 服务配置文件,已创建默认配置文件 {mcp_json_file}") + return + + mcp_server_json_obj: Dict[str, Dict] = json.load( + open(mcp_json_file, "r", encoding="utf-8") + )["mcpServers"] + + for name in mcp_server_json_obj.keys(): + cfg = mcp_server_json_obj[name] + if cfg.get("active", True): + event = asyncio.Event() + asyncio.create_task( + self._init_mcp_client_task_wrapper(name, cfg, event) + ) + self.mcp_client_event[name] = event + + async def mcp_service_selector(self): + """为了避免在不同异步任务中控制 MCP 服务导致的报错,整个项目统一通过这个 Task 来控制 + + 使用 self.mcp_service_queue.put_nowait() 来控制 MCP 服务的启停,数据格式如下: + + {"type": "init"} 初始化所有MCP客户端 + + {"type": "init", "name": "mcp_server_name", "cfg": {...}} 初始化指定的MCP客户端 + + {"type": "terminate"} 终止所有MCP客户端 + + {"type": "terminate", "name": "mcp_server_name"} 终止指定的MCP客户端 + """ + while True: + data = await self.mcp_service_queue.get() + if data["type"] == "init": + if "name" in data: + event = asyncio.Event() + asyncio.create_task( + self._init_mcp_client_task_wrapper( + data["name"], data["cfg"], event + ) + ) + self.mcp_client_event[data["name"]] = event + else: + await self._init_mcp_clients() + elif data["type"] == "terminate": + if "name" in data: + # await self._terminate_mcp_client(data["name"]) + if data["name"] in self.mcp_client_event: + self.mcp_client_event[data["name"]].set() + self.mcp_client_event.pop(data["name"], None) + else: + for name in self.mcp_client_dict.keys(): + # await self._terminate_mcp_client(name) + # self.mcp_client_event[name].set() + if name in self.mcp_client_event: + self.mcp_client_event[name].set() + self.mcp_client_event.pop(name, None) + + async def _init_mcp_client_task_wrapper( + self, name: str, cfg: dict, event: asyncio.Event + ) -> None: + """初始化 MCP 客户端的包装函数,用于捕获异常""" + try: + await self._init_mcp_client(name, cfg) + await event.wait() + logger.info(f"收到 MCP 客户端 {name} 终止信号") + await self._terminate_mcp_client(name) + except Exception as e: + logger.error(f"初始化 MCP 客户端 {name} 失败: {e}") + + async def _init_mcp_client(self, name: str, config: dict) -> None: + """初始化单个MCP客户端""" + try: + # 先清理之前的客户端,如果存在 + if name in self.mcp_client_dict: + await self._terminate_mcp_client(name) + + mcp_client = MCPClient() + mcp_client.name = name + await mcp_client.connect_to_server(config) + tools_res = await mcp_client.list_tools_and_save() + tool_names = [tool.name for tool in tools_res.tools] + self.mcp_client_dict[name] = mcp_client + + # 移除该MCP服务之前的工具(如有) + self.func_list = [ + f + for f in self.func_list + if not (f.origin == "mcp" and f.mcp_server_name == name) + ] + + # 将 MCP 工具转换为 FuncTool 并添加到 func_list + for tool in mcp_client.tools: + func_tool = FuncTool( + name=tool.name, + parameters=tool.inputSchema, + description=tool.description, + origin="mcp", + mcp_server_name=name, + mcp_client=mcp_client, + ) + self.func_list.append(func_tool) + + logger.info(f"已连接 MCP 服务 {name}, Tools: {tool_names}") + return True + except Exception as e: + logger.error(f"初始化 MCP 客户端 {name} 失败: {e}") + # 发生错误时确保客户端被清理 + if name in self.mcp_client_dict: + await self._terminate_mcp_client(name) + return False + + async def _terminate_mcp_client(self, name: str) -> None: + """关闭并清理MCP客户端""" + if name in self.mcp_client_dict: + try: + # 关闭MCP连接 + await self.mcp_client_dict[name].cleanup() + del self.mcp_client_dict[name] + except Exception as e: + logger.info(f"清空 MCP 客户端资源 {name}: {e}。") + # 移除关联的FuncTool + self.func_list = [ + f + for f in self.func_list + if not (f.origin == "mcp" and f.mcp_server_name == name) + ] + logger.info(f"已关闭 MCP 服务 {name}") + def get_func_desc_openai_style(self) -> list: """ 获得 OpenAI API 风格的**已经激活**的工具描述 """ _l = [] + # 处理所有工具(包括本地和MCP工具) for f in self.func_list: if not f.active: continue @@ -144,7 +392,13 @@ class FuncCall: # 检查并添加非空的properties参数 params = f.parameters if isinstance(f.parameters, dict) else {} + params = copy.deepcopy(params) if params.get("properties", {}): + properties = params["properties"] + for key, value in properties.items(): + if "default" in value: + del value["default"] + params["properties"] = properties func_declaration["parameters"] = params tools.append(func_declaration) @@ -160,9 +414,9 @@ class FuncCall: continue _l.append( { - "name": f["name"], - "parameters": f["parameters"], - "description": f["description"], + "name": f.name, + "parameters": f.parameters, + "description": f.description, } ) func_definition = json.dumps(_l, ensure_ascii=False) @@ -212,14 +466,11 @@ class FuncCall: func_name = tool["name"] args = tool["args"] # 调用函数 - tool_callable = None - for func in self.func_list: - if func.name == func_name: - tool_callable = func.star_handler_metadata.handler - break - if not tool_callable: + func_tool = self.get_func(func_name) + if not func_tool: raise Exception(f"Request function {func_name} not found.") - ret = await tool_callable(**args) + + ret = await func_tool.execute(**args) if ret: tool_call_result.append(str(ret)) return tool_call_result, True @@ -229,3 +480,8 @@ class FuncCall: def __repr__(self): return str(self.func_list) + + async def terminate(self): + for name in self.mcp_client_dict.keys(): + await self._terminate_mcp_client(name) + logger.debug(f"清理 MCP 客户端 {name} 资源") diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py index 28146333f..ef9040445 100644 --- a/astrbot/core/provider/manager.py +++ b/astrbot/core/provider/manager.py @@ -1,4 +1,5 @@ import traceback +import asyncio from astrbot.core.config.astrbot_config import AstrBotConfig from .provider import Provider, STTProvider, TTSProvider, Personality from .entites import ProviderType @@ -127,6 +128,12 @@ class ProviderManager: if self.tts_enabled and not self.curr_tts_provider_inst: logger.warning("未启用任何用于 文本转语音 的提供商适配器。") + # 初始化 MCP Client 连接 + asyncio.create_task( + self.llm_tools.mcp_service_selector(), name="mcp-service-handler" + ) + self.llm_tools.mcp_service_queue.put_nowait({"type": "init"}) + async def load_provider(self, provider_config: dict): if not provider_config["enable"]: return @@ -339,3 +346,5 @@ class ProviderManager: for provider_inst in self.provider_insts: if hasattr(provider_inst, "terminate"): await provider_inst.terminate() + # 清理 MCP Client 连接 + await self.llm_tools.mcp_service_queue.put({"type": "terminate"}) diff --git a/astrbot/core/provider/provider.py b/astrbot/core/provider/provider.py index 57dc57f90..8dcff9a52 100644 --- a/astrbot/core/provider/provider.py +++ b/astrbot/core/provider/provider.py @@ -3,7 +3,7 @@ from typing import List from astrbot.core.db import BaseDatabase from typing import TypedDict from astrbot.core.provider.func_tool_manager import FuncCall -from astrbot.core.provider.entites import LLMResponse +from astrbot.core.provider.entites import LLMResponse, ToolCallsResult from dataclasses import dataclass @@ -90,6 +90,7 @@ class Provider(AbstractProvider): func_tool: FuncCall = None, contexts: List = None, system_prompt: str = None, + tool_calls_result: ToolCallsResult = None, **kwargs, ) -> LLMResponse: """获得 LLM 的文本对话结果。会使用当前的模型进行对话。 @@ -100,6 +101,7 @@ class Provider(AbstractProvider): image_urls: 图片 URL 列表 tools: Function-calling 工具 contexts: 上下文 + tool_calls_result: 回传给 LLM 的工具调用结果。参考: https://platform.openai.com/docs/guides/function-calling kwargs: 其他参数 Notes: diff --git a/astrbot/core/provider/sources/anthropic_source.py b/astrbot/core/provider/sources/anthropic_source.py index 13f7482d4..fd19c40ca 100644 --- a/astrbot/core/provider/sources/anthropic_source.py +++ b/astrbot/core/provider/sources/anthropic_source.py @@ -10,7 +10,7 @@ from astrbot.api.provider import Provider, Personality from astrbot import logger from astrbot.core.provider.func_tool_manager import FuncCall from ..register import register_provider_adapter -from astrbot.core.provider.entites import LLMResponse +from astrbot.core.provider.entites import LLMResponse, ToolCallsResult from .openai_source import ProviderOpenAIOfficial @@ -79,11 +79,14 @@ class ProviderAnthropic(ProviderOpenAIOfficial): # tools call (function calling) args_ls = [] func_name_ls = [] + tool_use_ids = [] func_name_ls.append(content.name) args_ls.append(content.input) + tool_use_ids.append(content.id) llm_response.role = "tool" llm_response.tools_call_args = args_ls llm_response.tools_call_name = func_name_ls + llm_response.tools_call_ids = tool_use_ids if not llm_response.completion_text and not llm_response.tools_call_args: logger.error(f"API 返回的 completion 无法解析:{completion}。") @@ -101,6 +104,7 @@ class ProviderAnthropic(ProviderOpenAIOfficial): func_tool: FuncCall = None, contexts=[], system_prompt=None, + tool_calls_result: ToolCallsResult = None, **kwargs, ) -> LLMResponse: if not prompt: @@ -113,6 +117,10 @@ class ProviderAnthropic(ProviderOpenAIOfficial): if "_no_save" in part: del part["_no_save"] + if tool_calls_result: + # 暂时这样写。 + prompt += f"Here are the related results via using tools: {str(tool_calls_result.tool_calls_result)}" + model_config = self.provider_config.get("model_config", {}) payloads = {"messages": context_query, **model_config} diff --git a/astrbot/core/provider/sources/gemini_source.py b/astrbot/core/provider/sources/gemini_source.py index 233cc8932..90a584235 100644 --- a/astrbot/core/provider/sources/gemini_source.py +++ b/astrbot/core/provider/sources/gemini_source.py @@ -1,5 +1,6 @@ import base64 import aiohttp +import json import random from astrbot.core.utils.io import download_image_by_url from astrbot.core.db import BaseDatabase @@ -146,11 +147,41 @@ class ProviderGoogleGenAI(Provider): google_genai_conversation.append({"role": "user", "parts": parts}) elif message["role"] == "assistant": - if not message["content"]: - message["content"] = "" - google_genai_conversation.append( - {"role": "model", "parts": [{"text": message["content"]}]} + if "content" in message: + if not message["content"]: + message["content"] = "" + google_genai_conversation.append( + {"role": "model", "parts": [{"text": message["content"]}]} + ) + elif "tool_calls" in message: + # tool calls in the last turn + parts = [] + for tool_call in message["tool_calls"]: + parts.append( + { + "functionCall": { + "name": tool_call["function"]["name"], + "args": json.loads( + tool_call["function"]["arguments"] + ), + } + } + ) + google_genai_conversation.append({"role": "model", "parts": parts}) + elif message["role"] == "tool": + parts = [] + parts.append( + { + "functionResponse": { + "name": message["tool_call_id"], + "response": { + "name": message["tool_call_id"], + "content": message["content"], + }, + } + } ) + google_genai_conversation.append({"role": "user", "parts": parts}) logger.debug(f"google_genai_conversation: {google_genai_conversation}") @@ -174,6 +205,9 @@ class ProviderGoogleGenAI(Provider): llm_response.role = "tool" llm_response.tools_call_args.append(candidate["functionCall"]["args"]) llm_response.tools_call_name.append(candidate["functionCall"]["name"]) + llm_response.tools_call_ids.append( + candidate["functionCall"]["name"] + ) # 没有 tool id llm_response.completion_text = llm_response.completion_text.strip() return llm_response @@ -186,6 +220,7 @@ class ProviderGoogleGenAI(Provider): func_tool: FuncCall = None, contexts=[], system_prompt=None, + tool_calls_result=None, **kwargs, ) -> LLMResponse: new_record = await self.assemble_context(prompt, image_urls) @@ -198,6 +233,10 @@ class ProviderGoogleGenAI(Provider): if "_no_save" in part: del part["_no_save"] + # tool calls result + if tool_calls_result: + context_query.extend(tool_calls_result.to_openai_messages()) + model_config = self.provider_config.get("model_config", {}) model_config["model"] = self.get_model() diff --git a/astrbot/core/provider/sources/openai_source.py b/astrbot/core/provider/sources/openai_source.py index 4b66c50ef..628b1849d 100644 --- a/astrbot/core/provider/sources/openai_source.py +++ b/astrbot/core/provider/sources/openai_source.py @@ -120,15 +120,18 @@ class ProviderOpenAIOfficial(Provider): # tools call (function calling) args_ls = [] func_name_ls = [] + tool_call_ids = [] for tool_call in choice.message.tool_calls: for tool in tools.func_list: if tool.name == tool_call.function.name: args = json.loads(tool_call.function.arguments) args_ls.append(args) func_name_ls.append(tool_call.function.name) + tool_call_ids.append(tool_call.id) llm_response.role = "tool" llm_response.tools_call_args = args_ls llm_response.tools_call_name = func_name_ls + llm_response.tools_call_ids = tool_call_ids if choice.finish_reason == "content_filter": raise Exception( @@ -151,6 +154,7 @@ class ProviderOpenAIOfficial(Provider): func_tool: FuncCall = None, contexts=[], system_prompt=None, + tool_calls_result=None, **kwargs, ) -> LLMResponse: new_record = await self.assemble_context(prompt, image_urls) @@ -162,10 +166,15 @@ class ProviderOpenAIOfficial(Provider): if "_no_save" in part: del part["_no_save"] + # tool calls result + if tool_calls_result: + context_query.extend(tool_calls_result.to_openai_messages()) + model_config = self.provider_config.get("model_config", {}) model_config["model"] = self.get_model() payloads = {"messages": context_query, **model_config} + llm_response = None try: llm_response = await self._query(payloads, func_tool) @@ -275,10 +284,8 @@ class ProviderOpenAIOfficial(Provider): def set_key(self, key): self.client.api_key = key - async def assemble_context(self, text: str, image_urls: List[str] = None): - """ - 组装上下文。 - """ + async def assemble_context(self, text: str, image_urls: List[str] = None) -> dict: + """组装成符合 OpenAI 格式的 role 为 user 的消息段""" if image_urls: user_content = {"role": "user", "content": [{"type": "text", "text": text}]} for image_url in image_urls: diff --git a/astrbot/dashboard/routes/__init__.py b/astrbot/dashboard/routes/__init__.py index f4107bdc5..2e5461981 100644 --- a/astrbot/dashboard/routes/__init__.py +++ b/astrbot/dashboard/routes/__init__.py @@ -6,6 +6,7 @@ from .stat import StatRoute from .log import LogRoute from .static_file import StaticFileRoute from .chat import ChatRoute +from .tools import ToolsRoute # 导入新的ToolsRoute __all__ = [ @@ -17,4 +18,5 @@ __all__ = [ "LogRoute", "StaticFileRoute", "ChatRoute", + "ToolsRoute", # 添加新的ToolsRoute ] diff --git a/astrbot/dashboard/routes/config.py b/astrbot/dashboard/routes/config.py index 14af21bbc..dcfe50d38 100644 --- a/astrbot/dashboard/routes/config.py +++ b/astrbot/dashboard/routes/config.py @@ -146,6 +146,7 @@ class ConfigRoute(Route): "/config/provider/new": ("POST", self.post_new_provider), "/config/provider/update": ("POST", self.post_update_provider), "/config/provider/delete": ("POST", self.post_delete_provider), + "/config/llmtools": ("GET", self.get_llm_tools), } self.register_routes() @@ -278,6 +279,12 @@ class ConfigRoute(Route): return Response().error(str(e)).__dict__ return Response().ok(None, "删除成功,已经实时生效~").__dict__ + async def get_llm_tools(self): + """获取函数调用工具。包含了本地加载的以及 MCP 服务的工具""" + tool_mgr = self.core_lifecycle.provider_manager.llm_tools + tools = tool_mgr.get_func_desc_openai_style() + return Response().ok(tools).__dict__ + async def _get_astrbot_config(self): config = self.config diff --git a/astrbot/dashboard/routes/tools.py b/astrbot/dashboard/routes/tools.py new file mode 100644 index 000000000..36da48ce4 --- /dev/null +++ b/astrbot/dashboard/routes/tools.py @@ -0,0 +1,252 @@ +import os +import json +import traceback +from .route import Route, Response, RouteContext +from quart import request +from astrbot.core.core_lifecycle import AstrBotCoreLifecycle +from astrbot.core import logger + +DEFAULT_MCP_CONFIG = {"mcpServers": {}} + + +class ToolsRoute(Route): + def __init__( + self, context: RouteContext, core_lifecycle: AstrBotCoreLifecycle + ) -> None: + super().__init__(context) + self.core_lifecycle = core_lifecycle + self.routes = { + "/tools/mcp/servers": ("GET", self.get_mcp_servers), + "/tools/mcp/add": ("POST", self.add_mcp_server), + "/tools/mcp/update": ("POST", self.update_mcp_server), + "/tools/mcp/delete": ("POST", self.delete_mcp_server), + } + self.register_routes() + self.tool_mgr = self.core_lifecycle.provider_manager.llm_tools + + @property + def mcp_config_path(self): + current_dir = os.path.dirname(os.path.abspath(__file__)) + data_dir = os.path.abspath(os.path.join(current_dir, "../../../data")) + return os.path.join(data_dir, "mcp_server.json") + + def load_mcp_config(self): + if not os.path.exists(self.mcp_config_path): + # 配置文件不存在,创建默认配置 + os.makedirs(os.path.dirname(self.mcp_config_path), exist_ok=True) + with open(self.mcp_config_path, "w", encoding="utf-8") as f: + json.dump(DEFAULT_MCP_CONFIG, f, ensure_ascii=False, indent=4) + return DEFAULT_MCP_CONFIG + + try: + with open(self.mcp_config_path, "r", encoding="utf-8") as f: + return json.load(f) + except Exception as e: + logger.error(f"加载 MCP 配置失败: {e}") + return DEFAULT_MCP_CONFIG + + def save_mcp_config(self, config): + try: + with open(self.mcp_config_path, "w", encoding="utf-8") as f: + json.dump(config, f, ensure_ascii=False, indent=4) + return True + except Exception as e: + logger.error(f"保存 MCP 配置失败: {e}") + return False + + async def get_mcp_servers(self): + try: + config = self.load_mcp_config() + servers = [] + + # 获取所有服务器并添加它们的工具列表 + for name, server_config in config["mcpServers"].items(): + server_info = { + "name": name, + "active": server_config.get("active", True), + } + + # 复制所有配置字段 + for key, value in server_config.items(): + if key != "active": # active 已经处理 + server_info[key] = value + + # 如果MCP客户端已初始化,从客户端获取工具名称 + for ( + name_key, + mcp_client, + ) in self.tool_mgr.mcp_client_dict.items(): + if name_key == name: + server_info["tools"] = [tool.name for tool in mcp_client.tools] + break + else: + server_info["tools"] = [] + + servers.append(server_info) + + return Response().ok(servers).__dict__ + except Exception as e: + logger.error(traceback.format_exc()) + return Response().error(f"获取 MCP 服务器列表失败: {str(e)}").__dict__ + + async def add_mcp_server(self): + try: + server_data = await request.json + + name = server_data.get("name", "") + + # 检查必填字段 + if not name: + return Response().error("服务器名称不能为空").__dict__ + + # 移除特殊字段并检查配置是否有效 + has_valid_config = False + server_config = {"active": server_data.get("active", True)} + + # 复制所有配置字段 + for key, value in server_data.items(): + if key not in ["name", "active", "tools"]: # 排除特殊字段 + server_config[key] = value + has_valid_config = True + + if not has_valid_config: + return Response().error("必须提供有效的服务器配置").__dict__ + + config = self.load_mcp_config() + + if name in config["mcpServers"]: + return Response().error(f"服务器 {name} 已存在").__dict__ + + config["mcpServers"][name] = server_config + + if self.save_mcp_config(config): + # 动态初始化新MCP客户端 + self.tool_mgr.mcp_service_queue.put_nowait( + { + "type": "init", + "name": name, + "cfg": config["mcpServers"][name], + } + ) + return Response().ok(None, f"成功添加 MCP 服务器 {name}").__dict__ + else: + return Response().error("保存配置失败").__dict__ + except Exception as e: + logger.error(traceback.format_exc()) + return Response().error(f"添加 MCP 服务器失败: {str(e)}").__dict__ + + async def update_mcp_server(self): + try: + server_data = await request.json + + name = server_data.get("name", "") + + if not name: + return Response().error("服务器名称不能为空").__dict__ + + config = self.load_mcp_config() + + if name not in config["mcpServers"]: + return Response().error(f"服务器 {name} 不存在").__dict__ + + # 获取活动状态 + active = server_data.get( + "active", config["mcpServers"][name].get("active", True) + ) + + # 创建新的配置对象 + server_config = {"active": active} + + # 仅更新活动状态的特殊处理 + only_update_active = True + + # 复制所有配置字段 + for key, value in server_data.items(): + if key not in ["name", "active", "tools"]: # 排除特殊字段 + server_config[key] = value + only_update_active = False + + # 如果只更新活动状态,保留原始配置 + if only_update_active: + for key, value in config["mcpServers"][name].items(): + if key != "active": # 除了active之外的所有字段都保留 + server_config[key] = value + + config["mcpServers"][name] = server_config + + if self.save_mcp_config(config): + # 处理MCP客户端状态变化 + if active: + # 如果要激活服务器或者配置已更改 + if name in self.tool_mgr.mcp_client_dict or not only_update_active: + await self.tool_mgr.mcp_service_queue.put( + { + "type": "terminate", + "name": name, + } + ) + await self.tool_mgr.mcp_service_queue.put( + { + "type": "init", + "name": name, + "cfg": config["mcpServers"][name], + } + ) + else: + # 客户端不存在,初始化 + self.tool_mgr.mcp_service_queue.put_nowait( + { + "type": "init", + "name": name, + "cfg": config["mcpServers"][name], + } + ) + else: + # 如果要停用服务器 + if name in self.tool_mgr.mcp_client_dict: + self.tool_mgr.mcp_service_queue.put_nowait( + { + "type": "terminate", + "name": name, + } + ) + + return Response().ok(None, f"成功更新 MCP 服务器 {name}").__dict__ + else: + return Response().error("保存配置失败").__dict__ + except Exception as e: + logger.error(traceback.format_exc()) + return Response().error(f"更新 MCP 服务器失败: {str(e)}").__dict__ + + async def delete_mcp_server(self): + try: + server_data = await request.json + name = server_data.get("name", "") + + if not name: + return Response().error("服务器名称不能为空").__dict__ + + config = self.load_mcp_config() + + if name not in config["mcpServers"]: + return Response().error(f"服务器 {name} 不存在").__dict__ + + # 删除服务器配置 + del config["mcpServers"][name] + + if self.save_mcp_config(config): + # 关闭并删除MCP客户端 + if name in self.tool_mgr.mcp_client_dict: + self.tool_mgr.mcp_service_queue.put_nowait( + { + "type": "terminate", + "name": name, + } + ) + + return Response().ok(None, f"成功删除 MCP 服务器 {name}").__dict__ + else: + return Response().error("保存配置失败").__dict__ + except Exception as e: + logger.error(traceback.format_exc()) + return Response().error(f"删除 MCP 服务器失败: {str(e)}").__dict__ diff --git a/astrbot/dashboard/server.py b/astrbot/dashboard/server.py index 45aac3cd6..9af11dd53 100644 --- a/astrbot/dashboard/server.py +++ b/astrbot/dashboard/server.py @@ -50,6 +50,7 @@ class AstrBotDashboard: self.sfr = StaticFileRoute(self.context) self.ar = AuthRoute(self.context) self.chat_route = ChatRoute(self.context, db, core_lifecycle) + self.tools_root = ToolsRoute(self.context, core_lifecycle) self.shutdown_event = shutdown_event diff --git a/dashboard/src/components/shared/ItemCardGrid.vue b/dashboard/src/components/shared/ItemCardGrid.vue new file mode 100644 index 000000000..a1ed1609e --- /dev/null +++ b/dashboard/src/components/shared/ItemCardGrid.vue @@ -0,0 +1,134 @@ + + + + + diff --git a/dashboard/src/layouts/full/vertical-sidebar/sidebarItem.ts b/dashboard/src/layouts/full/vertical-sidebar/sidebarItem.ts index d2fbb556e..5325e0cfd 100644 --- a/dashboard/src/layouts/full/vertical-sidebar/sidebarItem.ts +++ b/dashboard/src/layouts/full/vertical-sidebar/sidebarItem.ts @@ -45,6 +45,11 @@ const sidebarItem: menu[] = [ icon: 'mdi-storefront', to: '/extension-marketplace' }, + { + title: '函数调用', + icon: 'mdi-function-variant', + to: '/tool-use' + }, { title: '聊天', icon: 'mdi-chat', diff --git a/dashboard/src/router/MainRoutes.ts b/dashboard/src/router/MainRoutes.ts index 5fb0abeea..9ce2402ca 100644 --- a/dashboard/src/router/MainRoutes.ts +++ b/dashboard/src/router/MainRoutes.ts @@ -31,6 +31,11 @@ const MainRoutes = { path: '/providers', component: () => import('@/views/ProviderPage.vue') }, + { + name: 'ToolUsePage', + path: '/tool-use', + component: () => import('@/views/ToolUsePage.vue') + }, { name: 'Configs', path: '/config', diff --git a/dashboard/src/views/PlatformPage.vue b/dashboard/src/views/PlatformPage.vue index ac3d24ad5..38fa08e68 100644 --- a/dashboard/src/views/PlatformPage.vue +++ b/dashboard/src/views/PlatformPage.vue @@ -1,316 +1,304 @@ - - \ No newline at end of file diff --git a/dashboard/src/views/ProviderPage.vue b/dashboard/src/views/ProviderPage.vue index c29247fb3..e2e1b796b 100644 --- a/dashboard/src/views/ProviderPage.vue +++ b/dashboard/src/views/ProviderPage.vue @@ -1,291 +1,328 @@ - - \ No newline at end of file diff --git a/dashboard/src/views/ToolUsePage.vue b/dashboard/src/views/ToolUsePage.vue new file mode 100644 index 000000000..39164ad36 --- /dev/null +++ b/dashboard/src/views/ToolUsePage.vue @@ -0,0 +1,648 @@ + + + + + \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 07b168cc1..313dba0c8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -24,4 +24,5 @@ cryptography dashscope python-telegram-bot wechatpy -dingtalk-stream \ No newline at end of file +dingtalk-stream +mcp \ No newline at end of file