376 lines
13 KiB
Python
376 lines
13 KiB
Python
"""知识库管理器
|
|
|
|
该模块提供知识库的CRUD操作和文档上传处理流程。
|
|
"""
|
|
|
|
import uuid
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
|
|
import aiofiles
|
|
from sqlalchemy import func, select, update
|
|
|
|
from astrbot.core.db import BaseDatabase
|
|
from astrbot.core.db.vec_db.base import BaseVecDB
|
|
from astrbot.core.knowledge_base.chunking.base import BaseChunker
|
|
from astrbot.core.knowledge_base.models import KBChunk, KBDocument, KnowledgeBase
|
|
from astrbot.core.knowledge_base.parsers.base import BaseParser
|
|
|
|
|
|
class KBManager:
|
|
"""知识库管理器
|
|
|
|
职责:
|
|
- 知识库的 CRUD 操作
|
|
- 文档上传与解析
|
|
- 文档块生成与存储
|
|
- 多媒体资源管理
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
db: BaseDatabase,
|
|
vec_db: BaseVecDB,
|
|
storage_path: str,
|
|
parsers: dict[str, BaseParser],
|
|
chunker: BaseChunker,
|
|
provider_manager=None,
|
|
):
|
|
self.db = db
|
|
self.vec_db = vec_db
|
|
self.storage_path = Path(storage_path)
|
|
self.media_path = self.storage_path / "media"
|
|
self.files_path = self.storage_path / "files"
|
|
self.parsers = parsers
|
|
self.chunker = chunker
|
|
self.provider_manager = provider_manager
|
|
|
|
# 确保目录存在
|
|
self.media_path.mkdir(parents=True, exist_ok=True)
|
|
self.files_path.mkdir(parents=True, exist_ok=True)
|
|
|
|
# ===== 知识库操作 =====
|
|
|
|
async def create_kb(
|
|
self,
|
|
kb_name: str,
|
|
description: Optional[str] = None,
|
|
emoji: Optional[str] = None,
|
|
embedding_provider_id: Optional[str] = None,
|
|
rerank_provider_id: Optional[str] = None,
|
|
chunk_size: Optional[int] = None,
|
|
chunk_overlap: Optional[int] = None,
|
|
top_k_dense: Optional[int] = None,
|
|
top_k_sparse: Optional[int] = None,
|
|
top_m_final: Optional[int] = None,
|
|
enable_rerank: Optional[bool] = None,
|
|
) -> KnowledgeBase:
|
|
"""创建知识库
|
|
|
|
Args:
|
|
enable_rerank: 是否启用重排序。
|
|
- 如果明确传入 True/False,则使用该值
|
|
- 如果为 None,则根据是否有可用的 rerank provider 自动决定
|
|
"""
|
|
# 智能决定 enable_rerank 的默认值
|
|
if enable_rerank is None:
|
|
# 检查是否有可用的 rerank provider
|
|
has_rerank_provider = (
|
|
self.provider_manager
|
|
and hasattr(self.provider_manager, 'rerank_provider_insts')
|
|
and len(self.provider_manager.rerank_provider_insts) > 0
|
|
)
|
|
enable_rerank = has_rerank_provider
|
|
|
|
kb = KnowledgeBase(
|
|
kb_name=kb_name,
|
|
description=description,
|
|
emoji=emoji or "📚",
|
|
embedding_provider_id=embedding_provider_id,
|
|
rerank_provider_id=rerank_provider_id,
|
|
chunk_size=chunk_size if chunk_size is not None else 512,
|
|
chunk_overlap=chunk_overlap if chunk_overlap is not None else 50,
|
|
top_k_dense=top_k_dense if top_k_dense is not None else 50,
|
|
top_k_sparse=top_k_sparse if top_k_sparse is not None else 50,
|
|
top_m_final=top_m_final if top_m_final is not None else 5,
|
|
enable_rerank=enable_rerank,
|
|
)
|
|
async with self.db.get_db() as session:
|
|
session.add(kb)
|
|
await session.commit()
|
|
await session.refresh(kb)
|
|
return kb
|
|
|
|
async def get_kb(self, kb_id: str) -> Optional[KnowledgeBase]:
|
|
"""获取知识库"""
|
|
async with self.db.get_db() as session:
|
|
stmt = select(KnowledgeBase).where(KnowledgeBase.kb_id == kb_id)
|
|
result = await session.execute(stmt)
|
|
return result.scalar_one_or_none()
|
|
|
|
async def list_kbs(self, offset: int = 0, limit: int = 100) -> list[KnowledgeBase]:
|
|
"""列出所有知识库"""
|
|
async with self.db.get_db() as session:
|
|
stmt = (
|
|
select(KnowledgeBase)
|
|
.offset(offset)
|
|
.limit(limit)
|
|
.order_by(KnowledgeBase.created_at.desc())
|
|
)
|
|
result = await session.execute(stmt)
|
|
return list(result.scalars().all())
|
|
|
|
async def update_kb(
|
|
self,
|
|
kb_id: str,
|
|
kb_name: Optional[str] = None,
|
|
description: Optional[str] = None,
|
|
emoji: Optional[str] = None,
|
|
embedding_provider_id: Optional[str] = None,
|
|
rerank_provider_id: Optional[str] = None,
|
|
chunk_size: Optional[int] = None,
|
|
chunk_overlap: Optional[int] = None,
|
|
top_k_dense: Optional[int] = None,
|
|
top_k_sparse: Optional[int] = None,
|
|
top_m_final: Optional[int] = None,
|
|
enable_rerank: Optional[bool] = None,
|
|
) -> Optional[KnowledgeBase]:
|
|
"""更新知识库"""
|
|
async with self.db.get_db() as session:
|
|
stmt = select(KnowledgeBase).where(KnowledgeBase.kb_id == kb_id)
|
|
result = await session.execute(stmt)
|
|
kb = result.scalar_one_or_none()
|
|
if not kb:
|
|
return None
|
|
|
|
if kb_name is not None:
|
|
kb.kb_name = kb_name
|
|
if description is not None:
|
|
kb.description = description
|
|
if emoji is not None:
|
|
kb.emoji = emoji
|
|
if embedding_provider_id is not None:
|
|
kb.embedding_provider_id = embedding_provider_id
|
|
if rerank_provider_id is not None:
|
|
kb.rerank_provider_id = rerank_provider_id
|
|
if chunk_size is not None:
|
|
kb.chunk_size = chunk_size
|
|
if chunk_overlap is not None:
|
|
kb.chunk_overlap = chunk_overlap
|
|
if top_k_dense is not None:
|
|
kb.top_k_dense = top_k_dense
|
|
if top_k_sparse is not None:
|
|
kb.top_k_sparse = top_k_sparse
|
|
if top_m_final is not None:
|
|
kb.top_m_final = top_m_final
|
|
if enable_rerank is not None:
|
|
kb.enable_rerank = enable_rerank
|
|
|
|
await session.commit()
|
|
await session.refresh(kb)
|
|
return kb
|
|
|
|
async def delete_kb(self, kb_id: str) -> bool:
|
|
"""删除知识库(级联删除所有文档和资源)"""
|
|
# 1. 获取所有文档
|
|
from astrbot.core.knowledge_base.manager_ops import KBManagerOps
|
|
|
|
ops = KBManagerOps(self)
|
|
docs = await ops.list_documents(kb_id)
|
|
|
|
# 2. 删除所有文档(包括文件和向量)
|
|
for doc in docs:
|
|
await ops.delete_document(doc.doc_id)
|
|
|
|
# 3. 删除知识库记录
|
|
async with self.db.get_db() as session:
|
|
stmt = select(KnowledgeBase).where(KnowledgeBase.kb_id == kb_id)
|
|
result = await session.execute(stmt)
|
|
kb = result.scalar_one_or_none()
|
|
if not kb:
|
|
return False
|
|
|
|
await session.delete(kb)
|
|
await session.commit()
|
|
|
|
return True
|
|
|
|
# ===== 文档上传 =====
|
|
|
|
async def upload_document(
|
|
self,
|
|
kb_id: str,
|
|
file_name: str,
|
|
file_content: bytes,
|
|
file_type: str,
|
|
) -> KBDocument:
|
|
"""上传并处理文档(带原子性保证和失败清理)
|
|
|
|
流程:
|
|
1. 保存原始文件
|
|
2. 解析文档内容
|
|
3. 提取多媒体资源
|
|
4. 分块处理
|
|
5. 生成向量并存储
|
|
6. 保存元数据(事务)
|
|
7. 更新统计
|
|
"""
|
|
doc_id = str(uuid.uuid4())
|
|
file_path = None
|
|
media_paths = []
|
|
vec_doc_ids = []
|
|
|
|
try:
|
|
# 1. 保存原始文件
|
|
file_path = self.files_path / kb_id / f"{doc_id}.{file_type}"
|
|
file_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
async with aiofiles.open(file_path, "wb") as f:
|
|
await f.write(file_content)
|
|
|
|
# 2. 解析文档
|
|
parser = self.parsers.get(file_type)
|
|
if not parser:
|
|
raise ValueError(f"不支持的文件类型: {file_type}")
|
|
|
|
parse_result = await parser.parse(file_content, file_name)
|
|
text_content = parse_result.text
|
|
media_items = parse_result.media
|
|
|
|
# 3. 保存多媒体资源
|
|
from astrbot.core.knowledge_base.manager_ops import KBManagerOps
|
|
|
|
ops = KBManagerOps(self)
|
|
saved_media = []
|
|
for media_item in media_items:
|
|
media = await ops._save_media(
|
|
kb_id=kb_id,
|
|
doc_id=doc_id,
|
|
media_type=media_item.media_type,
|
|
file_name=media_item.file_name,
|
|
content=media_item.content,
|
|
mime_type=media_item.mime_type,
|
|
)
|
|
saved_media.append(media)
|
|
media_paths.append(Path(media.file_path))
|
|
|
|
# 4. 文档分块
|
|
chunks_text = await self.chunker.chunk(text_content)
|
|
|
|
# 5. 生成向量并存储
|
|
saved_chunks = []
|
|
for idx, chunk_text in enumerate(chunks_text):
|
|
# 存储到向量数据库
|
|
vec_doc_id = await self.vec_db.insert(
|
|
content=chunk_text,
|
|
metadata={
|
|
"kb_id": kb_id,
|
|
"doc_id": doc_id,
|
|
"chunk_index": idx,
|
|
},
|
|
)
|
|
vec_doc_ids.append(str(vec_doc_id))
|
|
|
|
# 保存块元数据
|
|
chunk = KBChunk(
|
|
doc_id=doc_id,
|
|
kb_id=kb_id,
|
|
chunk_index=idx,
|
|
content=chunk_text,
|
|
char_count=len(chunk_text),
|
|
vec_doc_id=str(vec_doc_id),
|
|
)
|
|
saved_chunks.append(chunk)
|
|
|
|
# 6. 保存文档元数据(事务)
|
|
doc = KBDocument(
|
|
doc_id=doc_id,
|
|
kb_id=kb_id,
|
|
doc_name=file_name,
|
|
file_type=file_type,
|
|
file_size=len(file_content),
|
|
file_path=str(file_path),
|
|
chunk_count=len(saved_chunks),
|
|
media_count=len(saved_media),
|
|
)
|
|
|
|
async with self.db.get_db() as session:
|
|
async with session.begin():
|
|
session.add(doc)
|
|
for chunk in saved_chunks:
|
|
session.add(chunk)
|
|
for media in saved_media:
|
|
session.add(media)
|
|
await session.commit()
|
|
|
|
await session.refresh(doc)
|
|
|
|
# 7. 更新知识库统计
|
|
await self._update_kb_stats(kb_id)
|
|
|
|
return doc
|
|
|
|
except Exception as e:
|
|
# 失败清理:删除已创建的资源
|
|
from astrbot.core import logger
|
|
|
|
logger.error(f"文档上传失败,开始清理资源: {e}")
|
|
|
|
# 清理向量数据库
|
|
for vec_id in vec_doc_ids:
|
|
try:
|
|
await self.vec_db.delete(vec_id)
|
|
except Exception as ve:
|
|
logger.warning(f"清理向量失败 {vec_id}: {ve}")
|
|
|
|
# 清理多媒体文件
|
|
for media_path in media_paths:
|
|
try:
|
|
if media_path.exists():
|
|
media_path.unlink()
|
|
except Exception as me:
|
|
logger.warning(f"清理多媒体文件失败 {media_path}: {me}")
|
|
|
|
# 清理文档文件
|
|
if file_path and file_path.exists():
|
|
try:
|
|
file_path.unlink()
|
|
except Exception as fe:
|
|
logger.warning(f"清理文档文件失败 {file_path}: {fe}")
|
|
|
|
# 重新抛出原始异常
|
|
raise
|
|
|
|
# ===== 统计更新 =====
|
|
|
|
async def _update_kb_stats(self, kb_id: str):
|
|
"""更新知识库统计信息(事务中执行)"""
|
|
async with self.db.get_db() as session:
|
|
async with session.begin():
|
|
# 统计文档数(在事务中查询)
|
|
doc_count = (
|
|
await session.scalar(
|
|
select(func.count(KBDocument.id)).where(
|
|
KBDocument.kb_id == kb_id
|
|
)
|
|
)
|
|
or 0
|
|
)
|
|
|
|
# 统计块数(在事务中查询)
|
|
chunk_count = (
|
|
await session.scalar(
|
|
select(func.count(KBChunk.id)).where(KBChunk.kb_id == kb_id)
|
|
)
|
|
or 0
|
|
)
|
|
|
|
# 更新知识库(在同一事务中)
|
|
await session.execute(
|
|
update(KnowledgeBase)
|
|
.where(KnowledgeBase.kb_id == kb_id)
|
|
.values(doc_count=doc_count, chunk_count=chunk_count)
|
|
)
|
|
|
|
await session.commit()
|