This commit is contained in:
Soulter
2025-10-23 00:31:15 +08:00
parent 65bc5efa19
commit e3aa1315ae
12 changed files with 381 additions and 655 deletions
+1 -1
View File
@@ -113,7 +113,7 @@ class AstrBotCoreLifecycle:
# 初始化知识库管理器
self.kb_manager = KnowledgeBaseManager(
self.astrbot_config, self.db, self.provider_manager
self.astrbot_config, self.provider_manager
)
# 初始化提供给插件的上下文
+14 -2
View File
@@ -16,14 +16,23 @@ class BaseVecDB:
pass
@abc.abstractmethod
async def insert(self, content: str, metadata: dict = None, id: str = None) -> int:
async def insert(
self, content: str, metadata: dict | None = None, id: str | None = None
) -> int:
"""
插入一条文本和其对应向量,自动生成 ID 并保持一致性。
"""
...
@abc.abstractmethod
async def retrieve(self, query: str, top_k: int = 5) -> list[Result]:
async def retrieve(
self,
query: str,
top_k: int = 5,
fetch_k: int = 20,
rerank: bool = False,
metadata_filters: dict | None = None,
) -> list[Result]:
"""
搜索最相似的文档。
Args:
@@ -44,3 +53,6 @@ class BaseVecDB:
bool: 删除是否成功
"""
...
@abc.abstractmethod
async def close(self): ...
-164
View File
@@ -7,7 +7,6 @@
- 会话配置也存储在此数据库中,会话ID来源于主数据库
"""
import json
from typing import Optional
from sqlalchemy import func, select
@@ -17,7 +16,6 @@ from astrbot.core.knowledge_base.models import (
KBChunk,
KBDocument,
KBMedia,
KBSessionConfig,
KnowledgeBase,
)
@@ -183,165 +181,3 @@ class KBDatabase:
stmt = select(KBMedia).where(KBMedia.media_id == media_id)
result = await session.execute(stmt)
return result.scalar_one_or_none()
# ===== 会话配置查询 =====
async def get_session_kb_ids(self, session_id: str) -> list[str]:
"""获取会话关联的知识库 ID 列表
查找顺序:
1. 会话级别配置 (优先)
2. 平台级别配置
3. 返回空列表
Args:
session_id: 会话ID(来自主数据库)
Returns:
知识库ID列表
"""
async with self.db.get_db() as session:
# 1. 查找会话级别配置
stmt = select(KBSessionConfig).where(
KBSessionConfig.scope == "session",
KBSessionConfig.scope_id == session_id,
)
result = await session.execute(stmt)
config = result.scalar_one_or_none()
if config:
return json.loads(config.kb_ids)
# 2. 提取平台 ID (格式: platform:xxx:session_id)
parts = session_id.split(":")
if len(parts) >= 2:
platform_id = parts[0]
# 查找平台级别配置
stmt = select(KBSessionConfig).where(
KBSessionConfig.scope == "platform",
KBSessionConfig.scope_id == platform_id,
)
result = await session.execute(stmt)
config = result.scalar_one_or_none()
if config:
return json.loads(config.kb_ids)
# 3. 无配置
return []
async def set_session_kb_ids(
self,
scope: str,
scope_id: str,
kb_ids: list[str],
top_k: Optional[int] = None,
enable_rerank: Optional[bool] = None,
) -> KBSessionConfig:
"""设置会话知识库配置
Args:
scope: 配置范围 (session/platform)
scope_id: 范围标识 (会话 ID 或平台 ID,来自主数据库)
kb_ids: 知识库 ID 列表
top_k: 返回结果数量 (可选)
enable_rerank: 是否启用 Rerank (可选)
Returns:
配置对象
"""
async with self.db.get_db() as session:
# 查找现有配置
stmt = select(KBSessionConfig).where(
KBSessionConfig.scope == scope,
KBSessionConfig.scope_id == scope_id,
)
result = await session.execute(stmt)
config = result.scalar_one_or_none()
if config:
# 更新现有配置
config.kb_ids = json.dumps(kb_ids)
if top_k is not None:
config.top_k = top_k
if enable_rerank is not None:
config.enable_rerank = enable_rerank
else:
# 创建新配置
config = KBSessionConfig(
scope=scope,
scope_id=scope_id,
kb_ids=json.dumps(kb_ids),
top_k=top_k,
enable_rerank=enable_rerank,
)
session.add(config)
await session.commit()
await session.refresh(config)
return config
async def delete_session_kb_config(self, scope: str, scope_id: str) -> bool:
"""删除会话知识库配置
Args:
scope: 配置范围 (session/platform)
scope_id: 范围标识 (会话 ID 或平台 ID)
Returns:
是否删除成功
"""
async with self.db.get_db() as session:
stmt = select(KBSessionConfig).where(
KBSessionConfig.scope == scope,
KBSessionConfig.scope_id == scope_id,
)
result = await session.execute(stmt)
config = result.scalar_one_or_none()
if not config:
return False
await session.delete(config)
await session.commit()
return True
async def delete_session_kb_config_by_session_id(self, session_id: str) -> bool:
"""根据会话ID删除会话配置(用于主数据库会话删除时的级联清理)
Args:
session_id: 会话ID(来自主数据库)
Returns:
是否删除成功
"""
return await self.delete_session_kb_config("session", session_id)
async def list_all_session_configs(
self, offset: int = 0, limit: int = 100, scope: Optional[str] = None
) -> list[KBSessionConfig]:
"""列出所有会话配置
Args:
offset: 偏移量
limit: 限制数量
scope: 可选的范围过滤 (session/platform)
Returns:
会话配置列表
"""
async with self.db.get_db() as session:
stmt = select(KBSessionConfig)
if scope:
stmt = stmt.where(KBSessionConfig.scope == scope)
stmt = (
stmt.offset(offset)
.limit(limit)
.order_by(KBSessionConfig.created_at.desc())
)
result = await session.execute(stmt)
return list(result.scalars().all())
+5 -8
View File
@@ -10,6 +10,7 @@ from astrbot.core.knowledge_base.retrieval.manager import (
RetrievalManager,
RetrievalResult,
)
from .vec_db_factory import VecDBFactory
class KnowledgeBaseInjector:
@@ -24,6 +25,7 @@ class KnowledgeBaseInjector:
def __init__(
self,
kb_db: KBDatabase,
vec_db_factory: VecDBFactory,
retrieval_manager: RetrievalManager,
):
"""初始化知识库上下文注入器
@@ -33,18 +35,18 @@ class KnowledgeBaseInjector:
retrieval_manager: 检索管理器实例
"""
self.kb_db = kb_db
self.vec_db_factory = vec_db_factory
self.retrieval_manager = retrieval_manager
async def retrieve_and_inject(
self,
unified_msg_origin: str,
kb_ids: list[str],
query: str,
top_k: int = 5,
) -> Optional[dict]:
"""检索并注入知识库上下文
Args:
unified_msg_origin: 统一消息来源 ID (会话 ID)
query: 用户查询
top_k: 返回结果数量
@@ -55,14 +57,9 @@ class KnowledgeBaseInjector:
"results": List[dict], # 原始检索结果列表
}
"""
# 1. 获取会话关联的知识库
kb_ids = await self.kb_db.get_session_kb_ids(unified_msg_origin)
if not kb_ids:
return None
# 2. 检索知识
results = await self.retrieval_manager.retrieve(
vec_db_factory=self.vec_db_factory,
query=query,
kb_ids=kb_ids,
top_m_final=top_k,
@@ -9,8 +9,18 @@
from pathlib import Path
from astrbot.core import logger
from astrbot.core.db import BaseDatabase
from astrbot.core.provider.manager import ProviderManager
from .injector import KnowledgeBaseInjector
from .retrieval.manager import RetrievalManager
from .retrieval.sparse_retriever import SparseRetriever
from .retrieval.rank_fusion import RankFusion
from .kb_sqlite import KBSQLiteDatabase
from .database import KBDatabase
from .vec_db_factory import VecDBFactory
from .manager import KBManager
from .parsers.text_parser import TextParser
from .parsers.pdf_parser import PDFParser
from .chunking.fixed_size import FixedSizeChunker
class KnowledgeBaseManager:
@@ -28,32 +38,26 @@ class KnowledgeBaseManager:
- 通过回调机制实现与主数据库的生命周期同步
"""
kb_db: KBSQLiteDatabase
vec_db_factory: VecDBFactory
kb_database: KBDatabase
kb_manager: KBManager
retrieval_manager: RetrievalManager
kb_injector: KnowledgeBaseInjector
def __init__(
self,
config: dict,
main_db: BaseDatabase,
provider_manager: ProviderManager,
):
"""初始化知识库管理器
Args:
config: 配置字典
main_db: 主数据库实例 (不直接使用,仅用于类型引用)
provider_manager: Provider 管理器
"""
self.config = config.get("knowledge_base", {})
self.provider_manager = provider_manager
# 知识库独立数据库
self.kb_db = None
# 组件实例
self.kb_database = None
self.kb_manager = None
self.kb_vec_db = None
self.retrieval_manager = None
self.kb_injector = None
self._initialized = False
self._session_deleted_callback_registered = False
@@ -66,31 +70,54 @@ class KnowledgeBaseManager:
try:
logger.info("正在初始化知识库模块...")
# 1. 检查并选择 Embedding Provider
embedding_provider = self._select_embedding_provider()
if not embedding_provider:
logger.warning("未配置 Embedding Provider,知识库功能无法使用")
return
# 2. 初始化数据库
# 初始化数据库
await self._init_kb_database()
await self._init_database()
# 3. 初始化向量数据库
await self._init_vector_db(embedding_provider)
# 初始化向量数据库工厂
await self._init_vector_db_factory()
# 4. 初始化解析器和分块器
parsers = self._init_parsers()
chunker = self._init_chunker()
# 初始化解析器和分块器
parsers = {
"txt": TextParser(),
"md": TextParser(),
"markdown": TextParser(),
"pdf": PDFParser(),
}
chunking_config = self.config.get("chunking", {})
chunker = FixedSizeChunker(
chunk_size=chunking_config.get("chunk_size", 512),
chunk_overlap=chunking_config.get("chunk_overlap", 50),
)
# 5. 初始化知识库管理器
await self._init_kb_manager(parsers, chunker)
# 初始化知识库管理器
files_path = self.config.get("storage", {}).get(
"files_path", "data/knowledge_base"
)
self.kb_manager = KBManager(
db=self.kb_db,
vec_db_factory=self.vec_db_factory,
storage_path=files_path,
parsers=parsers,
chunker=chunker,
provider_manager=self.provider_manager,
)
# 6. 初始化检索管理器
await self._init_retrieval_manager()
# 初始化检索管理器
sparse_retriever = SparseRetriever(self.kb_database)
rank_fusion = RankFusion(self.kb_database)
self.retrieval_manager = RetrievalManager(
vec_db_factory=self.vec_db_factory,
sparse_retriever=sparse_retriever,
rank_fusion=rank_fusion,
kb_db=self.kb_database,
)
# 7. 初始化上下文注入器
await self._init_injector()
# 初始化上下文注入器
self.kb_injector = KnowledgeBaseInjector(
kb_db=self.kb_database,
vec_db_factory=self.vec_db_factory,
retrieval_manager=self.retrieval_manager,
)
self._initialized = True
logger.info("知识库模块初始化完成")
@@ -106,8 +133,6 @@ class KnowledgeBaseManager:
async def _init_kb_database(self):
"""初始化知识库独立数据库"""
from astrbot.core.knowledge_base.kb_sqlite import KBSQLiteDatabase
db_path = self.config.get("storage", {}).get(
"kb_db_path", "data/knowledge_base/kb.db"
)
@@ -116,168 +141,16 @@ class KnowledgeBaseManager:
self.kb_db = KBSQLiteDatabase(db_path)
await self.kb_db.initialize()
await self.kb_db.migrate_to_v1()
logger.info(f"知识库独立数据库已初始化: {db_path}")
async def _init_database(self):
"""初始化知识库数据库操作类"""
from astrbot.core.knowledge_base.database import KBDatabase
self.kb_database = KBDatabase(self.kb_db)
logger.info(f"KnowledgeBase database initialized: {db_path}")
async def _init_vector_db(self, embedding_provider):
"""初始化向量数据库"""
from astrbot.core.db.vec_db.faiss_impl import FaissVecDB
async def _init_vector_db_factory(self):
"""初始化向量数据库工厂"""
storage_path = self.config.get("storage", {}).get(
"vector_db_path", "data/knowledge_base/vectors"
)
Path(storage_path).mkdir(parents=True, exist_ok=True)
self.kb_vec_db = FaissVecDB(
doc_store_path=f"{storage_path}/documents.db",
index_store_path=f"{storage_path}/index.faiss",
embedding_provider=embedding_provider,
)
await self.kb_vec_db.initialize()
def _init_parsers(self) -> dict:
"""初始化文档解析器"""
from astrbot.core.knowledge_base.parsers.text_parser import TextParser
from astrbot.core.knowledge_base.parsers.pdf_parser import PDFParser
return {
"txt": TextParser(),
"md": TextParser(),
"markdown": TextParser(),
"pdf": PDFParser(),
}
def _init_chunker(self):
"""初始化分块器"""
from astrbot.core.knowledge_base.chunking.fixed_size import FixedSizeChunker
chunking_config = self.config.get("chunking", {})
return FixedSizeChunker(
chunk_size=chunking_config.get("chunk_size", 512),
chunk_overlap=chunking_config.get("chunk_overlap", 50),
)
async def _init_kb_manager(self, parsers: dict, chunker):
"""初始化知识库管理器"""
from astrbot.core.knowledge_base.manager import KBManager
files_path = self.config.get("storage", {}).get(
"files_path", "data/knowledge_base"
)
self.kb_manager = KBManager(
db=self.kb_db, # 使用独立的知识库数据库
vec_db=self.kb_vec_db,
storage_path=files_path,
parsers=parsers,
chunker=chunker,
provider_manager=self.provider_manager,
)
async def _init_retrieval_manager(self):
"""初始化检索管理器"""
from astrbot.core.knowledge_base.retrieval.manager import RetrievalManager
from astrbot.core.knowledge_base.retrieval.sparse_retriever import (
SparseRetriever,
)
from astrbot.core.knowledge_base.retrieval.rank_fusion import RankFusion
sparse_retriever = SparseRetriever(self.kb_database)
rank_fusion = RankFusion(self.kb_database)
# 选择 Rerank Provider (可选)
rerank_provider = self._select_rerank_provider()
self.retrieval_manager = RetrievalManager(
vec_db=self.kb_vec_db,
sparse_retriever=sparse_retriever,
rank_fusion=rank_fusion,
kb_db=self.kb_database,
rerank_provider=rerank_provider,
)
async def _init_injector(self):
"""初始化上下文注入器"""
from astrbot.core.knowledge_base.injector import KnowledgeBaseInjector
self.kb_injector = KnowledgeBaseInjector(
kb_db=self.kb_database,
retrieval_manager=self.retrieval_manager,
)
def _select_embedding_provider(self):
"""选择 Embedding Provider
逻辑:
- 如果配置了 embedding_provider_id,则使用指定的 provider
- 如果没有配置,但有 embedding provider,则使用第一个
- 如果有多个 embedding provider 但没有指定,则警告并使用第一个
"""
embedding_providers = self.provider_manager.embedding_provider_insts
if not embedding_providers:
return None
configured_provider_id = self.config.get("embedding_provider_id")
if configured_provider_id:
# 按 ID 查找
for provider in embedding_providers:
provider_id = provider.meta().id
if provider_id == configured_provider_id:
logger.info(f"知识库使用 Embedding Provider: {provider_id}")
return provider
logger.warning(
f"未找到配置的 Embedding Provider ID: {configured_provider_id}"
f"将使用第一个可用的"
)
if len(embedding_providers) > 1 and not configured_provider_id:
provider = embedding_providers[0]
provider_id = provider.meta().id
logger.info(
f"检测到 {len(embedding_providers)} 个 Embedding Provider"
f"未在配置文件中指定 embedding_provider_id,将使用第一个: {provider_id}"
)
return provider
provider = embedding_providers[0]
provider_id = provider.meta().id
logger.info(f"知识库使用 Embedding Provider: {provider_id}")
return provider
def _select_rerank_provider(self):
"""选择 Rerank Provider (可选)"""
if not self.config.get("retrieval", {}).get("enable_rerank", True):
return None
rerank_providers = self.provider_manager.rerank_provider_insts
if not rerank_providers:
return None
configured_provider_id = self.config.get("rerank_provider_id")
if configured_provider_id:
for provider in rerank_providers:
provider_id = provider.meta().id
if provider_id == configured_provider_id:
logger.info(f"知识库使用 Rerank Provider: {provider_id}")
return provider
logger.warning(f"未找到配置的 Rerank Provider ID: {configured_provider_id}")
if len(rerank_providers) > 0:
provider = rerank_providers[0]
provider_id = provider.meta().id
logger.info(f"知识库使用 Rerank Provider: {provider_id}")
return provider
return None
self.vec_db_factory = VecDBFactory(storage_base_path=storage_path)
@property
def is_initialized(self) -> bool:
@@ -292,31 +165,6 @@ class KnowledgeBaseManager:
"""获取知识库上下文注入器"""
return self.kb_injector if self._initialized else None
def register_session_lifecycle_hooks(self, conversation_manager):
"""注册会话生命周期钩子
在会话删除时自动清理知识库配置,实现零侵入的级联清理。
Args:
conversation_manager: 会话管理器实例
"""
if self._session_deleted_callback_registered or not self._initialized:
return
async def on_session_deleted(session_id: str):
"""会话删除回调:清理知识库配置"""
try:
await self.kb_database.delete_session_kb_config_by_session_id(
session_id
)
logger.info(f"已清理会话知识库配置: {session_id}")
except Exception as e:
logger.error(f"清理会话知识库配置失败 ({session_id}): {e}")
conversation_manager.register_on_session_deleted(on_session_deleted)
self._session_deleted_callback_registered = True
logger.info("已注册知识库会话删除回调")
async def reinitialize(self):
"""重新初始化知识库模块
@@ -336,13 +184,13 @@ class KnowledgeBaseManager:
logger.info("正在终止知识库模块...")
# 关闭向量数据库连接
if self.kb_vec_db:
# 关闭向量数据库工厂(关闭所有向量数据库实例)
if self.vec_db_factory:
try:
await self.kb_vec_db.close()
logger.debug("向量数据库已关闭")
await self.vec_db_factory.close_all()
logger.debug("向量数据库工厂已关闭")
except Exception as e:
logger.warning(f"关闭向量数据库时出错: {e}")
logger.warning(f"关闭向量数据库工厂时出错: {e}")
# 关闭知识库独立数据库连接
if self.kb_db:
@@ -352,13 +200,6 @@ class KnowledgeBaseManager:
except Exception as e:
logger.warning(f"关闭知识库数据库时出错: {e}")
# 清理资源
self._initialized = False
self.kb_db = None
self.kb_database = None
self.kb_manager = None
self.kb_vec_db = None
self.retrieval_manager = None
self.kb_injector = None
logger.info("知识库模块已终止")
+73 -18
View File
@@ -10,12 +10,11 @@ from typing import Optional
import aiofiles
from sqlalchemy import func, select, update
from astrbot.core.db import BaseDatabase
from astrbot.core.db.vec_db.base import BaseVecDB
from .kb_sqlite import KBSQLiteDatabase
from astrbot.core.knowledge_base.chunking.base import BaseChunker
from astrbot.core.knowledge_base.models import KBChunk, KBDocument, KnowledgeBase
from astrbot.core.knowledge_base.parsers.base import BaseParser
from .vec_db_factory import VecDBFactory
class KBManager:
"""知识库管理器
@@ -29,15 +28,15 @@ class KBManager:
def __init__(
self,
db: BaseDatabase,
vec_db: BaseVecDB,
db: KBSQLiteDatabase,
vec_db_factory: VecDBFactory,
storage_path: str,
parsers: dict[str, BaseParser],
chunker: BaseChunker,
provider_manager=None,
):
self.db = db
self.vec_db = vec_db
self.vec_db_factory = vec_db_factory
self.storage_path = Path(storage_path)
self.media_path = self.storage_path / "media"
self.files_path = self.storage_path / "files"
@@ -49,6 +48,48 @@ class KBManager:
self.media_path.mkdir(parents=True, exist_ok=True)
self.files_path.mkdir(parents=True, exist_ok=True)
async def _get_embedding_provider_for_kb(self, kb_id: str):
"""根据知识库配置获取 Embedding Provider
Args:
kb_id: 知识库 ID
Returns:
EmbeddingProvider: Embedding Provider 实例
Raises:
ValueError: 如果找不到合适的 embedding provider
"""
from astrbot.core.knowledge_base.database import KBDatabase
# 获取知识库配置
kb_database = KBDatabase(self.db)
kb = await kb_database.get_kb_by_id(kb_id)
if not kb:
raise ValueError(f"知识库不存在: {kb_id}")
embedding_provider_id = kb.embedding_provider_id
# 如果没有 provider_manager,使用默认的第一个
if not self.provider_manager:
raise ValueError("Provider Manager 未初始化")
embedding_providers = self.provider_manager.embedding_provider_insts
if not embedding_providers:
raise ValueError("系统中没有可用的 Embedding Provider")
# 如果指定了 provider ID,则查找该 provider
if embedding_provider_id:
for provider in embedding_providers:
if provider.meta().id == embedding_provider_id:
return provider
raise ValueError(
f"未找到配置的 Embedding Provider: {embedding_provider_id}"
)
# 使用第一个可用的 provider
return embedding_providers[0]
# ===== 知识库操作 =====
async def create_kb(
@@ -77,7 +118,7 @@ class KBManager:
# 检查是否有可用的 rerank provider
has_rerank_provider = (
self.provider_manager
and hasattr(self.provider_manager, 'rerank_provider_insts')
and hasattr(self.provider_manager, "rerank_provider_insts")
and len(self.provider_manager.rerank_provider_insts) > 0
)
enable_rerank = has_rerank_provider
@@ -182,7 +223,10 @@ class KBManager:
for doc in docs:
await ops.delete_document(doc.doc_id)
# 3. 删除知识库记录
# 3. 删除向量数据库
await self.vec_db_factory.delete_vec_db(kb_id)
# 4. 删除知识库记录
async with self.db.get_db() as session:
stmt = select(KnowledgeBase).where(KnowledgeBase.kb_id == kb_id)
result = await session.execute(stmt)
@@ -257,11 +301,15 @@ class KBManager:
# 4. 文档分块
chunks_text = await self.chunker.chunk(text_content)
# 5. 生成向量并存储
# 5. 获取 Embedding Provider 和向量数据库
embedding_provider = await self._get_embedding_provider_for_kb(kb_id)
vec_db = await self.vec_db_factory.get_vec_db(kb_id, embedding_provider)
# 6. 生成向量并存储
saved_chunks = []
for idx, chunk_text in enumerate(chunks_text):
# 存储到向量数据库
vec_doc_id = await self.vec_db.insert(
vec_doc_id = await vec_db.insert(
content=chunk_text,
metadata={
"kb_id": kb_id,
@@ -282,7 +330,7 @@ class KBManager:
)
saved_chunks.append(chunk)
# 6. 保存文档元数据(事务)
# 7. 保存文档元数据(事务)
doc = KBDocument(
doc_id=doc_id,
kb_id=kb_id,
@@ -305,7 +353,7 @@ class KBManager:
await session.refresh(doc)
# 7. 更新知识库统计
# 8. 更新知识库统计
await self._update_kb_stats(kb_id)
return doc
@@ -316,12 +364,19 @@ class KBManager:
logger.error(f"文档上传失败,开始清理资源: {e}")
# 清理向量数据库
for vec_id in vec_doc_ids:
try:
await self.vec_db.delete(vec_id)
except Exception as ve:
logger.warning(f"清理向量失败 {vec_id}: {ve}")
# 获取知识库的向量数据库
try:
embedding_provider = await self._get_embedding_provider_for_kb(kb_id)
vec_db = await self.vec_db_factory.get_vec_db(kb_id, embedding_provider)
# 清理向量数据库
for vec_id in vec_doc_ids:
try:
await vec_db.delete(vec_id)
except Exception as ve:
logger.warning(f"清理向量失败 {vec_id}: {ve}")
except Exception as vfe:
logger.error(f"获取向量数据库失败: {vfe}")
# 清理多媒体文件
for media_path in media_paths:
+15 -4
View File
@@ -28,7 +28,7 @@ class KBManagerOps:
def __init__(self, manager: "KBManager"):
self.manager = manager
self.db = manager.db
self.vec_db = manager.vec_db
self.vec_db_factory = manager.vec_db_factory
self.media_path = manager.media_path
self.files_path = manager.files_path
@@ -75,6 +75,12 @@ class KBManagerOps:
chunks = await self.list_chunks(doc_id)
media_list = await self.list_media(doc_id)
# 获取知识库的向量数据库
embedding_provider = await self.manager._get_embedding_provider_for_kb(
doc.kb_id
)
vec_db = await self.vec_db_factory.get_vec_db(doc.kb_id, embedding_provider)
# ===== 第一阶段: 删除向量(可重试) =====
vec_ids_to_delete = [chunk.vec_doc_id for chunk in chunks]
deleted_vec_ids = []
@@ -82,7 +88,7 @@ class KBManagerOps:
for vec_id in vec_ids_to_delete:
try:
await self.vec_db.delete(vec_id)
await vec_db.delete(vec_id)
deleted_vec_ids.append(vec_id)
except Exception as e:
logger.error(f"删除向量失败: {vec_id}, {e}")
@@ -173,11 +179,16 @@ class KBManagerOps:
return False
doc_id = chunk.doc_id
kb_id = chunk.kb_id
vec_doc_id = chunk.vec_doc_id
# 2. 删除向量
# 2. 获取知识库的向量数据库并删除向量
try:
await self.vec_db.delete(vec_doc_id)
embedding_provider = await self.manager._get_embedding_provider_for_kb(
kb_id
)
vec_db = await self.vec_db_factory.get_vec_db(kb_id, embedding_provider)
await vec_db.delete(vec_doc_id)
except Exception as e:
logger.error(f"删除向量失败: {vec_doc_id}, {e}")
return False
+5 -41
View File
@@ -16,7 +16,7 @@ import uuid
from datetime import datetime, timezone
from typing import Optional
from sqlmodel import Field, SQLModel, Text, UniqueConstraint
from sqlmodel import Field, SQLModel, Text
class KnowledgeBase(SQLModel, table=True):
@@ -25,7 +25,7 @@ class KnowledgeBase(SQLModel, table=True):
存储知识库的基本信息和统计数据。
"""
__tablename__ = "knowledge_bases"
__tablename__ = "knowledge_bases" # type: ignore
id: int | None = Field(
primary_key=True, sa_column_kwargs={"autoincrement": True}, default=None
@@ -65,7 +65,7 @@ class KBDocument(SQLModel, table=True):
存储上传到知识库的文档元数据。
"""
__tablename__ = "kb_documents"
__tablename__ = "kb_documents" # type: ignore
id: int | None = Field(
primary_key=True, sa_column_kwargs={"autoincrement": True}, default=None
@@ -97,7 +97,7 @@ class KBChunk(SQLModel, table=True):
存储文档分块后的文本内容和向量索引关联信息。
"""
__tablename__ = "kb_chunks"
__tablename__ = "kb_chunks" # type: ignore
id: int | None = Field(
primary_key=True, sa_column_kwargs={"autoincrement": True}, default=None
@@ -124,7 +124,7 @@ class KBMedia(SQLModel, table=True):
存储从文档中提取的图片、视频等多媒体资源。
"""
__tablename__ = "kb_media"
__tablename__ = "kb_media" # type: ignore
id: int | None = Field(
primary_key=True, sa_column_kwargs={"autoincrement": True}, default=None
@@ -144,39 +144,3 @@ class KBMedia(SQLModel, table=True):
file_size: int = Field(nullable=False)
mime_type: str = Field(max_length=100, nullable=False)
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
class KBSessionConfig(SQLModel, table=True):
"""会话知识库配置表
存储会话或平台级别的知识库关联配置。
该表存储在知识库独立数据库中,保持完全解耦。
支持两种配置范围:
- platform: 平台级别配置 (如 'qq', 'telegram')
- session: 会话级别配置 (如 'qq:group:12345')
"""
__tablename__ = "kb_session_config"
id: int | None = Field(
primary_key=True, sa_column_kwargs={"autoincrement": True}, default=None
)
config_id: str = Field(
max_length=36,
nullable=False,
unique=True,
default_factory=lambda: str(uuid.uuid4()),
)
scope: str = Field(max_length=20, nullable=False)
scope_id: str = Field(max_length=255, nullable=False, index=True)
kb_ids: str = Field(sa_type=Text, nullable=False)
top_k: Optional[int] = Field(default=None, nullable=True)
enable_rerank: Optional[bool] = Field(default=None, nullable=True)
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
updated_at: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc),
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
)
__table_args__ = (UniqueConstraint("scope", "scope_id", name="uix_scope_scope_id"),)
@@ -3,16 +3,15 @@
协调稠密检索、稀疏检索和 Rerank,提供统一的检索接口
"""
import json
from dataclasses import dataclass
from typing import List, Optional
from typing import List
from astrbot.core.db.vec_db.base import BaseVecDB
from astrbot.core.knowledge_base.database import KBDatabase
from astrbot.core.knowledge_base.retrieval.rank_fusion import RankFusion
from astrbot.core.knowledge_base.retrieval.sparse_retriever import SparseRetriever
from astrbot.core.provider.provider import RerankProvider
from astrbot.core.db.vec_db.base import BaseVecDB, Result
from ..vec_db_factory import VecDBFactory
@dataclass
class RetrievalResult:
@@ -38,36 +37,34 @@ class RetrievalManager:
def __init__(
self,
vec_db: BaseVecDB,
vec_db_factory, # VecDBFactory
sparse_retriever: SparseRetriever,
rank_fusion: RankFusion,
kb_db: KBDatabase,
rerank_provider: Optional[RerankProvider] = None,
):
"""初始化检索管理器
Args:
vec_db: 向量数据库实例
vec_db_factory: 向量数据库工厂
sparse_retriever: 稀疏检索器
rank_fusion: 结果融合器
kb_db: 知识库数据库实例
rerank_provider: Rerank 提供商 (可选)
"""
self.vec_db = vec_db
self.vec_db_factory = vec_db_factory
self.sparse_retriever = sparse_retriever
self.rank_fusion = rank_fusion
self.kb_db = kb_db
self.rerank_provider = rerank_provider
async def retrieve(
self,
vec_db_factory: VecDBFactory,
query: str,
kb_ids: List[str],
top_k_dense: int = 50,
top_k_sparse: int = 50,
top_n_fusion: int = 20,
top_m_final: int = 5,
enable_rerank: bool = True,
rerank_provider: RerankProvider | None = None,
) -> List[RetrievalResult]:
"""混合检索
@@ -94,6 +91,7 @@ class RetrievalManager:
query=query,
kb_ids=kb_ids,
top_k=top_k_dense,
vec_db=vec_db,
)
# 2. 稀疏检索
@@ -131,13 +129,13 @@ class RetrievalManager:
)
)
# 5. Rerank (可选)
if enable_rerank and self.rerank_provider and retrieval_results:
# 5. Rerank
if rerank_provider and retrieval_results:
retrieval_results = await self._rerank(
query=query,
results=retrieval_results,
top_k=top_m_final,
rerank_provider=self.rerank_provider,
rerank_provider=rerank_provider,
)
else:
retrieval_results = retrieval_results[:top_m_final]
@@ -149,9 +147,12 @@ class RetrievalManager:
query: str,
kb_ids: List[str],
top_k: int,
vec_db: BaseVecDB,
):
"""稠密检索 (向量相似度)
为每个知识库使用独立的向量数据库进行检索,然后合并结果。
Args:
query: 查询文本
kb_ids: 知识库 ID 列表
@@ -160,28 +161,27 @@ class RetrievalManager:
Returns:
List[Result]: 检索结果列表
"""
# 直接调用向量数据库检索
vec_results = await self.vec_db.retrieve(
query=query,
top_k=top_k * len(kb_ids) * 2, # 增加候选数量以便过滤
)
all_results: list[Result] = []
# 过滤:只保留指定知识库的结果
filtered_results = []
for result in vec_results:
metadata_str = result.data.get("metadata", "{}")
for kb_id in kb_ids:
try:
metadata = json.loads(metadata_str)
except (json.JSONDecodeError, TypeError):
metadata = {}
vec_results = await vec_db.retrieve(
query=query,
top_k=top_k,
fetch_k=top_k * 2,
metadata_filters={"kb_id": kb_id},
)
if metadata.get("kb_id") in kb_ids:
filtered_results.append(result)
all_results.extend(vec_results)
except Exception as e:
from astrbot.core import logger
if len(filtered_results) >= top_k:
break
logger.warning(f"知识库 {kb_id} 稠密检索失败: {e}")
continue
return filtered_results[:top_k]
# 按相似度排序并返回 top_k
all_results.sort(key=lambda x: x.similarity, reverse=True)
return all_results[:top_k]
async def _rerank(
self,
@@ -1,157 +0,0 @@
"""会话知识库配置数据库操作
该模块封装会话知识库配置的数据库查询操作。
注意: 会话配置表 (kb_session_config) 存储在知识库独立数据库 (kb.db) 中,
而不是主数据库 (astrbot.db) 中,以实现完全解耦。
"""
import json
from typing import Optional
from sqlalchemy import select
from astrbot.core.knowledge_base.kb_sqlite import KBSQLiteDatabase
from astrbot.core.knowledge_base.models import KBSessionConfig
class SessionConfigDB:
"""会话知识库配置数据库操作类
职责:
- 提供会话知识库配置管理
- 统一异常处理
注意: 该类操作知识库独立数据库,实现完全解耦
"""
def __init__(self, db: KBSQLiteDatabase):
"""初始化会话配置数据库操作类
Args:
db: 知识库独立数据库实例 (kb.db),不是主数据库
"""
self.db = db
async def get_session_kb_ids(self, session_id: str) -> list[str]:
"""获取会话关联的知识库 ID 列表
查找顺序:
1. 会话级别配置 (优先)
2. 平台级别配置
3. 返回空列表
"""
async with self.db.get_db() as session:
# 1. 查找会话级别配置
stmt = select(KBSessionConfig).where(
KBSessionConfig.scope == "session",
KBSessionConfig.scope_id == session_id,
)
result = await session.execute(stmt)
config = result.scalar_one_or_none()
if config:
return json.loads(config.kb_ids)
# 2. 提取平台 ID (格式: platform:xxx:session_id)
parts = session_id.split(":")
if len(parts) >= 2:
platform_id = parts[0]
# 查找平台级别配置
stmt = select(KBSessionConfig).where(
KBSessionConfig.scope == "platform",
KBSessionConfig.scope_id == platform_id,
)
result = await session.execute(stmt)
config = result.scalar_one_or_none()
if config:
return json.loads(config.kb_ids)
# 3. 无配置
return []
async def set_session_kb_ids(
self,
scope: str,
scope_id: str,
kb_ids: list[str],
top_k: Optional[int] = None,
enable_rerank: Optional[bool] = None,
) -> KBSessionConfig:
"""设置会话知识库配置
Args:
scope: 配置范围 (session/platform)
scope_id: 范围标识 (会话 ID 或平台 ID)
kb_ids: 知识库 ID 列表
top_k: 返回结果数量 (可选)
enable_rerank: 是否启用 Rerank (可选)
"""
async with self.db.get_db() as session:
# 查找现有配置
stmt = select(KBSessionConfig).where(
KBSessionConfig.scope == scope,
KBSessionConfig.scope_id == scope_id,
)
result = await session.execute(stmt)
config = result.scalar_one_or_none()
if config:
# 更新现有配置
config.kb_ids = json.dumps(kb_ids)
if top_k is not None:
config.top_k = top_k
if enable_rerank is not None:
config.enable_rerank = enable_rerank
else:
# 创建新配置
config = KBSessionConfig(
scope=scope,
scope_id=scope_id,
kb_ids=json.dumps(kb_ids),
top_k=top_k,
enable_rerank=enable_rerank,
)
session.add(config)
await session.commit()
await session.refresh(config)
return config
async def delete_session_kb_config(self, scope: str, scope_id: str) -> bool:
"""删除会话知识库配置"""
async with self.db.get_db() as session:
stmt = select(KBSessionConfig).where(
KBSessionConfig.scope == scope,
KBSessionConfig.scope_id == scope_id,
)
result = await session.execute(stmt)
config = result.scalar_one_or_none()
if not config:
return False
await session.delete(config)
await session.commit()
return True
async def list_all_session_configs(
self, offset: int = 0, limit: int = 100, scope: Optional[str] = None
) -> list[KBSessionConfig]:
"""列出所有会话配置"""
async with self.db.get_db() as session:
stmt = select(KBSessionConfig)
if scope:
stmt = stmt.where(KBSessionConfig.scope == scope)
stmt = (
stmt.offset(offset)
.limit(limit)
.order_by(KBSessionConfig.created_at.desc())
)
result = await session.execute(stmt)
return list(result.scalars().all())
@@ -0,0 +1,161 @@
"""向量数据库工厂
负责为每个知识库创建和管理独立的向量数据库实例。
架构说明:
- 每个知识库拥有独立的向量数据库实例
- 向量数据库文件以 kb_id 命名
- 工厂类负责实例的创建、缓存和生命周期管理
"""
from pathlib import Path
from typing import Dict, Optional
from astrbot.core import logger
from astrbot.core.db.vec_db.base import BaseVecDB
from astrbot.core.db.vec_db.faiss_impl.vec_db import FaissVecDB
from astrbot.core.provider.provider import EmbeddingProvider
class VecDBFactory:
"""向量数据库工厂
职责:
- 为每个知识库创建独立的向量数据库实例
- 缓存已创建的实例以提高性能
- 管理向量数据库的生命周期
"""
def __init__(
self,
storage_base_path: str,
):
"""初始化向量数据库工厂
Args:
storage_base_path: 向量数据库存储基础路径
"""
self.storage_base_path = Path(storage_base_path)
self._instances: Dict[str, BaseVecDB] = {}
# 确保基础路径存在
self.storage_base_path.mkdir(parents=True, exist_ok=True)
async def get_vec_db(
self, kb_id: str, embedding_provider: EmbeddingProvider
) -> BaseVecDB:
"""获取或创建指定知识库的向量数据库实例
Args:
kb_id: 知识库 ID
embedding_provider: Embedding Provider 实例
Returns:
BaseVecDB: 向量数据库实例
"""
# 如果已经创建过,直接返回缓存的实例
if kb_id in self._instances:
return self._instances[kb_id]
# 创建新实例
vec_db = await self._create_vec_db(kb_id, embedding_provider)
self._instances[kb_id] = vec_db
logger.debug(f"创建知识库 {kb_id} 的向量数据库实例")
return vec_db
async def _create_vec_db(
self, kb_id: str, embedding_provider: EmbeddingProvider
) -> BaseVecDB:
"""创建向量数据库实例
Args:
kb_id: 知识库 ID
embedding_provider: Embedding Provider 实例
Returns:
BaseVecDB: 向量数据库实例
"""
# 为每个知识库创建独立的存储路径
kb_storage_path = self.storage_base_path / kb_id
kb_storage_path.mkdir(parents=True, exist_ok=True)
doc_store_path = str(kb_storage_path / "documents.db")
index_store_path = str(kb_storage_path / "index.faiss")
vec_db = FaissVecDB(
doc_store_path=doc_store_path,
index_store_path=index_store_path,
embedding_provider=embedding_provider,
)
await vec_db.initialize()
return vec_db
async def delete_vec_db(self, kb_id: str) -> bool:
"""删除指定知识库的向量数据库
Args:
kb_id: 知识库 ID
Returns:
bool: 是否删除成功
"""
# 关闭并移除缓存的实例
if kb_id in self._instances:
try:
await self._instances[kb_id].close()
except Exception as e:
logger.warning(f"关闭向量数据库失败 ({kb_id}): {e}")
del self._instances[kb_id]
# 删除文件系统中的向量数据库文件
kb_storage_path = self.storage_base_path / kb_id
if kb_storage_path.exists():
try:
import shutil
shutil.rmtree(kb_storage_path)
logger.info(f"已删除知识库 {kb_id} 的向量数据库文件")
return True
except Exception as e:
logger.error(f"删除向量数据库文件失败 ({kb_id}): {e}")
return False
return True
async def close_all(self):
"""关闭所有向量数据库实例"""
for kb_id, vec_db in list(self._instances.items()):
try:
await vec_db.close()
logger.debug(f"已关闭知识库 {kb_id} 的向量数据库")
except Exception as e:
logger.warning(f"关闭向量数据库失败 ({kb_id}): {e}")
self._instances.clear()
def has_instance(self, kb_id: str) -> bool:
"""检查是否已创建指定知识库的向量数据库实例
Args:
kb_id: 知识库 ID
Returns:
bool: 是否已创建实例
"""
return kb_id in self._instances
def get_cached_instance(self, kb_id: str) -> Optional[BaseVecDB]:
"""获取已缓存的向量数据库实例(不创建新实例)
Args:
kb_id: 知识库 ID
Returns:
Optional[BaseVecDB]: 向量数据库实例,如果不存在则返回 None
"""
return self._instances.get(kb_id)
+7 -1
View File
@@ -1280,7 +1280,13 @@ class KnowledgeBaseRoute(Route):
return (
Response()
.ok({"sessions": session_list, "total": len(session_list), "kb_id": kb_id})
.ok(
{
"sessions": session_list,
"total": len(session_list),
"kb_id": kb_id,
}
)
.__dict__
)