f624971613
* chore(core.utils): 🚨 修正错误Lint
* chore(core.provider): 🚨 修复基类错误Lint
* chore(core.utils): 补全session_get()的重载
* chore(core.provider): 🚨 修正实现错误Lint
* chore(core.platform): 🚨 修正platform基类和webchat的错误Lint
* chore(core.platform): 修正错误实现Lint
* fix(core.provider): 修复循环调用和错误assert
* chore(core.platform): 修复部分实现Lint
* chore(core.provider): 补充Dify.text_chat_stream的参数类型
* chore(core.pipeline): 🚨 修复错误Lint
* fix(core.slack): 补充遗漏导入
* chore(core.utils): 修复错误的session_get声明
* chore(core.platform): 移除Lark adapter import中的wildcard
* chore(core.db): 修复声明和部分逻辑
* chore(core.db): 添加typings,使faiss参数能被正确识别。
* chore(core): 修复声明
* chore(core): 修改声明
* chore: 补充faiss声明
* chore(dashboard): 修改实现,减少报错
* chore(package): 修改部分声明与实现,减少报错
* chore(core): 添加Handler的overload,以去除部分assert同时通过类型检查
* chore(core.pipeline): 修改Pipeline Scheduler的execute,将判断属性改为判断类型,通过静态类型检查
* chore(core.config): 添加类型标注,通过类型检查
* chore(core.message): 为File._download_file添加检查,通过类型检查
* fix: 将断言改为条件判断以实现优雅关闭的容错性
* refactor: 移除 discord 客户端中的 assert,改用 if None 判断并抛出异常
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* fix: DiscordPlatformAdapter 对 self.client.user 为 None 做日志并返回,移除断言
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* fix: 增强 Lark 相关空值/异常检查并完善日志输出
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* refactor: 将断言替换为条件检查并加入日志与错误处理
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* chore: 移除LLM生成的无用注释
* refactor: 使用 File.get_file 替换下载逻辑并移除 assert,提供默认 filename
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* fix: Slack Socket 未初始化抛出运行时异常,图片 URL 判空改为非空判断
* refactor: 将 WeChatPadProAdapter 的断言改为空值判断并添加日志
* refactor: 使用 isinstance 替代断言实现类型判断,便于静态检查
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* fix: 去除cast,直接使用字段与字典访问,修正端口解析
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* refactor: 使用 match-case 重构 ProviderManager 加载并通过类型检查抛出 TypeError
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* fix: group_name_display 时若 group 对象为空则记录错误并返回
* fix: 将 _get_current_persona_id 的 assert 替换成 if guard 并返回 None
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* fix: 优化插件目录存在性检查及图片URL非空验证,更新JSON排序配置
* fix: 将 datetime_str 的 assert 替换为显式检查并抛出异常
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* refactor: 移除 cast,改为运行时检查并在找不到调度器时跳过
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* refactor: 移除 cast,改用 isinstance 检查 FaissVecDB 并警告
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* fix: 删除 typing.cast 导入,并在获取文件绝对路径前校验 file_
* refactor: 移除 typing.cast,简化内容安全检查调用
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* refactor: 将 PlatformMetadata.id 设为必填并在注册时传入 id,移除 cast
* refactor: 移除 cast,改用 HasInitialize 与 isinstance 进行初始化
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* fix: 为 ProviderManager.initialize 增加ID类型判断,避免 None 导致 get 失败
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* refactor: 为 OTTSProvider 与 AzureNativeProvider 引入 _client 与 client 属性改进上下文管理
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* fix: 为 Whisper 自托管源添加模型未初始化校验并直接调用 transcribe
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* refactor: 移除未使用的 cast 导入并简化 platform_name 赋值
* refactor: 引入 cast 并对 id 使用 cast(str, ...) 提升类型安全
* fix: 将 _id_to_sid 返回改为 str,空值返回空串;对 id 与 message_id 使用 cast
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* refactor: 重构 Discord 处理逻辑:强制 类型转换、优先斜杠指令并优化提及判断
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* fix: 统一对 id 获取执行 cast,并在微信消息解析失败时抛错
* Revert "fix: 去除cast,直接使用字段与字典访问,修正端口解析"
This reverts commit 1cbfdf9d1b.
* fix: 百炼 Rerank 会话关闭时返回空结果;初始化 request.prompt 避免空值拼接
* fix: 统一处理搜索结果链接为字符串,新增 _get_url 助手并适配 Bing/Sogo
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* refactor: 调整 call_handler 泛型、Discord 通道注解及 FishAudioTTS API 请求类型
* refactor: 使用 col(...) 替代列引用并对结果进行 CursorResult 强转
* chore: ruff format
---------
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
Co-authored-by: Soulter <905617992@qq.com>
787 lines
31 KiB
Python
787 lines
31 KiB
Python
import asyncio
|
|
import base64
|
|
import json
|
|
import logging
|
|
import random
|
|
from collections.abc import AsyncGenerator
|
|
from typing import cast
|
|
|
|
from google import genai
|
|
from google.genai import types
|
|
from google.genai.errors import APIError
|
|
|
|
import astrbot.core.message.components as Comp
|
|
from astrbot import logger
|
|
from astrbot.api.provider import Provider
|
|
from astrbot.core.message.message_event_result import MessageChain
|
|
from astrbot.core.provider.entities import LLMResponse
|
|
from astrbot.core.provider.func_tool_manager import ToolSet
|
|
from astrbot.core.utils.io import download_image_by_url
|
|
|
|
from ..register import register_provider_adapter
|
|
|
|
|
|
class SuppressNonTextPartsWarning(logging.Filter):
|
|
"""过滤 Gemini SDK 中的非文本部分警告"""
|
|
|
|
def filter(self, record):
|
|
return "there are non-text parts in the response" not in record.getMessage()
|
|
|
|
|
|
logging.getLogger("google_genai.types").addFilter(SuppressNonTextPartsWarning())
|
|
|
|
|
|
@register_provider_adapter(
|
|
"googlegenai_chat_completion",
|
|
"Google Gemini Chat Completion 提供商适配器",
|
|
)
|
|
class ProviderGoogleGenAI(Provider):
|
|
CATEGORY_MAPPING = {
|
|
"harassment": types.HarmCategory.HARM_CATEGORY_HARASSMENT,
|
|
"hate_speech": types.HarmCategory.HARM_CATEGORY_HATE_SPEECH,
|
|
"sexually_explicit": types.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
|
|
"dangerous_content": types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
|
|
}
|
|
|
|
THRESHOLD_MAPPING = {
|
|
"BLOCK_NONE": types.HarmBlockThreshold.BLOCK_NONE,
|
|
"BLOCK_ONLY_HIGH": types.HarmBlockThreshold.BLOCK_ONLY_HIGH,
|
|
"BLOCK_MEDIUM_AND_ABOVE": types.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
|
|
"BLOCK_LOW_AND_ABOVE": types.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
|
|
}
|
|
|
|
def __init__(
|
|
self,
|
|
provider_config,
|
|
provider_settings,
|
|
) -> None:
|
|
super().__init__(
|
|
provider_config,
|
|
provider_settings,
|
|
)
|
|
self.api_keys: list = super().get_keys()
|
|
self.chosen_api_key: str = self.api_keys[0] if len(self.api_keys) > 0 else ""
|
|
self.timeout: int = int(provider_config.get("timeout", 180))
|
|
|
|
self.api_base: str | None = provider_config.get("api_base", None)
|
|
if self.api_base and self.api_base.endswith("/"):
|
|
self.api_base = self.api_base[:-1]
|
|
|
|
self._init_client()
|
|
self.set_model(provider_config["model_config"]["model"])
|
|
self._init_safety_settings()
|
|
|
|
def _init_client(self) -> None:
|
|
"""初始化Gemini客户端"""
|
|
self.client = genai.Client(
|
|
api_key=self.chosen_api_key,
|
|
http_options=types.HttpOptions(
|
|
base_url=self.api_base,
|
|
timeout=self.timeout * 1000, # 毫秒
|
|
),
|
|
).aio
|
|
|
|
def _init_safety_settings(self) -> None:
|
|
"""初始化安全设置"""
|
|
user_safety_config = self.provider_config.get("gm_safety_settings", {})
|
|
self.safety_settings = [
|
|
types.SafetySetting(
|
|
category=harm_category,
|
|
threshold=self.THRESHOLD_MAPPING[threshold_str],
|
|
)
|
|
for config_key, harm_category in self.CATEGORY_MAPPING.items()
|
|
if (threshold_str := user_safety_config.get(config_key))
|
|
and threshold_str in self.THRESHOLD_MAPPING
|
|
]
|
|
|
|
async def _handle_api_error(self, e: APIError, keys: list[str]) -> bool:
|
|
"""处理API错误,返回是否需要重试"""
|
|
if e.message is None:
|
|
e.message = ""
|
|
|
|
if e.code == 429 or "API key not valid" in e.message:
|
|
keys.remove(self.chosen_api_key)
|
|
if len(keys) > 0:
|
|
self.set_key(random.choice(keys))
|
|
logger.info(
|
|
f"检测到 Key 异常({e.message}),正在尝试更换 API Key 重试... 当前 Key: {self.chosen_api_key[:12]}...",
|
|
)
|
|
await asyncio.sleep(1)
|
|
return True
|
|
logger.error(
|
|
f"检测到 Key 异常({e.message}),且已没有可用的 Key。 当前 Key: {self.chosen_api_key[:12]}...",
|
|
)
|
|
raise Exception("达到了 Gemini 速率限制, 请稍后再试...")
|
|
# logger.error(
|
|
# f"发生了错误(gemini_source)。Provider 配置如下: {self.provider_config}",
|
|
# )
|
|
raise e
|
|
|
|
async def _prepare_query_config(
|
|
self,
|
|
payloads: dict,
|
|
tools: ToolSet | None = None,
|
|
system_instruction: str | None = None,
|
|
modalities: list[str] | None = None,
|
|
temperature: float = 0.7,
|
|
) -> types.GenerateContentConfig:
|
|
"""准备查询配置"""
|
|
if not modalities:
|
|
modalities = ["Text"]
|
|
|
|
# 流式输出不支持图片模态
|
|
if (
|
|
self.provider_settings.get("streaming_response", False)
|
|
and "Image" in modalities
|
|
):
|
|
logger.warning("流式输出不支持图片模态,已自动降级为文本模态")
|
|
modalities = ["Text"]
|
|
|
|
tool_list: list[types.Tool] | None = []
|
|
model_name = self.get_model()
|
|
native_coderunner = self.provider_config.get("gm_native_coderunner", False)
|
|
native_search = self.provider_config.get("gm_native_search", False)
|
|
url_context = self.provider_config.get("gm_url_context", False)
|
|
|
|
if "gemini-2.5" in model_name:
|
|
if native_coderunner:
|
|
tool_list.append(types.Tool(code_execution=types.ToolCodeExecution()))
|
|
if native_search:
|
|
logger.warning("代码执行工具与搜索工具互斥,已忽略搜索工具")
|
|
if url_context:
|
|
logger.warning(
|
|
"代码执行工具与URL上下文工具互斥,已忽略URL上下文工具",
|
|
)
|
|
else:
|
|
if native_search:
|
|
tool_list.append(types.Tool(google_search=types.GoogleSearch()))
|
|
|
|
if url_context:
|
|
if hasattr(types, "UrlContext"):
|
|
tool_list.append(types.Tool(url_context=types.UrlContext()))
|
|
else:
|
|
logger.warning(
|
|
"当前 SDK 版本不支持 URL 上下文工具,已忽略该设置,请升级 google-genai 包",
|
|
)
|
|
|
|
elif "gemini-2.0-lite" in model_name:
|
|
if native_coderunner or native_search or url_context:
|
|
logger.warning(
|
|
"gemini-2.0-lite 不支持代码执行、搜索工具和URL上下文,将忽略这些设置",
|
|
)
|
|
tool_list = None
|
|
|
|
else:
|
|
if native_coderunner:
|
|
tool_list.append(types.Tool(code_execution=types.ToolCodeExecution()))
|
|
if native_search:
|
|
logger.warning("代码执行工具与搜索工具互斥,已忽略搜索工具")
|
|
elif native_search:
|
|
tool_list.append(types.Tool(google_search=types.GoogleSearch()))
|
|
|
|
if url_context and not native_coderunner:
|
|
if hasattr(types, "UrlContext"):
|
|
tool_list.append(types.Tool(url_context=types.UrlContext()))
|
|
else:
|
|
logger.warning(
|
|
"当前 SDK 版本不支持 URL 上下文工具,已忽略该设置,请升级 google-genai 包",
|
|
)
|
|
|
|
if not tool_list:
|
|
tool_list = None
|
|
|
|
if tools and tool_list:
|
|
logger.warning("已启用原生工具,函数工具将被忽略")
|
|
elif tools and (func_desc := tools.get_func_desc_google_genai_style()):
|
|
tool_list = [
|
|
types.Tool(function_declarations=func_desc["function_declarations"]),
|
|
]
|
|
|
|
return types.GenerateContentConfig(
|
|
system_instruction=system_instruction,
|
|
temperature=temperature,
|
|
max_output_tokens=payloads.get("max_tokens")
|
|
or payloads.get("maxOutputTokens"),
|
|
top_p=payloads.get("top_p") or payloads.get("topP"),
|
|
top_k=payloads.get("top_k") or payloads.get("topK"),
|
|
frequency_penalty=payloads.get("frequency_penalty")
|
|
or payloads.get("frequencyPenalty"),
|
|
presence_penalty=payloads.get("presence_penalty")
|
|
or payloads.get("presencePenalty"),
|
|
stop_sequences=payloads.get("stop") or payloads.get("stopSequences"),
|
|
response_logprobs=payloads.get("response_logprobs")
|
|
or payloads.get("responseLogprobs"),
|
|
logprobs=payloads.get("logprobs"),
|
|
seed=payloads.get("seed"),
|
|
response_modalities=modalities,
|
|
tools=cast(types.ToolListUnion | None, tool_list),
|
|
safety_settings=self.safety_settings if self.safety_settings else None,
|
|
thinking_config=(
|
|
types.ThinkingConfig(
|
|
thinking_budget=min(
|
|
int(
|
|
self.provider_config.get("gm_thinking_config", {}).get(
|
|
"budget",
|
|
0,
|
|
),
|
|
),
|
|
24576,
|
|
),
|
|
)
|
|
if "gemini-2.5-flash" in self.get_model()
|
|
and hasattr(types.ThinkingConfig, "thinking_budget")
|
|
else None
|
|
),
|
|
automatic_function_calling=types.AutomaticFunctionCallingConfig(
|
|
disable=True,
|
|
),
|
|
)
|
|
|
|
def _prepare_conversation(self, payloads: dict) -> list[types.Content]:
|
|
"""准备 Gemini SDK 的 Content 列表"""
|
|
|
|
def create_text_part(text: str) -> types.Part:
|
|
content_a = text if text else " "
|
|
if not text:
|
|
logger.warning("文本内容为空,已添加空格占位")
|
|
return types.Part.from_text(text=content_a)
|
|
|
|
def process_image_url(image_url_dict: dict) -> types.Part:
|
|
url = image_url_dict["url"]
|
|
mime_type = url.split(":")[1].split(";")[0]
|
|
image_bytes = base64.b64decode(url.split(",", 1)[1])
|
|
return types.Part.from_bytes(data=image_bytes, mime_type=mime_type)
|
|
|
|
def append_or_extend(
|
|
contents: list[types.Content],
|
|
part: list[types.Part],
|
|
content_cls: type[types.Content],
|
|
) -> None:
|
|
if contents and isinstance(contents[-1], content_cls):
|
|
assert contents[-1].parts is not None
|
|
contents[-1].parts.extend(part)
|
|
else:
|
|
contents.append(content_cls(parts=part))
|
|
|
|
gemini_contents: list[types.Content] = []
|
|
native_tool_enabled = any(
|
|
[
|
|
self.provider_config.get("gm_native_coderunner", False),
|
|
self.provider_config.get("gm_native_search", False),
|
|
],
|
|
)
|
|
for message in payloads["messages"]:
|
|
role, content = message["role"], message.get("content")
|
|
|
|
if role == "user":
|
|
if isinstance(content, list):
|
|
parts = [
|
|
(
|
|
types.Part.from_text(text=item["text"] or " ")
|
|
if item["type"] == "text"
|
|
else process_image_url(item["image_url"])
|
|
)
|
|
for item in content
|
|
]
|
|
else:
|
|
parts = [create_text_part(content)]
|
|
append_or_extend(gemini_contents, parts, types.UserContent)
|
|
|
|
elif role == "assistant":
|
|
if content:
|
|
parts = [types.Part.from_text(text=content)]
|
|
append_or_extend(gemini_contents, parts, types.ModelContent)
|
|
elif not native_tool_enabled and "tool_calls" in message:
|
|
parts = []
|
|
for tool in message["tool_calls"]:
|
|
part = types.Part.from_function_call(
|
|
name=tool["function"]["name"],
|
|
args=json.loads(tool["function"]["arguments"]),
|
|
)
|
|
# we should set thought_signature back to part if exists
|
|
# for more info about thought_signature, see:
|
|
# https://ai.google.dev/gemini-api/docs/thought-signatures
|
|
if "extra_content" in tool and tool["extra_content"]:
|
|
ts_bs64 = (
|
|
tool["extra_content"]
|
|
.get("google", {})
|
|
.get("thought_signature")
|
|
)
|
|
if ts_bs64:
|
|
part.thought_signature = base64.b64decode(ts_bs64)
|
|
parts.append(part)
|
|
append_or_extend(gemini_contents, parts, types.ModelContent)
|
|
else:
|
|
logger.warning("assistant 角色的消息内容为空,已添加空格占位")
|
|
if native_tool_enabled and "tool_calls" in message:
|
|
logger.warning(
|
|
"检测到启用Gemini原生工具,且上下文中存在函数调用,建议使用 /reset 重置上下文",
|
|
)
|
|
parts = [types.Part.from_text(text=" ")]
|
|
append_or_extend(gemini_contents, parts, types.ModelContent)
|
|
|
|
elif role == "tool" and not native_tool_enabled:
|
|
parts = [
|
|
types.Part.from_function_response(
|
|
name=message["tool_call_id"],
|
|
response={
|
|
"name": message["tool_call_id"],
|
|
"content": message["content"],
|
|
},
|
|
),
|
|
]
|
|
append_or_extend(gemini_contents, parts, types.UserContent)
|
|
|
|
if gemini_contents and isinstance(gemini_contents[0], types.ModelContent):
|
|
gemini_contents.pop()
|
|
|
|
return gemini_contents
|
|
|
|
def _extract_reasoning_content(self, candidate: types.Candidate) -> str:
|
|
"""Extract reasoning content from candidate parts"""
|
|
if not candidate.content or not candidate.content.parts:
|
|
return ""
|
|
|
|
thought_buf: list[str] = [
|
|
(p.text or "") for p in candidate.content.parts if p.thought
|
|
]
|
|
return "".join(thought_buf).strip()
|
|
|
|
def _process_content_parts(
|
|
self,
|
|
candidate: types.Candidate,
|
|
llm_response: LLMResponse,
|
|
) -> MessageChain:
|
|
"""处理内容部分并构建消息链"""
|
|
if not candidate.content:
|
|
logger.warning(f"收到的 candidate.content 为空: {candidate}")
|
|
raise Exception("API 返回的 candidate.content 为空。")
|
|
|
|
finish_reason = candidate.finish_reason
|
|
result_parts: list[types.Part] | None = candidate.content.parts
|
|
|
|
if finish_reason == types.FinishReason.SAFETY:
|
|
raise Exception("模型生成内容未通过 Gemini 平台的安全检查")
|
|
|
|
if finish_reason in {
|
|
types.FinishReason.PROHIBITED_CONTENT,
|
|
types.FinishReason.SPII,
|
|
types.FinishReason.BLOCKLIST,
|
|
}:
|
|
raise Exception("模型生成内容违反 Gemini 平台政策")
|
|
|
|
# 防止旧版本SDK不存在IMAGE_SAFETY
|
|
if hasattr(types.FinishReason, "IMAGE_SAFETY"):
|
|
if finish_reason == types.FinishReason.IMAGE_SAFETY:
|
|
raise Exception("模型生成内容违反 Gemini 平台政策")
|
|
|
|
if not result_parts:
|
|
logger.warning(f"收到的 candidate.content.parts 为空: {candidate}")
|
|
raise Exception("API 返回的 candidate.content.parts 为空。")
|
|
|
|
# 提取 reasoning content
|
|
reasoning = self._extract_reasoning_content(candidate)
|
|
if reasoning:
|
|
llm_response.reasoning_content = reasoning
|
|
|
|
chain = []
|
|
part: types.Part
|
|
|
|
# 暂时这样Fallback
|
|
if all(
|
|
part.inline_data
|
|
and part.inline_data.mime_type
|
|
and part.inline_data.mime_type.startswith("image/")
|
|
for part in result_parts
|
|
):
|
|
chain.append(Comp.Plain("这是图片"))
|
|
for part in result_parts:
|
|
if part.text:
|
|
chain.append(Comp.Plain(part.text))
|
|
elif (
|
|
part.function_call
|
|
and part.function_call.name is not None
|
|
and part.function_call.args is not None
|
|
):
|
|
llm_response.role = "tool"
|
|
llm_response.tools_call_name.append(part.function_call.name)
|
|
llm_response.tools_call_args.append(part.function_call.args)
|
|
# function_call.id might be None, use name as fallback
|
|
tool_call_id = part.function_call.id or part.function_call.name
|
|
llm_response.tools_call_ids.append(tool_call_id)
|
|
# extra_content
|
|
if part.thought_signature:
|
|
ts_bs64 = base64.b64encode(part.thought_signature).decode("utf-8")
|
|
llm_response.tools_call_extra_content[tool_call_id] = {
|
|
"google": {"thought_signature": ts_bs64}
|
|
}
|
|
elif (
|
|
part.inline_data
|
|
and part.inline_data.mime_type
|
|
and part.inline_data.mime_type.startswith("image/")
|
|
and part.inline_data.data
|
|
):
|
|
chain.append(Comp.Image.fromBytes(part.inline_data.data))
|
|
return MessageChain(chain=chain)
|
|
|
|
async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse:
|
|
"""非流式请求 Gemini API"""
|
|
system_instruction = next(
|
|
(msg["content"] for msg in payloads["messages"] if msg["role"] == "system"),
|
|
None,
|
|
)
|
|
|
|
modalities = ["Text"]
|
|
if self.provider_config.get("gm_resp_image_modal", False):
|
|
modalities.append("Image")
|
|
|
|
conversation = self._prepare_conversation(payloads)
|
|
temperature = payloads.get("temperature", 0.7)
|
|
|
|
result: types.GenerateContentResponse | None = None
|
|
while True:
|
|
try:
|
|
config = await self._prepare_query_config(
|
|
payloads,
|
|
tools,
|
|
system_instruction,
|
|
modalities,
|
|
temperature,
|
|
)
|
|
result = await self.client.models.generate_content(
|
|
model=self.get_model(),
|
|
contents=cast(types.ContentListUnion, conversation),
|
|
config=config,
|
|
)
|
|
logger.debug(f"genai result: {result}")
|
|
|
|
if not result.candidates:
|
|
logger.error(f"请求失败, 返回的 candidates 为空: {result}")
|
|
raise Exception("请求失败, 返回的 candidates 为空。")
|
|
|
|
if result.candidates[0].finish_reason == types.FinishReason.RECITATION:
|
|
if temperature > 2:
|
|
raise Exception("温度参数已超过最大值2,仍然发生recitation")
|
|
temperature += 0.2
|
|
logger.warning(
|
|
f"发生了recitation,正在提高温度至{temperature:.1f}重试...",
|
|
)
|
|
continue
|
|
|
|
break
|
|
|
|
except APIError as e:
|
|
if e.message is None:
|
|
e.message = ""
|
|
if "Developer instruction is not enabled" in e.message:
|
|
logger.warning(
|
|
f"{self.get_model()} 不支持 system prompt,已自动去除(影响人格设置)",
|
|
)
|
|
system_instruction = None
|
|
elif "Function calling is not enabled" in e.message:
|
|
logger.warning(f"{self.get_model()} 不支持函数调用,已自动去除")
|
|
tools = None
|
|
elif (
|
|
"Multi-modal output is not supported" in e.message
|
|
or "Model does not support the requested response modalities"
|
|
in e.message
|
|
or "only supports text output" in e.message
|
|
):
|
|
logger.warning(
|
|
f"{self.get_model()} 不支持多模态输出,降级为文本模态",
|
|
)
|
|
modalities = ["Text"]
|
|
else:
|
|
raise
|
|
continue
|
|
|
|
llm_response = LLMResponse("assistant")
|
|
llm_response.raw_completion = result
|
|
llm_response.result_chain = self._process_content_parts(
|
|
result.candidates[0],
|
|
llm_response,
|
|
)
|
|
return llm_response
|
|
|
|
async def _query_stream(
|
|
self,
|
|
payloads: dict,
|
|
tools: ToolSet | None,
|
|
) -> AsyncGenerator[LLMResponse, None]:
|
|
"""流式请求 Gemini API"""
|
|
system_instruction = next(
|
|
(msg["content"] for msg in payloads["messages"] if msg["role"] == "system"),
|
|
None,
|
|
)
|
|
|
|
conversation = self._prepare_conversation(payloads)
|
|
|
|
result = None
|
|
while True:
|
|
try:
|
|
config = await self._prepare_query_config(
|
|
payloads,
|
|
tools,
|
|
system_instruction,
|
|
)
|
|
result = await self.client.models.generate_content_stream(
|
|
model=self.get_model(),
|
|
contents=cast(types.ContentListUnion, conversation),
|
|
config=config,
|
|
)
|
|
break
|
|
except APIError as e:
|
|
if e.message is None:
|
|
e.message = ""
|
|
if "Developer instruction is not enabled" in e.message:
|
|
logger.warning(
|
|
f"{self.get_model()} 不支持 system prompt,已自动去除(影响人格设置)",
|
|
)
|
|
system_instruction = None
|
|
elif "Function calling is not enabled" in e.message:
|
|
logger.warning(f"{self.get_model()} 不支持函数调用,已自动去除")
|
|
tools = None
|
|
else:
|
|
raise
|
|
continue
|
|
|
|
# Accumulate the complete response text for the final response
|
|
accumulated_text = ""
|
|
accumulated_reasoning = ""
|
|
final_response = None
|
|
|
|
async for chunk in result:
|
|
llm_response = LLMResponse("assistant", is_chunk=True)
|
|
|
|
if not chunk.candidates:
|
|
logger.warning(f"收到的 chunk 中 candidates 为空: {chunk}")
|
|
continue
|
|
if not chunk.candidates[0].content:
|
|
logger.warning(f"收到的 chunk 中 content 为空: {chunk}")
|
|
continue
|
|
|
|
if chunk.candidates[0].content.parts and any(
|
|
part.function_call for part in chunk.candidates[0].content.parts
|
|
):
|
|
llm_response = LLMResponse("assistant", is_chunk=False)
|
|
llm_response.raw_completion = chunk
|
|
llm_response.result_chain = self._process_content_parts(
|
|
chunk.candidates[0],
|
|
llm_response,
|
|
)
|
|
yield llm_response
|
|
return
|
|
|
|
_f = False
|
|
|
|
# 提取 reasoning content
|
|
reasoning = self._extract_reasoning_content(chunk.candidates[0])
|
|
if reasoning:
|
|
_f = True
|
|
accumulated_reasoning += reasoning
|
|
llm_response.reasoning_content = reasoning
|
|
if chunk.text:
|
|
_f = True
|
|
accumulated_text += chunk.text
|
|
llm_response.result_chain = MessageChain(chain=[Comp.Plain(chunk.text)])
|
|
if _f:
|
|
yield llm_response
|
|
|
|
if chunk.candidates[0].finish_reason:
|
|
# Process the final chunk for potential tool calls or other content
|
|
if chunk.candidates[0].content.parts:
|
|
final_response = LLMResponse("assistant", is_chunk=False)
|
|
final_response.raw_completion = chunk
|
|
final_response.result_chain = self._process_content_parts(
|
|
chunk.candidates[0],
|
|
final_response,
|
|
)
|
|
break
|
|
|
|
# Yield final complete response with accumulated text
|
|
if not final_response:
|
|
final_response = LLMResponse("assistant", is_chunk=False)
|
|
|
|
# Set the complete accumulated reasoning in the final response
|
|
if accumulated_reasoning:
|
|
final_response.reasoning_content = accumulated_reasoning
|
|
|
|
# Set the complete accumulated text in the final response
|
|
if accumulated_text:
|
|
final_response.result_chain = MessageChain(
|
|
chain=[Comp.Plain(accumulated_text)],
|
|
)
|
|
elif not final_response.result_chain:
|
|
# If no text was accumulated and no final response was set, provide empty space
|
|
final_response.result_chain = MessageChain(chain=[Comp.Plain(" ")])
|
|
|
|
yield final_response
|
|
|
|
async def text_chat(
|
|
self,
|
|
prompt=None,
|
|
session_id=None,
|
|
image_urls=None,
|
|
func_tool=None,
|
|
contexts=None,
|
|
system_prompt=None,
|
|
tool_calls_result=None,
|
|
model=None,
|
|
**kwargs,
|
|
) -> LLMResponse:
|
|
if contexts is None:
|
|
contexts = []
|
|
new_record = None
|
|
if prompt is not None:
|
|
new_record = await self.assemble_context(prompt, image_urls)
|
|
context_query = self._ensure_message_to_dicts(contexts)
|
|
if new_record:
|
|
context_query.append(new_record)
|
|
if system_prompt:
|
|
context_query.insert(0, {"role": "system", "content": system_prompt})
|
|
|
|
for part in context_query:
|
|
if "_no_save" in part:
|
|
del part["_no_save"]
|
|
|
|
# tool calls result
|
|
if tool_calls_result:
|
|
if not isinstance(tool_calls_result, list):
|
|
context_query.extend(tool_calls_result.to_openai_messages())
|
|
else:
|
|
for tcr in tool_calls_result:
|
|
context_query.extend(tcr.to_openai_messages())
|
|
|
|
model_config = self.provider_config.get("model_config", {})
|
|
model_config["model"] = model or self.get_model()
|
|
|
|
payloads = {"messages": context_query, **model_config}
|
|
|
|
retry = 10
|
|
keys = self.api_keys.copy()
|
|
|
|
for _ in range(retry):
|
|
try:
|
|
return await self._query(payloads, func_tool)
|
|
except APIError as e:
|
|
if await self._handle_api_error(e, keys):
|
|
continue
|
|
break
|
|
|
|
raise Exception("请求失败。")
|
|
|
|
async def text_chat_stream(
|
|
self,
|
|
prompt=None,
|
|
session_id=None,
|
|
image_urls=None,
|
|
func_tool=None,
|
|
contexts=None,
|
|
system_prompt=None,
|
|
tool_calls_result=None,
|
|
model=None,
|
|
**kwargs,
|
|
) -> AsyncGenerator[LLMResponse, None]:
|
|
if contexts is None:
|
|
contexts = []
|
|
new_record = None
|
|
if prompt is not None:
|
|
new_record = await self.assemble_context(prompt, image_urls)
|
|
context_query = self._ensure_message_to_dicts(contexts)
|
|
if new_record:
|
|
context_query.append(new_record)
|
|
if system_prompt:
|
|
context_query.insert(0, {"role": "system", "content": system_prompt})
|
|
|
|
for part in context_query:
|
|
if "_no_save" in part:
|
|
del part["_no_save"]
|
|
|
|
# tool calls result
|
|
if tool_calls_result:
|
|
if not isinstance(tool_calls_result, list):
|
|
context_query.extend(tool_calls_result.to_openai_messages())
|
|
else:
|
|
for tcr in tool_calls_result:
|
|
context_query.extend(tcr.to_openai_messages())
|
|
|
|
model_config = self.provider_config.get("model_config", {})
|
|
model_config["model"] = model or self.get_model()
|
|
|
|
payloads = {"messages": context_query, **model_config}
|
|
|
|
retry = 10
|
|
keys = self.api_keys.copy()
|
|
|
|
for _ in range(retry):
|
|
try:
|
|
async for response in self._query_stream(payloads, func_tool):
|
|
yield response
|
|
break
|
|
except APIError as e:
|
|
if await self._handle_api_error(e, keys):
|
|
continue
|
|
break
|
|
|
|
async def get_models(self):
|
|
try:
|
|
models = await self.client.models.list()
|
|
return [
|
|
m.name.replace("models/", "")
|
|
for m in models
|
|
if m.supported_actions
|
|
and "generateContent" in m.supported_actions
|
|
and m.name
|
|
]
|
|
except APIError as e:
|
|
raise Exception(f"获取模型列表失败: {e.message}")
|
|
|
|
def get_current_key(self) -> str:
|
|
return self.chosen_api_key
|
|
|
|
def get_keys(self) -> list[str]:
|
|
return self.api_keys
|
|
|
|
def set_key(self, key):
|
|
self.chosen_api_key = key
|
|
self._init_client()
|
|
|
|
async def assemble_context(self, text: str, image_urls: list[str] | None = None):
|
|
"""组装上下文。"""
|
|
if image_urls:
|
|
user_content = {
|
|
"role": "user",
|
|
"content": [{"type": "text", "text": text if text else "[图片]"}],
|
|
}
|
|
for image_url in image_urls:
|
|
if image_url.startswith("http"):
|
|
image_path = await download_image_by_url(image_url)
|
|
image_data = await self.encode_image_bs64(image_path)
|
|
elif image_url.startswith("file:///"):
|
|
image_path = image_url.replace("file:///", "")
|
|
image_data = await self.encode_image_bs64(image_path)
|
|
else:
|
|
image_data = await self.encode_image_bs64(image_url)
|
|
if not image_data:
|
|
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
|
|
continue
|
|
user_content["content"].append(
|
|
{
|
|
"type": "image_url",
|
|
"image_url": {"url": image_data},
|
|
},
|
|
)
|
|
return user_content
|
|
return {"role": "user", "content": text}
|
|
|
|
async def encode_image_bs64(self, image_url: str) -> str:
|
|
"""将图片转换为 base64"""
|
|
if image_url.startswith("base64://"):
|
|
return image_url.replace("base64://", "data:image/jpeg;base64,")
|
|
with open(image_url, "rb") as f:
|
|
image_bs64 = base64.b64encode(f.read()).decode("utf-8")
|
|
return "data:image/jpeg;base64," + image_bs64
|
|
|
|
async def terminate(self):
|
|
logger.info("Google GenAI 适配器已终止。")
|