feat: add progress callback for document upload and enhance upload progress tracking
This commit is contained in:
@@ -33,9 +33,13 @@ class BaseVecDB:
|
||||
batch_size: int = 32,
|
||||
tasks_limit: int = 3,
|
||||
max_retries: int = 3,
|
||||
progress_callback=None,
|
||||
) -> int:
|
||||
"""
|
||||
批量插入文本和其对应向量,自动生成 ID 并保持一致性。
|
||||
|
||||
Args:
|
||||
progress_callback: 进度回调函数,接收参数 (current, total)
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
@@ -70,9 +70,13 @@ class FaissVecDB(BaseVecDB):
|
||||
batch_size: int = 32,
|
||||
tasks_limit: int = 3,
|
||||
max_retries: int = 3,
|
||||
progress_callback=None,
|
||||
) -> list[int]:
|
||||
"""
|
||||
批量插入文本和其对应向量,自动生成 ID 并保持一致性。
|
||||
|
||||
Args:
|
||||
progress_callback: 进度回调函数,接收参数 (current, total)
|
||||
"""
|
||||
assert self.document_storage.connection is not None, (
|
||||
"Database connection is not initialized."
|
||||
@@ -87,6 +91,7 @@ class FaissVecDB(BaseVecDB):
|
||||
batch_size=batch_size,
|
||||
tasks_limit=tasks_limit,
|
||||
max_retries=max_retries,
|
||||
progress_callback=progress_callback,
|
||||
)
|
||||
end = time.time()
|
||||
logger.debug(
|
||||
|
||||
@@ -106,6 +106,7 @@ class KBHelper:
|
||||
batch_size: int = 32,
|
||||
tasks_limit: int = 3,
|
||||
max_retries: int = 3,
|
||||
progress_callback=None,
|
||||
) -> KBDocument:
|
||||
"""上传并处理文档(带原子性保证和失败清理)
|
||||
|
||||
@@ -117,6 +118,12 @@ class KBHelper:
|
||||
5. 生成向量并存储
|
||||
6. 保存元数据(事务)
|
||||
7. 更新统计
|
||||
|
||||
Args:
|
||||
progress_callback: 进度回调函数,接收参数 (stage, current, total)
|
||||
- stage: 当前阶段 ('parsing', 'chunking', 'embedding')
|
||||
- current: 当前进度
|
||||
- total: 总数
|
||||
"""
|
||||
await self._ensure_vec_db()
|
||||
doc_id = str(uuid.uuid4())
|
||||
@@ -127,6 +134,10 @@ class KBHelper:
|
||||
# await f.write(file_content)
|
||||
|
||||
try:
|
||||
# 阶段1: 解析文档
|
||||
if progress_callback:
|
||||
await progress_callback("parsing", 0, 100)
|
||||
|
||||
parser = self.parsers.get(file_type)
|
||||
if not parser:
|
||||
raise ValueError(f"不支持的文件类型: {file_type}")
|
||||
@@ -134,6 +145,9 @@ class KBHelper:
|
||||
text_content = parse_result.text
|
||||
media_items = parse_result.media
|
||||
|
||||
if progress_callback:
|
||||
await progress_callback("parsing", 100, 100)
|
||||
|
||||
# 保存媒体文件
|
||||
saved_media = []
|
||||
for media_item in media_items:
|
||||
@@ -147,7 +161,10 @@ class KBHelper:
|
||||
saved_media.append(media)
|
||||
media_paths.append(Path(media.file_path))
|
||||
|
||||
# 分块并生成向量
|
||||
# 阶段2: 分块
|
||||
if progress_callback:
|
||||
await progress_callback("chunking", 0, 100)
|
||||
|
||||
chunks_text = await self.chunker.chunk(
|
||||
text_content, chunk_size=chunk_size, chunk_overlap=chunk_overlap
|
||||
)
|
||||
@@ -162,12 +179,22 @@ class KBHelper:
|
||||
"chunk_index": idx,
|
||||
}
|
||||
)
|
||||
|
||||
if progress_callback:
|
||||
await progress_callback("chunking", 100, 100)
|
||||
|
||||
# 阶段3: 生成向量(带进度回调)
|
||||
async def embedding_progress_callback(current, total):
|
||||
if progress_callback:
|
||||
await progress_callback("embedding", current, total)
|
||||
|
||||
await self.vec_db.insert_batch(
|
||||
contents=contents,
|
||||
metadatas=metadatas,
|
||||
batch_size=batch_size,
|
||||
tasks_limit=tasks_limit,
|
||||
max_retries=max_retries,
|
||||
progress_callback=embedding_progress_callback,
|
||||
)
|
||||
|
||||
# 保存文档的元数据
|
||||
|
||||
@@ -210,6 +210,7 @@ class EmbeddingProvider(AbstractProvider):
|
||||
batch_size: int = 16,
|
||||
tasks_limit: int = 3,
|
||||
max_retries: int = 3,
|
||||
progress_callback=None,
|
||||
) -> list[list[float]]:
|
||||
"""批量获取文本的向量,分批处理以节省内存
|
||||
|
||||
@@ -218,6 +219,7 @@ class EmbeddingProvider(AbstractProvider):
|
||||
batch_size: 每批处理的文本数量
|
||||
tasks_limit: 并发任务数量限制
|
||||
max_retries: 失败时的最大重试次数
|
||||
progress_callback: 进度回调函数,接收参数 (current, total)
|
||||
|
||||
Returns:
|
||||
向量列表
|
||||
@@ -225,13 +227,19 @@ class EmbeddingProvider(AbstractProvider):
|
||||
semaphore = asyncio.Semaphore(tasks_limit)
|
||||
all_embeddings: list[list[float]] = []
|
||||
failed_batches: list[tuple[int, list[str]]] = []
|
||||
completed_count = 0
|
||||
total_count = len(texts)
|
||||
|
||||
async def process_batch(batch_idx: int, batch_texts: list[str]):
|
||||
nonlocal completed_count
|
||||
async with semaphore:
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
batch_embeddings = await self.get_embeddings(batch_texts)
|
||||
all_embeddings.extend(batch_embeddings)
|
||||
completed_count += len(batch_texts)
|
||||
if progress_callback:
|
||||
await progress_callback(completed_count, total_count)
|
||||
return
|
||||
except Exception as e:
|
||||
if attempt == max_retries - 1:
|
||||
|
||||
@@ -4,6 +4,7 @@ import uuid
|
||||
import aiofiles
|
||||
import os
|
||||
import traceback
|
||||
import asyncio
|
||||
from quart import request
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
@@ -27,6 +28,8 @@ class KnowledgeBaseRoute(Route):
|
||||
self.kb_db = None
|
||||
self.session_config_db = None # 会话配置数据库
|
||||
self.retrieval_manager = None
|
||||
self.upload_progress = {} # 存储上传进度 {task_id: {status, file_index, file_total, stage, current, total}}
|
||||
self.upload_tasks = {} # 存储后台上传任务 {task_id: {"status", "result", "error"}}
|
||||
|
||||
# 注册路由
|
||||
self.routes = {
|
||||
@@ -40,6 +43,7 @@ class KnowledgeBaseRoute(Route):
|
||||
# 文档管理
|
||||
"/kb/document/list": ("GET", self.list_documents),
|
||||
"/kb/document/upload": ("POST", self.upload_document),
|
||||
"/kb/document/upload/progress": ("GET", self.get_upload_progress),
|
||||
"/kb/document/get": ("GET", self.get_document),
|
||||
"/kb/document/delete": ("POST", self.delete_document),
|
||||
# # 块管理
|
||||
@@ -56,6 +60,112 @@ class KnowledgeBaseRoute(Route):
|
||||
def _get_kb_manager(self):
|
||||
return self.core_lifecycle.kb_manager
|
||||
|
||||
async def _background_upload_task(
|
||||
self,
|
||||
task_id: str,
|
||||
kb_helper,
|
||||
files_to_upload: list,
|
||||
chunk_size: int,
|
||||
chunk_overlap: int,
|
||||
batch_size: int,
|
||||
tasks_limit: int,
|
||||
max_retries: int,
|
||||
):
|
||||
"""后台上传任务"""
|
||||
try:
|
||||
# 初始化任务状态
|
||||
self.upload_tasks[task_id] = {
|
||||
"status": "processing",
|
||||
"result": None,
|
||||
"error": None,
|
||||
}
|
||||
self.upload_progress[task_id] = {
|
||||
"status": "processing",
|
||||
"file_index": 0,
|
||||
"file_total": len(files_to_upload),
|
||||
"stage": "waiting",
|
||||
"current": 0,
|
||||
"total": 100,
|
||||
}
|
||||
|
||||
uploaded_docs = []
|
||||
failed_docs = []
|
||||
|
||||
for file_idx, file_info in enumerate(files_to_upload):
|
||||
try:
|
||||
# 更新整体进度
|
||||
self.upload_progress[task_id].update(
|
||||
{
|
||||
"status": "processing",
|
||||
"file_index": file_idx,
|
||||
"file_name": file_info["file_name"],
|
||||
"stage": "parsing",
|
||||
"current": 0,
|
||||
"total": 100,
|
||||
}
|
||||
)
|
||||
|
||||
# 创建进度回调函数
|
||||
async def progress_callback(stage, current, total):
|
||||
if task_id in self.upload_progress:
|
||||
self.upload_progress[task_id].update(
|
||||
{
|
||||
"status": "processing",
|
||||
"file_index": file_idx,
|
||||
"file_name": file_info["file_name"],
|
||||
"stage": stage,
|
||||
"current": current,
|
||||
"total": total,
|
||||
}
|
||||
)
|
||||
|
||||
doc = await kb_helper.upload_document(
|
||||
file_name=file_info["file_name"],
|
||||
file_content=file_info["file_content"],
|
||||
file_type=file_info["file_type"],
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
batch_size=batch_size,
|
||||
tasks_limit=tasks_limit,
|
||||
max_retries=max_retries,
|
||||
progress_callback=progress_callback,
|
||||
)
|
||||
|
||||
uploaded_docs.append(doc.model_dump())
|
||||
except Exception as e:
|
||||
logger.error(f"上传文档 {file_info['file_name']} 失败: {e}")
|
||||
failed_docs.append(
|
||||
{"file_name": file_info["file_name"], "error": str(e)}
|
||||
)
|
||||
|
||||
# 更新任务完成状态
|
||||
result = {
|
||||
"task_id": task_id,
|
||||
"uploaded": uploaded_docs,
|
||||
"failed": failed_docs,
|
||||
"total": len(files_to_upload),
|
||||
"success_count": len(uploaded_docs),
|
||||
"failed_count": len(failed_docs),
|
||||
}
|
||||
|
||||
self.upload_tasks[task_id] = {
|
||||
"status": "completed",
|
||||
"result": result,
|
||||
"error": None,
|
||||
}
|
||||
self.upload_progress[task_id]["status"] = "completed"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"后台上传任务 {task_id} 失败: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
self.upload_tasks[task_id] = {
|
||||
"status": "failed",
|
||||
"result": None,
|
||||
"error": str(e),
|
||||
}
|
||||
if task_id in self.upload_progress:
|
||||
self.upload_progress[task_id]["status"] = "failed"
|
||||
|
||||
async def list_kbs(self):
|
||||
"""获取知识库列表
|
||||
|
||||
@@ -74,24 +184,7 @@ class KnowledgeBaseRoute(Route):
|
||||
# 转换为字典列表
|
||||
kb_list = []
|
||||
for kb in kbs:
|
||||
kb_dict = {
|
||||
"kb_id": kb.kb_id,
|
||||
"kb_name": kb.kb_name,
|
||||
"description": kb.description,
|
||||
"emoji": kb.emoji or "📚",
|
||||
"embedding_provider_id": kb.embedding_provider_id,
|
||||
"rerank_provider_id": kb.rerank_provider_id,
|
||||
"doc_count": kb.doc_count,
|
||||
"chunk_count": kb.chunk_count,
|
||||
"chunk_size": kb.chunk_size or 512,
|
||||
"chunk_overlap": kb.chunk_overlap or 50,
|
||||
"top_k_dense": kb.top_k_dense or 50,
|
||||
"top_k_sparse": kb.top_k_sparse or 50,
|
||||
"top_m_final": kb.top_m_final or 5,
|
||||
"created_at": kb.created_at.isoformat(),
|
||||
"updated_at": kb.updated_at.isoformat(),
|
||||
}
|
||||
kb_list.append(kb_dict)
|
||||
kb_list.append(kb.model_dump())
|
||||
|
||||
return (
|
||||
Response()
|
||||
@@ -151,25 +244,7 @@ class KnowledgeBaseRoute(Route):
|
||||
)
|
||||
kb = kb_helper.kb
|
||||
|
||||
kb_dict = {
|
||||
"kb_id": kb.kb_id,
|
||||
"kb_name": kb.kb_name,
|
||||
"description": kb.description,
|
||||
"emoji": kb.emoji or "📚",
|
||||
"embedding_provider_id": kb.embedding_provider_id,
|
||||
"rerank_provider_id": kb.rerank_provider_id,
|
||||
"doc_count": kb.doc_count,
|
||||
"chunk_count": kb.chunk_count,
|
||||
"chunk_size": kb.chunk_size or 512,
|
||||
"chunk_overlap": kb.chunk_overlap or 50,
|
||||
"top_k_dense": kb.top_k_dense or 50,
|
||||
"top_k_sparse": kb.top_k_sparse or 50,
|
||||
"top_m_final": kb.top_m_final or 5,
|
||||
"created_at": kb.created_at.isoformat(),
|
||||
"updated_at": kb.updated_at.isoformat(),
|
||||
}
|
||||
|
||||
return Response().ok(kb_dict, "创建知识库成功").__dict__
|
||||
return Response().ok(kb.model_dump(), "创建知识库成功").__dict__
|
||||
|
||||
except ValueError as e:
|
||||
return Response().error(str(e)).__dict__
|
||||
@@ -195,24 +270,7 @@ class KnowledgeBaseRoute(Route):
|
||||
return Response().error("知识库不存在").__dict__
|
||||
kb = kb_helper.kb
|
||||
|
||||
kb_dict = {
|
||||
"kb_id": kb.kb_id,
|
||||
"kb_name": kb.kb_name,
|
||||
"description": kb.description,
|
||||
"emoji": kb.emoji or "📚",
|
||||
"embedding_provider_id": kb.embedding_provider_id,
|
||||
"rerank_provider_id": kb.rerank_provider_id,
|
||||
"doc_count": kb.doc_count,
|
||||
"chunk_count": kb.chunk_count,
|
||||
"chunk_size": kb.chunk_size,
|
||||
"chunk_overlap": kb.chunk_overlap,
|
||||
"top_k_dense": kb.top_k_dense,
|
||||
"top_k_sparse": kb.top_k_sparse,
|
||||
"created_at": kb.created_at.isoformat(),
|
||||
"updated_at": kb.updated_at.isoformat(),
|
||||
}
|
||||
|
||||
return Response().ok(kb_dict).__dict__
|
||||
return Response().ok(kb.model_dump()).__dict__
|
||||
|
||||
except ValueError as e:
|
||||
return Response().error(str(e)).__dict__
|
||||
@@ -293,25 +351,7 @@ class KnowledgeBaseRoute(Route):
|
||||
return Response().error("知识库不存在").__dict__
|
||||
|
||||
kb = kb_helper.kb
|
||||
|
||||
kb_dict = {
|
||||
"kb_id": kb.kb_id,
|
||||
"kb_name": kb.kb_name,
|
||||
"description": kb.description,
|
||||
"emoji": kb.emoji or "📚",
|
||||
"embedding_provider_id": kb.embedding_provider_id,
|
||||
"rerank_provider_id": kb.rerank_provider_id,
|
||||
"doc_count": kb.doc_count,
|
||||
"chunk_count": kb.chunk_count,
|
||||
"chunk_size": kb.chunk_size or 512,
|
||||
"chunk_overlap": kb.chunk_overlap or 50,
|
||||
"top_k_dense": kb.top_k_dense or 50,
|
||||
"top_k_sparse": kb.top_k_sparse or 50,
|
||||
"created_at": kb.created_at.isoformat(),
|
||||
"updated_at": kb.updated_at.isoformat(),
|
||||
}
|
||||
|
||||
return Response().ok(kb_dict, "更新知识库成功").__dict__
|
||||
return Response().ok(kb.model_dump(), "更新知识库成功").__dict__
|
||||
|
||||
except ValueError as e:
|
||||
return Response().error(str(e)).__dict__
|
||||
@@ -440,6 +480,9 @@ class KnowledgeBaseRoute(Route):
|
||||
- files: 文件数组 (必填)
|
||||
- file_name: 文件名 (必填)
|
||||
- file_content: base64 编码的文件内容 (必填)
|
||||
|
||||
返回:
|
||||
- task_id: 任务ID,用于查询上传进度和结果
|
||||
"""
|
||||
try:
|
||||
kb_manager = self._get_kb_manager()
|
||||
@@ -520,60 +563,41 @@ class KnowledgeBaseRoute(Route):
|
||||
if not kb_helper:
|
||||
return Response().error("知识库不存在").__dict__
|
||||
|
||||
# 上传所有文档
|
||||
uploaded_docs = []
|
||||
failed_docs = []
|
||||
# 生成任务ID
|
||||
task_id = str(uuid.uuid4())
|
||||
|
||||
for file_info in files_to_upload:
|
||||
try:
|
||||
doc = await kb_helper.upload_document(
|
||||
file_name=file_info["file_name"],
|
||||
file_content=file_info["file_content"],
|
||||
file_type=file_info["file_type"],
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
batch_size=batch_size,
|
||||
tasks_limit=tasks_limit,
|
||||
max_retries=max_retries,
|
||||
)
|
||||
|
||||
doc_dict = {
|
||||
"doc_id": doc.doc_id,
|
||||
"kb_id": doc.kb_id,
|
||||
"doc_name": doc.doc_name,
|
||||
"file_type": doc.file_type,
|
||||
"file_size": doc.file_size,
|
||||
"chunk_count": doc.chunk_count,
|
||||
"media_count": doc.media_count,
|
||||
"created_at": doc.created_at.isoformat(),
|
||||
"updated_at": doc.updated_at.isoformat(),
|
||||
}
|
||||
uploaded_docs.append(doc_dict)
|
||||
except Exception as e:
|
||||
logger.error(f"上传文档 {file_info['file_name']} 失败: {e}")
|
||||
failed_docs.append(
|
||||
{"file_name": file_info["file_name"], "error": str(e)}
|
||||
)
|
||||
|
||||
# 返回结果
|
||||
result = {
|
||||
"uploaded": uploaded_docs,
|
||||
"failed": failed_docs,
|
||||
"total": len(files_to_upload),
|
||||
"success_count": len(uploaded_docs),
|
||||
"failed_count": len(failed_docs),
|
||||
# 初始化任务状态
|
||||
self.upload_tasks[task_id] = {
|
||||
"status": "pending",
|
||||
"result": None,
|
||||
"error": None,
|
||||
}
|
||||
|
||||
if failed_docs:
|
||||
message = (
|
||||
f"部分文档上传成功 ({len(uploaded_docs)}/{len(files_to_upload)})"
|
||||
)
|
||||
else:
|
||||
message = (
|
||||
f"所有文档上传成功 ({len(uploaded_docs)}/{len(files_to_upload)})"
|
||||
# 启动后台任务
|
||||
asyncio.create_task(
|
||||
self._background_upload_task(
|
||||
task_id=task_id,
|
||||
kb_helper=kb_helper,
|
||||
files_to_upload=files_to_upload,
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
batch_size=batch_size,
|
||||
tasks_limit=tasks_limit,
|
||||
max_retries=max_retries,
|
||||
)
|
||||
)
|
||||
|
||||
return Response().ok(result, message).__dict__
|
||||
return (
|
||||
Response()
|
||||
.ok(
|
||||
{
|
||||
"task_id": task_id,
|
||||
"file_count": len(files_to_upload),
|
||||
"message": "task created, processing in background",
|
||||
}
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
return Response().error(str(e)).__dict__
|
||||
@@ -582,6 +606,59 @@ class KnowledgeBaseRoute(Route):
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(f"上传文档失败: {str(e)}").__dict__
|
||||
|
||||
async def get_upload_progress(self):
|
||||
"""获取上传进度和结果
|
||||
|
||||
Query 参数:
|
||||
- task_id: 任务 ID (必填)
|
||||
|
||||
返回状态:
|
||||
- pending: 任务待处理
|
||||
- processing: 任务处理中
|
||||
- completed: 任务完成
|
||||
- failed: 任务失败
|
||||
"""
|
||||
try:
|
||||
task_id = request.args.get("task_id")
|
||||
if not task_id:
|
||||
return Response().error("缺少参数 task_id").__dict__
|
||||
|
||||
# 检查任务是否存在
|
||||
if task_id not in self.upload_tasks:
|
||||
return Response().error("找不到该任务").__dict__
|
||||
|
||||
task_info = self.upload_tasks[task_id]
|
||||
status = task_info["status"]
|
||||
|
||||
# 构建返回数据
|
||||
response_data = {
|
||||
"task_id": task_id,
|
||||
"status": status,
|
||||
}
|
||||
|
||||
# 如果任务正在处理,返回进度信息
|
||||
if status == "processing" and task_id in self.upload_progress:
|
||||
response_data["progress"] = self.upload_progress[task_id]
|
||||
|
||||
# 如果任务完成,返回结果
|
||||
if status == "completed":
|
||||
response_data["result"] = task_info["result"]
|
||||
# 清理已完成的任务
|
||||
# del self.upload_tasks[task_id]
|
||||
# if task_id in self.upload_progress:
|
||||
# del self.upload_progress[task_id]
|
||||
|
||||
# 如果任务失败,返回错误信息
|
||||
if status == "failed":
|
||||
response_data["error"] = task_info["error"]
|
||||
|
||||
return Response().ok(response_data).__dict__
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取上传进度失败: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(f"获取上传进度失败: {str(e)}").__dict__
|
||||
|
||||
async def get_document(self):
|
||||
"""获取文档详情
|
||||
|
||||
@@ -604,20 +681,7 @@ class KnowledgeBaseRoute(Route):
|
||||
if not doc:
|
||||
return Response().error("文档不存在").__dict__
|
||||
|
||||
doc_dict = {
|
||||
"doc_id": doc.doc_id,
|
||||
"kb_id": doc.kb_id,
|
||||
"doc_name": doc.doc_name,
|
||||
"file_type": doc.file_type,
|
||||
"file_size": doc.file_size,
|
||||
"file_path": doc.file_path,
|
||||
"chunk_count": doc.chunk_count,
|
||||
"media_count": doc.media_count,
|
||||
"created_at": doc.created_at.isoformat(),
|
||||
"updated_at": doc.updated_at.isoformat(),
|
||||
}
|
||||
|
||||
return Response().ok(doc_dict).__dict__
|
||||
return Response().ok(doc.model_dump()).__dict__
|
||||
|
||||
except ValueError as e:
|
||||
return Response().error(str(e)).__dict__
|
||||
|
||||
@@ -14,10 +14,23 @@
|
||||
<v-data-table :headers="headers" :items="documents" :loading="loading" :search="searchQuery" :items-per-page="10">
|
||||
<template #item.doc_name="{ item }">
|
||||
<div class="d-flex align-center gap-2">
|
||||
<v-icon :color="getFileColor(item.file_type)">
|
||||
<v-icon :color="getFileColor(item.file_type)" class="mr-2">
|
||||
{{ getFileIcon(item.file_type) }}
|
||||
</v-icon>
|
||||
<span class="font-weight-medium">{{ item.doc_name }}</span>
|
||||
<div class="flex-grow-1" style="padding: 4px 0px;">
|
||||
<span class="font-weight-medium">{{ item.doc_name }}</span>
|
||||
<!-- 上传进度 -->
|
||||
<div v-if="item.uploading" class="mt-1">
|
||||
<div class="text-caption text-medium-emphasis mb-1">
|
||||
{{ getStageText(item.uploadProgress?.stage || 'waiting') }}
|
||||
<span v-if="item.uploadProgress?.current">
|
||||
({{ item.uploadProgress.current }} / {{ item.uploadProgress.total }})
|
||||
</span>
|
||||
</div>
|
||||
<v-progress-linear :model-value="getUploadPercentage(item)" color="primary" height="4" rounded
|
||||
striped />
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
@@ -72,7 +85,8 @@
|
||||
<v-btn variant="text" size="small" @click="selectedFiles = []">清空</v-btn>
|
||||
</div>
|
||||
<div class="files-list">
|
||||
<div v-for="(file, index) in selectedFiles" :key="index" class="file-item pa-3 mb-2 rounded bg-surface-variant">
|
||||
<div v-for="(file, index) in selectedFiles" :key="index"
|
||||
class="file-item pa-3 mb-2 rounded bg-surface-variant">
|
||||
<div class="d-flex align-center justify-space-between">
|
||||
<div class="d-flex align-center gap-2">
|
||||
<v-icon>{{ getFileIcon(file.name) }}</v-icon>
|
||||
@@ -95,8 +109,8 @@
|
||||
<v-row>
|
||||
<v-col cols="12" sm="6">
|
||||
<v-text-field v-model.number="uploadSettings.chunk_size" :label="t('upload.chunkSize')"
|
||||
:hint="t('upload.chunkSizeHint')" persistent-hint type="number" variant="outlined"
|
||||
density="compact" :placeholder="props.kb?.chunk_size?.toString() || '512'" />
|
||||
:hint="t('upload.chunkSizeHint')" persistent-hint type="number" variant="outlined" density="compact"
|
||||
:placeholder="props.kb?.chunk_size?.toString() || '512'" />
|
||||
</v-col>
|
||||
<v-col cols="12" sm="6">
|
||||
<v-text-field v-model.number="uploadSettings.chunk_overlap" :label="t('upload.chunkOverlap')"
|
||||
@@ -110,8 +124,8 @@
|
||||
<h3 class="text-h6 mb-4">{{ t('upload.batchSettings') }}</h3>
|
||||
<v-row>
|
||||
<v-col cols="12" sm="4">
|
||||
<v-text-field v-model.number="uploadSettings.batch_size" :label="t('upload.batchSize')"
|
||||
hint="每批处理的文本数量" persistent-hint type="number" variant="outlined" density="compact" />
|
||||
<v-text-field v-model.number="uploadSettings.batch_size" :label="t('upload.batchSize')" hint="每批处理的文本数量"
|
||||
persistent-hint type="number" variant="outlined" density="compact" />
|
||||
</v-col>
|
||||
<v-col cols="12" sm="4">
|
||||
<v-text-field v-model.number="uploadSettings.tasks_limit" :label="t('upload.tasksLimit')"
|
||||
@@ -124,13 +138,15 @@
|
||||
</v-row>
|
||||
</div>
|
||||
|
||||
|
||||
|
||||
</v-card-text>
|
||||
|
||||
<v-divider />
|
||||
|
||||
<v-card-actions class="pa-4">
|
||||
<v-spacer />
|
||||
<v-btn variant="text" @click="closeUploadDialog">
|
||||
<v-btn variant="text" @click="closeUploadDialog" :disabled="uploading">
|
||||
{{ t('upload.cancel') }}
|
||||
</v-btn>
|
||||
<v-btn color="primary" variant="elevated" @click="uploadDocument" :loading="uploading"
|
||||
@@ -171,7 +187,7 @@
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, onMounted } from 'vue'
|
||||
import { ref, onMounted, onUnmounted } from 'vue'
|
||||
import { useRouter } from 'vue-router'
|
||||
import axios from 'axios'
|
||||
import { useModuleI18n } from '@/i18n/composables'
|
||||
@@ -199,6 +215,10 @@ const deleteTarget = ref<any>(null)
|
||||
const isDragging = ref(false)
|
||||
const fileInput = ref<HTMLInputElement | null>(null)
|
||||
|
||||
// 上传进度 - 用于轮询多个任务
|
||||
const uploadingTasks = ref<Map<string, any>>(new Map())
|
||||
const progressPollingInterval = ref<number | null>(null)
|
||||
|
||||
const snackbar = ref({
|
||||
show: false,
|
||||
text: '',
|
||||
@@ -300,14 +320,15 @@ const uploadDocument = async () => {
|
||||
}
|
||||
|
||||
uploading.value = true
|
||||
|
||||
try {
|
||||
const formData = new FormData()
|
||||
|
||||
|
||||
// 添加所有文件
|
||||
selectedFiles.value.forEach((file, index) => {
|
||||
formData.append(`file${index}`, file)
|
||||
})
|
||||
|
||||
|
||||
formData.append('kb_id', props.kbId)
|
||||
if (uploadSettings.value.chunk_size) {
|
||||
formData.append('chunk_size', uploadSettings.value.chunk_size.toString())
|
||||
@@ -325,18 +346,37 @@ const uploadDocument = async () => {
|
||||
|
||||
if (response.data.status === 'ok') {
|
||||
const result = response.data.data
|
||||
const successCount = result.success_count || 0
|
||||
const failedCount = result.failed_count || 0
|
||||
|
||||
if (failedCount === 0) {
|
||||
showSnackbar(`成功上传 ${successCount} 个文档`)
|
||||
} else {
|
||||
showSnackbar(`上传完成: ${successCount} 个成功, ${failedCount} 个失败`, 'warning')
|
||||
}
|
||||
|
||||
const taskId = result.task_id
|
||||
|
||||
showSnackbar(`正在后台上传 ${result.file_count} 个文件...`, 'info')
|
||||
|
||||
// 为每个文件添加占位条目到文档列表
|
||||
const uploadingDocs = selectedFiles.value.map((file, index) => ({
|
||||
doc_id: `uploading_${taskId}_${index}`,
|
||||
doc_name: file.name,
|
||||
file_type: file.name.split('.').pop() || '',
|
||||
file_size: file.size,
|
||||
chunk_count: 0,
|
||||
created_at: new Date().toISOString(),
|
||||
uploading: true,
|
||||
taskId: taskId,
|
||||
uploadProgress: {
|
||||
stage: 'waiting',
|
||||
current: 0,
|
||||
total: 100
|
||||
}
|
||||
}))
|
||||
|
||||
// 添加到文档列表顶部
|
||||
documents.value = [...uploadingDocs, ...documents.value]
|
||||
|
||||
// 关闭对话框
|
||||
closeUploadDialog()
|
||||
await loadDocuments()
|
||||
emit('refresh')
|
||||
|
||||
// 开始轮询进度
|
||||
if (taskId) {
|
||||
startProgressPolling(taskId)
|
||||
}
|
||||
} else {
|
||||
showSnackbar(response.data.message || t('documents.uploadFailed'), 'error')
|
||||
}
|
||||
@@ -348,11 +388,117 @@ const uploadDocument = async () => {
|
||||
}
|
||||
}
|
||||
|
||||
// 开始轮询进度
|
||||
const startProgressPolling = (taskId: string) => {
|
||||
// 如果已经在轮询,先停止
|
||||
if (progressPollingInterval.value) {
|
||||
stopProgressPolling()
|
||||
}
|
||||
|
||||
progressPollingInterval.value = window.setInterval(async () => {
|
||||
try {
|
||||
const response = await axios.get('/api/kb/document/upload/progress', {
|
||||
params: { task_id: taskId }
|
||||
})
|
||||
|
||||
if (response.data.status === 'ok') {
|
||||
const data = response.data.data
|
||||
const status = data.status
|
||||
|
||||
if (status === 'processing' && data.progress) {
|
||||
// 更新进度
|
||||
const progress = data.progress
|
||||
const fileIndex = progress.file_index || 0
|
||||
|
||||
// 更新对应文件的进度
|
||||
documents.value = documents.value.map(doc => {
|
||||
if (doc.taskId === taskId) {
|
||||
const docIndex = parseInt(doc.doc_id.split('_').pop() || '0')
|
||||
if (docIndex === fileIndex) {
|
||||
return {
|
||||
...doc,
|
||||
uploadProgress: {
|
||||
stage: progress.stage || 'waiting',
|
||||
current: progress.current || 0,
|
||||
total: progress.total || 100
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return doc
|
||||
})
|
||||
} else if (status === 'completed') {
|
||||
// 任务完成
|
||||
stopProgressPolling()
|
||||
|
||||
const result = data.result
|
||||
const successCount = result?.success_count || 0
|
||||
const failedCount = result?.failed_count || 0
|
||||
|
||||
// 移除上传中的占位文档
|
||||
documents.value = documents.value.filter(doc => doc.taskId !== taskId)
|
||||
|
||||
// 重新加载文档列表
|
||||
await loadDocuments()
|
||||
emit('refresh')
|
||||
|
||||
if (failedCount === 0) {
|
||||
showSnackbar(`成功上传 ${successCount} 个文档`)
|
||||
} else {
|
||||
showSnackbar(`上传完成: ${successCount} 个成功, ${failedCount} 个失败`, 'warning')
|
||||
}
|
||||
} else if (status === 'failed') {
|
||||
// 任务失败
|
||||
stopProgressPolling()
|
||||
|
||||
// 移除上传中的占位文档
|
||||
documents.value = documents.value.filter(doc => doc.taskId !== taskId)
|
||||
|
||||
showSnackbar(`上传失败: ${data.error || '未知错误'}`, 'error')
|
||||
}
|
||||
} else {
|
||||
// 任务不存在,停止轮询
|
||||
stopProgressPolling()
|
||||
documents.value = documents.value.filter(doc => doc.taskId !== taskId)
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Failed to fetch progress:', error)
|
||||
// 不立即停止,允许重试
|
||||
}
|
||||
}, 500) // 每500ms轮询一次
|
||||
}
|
||||
|
||||
// 停止轮询进度
|
||||
const stopProgressPolling = () => {
|
||||
if (progressPollingInterval.value) {
|
||||
clearInterval(progressPollingInterval.value)
|
||||
progressPollingInterval.value = null
|
||||
}
|
||||
}
|
||||
|
||||
// 获取上传百分比
|
||||
const getUploadPercentage = (item: any) => {
|
||||
if (!item.uploadProgress) return 0
|
||||
const { current, total } = item.uploadProgress
|
||||
if (!total || total === 0) return 0
|
||||
return (current / total) * 100
|
||||
}
|
||||
|
||||
// 获取阶段文本
|
||||
const getStageText = (stage: string) => {
|
||||
const stageMap: Record<string, string> = {
|
||||
'waiting': '等待中...',
|
||||
'parsing': '解析文档...',
|
||||
'chunking': '文本分块...',
|
||||
'embedding': '生成向量...'
|
||||
}
|
||||
return stageMap[stage] || stage
|
||||
}
|
||||
|
||||
// 关闭上传对话框
|
||||
const closeUploadDialog = () => {
|
||||
showUploadDialog.value = false
|
||||
selectedFiles.value = []
|
||||
// 重置为知识库默认设置
|
||||
initUploadSettings()
|
||||
}
|
||||
|
||||
@@ -440,6 +586,10 @@ const formatDate = (dateStr: string) => {
|
||||
onMounted(() => {
|
||||
loadDocuments()
|
||||
})
|
||||
|
||||
onUnmounted(() => {
|
||||
stopProgressPolling()
|
||||
})
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
|
||||
Reference in New Issue
Block a user