From 98a75e923dd4e08c4d800ee202dc01137744f6df Mon Sep 17 00:00:00 2001 From: lxfight <1686540385@qq.com> Date: Sun, 19 Oct 2025 18:41:34 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E9=9B=86=E6=88=90=E7=9F=A5=E8=AF=86?= =?UTF-8?q?=E5=BA=93=E5=88=B0=E6=A0=B8=E5=BF=83=E7=94=9F=E5=91=BD=E5=91=A8?= =?UTF-8?q?=E6=9C=9F=E5=92=8C=E6=B6=88=E6=81=AF=E6=B5=81=E6=B0=B4=E7=BA=BF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 在 AstrBotCoreLifecycle 中初始化知识库管理器 - 将知识库注入器添加到消息处理上下文 - 在消息流水线中添加 KBEnhanceStage(知识库增强阶段) - 实现会话删除时的知识库配置级联清理机制 - 添加会话管理器的回调注册机制,支持零侵入扩展 --- astrbot/core/conversation_mgr.py | 37 +++++++++++- astrbot/core/core_lifecycle.py | 16 +++++ astrbot/core/pipeline/__init__.py | 3 + astrbot/core/pipeline/kb_enhance/stage.py | 72 +++++++++++++++++++++++ 4 files changed, 127 insertions(+), 1 deletion(-) create mode 100644 astrbot/core/pipeline/kb_enhance/stage.py diff --git a/astrbot/core/conversation_mgr.py b/astrbot/core/conversation_mgr.py index a6a2710f1..8f8e2e0e9 100644 --- a/astrbot/core/conversation_mgr.py +++ b/astrbot/core/conversation_mgr.py @@ -7,7 +7,7 @@ AstrBot 会话-对话管理器, 维护两个本地存储, 其中一个是 json import json from astrbot.core import sp -from typing import Dict, List +from typing import Dict, List, Callable, Awaitable from astrbot.core.db import BaseDatabase from astrbot.core.db.po import Conversation, ConversationV2 @@ -20,6 +20,38 @@ class ConversationManager: self.db = db_helper self.save_interval = 60 # 每 60 秒保存一次 + # 会话删除回调函数列表(用于级联清理,如知识库配置) + self._on_session_deleted_callbacks: List[Callable[[str], Awaitable[None]]] = [] + + def register_on_session_deleted( + self, callback: Callable[[str], Awaitable[None]] + ) -> None: + """注册会话删除回调函数 + + 其他模块可以注册回调来响应会话删除事件,实现级联清理。 + 例如:知识库模块可以注册回调来清理会话的知识库配置。 + + Args: + callback: 回调函数,接收会话ID (unified_msg_origin) 作为参数 + """ + self._on_session_deleted_callbacks.append(callback) + + async def _trigger_session_deleted(self, unified_msg_origin: str) -> None: + """触发会话删除回调 + + Args: + unified_msg_origin: 会话ID + """ + for callback in self._on_session_deleted_callbacks: + try: + await callback(unified_msg_origin) + except Exception as e: + from astrbot.core import logger + + logger.error( + f"会话删除回调执行失败 (session: {unified_msg_origin}): {e}" + ) + def _convert_conv_from_v2_to_v1(self, conv_v2: ConversationV2) -> Conversation: """将 ConversationV2 对象转换为 Conversation 对象""" created_at = int(conv_v2.created_at.timestamp()) @@ -106,6 +138,9 @@ class ConversationManager: self.session_conversations.pop(unified_msg_origin, None) await sp.session_remove(unified_msg_origin, "sel_conv_id") + # 触发会话删除回调(级联清理) + await self._trigger_session_deleted(unified_msg_origin) + async def get_curr_conversation_id(self, unified_msg_origin: str) -> str | None: """获取会话当前的对话 ID diff --git a/astrbot/core/core_lifecycle.py b/astrbot/core/core_lifecycle.py index 972a5f4f1..612500ef2 100644 --- a/astrbot/core/core_lifecycle.py +++ b/astrbot/core/core_lifecycle.py @@ -34,6 +34,7 @@ from astrbot.core.platform_message_history_mgr import PlatformMessageHistoryMana from astrbot.core.astrbot_config_mgr import AstrBotConfigManager from astrbot.core.star.star_handler import star_handlers_registry, EventType from astrbot.core.star.star_handler import star_map +from astrbot.core.knowledge_base.kb_manager_lifecycle import KnowledgeBaseManager class AstrBotCoreLifecycle: @@ -132,6 +133,19 @@ class AstrBotCoreLifecycle: # 根据配置实例化各个 Provider await self.provider_manager.initialize() + # 初始化知识库管理器 + self.kb_manager = KnowledgeBaseManager( + self.astrbot_config, self.db, self.provider_manager + ) + await self.kb_manager.initialize() + + # 将知识库注入器添加到 star_context 中,供 Pipeline 使用 + self.star_context.kb_injector = self.kb_manager.get_kb_injector() + + # 注册知识库会话生命周期钩子(零侵入级联清理) + if self.kb_manager.is_initialized: + self.kb_manager.register_session_lifecycle_hooks(self.conversation_manager) + # 初始化消息事件流水线调度器 self.pipeline_scheduler_mapping = await self.load_pipeline_scheduler() @@ -233,6 +247,7 @@ class AstrBotCoreLifecycle: await self.provider_manager.terminate() await self.platform_manager.terminate() + await self.kb_manager.terminate() self.dashboard_shutdown_event.set() # 再次遍历curr_tasks等待每个任务真正结束 @@ -248,6 +263,7 @@ class AstrBotCoreLifecycle: """重启 AstrBot 核心生命周期管理类, 终止各个管理器并重新加载平台实例""" await self.provider_manager.terminate() await self.platform_manager.terminate() + await self.kb_manager.terminate() self.dashboard_shutdown_event.set() threading.Thread( target=self.astrbot_updator._reboot, name="restart", daemon=True diff --git a/astrbot/core/pipeline/__init__.py b/astrbot/core/pipeline/__init__.py index 29a324a1d..fecde7f71 100644 --- a/astrbot/core/pipeline/__init__.py +++ b/astrbot/core/pipeline/__init__.py @@ -5,6 +5,7 @@ from astrbot.core.message.message_event_result import ( from .content_safety_check.stage import ContentSafetyCheckStage from .preprocess_stage.stage import PreProcessStage +from .kb_enhance.stage import KBEnhanceStage from .process_stage.stage import ProcessStage from .rate_limit_check.stage import RateLimitStage from .respond.stage import RespondStage @@ -21,6 +22,7 @@ STAGES_ORDER = [ "RateLimitStage", # 检查会话是否超过频率限制 "ContentSafetyCheckStage", # 检查内容安全 "PreProcessStage", # 预处理 + "KBEnhanceStage", # 知识库增强 "ProcessStage", # 交由 Stars 处理(a.k.a 插件),或者 LLM 调用 "ResultDecorateStage", # 处理结果,比如添加回复前缀、t2i、转换为语音 等 "RespondStage", # 发送消息 @@ -33,6 +35,7 @@ __all__ = [ "RateLimitStage", "ContentSafetyCheckStage", "PreProcessStage", + "KBEnhanceStage", "ProcessStage", "ResultDecorateStage", "RespondStage", diff --git a/astrbot/core/pipeline/kb_enhance/stage.py b/astrbot/core/pipeline/kb_enhance/stage.py new file mode 100644 index 000000000..c8441158e --- /dev/null +++ b/astrbot/core/pipeline/kb_enhance/stage.py @@ -0,0 +1,72 @@ +""" +知识库增强阶段 +在 LLM 调用之前,根据会话配置注入知识库上下文 +""" + +from typing import Union, AsyncGenerator +from ..stage import Stage, register_stage +from ..context import PipelineContext +from astrbot.core.platform.astr_message_event import AstrMessageEvent +from astrbot.core import logger + + +@register_stage +class KBEnhanceStage(Stage): + """知识库增强阶段 + + 功能: + - 检查会话是否配置了知识库 + - 如果配置了知识库,则检索相关知识并注入到事件上下文中 + - 供后续的 ProcessStage 使用 + """ + + async def initialize(self, ctx: PipelineContext) -> None: + self.ctx = ctx + self.config = ctx.astrbot_config + self.kb_config = self.config.get("knowledge_base", {}) + + async def process( + self, event: AstrMessageEvent + ) -> Union[None, AsyncGenerator[None, None]]: + """处理知识库上下文注入""" + + # 检查知识库功能是否启用 + if not self.kb_config.get("enabled", False): + return + + # 检查是否需要调用知识库 (只有在被@或唤醒时才检索) + if not event.is_at_or_wake_command: + return + + try: + # 从 plugin_manager.context 获取 kb_injector + kb_injector = getattr(self.ctx.plugin_manager.context, "kb_injector", None) + + if not kb_injector: + logger.debug("知识库注入器未初始化,跳过知识库增强") + return + + # 获取会话 ID + unified_msg_origin = event.unified_msg_origin + + # 获取用户查询 + query = event.message_str + + # 检索并注入知识 + kb_context = await kb_injector.retrieve_and_inject( + unified_msg_origin=unified_msg_origin, + query=query, + ) + + if kb_context: + # 将知识库上下文存储到事件的 extra 中 + event.set_extra("kb_context", kb_context) + logger.debug( + f"知识库上下文已注入,检索到 {len(kb_context.get('results', []))} 条相关知识" + ) + + except Exception as e: + logger.error(f"知识库增强阶段处理失败: {e}") + import traceback + + logger.error(traceback.format_exc())