diff --git a/astrbot/dashboard/routes/knowledge_base.py b/astrbot/dashboard/routes/knowledge_base.py index d7db42c40..537a81f0b 100644 --- a/astrbot/dashboard/routes/knowledge_base.py +++ b/astrbot/dashboard/routes/knowledge_base.py @@ -48,6 +48,7 @@ class KnowledgeBaseRoute(Route): # 文档管理 "/kb/document/list": ("GET", self.list_documents), "/kb/document/upload": ("POST", self.upload_document), + "/kb/document/import": ("POST", self.import_documents), "/kb/document/upload/url": ("POST", self.upload_document_from_url), "/kb/document/upload/progress": ("GET", self.get_upload_progress), "/kb/document/get": ("GET", self.get_document), @@ -66,6 +67,65 @@ class KnowledgeBaseRoute(Route): def _get_kb_manager(self): return self.core_lifecycle.kb_manager + def _init_task(self, task_id: str, status: str = "pending") -> None: + self.upload_tasks[task_id] = { + "status": status, + "result": None, + "error": None, + } + + def _set_task_result( + self, task_id: str, status: str, result: any = None, error: str | None = None + ) -> None: + self.upload_tasks[task_id] = { + "status": status, + "result": result, + "error": error, + } + if task_id in self.upload_progress: + self.upload_progress[task_id]["status"] = status + + def _update_progress( + self, + task_id: str, + *, + status: str | None = None, + file_index: int | None = None, + file_name: str | None = None, + stage: str | None = None, + current: int | None = None, + total: int | None = None, + ) -> None: + if task_id not in self.upload_progress: + return + p = self.upload_progress[task_id] + if status is not None: + p["status"] = status + if file_index is not None: + p["file_index"] = file_index + if file_name is not None: + p["file_name"] = file_name + if stage is not None: + p["stage"] = stage + if current is not None: + p["current"] = current + if total is not None: + p["total"] = total + + def _make_progress_callback(self, task_id: str, file_idx: int, file_name: str): + async def _callback(stage: str, current: int, total: int): + self._update_progress( + task_id, + status="processing", + file_index=file_idx, + file_name=file_name, + stage=stage, + current=current, + total=total, + ) + + return _callback + async def _background_upload_task( self, task_id: str, @@ -80,11 +140,7 @@ class KnowledgeBaseRoute(Route): """后台上传任务""" try: # 初始化任务状态 - self.upload_tasks[task_id] = { - "status": "processing", - "result": None, - "error": None, - } + self._init_task(task_id, status="processing") self.upload_progress[task_id] = { "status": "processing", "file_index": 0, @@ -100,30 +156,20 @@ class KnowledgeBaseRoute(Route): for file_idx, file_info in enumerate(files_to_upload): try: # 更新整体进度 - self.upload_progress[task_id].update( - { - "status": "processing", - "file_index": file_idx, - "file_name": file_info["file_name"], - "stage": "parsing", - "current": 0, - "total": 100, - }, + self._update_progress( + task_id, + status="processing", + file_index=file_idx, + file_name=file_info["file_name"], + stage="parsing", + current=0, + total=100, ) # 创建进度回调函数 - async def progress_callback(stage, current, total): - if task_id in self.upload_progress: - self.upload_progress[task_id].update( - { - "status": "processing", - "file_index": file_idx, - "file_name": file_info["file_name"], - "stage": stage, - "current": current, - "total": total, - }, - ) + progress_callback = self._make_progress_callback( + task_id, file_idx, file_info["file_name"] + ) doc = await kb_helper.upload_document( file_name=file_info["file_name"], @@ -154,23 +200,99 @@ class KnowledgeBaseRoute(Route): "failed_count": len(failed_docs), } - self.upload_tasks[task_id] = { - "status": "completed", - "result": result, - "error": None, - } - self.upload_progress[task_id]["status"] = "completed" + self._set_task_result(task_id, "completed", result=result) except Exception as e: logger.error(f"后台上传任务 {task_id} 失败: {e}") logger.error(traceback.format_exc()) - self.upload_tasks[task_id] = { - "status": "failed", - "result": None, - "error": str(e), + self._set_task_result(task_id, "failed", error=str(e)) + + async def _background_import_task( + self, + task_id: str, + kb_helper, + documents: list, + batch_size: int, + tasks_limit: int, + max_retries: int, + ): + """后台导入预切片文档任务""" + try: + # 初始化任务状态 + self._init_task(task_id, status="processing") + self.upload_progress[task_id] = { + "status": "processing", + "file_index": 0, + "file_total": len(documents), + "stage": "waiting", + "current": 0, + "total": 100, } - if task_id in self.upload_progress: - self.upload_progress[task_id]["status"] = "failed" + + uploaded_docs = [] + failed_docs = [] + + for file_idx, doc_info in enumerate(documents): + file_name = doc_info.get("file_name", f"imported_doc_{file_idx}") + chunks = doc_info.get("chunks", []) + + try: + # 更新整体进度 + self._update_progress( + task_id, + status="processing", + file_index=file_idx, + file_name=file_name, + stage="importing", + current=0, + total=100, + ) + + # 创建进度回调函数 + progress_callback = self._make_progress_callback( + task_id, file_idx, file_name + ) + + # 调用 upload_document,传入 pre_chunked_text + doc = await kb_helper.upload_document( + file_name=file_name, + file_content=None, # 预切片模式下不需要原始内容 + file_type=doc_info.get("file_type") + or ( + file_name.rsplit(".", 1)[-1].lower() + if "." in file_name + else "txt" + ), + batch_size=batch_size, + tasks_limit=tasks_limit, + max_retries=max_retries, + progress_callback=progress_callback, + pre_chunked_text=chunks, + ) + + uploaded_docs.append(doc.model_dump()) + except Exception as e: + logger.error(f"导入文档 {file_name} 失败: {e}") + failed_docs.append( + {"file_name": file_name, "error": str(e)}, + ) + + # 更新任务完成状态 + result = { + "task_id": task_id, + "uploaded": uploaded_docs, + "failed": failed_docs, + "total": len(documents), + "success_count": len(uploaded_docs), + "failed_count": len(failed_docs), + } + + self._set_task_result(task_id, "completed", result=result) + + except Exception as e: + logger.error(f"后台导入任务 {task_id} 失败: {e}") + logger.error(traceback.format_exc()) + self._set_task_result(task_id, "failed", error=str(e)) async def list_kbs(self): """获取知识库列表 @@ -614,11 +736,7 @@ class KnowledgeBaseRoute(Route): task_id = str(uuid.uuid4()) # 初始化任务状态 - self.upload_tasks[task_id] = { - "status": "pending", - "result": None, - "error": None, - } + self._init_task(task_id, status="pending") # 启动后台任务 asyncio.create_task( @@ -653,6 +771,93 @@ class KnowledgeBaseRoute(Route): logger.error(traceback.format_exc()) return Response().error(f"上传文档失败: {e!s}").__dict__ + def _validate_import_request(self, data: dict): + kb_id = data.get("kb_id") + if not kb_id: + raise ValueError("缺少参数 kb_id") + + documents = data.get("documents") + if not documents or not isinstance(documents, list): + raise ValueError("缺少参数 documents 或格式错误") + + for doc in documents: + if "file_name" not in doc or "chunks" not in doc: + raise ValueError("文档格式错误,必须包含 file_name 和 chunks") + if not isinstance(doc["chunks"], list): + raise ValueError("chunks 必须是列表") + if not all( + isinstance(chunk, str) and chunk.strip() for chunk in doc["chunks"] + ): + raise ValueError("chunks 必须是非空字符串列表") + + batch_size = data.get("batch_size", 32) + tasks_limit = data.get("tasks_limit", 3) + max_retries = data.get("max_retries", 3) + return kb_id, documents, batch_size, tasks_limit, max_retries + + async def import_documents(self): + """导入预切片文档 + + Body: + - kb_id: 知识库 ID (必填) + - documents: 文档列表 (必填) + - file_name: 文件名 (必填) + - chunks: 切片列表 (必填, list[str]) + - file_type: 文件类型 (可选, 默认从文件名推断或为 txt) + - batch_size: 批处理大小 (可选, 默认32) + - tasks_limit: 并发任务限制 (可选, 默认3) + - max_retries: 最大重试次数 (可选, 默认3) + """ + try: + kb_manager = self._get_kb_manager() + data = await request.json + + kb_id, documents, batch_size, tasks_limit, max_retries = ( + self._validate_import_request(data) + ) + + # 获取知识库 + kb_helper = await kb_manager.get_kb(kb_id) + if not kb_helper: + return Response().error("知识库不存在").__dict__ + + # 生成任务ID + task_id = str(uuid.uuid4()) + + # 初始化任务状态 + self._init_task(task_id, status="pending") + + # 启动后台任务 + asyncio.create_task( + self._background_import_task( + task_id=task_id, + kb_helper=kb_helper, + documents=documents, + batch_size=batch_size, + tasks_limit=tasks_limit, + max_retries=max_retries, + ), + ) + + return ( + Response() + .ok( + { + "task_id": task_id, + "doc_count": len(documents), + "message": "import task created, processing in background", + }, + ) + .__dict__ + ) + + except ValueError as e: + return Response().error(str(e)).__dict__ + except Exception as e: + logger.error(f"导入文档失败: {e}") + logger.error(traceback.format_exc()) + return Response().error(f"导入文档失败: {e!s}").__dict__ + async def get_upload_progress(self): """获取上传进度和结果 @@ -960,11 +1165,7 @@ class KnowledgeBaseRoute(Route): task_id = str(uuid.uuid4()) # 初始化任务状态 - self.upload_tasks[task_id] = { - "status": "pending", - "result": None, - "error": None, - } + self._init_task(task_id, status="pending") # 启动后台任务 asyncio.create_task( @@ -1017,11 +1218,7 @@ class KnowledgeBaseRoute(Route): """后台上传URL任务""" try: # 初始化任务状态 - self.upload_tasks[task_id] = { - "status": "processing", - "result": None, - "error": None, - } + self._init_task(task_id, status="processing") self.upload_progress[task_id] = { "status": "processing", "file_index": 0, @@ -1033,18 +1230,7 @@ class KnowledgeBaseRoute(Route): } # 创建进度回调函数 - async def progress_callback(stage, current, total): - if task_id in self.upload_progress: - self.upload_progress[task_id].update( - { - "status": "processing", - "file_index": 0, - "file_name": f"URL: {url}", - "stage": stage, - "current": current, - "total": total, - }, - ) + progress_callback = self._make_progress_callback(task_id, 0, f"URL: {url}") # 上传文档 doc = await kb_helper.upload_from_url( @@ -1069,20 +1255,9 @@ class KnowledgeBaseRoute(Route): "failed_count": 0, } - self.upload_tasks[task_id] = { - "status": "completed", - "result": result, - "error": None, - } - self.upload_progress[task_id]["status"] = "completed" + self._set_task_result(task_id, "completed", result=result) except Exception as e: logger.error(f"后台上传URL任务 {task_id} 失败: {e}") logger.error(traceback.format_exc()) - self.upload_tasks[task_id] = { - "status": "failed", - "result": None, - "error": str(e), - } - if task_id in self.upload_progress: - self.upload_progress[task_id]["status"] = "failed" + self._set_task_result(task_id, "failed", error=str(e)) diff --git a/tests/test_kb_import.py b/tests/test_kb_import.py new file mode 100644 index 000000000..8ad40f540 --- /dev/null +++ b/tests/test_kb_import.py @@ -0,0 +1,209 @@ +import asyncio +from unittest.mock import AsyncMock, MagicMock + +import pytest +import pytest_asyncio +from quart import Quart + +from astrbot.core import LogBroker +from astrbot.core.core_lifecycle import AstrBotCoreLifecycle +from astrbot.core.db.sqlite import SQLiteDatabase +from astrbot.core.knowledge_base.kb_helper import KBHelper +from astrbot.core.knowledge_base.models import KBDocument +from astrbot.dashboard.server import AstrBotDashboard + + +@pytest_asyncio.fixture(scope="module") +async def core_lifecycle_td(tmp_path_factory): + """Creates and initializes a core lifecycle instance with a temporary database.""" + tmp_db_path = tmp_path_factory.mktemp("data") / "test_data_kb.db" + db = SQLiteDatabase(str(tmp_db_path)) + log_broker = LogBroker() + core_lifecycle = AstrBotCoreLifecycle(log_broker, db) + await core_lifecycle.initialize() + + # Mock kb_manager and kb_helper + kb_manager = MagicMock() + kb_helper = AsyncMock(spec=KBHelper) + + # Configure get_kb to be an async mock that returns kb_helper + kb_manager.get_kb = AsyncMock(return_value=kb_helper) + + # Mock upload_document return value + mock_doc = KBDocument( + doc_id="test_doc_id", + kb_id="test_kb_id", + doc_name="test_file.txt", + file_type="txt", + file_size=100, + file_path="", + chunk_count=2, + media_count=0, + ) + kb_helper.upload_document.return_value = mock_doc + + # kb_manager.get_kb.return_value = kb_helper # Removed this line as it's handled above + core_lifecycle.kb_manager = kb_manager + + try: + yield core_lifecycle + finally: + try: + _stop_res = core_lifecycle.stop() + if asyncio.iscoroutine(_stop_res): + await _stop_res + except Exception: + pass + + +@pytest.fixture(scope="module") +def app(core_lifecycle_td: AstrBotCoreLifecycle): + """Creates a Quart app instance for testing.""" + shutdown_event = asyncio.Event() + server = AstrBotDashboard(core_lifecycle_td, core_lifecycle_td.db, shutdown_event) + return server.app + + +@pytest_asyncio.fixture(scope="module") +async def authenticated_header(app: Quart, core_lifecycle_td: AstrBotCoreLifecycle): + """Handles login and returns an authenticated header.""" + test_client = app.test_client() + response = await test_client.post( + "/api/auth/login", + json={ + "username": core_lifecycle_td.astrbot_config["dashboard"]["username"], + "password": core_lifecycle_td.astrbot_config["dashboard"]["password"], + }, + ) + data = await response.get_json() + assert data["status"] == "ok" + token = data["data"]["token"] + return {"Authorization": f"Bearer {token}"} + + +@pytest.mark.asyncio +async def test_import_documents( + app: Quart, authenticated_header: dict, core_lifecycle_td: AstrBotCoreLifecycle +): + """Tests the import documents functionality.""" + test_client = app.test_client() + + # Test data + import_data = { + "kb_id": "test_kb_id", + "documents": [ + {"file_name": "test_file_1.txt", "chunks": ["chunk1", "chunk2"]}, + {"file_name": "test_file_2.md", "chunks": ["chunk3", "chunk4", "chunk5"]}, + ], + } + + # Send request + response = await test_client.post( + "/api/kb/document/import", json=import_data, headers=authenticated_header + ) + + # Verify response + assert response.status_code == 200 + data = await response.get_json() + assert data["status"] == "ok" + assert "task_id" in data["data"] + assert data["data"]["doc_count"] == 2 + + task_id = data["data"]["task_id"] + + # Wait for background task to complete (mocked) + # Since we mocked upload_document, it should be fast, but we might need to poll progress + for _ in range(10): + progress_response = await test_client.get( + f"/api/kb/document/upload/progress?task_id={task_id}", + headers=authenticated_header, + ) + progress_data = await progress_response.get_json() + if progress_data["data"]["status"] == "completed": + break + await asyncio.sleep(0.1) + + assert progress_data["data"]["status"] == "completed" + result = progress_data["data"]["result"] + assert result["success_count"] == 2 + assert result["failed_count"] == 0 + + # Verify kb_helper.upload_document was called correctly + kb_helper = await core_lifecycle_td.kb_manager.get_kb("test_kb_id") + assert kb_helper.upload_document.call_count == 2 + + # Check first call arguments + call_args_list = kb_helper.upload_document.call_args_list + + # First document + args1, kwargs1 = call_args_list[0] + assert kwargs1["file_name"] == "test_file_1.txt" + assert kwargs1["pre_chunked_text"] == ["chunk1", "chunk2"] + + # Second document + args2, kwargs2 = call_args_list[1] + assert kwargs2["file_name"] == "test_file_2.md" + assert kwargs2["pre_chunked_text"] == ["chunk3", "chunk4", "chunk5"] + + +@pytest.mark.asyncio +async def test_import_documents_invalid_input(app: Quart, authenticated_header: dict): + """Tests import documents with invalid input.""" + test_client = app.test_client() + + # Missing kb_id + response = await test_client.post( + "/api/kb/document/import", json={"documents": []}, headers=authenticated_header + ) + data = await response.get_json() + assert data["status"] == "error" + assert "缺少参数 kb_id" in data["message"] + + # Missing documents + response = await test_client.post( + "/api/kb/document/import", + json={"kb_id": "test_kb"}, + headers=authenticated_header, + ) + data = await response.get_json() + assert data["status"] == "error" + assert "缺少参数 documents" in data["message"] + + # Invalid document format + response = await test_client.post( + "/api/kb/document/import", + json={ + "kb_id": "test_kb", + "documents": [{"file_name": "test"}], # Missing chunks + }, + headers=authenticated_header, + ) + data = await response.get_json() + assert data["status"] == "error" + assert "文档格式错误" in data["message"] + + # Invalid chunks type + response = await test_client.post( + "/api/kb/document/import", + json={ + "kb_id": "test_kb", + "documents": [{"file_name": "test", "chunks": "not-a-list"}], + }, + headers=authenticated_header, + ) + data = await response.get_json() + assert data["status"] == "error" + assert "chunks 必须是列表" in data["message"] + + # Invalid chunks content + response = await test_client.post( + "/api/kb/document/import", + json={ + "kb_id": "test_kb", + "documents": [{"file_name": "test", "chunks": ["valid", ""]}], + }, + headers=authenticated_header, + ) + data = await response.get_json() + assert data["status"] == "error" + assert "chunks 必须是非空字符串列表" in data["message"]