perf: 支持更多参数
This commit is contained in:
@@ -123,10 +123,11 @@ class ProviderGoogleGenAI(Provider):
|
||||
|
||||
async def _prepare_query_config(
|
||||
self,
|
||||
payloads: dict,
|
||||
tools: Optional[FuncCall] = None,
|
||||
system_instruction: Optional[str] = None,
|
||||
temperature: Optional[float] = 0.7,
|
||||
modalities: Optional[List[str]] = None,
|
||||
temperature: float = 0.7,
|
||||
) -> types.GenerateContentConfig:
|
||||
"""准备查询配置"""
|
||||
if not modalities:
|
||||
@@ -158,10 +159,18 @@ class ProviderGoogleGenAI(Provider):
|
||||
tool_list = [
|
||||
types.Tool(function_declarations=func_desc["function_declarations"])
|
||||
]
|
||||
|
||||
return types.GenerateContentConfig(
|
||||
system_instruction=system_instruction,
|
||||
temperature=temperature,
|
||||
max_output_tokens=payloads.get("max_tokens") or payloads.get("maxOutputTokens"),
|
||||
top_p=payloads.get("top_p") or payloads.get("topP"),
|
||||
top_k=payloads.get("top_k") or payloads.get("topK"),
|
||||
frequency_penalty=payloads.get("frequency_penalty") or payloads.get("frequencyPenalty"),
|
||||
presence_penalty=payloads.get("presence_penalty") or payloads.get("presencePenalty"),
|
||||
stop_sequences=payloads.get("stop") or payloads.get("stopSequences"),
|
||||
response_logprobs=payloads.get("response_logprobs") or payloads.get("responseLogprobs"),
|
||||
logprobs=payloads.get("logprobs"),
|
||||
seed=payloads.get("seed"),
|
||||
response_modalities=modalities,
|
||||
tools=tool_list,
|
||||
safety_settings=self.safety_settings if self.safety_settings else None,
|
||||
@@ -305,7 +314,7 @@ class ProviderGoogleGenAI(Provider):
|
||||
return MessageChain(chain=chain)
|
||||
|
||||
async def _query(
|
||||
self, payloads: dict, tools: FuncCall, temperature: float = 0.7
|
||||
self, payloads: dict, tools: FuncCall
|
||||
) -> LLMResponse:
|
||||
"""非流式请求 Gemini API"""
|
||||
system_instruction = next(
|
||||
@@ -318,12 +327,13 @@ class ProviderGoogleGenAI(Provider):
|
||||
modalities.append("Image")
|
||||
|
||||
conversation = self._prepare_conversation(payloads)
|
||||
temperature=payloads.get("temperature", 0.7)
|
||||
|
||||
result: Optional[types.GenerateContentResponse] = None
|
||||
while True:
|
||||
try:
|
||||
config = await self._prepare_query_config(
|
||||
tools, system_instruction, temperature, modalities
|
||||
payloads, tools, system_instruction, modalities, temperature
|
||||
)
|
||||
result = await self.client.models.generate_content(
|
||||
model=self.get_model(),
|
||||
@@ -370,7 +380,7 @@ class ProviderGoogleGenAI(Provider):
|
||||
return llm_response
|
||||
|
||||
async def _query_stream(
|
||||
self, payloads: dict, tools: FuncCall, temperature: float = 0.7
|
||||
self, payloads: dict, tools: FuncCall
|
||||
) -> AsyncGenerator[LLMResponse, None]:
|
||||
"""流式请求 Gemini API"""
|
||||
system_instruction = next(
|
||||
@@ -384,7 +394,7 @@ class ProviderGoogleGenAI(Provider):
|
||||
while True:
|
||||
try:
|
||||
config = await self._prepare_query_config(
|
||||
tools, system_instruction, temperature
|
||||
payloads, tools, system_instruction
|
||||
)
|
||||
result = await self.client.models.generate_content_stream(
|
||||
model=self.get_model(),
|
||||
@@ -464,11 +474,10 @@ class ProviderGoogleGenAI(Provider):
|
||||
|
||||
retry = 10
|
||||
keys = self.api_keys.copy()
|
||||
temp = kwargs.get("temperature", 0.7) # 暂定默认温度为0.7
|
||||
|
||||
for _ in range(retry):
|
||||
try:
|
||||
return await self._query(payloads, func_tool, temp)
|
||||
return await self._query(payloads, func_tool)
|
||||
except APIError as e:
|
||||
if await self._handle_api_error(e, keys):
|
||||
continue
|
||||
@@ -505,11 +514,10 @@ class ProviderGoogleGenAI(Provider):
|
||||
|
||||
retry = 10
|
||||
keys = self.api_keys.copy()
|
||||
temp = kwargs.get("temperature", 0.7) # 暂定默认温度为0.7
|
||||
|
||||
for _ in range(retry):
|
||||
try:
|
||||
async for response in self._query_stream(payloads, func_tool, temp):
|
||||
async for response in self._query_stream(payloads, func_tool):
|
||||
yield response
|
||||
break
|
||||
except APIError as e:
|
||||
|
||||
Reference in New Issue
Block a user