feat: introduce messages field in agent RunContext
This commit is contained in:
@@ -1,16 +1,21 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Generic
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic.dataclasses import dataclass
|
||||
from typing_extensions import TypeVar
|
||||
|
||||
from .message import Message
|
||||
|
||||
TContext = TypeVar("TContext", default=Any)
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclass(config={"arbitrary_types_allowed": True})
|
||||
class ContextWrapper(Generic[TContext]):
|
||||
"""A context for running an agent, which can be used to pass additional data or state."""
|
||||
|
||||
context: TContext
|
||||
messages: list[Message] = Field(default_factory=list)
|
||||
"""This field stores the llm message context for the agent run, agent runners will maintain this field automatically."""
|
||||
tool_call_timeout: int = 60 # Default tool call timeout in seconds
|
||||
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ from astrbot.core.provider.entities import (
|
||||
from astrbot.core.provider.provider import Provider
|
||||
|
||||
from ..hooks import BaseAgentRunHooks
|
||||
from ..message import AssistantMessageSegment, ToolCallMessageSegment
|
||||
from ..message import AssistantMessageSegment, Message, ToolCallMessageSegment
|
||||
from ..response import AgentResponseData
|
||||
from ..run_context import ContextWrapper, TContext
|
||||
from ..tool_executor import BaseFunctionToolExecutor
|
||||
@@ -55,6 +55,20 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
self.agent_hooks = agent_hooks
|
||||
self.run_context = run_context
|
||||
|
||||
messages = []
|
||||
# append existing messages in the run context
|
||||
for msg in request.contexts:
|
||||
messages.append(Message.model_validate(msg))
|
||||
if request.prompt is not None:
|
||||
m = await request.assemble_context()
|
||||
messages.append(Message.model_validate(m))
|
||||
if request.system_prompt:
|
||||
messages.insert(
|
||||
0,
|
||||
Message(role="system", content=request.system_prompt),
|
||||
)
|
||||
self.run_context.messages = messages
|
||||
|
||||
def _transition_state(self, new_state: AgentState) -> None:
|
||||
"""转换 Agent 状态"""
|
||||
if self._state != new_state:
|
||||
@@ -130,6 +144,13 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
# 如果没有工具调用,转换到完成状态
|
||||
self.final_llm_resp = llm_resp
|
||||
self._transition_state(AgentState.DONE)
|
||||
# record the final assistant message
|
||||
self.run_context.messages.append(
|
||||
Message(
|
||||
role="assistant",
|
||||
content=llm_resp.completion_text or "",
|
||||
),
|
||||
)
|
||||
try:
|
||||
await self.agent_hooks.on_agent_done(self.run_context, llm_resp)
|
||||
except Exception as e:
|
||||
@@ -175,6 +196,11 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
),
|
||||
tool_calls_result=tool_call_result_blocks,
|
||||
)
|
||||
# record the assistant message with tool calls
|
||||
self.run_context.messages.extend(
|
||||
tool_calls_result.to_openai_messages_model()
|
||||
)
|
||||
|
||||
self.req.append_tool_calls_result(tool_calls_result)
|
||||
|
||||
async def step_until_done(
|
||||
|
||||
@@ -10,6 +10,7 @@ from pydantic.dataclasses import dataclass
|
||||
from .run_context import ContextWrapper, TContext
|
||||
|
||||
ParametersType = dict[str, Any]
|
||||
ToolExecResult = str | mcp.types.CallToolResult
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -55,9 +56,7 @@ class FunctionTool(ToolSchema, Generic[TContext]):
|
||||
def __repr__(self):
|
||||
return f"FuncTool(name={self.name}, parameters={self.parameters}, description={self.description})"
|
||||
|
||||
async def call(
|
||||
self, context: ContextWrapper[TContext], **kwargs
|
||||
) -> str | mcp.types.CallToolResult:
|
||||
async def call(self, context: ContextWrapper[TContext], **kwargs) -> ToolExecResult:
|
||||
"""Run the tool with the given arguments. The handler field has priority."""
|
||||
raise NotImplementedError(
|
||||
"FunctionTool.call() must be implemented by subclasses or set a handler."
|
||||
|
||||
@@ -1,14 +1,19 @@
|
||||
from dataclasses import dataclass
|
||||
from pydantic import Field
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
from astrbot.core.agent.run_context import ContextWrapper
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.provider import Provider
|
||||
from astrbot.core.star.context import Context
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclass(config={"arbitrary_types_allowed": True})
|
||||
class AstrAgentContext:
|
||||
provider: Provider
|
||||
context: Context
|
||||
"""The star context instance"""
|
||||
event: AstrMessageEvent
|
||||
"""The message event associated with the agent context."""
|
||||
extra: dict[str, str] = Field(default_factory=dict)
|
||||
"""Customized extra data."""
|
||||
|
||||
|
||||
AgentContextWrapper = ContextWrapper[AstrAgentContext]
|
||||
|
||||
@@ -7,7 +7,6 @@ import mcp
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot.core.agent.handoff import HandoffTool
|
||||
from astrbot.core.agent.hooks import BaseAgentRunHooks
|
||||
from astrbot.core.agent.mcp_client import MCPTool
|
||||
from astrbot.core.agent.run_context import ContextWrapper
|
||||
from astrbot.core.agent.tool import FunctionTool, ToolSet
|
||||
@@ -18,12 +17,8 @@ from astrbot.core.message.message_event_result import (
|
||||
MessageChain,
|
||||
MessageEventResult,
|
||||
)
|
||||
from astrbot.core.provider.entities import ProviderRequest
|
||||
from astrbot.core.provider.register import llm_tools
|
||||
|
||||
from .astr_agent_context import AgentContextWrapper
|
||||
from .astr_agent_run_util import AgentRunner, run_agent
|
||||
|
||||
|
||||
class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
|
||||
@classmethod
|
||||
@@ -60,8 +55,7 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
|
||||
run_context: ContextWrapper[AstrAgentContext],
|
||||
**tool_args,
|
||||
):
|
||||
input_ = tool_args.get("input", "agent")
|
||||
agent_runner = AgentRunner()
|
||||
input_ = tool_args.get("input")
|
||||
|
||||
# make toolset for the agent
|
||||
tools = tool.agent.tools
|
||||
@@ -77,72 +71,21 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
|
||||
else:
|
||||
toolset = None
|
||||
|
||||
request = ProviderRequest(
|
||||
prompt=input_,
|
||||
system_prompt=tool.description or "",
|
||||
image_urls=[], # 暂时不传递原始 agent 的上下文
|
||||
contexts=[], # 暂时不传递原始 agent 的上下文
|
||||
func_tool=toolset,
|
||||
)
|
||||
astr_agent_ctx = AstrAgentContext(
|
||||
provider=run_context.context.provider,
|
||||
event=run_context.context.event,
|
||||
)
|
||||
|
||||
ctx = run_context.context.context
|
||||
event = run_context.context.event
|
||||
|
||||
logger.debug(f"正在将任务委托给 Agent: {tool.agent.name}, input: {input_}")
|
||||
await event.send(
|
||||
MessageChain().message("✨ 正在将任务委托给 Agent: " + tool.agent.name),
|
||||
umo = event.unified_msg_origin
|
||||
prov_id = await ctx.get_current_chat_provider_id(umo)
|
||||
llm_resp = await ctx.tool_loop_agent(
|
||||
event=event,
|
||||
chat_provider_id=prov_id,
|
||||
prompt=input_,
|
||||
tools=toolset,
|
||||
max_steps=30,
|
||||
)
|
||||
|
||||
await agent_runner.reset(
|
||||
provider=run_context.context.provider,
|
||||
request=request,
|
||||
run_context=AgentContextWrapper(
|
||||
context=astr_agent_ctx,
|
||||
tool_call_timeout=run_context.tool_call_timeout,
|
||||
),
|
||||
tool_executor=FunctionToolExecutor(),
|
||||
agent_hooks=tool.agent.run_hooks or BaseAgentRunHooks[AstrAgentContext](),
|
||||
yield mcp.types.CallToolResult(
|
||||
content=[mcp.types.TextContent(type="text", text=llm_resp.completion_text)]
|
||||
)
|
||||
|
||||
async for _ in run_agent(agent_runner, 15, True):
|
||||
pass
|
||||
|
||||
if agent_runner.done():
|
||||
llm_response = agent_runner.get_final_llm_resp()
|
||||
|
||||
if not llm_response:
|
||||
text_content = mcp.types.TextContent(
|
||||
type="text",
|
||||
text=f"error when deligate task to {tool.agent.name}",
|
||||
)
|
||||
yield mcp.types.CallToolResult(content=[text_content])
|
||||
return
|
||||
|
||||
logger.debug(
|
||||
f"Agent {tool.agent.name} 任务完成, response: {llm_response.completion_text}",
|
||||
)
|
||||
|
||||
result = (
|
||||
f"Agent {tool.agent.name} respond with: {llm_response.completion_text}\n\n"
|
||||
"Note: If the result is error or need user provide more information, please provide more information to the agent(you can ask user for more information first)."
|
||||
)
|
||||
|
||||
text_content = mcp.types.TextContent(
|
||||
type="text",
|
||||
text=result,
|
||||
)
|
||||
yield mcp.types.CallToolResult(content=[text_content])
|
||||
else:
|
||||
text_content = mcp.types.TextContent(
|
||||
type="text",
|
||||
text=f"error when deligate task to {tool.agent.name}",
|
||||
)
|
||||
yield mcp.types.CallToolResult(content=[text_content])
|
||||
return
|
||||
|
||||
@classmethod
|
||||
async def _execute_local(
|
||||
cls,
|
||||
|
||||
@@ -6,6 +6,7 @@ import json
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.agent.message import Message
|
||||
from astrbot.core.agent.tool import ToolSet
|
||||
from astrbot.core.astr_agent_context import AstrAgentContext
|
||||
from astrbot.core.conversation_mgr import Conversation
|
||||
@@ -393,8 +394,11 @@ class LLMRequestSubStage(Stage):
|
||||
logger.debug(
|
||||
f"handle provider[id: {provider.provider_config['id']}] request: {req}",
|
||||
)
|
||||
context_model: list[Message] = []
|
||||
for msg in req.contexts:
|
||||
context_model.append(Message.model_validate(msg))
|
||||
astr_agent_ctx = AstrAgentContext(
|
||||
provider=provider,
|
||||
context=self.ctx.plugin_manager.context,
|
||||
event=event,
|
||||
)
|
||||
await agent_runner.reset(
|
||||
|
||||
@@ -63,6 +63,14 @@ class ToolCallsResult:
|
||||
]
|
||||
return ret
|
||||
|
||||
def to_openai_messages_model(
|
||||
self,
|
||||
) -> list[AssistantMessageSegment | ToolCallMessageSegment]:
|
||||
return [
|
||||
self.tool_calls_info,
|
||||
*self.tool_calls_result,
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProviderRequest:
|
||||
|
||||
@@ -9,8 +9,6 @@ from astrbot.core.agent.hooks import BaseAgentRunHooks
|
||||
from astrbot.core.agent.message import Message
|
||||
from astrbot.core.agent.runners.tool_loop_agent_runner import ToolLoopAgentRunner
|
||||
from astrbot.core.agent.tool import ToolSet
|
||||
from astrbot.core.astr_agent_context import AgentContextWrapper, AstrAgentContext
|
||||
from astrbot.core.astr_agent_tool_exec import FunctionToolExecutor
|
||||
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core.conversation_mgr import ConversationManager
|
||||
@@ -90,7 +88,7 @@ class Context:
|
||||
image_urls: list[str] | None = None,
|
||||
tools: ToolSet | None = None,
|
||||
system_prompt: str | None = None,
|
||||
contexts: list[Message] | list[dict] | None = None,
|
||||
contexts: list[Message] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResponse:
|
||||
"""Call the LLM to generate a response. The method will not automatically execute tool calls. If you want to use tool calls, please use `tool_loop_agent()`.
|
||||
@@ -132,12 +130,15 @@ class Context:
|
||||
image_urls: list[str] | None = None,
|
||||
tools: ToolSet | None = None,
|
||||
system_prompt: str | None = None,
|
||||
contexts: list[Message] | list[dict] | None = None,
|
||||
contexts: list[Message] | None = None,
|
||||
max_steps: int = 30,
|
||||
tool_call_timeout: int = 60,
|
||||
**kwargs: Any,
|
||||
) -> LLMResponse:
|
||||
"""Run an agent loop that allows the LLM to call tools iteratively until a final answer is produced.
|
||||
If you do not pass the agent_context parameter, the method will recreate a new agent context.
|
||||
|
||||
.. versionadded:: 4.5.7 (sdk)
|
||||
|
||||
Args:
|
||||
chat_provider_id: The chat provider ID to use.
|
||||
@@ -147,7 +148,9 @@ class Context:
|
||||
system_prompt: System prompt to guide the LLM's behavior, if provided, it will always insert as the first system message in the context
|
||||
contexts: context messages for the LLM
|
||||
max_steps: Maximum number of tool calls before stopping the loop
|
||||
**kwargs: Additional keyword arguments for LLM generation, OpenAI compatible
|
||||
**kwargs: Additional keyword arguments. The kwargs will not be passed to the LLM directly for now, but can include:
|
||||
agent_hooks: BaseAgentRunHooks[AstrAgentContext] - hooks to run during agent execution
|
||||
agent_context: AstrAgentContext - context to use for the agent
|
||||
|
||||
Returns:
|
||||
The final LLMResponse after tool calls are completed.
|
||||
@@ -156,10 +159,20 @@ class Context:
|
||||
ChatProviderNotFoundError: If the specified chat provider ID is not found
|
||||
Exception: For other errors during LLM generation
|
||||
"""
|
||||
# Import here to avoid circular imports
|
||||
from astrbot.core.astr_agent_context import (
|
||||
AgentContextWrapper,
|
||||
AstrAgentContext,
|
||||
)
|
||||
from astrbot.core.astr_agent_tool_exec import FunctionToolExecutor
|
||||
|
||||
prov = await self.provider_manager.get_provider_by_id(chat_provider_id)
|
||||
if not prov or not isinstance(prov, Provider):
|
||||
raise ProviderNotFoundError(f"Provider {chat_provider_id} not found")
|
||||
|
||||
agent_hooks = kwargs.get("agent_hooks") or BaseAgentRunHooks[AstrAgentContext]()
|
||||
agent_context = kwargs.get("agent_context")
|
||||
|
||||
context_ = []
|
||||
for msg in contexts or []:
|
||||
if isinstance(msg, Message):
|
||||
@@ -174,23 +187,22 @@ class Context:
|
||||
contexts=context_,
|
||||
system_prompt=system_prompt or "",
|
||||
)
|
||||
astr_agent_ctx = AstrAgentContext(
|
||||
provider=prov,
|
||||
event=event,
|
||||
)
|
||||
if agent_context is None:
|
||||
agent_context = AstrAgentContext(
|
||||
context=self,
|
||||
event=event,
|
||||
)
|
||||
agent_runner = ToolLoopAgentRunner()
|
||||
tool_executor = FunctionToolExecutor()
|
||||
await agent_runner.reset(
|
||||
provider=prov,
|
||||
request=request,
|
||||
run_context=AgentContextWrapper(
|
||||
context=astr_agent_ctx,
|
||||
context=agent_context,
|
||||
tool_call_timeout=tool_call_timeout,
|
||||
),
|
||||
tool_executor=tool_executor,
|
||||
agent_hooks=kwargs.get(
|
||||
"agent_hooks", BaseAgentRunHooks[AstrAgentContext]()
|
||||
),
|
||||
agent_hooks=agent_hooks,
|
||||
streaming=kwargs.get("stream", False),
|
||||
)
|
||||
async for _ in agent_runner.step_until_done(max_steps):
|
||||
|
||||
Reference in New Issue
Block a user