refactor: 初步完成gemini_source的重写
This commit is contained in:
@@ -1,8 +1,9 @@
|
||||
import base64
|
||||
import aiohttp
|
||||
import json
|
||||
import random
|
||||
import asyncio
|
||||
from google import genai
|
||||
from google.genai import types, errors
|
||||
import astrbot.core.message.components as Comp
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from astrbot.core.utils.io import download_image_by_url
|
||||
@@ -10,112 +11,28 @@ from astrbot.core.db import BaseDatabase
|
||||
from astrbot.api.provider import Provider, Personality
|
||||
from astrbot import logger
|
||||
from astrbot.core.provider.func_tool_manager import FuncCall
|
||||
from typing import List
|
||||
from typing import Dict, List
|
||||
from ..register import register_provider_adapter
|
||||
from astrbot.core.provider.entities import LLMResponse
|
||||
|
||||
|
||||
class SimpleGoogleGenAIClient:
|
||||
def __init__(self, api_key: str, api_base: str, timeout: int = 120) -> None:
|
||||
self.api_key = api_key
|
||||
if api_base.endswith("/"):
|
||||
self.api_base = api_base[:-1]
|
||||
else:
|
||||
self.api_base = api_base
|
||||
self.client = aiohttp.ClientSession(trust_env=True)
|
||||
self.timeout = timeout
|
||||
|
||||
async def models_list(self) -> List[str]:
|
||||
request_url = f"{self.api_base}/v1beta/models?key={self.api_key}"
|
||||
async with self.client.get(request_url, timeout=self.timeout) as resp:
|
||||
response = await resp.json()
|
||||
|
||||
models = []
|
||||
for model in response["models"]:
|
||||
if "generateContent" in model["supportedGenerationMethods"]:
|
||||
models.append(model["name"].replace("models/", ""))
|
||||
return models
|
||||
|
||||
async def generate_content(
|
||||
self,
|
||||
contents: List[dict],
|
||||
model: str = "gemini-1.5-flash",
|
||||
system_instruction: str = "",
|
||||
tools: dict = None,
|
||||
modalities: List[str] = ["Text"],
|
||||
safety_settings: List[dict] = [],
|
||||
):
|
||||
payload = {}
|
||||
if system_instruction:
|
||||
payload["system_instruction"] = {"parts": {"text": system_instruction}}
|
||||
if tools:
|
||||
payload["tools"] = [tools]
|
||||
payload["contents"] = contents
|
||||
payload["generationConfig"] = {
|
||||
"responseModalities": modalities,
|
||||
}
|
||||
payload["safetySettings"] = [
|
||||
{"category": s["category"], "threshold": s["threshold"]}
|
||||
for s in safety_settings
|
||||
]
|
||||
logger.debug(f"payload: {payload}")
|
||||
request_url = (
|
||||
f"{self.api_base}/v1beta/models/{model}:generateContent?key={self.api_key}"
|
||||
)
|
||||
async with self.client.post(
|
||||
request_url, json=payload, timeout=self.timeout
|
||||
) as resp:
|
||||
if "application/json" in resp.headers.get("Content-Type"):
|
||||
try:
|
||||
response = await resp.json()
|
||||
except Exception as e:
|
||||
text = await resp.text()
|
||||
logger.error(f"Gemini 返回了非 json 数据: {text}")
|
||||
raise e
|
||||
return response
|
||||
else:
|
||||
text = await resp.text()
|
||||
logger.error(f"Gemini 返回了非 json 数据: {text}")
|
||||
raise Exception("Gemini 返回了非 json 数据: ")
|
||||
|
||||
async def stream_generate_content(
|
||||
self,
|
||||
contents: List[dict],
|
||||
model: str = "gemini-1.5-flash",
|
||||
system_instruction: str = "",
|
||||
tools: dict = None,
|
||||
modalities: List[str] = ["Text"],
|
||||
safety_settings: List[dict] = [],
|
||||
):
|
||||
payload = {}
|
||||
if system_instruction:
|
||||
payload["system_instruction"] = {"parts": {"text": system_instruction}}
|
||||
if tools:
|
||||
payload["tools"] = [tools]
|
||||
payload["contents"] = contents
|
||||
payload["generationConfig"] = {
|
||||
"responseModalities": modalities,
|
||||
"stream": True,
|
||||
}
|
||||
payload["safetySettings"] = [
|
||||
{"category": s["category"], "threshold": s["threshold"]}
|
||||
for s in safety_settings
|
||||
]
|
||||
logger.debug(f"payload: {payload}")
|
||||
request_url = (
|
||||
f"{self.api_base}/v1beta/models/{model}:streamGenerateContent?key={self.api_key}"
|
||||
)
|
||||
async with self.client.post(
|
||||
request_url, json=payload, timeout=self.timeout
|
||||
) as resp:
|
||||
async for line in resp.content:
|
||||
if line:
|
||||
yield line
|
||||
|
||||
@register_provider_adapter(
|
||||
"googlegenai_chat_completion", "Google Gemini Chat Completion 提供商适配器"
|
||||
)
|
||||
class ProviderGoogleGenAI(Provider):
|
||||
CATEGORY_MAPPING = {
|
||||
"harassment": types.HarmCategory.HARM_CATEGORY_HARASSMENT,
|
||||
"hate_speech": types.HarmCategory.HARM_CATEGORY_HATE_SPEECH,
|
||||
"sexually_explicit": types.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
|
||||
"dangerous_content": types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
|
||||
}
|
||||
|
||||
THRESHOLD_MAPPING = {
|
||||
"BLOCK_NONE": types.HarmBlockThreshold.BLOCK_NONE,
|
||||
"BLOCK_ONLY_HIGH": types.HarmBlockThreshold.BLOCK_ONLY_HIGH,
|
||||
"BLOCK_MEDIUM_AND_ABOVE": types.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
|
||||
"BLOCK_LOW_AND_ABOVE": types.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
provider_config: dict,
|
||||
@@ -131,43 +48,145 @@ class ProviderGoogleGenAI(Provider):
|
||||
db_helper,
|
||||
default_persona,
|
||||
)
|
||||
self.chosen_api_key = None
|
||||
self.api_keys: List = provider_config.get("key", [])
|
||||
self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else None
|
||||
self.timeout = provider_config.get("timeout", 180)
|
||||
self.chosen_api_key: str = self.api_keys[0] if len(self.api_keys) > 0 else None
|
||||
self.timeout: int = provider_config.get("timeout", 180)
|
||||
self.api_base: str = provider_config.get("api_base", None)
|
||||
if self.api_base.endswith("/"):
|
||||
self.api_base = self.api_base[:-1]
|
||||
if isinstance(self.timeout, str):
|
||||
self.timeout = int(self.timeout)
|
||||
self.client = SimpleGoogleGenAIClient(
|
||||
self.client = genai.Client(
|
||||
api_key=self.chosen_api_key,
|
||||
api_base=provider_config.get("api_base", None),
|
||||
timeout=self.timeout,
|
||||
)
|
||||
http_options=types.HttpOptions(
|
||||
base_url=self.api_base,
|
||||
timeout=self.timeout * 1000, # 毫秒
|
||||
),
|
||||
).aio
|
||||
self.set_model(provider_config["model_config"]["model"])
|
||||
|
||||
safety_mapping = {
|
||||
"harassment": "HARM_CATEGORY_HARASSMENT",
|
||||
"hate_speech": "HARM_CATEGORY_HATE_SPEECH",
|
||||
"sexually_explicit": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
||||
"dangerous_content": "HARM_CATEGORY_DANGEROUS_CONTENT",
|
||||
}
|
||||
|
||||
self.safety_settings = []
|
||||
user_safety_config = self.provider_config.get("gm_safety_settings", {})
|
||||
for config_key, harm_category in safety_mapping.items():
|
||||
if threshold := user_safety_config.get(config_key):
|
||||
self.safety_settings.append(
|
||||
{"category": harm_category, "threshold": threshold}
|
||||
)
|
||||
self.safety_settings = [
|
||||
types.SafetySetting(
|
||||
category=harm_category, threshold=self.THRESHOLD_MAPPING[threshold_str]
|
||||
)
|
||||
for config_key, harm_category in self.CATEGORY_MAPPING.items()
|
||||
if (threshold_str := user_safety_config.get(config_key))
|
||||
and threshold_str in self.THRESHOLD_MAPPING
|
||||
]
|
||||
|
||||
async def get_models(self):
|
||||
return await self.client.models_list()
|
||||
try:
|
||||
models = await self.client.models.list()
|
||||
return [
|
||||
m.name.replace("models/", "")
|
||||
for m in models
|
||||
if "generateContent" in m.supported_actions
|
||||
]
|
||||
except errors.APIError as e:
|
||||
raise Exception(f"获取模型列表失败: {e}")
|
||||
|
||||
def _prepare_conversation(
|
||||
self,
|
||||
payloads: Dict,
|
||||
) -> List[types.Content]:
|
||||
"""准备 Gemini SDK 的 Content 列表"""
|
||||
gemini_contents = []
|
||||
for message in payloads["messages"]:
|
||||
role = message["role"]
|
||||
content = message.get("content")
|
||||
|
||||
if role == "user":
|
||||
if isinstance(content, str):
|
||||
if content:
|
||||
gemini_contents.append(
|
||||
types.UserContent(
|
||||
parts=[types.Part.from_text(text=content)]
|
||||
)
|
||||
)
|
||||
else:
|
||||
logger.warning("文本内容为空,已添加空格占位")
|
||||
gemini_contents.append(
|
||||
types.UserContent(parts=[types.Part.from_text(text=" ")])
|
||||
)
|
||||
|
||||
elif isinstance(content, list):
|
||||
parts = []
|
||||
for item in content:
|
||||
if item.get("type") == "text":
|
||||
text_content = item.get("text")
|
||||
if text_content:
|
||||
parts.append(types.Part.from_text(text=text_content))
|
||||
else:
|
||||
logger.warning("文本内容为空,已添加空格占位")
|
||||
parts.append(types.Part.from_text(text=" "))
|
||||
elif item.get("type") == "image_url":
|
||||
image_url_dict = item["image_url"]
|
||||
url = image_url_dict["url"]
|
||||
mime_part, base64_data = url.split(",", 1)
|
||||
mime_type = mime_part.split(":")[1].split(";")[0]
|
||||
image_bytes = base64.b64decode(base64_data)
|
||||
parts.append(
|
||||
types.Part.from_bytes(
|
||||
data=image_bytes, mime_type=mime_type
|
||||
)
|
||||
)
|
||||
gemini_contents.append(types.UserContent(parts=parts))
|
||||
|
||||
elif role == "assistant":
|
||||
if content:
|
||||
gemini_contents.append(
|
||||
types.ModelContent(
|
||||
parts=[types.Part.from_text(text=message["content"])]
|
||||
)
|
||||
)
|
||||
elif "tool_calls" in message:
|
||||
parts = [
|
||||
{
|
||||
"name": tool_call["function"]["name"],
|
||||
"args": json.loads(tool_call["function"]["arguments"]),
|
||||
}
|
||||
for tool_call in message["tool_calls"]
|
||||
]
|
||||
gemini_contents.append(
|
||||
types.ModelContent(parts=[types.Part.from_function_call(parts)])
|
||||
)
|
||||
else:
|
||||
logger.warning("assistant 角色的消息内容为空,已添加空格占位")
|
||||
gemini_contents.append(
|
||||
types.ModelContent(parts=[types.Part.from_text(text=" ")])
|
||||
)
|
||||
|
||||
elif role == "tool":
|
||||
gemini_contents.append(
|
||||
types.UserContent(
|
||||
parts=[
|
||||
types.Part.from_function_response(
|
||||
{
|
||||
"name": message["tool_call_id"],
|
||||
"response": {
|
||||
"name": message["tool_call_id"],
|
||||
"content": message["content"],
|
||||
},
|
||||
}
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
logger.debug(f"gemini_contents: {gemini_contents}")
|
||||
|
||||
return gemini_contents
|
||||
|
||||
async def _query(self, payloads: dict, tools: FuncCall) -> LLMResponse:
|
||||
tool = None
|
||||
"""非流式请求 Gemini API"""
|
||||
if tools:
|
||||
tool = tools.get_func_desc_google_genai_style()
|
||||
if not tool:
|
||||
tool = None
|
||||
t = tools.get_func_desc_google_genai_style()
|
||||
tool = (
|
||||
types.Tool(function_declarations=t["function_declarations"])
|
||||
if t
|
||||
else None
|
||||
)
|
||||
|
||||
system_instruction = ""
|
||||
for message in payloads["messages"]:
|
||||
@@ -175,137 +194,78 @@ class ProviderGoogleGenAI(Provider):
|
||||
system_instruction = message["content"]
|
||||
break
|
||||
|
||||
google_genai_conversation = []
|
||||
for message in payloads["messages"]:
|
||||
if message["role"] == "user":
|
||||
if isinstance(message["content"], str):
|
||||
if not message["content"]:
|
||||
message["content"] = ""
|
||||
|
||||
google_genai_conversation.append(
|
||||
{"role": "user", "parts": [{"text": message["content"]}]}
|
||||
)
|
||||
elif isinstance(message["content"], list):
|
||||
# images
|
||||
parts = []
|
||||
for part in message["content"]:
|
||||
if part["type"] == "text":
|
||||
if not part["text"]:
|
||||
part["text"] = ""
|
||||
parts.append({"text": part["text"]})
|
||||
elif part["type"] == "image_url":
|
||||
parts.append(
|
||||
{
|
||||
"inline_data": {
|
||||
"mime_type": "image/jpeg",
|
||||
"data": part["image_url"]["url"].replace(
|
||||
"data:image/jpeg;base64,", ""
|
||||
), # base64
|
||||
}
|
||||
}
|
||||
)
|
||||
google_genai_conversation.append({"role": "user", "parts": parts})
|
||||
|
||||
elif message["role"] == "assistant":
|
||||
if "content" in message:
|
||||
if not message["content"]:
|
||||
message["content"] = ""
|
||||
google_genai_conversation.append(
|
||||
{"role": "model", "parts": [{"text": message["content"]}]}
|
||||
)
|
||||
elif "tool_calls" in message:
|
||||
# tool calls in the last turn
|
||||
parts = []
|
||||
for tool_call in message["tool_calls"]:
|
||||
parts.append(
|
||||
{
|
||||
"functionCall": {
|
||||
"name": tool_call["function"]["name"],
|
||||
"args": json.loads(
|
||||
tool_call["function"]["arguments"]
|
||||
),
|
||||
}
|
||||
}
|
||||
)
|
||||
google_genai_conversation.append({"role": "model", "parts": parts})
|
||||
elif message["role"] == "tool":
|
||||
parts = []
|
||||
parts.append(
|
||||
{
|
||||
"functionResponse": {
|
||||
"name": message["tool_call_id"],
|
||||
"response": {
|
||||
"name": message["tool_call_id"],
|
||||
"content": message["content"],
|
||||
},
|
||||
}
|
||||
}
|
||||
)
|
||||
google_genai_conversation.append({"role": "user", "parts": parts})
|
||||
|
||||
logger.debug(f"google_genai_conversation: {google_genai_conversation}")
|
||||
conversation = self._prepare_conversation(payloads)
|
||||
|
||||
modalites = ["Text"]
|
||||
if self.provider_config.get("gm_resp_image_modal", False):
|
||||
modalites.append("Image")
|
||||
|
||||
loop = True
|
||||
|
||||
while loop:
|
||||
loop = False
|
||||
result = await self.client.generate_content(
|
||||
contents=google_genai_conversation,
|
||||
result = await self.client.models.generate_content(
|
||||
model=self.get_model(),
|
||||
system_instruction=system_instruction,
|
||||
tools=tool,
|
||||
modalities=modalites,
|
||||
safety_settings=self.safety_settings,
|
||||
contents=conversation,
|
||||
config=types.GenerateContentConfig(
|
||||
system_instruction=system_instruction,
|
||||
tools=[tool] if tool else None,
|
||||
safety_settings=self.safety_settings
|
||||
if self.safety_settings
|
||||
else None,
|
||||
automatic_function_calling=types.AutomaticFunctionCallingConfig(
|
||||
disable=True
|
||||
),
|
||||
),
|
||||
)
|
||||
logger.debug(f"result: {result}")
|
||||
logger.debug(f"gemini result: {result}")
|
||||
|
||||
# Developer instruction is not enabled for models/gemini-2.0-flash-exp
|
||||
if "Developer instruction is not enabled" in str(result):
|
||||
logger.warning(
|
||||
f"{self.get_model()} 不支持 system prompt, 已自动去除, 将会影响人格设置。"
|
||||
)
|
||||
logger.warning(f"{self.get_model()} 不支持 system prompt,已自动去除。")
|
||||
system_instruction = ""
|
||||
loop = True
|
||||
|
||||
# 不支持函数调用的模型SDK似乎会自动去除,保险起见不删除此行判断。
|
||||
elif "Function calling is not enabled" in str(result):
|
||||
logger.warning(
|
||||
f"{self.get_model()} 不支持函数调用,已自动去除,不影响使用。"
|
||||
)
|
||||
logger.warning(f"{self.get_model()} 不支持函数调用,已自动去除。")
|
||||
tool = None
|
||||
loop = True
|
||||
|
||||
elif "Multi-modal output is not supported" in str(result):
|
||||
logger.warning(
|
||||
f"{self.get_model()} 不支持多模态输出,降级为文本模态重新请求。"
|
||||
)
|
||||
logger.warning(f"{self.get_model()} 不支持多模态输出,降级为文本模态。")
|
||||
modalites = ["Text"]
|
||||
loop = True
|
||||
|
||||
elif "candidates" not in result:
|
||||
raise Exception("Gemini 返回异常结果: " + str(result))
|
||||
|
||||
candidates = result["candidates"][0]["content"]["parts"]
|
||||
llm_response = LLMResponse("assistant")
|
||||
chain = []
|
||||
for candidate in candidates:
|
||||
if "text" in candidate:
|
||||
chain.append(Comp.Plain(candidate["text"]))
|
||||
elif "functionCall" in candidate:
|
||||
|
||||
finish_reason = result.candidates[0].finish_reason
|
||||
|
||||
if finish_reason == types.FinishReason.SAFETY:
|
||||
raise Exception("模型生成内容未通过用户定义的内容安全检查")
|
||||
|
||||
if finish_reason in {
|
||||
types.FinishReason.PROHIBITED_CONTENT,
|
||||
types.FinishReason.SPII,
|
||||
types.FinishReason.BLOCKLIST,
|
||||
types.FinishReason.IMAGE_SAFETY,
|
||||
}:
|
||||
raise Exception("模型生成内容违反Gemini平台政策")
|
||||
|
||||
if not result.candidates[0].content.parts:
|
||||
raise Exception("API 返回的内容为空。")
|
||||
|
||||
for part in result.candidates[0].content.parts:
|
||||
if part.text:
|
||||
chain.append(Comp.Plain(part.text))
|
||||
elif part.function_call:
|
||||
llm_response.role = "tool"
|
||||
llm_response.tools_call_args.append(candidate["functionCall"]["args"])
|
||||
llm_response.tools_call_name.append(candidate["functionCall"]["name"])
|
||||
llm_response.tools_call_ids.append(
|
||||
candidate["functionCall"]["name"]
|
||||
) # 没有 tool id
|
||||
elif "inlineData" in candidate:
|
||||
mime_type: str = candidate["inlineData"]["mimeType"]
|
||||
if mime_type.startswith("image/"):
|
||||
chain.append(Comp.Image.fromBase64(candidate["inlineData"]["data"]))
|
||||
llm_response.tools_call_name.append(part.function_call.name)
|
||||
llm_response.tools_call_args.append(part.function_call.args)
|
||||
llm_response.tools_call_ids.append(part.function_call.id)
|
||||
elif part.inline_data and part.inline_data.mime_type.startswith("image/"):
|
||||
chain.append(Comp.Image.fromBytes(part.inline_data.data))
|
||||
|
||||
llm_response.result_chain = MessageChain(chain=chain)
|
||||
|
||||
return llm_response
|
||||
|
||||
async def text_chat(
|
||||
@@ -320,7 +280,6 @@ class ProviderGoogleGenAI(Provider):
|
||||
**kwargs,
|
||||
) -> LLMResponse:
|
||||
new_record = await self.assemble_context(prompt, image_urls)
|
||||
context_query = []
|
||||
context_query = [*contexts, new_record]
|
||||
if system_prompt:
|
||||
context_query.insert(0, {"role": "system", "content": system_prompt})
|
||||
@@ -345,7 +304,7 @@ class ProviderGoogleGenAI(Provider):
|
||||
|
||||
for i in range(retry):
|
||||
try:
|
||||
self.client.api_key = chosen_key
|
||||
self.chosen_api_key = chosen_key
|
||||
llm_response = await self._query(payloads, func_tool)
|
||||
break
|
||||
except Exception as e:
|
||||
@@ -399,13 +358,13 @@ class ProviderGoogleGenAI(Provider):
|
||||
yield llm_response
|
||||
|
||||
def get_current_key(self) -> str:
|
||||
return self.client.api_key
|
||||
return self.chosen_api_key
|
||||
|
||||
def get_keys(self) -> List[str]:
|
||||
return self.api_keys
|
||||
|
||||
def set_key(self, key):
|
||||
self.client.api_key = key
|
||||
self.chosen_api_key = key
|
||||
|
||||
async def assemble_context(self, text: str, image_urls: List[str] = None):
|
||||
"""
|
||||
@@ -444,5 +403,4 @@ class ProviderGoogleGenAI(Provider):
|
||||
return ""
|
||||
|
||||
async def terminate(self):
|
||||
await self.client.client.close()
|
||||
logger.info("Google GenAI 适配器已终止。")
|
||||
|
||||
Reference in New Issue
Block a user