Files
AstrBot/astrbot/core/knowledge_base/database.py
T
lxfight ad96d676e6 feat: 实现知识库核心后端模块
- 实现完整的知识库数据模型(知识库、文档、文档块、会话配置)
- 实现基于 SQLite 的向量数据库存储和检索
- 实现文档解析器(PDF、TXT)和固定大小分块器
- 实现混合检索系统(密集向量检索 + BM25 稀疏检索 + RRF 融合)
- 实现知识库生命周期管理和消息注入器
- 支持会话级别的知识库配置和关联
2025-10-19 18:40:55 +08:00

348 lines
12 KiB
Python

"""知识库数据库操作类
该模块封装知识库、文档、块、多媒体和会话配置相关的数据库查询操作。
注意:
- 该模块操作的是独立的知识库数据库 (data/knowledge_base/kb.db)
- 会话配置也存储在此数据库中,会话ID来源于主数据库
"""
import json
from typing import Optional
from sqlalchemy import func, select
from astrbot.core.knowledge_base.kb_sqlite import KBSQLiteDatabase
from astrbot.core.knowledge_base.models import (
KBChunk,
KBDocument,
KBMedia,
KBSessionConfig,
KnowledgeBase,
)
class KBDatabase:
"""知识库数据库操作类
职责:
- 封装知识库、文档、块、多媒体和会话配置的数据库查询操作
- 统一异常处理
注意:
- 该类操作独立的知识库数据库 (kb.db)
- 会话配置存储会话ID与知识库的绑定关系,会话ID来源于主数据库
"""
def __init__(self, kb_db: KBSQLiteDatabase):
"""初始化知识库数据库操作类
Args:
kb_db: 知识库独立数据库实例,而非主数据库
"""
self.db = kb_db
# ===== 知识库查询 =====
async def get_kb_by_id(self, kb_id: str) -> Optional[KnowledgeBase]:
"""根据 ID 获取知识库"""
async with self.db.get_db() as session:
stmt = select(KnowledgeBase).where(KnowledgeBase.kb_id == kb_id)
result = await session.execute(stmt)
return result.scalar_one_or_none()
async def get_kb_by_name(self, kb_name: str) -> Optional[KnowledgeBase]:
"""根据名称获取知识库"""
async with self.db.get_db() as session:
stmt = select(KnowledgeBase).where(KnowledgeBase.kb_name == kb_name)
result = await session.execute(stmt)
return result.scalar_one_or_none()
async def list_kbs(self, offset: int = 0, limit: int = 100) -> list[KnowledgeBase]:
"""列出所有知识库"""
async with self.db.get_db() as session:
stmt = (
select(KnowledgeBase)
.offset(offset)
.limit(limit)
.order_by(KnowledgeBase.created_at.desc())
)
result = await session.execute(stmt)
return list(result.scalars().all())
async def count_kbs(self) -> int:
"""统计知识库数量"""
async with self.db.get_db() as session:
stmt = select(func.count(KnowledgeBase.id))
result = await session.execute(stmt)
return result.scalar() or 0
# ===== 文档查询 =====
async def get_document_by_id(self, doc_id: str) -> Optional[KBDocument]:
"""根据 ID 获取文档"""
async with self.db.get_db() as session:
stmt = select(KBDocument).where(KBDocument.doc_id == doc_id)
result = await session.execute(stmt)
return result.scalar_one_or_none()
async def list_documents_by_kb(
self, kb_id: str, offset: int = 0, limit: int = 100
) -> list[KBDocument]:
"""列出知识库的所有文档"""
async with self.db.get_db() as session:
stmt = (
select(KBDocument)
.where(KBDocument.kb_id == kb_id)
.offset(offset)
.limit(limit)
.order_by(KBDocument.created_at.desc())
)
result = await session.execute(stmt)
return list(result.scalars().all())
async def count_documents_by_kb(self, kb_id: str) -> int:
"""统计知识库的文档数量"""
async with self.db.get_db() as session:
stmt = select(func.count(KBDocument.id)).where(KBDocument.kb_id == kb_id)
result = await session.execute(stmt)
return result.scalar() or 0
# ===== 块查询 =====
async def get_chunk_by_id(self, chunk_id: str) -> Optional[KBChunk]:
"""根据 ID 获取块"""
async with self.db.get_db() as session:
stmt = select(KBChunk).where(KBChunk.chunk_id == chunk_id)
result = await session.execute(stmt)
return result.scalar_one_or_none()
async def get_chunks_by_kb_ids(self, kb_ids: list[str]) -> list[KBChunk]:
"""根据知识库 ID 列表获取所有块"""
async with self.db.get_db() as session:
stmt = select(KBChunk).where(KBChunk.kb_id.in_(kb_ids))
result = await session.execute(stmt)
return list(result.scalars().all())
async def get_chunk_by_vec_doc_id(self, vec_doc_id: str) -> Optional[KBChunk]:
"""根据向量文档 ID 获取块"""
async with self.db.get_db() as session:
stmt = select(KBChunk).where(KBChunk.vec_doc_id == vec_doc_id)
result = await session.execute(stmt)
return result.scalar_one_or_none()
async def get_chunk_with_metadata(self, chunk_id: str) -> Optional[dict]:
"""获取块及其关联的文档和知识库元数据"""
async with self.db.get_db() as session:
stmt = (
select(KBChunk, KBDocument, KnowledgeBase)
.join(KBDocument, KBChunk.doc_id == KBDocument.doc_id)
.join(KnowledgeBase, KBChunk.kb_id == KnowledgeBase.kb_id)
.where(KBChunk.chunk_id == chunk_id)
)
result = await session.execute(stmt)
row = result.first()
if not row:
return None
chunk, doc, kb = row
return {
"chunk": chunk,
"document": doc,
"knowledge_base": kb,
}
async def list_chunks_by_doc(
self, doc_id: str, offset: int = 0, limit: int = 100
) -> list[KBChunk]:
"""列出文档的所有块"""
async with self.db.get_db() as session:
stmt = (
select(KBChunk)
.where(KBChunk.doc_id == doc_id)
.offset(offset)
.limit(limit)
.order_by(KBChunk.chunk_index)
)
result = await session.execute(stmt)
return list(result.scalars().all())
# ===== 多媒体查询 =====
async def list_media_by_doc(self, doc_id: str) -> list[KBMedia]:
"""列出文档的所有多媒体资源"""
async with self.db.get_db() as session:
stmt = select(KBMedia).where(KBMedia.doc_id == doc_id)
result = await session.execute(stmt)
return list(result.scalars().all())
async def get_media_by_id(self, media_id: str) -> Optional[KBMedia]:
"""根据 ID 获取多媒体资源"""
async with self.db.get_db() as session:
stmt = select(KBMedia).where(KBMedia.media_id == media_id)
result = await session.execute(stmt)
return result.scalar_one_or_none()
# ===== 会话配置查询 =====
async def get_session_kb_ids(self, session_id: str) -> list[str]:
"""获取会话关联的知识库 ID 列表
查找顺序:
1. 会话级别配置 (优先)
2. 平台级别配置
3. 返回空列表
Args:
session_id: 会话ID(来自主数据库)
Returns:
知识库ID列表
"""
async with self.db.get_db() as session:
# 1. 查找会话级别配置
stmt = select(KBSessionConfig).where(
KBSessionConfig.scope == "session",
KBSessionConfig.scope_id == session_id,
)
result = await session.execute(stmt)
config = result.scalar_one_or_none()
if config:
return json.loads(config.kb_ids)
# 2. 提取平台 ID (格式: platform:xxx:session_id)
parts = session_id.split(":")
if len(parts) >= 2:
platform_id = parts[0]
# 查找平台级别配置
stmt = select(KBSessionConfig).where(
KBSessionConfig.scope == "platform",
KBSessionConfig.scope_id == platform_id,
)
result = await session.execute(stmt)
config = result.scalar_one_or_none()
if config:
return json.loads(config.kb_ids)
# 3. 无配置
return []
async def set_session_kb_ids(
self,
scope: str,
scope_id: str,
kb_ids: list[str],
top_k: Optional[int] = None,
enable_rerank: Optional[bool] = None,
) -> KBSessionConfig:
"""设置会话知识库配置
Args:
scope: 配置范围 (session/platform)
scope_id: 范围标识 (会话 ID 或平台 ID,来自主数据库)
kb_ids: 知识库 ID 列表
top_k: 返回结果数量 (可选)
enable_rerank: 是否启用 Rerank (可选)
Returns:
配置对象
"""
async with self.db.get_db() as session:
# 查找现有配置
stmt = select(KBSessionConfig).where(
KBSessionConfig.scope == scope,
KBSessionConfig.scope_id == scope_id,
)
result = await session.execute(stmt)
config = result.scalar_one_or_none()
if config:
# 更新现有配置
config.kb_ids = json.dumps(kb_ids)
if top_k is not None:
config.top_k = top_k
if enable_rerank is not None:
config.enable_rerank = enable_rerank
else:
# 创建新配置
config = KBSessionConfig(
scope=scope,
scope_id=scope_id,
kb_ids=json.dumps(kb_ids),
top_k=top_k,
enable_rerank=enable_rerank,
)
session.add(config)
await session.commit()
await session.refresh(config)
return config
async def delete_session_kb_config(self, scope: str, scope_id: str) -> bool:
"""删除会话知识库配置
Args:
scope: 配置范围 (session/platform)
scope_id: 范围标识 (会话 ID 或平台 ID)
Returns:
是否删除成功
"""
async with self.db.get_db() as session:
stmt = select(KBSessionConfig).where(
KBSessionConfig.scope == scope,
KBSessionConfig.scope_id == scope_id,
)
result = await session.execute(stmt)
config = result.scalar_one_or_none()
if not config:
return False
await session.delete(config)
await session.commit()
return True
async def delete_session_kb_config_by_session_id(self, session_id: str) -> bool:
"""根据会话ID删除会话配置(用于主数据库会话删除时的级联清理)
Args:
session_id: 会话ID(来自主数据库)
Returns:
是否删除成功
"""
return await self.delete_session_kb_config("session", session_id)
async def list_all_session_configs(
self, offset: int = 0, limit: int = 100, scope: Optional[str] = None
) -> list[KBSessionConfig]:
"""列出所有会话配置
Args:
offset: 偏移量
limit: 限制数量
scope: 可选的范围过滤 (session/platform)
Returns:
会话配置列表
"""
async with self.db.get_db() as session:
stmt = select(KBSessionConfig)
if scope:
stmt = stmt.where(KBSessionConfig.scope == scope)
stmt = (
stmt.offset(offset)
.limit(limit)
.order_by(KBSessionConfig.created_at.desc())
)
result = await session.execute(stmt)
return list(result.scalars().all())