feat: update chunk deletion to include document ID and refresh metadata

This commit is contained in:
Soulter
2025-10-24 14:18:32 +08:00
parent 2f130ba009
commit 1fd482e899
4 changed files with 38 additions and 7 deletions
+3 -2
View File
@@ -1,12 +1,13 @@
from contextlib import asynccontextmanager
from pathlib import Path
from sqlmodel import SQLModel, col, desc
from sqlmodel import col, desc
from sqlalchemy import text, func, select, update, delete
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from astrbot.core import logger
from astrbot.core.knowledge_base.models import (
BaseKBModel,
KBDocument,
KBMedia,
KnowledgeBase,
@@ -61,7 +62,7 @@ class KBSQLiteDatabase:
"""初始化数据库,创建表并配置 SQLite 参数"""
async with self.engine.begin() as conn:
# 创建所有知识库相关表
await conn.run_sync(SQLModel.metadata.create_all)
await conn.run_sync(BaseKBModel.metadata.create_all)
# 配置 SQLite 性能优化参数
await conn.execute(text("PRAGMA journal_mode=WAL"))
+24 -1
View File
@@ -15,6 +15,7 @@ from astrbot.core import logger
class KBHelper:
vec_db: BaseVecDB
kb: KnowledgeBase
def __init__(
self,
@@ -216,8 +217,9 @@ class KBHelper:
kb_id=self.kb.kb_id,
vec_db=self.vec_db, # type: ignore
)
await self.refresh_kb()
async def delete_chunk(self, chunk_id: str):
async def delete_chunk(self, chunk_id: str, doc_id: str):
"""删除单个文本块及其相关数据"""
vec_db: FaissVecDB = self.vec_db # type: ignore
await vec_db.delete(chunk_id)
@@ -225,6 +227,27 @@ class KBHelper:
kb_id=self.kb.kb_id,
vec_db=self.vec_db, # type: ignore
)
await self.refresh_kb()
await self.refresh_document(doc_id)
async def refresh_kb(self):
if self.kb:
kb = await self.kb_db.get_kb_by_id(self.kb.kb_id)
if kb:
self.kb = kb
async def refresh_document(self, doc_id: str) -> None:
"""更新文档的元数据"""
doc = await self.get_document(doc_id)
if not doc:
raise ValueError(f"无法找到 ID 为 {doc_id} 的文档")
chunk_count = await self.get_chunk_count_by_doc_id(doc_id)
doc.chunk_count = chunk_count
async with self.kb_db.get_db() as session:
async with session.begin():
session.add(doc)
await session.commit()
await session.refresh(doc)
async def get_chunks_by_doc_id(
self, doc_id: str, offset: int = 0, limit: int = 100
+7 -3
View File
@@ -5,7 +5,11 @@ from typing import Optional
from sqlmodel import Field, SQLModel, Text, UniqueConstraint
class KnowledgeBase(SQLModel, table=True):
class BaseKBModel(SQLModel, table=False):
pass
class KnowledgeBase(BaseKBModel, table=True):
"""知识库表
存储知识库的基本信息和统计数据。
@@ -51,7 +55,7 @@ class KnowledgeBase(SQLModel, table=True):
)
class KBDocument(SQLModel, table=True):
class KBDocument(BaseKBModel, table=True):
"""文档表
存储上传到知识库的文档元数据。
@@ -83,7 +87,7 @@ class KBDocument(SQLModel, table=True):
)
class KBMedia(SQLModel, table=True):
class KBMedia(BaseKBModel, table=True):
"""多媒体资源表
存储从文档中提取的图片、视频等多媒体资源。
+4 -1
View File
@@ -625,12 +625,15 @@ class KnowledgeBaseRoute(Route):
chunk_id = data.get("chunk_id")
if not chunk_id:
return Response().error("缺少参数 chunk_id").__dict__
doc_id = data.get("doc_id")
if not doc_id:
return Response().error("缺少参数 doc_id").__dict__
kb_helper = await kb_manager.get_kb(kb_id)
if not kb_helper:
return Response().error("知识库不存在").__dict__
await kb_helper.delete_chunk(chunk_id)
await kb_helper.delete_chunk(chunk_id, doc_id)
return Response().ok(message="删除文本块成功").__dict__
except ValueError as e: