From 0b766095d4c2f0fa6d6a2a13ed213d316a73affe Mon Sep 17 00:00:00 2001 From: Raven95676 Date: Fri, 11 Apr 2025 01:03:16 +0800 Subject: [PATCH 01/19] =?UTF-8?q?refactor:=20=E5=88=9D=E6=AD=A5=E5=AE=8C?= =?UTF-8?q?=E6=88=90gemini=5Fsource=E7=9A=84=E9=87=8D=E5=86=99?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../core/provider/sources/gemini_source.py | 428 ++++++++---------- 1 file changed, 193 insertions(+), 235 deletions(-) diff --git a/astrbot/core/provider/sources/gemini_source.py b/astrbot/core/provider/sources/gemini_source.py index 7def203db..c9da2ec9f 100644 --- a/astrbot/core/provider/sources/gemini_source.py +++ b/astrbot/core/provider/sources/gemini_source.py @@ -1,8 +1,9 @@ import base64 -import aiohttp import json import random import asyncio +from google import genai +from google.genai import types, errors import astrbot.core.message.components as Comp from astrbot.core.message.message_event_result import MessageChain from astrbot.core.utils.io import download_image_by_url @@ -10,112 +11,28 @@ from astrbot.core.db import BaseDatabase from astrbot.api.provider import Provider, Personality from astrbot import logger from astrbot.core.provider.func_tool_manager import FuncCall -from typing import List +from typing import Dict, List from ..register import register_provider_adapter from astrbot.core.provider.entities import LLMResponse - -class SimpleGoogleGenAIClient: - def __init__(self, api_key: str, api_base: str, timeout: int = 120) -> None: - self.api_key = api_key - if api_base.endswith("/"): - self.api_base = api_base[:-1] - else: - self.api_base = api_base - self.client = aiohttp.ClientSession(trust_env=True) - self.timeout = timeout - - async def models_list(self) -> List[str]: - request_url = f"{self.api_base}/v1beta/models?key={self.api_key}" - async with self.client.get(request_url, timeout=self.timeout) as resp: - response = await resp.json() - - models = [] - for model in response["models"]: - if "generateContent" in model["supportedGenerationMethods"]: - models.append(model["name"].replace("models/", "")) - return models - - async def generate_content( - self, - contents: List[dict], - model: str = "gemini-1.5-flash", - system_instruction: str = "", - tools: dict = None, - modalities: List[str] = ["Text"], - safety_settings: List[dict] = [], - ): - payload = {} - if system_instruction: - payload["system_instruction"] = {"parts": {"text": system_instruction}} - if tools: - payload["tools"] = [tools] - payload["contents"] = contents - payload["generationConfig"] = { - "responseModalities": modalities, - } - payload["safetySettings"] = [ - {"category": s["category"], "threshold": s["threshold"]} - for s in safety_settings - ] - logger.debug(f"payload: {payload}") - request_url = ( - f"{self.api_base}/v1beta/models/{model}:generateContent?key={self.api_key}" - ) - async with self.client.post( - request_url, json=payload, timeout=self.timeout - ) as resp: - if "application/json" in resp.headers.get("Content-Type"): - try: - response = await resp.json() - except Exception as e: - text = await resp.text() - logger.error(f"Gemini 返回了非 json 数据: {text}") - raise e - return response - else: - text = await resp.text() - logger.error(f"Gemini 返回了非 json 数据: {text}") - raise Exception("Gemini 返回了非 json 数据: ") - - async def stream_generate_content( - self, - contents: List[dict], - model: str = "gemini-1.5-flash", - system_instruction: str = "", - tools: dict = None, - modalities: List[str] = ["Text"], - safety_settings: List[dict] = [], - ): - payload = {} - if system_instruction: - payload["system_instruction"] = {"parts": {"text": system_instruction}} - if tools: - payload["tools"] = [tools] - payload["contents"] = contents - payload["generationConfig"] = { - "responseModalities": modalities, - "stream": True, - } - payload["safetySettings"] = [ - {"category": s["category"], "threshold": s["threshold"]} - for s in safety_settings - ] - logger.debug(f"payload: {payload}") - request_url = ( - f"{self.api_base}/v1beta/models/{model}:streamGenerateContent?key={self.api_key}" - ) - async with self.client.post( - request_url, json=payload, timeout=self.timeout - ) as resp: - async for line in resp.content: - if line: - yield line - @register_provider_adapter( "googlegenai_chat_completion", "Google Gemini Chat Completion 提供商适配器" ) class ProviderGoogleGenAI(Provider): + CATEGORY_MAPPING = { + "harassment": types.HarmCategory.HARM_CATEGORY_HARASSMENT, + "hate_speech": types.HarmCategory.HARM_CATEGORY_HATE_SPEECH, + "sexually_explicit": types.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + "dangerous_content": types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + } + + THRESHOLD_MAPPING = { + "BLOCK_NONE": types.HarmBlockThreshold.BLOCK_NONE, + "BLOCK_ONLY_HIGH": types.HarmBlockThreshold.BLOCK_ONLY_HIGH, + "BLOCK_MEDIUM_AND_ABOVE": types.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, + "BLOCK_LOW_AND_ABOVE": types.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, + } + def __init__( self, provider_config: dict, @@ -131,43 +48,145 @@ class ProviderGoogleGenAI(Provider): db_helper, default_persona, ) - self.chosen_api_key = None self.api_keys: List = provider_config.get("key", []) - self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else None - self.timeout = provider_config.get("timeout", 180) + self.chosen_api_key: str = self.api_keys[0] if len(self.api_keys) > 0 else None + self.timeout: int = provider_config.get("timeout", 180) + self.api_base: str = provider_config.get("api_base", None) + if self.api_base.endswith("/"): + self.api_base = self.api_base[:-1] if isinstance(self.timeout, str): self.timeout = int(self.timeout) - self.client = SimpleGoogleGenAIClient( + self.client = genai.Client( api_key=self.chosen_api_key, - api_base=provider_config.get("api_base", None), - timeout=self.timeout, - ) + http_options=types.HttpOptions( + base_url=self.api_base, + timeout=self.timeout * 1000, # 毫秒 + ), + ).aio self.set_model(provider_config["model_config"]["model"]) - safety_mapping = { - "harassment": "HARM_CATEGORY_HARASSMENT", - "hate_speech": "HARM_CATEGORY_HATE_SPEECH", - "sexually_explicit": "HARM_CATEGORY_SEXUALLY_EXPLICIT", - "dangerous_content": "HARM_CATEGORY_DANGEROUS_CONTENT", - } - - self.safety_settings = [] user_safety_config = self.provider_config.get("gm_safety_settings", {}) - for config_key, harm_category in safety_mapping.items(): - if threshold := user_safety_config.get(config_key): - self.safety_settings.append( - {"category": harm_category, "threshold": threshold} - ) + self.safety_settings = [ + types.SafetySetting( + category=harm_category, threshold=self.THRESHOLD_MAPPING[threshold_str] + ) + for config_key, harm_category in self.CATEGORY_MAPPING.items() + if (threshold_str := user_safety_config.get(config_key)) + and threshold_str in self.THRESHOLD_MAPPING + ] async def get_models(self): - return await self.client.models_list() + try: + models = await self.client.models.list() + return [ + m.name.replace("models/", "") + for m in models + if "generateContent" in m.supported_actions + ] + except errors.APIError as e: + raise Exception(f"获取模型列表失败: {e}") + + def _prepare_conversation( + self, + payloads: Dict, + ) -> List[types.Content]: + """准备 Gemini SDK 的 Content 列表""" + gemini_contents = [] + for message in payloads["messages"]: + role = message["role"] + content = message.get("content") + + if role == "user": + if isinstance(content, str): + if content: + gemini_contents.append( + types.UserContent( + parts=[types.Part.from_text(text=content)] + ) + ) + else: + logger.warning("文本内容为空,已添加空格占位") + gemini_contents.append( + types.UserContent(parts=[types.Part.from_text(text=" ")]) + ) + + elif isinstance(content, list): + parts = [] + for item in content: + if item.get("type") == "text": + text_content = item.get("text") + if text_content: + parts.append(types.Part.from_text(text=text_content)) + else: + logger.warning("文本内容为空,已添加空格占位") + parts.append(types.Part.from_text(text=" ")) + elif item.get("type") == "image_url": + image_url_dict = item["image_url"] + url = image_url_dict["url"] + mime_part, base64_data = url.split(",", 1) + mime_type = mime_part.split(":")[1].split(";")[0] + image_bytes = base64.b64decode(base64_data) + parts.append( + types.Part.from_bytes( + data=image_bytes, mime_type=mime_type + ) + ) + gemini_contents.append(types.UserContent(parts=parts)) + + elif role == "assistant": + if content: + gemini_contents.append( + types.ModelContent( + parts=[types.Part.from_text(text=message["content"])] + ) + ) + elif "tool_calls" in message: + parts = [ + { + "name": tool_call["function"]["name"], + "args": json.loads(tool_call["function"]["arguments"]), + } + for tool_call in message["tool_calls"] + ] + gemini_contents.append( + types.ModelContent(parts=[types.Part.from_function_call(parts)]) + ) + else: + logger.warning("assistant 角色的消息内容为空,已添加空格占位") + gemini_contents.append( + types.ModelContent(parts=[types.Part.from_text(text=" ")]) + ) + + elif role == "tool": + gemini_contents.append( + types.UserContent( + parts=[ + types.Part.from_function_response( + { + "name": message["tool_call_id"], + "response": { + "name": message["tool_call_id"], + "content": message["content"], + }, + } + ) + ] + ) + ) + + logger.debug(f"gemini_contents: {gemini_contents}") + + return gemini_contents async def _query(self, payloads: dict, tools: FuncCall) -> LLMResponse: - tool = None + """非流式请求 Gemini API""" if tools: - tool = tools.get_func_desc_google_genai_style() - if not tool: - tool = None + t = tools.get_func_desc_google_genai_style() + tool = ( + types.Tool(function_declarations=t["function_declarations"]) + if t + else None + ) system_instruction = "" for message in payloads["messages"]: @@ -175,137 +194,78 @@ class ProviderGoogleGenAI(Provider): system_instruction = message["content"] break - google_genai_conversation = [] - for message in payloads["messages"]: - if message["role"] == "user": - if isinstance(message["content"], str): - if not message["content"]: - message["content"] = "" - - google_genai_conversation.append( - {"role": "user", "parts": [{"text": message["content"]}]} - ) - elif isinstance(message["content"], list): - # images - parts = [] - for part in message["content"]: - if part["type"] == "text": - if not part["text"]: - part["text"] = "" - parts.append({"text": part["text"]}) - elif part["type"] == "image_url": - parts.append( - { - "inline_data": { - "mime_type": "image/jpeg", - "data": part["image_url"]["url"].replace( - "data:image/jpeg;base64,", "" - ), # base64 - } - } - ) - google_genai_conversation.append({"role": "user", "parts": parts}) - - elif message["role"] == "assistant": - if "content" in message: - if not message["content"]: - message["content"] = "" - google_genai_conversation.append( - {"role": "model", "parts": [{"text": message["content"]}]} - ) - elif "tool_calls" in message: - # tool calls in the last turn - parts = [] - for tool_call in message["tool_calls"]: - parts.append( - { - "functionCall": { - "name": tool_call["function"]["name"], - "args": json.loads( - tool_call["function"]["arguments"] - ), - } - } - ) - google_genai_conversation.append({"role": "model", "parts": parts}) - elif message["role"] == "tool": - parts = [] - parts.append( - { - "functionResponse": { - "name": message["tool_call_id"], - "response": { - "name": message["tool_call_id"], - "content": message["content"], - }, - } - } - ) - google_genai_conversation.append({"role": "user", "parts": parts}) - - logger.debug(f"google_genai_conversation: {google_genai_conversation}") + conversation = self._prepare_conversation(payloads) modalites = ["Text"] if self.provider_config.get("gm_resp_image_modal", False): modalites.append("Image") loop = True + while loop: loop = False - result = await self.client.generate_content( - contents=google_genai_conversation, + result = await self.client.models.generate_content( model=self.get_model(), - system_instruction=system_instruction, - tools=tool, - modalities=modalites, - safety_settings=self.safety_settings, + contents=conversation, + config=types.GenerateContentConfig( + system_instruction=system_instruction, + tools=[tool] if tool else None, + safety_settings=self.safety_settings + if self.safety_settings + else None, + automatic_function_calling=types.AutomaticFunctionCallingConfig( + disable=True + ), + ), ) - logger.debug(f"result: {result}") + logger.debug(f"gemini result: {result}") - # Developer instruction is not enabled for models/gemini-2.0-flash-exp if "Developer instruction is not enabled" in str(result): - logger.warning( - f"{self.get_model()} 不支持 system prompt, 已自动去除, 将会影响人格设置。" - ) + logger.warning(f"{self.get_model()} 不支持 system prompt,已自动去除。") system_instruction = "" loop = True - + # 不支持函数调用的模型SDK似乎会自动去除,保险起见不删除此行判断。 elif "Function calling is not enabled" in str(result): - logger.warning( - f"{self.get_model()} 不支持函数调用,已自动去除,不影响使用。" - ) + logger.warning(f"{self.get_model()} 不支持函数调用,已自动去除。") tool = None loop = True - elif "Multi-modal output is not supported" in str(result): - logger.warning( - f"{self.get_model()} 不支持多模态输出,降级为文本模态重新请求。" - ) + logger.warning(f"{self.get_model()} 不支持多模态输出,降级为文本模态。") modalites = ["Text"] loop = True - elif "candidates" not in result: - raise Exception("Gemini 返回异常结果: " + str(result)) - - candidates = result["candidates"][0]["content"]["parts"] llm_response = LLMResponse("assistant") chain = [] - for candidate in candidates: - if "text" in candidate: - chain.append(Comp.Plain(candidate["text"])) - elif "functionCall" in candidate: + + finish_reason = result.candidates[0].finish_reason + + if finish_reason == types.FinishReason.SAFETY: + raise Exception("模型生成内容未通过用户定义的内容安全检查") + + if finish_reason in { + types.FinishReason.PROHIBITED_CONTENT, + types.FinishReason.SPII, + types.FinishReason.BLOCKLIST, + types.FinishReason.IMAGE_SAFETY, + }: + raise Exception("模型生成内容违反Gemini平台政策") + + if not result.candidates[0].content.parts: + raise Exception("API 返回的内容为空。") + + for part in result.candidates[0].content.parts: + if part.text: + chain.append(Comp.Plain(part.text)) + elif part.function_call: llm_response.role = "tool" - llm_response.tools_call_args.append(candidate["functionCall"]["args"]) - llm_response.tools_call_name.append(candidate["functionCall"]["name"]) - llm_response.tools_call_ids.append( - candidate["functionCall"]["name"] - ) # 没有 tool id - elif "inlineData" in candidate: - mime_type: str = candidate["inlineData"]["mimeType"] - if mime_type.startswith("image/"): - chain.append(Comp.Image.fromBase64(candidate["inlineData"]["data"])) + llm_response.tools_call_name.append(part.function_call.name) + llm_response.tools_call_args.append(part.function_call.args) + llm_response.tools_call_ids.append(part.function_call.id) + elif part.inline_data and part.inline_data.mime_type.startswith("image/"): + chain.append(Comp.Image.fromBytes(part.inline_data.data)) llm_response.result_chain = MessageChain(chain=chain) + return llm_response async def text_chat( @@ -320,7 +280,6 @@ class ProviderGoogleGenAI(Provider): **kwargs, ) -> LLMResponse: new_record = await self.assemble_context(prompt, image_urls) - context_query = [] context_query = [*contexts, new_record] if system_prompt: context_query.insert(0, {"role": "system", "content": system_prompt}) @@ -345,7 +304,7 @@ class ProviderGoogleGenAI(Provider): for i in range(retry): try: - self.client.api_key = chosen_key + self.chosen_api_key = chosen_key llm_response = await self._query(payloads, func_tool) break except Exception as e: @@ -399,13 +358,13 @@ class ProviderGoogleGenAI(Provider): yield llm_response def get_current_key(self) -> str: - return self.client.api_key + return self.chosen_api_key def get_keys(self) -> List[str]: return self.api_keys def set_key(self, key): - self.client.api_key = key + self.chosen_api_key = key async def assemble_context(self, text: str, image_urls: List[str] = None): """ @@ -444,5 +403,4 @@ class ProviderGoogleGenAI(Provider): return "" async def terminate(self): - await self.client.client.close() logger.info("Google GenAI 适配器已终止。") From 4244d376250ecf0dbee58ab87944b8748af4b006 Mon Sep 17 00:00:00 2001 From: Raven95676 Date: Fri, 11 Apr 2025 01:06:20 +0800 Subject: [PATCH 02/19] =?UTF-8?q?chore:=20=E6=A0=BC=E5=BC=8F=E5=8C=96?= =?UTF-8?q?=E4=BB=A3=E7=A0=81=EF=BC=8C=E7=A6=81=E7=94=A8gemini=20source=20?= =?UTF-8?q?debug=E8=BE=93=E5=87=BA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../core/provider/sources/gemini_source.py | 32 +++++++++++-------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/astrbot/core/provider/sources/gemini_source.py b/astrbot/core/provider/sources/gemini_source.py index c9da2ec9f..527ebf348 100644 --- a/astrbot/core/provider/sources/gemini_source.py +++ b/astrbot/core/provider/sources/gemini_source.py @@ -1,19 +1,23 @@ +import asyncio import base64 import json import random -import asyncio -from google import genai -from google.genai import types, errors -import astrbot.core.message.components as Comp -from astrbot.core.message.message_event_result import MessageChain -from astrbot.core.utils.io import download_image_by_url -from astrbot.core.db import BaseDatabase -from astrbot.api.provider import Provider, Personality -from astrbot import logger -from astrbot.core.provider.func_tool_manager import FuncCall from typing import Dict, List -from ..register import register_provider_adapter + +from google import genai +from google.genai import errors, types + +import astrbot.core.message.components as Comp +from astrbot import logger +from astrbot.api.provider import Personality, Provider +from astrbot.core.db import BaseDatabase +from astrbot.core.message.message_event_result import MessageChain from astrbot.core.provider.entities import LLMResponse +from astrbot.core.provider.func_tool_manager import FuncCall +from astrbot.core.utils.io import download_image_by_url + +from ..register import register_provider_adapter + @register_provider_adapter( "googlegenai_chat_completion", "Google Gemini Chat Completion 提供商适配器" @@ -174,7 +178,7 @@ class ProviderGoogleGenAI(Provider): ) ) - logger.debug(f"gemini_contents: {gemini_contents}") + # logger.debug(f"gemini_contents: {gemini_contents}") return gemini_contents @@ -218,7 +222,7 @@ class ProviderGoogleGenAI(Provider): ), ), ) - logger.debug(f"gemini result: {result}") + # logger.debug(f"gemini result: {result}") if "Developer instruction is not enabled" in str(result): logger.warning(f"{self.get_model()} 不支持 system prompt,已自动去除。") @@ -302,7 +306,7 @@ class ProviderGoogleGenAI(Provider): keys = self.api_keys.copy() chosen_key = random.choice(keys) - for i in range(retry): + for _ in range(retry): try: self.chosen_api_key = chosen_key llm_response = await self._query(payloads, func_tool) From cc6cd96d8e8b0df3d50c63d1c481bd8f79d8ded1 Mon Sep 17 00:00:00 2001 From: Raven95676 Date: Fri, 11 Apr 2025 11:03:17 +0800 Subject: [PATCH 03/19] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=E6=BD=9C?= =?UTF-8?q?=E5=9C=A8=E7=9A=84=E7=A9=BA=E6=B6=88=E6=81=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/core/provider/sources/gemini_source.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/astrbot/core/provider/sources/gemini_source.py b/astrbot/core/provider/sources/gemini_source.py index 527ebf348..b2fd0a429 100644 --- a/astrbot/core/provider/sources/gemini_source.py +++ b/astrbot/core/provider/sources/gemini_source.py @@ -375,7 +375,10 @@ class ProviderGoogleGenAI(Provider): 组装上下文。 """ if image_urls: - user_content = {"role": "user", "content": [{"type": "text", "text": text}]} + user_content = { + "role": "user", + "content": [{"type": "text", "text": text if text else "[图片]"}], + } for image_url in image_urls: if image_url.startswith("http"): image_path = await download_image_by_url(image_url) From 0dc5b4cdfc79bdad323f4f491a02ceac29e6bb94 Mon Sep 17 00:00:00 2001 From: Raven95676 Date: Fri, 11 Apr 2025 12:25:44 +0800 Subject: [PATCH 04/19] =?UTF-8?q?perf:=20=E5=A2=9E=E5=8A=A0=E5=AF=B9RECITA?= =?UTF-8?q?TION=E5=AE=8C=E6=88=90=E5=8E=9F=E5=9B=A0=E7=9A=84=E5=A4=84?= =?UTF-8?q?=E7=90=86=EF=BC=8C=E6=8F=90=E5=8F=96=E5=86=85=E5=AE=B9=E5=A4=84?= =?UTF-8?q?=E7=90=86=E9=80=BB=E8=BE=91=E5=88=B0=E7=8B=AC=E7=AB=8B=E6=96=B9?= =?UTF-8?q?=E6=B3=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../core/provider/sources/gemini_source.py | 54 +++++++++++++------ 1 file changed, 38 insertions(+), 16 deletions(-) diff --git a/astrbot/core/provider/sources/gemini_source.py b/astrbot/core/provider/sources/gemini_source.py index b2fd0a429..89ff69559 100644 --- a/astrbot/core/provider/sources/gemini_source.py +++ b/astrbot/core/provider/sources/gemini_source.py @@ -182,7 +182,9 @@ class ProviderGoogleGenAI(Provider): return gemini_contents - async def _query(self, payloads: dict, tools: FuncCall) -> LLMResponse: + async def _query( + self, payloads: dict, tools: FuncCall, temperature: float = 0.7 + ) -> LLMResponse: """非流式请求 Gemini API""" if tools: t = tools.get_func_desc_google_genai_style() @@ -205,7 +207,6 @@ class ProviderGoogleGenAI(Provider): modalites.append("Image") loop = True - while loop: loop = False result = await self.client.models.generate_content( @@ -213,6 +214,7 @@ class ProviderGoogleGenAI(Provider): contents=conversation, config=types.GenerateContentConfig( system_instruction=system_instruction, + temperature=temperature, tools=[tool] if tool else None, safety_settings=self.safety_settings if self.safety_settings @@ -222,26 +224,35 @@ class ProviderGoogleGenAI(Provider): ), ), ) - # logger.debug(f"gemini result: {result}") - if "Developer instruction is not enabled" in str(result): + result_str = str(result) + finish_reason = result.candidates[0].finish_reason + if "Developer instruction is not enabled" in result_str: logger.warning(f"{self.get_model()} 不支持 system prompt,已自动去除。") system_instruction = "" loop = True - # 不支持函数调用的模型SDK似乎会自动去除,保险起见不删除此行判断。 - elif "Function calling is not enabled" in str(result): + continue + elif "Function calling is not enabled" in result_str: logger.warning(f"{self.get_model()} 不支持函数调用,已自动去除。") tool = None loop = True - elif "Multi-modal output is not supported" in str(result): + continue + elif "Multi-modal output is not supported" in result_str: logger.warning(f"{self.get_model()} 不支持多模态输出,降级为文本模态。") modalites = ["Text"] loop = True + continue + elif finish_reason == types.FinishReason.RECITATION: + logger.warning("发生了recitation,正在尝试加温重试...") + temperature += 0.2 + logger.info(f"当前温度: {temperature}") + if temperature < 2: + loop = True + else: + raise Exception("温度已到达(或超过)2") + continue llm_response = LLMResponse("assistant") - chain = [] - - finish_reason = result.candidates[0].finish_reason if finish_reason == types.FinishReason.SAFETY: raise Exception("模型生成内容未通过用户定义的内容安全检查") @@ -255,9 +266,22 @@ class ProviderGoogleGenAI(Provider): raise Exception("模型生成内容违反Gemini平台政策") if not result.candidates[0].content.parts: + logger.debug(result.candidates) raise Exception("API 返回的内容为空。") - for part in result.candidates[0].content.parts: + llm_response.result_chain = self._process_content_parts( + result.candidates[0].content.parts, llm_response + ) + + return llm_response + + def _process_content_parts( + self, parts: types.Part, llm_response: LLMResponse + ) -> MessageChain: + """处理内容部分并构建消息链""" + chain = [] + part: types.Part + for part in parts: if part.text: chain.append(Comp.Plain(part.text)) elif part.function_call: @@ -267,10 +291,7 @@ class ProviderGoogleGenAI(Provider): llm_response.tools_call_ids.append(part.function_call.id) elif part.inline_data and part.inline_data.mime_type.startswith("image/"): chain.append(Comp.Image.fromBytes(part.inline_data.data)) - - llm_response.result_chain = MessageChain(chain=chain) - - return llm_response + return MessageChain(chain=chain) async def text_chat( self, @@ -305,11 +326,12 @@ class ProviderGoogleGenAI(Provider): retry = 10 keys = self.api_keys.copy() chosen_key = random.choice(keys) + temp = kwargs.get("temperature", 0.7) # 暂定默认温度为0.7 for _ in range(retry): try: self.chosen_api_key = chosen_key - llm_response = await self._query(payloads, func_tool) + llm_response = await self._query(payloads, func_tool, temp) break except Exception as e: if "429" in str(e) or "API key not valid" in str(e): From 2ca95eaa9f65481d27358953540b1aeefc6089f4 Mon Sep 17 00:00:00 2001 From: Raven95676 Date: Fri, 11 Apr 2025 14:42:24 +0800 Subject: [PATCH 05/19] =?UTF-8?q?fix:=20=E5=9C=A8=E8=AE=BE=E7=BD=AE?= =?UTF-8?q?=E6=96=B0key=E5=90=8E=E9=87=8D=E6=96=B0=E5=88=9D=E5=A7=8B?= =?UTF-8?q?=E5=8C=96Gemini=E5=AE=A2=E6=88=B7=E7=AB=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/core/provider/sources/gemini_source.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/astrbot/core/provider/sources/gemini_source.py b/astrbot/core/provider/sources/gemini_source.py index 89ff69559..cd42398e0 100644 --- a/astrbot/core/provider/sources/gemini_source.py +++ b/astrbot/core/provider/sources/gemini_source.py @@ -325,27 +325,25 @@ class ProviderGoogleGenAI(Provider): retry = 10 keys = self.api_keys.copy() - chosen_key = random.choice(keys) temp = kwargs.get("temperature", 0.7) # 暂定默认温度为0.7 for _ in range(retry): try: - self.chosen_api_key = chosen_key llm_response = await self._query(payloads, func_tool, temp) break except Exception as e: if "429" in str(e) or "API key not valid" in str(e): - keys.remove(chosen_key) + keys.remove(self.chosen_api_key) if len(keys) > 0: - chosen_key = random.choice(keys) + self.set_key(random.choice(keys)) logger.info( - f"检测到 Key 异常({str(e)}),正在尝试更换 API Key 重试... 当前 Key: {chosen_key[:12]}..." + f"检测到 Key 异常({str(e)}),正在尝试更换 API Key 重试... 当前 Key: {self.chosen_api_key[:12]}..." ) await asyncio.sleep(1) continue else: logger.error( - f"检测到 Key 异常({str(e)}),且已没有可用的 Key。 当前 Key: {chosen_key[:12]}..." + f"检测到 Key 异常({str(e)}),且已没有可用的 Key。 当前 Key: {self.chosen_api_key[:12]}..." ) raise Exception("达到了 Gemini 速率限制, 请稍后再试...") else: @@ -391,6 +389,14 @@ class ProviderGoogleGenAI(Provider): def set_key(self, key): self.chosen_api_key = key + # 重新初始化客户端 + self.client = genai.Client( + api_key=self.chosen_api_key, + http_options=types.HttpOptions( + base_url=self.api_base, + timeout=self.timeout * 1000, # 毫秒 + ), + ).aio async def assemble_context(self, text: str, image_urls: List[str] = None): """ From e8ffebc006402023aa046478e7bc7c380257ba85 Mon Sep 17 00:00:00 2001 From: Raven95676 Date: Fri, 11 Apr 2025 15:01:20 +0800 Subject: [PATCH 06/19] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=E6=B6=88?= =?UTF-8?q?=E6=81=AF=E5=A4=84=E7=90=86=E6=B5=81=E7=A8=8B=E4=B8=AD=E5=8F=AF?= =?UTF-8?q?=E8=83=BD=E5=87=BA=E7=8E=B0=E7=9A=84=E7=A9=BA=E6=B6=88=E6=81=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/core/provider/entities.py | 2 +- astrbot/core/provider/sources/openai_source.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/astrbot/core/provider/entities.py b/astrbot/core/provider/entities.py index 99824fd0e..ffa20029e 100644 --- a/astrbot/core/provider/entities.py +++ b/astrbot/core/provider/entities.py @@ -155,7 +155,7 @@ class ProviderRequest: if self.image_urls: user_content = { "role": "user", - "content": [{"type": "text", "text": self.prompt}], + "content": [{"type": "text", "text": self.prompt if self.prompt else "[图片]"}], } for image_url in self.image_urls: if image_url.startswith("http"): diff --git a/astrbot/core/provider/sources/openai_source.py b/astrbot/core/provider/sources/openai_source.py index 8023d18d1..e945dea3e 100644 --- a/astrbot/core/provider/sources/openai_source.py +++ b/astrbot/core/provider/sources/openai_source.py @@ -497,7 +497,7 @@ class ProviderOpenAIOfficial(Provider): async def assemble_context(self, text: str, image_urls: List[str] = None) -> dict: """组装成符合 OpenAI 格式的 role 为 user 的消息段""" if image_urls: - user_content = {"role": "user", "content": [{"type": "text", "text": text}]} + user_content = {"role": "user", "content": [{"type": "text", "text": text if text else "[图片]"}]} for image_url in image_urls: if image_url.startswith("http"): image_path = await download_image_by_url(image_url) From 1b3963ebea0a583eff76c6c2871ca624468a9b22 Mon Sep 17 00:00:00 2001 From: Raven95676 Date: Fri, 11 Apr 2025 18:07:00 +0800 Subject: [PATCH 07/19] =?UTF-8?q?fix:=20=E6=9B=B4=E6=96=B0=E7=B1=BB?= =?UTF-8?q?=E5=9E=8B=E6=8F=90=E7=A4=BA=EF=BC=8C=E7=AE=80=E5=8C=96=E4=BB=A3?= =?UTF-8?q?=E7=A0=81=E5=B9=B6=E4=BF=AE=E5=A4=8D=E6=BD=9C=E5=9C=A8=E7=9A=84?= =?UTF-8?q?=E7=A9=BA=E5=80=BC=E9=97=AE=E9=A2=98=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../core/provider/sources/gemini_source.py | 174 ++++++++---------- 1 file changed, 77 insertions(+), 97 deletions(-) diff --git a/astrbot/core/provider/sources/gemini_source.py b/astrbot/core/provider/sources/gemini_source.py index cd42398e0..2f865dbbb 100644 --- a/astrbot/core/provider/sources/gemini_source.py +++ b/astrbot/core/provider/sources/gemini_source.py @@ -2,7 +2,7 @@ import asyncio import base64 import json import random -from typing import Dict, List +from typing import Dict, List, Optional from google import genai from google.genai import errors, types @@ -55,8 +55,8 @@ class ProviderGoogleGenAI(Provider): self.api_keys: List = provider_config.get("key", []) self.chosen_api_key: str = self.api_keys[0] if len(self.api_keys) > 0 else None self.timeout: int = provider_config.get("timeout", 180) - self.api_base: str = provider_config.get("api_base", None) - if self.api_base.endswith("/"): + self.api_base: Optional[str] = provider_config.get("api_base", None) + if self.api_base and self.api_base.endswith("/"): self.api_base = self.api_base[:-1] if isinstance(self.timeout, str): self.timeout = int(self.timeout) @@ -90,70 +90,56 @@ class ProviderGoogleGenAI(Provider): except errors.APIError as e: raise Exception(f"获取模型列表失败: {e}") - def _prepare_conversation( - self, - payloads: Dict, - ) -> List[types.Content]: + @staticmethod + def _prepare_conversation(payloads: Dict) -> List[types.Content]: """准备 Gemini SDK 的 Content 列表""" - gemini_contents = [] + + def create_text_part(text: str) -> types.UserContent: + content_a = text if text else " " + if not text: + logger.warning("文本内容为空,已添加空格占位") + return types.UserContent(parts=[types.Part.from_text(text=content_a)]) + + def process_image_url(image_url_dict: dict) -> types.Part: + url = image_url_dict["url"] + mime_type = url.split(":")[1].split(";")[0] + image_bytes = base64.b64decode(url.split(",", 1)[1]) + return types.Part.from_bytes(data=image_bytes, mime_type=mime_type) + + gemini_contents: List[types.Content] = [] for message in payloads["messages"]: - role = message["role"] - content = message.get("content") + role, content = message["role"], message.get("content") if role == "user": if isinstance(content, str): - if content: - gemini_contents.append( - types.UserContent( - parts=[types.Part.from_text(text=content)] - ) - ) - else: - logger.warning("文本内容为空,已添加空格占位") - gemini_contents.append( - types.UserContent(parts=[types.Part.from_text(text=" ")]) - ) - + gemini_contents.append(create_text_part(content)) elif isinstance(content, list): - parts = [] - for item in content: - if item.get("type") == "text": - text_content = item.get("text") - if text_content: - parts.append(types.Part.from_text(text=text_content)) - else: - logger.warning("文本内容为空,已添加空格占位") - parts.append(types.Part.from_text(text=" ")) - elif item.get("type") == "image_url": - image_url_dict = item["image_url"] - url = image_url_dict["url"] - mime_part, base64_data = url.split(",", 1) - mime_type = mime_part.split(":")[1].split(";")[0] - image_bytes = base64.b64decode(base64_data) - parts.append( - types.Part.from_bytes( - data=image_bytes, mime_type=mime_type - ) - ) + parts = [ + types.Part.from_text(text=item["text"] or " ") + if item["type"] == "text" + else process_image_url(item["image_url"]) + for item in content + ] gemini_contents.append(types.UserContent(parts=parts)) elif role == "assistant": if content: gemini_contents.append( - types.ModelContent( - parts=[types.Part.from_text(text=message["content"])] - ) + types.ModelContent(parts=[types.Part.from_text(text=content)]) ) elif "tool_calls" in message: - parts = [ - { - "name": tool_call["function"]["name"], - "args": json.loads(tool_call["function"]["arguments"]), - } - for tool_call in message["tool_calls"] - ] - gemini_contents.append( - types.ModelContent(parts=[types.Part.from_function_call(parts)]) + gemini_contents.extend( + [ + types.ModelContent( + parts=[ + types.Part.from_function_call( + name=tool["function"]["name"], + args=json.loads(tool["function"]["arguments"]), + ) + ] + ) + for tool in message["tool_calls"] + ] ) else: logger.warning("assistant 角色的消息内容为空,已添加空格占位") @@ -166,32 +152,26 @@ class ProviderGoogleGenAI(Provider): types.UserContent( parts=[ types.Part.from_function_response( - { + name=message["tool_call_id"], + response={ "name": message["tool_call_id"], - "response": { - "name": message["tool_call_id"], - "content": message["content"], - }, - } + "content": message["content"], + }, ) ] ) ) - # logger.debug(f"gemini_contents: {gemini_contents}") - return gemini_contents async def _query( self, payloads: dict, tools: FuncCall, temperature: float = 0.7 ) -> LLMResponse: """非流式请求 Gemini API""" - if tools: - t = tools.get_func_desc_google_genai_style() - tool = ( - types.Tool(function_declarations=t["function_declarations"]) - if t - else None + tool_list = [] + if func_desc := tools.get_func_desc_google_genai_style() if tools else None: + tool_list.append( + types.Tool(function_declarations=func_desc["function_declarations"]) ) system_instruction = "" @@ -202,27 +182,28 @@ class ProviderGoogleGenAI(Provider): conversation = self._prepare_conversation(payloads) - modalites = ["Text"] + modalities = ["Text"] if self.provider_config.get("gm_resp_image_modal", False): - modalites.append("Image") + modalities.append("Image") - loop = True - while loop: - loop = False - result = await self.client.models.generate_content( - model=self.get_model(), - contents=conversation, - config=types.GenerateContentConfig( - system_instruction=system_instruction, - temperature=temperature, - tools=[tool] if tool else None, - safety_settings=self.safety_settings - if self.safety_settings - else None, - automatic_function_calling=types.AutomaticFunctionCallingConfig( - disable=True + while True: + result: types.GenerateContentResponse = ( + await self.client.models.generate_content( + model=self.get_model(), + contents=conversation, + config=types.GenerateContentConfig( + system_instruction=system_instruction, + temperature=temperature, + response_modalities=modalities, + tools=tool_list, + safety_settings=self.safety_settings + if self.safety_settings + else None, + automatic_function_calling=types.AutomaticFunctionCallingConfig( + disable=True + ), ), - ), + ) ) result_str = str(result) @@ -230,29 +211,27 @@ class ProviderGoogleGenAI(Provider): if "Developer instruction is not enabled" in result_str: logger.warning(f"{self.get_model()} 不支持 system prompt,已自动去除。") system_instruction = "" - loop = True continue elif "Function calling is not enabled" in result_str: logger.warning(f"{self.get_model()} 不支持函数调用,已自动去除。") - tool = None - loop = True + tool_list = None continue elif "Multi-modal output is not supported" in result_str: logger.warning(f"{self.get_model()} 不支持多模态输出,降级为文本模态。") - modalites = ["Text"] - loop = True + modalities = ["Text"] continue elif finish_reason == types.FinishReason.RECITATION: logger.warning("发生了recitation,正在尝试加温重试...") temperature += 0.2 logger.info(f"当前温度: {temperature}") - if temperature < 2: - loop = True - else: + if temperature > 2: raise Exception("温度已到达(或超过)2") continue + break llm_response = LLMResponse("assistant") + result_parts: Optional[types.Part] = result.candidates[0].content.parts + finish_reason = result.candidates[0].finish_reason if finish_reason == types.FinishReason.SAFETY: raise Exception("模型生成内容未通过用户定义的内容安全检查") @@ -265,18 +244,19 @@ class ProviderGoogleGenAI(Provider): }: raise Exception("模型生成内容违反Gemini平台政策") - if not result.candidates[0].content.parts: + if not result_parts: logger.debug(result.candidates) raise Exception("API 返回的内容为空。") llm_response.result_chain = self._process_content_parts( - result.candidates[0].content.parts, llm_response + result_parts, llm_response ) return llm_response + @staticmethod def _process_content_parts( - self, parts: types.Part, llm_response: LLMResponse + parts: types.Part, llm_response: LLMResponse ) -> MessageChain: """处理内容部分并构建消息链""" chain = [] From c5e8bc7e20969c0729c92eadf17a0d4b855908c6 Mon Sep 17 00:00:00 2001 From: Raven95676 Date: Fri, 11 Apr 2025 18:55:46 +0800 Subject: [PATCH 08/19] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=E6=A8=A1?= =?UTF-8?q?=E5=9E=8B=E7=94=9F=E6=88=90=E5=86=85=E5=AE=B9=E7=9A=84=E9=87=8D?= =?UTF-8?q?=E8=AF=95=E6=9C=BA=E5=88=B6=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../core/provider/sources/gemini_source.py | 59 +++++++++++-------- 1 file changed, 33 insertions(+), 26 deletions(-) diff --git a/astrbot/core/provider/sources/gemini_source.py b/astrbot/core/provider/sources/gemini_source.py index 2f865dbbb..50f4018ee 100644 --- a/astrbot/core/provider/sources/gemini_source.py +++ b/astrbot/core/provider/sources/gemini_source.py @@ -5,7 +5,7 @@ import random from typing import Dict, List, Optional from google import genai -from google.genai import errors, types +from google.genai import types import astrbot.core.message.components as Comp from astrbot import logger @@ -87,7 +87,7 @@ class ProviderGoogleGenAI(Provider): for m in models if "generateContent" in m.supported_actions ] - except errors.APIError as e: + except Exception as e: raise Exception(f"获取模型列表失败: {e}") @staticmethod @@ -186,9 +186,10 @@ class ProviderGoogleGenAI(Provider): if self.provider_config.get("gm_resp_image_modal", False): modalities.append("Image") + result: Optional[types.GenerateContentResponse] = None while True: - result: types.GenerateContentResponse = ( - await self.client.models.generate_content( + try: + result = await self.client.models.generate_content( model=self.get_model(), contents=conversation, config=types.GenerateContentConfig( @@ -204,30 +205,36 @@ class ProviderGoogleGenAI(Provider): ), ), ) - ) - result_str = str(result) - finish_reason = result.candidates[0].finish_reason - if "Developer instruction is not enabled" in result_str: - logger.warning(f"{self.get_model()} 不支持 system prompt,已自动去除。") - system_instruction = "" + if result.candidates[0].finish_reason == types.FinishReason.RECITATION: + if temperature > 2: + raise Exception("温度参数已超过最大值2,仍然发生recitation") + temperature += 0.2 + logger.warning( + f"发生了recitation,正在提高温度至{temperature:.1f}重试..." + ) + continue + + break + + except Exception as e: + error_msg = str(e) + if "Developer instruction is not enabled" in error_msg: + logger.warning( + f"{self.get_model()} 不支持 system prompt,已自动去除(影响人格设置)" + ) + system_instruction = None + elif "Function calling is not enabled" in error_msg: + logger.warning(f"{self.get_model()} 不支持函数调用,已自动去除") + tool_list = None + elif "Multi-modal output is not supported" in error_msg: + logger.warning( + f"{self.get_model()} 不支持多模态输出,降级为文本模态" + ) + modalities = ["Text"] + else: + raise continue - elif "Function calling is not enabled" in result_str: - logger.warning(f"{self.get_model()} 不支持函数调用,已自动去除。") - tool_list = None - continue - elif "Multi-modal output is not supported" in result_str: - logger.warning(f"{self.get_model()} 不支持多模态输出,降级为文本模态。") - modalities = ["Text"] - continue - elif finish_reason == types.FinishReason.RECITATION: - logger.warning("发生了recitation,正在尝试加温重试...") - temperature += 0.2 - logger.info(f"当前温度: {temperature}") - if temperature > 2: - raise Exception("温度已到达(或超过)2") - continue - break llm_response = LLMResponse("assistant") result_parts: Optional[types.Part] = result.candidates[0].content.parts From b493a808fe2cd59741dcb6a3d0f4346ff2a1c762 Mon Sep 17 00:00:00 2001 From: Raven95676 Date: Fri, 11 Apr 2025 20:25:20 +0800 Subject: [PATCH 09/19] =?UTF-8?q?fix:=20=E5=A4=84=E7=90=86=E6=9B=B4?= =?UTF-8?q?=E5=A4=9A=E5=A4=9A=E6=A8=A1=E6=80=81=E4=B8=8D=E6=94=AF=E6=8C=81?= =?UTF-8?q?=E9=94=99=E8=AF=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/core/provider/sources/gemini_source.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/astrbot/core/provider/sources/gemini_source.py b/astrbot/core/provider/sources/gemini_source.py index 50f4018ee..1447b9e2c 100644 --- a/astrbot/core/provider/sources/gemini_source.py +++ b/astrbot/core/provider/sources/gemini_source.py @@ -227,7 +227,11 @@ class ProviderGoogleGenAI(Provider): elif "Function calling is not enabled" in error_msg: logger.warning(f"{self.get_model()} 不支持函数调用,已自动去除") tool_list = None - elif "Multi-modal output is not supported" in error_msg: + elif ( + "Multi-modal output is not supported" + or "Model does not support the requested response modalities" + in error_msg + ): logger.warning( f"{self.get_model()} 不支持多模态输出,降级为文本模态" ) From bd24cf3ea40b3943574f3a787ade3cb125bdbff3 Mon Sep 17 00:00:00 2001 From: Raven95676 Date: Fri, 11 Apr 2025 23:45:30 +0800 Subject: [PATCH 10/19] =?UTF-8?q?feat:=20=E5=88=9D=E6=AD=A5=E5=AE=8C?= =?UTF-8?q?=E6=88=90=E5=8E=9F=E7=94=9F=E6=B5=81=E5=BC=8F=E8=AF=B7=E6=B1=82?= =?UTF-8?q?=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../core/provider/sources/gemini_source.py | 221 +++++++++++++----- 1 file changed, 165 insertions(+), 56 deletions(-) diff --git a/astrbot/core/provider/sources/gemini_source.py b/astrbot/core/provider/sources/gemini_source.py index 1447b9e2c..093955c52 100644 --- a/astrbot/core/provider/sources/gemini_source.py +++ b/astrbot/core/provider/sources/gemini_source.py @@ -2,10 +2,11 @@ import asyncio import base64 import json import random -from typing import Dict, List, Optional +from typing import Dict, List, Optional, AsyncGenerator from google import genai from google.genai import types +from google.genai.errors import APIError import astrbot.core.message.components as Comp from astrbot import logger @@ -87,8 +88,8 @@ class ProviderGoogleGenAI(Provider): for m in models if "generateContent" in m.supported_actions ] - except Exception as e: - raise Exception(f"获取模型列表失败: {e}") + except APIError as e: + raise Exception(f"获取模型列表失败: {e.message}") @staticmethod def _prepare_conversation(payloads: Dict) -> List[types.Content]: @@ -168,17 +169,18 @@ class ProviderGoogleGenAI(Provider): self, payloads: dict, tools: FuncCall, temperature: float = 0.7 ) -> LLMResponse: """非流式请求 Gemini API""" - tool_list = [] - if func_desc := tools.get_func_desc_google_genai_style() if tools else None: - tool_list.append( - types.Tool(function_declarations=func_desc["function_declarations"]) - ) + tool_list = None + if tools: + func_desc = tools.get_func_desc_google_genai_style() + if func_desc: + tool_list = [ + types.Tool(function_declarations=func_desc["function_declarations"]) + ] - system_instruction = "" - for message in payloads["messages"]: - if message["role"] == "system": - system_instruction = message["content"] - break + system_instruction = next( + (msg["content"] for msg in payloads["messages"] if msg["role"] == "system"), + None, + ) conversation = self._prepare_conversation(payloads) @@ -217,20 +219,19 @@ class ProviderGoogleGenAI(Provider): break - except Exception as e: - error_msg = str(e) - if "Developer instruction is not enabled" in error_msg: + except APIError as e: + if "Developer instruction is not enabled" in e.message: logger.warning( f"{self.get_model()} 不支持 system prompt,已自动去除(影响人格设置)" ) system_instruction = None - elif "Function calling is not enabled" in error_msg: + elif "Function calling is not enabled" in e.message: logger.warning(f"{self.get_model()} 不支持函数调用,已自动去除") tool_list = None elif ( "Multi-modal output is not supported" or "Model does not support the requested response modalities" - in error_msg + in e.message ): logger.warning( f"{self.get_model()} 不支持多模态输出,降级为文本模态" @@ -241,8 +242,95 @@ class ProviderGoogleGenAI(Provider): continue llm_response = LLMResponse("assistant") - result_parts: Optional[types.Part] = result.candidates[0].content.parts + llm_response.result_chain = self._process_content_parts(result, llm_response) + return llm_response + + async def _query_stream( + self, payloads: dict, tools: FuncCall, temperature: float = 0.7 + ) -> AsyncGenerator[LLMResponse, None]: + """流式请求 Gemini API""" + tool_list = None + if tools: + func_desc = tools.get_func_desc_google_genai_style() + if func_desc: + tool_list = [ + types.Tool(function_declarations=func_desc["function_declarations"]) + ] + + system_instruction = next( + (msg["content"] for msg in payloads["messages"] if msg["role"] == "system"), + None, + ) + + conversation = self._prepare_conversation(payloads) + + result = None + while True: + try: + result = await self.client.models.generate_content_stream( + model=self.get_model(), + contents=conversation, + config=types.GenerateContentConfig( + system_instruction=system_instruction, + temperature=temperature, + tools=tool_list, + safety_settings=self.safety_settings + if self.safety_settings + else None, + automatic_function_calling=types.AutomaticFunctionCallingConfig( + disable=True + ), + ), + ) + + break + + except APIError as e: + if "Developer instruction is not enabled" in e.message: + logger.warning( + f"{self.get_model()} 不支持 system prompt,已自动去除(影响人格设置)" + ) + system_instruction = None + elif "Function calling is not enabled" in e.message: + logger.warning(f"{self.get_model()} 不支持函数调用,已自动去除") + tool_list = None + else: + raise + continue + + if not result: + raise Exception("API 返回异常") + + async for chunk in result: + llm_response = LLMResponse("assistant", is_chunk=True) + + if chunk.candidates[0].content.parts and any( + part.function_call for part in chunk.candidates[0].content.parts + ): + response = LLMResponse("assistant", is_chunk=False) + response.result_chain = self._process_content_parts(chunk, response) + yield response + break + + if chunk.text: + llm_response.result_chain = MessageChain(chain=[Comp.Plain(chunk.text)]) + yield llm_response + + if chunk.candidates[0].finish_reason: + llm_response = LLMResponse("assistant", is_chunk=False) + llm_response.result_chain = self._process_content_parts( + chunk, llm_response + ) + yield llm_response + break + + @staticmethod + def _process_content_parts( + result: types.GenerateContentResponse, llm_response: LLMResponse + ) -> MessageChain: + """处理内容部分并构建消息链""" finish_reason = result.candidates[0].finish_reason + result_parts: Optional[types.Part] = result.candidates[0].content.parts if finish_reason == types.FinishReason.SAFETY: raise Exception("模型生成内容未通过用户定义的内容安全检查") @@ -259,20 +347,9 @@ class ProviderGoogleGenAI(Provider): logger.debug(result.candidates) raise Exception("API 返回的内容为空。") - llm_response.result_chain = self._process_content_parts( - result_parts, llm_response - ) - - return llm_response - - @staticmethod - def _process_content_parts( - parts: types.Part, llm_response: LLMResponse - ) -> MessageChain: - """处理内容部分并构建消息链""" chain = [] part: types.Part - for part in parts: + for part in result_parts: if part.text: chain.append(Comp.Plain(part.text)) elif part.function_call: @@ -322,19 +399,19 @@ class ProviderGoogleGenAI(Provider): try: llm_response = await self._query(payloads, func_tool, temp) break - except Exception as e: - if "429" in str(e) or "API key not valid" in str(e): + except APIError as e: + if e.code == 429 or "API key not valid" in e.message: keys.remove(self.chosen_api_key) if len(keys) > 0: self.set_key(random.choice(keys)) logger.info( - f"检测到 Key 异常({str(e)}),正在尝试更换 API Key 重试... 当前 Key: {self.chosen_api_key[:12]}..." + f"检测到 Key 异常({e.message}),正在尝试更换 API Key 重试... 当前 Key: {self.chosen_api_key[:12]}..." ) await asyncio.sleep(1) continue else: logger.error( - f"检测到 Key 异常({str(e)}),且已没有可用的 Key。 当前 Key: {self.chosen_api_key[:12]}..." + f"检测到 Key 异常({e.message}),且已没有可用的 Key。 当前 Key: {self.chosen_api_key[:12]}..." ) raise Exception("达到了 Gemini 速率限制, 请稍后再试...") else: @@ -347,30 +424,62 @@ class ProviderGoogleGenAI(Provider): async def text_chat_stream( self, - prompt, - session_id=None, - image_urls=..., - func_tool=None, - contexts=..., + prompt: str, + session_id: str = None, + image_urls: List[str] = [], + func_tool: FuncCall = None, + contexts=[], system_prompt=None, tool_calls_result=None, **kwargs, - ): - # raise NotImplementedError("This method is not implemented yet.") - # 调用 text_chat 模拟流式 - llm_response = await self.text_chat( - prompt=prompt, - session_id=session_id, - image_urls=image_urls, - func_tool=func_tool, - contexts=contexts, - system_prompt=system_prompt, - tool_calls_result=tool_calls_result, - ) - llm_response.is_chunk = True - yield llm_response - llm_response.is_chunk = False - yield llm_response + ) -> AsyncGenerator[LLMResponse, None]: + new_record = await self.assemble_context(prompt, image_urls) + context_query = [*contexts, new_record] + if system_prompt: + context_query.insert(0, {"role": "system", "content": system_prompt}) + + for part in context_query: + if "_no_save" in part: + del part["_no_save"] + + # tool calls result + if tool_calls_result: + context_query.extend(tool_calls_result.to_openai_messages()) + + model_config = self.provider_config.get("model_config", {}) + model_config["model"] = self.get_model() + + payloads = {"messages": context_query, **model_config} + + retry = 10 + keys = self.api_keys.copy() + temp = kwargs.get("temperature", 0.7) # 暂定默认温度为0.7 + + for _ in range(retry): + try: + async for response in self._query_stream(payloads, func_tool, temp): + yield response + break + except APIError as e: + if e.code == 429 or "API key not valid" in e.message: + keys.remove(self.chosen_api_key) + if len(keys) > 0: + self.set_key(random.choice(keys)) + logger.info( + f"检测到 Key 异常({e.message}),正在尝试更换 API Key 重试... 当前 Key: {self.chosen_api_key[:12]}..." + ) + await asyncio.sleep(1) + continue + else: + logger.error( + f"检测到 Key 异常({e.message}),且已没有可用的 Key。 当前 Key: {self.chosen_api_key[:12]}..." + ) + raise Exception("达到了 Gemini 速率限制, 请稍后再试...") + else: + logger.error( + f"发生了错误(gemini_source)。Provider 配置如下: {self.provider_config}" + ) + raise e def get_current_key(self) -> str: return self.chosen_api_key From 44dbe475afdba26bd2ace2ccbb6f02556deddd75 Mon Sep 17 00:00:00 2001 From: Raven95676 Date: Sat, 12 Apr 2025 00:23:57 +0800 Subject: [PATCH 11/19] =?UTF-8?q?refactor:=20=E6=8B=86=E5=88=86=E6=96=B9?= =?UTF-8?q?=E6=B3=95=E4=BB=A5=E6=8F=90=E9=AB=98=E4=BB=A3=E7=A0=81=E5=8F=AF?= =?UTF-8?q?=E8=AF=BB=E6=80=A7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../core/provider/sources/gemini_source.py | 430 +++++++++--------- 1 file changed, 206 insertions(+), 224 deletions(-) diff --git a/astrbot/core/provider/sources/gemini_source.py b/astrbot/core/provider/sources/gemini_source.py index 093955c52..a6de908ce 100644 --- a/astrbot/core/provider/sources/gemini_source.py +++ b/astrbot/core/provider/sources/gemini_source.py @@ -55,12 +55,18 @@ class ProviderGoogleGenAI(Provider): ) self.api_keys: List = provider_config.get("key", []) self.chosen_api_key: str = self.api_keys[0] if len(self.api_keys) > 0 else None - self.timeout: int = provider_config.get("timeout", 180) + self.timeout: int = int(provider_config.get("timeout", 180)) + self.api_base: Optional[str] = provider_config.get("api_base", None) if self.api_base and self.api_base.endswith("/"): self.api_base = self.api_base[:-1] - if isinstance(self.timeout, str): - self.timeout = int(self.timeout) + + self._init_client() + self.set_model(provider_config["model_config"]["model"]) + self._init_safety_settings() + + def _init_client(self) -> None: + """初始化Gemini客户端""" self.client = genai.Client( api_key=self.chosen_api_key, http_options=types.HttpOptions( @@ -68,8 +74,9 @@ class ProviderGoogleGenAI(Provider): timeout=self.timeout * 1000, # 毫秒 ), ).aio - self.set_model(provider_config["model_config"]["model"]) + def _init_safety_settings(self) -> None: + """初始化安全设置""" user_safety_config = self.provider_config.get("gm_safety_settings", {}) self.safety_settings = [ types.SafetySetting( @@ -80,16 +87,59 @@ class ProviderGoogleGenAI(Provider): and threshold_str in self.THRESHOLD_MAPPING ] - async def get_models(self): - try: - models = await self.client.models.list() - return [ - m.name.replace("models/", "") - for m in models - if "generateContent" in m.supported_actions - ] - except APIError as e: - raise Exception(f"获取模型列表失败: {e.message}") + async def _handle_api_error(self, e: APIError, keys: List[str]) -> bool: + """处理API错误,返回是否需要重试""" + if e.code == 429 or "API key not valid" in e.message: + keys.remove(self.chosen_api_key) + if len(keys) > 0: + self.set_key(random.choice(keys)) + logger.info( + f"检测到 Key 异常({e.message}),正在尝试更换 API Key 重试... 当前 Key: {self.chosen_api_key[:12]}..." + ) + await asyncio.sleep(1) + return True + else: + logger.error( + f"检测到 Key 异常({e.message}),且已没有可用的 Key。 当前 Key: {self.chosen_api_key[:12]}..." + ) + raise Exception("达到了 Gemini 速率限制, 请稍后再试...") + else: + logger.error( + f"发生了错误(gemini_source)。Provider 配置如下: {self.provider_config}" + ) + raise e + + async def _prepare_query_config( + self, + tools: Optional[FuncCall] = None, + system_instruction: Optional[str] = None, + temperature: Optional[float] = 0.7, + modalities: Optional[List[str]] = None, + ) -> types.GenerateContentConfig: + """准备查询配置""" + if not modalities: + modalities = ["Text"] + if self.provider_config.get("gm_resp_image_modal", False): + modalities.append("Image") + + tool_list = None + if tools: + func_desc = tools.get_func_desc_google_genai_style() + if func_desc: + tool_list = [ + types.Tool(function_declarations=func_desc["function_declarations"]) + ] + + return types.GenerateContentConfig( + system_instruction=system_instruction, + temperature=temperature, + response_modalities=modalities, + tools=tool_list, + safety_settings=self.safety_settings if self.safety_settings else None, + automatic_function_calling=types.AutomaticFunctionCallingConfig( + disable=True + ), + ) @staticmethod def _prepare_conversation(payloads: Dict) -> List[types.Content]: @@ -165,165 +215,6 @@ class ProviderGoogleGenAI(Provider): return gemini_contents - async def _query( - self, payloads: dict, tools: FuncCall, temperature: float = 0.7 - ) -> LLMResponse: - """非流式请求 Gemini API""" - tool_list = None - if tools: - func_desc = tools.get_func_desc_google_genai_style() - if func_desc: - tool_list = [ - types.Tool(function_declarations=func_desc["function_declarations"]) - ] - - system_instruction = next( - (msg["content"] for msg in payloads["messages"] if msg["role"] == "system"), - None, - ) - - conversation = self._prepare_conversation(payloads) - - modalities = ["Text"] - if self.provider_config.get("gm_resp_image_modal", False): - modalities.append("Image") - - result: Optional[types.GenerateContentResponse] = None - while True: - try: - result = await self.client.models.generate_content( - model=self.get_model(), - contents=conversation, - config=types.GenerateContentConfig( - system_instruction=system_instruction, - temperature=temperature, - response_modalities=modalities, - tools=tool_list, - safety_settings=self.safety_settings - if self.safety_settings - else None, - automatic_function_calling=types.AutomaticFunctionCallingConfig( - disable=True - ), - ), - ) - - if result.candidates[0].finish_reason == types.FinishReason.RECITATION: - if temperature > 2: - raise Exception("温度参数已超过最大值2,仍然发生recitation") - temperature += 0.2 - logger.warning( - f"发生了recitation,正在提高温度至{temperature:.1f}重试..." - ) - continue - - break - - except APIError as e: - if "Developer instruction is not enabled" in e.message: - logger.warning( - f"{self.get_model()} 不支持 system prompt,已自动去除(影响人格设置)" - ) - system_instruction = None - elif "Function calling is not enabled" in e.message: - logger.warning(f"{self.get_model()} 不支持函数调用,已自动去除") - tool_list = None - elif ( - "Multi-modal output is not supported" - or "Model does not support the requested response modalities" - in e.message - ): - logger.warning( - f"{self.get_model()} 不支持多模态输出,降级为文本模态" - ) - modalities = ["Text"] - else: - raise - continue - - llm_response = LLMResponse("assistant") - llm_response.result_chain = self._process_content_parts(result, llm_response) - return llm_response - - async def _query_stream( - self, payloads: dict, tools: FuncCall, temperature: float = 0.7 - ) -> AsyncGenerator[LLMResponse, None]: - """流式请求 Gemini API""" - tool_list = None - if tools: - func_desc = tools.get_func_desc_google_genai_style() - if func_desc: - tool_list = [ - types.Tool(function_declarations=func_desc["function_declarations"]) - ] - - system_instruction = next( - (msg["content"] for msg in payloads["messages"] if msg["role"] == "system"), - None, - ) - - conversation = self._prepare_conversation(payloads) - - result = None - while True: - try: - result = await self.client.models.generate_content_stream( - model=self.get_model(), - contents=conversation, - config=types.GenerateContentConfig( - system_instruction=system_instruction, - temperature=temperature, - tools=tool_list, - safety_settings=self.safety_settings - if self.safety_settings - else None, - automatic_function_calling=types.AutomaticFunctionCallingConfig( - disable=True - ), - ), - ) - - break - - except APIError as e: - if "Developer instruction is not enabled" in e.message: - logger.warning( - f"{self.get_model()} 不支持 system prompt,已自动去除(影响人格设置)" - ) - system_instruction = None - elif "Function calling is not enabled" in e.message: - logger.warning(f"{self.get_model()} 不支持函数调用,已自动去除") - tool_list = None - else: - raise - continue - - if not result: - raise Exception("API 返回异常") - - async for chunk in result: - llm_response = LLMResponse("assistant", is_chunk=True) - - if chunk.candidates[0].content.parts and any( - part.function_call for part in chunk.candidates[0].content.parts - ): - response = LLMResponse("assistant", is_chunk=False) - response.result_chain = self._process_content_parts(chunk, response) - yield response - break - - if chunk.text: - llm_response.result_chain = MessageChain(chain=[Comp.Plain(chunk.text)]) - yield llm_response - - if chunk.candidates[0].finish_reason: - llm_response = LLMResponse("assistant", is_chunk=False) - llm_response.result_chain = self._process_content_parts( - chunk, llm_response - ) - yield llm_response - break - @staticmethod def _process_content_parts( result: types.GenerateContentResponse, llm_response: LLMResponse @@ -361,6 +252,129 @@ class ProviderGoogleGenAI(Provider): chain.append(Comp.Image.fromBytes(part.inline_data.data)) return MessageChain(chain=chain) + async def _query( + self, payloads: dict, tools: FuncCall, temperature: float = 0.7 + ) -> LLMResponse: + """非流式请求 Gemini API""" + system_instruction = next( + (msg["content"] for msg in payloads["messages"] if msg["role"] == "system"), + None, + ) + + modalities = ["Text"] + if self.provider_config.get("gm_resp_image_modal", False): + modalities.append("Image") + + conversation = self._prepare_conversation(payloads) + + result: Optional[types.GenerateContentResponse] = None + while True: + try: + config = await self._prepare_query_config( + tools, system_instruction, temperature, modalities + ) + result = await self.client.models.generate_content( + model=self.get_model(), + contents=conversation, + config=config, + ) + + if result.candidates[0].finish_reason == types.FinishReason.RECITATION: + if temperature > 2: + raise Exception("温度参数已超过最大值2,仍然发生recitation") + temperature += 0.2 + logger.warning( + f"发生了recitation,正在提高温度至{temperature:.1f}重试..." + ) + continue + + break + + except APIError as e: + if "Developer instruction is not enabled" in e.message: + logger.warning( + f"{self.get_model()} 不支持 system prompt,已自动去除(影响人格设置)" + ) + system_instruction = None + elif "Function calling is not enabled" in e.message: + logger.warning(f"{self.get_model()} 不支持函数调用,已自动去除") + tools = None + elif ( + "Multi-modal output is not supported" + or "Model does not support the requested response modalities" + in e.message + ): + logger.warning( + f"{self.get_model()} 不支持多模态输出,降级为文本模态" + ) + modalities = ["Text"] + else: + raise + continue + + llm_response = LLMResponse("assistant") + llm_response.result_chain = self._process_content_parts(result, llm_response) + return llm_response + + async def _query_stream( + self, payloads: dict, tools: FuncCall, temperature: float = 0.7 + ) -> AsyncGenerator[LLMResponse, None]: + """流式请求 Gemini API""" + system_instruction = next( + (msg["content"] for msg in payloads["messages"] if msg["role"] == "system"), + None, + ) + + conversation = self._prepare_conversation(payloads) + + result = None + while True: + try: + config = await self._prepare_query_config( + tools, system_instruction, temperature + ) + result = await self.client.models.generate_content_stream( + model=self.get_model(), + contents=conversation, + config=config, + ) + break + except APIError as e: + if "Developer instruction is not enabled" in e.message: + logger.warning( + f"{self.get_model()} 不支持 system prompt,已自动去除(影响人格设置)" + ) + system_instruction = None + elif "Function calling is not enabled" in e.message: + logger.warning(f"{self.get_model()} 不支持函数调用,已自动去除") + tools = None + else: + raise + continue + + async for chunk in result: + llm_response = LLMResponse("assistant", is_chunk=True) + + if chunk.candidates[0].content.parts and any( + part.function_call for part in chunk.candidates[0].content.parts + ): + response = LLMResponse("assistant", is_chunk=False) + response.result_chain = self._process_content_parts(chunk, response) + yield response + break + + if chunk.text: + llm_response.result_chain = MessageChain(chain=[Comp.Plain(chunk.text)]) + yield llm_response + + if chunk.candidates[0].finish_reason: + llm_response = LLMResponse("assistant", is_chunk=False) + llm_response.result_chain = self._process_content_parts( + chunk, llm_response + ) + yield llm_response + break + async def text_chat( self, prompt: str, @@ -389,7 +403,6 @@ class ProviderGoogleGenAI(Provider): model_config["model"] = self.get_model() payloads = {"messages": context_query, **model_config} - llm_response = None retry = 10 keys = self.api_keys.copy() @@ -397,30 +410,11 @@ class ProviderGoogleGenAI(Provider): for _ in range(retry): try: - llm_response = await self._query(payloads, func_tool, temp) - break + return await self._query(payloads, func_tool, temp) except APIError as e: - if e.code == 429 or "API key not valid" in e.message: - keys.remove(self.chosen_api_key) - if len(keys) > 0: - self.set_key(random.choice(keys)) - logger.info( - f"检测到 Key 异常({e.message}),正在尝试更换 API Key 重试... 当前 Key: {self.chosen_api_key[:12]}..." - ) - await asyncio.sleep(1) - continue - else: - logger.error( - f"检测到 Key 异常({e.message}),且已没有可用的 Key。 当前 Key: {self.chosen_api_key[:12]}..." - ) - raise Exception("达到了 Gemini 速率限制, 请稍后再试...") - else: - logger.error( - f"发生了错误(gemini_source)。Provider 配置如下: {self.provider_config}" - ) - raise e - - return llm_response + if await self._handle_api_error(e, keys): + continue + break async def text_chat_stream( self, @@ -461,25 +455,20 @@ class ProviderGoogleGenAI(Provider): yield response break except APIError as e: - if e.code == 429 or "API key not valid" in e.message: - keys.remove(self.chosen_api_key) - if len(keys) > 0: - self.set_key(random.choice(keys)) - logger.info( - f"检测到 Key 异常({e.message}),正在尝试更换 API Key 重试... 当前 Key: {self.chosen_api_key[:12]}..." - ) - await asyncio.sleep(1) - continue - else: - logger.error( - f"检测到 Key 异常({e.message}),且已没有可用的 Key。 当前 Key: {self.chosen_api_key[:12]}..." - ) - raise Exception("达到了 Gemini 速率限制, 请稍后再试...") - else: - logger.error( - f"发生了错误(gemini_source)。Provider 配置如下: {self.provider_config}" - ) - raise e + if await self._handle_api_error(e, keys): + continue + break + + async def get_models(self): + try: + models = await self.client.models.list() + return [ + m.name.replace("models/", "") + for m in models + if "generateContent" in m.supported_actions + ] + except APIError as e: + raise Exception(f"获取模型列表失败: {e.message}") def get_current_key(self) -> str: return self.chosen_api_key @@ -489,14 +478,7 @@ class ProviderGoogleGenAI(Provider): def set_key(self, key): self.chosen_api_key = key - # 重新初始化客户端 - self.client = genai.Client( - api_key=self.chosen_api_key, - http_options=types.HttpOptions( - base_url=self.api_base, - timeout=self.timeout * 1000, # 毫秒 - ), - ).aio + self._init_client() async def assemble_context(self, text: str, image_urls: List[str] = None): """ From 3860634fd21d85b4fe08d41d507833250c97f41a Mon Sep 17 00:00:00 2001 From: Raven95676 Date: Sat, 12 Apr 2025 19:15:39 +0800 Subject: [PATCH 12/19] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=E4=BA=86?= =?UTF-8?q?=E5=A4=9A=E6=A8=A1=E6=80=81=E8=BE=93=E5=87=BA=E6=94=AF=E6=8C=81?= =?UTF-8?q?=E5=88=A4=E6=96=AD=E9=97=AE=E9=A2=98=E5=B9=B6=E5=AF=B9=E5=8F=AA?= =?UTF-8?q?=E8=BE=93=E5=87=BA=E5=9B=BE=E7=89=87=E7=9A=84=E6=83=85=E5=86=B5?= =?UTF-8?q?=E8=BF=9B=E8=A1=8C=E5=A4=84=E7=90=86=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../core/provider/sources/gemini_source.py | 27 ++++++++++++++++--- 1 file changed, 23 insertions(+), 4 deletions(-) diff --git a/astrbot/core/provider/sources/gemini_source.py b/astrbot/core/provider/sources/gemini_source.py index a6de908ce..cc026bb0b 100644 --- a/astrbot/core/provider/sources/gemini_source.py +++ b/astrbot/core/provider/sources/gemini_source.py @@ -240,6 +240,13 @@ class ProviderGoogleGenAI(Provider): chain = [] part: types.Part + + # 暂时这样Fallback + if all( + part.inline_data and part.inline_data.mime_type.startswith("image/") + for part in result_parts + ): + chain.append(Comp.Plain("这是图片")) for part in result_parts: if part.text: chain.append(Comp.Plain(part.text)) @@ -300,7 +307,7 @@ class ProviderGoogleGenAI(Provider): logger.warning(f"{self.get_model()} 不支持函数调用,已自动去除") tools = None elif ( - "Multi-modal output is not supported" + "Multi-modal output is not supported" in e.message or "Model does not support the requested response modalities" in e.message ): @@ -358,9 +365,21 @@ class ProviderGoogleGenAI(Provider): if chunk.candidates[0].content.parts and any( part.function_call for part in chunk.candidates[0].content.parts ): - response = LLMResponse("assistant", is_chunk=False) - response.result_chain = self._process_content_parts(chunk, response) - yield response + llm_response = LLMResponse("assistant", is_chunk=False) + llm_response.result_chain = self._process_content_parts( + chunk, llm_response + ) + yield llm_response + break + + if chunk.candidates[0].content.parts and any( + part.inline_data for part in chunk.candidates[0].content.parts + ): + llm_response = LLMResponse("assistant", is_chunk=False) + llm_response.result_chain = self._process_content_parts( + chunk, llm_response + ) + yield llm_response break if chunk.text: From 9c29df47bb60d339d3e739ef5c06a4e8216197d7 Mon Sep 17 00:00:00 2001 From: Raven95676 Date: Sun, 13 Apr 2025 01:09:42 +0800 Subject: [PATCH 13/19] =?UTF-8?q?fix:=20=E6=9B=B4=E6=96=B0=E6=B5=81?= =?UTF-8?q?=E5=BC=8F=E8=BE=93=E5=87=BA=E9=80=BB=E8=BE=91=EF=BC=8C=E7=A6=81?= =?UTF-8?q?=E7=94=A8=E5=9B=BE=E7=89=87=E6=A8=A1=E6=80=81=E5=B9=B6=E6=B7=BB?= =?UTF-8?q?=E5=8A=A0=E6=97=A5=E5=BF=97=E8=AD=A6=E5=91=8A=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/core/provider/sources/gemini_source.py | 17 +++++------------ 1 file changed, 5 insertions(+), 12 deletions(-) diff --git a/astrbot/core/provider/sources/gemini_source.py b/astrbot/core/provider/sources/gemini_source.py index cc026bb0b..580363069 100644 --- a/astrbot/core/provider/sources/gemini_source.py +++ b/astrbot/core/provider/sources/gemini_source.py @@ -119,8 +119,11 @@ class ProviderGoogleGenAI(Provider): """准备查询配置""" if not modalities: modalities = ["Text"] - if self.provider_config.get("gm_resp_image_modal", False): - modalities.append("Image") + + # 流式输出不支持图片模态 + if self.provider_settings.get("streaming_response", False): + logger.warning("流式输出不支持图片模态,已自动降级为文本模态") + modalities = ["Text"] tool_list = None if tools: @@ -372,16 +375,6 @@ class ProviderGoogleGenAI(Provider): yield llm_response break - if chunk.candidates[0].content.parts and any( - part.inline_data for part in chunk.candidates[0].content.parts - ): - llm_response = LLMResponse("assistant", is_chunk=False) - llm_response.result_chain = self._process_content_parts( - chunk, llm_response - ) - yield llm_response - break - if chunk.text: llm_response.result_chain = MessageChain(chain=[Comp.Plain(chunk.text)]) yield llm_response From 739f09059e587317b5b6fa7960a6765daf78174a Mon Sep 17 00:00:00 2001 From: Raven95676 Date: Sun, 13 Apr 2025 12:43:25 +0800 Subject: [PATCH 14/19] =?UTF-8?q?feat:=20=E4=B8=BAGemini=E5=8E=9F=E7=94=9F?= =?UTF-8?q?=E4=BB=A3=E7=A0=81=E6=89=A7=E8=A1=8C=E5=99=A8=E6=8F=90=E4=BE=9B?= =?UTF-8?q?=E6=9C=89=E9=99=90=E6=94=AF=E6=8C=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/core/config/default.py | 7 +++++++ astrbot/core/provider/sources/gemini_source.py | 18 ++++++++++++------ 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 459ff622a..e31eb5570 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -528,6 +528,7 @@ CONFIG_METADATA_2 = { "model": "gemini-2.0-flash-exp", }, "gm_resp_image_modal": False, + "gm_native_coderunner": False, "gm_safety_settings": { "harassment": "BLOCK_MEDIUM_AND_ABOVE", "hate_speech": "BLOCK_MEDIUM_AND_ABOVE", @@ -704,6 +705,12 @@ CONFIG_METADATA_2 = { "type": "bool", "hint": "启用后,将支持返回图片内容。需要模型支持,否则会报错。具体支持模型请查看 Google Gemini 官方网站。温馨提示,如果您需要生成图片,请关闭 `启用群员识别` 配置获得更好的效果。", }, + "gm_native_coderunner": { + "description": "启用原生代码执行器", + "type": "bool", + "hint": "启用后所有函数工具将全部失效", + "obvious_hint": True, + }, "gm_safety_settings": { "description": "安全过滤器", "type": "object", diff --git a/astrbot/core/provider/sources/gemini_source.py b/astrbot/core/provider/sources/gemini_source.py index 580363069..905289987 100644 --- a/astrbot/core/provider/sources/gemini_source.py +++ b/astrbot/core/provider/sources/gemini_source.py @@ -126,12 +126,14 @@ class ProviderGoogleGenAI(Provider): modalities = ["Text"] tool_list = None - if tools: - func_desc = tools.get_func_desc_google_genai_style() - if func_desc: - tool_list = [ - types.Tool(function_declarations=func_desc["function_declarations"]) - ] + if self.provider_config.get("gm_native_coderunner", False): + if tools: + logger.warning("Gemini原生代码执行器已启用,函数工具将被忽略") + tool_list = [types.Tool(code_execution=types.ToolCodeExecution())] + elif tools and (func_desc := tools.get_func_desc_google_genai_style()): + tool_list = [ + types.Tool(function_declarations=func_desc["function_declarations"]) + ] return types.GenerateContentConfig( system_instruction=system_instruction, @@ -252,6 +254,10 @@ class ProviderGoogleGenAI(Provider): chain.append(Comp.Plain("这是图片")) for part in result_parts: if part.text: + if part.executable_code: + part.executable_code = None + if part.code_execution_result: + part.code_execution_result = None chain.append(Comp.Plain(part.text)) elif part.function_call: llm_response.role = "tool" From 310ed76b186acaa558db30703e0bce91e18d0a72 Mon Sep 17 00:00:00 2001 From: Raven95676 Date: Sun, 13 Apr 2025 17:28:34 +0800 Subject: [PATCH 15/19] =?UTF-8?q?fix:=20=E4=BB=85=E5=9C=A8=E7=A1=AE?= =?UTF-8?q?=E5=AE=9E=E5=8C=85=E5=90=AB=E5=9B=BE=E7=89=87=E6=A8=A1=E6=80=81?= =?UTF-8?q?=E6=97=B6=E9=99=8D=E7=BA=A7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/core/provider/sources/gemini_source.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/astrbot/core/provider/sources/gemini_source.py b/astrbot/core/provider/sources/gemini_source.py index 905289987..ab312e502 100644 --- a/astrbot/core/provider/sources/gemini_source.py +++ b/astrbot/core/provider/sources/gemini_source.py @@ -121,7 +121,10 @@ class ProviderGoogleGenAI(Provider): modalities = ["Text"] # 流式输出不支持图片模态 - if self.provider_settings.get("streaming_response", False): + if ( + self.provider_settings.get("streaming_response", False) + and "Image" in modalities + ): logger.warning("流式输出不支持图片模态,已自动降级为文本模态") modalities = ["Text"] From fe95506db416a42c9b2b21e2ed0a2927fff9d079 Mon Sep 17 00:00:00 2001 From: Raven95676 Date: Sun, 13 Apr 2025 17:50:44 +0800 Subject: [PATCH 16/19] =?UTF-8?q?perf:=20=E6=B7=BB=E5=8A=A0=E6=97=A5?= =?UTF-8?q?=E5=BF=97=E8=BF=87=E6=BB=A4=E5=99=A8=E4=BB=A5=E6=8A=91=E5=88=B6?= =?UTF-8?q?=E9=9D=9E=E6=96=87=E6=9C=AC=E9=83=A8=E5=88=86=E8=AD=A6=E5=91=8A?= =?UTF-8?q?=E4=BF=A1=E6=81=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/core/provider/sources/gemini_source.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/astrbot/core/provider/sources/gemini_source.py b/astrbot/core/provider/sources/gemini_source.py index ab312e502..9b58646f5 100644 --- a/astrbot/core/provider/sources/gemini_source.py +++ b/astrbot/core/provider/sources/gemini_source.py @@ -1,6 +1,7 @@ import asyncio import base64 import json +import logging import random from typing import Dict, List, Optional, AsyncGenerator @@ -20,6 +21,16 @@ from astrbot.core.utils.io import download_image_by_url from ..register import register_provider_adapter +class SuppressNonTextPartsWarning(logging.Filter): + """过滤 Gemini SDK 中的非文本部分警告""" + + def filter(self, record): + return "there are non-text parts in the response" not in record.getMessage() + + +logging.getLogger("google_genai.types").addFilter(SuppressNonTextPartsWarning()) + + @register_provider_adapter( "googlegenai_chat_completion", "Google Gemini Chat Completion 提供商适配器" ) From 6986c8d8f7ebf6b5cae3ac58518261fbc511a847 Mon Sep 17 00:00:00 2001 From: Raven95676 Date: Sun, 13 Apr 2025 18:34:57 +0800 Subject: [PATCH 17/19] =?UTF-8?q?fix:=20clean=20code=EF=BC=8C=E5=A4=84?= =?UTF-8?q?=E7=90=86Gemini=E6=B5=81=E5=BC=8F=E8=BE=93=E5=87=BA=E6=9C=80?= =?UTF-8?q?=E5=90=8E=E4=B8=80=E9=83=A8=E5=88=86=E6=A6=82=E7=8E=87=E6=80=A7?= =?UTF-8?q?=E4=B8=BANone=E7=9A=84=E6=83=85=E5=86=B5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/core/provider/sources/gemini_source.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/astrbot/core/provider/sources/gemini_source.py b/astrbot/core/provider/sources/gemini_source.py index 9b58646f5..5209ccd56 100644 --- a/astrbot/core/provider/sources/gemini_source.py +++ b/astrbot/core/provider/sources/gemini_source.py @@ -268,10 +268,6 @@ class ProviderGoogleGenAI(Provider): chain.append(Comp.Plain("这是图片")) for part in result_parts: if part.text: - if part.executable_code: - part.executable_code = None - if part.code_execution_result: - part.code_execution_result = None chain.append(Comp.Plain(part.text)) elif part.function_call: llm_response.role = "tool" @@ -401,9 +397,12 @@ class ProviderGoogleGenAI(Provider): if chunk.candidates[0].finish_reason: llm_response = LLMResponse("assistant", is_chunk=False) - llm_response.result_chain = self._process_content_parts( - chunk, llm_response - ) + if not chunk.candidates[0].content.parts: + llm_response.result_chain = MessageChain(chain=[Comp.Plain(" ")]) + else: + llm_response.result_chain = self._process_content_parts( + chunk, llm_response + ) yield llm_response break From a769fd7d135d198e1d9a25f3b78227722650a597 Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Tue, 15 Apr 2025 10:40:42 +0800 Subject: [PATCH 18/19] chore: add google-genai dependency to project --- pyproject.toml | 1 + requirements.txt | 3 +- uv.lock | 76 ++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 79 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 93ad29f42..0f3a40ce8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,6 +19,7 @@ dependencies = [ "defusedxml>=0.7.1", "dingtalk-stream>=0.22.1", "docstring-parser>=0.16", + "google-genai>=1.10.0", "googlesearch-python>=1.3.0", "lark-oapi>=1.4.12", "lxml-html-clean>=0.4.1", diff --git a/requirements.txt b/requirements.txt index e20771e86..2e0ab1ccc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -29,4 +29,5 @@ defusedxml mcp certifi pip -telegramify-markdown \ No newline at end of file +telegramify-markdown +google-genai \ No newline at end of file diff --git a/uv.lock b/uv.lock index 445941a7d..5c71940d7 100644 --- a/uv.lock +++ b/uv.lock @@ -209,6 +209,7 @@ dependencies = [ { name = "defusedxml" }, { name = "dingtalk-stream" }, { name = "docstring-parser" }, + { name = "google-genai" }, { name = "googlesearch-python" }, { name = "lark-oapi" }, { name = "lxml-html-clean" }, @@ -245,6 +246,7 @@ requires-dist = [ { name = "defusedxml", specifier = ">=0.7.1" }, { name = "dingtalk-stream", specifier = ">=0.22.1" }, { name = "docstring-parser", specifier = ">=0.16" }, + { name = "google-genai", specifier = ">=1.10.0" }, { name = "googlesearch-python", specifier = ">=1.3.0" }, { name = "lark-oapi", specifier = ">=1.4.12" }, { name = "lxml-html-clean", specifier = ">=0.4.1" }, @@ -305,6 +307,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/10/cb/f2ad4230dc2eb1a74edf38f1a38b9b52277f75bef262d8908e60d957e13c/blinker-1.9.0-py3-none-any.whl", hash = "sha256:ba0efaa9080b619ff2f3459d1d500c57bddea4a6b424b60a91141db6fd2f08bc", size = 8458 }, ] +[[package]] +name = "cachetools" +version = "5.5.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/6c/81/3747dad6b14fa2cf53fcf10548cf5aea6913e96fab41a3c198676f8948a5/cachetools-5.5.2.tar.gz", hash = "sha256:1a661caa9175d26759571b2e19580f9d6393969e5dfca11fdb1f947a23e640d4", size = 28380 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/72/76/20fa66124dbe6be5cafeb312ece67de6b61dd91a0247d1ea13db4ebb33c2/cachetools-5.5.2-py3-none-any.whl", hash = "sha256:d26a22bcc62eb95c3beabd9f1ee5e820d3d2704fe2967cbe350e20c8ffcd3f0a", size = 10080 }, +] + [[package]] name = "certifi" version = "2025.1.31" @@ -676,6 +687,38 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c6/c8/a5be5b7550c10858fcf9b0ea054baccab474da77d37f1e828ce043a3a5d4/frozenlist-1.5.0-py3-none-any.whl", hash = "sha256:d994863bba198a4a518b467bb971c56e1db3f180a25c6cf7bb1949c267f748c3", size = 11901 }, ] +[[package]] +name = "google-auth" +version = "2.39.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cachetools" }, + { name = "pyasn1-modules" }, + { name = "rsa" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/cb/8e/8f45c9a32f73e786e954b8f9761c61422955d23c45d1e8c347f9b4b59e8e/google_auth-2.39.0.tar.gz", hash = "sha256:73222d43cdc35a3aeacbfdcaf73142a97839f10de930550d89ebfe1d0a00cde7", size = 274834 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ce/12/ad37a1ef86006d0a0117fc06a4a00bd461c775356b534b425f00dde208ea/google_auth-2.39.0-py2.py3-none-any.whl", hash = "sha256:0150b6711e97fb9f52fe599f55648950cc4540015565d8fbb31be2ad6e1548a2", size = 212319 }, +] + +[[package]] +name = "google-genai" +version = "1.10.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, + { name = "google-auth" }, + { name = "httpx" }, + { name = "pydantic" }, + { name = "requests" }, + { name = "typing-extensions" }, + { name = "websockets" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/0e/7a/224e2f70c835202042969685ee3da00a6475508d1b64f0f1e90144f96beb/google_genai-1.10.0.tar.gz", hash = "sha256:f59423e0f155dc66b7792c8a0e6724c75c72dc699d1eb7907d4d0006d4f6186f", size = 156355 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ba/a0/56839a2e202d79c773edd1c1db124da8eb2a7b657267a888080b678d0369/google_genai-1.10.0-py3-none-any.whl", hash = "sha256:41b105a2fcf8a027fc45cc16694cd559b8cd1272eab7345ad58cfa2c353bf34f", size = 154705 }, +] + [[package]] name = "googlesearch-python" version = "1.3.0" @@ -1402,6 +1445,27 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/50/1b/6921afe68c74868b4c9fa424dad3be35b095e16687989ebbb50ce4fceb7c/psutil-7.0.0-cp37-abi3-win_amd64.whl", hash = "sha256:4cf3d4eb1aa9b348dec30105c55cd9b7d4629285735a102beb4441e38db90553", size = 244885 }, ] +[[package]] +name = "pyasn1" +version = "0.6.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ba/e9/01f1a64245b89f039897cb0130016d79f77d52669aae6ee7b159a6c4c018/pyasn1-0.6.1.tar.gz", hash = "sha256:6f580d2bdd84365380830acf45550f2511469f673cb4a5ae3857a3170128b034", size = 145322 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c8/f1/d6a797abb14f6283c0ddff96bbdd46937f64122b8c925cab503dd37f8214/pyasn1-0.6.1-py3-none-any.whl", hash = "sha256:0d632f46f2ba09143da3a8afe9e33fb6f92fa2320ab7e886e2d0f7672af84629", size = 83135 }, +] + +[[package]] +name = "pyasn1-modules" +version = "0.4.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pyasn1" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/e9/e6/78ebbb10a8c8e4b61a59249394a4a594c1a7af95593dc933a349c8d00964/pyasn1_modules-0.4.2.tar.gz", hash = "sha256:677091de870a80aae844b1ca6134f54652fa2c8c5a52aa396440ac3106e941e6", size = 307892 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/47/8d/d529b5d697919ba8c11ad626e835d4039be708a35b0d22de83a269a6682c/pyasn1_modules-0.4.2-py3-none-any.whl", hash = "sha256:29253a9207ce32b64c3ac6600edc75368f98473906e8fd1043bd6b5b1de2c14a", size = 181259 }, +] + [[package]] name = "pycparser" version = "2.22" @@ -1697,6 +1761,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/3f/51/d4db610ef29373b879047326cbf6fa98b6c1969d6f6dc423279de2b1be2c/requests_toolbelt-1.0.0-py2.py3-none-any.whl", hash = "sha256:cccfdd665f0a24fcf4726e690f65639d272bb0637b9b92dfd91a5568ccf6bd06", size = 54481 }, ] +[[package]] +name = "rsa" +version = "4.9" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pyasn1" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/aa/65/7d973b89c4d2351d7fb232c2e452547ddfa243e93131e7cfa766da627b52/rsa-4.9.tar.gz", hash = "sha256:e38464a49c6c85d7f1351b0126661487a7e0a14a50f1675ec50eb34d4f20ef21", size = 29711 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/49/97/fa78e3d2f65c02c8e1268b9aba606569fe97f6c8f7c2d74394553347c145/rsa-4.9-py3-none-any.whl", hash = "sha256:90260d9058e514786967344d0ef75fa8727eed8a7d2e43ce9f4bcf1b536174f7", size = 34315 }, +] + [[package]] name = "silk-python" version = "0.2.6" From 43ee943acb05dae1294461a653b95a9e2f9cde4c Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Tue, 15 Apr 2025 10:59:16 +0800 Subject: [PATCH 19/19] =?UTF-8?q?=F0=9F=90=9B=20fix:=20=E5=A4=9A=E8=BD=AE?= =?UTF-8?q?=E5=87=BD=E6=95=B0=E8=B0=83=E7=94=A8=E7=9A=84=E6=8A=A5=E9=94=99?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/core/provider/sources/gemini_source.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/astrbot/core/provider/sources/gemini_source.py b/astrbot/core/provider/sources/gemini_source.py index 5209ccd56..e0d4b56e9 100644 --- a/astrbot/core/provider/sources/gemini_source.py +++ b/astrbot/core/provider/sources/gemini_source.py @@ -273,7 +273,10 @@ class ProviderGoogleGenAI(Provider): llm_response.role = "tool" llm_response.tools_call_name.append(part.function_call.name) llm_response.tools_call_args.append(part.function_call.args) - llm_response.tools_call_ids.append(part.function_call.id) + # gemini 返回的 function_call.id 可能为 None + llm_response.tools_call_ids.append( + part.function_call.id or part.function_call.name + ) elif part.inline_data and part.inline_data.mime_type.startswith("image/"): chain.append(Comp.Image.fromBytes(part.inline_data.data)) return MessageChain(chain=chain)