fix: database is locked error when invoking tts command (#4313)

* fix: database is locked error when invoking /tts command

fixes: #4311

* chore: rm pnpm lockfile

* perf: 减少操作数据库的次数
This commit is contained in:
Soulter
2026-01-03 19:12:39 +08:00
committed by GitHub
parent 442b5403df
commit 454841de10
10 changed files with 101 additions and 60 deletions
@@ -14,13 +14,13 @@ class TTSCommand:
async def tts(self, event: AstrMessageEvent):
"""开关文本转语音(会话级别)"""
umo = event.unified_msg_origin
ses_tts = SessionServiceManager.is_tts_enabled_for_session(umo)
ses_tts = await SessionServiceManager.is_tts_enabled_for_session(umo)
cfg = self.context.get_config(umo=umo)
tts_enable = cfg["provider_tts_settings"]["enable"]
# 切换状态
new_status = not ses_tts
SessionServiceManager.set_tts_status_for_session(umo, new_status)
await SessionServiceManager.set_tts_status_for_session(umo, new_status)
status_text = "已开启" if new_status else "已关闭"
+1
View File
@@ -90,6 +90,7 @@ class AstrBotCoreLifecycle:
# 初始化 UMOP 配置路由器
self.umop_config_router = UmopConfigRouter(sp=sp)
await self.umop_config_router.initialize()
# 初始化 AstrBot 配置管理器
self.astrbot_config_mgr = AstrBotConfigManager(
@@ -38,7 +38,7 @@ class AgentRequestSubStage(Stage):
)
return
if not SessionServiceManager.should_process_llm_request(event):
if not await SessionServiceManager.should_process_llm_request(event):
logger.debug(
f"The session {event.unified_msg_origin} has disabled AI capability, skipping processing."
)
@@ -260,7 +260,7 @@ class ResultDecorateStage(Stage):
should_tts = (
bool(self.ctx.astrbot_config["provider_tts_settings"]["enable"])
and result.is_llm_result()
and SessionServiceManager.should_process_tts_request(event)
and await SessionServiceManager.should_process_tts_request(event)
and random.random() <= self.tts_trigger_probability
and tts_provider
)
@@ -21,7 +21,7 @@ class SessionStatusCheckStage(Stage):
event: AstrMessageEvent,
) -> None | AsyncGenerator[None, None]:
# 检查会话是否整体启用
if not SessionServiceManager.is_session_enabled(event.unified_msg_origin):
if not await SessionServiceManager.is_session_enabled(event.unified_msg_origin):
logger.debug(f"会话 {event.unified_msg_origin} 已被关闭,已终止事件传播。")
# workaround for #2309
+1 -1
View File
@@ -227,7 +227,7 @@ class WakingCheckStage(Stage):
event._extras.pop("parsed_params", None)
# 根据会话配置过滤插件处理器
activated_handlers = SessionPluginManager.filter_handlers_by_session(
activated_handlers = await SessionPluginManager.filter_handlers_by_session(
event,
activated_handlers,
)
+27 -12
View File
@@ -119,19 +119,34 @@ class ProviderManager:
TTSProvider,
):
self.curr_tts_provider_inst = prov
sp.put("curr_provider_tts", provider_id, scope="global", scope_id="global")
await sp.put_async(
key="curr_provider_tts",
value=provider_id,
scope="global",
scope_id="global",
)
elif provider_type == ProviderType.SPEECH_TO_TEXT and isinstance(
prov,
STTProvider,
):
self.curr_stt_provider_inst = prov
sp.put("curr_provider_stt", provider_id, scope="global", scope_id="global")
await sp.put_async(
key="curr_provider_stt",
value=provider_id,
scope="global",
scope_id="global",
)
elif provider_type == ProviderType.CHAT_COMPLETION and isinstance(
prov,
Provider,
):
self.curr_provider_inst = prov
sp.put("curr_provider", provider_id, scope="global", scope_id="global")
await sp.put_async(
key="curr_provider",
value=provider_id,
scope="global",
scope_id="global",
)
async def get_provider_by_id(self, provider_id: str) -> Providers | None:
"""根据提供商 ID 获取提供商实例"""
@@ -206,21 +221,21 @@ class ProviderManager:
logger.error(traceback.format_exc())
logger.error(e)
selected_provider_id = sp.get(
"curr_provider",
self.provider_settings.get("default_provider_id"),
selected_provider_id = await sp.get_async(
key="curr_provider",
default=self.provider_settings.get("default_provider_id"),
scope="global",
scope_id="global",
)
selected_stt_provider_id = sp.get(
"curr_provider_stt",
self.provider_stt_settings.get("provider_id"),
selected_stt_provider_id = await sp.get_async(
key="curr_provider_stt",
default=self.provider_stt_settings.get("provider_id"),
scope="global",
scope_id="global",
)
selected_tts_provider_id = sp.get(
"curr_provider_tts",
self.provider_tts_settings.get("provider_id"),
selected_tts_provider_id = await sp.get_async(
key="curr_provider_tts",
default=self.provider_tts_settings.get("provider_id"),
scope="global",
scope_id="global",
)
+38 -26
View File
@@ -12,7 +12,7 @@ class SessionServiceManager:
# =============================================================================
@staticmethod
def is_llm_enabled_for_session(session_id: str) -> bool:
async def is_llm_enabled_for_session(session_id: str) -> bool:
"""检查LLM是否在指定会话中启用
Args:
@@ -23,11 +23,11 @@ class SessionServiceManager:
"""
# 获取会话服务配置
session_services = sp.get(
"session_service_config",
{},
session_services = await sp.get_async(
scope="umo",
scope_id=session_id,
key="session_service_config",
default={},
)
# 如果配置了该会话的LLM状态,返回该状态
@@ -39,7 +39,7 @@ class SessionServiceManager:
return True
@staticmethod
def set_llm_status_for_session(session_id: str, enabled: bool) -> None:
async def set_llm_status_for_session(session_id: str, enabled: bool) -> None:
"""设置LLM在指定会话中的启停状态
Args:
@@ -48,18 +48,24 @@ class SessionServiceManager:
"""
session_config = (
sp.get("session_service_config", {}, scope="umo", scope_id=session_id) or {}
await sp.get_async(
scope="umo",
scope_id=session_id,
key="session_service_config",
default={},
)
or {}
)
session_config["llm_enabled"] = enabled
sp.put(
"session_service_config",
session_config,
await sp.put_async(
scope="umo",
scope_id=session_id,
key="session_service_config",
value=session_config,
)
@staticmethod
def should_process_llm_request(event: AstrMessageEvent) -> bool:
async def should_process_llm_request(event: AstrMessageEvent) -> bool:
"""检查是否应该处理LLM请求
Args:
@@ -70,14 +76,14 @@ class SessionServiceManager:
"""
session_id = event.unified_msg_origin
return SessionServiceManager.is_llm_enabled_for_session(session_id)
return await SessionServiceManager.is_llm_enabled_for_session(session_id)
# =============================================================================
# TTS 相关方法
# =============================================================================
@staticmethod
def is_tts_enabled_for_session(session_id: str) -> bool:
async def is_tts_enabled_for_session(session_id: str) -> bool:
"""检查TTS是否在指定会话中启用
Args:
@@ -88,11 +94,11 @@ class SessionServiceManager:
"""
# 获取会话服务配置
session_services = sp.get(
"session_service_config",
{},
session_services = await sp.get_async(
scope="umo",
scope_id=session_id,
key="session_service_config",
default={},
)
# 如果配置了该会话的TTS状态,返回该状态
@@ -104,7 +110,7 @@ class SessionServiceManager:
return True
@staticmethod
def set_tts_status_for_session(session_id: str, enabled: bool) -> None:
async def set_tts_status_for_session(session_id: str, enabled: bool) -> None:
"""设置TTS在指定会话中的启停状态
Args:
@@ -113,14 +119,20 @@ class SessionServiceManager:
"""
session_config = (
sp.get("session_service_config", {}, scope="umo", scope_id=session_id) or {}
await sp.get_async(
scope="umo",
scope_id=session_id,
key="session_service_config",
default={},
)
or {}
)
session_config["tts_enabled"] = enabled
sp.put(
"session_service_config",
session_config,
await sp.put_async(
scope="umo",
scope_id=session_id,
key="session_service_config",
value=session_config,
)
logger.info(
@@ -128,7 +140,7 @@ class SessionServiceManager:
)
@staticmethod
def should_process_tts_request(event: AstrMessageEvent) -> bool:
async def should_process_tts_request(event: AstrMessageEvent) -> bool:
"""检查是否应该处理TTS请求
Args:
@@ -139,14 +151,14 @@ class SessionServiceManager:
"""
session_id = event.unified_msg_origin
return SessionServiceManager.is_tts_enabled_for_session(session_id)
return await SessionServiceManager.is_tts_enabled_for_session(session_id)
# =============================================================================
# 会话整体启停相关方法
# =============================================================================
@staticmethod
def is_session_enabled(session_id: str) -> bool:
async def is_session_enabled(session_id: str) -> bool:
"""检查会话是否整体启用
Args:
@@ -157,11 +169,11 @@ class SessionServiceManager:
"""
# 获取会话服务配置
session_services = sp.get(
"session_service_config",
{},
session_services = await sp.get_async(
scope="umo",
scope_id=session_id,
key="session_service_config",
default={},
)
# 如果配置了该会话的整体状态,返回该状态
+23 -11
View File
@@ -8,7 +8,10 @@ class SessionPluginManager:
"""管理会话级别的插件启停状态"""
@staticmethod
def is_plugin_enabled_for_session(session_id: str, plugin_name: str) -> bool:
async def is_plugin_enabled_for_session(
session_id: str,
plugin_name: str,
) -> bool:
"""检查插件是否在指定会话中启用
Args:
@@ -20,11 +23,11 @@ class SessionPluginManager:
"""
# 获取会话插件配置
session_plugin_config = sp.get(
"session_plugin_config",
{},
session_plugin_config = await sp.get_async(
scope="umo",
scope_id=session_id,
key="session_plugin_config",
default={},
)
session_config = session_plugin_config.get(session_id, {})
@@ -43,7 +46,10 @@ class SessionPluginManager:
return True
@staticmethod
def filter_handlers_by_session(event: AstrMessageEvent, handlers: list) -> list:
async def filter_handlers_by_session(
event: AstrMessageEvent,
handlers: list,
) -> list:
"""根据会话配置过滤处理器列表
Args:
@@ -59,6 +65,15 @@ class SessionPluginManager:
session_id = event.unified_msg_origin
filtered_handlers = []
session_plugin_config = await sp.get_async(
scope="umo",
scope_id=session_id,
key="session_plugin_config",
default={},
)
session_config = session_plugin_config.get(session_id, {})
disabled_plugins = session_config.get("disabled_plugins", [])
for handler in handlers:
# 获取处理器对应的插件
plugin = star_map.get(handler.handler_module_path)
@@ -76,14 +91,11 @@ class SessionPluginManager:
continue
# 检查插件是否在当前会话中启用
if SessionPluginManager.is_plugin_enabled_for_session(
session_id,
plugin.name,
):
filtered_handlers.append(handler)
else:
if plugin.name in disabled_plugins:
logger.debug(
f"插件 {plugin.name} 在会话 {session_id} 中被禁用,跳过处理器 {handler.handler_name}",
)
else:
filtered_handlers.append(handler)
return filtered_handlers
+6 -5
View File
@@ -11,14 +11,15 @@ class UmopConfigRouter:
"""UMOP 到配置文件 ID 的映射"""
self.sp = sp
self._load_routing_table()
async def initialize(self):
await self._load_routing_table()
def _load_routing_table(self):
async def _load_routing_table(self):
"""加载路由表"""
# 从 SharedPreferences 中加载 umop_to_conf_id 映射
sp_data = self.sp.get(
"umop_config_routing",
{},
sp_data = await self.sp.get_async(
key="umop_config_routing",
default={},
scope="global",
scope_id="global",
)