feat: refactor document storage to use SQLModel and enhance database operations

This commit is contained in:
Soulter
2025-10-24 23:17:37 +08:00
parent 1969abc340
commit 8f021eb35a
8 changed files with 317 additions and 177 deletions
@@ -1,32 +1,99 @@
import aiosqlite
import os
import json
from datetime import datetime
from contextlib import asynccontextmanager
from sqlalchemy import Text, Column
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
from sqlmodel import Field, SQLModel, select, col, func, text, MetaData
class BaseDocModel(SQLModel, table=False):
metadata = MetaData()
class Document(BaseDocModel, table=True):
"""SQLModel for documents table."""
__tablename__ = "documents" # type: ignore
id: int | None = Field(
default=None, primary_key=True, sa_column_kwargs={"autoincrement": True}
)
doc_id: str = Field(nullable=False)
text: str = Field(nullable=False)
metadata_: str | None = Field(default=None, sa_column=Column("metadata", Text))
created_at: datetime | None = Field(default=None)
updated_at: datetime | None = Field(default=None)
class DocumentStorage:
def __init__(self, db_path: str):
self.db_path = db_path
self.connection = None
self.DATABASE_URL = f"sqlite+aiosqlite:///{db_path}"
self.engine: AsyncEngine | None = None
self.async_session_maker: sessionmaker | None = None
self.sqlite_init_path = os.path.join(
os.path.dirname(__file__), "sqlite_init.sql"
)
async def initialize(self):
"""Initialize the SQLite database and create the documents table if it doesn't exist."""
if not os.path.exists(self.db_path):
await self.connect()
if not self.connection:
raise RuntimeError("Failed to connect to the database.")
async with self.connection.cursor() as cursor:
with open(self.sqlite_init_path, "r", encoding="utf-8") as f:
sql_script = f.read()
await cursor.executescript(sql_script)
await self.connection.commit()
else:
await self.connect()
await self.connect()
async with self.engine.begin() as conn: # type: ignore
# Create tables using SQLModel
await conn.run_sync(BaseDocModel.metadata.create_all)
try:
await conn.execute(
text(
"ALTER TABLE documents ADD COLUMN kb_doc_id TEXT "
"GENERATED ALWAYS AS (json_extract(metadata, '$.kb_doc_id')) STORED"
)
)
await conn.execute(
text(
"ALTER TABLE documents ADD COLUMN user_id TEXT "
"GENERATED ALWAYS AS (json_extract(metadata, '$.user_id')) STORED"
)
)
# Create indexes
await conn.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_documents_kb_doc_id ON documents(kb_doc_id)"
)
)
await conn.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_documents_user_id ON documents(user_id)"
)
)
except BaseException:
pass
await conn.commit()
async def connect(self):
"""Connect to the SQLite database."""
self.connection = await aiosqlite.connect(self.db_path)
if self.engine is None:
self.engine = create_async_engine(
self.DATABASE_URL,
echo=False,
future=True,
)
self.async_session_maker = sessionmaker(
self.engine, # type: ignore
class_=AsyncSession,
expire_on_commit=False,
) # type: ignore
@asynccontextmanager
async def get_session(self):
"""Context manager for database sessions."""
async with self.async_session_maker() as session: # type: ignore
yield session
async def get_documents(
self,
@@ -39,37 +106,114 @@ class DocumentStorage:
Args:
metadata_filters (dict): The metadata filters to apply.
ids (list | None): Optional list of document IDs to filter.
offset (int | None): Offset for pagination.
limit (int | None): Limit for pagination.
Returns:
list: The list of document IDs(primary key, not doc_id) that match the filters.
list: The list of documents that match the filters.
"""
assert self.connection is not None, "Database connection is not initialized."
# metadata filter -> SQL WHERE clause
where_clauses = []
values = []
for key, val in metadata_filters.items():
where_clauses.append(f"json_extract(metadata, '$.{key}') = ?")
values.append(val)
if ids is not None and len(ids) > 0:
ids = [str(i) for i in ids if i != -1]
where_clauses.append("id IN ({})".format(",".join("?" * len(ids))))
values.extend(ids)
where_sql = " AND ".join(where_clauses) or "1=1"
assert self.engine is not None, "Database connection is not initialized."
async with self.get_session() as session:
query = select(Document)
for key, val in metadata_filters.items():
query = query.where(
text(f"json_extract(metadata, '$.{key}') = :filter_{key}")
).params(**{f"filter_{key}": val})
if ids is not None and len(ids) > 0:
valid_ids = [int(i) for i in ids if i != -1]
if valid_ids:
query = query.where(col(Document.id).in_(valid_ids))
result = []
async with self.connection.cursor() as cursor:
sql = f"SELECT * FROM documents WHERE {where_sql}"
if limit is not None:
sql += " LIMIT ?"
values.append(limit)
query = query.limit(limit)
if offset is not None:
sql += " OFFSET ?"
values.append(offset)
query = query.offset(offset)
await cursor.execute(sql, values)
for row in await cursor.fetchall():
result.append(await self.tuple_to_dict(row))
return result
result = await session.execute(query)
documents = result.scalars().all()
return [self._document_to_dict(doc) for doc in documents]
async def insert_document(self, doc_id: str, text: str, metadata: dict) -> int:
"""Insert a single document and return its integer ID.
Args:
doc_id (str): The document ID (UUID string).
text (str): The document text.
metadata (dict): The document metadata.
Returns:
int: The integer ID of the inserted document.
"""
assert self.engine is not None, "Database connection is not initialized."
async with self.get_session() as session:
async with session.begin():
document = Document(
doc_id=doc_id,
text=text,
metadata_=json.dumps(metadata),
created_at=datetime.now(),
updated_at=datetime.now(),
)
session.add(document)
await session.flush() # Flush to get the ID
return document.id # type: ignore
async def insert_documents_batch(
self, doc_ids: list[str], texts: list[str], metadatas: list[dict]
) -> list[int]:
"""Batch insert documents and return their integer IDs.
Args:
doc_ids (list[str]): List of document IDs (UUID strings).
texts (list[str]): List of document texts.
metadatas (list[dict]): List of document metadata.
Returns:
list[int]: List of integer IDs of the inserted documents.
"""
assert self.engine is not None, "Database connection is not initialized."
async with self.get_session() as session:
async with session.begin():
import json
documents = []
for doc_id, text, metadata in zip(doc_ids, texts, metadatas):
document = Document(
doc_id=doc_id,
text=text,
metadata_=json.dumps(metadata),
created_at=datetime.now(),
updated_at=datetime.now(),
)
documents.append(document)
session.add(document)
await session.flush() # Flush to get all IDs
return [doc.id for doc in documents] # type: ignore
async def delete_document_by_doc_id(self, doc_id: str):
"""Delete a document by its doc_id.
Args:
doc_id (str): The doc_id of the document to delete.
"""
assert self.engine is not None, "Database connection is not initialized."
async with self.get_session() as session:
async with session.begin():
query = select(Document).where(col(Document.doc_id) == doc_id)
result = await session.execute(query)
document = result.scalar_one_or_none()
if document:
await session.delete(document)
async def get_document_by_doc_id(self, doc_id: str):
"""Retrieve a document by its doc_id.
@@ -78,30 +222,38 @@ class DocumentStorage:
doc_id (str): The doc_id of the document to retrieve.
Returns:
dict: The document data.
dict: The document data or None if not found.
"""
assert self.connection is not None, "Database connection is not initialized."
async with self.connection.cursor() as cursor:
await cursor.execute("SELECT * FROM documents WHERE doc_id = ?", (doc_id,))
row = await cursor.fetchone()
if row:
return await self.tuple_to_dict(row)
else:
return None
assert self.engine is not None, "Database connection is not initialized."
async with self.get_session() as session:
query = select(Document).where(col(Document.doc_id) == doc_id)
result = await session.execute(query)
document = result.scalar_one_or_none()
if document:
return self._document_to_dict(document)
return None
async def update_document_by_doc_id(self, doc_id: str, new_text: str):
"""Retrieve a document by its doc_id.
"""Update a document by its doc_id.
Args:
doc_id (str): The doc_id.
new_text (str): The new text to update the document with.
"""
assert self.connection is not None, "Database connection is not initialized."
async with self.connection.cursor() as cursor:
await cursor.execute(
"UPDATE documents SET text = ? WHERE doc_id = ?", (new_text, doc_id)
)
await self.connection.commit()
assert self.engine is not None, "Database connection is not initialized."
async with self.get_session() as session:
async with session.begin():
query = select(Document).where(col(Document.doc_id) == doc_id)
result = await session.execute(query)
document = result.scalar_one_or_none()
if document:
document.text = new_text
document.updated_at = datetime.now()
session.add(document)
async def delete_documents(self, metadata_filters: dict):
"""Delete documents by their metadata filters.
@@ -109,16 +261,22 @@ class DocumentStorage:
Args:
metadata_filters (dict): The metadata filters to apply.
"""
assert self.connection is not None, "Database connection is not initialized."
async with self.connection.cursor() as cursor:
where_clauses = []
values = []
for key, val in metadata_filters.items():
where_clauses.append(f"json_extract(metadata, '$.{key}') = ?")
values.append(val)
where_sql = " AND ".join(where_clauses) or "1=1"
await cursor.execute(f"DELETE FROM documents WHERE {where_sql}", values)
await self.connection.commit()
assert self.engine is not None, "Database connection is not initialized."
async with self.get_session() as session:
async with session.begin():
query = select(Document)
for key, val in metadata_filters.items():
query = query.where(
text(f"json_extract(metadata, '$.{key}') = :filter_{key}")
).params(**{f"filter_{key}": val})
result = await session.execute(query)
documents = result.scalars().all()
for doc in documents:
await session.delete(doc)
async def count_documents(self, metadata_filters: dict | None = None) -> int:
"""Count documents in the database.
@@ -129,20 +287,20 @@ class DocumentStorage:
Returns:
int: The count of documents.
"""
assert self.connection is not None, "Database connection is not initialized."
async with self.connection.cursor() as cursor:
sql = "SELECT COUNT(*) FROM documents"
values = []
assert self.engine is not None, "Database connection is not initialized."
async with self.get_session() as session:
query = select(func.count(col(Document.id)))
if metadata_filters:
where_clauses = []
for key, val in metadata_filters.items():
where_clauses.append(f"json_extract(metadata, '$.{key}') = ?")
values.append(val)
where_sql = " AND ".join(where_clauses)
sql += f" WHERE {where_sql}"
await cursor.execute(sql, values)
count = await cursor.fetchone()
return count[0] if count else 0
query = query.where(
text(f"json_extract(metadata, '$.{key}') = :filter_{key}")
).params(**{f"filter_{key}": val})
result = await session.execute(query)
count = result.scalar_one_or_none()
return count if count is not None else 0
async def get_user_ids(self) -> list[str]:
"""Retrieve all user IDs from the documents table.
@@ -150,12 +308,38 @@ class DocumentStorage:
Returns:
list: A list of user IDs.
"""
assert self.connection is not None, "Database connection is not initialized."
async with self.connection.cursor() as cursor:
await cursor.execute("SELECT DISTINCT user_id FROM documents")
rows = await cursor.fetchall()
assert self.engine is not None, "Database connection is not initialized."
async with self.get_session() as session:
query = text(
"SELECT DISTINCT user_id FROM documents WHERE user_id IS NOT NULL"
)
result = await session.execute(query)
rows = result.fetchall()
return [row[0] for row in rows]
def _document_to_dict(self, document: Document) -> dict:
"""Convert a Document model to a dictionary.
Args:
document (Document): The document to convert.
Returns:
dict: The converted dictionary.
"""
return {
"id": document.id,
"doc_id": document.doc_id,
"text": document.text,
"metadata": document.metadata_,
"created_at": document.created_at.isoformat()
if isinstance(document.created_at, datetime)
else document.created_at,
"updated_at": document.updated_at.isoformat()
if isinstance(document.updated_at, datetime)
else document.updated_at,
}
async def tuple_to_dict(self, row):
"""Convert a tuple to a dictionary.
@@ -164,6 +348,8 @@ class DocumentStorage:
Returns:
dict: The converted dictionary.
Note: This method is kept for backward compatibility but is no longer used internally.
"""
return {
"id": row[0],
@@ -176,6 +362,7 @@ class DocumentStorage:
async def close(self):
"""Close the connection to the SQLite database."""
if self.connection:
await self.connection.close()
self.connection = None
if self.engine:
await self.engine.dispose()
self.engine = None
self.async_session_maker = None
+17 -44
View File
@@ -1,5 +1,4 @@
import uuid
import json
import time
import numpy as np
from .document_storage import DocumentStorage
@@ -41,26 +40,18 @@ class FaissVecDB(BaseVecDB):
"""
插入一条文本和其对应向量,自动生成 ID 并保持一致性。
"""
assert self.document_storage.connection is not None, (
"Database connection is not initialized."
)
metadata = metadata or {}
str_id = id or str(uuid.uuid4()) # 使用 UUID 作为原始 ID
vector = await self.embedding_provider.get_embedding(content)
vector = np.array(vector, dtype=np.float32)
async with self.document_storage.connection.cursor() as cursor:
await cursor.execute(
"INSERT INTO documents (doc_id, text, metadata) VALUES (?, ?, ?)",
(str_id, content, json.dumps(metadata)),
)
await self.document_storage.connection.commit()
result = await self.document_storage.get_document_by_doc_id(str_id)
int_id = result["id"]
# 插入向量到 FAISS
await self.embedding_storage.insert(vector, int_id)
return int_id
# 使用 DocumentStorage 的方法插入文档
int_id = await self.document_storage.insert_document(str_id, content, metadata)
# 插入向量到 FAISS
await self.embedding_storage.insert(vector, int_id)
return int_id
async def insert_batch(
self,
@@ -78,9 +69,6 @@ class FaissVecDB(BaseVecDB):
Args:
progress_callback: 进度回调函数,接收参数 (current, total)
"""
assert self.document_storage.connection is not None, (
"Database connection is not initialized."
)
metadatas = metadatas or [{} for _ in contents]
ids = ids or [str(uuid.uuid4()) for _ in contents]
@@ -98,23 +86,15 @@ class FaissVecDB(BaseVecDB):
f"Generated embeddings for {len(contents)} contents in {end - start:.2f} seconds."
)
int_ids = []
async with self.document_storage.connection.cursor() as cursor:
for str_id, content, metadata in zip(ids, contents, metadatas):
await cursor.execute(
"INSERT INTO documents (doc_id, text, metadata) VALUES (?, ?, ?)",
(str_id, content, json.dumps(metadata)),
)
await self.document_storage.connection.commit()
# 使用 DocumentStorage 的批量插入方法
int_ids = await self.document_storage.insert_documents_batch(
ids, contents, metadatas
)
for str_id in ids:
result = await self.document_storage.get_document_by_doc_id(str_id)
int_ids.append(result["id"])
# 批量插入向量到 FAISS
vectors_array = np.array(vectors).astype("float32")
await self.embedding_storage.insert_batch(vectors_array, int_ids)
return int_ids
# 批量插入向量到 FAISS
vectors_array = np.array(vectors).astype("float32")
await self.embedding_storage.insert_batch(vectors_array, int_ids)
return int_ids
async def retrieve(
self,
@@ -182,19 +162,15 @@ class FaissVecDB(BaseVecDB):
"""
删除一条文档块(chunk)
"""
assert self.document_storage.connection is not None, (
"Database connection is not initialized."
)
# 获得对应的 int id
result = await self.document_storage.get_document_by_doc_id(doc_id)
int_id = result["id"] if result else None
if int_id is None:
return
await self.document_storage.connection.execute(
"DELETE FROM documents WHERE doc_id = ?", (doc_id,)
)
# 使用 DocumentStorage 的删除方法
await self.document_storage.delete_document_by_doc_id(doc_id)
await self.embedding_storage.delete([int_id])
await self.document_storage.connection.commit()
async def close(self):
await self.document_storage.close()
@@ -206,9 +182,6 @@ class FaissVecDB(BaseVecDB):
Args:
metadata_filter (dict | None): 元数据过滤器
"""
assert self.document_storage.connection is not None, (
"Database connection is not initialized."
)
count = await self.document_storage.count_documents(
metadata_filters=metadata_filter or {}
)
+6 -8
View File
@@ -14,8 +14,6 @@ from astrbot.core.knowledge_base.models import (
)
from astrbot.core.db.vec_db.faiss_impl import FaissVecDB
from typing import Optional
class KBSQLiteDatabase:
def __init__(self, db_path: str = "data/knowledge_base/kb.db") -> None:
@@ -167,14 +165,14 @@ class KBSQLiteDatabase:
await self.engine.dispose()
logger.info(f"知识库数据库已关闭: {self.db_path}")
async def get_kb_by_id(self, kb_id: str) -> Optional[KnowledgeBase]:
async def get_kb_by_id(self, kb_id: str) -> KnowledgeBase | None:
"""根据 ID 获取知识库"""
async with self.get_db() as session:
stmt = select(KnowledgeBase).where(col(KnowledgeBase.kb_id) == kb_id)
result = await session.execute(stmt)
return result.scalar_one_or_none()
async def get_kb_by_name(self, kb_name: str) -> Optional[KnowledgeBase]:
async def get_kb_by_name(self, kb_name: str) -> KnowledgeBase | None:
"""根据名称获取知识库"""
async with self.get_db() as session:
stmt = select(KnowledgeBase).where(col(KnowledgeBase.kb_name) == kb_name)
@@ -202,7 +200,7 @@ class KBSQLiteDatabase:
# ===== 文档查询 =====
async def get_document_by_id(self, doc_id: str) -> Optional[KBDocument]:
async def get_document_by_id(self, doc_id: str) -> KBDocument | None:
"""根据 ID 获取文档"""
async with self.get_db() as session:
stmt = select(KBDocument).where(col(KBDocument.doc_id) == doc_id)
@@ -233,7 +231,7 @@ class KBSQLiteDatabase:
result = await session.execute(stmt)
return result.scalar() or 0
async def get_document_with_metadata(self, doc_id: str) -> Optional[dict]:
async def get_document_with_metadata(self, doc_id: str) -> dict | None:
async with self.get_db() as session:
stmt = (
select(KBDocument, KnowledgeBase)
@@ -262,7 +260,7 @@ class KBSQLiteDatabase:
await session.commit()
# 在 vec db 中删除相关向量
await vec_db.delete_documents(metadata_filters={"doc_id": doc_id})
await vec_db.delete_documents(metadata_filters={"kb_doc_id": doc_id})
# ===== 多媒体查询 =====
@@ -273,7 +271,7 @@ class KBSQLiteDatabase:
result = await session.execute(stmt)
return list(result.scalars().all())
async def get_media_by_id(self, media_id: str) -> Optional[KBMedia]:
async def get_media_by_id(self, media_id: str) -> KBMedia | None:
"""根据 ID 获取多媒体资源"""
async with self.get_db() as session:
stmt = select(KBMedia).where(col(KBMedia.media_id) == media_id)
+4 -4
View File
@@ -175,7 +175,7 @@ class KBHelper:
metadatas.append(
{
"kb_id": self.kb.kb_id,
"doc_id": doc_id,
"kb_doc_id": doc_id,
"chunk_index": idx,
}
)
@@ -297,7 +297,7 @@ class KBHelper:
"""获取文档的所有块及其元数据"""
vec_db: FaissVecDB = self.vec_db # type: ignore
chunks = await vec_db.document_storage.get_documents(
metadata_filters={"doc_id": doc_id}, offset=offset, limit=limit
metadata_filters={"kb_doc_id": doc_id}, offset=offset, limit=limit
)
result = []
for chunk in chunks:
@@ -305,7 +305,7 @@ class KBHelper:
result.append(
{
"chunk_id": chunk["doc_id"],
"doc_id": chunk_md["doc_id"],
"doc_id": chunk_md["kb_doc_id"],
"kb_id": chunk_md["kb_id"],
"chunk_index": chunk_md["chunk_index"],
"content": chunk["text"],
@@ -317,7 +317,7 @@ class KBHelper:
async def get_chunk_count_by_doc_id(self, doc_id: str) -> int:
"""获取文档的块数量"""
vec_db: FaissVecDB = self.vec_db # type: ignore
count = await vec_db.count_documents(metadata_filter={"doc_id": doc_id})
count = await vec_db.count_documents(metadata_filter={"kb_doc_id": doc_id})
return count
async def _save_media(
+11 -12
View File
@@ -1,12 +1,11 @@
import uuid
from datetime import datetime, timezone
from typing import Optional
from sqlmodel import Field, SQLModel, Text, UniqueConstraint
from sqlmodel import Field, SQLModel, Text, UniqueConstraint, MetaData
class BaseKBModel(SQLModel, table=False):
pass
metadata = MetaData()
class KnowledgeBase(BaseKBModel, table=True):
@@ -28,17 +27,17 @@ class KnowledgeBase(BaseKBModel, table=True):
index=True,
)
kb_name: str = Field(max_length=100, nullable=False)
description: Optional[str] = Field(default=None, sa_type=Text)
emoji: Optional[str] = Field(default="📚", max_length=10)
embedding_provider_id: Optional[str] = Field(default=None, max_length=100)
rerank_provider_id: Optional[str] = Field(default=None, max_length=100)
description: str | None = Field(default=None, sa_type=Text)
emoji: str | None = Field(default="📚", max_length=10)
embedding_provider_id: str | None = Field(default=None, max_length=100)
rerank_provider_id: str | None = Field(default=None, max_length=100)
# 分块配置参数
chunk_size: Optional[int] = Field(default=512, nullable=True)
chunk_overlap: Optional[int] = Field(default=50, nullable=True)
chunk_size: int | None = Field(default=512, nullable=True)
chunk_overlap: int | None = Field(default=50, nullable=True)
# 检索配置参数
top_k_dense: Optional[int] = Field(default=50, nullable=True)
top_k_sparse: Optional[int] = Field(default=50, nullable=True)
top_m_final: Optional[int] = Field(default=5, nullable=True)
top_k_dense: int | None = Field(default=50, nullable=True)
top_k_sparse: int | None = Field(default=50, nullable=True)
top_m_final: int | None = Field(default=5, nullable=True)
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
updated_at: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc),
@@ -5,7 +5,6 @@
import json
from dataclasses import dataclass
from typing import Dict, List
from astrbot.core.db.vec_db.base import Result
from astrbot.core.knowledge_base.kb_db_sqlite import KBSQLiteDatabase
@@ -44,10 +43,10 @@ class RankFusion:
async def fuse(
self,
dense_results: List[Result],
sparse_results: List[SparseResult],
dense_results: list[Result],
sparse_results: list[SparseResult],
top_k: int = 20,
) -> List[FusedResult]:
) -> list[FusedResult]:
"""融合稠密和稀疏检索结果
RRF 公式:
@@ -85,7 +84,7 @@ class RankFusion:
vec_doc_id_to_dense[vec_doc_id] = r
# 3. 计算 RRF 分数
rrf_scores: Dict[str, float] = {}
rrf_scores: dict[str, float] = {}
for identifier in all_chunk_ids:
score = 0.0
@@ -129,7 +128,7 @@ class RankFusion:
FusedResult(
chunk_id=identifier,
chunk_index=chunk_md["chunk_index"],
doc_id=chunk_md["doc_id"],
doc_id=chunk_md["kb_doc_id"],
kb_id=chunk_md["kb_id"],
content=vec_result.data["text"],
score=rrf_scores[identifier],
@@ -7,7 +7,6 @@ import jieba
import os
import json
from dataclasses import dataclass
from typing import List
from rank_bm25 import BM25Okapi
from astrbot.core.knowledge_base.kb_db_sqlite import KBSQLiteDatabase
from astrbot.core.db.vec_db.faiss_impl import FaissVecDB
@@ -44,7 +43,6 @@ class SparseRetriever:
with open(
os.path.join(os.path.dirname(__file__), "hit_stopwords.txt"),
"r",
encoding="utf-8",
) as f:
self.hit_stopwords = {
@@ -54,9 +52,9 @@ class SparseRetriever:
async def retrieve(
self,
query: str,
kb_ids: List[str],
kb_ids: list[str],
kb_options: dict,
) -> List[SparseResult]:
) -> list[SparseResult]:
"""执行稀疏检索
Args:
@@ -82,7 +80,7 @@ class SparseRetriever:
{
"chunk_id": doc["doc_id"],
"chunk_index": chunk_md["chunk_index"],
"doc_id": chunk_md["doc_id"],
"doc_id": chunk_md["kb_doc_id"],
"kb_id": kb_id,
"text": doc["text"],
}
+1 -15
View File
@@ -94,8 +94,7 @@
<v-select v-model="formData.embedding_provider_id" :items="embeddingProviders"
:item-title="item => item.embedding_model || item.id" :item-value="'id'"
:label="t('create.embeddingModelLabel')" variant="outlined" class="mb-4" :disabled="true"
@update:model-value="handleEmbeddingProviderChange">
:label="t('create.embeddingModelLabel')" variant="outlined" class="mb-4" :disabled="editingKB !== null">
<template #item="{ props, item }">
<v-list-item v-bind="props">
<template #subtitle>
@@ -327,19 +326,6 @@ const editKB = (kb: any) => {
showCreateDialog.value = true
}
// 处理 embedding provider 变更
const handleEmbeddingProviderChange = (newValue: string | null) => {
// 检测是否修改了embedding provider
if (newValue && originalEmbeddingProvider.value && newValue !== originalEmbeddingProvider.value) {
// 显示二次确认对话框
showEmbeddingWarning.value = true
pendingEmbeddingProvider.value = newValue
embeddingChangeDialog.value = true
} else {
showEmbeddingWarning.value = false
}
}
// 确认删除
const confirmDelete = (kb: any) => {
deleteTarget.value = kb