From 85609ea7422fe814996689b4bbfacbaf7327237d Mon Sep 17 00:00:00 2001 From: Raven95676 Date: Fri, 2 May 2025 10:49:45 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=94=AF=E6=8C=81Gemini=E6=80=9D?= =?UTF-8?q?=E8=80=83=E8=AE=BE=E7=BD=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/core/config/default.py | 20 +++++++++++ .../core/provider/sources/gemini_source.py | 34 +++++++++++++------ 2 files changed, 44 insertions(+), 10 deletions(-) diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 664a11307..0a27e62ce 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -568,6 +568,10 @@ CONFIG_METADATA_2 = { "sexually_explicit": "BLOCK_MEDIUM_AND_ABOVE", "dangerous_content": "BLOCK_MEDIUM_AND_ABOVE", }, + "gm_thinking_config": { + "enable": False, + "budget": 0, + }, }, "DeepSeek": { "id": "deepseek_default", @@ -801,6 +805,22 @@ CONFIG_METADATA_2 = { }, }, }, + "gm_thinking_config": { + "description": "Gemini思考设置", + "type": "object", + "items": { + "enable": { + "description": "启用思考", + "type": "bool", + "hint": "启用后,模型将在可用时输出思考过程", + }, + "budget": { + "description": "思考预算", + "type": "int", + "hint": "模型应该生成的思考Token的数量", + }, + }, + }, "rag_options": { "description": "RAG 选项", "type": "object", diff --git a/astrbot/core/provider/sources/gemini_source.py b/astrbot/core/provider/sources/gemini_source.py index a175a3d68..9a5c0235f 100644 --- a/astrbot/core/provider/sources/gemini_source.py +++ b/astrbot/core/provider/sources/gemini_source.py @@ -162,18 +162,30 @@ class ProviderGoogleGenAI(Provider): return types.GenerateContentConfig( system_instruction=system_instruction, temperature=temperature, - max_output_tokens=payloads.get("max_tokens") or payloads.get("maxOutputTokens"), + max_output_tokens=payloads.get("max_tokens") + or payloads.get("maxOutputTokens"), top_p=payloads.get("top_p") or payloads.get("topP"), top_k=payloads.get("top_k") or payloads.get("topK"), - frequency_penalty=payloads.get("frequency_penalty") or payloads.get("frequencyPenalty"), - presence_penalty=payloads.get("presence_penalty") or payloads.get("presencePenalty"), + frequency_penalty=payloads.get("frequency_penalty") + or payloads.get("frequencyPenalty"), + presence_penalty=payloads.get("presence_penalty") + or payloads.get("presencePenalty"), stop_sequences=payloads.get("stop") or payloads.get("stopSequences"), - response_logprobs=payloads.get("response_logprobs") or payloads.get("responseLogprobs"), + response_logprobs=payloads.get("response_logprobs") + or payloads.get("responseLogprobs"), logprobs=payloads.get("logprobs"), seed=payloads.get("seed"), response_modalities=modalities, tools=tool_list, safety_settings=self.safety_settings if self.safety_settings else None, + thinking_config=types.ThinkingConfig( + include_thoughts=self.provider_config.get("gm_thinking_config", {}).get( + "enable", False + ), + thinking_budget=self.provider_config.get("gm_thinking_config", {}).get( + "budget", 0 + ), + ), automatic_function_calling=types.AutomaticFunctionCallingConfig( disable=True ), @@ -194,7 +206,11 @@ class ProviderGoogleGenAI(Provider): image_bytes = base64.b64decode(url.split(",", 1)[1]) return types.Part.from_bytes(data=image_bytes, mime_type=mime_type) - def append_or_extend(contents: list[types.Content], part: list[types.Part], content_cls: type[types.Content]) -> None: + def append_or_extend( + contents: list[types.Content], + part: list[types.Part], + content_cls: type[types.Content], + ) -> None: if contents and isinstance(contents[-1], content_cls): contents[-1].parts.extend(part) else: @@ -226,7 +242,7 @@ class ProviderGoogleGenAI(Provider): if content: parts = [types.Part.from_text(text=content)] append_or_extend(gemini_contents, parts, types.ModelContent) - elif not native_tool_enabled and "tool_calls" in message : + elif not native_tool_enabled and "tool_calls" in message: parts = [ types.Part.from_function_call( name=tool["function"]["name"], @@ -312,9 +328,7 @@ class ProviderGoogleGenAI(Provider): chain.append(Comp.Image.fromBytes(part.inline_data.data)) return MessageChain(chain=chain) - async def _query( - self, payloads: dict, tools: FuncCall - ) -> LLMResponse: + async def _query(self, payloads: dict, tools: FuncCall) -> LLMResponse: """非流式请求 Gemini API""" system_instruction = next( (msg["content"] for msg in payloads["messages"] if msg["role"] == "system"), @@ -326,7 +340,7 @@ class ProviderGoogleGenAI(Provider): modalities.append("Image") conversation = self._prepare_conversation(payloads) - temperature=payloads.get("temperature", 0.7) + temperature = payloads.get("temperature", 0.7) result: Optional[types.GenerateContentResponse] = None while True: