diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 879193f43..35e6310ba 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -522,6 +522,12 @@ CONFIG_METADATA_2 = { "model": "gemini-2.0-flash-exp", }, "gm_resp_image_modal": False, + "gm_safety_settings": { + "harassment": "BLOCK_MEDIUM_AND_ABOVE", + "hate_speech": "BLOCK_MEDIUM_AND_ABOVE", + "sexually_explicit": "BLOCK_MEDIUM_AND_ABOVE", + "dangerous_content": "BLOCK_MEDIUM_AND_ABOVE", + }, }, "DeepSeek": { "id": "deepseek_default", @@ -678,6 +684,57 @@ CONFIG_METADATA_2 = { "type": "bool", "hint": "启用后,将支持返回图片内容。需要模型支持,否则会报错。具体支持模型请查看 Google Gemini 官方网站。温馨提示,如果您需要生成图片,请关闭 `启用群员识别` 配置获得更好的效果。", }, + "gm_safety_settings": { + "description": "安全过滤器", + "type": "object", + "hint": "设置模型输入的内容安全过滤级别。过滤级别分类为NONE(不屏蔽)、HIGH(高风险时屏蔽)、MEDIUM_AND_ABOVE(中等风险及以上屏蔽)、LOW_AND_ABOVE(低风险及以上时屏蔽),具体参见Gemini API文档。", + "items": { + "harassment": { + "description": "骚扰内容", + "type": "string", + "hint": "负面或有害评论", + "options": [ + "BLOCK_NONE", + "BLOCK_ONLY_HIGH", + "BLOCK_MEDIUM_AND_ABOVE", + "BLOCK_LOW_AND_ABOVE", + ], + }, + "hate_speech": { + "description": "仇恨言论", + "type": "string", + "hint": "粗鲁、无礼或亵渎性质内容", + "options": [ + "BLOCK_NONE", + "BLOCK_ONLY_HIGH", + "BLOCK_MEDIUM_AND_ABOVE", + "BLOCK_LOW_AND_ABOVE", + ], + }, + "sexually_explicit": { + "description": "露骨色情内容", + "type": "string", + "hint": "包含性行为或其他淫秽内容的引用", + "options": [ + "BLOCK_NONE", + "BLOCK_ONLY_HIGH", + "BLOCK_MEDIUM_AND_ABOVE", + "BLOCK_LOW_AND_ABOVE", + ], + }, + "dangerous_content": { + "description": "危险内容", + "type": "string", + "hint": "宣扬、助长或鼓励有害行为的信息", + "options": [ + "BLOCK_NONE", + "BLOCK_ONLY_HIGH", + "BLOCK_MEDIUM_AND_ABOVE", + "BLOCK_LOW_AND_ABOVE", + ], + }, + }, + }, "rag_options": { "description": "RAG 选项", "type": "object", diff --git a/astrbot/core/db/plugin/sqlite_impl.py b/astrbot/core/db/plugin/sqlite_impl.py new file mode 100644 index 000000000..5440362af --- /dev/null +++ b/astrbot/core/db/plugin/sqlite_impl.py @@ -0,0 +1,112 @@ +import json +import aiosqlite +import os +from typing import Any +from .plugin_storage import PluginStorage + +DBPATH = "data/plugin_data/sqlite/plugin_data.db" + + +class SQLitePluginStorage(PluginStorage): + """插件数据的 SQLite 存储实现类。 + + 该类提供异步方式将插件数据存储到 SQLite 数据库中,支持数据的增删改查操作。 + 所有数据以 (plugin, key) 作为复合主键进行索引。 + """ + + _instance = None # Standalone instance of the class + _db_conn = None + db_path = None + + def __new__(cls): + """ + 创建或获取 SQLitePluginStorage 的单例实例。 + 如果实例已存在,则返回现有实例;否则创建一个新实例。 + 数据在 `data/plugin_data/sqlite/plugin_data.db` 下。 + """ + os.makedirs(os.path.dirname(DBPATH), exist_ok=True) + if cls._instance is None: + cls._instance = super(SQLitePluginStorage, cls).__new__(cls) + cls._instance.db_path = DBPATH + return cls._instance + + async def _init_db(self): + """初始化数据库连接(只执行一次)""" + if SQLitePluginStorage._db_conn is None: + SQLitePluginStorage._db_conn = await aiosqlite.connect(self.db_path) + await self._setup_db() + + async def _setup_db(self): + """ + 异步初始化数据库。 + + 创建插件数据表,如果表不存在则创建,表结构包含 plugin、key 和 value 字段, + 其中 plugin 和 key 组合作为主键。 + """ + await self._db_conn.execute(""" + CREATE TABLE IF NOT EXISTS plugin_data ( + plugin TEXT, + key TEXT, + value TEXT, + PRIMARY KEY (plugin, key) + ) + """) + await self._db_conn.commit() + + async def set(self, plugin: str, key: str, value: Any): + """ + 异步存储数据。 + + 将指定插件的键值对存入数据库,如果键已存在则更新值。 + 值会被序列化为 JSON 字符串后存储。 + + Args: + plugin: 插件标识符 + key: 数据键名 + value: 要存储的数据值(任意类型,将被 JSON 序列化) + """ + await self._init_db() + await self._db_conn.execute( + "INSERT INTO plugin_data (plugin, key, value) VALUES (?, ?, ?) " + "ON CONFLICT(plugin, key) DO UPDATE SET value = excluded.value", + (plugin, key, json.dumps(value)), + ) + await self._db_conn.commit() + + async def get(self, plugin: str, key: str) -> Any: + """ + 异步获取数据。 + + 从数据库中获取指定插件和键名对应的值, + 返回的值会从 JSON 字符串反序列化为原始数据类型。 + + Args: + plugin: 插件标识符 + key: 数据键名 + + Returns: + Any: 存储的数据值,如果未找到则返回 None + """ + await self._init_db() + async with self._db_conn.execute( + "SELECT value FROM plugin_data WHERE plugin = ? AND key = ?", + (plugin, key), + ) as cursor: + row = await cursor.fetchone() + return json.loads(row[0]) if row else None + + async def delete(self, plugin: str, key: str): + """ + 异步删除数据。 + + 从数据库中删除指定插件和键名对应的数据项。 + + Args: + plugin: 插件标识符 + key: 要删除的数据键名 + """ + await self._init_db() + await self._db_conn.execute( + "DELETE FROM plugin_data WHERE plugin = ? AND key = ?", (plugin, key) + ) + await self._db_conn.commit() diff --git a/astrbot/core/pipeline/process_stage/method/llm_request.py b/astrbot/core/pipeline/process_stage/method/llm_request.py index 353d9d3df..7d41f7ecc 100644 --- a/astrbot/core/pipeline/process_stage/method/llm_request.py +++ b/astrbot/core/pipeline/process_stage/method/llm_request.py @@ -58,9 +58,9 @@ class LLMRequestSubStage(Stage): if event.get_extra("provider_request"): req = event.get_extra("provider_request") - assert isinstance( - req, ProviderRequest - ), "provider_request 必须是 ProviderRequest 类型。" + assert isinstance(req, ProviderRequest), ( + "provider_request 必须是 ProviderRequest 类型。" + ) if req.conversation: req.contexts = json.loads(req.conversation.history) diff --git a/astrbot/core/pipeline/result_decorate/stage.py b/astrbot/core/pipeline/result_decorate/stage.py index 4894b2e03..d7bb9583c 100644 --- a/astrbot/core/pipeline/result_decorate/stage.py +++ b/astrbot/core/pipeline/result_decorate/stage.py @@ -156,9 +156,7 @@ class ResultDecorateStage(Stage): self.ctx.astrbot_config["provider_tts_settings"]["enable"] and result.is_llm_result() ): - tts_provider = ( - self.ctx.plugin_manager.context.provider_manager.curr_tts_provider_inst - ) + tts_provider = self.ctx.plugin_manager.context.provider_manager.curr_tts_provider_inst new_chain = [] for comp in result.chain: if isinstance(comp, Plain) and len(comp.text) > 1: diff --git a/astrbot/core/provider/sources/edge_tts_source.py b/astrbot/core/provider/sources/edge_tts_source.py index b6b758e29..0eadb2190 100644 --- a/astrbot/core/provider/sources/edge_tts_source.py +++ b/astrbot/core/provider/sources/edge_tts_source.py @@ -63,9 +63,7 @@ class ProviderEdgeTTS(TTSProvider): ff = FFmpeg() ff.convert(input=mp3_path, output=wav_path) except Exception as e: - logger.debug( - f"pyffmpeg 转换失败: {e}, 尝试使用 ffmpeg 命令行进行转换" - ) + logger.debug(f"pyffmpeg 转换失败: {e}, 尝试使用 ffmpeg 命令行进行转换") # use ffmpeg command line # 使用ffmpeg将MP3转换为标准WAV格式 diff --git a/astrbot/core/provider/sources/gemini_source.py b/astrbot/core/provider/sources/gemini_source.py index 3233b3453..c316544ff 100644 --- a/astrbot/core/provider/sources/gemini_source.py +++ b/astrbot/core/provider/sources/gemini_source.py @@ -43,6 +43,7 @@ class SimpleGoogleGenAIClient: system_instruction: str = "", tools: dict = None, modalities: List[str] = ["Text"], + safety_settings: List[dict] = [], ): payload = {} if system_instruction: @@ -53,6 +54,10 @@ class SimpleGoogleGenAIClient: payload["generationConfig"] = { "responseModalities": modalities, } + payload["safetySettings"] = [ + {"category": s["category"], "threshold": s["threshold"]} + for s in safety_settings + ] logger.debug(f"payload: {payload}") request_url = ( f"{self.api_base}/v1beta/models/{model}:generateContent?key={self.api_key}" @@ -106,6 +111,21 @@ class ProviderGoogleGenAI(Provider): ) self.set_model(provider_config["model_config"]["model"]) + safety_mapping = { + "harassment": "HARM_CATEGORY_HARASSMENT", + "hate_speech": "HARM_CATEGORY_HATE_SPEECH", + "sexually_explicit": "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "dangerous_content": "HARM_CATEGORY_DANGEROUS_CONTENT", + } + + self.safety_settings = [] + user_safety_config = self.provider_config.get("gm_safety_settings", {}) + for config_key, harm_category in safety_mapping.items(): + if threshold := user_safety_config.get(config_key): + self.safety_settings.append( + {"category": harm_category, "threshold": threshold} + ) + async def get_models(self): return await self.client.models_list() @@ -205,6 +225,7 @@ class ProviderGoogleGenAI(Provider): system_instruction=system_instruction, tools=tool, modalities=modalites, + safety_settings=self.safety_settings, ) logger.debug(f"result: {result}") diff --git a/astrbot/dashboard/routes/static_file.py b/astrbot/dashboard/routes/static_file.py index 5d4c05c6b..4503a28e5 100644 --- a/astrbot/dashboard/routes/static_file.py +++ b/astrbot/dashboard/routes/static_file.py @@ -21,7 +21,7 @@ class StaticFileRoute(Route): "/about", "/extension-marketplace", "/conversation", - "/tool-use" + "/tool-use", ] for i in index_: self.app.add_url_rule(i, view_func=self.index) diff --git a/packages/session_controller/main.py b/packages/session_controller/main.py index 6c00bc81d..99d0a2e62 100644 --- a/packages/session_controller/main.py +++ b/packages/session_controller/main.py @@ -1,6 +1,5 @@ import astrbot.api.message_components as Comp import copy -import json from astrbot.api import logger from astrbot.api.event import AstrMessageEvent, filter from astrbot.api.star import Context, Star, register @@ -64,17 +63,11 @@ class Waiter(Star): event.unified_msg_origin ) conversation = None - context = [] if curr_cid: conversation = await self.context.conversation_manager.get_conversation( event.unified_msg_origin, curr_cid ) - context = ( - json.loads(conversation.history) - if conversation.history - else [] - ) else: # 创建新对话 curr_cid = await self.context.conversation_manager.new_conversation(