diff --git a/astrbot/core/__init__.py b/astrbot/core/__init__.py index 272a0f417..235a8284b 100644 --- a/astrbot/core/__init__.py +++ b/astrbot/core/__init__.py @@ -20,7 +20,7 @@ html_renderer = HtmlRenderer(t2i_base_url) logger = LogManager.GetLogger(log_name="astrbot") db_helper = SQLiteDatabase(DB_PATH) # 简单的偏好设置存储, 这里后续应该存储到数据库中, 一些部分可以存储到配置中 -sp = SharedPreferences() +sp = SharedPreferences(db_helper=db_helper) # 文件令牌服务 file_token_service = FileTokenService() pip_installer = PipInstaller( diff --git a/astrbot/core/astrbot_config_mgr.py b/astrbot/core/astrbot_config_mgr.py index 5fe5b6e83..ef40981dc 100644 --- a/astrbot/core/astrbot_config_mgr.py +++ b/astrbot/core/astrbot_config_mgr.py @@ -40,7 +40,9 @@ class AstrBotConfigManager: def _load_all_configs(self): """Load all configurations from the shared preferences.""" - abconf_data = self.sp.get("abconf_mapping", {}) + abconf_data = self.sp.get( + "abconf_mapping", {}, scope="global", scope_id="global" + ) for uuid_, meta in abconf_data.items(): filename = meta["path"] conf_path = os.path.join(get_astrbot_config_path(), filename) @@ -55,13 +57,13 @@ class AstrBotConfigManager: def _is_umo_match(self, p1: str, p2: str) -> bool: """判断 p2 umo 是否逻辑包含于 p1 umo""" - p1 = p1.split(":") - p2 = p2.split(":") + p1_ls = p1.split(":") + p2_ls = p2.split(":") - if len(p1) != 3 or len(p2) != 3: + if len(p1_ls) != 3 or len(p2_ls) != 3: return False # 非法格式 - return all(p == "" or p == t for p, t in zip(p1, p2)) + return all(p == "" or p == t for p, t in zip(p1_ls, p2_ls)) def _load_conf_mapping(self, umo: str | MessageSession) -> ConfInfo: """获取指定 umo 的配置文件 uuid, 如果不存在则返回默认配置(返回 "default") @@ -70,7 +72,9 @@ class AstrBotConfigManager: ConfInfo: 包含配置文件的 uuid, 路径和名称等信息, 是一个 dict 类型 """ # uuid -> { "umop": list, "path": str, "name": str } - abconf_data = self.sp.get("abconf_mapping", {}) # default is not included here + abconf_data = self.sp.get( + "abconf_mapping", {}, scope="global", scope_id="global" + ) if isinstance(umo, MessageSession): umo = str(umo) else: @@ -91,7 +95,7 @@ class AstrBotConfigManager: abconf_path: str, abconf_id: str, umo_parts: list[str] | list[MessageSession], - abconf_name: str = None, + abconf_name: str | None = None, ) -> None: """保存配置文件的映射关系""" for part in umo_parts: @@ -101,14 +105,16 @@ class AstrBotConfigManager: raise ValueError( "umo_parts must be a list of strings or MessageSession instances" ) - abconf_data = self.sp.get("abconf_mapping", {}) + abconf_data = self.sp.get( + "abconf_mapping", {}, scope="global", scope_id="global" + ) random_word = abconf_name or uuid.uuid4().hex[:8] abconf_data[abconf_id] = { "umop": umo_parts, "path": abconf_path, "name": random_word, } - self.sp.put("abconf_mapping", abconf_data) + self.sp.put("abconf_mapping", abconf_data, scope="global", scope_id="global") def get_conf(self, umo: str | MessageSession | None) -> AstrBotConfig: """获取指定 umo 的配置文件。如果不存在,则 fallback 到默认配置文件。""" @@ -141,7 +147,10 @@ class AstrBotConfigManager: """获取所有配置文件的元数据列表""" conf_list = [] conf_list.append(DEFAULT_CONFIG_CONF_INFO) - for uuid_, meta in self.sp.get("abconf_mapping", {}).items(): + abconf_mapping = self.sp.get( + "abconf_mapping", {}, scope="global", scope_id="global" + ) + for uuid_, meta in abconf_mapping.items(): conf_list.append(ConfInfo(**meta, id=uuid_)) return conf_list @@ -149,7 +158,7 @@ class AstrBotConfigManager: self, umo_parts: list[str] | list[MessageSession], config: dict = DEFAULT_CONFIG, - name: str = None, + name: str | None = None, ) -> str: """ umo 由三个部分组成 [platform_id]:[message_type]:[session_id]。 @@ -181,7 +190,9 @@ class AstrBotConfigManager: raise ValueError("不能删除默认配置文件") # 从映射中移除 - abconf_data = self.sp.get("abconf_mapping", {}) + abconf_data = self.sp.get( + "abconf_mapping", {}, scope="global", scope_id="global" + ) if conf_id not in abconf_data: logger.warning(f"配置文件 {conf_id} 不存在于映射中") return False @@ -206,13 +217,13 @@ class AstrBotConfigManager: # 从映射中移除 del abconf_data[conf_id] - self.sp.put("abconf_mapping", abconf_data) + self.sp.put("abconf_mapping", abconf_data, scope="global", scope_id="global") logger.info(f"成功删除配置文件 {conf_id}") return True def update_conf_info( - self, conf_id: str, name: str = None, umo_parts: list[str] = None + self, conf_id: str, name: str | None = None, umo_parts: list[str] | None = None ) -> bool: """更新配置文件信息 @@ -227,7 +238,9 @@ class AstrBotConfigManager: if conf_id == "default": raise ValueError("不能更新默认配置文件的信息") - abconf_data = self.sp.get("abconf_mapping", {}) + abconf_data = self.sp.get( + "abconf_mapping", {}, scope="global", scope_id="global" + ) if conf_id not in abconf_data: logger.warning(f"配置文件 {conf_id} 不存在于映射中") return False @@ -249,11 +262,13 @@ class AstrBotConfigManager: abconf_data[conf_id]["umop"] = umo_parts # 保存更新 - self.sp.put("abconf_mapping", abconf_data) + self.sp.put("abconf_mapping", abconf_data, scope="global", scope_id="global") logger.info(f"成功更新配置文件 {conf_id} 的信息") return True - def g(self, umo: str = None, key: str = None, default: _VT = None) -> _VT: + def g( + self, umo: str | None = None, key: str | None = None, default: _VT = None + ) -> _VT: """获取配置项。umo 为 None 时使用默认配置""" if umo is None: return self.confs["default"].get(key, default) diff --git a/astrbot/core/conversation_mgr.py b/astrbot/core/conversation_mgr.py index 62079b344..1e04e4b30 100644 --- a/astrbot/core/conversation_mgr.py +++ b/astrbot/core/conversation_mgr.py @@ -6,7 +6,6 @@ AstrBot 会话-对话管理器, 维护两个本地存储, 其中一个是 json """ import json -import asyncio from astrbot.core import sp from typing import Dict, List from astrbot.core.db import BaseDatabase @@ -17,25 +16,9 @@ class ConversationManager: """负责管理会话与 LLM 的对话,某个会话当前正在用哪个对话。""" def __init__(self, db_helper: BaseDatabase): - # session_conversations 字典记录会话ID-对话ID 映射关系 - self.session_conversations: Dict[str, str] = sp.get("session_conversation", {}) + self.session_conversations: Dict[str, str] = {} self.db = db_helper self.save_interval = 60 # 每 60 秒保存一次 - self._start_periodic_save() - - def _start_periodic_save(self): - """启动定时保存任务""" - asyncio.create_task(self._periodic_save()) - - async def _periodic_save(self): - """定时保存会话对话映射关系到存储中""" - while True: - await asyncio.sleep(self.save_interval) - self._save_to_storage() - - def _save_to_storage(self): - """保存会话对话映射关系到存储中""" - sp.put("session_conversation", self.session_conversations) def _convert_conv_from_v2_to_v1(self, conv_v2: ConversationV2) -> Conversation: """将 ConversationV2 对象转换为 Conversation 对象""" @@ -55,10 +38,10 @@ class ConversationManager: async def new_conversation( self, unified_msg_origin: str, - platform_id: str = None, - content: list[dict] = None, - title: str = None, - persona_id: str = None, + platform_id: str | None = None, + content: list[dict] | None = None, + title: str | None = None, + persona_id: str | None = None, ) -> str: """新建对话,并将当前会话的对话转移到新对话 @@ -82,8 +65,8 @@ class ConversationManager: persona_id=persona_id, ) self.session_conversations[unified_msg_origin] = conv.conversation_id - sp.put("session_conversation", self.session_conversations) - return str(conv.conversation_id) + await sp.session_put(unified_msg_origin, "sel_conv_id", conv.conversation_id) + return conv.conversation_id async def switch_conversation(self, unified_msg_origin: str, conversation_id: str): """切换会话的对话 @@ -93,10 +76,10 @@ class ConversationManager: conversation_id (str): 对话 ID, 是 uuid 格式的字符串 """ self.session_conversations[unified_msg_origin] = conversation_id - sp.put("session_conversation", self.session_conversations) + await sp.session_put(unified_msg_origin, "sel_conv_id", conversation_id) async def delete_conversation( - self, unified_msg_origin: str, conversation_id: str = None + self, unified_msg_origin: str, conversation_id: str | None = None ): """删除会话的对话,当 conversation_id 为 None 时删除会话当前的对话 @@ -113,9 +96,9 @@ class ConversationManager: await self.db.delete_conversation(cid=conversation_id) if f: self.session_conversations.pop(unified_msg_origin, None) - sp.put("session_conversation", self.session_conversations) + await sp.session_remove(unified_msg_origin, "sel_conv_id") - async def get_curr_conversation_id(self, unified_msg_origin: str) -> str: + async def get_curr_conversation_id(self, unified_msg_origin: str) -> str | None: """获取会话当前的对话 ID Args: @@ -123,7 +106,12 @@ class ConversationManager: Returns: conversation_id (str): 对话 ID, 是 uuid 格式的字符串 """ - return self.session_conversations.get(unified_msg_origin, None) + ret = self.session_conversations.get(unified_msg_origin, None) + if not ret: + ret = await sp.session_get(unified_msg_origin, "sel_conv_id", None) + if ret: + self.session_conversations[unified_msg_origin] = ret + return ret async def get_conversation( self, @@ -150,7 +138,7 @@ class ConversationManager: return conv_res async def get_conversations( - self, unified_msg_origin: str = None, platform_id: str = None + self, unified_msg_origin: str | None = None, platform_id: str | None = None ) -> List[Conversation]: """获取对话列表 @@ -203,10 +191,10 @@ class ConversationManager: async def update_conversation( self, unified_msg_origin: str, - conversation_id: str = None, - history: list[dict] = None, - title: str = None, - persona_id: str = None, + conversation_id: str | None = None, + history: list[dict] | None = None, + title: str | None = None, + persona_id: str | None = None, ): """更新会话的对话 @@ -216,8 +204,8 @@ class ConversationManager: history (List[Dict]): 对话历史记录, 是一个字典列表, 每个字典包含 role 和 content 字段 """ if not conversation_id: - # 如果没有提供 conversation_id,则从 session_conversations 中获取当前的 - conversation_id = self.session_conversations.get(unified_msg_origin) + # 如果没有提供 conversation_id,则获取当前的 + conversation_id = await self.get_curr_conversation_id(unified_msg_origin) if conversation_id: await self.db.update_conversation( cid=conversation_id, @@ -227,7 +215,7 @@ class ConversationManager: ) async def update_conversation_title( - self, unified_msg_origin: str, title: str, conversation_id: str = None + self, unified_msg_origin: str, title: str, conversation_id: str | None = None ): """更新会话的对话标题 @@ -245,7 +233,10 @@ class ConversationManager: ) async def update_conversation_persona_id( - self, unified_msg_origin: str, persona_id: str, conversation_id: str = None + self, + unified_msg_origin: str, + persona_id: str, + conversation_id: str | None = None, ): """更新会话的对话 Persona ID diff --git a/astrbot/core/db/__init__.py b/astrbot/core/db/__init__.py index 00c42505f..2de109b7d 100644 --- a/astrbot/core/db/__init__.py +++ b/astrbot/core/db/__init__.py @@ -74,7 +74,7 @@ class BaseDatabase(abc.ABC): platform_id: str, platform_type: str, count: int = 1, - timestamp: datetime.datetime = None, + timestamp: datetime.datetime | None = None, ) -> None: """Insert a new platform statistic record.""" ... @@ -91,7 +91,7 @@ class BaseDatabase(abc.ABC): @abc.abstractmethod async def get_conversations( - self, user_id: str = None, platform_id: str = None + self, user_id: str | None = None, platform_id: str | None = None ) -> list[ConversationV2]: """Get all conversations for a specific user and platform_id(optional). @@ -128,12 +128,12 @@ class BaseDatabase(abc.ABC): self, user_id: str, platform_id: str, - content: list[dict] = None, - title: str = None, - persona_id: str = None, - cid: str = None, - created_at: datetime.datetime = None, - updated_at: datetime.datetime = None, + content: list[dict] | None = None, + title: str | None = None, + persona_id: str | None = None, + cid: str | None = None, + created_at: datetime.datetime | None = None, + updated_at: datetime.datetime | None = None, ) -> ConversationV2: """Create a new conversation.""" ... @@ -142,9 +142,9 @@ class BaseDatabase(abc.ABC): async def update_conversation( self, cid: str, - title: str = None, - persona_id: str = None, - content: list[dict] = None, + title: str | None = None, + persona_id: str | None = None, + content: list[dict] | None = None, ) -> None: """Update a conversation's history.""" ... @@ -160,8 +160,8 @@ class BaseDatabase(abc.ABC): platform_id: str, user_id: str, content: list[dict], - sender_id: str = None, - sender_name: str = None, + sender_id: str | None = None, + sender_name: str | None = None, ) -> None: """Insert a new platform message history record.""" ... @@ -204,8 +204,8 @@ class BaseDatabase(abc.ABC): self, persona_id: str, system_prompt: str, - begin_dialogs: list[str] = None, - tools: list[str] = None, + begin_dialogs: list[str] | None = None, + tools: list[str] | None = None, ) -> Persona: """Insert a new persona record.""" ... @@ -224,9 +224,9 @@ class BaseDatabase(abc.ABC): async def update_persona( self, persona_id: str, - system_prompt: str = None, - begin_dialogs: list[str] = None, - tools: list[str] = None, + system_prompt: str | None = None, + begin_dialogs: list[str] | None = None, + tools: list[str] | None = None, ) -> Persona | None: """Update a persona's system prompt or begin dialogs.""" ... @@ -237,13 +237,32 @@ class BaseDatabase(abc.ABC): ... @abc.abstractmethod - async def insert_preference_or_update(self, key: str, value: str) -> Preference: + async def insert_preference_or_update( + self, scope: str, scope_id: str, key: str, value: dict + ) -> Preference: """Insert a new preference record.""" ... @abc.abstractmethod - async def get_preference(self, key: str) -> Preference: - """Get a preference by bot ID and key.""" + async def get_preference(self, scope: str, scope_id: str, key: str) -> Preference: + """Get a preference by scope ID and key.""" + ... + + @abc.abstractmethod + async def get_preferences( + self, scope: str, scope_id: str | None = None, key: str | None = None + ) -> list[Preference]: + """Get all preferences for a specific scope ID or key.""" + ... + + @abc.abstractmethod + async def remove_preference(self, scope: str, scope_id: str, key: str) -> None: + """Remove a preference by scope ID and key.""" + ... + + @abc.abstractmethod + async def clear_preferences(self, scope: str, scope_id: str) -> None: + """Clear all preferences for a specific scope ID.""" ... # @abc.abstractmethod diff --git a/astrbot/core/db/migration/helper.py b/astrbot/core/db/migration/helper.py index 4ac075a8d..796a7b336 100644 --- a/astrbot/core/db/migration/helper.py +++ b/astrbot/core/db/migration/helper.py @@ -2,12 +2,13 @@ import os from astrbot.core.utils.astrbot_path import get_astrbot_data_path from astrbot.core.db import BaseDatabase from astrbot.core.config import AstrBotConfig -from astrbot.api import logger +from astrbot.api import logger, sp from .migra_3_to_4 import ( migration_conversation_table, migration_platform_table, migration_webchat_data, migration_persona_data, + migration_preferences, ) @@ -19,7 +20,9 @@ async def check_migration_needed_v4(db_helper: BaseDatabase) -> bool: data_v3_exists = os.path.exists(get_astrbot_data_path()) if not data_v3_exists: return False - migration_done = await db_helper.get_preference("migration_done_v4") + migration_done = await db_helper.get_preference( + "global", "global", "migration_done_v4" + ) if migration_done: return False return True @@ -49,10 +52,13 @@ async def do_migration_v4( # 执行 WebChat 数据迁移 await migration_webchat_data(db_helper, platform_id_map) + # 执行偏好设置迁移 + await migration_preferences(db_helper,platform_id_map) + # 执行平台统计表迁移 await migration_platform_table(db_helper, platform_id_map) # 标记迁移完成 - await db_helper.insert_preference_or_update("migration_done_v4", "true") + await sp.put_async("global", "global", "migration_done_v4", True) logger.info("数据库迁移完成。") diff --git a/astrbot/core/db/migration/migra_3_to_4.py b/astrbot/core/db/migration/migra_3_to_4.py index b8947cbd4..4aa5082db 100644 --- a/astrbot/core/db/migration/migra_3_to_4.py +++ b/astrbot/core/db/migration/migra_3_to_4.py @@ -2,8 +2,9 @@ import json import datetime from .. import BaseDatabase from .sqlite_v3 import SQLiteDatabase as SQLiteV3DatabaseV3 +from .shared_preferences_v3 import sp as sp_v3 from astrbot.core.config.default import DB_PATH -from astrbot.api import logger +from astrbot.api import logger, sp from astrbot.core.config import AstrBotConfig from astrbot.core.platform.astr_message_event import MessageSesion from sqlalchemy.ext.asyncio import AsyncSession @@ -89,7 +90,7 @@ async def migration_conversation_table( async def migration_platform_table( - db_helper: BaseDatabase, platform_id_map: dict[str, str] + db_helper: BaseDatabase, platform_id_map: dict[str, dict[str, str]] ): db_helper_v3 = SQLiteV3DatabaseV3( db_path=DB_PATH.replace("data_v4.db", "data_v3.db") @@ -122,9 +123,7 @@ async def migration_platform_table( for bucket_idx, bucket_end in enumerate(range(start_time, end_time, 3600)): if bucket_idx % 500 == 0: progress = int((bucket_idx + 1) / total_buckets * 100) - logger.info( - f"进度: {progress}% ({bucket_idx + 1}/{total_buckets})" - ) + logger.info(f"进度: {progress}% ({bucket_idx + 1}/{total_buckets})") cnt = 0 while ( idx < len(platform_stats_v3) @@ -258,3 +257,82 @@ async def migration_persona_data( ) except Exception as e: logger.error(f"解析 Persona 配置失败:{e}") + + +async def migration_preferences( + db_helper: BaseDatabase, platform_id_map: dict[str, dict[str, str]] +): + # 1. global scope migration + keys = [ + "inactivated_llm_tools", + "inactivated_plugins", + "curr_provider", + "curr_provider_tts", + "curr_provider_stt", + "alter_cmd", + ] + for key in keys: + value = sp_v3.get(key) + if value is not None: + await sp.put_async("global", "global", key, value) + logger.info(f"迁移全局偏好设置 {key} 成功,值: {value}") + + # 2. umo scope migration + session_conversation = sp_v3.get("session_conversation", default={}) + for umo, conversation_id in session_conversation.items(): + if not umo or not conversation_id: + continue + try: + session = MessageSesion.from_str(session_str=umo) + platform_id = get_platform_id(platform_id_map, session.platform_name) + session.platform_id = platform_id + await sp.put_async("umo", str(session), "sel_conv_id", conversation_id) + logger.info(f"迁移会话 {umo} 的对话数据到新表成功,平台 ID: {platform_id}") + except Exception as e: + logger.error(f"迁移会话 {umo} 的对话数据失败: {e}", exc_info=True) + + session_service_config = sp_v3.get("session_service_config", default={}) + for umo, config in session_service_config.items(): + if not umo or not config: + continue + try: + session = MessageSesion.from_str(session_str=umo) + platform_id = get_platform_id(platform_id_map, session.platform_name) + session.platform_id = platform_id + + await sp.put_async("umo", str(session), "session_service_config", config) + + logger.info(f"迁移会话 {umo} 的服务配置到新表成功,平台 ID: {platform_id}") + except Exception as e: + logger.error(f"迁移会话 {umo} 的服务配置失败: {e}", exc_info=True) + + session_variables = sp_v3.get("session_variables", default={}) + for umo, variables in session_variables.items(): + if not umo or not variables: + continue + try: + session = MessageSesion.from_str(session_str=umo) + platform_id = get_platform_id(platform_id_map, session.platform_name) + session.platform_id = platform_id + await sp.put_async("umo", str(session), "session_variables", variables) + except Exception as e: + logger.error(f"迁移会话 {umo} 的变量失败: {e}", exc_info=True) + + session_provider_perf = sp_v3.get("session_provider_perf", default={}) + for umo, perf in session_provider_perf.items(): + if not umo or not perf: + continue + try: + session = MessageSesion.from_str(session_str=umo) + platform_id = get_platform_id(platform_id_map, session.platform_name) + session.platform_id = platform_id + + for provider_type, provider_id in perf.items(): + await sp.put_async( + "umo", str(session), f"provider_perf_{provider_type}", provider_id + ) + logger.info( + f"迁移会话 {umo} 的提供商偏好到新表成功,平台 ID: {platform_id}" + ) + except Exception as e: + logger.error(f"迁移会话 {umo} 的提供商偏好失败: {e}", exc_info=True) diff --git a/astrbot/core/db/migration/shared_preferences_v3.py b/astrbot/core/db/migration/shared_preferences_v3.py new file mode 100644 index 000000000..dda2cbcaf --- /dev/null +++ b/astrbot/core/db/migration/shared_preferences_v3.py @@ -0,0 +1,45 @@ +import json +import os +from typing import TypeVar +from astrbot.core.utils.astrbot_path import get_astrbot_data_path + +_VT = TypeVar("_VT") + +class SharedPreferences: + def __init__(self, path=None): + if path is None: + path = os.path.join(get_astrbot_data_path(), "shared_preferences.json") + self.path = path + self._data = self._load_preferences() + + def _load_preferences(self): + if os.path.exists(self.path): + try: + with open(self.path, "r") as f: + return json.load(f) + except json.JSONDecodeError: + os.remove(self.path) + return {} + + def _save_preferences(self): + with open(self.path, "w") as f: + json.dump(self._data, f, indent=4, ensure_ascii=False) + f.flush() + + def get(self, key, default: _VT = None) -> _VT: + return self._data.get(key, default) + + def put(self, key, value): + self._data[key] = value + self._save_preferences() + + def remove(self, key): + if key in self._data: + del self._data[key] + self._save_preferences() + + def clear(self): + self._data.clear() + self._save_preferences() + +sp = SharedPreferences() diff --git a/astrbot/core/db/po.py b/astrbot/core/db/po.py index cbc8c8d4e..88113d130 100644 --- a/astrbot/core/db/po.py +++ b/astrbot/core/db/po.py @@ -97,18 +97,34 @@ class Persona(SQLModel, table=True): class Preference(SQLModel, table=True): - """This class represents user preferences for bots.""" + """This class represents preferences for bots.""" __tablename__ = "preferences" - key: str = Field(primary_key=True, nullable=False) - value: str = Field(sa_type=Text, nullable=False) + id: int | None = Field( + default=None, primary_key=True, sa_column_kwargs={"autoincrement": True} + ) + scope: str = Field(nullable=False) + """Scope of the preference, such as 'global', 'umo', 'plugin'.""" + scope_id: str = Field(nullable=False) + """ID of the scope, such as 'global', 'umo', 'plugin_name'.""" + key: str = Field(nullable=False) + value: dict = Field(sa_type=JSON, nullable=False) created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) updated_at: datetime = Field( default_factory=lambda: datetime.now(timezone.utc), sa_column_kwargs={"onupdate": datetime.now(timezone.utc)}, ) + __table_args__ = ( + UniqueConstraint( + "scope", + "scope_id", + "key", + name="uix_preference_scope_scope_id_key", + ), + ) + class PlatformMessageHistory(SQLModel, table=True): """This class represents the message history for a specific platform. @@ -184,8 +200,8 @@ class Conversation: """对话 ID, 是 uuid 格式的字符串""" history: str = "" """字符串格式的对话列表。""" - title: str = "" - persona_id: str = "" + title: str | None = "" + persona_id: str | None = "" created_at: int = 0 updated_at: int = 0 diff --git a/astrbot/core/db/sqlite.py b/astrbot/core/db/sqlite.py index 0ecf787e5..418b35761 100644 --- a/astrbot/core/db/sqlite.py +++ b/astrbot/core/db/sqlite.py @@ -21,6 +21,7 @@ from sqlalchemy.sql import func NOT_GIVEN = T.TypeVar("NOT_GIVEN") + class SQLiteDatabase(BaseDatabase): def __init__(self, db_path: str) -> None: self.db_path = db_path @@ -373,29 +374,77 @@ class SQLiteDatabase(BaseDatabase): delete(Persona).where(Persona.persona_id == persona_id) ) - async def insert_preference_or_update(self, key, value): + async def insert_preference_or_update(self, scope, scope_id, key, value): """Insert a new preference record or update if it exists.""" async with self.get_db() as session: session: AsyncSession async with session.begin(): - query = select(Preference).where(Preference.key == key) + query = select(Preference).where( + Preference.scope == scope, + Preference.scope_id == scope_id, + Preference.key == key, + ) result = await session.execute(query) existing_preference = result.scalar_one_or_none() if existing_preference: existing_preference.value = value else: - new_preference = Preference(key=key, value=value) + new_preference = Preference( + scope=scope, scope_id=scope_id, key=key, value=value + ) session.add(new_preference) return existing_preference or new_preference - async def get_preference(self, key): + async def get_preference(self, scope, scope_id, key): """Get a preference by key.""" async with self.get_db() as session: session: AsyncSession - query = select(Preference).where(Preference.key == key) + query = select(Preference).where( + Preference.scope == scope, + Preference.scope_id == scope_id, + Preference.key == key, + ) result = await session.execute(query) return result.scalar_one_or_none() + async def get_preferences(self, scope, scope_id=None, key=None): + """Get all preferences for a specific scope ID or key.""" + async with self.get_db() as session: + session: AsyncSession + query = select(Preference).where(Preference.scope == scope) + if scope_id is not None: + query = query.where(Preference.scope_id == scope_id) + if key is not None: + query = query.where(Preference.key == key) + result = await session.execute(query) + return result.scalars().all() + + async def remove_preference(self, scope, scope_id, key): + """Remove a preference by scope ID and key.""" + async with self.get_db() as session: + session: AsyncSession + async with session.begin(): + await session.execute( + delete(Preference).where( + Preference.scope == scope, + Preference.scope_id == scope_id, + Preference.key == key, + ) + ) + await session.commit() + + async def clear_preferences(self, scope, scope_id): + """Clear all preferences for a specific scope ID.""" + async with self.get_db() as session: + session: AsyncSession + async with session.begin(): + await session.execute( + delete(Preference).where( + Preference.scope == scope, Preference.scope_id == scope_id + ) + ) + await session.commit() + # ==== # Deprecated Methods # ==== diff --git a/astrbot/core/platform/sources/webchat/webchat_adapter.py b/astrbot/core/platform/sources/webchat/webchat_adapter.py index aaac8e289..43da100f4 100644 --- a/astrbot/core/platform/sources/webchat/webchat_adapter.py +++ b/astrbot/core/platform/sources/webchat/webchat_adapter.py @@ -77,7 +77,7 @@ class WebChatAdapter(Platform): os.makedirs(self.imgs_dir, exist_ok=True) self.metadata = PlatformMetadata( - name="webchat", description="webchat", id=self.config.get("id", "") + name="webchat", description="webchat", id="webchat" ) async def send_by_session( diff --git a/astrbot/core/provider/func_tool_manager.py b/astrbot/core/provider/func_tool_manager.py index 92aedc6c6..a2b006dd0 100644 --- a/astrbot/core/provider/func_tool_manager.py +++ b/astrbot/core/provider/func_tool_manager.py @@ -425,10 +425,17 @@ class FunctionToolManager: if func_tool is not None: func_tool.active = False - inactivated_llm_tools: list = sp.get("inactivated_llm_tools", []) + inactivated_llm_tools: list = sp.get( + "inactivated_llm_tools", [], scope="global", scope_id="global" + ) if name not in inactivated_llm_tools: inactivated_llm_tools.append(name) - sp.put("inactivated_llm_tools", inactivated_llm_tools) + sp.put( + "inactivated_llm_tools", + inactivated_llm_tools, + scope="global", + scope_id="global", + ) return True return False @@ -445,10 +452,17 @@ class FunctionToolManager: func_tool.active = True - inactivated_llm_tools: list = sp.get("inactivated_llm_tools", []) + inactivated_llm_tools: list = sp.get( + "inactivated_llm_tools", [], scope="global", scope_id="global" + ) if name in inactivated_llm_tools: inactivated_llm_tools.remove(name) - sp.put("inactivated_llm_tools", inactivated_llm_tools) + sp.put( + "inactivated_llm_tools", + inactivated_llm_tools, + scope="global", + scope_id="global", + ) return True return False diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py index aa0ba25f5..44aeb594f 100644 --- a/astrbot/core/provider/manager.py +++ b/astrbot/core/provider/manager.py @@ -66,7 +66,7 @@ class ProviderManager: return self.persona_mgr.selected_default_persona_v3 async def set_provider( - self, provider_id: str, provider_type: ProviderType, umo: str = None + self, provider_id: str, provider_type: ProviderType, umo: str | None = None ): """设置提供商。 @@ -80,20 +80,20 @@ class ProviderManager: if provider_id not in self.inst_map: raise ValueError(f"提供商 {provider_id} 不存在,无法设置。") if umo: - perf = sp.get("session_provider_perf", {}) - session_perf = perf.get(umo, {}) - session_perf[provider_type.value] = provider_id - perf[umo] = session_perf - sp.put("session_provider_perf", perf) + await sp.session_put( + umo, + f"provider_perf_{provider_type.value}", + provider_id, + ) return # 不启用提供商会话隔离模式的情况 self.curr_provider_inst = self.inst_map[provider_id] if provider_type == ProviderType.TEXT_TO_SPEECH: - sp.put("curr_provider_tts", provider_id) + sp.put("curr_provider_tts", provider_id, scope="global", scope_id="global") elif provider_type == ProviderType.SPEECH_TO_TEXT: - sp.put("curr_provider_stt", provider_id) + sp.put("curr_provider_stt", provider_id, scope="global", scope_id="global") elif provider_type == ProviderType.CHAT_COMPLETION: - sp.put("curr_provider", provider_id) + sp.put("curr_provider", provider_id, scope="global", scope_id="global") async def get_provider_by_id(self, provider_id: str) -> Provider | None: """根据提供商 ID 获取提供商实例""" @@ -111,9 +111,12 @@ class ProviderManager: """ provider = None if umo: - perf = sp.get("session_provider_perf", {}) - session_perf = perf.get(umo, {}) - provider_id = session_perf.get(provider_type.value) + provider_id = sp.get( + f"provider_perf_{provider_type.value}", + None, + scope="umo", + scope_id=umo, + ) if provider_id: provider = self.inst_map.get(provider_id) if not provider: @@ -153,13 +156,22 @@ class ProviderManager: # 设置默认提供商 selected_provider_id = sp.get( - "curr_provider", self.provider_settings.get("default_provider_id") + "curr_provider", + self.provider_settings.get("default_provider_id"), + scope="global", + scope_id="global", ) selected_stt_provider_id = sp.get( - "curr_provider_stt", self.provider_stt_settings.get("provider_id") + "curr_provider_stt", + self.provider_stt_settings.get("provider_id"), + scope="global", + scope_id="global", ) selected_tts_provider_id = sp.get( - "curr_provider_tts", self.provider_tts_settings.get("provider_id") + "curr_provider_tts", + self.provider_tts_settings.get("provider_id"), + scope="global", + scope_id="global", ) self.curr_provider_inst = self.inst_map.get(selected_provider_id) if not self.curr_provider_inst and self.provider_insts: diff --git a/astrbot/core/provider/sources/dashscope_source.py b/astrbot/core/provider/sources/dashscope_source.py index 46b12726b..4e14d20da 100644 --- a/astrbot/core/provider/sources/dashscope_source.py +++ b/astrbot/core/provider/sources/dashscope_source.py @@ -75,8 +75,7 @@ class ProviderDashscope(ProviderOpenAIOfficial): # 获得会话变量 payload_vars = self.variables.copy() # 动态变量 - session_vars = sp.get("session_variables", {}) - session_var = session_vars.get(session_id, {}) + session_var = await sp.session_get(session_id, "session_variables", default={}) payload_vars.update(session_var) if ( diff --git a/astrbot/core/provider/sources/dify_source.py b/astrbot/core/provider/sources/dify_source.py index 9539227fe..e19e912ac 100644 --- a/astrbot/core/provider/sources/dify_source.py +++ b/astrbot/core/provider/sources/dify_source.py @@ -97,8 +97,7 @@ class ProviderDify(Provider): # 获得会话变量 payload_vars = self.variables.copy() # 动态变量 - session_vars = sp.get("session_variables", {}) - session_var = session_vars.get(session_id, {}) + session_var = await sp.session_get(session_id, "session_variables", default={}) payload_vars.update(session_var) payload_vars["system_prompt"] = system_prompt diff --git a/astrbot/core/star/session_llm_manager.py b/astrbot/core/star/session_llm_manager.py index 4bceb1109..6c5bc994d 100644 --- a/astrbot/core/star/session_llm_manager.py +++ b/astrbot/core/star/session_llm_manager.py @@ -2,8 +2,6 @@ 会话服务管理器 - 负责管理每个会话的LLM、TTS等服务的启停状态 """ -from typing import Dict - from astrbot.core import logger, sp from astrbot.core.platform.astr_message_event import AstrMessageEvent @@ -26,8 +24,9 @@ class SessionServiceManager: bool: True表示启用,False表示禁用 """ # 获取会话服务配置 - session_config = sp.get("session_service_config", {}) or {} - session_services = session_config.get(session_id, {}) + session_services = sp.get( + "session_service_config", {}, scope="umo", scope_id=session_id + ) # 如果配置了该会话的LLM状态,返回该状态 llm_enabled = session_services.get("llm_enabled") @@ -45,16 +44,13 @@ class SessionServiceManager: session_id: 会话ID (unified_msg_origin) enabled: True表示启用,False表示禁用 """ - # 获取当前配置 - session_config = sp.get("session_service_config", {}) or {} - if session_id not in session_config: - session_config[session_id] = {} - - # 设置LLM状态 - session_config[session_id]["llm_enabled"] = enabled - - # 保存配置 - sp.put("session_service_config", session_config) + session_config = ( + sp.get("session_service_config", {}, scope="umo", scope_id=session_id) or {} + ) + session_config["llm_enabled"] = enabled + sp.put( + "session_service_config", session_config, scope="umo", scope_id=session_id + ) logger.info( f"会话 {session_id} 的LLM状态已更新为: {'启用' if enabled else '禁用'}" @@ -88,8 +84,9 @@ class SessionServiceManager: bool: True表示启用,False表示禁用 """ # 获取会话服务配置 - session_config = sp.get("session_service_config", {}) or {} - session_services = session_config.get(session_id, {}) + session_services = sp.get( + "session_service_config", {}, scope="umo", scope_id=session_id + ) # 如果配置了该会话的TTS状态,返回该状态 tts_enabled = session_services.get("tts_enabled") @@ -107,16 +104,13 @@ class SessionServiceManager: session_id: 会话ID (unified_msg_origin) enabled: True表示启用,False表示禁用 """ - # 获取当前配置 - session_config = sp.get("session_service_config", {}) or {} - if session_id not in session_config: - session_config[session_id] = {} - - # 设置TTS状态 - session_config[session_id]["tts_enabled"] = enabled - - # 保存配置 - sp.put("session_service_config", session_config) + session_config = ( + sp.get("session_service_config", {}, scope="umo", scope_id=session_id) or {} + ) + session_config["tts_enabled"] = enabled + sp.put( + "session_service_config", session_config, scope="umo", scope_id=session_id + ) logger.info( f"会话 {session_id} 的TTS状态已更新为: {'启用' if enabled else '禁用'}" @@ -150,8 +144,9 @@ class SessionServiceManager: bool: True表示启用,False表示禁用 """ # 获取会话服务配置 - session_config = sp.get("session_service_config", {}) or {} - session_services = session_config.get(session_id, {}) + session_services = sp.get( + "session_service_config", {}, scope="umo", scope_id=session_id + ) # 如果配置了该会话的整体状态,返回该状态 session_enabled = session_services.get("session_enabled") @@ -169,16 +164,13 @@ class SessionServiceManager: session_id: 会话ID (unified_msg_origin) enabled: True表示启用,False表示禁用 """ - # 获取当前配置 - session_config = sp.get("session_service_config", {}) or {} - if session_id not in session_config: - session_config[session_id] = {} - - # 设置会话整体状态 - session_config[session_id]["session_enabled"] = enabled - - # 保存配置 - sp.put("session_service_config", session_config) + session_config = ( + sp.get("session_service_config", {}, scope="umo", scope_id=session_id) or {} + ) + session_config["session_enabled"] = enabled + sp.put( + "session_service_config", session_config, scope="umo", scope_id=session_id + ) logger.info( f"会话 {session_id} 的整体状态已更新为: {'启用' if enabled else '禁用'}" @@ -202,7 +194,7 @@ class SessionServiceManager: # ============================================================================= @staticmethod - def get_session_custom_name(session_id: str) -> str: + def get_session_custom_name(session_id: str) -> str | None: """获取会话的自定义名称 Args: @@ -211,8 +203,9 @@ class SessionServiceManager: Returns: str: 自定义名称,如果没有设置则返回None """ - session_config = sp.get("session_service_config", {}) or {} - session_services = session_config.get(session_id, {}) + session_services = sp.get( + "session_service_config", {}, scope="umo", scope_id=session_id + ) return session_services.get("custom_name") @staticmethod @@ -223,20 +216,17 @@ class SessionServiceManager: session_id: 会话ID (unified_msg_origin) custom_name: 自定义名称,可以为空字符串来清除名称 """ - # 获取当前配置 - session_config = sp.get("session_service_config", {}) or {} - if session_id not in session_config: - session_config[session_id] = {} - - # 设置自定义名称 + session_config = ( + sp.get("session_service_config", {}, scope="umo", scope_id=session_id) or {} + ) if custom_name and custom_name.strip(): - session_config[session_id]["custom_name"] = custom_name.strip() + session_config["custom_name"] = custom_name.strip() else: # 如果传入空名称,则删除自定义名称 - session_config[session_id].pop("custom_name", None) - - # 保存配置 - sp.put("session_service_config", session_config) + session_config.pop("custom_name", None) + sp.put( + "session_service_config", session_config, scope="umo", scope_id=session_id + ) logger.info( f"会话 {session_id} 的自定义名称已更新为: {custom_name.strip() if custom_name and custom_name.strip() else '已清除'}" @@ -258,36 +248,3 @@ class SessionServiceManager: # 如果没有自定义名称,返回session_id的最后一段 return session_id.split(":")[2] if session_id.count(":") >= 2 else session_id - - # ============================================================================= - # 通用配置方法 - # ============================================================================= - - @staticmethod - def get_session_service_config(session_id: str) -> Dict[str, bool]: - """获取指定会话的服务配置 - - Args: - session_id: 会话ID (unified_msg_origin) - - Returns: - Dict[str, bool]: 包含session_enabled、llm_enabled、tts_enabled的字典 - """ - session_config = sp.get("session_service_config", {}) or {} - return session_config.get( - session_id, - { - "session_enabled": True, # 默认启用 - "llm_enabled": True, # 默认启用 - "tts_enabled": True, # 默认启用 - }, - ) - - @staticmethod - def get_all_session_configs() -> Dict[str, Dict[str, bool]]: - """获取所有会话的服务配置 - - Returns: - Dict[str, Dict[str, bool]]: 所有会话的服务配置 - """ - return sp.get("session_service_config", {}) or {} diff --git a/astrbot/core/star/session_plugin_manager.py b/astrbot/core/star/session_plugin_manager.py index c0d1bbd73..5c7303e8d 100644 --- a/astrbot/core/star/session_plugin_manager.py +++ b/astrbot/core/star/session_plugin_manager.py @@ -22,7 +22,9 @@ class SessionPluginManager: bool: True表示启用,False表示禁用 """ # 获取会话插件配置 - session_plugin_config = sp.get("session_plugin_config", {}) or {} + session_plugin_config = sp.get( + "session_plugin_config", {}, scope="umo", scope_id=session_id + ) session_config = session_plugin_config.get(session_id, {}) enabled_plugins = session_config.get("enabled_plugins", []) @@ -51,7 +53,9 @@ class SessionPluginManager: enabled: True表示启用,False表示禁用 """ # 获取当前配置 - session_plugin_config = sp.get("session_plugin_config", {}) or {} + session_plugin_config = sp.get( + "session_plugin_config", {}, scope="umo", scope_id=session_id + ) if session_id not in session_plugin_config: session_plugin_config[session_id] = { "enabled_plugins": [], @@ -79,7 +83,9 @@ class SessionPluginManager: session_config["enabled_plugins"] = enabled_plugins session_config["disabled_plugins"] = disabled_plugins session_plugin_config[session_id] = session_config - sp.put("session_plugin_config", session_plugin_config) + sp.put( + "session_plugin_config", session_plugin_config, scope="umo", scope_id=session_id + ) logger.info( f"会话 {session_id} 的插件 {plugin_name} 状态已更新为: {'启用' if enabled else '禁用'}" @@ -95,7 +101,9 @@ class SessionPluginManager: Returns: Dict[str, List[str]]: 包含enabled_plugins和disabled_plugins的字典 """ - session_plugin_config = sp.get("session_plugin_config", {}) or {} + session_plugin_config = sp.get( + "session_plugin_config", {}, scope="umo", scope_id=session_id + ) return session_plugin_config.get( session_id, {"enabled_plugins": [], "disabled_plugins": []} ) diff --git a/astrbot/core/star/star_manager.py b/astrbot/core/star/star_manager.py index 214d4f455..13460d8d7 100644 --- a/astrbot/core/star/star_manager.py +++ b/astrbot/core/star/star_manager.py @@ -374,10 +374,9 @@ class PluginManager: - success (bool): 是否全部加载成功 - error_message (str|None): 错误信息,成功时为 None """ - inactivated_plugins: list = sp.get("inactivated_plugins", []) - inactivated_llm_tools: list = sp.get("inactivated_llm_tools", []) - - alter_cmd = sp.get("alter_cmd", {}) + inactivated_plugins = await sp.global_get("inactivated_plugins", []) + inactivated_llm_tools = await sp.global_get("inactivated_llm_tools", []) + alter_cmd = await sp.global_get("alter_cmd", {}) plugin_modules = self._get_plugin_modules() if plugin_modules is None: @@ -787,12 +786,12 @@ class PluginManager: await self._terminate_plugin(plugin) # 加入到 shared_preferences 中 - inactivated_plugins: list = sp.get("inactivated_plugins", []) + inactivated_plugins: list = await sp.global_get("inactivated_plugins", []) if plugin.module_path not in inactivated_plugins: inactivated_plugins.append(plugin.module_path) inactivated_llm_tools: list = list( - set(sp.get("inactivated_llm_tools", [])) + set(await sp.global_get("inactivated_llm_tools", [])) ) # 后向兼容 # 禁用插件启用的 llm_tool @@ -802,8 +801,8 @@ class PluginManager: if func_tool.name not in inactivated_llm_tools: inactivated_llm_tools.append(func_tool.name) - sp.put("inactivated_plugins", inactivated_plugins) - sp.put("inactivated_llm_tools", inactivated_llm_tools) + await sp.global_put("inactivated_plugins", inactivated_plugins) + await sp.global_put("inactivated_llm_tools", inactivated_llm_tools) plugin.activated = False @@ -829,11 +828,11 @@ class PluginManager: async def turn_on_plugin(self, plugin_name: str): plugin = self.context.get_registered_star(plugin_name) - inactivated_plugins: list = sp.get("inactivated_plugins", []) - inactivated_llm_tools: list = sp.get("inactivated_llm_tools", []) + inactivated_plugins: list = await sp.global_get("inactivated_plugins", []) + inactivated_llm_tools: list = await sp.global_get("inactivated_llm_tools", []) if plugin.module_path in inactivated_plugins: inactivated_plugins.remove(plugin.module_path) - sp.put("inactivated_plugins", inactivated_plugins) + await sp.global_put("inactivated_plugins", inactivated_plugins) # 启用插件启用的 llm_tool for func_tool in llm_tools.func_list: @@ -843,7 +842,7 @@ class PluginManager: ): inactivated_llm_tools.remove(func_tool.name) func_tool.active = True - sp.put("inactivated_llm_tools", inactivated_llm_tools) + await sp.global_put("inactivated_llm_tools", inactivated_llm_tools) await self.reload(plugin_name) diff --git a/astrbot/core/utils/shared_preferences.py b/astrbot/core/utils/shared_preferences.py index 42018d19e..b20333405 100644 --- a/astrbot/core/utils/shared_preferences.py +++ b/astrbot/core/utils/shared_preferences.py @@ -1,43 +1,214 @@ -import json +from astrbot.core.db import BaseDatabase +from astrbot.core.db.po import Preference +import threading +import asyncio import os -from typing import TypeVar +from typing import TypeVar, Any, overload from .astrbot_path import get_astrbot_data_path + _VT = TypeVar("_VT") + class SharedPreferences: - def __init__(self, path=None): - if path is None: - path = os.path.join(get_astrbot_data_path(), "shared_preferences.json") - self.path = path - self._data = self._load_preferences() + def __init__(self, db_helper: BaseDatabase, json_storage_path=None): + if json_storage_path is None: + json_storage_path = os.path.join( + get_astrbot_data_path(), "shared_preferences.json" + ) + self.path = json_storage_path + self.db_helper = db_helper - def _load_preferences(self): - if os.path.exists(self.path): - try: - with open(self.path, "r") as f: - return json.load(f) - except json.JSONDecodeError: - os.remove(self.path) - return {} + async def get_async( + self, + scope: str, + scope_id: str, + key: str, + default: _VT = None, + ) -> _VT: + """获取指定范围和键的偏好设置""" + if scope_id is not None and key is not None: + result = await self.db_helper.get_preference(scope, scope_id, key) + if result: + ret = result.value["val"] + else: + ret = default + return ret + else: + raise ValueError( + "scope_id and key cannot be None when getting a specific preference." + ) - def _save_preferences(self): - with open(self.path, "w") as f: - json.dump(self._data, f, indent=4, ensure_ascii=False) - f.flush() + async def range_get_async( + self, scope: str, scope_id: str | None = None, key: str | None = None + ) -> list[Preference]: + """获取指定范围的偏好设置 + Note: 返回 Preference 列表,其中的 value 属性是一个 dict,value["val"] 为值。scope_id 和 key 可以为 None,这时返回该范围下所有的偏好设置。 + """ + ret = await self.db_helper.get_preferences(scope, scope_id, key) + return ret - def get(self, key, default: _VT = None) -> _VT: - return self._data.get(key, default) + @overload + async def session_get( + self, umo: None, key: str, default: Any = None + ) -> list[Preference]: ... - def put(self, key, value): - self._data[key] = value - self._save_preferences() + @overload + async def session_get( + self, umo: str, key: None, default: Any = None + ) -> list[Preference]: ... - def remove(self, key): - if key in self._data: - del self._data[key] - self._save_preferences() + @overload + async def session_get( + self, umo: None, key: None, default: Any = None + ) -> list[Preference]: ... - def clear(self): - self._data.clear() - self._save_preferences() + async def session_get( + self, umo: str | None, key: str | None = None, default: _VT = None + ) -> _VT | list[Preference]: + """获取会话范围的偏好设置 + + Note: 当 scope_id 或者 key 为 None,时,返回 Preference 列表,其中的 value 属性是一个 dict,value["val"] 为值。 + """ + if umo is None or key is None: + return await self.range_get_async("umo", umo, key) + return await self.get_async("umo", umo, key, default) + + @overload + async def global_get( + self, key: None, default: Any = None + ) -> list[Preference]: ... + + @overload + async def global_get( + self, key: str, default: _VT = None + ) -> _VT: ... + + async def global_get( + self, key: str | None, default: _VT = None + ) -> _VT | list[Preference]: + """获取全局范围的偏好设置 + + Note: 当 scope_id 或者 key 为 None,时,返回 Preference 列表,其中的 value 属性是一个 dict,value["val"] 为值。 + """ + if key is None: + return await self.range_get_async("global", "global", key) + return await self.get_async("global", "global", key, default) + + async def put_async(self, scope: str, scope_id: str, key: str, value: Any): + """设置指定范围和键的偏好设置""" + await self.db_helper.insert_preference_or_update( + scope, scope_id, key, {"val": value} + ) + + async def session_put(self, umo: str, key: str, value: Any): + await self.put_async("umo", umo, key, value) + + async def global_put(self, key: str, value: Any): + await self.put_async("global", "global", key, value) + + async def remove_async(self, scope: str, scope_id: str, key: str): + """删除指定范围和键的偏好设置""" + await self.db_helper.remove_preference(scope, scope_id, key) + + async def session_remove(self, umo: str, key: str): + await self.remove_async("umo", umo, key) + + async def global_remove(self, key: str): + """删除全局偏好设置""" + await self.remove_async("global", "global", key) + + async def clear_async(self, scope: str, scope_id: str): + """清空指定范围的所有偏好设置""" + await self.db_helper.clear_preferences(scope, scope_id) + + # ==== + # DEPRECATED METHODS + # ==== + + def get( + self, + key: str, + default: _VT = None, + scope: str | None = None, + scope_id: str | None = "", + ) -> _VT: + """获取偏好设置(已弃用)""" + result: _VT | None = None + + def runner(): + nonlocal result, scope, scope_id + scope = scope or "unknown" + if scope_id == "": + scope_id = "unknown" + if scope_id is None or key is None: + # result = asyncio.run(self.range_get_async(scope, scope_id, key)) + raise ValueError( + "scope_id and key cannot be None when getting a specific preference." + ) + else: + result = asyncio.run(self.get_async(scope, scope_id, key, default)) + + t = threading.Thread(target=runner) + t.start() + t.join() + + if result is None: + return default + + return result + + def range_get( + self, scope: str, scope_id: str | None = None, key: str | None = None + ) -> list[Preference]: + """获取指定范围的偏好设置(已弃用)""" + result: list[Preference] = [] + + def runner(): + nonlocal result, scope, scope_id, key + result = asyncio.run(self.range_get_async(scope, scope_id, key)) + + t = threading.Thread(target=runner) + t.start() + t.join() + + return result + + def put(self, key, value, scope: str | None = None, scope_id: str | None = None): + """设置偏好设置(已弃用)""" + + def runner(): + nonlocal scope, scope_id + scope = scope or "unknown" + scope_id = scope_id or "unknown" + asyncio.run(self.put_async(scope, scope_id, key, value)) + + t = threading.Thread(target=runner) + t.start() + t.join() + + def remove(self, key, scope: str | None = None, scope_id: str | None = None): + """删除偏好设置(已弃用)""" + + def runner(): + nonlocal scope, scope_id + scope = scope or "unknown" + scope_id = scope_id or "unknown" + asyncio.run(self.remove_async(scope, scope_id, key)) + + t = threading.Thread(target=runner) + t.start() + t.join() + + def clear(self, scope: str | None = None, scope_id: str | None = None): + """清空偏好设置(已弃用)""" + + def runner(): + nonlocal scope, scope_id + scope = scope or "unknown" + scope_id = scope_id or "unknown" + asyncio.run(self.clear_async(scope, scope_id)) + + t = threading.Thread(target=runner) + t.start() + t.join() diff --git a/astrbot/dashboard/routes/session_management.py b/astrbot/dashboard/routes/session_management.py index c8d66e01c..5a3dd3d28 100644 --- a/astrbot/dashboard/routes/session_management.py +++ b/astrbot/dashboard/routes/session_management.py @@ -38,7 +38,13 @@ class SessionManagementRoute(Route): async def list_sessions(self): """获取所有会话的列表,包括 persona 和 provider 信息""" try: - session_conversations = sp.get("session_conversation", {}) or {} + preferences = await sp.session_get(umo=None, key="sel_conv_id", default=[]) + session_conversations = {} + for pref in preferences: + session_conversations[pref.scope_id] = pref.value["val"] + + logger.debug(session_conversations) + provider_manager = self.core_lifecycle.provider_manager persona_mgr = self.core_lifecycle.persona_mgr personas = persona_mgr.personas_v3 @@ -51,13 +57,9 @@ class SessionManagementRoute(Route): "session_id": session_id, "conversation_id": conversation_id, "persona_id": None, - "persona_name": None, "chat_provider_id": None, - "chat_provider_name": None, "stt_provider_id": None, - "stt_provider_name": None, "tts_provider_id": None, - "tts_provider_name": None, "session_enabled": SessionServiceManager.is_session_enabled( session_id ), @@ -92,16 +94,15 @@ class SessionManagementRoute(Route): if conversation.persona_id and conversation.persona_id != "[%None]": for persona in personas: if persona["name"] == conversation.persona_id: - session_info["persona_name"] = persona["name"] + session_info["persona_id"] = persona["name"] break elif conversation.persona_id == "[%None]": - session_info["persona_name"] = "无人格" + session_info["persona_id"] = "无人格" else: # 使用默认人格 default_persona = persona_mgr.selected_default_persona_v3 if default_persona: session_info["persona_id"] = default_persona["name"] - session_info["persona_name"] = default_persona["name"] # 获取 provider 信息 provider_manager = self.core_lifecycle.provider_manager @@ -117,15 +118,12 @@ class SessionManagementRoute(Route): if chat_provider: meta = chat_provider.meta() session_info["chat_provider_id"] = meta.id - session_info["chat_provider_name"] = meta.id if tts_provider: meta = tts_provider.meta() session_info["tts_provider_id"] = meta.id - session_info["tts_provider_name"] = meta.id if stt_provider: meta = stt_provider.meta() session_info["stt_provider_id"] = meta.id - session_info["stt_provider_name"] = meta.id sessions.append(session_info) diff --git a/packages/astrbot/main.py b/packages/astrbot/main.py index 704b7df2b..0c85cfbf7 100644 --- a/packages/astrbot/main.py +++ b/packages/astrbot/main.py @@ -503,7 +503,7 @@ UID: {user_id} 此 ID 可用于设置管理员。 scene = RstScene.get_scene(is_group, is_unique_session) - alter_cmd_cfg = sp.get("alter_cmd", {}) + alter_cmd_cfg = await sp.get_async("global", "global", "alter_cmd", {}) plugin_config = alter_cmd_cfg.get("astrbot", {}) reset_cfg = plugin_config.get("reset", {}) @@ -1101,29 +1101,22 @@ UID: {user_id} 此 ID 可用于设置管理员。 async def set_variable(self, event: AstrMessageEvent, key: str, value: str): # session_id = event.get_session_id() uid = event.unified_msg_origin - session_vars = sp.get("session_variables", {}) - - session_var = session_vars.get(uid, {}) + session_var = await sp.session_get(uid, "session_variables", {}) session_var[key] = value - - session_vars[uid] = session_var - - sp.put("session_variables", session_vars) + await sp.session_put(uid, "session_variables", session_var) yield event.plain_result(f"会话 {uid} 变量 {key} 存储成功。使用 /unset 移除。") @filter.command("unset") async def unset_variable(self, event: AstrMessageEvent, key: str): uid = event.unified_msg_origin - session_vars = sp.get("session_variables", {}) - - session_var = session_vars.get(uid, {}) + session_var = await sp.session_get(umo="uid", key="session_variables", default={}) if key not in session_var: yield event.plain_result("没有那个变量名。格式 /unset 变量名。") else: del session_var[key] - sp.put("session_variables", session_vars) + await sp.session_put(uid, "session_variables", session_var) yield event.plain_result(f"会话 {uid} 变量 {key} 移除成功。") @filter.platform_adapter_type(filter.PlatformAdapterType.ALL) @@ -1356,7 +1349,7 @@ UID: {user_id} 此 ID 可用于设置管理员。 # 对reset权限进行特殊处理 # ============================ if cmd_name == "reset" and cmd_type == "config": - alter_cmd_cfg = sp.get("alter_cmd", {}) + alter_cmd_cfg = await sp.global_get("alter_cmd", {}) plugin_ = alter_cmd_cfg.get("astrbot", {}) reset_cfg = plugin_.get("reset", {}) @@ -1391,7 +1384,7 @@ UID: {user_id} 此 ID 可用于设置管理员。 scene = RstScene.from_index(scene_num) scene_key = scene.key - self.update_reset_permission(scene_key, perm_type) + await self.update_reset_permission(scene_key, perm_type) yield event.plain_result( f"已将 reset 命令在{scene.name}场景下的权限设为{perm_type}" @@ -1422,14 +1415,14 @@ UID: {user_id} 此 ID 可用于设置管理员。 found_plugin = star_map[found_command.handler_module_path] - alter_cmd_cfg = sp.get("alter_cmd", {}) + alter_cmd_cfg = await sp.global_get("alter_cmd", {}) plugin_ = alter_cmd_cfg.get(found_plugin.name, {}) cfg = plugin_.get(found_command.handler_name, {}) cfg["permission"] = cmd_type plugin_[found_command.handler_name] = cfg alter_cmd_cfg[found_plugin.name] = plugin_ - sp.put("alter_cmd", alter_cmd_cfg) + await sp.global_put("alter_cmd", alter_cmd_cfg) # 注入权限过滤器 found_permission_filter = False @@ -1453,17 +1446,17 @@ UID: {user_id} 此 ID 可用于设置管理员。 yield event.plain_result(f"已将 {cmd_name} 设置为 {cmd_type} 指令") - def update_reset_permission(self, scene_key: str, perm_type: str): + async def update_reset_permission(self, scene_key: str, perm_type: str): """更新reset命令在特定场景下的权限设置 Args: scene_key (str): 场景编号,1-3 perm_type (str): 权限类型,admin或member """ - alter_cmd_cfg = sp.get("alter_cmd", {}) + alter_cmd_cfg = await sp.global_get("alter_cmd", {}) plugin_cfg = alter_cmd_cfg.get("astrbot", {}) reset_cfg = plugin_cfg.get("reset", {}) reset_cfg[scene_key] = perm_type plugin_cfg["reset"] = reset_cfg alter_cmd_cfg["astrbot"] = plugin_cfg - sp.put("alter_cmd", alter_cmd_cfg) + await sp.global_put("alter_cmd", alter_cmd_cfg)