feat: add fallback chat model chain in tool loop runner (#5109)

* feat: implement fallback provider support for chat models and update configuration

* feat: enhance provider selection display with count and chips for selected providers

* feat: update fallback chat providers to use provider settings and add warning for non-list fallback models
This commit is contained in:
Soulter
2026-02-15 11:51:34 +08:00
committed by GitHub
parent 0faf109c2a
commit 754144ad99
10 changed files with 394 additions and 19 deletions
@@ -91,6 +91,7 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
custom_token_counter: TokenCounter | None = None,
custom_compressor: ContextCompressor | None = None,
tool_schema_mode: str | None = "full",
fallback_providers: list[Provider] | None = None,
**kwargs: T.Any,
) -> None:
self.req = request
@@ -120,6 +121,17 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
self.context_manager = ContextManager(self.context_config)
self.provider = provider
self.fallback_providers: list[Provider] = []
seen_provider_ids: set[str] = {str(provider.provider_config.get("id", ""))}
for fallback_provider in fallback_providers or []:
fallback_id = str(fallback_provider.provider_config.get("id", ""))
if fallback_provider is provider:
continue
if fallback_id and fallback_id in seen_provider_ids:
continue
self.fallback_providers.append(fallback_provider)
if fallback_id:
seen_provider_ids.add(fallback_id)
self.final_llm_resp = None
self._state = AgentState.IDLE
self.tool_executor = tool_executor
@@ -166,16 +178,19 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
self.stats = AgentStats()
self.stats.start_time = time.time()
async def _iter_llm_responses(self) -> T.AsyncGenerator[LLMResponse, None]:
async def _iter_llm_responses(
self, *, include_model: bool = True
) -> T.AsyncGenerator[LLMResponse, None]:
"""Yields chunks *and* a final LLMResponse."""
payload = {
"contexts": self.run_context.messages, # list[Message]
"func_tool": self.req.func_tool,
"model": self.req.model, # NOTE: in fact, this arg is None in most cases
"session_id": self.req.session_id,
"extra_user_content_parts": self.req.extra_user_content_parts, # list[ContentPart]
}
if include_model:
# For primary provider we keep explicit model selection if provided.
payload["model"] = self.req.model
if self.streaming:
stream = self.provider.text_chat_stream(**payload)
async for resp in stream: # type: ignore
@@ -183,6 +198,77 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
else:
yield await self.provider.text_chat(**payload)
async def _iter_llm_responses_with_fallback(
self,
) -> T.AsyncGenerator[LLMResponse, None]:
"""Wrap _iter_llm_responses with provider fallback handling."""
candidates = [self.provider, *self.fallback_providers]
total_candidates = len(candidates)
last_exception: Exception | None = None
last_err_response: LLMResponse | None = None
for idx, candidate in enumerate(candidates):
candidate_id = candidate.provider_config.get("id", "<unknown>")
is_last_candidate = idx == total_candidates - 1
if idx > 0:
logger.warning(
"Switched from %s to fallback chat provider: %s",
self.provider.provider_config.get("id", "<unknown>"),
candidate_id,
)
self.provider = candidate
has_stream_output = False
try:
async for resp in self._iter_llm_responses(include_model=idx == 0):
if resp.is_chunk:
has_stream_output = True
yield resp
continue
if (
resp.role == "err"
and not has_stream_output
and (not is_last_candidate)
):
last_err_response = resp
logger.warning(
"Chat Model %s returns error response, trying fallback to next provider.",
candidate_id,
)
break
yield resp
return
if has_stream_output:
return
except Exception as exc: # noqa: BLE001
last_exception = exc
logger.warning(
"Chat Model %s request error: %s",
candidate_id,
exc,
exc_info=True,
)
continue
if last_err_response:
yield last_err_response
return
if last_exception:
yield LLMResponse(
role="err",
completion_text=(
"All chat models failed: "
f"{type(last_exception).__name__}: {last_exception}"
),
)
return
yield LLMResponse(
role="err",
completion_text="All available chat models are unavailable.",
)
def _simple_print_message_role(self, tag: str = ""):
roles = []
for message in self.run_context.messages:
@@ -215,7 +301,7 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
)
self._simple_print_message_role("[AftCompact]")
async for llm_response in self._iter_llm_responses():
async for llm_response in self._iter_llm_responses_with_fallback():
if llm_response.is_chunk:
# update ttft
if self.stats.time_to_first_token == 0:
+38
View File
@@ -870,6 +870,41 @@ def _get_compress_provider(
return provider
def _get_fallback_chat_providers(
provider: Provider, plugin_context: Context, provider_settings: dict
) -> list[Provider]:
fallback_ids = provider_settings.get("fallback_chat_models", [])
if not isinstance(fallback_ids, list):
logger.warning(
"fallback_chat_models setting is not a list, skip fallback providers."
)
return []
provider_id = str(provider.provider_config.get("id", ""))
seen_provider_ids: set[str] = {provider_id} if provider_id else set()
fallbacks: list[Provider] = []
for fallback_id in fallback_ids:
if not isinstance(fallback_id, str) or not fallback_id:
continue
if fallback_id in seen_provider_ids:
continue
fallback_provider = plugin_context.get_provider_by_id(fallback_id)
if fallback_provider is None:
logger.warning("Fallback chat provider `%s` not found, skip.", fallback_id)
continue
if not isinstance(fallback_provider, Provider):
logger.warning(
"Fallback chat provider `%s` is invalid type: %s, skip.",
fallback_id,
type(fallback_provider),
)
continue
fallbacks.append(fallback_provider)
seen_provider_ids.add(fallback_id)
return fallbacks
async def build_main_agent(
*,
event: AstrMessageEvent,
@@ -1093,6 +1128,9 @@ async def build_main_agent(
truncate_turns=config.dequeue_context_length,
enforce_max_turns=config.max_context_length,
tool_schema_mode=config.tool_schema_mode,
fallback_providers=_get_fallback_chat_providers(
provider, plugin_context, config.provider_settings
),
)
if apply_reset:
+14 -2
View File
@@ -68,6 +68,7 @@ DEFAULT_CONFIG = {
"provider_settings": {
"enable": True,
"default_provider_id": "",
"fallback_chat_models": [],
"default_image_caption_provider_id": "",
"image_caption_prompt": "Please describe the image using Chinese.",
"provider_pool": ["*"], # "*" 表示使用所有可用的提供者
@@ -2207,6 +2208,10 @@ CONFIG_METADATA_2 = {
"default_provider_id": {
"type": "string",
},
"fallback_chat_models": {
"type": "list",
"items": {"type": "string"},
},
"wake_prefix": {
"type": "string",
},
@@ -2504,15 +2509,22 @@ CONFIG_METADATA_3 = {
},
"ai": {
"description": "模型",
"hint": "当使用非内置 Agent 执行器时,默认聊天模型和默认图片转述模型可能会无效,但某些插件会依赖此配置项来调用 AI 能力。",
"hint": "当使用非内置 Agent 执行器时,默认对话模型和默认图片转述模型可能会无效,但某些插件会依赖此配置项来调用 AI 能力。",
"type": "object",
"items": {
"provider_settings.default_provider_id": {
"description": "默认聊天模型",
"description": "默认对话模型",
"type": "string",
"_special": "select_provider",
"hint": "留空时使用第一个模型",
},
"provider_settings.fallback_chat_models": {
"description": "回退对话模型列表",
"type": "list",
"items": {"type": "string"},
"_special": "select_providers",
"hint": "主聊天模型请求失败时,按顺序切换到这些模型。",
},
"provider_settings.default_image_caption_provider_id": {
"description": "默认图片转述模型",
"type": "string",
@@ -10,6 +10,14 @@
<template v-else-if="itemMeta?._special === 'select_provider_tts'">
<ProviderSelector :model-value="modelValue" @update:model-value="emitUpdate" :provider-type="'text_to_speech'" />
</template>
<template v-else-if="itemMeta?._special === 'select_providers'">
<ProviderSelector
:model-value="modelValue"
@update:model-value="emitUpdate"
:provider-type="'chat_completion'"
:multiple="true"
/>
</template>
<template v-else-if="getSpecialName(itemMeta?._special) === 'select_agent_runner_provider'">
<ProviderSelector
:model-value="modelValue"
@@ -1,16 +1,35 @@
<template>
<div class="d-flex align-center justify-space-between">
<span v-if="!modelValue" style="color: rgb(var(--v-theme-primaryText));">
<span v-if="!hasSelection" style="color: rgb(var(--v-theme-primaryText));">
{{ tm('providerSelector.notSelected') }}
</span>
<span v-else class="provider-name-text">
{{ modelValue }}
<template v-if="multiple">
{{ tm('providerSelector.selectedCount', { count: selectedProviders.length }) }}
</template>
<template v-else>
{{ modelValue }}
</template>
</span>
<v-btn size="small" color="primary" variant="tonal" @click="openDialog">
{{ buttonText || tm('providerSelector.buttonText') }}
</v-btn>
</div>
<div v-if="multiple && selectedProviders.length > 0" class="selected-preview mt-2">
<v-chip
v-for="providerId in selectedProviders"
:key="`preview-${providerId}`"
size="x-small"
color="primary"
variant="tonal"
class="mr-1 mb-1"
label
>
{{ providerId }}
</v-chip>
</div>
<!-- Provider Selection Dialog -->
<v-dialog v-model="dialog" max-width="600px">
<v-card>
@@ -33,9 +52,51 @@
<v-card-text class="pa-0" style="max-height: 400px; overflow-y: auto;">
<v-progress-linear v-if="loading" indeterminate color="primary"></v-progress-linear>
<div v-if="multiple && selectedProviders.length > 0" class="pa-3">
<div class="text-caption text-medium-emphasis mb-2">
{{ tm('providerSelector.selectedCount', { count: selectedProviders.length }) }}
</div>
<v-list density="compact" class="selected-order-list">
<v-list-item
v-for="(providerId, index) in selectedProviders"
:key="`selected-${providerId}-${index}`"
rounded="md"
class="ma-1"
>
<v-list-item-title>{{ providerId }}</v-list-item-title>
<template #append>
<div class="d-flex ga-1">
<v-btn
icon="mdi-arrow-up"
size="x-small"
variant="text"
:disabled="index === 0"
@click.stop="moveSelected(index, -1)"
/>
<v-btn
icon="mdi-arrow-down"
size="x-small"
variant="text"
:disabled="index === selectedProviders.length - 1"
@click.stop="moveSelected(index, 1)"
/>
<v-btn
icon="mdi-close"
size="x-small"
variant="text"
@click.stop="removeSelected(providerId)"
/>
</div>
</template>
</v-list-item>
</v-list>
<v-divider class="ma-1"></v-divider>
</div>
<v-list v-if="!loading && providerList.length > 0" density="compact">
<!-- 不选择选项 -->
<v-list-item
v-if="!multiple"
key="none"
value=""
@click="selectProvider({ id: '' })"
@@ -57,7 +118,7 @@
:key="provider.id"
:value="provider.id"
@click="selectProvider(provider)"
:active="selectedProvider === provider.id"
:active="isProviderSelected(provider.id)"
rounded="md"
class="ma-1">
<v-list-item-title>{{ provider.id }}</v-list-item-title>
@@ -67,7 +128,7 @@
</v-list-item-subtitle>
<template v-slot:append>
<v-icon v-if="selectedProvider === provider.id" color="primary">mdi-check-circle</v-icon>
<v-icon v-if="isProviderSelected(provider.id)" color="primary">mdi-check-circle</v-icon>
</template>
</v-list-item>
</v-list>
@@ -121,7 +182,7 @@ import ProviderPage from '@/views/ProviderPage.vue'
const props = defineProps({
modelValue: {
type: String,
type: [String, Array],
default: ''
},
providerType: {
@@ -135,6 +196,10 @@ const props = defineProps({
buttonText: {
type: String,
default: ''
},
multiple: {
type: Boolean,
default: false
}
})
@@ -145,8 +210,16 @@ const dialog = ref(false)
const providerList = ref([])
const loading = ref(false)
const selectedProvider = ref('')
const selectedProviders = ref([])
const providerDrawer = ref(false)
const hasSelection = computed(() => {
if (props.multiple) {
return selectedProviders.value.length > 0
}
return Boolean(props.modelValue)
})
const defaultTab = computed(() => {
if (props.providerType === 'agent_runner' && props.providerSubtype) {
return `select_agent_runner_provider:${props.providerSubtype}`
@@ -156,7 +229,13 @@ const defaultTab = computed(() => {
// modelValue selectedProvider
watch(() => props.modelValue, (newValue) => {
selectedProvider.value = newValue || ''
if (props.multiple) {
selectedProviders.value = Array.isArray(newValue)
? [...newValue.filter((v) => typeof v === 'string' && v)]
: []
return
}
selectedProvider.value = typeof newValue === 'string' ? newValue : ''
}, { immediate: true })
watch(providerDrawer, (isOpen, wasOpen) => {
@@ -166,7 +245,13 @@ watch(providerDrawer, (isOpen, wasOpen) => {
})
async function openDialog() {
selectedProvider.value = props.modelValue || ''
if (props.multiple) {
selectedProviders.value = Array.isArray(props.modelValue)
? [...props.modelValue.filter((v) => typeof v === 'string' && v)]
: []
} else {
selectedProvider.value = typeof props.modelValue === 'string' ? props.modelValue : ''
}
dialog.value = true
await loadProviders()
}
@@ -205,19 +290,72 @@ function matchesProviderSubtype(provider, subtype) {
}
function selectProvider(provider) {
if (props.multiple) {
if (!provider.id) {
selectedProviders.value = []
return
}
const idx = selectedProviders.value.indexOf(provider.id)
if (idx >= 0) {
selectedProviders.value.splice(idx, 1)
} else {
selectedProviders.value.push(provider.id)
}
return
}
selectedProvider.value = provider.id
}
function confirmSelection() {
emit('update:modelValue', selectedProvider.value)
if (props.multiple) {
emit('update:modelValue', [...selectedProviders.value])
} else {
emit('update:modelValue', selectedProvider.value)
}
dialog.value = false
}
function cancelSelection() {
selectedProvider.value = props.modelValue || ''
if (props.multiple) {
selectedProviders.value = Array.isArray(props.modelValue)
? [...props.modelValue.filter((v) => typeof v === 'string' && v)]
: []
} else {
selectedProvider.value = typeof props.modelValue === 'string' ? props.modelValue : ''
}
dialog.value = false
}
function isProviderSelected(providerId) {
if (props.multiple) {
return selectedProviders.value.includes(providerId)
}
return selectedProvider.value === providerId
}
function removeSelected(providerId) {
const idx = selectedProviders.value.indexOf(providerId)
if (idx >= 0) {
selectedProviders.value.splice(idx, 1)
}
}
function moveSelected(index, delta) {
const targetIndex = index + delta
if (
targetIndex < 0
|| targetIndex >= selectedProviders.value.length
|| index < 0
|| index >= selectedProviders.value.length
) {
return
}
const copied = [...selectedProviders.value]
const [item] = copied.splice(index, 1)
copied.splice(targetIndex, 0, item)
selectedProviders.value = copied
}
function openProviderDrawer() {
providerDrawer.value = true
}
@@ -236,6 +374,16 @@ function closeProviderDrawer() {
display: inline-block;
}
.selected-preview {
width: 100%;
max-width: 100%;
}
.selected-order-list {
background: rgba(var(--v-theme-surface-variant), 0.15);
border-radius: 10px;
}
.v-list-item {
transition: all 0.2s ease;
}
@@ -45,7 +45,8 @@
"unknownType": "Unknown type",
"createProvider": "Create Provider",
"manageProviders": "Provider Management",
"selectProviderPool": "Select Provider Pool..."
"selectProviderPool": "Select Provider Pool...",
"selectedCount": "{count} provider(s) selected"
},
"personaSelector": {
"notSelected": "Not selected",
@@ -37,6 +37,10 @@
"description": "Default Chat Model",
"hint": "Uses the first model when left empty"
},
"fallback_chat_models": {
"description": "Fallback chat model IDs",
"hint": "When the primary chat model request fails, fallback to these chat models in order."
},
"default_image_caption_provider_id": {
"description": "Default Image Caption Model",
"hint": "Leave empty to disable; useful for non-multimodal models"
@@ -45,7 +45,8 @@
"unknownType": "未知类型",
"createProvider": "创建提供商",
"manageProviders": "提供商管理",
"selectProviderPool": "选择提供商池..."
"selectProviderPool": "选择提供商池...",
"selectedCount": "已选择 {count} 个提供商"
},
"personaSelector": {
"notSelected": "未选择",
@@ -31,12 +31,16 @@
},
"ai": {
"description": "模型",
"hint": "当使用非内置 Agent 执行器时,默认聊天模型和默认图片转述模型可能会无效,但某些插件会依赖此配置项来调用 AI 能力。",
"hint": "当使用非内置 Agent 执行器时,默认对话模型和默认图片转述模型可能会无效,但某些插件会依赖此配置项来调用 AI 能力。",
"provider_settings": {
"default_provider_id": {
"description": "默认聊天模型",
"description": "默认对话模型",
"hint": "留空时使用第一个模型"
},
"fallback_chat_models": {
"description": "回退对话模型列表",
"hint": "主对话模型请求失败时,按顺序切换到这些对话模型。"
},
"default_image_caption_provider_id": {
"description": "默认图片转述模型",
"hint": "留空代表不使用,可用于非多模态模型"
+73
View File
@@ -90,6 +90,21 @@ class MockToolExecutor:
return generator()
class MockFailingProvider(MockProvider):
async def text_chat(self, **kwargs) -> LLMResponse:
self.call_count += 1
raise RuntimeError("primary provider failed")
class MockErrProvider(MockProvider):
async def text_chat(self, **kwargs) -> LLMResponse:
self.call_count += 1
return LLMResponse(
role="err",
completion_text="primary provider returned error",
)
class MockHooks(BaseAgentRunHooks):
"""模拟钩子函数"""
@@ -321,6 +336,64 @@ async def test_hooks_called_with_max_step(
assert mock_hooks.tool_end_called, "on_tool_end应该被调用"
@pytest.mark.asyncio
async def test_fallback_provider_used_when_primary_raises(
runner, provider_request, mock_tool_executor, mock_hooks
):
primary_provider = MockFailingProvider()
fallback_provider = MockProvider()
fallback_provider.should_call_tools = False
await runner.reset(
provider=primary_provider,
request=provider_request,
run_context=ContextWrapper(context=None),
tool_executor=mock_tool_executor,
agent_hooks=mock_hooks,
streaming=False,
fallback_providers=[fallback_provider],
)
async for _ in runner.step_until_done(5):
pass
final_resp = runner.get_final_llm_resp()
assert final_resp is not None
assert final_resp.role == "assistant"
assert final_resp.completion_text == "这是我的最终回答"
assert primary_provider.call_count == 1
assert fallback_provider.call_count == 1
@pytest.mark.asyncio
async def test_fallback_provider_used_when_primary_returns_err(
runner, provider_request, mock_tool_executor, mock_hooks
):
primary_provider = MockErrProvider()
fallback_provider = MockProvider()
fallback_provider.should_call_tools = False
await runner.reset(
provider=primary_provider,
request=provider_request,
run_context=ContextWrapper(context=None),
tool_executor=mock_tool_executor,
agent_hooks=mock_hooks,
streaming=False,
fallback_providers=[fallback_provider],
)
async for _ in runner.step_until_done(5):
pass
final_resp = runner.get_final_llm_resp()
assert final_resp is not None
assert final_resp.role == "assistant"
assert final_resp.completion_text == "这是我的最终回答"
assert primary_provider.call_count == 1
assert fallback_provider.call_count == 1
if __name__ == "__main__":
# 运行测试
pytest.main([__file__, "-v"])