feat: add fallback chat model chain in tool loop runner (#5109)
* feat: implement fallback provider support for chat models and update configuration * feat: enhance provider selection display with count and chips for selected providers * feat: update fallback chat providers to use provider settings and add warning for non-list fallback models
This commit is contained in:
@@ -90,6 +90,21 @@ class MockToolExecutor:
|
||||
return generator()
|
||||
|
||||
|
||||
class MockFailingProvider(MockProvider):
|
||||
async def text_chat(self, **kwargs) -> LLMResponse:
|
||||
self.call_count += 1
|
||||
raise RuntimeError("primary provider failed")
|
||||
|
||||
|
||||
class MockErrProvider(MockProvider):
|
||||
async def text_chat(self, **kwargs) -> LLMResponse:
|
||||
self.call_count += 1
|
||||
return LLMResponse(
|
||||
role="err",
|
||||
completion_text="primary provider returned error",
|
||||
)
|
||||
|
||||
|
||||
class MockHooks(BaseAgentRunHooks):
|
||||
"""模拟钩子函数"""
|
||||
|
||||
@@ -321,6 +336,64 @@ async def test_hooks_called_with_max_step(
|
||||
assert mock_hooks.tool_end_called, "on_tool_end应该被调用"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_provider_used_when_primary_raises(
|
||||
runner, provider_request, mock_tool_executor, mock_hooks
|
||||
):
|
||||
primary_provider = MockFailingProvider()
|
||||
fallback_provider = MockProvider()
|
||||
fallback_provider.should_call_tools = False
|
||||
|
||||
await runner.reset(
|
||||
provider=primary_provider,
|
||||
request=provider_request,
|
||||
run_context=ContextWrapper(context=None),
|
||||
tool_executor=mock_tool_executor,
|
||||
agent_hooks=mock_hooks,
|
||||
streaming=False,
|
||||
fallback_providers=[fallback_provider],
|
||||
)
|
||||
|
||||
async for _ in runner.step_until_done(5):
|
||||
pass
|
||||
|
||||
final_resp = runner.get_final_llm_resp()
|
||||
assert final_resp is not None
|
||||
assert final_resp.role == "assistant"
|
||||
assert final_resp.completion_text == "这是我的最终回答"
|
||||
assert primary_provider.call_count == 1
|
||||
assert fallback_provider.call_count == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_provider_used_when_primary_returns_err(
|
||||
runner, provider_request, mock_tool_executor, mock_hooks
|
||||
):
|
||||
primary_provider = MockErrProvider()
|
||||
fallback_provider = MockProvider()
|
||||
fallback_provider.should_call_tools = False
|
||||
|
||||
await runner.reset(
|
||||
provider=primary_provider,
|
||||
request=provider_request,
|
||||
run_context=ContextWrapper(context=None),
|
||||
tool_executor=mock_tool_executor,
|
||||
agent_hooks=mock_hooks,
|
||||
streaming=False,
|
||||
fallback_providers=[fallback_provider],
|
||||
)
|
||||
|
||||
async for _ in runner.step_until_done(5):
|
||||
pass
|
||||
|
||||
final_resp = runner.get_final_llm_resp()
|
||||
assert final_resp is not None
|
||||
assert final_resp.role == "assistant"
|
||||
assert final_resp.completion_text == "这是我的最终回答"
|
||||
assert primary_provider.call_count == 1
|
||||
assert fallback_provider.call_count == 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 运行测试
|
||||
pytest.main([__file__, "-v"])
|
||||
|
||||
Reference in New Issue
Block a user