diff --git a/astrbot/dashboard/routes/session_management.py b/astrbot/dashboard/routes/session_management.py index 6b3b1cb77..f510a1d4c 100644 --- a/astrbot/dashboard/routes/session_management.py +++ b/astrbot/dashboard/routes/session_management.py @@ -39,22 +39,22 @@ class SessionManagementRoute(Route): try: # 获取所有会话的对话信息 conversations = self.db_helper.get_all_conversations() - + # 获取会话对话映射 session_conversations = sp.get("session_conversation", {}) - + # 获取会话提供商偏好设置 session_provider_perf = sp.get("session_provider_perf", {}) - + # 获取可用的 personas personas = self.core_lifecycle.star_context.provider_manager.personas - + # 获取可用的 providers provider_manager = self.core_lifecycle.star_context.provider_manager - + sessions = [] - - # 构建会话信息 + + # 构建会话信息 for session_id, conversation_id in session_conversations.items(): session_info = { "session_id": session_id, @@ -67,18 +67,36 @@ class SessionManagementRoute(Route): "stt_provider_name": None, "tts_provider_id": None, "tts_provider_name": None, - "session_enabled": SessionServiceManager.is_session_enabled(session_id), - "llm_enabled": SessionServiceManager.is_llm_enabled_for_session(session_id), - "tts_enabled": SessionServiceManager.is_tts_enabled_for_session(session_id), - "mcp_enabled": SessionServiceManager.is_mcp_enabled_for_session(session_id), - "platform": session_id.split(":")[0] if ":" in session_id else "unknown", - "message_type": session_id.split(":")[1] if session_id.count(":") >= 1 else "unknown", - "session_name": SessionServiceManager.get_session_display_name(session_id), - "session_raw_name": session_id.split(":")[2] if session_id.count(":") >= 2 else session_id, + "session_enabled": SessionServiceManager.is_session_enabled( + session_id + ), + "llm_enabled": SessionServiceManager.is_llm_enabled_for_session( + session_id + ), + "tts_enabled": SessionServiceManager.is_tts_enabled_for_session( + session_id + ), + "mcp_enabled": SessionServiceManager.is_mcp_enabled_for_session( + session_id + ), + "platform": session_id.split(":")[0] + if ":" in session_id + else "unknown", + "message_type": session_id.split(":")[1] + if session_id.count(":") >= 1 + else "unknown", + "session_name": SessionServiceManager.get_session_display_name( + session_id + ), + "session_raw_name": session_id.split(":")[2] + if session_id.count(":") >= 2 + else session_id, } - + # 获取对话信息 - conversation = self.db_helper.get_conversation_by_user_id(session_id, conversation_id) + conversation = self.db_helper.get_conversation_by_user_id( + session_id, conversation_id + ) if conversation: session_info["persona_id"] = conversation.persona_id # 查找 persona 名称 @@ -95,10 +113,10 @@ class SessionManagementRoute(Route): if default_persona: session_info["persona_id"] = default_persona["name"] session_info["persona_name"] = default_persona["name"] - + # 获取会话的 provider 偏好设置 session_perf = session_provider_perf.get(session_id, {}) - + # Chat completion provider chat_provider_id = session_perf.get(ProviderType.CHAT_COMPLETION.value) if chat_provider_id: @@ -112,7 +130,7 @@ class SessionManagementRoute(Route): if default_provider: session_info["chat_provider_id"] = default_provider.meta().id session_info["chat_provider_name"] = default_provider.meta().id - + # STT provider stt_provider_id = session_perf.get(ProviderType.SPEECH_TO_TEXT.value) if stt_provider_id: @@ -125,8 +143,10 @@ class SessionManagementRoute(Route): default_stt_provider = provider_manager.curr_stt_provider_inst if default_stt_provider: session_info["stt_provider_id"] = default_stt_provider.meta().id - session_info["stt_provider_name"] = default_stt_provider.meta().id - + session_info["stt_provider_name"] = ( + default_stt_provider.meta().id + ) + # TTS provider tts_provider_id = session_perf.get(ProviderType.TEXT_TO_SPEECH.value) if tts_provider_id: @@ -139,43 +159,53 @@ class SessionManagementRoute(Route): default_tts_provider = provider_manager.curr_tts_provider_inst if default_tts_provider: session_info["tts_provider_id"] = default_tts_provider.meta().id - session_info["tts_provider_name"] = default_tts_provider.meta().id - + session_info["tts_provider_name"] = ( + default_tts_provider.meta().id + ) + sessions.append(session_info) - + # 获取可用的 personas 和 providers 列表 - available_personas = [{"name": p["name"], "prompt": p.get("prompt", "")} for p in personas] - + available_personas = [ + {"name": p["name"], "prompt": p.get("prompt", "")} for p in personas + ] + available_chat_providers = [] for provider in provider_manager.provider_insts: meta = provider.meta() - available_chat_providers.append({ - "id": meta.id, - "name": meta.id, - "model": meta.model, - "type": meta.type, - }) - + available_chat_providers.append( + { + "id": meta.id, + "name": meta.id, + "model": meta.model, + "type": meta.type, + } + ) + available_stt_providers = [] for provider in provider_manager.stt_provider_insts: meta = provider.meta() - available_stt_providers.append({ - "id": meta.id, - "name": meta.id, - "model": meta.model, - "type": meta.type, - }) - + available_stt_providers.append( + { + "id": meta.id, + "name": meta.id, + "model": meta.model, + "type": meta.type, + } + ) + available_tts_providers = [] for provider in provider_manager.tts_provider_insts: meta = provider.meta() - available_tts_providers.append({ - "id": meta.id, - "name": meta.id, - "model": meta.model, - "type": meta.type, - }) - + available_tts_providers.append( + { + "id": meta.id, + "name": meta.id, + "model": meta.model, + "type": meta.type, + } + ) + result = { "sessions": sessions, "available_personas": available_personas, @@ -183,9 +213,9 @@ class SessionManagementRoute(Route): "available_stt_providers": available_stt_providers, "available_tts_providers": available_tts_providers, } - + return Response().ok(result).__dict__ - + except Exception as e: error_msg = f"获取会话列表失败: {str(e)}\n{traceback.format_exc()}" logger.error(error_msg) @@ -197,26 +227,36 @@ class SessionManagementRoute(Route): data = await request.get_json() session_id = data.get("session_id") persona_name = data.get("persona_name") - + if not session_id: return Response().error("缺少必要参数: session_id").__dict__ - + if persona_name is None: return Response().error("缺少必要参数: persona_name").__dict__ - + # 获取会话当前的对话 ID conversation_manager = self.core_lifecycle.star_context.conversation_manager - conversation_id = await conversation_manager.get_curr_conversation_id(session_id) - + conversation_id = await conversation_manager.get_curr_conversation_id( + session_id + ) + if not conversation_id: # 如果没有对话,创建一个新的对话 - conversation_id = await conversation_manager.new_conversation(session_id) - + conversation_id = await conversation_manager.new_conversation( + session_id + ) + # 更新 persona - await conversation_manager.update_conversation_persona_id(session_id, persona_name) - - return Response().ok({"message": f"成功更新会话 {session_id} 的人格为 {persona_name}"}).__dict__ - + await conversation_manager.update_conversation_persona_id( + session_id, persona_name + ) + + return ( + Response() + .ok({"message": f"成功更新会话 {session_id} 的人格为 {persona_name}"}) + .__dict__ + ) + except Exception as e: error_msg = f"更新会话人格失败: {str(e)}\n{traceback.format_exc()}" logger.error(error_msg) @@ -228,11 +268,17 @@ class SessionManagementRoute(Route): data = await request.get_json() session_id = data.get("session_id") provider_id = data.get("provider_id") - provider_type = data.get("provider_type") # "chat_completion", "speech_to_text", "text_to_speech" - + provider_type = data.get( + "provider_type" + ) # "chat_completion", "speech_to_text", "text_to_speech" + if not session_id or not provider_id or not provider_type: - return Response().error("缺少必要参数: session_id, provider_id, provider_type").__dict__ - + return ( + Response() + .error("缺少必要参数: session_id, provider_id, provider_type") + .__dict__ + ) + # 转换 provider_type 字符串为枚举 try: if provider_type == "chat_completion": @@ -242,10 +288,16 @@ class SessionManagementRoute(Route): elif provider_type == "text_to_speech": provider_type_enum = ProviderType.TEXT_TO_SPEECH else: - return Response().error(f"不支持的 provider_type: {provider_type}").__dict__ + return ( + Response() + .error(f"不支持的 provider_type: {provider_type}") + .__dict__ + ) except Exception as e: - return Response().error(f"无效的 provider_type: {provider_type}").__dict__ - + return ( + Response().error(f"无效的 provider_type: {provider_type}").__dict__ + ) + # 设置 provider provider_manager = self.core_lifecycle.star_context.provider_manager await provider_manager.set_provider( @@ -253,11 +305,17 @@ class SessionManagementRoute(Route): provider_type=provider_type_enum, umo=session_id, ) - - return Response().ok({ - "message": f"成功更新会话 {session_id} 的 {provider_type} 提供商为 {provider_id}" - }).__dict__ - + + return ( + Response() + .ok( + { + "message": f"成功更新会话 {session_id} 的 {provider_type} 提供商为 {provider_id}" + } + ) + .__dict__ + ) + except Exception as e: error_msg = f"更新会话提供商失败: {str(e)}\n{traceback.format_exc()}" logger.error(error_msg) @@ -268,16 +326,16 @@ class SessionManagementRoute(Route): try: data = await request.get_json() session_id = data.get("session_id") - + if not session_id: return Response().error("缺少必要参数: session_id").__dict__ - # 获取会话对话信息 + # 获取会话对话信息 session_conversations = sp.get("session_conversation", {}) conversation_id = session_conversations.get(session_id) - + if not conversation_id: return Response().error(f"会话 {session_id} 未找到对话").__dict__ - + session_info = { "session_id": session_id, "conversation_id": conversation_id, @@ -289,26 +347,40 @@ class SessionManagementRoute(Route): "stt_provider_name": None, "tts_provider_id": None, "tts_provider_name": None, - "llm_enabled": SessionServiceManager.is_llm_enabled_for_session(session_id), + "llm_enabled": SessionServiceManager.is_llm_enabled_for_session( + session_id + ), "tts_enabled": None, # 将在下面设置 - "mcp_enabled": SessionServiceManager.is_mcp_enabled_for_session(session_id), - "platform": session_id.split(":")[0] if ":" in session_id else "unknown", - "message_type": session_id.split(":")[1] if session_id.count(":") >= 1 else "unknown", - "session_name": session_id.split(":")[2] if session_id.count(":") >= 2 else session_id, + "mcp_enabled": SessionServiceManager.is_mcp_enabled_for_session( + session_id + ), + "platform": session_id.split(":")[0] + if ":" in session_id + else "unknown", + "message_type": session_id.split(":")[1] + if session_id.count(":") >= 1 + else "unknown", + "session_name": session_id.split(":")[2] + if session_id.count(":") >= 2 + else session_id, } - + # 获取TTS状态 - session_info["tts_enabled"] = SessionServiceManager.is_tts_enabled_for_session(session_id) - + session_info["tts_enabled"] = ( + SessionServiceManager.is_tts_enabled_for_session(session_id) + ) + # 获取对话信息 - conversation = self.db_helper.get_conversation_by_user_id(session_id, conversation_id) + conversation = self.db_helper.get_conversation_by_user_id( + session_id, conversation_id + ) if conversation: session_info["persona_id"] = conversation.persona_id - + # 查找 persona 名称 provider_manager = self.core_lifecycle.star_context.provider_manager personas = provider_manager.personas - + if conversation.persona_id and conversation.persona_id != "[%None]": for persona in personas: if persona["name"] == conversation.persona_id: @@ -322,14 +394,14 @@ class SessionManagementRoute(Route): if default_persona: session_info["persona_id"] = default_persona["name"] session_info["persona_name"] = default_persona["name"] - + # 获取会话的 provider 偏好设置 session_provider_perf = sp.get("session_provider_perf", {}) session_perf = session_provider_perf.get(session_id, {}) - + # 获取 provider 信息 provider_manager = self.core_lifecycle.star_context.provider_manager - + # Chat completion provider chat_provider_id = session_perf.get(ProviderType.CHAT_COMPLETION.value) if chat_provider_id: @@ -343,7 +415,7 @@ class SessionManagementRoute(Route): if default_provider: session_info["chat_provider_id"] = default_provider.meta().id session_info["chat_provider_name"] = default_provider.meta().id - + # STT provider stt_provider_id = session_perf.get(ProviderType.SPEECH_TO_TEXT.value) if stt_provider_id: @@ -357,7 +429,7 @@ class SessionManagementRoute(Route): if default_stt_provider: session_info["stt_provider_id"] = default_stt_provider.meta().id session_info["stt_provider_name"] = default_stt_provider.meta().id - + # TTS provider tts_provider_id = session_perf.get(ProviderType.TEXT_TO_SPEECH.value) if tts_provider_id: @@ -371,9 +443,9 @@ class SessionManagementRoute(Route): if default_tts_provider: session_info["tts_provider_id"] = default_tts_provider.meta().id session_info["tts_provider_name"] = default_tts_provider.meta().id - + return Response().ok(session_info).__dict__ - + except Exception as e: error_msg = f"获取会话信息失败: {str(e)}\n{traceback.format_exc()}" logger.error(error_msg) @@ -383,36 +455,46 @@ class SessionManagementRoute(Route): """获取指定会话的插件配置信息""" try: session_id = request.args.get("session_id") - + if not session_id: return Response().error("缺少必要参数: session_id").__dict__ - + # 获取所有已激活的插件 all_plugins = [] plugin_manager = self.core_lifecycle.star_context._star_manager - + for plugin in plugin_manager.context.get_all_stars(): # 只显示已激活的插件,不包括保留插件 if plugin.activated and not plugin.reserved: - plugin_enabled = SessionPluginManager.is_plugin_enabled_for_session(session_id, plugin.name) - - all_plugins.append({ - "name": plugin.name, - "author": plugin.author, - "desc": plugin.desc, - "enabled": plugin_enabled, - }) - - return Response().ok({ - "session_id": session_id, - "plugins": all_plugins, - }).__dict__ - + plugin_enabled = SessionPluginManager.is_plugin_enabled_for_session( + session_id, plugin.name + ) + + all_plugins.append( + { + "name": plugin.name, + "author": plugin.author, + "desc": plugin.desc, + "enabled": plugin_enabled, + } + ) + + return ( + Response() + .ok( + { + "session_id": session_id, + "plugins": all_plugins, + } + ) + .__dict__ + ) + except Exception as e: error_msg = f"获取会话插件配置失败: {str(e)}\n{traceback.format_exc()}" logger.error(error_msg) return Response().error(f"获取会话插件配置失败: {str(e)}").__dict__ - + async def update_session_plugin(self): """更新指定会话的插件启停状态""" try: @@ -420,145 +502,183 @@ class SessionManagementRoute(Route): session_id = data.get("session_id") plugin_name = data.get("plugin_name") enabled = data.get("enabled") - + if not session_id: return Response().error("缺少必要参数: session_id").__dict__ - + if not plugin_name: return Response().error("缺少必要参数: plugin_name").__dict__ - + if enabled is None: return Response().error("缺少必要参数: enabled").__dict__ - + # 验证插件是否存在且已激活 plugin_manager = self.core_lifecycle.star_context._star_manager plugin = plugin_manager.context.get_registered_star(plugin_name) - + if not plugin: return Response().error(f"插件 {plugin_name} 不存在").__dict__ - + if not plugin.activated: return Response().error(f"插件 {plugin_name} 未激活").__dict__ - + if plugin.reserved: - return Response().error(f"插件 {plugin_name} 是系统保留插件,无法管理").__dict__ - + return ( + Response() + .error(f"插件 {plugin_name} 是系统保留插件,无法管理") + .__dict__ + ) + # 使用 SessionPluginManager 更新插件状态 - SessionPluginManager.set_plugin_status_for_session(session_id, plugin_name, enabled) - - return Response().ok({ - "message": f"插件 {plugin_name} 已{'启用' if enabled else '禁用'}", - "session_id": session_id, - "plugin_name": plugin_name, - "enabled": enabled, - }).__dict__ - + SessionPluginManager.set_plugin_status_for_session( + session_id, plugin_name, enabled + ) + + return ( + Response() + .ok( + { + "message": f"插件 {plugin_name} 已{'启用' if enabled else '禁用'}", + "session_id": session_id, + "plugin_name": plugin_name, + "enabled": enabled, + } + ) + .__dict__ + ) + except Exception as e: error_msg = f"更新会话插件状态失败: {str(e)}\n{traceback.format_exc()}" logger.error(error_msg) return Response().error(f"更新会话插件状态失败: {str(e)}").__dict__ - + async def update_session_llm(self): """更新指定会话的LLM启停状态""" try: data = await request.get_json() session_id = data.get("session_id") enabled = data.get("enabled") - + if not session_id: return Response().error("缺少必要参数: session_id").__dict__ - + if enabled is None: return Response().error("缺少必要参数: enabled").__dict__ - + # 使用 SessionServiceManager 更新LLM状态 SessionServiceManager.set_llm_status_for_session(session_id, enabled) - - return Response().ok({ - "message": f"LLM已{'启用' if enabled else '禁用'}", - "session_id": session_id, - "llm_enabled": enabled, - }).__dict__ - + + return ( + Response() + .ok( + { + "message": f"LLM已{'启用' if enabled else '禁用'}", + "session_id": session_id, + "llm_enabled": enabled, + } + ) + .__dict__ + ) + except Exception as e: error_msg = f"更新会话LLM状态失败: {str(e)}\n{traceback.format_exc()}" logger.error(error_msg) return Response().error(f"更新会话LLM状态失败: {str(e)}").__dict__ - + async def update_session_tts(self): """更新指定会话的TTS启停状态""" try: data = await request.get_json() session_id = data.get("session_id") enabled = data.get("enabled") - + if not session_id: return Response().error("缺少必要参数: session_id").__dict__ - + if enabled is None: return Response().error("缺少必要参数: enabled").__dict__ - + # 使用 SessionServiceManager 更新TTS状态 SessionServiceManager.set_tts_status_for_session(session_id, enabled) - - return Response().ok({ - "message": f"TTS已{'启用' if enabled else '禁用'}", - "session_id": session_id, - "tts_enabled": enabled, - }).__dict__ - + + return ( + Response() + .ok( + { + "message": f"TTS已{'启用' if enabled else '禁用'}", + "session_id": session_id, + "tts_enabled": enabled, + } + ) + .__dict__ + ) + except Exception as e: error_msg = f"更新会话TTS状态失败: {str(e)}\n{traceback.format_exc()}" logger.error(error_msg) return Response().error(f"更新会话TTS状态失败: {str(e)}").__dict__ - + async def update_session_mcp(self): """更新指定会话的MCP启停状态""" try: data = await request.get_json() session_id = data.get("session_id") enabled = data.get("enabled") - + if not session_id: return Response().error("缺少必要参数: session_id").__dict__ - + if enabled is None: return Response().error("缺少必要参数: enabled").__dict__ - + # 使用 SessionServiceManager 更新MCP状态 SessionServiceManager.set_mcp_status_for_session(session_id, enabled) - - return Response().ok({ - "message": f"MCP工具调用已{'启用' if enabled else '禁用'}", - "session_id": session_id, - "mcp_enabled": enabled, - }).__dict__ - + + return ( + Response() + .ok( + { + "message": f"MCP工具调用已{'启用' if enabled else '禁用'}", + "session_id": session_id, + "mcp_enabled": enabled, + } + ) + .__dict__ + ) + except Exception as e: error_msg = f"更新会话MCP状态失败: {str(e)}\n{traceback.format_exc()}" logger.error(error_msg) return Response().error(f"更新会话MCP状态失败: {str(e)}").__dict__ - + async def update_session_name(self): """更新指定会话的自定义名称""" try: data = await request.get_json() session_id = data.get("session_id") custom_name = data.get("custom_name", "") - + if not session_id: return Response().error("缺少必要参数: session_id").__dict__ - + # 使用 SessionServiceManager 更新会话名称 SessionServiceManager.set_session_custom_name(session_id, custom_name) - - return Response().ok({ - "message": f"会话名称已更新为: {custom_name if custom_name.strip() else '已清除自定义名称'}", - "session_id": session_id, - "custom_name": custom_name, - "display_name": SessionServiceManager.get_session_display_name(session_id), - }).__dict__ - + + return ( + Response() + .ok( + { + "message": f"会话名称已更新为: {custom_name if custom_name.strip() else '已清除自定义名称'}", + "session_id": session_id, + "custom_name": custom_name, + "display_name": SessionServiceManager.get_session_display_name( + session_id + ), + } + ) + .__dict__ + ) + except Exception as e: error_msg = f"更新会话名称失败: {str(e)}\n{traceback.format_exc()}" logger.error(error_msg) @@ -570,22 +690,28 @@ class SessionManagementRoute(Route): data = await request.get_json() session_id = data.get("session_id") session_enabled = data.get("session_enabled") - + if not session_id: return Response().error("缺少必要参数: session_id").__dict__ - + if session_enabled is None: return Response().error("缺少必要参数: session_enabled").__dict__ - + # 使用 SessionServiceManager 更新会话整体状态 SessionServiceManager.set_session_status(session_id, session_enabled) - - return Response().ok({ - "message": f"会话整体状态已更新为: {'启用' if session_enabled else '禁用'}", - "session_id": session_id, - "session_enabled": session_enabled, - }).__dict__ - + + return ( + Response() + .ok( + { + "message": f"会话整体状态已更新为: {'启用' if session_enabled else '禁用'}", + "session_id": session_id, + "session_enabled": session_enabled, + } + ) + .__dict__ + ) + except Exception as e: error_msg = f"更新会话整体状态失败: {str(e)}\n{traceback.format_exc()}" logger.error(error_msg)