From 7d4c07e4f6bf279189d3891c97343336e09b9141 Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Thu, 6 Feb 2025 12:31:39 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=94=AF=E6=8C=81=E8=AE=BE=E7=BD=AE=20?= =?UTF-8?q?timeout?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/core/config/default.py | 12 ++++++++++++ .../core/provider/sources/gemini_source.py | 19 ++++++++++++------- .../core/provider/sources/openai_source.py | 13 ++++++++----- 3 files changed, 32 insertions(+), 12 deletions(-) diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index eec29cb47..b687adbde 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -338,6 +338,7 @@ CONFIG_METADATA_2 = { "enable": True, "key": [], "api_base": "https://api.openai.com/v1", + "timeout": 120, "model_config": { "model": "gpt-4o-mini", }, @@ -349,6 +350,7 @@ CONFIG_METADATA_2 = { "api_version": "2024-05-01-preview", "key": [], "api_base": "", + "timeout": 120, "model_config": { "model": "gpt-4o-mini", }, @@ -369,6 +371,7 @@ CONFIG_METADATA_2 = { "enable": True, "key": [], "api_base": "https://generativelanguage.googleapis.com/v1beta/openai/", + "timeout": 120, "model_config": { "model": "gemini-1.5-flash", }, @@ -379,6 +382,7 @@ CONFIG_METADATA_2 = { "enable": True, "key": [], "api_base": "https://generativelanguage.googleapis.com/", + "timeout": 120, "model_config": { "model": "gemini-1.5-flash", }, @@ -389,6 +393,7 @@ CONFIG_METADATA_2 = { "enable": True, "key": [], "api_base": "https://api.deepseek.com/v1", + "timeout": 120, "model_config": { "model": "deepseek-chat", }, @@ -398,6 +403,7 @@ CONFIG_METADATA_2 = { "type": "zhipu_chat_completion", "enable": True, "key": [], + "timeout": 120, "api_base": "https://open.bigmodel.cn/api/paas/v4/", "model_config": { "model": "glm-4-flash", @@ -408,6 +414,7 @@ CONFIG_METADATA_2 = { "type": "openai_chat_completion", "enable": True, "key": [], + "timeout": 120, "api_base": "https://api.siliconflow.cn/v1", "model_config": { "model": "deepseek-ai/DeepSeek-V3", @@ -459,6 +466,11 @@ CONFIG_METADATA_2 = { }, }, "items": { + "timeout": { + "description": "超时时间", + "type": "int", + "hint": "超时时间,单位为秒。", + }, "openai-tts-voice": { "description": "voice", "type": "string", diff --git a/astrbot/core/provider/sources/gemini_source.py b/astrbot/core/provider/sources/gemini_source.py index 3de22a4fc..63cf7f9a2 100644 --- a/astrbot/core/provider/sources/gemini_source.py +++ b/astrbot/core/provider/sources/gemini_source.py @@ -1,6 +1,4 @@ -import traceback import base64 -import json import aiohttp from astrbot.core.utils.io import download_image_by_url from astrbot.core.db import BaseDatabase @@ -12,17 +10,18 @@ from ..register import register_provider_adapter from astrbot.core.provider.entites import LLMResponse class SimpleGoogleGenAIClient(): - def __init__(self, api_key: str, api_base: str): + def __init__(self, api_key: str, api_base: str, timeout: int=120) -> None: self.api_key = api_key if api_base.endswith("/"): self.api_base = api_base[:-1] else: self.api_base = api_base self.client = aiohttp.ClientSession(trust_env=True) + self.timeout = timeout async def models_list(self) -> List[str]: request_url = f"{self.api_base}/v1beta/models?key={self.api_key}" - async with self.client.get(request_url, timeout=10) as resp: + async with self.client.get(request_url, timeout=self.timeout) as resp: response = await resp.json() models = [] @@ -48,7 +47,7 @@ class SimpleGoogleGenAIClient(): payload["contents"] = contents logger.debug(f"payload: {payload}") request_url = f"{self.api_base}/v1beta/models/{model}:generateContent?key={self.api_key}" - async with self.client.post(request_url, json=payload, timeout=10) as resp: + async with self.client.post(request_url, json=payload, timeout=self.timeout) as resp: response = await resp.json() return response @@ -67,10 +66,13 @@ class ProviderGoogleGenAI(Provider): self.chosen_api_key = None self.api_keys: List = provider_config.get("key", []) self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else None - + self.timeout = provider_config.get("timeout", 180) + if isinstance(self.timeout, str): + self.timeout = int(self.timeout) self.client = SimpleGoogleGenAIClient( api_key=self.chosen_api_key, - api_base=provider_config.get("api_base", None) + api_base=provider_config.get("api_base", None), + timeout=self.timeout ) self.set_model(provider_config['model_config']['model']) @@ -224,6 +226,9 @@ class ProviderGoogleGenAI(Provider): if image_url.startswith("http"): image_path = await download_image_by_url(image_url) image_data = await self.encode_image_bs64(image_path) + elif image_url.startswith("file:///"): + image_path = image_url.replace("file:///", "") + image_data = await self.encode_image_bs64(image_path) else: image_data = await self.encode_image_bs64(image_url) user_content["content"].append({"type": "image_url", "image_url": {"url": image_data}}) diff --git a/astrbot/core/provider/sources/openai_source.py b/astrbot/core/provider/sources/openai_source.py index 996693b13..4b4b678ec 100644 --- a/astrbot/core/provider/sources/openai_source.py +++ b/astrbot/core/provider/sources/openai_source.py @@ -29,7 +29,9 @@ class ProviderOpenAIOfficial(Provider): self.chosen_api_key = None self.api_keys: List = provider_config.get("key", []) self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else None - + self.timeout = provider_config.get("timeout", 120) + if isinstance(self.timeout, str): + self.timeout = int(self.timeout) # 适配 azure openai #332 if "api_version" in provider_config: # 使用 azure api @@ -37,14 +39,14 @@ class ProviderOpenAIOfficial(Provider): api_key=self.chosen_api_key, api_version=provider_config.get("api_version", None), base_url=provider_config.get("api_base", None), - timeout=provider_config.get("timeout", NOT_GIVEN), + timeout=self.timeout ) else: # 使用 openai api self.client = AsyncOpenAI( api_key=self.chosen_api_key, base_url=provider_config.get("api_base", None), - timeout=provider_config.get("timeout", NOT_GIVEN), + timeout=self.timeout ) self.set_model(provider_config['model_config']['model']) @@ -227,9 +229,10 @@ class ProviderOpenAIOfficial(Provider): if image_url.startswith("http"): image_path = await download_image_by_url(image_url) image_data = await self.encode_image_bs64(image_path) + elif image_url.startswith("file:///"): + image_path = image_url.replace("file:///", "") + image_data = await self.encode_image_bs64(image_path) else: - if image_url.startswith("file:///"): - image_url = image_url.replace("file:///", "") image_data = await self.encode_image_bs64(image_url) user_content["content"].append({"type": "image_url", "image_url": {"url": image_data}}) return user_content