Compare commits

...

1 Commits

Author SHA1 Message Date
Soulter 572689b416 feat: implement retry mechanism for model requests in anthropic and openai providers 2026-02-17 13:18:00 +08:00
4 changed files with 84 additions and 114 deletions
@@ -22,6 +22,7 @@ from astrbot.core.utils.network_utils import (
)
from ..register import register_provider_adapter
from .default import with_model_request_retry
@register_provider_adapter(
@@ -204,6 +205,7 @@ class ProviderAnthropic(Provider):
if usage.output_tokens is not None:
token_usage.output = usage.output_tokens
@with_model_request_retry()
async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse:
if tools:
if tool_list := tools.get_func_desc_anthropic_style():
@@ -265,6 +267,10 @@ class ProviderAnthropic(Provider):
return llm_response
@with_model_request_retry()
async def _create_message_stream(self, payloads: dict, extra_body: dict):
return self.client.messages.stream(**payloads, extra_body=extra_body)
async def _query_stream(
self,
payloads: dict,
@@ -293,9 +299,8 @@ class ProviderAnthropic(Provider):
"type": "enabled",
}
async with self.client.messages.stream(
**payloads, extra_body=extra_body
) as stream:
stream_ctx = await self._create_message_stream(payloads, extra_body)
async with stream_ctx as stream:
assert isinstance(stream, anthropic.AsyncMessageStream)
async for event in stream:
if event.type == "message_start":
+38
View File
@@ -0,0 +1,38 @@
from tenacity import (
AsyncRetrying,
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
MODEL_REQUEST_RETRY_ATTEMPTS = 5
MODEL_REQUEST_RETRY_WAIT_MAX_SECONDS = 15
MODEL_REQUEST_RETRY_WAIT_MIN_SECONDS = 1
MODEL_REQUEST_RETRY_WAIT_MULTIPLIER = 1
def with_model_request_retry():
return retry(
retry=retry_if_exception_type(Exception),
stop=stop_after_attempt(MODEL_REQUEST_RETRY_ATTEMPTS),
wait=wait_exponential(
multiplier=MODEL_REQUEST_RETRY_WAIT_MULTIPLIER,
min=MODEL_REQUEST_RETRY_WAIT_MIN_SECONDS,
max=MODEL_REQUEST_RETRY_WAIT_MAX_SECONDS,
),
reraise=True,
)
def get_model_request_async_retrying() -> AsyncRetrying:
return AsyncRetrying(
retry=retry_if_exception_type(Exception),
stop=stop_after_attempt(MODEL_REQUEST_RETRY_ATTEMPTS),
wait=wait_exponential(
multiplier=MODEL_REQUEST_RETRY_WAIT_MULTIPLIER,
min=MODEL_REQUEST_RETRY_WAIT_MIN_SECONDS,
max=MODEL_REQUEST_RETRY_WAIT_MAX_SECONDS,
),
reraise=True,
)
+16 -24
View File
@@ -21,6 +21,7 @@ from astrbot.core.utils.io import download_image_by_url
from astrbot.core.utils.network_utils import is_connection_error, log_connection_failure
from ..register import register_provider_adapter
from .default import get_model_request_async_retrying, with_model_request_retry
class SuppressNonTextPartsWarning(logging.Filter):
@@ -513,6 +514,7 @@ class ProviderGoogleGenAI(Provider):
llm_response.reasoning_signature = base64.b64encode(ts).decode("utf-8")
return MessageChain(chain=chain)
@with_model_request_retry()
async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse:
"""非流式请求 Gemini API"""
system_instruction = next(
@@ -601,6 +603,17 @@ class ProviderGoogleGenAI(Provider):
self,
payloads: dict,
tools: ToolSet | None,
) -> AsyncGenerator[LLMResponse, None]:
async for attempt in get_model_request_async_retrying():
with attempt:
async for response in self._query_stream_once(payloads, tools):
yield response
return
async def _query_stream_once(
self,
payloads: dict,
tools: ToolSet | None,
) -> AsyncGenerator[LLMResponse, None]:
"""流式请求 Gemini API"""
system_instruction = next(
@@ -759,18 +772,7 @@ class ProviderGoogleGenAI(Provider):
payloads = {"messages": context_query, "model": model}
retry = 10
keys = self.api_keys.copy()
for _ in range(retry):
try:
return await self._query(payloads, func_tool)
except APIError as e:
if await self._handle_api_error(e, keys):
continue
break
raise Exception("请求失败。")
return await self._query(payloads, func_tool)
async def text_chat_stream(
self,
@@ -814,18 +816,8 @@ class ProviderGoogleGenAI(Provider):
payloads = {"messages": context_query, "model": model}
retry = 10
keys = self.api_keys.copy()
for _ in range(retry):
try:
async for response in self._query_stream(payloads, func_tool):
yield response
break
except APIError as e:
if await self._handle_api_error(e, keys):
continue
break
async for response in self._query_stream(payloads, func_tool):
yield response
async def get_models(self):
try:
+22 -87
View File
@@ -31,6 +31,7 @@ from astrbot.core.utils.network_utils import (
from astrbot.core.utils.string_utils import normalize_and_dedupe_strings
from ..register import register_provider_adapter
from .default import get_model_request_async_retrying, with_model_request_retry
@register_provider_adapter(
@@ -221,6 +222,7 @@ class ProviderOpenAIOfficial(Provider):
except NotFoundError as e:
raise Exception(f"获取模型列表失败:{e}")
@with_model_request_retry()
async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse:
if tools:
model = payloads.get("model", "").lower()
@@ -246,8 +248,6 @@ class ProviderOpenAIOfficial(Provider):
if isinstance(custom_extra_body, dict):
extra_body.update(custom_extra_body)
model = payloads.get("model", "").lower()
completion = await self.client.chat.completions.create(
**payloads,
stream=False,
@@ -269,6 +269,17 @@ class ProviderOpenAIOfficial(Provider):
self,
payloads: dict,
tools: ToolSet | None,
) -> AsyncGenerator[LLMResponse, None]:
async for attempt in get_model_request_async_retrying():
with attempt:
async for response in self._query_stream_once(payloads, tools):
yield response
return
async def _query_stream_once(
self,
payloads: dict,
tools: ToolSet | None,
) -> AsyncGenerator[LLMResponse, None]:
"""流式查询API,逐步返回结果"""
if tools:
@@ -716,7 +727,7 @@ class ProviderOpenAIOfficial(Provider):
extra_user_content_parts=None,
**kwargs,
) -> LLMResponse:
payloads, context_query = await self._prepare_chat_payload(
payloads, _ = await self._prepare_chat_payload(
prompt,
image_urls,
contexts,
@@ -728,47 +739,9 @@ class ProviderOpenAIOfficial(Provider):
)
llm_response = None
max_retries = 10
available_api_keys = self.api_keys.copy()
chosen_key = random.choice(available_api_keys)
image_fallback_used = False
last_exception = None
retry_cnt = 0
for retry_cnt in range(max_retries):
try:
self.client.api_key = chosen_key
llm_response = await self._query(payloads, func_tool)
break
except Exception as e:
last_exception = e
(
success,
chosen_key,
available_api_keys,
payloads,
context_query,
func_tool,
image_fallback_used,
) = await self._handle_api_error(
e,
payloads,
context_query,
func_tool,
chosen_key,
available_api_keys,
retry_cnt,
max_retries,
image_fallback_used=image_fallback_used,
)
if success:
break
if retry_cnt == max_retries - 1 or llm_response is None:
logger.error(f"API 调用失败,重试 {max_retries} 次仍然失败。")
if last_exception is None:
raise Exception("未知错误")
raise last_exception
if self.api_keys:
self.client.api_key = random.choice(self.api_keys)
llm_response = await self._query(payloads, func_tool)
return llm_response
async def text_chat_stream(
@@ -784,7 +757,7 @@ class ProviderOpenAIOfficial(Provider):
**kwargs,
) -> AsyncGenerator[LLMResponse, None]:
"""流式对话,与服务商交互并逐步返回结果"""
payloads, context_query = await self._prepare_chat_payload(
payloads, _ = await self._prepare_chat_payload(
prompt,
image_urls,
contexts,
@@ -794,48 +767,10 @@ class ProviderOpenAIOfficial(Provider):
**kwargs,
)
max_retries = 10
available_api_keys = self.api_keys.copy()
chosen_key = random.choice(available_api_keys)
image_fallback_used = False
last_exception = None
retry_cnt = 0
for retry_cnt in range(max_retries):
try:
self.client.api_key = chosen_key
async for response in self._query_stream(payloads, func_tool):
yield response
break
except Exception as e:
last_exception = e
(
success,
chosen_key,
available_api_keys,
payloads,
context_query,
func_tool,
image_fallback_used,
) = await self._handle_api_error(
e,
payloads,
context_query,
func_tool,
chosen_key,
available_api_keys,
retry_cnt,
max_retries,
image_fallback_used=image_fallback_used,
)
if success:
break
if retry_cnt == max_retries - 1:
logger.error(f"API 调用失败,重试 {max_retries} 次仍然失败。")
if last_exception is None:
raise Exception("未知错误")
raise last_exception
if self.api_keys:
self.client.api_key = random.choice(self.api_keys)
async for response in self._query_stream(payloads, func_tool):
yield response
async def _remove_image_from_context(self, contexts: list):
"""从上下文中删除所有带有 image 的记录"""