Compare commits
1 Commits
dev
...
perf/tenacity
| Author | SHA1 | Date | |
|---|---|---|---|
| 572689b416 |
@@ -22,6 +22,7 @@ from astrbot.core.utils.network_utils import (
|
||||
)
|
||||
|
||||
from ..register import register_provider_adapter
|
||||
from .default import with_model_request_retry
|
||||
|
||||
|
||||
@register_provider_adapter(
|
||||
@@ -204,6 +205,7 @@ class ProviderAnthropic(Provider):
|
||||
if usage.output_tokens is not None:
|
||||
token_usage.output = usage.output_tokens
|
||||
|
||||
@with_model_request_retry()
|
||||
async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse:
|
||||
if tools:
|
||||
if tool_list := tools.get_func_desc_anthropic_style():
|
||||
@@ -265,6 +267,10 @@ class ProviderAnthropic(Provider):
|
||||
|
||||
return llm_response
|
||||
|
||||
@with_model_request_retry()
|
||||
async def _create_message_stream(self, payloads: dict, extra_body: dict):
|
||||
return self.client.messages.stream(**payloads, extra_body=extra_body)
|
||||
|
||||
async def _query_stream(
|
||||
self,
|
||||
payloads: dict,
|
||||
@@ -293,9 +299,8 @@ class ProviderAnthropic(Provider):
|
||||
"type": "enabled",
|
||||
}
|
||||
|
||||
async with self.client.messages.stream(
|
||||
**payloads, extra_body=extra_body
|
||||
) as stream:
|
||||
stream_ctx = await self._create_message_stream(payloads, extra_body)
|
||||
async with stream_ctx as stream:
|
||||
assert isinstance(stream, anthropic.AsyncMessageStream)
|
||||
async for event in stream:
|
||||
if event.type == "message_start":
|
||||
|
||||
@@ -0,0 +1,38 @@
|
||||
from tenacity import (
|
||||
AsyncRetrying,
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
)
|
||||
|
||||
MODEL_REQUEST_RETRY_ATTEMPTS = 5
|
||||
MODEL_REQUEST_RETRY_WAIT_MAX_SECONDS = 15
|
||||
MODEL_REQUEST_RETRY_WAIT_MIN_SECONDS = 1
|
||||
MODEL_REQUEST_RETRY_WAIT_MULTIPLIER = 1
|
||||
|
||||
|
||||
def with_model_request_retry():
|
||||
return retry(
|
||||
retry=retry_if_exception_type(Exception),
|
||||
stop=stop_after_attempt(MODEL_REQUEST_RETRY_ATTEMPTS),
|
||||
wait=wait_exponential(
|
||||
multiplier=MODEL_REQUEST_RETRY_WAIT_MULTIPLIER,
|
||||
min=MODEL_REQUEST_RETRY_WAIT_MIN_SECONDS,
|
||||
max=MODEL_REQUEST_RETRY_WAIT_MAX_SECONDS,
|
||||
),
|
||||
reraise=True,
|
||||
)
|
||||
|
||||
|
||||
def get_model_request_async_retrying() -> AsyncRetrying:
|
||||
return AsyncRetrying(
|
||||
retry=retry_if_exception_type(Exception),
|
||||
stop=stop_after_attempt(MODEL_REQUEST_RETRY_ATTEMPTS),
|
||||
wait=wait_exponential(
|
||||
multiplier=MODEL_REQUEST_RETRY_WAIT_MULTIPLIER,
|
||||
min=MODEL_REQUEST_RETRY_WAIT_MIN_SECONDS,
|
||||
max=MODEL_REQUEST_RETRY_WAIT_MAX_SECONDS,
|
||||
),
|
||||
reraise=True,
|
||||
)
|
||||
@@ -21,6 +21,7 @@ from astrbot.core.utils.io import download_image_by_url
|
||||
from astrbot.core.utils.network_utils import is_connection_error, log_connection_failure
|
||||
|
||||
from ..register import register_provider_adapter
|
||||
from .default import get_model_request_async_retrying, with_model_request_retry
|
||||
|
||||
|
||||
class SuppressNonTextPartsWarning(logging.Filter):
|
||||
@@ -513,6 +514,7 @@ class ProviderGoogleGenAI(Provider):
|
||||
llm_response.reasoning_signature = base64.b64encode(ts).decode("utf-8")
|
||||
return MessageChain(chain=chain)
|
||||
|
||||
@with_model_request_retry()
|
||||
async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse:
|
||||
"""非流式请求 Gemini API"""
|
||||
system_instruction = next(
|
||||
@@ -601,6 +603,17 @@ class ProviderGoogleGenAI(Provider):
|
||||
self,
|
||||
payloads: dict,
|
||||
tools: ToolSet | None,
|
||||
) -> AsyncGenerator[LLMResponse, None]:
|
||||
async for attempt in get_model_request_async_retrying():
|
||||
with attempt:
|
||||
async for response in self._query_stream_once(payloads, tools):
|
||||
yield response
|
||||
return
|
||||
|
||||
async def _query_stream_once(
|
||||
self,
|
||||
payloads: dict,
|
||||
tools: ToolSet | None,
|
||||
) -> AsyncGenerator[LLMResponse, None]:
|
||||
"""流式请求 Gemini API"""
|
||||
system_instruction = next(
|
||||
@@ -759,18 +772,7 @@ class ProviderGoogleGenAI(Provider):
|
||||
|
||||
payloads = {"messages": context_query, "model": model}
|
||||
|
||||
retry = 10
|
||||
keys = self.api_keys.copy()
|
||||
|
||||
for _ in range(retry):
|
||||
try:
|
||||
return await self._query(payloads, func_tool)
|
||||
except APIError as e:
|
||||
if await self._handle_api_error(e, keys):
|
||||
continue
|
||||
break
|
||||
|
||||
raise Exception("请求失败。")
|
||||
return await self._query(payloads, func_tool)
|
||||
|
||||
async def text_chat_stream(
|
||||
self,
|
||||
@@ -814,18 +816,8 @@ class ProviderGoogleGenAI(Provider):
|
||||
|
||||
payloads = {"messages": context_query, "model": model}
|
||||
|
||||
retry = 10
|
||||
keys = self.api_keys.copy()
|
||||
|
||||
for _ in range(retry):
|
||||
try:
|
||||
async for response in self._query_stream(payloads, func_tool):
|
||||
yield response
|
||||
break
|
||||
except APIError as e:
|
||||
if await self._handle_api_error(e, keys):
|
||||
continue
|
||||
break
|
||||
async for response in self._query_stream(payloads, func_tool):
|
||||
yield response
|
||||
|
||||
async def get_models(self):
|
||||
try:
|
||||
|
||||
@@ -31,6 +31,7 @@ from astrbot.core.utils.network_utils import (
|
||||
from astrbot.core.utils.string_utils import normalize_and_dedupe_strings
|
||||
|
||||
from ..register import register_provider_adapter
|
||||
from .default import get_model_request_async_retrying, with_model_request_retry
|
||||
|
||||
|
||||
@register_provider_adapter(
|
||||
@@ -221,6 +222,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
except NotFoundError as e:
|
||||
raise Exception(f"获取模型列表失败:{e}")
|
||||
|
||||
@with_model_request_retry()
|
||||
async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse:
|
||||
if tools:
|
||||
model = payloads.get("model", "").lower()
|
||||
@@ -246,8 +248,6 @@ class ProviderOpenAIOfficial(Provider):
|
||||
if isinstance(custom_extra_body, dict):
|
||||
extra_body.update(custom_extra_body)
|
||||
|
||||
model = payloads.get("model", "").lower()
|
||||
|
||||
completion = await self.client.chat.completions.create(
|
||||
**payloads,
|
||||
stream=False,
|
||||
@@ -269,6 +269,17 @@ class ProviderOpenAIOfficial(Provider):
|
||||
self,
|
||||
payloads: dict,
|
||||
tools: ToolSet | None,
|
||||
) -> AsyncGenerator[LLMResponse, None]:
|
||||
async for attempt in get_model_request_async_retrying():
|
||||
with attempt:
|
||||
async for response in self._query_stream_once(payloads, tools):
|
||||
yield response
|
||||
return
|
||||
|
||||
async def _query_stream_once(
|
||||
self,
|
||||
payloads: dict,
|
||||
tools: ToolSet | None,
|
||||
) -> AsyncGenerator[LLMResponse, None]:
|
||||
"""流式查询API,逐步返回结果"""
|
||||
if tools:
|
||||
@@ -716,7 +727,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
extra_user_content_parts=None,
|
||||
**kwargs,
|
||||
) -> LLMResponse:
|
||||
payloads, context_query = await self._prepare_chat_payload(
|
||||
payloads, _ = await self._prepare_chat_payload(
|
||||
prompt,
|
||||
image_urls,
|
||||
contexts,
|
||||
@@ -728,47 +739,9 @@ class ProviderOpenAIOfficial(Provider):
|
||||
)
|
||||
|
||||
llm_response = None
|
||||
max_retries = 10
|
||||
available_api_keys = self.api_keys.copy()
|
||||
chosen_key = random.choice(available_api_keys)
|
||||
image_fallback_used = False
|
||||
|
||||
last_exception = None
|
||||
retry_cnt = 0
|
||||
for retry_cnt in range(max_retries):
|
||||
try:
|
||||
self.client.api_key = chosen_key
|
||||
llm_response = await self._query(payloads, func_tool)
|
||||
break
|
||||
except Exception as e:
|
||||
last_exception = e
|
||||
(
|
||||
success,
|
||||
chosen_key,
|
||||
available_api_keys,
|
||||
payloads,
|
||||
context_query,
|
||||
func_tool,
|
||||
image_fallback_used,
|
||||
) = await self._handle_api_error(
|
||||
e,
|
||||
payloads,
|
||||
context_query,
|
||||
func_tool,
|
||||
chosen_key,
|
||||
available_api_keys,
|
||||
retry_cnt,
|
||||
max_retries,
|
||||
image_fallback_used=image_fallback_used,
|
||||
)
|
||||
if success:
|
||||
break
|
||||
|
||||
if retry_cnt == max_retries - 1 or llm_response is None:
|
||||
logger.error(f"API 调用失败,重试 {max_retries} 次仍然失败。")
|
||||
if last_exception is None:
|
||||
raise Exception("未知错误")
|
||||
raise last_exception
|
||||
if self.api_keys:
|
||||
self.client.api_key = random.choice(self.api_keys)
|
||||
llm_response = await self._query(payloads, func_tool)
|
||||
return llm_response
|
||||
|
||||
async def text_chat_stream(
|
||||
@@ -784,7 +757,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
**kwargs,
|
||||
) -> AsyncGenerator[LLMResponse, None]:
|
||||
"""流式对话,与服务商交互并逐步返回结果"""
|
||||
payloads, context_query = await self._prepare_chat_payload(
|
||||
payloads, _ = await self._prepare_chat_payload(
|
||||
prompt,
|
||||
image_urls,
|
||||
contexts,
|
||||
@@ -794,48 +767,10 @@ class ProviderOpenAIOfficial(Provider):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
max_retries = 10
|
||||
available_api_keys = self.api_keys.copy()
|
||||
chosen_key = random.choice(available_api_keys)
|
||||
image_fallback_used = False
|
||||
|
||||
last_exception = None
|
||||
retry_cnt = 0
|
||||
for retry_cnt in range(max_retries):
|
||||
try:
|
||||
self.client.api_key = chosen_key
|
||||
async for response in self._query_stream(payloads, func_tool):
|
||||
yield response
|
||||
break
|
||||
except Exception as e:
|
||||
last_exception = e
|
||||
(
|
||||
success,
|
||||
chosen_key,
|
||||
available_api_keys,
|
||||
payloads,
|
||||
context_query,
|
||||
func_tool,
|
||||
image_fallback_used,
|
||||
) = await self._handle_api_error(
|
||||
e,
|
||||
payloads,
|
||||
context_query,
|
||||
func_tool,
|
||||
chosen_key,
|
||||
available_api_keys,
|
||||
retry_cnt,
|
||||
max_retries,
|
||||
image_fallback_used=image_fallback_used,
|
||||
)
|
||||
if success:
|
||||
break
|
||||
|
||||
if retry_cnt == max_retries - 1:
|
||||
logger.error(f"API 调用失败,重试 {max_retries} 次仍然失败。")
|
||||
if last_exception is None:
|
||||
raise Exception("未知错误")
|
||||
raise last_exception
|
||||
if self.api_keys:
|
||||
self.client.api_key = random.choice(self.api_keys)
|
||||
async for response in self._query_stream(payloads, func_tool):
|
||||
yield response
|
||||
|
||||
async def _remove_image_from_context(self, contexts: list):
|
||||
"""从上下文中删除所有带有 image 的记录"""
|
||||
|
||||
Reference in New Issue
Block a user