diff --git a/astrbot/core/pipeline/process_stage/method/llm_request.py b/astrbot/core/pipeline/process_stage/method/llm_request.py index 9689ef34e..acde70c27 100644 --- a/astrbot/core/pipeline/process_stage/method/llm_request.py +++ b/astrbot/core/pipeline/process_stage/method/llm_request.py @@ -16,7 +16,13 @@ from astrbot.core.message.message_event_result import ( from astrbot.core.message.components import Image from astrbot.core import logger from astrbot.core.utils.metrics import Metric -from astrbot.core.provider.entites import ProviderRequest, LLMResponse +from astrbot.core.provider.entites import ( + ProviderRequest, + LLMResponse, + ToolCallMessageSegment, + AssistantMessageSegment, + ToolCallsResult, +) from astrbot.core.star.star_handler import star_handlers_registry, EventType from astrbot.core.star.star import star_map @@ -111,10 +117,18 @@ class LLMRequestSubStage(Stage): req.contexts = json.loads(req.contexts) try: - logger.debug(f"提供商请求 Payload: {req}") - if _nested: - req.func_tool = None # 暂时不支持递归工具调用 - llm_response = await provider.text_chat(**req.__dict__) # 请求 LLM + need_loop = True + while need_loop: + need_loop = False + logger.debug(f"提供商请求 Payload: {req}") + llm_response = await provider.text_chat(**req.__dict__) # 请求 LLM + async for result in self._handle_llm_response(event, req, llm_response): + if isinstance(result, ProviderRequest): + # 有函数工具调用并且返回了结果,我们需要再次请求 LLM + req = result + need_loop = True + else: + yield # 执行 LLM 响应后的事件钩子。 handlers = star_handlers_registry.get_handlers_by_event_type( @@ -135,9 +149,6 @@ class LLMRequestSubStage(Stage): ) return - # 保存到历史记录 - await self._save_to_history(event, req, llm_response) - asyncio.create_task( Metric.upload( llm_tick=1, @@ -146,88 +157,8 @@ class LLMRequestSubStage(Stage): ) ) - if llm_response.role == "assistant": - # text completion - if llm_response.result_chain: - event.set_result( - MessageEventResult( - chain=llm_response.result_chain.chain - ).set_result_content_type(ResultContentType.LLM_RESULT) - ) - else: - event.set_result( - MessageEventResult() - .message(llm_response.completion_text) - .set_result_content_type(ResultContentType.LLM_RESULT) - ) - elif llm_response.role == "err": - event.set_result( - MessageEventResult().message( - f"AstrBot 请求失败。\n错误信息: {llm_response.completion_text}" - ) - ) - elif llm_response.role == "tool": - # function calling - function_calling_result = {} - logger.info( - f"触发 {len(llm_response.tools_call_name)} 个函数调用: {llm_response.tools_call_name}" - ) - for func_tool_name, func_tool_args in zip( - llm_response.tools_call_name, llm_response.tools_call_args - ): - try: - func_tool = req.func_tool.get_func(func_tool_name) - if func_tool.origin == "mcp": - logger.info( - f"从 MCP 服务 {func_tool.mcp_server_name} 调用工具函数:{func_tool.name},参数:{func_tool_args}" - ) - client = req.func_tool.mcp_client_dict[ - func_tool.mcp_server_name - ] - res = await client.session.call_tool( - func_tool.name, func_tool_args - ) - if res: - # TODO content的类型可能包括list[TextContent | ImageContent | EmbeddedResource],这里只处理了TextContent。 - res_event = event.plain_result(res.content[0].text) - event.set_result(res_event) - yield - else: - logger.info( - f"调用工具函数:{func_tool_name},参数:{func_tool_args}" - ) - # 尝试调用工具函数 - wrapper = self._call_handler( - self.ctx, event, func_tool.handler, **func_tool_args - ) - async for resp in wrapper: - if resp is not None: # 有 return 返回 - function_calling_result[func_tool_name] = resp - else: - yield # 有生成器返回 - event.clear_result() # 清除上一个 handler 的结果 - except BaseException as e: - logger.warning(traceback.format_exc()) - function_calling_result[func_tool_name] = ( - "When calling the function, an error occurred: " + str(e) - ) - if function_calling_result: - # 工具返回 LLM 资源。比如 RAG、网页 得到的相关结果等。 - # 我们重新执行一遍这个 stage - req.func_tool = None # 暂时不支持递归工具调用 - extra_prompt = "\n\nSystem executed some external tools for this task and here are the results:\n" - for tool_name, tool_result in function_calling_result.items(): - extra_prompt += ( - f"Tool: {tool_name}\nTool Result: {tool_result}\n" - ) - req.prompt += extra_prompt - async for _ in self.process(event, _nested=True): - yield - else: - if llm_response.completion_text: - event.set_result( - MessageEventResult().message(llm_response.completion_text) - ) + # 保存到历史记录 + await self._save_to_history(event, req, llm_response) except BaseException as e: logger.error(traceback.format_exc()) @@ -238,6 +169,116 @@ class LLMRequestSubStage(Stage): ) return + async def _handle_llm_response( + self, event: AstrMessageEvent, req: ProviderRequest, llm_response: LLMResponse + ) -> AsyncGenerator[None, None]: + """处理 LLM 响应。 + + Returns: + bool: 是否需要继续调用 LLM + + Yields: + Iterator[bool]: 将 event 交付给下一个 stage + """ + if llm_response.role == "assistant": + # text completion + if llm_response.result_chain: + event.set_result( + MessageEventResult( + chain=llm_response.result_chain.chain + ).set_result_content_type(ResultContentType.LLM_RESULT) + ) + else: + event.set_result( + MessageEventResult() + .message(llm_response.completion_text) + .set_result_content_type(ResultContentType.LLM_RESULT) + ) + elif llm_response.role == "err": + event.set_result( + MessageEventResult().message( + f"AstrBot 请求失败。\n错误信息: {llm_response.completion_text}" + ) + ) + elif llm_response.role == "tool": + # function calling + tool_call_result: list[ToolCallMessageSegment] = [] + logger.info( + f"触发 {len(llm_response.tools_call_name)} 个函数调用: {llm_response.tools_call_name}" + ) + for func_tool_name, func_tool_args, func_tool_id in zip( + llm_response.tools_call_name, + llm_response.tools_call_args, + llm_response.tools_call_ids, + ): + try: + func_tool = req.func_tool.get_func(func_tool_name) + if func_tool.origin == "mcp": + logger.info( + f"从 MCP 服务 {func_tool.mcp_server_name} 调用工具函数:{func_tool.name},参数:{func_tool_args}" + ) + client = req.func_tool.mcp_client_dict[ + func_tool.mcp_server_name + ] + res = await client.session.call_tool( + func_tool.name, func_tool_args + ) + if res: + # TODO content的类型可能包括list[TextContent | ImageContent | EmbeddedResource],这里只处理了TextContent。 + tool_call_result.append( + ToolCallMessageSegment( + role="tool", + tool_call_id=func_tool_id, + content=res.content[0].text, + ) + ) + else: + logger.info( + f"调用工具函数:{func_tool_name},参数:{func_tool_args}" + ) + # 尝试调用工具函数 + wrapper = self._call_handler( + self.ctx, event, func_tool.handler, **func_tool_args + ) + async for resp in wrapper: + if resp is not None: # 有 return 返回 + tool_call_result.append( + ToolCallMessageSegment( + role="tool", + tool_call_id=func_tool_id, + content=resp, + ) + ) + else: + yield # 有生成器返回 + event.clear_result() # 清除上一个 handler 的结果 + except BaseException as e: + logger.warning(traceback.format_exc()) + tool_call_result.append( + ToolCallMessageSegment( + role="tool", + tool_call_id=func_tool_id, + content=f"error: {str(e)}", + ) + ) + if tool_call_result: + # 函数调用结果 + req.func_tool = None # 暂时不支持递归工具调用 + assistant_msg_seg = AssistantMessageSegment( + role="assistant", tool_calls=llm_response.to_openai_tool_calls() + ) + # 在多轮 Tool 调用的情况下,这里始终保持最新的 Tool 调用结果,减少上下文长度。 + req.tool_calls_result = ToolCallsResult( + tool_calls_info=assistant_msg_seg, + tool_calls_result=tool_call_result, + ) + yield req # 再次执行 LLM 请求 + else: + if llm_response.completion_text: + event.set_result( + MessageEventResult().message(llm_response.completion_text) + ) + async def _save_to_history( self, event: AstrMessageEvent, req: ProviderRequest, llm_response: LLMResponse ): @@ -248,6 +289,13 @@ class LLMRequestSubStage(Stage): # 文本回复 contexts = req.contexts contexts.append(await req.assemble_context()) + + # tool calls result + if req.tool_calls_result: + contexts.extend( + req.tool_calls_result.to_openai_messages() + ) + contexts.append( {"role": "assistant", "content": llm_response.completion_text} ) diff --git a/astrbot/core/provider/entites.py b/astrbot/core/provider/entites.py index 4236214d4..a8ffcdf64 100644 --- a/astrbot/core/provider/entites.py +++ b/astrbot/core/provider/entites.py @@ -1,11 +1,15 @@ import enum import base64 +import json from astrbot.core.utils.io import download_image_by_url from astrbot import logger from dataclasses import dataclass, field from typing import List, Dict, Type from .func_tool_manager import FuncCall from openai.types.chat.chat_completion import ChatCompletion +from openai.types.chat.chat_completion_message_tool_call import ( + ChatCompletionMessageToolCall, +) from astrbot.core.db.po import Conversation from astrbot.core.message.message_event_result import MessageChain import astrbot.core.message.components as Comp @@ -32,6 +36,58 @@ class ProviderMetaData: """显示在 WebUI 配置页中的提供商名称,如空则是 type""" +@dataclass +class ToolCallMessageSegment: + """OpenAI 格式的上下文中 role 为 tool 的消息段。参考: https://platform.openai.com/docs/guides/function-calling""" + + tool_call_id: str + content: str + role: str = "tool" + + def to_dict(self): + return { + "tool_call_id": self.tool_call_id, + "content": self.content, + "role": self.role, + } + + +@dataclass +class AssistantMessageSegment: + """OpenAI 格式的上下文中 role 为 assistant 的消息段。参考: https://platform.openai.com/docs/guides/function-calling""" + + content: str = None + tool_calls: List[ChatCompletionMessageToolCall | Dict] = None + role: str = "assistant" + + def to_dict(self): + ret = { + "role": self.role, + } + if self.content: + ret["content"] = self.content + elif self.tool_calls: + ret["tool_calls"] = self.tool_calls + return ret + + +@dataclass +class ToolCallsResult: + """工具调用结果""" + + tool_calls_info: AssistantMessageSegment + """函数调用的信息""" + tool_calls_result: List[ToolCallMessageSegment] + """函数调用的结果""" + + def to_openai_messages(self) -> List[Dict]: + ret = [ + self.tool_calls_info.to_dict(), + *[item.to_dict() for item in self.tool_calls_result], + ] + return ret + + @dataclass class ProviderRequest: prompt: str @@ -41,7 +97,7 @@ class ProviderRequest: image_urls: List[str] = None """图片 URL 列表""" func_tool: FuncCall = None - """工具""" + """可用的函数工具""" contexts: List = None """上下文。格式与 openai 的上下文格式一致: 参考 https://platform.openai.com/docs/api-reference/chat/create#chat-create-messages @@ -50,8 +106,11 @@ class ProviderRequest: """系统提示词""" conversation: Conversation = None + tool_calls_result: ToolCallsResult = None + """附加的上次请求后工具调用的结果。参考: https://platform.openai.com/docs/guides/function-calling#handling-function-calls""" + def __repr__(self): - return f"ProviderRequest(prompt={self.prompt}, session_id={self.session_id}, image_urls={self.image_urls}, func_tool={self.func_tool}, contexts={self._print_friendly_context()}, system_prompt={self.system_prompt.strip()})" + return f"ProviderRequest(prompt={self.prompt}, session_id={self.session_id}, image_urls={self.image_urls}, func_tool={self.func_tool}, contexts={self._print_friendly_context()}, system_prompt={self.system_prompt.strip()}, tool_calls_result={self.tool_calls_result})" def __str__(self): return self.__repr__() @@ -137,6 +196,8 @@ class LLMResponse: """工具调用参数""" tools_call_name: List[str] = field(default_factory=list) """工具调用名称""" + tools_call_ids: List[str] = field(default_factory=list) + """工具调用 ID""" raw_completion: ChatCompletion = None _new_record: Dict[str, any] = None @@ -148,8 +209,9 @@ class LLMResponse: role: str, completion_text: str = "", result_chain: MessageChain = None, - tools_call_args: List[Dict[str, any]] = None, - tools_call_name: List[str] = None, + tools_call_args: List[Dict[str, any]] = [], + tools_call_name: List[str] = [], + tools_call_ids: List[str] = [], raw_completion: ChatCompletion = None, _new_record: Dict[str, any] = None, ): @@ -168,6 +230,7 @@ class LLMResponse: self.result_chain = result_chain self.tools_call_args = tools_call_args self.tools_call_name = tools_call_name + self.tools_call_ids = tools_call_ids self.raw_completion = raw_completion self._new_record = _new_record @@ -188,3 +251,19 @@ class LLMResponse: self.result_chain.chain.insert(0, Comp.Plain(value)) else: self._completion_text = value + + def to_openai_tool_calls(self) -> List[Dict]: + """将工具调用信息转换为 OpenAI 格式""" + ret = [] + for idx, tool_call_arg in enumerate(self.tools_call_args): + ret.append( + { + "id": self.tools_call_ids[idx], + "function": { + "name": self.tools_call_name[idx], + "arguments": json.dumps(tool_call_arg), + }, + "type": "function", + } + ) + return ret diff --git a/astrbot/core/provider/func_tool_manager.py b/astrbot/core/provider/func_tool_manager.py index fd4a14f7f..52f31c363 100644 --- a/astrbot/core/provider/func_tool_manager.py +++ b/astrbot/core/provider/func_tool_manager.py @@ -4,6 +4,7 @@ import textwrap import os import asyncio import mcp +import copy from typing import Dict, List, Awaitable, Literal, Any from dataclasses import dataclass @@ -391,7 +392,13 @@ class FuncCall: # 检查并添加非空的properties参数 params = f.parameters if isinstance(f.parameters, dict) else {} + params = copy.deepcopy(params) if params.get("properties", {}): + properties = params["properties"] + for key, value in properties.items(): + if "default" in value: + del value["default"] + params["properties"] = properties func_declaration["parameters"] = params tools.append(func_declaration) diff --git a/astrbot/core/provider/provider.py b/astrbot/core/provider/provider.py index 57dc57f90..8dcff9a52 100644 --- a/astrbot/core/provider/provider.py +++ b/astrbot/core/provider/provider.py @@ -3,7 +3,7 @@ from typing import List from astrbot.core.db import BaseDatabase from typing import TypedDict from astrbot.core.provider.func_tool_manager import FuncCall -from astrbot.core.provider.entites import LLMResponse +from astrbot.core.provider.entites import LLMResponse, ToolCallsResult from dataclasses import dataclass @@ -90,6 +90,7 @@ class Provider(AbstractProvider): func_tool: FuncCall = None, contexts: List = None, system_prompt: str = None, + tool_calls_result: ToolCallsResult = None, **kwargs, ) -> LLMResponse: """获得 LLM 的文本对话结果。会使用当前的模型进行对话。 @@ -100,6 +101,7 @@ class Provider(AbstractProvider): image_urls: 图片 URL 列表 tools: Function-calling 工具 contexts: 上下文 + tool_calls_result: 回传给 LLM 的工具调用结果。参考: https://platform.openai.com/docs/guides/function-calling kwargs: 其他参数 Notes: diff --git a/astrbot/core/provider/sources/anthropic_source.py b/astrbot/core/provider/sources/anthropic_source.py index 13f7482d4..90efdee91 100644 --- a/astrbot/core/provider/sources/anthropic_source.py +++ b/astrbot/core/provider/sources/anthropic_source.py @@ -10,7 +10,7 @@ from astrbot.api.provider import Provider, Personality from astrbot import logger from astrbot.core.provider.func_tool_manager import FuncCall from ..register import register_provider_adapter -from astrbot.core.provider.entites import LLMResponse +from astrbot.core.provider.entites import LLMResponse, ToolCallsResult from .openai_source import ProviderOpenAIOfficial @@ -79,11 +79,14 @@ class ProviderAnthropic(ProviderOpenAIOfficial): # tools call (function calling) args_ls = [] func_name_ls = [] + tool_use_ids = [] func_name_ls.append(content.name) args_ls.append(content.input) + tool_use_ids.append(content.id) llm_response.role = "tool" llm_response.tools_call_args = args_ls llm_response.tools_call_name = func_name_ls + llm_response.tools_call_ids = tool_use_ids if not llm_response.completion_text and not llm_response.tools_call_args: logger.error(f"API 返回的 completion 无法解析:{completion}。") @@ -101,6 +104,7 @@ class ProviderAnthropic(ProviderOpenAIOfficial): func_tool: FuncCall = None, contexts=[], system_prompt=None, + tool_calls_result: ToolCallsResult=None, **kwargs, ) -> LLMResponse: if not prompt: @@ -113,6 +117,10 @@ class ProviderAnthropic(ProviderOpenAIOfficial): if "_no_save" in part: del part["_no_save"] + if tool_calls_result: + # 暂时这样写。 + prompt += f"Here are the related results via using tools: {str(tool_calls_result.tool_calls_result)}" + model_config = self.provider_config.get("model_config", {}) payloads = {"messages": context_query, **model_config} diff --git a/astrbot/core/provider/sources/gemini_source.py b/astrbot/core/provider/sources/gemini_source.py index 233cc8932..1caccf2b1 100644 --- a/astrbot/core/provider/sources/gemini_source.py +++ b/astrbot/core/provider/sources/gemini_source.py @@ -1,5 +1,6 @@ import base64 import aiohttp +import json import random from astrbot.core.utils.io import download_image_by_url from astrbot.core.db import BaseDatabase @@ -115,6 +116,7 @@ class ProviderGoogleGenAI(Provider): break google_genai_conversation = [] + print(payloads) for message in payloads["messages"]: if message["role"] == "user": if isinstance(message["content"], str): @@ -146,11 +148,39 @@ class ProviderGoogleGenAI(Provider): google_genai_conversation.append({"role": "user", "parts": parts}) elif message["role"] == "assistant": - if not message["content"]: - message["content"] = "" - google_genai_conversation.append( - {"role": "model", "parts": [{"text": message["content"]}]} + if "content" in message: + if not message["content"]: + message["content"] = "" + google_genai_conversation.append( + {"role": "model", "parts": [{"text": message["content"]}]} + ) + elif "tool_calls" in message: + # tool calls in the last turn + parts = [] + for tool_call in message["tool_calls"]: + parts.append( + { + "functionCall": { + "name": tool_call["function"]["name"], + "args": json.loads(tool_call["function"]["arguments"]), + } + } + ) + google_genai_conversation.append({"role": "model", "parts": parts}) + elif message["role"] == "tool": + parts = [] + parts.append( + { + "functionResponse": { + "name": message["tool_call_id"], + "response": { + "name": message["tool_call_id"], + "content": message["content"], + }, + } + } ) + google_genai_conversation.append({"role": "user", "parts": parts}) logger.debug(f"google_genai_conversation: {google_genai_conversation}") @@ -174,6 +204,7 @@ class ProviderGoogleGenAI(Provider): llm_response.role = "tool" llm_response.tools_call_args.append(candidate["functionCall"]["args"]) llm_response.tools_call_name.append(candidate["functionCall"]["name"]) + llm_response.tools_call_ids.append(candidate["functionCall"]["name"]) # 没有 tool id llm_response.completion_text = llm_response.completion_text.strip() return llm_response @@ -186,6 +217,7 @@ class ProviderGoogleGenAI(Provider): func_tool: FuncCall = None, contexts=[], system_prompt=None, + tool_calls_result=None, **kwargs, ) -> LLMResponse: new_record = await self.assemble_context(prompt, image_urls) @@ -198,6 +230,10 @@ class ProviderGoogleGenAI(Provider): if "_no_save" in part: del part["_no_save"] + # tool calls result + if tool_calls_result: + context_query.extend(tool_calls_result.to_openai_messages()) + model_config = self.provider_config.get("model_config", {}) model_config["model"] = self.get_model() diff --git a/astrbot/core/provider/sources/openai_source.py b/astrbot/core/provider/sources/openai_source.py index 4b66c50ef..628b1849d 100644 --- a/astrbot/core/provider/sources/openai_source.py +++ b/astrbot/core/provider/sources/openai_source.py @@ -120,15 +120,18 @@ class ProviderOpenAIOfficial(Provider): # tools call (function calling) args_ls = [] func_name_ls = [] + tool_call_ids = [] for tool_call in choice.message.tool_calls: for tool in tools.func_list: if tool.name == tool_call.function.name: args = json.loads(tool_call.function.arguments) args_ls.append(args) func_name_ls.append(tool_call.function.name) + tool_call_ids.append(tool_call.id) llm_response.role = "tool" llm_response.tools_call_args = args_ls llm_response.tools_call_name = func_name_ls + llm_response.tools_call_ids = tool_call_ids if choice.finish_reason == "content_filter": raise Exception( @@ -151,6 +154,7 @@ class ProviderOpenAIOfficial(Provider): func_tool: FuncCall = None, contexts=[], system_prompt=None, + tool_calls_result=None, **kwargs, ) -> LLMResponse: new_record = await self.assemble_context(prompt, image_urls) @@ -162,10 +166,15 @@ class ProviderOpenAIOfficial(Provider): if "_no_save" in part: del part["_no_save"] + # tool calls result + if tool_calls_result: + context_query.extend(tool_calls_result.to_openai_messages()) + model_config = self.provider_config.get("model_config", {}) model_config["model"] = self.get_model() payloads = {"messages": context_query, **model_config} + llm_response = None try: llm_response = await self._query(payloads, func_tool) @@ -275,10 +284,8 @@ class ProviderOpenAIOfficial(Provider): def set_key(self, key): self.client.api_key = key - async def assemble_context(self, text: str, image_urls: List[str] = None): - """ - 组装上下文。 - """ + async def assemble_context(self, text: str, image_urls: List[str] = None) -> dict: + """组装成符合 OpenAI 格式的 role 为 user 的消息段""" if image_urls: user_content = {"role": "user", "content": [{"type": "text", "text": text}]} for image_url in image_urls: