diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 28c075e68..f42da1871 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -528,6 +528,7 @@ CONFIG_METADATA_2 = { "model": "gemini-2.0-flash-exp", }, "gm_resp_image_modal": False, + "gm_native_coderunner": False, "gm_safety_settings": { "harassment": "BLOCK_MEDIUM_AND_ABOVE", "hate_speech": "BLOCK_MEDIUM_AND_ABOVE", @@ -704,6 +705,12 @@ CONFIG_METADATA_2 = { "type": "bool", "hint": "启用后,将支持返回图片内容。需要模型支持,否则会报错。具体支持模型请查看 Google Gemini 官方网站。温馨提示,如果您需要生成图片,请关闭 `启用群员识别` 配置获得更好的效果。", }, + "gm_native_coderunner": { + "description": "启用原生代码执行器", + "type": "bool", + "hint": "启用后所有函数工具将全部失效", + "obvious_hint": True, + }, "gm_safety_settings": { "description": "安全过滤器", "type": "object", diff --git a/astrbot/core/provider/entities.py b/astrbot/core/provider/entities.py index d0a9d30a5..6ad67da55 100644 --- a/astrbot/core/provider/entities.py +++ b/astrbot/core/provider/entities.py @@ -155,7 +155,7 @@ class ProviderRequest: if self.image_urls: user_content = { "role": "user", - "content": [{"type": "text", "text": self.prompt}], + "content": [{"type": "text", "text": self.prompt if self.prompt else "[图片]"}], } for image_url in self.image_urls: if image_url.startswith("http"): diff --git a/astrbot/core/provider/sources/gemini_source.py b/astrbot/core/provider/sources/gemini_source.py index a3ca8c8f2..e0d4b56e9 100644 --- a/astrbot/core/provider/sources/gemini_source.py +++ b/astrbot/core/provider/sources/gemini_source.py @@ -1,121 +1,54 @@ -import base64 -import aiohttp -import json -import random import asyncio +import base64 +import json +import logging +import random +from typing import Dict, List, Optional, AsyncGenerator + +from google import genai +from google.genai import types +from google.genai.errors import APIError + import astrbot.core.message.components as Comp -from astrbot.core.message.message_event_result import MessageChain -from astrbot.core.utils.io import download_image_by_url -from astrbot.core.db import BaseDatabase -from astrbot.api.provider import Provider, Personality from astrbot import logger -from astrbot.core.provider.func_tool_manager import FuncCall -from typing import List -from ..register import register_provider_adapter +from astrbot.api.provider import Personality, Provider +from astrbot.core.db import BaseDatabase +from astrbot.core.message.message_event_result import MessageChain from astrbot.core.provider.entities import LLMResponse +from astrbot.core.provider.func_tool_manager import FuncCall +from astrbot.core.utils.io import download_image_by_url + +from ..register import register_provider_adapter -class SimpleGoogleGenAIClient: - def __init__(self, api_key: str, api_base: str, timeout: int = 120) -> None: - self.api_key = api_key - if api_base.endswith("/"): - self.api_base = api_base[:-1] - else: - self.api_base = api_base - self.client = aiohttp.ClientSession(trust_env=True) - self.timeout = timeout +class SuppressNonTextPartsWarning(logging.Filter): + """过滤 Gemini SDK 中的非文本部分警告""" - async def models_list(self) -> List[str]: - request_url = f"{self.api_base}/v1beta/models?key={self.api_key}" - async with self.client.get(request_url, timeout=self.timeout) as resp: - response = await resp.json() + def filter(self, record): + return "there are non-text parts in the response" not in record.getMessage() - models = [] - for model in response["models"]: - if "generateContent" in model["supportedGenerationMethods"]: - models.append(model["name"].replace("models/", "")) - return models - async def generate_content( - self, - contents: List[dict], - model: str = "gemini-1.5-flash", - system_instruction: str = "", - tools: dict = None, - modalities: List[str] = ["Text"], - safety_settings: List[dict] = [], - ): - payload = {} - if system_instruction: - payload["system_instruction"] = {"parts": {"text": system_instruction}} - if tools: - payload["tools"] = [tools] - payload["contents"] = contents - payload["generationConfig"] = { - "responseModalities": modalities, - } - payload["safetySettings"] = [ - {"category": s["category"], "threshold": s["threshold"]} - for s in safety_settings - ] - logger.debug(f"payload: {payload}") - request_url = ( - f"{self.api_base}/v1beta/models/{model}:generateContent?key={self.api_key}" - ) - async with self.client.post( - request_url, json=payload, timeout=self.timeout - ) as resp: - if "application/json" in resp.headers.get("Content-Type"): - try: - response = await resp.json() - except Exception as e: - text = await resp.text() - logger.error(f"Gemini 返回了非 json 数据: {text}") - raise e - return response - else: - text = await resp.text() - logger.error(f"Gemini 返回了非 json 数据: {text}") - raise Exception("Gemini 返回了非 json 数据: ") +logging.getLogger("google_genai.types").addFilter(SuppressNonTextPartsWarning()) - async def stream_generate_content( - self, - contents: List[dict], - model: str = "gemini-1.5-flash", - system_instruction: str = "", - tools: dict = None, - modalities: List[str] = ["Text"], - safety_settings: List[dict] = [], - ): - payload = {} - if system_instruction: - payload["system_instruction"] = {"parts": {"text": system_instruction}} - if tools: - payload["tools"] = [tools] - payload["contents"] = contents - payload["generationConfig"] = { - "responseModalities": modalities, - "stream": True, - } - payload["safetySettings"] = [ - {"category": s["category"], "threshold": s["threshold"]} - for s in safety_settings - ] - logger.debug(f"payload: {payload}") - request_url = ( - f"{self.api_base}/v1beta/models/{model}:streamGenerateContent?key={self.api_key}" - ) - async with self.client.post( - request_url, json=payload, timeout=self.timeout - ) as resp: - async for line in resp.content: - if line: - yield line @register_provider_adapter( "googlegenai_chat_completion", "Google Gemini Chat Completion 提供商适配器" ) class ProviderGoogleGenAI(Provider): + CATEGORY_MAPPING = { + "harassment": types.HarmCategory.HARM_CATEGORY_HARASSMENT, + "hate_speech": types.HarmCategory.HARM_CATEGORY_HATE_SPEECH, + "sexually_explicit": types.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + "dangerous_content": types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + } + + THRESHOLD_MAPPING = { + "BLOCK_NONE": types.HarmBlockThreshold.BLOCK_NONE, + "BLOCK_ONLY_HIGH": types.HarmBlockThreshold.BLOCK_ONLY_HIGH, + "BLOCK_MEDIUM_AND_ABOVE": types.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, + "BLOCK_LOW_AND_ABOVE": types.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, + } + def __init__( self, provider_config: dict, @@ -131,183 +64,351 @@ class ProviderGoogleGenAI(Provider): db_helper, default_persona, ) - self.chosen_api_key = None self.api_keys: List = provider_config.get("key", []) - self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else None - self.timeout = provider_config.get("timeout", 180) - if isinstance(self.timeout, str): - self.timeout = int(self.timeout) - self.client = SimpleGoogleGenAIClient( - api_key=self.chosen_api_key, - api_base=provider_config.get("api_base", None), - timeout=self.timeout, - ) + self.chosen_api_key: str = self.api_keys[0] if len(self.api_keys) > 0 else None + self.timeout: int = int(provider_config.get("timeout", 180)) + + self.api_base: Optional[str] = provider_config.get("api_base", None) + if self.api_base and self.api_base.endswith("/"): + self.api_base = self.api_base[:-1] + + self._init_client() self.set_model(provider_config["model_config"]["model"]) + self._init_safety_settings() - safety_mapping = { - "harassment": "HARM_CATEGORY_HARASSMENT", - "hate_speech": "HARM_CATEGORY_HATE_SPEECH", - "sexually_explicit": "HARM_CATEGORY_SEXUALLY_EXPLICIT", - "dangerous_content": "HARM_CATEGORY_DANGEROUS_CONTENT", - } + def _init_client(self) -> None: + """初始化Gemini客户端""" + self.client = genai.Client( + api_key=self.chosen_api_key, + http_options=types.HttpOptions( + base_url=self.api_base, + timeout=self.timeout * 1000, # 毫秒 + ), + ).aio - self.safety_settings = [] + def _init_safety_settings(self) -> None: + """初始化安全设置""" user_safety_config = self.provider_config.get("gm_safety_settings", {}) - for config_key, harm_category in safety_mapping.items(): - if threshold := user_safety_config.get(config_key): - self.safety_settings.append( - {"category": harm_category, "threshold": threshold} + self.safety_settings = [ + types.SafetySetting( + category=harm_category, threshold=self.THRESHOLD_MAPPING[threshold_str] + ) + for config_key, harm_category in self.CATEGORY_MAPPING.items() + if (threshold_str := user_safety_config.get(config_key)) + and threshold_str in self.THRESHOLD_MAPPING + ] + + async def _handle_api_error(self, e: APIError, keys: List[str]) -> bool: + """处理API错误,返回是否需要重试""" + if e.code == 429 or "API key not valid" in e.message: + keys.remove(self.chosen_api_key) + if len(keys) > 0: + self.set_key(random.choice(keys)) + logger.info( + f"检测到 Key 异常({e.message}),正在尝试更换 API Key 重试... 当前 Key: {self.chosen_api_key[:12]}..." ) + await asyncio.sleep(1) + return True + else: + logger.error( + f"检测到 Key 异常({e.message}),且已没有可用的 Key。 当前 Key: {self.chosen_api_key[:12]}..." + ) + raise Exception("达到了 Gemini 速率限制, 请稍后再试...") + else: + logger.error( + f"发生了错误(gemini_source)。Provider 配置如下: {self.provider_config}" + ) + raise e - async def get_models(self): - return await self.client.models_list() + async def _prepare_query_config( + self, + tools: Optional[FuncCall] = None, + system_instruction: Optional[str] = None, + temperature: Optional[float] = 0.7, + modalities: Optional[List[str]] = None, + ) -> types.GenerateContentConfig: + """准备查询配置""" + if not modalities: + modalities = ["Text"] - async def _query(self, payloads: dict, tools: FuncCall) -> LLMResponse: - tool = None - if tools: - tool = tools.get_func_desc_google_genai_style() - if not tool: - tool = None + # 流式输出不支持图片模态 + if ( + self.provider_settings.get("streaming_response", False) + and "Image" in modalities + ): + logger.warning("流式输出不支持图片模态,已自动降级为文本模态") + modalities = ["Text"] - system_instruction = "" + tool_list = None + if self.provider_config.get("gm_native_coderunner", False): + if tools: + logger.warning("Gemini原生代码执行器已启用,函数工具将被忽略") + tool_list = [types.Tool(code_execution=types.ToolCodeExecution())] + elif tools and (func_desc := tools.get_func_desc_google_genai_style()): + tool_list = [ + types.Tool(function_declarations=func_desc["function_declarations"]) + ] + + return types.GenerateContentConfig( + system_instruction=system_instruction, + temperature=temperature, + response_modalities=modalities, + tools=tool_list, + safety_settings=self.safety_settings if self.safety_settings else None, + automatic_function_calling=types.AutomaticFunctionCallingConfig( + disable=True + ), + ) + + @staticmethod + def _prepare_conversation(payloads: Dict) -> List[types.Content]: + """准备 Gemini SDK 的 Content 列表""" + + def create_text_part(text: str) -> types.UserContent: + content_a = text if text else " " + if not text: + logger.warning("文本内容为空,已添加空格占位") + return types.UserContent(parts=[types.Part.from_text(text=content_a)]) + + def process_image_url(image_url_dict: dict) -> types.Part: + url = image_url_dict["url"] + mime_type = url.split(":")[1].split(";")[0] + image_bytes = base64.b64decode(url.split(",", 1)[1]) + return types.Part.from_bytes(data=image_bytes, mime_type=mime_type) + + gemini_contents: List[types.Content] = [] for message in payloads["messages"]: - if message["role"] == "system": - system_instruction = message["content"] - break + role, content = message["role"], message.get("content") - google_genai_conversation = [] - for message in payloads["messages"]: - if message["role"] == "user": - if isinstance(message["content"], str): - if not message["content"]: - message["content"] = " " + if role == "user": + if isinstance(content, str): + gemini_contents.append(create_text_part(content)) + elif isinstance(content, list): + parts = [ + types.Part.from_text(text=item["text"] or " ") + if item["type"] == "text" + else process_image_url(item["image_url"]) + for item in content + ] + gemini_contents.append(types.UserContent(parts=parts)) - google_genai_conversation.append( - {"role": "user", "parts": [{"text": message["content"]}]} - ) - elif isinstance(message["content"], list): - # images - parts = [] - for part in message["content"]: - if part["type"] == "text": - if not part["text"]: - part["text"] = "" - parts.append({"text": part["text"]}) - elif part["type"] == "image_url": - parts.append( - { - "inline_data": { - "mime_type": "image/jpeg", - "data": part["image_url"]["url"].replace( - "data:image/jpeg;base64,", "" - ), # base64 - } - } - ) - google_genai_conversation.append({"role": "user", "parts": parts}) - - elif message["role"] == "assistant": - if "content" in message: - if not message["content"]: - message["content"] = " " - google_genai_conversation.append( - {"role": "model", "parts": [{"text": message["content"]}]} + elif role == "assistant": + if content: + gemini_contents.append( + types.ModelContent(parts=[types.Part.from_text(text=content)]) ) elif "tool_calls" in message: - # tool calls in the last turn - parts = [] - for tool_call in message["tool_calls"]: - parts.append( - { - "functionCall": { - "name": tool_call["function"]["name"], - "args": json.loads( - tool_call["function"]["arguments"] - ), - } - } - ) - google_genai_conversation.append({"role": "model", "parts": parts}) - elif message["role"] == "tool": - parts = [] - parts.append( - { - "functionResponse": { - "name": message["tool_call_id"], - "response": { - "name": message["tool_call_id"], - "content": message["content"], - }, - } - } + gemini_contents.extend( + [ + types.ModelContent( + parts=[ + types.Part.from_function_call( + name=tool["function"]["name"], + args=json.loads(tool["function"]["arguments"]), + ) + ] + ) + for tool in message["tool_calls"] + ] + ) + else: + logger.warning("assistant 角色的消息内容为空,已添加空格占位") + gemini_contents.append( + types.ModelContent(parts=[types.Part.from_text(text=" ")]) + ) + + elif role == "tool": + gemini_contents.append( + types.UserContent( + parts=[ + types.Part.from_function_response( + name=message["tool_call_id"], + response={ + "name": message["tool_call_id"], + "content": message["content"], + }, + ) + ] + ) ) - google_genai_conversation.append({"role": "user", "parts": parts}) - logger.debug(f"google_genai_conversation: {google_genai_conversation}") + return gemini_contents - modalites = ["Text"] - if self.provider_config.get("gm_resp_image_modal", False): - modalites.append("Image") + @staticmethod + def _process_content_parts( + result: types.GenerateContentResponse, llm_response: LLMResponse + ) -> MessageChain: + """处理内容部分并构建消息链""" + finish_reason = result.candidates[0].finish_reason + result_parts: Optional[types.Part] = result.candidates[0].content.parts - loop = True - while loop: - loop = False - result = await self.client.generate_content( - contents=google_genai_conversation, - model=self.get_model(), - system_instruction=system_instruction, - tools=tool, - modalities=modalites, - safety_settings=self.safety_settings, - ) - logger.debug(f"result: {result}") + if finish_reason == types.FinishReason.SAFETY: + raise Exception("模型生成内容未通过用户定义的内容安全检查") - # Developer instruction is not enabled for models/gemini-2.0-flash-exp - if "Developer instruction is not enabled" in str(result): - logger.warning( - f"{self.get_model()} 不支持 system prompt, 已自动去除, 将会影响人格设置。" - ) - system_instruction = "" - loop = True + if finish_reason in { + types.FinishReason.PROHIBITED_CONTENT, + types.FinishReason.SPII, + types.FinishReason.BLOCKLIST, + types.FinishReason.IMAGE_SAFETY, + }: + raise Exception("模型生成内容违反Gemini平台政策") - elif "Function calling is not enabled" in str(result): - logger.warning( - f"{self.get_model()} 不支持函数调用,已自动去除,不影响使用。" - ) - tool = None - loop = True + if not result_parts: + logger.debug(result.candidates) + raise Exception("API 返回的内容为空。") - elif "Multi-modal output is not supported" in str(result): - logger.warning( - f"{self.get_model()} 不支持多模态输出,降级为文本模态重新请求。" - ) - modalites = ["Text"] - loop = True - - elif "candidates" not in result: - raise Exception("Gemini 返回异常结果: " + str(result)) - - candidates = result["candidates"][0]["content"]["parts"] - llm_response = LLMResponse("assistant") chain = [] - for candidate in candidates: - if "text" in candidate: - chain.append(Comp.Plain(candidate["text"])) - elif "functionCall" in candidate: - llm_response.role = "tool" - llm_response.tools_call_args.append(candidate["functionCall"]["args"]) - llm_response.tools_call_name.append(candidate["functionCall"]["name"]) - llm_response.tools_call_ids.append( - candidate["functionCall"]["name"] - ) # 没有 tool id - elif "inlineData" in candidate: - mime_type: str = candidate["inlineData"]["mimeType"] - if mime_type.startswith("image/"): - chain.append(Comp.Image.fromBase64(candidate["inlineData"]["data"])) + part: types.Part - llm_response.result_chain = MessageChain(chain=chain) + # 暂时这样Fallback + if all( + part.inline_data and part.inline_data.mime_type.startswith("image/") + for part in result_parts + ): + chain.append(Comp.Plain("这是图片")) + for part in result_parts: + if part.text: + chain.append(Comp.Plain(part.text)) + elif part.function_call: + llm_response.role = "tool" + llm_response.tools_call_name.append(part.function_call.name) + llm_response.tools_call_args.append(part.function_call.args) + # gemini 返回的 function_call.id 可能为 None + llm_response.tools_call_ids.append( + part.function_call.id or part.function_call.name + ) + elif part.inline_data and part.inline_data.mime_type.startswith("image/"): + chain.append(Comp.Image.fromBytes(part.inline_data.data)) + return MessageChain(chain=chain) + + async def _query( + self, payloads: dict, tools: FuncCall, temperature: float = 0.7 + ) -> LLMResponse: + """非流式请求 Gemini API""" + system_instruction = next( + (msg["content"] for msg in payloads["messages"] if msg["role"] == "system"), + None, + ) + + modalities = ["Text"] + if self.provider_config.get("gm_resp_image_modal", False): + modalities.append("Image") + + conversation = self._prepare_conversation(payloads) + + result: Optional[types.GenerateContentResponse] = None + while True: + try: + config = await self._prepare_query_config( + tools, system_instruction, temperature, modalities + ) + result = await self.client.models.generate_content( + model=self.get_model(), + contents=conversation, + config=config, + ) + + if result.candidates[0].finish_reason == types.FinishReason.RECITATION: + if temperature > 2: + raise Exception("温度参数已超过最大值2,仍然发生recitation") + temperature += 0.2 + logger.warning( + f"发生了recitation,正在提高温度至{temperature:.1f}重试..." + ) + continue + + break + + except APIError as e: + if "Developer instruction is not enabled" in e.message: + logger.warning( + f"{self.get_model()} 不支持 system prompt,已自动去除(影响人格设置)" + ) + system_instruction = None + elif "Function calling is not enabled" in e.message: + logger.warning(f"{self.get_model()} 不支持函数调用,已自动去除") + tools = None + elif ( + "Multi-modal output is not supported" in e.message + or "Model does not support the requested response modalities" + in e.message + ): + logger.warning( + f"{self.get_model()} 不支持多模态输出,降级为文本模态" + ) + modalities = ["Text"] + else: + raise + continue + + llm_response = LLMResponse("assistant") + llm_response.result_chain = self._process_content_parts(result, llm_response) return llm_response + async def _query_stream( + self, payloads: dict, tools: FuncCall, temperature: float = 0.7 + ) -> AsyncGenerator[LLMResponse, None]: + """流式请求 Gemini API""" + system_instruction = next( + (msg["content"] for msg in payloads["messages"] if msg["role"] == "system"), + None, + ) + + conversation = self._prepare_conversation(payloads) + + result = None + while True: + try: + config = await self._prepare_query_config( + tools, system_instruction, temperature + ) + result = await self.client.models.generate_content_stream( + model=self.get_model(), + contents=conversation, + config=config, + ) + break + except APIError as e: + if "Developer instruction is not enabled" in e.message: + logger.warning( + f"{self.get_model()} 不支持 system prompt,已自动去除(影响人格设置)" + ) + system_instruction = None + elif "Function calling is not enabled" in e.message: + logger.warning(f"{self.get_model()} 不支持函数调用,已自动去除") + tools = None + else: + raise + continue + + async for chunk in result: + llm_response = LLMResponse("assistant", is_chunk=True) + + if chunk.candidates[0].content.parts and any( + part.function_call for part in chunk.candidates[0].content.parts + ): + llm_response = LLMResponse("assistant", is_chunk=False) + llm_response.result_chain = self._process_content_parts( + chunk, llm_response + ) + yield llm_response + break + + if chunk.text: + llm_response.result_chain = MessageChain(chain=[Comp.Plain(chunk.text)]) + yield llm_response + + if chunk.candidates[0].finish_reason: + llm_response = LLMResponse("assistant", is_chunk=False) + if not chunk.candidates[0].content.parts: + llm_response.result_chain = MessageChain(chain=[Comp.Plain(" ")]) + else: + llm_response.result_chain = self._process_content_parts( + chunk, llm_response + ) + yield llm_response + break + async def text_chat( self, prompt: str, @@ -320,7 +421,6 @@ class ProviderGoogleGenAI(Provider): **kwargs, ) -> LLMResponse: new_record = await self.assemble_context(prompt, image_urls) - context_query = [] context_query = [*contexts, new_record] if system_prompt: context_query.insert(0, {"role": "system", "content": system_prompt}) @@ -337,82 +437,92 @@ class ProviderGoogleGenAI(Provider): model_config["model"] = self.get_model() payloads = {"messages": context_query, **model_config} - llm_response = None retry = 10 keys = self.api_keys.copy() - chosen_key = random.choice(keys) + temp = kwargs.get("temperature", 0.7) # 暂定默认温度为0.7 - for i in range(retry): + for _ in range(retry): try: - self.client.api_key = chosen_key - llm_response = await self._query(payloads, func_tool) + return await self._query(payloads, func_tool, temp) + except APIError as e: + if await self._handle_api_error(e, keys): + continue break - except Exception as e: - if "429" in str(e) or "API key not valid" in str(e): - keys.remove(chosen_key) - if len(keys) > 0: - chosen_key = random.choice(keys) - logger.info( - f"检测到 Key 异常({str(e)}),正在尝试更换 API Key 重试... 当前 Key: {chosen_key[:12]}..." - ) - await asyncio.sleep(1) - continue - else: - logger.error( - f"检测到 Key 异常({str(e)}),且已没有可用的 Key。 当前 Key: {chosen_key[:12]}..." - ) - raise Exception("达到了 Gemini 速率限制, 请稍后再试...") - else: - logger.error( - f"发生了错误(gemini_source)。Provider 配置如下: {self.provider_config}" - ) - raise e - - return llm_response async def text_chat_stream( self, - prompt, - session_id=None, - image_urls=..., - func_tool=None, - contexts=..., + prompt: str, + session_id: str = None, + image_urls: List[str] = [], + func_tool: FuncCall = None, + contexts=[], system_prompt=None, tool_calls_result=None, **kwargs, - ): - # raise NotImplementedError("This method is not implemented yet.") - # 调用 text_chat 模拟流式 - llm_response = await self.text_chat( - prompt=prompt, - session_id=session_id, - image_urls=image_urls, - func_tool=func_tool, - contexts=contexts, - system_prompt=system_prompt, - tool_calls_result=tool_calls_result, - ) - llm_response.is_chunk = True - yield llm_response - llm_response.is_chunk = False - yield llm_response + ) -> AsyncGenerator[LLMResponse, None]: + new_record = await self.assemble_context(prompt, image_urls) + context_query = [*contexts, new_record] + if system_prompt: + context_query.insert(0, {"role": "system", "content": system_prompt}) + + for part in context_query: + if "_no_save" in part: + del part["_no_save"] + + # tool calls result + if tool_calls_result: + context_query.extend(tool_calls_result.to_openai_messages()) + + model_config = self.provider_config.get("model_config", {}) + model_config["model"] = self.get_model() + + payloads = {"messages": context_query, **model_config} + + retry = 10 + keys = self.api_keys.copy() + temp = kwargs.get("temperature", 0.7) # 暂定默认温度为0.7 + + for _ in range(retry): + try: + async for response in self._query_stream(payloads, func_tool, temp): + yield response + break + except APIError as e: + if await self._handle_api_error(e, keys): + continue + break + + async def get_models(self): + try: + models = await self.client.models.list() + return [ + m.name.replace("models/", "") + for m in models + if "generateContent" in m.supported_actions + ] + except APIError as e: + raise Exception(f"获取模型列表失败: {e.message}") def get_current_key(self) -> str: - return self.client.api_key + return self.chosen_api_key def get_keys(self) -> List[str]: return self.api_keys def set_key(self, key): - self.client.api_key = key + self.chosen_api_key = key + self._init_client() async def assemble_context(self, text: str, image_urls: List[str] = None): """ 组装上下文。 """ if image_urls: - user_content = {"role": "user", "content": [{"type": "text", "text": text}]} + user_content = { + "role": "user", + "content": [{"type": "text", "text": text if text else "[图片]"}], + } for image_url in image_urls: if image_url.startswith("http"): image_path = await download_image_by_url(image_url) @@ -444,5 +554,4 @@ class ProviderGoogleGenAI(Provider): return "" async def terminate(self): - await self.client.client.close() logger.info("Google GenAI 适配器已终止。") diff --git a/astrbot/core/provider/sources/openai_source.py b/astrbot/core/provider/sources/openai_source.py index 110d0d435..f4e02b5f5 100644 --- a/astrbot/core/provider/sources/openai_source.py +++ b/astrbot/core/provider/sources/openai_source.py @@ -505,7 +505,7 @@ class ProviderOpenAIOfficial(Provider): async def assemble_context(self, text: str, image_urls: List[str] = None) -> dict: """组装成符合 OpenAI 格式的 role 为 user 的消息段""" if image_urls: - user_content = {"role": "user", "content": [{"type": "text", "text": text}]} + user_content = {"role": "user", "content": [{"type": "text", "text": text if text else "[图片]"}]} for image_url in image_urls: if image_url.startswith("http"): image_path = await download_image_by_url(image_url) diff --git a/pyproject.toml b/pyproject.toml index 93ad29f42..0f3a40ce8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,6 +19,7 @@ dependencies = [ "defusedxml>=0.7.1", "dingtalk-stream>=0.22.1", "docstring-parser>=0.16", + "google-genai>=1.10.0", "googlesearch-python>=1.3.0", "lark-oapi>=1.4.12", "lxml-html-clean>=0.4.1", diff --git a/requirements.txt b/requirements.txt index e20771e86..2e0ab1ccc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -29,4 +29,5 @@ defusedxml mcp certifi pip -telegramify-markdown \ No newline at end of file +telegramify-markdown +google-genai \ No newline at end of file diff --git a/uv.lock b/uv.lock index 445941a7d..5c71940d7 100644 --- a/uv.lock +++ b/uv.lock @@ -209,6 +209,7 @@ dependencies = [ { name = "defusedxml" }, { name = "dingtalk-stream" }, { name = "docstring-parser" }, + { name = "google-genai" }, { name = "googlesearch-python" }, { name = "lark-oapi" }, { name = "lxml-html-clean" }, @@ -245,6 +246,7 @@ requires-dist = [ { name = "defusedxml", specifier = ">=0.7.1" }, { name = "dingtalk-stream", specifier = ">=0.22.1" }, { name = "docstring-parser", specifier = ">=0.16" }, + { name = "google-genai", specifier = ">=1.10.0" }, { name = "googlesearch-python", specifier = ">=1.3.0" }, { name = "lark-oapi", specifier = ">=1.4.12" }, { name = "lxml-html-clean", specifier = ">=0.4.1" }, @@ -305,6 +307,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/10/cb/f2ad4230dc2eb1a74edf38f1a38b9b52277f75bef262d8908e60d957e13c/blinker-1.9.0-py3-none-any.whl", hash = "sha256:ba0efaa9080b619ff2f3459d1d500c57bddea4a6b424b60a91141db6fd2f08bc", size = 8458 }, ] +[[package]] +name = "cachetools" +version = "5.5.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/6c/81/3747dad6b14fa2cf53fcf10548cf5aea6913e96fab41a3c198676f8948a5/cachetools-5.5.2.tar.gz", hash = "sha256:1a661caa9175d26759571b2e19580f9d6393969e5dfca11fdb1f947a23e640d4", size = 28380 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/72/76/20fa66124dbe6be5cafeb312ece67de6b61dd91a0247d1ea13db4ebb33c2/cachetools-5.5.2-py3-none-any.whl", hash = "sha256:d26a22bcc62eb95c3beabd9f1ee5e820d3d2704fe2967cbe350e20c8ffcd3f0a", size = 10080 }, +] + [[package]] name = "certifi" version = "2025.1.31" @@ -676,6 +687,38 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c6/c8/a5be5b7550c10858fcf9b0ea054baccab474da77d37f1e828ce043a3a5d4/frozenlist-1.5.0-py3-none-any.whl", hash = "sha256:d994863bba198a4a518b467bb971c56e1db3f180a25c6cf7bb1949c267f748c3", size = 11901 }, ] +[[package]] +name = "google-auth" +version = "2.39.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cachetools" }, + { name = "pyasn1-modules" }, + { name = "rsa" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/cb/8e/8f45c9a32f73e786e954b8f9761c61422955d23c45d1e8c347f9b4b59e8e/google_auth-2.39.0.tar.gz", hash = "sha256:73222d43cdc35a3aeacbfdcaf73142a97839f10de930550d89ebfe1d0a00cde7", size = 274834 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ce/12/ad37a1ef86006d0a0117fc06a4a00bd461c775356b534b425f00dde208ea/google_auth-2.39.0-py2.py3-none-any.whl", hash = "sha256:0150b6711e97fb9f52fe599f55648950cc4540015565d8fbb31be2ad6e1548a2", size = 212319 }, +] + +[[package]] +name = "google-genai" +version = "1.10.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, + { name = "google-auth" }, + { name = "httpx" }, + { name = "pydantic" }, + { name = "requests" }, + { name = "typing-extensions" }, + { name = "websockets" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/0e/7a/224e2f70c835202042969685ee3da00a6475508d1b64f0f1e90144f96beb/google_genai-1.10.0.tar.gz", hash = "sha256:f59423e0f155dc66b7792c8a0e6724c75c72dc699d1eb7907d4d0006d4f6186f", size = 156355 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ba/a0/56839a2e202d79c773edd1c1db124da8eb2a7b657267a888080b678d0369/google_genai-1.10.0-py3-none-any.whl", hash = "sha256:41b105a2fcf8a027fc45cc16694cd559b8cd1272eab7345ad58cfa2c353bf34f", size = 154705 }, +] + [[package]] name = "googlesearch-python" version = "1.3.0" @@ -1402,6 +1445,27 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/50/1b/6921afe68c74868b4c9fa424dad3be35b095e16687989ebbb50ce4fceb7c/psutil-7.0.0-cp37-abi3-win_amd64.whl", hash = "sha256:4cf3d4eb1aa9b348dec30105c55cd9b7d4629285735a102beb4441e38db90553", size = 244885 }, ] +[[package]] +name = "pyasn1" +version = "0.6.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ba/e9/01f1a64245b89f039897cb0130016d79f77d52669aae6ee7b159a6c4c018/pyasn1-0.6.1.tar.gz", hash = "sha256:6f580d2bdd84365380830acf45550f2511469f673cb4a5ae3857a3170128b034", size = 145322 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c8/f1/d6a797abb14f6283c0ddff96bbdd46937f64122b8c925cab503dd37f8214/pyasn1-0.6.1-py3-none-any.whl", hash = "sha256:0d632f46f2ba09143da3a8afe9e33fb6f92fa2320ab7e886e2d0f7672af84629", size = 83135 }, +] + +[[package]] +name = "pyasn1-modules" +version = "0.4.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pyasn1" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/e9/e6/78ebbb10a8c8e4b61a59249394a4a594c1a7af95593dc933a349c8d00964/pyasn1_modules-0.4.2.tar.gz", hash = "sha256:677091de870a80aae844b1ca6134f54652fa2c8c5a52aa396440ac3106e941e6", size = 307892 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/47/8d/d529b5d697919ba8c11ad626e835d4039be708a35b0d22de83a269a6682c/pyasn1_modules-0.4.2-py3-none-any.whl", hash = "sha256:29253a9207ce32b64c3ac6600edc75368f98473906e8fd1043bd6b5b1de2c14a", size = 181259 }, +] + [[package]] name = "pycparser" version = "2.22" @@ -1697,6 +1761,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/3f/51/d4db610ef29373b879047326cbf6fa98b6c1969d6f6dc423279de2b1be2c/requests_toolbelt-1.0.0-py2.py3-none-any.whl", hash = "sha256:cccfdd665f0a24fcf4726e690f65639d272bb0637b9b92dfd91a5568ccf6bd06", size = 54481 }, ] +[[package]] +name = "rsa" +version = "4.9" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pyasn1" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/aa/65/7d973b89c4d2351d7fb232c2e452547ddfa243e93131e7cfa766da627b52/rsa-4.9.tar.gz", hash = "sha256:e38464a49c6c85d7f1351b0126661487a7e0a14a50f1675ec50eb34d4f20ef21", size = 29711 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/49/97/fa78e3d2f65c02c8e1268b9aba606569fe97f6c8f7c2d74394553347c145/rsa-4.9-py3-none-any.whl", hash = "sha256:90260d9058e514786967344d0ef75fa8727eed8a7d2e43ce9f4bcf1b536174f7", size = 34315 }, +] + [[package]] name = "silk-python" version = "0.2.6"