perf: 支持更多参数

This commit is contained in:
Raven95676
2025-04-20 00:12:14 +08:00
parent 56001ed272
commit 01d52cef74
+18 -10
View File
@@ -123,10 +123,11 @@ class ProviderGoogleGenAI(Provider):
async def _prepare_query_config(
self,
payloads: dict,
tools: Optional[FuncCall] = None,
system_instruction: Optional[str] = None,
temperature: Optional[float] = 0.7,
modalities: Optional[List[str]] = None,
temperature: float = 0.7,
) -> types.GenerateContentConfig:
"""准备查询配置"""
if not modalities:
@@ -158,10 +159,18 @@ class ProviderGoogleGenAI(Provider):
tool_list = [
types.Tool(function_declarations=func_desc["function_declarations"])
]
return types.GenerateContentConfig(
system_instruction=system_instruction,
temperature=temperature,
max_output_tokens=payloads.get("max_tokens") or payloads.get("maxOutputTokens"),
top_p=payloads.get("top_p") or payloads.get("topP"),
top_k=payloads.get("top_k") or payloads.get("topK"),
frequency_penalty=payloads.get("frequency_penalty") or payloads.get("frequencyPenalty"),
presence_penalty=payloads.get("presence_penalty") or payloads.get("presencePenalty"),
stop_sequences=payloads.get("stop") or payloads.get("stopSequences"),
response_logprobs=payloads.get("response_logprobs") or payloads.get("responseLogprobs"),
logprobs=payloads.get("logprobs"),
seed=payloads.get("seed"),
response_modalities=modalities,
tools=tool_list,
safety_settings=self.safety_settings if self.safety_settings else None,
@@ -305,7 +314,7 @@ class ProviderGoogleGenAI(Provider):
return MessageChain(chain=chain)
async def _query(
self, payloads: dict, tools: FuncCall, temperature: float = 0.7
self, payloads: dict, tools: FuncCall
) -> LLMResponse:
"""非流式请求 Gemini API"""
system_instruction = next(
@@ -318,12 +327,13 @@ class ProviderGoogleGenAI(Provider):
modalities.append("Image")
conversation = self._prepare_conversation(payloads)
temperature=payloads.get("temperature", 0.7)
result: Optional[types.GenerateContentResponse] = None
while True:
try:
config = await self._prepare_query_config(
tools, system_instruction, temperature, modalities
payloads, tools, system_instruction, modalities, temperature
)
result = await self.client.models.generate_content(
model=self.get_model(),
@@ -370,7 +380,7 @@ class ProviderGoogleGenAI(Provider):
return llm_response
async def _query_stream(
self, payloads: dict, tools: FuncCall, temperature: float = 0.7
self, payloads: dict, tools: FuncCall
) -> AsyncGenerator[LLMResponse, None]:
"""流式请求 Gemini API"""
system_instruction = next(
@@ -384,7 +394,7 @@ class ProviderGoogleGenAI(Provider):
while True:
try:
config = await self._prepare_query_config(
tools, system_instruction, temperature
payloads, tools, system_instruction
)
result = await self.client.models.generate_content_stream(
model=self.get_model(),
@@ -464,11 +474,10 @@ class ProviderGoogleGenAI(Provider):
retry = 10
keys = self.api_keys.copy()
temp = kwargs.get("temperature", 0.7) # 暂定默认温度为0.7
for _ in range(retry):
try:
return await self._query(payloads, func_tool, temp)
return await self._query(payloads, func_tool)
except APIError as e:
if await self._handle_api_error(e, keys):
continue
@@ -505,11 +514,10 @@ class ProviderGoogleGenAI(Provider):
retry = 10
keys = self.api_keys.copy()
temp = kwargs.get("temperature", 0.7) # 暂定默认温度为0.7
for _ in range(retry):
try:
async for response in self._query_stream(payloads, func_tool, temp):
async for response in self._query_stream(payloads, func_tool):
yield response
break
except APIError as e: