140 lines
4.2 KiB
Python
140 lines
4.2 KiB
Python
"""Knowledge base query tool and retrieval logic.
|
|
|
|
Extracted from ``astr_main_agent_resources.py`` to its own module.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing import TYPE_CHECKING
|
|
|
|
from pydantic import Field
|
|
from pydantic.dataclasses import dataclass
|
|
|
|
from astrbot.api import logger, sp
|
|
from astrbot.core.agent.run_context import ContextWrapper
|
|
from astrbot.core.agent.tool import FunctionTool, ToolExecResult
|
|
from astrbot.core.astr_agent_context import AstrAgentContext
|
|
|
|
if TYPE_CHECKING:
|
|
from astrbot.core.star.context import Context
|
|
|
|
|
|
@dataclass
|
|
class KnowledgeBaseQueryTool(FunctionTool[AstrAgentContext]):
|
|
name: str = "astr_kb_search"
|
|
description: str = (
|
|
"Query the knowledge base for facts or relevant context. "
|
|
"Use this tool when the user's question requires factual information, "
|
|
"definitions, background knowledge, or previously indexed content. "
|
|
"Only send short keywords or a concise question as the query."
|
|
)
|
|
parameters: dict = Field(
|
|
default_factory=lambda: {
|
|
"type": "object",
|
|
"properties": {
|
|
"query": {
|
|
"type": "string",
|
|
"description": "A concise keyword query for the knowledge base.",
|
|
},
|
|
},
|
|
"required": ["query"],
|
|
}
|
|
)
|
|
|
|
async def call(
|
|
self, context: ContextWrapper[AstrAgentContext], **kwargs
|
|
) -> ToolExecResult:
|
|
query = kwargs.get("query", "")
|
|
if not query:
|
|
return "error: Query parameter is empty."
|
|
result = await retrieve_knowledge_base(
|
|
query=kwargs.get("query", ""),
|
|
umo=context.context.event.unified_msg_origin,
|
|
context=context.context.context,
|
|
)
|
|
if not result:
|
|
return "No relevant knowledge found."
|
|
return result
|
|
|
|
|
|
async def retrieve_knowledge_base(
|
|
query: str,
|
|
umo: str,
|
|
context: Context,
|
|
) -> str | None:
|
|
"""Inject knowledge base context into the provider request
|
|
|
|
Args:
|
|
query: The search query string
|
|
umo: Unique message object (session ID)
|
|
context: Star context
|
|
"""
|
|
kb_mgr = context.kb_manager
|
|
config = context.get_config(umo=umo)
|
|
|
|
# 1. Prefer session-level config
|
|
session_config = await sp.session_get(umo, "kb_config", default={})
|
|
|
|
if session_config and "kb_ids" in session_config:
|
|
kb_ids = session_config.get("kb_ids", [])
|
|
|
|
if not kb_ids:
|
|
logger.info(f"[知识库] 会话 {umo} 已被配置为不使用知识库")
|
|
return
|
|
|
|
top_k = session_config.get("top_k", 5)
|
|
|
|
kb_names = []
|
|
invalid_kb_ids = []
|
|
for kb_id in kb_ids:
|
|
kb_helper = await kb_mgr.get_kb(kb_id)
|
|
if kb_helper:
|
|
kb_names.append(kb_helper.kb.kb_name)
|
|
else:
|
|
logger.warning(f"[知识库] 知识库不存在或未加载: {kb_id}")
|
|
invalid_kb_ids.append(kb_id)
|
|
|
|
if invalid_kb_ids:
|
|
logger.warning(
|
|
f"[知识库] 会话 {umo} 配置的以下知识库无效: {invalid_kb_ids}",
|
|
)
|
|
|
|
if not kb_names:
|
|
return
|
|
|
|
logger.debug(f"[知识库] 使用会话级配置,知识库数量: {len(kb_names)}")
|
|
else:
|
|
kb_names = config.get("kb_names", [])
|
|
top_k = config.get("kb_final_top_k", 5)
|
|
logger.debug(f"[知识库] 使用全局配置,知识库数量: {len(kb_names)}")
|
|
|
|
top_k_fusion = config.get("kb_fusion_top_k", 20)
|
|
|
|
if not kb_names:
|
|
return
|
|
|
|
logger.debug(f"[知识库] 开始检索知识库,数量: {len(kb_names)}, top_k={top_k}")
|
|
kb_context = await kb_mgr.retrieve(
|
|
query=query,
|
|
kb_names=kb_names,
|
|
top_k_fusion=top_k_fusion,
|
|
top_m_final=top_k,
|
|
)
|
|
|
|
if not kb_context:
|
|
return
|
|
|
|
formatted = kb_context.get("context_text", "")
|
|
if formatted:
|
|
results = kb_context.get("results", [])
|
|
logger.debug(f"[知识库] 为会话 {umo} 注入了 {len(results)} 条相关知识块")
|
|
return formatted
|
|
|
|
|
|
KNOWLEDGE_BASE_QUERY_TOOL = KnowledgeBaseQueryTool()
|
|
|
|
|
|
def get_all_tools() -> list[FunctionTool]:
|
|
"""Return all knowledge-base tools for registration."""
|
|
return [KNOWLEDGE_BASE_QUERY_TOOL]
|