refactor: 重构 SharedPreference 类并采用数据库存储替换 json 存储 (#2482)
This commit is contained in:
@@ -20,7 +20,7 @@ html_renderer = HtmlRenderer(t2i_base_url)
|
||||
logger = LogManager.GetLogger(log_name="astrbot")
|
||||
db_helper = SQLiteDatabase(DB_PATH)
|
||||
# 简单的偏好设置存储, 这里后续应该存储到数据库中, 一些部分可以存储到配置中
|
||||
sp = SharedPreferences()
|
||||
sp = SharedPreferences(db_helper=db_helper)
|
||||
# 文件令牌服务
|
||||
file_token_service = FileTokenService()
|
||||
pip_installer = PipInstaller(
|
||||
|
||||
@@ -40,7 +40,9 @@ class AstrBotConfigManager:
|
||||
|
||||
def _load_all_configs(self):
|
||||
"""Load all configurations from the shared preferences."""
|
||||
abconf_data = self.sp.get("abconf_mapping", {})
|
||||
abconf_data = self.sp.get(
|
||||
"abconf_mapping", {}, scope="global", scope_id="global"
|
||||
)
|
||||
for uuid_, meta in abconf_data.items():
|
||||
filename = meta["path"]
|
||||
conf_path = os.path.join(get_astrbot_config_path(), filename)
|
||||
@@ -55,13 +57,13 @@ class AstrBotConfigManager:
|
||||
|
||||
def _is_umo_match(self, p1: str, p2: str) -> bool:
|
||||
"""判断 p2 umo 是否逻辑包含于 p1 umo"""
|
||||
p1 = p1.split(":")
|
||||
p2 = p2.split(":")
|
||||
p1_ls = p1.split(":")
|
||||
p2_ls = p2.split(":")
|
||||
|
||||
if len(p1) != 3 or len(p2) != 3:
|
||||
if len(p1_ls) != 3 or len(p2_ls) != 3:
|
||||
return False # 非法格式
|
||||
|
||||
return all(p == "" or p == t for p, t in zip(p1, p2))
|
||||
return all(p == "" or p == t for p, t in zip(p1_ls, p2_ls))
|
||||
|
||||
def _load_conf_mapping(self, umo: str | MessageSession) -> ConfInfo:
|
||||
"""获取指定 umo 的配置文件 uuid, 如果不存在则返回默认配置(返回 "default")
|
||||
@@ -70,7 +72,9 @@ class AstrBotConfigManager:
|
||||
ConfInfo: 包含配置文件的 uuid, 路径和名称等信息, 是一个 dict 类型
|
||||
"""
|
||||
# uuid -> { "umop": list, "path": str, "name": str }
|
||||
abconf_data = self.sp.get("abconf_mapping", {}) # default is not included here
|
||||
abconf_data = self.sp.get(
|
||||
"abconf_mapping", {}, scope="global", scope_id="global"
|
||||
)
|
||||
if isinstance(umo, MessageSession):
|
||||
umo = str(umo)
|
||||
else:
|
||||
@@ -91,7 +95,7 @@ class AstrBotConfigManager:
|
||||
abconf_path: str,
|
||||
abconf_id: str,
|
||||
umo_parts: list[str] | list[MessageSession],
|
||||
abconf_name: str = None,
|
||||
abconf_name: str | None = None,
|
||||
) -> None:
|
||||
"""保存配置文件的映射关系"""
|
||||
for part in umo_parts:
|
||||
@@ -101,14 +105,16 @@ class AstrBotConfigManager:
|
||||
raise ValueError(
|
||||
"umo_parts must be a list of strings or MessageSession instances"
|
||||
)
|
||||
abconf_data = self.sp.get("abconf_mapping", {})
|
||||
abconf_data = self.sp.get(
|
||||
"abconf_mapping", {}, scope="global", scope_id="global"
|
||||
)
|
||||
random_word = abconf_name or uuid.uuid4().hex[:8]
|
||||
abconf_data[abconf_id] = {
|
||||
"umop": umo_parts,
|
||||
"path": abconf_path,
|
||||
"name": random_word,
|
||||
}
|
||||
self.sp.put("abconf_mapping", abconf_data)
|
||||
self.sp.put("abconf_mapping", abconf_data, scope="global", scope_id="global")
|
||||
|
||||
def get_conf(self, umo: str | MessageSession | None) -> AstrBotConfig:
|
||||
"""获取指定 umo 的配置文件。如果不存在,则 fallback 到默认配置文件。"""
|
||||
@@ -141,7 +147,10 @@ class AstrBotConfigManager:
|
||||
"""获取所有配置文件的元数据列表"""
|
||||
conf_list = []
|
||||
conf_list.append(DEFAULT_CONFIG_CONF_INFO)
|
||||
for uuid_, meta in self.sp.get("abconf_mapping", {}).items():
|
||||
abconf_mapping = self.sp.get(
|
||||
"abconf_mapping", {}, scope="global", scope_id="global"
|
||||
)
|
||||
for uuid_, meta in abconf_mapping.items():
|
||||
conf_list.append(ConfInfo(**meta, id=uuid_))
|
||||
return conf_list
|
||||
|
||||
@@ -149,7 +158,7 @@ class AstrBotConfigManager:
|
||||
self,
|
||||
umo_parts: list[str] | list[MessageSession],
|
||||
config: dict = DEFAULT_CONFIG,
|
||||
name: str = None,
|
||||
name: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
umo 由三个部分组成 [platform_id]:[message_type]:[session_id]。
|
||||
@@ -181,7 +190,9 @@ class AstrBotConfigManager:
|
||||
raise ValueError("不能删除默认配置文件")
|
||||
|
||||
# 从映射中移除
|
||||
abconf_data = self.sp.get("abconf_mapping", {})
|
||||
abconf_data = self.sp.get(
|
||||
"abconf_mapping", {}, scope="global", scope_id="global"
|
||||
)
|
||||
if conf_id not in abconf_data:
|
||||
logger.warning(f"配置文件 {conf_id} 不存在于映射中")
|
||||
return False
|
||||
@@ -206,13 +217,13 @@ class AstrBotConfigManager:
|
||||
|
||||
# 从映射中移除
|
||||
del abconf_data[conf_id]
|
||||
self.sp.put("abconf_mapping", abconf_data)
|
||||
self.sp.put("abconf_mapping", abconf_data, scope="global", scope_id="global")
|
||||
|
||||
logger.info(f"成功删除配置文件 {conf_id}")
|
||||
return True
|
||||
|
||||
def update_conf_info(
|
||||
self, conf_id: str, name: str = None, umo_parts: list[str] = None
|
||||
self, conf_id: str, name: str | None = None, umo_parts: list[str] | None = None
|
||||
) -> bool:
|
||||
"""更新配置文件信息
|
||||
|
||||
@@ -227,7 +238,9 @@ class AstrBotConfigManager:
|
||||
if conf_id == "default":
|
||||
raise ValueError("不能更新默认配置文件的信息")
|
||||
|
||||
abconf_data = self.sp.get("abconf_mapping", {})
|
||||
abconf_data = self.sp.get(
|
||||
"abconf_mapping", {}, scope="global", scope_id="global"
|
||||
)
|
||||
if conf_id not in abconf_data:
|
||||
logger.warning(f"配置文件 {conf_id} 不存在于映射中")
|
||||
return False
|
||||
@@ -249,11 +262,13 @@ class AstrBotConfigManager:
|
||||
abconf_data[conf_id]["umop"] = umo_parts
|
||||
|
||||
# 保存更新
|
||||
self.sp.put("abconf_mapping", abconf_data)
|
||||
self.sp.put("abconf_mapping", abconf_data, scope="global", scope_id="global")
|
||||
logger.info(f"成功更新配置文件 {conf_id} 的信息")
|
||||
return True
|
||||
|
||||
def g(self, umo: str = None, key: str = None, default: _VT = None) -> _VT:
|
||||
def g(
|
||||
self, umo: str | None = None, key: str | None = None, default: _VT = None
|
||||
) -> _VT:
|
||||
"""获取配置项。umo 为 None 时使用默认配置"""
|
||||
if umo is None:
|
||||
return self.confs["default"].get(key, default)
|
||||
|
||||
@@ -6,7 +6,6 @@ AstrBot 会话-对话管理器, 维护两个本地存储, 其中一个是 json
|
||||
"""
|
||||
|
||||
import json
|
||||
import asyncio
|
||||
from astrbot.core import sp
|
||||
from typing import Dict, List
|
||||
from astrbot.core.db import BaseDatabase
|
||||
@@ -17,25 +16,9 @@ class ConversationManager:
|
||||
"""负责管理会话与 LLM 的对话,某个会话当前正在用哪个对话。"""
|
||||
|
||||
def __init__(self, db_helper: BaseDatabase):
|
||||
# session_conversations 字典记录会话ID-对话ID 映射关系
|
||||
self.session_conversations: Dict[str, str] = sp.get("session_conversation", {})
|
||||
self.session_conversations: Dict[str, str] = {}
|
||||
self.db = db_helper
|
||||
self.save_interval = 60 # 每 60 秒保存一次
|
||||
self._start_periodic_save()
|
||||
|
||||
def _start_periodic_save(self):
|
||||
"""启动定时保存任务"""
|
||||
asyncio.create_task(self._periodic_save())
|
||||
|
||||
async def _periodic_save(self):
|
||||
"""定时保存会话对话映射关系到存储中"""
|
||||
while True:
|
||||
await asyncio.sleep(self.save_interval)
|
||||
self._save_to_storage()
|
||||
|
||||
def _save_to_storage(self):
|
||||
"""保存会话对话映射关系到存储中"""
|
||||
sp.put("session_conversation", self.session_conversations)
|
||||
|
||||
def _convert_conv_from_v2_to_v1(self, conv_v2: ConversationV2) -> Conversation:
|
||||
"""将 ConversationV2 对象转换为 Conversation 对象"""
|
||||
@@ -55,10 +38,10 @@ class ConversationManager:
|
||||
async def new_conversation(
|
||||
self,
|
||||
unified_msg_origin: str,
|
||||
platform_id: str = None,
|
||||
content: list[dict] = None,
|
||||
title: str = None,
|
||||
persona_id: str = None,
|
||||
platform_id: str | None = None,
|
||||
content: list[dict] | None = None,
|
||||
title: str | None = None,
|
||||
persona_id: str | None = None,
|
||||
) -> str:
|
||||
"""新建对话,并将当前会话的对话转移到新对话
|
||||
|
||||
@@ -82,8 +65,8 @@ class ConversationManager:
|
||||
persona_id=persona_id,
|
||||
)
|
||||
self.session_conversations[unified_msg_origin] = conv.conversation_id
|
||||
sp.put("session_conversation", self.session_conversations)
|
||||
return str(conv.conversation_id)
|
||||
await sp.session_put(unified_msg_origin, "sel_conv_id", conv.conversation_id)
|
||||
return conv.conversation_id
|
||||
|
||||
async def switch_conversation(self, unified_msg_origin: str, conversation_id: str):
|
||||
"""切换会话的对话
|
||||
@@ -93,10 +76,10 @@ class ConversationManager:
|
||||
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
|
||||
"""
|
||||
self.session_conversations[unified_msg_origin] = conversation_id
|
||||
sp.put("session_conversation", self.session_conversations)
|
||||
await sp.session_put(unified_msg_origin, "sel_conv_id", conversation_id)
|
||||
|
||||
async def delete_conversation(
|
||||
self, unified_msg_origin: str, conversation_id: str = None
|
||||
self, unified_msg_origin: str, conversation_id: str | None = None
|
||||
):
|
||||
"""删除会话的对话,当 conversation_id 为 None 时删除会话当前的对话
|
||||
|
||||
@@ -113,9 +96,9 @@ class ConversationManager:
|
||||
await self.db.delete_conversation(cid=conversation_id)
|
||||
if f:
|
||||
self.session_conversations.pop(unified_msg_origin, None)
|
||||
sp.put("session_conversation", self.session_conversations)
|
||||
await sp.session_remove(unified_msg_origin, "sel_conv_id")
|
||||
|
||||
async def get_curr_conversation_id(self, unified_msg_origin: str) -> str:
|
||||
async def get_curr_conversation_id(self, unified_msg_origin: str) -> str | None:
|
||||
"""获取会话当前的对话 ID
|
||||
|
||||
Args:
|
||||
@@ -123,7 +106,12 @@ class ConversationManager:
|
||||
Returns:
|
||||
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
|
||||
"""
|
||||
return self.session_conversations.get(unified_msg_origin, None)
|
||||
ret = self.session_conversations.get(unified_msg_origin, None)
|
||||
if not ret:
|
||||
ret = await sp.session_get(unified_msg_origin, "sel_conv_id", None)
|
||||
if ret:
|
||||
self.session_conversations[unified_msg_origin] = ret
|
||||
return ret
|
||||
|
||||
async def get_conversation(
|
||||
self,
|
||||
@@ -150,7 +138,7 @@ class ConversationManager:
|
||||
return conv_res
|
||||
|
||||
async def get_conversations(
|
||||
self, unified_msg_origin: str = None, platform_id: str = None
|
||||
self, unified_msg_origin: str | None = None, platform_id: str | None = None
|
||||
) -> List[Conversation]:
|
||||
"""获取对话列表
|
||||
|
||||
@@ -203,10 +191,10 @@ class ConversationManager:
|
||||
async def update_conversation(
|
||||
self,
|
||||
unified_msg_origin: str,
|
||||
conversation_id: str = None,
|
||||
history: list[dict] = None,
|
||||
title: str = None,
|
||||
persona_id: str = None,
|
||||
conversation_id: str | None = None,
|
||||
history: list[dict] | None = None,
|
||||
title: str | None = None,
|
||||
persona_id: str | None = None,
|
||||
):
|
||||
"""更新会话的对话
|
||||
|
||||
@@ -216,8 +204,8 @@ class ConversationManager:
|
||||
history (List[Dict]): 对话历史记录, 是一个字典列表, 每个字典包含 role 和 content 字段
|
||||
"""
|
||||
if not conversation_id:
|
||||
# 如果没有提供 conversation_id,则从 session_conversations 中获取当前的
|
||||
conversation_id = self.session_conversations.get(unified_msg_origin)
|
||||
# 如果没有提供 conversation_id,则获取当前的
|
||||
conversation_id = await self.get_curr_conversation_id(unified_msg_origin)
|
||||
if conversation_id:
|
||||
await self.db.update_conversation(
|
||||
cid=conversation_id,
|
||||
@@ -227,7 +215,7 @@ class ConversationManager:
|
||||
)
|
||||
|
||||
async def update_conversation_title(
|
||||
self, unified_msg_origin: str, title: str, conversation_id: str = None
|
||||
self, unified_msg_origin: str, title: str, conversation_id: str | None = None
|
||||
):
|
||||
"""更新会话的对话标题
|
||||
|
||||
@@ -245,7 +233,10 @@ class ConversationManager:
|
||||
)
|
||||
|
||||
async def update_conversation_persona_id(
|
||||
self, unified_msg_origin: str, persona_id: str, conversation_id: str = None
|
||||
self,
|
||||
unified_msg_origin: str,
|
||||
persona_id: str,
|
||||
conversation_id: str | None = None,
|
||||
):
|
||||
"""更新会话的对话 Persona ID
|
||||
|
||||
|
||||
+40
-21
@@ -74,7 +74,7 @@ class BaseDatabase(abc.ABC):
|
||||
platform_id: str,
|
||||
platform_type: str,
|
||||
count: int = 1,
|
||||
timestamp: datetime.datetime = None,
|
||||
timestamp: datetime.datetime | None = None,
|
||||
) -> None:
|
||||
"""Insert a new platform statistic record."""
|
||||
...
|
||||
@@ -91,7 +91,7 @@ class BaseDatabase(abc.ABC):
|
||||
|
||||
@abc.abstractmethod
|
||||
async def get_conversations(
|
||||
self, user_id: str = None, platform_id: str = None
|
||||
self, user_id: str | None = None, platform_id: str | None = None
|
||||
) -> list[ConversationV2]:
|
||||
"""Get all conversations for a specific user and platform_id(optional).
|
||||
|
||||
@@ -128,12 +128,12 @@ class BaseDatabase(abc.ABC):
|
||||
self,
|
||||
user_id: str,
|
||||
platform_id: str,
|
||||
content: list[dict] = None,
|
||||
title: str = None,
|
||||
persona_id: str = None,
|
||||
cid: str = None,
|
||||
created_at: datetime.datetime = None,
|
||||
updated_at: datetime.datetime = None,
|
||||
content: list[dict] | None = None,
|
||||
title: str | None = None,
|
||||
persona_id: str | None = None,
|
||||
cid: str | None = None,
|
||||
created_at: datetime.datetime | None = None,
|
||||
updated_at: datetime.datetime | None = None,
|
||||
) -> ConversationV2:
|
||||
"""Create a new conversation."""
|
||||
...
|
||||
@@ -142,9 +142,9 @@ class BaseDatabase(abc.ABC):
|
||||
async def update_conversation(
|
||||
self,
|
||||
cid: str,
|
||||
title: str = None,
|
||||
persona_id: str = None,
|
||||
content: list[dict] = None,
|
||||
title: str | None = None,
|
||||
persona_id: str | None = None,
|
||||
content: list[dict] | None = None,
|
||||
) -> None:
|
||||
"""Update a conversation's history."""
|
||||
...
|
||||
@@ -160,8 +160,8 @@ class BaseDatabase(abc.ABC):
|
||||
platform_id: str,
|
||||
user_id: str,
|
||||
content: list[dict],
|
||||
sender_id: str = None,
|
||||
sender_name: str = None,
|
||||
sender_id: str | None = None,
|
||||
sender_name: str | None = None,
|
||||
) -> None:
|
||||
"""Insert a new platform message history record."""
|
||||
...
|
||||
@@ -204,8 +204,8 @@ class BaseDatabase(abc.ABC):
|
||||
self,
|
||||
persona_id: str,
|
||||
system_prompt: str,
|
||||
begin_dialogs: list[str] = None,
|
||||
tools: list[str] = None,
|
||||
begin_dialogs: list[str] | None = None,
|
||||
tools: list[str] | None = None,
|
||||
) -> Persona:
|
||||
"""Insert a new persona record."""
|
||||
...
|
||||
@@ -224,9 +224,9 @@ class BaseDatabase(abc.ABC):
|
||||
async def update_persona(
|
||||
self,
|
||||
persona_id: str,
|
||||
system_prompt: str = None,
|
||||
begin_dialogs: list[str] = None,
|
||||
tools: list[str] = None,
|
||||
system_prompt: str | None = None,
|
||||
begin_dialogs: list[str] | None = None,
|
||||
tools: list[str] | None = None,
|
||||
) -> Persona | None:
|
||||
"""Update a persona's system prompt or begin dialogs."""
|
||||
...
|
||||
@@ -237,13 +237,32 @@ class BaseDatabase(abc.ABC):
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def insert_preference_or_update(self, key: str, value: str) -> Preference:
|
||||
async def insert_preference_or_update(
|
||||
self, scope: str, scope_id: str, key: str, value: dict
|
||||
) -> Preference:
|
||||
"""Insert a new preference record."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def get_preference(self, key: str) -> Preference:
|
||||
"""Get a preference by bot ID and key."""
|
||||
async def get_preference(self, scope: str, scope_id: str, key: str) -> Preference:
|
||||
"""Get a preference by scope ID and key."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def get_preferences(
|
||||
self, scope: str, scope_id: str | None = None, key: str | None = None
|
||||
) -> list[Preference]:
|
||||
"""Get all preferences for a specific scope ID or key."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def remove_preference(self, scope: str, scope_id: str, key: str) -> None:
|
||||
"""Remove a preference by scope ID and key."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def clear_preferences(self, scope: str, scope_id: str) -> None:
|
||||
"""Clear all preferences for a specific scope ID."""
|
||||
...
|
||||
|
||||
# @abc.abstractmethod
|
||||
|
||||
@@ -2,12 +2,13 @@ import os
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.config import AstrBotConfig
|
||||
from astrbot.api import logger
|
||||
from astrbot.api import logger, sp
|
||||
from .migra_3_to_4 import (
|
||||
migration_conversation_table,
|
||||
migration_platform_table,
|
||||
migration_webchat_data,
|
||||
migration_persona_data,
|
||||
migration_preferences,
|
||||
)
|
||||
|
||||
|
||||
@@ -19,7 +20,9 @@ async def check_migration_needed_v4(db_helper: BaseDatabase) -> bool:
|
||||
data_v3_exists = os.path.exists(get_astrbot_data_path())
|
||||
if not data_v3_exists:
|
||||
return False
|
||||
migration_done = await db_helper.get_preference("migration_done_v4")
|
||||
migration_done = await db_helper.get_preference(
|
||||
"global", "global", "migration_done_v4"
|
||||
)
|
||||
if migration_done:
|
||||
return False
|
||||
return True
|
||||
@@ -49,10 +52,13 @@ async def do_migration_v4(
|
||||
# 执行 WebChat 数据迁移
|
||||
await migration_webchat_data(db_helper, platform_id_map)
|
||||
|
||||
# 执行偏好设置迁移
|
||||
await migration_preferences(db_helper,platform_id_map)
|
||||
|
||||
# 执行平台统计表迁移
|
||||
await migration_platform_table(db_helper, platform_id_map)
|
||||
|
||||
# 标记迁移完成
|
||||
await db_helper.insert_preference_or_update("migration_done_v4", "true")
|
||||
await sp.put_async("global", "global", "migration_done_v4", True)
|
||||
|
||||
logger.info("数据库迁移完成。")
|
||||
|
||||
@@ -2,8 +2,9 @@ import json
|
||||
import datetime
|
||||
from .. import BaseDatabase
|
||||
from .sqlite_v3 import SQLiteDatabase as SQLiteV3DatabaseV3
|
||||
from .shared_preferences_v3 import sp as sp_v3
|
||||
from astrbot.core.config.default import DB_PATH
|
||||
from astrbot.api import logger
|
||||
from astrbot.api import logger, sp
|
||||
from astrbot.core.config import AstrBotConfig
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
@@ -89,7 +90,7 @@ async def migration_conversation_table(
|
||||
|
||||
|
||||
async def migration_platform_table(
|
||||
db_helper: BaseDatabase, platform_id_map: dict[str, str]
|
||||
db_helper: BaseDatabase, platform_id_map: dict[str, dict[str, str]]
|
||||
):
|
||||
db_helper_v3 = SQLiteV3DatabaseV3(
|
||||
db_path=DB_PATH.replace("data_v4.db", "data_v3.db")
|
||||
@@ -122,9 +123,7 @@ async def migration_platform_table(
|
||||
for bucket_idx, bucket_end in enumerate(range(start_time, end_time, 3600)):
|
||||
if bucket_idx % 500 == 0:
|
||||
progress = int((bucket_idx + 1) / total_buckets * 100)
|
||||
logger.info(
|
||||
f"进度: {progress}% ({bucket_idx + 1}/{total_buckets})"
|
||||
)
|
||||
logger.info(f"进度: {progress}% ({bucket_idx + 1}/{total_buckets})")
|
||||
cnt = 0
|
||||
while (
|
||||
idx < len(platform_stats_v3)
|
||||
@@ -258,3 +257,82 @@ async def migration_persona_data(
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"解析 Persona 配置失败:{e}")
|
||||
|
||||
|
||||
async def migration_preferences(
|
||||
db_helper: BaseDatabase, platform_id_map: dict[str, dict[str, str]]
|
||||
):
|
||||
# 1. global scope migration
|
||||
keys = [
|
||||
"inactivated_llm_tools",
|
||||
"inactivated_plugins",
|
||||
"curr_provider",
|
||||
"curr_provider_tts",
|
||||
"curr_provider_stt",
|
||||
"alter_cmd",
|
||||
]
|
||||
for key in keys:
|
||||
value = sp_v3.get(key)
|
||||
if value is not None:
|
||||
await sp.put_async("global", "global", key, value)
|
||||
logger.info(f"迁移全局偏好设置 {key} 成功,值: {value}")
|
||||
|
||||
# 2. umo scope migration
|
||||
session_conversation = sp_v3.get("session_conversation", default={})
|
||||
for umo, conversation_id in session_conversation.items():
|
||||
if not umo or not conversation_id:
|
||||
continue
|
||||
try:
|
||||
session = MessageSesion.from_str(session_str=umo)
|
||||
platform_id = get_platform_id(platform_id_map, session.platform_name)
|
||||
session.platform_id = platform_id
|
||||
await sp.put_async("umo", str(session), "sel_conv_id", conversation_id)
|
||||
logger.info(f"迁移会话 {umo} 的对话数据到新表成功,平台 ID: {platform_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"迁移会话 {umo} 的对话数据失败: {e}", exc_info=True)
|
||||
|
||||
session_service_config = sp_v3.get("session_service_config", default={})
|
||||
for umo, config in session_service_config.items():
|
||||
if not umo or not config:
|
||||
continue
|
||||
try:
|
||||
session = MessageSesion.from_str(session_str=umo)
|
||||
platform_id = get_platform_id(platform_id_map, session.platform_name)
|
||||
session.platform_id = platform_id
|
||||
|
||||
await sp.put_async("umo", str(session), "session_service_config", config)
|
||||
|
||||
logger.info(f"迁移会话 {umo} 的服务配置到新表成功,平台 ID: {platform_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"迁移会话 {umo} 的服务配置失败: {e}", exc_info=True)
|
||||
|
||||
session_variables = sp_v3.get("session_variables", default={})
|
||||
for umo, variables in session_variables.items():
|
||||
if not umo or not variables:
|
||||
continue
|
||||
try:
|
||||
session = MessageSesion.from_str(session_str=umo)
|
||||
platform_id = get_platform_id(platform_id_map, session.platform_name)
|
||||
session.platform_id = platform_id
|
||||
await sp.put_async("umo", str(session), "session_variables", variables)
|
||||
except Exception as e:
|
||||
logger.error(f"迁移会话 {umo} 的变量失败: {e}", exc_info=True)
|
||||
|
||||
session_provider_perf = sp_v3.get("session_provider_perf", default={})
|
||||
for umo, perf in session_provider_perf.items():
|
||||
if not umo or not perf:
|
||||
continue
|
||||
try:
|
||||
session = MessageSesion.from_str(session_str=umo)
|
||||
platform_id = get_platform_id(platform_id_map, session.platform_name)
|
||||
session.platform_id = platform_id
|
||||
|
||||
for provider_type, provider_id in perf.items():
|
||||
await sp.put_async(
|
||||
"umo", str(session), f"provider_perf_{provider_type}", provider_id
|
||||
)
|
||||
logger.info(
|
||||
f"迁移会话 {umo} 的提供商偏好到新表成功,平台 ID: {platform_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"迁移会话 {umo} 的提供商偏好失败: {e}", exc_info=True)
|
||||
|
||||
@@ -0,0 +1,45 @@
|
||||
import json
|
||||
import os
|
||||
from typing import TypeVar
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
_VT = TypeVar("_VT")
|
||||
|
||||
class SharedPreferences:
|
||||
def __init__(self, path=None):
|
||||
if path is None:
|
||||
path = os.path.join(get_astrbot_data_path(), "shared_preferences.json")
|
||||
self.path = path
|
||||
self._data = self._load_preferences()
|
||||
|
||||
def _load_preferences(self):
|
||||
if os.path.exists(self.path):
|
||||
try:
|
||||
with open(self.path, "r") as f:
|
||||
return json.load(f)
|
||||
except json.JSONDecodeError:
|
||||
os.remove(self.path)
|
||||
return {}
|
||||
|
||||
def _save_preferences(self):
|
||||
with open(self.path, "w") as f:
|
||||
json.dump(self._data, f, indent=4, ensure_ascii=False)
|
||||
f.flush()
|
||||
|
||||
def get(self, key, default: _VT = None) -> _VT:
|
||||
return self._data.get(key, default)
|
||||
|
||||
def put(self, key, value):
|
||||
self._data[key] = value
|
||||
self._save_preferences()
|
||||
|
||||
def remove(self, key):
|
||||
if key in self._data:
|
||||
del self._data[key]
|
||||
self._save_preferences()
|
||||
|
||||
def clear(self):
|
||||
self._data.clear()
|
||||
self._save_preferences()
|
||||
|
||||
sp = SharedPreferences()
|
||||
+21
-5
@@ -97,18 +97,34 @@ class Persona(SQLModel, table=True):
|
||||
|
||||
|
||||
class Preference(SQLModel, table=True):
|
||||
"""This class represents user preferences for bots."""
|
||||
"""This class represents preferences for bots."""
|
||||
|
||||
__tablename__ = "preferences"
|
||||
|
||||
key: str = Field(primary_key=True, nullable=False)
|
||||
value: str = Field(sa_type=Text, nullable=False)
|
||||
id: int | None = Field(
|
||||
default=None, primary_key=True, sa_column_kwargs={"autoincrement": True}
|
||||
)
|
||||
scope: str = Field(nullable=False)
|
||||
"""Scope of the preference, such as 'global', 'umo', 'plugin'."""
|
||||
scope_id: str = Field(nullable=False)
|
||||
"""ID of the scope, such as 'global', 'umo', 'plugin_name'."""
|
||||
key: str = Field(nullable=False)
|
||||
value: dict = Field(sa_type=JSON, nullable=False)
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
updated_at: datetime = Field(
|
||||
default_factory=lambda: datetime.now(timezone.utc),
|
||||
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
"scope",
|
||||
"scope_id",
|
||||
"key",
|
||||
name="uix_preference_scope_scope_id_key",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class PlatformMessageHistory(SQLModel, table=True):
|
||||
"""This class represents the message history for a specific platform.
|
||||
@@ -184,8 +200,8 @@ class Conversation:
|
||||
"""对话 ID, 是 uuid 格式的字符串"""
|
||||
history: str = ""
|
||||
"""字符串格式的对话列表。"""
|
||||
title: str = ""
|
||||
persona_id: str = ""
|
||||
title: str | None = ""
|
||||
persona_id: str | None = ""
|
||||
created_at: int = 0
|
||||
updated_at: int = 0
|
||||
|
||||
|
||||
@@ -21,6 +21,7 @@ from sqlalchemy.sql import func
|
||||
|
||||
NOT_GIVEN = T.TypeVar("NOT_GIVEN")
|
||||
|
||||
|
||||
class SQLiteDatabase(BaseDatabase):
|
||||
def __init__(self, db_path: str) -> None:
|
||||
self.db_path = db_path
|
||||
@@ -373,29 +374,77 @@ class SQLiteDatabase(BaseDatabase):
|
||||
delete(Persona).where(Persona.persona_id == persona_id)
|
||||
)
|
||||
|
||||
async def insert_preference_or_update(self, key, value):
|
||||
async def insert_preference_or_update(self, scope, scope_id, key, value):
|
||||
"""Insert a new preference record or update if it exists."""
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
async with session.begin():
|
||||
query = select(Preference).where(Preference.key == key)
|
||||
query = select(Preference).where(
|
||||
Preference.scope == scope,
|
||||
Preference.scope_id == scope_id,
|
||||
Preference.key == key,
|
||||
)
|
||||
result = await session.execute(query)
|
||||
existing_preference = result.scalar_one_or_none()
|
||||
if existing_preference:
|
||||
existing_preference.value = value
|
||||
else:
|
||||
new_preference = Preference(key=key, value=value)
|
||||
new_preference = Preference(
|
||||
scope=scope, scope_id=scope_id, key=key, value=value
|
||||
)
|
||||
session.add(new_preference)
|
||||
return existing_preference or new_preference
|
||||
|
||||
async def get_preference(self, key):
|
||||
async def get_preference(self, scope, scope_id, key):
|
||||
"""Get a preference by key."""
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
query = select(Preference).where(Preference.key == key)
|
||||
query = select(Preference).where(
|
||||
Preference.scope == scope,
|
||||
Preference.scope_id == scope_id,
|
||||
Preference.key == key,
|
||||
)
|
||||
result = await session.execute(query)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_preferences(self, scope, scope_id=None, key=None):
|
||||
"""Get all preferences for a specific scope ID or key."""
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
query = select(Preference).where(Preference.scope == scope)
|
||||
if scope_id is not None:
|
||||
query = query.where(Preference.scope_id == scope_id)
|
||||
if key is not None:
|
||||
query = query.where(Preference.key == key)
|
||||
result = await session.execute(query)
|
||||
return result.scalars().all()
|
||||
|
||||
async def remove_preference(self, scope, scope_id, key):
|
||||
"""Remove a preference by scope ID and key."""
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
async with session.begin():
|
||||
await session.execute(
|
||||
delete(Preference).where(
|
||||
Preference.scope == scope,
|
||||
Preference.scope_id == scope_id,
|
||||
Preference.key == key,
|
||||
)
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
async def clear_preferences(self, scope, scope_id):
|
||||
"""Clear all preferences for a specific scope ID."""
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
async with session.begin():
|
||||
await session.execute(
|
||||
delete(Preference).where(
|
||||
Preference.scope == scope, Preference.scope_id == scope_id
|
||||
)
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
# ====
|
||||
# Deprecated Methods
|
||||
# ====
|
||||
|
||||
@@ -77,7 +77,7 @@ class WebChatAdapter(Platform):
|
||||
os.makedirs(self.imgs_dir, exist_ok=True)
|
||||
|
||||
self.metadata = PlatformMetadata(
|
||||
name="webchat", description="webchat", id=self.config.get("id", "")
|
||||
name="webchat", description="webchat", id="webchat"
|
||||
)
|
||||
|
||||
async def send_by_session(
|
||||
|
||||
@@ -425,10 +425,17 @@ class FunctionToolManager:
|
||||
if func_tool is not None:
|
||||
func_tool.active = False
|
||||
|
||||
inactivated_llm_tools: list = sp.get("inactivated_llm_tools", [])
|
||||
inactivated_llm_tools: list = sp.get(
|
||||
"inactivated_llm_tools", [], scope="global", scope_id="global"
|
||||
)
|
||||
if name not in inactivated_llm_tools:
|
||||
inactivated_llm_tools.append(name)
|
||||
sp.put("inactivated_llm_tools", inactivated_llm_tools)
|
||||
sp.put(
|
||||
"inactivated_llm_tools",
|
||||
inactivated_llm_tools,
|
||||
scope="global",
|
||||
scope_id="global",
|
||||
)
|
||||
|
||||
return True
|
||||
return False
|
||||
@@ -445,10 +452,17 @@ class FunctionToolManager:
|
||||
|
||||
func_tool.active = True
|
||||
|
||||
inactivated_llm_tools: list = sp.get("inactivated_llm_tools", [])
|
||||
inactivated_llm_tools: list = sp.get(
|
||||
"inactivated_llm_tools", [], scope="global", scope_id="global"
|
||||
)
|
||||
if name in inactivated_llm_tools:
|
||||
inactivated_llm_tools.remove(name)
|
||||
sp.put("inactivated_llm_tools", inactivated_llm_tools)
|
||||
sp.put(
|
||||
"inactivated_llm_tools",
|
||||
inactivated_llm_tools,
|
||||
scope="global",
|
||||
scope_id="global",
|
||||
)
|
||||
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -66,7 +66,7 @@ class ProviderManager:
|
||||
return self.persona_mgr.selected_default_persona_v3
|
||||
|
||||
async def set_provider(
|
||||
self, provider_id: str, provider_type: ProviderType, umo: str = None
|
||||
self, provider_id: str, provider_type: ProviderType, umo: str | None = None
|
||||
):
|
||||
"""设置提供商。
|
||||
|
||||
@@ -80,20 +80,20 @@ class ProviderManager:
|
||||
if provider_id not in self.inst_map:
|
||||
raise ValueError(f"提供商 {provider_id} 不存在,无法设置。")
|
||||
if umo:
|
||||
perf = sp.get("session_provider_perf", {})
|
||||
session_perf = perf.get(umo, {})
|
||||
session_perf[provider_type.value] = provider_id
|
||||
perf[umo] = session_perf
|
||||
sp.put("session_provider_perf", perf)
|
||||
await sp.session_put(
|
||||
umo,
|
||||
f"provider_perf_{provider_type.value}",
|
||||
provider_id,
|
||||
)
|
||||
return
|
||||
# 不启用提供商会话隔离模式的情况
|
||||
self.curr_provider_inst = self.inst_map[provider_id]
|
||||
if provider_type == ProviderType.TEXT_TO_SPEECH:
|
||||
sp.put("curr_provider_tts", provider_id)
|
||||
sp.put("curr_provider_tts", provider_id, scope="global", scope_id="global")
|
||||
elif provider_type == ProviderType.SPEECH_TO_TEXT:
|
||||
sp.put("curr_provider_stt", provider_id)
|
||||
sp.put("curr_provider_stt", provider_id, scope="global", scope_id="global")
|
||||
elif provider_type == ProviderType.CHAT_COMPLETION:
|
||||
sp.put("curr_provider", provider_id)
|
||||
sp.put("curr_provider", provider_id, scope="global", scope_id="global")
|
||||
|
||||
async def get_provider_by_id(self, provider_id: str) -> Provider | None:
|
||||
"""根据提供商 ID 获取提供商实例"""
|
||||
@@ -111,9 +111,12 @@ class ProviderManager:
|
||||
"""
|
||||
provider = None
|
||||
if umo:
|
||||
perf = sp.get("session_provider_perf", {})
|
||||
session_perf = perf.get(umo, {})
|
||||
provider_id = session_perf.get(provider_type.value)
|
||||
provider_id = sp.get(
|
||||
f"provider_perf_{provider_type.value}",
|
||||
None,
|
||||
scope="umo",
|
||||
scope_id=umo,
|
||||
)
|
||||
if provider_id:
|
||||
provider = self.inst_map.get(provider_id)
|
||||
if not provider:
|
||||
@@ -153,13 +156,22 @@ class ProviderManager:
|
||||
|
||||
# 设置默认提供商
|
||||
selected_provider_id = sp.get(
|
||||
"curr_provider", self.provider_settings.get("default_provider_id")
|
||||
"curr_provider",
|
||||
self.provider_settings.get("default_provider_id"),
|
||||
scope="global",
|
||||
scope_id="global",
|
||||
)
|
||||
selected_stt_provider_id = sp.get(
|
||||
"curr_provider_stt", self.provider_stt_settings.get("provider_id")
|
||||
"curr_provider_stt",
|
||||
self.provider_stt_settings.get("provider_id"),
|
||||
scope="global",
|
||||
scope_id="global",
|
||||
)
|
||||
selected_tts_provider_id = sp.get(
|
||||
"curr_provider_tts", self.provider_tts_settings.get("provider_id")
|
||||
"curr_provider_tts",
|
||||
self.provider_tts_settings.get("provider_id"),
|
||||
scope="global",
|
||||
scope_id="global",
|
||||
)
|
||||
self.curr_provider_inst = self.inst_map.get(selected_provider_id)
|
||||
if not self.curr_provider_inst and self.provider_insts:
|
||||
|
||||
@@ -75,8 +75,7 @@ class ProviderDashscope(ProviderOpenAIOfficial):
|
||||
# 获得会话变量
|
||||
payload_vars = self.variables.copy()
|
||||
# 动态变量
|
||||
session_vars = sp.get("session_variables", {})
|
||||
session_var = session_vars.get(session_id, {})
|
||||
session_var = await sp.session_get(session_id, "session_variables", default={})
|
||||
payload_vars.update(session_var)
|
||||
|
||||
if (
|
||||
|
||||
@@ -97,8 +97,7 @@ class ProviderDify(Provider):
|
||||
# 获得会话变量
|
||||
payload_vars = self.variables.copy()
|
||||
# 动态变量
|
||||
session_vars = sp.get("session_variables", {})
|
||||
session_var = session_vars.get(session_id, {})
|
||||
session_var = await sp.session_get(session_id, "session_variables", default={})
|
||||
payload_vars.update(session_var)
|
||||
payload_vars["system_prompt"] = system_prompt
|
||||
|
||||
|
||||
@@ -2,8 +2,6 @@
|
||||
会话服务管理器 - 负责管理每个会话的LLM、TTS等服务的启停状态
|
||||
"""
|
||||
|
||||
from typing import Dict
|
||||
|
||||
from astrbot.core import logger, sp
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
|
||||
@@ -26,8 +24,9 @@ class SessionServiceManager:
|
||||
bool: True表示启用,False表示禁用
|
||||
"""
|
||||
# 获取会话服务配置
|
||||
session_config = sp.get("session_service_config", {}) or {}
|
||||
session_services = session_config.get(session_id, {})
|
||||
session_services = sp.get(
|
||||
"session_service_config", {}, scope="umo", scope_id=session_id
|
||||
)
|
||||
|
||||
# 如果配置了该会话的LLM状态,返回该状态
|
||||
llm_enabled = session_services.get("llm_enabled")
|
||||
@@ -45,16 +44,13 @@ class SessionServiceManager:
|
||||
session_id: 会话ID (unified_msg_origin)
|
||||
enabled: True表示启用,False表示禁用
|
||||
"""
|
||||
# 获取当前配置
|
||||
session_config = sp.get("session_service_config", {}) or {}
|
||||
if session_id not in session_config:
|
||||
session_config[session_id] = {}
|
||||
|
||||
# 设置LLM状态
|
||||
session_config[session_id]["llm_enabled"] = enabled
|
||||
|
||||
# 保存配置
|
||||
sp.put("session_service_config", session_config)
|
||||
session_config = (
|
||||
sp.get("session_service_config", {}, scope="umo", scope_id=session_id) or {}
|
||||
)
|
||||
session_config["llm_enabled"] = enabled
|
||||
sp.put(
|
||||
"session_service_config", session_config, scope="umo", scope_id=session_id
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"会话 {session_id} 的LLM状态已更新为: {'启用' if enabled else '禁用'}"
|
||||
@@ -88,8 +84,9 @@ class SessionServiceManager:
|
||||
bool: True表示启用,False表示禁用
|
||||
"""
|
||||
# 获取会话服务配置
|
||||
session_config = sp.get("session_service_config", {}) or {}
|
||||
session_services = session_config.get(session_id, {})
|
||||
session_services = sp.get(
|
||||
"session_service_config", {}, scope="umo", scope_id=session_id
|
||||
)
|
||||
|
||||
# 如果配置了该会话的TTS状态,返回该状态
|
||||
tts_enabled = session_services.get("tts_enabled")
|
||||
@@ -107,16 +104,13 @@ class SessionServiceManager:
|
||||
session_id: 会话ID (unified_msg_origin)
|
||||
enabled: True表示启用,False表示禁用
|
||||
"""
|
||||
# 获取当前配置
|
||||
session_config = sp.get("session_service_config", {}) or {}
|
||||
if session_id not in session_config:
|
||||
session_config[session_id] = {}
|
||||
|
||||
# 设置TTS状态
|
||||
session_config[session_id]["tts_enabled"] = enabled
|
||||
|
||||
# 保存配置
|
||||
sp.put("session_service_config", session_config)
|
||||
session_config = (
|
||||
sp.get("session_service_config", {}, scope="umo", scope_id=session_id) or {}
|
||||
)
|
||||
session_config["tts_enabled"] = enabled
|
||||
sp.put(
|
||||
"session_service_config", session_config, scope="umo", scope_id=session_id
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"会话 {session_id} 的TTS状态已更新为: {'启用' if enabled else '禁用'}"
|
||||
@@ -150,8 +144,9 @@ class SessionServiceManager:
|
||||
bool: True表示启用,False表示禁用
|
||||
"""
|
||||
# 获取会话服务配置
|
||||
session_config = sp.get("session_service_config", {}) or {}
|
||||
session_services = session_config.get(session_id, {})
|
||||
session_services = sp.get(
|
||||
"session_service_config", {}, scope="umo", scope_id=session_id
|
||||
)
|
||||
|
||||
# 如果配置了该会话的整体状态,返回该状态
|
||||
session_enabled = session_services.get("session_enabled")
|
||||
@@ -169,16 +164,13 @@ class SessionServiceManager:
|
||||
session_id: 会话ID (unified_msg_origin)
|
||||
enabled: True表示启用,False表示禁用
|
||||
"""
|
||||
# 获取当前配置
|
||||
session_config = sp.get("session_service_config", {}) or {}
|
||||
if session_id not in session_config:
|
||||
session_config[session_id] = {}
|
||||
|
||||
# 设置会话整体状态
|
||||
session_config[session_id]["session_enabled"] = enabled
|
||||
|
||||
# 保存配置
|
||||
sp.put("session_service_config", session_config)
|
||||
session_config = (
|
||||
sp.get("session_service_config", {}, scope="umo", scope_id=session_id) or {}
|
||||
)
|
||||
session_config["session_enabled"] = enabled
|
||||
sp.put(
|
||||
"session_service_config", session_config, scope="umo", scope_id=session_id
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"会话 {session_id} 的整体状态已更新为: {'启用' if enabled else '禁用'}"
|
||||
@@ -202,7 +194,7 @@ class SessionServiceManager:
|
||||
# =============================================================================
|
||||
|
||||
@staticmethod
|
||||
def get_session_custom_name(session_id: str) -> str:
|
||||
def get_session_custom_name(session_id: str) -> str | None:
|
||||
"""获取会话的自定义名称
|
||||
|
||||
Args:
|
||||
@@ -211,8 +203,9 @@ class SessionServiceManager:
|
||||
Returns:
|
||||
str: 自定义名称,如果没有设置则返回None
|
||||
"""
|
||||
session_config = sp.get("session_service_config", {}) or {}
|
||||
session_services = session_config.get(session_id, {})
|
||||
session_services = sp.get(
|
||||
"session_service_config", {}, scope="umo", scope_id=session_id
|
||||
)
|
||||
return session_services.get("custom_name")
|
||||
|
||||
@staticmethod
|
||||
@@ -223,20 +216,17 @@ class SessionServiceManager:
|
||||
session_id: 会话ID (unified_msg_origin)
|
||||
custom_name: 自定义名称,可以为空字符串来清除名称
|
||||
"""
|
||||
# 获取当前配置
|
||||
session_config = sp.get("session_service_config", {}) or {}
|
||||
if session_id not in session_config:
|
||||
session_config[session_id] = {}
|
||||
|
||||
# 设置自定义名称
|
||||
session_config = (
|
||||
sp.get("session_service_config", {}, scope="umo", scope_id=session_id) or {}
|
||||
)
|
||||
if custom_name and custom_name.strip():
|
||||
session_config[session_id]["custom_name"] = custom_name.strip()
|
||||
session_config["custom_name"] = custom_name.strip()
|
||||
else:
|
||||
# 如果传入空名称,则删除自定义名称
|
||||
session_config[session_id].pop("custom_name", None)
|
||||
|
||||
# 保存配置
|
||||
sp.put("session_service_config", session_config)
|
||||
session_config.pop("custom_name", None)
|
||||
sp.put(
|
||||
"session_service_config", session_config, scope="umo", scope_id=session_id
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"会话 {session_id} 的自定义名称已更新为: {custom_name.strip() if custom_name and custom_name.strip() else '已清除'}"
|
||||
@@ -258,36 +248,3 @@ class SessionServiceManager:
|
||||
|
||||
# 如果没有自定义名称,返回session_id的最后一段
|
||||
return session_id.split(":")[2] if session_id.count(":") >= 2 else session_id
|
||||
|
||||
# =============================================================================
|
||||
# 通用配置方法
|
||||
# =============================================================================
|
||||
|
||||
@staticmethod
|
||||
def get_session_service_config(session_id: str) -> Dict[str, bool]:
|
||||
"""获取指定会话的服务配置
|
||||
|
||||
Args:
|
||||
session_id: 会话ID (unified_msg_origin)
|
||||
|
||||
Returns:
|
||||
Dict[str, bool]: 包含session_enabled、llm_enabled、tts_enabled的字典
|
||||
"""
|
||||
session_config = sp.get("session_service_config", {}) or {}
|
||||
return session_config.get(
|
||||
session_id,
|
||||
{
|
||||
"session_enabled": True, # 默认启用
|
||||
"llm_enabled": True, # 默认启用
|
||||
"tts_enabled": True, # 默认启用
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_all_session_configs() -> Dict[str, Dict[str, bool]]:
|
||||
"""获取所有会话的服务配置
|
||||
|
||||
Returns:
|
||||
Dict[str, Dict[str, bool]]: 所有会话的服务配置
|
||||
"""
|
||||
return sp.get("session_service_config", {}) or {}
|
||||
|
||||
@@ -22,7 +22,9 @@ class SessionPluginManager:
|
||||
bool: True表示启用,False表示禁用
|
||||
"""
|
||||
# 获取会话插件配置
|
||||
session_plugin_config = sp.get("session_plugin_config", {}) or {}
|
||||
session_plugin_config = sp.get(
|
||||
"session_plugin_config", {}, scope="umo", scope_id=session_id
|
||||
)
|
||||
session_config = session_plugin_config.get(session_id, {})
|
||||
|
||||
enabled_plugins = session_config.get("enabled_plugins", [])
|
||||
@@ -51,7 +53,9 @@ class SessionPluginManager:
|
||||
enabled: True表示启用,False表示禁用
|
||||
"""
|
||||
# 获取当前配置
|
||||
session_plugin_config = sp.get("session_plugin_config", {}) or {}
|
||||
session_plugin_config = sp.get(
|
||||
"session_plugin_config", {}, scope="umo", scope_id=session_id
|
||||
)
|
||||
if session_id not in session_plugin_config:
|
||||
session_plugin_config[session_id] = {
|
||||
"enabled_plugins": [],
|
||||
@@ -79,7 +83,9 @@ class SessionPluginManager:
|
||||
session_config["enabled_plugins"] = enabled_plugins
|
||||
session_config["disabled_plugins"] = disabled_plugins
|
||||
session_plugin_config[session_id] = session_config
|
||||
sp.put("session_plugin_config", session_plugin_config)
|
||||
sp.put(
|
||||
"session_plugin_config", session_plugin_config, scope="umo", scope_id=session_id
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"会话 {session_id} 的插件 {plugin_name} 状态已更新为: {'启用' if enabled else '禁用'}"
|
||||
@@ -95,7 +101,9 @@ class SessionPluginManager:
|
||||
Returns:
|
||||
Dict[str, List[str]]: 包含enabled_plugins和disabled_plugins的字典
|
||||
"""
|
||||
session_plugin_config = sp.get("session_plugin_config", {}) or {}
|
||||
session_plugin_config = sp.get(
|
||||
"session_plugin_config", {}, scope="umo", scope_id=session_id
|
||||
)
|
||||
return session_plugin_config.get(
|
||||
session_id, {"enabled_plugins": [], "disabled_plugins": []}
|
||||
)
|
||||
|
||||
@@ -374,10 +374,9 @@ class PluginManager:
|
||||
- success (bool): 是否全部加载成功
|
||||
- error_message (str|None): 错误信息,成功时为 None
|
||||
"""
|
||||
inactivated_plugins: list = sp.get("inactivated_plugins", [])
|
||||
inactivated_llm_tools: list = sp.get("inactivated_llm_tools", [])
|
||||
|
||||
alter_cmd = sp.get("alter_cmd", {})
|
||||
inactivated_plugins = await sp.global_get("inactivated_plugins", [])
|
||||
inactivated_llm_tools = await sp.global_get("inactivated_llm_tools", [])
|
||||
alter_cmd = await sp.global_get("alter_cmd", {})
|
||||
|
||||
plugin_modules = self._get_plugin_modules()
|
||||
if plugin_modules is None:
|
||||
@@ -787,12 +786,12 @@ class PluginManager:
|
||||
await self._terminate_plugin(plugin)
|
||||
|
||||
# 加入到 shared_preferences 中
|
||||
inactivated_plugins: list = sp.get("inactivated_plugins", [])
|
||||
inactivated_plugins: list = await sp.global_get("inactivated_plugins", [])
|
||||
if plugin.module_path not in inactivated_plugins:
|
||||
inactivated_plugins.append(plugin.module_path)
|
||||
|
||||
inactivated_llm_tools: list = list(
|
||||
set(sp.get("inactivated_llm_tools", []))
|
||||
set(await sp.global_get("inactivated_llm_tools", []))
|
||||
) # 后向兼容
|
||||
|
||||
# 禁用插件启用的 llm_tool
|
||||
@@ -802,8 +801,8 @@ class PluginManager:
|
||||
if func_tool.name not in inactivated_llm_tools:
|
||||
inactivated_llm_tools.append(func_tool.name)
|
||||
|
||||
sp.put("inactivated_plugins", inactivated_plugins)
|
||||
sp.put("inactivated_llm_tools", inactivated_llm_tools)
|
||||
await sp.global_put("inactivated_plugins", inactivated_plugins)
|
||||
await sp.global_put("inactivated_llm_tools", inactivated_llm_tools)
|
||||
|
||||
plugin.activated = False
|
||||
|
||||
@@ -829,11 +828,11 @@ class PluginManager:
|
||||
|
||||
async def turn_on_plugin(self, plugin_name: str):
|
||||
plugin = self.context.get_registered_star(plugin_name)
|
||||
inactivated_plugins: list = sp.get("inactivated_plugins", [])
|
||||
inactivated_llm_tools: list = sp.get("inactivated_llm_tools", [])
|
||||
inactivated_plugins: list = await sp.global_get("inactivated_plugins", [])
|
||||
inactivated_llm_tools: list = await sp.global_get("inactivated_llm_tools", [])
|
||||
if plugin.module_path in inactivated_plugins:
|
||||
inactivated_plugins.remove(plugin.module_path)
|
||||
sp.put("inactivated_plugins", inactivated_plugins)
|
||||
await sp.global_put("inactivated_plugins", inactivated_plugins)
|
||||
|
||||
# 启用插件启用的 llm_tool
|
||||
for func_tool in llm_tools.func_list:
|
||||
@@ -843,7 +842,7 @@ class PluginManager:
|
||||
):
|
||||
inactivated_llm_tools.remove(func_tool.name)
|
||||
func_tool.active = True
|
||||
sp.put("inactivated_llm_tools", inactivated_llm_tools)
|
||||
await sp.global_put("inactivated_llm_tools", inactivated_llm_tools)
|
||||
|
||||
await self.reload(plugin_name)
|
||||
|
||||
|
||||
@@ -1,43 +1,214 @@
|
||||
import json
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.db.po import Preference
|
||||
import threading
|
||||
import asyncio
|
||||
import os
|
||||
from typing import TypeVar
|
||||
from typing import TypeVar, Any, overload
|
||||
from .astrbot_path import get_astrbot_data_path
|
||||
|
||||
|
||||
_VT = TypeVar("_VT")
|
||||
|
||||
|
||||
class SharedPreferences:
|
||||
def __init__(self, path=None):
|
||||
if path is None:
|
||||
path = os.path.join(get_astrbot_data_path(), "shared_preferences.json")
|
||||
self.path = path
|
||||
self._data = self._load_preferences()
|
||||
def __init__(self, db_helper: BaseDatabase, json_storage_path=None):
|
||||
if json_storage_path is None:
|
||||
json_storage_path = os.path.join(
|
||||
get_astrbot_data_path(), "shared_preferences.json"
|
||||
)
|
||||
self.path = json_storage_path
|
||||
self.db_helper = db_helper
|
||||
|
||||
def _load_preferences(self):
|
||||
if os.path.exists(self.path):
|
||||
try:
|
||||
with open(self.path, "r") as f:
|
||||
return json.load(f)
|
||||
except json.JSONDecodeError:
|
||||
os.remove(self.path)
|
||||
return {}
|
||||
async def get_async(
|
||||
self,
|
||||
scope: str,
|
||||
scope_id: str,
|
||||
key: str,
|
||||
default: _VT = None,
|
||||
) -> _VT:
|
||||
"""获取指定范围和键的偏好设置"""
|
||||
if scope_id is not None and key is not None:
|
||||
result = await self.db_helper.get_preference(scope, scope_id, key)
|
||||
if result:
|
||||
ret = result.value["val"]
|
||||
else:
|
||||
ret = default
|
||||
return ret
|
||||
else:
|
||||
raise ValueError(
|
||||
"scope_id and key cannot be None when getting a specific preference."
|
||||
)
|
||||
|
||||
def _save_preferences(self):
|
||||
with open(self.path, "w") as f:
|
||||
json.dump(self._data, f, indent=4, ensure_ascii=False)
|
||||
f.flush()
|
||||
async def range_get_async(
|
||||
self, scope: str, scope_id: str | None = None, key: str | None = None
|
||||
) -> list[Preference]:
|
||||
"""获取指定范围的偏好设置
|
||||
Note: 返回 Preference 列表,其中的 value 属性是一个 dict,value["val"] 为值。scope_id 和 key 可以为 None,这时返回该范围下所有的偏好设置。
|
||||
"""
|
||||
ret = await self.db_helper.get_preferences(scope, scope_id, key)
|
||||
return ret
|
||||
|
||||
def get(self, key, default: _VT = None) -> _VT:
|
||||
return self._data.get(key, default)
|
||||
@overload
|
||||
async def session_get(
|
||||
self, umo: None, key: str, default: Any = None
|
||||
) -> list[Preference]: ...
|
||||
|
||||
def put(self, key, value):
|
||||
self._data[key] = value
|
||||
self._save_preferences()
|
||||
@overload
|
||||
async def session_get(
|
||||
self, umo: str, key: None, default: Any = None
|
||||
) -> list[Preference]: ...
|
||||
|
||||
def remove(self, key):
|
||||
if key in self._data:
|
||||
del self._data[key]
|
||||
self._save_preferences()
|
||||
@overload
|
||||
async def session_get(
|
||||
self, umo: None, key: None, default: Any = None
|
||||
) -> list[Preference]: ...
|
||||
|
||||
def clear(self):
|
||||
self._data.clear()
|
||||
self._save_preferences()
|
||||
async def session_get(
|
||||
self, umo: str | None, key: str | None = None, default: _VT = None
|
||||
) -> _VT | list[Preference]:
|
||||
"""获取会话范围的偏好设置
|
||||
|
||||
Note: 当 scope_id 或者 key 为 None,时,返回 Preference 列表,其中的 value 属性是一个 dict,value["val"] 为值。
|
||||
"""
|
||||
if umo is None or key is None:
|
||||
return await self.range_get_async("umo", umo, key)
|
||||
return await self.get_async("umo", umo, key, default)
|
||||
|
||||
@overload
|
||||
async def global_get(
|
||||
self, key: None, default: Any = None
|
||||
) -> list[Preference]: ...
|
||||
|
||||
@overload
|
||||
async def global_get(
|
||||
self, key: str, default: _VT = None
|
||||
) -> _VT: ...
|
||||
|
||||
async def global_get(
|
||||
self, key: str | None, default: _VT = None
|
||||
) -> _VT | list[Preference]:
|
||||
"""获取全局范围的偏好设置
|
||||
|
||||
Note: 当 scope_id 或者 key 为 None,时,返回 Preference 列表,其中的 value 属性是一个 dict,value["val"] 为值。
|
||||
"""
|
||||
if key is None:
|
||||
return await self.range_get_async("global", "global", key)
|
||||
return await self.get_async("global", "global", key, default)
|
||||
|
||||
async def put_async(self, scope: str, scope_id: str, key: str, value: Any):
|
||||
"""设置指定范围和键的偏好设置"""
|
||||
await self.db_helper.insert_preference_or_update(
|
||||
scope, scope_id, key, {"val": value}
|
||||
)
|
||||
|
||||
async def session_put(self, umo: str, key: str, value: Any):
|
||||
await self.put_async("umo", umo, key, value)
|
||||
|
||||
async def global_put(self, key: str, value: Any):
|
||||
await self.put_async("global", "global", key, value)
|
||||
|
||||
async def remove_async(self, scope: str, scope_id: str, key: str):
|
||||
"""删除指定范围和键的偏好设置"""
|
||||
await self.db_helper.remove_preference(scope, scope_id, key)
|
||||
|
||||
async def session_remove(self, umo: str, key: str):
|
||||
await self.remove_async("umo", umo, key)
|
||||
|
||||
async def global_remove(self, key: str):
|
||||
"""删除全局偏好设置"""
|
||||
await self.remove_async("global", "global", key)
|
||||
|
||||
async def clear_async(self, scope: str, scope_id: str):
|
||||
"""清空指定范围的所有偏好设置"""
|
||||
await self.db_helper.clear_preferences(scope, scope_id)
|
||||
|
||||
# ====
|
||||
# DEPRECATED METHODS
|
||||
# ====
|
||||
|
||||
def get(
|
||||
self,
|
||||
key: str,
|
||||
default: _VT = None,
|
||||
scope: str | None = None,
|
||||
scope_id: str | None = "",
|
||||
) -> _VT:
|
||||
"""获取偏好设置(已弃用)"""
|
||||
result: _VT | None = None
|
||||
|
||||
def runner():
|
||||
nonlocal result, scope, scope_id
|
||||
scope = scope or "unknown"
|
||||
if scope_id == "":
|
||||
scope_id = "unknown"
|
||||
if scope_id is None or key is None:
|
||||
# result = asyncio.run(self.range_get_async(scope, scope_id, key))
|
||||
raise ValueError(
|
||||
"scope_id and key cannot be None when getting a specific preference."
|
||||
)
|
||||
else:
|
||||
result = asyncio.run(self.get_async(scope, scope_id, key, default))
|
||||
|
||||
t = threading.Thread(target=runner)
|
||||
t.start()
|
||||
t.join()
|
||||
|
||||
if result is None:
|
||||
return default
|
||||
|
||||
return result
|
||||
|
||||
def range_get(
|
||||
self, scope: str, scope_id: str | None = None, key: str | None = None
|
||||
) -> list[Preference]:
|
||||
"""获取指定范围的偏好设置(已弃用)"""
|
||||
result: list[Preference] = []
|
||||
|
||||
def runner():
|
||||
nonlocal result, scope, scope_id, key
|
||||
result = asyncio.run(self.range_get_async(scope, scope_id, key))
|
||||
|
||||
t = threading.Thread(target=runner)
|
||||
t.start()
|
||||
t.join()
|
||||
|
||||
return result
|
||||
|
||||
def put(self, key, value, scope: str | None = None, scope_id: str | None = None):
|
||||
"""设置偏好设置(已弃用)"""
|
||||
|
||||
def runner():
|
||||
nonlocal scope, scope_id
|
||||
scope = scope or "unknown"
|
||||
scope_id = scope_id or "unknown"
|
||||
asyncio.run(self.put_async(scope, scope_id, key, value))
|
||||
|
||||
t = threading.Thread(target=runner)
|
||||
t.start()
|
||||
t.join()
|
||||
|
||||
def remove(self, key, scope: str | None = None, scope_id: str | None = None):
|
||||
"""删除偏好设置(已弃用)"""
|
||||
|
||||
def runner():
|
||||
nonlocal scope, scope_id
|
||||
scope = scope or "unknown"
|
||||
scope_id = scope_id or "unknown"
|
||||
asyncio.run(self.remove_async(scope, scope_id, key))
|
||||
|
||||
t = threading.Thread(target=runner)
|
||||
t.start()
|
||||
t.join()
|
||||
|
||||
def clear(self, scope: str | None = None, scope_id: str | None = None):
|
||||
"""清空偏好设置(已弃用)"""
|
||||
|
||||
def runner():
|
||||
nonlocal scope, scope_id
|
||||
scope = scope or "unknown"
|
||||
scope_id = scope_id or "unknown"
|
||||
asyncio.run(self.clear_async(scope, scope_id))
|
||||
|
||||
t = threading.Thread(target=runner)
|
||||
t.start()
|
||||
t.join()
|
||||
|
||||
@@ -38,7 +38,13 @@ class SessionManagementRoute(Route):
|
||||
async def list_sessions(self):
|
||||
"""获取所有会话的列表,包括 persona 和 provider 信息"""
|
||||
try:
|
||||
session_conversations = sp.get("session_conversation", {}) or {}
|
||||
preferences = await sp.session_get(umo=None, key="sel_conv_id", default=[])
|
||||
session_conversations = {}
|
||||
for pref in preferences:
|
||||
session_conversations[pref.scope_id] = pref.value["val"]
|
||||
|
||||
logger.debug(session_conversations)
|
||||
|
||||
provider_manager = self.core_lifecycle.provider_manager
|
||||
persona_mgr = self.core_lifecycle.persona_mgr
|
||||
personas = persona_mgr.personas_v3
|
||||
@@ -51,13 +57,9 @@ class SessionManagementRoute(Route):
|
||||
"session_id": session_id,
|
||||
"conversation_id": conversation_id,
|
||||
"persona_id": None,
|
||||
"persona_name": None,
|
||||
"chat_provider_id": None,
|
||||
"chat_provider_name": None,
|
||||
"stt_provider_id": None,
|
||||
"stt_provider_name": None,
|
||||
"tts_provider_id": None,
|
||||
"tts_provider_name": None,
|
||||
"session_enabled": SessionServiceManager.is_session_enabled(
|
||||
session_id
|
||||
),
|
||||
@@ -92,16 +94,15 @@ class SessionManagementRoute(Route):
|
||||
if conversation.persona_id and conversation.persona_id != "[%None]":
|
||||
for persona in personas:
|
||||
if persona["name"] == conversation.persona_id:
|
||||
session_info["persona_name"] = persona["name"]
|
||||
session_info["persona_id"] = persona["name"]
|
||||
break
|
||||
elif conversation.persona_id == "[%None]":
|
||||
session_info["persona_name"] = "无人格"
|
||||
session_info["persona_id"] = "无人格"
|
||||
else:
|
||||
# 使用默认人格
|
||||
default_persona = persona_mgr.selected_default_persona_v3
|
||||
if default_persona:
|
||||
session_info["persona_id"] = default_persona["name"]
|
||||
session_info["persona_name"] = default_persona["name"]
|
||||
|
||||
# 获取 provider 信息
|
||||
provider_manager = self.core_lifecycle.provider_manager
|
||||
@@ -117,15 +118,12 @@ class SessionManagementRoute(Route):
|
||||
if chat_provider:
|
||||
meta = chat_provider.meta()
|
||||
session_info["chat_provider_id"] = meta.id
|
||||
session_info["chat_provider_name"] = meta.id
|
||||
if tts_provider:
|
||||
meta = tts_provider.meta()
|
||||
session_info["tts_provider_id"] = meta.id
|
||||
session_info["tts_provider_name"] = meta.id
|
||||
if stt_provider:
|
||||
meta = stt_provider.meta()
|
||||
session_info["stt_provider_id"] = meta.id
|
||||
session_info["stt_provider_name"] = meta.id
|
||||
|
||||
sessions.append(session_info)
|
||||
|
||||
|
||||
+12
-19
@@ -503,7 +503,7 @@ UID: {user_id} 此 ID 可用于设置管理员。
|
||||
|
||||
scene = RstScene.get_scene(is_group, is_unique_session)
|
||||
|
||||
alter_cmd_cfg = sp.get("alter_cmd", {})
|
||||
alter_cmd_cfg = await sp.get_async("global", "global", "alter_cmd", {})
|
||||
plugin_config = alter_cmd_cfg.get("astrbot", {})
|
||||
reset_cfg = plugin_config.get("reset", {})
|
||||
|
||||
@@ -1101,29 +1101,22 @@ UID: {user_id} 此 ID 可用于设置管理员。
|
||||
async def set_variable(self, event: AstrMessageEvent, key: str, value: str):
|
||||
# session_id = event.get_session_id()
|
||||
uid = event.unified_msg_origin
|
||||
session_vars = sp.get("session_variables", {})
|
||||
|
||||
session_var = session_vars.get(uid, {})
|
||||
session_var = await sp.session_get(uid, "session_variables", {})
|
||||
session_var[key] = value
|
||||
|
||||
session_vars[uid] = session_var
|
||||
|
||||
sp.put("session_variables", session_vars)
|
||||
await sp.session_put(uid, "session_variables", session_var)
|
||||
|
||||
yield event.plain_result(f"会话 {uid} 变量 {key} 存储成功。使用 /unset 移除。")
|
||||
|
||||
@filter.command("unset")
|
||||
async def unset_variable(self, event: AstrMessageEvent, key: str):
|
||||
uid = event.unified_msg_origin
|
||||
session_vars = sp.get("session_variables", {})
|
||||
|
||||
session_var = session_vars.get(uid, {})
|
||||
session_var = await sp.session_get(umo="uid", key="session_variables", default={})
|
||||
|
||||
if key not in session_var:
|
||||
yield event.plain_result("没有那个变量名。格式 /unset 变量名。")
|
||||
else:
|
||||
del session_var[key]
|
||||
sp.put("session_variables", session_vars)
|
||||
await sp.session_put(uid, "session_variables", session_var)
|
||||
yield event.plain_result(f"会话 {uid} 变量 {key} 移除成功。")
|
||||
|
||||
@filter.platform_adapter_type(filter.PlatformAdapterType.ALL)
|
||||
@@ -1356,7 +1349,7 @@ UID: {user_id} 此 ID 可用于设置管理员。
|
||||
# 对reset权限进行特殊处理
|
||||
# ============================
|
||||
if cmd_name == "reset" and cmd_type == "config":
|
||||
alter_cmd_cfg = sp.get("alter_cmd", {})
|
||||
alter_cmd_cfg = await sp.global_get("alter_cmd", {})
|
||||
plugin_ = alter_cmd_cfg.get("astrbot", {})
|
||||
reset_cfg = plugin_.get("reset", {})
|
||||
|
||||
@@ -1391,7 +1384,7 @@ UID: {user_id} 此 ID 可用于设置管理员。
|
||||
scene = RstScene.from_index(scene_num)
|
||||
scene_key = scene.key
|
||||
|
||||
self.update_reset_permission(scene_key, perm_type)
|
||||
await self.update_reset_permission(scene_key, perm_type)
|
||||
|
||||
yield event.plain_result(
|
||||
f"已将 reset 命令在{scene.name}场景下的权限设为{perm_type}"
|
||||
@@ -1422,14 +1415,14 @@ UID: {user_id} 此 ID 可用于设置管理员。
|
||||
|
||||
found_plugin = star_map[found_command.handler_module_path]
|
||||
|
||||
alter_cmd_cfg = sp.get("alter_cmd", {})
|
||||
alter_cmd_cfg = await sp.global_get("alter_cmd", {})
|
||||
plugin_ = alter_cmd_cfg.get(found_plugin.name, {})
|
||||
cfg = plugin_.get(found_command.handler_name, {})
|
||||
cfg["permission"] = cmd_type
|
||||
plugin_[found_command.handler_name] = cfg
|
||||
alter_cmd_cfg[found_plugin.name] = plugin_
|
||||
|
||||
sp.put("alter_cmd", alter_cmd_cfg)
|
||||
await sp.global_put("alter_cmd", alter_cmd_cfg)
|
||||
|
||||
# 注入权限过滤器
|
||||
found_permission_filter = False
|
||||
@@ -1453,17 +1446,17 @@ UID: {user_id} 此 ID 可用于设置管理员。
|
||||
|
||||
yield event.plain_result(f"已将 {cmd_name} 设置为 {cmd_type} 指令")
|
||||
|
||||
def update_reset_permission(self, scene_key: str, perm_type: str):
|
||||
async def update_reset_permission(self, scene_key: str, perm_type: str):
|
||||
"""更新reset命令在特定场景下的权限设置
|
||||
|
||||
Args:
|
||||
scene_key (str): 场景编号,1-3
|
||||
perm_type (str): 权限类型,admin或member
|
||||
"""
|
||||
alter_cmd_cfg = sp.get("alter_cmd", {})
|
||||
alter_cmd_cfg = await sp.global_get("alter_cmd", {})
|
||||
plugin_cfg = alter_cmd_cfg.get("astrbot", {})
|
||||
reset_cfg = plugin_cfg.get("reset", {})
|
||||
reset_cfg[scene_key] = perm_type
|
||||
plugin_cfg["reset"] = reset_cfg
|
||||
alter_cmd_cfg["astrbot"] = plugin_cfg
|
||||
sp.put("alter_cmd", alter_cmd_cfg)
|
||||
await sp.global_put("alter_cmd", alter_cmd_cfg)
|
||||
|
||||
Reference in New Issue
Block a user