diff --git a/astrbot/core/agent/runners/tool_loop_agent_runner.py b/astrbot/core/agent/runners/tool_loop_agent_runner.py index 069de144f..5079d6484 100644 --- a/astrbot/core/agent/runners/tool_loop_agent_runner.py +++ b/astrbot/core/agent/runners/tool_loop_agent_runner.py @@ -76,12 +76,19 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]): async def _iter_llm_responses(self) -> T.AsyncGenerator[LLMResponse, None]: """Yields chunks *and* a final LLMResponse.""" + payload = { + "contexts": self.run_context.messages, + "func_tool": self.req.func_tool, + "model": self.req.model, # NOTE: in fact, this arg is None in most cases + "session_id": self.req.session_id, + } + if self.streaming: - stream = self.provider.text_chat_stream(**self.req.__dict__) + stream = self.provider.text_chat_stream(**payload) async for resp in stream: # type: ignore yield resp else: - yield await self.provider.text_chat(**self.req.__dict__) + yield await self.provider.text_chat(**payload) @override async def step(self): @@ -230,6 +237,25 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]): async for resp in self.step(): yield resp + # 如果循环结束了但是 agent 还没有完成,说明是达到了 max_step + if not self.done(): + logger.warning( + f"Agent reached max steps ({max_step}), forcing a final response." + ) + # 拔掉所有工具 + if self.req: + self.req.func_tool = None + # 注入提示词 + self.run_context.messages.append( + Message( + role="user", + content="工具调用次数已达到上限,请停止使用工具,并根据已经收集到的信息,对你的任务和发现进行总结,然后直接回复用户。", + ) + ) + # 再执行最后一步 + async for resp in self.step(): + yield resp + async def _handle_function_tools( self, req: ProviderRequest, diff --git a/astrbot/core/astr_agent_run_util.py b/astrbot/core/astr_agent_run_util.py index 5421a14c0..d57cf5e93 100644 --- a/astrbot/core/astr_agent_run_util.py +++ b/astrbot/core/astr_agent_run_util.py @@ -2,6 +2,7 @@ import traceback from collections.abc import AsyncGenerator from astrbot.core import logger +from astrbot.core.agent.message import Message from astrbot.core.agent.runners.tool_loop_agent_runner import ToolLoopAgentRunner from astrbot.core.astr_agent_context import AstrAgentContext from astrbot.core.message.components import Json @@ -24,8 +25,25 @@ async def run_agent( ) -> AsyncGenerator[MessageChain | None, None]: step_idx = 0 astr_event = agent_runner.run_context.context.event - while step_idx < max_step: + while step_idx < max_step + 1: step_idx += 1 + + if step_idx == max_step + 1: + logger.warning( + f"Agent reached max steps ({max_step}), forcing a final response." + ) + if not agent_runner.done(): + # 拔掉所有工具 + if agent_runner.req: + agent_runner.req.func_tool = None + # 注入提示词 + agent_runner.run_context.messages.append( + Message( + role="user", + content="工具调用次数已达到上限,请停止使用工具,并根据已经收集到的信息,对你的任务和发现进行总结,然后直接回复用户。", + ) + ) + try: async for resp in agent_runner.step(): if astr_event.is_stopped(): diff --git a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py index 7e3305f55..d147a811f 100644 --- a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py +++ b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py @@ -321,7 +321,12 @@ class InternalAgentSubStage(Stage): elif isinstance(req.tool_calls_result, list): for tcr in req.tool_calls_result: messages.extend(tcr.to_openai_messages()) - messages.append({"role": "assistant", "content": llm_response.completion_text}) + messages.append( + { + "role": "assistant", + "content": llm_response.completion_text or "*No response*", + } + ) messages = list(filter(lambda item: "_no_save" not in item, messages)) await self.conv_manager.update_conversation( event.unified_msg_origin, diff --git a/tests/test_tool_loop_agent_runner.py b/tests/test_tool_loop_agent_runner.py new file mode 100644 index 000000000..f0e90002d --- /dev/null +++ b/tests/test_tool_loop_agent_runner.py @@ -0,0 +1,326 @@ +import os +import sys +from unittest.mock import AsyncMock + +import pytest + +# 将项目根目录添加到 sys.path +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) + +from astrbot.core.agent.hooks import BaseAgentRunHooks +from astrbot.core.agent.run_context import ContextWrapper +from astrbot.core.agent.runners.tool_loop_agent_runner import ToolLoopAgentRunner +from astrbot.core.agent.tool import FunctionTool, ToolSet +from astrbot.core.provider.entities import LLMResponse, ProviderRequest, TokenUsage +from astrbot.core.provider.provider import Provider + + +class MockProvider(Provider): + """模拟Provider用于测试""" + + def __init__(self): + super().__init__({}, {}) + self.call_count = 0 + self.should_call_tools = True + self.max_calls_before_normal_response = 10 + + def get_current_key(self) -> str: + return "test_key" + + def set_key(self, key: str): + pass + + async def get_models(self) -> list[str]: + return ["test_model"] + + async def text_chat(self, **kwargs) -> LLMResponse: + self.call_count += 1 + + # 检查工具是否被禁用 + func_tool = kwargs.get("func_tool") + + # 如果工具被禁用或超过最大调用次数,返回正常响应 + if func_tool is None or self.call_count > self.max_calls_before_normal_response: + return LLMResponse( + role="assistant", + completion_text="这是我的最终回答", + usage=TokenUsage(input_other=10, output=5), + ) + + # 模拟工具调用响应 + if self.should_call_tools: + return LLMResponse( + role="assistant", + completion_text="我需要使用工具来帮助您", + tools_call_name=["test_tool"], + tools_call_args=[{"query": "test"}], + tools_call_ids=["call_123"], + usage=TokenUsage(input_other=10, output=5), + ) + + # 默认返回正常响应 + return LLMResponse( + role="assistant", + completion_text="这是我的最终回答", + usage=TokenUsage(input_other=10, output=5), + ) + + async def text_chat_stream(self, **kwargs): + response = await self.text_chat(**kwargs) + response.is_chunk = True + yield response + response.is_chunk = False + yield response + + +class MockToolExecutor: + """模拟工具执行器""" + + @classmethod + def execute(cls, tool, run_context, **tool_args): + async def generator(): + # 模拟工具返回结果,使用正确的类型 + from mcp.types import CallToolResult, TextContent + + result = CallToolResult( + content=[TextContent(type="text", text="工具执行结果")] + ) + yield result + + return generator() + + +class MockHooks(BaseAgentRunHooks): + """模拟钩子函数""" + + def __init__(self): + self.agent_begin_called = False + self.agent_done_called = False + self.tool_start_called = False + self.tool_end_called = False + + async def on_agent_begin(self, run_context): + self.agent_begin_called = True + + async def on_tool_start(self, run_context, tool, tool_args): + self.tool_start_called = True + + async def on_tool_end(self, run_context, tool, tool_args, tool_result): + self.tool_end_called = True + + async def on_agent_done(self, run_context, llm_response): + self.agent_done_called = True + + +@pytest.fixture +def mock_provider(): + return MockProvider() + + +@pytest.fixture +def mock_tool_executor(): + return MockToolExecutor() + + +@pytest.fixture +def mock_hooks(): + return MockHooks() + + +@pytest.fixture +def tool_set(): + """创建测试用的工具集""" + tool = FunctionTool( + name="test_tool", + description="测试工具", + parameters={"type": "object", "properties": {"query": {"type": "string"}}}, + handler=AsyncMock(), + ) + return ToolSet(tools=[tool]) + + +@pytest.fixture +def provider_request(tool_set): + """创建测试用的ProviderRequest""" + return ProviderRequest(prompt="请帮我查询信息", func_tool=tool_set, contexts=[]) + + +@pytest.fixture +def runner(): + """创建ToolLoopAgentRunner实例""" + return ToolLoopAgentRunner() + + +@pytest.mark.asyncio +async def test_max_step_limit_functionality( + runner, mock_provider, provider_request, mock_tool_executor, mock_hooks +): + """测试最大步数限制功能""" + + # 设置模拟provider,让它总是返回工具调用 + mock_provider.should_call_tools = True + mock_provider.max_calls_before_normal_response = ( + 100 # 设置一个很大的值,确保不会自然结束 + ) + + # 初始化runner + await runner.reset( + provider=mock_provider, + request=provider_request, + run_context=ContextWrapper(context=None), + tool_executor=mock_tool_executor, + agent_hooks=mock_hooks, + streaming=False, + ) + + # 设置较小的最大步数来测试限制功能 + max_steps = 3 + + # 收集所有响应 + responses = [] + async for response in runner.step_until_done(max_steps): + responses.append(response) + + # 验证结果 + assert runner.done(), "代理应该在达到最大步数后完成" + + # 验证工具被禁用(这是最重要的验证点) + assert runner.req.func_tool is None, "达到最大步数后工具应该被禁用" + + # 验证有最终响应 + final_responses = [r for r in responses if r.type == "llm_result"] + assert len(final_responses) > 0, "应该有最终的LLM响应" + + # 验证最后一条消息是assistant的最终回答 + last_message = runner.run_context.messages[-1] + assert last_message.role == "assistant", "最后一条消息应该是assistant的最终回答" + + +@pytest.mark.asyncio +async def test_normal_completion_without_max_step( + runner, mock_provider, provider_request, mock_tool_executor, mock_hooks +): + """测试正常完成(不触发最大步数限制)""" + + # 设置模拟provider,让它在第2次调用时返回正常响应 + mock_provider.should_call_tools = True + mock_provider.max_calls_before_normal_response = 2 + + # 初始化runner + await runner.reset( + provider=mock_provider, + request=provider_request, + run_context=ContextWrapper(context=None), + tool_executor=mock_tool_executor, + agent_hooks=mock_hooks, + streaming=False, + ) + + # 设置足够大的最大步数 + max_steps = 10 + + # 收集所有响应 + responses = [] + async for response in runner.step_until_done(max_steps): + responses.append(response) + + # 验证结果 + assert runner.done(), "代理应该正常完成" + + # 验证没有触发最大步数限制 - 通过检查provider调用次数 + # mock_provider在第2次调用后返回正常响应,所以不应该达到max_steps(10) + assert mock_provider.call_count < max_steps, ( + f"正常完成时调用次数({mock_provider.call_count})应该小于最大步数({max_steps})" + ) + + # 验证没有最大步数警告消息(注意:实际注入的是user角色的消息) + user_messages = [m for m in runner.run_context.messages if m.role == "user"] + max_step_messages = [ + m for m in user_messages if "工具调用次数已达到上限" in m.content + ] + assert len(max_step_messages) == 0, "正常完成时不应该有步数限制消息" + + # 验证工具仍然可用(没有被禁用) + assert runner.req.func_tool is not None, "正常完成时工具不应该被禁用" + + +@pytest.mark.asyncio +async def test_max_step_with_streaming( + runner, mock_provider, provider_request, mock_tool_executor, mock_hooks +): + """测试流式响应下的最大步数限制""" + + # 设置模拟provider + mock_provider.should_call_tools = True + mock_provider.max_calls_before_normal_response = 100 + + # 初始化runner,启用流式响应 + await runner.reset( + provider=mock_provider, + request=provider_request, + run_context=ContextWrapper(context=None), + tool_executor=mock_tool_executor, + agent_hooks=mock_hooks, + streaming=True, + ) + + # 设置较小的最大步数 + max_steps = 2 + + # 收集所有响应 + responses = [] + async for response in runner.step_until_done(max_steps): + responses.append(response) + + # 验证结果 + assert runner.done(), "代理应该在达到最大步数后完成" + + # 验证有流式响应 + streaming_responses = [r for r in responses if r.type == "streaming_delta"] + assert len(streaming_responses) > 0, "应该有流式响应" + + # 验证工具被禁用 + assert runner.req.func_tool is None, "达到最大步数后工具应该被禁用" + + # 验证最后一条消息是assistant的最终回答 + last_message = runner.run_context.messages[-1] + assert last_message.role == "assistant", "最后一条消息应该是assistant的最终回答" + + +@pytest.mark.asyncio +async def test_hooks_called_with_max_step( + runner, mock_provider, provider_request, mock_tool_executor, mock_hooks +): + """测试达到最大步数时钩子函数是否被正确调用""" + + # 设置模拟provider + mock_provider.should_call_tools = True + mock_provider.max_calls_before_normal_response = 100 + + # 初始化runner + await runner.reset( + provider=mock_provider, + request=provider_request, + run_context=ContextWrapper(context=None), + tool_executor=mock_tool_executor, + agent_hooks=mock_hooks, + streaming=False, + ) + + # 设置较小的最大步数 + max_steps = 2 + + # 执行步骤 + async for response in runner.step_until_done(max_steps): + pass + + # 验证钩子函数被调用 + assert mock_hooks.agent_begin_called, "on_agent_begin应该被调用" + assert mock_hooks.agent_done_called, "on_agent_done应该被调用" + assert mock_hooks.tool_start_called, "on_tool_start应该被调用" + assert mock_hooks.tool_end_called, "on_tool_end应该被调用" + + +if __name__ == "__main__": + # 运行测试 + pytest.main([__file__, "-v"])