feat: introduce messages field in agent RunContext

This commit is contained in:
Soulter
2025-11-15 21:15:20 +08:00
parent 89e79863f6
commit 17422ba9c3
8 changed files with 95 additions and 93 deletions
+7 -2
View File
@@ -1,16 +1,21 @@
from dataclasses import dataclass
from typing import Any, Generic
from pydantic import Field
from pydantic.dataclasses import dataclass
from typing_extensions import TypeVar
from .message import Message
TContext = TypeVar("TContext", default=Any)
@dataclass
@dataclass(config={"arbitrary_types_allowed": True})
class ContextWrapper(Generic[TContext]):
"""A context for running an agent, which can be used to pass additional data or state."""
context: TContext
messages: list[Message] = Field(default_factory=list)
"""This field stores the llm message context for the agent run, agent runners will maintain this field automatically."""
tool_call_timeout: int = 60 # Default tool call timeout in seconds
@@ -23,7 +23,7 @@ from astrbot.core.provider.entities import (
from astrbot.core.provider.provider import Provider
from ..hooks import BaseAgentRunHooks
from ..message import AssistantMessageSegment, ToolCallMessageSegment
from ..message import AssistantMessageSegment, Message, ToolCallMessageSegment
from ..response import AgentResponseData
from ..run_context import ContextWrapper, TContext
from ..tool_executor import BaseFunctionToolExecutor
@@ -55,6 +55,20 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
self.agent_hooks = agent_hooks
self.run_context = run_context
messages = []
# append existing messages in the run context
for msg in request.contexts:
messages.append(Message.model_validate(msg))
if request.prompt is not None:
m = await request.assemble_context()
messages.append(Message.model_validate(m))
if request.system_prompt:
messages.insert(
0,
Message(role="system", content=request.system_prompt),
)
self.run_context.messages = messages
def _transition_state(self, new_state: AgentState) -> None:
"""转换 Agent 状态"""
if self._state != new_state:
@@ -130,6 +144,13 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
# 如果没有工具调用,转换到完成状态
self.final_llm_resp = llm_resp
self._transition_state(AgentState.DONE)
# record the final assistant message
self.run_context.messages.append(
Message(
role="assistant",
content=llm_resp.completion_text or "",
),
)
try:
await self.agent_hooks.on_agent_done(self.run_context, llm_resp)
except Exception as e:
@@ -175,6 +196,11 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
),
tool_calls_result=tool_call_result_blocks,
)
# record the assistant message with tool calls
self.run_context.messages.extend(
tool_calls_result.to_openai_messages_model()
)
self.req.append_tool_calls_result(tool_calls_result)
async def step_until_done(
+2 -3
View File
@@ -10,6 +10,7 @@ from pydantic.dataclasses import dataclass
from .run_context import ContextWrapper, TContext
ParametersType = dict[str, Any]
ToolExecResult = str | mcp.types.CallToolResult
@dataclass
@@ -55,9 +56,7 @@ class FunctionTool(ToolSchema, Generic[TContext]):
def __repr__(self):
return f"FuncTool(name={self.name}, parameters={self.parameters}, description={self.description})"
async def call(
self, context: ContextWrapper[TContext], **kwargs
) -> str | mcp.types.CallToolResult:
async def call(self, context: ContextWrapper[TContext], **kwargs) -> ToolExecResult:
"""Run the tool with the given arguments. The handler field has priority."""
raise NotImplementedError(
"FunctionTool.call() must be implemented by subclasses or set a handler."
+9 -4
View File
@@ -1,14 +1,19 @@
from dataclasses import dataclass
from pydantic import Field
from pydantic.dataclasses import dataclass
from astrbot.core.agent.run_context import ContextWrapper
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.provider import Provider
from astrbot.core.star.context import Context
@dataclass
@dataclass(config={"arbitrary_types_allowed": True})
class AstrAgentContext:
provider: Provider
context: Context
"""The star context instance"""
event: AstrMessageEvent
"""The message event associated with the agent context."""
extra: dict[str, str] = Field(default_factory=dict)
"""Customized extra data."""
AgentContextWrapper = ContextWrapper[AstrAgentContext]
+12 -69
View File
@@ -7,7 +7,6 @@ import mcp
from astrbot import logger
from astrbot.core.agent.handoff import HandoffTool
from astrbot.core.agent.hooks import BaseAgentRunHooks
from astrbot.core.agent.mcp_client import MCPTool
from astrbot.core.agent.run_context import ContextWrapper
from astrbot.core.agent.tool import FunctionTool, ToolSet
@@ -18,12 +17,8 @@ from astrbot.core.message.message_event_result import (
MessageChain,
MessageEventResult,
)
from astrbot.core.provider.entities import ProviderRequest
from astrbot.core.provider.register import llm_tools
from .astr_agent_context import AgentContextWrapper
from .astr_agent_run_util import AgentRunner, run_agent
class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
@classmethod
@@ -60,8 +55,7 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
run_context: ContextWrapper[AstrAgentContext],
**tool_args,
):
input_ = tool_args.get("input", "agent")
agent_runner = AgentRunner()
input_ = tool_args.get("input")
# make toolset for the agent
tools = tool.agent.tools
@@ -77,72 +71,21 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
else:
toolset = None
request = ProviderRequest(
prompt=input_,
system_prompt=tool.description or "",
image_urls=[], # 暂时不传递原始 agent 的上下文
contexts=[], # 暂时不传递原始 agent 的上下文
func_tool=toolset,
)
astr_agent_ctx = AstrAgentContext(
provider=run_context.context.provider,
event=run_context.context.event,
)
ctx = run_context.context.context
event = run_context.context.event
logger.debug(f"正在将任务委托给 Agent: {tool.agent.name}, input: {input_}")
await event.send(
MessageChain().message("✨ 正在将任务委托给 Agent: " + tool.agent.name),
umo = event.unified_msg_origin
prov_id = await ctx.get_current_chat_provider_id(umo)
llm_resp = await ctx.tool_loop_agent(
event=event,
chat_provider_id=prov_id,
prompt=input_,
tools=toolset,
max_steps=30,
)
await agent_runner.reset(
provider=run_context.context.provider,
request=request,
run_context=AgentContextWrapper(
context=astr_agent_ctx,
tool_call_timeout=run_context.tool_call_timeout,
),
tool_executor=FunctionToolExecutor(),
agent_hooks=tool.agent.run_hooks or BaseAgentRunHooks[AstrAgentContext](),
yield mcp.types.CallToolResult(
content=[mcp.types.TextContent(type="text", text=llm_resp.completion_text)]
)
async for _ in run_agent(agent_runner, 15, True):
pass
if agent_runner.done():
llm_response = agent_runner.get_final_llm_resp()
if not llm_response:
text_content = mcp.types.TextContent(
type="text",
text=f"error when deligate task to {tool.agent.name}",
)
yield mcp.types.CallToolResult(content=[text_content])
return
logger.debug(
f"Agent {tool.agent.name} 任务完成, response: {llm_response.completion_text}",
)
result = (
f"Agent {tool.agent.name} respond with: {llm_response.completion_text}\n\n"
"Note: If the result is error or need user provide more information, please provide more information to the agent(you can ask user for more information first)."
)
text_content = mcp.types.TextContent(
type="text",
text=result,
)
yield mcp.types.CallToolResult(content=[text_content])
else:
text_content = mcp.types.TextContent(
type="text",
text=f"error when deligate task to {tool.agent.name}",
)
yield mcp.types.CallToolResult(content=[text_content])
return
@classmethod
async def _execute_local(
cls,
@@ -6,6 +6,7 @@ import json
from collections.abc import AsyncGenerator
from astrbot.core import logger
from astrbot.core.agent.message import Message
from astrbot.core.agent.tool import ToolSet
from astrbot.core.astr_agent_context import AstrAgentContext
from astrbot.core.conversation_mgr import Conversation
@@ -393,8 +394,11 @@ class LLMRequestSubStage(Stage):
logger.debug(
f"handle provider[id: {provider.provider_config['id']}] request: {req}",
)
context_model: list[Message] = []
for msg in req.contexts:
context_model.append(Message.model_validate(msg))
astr_agent_ctx = AstrAgentContext(
provider=provider,
context=self.ctx.plugin_manager.context,
event=event,
)
await agent_runner.reset(
+8
View File
@@ -63,6 +63,14 @@ class ToolCallsResult:
]
return ret
def to_openai_messages_model(
self,
) -> list[AssistantMessageSegment | ToolCallMessageSegment]:
return [
self.tool_calls_info,
*self.tool_calls_result,
]
@dataclass
class ProviderRequest:
+25 -13
View File
@@ -9,8 +9,6 @@ from astrbot.core.agent.hooks import BaseAgentRunHooks
from astrbot.core.agent.message import Message
from astrbot.core.agent.runners.tool_loop_agent_runner import ToolLoopAgentRunner
from astrbot.core.agent.tool import ToolSet
from astrbot.core.astr_agent_context import AgentContextWrapper, AstrAgentContext
from astrbot.core.astr_agent_tool_exec import FunctionToolExecutor
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.conversation_mgr import ConversationManager
@@ -90,7 +88,7 @@ class Context:
image_urls: list[str] | None = None,
tools: ToolSet | None = None,
system_prompt: str | None = None,
contexts: list[Message] | list[dict] | None = None,
contexts: list[Message] | None = None,
**kwargs: Any,
) -> LLMResponse:
"""Call the LLM to generate a response. The method will not automatically execute tool calls. If you want to use tool calls, please use `tool_loop_agent()`.
@@ -132,12 +130,15 @@ class Context:
image_urls: list[str] | None = None,
tools: ToolSet | None = None,
system_prompt: str | None = None,
contexts: list[Message] | list[dict] | None = None,
contexts: list[Message] | None = None,
max_steps: int = 30,
tool_call_timeout: int = 60,
**kwargs: Any,
) -> LLMResponse:
"""Run an agent loop that allows the LLM to call tools iteratively until a final answer is produced.
If you do not pass the agent_context parameter, the method will recreate a new agent context.
.. versionadded:: 4.5.7 (sdk)
Args:
chat_provider_id: The chat provider ID to use.
@@ -147,7 +148,9 @@ class Context:
system_prompt: System prompt to guide the LLM's behavior, if provided, it will always insert as the first system message in the context
contexts: context messages for the LLM
max_steps: Maximum number of tool calls before stopping the loop
**kwargs: Additional keyword arguments for LLM generation, OpenAI compatible
**kwargs: Additional keyword arguments. The kwargs will not be passed to the LLM directly for now, but can include:
agent_hooks: BaseAgentRunHooks[AstrAgentContext] - hooks to run during agent execution
agent_context: AstrAgentContext - context to use for the agent
Returns:
The final LLMResponse after tool calls are completed.
@@ -156,10 +159,20 @@ class Context:
ChatProviderNotFoundError: If the specified chat provider ID is not found
Exception: For other errors during LLM generation
"""
# Import here to avoid circular imports
from astrbot.core.astr_agent_context import (
AgentContextWrapper,
AstrAgentContext,
)
from astrbot.core.astr_agent_tool_exec import FunctionToolExecutor
prov = await self.provider_manager.get_provider_by_id(chat_provider_id)
if not prov or not isinstance(prov, Provider):
raise ProviderNotFoundError(f"Provider {chat_provider_id} not found")
agent_hooks = kwargs.get("agent_hooks") or BaseAgentRunHooks[AstrAgentContext]()
agent_context = kwargs.get("agent_context")
context_ = []
for msg in contexts or []:
if isinstance(msg, Message):
@@ -174,23 +187,22 @@ class Context:
contexts=context_,
system_prompt=system_prompt or "",
)
astr_agent_ctx = AstrAgentContext(
provider=prov,
event=event,
)
if agent_context is None:
agent_context = AstrAgentContext(
context=self,
event=event,
)
agent_runner = ToolLoopAgentRunner()
tool_executor = FunctionToolExecutor()
await agent_runner.reset(
provider=prov,
request=request,
run_context=AgentContextWrapper(
context=astr_agent_ctx,
context=agent_context,
tool_call_timeout=tool_call_timeout,
),
tool_executor=tool_executor,
agent_hooks=kwargs.get(
"agent_hooks", BaseAgentRunHooks[AstrAgentContext]()
),
agent_hooks=agent_hooks,
streaming=kwargs.get("stream", False),
)
async for _ in agent_runner.step_until_done(max_steps):