diff --git a/astrbot/core/core_lifecycle.py b/astrbot/core/core_lifecycle.py index cc0122a1b..a07c4d149 100644 --- a/astrbot/core/core_lifecycle.py +++ b/astrbot/core/core_lifecycle.py @@ -113,7 +113,7 @@ class AstrBotCoreLifecycle: # 初始化知识库管理器 self.kb_manager = KnowledgeBaseManager( - self.astrbot_config, self.db, self.provider_manager + self.astrbot_config, self.provider_manager ) # 初始化提供给插件的上下文 diff --git a/astrbot/core/db/vec_db/base.py b/astrbot/core/db/vec_db/base.py index d6e5e56be..d71cf6bf8 100644 --- a/astrbot/core/db/vec_db/base.py +++ b/astrbot/core/db/vec_db/base.py @@ -16,14 +16,23 @@ class BaseVecDB: pass @abc.abstractmethod - async def insert(self, content: str, metadata: dict = None, id: str = None) -> int: + async def insert( + self, content: str, metadata: dict | None = None, id: str | None = None + ) -> int: """ 插入一条文本和其对应向量,自动生成 ID 并保持一致性。 """ ... @abc.abstractmethod - async def retrieve(self, query: str, top_k: int = 5) -> list[Result]: + async def retrieve( + self, + query: str, + top_k: int = 5, + fetch_k: int = 20, + rerank: bool = False, + metadata_filters: dict | None = None, + ) -> list[Result]: """ 搜索最相似的文档。 Args: @@ -44,3 +53,6 @@ class BaseVecDB: bool: 删除是否成功 """ ... + + @abc.abstractmethod + async def close(self): ... diff --git a/astrbot/core/knowledge_base/database.py b/astrbot/core/knowledge_base/database.py index 83ec0e4ba..d48c9e522 100644 --- a/astrbot/core/knowledge_base/database.py +++ b/astrbot/core/knowledge_base/database.py @@ -7,7 +7,6 @@ - 会话配置也存储在此数据库中,会话ID来源于主数据库 """ -import json from typing import Optional from sqlalchemy import func, select @@ -17,7 +16,6 @@ from astrbot.core.knowledge_base.models import ( KBChunk, KBDocument, KBMedia, - KBSessionConfig, KnowledgeBase, ) @@ -183,165 +181,3 @@ class KBDatabase: stmt = select(KBMedia).where(KBMedia.media_id == media_id) result = await session.execute(stmt) return result.scalar_one_or_none() - - # ===== 会话配置查询 ===== - - async def get_session_kb_ids(self, session_id: str) -> list[str]: - """获取会话关联的知识库 ID 列表 - - 查找顺序: - 1. 会话级别配置 (优先) - 2. 平台级别配置 - 3. 返回空列表 - - Args: - session_id: 会话ID(来自主数据库) - - Returns: - 知识库ID列表 - """ - async with self.db.get_db() as session: - # 1. 查找会话级别配置 - stmt = select(KBSessionConfig).where( - KBSessionConfig.scope == "session", - KBSessionConfig.scope_id == session_id, - ) - result = await session.execute(stmt) - config = result.scalar_one_or_none() - - if config: - return json.loads(config.kb_ids) - - # 2. 提取平台 ID (格式: platform:xxx:session_id) - parts = session_id.split(":") - if len(parts) >= 2: - platform_id = parts[0] - - # 查找平台级别配置 - stmt = select(KBSessionConfig).where( - KBSessionConfig.scope == "platform", - KBSessionConfig.scope_id == platform_id, - ) - result = await session.execute(stmt) - config = result.scalar_one_or_none() - - if config: - return json.loads(config.kb_ids) - - # 3. 无配置 - return [] - - async def set_session_kb_ids( - self, - scope: str, - scope_id: str, - kb_ids: list[str], - top_k: Optional[int] = None, - enable_rerank: Optional[bool] = None, - ) -> KBSessionConfig: - """设置会话知识库配置 - - Args: - scope: 配置范围 (session/platform) - scope_id: 范围标识 (会话 ID 或平台 ID,来自主数据库) - kb_ids: 知识库 ID 列表 - top_k: 返回结果数量 (可选) - enable_rerank: 是否启用 Rerank (可选) - - Returns: - 配置对象 - """ - async with self.db.get_db() as session: - # 查找现有配置 - stmt = select(KBSessionConfig).where( - KBSessionConfig.scope == scope, - KBSessionConfig.scope_id == scope_id, - ) - result = await session.execute(stmt) - config = result.scalar_one_or_none() - - if config: - # 更新现有配置 - config.kb_ids = json.dumps(kb_ids) - if top_k is not None: - config.top_k = top_k - if enable_rerank is not None: - config.enable_rerank = enable_rerank - else: - # 创建新配置 - config = KBSessionConfig( - scope=scope, - scope_id=scope_id, - kb_ids=json.dumps(kb_ids), - top_k=top_k, - enable_rerank=enable_rerank, - ) - session.add(config) - - await session.commit() - await session.refresh(config) - return config - - async def delete_session_kb_config(self, scope: str, scope_id: str) -> bool: - """删除会话知识库配置 - - Args: - scope: 配置范围 (session/platform) - scope_id: 范围标识 (会话 ID 或平台 ID) - - Returns: - 是否删除成功 - """ - async with self.db.get_db() as session: - stmt = select(KBSessionConfig).where( - KBSessionConfig.scope == scope, - KBSessionConfig.scope_id == scope_id, - ) - result = await session.execute(stmt) - config = result.scalar_one_or_none() - - if not config: - return False - - await session.delete(config) - await session.commit() - return True - - async def delete_session_kb_config_by_session_id(self, session_id: str) -> bool: - """根据会话ID删除会话配置(用于主数据库会话删除时的级联清理) - - Args: - session_id: 会话ID(来自主数据库) - - Returns: - 是否删除成功 - """ - return await self.delete_session_kb_config("session", session_id) - - async def list_all_session_configs( - self, offset: int = 0, limit: int = 100, scope: Optional[str] = None - ) -> list[KBSessionConfig]: - """列出所有会话配置 - - Args: - offset: 偏移量 - limit: 限制数量 - scope: 可选的范围过滤 (session/platform) - - Returns: - 会话配置列表 - """ - async with self.db.get_db() as session: - stmt = select(KBSessionConfig) - - if scope: - stmt = stmt.where(KBSessionConfig.scope == scope) - - stmt = ( - stmt.offset(offset) - .limit(limit) - .order_by(KBSessionConfig.created_at.desc()) - ) - - result = await session.execute(stmt) - return list(result.scalars().all()) diff --git a/astrbot/core/knowledge_base/injector.py b/astrbot/core/knowledge_base/injector.py index 968e51ff6..d2a813911 100644 --- a/astrbot/core/knowledge_base/injector.py +++ b/astrbot/core/knowledge_base/injector.py @@ -10,6 +10,7 @@ from astrbot.core.knowledge_base.retrieval.manager import ( RetrievalManager, RetrievalResult, ) +from .vec_db_factory import VecDBFactory class KnowledgeBaseInjector: @@ -24,6 +25,7 @@ class KnowledgeBaseInjector: def __init__( self, kb_db: KBDatabase, + vec_db_factory: VecDBFactory, retrieval_manager: RetrievalManager, ): """初始化知识库上下文注入器 @@ -33,18 +35,18 @@ class KnowledgeBaseInjector: retrieval_manager: 检索管理器实例 """ self.kb_db = kb_db + self.vec_db_factory = vec_db_factory self.retrieval_manager = retrieval_manager async def retrieve_and_inject( self, - unified_msg_origin: str, + kb_ids: list[str], query: str, top_k: int = 5, ) -> Optional[dict]: """检索并注入知识库上下文 Args: - unified_msg_origin: 统一消息来源 ID (会话 ID) query: 用户查询 top_k: 返回结果数量 @@ -55,14 +57,9 @@ class KnowledgeBaseInjector: "results": List[dict], # 原始检索结果列表 } """ - # 1. 获取会话关联的知识库 - kb_ids = await self.kb_db.get_session_kb_ids(unified_msg_origin) - - if not kb_ids: - return None - # 2. 检索知识 results = await self.retrieval_manager.retrieve( + vec_db_factory=self.vec_db_factory, query=query, kb_ids=kb_ids, top_m_final=top_k, diff --git a/astrbot/core/knowledge_base/kb_manager_lifecycle.py b/astrbot/core/knowledge_base/kb_manager_lifecycle.py index 0768874fa..c0b709a5c 100644 --- a/astrbot/core/knowledge_base/kb_manager_lifecycle.py +++ b/astrbot/core/knowledge_base/kb_manager_lifecycle.py @@ -9,8 +9,18 @@ from pathlib import Path from astrbot.core import logger -from astrbot.core.db import BaseDatabase from astrbot.core.provider.manager import ProviderManager +from .injector import KnowledgeBaseInjector +from .retrieval.manager import RetrievalManager +from .retrieval.sparse_retriever import SparseRetriever +from .retrieval.rank_fusion import RankFusion +from .kb_sqlite import KBSQLiteDatabase +from .database import KBDatabase +from .vec_db_factory import VecDBFactory +from .manager import KBManager +from .parsers.text_parser import TextParser +from .parsers.pdf_parser import PDFParser +from .chunking.fixed_size import FixedSizeChunker class KnowledgeBaseManager: @@ -28,32 +38,26 @@ class KnowledgeBaseManager: - 通过回调机制实现与主数据库的生命周期同步 """ + kb_db: KBSQLiteDatabase + vec_db_factory: VecDBFactory + kb_database: KBDatabase + kb_manager: KBManager + retrieval_manager: RetrievalManager + kb_injector: KnowledgeBaseInjector + def __init__( self, config: dict, - main_db: BaseDatabase, provider_manager: ProviderManager, ): """初始化知识库管理器 Args: config: 配置字典 - main_db: 主数据库实例 (不直接使用,仅用于类型引用) provider_manager: Provider 管理器 """ self.config = config.get("knowledge_base", {}) self.provider_manager = provider_manager - - # 知识库独立数据库 - self.kb_db = None - - # 组件实例 - self.kb_database = None - self.kb_manager = None - self.kb_vec_db = None - self.retrieval_manager = None - self.kb_injector = None - self._initialized = False self._session_deleted_callback_registered = False @@ -66,31 +70,54 @@ class KnowledgeBaseManager: try: logger.info("正在初始化知识库模块...") - # 1. 检查并选择 Embedding Provider - embedding_provider = self._select_embedding_provider() - if not embedding_provider: - logger.warning("未配置 Embedding Provider,知识库功能无法使用") - return - - # 2. 初始化数据库 + # 初始化数据库 await self._init_kb_database() - await self._init_database() - # 3. 初始化向量数据库 - await self._init_vector_db(embedding_provider) + # 初始化向量数据库工厂 + await self._init_vector_db_factory() - # 4. 初始化解析器和分块器 - parsers = self._init_parsers() - chunker = self._init_chunker() + # 初始化解析器和分块器 + parsers = { + "txt": TextParser(), + "md": TextParser(), + "markdown": TextParser(), + "pdf": PDFParser(), + } + chunking_config = self.config.get("chunking", {}) + chunker = FixedSizeChunker( + chunk_size=chunking_config.get("chunk_size", 512), + chunk_overlap=chunking_config.get("chunk_overlap", 50), + ) - # 5. 初始化知识库管理器 - await self._init_kb_manager(parsers, chunker) + # 初始化知识库管理器 + files_path = self.config.get("storage", {}).get( + "files_path", "data/knowledge_base" + ) + self.kb_manager = KBManager( + db=self.kb_db, + vec_db_factory=self.vec_db_factory, + storage_path=files_path, + parsers=parsers, + chunker=chunker, + provider_manager=self.provider_manager, + ) - # 6. 初始化检索管理器 - await self._init_retrieval_manager() + # 初始化检索管理器 + sparse_retriever = SparseRetriever(self.kb_database) + rank_fusion = RankFusion(self.kb_database) + self.retrieval_manager = RetrievalManager( + vec_db_factory=self.vec_db_factory, + sparse_retriever=sparse_retriever, + rank_fusion=rank_fusion, + kb_db=self.kb_database, + ) - # 7. 初始化上下文注入器 - await self._init_injector() + # 初始化上下文注入器 + self.kb_injector = KnowledgeBaseInjector( + kb_db=self.kb_database, + vec_db_factory=self.vec_db_factory, + retrieval_manager=self.retrieval_manager, + ) self._initialized = True logger.info("知识库模块初始化完成") @@ -106,8 +133,6 @@ class KnowledgeBaseManager: async def _init_kb_database(self): """初始化知识库独立数据库""" - from astrbot.core.knowledge_base.kb_sqlite import KBSQLiteDatabase - db_path = self.config.get("storage", {}).get( "kb_db_path", "data/knowledge_base/kb.db" ) @@ -116,168 +141,16 @@ class KnowledgeBaseManager: self.kb_db = KBSQLiteDatabase(db_path) await self.kb_db.initialize() await self.kb_db.migrate_to_v1() - - logger.info(f"知识库独立数据库已初始化: {db_path}") - - async def _init_database(self): - """初始化知识库数据库操作类""" - from astrbot.core.knowledge_base.database import KBDatabase - self.kb_database = KBDatabase(self.kb_db) + logger.info(f"KnowledgeBase database initialized: {db_path}") - async def _init_vector_db(self, embedding_provider): - """初始化向量数据库""" - from astrbot.core.db.vec_db.faiss_impl import FaissVecDB - + async def _init_vector_db_factory(self): + """初始化向量数据库工厂""" storage_path = self.config.get("storage", {}).get( "vector_db_path", "data/knowledge_base/vectors" ) Path(storage_path).mkdir(parents=True, exist_ok=True) - - self.kb_vec_db = FaissVecDB( - doc_store_path=f"{storage_path}/documents.db", - index_store_path=f"{storage_path}/index.faiss", - embedding_provider=embedding_provider, - ) - await self.kb_vec_db.initialize() - - def _init_parsers(self) -> dict: - """初始化文档解析器""" - from astrbot.core.knowledge_base.parsers.text_parser import TextParser - from astrbot.core.knowledge_base.parsers.pdf_parser import PDFParser - - return { - "txt": TextParser(), - "md": TextParser(), - "markdown": TextParser(), - "pdf": PDFParser(), - } - - def _init_chunker(self): - """初始化分块器""" - from astrbot.core.knowledge_base.chunking.fixed_size import FixedSizeChunker - - chunking_config = self.config.get("chunking", {}) - return FixedSizeChunker( - chunk_size=chunking_config.get("chunk_size", 512), - chunk_overlap=chunking_config.get("chunk_overlap", 50), - ) - - async def _init_kb_manager(self, parsers: dict, chunker): - """初始化知识库管理器""" - from astrbot.core.knowledge_base.manager import KBManager - - files_path = self.config.get("storage", {}).get( - "files_path", "data/knowledge_base" - ) - - self.kb_manager = KBManager( - db=self.kb_db, # 使用独立的知识库数据库 - vec_db=self.kb_vec_db, - storage_path=files_path, - parsers=parsers, - chunker=chunker, - provider_manager=self.provider_manager, - ) - - async def _init_retrieval_manager(self): - """初始化检索管理器""" - from astrbot.core.knowledge_base.retrieval.manager import RetrievalManager - from astrbot.core.knowledge_base.retrieval.sparse_retriever import ( - SparseRetriever, - ) - from astrbot.core.knowledge_base.retrieval.rank_fusion import RankFusion - - sparse_retriever = SparseRetriever(self.kb_database) - rank_fusion = RankFusion(self.kb_database) - - # 选择 Rerank Provider (可选) - rerank_provider = self._select_rerank_provider() - - self.retrieval_manager = RetrievalManager( - vec_db=self.kb_vec_db, - sparse_retriever=sparse_retriever, - rank_fusion=rank_fusion, - kb_db=self.kb_database, - rerank_provider=rerank_provider, - ) - - async def _init_injector(self): - """初始化上下文注入器""" - from astrbot.core.knowledge_base.injector import KnowledgeBaseInjector - - self.kb_injector = KnowledgeBaseInjector( - kb_db=self.kb_database, - retrieval_manager=self.retrieval_manager, - ) - - def _select_embedding_provider(self): - """选择 Embedding Provider - - 逻辑: - - 如果配置了 embedding_provider_id,则使用指定的 provider - - 如果没有配置,但有 embedding provider,则使用第一个 - - 如果有多个 embedding provider 但没有指定,则警告并使用第一个 - """ - embedding_providers = self.provider_manager.embedding_provider_insts - - if not embedding_providers: - return None - - configured_provider_id = self.config.get("embedding_provider_id") - - if configured_provider_id: - # 按 ID 查找 - for provider in embedding_providers: - provider_id = provider.meta().id - if provider_id == configured_provider_id: - logger.info(f"知识库使用 Embedding Provider: {provider_id}") - return provider - logger.warning( - f"未找到配置的 Embedding Provider ID: {configured_provider_id}," - f"将使用第一个可用的" - ) - - if len(embedding_providers) > 1 and not configured_provider_id: - provider = embedding_providers[0] - provider_id = provider.meta().id - logger.info( - f"检测到 {len(embedding_providers)} 个 Embedding Provider," - f"未在配置文件中指定 embedding_provider_id,将使用第一个: {provider_id}" - ) - return provider - - provider = embedding_providers[0] - provider_id = provider.meta().id - logger.info(f"知识库使用 Embedding Provider: {provider_id}") - return provider - - def _select_rerank_provider(self): - """选择 Rerank Provider (可选)""" - if not self.config.get("retrieval", {}).get("enable_rerank", True): - return None - - rerank_providers = self.provider_manager.rerank_provider_insts - if not rerank_providers: - return None - - configured_provider_id = self.config.get("rerank_provider_id") - - if configured_provider_id: - for provider in rerank_providers: - provider_id = provider.meta().id - if provider_id == configured_provider_id: - logger.info(f"知识库使用 Rerank Provider: {provider_id}") - return provider - logger.warning(f"未找到配置的 Rerank Provider ID: {configured_provider_id}") - - if len(rerank_providers) > 0: - provider = rerank_providers[0] - provider_id = provider.meta().id - logger.info(f"知识库使用 Rerank Provider: {provider_id}") - return provider - - return None + self.vec_db_factory = VecDBFactory(storage_base_path=storage_path) @property def is_initialized(self) -> bool: @@ -292,31 +165,6 @@ class KnowledgeBaseManager: """获取知识库上下文注入器""" return self.kb_injector if self._initialized else None - def register_session_lifecycle_hooks(self, conversation_manager): - """注册会话生命周期钩子 - - 在会话删除时自动清理知识库配置,实现零侵入的级联清理。 - - Args: - conversation_manager: 会话管理器实例 - """ - if self._session_deleted_callback_registered or not self._initialized: - return - - async def on_session_deleted(session_id: str): - """会话删除回调:清理知识库配置""" - try: - await self.kb_database.delete_session_kb_config_by_session_id( - session_id - ) - logger.info(f"已清理会话知识库配置: {session_id}") - except Exception as e: - logger.error(f"清理会话知识库配置失败 ({session_id}): {e}") - - conversation_manager.register_on_session_deleted(on_session_deleted) - self._session_deleted_callback_registered = True - logger.info("已注册知识库会话删除回调") - async def reinitialize(self): """重新初始化知识库模块 @@ -336,13 +184,13 @@ class KnowledgeBaseManager: logger.info("正在终止知识库模块...") - # 关闭向量数据库连接 - if self.kb_vec_db: + # 关闭向量数据库工厂(关闭所有向量数据库实例) + if self.vec_db_factory: try: - await self.kb_vec_db.close() - logger.debug("向量数据库已关闭") + await self.vec_db_factory.close_all() + logger.debug("向量数据库工厂已关闭") except Exception as e: - logger.warning(f"关闭向量数据库时出错: {e}") + logger.warning(f"关闭向量数据库工厂时出错: {e}") # 关闭知识库独立数据库连接 if self.kb_db: @@ -352,13 +200,6 @@ class KnowledgeBaseManager: except Exception as e: logger.warning(f"关闭知识库数据库时出错: {e}") - # 清理资源 self._initialized = False - self.kb_db = None - self.kb_database = None - self.kb_manager = None - self.kb_vec_db = None - self.retrieval_manager = None - self.kb_injector = None logger.info("知识库模块已终止") diff --git a/astrbot/core/knowledge_base/manager.py b/astrbot/core/knowledge_base/manager.py index c4d31ba45..f6a6c7f87 100644 --- a/astrbot/core/knowledge_base/manager.py +++ b/astrbot/core/knowledge_base/manager.py @@ -10,12 +10,11 @@ from typing import Optional import aiofiles from sqlalchemy import func, select, update -from astrbot.core.db import BaseDatabase -from astrbot.core.db.vec_db.base import BaseVecDB +from .kb_sqlite import KBSQLiteDatabase from astrbot.core.knowledge_base.chunking.base import BaseChunker from astrbot.core.knowledge_base.models import KBChunk, KBDocument, KnowledgeBase from astrbot.core.knowledge_base.parsers.base import BaseParser - +from .vec_db_factory import VecDBFactory class KBManager: """知识库管理器 @@ -29,15 +28,15 @@ class KBManager: def __init__( self, - db: BaseDatabase, - vec_db: BaseVecDB, + db: KBSQLiteDatabase, + vec_db_factory: VecDBFactory, storage_path: str, parsers: dict[str, BaseParser], chunker: BaseChunker, provider_manager=None, ): self.db = db - self.vec_db = vec_db + self.vec_db_factory = vec_db_factory self.storage_path = Path(storage_path) self.media_path = self.storage_path / "media" self.files_path = self.storage_path / "files" @@ -49,6 +48,48 @@ class KBManager: self.media_path.mkdir(parents=True, exist_ok=True) self.files_path.mkdir(parents=True, exist_ok=True) + async def _get_embedding_provider_for_kb(self, kb_id: str): + """根据知识库配置获取 Embedding Provider + + Args: + kb_id: 知识库 ID + + Returns: + EmbeddingProvider: Embedding Provider 实例 + + Raises: + ValueError: 如果找不到合适的 embedding provider + """ + from astrbot.core.knowledge_base.database import KBDatabase + + # 获取知识库配置 + kb_database = KBDatabase(self.db) + kb = await kb_database.get_kb_by_id(kb_id) + if not kb: + raise ValueError(f"知识库不存在: {kb_id}") + + embedding_provider_id = kb.embedding_provider_id + + # 如果没有 provider_manager,使用默认的第一个 + if not self.provider_manager: + raise ValueError("Provider Manager 未初始化") + + embedding_providers = self.provider_manager.embedding_provider_insts + if not embedding_providers: + raise ValueError("系统中没有可用的 Embedding Provider") + + # 如果指定了 provider ID,则查找该 provider + if embedding_provider_id: + for provider in embedding_providers: + if provider.meta().id == embedding_provider_id: + return provider + raise ValueError( + f"未找到配置的 Embedding Provider: {embedding_provider_id}" + ) + + # 使用第一个可用的 provider + return embedding_providers[0] + # ===== 知识库操作 ===== async def create_kb( @@ -77,7 +118,7 @@ class KBManager: # 检查是否有可用的 rerank provider has_rerank_provider = ( self.provider_manager - and hasattr(self.provider_manager, 'rerank_provider_insts') + and hasattr(self.provider_manager, "rerank_provider_insts") and len(self.provider_manager.rerank_provider_insts) > 0 ) enable_rerank = has_rerank_provider @@ -182,7 +223,10 @@ class KBManager: for doc in docs: await ops.delete_document(doc.doc_id) - # 3. 删除知识库记录 + # 3. 删除向量数据库 + await self.vec_db_factory.delete_vec_db(kb_id) + + # 4. 删除知识库记录 async with self.db.get_db() as session: stmt = select(KnowledgeBase).where(KnowledgeBase.kb_id == kb_id) result = await session.execute(stmt) @@ -257,11 +301,15 @@ class KBManager: # 4. 文档分块 chunks_text = await self.chunker.chunk(text_content) - # 5. 生成向量并存储 + # 5. 获取 Embedding Provider 和向量数据库 + embedding_provider = await self._get_embedding_provider_for_kb(kb_id) + vec_db = await self.vec_db_factory.get_vec_db(kb_id, embedding_provider) + + # 6. 生成向量并存储 saved_chunks = [] for idx, chunk_text in enumerate(chunks_text): # 存储到向量数据库 - vec_doc_id = await self.vec_db.insert( + vec_doc_id = await vec_db.insert( content=chunk_text, metadata={ "kb_id": kb_id, @@ -282,7 +330,7 @@ class KBManager: ) saved_chunks.append(chunk) - # 6. 保存文档元数据(事务) + # 7. 保存文档元数据(事务) doc = KBDocument( doc_id=doc_id, kb_id=kb_id, @@ -305,7 +353,7 @@ class KBManager: await session.refresh(doc) - # 7. 更新知识库统计 + # 8. 更新知识库统计 await self._update_kb_stats(kb_id) return doc @@ -316,12 +364,19 @@ class KBManager: logger.error(f"文档上传失败,开始清理资源: {e}") - # 清理向量数据库 - for vec_id in vec_doc_ids: - try: - await self.vec_db.delete(vec_id) - except Exception as ve: - logger.warning(f"清理向量失败 {vec_id}: {ve}") + # 获取知识库的向量数据库 + try: + embedding_provider = await self._get_embedding_provider_for_kb(kb_id) + vec_db = await self.vec_db_factory.get_vec_db(kb_id, embedding_provider) + + # 清理向量数据库 + for vec_id in vec_doc_ids: + try: + await vec_db.delete(vec_id) + except Exception as ve: + logger.warning(f"清理向量失败 {vec_id}: {ve}") + except Exception as vfe: + logger.error(f"获取向量数据库失败: {vfe}") # 清理多媒体文件 for media_path in media_paths: diff --git a/astrbot/core/knowledge_base/manager_ops.py b/astrbot/core/knowledge_base/manager_ops.py index e0ab5f6d6..450689263 100644 --- a/astrbot/core/knowledge_base/manager_ops.py +++ b/astrbot/core/knowledge_base/manager_ops.py @@ -28,7 +28,7 @@ class KBManagerOps: def __init__(self, manager: "KBManager"): self.manager = manager self.db = manager.db - self.vec_db = manager.vec_db + self.vec_db_factory = manager.vec_db_factory self.media_path = manager.media_path self.files_path = manager.files_path @@ -75,6 +75,12 @@ class KBManagerOps: chunks = await self.list_chunks(doc_id) media_list = await self.list_media(doc_id) + # 获取知识库的向量数据库 + embedding_provider = await self.manager._get_embedding_provider_for_kb( + doc.kb_id + ) + vec_db = await self.vec_db_factory.get_vec_db(doc.kb_id, embedding_provider) + # ===== 第一阶段: 删除向量(可重试) ===== vec_ids_to_delete = [chunk.vec_doc_id for chunk in chunks] deleted_vec_ids = [] @@ -82,7 +88,7 @@ class KBManagerOps: for vec_id in vec_ids_to_delete: try: - await self.vec_db.delete(vec_id) + await vec_db.delete(vec_id) deleted_vec_ids.append(vec_id) except Exception as e: logger.error(f"删除向量失败: {vec_id}, {e}") @@ -173,11 +179,16 @@ class KBManagerOps: return False doc_id = chunk.doc_id + kb_id = chunk.kb_id vec_doc_id = chunk.vec_doc_id - # 2. 删除向量 + # 2. 获取知识库的向量数据库并删除向量 try: - await self.vec_db.delete(vec_doc_id) + embedding_provider = await self.manager._get_embedding_provider_for_kb( + kb_id + ) + vec_db = await self.vec_db_factory.get_vec_db(kb_id, embedding_provider) + await vec_db.delete(vec_doc_id) except Exception as e: logger.error(f"删除向量失败: {vec_doc_id}, {e}") return False diff --git a/astrbot/core/knowledge_base/models.py b/astrbot/core/knowledge_base/models.py index 6ec77ad93..4777ac474 100644 --- a/astrbot/core/knowledge_base/models.py +++ b/astrbot/core/knowledge_base/models.py @@ -16,7 +16,7 @@ import uuid from datetime import datetime, timezone from typing import Optional -from sqlmodel import Field, SQLModel, Text, UniqueConstraint +from sqlmodel import Field, SQLModel, Text class KnowledgeBase(SQLModel, table=True): @@ -25,7 +25,7 @@ class KnowledgeBase(SQLModel, table=True): 存储知识库的基本信息和统计数据。 """ - __tablename__ = "knowledge_bases" + __tablename__ = "knowledge_bases" # type: ignore id: int | None = Field( primary_key=True, sa_column_kwargs={"autoincrement": True}, default=None @@ -65,7 +65,7 @@ class KBDocument(SQLModel, table=True): 存储上传到知识库的文档元数据。 """ - __tablename__ = "kb_documents" + __tablename__ = "kb_documents" # type: ignore id: int | None = Field( primary_key=True, sa_column_kwargs={"autoincrement": True}, default=None @@ -97,7 +97,7 @@ class KBChunk(SQLModel, table=True): 存储文档分块后的文本内容和向量索引关联信息。 """ - __tablename__ = "kb_chunks" + __tablename__ = "kb_chunks" # type: ignore id: int | None = Field( primary_key=True, sa_column_kwargs={"autoincrement": True}, default=None @@ -124,7 +124,7 @@ class KBMedia(SQLModel, table=True): 存储从文档中提取的图片、视频等多媒体资源。 """ - __tablename__ = "kb_media" + __tablename__ = "kb_media" # type: ignore id: int | None = Field( primary_key=True, sa_column_kwargs={"autoincrement": True}, default=None @@ -144,39 +144,3 @@ class KBMedia(SQLModel, table=True): file_size: int = Field(nullable=False) mime_type: str = Field(max_length=100, nullable=False) created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) - - -class KBSessionConfig(SQLModel, table=True): - """会话知识库配置表 - - 存储会话或平台级别的知识库关联配置。 - 该表存储在知识库独立数据库中,保持完全解耦。 - - 支持两种配置范围: - - platform: 平台级别配置 (如 'qq', 'telegram') - - session: 会话级别配置 (如 'qq:group:12345') - """ - - __tablename__ = "kb_session_config" - - id: int | None = Field( - primary_key=True, sa_column_kwargs={"autoincrement": True}, default=None - ) - config_id: str = Field( - max_length=36, - nullable=False, - unique=True, - default_factory=lambda: str(uuid.uuid4()), - ) - scope: str = Field(max_length=20, nullable=False) - scope_id: str = Field(max_length=255, nullable=False, index=True) - kb_ids: str = Field(sa_type=Text, nullable=False) - top_k: Optional[int] = Field(default=None, nullable=True) - enable_rerank: Optional[bool] = Field(default=None, nullable=True) - created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) - updated_at: datetime = Field( - default_factory=lambda: datetime.now(timezone.utc), - sa_column_kwargs={"onupdate": datetime.now(timezone.utc)}, - ) - - __table_args__ = (UniqueConstraint("scope", "scope_id", name="uix_scope_scope_id"),) diff --git a/astrbot/core/knowledge_base/retrieval/manager.py b/astrbot/core/knowledge_base/retrieval/manager.py index fb19d6717..1b500a1e3 100644 --- a/astrbot/core/knowledge_base/retrieval/manager.py +++ b/astrbot/core/knowledge_base/retrieval/manager.py @@ -3,16 +3,15 @@ 协调稠密检索、稀疏检索和 Rerank,提供统一的检索接口 """ -import json from dataclasses import dataclass -from typing import List, Optional +from typing import List -from astrbot.core.db.vec_db.base import BaseVecDB from astrbot.core.knowledge_base.database import KBDatabase from astrbot.core.knowledge_base.retrieval.rank_fusion import RankFusion from astrbot.core.knowledge_base.retrieval.sparse_retriever import SparseRetriever from astrbot.core.provider.provider import RerankProvider - +from astrbot.core.db.vec_db.base import BaseVecDB, Result +from ..vec_db_factory import VecDBFactory @dataclass class RetrievalResult: @@ -38,36 +37,34 @@ class RetrievalManager: def __init__( self, - vec_db: BaseVecDB, + vec_db_factory, # VecDBFactory sparse_retriever: SparseRetriever, rank_fusion: RankFusion, kb_db: KBDatabase, - rerank_provider: Optional[RerankProvider] = None, ): """初始化检索管理器 Args: - vec_db: 向量数据库实例 + vec_db_factory: 向量数据库工厂 sparse_retriever: 稀疏检索器 rank_fusion: 结果融合器 kb_db: 知识库数据库实例 - rerank_provider: Rerank 提供商 (可选) """ - self.vec_db = vec_db + self.vec_db_factory = vec_db_factory self.sparse_retriever = sparse_retriever self.rank_fusion = rank_fusion self.kb_db = kb_db - self.rerank_provider = rerank_provider async def retrieve( self, + vec_db_factory: VecDBFactory, query: str, kb_ids: List[str], top_k_dense: int = 50, top_k_sparse: int = 50, top_n_fusion: int = 20, top_m_final: int = 5, - enable_rerank: bool = True, + rerank_provider: RerankProvider | None = None, ) -> List[RetrievalResult]: """混合检索 @@ -94,6 +91,7 @@ class RetrievalManager: query=query, kb_ids=kb_ids, top_k=top_k_dense, + vec_db=vec_db, ) # 2. 稀疏检索 @@ -131,13 +129,13 @@ class RetrievalManager: ) ) - # 5. Rerank (可选) - if enable_rerank and self.rerank_provider and retrieval_results: + # 5. Rerank + if rerank_provider and retrieval_results: retrieval_results = await self._rerank( query=query, results=retrieval_results, top_k=top_m_final, - rerank_provider=self.rerank_provider, + rerank_provider=rerank_provider, ) else: retrieval_results = retrieval_results[:top_m_final] @@ -149,9 +147,12 @@ class RetrievalManager: query: str, kb_ids: List[str], top_k: int, + vec_db: BaseVecDB, ): """稠密检索 (向量相似度) + 为每个知识库使用独立的向量数据库进行检索,然后合并结果。 + Args: query: 查询文本 kb_ids: 知识库 ID 列表 @@ -160,28 +161,27 @@ class RetrievalManager: Returns: List[Result]: 检索结果列表 """ - # 直接调用向量数据库检索 - vec_results = await self.vec_db.retrieve( - query=query, - top_k=top_k * len(kb_ids) * 2, # 增加候选数量以便过滤 - ) + all_results: list[Result] = [] - # 过滤:只保留指定知识库的结果 - filtered_results = [] - for result in vec_results: - metadata_str = result.data.get("metadata", "{}") + for kb_id in kb_ids: try: - metadata = json.loads(metadata_str) - except (json.JSONDecodeError, TypeError): - metadata = {} + vec_results = await vec_db.retrieve( + query=query, + top_k=top_k, + fetch_k=top_k * 2, + metadata_filters={"kb_id": kb_id}, + ) - if metadata.get("kb_id") in kb_ids: - filtered_results.append(result) + all_results.extend(vec_results) + except Exception as e: + from astrbot.core import logger - if len(filtered_results) >= top_k: - break + logger.warning(f"知识库 {kb_id} 稠密检索失败: {e}") + continue - return filtered_results[:top_k] + # 按相似度排序并返回 top_k + all_results.sort(key=lambda x: x.similarity, reverse=True) + return all_results[:top_k] async def _rerank( self, diff --git a/astrbot/core/knowledge_base/session_config_db.py b/astrbot/core/knowledge_base/session_config_db.py deleted file mode 100644 index ce0c63a4c..000000000 --- a/astrbot/core/knowledge_base/session_config_db.py +++ /dev/null @@ -1,157 +0,0 @@ -"""会话知识库配置数据库操作 - -该模块封装会话知识库配置的数据库查询操作。 - -注意: 会话配置表 (kb_session_config) 存储在知识库独立数据库 (kb.db) 中, - 而不是主数据库 (astrbot.db) 中,以实现完全解耦。 -""" - -import json -from typing import Optional - -from sqlalchemy import select - -from astrbot.core.knowledge_base.kb_sqlite import KBSQLiteDatabase -from astrbot.core.knowledge_base.models import KBSessionConfig - - -class SessionConfigDB: - """会话知识库配置数据库操作类 - - 职责: - - 提供会话知识库配置管理 - - 统一异常处理 - - 注意: 该类操作知识库独立数据库,实现完全解耦 - """ - - def __init__(self, db: KBSQLiteDatabase): - """初始化会话配置数据库操作类 - - Args: - db: 知识库独立数据库实例 (kb.db),不是主数据库 - """ - self.db = db - - async def get_session_kb_ids(self, session_id: str) -> list[str]: - """获取会话关联的知识库 ID 列表 - - 查找顺序: - 1. 会话级别配置 (优先) - 2. 平台级别配置 - 3. 返回空列表 - """ - async with self.db.get_db() as session: - # 1. 查找会话级别配置 - stmt = select(KBSessionConfig).where( - KBSessionConfig.scope == "session", - KBSessionConfig.scope_id == session_id, - ) - result = await session.execute(stmt) - config = result.scalar_one_or_none() - - if config: - return json.loads(config.kb_ids) - - # 2. 提取平台 ID (格式: platform:xxx:session_id) - parts = session_id.split(":") - if len(parts) >= 2: - platform_id = parts[0] - - # 查找平台级别配置 - stmt = select(KBSessionConfig).where( - KBSessionConfig.scope == "platform", - KBSessionConfig.scope_id == platform_id, - ) - result = await session.execute(stmt) - config = result.scalar_one_or_none() - - if config: - return json.loads(config.kb_ids) - - # 3. 无配置 - return [] - - async def set_session_kb_ids( - self, - scope: str, - scope_id: str, - kb_ids: list[str], - top_k: Optional[int] = None, - enable_rerank: Optional[bool] = None, - ) -> KBSessionConfig: - """设置会话知识库配置 - - Args: - scope: 配置范围 (session/platform) - scope_id: 范围标识 (会话 ID 或平台 ID) - kb_ids: 知识库 ID 列表 - top_k: 返回结果数量 (可选) - enable_rerank: 是否启用 Rerank (可选) - """ - async with self.db.get_db() as session: - # 查找现有配置 - stmt = select(KBSessionConfig).where( - KBSessionConfig.scope == scope, - KBSessionConfig.scope_id == scope_id, - ) - result = await session.execute(stmt) - config = result.scalar_one_or_none() - - if config: - # 更新现有配置 - config.kb_ids = json.dumps(kb_ids) - if top_k is not None: - config.top_k = top_k - if enable_rerank is not None: - config.enable_rerank = enable_rerank - else: - # 创建新配置 - config = KBSessionConfig( - scope=scope, - scope_id=scope_id, - kb_ids=json.dumps(kb_ids), - top_k=top_k, - enable_rerank=enable_rerank, - ) - session.add(config) - - await session.commit() - await session.refresh(config) - return config - - async def delete_session_kb_config(self, scope: str, scope_id: str) -> bool: - """删除会话知识库配置""" - async with self.db.get_db() as session: - stmt = select(KBSessionConfig).where( - KBSessionConfig.scope == scope, - KBSessionConfig.scope_id == scope_id, - ) - result = await session.execute(stmt) - config = result.scalar_one_or_none() - - if not config: - return False - - await session.delete(config) - await session.commit() - return True - - async def list_all_session_configs( - self, offset: int = 0, limit: int = 100, scope: Optional[str] = None - ) -> list[KBSessionConfig]: - """列出所有会话配置""" - async with self.db.get_db() as session: - stmt = select(KBSessionConfig) - - if scope: - stmt = stmt.where(KBSessionConfig.scope == scope) - - stmt = ( - stmt.offset(offset) - .limit(limit) - .order_by(KBSessionConfig.created_at.desc()) - ) - - result = await session.execute(stmt) - return list(result.scalars().all()) diff --git a/astrbot/core/knowledge_base/vec_db_factory.py b/astrbot/core/knowledge_base/vec_db_factory.py new file mode 100644 index 000000000..ba2187f4d --- /dev/null +++ b/astrbot/core/knowledge_base/vec_db_factory.py @@ -0,0 +1,161 @@ +"""向量数据库工厂 + +负责为每个知识库创建和管理独立的向量数据库实例。 + +架构说明: +- 每个知识库拥有独立的向量数据库实例 +- 向量数据库文件以 kb_id 命名 +- 工厂类负责实例的创建、缓存和生命周期管理 +""" + +from pathlib import Path +from typing import Dict, Optional + +from astrbot.core import logger +from astrbot.core.db.vec_db.base import BaseVecDB +from astrbot.core.db.vec_db.faiss_impl.vec_db import FaissVecDB +from astrbot.core.provider.provider import EmbeddingProvider + + +class VecDBFactory: + """向量数据库工厂 + + 职责: + - 为每个知识库创建独立的向量数据库实例 + - 缓存已创建的实例以提高性能 + - 管理向量数据库的生命周期 + """ + + def __init__( + self, + storage_base_path: str, + ): + """初始化向量数据库工厂 + + Args: + storage_base_path: 向量数据库存储基础路径 + """ + self.storage_base_path = Path(storage_base_path) + self._instances: Dict[str, BaseVecDB] = {} + + # 确保基础路径存在 + self.storage_base_path.mkdir(parents=True, exist_ok=True) + + async def get_vec_db( + self, kb_id: str, embedding_provider: EmbeddingProvider + ) -> BaseVecDB: + """获取或创建指定知识库的向量数据库实例 + + Args: + kb_id: 知识库 ID + embedding_provider: Embedding Provider 实例 + + Returns: + BaseVecDB: 向量数据库实例 + """ + # 如果已经创建过,直接返回缓存的实例 + if kb_id in self._instances: + return self._instances[kb_id] + + # 创建新实例 + vec_db = await self._create_vec_db(kb_id, embedding_provider) + self._instances[kb_id] = vec_db + + logger.debug(f"创建知识库 {kb_id} 的向量数据库实例") + + return vec_db + + async def _create_vec_db( + self, kb_id: str, embedding_provider: EmbeddingProvider + ) -> BaseVecDB: + """创建向量数据库实例 + + Args: + kb_id: 知识库 ID + embedding_provider: Embedding Provider 实例 + + Returns: + BaseVecDB: 向量数据库实例 + """ + # 为每个知识库创建独立的存储路径 + kb_storage_path = self.storage_base_path / kb_id + kb_storage_path.mkdir(parents=True, exist_ok=True) + + doc_store_path = str(kb_storage_path / "documents.db") + index_store_path = str(kb_storage_path / "index.faiss") + + vec_db = FaissVecDB( + doc_store_path=doc_store_path, + index_store_path=index_store_path, + embedding_provider=embedding_provider, + ) + + await vec_db.initialize() + + return vec_db + + async def delete_vec_db(self, kb_id: str) -> bool: + """删除指定知识库的向量数据库 + + Args: + kb_id: 知识库 ID + + Returns: + bool: 是否删除成功 + """ + # 关闭并移除缓存的实例 + if kb_id in self._instances: + try: + await self._instances[kb_id].close() + except Exception as e: + logger.warning(f"关闭向量数据库失败 ({kb_id}): {e}") + + del self._instances[kb_id] + + # 删除文件系统中的向量数据库文件 + kb_storage_path = self.storage_base_path / kb_id + if kb_storage_path.exists(): + try: + import shutil + + shutil.rmtree(kb_storage_path) + logger.info(f"已删除知识库 {kb_id} 的向量数据库文件") + return True + except Exception as e: + logger.error(f"删除向量数据库文件失败 ({kb_id}): {e}") + return False + + return True + + async def close_all(self): + """关闭所有向量数据库实例""" + for kb_id, vec_db in list(self._instances.items()): + try: + await vec_db.close() + logger.debug(f"已关闭知识库 {kb_id} 的向量数据库") + except Exception as e: + logger.warning(f"关闭向量数据库失败 ({kb_id}): {e}") + + self._instances.clear() + + def has_instance(self, kb_id: str) -> bool: + """检查是否已创建指定知识库的向量数据库实例 + + Args: + kb_id: 知识库 ID + + Returns: + bool: 是否已创建实例 + """ + return kb_id in self._instances + + def get_cached_instance(self, kb_id: str) -> Optional[BaseVecDB]: + """获取已缓存的向量数据库实例(不创建新实例) + + Args: + kb_id: 知识库 ID + + Returns: + Optional[BaseVecDB]: 向量数据库实例,如果不存在则返回 None + """ + return self._instances.get(kb_id) diff --git a/astrbot/dashboard/routes/knowledge_base.py b/astrbot/dashboard/routes/knowledge_base.py index 53b10d59a..e4736a245 100644 --- a/astrbot/dashboard/routes/knowledge_base.py +++ b/astrbot/dashboard/routes/knowledge_base.py @@ -1280,7 +1280,13 @@ class KnowledgeBaseRoute(Route): return ( Response() - .ok({"sessions": session_list, "total": len(session_list), "kb_id": kb_id}) + .ok( + { + "sessions": session_list, + "total": len(session_list), + "kb_id": kb_id, + } + ) .__dict__ )