✨ feat: 增加对 Gemini 系列模型的输入安全设置参数支持
fixes: #216 Squashed: Update astrbot/core/config/default.py 描述更正. Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> 🎨 style: clean up 🐛 fix: 修复安全设置参数的默认值为列表
This commit is contained in:
@@ -522,6 +522,12 @@ CONFIG_METADATA_2 = {
|
||||
"model": "gemini-2.0-flash-exp",
|
||||
},
|
||||
"gm_resp_image_modal": False,
|
||||
"gm_safety_settings": {
|
||||
"harassment": "BLOCK_MEDIUM_AND_ABOVE",
|
||||
"hate_speech": "BLOCK_MEDIUM_AND_ABOVE",
|
||||
"sexually_explicit": "BLOCK_MEDIUM_AND_ABOVE",
|
||||
"dangerous_content": "BLOCK_MEDIUM_AND_ABOVE",
|
||||
},
|
||||
},
|
||||
"DeepSeek": {
|
||||
"id": "deepseek_default",
|
||||
@@ -678,6 +684,57 @@ CONFIG_METADATA_2 = {
|
||||
"type": "bool",
|
||||
"hint": "启用后,将支持返回图片内容。需要模型支持,否则会报错。具体支持模型请查看 Google Gemini 官方网站。温馨提示,如果您需要生成图片,请关闭 `启用群员识别` 配置获得更好的效果。",
|
||||
},
|
||||
"gm_safety_settings": {
|
||||
"description": "安全过滤器",
|
||||
"type": "object",
|
||||
"hint": "设置模型输入的内容安全过滤级别。过滤级别分类为NONE(不屏蔽)、HIGH(高风险时屏蔽)、MEDIUM_AND_ABOVE(中等风险及以上屏蔽)、LOW_AND_ABOVE(低风险及以上时屏蔽),具体参见Gemini API文档。",
|
||||
"items": {
|
||||
"harassment": {
|
||||
"description": "骚扰内容",
|
||||
"type": "string",
|
||||
"hint": "负面或有害评论",
|
||||
"options": [
|
||||
"BLOCK_NONE",
|
||||
"BLOCK_ONLY_HIGH",
|
||||
"BLOCK_MEDIUM_AND_ABOVE",
|
||||
"BLOCK_LOW_AND_ABOVE",
|
||||
],
|
||||
},
|
||||
"hate_speech": {
|
||||
"description": "仇恨言论",
|
||||
"type": "string",
|
||||
"hint": "粗鲁、无礼或亵渎性质内容",
|
||||
"options": [
|
||||
"BLOCK_NONE",
|
||||
"BLOCK_ONLY_HIGH",
|
||||
"BLOCK_MEDIUM_AND_ABOVE",
|
||||
"BLOCK_LOW_AND_ABOVE",
|
||||
],
|
||||
},
|
||||
"sexually_explicit": {
|
||||
"description": "露骨色情内容",
|
||||
"type": "string",
|
||||
"hint": "包含性行为或其他淫秽内容的引用",
|
||||
"options": [
|
||||
"BLOCK_NONE",
|
||||
"BLOCK_ONLY_HIGH",
|
||||
"BLOCK_MEDIUM_AND_ABOVE",
|
||||
"BLOCK_LOW_AND_ABOVE",
|
||||
],
|
||||
},
|
||||
"dangerous_content": {
|
||||
"description": "危险内容",
|
||||
"type": "string",
|
||||
"hint": "宣扬、助长或鼓励有害行为的信息",
|
||||
"options": [
|
||||
"BLOCK_NONE",
|
||||
"BLOCK_ONLY_HIGH",
|
||||
"BLOCK_MEDIUM_AND_ABOVE",
|
||||
"BLOCK_LOW_AND_ABOVE",
|
||||
],
|
||||
},
|
||||
},
|
||||
},
|
||||
"rag_options": {
|
||||
"description": "RAG 选项",
|
||||
"type": "object",
|
||||
|
||||
@@ -0,0 +1,112 @@
|
||||
import json
|
||||
import aiosqlite
|
||||
import os
|
||||
from typing import Any
|
||||
from .plugin_storage import PluginStorage
|
||||
|
||||
DBPATH = "data/plugin_data/sqlite/plugin_data.db"
|
||||
|
||||
|
||||
class SQLitePluginStorage(PluginStorage):
|
||||
"""插件数据的 SQLite 存储实现类。
|
||||
|
||||
该类提供异步方式将插件数据存储到 SQLite 数据库中,支持数据的增删改查操作。
|
||||
所有数据以 (plugin, key) 作为复合主键进行索引。
|
||||
"""
|
||||
|
||||
_instance = None # Standalone instance of the class
|
||||
_db_conn = None
|
||||
db_path = None
|
||||
|
||||
def __new__(cls):
|
||||
"""
|
||||
创建或获取 SQLitePluginStorage 的单例实例。
|
||||
如果实例已存在,则返回现有实例;否则创建一个新实例。
|
||||
数据在 `data/plugin_data/sqlite/plugin_data.db` 下。
|
||||
"""
|
||||
os.makedirs(os.path.dirname(DBPATH), exist_ok=True)
|
||||
if cls._instance is None:
|
||||
cls._instance = super(SQLitePluginStorage, cls).__new__(cls)
|
||||
cls._instance.db_path = DBPATH
|
||||
return cls._instance
|
||||
|
||||
async def _init_db(self):
|
||||
"""初始化数据库连接(只执行一次)"""
|
||||
if SQLitePluginStorage._db_conn is None:
|
||||
SQLitePluginStorage._db_conn = await aiosqlite.connect(self.db_path)
|
||||
await self._setup_db()
|
||||
|
||||
async def _setup_db(self):
|
||||
"""
|
||||
异步初始化数据库。
|
||||
|
||||
创建插件数据表,如果表不存在则创建,表结构包含 plugin、key 和 value 字段,
|
||||
其中 plugin 和 key 组合作为主键。
|
||||
"""
|
||||
await self._db_conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS plugin_data (
|
||||
plugin TEXT,
|
||||
key TEXT,
|
||||
value TEXT,
|
||||
PRIMARY KEY (plugin, key)
|
||||
)
|
||||
""")
|
||||
await self._db_conn.commit()
|
||||
|
||||
async def set(self, plugin: str, key: str, value: Any):
|
||||
"""
|
||||
异步存储数据。
|
||||
|
||||
将指定插件的键值对存入数据库,如果键已存在则更新值。
|
||||
值会被序列化为 JSON 字符串后存储。
|
||||
|
||||
Args:
|
||||
plugin: 插件标识符
|
||||
key: 数据键名
|
||||
value: 要存储的数据值(任意类型,将被 JSON 序列化)
|
||||
"""
|
||||
await self._init_db()
|
||||
await self._db_conn.execute(
|
||||
"INSERT INTO plugin_data (plugin, key, value) VALUES (?, ?, ?) "
|
||||
"ON CONFLICT(plugin, key) DO UPDATE SET value = excluded.value",
|
||||
(plugin, key, json.dumps(value)),
|
||||
)
|
||||
await self._db_conn.commit()
|
||||
|
||||
async def get(self, plugin: str, key: str) -> Any:
|
||||
"""
|
||||
异步获取数据。
|
||||
|
||||
从数据库中获取指定插件和键名对应的值,
|
||||
返回的值会从 JSON 字符串反序列化为原始数据类型。
|
||||
|
||||
Args:
|
||||
plugin: 插件标识符
|
||||
key: 数据键名
|
||||
|
||||
Returns:
|
||||
Any: 存储的数据值,如果未找到则返回 None
|
||||
"""
|
||||
await self._init_db()
|
||||
async with self._db_conn.execute(
|
||||
"SELECT value FROM plugin_data WHERE plugin = ? AND key = ?",
|
||||
(plugin, key),
|
||||
) as cursor:
|
||||
row = await cursor.fetchone()
|
||||
return json.loads(row[0]) if row else None
|
||||
|
||||
async def delete(self, plugin: str, key: str):
|
||||
"""
|
||||
异步删除数据。
|
||||
|
||||
从数据库中删除指定插件和键名对应的数据项。
|
||||
|
||||
Args:
|
||||
plugin: 插件标识符
|
||||
key: 要删除的数据键名
|
||||
"""
|
||||
await self._init_db()
|
||||
await self._db_conn.execute(
|
||||
"DELETE FROM plugin_data WHERE plugin = ? AND key = ?", (plugin, key)
|
||||
)
|
||||
await self._db_conn.commit()
|
||||
@@ -58,9 +58,9 @@ class LLMRequestSubStage(Stage):
|
||||
|
||||
if event.get_extra("provider_request"):
|
||||
req = event.get_extra("provider_request")
|
||||
assert isinstance(
|
||||
req, ProviderRequest
|
||||
), "provider_request 必须是 ProviderRequest 类型。"
|
||||
assert isinstance(req, ProviderRequest), (
|
||||
"provider_request 必须是 ProviderRequest 类型。"
|
||||
)
|
||||
|
||||
if req.conversation:
|
||||
req.contexts = json.loads(req.conversation.history)
|
||||
|
||||
@@ -156,9 +156,7 @@ class ResultDecorateStage(Stage):
|
||||
self.ctx.astrbot_config["provider_tts_settings"]["enable"]
|
||||
and result.is_llm_result()
|
||||
):
|
||||
tts_provider = (
|
||||
self.ctx.plugin_manager.context.provider_manager.curr_tts_provider_inst
|
||||
)
|
||||
tts_provider = self.ctx.plugin_manager.context.provider_manager.curr_tts_provider_inst
|
||||
new_chain = []
|
||||
for comp in result.chain:
|
||||
if isinstance(comp, Plain) and len(comp.text) > 1:
|
||||
|
||||
@@ -63,9 +63,7 @@ class ProviderEdgeTTS(TTSProvider):
|
||||
ff = FFmpeg()
|
||||
ff.convert(input=mp3_path, output=wav_path)
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
f"pyffmpeg 转换失败: {e}, 尝试使用 ffmpeg 命令行进行转换"
|
||||
)
|
||||
logger.debug(f"pyffmpeg 转换失败: {e}, 尝试使用 ffmpeg 命令行进行转换")
|
||||
# use ffmpeg command line
|
||||
|
||||
# 使用ffmpeg将MP3转换为标准WAV格式
|
||||
|
||||
@@ -43,6 +43,7 @@ class SimpleGoogleGenAIClient:
|
||||
system_instruction: str = "",
|
||||
tools: dict = None,
|
||||
modalities: List[str] = ["Text"],
|
||||
safety_settings: List[dict] = [],
|
||||
):
|
||||
payload = {}
|
||||
if system_instruction:
|
||||
@@ -53,6 +54,10 @@ class SimpleGoogleGenAIClient:
|
||||
payload["generationConfig"] = {
|
||||
"responseModalities": modalities,
|
||||
}
|
||||
payload["safetySettings"] = [
|
||||
{"category": s["category"], "threshold": s["threshold"]}
|
||||
for s in safety_settings
|
||||
]
|
||||
logger.debug(f"payload: {payload}")
|
||||
request_url = (
|
||||
f"{self.api_base}/v1beta/models/{model}:generateContent?key={self.api_key}"
|
||||
@@ -106,6 +111,21 @@ class ProviderGoogleGenAI(Provider):
|
||||
)
|
||||
self.set_model(provider_config["model_config"]["model"])
|
||||
|
||||
safety_mapping = {
|
||||
"harassment": "HARM_CATEGORY_HARASSMENT",
|
||||
"hate_speech": "HARM_CATEGORY_HATE_SPEECH",
|
||||
"sexually_explicit": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
||||
"dangerous_content": "HARM_CATEGORY_DANGEROUS_CONTENT",
|
||||
}
|
||||
|
||||
self.safety_settings = []
|
||||
user_safety_config = self.provider_config.get("gm_safety_settings", {})
|
||||
for config_key, harm_category in safety_mapping.items():
|
||||
if threshold := user_safety_config.get(config_key):
|
||||
self.safety_settings.append(
|
||||
{"category": harm_category, "threshold": threshold}
|
||||
)
|
||||
|
||||
async def get_models(self):
|
||||
return await self.client.models_list()
|
||||
|
||||
@@ -205,6 +225,7 @@ class ProviderGoogleGenAI(Provider):
|
||||
system_instruction=system_instruction,
|
||||
tools=tool,
|
||||
modalities=modalites,
|
||||
safety_settings=self.safety_settings,
|
||||
)
|
||||
logger.debug(f"result: {result}")
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ class StaticFileRoute(Route):
|
||||
"/about",
|
||||
"/extension-marketplace",
|
||||
"/conversation",
|
||||
"/tool-use"
|
||||
"/tool-use",
|
||||
]
|
||||
for i in index_:
|
||||
self.app.add_url_rule(i, view_func=self.index)
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import astrbot.api.message_components as Comp
|
||||
import copy
|
||||
import json
|
||||
from astrbot.api import logger
|
||||
from astrbot.api.event import AstrMessageEvent, filter
|
||||
from astrbot.api.star import Context, Star, register
|
||||
@@ -64,17 +63,11 @@ class Waiter(Star):
|
||||
event.unified_msg_origin
|
||||
)
|
||||
conversation = None
|
||||
context = []
|
||||
|
||||
if curr_cid:
|
||||
conversation = await self.context.conversation_manager.get_conversation(
|
||||
event.unified_msg_origin, curr_cid
|
||||
)
|
||||
context = (
|
||||
json.loads(conversation.history)
|
||||
if conversation.history
|
||||
else []
|
||||
)
|
||||
else:
|
||||
# 创建新对话
|
||||
curr_cid = await self.context.conversation_manager.new_conversation(
|
||||
|
||||
Reference in New Issue
Block a user