feat: add provider-souce-level proxy (#4949)
* feat: 添加 Provider 级别代理支持及请求失败日志 * refactor: simplify provider source configuration structure * refactor: move env proxy fallback logic to log_connection_failure * refactor: update client proxy handling and add terminate method for cleanup * refactor: update no_proxy configuration to remove redundant subnet --------- Co-authored-by: Soulter <905617992@qq.com>
This commit is contained in:
@@ -177,7 +177,7 @@ DEFAULT_CONFIG = {
|
||||
"t2i_use_file_service": False,
|
||||
"t2i_active_template": "base",
|
||||
"http_proxy": "",
|
||||
"no_proxy": ["localhost", "127.0.0.1", "::1"],
|
||||
"no_proxy": ["localhost", "127.0.0.1", "::1", "10.*", "192.168.*"],
|
||||
"dashboard": {
|
||||
"enable": True,
|
||||
"username": "astrbot",
|
||||
@@ -913,6 +913,7 @@ CONFIG_METADATA_2 = {
|
||||
"key": [],
|
||||
"api_base": "https://api.openai.com/v1",
|
||||
"timeout": 120,
|
||||
"proxy": "",
|
||||
"custom_headers": {},
|
||||
},
|
||||
"Google Gemini": {
|
||||
@@ -935,6 +936,7 @@ CONFIG_METADATA_2 = {
|
||||
"dangerous_content": "BLOCK_MEDIUM_AND_ABOVE",
|
||||
},
|
||||
"gm_thinking_config": {"budget": 0, "level": "HIGH"},
|
||||
"proxy": "",
|
||||
},
|
||||
"Anthropic": {
|
||||
"id": "anthropic",
|
||||
@@ -945,6 +947,7 @@ CONFIG_METADATA_2 = {
|
||||
"key": [],
|
||||
"api_base": "https://api.anthropic.com/v1",
|
||||
"timeout": 120,
|
||||
"proxy": "",
|
||||
"anth_thinking_config": {"budget": 0},
|
||||
},
|
||||
"Moonshot": {
|
||||
@@ -956,6 +959,7 @@ CONFIG_METADATA_2 = {
|
||||
"key": [],
|
||||
"timeout": 120,
|
||||
"api_base": "https://api.moonshot.cn/v1",
|
||||
"proxy": "",
|
||||
"custom_headers": {},
|
||||
},
|
||||
"xAI": {
|
||||
@@ -967,6 +971,7 @@ CONFIG_METADATA_2 = {
|
||||
"key": [],
|
||||
"api_base": "https://api.x.ai/v1",
|
||||
"timeout": 120,
|
||||
"proxy": "",
|
||||
"custom_headers": {},
|
||||
"xai_native_search": False,
|
||||
},
|
||||
@@ -979,6 +984,7 @@ CONFIG_METADATA_2 = {
|
||||
"key": [],
|
||||
"api_base": "https://api.deepseek.com/v1",
|
||||
"timeout": 120,
|
||||
"proxy": "",
|
||||
"custom_headers": {},
|
||||
},
|
||||
"Zhipu": {
|
||||
@@ -990,6 +996,7 @@ CONFIG_METADATA_2 = {
|
||||
"key": [],
|
||||
"timeout": 120,
|
||||
"api_base": "https://open.bigmodel.cn/api/paas/v4/",
|
||||
"proxy": "",
|
||||
"custom_headers": {},
|
||||
},
|
||||
"Azure OpenAI": {
|
||||
@@ -1002,6 +1009,7 @@ CONFIG_METADATA_2 = {
|
||||
"key": [],
|
||||
"api_base": "",
|
||||
"timeout": 120,
|
||||
"proxy": "",
|
||||
"custom_headers": {},
|
||||
},
|
||||
"Ollama": {
|
||||
@@ -1012,6 +1020,7 @@ CONFIG_METADATA_2 = {
|
||||
"enable": True,
|
||||
"key": ["ollama"], # ollama 的 key 默认是 ollama
|
||||
"api_base": "http://127.0.0.1:11434/v1",
|
||||
"proxy": "",
|
||||
"custom_headers": {},
|
||||
},
|
||||
"LM Studio": {
|
||||
@@ -1022,6 +1031,7 @@ CONFIG_METADATA_2 = {
|
||||
"enable": True,
|
||||
"key": ["lmstudio"],
|
||||
"api_base": "http://127.0.0.1:1234/v1",
|
||||
"proxy": "",
|
||||
"custom_headers": {},
|
||||
},
|
||||
"Gemini_OpenAI_API": {
|
||||
@@ -1033,6 +1043,7 @@ CONFIG_METADATA_2 = {
|
||||
"key": [],
|
||||
"api_base": "https://generativelanguage.googleapis.com/v1beta/openai/",
|
||||
"timeout": 120,
|
||||
"proxy": "",
|
||||
"custom_headers": {},
|
||||
},
|
||||
"Groq": {
|
||||
@@ -1044,6 +1055,7 @@ CONFIG_METADATA_2 = {
|
||||
"key": [],
|
||||
"api_base": "https://api.groq.com/openai/v1",
|
||||
"timeout": 120,
|
||||
"proxy": "",
|
||||
"custom_headers": {},
|
||||
},
|
||||
"302.AI": {
|
||||
@@ -1055,6 +1067,7 @@ CONFIG_METADATA_2 = {
|
||||
"key": [],
|
||||
"api_base": "https://api.302.ai/v1",
|
||||
"timeout": 120,
|
||||
"proxy": "",
|
||||
"custom_headers": {},
|
||||
},
|
||||
"SiliconFlow": {
|
||||
@@ -1066,6 +1079,7 @@ CONFIG_METADATA_2 = {
|
||||
"key": [],
|
||||
"timeout": 120,
|
||||
"api_base": "https://api.siliconflow.cn/v1",
|
||||
"proxy": "",
|
||||
"custom_headers": {},
|
||||
},
|
||||
"PPIO": {
|
||||
@@ -1077,6 +1091,7 @@ CONFIG_METADATA_2 = {
|
||||
"key": [],
|
||||
"api_base": "https://api.ppinfra.com/v3/openai",
|
||||
"timeout": 120,
|
||||
"proxy": "",
|
||||
"custom_headers": {},
|
||||
},
|
||||
"TokenPony": {
|
||||
@@ -1088,6 +1103,7 @@ CONFIG_METADATA_2 = {
|
||||
"key": [],
|
||||
"api_base": "https://api.tokenpony.cn/v1",
|
||||
"timeout": 120,
|
||||
"proxy": "",
|
||||
"custom_headers": {},
|
||||
},
|
||||
"Compshare": {
|
||||
@@ -1099,6 +1115,7 @@ CONFIG_METADATA_2 = {
|
||||
"key": [],
|
||||
"api_base": "https://api.modelverse.cn/v1",
|
||||
"timeout": 120,
|
||||
"proxy": "",
|
||||
"custom_headers": {},
|
||||
},
|
||||
"ModelScope": {
|
||||
@@ -1110,6 +1127,7 @@ CONFIG_METADATA_2 = {
|
||||
"key": [],
|
||||
"timeout": 120,
|
||||
"api_base": "https://api-inference.modelscope.cn/v1",
|
||||
"proxy": "",
|
||||
"custom_headers": {},
|
||||
},
|
||||
"Dify": {
|
||||
@@ -1125,6 +1143,7 @@ CONFIG_METADATA_2 = {
|
||||
"dify_query_input_key": "astrbot_text_query",
|
||||
"variables": {},
|
||||
"timeout": 60,
|
||||
"proxy": "",
|
||||
},
|
||||
"Coze": {
|
||||
"id": "coze",
|
||||
@@ -1136,6 +1155,7 @@ CONFIG_METADATA_2 = {
|
||||
"bot_id": "",
|
||||
"coze_api_base": "https://api.coze.cn",
|
||||
"timeout": 60,
|
||||
"proxy": "",
|
||||
# "auto_save_history": True,
|
||||
},
|
||||
"阿里云百炼应用": {
|
||||
@@ -1154,6 +1174,7 @@ CONFIG_METADATA_2 = {
|
||||
},
|
||||
"variables": {},
|
||||
"timeout": 60,
|
||||
"proxy": "",
|
||||
},
|
||||
"FastGPT": {
|
||||
"id": "fastgpt",
|
||||
@@ -1164,6 +1185,7 @@ CONFIG_METADATA_2 = {
|
||||
"key": [],
|
||||
"api_base": "https://api.fastgpt.in/api/v1",
|
||||
"timeout": 60,
|
||||
"proxy": "",
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
},
|
||||
@@ -1176,6 +1198,7 @@ CONFIG_METADATA_2 = {
|
||||
"api_key": "",
|
||||
"api_base": "",
|
||||
"model": "whisper-1",
|
||||
"proxy": "",
|
||||
},
|
||||
"Whisper(Local)": {
|
||||
"provider": "openai",
|
||||
@@ -1205,6 +1228,7 @@ CONFIG_METADATA_2 = {
|
||||
"model": "tts-1",
|
||||
"openai-tts-voice": "alloy",
|
||||
"timeout": "20",
|
||||
"proxy": "",
|
||||
},
|
||||
"Genie TTS": {
|
||||
"id": "genie_tts",
|
||||
@@ -1285,6 +1309,7 @@ CONFIG_METADATA_2 = {
|
||||
"fishaudio-tts-character": "可莉",
|
||||
"fishaudio-tts-reference-id": "",
|
||||
"timeout": "20",
|
||||
"proxy": "",
|
||||
},
|
||||
"阿里云百炼 TTS(API)": {
|
||||
"hint": "API Key 从 https://bailian.console.aliyun.com/?tab=model#/api-key 获取。模型和音色的选择文档请参考: 阿里云百炼语音合成音色名称。具体可参考 https://help.aliyun.com/zh/model-studio/speech-synthesis-and-speech-recognition",
|
||||
@@ -1311,6 +1336,7 @@ CONFIG_METADATA_2 = {
|
||||
"azure_tts_volume": "100",
|
||||
"azure_tts_subscription_key": "",
|
||||
"azure_tts_region": "eastus",
|
||||
"proxy": "",
|
||||
},
|
||||
"MiniMax TTS(API)": {
|
||||
"id": "minimax_tts",
|
||||
@@ -1333,6 +1359,7 @@ CONFIG_METADATA_2 = {
|
||||
"minimax-voice-latex": False,
|
||||
"minimax-voice-english-normalization": False,
|
||||
"timeout": 20,
|
||||
"proxy": "",
|
||||
},
|
||||
"火山引擎_TTS(API)": {
|
||||
"id": "volcengine_tts",
|
||||
@@ -1347,6 +1374,7 @@ CONFIG_METADATA_2 = {
|
||||
"volcengine_speed_ratio": 1.0,
|
||||
"api_base": "https://openspeech.bytedance.com/api/v1/tts",
|
||||
"timeout": 20,
|
||||
"proxy": "",
|
||||
},
|
||||
"Gemini TTS": {
|
||||
"id": "gemini_tts",
|
||||
@@ -1360,6 +1388,7 @@ CONFIG_METADATA_2 = {
|
||||
"gemini_tts_model": "gemini-2.5-flash-preview-tts",
|
||||
"gemini_tts_prefix": "",
|
||||
"gemini_tts_voice_name": "Leda",
|
||||
"proxy": "",
|
||||
},
|
||||
"OpenAI Embedding": {
|
||||
"id": "openai_embedding",
|
||||
@@ -1372,6 +1401,7 @@ CONFIG_METADATA_2 = {
|
||||
"embedding_model": "",
|
||||
"embedding_dimensions": 1024,
|
||||
"timeout": 20,
|
||||
"proxy": "",
|
||||
},
|
||||
"Gemini Embedding": {
|
||||
"id": "gemini_embedding",
|
||||
@@ -1384,6 +1414,7 @@ CONFIG_METADATA_2 = {
|
||||
"embedding_model": "gemini-embedding-exp-03-07",
|
||||
"embedding_dimensions": 768,
|
||||
"timeout": 20,
|
||||
"proxy": "",
|
||||
},
|
||||
"vLLM Rerank": {
|
||||
"id": "vllm_rerank",
|
||||
@@ -2080,6 +2111,11 @@ CONFIG_METADATA_2 = {
|
||||
"description": "API Base URL",
|
||||
"type": "string",
|
||||
},
|
||||
"proxy": {
|
||||
"description": "代理地址",
|
||||
"type": "string",
|
||||
"hint": "HTTP/HTTPS 代理地址,格式如 http://127.0.0.1:7890。仅对该提供商的 API 请求生效,不影响 Docker 内网通信。",
|
||||
},
|
||||
"model": {
|
||||
"description": "模型 ID",
|
||||
"type": "string",
|
||||
|
||||
@@ -3,6 +3,7 @@ import json
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
import anthropic
|
||||
import httpx
|
||||
from anthropic import AsyncAnthropic
|
||||
from anthropic.types import Message
|
||||
from anthropic.types.message_delta_usage import MessageDeltaUsage
|
||||
@@ -14,6 +15,11 @@ from astrbot.core.agent.message import ContentPart, ImageURLPart, TextPart
|
||||
from astrbot.core.provider.entities import LLMResponse, TokenUsage
|
||||
from astrbot.core.provider.func_tool_manager import ToolSet
|
||||
from astrbot.core.utils.io import download_image_by_url
|
||||
from astrbot.core.utils.network_utils import (
|
||||
create_proxy_client,
|
||||
is_connection_error,
|
||||
log_connection_failure,
|
||||
)
|
||||
|
||||
from ..register import register_provider_adapter
|
||||
|
||||
@@ -45,12 +51,18 @@ class ProviderAnthropic(Provider):
|
||||
api_key=self.chosen_api_key,
|
||||
timeout=self.timeout,
|
||||
base_url=self.base_url,
|
||||
http_client=self._create_http_client(provider_config),
|
||||
)
|
||||
|
||||
self.thinking_config = provider_config.get("anth_thinking_config", {})
|
||||
|
||||
self.set_model(provider_config.get("model", "unknown"))
|
||||
|
||||
def _create_http_client(self, provider_config: dict) -> httpx.AsyncClient | None:
|
||||
"""创建带代理的 HTTP 客户端"""
|
||||
proxy = provider_config.get("proxy", "")
|
||||
return create_proxy_client("Anthropic", proxy)
|
||||
|
||||
def _prepare_payload(self, messages: list[dict]):
|
||||
"""准备 Anthropic API 的请求 payload
|
||||
|
||||
@@ -207,9 +219,19 @@ class ProviderAnthropic(Provider):
|
||||
"type": "enabled",
|
||||
}
|
||||
|
||||
completion = await self.client.messages.create(
|
||||
**payloads, stream=False, extra_body=extra_body
|
||||
)
|
||||
try:
|
||||
completion = await self.client.messages.create(
|
||||
**payloads, stream=False, extra_body=extra_body
|
||||
)
|
||||
except httpx.RequestError as e:
|
||||
proxy = self.provider_config.get("proxy", "")
|
||||
log_connection_failure("Anthropic", e, proxy)
|
||||
raise
|
||||
except Exception as e:
|
||||
if is_connection_error(e):
|
||||
proxy = self.provider_config.get("proxy", "")
|
||||
log_connection_failure("Anthropic", e, proxy)
|
||||
raise
|
||||
|
||||
assert isinstance(completion, Message)
|
||||
logger.debug(f"completion: {completion}")
|
||||
@@ -622,3 +644,7 @@ class ProviderAnthropic(Provider):
|
||||
|
||||
def set_key(self, key: str):
|
||||
self.chosen_api_key = key
|
||||
|
||||
async def terminate(self):
|
||||
if self.client:
|
||||
await self.client.close()
|
||||
|
||||
@@ -10,6 +10,7 @@ from xml.sax.saxutils import escape
|
||||
|
||||
from httpx import AsyncClient, Timeout
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot.core.config.default import VERSION
|
||||
|
||||
from ..entities import ProviderType
|
||||
@@ -29,6 +30,9 @@ class OTTSProvider:
|
||||
self.last_sync_time = 0
|
||||
self.timeout = Timeout(10.0)
|
||||
self.retry_count = 3
|
||||
self.proxy = config.get("proxy", "")
|
||||
if self.proxy:
|
||||
logger.info(f"[Azure TTS] 使用代理: {self.proxy}")
|
||||
self._client: AsyncClient | None = None
|
||||
|
||||
@property
|
||||
@@ -40,7 +44,9 @@ class OTTSProvider:
|
||||
return self._client
|
||||
|
||||
async def __aenter__(self):
|
||||
self._client = AsyncClient(timeout=self.timeout)
|
||||
self._client = AsyncClient(
|
||||
timeout=self.timeout, proxy=self.proxy if self.proxy else None
|
||||
)
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
@@ -125,6 +131,9 @@ class AzureNativeProvider(TTSProvider):
|
||||
"rate": provider_config.get("azure_tts_rate", "1"),
|
||||
"volume": provider_config.get("azure_tts_volume", "100"),
|
||||
}
|
||||
self.proxy = provider_config.get("proxy", "")
|
||||
if self.proxy:
|
||||
logger.info(f"[Azure TTS Native] 使用代理: {self.proxy}")
|
||||
|
||||
@property
|
||||
def client(self) -> AsyncClient:
|
||||
@@ -141,6 +150,7 @@ class AzureNativeProvider(TTSProvider):
|
||||
"Content-Type": "application/ssml+xml",
|
||||
"X-Microsoft-OutputFormat": "riff-48khz-16bit-mono-pcm",
|
||||
},
|
||||
proxy=self.proxy if self.proxy else None,
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ import ormsgpack
|
||||
from httpx import AsyncClient
|
||||
from pydantic import BaseModel, conint
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
from ..entities import ProviderType
|
||||
@@ -60,6 +61,9 @@ class ProviderFishAudioTTSAPI(TTSProvider):
|
||||
self.timeout: int = int(provider_config.get("timeout", 20))
|
||||
except ValueError:
|
||||
self.timeout = 20
|
||||
self.proxy: str = provider_config.get("proxy", "")
|
||||
if self.proxy:
|
||||
logger.info(f"[FishAudio TTS] 使用代理: {self.proxy}")
|
||||
self.headers = {
|
||||
"Authorization": f"Bearer {self.chosen_api_key}",
|
||||
}
|
||||
@@ -79,7 +83,10 @@ class ProviderFishAudioTTSAPI(TTSProvider):
|
||||
|
||||
"""
|
||||
sort_options = ["score", "task_count", "created_at"]
|
||||
async with AsyncClient(base_url=self.api_base.replace("/v1", "")) as client:
|
||||
async with AsyncClient(
|
||||
base_url=self.api_base.replace("/v1", ""),
|
||||
proxy=self.proxy if self.proxy else None,
|
||||
) as client:
|
||||
for sort_by in sort_options:
|
||||
params = {"title": character, "sort_by": sort_by}
|
||||
response = await client.get(
|
||||
@@ -139,7 +146,11 @@ class ProviderFishAudioTTSAPI(TTSProvider):
|
||||
path = os.path.join(temp_dir, f"fishaudio_tts_api_{uuid.uuid4()}.wav")
|
||||
self.headers["content-type"] = "application/msgpack"
|
||||
request = await self._generate_request(text)
|
||||
async with AsyncClient(base_url=self.api_base, timeout=self.timeout).stream(
|
||||
async with AsyncClient(
|
||||
base_url=self.api_base,
|
||||
timeout=self.timeout,
|
||||
proxy=self.proxy if self.proxy else None,
|
||||
).stream(
|
||||
"POST",
|
||||
"/tts",
|
||||
headers=self.headers,
|
||||
|
||||
@@ -4,6 +4,8 @@ from google import genai
|
||||
from google.genai import types
|
||||
from google.genai.errors import APIError
|
||||
|
||||
from astrbot import logger
|
||||
|
||||
from ..entities import ProviderType
|
||||
from ..provider import EmbeddingProvider
|
||||
from ..register import register_provider_adapter
|
||||
@@ -28,6 +30,10 @@ class GeminiEmbeddingProvider(EmbeddingProvider):
|
||||
if api_base:
|
||||
api_base = api_base.removesuffix("/")
|
||||
http_options.base_url = api_base
|
||||
proxy = provider_config.get("proxy", "")
|
||||
if proxy:
|
||||
http_options.async_client_args = {"proxy": proxy}
|
||||
logger.info(f"[Gemini Embedding] 使用代理: {proxy}")
|
||||
|
||||
self.client = genai.Client(api_key=api_key, http_options=http_options).aio
|
||||
|
||||
@@ -69,3 +75,7 @@ class GeminiEmbeddingProvider(EmbeddingProvider):
|
||||
def get_dim(self) -> int:
|
||||
"""获取向量的维度"""
|
||||
return int(self.provider_config.get("embedding_dimensions", 768))
|
||||
|
||||
async def terminate(self):
|
||||
if self.client:
|
||||
await self.client.aclose()
|
||||
|
||||
@@ -18,6 +18,7 @@ from astrbot.core.message.message_event_result import MessageChain
|
||||
from astrbot.core.provider.entities import LLMResponse, TokenUsage
|
||||
from astrbot.core.provider.func_tool_manager import ToolSet
|
||||
from astrbot.core.utils.io import download_image_by_url
|
||||
from astrbot.core.utils.network_utils import is_connection_error, log_connection_failure
|
||||
|
||||
from ..register import register_provider_adapter
|
||||
|
||||
@@ -74,12 +75,17 @@ class ProviderGoogleGenAI(Provider):
|
||||
|
||||
def _init_client(self) -> None:
|
||||
"""初始化Gemini客户端"""
|
||||
proxy = self.provider_config.get("proxy", "")
|
||||
http_options = types.HttpOptions(
|
||||
base_url=self.api_base,
|
||||
timeout=self.timeout * 1000, # 毫秒
|
||||
)
|
||||
if proxy:
|
||||
http_options.async_client_args = {"proxy": proxy}
|
||||
logger.info(f"[Gemini] 使用代理: {proxy}")
|
||||
self.client = genai.Client(
|
||||
api_key=self.chosen_api_key,
|
||||
http_options=types.HttpOptions(
|
||||
base_url=self.api_base,
|
||||
timeout=self.timeout * 1000, # 毫秒
|
||||
),
|
||||
http_options=http_options,
|
||||
).aio
|
||||
|
||||
def _init_safety_settings(self) -> None:
|
||||
@@ -113,9 +119,12 @@ class ProviderGoogleGenAI(Provider):
|
||||
f"检测到 Key 异常({e.message}),且已没有可用的 Key。 当前 Key: {self.chosen_api_key[:12]}...",
|
||||
)
|
||||
raise Exception("达到了 Gemini 速率限制, 请稍后再试...")
|
||||
# logger.error(
|
||||
# f"发生了错误(gemini_source)。Provider 配置如下: {self.provider_config}",
|
||||
# )
|
||||
|
||||
# 连接错误处理
|
||||
if is_connection_error(e):
|
||||
proxy = self.provider_config.get("proxy", "")
|
||||
log_connection_failure("Gemini", e, proxy)
|
||||
|
||||
raise e
|
||||
|
||||
async def _prepare_query_config(
|
||||
@@ -920,4 +929,5 @@ class ProviderGoogleGenAI(Provider):
|
||||
return "data:image/jpeg;base64," + image_bs64
|
||||
|
||||
async def terminate(self):
|
||||
logger.info("Google GenAI 适配器已终止。")
|
||||
if self.client:
|
||||
await self.client.aclose()
|
||||
|
||||
@@ -5,6 +5,7 @@ import wave
|
||||
from google import genai
|
||||
from google.genai import types
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
from ..entities import ProviderType
|
||||
@@ -32,6 +33,10 @@ class ProviderGeminiTTSAPI(TTSProvider):
|
||||
if api_base:
|
||||
api_base = api_base.removesuffix("/")
|
||||
http_options.base_url = api_base
|
||||
proxy = provider_config.get("proxy", "")
|
||||
if proxy:
|
||||
http_options.async_client_args = {"proxy": proxy}
|
||||
logger.info(f"[Gemini TTS] 使用代理: {proxy}")
|
||||
|
||||
self.client = genai.Client(api_key=api_key, http_options=http_options).aio
|
||||
self.model: str = provider_config.get(
|
||||
@@ -79,3 +84,7 @@ class ProviderGeminiTTSAPI(TTSProvider):
|
||||
wf.writeframes(response.candidates[0].content.parts[0].inline_data.data)
|
||||
|
||||
return path
|
||||
|
||||
async def terminate(self):
|
||||
if self.client:
|
||||
await self.client.aclose()
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
import httpx
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from astrbot import logger
|
||||
|
||||
from ..entities import ProviderType
|
||||
from ..provider import EmbeddingProvider
|
||||
from ..register import register_provider_adapter
|
||||
@@ -15,6 +18,11 @@ class OpenAIEmbeddingProvider(EmbeddingProvider):
|
||||
super().__init__(provider_config, provider_settings)
|
||||
self.provider_config = provider_config
|
||||
self.provider_settings = provider_settings
|
||||
proxy = provider_config.get("proxy", "")
|
||||
http_client = None
|
||||
if proxy:
|
||||
logger.info(f"[OpenAI Embedding] 使用代理: {proxy}")
|
||||
http_client = httpx.AsyncClient(proxy=proxy)
|
||||
self.client = AsyncOpenAI(
|
||||
api_key=provider_config.get("embedding_api_key"),
|
||||
base_url=provider_config.get(
|
||||
@@ -22,6 +30,7 @@ class OpenAIEmbeddingProvider(EmbeddingProvider):
|
||||
"https://api.openai.com/v1",
|
||||
),
|
||||
timeout=int(provider_config.get("timeout", 20)),
|
||||
http_client=http_client,
|
||||
)
|
||||
self.model = provider_config.get("embedding_model", "text-embedding-3-small")
|
||||
|
||||
@@ -38,3 +47,7 @@ class OpenAIEmbeddingProvider(EmbeddingProvider):
|
||||
def get_dim(self) -> int:
|
||||
"""获取向量的维度"""
|
||||
return int(self.provider_config.get("embedding_dimensions", 1024))
|
||||
|
||||
async def terminate(self):
|
||||
if self.client:
|
||||
await self.client.close()
|
||||
|
||||
@@ -2,11 +2,11 @@ import asyncio
|
||||
import base64
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
import httpx
|
||||
from openai import AsyncAzureOpenAI, AsyncOpenAI
|
||||
from openai._exceptions import NotFoundError
|
||||
from openai.lib.streaming.chat._completions import ChatCompletionStreamState
|
||||
@@ -22,6 +22,11 @@ from astrbot.core.agent.tool import ToolSet
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from astrbot.core.provider.entities import LLMResponse, TokenUsage, ToolCallsResult
|
||||
from astrbot.core.utils.io import download_image_by_url
|
||||
from astrbot.core.utils.network_utils import (
|
||||
create_proxy_client,
|
||||
is_connection_error,
|
||||
log_connection_failure,
|
||||
)
|
||||
|
||||
from ..register import register_provider_adapter
|
||||
|
||||
@@ -31,6 +36,11 @@ from ..register import register_provider_adapter
|
||||
"OpenAI API Chat Completion 提供商适配器",
|
||||
)
|
||||
class ProviderOpenAIOfficial(Provider):
|
||||
def _create_http_client(self, provider_config: dict) -> httpx.AsyncClient | None:
|
||||
"""创建带代理的 HTTP 客户端"""
|
||||
proxy = provider_config.get("proxy", "")
|
||||
return create_proxy_client("OpenAI", proxy)
|
||||
|
||||
def __init__(self, provider_config, provider_settings) -> None:
|
||||
super().__init__(provider_config, provider_settings)
|
||||
self.chosen_api_key = None
|
||||
@@ -55,6 +65,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
default_headers=self.custom_headers,
|
||||
base_url=provider_config.get("api_base", ""),
|
||||
timeout=self.timeout,
|
||||
http_client=self._create_http_client(provider_config),
|
||||
)
|
||||
else:
|
||||
# Using OpenAI Official API
|
||||
@@ -63,6 +74,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
base_url=provider_config.get("api_base", None),
|
||||
default_headers=self.custom_headers,
|
||||
timeout=self.timeout,
|
||||
http_client=self._create_http_client(provider_config),
|
||||
)
|
||||
|
||||
self.default_params = inspect.signature(
|
||||
@@ -455,12 +467,9 @@ class ProviderOpenAIOfficial(Provider):
|
||||
if "tool" in str(e).lower() and "support" in str(e).lower():
|
||||
logger.error("疑似该模型不支持函数调用工具调用。请输入 /tool off_all")
|
||||
|
||||
if "Connection error." in str(e):
|
||||
proxy = os.environ.get("http_proxy", None)
|
||||
if proxy:
|
||||
logger.error(
|
||||
f"可能为代理原因,请检查代理是否正常。当前代理: {proxy}",
|
||||
)
|
||||
if is_connection_error(e):
|
||||
proxy = self.provider_config.get("proxy", "")
|
||||
log_connection_failure("OpenAI", e, proxy)
|
||||
|
||||
raise e
|
||||
|
||||
@@ -697,3 +706,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
with open(image_url, "rb") as f:
|
||||
image_bs64 = base64.b64encode(f.read()).decode("utf-8")
|
||||
return "data:image/jpeg;base64," + image_bs64
|
||||
|
||||
async def terminate(self):
|
||||
if self.client:
|
||||
await self.client.close()
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
import os
|
||||
import uuid
|
||||
|
||||
import httpx
|
||||
from openai import NOT_GIVEN, AsyncOpenAI
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
from ..entities import ProviderType
|
||||
@@ -29,10 +31,16 @@ class ProviderOpenAITTSAPI(TTSProvider):
|
||||
if isinstance(timeout, str):
|
||||
timeout = int(timeout)
|
||||
|
||||
proxy = provider_config.get("proxy", "")
|
||||
http_client = None
|
||||
if proxy:
|
||||
logger.info(f"[OpenAI TTS] 使用代理: {proxy}")
|
||||
http_client = httpx.AsyncClient(proxy=proxy)
|
||||
self.client = AsyncOpenAI(
|
||||
api_key=self.chosen_api_key,
|
||||
base_url=provider_config.get("api_base"),
|
||||
timeout=timeout,
|
||||
http_client=http_client,
|
||||
)
|
||||
|
||||
self.set_model(provider_config.get("model", ""))
|
||||
@@ -50,3 +58,7 @@ class ProviderOpenAITTSAPI(TTSProvider):
|
||||
async for chunk in response.iter_bytes(chunk_size=1024):
|
||||
f.write(chunk)
|
||||
return path
|
||||
|
||||
async def terminate(self):
|
||||
if self.client:
|
||||
await self.client.close()
|
||||
|
||||
@@ -107,3 +107,7 @@ class ProviderOpenAIWhisperAPI(STTProvider):
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to remove temp file {audio_url}: {e}")
|
||||
return result.text
|
||||
|
||||
async def terminate(self):
|
||||
if self.client:
|
||||
await self.client.close()
|
||||
|
||||
@@ -0,0 +1,104 @@
|
||||
"""Network error handling utilities for providers."""
|
||||
|
||||
import httpx
|
||||
|
||||
from astrbot import logger
|
||||
|
||||
|
||||
def is_connection_error(exc: BaseException) -> bool:
|
||||
"""Check if an exception is a connection/network related error.
|
||||
|
||||
Uses explicit exception type checking instead of brittle string matching.
|
||||
Handles httpx network errors, timeouts, and common Python network exceptions.
|
||||
|
||||
Args:
|
||||
exc: The exception to check
|
||||
|
||||
Returns:
|
||||
True if the exception is a connection/network error
|
||||
"""
|
||||
# Check for httpx network errors
|
||||
if isinstance(
|
||||
exc,
|
||||
(
|
||||
httpx.ConnectError,
|
||||
httpx.ConnectTimeout,
|
||||
httpx.ReadTimeout,
|
||||
httpx.WriteTimeout,
|
||||
httpx.PoolTimeout,
|
||||
httpx.NetworkError,
|
||||
httpx.ProxyError,
|
||||
httpx.RequestError,
|
||||
),
|
||||
):
|
||||
return True
|
||||
|
||||
# Check for common Python network errors
|
||||
if isinstance(exc, (TimeoutError, OSError, ConnectionError)):
|
||||
return True
|
||||
|
||||
# Check the __cause__ chain for wrapped connection errors
|
||||
cause = getattr(exc, "__cause__", None)
|
||||
if cause is not None and cause is not exc:
|
||||
return is_connection_error(cause)
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def log_connection_failure(
|
||||
provider_label: str,
|
||||
error: Exception,
|
||||
proxy: str | None = None,
|
||||
) -> None:
|
||||
"""Log a connection failure with proxy information.
|
||||
|
||||
If proxy is not provided, will fallback to check os.environ for
|
||||
http_proxy/https_proxy environment variables.
|
||||
|
||||
Args:
|
||||
provider_label: The provider name for log prefix (e.g., "OpenAI", "Gemini")
|
||||
error: The exception that occurred
|
||||
proxy: The proxy address if configured, or None/empty string
|
||||
"""
|
||||
import os
|
||||
|
||||
error_type = type(error).__name__
|
||||
|
||||
# Fallback to environment proxy if not configured
|
||||
effective_proxy = proxy
|
||||
if not effective_proxy:
|
||||
effective_proxy = os.environ.get(
|
||||
"http_proxy", os.environ.get("https_proxy", "")
|
||||
)
|
||||
|
||||
if effective_proxy:
|
||||
logger.error(
|
||||
f"[{provider_label}] 网络/代理连接失败 ({error_type})。"
|
||||
f"代理地址: {effective_proxy},错误: {error}"
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
f"[{provider_label}] 网络连接失败 ({error_type}),未配置代理。错误: {error}"
|
||||
)
|
||||
|
||||
|
||||
def create_proxy_client(
|
||||
provider_label: str,
|
||||
proxy: str | None = None,
|
||||
) -> httpx.AsyncClient | None:
|
||||
"""Create an httpx AsyncClient with proxy configuration if provided.
|
||||
|
||||
Note: The caller is responsible for closing the client when done.
|
||||
Consider using the client as a context manager or calling aclose() explicitly.
|
||||
|
||||
Args:
|
||||
provider_label: The provider name for log prefix (e.g., "OpenAI", "Gemini")
|
||||
proxy: The proxy address (e.g., "http://127.0.0.1:7890"), or None/empty
|
||||
|
||||
Returns:
|
||||
An httpx.AsyncClient configured with the proxy, or None if no proxy
|
||||
"""
|
||||
if proxy:
|
||||
logger.info(f"[{provider_label}] 使用代理: {proxy}")
|
||||
return httpx.AsyncClient(proxy=proxy)
|
||||
return None
|
||||
@@ -233,6 +233,11 @@ export function useProviderSources(options: UseProviderSourcesOptions) {
|
||||
customSchema.provider.items.key.hint = tm('providerSources.hints.key')
|
||||
customSchema.provider.items.api_base.hint = tm('providerSources.hints.apiBase')
|
||||
}
|
||||
// 为 proxy 字段添加描述和提示
|
||||
if (customSchema.provider?.items?.proxy) {
|
||||
customSchema.provider.items.proxy.description = tm('providerSources.labels.proxy')
|
||||
customSchema.provider.items.proxy.hint = tm('providerSources.hints.proxy')
|
||||
}
|
||||
|
||||
return customSchema
|
||||
})
|
||||
|
||||
@@ -113,7 +113,11 @@
|
||||
"hints": {
|
||||
"id": "Provider source ID (not provider ID)",
|
||||
"key": "API key for authentication",
|
||||
"apiBase": "Custom API endpoint URL"
|
||||
"apiBase": "Custom API endpoint URL",
|
||||
"proxy": "HTTP/HTTPS proxy address, e.g. http://127.0.0.1:7890. Only affects this provider's API requests, doesn't interfere with Docker internal networking."
|
||||
},
|
||||
"labels": {
|
||||
"proxy": "Proxy"
|
||||
}
|
||||
},
|
||||
"models": {
|
||||
@@ -142,4 +146,4 @@
|
||||
"modelId": "Model ID"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -114,7 +114,11 @@
|
||||
"hints": {
|
||||
"id": "提供商源唯一 ID(不是提供商 ID)",
|
||||
"key": "API 密钥",
|
||||
"apiBase": "自定义 API 端点 URL"
|
||||
"apiBase": "自定义 API 端点 URL",
|
||||
"proxy": "HTTP/HTTPS 代理地址,格式如 http://127.0.0.1:7890。仅对该提供商的 API 请求生效,不影响 Docker 内网通信。"
|
||||
},
|
||||
"labels": {
|
||||
"proxy": "代理地址"
|
||||
}
|
||||
},
|
||||
"models": {
|
||||
@@ -143,4 +147,4 @@
|
||||
"modelId": "模型 ID"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user