refactor: 初步完成gemini_source的重写

This commit is contained in:
Raven95676
2025-04-11 01:03:16 +08:00
parent f2cc4311c5
commit 0b766095d4
+193 -235
View File
@@ -1,8 +1,9 @@
import base64
import aiohttp
import json
import random
import asyncio
from google import genai
from google.genai import types, errors
import astrbot.core.message.components as Comp
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.utils.io import download_image_by_url
@@ -10,112 +11,28 @@ from astrbot.core.db import BaseDatabase
from astrbot.api.provider import Provider, Personality
from astrbot import logger
from astrbot.core.provider.func_tool_manager import FuncCall
from typing import List
from typing import Dict, List
from ..register import register_provider_adapter
from astrbot.core.provider.entities import LLMResponse
class SimpleGoogleGenAIClient:
def __init__(self, api_key: str, api_base: str, timeout: int = 120) -> None:
self.api_key = api_key
if api_base.endswith("/"):
self.api_base = api_base[:-1]
else:
self.api_base = api_base
self.client = aiohttp.ClientSession(trust_env=True)
self.timeout = timeout
async def models_list(self) -> List[str]:
request_url = f"{self.api_base}/v1beta/models?key={self.api_key}"
async with self.client.get(request_url, timeout=self.timeout) as resp:
response = await resp.json()
models = []
for model in response["models"]:
if "generateContent" in model["supportedGenerationMethods"]:
models.append(model["name"].replace("models/", ""))
return models
async def generate_content(
self,
contents: List[dict],
model: str = "gemini-1.5-flash",
system_instruction: str = "",
tools: dict = None,
modalities: List[str] = ["Text"],
safety_settings: List[dict] = [],
):
payload = {}
if system_instruction:
payload["system_instruction"] = {"parts": {"text": system_instruction}}
if tools:
payload["tools"] = [tools]
payload["contents"] = contents
payload["generationConfig"] = {
"responseModalities": modalities,
}
payload["safetySettings"] = [
{"category": s["category"], "threshold": s["threshold"]}
for s in safety_settings
]
logger.debug(f"payload: {payload}")
request_url = (
f"{self.api_base}/v1beta/models/{model}:generateContent?key={self.api_key}"
)
async with self.client.post(
request_url, json=payload, timeout=self.timeout
) as resp:
if "application/json" in resp.headers.get("Content-Type"):
try:
response = await resp.json()
except Exception as e:
text = await resp.text()
logger.error(f"Gemini 返回了非 json 数据: {text}")
raise e
return response
else:
text = await resp.text()
logger.error(f"Gemini 返回了非 json 数据: {text}")
raise Exception("Gemini 返回了非 json 数据: ")
async def stream_generate_content(
self,
contents: List[dict],
model: str = "gemini-1.5-flash",
system_instruction: str = "",
tools: dict = None,
modalities: List[str] = ["Text"],
safety_settings: List[dict] = [],
):
payload = {}
if system_instruction:
payload["system_instruction"] = {"parts": {"text": system_instruction}}
if tools:
payload["tools"] = [tools]
payload["contents"] = contents
payload["generationConfig"] = {
"responseModalities": modalities,
"stream": True,
}
payload["safetySettings"] = [
{"category": s["category"], "threshold": s["threshold"]}
for s in safety_settings
]
logger.debug(f"payload: {payload}")
request_url = (
f"{self.api_base}/v1beta/models/{model}:streamGenerateContent?key={self.api_key}"
)
async with self.client.post(
request_url, json=payload, timeout=self.timeout
) as resp:
async for line in resp.content:
if line:
yield line
@register_provider_adapter(
"googlegenai_chat_completion", "Google Gemini Chat Completion 提供商适配器"
)
class ProviderGoogleGenAI(Provider):
CATEGORY_MAPPING = {
"harassment": types.HarmCategory.HARM_CATEGORY_HARASSMENT,
"hate_speech": types.HarmCategory.HARM_CATEGORY_HATE_SPEECH,
"sexually_explicit": types.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
"dangerous_content": types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
}
THRESHOLD_MAPPING = {
"BLOCK_NONE": types.HarmBlockThreshold.BLOCK_NONE,
"BLOCK_ONLY_HIGH": types.HarmBlockThreshold.BLOCK_ONLY_HIGH,
"BLOCK_MEDIUM_AND_ABOVE": types.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
"BLOCK_LOW_AND_ABOVE": types.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
}
def __init__(
self,
provider_config: dict,
@@ -131,43 +48,145 @@ class ProviderGoogleGenAI(Provider):
db_helper,
default_persona,
)
self.chosen_api_key = None
self.api_keys: List = provider_config.get("key", [])
self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else None
self.timeout = provider_config.get("timeout", 180)
self.chosen_api_key: str = self.api_keys[0] if len(self.api_keys) > 0 else None
self.timeout: int = provider_config.get("timeout", 180)
self.api_base: str = provider_config.get("api_base", None)
if self.api_base.endswith("/"):
self.api_base = self.api_base[:-1]
if isinstance(self.timeout, str):
self.timeout = int(self.timeout)
self.client = SimpleGoogleGenAIClient(
self.client = genai.Client(
api_key=self.chosen_api_key,
api_base=provider_config.get("api_base", None),
timeout=self.timeout,
)
http_options=types.HttpOptions(
base_url=self.api_base,
timeout=self.timeout * 1000, # 毫秒
),
).aio
self.set_model(provider_config["model_config"]["model"])
safety_mapping = {
"harassment": "HARM_CATEGORY_HARASSMENT",
"hate_speech": "HARM_CATEGORY_HATE_SPEECH",
"sexually_explicit": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"dangerous_content": "HARM_CATEGORY_DANGEROUS_CONTENT",
}
self.safety_settings = []
user_safety_config = self.provider_config.get("gm_safety_settings", {})
for config_key, harm_category in safety_mapping.items():
if threshold := user_safety_config.get(config_key):
self.safety_settings.append(
{"category": harm_category, "threshold": threshold}
)
self.safety_settings = [
types.SafetySetting(
category=harm_category, threshold=self.THRESHOLD_MAPPING[threshold_str]
)
for config_key, harm_category in self.CATEGORY_MAPPING.items()
if (threshold_str := user_safety_config.get(config_key))
and threshold_str in self.THRESHOLD_MAPPING
]
async def get_models(self):
return await self.client.models_list()
try:
models = await self.client.models.list()
return [
m.name.replace("models/", "")
for m in models
if "generateContent" in m.supported_actions
]
except errors.APIError as e:
raise Exception(f"获取模型列表失败: {e}")
def _prepare_conversation(
self,
payloads: Dict,
) -> List[types.Content]:
"""准备 Gemini SDK 的 Content 列表"""
gemini_contents = []
for message in payloads["messages"]:
role = message["role"]
content = message.get("content")
if role == "user":
if isinstance(content, str):
if content:
gemini_contents.append(
types.UserContent(
parts=[types.Part.from_text(text=content)]
)
)
else:
logger.warning("文本内容为空,已添加空格占位")
gemini_contents.append(
types.UserContent(parts=[types.Part.from_text(text=" ")])
)
elif isinstance(content, list):
parts = []
for item in content:
if item.get("type") == "text":
text_content = item.get("text")
if text_content:
parts.append(types.Part.from_text(text=text_content))
else:
logger.warning("文本内容为空,已添加空格占位")
parts.append(types.Part.from_text(text=" "))
elif item.get("type") == "image_url":
image_url_dict = item["image_url"]
url = image_url_dict["url"]
mime_part, base64_data = url.split(",", 1)
mime_type = mime_part.split(":")[1].split(";")[0]
image_bytes = base64.b64decode(base64_data)
parts.append(
types.Part.from_bytes(
data=image_bytes, mime_type=mime_type
)
)
gemini_contents.append(types.UserContent(parts=parts))
elif role == "assistant":
if content:
gemini_contents.append(
types.ModelContent(
parts=[types.Part.from_text(text=message["content"])]
)
)
elif "tool_calls" in message:
parts = [
{
"name": tool_call["function"]["name"],
"args": json.loads(tool_call["function"]["arguments"]),
}
for tool_call in message["tool_calls"]
]
gemini_contents.append(
types.ModelContent(parts=[types.Part.from_function_call(parts)])
)
else:
logger.warning("assistant 角色的消息内容为空,已添加空格占位")
gemini_contents.append(
types.ModelContent(parts=[types.Part.from_text(text=" ")])
)
elif role == "tool":
gemini_contents.append(
types.UserContent(
parts=[
types.Part.from_function_response(
{
"name": message["tool_call_id"],
"response": {
"name": message["tool_call_id"],
"content": message["content"],
},
}
)
]
)
)
logger.debug(f"gemini_contents: {gemini_contents}")
return gemini_contents
async def _query(self, payloads: dict, tools: FuncCall) -> LLMResponse:
tool = None
"""非流式请求 Gemini API"""
if tools:
tool = tools.get_func_desc_google_genai_style()
if not tool:
tool = None
t = tools.get_func_desc_google_genai_style()
tool = (
types.Tool(function_declarations=t["function_declarations"])
if t
else None
)
system_instruction = ""
for message in payloads["messages"]:
@@ -175,137 +194,78 @@ class ProviderGoogleGenAI(Provider):
system_instruction = message["content"]
break
google_genai_conversation = []
for message in payloads["messages"]:
if message["role"] == "user":
if isinstance(message["content"], str):
if not message["content"]:
message["content"] = ""
google_genai_conversation.append(
{"role": "user", "parts": [{"text": message["content"]}]}
)
elif isinstance(message["content"], list):
# images
parts = []
for part in message["content"]:
if part["type"] == "text":
if not part["text"]:
part["text"] = ""
parts.append({"text": part["text"]})
elif part["type"] == "image_url":
parts.append(
{
"inline_data": {
"mime_type": "image/jpeg",
"data": part["image_url"]["url"].replace(
"data:image/jpeg;base64,", ""
), # base64
}
}
)
google_genai_conversation.append({"role": "user", "parts": parts})
elif message["role"] == "assistant":
if "content" in message:
if not message["content"]:
message["content"] = ""
google_genai_conversation.append(
{"role": "model", "parts": [{"text": message["content"]}]}
)
elif "tool_calls" in message:
# tool calls in the last turn
parts = []
for tool_call in message["tool_calls"]:
parts.append(
{
"functionCall": {
"name": tool_call["function"]["name"],
"args": json.loads(
tool_call["function"]["arguments"]
),
}
}
)
google_genai_conversation.append({"role": "model", "parts": parts})
elif message["role"] == "tool":
parts = []
parts.append(
{
"functionResponse": {
"name": message["tool_call_id"],
"response": {
"name": message["tool_call_id"],
"content": message["content"],
},
}
}
)
google_genai_conversation.append({"role": "user", "parts": parts})
logger.debug(f"google_genai_conversation: {google_genai_conversation}")
conversation = self._prepare_conversation(payloads)
modalites = ["Text"]
if self.provider_config.get("gm_resp_image_modal", False):
modalites.append("Image")
loop = True
while loop:
loop = False
result = await self.client.generate_content(
contents=google_genai_conversation,
result = await self.client.models.generate_content(
model=self.get_model(),
system_instruction=system_instruction,
tools=tool,
modalities=modalites,
safety_settings=self.safety_settings,
contents=conversation,
config=types.GenerateContentConfig(
system_instruction=system_instruction,
tools=[tool] if tool else None,
safety_settings=self.safety_settings
if self.safety_settings
else None,
automatic_function_calling=types.AutomaticFunctionCallingConfig(
disable=True
),
),
)
logger.debug(f"result: {result}")
logger.debug(f"gemini result: {result}")
# Developer instruction is not enabled for models/gemini-2.0-flash-exp
if "Developer instruction is not enabled" in str(result):
logger.warning(
f"{self.get_model()} 不支持 system prompt, 已自动去除, 将会影响人格设置。"
)
logger.warning(f"{self.get_model()} 不支持 system prompt,已自动去除。")
system_instruction = ""
loop = True
# 不支持函数调用的模型SDK似乎会自动去除,保险起见不删除此行判断。
elif "Function calling is not enabled" in str(result):
logger.warning(
f"{self.get_model()} 不支持函数调用,已自动去除,不影响使用。"
)
logger.warning(f"{self.get_model()} 不支持函数调用,已自动去除。")
tool = None
loop = True
elif "Multi-modal output is not supported" in str(result):
logger.warning(
f"{self.get_model()} 不支持多模态输出,降级为文本模态重新请求。"
)
logger.warning(f"{self.get_model()} 不支持多模态输出,降级为文本模态。")
modalites = ["Text"]
loop = True
elif "candidates" not in result:
raise Exception("Gemini 返回异常结果: " + str(result))
candidates = result["candidates"][0]["content"]["parts"]
llm_response = LLMResponse("assistant")
chain = []
for candidate in candidates:
if "text" in candidate:
chain.append(Comp.Plain(candidate["text"]))
elif "functionCall" in candidate:
finish_reason = result.candidates[0].finish_reason
if finish_reason == types.FinishReason.SAFETY:
raise Exception("模型生成内容未通过用户定义的内容安全检查")
if finish_reason in {
types.FinishReason.PROHIBITED_CONTENT,
types.FinishReason.SPII,
types.FinishReason.BLOCKLIST,
types.FinishReason.IMAGE_SAFETY,
}:
raise Exception("模型生成内容违反Gemini平台政策")
if not result.candidates[0].content.parts:
raise Exception("API 返回的内容为空。")
for part in result.candidates[0].content.parts:
if part.text:
chain.append(Comp.Plain(part.text))
elif part.function_call:
llm_response.role = "tool"
llm_response.tools_call_args.append(candidate["functionCall"]["args"])
llm_response.tools_call_name.append(candidate["functionCall"]["name"])
llm_response.tools_call_ids.append(
candidate["functionCall"]["name"]
) # 没有 tool id
elif "inlineData" in candidate:
mime_type: str = candidate["inlineData"]["mimeType"]
if mime_type.startswith("image/"):
chain.append(Comp.Image.fromBase64(candidate["inlineData"]["data"]))
llm_response.tools_call_name.append(part.function_call.name)
llm_response.tools_call_args.append(part.function_call.args)
llm_response.tools_call_ids.append(part.function_call.id)
elif part.inline_data and part.inline_data.mime_type.startswith("image/"):
chain.append(Comp.Image.fromBytes(part.inline_data.data))
llm_response.result_chain = MessageChain(chain=chain)
return llm_response
async def text_chat(
@@ -320,7 +280,6 @@ class ProviderGoogleGenAI(Provider):
**kwargs,
) -> LLMResponse:
new_record = await self.assemble_context(prompt, image_urls)
context_query = []
context_query = [*contexts, new_record]
if system_prompt:
context_query.insert(0, {"role": "system", "content": system_prompt})
@@ -345,7 +304,7 @@ class ProviderGoogleGenAI(Provider):
for i in range(retry):
try:
self.client.api_key = chosen_key
self.chosen_api_key = chosen_key
llm_response = await self._query(payloads, func_tool)
break
except Exception as e:
@@ -399,13 +358,13 @@ class ProviderGoogleGenAI(Provider):
yield llm_response
def get_current_key(self) -> str:
return self.client.api_key
return self.chosen_api_key
def get_keys(self) -> List[str]:
return self.api_keys
def set_key(self, key):
self.client.api_key = key
self.chosen_api_key = key
async def assemble_context(self, text: str, image_urls: List[str] = None):
"""
@@ -444,5 +403,4 @@ class ProviderGoogleGenAI(Provider):
return ""
async def terminate(self):
await self.client.client.close()
logger.info("Google GenAI 适配器已终止。")