From 4e29684aa381cd260d4b496dcbce8a612a401b79 Mon Sep 17 00:00:00 2001 From: Soulter <37870767+Soulter@users.noreply.github.com> Date: Fri, 28 Nov 2025 13:29:50 +0800 Subject: [PATCH] fix: add plugin set and knowledge bases selection in custom rules page (#3813) fixes: #3806 --- astrbot/core/star/session_llm_manager.py | 107 --------- astrbot/core/star/session_plugin_manager.py | 81 ------- astrbot/dashboard/routes/knowledge_base.py | 156 ------------- .../dashboard/routes/session_management.py | 41 +++- .../en-US/features/session-management.json | 11 + .../zh-CN/features/session-management.json | 11 + dashboard/src/views/SessionManagementPage.vue | 214 +++++++++++++++++- 7 files changed, 274 insertions(+), 347 deletions(-) diff --git a/astrbot/core/star/session_llm_manager.py b/astrbot/core/star/session_llm_manager.py index 8c40f25c1..9fdca1457 100644 --- a/astrbot/core/star/session_llm_manager.py +++ b/astrbot/core/star/session_llm_manager.py @@ -171,110 +171,3 @@ class SessionServiceManager: # 如果没有配置,默认为启用(兼容性考虑) return True - - @staticmethod - def set_session_status(session_id: str, enabled: bool) -> None: - """设置会话的整体启停状态 - - Args: - session_id: 会话ID (unified_msg_origin) - enabled: True表示启用,False表示禁用 - - """ - session_config = ( - sp.get("session_service_config", {}, scope="umo", scope_id=session_id) or {} - ) - session_config["session_enabled"] = enabled - sp.put( - "session_service_config", - session_config, - scope="umo", - scope_id=session_id, - ) - - logger.info( - f"会话 {session_id} 的整体状态已更新为: {'启用' if enabled else '禁用'}", - ) - - @staticmethod - def should_process_session_request(event: AstrMessageEvent) -> bool: - """检查是否应该处理会话请求(会话整体启停检查) - - Args: - event: 消息事件 - - Returns: - bool: True表示应该处理,False表示跳过 - - """ - session_id = event.unified_msg_origin - return SessionServiceManager.is_session_enabled(session_id) - - # ============================================================================= - # 会话命名相关方法 - # ============================================================================= - - @staticmethod - def get_session_custom_name(session_id: str) -> str | None: - """获取会话的自定义名称 - - Args: - session_id: 会话ID (unified_msg_origin) - - Returns: - str: 自定义名称,如果没有设置则返回None - - """ - session_services = sp.get( - "session_service_config", - {}, - scope="umo", - scope_id=session_id, - ) - return session_services.get("custom_name") - - @staticmethod - def set_session_custom_name(session_id: str, custom_name: str) -> None: - """设置会话的自定义名称 - - Args: - session_id: 会话ID (unified_msg_origin) - custom_name: 自定义名称,可以为空字符串来清除名称 - - """ - session_config = ( - sp.get("session_service_config", {}, scope="umo", scope_id=session_id) or {} - ) - if custom_name and custom_name.strip(): - session_config["custom_name"] = custom_name.strip() - else: - # 如果传入空名称,则删除自定义名称 - session_config.pop("custom_name", None) - sp.put( - "session_service_config", - session_config, - scope="umo", - scope_id=session_id, - ) - - logger.info( - f"会话 {session_id} 的自定义名称已更新为: {custom_name.strip() if custom_name and custom_name.strip() else '已清除'}", - ) - - @staticmethod - def get_session_display_name(session_id: str) -> str: - """获取会话的显示名称(优先显示自定义名称,否则显示原始session_id的最后一段) - - Args: - session_id: 会话ID (unified_msg_origin) - - Returns: - str: 显示名称 - - """ - custom_name = SessionServiceManager.get_session_custom_name(session_id) - if custom_name: - return custom_name - - # 如果没有自定义名称,返回session_id的最后一段 - return session_id.split(":")[2] if session_id.count(":") >= 2 else session_id diff --git a/astrbot/core/star/session_plugin_manager.py b/astrbot/core/star/session_plugin_manager.py index c74546fe7..e2ebd11f0 100644 --- a/astrbot/core/star/session_plugin_manager.py +++ b/astrbot/core/star/session_plugin_manager.py @@ -42,87 +42,6 @@ class SessionPluginManager: # 如果都没有配置,默认为启用(兼容性考虑) return True - @staticmethod - def set_plugin_status_for_session( - session_id: str, - plugin_name: str, - enabled: bool, - ) -> None: - """设置插件在指定会话中的启停状态 - - Args: - session_id: 会话ID (unified_msg_origin) - plugin_name: 插件名称 - enabled: True表示启用,False表示禁用 - - """ - # 获取当前配置 - session_plugin_config = sp.get( - "session_plugin_config", - {}, - scope="umo", - scope_id=session_id, - ) - if session_id not in session_plugin_config: - session_plugin_config[session_id] = { - "enabled_plugins": [], - "disabled_plugins": [], - } - - session_config = session_plugin_config[session_id] - enabled_plugins = session_config.get("enabled_plugins", []) - disabled_plugins = session_config.get("disabled_plugins", []) - - if enabled: - # 启用插件 - if plugin_name in disabled_plugins: - disabled_plugins.remove(plugin_name) - if plugin_name not in enabled_plugins: - enabled_plugins.append(plugin_name) - else: - # 禁用插件 - if plugin_name in enabled_plugins: - enabled_plugins.remove(plugin_name) - if plugin_name not in disabled_plugins: - disabled_plugins.append(plugin_name) - - # 保存配置 - session_config["enabled_plugins"] = enabled_plugins - session_config["disabled_plugins"] = disabled_plugins - session_plugin_config[session_id] = session_config - sp.put( - "session_plugin_config", - session_plugin_config, - scope="umo", - scope_id=session_id, - ) - - logger.info( - f"会话 {session_id} 的插件 {plugin_name} 状态已更新为: {'启用' if enabled else '禁用'}", - ) - - @staticmethod - def get_session_plugin_config(session_id: str) -> dict[str, list[str]]: - """获取指定会话的插件配置 - - Args: - session_id: 会话ID (unified_msg_origin) - - Returns: - Dict[str, List[str]]: 包含enabled_plugins和disabled_plugins的字典 - - """ - session_plugin_config = sp.get( - "session_plugin_config", - {}, - scope="umo", - scope_id=session_id, - ) - return session_plugin_config.get( - session_id, - {"enabled_plugins": [], "disabled_plugins": []}, - ) - @staticmethod def filter_handlers_by_session(event: AstrMessageEvent, handlers: list) -> list: """根据会话配置过滤处理器列表 diff --git a/astrbot/dashboard/routes/knowledge_base.py b/astrbot/dashboard/routes/knowledge_base.py index b692f23bb..050d5836c 100644 --- a/astrbot/dashboard/routes/knowledge_base.py +++ b/astrbot/dashboard/routes/knowledge_base.py @@ -60,10 +60,6 @@ class KnowledgeBaseRoute(Route): # "/kb/media/delete": ("POST", self.delete_media), # 检索 "/kb/retrieve": ("POST", self.retrieve), - # 会话知识库配置 - "/kb/session/config/get": ("GET", self.get_session_kb_config), - "/kb/session/config/set": ("POST", self.set_session_kb_config), - "/kb/session/config/delete": ("POST", self.delete_session_kb_config), } self.register_routes() @@ -920,158 +916,6 @@ class KnowledgeBaseRoute(Route): logger.error(traceback.format_exc()) return Response().error(f"检索失败: {e!s}").__dict__ - # ===== 会话知识库配置 API ===== - - async def get_session_kb_config(self): - """获取会话的知识库配置 - - Query 参数: - - session_id: 会话 ID (必填) - - 返回: - - kb_ids: 知识库 ID 列表 - - top_k: 返回结果数量 - - enable_rerank: 是否启用重排序 - """ - try: - from astrbot.core import sp - - session_id = request.args.get("session_id") - - if not session_id: - return Response().error("缺少参数 session_id").__dict__ - - # 从 SharedPreferences 获取配置 - config = await sp.session_get(session_id, "kb_config", default={}) - - logger.debug(f"[KB配置] 读取到配置: session_id={session_id}") - - # 如果没有配置,返回默认值 - if not config: - config = {"kb_ids": [], "top_k": 5, "enable_rerank": True} - - return Response().ok(config).__dict__ - - except Exception as e: - logger.error(f"[KB配置] 获取配置时出错: {e}", exc_info=True) - return Response().error(f"获取会话知识库配置失败: {e!s}").__dict__ - - async def set_session_kb_config(self): - """设置会话的知识库配置 - - Body: - - scope: 配置范围 (目前只支持 "session") - - scope_id: 会话 ID (必填) - - kb_ids: 知识库 ID 列表 (必填) - - top_k: 返回结果数量 (可选, 默认 5) - - enable_rerank: 是否启用重排序 (可选, 默认 true) - """ - try: - from astrbot.core import sp - - data = await request.json - - scope = data.get("scope") - scope_id = data.get("scope_id") - kb_ids = data.get("kb_ids", []) - top_k = data.get("top_k", 5) - enable_rerank = data.get("enable_rerank", True) - - # 验证参数 - if scope != "session": - return Response().error("目前仅支持 session 范围的配置").__dict__ - - if not scope_id: - return Response().error("缺少参数 scope_id").__dict__ - - if not isinstance(kb_ids, list): - return Response().error("kb_ids 必须是列表").__dict__ - - # 验证知识库是否存在 - kb_mgr = self._get_kb_manager() - invalid_ids = [] - valid_ids = [] - for kb_id in kb_ids: - kb_helper = await kb_mgr.get_kb(kb_id) - if kb_helper: - valid_ids.append(kb_id) - else: - invalid_ids.append(kb_id) - logger.warning(f"[KB配置] 知识库不存在: {kb_id}") - - if invalid_ids: - logger.warning(f"[KB配置] 以下知识库ID无效: {invalid_ids}") - - # 允许保存空列表,表示明确不使用任何知识库 - if kb_ids and not valid_ids: - # 只有当用户提供了 kb_ids 但全部无效时才报错 - return Response().error(f"所有提供的知识库ID都无效: {kb_ids}").__dict__ - - # 如果 kb_ids 为空列表,表示用户想清空配置 - if not kb_ids: - valid_ids = [] - - # 构建配置对象(只保存有效的ID) - config = { - "kb_ids": valid_ids, - "top_k": top_k, - "enable_rerank": enable_rerank, - } - - # 保存到 SharedPreferences - await sp.session_put(scope_id, "kb_config", config) - - # 立即验证是否保存成功 - verify_config = await sp.session_get(scope_id, "kb_config", default={}) - - if verify_config == config: - return ( - Response() - .ok( - {"valid_ids": valid_ids, "invalid_ids": invalid_ids}, - "保存知识库配置成功", - ) - .__dict__ - ) - logger.error("[KB配置] 配置保存失败,验证不匹配") - return Response().error("配置保存失败").__dict__ - - except Exception as e: - logger.error(f"[KB配置] 设置配置时出错: {e}", exc_info=True) - return Response().error(f"设置会话知识库配置失败: {e!s}").__dict__ - - async def delete_session_kb_config(self): - """删除会话的知识库配置 - - Body: - - scope: 配置范围 (目前只支持 "session") - - scope_id: 会话 ID (必填) - """ - try: - from astrbot.core import sp - - data = await request.json - - scope = data.get("scope") - scope_id = data.get("scope_id") - - # 验证参数 - if scope != "session": - return Response().error("目前仅支持 session 范围的配置").__dict__ - - if not scope_id: - return Response().error("缺少参数 scope_id").__dict__ - - # 从 SharedPreferences 删除配置 - await sp.session_remove(scope_id, "kb_config") - - return Response().ok(message="删除知识库配置成功").__dict__ - - except Exception as e: - logger.error(f"删除会话知识库配置失败: {e}") - logger.error(traceback.format_exc()) - return Response().error(f"删除会话知识库配置失败: {e!s}").__dict__ - async def upload_document_from_url(self): """从 URL 上传文档 diff --git a/astrbot/dashboard/routes/session_management.py b/astrbot/dashboard/routes/session_management.py index 3e1136b79..a938d662d 100644 --- a/astrbot/dashboard/routes/session_management.py +++ b/astrbot/dashboard/routes/session_management.py @@ -74,7 +74,10 @@ class SessionManagementRoute(Route): umo_id = pref.scope_id if umo_id not in umo_rules: umo_rules[umo_id] = {} - umo_rules[umo_id][pref.key] = pref.value["val"] + if pref.key == "session_plugin_config" and umo_id in pref.value["val"]: + umo_rules[umo_id][pref.key] = pref.value["val"][umo_id] + else: + umo_rules[umo_id][pref.key] = pref.value["val"] # 搜索过滤 if search: @@ -185,6 +188,35 @@ class SessionManagementRoute(Route): for p in provider_manager.tts_provider_insts ] + # 获取可用的插件列表(排除 reserved 的系统插件) + plugin_manager = self.core_lifecycle.plugin_manager + available_plugins = [ + { + "name": p.name, + "display_name": p.display_name or p.name, + "desc": p.desc, + } + for p in plugin_manager.context.get_all_stars() + if not p.reserved and p.name + ] + + # 获取可用的知识库列表 + available_kbs = [] + kb_manager = self.core_lifecycle.kb_manager + if kb_manager: + try: + kbs = await kb_manager.list_kbs() + available_kbs = [ + { + "kb_id": kb.kb_id, + "kb_name": kb.kb_name, + "emoji": kb.emoji, + } + for kb in kbs + ] + except Exception as e: + logger.warning(f"获取知识库列表失败: {e!s}") + return ( Response() .ok( @@ -197,6 +229,8 @@ class SessionManagementRoute(Route): "available_chat_providers": available_chat_providers, "available_stt_providers": available_stt_providers, "available_tts_providers": available_tts_providers, + "available_plugins": available_plugins, + "available_kbs": available_kbs, "available_rule_keys": AVAILABLE_SESSION_RULE_KEYS, } ) @@ -229,6 +263,11 @@ class SessionManagementRoute(Route): if rule_key not in AVAILABLE_SESSION_RULE_KEYS: return Response().error(f"不支持的规则键: {rule_key}").__dict__ + if rule_key == "session_plugin_config": + rule_value = { + umo: rule_value, + } + # 使用 shared preferences 更新规则 await sp.session_put(umo, rule_key, rule_value) diff --git a/dashboard/src/i18n/locales/en-US/features/session-management.json b/dashboard/src/i18n/locales/en-US/features/session-management.json index b1fadf666..dc0d3c4a8 100644 --- a/dashboard/src/i18n/locales/en-US/features/session-management.json +++ b/dashboard/src/i18n/locales/en-US/features/session-management.json @@ -73,6 +73,17 @@ "title": "Persona Configuration", "selectPersona": "Select Persona", "hint": "Persona settings affect the conversation style and behavior of the LLM" + }, + "pluginConfig": { + "title": "Plugin Configuration", + "disabledPlugins": "Disabled Plugins", + "hint": "Select plugins to disable for this session. Unselected plugins will remain enabled." + }, + "kbConfig": { + "title": "Knowledge Base Configuration", + "selectKbs": "Select Knowledge Bases", + "topK": "Top K Results", + "enableRerank": "Enable Reranking" } }, "deleteConfirm": { diff --git a/dashboard/src/i18n/locales/zh-CN/features/session-management.json b/dashboard/src/i18n/locales/zh-CN/features/session-management.json index fb924a147..4b9053ebf 100644 --- a/dashboard/src/i18n/locales/zh-CN/features/session-management.json +++ b/dashboard/src/i18n/locales/zh-CN/features/session-management.json @@ -73,6 +73,17 @@ "title": "人格配置", "selectPersona": "选择人格", "hint": "应用人格配置后,将会强制该来源的所有对话使用该人格。" + }, + "pluginConfig": { + "title": "插件配置", + "disabledPlugins": "禁用的插件", + "hint": "选择要在此会话中禁用的插件。未选择的插件将保持启用状态。" + }, + "kbConfig": { + "title": "知识库配置", + "selectKbs": "选择知识库", + "topK": "返回结果数量 (Top K)", + "enableRerank": "启用重排序" } }, "deleteConfirm": { diff --git a/dashboard/src/views/SessionManagementPage.vue b/dashboard/src/views/SessionManagementPage.vue index aea1af9fb..0e2b59555 100644 --- a/dashboard/src/views/SessionManagementPage.vue +++ b/dashboard/src/views/SessionManagementPage.vue @@ -143,11 +143,11 @@ - + {{ tm('ruleEditor.title') }} - + {{ selectedUmo.umo }} @@ -241,6 +241,59 @@ {{ tm('buttons.save') }} + + +
+

{{ tm('ruleEditor.pluginConfig.title') }}

+
+ + + + + + + + {{ tm('ruleEditor.pluginConfig.hint') }} + + + + +
+ + {{ tm('buttons.save') }} + +
+ + +
+

{{ tm('ruleEditor.kbConfig.title') }}

+
+ + + + + + + + + + + + + +
+ + {{ tm('buttons.save') }} + +
@@ -347,6 +400,8 @@ export default { availableChatProviders: [], availableSttProviders: [], availableTtsProviders: [], + availablePlugins: [], + availableKbs: [], // 添加规则 addRuleDialog: false, @@ -374,6 +429,19 @@ export default { text_to_speech: null, }, + // 插件配置 + pluginConfig: { + enabled_plugins: [], + disabled_plugins: [], + }, + + // 知识库配置 + kbConfig: { + kb_ids: [], + top_k: 5, + enable_rerank: true, + }, + // 删除确认 deleteDialog: false, deleteTarget: null, @@ -447,6 +515,20 @@ export default { })) ] }, + + pluginOptions() { + return this.availablePlugins.map(p => ({ + label: p.display_name || p.name, + value: p.name + })) + }, + + kbOptions() { + return this.availableKbs.map(kb => ({ + label: `${kb.emoji || '📚'} ${kb.kb_name}`, + value: kb.kb_id + })) + }, }, watch: { @@ -492,6 +574,8 @@ export default { this.availableChatProviders = data.available_chat_providers this.availableSttProviders = data.available_stt_providers this.availableTtsProviders = data.available_tts_providers + this.availablePlugins = data.available_plugins || [] + this.availableKbs = data.available_kbs || [] } else { this.showError(response.data.message || this.tm('messages.loadError')) } @@ -589,6 +673,21 @@ export default { text_to_speech: this.editingRules['provider_perf_text_to_speech'] || null, } + // 初始化插件配置 + const pluginCfg = this.editingRules.session_plugin_config || {} + this.pluginConfig = { + enabled_plugins: pluginCfg.enabled_plugins || [], + disabled_plugins: pluginCfg.disabled_plugins || [], + } + + // 初始化知识库配置 + const kbCfg = this.editingRules.kb_config || {} + this.kbConfig = { + kb_ids: kbCfg.kb_ids || [], + top_k: kbCfg.top_k ?? 5, + enable_rerank: kbCfg.enable_rerank !== false, + } + this.ruleDialog = true }, @@ -708,6 +807,117 @@ export default { this.saving = false }, + async savePluginConfig() { + if (!this.selectedUmo) return + + this.saving = true + try { + const config = { + enabled_plugins: this.pluginConfig.enabled_plugins, + disabled_plugins: this.pluginConfig.disabled_plugins, + } + + // 如果两个列表都为空,删除配置 + if (config.enabled_plugins.length === 0 && config.disabled_plugins.length === 0) { + if (this.editingRules.session_plugin_config) { + await axios.post('/api/session/delete-rule', { + umo: this.selectedUmo.umo, + rule_key: 'session_plugin_config' + }) + delete this.editingRules.session_plugin_config + let item = this.rulesList.find(u => u.umo === this.selectedUmo.umo) + if (item) delete item.rules.session_plugin_config + } + this.showSuccess(this.tm('messages.saveSuccess')) + } else { + const response = await axios.post('/api/session/update-rule', { + umo: this.selectedUmo.umo, + rule_key: 'session_plugin_config', + rule_value: config + }) + + if (response.data.status === 'ok') { + this.showSuccess(this.tm('messages.saveSuccess')) + this.editingRules.session_plugin_config = config + + let item = this.rulesList.find(u => u.umo === this.selectedUmo.umo) + if (item) { + item.rules.session_plugin_config = config + } else { + this.rulesList.push({ + umo: this.selectedUmo.umo, + platform: this.selectedUmo.platform, + message_type: this.selectedUmo.message_type, + session_id: this.selectedUmo.session_id, + rules: { session_plugin_config: config } + }) + } + } else { + this.showError(response.data.message || this.tm('messages.saveError')) + } + } + } catch (error) { + this.showError(error.response?.data?.message || this.tm('messages.saveError')) + } + this.saving = false + }, + + async saveKbConfig() { + if (!this.selectedUmo) return + + this.saving = true + try { + const config = { + kb_ids: this.kbConfig.kb_ids, + top_k: this.kbConfig.top_k, + enable_rerank: this.kbConfig.enable_rerank, + } + + // 如果 kb_ids 为空,删除配置 + if (config.kb_ids.length === 0) { + if (this.editingRules.kb_config) { + await axios.post('/api/session/delete-rule', { + umo: this.selectedUmo.umo, + rule_key: 'kb_config' + }) + delete this.editingRules.kb_config + let item = this.rulesList.find(u => u.umo === this.selectedUmo.umo) + if (item) delete item.rules.kb_config + } + this.showSuccess(this.tm('messages.saveSuccess')) + } else { + const response = await axios.post('/api/session/update-rule', { + umo: this.selectedUmo.umo, + rule_key: 'kb_config', + rule_value: config + }) + + if (response.data.status === 'ok') { + this.showSuccess(this.tm('messages.saveSuccess')) + this.editingRules.kb_config = config + + let item = this.rulesList.find(u => u.umo === this.selectedUmo.umo) + if (item) { + item.rules.kb_config = config + } else { + this.rulesList.push({ + umo: this.selectedUmo.umo, + platform: this.selectedUmo.platform, + message_type: this.selectedUmo.message_type, + session_id: this.selectedUmo.session_id, + rules: { kb_config: config } + }) + } + } else { + this.showError(response.data.message || this.tm('messages.saveError')) + } + } + } catch (error) { + this.showError(error.response?.data?.message || this.tm('messages.saveError')) + } + this.saving = false + }, + confirmDeleteRules(item) { this.deleteTarget = item this.deleteDialog = true