"""知识库管理 API 路由""" import asyncio import os import traceback import uuid from typing import Any import aiofiles from quart import request from astrbot.core import logger from astrbot.core.core_lifecycle import AstrBotCoreLifecycle from astrbot.core.provider.provider import EmbeddingProvider, RerankProvider from astrbot.core.utils.astrbot_path import get_astrbot_temp_path from ..utils import generate_tsne_visualization from .route import Response, Route, RouteContext class KnowledgeBaseRoute(Route): """知识库管理路由 提供知识库、文档、检索、会话配置等 API 接口 """ def __init__( self, context: RouteContext, core_lifecycle: AstrBotCoreLifecycle, ) -> None: super().__init__(context) self.core_lifecycle = core_lifecycle self.kb_manager = None # 延迟初始化 self.kb_db = None self.session_config_db = None # 会话配置数据库 self.retrieval_manager = None self.upload_progress = {} # 存储上传进度 {task_id: {status, file_index, file_total, stage, current, total}} self.upload_tasks = {} # 存储后台上传任务 {task_id: {"status", "result", "error"}} # 注册路由 self.routes = { # 知识库管理 "/kb/list": ("GET", self.list_kbs), "/kb/create": ("POST", self.create_kb), "/kb/get": ("GET", self.get_kb), "/kb/update": ("POST", self.update_kb), "/kb/delete": ("POST", self.delete_kb), "/kb/stats": ("GET", self.get_kb_stats), # 文档管理 "/kb/document/list": ("GET", self.list_documents), "/kb/document/upload": ("POST", self.upload_document), "/kb/document/import": ("POST", self.import_documents), "/kb/document/upload/url": ("POST", self.upload_document_from_url), "/kb/document/upload/progress": ("GET", self.get_upload_progress), "/kb/document/get": ("GET", self.get_document), "/kb/document/delete": ("POST", self.delete_document), # # 块管理 "/kb/chunk/list": ("GET", self.list_chunks), "/kb/chunk/delete": ("POST", self.delete_chunk), # # 多媒体管理 # "/kb/media/list": ("GET", self.list_media), # "/kb/media/delete": ("POST", self.delete_media), # 检索 "/kb/retrieve": ("POST", self.retrieve), } self.register_routes() def _get_kb_manager(self): return self.core_lifecycle.kb_manager def _init_task(self, task_id: str, status: str = "pending") -> None: self.upload_tasks[task_id] = { "status": status, "result": None, "error": None, } def _set_task_result( self, task_id: str, status: str, result: Any = None, error: str | None = None ) -> None: self.upload_tasks[task_id] = { "status": status, "result": result, "error": error, } if task_id in self.upload_progress: self.upload_progress[task_id]["status"] = status def _update_progress( self, task_id: str, *, status: str | None = None, file_index: int | None = None, file_name: str | None = None, stage: str | None = None, current: int | None = None, total: int | None = None, ) -> None: if task_id not in self.upload_progress: return p = self.upload_progress[task_id] if status is not None: p["status"] = status if file_index is not None: p["file_index"] = file_index if file_name is not None: p["file_name"] = file_name if stage is not None: p["stage"] = stage if current is not None: p["current"] = current if total is not None: p["total"] = total def _make_progress_callback(self, task_id: str, file_idx: int, file_name: str): async def _callback(stage: str, current: int, total: int) -> None: self._update_progress( task_id, status="processing", file_index=file_idx, file_name=file_name, stage=stage, current=current, total=total, ) return _callback async def _background_upload_task( self, task_id: str, kb_helper, files_to_upload: list, chunk_size: int, chunk_overlap: int, batch_size: int, tasks_limit: int, max_retries: int, ) -> None: """后台上传任务""" try: # 初始化任务状态 self._init_task(task_id, status="processing") self.upload_progress[task_id] = { "status": "processing", "file_index": 0, "file_total": len(files_to_upload), "stage": "waiting", "current": 0, "total": 100, } uploaded_docs = [] failed_docs = [] for file_idx, file_info in enumerate(files_to_upload): try: # 更新整体进度 self._update_progress( task_id, status="processing", file_index=file_idx, file_name=file_info["file_name"], stage="parsing", current=0, total=100, ) # 创建进度回调函数 progress_callback = self._make_progress_callback( task_id, file_idx, file_info["file_name"] ) doc = await kb_helper.upload_document( file_name=file_info["file_name"], file_content=file_info["file_content"], file_type=file_info["file_type"], chunk_size=chunk_size, chunk_overlap=chunk_overlap, batch_size=batch_size, tasks_limit=tasks_limit, max_retries=max_retries, progress_callback=progress_callback, ) uploaded_docs.append(doc.model_dump()) except Exception as e: logger.error(f"上传文档 {file_info['file_name']} 失败: {e}") failed_docs.append( {"file_name": file_info["file_name"], "error": str(e)}, ) # 更新任务完成状态 result = { "task_id": task_id, "uploaded": uploaded_docs, "failed": failed_docs, "total": len(files_to_upload), "success_count": len(uploaded_docs), "failed_count": len(failed_docs), } self._set_task_result(task_id, "completed", result=result) except Exception as e: logger.error(f"后台上传任务 {task_id} 失败: {e}") logger.error(traceback.format_exc()) self._set_task_result(task_id, "failed", error=str(e)) async def _background_import_task( self, task_id: str, kb_helper, documents: list, batch_size: int, tasks_limit: int, max_retries: int, ) -> None: """后台导入预切片文档任务""" try: # 初始化任务状态 self._init_task(task_id, status="processing") self.upload_progress[task_id] = { "status": "processing", "file_index": 0, "file_total": len(documents), "stage": "waiting", "current": 0, "total": 100, } uploaded_docs = [] failed_docs = [] for file_idx, doc_info in enumerate(documents): file_name = doc_info.get("file_name", f"imported_doc_{file_idx}") chunks = doc_info.get("chunks", []) try: # 更新整体进度 self._update_progress( task_id, status="processing", file_index=file_idx, file_name=file_name, stage="importing", current=0, total=100, ) # 创建进度回调函数 progress_callback = self._make_progress_callback( task_id, file_idx, file_name ) # 调用 upload_document,传入 pre_chunked_text doc = await kb_helper.upload_document( file_name=file_name, file_content=None, # 预切片模式下不需要原始内容 file_type=doc_info.get("file_type") or ( file_name.rsplit(".", 1)[-1].lower() if "." in file_name else "txt" ), batch_size=batch_size, tasks_limit=tasks_limit, max_retries=max_retries, progress_callback=progress_callback, pre_chunked_text=chunks, ) uploaded_docs.append(doc.model_dump()) except Exception as e: logger.error(f"导入文档 {file_name} 失败: {e}") failed_docs.append( {"file_name": file_name, "error": str(e)}, ) # 更新任务完成状态 result = { "task_id": task_id, "uploaded": uploaded_docs, "failed": failed_docs, "total": len(documents), "success_count": len(uploaded_docs), "failed_count": len(failed_docs), } self._set_task_result(task_id, "completed", result=result) except Exception as e: logger.error(f"后台导入任务 {task_id} 失败: {e}") logger.error(traceback.format_exc()) self._set_task_result(task_id, "failed", error=str(e)) async def list_kbs(self): """获取知识库列表 Query 参数: - page: 页码 (默认 1) - page_size: 每页数量 (默认 20) - refresh_stats: 是否刷新统计信息 (默认 false,首次加载时可设为 true) """ try: kb_manager = self._get_kb_manager() page = request.args.get("page", 1, type=int) page_size = request.args.get("page_size", 20, type=int) kbs = await kb_manager.list_kbs() # 转换为字典列表 kb_list = [] for kb in kbs: kb_list.append(kb.model_dump()) return ( Response() .ok({"items": kb_list, "page": page, "page_size": page_size}) .__dict__ ) except ValueError as e: return Response().error(str(e)).__dict__ except Exception as e: logger.error(f"获取知识库列表失败: {e}") logger.error(traceback.format_exc()) return Response().error(f"获取知识库列表失败: {e!s}").__dict__ async def create_kb(self): """创建知识库 Body: - kb_name: 知识库名称 (必填) - description: 描述 (可选) - emoji: 图标 (可选) - embedding_provider_id: 嵌入模型提供商ID (可选) - rerank_provider_id: 重排序模型提供商ID (可选) - chunk_size: 分块大小 (可选, 默认512) - chunk_overlap: 块重叠大小 (可选, 默认50) - top_k_dense: 密集检索数量 (可选, 默认50) - top_k_sparse: 稀疏检索数量 (可选, 默认50) - top_m_final: 最终返回数量 (可选, 默认5) """ try: kb_manager = self._get_kb_manager() data = await request.json kb_name = data.get("kb_name") if not kb_name: return Response().error("知识库名称不能为空").__dict__ description = data.get("description") emoji = data.get("emoji") embedding_provider_id = data.get("embedding_provider_id") rerank_provider_id = data.get("rerank_provider_id") chunk_size = data.get("chunk_size") chunk_overlap = data.get("chunk_overlap") top_k_dense = data.get("top_k_dense") top_k_sparse = data.get("top_k_sparse") top_m_final = data.get("top_m_final") # pre-check embedding dim if not embedding_provider_id: return Response().error("缺少参数 embedding_provider_id").__dict__ prv = await kb_manager.provider_manager.get_provider_by_id( embedding_provider_id, ) # type: ignore if not prv or not isinstance(prv, EmbeddingProvider): return ( Response().error(f"嵌入模型不存在或类型错误({type(prv)})").__dict__ ) try: vec = await prv.get_embedding("astrbot") if len(vec) != prv.get_dim(): raise ValueError( f"嵌入向量维度不匹配,实际是 {len(vec)},然而配置是 {prv.get_dim()}", ) except Exception as e: return Response().error(f"测试嵌入模型失败: {e!s}").__dict__ # pre-check rerank if rerank_provider_id: rerank_prv: RerankProvider = ( await kb_manager.provider_manager.get_provider_by_id( rerank_provider_id, ) ) # type: ignore if not rerank_prv: return Response().error("重排序模型不存在").__dict__ # 检查重排序模型可用性 try: res = await rerank_prv.rerank( query="astrbot", documents=["astrbot knowledge base"], ) if not res: raise ValueError("重排序模型返回结果异常") except Exception as e: return ( Response() .error(f"测试重排序模型失败: {e!s},请检查平台日志输出。") .__dict__ ) kb_helper = await kb_manager.create_kb( kb_name=kb_name, description=description, emoji=emoji, embedding_provider_id=embedding_provider_id, rerank_provider_id=rerank_provider_id, chunk_size=chunk_size, chunk_overlap=chunk_overlap, top_k_dense=top_k_dense, top_k_sparse=top_k_sparse, top_m_final=top_m_final, ) kb = kb_helper.kb return Response().ok(kb.model_dump(), "创建知识库成功").__dict__ except ValueError as e: return Response().error(str(e)).__dict__ except Exception as e: logger.error(f"创建知识库失败: {e}") logger.error(traceback.format_exc()) return Response().error(f"创建知识库失败: {e!s}").__dict__ async def get_kb(self): """获取知识库详情 Query 参数: - kb_id: 知识库 ID (必填) """ try: kb_manager = self._get_kb_manager() kb_id = request.args.get("kb_id") if not kb_id: return Response().error("缺少参数 kb_id").__dict__ kb_helper = await kb_manager.get_kb(kb_id) if not kb_helper: return Response().error("知识库不存在").__dict__ kb = kb_helper.kb return Response().ok(kb.model_dump()).__dict__ except ValueError as e: return Response().error(str(e)).__dict__ except Exception as e: logger.error(f"获取知识库详情失败: {e}") logger.error(traceback.format_exc()) return Response().error(f"获取知识库详情失败: {e!s}").__dict__ async def update_kb(self): """更新知识库 Body: - kb_id: 知识库 ID (必填) - kb_name: 新的知识库名称 (可选) - description: 新的描述 (可选) - emoji: 新的图标 (可选) - embedding_provider_id: 新的嵌入模型提供商ID (可选) - rerank_provider_id: 新的重排序模型提供商ID (可选) - chunk_size: 分块大小 (可选) - chunk_overlap: 块重叠大小 (可选) - top_k_dense: 密集检索数量 (可选) - top_k_sparse: 稀疏检索数量 (可选) - top_m_final: 最终返回数量 (可选) """ try: kb_manager = self._get_kb_manager() data = await request.json kb_id = data.get("kb_id") if not kb_id: return Response().error("缺少参数 kb_id").__dict__ kb_name = data.get("kb_name") description = data.get("description") emoji = data.get("emoji") embedding_provider_id = data.get("embedding_provider_id") rerank_provider_id = data.get("rerank_provider_id") chunk_size = data.get("chunk_size") chunk_overlap = data.get("chunk_overlap") top_k_dense = data.get("top_k_dense") top_k_sparse = data.get("top_k_sparse") top_m_final = data.get("top_m_final") # 检查是否至少提供了一个更新字段 if all( v is None for v in [ kb_name, description, emoji, embedding_provider_id, rerank_provider_id, chunk_size, chunk_overlap, top_k_dense, top_k_sparse, top_m_final, ] ): return Response().error("至少需要提供一个更新字段").__dict__ kb_helper = await kb_manager.update_kb( kb_id=kb_id, kb_name=kb_name, description=description, emoji=emoji, embedding_provider_id=embedding_provider_id, rerank_provider_id=rerank_provider_id, chunk_size=chunk_size, chunk_overlap=chunk_overlap, top_k_dense=top_k_dense, top_k_sparse=top_k_sparse, top_m_final=top_m_final, ) if not kb_helper: return Response().error("知识库不存在").__dict__ kb = kb_helper.kb return Response().ok(kb.model_dump(), "更新知识库成功").__dict__ except ValueError as e: return Response().error(str(e)).__dict__ except Exception as e: logger.error(f"更新知识库失败: {e}") logger.error(traceback.format_exc()) return Response().error(f"更新知识库失败: {e!s}").__dict__ async def delete_kb(self): """删除知识库 Body: - kb_id: 知识库 ID (必填) """ try: kb_manager = self._get_kb_manager() data = await request.json kb_id = data.get("kb_id") if not kb_id: return Response().error("缺少参数 kb_id").__dict__ success = await kb_manager.delete_kb(kb_id) if not success: return Response().error("知识库不存在").__dict__ return Response().ok(message="删除知识库成功").__dict__ except ValueError as e: return Response().error(str(e)).__dict__ except Exception as e: logger.error(f"删除知识库失败: {e}") logger.error(traceback.format_exc()) return Response().error(f"删除知识库失败: {e!s}").__dict__ async def get_kb_stats(self): """获取知识库统计信息 Query 参数: - kb_id: 知识库 ID (必填) """ try: kb_manager = self._get_kb_manager() kb_id = request.args.get("kb_id") if not kb_id: return Response().error("缺少参数 kb_id").__dict__ kb_helper = await kb_manager.get_kb(kb_id) if not kb_helper: return Response().error("知识库不存在").__dict__ kb = kb_helper.kb stats = { "kb_id": kb.kb_id, "kb_name": kb.kb_name, "doc_count": kb.doc_count, "chunk_count": kb.chunk_count, "created_at": kb.created_at.isoformat(), "updated_at": kb.updated_at.isoformat(), } return Response().ok(stats).__dict__ except ValueError as e: return Response().error(str(e)).__dict__ except Exception as e: logger.error(f"获取知识库统计失败: {e}") logger.error(traceback.format_exc()) return Response().error(f"获取知识库统计失败: {e!s}").__dict__ # ===== 文档管理 API ===== async def list_documents(self): """获取文档列表 Query 参数: - kb_id: 知识库 ID (必填) - page: 页码 (默认 1) - page_size: 每页数量 (默认 20) """ try: kb_manager = self._get_kb_manager() kb_id = request.args.get("kb_id") if not kb_id: return Response().error("缺少参数 kb_id").__dict__ kb_helper = await kb_manager.get_kb(kb_id) if not kb_helper: return Response().error("知识库不存在").__dict__ page = request.args.get("page", 1, type=int) page_size = request.args.get("page_size", 100, type=int) offset = (page - 1) * page_size limit = page_size doc_list = await kb_helper.list_documents(offset=offset, limit=limit) doc_list = [doc.model_dump() for doc in doc_list] return ( Response() .ok({"items": doc_list, "page": page, "page_size": page_size}) .__dict__ ) except ValueError as e: return Response().error(str(e)).__dict__ except Exception as e: logger.error(f"获取文档列表失败: {e}") logger.error(traceback.format_exc()) return Response().error(f"获取文档列表失败: {e!s}").__dict__ async def upload_document(self): """上传文档 支持两种方式: 1. multipart/form-data 文件上传(支持多文件,最多10个) 2. JSON 格式 base64 编码上传(支持多文件,最多10个) Form Data (multipart/form-data): - kb_id: 知识库 ID (必填) - file: 文件对象 (必填,可多个,字段名为 file, file1, file2, ... 或 files[]) JSON Body (application/json): - kb_id: 知识库 ID (必填) - files: 文件数组 (必填) - file_name: 文件名 (必填) - file_content: base64 编码的文件内容 (必填) 返回: - task_id: 任务ID,用于查询上传进度和结果 """ try: kb_manager = self._get_kb_manager() # 检查 Content-Type content_type = request.content_type kb_id = None chunk_size = None chunk_overlap = None batch_size = 32 tasks_limit = 3 max_retries = 3 files_to_upload = [] # 存储待上传的文件信息列表 if content_type and "multipart/form-data" not in content_type: return ( Response().error("Content-Type 须为 multipart/form-data").__dict__ ) form_data = await request.form files = await request.files kb_id = form_data.get("kb_id") chunk_size = int(form_data.get("chunk_size", 512)) chunk_overlap = int(form_data.get("chunk_overlap", 50)) batch_size = int(form_data.get("batch_size", 32)) tasks_limit = int(form_data.get("tasks_limit", 3)) max_retries = int(form_data.get("max_retries", 3)) if not kb_id: return Response().error("缺少参数 kb_id").__dict__ # 收集所有文件 file_list = [] # 支持 file, file1, file2, ... 或 files[] 格式 for key in files.keys(): if key == "file" or key.startswith("file") or key == "files[]": file_items = files.getlist(key) file_list.extend(file_items) if not file_list: return Response().error("缺少文件").__dict__ # 限制文件数量 if len(file_list) > 10: return Response().error("最多只能上传10个文件").__dict__ # 处理每个文件 for file in file_list: file_name = file.filename # 保存到临时文件 temp_file_path = os.path.join( get_astrbot_temp_path(), f"kb_upload_{uuid.uuid4()}_{file_name}", ) await file.save(temp_file_path) try: # 异步读取文件内容 async with aiofiles.open(temp_file_path, "rb") as f: file_content = await f.read() # 提取文件类型 file_type = ( file_name.rsplit(".", 1)[-1].lower() if "." in file_name else "" ) files_to_upload.append( { "file_name": file_name, "file_content": file_content, "file_type": file_type, }, ) finally: # 清理临时文件 if os.path.exists(temp_file_path): os.remove(temp_file_path) # 获取知识库 kb_helper = await kb_manager.get_kb(kb_id) if not kb_helper: return Response().error("知识库不存在").__dict__ # 生成任务ID task_id = str(uuid.uuid4()) # 初始化任务状态 self._init_task(task_id, status="pending") # 启动后台任务 asyncio.create_task( self._background_upload_task( task_id=task_id, kb_helper=kb_helper, files_to_upload=files_to_upload, chunk_size=chunk_size, chunk_overlap=chunk_overlap, batch_size=batch_size, tasks_limit=tasks_limit, max_retries=max_retries, ), ) return ( Response() .ok( { "task_id": task_id, "file_count": len(files_to_upload), "message": "task created, processing in background", }, ) .__dict__ ) except ValueError as e: return Response().error(str(e)).__dict__ except Exception as e: logger.error(f"上传文档失败: {e}") logger.error(traceback.format_exc()) return Response().error(f"上传文档失败: {e!s}").__dict__ def _validate_import_request(self, data: dict): kb_id = data.get("kb_id") if not kb_id: raise ValueError("缺少参数 kb_id") documents = data.get("documents") if not documents or not isinstance(documents, list): raise ValueError("缺少参数 documents 或格式错误") for doc in documents: if "file_name" not in doc or "chunks" not in doc: raise ValueError("文档格式错误,必须包含 file_name 和 chunks") if not isinstance(doc["chunks"], list): raise ValueError("chunks 必须是列表") if not all( isinstance(chunk, str) and chunk.strip() for chunk in doc["chunks"] ): raise ValueError("chunks 必须是非空字符串列表") batch_size = data.get("batch_size", 32) tasks_limit = data.get("tasks_limit", 3) max_retries = data.get("max_retries", 3) return kb_id, documents, batch_size, tasks_limit, max_retries async def import_documents(self): """导入预切片文档 Body: - kb_id: 知识库 ID (必填) - documents: 文档列表 (必填) - file_name: 文件名 (必填) - chunks: 切片列表 (必填, list[str]) - file_type: 文件类型 (可选, 默认从文件名推断或为 txt) - batch_size: 批处理大小 (可选, 默认32) - tasks_limit: 并发任务限制 (可选, 默认3) - max_retries: 最大重试次数 (可选, 默认3) """ try: kb_manager = self._get_kb_manager() data = await request.json kb_id, documents, batch_size, tasks_limit, max_retries = ( self._validate_import_request(data) ) # 获取知识库 kb_helper = await kb_manager.get_kb(kb_id) if not kb_helper: return Response().error("知识库不存在").__dict__ # 生成任务ID task_id = str(uuid.uuid4()) # 初始化任务状态 self._init_task(task_id, status="pending") # 启动后台任务 asyncio.create_task( self._background_import_task( task_id=task_id, kb_helper=kb_helper, documents=documents, batch_size=batch_size, tasks_limit=tasks_limit, max_retries=max_retries, ), ) return ( Response() .ok( { "task_id": task_id, "doc_count": len(documents), "message": "import task created, processing in background", }, ) .__dict__ ) except ValueError as e: return Response().error(str(e)).__dict__ except Exception as e: logger.error(f"导入文档失败: {e}") logger.error(traceback.format_exc()) return Response().error(f"导入文档失败: {e!s}").__dict__ async def get_upload_progress(self): """获取上传进度和结果 Query 参数: - task_id: 任务 ID (必填) 返回状态: - pending: 任务待处理 - processing: 任务处理中 - completed: 任务完成 - failed: 任务失败 """ try: task_id = request.args.get("task_id") if not task_id: return Response().error("缺少参数 task_id").__dict__ # 检查任务是否存在 if task_id not in self.upload_tasks: return Response().error("找不到该任务").__dict__ task_info = self.upload_tasks[task_id] status = task_info["status"] # 构建返回数据 response_data = { "task_id": task_id, "status": status, } # 如果任务正在处理,返回进度信息 if status == "processing" and task_id in self.upload_progress: response_data["progress"] = self.upload_progress[task_id] # 如果任务完成,返回结果 if status == "completed": response_data["result"] = task_info["result"] # 清理已完成的任务 # del self.upload_tasks[task_id] # if task_id in self.upload_progress: # del self.upload_progress[task_id] # 如果任务失败,返回错误信息 if status == "failed": response_data["error"] = task_info["error"] return Response().ok(response_data).__dict__ except Exception as e: logger.error(f"获取上传进度失败: {e}") logger.error(traceback.format_exc()) return Response().error(f"获取上传进度失败: {e!s}").__dict__ async def get_document(self): """获取文档详情 Query 参数: - doc_id: 文档 ID (必填) """ try: kb_manager = self._get_kb_manager() kb_id = request.args.get("kb_id") if not kb_id: return Response().error("缺少参数 kb_id").__dict__ doc_id = request.args.get("doc_id") if not doc_id: return Response().error("缺少参数 doc_id").__dict__ kb_helper = await kb_manager.get_kb(kb_id) if not kb_helper: return Response().error("知识库不存在").__dict__ doc = await kb_helper.get_document(doc_id) if not doc: return Response().error("文档不存在").__dict__ return Response().ok(doc.model_dump()).__dict__ except ValueError as e: return Response().error(str(e)).__dict__ except Exception as e: logger.error(f"获取文档详情失败: {e}") logger.error(traceback.format_exc()) return Response().error(f"获取文档详情失败: {e!s}").__dict__ async def delete_document(self): """删除文档 Body: - kb_id: 知识库 ID (必填) - doc_id: 文档 ID (必填) """ try: kb_manager = self._get_kb_manager() data = await request.json kb_id = data.get("kb_id") if not kb_id: return Response().error("缺少参数 kb_id").__dict__ doc_id = data.get("doc_id") if not doc_id: return Response().error("缺少参数 doc_id").__dict__ kb_helper = await kb_manager.get_kb(kb_id) if not kb_helper: return Response().error("知识库不存在").__dict__ await kb_helper.delete_document(doc_id) return Response().ok(message="删除文档成功").__dict__ except ValueError as e: return Response().error(str(e)).__dict__ except Exception as e: logger.error(f"删除文档失败: {e}") logger.error(traceback.format_exc()) return Response().error(f"删除文档失败: {e!s}").__dict__ async def delete_chunk(self): """删除文本块 Body: - kb_id: 知识库 ID (必填) - chunk_id: 块 ID (必填) """ try: kb_manager = self._get_kb_manager() data = await request.json kb_id = data.get("kb_id") if not kb_id: return Response().error("缺少参数 kb_id").__dict__ chunk_id = data.get("chunk_id") if not chunk_id: return Response().error("缺少参数 chunk_id").__dict__ doc_id = data.get("doc_id") if not doc_id: return Response().error("缺少参数 doc_id").__dict__ kb_helper = await kb_manager.get_kb(kb_id) if not kb_helper: return Response().error("知识库不存在").__dict__ await kb_helper.delete_chunk(chunk_id, doc_id) return Response().ok(message="删除文本块成功").__dict__ except ValueError as e: return Response().error(str(e)).__dict__ except Exception as e: logger.error(f"删除文本块失败: {e}") logger.error(traceback.format_exc()) return Response().error(f"删除文本块失败: {e!s}").__dict__ async def list_chunks(self): """获取块列表 Query 参数: - kb_id: 知识库 ID (必填) - page: 页码 (默认 1) - page_size: 每页数量 (默认 20) """ try: kb_manager = self._get_kb_manager() kb_id = request.args.get("kb_id") doc_id = request.args.get("doc_id") page = request.args.get("page", 1, type=int) page_size = request.args.get("page_size", 100, type=int) if not kb_id: return Response().error("缺少参数 kb_id").__dict__ if not doc_id: return Response().error("缺少参数 doc_id").__dict__ kb_helper = await kb_manager.get_kb(kb_id) offset = (page - 1) * page_size limit = page_size if not kb_helper: return Response().error("知识库不存在").__dict__ chunk_list = await kb_helper.get_chunks_by_doc_id( doc_id=doc_id, offset=offset, limit=limit, ) return ( Response() .ok( data={ "items": chunk_list, "page": page, "page_size": page_size, "total": await kb_helper.get_chunk_count_by_doc_id(doc_id), }, ) .__dict__ ) except ValueError as e: return Response().error(str(e)).__dict__ except Exception as e: logger.error(f"获取块列表失败: {e}") logger.error(traceback.format_exc()) return Response().error(f"获取块列表失败: {e!s}").__dict__ # ===== 检索 API ===== async def retrieve(self): """检索知识库 Body: - query: 查询文本 (必填) - kb_ids: 知识库 ID 列表 (必填) - top_k: 返回结果数量 (可选, 默认 5) - debug: 是否启用调试模式,返回 t-SNE 可视化图片 (可选, 默认 False) """ try: kb_manager = self._get_kb_manager() data = await request.json query = data.get("query") kb_names = data.get("kb_names") debug = data.get("debug", False) if not query: return Response().error("缺少参数 query").__dict__ if not kb_names or not isinstance(kb_names, list): return Response().error("缺少参数 kb_names 或格式错误").__dict__ top_k = data.get("top_k", 5) results = await kb_manager.retrieve( query=query, kb_names=kb_names, top_m_final=top_k, ) result_list = [] if results: result_list = results["results"] response_data = { "results": result_list, "total": len(result_list), "query": query, } # Debug 模式:生成 t-SNE 可视化 if debug: try: img_base64 = await generate_tsne_visualization( query, kb_names, kb_manager, ) if img_base64: response_data["visualization"] = img_base64 except Exception as e: logger.error(f"生成 t-SNE 可视化失败: {e}") logger.error(traceback.format_exc()) response_data["visualization_error"] = str(e) return Response().ok(response_data).__dict__ except ValueError as e: return Response().error(str(e)).__dict__ except Exception as e: logger.error(f"检索失败: {e}") logger.error(traceback.format_exc()) return Response().error(f"检索失败: {e!s}").__dict__ async def upload_document_from_url(self): """从 URL 上传文档 Body: - kb_id: 知识库 ID (必填) - url: 要提取内容的网页 URL (必填) - chunk_size: 分块大小 (可选, 默认512) - chunk_overlap: 块重叠大小 (可选, 默认50) - batch_size: 批处理大小 (可选, 默认32) - tasks_limit: 并发任务限制 (可选, 默认3) - max_retries: 最大重试次数 (可选, 默认3) 返回: - task_id: 任务ID,用于查询上传进度和结果 """ try: kb_manager = self._get_kb_manager() data = await request.json kb_id = data.get("kb_id") if not kb_id: return Response().error("缺少参数 kb_id").__dict__ url = data.get("url") if not url: return Response().error("缺少参数 url").__dict__ chunk_size = data.get("chunk_size", 512) chunk_overlap = data.get("chunk_overlap", 50) batch_size = data.get("batch_size", 32) tasks_limit = data.get("tasks_limit", 3) max_retries = data.get("max_retries", 3) enable_cleaning = data.get("enable_cleaning", False) cleaning_provider_id = data.get("cleaning_provider_id") # 获取知识库 kb_helper = await kb_manager.get_kb(kb_id) if not kb_helper: return Response().error("知识库不存在").__dict__ # 生成任务ID task_id = str(uuid.uuid4()) # 初始化任务状态 self._init_task(task_id, status="pending") # 启动后台任务 asyncio.create_task( self._background_upload_from_url_task( task_id=task_id, kb_helper=kb_helper, url=url, chunk_size=chunk_size, chunk_overlap=chunk_overlap, batch_size=batch_size, tasks_limit=tasks_limit, max_retries=max_retries, enable_cleaning=enable_cleaning, cleaning_provider_id=cleaning_provider_id, ), ) return ( Response() .ok( { "task_id": task_id, "url": url, "message": "URL upload task created, processing in background", }, ) .__dict__ ) except ValueError as e: return Response().error(str(e)).__dict__ except Exception as e: logger.error(f"从URL上传文档失败: {e}") logger.error(traceback.format_exc()) return Response().error(f"从URL上传文档失败: {e!s}").__dict__ async def _background_upload_from_url_task( self, task_id: str, kb_helper, url: str, chunk_size: int, chunk_overlap: int, batch_size: int, tasks_limit: int, max_retries: int, enable_cleaning: bool, cleaning_provider_id: str | None, ) -> None: """后台上传URL任务""" try: # 初始化任务状态 self._init_task(task_id, status="processing") self.upload_progress[task_id] = { "status": "processing", "file_index": 0, "file_total": 1, "file_name": f"URL: {url}", "stage": "extracting", "current": 0, "total": 100, } # 创建进度回调函数 progress_callback = self._make_progress_callback(task_id, 0, f"URL: {url}") # 上传文档 doc = await kb_helper.upload_from_url( url=url, chunk_size=chunk_size, chunk_overlap=chunk_overlap, batch_size=batch_size, tasks_limit=tasks_limit, max_retries=max_retries, progress_callback=progress_callback, enable_cleaning=enable_cleaning, cleaning_provider_id=cleaning_provider_id, ) # 更新任务完成状态 result = { "task_id": task_id, "uploaded": [doc.model_dump()], "failed": [], "total": 1, "success_count": 1, "failed_count": 0, } self._set_task_result(task_id, "completed", result=result) except Exception as e: logger.error(f"后台上传URL任务 {task_id} 失败: {e}") logger.error(traceback.format_exc()) self._set_task_result(task_id, "failed", error=str(e))