feat: add fallback chat model chain in tool loop runner (#5109)
* feat: implement fallback provider support for chat models and update configuration * feat: enhance provider selection display with count and chips for selected providers * feat: update fallback chat providers to use provider settings and add warning for non-list fallback models
This commit is contained in:
@@ -91,6 +91,7 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
custom_token_counter: TokenCounter | None = None,
|
||||
custom_compressor: ContextCompressor | None = None,
|
||||
tool_schema_mode: str | None = "full",
|
||||
fallback_providers: list[Provider] | None = None,
|
||||
**kwargs: T.Any,
|
||||
) -> None:
|
||||
self.req = request
|
||||
@@ -120,6 +121,17 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
self.context_manager = ContextManager(self.context_config)
|
||||
|
||||
self.provider = provider
|
||||
self.fallback_providers: list[Provider] = []
|
||||
seen_provider_ids: set[str] = {str(provider.provider_config.get("id", ""))}
|
||||
for fallback_provider in fallback_providers or []:
|
||||
fallback_id = str(fallback_provider.provider_config.get("id", ""))
|
||||
if fallback_provider is provider:
|
||||
continue
|
||||
if fallback_id and fallback_id in seen_provider_ids:
|
||||
continue
|
||||
self.fallback_providers.append(fallback_provider)
|
||||
if fallback_id:
|
||||
seen_provider_ids.add(fallback_id)
|
||||
self.final_llm_resp = None
|
||||
self._state = AgentState.IDLE
|
||||
self.tool_executor = tool_executor
|
||||
@@ -166,16 +178,19 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
self.stats = AgentStats()
|
||||
self.stats.start_time = time.time()
|
||||
|
||||
async def _iter_llm_responses(self) -> T.AsyncGenerator[LLMResponse, None]:
|
||||
async def _iter_llm_responses(
|
||||
self, *, include_model: bool = True
|
||||
) -> T.AsyncGenerator[LLMResponse, None]:
|
||||
"""Yields chunks *and* a final LLMResponse."""
|
||||
payload = {
|
||||
"contexts": self.run_context.messages, # list[Message]
|
||||
"func_tool": self.req.func_tool,
|
||||
"model": self.req.model, # NOTE: in fact, this arg is None in most cases
|
||||
"session_id": self.req.session_id,
|
||||
"extra_user_content_parts": self.req.extra_user_content_parts, # list[ContentPart]
|
||||
}
|
||||
|
||||
if include_model:
|
||||
# For primary provider we keep explicit model selection if provided.
|
||||
payload["model"] = self.req.model
|
||||
if self.streaming:
|
||||
stream = self.provider.text_chat_stream(**payload)
|
||||
async for resp in stream: # type: ignore
|
||||
@@ -183,6 +198,77 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
else:
|
||||
yield await self.provider.text_chat(**payload)
|
||||
|
||||
async def _iter_llm_responses_with_fallback(
|
||||
self,
|
||||
) -> T.AsyncGenerator[LLMResponse, None]:
|
||||
"""Wrap _iter_llm_responses with provider fallback handling."""
|
||||
candidates = [self.provider, *self.fallback_providers]
|
||||
total_candidates = len(candidates)
|
||||
last_exception: Exception | None = None
|
||||
last_err_response: LLMResponse | None = None
|
||||
|
||||
for idx, candidate in enumerate(candidates):
|
||||
candidate_id = candidate.provider_config.get("id", "<unknown>")
|
||||
is_last_candidate = idx == total_candidates - 1
|
||||
if idx > 0:
|
||||
logger.warning(
|
||||
"Switched from %s to fallback chat provider: %s",
|
||||
self.provider.provider_config.get("id", "<unknown>"),
|
||||
candidate_id,
|
||||
)
|
||||
self.provider = candidate
|
||||
has_stream_output = False
|
||||
try:
|
||||
async for resp in self._iter_llm_responses(include_model=idx == 0):
|
||||
if resp.is_chunk:
|
||||
has_stream_output = True
|
||||
yield resp
|
||||
continue
|
||||
|
||||
if (
|
||||
resp.role == "err"
|
||||
and not has_stream_output
|
||||
and (not is_last_candidate)
|
||||
):
|
||||
last_err_response = resp
|
||||
logger.warning(
|
||||
"Chat Model %s returns error response, trying fallback to next provider.",
|
||||
candidate_id,
|
||||
)
|
||||
break
|
||||
|
||||
yield resp
|
||||
return
|
||||
|
||||
if has_stream_output:
|
||||
return
|
||||
except Exception as exc: # noqa: BLE001
|
||||
last_exception = exc
|
||||
logger.warning(
|
||||
"Chat Model %s request error: %s",
|
||||
candidate_id,
|
||||
exc,
|
||||
exc_info=True,
|
||||
)
|
||||
continue
|
||||
|
||||
if last_err_response:
|
||||
yield last_err_response
|
||||
return
|
||||
if last_exception:
|
||||
yield LLMResponse(
|
||||
role="err",
|
||||
completion_text=(
|
||||
"All chat models failed: "
|
||||
f"{type(last_exception).__name__}: {last_exception}"
|
||||
),
|
||||
)
|
||||
return
|
||||
yield LLMResponse(
|
||||
role="err",
|
||||
completion_text="All available chat models are unavailable.",
|
||||
)
|
||||
|
||||
def _simple_print_message_role(self, tag: str = ""):
|
||||
roles = []
|
||||
for message in self.run_context.messages:
|
||||
@@ -215,7 +301,7 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
)
|
||||
self._simple_print_message_role("[AftCompact]")
|
||||
|
||||
async for llm_response in self._iter_llm_responses():
|
||||
async for llm_response in self._iter_llm_responses_with_fallback():
|
||||
if llm_response.is_chunk:
|
||||
# update ttft
|
||||
if self.stats.time_to_first_token == 0:
|
||||
|
||||
@@ -870,6 +870,41 @@ def _get_compress_provider(
|
||||
return provider
|
||||
|
||||
|
||||
def _get_fallback_chat_providers(
|
||||
provider: Provider, plugin_context: Context, provider_settings: dict
|
||||
) -> list[Provider]:
|
||||
fallback_ids = provider_settings.get("fallback_chat_models", [])
|
||||
if not isinstance(fallback_ids, list):
|
||||
logger.warning(
|
||||
"fallback_chat_models setting is not a list, skip fallback providers."
|
||||
)
|
||||
return []
|
||||
|
||||
provider_id = str(provider.provider_config.get("id", ""))
|
||||
seen_provider_ids: set[str] = {provider_id} if provider_id else set()
|
||||
fallbacks: list[Provider] = []
|
||||
|
||||
for fallback_id in fallback_ids:
|
||||
if not isinstance(fallback_id, str) or not fallback_id:
|
||||
continue
|
||||
if fallback_id in seen_provider_ids:
|
||||
continue
|
||||
fallback_provider = plugin_context.get_provider_by_id(fallback_id)
|
||||
if fallback_provider is None:
|
||||
logger.warning("Fallback chat provider `%s` not found, skip.", fallback_id)
|
||||
continue
|
||||
if not isinstance(fallback_provider, Provider):
|
||||
logger.warning(
|
||||
"Fallback chat provider `%s` is invalid type: %s, skip.",
|
||||
fallback_id,
|
||||
type(fallback_provider),
|
||||
)
|
||||
continue
|
||||
fallbacks.append(fallback_provider)
|
||||
seen_provider_ids.add(fallback_id)
|
||||
return fallbacks
|
||||
|
||||
|
||||
async def build_main_agent(
|
||||
*,
|
||||
event: AstrMessageEvent,
|
||||
@@ -1093,6 +1128,9 @@ async def build_main_agent(
|
||||
truncate_turns=config.dequeue_context_length,
|
||||
enforce_max_turns=config.max_context_length,
|
||||
tool_schema_mode=config.tool_schema_mode,
|
||||
fallback_providers=_get_fallback_chat_providers(
|
||||
provider, plugin_context, config.provider_settings
|
||||
),
|
||||
)
|
||||
|
||||
if apply_reset:
|
||||
|
||||
@@ -68,6 +68,7 @@ DEFAULT_CONFIG = {
|
||||
"provider_settings": {
|
||||
"enable": True,
|
||||
"default_provider_id": "",
|
||||
"fallback_chat_models": [],
|
||||
"default_image_caption_provider_id": "",
|
||||
"image_caption_prompt": "Please describe the image using Chinese.",
|
||||
"provider_pool": ["*"], # "*" 表示使用所有可用的提供者
|
||||
@@ -2207,6 +2208,10 @@ CONFIG_METADATA_2 = {
|
||||
"default_provider_id": {
|
||||
"type": "string",
|
||||
},
|
||||
"fallback_chat_models": {
|
||||
"type": "list",
|
||||
"items": {"type": "string"},
|
||||
},
|
||||
"wake_prefix": {
|
||||
"type": "string",
|
||||
},
|
||||
@@ -2504,15 +2509,22 @@ CONFIG_METADATA_3 = {
|
||||
},
|
||||
"ai": {
|
||||
"description": "模型",
|
||||
"hint": "当使用非内置 Agent 执行器时,默认聊天模型和默认图片转述模型可能会无效,但某些插件会依赖此配置项来调用 AI 能力。",
|
||||
"hint": "当使用非内置 Agent 执行器时,默认对话模型和默认图片转述模型可能会无效,但某些插件会依赖此配置项来调用 AI 能力。",
|
||||
"type": "object",
|
||||
"items": {
|
||||
"provider_settings.default_provider_id": {
|
||||
"description": "默认聊天模型",
|
||||
"description": "默认对话模型",
|
||||
"type": "string",
|
||||
"_special": "select_provider",
|
||||
"hint": "留空时使用第一个模型",
|
||||
},
|
||||
"provider_settings.fallback_chat_models": {
|
||||
"description": "回退对话模型列表",
|
||||
"type": "list",
|
||||
"items": {"type": "string"},
|
||||
"_special": "select_providers",
|
||||
"hint": "主聊天模型请求失败时,按顺序切换到这些模型。",
|
||||
},
|
||||
"provider_settings.default_image_caption_provider_id": {
|
||||
"description": "默认图片转述模型",
|
||||
"type": "string",
|
||||
|
||||
@@ -10,6 +10,14 @@
|
||||
<template v-else-if="itemMeta?._special === 'select_provider_tts'">
|
||||
<ProviderSelector :model-value="modelValue" @update:model-value="emitUpdate" :provider-type="'text_to_speech'" />
|
||||
</template>
|
||||
<template v-else-if="itemMeta?._special === 'select_providers'">
|
||||
<ProviderSelector
|
||||
:model-value="modelValue"
|
||||
@update:model-value="emitUpdate"
|
||||
:provider-type="'chat_completion'"
|
||||
:multiple="true"
|
||||
/>
|
||||
</template>
|
||||
<template v-else-if="getSpecialName(itemMeta?._special) === 'select_agent_runner_provider'">
|
||||
<ProviderSelector
|
||||
:model-value="modelValue"
|
||||
|
||||
@@ -1,16 +1,35 @@
|
||||
<template>
|
||||
<div class="d-flex align-center justify-space-between">
|
||||
<span v-if="!modelValue" style="color: rgb(var(--v-theme-primaryText));">
|
||||
<span v-if="!hasSelection" style="color: rgb(var(--v-theme-primaryText));">
|
||||
{{ tm('providerSelector.notSelected') }}
|
||||
</span>
|
||||
<span v-else class="provider-name-text">
|
||||
{{ modelValue }}
|
||||
<template v-if="multiple">
|
||||
{{ tm('providerSelector.selectedCount', { count: selectedProviders.length }) }}
|
||||
</template>
|
||||
<template v-else>
|
||||
{{ modelValue }}
|
||||
</template>
|
||||
</span>
|
||||
<v-btn size="small" color="primary" variant="tonal" @click="openDialog">
|
||||
{{ buttonText || tm('providerSelector.buttonText') }}
|
||||
</v-btn>
|
||||
</div>
|
||||
|
||||
<div v-if="multiple && selectedProviders.length > 0" class="selected-preview mt-2">
|
||||
<v-chip
|
||||
v-for="providerId in selectedProviders"
|
||||
:key="`preview-${providerId}`"
|
||||
size="x-small"
|
||||
color="primary"
|
||||
variant="tonal"
|
||||
class="mr-1 mb-1"
|
||||
label
|
||||
>
|
||||
{{ providerId }}
|
||||
</v-chip>
|
||||
</div>
|
||||
|
||||
<!-- Provider Selection Dialog -->
|
||||
<v-dialog v-model="dialog" max-width="600px">
|
||||
<v-card>
|
||||
@@ -33,9 +52,51 @@
|
||||
<v-card-text class="pa-0" style="max-height: 400px; overflow-y: auto;">
|
||||
<v-progress-linear v-if="loading" indeterminate color="primary"></v-progress-linear>
|
||||
|
||||
<div v-if="multiple && selectedProviders.length > 0" class="pa-3">
|
||||
<div class="text-caption text-medium-emphasis mb-2">
|
||||
{{ tm('providerSelector.selectedCount', { count: selectedProviders.length }) }}
|
||||
</div>
|
||||
<v-list density="compact" class="selected-order-list">
|
||||
<v-list-item
|
||||
v-for="(providerId, index) in selectedProviders"
|
||||
:key="`selected-${providerId}-${index}`"
|
||||
rounded="md"
|
||||
class="ma-1"
|
||||
>
|
||||
<v-list-item-title>{{ providerId }}</v-list-item-title>
|
||||
<template #append>
|
||||
<div class="d-flex ga-1">
|
||||
<v-btn
|
||||
icon="mdi-arrow-up"
|
||||
size="x-small"
|
||||
variant="text"
|
||||
:disabled="index === 0"
|
||||
@click.stop="moveSelected(index, -1)"
|
||||
/>
|
||||
<v-btn
|
||||
icon="mdi-arrow-down"
|
||||
size="x-small"
|
||||
variant="text"
|
||||
:disabled="index === selectedProviders.length - 1"
|
||||
@click.stop="moveSelected(index, 1)"
|
||||
/>
|
||||
<v-btn
|
||||
icon="mdi-close"
|
||||
size="x-small"
|
||||
variant="text"
|
||||
@click.stop="removeSelected(providerId)"
|
||||
/>
|
||||
</div>
|
||||
</template>
|
||||
</v-list-item>
|
||||
</v-list>
|
||||
<v-divider class="ma-1"></v-divider>
|
||||
</div>
|
||||
|
||||
<v-list v-if="!loading && providerList.length > 0" density="compact">
|
||||
<!-- 不选择选项 -->
|
||||
<v-list-item
|
||||
v-if="!multiple"
|
||||
key="none"
|
||||
value=""
|
||||
@click="selectProvider({ id: '' })"
|
||||
@@ -57,7 +118,7 @@
|
||||
:key="provider.id"
|
||||
:value="provider.id"
|
||||
@click="selectProvider(provider)"
|
||||
:active="selectedProvider === provider.id"
|
||||
:active="isProviderSelected(provider.id)"
|
||||
rounded="md"
|
||||
class="ma-1">
|
||||
<v-list-item-title>{{ provider.id }}</v-list-item-title>
|
||||
@@ -67,7 +128,7 @@
|
||||
</v-list-item-subtitle>
|
||||
|
||||
<template v-slot:append>
|
||||
<v-icon v-if="selectedProvider === provider.id" color="primary">mdi-check-circle</v-icon>
|
||||
<v-icon v-if="isProviderSelected(provider.id)" color="primary">mdi-check-circle</v-icon>
|
||||
</template>
|
||||
</v-list-item>
|
||||
</v-list>
|
||||
@@ -121,7 +182,7 @@ import ProviderPage from '@/views/ProviderPage.vue'
|
||||
|
||||
const props = defineProps({
|
||||
modelValue: {
|
||||
type: String,
|
||||
type: [String, Array],
|
||||
default: ''
|
||||
},
|
||||
providerType: {
|
||||
@@ -135,6 +196,10 @@ const props = defineProps({
|
||||
buttonText: {
|
||||
type: String,
|
||||
default: ''
|
||||
},
|
||||
multiple: {
|
||||
type: Boolean,
|
||||
default: false
|
||||
}
|
||||
})
|
||||
|
||||
@@ -145,8 +210,16 @@ const dialog = ref(false)
|
||||
const providerList = ref([])
|
||||
const loading = ref(false)
|
||||
const selectedProvider = ref('')
|
||||
const selectedProviders = ref([])
|
||||
const providerDrawer = ref(false)
|
||||
|
||||
const hasSelection = computed(() => {
|
||||
if (props.multiple) {
|
||||
return selectedProviders.value.length > 0
|
||||
}
|
||||
return Boolean(props.modelValue)
|
||||
})
|
||||
|
||||
const defaultTab = computed(() => {
|
||||
if (props.providerType === 'agent_runner' && props.providerSubtype) {
|
||||
return `select_agent_runner_provider:${props.providerSubtype}`
|
||||
@@ -156,7 +229,13 @@ const defaultTab = computed(() => {
|
||||
|
||||
// 监听 modelValue 变化,同步到 selectedProvider
|
||||
watch(() => props.modelValue, (newValue) => {
|
||||
selectedProvider.value = newValue || ''
|
||||
if (props.multiple) {
|
||||
selectedProviders.value = Array.isArray(newValue)
|
||||
? [...newValue.filter((v) => typeof v === 'string' && v)]
|
||||
: []
|
||||
return
|
||||
}
|
||||
selectedProvider.value = typeof newValue === 'string' ? newValue : ''
|
||||
}, { immediate: true })
|
||||
|
||||
watch(providerDrawer, (isOpen, wasOpen) => {
|
||||
@@ -166,7 +245,13 @@ watch(providerDrawer, (isOpen, wasOpen) => {
|
||||
})
|
||||
|
||||
async function openDialog() {
|
||||
selectedProvider.value = props.modelValue || ''
|
||||
if (props.multiple) {
|
||||
selectedProviders.value = Array.isArray(props.modelValue)
|
||||
? [...props.modelValue.filter((v) => typeof v === 'string' && v)]
|
||||
: []
|
||||
} else {
|
||||
selectedProvider.value = typeof props.modelValue === 'string' ? props.modelValue : ''
|
||||
}
|
||||
dialog.value = true
|
||||
await loadProviders()
|
||||
}
|
||||
@@ -205,19 +290,72 @@ function matchesProviderSubtype(provider, subtype) {
|
||||
}
|
||||
|
||||
function selectProvider(provider) {
|
||||
if (props.multiple) {
|
||||
if (!provider.id) {
|
||||
selectedProviders.value = []
|
||||
return
|
||||
}
|
||||
const idx = selectedProviders.value.indexOf(provider.id)
|
||||
if (idx >= 0) {
|
||||
selectedProviders.value.splice(idx, 1)
|
||||
} else {
|
||||
selectedProviders.value.push(provider.id)
|
||||
}
|
||||
return
|
||||
}
|
||||
selectedProvider.value = provider.id
|
||||
}
|
||||
|
||||
function confirmSelection() {
|
||||
emit('update:modelValue', selectedProvider.value)
|
||||
if (props.multiple) {
|
||||
emit('update:modelValue', [...selectedProviders.value])
|
||||
} else {
|
||||
emit('update:modelValue', selectedProvider.value)
|
||||
}
|
||||
dialog.value = false
|
||||
}
|
||||
|
||||
function cancelSelection() {
|
||||
selectedProvider.value = props.modelValue || ''
|
||||
if (props.multiple) {
|
||||
selectedProviders.value = Array.isArray(props.modelValue)
|
||||
? [...props.modelValue.filter((v) => typeof v === 'string' && v)]
|
||||
: []
|
||||
} else {
|
||||
selectedProvider.value = typeof props.modelValue === 'string' ? props.modelValue : ''
|
||||
}
|
||||
dialog.value = false
|
||||
}
|
||||
|
||||
function isProviderSelected(providerId) {
|
||||
if (props.multiple) {
|
||||
return selectedProviders.value.includes(providerId)
|
||||
}
|
||||
return selectedProvider.value === providerId
|
||||
}
|
||||
|
||||
function removeSelected(providerId) {
|
||||
const idx = selectedProviders.value.indexOf(providerId)
|
||||
if (idx >= 0) {
|
||||
selectedProviders.value.splice(idx, 1)
|
||||
}
|
||||
}
|
||||
|
||||
function moveSelected(index, delta) {
|
||||
const targetIndex = index + delta
|
||||
if (
|
||||
targetIndex < 0
|
||||
|| targetIndex >= selectedProviders.value.length
|
||||
|| index < 0
|
||||
|| index >= selectedProviders.value.length
|
||||
) {
|
||||
return
|
||||
}
|
||||
const copied = [...selectedProviders.value]
|
||||
const [item] = copied.splice(index, 1)
|
||||
copied.splice(targetIndex, 0, item)
|
||||
selectedProviders.value = copied
|
||||
}
|
||||
|
||||
function openProviderDrawer() {
|
||||
providerDrawer.value = true
|
||||
}
|
||||
@@ -236,6 +374,16 @@ function closeProviderDrawer() {
|
||||
display: inline-block;
|
||||
}
|
||||
|
||||
.selected-preview {
|
||||
width: 100%;
|
||||
max-width: 100%;
|
||||
}
|
||||
|
||||
.selected-order-list {
|
||||
background: rgba(var(--v-theme-surface-variant), 0.15);
|
||||
border-radius: 10px;
|
||||
}
|
||||
|
||||
.v-list-item {
|
||||
transition: all 0.2s ease;
|
||||
}
|
||||
|
||||
@@ -45,7 +45,8 @@
|
||||
"unknownType": "Unknown type",
|
||||
"createProvider": "Create Provider",
|
||||
"manageProviders": "Provider Management",
|
||||
"selectProviderPool": "Select Provider Pool..."
|
||||
"selectProviderPool": "Select Provider Pool...",
|
||||
"selectedCount": "{count} provider(s) selected"
|
||||
},
|
||||
"personaSelector": {
|
||||
"notSelected": "Not selected",
|
||||
|
||||
@@ -37,6 +37,10 @@
|
||||
"description": "Default Chat Model",
|
||||
"hint": "Uses the first model when left empty"
|
||||
},
|
||||
"fallback_chat_models": {
|
||||
"description": "Fallback chat model IDs",
|
||||
"hint": "When the primary chat model request fails, fallback to these chat models in order."
|
||||
},
|
||||
"default_image_caption_provider_id": {
|
||||
"description": "Default Image Caption Model",
|
||||
"hint": "Leave empty to disable; useful for non-multimodal models"
|
||||
|
||||
@@ -45,7 +45,8 @@
|
||||
"unknownType": "未知类型",
|
||||
"createProvider": "创建提供商",
|
||||
"manageProviders": "提供商管理",
|
||||
"selectProviderPool": "选择提供商池..."
|
||||
"selectProviderPool": "选择提供商池...",
|
||||
"selectedCount": "已选择 {count} 个提供商"
|
||||
},
|
||||
"personaSelector": {
|
||||
"notSelected": "未选择",
|
||||
|
||||
@@ -31,12 +31,16 @@
|
||||
},
|
||||
"ai": {
|
||||
"description": "模型",
|
||||
"hint": "当使用非内置 Agent 执行器时,默认聊天模型和默认图片转述模型可能会无效,但某些插件会依赖此配置项来调用 AI 能力。",
|
||||
"hint": "当使用非内置 Agent 执行器时,默认对话模型和默认图片转述模型可能会无效,但某些插件会依赖此配置项来调用 AI 能力。",
|
||||
"provider_settings": {
|
||||
"default_provider_id": {
|
||||
"description": "默认聊天模型",
|
||||
"description": "默认对话模型",
|
||||
"hint": "留空时使用第一个模型"
|
||||
},
|
||||
"fallback_chat_models": {
|
||||
"description": "回退对话模型列表",
|
||||
"hint": "主对话模型请求失败时,按顺序切换到这些对话模型。"
|
||||
},
|
||||
"default_image_caption_provider_id": {
|
||||
"description": "默认图片转述模型",
|
||||
"hint": "留空代表不使用,可用于非多模态模型"
|
||||
|
||||
@@ -90,6 +90,21 @@ class MockToolExecutor:
|
||||
return generator()
|
||||
|
||||
|
||||
class MockFailingProvider(MockProvider):
|
||||
async def text_chat(self, **kwargs) -> LLMResponse:
|
||||
self.call_count += 1
|
||||
raise RuntimeError("primary provider failed")
|
||||
|
||||
|
||||
class MockErrProvider(MockProvider):
|
||||
async def text_chat(self, **kwargs) -> LLMResponse:
|
||||
self.call_count += 1
|
||||
return LLMResponse(
|
||||
role="err",
|
||||
completion_text="primary provider returned error",
|
||||
)
|
||||
|
||||
|
||||
class MockHooks(BaseAgentRunHooks):
|
||||
"""模拟钩子函数"""
|
||||
|
||||
@@ -321,6 +336,64 @@ async def test_hooks_called_with_max_step(
|
||||
assert mock_hooks.tool_end_called, "on_tool_end应该被调用"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_provider_used_when_primary_raises(
|
||||
runner, provider_request, mock_tool_executor, mock_hooks
|
||||
):
|
||||
primary_provider = MockFailingProvider()
|
||||
fallback_provider = MockProvider()
|
||||
fallback_provider.should_call_tools = False
|
||||
|
||||
await runner.reset(
|
||||
provider=primary_provider,
|
||||
request=provider_request,
|
||||
run_context=ContextWrapper(context=None),
|
||||
tool_executor=mock_tool_executor,
|
||||
agent_hooks=mock_hooks,
|
||||
streaming=False,
|
||||
fallback_providers=[fallback_provider],
|
||||
)
|
||||
|
||||
async for _ in runner.step_until_done(5):
|
||||
pass
|
||||
|
||||
final_resp = runner.get_final_llm_resp()
|
||||
assert final_resp is not None
|
||||
assert final_resp.role == "assistant"
|
||||
assert final_resp.completion_text == "这是我的最终回答"
|
||||
assert primary_provider.call_count == 1
|
||||
assert fallback_provider.call_count == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_provider_used_when_primary_returns_err(
|
||||
runner, provider_request, mock_tool_executor, mock_hooks
|
||||
):
|
||||
primary_provider = MockErrProvider()
|
||||
fallback_provider = MockProvider()
|
||||
fallback_provider.should_call_tools = False
|
||||
|
||||
await runner.reset(
|
||||
provider=primary_provider,
|
||||
request=provider_request,
|
||||
run_context=ContextWrapper(context=None),
|
||||
tool_executor=mock_tool_executor,
|
||||
agent_hooks=mock_hooks,
|
||||
streaming=False,
|
||||
fallback_providers=[fallback_provider],
|
||||
)
|
||||
|
||||
async for _ in runner.step_until_done(5):
|
||||
pass
|
||||
|
||||
final_resp = runner.get_final_llm_resp()
|
||||
assert final_resp is not None
|
||||
assert final_resp.role == "assistant"
|
||||
assert final_resp.completion_text == "这是我的最终回答"
|
||||
assert primary_provider.call_count == 1
|
||||
assert fallback_provider.call_count == 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 运行测试
|
||||
pytest.main([__file__, "-v"])
|
||||
|
||||
Reference in New Issue
Block a user