feat: 实现知识库核心后端模块

- 实现完整的知识库数据模型(知识库、文档、文档块、会话配置)
- 实现基于 SQLite 的向量数据库存储和检索
- 实现文档解析器(PDF、TXT)和固定大小分块器
- 实现混合检索系统(密集向量检索 + BM25 稀疏检索 + RRF 融合)
- 实现知识库生命周期管理和消息注入器
- 支持会话级别的知识库配置和关联
This commit is contained in:
lxfight
2025-10-19 18:40:55 +08:00
parent 79333bbc35
commit ad96d676e6
20 changed files with 2862 additions and 0 deletions
+34
View File
@@ -0,0 +1,34 @@
"""
知识库管理模块
提供文档上传、解析、分块、向量化、检索等功能
"""
from astrbot.core.db.po import KBSessionConfig
from astrbot.core.knowledge_base.models import (
KBChunk,
KBDocument,
KBMedia,
KnowledgeBase,
)
# 注意: 以下导入在对应模块实现后取消注释
from .database import KBDatabase
from .manager import KBManager
from .manager_ops import KBManagerOps
from .session_config_db import SessionConfigDB
# from .injector import KnowledgeBaseInjector
__all__ = [
"KnowledgeBase",
"KBDocument",
"KBChunk",
"KBMedia",
"KBSessionConfig",
"KBDatabase",
"SessionConfigDB",
"KBManager",
"KBManagerOps",
# "KnowledgeBaseInjector",
]
@@ -0,0 +1,11 @@
"""
文档分块模块
"""
from .base import BaseChunker
from .fixed_size import FixedSizeChunker
__all__ = [
"BaseChunker",
"FixedSizeChunker",
]
@@ -0,0 +1,24 @@
"""文档分块器基类
定义了文档分块处理的抽象接口。
"""
from abc import ABC, abstractmethod
class BaseChunker(ABC):
"""分块器基类
所有分块器都应该继承此类并实现 chunk 方法。
"""
@abstractmethod
async def chunk(self, text: str) -> list[str]:
"""将文本分块
Args:
text: 输入文本
Returns:
list[str]: 分块后的文本列表
"""
@@ -0,0 +1,52 @@
"""固定大小分块器
按照固定的字符数将文本分块,支持重叠区域。
"""
from astrbot.core.knowledge_base.chunking.base import BaseChunker
class FixedSizeChunker(BaseChunker):
"""固定大小分块器
按照固定的字符数分块,并支持块之间的重叠。
"""
def __init__(self, chunk_size: int = 512, chunk_overlap: int = 50):
"""初始化分块器
Args:
chunk_size: 块的大小(字符数)
chunk_overlap: 块之间的重叠字符数
"""
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
async def chunk(self, text: str) -> list[str]:
"""固定大小分块
Args:
text: 输入文本
Returns:
list[str]: 分块后的文本列表
"""
chunks = []
start = 0
text_len = len(text)
while start < text_len:
end = start + self.chunk_size
chunk = text[start:end]
if chunk:
chunks.append(chunk)
# 移动窗口,保留重叠部分
start = end - self.chunk_overlap
# 防止无限循环: 如果重叠过大,直接移到end
if start >= end or self.chunk_overlap >= self.chunk_size:
start = end
return chunks
+347
View File
@@ -0,0 +1,347 @@
"""知识库数据库操作类
该模块封装知识库、文档、块、多媒体和会话配置相关的数据库查询操作。
注意:
- 该模块操作的是独立的知识库数据库 (data/knowledge_base/kb.db)
- 会话配置也存储在此数据库中,会话ID来源于主数据库
"""
import json
from typing import Optional
from sqlalchemy import func, select
from astrbot.core.knowledge_base.kb_sqlite import KBSQLiteDatabase
from astrbot.core.knowledge_base.models import (
KBChunk,
KBDocument,
KBMedia,
KBSessionConfig,
KnowledgeBase,
)
class KBDatabase:
"""知识库数据库操作类
职责:
- 封装知识库、文档、块、多媒体和会话配置的数据库查询操作
- 统一异常处理
注意:
- 该类操作独立的知识库数据库 (kb.db)
- 会话配置存储会话ID与知识库的绑定关系,会话ID来源于主数据库
"""
def __init__(self, kb_db: KBSQLiteDatabase):
"""初始化知识库数据库操作类
Args:
kb_db: 知识库独立数据库实例,而非主数据库
"""
self.db = kb_db
# ===== 知识库查询 =====
async def get_kb_by_id(self, kb_id: str) -> Optional[KnowledgeBase]:
"""根据 ID 获取知识库"""
async with self.db.get_db() as session:
stmt = select(KnowledgeBase).where(KnowledgeBase.kb_id == kb_id)
result = await session.execute(stmt)
return result.scalar_one_or_none()
async def get_kb_by_name(self, kb_name: str) -> Optional[KnowledgeBase]:
"""根据名称获取知识库"""
async with self.db.get_db() as session:
stmt = select(KnowledgeBase).where(KnowledgeBase.kb_name == kb_name)
result = await session.execute(stmt)
return result.scalar_one_or_none()
async def list_kbs(self, offset: int = 0, limit: int = 100) -> list[KnowledgeBase]:
"""列出所有知识库"""
async with self.db.get_db() as session:
stmt = (
select(KnowledgeBase)
.offset(offset)
.limit(limit)
.order_by(KnowledgeBase.created_at.desc())
)
result = await session.execute(stmt)
return list(result.scalars().all())
async def count_kbs(self) -> int:
"""统计知识库数量"""
async with self.db.get_db() as session:
stmt = select(func.count(KnowledgeBase.id))
result = await session.execute(stmt)
return result.scalar() or 0
# ===== 文档查询 =====
async def get_document_by_id(self, doc_id: str) -> Optional[KBDocument]:
"""根据 ID 获取文档"""
async with self.db.get_db() as session:
stmt = select(KBDocument).where(KBDocument.doc_id == doc_id)
result = await session.execute(stmt)
return result.scalar_one_or_none()
async def list_documents_by_kb(
self, kb_id: str, offset: int = 0, limit: int = 100
) -> list[KBDocument]:
"""列出知识库的所有文档"""
async with self.db.get_db() as session:
stmt = (
select(KBDocument)
.where(KBDocument.kb_id == kb_id)
.offset(offset)
.limit(limit)
.order_by(KBDocument.created_at.desc())
)
result = await session.execute(stmt)
return list(result.scalars().all())
async def count_documents_by_kb(self, kb_id: str) -> int:
"""统计知识库的文档数量"""
async with self.db.get_db() as session:
stmt = select(func.count(KBDocument.id)).where(KBDocument.kb_id == kb_id)
result = await session.execute(stmt)
return result.scalar() or 0
# ===== 块查询 =====
async def get_chunk_by_id(self, chunk_id: str) -> Optional[KBChunk]:
"""根据 ID 获取块"""
async with self.db.get_db() as session:
stmt = select(KBChunk).where(KBChunk.chunk_id == chunk_id)
result = await session.execute(stmt)
return result.scalar_one_or_none()
async def get_chunks_by_kb_ids(self, kb_ids: list[str]) -> list[KBChunk]:
"""根据知识库 ID 列表获取所有块"""
async with self.db.get_db() as session:
stmt = select(KBChunk).where(KBChunk.kb_id.in_(kb_ids))
result = await session.execute(stmt)
return list(result.scalars().all())
async def get_chunk_by_vec_doc_id(self, vec_doc_id: str) -> Optional[KBChunk]:
"""根据向量文档 ID 获取块"""
async with self.db.get_db() as session:
stmt = select(KBChunk).where(KBChunk.vec_doc_id == vec_doc_id)
result = await session.execute(stmt)
return result.scalar_one_or_none()
async def get_chunk_with_metadata(self, chunk_id: str) -> Optional[dict]:
"""获取块及其关联的文档和知识库元数据"""
async with self.db.get_db() as session:
stmt = (
select(KBChunk, KBDocument, KnowledgeBase)
.join(KBDocument, KBChunk.doc_id == KBDocument.doc_id)
.join(KnowledgeBase, KBChunk.kb_id == KnowledgeBase.kb_id)
.where(KBChunk.chunk_id == chunk_id)
)
result = await session.execute(stmt)
row = result.first()
if not row:
return None
chunk, doc, kb = row
return {
"chunk": chunk,
"document": doc,
"knowledge_base": kb,
}
async def list_chunks_by_doc(
self, doc_id: str, offset: int = 0, limit: int = 100
) -> list[KBChunk]:
"""列出文档的所有块"""
async with self.db.get_db() as session:
stmt = (
select(KBChunk)
.where(KBChunk.doc_id == doc_id)
.offset(offset)
.limit(limit)
.order_by(KBChunk.chunk_index)
)
result = await session.execute(stmt)
return list(result.scalars().all())
# ===== 多媒体查询 =====
async def list_media_by_doc(self, doc_id: str) -> list[KBMedia]:
"""列出文档的所有多媒体资源"""
async with self.db.get_db() as session:
stmt = select(KBMedia).where(KBMedia.doc_id == doc_id)
result = await session.execute(stmt)
return list(result.scalars().all())
async def get_media_by_id(self, media_id: str) -> Optional[KBMedia]:
"""根据 ID 获取多媒体资源"""
async with self.db.get_db() as session:
stmt = select(KBMedia).where(KBMedia.media_id == media_id)
result = await session.execute(stmt)
return result.scalar_one_or_none()
# ===== 会话配置查询 =====
async def get_session_kb_ids(self, session_id: str) -> list[str]:
"""获取会话关联的知识库 ID 列表
查找顺序:
1. 会话级别配置 (优先)
2. 平台级别配置
3. 返回空列表
Args:
session_id: 会话ID(来自主数据库)
Returns:
知识库ID列表
"""
async with self.db.get_db() as session:
# 1. 查找会话级别配置
stmt = select(KBSessionConfig).where(
KBSessionConfig.scope == "session",
KBSessionConfig.scope_id == session_id,
)
result = await session.execute(stmt)
config = result.scalar_one_or_none()
if config:
return json.loads(config.kb_ids)
# 2. 提取平台 ID (格式: platform:xxx:session_id)
parts = session_id.split(":")
if len(parts) >= 2:
platform_id = parts[0]
# 查找平台级别配置
stmt = select(KBSessionConfig).where(
KBSessionConfig.scope == "platform",
KBSessionConfig.scope_id == platform_id,
)
result = await session.execute(stmt)
config = result.scalar_one_or_none()
if config:
return json.loads(config.kb_ids)
# 3. 无配置
return []
async def set_session_kb_ids(
self,
scope: str,
scope_id: str,
kb_ids: list[str],
top_k: Optional[int] = None,
enable_rerank: Optional[bool] = None,
) -> KBSessionConfig:
"""设置会话知识库配置
Args:
scope: 配置范围 (session/platform)
scope_id: 范围标识 (会话 ID 或平台 ID,来自主数据库)
kb_ids: 知识库 ID 列表
top_k: 返回结果数量 (可选)
enable_rerank: 是否启用 Rerank (可选)
Returns:
配置对象
"""
async with self.db.get_db() as session:
# 查找现有配置
stmt = select(KBSessionConfig).where(
KBSessionConfig.scope == scope,
KBSessionConfig.scope_id == scope_id,
)
result = await session.execute(stmt)
config = result.scalar_one_or_none()
if config:
# 更新现有配置
config.kb_ids = json.dumps(kb_ids)
if top_k is not None:
config.top_k = top_k
if enable_rerank is not None:
config.enable_rerank = enable_rerank
else:
# 创建新配置
config = KBSessionConfig(
scope=scope,
scope_id=scope_id,
kb_ids=json.dumps(kb_ids),
top_k=top_k,
enable_rerank=enable_rerank,
)
session.add(config)
await session.commit()
await session.refresh(config)
return config
async def delete_session_kb_config(self, scope: str, scope_id: str) -> bool:
"""删除会话知识库配置
Args:
scope: 配置范围 (session/platform)
scope_id: 范围标识 (会话 ID 或平台 ID)
Returns:
是否删除成功
"""
async with self.db.get_db() as session:
stmt = select(KBSessionConfig).where(
KBSessionConfig.scope == scope,
KBSessionConfig.scope_id == scope_id,
)
result = await session.execute(stmt)
config = result.scalar_one_or_none()
if not config:
return False
await session.delete(config)
await session.commit()
return True
async def delete_session_kb_config_by_session_id(self, session_id: str) -> bool:
"""根据会话ID删除会话配置(用于主数据库会话删除时的级联清理)
Args:
session_id: 会话ID(来自主数据库)
Returns:
是否删除成功
"""
return await self.delete_session_kb_config("session", session_id)
async def list_all_session_configs(
self, offset: int = 0, limit: int = 100, scope: Optional[str] = None
) -> list[KBSessionConfig]:
"""列出所有会话配置
Args:
offset: 偏移量
limit: 限制数量
scope: 可选的范围过滤 (session/platform)
Returns:
会话配置列表
"""
async with self.db.get_db() as session:
stmt = select(KBSessionConfig)
if scope:
stmt = stmt.where(KBSessionConfig.scope == scope)
stmt = (
stmt.offset(offset)
.limit(limit)
.order_by(KBSessionConfig.created_at.desc())
)
result = await session.execute(stmt)
return list(result.scalars().all())
+139
View File
@@ -0,0 +1,139 @@
"""知识库上下文注入器
负责检索相关知识并格式化为 LLM 可用的上下文文本
"""
from typing import List, Optional
from astrbot.core.knowledge_base.database import KBDatabase
from astrbot.core.knowledge_base.retrieval.manager import (
RetrievalManager,
RetrievalResult,
)
class KnowledgeBaseInjector:
"""知识库上下文注入器
职责:
- 检索相关知识
- 格式化为上下文文本
- 注入到 LLM Prompt
"""
def __init__(
self,
kb_db: KBDatabase,
retrieval_manager: RetrievalManager,
):
"""初始化知识库上下文注入器
Args:
kb_db: 知识库数据库实例
retrieval_manager: 检索管理器实例
"""
self.kb_db = kb_db
self.retrieval_manager = retrieval_manager
async def retrieve_and_inject(
self,
unified_msg_origin: str,
query: str,
top_k: int = 5,
) -> Optional[dict]:
"""检索并注入知识库上下文
Args:
unified_msg_origin: 统一消息来源 ID (会话 ID)
query: 用户查询
top_k: 返回结果数量
Returns:
Optional[dict]: 包含检索结果和格式化上下文的字典,如果无结果则返回 None
{
"context_text": str, # 格式化的上下文文本
"results": List[dict], # 原始检索结果列表
}
"""
# 1. 获取会话关联的知识库
kb_ids = await self.kb_db.get_session_kb_ids(unified_msg_origin)
if not kb_ids:
return None
# 2. 检索知识
results = await self.retrieval_manager.retrieve(
query=query,
kb_ids=kb_ids,
top_m_final=top_k,
)
if not results:
return None
# 3. 格式化上下文
context_text = self._format_context(results)
# 4. 转换结果为字典格式
results_dict = [
{
"chunk_id": r.chunk_id,
"doc_id": r.doc_id,
"kb_id": r.kb_id,
"kb_name": r.kb_name,
"doc_name": r.doc_name,
"chunk_index": r.metadata.get("chunk_index", 0),
"content": r.content,
"score": r.score,
}
for r in results
]
return {
"context_text": context_text,
"results": results_dict,
}
async def inject(
self,
session_id: str,
query: str,
top_k: int = 5,
) -> Optional[str]:
"""注入知识库上下文 (简化版本,仅返回文本)
Args:
session_id: 会话 ID (来自主数据库)
query: 用户查询
top_k: 返回结果数量
Returns:
Optional[str]: 格式化的知识上下文,如果无结果则返回 None
"""
result = await self.retrieve_and_inject(
unified_msg_origin=session_id,
query=query,
top_k=top_k,
)
return result["context_text"] if result else None
def _format_context(self, results: List[RetrievalResult]) -> str:
"""格式化知识上下文
Args:
results: 检索结果列表
Returns:
str: 格式化的上下文文本
"""
lines = ["以下是相关的知识库内容,请参考这些信息回答用户的问题:\n"]
for i, result in enumerate(results, 1):
lines.append(f"【知识 {i}")
lines.append(f"来源: {result.kb_name} / {result.doc_name}")
lines.append(f"内容: {result.content}")
lines.append(f"相关度: {result.score:.2f}")
lines.append("")
return "\n".join(lines)
@@ -0,0 +1,358 @@
"""
知识库管理器
负责知识库模块的初始化、配置和资源管理
架构说明:
- 知识库数据存储在独立的数据库 (data/knowledge_base/kb.db)
- 会话配置存储在主数据库 (data/astrbot.db) 以便于会话关联
"""
from pathlib import Path
from astrbot.core import logger
from astrbot.core.db import BaseDatabase
from astrbot.core.provider.manager import ProviderManager
class KnowledgeBaseManager:
"""知识库管理器
职责:
- 知识库模块的初始化
- Embedding Provider 和 Rerank Provider 的选择
- 各个子组件的协调管理
- 注册会话删除回调,实现级联清理
架构说明:
- 知识库数据存储在独立数据库 (kb.db)
- 会话配置存储在独立数据库 (kb.db),会话ID来自主数据库
- 通过回调机制实现与主数据库的生命周期同步
"""
def __init__(
self,
config: dict,
main_db: BaseDatabase,
provider_manager: ProviderManager,
):
"""初始化知识库管理器
Args:
config: 配置字典
main_db: 主数据库实例 (不直接使用,仅用于类型引用)
provider_manager: Provider 管理器
"""
self.config = config.get("knowledge_base", {})
self.provider_manager = provider_manager
# 知识库独立数据库
self.kb_db = None
# 组件实例
self.kb_database = None
self.kb_manager = None
self.kb_vec_db = None
self.retrieval_manager = None
self.kb_injector = None
self._initialized = False
self._session_deleted_callback_registered = False
async def initialize(self):
"""初始化知识库模块"""
if not self.config.get("enabled", False):
logger.info("知识库功能未启用")
return
try:
logger.info("正在初始化知识库模块...")
# 1. 检查并选择 Embedding Provider
embedding_provider = self._select_embedding_provider()
if not embedding_provider:
logger.warning("未配置 Embedding Provider,知识库功能无法使用")
return
# 2. 初始化数据库
await self._init_kb_database()
await self._init_database()
# 3. 初始化向量数据库
await self._init_vector_db(embedding_provider)
# 4. 初始化解析器和分块器
parsers = self._init_parsers()
chunker = self._init_chunker()
# 5. 初始化知识库管理器
await self._init_kb_manager(parsers, chunker)
# 6. 初始化检索管理器
await self._init_retrieval_manager()
# 7. 初始化上下文注入器
await self._init_injector()
self._initialized = True
logger.info("知识库模块初始化完成")
except ImportError as e:
logger.error(f"知识库模块导入失败: {e}")
logger.warning("请确保已安装所需依赖: pypdf, aiofiles, Pillow, rank-bm25")
except Exception as e:
logger.error(f"知识库模块初始化失败: {e}")
import traceback
logger.error(traceback.format_exc())
async def _init_kb_database(self):
"""初始化知识库独立数据库"""
from astrbot.core.knowledge_base.kb_sqlite import KBSQLiteDatabase
db_path = self.config.get("storage", {}).get(
"kb_db_path", "data/knowledge_base/kb.db"
)
Path(db_path).parent.mkdir(parents=True, exist_ok=True)
self.kb_db = KBSQLiteDatabase(db_path)
await self.kb_db.initialize()
await self.kb_db.migrate_to_v1()
logger.info(f"知识库独立数据库已初始化: {db_path}")
async def _init_database(self):
"""初始化知识库数据库操作类"""
from astrbot.core.knowledge_base.database import KBDatabase
self.kb_database = KBDatabase(self.kb_db)
async def _init_vector_db(self, embedding_provider):
"""初始化向量数据库"""
from astrbot.core.db.vec_db.faiss_impl import FaissVecDB
storage_path = self.config.get("storage", {}).get(
"vector_db_path", "data/knowledge_base/vectors"
)
Path(storage_path).mkdir(parents=True, exist_ok=True)
self.kb_vec_db = FaissVecDB(
doc_store_path=f"{storage_path}/documents.db",
index_store_path=f"{storage_path}/index.faiss",
embedding_provider=embedding_provider,
)
await self.kb_vec_db.initialize()
def _init_parsers(self) -> dict:
"""初始化文档解析器"""
from astrbot.core.knowledge_base.parsers.text_parser import TextParser
from astrbot.core.knowledge_base.parsers.pdf_parser import PDFParser
return {
"txt": TextParser(),
"md": TextParser(),
"markdown": TextParser(),
"pdf": PDFParser(),
}
def _init_chunker(self):
"""初始化分块器"""
from astrbot.core.knowledge_base.chunking.fixed_size import FixedSizeChunker
chunking_config = self.config.get("chunking", {})
return FixedSizeChunker(
chunk_size=chunking_config.get("chunk_size", 512),
chunk_overlap=chunking_config.get("chunk_overlap", 50),
)
async def _init_kb_manager(self, parsers: dict, chunker):
"""初始化知识库管理器"""
from astrbot.core.knowledge_base.manager import KBManager
files_path = self.config.get("storage", {}).get(
"files_path", "data/knowledge_base"
)
self.kb_manager = KBManager(
db=self.kb_db, # 使用独立的知识库数据库
vec_db=self.kb_vec_db,
storage_path=files_path,
parsers=parsers,
chunker=chunker,
)
async def _init_retrieval_manager(self):
"""初始化检索管理器"""
from astrbot.core.knowledge_base.retrieval.manager import RetrievalManager
from astrbot.core.knowledge_base.retrieval.sparse_retriever import (
SparseRetriever,
)
from astrbot.core.knowledge_base.retrieval.rank_fusion import RankFusion
sparse_retriever = SparseRetriever(self.kb_database)
rank_fusion = RankFusion(self.kb_database)
# 选择 Rerank Provider (可选)
rerank_provider = self._select_rerank_provider()
self.retrieval_manager = RetrievalManager(
vec_db=self.kb_vec_db,
sparse_retriever=sparse_retriever,
rank_fusion=rank_fusion,
kb_db=self.kb_database,
rerank_provider=rerank_provider,
)
async def _init_injector(self):
"""初始化上下文注入器"""
from astrbot.core.knowledge_base.injector import KnowledgeBaseInjector
self.kb_injector = KnowledgeBaseInjector(
kb_db=self.kb_database,
retrieval_manager=self.retrieval_manager,
)
def _select_embedding_provider(self):
"""选择 Embedding Provider
逻辑:
- 如果配置了 embedding_provider_id,则使用指定的 provider
- 如果没有配置,但有 embedding provider,则使用第一个
- 如果有多个 embedding provider 但没有指定,则警告并使用第一个
"""
embedding_providers = self.provider_manager.embedding_provider_insts
if not embedding_providers:
return None
configured_provider_id = self.config.get("embedding_provider_id")
if configured_provider_id:
# 按 ID 查找
for provider in embedding_providers:
provider_id = provider.meta().id
if provider_id == configured_provider_id:
logger.info(f"知识库使用 Embedding Provider: {provider_id}")
return provider
logger.warning(
f"未找到配置的 Embedding Provider ID: {configured_provider_id}"
f"将使用第一个可用的"
)
if len(embedding_providers) > 1 and not configured_provider_id:
logger.warning(
f"检测到 {len(embedding_providers)} 个 Embedding Provider"
f"但未指定使用哪个,将默认使用第一个"
)
provider = embedding_providers[0]
provider_id = provider.meta().id
logger.info(f"知识库使用 Embedding Provider: {provider_id}")
return provider
def _select_rerank_provider(self):
"""选择 Rerank Provider (可选)"""
if not self.config.get("retrieval", {}).get("enable_rerank", True):
return None
rerank_providers = self.provider_manager.rerank_provider_insts
if not rerank_providers:
return None
configured_provider_id = self.config.get("rerank_provider_id")
if configured_provider_id:
for provider in rerank_providers:
provider_id = provider.meta().id
if provider_id == configured_provider_id:
logger.info(f"知识库使用 Rerank Provider: {provider_id}")
return provider
logger.warning(f"未找到配置的 Rerank Provider ID: {configured_provider_id}")
if len(rerank_providers) > 0:
provider = rerank_providers[0]
provider_id = provider.meta().id
logger.info(f"知识库使用 Rerank Provider: {provider_id}")
return provider
return None
@property
def is_initialized(self) -> bool:
"""检查是否已初始化"""
return self._initialized
def get_kb_manager(self):
"""获取知识库管理器"""
return self.kb_manager if self._initialized else None
def get_kb_injector(self):
"""获取知识库上下文注入器"""
return self.kb_injector if self._initialized else None
def register_session_lifecycle_hooks(self, conversation_manager):
"""注册会话生命周期钩子
在会话删除时自动清理知识库配置,实现零侵入的级联清理。
Args:
conversation_manager: 会话管理器实例
"""
if self._session_deleted_callback_registered or not self._initialized:
return
async def on_session_deleted(session_id: str):
"""会话删除回调:清理知识库配置"""
try:
await self.kb_database.delete_session_kb_config_by_session_id(session_id)
logger.info(f"已清理会话知识库配置: {session_id}")
except Exception as e:
logger.error(f"清理会话知识库配置失败 ({session_id}): {e}")
conversation_manager.register_on_session_deleted(on_session_deleted)
self._session_deleted_callback_registered = True
logger.info("已注册知识库会话删除回调")
async def reinitialize(self):
"""重新初始化知识库模块
用于在运行时动态初始化知识库模块(例如用户添加了 embedding provider 后)
"""
if self._initialized:
logger.info("知识库模块已初始化,将重新初始化")
await self.terminate()
await self.initialize()
return self._initialized
async def terminate(self):
"""终止知识库模块,清理资源"""
if not self._initialized:
return
logger.info("正在终止知识库模块...")
# 关闭向量数据库连接
if self.kb_vec_db:
try:
await self.kb_vec_db.close()
logger.debug("向量数据库已关闭")
except Exception as e:
logger.warning(f"关闭向量数据库时出错: {e}")
# 关闭知识库独立数据库连接
if self.kb_db:
try:
await self.kb_db.close()
logger.debug("知识库数据库已关闭")
except Exception as e:
logger.warning(f"关闭知识库数据库时出错: {e}")
# 清理资源
self._initialized = False
self.kb_db = None
self.kb_database = None
self.kb_manager = None
self.kb_vec_db = None
self.retrieval_manager = None
self.kb_injector = None
logger.info("知识库模块已终止")
+231
View File
@@ -0,0 +1,231 @@
"""
知识库独立 SQLite 数据库
该模块提供知识库专用的独立 SQLite 数据库,与主数据库 (astrbot.db) 完全隔离。
职责:
- 管理知识库相关表 (knowledge_bases, kb_documents, kb_chunks, kb_media)
- 提供数据库连接和会话管理
- 执行数据库迁移和初始化
"""
from contextlib import asynccontextmanager
from pathlib import Path
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from astrbot.core import logger
class KBSQLiteDatabase:
"""知识库独立 SQLite 数据库
与主数据库 (astrbot.db) 完全隔离的独立数据库,专门用于存储知识库数据。
特点:
- 数据隔离: 知识库数据不会影响主数据库格式
- 独立备份: 可以单独备份和恢复知识库数据
- 性能隔离: 大量知识库查询不会影响主业务性能
"""
def __init__(self, db_path: str = "data/knowledge_base/kb.db") -> None:
"""初始化知识库数据库
Args:
db_path: 数据库文件路径,默认为 data/knowledge_base/kb.db
"""
self.db_path = db_path
self.DATABASE_URL = f"sqlite+aiosqlite:///{db_path}"
self.inited = False
# 确保目录存在
Path(db_path).parent.mkdir(parents=True, exist_ok=True)
# 创建异步引擎
self.engine = create_async_engine(
self.DATABASE_URL,
echo=False,
pool_pre_ping=True,
pool_recycle=3600,
)
# 创建会话工厂
self.async_session = async_sessionmaker(
self.engine,
class_=AsyncSession,
expire_on_commit=False,
)
@asynccontextmanager
async def get_db(self):
"""获取数据库会话
用法:
async with kb_db.get_db() as session:
# 执行数据库操作
result = await session.execute(stmt)
"""
async with self.async_session() as session:
yield session
async def initialize(self) -> None:
"""初始化数据库,创建表并配置 SQLite 参数"""
from astrbot.core.knowledge_base.models import (
KBChunk,
KBDocument,
KBMedia,
KBSessionConfig,
KnowledgeBase,
)
from sqlmodel import SQLModel
async with self.engine.begin() as conn:
# 创建所有知识库相关表
await conn.run_sync(SQLModel.metadata.create_all)
# 配置 SQLite 性能优化参数
await conn.execute(text("PRAGMA journal_mode=WAL"))
await conn.execute(text("PRAGMA synchronous=NORMAL"))
await conn.execute(text("PRAGMA cache_size=20000"))
await conn.execute(text("PRAGMA temp_store=MEMORY"))
await conn.execute(text("PRAGMA mmap_size=134217728"))
await conn.execute(text("PRAGMA optimize"))
await conn.commit()
self.inited = True
logger.info(f"知识库数据库已初始化: {self.db_path}")
async def migrate_to_v1(self) -> None:
"""执行知识库数据库 v1 迁移
创建所有必要的索引以优化查询性能
"""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
# 创建知识库表索引
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_kb_kb_id "
"ON knowledge_bases(kb_id)"
)
)
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_kb_name "
"ON knowledge_bases(kb_name)"
)
)
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_kb_created_at "
"ON knowledge_bases(created_at)"
)
)
# 创建文档表索引
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_doc_doc_id "
"ON kb_documents(doc_id)"
)
)
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_doc_kb_id "
"ON kb_documents(kb_id)"
)
)
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_doc_name "
"ON kb_documents(doc_name)"
)
)
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_doc_type "
"ON kb_documents(file_type)"
)
)
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_doc_created_at "
"ON kb_documents(created_at)"
)
)
# 创建块表索引
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_chunk_chunk_id "
"ON kb_chunks(chunk_id)"
)
)
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_chunk_doc_id "
"ON kb_chunks(doc_id)"
)
)
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_chunk_kb_id "
"ON kb_chunks(kb_id)"
)
)
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_chunk_vec_doc_id "
"ON kb_chunks(vec_doc_id)"
)
)
# 创建多媒体表索引
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_media_media_id "
"ON kb_media(media_id)"
)
)
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_media_doc_id "
"ON kb_media(doc_id)"
)
)
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_media_kb_id "
"ON kb_media(kb_id)"
)
)
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_media_type "
"ON kb_media(media_type)"
)
)
# 创建会话配置表索引
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_session_config_scope_id "
"ON kb_session_config(scope_id)"
)
)
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_session_config_scope "
"ON kb_session_config(scope)"
)
)
await session.commit()
logger.info("知识库数据库迁移 v1 完成")
async def close(self) -> None:
"""关闭数据库连接"""
await self.engine.dispose()
logger.info(f"知识库数据库已关闭: {self.db_path}")
+349
View File
@@ -0,0 +1,349 @@
"""知识库管理器
该模块提供知识库的CRUD操作和文档上传处理流程。
"""
import uuid
from pathlib import Path
from typing import Optional
import aiofiles
from sqlalchemy import func, select, update
from astrbot.core.db import BaseDatabase
from astrbot.core.db.vec_db.base import BaseVecDB
from astrbot.core.knowledge_base.chunking.base import BaseChunker
from astrbot.core.knowledge_base.models import KBChunk, KBDocument, KnowledgeBase
from astrbot.core.knowledge_base.parsers.base import BaseParser
class KBManager:
"""知识库管理器
职责:
- 知识库的 CRUD 操作
- 文档上传与解析
- 文档块生成与存储
- 多媒体资源管理
"""
def __init__(
self,
db: BaseDatabase,
vec_db: BaseVecDB,
storage_path: str,
parsers: dict[str, BaseParser],
chunker: BaseChunker,
):
self.db = db
self.vec_db = vec_db
self.storage_path = Path(storage_path)
self.media_path = self.storage_path / "media"
self.files_path = self.storage_path / "files"
self.parsers = parsers
self.chunker = chunker
# 确保目录存在
self.media_path.mkdir(parents=True, exist_ok=True)
self.files_path.mkdir(parents=True, exist_ok=True)
# ===== 知识库操作 =====
async def create_kb(
self,
kb_name: str,
description: Optional[str] = None,
emoji: Optional[str] = None,
embedding_provider_id: Optional[str] = None,
rerank_provider_id: Optional[str] = None,
chunk_size: Optional[int] = None,
chunk_overlap: Optional[int] = None,
top_k_dense: Optional[int] = None,
top_k_sparse: Optional[int] = None,
top_m_final: Optional[int] = None,
enable_rerank: Optional[bool] = None,
) -> KnowledgeBase:
"""创建知识库"""
kb = KnowledgeBase(
kb_name=kb_name,
description=description,
emoji=emoji or "📚",
embedding_provider_id=embedding_provider_id,
rerank_provider_id=rerank_provider_id,
chunk_size=chunk_size if chunk_size is not None else 512,
chunk_overlap=chunk_overlap if chunk_overlap is not None else 50,
top_k_dense=top_k_dense if top_k_dense is not None else 50,
top_k_sparse=top_k_sparse if top_k_sparse is not None else 50,
top_m_final=top_m_final if top_m_final is not None else 5,
enable_rerank=enable_rerank if enable_rerank is not None else True,
)
async with self.db.get_db() as session:
session.add(kb)
await session.commit()
await session.refresh(kb)
return kb
async def get_kb(self, kb_id: str) -> Optional[KnowledgeBase]:
"""获取知识库"""
async with self.db.get_db() as session:
stmt = select(KnowledgeBase).where(KnowledgeBase.kb_id == kb_id)
result = await session.execute(stmt)
return result.scalar_one_or_none()
async def list_kbs(self, offset: int = 0, limit: int = 100) -> list[KnowledgeBase]:
"""列出所有知识库"""
async with self.db.get_db() as session:
stmt = (
select(KnowledgeBase)
.offset(offset)
.limit(limit)
.order_by(KnowledgeBase.created_at.desc())
)
result = await session.execute(stmt)
return list(result.scalars().all())
async def update_kb(
self,
kb_id: str,
kb_name: Optional[str] = None,
description: Optional[str] = None,
emoji: Optional[str] = None,
embedding_provider_id: Optional[str] = None,
rerank_provider_id: Optional[str] = None,
chunk_size: Optional[int] = None,
chunk_overlap: Optional[int] = None,
top_k_dense: Optional[int] = None,
top_k_sparse: Optional[int] = None,
top_m_final: Optional[int] = None,
enable_rerank: Optional[bool] = None,
) -> Optional[KnowledgeBase]:
"""更新知识库"""
async with self.db.get_db() as session:
stmt = select(KnowledgeBase).where(KnowledgeBase.kb_id == kb_id)
result = await session.execute(stmt)
kb = result.scalar_one_or_none()
if not kb:
return None
if kb_name is not None:
kb.kb_name = kb_name
if description is not None:
kb.description = description
if emoji is not None:
kb.emoji = emoji
if embedding_provider_id is not None:
kb.embedding_provider_id = embedding_provider_id
if rerank_provider_id is not None:
kb.rerank_provider_id = rerank_provider_id
if chunk_size is not None:
kb.chunk_size = chunk_size
if chunk_overlap is not None:
kb.chunk_overlap = chunk_overlap
if top_k_dense is not None:
kb.top_k_dense = top_k_dense
if top_k_sparse is not None:
kb.top_k_sparse = top_k_sparse
if top_m_final is not None:
kb.top_m_final = top_m_final
if enable_rerank is not None:
kb.enable_rerank = enable_rerank
await session.commit()
await session.refresh(kb)
return kb
async def delete_kb(self, kb_id: str) -> bool:
"""删除知识库(级联删除所有文档和资源)"""
# 1. 获取所有文档
from astrbot.core.knowledge_base.manager_ops import KBManagerOps
ops = KBManagerOps(self)
docs = await ops.list_documents(kb_id)
# 2. 删除所有文档(包括文件和向量)
for doc in docs:
await ops.delete_document(doc.doc_id)
# 3. 删除知识库记录
async with self.db.get_db() as session:
stmt = select(KnowledgeBase).where(KnowledgeBase.kb_id == kb_id)
result = await session.execute(stmt)
kb = result.scalar_one_or_none()
if not kb:
return False
await session.delete(kb)
await session.commit()
return True
# ===== 文档上传 =====
async def upload_document(
self,
kb_id: str,
file_name: str,
file_content: bytes,
file_type: str,
) -> KBDocument:
"""上传并处理文档(带原子性保证和失败清理)
流程:
1. 保存原始文件
2. 解析文档内容
3. 提取多媒体资源
4. 分块处理
5. 生成向量并存储
6. 保存元数据(事务)
7. 更新统计
"""
doc_id = str(uuid.uuid4())
file_path = None
media_paths = []
vec_doc_ids = []
try:
# 1. 保存原始文件
file_path = self.files_path / kb_id / f"{doc_id}.{file_type}"
file_path.parent.mkdir(parents=True, exist_ok=True)
async with aiofiles.open(file_path, "wb") as f:
await f.write(file_content)
# 2. 解析文档
parser = self.parsers.get(file_type)
if not parser:
raise ValueError(f"不支持的文件类型: {file_type}")
parse_result = await parser.parse(file_content, file_name)
text_content = parse_result.text
media_items = parse_result.media
# 3. 保存多媒体资源
from astrbot.core.knowledge_base.manager_ops import KBManagerOps
ops = KBManagerOps(self)
saved_media = []
for media_item in media_items:
media = await ops._save_media(
kb_id=kb_id,
doc_id=doc_id,
media_type=media_item.media_type,
file_name=media_item.file_name,
content=media_item.content,
mime_type=media_item.mime_type,
)
saved_media.append(media)
media_paths.append(Path(media.file_path))
# 4. 文档分块
chunks_text = await self.chunker.chunk(text_content)
# 5. 生成向量并存储
saved_chunks = []
for idx, chunk_text in enumerate(chunks_text):
# 存储到向量数据库
vec_doc_id = await self.vec_db.insert(
content=chunk_text,
metadata={
"kb_id": kb_id,
"doc_id": doc_id,
"chunk_index": idx,
},
)
vec_doc_ids.append(str(vec_doc_id))
# 保存块元数据
chunk = KBChunk(
doc_id=doc_id,
kb_id=kb_id,
chunk_index=idx,
content=chunk_text,
char_count=len(chunk_text),
vec_doc_id=str(vec_doc_id),
)
saved_chunks.append(chunk)
# 6. 保存文档元数据(事务)
doc = KBDocument(
doc_id=doc_id,
kb_id=kb_id,
doc_name=file_name,
file_type=file_type,
file_size=len(file_content),
file_path=str(file_path),
chunk_count=len(saved_chunks),
media_count=len(saved_media),
)
async with self.db.get_db() as session:
async with session.begin():
session.add(doc)
for chunk in saved_chunks:
session.add(chunk)
for media in saved_media:
session.add(media)
await session.commit()
await session.refresh(doc)
# 7. 更新知识库统计
await self._update_kb_stats(kb_id)
return doc
except Exception as e:
# 失败清理:删除已创建的资源
from astrbot.core import logger
logger.error(f"文档上传失败,开始清理资源: {e}")
# 清理向量数据库
for vec_id in vec_doc_ids:
try:
await self.vec_db.delete(vec_id)
except Exception as ve:
logger.warning(f"清理向量失败 {vec_id}: {ve}")
# 清理多媒体文件
for media_path in media_paths:
try:
if media_path.exists():
media_path.unlink()
except Exception as me:
logger.warning(f"清理多媒体文件失败 {media_path}: {me}")
# 清理文档文件
if file_path and file_path.exists():
try:
file_path.unlink()
except Exception as fe:
logger.warning(f"清理文档文件失败 {file_path}: {fe}")
# 重新抛出原始异常
raise
# ===== 统计更新 =====
async def _update_kb_stats(self, kb_id: str):
"""更新知识库统计信息(事务中执行)"""
async with self.db.get_db() as session:
async with session.begin():
# 统计文档数(在事务中查询)
doc_count = await session.scalar(
select(func.count(KBDocument.id)).where(KBDocument.kb_id == kb_id)
) or 0
# 统计块数(在事务中查询)
chunk_count = await session.scalar(
select(func.count(KBChunk.id)).where(KBChunk.kb_id == kb_id)
) or 0
# 更新知识库(在同一事务中)
await session.execute(
update(KnowledgeBase)
.where(KnowledgeBase.kb_id == kb_id)
.values(doc_count=doc_count, chunk_count=chunk_count)
)
await session.commit()
+306
View File
@@ -0,0 +1,306 @@
"""知识库管理器辅助操作
该模块提供文档、块和多媒体的管理操作。
"""
import uuid
from pathlib import Path
from typing import TYPE_CHECKING
import aiofiles
from sqlalchemy import delete, func, select
from astrbot.core.knowledge_base.models import KBChunk, KBDocument, KBMedia
if TYPE_CHECKING:
from astrbot.core.knowledge_base.manager import KBManager
class KBManagerOps:
"""知识库管理器辅助操作类
职责:
- 文档管理操作
- 块管理操作
- 多媒体管理操作
"""
def __init__(self, manager: "KBManager"):
self.manager = manager
self.db = manager.db
self.vec_db = manager.vec_db
self.media_path = manager.media_path
self.files_path = manager.files_path
# ===== 文档操作 =====
async def list_documents(
self, kb_id: str, offset: int = 0, limit: int = 100
) -> list[KBDocument]:
"""列出知识库的所有文档"""
async with self.db.get_db() as session:
stmt = (
select(KBDocument)
.where(KBDocument.kb_id == kb_id)
.offset(offset)
.limit(limit)
.order_by(KBDocument.created_at.desc())
)
result = await session.execute(stmt)
return list(result.scalars().all())
async def get_document(self, doc_id: str) -> KBDocument | None:
"""获取文档详情"""
async with self.db.get_db() as session:
stmt = select(KBDocument).where(KBDocument.doc_id == doc_id)
result = await session.execute(stmt)
return result.scalar_one_or_none()
async def delete_document(self, doc_id: str) -> bool:
"""删除文档(级联删除块、多媒体、向量)
采用三阶段删除策略:
1. 删除向量数据库中的向量(允许部分失败)
2. 删除SQL数据库中的记录(事务保证原子性)
3. 删除文件系统中的文件(失败不影响数据一致性)
"""
from astrbot.core import logger
# 0. 获取文档信息
doc = await self.get_document(doc_id)
if not doc:
return False
# 收集所有需要删除的资源
chunks = await self.list_chunks(doc_id)
media_list = await self.list_media(doc_id)
# ===== 第一阶段: 删除向量(可重试) =====
vec_ids_to_delete = [chunk.vec_doc_id for chunk in chunks]
deleted_vec_ids = []
failed_vec_ids = []
for vec_id in vec_ids_to_delete:
try:
await self.vec_db.delete(vec_id)
deleted_vec_ids.append(vec_id)
except Exception as e:
logger.error(f"删除向量失败: {vec_id}, {e}")
failed_vec_ids.append(vec_id)
# 如果向量删除失败过多(超过50%),中止操作
if len(failed_vec_ids) > len(vec_ids_to_delete) * 0.5:
logger.error(
f"向量删除失败过多 ({len(failed_vec_ids)}/{len(vec_ids_to_delete)}), 中止文档删除"
)
return False
# 记录部分失败但继续执行
if failed_vec_ids:
logger.warning(
f"部分向量删除失败 ({len(failed_vec_ids)}/{len(vec_ids_to_delete)}), 但继续执行删除操作"
)
# ===== 第二阶段: 删除数据库记录(事务) =====
async with self.db.get_db() as session:
async with session.begin():
# 删除块记录
await session.execute(delete(KBChunk).where(KBChunk.doc_id == doc_id))
# 删除多媒体记录
await session.execute(delete(KBMedia).where(KBMedia.doc_id == doc_id))
# 删除文档记录
await session.execute(delete(KBDocument).where(KBDocument.doc_id == doc_id))
await session.commit()
# ===== 第三阶段: 删除文件(失败不影响) =====
# 删除多媒体文件
for media in media_list:
try:
media_path = Path(media.file_path)
if media_path.exists():
media_path.unlink()
except Exception as e:
logger.warning(f"删除多媒体文件失败: {media.file_path}, {e}")
# 删除文档文件
try:
file_path = Path(doc.file_path)
if file_path.exists():
file_path.unlink()
except Exception as e:
logger.warning(f"删除文档文件失败: {doc.file_path}, {e}")
# ===== 更新统计 =====
await self.manager._update_kb_stats(doc.kb_id)
return True
# ===== 块操作 =====
async def list_chunks(self, doc_id: str) -> list[KBChunk]:
"""列出文档的所有块"""
async with self.db.get_db() as session:
stmt = (
select(KBChunk)
.where(KBChunk.doc_id == doc_id)
.order_by(KBChunk.chunk_index)
)
result = await session.execute(stmt)
return list(result.scalars().all())
async def delete_chunk(self, chunk_id: str) -> bool:
"""删除单个块
流程:
1. 查询块信息
2. 删除向量
3. 删除数据库记录
4. 更新文档统计
"""
from astrbot.core import logger
# 1. 查询块信息
async with self.db.get_db() as session:
stmt = select(KBChunk).where(KBChunk.chunk_id == chunk_id)
result = await session.execute(stmt)
chunk = result.scalar_one_or_none()
if not chunk:
return False
doc_id = chunk.doc_id
vec_doc_id = chunk.vec_doc_id
# 2. 删除向量
try:
await self.vec_db.delete(vec_doc_id)
except Exception as e:
logger.error(f"删除向量失败: {vec_doc_id}, {e}")
return False
# 3. 删除数据库记录
async with self.db.get_db() as session:
async with session.begin():
await session.execute(delete(KBChunk).where(KBChunk.chunk_id == chunk_id))
await session.commit()
# 4. 更新文档统计
await self._update_doc_stats(doc_id)
return True
# ===== 多媒体操作 =====
async def list_media(self, doc_id: str) -> list[KBMedia]:
"""列出文档的所有多媒体资源"""
async with self.db.get_db() as session:
stmt = select(KBMedia).where(KBMedia.doc_id == doc_id)
result = await session.execute(stmt)
return list(result.scalars().all())
async def delete_media(self, media_id: str) -> bool:
"""删除多媒体资源
流程:
1. 查询媒体信息
2. 删除数据库记录
3. 删除文件(失败不影响)
4. 更新文档统计
"""
from astrbot.core import logger
# 1. 查询媒体信息
async with self.db.get_db() as session:
stmt = select(KBMedia).where(KBMedia.media_id == media_id)
result = await session.execute(stmt)
media = result.scalar_one_or_none()
if not media:
return False
doc_id = media.doc_id
file_path_str = media.file_path
# 2. 删除数据库记录
async with self.db.get_db() as session:
async with session.begin():
await session.execute(delete(KBMedia).where(KBMedia.media_id == media_id))
await session.commit()
# 3. 删除文件(失败不影响)
try:
media_path = Path(file_path_str)
if media_path.exists():
media_path.unlink()
except Exception as e:
logger.warning(f"删除多媒体文件失败: {file_path_str}, {e}")
# 4. 更新文档统计
await self._update_doc_stats(doc_id)
return True
# ===== 内部辅助方法 =====
async def _save_media(
self,
kb_id: str,
doc_id: str,
media_type: str,
file_name: str,
content: bytes,
mime_type: str,
) -> KBMedia:
"""保存多媒体资源"""
media_id = str(uuid.uuid4())
ext = Path(file_name).suffix
# 保存文件
file_path = self.media_path / kb_id / doc_id / f"{media_id}{ext}"
file_path.parent.mkdir(parents=True, exist_ok=True)
async with aiofiles.open(file_path, "wb") as f:
await f.write(content)
# 创建记录
media = KBMedia(
media_id=media_id,
doc_id=doc_id,
kb_id=kb_id,
media_type=media_type,
file_name=file_name,
file_path=str(file_path),
file_size=len(content),
mime_type=mime_type,
)
return media
async def _update_doc_stats(self, doc_id: str):
"""更新文档统计信息(事务中执行)"""
async with self.db.get_db() as session:
async with session.begin():
# 统计块数
chunk_count = (
await session.scalar(
select(func.count(KBChunk.id)).where(KBChunk.doc_id == doc_id)
)
) or 0
# 统计多媒体数
media_count = (
await session.scalar(
select(func.count(KBMedia.id)).where(KBMedia.doc_id == doc_id)
)
) or 0
# 更新文档
doc = await session.scalar(
select(KBDocument).where(KBDocument.doc_id == doc_id)
)
if doc:
doc.chunk_count = chunk_count
doc.media_count = media_count
await session.commit()
+184
View File
@@ -0,0 +1,184 @@
"""知识库管理功能的数据模型定义
该模块定义了知识库系统所需的数据模型,包括:
- KnowledgeBase: 知识库表 (存储在独立的 kb.db)
- KBDocument: 文档表 (存储在独立的 kb.db)
- KBChunk: 文档块表 (存储在独立的 kb.db)
- KBMedia: 多媒体资源表 (存储在独立的 kb.db)
- KBSessionConfig: 会话配置表 (存储在独立的 kb.db)
注意:
- 所有模型存储在独立的知识库数据库 (data/knowledge_base/kb.db)
- 与主数据库 (astrbot.db) 完全解耦
"""
import uuid
from datetime import datetime, timezone
from typing import Optional
from sqlmodel import Field, SQLModel, Text, UniqueConstraint
class KnowledgeBase(SQLModel, table=True):
"""知识库表
存储知识库的基本信息和统计数据。
"""
__tablename__ = "knowledge_bases"
id: int | None = Field(
primary_key=True, sa_column_kwargs={"autoincrement": True}, default=None
)
kb_id: str = Field(
max_length=36,
nullable=False,
unique=True,
default_factory=lambda: str(uuid.uuid4()),
index=True,
)
kb_name: str = Field(max_length=100, nullable=False)
description: Optional[str] = Field(default=None, sa_type=Text)
emoji: Optional[str] = Field(default="📚", max_length=10)
embedding_provider_id: Optional[str] = Field(default=None, max_length=100)
rerank_provider_id: Optional[str] = Field(default=None, max_length=100)
# 分块配置参数
chunk_size: Optional[int] = Field(default=512, nullable=True)
chunk_overlap: Optional[int] = Field(default=50, nullable=True)
# 检索配置参数
top_k_dense: Optional[int] = Field(default=50, nullable=True)
top_k_sparse: Optional[int] = Field(default=50, nullable=True)
top_m_final: Optional[int] = Field(default=5, nullable=True)
enable_rerank: Optional[bool] = Field(default=True, nullable=True)
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
updated_at: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc),
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
)
doc_count: int = Field(default=0, nullable=False)
chunk_count: int = Field(default=0, nullable=False)
class KBDocument(SQLModel, table=True):
"""文档表
存储上传到知识库的文档元数据。
"""
__tablename__ = "kb_documents"
id: int | None = Field(
primary_key=True, sa_column_kwargs={"autoincrement": True}, default=None
)
doc_id: str = Field(
max_length=36,
nullable=False,
unique=True,
default_factory=lambda: str(uuid.uuid4()),
index=True,
)
kb_id: str = Field(max_length=36, nullable=False, index=True)
doc_name: str = Field(max_length=255, nullable=False)
file_type: str = Field(max_length=20, nullable=False)
file_size: int = Field(nullable=False)
file_path: str = Field(max_length=512, nullable=False)
chunk_count: int = Field(default=0, nullable=False)
media_count: int = Field(default=0, nullable=False)
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
updated_at: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc),
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
)
class KBChunk(SQLModel, table=True):
"""文档块表
存储文档分块后的文本内容和向量索引关联信息。
"""
__tablename__ = "kb_chunks"
id: int | None = Field(
primary_key=True, sa_column_kwargs={"autoincrement": True}, default=None
)
chunk_id: str = Field(
max_length=36,
nullable=False,
unique=True,
default_factory=lambda: str(uuid.uuid4()),
index=True,
)
doc_id: str = Field(max_length=36, nullable=False, index=True)
kb_id: str = Field(max_length=36, nullable=False, index=True)
chunk_index: int = Field(nullable=False)
content: str = Field(sa_type=Text, nullable=False)
char_count: int = Field(nullable=False)
vec_doc_id: str = Field(max_length=100, nullable=False, index=True)
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
class KBMedia(SQLModel, table=True):
"""多媒体资源表
存储从文档中提取的图片、视频等多媒体资源。
"""
__tablename__ = "kb_media"
id: int | None = Field(
primary_key=True, sa_column_kwargs={"autoincrement": True}, default=None
)
media_id: str = Field(
max_length=36,
nullable=False,
unique=True,
default_factory=lambda: str(uuid.uuid4()),
index=True,
)
doc_id: str = Field(max_length=36, nullable=False, index=True)
kb_id: str = Field(max_length=36, nullable=False, index=True)
media_type: str = Field(max_length=20, nullable=False)
file_name: str = Field(max_length=255, nullable=False)
file_path: str = Field(max_length=512, nullable=False)
file_size: int = Field(nullable=False)
mime_type: str = Field(max_length=100, nullable=False)
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
class KBSessionConfig(SQLModel, table=True):
"""会话知识库配置表
存储会话或平台级别的知识库关联配置。
该表存储在知识库独立数据库中,保持完全解耦。
支持两种配置范围:
- platform: 平台级别配置 (如 'qq', 'telegram')
- session: 会话级别配置 (如 'qq:group:12345')
"""
__tablename__ = "kb_session_config"
id: int | None = Field(
primary_key=True, sa_column_kwargs={"autoincrement": True}, default=None
)
config_id: str = Field(
max_length=36,
nullable=False,
unique=True,
default_factory=lambda: str(uuid.uuid4()),
)
scope: str = Field(max_length=20, nullable=False)
scope_id: str = Field(max_length=255, nullable=False, index=True)
kb_ids: str = Field(sa_type=Text, nullable=False)
top_k: Optional[int] = Field(default=None, nullable=True)
enable_rerank: Optional[bool] = Field(default=None, nullable=True)
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
updated_at: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc),
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
)
__table_args__ = (
UniqueConstraint("scope", "scope_id", name="uix_scope_scope_id"),
)
@@ -0,0 +1,15 @@
"""
文档解析器模块
"""
from .base import BaseParser, MediaItem, ParseResult
from .text_parser import TextParser
from .pdf_parser import PDFParser
__all__ = [
"BaseParser",
"MediaItem",
"ParseResult",
"TextParser",
"PDFParser",
]
@@ -0,0 +1,50 @@
"""文档解析器基类和数据结构
定义了文档解析器的抽象接口和相关数据类。
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass
@dataclass
class MediaItem:
"""多媒体项
表示从文档中提取的多媒体资源。
"""
media_type: str # image, video
file_name: str
content: bytes
mime_type: str
@dataclass
class ParseResult:
"""解析结果
包含解析后的文本内容和提取的多媒体资源。
"""
text: str
media: list[MediaItem]
class BaseParser(ABC):
"""文档解析器基类
所有文档解析器都应该继承此类并实现 parse 方法。
"""
@abstractmethod
async def parse(self, file_content: bytes, file_name: str) -> ParseResult:
"""解析文档
Args:
file_content: 文件内容
file_name: 文件名
Returns:
ParseResult: 解析结果
"""
@@ -0,0 +1,100 @@
"""PDF 文件解析器
支持解析 PDF 文件中的文本和图片资源。
"""
import io
from pypdf import PdfReader
from astrbot.core.knowledge_base.parsers.base import (
BaseParser,
MediaItem,
ParseResult,
)
class PDFParser(BaseParser):
"""PDF 文档解析器
提取 PDF 中的文本内容和嵌入的图片资源。
"""
async def parse(self, file_content: bytes, file_name: str) -> ParseResult:
"""解析 PDF 文件
Args:
file_content: 文件内容
file_name: 文件名
Returns:
ParseResult: 包含文本和图片的解析结果
"""
pdf_file = io.BytesIO(file_content)
reader = PdfReader(pdf_file)
text_parts = []
media_items = []
# 提取文本
for page in reader.pages:
text = page.extract_text()
if text:
text_parts.append(text)
# 提取图片
image_counter = 0
for page_num, page in enumerate(reader.pages):
try:
# 安全检查 Resources
if "/Resources" not in page:
continue
resources = page["/Resources"]
if not resources or "/XObject" not in resources:
continue
xobjects = resources["/XObject"].get_object()
if not xobjects:
continue
for obj_name in xobjects:
try:
obj = xobjects[obj_name]
if obj.get("/Subtype") != "/Image":
continue
# 提取图片数据
image_data = obj.get_data()
# 确定格式
filter_type = obj.get("/Filter", "")
if filter_type == "/DCTDecode":
ext = "jpg"
mime_type = "image/jpeg"
elif filter_type == "/FlateDecode":
ext = "png"
mime_type = "image/png"
else:
ext = "png"
mime_type = "image/png"
image_counter += 1
media_items.append(
MediaItem(
media_type="image",
file_name=f"page_{page_num}_img_{image_counter}.{ext}",
content=image_data,
mime_type=mime_type,
)
)
except Exception:
# 单个图片提取失败不影响整体
continue
except Exception:
# 页面处理失败不影响其他页面
continue
full_text = "\n\n".join(text_parts)
return ParseResult(text=full_text, media=media_items)
@@ -0,0 +1,41 @@
"""文本文件解析器
支持解析 TXT 和 Markdown 文件。
"""
from astrbot.core.knowledge_base.parsers.base import BaseParser, ParseResult
class TextParser(BaseParser):
"""TXT/MD 文本解析器
支持多种字符编码的自动检测。
"""
async def parse(self, file_content: bytes, file_name: str) -> ParseResult:
"""解析文本文件
尝试使用多种编码解析文件内容。
Args:
file_content: 文件内容
file_name: 文件名
Returns:
ParseResult: 解析结果,不包含多媒体资源
Raises:
ValueError: 如果无法解码文件
"""
# 尝试多种编码
for encoding in ["utf-8", "gbk", "gb2312", "gb18030"]:
try:
text = file_content.decode(encoding)
break
except UnicodeDecodeError:
continue
else:
raise ValueError(f"无法解码文件: {file_name}")
# 文本文件无多媒体资源
return ParseResult(text=text, media=[])
@@ -0,0 +1,16 @@
"""
检索模块
"""
from .manager import RetrievalManager, RetrievalResult
from .sparse_retriever import SparseRetriever, SparseResult
from .rank_fusion import RankFusion, FusedResult
__all__ = [
"RetrievalManager",
"RetrievalResult",
"SparseRetriever",
"SparseResult",
"RankFusion",
"FusedResult",
]
@@ -0,0 +1,224 @@
"""检索管理器
协调稠密检索、稀疏检索和 Rerank,提供统一的检索接口
"""
import json
from dataclasses import dataclass
from typing import List, Optional
from astrbot.core.db.vec_db.base import BaseVecDB
from astrbot.core.knowledge_base.database import KBDatabase
from astrbot.core.knowledge_base.retrieval.rank_fusion import RankFusion
from astrbot.core.knowledge_base.retrieval.sparse_retriever import SparseRetriever
from astrbot.core.provider.provider import RerankProvider
@dataclass
class RetrievalResult:
"""检索结果"""
chunk_id: str
doc_id: str
doc_name: str
kb_id: str
kb_name: str
content: str
score: float
metadata: dict
class RetrievalManager:
"""检索管理器
职责:
- 协调稠密检索、稀疏检索和 Rerank
- 结果融合和排序
"""
def __init__(
self,
vec_db: BaseVecDB,
sparse_retriever: SparseRetriever,
rank_fusion: RankFusion,
kb_db: KBDatabase,
rerank_provider: Optional[RerankProvider] = None,
):
"""初始化检索管理器
Args:
vec_db: 向量数据库实例
sparse_retriever: 稀疏检索器
rank_fusion: 结果融合器
kb_db: 知识库数据库实例
rerank_provider: Rerank 提供商 (可选)
"""
self.vec_db = vec_db
self.sparse_retriever = sparse_retriever
self.rank_fusion = rank_fusion
self.kb_db = kb_db
self.rerank_provider = rerank_provider
async def retrieve(
self,
query: str,
kb_ids: List[str],
top_k_dense: int = 50,
top_k_sparse: int = 50,
top_n_fusion: int = 20,
top_m_final: int = 5,
enable_rerank: bool = True,
) -> List[RetrievalResult]:
"""混合检索
流程:
1. 稠密检索 (向量相似度)
2. 稀疏检索 (BM25)
3. 结果融合 (RRF)
4. Rerank 重排序
Args:
query: 查询文本
kb_ids: 知识库 ID 列表
top_k_dense: 稠密检索返回数量
top_k_sparse: 稀疏检索返回数量
top_n_fusion: 融合后返回数量
top_m_final: 最终返回数量
enable_rerank: 是否启用 Rerank
Returns:
List[RetrievalResult]: 检索结果列表
"""
# 1. 稠密检索
dense_results = await self._dense_retrieve(
query=query,
kb_ids=kb_ids,
top_k=top_k_dense,
)
# 2. 稀疏检索
sparse_results = await self.sparse_retriever.retrieve(
query=query,
kb_ids=kb_ids,
top_k=top_k_sparse,
)
# 3. 结果融合
fused_results = await self.rank_fusion.fuse(
dense_results=dense_results,
sparse_results=sparse_results,
top_k=top_n_fusion,
)
# 4. 转换为 RetrievalResult (获取元数据)
retrieval_results = []
for fr in fused_results:
metadata_dict = await self.kb_db.get_chunk_with_metadata(fr.chunk_id)
if metadata_dict:
retrieval_results.append(
RetrievalResult(
chunk_id=fr.chunk_id,
doc_id=fr.doc_id,
doc_name=metadata_dict["document"].doc_name,
kb_id=fr.kb_id,
kb_name=metadata_dict["knowledge_base"].kb_name,
content=fr.content,
score=fr.score,
metadata={
"chunk_index": metadata_dict["chunk"].chunk_index,
"char_count": metadata_dict["chunk"].char_count,
},
)
)
# 5. Rerank (可选)
if enable_rerank and self.rerank_provider and retrieval_results:
retrieval_results = await self._rerank(
query=query,
results=retrieval_results,
top_k=top_m_final,
)
else:
retrieval_results = retrieval_results[:top_m_final]
return retrieval_results
async def _dense_retrieve(
self,
query: str,
kb_ids: List[str],
top_k: int,
):
"""稠密检索 (向量相似度)
Args:
query: 查询文本
kb_ids: 知识库 ID 列表
top_k: 返回结果数量
Returns:
List[Result]: 检索结果列表
"""
# 直接调用向量数据库检索
vec_results = await self.vec_db.retrieve(
query=query,
k=top_k * len(kb_ids) * 2, # 增加候选数量以便过滤
)
# 过滤:只保留指定知识库的结果
filtered_results = []
for result in vec_results:
metadata_str = result.data.get("metadata", "{}")
try:
metadata = json.loads(metadata_str)
except (json.JSONDecodeError, TypeError):
metadata = {}
if metadata.get("kb_id") in kb_ids:
filtered_results.append(result)
if len(filtered_results) >= top_k:
break
return filtered_results[:top_k]
async def _rerank(
self,
query: str,
results: List[RetrievalResult],
top_k: int,
) -> List[RetrievalResult]:
"""Rerank 重排序
Args:
query: 查询文本
results: 检索结果列表
top_k: 返回结果数量
Returns:
List[RetrievalResult]: 重排序后的结果列表
"""
if not results:
return []
# 准备文档列表
docs = [r.content for r in results]
# 调用 Rerank Provider
rerank_results = await self.rerank_provider.rerank(
query=query,
documents=docs,
)
# 更新分数并重新排序
reranked_list = []
for rerank_result in rerank_results:
idx = rerank_result.index
if idx < len(results):
result = results[idx]
result.score = rerank_result.relevance_score
reranked_list.append(result)
reranked_list.sort(key=lambda x: x.score, reverse=True)
return reranked_list[:top_k]
@@ -0,0 +1,134 @@
"""检索结果融合器
使用 Reciprocal Rank Fusion (RRF) 算法融合稠密检索和稀疏检索的结果
"""
from dataclasses import dataclass
from typing import Dict, List
from astrbot.core.db.vec_db.base import Result
from astrbot.core.knowledge_base.database import KBDatabase
from astrbot.core.knowledge_base.retrieval.sparse_retriever import SparseResult
@dataclass
class FusedResult:
"""融合后的检索结果"""
chunk_id: str
doc_id: str
kb_id: str
content: str
score: float
class RankFusion:
"""检索结果融合器
职责:
- 融合稠密检索和稀疏检索的结果
- 使用 Reciprocal Rank Fusion (RRF) 算法
"""
def __init__(self, kb_db: KBDatabase, k: int = 60):
"""初始化结果融合器
Args:
kb_db: 知识库数据库实例
k: RRF 参数,用于平滑排名
"""
self.kb_db = kb_db
self.k = k
async def fuse(
self,
dense_results: List[Result],
sparse_results: List[SparseResult],
top_k: int = 20,
) -> List[FusedResult]:
"""融合稠密和稀疏检索结果
RRF 公式:
score(doc) = sum(1 / (k + rank_i))
Args:
dense_results: 稠密检索结果
sparse_results: 稀疏检索结果
top_k: 返回结果数量
Returns:
List[FusedResult]: 融合后的结果列表
"""
# 1. 构建排名映射
dense_ranks = {r.data["doc_id"]: (idx + 1) for idx, r in enumerate(dense_results)}
sparse_ranks = {r.chunk_id: (idx + 1) for idx, r in enumerate(sparse_results)}
# 2. 收集所有唯一的 ID (来自稠密检索的是 vec_doc_id, 稀疏检索的是 chunk_id)
# 需要统一为 chunk_id
all_chunk_ids = set()
vec_doc_id_to_dense = {} # vec_doc_id -> Result
chunk_id_to_sparse = {} # chunk_id -> SparseResult
# 处理稀疏检索结果
for r in sparse_results:
all_chunk_ids.add(r.chunk_id)
chunk_id_to_sparse[r.chunk_id] = r
# 处理稠密检索结果 (需要转换 vec_doc_id 到 chunk_id)
for r in dense_results:
vec_doc_id = r.data["doc_id"]
all_chunk_ids.add(vec_doc_id)
vec_doc_id_to_dense[vec_doc_id] = r
# 3. 计算 RRF 分数
rrf_scores: Dict[str, float] = {}
for identifier in all_chunk_ids:
score = 0.0
# 来自稠密检索的贡献
if identifier in dense_ranks:
score += 1.0 / (self.k + dense_ranks[identifier])
# 来自稀疏检索的贡献
if identifier in sparse_ranks:
score += 1.0 / (self.k + sparse_ranks[identifier])
rrf_scores[identifier] = score
# 4. 排序
sorted_ids = sorted(
rrf_scores.keys(), key=lambda cid: rrf_scores[cid], reverse=True
)[:top_k]
# 5. 构建融合结果
fused_results = []
for identifier in sorted_ids:
# 优先从稀疏检索获取完整信息
if identifier in chunk_id_to_sparse:
sr = chunk_id_to_sparse[identifier]
fused_results.append(
FusedResult(
chunk_id=sr.chunk_id,
doc_id=sr.doc_id,
kb_id=sr.kb_id,
content=sr.content,
score=rrf_scores[identifier],
)
)
elif identifier in vec_doc_id_to_dense:
# 从向量检索获取信息,需要从数据库获取块的详细信息
dr = vec_doc_id_to_dense[identifier]
chunk = await self.kb_db.get_chunk_by_vec_doc_id(identifier)
if chunk:
fused_results.append(
FusedResult(
chunk_id=chunk.chunk_id,
doc_id=chunk.doc_id,
kb_id=chunk.kb_id,
content=chunk.content,
score=rrf_scores[identifier],
)
)
return fused_results
@@ -0,0 +1,90 @@
"""稀疏检索器
使用 BM25 算法进行基于关键词的文档检索
"""
from dataclasses import dataclass
from typing import List
from rank_bm25 import BM25Okapi
from astrbot.core.knowledge_base.database import KBDatabase
@dataclass
class SparseResult:
"""稀疏检索结果"""
chunk_id: str
doc_id: str
kb_id: str
content: str
score: float
class SparseRetriever:
"""BM25 稀疏检索器
职责:
- 基于关键词的文档检索
- 使用 BM25 算法计算相关度
"""
def __init__(self, kb_db: KBDatabase):
"""初始化稀疏检索器
Args:
kb_db: 知识库数据库实例
"""
self.kb_db = kb_db
self._index_cache = {} # 缓存 BM25 索引
async def retrieve(
self,
query: str,
kb_ids: List[str],
top_k: int = 50,
) -> List[SparseResult]:
"""执行稀疏检索
Args:
query: 查询文本
kb_ids: 知识库 ID 列表
top_k: 返回结果数量
Returns:
List[SparseResult]: 检索结果列表
"""
# 1. 获取所有相关块
chunks = await self.kb_db.get_chunks_by_kb_ids(kb_ids)
if not chunks:
return []
# 2. 准备文档和索引
corpus = [chunk.content for chunk in chunks]
tokenized_corpus = [doc.split() for doc in corpus]
# 3. 构建 BM25 索引
bm25 = BM25Okapi(tokenized_corpus)
# 4. 执行检索
tokenized_query = query.split()
scores = bm25.get_scores(tokenized_query)
# 5. 排序并返回 Top-K
results = []
for idx, score in enumerate(scores):
chunk = chunks[idx]
results.append(
SparseResult(
chunk_id=chunk.chunk_id,
doc_id=chunk.doc_id,
kb_id=chunk.kb_id,
content=chunk.content,
score=float(score),
)
)
results.sort(key=lambda x: x.score, reverse=True)
return results[:top_k]
@@ -0,0 +1,157 @@
"""会话知识库配置数据库操作
该模块封装会话知识库配置的数据库查询操作。
注意: 会话配置表 (kb_session_config) 存储在知识库独立数据库 (kb.db) 中,
而不是主数据库 (astrbot.db) 中,以实现完全解耦。
"""
import json
from typing import Optional
from sqlalchemy import select
from astrbot.core.knowledge_base.kb_sqlite import KBSQLiteDatabase
from astrbot.core.knowledge_base.models import KBSessionConfig
class SessionConfigDB:
"""会话知识库配置数据库操作类
职责:
- 提供会话知识库配置管理
- 统一异常处理
注意: 该类操作知识库独立数据库,实现完全解耦
"""
def __init__(self, db: KBSQLiteDatabase):
"""初始化会话配置数据库操作类
Args:
db: 知识库独立数据库实例 (kb.db),不是主数据库
"""
self.db = db
async def get_session_kb_ids(self, session_id: str) -> list[str]:
"""获取会话关联的知识库 ID 列表
查找顺序:
1. 会话级别配置 (优先)
2. 平台级别配置
3. 返回空列表
"""
async with self.db.get_db() as session:
# 1. 查找会话级别配置
stmt = select(KBSessionConfig).where(
KBSessionConfig.scope == "session",
KBSessionConfig.scope_id == session_id,
)
result = await session.execute(stmt)
config = result.scalar_one_or_none()
if config:
return json.loads(config.kb_ids)
# 2. 提取平台 ID (格式: platform:xxx:session_id)
parts = session_id.split(":")
if len(parts) >= 2:
platform_id = parts[0]
# 查找平台级别配置
stmt = select(KBSessionConfig).where(
KBSessionConfig.scope == "platform",
KBSessionConfig.scope_id == platform_id,
)
result = await session.execute(stmt)
config = result.scalar_one_or_none()
if config:
return json.loads(config.kb_ids)
# 3. 无配置
return []
async def set_session_kb_ids(
self,
scope: str,
scope_id: str,
kb_ids: list[str],
top_k: Optional[int] = None,
enable_rerank: Optional[bool] = None,
) -> KBSessionConfig:
"""设置会话知识库配置
Args:
scope: 配置范围 (session/platform)
scope_id: 范围标识 (会话 ID 或平台 ID)
kb_ids: 知识库 ID 列表
top_k: 返回结果数量 (可选)
enable_rerank: 是否启用 Rerank (可选)
"""
async with self.db.get_db() as session:
# 查找现有配置
stmt = select(KBSessionConfig).where(
KBSessionConfig.scope == scope,
KBSessionConfig.scope_id == scope_id,
)
result = await session.execute(stmt)
config = result.scalar_one_or_none()
if config:
# 更新现有配置
config.kb_ids = json.dumps(kb_ids)
if top_k is not None:
config.top_k = top_k
if enable_rerank is not None:
config.enable_rerank = enable_rerank
else:
# 创建新配置
config = KBSessionConfig(
scope=scope,
scope_id=scope_id,
kb_ids=json.dumps(kb_ids),
top_k=top_k,
enable_rerank=enable_rerank,
)
session.add(config)
await session.commit()
await session.refresh(config)
return config
async def delete_session_kb_config(self, scope: str, scope_id: str) -> bool:
"""删除会话知识库配置"""
async with self.db.get_db() as session:
stmt = select(KBSessionConfig).where(
KBSessionConfig.scope == scope,
KBSessionConfig.scope_id == scope_id,
)
result = await session.execute(stmt)
config = result.scalar_one_or_none()
if not config:
return False
await session.delete(config)
await session.commit()
return True
async def list_all_session_configs(
self, offset: int = 0, limit: int = 100, scope: Optional[str] = None
) -> list[KBSessionConfig]:
"""列出所有会话配置"""
async with self.db.get_db() as session:
stmt = select(KBSessionConfig)
if scope:
stmt = stmt.where(KBSessionConfig.scope == scope)
stmt = (
stmt.offset(offset)
.limit(limit)
.order_by(KBSessionConfig.created_at.desc())
)
result = await session.execute(stmt)
return list(result.scalars().all())