feat: refactor document storage to use SQLModel and enhance database operations
This commit is contained in:
@@ -1,32 +1,99 @@
|
||||
import aiosqlite
|
||||
import os
|
||||
import json
|
||||
from datetime import datetime
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from sqlalchemy import Text, Column
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlmodel import Field, SQLModel, select, col, func, text, MetaData
|
||||
|
||||
|
||||
class BaseDocModel(SQLModel, table=False):
|
||||
metadata = MetaData()
|
||||
|
||||
|
||||
class Document(BaseDocModel, table=True):
|
||||
"""SQLModel for documents table."""
|
||||
|
||||
__tablename__ = "documents" # type: ignore
|
||||
|
||||
id: int | None = Field(
|
||||
default=None, primary_key=True, sa_column_kwargs={"autoincrement": True}
|
||||
)
|
||||
doc_id: str = Field(nullable=False)
|
||||
text: str = Field(nullable=False)
|
||||
metadata_: str | None = Field(default=None, sa_column=Column("metadata", Text))
|
||||
created_at: datetime | None = Field(default=None)
|
||||
updated_at: datetime | None = Field(default=None)
|
||||
|
||||
|
||||
class DocumentStorage:
|
||||
def __init__(self, db_path: str):
|
||||
self.db_path = db_path
|
||||
self.connection = None
|
||||
self.DATABASE_URL = f"sqlite+aiosqlite:///{db_path}"
|
||||
self.engine: AsyncEngine | None = None
|
||||
self.async_session_maker: sessionmaker | None = None
|
||||
self.sqlite_init_path = os.path.join(
|
||||
os.path.dirname(__file__), "sqlite_init.sql"
|
||||
)
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize the SQLite database and create the documents table if it doesn't exist."""
|
||||
if not os.path.exists(self.db_path):
|
||||
await self.connect()
|
||||
if not self.connection:
|
||||
raise RuntimeError("Failed to connect to the database.")
|
||||
async with self.connection.cursor() as cursor:
|
||||
with open(self.sqlite_init_path, "r", encoding="utf-8") as f:
|
||||
sql_script = f.read()
|
||||
await cursor.executescript(sql_script)
|
||||
await self.connection.commit()
|
||||
else:
|
||||
await self.connect()
|
||||
await self.connect()
|
||||
async with self.engine.begin() as conn: # type: ignore
|
||||
# Create tables using SQLModel
|
||||
await conn.run_sync(BaseDocModel.metadata.create_all)
|
||||
|
||||
try:
|
||||
await conn.execute(
|
||||
text(
|
||||
"ALTER TABLE documents ADD COLUMN kb_doc_id TEXT "
|
||||
"GENERATED ALWAYS AS (json_extract(metadata, '$.kb_doc_id')) STORED"
|
||||
)
|
||||
)
|
||||
await conn.execute(
|
||||
text(
|
||||
"ALTER TABLE documents ADD COLUMN user_id TEXT "
|
||||
"GENERATED ALWAYS AS (json_extract(metadata, '$.user_id')) STORED"
|
||||
)
|
||||
)
|
||||
|
||||
# Create indexes
|
||||
await conn.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_documents_kb_doc_id ON documents(kb_doc_id)"
|
||||
)
|
||||
)
|
||||
await conn.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_documents_user_id ON documents(user_id)"
|
||||
)
|
||||
)
|
||||
except BaseException:
|
||||
pass
|
||||
|
||||
await conn.commit()
|
||||
|
||||
async def connect(self):
|
||||
"""Connect to the SQLite database."""
|
||||
self.connection = await aiosqlite.connect(self.db_path)
|
||||
if self.engine is None:
|
||||
self.engine = create_async_engine(
|
||||
self.DATABASE_URL,
|
||||
echo=False,
|
||||
future=True,
|
||||
)
|
||||
self.async_session_maker = sessionmaker(
|
||||
self.engine, # type: ignore
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
) # type: ignore
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_session(self):
|
||||
"""Context manager for database sessions."""
|
||||
async with self.async_session_maker() as session: # type: ignore
|
||||
yield session
|
||||
|
||||
async def get_documents(
|
||||
self,
|
||||
@@ -39,37 +106,114 @@ class DocumentStorage:
|
||||
|
||||
Args:
|
||||
metadata_filters (dict): The metadata filters to apply.
|
||||
ids (list | None): Optional list of document IDs to filter.
|
||||
offset (int | None): Offset for pagination.
|
||||
limit (int | None): Limit for pagination.
|
||||
|
||||
Returns:
|
||||
list: The list of document IDs(primary key, not doc_id) that match the filters.
|
||||
list: The list of documents that match the filters.
|
||||
"""
|
||||
assert self.connection is not None, "Database connection is not initialized."
|
||||
# metadata filter -> SQL WHERE clause
|
||||
where_clauses = []
|
||||
values = []
|
||||
for key, val in metadata_filters.items():
|
||||
where_clauses.append(f"json_extract(metadata, '$.{key}') = ?")
|
||||
values.append(val)
|
||||
if ids is not None and len(ids) > 0:
|
||||
ids = [str(i) for i in ids if i != -1]
|
||||
where_clauses.append("id IN ({})".format(",".join("?" * len(ids))))
|
||||
values.extend(ids)
|
||||
where_sql = " AND ".join(where_clauses) or "1=1"
|
||||
assert self.engine is not None, "Database connection is not initialized."
|
||||
|
||||
async with self.get_session() as session:
|
||||
query = select(Document)
|
||||
|
||||
for key, val in metadata_filters.items():
|
||||
query = query.where(
|
||||
text(f"json_extract(metadata, '$.{key}') = :filter_{key}")
|
||||
).params(**{f"filter_{key}": val})
|
||||
|
||||
if ids is not None and len(ids) > 0:
|
||||
valid_ids = [int(i) for i in ids if i != -1]
|
||||
if valid_ids:
|
||||
query = query.where(col(Document.id).in_(valid_ids))
|
||||
|
||||
result = []
|
||||
async with self.connection.cursor() as cursor:
|
||||
sql = f"SELECT * FROM documents WHERE {where_sql}"
|
||||
if limit is not None:
|
||||
sql += " LIMIT ?"
|
||||
values.append(limit)
|
||||
query = query.limit(limit)
|
||||
if offset is not None:
|
||||
sql += " OFFSET ?"
|
||||
values.append(offset)
|
||||
query = query.offset(offset)
|
||||
|
||||
await cursor.execute(sql, values)
|
||||
for row in await cursor.fetchall():
|
||||
result.append(await self.tuple_to_dict(row))
|
||||
return result
|
||||
result = await session.execute(query)
|
||||
documents = result.scalars().all()
|
||||
|
||||
return [self._document_to_dict(doc) for doc in documents]
|
||||
|
||||
async def insert_document(self, doc_id: str, text: str, metadata: dict) -> int:
|
||||
"""Insert a single document and return its integer ID.
|
||||
|
||||
Args:
|
||||
doc_id (str): The document ID (UUID string).
|
||||
text (str): The document text.
|
||||
metadata (dict): The document metadata.
|
||||
|
||||
Returns:
|
||||
int: The integer ID of the inserted document.
|
||||
"""
|
||||
assert self.engine is not None, "Database connection is not initialized."
|
||||
|
||||
async with self.get_session() as session:
|
||||
async with session.begin():
|
||||
document = Document(
|
||||
doc_id=doc_id,
|
||||
text=text,
|
||||
metadata_=json.dumps(metadata),
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
)
|
||||
session.add(document)
|
||||
await session.flush() # Flush to get the ID
|
||||
return document.id # type: ignore
|
||||
|
||||
async def insert_documents_batch(
|
||||
self, doc_ids: list[str], texts: list[str], metadatas: list[dict]
|
||||
) -> list[int]:
|
||||
"""Batch insert documents and return their integer IDs.
|
||||
|
||||
Args:
|
||||
doc_ids (list[str]): List of document IDs (UUID strings).
|
||||
texts (list[str]): List of document texts.
|
||||
metadatas (list[dict]): List of document metadata.
|
||||
|
||||
Returns:
|
||||
list[int]: List of integer IDs of the inserted documents.
|
||||
"""
|
||||
assert self.engine is not None, "Database connection is not initialized."
|
||||
|
||||
async with self.get_session() as session:
|
||||
async with session.begin():
|
||||
import json
|
||||
|
||||
documents = []
|
||||
for doc_id, text, metadata in zip(doc_ids, texts, metadatas):
|
||||
document = Document(
|
||||
doc_id=doc_id,
|
||||
text=text,
|
||||
metadata_=json.dumps(metadata),
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
)
|
||||
documents.append(document)
|
||||
session.add(document)
|
||||
|
||||
await session.flush() # Flush to get all IDs
|
||||
return [doc.id for doc in documents] # type: ignore
|
||||
|
||||
async def delete_document_by_doc_id(self, doc_id: str):
|
||||
"""Delete a document by its doc_id.
|
||||
|
||||
Args:
|
||||
doc_id (str): The doc_id of the document to delete.
|
||||
"""
|
||||
assert self.engine is not None, "Database connection is not initialized."
|
||||
|
||||
async with self.get_session() as session:
|
||||
async with session.begin():
|
||||
query = select(Document).where(col(Document.doc_id) == doc_id)
|
||||
result = await session.execute(query)
|
||||
document = result.scalar_one_or_none()
|
||||
|
||||
if document:
|
||||
await session.delete(document)
|
||||
|
||||
async def get_document_by_doc_id(self, doc_id: str):
|
||||
"""Retrieve a document by its doc_id.
|
||||
@@ -78,30 +222,38 @@ class DocumentStorage:
|
||||
doc_id (str): The doc_id of the document to retrieve.
|
||||
|
||||
Returns:
|
||||
dict: The document data.
|
||||
dict: The document data or None if not found.
|
||||
"""
|
||||
assert self.connection is not None, "Database connection is not initialized."
|
||||
async with self.connection.cursor() as cursor:
|
||||
await cursor.execute("SELECT * FROM documents WHERE doc_id = ?", (doc_id,))
|
||||
row = await cursor.fetchone()
|
||||
if row:
|
||||
return await self.tuple_to_dict(row)
|
||||
else:
|
||||
return None
|
||||
assert self.engine is not None, "Database connection is not initialized."
|
||||
|
||||
async with self.get_session() as session:
|
||||
query = select(Document).where(col(Document.doc_id) == doc_id)
|
||||
result = await session.execute(query)
|
||||
document = result.scalar_one_or_none()
|
||||
|
||||
if document:
|
||||
return self._document_to_dict(document)
|
||||
return None
|
||||
|
||||
async def update_document_by_doc_id(self, doc_id: str, new_text: str):
|
||||
"""Retrieve a document by its doc_id.
|
||||
"""Update a document by its doc_id.
|
||||
|
||||
Args:
|
||||
doc_id (str): The doc_id.
|
||||
new_text (str): The new text to update the document with.
|
||||
"""
|
||||
assert self.connection is not None, "Database connection is not initialized."
|
||||
async with self.connection.cursor() as cursor:
|
||||
await cursor.execute(
|
||||
"UPDATE documents SET text = ? WHERE doc_id = ?", (new_text, doc_id)
|
||||
)
|
||||
await self.connection.commit()
|
||||
assert self.engine is not None, "Database connection is not initialized."
|
||||
|
||||
async with self.get_session() as session:
|
||||
async with session.begin():
|
||||
query = select(Document).where(col(Document.doc_id) == doc_id)
|
||||
result = await session.execute(query)
|
||||
document = result.scalar_one_or_none()
|
||||
|
||||
if document:
|
||||
document.text = new_text
|
||||
document.updated_at = datetime.now()
|
||||
session.add(document)
|
||||
|
||||
async def delete_documents(self, metadata_filters: dict):
|
||||
"""Delete documents by their metadata filters.
|
||||
@@ -109,16 +261,22 @@ class DocumentStorage:
|
||||
Args:
|
||||
metadata_filters (dict): The metadata filters to apply.
|
||||
"""
|
||||
assert self.connection is not None, "Database connection is not initialized."
|
||||
async with self.connection.cursor() as cursor:
|
||||
where_clauses = []
|
||||
values = []
|
||||
for key, val in metadata_filters.items():
|
||||
where_clauses.append(f"json_extract(metadata, '$.{key}') = ?")
|
||||
values.append(val)
|
||||
where_sql = " AND ".join(where_clauses) or "1=1"
|
||||
await cursor.execute(f"DELETE FROM documents WHERE {where_sql}", values)
|
||||
await self.connection.commit()
|
||||
assert self.engine is not None, "Database connection is not initialized."
|
||||
|
||||
async with self.get_session() as session:
|
||||
async with session.begin():
|
||||
query = select(Document)
|
||||
|
||||
for key, val in metadata_filters.items():
|
||||
query = query.where(
|
||||
text(f"json_extract(metadata, '$.{key}') = :filter_{key}")
|
||||
).params(**{f"filter_{key}": val})
|
||||
|
||||
result = await session.execute(query)
|
||||
documents = result.scalars().all()
|
||||
|
||||
for doc in documents:
|
||||
await session.delete(doc)
|
||||
|
||||
async def count_documents(self, metadata_filters: dict | None = None) -> int:
|
||||
"""Count documents in the database.
|
||||
@@ -129,20 +287,20 @@ class DocumentStorage:
|
||||
Returns:
|
||||
int: The count of documents.
|
||||
"""
|
||||
assert self.connection is not None, "Database connection is not initialized."
|
||||
async with self.connection.cursor() as cursor:
|
||||
sql = "SELECT COUNT(*) FROM documents"
|
||||
values = []
|
||||
assert self.engine is not None, "Database connection is not initialized."
|
||||
|
||||
async with self.get_session() as session:
|
||||
query = select(func.count(col(Document.id)))
|
||||
|
||||
if metadata_filters:
|
||||
where_clauses = []
|
||||
for key, val in metadata_filters.items():
|
||||
where_clauses.append(f"json_extract(metadata, '$.{key}') = ?")
|
||||
values.append(val)
|
||||
where_sql = " AND ".join(where_clauses)
|
||||
sql += f" WHERE {where_sql}"
|
||||
await cursor.execute(sql, values)
|
||||
count = await cursor.fetchone()
|
||||
return count[0] if count else 0
|
||||
query = query.where(
|
||||
text(f"json_extract(metadata, '$.{key}') = :filter_{key}")
|
||||
).params(**{f"filter_{key}": val})
|
||||
|
||||
result = await session.execute(query)
|
||||
count = result.scalar_one_or_none()
|
||||
return count if count is not None else 0
|
||||
|
||||
async def get_user_ids(self) -> list[str]:
|
||||
"""Retrieve all user IDs from the documents table.
|
||||
@@ -150,12 +308,38 @@ class DocumentStorage:
|
||||
Returns:
|
||||
list: A list of user IDs.
|
||||
"""
|
||||
assert self.connection is not None, "Database connection is not initialized."
|
||||
async with self.connection.cursor() as cursor:
|
||||
await cursor.execute("SELECT DISTINCT user_id FROM documents")
|
||||
rows = await cursor.fetchall()
|
||||
assert self.engine is not None, "Database connection is not initialized."
|
||||
|
||||
async with self.get_session() as session:
|
||||
query = text(
|
||||
"SELECT DISTINCT user_id FROM documents WHERE user_id IS NOT NULL"
|
||||
)
|
||||
result = await session.execute(query)
|
||||
rows = result.fetchall()
|
||||
return [row[0] for row in rows]
|
||||
|
||||
def _document_to_dict(self, document: Document) -> dict:
|
||||
"""Convert a Document model to a dictionary.
|
||||
|
||||
Args:
|
||||
document (Document): The document to convert.
|
||||
|
||||
Returns:
|
||||
dict: The converted dictionary.
|
||||
"""
|
||||
return {
|
||||
"id": document.id,
|
||||
"doc_id": document.doc_id,
|
||||
"text": document.text,
|
||||
"metadata": document.metadata_,
|
||||
"created_at": document.created_at.isoformat()
|
||||
if isinstance(document.created_at, datetime)
|
||||
else document.created_at,
|
||||
"updated_at": document.updated_at.isoformat()
|
||||
if isinstance(document.updated_at, datetime)
|
||||
else document.updated_at,
|
||||
}
|
||||
|
||||
async def tuple_to_dict(self, row):
|
||||
"""Convert a tuple to a dictionary.
|
||||
|
||||
@@ -164,6 +348,8 @@ class DocumentStorage:
|
||||
|
||||
Returns:
|
||||
dict: The converted dictionary.
|
||||
|
||||
Note: This method is kept for backward compatibility but is no longer used internally.
|
||||
"""
|
||||
return {
|
||||
"id": row[0],
|
||||
@@ -176,6 +362,7 @@ class DocumentStorage:
|
||||
|
||||
async def close(self):
|
||||
"""Close the connection to the SQLite database."""
|
||||
if self.connection:
|
||||
await self.connection.close()
|
||||
self.connection = None
|
||||
if self.engine:
|
||||
await self.engine.dispose()
|
||||
self.engine = None
|
||||
self.async_session_maker = None
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import uuid
|
||||
import json
|
||||
import time
|
||||
import numpy as np
|
||||
from .document_storage import DocumentStorage
|
||||
@@ -41,26 +40,18 @@ class FaissVecDB(BaseVecDB):
|
||||
"""
|
||||
插入一条文本和其对应向量,自动生成 ID 并保持一致性。
|
||||
"""
|
||||
assert self.document_storage.connection is not None, (
|
||||
"Database connection is not initialized."
|
||||
)
|
||||
metadata = metadata or {}
|
||||
str_id = id or str(uuid.uuid4()) # 使用 UUID 作为原始 ID
|
||||
|
||||
vector = await self.embedding_provider.get_embedding(content)
|
||||
vector = np.array(vector, dtype=np.float32)
|
||||
async with self.document_storage.connection.cursor() as cursor:
|
||||
await cursor.execute(
|
||||
"INSERT INTO documents (doc_id, text, metadata) VALUES (?, ?, ?)",
|
||||
(str_id, content, json.dumps(metadata)),
|
||||
)
|
||||
await self.document_storage.connection.commit()
|
||||
result = await self.document_storage.get_document_by_doc_id(str_id)
|
||||
int_id = result["id"]
|
||||
|
||||
# 插入向量到 FAISS
|
||||
await self.embedding_storage.insert(vector, int_id)
|
||||
return int_id
|
||||
# 使用 DocumentStorage 的方法插入文档
|
||||
int_id = await self.document_storage.insert_document(str_id, content, metadata)
|
||||
|
||||
# 插入向量到 FAISS
|
||||
await self.embedding_storage.insert(vector, int_id)
|
||||
return int_id
|
||||
|
||||
async def insert_batch(
|
||||
self,
|
||||
@@ -78,9 +69,6 @@ class FaissVecDB(BaseVecDB):
|
||||
Args:
|
||||
progress_callback: 进度回调函数,接收参数 (current, total)
|
||||
"""
|
||||
assert self.document_storage.connection is not None, (
|
||||
"Database connection is not initialized."
|
||||
)
|
||||
metadatas = metadatas or [{} for _ in contents]
|
||||
ids = ids or [str(uuid.uuid4()) for _ in contents]
|
||||
|
||||
@@ -98,23 +86,15 @@ class FaissVecDB(BaseVecDB):
|
||||
f"Generated embeddings for {len(contents)} contents in {end - start:.2f} seconds."
|
||||
)
|
||||
|
||||
int_ids = []
|
||||
async with self.document_storage.connection.cursor() as cursor:
|
||||
for str_id, content, metadata in zip(ids, contents, metadatas):
|
||||
await cursor.execute(
|
||||
"INSERT INTO documents (doc_id, text, metadata) VALUES (?, ?, ?)",
|
||||
(str_id, content, json.dumps(metadata)),
|
||||
)
|
||||
await self.document_storage.connection.commit()
|
||||
# 使用 DocumentStorage 的批量插入方法
|
||||
int_ids = await self.document_storage.insert_documents_batch(
|
||||
ids, contents, metadatas
|
||||
)
|
||||
|
||||
for str_id in ids:
|
||||
result = await self.document_storage.get_document_by_doc_id(str_id)
|
||||
int_ids.append(result["id"])
|
||||
|
||||
# 批量插入向量到 FAISS
|
||||
vectors_array = np.array(vectors).astype("float32")
|
||||
await self.embedding_storage.insert_batch(vectors_array, int_ids)
|
||||
return int_ids
|
||||
# 批量插入向量到 FAISS
|
||||
vectors_array = np.array(vectors).astype("float32")
|
||||
await self.embedding_storage.insert_batch(vectors_array, int_ids)
|
||||
return int_ids
|
||||
|
||||
async def retrieve(
|
||||
self,
|
||||
@@ -182,19 +162,15 @@ class FaissVecDB(BaseVecDB):
|
||||
"""
|
||||
删除一条文档块(chunk)
|
||||
"""
|
||||
assert self.document_storage.connection is not None, (
|
||||
"Database connection is not initialized."
|
||||
)
|
||||
# 获得对应的 int id
|
||||
result = await self.document_storage.get_document_by_doc_id(doc_id)
|
||||
int_id = result["id"] if result else None
|
||||
if int_id is None:
|
||||
return
|
||||
await self.document_storage.connection.execute(
|
||||
"DELETE FROM documents WHERE doc_id = ?", (doc_id,)
|
||||
)
|
||||
|
||||
# 使用 DocumentStorage 的删除方法
|
||||
await self.document_storage.delete_document_by_doc_id(doc_id)
|
||||
await self.embedding_storage.delete([int_id])
|
||||
await self.document_storage.connection.commit()
|
||||
|
||||
async def close(self):
|
||||
await self.document_storage.close()
|
||||
@@ -206,9 +182,6 @@ class FaissVecDB(BaseVecDB):
|
||||
Args:
|
||||
metadata_filter (dict | None): 元数据过滤器
|
||||
"""
|
||||
assert self.document_storage.connection is not None, (
|
||||
"Database connection is not initialized."
|
||||
)
|
||||
count = await self.document_storage.count_documents(
|
||||
metadata_filters=metadata_filter or {}
|
||||
)
|
||||
|
||||
@@ -14,8 +14,6 @@ from astrbot.core.knowledge_base.models import (
|
||||
)
|
||||
from astrbot.core.db.vec_db.faiss_impl import FaissVecDB
|
||||
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class KBSQLiteDatabase:
|
||||
def __init__(self, db_path: str = "data/knowledge_base/kb.db") -> None:
|
||||
@@ -167,14 +165,14 @@ class KBSQLiteDatabase:
|
||||
await self.engine.dispose()
|
||||
logger.info(f"知识库数据库已关闭: {self.db_path}")
|
||||
|
||||
async def get_kb_by_id(self, kb_id: str) -> Optional[KnowledgeBase]:
|
||||
async def get_kb_by_id(self, kb_id: str) -> KnowledgeBase | None:
|
||||
"""根据 ID 获取知识库"""
|
||||
async with self.get_db() as session:
|
||||
stmt = select(KnowledgeBase).where(col(KnowledgeBase.kb_id) == kb_id)
|
||||
result = await session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_kb_by_name(self, kb_name: str) -> Optional[KnowledgeBase]:
|
||||
async def get_kb_by_name(self, kb_name: str) -> KnowledgeBase | None:
|
||||
"""根据名称获取知识库"""
|
||||
async with self.get_db() as session:
|
||||
stmt = select(KnowledgeBase).where(col(KnowledgeBase.kb_name) == kb_name)
|
||||
@@ -202,7 +200,7 @@ class KBSQLiteDatabase:
|
||||
|
||||
# ===== 文档查询 =====
|
||||
|
||||
async def get_document_by_id(self, doc_id: str) -> Optional[KBDocument]:
|
||||
async def get_document_by_id(self, doc_id: str) -> KBDocument | None:
|
||||
"""根据 ID 获取文档"""
|
||||
async with self.get_db() as session:
|
||||
stmt = select(KBDocument).where(col(KBDocument.doc_id) == doc_id)
|
||||
@@ -233,7 +231,7 @@ class KBSQLiteDatabase:
|
||||
result = await session.execute(stmt)
|
||||
return result.scalar() or 0
|
||||
|
||||
async def get_document_with_metadata(self, doc_id: str) -> Optional[dict]:
|
||||
async def get_document_with_metadata(self, doc_id: str) -> dict | None:
|
||||
async with self.get_db() as session:
|
||||
stmt = (
|
||||
select(KBDocument, KnowledgeBase)
|
||||
@@ -262,7 +260,7 @@ class KBSQLiteDatabase:
|
||||
await session.commit()
|
||||
|
||||
# 在 vec db 中删除相关向量
|
||||
await vec_db.delete_documents(metadata_filters={"doc_id": doc_id})
|
||||
await vec_db.delete_documents(metadata_filters={"kb_doc_id": doc_id})
|
||||
|
||||
# ===== 多媒体查询 =====
|
||||
|
||||
@@ -273,7 +271,7 @@ class KBSQLiteDatabase:
|
||||
result = await session.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def get_media_by_id(self, media_id: str) -> Optional[KBMedia]:
|
||||
async def get_media_by_id(self, media_id: str) -> KBMedia | None:
|
||||
"""根据 ID 获取多媒体资源"""
|
||||
async with self.get_db() as session:
|
||||
stmt = select(KBMedia).where(col(KBMedia.media_id) == media_id)
|
||||
|
||||
@@ -175,7 +175,7 @@ class KBHelper:
|
||||
metadatas.append(
|
||||
{
|
||||
"kb_id": self.kb.kb_id,
|
||||
"doc_id": doc_id,
|
||||
"kb_doc_id": doc_id,
|
||||
"chunk_index": idx,
|
||||
}
|
||||
)
|
||||
@@ -297,7 +297,7 @@ class KBHelper:
|
||||
"""获取文档的所有块及其元数据"""
|
||||
vec_db: FaissVecDB = self.vec_db # type: ignore
|
||||
chunks = await vec_db.document_storage.get_documents(
|
||||
metadata_filters={"doc_id": doc_id}, offset=offset, limit=limit
|
||||
metadata_filters={"kb_doc_id": doc_id}, offset=offset, limit=limit
|
||||
)
|
||||
result = []
|
||||
for chunk in chunks:
|
||||
@@ -305,7 +305,7 @@ class KBHelper:
|
||||
result.append(
|
||||
{
|
||||
"chunk_id": chunk["doc_id"],
|
||||
"doc_id": chunk_md["doc_id"],
|
||||
"doc_id": chunk_md["kb_doc_id"],
|
||||
"kb_id": chunk_md["kb_id"],
|
||||
"chunk_index": chunk_md["chunk_index"],
|
||||
"content": chunk["text"],
|
||||
@@ -317,7 +317,7 @@ class KBHelper:
|
||||
async def get_chunk_count_by_doc_id(self, doc_id: str) -> int:
|
||||
"""获取文档的块数量"""
|
||||
vec_db: FaissVecDB = self.vec_db # type: ignore
|
||||
count = await vec_db.count_documents(metadata_filter={"doc_id": doc_id})
|
||||
count = await vec_db.count_documents(metadata_filter={"kb_doc_id": doc_id})
|
||||
return count
|
||||
|
||||
async def _save_media(
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
|
||||
from sqlmodel import Field, SQLModel, Text, UniqueConstraint
|
||||
from sqlmodel import Field, SQLModel, Text, UniqueConstraint, MetaData
|
||||
|
||||
|
||||
class BaseKBModel(SQLModel, table=False):
|
||||
pass
|
||||
metadata = MetaData()
|
||||
|
||||
|
||||
class KnowledgeBase(BaseKBModel, table=True):
|
||||
@@ -28,17 +27,17 @@ class KnowledgeBase(BaseKBModel, table=True):
|
||||
index=True,
|
||||
)
|
||||
kb_name: str = Field(max_length=100, nullable=False)
|
||||
description: Optional[str] = Field(default=None, sa_type=Text)
|
||||
emoji: Optional[str] = Field(default="📚", max_length=10)
|
||||
embedding_provider_id: Optional[str] = Field(default=None, max_length=100)
|
||||
rerank_provider_id: Optional[str] = Field(default=None, max_length=100)
|
||||
description: str | None = Field(default=None, sa_type=Text)
|
||||
emoji: str | None = Field(default="📚", max_length=10)
|
||||
embedding_provider_id: str | None = Field(default=None, max_length=100)
|
||||
rerank_provider_id: str | None = Field(default=None, max_length=100)
|
||||
# 分块配置参数
|
||||
chunk_size: Optional[int] = Field(default=512, nullable=True)
|
||||
chunk_overlap: Optional[int] = Field(default=50, nullable=True)
|
||||
chunk_size: int | None = Field(default=512, nullable=True)
|
||||
chunk_overlap: int | None = Field(default=50, nullable=True)
|
||||
# 检索配置参数
|
||||
top_k_dense: Optional[int] = Field(default=50, nullable=True)
|
||||
top_k_sparse: Optional[int] = Field(default=50, nullable=True)
|
||||
top_m_final: Optional[int] = Field(default=5, nullable=True)
|
||||
top_k_dense: int | None = Field(default=50, nullable=True)
|
||||
top_k_sparse: int | None = Field(default=50, nullable=True)
|
||||
top_m_final: int | None = Field(default=5, nullable=True)
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
updated_at: datetime = Field(
|
||||
default_factory=lambda: datetime.now(timezone.utc),
|
||||
|
||||
@@ -5,7 +5,6 @@
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List
|
||||
|
||||
from astrbot.core.db.vec_db.base import Result
|
||||
from astrbot.core.knowledge_base.kb_db_sqlite import KBSQLiteDatabase
|
||||
@@ -44,10 +43,10 @@ class RankFusion:
|
||||
|
||||
async def fuse(
|
||||
self,
|
||||
dense_results: List[Result],
|
||||
sparse_results: List[SparseResult],
|
||||
dense_results: list[Result],
|
||||
sparse_results: list[SparseResult],
|
||||
top_k: int = 20,
|
||||
) -> List[FusedResult]:
|
||||
) -> list[FusedResult]:
|
||||
"""融合稠密和稀疏检索结果
|
||||
|
||||
RRF 公式:
|
||||
@@ -85,7 +84,7 @@ class RankFusion:
|
||||
vec_doc_id_to_dense[vec_doc_id] = r
|
||||
|
||||
# 3. 计算 RRF 分数
|
||||
rrf_scores: Dict[str, float] = {}
|
||||
rrf_scores: dict[str, float] = {}
|
||||
|
||||
for identifier in all_chunk_ids:
|
||||
score = 0.0
|
||||
@@ -129,7 +128,7 @@ class RankFusion:
|
||||
FusedResult(
|
||||
chunk_id=identifier,
|
||||
chunk_index=chunk_md["chunk_index"],
|
||||
doc_id=chunk_md["doc_id"],
|
||||
doc_id=chunk_md["kb_doc_id"],
|
||||
kb_id=chunk_md["kb_id"],
|
||||
content=vec_result.data["text"],
|
||||
score=rrf_scores[identifier],
|
||||
|
||||
@@ -7,7 +7,6 @@ import jieba
|
||||
import os
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
from rank_bm25 import BM25Okapi
|
||||
from astrbot.core.knowledge_base.kb_db_sqlite import KBSQLiteDatabase
|
||||
from astrbot.core.db.vec_db.faiss_impl import FaissVecDB
|
||||
@@ -44,7 +43,6 @@ class SparseRetriever:
|
||||
|
||||
with open(
|
||||
os.path.join(os.path.dirname(__file__), "hit_stopwords.txt"),
|
||||
"r",
|
||||
encoding="utf-8",
|
||||
) as f:
|
||||
self.hit_stopwords = {
|
||||
@@ -54,9 +52,9 @@ class SparseRetriever:
|
||||
async def retrieve(
|
||||
self,
|
||||
query: str,
|
||||
kb_ids: List[str],
|
||||
kb_ids: list[str],
|
||||
kb_options: dict,
|
||||
) -> List[SparseResult]:
|
||||
) -> list[SparseResult]:
|
||||
"""执行稀疏检索
|
||||
|
||||
Args:
|
||||
@@ -82,7 +80,7 @@ class SparseRetriever:
|
||||
{
|
||||
"chunk_id": doc["doc_id"],
|
||||
"chunk_index": chunk_md["chunk_index"],
|
||||
"doc_id": chunk_md["doc_id"],
|
||||
"doc_id": chunk_md["kb_doc_id"],
|
||||
"kb_id": kb_id,
|
||||
"text": doc["text"],
|
||||
}
|
||||
|
||||
@@ -94,8 +94,7 @@
|
||||
|
||||
<v-select v-model="formData.embedding_provider_id" :items="embeddingProviders"
|
||||
:item-title="item => item.embedding_model || item.id" :item-value="'id'"
|
||||
:label="t('create.embeddingModelLabel')" variant="outlined" class="mb-4" :disabled="true"
|
||||
@update:model-value="handleEmbeddingProviderChange">
|
||||
:label="t('create.embeddingModelLabel')" variant="outlined" class="mb-4" :disabled="editingKB !== null">
|
||||
<template #item="{ props, item }">
|
||||
<v-list-item v-bind="props">
|
||||
<template #subtitle>
|
||||
@@ -327,19 +326,6 @@ const editKB = (kb: any) => {
|
||||
showCreateDialog.value = true
|
||||
}
|
||||
|
||||
// 处理 embedding provider 变更
|
||||
const handleEmbeddingProviderChange = (newValue: string | null) => {
|
||||
// 检测是否修改了embedding provider
|
||||
if (newValue && originalEmbeddingProvider.value && newValue !== originalEmbeddingProvider.value) {
|
||||
// 显示二次确认对话框
|
||||
showEmbeddingWarning.value = true
|
||||
pendingEmbeddingProvider.value = newValue
|
||||
embeddingChangeDialog.value = true
|
||||
} else {
|
||||
showEmbeddingWarning.value = false
|
||||
}
|
||||
}
|
||||
|
||||
// 确认删除
|
||||
const confirmDelete = (kb: any) => {
|
||||
deleteTarget.value = kb
|
||||
|
||||
Reference in New Issue
Block a user