Compare commits
1 Commits
master
...
perf/tenacity
| Author | SHA1 | Date | |
|---|---|---|---|
| 572689b416 |
@@ -22,6 +22,7 @@ from astrbot.core.utils.network_utils import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from ..register import register_provider_adapter
|
from ..register import register_provider_adapter
|
||||||
|
from .default import with_model_request_retry
|
||||||
|
|
||||||
|
|
||||||
@register_provider_adapter(
|
@register_provider_adapter(
|
||||||
@@ -204,6 +205,7 @@ class ProviderAnthropic(Provider):
|
|||||||
if usage.output_tokens is not None:
|
if usage.output_tokens is not None:
|
||||||
token_usage.output = usage.output_tokens
|
token_usage.output = usage.output_tokens
|
||||||
|
|
||||||
|
@with_model_request_retry()
|
||||||
async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse:
|
async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse:
|
||||||
if tools:
|
if tools:
|
||||||
if tool_list := tools.get_func_desc_anthropic_style():
|
if tool_list := tools.get_func_desc_anthropic_style():
|
||||||
@@ -265,6 +267,10 @@ class ProviderAnthropic(Provider):
|
|||||||
|
|
||||||
return llm_response
|
return llm_response
|
||||||
|
|
||||||
|
@with_model_request_retry()
|
||||||
|
async def _create_message_stream(self, payloads: dict, extra_body: dict):
|
||||||
|
return self.client.messages.stream(**payloads, extra_body=extra_body)
|
||||||
|
|
||||||
async def _query_stream(
|
async def _query_stream(
|
||||||
self,
|
self,
|
||||||
payloads: dict,
|
payloads: dict,
|
||||||
@@ -293,9 +299,8 @@ class ProviderAnthropic(Provider):
|
|||||||
"type": "enabled",
|
"type": "enabled",
|
||||||
}
|
}
|
||||||
|
|
||||||
async with self.client.messages.stream(
|
stream_ctx = await self._create_message_stream(payloads, extra_body)
|
||||||
**payloads, extra_body=extra_body
|
async with stream_ctx as stream:
|
||||||
) as stream:
|
|
||||||
assert isinstance(stream, anthropic.AsyncMessageStream)
|
assert isinstance(stream, anthropic.AsyncMessageStream)
|
||||||
async for event in stream:
|
async for event in stream:
|
||||||
if event.type == "message_start":
|
if event.type == "message_start":
|
||||||
|
|||||||
@@ -0,0 +1,38 @@
|
|||||||
|
from tenacity import (
|
||||||
|
AsyncRetrying,
|
||||||
|
retry,
|
||||||
|
retry_if_exception_type,
|
||||||
|
stop_after_attempt,
|
||||||
|
wait_exponential,
|
||||||
|
)
|
||||||
|
|
||||||
|
MODEL_REQUEST_RETRY_ATTEMPTS = 5
|
||||||
|
MODEL_REQUEST_RETRY_WAIT_MAX_SECONDS = 15
|
||||||
|
MODEL_REQUEST_RETRY_WAIT_MIN_SECONDS = 1
|
||||||
|
MODEL_REQUEST_RETRY_WAIT_MULTIPLIER = 1
|
||||||
|
|
||||||
|
|
||||||
|
def with_model_request_retry():
|
||||||
|
return retry(
|
||||||
|
retry=retry_if_exception_type(Exception),
|
||||||
|
stop=stop_after_attempt(MODEL_REQUEST_RETRY_ATTEMPTS),
|
||||||
|
wait=wait_exponential(
|
||||||
|
multiplier=MODEL_REQUEST_RETRY_WAIT_MULTIPLIER,
|
||||||
|
min=MODEL_REQUEST_RETRY_WAIT_MIN_SECONDS,
|
||||||
|
max=MODEL_REQUEST_RETRY_WAIT_MAX_SECONDS,
|
||||||
|
),
|
||||||
|
reraise=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_request_async_retrying() -> AsyncRetrying:
|
||||||
|
return AsyncRetrying(
|
||||||
|
retry=retry_if_exception_type(Exception),
|
||||||
|
stop=stop_after_attempt(MODEL_REQUEST_RETRY_ATTEMPTS),
|
||||||
|
wait=wait_exponential(
|
||||||
|
multiplier=MODEL_REQUEST_RETRY_WAIT_MULTIPLIER,
|
||||||
|
min=MODEL_REQUEST_RETRY_WAIT_MIN_SECONDS,
|
||||||
|
max=MODEL_REQUEST_RETRY_WAIT_MAX_SECONDS,
|
||||||
|
),
|
||||||
|
reraise=True,
|
||||||
|
)
|
||||||
@@ -21,6 +21,7 @@ from astrbot.core.utils.io import download_image_by_url
|
|||||||
from astrbot.core.utils.network_utils import is_connection_error, log_connection_failure
|
from astrbot.core.utils.network_utils import is_connection_error, log_connection_failure
|
||||||
|
|
||||||
from ..register import register_provider_adapter
|
from ..register import register_provider_adapter
|
||||||
|
from .default import get_model_request_async_retrying, with_model_request_retry
|
||||||
|
|
||||||
|
|
||||||
class SuppressNonTextPartsWarning(logging.Filter):
|
class SuppressNonTextPartsWarning(logging.Filter):
|
||||||
@@ -513,6 +514,7 @@ class ProviderGoogleGenAI(Provider):
|
|||||||
llm_response.reasoning_signature = base64.b64encode(ts).decode("utf-8")
|
llm_response.reasoning_signature = base64.b64encode(ts).decode("utf-8")
|
||||||
return MessageChain(chain=chain)
|
return MessageChain(chain=chain)
|
||||||
|
|
||||||
|
@with_model_request_retry()
|
||||||
async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse:
|
async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse:
|
||||||
"""非流式请求 Gemini API"""
|
"""非流式请求 Gemini API"""
|
||||||
system_instruction = next(
|
system_instruction = next(
|
||||||
@@ -601,6 +603,17 @@ class ProviderGoogleGenAI(Provider):
|
|||||||
self,
|
self,
|
||||||
payloads: dict,
|
payloads: dict,
|
||||||
tools: ToolSet | None,
|
tools: ToolSet | None,
|
||||||
|
) -> AsyncGenerator[LLMResponse, None]:
|
||||||
|
async for attempt in get_model_request_async_retrying():
|
||||||
|
with attempt:
|
||||||
|
async for response in self._query_stream_once(payloads, tools):
|
||||||
|
yield response
|
||||||
|
return
|
||||||
|
|
||||||
|
async def _query_stream_once(
|
||||||
|
self,
|
||||||
|
payloads: dict,
|
||||||
|
tools: ToolSet | None,
|
||||||
) -> AsyncGenerator[LLMResponse, None]:
|
) -> AsyncGenerator[LLMResponse, None]:
|
||||||
"""流式请求 Gemini API"""
|
"""流式请求 Gemini API"""
|
||||||
system_instruction = next(
|
system_instruction = next(
|
||||||
@@ -759,18 +772,7 @@ class ProviderGoogleGenAI(Provider):
|
|||||||
|
|
||||||
payloads = {"messages": context_query, "model": model}
|
payloads = {"messages": context_query, "model": model}
|
||||||
|
|
||||||
retry = 10
|
return await self._query(payloads, func_tool)
|
||||||
keys = self.api_keys.copy()
|
|
||||||
|
|
||||||
for _ in range(retry):
|
|
||||||
try:
|
|
||||||
return await self._query(payloads, func_tool)
|
|
||||||
except APIError as e:
|
|
||||||
if await self._handle_api_error(e, keys):
|
|
||||||
continue
|
|
||||||
break
|
|
||||||
|
|
||||||
raise Exception("请求失败。")
|
|
||||||
|
|
||||||
async def text_chat_stream(
|
async def text_chat_stream(
|
||||||
self,
|
self,
|
||||||
@@ -814,18 +816,8 @@ class ProviderGoogleGenAI(Provider):
|
|||||||
|
|
||||||
payloads = {"messages": context_query, "model": model}
|
payloads = {"messages": context_query, "model": model}
|
||||||
|
|
||||||
retry = 10
|
async for response in self._query_stream(payloads, func_tool):
|
||||||
keys = self.api_keys.copy()
|
yield response
|
||||||
|
|
||||||
for _ in range(retry):
|
|
||||||
try:
|
|
||||||
async for response in self._query_stream(payloads, func_tool):
|
|
||||||
yield response
|
|
||||||
break
|
|
||||||
except APIError as e:
|
|
||||||
if await self._handle_api_error(e, keys):
|
|
||||||
continue
|
|
||||||
break
|
|
||||||
|
|
||||||
async def get_models(self):
|
async def get_models(self):
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ from astrbot.core.utils.network_utils import (
|
|||||||
from astrbot.core.utils.string_utils import normalize_and_dedupe_strings
|
from astrbot.core.utils.string_utils import normalize_and_dedupe_strings
|
||||||
|
|
||||||
from ..register import register_provider_adapter
|
from ..register import register_provider_adapter
|
||||||
|
from .default import get_model_request_async_retrying, with_model_request_retry
|
||||||
|
|
||||||
|
|
||||||
@register_provider_adapter(
|
@register_provider_adapter(
|
||||||
@@ -221,6 +222,7 @@ class ProviderOpenAIOfficial(Provider):
|
|||||||
except NotFoundError as e:
|
except NotFoundError as e:
|
||||||
raise Exception(f"获取模型列表失败:{e}")
|
raise Exception(f"获取模型列表失败:{e}")
|
||||||
|
|
||||||
|
@with_model_request_retry()
|
||||||
async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse:
|
async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse:
|
||||||
if tools:
|
if tools:
|
||||||
model = payloads.get("model", "").lower()
|
model = payloads.get("model", "").lower()
|
||||||
@@ -246,8 +248,6 @@ class ProviderOpenAIOfficial(Provider):
|
|||||||
if isinstance(custom_extra_body, dict):
|
if isinstance(custom_extra_body, dict):
|
||||||
extra_body.update(custom_extra_body)
|
extra_body.update(custom_extra_body)
|
||||||
|
|
||||||
model = payloads.get("model", "").lower()
|
|
||||||
|
|
||||||
completion = await self.client.chat.completions.create(
|
completion = await self.client.chat.completions.create(
|
||||||
**payloads,
|
**payloads,
|
||||||
stream=False,
|
stream=False,
|
||||||
@@ -269,6 +269,17 @@ class ProviderOpenAIOfficial(Provider):
|
|||||||
self,
|
self,
|
||||||
payloads: dict,
|
payloads: dict,
|
||||||
tools: ToolSet | None,
|
tools: ToolSet | None,
|
||||||
|
) -> AsyncGenerator[LLMResponse, None]:
|
||||||
|
async for attempt in get_model_request_async_retrying():
|
||||||
|
with attempt:
|
||||||
|
async for response in self._query_stream_once(payloads, tools):
|
||||||
|
yield response
|
||||||
|
return
|
||||||
|
|
||||||
|
async def _query_stream_once(
|
||||||
|
self,
|
||||||
|
payloads: dict,
|
||||||
|
tools: ToolSet | None,
|
||||||
) -> AsyncGenerator[LLMResponse, None]:
|
) -> AsyncGenerator[LLMResponse, None]:
|
||||||
"""流式查询API,逐步返回结果"""
|
"""流式查询API,逐步返回结果"""
|
||||||
if tools:
|
if tools:
|
||||||
@@ -716,7 +727,7 @@ class ProviderOpenAIOfficial(Provider):
|
|||||||
extra_user_content_parts=None,
|
extra_user_content_parts=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> LLMResponse:
|
) -> LLMResponse:
|
||||||
payloads, context_query = await self._prepare_chat_payload(
|
payloads, _ = await self._prepare_chat_payload(
|
||||||
prompt,
|
prompt,
|
||||||
image_urls,
|
image_urls,
|
||||||
contexts,
|
contexts,
|
||||||
@@ -728,47 +739,9 @@ class ProviderOpenAIOfficial(Provider):
|
|||||||
)
|
)
|
||||||
|
|
||||||
llm_response = None
|
llm_response = None
|
||||||
max_retries = 10
|
if self.api_keys:
|
||||||
available_api_keys = self.api_keys.copy()
|
self.client.api_key = random.choice(self.api_keys)
|
||||||
chosen_key = random.choice(available_api_keys)
|
llm_response = await self._query(payloads, func_tool)
|
||||||
image_fallback_used = False
|
|
||||||
|
|
||||||
last_exception = None
|
|
||||||
retry_cnt = 0
|
|
||||||
for retry_cnt in range(max_retries):
|
|
||||||
try:
|
|
||||||
self.client.api_key = chosen_key
|
|
||||||
llm_response = await self._query(payloads, func_tool)
|
|
||||||
break
|
|
||||||
except Exception as e:
|
|
||||||
last_exception = e
|
|
||||||
(
|
|
||||||
success,
|
|
||||||
chosen_key,
|
|
||||||
available_api_keys,
|
|
||||||
payloads,
|
|
||||||
context_query,
|
|
||||||
func_tool,
|
|
||||||
image_fallback_used,
|
|
||||||
) = await self._handle_api_error(
|
|
||||||
e,
|
|
||||||
payloads,
|
|
||||||
context_query,
|
|
||||||
func_tool,
|
|
||||||
chosen_key,
|
|
||||||
available_api_keys,
|
|
||||||
retry_cnt,
|
|
||||||
max_retries,
|
|
||||||
image_fallback_used=image_fallback_used,
|
|
||||||
)
|
|
||||||
if success:
|
|
||||||
break
|
|
||||||
|
|
||||||
if retry_cnt == max_retries - 1 or llm_response is None:
|
|
||||||
logger.error(f"API 调用失败,重试 {max_retries} 次仍然失败。")
|
|
||||||
if last_exception is None:
|
|
||||||
raise Exception("未知错误")
|
|
||||||
raise last_exception
|
|
||||||
return llm_response
|
return llm_response
|
||||||
|
|
||||||
async def text_chat_stream(
|
async def text_chat_stream(
|
||||||
@@ -784,7 +757,7 @@ class ProviderOpenAIOfficial(Provider):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
) -> AsyncGenerator[LLMResponse, None]:
|
) -> AsyncGenerator[LLMResponse, None]:
|
||||||
"""流式对话,与服务商交互并逐步返回结果"""
|
"""流式对话,与服务商交互并逐步返回结果"""
|
||||||
payloads, context_query = await self._prepare_chat_payload(
|
payloads, _ = await self._prepare_chat_payload(
|
||||||
prompt,
|
prompt,
|
||||||
image_urls,
|
image_urls,
|
||||||
contexts,
|
contexts,
|
||||||
@@ -794,48 +767,10 @@ class ProviderOpenAIOfficial(Provider):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
max_retries = 10
|
if self.api_keys:
|
||||||
available_api_keys = self.api_keys.copy()
|
self.client.api_key = random.choice(self.api_keys)
|
||||||
chosen_key = random.choice(available_api_keys)
|
async for response in self._query_stream(payloads, func_tool):
|
||||||
image_fallback_used = False
|
yield response
|
||||||
|
|
||||||
last_exception = None
|
|
||||||
retry_cnt = 0
|
|
||||||
for retry_cnt in range(max_retries):
|
|
||||||
try:
|
|
||||||
self.client.api_key = chosen_key
|
|
||||||
async for response in self._query_stream(payloads, func_tool):
|
|
||||||
yield response
|
|
||||||
break
|
|
||||||
except Exception as e:
|
|
||||||
last_exception = e
|
|
||||||
(
|
|
||||||
success,
|
|
||||||
chosen_key,
|
|
||||||
available_api_keys,
|
|
||||||
payloads,
|
|
||||||
context_query,
|
|
||||||
func_tool,
|
|
||||||
image_fallback_used,
|
|
||||||
) = await self._handle_api_error(
|
|
||||||
e,
|
|
||||||
payloads,
|
|
||||||
context_query,
|
|
||||||
func_tool,
|
|
||||||
chosen_key,
|
|
||||||
available_api_keys,
|
|
||||||
retry_cnt,
|
|
||||||
max_retries,
|
|
||||||
image_fallback_used=image_fallback_used,
|
|
||||||
)
|
|
||||||
if success:
|
|
||||||
break
|
|
||||||
|
|
||||||
if retry_cnt == max_retries - 1:
|
|
||||||
logger.error(f"API 调用失败,重试 {max_retries} 次仍然失败。")
|
|
||||||
if last_exception is None:
|
|
||||||
raise Exception("未知错误")
|
|
||||||
raise last_exception
|
|
||||||
|
|
||||||
async def _remove_image_from_context(self, contexts: list):
|
async def _remove_image_from_context(self, contexts: list):
|
||||||
"""从上下文中删除所有带有 image 的记录"""
|
"""从上下文中删除所有带有 image 的记录"""
|
||||||
|
|||||||
Reference in New Issue
Block a user