feat: 集成知识库管理器,优化知识库上下文注入流程,移除冗余代码
This commit is contained in:
@@ -111,6 +111,11 @@ class AstrBotCoreLifecycle:
|
||||
# 初始化平台消息历史管理器
|
||||
self.platform_message_history_manager = PlatformMessageHistoryManager(self.db)
|
||||
|
||||
# 初始化知识库管理器
|
||||
self.kb_manager = KnowledgeBaseManager(
|
||||
self.astrbot_config, self.db, self.provider_manager
|
||||
)
|
||||
|
||||
# 初始化提供给插件的上下文
|
||||
self.star_context = Context(
|
||||
self.event_queue,
|
||||
@@ -122,6 +127,7 @@ class AstrBotCoreLifecycle:
|
||||
self.platform_message_history_manager,
|
||||
self.persona_mgr,
|
||||
self.astrbot_config_mgr,
|
||||
self.kb_manager,
|
||||
)
|
||||
|
||||
# 初始化插件管理器
|
||||
@@ -133,21 +139,13 @@ class AstrBotCoreLifecycle:
|
||||
# 根据配置实例化各个 Provider
|
||||
await self.provider_manager.initialize()
|
||||
|
||||
# 初始化知识库管理器
|
||||
self.kb_manager = KnowledgeBaseManager(
|
||||
self.astrbot_config, self.db, self.provider_manager
|
||||
)
|
||||
await self.kb_manager.initialize()
|
||||
|
||||
# 将知识库注入器添加到 star_context 中,供 Pipeline 使用
|
||||
self.star_context.kb_injector = self.kb_manager.get_kb_injector()
|
||||
|
||||
# 注册知识库会话生命周期钩子(零侵入级联清理)
|
||||
if self.kb_manager.is_initialized:
|
||||
self.kb_manager.register_session_lifecycle_hooks(self.conversation_manager)
|
||||
|
||||
# 初始化消息事件流水线调度器
|
||||
|
||||
self.pipeline_scheduler_mapping = await self.load_pipeline_scheduler()
|
||||
|
||||
# 初始化更新器
|
||||
|
||||
@@ -94,30 +94,6 @@ class KnowledgeBaseInjector:
|
||||
"results": results_dict,
|
||||
}
|
||||
|
||||
async def inject(
|
||||
self,
|
||||
session_id: str,
|
||||
query: str,
|
||||
top_k: int = 5,
|
||||
) -> Optional[str]:
|
||||
"""注入知识库上下文 (简化版本,仅返回文本)
|
||||
|
||||
Args:
|
||||
session_id: 会话 ID (来自主数据库)
|
||||
query: 用户查询
|
||||
top_k: 返回结果数量
|
||||
|
||||
Returns:
|
||||
Optional[str]: 格式化的知识上下文,如果无结果则返回 None
|
||||
"""
|
||||
result = await self.retrieve_and_inject(
|
||||
unified_msg_origin=session_id,
|
||||
query=query,
|
||||
top_k=top_k,
|
||||
)
|
||||
|
||||
return result["context_text"] if result else None
|
||||
|
||||
def _format_context(self, results: List[RetrievalResult]) -> str:
|
||||
"""格式化知识上下文
|
||||
|
||||
|
||||
@@ -137,6 +137,7 @@ class RetrievalManager:
|
||||
query=query,
|
||||
results=retrieval_results,
|
||||
top_k=top_m_final,
|
||||
rerank_provider=self.rerank_provider,
|
||||
)
|
||||
else:
|
||||
retrieval_results = retrieval_results[:top_m_final]
|
||||
@@ -162,7 +163,7 @@ class RetrievalManager:
|
||||
# 直接调用向量数据库检索
|
||||
vec_results = await self.vec_db.retrieve(
|
||||
query=query,
|
||||
k=top_k * len(kb_ids) * 2, # 增加候选数量以便过滤
|
||||
top_k=top_k * len(kb_ids) * 2, # 增加候选数量以便过滤
|
||||
)
|
||||
|
||||
# 过滤:只保留指定知识库的结果
|
||||
@@ -187,6 +188,7 @@ class RetrievalManager:
|
||||
query: str,
|
||||
results: List[RetrievalResult],
|
||||
top_k: int,
|
||||
rerank_provider: RerankProvider,
|
||||
) -> List[RetrievalResult]:
|
||||
"""Rerank 重排序
|
||||
|
||||
@@ -205,7 +207,7 @@ class RetrievalManager:
|
||||
docs = [r.content for r in results]
|
||||
|
||||
# 调用 Rerank Provider
|
||||
rerank_results = await self.rerank_provider.rerank(
|
||||
rerank_results = await rerank_provider.rerank(
|
||||
query=query,
|
||||
documents=docs,
|
||||
)
|
||||
|
||||
@@ -5,7 +5,6 @@ from astrbot.core.message.message_event_result import (
|
||||
|
||||
from .content_safety_check.stage import ContentSafetyCheckStage
|
||||
from .preprocess_stage.stage import PreProcessStage
|
||||
from .kb_enhance.stage import KBEnhanceStage
|
||||
from .process_stage.stage import ProcessStage
|
||||
from .rate_limit_check.stage import RateLimitStage
|
||||
from .respond.stage import RespondStage
|
||||
@@ -22,7 +21,6 @@ STAGES_ORDER = [
|
||||
"RateLimitStage", # 检查会话是否超过频率限制
|
||||
"ContentSafetyCheckStage", # 检查内容安全
|
||||
"PreProcessStage", # 预处理
|
||||
"KBEnhanceStage", # 知识库增强
|
||||
"ProcessStage", # 交由 Stars 处理(a.k.a 插件),或者 LLM 调用
|
||||
"ResultDecorateStage", # 处理结果,比如添加回复前缀、t2i、转换为语音 等
|
||||
"RespondStage", # 发送消息
|
||||
@@ -35,7 +33,6 @@ __all__ = [
|
||||
"RateLimitStage",
|
||||
"ContentSafetyCheckStage",
|
||||
"PreProcessStage",
|
||||
"KBEnhanceStage",
|
||||
"ProcessStage",
|
||||
"ResultDecorateStage",
|
||||
"RespondStage",
|
||||
|
||||
@@ -1,72 +0,0 @@
|
||||
"""
|
||||
知识库增强阶段
|
||||
在 LLM 调用之前,根据会话配置注入知识库上下文
|
||||
"""
|
||||
|
||||
from typing import Union, AsyncGenerator
|
||||
from ..stage import Stage, register_stage
|
||||
from ..context import PipelineContext
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core import logger
|
||||
|
||||
|
||||
@register_stage
|
||||
class KBEnhanceStage(Stage):
|
||||
"""知识库增强阶段
|
||||
|
||||
功能:
|
||||
- 检查会话是否配置了知识库
|
||||
- 如果配置了知识库,则检索相关知识并注入到事件上下文中
|
||||
- 供后续的 ProcessStage 使用
|
||||
"""
|
||||
|
||||
async def initialize(self, ctx: PipelineContext) -> None:
|
||||
self.ctx = ctx
|
||||
self.config = ctx.astrbot_config
|
||||
self.kb_config = self.config.get("knowledge_base", {})
|
||||
|
||||
async def process(
|
||||
self, event: AstrMessageEvent
|
||||
) -> Union[None, AsyncGenerator[None, None]]:
|
||||
"""处理知识库上下文注入"""
|
||||
|
||||
# 检查知识库功能是否启用
|
||||
if not self.kb_config.get("enabled", False):
|
||||
return
|
||||
|
||||
# 检查是否需要调用知识库 (只有在被@或唤醒时才检索)
|
||||
if not event.is_at_or_wake_command:
|
||||
return
|
||||
|
||||
try:
|
||||
# 从 plugin_manager.context 获取 kb_injector
|
||||
kb_injector = getattr(self.ctx.plugin_manager.context, "kb_injector", None)
|
||||
|
||||
if not kb_injector:
|
||||
logger.debug("知识库注入器未初始化,跳过知识库增强")
|
||||
return
|
||||
|
||||
# 获取会话 ID
|
||||
unified_msg_origin = event.unified_msg_origin
|
||||
|
||||
# 获取用户查询
|
||||
query = event.message_str
|
||||
|
||||
# 检索并注入知识
|
||||
kb_context = await kb_injector.retrieve_and_inject(
|
||||
unified_msg_origin=unified_msg_origin,
|
||||
query=query,
|
||||
)
|
||||
|
||||
if kb_context:
|
||||
# 将知识库上下文存储到事件的 extra 中
|
||||
event.set_extra("kb_context", kb_context)
|
||||
logger.debug(
|
||||
f"知识库上下文已注入,检索到 {len(kb_context.get('results', []))} 条相关知识"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"知识库增强阶段处理失败: {e}")
|
||||
import traceback
|
||||
|
||||
logger.error(traceback.format_exc())
|
||||
@@ -33,6 +33,7 @@ from astrbot.core.star.star_handler import EventType
|
||||
from astrbot.core.utils.metrics import Metric
|
||||
from ...context import PipelineContext, call_event_hook, call_handler
|
||||
from ..stage import Stage
|
||||
from ..utils import inject_kb_context
|
||||
from astrbot.core.provider.register import llm_tools
|
||||
from astrbot.core.star.star_handler import star_map
|
||||
from astrbot.core.astr_agent_context import AstrAgentContext
|
||||
@@ -416,6 +417,14 @@ class LLMRequestSubStage(Stage):
|
||||
if not req.prompt and not req.image_urls:
|
||||
return
|
||||
|
||||
# 应用知识库
|
||||
try:
|
||||
await inject_kb_context(
|
||||
umo=event.unified_msg_origin, p_ctx=self.ctx, req=req
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"调用知识库时遇到问题: {e}")
|
||||
|
||||
# 执行请求 LLM 前事件钩子。
|
||||
if await call_event_hook(event, EventType.OnLLMRequestEvent, req):
|
||||
return
|
||||
|
||||
@@ -0,0 +1,32 @@
|
||||
from ..context import PipelineContext
|
||||
from astrbot.core.provider.entities import ProviderRequest
|
||||
from astrbot.api import logger
|
||||
|
||||
|
||||
async def inject_kb_context(
|
||||
umo: str,
|
||||
p_ctx: PipelineContext,
|
||||
req: ProviderRequest,
|
||||
top_k: int = 5,
|
||||
) -> None:
|
||||
"""inject knowledge base context into the provider request
|
||||
|
||||
Args:
|
||||
p_ctx: Pipeline context
|
||||
req: Provider request
|
||||
"""
|
||||
kb_injector = p_ctx.plugin_manager.context.kb_manager.get_kb_injector()
|
||||
if not kb_injector:
|
||||
return
|
||||
kb_context = await kb_injector.retrieve_and_inject(
|
||||
unified_msg_origin=umo,
|
||||
query=req.prompt,
|
||||
top_k=top_k,
|
||||
)
|
||||
if not kb_context:
|
||||
return
|
||||
formatted = kb_context.get("context_text", "") if kb_context else ""
|
||||
if formatted:
|
||||
results = kb_context.get("results", [])
|
||||
logger.debug(f"知识库上下文注入: 为请求注入了 {len(results)} 条相关知识块")
|
||||
req.system_prompt = f"{formatted}\n\n{req.system_prompt or ''}"
|
||||
@@ -19,6 +19,7 @@ from astrbot.core.platform import Platform
|
||||
from astrbot.core.platform.manager import PlatformManager
|
||||
from astrbot.core.platform_message_history_mgr import PlatformMessageHistoryManager
|
||||
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
|
||||
from astrbot.core.knowledge_base.kb_manager_lifecycle import KnowledgeBaseManager
|
||||
from astrbot.core.persona_mgr import PersonaManager
|
||||
from .star import star_registry, StarMetadata, star_map
|
||||
from .star_handler import star_handlers_registry, StarHandlerMetadata, EventType
|
||||
@@ -55,6 +56,7 @@ class Context:
|
||||
message_history_manager: PlatformMessageHistoryManager,
|
||||
persona_manager: PersonaManager,
|
||||
astrbot_config_mgr: AstrBotConfigManager,
|
||||
knowledge_base_manager: KnowledgeBaseManager,
|
||||
):
|
||||
self._event_queue = event_queue
|
||||
"""事件队列。消息平台通过事件队列传递消息事件。"""
|
||||
@@ -68,6 +70,7 @@ class Context:
|
||||
self.message_history_manager = message_history_manager
|
||||
self.persona_manager = persona_manager
|
||||
self.astrbot_config_mgr = astrbot_config_mgr
|
||||
self.kb_manager = knowledge_base_manager
|
||||
|
||||
def get_registered_star(self, star_name: str) -> StarMetadata | None:
|
||||
"""根据插件名获取插件的 Metadata"""
|
||||
|
||||
Reference in New Issue
Block a user