diff --git a/astrbot/core/knowledge_base/__init__.py b/astrbot/core/knowledge_base/__init__.py new file mode 100644 index 000000000..a881eef45 --- /dev/null +++ b/astrbot/core/knowledge_base/__init__.py @@ -0,0 +1,34 @@ +""" +知识库管理模块 + +提供文档上传、解析、分块、向量化、检索等功能 +""" + +from astrbot.core.db.po import KBSessionConfig +from astrbot.core.knowledge_base.models import ( + KBChunk, + KBDocument, + KBMedia, + KnowledgeBase, +) + +# 注意: 以下导入在对应模块实现后取消注释 +from .database import KBDatabase +from .manager import KBManager +from .manager_ops import KBManagerOps +from .session_config_db import SessionConfigDB + +# from .injector import KnowledgeBaseInjector + +__all__ = [ + "KnowledgeBase", + "KBDocument", + "KBChunk", + "KBMedia", + "KBSessionConfig", + "KBDatabase", + "SessionConfigDB", + "KBManager", + "KBManagerOps", + # "KnowledgeBaseInjector", +] diff --git a/astrbot/core/knowledge_base/chunking/__init__.py b/astrbot/core/knowledge_base/chunking/__init__.py new file mode 100644 index 000000000..3124afe81 --- /dev/null +++ b/astrbot/core/knowledge_base/chunking/__init__.py @@ -0,0 +1,11 @@ +""" +文档分块模块 +""" + +from .base import BaseChunker +from .fixed_size import FixedSizeChunker + +__all__ = [ + "BaseChunker", + "FixedSizeChunker", +] diff --git a/astrbot/core/knowledge_base/chunking/base.py b/astrbot/core/knowledge_base/chunking/base.py new file mode 100644 index 000000000..bcc29a5cf --- /dev/null +++ b/astrbot/core/knowledge_base/chunking/base.py @@ -0,0 +1,24 @@ +"""文档分块器基类 + +定义了文档分块处理的抽象接口。 +""" + +from abc import ABC, abstractmethod + + +class BaseChunker(ABC): + """分块器基类 + + 所有分块器都应该继承此类并实现 chunk 方法。 + """ + + @abstractmethod + async def chunk(self, text: str) -> list[str]: + """将文本分块 + + Args: + text: 输入文本 + + Returns: + list[str]: 分块后的文本列表 + """ diff --git a/astrbot/core/knowledge_base/chunking/fixed_size.py b/astrbot/core/knowledge_base/chunking/fixed_size.py new file mode 100644 index 000000000..4d1a1b280 --- /dev/null +++ b/astrbot/core/knowledge_base/chunking/fixed_size.py @@ -0,0 +1,52 @@ +"""固定大小分块器 + +按照固定的字符数将文本分块,支持重叠区域。 +""" + +from astrbot.core.knowledge_base.chunking.base import BaseChunker + + +class FixedSizeChunker(BaseChunker): + """固定大小分块器 + + 按照固定的字符数分块,并支持块之间的重叠。 + """ + + def __init__(self, chunk_size: int = 512, chunk_overlap: int = 50): + """初始化分块器 + + Args: + chunk_size: 块的大小(字符数) + chunk_overlap: 块之间的重叠字符数 + """ + self.chunk_size = chunk_size + self.chunk_overlap = chunk_overlap + + async def chunk(self, text: str) -> list[str]: + """固定大小分块 + + Args: + text: 输入文本 + + Returns: + list[str]: 分块后的文本列表 + """ + chunks = [] + start = 0 + text_len = len(text) + + while start < text_len: + end = start + self.chunk_size + chunk = text[start:end] + + if chunk: + chunks.append(chunk) + + # 移动窗口,保留重叠部分 + start = end - self.chunk_overlap + + # 防止无限循环: 如果重叠过大,直接移到end + if start >= end or self.chunk_overlap >= self.chunk_size: + start = end + + return chunks diff --git a/astrbot/core/knowledge_base/database.py b/astrbot/core/knowledge_base/database.py new file mode 100644 index 000000000..83ec0e4ba --- /dev/null +++ b/astrbot/core/knowledge_base/database.py @@ -0,0 +1,347 @@ +"""知识库数据库操作类 + +该模块封装知识库、文档、块、多媒体和会话配置相关的数据库查询操作。 + +注意: +- 该模块操作的是独立的知识库数据库 (data/knowledge_base/kb.db) +- 会话配置也存储在此数据库中,会话ID来源于主数据库 +""" + +import json +from typing import Optional + +from sqlalchemy import func, select + +from astrbot.core.knowledge_base.kb_sqlite import KBSQLiteDatabase +from astrbot.core.knowledge_base.models import ( + KBChunk, + KBDocument, + KBMedia, + KBSessionConfig, + KnowledgeBase, +) + + +class KBDatabase: + """知识库数据库操作类 + + 职责: + - 封装知识库、文档、块、多媒体和会话配置的数据库查询操作 + - 统一异常处理 + + 注意: + - 该类操作独立的知识库数据库 (kb.db) + - 会话配置存储会话ID与知识库的绑定关系,会话ID来源于主数据库 + """ + + def __init__(self, kb_db: KBSQLiteDatabase): + """初始化知识库数据库操作类 + + Args: + kb_db: 知识库独立数据库实例,而非主数据库 + """ + self.db = kb_db + + # ===== 知识库查询 ===== + + async def get_kb_by_id(self, kb_id: str) -> Optional[KnowledgeBase]: + """根据 ID 获取知识库""" + async with self.db.get_db() as session: + stmt = select(KnowledgeBase).where(KnowledgeBase.kb_id == kb_id) + result = await session.execute(stmt) + return result.scalar_one_or_none() + + async def get_kb_by_name(self, kb_name: str) -> Optional[KnowledgeBase]: + """根据名称获取知识库""" + async with self.db.get_db() as session: + stmt = select(KnowledgeBase).where(KnowledgeBase.kb_name == kb_name) + result = await session.execute(stmt) + return result.scalar_one_or_none() + + async def list_kbs(self, offset: int = 0, limit: int = 100) -> list[KnowledgeBase]: + """列出所有知识库""" + async with self.db.get_db() as session: + stmt = ( + select(KnowledgeBase) + .offset(offset) + .limit(limit) + .order_by(KnowledgeBase.created_at.desc()) + ) + result = await session.execute(stmt) + return list(result.scalars().all()) + + async def count_kbs(self) -> int: + """统计知识库数量""" + async with self.db.get_db() as session: + stmt = select(func.count(KnowledgeBase.id)) + result = await session.execute(stmt) + return result.scalar() or 0 + + # ===== 文档查询 ===== + + async def get_document_by_id(self, doc_id: str) -> Optional[KBDocument]: + """根据 ID 获取文档""" + async with self.db.get_db() as session: + stmt = select(KBDocument).where(KBDocument.doc_id == doc_id) + result = await session.execute(stmt) + return result.scalar_one_or_none() + + async def list_documents_by_kb( + self, kb_id: str, offset: int = 0, limit: int = 100 + ) -> list[KBDocument]: + """列出知识库的所有文档""" + async with self.db.get_db() as session: + stmt = ( + select(KBDocument) + .where(KBDocument.kb_id == kb_id) + .offset(offset) + .limit(limit) + .order_by(KBDocument.created_at.desc()) + ) + result = await session.execute(stmt) + return list(result.scalars().all()) + + async def count_documents_by_kb(self, kb_id: str) -> int: + """统计知识库的文档数量""" + async with self.db.get_db() as session: + stmt = select(func.count(KBDocument.id)).where(KBDocument.kb_id == kb_id) + result = await session.execute(stmt) + return result.scalar() or 0 + + # ===== 块查询 ===== + + async def get_chunk_by_id(self, chunk_id: str) -> Optional[KBChunk]: + """根据 ID 获取块""" + async with self.db.get_db() as session: + stmt = select(KBChunk).where(KBChunk.chunk_id == chunk_id) + result = await session.execute(stmt) + return result.scalar_one_or_none() + + async def get_chunks_by_kb_ids(self, kb_ids: list[str]) -> list[KBChunk]: + """根据知识库 ID 列表获取所有块""" + async with self.db.get_db() as session: + stmt = select(KBChunk).where(KBChunk.kb_id.in_(kb_ids)) + result = await session.execute(stmt) + return list(result.scalars().all()) + + async def get_chunk_by_vec_doc_id(self, vec_doc_id: str) -> Optional[KBChunk]: + """根据向量文档 ID 获取块""" + async with self.db.get_db() as session: + stmt = select(KBChunk).where(KBChunk.vec_doc_id == vec_doc_id) + result = await session.execute(stmt) + return result.scalar_one_or_none() + + async def get_chunk_with_metadata(self, chunk_id: str) -> Optional[dict]: + """获取块及其关联的文档和知识库元数据""" + async with self.db.get_db() as session: + stmt = ( + select(KBChunk, KBDocument, KnowledgeBase) + .join(KBDocument, KBChunk.doc_id == KBDocument.doc_id) + .join(KnowledgeBase, KBChunk.kb_id == KnowledgeBase.kb_id) + .where(KBChunk.chunk_id == chunk_id) + ) + result = await session.execute(stmt) + row = result.first() + + if not row: + return None + + chunk, doc, kb = row + return { + "chunk": chunk, + "document": doc, + "knowledge_base": kb, + } + + async def list_chunks_by_doc( + self, doc_id: str, offset: int = 0, limit: int = 100 + ) -> list[KBChunk]: + """列出文档的所有块""" + async with self.db.get_db() as session: + stmt = ( + select(KBChunk) + .where(KBChunk.doc_id == doc_id) + .offset(offset) + .limit(limit) + .order_by(KBChunk.chunk_index) + ) + result = await session.execute(stmt) + return list(result.scalars().all()) + + # ===== 多媒体查询 ===== + + async def list_media_by_doc(self, doc_id: str) -> list[KBMedia]: + """列出文档的所有多媒体资源""" + async with self.db.get_db() as session: + stmt = select(KBMedia).where(KBMedia.doc_id == doc_id) + result = await session.execute(stmt) + return list(result.scalars().all()) + + async def get_media_by_id(self, media_id: str) -> Optional[KBMedia]: + """根据 ID 获取多媒体资源""" + async with self.db.get_db() as session: + stmt = select(KBMedia).where(KBMedia.media_id == media_id) + result = await session.execute(stmt) + return result.scalar_one_or_none() + + # ===== 会话配置查询 ===== + + async def get_session_kb_ids(self, session_id: str) -> list[str]: + """获取会话关联的知识库 ID 列表 + + 查找顺序: + 1. 会话级别配置 (优先) + 2. 平台级别配置 + 3. 返回空列表 + + Args: + session_id: 会话ID(来自主数据库) + + Returns: + 知识库ID列表 + """ + async with self.db.get_db() as session: + # 1. 查找会话级别配置 + stmt = select(KBSessionConfig).where( + KBSessionConfig.scope == "session", + KBSessionConfig.scope_id == session_id, + ) + result = await session.execute(stmt) + config = result.scalar_one_or_none() + + if config: + return json.loads(config.kb_ids) + + # 2. 提取平台 ID (格式: platform:xxx:session_id) + parts = session_id.split(":") + if len(parts) >= 2: + platform_id = parts[0] + + # 查找平台级别配置 + stmt = select(KBSessionConfig).where( + KBSessionConfig.scope == "platform", + KBSessionConfig.scope_id == platform_id, + ) + result = await session.execute(stmt) + config = result.scalar_one_or_none() + + if config: + return json.loads(config.kb_ids) + + # 3. 无配置 + return [] + + async def set_session_kb_ids( + self, + scope: str, + scope_id: str, + kb_ids: list[str], + top_k: Optional[int] = None, + enable_rerank: Optional[bool] = None, + ) -> KBSessionConfig: + """设置会话知识库配置 + + Args: + scope: 配置范围 (session/platform) + scope_id: 范围标识 (会话 ID 或平台 ID,来自主数据库) + kb_ids: 知识库 ID 列表 + top_k: 返回结果数量 (可选) + enable_rerank: 是否启用 Rerank (可选) + + Returns: + 配置对象 + """ + async with self.db.get_db() as session: + # 查找现有配置 + stmt = select(KBSessionConfig).where( + KBSessionConfig.scope == scope, + KBSessionConfig.scope_id == scope_id, + ) + result = await session.execute(stmt) + config = result.scalar_one_or_none() + + if config: + # 更新现有配置 + config.kb_ids = json.dumps(kb_ids) + if top_k is not None: + config.top_k = top_k + if enable_rerank is not None: + config.enable_rerank = enable_rerank + else: + # 创建新配置 + config = KBSessionConfig( + scope=scope, + scope_id=scope_id, + kb_ids=json.dumps(kb_ids), + top_k=top_k, + enable_rerank=enable_rerank, + ) + session.add(config) + + await session.commit() + await session.refresh(config) + return config + + async def delete_session_kb_config(self, scope: str, scope_id: str) -> bool: + """删除会话知识库配置 + + Args: + scope: 配置范围 (session/platform) + scope_id: 范围标识 (会话 ID 或平台 ID) + + Returns: + 是否删除成功 + """ + async with self.db.get_db() as session: + stmt = select(KBSessionConfig).where( + KBSessionConfig.scope == scope, + KBSessionConfig.scope_id == scope_id, + ) + result = await session.execute(stmt) + config = result.scalar_one_or_none() + + if not config: + return False + + await session.delete(config) + await session.commit() + return True + + async def delete_session_kb_config_by_session_id(self, session_id: str) -> bool: + """根据会话ID删除会话配置(用于主数据库会话删除时的级联清理) + + Args: + session_id: 会话ID(来自主数据库) + + Returns: + 是否删除成功 + """ + return await self.delete_session_kb_config("session", session_id) + + async def list_all_session_configs( + self, offset: int = 0, limit: int = 100, scope: Optional[str] = None + ) -> list[KBSessionConfig]: + """列出所有会话配置 + + Args: + offset: 偏移量 + limit: 限制数量 + scope: 可选的范围过滤 (session/platform) + + Returns: + 会话配置列表 + """ + async with self.db.get_db() as session: + stmt = select(KBSessionConfig) + + if scope: + stmt = stmt.where(KBSessionConfig.scope == scope) + + stmt = ( + stmt.offset(offset) + .limit(limit) + .order_by(KBSessionConfig.created_at.desc()) + ) + + result = await session.execute(stmt) + return list(result.scalars().all()) diff --git a/astrbot/core/knowledge_base/injector.py b/astrbot/core/knowledge_base/injector.py new file mode 100644 index 000000000..4f297f29f --- /dev/null +++ b/astrbot/core/knowledge_base/injector.py @@ -0,0 +1,139 @@ +"""知识库上下文注入器 + +负责检索相关知识并格式化为 LLM 可用的上下文文本 +""" + +from typing import List, Optional + +from astrbot.core.knowledge_base.database import KBDatabase +from astrbot.core.knowledge_base.retrieval.manager import ( + RetrievalManager, + RetrievalResult, +) + + +class KnowledgeBaseInjector: + """知识库上下文注入器 + + 职责: + - 检索相关知识 + - 格式化为上下文文本 + - 注入到 LLM Prompt + """ + + def __init__( + self, + kb_db: KBDatabase, + retrieval_manager: RetrievalManager, + ): + """初始化知识库上下文注入器 + + Args: + kb_db: 知识库数据库实例 + retrieval_manager: 检索管理器实例 + """ + self.kb_db = kb_db + self.retrieval_manager = retrieval_manager + + async def retrieve_and_inject( + self, + unified_msg_origin: str, + query: str, + top_k: int = 5, + ) -> Optional[dict]: + """检索并注入知识库上下文 + + Args: + unified_msg_origin: 统一消息来源 ID (会话 ID) + query: 用户查询 + top_k: 返回结果数量 + + Returns: + Optional[dict]: 包含检索结果和格式化上下文的字典,如果无结果则返回 None + { + "context_text": str, # 格式化的上下文文本 + "results": List[dict], # 原始检索结果列表 + } + """ + # 1. 获取会话关联的知识库 + kb_ids = await self.kb_db.get_session_kb_ids(unified_msg_origin) + + if not kb_ids: + return None + + # 2. 检索知识 + results = await self.retrieval_manager.retrieve( + query=query, + kb_ids=kb_ids, + top_m_final=top_k, + ) + + if not results: + return None + + # 3. 格式化上下文 + context_text = self._format_context(results) + + # 4. 转换结果为字典格式 + results_dict = [ + { + "chunk_id": r.chunk_id, + "doc_id": r.doc_id, + "kb_id": r.kb_id, + "kb_name": r.kb_name, + "doc_name": r.doc_name, + "chunk_index": r.metadata.get("chunk_index", 0), + "content": r.content, + "score": r.score, + } + for r in results + ] + + return { + "context_text": context_text, + "results": results_dict, + } + + async def inject( + self, + session_id: str, + query: str, + top_k: int = 5, + ) -> Optional[str]: + """注入知识库上下文 (简化版本,仅返回文本) + + Args: + session_id: 会话 ID (来自主数据库) + query: 用户查询 + top_k: 返回结果数量 + + Returns: + Optional[str]: 格式化的知识上下文,如果无结果则返回 None + """ + result = await self.retrieve_and_inject( + unified_msg_origin=session_id, + query=query, + top_k=top_k, + ) + + return result["context_text"] if result else None + + def _format_context(self, results: List[RetrievalResult]) -> str: + """格式化知识上下文 + + Args: + results: 检索结果列表 + + Returns: + str: 格式化的上下文文本 + """ + lines = ["以下是相关的知识库内容,请参考这些信息回答用户的问题:\n"] + + for i, result in enumerate(results, 1): + lines.append(f"【知识 {i}】") + lines.append(f"来源: {result.kb_name} / {result.doc_name}") + lines.append(f"内容: {result.content}") + lines.append(f"相关度: {result.score:.2f}") + lines.append("") + + return "\n".join(lines) diff --git a/astrbot/core/knowledge_base/kb_manager_lifecycle.py b/astrbot/core/knowledge_base/kb_manager_lifecycle.py new file mode 100644 index 000000000..51830769c --- /dev/null +++ b/astrbot/core/knowledge_base/kb_manager_lifecycle.py @@ -0,0 +1,358 @@ +""" +知识库管理器 +负责知识库模块的初始化、配置和资源管理 + +架构说明: +- 知识库数据存储在独立的数据库 (data/knowledge_base/kb.db) +- 会话配置存储在主数据库 (data/astrbot.db) 以便于会话关联 +""" + +from pathlib import Path +from astrbot.core import logger +from astrbot.core.db import BaseDatabase +from astrbot.core.provider.manager import ProviderManager + + +class KnowledgeBaseManager: + """知识库管理器 + + 职责: + - 知识库模块的初始化 + - Embedding Provider 和 Rerank Provider 的选择 + - 各个子组件的协调管理 + - 注册会话删除回调,实现级联清理 + + 架构说明: + - 知识库数据存储在独立数据库 (kb.db) + - 会话配置存储在独立数据库 (kb.db),会话ID来自主数据库 + - 通过回调机制实现与主数据库的生命周期同步 + """ + + def __init__( + self, + config: dict, + main_db: BaseDatabase, + provider_manager: ProviderManager, + ): + """初始化知识库管理器 + + Args: + config: 配置字典 + main_db: 主数据库实例 (不直接使用,仅用于类型引用) + provider_manager: Provider 管理器 + """ + self.config = config.get("knowledge_base", {}) + self.provider_manager = provider_manager + + # 知识库独立数据库 + self.kb_db = None + + # 组件实例 + self.kb_database = None + self.kb_manager = None + self.kb_vec_db = None + self.retrieval_manager = None + self.kb_injector = None + + self._initialized = False + self._session_deleted_callback_registered = False + + async def initialize(self): + """初始化知识库模块""" + if not self.config.get("enabled", False): + logger.info("知识库功能未启用") + return + + try: + logger.info("正在初始化知识库模块...") + + # 1. 检查并选择 Embedding Provider + embedding_provider = self._select_embedding_provider() + if not embedding_provider: + logger.warning("未配置 Embedding Provider,知识库功能无法使用") + return + + # 2. 初始化数据库 + await self._init_kb_database() + await self._init_database() + + # 3. 初始化向量数据库 + await self._init_vector_db(embedding_provider) + + # 4. 初始化解析器和分块器 + parsers = self._init_parsers() + chunker = self._init_chunker() + + # 5. 初始化知识库管理器 + await self._init_kb_manager(parsers, chunker) + + # 6. 初始化检索管理器 + await self._init_retrieval_manager() + + # 7. 初始化上下文注入器 + await self._init_injector() + + self._initialized = True + logger.info("知识库模块初始化完成") + + except ImportError as e: + logger.error(f"知识库模块导入失败: {e}") + logger.warning("请确保已安装所需依赖: pypdf, aiofiles, Pillow, rank-bm25") + except Exception as e: + logger.error(f"知识库模块初始化失败: {e}") + import traceback + + logger.error(traceback.format_exc()) + + async def _init_kb_database(self): + """初始化知识库独立数据库""" + from astrbot.core.knowledge_base.kb_sqlite import KBSQLiteDatabase + + db_path = self.config.get("storage", {}).get( + "kb_db_path", "data/knowledge_base/kb.db" + ) + Path(db_path).parent.mkdir(parents=True, exist_ok=True) + + self.kb_db = KBSQLiteDatabase(db_path) + await self.kb_db.initialize() + await self.kb_db.migrate_to_v1() + + logger.info(f"知识库独立数据库已初始化: {db_path}") + + async def _init_database(self): + """初始化知识库数据库操作类""" + from astrbot.core.knowledge_base.database import KBDatabase + + self.kb_database = KBDatabase(self.kb_db) + + async def _init_vector_db(self, embedding_provider): + """初始化向量数据库""" + from astrbot.core.db.vec_db.faiss_impl import FaissVecDB + + storage_path = self.config.get("storage", {}).get( + "vector_db_path", "data/knowledge_base/vectors" + ) + Path(storage_path).mkdir(parents=True, exist_ok=True) + + self.kb_vec_db = FaissVecDB( + doc_store_path=f"{storage_path}/documents.db", + index_store_path=f"{storage_path}/index.faiss", + embedding_provider=embedding_provider, + ) + await self.kb_vec_db.initialize() + + def _init_parsers(self) -> dict: + """初始化文档解析器""" + from astrbot.core.knowledge_base.parsers.text_parser import TextParser + from astrbot.core.knowledge_base.parsers.pdf_parser import PDFParser + + return { + "txt": TextParser(), + "md": TextParser(), + "markdown": TextParser(), + "pdf": PDFParser(), + } + + def _init_chunker(self): + """初始化分块器""" + from astrbot.core.knowledge_base.chunking.fixed_size import FixedSizeChunker + + chunking_config = self.config.get("chunking", {}) + return FixedSizeChunker( + chunk_size=chunking_config.get("chunk_size", 512), + chunk_overlap=chunking_config.get("chunk_overlap", 50), + ) + + async def _init_kb_manager(self, parsers: dict, chunker): + """初始化知识库管理器""" + from astrbot.core.knowledge_base.manager import KBManager + + files_path = self.config.get("storage", {}).get( + "files_path", "data/knowledge_base" + ) + + self.kb_manager = KBManager( + db=self.kb_db, # 使用独立的知识库数据库 + vec_db=self.kb_vec_db, + storage_path=files_path, + parsers=parsers, + chunker=chunker, + ) + + async def _init_retrieval_manager(self): + """初始化检索管理器""" + from astrbot.core.knowledge_base.retrieval.manager import RetrievalManager + from astrbot.core.knowledge_base.retrieval.sparse_retriever import ( + SparseRetriever, + ) + from astrbot.core.knowledge_base.retrieval.rank_fusion import RankFusion + + sparse_retriever = SparseRetriever(self.kb_database) + rank_fusion = RankFusion(self.kb_database) + + # 选择 Rerank Provider (可选) + rerank_provider = self._select_rerank_provider() + + self.retrieval_manager = RetrievalManager( + vec_db=self.kb_vec_db, + sparse_retriever=sparse_retriever, + rank_fusion=rank_fusion, + kb_db=self.kb_database, + rerank_provider=rerank_provider, + ) + + async def _init_injector(self): + """初始化上下文注入器""" + from astrbot.core.knowledge_base.injector import KnowledgeBaseInjector + + self.kb_injector = KnowledgeBaseInjector( + kb_db=self.kb_database, + retrieval_manager=self.retrieval_manager, + ) + + def _select_embedding_provider(self): + """选择 Embedding Provider + + 逻辑: + - 如果配置了 embedding_provider_id,则使用指定的 provider + - 如果没有配置,但有 embedding provider,则使用第一个 + - 如果有多个 embedding provider 但没有指定,则警告并使用第一个 + """ + embedding_providers = self.provider_manager.embedding_provider_insts + + if not embedding_providers: + return None + + configured_provider_id = self.config.get("embedding_provider_id") + + if configured_provider_id: + # 按 ID 查找 + for provider in embedding_providers: + provider_id = provider.meta().id + if provider_id == configured_provider_id: + logger.info(f"知识库使用 Embedding Provider: {provider_id}") + return provider + logger.warning( + f"未找到配置的 Embedding Provider ID: {configured_provider_id}," + f"将使用第一个可用的" + ) + + if len(embedding_providers) > 1 and not configured_provider_id: + logger.warning( + f"检测到 {len(embedding_providers)} 个 Embedding Provider," + f"但未指定使用哪个,将默认使用第一个" + ) + + provider = embedding_providers[0] + provider_id = provider.meta().id + logger.info(f"知识库使用 Embedding Provider: {provider_id}") + return provider + + def _select_rerank_provider(self): + """选择 Rerank Provider (可选)""" + if not self.config.get("retrieval", {}).get("enable_rerank", True): + return None + + rerank_providers = self.provider_manager.rerank_provider_insts + if not rerank_providers: + return None + + configured_provider_id = self.config.get("rerank_provider_id") + + if configured_provider_id: + for provider in rerank_providers: + provider_id = provider.meta().id + if provider_id == configured_provider_id: + logger.info(f"知识库使用 Rerank Provider: {provider_id}") + return provider + logger.warning(f"未找到配置的 Rerank Provider ID: {configured_provider_id}") + + if len(rerank_providers) > 0: + provider = rerank_providers[0] + provider_id = provider.meta().id + logger.info(f"知识库使用 Rerank Provider: {provider_id}") + return provider + + return None + + @property + def is_initialized(self) -> bool: + """检查是否已初始化""" + return self._initialized + + def get_kb_manager(self): + """获取知识库管理器""" + return self.kb_manager if self._initialized else None + + def get_kb_injector(self): + """获取知识库上下文注入器""" + return self.kb_injector if self._initialized else None + + def register_session_lifecycle_hooks(self, conversation_manager): + """注册会话生命周期钩子 + + 在会话删除时自动清理知识库配置,实现零侵入的级联清理。 + + Args: + conversation_manager: 会话管理器实例 + """ + if self._session_deleted_callback_registered or not self._initialized: + return + + async def on_session_deleted(session_id: str): + """会话删除回调:清理知识库配置""" + try: + await self.kb_database.delete_session_kb_config_by_session_id(session_id) + logger.info(f"已清理会话知识库配置: {session_id}") + except Exception as e: + logger.error(f"清理会话知识库配置失败 ({session_id}): {e}") + + conversation_manager.register_on_session_deleted(on_session_deleted) + self._session_deleted_callback_registered = True + logger.info("已注册知识库会话删除回调") + + async def reinitialize(self): + """重新初始化知识库模块 + + 用于在运行时动态初始化知识库模块(例如用户添加了 embedding provider 后) + """ + if self._initialized: + logger.info("知识库模块已初始化,将重新初始化") + await self.terminate() + + await self.initialize() + return self._initialized + + async def terminate(self): + """终止知识库模块,清理资源""" + if not self._initialized: + return + + logger.info("正在终止知识库模块...") + + # 关闭向量数据库连接 + if self.kb_vec_db: + try: + await self.kb_vec_db.close() + logger.debug("向量数据库已关闭") + except Exception as e: + logger.warning(f"关闭向量数据库时出错: {e}") + + # 关闭知识库独立数据库连接 + if self.kb_db: + try: + await self.kb_db.close() + logger.debug("知识库数据库已关闭") + except Exception as e: + logger.warning(f"关闭知识库数据库时出错: {e}") + + # 清理资源 + self._initialized = False + self.kb_db = None + self.kb_database = None + self.kb_manager = None + self.kb_vec_db = None + self.retrieval_manager = None + self.kb_injector = None + + logger.info("知识库模块已终止") diff --git a/astrbot/core/knowledge_base/kb_sqlite.py b/astrbot/core/knowledge_base/kb_sqlite.py new file mode 100644 index 000000000..c42d2b4b5 --- /dev/null +++ b/astrbot/core/knowledge_base/kb_sqlite.py @@ -0,0 +1,231 @@ +""" +知识库独立 SQLite 数据库 + +该模块提供知识库专用的独立 SQLite 数据库,与主数据库 (astrbot.db) 完全隔离。 +职责: +- 管理知识库相关表 (knowledge_bases, kb_documents, kb_chunks, kb_media) +- 提供数据库连接和会话管理 +- 执行数据库迁移和初始化 +""" + +from contextlib import asynccontextmanager +from pathlib import Path + +from sqlalchemy import text +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine + +from astrbot.core import logger + + +class KBSQLiteDatabase: + """知识库独立 SQLite 数据库 + + 与主数据库 (astrbot.db) 完全隔离的独立数据库,专门用于存储知识库数据。 + + 特点: + - 数据隔离: 知识库数据不会影响主数据库格式 + - 独立备份: 可以单独备份和恢复知识库数据 + - 性能隔离: 大量知识库查询不会影响主业务性能 + """ + + def __init__(self, db_path: str = "data/knowledge_base/kb.db") -> None: + """初始化知识库数据库 + + Args: + db_path: 数据库文件路径,默认为 data/knowledge_base/kb.db + """ + self.db_path = db_path + self.DATABASE_URL = f"sqlite+aiosqlite:///{db_path}" + self.inited = False + + # 确保目录存在 + Path(db_path).parent.mkdir(parents=True, exist_ok=True) + + # 创建异步引擎 + self.engine = create_async_engine( + self.DATABASE_URL, + echo=False, + pool_pre_ping=True, + pool_recycle=3600, + ) + + # 创建会话工厂 + self.async_session = async_sessionmaker( + self.engine, + class_=AsyncSession, + expire_on_commit=False, + ) + + @asynccontextmanager + async def get_db(self): + """获取数据库会话 + + 用法: + async with kb_db.get_db() as session: + # 执行数据库操作 + result = await session.execute(stmt) + """ + async with self.async_session() as session: + yield session + + async def initialize(self) -> None: + """初始化数据库,创建表并配置 SQLite 参数""" + from astrbot.core.knowledge_base.models import ( + KBChunk, + KBDocument, + KBMedia, + KBSessionConfig, + KnowledgeBase, + ) + from sqlmodel import SQLModel + + async with self.engine.begin() as conn: + # 创建所有知识库相关表 + await conn.run_sync(SQLModel.metadata.create_all) + + # 配置 SQLite 性能优化参数 + await conn.execute(text("PRAGMA journal_mode=WAL")) + await conn.execute(text("PRAGMA synchronous=NORMAL")) + await conn.execute(text("PRAGMA cache_size=20000")) + await conn.execute(text("PRAGMA temp_store=MEMORY")) + await conn.execute(text("PRAGMA mmap_size=134217728")) + await conn.execute(text("PRAGMA optimize")) + await conn.commit() + + self.inited = True + logger.info(f"知识库数据库已初始化: {self.db_path}") + + async def migrate_to_v1(self) -> None: + """执行知识库数据库 v1 迁移 + + 创建所有必要的索引以优化查询性能 + """ + async with self.get_db() as session: + session: AsyncSession + async with session.begin(): + # 创建知识库表索引 + await session.execute( + text( + "CREATE INDEX IF NOT EXISTS idx_kb_kb_id " + "ON knowledge_bases(kb_id)" + ) + ) + await session.execute( + text( + "CREATE INDEX IF NOT EXISTS idx_kb_name " + "ON knowledge_bases(kb_name)" + ) + ) + await session.execute( + text( + "CREATE INDEX IF NOT EXISTS idx_kb_created_at " + "ON knowledge_bases(created_at)" + ) + ) + + # 创建文档表索引 + await session.execute( + text( + "CREATE INDEX IF NOT EXISTS idx_doc_doc_id " + "ON kb_documents(doc_id)" + ) + ) + await session.execute( + text( + "CREATE INDEX IF NOT EXISTS idx_doc_kb_id " + "ON kb_documents(kb_id)" + ) + ) + await session.execute( + text( + "CREATE INDEX IF NOT EXISTS idx_doc_name " + "ON kb_documents(doc_name)" + ) + ) + await session.execute( + text( + "CREATE INDEX IF NOT EXISTS idx_doc_type " + "ON kb_documents(file_type)" + ) + ) + await session.execute( + text( + "CREATE INDEX IF NOT EXISTS idx_doc_created_at " + "ON kb_documents(created_at)" + ) + ) + + # 创建块表索引 + await session.execute( + text( + "CREATE INDEX IF NOT EXISTS idx_chunk_chunk_id " + "ON kb_chunks(chunk_id)" + ) + ) + await session.execute( + text( + "CREATE INDEX IF NOT EXISTS idx_chunk_doc_id " + "ON kb_chunks(doc_id)" + ) + ) + await session.execute( + text( + "CREATE INDEX IF NOT EXISTS idx_chunk_kb_id " + "ON kb_chunks(kb_id)" + ) + ) + await session.execute( + text( + "CREATE INDEX IF NOT EXISTS idx_chunk_vec_doc_id " + "ON kb_chunks(vec_doc_id)" + ) + ) + + # 创建多媒体表索引 + await session.execute( + text( + "CREATE INDEX IF NOT EXISTS idx_media_media_id " + "ON kb_media(media_id)" + ) + ) + await session.execute( + text( + "CREATE INDEX IF NOT EXISTS idx_media_doc_id " + "ON kb_media(doc_id)" + ) + ) + await session.execute( + text( + "CREATE INDEX IF NOT EXISTS idx_media_kb_id " + "ON kb_media(kb_id)" + ) + ) + await session.execute( + text( + "CREATE INDEX IF NOT EXISTS idx_media_type " + "ON kb_media(media_type)" + ) + ) + + # 创建会话配置表索引 + await session.execute( + text( + "CREATE INDEX IF NOT EXISTS idx_session_config_scope_id " + "ON kb_session_config(scope_id)" + ) + ) + await session.execute( + text( + "CREATE INDEX IF NOT EXISTS idx_session_config_scope " + "ON kb_session_config(scope)" + ) + ) + + await session.commit() + + logger.info("知识库数据库迁移 v1 完成") + + async def close(self) -> None: + """关闭数据库连接""" + await self.engine.dispose() + logger.info(f"知识库数据库已关闭: {self.db_path}") diff --git a/astrbot/core/knowledge_base/manager.py b/astrbot/core/knowledge_base/manager.py new file mode 100644 index 000000000..98462941a --- /dev/null +++ b/astrbot/core/knowledge_base/manager.py @@ -0,0 +1,349 @@ +"""知识库管理器 + +该模块提供知识库的CRUD操作和文档上传处理流程。 +""" + +import uuid +from pathlib import Path +from typing import Optional + +import aiofiles +from sqlalchemy import func, select, update + +from astrbot.core.db import BaseDatabase +from astrbot.core.db.vec_db.base import BaseVecDB +from astrbot.core.knowledge_base.chunking.base import BaseChunker +from astrbot.core.knowledge_base.models import KBChunk, KBDocument, KnowledgeBase +from astrbot.core.knowledge_base.parsers.base import BaseParser + + +class KBManager: + """知识库管理器 + + 职责: + - 知识库的 CRUD 操作 + - 文档上传与解析 + - 文档块生成与存储 + - 多媒体资源管理 + """ + + def __init__( + self, + db: BaseDatabase, + vec_db: BaseVecDB, + storage_path: str, + parsers: dict[str, BaseParser], + chunker: BaseChunker, + ): + self.db = db + self.vec_db = vec_db + self.storage_path = Path(storage_path) + self.media_path = self.storage_path / "media" + self.files_path = self.storage_path / "files" + self.parsers = parsers + self.chunker = chunker + + # 确保目录存在 + self.media_path.mkdir(parents=True, exist_ok=True) + self.files_path.mkdir(parents=True, exist_ok=True) + + # ===== 知识库操作 ===== + + async def create_kb( + self, + kb_name: str, + description: Optional[str] = None, + emoji: Optional[str] = None, + embedding_provider_id: Optional[str] = None, + rerank_provider_id: Optional[str] = None, + chunk_size: Optional[int] = None, + chunk_overlap: Optional[int] = None, + top_k_dense: Optional[int] = None, + top_k_sparse: Optional[int] = None, + top_m_final: Optional[int] = None, + enable_rerank: Optional[bool] = None, + ) -> KnowledgeBase: + """创建知识库""" + kb = KnowledgeBase( + kb_name=kb_name, + description=description, + emoji=emoji or "📚", + embedding_provider_id=embedding_provider_id, + rerank_provider_id=rerank_provider_id, + chunk_size=chunk_size if chunk_size is not None else 512, + chunk_overlap=chunk_overlap if chunk_overlap is not None else 50, + top_k_dense=top_k_dense if top_k_dense is not None else 50, + top_k_sparse=top_k_sparse if top_k_sparse is not None else 50, + top_m_final=top_m_final if top_m_final is not None else 5, + enable_rerank=enable_rerank if enable_rerank is not None else True, + ) + async with self.db.get_db() as session: + session.add(kb) + await session.commit() + await session.refresh(kb) + return kb + + async def get_kb(self, kb_id: str) -> Optional[KnowledgeBase]: + """获取知识库""" + async with self.db.get_db() as session: + stmt = select(KnowledgeBase).where(KnowledgeBase.kb_id == kb_id) + result = await session.execute(stmt) + return result.scalar_one_or_none() + + async def list_kbs(self, offset: int = 0, limit: int = 100) -> list[KnowledgeBase]: + """列出所有知识库""" + async with self.db.get_db() as session: + stmt = ( + select(KnowledgeBase) + .offset(offset) + .limit(limit) + .order_by(KnowledgeBase.created_at.desc()) + ) + result = await session.execute(stmt) + return list(result.scalars().all()) + + async def update_kb( + self, + kb_id: str, + kb_name: Optional[str] = None, + description: Optional[str] = None, + emoji: Optional[str] = None, + embedding_provider_id: Optional[str] = None, + rerank_provider_id: Optional[str] = None, + chunk_size: Optional[int] = None, + chunk_overlap: Optional[int] = None, + top_k_dense: Optional[int] = None, + top_k_sparse: Optional[int] = None, + top_m_final: Optional[int] = None, + enable_rerank: Optional[bool] = None, + ) -> Optional[KnowledgeBase]: + """更新知识库""" + async with self.db.get_db() as session: + stmt = select(KnowledgeBase).where(KnowledgeBase.kb_id == kb_id) + result = await session.execute(stmt) + kb = result.scalar_one_or_none() + if not kb: + return None + + if kb_name is not None: + kb.kb_name = kb_name + if description is not None: + kb.description = description + if emoji is not None: + kb.emoji = emoji + if embedding_provider_id is not None: + kb.embedding_provider_id = embedding_provider_id + if rerank_provider_id is not None: + kb.rerank_provider_id = rerank_provider_id + if chunk_size is not None: + kb.chunk_size = chunk_size + if chunk_overlap is not None: + kb.chunk_overlap = chunk_overlap + if top_k_dense is not None: + kb.top_k_dense = top_k_dense + if top_k_sparse is not None: + kb.top_k_sparse = top_k_sparse + if top_m_final is not None: + kb.top_m_final = top_m_final + if enable_rerank is not None: + kb.enable_rerank = enable_rerank + + await session.commit() + await session.refresh(kb) + return kb + + async def delete_kb(self, kb_id: str) -> bool: + """删除知识库(级联删除所有文档和资源)""" + # 1. 获取所有文档 + from astrbot.core.knowledge_base.manager_ops import KBManagerOps + + ops = KBManagerOps(self) + docs = await ops.list_documents(kb_id) + + # 2. 删除所有文档(包括文件和向量) + for doc in docs: + await ops.delete_document(doc.doc_id) + + # 3. 删除知识库记录 + async with self.db.get_db() as session: + stmt = select(KnowledgeBase).where(KnowledgeBase.kb_id == kb_id) + result = await session.execute(stmt) + kb = result.scalar_one_or_none() + if not kb: + return False + + await session.delete(kb) + await session.commit() + + return True + + # ===== 文档上传 ===== + + async def upload_document( + self, + kb_id: str, + file_name: str, + file_content: bytes, + file_type: str, + ) -> KBDocument: + """上传并处理文档(带原子性保证和失败清理) + + 流程: + 1. 保存原始文件 + 2. 解析文档内容 + 3. 提取多媒体资源 + 4. 分块处理 + 5. 生成向量并存储 + 6. 保存元数据(事务) + 7. 更新统计 + """ + doc_id = str(uuid.uuid4()) + file_path = None + media_paths = [] + vec_doc_ids = [] + + try: + # 1. 保存原始文件 + file_path = self.files_path / kb_id / f"{doc_id}.{file_type}" + file_path.parent.mkdir(parents=True, exist_ok=True) + + async with aiofiles.open(file_path, "wb") as f: + await f.write(file_content) + + # 2. 解析文档 + parser = self.parsers.get(file_type) + if not parser: + raise ValueError(f"不支持的文件类型: {file_type}") + + parse_result = await parser.parse(file_content, file_name) + text_content = parse_result.text + media_items = parse_result.media + + # 3. 保存多媒体资源 + from astrbot.core.knowledge_base.manager_ops import KBManagerOps + + ops = KBManagerOps(self) + saved_media = [] + for media_item in media_items: + media = await ops._save_media( + kb_id=kb_id, + doc_id=doc_id, + media_type=media_item.media_type, + file_name=media_item.file_name, + content=media_item.content, + mime_type=media_item.mime_type, + ) + saved_media.append(media) + media_paths.append(Path(media.file_path)) + + # 4. 文档分块 + chunks_text = await self.chunker.chunk(text_content) + + # 5. 生成向量并存储 + saved_chunks = [] + for idx, chunk_text in enumerate(chunks_text): + # 存储到向量数据库 + vec_doc_id = await self.vec_db.insert( + content=chunk_text, + metadata={ + "kb_id": kb_id, + "doc_id": doc_id, + "chunk_index": idx, + }, + ) + vec_doc_ids.append(str(vec_doc_id)) + + # 保存块元数据 + chunk = KBChunk( + doc_id=doc_id, + kb_id=kb_id, + chunk_index=idx, + content=chunk_text, + char_count=len(chunk_text), + vec_doc_id=str(vec_doc_id), + ) + saved_chunks.append(chunk) + + # 6. 保存文档元数据(事务) + doc = KBDocument( + doc_id=doc_id, + kb_id=kb_id, + doc_name=file_name, + file_type=file_type, + file_size=len(file_content), + file_path=str(file_path), + chunk_count=len(saved_chunks), + media_count=len(saved_media), + ) + + async with self.db.get_db() as session: + async with session.begin(): + session.add(doc) + for chunk in saved_chunks: + session.add(chunk) + for media in saved_media: + session.add(media) + await session.commit() + + await session.refresh(doc) + + # 7. 更新知识库统计 + await self._update_kb_stats(kb_id) + + return doc + + except Exception as e: + # 失败清理:删除已创建的资源 + from astrbot.core import logger + + logger.error(f"文档上传失败,开始清理资源: {e}") + + # 清理向量数据库 + for vec_id in vec_doc_ids: + try: + await self.vec_db.delete(vec_id) + except Exception as ve: + logger.warning(f"清理向量失败 {vec_id}: {ve}") + + # 清理多媒体文件 + for media_path in media_paths: + try: + if media_path.exists(): + media_path.unlink() + except Exception as me: + logger.warning(f"清理多媒体文件失败 {media_path}: {me}") + + # 清理文档文件 + if file_path and file_path.exists(): + try: + file_path.unlink() + except Exception as fe: + logger.warning(f"清理文档文件失败 {file_path}: {fe}") + + # 重新抛出原始异常 + raise + + # ===== 统计更新 ===== + + async def _update_kb_stats(self, kb_id: str): + """更新知识库统计信息(事务中执行)""" + async with self.db.get_db() as session: + async with session.begin(): + # 统计文档数(在事务中查询) + doc_count = await session.scalar( + select(func.count(KBDocument.id)).where(KBDocument.kb_id == kb_id) + ) or 0 + + # 统计块数(在事务中查询) + chunk_count = await session.scalar( + select(func.count(KBChunk.id)).where(KBChunk.kb_id == kb_id) + ) or 0 + + # 更新知识库(在同一事务中) + await session.execute( + update(KnowledgeBase) + .where(KnowledgeBase.kb_id == kb_id) + .values(doc_count=doc_count, chunk_count=chunk_count) + ) + + await session.commit() diff --git a/astrbot/core/knowledge_base/manager_ops.py b/astrbot/core/knowledge_base/manager_ops.py new file mode 100644 index 000000000..521d3de50 --- /dev/null +++ b/astrbot/core/knowledge_base/manager_ops.py @@ -0,0 +1,306 @@ +"""知识库管理器辅助操作 + +该模块提供文档、块和多媒体的管理操作。 +""" + +import uuid +from pathlib import Path +from typing import TYPE_CHECKING + +import aiofiles +from sqlalchemy import delete, func, select + +from astrbot.core.knowledge_base.models import KBChunk, KBDocument, KBMedia + +if TYPE_CHECKING: + from astrbot.core.knowledge_base.manager import KBManager + + +class KBManagerOps: + """知识库管理器辅助操作类 + + 职责: + - 文档管理操作 + - 块管理操作 + - 多媒体管理操作 + """ + + def __init__(self, manager: "KBManager"): + self.manager = manager + self.db = manager.db + self.vec_db = manager.vec_db + self.media_path = manager.media_path + self.files_path = manager.files_path + + # ===== 文档操作 ===== + + async def list_documents( + self, kb_id: str, offset: int = 0, limit: int = 100 + ) -> list[KBDocument]: + """列出知识库的所有文档""" + async with self.db.get_db() as session: + stmt = ( + select(KBDocument) + .where(KBDocument.kb_id == kb_id) + .offset(offset) + .limit(limit) + .order_by(KBDocument.created_at.desc()) + ) + result = await session.execute(stmt) + return list(result.scalars().all()) + + async def get_document(self, doc_id: str) -> KBDocument | None: + """获取文档详情""" + async with self.db.get_db() as session: + stmt = select(KBDocument).where(KBDocument.doc_id == doc_id) + result = await session.execute(stmt) + return result.scalar_one_or_none() + + async def delete_document(self, doc_id: str) -> bool: + """删除文档(级联删除块、多媒体、向量) + + 采用三阶段删除策略: + 1. 删除向量数据库中的向量(允许部分失败) + 2. 删除SQL数据库中的记录(事务保证原子性) + 3. 删除文件系统中的文件(失败不影响数据一致性) + """ + from astrbot.core import logger + + # 0. 获取文档信息 + doc = await self.get_document(doc_id) + if not doc: + return False + + # 收集所有需要删除的资源 + chunks = await self.list_chunks(doc_id) + media_list = await self.list_media(doc_id) + + # ===== 第一阶段: 删除向量(可重试) ===== + vec_ids_to_delete = [chunk.vec_doc_id for chunk in chunks] + deleted_vec_ids = [] + failed_vec_ids = [] + + for vec_id in vec_ids_to_delete: + try: + await self.vec_db.delete(vec_id) + deleted_vec_ids.append(vec_id) + except Exception as e: + logger.error(f"删除向量失败: {vec_id}, {e}") + failed_vec_ids.append(vec_id) + + # 如果向量删除失败过多(超过50%),中止操作 + if len(failed_vec_ids) > len(vec_ids_to_delete) * 0.5: + logger.error( + f"向量删除失败过多 ({len(failed_vec_ids)}/{len(vec_ids_to_delete)}), 中止文档删除" + ) + return False + + # 记录部分失败但继续执行 + if failed_vec_ids: + logger.warning( + f"部分向量删除失败 ({len(failed_vec_ids)}/{len(vec_ids_to_delete)}), 但继续执行删除操作" + ) + + # ===== 第二阶段: 删除数据库记录(事务) ===== + async with self.db.get_db() as session: + async with session.begin(): + # 删除块记录 + await session.execute(delete(KBChunk).where(KBChunk.doc_id == doc_id)) + + # 删除多媒体记录 + await session.execute(delete(KBMedia).where(KBMedia.doc_id == doc_id)) + + # 删除文档记录 + await session.execute(delete(KBDocument).where(KBDocument.doc_id == doc_id)) + + await session.commit() + + # ===== 第三阶段: 删除文件(失败不影响) ===== + # 删除多媒体文件 + for media in media_list: + try: + media_path = Path(media.file_path) + if media_path.exists(): + media_path.unlink() + except Exception as e: + logger.warning(f"删除多媒体文件失败: {media.file_path}, {e}") + + # 删除文档文件 + try: + file_path = Path(doc.file_path) + if file_path.exists(): + file_path.unlink() + except Exception as e: + logger.warning(f"删除文档文件失败: {doc.file_path}, {e}") + + # ===== 更新统计 ===== + await self.manager._update_kb_stats(doc.kb_id) + + return True + + # ===== 块操作 ===== + + async def list_chunks(self, doc_id: str) -> list[KBChunk]: + """列出文档的所有块""" + async with self.db.get_db() as session: + stmt = ( + select(KBChunk) + .where(KBChunk.doc_id == doc_id) + .order_by(KBChunk.chunk_index) + ) + result = await session.execute(stmt) + return list(result.scalars().all()) + + async def delete_chunk(self, chunk_id: str) -> bool: + """删除单个块 + + 流程: + 1. 查询块信息 + 2. 删除向量 + 3. 删除数据库记录 + 4. 更新文档统计 + """ + from astrbot.core import logger + + # 1. 查询块信息 + async with self.db.get_db() as session: + stmt = select(KBChunk).where(KBChunk.chunk_id == chunk_id) + result = await session.execute(stmt) + chunk = result.scalar_one_or_none() + if not chunk: + return False + + doc_id = chunk.doc_id + vec_doc_id = chunk.vec_doc_id + + # 2. 删除向量 + try: + await self.vec_db.delete(vec_doc_id) + except Exception as e: + logger.error(f"删除向量失败: {vec_doc_id}, {e}") + return False + + # 3. 删除数据库记录 + async with self.db.get_db() as session: + async with session.begin(): + await session.execute(delete(KBChunk).where(KBChunk.chunk_id == chunk_id)) + await session.commit() + + # 4. 更新文档统计 + await self._update_doc_stats(doc_id) + + return True + + # ===== 多媒体操作 ===== + + async def list_media(self, doc_id: str) -> list[KBMedia]: + """列出文档的所有多媒体资源""" + async with self.db.get_db() as session: + stmt = select(KBMedia).where(KBMedia.doc_id == doc_id) + result = await session.execute(stmt) + return list(result.scalars().all()) + + async def delete_media(self, media_id: str) -> bool: + """删除多媒体资源 + + 流程: + 1. 查询媒体信息 + 2. 删除数据库记录 + 3. 删除文件(失败不影响) + 4. 更新文档统计 + """ + from astrbot.core import logger + + # 1. 查询媒体信息 + async with self.db.get_db() as session: + stmt = select(KBMedia).where(KBMedia.media_id == media_id) + result = await session.execute(stmt) + media = result.scalar_one_or_none() + if not media: + return False + + doc_id = media.doc_id + file_path_str = media.file_path + + # 2. 删除数据库记录 + async with self.db.get_db() as session: + async with session.begin(): + await session.execute(delete(KBMedia).where(KBMedia.media_id == media_id)) + await session.commit() + + # 3. 删除文件(失败不影响) + try: + media_path = Path(file_path_str) + if media_path.exists(): + media_path.unlink() + except Exception as e: + logger.warning(f"删除多媒体文件失败: {file_path_str}, {e}") + + # 4. 更新文档统计 + await self._update_doc_stats(doc_id) + + return True + + # ===== 内部辅助方法 ===== + + async def _save_media( + self, + kb_id: str, + doc_id: str, + media_type: str, + file_name: str, + content: bytes, + mime_type: str, + ) -> KBMedia: + """保存多媒体资源""" + media_id = str(uuid.uuid4()) + ext = Path(file_name).suffix + + # 保存文件 + file_path = self.media_path / kb_id / doc_id / f"{media_id}{ext}" + file_path.parent.mkdir(parents=True, exist_ok=True) + + async with aiofiles.open(file_path, "wb") as f: + await f.write(content) + + # 创建记录 + media = KBMedia( + media_id=media_id, + doc_id=doc_id, + kb_id=kb_id, + media_type=media_type, + file_name=file_name, + file_path=str(file_path), + file_size=len(content), + mime_type=mime_type, + ) + + return media + + async def _update_doc_stats(self, doc_id: str): + """更新文档统计信息(事务中执行)""" + async with self.db.get_db() as session: + async with session.begin(): + # 统计块数 + chunk_count = ( + await session.scalar( + select(func.count(KBChunk.id)).where(KBChunk.doc_id == doc_id) + ) + ) or 0 + + # 统计多媒体数 + media_count = ( + await session.scalar( + select(func.count(KBMedia.id)).where(KBMedia.doc_id == doc_id) + ) + ) or 0 + + # 更新文档 + doc = await session.scalar( + select(KBDocument).where(KBDocument.doc_id == doc_id) + ) + if doc: + doc.chunk_count = chunk_count + doc.media_count = media_count + + await session.commit() diff --git a/astrbot/core/knowledge_base/models.py b/astrbot/core/knowledge_base/models.py new file mode 100644 index 000000000..44e519280 --- /dev/null +++ b/astrbot/core/knowledge_base/models.py @@ -0,0 +1,184 @@ +"""知识库管理功能的数据模型定义 + +该模块定义了知识库系统所需的数据模型,包括: +- KnowledgeBase: 知识库表 (存储在独立的 kb.db) +- KBDocument: 文档表 (存储在独立的 kb.db) +- KBChunk: 文档块表 (存储在独立的 kb.db) +- KBMedia: 多媒体资源表 (存储在独立的 kb.db) +- KBSessionConfig: 会话配置表 (存储在独立的 kb.db) + +注意: +- 所有模型存储在独立的知识库数据库 (data/knowledge_base/kb.db) +- 与主数据库 (astrbot.db) 完全解耦 +""" + +import uuid +from datetime import datetime, timezone +from typing import Optional + +from sqlmodel import Field, SQLModel, Text, UniqueConstraint + + +class KnowledgeBase(SQLModel, table=True): + """知识库表 + + 存储知识库的基本信息和统计数据。 + """ + + __tablename__ = "knowledge_bases" + + id: int | None = Field( + primary_key=True, sa_column_kwargs={"autoincrement": True}, default=None + ) + kb_id: str = Field( + max_length=36, + nullable=False, + unique=True, + default_factory=lambda: str(uuid.uuid4()), + index=True, + ) + kb_name: str = Field(max_length=100, nullable=False) + description: Optional[str] = Field(default=None, sa_type=Text) + emoji: Optional[str] = Field(default="📚", max_length=10) + embedding_provider_id: Optional[str] = Field(default=None, max_length=100) + rerank_provider_id: Optional[str] = Field(default=None, max_length=100) + # 分块配置参数 + chunk_size: Optional[int] = Field(default=512, nullable=True) + chunk_overlap: Optional[int] = Field(default=50, nullable=True) + # 检索配置参数 + top_k_dense: Optional[int] = Field(default=50, nullable=True) + top_k_sparse: Optional[int] = Field(default=50, nullable=True) + top_m_final: Optional[int] = Field(default=5, nullable=True) + enable_rerank: Optional[bool] = Field(default=True, nullable=True) + created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + updated_at: datetime = Field( + default_factory=lambda: datetime.now(timezone.utc), + sa_column_kwargs={"onupdate": datetime.now(timezone.utc)}, + ) + doc_count: int = Field(default=0, nullable=False) + chunk_count: int = Field(default=0, nullable=False) + + +class KBDocument(SQLModel, table=True): + """文档表 + + 存储上传到知识库的文档元数据。 + """ + + __tablename__ = "kb_documents" + + id: int | None = Field( + primary_key=True, sa_column_kwargs={"autoincrement": True}, default=None + ) + doc_id: str = Field( + max_length=36, + nullable=False, + unique=True, + default_factory=lambda: str(uuid.uuid4()), + index=True, + ) + kb_id: str = Field(max_length=36, nullable=False, index=True) + doc_name: str = Field(max_length=255, nullable=False) + file_type: str = Field(max_length=20, nullable=False) + file_size: int = Field(nullable=False) + file_path: str = Field(max_length=512, nullable=False) + chunk_count: int = Field(default=0, nullable=False) + media_count: int = Field(default=0, nullable=False) + created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + updated_at: datetime = Field( + default_factory=lambda: datetime.now(timezone.utc), + sa_column_kwargs={"onupdate": datetime.now(timezone.utc)}, + ) + + +class KBChunk(SQLModel, table=True): + """文档块表 + + 存储文档分块后的文本内容和向量索引关联信息。 + """ + + __tablename__ = "kb_chunks" + + id: int | None = Field( + primary_key=True, sa_column_kwargs={"autoincrement": True}, default=None + ) + chunk_id: str = Field( + max_length=36, + nullable=False, + unique=True, + default_factory=lambda: str(uuid.uuid4()), + index=True, + ) + doc_id: str = Field(max_length=36, nullable=False, index=True) + kb_id: str = Field(max_length=36, nullable=False, index=True) + chunk_index: int = Field(nullable=False) + content: str = Field(sa_type=Text, nullable=False) + char_count: int = Field(nullable=False) + vec_doc_id: str = Field(max_length=100, nullable=False, index=True) + created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + + +class KBMedia(SQLModel, table=True): + """多媒体资源表 + + 存储从文档中提取的图片、视频等多媒体资源。 + """ + + __tablename__ = "kb_media" + + id: int | None = Field( + primary_key=True, sa_column_kwargs={"autoincrement": True}, default=None + ) + media_id: str = Field( + max_length=36, + nullable=False, + unique=True, + default_factory=lambda: str(uuid.uuid4()), + index=True, + ) + doc_id: str = Field(max_length=36, nullable=False, index=True) + kb_id: str = Field(max_length=36, nullable=False, index=True) + media_type: str = Field(max_length=20, nullable=False) + file_name: str = Field(max_length=255, nullable=False) + file_path: str = Field(max_length=512, nullable=False) + file_size: int = Field(nullable=False) + mime_type: str = Field(max_length=100, nullable=False) + created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + + +class KBSessionConfig(SQLModel, table=True): + """会话知识库配置表 + + 存储会话或平台级别的知识库关联配置。 + 该表存储在知识库独立数据库中,保持完全解耦。 + + 支持两种配置范围: + - platform: 平台级别配置 (如 'qq', 'telegram') + - session: 会话级别配置 (如 'qq:group:12345') + """ + + __tablename__ = "kb_session_config" + + id: int | None = Field( + primary_key=True, sa_column_kwargs={"autoincrement": True}, default=None + ) + config_id: str = Field( + max_length=36, + nullable=False, + unique=True, + default_factory=lambda: str(uuid.uuid4()), + ) + scope: str = Field(max_length=20, nullable=False) + scope_id: str = Field(max_length=255, nullable=False, index=True) + kb_ids: str = Field(sa_type=Text, nullable=False) + top_k: Optional[int] = Field(default=None, nullable=True) + enable_rerank: Optional[bool] = Field(default=None, nullable=True) + created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + updated_at: datetime = Field( + default_factory=lambda: datetime.now(timezone.utc), + sa_column_kwargs={"onupdate": datetime.now(timezone.utc)}, + ) + + __table_args__ = ( + UniqueConstraint("scope", "scope_id", name="uix_scope_scope_id"), + ) diff --git a/astrbot/core/knowledge_base/parsers/__init__.py b/astrbot/core/knowledge_base/parsers/__init__.py new file mode 100644 index 000000000..6851edebd --- /dev/null +++ b/astrbot/core/knowledge_base/parsers/__init__.py @@ -0,0 +1,15 @@ +""" +文档解析器模块 +""" + +from .base import BaseParser, MediaItem, ParseResult +from .text_parser import TextParser +from .pdf_parser import PDFParser + +__all__ = [ + "BaseParser", + "MediaItem", + "ParseResult", + "TextParser", + "PDFParser", +] diff --git a/astrbot/core/knowledge_base/parsers/base.py b/astrbot/core/knowledge_base/parsers/base.py new file mode 100644 index 000000000..1c571db2e --- /dev/null +++ b/astrbot/core/knowledge_base/parsers/base.py @@ -0,0 +1,50 @@ +"""文档解析器基类和数据结构 + +定义了文档解析器的抽象接口和相关数据类。 +""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass + + +@dataclass +class MediaItem: + """多媒体项 + + 表示从文档中提取的多媒体资源。 + """ + + media_type: str # image, video + file_name: str + content: bytes + mime_type: str + + +@dataclass +class ParseResult: + """解析结果 + + 包含解析后的文本内容和提取的多媒体资源。 + """ + + text: str + media: list[MediaItem] + + +class BaseParser(ABC): + """文档解析器基类 + + 所有文档解析器都应该继承此类并实现 parse 方法。 + """ + + @abstractmethod + async def parse(self, file_content: bytes, file_name: str) -> ParseResult: + """解析文档 + + Args: + file_content: 文件内容 + file_name: 文件名 + + Returns: + ParseResult: 解析结果 + """ diff --git a/astrbot/core/knowledge_base/parsers/pdf_parser.py b/astrbot/core/knowledge_base/parsers/pdf_parser.py new file mode 100644 index 000000000..8bb1dea66 --- /dev/null +++ b/astrbot/core/knowledge_base/parsers/pdf_parser.py @@ -0,0 +1,100 @@ +"""PDF 文件解析器 + +支持解析 PDF 文件中的文本和图片资源。 +""" + +import io + +from pypdf import PdfReader + +from astrbot.core.knowledge_base.parsers.base import ( + BaseParser, + MediaItem, + ParseResult, +) + + +class PDFParser(BaseParser): + """PDF 文档解析器 + + 提取 PDF 中的文本内容和嵌入的图片资源。 + """ + + async def parse(self, file_content: bytes, file_name: str) -> ParseResult: + """解析 PDF 文件 + + Args: + file_content: 文件内容 + file_name: 文件名 + + Returns: + ParseResult: 包含文本和图片的解析结果 + """ + pdf_file = io.BytesIO(file_content) + reader = PdfReader(pdf_file) + + text_parts = [] + media_items = [] + + # 提取文本 + for page in reader.pages: + text = page.extract_text() + if text: + text_parts.append(text) + + # 提取图片 + image_counter = 0 + for page_num, page in enumerate(reader.pages): + try: + # 安全检查 Resources + if "/Resources" not in page: + continue + + resources = page["/Resources"] + if not resources or "/XObject" not in resources: + continue + + xobjects = resources["/XObject"].get_object() + if not xobjects: + continue + + for obj_name in xobjects: + try: + obj = xobjects[obj_name] + + if obj.get("/Subtype") != "/Image": + continue + + # 提取图片数据 + image_data = obj.get_data() + + # 确定格式 + filter_type = obj.get("/Filter", "") + if filter_type == "/DCTDecode": + ext = "jpg" + mime_type = "image/jpeg" + elif filter_type == "/FlateDecode": + ext = "png" + mime_type = "image/png" + else: + ext = "png" + mime_type = "image/png" + + image_counter += 1 + media_items.append( + MediaItem( + media_type="image", + file_name=f"page_{page_num}_img_{image_counter}.{ext}", + content=image_data, + mime_type=mime_type, + ) + ) + except Exception: + # 单个图片提取失败不影响整体 + continue + except Exception: + # 页面处理失败不影响其他页面 + continue + + full_text = "\n\n".join(text_parts) + return ParseResult(text=full_text, media=media_items) diff --git a/astrbot/core/knowledge_base/parsers/text_parser.py b/astrbot/core/knowledge_base/parsers/text_parser.py new file mode 100644 index 000000000..49a95a95c --- /dev/null +++ b/astrbot/core/knowledge_base/parsers/text_parser.py @@ -0,0 +1,41 @@ +"""文本文件解析器 + +支持解析 TXT 和 Markdown 文件。 +""" + +from astrbot.core.knowledge_base.parsers.base import BaseParser, ParseResult + + +class TextParser(BaseParser): + """TXT/MD 文本解析器 + + 支持多种字符编码的自动检测。 + """ + + async def parse(self, file_content: bytes, file_name: str) -> ParseResult: + """解析文本文件 + + 尝试使用多种编码解析文件内容。 + + Args: + file_content: 文件内容 + file_name: 文件名 + + Returns: + ParseResult: 解析结果,不包含多媒体资源 + + Raises: + ValueError: 如果无法解码文件 + """ + # 尝试多种编码 + for encoding in ["utf-8", "gbk", "gb2312", "gb18030"]: + try: + text = file_content.decode(encoding) + break + except UnicodeDecodeError: + continue + else: + raise ValueError(f"无法解码文件: {file_name}") + + # 文本文件无多媒体资源 + return ParseResult(text=text, media=[]) diff --git a/astrbot/core/knowledge_base/retrieval/__init__.py b/astrbot/core/knowledge_base/retrieval/__init__.py new file mode 100644 index 000000000..16a5e6645 --- /dev/null +++ b/astrbot/core/knowledge_base/retrieval/__init__.py @@ -0,0 +1,16 @@ +""" +检索模块 +""" + +from .manager import RetrievalManager, RetrievalResult +from .sparse_retriever import SparseRetriever, SparseResult +from .rank_fusion import RankFusion, FusedResult + +__all__ = [ + "RetrievalManager", + "RetrievalResult", + "SparseRetriever", + "SparseResult", + "RankFusion", + "FusedResult", +] diff --git a/astrbot/core/knowledge_base/retrieval/manager.py b/astrbot/core/knowledge_base/retrieval/manager.py new file mode 100644 index 000000000..0c0f0a9f1 --- /dev/null +++ b/astrbot/core/knowledge_base/retrieval/manager.py @@ -0,0 +1,224 @@ +"""检索管理器 + +协调稠密检索、稀疏检索和 Rerank,提供统一的检索接口 +""" + +import json +from dataclasses import dataclass +from typing import List, Optional + +from astrbot.core.db.vec_db.base import BaseVecDB +from astrbot.core.knowledge_base.database import KBDatabase +from astrbot.core.knowledge_base.retrieval.rank_fusion import RankFusion +from astrbot.core.knowledge_base.retrieval.sparse_retriever import SparseRetriever +from astrbot.core.provider.provider import RerankProvider + + +@dataclass +class RetrievalResult: + """检索结果""" + + chunk_id: str + doc_id: str + doc_name: str + kb_id: str + kb_name: str + content: str + score: float + metadata: dict + + +class RetrievalManager: + """检索管理器 + + 职责: + - 协调稠密检索、稀疏检索和 Rerank + - 结果融合和排序 + """ + + def __init__( + self, + vec_db: BaseVecDB, + sparse_retriever: SparseRetriever, + rank_fusion: RankFusion, + kb_db: KBDatabase, + rerank_provider: Optional[RerankProvider] = None, + ): + """初始化检索管理器 + + Args: + vec_db: 向量数据库实例 + sparse_retriever: 稀疏检索器 + rank_fusion: 结果融合器 + kb_db: 知识库数据库实例 + rerank_provider: Rerank 提供商 (可选) + """ + self.vec_db = vec_db + self.sparse_retriever = sparse_retriever + self.rank_fusion = rank_fusion + self.kb_db = kb_db + self.rerank_provider = rerank_provider + + async def retrieve( + self, + query: str, + kb_ids: List[str], + top_k_dense: int = 50, + top_k_sparse: int = 50, + top_n_fusion: int = 20, + top_m_final: int = 5, + enable_rerank: bool = True, + ) -> List[RetrievalResult]: + """混合检索 + + 流程: + 1. 稠密检索 (向量相似度) + 2. 稀疏检索 (BM25) + 3. 结果融合 (RRF) + 4. Rerank 重排序 + + Args: + query: 查询文本 + kb_ids: 知识库 ID 列表 + top_k_dense: 稠密检索返回数量 + top_k_sparse: 稀疏检索返回数量 + top_n_fusion: 融合后返回数量 + top_m_final: 最终返回数量 + enable_rerank: 是否启用 Rerank + + Returns: + List[RetrievalResult]: 检索结果列表 + """ + # 1. 稠密检索 + dense_results = await self._dense_retrieve( + query=query, + kb_ids=kb_ids, + top_k=top_k_dense, + ) + + # 2. 稀疏检索 + sparse_results = await self.sparse_retriever.retrieve( + query=query, + kb_ids=kb_ids, + top_k=top_k_sparse, + ) + + # 3. 结果融合 + fused_results = await self.rank_fusion.fuse( + dense_results=dense_results, + sparse_results=sparse_results, + top_k=top_n_fusion, + ) + + # 4. 转换为 RetrievalResult (获取元数据) + retrieval_results = [] + for fr in fused_results: + metadata_dict = await self.kb_db.get_chunk_with_metadata(fr.chunk_id) + if metadata_dict: + retrieval_results.append( + RetrievalResult( + chunk_id=fr.chunk_id, + doc_id=fr.doc_id, + doc_name=metadata_dict["document"].doc_name, + kb_id=fr.kb_id, + kb_name=metadata_dict["knowledge_base"].kb_name, + content=fr.content, + score=fr.score, + metadata={ + "chunk_index": metadata_dict["chunk"].chunk_index, + "char_count": metadata_dict["chunk"].char_count, + }, + ) + ) + + # 5. Rerank (可选) + if enable_rerank and self.rerank_provider and retrieval_results: + retrieval_results = await self._rerank( + query=query, + results=retrieval_results, + top_k=top_m_final, + ) + else: + retrieval_results = retrieval_results[:top_m_final] + + return retrieval_results + + async def _dense_retrieve( + self, + query: str, + kb_ids: List[str], + top_k: int, + ): + """稠密检索 (向量相似度) + + Args: + query: 查询文本 + kb_ids: 知识库 ID 列表 + top_k: 返回结果数量 + + Returns: + List[Result]: 检索结果列表 + """ + # 直接调用向量数据库检索 + vec_results = await self.vec_db.retrieve( + query=query, + k=top_k * len(kb_ids) * 2, # 增加候选数量以便过滤 + ) + + # 过滤:只保留指定知识库的结果 + filtered_results = [] + for result in vec_results: + metadata_str = result.data.get("metadata", "{}") + try: + metadata = json.loads(metadata_str) + except (json.JSONDecodeError, TypeError): + metadata = {} + + if metadata.get("kb_id") in kb_ids: + filtered_results.append(result) + + if len(filtered_results) >= top_k: + break + + return filtered_results[:top_k] + + async def _rerank( + self, + query: str, + results: List[RetrievalResult], + top_k: int, + ) -> List[RetrievalResult]: + """Rerank 重排序 + + Args: + query: 查询文本 + results: 检索结果列表 + top_k: 返回结果数量 + + Returns: + List[RetrievalResult]: 重排序后的结果列表 + """ + if not results: + return [] + + # 准备文档列表 + docs = [r.content for r in results] + + # 调用 Rerank Provider + rerank_results = await self.rerank_provider.rerank( + query=query, + documents=docs, + ) + + # 更新分数并重新排序 + reranked_list = [] + for rerank_result in rerank_results: + idx = rerank_result.index + if idx < len(results): + result = results[idx] + result.score = rerank_result.relevance_score + reranked_list.append(result) + + reranked_list.sort(key=lambda x: x.score, reverse=True) + + return reranked_list[:top_k] diff --git a/astrbot/core/knowledge_base/retrieval/rank_fusion.py b/astrbot/core/knowledge_base/retrieval/rank_fusion.py new file mode 100644 index 000000000..0dd483c1a --- /dev/null +++ b/astrbot/core/knowledge_base/retrieval/rank_fusion.py @@ -0,0 +1,134 @@ +"""检索结果融合器 + +使用 Reciprocal Rank Fusion (RRF) 算法融合稠密检索和稀疏检索的结果 +""" + +from dataclasses import dataclass +from typing import Dict, List + +from astrbot.core.db.vec_db.base import Result +from astrbot.core.knowledge_base.database import KBDatabase +from astrbot.core.knowledge_base.retrieval.sparse_retriever import SparseResult + + +@dataclass +class FusedResult: + """融合后的检索结果""" + + chunk_id: str + doc_id: str + kb_id: str + content: str + score: float + + +class RankFusion: + """检索结果融合器 + + 职责: + - 融合稠密检索和稀疏检索的结果 + - 使用 Reciprocal Rank Fusion (RRF) 算法 + """ + + def __init__(self, kb_db: KBDatabase, k: int = 60): + """初始化结果融合器 + + Args: + kb_db: 知识库数据库实例 + k: RRF 参数,用于平滑排名 + """ + self.kb_db = kb_db + self.k = k + + async def fuse( + self, + dense_results: List[Result], + sparse_results: List[SparseResult], + top_k: int = 20, + ) -> List[FusedResult]: + """融合稠密和稀疏检索结果 + + RRF 公式: + score(doc) = sum(1 / (k + rank_i)) + + Args: + dense_results: 稠密检索结果 + sparse_results: 稀疏检索结果 + top_k: 返回结果数量 + + Returns: + List[FusedResult]: 融合后的结果列表 + """ + # 1. 构建排名映射 + dense_ranks = {r.data["doc_id"]: (idx + 1) for idx, r in enumerate(dense_results)} + sparse_ranks = {r.chunk_id: (idx + 1) for idx, r in enumerate(sparse_results)} + + # 2. 收集所有唯一的 ID (来自稠密检索的是 vec_doc_id, 稀疏检索的是 chunk_id) + # 需要统一为 chunk_id + all_chunk_ids = set() + vec_doc_id_to_dense = {} # vec_doc_id -> Result + chunk_id_to_sparse = {} # chunk_id -> SparseResult + + # 处理稀疏检索结果 + for r in sparse_results: + all_chunk_ids.add(r.chunk_id) + chunk_id_to_sparse[r.chunk_id] = r + + # 处理稠密检索结果 (需要转换 vec_doc_id 到 chunk_id) + for r in dense_results: + vec_doc_id = r.data["doc_id"] + all_chunk_ids.add(vec_doc_id) + vec_doc_id_to_dense[vec_doc_id] = r + + # 3. 计算 RRF 分数 + rrf_scores: Dict[str, float] = {} + + for identifier in all_chunk_ids: + score = 0.0 + + # 来自稠密检索的贡献 + if identifier in dense_ranks: + score += 1.0 / (self.k + dense_ranks[identifier]) + + # 来自稀疏检索的贡献 + if identifier in sparse_ranks: + score += 1.0 / (self.k + sparse_ranks[identifier]) + + rrf_scores[identifier] = score + + # 4. 排序 + sorted_ids = sorted( + rrf_scores.keys(), key=lambda cid: rrf_scores[cid], reverse=True + )[:top_k] + + # 5. 构建融合结果 + fused_results = [] + for identifier in sorted_ids: + # 优先从稀疏检索获取完整信息 + if identifier in chunk_id_to_sparse: + sr = chunk_id_to_sparse[identifier] + fused_results.append( + FusedResult( + chunk_id=sr.chunk_id, + doc_id=sr.doc_id, + kb_id=sr.kb_id, + content=sr.content, + score=rrf_scores[identifier], + ) + ) + elif identifier in vec_doc_id_to_dense: + # 从向量检索获取信息,需要从数据库获取块的详细信息 + dr = vec_doc_id_to_dense[identifier] + chunk = await self.kb_db.get_chunk_by_vec_doc_id(identifier) + if chunk: + fused_results.append( + FusedResult( + chunk_id=chunk.chunk_id, + doc_id=chunk.doc_id, + kb_id=chunk.kb_id, + content=chunk.content, + score=rrf_scores[identifier], + ) + ) + + return fused_results diff --git a/astrbot/core/knowledge_base/retrieval/sparse_retriever.py b/astrbot/core/knowledge_base/retrieval/sparse_retriever.py new file mode 100644 index 000000000..15c205128 --- /dev/null +++ b/astrbot/core/knowledge_base/retrieval/sparse_retriever.py @@ -0,0 +1,90 @@ +"""稀疏检索器 + +使用 BM25 算法进行基于关键词的文档检索 +""" + +from dataclasses import dataclass +from typing import List + +from rank_bm25 import BM25Okapi + +from astrbot.core.knowledge_base.database import KBDatabase + + +@dataclass +class SparseResult: + """稀疏检索结果""" + + chunk_id: str + doc_id: str + kb_id: str + content: str + score: float + + +class SparseRetriever: + """BM25 稀疏检索器 + + 职责: + - 基于关键词的文档检索 + - 使用 BM25 算法计算相关度 + """ + + def __init__(self, kb_db: KBDatabase): + """初始化稀疏检索器 + + Args: + kb_db: 知识库数据库实例 + """ + self.kb_db = kb_db + self._index_cache = {} # 缓存 BM25 索引 + + async def retrieve( + self, + query: str, + kb_ids: List[str], + top_k: int = 50, + ) -> List[SparseResult]: + """执行稀疏检索 + + Args: + query: 查询文本 + kb_ids: 知识库 ID 列表 + top_k: 返回结果数量 + + Returns: + List[SparseResult]: 检索结果列表 + """ + # 1. 获取所有相关块 + chunks = await self.kb_db.get_chunks_by_kb_ids(kb_ids) + + if not chunks: + return [] + + # 2. 准备文档和索引 + corpus = [chunk.content for chunk in chunks] + tokenized_corpus = [doc.split() for doc in corpus] + + # 3. 构建 BM25 索引 + bm25 = BM25Okapi(tokenized_corpus) + + # 4. 执行检索 + tokenized_query = query.split() + scores = bm25.get_scores(tokenized_query) + + # 5. 排序并返回 Top-K + results = [] + for idx, score in enumerate(scores): + chunk = chunks[idx] + results.append( + SparseResult( + chunk_id=chunk.chunk_id, + doc_id=chunk.doc_id, + kb_id=chunk.kb_id, + content=chunk.content, + score=float(score), + ) + ) + + results.sort(key=lambda x: x.score, reverse=True) + return results[:top_k] diff --git a/astrbot/core/knowledge_base/session_config_db.py b/astrbot/core/knowledge_base/session_config_db.py new file mode 100644 index 000000000..ce0c63a4c --- /dev/null +++ b/astrbot/core/knowledge_base/session_config_db.py @@ -0,0 +1,157 @@ +"""会话知识库配置数据库操作 + +该模块封装会话知识库配置的数据库查询操作。 + +注意: 会话配置表 (kb_session_config) 存储在知识库独立数据库 (kb.db) 中, + 而不是主数据库 (astrbot.db) 中,以实现完全解耦。 +""" + +import json +from typing import Optional + +from sqlalchemy import select + +from astrbot.core.knowledge_base.kb_sqlite import KBSQLiteDatabase +from astrbot.core.knowledge_base.models import KBSessionConfig + + +class SessionConfigDB: + """会话知识库配置数据库操作类 + + 职责: + - 提供会话知识库配置管理 + - 统一异常处理 + + 注意: 该类操作知识库独立数据库,实现完全解耦 + """ + + def __init__(self, db: KBSQLiteDatabase): + """初始化会话配置数据库操作类 + + Args: + db: 知识库独立数据库实例 (kb.db),不是主数据库 + """ + self.db = db + + async def get_session_kb_ids(self, session_id: str) -> list[str]: + """获取会话关联的知识库 ID 列表 + + 查找顺序: + 1. 会话级别配置 (优先) + 2. 平台级别配置 + 3. 返回空列表 + """ + async with self.db.get_db() as session: + # 1. 查找会话级别配置 + stmt = select(KBSessionConfig).where( + KBSessionConfig.scope == "session", + KBSessionConfig.scope_id == session_id, + ) + result = await session.execute(stmt) + config = result.scalar_one_or_none() + + if config: + return json.loads(config.kb_ids) + + # 2. 提取平台 ID (格式: platform:xxx:session_id) + parts = session_id.split(":") + if len(parts) >= 2: + platform_id = parts[0] + + # 查找平台级别配置 + stmt = select(KBSessionConfig).where( + KBSessionConfig.scope == "platform", + KBSessionConfig.scope_id == platform_id, + ) + result = await session.execute(stmt) + config = result.scalar_one_or_none() + + if config: + return json.loads(config.kb_ids) + + # 3. 无配置 + return [] + + async def set_session_kb_ids( + self, + scope: str, + scope_id: str, + kb_ids: list[str], + top_k: Optional[int] = None, + enable_rerank: Optional[bool] = None, + ) -> KBSessionConfig: + """设置会话知识库配置 + + Args: + scope: 配置范围 (session/platform) + scope_id: 范围标识 (会话 ID 或平台 ID) + kb_ids: 知识库 ID 列表 + top_k: 返回结果数量 (可选) + enable_rerank: 是否启用 Rerank (可选) + """ + async with self.db.get_db() as session: + # 查找现有配置 + stmt = select(KBSessionConfig).where( + KBSessionConfig.scope == scope, + KBSessionConfig.scope_id == scope_id, + ) + result = await session.execute(stmt) + config = result.scalar_one_or_none() + + if config: + # 更新现有配置 + config.kb_ids = json.dumps(kb_ids) + if top_k is not None: + config.top_k = top_k + if enable_rerank is not None: + config.enable_rerank = enable_rerank + else: + # 创建新配置 + config = KBSessionConfig( + scope=scope, + scope_id=scope_id, + kb_ids=json.dumps(kb_ids), + top_k=top_k, + enable_rerank=enable_rerank, + ) + session.add(config) + + await session.commit() + await session.refresh(config) + return config + + async def delete_session_kb_config(self, scope: str, scope_id: str) -> bool: + """删除会话知识库配置""" + async with self.db.get_db() as session: + stmt = select(KBSessionConfig).where( + KBSessionConfig.scope == scope, + KBSessionConfig.scope_id == scope_id, + ) + result = await session.execute(stmt) + config = result.scalar_one_or_none() + + if not config: + return False + + await session.delete(config) + await session.commit() + return True + + async def list_all_session_configs( + self, offset: int = 0, limit: int = 100, scope: Optional[str] = None + ) -> list[KBSessionConfig]: + """列出所有会话配置""" + async with self.db.get_db() as session: + stmt = select(KBSessionConfig) + + if scope: + stmt = stmt.where(KBSessionConfig.scope == scope) + + stmt = ( + stmt.offset(offset) + .limit(limit) + .order_by(KBSessionConfig.created_at.desc()) + ) + + result = await session.execute(stmt) + return list(result.scalars().all())