Merge branch 'master' into better-stream

This commit is contained in:
渡鸦95676
2025-04-15 21:22:08 +08:00
committed by GitHub
19 changed files with 604 additions and 340 deletions
+4 -3
View File
@@ -13,10 +13,11 @@ _✨ 易上手的多平台 LLM 聊天机器人及开发框架 ✨_
[![GitHub release (latest by date)](https://img.shields.io/github/v/release/Soulter/AstrBot?style=for-the-badge&color=76bad9)](https://github.com/Soulter/AstrBot/releases/latest)
<img src="https://img.shields.io/badge/python-3.10+-blue.svg?style=for-the-badge&color=76bad9" alt="python">
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg?style=for-the-badge&color=76bad9"/></a>
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="Static Badge" src="https://img.shields.io/badge/QQ群-775869627-purple?style=for-the-badge&color=76bad9"></a>
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="QQ_community" src="https://img.shields.io/badge/QQ群-775869627-purple?style=for-the-badge&color=76bad9"></a>
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
[![wakatime](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e.svg?style=for-the-badge&color=76bad9)](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e)
![Dynamic JSON Badge](https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fstats&query=v&label=7%E6%97%A5%E6%B4%BB%E8%B7%83%E9%87%8F&cacheSeconds=10800&style=for-the-badge&color=3b618e)
![Dynamic JSON Badge](https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fplugin-num&query=%24.result&suffix=%E4%B8%AA&style=for-the-badge&label=%E6%8F%92%E4%BB%B6%E5%B8%82%E5%9C%BA&cacheSeconds=7200)
![Dynamic JSON Badge](https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fstats&query=v&label=7%E6%97%A5%E6%B4%BB%E8%B7%83%E9%87%8F&cacheSeconds=3600&style=for-the-badge&color=3b618e)
![Dynamic JSON Badge](https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fplugin-num&query=%24.result&suffix=%E4%B8%AA&style=for-the-badge&label=%E6%8F%92%E4%BB%B6%E5%B8%82%E5%9C%BA&cacheSeconds=3600)
<a href="https://github.com/Soulter/AstrBot/blob/master/README_en.md">English</a>
+4 -1
View File
@@ -23,7 +23,10 @@ db_helper = SQLiteDatabase(DB_PATH)
sp = (
SharedPreferences()
) # 简单的偏好设置存储, 这里后续应该存储到数据库中, 一些部分可以存储到配置中
pip_installer = PipInstaller(astrbot_config.get("pip_install_arg", ""))
pip_installer = PipInstaller(
astrbot_config.get("pip_install_arg", ""),
astrbot_config.get("pypi_index_url", None),
)
web_chat_queue = asyncio.Queue(maxsize=32)
web_chat_back_queue = asyncio.Queue(maxsize=32)
WEBUI_SK = "Advanced_System_for_Text_Response_and_Bot_Operations_Tool"
+11 -10
View File
@@ -98,7 +98,7 @@ DEFAULT_CONFIG = {
"wake_prefix": ["/"],
"log_level": "INFO",
"pip_install_arg": "",
"plugin_repo_mirror": "",
"pypi_index_url": "https://mirrors.aliyun.com/pypi/simple/",
"knowledge_db": {},
"persona": [],
"timezone": "",
@@ -529,6 +529,7 @@ CONFIG_METADATA_2 = {
"model": "gemini-2.0-flash-exp",
},
"gm_resp_image_modal": False,
"gm_native_coderunner": False,
"gm_safety_settings": {
"harassment": "BLOCK_MEDIUM_AND_ABOVE",
"hate_speech": "BLOCK_MEDIUM_AND_ABOVE",
@@ -705,6 +706,12 @@ CONFIG_METADATA_2 = {
"type": "bool",
"hint": "启用后,将支持返回图片内容。需要模型支持,否则会报错。具体支持模型请查看 Google Gemini 官方网站。温馨提示,如果您需要生成图片,请关闭 `启用群员识别` 配置获得更好的效果。",
},
"gm_native_coderunner": {
"description": "启用原生代码执行器",
"type": "bool",
"hint": "启用后所有函数工具将全部失效",
"obvious_hint": True,
},
"gm_safety_settings": {
"description": "安全过滤器",
"type": "object",
@@ -1222,16 +1229,10 @@ CONFIG_METADATA_2 = {
"type": "string",
"hint": "安装插件依赖时,会使用 Python 的 pip 工具。这里可以填写额外的参数,如 `--break-system-package` 等。",
},
"plugin_repo_mirror": {
"description": "件仓库镜像",
"pypi_index_url": {
"description": "PyPI 软件仓库地址",
"type": "string",
"hint": "已废弃,请使用管理面板->设置页的代理地址选择",
"obvious_hint": True,
"options": [
"default",
"https://ghp.ci/",
"https://github-mirror.us.kg/",
],
"hint": "安装 Python 依赖时请求的 PyPI 软件仓库地址。默认为 https://mirrors.aliyun.com/pypi/simple/",
},
},
},
+1 -1
View File
@@ -106,7 +106,7 @@ class AstrBotCoreLifecycle:
await self.pipeline_scheduler.initialize()
# 初始化更新器
self.astrbot_updator = AstrBotUpdator(self.astrbot_config["plugin_repo_mirror"])
self.astrbot_updator = AstrBotUpdator()
# 初始化事件总线
self.event_bus = EventBus(self.event_queue, self.pipeline_scheduler)
+8 -1
View File
@@ -1,6 +1,8 @@
import abc
import asyncio
import re
import hashlib
import uuid
from dataclasses import dataclass
from typing import List, Union, Optional, AsyncGenerator
@@ -402,8 +404,13 @@ class AstrMessageEvent(abc.ABC):
Args:
message (MessageChain): 消息链,具体使用方式请参考文档。
"""
# Leverage BLAKE2 hash function to generate a non-reversible hash of the sender ID for privacy.
hash_obj = hashlib.blake2b(self.get_sender_id().encode("utf-8"), digest_size=16)
sid = str(uuid.UUID(bytes=hash_obj.digest()))
asyncio.create_task(
Metric.upload(msg_event_tick=1, adapter_name=self.platform_meta.name)
Metric.upload(
msg_event_tick=1, adapter_name=self.platform_meta.name, sid=sid
)
)
self._has_send_oper = True
@@ -30,7 +30,7 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
# convert to base64
bs64 = await segment.convert_to_base64()
d["data"] = {
"file": bs64,
"file": f"base64://{bs64}",
}
elif isinstance(segment, At):
d["data"] = {
@@ -1,4 +1,5 @@
import asyncio
import re
import sys
import uuid
@@ -118,8 +119,6 @@ class TelegramPlatformAdapter(Platform):
if commands:
await self.client.set_my_commands(commands)
for cmd in commands:
logger.debug(f"已注册指令: /{cmd.command} - {cmd.description}")
except Exception as e:
logger.error(f"向 Telegram 注册指令时发生错误: {e!s}")
@@ -167,6 +166,10 @@ class TelegramPlatformAdapter(Platform):
if not cmd_name or cmd_name in skip_commands:
return None
if not re.match(r"^[a-z0-9_]+$", cmd_name) or len(cmd_name) > 32:
logger.warning(f"跳过无法注册的命令: {cmd_name}")
return None
# Build description.
description = handler_metadata.desc or (
f"指令组: {cmd_name} (包含多个子指令)" if is_group else f"指令: {cmd_name}"
+1 -1
View File
@@ -155,7 +155,7 @@ class ProviderRequest:
if self.image_urls:
user_content = {
"role": "user",
"content": [{"type": "text", "text": self.prompt}],
"content": [{"type": "text", "text": self.prompt if self.prompt else "[图片]"}],
}
for image_url in self.image_urls:
if image_url.startswith("http"):
+417 -308
View File
@@ -1,121 +1,54 @@
import base64
import aiohttp
import json
import random
import asyncio
import base64
import json
import logging
import random
from typing import Dict, List, Optional, AsyncGenerator
from google import genai
from google.genai import types
from google.genai.errors import APIError
import astrbot.core.message.components as Comp
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.utils.io import download_image_by_url
from astrbot.core.db import BaseDatabase
from astrbot.api.provider import Provider, Personality
from astrbot import logger
from astrbot.core.provider.func_tool_manager import FuncCall
from typing import List
from ..register import register_provider_adapter
from astrbot.api.provider import Personality, Provider
from astrbot.core.db import BaseDatabase
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.provider.entities import LLMResponse
from astrbot.core.provider.func_tool_manager import FuncCall
from astrbot.core.utils.io import download_image_by_url
from ..register import register_provider_adapter
class SimpleGoogleGenAIClient:
def __init__(self, api_key: str, api_base: str, timeout: int = 120) -> None:
self.api_key = api_key
if api_base.endswith("/"):
self.api_base = api_base[:-1]
else:
self.api_base = api_base
self.client = aiohttp.ClientSession(trust_env=True)
self.timeout = timeout
class SuppressNonTextPartsWarning(logging.Filter):
"""过滤 Gemini SDK 中的非文本部分警告"""
async def models_list(self) -> List[str]:
request_url = f"{self.api_base}/v1beta/models?key={self.api_key}"
async with self.client.get(request_url, timeout=self.timeout) as resp:
response = await resp.json()
def filter(self, record):
return "there are non-text parts in the response" not in record.getMessage()
models = []
for model in response["models"]:
if "generateContent" in model["supportedGenerationMethods"]:
models.append(model["name"].replace("models/", ""))
return models
async def generate_content(
self,
contents: List[dict],
model: str = "gemini-1.5-flash",
system_instruction: str = "",
tools: dict = None,
modalities: List[str] = ["Text"],
safety_settings: List[dict] = [],
):
payload = {}
if system_instruction:
payload["system_instruction"] = {"parts": {"text": system_instruction}}
if tools:
payload["tools"] = [tools]
payload["contents"] = contents
payload["generationConfig"] = {
"responseModalities": modalities,
}
payload["safetySettings"] = [
{"category": s["category"], "threshold": s["threshold"]}
for s in safety_settings
]
logger.debug(f"payload: {payload}")
request_url = (
f"{self.api_base}/v1beta/models/{model}:generateContent?key={self.api_key}"
)
async with self.client.post(
request_url, json=payload, timeout=self.timeout
) as resp:
if "application/json" in resp.headers.get("Content-Type"):
try:
response = await resp.json()
except Exception as e:
text = await resp.text()
logger.error(f"Gemini 返回了非 json 数据: {text}")
raise e
return response
else:
text = await resp.text()
logger.error(f"Gemini 返回了非 json 数据: {text}")
raise Exception("Gemini 返回了非 json 数据: ")
logging.getLogger("google_genai.types").addFilter(SuppressNonTextPartsWarning())
async def stream_generate_content(
self,
contents: List[dict],
model: str = "gemini-1.5-flash",
system_instruction: str = "",
tools: dict = None,
modalities: List[str] = ["Text"],
safety_settings: List[dict] = [],
):
payload = {}
if system_instruction:
payload["system_instruction"] = {"parts": {"text": system_instruction}}
if tools:
payload["tools"] = [tools]
payload["contents"] = contents
payload["generationConfig"] = {
"responseModalities": modalities,
"stream": True,
}
payload["safetySettings"] = [
{"category": s["category"], "threshold": s["threshold"]}
for s in safety_settings
]
logger.debug(f"payload: {payload}")
request_url = (
f"{self.api_base}/v1beta/models/{model}:streamGenerateContent?key={self.api_key}"
)
async with self.client.post(
request_url, json=payload, timeout=self.timeout
) as resp:
async for line in resp.content:
if line:
yield line
@register_provider_adapter(
"googlegenai_chat_completion", "Google Gemini Chat Completion 提供商适配器"
)
class ProviderGoogleGenAI(Provider):
CATEGORY_MAPPING = {
"harassment": types.HarmCategory.HARM_CATEGORY_HARASSMENT,
"hate_speech": types.HarmCategory.HARM_CATEGORY_HATE_SPEECH,
"sexually_explicit": types.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
"dangerous_content": types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
}
THRESHOLD_MAPPING = {
"BLOCK_NONE": types.HarmBlockThreshold.BLOCK_NONE,
"BLOCK_ONLY_HIGH": types.HarmBlockThreshold.BLOCK_ONLY_HIGH,
"BLOCK_MEDIUM_AND_ABOVE": types.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
"BLOCK_LOW_AND_ABOVE": types.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
}
def __init__(
self,
provider_config: dict,
@@ -131,183 +64,351 @@ class ProviderGoogleGenAI(Provider):
db_helper,
default_persona,
)
self.chosen_api_key = None
self.api_keys: List = provider_config.get("key", [])
self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else None
self.timeout = provider_config.get("timeout", 180)
if isinstance(self.timeout, str):
self.timeout = int(self.timeout)
self.client = SimpleGoogleGenAIClient(
api_key=self.chosen_api_key,
api_base=provider_config.get("api_base", None),
timeout=self.timeout,
)
self.chosen_api_key: str = self.api_keys[0] if len(self.api_keys) > 0 else None
self.timeout: int = int(provider_config.get("timeout", 180))
self.api_base: Optional[str] = provider_config.get("api_base", None)
if self.api_base and self.api_base.endswith("/"):
self.api_base = self.api_base[:-1]
self._init_client()
self.set_model(provider_config["model_config"]["model"])
self._init_safety_settings()
safety_mapping = {
"harassment": "HARM_CATEGORY_HARASSMENT",
"hate_speech": "HARM_CATEGORY_HATE_SPEECH",
"sexually_explicit": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"dangerous_content": "HARM_CATEGORY_DANGEROUS_CONTENT",
}
def _init_client(self) -> None:
"""初始化Gemini客户端"""
self.client = genai.Client(
api_key=self.chosen_api_key,
http_options=types.HttpOptions(
base_url=self.api_base,
timeout=self.timeout * 1000, # 毫秒
),
).aio
self.safety_settings = []
def _init_safety_settings(self) -> None:
"""初始化安全设置"""
user_safety_config = self.provider_config.get("gm_safety_settings", {})
for config_key, harm_category in safety_mapping.items():
if threshold := user_safety_config.get(config_key):
self.safety_settings.append(
{"category": harm_category, "threshold": threshold}
self.safety_settings = [
types.SafetySetting(
category=harm_category, threshold=self.THRESHOLD_MAPPING[threshold_str]
)
for config_key, harm_category in self.CATEGORY_MAPPING.items()
if (threshold_str := user_safety_config.get(config_key))
and threshold_str in self.THRESHOLD_MAPPING
]
async def _handle_api_error(self, e: APIError, keys: List[str]) -> bool:
"""处理API错误,返回是否需要重试"""
if e.code == 429 or "API key not valid" in e.message:
keys.remove(self.chosen_api_key)
if len(keys) > 0:
self.set_key(random.choice(keys))
logger.info(
f"检测到 Key 异常({e.message}),正在尝试更换 API Key 重试... 当前 Key: {self.chosen_api_key[:12]}..."
)
await asyncio.sleep(1)
return True
else:
logger.error(
f"检测到 Key 异常({e.message}),且已没有可用的 Key。 当前 Key: {self.chosen_api_key[:12]}..."
)
raise Exception("达到了 Gemini 速率限制, 请稍后再试...")
else:
logger.error(
f"发生了错误(gemini_source)。Provider 配置如下: {self.provider_config}"
)
raise e
async def get_models(self):
return await self.client.models_list()
async def _prepare_query_config(
self,
tools: Optional[FuncCall] = None,
system_instruction: Optional[str] = None,
temperature: Optional[float] = 0.7,
modalities: Optional[List[str]] = None,
) -> types.GenerateContentConfig:
"""准备查询配置"""
if not modalities:
modalities = ["Text"]
async def _query(self, payloads: dict, tools: FuncCall) -> LLMResponse:
tool = None
if tools:
tool = tools.get_func_desc_google_genai_style()
if not tool:
tool = None
# 流式输出不支持图片模态
if (
self.provider_settings.get("streaming_response", False)
and "Image" in modalities
):
logger.warning("流式输出不支持图片模态,已自动降级为文本模态")
modalities = ["Text"]
system_instruction = ""
tool_list = None
if self.provider_config.get("gm_native_coderunner", False):
if tools:
logger.warning("Gemini原生代码执行器已启用,函数工具将被忽略")
tool_list = [types.Tool(code_execution=types.ToolCodeExecution())]
elif tools and (func_desc := tools.get_func_desc_google_genai_style()):
tool_list = [
types.Tool(function_declarations=func_desc["function_declarations"])
]
return types.GenerateContentConfig(
system_instruction=system_instruction,
temperature=temperature,
response_modalities=modalities,
tools=tool_list,
safety_settings=self.safety_settings if self.safety_settings else None,
automatic_function_calling=types.AutomaticFunctionCallingConfig(
disable=True
),
)
@staticmethod
def _prepare_conversation(payloads: Dict) -> List[types.Content]:
"""准备 Gemini SDK 的 Content 列表"""
def create_text_part(text: str) -> types.UserContent:
content_a = text if text else " "
if not text:
logger.warning("文本内容为空,已添加空格占位")
return types.UserContent(parts=[types.Part.from_text(text=content_a)])
def process_image_url(image_url_dict: dict) -> types.Part:
url = image_url_dict["url"]
mime_type = url.split(":")[1].split(";")[0]
image_bytes = base64.b64decode(url.split(",", 1)[1])
return types.Part.from_bytes(data=image_bytes, mime_type=mime_type)
gemini_contents: List[types.Content] = []
for message in payloads["messages"]:
if message["role"] == "system":
system_instruction = message["content"]
break
role, content = message["role"], message.get("content")
google_genai_conversation = []
for message in payloads["messages"]:
if message["role"] == "user":
if isinstance(message["content"], str):
if not message["content"]:
message["content"] = " "
if role == "user":
if isinstance(content, str):
gemini_contents.append(create_text_part(content))
elif isinstance(content, list):
parts = [
types.Part.from_text(text=item["text"] or " ")
if item["type"] == "text"
else process_image_url(item["image_url"])
for item in content
]
gemini_contents.append(types.UserContent(parts=parts))
google_genai_conversation.append(
{"role": "user", "parts": [{"text": message["content"]}]}
)
elif isinstance(message["content"], list):
# images
parts = []
for part in message["content"]:
if part["type"] == "text":
if not part["text"]:
part["text"] = ""
parts.append({"text": part["text"]})
elif part["type"] == "image_url":
parts.append(
{
"inline_data": {
"mime_type": "image/jpeg",
"data": part["image_url"]["url"].replace(
"data:image/jpeg;base64,", ""
), # base64
}
}
)
google_genai_conversation.append({"role": "user", "parts": parts})
elif message["role"] == "assistant":
if "content" in message:
if not message["content"]:
message["content"] = " "
google_genai_conversation.append(
{"role": "model", "parts": [{"text": message["content"]}]}
elif role == "assistant":
if content:
gemini_contents.append(
types.ModelContent(parts=[types.Part.from_text(text=content)])
)
elif "tool_calls" in message:
# tool calls in the last turn
parts = []
for tool_call in message["tool_calls"]:
parts.append(
{
"functionCall": {
"name": tool_call["function"]["name"],
"args": json.loads(
tool_call["function"]["arguments"]
),
}
}
)
google_genai_conversation.append({"role": "model", "parts": parts})
elif message["role"] == "tool":
parts = []
parts.append(
{
"functionResponse": {
"name": message["tool_call_id"],
"response": {
"name": message["tool_call_id"],
"content": message["content"],
},
}
}
gemini_contents.extend(
[
types.ModelContent(
parts=[
types.Part.from_function_call(
name=tool["function"]["name"],
args=json.loads(tool["function"]["arguments"]),
)
]
)
for tool in message["tool_calls"]
]
)
else:
logger.warning("assistant 角色的消息内容为空,已添加空格占位")
gemini_contents.append(
types.ModelContent(parts=[types.Part.from_text(text=" ")])
)
elif role == "tool":
gemini_contents.append(
types.UserContent(
parts=[
types.Part.from_function_response(
name=message["tool_call_id"],
response={
"name": message["tool_call_id"],
"content": message["content"],
},
)
]
)
)
google_genai_conversation.append({"role": "user", "parts": parts})
logger.debug(f"google_genai_conversation: {google_genai_conversation}")
return gemini_contents
modalites = ["Text"]
if self.provider_config.get("gm_resp_image_modal", False):
modalites.append("Image")
@staticmethod
def _process_content_parts(
result: types.GenerateContentResponse, llm_response: LLMResponse
) -> MessageChain:
"""处理内容部分并构建消息链"""
finish_reason = result.candidates[0].finish_reason
result_parts: Optional[types.Part] = result.candidates[0].content.parts
loop = True
while loop:
loop = False
result = await self.client.generate_content(
contents=google_genai_conversation,
model=self.get_model(),
system_instruction=system_instruction,
tools=tool,
modalities=modalites,
safety_settings=self.safety_settings,
)
logger.debug(f"result: {result}")
if finish_reason == types.FinishReason.SAFETY:
raise Exception("模型生成内容未通过用户定义的内容安全检查")
# Developer instruction is not enabled for models/gemini-2.0-flash-exp
if "Developer instruction is not enabled" in str(result):
logger.warning(
f"{self.get_model()} 不支持 system prompt, 已自动去除, 将会影响人格设置。"
)
system_instruction = ""
loop = True
if finish_reason in {
types.FinishReason.PROHIBITED_CONTENT,
types.FinishReason.SPII,
types.FinishReason.BLOCKLIST,
types.FinishReason.IMAGE_SAFETY,
}:
raise Exception("模型生成内容违反Gemini平台政策")
elif "Function calling is not enabled" in str(result):
logger.warning(
f"{self.get_model()} 不支持函数调用,已自动去除,不影响使用"
)
tool = None
loop = True
if not result_parts:
logger.debug(result.candidates)
raise Exception("API 返回的内容为空")
elif "Multi-modal output is not supported" in str(result):
logger.warning(
f"{self.get_model()} 不支持多模态输出,降级为文本模态重新请求。"
)
modalites = ["Text"]
loop = True
elif "candidates" not in result:
raise Exception("Gemini 返回异常结果: " + str(result))
candidates = result["candidates"][0]["content"]["parts"]
llm_response = LLMResponse("assistant")
chain = []
for candidate in candidates:
if "text" in candidate:
chain.append(Comp.Plain(candidate["text"]))
elif "functionCall" in candidate:
llm_response.role = "tool"
llm_response.tools_call_args.append(candidate["functionCall"]["args"])
llm_response.tools_call_name.append(candidate["functionCall"]["name"])
llm_response.tools_call_ids.append(
candidate["functionCall"]["name"]
) # 没有 tool id
elif "inlineData" in candidate:
mime_type: str = candidate["inlineData"]["mimeType"]
if mime_type.startswith("image/"):
chain.append(Comp.Image.fromBase64(candidate["inlineData"]["data"]))
part: types.Part
llm_response.result_chain = MessageChain(chain=chain)
# 暂时这样Fallback
if all(
part.inline_data and part.inline_data.mime_type.startswith("image/")
for part in result_parts
):
chain.append(Comp.Plain("这是图片"))
for part in result_parts:
if part.text:
chain.append(Comp.Plain(part.text))
elif part.function_call:
llm_response.role = "tool"
llm_response.tools_call_name.append(part.function_call.name)
llm_response.tools_call_args.append(part.function_call.args)
# gemini 返回的 function_call.id 可能为 None
llm_response.tools_call_ids.append(
part.function_call.id or part.function_call.name
)
elif part.inline_data and part.inline_data.mime_type.startswith("image/"):
chain.append(Comp.Image.fromBytes(part.inline_data.data))
return MessageChain(chain=chain)
async def _query(
self, payloads: dict, tools: FuncCall, temperature: float = 0.7
) -> LLMResponse:
"""非流式请求 Gemini API"""
system_instruction = next(
(msg["content"] for msg in payloads["messages"] if msg["role"] == "system"),
None,
)
modalities = ["Text"]
if self.provider_config.get("gm_resp_image_modal", False):
modalities.append("Image")
conversation = self._prepare_conversation(payloads)
result: Optional[types.GenerateContentResponse] = None
while True:
try:
config = await self._prepare_query_config(
tools, system_instruction, temperature, modalities
)
result = await self.client.models.generate_content(
model=self.get_model(),
contents=conversation,
config=config,
)
if result.candidates[0].finish_reason == types.FinishReason.RECITATION:
if temperature > 2:
raise Exception("温度参数已超过最大值2,仍然发生recitation")
temperature += 0.2
logger.warning(
f"发生了recitation,正在提高温度至{temperature:.1f}重试..."
)
continue
break
except APIError as e:
if "Developer instruction is not enabled" in e.message:
logger.warning(
f"{self.get_model()} 不支持 system prompt,已自动去除(影响人格设置)"
)
system_instruction = None
elif "Function calling is not enabled" in e.message:
logger.warning(f"{self.get_model()} 不支持函数调用,已自动去除")
tools = None
elif (
"Multi-modal output is not supported" in e.message
or "Model does not support the requested response modalities"
in e.message
):
logger.warning(
f"{self.get_model()} 不支持多模态输出,降级为文本模态"
)
modalities = ["Text"]
else:
raise
continue
llm_response = LLMResponse("assistant")
llm_response.result_chain = self._process_content_parts(result, llm_response)
return llm_response
async def _query_stream(
self, payloads: dict, tools: FuncCall, temperature: float = 0.7
) -> AsyncGenerator[LLMResponse, None]:
"""流式请求 Gemini API"""
system_instruction = next(
(msg["content"] for msg in payloads["messages"] if msg["role"] == "system"),
None,
)
conversation = self._prepare_conversation(payloads)
result = None
while True:
try:
config = await self._prepare_query_config(
tools, system_instruction, temperature
)
result = await self.client.models.generate_content_stream(
model=self.get_model(),
contents=conversation,
config=config,
)
break
except APIError as e:
if "Developer instruction is not enabled" in e.message:
logger.warning(
f"{self.get_model()} 不支持 system prompt,已自动去除(影响人格设置)"
)
system_instruction = None
elif "Function calling is not enabled" in e.message:
logger.warning(f"{self.get_model()} 不支持函数调用,已自动去除")
tools = None
else:
raise
continue
async for chunk in result:
llm_response = LLMResponse("assistant", is_chunk=True)
if chunk.candidates[0].content.parts and any(
part.function_call for part in chunk.candidates[0].content.parts
):
llm_response = LLMResponse("assistant", is_chunk=False)
llm_response.result_chain = self._process_content_parts(
chunk, llm_response
)
yield llm_response
break
if chunk.text:
llm_response.result_chain = MessageChain(chain=[Comp.Plain(chunk.text)])
yield llm_response
if chunk.candidates[0].finish_reason:
llm_response = LLMResponse("assistant", is_chunk=False)
if not chunk.candidates[0].content.parts:
llm_response.result_chain = MessageChain(chain=[Comp.Plain(" ")])
else:
llm_response.result_chain = self._process_content_parts(
chunk, llm_response
)
yield llm_response
break
async def text_chat(
self,
prompt: str,
@@ -320,7 +421,6 @@ class ProviderGoogleGenAI(Provider):
**kwargs,
) -> LLMResponse:
new_record = await self.assemble_context(prompt, image_urls)
context_query = []
context_query = [*contexts, new_record]
if system_prompt:
context_query.insert(0, {"role": "system", "content": system_prompt})
@@ -337,82 +437,92 @@ class ProviderGoogleGenAI(Provider):
model_config["model"] = self.get_model()
payloads = {"messages": context_query, **model_config}
llm_response = None
retry = 10
keys = self.api_keys.copy()
chosen_key = random.choice(keys)
temp = kwargs.get("temperature", 0.7) # 暂定默认温度为0.7
for i in range(retry):
for _ in range(retry):
try:
self.client.api_key = chosen_key
llm_response = await self._query(payloads, func_tool)
return await self._query(payloads, func_tool, temp)
except APIError as e:
if await self._handle_api_error(e, keys):
continue
break
except Exception as e:
if "429" in str(e) or "API key not valid" in str(e):
keys.remove(chosen_key)
if len(keys) > 0:
chosen_key = random.choice(keys)
logger.info(
f"检测到 Key 异常({str(e)}),正在尝试更换 API Key 重试... 当前 Key: {chosen_key[:12]}..."
)
await asyncio.sleep(1)
continue
else:
logger.error(
f"检测到 Key 异常({str(e)}),且已没有可用的 Key。 当前 Key: {chosen_key[:12]}..."
)
raise Exception("达到了 Gemini 速率限制, 请稍后再试...")
else:
logger.error(
f"发生了错误(gemini_source)。Provider 配置如下: {self.provider_config}"
)
raise e
return llm_response
async def text_chat_stream(
self,
prompt,
session_id=None,
image_urls=...,
func_tool=None,
contexts=...,
prompt: str,
session_id: str = None,
image_urls: List[str] = [],
func_tool: FuncCall = None,
contexts=[],
system_prompt=None,
tool_calls_result=None,
**kwargs,
):
# raise NotImplementedError("This method is not implemented yet.")
# 调用 text_chat 模拟流式
llm_response = await self.text_chat(
prompt=prompt,
session_id=session_id,
image_urls=image_urls,
func_tool=func_tool,
contexts=contexts,
system_prompt=system_prompt,
tool_calls_result=tool_calls_result,
)
llm_response.is_chunk = True
yield llm_response
llm_response.is_chunk = False
yield llm_response
) -> AsyncGenerator[LLMResponse, None]:
new_record = await self.assemble_context(prompt, image_urls)
context_query = [*contexts, new_record]
if system_prompt:
context_query.insert(0, {"role": "system", "content": system_prompt})
for part in context_query:
if "_no_save" in part:
del part["_no_save"]
# tool calls result
if tool_calls_result:
context_query.extend(tool_calls_result.to_openai_messages())
model_config = self.provider_config.get("model_config", {})
model_config["model"] = self.get_model()
payloads = {"messages": context_query, **model_config}
retry = 10
keys = self.api_keys.copy()
temp = kwargs.get("temperature", 0.7) # 暂定默认温度为0.7
for _ in range(retry):
try:
async for response in self._query_stream(payloads, func_tool, temp):
yield response
break
except APIError as e:
if await self._handle_api_error(e, keys):
continue
break
async def get_models(self):
try:
models = await self.client.models.list()
return [
m.name.replace("models/", "")
for m in models
if "generateContent" in m.supported_actions
]
except APIError as e:
raise Exception(f"获取模型列表失败: {e.message}")
def get_current_key(self) -> str:
return self.client.api_key
return self.chosen_api_key
def get_keys(self) -> List[str]:
return self.api_keys
def set_key(self, key):
self.client.api_key = key
self.chosen_api_key = key
self._init_client()
async def assemble_context(self, text: str, image_urls: List[str] = None):
"""
组装上下文。
"""
if image_urls:
user_content = {"role": "user", "content": [{"type": "text", "text": text}]}
user_content = {
"role": "user",
"content": [{"type": "text", "text": text if text else "[图片]"}],
}
for image_url in image_urls:
if image_url.startswith("http"):
image_path = await download_image_by_url(image_url)
@@ -444,5 +554,4 @@ class ProviderGoogleGenAI(Provider):
return ""
async def terminate(self):
await self.client.client.close()
logger.info("Google GenAI 适配器已终止。")
@@ -505,7 +505,7 @@ class ProviderOpenAIOfficial(Provider):
async def assemble_context(self, text: str, image_urls: List[str] = None) -> dict:
"""组装成符合 OpenAI 格式的 role 为 user 的消息段"""
if image_urls:
user_content = {"role": "user", "content": [{"type": "text", "text": text}]}
user_content = {"role": "user", "content": [{"type": "text", "text": text if text else "[图片]"}]}
for image_url in image_urls:
if image_url.startswith("http"):
image_path = await download_image_by_url(image_url)
+1 -1
View File
@@ -28,7 +28,7 @@ from .filter.permission import PermissionTypeFilter, PermissionType
class PluginManager:
def __init__(self, context: Context, config: AstrBotConfig):
self.updator = PluginUpdator(config["plugin_repo_mirror"])
self.updator = PluginUpdator()
self.context = context
self.context._star_manager = self
+41
View File
@@ -1,10 +1,42 @@
import aiohttp
import sys
import os
import socket
import uuid
from astrbot.core.config import VERSION
from astrbot.core import db_helper, logger
class Metric:
_iid_cache = None
@staticmethod
def get_installation_id():
"""获取或创建一个唯一的安装ID"""
if Metric._iid_cache is not None:
return Metric._iid_cache
config_dir = os.path.join(os.path.expanduser("~"), ".astrbot")
id_file = os.path.join(config_dir, ".installation_id")
if os.path.exists(id_file):
try:
with open(id_file, "r") as f:
Metric._iid_cache = f.read().strip()
return Metric._iid_cache
except Exception:
pass
try:
os.makedirs(config_dir, exist_ok=True)
installation_id = str(uuid.uuid4())
with open(id_file, "w") as f:
f.write(installation_id)
Metric._iid_cache = installation_id
return installation_id
except Exception:
Metric._iid_cache = "null"
return "null"
@staticmethod
async def upload(**kwargs):
"""
@@ -16,6 +48,14 @@ class Metric:
kwargs["v"] = VERSION
kwargs["os"] = sys.platform
payload = {"metrics_data": kwargs}
try:
kwargs["hn"] = socket.gethostname()
except Exception:
pass
try:
kwargs["iid"] = Metric.get_installation_id()
except Exception:
pass
try:
if "adapter_name" in kwargs:
db_helper.insert_platform_metrics({kwargs["adapter_name"]: 1})
@@ -24,6 +64,7 @@ class Metric:
except Exception as e:
logger.error(f"保存指标到数据库失败: {e}")
pass
print(f"上传指标: {payload}")
try:
async with aiohttp.ClientSession(trust_env=True) as session:
+4 -4
View File
@@ -5,8 +5,9 @@ logger = logging.getLogger("astrbot")
class PipInstaller:
def __init__(self, pip_install_arg: str):
def __init__(self, pip_install_arg: str, pypi_index_url: str = None):
self.pip_install_arg = pip_install_arg
self.pypi_index_url = pypi_index_url
def install(
self,
@@ -20,10 +21,9 @@ class PipInstaller:
elif requirements_path:
args.extend(["-r", requirements_path])
if not mirror:
mirror = "https://mirrors.aliyun.com/pypi/simple/"
index_url = mirror or self.pypi_index_url or "https://pypi.org/simple"
args.extend(["--trusted-host", "mirrors.aliyun.com", "-i", mirror])
args.extend(["--trusted-host", "mirrors.aliyun.com", "-i", index_url])
if self.pip_install_arg:
args.extend(self.pip_install_arg.split())
+2 -1
View File
@@ -136,10 +136,11 @@ class UpdateRoute(Route):
data = await request.json
package = data.get("package", "")
mirror = data.get("mirror", None)
if not package:
return Response().error("缺少参数 package 或不合法。").__dict__
try:
pip_installer.install(package)
pip_installer.install(package, mirror=mirror)
return Response().ok(None, "安装成功。").__dict__
except Exception as e:
logger.error(f"/api/update_pip: {traceback.format_exc()}")
+2 -2
View File
@@ -27,8 +27,8 @@ import axios from 'axios';
</v-card-title>
<v-card-text>
<v-text-field v-model="pipInstallPayload.package" label="*库名,如 llmtuner" variant="outlined"></v-text-field>
<v-text-field v-model="pipInstallPayload.mirror" label="镜像站链接(可选)" variant="outlined"></v-text-field>
<small>如果不填镜像站链接默认使用阿里云镜像https://mirrors.aliyun.com/pypi/simple/</small>
<v-text-field v-model="pipInstallPayload.mirror" label="强制 PyPI 软件仓库链接(可选)" variant="outlined"></v-text-field>
<small>强制 PyPI 软件仓库链接 > 配置项 `PyPI 软件仓库地址`</small>
<div>
<small>{{ status }}</small>
</div>
+22 -2
View File
@@ -839,10 +839,30 @@ export default {
//
formatMessage(content) {
if (!content) return '空消息';
// content
// [{"type": "image_url", "image_url": {"url": url_or_base64}}, {"type": "text", "text": "text"}]
let final_content = content;
if (Array.isArray(content)) {
//
final_content = content.map(item => {
if (item.type === 'image_url') {
return `<img src="${item.image_url.url}" alt="Image" />`;
} else if (item.type === 'text') {
return item.text;
}
return '';
}).join('\n');
} else if (typeof content === 'object') {
//
final_content = Object.values(content).join('');
} else if (typeof content === 'string') {
//
final_content = content;
} else if (!final_content) return '空消息';
// 使markedMarkdown
return marked(content);
return marked(final_content);
},
//
+1
View File
@@ -19,6 +19,7 @@ dependencies = [
"defusedxml>=0.7.1",
"dingtalk-stream>=0.22.1",
"docstring-parser>=0.16",
"google-genai>=1.10.0",
"googlesearch-python>=1.3.0",
"lark-oapi>=1.4.12",
"lxml-html-clean>=0.4.1",
+2 -1
View File
@@ -29,4 +29,5 @@ defusedxml
mcp
certifi
pip
telegramify-markdown
telegramify-markdown
google-genai
Generated
+76
View File
@@ -209,6 +209,7 @@ dependencies = [
{ name = "defusedxml" },
{ name = "dingtalk-stream" },
{ name = "docstring-parser" },
{ name = "google-genai" },
{ name = "googlesearch-python" },
{ name = "lark-oapi" },
{ name = "lxml-html-clean" },
@@ -245,6 +246,7 @@ requires-dist = [
{ name = "defusedxml", specifier = ">=0.7.1" },
{ name = "dingtalk-stream", specifier = ">=0.22.1" },
{ name = "docstring-parser", specifier = ">=0.16" },
{ name = "google-genai", specifier = ">=1.10.0" },
{ name = "googlesearch-python", specifier = ">=1.3.0" },
{ name = "lark-oapi", specifier = ">=1.4.12" },
{ name = "lxml-html-clean", specifier = ">=0.4.1" },
@@ -305,6 +307,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/10/cb/f2ad4230dc2eb1a74edf38f1a38b9b52277f75bef262d8908e60d957e13c/blinker-1.9.0-py3-none-any.whl", hash = "sha256:ba0efaa9080b619ff2f3459d1d500c57bddea4a6b424b60a91141db6fd2f08bc", size = 8458 },
]
[[package]]
name = "cachetools"
version = "5.5.2"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/6c/81/3747dad6b14fa2cf53fcf10548cf5aea6913e96fab41a3c198676f8948a5/cachetools-5.5.2.tar.gz", hash = "sha256:1a661caa9175d26759571b2e19580f9d6393969e5dfca11fdb1f947a23e640d4", size = 28380 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/72/76/20fa66124dbe6be5cafeb312ece67de6b61dd91a0247d1ea13db4ebb33c2/cachetools-5.5.2-py3-none-any.whl", hash = "sha256:d26a22bcc62eb95c3beabd9f1ee5e820d3d2704fe2967cbe350e20c8ffcd3f0a", size = 10080 },
]
[[package]]
name = "certifi"
version = "2025.1.31"
@@ -676,6 +687,38 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/c6/c8/a5be5b7550c10858fcf9b0ea054baccab474da77d37f1e828ce043a3a5d4/frozenlist-1.5.0-py3-none-any.whl", hash = "sha256:d994863bba198a4a518b467bb971c56e1db3f180a25c6cf7bb1949c267f748c3", size = 11901 },
]
[[package]]
name = "google-auth"
version = "2.39.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "cachetools" },
{ name = "pyasn1-modules" },
{ name = "rsa" },
]
sdist = { url = "https://files.pythonhosted.org/packages/cb/8e/8f45c9a32f73e786e954b8f9761c61422955d23c45d1e8c347f9b4b59e8e/google_auth-2.39.0.tar.gz", hash = "sha256:73222d43cdc35a3aeacbfdcaf73142a97839f10de930550d89ebfe1d0a00cde7", size = 274834 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/ce/12/ad37a1ef86006d0a0117fc06a4a00bd461c775356b534b425f00dde208ea/google_auth-2.39.0-py2.py3-none-any.whl", hash = "sha256:0150b6711e97fb9f52fe599f55648950cc4540015565d8fbb31be2ad6e1548a2", size = 212319 },
]
[[package]]
name = "google-genai"
version = "1.10.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "anyio" },
{ name = "google-auth" },
{ name = "httpx" },
{ name = "pydantic" },
{ name = "requests" },
{ name = "typing-extensions" },
{ name = "websockets" },
]
sdist = { url = "https://files.pythonhosted.org/packages/0e/7a/224e2f70c835202042969685ee3da00a6475508d1b64f0f1e90144f96beb/google_genai-1.10.0.tar.gz", hash = "sha256:f59423e0f155dc66b7792c8a0e6724c75c72dc699d1eb7907d4d0006d4f6186f", size = 156355 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/ba/a0/56839a2e202d79c773edd1c1db124da8eb2a7b657267a888080b678d0369/google_genai-1.10.0-py3-none-any.whl", hash = "sha256:41b105a2fcf8a027fc45cc16694cd559b8cd1272eab7345ad58cfa2c353bf34f", size = 154705 },
]
[[package]]
name = "googlesearch-python"
version = "1.3.0"
@@ -1402,6 +1445,27 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/50/1b/6921afe68c74868b4c9fa424dad3be35b095e16687989ebbb50ce4fceb7c/psutil-7.0.0-cp37-abi3-win_amd64.whl", hash = "sha256:4cf3d4eb1aa9b348dec30105c55cd9b7d4629285735a102beb4441e38db90553", size = 244885 },
]
[[package]]
name = "pyasn1"
version = "0.6.1"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/ba/e9/01f1a64245b89f039897cb0130016d79f77d52669aae6ee7b159a6c4c018/pyasn1-0.6.1.tar.gz", hash = "sha256:6f580d2bdd84365380830acf45550f2511469f673cb4a5ae3857a3170128b034", size = 145322 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/c8/f1/d6a797abb14f6283c0ddff96bbdd46937f64122b8c925cab503dd37f8214/pyasn1-0.6.1-py3-none-any.whl", hash = "sha256:0d632f46f2ba09143da3a8afe9e33fb6f92fa2320ab7e886e2d0f7672af84629", size = 83135 },
]
[[package]]
name = "pyasn1-modules"
version = "0.4.2"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "pyasn1" },
]
sdist = { url = "https://files.pythonhosted.org/packages/e9/e6/78ebbb10a8c8e4b61a59249394a4a594c1a7af95593dc933a349c8d00964/pyasn1_modules-0.4.2.tar.gz", hash = "sha256:677091de870a80aae844b1ca6134f54652fa2c8c5a52aa396440ac3106e941e6", size = 307892 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/47/8d/d529b5d697919ba8c11ad626e835d4039be708a35b0d22de83a269a6682c/pyasn1_modules-0.4.2-py3-none-any.whl", hash = "sha256:29253a9207ce32b64c3ac6600edc75368f98473906e8fd1043bd6b5b1de2c14a", size = 181259 },
]
[[package]]
name = "pycparser"
version = "2.22"
@@ -1697,6 +1761,18 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/3f/51/d4db610ef29373b879047326cbf6fa98b6c1969d6f6dc423279de2b1be2c/requests_toolbelt-1.0.0-py2.py3-none-any.whl", hash = "sha256:cccfdd665f0a24fcf4726e690f65639d272bb0637b9b92dfd91a5568ccf6bd06", size = 54481 },
]
[[package]]
name = "rsa"
version = "4.9"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "pyasn1" },
]
sdist = { url = "https://files.pythonhosted.org/packages/aa/65/7d973b89c4d2351d7fb232c2e452547ddfa243e93131e7cfa766da627b52/rsa-4.9.tar.gz", hash = "sha256:e38464a49c6c85d7f1351b0126661487a7e0a14a50f1675ec50eb34d4f20ef21", size = 29711 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/49/97/fa78e3d2f65c02c8e1268b9aba606569fe97f6c8f7c2d74394553347c145/rsa-4.9-py3-none-any.whl", hash = "sha256:90260d9058e514786967344d0ef75fa8727eed8a7d2e43ce9f4bcf1b536174f7", size = 34315 },
]
[[package]]
name = "silk-python"
version = "0.2.6"