Compare commits

...

11 Commits

Author SHA1 Message Date
Soulter 0916177a57 chore: bump version to 4.9.1 2025-12-15 16:07:10 +08:00
Soulter 02cd5e396b feat: add trigger probability setting for TTS and support to render slider in schema (#4047)
* feat: add trigger probability setting for TTS and support to render slider in schema

* chore: ruff format
2025-12-15 16:04:27 +08:00
Soulter 56673ad78f fix: prevent duplicate result content type after streaming finishes in RespondStage 2025-12-15 15:33:40 +08:00
Soulter 9a4d05e2b6 fix: remove unnecessary persistent attribute from ReadmeDialog and adjust dialog structure in ExtensionPage 2025-12-15 15:27:42 +08:00
Soulter c3f45449e8 docs: readme
wa ta shi wa ko sei no de su ka ra!
2025-12-15 11:47:21 +08:00
Copilot 65da469deb feat: add conversation export feature to JSONL for AI training (#4037)
* Initial plan

* Add conversation export functionality (backend and frontend)

Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>

* Address code review feedback: move imports, simplify logic, improve i18n

Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>

* Simplify frontend download logic: remove redundant Blob wrapper and complex filename parsing

Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>

* fix: update conversation export filename format for consistency

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>
Co-authored-by: Soulter <905617992@qq.com>
2025-12-14 21:44:12 +08:00
Soulter 16df64c405 fix: lark domain and log_level of Lark API client (#4038)
fixes: #4035
2025-12-14 21:31:17 +08:00
i0cLiceao 6b73b19e54 fix: support using GitHub Raw content as plugin source (#3975)
* Update plugin.py

* Update plugin.py

* Update plugin.py

* Update plugin.py
2025-12-14 18:23:29 +08:00
Soulter e7e97730af chore: bump version to 4.9.0 2025-12-13 18:49:07 +08:00
Soulter 467ca1eb5c fix: webui log output incompletely (#4029)
* fix: webui log output incompletely

* fix: improve SSE log parsing to handle partial data chunks

* fix: enhance log handling by implementing local cache and fetching history

* fix: log time handling to use epoch time
2025-12-13 18:46:16 +08:00
RC-CHN 46528391c2 feat: add pre-chunk import strategy for knowledge base (#3973)
* feat: 添加文档导入功能及相关测试

* feat: 优化文档上传功能,支持从文件名推断文件类型,并增强文档切片验证

* feat: 添加文档导入功能的无效输入测试,验证 chunks 类型和内容的错误处理

* refactor: 重构文档上传和导入任务的状态管理,添加任务初始化、结果设置和进度更新方法
2025-12-12 23:15:11 +08:00
27 changed files with 920 additions and 223 deletions
+6
View File
@@ -243,4 +243,10 @@ pre-commit install
</details> </details>
<div align="center">
_私は、高性能ですから!_ _私は、高性能ですから!_
<img src="https://files.astrbot.app/watashiwa-koseino-desukara.gif" width="100"/>
</div
+1 -1
View File
@@ -1 +1 @@
__version__ = "4.8.0" __version__ = "4.9.1"
+14 -1
View File
@@ -4,7 +4,7 @@ import os
from astrbot.core.utils.astrbot_path import get_astrbot_data_path from astrbot.core.utils.astrbot_path import get_astrbot_data_path
VERSION = "4.8.0" VERSION = "4.9.1"
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db") DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db")
WEBHOOK_SUPPORTED_PLATFORMS = [ WEBHOOK_SUPPORTED_PLATFORMS = [
@@ -108,6 +108,7 @@ DEFAULT_CONFIG = {
"provider_id": "", "provider_id": "",
"dual_output": False, "dual_output": False,
"use_file_service": False, "use_file_service": False,
"trigger_probability": 1.0,
}, },
"provider_ltm_settings": { "provider_ltm_settings": {
"group_icl_enable": False, "group_icl_enable": False,
@@ -2209,6 +2210,9 @@ CONFIG_METADATA_2 = {
"use_file_service": { "use_file_service": {
"type": "bool", "type": "bool",
}, },
"trigger_probability": {
"type": "float",
},
}, },
}, },
"provider_ltm_settings": { "provider_ltm_settings": {
@@ -2419,6 +2423,14 @@ CONFIG_METADATA_3 = {
"provider_tts_settings.enable": True, "provider_tts_settings.enable": True,
}, },
}, },
"provider_tts_settings.trigger_probability": {
"description": "TTS 触发概率",
"type": "float",
"slider": {"min": 0, "max": 1, "step": 0.05},
"condition": {
"provider_tts_settings.enable": True,
},
},
"provider_settings.image_caption_prompt": { "provider_settings.image_caption_prompt": {
"description": "图片转述提示词", "description": "图片转述提示词",
"type": "text", "type": "text",
@@ -2986,6 +2998,7 @@ CONFIG_METADATA_3 = {
"description": "回复概率", "description": "回复概率",
"type": "float", "type": "float",
"hint": "0.0-1.0 之间的数值", "hint": "0.0-1.0 之间的数值",
"slider": {"min": 0, "max": 1, "step": 0.05},
"condition": { "condition": {
"provider_ltm_settings.active_reply.enable": True, "provider_ltm_settings.active_reply.enable": True,
}, },
+1
View File
@@ -79,6 +79,7 @@ class ConfigMetadataI18n:
"_special", "_special",
"invisible", "invisible",
"options", "options",
"slider",
]: ]:
if attr in field_data: if attr in field_data:
field_result[attr] = field_data[attr] field_result[attr] = field_data[attr]
+2 -1
View File
@@ -24,6 +24,7 @@ import asyncio
import logging import logging
import os import os
import sys import sys
import time
from asyncio import Queue from asyncio import Queue
from collections import deque from collections import deque
@@ -148,7 +149,7 @@ class LogQueueHandler(logging.Handler):
self.log_broker.publish( self.log_broker.publish(
{ {
"level": record.levelname, "level": record.levelname,
"time": record.asctime, "time": time.time(),
"data": log_entry, "data": log_entry,
}, },
) )
+4
View File
@@ -158,7 +158,11 @@ class RespondStage(Stage):
result = event.get_result() result = event.get_result()
if result is None: if result is None:
return return
if event.get_extra("_streaming_finished", False):
# prevent some plugin make result content type to LLM_RESULT after streaming finished, lead to send again
return
if result.result_content_type == ResultContentType.STREAMING_FINISH: if result.result_content_type == ResultContentType.STREAMING_FINISH:
event.set_extra("_streaming_finished", True)
return return
logger.info( logger.info(
+21 -1
View File
@@ -1,3 +1,4 @@
import random
import re import re
import time import time
import traceback import traceback
@@ -42,6 +43,18 @@ class ResultDecorateStage(Stage):
"forward_threshold" "forward_threshold"
] ]
trigger_probability = ctx.astrbot_config["provider_tts_settings"].get(
"trigger_probability",
1,
)
try:
self.tts_trigger_probability = max(
0.0,
min(float(trigger_probability), 1.0),
)
except (TypeError, ValueError):
self.tts_trigger_probability = 1.0
# 分段回复 # 分段回复
self.words_count_threshold = int( self.words_count_threshold = int(
ctx.astrbot_config["platform_settings"]["segmented_reply"][ ctx.astrbot_config["platform_settings"]["segmented_reply"][
@@ -246,7 +259,14 @@ class ResultDecorateStage(Stage):
and result.is_llm_result() and result.is_llm_result()
and SessionServiceManager.should_process_tts_request(event) and SessionServiceManager.should_process_tts_request(event)
): ):
if not tts_provider: should_tts = self.tts_trigger_probability >= 1.0 or (
self.tts_trigger_probability > 0.0
and random.random() <= self.tts_trigger_probability
)
if not should_tts:
logger.debug("跳过 TTS:触发概率未命中。")
elif not tts_provider:
logger.warning( logger.warning(
f"会话 {event.unified_msg_origin} 未配置文本转语音模型。", f"会话 {event.unified_msg_origin} 未配置文本转语音模型。",
) )
@@ -81,7 +81,12 @@ class LarkPlatformAdapter(Platform):
) )
self.lark_api = ( self.lark_api = (
lark.Client.builder().app_id(self.appid).app_secret(self.appsecret).build() lark.Client.builder()
.app_id(self.appid)
.app_secret(self.appsecret)
.log_level(lark.LogLevel.ERROR)
.domain(self.domain)
.build()
) )
self.webhook_server = None self.webhook_server = None
+91 -1
View File
@@ -1,7 +1,9 @@
import json import json
import traceback import traceback
from datetime import datetime
from io import BytesIO
from quart import request from quart import request, send_file
from astrbot.core import logger from astrbot.core import logger
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
@@ -30,6 +32,7 @@ class ConversationRoute(Route):
"POST", "POST",
self.update_history, self.update_history,
), ),
"/conversation/export": ("POST", self.export_conversations),
} }
self.db_helper = db_helper self.db_helper = db_helper
self.conv_mgr = core_lifecycle.conversation_manager self.conv_mgr = core_lifecycle.conversation_manager
@@ -283,3 +286,90 @@ class ConversationRoute(Route):
except Exception as e: except Exception as e:
logger.error(f"更新对话历史失败: {e!s}\n{traceback.format_exc()}") logger.error(f"更新对话历史失败: {e!s}\n{traceback.format_exc()}")
return Response().error(f"更新对话历史失败: {e!s}").__dict__ return Response().error(f"更新对话历史失败: {e!s}").__dict__
async def export_conversations(self):
"""批量导出对话为 JSONL 格式"""
try:
data = await request.get_json()
conversations_to_export = data.get("conversations", [])
if not conversations_to_export:
return Response().error("导出列表不能为空").__dict__
# 收集所有对话的内容
jsonl_lines = []
exported_count = 0
failed_items = []
for conv_info in conversations_to_export:
user_id = conv_info.get("user_id")
cid = conv_info.get("cid")
if not user_id or not cid:
failed_items.append(
f"user_id:{user_id}, cid:{cid} - 缺少必要参数",
)
continue
try:
conversation = await self.conv_mgr.get_conversation(
unified_msg_origin=user_id,
conversation_id=cid,
)
if not conversation:
failed_items.append(
f"user_id:{user_id}, cid:{cid} - 对话不存在"
)
continue
# 解析对话内容 (history is always a JSON string from _convert_conv_from_v2_to_v1)
content = json.loads(conversation.history)
# 创建导出记录
export_record = {
"cid": cid,
"user_id": user_id,
"platform_id": conversation.platform_id,
"title": conversation.title,
"persona_id": conversation.persona_id,
"created_at": conversation.created_at,
"updated_at": conversation.updated_at,
"content": content,
}
# 将记录转换为 JSON 字符串并添加到 JSONL
jsonl_lines.append(json.dumps(export_record, ensure_ascii=False))
exported_count += 1
except Exception as e:
failed_items.append(f"user_id:{user_id}, cid:{cid} - {e!s}")
logger.error(
f"导出对话失败: user_id={user_id}, cid={cid}, error={e!s}"
)
if exported_count == 0:
return Response().error("没有成功导出任何对话").__dict__
# 创建 JSONL 内容
jsonl_content = "\n".join(jsonl_lines)
# 创建一个内存文件对象
file_obj = BytesIO(jsonl_content.encode("utf-8"))
file_obj.seek(0)
# 生成文件名
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"astrbot_conversations_export_{timestamp}.jsonl"
# 返回文件流
return await send_file(
file_obj,
mimetype="application/jsonl",
as_attachment=True,
attachment_filename=filename,
)
except Exception as e:
logger.error(f"批量导出对话失败: {e!s}\n{traceback.format_exc()}")
return Response().error(f"批量导出对话失败: {e!s}").__dict__
+252 -77
View File
@@ -48,6 +48,7 @@ class KnowledgeBaseRoute(Route):
# 文档管理 # 文档管理
"/kb/document/list": ("GET", self.list_documents), "/kb/document/list": ("GET", self.list_documents),
"/kb/document/upload": ("POST", self.upload_document), "/kb/document/upload": ("POST", self.upload_document),
"/kb/document/import": ("POST", self.import_documents),
"/kb/document/upload/url": ("POST", self.upload_document_from_url), "/kb/document/upload/url": ("POST", self.upload_document_from_url),
"/kb/document/upload/progress": ("GET", self.get_upload_progress), "/kb/document/upload/progress": ("GET", self.get_upload_progress),
"/kb/document/get": ("GET", self.get_document), "/kb/document/get": ("GET", self.get_document),
@@ -66,6 +67,65 @@ class KnowledgeBaseRoute(Route):
def _get_kb_manager(self): def _get_kb_manager(self):
return self.core_lifecycle.kb_manager return self.core_lifecycle.kb_manager
def _init_task(self, task_id: str, status: str = "pending") -> None:
self.upload_tasks[task_id] = {
"status": status,
"result": None,
"error": None,
}
def _set_task_result(
self, task_id: str, status: str, result: any = None, error: str | None = None
) -> None:
self.upload_tasks[task_id] = {
"status": status,
"result": result,
"error": error,
}
if task_id in self.upload_progress:
self.upload_progress[task_id]["status"] = status
def _update_progress(
self,
task_id: str,
*,
status: str | None = None,
file_index: int | None = None,
file_name: str | None = None,
stage: str | None = None,
current: int | None = None,
total: int | None = None,
) -> None:
if task_id not in self.upload_progress:
return
p = self.upload_progress[task_id]
if status is not None:
p["status"] = status
if file_index is not None:
p["file_index"] = file_index
if file_name is not None:
p["file_name"] = file_name
if stage is not None:
p["stage"] = stage
if current is not None:
p["current"] = current
if total is not None:
p["total"] = total
def _make_progress_callback(self, task_id: str, file_idx: int, file_name: str):
async def _callback(stage: str, current: int, total: int):
self._update_progress(
task_id,
status="processing",
file_index=file_idx,
file_name=file_name,
stage=stage,
current=current,
total=total,
)
return _callback
async def _background_upload_task( async def _background_upload_task(
self, self,
task_id: str, task_id: str,
@@ -80,11 +140,7 @@ class KnowledgeBaseRoute(Route):
"""后台上传任务""" """后台上传任务"""
try: try:
# 初始化任务状态 # 初始化任务状态
self.upload_tasks[task_id] = { self._init_task(task_id, status="processing")
"status": "processing",
"result": None,
"error": None,
}
self.upload_progress[task_id] = { self.upload_progress[task_id] = {
"status": "processing", "status": "processing",
"file_index": 0, "file_index": 0,
@@ -100,29 +156,19 @@ class KnowledgeBaseRoute(Route):
for file_idx, file_info in enumerate(files_to_upload): for file_idx, file_info in enumerate(files_to_upload):
try: try:
# 更新整体进度 # 更新整体进度
self.upload_progress[task_id].update( self._update_progress(
{ task_id,
"status": "processing", status="processing",
"file_index": file_idx, file_index=file_idx,
"file_name": file_info["file_name"], file_name=file_info["file_name"],
"stage": "parsing", stage="parsing",
"current": 0, current=0,
"total": 100, total=100,
},
) )
# 创建进度回调函数 # 创建进度回调函数
async def progress_callback(stage, current, total): progress_callback = self._make_progress_callback(
if task_id in self.upload_progress: task_id, file_idx, file_info["file_name"]
self.upload_progress[task_id].update(
{
"status": "processing",
"file_index": file_idx,
"file_name": file_info["file_name"],
"stage": stage,
"current": current,
"total": total,
},
) )
doc = await kb_helper.upload_document( doc = await kb_helper.upload_document(
@@ -154,23 +200,99 @@ class KnowledgeBaseRoute(Route):
"failed_count": len(failed_docs), "failed_count": len(failed_docs),
} }
self.upload_tasks[task_id] = { self._set_task_result(task_id, "completed", result=result)
"status": "completed",
"result": result,
"error": None,
}
self.upload_progress[task_id]["status"] = "completed"
except Exception as e: except Exception as e:
logger.error(f"后台上传任务 {task_id} 失败: {e}") logger.error(f"后台上传任务 {task_id} 失败: {e}")
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
self.upload_tasks[task_id] = { self._set_task_result(task_id, "failed", error=str(e))
"status": "failed",
"result": None, async def _background_import_task(
"error": str(e), self,
task_id: str,
kb_helper,
documents: list,
batch_size: int,
tasks_limit: int,
max_retries: int,
):
"""后台导入预切片文档任务"""
try:
# 初始化任务状态
self._init_task(task_id, status="processing")
self.upload_progress[task_id] = {
"status": "processing",
"file_index": 0,
"file_total": len(documents),
"stage": "waiting",
"current": 0,
"total": 100,
} }
if task_id in self.upload_progress:
self.upload_progress[task_id]["status"] = "failed" uploaded_docs = []
failed_docs = []
for file_idx, doc_info in enumerate(documents):
file_name = doc_info.get("file_name", f"imported_doc_{file_idx}")
chunks = doc_info.get("chunks", [])
try:
# 更新整体进度
self._update_progress(
task_id,
status="processing",
file_index=file_idx,
file_name=file_name,
stage="importing",
current=0,
total=100,
)
# 创建进度回调函数
progress_callback = self._make_progress_callback(
task_id, file_idx, file_name
)
# 调用 upload_document,传入 pre_chunked_text
doc = await kb_helper.upload_document(
file_name=file_name,
file_content=None, # 预切片模式下不需要原始内容
file_type=doc_info.get("file_type")
or (
file_name.rsplit(".", 1)[-1].lower()
if "." in file_name
else "txt"
),
batch_size=batch_size,
tasks_limit=tasks_limit,
max_retries=max_retries,
progress_callback=progress_callback,
pre_chunked_text=chunks,
)
uploaded_docs.append(doc.model_dump())
except Exception as e:
logger.error(f"导入文档 {file_name} 失败: {e}")
failed_docs.append(
{"file_name": file_name, "error": str(e)},
)
# 更新任务完成状态
result = {
"task_id": task_id,
"uploaded": uploaded_docs,
"failed": failed_docs,
"total": len(documents),
"success_count": len(uploaded_docs),
"failed_count": len(failed_docs),
}
self._set_task_result(task_id, "completed", result=result)
except Exception as e:
logger.error(f"后台导入任务 {task_id} 失败: {e}")
logger.error(traceback.format_exc())
self._set_task_result(task_id, "failed", error=str(e))
async def list_kbs(self): async def list_kbs(self):
"""获取知识库列表 """获取知识库列表
@@ -614,11 +736,7 @@ class KnowledgeBaseRoute(Route):
task_id = str(uuid.uuid4()) task_id = str(uuid.uuid4())
# 初始化任务状态 # 初始化任务状态
self.upload_tasks[task_id] = { self._init_task(task_id, status="pending")
"status": "pending",
"result": None,
"error": None,
}
# 启动后台任务 # 启动后台任务
asyncio.create_task( asyncio.create_task(
@@ -653,6 +771,93 @@ class KnowledgeBaseRoute(Route):
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
return Response().error(f"上传文档失败: {e!s}").__dict__ return Response().error(f"上传文档失败: {e!s}").__dict__
def _validate_import_request(self, data: dict):
kb_id = data.get("kb_id")
if not kb_id:
raise ValueError("缺少参数 kb_id")
documents = data.get("documents")
if not documents or not isinstance(documents, list):
raise ValueError("缺少参数 documents 或格式错误")
for doc in documents:
if "file_name" not in doc or "chunks" not in doc:
raise ValueError("文档格式错误,必须包含 file_name 和 chunks")
if not isinstance(doc["chunks"], list):
raise ValueError("chunks 必须是列表")
if not all(
isinstance(chunk, str) and chunk.strip() for chunk in doc["chunks"]
):
raise ValueError("chunks 必须是非空字符串列表")
batch_size = data.get("batch_size", 32)
tasks_limit = data.get("tasks_limit", 3)
max_retries = data.get("max_retries", 3)
return kb_id, documents, batch_size, tasks_limit, max_retries
async def import_documents(self):
"""导入预切片文档
Body:
- kb_id: 知识库 ID (必填)
- documents: 文档列表 (必填)
- file_name: 文件名 (必填)
- chunks: 切片列表 (必填, list[str])
- file_type: 文件类型 (可选, 默认从文件名推断或为 txt)
- batch_size: 批处理大小 (可选, 默认32)
- tasks_limit: 并发任务限制 (可选, 默认3)
- max_retries: 最大重试次数 (可选, 默认3)
"""
try:
kb_manager = self._get_kb_manager()
data = await request.json
kb_id, documents, batch_size, tasks_limit, max_retries = (
self._validate_import_request(data)
)
# 获取知识库
kb_helper = await kb_manager.get_kb(kb_id)
if not kb_helper:
return Response().error("知识库不存在").__dict__
# 生成任务ID
task_id = str(uuid.uuid4())
# 初始化任务状态
self._init_task(task_id, status="pending")
# 启动后台任务
asyncio.create_task(
self._background_import_task(
task_id=task_id,
kb_helper=kb_helper,
documents=documents,
batch_size=batch_size,
tasks_limit=tasks_limit,
max_retries=max_retries,
),
)
return (
Response()
.ok(
{
"task_id": task_id,
"doc_count": len(documents),
"message": "import task created, processing in background",
},
)
.__dict__
)
except ValueError as e:
return Response().error(str(e)).__dict__
except Exception as e:
logger.error(f"导入文档失败: {e}")
logger.error(traceback.format_exc())
return Response().error(f"导入文档失败: {e!s}").__dict__
async def get_upload_progress(self): async def get_upload_progress(self):
"""获取上传进度和结果 """获取上传进度和结果
@@ -960,11 +1165,7 @@ class KnowledgeBaseRoute(Route):
task_id = str(uuid.uuid4()) task_id = str(uuid.uuid4())
# 初始化任务状态 # 初始化任务状态
self.upload_tasks[task_id] = { self._init_task(task_id, status="pending")
"status": "pending",
"result": None,
"error": None,
}
# 启动后台任务 # 启动后台任务
asyncio.create_task( asyncio.create_task(
@@ -1017,11 +1218,7 @@ class KnowledgeBaseRoute(Route):
"""后台上传URL任务""" """后台上传URL任务"""
try: try:
# 初始化任务状态 # 初始化任务状态
self.upload_tasks[task_id] = { self._init_task(task_id, status="processing")
"status": "processing",
"result": None,
"error": None,
}
self.upload_progress[task_id] = { self.upload_progress[task_id] = {
"status": "processing", "status": "processing",
"file_index": 0, "file_index": 0,
@@ -1033,18 +1230,7 @@ class KnowledgeBaseRoute(Route):
} }
# 创建进度回调函数 # 创建进度回调函数
async def progress_callback(stage, current, total): progress_callback = self._make_progress_callback(task_id, 0, f"URL: {url}")
if task_id in self.upload_progress:
self.upload_progress[task_id].update(
{
"status": "processing",
"file_index": 0,
"file_name": f"URL: {url}",
"stage": stage,
"current": current,
"total": total,
},
)
# 上传文档 # 上传文档
doc = await kb_helper.upload_from_url( doc = await kb_helper.upload_from_url(
@@ -1069,20 +1255,9 @@ class KnowledgeBaseRoute(Route):
"failed_count": 0, "failed_count": 0,
} }
self.upload_tasks[task_id] = { self._set_task_result(task_id, "completed", result=result)
"status": "completed",
"result": result,
"error": None,
}
self.upload_progress[task_id]["status"] = "completed"
except Exception as e: except Exception as e:
logger.error(f"后台上传URL任务 {task_id} 失败: {e}") logger.error(f"后台上传URL任务 {task_id} 失败: {e}")
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
self.upload_tasks[task_id] = { self._set_task_result(task_id, "failed", error=str(e))
"status": "failed",
"result": None,
"error": str(e),
}
if task_id in self.upload_progress:
self.upload_progress[task_id]["status"] = "failed"
+4
View File
@@ -124,7 +124,11 @@ class PluginRoute(Route):
session.get(url) as response, session.get(url) as response,
): ):
if response.status == 200: if response.status == 200:
try:
remote_data = await response.json() remote_data = await response.json()
except aiohttp.ContentTypeError:
remote_text = await response.text()
remote_data = json.loads(remote_text)
# 检查远程数据是否为空 # 检查远程数据是否为空
if not remote_data or ( if not remote_data or (
+19
View File
@@ -0,0 +1,19 @@
## What's Changed
### 新增
- 支持自定义插件源。
- 支持飞书(Lark)的 Webhook 模式(将事件推送至开发者服务器)。
- 支持 “禁用自带指令” 快捷配置项,启用后将禁用所有 AstrBot 自带指令。入口: WebUI -> 配置文件 -> 平台配置。
### 优化
- 从 WebUI 移除了开发版本渠道。
- 当试图测试"Agent Runner"时,提示前往配置文件页测试。
- WebUI 列表项支持批量粘贴、回车创建项目。
### 修复
- Gemini API 部分调用失败的问题。
- WebUI 插件安装加载 Dialog 关闭按钮在手机端下显示异常的问题。
- 部分情况下,WebUI 日志显示不全的问题。
+16
View File
@@ -0,0 +1,16 @@
## What's Changed
### 修复
- 企业自部署飞书(自定义 domain)可以接收消息但无法发送消息的问题。
- 安装插件 Dialog 的深色样式问题。
### 优化
- 避免某些插件在流式响应结束后重复发送消息的问题。
### 新增
- 支持在对话管理批量导出对话轨迹数据为 `jsonl` 格式文件。入口:WebUI -> 对话管理 -> 批量选中 -> 导出。
- 支持对 TTS(文本转语音)设置概率触发。
- (插件开发)支持在 schema 中对 float 和 int 类型设置 `slider` 滑块控件。例如 `slider: {min: 0, max: 1, step: 0.1}`
Binary file not shown.

After

Width:  |  Height:  |  Size: 12 KiB

@@ -304,16 +304,32 @@ function hasVisibleItemsAfter(items, currentIndex) {
hide-details hide-details
></v-text-field> ></v-text-field>
<!-- Numeric input --> <!-- Numeric input with optional slider -->
<v-text-field <div
v-else-if="(metadata[metadataKey].items[key]?.type === 'int' || metadata[metadataKey].items[key]?.type === 'float') && !metadata[metadataKey]?.invisible" v-else-if="(metadata[metadataKey].items[key]?.type === 'int' || metadata[metadataKey].items[key]?.type === 'float') && !metadata[metadataKey]?.invisible"
v-model="iterable[key]" class="d-flex align-center gap-3"
>
<v-slider
v-if="metadata[metadataKey].items[key]?.slider"
v-model.number="iterable[key]"
:min="metadata[metadataKey].items[key]?.slider?.min ?? 0"
:max="metadata[metadataKey].items[key]?.slider?.max ?? 100"
:step="metadata[metadataKey].items[key]?.slider?.step ?? 1"
color="primary"
density="compact"
hide-details
class="flex-grow-1"
></v-slider>
<v-text-field
v-model.number="iterable[key]"
density="compact" density="compact"
variant="outlined" variant="outlined"
class="config-field" class="config-field"
type="number" type="number"
hide-details hide-details
style="max-width: 140px;"
></v-text-field> ></v-text-field>
</div>
<!-- Text area --> <!-- Text area -->
<v-textarea <v-textarea
@@ -413,16 +429,32 @@ function hasVisibleItemsAfter(items, currentIndex) {
hide-details hide-details
></v-text-field> ></v-text-field>
<!-- Numeric input --> <!-- Numeric input with optional slider -->
<v-text-field <div
v-else-if="(metadata[metadataKey]?.type === 'int' || metadata[metadataKey]?.type === 'float') && !metadata[metadataKey]?.invisible" v-else-if="(metadata[metadataKey]?.type === 'int' || metadata[metadataKey]?.type === 'float') && !metadata[metadataKey]?.invisible"
v-model="iterable[metadataKey]" class="d-flex align-center gap-3"
>
<v-slider
v-if="metadata[metadataKey]?.slider"
v-model.number="iterable[metadataKey]"
:min="metadata[metadataKey]?.slider?.min ?? 0"
:max="metadata[metadataKey]?.slider?.max ?? 100"
:step="metadata[metadataKey]?.slider?.step ?? 1"
color="primary"
density="compact"
hide-details
class="flex-grow-1"
></v-slider>
<v-text-field
v-model.number="iterable[metadataKey]"
density="compact" density="compact"
variant="outlined" variant="outlined"
class="config-field" class="config-field"
type="number" type="number"
hide-details hide-details
style="max-width: 140px;"
></v-text-field> ></v-text-field>
</div>
<!-- Text area --> <!-- Text area -->
<v-textarea <v-textarea
@@ -245,10 +245,29 @@ function getSpecialSubtype(value) {
<v-text-field v-else-if="itemMeta?.type === 'string'" v-model="createSelectorModel(itemKey).value" <v-text-field v-else-if="itemMeta?.type === 'string'" v-model="createSelectorModel(itemKey).value"
density="compact" variant="outlined" class="config-field" hide-details></v-text-field> density="compact" variant="outlined" class="config-field" hide-details></v-text-field>
<!-- Numeric input for JSON selector --> <!-- Numeric input with optional slider for JSON selector -->
<v-text-field v-else-if="itemMeta?.type === 'int' || itemMeta?.type === 'float'" <div v-else-if="itemMeta?.type === 'int' || itemMeta?.type === 'float'" class="d-flex align-center gap-3">
v-model="createSelectorModel(itemKey).value" density="compact" variant="outlined" class="config-field" <v-slider
type="number" hide-details></v-text-field> v-if="itemMeta?.slider"
v-model.number="createSelectorModel(itemKey).value"
:min="itemMeta?.slider?.min ?? 0"
:max="itemMeta?.slider?.max ?? 100"
:step="itemMeta?.slider?.step ?? 1"
color="primary"
density="compact"
hide-details
style="flex: 3"
></v-slider>
<v-text-field
v-model.number="createSelectorModel(itemKey).value"
density="compact"
variant="outlined"
class="config-field"
style="flex: 2"
type="number"
hide-details
></v-text-field>
</div>
<!-- Text area for JSON selector --> <!-- Text area for JSON selector -->
<v-textarea v-else-if="itemMeta?.type === 'text'" v-model="createSelectorModel(itemKey).value" <v-textarea v-else-if="itemMeta?.type === 'text'" v-model="createSelectorModel(itemKey).value"
@@ -1,6 +1,7 @@
<script setup> <script setup>
import { useCommonStore } from '@/stores/common'; import { useCommonStore } from '@/stores/common';
import { storeToRefs } from 'pinia'; import { storeToRefs } from 'pinia';
import axios from 'axios';
</script> </script>
<template> <template>
@@ -24,8 +25,6 @@ import { storeToRefs } from 'pinia';
export default { export default {
name: 'ConsoleDisplayer', name: 'ConsoleDisplayer',
data() { data() {
const commonStore = useCommonStore();
const { log_cache } = storeToRefs(commonStore);
return { return {
autoScroll: true, // autoScroll: true, //
logColorAnsiMap: { logColorAnsiMap: {
@@ -38,7 +37,6 @@ export default {
'\u001b[32m': 'color: #00FF00;', // green '\u001b[32m': 'color: #00FF00;', // green
'default': 'color: #FFFFFF;' 'default': 'color: #FFFFFF;'
}, },
logCache: log_cache,
historyNum_: -1, historyNum_: -1,
logLevels: ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], logLevels: ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
selectedLevels: [0, 1, 2, 3, 4], // selectedLevels: [0, 1, 2, 3, 4], //
@@ -48,7 +46,17 @@ export default {
'WARNING': 'amber', 'WARNING': 'amber',
'ERROR': 'red', 'ERROR': 'red',
'CRITICAL': 'purple' 'CRITICAL': 'purple'
},
lastProcessedTime: 0, //
localLogCache: [], //
} }
},
computed: {
commonStore() {
return useCommonStore();
},
logCache() {
return this.commonStore.log_cache;
} }
}, },
props: { props: {
@@ -63,13 +71,39 @@ export default {
}, },
watch: { watch: {
logCache: { logCache: {
handler(val) { handler(newVal) {
const lastLog = val[this.logCache.length - 1]; // timestamp
if (lastLog && this.isLevelSelected(lastLog.level)) { if (newVal && newVal.length > 0) {
this.printLog(lastLog.data); // DOM
this.$nextTick(() => {
//
const newLogs = newVal.filter(log => log.time > this.lastProcessedTime);
if (newLogs.length > 0) {
this.localLogCache.push(...newLogs);
//
this.localLogCache.sort((a, b) => a.time - b.time);
// log_cache_max_len
if (this.localLogCache.length > this.commonStore.log_cache_max_len) {
this.localLogCache.splice(0, this.localLogCache.length - this.commonStore.log_cache_max_len);
}
//
newLogs.forEach(logItem => {
if (this.isLevelSelected(logItem.level)) {
this.printLog(logItem.data);
}
});
//
this.lastProcessedTime = Math.max(...newLogs.map(log => log.time));
}
});
} }
}, },
deep: true deep: true,
immediate: false
}, },
selectedLevels: { selectedLevels: {
handler() { handler() {
@@ -78,14 +112,37 @@ export default {
deep: true deep: true
} }
}, },
mounted() { async mounted() {
if (this.logCache.length === 0) { //
this.delayInit() await this.fetchLogHistory();
} else {
this.init() // DOM
this.$nextTick(() => {
if (this.localLogCache.length > 0) {
this.localLogCache.forEach(logItem => {
if (this.isLevelSelected(logItem.level)) {
this.printLog(logItem.data);
} }
});
//
this.lastProcessedTime = Math.max(...this.localLogCache.map(log => log.time));
}
});
}, },
methods: { methods: {
async fetchLogHistory() {
try {
const res = await axios.get('/api/log-history');
if (res.data.data.logs && res.data.data.logs.length > 0) {
this.localLogCache = [...res.data.data.logs];
//
this.localLogCache.sort((a, b) => a.time - b.time);
}
} catch (err) {
console.error('Failed to fetch log history:', err);
}
},
getLevelColor(level) { getLevelColor(level) {
return this.levelColors[level] || 'grey'; return this.levelColors[level] || 'grey';
}, },
@@ -101,40 +158,21 @@ export default {
}, },
refreshDisplay() { refreshDisplay() {
//
const termElement = document.getElementById('term'); const termElement = document.getElementById('term');
if (termElement) { if (termElement) {
termElement.innerHTML = ''; termElement.innerHTML = '';
//
if (this.localLogCache && this.localLogCache.length > 0) {
this.localLogCache.forEach(logItem => {
if (this.isLevelSelected(logItem.level)) {
this.printLog(logItem.data);
}
});
} }
//
this.init();
},
delayInit() {
if (this.logCache.length === 0) {
setTimeout(() => {
this.delayInit()
}, 500)
} else {
this.init()
} }
}, },
init() {
this.historyNum_ = parseInt(this.historyNum)
let i = 0
for (let log of this.logCache) {
if (this.isLevelSelected(log.level)) { //
if (this.historyNum_ != -1 && i >= this.logCache.length - this.historyNum_) {
this.printLog(log.data)
++i
} else if (this.historyNum_ == -1) {
this.printLog(log.data)
}
}
}
},
toggleAutoScroll() { toggleAutoScroll() {
this.autoScroll = !this.autoScroll; this.autoScroll = !this.autoScroll;
@@ -143,6 +181,11 @@ export default {
printLog(log) { printLog(log) {
// append span termblock // append span termblock
let ele = document.getElementById('term') let ele = document.getElementById('term')
if (!ele) {
console.warn('term element not found, skipping log print');
return;
}
let span = document.createElement('pre') let span = document.createElement('pre')
let style = this.logColorAnsiMap['default'] let style = this.logColorAnsiMap['default']
for (let key in this.logColorAnsiMap) { for (let key in this.logColorAnsiMap) {
@@ -115,7 +115,7 @@ const _show = computed({
</script> </script>
<template> <template>
<v-dialog v-model="_show" width="800" persistent> <v-dialog v-model="_show" width="800">
<v-card> <v-card>
<v-card-title class="d-flex justify-space-between align-center"> <v-card-title class="d-flex justify-space-between align-center">
<span class="text-h5">{{ t('core.common.readme.title') }}</span> <span class="text-h5">{{ t('core.common.readme.title') }}</span>
@@ -57,6 +57,9 @@
}, },
"provider_id": { "provider_id": {
"description": "Default Text-to-Speech Model" "description": "Default Text-to-Speech Model"
},
"trigger_probability": {
"description": "TTS Trigger Probability"
} }
} }
}, },
@@ -13,7 +13,8 @@
"refresh": "Refresh" "refresh": "Refresh"
}, },
"batch": { "batch": {
"deleteSelected": "Delete Selected ({count})" "deleteSelected": "Delete Selected ({count})",
"exportSelected": "Export Selected ({count})"
}, },
"pagination": { "pagination": {
"itemsPerPage": "Items per page", "itemsPerPage": "Items per page",
@@ -76,7 +77,8 @@
"message": "Are you sure you want to delete the selected {count} conversations? This action cannot be undone, please proceed with caution!", "message": "Are you sure you want to delete the selected {count} conversations? This action cannot be undone, please proceed with caution!",
"andMore": "and {count} more", "andMore": "and {count} more",
"cancel": "Cancel", "cancel": "Cancel",
"confirm": "Batch Delete" "confirm": "Batch Delete",
"warning": "Warning: This action cannot be undone!"
} }
}, },
"messages": { "messages": {
@@ -92,6 +94,9 @@
"noItemSelected": "Please select conversations to delete first", "noItemSelected": "Please select conversations to delete first",
"batchDeleteSuccess": "Successfully deleted {count} conversations", "batchDeleteSuccess": "Successfully deleted {count} conversations",
"batchDeleteError": "Batch delete failed", "batchDeleteError": "Batch delete failed",
"batchDeletePartial": "Delete completed: {deleted} successful, {failed} failed" "batchDeletePartial": "Delete completed: {deleted} successful, {failed} failed",
"exportSuccess": "Export successful",
"exportError": "Export failed",
"noItemSelectedForExport": "Please select conversations to export first"
} }
} }
@@ -62,6 +62,9 @@
}, },
"provider_id": { "provider_id": {
"description": "默认文本转语音模型" "description": "默认文本转语音模型"
},
"trigger_probability": {
"description": "TTS 触发概率"
} }
} }
}, },
@@ -13,7 +13,8 @@
"refresh": "刷新" "refresh": "刷新"
}, },
"batch": { "batch": {
"deleteSelected": "删除选中 ({count})" "deleteSelected": "删除选中 ({count})",
"exportSelected": "导出选中 ({count})"
}, },
"pagination": { "pagination": {
"itemsPerPage": "每页", "itemsPerPage": "每页",
@@ -76,7 +77,8 @@
"message": "确定要删除选中的 {count} 个对话吗?此操作不可恢复,请谨慎操作!", "message": "确定要删除选中的 {count} 个对话吗?此操作不可恢复,请谨慎操作!",
"andMore": "等 {count} 个", "andMore": "等 {count} 个",
"cancel": "取消", "cancel": "取消",
"confirm": "批量删除" "confirm": "批量删除",
"warning": "警告:此操作不可撤销!"
} }
}, },
"messages": { "messages": {
@@ -92,6 +94,9 @@
"noItemSelected": "请先选择要删除的对话", "noItemSelected": "请先选择要删除的对话",
"batchDeleteSuccess": "成功删除 {count} 个对话", "batchDeleteSuccess": "成功删除 {count} 个对话",
"batchDeleteError": "批量删除失败", "batchDeleteError": "批量删除失败",
"batchDeletePartial": "删除完成:成功 {deleted} 个,失败 {failed} 个" "batchDeletePartial": "删除完成:成功 {deleted} 个,失败 {failed} 个",
"exportSuccess": "导出成功",
"exportError": "导出失败",
"noItemSelectedForExport": "请先选择要导出的对话"
} }
} }
+28 -62
View File
@@ -16,21 +16,6 @@ export const useCommonStore = defineStore({
}), }),
actions: { actions: {
async createEventSource() { async createEventSource() {
const fetchLogHistory = async () => {
try {
const res = await axios.get('/api/log-history');
if (res.data.data.logs) {
this.log_cache.push(...res.data.data.logs);
} else {
this.log_cache = [];
}
} catch (err) {
console.error('Failed to fetch log history:', err);
}
};
await fetchLogHistory();
if (this.eventSource) { if (this.eventSource) {
return return
} }
@@ -54,25 +39,9 @@ export const useCommonStore = defineStore({
const reader = response.body.getReader(); const reader = response.body.getReader();
const decoder = new TextDecoder(); const decoder = new TextDecoder();
let bufferedText = '';
let incompleteLine = ""; // 用于存储不完整的行
const handleIncompleteLine = (line) => {
incompleteLine += line;
// if can parse as JSON, return it
try {
const data_json = JSON.parse(incompleteLine);
incompleteLine = ""; // 清空不完整行
return data_json;
} catch (e) {
return null;
}
}
const processStream = ({ done, value }) => { const processStream = ({ done, value }) => {
// get bytes length
const bytesLength = value ? value.byteLength : 0;
console.log(`Received ${bytesLength} bytes from live log`);
if (done) { if (done) {
console.log('SSE stream closed'); console.log('SSE stream closed');
setTimeout(() => { setTimeout(() => {
@@ -82,44 +51,41 @@ export const useCommonStore = defineStore({
return; return;
} }
const text = decoder.decode(value); // Accumulate partial chunks; SSE data may split JSON across reads.
const lines = text.split('\n\n'); const text = decoder.decode(value, { stream: true });
lines.forEach(line => { bufferedText += text;
if (!line.trim()) {
// Split completed events; keep the trailing partial in buffer.
const segments = bufferedText.split('\n\n');
bufferedText = segments.pop() || '';
segments.forEach(segment => {
const line = segment.trim();
if (!line.startsWith('data: ')) {
return; return;
} }
if (line.startsWith('data:')) {
const data = line.substring(5).trim(); const logLine = line.replace('data: ', '').trim();
// {"type":"log","data":"[2021-08-01 00:00:00] INFO: Hello, world!"} if (!logLine) {
let data_json = {} return;
}
try { try {
data_json = JSON.parse(data); const logObject = JSON.parse(logLine);
} catch (e) { // give a uuid if not exists
console.warn('Invalid JSON:', data); if (!logObject.uuid) {
// 尝试处理不完整的行 logObject.uuid = crypto.randomUUID();
const parsedData = handleIncompleteLine(data);
if (parsedData) {
data_json = parsedData;
} else {
return; // 如果无法解析,跳过当前行
} }
} this.log_cache.push(logObject);
if (data_json.type === 'log') { // Limit log cache size
this.log_cache.push(data_json);
if (this.log_cache.length > this.log_cache_max_len) { if (this.log_cache.length > this.log_cache_max_len) {
this.log_cache.shift(); this.log_cache.splice(0, this.log_cache.length - this.log_cache_max_len);
}
}
} else {
const parsedData = handleIncompleteLine(line);
if (parsedData && parsedData.type === 'log') {
this.log_cache.push(parsedData);
if (this.log_cache.length > this.log_cache_max_len) {
this.log_cache.shift();
}
} }
} catch (err) {
console.warn('Failed to parse SSE log line, skipping:', err, logLine);
} }
}); });
return reader.read().then(processStream); return reader.read().then(processStream);
}; };
+58
View File
@@ -40,6 +40,17 @@
:loading="loading" size="small" class="mr-2"> :loading="loading" size="small" class="mr-2">
{{ tm('history.refresh') }} {{ tm('history.refresh') }}
</v-btn> </v-btn>
<v-btn
v-if="selectedItems.length > 0"
color="success"
prepend-icon="mdi-download"
variant="tonal"
@click="exportConversations"
:disabled="loading"
size="small"
class="mr-2">
{{ tm('batch.exportSelected', { count: selectedItems.length }) }}
</v-btn>
<v-btn <v-btn
v-if="selectedItems.length > 0" v-if="selectedItems.length > 0"
color="error" color="error"
@@ -910,6 +921,53 @@ export default {
} }
}, },
//
async exportConversations() {
if (this.selectedItems.length === 0) {
this.showErrorMessage(this.tm('messages.noItemSelectedForExport'));
return;
}
this.loading = true;
try {
//
const conversations = this.selectedItems.map(item => ({
user_id: item.user_id,
cid: item.cid
}));
const response = await axios.post('/api/conversation/export', {
conversations: conversations
}, {
responseType: 'blob' // axios blob
});
//
const url = window.URL.createObjectURL(response.data);
const link = document.createElement('a');
link.href = url;
// 使
const timestamp = new Date().toISOString().replace(/[:.]/g, '-').slice(0, -5);
const filename = `conversations_export_${timestamp}.jsonl`;
link.setAttribute('download', filename);
document.body.appendChild(link);
link.click();
//
link.remove();
window.URL.revokeObjectURL(url);
this.showSuccessMessage(this.tm('messages.exportSuccess'));
} catch (error) {
console.error(this.tm('messages.exportError'), error);
this.showErrorMessage(error.response?.data?.message || error.message || this.tm('messages.exportError'));
} finally {
this.loading = false;
}
},
// //
formatTimestamp(timestamp) { formatTimestamp(timestamp) {
if (!timestamp) return this.tm('status.unknown'); if (!timestamp) return this.tm('status.unknown');
+1 -1
View File
@@ -1568,7 +1568,7 @@ watch(marketSearch, (newVal) => {
<!-- 上传插件对话框 --> <!-- 上传插件对话框 -->
<v-dialog v-model="dialog" width="500"> <v-dialog v-model="dialog" width="500">
<div class="v-card v-theme--PurpleThemeDark v-card--density-default rounded-lg v-card--variant-elevated"> <div class="v-card v-card--density-default rounded-lg v-card--variant-elevated">
<div class="v-card__loader"> <div class="v-card__loader">
<v-progress-linear :indeterminate="loading_" color="primary" height="2" :active="loading_"></v-progress-linear> <v-progress-linear :indeterminate="loading_" color="primary" height="2" :active="loading_"></v-progress-linear>
</div> </div>
+1 -1
View File
@@ -1,6 +1,6 @@
[project] [project]
name = "AstrBot" name = "AstrBot"
version = "4.8.0" version = "4.9.1"
description = "Easy-to-use multi-platform LLM chatbot and development framework" description = "Easy-to-use multi-platform LLM chatbot and development framework"
readme = "README.md" readme = "README.md"
requires-python = ">=3.10" requires-python = ">=3.10"
+209
View File
@@ -0,0 +1,209 @@
import asyncio
from unittest.mock import AsyncMock, MagicMock
import pytest
import pytest_asyncio
from quart import Quart
from astrbot.core import LogBroker
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.db.sqlite import SQLiteDatabase
from astrbot.core.knowledge_base.kb_helper import KBHelper
from astrbot.core.knowledge_base.models import KBDocument
from astrbot.dashboard.server import AstrBotDashboard
@pytest_asyncio.fixture(scope="module")
async def core_lifecycle_td(tmp_path_factory):
"""Creates and initializes a core lifecycle instance with a temporary database."""
tmp_db_path = tmp_path_factory.mktemp("data") / "test_data_kb.db"
db = SQLiteDatabase(str(tmp_db_path))
log_broker = LogBroker()
core_lifecycle = AstrBotCoreLifecycle(log_broker, db)
await core_lifecycle.initialize()
# Mock kb_manager and kb_helper
kb_manager = MagicMock()
kb_helper = AsyncMock(spec=KBHelper)
# Configure get_kb to be an async mock that returns kb_helper
kb_manager.get_kb = AsyncMock(return_value=kb_helper)
# Mock upload_document return value
mock_doc = KBDocument(
doc_id="test_doc_id",
kb_id="test_kb_id",
doc_name="test_file.txt",
file_type="txt",
file_size=100,
file_path="",
chunk_count=2,
media_count=0,
)
kb_helper.upload_document.return_value = mock_doc
# kb_manager.get_kb.return_value = kb_helper # Removed this line as it's handled above
core_lifecycle.kb_manager = kb_manager
try:
yield core_lifecycle
finally:
try:
_stop_res = core_lifecycle.stop()
if asyncio.iscoroutine(_stop_res):
await _stop_res
except Exception:
pass
@pytest.fixture(scope="module")
def app(core_lifecycle_td: AstrBotCoreLifecycle):
"""Creates a Quart app instance for testing."""
shutdown_event = asyncio.Event()
server = AstrBotDashboard(core_lifecycle_td, core_lifecycle_td.db, shutdown_event)
return server.app
@pytest_asyncio.fixture(scope="module")
async def authenticated_header(app: Quart, core_lifecycle_td: AstrBotCoreLifecycle):
"""Handles login and returns an authenticated header."""
test_client = app.test_client()
response = await test_client.post(
"/api/auth/login",
json={
"username": core_lifecycle_td.astrbot_config["dashboard"]["username"],
"password": core_lifecycle_td.astrbot_config["dashboard"]["password"],
},
)
data = await response.get_json()
assert data["status"] == "ok"
token = data["data"]["token"]
return {"Authorization": f"Bearer {token}"}
@pytest.mark.asyncio
async def test_import_documents(
app: Quart, authenticated_header: dict, core_lifecycle_td: AstrBotCoreLifecycle
):
"""Tests the import documents functionality."""
test_client = app.test_client()
# Test data
import_data = {
"kb_id": "test_kb_id",
"documents": [
{"file_name": "test_file_1.txt", "chunks": ["chunk1", "chunk2"]},
{"file_name": "test_file_2.md", "chunks": ["chunk3", "chunk4", "chunk5"]},
],
}
# Send request
response = await test_client.post(
"/api/kb/document/import", json=import_data, headers=authenticated_header
)
# Verify response
assert response.status_code == 200
data = await response.get_json()
assert data["status"] == "ok"
assert "task_id" in data["data"]
assert data["data"]["doc_count"] == 2
task_id = data["data"]["task_id"]
# Wait for background task to complete (mocked)
# Since we mocked upload_document, it should be fast, but we might need to poll progress
for _ in range(10):
progress_response = await test_client.get(
f"/api/kb/document/upload/progress?task_id={task_id}",
headers=authenticated_header,
)
progress_data = await progress_response.get_json()
if progress_data["data"]["status"] == "completed":
break
await asyncio.sleep(0.1)
assert progress_data["data"]["status"] == "completed"
result = progress_data["data"]["result"]
assert result["success_count"] == 2
assert result["failed_count"] == 0
# Verify kb_helper.upload_document was called correctly
kb_helper = await core_lifecycle_td.kb_manager.get_kb("test_kb_id")
assert kb_helper.upload_document.call_count == 2
# Check first call arguments
call_args_list = kb_helper.upload_document.call_args_list
# First document
args1, kwargs1 = call_args_list[0]
assert kwargs1["file_name"] == "test_file_1.txt"
assert kwargs1["pre_chunked_text"] == ["chunk1", "chunk2"]
# Second document
args2, kwargs2 = call_args_list[1]
assert kwargs2["file_name"] == "test_file_2.md"
assert kwargs2["pre_chunked_text"] == ["chunk3", "chunk4", "chunk5"]
@pytest.mark.asyncio
async def test_import_documents_invalid_input(app: Quart, authenticated_header: dict):
"""Tests import documents with invalid input."""
test_client = app.test_client()
# Missing kb_id
response = await test_client.post(
"/api/kb/document/import", json={"documents": []}, headers=authenticated_header
)
data = await response.get_json()
assert data["status"] == "error"
assert "缺少参数 kb_id" in data["message"]
# Missing documents
response = await test_client.post(
"/api/kb/document/import",
json={"kb_id": "test_kb"},
headers=authenticated_header,
)
data = await response.get_json()
assert data["status"] == "error"
assert "缺少参数 documents" in data["message"]
# Invalid document format
response = await test_client.post(
"/api/kb/document/import",
json={
"kb_id": "test_kb",
"documents": [{"file_name": "test"}], # Missing chunks
},
headers=authenticated_header,
)
data = await response.get_json()
assert data["status"] == "error"
assert "文档格式错误" in data["message"]
# Invalid chunks type
response = await test_client.post(
"/api/kb/document/import",
json={
"kb_id": "test_kb",
"documents": [{"file_name": "test", "chunks": "not-a-list"}],
},
headers=authenticated_header,
)
data = await response.get_json()
assert data["status"] == "error"
assert "chunks 必须是列表" in data["message"]
# Invalid chunks content
response = await test_client.post(
"/api/kb/document/import",
json={
"kb_id": "test_kb",
"documents": [{"file_name": "test", "chunks": ["valid", ""]}],
},
headers=authenticated_header,
)
data = await response.get_json()
assert data["status"] == "error"
assert "chunks 必须是非空字符串列表" in data["message"]