From 454841de107925c80d9e3ccb4cc1cb160db5f0b2 Mon Sep 17 00:00:00 2001 From: Soulter <37870767+Soulter@users.noreply.github.com> Date: Sat, 3 Jan 2026 19:12:39 +0800 Subject: [PATCH] fix: database is locked error when invoking tts command (#4313) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix: database is locked error when invoking /tts command fixes: #4311 * chore: rm pnpm lockfile * perf: 减少操作数据库的次数 --- .../builtin_commands/commands/tts.py | 4 +- astrbot/core/core_lifecycle.py | 1 + .../process_stage/method/agent_request.py | 2 +- .../core/pipeline/result_decorate/stage.py | 2 +- .../pipeline/session_status_check/stage.py | 2 +- astrbot/core/pipeline/waking_check/stage.py | 2 +- astrbot/core/provider/manager.py | 39 +++++++---- astrbot/core/star/session_llm_manager.py | 64 +++++++++++-------- astrbot/core/star/session_plugin_manager.py | 34 ++++++---- astrbot/core/umop_config_router.py | 11 ++-- 10 files changed, 101 insertions(+), 60 deletions(-) diff --git a/astrbot/builtin_stars/builtin_commands/commands/tts.py b/astrbot/builtin_stars/builtin_commands/commands/tts.py index d733ba1ea..dee8e31de 100644 --- a/astrbot/builtin_stars/builtin_commands/commands/tts.py +++ b/astrbot/builtin_stars/builtin_commands/commands/tts.py @@ -14,13 +14,13 @@ class TTSCommand: async def tts(self, event: AstrMessageEvent): """开关文本转语音(会话级别)""" umo = event.unified_msg_origin - ses_tts = SessionServiceManager.is_tts_enabled_for_session(umo) + ses_tts = await SessionServiceManager.is_tts_enabled_for_session(umo) cfg = self.context.get_config(umo=umo) tts_enable = cfg["provider_tts_settings"]["enable"] # 切换状态 new_status = not ses_tts - SessionServiceManager.set_tts_status_for_session(umo, new_status) + await SessionServiceManager.set_tts_status_for_session(umo, new_status) status_text = "已开启" if new_status else "已关闭" diff --git a/astrbot/core/core_lifecycle.py b/astrbot/core/core_lifecycle.py index 823fdf260..a14d8d970 100644 --- a/astrbot/core/core_lifecycle.py +++ b/astrbot/core/core_lifecycle.py @@ -90,6 +90,7 @@ class AstrBotCoreLifecycle: # 初始化 UMOP 配置路由器 self.umop_config_router = UmopConfigRouter(sp=sp) + await self.umop_config_router.initialize() # 初始化 AstrBot 配置管理器 self.astrbot_config_mgr = AstrBotConfigManager( diff --git a/astrbot/core/pipeline/process_stage/method/agent_request.py b/astrbot/core/pipeline/process_stage/method/agent_request.py index f6f81631e..9efe53814 100644 --- a/astrbot/core/pipeline/process_stage/method/agent_request.py +++ b/astrbot/core/pipeline/process_stage/method/agent_request.py @@ -38,7 +38,7 @@ class AgentRequestSubStage(Stage): ) return - if not SessionServiceManager.should_process_llm_request(event): + if not await SessionServiceManager.should_process_llm_request(event): logger.debug( f"The session {event.unified_msg_origin} has disabled AI capability, skipping processing." ) diff --git a/astrbot/core/pipeline/result_decorate/stage.py b/astrbot/core/pipeline/result_decorate/stage.py index 529d0b263..e0bcd5ac9 100644 --- a/astrbot/core/pipeline/result_decorate/stage.py +++ b/astrbot/core/pipeline/result_decorate/stage.py @@ -260,7 +260,7 @@ class ResultDecorateStage(Stage): should_tts = ( bool(self.ctx.astrbot_config["provider_tts_settings"]["enable"]) and result.is_llm_result() - and SessionServiceManager.should_process_tts_request(event) + and await SessionServiceManager.should_process_tts_request(event) and random.random() <= self.tts_trigger_probability and tts_provider ) diff --git a/astrbot/core/pipeline/session_status_check/stage.py b/astrbot/core/pipeline/session_status_check/stage.py index 7feeeb86a..26c3c235a 100644 --- a/astrbot/core/pipeline/session_status_check/stage.py +++ b/astrbot/core/pipeline/session_status_check/stage.py @@ -21,7 +21,7 @@ class SessionStatusCheckStage(Stage): event: AstrMessageEvent, ) -> None | AsyncGenerator[None, None]: # 检查会话是否整体启用 - if not SessionServiceManager.is_session_enabled(event.unified_msg_origin): + if not await SessionServiceManager.is_session_enabled(event.unified_msg_origin): logger.debug(f"会话 {event.unified_msg_origin} 已被关闭,已终止事件传播。") # workaround for #2309 diff --git a/astrbot/core/pipeline/waking_check/stage.py b/astrbot/core/pipeline/waking_check/stage.py index 51e0f8795..00fb41c51 100644 --- a/astrbot/core/pipeline/waking_check/stage.py +++ b/astrbot/core/pipeline/waking_check/stage.py @@ -227,7 +227,7 @@ class WakingCheckStage(Stage): event._extras.pop("parsed_params", None) # 根据会话配置过滤插件处理器 - activated_handlers = SessionPluginManager.filter_handlers_by_session( + activated_handlers = await SessionPluginManager.filter_handlers_by_session( event, activated_handlers, ) diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py index 0dff2c8ed..b523a0661 100644 --- a/astrbot/core/provider/manager.py +++ b/astrbot/core/provider/manager.py @@ -119,19 +119,34 @@ class ProviderManager: TTSProvider, ): self.curr_tts_provider_inst = prov - sp.put("curr_provider_tts", provider_id, scope="global", scope_id="global") + await sp.put_async( + key="curr_provider_tts", + value=provider_id, + scope="global", + scope_id="global", + ) elif provider_type == ProviderType.SPEECH_TO_TEXT and isinstance( prov, STTProvider, ): self.curr_stt_provider_inst = prov - sp.put("curr_provider_stt", provider_id, scope="global", scope_id="global") + await sp.put_async( + key="curr_provider_stt", + value=provider_id, + scope="global", + scope_id="global", + ) elif provider_type == ProviderType.CHAT_COMPLETION and isinstance( prov, Provider, ): self.curr_provider_inst = prov - sp.put("curr_provider", provider_id, scope="global", scope_id="global") + await sp.put_async( + key="curr_provider", + value=provider_id, + scope="global", + scope_id="global", + ) async def get_provider_by_id(self, provider_id: str) -> Providers | None: """根据提供商 ID 获取提供商实例""" @@ -206,21 +221,21 @@ class ProviderManager: logger.error(traceback.format_exc()) logger.error(e) - selected_provider_id = sp.get( - "curr_provider", - self.provider_settings.get("default_provider_id"), + selected_provider_id = await sp.get_async( + key="curr_provider", + default=self.provider_settings.get("default_provider_id"), scope="global", scope_id="global", ) - selected_stt_provider_id = sp.get( - "curr_provider_stt", - self.provider_stt_settings.get("provider_id"), + selected_stt_provider_id = await sp.get_async( + key="curr_provider_stt", + default=self.provider_stt_settings.get("provider_id"), scope="global", scope_id="global", ) - selected_tts_provider_id = sp.get( - "curr_provider_tts", - self.provider_tts_settings.get("provider_id"), + selected_tts_provider_id = await sp.get_async( + key="curr_provider_tts", + default=self.provider_tts_settings.get("provider_id"), scope="global", scope_id="global", ) diff --git a/astrbot/core/star/session_llm_manager.py b/astrbot/core/star/session_llm_manager.py index 9fdca1457..ad4a473b4 100644 --- a/astrbot/core/star/session_llm_manager.py +++ b/astrbot/core/star/session_llm_manager.py @@ -12,7 +12,7 @@ class SessionServiceManager: # ============================================================================= @staticmethod - def is_llm_enabled_for_session(session_id: str) -> bool: + async def is_llm_enabled_for_session(session_id: str) -> bool: """检查LLM是否在指定会话中启用 Args: @@ -23,11 +23,11 @@ class SessionServiceManager: """ # 获取会话服务配置 - session_services = sp.get( - "session_service_config", - {}, + session_services = await sp.get_async( scope="umo", scope_id=session_id, + key="session_service_config", + default={}, ) # 如果配置了该会话的LLM状态,返回该状态 @@ -39,7 +39,7 @@ class SessionServiceManager: return True @staticmethod - def set_llm_status_for_session(session_id: str, enabled: bool) -> None: + async def set_llm_status_for_session(session_id: str, enabled: bool) -> None: """设置LLM在指定会话中的启停状态 Args: @@ -48,18 +48,24 @@ class SessionServiceManager: """ session_config = ( - sp.get("session_service_config", {}, scope="umo", scope_id=session_id) or {} + await sp.get_async( + scope="umo", + scope_id=session_id, + key="session_service_config", + default={}, + ) + or {} ) session_config["llm_enabled"] = enabled - sp.put( - "session_service_config", - session_config, + await sp.put_async( scope="umo", scope_id=session_id, + key="session_service_config", + value=session_config, ) @staticmethod - def should_process_llm_request(event: AstrMessageEvent) -> bool: + async def should_process_llm_request(event: AstrMessageEvent) -> bool: """检查是否应该处理LLM请求 Args: @@ -70,14 +76,14 @@ class SessionServiceManager: """ session_id = event.unified_msg_origin - return SessionServiceManager.is_llm_enabled_for_session(session_id) + return await SessionServiceManager.is_llm_enabled_for_session(session_id) # ============================================================================= # TTS 相关方法 # ============================================================================= @staticmethod - def is_tts_enabled_for_session(session_id: str) -> bool: + async def is_tts_enabled_for_session(session_id: str) -> bool: """检查TTS是否在指定会话中启用 Args: @@ -88,11 +94,11 @@ class SessionServiceManager: """ # 获取会话服务配置 - session_services = sp.get( - "session_service_config", - {}, + session_services = await sp.get_async( scope="umo", scope_id=session_id, + key="session_service_config", + default={}, ) # 如果配置了该会话的TTS状态,返回该状态 @@ -104,7 +110,7 @@ class SessionServiceManager: return True @staticmethod - def set_tts_status_for_session(session_id: str, enabled: bool) -> None: + async def set_tts_status_for_session(session_id: str, enabled: bool) -> None: """设置TTS在指定会话中的启停状态 Args: @@ -113,14 +119,20 @@ class SessionServiceManager: """ session_config = ( - sp.get("session_service_config", {}, scope="umo", scope_id=session_id) or {} + await sp.get_async( + scope="umo", + scope_id=session_id, + key="session_service_config", + default={}, + ) + or {} ) session_config["tts_enabled"] = enabled - sp.put( - "session_service_config", - session_config, + await sp.put_async( scope="umo", scope_id=session_id, + key="session_service_config", + value=session_config, ) logger.info( @@ -128,7 +140,7 @@ class SessionServiceManager: ) @staticmethod - def should_process_tts_request(event: AstrMessageEvent) -> bool: + async def should_process_tts_request(event: AstrMessageEvent) -> bool: """检查是否应该处理TTS请求 Args: @@ -139,14 +151,14 @@ class SessionServiceManager: """ session_id = event.unified_msg_origin - return SessionServiceManager.is_tts_enabled_for_session(session_id) + return await SessionServiceManager.is_tts_enabled_for_session(session_id) # ============================================================================= # 会话整体启停相关方法 # ============================================================================= @staticmethod - def is_session_enabled(session_id: str) -> bool: + async def is_session_enabled(session_id: str) -> bool: """检查会话是否整体启用 Args: @@ -157,11 +169,11 @@ class SessionServiceManager: """ # 获取会话服务配置 - session_services = sp.get( - "session_service_config", - {}, + session_services = await sp.get_async( scope="umo", scope_id=session_id, + key="session_service_config", + default={}, ) # 如果配置了该会话的整体状态,返回该状态 diff --git a/astrbot/core/star/session_plugin_manager.py b/astrbot/core/star/session_plugin_manager.py index e2ebd11f0..a81113415 100644 --- a/astrbot/core/star/session_plugin_manager.py +++ b/astrbot/core/star/session_plugin_manager.py @@ -8,7 +8,10 @@ class SessionPluginManager: """管理会话级别的插件启停状态""" @staticmethod - def is_plugin_enabled_for_session(session_id: str, plugin_name: str) -> bool: + async def is_plugin_enabled_for_session( + session_id: str, + plugin_name: str, + ) -> bool: """检查插件是否在指定会话中启用 Args: @@ -20,11 +23,11 @@ class SessionPluginManager: """ # 获取会话插件配置 - session_plugin_config = sp.get( - "session_plugin_config", - {}, + session_plugin_config = await sp.get_async( scope="umo", scope_id=session_id, + key="session_plugin_config", + default={}, ) session_config = session_plugin_config.get(session_id, {}) @@ -43,7 +46,10 @@ class SessionPluginManager: return True @staticmethod - def filter_handlers_by_session(event: AstrMessageEvent, handlers: list) -> list: + async def filter_handlers_by_session( + event: AstrMessageEvent, + handlers: list, + ) -> list: """根据会话配置过滤处理器列表 Args: @@ -59,6 +65,15 @@ class SessionPluginManager: session_id = event.unified_msg_origin filtered_handlers = [] + session_plugin_config = await sp.get_async( + scope="umo", + scope_id=session_id, + key="session_plugin_config", + default={}, + ) + session_config = session_plugin_config.get(session_id, {}) + disabled_plugins = session_config.get("disabled_plugins", []) + for handler in handlers: # 获取处理器对应的插件 plugin = star_map.get(handler.handler_module_path) @@ -76,14 +91,11 @@ class SessionPluginManager: continue # 检查插件是否在当前会话中启用 - if SessionPluginManager.is_plugin_enabled_for_session( - session_id, - plugin.name, - ): - filtered_handlers.append(handler) - else: + if plugin.name in disabled_plugins: logger.debug( f"插件 {plugin.name} 在会话 {session_id} 中被禁用,跳过处理器 {handler.handler_name}", ) + else: + filtered_handlers.append(handler) return filtered_handlers diff --git a/astrbot/core/umop_config_router.py b/astrbot/core/umop_config_router.py index b82ece2df..1f2289f4d 100644 --- a/astrbot/core/umop_config_router.py +++ b/astrbot/core/umop_config_router.py @@ -11,14 +11,15 @@ class UmopConfigRouter: """UMOP 到配置文件 ID 的映射""" self.sp = sp - self._load_routing_table() + async def initialize(self): + await self._load_routing_table() - def _load_routing_table(self): + async def _load_routing_table(self): """加载路由表""" # 从 SharedPreferences 中加载 umop_to_conf_id 映射 - sp_data = self.sp.get( - "umop_config_routing", - {}, + sp_data = await self.sp.get_async( + key="umop_config_routing", + default={}, scope="global", scope_id="global", )