313 lines
9.9 KiB
Python
313 lines
9.9 KiB
Python
"""知识库管理器辅助操作
|
|
|
|
该模块提供文档、块和多媒体的管理操作。
|
|
"""
|
|
|
|
import uuid
|
|
from pathlib import Path
|
|
from typing import TYPE_CHECKING
|
|
|
|
import aiofiles
|
|
from sqlalchemy import delete, func, select
|
|
|
|
from astrbot.core.knowledge_base.models import KBChunk, KBDocument, KBMedia
|
|
|
|
if TYPE_CHECKING:
|
|
from astrbot.core.knowledge_base.manager import KBManager
|
|
|
|
|
|
class KBManagerOps:
|
|
"""知识库管理器辅助操作类
|
|
|
|
职责:
|
|
- 文档管理操作
|
|
- 块管理操作
|
|
- 多媒体管理操作
|
|
"""
|
|
|
|
def __init__(self, manager: "KBManager"):
|
|
self.manager = manager
|
|
self.db = manager.db
|
|
self.vec_db = manager.vec_db
|
|
self.media_path = manager.media_path
|
|
self.files_path = manager.files_path
|
|
|
|
# ===== 文档操作 =====
|
|
|
|
async def list_documents(
|
|
self, kb_id: str, offset: int = 0, limit: int = 100
|
|
) -> list[KBDocument]:
|
|
"""列出知识库的所有文档"""
|
|
async with self.db.get_db() as session:
|
|
stmt = (
|
|
select(KBDocument)
|
|
.where(KBDocument.kb_id == kb_id)
|
|
.offset(offset)
|
|
.limit(limit)
|
|
.order_by(KBDocument.created_at.desc())
|
|
)
|
|
result = await session.execute(stmt)
|
|
return list(result.scalars().all())
|
|
|
|
async def get_document(self, doc_id: str) -> KBDocument | None:
|
|
"""获取文档详情"""
|
|
async with self.db.get_db() as session:
|
|
stmt = select(KBDocument).where(KBDocument.doc_id == doc_id)
|
|
result = await session.execute(stmt)
|
|
return result.scalar_one_or_none()
|
|
|
|
async def delete_document(self, doc_id: str) -> bool:
|
|
"""删除文档(级联删除块、多媒体、向量)
|
|
|
|
采用三阶段删除策略:
|
|
1. 删除向量数据库中的向量(允许部分失败)
|
|
2. 删除SQL数据库中的记录(事务保证原子性)
|
|
3. 删除文件系统中的文件(失败不影响数据一致性)
|
|
"""
|
|
from astrbot.core import logger
|
|
|
|
# 0. 获取文档信息
|
|
doc = await self.get_document(doc_id)
|
|
if not doc:
|
|
return False
|
|
|
|
# 收集所有需要删除的资源
|
|
chunks = await self.list_chunks(doc_id)
|
|
media_list = await self.list_media(doc_id)
|
|
|
|
# ===== 第一阶段: 删除向量(可重试) =====
|
|
vec_ids_to_delete = [chunk.vec_doc_id for chunk in chunks]
|
|
deleted_vec_ids = []
|
|
failed_vec_ids = []
|
|
|
|
for vec_id in vec_ids_to_delete:
|
|
try:
|
|
await self.vec_db.delete(vec_id)
|
|
deleted_vec_ids.append(vec_id)
|
|
except Exception as e:
|
|
logger.error(f"删除向量失败: {vec_id}, {e}")
|
|
failed_vec_ids.append(vec_id)
|
|
|
|
# 如果向量删除失败过多(超过50%),中止操作
|
|
if len(failed_vec_ids) > len(vec_ids_to_delete) * 0.5:
|
|
logger.error(
|
|
f"向量删除失败过多 ({len(failed_vec_ids)}/{len(vec_ids_to_delete)}), 中止文档删除"
|
|
)
|
|
return False
|
|
|
|
# 记录部分失败但继续执行
|
|
if failed_vec_ids:
|
|
logger.warning(
|
|
f"部分向量删除失败 ({len(failed_vec_ids)}/{len(vec_ids_to_delete)}), 但继续执行删除操作"
|
|
)
|
|
|
|
# ===== 第二阶段: 删除数据库记录(事务) =====
|
|
async with self.db.get_db() as session:
|
|
async with session.begin():
|
|
# 删除块记录
|
|
await session.execute(delete(KBChunk).where(KBChunk.doc_id == doc_id))
|
|
|
|
# 删除多媒体记录
|
|
await session.execute(delete(KBMedia).where(KBMedia.doc_id == doc_id))
|
|
|
|
# 删除文档记录
|
|
await session.execute(
|
|
delete(KBDocument).where(KBDocument.doc_id == doc_id)
|
|
)
|
|
|
|
await session.commit()
|
|
|
|
# ===== 第三阶段: 删除文件(失败不影响) =====
|
|
# 删除多媒体文件
|
|
for media in media_list:
|
|
try:
|
|
media_path = Path(media.file_path)
|
|
if media_path.exists():
|
|
media_path.unlink()
|
|
except Exception as e:
|
|
logger.warning(f"删除多媒体文件失败: {media.file_path}, {e}")
|
|
|
|
# 删除文档文件
|
|
try:
|
|
file_path = Path(doc.file_path)
|
|
if file_path.exists():
|
|
file_path.unlink()
|
|
except Exception as e:
|
|
logger.warning(f"删除文档文件失败: {doc.file_path}, {e}")
|
|
|
|
# ===== 更新统计 =====
|
|
await self.manager._update_kb_stats(doc.kb_id)
|
|
|
|
return True
|
|
|
|
# ===== 块操作 =====
|
|
|
|
async def list_chunks(self, doc_id: str) -> list[KBChunk]:
|
|
"""列出文档的所有块"""
|
|
async with self.db.get_db() as session:
|
|
stmt = (
|
|
select(KBChunk)
|
|
.where(KBChunk.doc_id == doc_id)
|
|
.order_by(KBChunk.chunk_index)
|
|
)
|
|
result = await session.execute(stmt)
|
|
return list(result.scalars().all())
|
|
|
|
async def delete_chunk(self, chunk_id: str) -> bool:
|
|
"""删除单个块
|
|
|
|
流程:
|
|
1. 查询块信息
|
|
2. 删除向量
|
|
3. 删除数据库记录
|
|
4. 更新文档统计
|
|
"""
|
|
from astrbot.core import logger
|
|
|
|
# 1. 查询块信息
|
|
async with self.db.get_db() as session:
|
|
stmt = select(KBChunk).where(KBChunk.chunk_id == chunk_id)
|
|
result = await session.execute(stmt)
|
|
chunk = result.scalar_one_or_none()
|
|
if not chunk:
|
|
return False
|
|
|
|
doc_id = chunk.doc_id
|
|
vec_doc_id = chunk.vec_doc_id
|
|
|
|
# 2. 删除向量
|
|
try:
|
|
await self.vec_db.delete(vec_doc_id)
|
|
except Exception as e:
|
|
logger.error(f"删除向量失败: {vec_doc_id}, {e}")
|
|
return False
|
|
|
|
# 3. 删除数据库记录
|
|
async with self.db.get_db() as session:
|
|
async with session.begin():
|
|
await session.execute(
|
|
delete(KBChunk).where(KBChunk.chunk_id == chunk_id)
|
|
)
|
|
await session.commit()
|
|
|
|
# 4. 更新文档统计
|
|
await self._update_doc_stats(doc_id)
|
|
|
|
return True
|
|
|
|
# ===== 多媒体操作 =====
|
|
|
|
async def list_media(self, doc_id: str) -> list[KBMedia]:
|
|
"""列出文档的所有多媒体资源"""
|
|
async with self.db.get_db() as session:
|
|
stmt = select(KBMedia).where(KBMedia.doc_id == doc_id)
|
|
result = await session.execute(stmt)
|
|
return list(result.scalars().all())
|
|
|
|
async def delete_media(self, media_id: str) -> bool:
|
|
"""删除多媒体资源
|
|
|
|
流程:
|
|
1. 查询媒体信息
|
|
2. 删除数据库记录
|
|
3. 删除文件(失败不影响)
|
|
4. 更新文档统计
|
|
"""
|
|
from astrbot.core import logger
|
|
|
|
# 1. 查询媒体信息
|
|
async with self.db.get_db() as session:
|
|
stmt = select(KBMedia).where(KBMedia.media_id == media_id)
|
|
result = await session.execute(stmt)
|
|
media = result.scalar_one_or_none()
|
|
if not media:
|
|
return False
|
|
|
|
doc_id = media.doc_id
|
|
file_path_str = media.file_path
|
|
|
|
# 2. 删除数据库记录
|
|
async with self.db.get_db() as session:
|
|
async with session.begin():
|
|
await session.execute(
|
|
delete(KBMedia).where(KBMedia.media_id == media_id)
|
|
)
|
|
await session.commit()
|
|
|
|
# 3. 删除文件(失败不影响)
|
|
try:
|
|
media_path = Path(file_path_str)
|
|
if media_path.exists():
|
|
media_path.unlink()
|
|
except Exception as e:
|
|
logger.warning(f"删除多媒体文件失败: {file_path_str}, {e}")
|
|
|
|
# 4. 更新文档统计
|
|
await self._update_doc_stats(doc_id)
|
|
|
|
return True
|
|
|
|
# ===== 内部辅助方法 =====
|
|
|
|
async def _save_media(
|
|
self,
|
|
kb_id: str,
|
|
doc_id: str,
|
|
media_type: str,
|
|
file_name: str,
|
|
content: bytes,
|
|
mime_type: str,
|
|
) -> KBMedia:
|
|
"""保存多媒体资源"""
|
|
media_id = str(uuid.uuid4())
|
|
ext = Path(file_name).suffix
|
|
|
|
# 保存文件
|
|
file_path = self.media_path / kb_id / doc_id / f"{media_id}{ext}"
|
|
file_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
async with aiofiles.open(file_path, "wb") as f:
|
|
await f.write(content)
|
|
|
|
# 创建记录
|
|
media = KBMedia(
|
|
media_id=media_id,
|
|
doc_id=doc_id,
|
|
kb_id=kb_id,
|
|
media_type=media_type,
|
|
file_name=file_name,
|
|
file_path=str(file_path),
|
|
file_size=len(content),
|
|
mime_type=mime_type,
|
|
)
|
|
|
|
return media
|
|
|
|
async def _update_doc_stats(self, doc_id: str):
|
|
"""更新文档统计信息(事务中执行)"""
|
|
async with self.db.get_db() as session:
|
|
async with session.begin():
|
|
# 统计块数
|
|
chunk_count = (
|
|
await session.scalar(
|
|
select(func.count(KBChunk.id)).where(KBChunk.doc_id == doc_id)
|
|
)
|
|
) or 0
|
|
|
|
# 统计多媒体数
|
|
media_count = (
|
|
await session.scalar(
|
|
select(func.count(KBMedia.id)).where(KBMedia.doc_id == doc_id)
|
|
)
|
|
) or 0
|
|
|
|
# 更新文档
|
|
doc = await session.scalar(
|
|
select(KBDocument).where(KBDocument.doc_id == doc_id)
|
|
)
|
|
if doc:
|
|
doc.chunk_count = chunk_count
|
|
doc.media_count = media_count
|
|
|
|
await session.commit()
|