From c7a58252feba0dff9dc5e1e6b334c1965bdb9db6 Mon Sep 17 00:00:00 2001 From: Soulter <37870767+Soulter@users.noreply.github.com> Date: Mon, 17 Nov 2025 17:29:18 +0800 Subject: [PATCH] feat: supports knowledge base agentic search (#3667) * feat: supports knowledge base agentic search * fix: correct formatting of system prompt in knowledge base results --- astrbot/core/config/default.py | 7 ++ .../process_stage/method/llm_request.py | 42 +++++++---- astrbot/core/pipeline/process_stage/utils.py | 74 +++++++++++++++---- 3 files changed, 94 insertions(+), 29 deletions(-) diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 5df165dbf..528611d2d 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -137,6 +137,7 @@ DEFAULT_CONFIG = { "kb_names": [], # 默认知识库名称列表 "kb_fusion_top_k": 20, # 知识库检索融合阶段返回结果数量 "kb_final_top_k": 5, # 知识库检索最终返回结果数量 + "kb_agentic_mode": False, } @@ -2146,6 +2147,7 @@ CONFIG_METADATA_2 = { "kb_names": {"type": "list", "items": {"type": "string"}}, "kb_fusion_top_k": {"type": "int", "default": 20}, "kb_final_top_k": {"type": "int", "default": 5}, + "kb_agentic_mode": {"type": "bool"}, }, }, } @@ -2241,6 +2243,11 @@ CONFIG_METADATA_3 = { "type": "int", "hint": "从知识库中检索到的结果数量,越大可能获得越多相关信息,但也可能引入噪音。建议根据实际需求调整", }, + "kb_agentic_mode": { + "description": "Agentic 知识库检索", + "type": "bool", + "hint": "启用后,知识库检索将作为 LLM Tool,由模型自主决定何时调用知识库进行查询。需要模型支持函数调用能力。", + }, }, }, "websearch": { diff --git a/astrbot/core/pipeline/process_stage/method/llm_request.py b/astrbot/core/pipeline/process_stage/method/llm_request.py index 4ef10a428..bd9e4ce3b 100644 --- a/astrbot/core/pipeline/process_stage/method/llm_request.py +++ b/astrbot/core/pipeline/process_stage/method/llm_request.py @@ -32,7 +32,7 @@ from ....astr_agent_run_util import AgentRunner, run_agent from ....astr_agent_tool_exec import FunctionToolExecutor from ...context import PipelineContext, call_event_hook from ..stage import Stage -from ..utils import inject_kb_context +from ..utils import KNOWLEDGE_BASE_QUERY_TOOL, retrieve_knowledge_base class LLMRequestSubStage(Stage): @@ -57,6 +57,7 @@ class LLMRequestSubStage(Stage): self.max_step = 30 self.show_tool_use: bool = settings.get("show_tool_use_status", True) self.show_reasoning = settings.get("display_reasoning_text", False) + self.kb_agentic_mode: bool = conf.get("kb_agentic_mode", False) for bwp in self.bot_wake_prefixs: if self.provider_wake_prefix.startswith(bwp): @@ -95,20 +96,33 @@ class LLMRequestSubStage(Stage): raise RuntimeError("无法创建新的对话。") return conversation - async def _apply_kb_context( + async def _apply_kb( self, event: AstrMessageEvent, req: ProviderRequest, ): - """应用知识库上下文到请求中""" - try: - await inject_kb_context( - umo=event.unified_msg_origin, - p_ctx=self.ctx, - req=req, - ) - except Exception as e: - logger.error(f"调用知识库时遇到问题: {e}") + """Apply knowledge base context to the provider request""" + if not self.kb_agentic_mode: + if req.prompt is None: + return + try: + kb_result = await retrieve_knowledge_base( + query=req.prompt, + umo=event.unified_msg_origin, + context=self.ctx.plugin_manager.context, + ) + if not kb_result: + return + if req.system_prompt is not None: + req.system_prompt += ( + f"\n\n[Related Knowledge Base Results]:\n{kb_result}" + ) + except Exception as e: + logger.error(f"Error occurred while retrieving knowledge base: {e}") + else: + if req.func_tool is None: + req.func_tool = ToolSet() + req.func_tool.add_tool(KNOWLEDGE_BASE_QUERY_TOOL) def _truncate_contexts( self, @@ -356,13 +370,13 @@ class LLMRequestSubStage(Stage): if not req.prompt and not req.image_urls: return - # apply knowledge base context - await self._apply_kb_context(event, req) - # call event hook if await call_event_hook(event, EventType.OnLLMRequestEvent, req): return + # apply knowledge base feature + await self._apply_kb(event, req) + # fix contexts json str if isinstance(req.contexts, str): req.contexts = json.loads(req.contexts) diff --git a/astrbot/core/pipeline/process_stage/utils.py b/astrbot/core/pipeline/process_stage/utils.py index b1168aa0a..24e052e1e 100644 --- a/astrbot/core/pipeline/process_stage/utils.py +++ b/astrbot/core/pipeline/process_stage/utils.py @@ -1,23 +1,64 @@ +from pydantic import Field +from pydantic.dataclasses import dataclass + from astrbot.api import logger, sp -from astrbot.core.provider.entities import ProviderRequest - -from ..context import PipelineContext +from astrbot.core.agent.run_context import ContextWrapper +from astrbot.core.agent.tool import FunctionTool, ToolExecResult +from astrbot.core.astr_agent_context import AstrAgentContext +from astrbot.core.star.context import Context -async def inject_kb_context( +@dataclass +class KnowledgeBaseQueryTool(FunctionTool[AstrAgentContext]): + name: str = "astr_kb_search" + description: str = ( + "Query the knowledge base for facts or relevant context. " + "Use this tool when the user's question requires factual information, " + "definitions, background knowledge, or previously indexed content. " + "Only send short keywords or a concise question as the query." + ) + parameters: dict = Field( + default_factory=lambda: { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "A concise keyword query for the knowledge base.", + }, + }, + "required": ["query"], + } + ) + + async def call( + self, context: ContextWrapper[AstrAgentContext], **kwargs + ) -> ToolExecResult: + query = kwargs.get("query", "") + if not query: + return "error: Query parameter is empty." + result = await retrieve_knowledge_base( + query=kwargs.get("query", ""), + umo=context.context.event.unified_msg_origin, + context=context.context.context, + ) + if not result: + return "No relevant knowledge found." + return result + + +async def retrieve_knowledge_base( + query: str, umo: str, - p_ctx: PipelineContext, - req: ProviderRequest, -) -> None: + context: Context, +) -> str | None: """Inject knowledge base context into the provider request Args: umo: Unique message object (session ID) p_ctx: Pipeline context - req: Provider request - """ - kb_mgr = p_ctx.plugin_manager.context.kb_manager + kb_mgr = context.kb_manager + config = context.get_config(umo=umo) # 1. 优先读取会话级配置 session_config = await sp.session_get(umo, "kb_config", default={}) @@ -54,18 +95,18 @@ async def inject_kb_context( logger.debug(f"[知识库] 使用会话级配置,知识库数量: {len(kb_names)}") else: - kb_names = p_ctx.astrbot_config.get("kb_names", []) - top_k = p_ctx.astrbot_config.get("kb_final_top_k", 5) + kb_names = config.get("kb_names", []) + top_k = config.get("kb_final_top_k", 5) logger.debug(f"[知识库] 使用全局配置,知识库数量: {len(kb_names)}") - top_k_fusion = p_ctx.astrbot_config.get("kb_fusion_top_k", 20) + top_k_fusion = config.get("kb_fusion_top_k", 20) if not kb_names: return logger.debug(f"[知识库] 开始检索知识库,数量: {len(kb_names)}, top_k={top_k}") kb_context = await kb_mgr.retrieve( - query=req.prompt, + query=query, kb_names=kb_names, top_k_fusion=top_k_fusion, top_m_final=top_k, @@ -78,4 +119,7 @@ async def inject_kb_context( if formatted: results = kb_context.get("results", []) logger.debug(f"[知识库] 为会话 {umo} 注入了 {len(results)} 条相关知识块") - req.system_prompt = f"{formatted}\n\n{req.system_prompt or ''}" + return formatted + + +KNOWLEDGE_BASE_QUERY_TOOL = KnowledgeBaseQueryTool()