diff --git a/astrbot/core/db/vec_db/faiss_impl/document_storage.py b/astrbot/core/db/vec_db/faiss_impl/document_storage.py index 17c0cb3ae..1feeb9b92 100644 --- a/astrbot/core/db/vec_db/faiss_impl/document_storage.py +++ b/astrbot/core/db/vec_db/faiss_impl/document_storage.py @@ -1,32 +1,99 @@ -import aiosqlite import os +import json +from datetime import datetime +from contextlib import asynccontextmanager + +from sqlalchemy import Text, Column +from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine +from sqlalchemy.orm import sessionmaker +from sqlmodel import Field, SQLModel, select, col, func, text, MetaData + + +class BaseDocModel(SQLModel, table=False): + metadata = MetaData() + + +class Document(BaseDocModel, table=True): + """SQLModel for documents table.""" + + __tablename__ = "documents" # type: ignore + + id: int | None = Field( + default=None, primary_key=True, sa_column_kwargs={"autoincrement": True} + ) + doc_id: str = Field(nullable=False) + text: str = Field(nullable=False) + metadata_: str | None = Field(default=None, sa_column=Column("metadata", Text)) + created_at: datetime | None = Field(default=None) + updated_at: datetime | None = Field(default=None) class DocumentStorage: def __init__(self, db_path: str): self.db_path = db_path - self.connection = None + self.DATABASE_URL = f"sqlite+aiosqlite:///{db_path}" + self.engine: AsyncEngine | None = None + self.async_session_maker: sessionmaker | None = None self.sqlite_init_path = os.path.join( os.path.dirname(__file__), "sqlite_init.sql" ) async def initialize(self): """Initialize the SQLite database and create the documents table if it doesn't exist.""" - if not os.path.exists(self.db_path): - await self.connect() - if not self.connection: - raise RuntimeError("Failed to connect to the database.") - async with self.connection.cursor() as cursor: - with open(self.sqlite_init_path, "r", encoding="utf-8") as f: - sql_script = f.read() - await cursor.executescript(sql_script) - await self.connection.commit() - else: - await self.connect() + await self.connect() + async with self.engine.begin() as conn: # type: ignore + # Create tables using SQLModel + await conn.run_sync(BaseDocModel.metadata.create_all) + + try: + await conn.execute( + text( + "ALTER TABLE documents ADD COLUMN kb_doc_id TEXT " + "GENERATED ALWAYS AS (json_extract(metadata, '$.kb_doc_id')) STORED" + ) + ) + await conn.execute( + text( + "ALTER TABLE documents ADD COLUMN user_id TEXT " + "GENERATED ALWAYS AS (json_extract(metadata, '$.user_id')) STORED" + ) + ) + + # Create indexes + await conn.execute( + text( + "CREATE INDEX IF NOT EXISTS idx_documents_kb_doc_id ON documents(kb_doc_id)" + ) + ) + await conn.execute( + text( + "CREATE INDEX IF NOT EXISTS idx_documents_user_id ON documents(user_id)" + ) + ) + except BaseException: + pass + + await conn.commit() async def connect(self): """Connect to the SQLite database.""" - self.connection = await aiosqlite.connect(self.db_path) + if self.engine is None: + self.engine = create_async_engine( + self.DATABASE_URL, + echo=False, + future=True, + ) + self.async_session_maker = sessionmaker( + self.engine, # type: ignore + class_=AsyncSession, + expire_on_commit=False, + ) # type: ignore + + @asynccontextmanager + async def get_session(self): + """Context manager for database sessions.""" + async with self.async_session_maker() as session: # type: ignore + yield session async def get_documents( self, @@ -39,37 +106,114 @@ class DocumentStorage: Args: metadata_filters (dict): The metadata filters to apply. + ids (list | None): Optional list of document IDs to filter. + offset (int | None): Offset for pagination. + limit (int | None): Limit for pagination. Returns: - list: The list of document IDs(primary key, not doc_id) that match the filters. + list: The list of documents that match the filters. """ - assert self.connection is not None, "Database connection is not initialized." - # metadata filter -> SQL WHERE clause - where_clauses = [] - values = [] - for key, val in metadata_filters.items(): - where_clauses.append(f"json_extract(metadata, '$.{key}') = ?") - values.append(val) - if ids is not None and len(ids) > 0: - ids = [str(i) for i in ids if i != -1] - where_clauses.append("id IN ({})".format(",".join("?" * len(ids)))) - values.extend(ids) - where_sql = " AND ".join(where_clauses) or "1=1" + assert self.engine is not None, "Database connection is not initialized." + + async with self.get_session() as session: + query = select(Document) + + for key, val in metadata_filters.items(): + query = query.where( + text(f"json_extract(metadata, '$.{key}') = :filter_{key}") + ).params(**{f"filter_{key}": val}) + + if ids is not None and len(ids) > 0: + valid_ids = [int(i) for i in ids if i != -1] + if valid_ids: + query = query.where(col(Document.id).in_(valid_ids)) - result = [] - async with self.connection.cursor() as cursor: - sql = f"SELECT * FROM documents WHERE {where_sql}" if limit is not None: - sql += " LIMIT ?" - values.append(limit) + query = query.limit(limit) if offset is not None: - sql += " OFFSET ?" - values.append(offset) + query = query.offset(offset) - await cursor.execute(sql, values) - for row in await cursor.fetchall(): - result.append(await self.tuple_to_dict(row)) - return result + result = await session.execute(query) + documents = result.scalars().all() + + return [self._document_to_dict(doc) for doc in documents] + + async def insert_document(self, doc_id: str, text: str, metadata: dict) -> int: + """Insert a single document and return its integer ID. + + Args: + doc_id (str): The document ID (UUID string). + text (str): The document text. + metadata (dict): The document metadata. + + Returns: + int: The integer ID of the inserted document. + """ + assert self.engine is not None, "Database connection is not initialized." + + async with self.get_session() as session: + async with session.begin(): + document = Document( + doc_id=doc_id, + text=text, + metadata_=json.dumps(metadata), + created_at=datetime.now(), + updated_at=datetime.now(), + ) + session.add(document) + await session.flush() # Flush to get the ID + return document.id # type: ignore + + async def insert_documents_batch( + self, doc_ids: list[str], texts: list[str], metadatas: list[dict] + ) -> list[int]: + """Batch insert documents and return their integer IDs. + + Args: + doc_ids (list[str]): List of document IDs (UUID strings). + texts (list[str]): List of document texts. + metadatas (list[dict]): List of document metadata. + + Returns: + list[int]: List of integer IDs of the inserted documents. + """ + assert self.engine is not None, "Database connection is not initialized." + + async with self.get_session() as session: + async with session.begin(): + import json + + documents = [] + for doc_id, text, metadata in zip(doc_ids, texts, metadatas): + document = Document( + doc_id=doc_id, + text=text, + metadata_=json.dumps(metadata), + created_at=datetime.now(), + updated_at=datetime.now(), + ) + documents.append(document) + session.add(document) + + await session.flush() # Flush to get all IDs + return [doc.id for doc in documents] # type: ignore + + async def delete_document_by_doc_id(self, doc_id: str): + """Delete a document by its doc_id. + + Args: + doc_id (str): The doc_id of the document to delete. + """ + assert self.engine is not None, "Database connection is not initialized." + + async with self.get_session() as session: + async with session.begin(): + query = select(Document).where(col(Document.doc_id) == doc_id) + result = await session.execute(query) + document = result.scalar_one_or_none() + + if document: + await session.delete(document) async def get_document_by_doc_id(self, doc_id: str): """Retrieve a document by its doc_id. @@ -78,30 +222,38 @@ class DocumentStorage: doc_id (str): The doc_id of the document to retrieve. Returns: - dict: The document data. + dict: The document data or None if not found. """ - assert self.connection is not None, "Database connection is not initialized." - async with self.connection.cursor() as cursor: - await cursor.execute("SELECT * FROM documents WHERE doc_id = ?", (doc_id,)) - row = await cursor.fetchone() - if row: - return await self.tuple_to_dict(row) - else: - return None + assert self.engine is not None, "Database connection is not initialized." + + async with self.get_session() as session: + query = select(Document).where(col(Document.doc_id) == doc_id) + result = await session.execute(query) + document = result.scalar_one_or_none() + + if document: + return self._document_to_dict(document) + return None async def update_document_by_doc_id(self, doc_id: str, new_text: str): - """Retrieve a document by its doc_id. + """Update a document by its doc_id. Args: doc_id (str): The doc_id. new_text (str): The new text to update the document with. """ - assert self.connection is not None, "Database connection is not initialized." - async with self.connection.cursor() as cursor: - await cursor.execute( - "UPDATE documents SET text = ? WHERE doc_id = ?", (new_text, doc_id) - ) - await self.connection.commit() + assert self.engine is not None, "Database connection is not initialized." + + async with self.get_session() as session: + async with session.begin(): + query = select(Document).where(col(Document.doc_id) == doc_id) + result = await session.execute(query) + document = result.scalar_one_or_none() + + if document: + document.text = new_text + document.updated_at = datetime.now() + session.add(document) async def delete_documents(self, metadata_filters: dict): """Delete documents by their metadata filters. @@ -109,16 +261,22 @@ class DocumentStorage: Args: metadata_filters (dict): The metadata filters to apply. """ - assert self.connection is not None, "Database connection is not initialized." - async with self.connection.cursor() as cursor: - where_clauses = [] - values = [] - for key, val in metadata_filters.items(): - where_clauses.append(f"json_extract(metadata, '$.{key}') = ?") - values.append(val) - where_sql = " AND ".join(where_clauses) or "1=1" - await cursor.execute(f"DELETE FROM documents WHERE {where_sql}", values) - await self.connection.commit() + assert self.engine is not None, "Database connection is not initialized." + + async with self.get_session() as session: + async with session.begin(): + query = select(Document) + + for key, val in metadata_filters.items(): + query = query.where( + text(f"json_extract(metadata, '$.{key}') = :filter_{key}") + ).params(**{f"filter_{key}": val}) + + result = await session.execute(query) + documents = result.scalars().all() + + for doc in documents: + await session.delete(doc) async def count_documents(self, metadata_filters: dict | None = None) -> int: """Count documents in the database. @@ -129,20 +287,20 @@ class DocumentStorage: Returns: int: The count of documents. """ - assert self.connection is not None, "Database connection is not initialized." - async with self.connection.cursor() as cursor: - sql = "SELECT COUNT(*) FROM documents" - values = [] + assert self.engine is not None, "Database connection is not initialized." + + async with self.get_session() as session: + query = select(func.count(col(Document.id))) + if metadata_filters: - where_clauses = [] for key, val in metadata_filters.items(): - where_clauses.append(f"json_extract(metadata, '$.{key}') = ?") - values.append(val) - where_sql = " AND ".join(where_clauses) - sql += f" WHERE {where_sql}" - await cursor.execute(sql, values) - count = await cursor.fetchone() - return count[0] if count else 0 + query = query.where( + text(f"json_extract(metadata, '$.{key}') = :filter_{key}") + ).params(**{f"filter_{key}": val}) + + result = await session.execute(query) + count = result.scalar_one_or_none() + return count if count is not None else 0 async def get_user_ids(self) -> list[str]: """Retrieve all user IDs from the documents table. @@ -150,12 +308,38 @@ class DocumentStorage: Returns: list: A list of user IDs. """ - assert self.connection is not None, "Database connection is not initialized." - async with self.connection.cursor() as cursor: - await cursor.execute("SELECT DISTINCT user_id FROM documents") - rows = await cursor.fetchall() + assert self.engine is not None, "Database connection is not initialized." + + async with self.get_session() as session: + query = text( + "SELECT DISTINCT user_id FROM documents WHERE user_id IS NOT NULL" + ) + result = await session.execute(query) + rows = result.fetchall() return [row[0] for row in rows] + def _document_to_dict(self, document: Document) -> dict: + """Convert a Document model to a dictionary. + + Args: + document (Document): The document to convert. + + Returns: + dict: The converted dictionary. + """ + return { + "id": document.id, + "doc_id": document.doc_id, + "text": document.text, + "metadata": document.metadata_, + "created_at": document.created_at.isoformat() + if isinstance(document.created_at, datetime) + else document.created_at, + "updated_at": document.updated_at.isoformat() + if isinstance(document.updated_at, datetime) + else document.updated_at, + } + async def tuple_to_dict(self, row): """Convert a tuple to a dictionary. @@ -164,6 +348,8 @@ class DocumentStorage: Returns: dict: The converted dictionary. + + Note: This method is kept for backward compatibility but is no longer used internally. """ return { "id": row[0], @@ -176,6 +362,7 @@ class DocumentStorage: async def close(self): """Close the connection to the SQLite database.""" - if self.connection: - await self.connection.close() - self.connection = None + if self.engine: + await self.engine.dispose() + self.engine = None + self.async_session_maker = None diff --git a/astrbot/core/db/vec_db/faiss_impl/vec_db.py b/astrbot/core/db/vec_db/faiss_impl/vec_db.py index 587de2496..8a21538ec 100644 --- a/astrbot/core/db/vec_db/faiss_impl/vec_db.py +++ b/astrbot/core/db/vec_db/faiss_impl/vec_db.py @@ -1,5 +1,4 @@ import uuid -import json import time import numpy as np from .document_storage import DocumentStorage @@ -41,26 +40,18 @@ class FaissVecDB(BaseVecDB): """ 插入一条文本和其对应向量,自动生成 ID 并保持一致性。 """ - assert self.document_storage.connection is not None, ( - "Database connection is not initialized." - ) metadata = metadata or {} str_id = id or str(uuid.uuid4()) # 使用 UUID 作为原始 ID vector = await self.embedding_provider.get_embedding(content) vector = np.array(vector, dtype=np.float32) - async with self.document_storage.connection.cursor() as cursor: - await cursor.execute( - "INSERT INTO documents (doc_id, text, metadata) VALUES (?, ?, ?)", - (str_id, content, json.dumps(metadata)), - ) - await self.document_storage.connection.commit() - result = await self.document_storage.get_document_by_doc_id(str_id) - int_id = result["id"] - # 插入向量到 FAISS - await self.embedding_storage.insert(vector, int_id) - return int_id + # 使用 DocumentStorage 的方法插入文档 + int_id = await self.document_storage.insert_document(str_id, content, metadata) + + # 插入向量到 FAISS + await self.embedding_storage.insert(vector, int_id) + return int_id async def insert_batch( self, @@ -78,9 +69,6 @@ class FaissVecDB(BaseVecDB): Args: progress_callback: 进度回调函数,接收参数 (current, total) """ - assert self.document_storage.connection is not None, ( - "Database connection is not initialized." - ) metadatas = metadatas or [{} for _ in contents] ids = ids or [str(uuid.uuid4()) for _ in contents] @@ -98,23 +86,15 @@ class FaissVecDB(BaseVecDB): f"Generated embeddings for {len(contents)} contents in {end - start:.2f} seconds." ) - int_ids = [] - async with self.document_storage.connection.cursor() as cursor: - for str_id, content, metadata in zip(ids, contents, metadatas): - await cursor.execute( - "INSERT INTO documents (doc_id, text, metadata) VALUES (?, ?, ?)", - (str_id, content, json.dumps(metadata)), - ) - await self.document_storage.connection.commit() + # 使用 DocumentStorage 的批量插入方法 + int_ids = await self.document_storage.insert_documents_batch( + ids, contents, metadatas + ) - for str_id in ids: - result = await self.document_storage.get_document_by_doc_id(str_id) - int_ids.append(result["id"]) - - # 批量插入向量到 FAISS - vectors_array = np.array(vectors).astype("float32") - await self.embedding_storage.insert_batch(vectors_array, int_ids) - return int_ids + # 批量插入向量到 FAISS + vectors_array = np.array(vectors).astype("float32") + await self.embedding_storage.insert_batch(vectors_array, int_ids) + return int_ids async def retrieve( self, @@ -182,19 +162,15 @@ class FaissVecDB(BaseVecDB): """ 删除一条文档块(chunk) """ - assert self.document_storage.connection is not None, ( - "Database connection is not initialized." - ) # 获得对应的 int id result = await self.document_storage.get_document_by_doc_id(doc_id) int_id = result["id"] if result else None if int_id is None: return - await self.document_storage.connection.execute( - "DELETE FROM documents WHERE doc_id = ?", (doc_id,) - ) + + # 使用 DocumentStorage 的删除方法 + await self.document_storage.delete_document_by_doc_id(doc_id) await self.embedding_storage.delete([int_id]) - await self.document_storage.connection.commit() async def close(self): await self.document_storage.close() @@ -206,9 +182,6 @@ class FaissVecDB(BaseVecDB): Args: metadata_filter (dict | None): 元数据过滤器 """ - assert self.document_storage.connection is not None, ( - "Database connection is not initialized." - ) count = await self.document_storage.count_documents( metadata_filters=metadata_filter or {} ) diff --git a/astrbot/core/knowledge_base/kb_db_sqlite.py b/astrbot/core/knowledge_base/kb_db_sqlite.py index 2a67d6514..827d621d3 100644 --- a/astrbot/core/knowledge_base/kb_db_sqlite.py +++ b/astrbot/core/knowledge_base/kb_db_sqlite.py @@ -14,8 +14,6 @@ from astrbot.core.knowledge_base.models import ( ) from astrbot.core.db.vec_db.faiss_impl import FaissVecDB -from typing import Optional - class KBSQLiteDatabase: def __init__(self, db_path: str = "data/knowledge_base/kb.db") -> None: @@ -167,14 +165,14 @@ class KBSQLiteDatabase: await self.engine.dispose() logger.info(f"知识库数据库已关闭: {self.db_path}") - async def get_kb_by_id(self, kb_id: str) -> Optional[KnowledgeBase]: + async def get_kb_by_id(self, kb_id: str) -> KnowledgeBase | None: """根据 ID 获取知识库""" async with self.get_db() as session: stmt = select(KnowledgeBase).where(col(KnowledgeBase.kb_id) == kb_id) result = await session.execute(stmt) return result.scalar_one_or_none() - async def get_kb_by_name(self, kb_name: str) -> Optional[KnowledgeBase]: + async def get_kb_by_name(self, kb_name: str) -> KnowledgeBase | None: """根据名称获取知识库""" async with self.get_db() as session: stmt = select(KnowledgeBase).where(col(KnowledgeBase.kb_name) == kb_name) @@ -202,7 +200,7 @@ class KBSQLiteDatabase: # ===== 文档查询 ===== - async def get_document_by_id(self, doc_id: str) -> Optional[KBDocument]: + async def get_document_by_id(self, doc_id: str) -> KBDocument | None: """根据 ID 获取文档""" async with self.get_db() as session: stmt = select(KBDocument).where(col(KBDocument.doc_id) == doc_id) @@ -233,7 +231,7 @@ class KBSQLiteDatabase: result = await session.execute(stmt) return result.scalar() or 0 - async def get_document_with_metadata(self, doc_id: str) -> Optional[dict]: + async def get_document_with_metadata(self, doc_id: str) -> dict | None: async with self.get_db() as session: stmt = ( select(KBDocument, KnowledgeBase) @@ -262,7 +260,7 @@ class KBSQLiteDatabase: await session.commit() # 在 vec db 中删除相关向量 - await vec_db.delete_documents(metadata_filters={"doc_id": doc_id}) + await vec_db.delete_documents(metadata_filters={"kb_doc_id": doc_id}) # ===== 多媒体查询 ===== @@ -273,7 +271,7 @@ class KBSQLiteDatabase: result = await session.execute(stmt) return list(result.scalars().all()) - async def get_media_by_id(self, media_id: str) -> Optional[KBMedia]: + async def get_media_by_id(self, media_id: str) -> KBMedia | None: """根据 ID 获取多媒体资源""" async with self.get_db() as session: stmt = select(KBMedia).where(col(KBMedia.media_id) == media_id) diff --git a/astrbot/core/knowledge_base/kb_helper.py b/astrbot/core/knowledge_base/kb_helper.py index ebad6d14a..dc9febea9 100644 --- a/astrbot/core/knowledge_base/kb_helper.py +++ b/astrbot/core/knowledge_base/kb_helper.py @@ -175,7 +175,7 @@ class KBHelper: metadatas.append( { "kb_id": self.kb.kb_id, - "doc_id": doc_id, + "kb_doc_id": doc_id, "chunk_index": idx, } ) @@ -297,7 +297,7 @@ class KBHelper: """获取文档的所有块及其元数据""" vec_db: FaissVecDB = self.vec_db # type: ignore chunks = await vec_db.document_storage.get_documents( - metadata_filters={"doc_id": doc_id}, offset=offset, limit=limit + metadata_filters={"kb_doc_id": doc_id}, offset=offset, limit=limit ) result = [] for chunk in chunks: @@ -305,7 +305,7 @@ class KBHelper: result.append( { "chunk_id": chunk["doc_id"], - "doc_id": chunk_md["doc_id"], + "doc_id": chunk_md["kb_doc_id"], "kb_id": chunk_md["kb_id"], "chunk_index": chunk_md["chunk_index"], "content": chunk["text"], @@ -317,7 +317,7 @@ class KBHelper: async def get_chunk_count_by_doc_id(self, doc_id: str) -> int: """获取文档的块数量""" vec_db: FaissVecDB = self.vec_db # type: ignore - count = await vec_db.count_documents(metadata_filter={"doc_id": doc_id}) + count = await vec_db.count_documents(metadata_filter={"kb_doc_id": doc_id}) return count async def _save_media( diff --git a/astrbot/core/knowledge_base/models.py b/astrbot/core/knowledge_base/models.py index d46cdc9b3..010d6113c 100644 --- a/astrbot/core/knowledge_base/models.py +++ b/astrbot/core/knowledge_base/models.py @@ -1,12 +1,11 @@ import uuid from datetime import datetime, timezone -from typing import Optional -from sqlmodel import Field, SQLModel, Text, UniqueConstraint +from sqlmodel import Field, SQLModel, Text, UniqueConstraint, MetaData class BaseKBModel(SQLModel, table=False): - pass + metadata = MetaData() class KnowledgeBase(BaseKBModel, table=True): @@ -28,17 +27,17 @@ class KnowledgeBase(BaseKBModel, table=True): index=True, ) kb_name: str = Field(max_length=100, nullable=False) - description: Optional[str] = Field(default=None, sa_type=Text) - emoji: Optional[str] = Field(default="📚", max_length=10) - embedding_provider_id: Optional[str] = Field(default=None, max_length=100) - rerank_provider_id: Optional[str] = Field(default=None, max_length=100) + description: str | None = Field(default=None, sa_type=Text) + emoji: str | None = Field(default="📚", max_length=10) + embedding_provider_id: str | None = Field(default=None, max_length=100) + rerank_provider_id: str | None = Field(default=None, max_length=100) # 分块配置参数 - chunk_size: Optional[int] = Field(default=512, nullable=True) - chunk_overlap: Optional[int] = Field(default=50, nullable=True) + chunk_size: int | None = Field(default=512, nullable=True) + chunk_overlap: int | None = Field(default=50, nullable=True) # 检索配置参数 - top_k_dense: Optional[int] = Field(default=50, nullable=True) - top_k_sparse: Optional[int] = Field(default=50, nullable=True) - top_m_final: Optional[int] = Field(default=5, nullable=True) + top_k_dense: int | None = Field(default=50, nullable=True) + top_k_sparse: int | None = Field(default=50, nullable=True) + top_m_final: int | None = Field(default=5, nullable=True) created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) updated_at: datetime = Field( default_factory=lambda: datetime.now(timezone.utc), diff --git a/astrbot/core/knowledge_base/retrieval/rank_fusion.py b/astrbot/core/knowledge_base/retrieval/rank_fusion.py index d148bc2cc..3ceba4ff8 100644 --- a/astrbot/core/knowledge_base/retrieval/rank_fusion.py +++ b/astrbot/core/knowledge_base/retrieval/rank_fusion.py @@ -5,7 +5,6 @@ import json from dataclasses import dataclass -from typing import Dict, List from astrbot.core.db.vec_db.base import Result from astrbot.core.knowledge_base.kb_db_sqlite import KBSQLiteDatabase @@ -44,10 +43,10 @@ class RankFusion: async def fuse( self, - dense_results: List[Result], - sparse_results: List[SparseResult], + dense_results: list[Result], + sparse_results: list[SparseResult], top_k: int = 20, - ) -> List[FusedResult]: + ) -> list[FusedResult]: """融合稠密和稀疏检索结果 RRF 公式: @@ -85,7 +84,7 @@ class RankFusion: vec_doc_id_to_dense[vec_doc_id] = r # 3. 计算 RRF 分数 - rrf_scores: Dict[str, float] = {} + rrf_scores: dict[str, float] = {} for identifier in all_chunk_ids: score = 0.0 @@ -129,7 +128,7 @@ class RankFusion: FusedResult( chunk_id=identifier, chunk_index=chunk_md["chunk_index"], - doc_id=chunk_md["doc_id"], + doc_id=chunk_md["kb_doc_id"], kb_id=chunk_md["kb_id"], content=vec_result.data["text"], score=rrf_scores[identifier], diff --git a/astrbot/core/knowledge_base/retrieval/sparse_retriever.py b/astrbot/core/knowledge_base/retrieval/sparse_retriever.py index d9ff915d3..315930b3e 100644 --- a/astrbot/core/knowledge_base/retrieval/sparse_retriever.py +++ b/astrbot/core/knowledge_base/retrieval/sparse_retriever.py @@ -7,7 +7,6 @@ import jieba import os import json from dataclasses import dataclass -from typing import List from rank_bm25 import BM25Okapi from astrbot.core.knowledge_base.kb_db_sqlite import KBSQLiteDatabase from astrbot.core.db.vec_db.faiss_impl import FaissVecDB @@ -44,7 +43,6 @@ class SparseRetriever: with open( os.path.join(os.path.dirname(__file__), "hit_stopwords.txt"), - "r", encoding="utf-8", ) as f: self.hit_stopwords = { @@ -54,9 +52,9 @@ class SparseRetriever: async def retrieve( self, query: str, - kb_ids: List[str], + kb_ids: list[str], kb_options: dict, - ) -> List[SparseResult]: + ) -> list[SparseResult]: """执行稀疏检索 Args: @@ -82,7 +80,7 @@ class SparseRetriever: { "chunk_id": doc["doc_id"], "chunk_index": chunk_md["chunk_index"], - "doc_id": chunk_md["doc_id"], + "doc_id": chunk_md["kb_doc_id"], "kb_id": kb_id, "text": doc["text"], } diff --git a/dashboard/src/views/knowledge-base/KBList.vue b/dashboard/src/views/knowledge-base/KBList.vue index 3a69bd44e..33516a215 100644 --- a/dashboard/src/views/knowledge-base/KBList.vue +++ b/dashboard/src/views/knowledge-base/KBList.vue @@ -94,8 +94,7 @@ + :label="t('create.embeddingModelLabel')" variant="outlined" class="mb-4" :disabled="editingKB !== null">