fix: 修复 e.message 为 None 时报错的问题和一些 lint error
This commit is contained in:
@@ -4,9 +4,11 @@ import json
|
||||
from astrbot.core.utils.io import download_image_by_url
|
||||
from astrbot import logger
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Dict, Type
|
||||
from typing import List, Dict, Type, Any
|
||||
from astrbot.core.agent.tool import ToolSet
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
from google.genai.types import GenerateContentResponse
|
||||
from anthropic.types import Message
|
||||
from openai.types.chat.chat_completion_message_tool_call import (
|
||||
ChatCompletionMessageToolCall,
|
||||
)
|
||||
@@ -30,11 +32,11 @@ class ProviderMetaData:
|
||||
desc: str = ""
|
||||
"""提供商适配器描述."""
|
||||
provider_type: ProviderType = ProviderType.CHAT_COMPLETION
|
||||
cls_type: Type = None
|
||||
cls_type: Type | None = None
|
||||
|
||||
default_config_tmpl: dict = None
|
||||
default_config_tmpl: dict | None = None
|
||||
"""平台的默认配置模板"""
|
||||
provider_display_name: str = None
|
||||
provider_display_name: str | None = None
|
||||
"""显示在 WebUI 配置页中的提供商名称,如空则是 type"""
|
||||
|
||||
|
||||
@@ -58,7 +60,7 @@ class ToolCallMessageSegment:
|
||||
class AssistantMessageSegment:
|
||||
"""OpenAI 格式的上下文中 role 为 assistant 的消息段。参考: https://platform.openai.com/docs/guides/function-calling"""
|
||||
|
||||
content: str = None
|
||||
content: str | None = None
|
||||
tool_calls: List[ChatCompletionMessageToolCall | Dict] = field(default_factory=list)
|
||||
role: str = "assistant"
|
||||
|
||||
@@ -205,17 +207,17 @@ class ProviderRequest:
|
||||
class LLMResponse:
|
||||
role: str
|
||||
"""角色, assistant, tool, err"""
|
||||
result_chain: MessageChain = None
|
||||
result_chain: MessageChain | None = None
|
||||
"""返回的消息链"""
|
||||
tools_call_args: List[Dict[str, any]] = field(default_factory=list)
|
||||
tools_call_args: List[Dict[str, Any]] = field(default_factory=list)
|
||||
"""工具调用参数"""
|
||||
tools_call_name: List[str] = field(default_factory=list)
|
||||
"""工具调用名称"""
|
||||
tools_call_ids: List[str] = field(default_factory=list)
|
||||
"""工具调用 ID"""
|
||||
|
||||
raw_completion: ChatCompletion = None
|
||||
_new_record: Dict[str, any] = None
|
||||
raw_completion: ChatCompletion | GenerateContentResponse | Message | None = None
|
||||
_new_record: Dict[str, Any] | None = None
|
||||
|
||||
_completion_text: str = ""
|
||||
|
||||
@@ -226,12 +228,12 @@ class LLMResponse:
|
||||
self,
|
||||
role: str,
|
||||
completion_text: str = "",
|
||||
result_chain: MessageChain = None,
|
||||
tools_call_args: List[Dict[str, any]] = None,
|
||||
tools_call_name: List[str] = None,
|
||||
tools_call_ids: List[str] = None,
|
||||
raw_completion: ChatCompletion = None,
|
||||
_new_record: Dict[str, any] = None,
|
||||
result_chain: MessageChain | None = None,
|
||||
tools_call_args: List[Dict[str, Any]] | None = None,
|
||||
tools_call_name: List[str] | None = None,
|
||||
tools_call_ids: List[str] | None = None,
|
||||
raw_completion: ChatCompletion | None = None,
|
||||
_new_record: Dict[str, Any] | None = None,
|
||||
is_chunk: bool = False,
|
||||
):
|
||||
"""初始化 LLMResponse
|
||||
|
||||
@@ -15,7 +15,7 @@ from astrbot import logger
|
||||
from astrbot.api.provider import Provider
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from astrbot.core.provider.entities import LLMResponse
|
||||
from astrbot.core.provider.func_tool_manager import FuncCall
|
||||
from astrbot.core.provider.func_tool_manager import ToolSet
|
||||
from astrbot.core.utils.io import download_image_by_url
|
||||
|
||||
from ..register import register_provider_adapter
|
||||
@@ -61,7 +61,7 @@ class ProviderGoogleGenAI(Provider):
|
||||
default_persona,
|
||||
)
|
||||
self.api_keys: list = provider_config.get("key", [])
|
||||
self.chosen_api_key: str = self.api_keys[0] if len(self.api_keys) > 0 else None
|
||||
self.chosen_api_key: str = self.api_keys[0] if len(self.api_keys) > 0 else ""
|
||||
self.timeout: int = int(provider_config.get("timeout", 180))
|
||||
|
||||
self.api_base: Optional[str] = provider_config.get("api_base", None)
|
||||
@@ -96,6 +96,9 @@ class ProviderGoogleGenAI(Provider):
|
||||
|
||||
async def _handle_api_error(self, e: APIError, keys: list[str]) -> bool:
|
||||
"""处理API错误,返回是否需要重试"""
|
||||
if e.message is None:
|
||||
e.message = ""
|
||||
|
||||
if e.code == 429 or "API key not valid" in e.message:
|
||||
keys.remove(self.chosen_api_key)
|
||||
if len(keys) > 0:
|
||||
@@ -119,7 +122,7 @@ class ProviderGoogleGenAI(Provider):
|
||||
async def _prepare_query_config(
|
||||
self,
|
||||
payloads: dict,
|
||||
tools: Optional[FuncCall] = None,
|
||||
tools: Optional[ToolSet] = None,
|
||||
system_instruction: Optional[str] = None,
|
||||
modalities: Optional[list[str]] = None,
|
||||
temperature: float = 0.7,
|
||||
@@ -321,11 +324,15 @@ class ProviderGoogleGenAI(Provider):
|
||||
|
||||
@staticmethod
|
||||
def _process_content_parts(
|
||||
result: types.GenerateContentResponse, llm_response: LLMResponse
|
||||
candidate: types.Candidate, llm_response: LLMResponse
|
||||
) -> MessageChain:
|
||||
"""处理内容部分并构建消息链"""
|
||||
finish_reason = result.candidates[0].finish_reason
|
||||
result_parts: Optional[types.Part] = result.candidates[0].content.parts
|
||||
if not candidate.content:
|
||||
logger.warning(f"收到的 candidate.content 为空: {candidate}")
|
||||
raise Exception("API 返回的 candidate.content 为空。")
|
||||
|
||||
finish_reason = candidate.finish_reason
|
||||
result_parts: list[types.Part] | None = candidate.content.parts
|
||||
|
||||
if finish_reason == types.FinishReason.SAFETY:
|
||||
raise Exception("模型生成内容未通过 Gemini 平台的安全检查")
|
||||
@@ -343,22 +350,28 @@ class ProviderGoogleGenAI(Provider):
|
||||
raise Exception("模型生成内容违反 Gemini 平台政策")
|
||||
|
||||
if not result_parts:
|
||||
logger.debug(result.candidates)
|
||||
raise Exception("API 返回的内容为空。")
|
||||
logger.warning(f"收到的 candidate.content.parts 为空: {candidate}")
|
||||
raise Exception("API 返回的 candidate.content.parts 为空。")
|
||||
|
||||
chain = []
|
||||
part: types.Part
|
||||
|
||||
# 暂时这样Fallback
|
||||
if all(
|
||||
part.inline_data and part.inline_data.mime_type.startswith("image/")
|
||||
part.inline_data
|
||||
and part.inline_data.mime_type
|
||||
and part.inline_data.mime_type.startswith("image/")
|
||||
for part in result_parts
|
||||
):
|
||||
chain.append(Comp.Plain("这是图片"))
|
||||
for part in result_parts:
|
||||
if part.text:
|
||||
chain.append(Comp.Plain(part.text))
|
||||
elif part.function_call:
|
||||
elif (
|
||||
part.function_call
|
||||
and part.function_call.name
|
||||
and part.function_call.args
|
||||
):
|
||||
llm_response.role = "tool"
|
||||
llm_response.tools_call_name.append(part.function_call.name)
|
||||
llm_response.tools_call_args.append(part.function_call.args)
|
||||
@@ -366,11 +379,16 @@ class ProviderGoogleGenAI(Provider):
|
||||
llm_response.tools_call_ids.append(
|
||||
part.function_call.id or part.function_call.name
|
||||
)
|
||||
elif part.inline_data and part.inline_data.mime_type.startswith("image/"):
|
||||
elif (
|
||||
part.inline_data
|
||||
and part.inline_data.mime_type
|
||||
and part.inline_data.mime_type.startswith("image/")
|
||||
and part.inline_data.data
|
||||
):
|
||||
chain.append(Comp.Image.fromBytes(part.inline_data.data))
|
||||
return MessageChain(chain=chain)
|
||||
|
||||
async def _query(self, payloads: dict, tools: FuncCall) -> LLMResponse:
|
||||
async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse:
|
||||
"""非流式请求 Gemini API"""
|
||||
system_instruction = next(
|
||||
(msg["content"] for msg in payloads["messages"] if msg["role"] == "system"),
|
||||
@@ -396,6 +414,10 @@ class ProviderGoogleGenAI(Provider):
|
||||
config=config,
|
||||
)
|
||||
|
||||
if not result.candidates:
|
||||
logger.error(f"请求失败, 返回的 candidates 为空: {result}")
|
||||
raise Exception("请求失败, 返回的 candidates 为空。")
|
||||
|
||||
if result.candidates[0].finish_reason == types.FinishReason.RECITATION:
|
||||
if temperature > 2:
|
||||
raise Exception("温度参数已超过最大值2,仍然发生recitation")
|
||||
@@ -408,6 +430,8 @@ class ProviderGoogleGenAI(Provider):
|
||||
break
|
||||
|
||||
except APIError as e:
|
||||
if e.message is None:
|
||||
e.message = ""
|
||||
if "Developer instruction is not enabled" in e.message:
|
||||
logger.warning(
|
||||
f"{self.get_model()} 不支持 system prompt,已自动去除(影响人格设置)"
|
||||
@@ -432,11 +456,13 @@ class ProviderGoogleGenAI(Provider):
|
||||
|
||||
llm_response = LLMResponse("assistant")
|
||||
llm_response.raw_completion = result
|
||||
llm_response.result_chain = self._process_content_parts(result, llm_response)
|
||||
llm_response.result_chain = self._process_content_parts(
|
||||
result.candidates[0], llm_response
|
||||
)
|
||||
return llm_response
|
||||
|
||||
async def _query_stream(
|
||||
self, payloads: dict, tools: FuncCall
|
||||
self, payloads: dict, tools: ToolSet | None
|
||||
) -> AsyncGenerator[LLMResponse, None]:
|
||||
"""流式请求 Gemini API"""
|
||||
system_instruction = next(
|
||||
@@ -459,6 +485,8 @@ class ProviderGoogleGenAI(Provider):
|
||||
)
|
||||
break
|
||||
except APIError as e:
|
||||
if e.message is None:
|
||||
e.message = ""
|
||||
if "Developer instruction is not enabled" in e.message:
|
||||
logger.warning(
|
||||
f"{self.get_model()} 不支持 system prompt,已自动去除(影响人格设置)"
|
||||
@@ -478,13 +506,20 @@ class ProviderGoogleGenAI(Provider):
|
||||
async for chunk in result:
|
||||
llm_response = LLMResponse("assistant", is_chunk=True)
|
||||
|
||||
if not chunk.candidates:
|
||||
logger.warning(f"收到的 chunk 中 candidates 为空: {chunk}")
|
||||
continue
|
||||
if not chunk.candidates[0].content:
|
||||
logger.warning(f"收到的 chunk 中 content 为空: {chunk}")
|
||||
continue
|
||||
|
||||
if chunk.candidates[0].content.parts and any(
|
||||
part.function_call for part in chunk.candidates[0].content.parts
|
||||
):
|
||||
llm_response = LLMResponse("assistant", is_chunk=False)
|
||||
llm_response.raw_completion = chunk
|
||||
llm_response.result_chain = self._process_content_parts(
|
||||
chunk, llm_response
|
||||
chunk.candidates[0], llm_response
|
||||
)
|
||||
yield llm_response
|
||||
return
|
||||
@@ -500,7 +535,7 @@ class ProviderGoogleGenAI(Provider):
|
||||
final_response = LLMResponse("assistant", is_chunk=False)
|
||||
final_response.raw_completion = chunk
|
||||
final_response.result_chain = self._process_content_parts(
|
||||
chunk, final_response
|
||||
chunk.candidates[0], final_response
|
||||
)
|
||||
break
|
||||
|
||||
@@ -566,6 +601,8 @@ class ProviderGoogleGenAI(Provider):
|
||||
continue
|
||||
break
|
||||
|
||||
raise Exception("请求失败。")
|
||||
|
||||
async def text_chat_stream(
|
||||
self,
|
||||
prompt,
|
||||
@@ -621,7 +658,9 @@ class ProviderGoogleGenAI(Provider):
|
||||
return [
|
||||
m.name.replace("models/", "")
|
||||
for m in models
|
||||
if "generateContent" in m.supported_actions
|
||||
if m.supported_actions
|
||||
and "generateContent" in m.supported_actions
|
||||
and m.name
|
||||
]
|
||||
except APIError as e:
|
||||
raise Exception(f"获取模型列表失败: {e.message}")
|
||||
@@ -636,7 +675,7 @@ class ProviderGoogleGenAI(Provider):
|
||||
self.chosen_api_key = key
|
||||
self._init_client()
|
||||
|
||||
async def assemble_context(self, text: str, image_urls: list[str] = None):
|
||||
async def assemble_context(self, text: str, image_urls: list[str] | None = None):
|
||||
"""
|
||||
组装上下文。
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user