feat: 集成知识库管理器,优化知识库上下文注入流程,移除冗余代码

This commit is contained in:
Soulter
2025-10-22 21:59:00 +08:00
parent a05868cc45
commit 65bc5efa19
8 changed files with 54 additions and 109 deletions
+6 -8
View File
@@ -111,6 +111,11 @@ class AstrBotCoreLifecycle:
# 初始化平台消息历史管理器
self.platform_message_history_manager = PlatformMessageHistoryManager(self.db)
# 初始化知识库管理器
self.kb_manager = KnowledgeBaseManager(
self.astrbot_config, self.db, self.provider_manager
)
# 初始化提供给插件的上下文
self.star_context = Context(
self.event_queue,
@@ -122,6 +127,7 @@ class AstrBotCoreLifecycle:
self.platform_message_history_manager,
self.persona_mgr,
self.astrbot_config_mgr,
self.kb_manager,
)
# 初始化插件管理器
@@ -133,21 +139,13 @@ class AstrBotCoreLifecycle:
# 根据配置实例化各个 Provider
await self.provider_manager.initialize()
# 初始化知识库管理器
self.kb_manager = KnowledgeBaseManager(
self.astrbot_config, self.db, self.provider_manager
)
await self.kb_manager.initialize()
# 将知识库注入器添加到 star_context 中,供 Pipeline 使用
self.star_context.kb_injector = self.kb_manager.get_kb_injector()
# 注册知识库会话生命周期钩子(零侵入级联清理)
if self.kb_manager.is_initialized:
self.kb_manager.register_session_lifecycle_hooks(self.conversation_manager)
# 初始化消息事件流水线调度器
self.pipeline_scheduler_mapping = await self.load_pipeline_scheduler()
# 初始化更新器
-24
View File
@@ -94,30 +94,6 @@ class KnowledgeBaseInjector:
"results": results_dict,
}
async def inject(
self,
session_id: str,
query: str,
top_k: int = 5,
) -> Optional[str]:
"""注入知识库上下文 (简化版本,仅返回文本)
Args:
session_id: 会话 ID (来自主数据库)
query: 用户查询
top_k: 返回结果数量
Returns:
Optional[str]: 格式化的知识上下文,如果无结果则返回 None
"""
result = await self.retrieve_and_inject(
unified_msg_origin=session_id,
query=query,
top_k=top_k,
)
return result["context_text"] if result else None
def _format_context(self, results: List[RetrievalResult]) -> str:
"""格式化知识上下文
@@ -137,6 +137,7 @@ class RetrievalManager:
query=query,
results=retrieval_results,
top_k=top_m_final,
rerank_provider=self.rerank_provider,
)
else:
retrieval_results = retrieval_results[:top_m_final]
@@ -162,7 +163,7 @@ class RetrievalManager:
# 直接调用向量数据库检索
vec_results = await self.vec_db.retrieve(
query=query,
k=top_k * len(kb_ids) * 2, # 增加候选数量以便过滤
top_k=top_k * len(kb_ids) * 2, # 增加候选数量以便过滤
)
# 过滤:只保留指定知识库的结果
@@ -187,6 +188,7 @@ class RetrievalManager:
query: str,
results: List[RetrievalResult],
top_k: int,
rerank_provider: RerankProvider,
) -> List[RetrievalResult]:
"""Rerank 重排序
@@ -205,7 +207,7 @@ class RetrievalManager:
docs = [r.content for r in results]
# 调用 Rerank Provider
rerank_results = await self.rerank_provider.rerank(
rerank_results = await rerank_provider.rerank(
query=query,
documents=docs,
)
-3
View File
@@ -5,7 +5,6 @@ from astrbot.core.message.message_event_result import (
from .content_safety_check.stage import ContentSafetyCheckStage
from .preprocess_stage.stage import PreProcessStage
from .kb_enhance.stage import KBEnhanceStage
from .process_stage.stage import ProcessStage
from .rate_limit_check.stage import RateLimitStage
from .respond.stage import RespondStage
@@ -22,7 +21,6 @@ STAGES_ORDER = [
"RateLimitStage", # 检查会话是否超过频率限制
"ContentSafetyCheckStage", # 检查内容安全
"PreProcessStage", # 预处理
"KBEnhanceStage", # 知识库增强
"ProcessStage", # 交由 Stars 处理(a.k.a 插件),或者 LLM 调用
"ResultDecorateStage", # 处理结果,比如添加回复前缀、t2i、转换为语音 等
"RespondStage", # 发送消息
@@ -35,7 +33,6 @@ __all__ = [
"RateLimitStage",
"ContentSafetyCheckStage",
"PreProcessStage",
"KBEnhanceStage",
"ProcessStage",
"ResultDecorateStage",
"RespondStage",
-72
View File
@@ -1,72 +0,0 @@
"""
知识库增强阶段
在 LLM 调用之前,根据会话配置注入知识库上下文
"""
from typing import Union, AsyncGenerator
from ..stage import Stage, register_stage
from ..context import PipelineContext
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core import logger
@register_stage
class KBEnhanceStage(Stage):
"""知识库增强阶段
功能:
- 检查会话是否配置了知识库
- 如果配置了知识库,则检索相关知识并注入到事件上下文中
- 供后续的 ProcessStage 使用
"""
async def initialize(self, ctx: PipelineContext) -> None:
self.ctx = ctx
self.config = ctx.astrbot_config
self.kb_config = self.config.get("knowledge_base", {})
async def process(
self, event: AstrMessageEvent
) -> Union[None, AsyncGenerator[None, None]]:
"""处理知识库上下文注入"""
# 检查知识库功能是否启用
if not self.kb_config.get("enabled", False):
return
# 检查是否需要调用知识库 (只有在被@或唤醒时才检索)
if not event.is_at_or_wake_command:
return
try:
# 从 plugin_manager.context 获取 kb_injector
kb_injector = getattr(self.ctx.plugin_manager.context, "kb_injector", None)
if not kb_injector:
logger.debug("知识库注入器未初始化,跳过知识库增强")
return
# 获取会话 ID
unified_msg_origin = event.unified_msg_origin
# 获取用户查询
query = event.message_str
# 检索并注入知识
kb_context = await kb_injector.retrieve_and_inject(
unified_msg_origin=unified_msg_origin,
query=query,
)
if kb_context:
# 将知识库上下文存储到事件的 extra 中
event.set_extra("kb_context", kb_context)
logger.debug(
f"知识库上下文已注入,检索到 {len(kb_context.get('results', []))} 条相关知识"
)
except Exception as e:
logger.error(f"知识库增强阶段处理失败: {e}")
import traceback
logger.error(traceback.format_exc())
@@ -33,6 +33,7 @@ from astrbot.core.star.star_handler import EventType
from astrbot.core.utils.metrics import Metric
from ...context import PipelineContext, call_event_hook, call_handler
from ..stage import Stage
from ..utils import inject_kb_context
from astrbot.core.provider.register import llm_tools
from astrbot.core.star.star_handler import star_map
from astrbot.core.astr_agent_context import AstrAgentContext
@@ -416,6 +417,14 @@ class LLMRequestSubStage(Stage):
if not req.prompt and not req.image_urls:
return
# 应用知识库
try:
await inject_kb_context(
umo=event.unified_msg_origin, p_ctx=self.ctx, req=req
)
except Exception as e:
logger.error(f"调用知识库时遇到问题: {e}")
# 执行请求 LLM 前事件钩子。
if await call_event_hook(event, EventType.OnLLMRequestEvent, req):
return
@@ -0,0 +1,32 @@
from ..context import PipelineContext
from astrbot.core.provider.entities import ProviderRequest
from astrbot.api import logger
async def inject_kb_context(
umo: str,
p_ctx: PipelineContext,
req: ProviderRequest,
top_k: int = 5,
) -> None:
"""inject knowledge base context into the provider request
Args:
p_ctx: Pipeline context
req: Provider request
"""
kb_injector = p_ctx.plugin_manager.context.kb_manager.get_kb_injector()
if not kb_injector:
return
kb_context = await kb_injector.retrieve_and_inject(
unified_msg_origin=umo,
query=req.prompt,
top_k=top_k,
)
if not kb_context:
return
formatted = kb_context.get("context_text", "") if kb_context else ""
if formatted:
results = kb_context.get("results", [])
logger.debug(f"知识库上下文注入: 为请求注入了 {len(results)} 条相关知识块")
req.system_prompt = f"{formatted}\n\n{req.system_prompt or ''}"
+3
View File
@@ -19,6 +19,7 @@ from astrbot.core.platform import Platform
from astrbot.core.platform.manager import PlatformManager
from astrbot.core.platform_message_history_mgr import PlatformMessageHistoryManager
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
from astrbot.core.knowledge_base.kb_manager_lifecycle import KnowledgeBaseManager
from astrbot.core.persona_mgr import PersonaManager
from .star import star_registry, StarMetadata, star_map
from .star_handler import star_handlers_registry, StarHandlerMetadata, EventType
@@ -55,6 +56,7 @@ class Context:
message_history_manager: PlatformMessageHistoryManager,
persona_manager: PersonaManager,
astrbot_config_mgr: AstrBotConfigManager,
knowledge_base_manager: KnowledgeBaseManager,
):
self._event_queue = event_queue
"""事件队列。消息平台通过事件队列传递消息事件。"""
@@ -68,6 +70,7 @@ class Context:
self.message_history_manager = message_history_manager
self.persona_manager = persona_manager
self.astrbot_config_mgr = astrbot_config_mgr
self.kb_manager = knowledge_base_manager
def get_registered_star(self, star_name: str) -> StarMetadata | None:
"""根据插件名获取插件的 Metadata"""