From 17422ba9c3e9722002409349a84f4d512b8372e2 Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Sat, 15 Nov 2025 21:15:20 +0800 Subject: [PATCH] feat: introduce messages field in agent RunContext --- astrbot/core/agent/run_context.py | 9 ++- .../agent/runners/tool_loop_agent_runner.py | 28 ++++++- astrbot/core/agent/tool.py | 5 +- astrbot/core/astr_agent_context.py | 13 ++- astrbot/core/astr_agent_tool_exec.py | 81 +++---------------- .../process_stage/method/llm_request.py | 6 +- astrbot/core/provider/entities.py | 8 ++ astrbot/core/star/context.py | 38 ++++++--- 8 files changed, 95 insertions(+), 93 deletions(-) diff --git a/astrbot/core/agent/run_context.py b/astrbot/core/agent/run_context.py index 395817679..07e435895 100644 --- a/astrbot/core/agent/run_context.py +++ b/astrbot/core/agent/run_context.py @@ -1,16 +1,21 @@ -from dataclasses import dataclass from typing import Any, Generic +from pydantic import Field +from pydantic.dataclasses import dataclass from typing_extensions import TypeVar +from .message import Message + TContext = TypeVar("TContext", default=Any) -@dataclass +@dataclass(config={"arbitrary_types_allowed": True}) class ContextWrapper(Generic[TContext]): """A context for running an agent, which can be used to pass additional data or state.""" context: TContext + messages: list[Message] = Field(default_factory=list) + """This field stores the llm message context for the agent run, agent runners will maintain this field automatically.""" tool_call_timeout: int = 60 # Default tool call timeout in seconds diff --git a/astrbot/core/agent/runners/tool_loop_agent_runner.py b/astrbot/core/agent/runners/tool_loop_agent_runner.py index f6e613679..744030bbc 100644 --- a/astrbot/core/agent/runners/tool_loop_agent_runner.py +++ b/astrbot/core/agent/runners/tool_loop_agent_runner.py @@ -23,7 +23,7 @@ from astrbot.core.provider.entities import ( from astrbot.core.provider.provider import Provider from ..hooks import BaseAgentRunHooks -from ..message import AssistantMessageSegment, ToolCallMessageSegment +from ..message import AssistantMessageSegment, Message, ToolCallMessageSegment from ..response import AgentResponseData from ..run_context import ContextWrapper, TContext from ..tool_executor import BaseFunctionToolExecutor @@ -55,6 +55,20 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]): self.agent_hooks = agent_hooks self.run_context = run_context + messages = [] + # append existing messages in the run context + for msg in request.contexts: + messages.append(Message.model_validate(msg)) + if request.prompt is not None: + m = await request.assemble_context() + messages.append(Message.model_validate(m)) + if request.system_prompt: + messages.insert( + 0, + Message(role="system", content=request.system_prompt), + ) + self.run_context.messages = messages + def _transition_state(self, new_state: AgentState) -> None: """转换 Agent 状态""" if self._state != new_state: @@ -130,6 +144,13 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]): # 如果没有工具调用,转换到完成状态 self.final_llm_resp = llm_resp self._transition_state(AgentState.DONE) + # record the final assistant message + self.run_context.messages.append( + Message( + role="assistant", + content=llm_resp.completion_text or "", + ), + ) try: await self.agent_hooks.on_agent_done(self.run_context, llm_resp) except Exception as e: @@ -175,6 +196,11 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]): ), tool_calls_result=tool_call_result_blocks, ) + # record the assistant message with tool calls + self.run_context.messages.extend( + tool_calls_result.to_openai_messages_model() + ) + self.req.append_tool_calls_result(tool_calls_result) async def step_until_done( diff --git a/astrbot/core/agent/tool.py b/astrbot/core/agent/tool.py index ae240d2e0..45226991c 100644 --- a/astrbot/core/agent/tool.py +++ b/astrbot/core/agent/tool.py @@ -10,6 +10,7 @@ from pydantic.dataclasses import dataclass from .run_context import ContextWrapper, TContext ParametersType = dict[str, Any] +ToolExecResult = str | mcp.types.CallToolResult @dataclass @@ -55,9 +56,7 @@ class FunctionTool(ToolSchema, Generic[TContext]): def __repr__(self): return f"FuncTool(name={self.name}, parameters={self.parameters}, description={self.description})" - async def call( - self, context: ContextWrapper[TContext], **kwargs - ) -> str | mcp.types.CallToolResult: + async def call(self, context: ContextWrapper[TContext], **kwargs) -> ToolExecResult: """Run the tool with the given arguments. The handler field has priority.""" raise NotImplementedError( "FunctionTool.call() must be implemented by subclasses or set a handler." diff --git a/astrbot/core/astr_agent_context.py b/astrbot/core/astr_agent_context.py index ffe8a199b..5eed5de8f 100644 --- a/astrbot/core/astr_agent_context.py +++ b/astrbot/core/astr_agent_context.py @@ -1,14 +1,19 @@ -from dataclasses import dataclass +from pydantic import Field +from pydantic.dataclasses import dataclass from astrbot.core.agent.run_context import ContextWrapper from astrbot.core.platform.astr_message_event import AstrMessageEvent -from astrbot.core.provider import Provider +from astrbot.core.star.context import Context -@dataclass +@dataclass(config={"arbitrary_types_allowed": True}) class AstrAgentContext: - provider: Provider + context: Context + """The star context instance""" event: AstrMessageEvent + """The message event associated with the agent context.""" + extra: dict[str, str] = Field(default_factory=dict) + """Customized extra data.""" AgentContextWrapper = ContextWrapper[AstrAgentContext] diff --git a/astrbot/core/astr_agent_tool_exec.py b/astrbot/core/astr_agent_tool_exec.py index f7425b0b5..25a6a06e5 100644 --- a/astrbot/core/astr_agent_tool_exec.py +++ b/astrbot/core/astr_agent_tool_exec.py @@ -7,7 +7,6 @@ import mcp from astrbot import logger from astrbot.core.agent.handoff import HandoffTool -from astrbot.core.agent.hooks import BaseAgentRunHooks from astrbot.core.agent.mcp_client import MCPTool from astrbot.core.agent.run_context import ContextWrapper from astrbot.core.agent.tool import FunctionTool, ToolSet @@ -18,12 +17,8 @@ from astrbot.core.message.message_event_result import ( MessageChain, MessageEventResult, ) -from astrbot.core.provider.entities import ProviderRequest from astrbot.core.provider.register import llm_tools -from .astr_agent_context import AgentContextWrapper -from .astr_agent_run_util import AgentRunner, run_agent - class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]): @classmethod @@ -60,8 +55,7 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]): run_context: ContextWrapper[AstrAgentContext], **tool_args, ): - input_ = tool_args.get("input", "agent") - agent_runner = AgentRunner() + input_ = tool_args.get("input") # make toolset for the agent tools = tool.agent.tools @@ -77,72 +71,21 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]): else: toolset = None - request = ProviderRequest( - prompt=input_, - system_prompt=tool.description or "", - image_urls=[], # 暂时不传递原始 agent 的上下文 - contexts=[], # 暂时不传递原始 agent 的上下文 - func_tool=toolset, - ) - astr_agent_ctx = AstrAgentContext( - provider=run_context.context.provider, - event=run_context.context.event, - ) - + ctx = run_context.context.context event = run_context.context.event - - logger.debug(f"正在将任务委托给 Agent: {tool.agent.name}, input: {input_}") - await event.send( - MessageChain().message("✨ 正在将任务委托给 Agent: " + tool.agent.name), + umo = event.unified_msg_origin + prov_id = await ctx.get_current_chat_provider_id(umo) + llm_resp = await ctx.tool_loop_agent( + event=event, + chat_provider_id=prov_id, + prompt=input_, + tools=toolset, + max_steps=30, ) - - await agent_runner.reset( - provider=run_context.context.provider, - request=request, - run_context=AgentContextWrapper( - context=astr_agent_ctx, - tool_call_timeout=run_context.tool_call_timeout, - ), - tool_executor=FunctionToolExecutor(), - agent_hooks=tool.agent.run_hooks or BaseAgentRunHooks[AstrAgentContext](), + yield mcp.types.CallToolResult( + content=[mcp.types.TextContent(type="text", text=llm_resp.completion_text)] ) - async for _ in run_agent(agent_runner, 15, True): - pass - - if agent_runner.done(): - llm_response = agent_runner.get_final_llm_resp() - - if not llm_response: - text_content = mcp.types.TextContent( - type="text", - text=f"error when deligate task to {tool.agent.name}", - ) - yield mcp.types.CallToolResult(content=[text_content]) - return - - logger.debug( - f"Agent {tool.agent.name} 任务完成, response: {llm_response.completion_text}", - ) - - result = ( - f"Agent {tool.agent.name} respond with: {llm_response.completion_text}\n\n" - "Note: If the result is error or need user provide more information, please provide more information to the agent(you can ask user for more information first)." - ) - - text_content = mcp.types.TextContent( - type="text", - text=result, - ) - yield mcp.types.CallToolResult(content=[text_content]) - else: - text_content = mcp.types.TextContent( - type="text", - text=f"error when deligate task to {tool.agent.name}", - ) - yield mcp.types.CallToolResult(content=[text_content]) - return - @classmethod async def _execute_local( cls, diff --git a/astrbot/core/pipeline/process_stage/method/llm_request.py b/astrbot/core/pipeline/process_stage/method/llm_request.py index eef9d69e2..76481ce56 100644 --- a/astrbot/core/pipeline/process_stage/method/llm_request.py +++ b/astrbot/core/pipeline/process_stage/method/llm_request.py @@ -6,6 +6,7 @@ import json from collections.abc import AsyncGenerator from astrbot.core import logger +from astrbot.core.agent.message import Message from astrbot.core.agent.tool import ToolSet from astrbot.core.astr_agent_context import AstrAgentContext from astrbot.core.conversation_mgr import Conversation @@ -393,8 +394,11 @@ class LLMRequestSubStage(Stage): logger.debug( f"handle provider[id: {provider.provider_config['id']}] request: {req}", ) + context_model: list[Message] = [] + for msg in req.contexts: + context_model.append(Message.model_validate(msg)) astr_agent_ctx = AstrAgentContext( - provider=provider, + context=self.ctx.plugin_manager.context, event=event, ) await agent_runner.reset( diff --git a/astrbot/core/provider/entities.py b/astrbot/core/provider/entities.py index e26f3ea50..0a0b8d405 100644 --- a/astrbot/core/provider/entities.py +++ b/astrbot/core/provider/entities.py @@ -63,6 +63,14 @@ class ToolCallsResult: ] return ret + def to_openai_messages_model( + self, + ) -> list[AssistantMessageSegment | ToolCallMessageSegment]: + return [ + self.tool_calls_info, + *self.tool_calls_result, + ] + @dataclass class ProviderRequest: diff --git a/astrbot/core/star/context.py b/astrbot/core/star/context.py index 638dd435b..5918b2029 100644 --- a/astrbot/core/star/context.py +++ b/astrbot/core/star/context.py @@ -9,8 +9,6 @@ from astrbot.core.agent.hooks import BaseAgentRunHooks from astrbot.core.agent.message import Message from astrbot.core.agent.runners.tool_loop_agent_runner import ToolLoopAgentRunner from astrbot.core.agent.tool import ToolSet -from astrbot.core.astr_agent_context import AgentContextWrapper, AstrAgentContext -from astrbot.core.astr_agent_tool_exec import FunctionToolExecutor from astrbot.core.astrbot_config_mgr import AstrBotConfigManager from astrbot.core.config.astrbot_config import AstrBotConfig from astrbot.core.conversation_mgr import ConversationManager @@ -90,7 +88,7 @@ class Context: image_urls: list[str] | None = None, tools: ToolSet | None = None, system_prompt: str | None = None, - contexts: list[Message] | list[dict] | None = None, + contexts: list[Message] | None = None, **kwargs: Any, ) -> LLMResponse: """Call the LLM to generate a response. The method will not automatically execute tool calls. If you want to use tool calls, please use `tool_loop_agent()`. @@ -132,12 +130,15 @@ class Context: image_urls: list[str] | None = None, tools: ToolSet | None = None, system_prompt: str | None = None, - contexts: list[Message] | list[dict] | None = None, + contexts: list[Message] | None = None, max_steps: int = 30, tool_call_timeout: int = 60, **kwargs: Any, ) -> LLMResponse: """Run an agent loop that allows the LLM to call tools iteratively until a final answer is produced. + If you do not pass the agent_context parameter, the method will recreate a new agent context. + + .. versionadded:: 4.5.7 (sdk) Args: chat_provider_id: The chat provider ID to use. @@ -147,7 +148,9 @@ class Context: system_prompt: System prompt to guide the LLM's behavior, if provided, it will always insert as the first system message in the context contexts: context messages for the LLM max_steps: Maximum number of tool calls before stopping the loop - **kwargs: Additional keyword arguments for LLM generation, OpenAI compatible + **kwargs: Additional keyword arguments. The kwargs will not be passed to the LLM directly for now, but can include: + agent_hooks: BaseAgentRunHooks[AstrAgentContext] - hooks to run during agent execution + agent_context: AstrAgentContext - context to use for the agent Returns: The final LLMResponse after tool calls are completed. @@ -156,10 +159,20 @@ class Context: ChatProviderNotFoundError: If the specified chat provider ID is not found Exception: For other errors during LLM generation """ + # Import here to avoid circular imports + from astrbot.core.astr_agent_context import ( + AgentContextWrapper, + AstrAgentContext, + ) + from astrbot.core.astr_agent_tool_exec import FunctionToolExecutor + prov = await self.provider_manager.get_provider_by_id(chat_provider_id) if not prov or not isinstance(prov, Provider): raise ProviderNotFoundError(f"Provider {chat_provider_id} not found") + agent_hooks = kwargs.get("agent_hooks") or BaseAgentRunHooks[AstrAgentContext]() + agent_context = kwargs.get("agent_context") + context_ = [] for msg in contexts or []: if isinstance(msg, Message): @@ -174,23 +187,22 @@ class Context: contexts=context_, system_prompt=system_prompt or "", ) - astr_agent_ctx = AstrAgentContext( - provider=prov, - event=event, - ) + if agent_context is None: + agent_context = AstrAgentContext( + context=self, + event=event, + ) agent_runner = ToolLoopAgentRunner() tool_executor = FunctionToolExecutor() await agent_runner.reset( provider=prov, request=request, run_context=AgentContextWrapper( - context=astr_agent_ctx, + context=agent_context, tool_call_timeout=tool_call_timeout, ), tool_executor=tool_executor, - agent_hooks=kwargs.get( - "agent_hooks", BaseAgentRunHooks[AstrAgentContext]() - ), + agent_hooks=agent_hooks, streaming=kwargs.get("stream", False), ) async for _ in agent_runner.step_until_done(max_steps):