From bd24cf3ea40b3943574f3a787ade3cb125bdbff3 Mon Sep 17 00:00:00 2001 From: Raven95676 Date: Fri, 11 Apr 2025 23:45:30 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=88=9D=E6=AD=A5=E5=AE=8C=E6=88=90?= =?UTF-8?q?=E5=8E=9F=E7=94=9F=E6=B5=81=E5=BC=8F=E8=AF=B7=E6=B1=82=E9=80=BB?= =?UTF-8?q?=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../core/provider/sources/gemini_source.py | 221 +++++++++++++----- 1 file changed, 165 insertions(+), 56 deletions(-) diff --git a/astrbot/core/provider/sources/gemini_source.py b/astrbot/core/provider/sources/gemini_source.py index 1447b9e2c..093955c52 100644 --- a/astrbot/core/provider/sources/gemini_source.py +++ b/astrbot/core/provider/sources/gemini_source.py @@ -2,10 +2,11 @@ import asyncio import base64 import json import random -from typing import Dict, List, Optional +from typing import Dict, List, Optional, AsyncGenerator from google import genai from google.genai import types +from google.genai.errors import APIError import astrbot.core.message.components as Comp from astrbot import logger @@ -87,8 +88,8 @@ class ProviderGoogleGenAI(Provider): for m in models if "generateContent" in m.supported_actions ] - except Exception as e: - raise Exception(f"获取模型列表失败: {e}") + except APIError as e: + raise Exception(f"获取模型列表失败: {e.message}") @staticmethod def _prepare_conversation(payloads: Dict) -> List[types.Content]: @@ -168,17 +169,18 @@ class ProviderGoogleGenAI(Provider): self, payloads: dict, tools: FuncCall, temperature: float = 0.7 ) -> LLMResponse: """非流式请求 Gemini API""" - tool_list = [] - if func_desc := tools.get_func_desc_google_genai_style() if tools else None: - tool_list.append( - types.Tool(function_declarations=func_desc["function_declarations"]) - ) + tool_list = None + if tools: + func_desc = tools.get_func_desc_google_genai_style() + if func_desc: + tool_list = [ + types.Tool(function_declarations=func_desc["function_declarations"]) + ] - system_instruction = "" - for message in payloads["messages"]: - if message["role"] == "system": - system_instruction = message["content"] - break + system_instruction = next( + (msg["content"] for msg in payloads["messages"] if msg["role"] == "system"), + None, + ) conversation = self._prepare_conversation(payloads) @@ -217,20 +219,19 @@ class ProviderGoogleGenAI(Provider): break - except Exception as e: - error_msg = str(e) - if "Developer instruction is not enabled" in error_msg: + except APIError as e: + if "Developer instruction is not enabled" in e.message: logger.warning( f"{self.get_model()} 不支持 system prompt,已自动去除(影响人格设置)" ) system_instruction = None - elif "Function calling is not enabled" in error_msg: + elif "Function calling is not enabled" in e.message: logger.warning(f"{self.get_model()} 不支持函数调用,已自动去除") tool_list = None elif ( "Multi-modal output is not supported" or "Model does not support the requested response modalities" - in error_msg + in e.message ): logger.warning( f"{self.get_model()} 不支持多模态输出,降级为文本模态" @@ -241,8 +242,95 @@ class ProviderGoogleGenAI(Provider): continue llm_response = LLMResponse("assistant") - result_parts: Optional[types.Part] = result.candidates[0].content.parts + llm_response.result_chain = self._process_content_parts(result, llm_response) + return llm_response + + async def _query_stream( + self, payloads: dict, tools: FuncCall, temperature: float = 0.7 + ) -> AsyncGenerator[LLMResponse, None]: + """流式请求 Gemini API""" + tool_list = None + if tools: + func_desc = tools.get_func_desc_google_genai_style() + if func_desc: + tool_list = [ + types.Tool(function_declarations=func_desc["function_declarations"]) + ] + + system_instruction = next( + (msg["content"] for msg in payloads["messages"] if msg["role"] == "system"), + None, + ) + + conversation = self._prepare_conversation(payloads) + + result = None + while True: + try: + result = await self.client.models.generate_content_stream( + model=self.get_model(), + contents=conversation, + config=types.GenerateContentConfig( + system_instruction=system_instruction, + temperature=temperature, + tools=tool_list, + safety_settings=self.safety_settings + if self.safety_settings + else None, + automatic_function_calling=types.AutomaticFunctionCallingConfig( + disable=True + ), + ), + ) + + break + + except APIError as e: + if "Developer instruction is not enabled" in e.message: + logger.warning( + f"{self.get_model()} 不支持 system prompt,已自动去除(影响人格设置)" + ) + system_instruction = None + elif "Function calling is not enabled" in e.message: + logger.warning(f"{self.get_model()} 不支持函数调用,已自动去除") + tool_list = None + else: + raise + continue + + if not result: + raise Exception("API 返回异常") + + async for chunk in result: + llm_response = LLMResponse("assistant", is_chunk=True) + + if chunk.candidates[0].content.parts and any( + part.function_call for part in chunk.candidates[0].content.parts + ): + response = LLMResponse("assistant", is_chunk=False) + response.result_chain = self._process_content_parts(chunk, response) + yield response + break + + if chunk.text: + llm_response.result_chain = MessageChain(chain=[Comp.Plain(chunk.text)]) + yield llm_response + + if chunk.candidates[0].finish_reason: + llm_response = LLMResponse("assistant", is_chunk=False) + llm_response.result_chain = self._process_content_parts( + chunk, llm_response + ) + yield llm_response + break + + @staticmethod + def _process_content_parts( + result: types.GenerateContentResponse, llm_response: LLMResponse + ) -> MessageChain: + """处理内容部分并构建消息链""" finish_reason = result.candidates[0].finish_reason + result_parts: Optional[types.Part] = result.candidates[0].content.parts if finish_reason == types.FinishReason.SAFETY: raise Exception("模型生成内容未通过用户定义的内容安全检查") @@ -259,20 +347,9 @@ class ProviderGoogleGenAI(Provider): logger.debug(result.candidates) raise Exception("API 返回的内容为空。") - llm_response.result_chain = self._process_content_parts( - result_parts, llm_response - ) - - return llm_response - - @staticmethod - def _process_content_parts( - parts: types.Part, llm_response: LLMResponse - ) -> MessageChain: - """处理内容部分并构建消息链""" chain = [] part: types.Part - for part in parts: + for part in result_parts: if part.text: chain.append(Comp.Plain(part.text)) elif part.function_call: @@ -322,19 +399,19 @@ class ProviderGoogleGenAI(Provider): try: llm_response = await self._query(payloads, func_tool, temp) break - except Exception as e: - if "429" in str(e) or "API key not valid" in str(e): + except APIError as e: + if e.code == 429 or "API key not valid" in e.message: keys.remove(self.chosen_api_key) if len(keys) > 0: self.set_key(random.choice(keys)) logger.info( - f"检测到 Key 异常({str(e)}),正在尝试更换 API Key 重试... 当前 Key: {self.chosen_api_key[:12]}..." + f"检测到 Key 异常({e.message}),正在尝试更换 API Key 重试... 当前 Key: {self.chosen_api_key[:12]}..." ) await asyncio.sleep(1) continue else: logger.error( - f"检测到 Key 异常({str(e)}),且已没有可用的 Key。 当前 Key: {self.chosen_api_key[:12]}..." + f"检测到 Key 异常({e.message}),且已没有可用的 Key。 当前 Key: {self.chosen_api_key[:12]}..." ) raise Exception("达到了 Gemini 速率限制, 请稍后再试...") else: @@ -347,30 +424,62 @@ class ProviderGoogleGenAI(Provider): async def text_chat_stream( self, - prompt, - session_id=None, - image_urls=..., - func_tool=None, - contexts=..., + prompt: str, + session_id: str = None, + image_urls: List[str] = [], + func_tool: FuncCall = None, + contexts=[], system_prompt=None, tool_calls_result=None, **kwargs, - ): - # raise NotImplementedError("This method is not implemented yet.") - # 调用 text_chat 模拟流式 - llm_response = await self.text_chat( - prompt=prompt, - session_id=session_id, - image_urls=image_urls, - func_tool=func_tool, - contexts=contexts, - system_prompt=system_prompt, - tool_calls_result=tool_calls_result, - ) - llm_response.is_chunk = True - yield llm_response - llm_response.is_chunk = False - yield llm_response + ) -> AsyncGenerator[LLMResponse, None]: + new_record = await self.assemble_context(prompt, image_urls) + context_query = [*contexts, new_record] + if system_prompt: + context_query.insert(0, {"role": "system", "content": system_prompt}) + + for part in context_query: + if "_no_save" in part: + del part["_no_save"] + + # tool calls result + if tool_calls_result: + context_query.extend(tool_calls_result.to_openai_messages()) + + model_config = self.provider_config.get("model_config", {}) + model_config["model"] = self.get_model() + + payloads = {"messages": context_query, **model_config} + + retry = 10 + keys = self.api_keys.copy() + temp = kwargs.get("temperature", 0.7) # 暂定默认温度为0.7 + + for _ in range(retry): + try: + async for response in self._query_stream(payloads, func_tool, temp): + yield response + break + except APIError as e: + if e.code == 429 or "API key not valid" in e.message: + keys.remove(self.chosen_api_key) + if len(keys) > 0: + self.set_key(random.choice(keys)) + logger.info( + f"检测到 Key 异常({e.message}),正在尝试更换 API Key 重试... 当前 Key: {self.chosen_api_key[:12]}..." + ) + await asyncio.sleep(1) + continue + else: + logger.error( + f"检测到 Key 异常({e.message}),且已没有可用的 Key。 当前 Key: {self.chosen_api_key[:12]}..." + ) + raise Exception("达到了 Gemini 速率限制, 请稍后再试...") + else: + logger.error( + f"发生了错误(gemini_source)。Provider 配置如下: {self.provider_config}" + ) + raise e def get_current_key(self) -> str: return self.chosen_api_key