refactor: 重构 SharedPreference 类并采用数据库存储替换 json 存储 (#2482)

This commit is contained in:
Soulter
2025-08-18 19:12:26 +08:00
committed by GitHub
parent 9e7d46f956
commit 64bcbc9fc0
20 changed files with 650 additions and 281 deletions
+1 -1
View File
@@ -20,7 +20,7 @@ html_renderer = HtmlRenderer(t2i_base_url)
logger = LogManager.GetLogger(log_name="astrbot")
db_helper = SQLiteDatabase(DB_PATH)
# 简单的偏好设置存储, 这里后续应该存储到数据库中, 一些部分可以存储到配置中
sp = SharedPreferences()
sp = SharedPreferences(db_helper=db_helper)
# 文件令牌服务
file_token_service = FileTokenService()
pip_installer = PipInstaller(
+32 -17
View File
@@ -40,7 +40,9 @@ class AstrBotConfigManager:
def _load_all_configs(self):
"""Load all configurations from the shared preferences."""
abconf_data = self.sp.get("abconf_mapping", {})
abconf_data = self.sp.get(
"abconf_mapping", {}, scope="global", scope_id="global"
)
for uuid_, meta in abconf_data.items():
filename = meta["path"]
conf_path = os.path.join(get_astrbot_config_path(), filename)
@@ -55,13 +57,13 @@ class AstrBotConfigManager:
def _is_umo_match(self, p1: str, p2: str) -> bool:
"""判断 p2 umo 是否逻辑包含于 p1 umo"""
p1 = p1.split(":")
p2 = p2.split(":")
p1_ls = p1.split(":")
p2_ls = p2.split(":")
if len(p1) != 3 or len(p2) != 3:
if len(p1_ls) != 3 or len(p2_ls) != 3:
return False # 非法格式
return all(p == "" or p == t for p, t in zip(p1, p2))
return all(p == "" or p == t for p, t in zip(p1_ls, p2_ls))
def _load_conf_mapping(self, umo: str | MessageSession) -> ConfInfo:
"""获取指定 umo 的配置文件 uuid, 如果不存在则返回默认配置(返回 "default")
@@ -70,7 +72,9 @@ class AstrBotConfigManager:
ConfInfo: 包含配置文件的 uuid, 路径和名称等信息, 是一个 dict 类型
"""
# uuid -> { "umop": list, "path": str, "name": str }
abconf_data = self.sp.get("abconf_mapping", {}) # default is not included here
abconf_data = self.sp.get(
"abconf_mapping", {}, scope="global", scope_id="global"
)
if isinstance(umo, MessageSession):
umo = str(umo)
else:
@@ -91,7 +95,7 @@ class AstrBotConfigManager:
abconf_path: str,
abconf_id: str,
umo_parts: list[str] | list[MessageSession],
abconf_name: str = None,
abconf_name: str | None = None,
) -> None:
"""保存配置文件的映射关系"""
for part in umo_parts:
@@ -101,14 +105,16 @@ class AstrBotConfigManager:
raise ValueError(
"umo_parts must be a list of strings or MessageSession instances"
)
abconf_data = self.sp.get("abconf_mapping", {})
abconf_data = self.sp.get(
"abconf_mapping", {}, scope="global", scope_id="global"
)
random_word = abconf_name or uuid.uuid4().hex[:8]
abconf_data[abconf_id] = {
"umop": umo_parts,
"path": abconf_path,
"name": random_word,
}
self.sp.put("abconf_mapping", abconf_data)
self.sp.put("abconf_mapping", abconf_data, scope="global", scope_id="global")
def get_conf(self, umo: str | MessageSession | None) -> AstrBotConfig:
"""获取指定 umo 的配置文件。如果不存在,则 fallback 到默认配置文件。"""
@@ -141,7 +147,10 @@ class AstrBotConfigManager:
"""获取所有配置文件的元数据列表"""
conf_list = []
conf_list.append(DEFAULT_CONFIG_CONF_INFO)
for uuid_, meta in self.sp.get("abconf_mapping", {}).items():
abconf_mapping = self.sp.get(
"abconf_mapping", {}, scope="global", scope_id="global"
)
for uuid_, meta in abconf_mapping.items():
conf_list.append(ConfInfo(**meta, id=uuid_))
return conf_list
@@ -149,7 +158,7 @@ class AstrBotConfigManager:
self,
umo_parts: list[str] | list[MessageSession],
config: dict = DEFAULT_CONFIG,
name: str = None,
name: str | None = None,
) -> str:
"""
umo 由三个部分组成 [platform_id]:[message_type]:[session_id]。
@@ -181,7 +190,9 @@ class AstrBotConfigManager:
raise ValueError("不能删除默认配置文件")
# 从映射中移除
abconf_data = self.sp.get("abconf_mapping", {})
abconf_data = self.sp.get(
"abconf_mapping", {}, scope="global", scope_id="global"
)
if conf_id not in abconf_data:
logger.warning(f"配置文件 {conf_id} 不存在于映射中")
return False
@@ -206,13 +217,13 @@ class AstrBotConfigManager:
# 从映射中移除
del abconf_data[conf_id]
self.sp.put("abconf_mapping", abconf_data)
self.sp.put("abconf_mapping", abconf_data, scope="global", scope_id="global")
logger.info(f"成功删除配置文件 {conf_id}")
return True
def update_conf_info(
self, conf_id: str, name: str = None, umo_parts: list[str] = None
self, conf_id: str, name: str | None = None, umo_parts: list[str] | None = None
) -> bool:
"""更新配置文件信息
@@ -227,7 +238,9 @@ class AstrBotConfigManager:
if conf_id == "default":
raise ValueError("不能更新默认配置文件的信息")
abconf_data = self.sp.get("abconf_mapping", {})
abconf_data = self.sp.get(
"abconf_mapping", {}, scope="global", scope_id="global"
)
if conf_id not in abconf_data:
logger.warning(f"配置文件 {conf_id} 不存在于映射中")
return False
@@ -249,11 +262,13 @@ class AstrBotConfigManager:
abconf_data[conf_id]["umop"] = umo_parts
# 保存更新
self.sp.put("abconf_mapping", abconf_data)
self.sp.put("abconf_mapping", abconf_data, scope="global", scope_id="global")
logger.info(f"成功更新配置文件 {conf_id} 的信息")
return True
def g(self, umo: str = None, key: str = None, default: _VT = None) -> _VT:
def g(
self, umo: str | None = None, key: str | None = None, default: _VT = None
) -> _VT:
"""获取配置项。umo 为 None 时使用默认配置"""
if umo is None:
return self.confs["default"].get(key, default)
+29 -38
View File
@@ -6,7 +6,6 @@ AstrBot 会话-对话管理器, 维护两个本地存储, 其中一个是 json
"""
import json
import asyncio
from astrbot.core import sp
from typing import Dict, List
from astrbot.core.db import BaseDatabase
@@ -17,25 +16,9 @@ class ConversationManager:
"""负责管理会话与 LLM 的对话,某个会话当前正在用哪个对话。"""
def __init__(self, db_helper: BaseDatabase):
# session_conversations 字典记录会话ID-对话ID 映射关系
self.session_conversations: Dict[str, str] = sp.get("session_conversation", {})
self.session_conversations: Dict[str, str] = {}
self.db = db_helper
self.save_interval = 60 # 每 60 秒保存一次
self._start_periodic_save()
def _start_periodic_save(self):
"""启动定时保存任务"""
asyncio.create_task(self._periodic_save())
async def _periodic_save(self):
"""定时保存会话对话映射关系到存储中"""
while True:
await asyncio.sleep(self.save_interval)
self._save_to_storage()
def _save_to_storage(self):
"""保存会话对话映射关系到存储中"""
sp.put("session_conversation", self.session_conversations)
def _convert_conv_from_v2_to_v1(self, conv_v2: ConversationV2) -> Conversation:
"""将 ConversationV2 对象转换为 Conversation 对象"""
@@ -55,10 +38,10 @@ class ConversationManager:
async def new_conversation(
self,
unified_msg_origin: str,
platform_id: str = None,
content: list[dict] = None,
title: str = None,
persona_id: str = None,
platform_id: str | None = None,
content: list[dict] | None = None,
title: str | None = None,
persona_id: str | None = None,
) -> str:
"""新建对话,并将当前会话的对话转移到新对话
@@ -82,8 +65,8 @@ class ConversationManager:
persona_id=persona_id,
)
self.session_conversations[unified_msg_origin] = conv.conversation_id
sp.put("session_conversation", self.session_conversations)
return str(conv.conversation_id)
await sp.session_put(unified_msg_origin, "sel_conv_id", conv.conversation_id)
return conv.conversation_id
async def switch_conversation(self, unified_msg_origin: str, conversation_id: str):
"""切换会话的对话
@@ -93,10 +76,10 @@ class ConversationManager:
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
"""
self.session_conversations[unified_msg_origin] = conversation_id
sp.put("session_conversation", self.session_conversations)
await sp.session_put(unified_msg_origin, "sel_conv_id", conversation_id)
async def delete_conversation(
self, unified_msg_origin: str, conversation_id: str = None
self, unified_msg_origin: str, conversation_id: str | None = None
):
"""删除会话的对话,当 conversation_id 为 None 时删除会话当前的对话
@@ -113,9 +96,9 @@ class ConversationManager:
await self.db.delete_conversation(cid=conversation_id)
if f:
self.session_conversations.pop(unified_msg_origin, None)
sp.put("session_conversation", self.session_conversations)
await sp.session_remove(unified_msg_origin, "sel_conv_id")
async def get_curr_conversation_id(self, unified_msg_origin: str) -> str:
async def get_curr_conversation_id(self, unified_msg_origin: str) -> str | None:
"""获取会话当前的对话 ID
Args:
@@ -123,7 +106,12 @@ class ConversationManager:
Returns:
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
"""
return self.session_conversations.get(unified_msg_origin, None)
ret = self.session_conversations.get(unified_msg_origin, None)
if not ret:
ret = await sp.session_get(unified_msg_origin, "sel_conv_id", None)
if ret:
self.session_conversations[unified_msg_origin] = ret
return ret
async def get_conversation(
self,
@@ -150,7 +138,7 @@ class ConversationManager:
return conv_res
async def get_conversations(
self, unified_msg_origin: str = None, platform_id: str = None
self, unified_msg_origin: str | None = None, platform_id: str | None = None
) -> List[Conversation]:
"""获取对话列表
@@ -203,10 +191,10 @@ class ConversationManager:
async def update_conversation(
self,
unified_msg_origin: str,
conversation_id: str = None,
history: list[dict] = None,
title: str = None,
persona_id: str = None,
conversation_id: str | None = None,
history: list[dict] | None = None,
title: str | None = None,
persona_id: str | None = None,
):
"""更新会话的对话
@@ -216,8 +204,8 @@ class ConversationManager:
history (List[Dict]): 对话历史记录, 是一个字典列表, 每个字典包含 role 和 content 字段
"""
if not conversation_id:
# 如果没有提供 conversation_id,则从 session_conversations 中获取当前的
conversation_id = self.session_conversations.get(unified_msg_origin)
# 如果没有提供 conversation_id,则获取当前的
conversation_id = await self.get_curr_conversation_id(unified_msg_origin)
if conversation_id:
await self.db.update_conversation(
cid=conversation_id,
@@ -227,7 +215,7 @@ class ConversationManager:
)
async def update_conversation_title(
self, unified_msg_origin: str, title: str, conversation_id: str = None
self, unified_msg_origin: str, title: str, conversation_id: str | None = None
):
"""更新会话的对话标题
@@ -245,7 +233,10 @@ class ConversationManager:
)
async def update_conversation_persona_id(
self, unified_msg_origin: str, persona_id: str, conversation_id: str = None
self,
unified_msg_origin: str,
persona_id: str,
conversation_id: str | None = None,
):
"""更新会话的对话 Persona ID
+40 -21
View File
@@ -74,7 +74,7 @@ class BaseDatabase(abc.ABC):
platform_id: str,
platform_type: str,
count: int = 1,
timestamp: datetime.datetime = None,
timestamp: datetime.datetime | None = None,
) -> None:
"""Insert a new platform statistic record."""
...
@@ -91,7 +91,7 @@ class BaseDatabase(abc.ABC):
@abc.abstractmethod
async def get_conversations(
self, user_id: str = None, platform_id: str = None
self, user_id: str | None = None, platform_id: str | None = None
) -> list[ConversationV2]:
"""Get all conversations for a specific user and platform_id(optional).
@@ -128,12 +128,12 @@ class BaseDatabase(abc.ABC):
self,
user_id: str,
platform_id: str,
content: list[dict] = None,
title: str = None,
persona_id: str = None,
cid: str = None,
created_at: datetime.datetime = None,
updated_at: datetime.datetime = None,
content: list[dict] | None = None,
title: str | None = None,
persona_id: str | None = None,
cid: str | None = None,
created_at: datetime.datetime | None = None,
updated_at: datetime.datetime | None = None,
) -> ConversationV2:
"""Create a new conversation."""
...
@@ -142,9 +142,9 @@ class BaseDatabase(abc.ABC):
async def update_conversation(
self,
cid: str,
title: str = None,
persona_id: str = None,
content: list[dict] = None,
title: str | None = None,
persona_id: str | None = None,
content: list[dict] | None = None,
) -> None:
"""Update a conversation's history."""
...
@@ -160,8 +160,8 @@ class BaseDatabase(abc.ABC):
platform_id: str,
user_id: str,
content: list[dict],
sender_id: str = None,
sender_name: str = None,
sender_id: str | None = None,
sender_name: str | None = None,
) -> None:
"""Insert a new platform message history record."""
...
@@ -204,8 +204,8 @@ class BaseDatabase(abc.ABC):
self,
persona_id: str,
system_prompt: str,
begin_dialogs: list[str] = None,
tools: list[str] = None,
begin_dialogs: list[str] | None = None,
tools: list[str] | None = None,
) -> Persona:
"""Insert a new persona record."""
...
@@ -224,9 +224,9 @@ class BaseDatabase(abc.ABC):
async def update_persona(
self,
persona_id: str,
system_prompt: str = None,
begin_dialogs: list[str] = None,
tools: list[str] = None,
system_prompt: str | None = None,
begin_dialogs: list[str] | None = None,
tools: list[str] | None = None,
) -> Persona | None:
"""Update a persona's system prompt or begin dialogs."""
...
@@ -237,13 +237,32 @@ class BaseDatabase(abc.ABC):
...
@abc.abstractmethod
async def insert_preference_or_update(self, key: str, value: str) -> Preference:
async def insert_preference_or_update(
self, scope: str, scope_id: str, key: str, value: dict
) -> Preference:
"""Insert a new preference record."""
...
@abc.abstractmethod
async def get_preference(self, key: str) -> Preference:
"""Get a preference by bot ID and key."""
async def get_preference(self, scope: str, scope_id: str, key: str) -> Preference:
"""Get a preference by scope ID and key."""
...
@abc.abstractmethod
async def get_preferences(
self, scope: str, scope_id: str | None = None, key: str | None = None
) -> list[Preference]:
"""Get all preferences for a specific scope ID or key."""
...
@abc.abstractmethod
async def remove_preference(self, scope: str, scope_id: str, key: str) -> None:
"""Remove a preference by scope ID and key."""
...
@abc.abstractmethod
async def clear_preferences(self, scope: str, scope_id: str) -> None:
"""Clear all preferences for a specific scope ID."""
...
# @abc.abstractmethod
+9 -3
View File
@@ -2,12 +2,13 @@ import os
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from astrbot.core.db import BaseDatabase
from astrbot.core.config import AstrBotConfig
from astrbot.api import logger
from astrbot.api import logger, sp
from .migra_3_to_4 import (
migration_conversation_table,
migration_platform_table,
migration_webchat_data,
migration_persona_data,
migration_preferences,
)
@@ -19,7 +20,9 @@ async def check_migration_needed_v4(db_helper: BaseDatabase) -> bool:
data_v3_exists = os.path.exists(get_astrbot_data_path())
if not data_v3_exists:
return False
migration_done = await db_helper.get_preference("migration_done_v4")
migration_done = await db_helper.get_preference(
"global", "global", "migration_done_v4"
)
if migration_done:
return False
return True
@@ -49,10 +52,13 @@ async def do_migration_v4(
# 执行 WebChat 数据迁移
await migration_webchat_data(db_helper, platform_id_map)
# 执行偏好设置迁移
await migration_preferences(db_helper,platform_id_map)
# 执行平台统计表迁移
await migration_platform_table(db_helper, platform_id_map)
# 标记迁移完成
await db_helper.insert_preference_or_update("migration_done_v4", "true")
await sp.put_async("global", "global", "migration_done_v4", True)
logger.info("数据库迁移完成。")
+83 -5
View File
@@ -2,8 +2,9 @@ import json
import datetime
from .. import BaseDatabase
from .sqlite_v3 import SQLiteDatabase as SQLiteV3DatabaseV3
from .shared_preferences_v3 import sp as sp_v3
from astrbot.core.config.default import DB_PATH
from astrbot.api import logger
from astrbot.api import logger, sp
from astrbot.core.config import AstrBotConfig
from astrbot.core.platform.astr_message_event import MessageSesion
from sqlalchemy.ext.asyncio import AsyncSession
@@ -89,7 +90,7 @@ async def migration_conversation_table(
async def migration_platform_table(
db_helper: BaseDatabase, platform_id_map: dict[str, str]
db_helper: BaseDatabase, platform_id_map: dict[str, dict[str, str]]
):
db_helper_v3 = SQLiteV3DatabaseV3(
db_path=DB_PATH.replace("data_v4.db", "data_v3.db")
@@ -122,9 +123,7 @@ async def migration_platform_table(
for bucket_idx, bucket_end in enumerate(range(start_time, end_time, 3600)):
if bucket_idx % 500 == 0:
progress = int((bucket_idx + 1) / total_buckets * 100)
logger.info(
f"进度: {progress}% ({bucket_idx + 1}/{total_buckets})"
)
logger.info(f"进度: {progress}% ({bucket_idx + 1}/{total_buckets})")
cnt = 0
while (
idx < len(platform_stats_v3)
@@ -258,3 +257,82 @@ async def migration_persona_data(
)
except Exception as e:
logger.error(f"解析 Persona 配置失败:{e}")
async def migration_preferences(
db_helper: BaseDatabase, platform_id_map: dict[str, dict[str, str]]
):
# 1. global scope migration
keys = [
"inactivated_llm_tools",
"inactivated_plugins",
"curr_provider",
"curr_provider_tts",
"curr_provider_stt",
"alter_cmd",
]
for key in keys:
value = sp_v3.get(key)
if value is not None:
await sp.put_async("global", "global", key, value)
logger.info(f"迁移全局偏好设置 {key} 成功,值: {value}")
# 2. umo scope migration
session_conversation = sp_v3.get("session_conversation", default={})
for umo, conversation_id in session_conversation.items():
if not umo or not conversation_id:
continue
try:
session = MessageSesion.from_str(session_str=umo)
platform_id = get_platform_id(platform_id_map, session.platform_name)
session.platform_id = platform_id
await sp.put_async("umo", str(session), "sel_conv_id", conversation_id)
logger.info(f"迁移会话 {umo} 的对话数据到新表成功,平台 ID: {platform_id}")
except Exception as e:
logger.error(f"迁移会话 {umo} 的对话数据失败: {e}", exc_info=True)
session_service_config = sp_v3.get("session_service_config", default={})
for umo, config in session_service_config.items():
if not umo or not config:
continue
try:
session = MessageSesion.from_str(session_str=umo)
platform_id = get_platform_id(platform_id_map, session.platform_name)
session.platform_id = platform_id
await sp.put_async("umo", str(session), "session_service_config", config)
logger.info(f"迁移会话 {umo} 的服务配置到新表成功,平台 ID: {platform_id}")
except Exception as e:
logger.error(f"迁移会话 {umo} 的服务配置失败: {e}", exc_info=True)
session_variables = sp_v3.get("session_variables", default={})
for umo, variables in session_variables.items():
if not umo or not variables:
continue
try:
session = MessageSesion.from_str(session_str=umo)
platform_id = get_platform_id(platform_id_map, session.platform_name)
session.platform_id = platform_id
await sp.put_async("umo", str(session), "session_variables", variables)
except Exception as e:
logger.error(f"迁移会话 {umo} 的变量失败: {e}", exc_info=True)
session_provider_perf = sp_v3.get("session_provider_perf", default={})
for umo, perf in session_provider_perf.items():
if not umo or not perf:
continue
try:
session = MessageSesion.from_str(session_str=umo)
platform_id = get_platform_id(platform_id_map, session.platform_name)
session.platform_id = platform_id
for provider_type, provider_id in perf.items():
await sp.put_async(
"umo", str(session), f"provider_perf_{provider_type}", provider_id
)
logger.info(
f"迁移会话 {umo} 的提供商偏好到新表成功,平台 ID: {platform_id}"
)
except Exception as e:
logger.error(f"迁移会话 {umo} 的提供商偏好失败: {e}", exc_info=True)
@@ -0,0 +1,45 @@
import json
import os
from typing import TypeVar
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
_VT = TypeVar("_VT")
class SharedPreferences:
def __init__(self, path=None):
if path is None:
path = os.path.join(get_astrbot_data_path(), "shared_preferences.json")
self.path = path
self._data = self._load_preferences()
def _load_preferences(self):
if os.path.exists(self.path):
try:
with open(self.path, "r") as f:
return json.load(f)
except json.JSONDecodeError:
os.remove(self.path)
return {}
def _save_preferences(self):
with open(self.path, "w") as f:
json.dump(self._data, f, indent=4, ensure_ascii=False)
f.flush()
def get(self, key, default: _VT = None) -> _VT:
return self._data.get(key, default)
def put(self, key, value):
self._data[key] = value
self._save_preferences()
def remove(self, key):
if key in self._data:
del self._data[key]
self._save_preferences()
def clear(self):
self._data.clear()
self._save_preferences()
sp = SharedPreferences()
+21 -5
View File
@@ -97,18 +97,34 @@ class Persona(SQLModel, table=True):
class Preference(SQLModel, table=True):
"""This class represents user preferences for bots."""
"""This class represents preferences for bots."""
__tablename__ = "preferences"
key: str = Field(primary_key=True, nullable=False)
value: str = Field(sa_type=Text, nullable=False)
id: int | None = Field(
default=None, primary_key=True, sa_column_kwargs={"autoincrement": True}
)
scope: str = Field(nullable=False)
"""Scope of the preference, such as 'global', 'umo', 'plugin'."""
scope_id: str = Field(nullable=False)
"""ID of the scope, such as 'global', 'umo', 'plugin_name'."""
key: str = Field(nullable=False)
value: dict = Field(sa_type=JSON, nullable=False)
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
updated_at: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc),
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
)
__table_args__ = (
UniqueConstraint(
"scope",
"scope_id",
"key",
name="uix_preference_scope_scope_id_key",
),
)
class PlatformMessageHistory(SQLModel, table=True):
"""This class represents the message history for a specific platform.
@@ -184,8 +200,8 @@ class Conversation:
"""对话 ID, 是 uuid 格式的字符串"""
history: str = ""
"""字符串格式的对话列表。"""
title: str = ""
persona_id: str = ""
title: str | None = ""
persona_id: str | None = ""
created_at: int = 0
updated_at: int = 0
+54 -5
View File
@@ -21,6 +21,7 @@ from sqlalchemy.sql import func
NOT_GIVEN = T.TypeVar("NOT_GIVEN")
class SQLiteDatabase(BaseDatabase):
def __init__(self, db_path: str) -> None:
self.db_path = db_path
@@ -373,29 +374,77 @@ class SQLiteDatabase(BaseDatabase):
delete(Persona).where(Persona.persona_id == persona_id)
)
async def insert_preference_or_update(self, key, value):
async def insert_preference_or_update(self, scope, scope_id, key, value):
"""Insert a new preference record or update if it exists."""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
query = select(Preference).where(Preference.key == key)
query = select(Preference).where(
Preference.scope == scope,
Preference.scope_id == scope_id,
Preference.key == key,
)
result = await session.execute(query)
existing_preference = result.scalar_one_or_none()
if existing_preference:
existing_preference.value = value
else:
new_preference = Preference(key=key, value=value)
new_preference = Preference(
scope=scope, scope_id=scope_id, key=key, value=value
)
session.add(new_preference)
return existing_preference or new_preference
async def get_preference(self, key):
async def get_preference(self, scope, scope_id, key):
"""Get a preference by key."""
async with self.get_db() as session:
session: AsyncSession
query = select(Preference).where(Preference.key == key)
query = select(Preference).where(
Preference.scope == scope,
Preference.scope_id == scope_id,
Preference.key == key,
)
result = await session.execute(query)
return result.scalar_one_or_none()
async def get_preferences(self, scope, scope_id=None, key=None):
"""Get all preferences for a specific scope ID or key."""
async with self.get_db() as session:
session: AsyncSession
query = select(Preference).where(Preference.scope == scope)
if scope_id is not None:
query = query.where(Preference.scope_id == scope_id)
if key is not None:
query = query.where(Preference.key == key)
result = await session.execute(query)
return result.scalars().all()
async def remove_preference(self, scope, scope_id, key):
"""Remove a preference by scope ID and key."""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
await session.execute(
delete(Preference).where(
Preference.scope == scope,
Preference.scope_id == scope_id,
Preference.key == key,
)
)
await session.commit()
async def clear_preferences(self, scope, scope_id):
"""Clear all preferences for a specific scope ID."""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
await session.execute(
delete(Preference).where(
Preference.scope == scope, Preference.scope_id == scope_id
)
)
await session.commit()
# ====
# Deprecated Methods
# ====
@@ -77,7 +77,7 @@ class WebChatAdapter(Platform):
os.makedirs(self.imgs_dir, exist_ok=True)
self.metadata = PlatformMetadata(
name="webchat", description="webchat", id=self.config.get("id", "")
name="webchat", description="webchat", id="webchat"
)
async def send_by_session(
+18 -4
View File
@@ -425,10 +425,17 @@ class FunctionToolManager:
if func_tool is not None:
func_tool.active = False
inactivated_llm_tools: list = sp.get("inactivated_llm_tools", [])
inactivated_llm_tools: list = sp.get(
"inactivated_llm_tools", [], scope="global", scope_id="global"
)
if name not in inactivated_llm_tools:
inactivated_llm_tools.append(name)
sp.put("inactivated_llm_tools", inactivated_llm_tools)
sp.put(
"inactivated_llm_tools",
inactivated_llm_tools,
scope="global",
scope_id="global",
)
return True
return False
@@ -445,10 +452,17 @@ class FunctionToolManager:
func_tool.active = True
inactivated_llm_tools: list = sp.get("inactivated_llm_tools", [])
inactivated_llm_tools: list = sp.get(
"inactivated_llm_tools", [], scope="global", scope_id="global"
)
if name in inactivated_llm_tools:
inactivated_llm_tools.remove(name)
sp.put("inactivated_llm_tools", inactivated_llm_tools)
sp.put(
"inactivated_llm_tools",
inactivated_llm_tools,
scope="global",
scope_id="global",
)
return True
return False
+27 -15
View File
@@ -66,7 +66,7 @@ class ProviderManager:
return self.persona_mgr.selected_default_persona_v3
async def set_provider(
self, provider_id: str, provider_type: ProviderType, umo: str = None
self, provider_id: str, provider_type: ProviderType, umo: str | None = None
):
"""设置提供商。
@@ -80,20 +80,20 @@ class ProviderManager:
if provider_id not in self.inst_map:
raise ValueError(f"提供商 {provider_id} 不存在,无法设置。")
if umo:
perf = sp.get("session_provider_perf", {})
session_perf = perf.get(umo, {})
session_perf[provider_type.value] = provider_id
perf[umo] = session_perf
sp.put("session_provider_perf", perf)
await sp.session_put(
umo,
f"provider_perf_{provider_type.value}",
provider_id,
)
return
# 不启用提供商会话隔离模式的情况
self.curr_provider_inst = self.inst_map[provider_id]
if provider_type == ProviderType.TEXT_TO_SPEECH:
sp.put("curr_provider_tts", provider_id)
sp.put("curr_provider_tts", provider_id, scope="global", scope_id="global")
elif provider_type == ProviderType.SPEECH_TO_TEXT:
sp.put("curr_provider_stt", provider_id)
sp.put("curr_provider_stt", provider_id, scope="global", scope_id="global")
elif provider_type == ProviderType.CHAT_COMPLETION:
sp.put("curr_provider", provider_id)
sp.put("curr_provider", provider_id, scope="global", scope_id="global")
async def get_provider_by_id(self, provider_id: str) -> Provider | None:
"""根据提供商 ID 获取提供商实例"""
@@ -111,9 +111,12 @@ class ProviderManager:
"""
provider = None
if umo:
perf = sp.get("session_provider_perf", {})
session_perf = perf.get(umo, {})
provider_id = session_perf.get(provider_type.value)
provider_id = sp.get(
f"provider_perf_{provider_type.value}",
None,
scope="umo",
scope_id=umo,
)
if provider_id:
provider = self.inst_map.get(provider_id)
if not provider:
@@ -153,13 +156,22 @@ class ProviderManager:
# 设置默认提供商
selected_provider_id = sp.get(
"curr_provider", self.provider_settings.get("default_provider_id")
"curr_provider",
self.provider_settings.get("default_provider_id"),
scope="global",
scope_id="global",
)
selected_stt_provider_id = sp.get(
"curr_provider_stt", self.provider_stt_settings.get("provider_id")
"curr_provider_stt",
self.provider_stt_settings.get("provider_id"),
scope="global",
scope_id="global",
)
selected_tts_provider_id = sp.get(
"curr_provider_tts", self.provider_tts_settings.get("provider_id")
"curr_provider_tts",
self.provider_tts_settings.get("provider_id"),
scope="global",
scope_id="global",
)
self.curr_provider_inst = self.inst_map.get(selected_provider_id)
if not self.curr_provider_inst and self.provider_insts:
@@ -75,8 +75,7 @@ class ProviderDashscope(ProviderOpenAIOfficial):
# 获得会话变量
payload_vars = self.variables.copy()
# 动态变量
session_vars = sp.get("session_variables", {})
session_var = session_vars.get(session_id, {})
session_var = await sp.session_get(session_id, "session_variables", default={})
payload_vars.update(session_var)
if (
+1 -2
View File
@@ -97,8 +97,7 @@ class ProviderDify(Provider):
# 获得会话变量
payload_vars = self.variables.copy()
# 动态变量
session_vars = sp.get("session_variables", {})
session_var = session_vars.get(session_id, {})
session_var = await sp.session_get(session_id, "session_variables", default={})
payload_vars.update(session_var)
payload_vars["system_prompt"] = system_prompt
+42 -85
View File
@@ -2,8 +2,6 @@
会话服务管理器 - 负责管理每个会话的LLMTTS等服务的启停状态
"""
from typing import Dict
from astrbot.core import logger, sp
from astrbot.core.platform.astr_message_event import AstrMessageEvent
@@ -26,8 +24,9 @@ class SessionServiceManager:
bool: True表示启用False表示禁用
"""
# 获取会话服务配置
session_config = sp.get("session_service_config", {}) or {}
session_services = session_config.get(session_id, {})
session_services = sp.get(
"session_service_config", {}, scope="umo", scope_id=session_id
)
# 如果配置了该会话的LLM状态,返回该状态
llm_enabled = session_services.get("llm_enabled")
@@ -45,16 +44,13 @@ class SessionServiceManager:
session_id: 会话ID (unified_msg_origin)
enabled: True表示启用False表示禁用
"""
# 获取当前配置
session_config = sp.get("session_service_config", {}) or {}
if session_id not in session_config:
session_config[session_id] = {}
# 设置LLM状态
session_config[session_id]["llm_enabled"] = enabled
# 保存配置
sp.put("session_service_config", session_config)
session_config = (
sp.get("session_service_config", {}, scope="umo", scope_id=session_id) or {}
)
session_config["llm_enabled"] = enabled
sp.put(
"session_service_config", session_config, scope="umo", scope_id=session_id
)
logger.info(
f"会话 {session_id} 的LLM状态已更新为: {'启用' if enabled else '禁用'}"
@@ -88,8 +84,9 @@ class SessionServiceManager:
bool: True表示启用False表示禁用
"""
# 获取会话服务配置
session_config = sp.get("session_service_config", {}) or {}
session_services = session_config.get(session_id, {})
session_services = sp.get(
"session_service_config", {}, scope="umo", scope_id=session_id
)
# 如果配置了该会话的TTS状态,返回该状态
tts_enabled = session_services.get("tts_enabled")
@@ -107,16 +104,13 @@ class SessionServiceManager:
session_id: 会话ID (unified_msg_origin)
enabled: True表示启用False表示禁用
"""
# 获取当前配置
session_config = sp.get("session_service_config", {}) or {}
if session_id not in session_config:
session_config[session_id] = {}
# 设置TTS状态
session_config[session_id]["tts_enabled"] = enabled
# 保存配置
sp.put("session_service_config", session_config)
session_config = (
sp.get("session_service_config", {}, scope="umo", scope_id=session_id) or {}
)
session_config["tts_enabled"] = enabled
sp.put(
"session_service_config", session_config, scope="umo", scope_id=session_id
)
logger.info(
f"会话 {session_id} 的TTS状态已更新为: {'启用' if enabled else '禁用'}"
@@ -150,8 +144,9 @@ class SessionServiceManager:
bool: True表示启用False表示禁用
"""
# 获取会话服务配置
session_config = sp.get("session_service_config", {}) or {}
session_services = session_config.get(session_id, {})
session_services = sp.get(
"session_service_config", {}, scope="umo", scope_id=session_id
)
# 如果配置了该会话的整体状态,返回该状态
session_enabled = session_services.get("session_enabled")
@@ -169,16 +164,13 @@ class SessionServiceManager:
session_id: 会话ID (unified_msg_origin)
enabled: True表示启用False表示禁用
"""
# 获取当前配置
session_config = sp.get("session_service_config", {}) or {}
if session_id not in session_config:
session_config[session_id] = {}
# 设置会话整体状态
session_config[session_id]["session_enabled"] = enabled
# 保存配置
sp.put("session_service_config", session_config)
session_config = (
sp.get("session_service_config", {}, scope="umo", scope_id=session_id) or {}
)
session_config["session_enabled"] = enabled
sp.put(
"session_service_config", session_config, scope="umo", scope_id=session_id
)
logger.info(
f"会话 {session_id} 的整体状态已更新为: {'启用' if enabled else '禁用'}"
@@ -202,7 +194,7 @@ class SessionServiceManager:
# =============================================================================
@staticmethod
def get_session_custom_name(session_id: str) -> str:
def get_session_custom_name(session_id: str) -> str | None:
"""获取会话的自定义名称
Args:
@@ -211,8 +203,9 @@ class SessionServiceManager:
Returns:
str: 自定义名称如果没有设置则返回None
"""
session_config = sp.get("session_service_config", {}) or {}
session_services = session_config.get(session_id, {})
session_services = sp.get(
"session_service_config", {}, scope="umo", scope_id=session_id
)
return session_services.get("custom_name")
@staticmethod
@@ -223,20 +216,17 @@ class SessionServiceManager:
session_id: 会话ID (unified_msg_origin)
custom_name: 自定义名称可以为空字符串来清除名称
"""
# 获取当前配置
session_config = sp.get("session_service_config", {}) or {}
if session_id not in session_config:
session_config[session_id] = {}
# 设置自定义名称
session_config = (
sp.get("session_service_config", {}, scope="umo", scope_id=session_id) or {}
)
if custom_name and custom_name.strip():
session_config[session_id]["custom_name"] = custom_name.strip()
session_config["custom_name"] = custom_name.strip()
else:
# 如果传入空名称,则删除自定义名称
session_config[session_id].pop("custom_name", None)
# 保存配置
sp.put("session_service_config", session_config)
session_config.pop("custom_name", None)
sp.put(
"session_service_config", session_config, scope="umo", scope_id=session_id
)
logger.info(
f"会话 {session_id} 的自定义名称已更新为: {custom_name.strip() if custom_name and custom_name.strip() else '已清除'}"
@@ -258,36 +248,3 @@ class SessionServiceManager:
# 如果没有自定义名称,返回session_id的最后一段
return session_id.split(":")[2] if session_id.count(":") >= 2 else session_id
# =============================================================================
# 通用配置方法
# =============================================================================
@staticmethod
def get_session_service_config(session_id: str) -> Dict[str, bool]:
"""获取指定会话的服务配置
Args:
session_id: 会话ID (unified_msg_origin)
Returns:
Dict[str, bool]: 包含session_enabledllm_enabledtts_enabled的字典
"""
session_config = sp.get("session_service_config", {}) or {}
return session_config.get(
session_id,
{
"session_enabled": True, # 默认启用
"llm_enabled": True, # 默认启用
"tts_enabled": True, # 默认启用
},
)
@staticmethod
def get_all_session_configs() -> Dict[str, Dict[str, bool]]:
"""获取所有会话的服务配置
Returns:
Dict[str, Dict[str, bool]]: 所有会话的服务配置
"""
return sp.get("session_service_config", {}) or {}
+12 -4
View File
@@ -22,7 +22,9 @@ class SessionPluginManager:
bool: True表示启用False表示禁用
"""
# 获取会话插件配置
session_plugin_config = sp.get("session_plugin_config", {}) or {}
session_plugin_config = sp.get(
"session_plugin_config", {}, scope="umo", scope_id=session_id
)
session_config = session_plugin_config.get(session_id, {})
enabled_plugins = session_config.get("enabled_plugins", [])
@@ -51,7 +53,9 @@ class SessionPluginManager:
enabled: True表示启用False表示禁用
"""
# 获取当前配置
session_plugin_config = sp.get("session_plugin_config", {}) or {}
session_plugin_config = sp.get(
"session_plugin_config", {}, scope="umo", scope_id=session_id
)
if session_id not in session_plugin_config:
session_plugin_config[session_id] = {
"enabled_plugins": [],
@@ -79,7 +83,9 @@ class SessionPluginManager:
session_config["enabled_plugins"] = enabled_plugins
session_config["disabled_plugins"] = disabled_plugins
session_plugin_config[session_id] = session_config
sp.put("session_plugin_config", session_plugin_config)
sp.put(
"session_plugin_config", session_plugin_config, scope="umo", scope_id=session_id
)
logger.info(
f"会话 {session_id} 的插件 {plugin_name} 状态已更新为: {'启用' if enabled else '禁用'}"
@@ -95,7 +101,9 @@ class SessionPluginManager:
Returns:
Dict[str, List[str]]: 包含enabled_plugins和disabled_plugins的字典
"""
session_plugin_config = sp.get("session_plugin_config", {}) or {}
session_plugin_config = sp.get(
"session_plugin_config", {}, scope="umo", scope_id=session_id
)
return session_plugin_config.get(
session_id, {"enabled_plugins": [], "disabled_plugins": []}
)
+11 -12
View File
@@ -374,10 +374,9 @@ class PluginManager:
- success (bool): 是否全部加载成功
- error_message (str|None): 错误信息成功时为 None
"""
inactivated_plugins: list = sp.get("inactivated_plugins", [])
inactivated_llm_tools: list = sp.get("inactivated_llm_tools", [])
alter_cmd = sp.get("alter_cmd", {})
inactivated_plugins = await sp.global_get("inactivated_plugins", [])
inactivated_llm_tools = await sp.global_get("inactivated_llm_tools", [])
alter_cmd = await sp.global_get("alter_cmd", {})
plugin_modules = self._get_plugin_modules()
if plugin_modules is None:
@@ -787,12 +786,12 @@ class PluginManager:
await self._terminate_plugin(plugin)
# 加入到 shared_preferences 中
inactivated_plugins: list = sp.get("inactivated_plugins", [])
inactivated_plugins: list = await sp.global_get("inactivated_plugins", [])
if plugin.module_path not in inactivated_plugins:
inactivated_plugins.append(plugin.module_path)
inactivated_llm_tools: list = list(
set(sp.get("inactivated_llm_tools", []))
set(await sp.global_get("inactivated_llm_tools", []))
) # 后向兼容
# 禁用插件启用的 llm_tool
@@ -802,8 +801,8 @@ class PluginManager:
if func_tool.name not in inactivated_llm_tools:
inactivated_llm_tools.append(func_tool.name)
sp.put("inactivated_plugins", inactivated_plugins)
sp.put("inactivated_llm_tools", inactivated_llm_tools)
await sp.global_put("inactivated_plugins", inactivated_plugins)
await sp.global_put("inactivated_llm_tools", inactivated_llm_tools)
plugin.activated = False
@@ -829,11 +828,11 @@ class PluginManager:
async def turn_on_plugin(self, plugin_name: str):
plugin = self.context.get_registered_star(plugin_name)
inactivated_plugins: list = sp.get("inactivated_plugins", [])
inactivated_llm_tools: list = sp.get("inactivated_llm_tools", [])
inactivated_plugins: list = await sp.global_get("inactivated_plugins", [])
inactivated_llm_tools: list = await sp.global_get("inactivated_llm_tools", [])
if plugin.module_path in inactivated_plugins:
inactivated_plugins.remove(plugin.module_path)
sp.put("inactivated_plugins", inactivated_plugins)
await sp.global_put("inactivated_plugins", inactivated_plugins)
# 启用插件启用的 llm_tool
for func_tool in llm_tools.func_list:
@@ -843,7 +842,7 @@ class PluginManager:
):
inactivated_llm_tools.remove(func_tool.name)
func_tool.active = True
sp.put("inactivated_llm_tools", inactivated_llm_tools)
await sp.global_put("inactivated_llm_tools", inactivated_llm_tools)
await self.reload(plugin_name)
+202 -31
View File
@@ -1,43 +1,214 @@
import json
from astrbot.core.db import BaseDatabase
from astrbot.core.db.po import Preference
import threading
import asyncio
import os
from typing import TypeVar
from typing import TypeVar, Any, overload
from .astrbot_path import get_astrbot_data_path
_VT = TypeVar("_VT")
class SharedPreferences:
def __init__(self, path=None):
if path is None:
path = os.path.join(get_astrbot_data_path(), "shared_preferences.json")
self.path = path
self._data = self._load_preferences()
def __init__(self, db_helper: BaseDatabase, json_storage_path=None):
if json_storage_path is None:
json_storage_path = os.path.join(
get_astrbot_data_path(), "shared_preferences.json"
)
self.path = json_storage_path
self.db_helper = db_helper
def _load_preferences(self):
if os.path.exists(self.path):
try:
with open(self.path, "r") as f:
return json.load(f)
except json.JSONDecodeError:
os.remove(self.path)
return {}
async def get_async(
self,
scope: str,
scope_id: str,
key: str,
default: _VT = None,
) -> _VT:
"""获取指定范围和键的偏好设置"""
if scope_id is not None and key is not None:
result = await self.db_helper.get_preference(scope, scope_id, key)
if result:
ret = result.value["val"]
else:
ret = default
return ret
else:
raise ValueError(
"scope_id and key cannot be None when getting a specific preference."
)
def _save_preferences(self):
with open(self.path, "w") as f:
json.dump(self._data, f, indent=4, ensure_ascii=False)
f.flush()
async def range_get_async(
self, scope: str, scope_id: str | None = None, key: str | None = None
) -> list[Preference]:
"""获取指定范围的偏好设置
Note: 返回 Preference 列表其中的 value 属性是一个 dictvalue["val"] 为值scope_id key 可以为 None这时返回该范围下所有的偏好设置
"""
ret = await self.db_helper.get_preferences(scope, scope_id, key)
return ret
def get(self, key, default: _VT = None) -> _VT:
return self._data.get(key, default)
@overload
async def session_get(
self, umo: None, key: str, default: Any = None
) -> list[Preference]: ...
def put(self, key, value):
self._data[key] = value
self._save_preferences()
@overload
async def session_get(
self, umo: str, key: None, default: Any = None
) -> list[Preference]: ...
def remove(self, key):
if key in self._data:
del self._data[key]
self._save_preferences()
@overload
async def session_get(
self, umo: None, key: None, default: Any = None
) -> list[Preference]: ...
def clear(self):
self._data.clear()
self._save_preferences()
async def session_get(
self, umo: str | None, key: str | None = None, default: _VT = None
) -> _VT | list[Preference]:
"""获取会话范围的偏好设置
Note: scope_id 或者 key None返回 Preference 列表其中的 value 属性是一个 dictvalue["val"] 为值
"""
if umo is None or key is None:
return await self.range_get_async("umo", umo, key)
return await self.get_async("umo", umo, key, default)
@overload
async def global_get(
self, key: None, default: Any = None
) -> list[Preference]: ...
@overload
async def global_get(
self, key: str, default: _VT = None
) -> _VT: ...
async def global_get(
self, key: str | None, default: _VT = None
) -> _VT | list[Preference]:
"""获取全局范围的偏好设置
Note: scope_id 或者 key None返回 Preference 列表其中的 value 属性是一个 dictvalue["val"] 为值
"""
if key is None:
return await self.range_get_async("global", "global", key)
return await self.get_async("global", "global", key, default)
async def put_async(self, scope: str, scope_id: str, key: str, value: Any):
"""设置指定范围和键的偏好设置"""
await self.db_helper.insert_preference_or_update(
scope, scope_id, key, {"val": value}
)
async def session_put(self, umo: str, key: str, value: Any):
await self.put_async("umo", umo, key, value)
async def global_put(self, key: str, value: Any):
await self.put_async("global", "global", key, value)
async def remove_async(self, scope: str, scope_id: str, key: str):
"""删除指定范围和键的偏好设置"""
await self.db_helper.remove_preference(scope, scope_id, key)
async def session_remove(self, umo: str, key: str):
await self.remove_async("umo", umo, key)
async def global_remove(self, key: str):
"""删除全局偏好设置"""
await self.remove_async("global", "global", key)
async def clear_async(self, scope: str, scope_id: str):
"""清空指定范围的所有偏好设置"""
await self.db_helper.clear_preferences(scope, scope_id)
# ====
# DEPRECATED METHODS
# ====
def get(
self,
key: str,
default: _VT = None,
scope: str | None = None,
scope_id: str | None = "",
) -> _VT:
"""获取偏好设置(已弃用)"""
result: _VT | None = None
def runner():
nonlocal result, scope, scope_id
scope = scope or "unknown"
if scope_id == "":
scope_id = "unknown"
if scope_id is None or key is None:
# result = asyncio.run(self.range_get_async(scope, scope_id, key))
raise ValueError(
"scope_id and key cannot be None when getting a specific preference."
)
else:
result = asyncio.run(self.get_async(scope, scope_id, key, default))
t = threading.Thread(target=runner)
t.start()
t.join()
if result is None:
return default
return result
def range_get(
self, scope: str, scope_id: str | None = None, key: str | None = None
) -> list[Preference]:
"""获取指定范围的偏好设置(已弃用)"""
result: list[Preference] = []
def runner():
nonlocal result, scope, scope_id, key
result = asyncio.run(self.range_get_async(scope, scope_id, key))
t = threading.Thread(target=runner)
t.start()
t.join()
return result
def put(self, key, value, scope: str | None = None, scope_id: str | None = None):
"""设置偏好设置(已弃用)"""
def runner():
nonlocal scope, scope_id
scope = scope or "unknown"
scope_id = scope_id or "unknown"
asyncio.run(self.put_async(scope, scope_id, key, value))
t = threading.Thread(target=runner)
t.start()
t.join()
def remove(self, key, scope: str | None = None, scope_id: str | None = None):
"""删除偏好设置(已弃用)"""
def runner():
nonlocal scope, scope_id
scope = scope or "unknown"
scope_id = scope_id or "unknown"
asyncio.run(self.remove_async(scope, scope_id, key))
t = threading.Thread(target=runner)
t.start()
t.join()
def clear(self, scope: str | None = None, scope_id: str | None = None):
"""清空偏好设置(已弃用)"""
def runner():
nonlocal scope, scope_id
scope = scope or "unknown"
scope_id = scope_id or "unknown"
asyncio.run(self.clear_async(scope, scope_id))
t = threading.Thread(target=runner)
t.start()
t.join()
+9 -11
View File
@@ -38,7 +38,13 @@ class SessionManagementRoute(Route):
async def list_sessions(self):
"""获取所有会话的列表,包括 persona 和 provider 信息"""
try:
session_conversations = sp.get("session_conversation", {}) or {}
preferences = await sp.session_get(umo=None, key="sel_conv_id", default=[])
session_conversations = {}
for pref in preferences:
session_conversations[pref.scope_id] = pref.value["val"]
logger.debug(session_conversations)
provider_manager = self.core_lifecycle.provider_manager
persona_mgr = self.core_lifecycle.persona_mgr
personas = persona_mgr.personas_v3
@@ -51,13 +57,9 @@ class SessionManagementRoute(Route):
"session_id": session_id,
"conversation_id": conversation_id,
"persona_id": None,
"persona_name": None,
"chat_provider_id": None,
"chat_provider_name": None,
"stt_provider_id": None,
"stt_provider_name": None,
"tts_provider_id": None,
"tts_provider_name": None,
"session_enabled": SessionServiceManager.is_session_enabled(
session_id
),
@@ -92,16 +94,15 @@ class SessionManagementRoute(Route):
if conversation.persona_id and conversation.persona_id != "[%None]":
for persona in personas:
if persona["name"] == conversation.persona_id:
session_info["persona_name"] = persona["name"]
session_info["persona_id"] = persona["name"]
break
elif conversation.persona_id == "[%None]":
session_info["persona_name"] = "无人格"
session_info["persona_id"] = "无人格"
else:
# 使用默认人格
default_persona = persona_mgr.selected_default_persona_v3
if default_persona:
session_info["persona_id"] = default_persona["name"]
session_info["persona_name"] = default_persona["name"]
# 获取 provider 信息
provider_manager = self.core_lifecycle.provider_manager
@@ -117,15 +118,12 @@ class SessionManagementRoute(Route):
if chat_provider:
meta = chat_provider.meta()
session_info["chat_provider_id"] = meta.id
session_info["chat_provider_name"] = meta.id
if tts_provider:
meta = tts_provider.meta()
session_info["tts_provider_id"] = meta.id
session_info["tts_provider_name"] = meta.id
if stt_provider:
meta = stt_provider.meta()
session_info["stt_provider_id"] = meta.id
session_info["stt_provider_name"] = meta.id
sessions.append(session_info)
+12 -19
View File
@@ -503,7 +503,7 @@ UID: {user_id} 此 ID 可用于设置管理员。
scene = RstScene.get_scene(is_group, is_unique_session)
alter_cmd_cfg = sp.get("alter_cmd", {})
alter_cmd_cfg = await sp.get_async("global", "global", "alter_cmd", {})
plugin_config = alter_cmd_cfg.get("astrbot", {})
reset_cfg = plugin_config.get("reset", {})
@@ -1101,29 +1101,22 @@ UID: {user_id} 此 ID 可用于设置管理员。
async def set_variable(self, event: AstrMessageEvent, key: str, value: str):
# session_id = event.get_session_id()
uid = event.unified_msg_origin
session_vars = sp.get("session_variables", {})
session_var = session_vars.get(uid, {})
session_var = await sp.session_get(uid, "session_variables", {})
session_var[key] = value
session_vars[uid] = session_var
sp.put("session_variables", session_vars)
await sp.session_put(uid, "session_variables", session_var)
yield event.plain_result(f"会话 {uid} 变量 {key} 存储成功。使用 /unset 移除。")
@filter.command("unset")
async def unset_variable(self, event: AstrMessageEvent, key: str):
uid = event.unified_msg_origin
session_vars = sp.get("session_variables", {})
session_var = session_vars.get(uid, {})
session_var = await sp.session_get(umo="uid", key="session_variables", default={})
if key not in session_var:
yield event.plain_result("没有那个变量名。格式 /unset 变量名。")
else:
del session_var[key]
sp.put("session_variables", session_vars)
await sp.session_put(uid, "session_variables", session_var)
yield event.plain_result(f"会话 {uid} 变量 {key} 移除成功。")
@filter.platform_adapter_type(filter.PlatformAdapterType.ALL)
@@ -1356,7 +1349,7 @@ UID: {user_id} 此 ID 可用于设置管理员。
# 对reset权限进行特殊处理
# ============================
if cmd_name == "reset" and cmd_type == "config":
alter_cmd_cfg = sp.get("alter_cmd", {})
alter_cmd_cfg = await sp.global_get("alter_cmd", {})
plugin_ = alter_cmd_cfg.get("astrbot", {})
reset_cfg = plugin_.get("reset", {})
@@ -1391,7 +1384,7 @@ UID: {user_id} 此 ID 可用于设置管理员。
scene = RstScene.from_index(scene_num)
scene_key = scene.key
self.update_reset_permission(scene_key, perm_type)
await self.update_reset_permission(scene_key, perm_type)
yield event.plain_result(
f"已将 reset 命令在{scene.name}场景下的权限设为{perm_type}"
@@ -1422,14 +1415,14 @@ UID: {user_id} 此 ID 可用于设置管理员。
found_plugin = star_map[found_command.handler_module_path]
alter_cmd_cfg = sp.get("alter_cmd", {})
alter_cmd_cfg = await sp.global_get("alter_cmd", {})
plugin_ = alter_cmd_cfg.get(found_plugin.name, {})
cfg = plugin_.get(found_command.handler_name, {})
cfg["permission"] = cmd_type
plugin_[found_command.handler_name] = cfg
alter_cmd_cfg[found_plugin.name] = plugin_
sp.put("alter_cmd", alter_cmd_cfg)
await sp.global_put("alter_cmd", alter_cmd_cfg)
# 注入权限过滤器
found_permission_filter = False
@@ -1453,17 +1446,17 @@ UID: {user_id} 此 ID 可用于设置管理员。
yield event.plain_result(f"已将 {cmd_name} 设置为 {cmd_type} 指令")
def update_reset_permission(self, scene_key: str, perm_type: str):
async def update_reset_permission(self, scene_key: str, perm_type: str):
"""更新reset命令在特定场景下的权限设置
Args:
scene_key (str): 场景编号1-3
perm_type (str): 权限类型admin或member
"""
alter_cmd_cfg = sp.get("alter_cmd", {})
alter_cmd_cfg = await sp.global_get("alter_cmd", {})
plugin_cfg = alter_cmd_cfg.get("astrbot", {})
reset_cfg = plugin_cfg.get("reset", {})
reset_cfg[scene_key] = perm_type
plugin_cfg["reset"] = reset_cfg
alter_cmd_cfg["astrbot"] = plugin_cfg
sp.put("alter_cmd", alter_cmd_cfg)
await sp.global_put("alter_cmd", alter_cmd_cfg)