From 388ae49e55352e91def5757d8c92975e2699c4df Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Wed, 3 Sep 2025 22:25:18 +0800 Subject: [PATCH 1/2] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=20e.message=20?= =?UTF-8?q?=E4=B8=BA=20None=20=E6=97=B6=E6=8A=A5=E9=94=99=E7=9A=84?= =?UTF-8?q?=E9=97=AE=E9=A2=98=E5=92=8C=E4=B8=80=E4=BA=9B=20lint=20error?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/core/provider/entities.py | 32 ++++---- .../core/provider/sources/gemini_source.py | 75 ++++++++++++++----- 2 files changed, 74 insertions(+), 33 deletions(-) diff --git a/astrbot/core/provider/entities.py b/astrbot/core/provider/entities.py index 0a31093ae..aac03e7a8 100644 --- a/astrbot/core/provider/entities.py +++ b/astrbot/core/provider/entities.py @@ -4,9 +4,11 @@ import json from astrbot.core.utils.io import download_image_by_url from astrbot import logger from dataclasses import dataclass, field -from typing import List, Dict, Type +from typing import List, Dict, Type, Any from astrbot.core.agent.tool import ToolSet from openai.types.chat.chat_completion import ChatCompletion +from google.genai.types import GenerateContentResponse +from anthropic.types import Message from openai.types.chat.chat_completion_message_tool_call import ( ChatCompletionMessageToolCall, ) @@ -30,11 +32,11 @@ class ProviderMetaData: desc: str = "" """提供商适配器描述.""" provider_type: ProviderType = ProviderType.CHAT_COMPLETION - cls_type: Type = None + cls_type: Type | None = None - default_config_tmpl: dict = None + default_config_tmpl: dict | None = None """平台的默认配置模板""" - provider_display_name: str = None + provider_display_name: str | None = None """显示在 WebUI 配置页中的提供商名称,如空则是 type""" @@ -58,7 +60,7 @@ class ToolCallMessageSegment: class AssistantMessageSegment: """OpenAI 格式的上下文中 role 为 assistant 的消息段。参考: https://platform.openai.com/docs/guides/function-calling""" - content: str = None + content: str | None = None tool_calls: List[ChatCompletionMessageToolCall | Dict] = field(default_factory=list) role: str = "assistant" @@ -205,17 +207,17 @@ class ProviderRequest: class LLMResponse: role: str """角色, assistant, tool, err""" - result_chain: MessageChain = None + result_chain: MessageChain | None = None """返回的消息链""" - tools_call_args: List[Dict[str, any]] = field(default_factory=list) + tools_call_args: List[Dict[str, Any]] = field(default_factory=list) """工具调用参数""" tools_call_name: List[str] = field(default_factory=list) """工具调用名称""" tools_call_ids: List[str] = field(default_factory=list) """工具调用 ID""" - raw_completion: ChatCompletion = None - _new_record: Dict[str, any] = None + raw_completion: ChatCompletion | GenerateContentResponse | Message | None = None + _new_record: Dict[str, Any] | None = None _completion_text: str = "" @@ -226,12 +228,12 @@ class LLMResponse: self, role: str, completion_text: str = "", - result_chain: MessageChain = None, - tools_call_args: List[Dict[str, any]] = None, - tools_call_name: List[str] = None, - tools_call_ids: List[str] = None, - raw_completion: ChatCompletion = None, - _new_record: Dict[str, any] = None, + result_chain: MessageChain | None = None, + tools_call_args: List[Dict[str, Any]] | None = None, + tools_call_name: List[str] | None = None, + tools_call_ids: List[str] | None = None, + raw_completion: ChatCompletion | None = None, + _new_record: Dict[str, Any] | None = None, is_chunk: bool = False, ): """初始化 LLMResponse diff --git a/astrbot/core/provider/sources/gemini_source.py b/astrbot/core/provider/sources/gemini_source.py index e1d12623c..5d3d579c3 100644 --- a/astrbot/core/provider/sources/gemini_source.py +++ b/astrbot/core/provider/sources/gemini_source.py @@ -15,7 +15,7 @@ from astrbot import logger from astrbot.api.provider import Provider from astrbot.core.message.message_event_result import MessageChain from astrbot.core.provider.entities import LLMResponse -from astrbot.core.provider.func_tool_manager import FuncCall +from astrbot.core.provider.func_tool_manager import ToolSet from astrbot.core.utils.io import download_image_by_url from ..register import register_provider_adapter @@ -61,7 +61,7 @@ class ProviderGoogleGenAI(Provider): default_persona, ) self.api_keys: list = provider_config.get("key", []) - self.chosen_api_key: str = self.api_keys[0] if len(self.api_keys) > 0 else None + self.chosen_api_key: str = self.api_keys[0] if len(self.api_keys) > 0 else "" self.timeout: int = int(provider_config.get("timeout", 180)) self.api_base: Optional[str] = provider_config.get("api_base", None) @@ -96,6 +96,9 @@ class ProviderGoogleGenAI(Provider): async def _handle_api_error(self, e: APIError, keys: list[str]) -> bool: """处理API错误,返回是否需要重试""" + if e.message is None: + e.message = "" + if e.code == 429 or "API key not valid" in e.message: keys.remove(self.chosen_api_key) if len(keys) > 0: @@ -119,7 +122,7 @@ class ProviderGoogleGenAI(Provider): async def _prepare_query_config( self, payloads: dict, - tools: Optional[FuncCall] = None, + tools: Optional[ToolSet] = None, system_instruction: Optional[str] = None, modalities: Optional[list[str]] = None, temperature: float = 0.7, @@ -321,11 +324,15 @@ class ProviderGoogleGenAI(Provider): @staticmethod def _process_content_parts( - result: types.GenerateContentResponse, llm_response: LLMResponse + candidate: types.Candidate, llm_response: LLMResponse ) -> MessageChain: """处理内容部分并构建消息链""" - finish_reason = result.candidates[0].finish_reason - result_parts: Optional[types.Part] = result.candidates[0].content.parts + if not candidate.content: + logger.warning(f"收到的 candidate.content 为空: {candidate}") + raise Exception("API 返回的 candidate.content 为空。") + + finish_reason = candidate.finish_reason + result_parts: list[types.Part] | None = candidate.content.parts if finish_reason == types.FinishReason.SAFETY: raise Exception("模型生成内容未通过 Gemini 平台的安全检查") @@ -343,22 +350,28 @@ class ProviderGoogleGenAI(Provider): raise Exception("模型生成内容违反 Gemini 平台政策") if not result_parts: - logger.debug(result.candidates) - raise Exception("API 返回的内容为空。") + logger.warning(f"收到的 candidate.content.parts 为空: {candidate}") + raise Exception("API 返回的 candidate.content.parts 为空。") chain = [] part: types.Part # 暂时这样Fallback if all( - part.inline_data and part.inline_data.mime_type.startswith("image/") + part.inline_data + and part.inline_data.mime_type + and part.inline_data.mime_type.startswith("image/") for part in result_parts ): chain.append(Comp.Plain("这是图片")) for part in result_parts: if part.text: chain.append(Comp.Plain(part.text)) - elif part.function_call: + elif ( + part.function_call + and part.function_call.name + and part.function_call.args + ): llm_response.role = "tool" llm_response.tools_call_name.append(part.function_call.name) llm_response.tools_call_args.append(part.function_call.args) @@ -366,11 +379,16 @@ class ProviderGoogleGenAI(Provider): llm_response.tools_call_ids.append( part.function_call.id or part.function_call.name ) - elif part.inline_data and part.inline_data.mime_type.startswith("image/"): + elif ( + part.inline_data + and part.inline_data.mime_type + and part.inline_data.mime_type.startswith("image/") + and part.inline_data.data + ): chain.append(Comp.Image.fromBytes(part.inline_data.data)) return MessageChain(chain=chain) - async def _query(self, payloads: dict, tools: FuncCall) -> LLMResponse: + async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse: """非流式请求 Gemini API""" system_instruction = next( (msg["content"] for msg in payloads["messages"] if msg["role"] == "system"), @@ -396,6 +414,10 @@ class ProviderGoogleGenAI(Provider): config=config, ) + if not result.candidates: + logger.error(f"请求失败, 返回的 candidates 为空: {result}") + raise Exception("请求失败, 返回的 candidates 为空。") + if result.candidates[0].finish_reason == types.FinishReason.RECITATION: if temperature > 2: raise Exception("温度参数已超过最大值2,仍然发生recitation") @@ -408,6 +430,8 @@ class ProviderGoogleGenAI(Provider): break except APIError as e: + if e.message is None: + e.message = "" if "Developer instruction is not enabled" in e.message: logger.warning( f"{self.get_model()} 不支持 system prompt,已自动去除(影响人格设置)" @@ -432,11 +456,13 @@ class ProviderGoogleGenAI(Provider): llm_response = LLMResponse("assistant") llm_response.raw_completion = result - llm_response.result_chain = self._process_content_parts(result, llm_response) + llm_response.result_chain = self._process_content_parts( + result.candidates[0], llm_response + ) return llm_response async def _query_stream( - self, payloads: dict, tools: FuncCall + self, payloads: dict, tools: ToolSet | None ) -> AsyncGenerator[LLMResponse, None]: """流式请求 Gemini API""" system_instruction = next( @@ -459,6 +485,8 @@ class ProviderGoogleGenAI(Provider): ) break except APIError as e: + if e.message is None: + e.message = "" if "Developer instruction is not enabled" in e.message: logger.warning( f"{self.get_model()} 不支持 system prompt,已自动去除(影响人格设置)" @@ -478,13 +506,20 @@ class ProviderGoogleGenAI(Provider): async for chunk in result: llm_response = LLMResponse("assistant", is_chunk=True) + if not chunk.candidates: + logger.warning(f"收到的 chunk 中 candidates 为空: {chunk}") + continue + if not chunk.candidates[0].content: + logger.warning(f"收到的 chunk 中 content 为空: {chunk}") + continue + if chunk.candidates[0].content.parts and any( part.function_call for part in chunk.candidates[0].content.parts ): llm_response = LLMResponse("assistant", is_chunk=False) llm_response.raw_completion = chunk llm_response.result_chain = self._process_content_parts( - chunk, llm_response + chunk.candidates[0], llm_response ) yield llm_response return @@ -500,7 +535,7 @@ class ProviderGoogleGenAI(Provider): final_response = LLMResponse("assistant", is_chunk=False) final_response.raw_completion = chunk final_response.result_chain = self._process_content_parts( - chunk, final_response + chunk.candidates[0], final_response ) break @@ -566,6 +601,8 @@ class ProviderGoogleGenAI(Provider): continue break + raise Exception("请求失败。") + async def text_chat_stream( self, prompt, @@ -621,7 +658,9 @@ class ProviderGoogleGenAI(Provider): return [ m.name.replace("models/", "") for m in models - if "generateContent" in m.supported_actions + if m.supported_actions + and "generateContent" in m.supported_actions + and m.name ] except APIError as e: raise Exception(f"获取模型列表失败: {e.message}") @@ -636,7 +675,7 @@ class ProviderGoogleGenAI(Provider): self.chosen_api_key = key self._init_client() - async def assemble_context(self, text: str, image_urls: list[str] = None): + async def assemble_context(self, text: str, image_urls: list[str] | None = None): """ 组装上下文。 """ From fa53b468fdaab500b5bd567d0a6b3b6ed7a5277b Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Thu, 4 Sep 2025 11:18:58 +0800 Subject: [PATCH 2/2] fix: ensure function call name and args are not None before processing --- astrbot/core/provider/sources/gemini_source.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/astrbot/core/provider/sources/gemini_source.py b/astrbot/core/provider/sources/gemini_source.py index 5d3d579c3..cc4475b6b 100644 --- a/astrbot/core/provider/sources/gemini_source.py +++ b/astrbot/core/provider/sources/gemini_source.py @@ -369,8 +369,8 @@ class ProviderGoogleGenAI(Provider): chain.append(Comp.Plain(part.text)) elif ( part.function_call - and part.function_call.name - and part.function_call.args + and part.function_call.name is not None + and part.function_call.args is not None ): llm_response.role = "tool" llm_response.tools_call_name.append(part.function_call.name)