perf: 增加对RECITATION完成原因的处理,提取内容处理逻辑到独立方法
This commit is contained in:
@@ -182,7 +182,9 @@ class ProviderGoogleGenAI(Provider):
|
||||
|
||||
return gemini_contents
|
||||
|
||||
async def _query(self, payloads: dict, tools: FuncCall) -> LLMResponse:
|
||||
async def _query(
|
||||
self, payloads: dict, tools: FuncCall, temperature: float = 0.7
|
||||
) -> LLMResponse:
|
||||
"""非流式请求 Gemini API"""
|
||||
if tools:
|
||||
t = tools.get_func_desc_google_genai_style()
|
||||
@@ -205,7 +207,6 @@ class ProviderGoogleGenAI(Provider):
|
||||
modalites.append("Image")
|
||||
|
||||
loop = True
|
||||
|
||||
while loop:
|
||||
loop = False
|
||||
result = await self.client.models.generate_content(
|
||||
@@ -213,6 +214,7 @@ class ProviderGoogleGenAI(Provider):
|
||||
contents=conversation,
|
||||
config=types.GenerateContentConfig(
|
||||
system_instruction=system_instruction,
|
||||
temperature=temperature,
|
||||
tools=[tool] if tool else None,
|
||||
safety_settings=self.safety_settings
|
||||
if self.safety_settings
|
||||
@@ -222,26 +224,35 @@ class ProviderGoogleGenAI(Provider):
|
||||
),
|
||||
),
|
||||
)
|
||||
# logger.debug(f"gemini result: {result}")
|
||||
|
||||
if "Developer instruction is not enabled" in str(result):
|
||||
result_str = str(result)
|
||||
finish_reason = result.candidates[0].finish_reason
|
||||
if "Developer instruction is not enabled" in result_str:
|
||||
logger.warning(f"{self.get_model()} 不支持 system prompt,已自动去除。")
|
||||
system_instruction = ""
|
||||
loop = True
|
||||
# 不支持函数调用的模型SDK似乎会自动去除,保险起见不删除此行判断。
|
||||
elif "Function calling is not enabled" in str(result):
|
||||
continue
|
||||
elif "Function calling is not enabled" in result_str:
|
||||
logger.warning(f"{self.get_model()} 不支持函数调用,已自动去除。")
|
||||
tool = None
|
||||
loop = True
|
||||
elif "Multi-modal output is not supported" in str(result):
|
||||
continue
|
||||
elif "Multi-modal output is not supported" in result_str:
|
||||
logger.warning(f"{self.get_model()} 不支持多模态输出,降级为文本模态。")
|
||||
modalites = ["Text"]
|
||||
loop = True
|
||||
continue
|
||||
elif finish_reason == types.FinishReason.RECITATION:
|
||||
logger.warning("发生了recitation,正在尝试加温重试...")
|
||||
temperature += 0.2
|
||||
logger.info(f"当前温度: {temperature}")
|
||||
if temperature < 2:
|
||||
loop = True
|
||||
else:
|
||||
raise Exception("温度已到达(或超过)2")
|
||||
continue
|
||||
|
||||
llm_response = LLMResponse("assistant")
|
||||
chain = []
|
||||
|
||||
finish_reason = result.candidates[0].finish_reason
|
||||
|
||||
if finish_reason == types.FinishReason.SAFETY:
|
||||
raise Exception("模型生成内容未通过用户定义的内容安全检查")
|
||||
@@ -255,9 +266,22 @@ class ProviderGoogleGenAI(Provider):
|
||||
raise Exception("模型生成内容违反Gemini平台政策")
|
||||
|
||||
if not result.candidates[0].content.parts:
|
||||
logger.debug(result.candidates)
|
||||
raise Exception("API 返回的内容为空。")
|
||||
|
||||
for part in result.candidates[0].content.parts:
|
||||
llm_response.result_chain = self._process_content_parts(
|
||||
result.candidates[0].content.parts, llm_response
|
||||
)
|
||||
|
||||
return llm_response
|
||||
|
||||
def _process_content_parts(
|
||||
self, parts: types.Part, llm_response: LLMResponse
|
||||
) -> MessageChain:
|
||||
"""处理内容部分并构建消息链"""
|
||||
chain = []
|
||||
part: types.Part
|
||||
for part in parts:
|
||||
if part.text:
|
||||
chain.append(Comp.Plain(part.text))
|
||||
elif part.function_call:
|
||||
@@ -267,10 +291,7 @@ class ProviderGoogleGenAI(Provider):
|
||||
llm_response.tools_call_ids.append(part.function_call.id)
|
||||
elif part.inline_data and part.inline_data.mime_type.startswith("image/"):
|
||||
chain.append(Comp.Image.fromBytes(part.inline_data.data))
|
||||
|
||||
llm_response.result_chain = MessageChain(chain=chain)
|
||||
|
||||
return llm_response
|
||||
return MessageChain(chain=chain)
|
||||
|
||||
async def text_chat(
|
||||
self,
|
||||
@@ -305,11 +326,12 @@ class ProviderGoogleGenAI(Provider):
|
||||
retry = 10
|
||||
keys = self.api_keys.copy()
|
||||
chosen_key = random.choice(keys)
|
||||
temp = kwargs.get("temperature", 0.7) # 暂定默认温度为0.7
|
||||
|
||||
for _ in range(retry):
|
||||
try:
|
||||
self.chosen_api_key = chosen_key
|
||||
llm_response = await self._query(payloads, func_tool)
|
||||
llm_response = await self._query(payloads, func_tool, temp)
|
||||
break
|
||||
except Exception as e:
|
||||
if "429" in str(e) or "API key not valid" in str(e):
|
||||
|
||||
Reference in New Issue
Block a user