From afb56cf707c8b7e29855969fb03de310aa1de43d Mon Sep 17 00:00:00 2001 From: Soulter <37870767+Soulter@users.noreply.github.com> Date: Wed, 19 Nov 2025 18:54:56 +0800 Subject: [PATCH] feat: add supports for gemini-3 series thought signature (#3698) * feat: add supports for gemini-3 series thought signature * feat: refactor tools_call_extra_content to use a dictionary for better structure --- astrbot/core/agent/message.py | 7 ++++ astrbot/core/db/po.py | 20 ++++------- astrbot/core/provider/entities.py | 31 ++++++++++++----- .../core/provider/sources/gemini_source.py | 33 ++++++++++++++----- .../core/provider/sources/openai_source.py | 25 +++++--------- 5 files changed, 70 insertions(+), 46 deletions(-) diff --git a/astrbot/core/agent/message.py b/astrbot/core/agent/message.py index 4a2e1b149..4c65c32f6 100644 --- a/astrbot/core/agent/message.py +++ b/astrbot/core/agent/message.py @@ -119,6 +119,13 @@ class ToolCall(BaseModel): """The ID of the tool call.""" function: FunctionBody """The function body of the tool call.""" + extra_content: dict[str, Any] | None = None + """Extra metadata for the tool call.""" + + def model_dump(self, **kwargs: Any) -> dict[str, Any]: + if self.extra_content is None: + kwargs.setdefault("exclude", set()).add("extra_content") + return super().model_dump(**kwargs) class ToolCallPart(BaseModel): diff --git a/astrbot/core/db/po.py b/astrbot/core/db/po.py index 1e7245976..5cf25ec13 100644 --- a/astrbot/core/db/po.py +++ b/astrbot/core/db/po.py @@ -3,13 +3,7 @@ from dataclasses import dataclass, field from datetime import datetime, timezone from typing import TypedDict -from sqlmodel import ( - JSON, - Field, - SQLModel, - Text, - UniqueConstraint, -) +from sqlmodel import JSON, Field, SQLModel, Text, UniqueConstraint class PlatformStat(SQLModel, table=True): @@ -18,7 +12,7 @@ class PlatformStat(SQLModel, table=True): Note: In astrbot v4, we moved `platform` table to here. """ - __tablename__ = "platform_stats" + __tablename__ = "platform_stats" # type: ignore id: int = Field(primary_key=True, sa_column_kwargs={"autoincrement": True}) timestamp: datetime = Field(nullable=False) @@ -37,7 +31,7 @@ class PlatformStat(SQLModel, table=True): class ConversationV2(SQLModel, table=True): - __tablename__ = "conversations" + __tablename__ = "conversations" # type: ignore inner_conversation_id: int = Field( primary_key=True, @@ -74,7 +68,7 @@ class Persona(SQLModel, table=True): It can be used to customize the behavior of LLMs. """ - __tablename__ = "personas" + __tablename__ = "personas" # type: ignore id: int | None = Field( primary_key=True, @@ -104,7 +98,7 @@ class Persona(SQLModel, table=True): class Preference(SQLModel, table=True): """This class represents preferences for bots.""" - __tablename__ = "preferences" + __tablename__ = "preferences" # type: ignore id: int | None = Field( default=None, @@ -140,7 +134,7 @@ class PlatformMessageHistory(SQLModel, table=True): or platform-specific messages. """ - __tablename__ = "platform_message_history" + __tablename__ = "platform_message_history" # type: ignore id: int | None = Field( primary_key=True, @@ -167,7 +161,7 @@ class Attachment(SQLModel, table=True): Attachments can be images, files, or other media types. """ - __tablename__ = "attachments" + __tablename__ = "attachments" # type: ignore inner_attachment_id: int | None = Field( primary_key=True, diff --git a/astrbot/core/provider/entities.py b/astrbot/core/provider/entities.py index c6978e7b9..dc188f141 100644 --- a/astrbot/core/provider/entities.py +++ b/astrbot/core/provider/entities.py @@ -211,6 +211,8 @@ class LLMResponse: """Tool call names.""" tools_call_ids: list[str] = field(default_factory=list) """Tool call IDs.""" + tools_call_extra_content: dict[str, dict[str, Any]] = field(default_factory=dict) + """Tool call extra content. tool_call_id -> extra_content dict""" reasoning_content: str = "" """The reasoning content extracted from the LLM, if any.""" @@ -233,6 +235,7 @@ class LLMResponse: tools_call_args: list[dict[str, Any]] | None = None, tools_call_name: list[str] | None = None, tools_call_ids: list[str] | None = None, + tools_call_extra_content: dict[str, dict[str, Any]] | None = None, raw_completion: ChatCompletion | GenerateContentResponse | AnthropicMessage @@ -256,6 +259,8 @@ class LLMResponse: tools_call_name = [] if tools_call_ids is None: tools_call_ids = [] + if tools_call_extra_content is None: + tools_call_extra_content = {} self.role = role self.completion_text = completion_text @@ -263,6 +268,7 @@ class LLMResponse: self.tools_call_args = tools_call_args self.tools_call_name = tools_call_name self.tools_call_ids = tools_call_ids + self.tools_call_extra_content = tools_call_extra_content self.raw_completion = raw_completion self.is_chunk = is_chunk @@ -288,16 +294,19 @@ class LLMResponse: """Convert to OpenAI tool calls format. Deprecated, use to_openai_to_calls_model instead.""" ret = [] for idx, tool_call_arg in enumerate(self.tools_call_args): - ret.append( - { - "id": self.tools_call_ids[idx], - "function": { - "name": self.tools_call_name[idx], - "arguments": json.dumps(tool_call_arg), - }, - "type": "function", + payload = { + "id": self.tools_call_ids[idx], + "function": { + "name": self.tools_call_name[idx], + "arguments": json.dumps(tool_call_arg), }, - ) + "type": "function", + } + if self.tools_call_extra_content.get(self.tools_call_ids[idx]): + payload["extra_content"] = self.tools_call_extra_content[ + self.tools_call_ids[idx] + ] + ret.append(payload) return ret def to_openai_to_calls_model(self) -> list[ToolCall]: @@ -311,6 +320,10 @@ class LLMResponse: name=self.tools_call_name[idx], arguments=json.dumps(tool_call_arg), ), + # the extra_content will not serialize if it's None when calling ToolCall.model_dump() + extra_content=self.tools_call_extra_content.get( + self.tools_call_ids[idx] + ), ), ) return ret diff --git a/astrbot/core/provider/sources/gemini_source.py b/astrbot/core/provider/sources/gemini_source.py index b9159eec9..e14140d43 100644 --- a/astrbot/core/provider/sources/gemini_source.py +++ b/astrbot/core/provider/sources/gemini_source.py @@ -290,13 +290,24 @@ class ProviderGoogleGenAI(Provider): parts = [types.Part.from_text(text=content)] append_or_extend(gemini_contents, parts, types.ModelContent) elif not native_tool_enabled and "tool_calls" in message: - parts = [ - types.Part.from_function_call( + parts = [] + for tool in message["tool_calls"]: + part = types.Part.from_function_call( name=tool["function"]["name"], args=json.loads(tool["function"]["arguments"]), ) - for tool in message["tool_calls"] - ] + # we should set thought_signature back to part if exists + # for more info about thought_signature, see: + # https://ai.google.dev/gemini-api/docs/thought-signatures + if "extra_content" in tool: + ts_bs64 = ( + tool["extra_content"] + .get("google", {}) + .get("thought_signature") + ) + if ts_bs64: + part.thought_signature = base64.b64decode(ts_bs64) + parts.append(part) append_or_extend(gemini_contents, parts, types.ModelContent) else: logger.warning("assistant 角色的消息内容为空,已添加空格占位") @@ -393,10 +404,15 @@ class ProviderGoogleGenAI(Provider): llm_response.role = "tool" llm_response.tools_call_name.append(part.function_call.name) llm_response.tools_call_args.append(part.function_call.args) - # gemini 返回的 function_call.id 可能为 None - llm_response.tools_call_ids.append( - part.function_call.id or part.function_call.name, - ) + # function_call.id might be None, use name as fallback + tool_call_id = part.function_call.id or part.function_call.name + llm_response.tools_call_ids.append(tool_call_id) + # extra_content + if part.thought_signature: + ts_bs64 = base64.b64encode(part.thought_signature).decode("utf-8") + llm_response.tools_call_extra_content[tool_call_id] = { + "google": {"thought_signature": ts_bs64} + } elif ( part.inline_data and part.inline_data.mime_type @@ -435,6 +451,7 @@ class ProviderGoogleGenAI(Provider): contents=conversation, config=config, ) + logger.debug(f"genai result: {result}") if not result.candidates: logger.error(f"请求失败, 返回的 candidates 为空: {result}") diff --git a/astrbot/core/provider/sources/openai_source.py b/astrbot/core/provider/sources/openai_source.py index da2ce68f8..3f1d283ce 100644 --- a/astrbot/core/provider/sources/openai_source.py +++ b/astrbot/core/provider/sources/openai_source.py @@ -8,7 +8,7 @@ import re from collections.abc import AsyncGenerator from openai import AsyncAzureOpenAI, AsyncOpenAI -from openai._exceptions import NotFoundError, UnprocessableEntityError +from openai._exceptions import NotFoundError from openai.lib.streaming.chat._completions import ChatCompletionStreamState from openai.types.chat.chat_completion import ChatCompletion from openai.types.chat.chat_completion_chunk import ChatCompletionChunk @@ -279,6 +279,7 @@ class ProviderOpenAIOfficial(Provider): args_ls = [] func_name_ls = [] tool_call_ids = [] + tool_call_extra_content_dict = {} for tool_call in choice.message.tool_calls: if isinstance(tool_call, str): # workaround for #1359 @@ -296,11 +297,16 @@ class ProviderOpenAIOfficial(Provider): args_ls.append(args) func_name_ls.append(tool_call.function.name) tool_call_ids.append(tool_call.id) + + # gemini-2.5 / gemini-3 series extra_content handling + extra_content = getattr(tool_call, "extra_content", None) + if extra_content is not None: + tool_call_extra_content_dict[tool_call.id] = extra_content llm_response.role = "tool" llm_response.tools_call_args = args_ls llm_response.tools_call_name = func_name_ls llm_response.tools_call_ids = tool_call_ids - + llm_response.tools_call_extra_content = tool_call_extra_content_dict # specially handle finish reason if choice.finish_reason == "content_filter": raise Exception( @@ -353,7 +359,7 @@ class ProviderOpenAIOfficial(Provider): payloads = {"messages": context_query, **model_config} - # xAI 原生搜索参数(最小侵入地在此处注入) + # xAI origin search tool inject self._maybe_inject_xai_search(payloads, **kwargs) return payloads, context_query @@ -475,12 +481,6 @@ class ProviderOpenAIOfficial(Provider): self.client.api_key = chosen_key llm_response = await self._query(payloads, func_tool) break - except UnprocessableEntityError as e: - logger.warning(f"不可处理的实体错误:{e},尝试删除图片。") - # 尝试删除所有 image - new_contexts = await self._remove_image_from_context(context_query) - payloads["messages"] = new_contexts - context_query = new_contexts except Exception as e: last_exception = e ( @@ -545,12 +545,6 @@ class ProviderOpenAIOfficial(Provider): async for response in self._query_stream(payloads, func_tool): yield response break - except UnprocessableEntityError as e: - logger.warning(f"不可处理的实体错误:{e},尝试删除图片。") - # 尝试删除所有 image - new_contexts = await self._remove_image_from_context(context_query) - payloads["messages"] = new_contexts - context_query = new_contexts except Exception as e: last_exception = e ( @@ -646,4 +640,3 @@ class ProviderOpenAIOfficial(Provider): with open(image_url, "rb") as f: image_bs64 = base64.b64encode(f.read()).decode("utf-8") return "data:image/jpeg;base64," + image_bs64 - return ""