ad96d676e6
- 实现完整的知识库数据模型(知识库、文档、文档块、会话配置) - 实现基于 SQLite 的向量数据库存储和检索 - 实现文档解析器(PDF、TXT)和固定大小分块器 - 实现混合检索系统(密集向量检索 + BM25 稀疏检索 + RRF 融合) - 实现知识库生命周期管理和消息注入器 - 支持会话级别的知识库配置和关联
348 lines
12 KiB
Python
348 lines
12 KiB
Python
"""知识库数据库操作类
|
|
|
|
该模块封装知识库、文档、块、多媒体和会话配置相关的数据库查询操作。
|
|
|
|
注意:
|
|
- 该模块操作的是独立的知识库数据库 (data/knowledge_base/kb.db)
|
|
- 会话配置也存储在此数据库中,会话ID来源于主数据库
|
|
"""
|
|
|
|
import json
|
|
from typing import Optional
|
|
|
|
from sqlalchemy import func, select
|
|
|
|
from astrbot.core.knowledge_base.kb_sqlite import KBSQLiteDatabase
|
|
from astrbot.core.knowledge_base.models import (
|
|
KBChunk,
|
|
KBDocument,
|
|
KBMedia,
|
|
KBSessionConfig,
|
|
KnowledgeBase,
|
|
)
|
|
|
|
|
|
class KBDatabase:
|
|
"""知识库数据库操作类
|
|
|
|
职责:
|
|
- 封装知识库、文档、块、多媒体和会话配置的数据库查询操作
|
|
- 统一异常处理
|
|
|
|
注意:
|
|
- 该类操作独立的知识库数据库 (kb.db)
|
|
- 会话配置存储会话ID与知识库的绑定关系,会话ID来源于主数据库
|
|
"""
|
|
|
|
def __init__(self, kb_db: KBSQLiteDatabase):
|
|
"""初始化知识库数据库操作类
|
|
|
|
Args:
|
|
kb_db: 知识库独立数据库实例,而非主数据库
|
|
"""
|
|
self.db = kb_db
|
|
|
|
# ===== 知识库查询 =====
|
|
|
|
async def get_kb_by_id(self, kb_id: str) -> Optional[KnowledgeBase]:
|
|
"""根据 ID 获取知识库"""
|
|
async with self.db.get_db() as session:
|
|
stmt = select(KnowledgeBase).where(KnowledgeBase.kb_id == kb_id)
|
|
result = await session.execute(stmt)
|
|
return result.scalar_one_or_none()
|
|
|
|
async def get_kb_by_name(self, kb_name: str) -> Optional[KnowledgeBase]:
|
|
"""根据名称获取知识库"""
|
|
async with self.db.get_db() as session:
|
|
stmt = select(KnowledgeBase).where(KnowledgeBase.kb_name == kb_name)
|
|
result = await session.execute(stmt)
|
|
return result.scalar_one_or_none()
|
|
|
|
async def list_kbs(self, offset: int = 0, limit: int = 100) -> list[KnowledgeBase]:
|
|
"""列出所有知识库"""
|
|
async with self.db.get_db() as session:
|
|
stmt = (
|
|
select(KnowledgeBase)
|
|
.offset(offset)
|
|
.limit(limit)
|
|
.order_by(KnowledgeBase.created_at.desc())
|
|
)
|
|
result = await session.execute(stmt)
|
|
return list(result.scalars().all())
|
|
|
|
async def count_kbs(self) -> int:
|
|
"""统计知识库数量"""
|
|
async with self.db.get_db() as session:
|
|
stmt = select(func.count(KnowledgeBase.id))
|
|
result = await session.execute(stmt)
|
|
return result.scalar() or 0
|
|
|
|
# ===== 文档查询 =====
|
|
|
|
async def get_document_by_id(self, doc_id: str) -> Optional[KBDocument]:
|
|
"""根据 ID 获取文档"""
|
|
async with self.db.get_db() as session:
|
|
stmt = select(KBDocument).where(KBDocument.doc_id == doc_id)
|
|
result = await session.execute(stmt)
|
|
return result.scalar_one_or_none()
|
|
|
|
async def list_documents_by_kb(
|
|
self, kb_id: str, offset: int = 0, limit: int = 100
|
|
) -> list[KBDocument]:
|
|
"""列出知识库的所有文档"""
|
|
async with self.db.get_db() as session:
|
|
stmt = (
|
|
select(KBDocument)
|
|
.where(KBDocument.kb_id == kb_id)
|
|
.offset(offset)
|
|
.limit(limit)
|
|
.order_by(KBDocument.created_at.desc())
|
|
)
|
|
result = await session.execute(stmt)
|
|
return list(result.scalars().all())
|
|
|
|
async def count_documents_by_kb(self, kb_id: str) -> int:
|
|
"""统计知识库的文档数量"""
|
|
async with self.db.get_db() as session:
|
|
stmt = select(func.count(KBDocument.id)).where(KBDocument.kb_id == kb_id)
|
|
result = await session.execute(stmt)
|
|
return result.scalar() or 0
|
|
|
|
# ===== 块查询 =====
|
|
|
|
async def get_chunk_by_id(self, chunk_id: str) -> Optional[KBChunk]:
|
|
"""根据 ID 获取块"""
|
|
async with self.db.get_db() as session:
|
|
stmt = select(KBChunk).where(KBChunk.chunk_id == chunk_id)
|
|
result = await session.execute(stmt)
|
|
return result.scalar_one_or_none()
|
|
|
|
async def get_chunks_by_kb_ids(self, kb_ids: list[str]) -> list[KBChunk]:
|
|
"""根据知识库 ID 列表获取所有块"""
|
|
async with self.db.get_db() as session:
|
|
stmt = select(KBChunk).where(KBChunk.kb_id.in_(kb_ids))
|
|
result = await session.execute(stmt)
|
|
return list(result.scalars().all())
|
|
|
|
async def get_chunk_by_vec_doc_id(self, vec_doc_id: str) -> Optional[KBChunk]:
|
|
"""根据向量文档 ID 获取块"""
|
|
async with self.db.get_db() as session:
|
|
stmt = select(KBChunk).where(KBChunk.vec_doc_id == vec_doc_id)
|
|
result = await session.execute(stmt)
|
|
return result.scalar_one_or_none()
|
|
|
|
async def get_chunk_with_metadata(self, chunk_id: str) -> Optional[dict]:
|
|
"""获取块及其关联的文档和知识库元数据"""
|
|
async with self.db.get_db() as session:
|
|
stmt = (
|
|
select(KBChunk, KBDocument, KnowledgeBase)
|
|
.join(KBDocument, KBChunk.doc_id == KBDocument.doc_id)
|
|
.join(KnowledgeBase, KBChunk.kb_id == KnowledgeBase.kb_id)
|
|
.where(KBChunk.chunk_id == chunk_id)
|
|
)
|
|
result = await session.execute(stmt)
|
|
row = result.first()
|
|
|
|
if not row:
|
|
return None
|
|
|
|
chunk, doc, kb = row
|
|
return {
|
|
"chunk": chunk,
|
|
"document": doc,
|
|
"knowledge_base": kb,
|
|
}
|
|
|
|
async def list_chunks_by_doc(
|
|
self, doc_id: str, offset: int = 0, limit: int = 100
|
|
) -> list[KBChunk]:
|
|
"""列出文档的所有块"""
|
|
async with self.db.get_db() as session:
|
|
stmt = (
|
|
select(KBChunk)
|
|
.where(KBChunk.doc_id == doc_id)
|
|
.offset(offset)
|
|
.limit(limit)
|
|
.order_by(KBChunk.chunk_index)
|
|
)
|
|
result = await session.execute(stmt)
|
|
return list(result.scalars().all())
|
|
|
|
# ===== 多媒体查询 =====
|
|
|
|
async def list_media_by_doc(self, doc_id: str) -> list[KBMedia]:
|
|
"""列出文档的所有多媒体资源"""
|
|
async with self.db.get_db() as session:
|
|
stmt = select(KBMedia).where(KBMedia.doc_id == doc_id)
|
|
result = await session.execute(stmt)
|
|
return list(result.scalars().all())
|
|
|
|
async def get_media_by_id(self, media_id: str) -> Optional[KBMedia]:
|
|
"""根据 ID 获取多媒体资源"""
|
|
async with self.db.get_db() as session:
|
|
stmt = select(KBMedia).where(KBMedia.media_id == media_id)
|
|
result = await session.execute(stmt)
|
|
return result.scalar_one_or_none()
|
|
|
|
# ===== 会话配置查询 =====
|
|
|
|
async def get_session_kb_ids(self, session_id: str) -> list[str]:
|
|
"""获取会话关联的知识库 ID 列表
|
|
|
|
查找顺序:
|
|
1. 会话级别配置 (优先)
|
|
2. 平台级别配置
|
|
3. 返回空列表
|
|
|
|
Args:
|
|
session_id: 会话ID(来自主数据库)
|
|
|
|
Returns:
|
|
知识库ID列表
|
|
"""
|
|
async with self.db.get_db() as session:
|
|
# 1. 查找会话级别配置
|
|
stmt = select(KBSessionConfig).where(
|
|
KBSessionConfig.scope == "session",
|
|
KBSessionConfig.scope_id == session_id,
|
|
)
|
|
result = await session.execute(stmt)
|
|
config = result.scalar_one_or_none()
|
|
|
|
if config:
|
|
return json.loads(config.kb_ids)
|
|
|
|
# 2. 提取平台 ID (格式: platform:xxx:session_id)
|
|
parts = session_id.split(":")
|
|
if len(parts) >= 2:
|
|
platform_id = parts[0]
|
|
|
|
# 查找平台级别配置
|
|
stmt = select(KBSessionConfig).where(
|
|
KBSessionConfig.scope == "platform",
|
|
KBSessionConfig.scope_id == platform_id,
|
|
)
|
|
result = await session.execute(stmt)
|
|
config = result.scalar_one_or_none()
|
|
|
|
if config:
|
|
return json.loads(config.kb_ids)
|
|
|
|
# 3. 无配置
|
|
return []
|
|
|
|
async def set_session_kb_ids(
|
|
self,
|
|
scope: str,
|
|
scope_id: str,
|
|
kb_ids: list[str],
|
|
top_k: Optional[int] = None,
|
|
enable_rerank: Optional[bool] = None,
|
|
) -> KBSessionConfig:
|
|
"""设置会话知识库配置
|
|
|
|
Args:
|
|
scope: 配置范围 (session/platform)
|
|
scope_id: 范围标识 (会话 ID 或平台 ID,来自主数据库)
|
|
kb_ids: 知识库 ID 列表
|
|
top_k: 返回结果数量 (可选)
|
|
enable_rerank: 是否启用 Rerank (可选)
|
|
|
|
Returns:
|
|
配置对象
|
|
"""
|
|
async with self.db.get_db() as session:
|
|
# 查找现有配置
|
|
stmt = select(KBSessionConfig).where(
|
|
KBSessionConfig.scope == scope,
|
|
KBSessionConfig.scope_id == scope_id,
|
|
)
|
|
result = await session.execute(stmt)
|
|
config = result.scalar_one_or_none()
|
|
|
|
if config:
|
|
# 更新现有配置
|
|
config.kb_ids = json.dumps(kb_ids)
|
|
if top_k is not None:
|
|
config.top_k = top_k
|
|
if enable_rerank is not None:
|
|
config.enable_rerank = enable_rerank
|
|
else:
|
|
# 创建新配置
|
|
config = KBSessionConfig(
|
|
scope=scope,
|
|
scope_id=scope_id,
|
|
kb_ids=json.dumps(kb_ids),
|
|
top_k=top_k,
|
|
enable_rerank=enable_rerank,
|
|
)
|
|
session.add(config)
|
|
|
|
await session.commit()
|
|
await session.refresh(config)
|
|
return config
|
|
|
|
async def delete_session_kb_config(self, scope: str, scope_id: str) -> bool:
|
|
"""删除会话知识库配置
|
|
|
|
Args:
|
|
scope: 配置范围 (session/platform)
|
|
scope_id: 范围标识 (会话 ID 或平台 ID)
|
|
|
|
Returns:
|
|
是否删除成功
|
|
"""
|
|
async with self.db.get_db() as session:
|
|
stmt = select(KBSessionConfig).where(
|
|
KBSessionConfig.scope == scope,
|
|
KBSessionConfig.scope_id == scope_id,
|
|
)
|
|
result = await session.execute(stmt)
|
|
config = result.scalar_one_or_none()
|
|
|
|
if not config:
|
|
return False
|
|
|
|
await session.delete(config)
|
|
await session.commit()
|
|
return True
|
|
|
|
async def delete_session_kb_config_by_session_id(self, session_id: str) -> bool:
|
|
"""根据会话ID删除会话配置(用于主数据库会话删除时的级联清理)
|
|
|
|
Args:
|
|
session_id: 会话ID(来自主数据库)
|
|
|
|
Returns:
|
|
是否删除成功
|
|
"""
|
|
return await self.delete_session_kb_config("session", session_id)
|
|
|
|
async def list_all_session_configs(
|
|
self, offset: int = 0, limit: int = 100, scope: Optional[str] = None
|
|
) -> list[KBSessionConfig]:
|
|
"""列出所有会话配置
|
|
|
|
Args:
|
|
offset: 偏移量
|
|
limit: 限制数量
|
|
scope: 可选的范围过滤 (session/platform)
|
|
|
|
Returns:
|
|
会话配置列表
|
|
"""
|
|
async with self.db.get_db() as session:
|
|
stmt = select(KBSessionConfig)
|
|
|
|
if scope:
|
|
stmt = stmt.where(KBSessionConfig.scope == scope)
|
|
|
|
stmt = (
|
|
stmt.offset(offset)
|
|
.limit(limit)
|
|
.order_by(KBSessionConfig.created_at.desc())
|
|
)
|
|
|
|
result = await session.execute(stmt)
|
|
return list(result.scalars().all())
|