From 8ebf087dbfb0a8ae57160943f35f23bdb1ac8e3a Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Thu, 10 Jul 2025 23:28:00 +0800 Subject: [PATCH] chore: optimize codes --- .../dashboard/routes/session_management.py | 50 ++++++++----------- 1 file changed, 21 insertions(+), 29 deletions(-) diff --git a/astrbot/dashboard/routes/session_management.py b/astrbot/dashboard/routes/session_management.py index f510a1d4c..d1f5f5631 100644 --- a/astrbot/dashboard/routes/session_management.py +++ b/astrbot/dashboard/routes/session_management.py @@ -37,14 +37,11 @@ class SessionManagementRoute(Route): async def list_sessions(self): """获取所有会话的列表,包括 persona 和 provider 信息""" try: - # 获取所有会话的对话信息 - conversations = self.db_helper.get_all_conversations() - # 获取会话对话映射 - session_conversations = sp.get("session_conversation", {}) + session_conversations = sp.get("session_conversation", {}) or {} # 获取会话提供商偏好设置 - session_provider_perf = sp.get("session_provider_perf", {}) + session_provider_perf = sp.get("session_provider_perf", {}) or {} # 获取可用的 personas personas = self.core_lifecycle.star_context.provider_manager.personas @@ -268,9 +265,8 @@ class SessionManagementRoute(Route): data = await request.get_json() session_id = data.get("session_id") provider_id = data.get("provider_id") - provider_type = data.get( - "provider_type" - ) # "chat_completion", "speech_to_text", "text_to_speech" + # "chat_completion", "speech_to_text", "text_to_speech" + provider_type = data.get("provider_type") if not session_id or not provider_id or not provider_type: return ( @@ -280,22 +276,17 @@ class SessionManagementRoute(Route): ) # 转换 provider_type 字符串为枚举 - try: - if provider_type == "chat_completion": - provider_type_enum = ProviderType.CHAT_COMPLETION - elif provider_type == "speech_to_text": - provider_type_enum = ProviderType.SPEECH_TO_TEXT - elif provider_type == "text_to_speech": - provider_type_enum = ProviderType.TEXT_TO_SPEECH - else: - return ( - Response() - .error(f"不支持的 provider_type: {provider_type}") - .__dict__ - ) - except Exception as e: + if provider_type == "chat_completion": + provider_type_enum = ProviderType.CHAT_COMPLETION + elif provider_type == "speech_to_text": + provider_type_enum = ProviderType.SPEECH_TO_TEXT + elif provider_type == "text_to_speech": + provider_type_enum = ProviderType.TEXT_TO_SPEECH + else: return ( - Response().error(f"无效的 provider_type: {provider_type}").__dict__ + Response() + .error(f"不支持的 provider_type: {provider_type}") + .__dict__ ) # 设置 provider @@ -330,7 +321,7 @@ class SessionManagementRoute(Route): if not session_id: return Response().error("缺少必要参数: session_id").__dict__ # 获取会话对话信息 - session_conversations = sp.get("session_conversation", {}) + session_conversations = sp.get("session_conversation", {}) or {} conversation_id = session_conversations.get(session_id) if not conversation_id: @@ -396,7 +387,7 @@ class SessionManagementRoute(Route): session_info["persona_name"] = default_persona["name"] # 获取会话的 provider 偏好设置 - session_provider_perf = sp.get("session_provider_perf", {}) + session_provider_perf = sp.get("session_provider_perf", {}) or {} session_perf = session_provider_perf.get(session_id, {}) # 获取 provider 信息 @@ -461,18 +452,19 @@ class SessionManagementRoute(Route): # 获取所有已激活的插件 all_plugins = [] - plugin_manager = self.core_lifecycle.star_context._star_manager + plugin_manager = self.core_lifecycle.plugin_manager for plugin in plugin_manager.context.get_all_stars(): # 只显示已激活的插件,不包括保留插件 if plugin.activated and not plugin.reserved: + plugin_name = plugin.name or "" plugin_enabled = SessionPluginManager.is_plugin_enabled_for_session( - session_id, plugin.name + session_id, plugin_name ) all_plugins.append( { - "name": plugin.name, + "name": plugin_name, "author": plugin.author, "desc": plugin.desc, "enabled": plugin_enabled, @@ -513,7 +505,7 @@ class SessionManagementRoute(Route): return Response().error("缺少必要参数: enabled").__dict__ # 验证插件是否存在且已激活 - plugin_manager = self.core_lifecycle.star_context._star_manager + plugin_manager = self.core_lifecycle.plugin_manager plugin = plugin_manager.context.get_registered_star(plugin_name) if not plugin: