diff --git a/astrbot/core/provider/sources/gemini_source.py b/astrbot/core/provider/sources/gemini_source.py index 7def203db..c9da2ec9f 100644 --- a/astrbot/core/provider/sources/gemini_source.py +++ b/astrbot/core/provider/sources/gemini_source.py @@ -1,8 +1,9 @@ import base64 -import aiohttp import json import random import asyncio +from google import genai +from google.genai import types, errors import astrbot.core.message.components as Comp from astrbot.core.message.message_event_result import MessageChain from astrbot.core.utils.io import download_image_by_url @@ -10,112 +11,28 @@ from astrbot.core.db import BaseDatabase from astrbot.api.provider import Provider, Personality from astrbot import logger from astrbot.core.provider.func_tool_manager import FuncCall -from typing import List +from typing import Dict, List from ..register import register_provider_adapter from astrbot.core.provider.entities import LLMResponse - -class SimpleGoogleGenAIClient: - def __init__(self, api_key: str, api_base: str, timeout: int = 120) -> None: - self.api_key = api_key - if api_base.endswith("/"): - self.api_base = api_base[:-1] - else: - self.api_base = api_base - self.client = aiohttp.ClientSession(trust_env=True) - self.timeout = timeout - - async def models_list(self) -> List[str]: - request_url = f"{self.api_base}/v1beta/models?key={self.api_key}" - async with self.client.get(request_url, timeout=self.timeout) as resp: - response = await resp.json() - - models = [] - for model in response["models"]: - if "generateContent" in model["supportedGenerationMethods"]: - models.append(model["name"].replace("models/", "")) - return models - - async def generate_content( - self, - contents: List[dict], - model: str = "gemini-1.5-flash", - system_instruction: str = "", - tools: dict = None, - modalities: List[str] = ["Text"], - safety_settings: List[dict] = [], - ): - payload = {} - if system_instruction: - payload["system_instruction"] = {"parts": {"text": system_instruction}} - if tools: - payload["tools"] = [tools] - payload["contents"] = contents - payload["generationConfig"] = { - "responseModalities": modalities, - } - payload["safetySettings"] = [ - {"category": s["category"], "threshold": s["threshold"]} - for s in safety_settings - ] - logger.debug(f"payload: {payload}") - request_url = ( - f"{self.api_base}/v1beta/models/{model}:generateContent?key={self.api_key}" - ) - async with self.client.post( - request_url, json=payload, timeout=self.timeout - ) as resp: - if "application/json" in resp.headers.get("Content-Type"): - try: - response = await resp.json() - except Exception as e: - text = await resp.text() - logger.error(f"Gemini 返回了非 json 数据: {text}") - raise e - return response - else: - text = await resp.text() - logger.error(f"Gemini 返回了非 json 数据: {text}") - raise Exception("Gemini 返回了非 json 数据: ") - - async def stream_generate_content( - self, - contents: List[dict], - model: str = "gemini-1.5-flash", - system_instruction: str = "", - tools: dict = None, - modalities: List[str] = ["Text"], - safety_settings: List[dict] = [], - ): - payload = {} - if system_instruction: - payload["system_instruction"] = {"parts": {"text": system_instruction}} - if tools: - payload["tools"] = [tools] - payload["contents"] = contents - payload["generationConfig"] = { - "responseModalities": modalities, - "stream": True, - } - payload["safetySettings"] = [ - {"category": s["category"], "threshold": s["threshold"]} - for s in safety_settings - ] - logger.debug(f"payload: {payload}") - request_url = ( - f"{self.api_base}/v1beta/models/{model}:streamGenerateContent?key={self.api_key}" - ) - async with self.client.post( - request_url, json=payload, timeout=self.timeout - ) as resp: - async for line in resp.content: - if line: - yield line - @register_provider_adapter( "googlegenai_chat_completion", "Google Gemini Chat Completion 提供商适配器" ) class ProviderGoogleGenAI(Provider): + CATEGORY_MAPPING = { + "harassment": types.HarmCategory.HARM_CATEGORY_HARASSMENT, + "hate_speech": types.HarmCategory.HARM_CATEGORY_HATE_SPEECH, + "sexually_explicit": types.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + "dangerous_content": types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + } + + THRESHOLD_MAPPING = { + "BLOCK_NONE": types.HarmBlockThreshold.BLOCK_NONE, + "BLOCK_ONLY_HIGH": types.HarmBlockThreshold.BLOCK_ONLY_HIGH, + "BLOCK_MEDIUM_AND_ABOVE": types.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, + "BLOCK_LOW_AND_ABOVE": types.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, + } + def __init__( self, provider_config: dict, @@ -131,43 +48,145 @@ class ProviderGoogleGenAI(Provider): db_helper, default_persona, ) - self.chosen_api_key = None self.api_keys: List = provider_config.get("key", []) - self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else None - self.timeout = provider_config.get("timeout", 180) + self.chosen_api_key: str = self.api_keys[0] if len(self.api_keys) > 0 else None + self.timeout: int = provider_config.get("timeout", 180) + self.api_base: str = provider_config.get("api_base", None) + if self.api_base.endswith("/"): + self.api_base = self.api_base[:-1] if isinstance(self.timeout, str): self.timeout = int(self.timeout) - self.client = SimpleGoogleGenAIClient( + self.client = genai.Client( api_key=self.chosen_api_key, - api_base=provider_config.get("api_base", None), - timeout=self.timeout, - ) + http_options=types.HttpOptions( + base_url=self.api_base, + timeout=self.timeout * 1000, # 毫秒 + ), + ).aio self.set_model(provider_config["model_config"]["model"]) - safety_mapping = { - "harassment": "HARM_CATEGORY_HARASSMENT", - "hate_speech": "HARM_CATEGORY_HATE_SPEECH", - "sexually_explicit": "HARM_CATEGORY_SEXUALLY_EXPLICIT", - "dangerous_content": "HARM_CATEGORY_DANGEROUS_CONTENT", - } - - self.safety_settings = [] user_safety_config = self.provider_config.get("gm_safety_settings", {}) - for config_key, harm_category in safety_mapping.items(): - if threshold := user_safety_config.get(config_key): - self.safety_settings.append( - {"category": harm_category, "threshold": threshold} - ) + self.safety_settings = [ + types.SafetySetting( + category=harm_category, threshold=self.THRESHOLD_MAPPING[threshold_str] + ) + for config_key, harm_category in self.CATEGORY_MAPPING.items() + if (threshold_str := user_safety_config.get(config_key)) + and threshold_str in self.THRESHOLD_MAPPING + ] async def get_models(self): - return await self.client.models_list() + try: + models = await self.client.models.list() + return [ + m.name.replace("models/", "") + for m in models + if "generateContent" in m.supported_actions + ] + except errors.APIError as e: + raise Exception(f"获取模型列表失败: {e}") + + def _prepare_conversation( + self, + payloads: Dict, + ) -> List[types.Content]: + """准备 Gemini SDK 的 Content 列表""" + gemini_contents = [] + for message in payloads["messages"]: + role = message["role"] + content = message.get("content") + + if role == "user": + if isinstance(content, str): + if content: + gemini_contents.append( + types.UserContent( + parts=[types.Part.from_text(text=content)] + ) + ) + else: + logger.warning("文本内容为空,已添加空格占位") + gemini_contents.append( + types.UserContent(parts=[types.Part.from_text(text=" ")]) + ) + + elif isinstance(content, list): + parts = [] + for item in content: + if item.get("type") == "text": + text_content = item.get("text") + if text_content: + parts.append(types.Part.from_text(text=text_content)) + else: + logger.warning("文本内容为空,已添加空格占位") + parts.append(types.Part.from_text(text=" ")) + elif item.get("type") == "image_url": + image_url_dict = item["image_url"] + url = image_url_dict["url"] + mime_part, base64_data = url.split(",", 1) + mime_type = mime_part.split(":")[1].split(";")[0] + image_bytes = base64.b64decode(base64_data) + parts.append( + types.Part.from_bytes( + data=image_bytes, mime_type=mime_type + ) + ) + gemini_contents.append(types.UserContent(parts=parts)) + + elif role == "assistant": + if content: + gemini_contents.append( + types.ModelContent( + parts=[types.Part.from_text(text=message["content"])] + ) + ) + elif "tool_calls" in message: + parts = [ + { + "name": tool_call["function"]["name"], + "args": json.loads(tool_call["function"]["arguments"]), + } + for tool_call in message["tool_calls"] + ] + gemini_contents.append( + types.ModelContent(parts=[types.Part.from_function_call(parts)]) + ) + else: + logger.warning("assistant 角色的消息内容为空,已添加空格占位") + gemini_contents.append( + types.ModelContent(parts=[types.Part.from_text(text=" ")]) + ) + + elif role == "tool": + gemini_contents.append( + types.UserContent( + parts=[ + types.Part.from_function_response( + { + "name": message["tool_call_id"], + "response": { + "name": message["tool_call_id"], + "content": message["content"], + }, + } + ) + ] + ) + ) + + logger.debug(f"gemini_contents: {gemini_contents}") + + return gemini_contents async def _query(self, payloads: dict, tools: FuncCall) -> LLMResponse: - tool = None + """非流式请求 Gemini API""" if tools: - tool = tools.get_func_desc_google_genai_style() - if not tool: - tool = None + t = tools.get_func_desc_google_genai_style() + tool = ( + types.Tool(function_declarations=t["function_declarations"]) + if t + else None + ) system_instruction = "" for message in payloads["messages"]: @@ -175,137 +194,78 @@ class ProviderGoogleGenAI(Provider): system_instruction = message["content"] break - google_genai_conversation = [] - for message in payloads["messages"]: - if message["role"] == "user": - if isinstance(message["content"], str): - if not message["content"]: - message["content"] = "" - - google_genai_conversation.append( - {"role": "user", "parts": [{"text": message["content"]}]} - ) - elif isinstance(message["content"], list): - # images - parts = [] - for part in message["content"]: - if part["type"] == "text": - if not part["text"]: - part["text"] = "" - parts.append({"text": part["text"]}) - elif part["type"] == "image_url": - parts.append( - { - "inline_data": { - "mime_type": "image/jpeg", - "data": part["image_url"]["url"].replace( - "data:image/jpeg;base64,", "" - ), # base64 - } - } - ) - google_genai_conversation.append({"role": "user", "parts": parts}) - - elif message["role"] == "assistant": - if "content" in message: - if not message["content"]: - message["content"] = "" - google_genai_conversation.append( - {"role": "model", "parts": [{"text": message["content"]}]} - ) - elif "tool_calls" in message: - # tool calls in the last turn - parts = [] - for tool_call in message["tool_calls"]: - parts.append( - { - "functionCall": { - "name": tool_call["function"]["name"], - "args": json.loads( - tool_call["function"]["arguments"] - ), - } - } - ) - google_genai_conversation.append({"role": "model", "parts": parts}) - elif message["role"] == "tool": - parts = [] - parts.append( - { - "functionResponse": { - "name": message["tool_call_id"], - "response": { - "name": message["tool_call_id"], - "content": message["content"], - }, - } - } - ) - google_genai_conversation.append({"role": "user", "parts": parts}) - - logger.debug(f"google_genai_conversation: {google_genai_conversation}") + conversation = self._prepare_conversation(payloads) modalites = ["Text"] if self.provider_config.get("gm_resp_image_modal", False): modalites.append("Image") loop = True + while loop: loop = False - result = await self.client.generate_content( - contents=google_genai_conversation, + result = await self.client.models.generate_content( model=self.get_model(), - system_instruction=system_instruction, - tools=tool, - modalities=modalites, - safety_settings=self.safety_settings, + contents=conversation, + config=types.GenerateContentConfig( + system_instruction=system_instruction, + tools=[tool] if tool else None, + safety_settings=self.safety_settings + if self.safety_settings + else None, + automatic_function_calling=types.AutomaticFunctionCallingConfig( + disable=True + ), + ), ) - logger.debug(f"result: {result}") + logger.debug(f"gemini result: {result}") - # Developer instruction is not enabled for models/gemini-2.0-flash-exp if "Developer instruction is not enabled" in str(result): - logger.warning( - f"{self.get_model()} 不支持 system prompt, 已自动去除, 将会影响人格设置。" - ) + logger.warning(f"{self.get_model()} 不支持 system prompt,已自动去除。") system_instruction = "" loop = True - + # 不支持函数调用的模型SDK似乎会自动去除,保险起见不删除此行判断。 elif "Function calling is not enabled" in str(result): - logger.warning( - f"{self.get_model()} 不支持函数调用,已自动去除,不影响使用。" - ) + logger.warning(f"{self.get_model()} 不支持函数调用,已自动去除。") tool = None loop = True - elif "Multi-modal output is not supported" in str(result): - logger.warning( - f"{self.get_model()} 不支持多模态输出,降级为文本模态重新请求。" - ) + logger.warning(f"{self.get_model()} 不支持多模态输出,降级为文本模态。") modalites = ["Text"] loop = True - elif "candidates" not in result: - raise Exception("Gemini 返回异常结果: " + str(result)) - - candidates = result["candidates"][0]["content"]["parts"] llm_response = LLMResponse("assistant") chain = [] - for candidate in candidates: - if "text" in candidate: - chain.append(Comp.Plain(candidate["text"])) - elif "functionCall" in candidate: + + finish_reason = result.candidates[0].finish_reason + + if finish_reason == types.FinishReason.SAFETY: + raise Exception("模型生成内容未通过用户定义的内容安全检查") + + if finish_reason in { + types.FinishReason.PROHIBITED_CONTENT, + types.FinishReason.SPII, + types.FinishReason.BLOCKLIST, + types.FinishReason.IMAGE_SAFETY, + }: + raise Exception("模型生成内容违反Gemini平台政策") + + if not result.candidates[0].content.parts: + raise Exception("API 返回的内容为空。") + + for part in result.candidates[0].content.parts: + if part.text: + chain.append(Comp.Plain(part.text)) + elif part.function_call: llm_response.role = "tool" - llm_response.tools_call_args.append(candidate["functionCall"]["args"]) - llm_response.tools_call_name.append(candidate["functionCall"]["name"]) - llm_response.tools_call_ids.append( - candidate["functionCall"]["name"] - ) # 没有 tool id - elif "inlineData" in candidate: - mime_type: str = candidate["inlineData"]["mimeType"] - if mime_type.startswith("image/"): - chain.append(Comp.Image.fromBase64(candidate["inlineData"]["data"])) + llm_response.tools_call_name.append(part.function_call.name) + llm_response.tools_call_args.append(part.function_call.args) + llm_response.tools_call_ids.append(part.function_call.id) + elif part.inline_data and part.inline_data.mime_type.startswith("image/"): + chain.append(Comp.Image.fromBytes(part.inline_data.data)) llm_response.result_chain = MessageChain(chain=chain) + return llm_response async def text_chat( @@ -320,7 +280,6 @@ class ProviderGoogleGenAI(Provider): **kwargs, ) -> LLMResponse: new_record = await self.assemble_context(prompt, image_urls) - context_query = [] context_query = [*contexts, new_record] if system_prompt: context_query.insert(0, {"role": "system", "content": system_prompt}) @@ -345,7 +304,7 @@ class ProviderGoogleGenAI(Provider): for i in range(retry): try: - self.client.api_key = chosen_key + self.chosen_api_key = chosen_key llm_response = await self._query(payloads, func_tool) break except Exception as e: @@ -399,13 +358,13 @@ class ProviderGoogleGenAI(Provider): yield llm_response def get_current_key(self) -> str: - return self.client.api_key + return self.chosen_api_key def get_keys(self) -> List[str]: return self.api_keys def set_key(self, key): - self.client.api_key = key + self.chosen_api_key = key async def assemble_context(self, text: str, image_urls: List[str] = None): """ @@ -444,5 +403,4 @@ class ProviderGoogleGenAI(Provider): return "" async def terminate(self): - await self.client.client.close() logger.info("Google GenAI 适配器已终止。")