"""Knowledge base query tool and retrieval logic. Extracted from ``astr_main_agent_resources.py`` to its own module. """ from __future__ import annotations from typing import TYPE_CHECKING from pydantic import Field from pydantic.dataclasses import dataclass from astrbot.api import logger, sp from astrbot.core.agent.run_context import ContextWrapper from astrbot.core.agent.tool import FunctionTool, ToolExecResult from astrbot.core.astr_agent_context import AstrAgentContext if TYPE_CHECKING: from astrbot.core.star.context import Context @dataclass class KnowledgeBaseQueryTool(FunctionTool[AstrAgentContext]): name: str = "astr_kb_search" description: str = ( "Query the knowledge base for facts or relevant context. " "Use this tool when the user's question requires factual information, " "definitions, background knowledge, or previously indexed content. " "Only send short keywords or a concise question as the query." ) parameters: dict = Field( default_factory=lambda: { "type": "object", "properties": { "query": { "type": "string", "description": "A concise keyword query for the knowledge base.", }, }, "required": ["query"], } ) async def call( self, context: ContextWrapper[AstrAgentContext], **kwargs ) -> ToolExecResult: query = kwargs.get("query", "") if not query: return "error: Query parameter is empty." result = await retrieve_knowledge_base( query=kwargs.get("query", ""), umo=context.context.event.unified_msg_origin, context=context.context.context, ) if not result: return "No relevant knowledge found." return result async def retrieve_knowledge_base( query: str, umo: str, context: Context, ) -> str | None: """Inject knowledge base context into the provider request Args: query: The search query string umo: Unique message object (session ID) context: Star context """ kb_mgr = context.kb_manager config = context.get_config(umo=umo) # 1. Prefer session-level config session_config = await sp.session_get(umo, "kb_config", default={}) if session_config and "kb_ids" in session_config: kb_ids = session_config.get("kb_ids", []) if not kb_ids: logger.info(f"[知识库] 会话 {umo} 已被配置为不使用知识库") return top_k = session_config.get("top_k", 5) kb_names = [] invalid_kb_ids = [] for kb_id in kb_ids: kb_helper = await kb_mgr.get_kb(kb_id) if kb_helper: kb_names.append(kb_helper.kb.kb_name) else: logger.warning(f"[知识库] 知识库不存在或未加载: {kb_id}") invalid_kb_ids.append(kb_id) if invalid_kb_ids: logger.warning( f"[知识库] 会话 {umo} 配置的以下知识库无效: {invalid_kb_ids}", ) if not kb_names: return logger.debug(f"[知识库] 使用会话级配置,知识库数量: {len(kb_names)}") else: kb_names = config.get("kb_names", []) top_k = config.get("kb_final_top_k", 5) logger.debug(f"[知识库] 使用全局配置,知识库数量: {len(kb_names)}") top_k_fusion = config.get("kb_fusion_top_k", 20) if not kb_names: return logger.debug(f"[知识库] 开始检索知识库,数量: {len(kb_names)}, top_k={top_k}") kb_context = await kb_mgr.retrieve( query=query, kb_names=kb_names, top_k_fusion=top_k_fusion, top_m_final=top_k, ) if not kb_context: return formatted = kb_context.get("context_text", "") if formatted: results = kb_context.get("results", []) logger.debug(f"[知识库] 为会话 {umo} 注入了 {len(results)} 条相关知识块") return formatted KNOWLEDGE_BASE_QUERY_TOOL = KnowledgeBaseQueryTool() def get_all_tools() -> list[FunctionTool]: """Return all knowledge-base tools for registration.""" return [KNOWLEDGE_BASE_QUERY_TOOL]