feat: update chunk deletion to include document ID and refresh metadata
This commit is contained in:
@@ -1,12 +1,13 @@
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
|
||||
from sqlmodel import SQLModel, col, desc
|
||||
from sqlmodel import col, desc
|
||||
from sqlalchemy import text, func, select, update, delete
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.knowledge_base.models import (
|
||||
BaseKBModel,
|
||||
KBDocument,
|
||||
KBMedia,
|
||||
KnowledgeBase,
|
||||
@@ -61,7 +62,7 @@ class KBSQLiteDatabase:
|
||||
"""初始化数据库,创建表并配置 SQLite 参数"""
|
||||
async with self.engine.begin() as conn:
|
||||
# 创建所有知识库相关表
|
||||
await conn.run_sync(SQLModel.metadata.create_all)
|
||||
await conn.run_sync(BaseKBModel.metadata.create_all)
|
||||
|
||||
# 配置 SQLite 性能优化参数
|
||||
await conn.execute(text("PRAGMA journal_mode=WAL"))
|
||||
|
||||
@@ -15,6 +15,7 @@ from astrbot.core import logger
|
||||
|
||||
class KBHelper:
|
||||
vec_db: BaseVecDB
|
||||
kb: KnowledgeBase
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -216,8 +217,9 @@ class KBHelper:
|
||||
kb_id=self.kb.kb_id,
|
||||
vec_db=self.vec_db, # type: ignore
|
||||
)
|
||||
await self.refresh_kb()
|
||||
|
||||
async def delete_chunk(self, chunk_id: str):
|
||||
async def delete_chunk(self, chunk_id: str, doc_id: str):
|
||||
"""删除单个文本块及其相关数据"""
|
||||
vec_db: FaissVecDB = self.vec_db # type: ignore
|
||||
await vec_db.delete(chunk_id)
|
||||
@@ -225,6 +227,27 @@ class KBHelper:
|
||||
kb_id=self.kb.kb_id,
|
||||
vec_db=self.vec_db, # type: ignore
|
||||
)
|
||||
await self.refresh_kb()
|
||||
await self.refresh_document(doc_id)
|
||||
|
||||
async def refresh_kb(self):
|
||||
if self.kb:
|
||||
kb = await self.kb_db.get_kb_by_id(self.kb.kb_id)
|
||||
if kb:
|
||||
self.kb = kb
|
||||
|
||||
async def refresh_document(self, doc_id: str) -> None:
|
||||
"""更新文档的元数据"""
|
||||
doc = await self.get_document(doc_id)
|
||||
if not doc:
|
||||
raise ValueError(f"无法找到 ID 为 {doc_id} 的文档")
|
||||
chunk_count = await self.get_chunk_count_by_doc_id(doc_id)
|
||||
doc.chunk_count = chunk_count
|
||||
async with self.kb_db.get_db() as session:
|
||||
async with session.begin():
|
||||
session.add(doc)
|
||||
await session.commit()
|
||||
await session.refresh(doc)
|
||||
|
||||
async def get_chunks_by_doc_id(
|
||||
self, doc_id: str, offset: int = 0, limit: int = 100
|
||||
|
||||
@@ -5,7 +5,11 @@ from typing import Optional
|
||||
from sqlmodel import Field, SQLModel, Text, UniqueConstraint
|
||||
|
||||
|
||||
class KnowledgeBase(SQLModel, table=True):
|
||||
class BaseKBModel(SQLModel, table=False):
|
||||
pass
|
||||
|
||||
|
||||
class KnowledgeBase(BaseKBModel, table=True):
|
||||
"""知识库表
|
||||
|
||||
存储知识库的基本信息和统计数据。
|
||||
@@ -51,7 +55,7 @@ class KnowledgeBase(SQLModel, table=True):
|
||||
)
|
||||
|
||||
|
||||
class KBDocument(SQLModel, table=True):
|
||||
class KBDocument(BaseKBModel, table=True):
|
||||
"""文档表
|
||||
|
||||
存储上传到知识库的文档元数据。
|
||||
@@ -83,7 +87,7 @@ class KBDocument(SQLModel, table=True):
|
||||
)
|
||||
|
||||
|
||||
class KBMedia(SQLModel, table=True):
|
||||
class KBMedia(BaseKBModel, table=True):
|
||||
"""多媒体资源表
|
||||
|
||||
存储从文档中提取的图片、视频等多媒体资源。
|
||||
|
||||
@@ -625,12 +625,15 @@ class KnowledgeBaseRoute(Route):
|
||||
chunk_id = data.get("chunk_id")
|
||||
if not chunk_id:
|
||||
return Response().error("缺少参数 chunk_id").__dict__
|
||||
doc_id = data.get("doc_id")
|
||||
if not doc_id:
|
||||
return Response().error("缺少参数 doc_id").__dict__
|
||||
|
||||
kb_helper = await kb_manager.get_kb(kb_id)
|
||||
if not kb_helper:
|
||||
return Response().error("知识库不存在").__dict__
|
||||
|
||||
await kb_helper.delete_chunk(chunk_id)
|
||||
await kb_helper.delete_chunk(chunk_id, doc_id)
|
||||
return Response().ok(message="删除文本块成功").__dict__
|
||||
|
||||
except ValueError as e:
|
||||
|
||||
Reference in New Issue
Block a user