From 9430e3090d23db136da109ec1dc92fffbd8af447 Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Fri, 24 Oct 2025 17:13:44 +0800 Subject: [PATCH] feat: add progress callback for document upload and enhance upload progress tracking --- astrbot/core/db/vec_db/base.py | 4 + astrbot/core/db/vec_db/faiss_impl/vec_db.py | 5 + astrbot/core/knowledge_base/kb_helper.py | 29 +- astrbot/core/provider/provider.py | 8 + astrbot/dashboard/routes/knowledge_base.py | 338 +++++++++++------- .../components/DocumentsTab.vue | 196 ++++++++-- 6 files changed, 419 insertions(+), 161 deletions(-) diff --git a/astrbot/core/db/vec_db/base.py b/astrbot/core/db/vec_db/base.py index d100aa71e..27fc9f3fb 100644 --- a/astrbot/core/db/vec_db/base.py +++ b/astrbot/core/db/vec_db/base.py @@ -33,9 +33,13 @@ class BaseVecDB: batch_size: int = 32, tasks_limit: int = 3, max_retries: int = 3, + progress_callback=None, ) -> int: """ 批量插入文本和其对应向量,自动生成 ID 并保持一致性。 + + Args: + progress_callback: 进度回调函数,接收参数 (current, total) """ ... diff --git a/astrbot/core/db/vec_db/faiss_impl/vec_db.py b/astrbot/core/db/vec_db/faiss_impl/vec_db.py index 77a1e605a..587de2496 100644 --- a/astrbot/core/db/vec_db/faiss_impl/vec_db.py +++ b/astrbot/core/db/vec_db/faiss_impl/vec_db.py @@ -70,9 +70,13 @@ class FaissVecDB(BaseVecDB): batch_size: int = 32, tasks_limit: int = 3, max_retries: int = 3, + progress_callback=None, ) -> list[int]: """ 批量插入文本和其对应向量,自动生成 ID 并保持一致性。 + + Args: + progress_callback: 进度回调函数,接收参数 (current, total) """ assert self.document_storage.connection is not None, ( "Database connection is not initialized." @@ -87,6 +91,7 @@ class FaissVecDB(BaseVecDB): batch_size=batch_size, tasks_limit=tasks_limit, max_retries=max_retries, + progress_callback=progress_callback, ) end = time.time() logger.debug( diff --git a/astrbot/core/knowledge_base/kb_helper.py b/astrbot/core/knowledge_base/kb_helper.py index 389d2cb5f..ebad6d14a 100644 --- a/astrbot/core/knowledge_base/kb_helper.py +++ b/astrbot/core/knowledge_base/kb_helper.py @@ -106,6 +106,7 @@ class KBHelper: batch_size: int = 32, tasks_limit: int = 3, max_retries: int = 3, + progress_callback=None, ) -> KBDocument: """上传并处理文档(带原子性保证和失败清理) @@ -117,6 +118,12 @@ class KBHelper: 5. 生成向量并存储 6. 保存元数据(事务) 7. 更新统计 + + Args: + progress_callback: 进度回调函数,接收参数 (stage, current, total) + - stage: 当前阶段 ('parsing', 'chunking', 'embedding') + - current: 当前进度 + - total: 总数 """ await self._ensure_vec_db() doc_id = str(uuid.uuid4()) @@ -127,6 +134,10 @@ class KBHelper: # await f.write(file_content) try: + # 阶段1: 解析文档 + if progress_callback: + await progress_callback("parsing", 0, 100) + parser = self.parsers.get(file_type) if not parser: raise ValueError(f"不支持的文件类型: {file_type}") @@ -134,6 +145,9 @@ class KBHelper: text_content = parse_result.text media_items = parse_result.media + if progress_callback: + await progress_callback("parsing", 100, 100) + # 保存媒体文件 saved_media = [] for media_item in media_items: @@ -147,7 +161,10 @@ class KBHelper: saved_media.append(media) media_paths.append(Path(media.file_path)) - # 分块并生成向量 + # 阶段2: 分块 + if progress_callback: + await progress_callback("chunking", 0, 100) + chunks_text = await self.chunker.chunk( text_content, chunk_size=chunk_size, chunk_overlap=chunk_overlap ) @@ -162,12 +179,22 @@ class KBHelper: "chunk_index": idx, } ) + + if progress_callback: + await progress_callback("chunking", 100, 100) + + # 阶段3: 生成向量(带进度回调) + async def embedding_progress_callback(current, total): + if progress_callback: + await progress_callback("embedding", current, total) + await self.vec_db.insert_batch( contents=contents, metadatas=metadatas, batch_size=batch_size, tasks_limit=tasks_limit, max_retries=max_retries, + progress_callback=embedding_progress_callback, ) # 保存文档的元数据 diff --git a/astrbot/core/provider/provider.py b/astrbot/core/provider/provider.py index 99ec9443b..9953e9f17 100644 --- a/astrbot/core/provider/provider.py +++ b/astrbot/core/provider/provider.py @@ -210,6 +210,7 @@ class EmbeddingProvider(AbstractProvider): batch_size: int = 16, tasks_limit: int = 3, max_retries: int = 3, + progress_callback=None, ) -> list[list[float]]: """批量获取文本的向量,分批处理以节省内存 @@ -218,6 +219,7 @@ class EmbeddingProvider(AbstractProvider): batch_size: 每批处理的文本数量 tasks_limit: 并发任务数量限制 max_retries: 失败时的最大重试次数 + progress_callback: 进度回调函数,接收参数 (current, total) Returns: 向量列表 @@ -225,13 +227,19 @@ class EmbeddingProvider(AbstractProvider): semaphore = asyncio.Semaphore(tasks_limit) all_embeddings: list[list[float]] = [] failed_batches: list[tuple[int, list[str]]] = [] + completed_count = 0 + total_count = len(texts) async def process_batch(batch_idx: int, batch_texts: list[str]): + nonlocal completed_count async with semaphore: for attempt in range(max_retries): try: batch_embeddings = await self.get_embeddings(batch_texts) all_embeddings.extend(batch_embeddings) + completed_count += len(batch_texts) + if progress_callback: + await progress_callback(completed_count, total_count) return except Exception as e: if attempt == max_retries - 1: diff --git a/astrbot/dashboard/routes/knowledge_base.py b/astrbot/dashboard/routes/knowledge_base.py index 9bd71b0e6..d461aebfb 100644 --- a/astrbot/dashboard/routes/knowledge_base.py +++ b/astrbot/dashboard/routes/knowledge_base.py @@ -4,6 +4,7 @@ import uuid import aiofiles import os import traceback +import asyncio from quart import request from astrbot.core import logger from astrbot.core.core_lifecycle import AstrBotCoreLifecycle @@ -27,6 +28,8 @@ class KnowledgeBaseRoute(Route): self.kb_db = None self.session_config_db = None # 会话配置数据库 self.retrieval_manager = None + self.upload_progress = {} # 存储上传进度 {task_id: {status, file_index, file_total, stage, current, total}} + self.upload_tasks = {} # 存储后台上传任务 {task_id: {"status", "result", "error"}} # 注册路由 self.routes = { @@ -40,6 +43,7 @@ class KnowledgeBaseRoute(Route): # 文档管理 "/kb/document/list": ("GET", self.list_documents), "/kb/document/upload": ("POST", self.upload_document), + "/kb/document/upload/progress": ("GET", self.get_upload_progress), "/kb/document/get": ("GET", self.get_document), "/kb/document/delete": ("POST", self.delete_document), # # 块管理 @@ -56,6 +60,112 @@ class KnowledgeBaseRoute(Route): def _get_kb_manager(self): return self.core_lifecycle.kb_manager + async def _background_upload_task( + self, + task_id: str, + kb_helper, + files_to_upload: list, + chunk_size: int, + chunk_overlap: int, + batch_size: int, + tasks_limit: int, + max_retries: int, + ): + """后台上传任务""" + try: + # 初始化任务状态 + self.upload_tasks[task_id] = { + "status": "processing", + "result": None, + "error": None, + } + self.upload_progress[task_id] = { + "status": "processing", + "file_index": 0, + "file_total": len(files_to_upload), + "stage": "waiting", + "current": 0, + "total": 100, + } + + uploaded_docs = [] + failed_docs = [] + + for file_idx, file_info in enumerate(files_to_upload): + try: + # 更新整体进度 + self.upload_progress[task_id].update( + { + "status": "processing", + "file_index": file_idx, + "file_name": file_info["file_name"], + "stage": "parsing", + "current": 0, + "total": 100, + } + ) + + # 创建进度回调函数 + async def progress_callback(stage, current, total): + if task_id in self.upload_progress: + self.upload_progress[task_id].update( + { + "status": "processing", + "file_index": file_idx, + "file_name": file_info["file_name"], + "stage": stage, + "current": current, + "total": total, + } + ) + + doc = await kb_helper.upload_document( + file_name=file_info["file_name"], + file_content=file_info["file_content"], + file_type=file_info["file_type"], + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + batch_size=batch_size, + tasks_limit=tasks_limit, + max_retries=max_retries, + progress_callback=progress_callback, + ) + + uploaded_docs.append(doc.model_dump()) + except Exception as e: + logger.error(f"上传文档 {file_info['file_name']} 失败: {e}") + failed_docs.append( + {"file_name": file_info["file_name"], "error": str(e)} + ) + + # 更新任务完成状态 + result = { + "task_id": task_id, + "uploaded": uploaded_docs, + "failed": failed_docs, + "total": len(files_to_upload), + "success_count": len(uploaded_docs), + "failed_count": len(failed_docs), + } + + self.upload_tasks[task_id] = { + "status": "completed", + "result": result, + "error": None, + } + self.upload_progress[task_id]["status"] = "completed" + + except Exception as e: + logger.error(f"后台上传任务 {task_id} 失败: {e}") + logger.error(traceback.format_exc()) + self.upload_tasks[task_id] = { + "status": "failed", + "result": None, + "error": str(e), + } + if task_id in self.upload_progress: + self.upload_progress[task_id]["status"] = "failed" + async def list_kbs(self): """获取知识库列表 @@ -74,24 +184,7 @@ class KnowledgeBaseRoute(Route): # 转换为字典列表 kb_list = [] for kb in kbs: - kb_dict = { - "kb_id": kb.kb_id, - "kb_name": kb.kb_name, - "description": kb.description, - "emoji": kb.emoji or "📚", - "embedding_provider_id": kb.embedding_provider_id, - "rerank_provider_id": kb.rerank_provider_id, - "doc_count": kb.doc_count, - "chunk_count": kb.chunk_count, - "chunk_size": kb.chunk_size or 512, - "chunk_overlap": kb.chunk_overlap or 50, - "top_k_dense": kb.top_k_dense or 50, - "top_k_sparse": kb.top_k_sparse or 50, - "top_m_final": kb.top_m_final or 5, - "created_at": kb.created_at.isoformat(), - "updated_at": kb.updated_at.isoformat(), - } - kb_list.append(kb_dict) + kb_list.append(kb.model_dump()) return ( Response() @@ -151,25 +244,7 @@ class KnowledgeBaseRoute(Route): ) kb = kb_helper.kb - kb_dict = { - "kb_id": kb.kb_id, - "kb_name": kb.kb_name, - "description": kb.description, - "emoji": kb.emoji or "📚", - "embedding_provider_id": kb.embedding_provider_id, - "rerank_provider_id": kb.rerank_provider_id, - "doc_count": kb.doc_count, - "chunk_count": kb.chunk_count, - "chunk_size": kb.chunk_size or 512, - "chunk_overlap": kb.chunk_overlap or 50, - "top_k_dense": kb.top_k_dense or 50, - "top_k_sparse": kb.top_k_sparse or 50, - "top_m_final": kb.top_m_final or 5, - "created_at": kb.created_at.isoformat(), - "updated_at": kb.updated_at.isoformat(), - } - - return Response().ok(kb_dict, "创建知识库成功").__dict__ + return Response().ok(kb.model_dump(), "创建知识库成功").__dict__ except ValueError as e: return Response().error(str(e)).__dict__ @@ -195,24 +270,7 @@ class KnowledgeBaseRoute(Route): return Response().error("知识库不存在").__dict__ kb = kb_helper.kb - kb_dict = { - "kb_id": kb.kb_id, - "kb_name": kb.kb_name, - "description": kb.description, - "emoji": kb.emoji or "📚", - "embedding_provider_id": kb.embedding_provider_id, - "rerank_provider_id": kb.rerank_provider_id, - "doc_count": kb.doc_count, - "chunk_count": kb.chunk_count, - "chunk_size": kb.chunk_size, - "chunk_overlap": kb.chunk_overlap, - "top_k_dense": kb.top_k_dense, - "top_k_sparse": kb.top_k_sparse, - "created_at": kb.created_at.isoformat(), - "updated_at": kb.updated_at.isoformat(), - } - - return Response().ok(kb_dict).__dict__ + return Response().ok(kb.model_dump()).__dict__ except ValueError as e: return Response().error(str(e)).__dict__ @@ -293,25 +351,7 @@ class KnowledgeBaseRoute(Route): return Response().error("知识库不存在").__dict__ kb = kb_helper.kb - - kb_dict = { - "kb_id": kb.kb_id, - "kb_name": kb.kb_name, - "description": kb.description, - "emoji": kb.emoji or "📚", - "embedding_provider_id": kb.embedding_provider_id, - "rerank_provider_id": kb.rerank_provider_id, - "doc_count": kb.doc_count, - "chunk_count": kb.chunk_count, - "chunk_size": kb.chunk_size or 512, - "chunk_overlap": kb.chunk_overlap or 50, - "top_k_dense": kb.top_k_dense or 50, - "top_k_sparse": kb.top_k_sparse or 50, - "created_at": kb.created_at.isoformat(), - "updated_at": kb.updated_at.isoformat(), - } - - return Response().ok(kb_dict, "更新知识库成功").__dict__ + return Response().ok(kb.model_dump(), "更新知识库成功").__dict__ except ValueError as e: return Response().error(str(e)).__dict__ @@ -440,6 +480,9 @@ class KnowledgeBaseRoute(Route): - files: 文件数组 (必填) - file_name: 文件名 (必填) - file_content: base64 编码的文件内容 (必填) + + 返回: + - task_id: 任务ID,用于查询上传进度和结果 """ try: kb_manager = self._get_kb_manager() @@ -520,60 +563,41 @@ class KnowledgeBaseRoute(Route): if not kb_helper: return Response().error("知识库不存在").__dict__ - # 上传所有文档 - uploaded_docs = [] - failed_docs = [] + # 生成任务ID + task_id = str(uuid.uuid4()) - for file_info in files_to_upload: - try: - doc = await kb_helper.upload_document( - file_name=file_info["file_name"], - file_content=file_info["file_content"], - file_type=file_info["file_type"], - chunk_size=chunk_size, - chunk_overlap=chunk_overlap, - batch_size=batch_size, - tasks_limit=tasks_limit, - max_retries=max_retries, - ) - - doc_dict = { - "doc_id": doc.doc_id, - "kb_id": doc.kb_id, - "doc_name": doc.doc_name, - "file_type": doc.file_type, - "file_size": doc.file_size, - "chunk_count": doc.chunk_count, - "media_count": doc.media_count, - "created_at": doc.created_at.isoformat(), - "updated_at": doc.updated_at.isoformat(), - } - uploaded_docs.append(doc_dict) - except Exception as e: - logger.error(f"上传文档 {file_info['file_name']} 失败: {e}") - failed_docs.append( - {"file_name": file_info["file_name"], "error": str(e)} - ) - - # 返回结果 - result = { - "uploaded": uploaded_docs, - "failed": failed_docs, - "total": len(files_to_upload), - "success_count": len(uploaded_docs), - "failed_count": len(failed_docs), + # 初始化任务状态 + self.upload_tasks[task_id] = { + "status": "pending", + "result": None, + "error": None, } - if failed_docs: - message = ( - f"部分文档上传成功 ({len(uploaded_docs)}/{len(files_to_upload)})" - ) - else: - message = ( - f"所有文档上传成功 ({len(uploaded_docs)}/{len(files_to_upload)})" + # 启动后台任务 + asyncio.create_task( + self._background_upload_task( + task_id=task_id, + kb_helper=kb_helper, + files_to_upload=files_to_upload, + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + batch_size=batch_size, + tasks_limit=tasks_limit, + max_retries=max_retries, ) + ) - return Response().ok(result, message).__dict__ + return ( + Response() + .ok( + { + "task_id": task_id, + "file_count": len(files_to_upload), + "message": "task created, processing in background", + } + ) + .__dict__ + ) except ValueError as e: return Response().error(str(e)).__dict__ @@ -582,6 +606,59 @@ class KnowledgeBaseRoute(Route): logger.error(traceback.format_exc()) return Response().error(f"上传文档失败: {str(e)}").__dict__ + async def get_upload_progress(self): + """获取上传进度和结果 + + Query 参数: + - task_id: 任务 ID (必填) + + 返回状态: + - pending: 任务待处理 + - processing: 任务处理中 + - completed: 任务完成 + - failed: 任务失败 + """ + try: + task_id = request.args.get("task_id") + if not task_id: + return Response().error("缺少参数 task_id").__dict__ + + # 检查任务是否存在 + if task_id not in self.upload_tasks: + return Response().error("找不到该任务").__dict__ + + task_info = self.upload_tasks[task_id] + status = task_info["status"] + + # 构建返回数据 + response_data = { + "task_id": task_id, + "status": status, + } + + # 如果任务正在处理,返回进度信息 + if status == "processing" and task_id in self.upload_progress: + response_data["progress"] = self.upload_progress[task_id] + + # 如果任务完成,返回结果 + if status == "completed": + response_data["result"] = task_info["result"] + # 清理已完成的任务 + # del self.upload_tasks[task_id] + # if task_id in self.upload_progress: + # del self.upload_progress[task_id] + + # 如果任务失败,返回错误信息 + if status == "failed": + response_data["error"] = task_info["error"] + + return Response().ok(response_data).__dict__ + + except Exception as e: + logger.error(f"获取上传进度失败: {e}") + logger.error(traceback.format_exc()) + return Response().error(f"获取上传进度失败: {str(e)}").__dict__ + async def get_document(self): """获取文档详情 @@ -604,20 +681,7 @@ class KnowledgeBaseRoute(Route): if not doc: return Response().error("文档不存在").__dict__ - doc_dict = { - "doc_id": doc.doc_id, - "kb_id": doc.kb_id, - "doc_name": doc.doc_name, - "file_type": doc.file_type, - "file_size": doc.file_size, - "file_path": doc.file_path, - "chunk_count": doc.chunk_count, - "media_count": doc.media_count, - "created_at": doc.created_at.isoformat(), - "updated_at": doc.updated_at.isoformat(), - } - - return Response().ok(doc_dict).__dict__ + return Response().ok(doc.model_dump()).__dict__ except ValueError as e: return Response().error(str(e)).__dict__ diff --git a/dashboard/src/views/knowledge-base/components/DocumentsTab.vue b/dashboard/src/views/knowledge-base/components/DocumentsTab.vue index 4dd2ff622..34d0d54d4 100644 --- a/dashboard/src/views/knowledge-base/components/DocumentsTab.vue +++ b/dashboard/src/views/knowledge-base/components/DocumentsTab.vue @@ -14,10 +14,23 @@ @@ -72,7 +85,8 @@ 清空
-
+
{{ getFileIcon(file.name) }} @@ -95,8 +109,8 @@ + :hint="t('upload.chunkSizeHint')" persistent-hint type="number" variant="outlined" density="compact" + :placeholder="props.kb?.chunk_size?.toString() || '512'" /> {{ t('upload.batchSettings') }} - +
+ + - + {{ t('upload.cancel') }}