From d19945009f9bf4c2e619641c57c90561bb9520aa Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Fri, 14 Nov 2025 19:17:24 +0800 Subject: [PATCH] refactor: decople the agent impl part and introduce some helper context method to call llm --- astrbot/core/agent/runners/base.py | 7 + .../agent/runners/tool_loop_agent_runner.py | 10 + astrbot/core/astr_agent_context.py | 8 +- astrbot/core/astr_agent_hooks.py | 36 ++ astrbot/core/astr_agent_run_util.py | 77 ++++ astrbot/core/astr_agent_tool_exec.py | 301 +++++++++++++++ astrbot/core/exceptions.py | 9 + astrbot/core/pipeline/context.py | 3 +- astrbot/core/pipeline/context_utils.py | 65 ---- .../process_stage/method/llm_request.py | 348 +----------------- astrbot/core/provider/entities.py | 4 +- astrbot/core/star/context.py | 144 +++++++- 12 files changed, 604 insertions(+), 408 deletions(-) create mode 100644 astrbot/core/astr_agent_hooks.py create mode 100644 astrbot/core/astr_agent_run_util.py create mode 100644 astrbot/core/astr_agent_tool_exec.py create mode 100644 astrbot/core/exceptions.py diff --git a/astrbot/core/agent/runners/base.py b/astrbot/core/agent/runners/base.py index c7cd36d96..f7e0913b4 100644 --- a/astrbot/core/agent/runners/base.py +++ b/astrbot/core/agent/runners/base.py @@ -40,6 +40,13 @@ class BaseAgentRunner(T.Generic[TContext]): """Process a single step of the agent.""" ... + @abc.abstractmethod + async def step_until_done( + self, max_step: int + ) -> T.AsyncGenerator[AgentResponse, None]: + """Process steps until the agent is done.""" + ... + @abc.abstractmethod def done(self) -> bool: """Check if the agent has completed its task. diff --git a/astrbot/core/agent/runners/tool_loop_agent_runner.py b/astrbot/core/agent/runners/tool_loop_agent_runner.py index 23071d446..f6e613679 100644 --- a/astrbot/core/agent/runners/tool_loop_agent_runner.py +++ b/astrbot/core/agent/runners/tool_loop_agent_runner.py @@ -177,6 +177,16 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]): ) self.req.append_tool_calls_result(tool_calls_result) + async def step_until_done( + self, max_step: int + ) -> T.AsyncGenerator[AgentResponse, None]: + """Process steps until the agent is done.""" + step_count = 0 + while not self.done() and step_count < max_step: + step_count += 1 + async for resp in self.step(): + yield resp + async def _handle_function_tools( self, req: ProviderRequest, diff --git a/astrbot/core/astr_agent_context.py b/astrbot/core/astr_agent_context.py index 28b242253..ffe8a199b 100644 --- a/astrbot/core/astr_agent_context.py +++ b/astrbot/core/astr_agent_context.py @@ -1,14 +1,14 @@ from dataclasses import dataclass +from astrbot.core.agent.run_context import ContextWrapper from astrbot.core.platform.astr_message_event import AstrMessageEvent from astrbot.core.provider import Provider -from astrbot.core.provider.entities import ProviderRequest @dataclass class AstrAgentContext: provider: Provider - first_provider_request: ProviderRequest - curr_provider_request: ProviderRequest - streaming: bool event: AstrMessageEvent + + +AgentContextWrapper = ContextWrapper[AstrAgentContext] diff --git a/astrbot/core/astr_agent_hooks.py b/astrbot/core/astr_agent_hooks.py new file mode 100644 index 000000000..f394fc947 --- /dev/null +++ b/astrbot/core/astr_agent_hooks.py @@ -0,0 +1,36 @@ +from typing import Any + +from mcp.types import CallToolResult + +from astrbot.core.agent.hooks import BaseAgentRunHooks +from astrbot.core.agent.run_context import ContextWrapper +from astrbot.core.agent.tool import FunctionTool +from astrbot.core.astr_agent_context import AstrAgentContext +from astrbot.core.pipeline.context_utils import call_event_hook +from astrbot.core.star.star_handler import EventType + + +class MainAgentHooks(BaseAgentRunHooks[AstrAgentContext]): + async def on_agent_done(self, run_context, llm_response): + # 执行事件钩子 + await call_event_hook( + run_context.context.event, + EventType.OnLLMResponseEvent, + llm_response, + ) + + async def on_tool_end( + self, + run_context: ContextWrapper[AstrAgentContext], + tool: FunctionTool[Any], + tool_args: dict | None, + tool_result: CallToolResult | None, + ): + run_context.context.event.clear_result() + + +class EmptyAgentHooks(BaseAgentRunHooks[AstrAgentContext]): + pass + + +MAIN_AGENT_HOOKS = MainAgentHooks() diff --git a/astrbot/core/astr_agent_run_util.py b/astrbot/core/astr_agent_run_util.py new file mode 100644 index 000000000..ed8c0028d --- /dev/null +++ b/astrbot/core/astr_agent_run_util.py @@ -0,0 +1,77 @@ +import traceback +from collections.abc import AsyncGenerator + +from astrbot.core import logger +from astrbot.core.agent.runners.tool_loop_agent_runner import ToolLoopAgentRunner +from astrbot.core.astr_agent_context import AstrAgentContext +from astrbot.core.message.message_event_result import ( + MessageChain, + MessageEventResult, + ResultContentType, +) + +AgentRunner = ToolLoopAgentRunner[AstrAgentContext] + + +async def run_agent( + agent_runner: AgentRunner, + max_step: int = 30, + show_tool_use: bool = True, + stream_to_general: bool = False, +) -> AsyncGenerator[MessageChain | None, None]: + step_idx = 0 + astr_event = agent_runner.run_context.context.event + while step_idx < max_step: + step_idx += 1 + try: + async for resp in agent_runner.step(): + if astr_event.is_stopped(): + return + if resp.type == "tool_call_result": + msg_chain = resp.data["chain"] + if msg_chain.type == "tool_direct_result": + # tool_direct_result 用于标记 llm tool 需要直接发送给用户的内容 + resp.data["chain"].type = "tool_call_result" + await astr_event.send(resp.data["chain"]) + continue + # 对于其他情况,暂时先不处理 + continue + elif resp.type == "tool_call": + if agent_runner.streaming: + # 用来标记流式响应需要分节 + yield MessageChain(chain=[], type="break") + if show_tool_use or astr_event.get_platform_name() == "webchat": + resp.data["chain"].type = "tool_call" + await astr_event.send(resp.data["chain"]) + continue + + if stream_to_general and resp.type == "streaming_delta": + continue + + if stream_to_general or not agent_runner.streaming: + content_typ = ( + ResultContentType.LLM_RESULT + if resp.type == "llm_result" + else ResultContentType.GENERAL_RESULT + ) + astr_event.set_result( + MessageEventResult( + chain=resp.data["chain"].chain, + result_content_type=content_typ, + ), + ) + yield + astr_event.clear_result() + elif resp.type == "streaming_delta": + yield resp.data["chain"] # MessageChain + if agent_runner.done(): + break + + except Exception as e: + logger.error(traceback.format_exc()) + err_msg = f"\n\nAstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {e!s}\n\n请在控制台查看和分享错误详情。\n" + if agent_runner.streaming: + yield MessageChain().message(err_msg) + else: + astr_event.set_result(MessageEventResult().message(err_msg)) + return diff --git a/astrbot/core/astr_agent_tool_exec.py b/astrbot/core/astr_agent_tool_exec.py new file mode 100644 index 000000000..f7425b0b5 --- /dev/null +++ b/astrbot/core/astr_agent_tool_exec.py @@ -0,0 +1,301 @@ +import asyncio +import inspect +import traceback +import typing as T + +import mcp + +from astrbot import logger +from astrbot.core.agent.handoff import HandoffTool +from astrbot.core.agent.hooks import BaseAgentRunHooks +from astrbot.core.agent.mcp_client import MCPTool +from astrbot.core.agent.run_context import ContextWrapper +from astrbot.core.agent.tool import FunctionTool, ToolSet +from astrbot.core.agent.tool_executor import BaseFunctionToolExecutor +from astrbot.core.astr_agent_context import AstrAgentContext +from astrbot.core.message.message_event_result import ( + CommandResult, + MessageChain, + MessageEventResult, +) +from astrbot.core.provider.entities import ProviderRequest +from astrbot.core.provider.register import llm_tools + +from .astr_agent_context import AgentContextWrapper +from .astr_agent_run_util import AgentRunner, run_agent + + +class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]): + @classmethod + async def execute(cls, tool, run_context, **tool_args): + """执行函数调用。 + + Args: + event (AstrMessageEvent): 事件对象, 当 origin 为 local 时必须提供。 + **kwargs: 函数调用的参数。 + + Returns: + AsyncGenerator[None | mcp.types.CallToolResult, None] + + """ + if isinstance(tool, HandoffTool): + async for r in cls._execute_handoff(tool, run_context, **tool_args): + yield r + return + + elif isinstance(tool, MCPTool): + async for r in cls._execute_mcp(tool, run_context, **tool_args): + yield r + return + + else: + async for r in cls._execute_local(tool, run_context, **tool_args): + yield r + return + + @classmethod + async def _execute_handoff( + cls, + tool: HandoffTool, + run_context: ContextWrapper[AstrAgentContext], + **tool_args, + ): + input_ = tool_args.get("input", "agent") + agent_runner = AgentRunner() + + # make toolset for the agent + tools = tool.agent.tools + if tools: + toolset = ToolSet() + for t in tools: + if isinstance(t, str): + _t = llm_tools.get_func(t) + if _t: + toolset.add_tool(_t) + elif isinstance(t, FunctionTool): + toolset.add_tool(t) + else: + toolset = None + + request = ProviderRequest( + prompt=input_, + system_prompt=tool.description or "", + image_urls=[], # 暂时不传递原始 agent 的上下文 + contexts=[], # 暂时不传递原始 agent 的上下文 + func_tool=toolset, + ) + astr_agent_ctx = AstrAgentContext( + provider=run_context.context.provider, + event=run_context.context.event, + ) + + event = run_context.context.event + + logger.debug(f"正在将任务委托给 Agent: {tool.agent.name}, input: {input_}") + await event.send( + MessageChain().message("✨ 正在将任务委托给 Agent: " + tool.agent.name), + ) + + await agent_runner.reset( + provider=run_context.context.provider, + request=request, + run_context=AgentContextWrapper( + context=astr_agent_ctx, + tool_call_timeout=run_context.tool_call_timeout, + ), + tool_executor=FunctionToolExecutor(), + agent_hooks=tool.agent.run_hooks or BaseAgentRunHooks[AstrAgentContext](), + ) + + async for _ in run_agent(agent_runner, 15, True): + pass + + if agent_runner.done(): + llm_response = agent_runner.get_final_llm_resp() + + if not llm_response: + text_content = mcp.types.TextContent( + type="text", + text=f"error when deligate task to {tool.agent.name}", + ) + yield mcp.types.CallToolResult(content=[text_content]) + return + + logger.debug( + f"Agent {tool.agent.name} 任务完成, response: {llm_response.completion_text}", + ) + + result = ( + f"Agent {tool.agent.name} respond with: {llm_response.completion_text}\n\n" + "Note: If the result is error or need user provide more information, please provide more information to the agent(you can ask user for more information first)." + ) + + text_content = mcp.types.TextContent( + type="text", + text=result, + ) + yield mcp.types.CallToolResult(content=[text_content]) + else: + text_content = mcp.types.TextContent( + type="text", + text=f"error when deligate task to {tool.agent.name}", + ) + yield mcp.types.CallToolResult(content=[text_content]) + return + + @classmethod + async def _execute_local( + cls, + tool: FunctionTool, + run_context: ContextWrapper[AstrAgentContext], + **tool_args, + ): + event = run_context.context.event + if not event: + raise ValueError("Event must be provided for local function tools.") + + is_override_call = False + for ty in type(tool).mro(): + if "call" in ty.__dict__ and ty.__dict__["call"] is not FunctionTool.call: + is_override_call = True + break + + # 检查 tool 下有没有 run 方法 + if not tool.handler and not hasattr(tool, "run") and not is_override_call: + raise ValueError("Tool must have a valid handler or override 'run' method.") + + awaitable = None + method_name = "" + if tool.handler: + awaitable = tool.handler + method_name = "decorator_handler" + elif is_override_call: + awaitable = tool.call + method_name = "call" + elif hasattr(tool, "run"): + awaitable = getattr(tool, "run") + method_name = "run" + if awaitable is None: + raise ValueError("Tool must have a valid handler or override 'run' method.") + + wrapper = call_local_llm_tool( + context=run_context, + handler=awaitable, + method_name=method_name, + **tool_args, + ) + while True: + try: + resp = await asyncio.wait_for( + anext(wrapper), + timeout=run_context.tool_call_timeout, + ) + if resp is not None: + if isinstance(resp, mcp.types.CallToolResult): + yield resp + else: + text_content = mcp.types.TextContent( + type="text", + text=str(resp), + ) + yield mcp.types.CallToolResult(content=[text_content]) + else: + # NOTE: Tool 在这里直接请求发送消息给用户 + # TODO: 是否需要判断 event.get_result() 是否为空? + # 如果为空,则说明没有发送消息给用户,并且返回值为空,将返回一个特殊的 TextContent,其内容如"工具没有返回内容" + if res := run_context.context.event.get_result(): + if res.chain: + try: + await event.send( + MessageChain( + chain=res.chain, + type="tool_direct_result", + ) + ) + except Exception as e: + logger.error( + f"Tool 直接发送消息失败: {e}", + exc_info=True, + ) + yield None + except asyncio.TimeoutError: + raise Exception( + f"tool {tool.name} execution timeout after {run_context.tool_call_timeout} seconds.", + ) + except StopAsyncIteration: + break + + @classmethod + async def _execute_mcp( + cls, + tool: FunctionTool, + run_context: ContextWrapper[AstrAgentContext], + **tool_args, + ): + res = await tool.call(run_context, **tool_args) + if not res: + return + yield res + + +async def call_local_llm_tool( + context: ContextWrapper[AstrAgentContext], + handler: T.Callable[..., T.Awaitable[T.Any]], + method_name: str, + *args, + **kwargs, +) -> T.AsyncGenerator[T.Any, None]: + """执行本地 LLM 工具的处理函数并处理其返回结果""" + ready_to_call = None # 一个协程或者异步生成器 + + trace_ = None + + event = context.context.event + + try: + if method_name == "run" or method_name == "decorator_handler": + ready_to_call = handler(event, *args, **kwargs) + elif method_name == "call": + ready_to_call = handler(context, *args, **kwargs) + else: + raise ValueError(f"未知的方法名: {method_name}") + except ValueError as e: + logger.error(f"调用本地 LLM 工具时出错: {e}", exc_info=True) + except TypeError: + logger.error("处理函数参数不匹配,请检查 handler 的定义。", exc_info=True) + except Exception as e: + trace_ = traceback.format_exc() + logger.error(f"调用本地 LLM 工具时出错: {e}\n{trace_}") + + if not ready_to_call: + return + + if inspect.isasyncgen(ready_to_call): + _has_yielded = False + try: + async for ret in ready_to_call: + # 这里逐步执行异步生成器, 对于每个 yield 返回的 ret, 执行下面的代码 + # 返回值只能是 MessageEventResult 或者 None(无返回值) + _has_yielded = True + if isinstance(ret, (MessageEventResult, CommandResult)): + # 如果返回值是 MessageEventResult, 设置结果并继续 + event.set_result(ret) + yield + else: + # 如果返回值是 None, 则不设置结果并继续 + # 继续执行后续阶段 + yield ret + if not _has_yielded: + # 如果这个异步生成器没有执行到 yield 分支 + yield + except Exception as e: + logger.error(f"Previous Error: {trace_}") + raise e + elif inspect.iscoroutine(ready_to_call): + # 如果只是一个协程, 直接执行 + ret = await ready_to_call + if isinstance(ret, (MessageEventResult, CommandResult)): + event.set_result(ret) + yield + else: + yield ret diff --git a/astrbot/core/exceptions.py b/astrbot/core/exceptions.py new file mode 100644 index 000000000..e637d4930 --- /dev/null +++ b/astrbot/core/exceptions.py @@ -0,0 +1,9 @@ +from __future__ import annotations + + +class AstrBotError(Exception): + """Base exception for all AstrBot errors.""" + + +class ProviderNotFoundError(AstrBotError): + """Raised when a specified provider is not found.""" diff --git a/astrbot/core/pipeline/context.py b/astrbot/core/pipeline/context.py index 44186764e..a6cd567e0 100644 --- a/astrbot/core/pipeline/context.py +++ b/astrbot/core/pipeline/context.py @@ -3,7 +3,7 @@ from dataclasses import dataclass from astrbot.core.config import AstrBotConfig from astrbot.core.star import PluginManager -from .context_utils import call_event_hook, call_handler, call_local_llm_tool +from .context_utils import call_event_hook, call_handler @dataclass @@ -15,4 +15,3 @@ class PipelineContext: astrbot_config_id: str call_handler = call_handler call_event_hook = call_event_hook - call_local_llm_tool = call_local_llm_tool diff --git a/astrbot/core/pipeline/context_utils.py b/astrbot/core/pipeline/context_utils.py index 371816b6e..73d28c5d1 100644 --- a/astrbot/core/pipeline/context_utils.py +++ b/astrbot/core/pipeline/context_utils.py @@ -3,8 +3,6 @@ import traceback import typing as T from astrbot import logger -from astrbot.core.agent.run_context import ContextWrapper -from astrbot.core.astr_agent_context import AstrAgentContext from astrbot.core.message.message_event_result import CommandResult, MessageEventResult from astrbot.core.platform.astr_message_event import AstrMessageEvent from astrbot.core.star.star import star_map @@ -107,66 +105,3 @@ async def call_event_hook( return True return event.is_stopped() - - -async def call_local_llm_tool( - context: ContextWrapper[AstrAgentContext], - handler: T.Callable[..., T.Awaitable[T.Any]], - method_name: str, - *args, - **kwargs, -) -> T.AsyncGenerator[T.Any, None]: - """执行本地 LLM 工具的处理函数并处理其返回结果""" - ready_to_call = None # 一个协程或者异步生成器 - - trace_ = None - - event = context.context.event - - try: - if method_name == "run" or method_name == "decorator_handler": - ready_to_call = handler(event, *args, **kwargs) - elif method_name == "call": - ready_to_call = handler(context, *args, **kwargs) - else: - raise ValueError(f"未知的方法名: {method_name}") - except ValueError as e: - logger.error(f"调用本地 LLM 工具时出错: {e}", exc_info=True) - except TypeError: - logger.error("处理函数参数不匹配,请检查 handler 的定义。", exc_info=True) - except Exception as e: - trace_ = traceback.format_exc() - logger.error(f"调用本地 LLM 工具时出错: {e}\n{trace_}") - - if not ready_to_call: - return - - if inspect.isasyncgen(ready_to_call): - _has_yielded = False - try: - async for ret in ready_to_call: - # 这里逐步执行异步生成器, 对于每个 yield 返回的 ret, 执行下面的代码 - # 返回值只能是 MessageEventResult 或者 None(无返回值) - _has_yielded = True - if isinstance(ret, (MessageEventResult, CommandResult)): - # 如果返回值是 MessageEventResult, 设置结果并继续 - event.set_result(ret) - yield - else: - # 如果返回值是 None, 则不设置结果并继续 - # 继续执行后续阶段 - yield ret - if not _has_yielded: - # 如果这个异步生成器没有执行到 yield 分支 - yield - except Exception as e: - logger.error(f"Previous Error: {trace_}") - raise e - elif inspect.iscoroutine(ready_to_call): - # 如果只是一个协程, 直接执行 - ret = await ready_to_call - if isinstance(ret, (MessageEventResult, CommandResult)): - event.set_result(ret) - yield - else: - yield ret diff --git a/astrbot/core/pipeline/process_stage/method/llm_request.py b/astrbot/core/pipeline/process_stage/method/llm_request.py index 69bf31a55..eef9d69e2 100644 --- a/astrbot/core/pipeline/process_stage/method/llm_request.py +++ b/astrbot/core/pipeline/process_stage/method/llm_request.py @@ -3,20 +3,10 @@ import asyncio import copy import json -import traceback from collections.abc import AsyncGenerator -from typing import Any - -from mcp.types import CallToolResult from astrbot.core import logger -from astrbot.core.agent.handoff import HandoffTool -from astrbot.core.agent.hooks import BaseAgentRunHooks -from astrbot.core.agent.mcp_client import MCPTool -from astrbot.core.agent.run_context import ContextWrapper -from astrbot.core.agent.runners.tool_loop_agent_runner import ToolLoopAgentRunner -from astrbot.core.agent.tool import FunctionTool, ToolSet -from astrbot.core.agent.tool_executor import BaseFunctionToolExecutor +from astrbot.core.agent.tool import ToolSet from astrbot.core.astr_agent_context import AstrAgentContext from astrbot.core.conversation_mgr import Conversation from astrbot.core.message.components import Image @@ -31,328 +21,19 @@ from astrbot.core.provider.entities import ( LLMResponse, ProviderRequest, ) -from astrbot.core.provider.register import llm_tools from astrbot.core.star.session_llm_manager import SessionServiceManager from astrbot.core.star.star_handler import EventType, star_map from astrbot.core.utils.metrics import Metric from astrbot.core.utils.session_lock import session_lock_manager -from ...context import PipelineContext, call_event_hook, call_local_llm_tool +from ....astr_agent_context import AgentContextWrapper +from ....astr_agent_hooks import MAIN_AGENT_HOOKS +from ....astr_agent_run_util import AgentRunner, run_agent +from ....astr_agent_tool_exec import FunctionToolExecutor +from ...context import PipelineContext, call_event_hook from ..stage import Stage from ..utils import inject_kb_context -try: - import mcp -except (ModuleNotFoundError, ImportError): - logger.warning("警告: 缺少依赖库 'mcp',将无法使用 MCP 服务。") - - -AgentContextWrapper = ContextWrapper[AstrAgentContext] -AgentRunner = ToolLoopAgentRunner[AstrAgentContext] - - -class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]): - @classmethod - async def execute(cls, tool, run_context, **tool_args): - """执行函数调用。 - - Args: - event (AstrMessageEvent): 事件对象, 当 origin 为 local 时必须提供。 - **kwargs: 函数调用的参数。 - - Returns: - AsyncGenerator[None | mcp.types.CallToolResult, None] - - """ - if isinstance(tool, HandoffTool): - async for r in cls._execute_handoff(tool, run_context, **tool_args): - yield r - return - - elif isinstance(tool, MCPTool): - async for r in cls._execute_mcp(tool, run_context, **tool_args): - yield r - return - - else: - async for r in cls._execute_local(tool, run_context, **tool_args): - yield r - return - - @classmethod - async def _execute_handoff( - cls, - tool: HandoffTool, - run_context: ContextWrapper[AstrAgentContext], - **tool_args, - ): - input_ = tool_args.get("input", "agent") - agent_runner = AgentRunner() - - # make toolset for the agent - tools = tool.agent.tools - if tools: - toolset = ToolSet() - for t in tools: - if isinstance(t, str): - _t = llm_tools.get_func(t) - if _t: - toolset.add_tool(_t) - elif isinstance(t, FunctionTool): - toolset.add_tool(t) - else: - toolset = None - - request = ProviderRequest( - prompt=input_, - system_prompt=tool.description or "", - image_urls=[], # 暂时不传递原始 agent 的上下文 - contexts=[], # 暂时不传递原始 agent 的上下文 - func_tool=toolset, - ) - astr_agent_ctx = AstrAgentContext( - provider=run_context.context.provider, - first_provider_request=run_context.context.first_provider_request, - curr_provider_request=request, - streaming=run_context.context.streaming, - event=run_context.context.event, - ) - - event = run_context.context.event - - logger.debug(f"正在将任务委托给 Agent: {tool.agent.name}, input: {input_}") - await event.send( - MessageChain().message("✨ 正在将任务委托给 Agent: " + tool.agent.name), - ) - - await agent_runner.reset( - provider=run_context.context.provider, - request=request, - run_context=AgentContextWrapper( - context=astr_agent_ctx, - tool_call_timeout=run_context.tool_call_timeout, - ), - tool_executor=FunctionToolExecutor(), - agent_hooks=tool.agent.run_hooks or BaseAgentRunHooks[AstrAgentContext](), - streaming=run_context.context.streaming, - ) - - async for _ in run_agent(agent_runner, 15, True): - pass - - if agent_runner.done(): - llm_response = agent_runner.get_final_llm_resp() - - if not llm_response: - text_content = mcp.types.TextContent( - type="text", - text=f"error when deligate task to {tool.agent.name}", - ) - yield mcp.types.CallToolResult(content=[text_content]) - return - - logger.debug( - f"Agent {tool.agent.name} 任务完成, response: {llm_response.completion_text}", - ) - - result = ( - f"Agent {tool.agent.name} respond with: {llm_response.completion_text}\n\n" - "Note: If the result is error or need user provide more information, please provide more information to the agent(you can ask user for more information first)." - ) - - text_content = mcp.types.TextContent( - type="text", - text=result, - ) - yield mcp.types.CallToolResult(content=[text_content]) - else: - text_content = mcp.types.TextContent( - type="text", - text=f"error when deligate task to {tool.agent.name}", - ) - yield mcp.types.CallToolResult(content=[text_content]) - return - - @classmethod - async def _execute_local( - cls, - tool: FunctionTool, - run_context: ContextWrapper[AstrAgentContext], - **tool_args, - ): - event = run_context.context.event - if not event: - raise ValueError("Event must be provided for local function tools.") - - is_override_call = False - for ty in type(tool).mro(): - if "call" in ty.__dict__ and ty.__dict__["call"] is not FunctionTool.call: - is_override_call = True - break - - # 检查 tool 下有没有 run 方法 - if not tool.handler and not hasattr(tool, "run") and not is_override_call: - raise ValueError("Tool must have a valid handler or override 'run' method.") - - awaitable = None - method_name = "" - if tool.handler: - awaitable = tool.handler - method_name = "decorator_handler" - elif is_override_call: - awaitable = tool.call - method_name = "call" - elif hasattr(tool, "run"): - awaitable = getattr(tool, "run") - method_name = "run" - if awaitable is None: - raise ValueError("Tool must have a valid handler or override 'run' method.") - - wrapper = call_local_llm_tool( - context=run_context, - handler=awaitable, - method_name=method_name, - **tool_args, - ) - while True: - try: - resp = await asyncio.wait_for( - anext(wrapper), - timeout=run_context.tool_call_timeout, - ) - if resp is not None: - if isinstance(resp, mcp.types.CallToolResult): - yield resp - else: - text_content = mcp.types.TextContent( - type="text", - text=str(resp), - ) - yield mcp.types.CallToolResult(content=[text_content]) - else: - # NOTE: Tool 在这里直接请求发送消息给用户 - # TODO: 是否需要判断 event.get_result() 是否为空? - # 如果为空,则说明没有发送消息给用户,并且返回值为空,将返回一个特殊的 TextContent,其内容如"工具没有返回内容" - if res := run_context.context.event.get_result(): - if res.chain: - try: - await event.send( - MessageChain( - chain=res.chain, - type="tool_direct_result", - ) - ) - except Exception as e: - logger.error( - f"Tool 直接发送消息失败: {e}", - exc_info=True, - ) - yield None - except asyncio.TimeoutError: - raise Exception( - f"tool {tool.name} execution timeout after {run_context.tool_call_timeout} seconds.", - ) - except StopAsyncIteration: - break - - @classmethod - async def _execute_mcp( - cls, - tool: FunctionTool, - run_context: ContextWrapper[AstrAgentContext], - **tool_args, - ): - res = await tool.call(run_context, **tool_args) - if not res: - return - yield res - - -class MainAgentHooks(BaseAgentRunHooks[AstrAgentContext]): - async def on_agent_done(self, run_context, llm_response): - # 执行事件钩子 - await call_event_hook( - run_context.context.event, - EventType.OnLLMResponseEvent, - llm_response, - ) - - async def on_tool_end( - self, - run_context: ContextWrapper[AstrAgentContext], - tool: FunctionTool[Any], - tool_args: dict | None, - tool_result: CallToolResult | None, - ): - run_context.context.event.clear_result() - - -MAIN_AGENT_HOOKS = MainAgentHooks() - - -async def run_agent( - agent_runner: AgentRunner, - max_step: int = 30, - show_tool_use: bool = True, - stream_to_general: bool = False, -) -> AsyncGenerator[MessageChain, None]: - step_idx = 0 - astr_event = agent_runner.run_context.context.event - while step_idx < max_step: - step_idx += 1 - try: - async for resp in agent_runner.step(): - if astr_event.is_stopped(): - return - if resp.type == "tool_call_result": - msg_chain = resp.data["chain"] - if msg_chain.type == "tool_direct_result": - # tool_direct_result 用于标记 llm tool 需要直接发送给用户的内容 - resp.data["chain"].type = "tool_call_result" - await astr_event.send(resp.data["chain"]) - continue - # 对于其他情况,暂时先不处理 - continue - elif resp.type == "tool_call": - if agent_runner.streaming: - # 用来标记流式响应需要分节 - yield MessageChain(chain=[], type="break") - if show_tool_use or astr_event.get_platform_name() == "webchat": - resp.data["chain"].type = "tool_call" - await astr_event.send(resp.data["chain"]) - continue - - if stream_to_general and resp.type == "streaming_delta": - continue - - if stream_to_general or not agent_runner.streaming: - content_typ = ( - ResultContentType.LLM_RESULT - if resp.type == "llm_result" - else ResultContentType.GENERAL_RESULT - ) - astr_event.set_result( - MessageEventResult( - chain=resp.data["chain"].chain, - result_content_type=content_typ, - ), - ) - yield - astr_event.clear_result() - elif resp.type == "streaming_delta": - yield resp.data["chain"] # MessageChain - if agent_runner.done(): - break - - except Exception as e: - logger.error(traceback.format_exc()) - err_msg = f"\n\nAstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {e!s}\n\n请在控制台查看和分享错误详情。\n" - if agent_runner.streaming: - yield MessageChain().message(err_msg) - else: - astr_event.set_result(MessageEventResult().message(err_msg)) - return - class LLMRequestSubStage(Stage): async def initialize(self, ctx: PipelineContext) -> None: @@ -569,6 +250,9 @@ class LLMRequestSubStage(Stage): logger.debug("LLM 响应为空,不保存记录。") return + if req.contexts is None: + req.contexts = [] + # 历史上下文 messages = copy.deepcopy(req.contexts) # 这一轮对话请求的用户输入 @@ -644,7 +328,9 @@ class LLMRequestSubStage(Stage): req.contexts = json.loads(req.conversation.history) else: - req = ProviderRequest(prompt="", image_urls=[]) + req = ProviderRequest() + req.prompt = "" + req.image_urls = [] if sel_model := event.get_extra("selected_model"): req.model = sel_model if self.provider_wake_prefix and not event.message_str.startswith( @@ -681,15 +367,14 @@ class LLMRequestSubStage(Stage): req.contexts = json.loads(req.contexts) # truncate contexts to fit max length - req.contexts = self._truncate_contexts(req.contexts) + if req.contexts: + req.contexts = self._truncate_contexts(req.contexts) + self._fix_messages(req.contexts) # session_id if not req.session_id: req.session_id = event.unified_msg_origin - # fix messages - req.contexts = self._fix_messages(req.contexts) - # check provider modalities, if provider does not support image/tool_use, clear them in request. self._modalities_fix(provider, req) @@ -710,9 +395,6 @@ class LLMRequestSubStage(Stage): ) astr_agent_ctx = AstrAgentContext( provider=provider, - first_provider_request=req, - curr_provider_request=req, - streaming=streaming_response, event=event, ) await agent_runner.reset( diff --git a/astrbot/core/provider/entities.py b/astrbot/core/provider/entities.py index 6852f9cd6..e26f3ea50 100644 --- a/astrbot/core/provider/entities.py +++ b/astrbot/core/provider/entities.py @@ -66,9 +66,9 @@ class ToolCallsResult: @dataclass class ProviderRequest: - prompt: str + prompt: str | None = None """提示词""" - session_id: str = "" + session_id: str | None = "" """会话 ID""" image_urls: list[str] = field(default_factory=list) """图片 URL 列表""" diff --git a/astrbot/core/star/context.py b/astrbot/core/star/context.py index 1a5bc53d9..7e238b8a0 100644 --- a/astrbot/core/star/context.py +++ b/astrbot/core/star/context.py @@ -5,6 +5,12 @@ from typing import Any from deprecated import deprecated +from astrbot.core.agent.hooks import BaseAgentRunHooks +from astrbot.core.agent.message import Message +from astrbot.core.agent.runners.tool_loop_agent_runner import ToolLoopAgentRunner +from astrbot.core.agent.tool import ToolSet +from astrbot.core.astr_agent_context import AgentContextWrapper, AstrAgentContext +from astrbot.core.astr_agent_tool_exec import FunctionToolExecutor from astrbot.core.astrbot_config_mgr import AstrBotConfigManager from astrbot.core.config.astrbot_config import AstrBotConfig from astrbot.core.conversation_mgr import ConversationManager @@ -13,10 +19,10 @@ from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager from astrbot.core.message.message_event_result import MessageChain from astrbot.core.persona_mgr import PersonaManager from astrbot.core.platform import Platform -from astrbot.core.platform.astr_message_event import MessageSesion +from astrbot.core.platform.astr_message_event import AstrMessageEvent, MessageSesion from astrbot.core.platform.manager import PlatformManager from astrbot.core.platform_message_history_mgr import PlatformMessageHistoryManager -from astrbot.core.provider.entities import ProviderType +from astrbot.core.provider.entities import LLMResponse, ProviderRequest, ProviderType from astrbot.core.provider.func_tool_manager import FunctionTool, FunctionToolManager from astrbot.core.provider.manager import ProviderManager from astrbot.core.provider.provider import ( @@ -31,6 +37,7 @@ from astrbot.core.star.filter.platform_adapter_type import ( PlatformAdapterType, ) +from ..exceptions import ProviderNotFoundError from .filter.command import CommandFilter from .filter.regex import RegexFilter from .star import StarMetadata, star_map, star_registry @@ -75,6 +82,139 @@ class Context: self.astrbot_config_mgr = astrbot_config_mgr self.kb_manager = knowledge_base_manager + async def llm_generate( + self, + *, + chat_provider_id: str, + prompt: str | None = None, + image_urls: list[str] | None = None, + tools: ToolSet | None = None, + system_prompt: str | None = None, + contexts: list[Message] | list[dict] | None = None, + **kwargs: Any, + ) -> LLMResponse: + """Call the LLM to generate a response. The method will not automatically execute tool calls. If you want to use tool calls, please use `tool_loop_agent()`. + + .. versionadded:: 4.5.7 (sdk) + + Args: + chat_provider_id: The chat provider ID to use. + prompt: The prompt to send to the LLM, if `contexts` and `prompt` are both provided, `prompt` will be appended as the last user message + image_urls: List of image URLs to include in the prompt, if `contexts` and `prompt` are both provided, `image_urls` will be appended to the last user message + tools: ToolSet of tools available to the LLM + system_prompt: System prompt to guide the LLM's behavior, if provided, it will always insert as the first system message in the context + contexts: context messages for the LLM + **kwargs: Additional keyword arguments for LLM generation, OpenAI compatible + + Raises: + ChatProviderNotFoundError: If the specified chat provider ID is not found + Exception: For other errors during LLM generation + """ + prov = await self.provider_manager.get_provider_by_id(chat_provider_id) + if not prov or not isinstance(prov, Provider): + raise ProviderNotFoundError(f"Provider {chat_provider_id} not found") + llm_resp = await prov.text_chat( + prompt=prompt, + image_urls=image_urls, + func_tool=tools, + contexts=contexts, + system_prompt=system_prompt, + **kwargs, + ) + return llm_resp + + async def tool_loop_agent( + self, + *, + event: AstrMessageEvent, + chat_provider_id: str, + prompt: str | None = None, + image_urls: list[str] | None = None, + tools: ToolSet | None = None, + system_prompt: str | None = None, + contexts: list[Message] | list[dict] | None = None, + max_steps: int = 30, + tool_call_timeout: int = 60, + **kwargs: Any, + ) -> LLMResponse: + """Run an agent loop that allows the LLM to call tools iteratively until a final answer is produced. + + Args: + chat_provider_id: The chat provider ID to use. + prompt: The prompt to send to the LLM, if `contexts` and `prompt` are both provided, `prompt` will be appended as the last user message + image_urls: List of image URLs to include in the prompt, if `contexts` and `prompt` are both provided, `image_urls` will be appended to the last user message + tools: ToolSet of tools available to the LLM + system_prompt: System prompt to guide the LLM's behavior, if provided, it will always insert as the first system message in the context + contexts: context messages for the LLM + max_steps: Maximum number of tool calls before stopping the loop + **kwargs: Additional keyword arguments for LLM generation, OpenAI compatible + + Returns: + The final LLMResponse after tool calls are completed. + + Raises: + ChatProviderNotFoundError: If the specified chat provider ID is not found + Exception: For other errors during LLM generation + """ + prov = await self.provider_manager.get_provider_by_id(chat_provider_id) + if not prov or not isinstance(prov, Provider): + raise ProviderNotFoundError(f"Provider {chat_provider_id} not found") + + context_ = [] + for msg in contexts or []: + if isinstance(msg, Message): + context_.append(msg.model_dump()) + else: + context_.append(msg) + + request = ProviderRequest( + prompt=prompt, + image_urls=image_urls, + func_tool=tools, + contexts=context_, + system_prompt=system_prompt, + ) + astr_agent_ctx = AstrAgentContext( + provider=prov, + event=event, + ) + agent_runner = ToolLoopAgentRunner() + tool_executor = FunctionToolExecutor() + await agent_runner.reset( + provider=prov, + request=request, + run_context=AgentContextWrapper( + context=astr_agent_ctx, + tool_call_timeout=tool_call_timeout, + ), + tool_executor=tool_executor, + agent_hooks=kwargs.get( + "agent_hooks", BaseAgentRunHooks[AstrAgentContext]() + ), + streaming=kwargs.get("stream", False), + ) + async for _ in agent_runner.step_until_done(max_steps): + pass + llm_resp = agent_runner.get_final_llm_resp() + if not llm_resp: + raise Exception("Agent did not produce a final LLM response") + return llm_resp + + async def get_current_chat_provider_id(self, umo: str) -> str: + """Get the ID of the currently used chat provider. + + Args: + umo(str): unified_message_origin value, if provided and user has enabled provider session isolation, the provider preferred by that session will be used. + + Raises: + ProviderNotFoundError: If the specified chat provider is not found + + """ + prov = self.get_using_provider(umo) + if not prov: + raise ProviderNotFoundError("Provider not found") + return prov.meta().id + def get_registered_star(self, star_name: str) -> StarMetadata | None: """根据插件名获取插件的 Metadata""" for star in star_registry: