perf(agent): add max step limit to prevent infinite tool call loops (#4110)
* perf(agent): add max step limit to prevent infinite tool call loops * feat: implement max step limit handling in main agent runner - Enhanced the agent runner to enforce a maximum step limit, logging a warning and forcing a final response when the limit is reached. - Updated message handling to append a user prompt when the tool call limit is exceeded. - Refactored tool response handling to yield appropriate messages based on the response type, including handling cases with no response or unsupported types. - Improved conversation message formatting to ensure consistent output in the assistant's responses. * chore: ruff format --------- Co-authored-by: Soulter <905617992@qq.com>
This commit is contained in:
@@ -76,12 +76,19 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
|
||||
async def _iter_llm_responses(self) -> T.AsyncGenerator[LLMResponse, None]:
|
||||
"""Yields chunks *and* a final LLMResponse."""
|
||||
payload = {
|
||||
"contexts": self.run_context.messages,
|
||||
"func_tool": self.req.func_tool,
|
||||
"model": self.req.model, # NOTE: in fact, this arg is None in most cases
|
||||
"session_id": self.req.session_id,
|
||||
}
|
||||
|
||||
if self.streaming:
|
||||
stream = self.provider.text_chat_stream(**self.req.__dict__)
|
||||
stream = self.provider.text_chat_stream(**payload)
|
||||
async for resp in stream: # type: ignore
|
||||
yield resp
|
||||
else:
|
||||
yield await self.provider.text_chat(**self.req.__dict__)
|
||||
yield await self.provider.text_chat(**payload)
|
||||
|
||||
@override
|
||||
async def step(self):
|
||||
@@ -230,6 +237,25 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
async for resp in self.step():
|
||||
yield resp
|
||||
|
||||
# 如果循环结束了但是 agent 还没有完成,说明是达到了 max_step
|
||||
if not self.done():
|
||||
logger.warning(
|
||||
f"Agent reached max steps ({max_step}), forcing a final response."
|
||||
)
|
||||
# 拔掉所有工具
|
||||
if self.req:
|
||||
self.req.func_tool = None
|
||||
# 注入提示词
|
||||
self.run_context.messages.append(
|
||||
Message(
|
||||
role="user",
|
||||
content="工具调用次数已达到上限,请停止使用工具,并根据已经收集到的信息,对你的任务和发现进行总结,然后直接回复用户。",
|
||||
)
|
||||
)
|
||||
# 再执行最后一步
|
||||
async for resp in self.step():
|
||||
yield resp
|
||||
|
||||
async def _handle_function_tools(
|
||||
self,
|
||||
req: ProviderRequest,
|
||||
|
||||
@@ -2,6 +2,7 @@ import traceback
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.agent.message import Message
|
||||
from astrbot.core.agent.runners.tool_loop_agent_runner import ToolLoopAgentRunner
|
||||
from astrbot.core.astr_agent_context import AstrAgentContext
|
||||
from astrbot.core.message.components import Json
|
||||
@@ -24,8 +25,25 @@ async def run_agent(
|
||||
) -> AsyncGenerator[MessageChain | None, None]:
|
||||
step_idx = 0
|
||||
astr_event = agent_runner.run_context.context.event
|
||||
while step_idx < max_step:
|
||||
while step_idx < max_step + 1:
|
||||
step_idx += 1
|
||||
|
||||
if step_idx == max_step + 1:
|
||||
logger.warning(
|
||||
f"Agent reached max steps ({max_step}), forcing a final response."
|
||||
)
|
||||
if not agent_runner.done():
|
||||
# 拔掉所有工具
|
||||
if agent_runner.req:
|
||||
agent_runner.req.func_tool = None
|
||||
# 注入提示词
|
||||
agent_runner.run_context.messages.append(
|
||||
Message(
|
||||
role="user",
|
||||
content="工具调用次数已达到上限,请停止使用工具,并根据已经收集到的信息,对你的任务和发现进行总结,然后直接回复用户。",
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
async for resp in agent_runner.step():
|
||||
if astr_event.is_stopped():
|
||||
|
||||
@@ -321,7 +321,12 @@ class InternalAgentSubStage(Stage):
|
||||
elif isinstance(req.tool_calls_result, list):
|
||||
for tcr in req.tool_calls_result:
|
||||
messages.extend(tcr.to_openai_messages())
|
||||
messages.append({"role": "assistant", "content": llm_response.completion_text})
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": llm_response.completion_text or "*No response*",
|
||||
}
|
||||
)
|
||||
messages = list(filter(lambda item: "_no_save" not in item, messages))
|
||||
await self.conv_manager.update_conversation(
|
||||
event.unified_msg_origin,
|
||||
|
||||
@@ -0,0 +1,326 @@
|
||||
import os
|
||||
import sys
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
# 将项目根目录添加到 sys.path
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
|
||||
from astrbot.core.agent.hooks import BaseAgentRunHooks
|
||||
from astrbot.core.agent.run_context import ContextWrapper
|
||||
from astrbot.core.agent.runners.tool_loop_agent_runner import ToolLoopAgentRunner
|
||||
from astrbot.core.agent.tool import FunctionTool, ToolSet
|
||||
from astrbot.core.provider.entities import LLMResponse, ProviderRequest, TokenUsage
|
||||
from astrbot.core.provider.provider import Provider
|
||||
|
||||
|
||||
class MockProvider(Provider):
|
||||
"""模拟Provider用于测试"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__({}, {})
|
||||
self.call_count = 0
|
||||
self.should_call_tools = True
|
||||
self.max_calls_before_normal_response = 10
|
||||
|
||||
def get_current_key(self) -> str:
|
||||
return "test_key"
|
||||
|
||||
def set_key(self, key: str):
|
||||
pass
|
||||
|
||||
async def get_models(self) -> list[str]:
|
||||
return ["test_model"]
|
||||
|
||||
async def text_chat(self, **kwargs) -> LLMResponse:
|
||||
self.call_count += 1
|
||||
|
||||
# 检查工具是否被禁用
|
||||
func_tool = kwargs.get("func_tool")
|
||||
|
||||
# 如果工具被禁用或超过最大调用次数,返回正常响应
|
||||
if func_tool is None or self.call_count > self.max_calls_before_normal_response:
|
||||
return LLMResponse(
|
||||
role="assistant",
|
||||
completion_text="这是我的最终回答",
|
||||
usage=TokenUsage(input_other=10, output=5),
|
||||
)
|
||||
|
||||
# 模拟工具调用响应
|
||||
if self.should_call_tools:
|
||||
return LLMResponse(
|
||||
role="assistant",
|
||||
completion_text="我需要使用工具来帮助您",
|
||||
tools_call_name=["test_tool"],
|
||||
tools_call_args=[{"query": "test"}],
|
||||
tools_call_ids=["call_123"],
|
||||
usage=TokenUsage(input_other=10, output=5),
|
||||
)
|
||||
|
||||
# 默认返回正常响应
|
||||
return LLMResponse(
|
||||
role="assistant",
|
||||
completion_text="这是我的最终回答",
|
||||
usage=TokenUsage(input_other=10, output=5),
|
||||
)
|
||||
|
||||
async def text_chat_stream(self, **kwargs):
|
||||
response = await self.text_chat(**kwargs)
|
||||
response.is_chunk = True
|
||||
yield response
|
||||
response.is_chunk = False
|
||||
yield response
|
||||
|
||||
|
||||
class MockToolExecutor:
|
||||
"""模拟工具执行器"""
|
||||
|
||||
@classmethod
|
||||
def execute(cls, tool, run_context, **tool_args):
|
||||
async def generator():
|
||||
# 模拟工具返回结果,使用正确的类型
|
||||
from mcp.types import CallToolResult, TextContent
|
||||
|
||||
result = CallToolResult(
|
||||
content=[TextContent(type="text", text="工具执行结果")]
|
||||
)
|
||||
yield result
|
||||
|
||||
return generator()
|
||||
|
||||
|
||||
class MockHooks(BaseAgentRunHooks):
|
||||
"""模拟钩子函数"""
|
||||
|
||||
def __init__(self):
|
||||
self.agent_begin_called = False
|
||||
self.agent_done_called = False
|
||||
self.tool_start_called = False
|
||||
self.tool_end_called = False
|
||||
|
||||
async def on_agent_begin(self, run_context):
|
||||
self.agent_begin_called = True
|
||||
|
||||
async def on_tool_start(self, run_context, tool, tool_args):
|
||||
self.tool_start_called = True
|
||||
|
||||
async def on_tool_end(self, run_context, tool, tool_args, tool_result):
|
||||
self.tool_end_called = True
|
||||
|
||||
async def on_agent_done(self, run_context, llm_response):
|
||||
self.agent_done_called = True
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_provider():
|
||||
return MockProvider()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tool_executor():
|
||||
return MockToolExecutor()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_hooks():
|
||||
return MockHooks()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tool_set():
|
||||
"""创建测试用的工具集"""
|
||||
tool = FunctionTool(
|
||||
name="test_tool",
|
||||
description="测试工具",
|
||||
parameters={"type": "object", "properties": {"query": {"type": "string"}}},
|
||||
handler=AsyncMock(),
|
||||
)
|
||||
return ToolSet(tools=[tool])
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def provider_request(tool_set):
|
||||
"""创建测试用的ProviderRequest"""
|
||||
return ProviderRequest(prompt="请帮我查询信息", func_tool=tool_set, contexts=[])
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def runner():
|
||||
"""创建ToolLoopAgentRunner实例"""
|
||||
return ToolLoopAgentRunner()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_max_step_limit_functionality(
|
||||
runner, mock_provider, provider_request, mock_tool_executor, mock_hooks
|
||||
):
|
||||
"""测试最大步数限制功能"""
|
||||
|
||||
# 设置模拟provider,让它总是返回工具调用
|
||||
mock_provider.should_call_tools = True
|
||||
mock_provider.max_calls_before_normal_response = (
|
||||
100 # 设置一个很大的值,确保不会自然结束
|
||||
)
|
||||
|
||||
# 初始化runner
|
||||
await runner.reset(
|
||||
provider=mock_provider,
|
||||
request=provider_request,
|
||||
run_context=ContextWrapper(context=None),
|
||||
tool_executor=mock_tool_executor,
|
||||
agent_hooks=mock_hooks,
|
||||
streaming=False,
|
||||
)
|
||||
|
||||
# 设置较小的最大步数来测试限制功能
|
||||
max_steps = 3
|
||||
|
||||
# 收集所有响应
|
||||
responses = []
|
||||
async for response in runner.step_until_done(max_steps):
|
||||
responses.append(response)
|
||||
|
||||
# 验证结果
|
||||
assert runner.done(), "代理应该在达到最大步数后完成"
|
||||
|
||||
# 验证工具被禁用(这是最重要的验证点)
|
||||
assert runner.req.func_tool is None, "达到最大步数后工具应该被禁用"
|
||||
|
||||
# 验证有最终响应
|
||||
final_responses = [r for r in responses if r.type == "llm_result"]
|
||||
assert len(final_responses) > 0, "应该有最终的LLM响应"
|
||||
|
||||
# 验证最后一条消息是assistant的最终回答
|
||||
last_message = runner.run_context.messages[-1]
|
||||
assert last_message.role == "assistant", "最后一条消息应该是assistant的最终回答"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_normal_completion_without_max_step(
|
||||
runner, mock_provider, provider_request, mock_tool_executor, mock_hooks
|
||||
):
|
||||
"""测试正常完成(不触发最大步数限制)"""
|
||||
|
||||
# 设置模拟provider,让它在第2次调用时返回正常响应
|
||||
mock_provider.should_call_tools = True
|
||||
mock_provider.max_calls_before_normal_response = 2
|
||||
|
||||
# 初始化runner
|
||||
await runner.reset(
|
||||
provider=mock_provider,
|
||||
request=provider_request,
|
||||
run_context=ContextWrapper(context=None),
|
||||
tool_executor=mock_tool_executor,
|
||||
agent_hooks=mock_hooks,
|
||||
streaming=False,
|
||||
)
|
||||
|
||||
# 设置足够大的最大步数
|
||||
max_steps = 10
|
||||
|
||||
# 收集所有响应
|
||||
responses = []
|
||||
async for response in runner.step_until_done(max_steps):
|
||||
responses.append(response)
|
||||
|
||||
# 验证结果
|
||||
assert runner.done(), "代理应该正常完成"
|
||||
|
||||
# 验证没有触发最大步数限制 - 通过检查provider调用次数
|
||||
# mock_provider在第2次调用后返回正常响应,所以不应该达到max_steps(10)
|
||||
assert mock_provider.call_count < max_steps, (
|
||||
f"正常完成时调用次数({mock_provider.call_count})应该小于最大步数({max_steps})"
|
||||
)
|
||||
|
||||
# 验证没有最大步数警告消息(注意:实际注入的是user角色的消息)
|
||||
user_messages = [m for m in runner.run_context.messages if m.role == "user"]
|
||||
max_step_messages = [
|
||||
m for m in user_messages if "工具调用次数已达到上限" in m.content
|
||||
]
|
||||
assert len(max_step_messages) == 0, "正常完成时不应该有步数限制消息"
|
||||
|
||||
# 验证工具仍然可用(没有被禁用)
|
||||
assert runner.req.func_tool is not None, "正常完成时工具不应该被禁用"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_max_step_with_streaming(
|
||||
runner, mock_provider, provider_request, mock_tool_executor, mock_hooks
|
||||
):
|
||||
"""测试流式响应下的最大步数限制"""
|
||||
|
||||
# 设置模拟provider
|
||||
mock_provider.should_call_tools = True
|
||||
mock_provider.max_calls_before_normal_response = 100
|
||||
|
||||
# 初始化runner,启用流式响应
|
||||
await runner.reset(
|
||||
provider=mock_provider,
|
||||
request=provider_request,
|
||||
run_context=ContextWrapper(context=None),
|
||||
tool_executor=mock_tool_executor,
|
||||
agent_hooks=mock_hooks,
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
# 设置较小的最大步数
|
||||
max_steps = 2
|
||||
|
||||
# 收集所有响应
|
||||
responses = []
|
||||
async for response in runner.step_until_done(max_steps):
|
||||
responses.append(response)
|
||||
|
||||
# 验证结果
|
||||
assert runner.done(), "代理应该在达到最大步数后完成"
|
||||
|
||||
# 验证有流式响应
|
||||
streaming_responses = [r for r in responses if r.type == "streaming_delta"]
|
||||
assert len(streaming_responses) > 0, "应该有流式响应"
|
||||
|
||||
# 验证工具被禁用
|
||||
assert runner.req.func_tool is None, "达到最大步数后工具应该被禁用"
|
||||
|
||||
# 验证最后一条消息是assistant的最终回答
|
||||
last_message = runner.run_context.messages[-1]
|
||||
assert last_message.role == "assistant", "最后一条消息应该是assistant的最终回答"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_hooks_called_with_max_step(
|
||||
runner, mock_provider, provider_request, mock_tool_executor, mock_hooks
|
||||
):
|
||||
"""测试达到最大步数时钩子函数是否被正确调用"""
|
||||
|
||||
# 设置模拟provider
|
||||
mock_provider.should_call_tools = True
|
||||
mock_provider.max_calls_before_normal_response = 100
|
||||
|
||||
# 初始化runner
|
||||
await runner.reset(
|
||||
provider=mock_provider,
|
||||
request=provider_request,
|
||||
run_context=ContextWrapper(context=None),
|
||||
tool_executor=mock_tool_executor,
|
||||
agent_hooks=mock_hooks,
|
||||
streaming=False,
|
||||
)
|
||||
|
||||
# 设置较小的最大步数
|
||||
max_steps = 2
|
||||
|
||||
# 执行步骤
|
||||
async for response in runner.step_until_done(max_steps):
|
||||
pass
|
||||
|
||||
# 验证钩子函数被调用
|
||||
assert mock_hooks.agent_begin_called, "on_agent_begin应该被调用"
|
||||
assert mock_hooks.agent_done_called, "on_agent_done应该被调用"
|
||||
assert mock_hooks.tool_start_called, "on_tool_start应该被调用"
|
||||
assert mock_hooks.tool_end_called, "on_tool_end应该被调用"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 运行测试
|
||||
pytest.main([__file__, "-v"])
|
||||
Reference in New Issue
Block a user