diff --git a/astrbot/core/provider/sources/gemini_source.py b/astrbot/core/provider/sources/gemini_source.py index aacebb3fb..ef2305069 100644 --- a/astrbot/core/provider/sources/gemini_source.py +++ b/astrbot/core/provider/sources/gemini_source.py @@ -123,10 +123,11 @@ class ProviderGoogleGenAI(Provider): async def _prepare_query_config( self, + payloads: dict, tools: Optional[FuncCall] = None, system_instruction: Optional[str] = None, - temperature: Optional[float] = 0.7, modalities: Optional[List[str]] = None, + temperature: float = 0.7, ) -> types.GenerateContentConfig: """准备查询配置""" if not modalities: @@ -158,10 +159,18 @@ class ProviderGoogleGenAI(Provider): tool_list = [ types.Tool(function_declarations=func_desc["function_declarations"]) ] - return types.GenerateContentConfig( system_instruction=system_instruction, temperature=temperature, + max_output_tokens=payloads.get("max_tokens") or payloads.get("maxOutputTokens"), + top_p=payloads.get("top_p") or payloads.get("topP"), + top_k=payloads.get("top_k") or payloads.get("topK"), + frequency_penalty=payloads.get("frequency_penalty") or payloads.get("frequencyPenalty"), + presence_penalty=payloads.get("presence_penalty") or payloads.get("presencePenalty"), + stop_sequences=payloads.get("stop") or payloads.get("stopSequences"), + response_logprobs=payloads.get("response_logprobs") or payloads.get("responseLogprobs"), + logprobs=payloads.get("logprobs"), + seed=payloads.get("seed"), response_modalities=modalities, tools=tool_list, safety_settings=self.safety_settings if self.safety_settings else None, @@ -305,7 +314,7 @@ class ProviderGoogleGenAI(Provider): return MessageChain(chain=chain) async def _query( - self, payloads: dict, tools: FuncCall, temperature: float = 0.7 + self, payloads: dict, tools: FuncCall ) -> LLMResponse: """非流式请求 Gemini API""" system_instruction = next( @@ -318,12 +327,13 @@ class ProviderGoogleGenAI(Provider): modalities.append("Image") conversation = self._prepare_conversation(payloads) + temperature=payloads.get("temperature", 0.7) result: Optional[types.GenerateContentResponse] = None while True: try: config = await self._prepare_query_config( - tools, system_instruction, temperature, modalities + payloads, tools, system_instruction, modalities, temperature ) result = await self.client.models.generate_content( model=self.get_model(), @@ -370,7 +380,7 @@ class ProviderGoogleGenAI(Provider): return llm_response async def _query_stream( - self, payloads: dict, tools: FuncCall, temperature: float = 0.7 + self, payloads: dict, tools: FuncCall ) -> AsyncGenerator[LLMResponse, None]: """流式请求 Gemini API""" system_instruction = next( @@ -384,7 +394,7 @@ class ProviderGoogleGenAI(Provider): while True: try: config = await self._prepare_query_config( - tools, system_instruction, temperature + payloads, tools, system_instruction ) result = await self.client.models.generate_content_stream( model=self.get_model(), @@ -464,11 +474,10 @@ class ProviderGoogleGenAI(Provider): retry = 10 keys = self.api_keys.copy() - temp = kwargs.get("temperature", 0.7) # 暂定默认温度为0.7 for _ in range(retry): try: - return await self._query(payloads, func_tool, temp) + return await self._query(payloads, func_tool) except APIError as e: if await self._handle_api_error(e, keys): continue @@ -505,11 +514,10 @@ class ProviderGoogleGenAI(Provider): retry = 10 keys = self.api_keys.copy() - temp = kwargs.get("temperature", 0.7) # 暂定默认温度为0.7 for _ in range(retry): try: - async for response in self._query_stream(payloads, func_tool, temp): + async for response in self._query_stream(payloads, func_tool): yield response break except APIError as e: