diff --git a/astrbot/core/provider/sources/anthropic_source.py b/astrbot/core/provider/sources/anthropic_source.py index 80684aca6..18ba6b9ed 100644 --- a/astrbot/core/provider/sources/anthropic_source.py +++ b/astrbot/core/provider/sources/anthropic_source.py @@ -22,6 +22,7 @@ from astrbot.core.utils.network_utils import ( ) from ..register import register_provider_adapter +from .default import with_model_request_retry @register_provider_adapter( @@ -204,6 +205,7 @@ class ProviderAnthropic(Provider): if usage.output_tokens is not None: token_usage.output = usage.output_tokens + @with_model_request_retry() async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse: if tools: if tool_list := tools.get_func_desc_anthropic_style(): @@ -265,6 +267,10 @@ class ProviderAnthropic(Provider): return llm_response + @with_model_request_retry() + async def _create_message_stream(self, payloads: dict, extra_body: dict): + return self.client.messages.stream(**payloads, extra_body=extra_body) + async def _query_stream( self, payloads: dict, @@ -293,9 +299,8 @@ class ProviderAnthropic(Provider): "type": "enabled", } - async with self.client.messages.stream( - **payloads, extra_body=extra_body - ) as stream: + stream_ctx = await self._create_message_stream(payloads, extra_body) + async with stream_ctx as stream: assert isinstance(stream, anthropic.AsyncMessageStream) async for event in stream: if event.type == "message_start": diff --git a/astrbot/core/provider/sources/default.py b/astrbot/core/provider/sources/default.py new file mode 100644 index 000000000..8554a7a72 --- /dev/null +++ b/astrbot/core/provider/sources/default.py @@ -0,0 +1,38 @@ +from tenacity import ( + AsyncRetrying, + retry, + retry_if_exception_type, + stop_after_attempt, + wait_exponential, +) + +MODEL_REQUEST_RETRY_ATTEMPTS = 5 +MODEL_REQUEST_RETRY_WAIT_MAX_SECONDS = 15 +MODEL_REQUEST_RETRY_WAIT_MIN_SECONDS = 1 +MODEL_REQUEST_RETRY_WAIT_MULTIPLIER = 1 + + +def with_model_request_retry(): + return retry( + retry=retry_if_exception_type(Exception), + stop=stop_after_attempt(MODEL_REQUEST_RETRY_ATTEMPTS), + wait=wait_exponential( + multiplier=MODEL_REQUEST_RETRY_WAIT_MULTIPLIER, + min=MODEL_REQUEST_RETRY_WAIT_MIN_SECONDS, + max=MODEL_REQUEST_RETRY_WAIT_MAX_SECONDS, + ), + reraise=True, + ) + + +def get_model_request_async_retrying() -> AsyncRetrying: + return AsyncRetrying( + retry=retry_if_exception_type(Exception), + stop=stop_after_attempt(MODEL_REQUEST_RETRY_ATTEMPTS), + wait=wait_exponential( + multiplier=MODEL_REQUEST_RETRY_WAIT_MULTIPLIER, + min=MODEL_REQUEST_RETRY_WAIT_MIN_SECONDS, + max=MODEL_REQUEST_RETRY_WAIT_MAX_SECONDS, + ), + reraise=True, + ) diff --git a/astrbot/core/provider/sources/gemini_source.py b/astrbot/core/provider/sources/gemini_source.py index 9557f3dbc..6e64b8770 100644 --- a/astrbot/core/provider/sources/gemini_source.py +++ b/astrbot/core/provider/sources/gemini_source.py @@ -21,6 +21,7 @@ from astrbot.core.utils.io import download_image_by_url from astrbot.core.utils.network_utils import is_connection_error, log_connection_failure from ..register import register_provider_adapter +from .default import get_model_request_async_retrying, with_model_request_retry class SuppressNonTextPartsWarning(logging.Filter): @@ -513,6 +514,7 @@ class ProviderGoogleGenAI(Provider): llm_response.reasoning_signature = base64.b64encode(ts).decode("utf-8") return MessageChain(chain=chain) + @with_model_request_retry() async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse: """非流式请求 Gemini API""" system_instruction = next( @@ -601,6 +603,17 @@ class ProviderGoogleGenAI(Provider): self, payloads: dict, tools: ToolSet | None, + ) -> AsyncGenerator[LLMResponse, None]: + async for attempt in get_model_request_async_retrying(): + with attempt: + async for response in self._query_stream_once(payloads, tools): + yield response + return + + async def _query_stream_once( + self, + payloads: dict, + tools: ToolSet | None, ) -> AsyncGenerator[LLMResponse, None]: """流式请求 Gemini API""" system_instruction = next( @@ -759,18 +772,7 @@ class ProviderGoogleGenAI(Provider): payloads = {"messages": context_query, "model": model} - retry = 10 - keys = self.api_keys.copy() - - for _ in range(retry): - try: - return await self._query(payloads, func_tool) - except APIError as e: - if await self._handle_api_error(e, keys): - continue - break - - raise Exception("请求失败。") + return await self._query(payloads, func_tool) async def text_chat_stream( self, @@ -814,18 +816,8 @@ class ProviderGoogleGenAI(Provider): payloads = {"messages": context_query, "model": model} - retry = 10 - keys = self.api_keys.copy() - - for _ in range(retry): - try: - async for response in self._query_stream(payloads, func_tool): - yield response - break - except APIError as e: - if await self._handle_api_error(e, keys): - continue - break + async for response in self._query_stream(payloads, func_tool): + yield response async def get_models(self): try: diff --git a/astrbot/core/provider/sources/openai_source.py b/astrbot/core/provider/sources/openai_source.py index 5378385e5..a88544451 100644 --- a/astrbot/core/provider/sources/openai_source.py +++ b/astrbot/core/provider/sources/openai_source.py @@ -31,6 +31,7 @@ from astrbot.core.utils.network_utils import ( from astrbot.core.utils.string_utils import normalize_and_dedupe_strings from ..register import register_provider_adapter +from .default import get_model_request_async_retrying, with_model_request_retry @register_provider_adapter( @@ -221,6 +222,7 @@ class ProviderOpenAIOfficial(Provider): except NotFoundError as e: raise Exception(f"获取模型列表失败:{e}") + @with_model_request_retry() async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse: if tools: model = payloads.get("model", "").lower() @@ -246,8 +248,6 @@ class ProviderOpenAIOfficial(Provider): if isinstance(custom_extra_body, dict): extra_body.update(custom_extra_body) - model = payloads.get("model", "").lower() - completion = await self.client.chat.completions.create( **payloads, stream=False, @@ -269,6 +269,17 @@ class ProviderOpenAIOfficial(Provider): self, payloads: dict, tools: ToolSet | None, + ) -> AsyncGenerator[LLMResponse, None]: + async for attempt in get_model_request_async_retrying(): + with attempt: + async for response in self._query_stream_once(payloads, tools): + yield response + return + + async def _query_stream_once( + self, + payloads: dict, + tools: ToolSet | None, ) -> AsyncGenerator[LLMResponse, None]: """流式查询API,逐步返回结果""" if tools: @@ -716,7 +727,7 @@ class ProviderOpenAIOfficial(Provider): extra_user_content_parts=None, **kwargs, ) -> LLMResponse: - payloads, context_query = await self._prepare_chat_payload( + payloads, _ = await self._prepare_chat_payload( prompt, image_urls, contexts, @@ -728,47 +739,9 @@ class ProviderOpenAIOfficial(Provider): ) llm_response = None - max_retries = 10 - available_api_keys = self.api_keys.copy() - chosen_key = random.choice(available_api_keys) - image_fallback_used = False - - last_exception = None - retry_cnt = 0 - for retry_cnt in range(max_retries): - try: - self.client.api_key = chosen_key - llm_response = await self._query(payloads, func_tool) - break - except Exception as e: - last_exception = e - ( - success, - chosen_key, - available_api_keys, - payloads, - context_query, - func_tool, - image_fallback_used, - ) = await self._handle_api_error( - e, - payloads, - context_query, - func_tool, - chosen_key, - available_api_keys, - retry_cnt, - max_retries, - image_fallback_used=image_fallback_used, - ) - if success: - break - - if retry_cnt == max_retries - 1 or llm_response is None: - logger.error(f"API 调用失败,重试 {max_retries} 次仍然失败。") - if last_exception is None: - raise Exception("未知错误") - raise last_exception + if self.api_keys: + self.client.api_key = random.choice(self.api_keys) + llm_response = await self._query(payloads, func_tool) return llm_response async def text_chat_stream( @@ -784,7 +757,7 @@ class ProviderOpenAIOfficial(Provider): **kwargs, ) -> AsyncGenerator[LLMResponse, None]: """流式对话,与服务商交互并逐步返回结果""" - payloads, context_query = await self._prepare_chat_payload( + payloads, _ = await self._prepare_chat_payload( prompt, image_urls, contexts, @@ -794,48 +767,10 @@ class ProviderOpenAIOfficial(Provider): **kwargs, ) - max_retries = 10 - available_api_keys = self.api_keys.copy() - chosen_key = random.choice(available_api_keys) - image_fallback_used = False - - last_exception = None - retry_cnt = 0 - for retry_cnt in range(max_retries): - try: - self.client.api_key = chosen_key - async for response in self._query_stream(payloads, func_tool): - yield response - break - except Exception as e: - last_exception = e - ( - success, - chosen_key, - available_api_keys, - payloads, - context_query, - func_tool, - image_fallback_used, - ) = await self._handle_api_error( - e, - payloads, - context_query, - func_tool, - chosen_key, - available_api_keys, - retry_cnt, - max_retries, - image_fallback_used=image_fallback_used, - ) - if success: - break - - if retry_cnt == max_retries - 1: - logger.error(f"API 调用失败,重试 {max_retries} 次仍然失败。") - if last_exception is None: - raise Exception("未知错误") - raise last_exception + if self.api_keys: + self.client.api_key = random.choice(self.api_keys) + async for response in self._query_stream(payloads, func_tool): + yield response async def _remove_image_from_context(self, contexts: list): """从上下文中删除所有带有 image 的记录"""