From 0dc5b4cdfc79bdad323f4f491a02ceac29e6bb94 Mon Sep 17 00:00:00 2001 From: Raven95676 Date: Fri, 11 Apr 2025 12:25:44 +0800 Subject: [PATCH] =?UTF-8?q?perf:=20=E5=A2=9E=E5=8A=A0=E5=AF=B9RECITATION?= =?UTF-8?q?=E5=AE=8C=E6=88=90=E5=8E=9F=E5=9B=A0=E7=9A=84=E5=A4=84=E7=90=86?= =?UTF-8?q?=EF=BC=8C=E6=8F=90=E5=8F=96=E5=86=85=E5=AE=B9=E5=A4=84=E7=90=86?= =?UTF-8?q?=E9=80=BB=E8=BE=91=E5=88=B0=E7=8B=AC=E7=AB=8B=E6=96=B9=E6=B3=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../core/provider/sources/gemini_source.py | 54 +++++++++++++------ 1 file changed, 38 insertions(+), 16 deletions(-) diff --git a/astrbot/core/provider/sources/gemini_source.py b/astrbot/core/provider/sources/gemini_source.py index b2fd0a429..89ff69559 100644 --- a/astrbot/core/provider/sources/gemini_source.py +++ b/astrbot/core/provider/sources/gemini_source.py @@ -182,7 +182,9 @@ class ProviderGoogleGenAI(Provider): return gemini_contents - async def _query(self, payloads: dict, tools: FuncCall) -> LLMResponse: + async def _query( + self, payloads: dict, tools: FuncCall, temperature: float = 0.7 + ) -> LLMResponse: """非流式请求 Gemini API""" if tools: t = tools.get_func_desc_google_genai_style() @@ -205,7 +207,6 @@ class ProviderGoogleGenAI(Provider): modalites.append("Image") loop = True - while loop: loop = False result = await self.client.models.generate_content( @@ -213,6 +214,7 @@ class ProviderGoogleGenAI(Provider): contents=conversation, config=types.GenerateContentConfig( system_instruction=system_instruction, + temperature=temperature, tools=[tool] if tool else None, safety_settings=self.safety_settings if self.safety_settings @@ -222,26 +224,35 @@ class ProviderGoogleGenAI(Provider): ), ), ) - # logger.debug(f"gemini result: {result}") - if "Developer instruction is not enabled" in str(result): + result_str = str(result) + finish_reason = result.candidates[0].finish_reason + if "Developer instruction is not enabled" in result_str: logger.warning(f"{self.get_model()} 不支持 system prompt,已自动去除。") system_instruction = "" loop = True - # 不支持函数调用的模型SDK似乎会自动去除,保险起见不删除此行判断。 - elif "Function calling is not enabled" in str(result): + continue + elif "Function calling is not enabled" in result_str: logger.warning(f"{self.get_model()} 不支持函数调用,已自动去除。") tool = None loop = True - elif "Multi-modal output is not supported" in str(result): + continue + elif "Multi-modal output is not supported" in result_str: logger.warning(f"{self.get_model()} 不支持多模态输出,降级为文本模态。") modalites = ["Text"] loop = True + continue + elif finish_reason == types.FinishReason.RECITATION: + logger.warning("发生了recitation,正在尝试加温重试...") + temperature += 0.2 + logger.info(f"当前温度: {temperature}") + if temperature < 2: + loop = True + else: + raise Exception("温度已到达(或超过)2") + continue llm_response = LLMResponse("assistant") - chain = [] - - finish_reason = result.candidates[0].finish_reason if finish_reason == types.FinishReason.SAFETY: raise Exception("模型生成内容未通过用户定义的内容安全检查") @@ -255,9 +266,22 @@ class ProviderGoogleGenAI(Provider): raise Exception("模型生成内容违反Gemini平台政策") if not result.candidates[0].content.parts: + logger.debug(result.candidates) raise Exception("API 返回的内容为空。") - for part in result.candidates[0].content.parts: + llm_response.result_chain = self._process_content_parts( + result.candidates[0].content.parts, llm_response + ) + + return llm_response + + def _process_content_parts( + self, parts: types.Part, llm_response: LLMResponse + ) -> MessageChain: + """处理内容部分并构建消息链""" + chain = [] + part: types.Part + for part in parts: if part.text: chain.append(Comp.Plain(part.text)) elif part.function_call: @@ -267,10 +291,7 @@ class ProviderGoogleGenAI(Provider): llm_response.tools_call_ids.append(part.function_call.id) elif part.inline_data and part.inline_data.mime_type.startswith("image/"): chain.append(Comp.Image.fromBytes(part.inline_data.data)) - - llm_response.result_chain = MessageChain(chain=chain) - - return llm_response + return MessageChain(chain=chain) async def text_chat( self, @@ -305,11 +326,12 @@ class ProviderGoogleGenAI(Provider): retry = 10 keys = self.api_keys.copy() chosen_key = random.choice(keys) + temp = kwargs.get("temperature", 0.7) # 暂定默认温度为0.7 for _ in range(retry): try: self.chosen_api_key = chosen_key - llm_response = await self._query(payloads, func_tool) + llm_response = await self._query(payloads, func_tool, temp) break except Exception as e: if "429" in str(e) or "API key not valid" in str(e):