perf: 增加对RECITATION完成原因的处理,提取内容处理逻辑到独立方法

This commit is contained in:
Raven95676
2025-04-11 12:25:44 +08:00
parent cc6cd96d8e
commit 0dc5b4cdfc
+38 -16
View File
@@ -182,7 +182,9 @@ class ProviderGoogleGenAI(Provider):
return gemini_contents
async def _query(self, payloads: dict, tools: FuncCall) -> LLMResponse:
async def _query(
self, payloads: dict, tools: FuncCall, temperature: float = 0.7
) -> LLMResponse:
"""非流式请求 Gemini API"""
if tools:
t = tools.get_func_desc_google_genai_style()
@@ -205,7 +207,6 @@ class ProviderGoogleGenAI(Provider):
modalites.append("Image")
loop = True
while loop:
loop = False
result = await self.client.models.generate_content(
@@ -213,6 +214,7 @@ class ProviderGoogleGenAI(Provider):
contents=conversation,
config=types.GenerateContentConfig(
system_instruction=system_instruction,
temperature=temperature,
tools=[tool] if tool else None,
safety_settings=self.safety_settings
if self.safety_settings
@@ -222,26 +224,35 @@ class ProviderGoogleGenAI(Provider):
),
),
)
# logger.debug(f"gemini result: {result}")
if "Developer instruction is not enabled" in str(result):
result_str = str(result)
finish_reason = result.candidates[0].finish_reason
if "Developer instruction is not enabled" in result_str:
logger.warning(f"{self.get_model()} 不支持 system prompt,已自动去除。")
system_instruction = ""
loop = True
# 不支持函数调用的模型SDK似乎会自动去除,保险起见不删除此行判断。
elif "Function calling is not enabled" in str(result):
continue
elif "Function calling is not enabled" in result_str:
logger.warning(f"{self.get_model()} 不支持函数调用,已自动去除。")
tool = None
loop = True
elif "Multi-modal output is not supported" in str(result):
continue
elif "Multi-modal output is not supported" in result_str:
logger.warning(f"{self.get_model()} 不支持多模态输出,降级为文本模态。")
modalites = ["Text"]
loop = True
continue
elif finish_reason == types.FinishReason.RECITATION:
logger.warning("发生了recitation,正在尝试加温重试...")
temperature += 0.2
logger.info(f"当前温度: {temperature}")
if temperature < 2:
loop = True
else:
raise Exception("温度已到达(或超过)2")
continue
llm_response = LLMResponse("assistant")
chain = []
finish_reason = result.candidates[0].finish_reason
if finish_reason == types.FinishReason.SAFETY:
raise Exception("模型生成内容未通过用户定义的内容安全检查")
@@ -255,9 +266,22 @@ class ProviderGoogleGenAI(Provider):
raise Exception("模型生成内容违反Gemini平台政策")
if not result.candidates[0].content.parts:
logger.debug(result.candidates)
raise Exception("API 返回的内容为空。")
for part in result.candidates[0].content.parts:
llm_response.result_chain = self._process_content_parts(
result.candidates[0].content.parts, llm_response
)
return llm_response
def _process_content_parts(
self, parts: types.Part, llm_response: LLMResponse
) -> MessageChain:
"""处理内容部分并构建消息链"""
chain = []
part: types.Part
for part in parts:
if part.text:
chain.append(Comp.Plain(part.text))
elif part.function_call:
@@ -267,10 +291,7 @@ class ProviderGoogleGenAI(Provider):
llm_response.tools_call_ids.append(part.function_call.id)
elif part.inline_data and part.inline_data.mime_type.startswith("image/"):
chain.append(Comp.Image.fromBytes(part.inline_data.data))
llm_response.result_chain = MessageChain(chain=chain)
return llm_response
return MessageChain(chain=chain)
async def text_chat(
self,
@@ -305,11 +326,12 @@ class ProviderGoogleGenAI(Provider):
retry = 10
keys = self.api_keys.copy()
chosen_key = random.choice(keys)
temp = kwargs.get("temperature", 0.7) # 暂定默认温度为0.7
for _ in range(retry):
try:
self.chosen_api_key = chosen_key
llm_response = await self._query(payloads, func_tool)
llm_response = await self._query(payloads, func_tool, temp)
break
except Exception as e:
if "429" in str(e) or "API key not valid" in str(e):