5219ba5c4e
* feat: implement follow-up message handling in ToolLoopAgentRunner * fix: correct import path for follow-up module in InternalAgentSubStage
541 lines
16 KiB
Python
541 lines
16 KiB
Python
import os
|
|
import sys
|
|
from unittest.mock import AsyncMock
|
|
|
|
import pytest
|
|
|
|
# 将项目根目录添加到 sys.path
|
|
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
|
|
|
from astrbot.core.agent.hooks import BaseAgentRunHooks
|
|
from astrbot.core.agent.run_context import ContextWrapper
|
|
from astrbot.core.agent.runners.tool_loop_agent_runner import ToolLoopAgentRunner
|
|
from astrbot.core.agent.tool import FunctionTool, ToolSet
|
|
from astrbot.core.provider.entities import LLMResponse, ProviderRequest, TokenUsage
|
|
from astrbot.core.provider.provider import Provider
|
|
|
|
|
|
class MockProvider(Provider):
|
|
"""模拟Provider用于测试"""
|
|
|
|
def __init__(self):
|
|
super().__init__({}, {})
|
|
self.call_count = 0
|
|
self.should_call_tools = True
|
|
self.max_calls_before_normal_response = 10
|
|
|
|
def get_current_key(self) -> str:
|
|
return "test_key"
|
|
|
|
def set_key(self, key: str):
|
|
pass
|
|
|
|
async def get_models(self) -> list[str]:
|
|
return ["test_model"]
|
|
|
|
async def text_chat(self, **kwargs) -> LLMResponse:
|
|
self.call_count += 1
|
|
|
|
# 检查工具是否被禁用
|
|
func_tool = kwargs.get("func_tool")
|
|
|
|
# 如果工具被禁用或超过最大调用次数,返回正常响应
|
|
if func_tool is None or self.call_count > self.max_calls_before_normal_response:
|
|
return LLMResponse(
|
|
role="assistant",
|
|
completion_text="这是我的最终回答",
|
|
usage=TokenUsage(input_other=10, output=5),
|
|
)
|
|
|
|
# 模拟工具调用响应
|
|
if self.should_call_tools:
|
|
return LLMResponse(
|
|
role="assistant",
|
|
completion_text="我需要使用工具来帮助您",
|
|
tools_call_name=["test_tool"],
|
|
tools_call_args=[{"query": "test"}],
|
|
tools_call_ids=["call_123"],
|
|
usage=TokenUsage(input_other=10, output=5),
|
|
)
|
|
|
|
# 默认返回正常响应
|
|
return LLMResponse(
|
|
role="assistant",
|
|
completion_text="这是我的最终回答",
|
|
usage=TokenUsage(input_other=10, output=5),
|
|
)
|
|
|
|
async def text_chat_stream(self, **kwargs):
|
|
response = await self.text_chat(**kwargs)
|
|
response.is_chunk = True
|
|
yield response
|
|
response.is_chunk = False
|
|
yield response
|
|
|
|
|
|
class MockToolExecutor:
|
|
"""模拟工具执行器"""
|
|
|
|
@classmethod
|
|
def execute(cls, tool, run_context, **tool_args):
|
|
async def generator():
|
|
# 模拟工具返回结果,使用正确的类型
|
|
from mcp.types import CallToolResult, TextContent
|
|
|
|
result = CallToolResult(
|
|
content=[TextContent(type="text", text="工具执行结果")]
|
|
)
|
|
yield result
|
|
|
|
return generator()
|
|
|
|
|
|
class MockFailingProvider(MockProvider):
|
|
async def text_chat(self, **kwargs) -> LLMResponse:
|
|
self.call_count += 1
|
|
raise RuntimeError("primary provider failed")
|
|
|
|
|
|
class MockErrProvider(MockProvider):
|
|
async def text_chat(self, **kwargs) -> LLMResponse:
|
|
self.call_count += 1
|
|
return LLMResponse(
|
|
role="err",
|
|
completion_text="primary provider returned error",
|
|
)
|
|
|
|
|
|
class MockAbortableStreamProvider(MockProvider):
|
|
async def text_chat_stream(self, **kwargs):
|
|
abort_signal = kwargs.get("abort_signal")
|
|
yield LLMResponse(
|
|
role="assistant",
|
|
completion_text="partial ",
|
|
is_chunk=True,
|
|
)
|
|
if abort_signal and abort_signal.is_set():
|
|
yield LLMResponse(
|
|
role="assistant",
|
|
completion_text="partial ",
|
|
is_chunk=False,
|
|
)
|
|
return
|
|
yield LLMResponse(
|
|
role="assistant",
|
|
completion_text="partial final",
|
|
is_chunk=False,
|
|
)
|
|
|
|
|
|
class MockHooks(BaseAgentRunHooks):
|
|
"""模拟钩子函数"""
|
|
|
|
def __init__(self):
|
|
self.agent_begin_called = False
|
|
self.agent_done_called = False
|
|
self.tool_start_called = False
|
|
self.tool_end_called = False
|
|
|
|
async def on_agent_begin(self, run_context):
|
|
self.agent_begin_called = True
|
|
|
|
async def on_tool_start(self, run_context, tool, tool_args):
|
|
self.tool_start_called = True
|
|
|
|
async def on_tool_end(self, run_context, tool, tool_args, tool_result):
|
|
self.tool_end_called = True
|
|
|
|
async def on_agent_done(self, run_context, llm_response):
|
|
self.agent_done_called = True
|
|
|
|
|
|
class MockEvent:
|
|
def __init__(self, umo: str, sender_id: str):
|
|
self.unified_msg_origin = umo
|
|
self._sender_id = sender_id
|
|
|
|
def get_sender_id(self):
|
|
return self._sender_id
|
|
|
|
|
|
class MockAgentContext:
|
|
def __init__(self, event):
|
|
self.event = event
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_provider():
|
|
return MockProvider()
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_tool_executor():
|
|
return MockToolExecutor()
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_hooks():
|
|
return MockHooks()
|
|
|
|
|
|
@pytest.fixture
|
|
def tool_set():
|
|
"""创建测试用的工具集"""
|
|
tool = FunctionTool(
|
|
name="test_tool",
|
|
description="测试工具",
|
|
parameters={"type": "object", "properties": {"query": {"type": "string"}}},
|
|
handler=AsyncMock(),
|
|
)
|
|
return ToolSet(tools=[tool])
|
|
|
|
|
|
@pytest.fixture
|
|
def provider_request(tool_set):
|
|
"""创建测试用的ProviderRequest"""
|
|
return ProviderRequest(prompt="请帮我查询信息", func_tool=tool_set, contexts=[])
|
|
|
|
|
|
@pytest.fixture
|
|
def runner():
|
|
"""创建ToolLoopAgentRunner实例"""
|
|
return ToolLoopAgentRunner()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_max_step_limit_functionality(
|
|
runner, mock_provider, provider_request, mock_tool_executor, mock_hooks
|
|
):
|
|
"""测试最大步数限制功能"""
|
|
|
|
# 设置模拟provider,让它总是返回工具调用
|
|
mock_provider.should_call_tools = True
|
|
mock_provider.max_calls_before_normal_response = (
|
|
100 # 设置一个很大的值,确保不会自然结束
|
|
)
|
|
|
|
# 初始化runner
|
|
await runner.reset(
|
|
provider=mock_provider,
|
|
request=provider_request,
|
|
run_context=ContextWrapper(context=None),
|
|
tool_executor=mock_tool_executor,
|
|
agent_hooks=mock_hooks,
|
|
streaming=False,
|
|
)
|
|
|
|
# 设置较小的最大步数来测试限制功能
|
|
max_steps = 3
|
|
|
|
# 收集所有响应
|
|
responses = []
|
|
async for response in runner.step_until_done(max_steps):
|
|
responses.append(response)
|
|
|
|
# 验证结果
|
|
assert runner.done(), "代理应该在达到最大步数后完成"
|
|
|
|
# 验证工具被禁用(这是最重要的验证点)
|
|
assert runner.req.func_tool is None, "达到最大步数后工具应该被禁用"
|
|
|
|
# 验证有最终响应
|
|
final_responses = [r for r in responses if r.type == "llm_result"]
|
|
assert len(final_responses) > 0, "应该有最终的LLM响应"
|
|
|
|
# 验证最后一条消息是assistant的最终回答
|
|
last_message = runner.run_context.messages[-1]
|
|
assert last_message.role == "assistant", "最后一条消息应该是assistant的最终回答"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_normal_completion_without_max_step(
|
|
runner, mock_provider, provider_request, mock_tool_executor, mock_hooks
|
|
):
|
|
"""测试正常完成(不触发最大步数限制)"""
|
|
|
|
# 设置模拟provider,让它在第2次调用时返回正常响应
|
|
mock_provider.should_call_tools = True
|
|
mock_provider.max_calls_before_normal_response = 2
|
|
|
|
# 初始化runner
|
|
await runner.reset(
|
|
provider=mock_provider,
|
|
request=provider_request,
|
|
run_context=ContextWrapper(context=None),
|
|
tool_executor=mock_tool_executor,
|
|
agent_hooks=mock_hooks,
|
|
streaming=False,
|
|
)
|
|
|
|
# 设置足够大的最大步数
|
|
max_steps = 10
|
|
|
|
# 收集所有响应
|
|
responses = []
|
|
async for response in runner.step_until_done(max_steps):
|
|
responses.append(response)
|
|
|
|
# 验证结果
|
|
assert runner.done(), "代理应该正常完成"
|
|
|
|
# 验证没有触发最大步数限制 - 通过检查provider调用次数
|
|
# mock_provider在第2次调用后返回正常响应,所以不应该达到max_steps(10)
|
|
assert mock_provider.call_count < max_steps, (
|
|
f"正常完成时调用次数({mock_provider.call_count})应该小于最大步数({max_steps})"
|
|
)
|
|
|
|
# 验证没有最大步数警告消息(注意:实际注入的是user角色的消息)
|
|
user_messages = [m for m in runner.run_context.messages if m.role == "user"]
|
|
max_step_messages = [
|
|
m for m in user_messages if "工具调用次数已达到上限" in m.content
|
|
]
|
|
assert len(max_step_messages) == 0, "正常完成时不应该有步数限制消息"
|
|
|
|
# 验证工具仍然可用(没有被禁用)
|
|
assert runner.req.func_tool is not None, "正常完成时工具不应该被禁用"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_max_step_with_streaming(
|
|
runner, mock_provider, provider_request, mock_tool_executor, mock_hooks
|
|
):
|
|
"""测试流式响应下的最大步数限制"""
|
|
|
|
# 设置模拟provider
|
|
mock_provider.should_call_tools = True
|
|
mock_provider.max_calls_before_normal_response = 100
|
|
|
|
# 初始化runner,启用流式响应
|
|
await runner.reset(
|
|
provider=mock_provider,
|
|
request=provider_request,
|
|
run_context=ContextWrapper(context=None),
|
|
tool_executor=mock_tool_executor,
|
|
agent_hooks=mock_hooks,
|
|
streaming=True,
|
|
)
|
|
|
|
# 设置较小的最大步数
|
|
max_steps = 2
|
|
|
|
# 收集所有响应
|
|
responses = []
|
|
async for response in runner.step_until_done(max_steps):
|
|
responses.append(response)
|
|
|
|
# 验证结果
|
|
assert runner.done(), "代理应该在达到最大步数后完成"
|
|
|
|
# 验证有流式响应
|
|
streaming_responses = [r for r in responses if r.type == "streaming_delta"]
|
|
assert len(streaming_responses) > 0, "应该有流式响应"
|
|
|
|
# 验证工具被禁用
|
|
assert runner.req.func_tool is None, "达到最大步数后工具应该被禁用"
|
|
|
|
# 验证最后一条消息是assistant的最终回答
|
|
last_message = runner.run_context.messages[-1]
|
|
assert last_message.role == "assistant", "最后一条消息应该是assistant的最终回答"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_hooks_called_with_max_step(
|
|
runner, mock_provider, provider_request, mock_tool_executor, mock_hooks
|
|
):
|
|
"""测试达到最大步数时钩子函数是否被正确调用"""
|
|
|
|
# 设置模拟provider
|
|
mock_provider.should_call_tools = True
|
|
mock_provider.max_calls_before_normal_response = 100
|
|
|
|
# 初始化runner
|
|
await runner.reset(
|
|
provider=mock_provider,
|
|
request=provider_request,
|
|
run_context=ContextWrapper(context=None),
|
|
tool_executor=mock_tool_executor,
|
|
agent_hooks=mock_hooks,
|
|
streaming=False,
|
|
)
|
|
|
|
# 设置较小的最大步数
|
|
max_steps = 2
|
|
|
|
# 执行步骤
|
|
async for response in runner.step_until_done(max_steps):
|
|
pass
|
|
|
|
# 验证钩子函数被调用
|
|
assert mock_hooks.agent_begin_called, "on_agent_begin应该被调用"
|
|
assert mock_hooks.agent_done_called, "on_agent_done应该被调用"
|
|
assert mock_hooks.tool_start_called, "on_tool_start应该被调用"
|
|
assert mock_hooks.tool_end_called, "on_tool_end应该被调用"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_fallback_provider_used_when_primary_raises(
|
|
runner, provider_request, mock_tool_executor, mock_hooks
|
|
):
|
|
primary_provider = MockFailingProvider()
|
|
fallback_provider = MockProvider()
|
|
fallback_provider.should_call_tools = False
|
|
|
|
await runner.reset(
|
|
provider=primary_provider,
|
|
request=provider_request,
|
|
run_context=ContextWrapper(context=None),
|
|
tool_executor=mock_tool_executor,
|
|
agent_hooks=mock_hooks,
|
|
streaming=False,
|
|
fallback_providers=[fallback_provider],
|
|
)
|
|
|
|
async for _ in runner.step_until_done(5):
|
|
pass
|
|
|
|
final_resp = runner.get_final_llm_resp()
|
|
assert final_resp is not None
|
|
assert final_resp.role == "assistant"
|
|
assert final_resp.completion_text == "这是我的最终回答"
|
|
assert primary_provider.call_count == 1
|
|
assert fallback_provider.call_count == 1
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_fallback_provider_used_when_primary_returns_err(
|
|
runner, provider_request, mock_tool_executor, mock_hooks
|
|
):
|
|
primary_provider = MockErrProvider()
|
|
fallback_provider = MockProvider()
|
|
fallback_provider.should_call_tools = False
|
|
|
|
await runner.reset(
|
|
provider=primary_provider,
|
|
request=provider_request,
|
|
run_context=ContextWrapper(context=None),
|
|
tool_executor=mock_tool_executor,
|
|
agent_hooks=mock_hooks,
|
|
streaming=False,
|
|
fallback_providers=[fallback_provider],
|
|
)
|
|
|
|
async for _ in runner.step_until_done(5):
|
|
pass
|
|
|
|
final_resp = runner.get_final_llm_resp()
|
|
assert final_resp is not None
|
|
assert final_resp.role == "assistant"
|
|
assert final_resp.completion_text == "这是我的最终回答"
|
|
assert primary_provider.call_count == 1
|
|
assert fallback_provider.call_count == 1
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_stop_signal_returns_aborted_and_persists_partial_message(
|
|
runner, provider_request, mock_tool_executor, mock_hooks
|
|
):
|
|
provider = MockAbortableStreamProvider()
|
|
|
|
await runner.reset(
|
|
provider=provider,
|
|
request=provider_request,
|
|
run_context=ContextWrapper(context=None),
|
|
tool_executor=mock_tool_executor,
|
|
agent_hooks=mock_hooks,
|
|
streaming=True,
|
|
)
|
|
|
|
step_iter = runner.step()
|
|
first_resp = await step_iter.__anext__()
|
|
assert first_resp.type == "streaming_delta"
|
|
|
|
runner.request_stop()
|
|
|
|
rest_responses = []
|
|
async for response in step_iter:
|
|
rest_responses.append(response)
|
|
|
|
assert any(resp.type == "aborted" for resp in rest_responses)
|
|
assert runner.was_aborted() is True
|
|
|
|
final_resp = runner.get_final_llm_resp()
|
|
assert final_resp is not None
|
|
assert final_resp.role == "assistant"
|
|
assert final_resp.completion_text == "partial "
|
|
assert runner.run_context.messages[-1].role == "assistant"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_tool_result_injects_follow_up_notice(
|
|
runner, mock_provider, provider_request, mock_tool_executor, mock_hooks
|
|
):
|
|
mock_event = MockEvent("test:FriendMessage:follow_up", "u1")
|
|
run_context = ContextWrapper(context=MockAgentContext(mock_event))
|
|
|
|
await runner.reset(
|
|
provider=mock_provider,
|
|
request=provider_request,
|
|
run_context=run_context,
|
|
tool_executor=mock_tool_executor,
|
|
agent_hooks=mock_hooks,
|
|
streaming=False,
|
|
)
|
|
|
|
ticket1 = runner.follow_up(
|
|
message_text="follow up 1",
|
|
)
|
|
ticket2 = runner.follow_up(
|
|
message_text="follow up 2",
|
|
)
|
|
assert ticket1 is not None
|
|
assert ticket2 is not None
|
|
|
|
async for _ in runner.step():
|
|
pass
|
|
|
|
assert provider_request.tool_calls_result is not None
|
|
assert isinstance(provider_request.tool_calls_result, list)
|
|
assert provider_request.tool_calls_result
|
|
tool_result = str(
|
|
provider_request.tool_calls_result[0].tool_calls_result[0].content
|
|
)
|
|
assert "SYSTEM NOTICE" in tool_result
|
|
assert "1. follow up 1" in tool_result
|
|
assert "2. follow up 2" in tool_result
|
|
assert ticket1.resolved.is_set() is True
|
|
assert ticket2.resolved.is_set() is True
|
|
assert ticket1.consumed is True
|
|
assert ticket2.consumed is True
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_follow_up_ticket_not_consumed_when_no_next_tool_call(
|
|
runner, mock_provider, provider_request, mock_tool_executor, mock_hooks
|
|
):
|
|
mock_provider.should_call_tools = False
|
|
mock_event = MockEvent("test:FriendMessage:follow_up_no_tool", "u1")
|
|
run_context = ContextWrapper(context=MockAgentContext(mock_event))
|
|
|
|
await runner.reset(
|
|
provider=mock_provider,
|
|
request=provider_request,
|
|
run_context=run_context,
|
|
tool_executor=mock_tool_executor,
|
|
agent_hooks=mock_hooks,
|
|
streaming=False,
|
|
)
|
|
|
|
ticket = runner.follow_up(message_text="follow up without tool")
|
|
assert ticket is not None
|
|
|
|
async for _ in runner.step():
|
|
pass
|
|
|
|
assert ticket.resolved.is_set() is True
|
|
assert ticket.consumed is False
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# 运行测试
|
|
pytest.main([__file__, "-v"])
|