Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| e7e97730af | |||
| 467ca1eb5c | |||
| 46528391c2 |
@@ -1 +1 @@
|
|||||||
__version__ = "4.8.0"
|
__version__ = "4.9.0"
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import os
|
|||||||
|
|
||||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||||
|
|
||||||
VERSION = "4.8.0"
|
VERSION = "4.9.0"
|
||||||
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db")
|
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db")
|
||||||
|
|
||||||
WEBHOOK_SUPPORTED_PLATFORMS = [
|
WEBHOOK_SUPPORTED_PLATFORMS = [
|
||||||
|
|||||||
+2
-1
@@ -24,6 +24,7 @@ import asyncio
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
import time
|
||||||
from asyncio import Queue
|
from asyncio import Queue
|
||||||
from collections import deque
|
from collections import deque
|
||||||
|
|
||||||
@@ -148,7 +149,7 @@ class LogQueueHandler(logging.Handler):
|
|||||||
self.log_broker.publish(
|
self.log_broker.publish(
|
||||||
{
|
{
|
||||||
"level": record.levelname,
|
"level": record.levelname,
|
||||||
"time": record.asctime,
|
"time": time.time(),
|
||||||
"data": log_entry,
|
"data": log_entry,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -112,10 +112,6 @@ class PlatformManager:
|
|||||||
from .sources.satori.satori_adapter import (
|
from .sources.satori.satori_adapter import (
|
||||||
SatoriPlatformAdapter, # noqa: F401
|
SatoriPlatformAdapter, # noqa: F401
|
||||||
)
|
)
|
||||||
case "github_webhook":
|
|
||||||
from .sources.github_webhook.github_webhook_adapter import (
|
|
||||||
GitHubWebhookPlatformAdapter, # noqa: F401
|
|
||||||
)
|
|
||||||
except (ImportError, ModuleNotFoundError) as e:
|
except (ImportError, ModuleNotFoundError) as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"加载平台适配器 {platform_config['type']} 失败,原因:{e}。请检查依赖库是否安装。提示:可以在 管理面板->平台日志->安装Pip库 中安装依赖库。",
|
f"加载平台适配器 {platform_config['type']} 失败,原因:{e}。请检查依赖库是否安装。提示:可以在 管理面板->平台日志->安装Pip库 中安装依赖库。",
|
||||||
|
|||||||
@@ -1,315 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
import hashlib
|
|
||||||
import hmac
|
|
||||||
from typing import Any, cast
|
|
||||||
|
|
||||||
from astrbot import logger
|
|
||||||
from astrbot.api.event import MessageChain
|
|
||||||
from astrbot.api.message_components import Plain
|
|
||||||
from astrbot.api.platform import (
|
|
||||||
AstrBotMessage,
|
|
||||||
MessageMember,
|
|
||||||
MessageType,
|
|
||||||
Platform,
|
|
||||||
PlatformMetadata,
|
|
||||||
)
|
|
||||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
|
||||||
from astrbot.core.platform.platform import PlatformStatus
|
|
||||||
from astrbot.core.utils.webhook_utils import log_webhook_info
|
|
||||||
|
|
||||||
from ...register import register_platform_adapter
|
|
||||||
from .github_webhook_event import GitHubWebhookMessageEvent
|
|
||||||
|
|
||||||
|
|
||||||
@register_platform_adapter(
|
|
||||||
"github_webhook",
|
|
||||||
"GitHub Webhook 适配器",
|
|
||||||
support_streaming_message=False,
|
|
||||||
)
|
|
||||||
class GitHubWebhookPlatformAdapter(Platform):
|
|
||||||
"""GitHub Webhook 平台适配器
|
|
||||||
|
|
||||||
支持的事件:
|
|
||||||
- issues (created)
|
|
||||||
- issue_comment (created)
|
|
||||||
- pull_request (opened)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
platform_config: dict,
|
|
||||||
platform_settings: dict,
|
|
||||||
event_queue: asyncio.Queue,
|
|
||||||
) -> None:
|
|
||||||
super().__init__(platform_config, event_queue)
|
|
||||||
|
|
||||||
self.unified_webhook_mode = platform_config.get("unified_webhook_mode", True)
|
|
||||||
self.webhook_secret = platform_config.get("webhook_secret", "")
|
|
||||||
self.shutdown_event = asyncio.Event()
|
|
||||||
|
|
||||||
async def send_by_session(
|
|
||||||
self,
|
|
||||||
session: MessageSesion,
|
|
||||||
message_chain: MessageChain,
|
|
||||||
):
|
|
||||||
"""GitHub Webhook 是单向接收,不支持主动发送消息"""
|
|
||||||
logger.warning("GitHub Webhook 适配器不支持 send_by_session")
|
|
||||||
|
|
||||||
def meta(self) -> PlatformMetadata:
|
|
||||||
return PlatformMetadata(
|
|
||||||
name="github_webhook",
|
|
||||||
description="GitHub Webhook 适配器",
|
|
||||||
id=cast(str, self.config.get("id")),
|
|
||||||
)
|
|
||||||
|
|
||||||
async def run(self):
|
|
||||||
"""运行适配器"""
|
|
||||||
self.status = PlatformStatus.RUNNING
|
|
||||||
|
|
||||||
# 如果启用统一 webhook 模式
|
|
||||||
webhook_uuid = self.config.get("webhook_uuid")
|
|
||||||
if self.unified_webhook_mode and webhook_uuid:
|
|
||||||
log_webhook_info(f"{self.meta().id}(GitHub Webhook)", webhook_uuid)
|
|
||||||
# 保持运行状态,等待 shutdown
|
|
||||||
await self.shutdown_event.wait()
|
|
||||||
else:
|
|
||||||
logger.warning("GitHub Webhook 适配器需要启用统一 webhook 模式")
|
|
||||||
await self.shutdown_event.wait()
|
|
||||||
|
|
||||||
async def webhook_callback(self, request: Any) -> Any:
|
|
||||||
"""统一 Webhook 回调入口
|
|
||||||
|
|
||||||
处理 GitHub webhook 事件
|
|
||||||
|
|
||||||
Args:
|
|
||||||
request: Quart 请求对象
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
响应数据
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# 获取事件类型
|
|
||||||
event_type = request.headers.get("X-GitHub-Event", "")
|
|
||||||
|
|
||||||
# 获取请求数据
|
|
||||||
payload = await request.json
|
|
||||||
|
|
||||||
# 验证 webhook 签名(如果配置了 secret)
|
|
||||||
if self.webhook_secret:
|
|
||||||
if not await self._verify_signature(request, payload):
|
|
||||||
logger.warning("GitHub webhook 签名验证失败")
|
|
||||||
return {"error": "Invalid signature"}, 401
|
|
||||||
|
|
||||||
logger.debug(f"收到 GitHub Webhook 事件: {event_type}")
|
|
||||||
|
|
||||||
# 处理不同类型的事件
|
|
||||||
if event_type == "issues":
|
|
||||||
await self._handle_issue_event(payload)
|
|
||||||
elif event_type == "issue_comment":
|
|
||||||
await self._handle_issue_comment_event(payload)
|
|
||||||
elif event_type == "pull_request":
|
|
||||||
await self._handle_pull_request_event(payload)
|
|
||||||
elif event_type == "ping":
|
|
||||||
# GitHub webhook 验证事件
|
|
||||||
return {"message": "pong"}
|
|
||||||
else:
|
|
||||||
logger.debug(f"忽略不支持的 GitHub 事件类型: {event_type}")
|
|
||||||
|
|
||||||
return {"status": "ok"}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"处理 GitHub webhook 回调时发生错误: {e}", exc_info=True)
|
|
||||||
return {"error": str(e)}, 500
|
|
||||||
|
|
||||||
async def _verify_signature(self, request: Any, payload: dict) -> bool:
|
|
||||||
"""验证 GitHub webhook 签名
|
|
||||||
|
|
||||||
Args:
|
|
||||||
request: Quart 请求对象
|
|
||||||
payload: 请求负载数据
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
签名是否有效
|
|
||||||
"""
|
|
||||||
signature_header = request.headers.get("X-Hub-Signature-256", "")
|
|
||||||
if not signature_header:
|
|
||||||
# 如果没有签名头,检查是否有旧版本的签名
|
|
||||||
signature_header = request.headers.get("X-Hub-Signature", "")
|
|
||||||
if not signature_header:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# 获取原始请求体
|
|
||||||
body = await request.get_data()
|
|
||||||
|
|
||||||
# 计算 HMAC
|
|
||||||
if signature_header.startswith("sha256="):
|
|
||||||
expected_signature = hmac.new(
|
|
||||||
self.webhook_secret.encode("utf-8"),
|
|
||||||
body,
|
|
||||||
hashlib.sha256,
|
|
||||||
).hexdigest()
|
|
||||||
received_signature = signature_header.replace("sha256=", "")
|
|
||||||
elif signature_header.startswith("sha1="):
|
|
||||||
expected_signature = hmac.new(
|
|
||||||
self.webhook_secret.encode("utf-8"),
|
|
||||||
body,
|
|
||||||
hashlib.sha1,
|
|
||||||
).hexdigest()
|
|
||||||
received_signature = signature_header.replace("sha1=", "")
|
|
||||||
else:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# 使用 hmac.compare_digest 防止时序攻击
|
|
||||||
return hmac.compare_digest(expected_signature, received_signature)
|
|
||||||
|
|
||||||
async def _handle_issue_event(self, payload: dict):
|
|
||||||
"""处理 issue 事件"""
|
|
||||||
action = payload.get("action", "")
|
|
||||||
|
|
||||||
# 只处理创建事件
|
|
||||||
if action != "created" and action != "opened":
|
|
||||||
return
|
|
||||||
|
|
||||||
issue = payload.get("issue", {})
|
|
||||||
repo = payload.get("repository", {})
|
|
||||||
sender = payload.get("sender", {})
|
|
||||||
|
|
||||||
# 构造消息文本
|
|
||||||
message_text = (
|
|
||||||
f"📝 新 Issue 创建\n"
|
|
||||||
f"仓库: {repo.get('full_name', 'unknown')}\n"
|
|
||||||
f"标题: {issue.get('title', 'No title')}\n"
|
|
||||||
f"作者: {sender.get('login', 'unknown')}\n"
|
|
||||||
f"链接: {issue.get('html_url', '')}\n"
|
|
||||||
f"内容:\n{issue.get('body', 'No description')[:200]}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 创建 AstrBotMessage
|
|
||||||
abm = self._create_message(
|
|
||||||
message_text,
|
|
||||||
sender.get("login", "unknown"),
|
|
||||||
sender.get("login", "unknown"),
|
|
||||||
repo.get("full_name", "unknown"),
|
|
||||||
)
|
|
||||||
|
|
||||||
# 提交事件
|
|
||||||
self.commit_event(
|
|
||||||
GitHubWebhookMessageEvent(
|
|
||||||
message_text,
|
|
||||||
abm,
|
|
||||||
self.meta(),
|
|
||||||
repo.get("full_name", "unknown"),
|
|
||||||
"issues",
|
|
||||||
payload,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _handle_issue_comment_event(self, payload: dict):
|
|
||||||
"""处理 issue 评论事件"""
|
|
||||||
action = payload.get("action", "")
|
|
||||||
|
|
||||||
# 只处理创建事件
|
|
||||||
if action != "created":
|
|
||||||
return
|
|
||||||
|
|
||||||
issue = payload.get("issue", {})
|
|
||||||
comment = payload.get("comment", {})
|
|
||||||
repo = payload.get("repository", {})
|
|
||||||
sender = payload.get("sender", {})
|
|
||||||
|
|
||||||
# 构造消息文本
|
|
||||||
message_text = (
|
|
||||||
f"💬 新 Issue 评论\n"
|
|
||||||
f"仓库: {repo.get('full_name', 'unknown')}\n"
|
|
||||||
f"Issue: {issue.get('title', 'No title')}\n"
|
|
||||||
f"评论者: {sender.get('login', 'unknown')}\n"
|
|
||||||
f"链接: {comment.get('html_url', '')}\n"
|
|
||||||
f"内容:\n{comment.get('body', 'No comment')[:200]}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 创建 AstrBotMessage
|
|
||||||
abm = self._create_message(
|
|
||||||
message_text,
|
|
||||||
sender.get("login", "unknown"),
|
|
||||||
sender.get("login", "unknown"),
|
|
||||||
repo.get("full_name", "unknown"),
|
|
||||||
)
|
|
||||||
|
|
||||||
# 提交事件
|
|
||||||
self.commit_event(
|
|
||||||
GitHubWebhookMessageEvent(
|
|
||||||
message_text,
|
|
||||||
abm,
|
|
||||||
self.meta(),
|
|
||||||
repo.get("full_name", "unknown"),
|
|
||||||
"issue_comment",
|
|
||||||
payload,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _handle_pull_request_event(self, payload: dict):
|
|
||||||
"""处理 pull request 事件"""
|
|
||||||
action = payload.get("action", "")
|
|
||||||
|
|
||||||
# 只处理打开事件
|
|
||||||
if action != "opened":
|
|
||||||
return
|
|
||||||
|
|
||||||
pr = payload.get("pull_request", {})
|
|
||||||
repo = payload.get("repository", {})
|
|
||||||
sender = payload.get("sender", {})
|
|
||||||
|
|
||||||
# 构造消息文本
|
|
||||||
message_text = (
|
|
||||||
f"🔀 新 Pull Request\n"
|
|
||||||
f"仓库: {repo.get('full_name', 'unknown')}\n"
|
|
||||||
f"标题: {pr.get('title', 'No title')}\n"
|
|
||||||
f"作者: {sender.get('login', 'unknown')}\n"
|
|
||||||
f"链接: {pr.get('html_url', '')}\n"
|
|
||||||
f"内容:\n{pr.get('body', 'No description')[:200]}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 创建 AstrBotMessage
|
|
||||||
abm = self._create_message(
|
|
||||||
message_text,
|
|
||||||
sender.get("login", "unknown"),
|
|
||||||
sender.get("login", "unknown"),
|
|
||||||
repo.get("full_name", "unknown"),
|
|
||||||
)
|
|
||||||
|
|
||||||
# 提交事件
|
|
||||||
self.commit_event(
|
|
||||||
GitHubWebhookMessageEvent(
|
|
||||||
message_text,
|
|
||||||
abm,
|
|
||||||
self.meta(),
|
|
||||||
repo.get("full_name", "unknown"),
|
|
||||||
"pull_request",
|
|
||||||
payload,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
def _create_message(
|
|
||||||
self,
|
|
||||||
message_text: str,
|
|
||||||
user_id: str,
|
|
||||||
nickname: str,
|
|
||||||
session_id: str,
|
|
||||||
) -> AstrBotMessage:
|
|
||||||
"""创建 AstrBotMessage 对象"""
|
|
||||||
abm = AstrBotMessage()
|
|
||||||
abm.type = MessageType.GROUP_MESSAGE
|
|
||||||
abm.self_id = self.client_self_id
|
|
||||||
abm.session_id = session_id
|
|
||||||
abm.message_id = ""
|
|
||||||
abm.sender = MessageMember(user_id=user_id, nickname=nickname)
|
|
||||||
abm.message = [Plain(message_text)]
|
|
||||||
abm.message_str = message_text
|
|
||||||
abm.raw_message = message_text
|
|
||||||
|
|
||||||
return abm
|
|
||||||
|
|
||||||
async def terminate(self):
|
|
||||||
"""终止适配器运行"""
|
|
||||||
self.shutdown_event.set()
|
|
||||||
logger.info("GitHub Webhook 适配器已经被优雅地关闭")
|
|
||||||
@@ -1,22 +0,0 @@
|
|||||||
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
|
|
||||||
|
|
||||||
from ...astr_message_event import AstrMessageEvent
|
|
||||||
|
|
||||||
|
|
||||||
class GitHubWebhookMessageEvent(AstrMessageEvent):
|
|
||||||
"""GitHub Webhook 消息事件"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
message_str: str,
|
|
||||||
message_obj: AstrBotMessage,
|
|
||||||
platform_meta: PlatformMetadata,
|
|
||||||
session_id: str,
|
|
||||||
event_type: str,
|
|
||||||
event_data: dict,
|
|
||||||
):
|
|
||||||
super().__init__(message_str, message_obj, platform_meta, session_id)
|
|
||||||
self.event_type = event_type
|
|
||||||
"""GitHub 事件类型: issues, issue_comment, pull_request"""
|
|
||||||
self.event_data = event_data
|
|
||||||
"""原始事件数据"""
|
|
||||||
@@ -48,6 +48,7 @@ class KnowledgeBaseRoute(Route):
|
|||||||
# 文档管理
|
# 文档管理
|
||||||
"/kb/document/list": ("GET", self.list_documents),
|
"/kb/document/list": ("GET", self.list_documents),
|
||||||
"/kb/document/upload": ("POST", self.upload_document),
|
"/kb/document/upload": ("POST", self.upload_document),
|
||||||
|
"/kb/document/import": ("POST", self.import_documents),
|
||||||
"/kb/document/upload/url": ("POST", self.upload_document_from_url),
|
"/kb/document/upload/url": ("POST", self.upload_document_from_url),
|
||||||
"/kb/document/upload/progress": ("GET", self.get_upload_progress),
|
"/kb/document/upload/progress": ("GET", self.get_upload_progress),
|
||||||
"/kb/document/get": ("GET", self.get_document),
|
"/kb/document/get": ("GET", self.get_document),
|
||||||
@@ -66,6 +67,65 @@ class KnowledgeBaseRoute(Route):
|
|||||||
def _get_kb_manager(self):
|
def _get_kb_manager(self):
|
||||||
return self.core_lifecycle.kb_manager
|
return self.core_lifecycle.kb_manager
|
||||||
|
|
||||||
|
def _init_task(self, task_id: str, status: str = "pending") -> None:
|
||||||
|
self.upload_tasks[task_id] = {
|
||||||
|
"status": status,
|
||||||
|
"result": None,
|
||||||
|
"error": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _set_task_result(
|
||||||
|
self, task_id: str, status: str, result: any = None, error: str | None = None
|
||||||
|
) -> None:
|
||||||
|
self.upload_tasks[task_id] = {
|
||||||
|
"status": status,
|
||||||
|
"result": result,
|
||||||
|
"error": error,
|
||||||
|
}
|
||||||
|
if task_id in self.upload_progress:
|
||||||
|
self.upload_progress[task_id]["status"] = status
|
||||||
|
|
||||||
|
def _update_progress(
|
||||||
|
self,
|
||||||
|
task_id: str,
|
||||||
|
*,
|
||||||
|
status: str | None = None,
|
||||||
|
file_index: int | None = None,
|
||||||
|
file_name: str | None = None,
|
||||||
|
stage: str | None = None,
|
||||||
|
current: int | None = None,
|
||||||
|
total: int | None = None,
|
||||||
|
) -> None:
|
||||||
|
if task_id not in self.upload_progress:
|
||||||
|
return
|
||||||
|
p = self.upload_progress[task_id]
|
||||||
|
if status is not None:
|
||||||
|
p["status"] = status
|
||||||
|
if file_index is not None:
|
||||||
|
p["file_index"] = file_index
|
||||||
|
if file_name is not None:
|
||||||
|
p["file_name"] = file_name
|
||||||
|
if stage is not None:
|
||||||
|
p["stage"] = stage
|
||||||
|
if current is not None:
|
||||||
|
p["current"] = current
|
||||||
|
if total is not None:
|
||||||
|
p["total"] = total
|
||||||
|
|
||||||
|
def _make_progress_callback(self, task_id: str, file_idx: int, file_name: str):
|
||||||
|
async def _callback(stage: str, current: int, total: int):
|
||||||
|
self._update_progress(
|
||||||
|
task_id,
|
||||||
|
status="processing",
|
||||||
|
file_index=file_idx,
|
||||||
|
file_name=file_name,
|
||||||
|
stage=stage,
|
||||||
|
current=current,
|
||||||
|
total=total,
|
||||||
|
)
|
||||||
|
|
||||||
|
return _callback
|
||||||
|
|
||||||
async def _background_upload_task(
|
async def _background_upload_task(
|
||||||
self,
|
self,
|
||||||
task_id: str,
|
task_id: str,
|
||||||
@@ -80,11 +140,7 @@ class KnowledgeBaseRoute(Route):
|
|||||||
"""后台上传任务"""
|
"""后台上传任务"""
|
||||||
try:
|
try:
|
||||||
# 初始化任务状态
|
# 初始化任务状态
|
||||||
self.upload_tasks[task_id] = {
|
self._init_task(task_id, status="processing")
|
||||||
"status": "processing",
|
|
||||||
"result": None,
|
|
||||||
"error": None,
|
|
||||||
}
|
|
||||||
self.upload_progress[task_id] = {
|
self.upload_progress[task_id] = {
|
||||||
"status": "processing",
|
"status": "processing",
|
||||||
"file_index": 0,
|
"file_index": 0,
|
||||||
@@ -100,30 +156,20 @@ class KnowledgeBaseRoute(Route):
|
|||||||
for file_idx, file_info in enumerate(files_to_upload):
|
for file_idx, file_info in enumerate(files_to_upload):
|
||||||
try:
|
try:
|
||||||
# 更新整体进度
|
# 更新整体进度
|
||||||
self.upload_progress[task_id].update(
|
self._update_progress(
|
||||||
{
|
task_id,
|
||||||
"status": "processing",
|
status="processing",
|
||||||
"file_index": file_idx,
|
file_index=file_idx,
|
||||||
"file_name": file_info["file_name"],
|
file_name=file_info["file_name"],
|
||||||
"stage": "parsing",
|
stage="parsing",
|
||||||
"current": 0,
|
current=0,
|
||||||
"total": 100,
|
total=100,
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# 创建进度回调函数
|
# 创建进度回调函数
|
||||||
async def progress_callback(stage, current, total):
|
progress_callback = self._make_progress_callback(
|
||||||
if task_id in self.upload_progress:
|
task_id, file_idx, file_info["file_name"]
|
||||||
self.upload_progress[task_id].update(
|
)
|
||||||
{
|
|
||||||
"status": "processing",
|
|
||||||
"file_index": file_idx,
|
|
||||||
"file_name": file_info["file_name"],
|
|
||||||
"stage": stage,
|
|
||||||
"current": current,
|
|
||||||
"total": total,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
doc = await kb_helper.upload_document(
|
doc = await kb_helper.upload_document(
|
||||||
file_name=file_info["file_name"],
|
file_name=file_info["file_name"],
|
||||||
@@ -154,23 +200,99 @@ class KnowledgeBaseRoute(Route):
|
|||||||
"failed_count": len(failed_docs),
|
"failed_count": len(failed_docs),
|
||||||
}
|
}
|
||||||
|
|
||||||
self.upload_tasks[task_id] = {
|
self._set_task_result(task_id, "completed", result=result)
|
||||||
"status": "completed",
|
|
||||||
"result": result,
|
|
||||||
"error": None,
|
|
||||||
}
|
|
||||||
self.upload_progress[task_id]["status"] = "completed"
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"后台上传任务 {task_id} 失败: {e}")
|
logger.error(f"后台上传任务 {task_id} 失败: {e}")
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
self.upload_tasks[task_id] = {
|
self._set_task_result(task_id, "failed", error=str(e))
|
||||||
"status": "failed",
|
|
||||||
"result": None,
|
async def _background_import_task(
|
||||||
"error": str(e),
|
self,
|
||||||
|
task_id: str,
|
||||||
|
kb_helper,
|
||||||
|
documents: list,
|
||||||
|
batch_size: int,
|
||||||
|
tasks_limit: int,
|
||||||
|
max_retries: int,
|
||||||
|
):
|
||||||
|
"""后台导入预切片文档任务"""
|
||||||
|
try:
|
||||||
|
# 初始化任务状态
|
||||||
|
self._init_task(task_id, status="processing")
|
||||||
|
self.upload_progress[task_id] = {
|
||||||
|
"status": "processing",
|
||||||
|
"file_index": 0,
|
||||||
|
"file_total": len(documents),
|
||||||
|
"stage": "waiting",
|
||||||
|
"current": 0,
|
||||||
|
"total": 100,
|
||||||
}
|
}
|
||||||
if task_id in self.upload_progress:
|
|
||||||
self.upload_progress[task_id]["status"] = "failed"
|
uploaded_docs = []
|
||||||
|
failed_docs = []
|
||||||
|
|
||||||
|
for file_idx, doc_info in enumerate(documents):
|
||||||
|
file_name = doc_info.get("file_name", f"imported_doc_{file_idx}")
|
||||||
|
chunks = doc_info.get("chunks", [])
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 更新整体进度
|
||||||
|
self._update_progress(
|
||||||
|
task_id,
|
||||||
|
status="processing",
|
||||||
|
file_index=file_idx,
|
||||||
|
file_name=file_name,
|
||||||
|
stage="importing",
|
||||||
|
current=0,
|
||||||
|
total=100,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 创建进度回调函数
|
||||||
|
progress_callback = self._make_progress_callback(
|
||||||
|
task_id, file_idx, file_name
|
||||||
|
)
|
||||||
|
|
||||||
|
# 调用 upload_document,传入 pre_chunked_text
|
||||||
|
doc = await kb_helper.upload_document(
|
||||||
|
file_name=file_name,
|
||||||
|
file_content=None, # 预切片模式下不需要原始内容
|
||||||
|
file_type=doc_info.get("file_type")
|
||||||
|
or (
|
||||||
|
file_name.rsplit(".", 1)[-1].lower()
|
||||||
|
if "." in file_name
|
||||||
|
else "txt"
|
||||||
|
),
|
||||||
|
batch_size=batch_size,
|
||||||
|
tasks_limit=tasks_limit,
|
||||||
|
max_retries=max_retries,
|
||||||
|
progress_callback=progress_callback,
|
||||||
|
pre_chunked_text=chunks,
|
||||||
|
)
|
||||||
|
|
||||||
|
uploaded_docs.append(doc.model_dump())
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"导入文档 {file_name} 失败: {e}")
|
||||||
|
failed_docs.append(
|
||||||
|
{"file_name": file_name, "error": str(e)},
|
||||||
|
)
|
||||||
|
|
||||||
|
# 更新任务完成状态
|
||||||
|
result = {
|
||||||
|
"task_id": task_id,
|
||||||
|
"uploaded": uploaded_docs,
|
||||||
|
"failed": failed_docs,
|
||||||
|
"total": len(documents),
|
||||||
|
"success_count": len(uploaded_docs),
|
||||||
|
"failed_count": len(failed_docs),
|
||||||
|
}
|
||||||
|
|
||||||
|
self._set_task_result(task_id, "completed", result=result)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"后台导入任务 {task_id} 失败: {e}")
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
self._set_task_result(task_id, "failed", error=str(e))
|
||||||
|
|
||||||
async def list_kbs(self):
|
async def list_kbs(self):
|
||||||
"""获取知识库列表
|
"""获取知识库列表
|
||||||
@@ -614,11 +736,7 @@ class KnowledgeBaseRoute(Route):
|
|||||||
task_id = str(uuid.uuid4())
|
task_id = str(uuid.uuid4())
|
||||||
|
|
||||||
# 初始化任务状态
|
# 初始化任务状态
|
||||||
self.upload_tasks[task_id] = {
|
self._init_task(task_id, status="pending")
|
||||||
"status": "pending",
|
|
||||||
"result": None,
|
|
||||||
"error": None,
|
|
||||||
}
|
|
||||||
|
|
||||||
# 启动后台任务
|
# 启动后台任务
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
@@ -653,6 +771,93 @@ class KnowledgeBaseRoute(Route):
|
|||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
return Response().error(f"上传文档失败: {e!s}").__dict__
|
return Response().error(f"上传文档失败: {e!s}").__dict__
|
||||||
|
|
||||||
|
def _validate_import_request(self, data: dict):
|
||||||
|
kb_id = data.get("kb_id")
|
||||||
|
if not kb_id:
|
||||||
|
raise ValueError("缺少参数 kb_id")
|
||||||
|
|
||||||
|
documents = data.get("documents")
|
||||||
|
if not documents or not isinstance(documents, list):
|
||||||
|
raise ValueError("缺少参数 documents 或格式错误")
|
||||||
|
|
||||||
|
for doc in documents:
|
||||||
|
if "file_name" not in doc or "chunks" not in doc:
|
||||||
|
raise ValueError("文档格式错误,必须包含 file_name 和 chunks")
|
||||||
|
if not isinstance(doc["chunks"], list):
|
||||||
|
raise ValueError("chunks 必须是列表")
|
||||||
|
if not all(
|
||||||
|
isinstance(chunk, str) and chunk.strip() for chunk in doc["chunks"]
|
||||||
|
):
|
||||||
|
raise ValueError("chunks 必须是非空字符串列表")
|
||||||
|
|
||||||
|
batch_size = data.get("batch_size", 32)
|
||||||
|
tasks_limit = data.get("tasks_limit", 3)
|
||||||
|
max_retries = data.get("max_retries", 3)
|
||||||
|
return kb_id, documents, batch_size, tasks_limit, max_retries
|
||||||
|
|
||||||
|
async def import_documents(self):
|
||||||
|
"""导入预切片文档
|
||||||
|
|
||||||
|
Body:
|
||||||
|
- kb_id: 知识库 ID (必填)
|
||||||
|
- documents: 文档列表 (必填)
|
||||||
|
- file_name: 文件名 (必填)
|
||||||
|
- chunks: 切片列表 (必填, list[str])
|
||||||
|
- file_type: 文件类型 (可选, 默认从文件名推断或为 txt)
|
||||||
|
- batch_size: 批处理大小 (可选, 默认32)
|
||||||
|
- tasks_limit: 并发任务限制 (可选, 默认3)
|
||||||
|
- max_retries: 最大重试次数 (可选, 默认3)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
kb_manager = self._get_kb_manager()
|
||||||
|
data = await request.json
|
||||||
|
|
||||||
|
kb_id, documents, batch_size, tasks_limit, max_retries = (
|
||||||
|
self._validate_import_request(data)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 获取知识库
|
||||||
|
kb_helper = await kb_manager.get_kb(kb_id)
|
||||||
|
if not kb_helper:
|
||||||
|
return Response().error("知识库不存在").__dict__
|
||||||
|
|
||||||
|
# 生成任务ID
|
||||||
|
task_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
# 初始化任务状态
|
||||||
|
self._init_task(task_id, status="pending")
|
||||||
|
|
||||||
|
# 启动后台任务
|
||||||
|
asyncio.create_task(
|
||||||
|
self._background_import_task(
|
||||||
|
task_id=task_id,
|
||||||
|
kb_helper=kb_helper,
|
||||||
|
documents=documents,
|
||||||
|
batch_size=batch_size,
|
||||||
|
tasks_limit=tasks_limit,
|
||||||
|
max_retries=max_retries,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
return (
|
||||||
|
Response()
|
||||||
|
.ok(
|
||||||
|
{
|
||||||
|
"task_id": task_id,
|
||||||
|
"doc_count": len(documents),
|
||||||
|
"message": "import task created, processing in background",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.__dict__
|
||||||
|
)
|
||||||
|
|
||||||
|
except ValueError as e:
|
||||||
|
return Response().error(str(e)).__dict__
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"导入文档失败: {e}")
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
return Response().error(f"导入文档失败: {e!s}").__dict__
|
||||||
|
|
||||||
async def get_upload_progress(self):
|
async def get_upload_progress(self):
|
||||||
"""获取上传进度和结果
|
"""获取上传进度和结果
|
||||||
|
|
||||||
@@ -960,11 +1165,7 @@ class KnowledgeBaseRoute(Route):
|
|||||||
task_id = str(uuid.uuid4())
|
task_id = str(uuid.uuid4())
|
||||||
|
|
||||||
# 初始化任务状态
|
# 初始化任务状态
|
||||||
self.upload_tasks[task_id] = {
|
self._init_task(task_id, status="pending")
|
||||||
"status": "pending",
|
|
||||||
"result": None,
|
|
||||||
"error": None,
|
|
||||||
}
|
|
||||||
|
|
||||||
# 启动后台任务
|
# 启动后台任务
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
@@ -1017,11 +1218,7 @@ class KnowledgeBaseRoute(Route):
|
|||||||
"""后台上传URL任务"""
|
"""后台上传URL任务"""
|
||||||
try:
|
try:
|
||||||
# 初始化任务状态
|
# 初始化任务状态
|
||||||
self.upload_tasks[task_id] = {
|
self._init_task(task_id, status="processing")
|
||||||
"status": "processing",
|
|
||||||
"result": None,
|
|
||||||
"error": None,
|
|
||||||
}
|
|
||||||
self.upload_progress[task_id] = {
|
self.upload_progress[task_id] = {
|
||||||
"status": "processing",
|
"status": "processing",
|
||||||
"file_index": 0,
|
"file_index": 0,
|
||||||
@@ -1033,18 +1230,7 @@ class KnowledgeBaseRoute(Route):
|
|||||||
}
|
}
|
||||||
|
|
||||||
# 创建进度回调函数
|
# 创建进度回调函数
|
||||||
async def progress_callback(stage, current, total):
|
progress_callback = self._make_progress_callback(task_id, 0, f"URL: {url}")
|
||||||
if task_id in self.upload_progress:
|
|
||||||
self.upload_progress[task_id].update(
|
|
||||||
{
|
|
||||||
"status": "processing",
|
|
||||||
"file_index": 0,
|
|
||||||
"file_name": f"URL: {url}",
|
|
||||||
"stage": stage,
|
|
||||||
"current": current,
|
|
||||||
"total": total,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# 上传文档
|
# 上传文档
|
||||||
doc = await kb_helper.upload_from_url(
|
doc = await kb_helper.upload_from_url(
|
||||||
@@ -1069,20 +1255,9 @@ class KnowledgeBaseRoute(Route):
|
|||||||
"failed_count": 0,
|
"failed_count": 0,
|
||||||
}
|
}
|
||||||
|
|
||||||
self.upload_tasks[task_id] = {
|
self._set_task_result(task_id, "completed", result=result)
|
||||||
"status": "completed",
|
|
||||||
"result": result,
|
|
||||||
"error": None,
|
|
||||||
}
|
|
||||||
self.upload_progress[task_id]["status"] = "completed"
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"后台上传URL任务 {task_id} 失败: {e}")
|
logger.error(f"后台上传URL任务 {task_id} 失败: {e}")
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
self.upload_tasks[task_id] = {
|
self._set_task_result(task_id, "failed", error=str(e))
|
||||||
"status": "failed",
|
|
||||||
"result": None,
|
|
||||||
"error": str(e),
|
|
||||||
}
|
|
||||||
if task_id in self.upload_progress:
|
|
||||||
self.upload_progress[task_id]["status"] = "failed"
|
|
||||||
|
|||||||
@@ -0,0 +1,19 @@
|
|||||||
|
## What's Changed
|
||||||
|
|
||||||
|
### 新增
|
||||||
|
|
||||||
|
- 支持自定义插件源。
|
||||||
|
- 支持飞书(Lark)的 Webhook 模式(将事件推送至开发者服务器)。
|
||||||
|
- 支持 “禁用自带指令” 快捷配置项,启用后将禁用所有 AstrBot 自带指令。入口: WebUI -> 配置文件 -> 平台配置。
|
||||||
|
|
||||||
|
### 优化
|
||||||
|
|
||||||
|
- 从 WebUI 移除了开发版本渠道。
|
||||||
|
- 当试图测试"Agent Runner"时,提示前往配置文件页测试。
|
||||||
|
- WebUI 列表项支持批量粘贴、回车创建项目。
|
||||||
|
|
||||||
|
### 修复
|
||||||
|
|
||||||
|
- Gemini API 部分调用失败的问题。
|
||||||
|
- WebUI 插件安装加载 Dialog 关闭按钮在手机端下显示异常的问题。
|
||||||
|
- 部分情况下,WebUI 日志显示不全的问题。
|
||||||
Binary file not shown.
|
After Width: | Height: | Size: 12 KiB |
@@ -1,6 +1,7 @@
|
|||||||
<script setup>
|
<script setup>
|
||||||
import { useCommonStore } from '@/stores/common';
|
import { useCommonStore } from '@/stores/common';
|
||||||
import { storeToRefs } from 'pinia';
|
import { storeToRefs } from 'pinia';
|
||||||
|
import axios from 'axios';
|
||||||
</script>
|
</script>
|
||||||
|
|
||||||
<template>
|
<template>
|
||||||
@@ -24,8 +25,6 @@ import { storeToRefs } from 'pinia';
|
|||||||
export default {
|
export default {
|
||||||
name: 'ConsoleDisplayer',
|
name: 'ConsoleDisplayer',
|
||||||
data() {
|
data() {
|
||||||
const commonStore = useCommonStore();
|
|
||||||
const { log_cache } = storeToRefs(commonStore);
|
|
||||||
return {
|
return {
|
||||||
autoScroll: true, // 默认开启自动滚动
|
autoScroll: true, // 默认开启自动滚动
|
||||||
logColorAnsiMap: {
|
logColorAnsiMap: {
|
||||||
@@ -38,7 +37,6 @@ export default {
|
|||||||
'\u001b[32m': 'color: #00FF00;', // green
|
'\u001b[32m': 'color: #00FF00;', // green
|
||||||
'default': 'color: #FFFFFF;'
|
'default': 'color: #FFFFFF;'
|
||||||
},
|
},
|
||||||
logCache: log_cache,
|
|
||||||
historyNum_: -1,
|
historyNum_: -1,
|
||||||
logLevels: ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
|
logLevels: ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
|
||||||
selectedLevels: [0, 1, 2, 3, 4], // 默认选中所有级别
|
selectedLevels: [0, 1, 2, 3, 4], // 默认选中所有级别
|
||||||
@@ -48,7 +46,17 @@ export default {
|
|||||||
'WARNING': 'amber',
|
'WARNING': 'amber',
|
||||||
'ERROR': 'red',
|
'ERROR': 'red',
|
||||||
'CRITICAL': 'purple'
|
'CRITICAL': 'purple'
|
||||||
}
|
},
|
||||||
|
lastProcessedTime: 0, // 记录最后处理的日志时间戳
|
||||||
|
localLogCache: [], // 本地日志缓存
|
||||||
|
}
|
||||||
|
},
|
||||||
|
computed: {
|
||||||
|
commonStore() {
|
||||||
|
return useCommonStore();
|
||||||
|
},
|
||||||
|
logCache() {
|
||||||
|
return this.commonStore.log_cache;
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
props: {
|
props: {
|
||||||
@@ -63,13 +71,39 @@ export default {
|
|||||||
},
|
},
|
||||||
watch: {
|
watch: {
|
||||||
logCache: {
|
logCache: {
|
||||||
handler(val) {
|
handler(newVal) {
|
||||||
const lastLog = val[this.logCache.length - 1];
|
// 基于 timestamp 处理新增的日志
|
||||||
if (lastLog && this.isLevelSelected(lastLog.level)) {
|
if (newVal && newVal.length > 0) {
|
||||||
this.printLog(lastLog.data);
|
// 确保 DOM 已经准备好
|
||||||
|
this.$nextTick(() => {
|
||||||
|
// 合并到本地缓存并按时间排序
|
||||||
|
const newLogs = newVal.filter(log => log.time > this.lastProcessedTime);
|
||||||
|
|
||||||
|
if (newLogs.length > 0) {
|
||||||
|
this.localLogCache.push(...newLogs);
|
||||||
|
// 按时间戳排序
|
||||||
|
this.localLogCache.sort((a, b) => a.time - b.time);
|
||||||
|
|
||||||
|
// 只保留最新的 log_cache_max_len 条
|
||||||
|
if (this.localLogCache.length > this.commonStore.log_cache_max_len) {
|
||||||
|
this.localLogCache.splice(0, this.localLogCache.length - this.commonStore.log_cache_max_len);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 显示新日志
|
||||||
|
newLogs.forEach(logItem => {
|
||||||
|
if (this.isLevelSelected(logItem.level)) {
|
||||||
|
this.printLog(logItem.data);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// 更新最后处理时间
|
||||||
|
this.lastProcessedTime = Math.max(...newLogs.map(log => log.time));
|
||||||
|
}
|
||||||
|
});
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
deep: true
|
deep: true,
|
||||||
|
immediate: false
|
||||||
},
|
},
|
||||||
selectedLevels: {
|
selectedLevels: {
|
||||||
handler() {
|
handler() {
|
||||||
@@ -78,14 +112,37 @@ export default {
|
|||||||
deep: true
|
deep: true
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
mounted() {
|
async mounted() {
|
||||||
if (this.logCache.length === 0) {
|
// 请求历史日志
|
||||||
this.delayInit()
|
await this.fetchLogHistory();
|
||||||
} else {
|
|
||||||
this.init()
|
// 等待 DOM 准备好后,显示历史日志
|
||||||
}
|
this.$nextTick(() => {
|
||||||
|
if (this.localLogCache.length > 0) {
|
||||||
|
this.localLogCache.forEach(logItem => {
|
||||||
|
if (this.isLevelSelected(logItem.level)) {
|
||||||
|
this.printLog(logItem.data);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
// 更新最后处理时间
|
||||||
|
this.lastProcessedTime = Math.max(...this.localLogCache.map(log => log.time));
|
||||||
|
}
|
||||||
|
});
|
||||||
},
|
},
|
||||||
methods: {
|
methods: {
|
||||||
|
async fetchLogHistory() {
|
||||||
|
try {
|
||||||
|
const res = await axios.get('/api/log-history');
|
||||||
|
if (res.data.data.logs && res.data.data.logs.length > 0) {
|
||||||
|
this.localLogCache = [...res.data.data.logs];
|
||||||
|
// 按时间戳排序
|
||||||
|
this.localLogCache.sort((a, b) => a.time - b.time);
|
||||||
|
}
|
||||||
|
} catch (err) {
|
||||||
|
console.error('Failed to fetch log history:', err);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
getLevelColor(level) {
|
getLevelColor(level) {
|
||||||
return this.levelColors[level] || 'grey';
|
return this.levelColors[level] || 'grey';
|
||||||
},
|
},
|
||||||
@@ -101,41 +158,22 @@ export default {
|
|||||||
},
|
},
|
||||||
|
|
||||||
refreshDisplay() {
|
refreshDisplay() {
|
||||||
// 清空现有的显示
|
|
||||||
const termElement = document.getElementById('term');
|
const termElement = document.getElementById('term');
|
||||||
if (termElement) {
|
if (termElement) {
|
||||||
termElement.innerHTML = '';
|
termElement.innerHTML = '';
|
||||||
}
|
|
||||||
|
|
||||||
// 重新显示符合筛选条件的日志
|
// 重新显示所有符合筛选条件的日志
|
||||||
this.init();
|
if (this.localLogCache && this.localLogCache.length > 0) {
|
||||||
},
|
this.localLogCache.forEach(logItem => {
|
||||||
|
if (this.isLevelSelected(logItem.level)) {
|
||||||
delayInit() {
|
this.printLog(logItem.data);
|
||||||
if (this.logCache.length === 0) {
|
}
|
||||||
setTimeout(() => {
|
});
|
||||||
this.delayInit()
|
|
||||||
}, 500)
|
|
||||||
} else {
|
|
||||||
this.init()
|
|
||||||
}
|
|
||||||
},
|
|
||||||
|
|
||||||
init() {
|
|
||||||
this.historyNum_ = parseInt(this.historyNum)
|
|
||||||
let i = 0
|
|
||||||
for (let log of this.logCache) {
|
|
||||||
if (this.isLevelSelected(log.level)) { // 只显示选中级别的日志
|
|
||||||
if (this.historyNum_ != -1 && i >= this.logCache.length - this.historyNum_) {
|
|
||||||
this.printLog(log.data)
|
|
||||||
++i
|
|
||||||
} else if (this.historyNum_ == -1) {
|
|
||||||
this.printLog(log.data)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|
||||||
|
|
||||||
toggleAutoScroll() {
|
toggleAutoScroll() {
|
||||||
this.autoScroll = !this.autoScroll;
|
this.autoScroll = !this.autoScroll;
|
||||||
},
|
},
|
||||||
@@ -143,6 +181,11 @@ export default {
|
|||||||
printLog(log) {
|
printLog(log) {
|
||||||
// append 一个 span 标签到 term,block 的方式
|
// append 一个 span 标签到 term,block 的方式
|
||||||
let ele = document.getElementById('term')
|
let ele = document.getElementById('term')
|
||||||
|
if (!ele) {
|
||||||
|
console.warn('term element not found, skipping log print');
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
let span = document.createElement('pre')
|
let span = document.createElement('pre')
|
||||||
let style = this.logColorAnsiMap['default']
|
let style = this.logColorAnsiMap['default']
|
||||||
for (let key in this.logColorAnsiMap) {
|
for (let key in this.logColorAnsiMap) {
|
||||||
|
|||||||
@@ -16,21 +16,6 @@ export const useCommonStore = defineStore({
|
|||||||
}),
|
}),
|
||||||
actions: {
|
actions: {
|
||||||
async createEventSource() {
|
async createEventSource() {
|
||||||
|
|
||||||
const fetchLogHistory = async () => {
|
|
||||||
try {
|
|
||||||
const res = await axios.get('/api/log-history');
|
|
||||||
if (res.data.data.logs) {
|
|
||||||
this.log_cache.push(...res.data.data.logs);
|
|
||||||
} else {
|
|
||||||
this.log_cache = [];
|
|
||||||
}
|
|
||||||
} catch (err) {
|
|
||||||
console.error('Failed to fetch log history:', err);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
await fetchLogHistory();
|
|
||||||
|
|
||||||
if (this.eventSource) {
|
if (this.eventSource) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -54,25 +39,9 @@ export const useCommonStore = defineStore({
|
|||||||
|
|
||||||
const reader = response.body.getReader();
|
const reader = response.body.getReader();
|
||||||
const decoder = new TextDecoder();
|
const decoder = new TextDecoder();
|
||||||
|
let bufferedText = '';
|
||||||
let incompleteLine = ""; // 用于存储不完整的行
|
|
||||||
|
|
||||||
const handleIncompleteLine = (line) => {
|
|
||||||
incompleteLine += line;
|
|
||||||
// if can parse as JSON, return it
|
|
||||||
try {
|
|
||||||
const data_json = JSON.parse(incompleteLine);
|
|
||||||
incompleteLine = ""; // 清空不完整行
|
|
||||||
return data_json;
|
|
||||||
} catch (e) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const processStream = ({ done, value }) => {
|
const processStream = ({ done, value }) => {
|
||||||
// get bytes length
|
|
||||||
const bytesLength = value ? value.byteLength : 0;
|
|
||||||
console.log(`Received ${bytesLength} bytes from live log`);
|
|
||||||
if (done) {
|
if (done) {
|
||||||
console.log('SSE stream closed');
|
console.log('SSE stream closed');
|
||||||
setTimeout(() => {
|
setTimeout(() => {
|
||||||
@@ -82,44 +51,41 @@ export const useCommonStore = defineStore({
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const text = decoder.decode(value);
|
// Accumulate partial chunks; SSE data may split JSON across reads.
|
||||||
const lines = text.split('\n\n');
|
const text = decoder.decode(value, { stream: true });
|
||||||
lines.forEach(line => {
|
bufferedText += text;
|
||||||
if (!line.trim()) {
|
|
||||||
|
// Split completed events; keep the trailing partial in buffer.
|
||||||
|
const segments = bufferedText.split('\n\n');
|
||||||
|
bufferedText = segments.pop() || '';
|
||||||
|
|
||||||
|
segments.forEach(segment => {
|
||||||
|
const line = segment.trim();
|
||||||
|
if (!line.startsWith('data: ')) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
if (line.startsWith('data:')) {
|
|
||||||
const data = line.substring(5).trim();
|
const logLine = line.replace('data: ', '').trim();
|
||||||
// {"type":"log","data":"[2021-08-01 00:00:00] INFO: Hello, world!"}
|
if (!logLine) {
|
||||||
let data_json = {}
|
return;
|
||||||
try {
|
}
|
||||||
data_json = JSON.parse(data);
|
|
||||||
} catch (e) {
|
try {
|
||||||
console.warn('Invalid JSON:', data);
|
const logObject = JSON.parse(logLine);
|
||||||
// 尝试处理不完整的行
|
// give a uuid if not exists
|
||||||
const parsedData = handleIncompleteLine(data);
|
if (!logObject.uuid) {
|
||||||
if (parsedData) {
|
logObject.uuid = crypto.randomUUID();
|
||||||
data_json = parsedData;
|
|
||||||
} else {
|
|
||||||
return; // 如果无法解析,跳过当前行
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
if (data_json.type === 'log') {
|
this.log_cache.push(logObject);
|
||||||
this.log_cache.push(data_json);
|
// Limit log cache size
|
||||||
if (this.log_cache.length > this.log_cache_max_len) {
|
if (this.log_cache.length > this.log_cache_max_len) {
|
||||||
this.log_cache.shift();
|
this.log_cache.splice(0, this.log_cache.length - this.log_cache_max_len);
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
const parsedData = handleIncompleteLine(line);
|
|
||||||
if (parsedData && parsedData.type === 'log') {
|
|
||||||
this.log_cache.push(parsedData);
|
|
||||||
if (this.log_cache.length > this.log_cache_max_len) {
|
|
||||||
this.log_cache.shift();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
} catch (err) {
|
||||||
|
console.warn('Failed to parse SSE log line, skipping:', err, logLine);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
return reader.read().then(processStream);
|
return reader.read().then(processStream);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
+1
-1
@@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "AstrBot"
|
name = "AstrBot"
|
||||||
version = "4.8.0"
|
version = "4.9.0"
|
||||||
description = "Easy-to-use multi-platform LLM chatbot and development framework"
|
description = "Easy-to-use multi-platform LLM chatbot and development framework"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.10"
|
requires-python = ">=3.10"
|
||||||
|
|||||||
@@ -1,279 +0,0 @@
|
|||||||
"""Test GitHub webhook platform adapter"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import hashlib
|
|
||||||
import hmac
|
|
||||||
from unittest.mock import MagicMock
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from astrbot.core.platform.sources.github_webhook.github_webhook_adapter import (
|
|
||||||
GitHubWebhookPlatformAdapter,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def event_queue():
|
|
||||||
"""Create a test event queue"""
|
|
||||||
return asyncio.Queue()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def platform_config():
|
|
||||||
"""Create test platform configuration"""
|
|
||||||
return {
|
|
||||||
"type": "github_webhook",
|
|
||||||
"enable": True,
|
|
||||||
"id": "test_github_webhook",
|
|
||||||
"unified_webhook_mode": True,
|
|
||||||
"webhook_uuid": "test-uuid-123",
|
|
||||||
"webhook_secret": "", # No secret by default for easier testing
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def platform_settings():
|
|
||||||
"""Create test platform settings"""
|
|
||||||
return {"unique_session": False}
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def adapter(platform_config, platform_settings, event_queue):
|
|
||||||
"""Create test adapter instance"""
|
|
||||||
return GitHubWebhookPlatformAdapter(platform_config, platform_settings, event_queue)
|
|
||||||
|
|
||||||
|
|
||||||
class TestGitHubWebhookAdapter:
|
|
||||||
"""Test cases for GitHub webhook adapter"""
|
|
||||||
|
|
||||||
def test_adapter_initialization(self, adapter):
|
|
||||||
"""Test adapter is initialized correctly"""
|
|
||||||
assert adapter.unified_webhook_mode is True
|
|
||||||
assert adapter.webhook_secret == ""
|
|
||||||
assert adapter.meta().name == "github_webhook"
|
|
||||||
assert adapter.meta().description == "GitHub Webhook 适配器"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_ping_event(self, adapter):
|
|
||||||
"""Test GitHub ping event"""
|
|
||||||
# Mock request
|
|
||||||
request = MagicMock()
|
|
||||||
request.headers.get.return_value = "ping"
|
|
||||||
|
|
||||||
async def mock_json():
|
|
||||||
return {}
|
|
||||||
|
|
||||||
request.json = mock_json()
|
|
||||||
|
|
||||||
response = await adapter.webhook_callback(request)
|
|
||||||
assert response == {"message": "pong"}
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_issue_created_event(self, adapter, event_queue):
|
|
||||||
"""Test GitHub issue created event"""
|
|
||||||
# Mock request with issue created payload
|
|
||||||
request = MagicMock()
|
|
||||||
request.headers.get.return_value = "issues"
|
|
||||||
payload = {
|
|
||||||
"action": "opened",
|
|
||||||
"issue": {
|
|
||||||
"title": "Test Issue",
|
|
||||||
"body": "This is a test issue",
|
|
||||||
"html_url": "https://github.com/test/repo/issues/1",
|
|
||||||
},
|
|
||||||
"repository": {"full_name": "test/repo"},
|
|
||||||
"sender": {"login": "testuser"},
|
|
||||||
}
|
|
||||||
|
|
||||||
async def mock_json():
|
|
||||||
return payload
|
|
||||||
|
|
||||||
request.json = mock_json()
|
|
||||||
|
|
||||||
response = await adapter.webhook_callback(request)
|
|
||||||
assert response == {"status": "ok"}
|
|
||||||
|
|
||||||
# Verify event was queued
|
|
||||||
assert not event_queue.empty()
|
|
||||||
event = event_queue.get_nowait()
|
|
||||||
assert event.event_type == "issues"
|
|
||||||
assert "新 Issue 创建" in event.message_str
|
|
||||||
assert "Test Issue" in event.message_str
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_issue_comment_event(self, adapter, event_queue):
|
|
||||||
"""Test GitHub issue comment event"""
|
|
||||||
request = MagicMock()
|
|
||||||
request.headers.get.return_value = "issue_comment"
|
|
||||||
payload = {
|
|
||||||
"action": "created",
|
|
||||||
"issue": {"title": "Test Issue"},
|
|
||||||
"comment": {
|
|
||||||
"body": "Test comment",
|
|
||||||
"html_url": "https://github.com/test/repo/issues/1#comment",
|
|
||||||
},
|
|
||||||
"repository": {"full_name": "test/repo"},
|
|
||||||
"sender": {"login": "commenter"},
|
|
||||||
}
|
|
||||||
|
|
||||||
async def mock_json():
|
|
||||||
return payload
|
|
||||||
|
|
||||||
request.json = mock_json()
|
|
||||||
|
|
||||||
response = await adapter.webhook_callback(request)
|
|
||||||
assert response == {"status": "ok"}
|
|
||||||
|
|
||||||
# Verify event was queued
|
|
||||||
assert not event_queue.empty()
|
|
||||||
event = event_queue.get_nowait()
|
|
||||||
assert event.event_type == "issue_comment"
|
|
||||||
assert "新 Issue 评论" in event.message_str
|
|
||||||
assert "Test comment" in event.message_str
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_pull_request_event(self, adapter, event_queue):
|
|
||||||
"""Test GitHub pull request opened event"""
|
|
||||||
request = MagicMock()
|
|
||||||
request.headers.get.return_value = "pull_request"
|
|
||||||
payload = {
|
|
||||||
"action": "opened",
|
|
||||||
"pull_request": {
|
|
||||||
"title": "Test PR",
|
|
||||||
"body": "This is a test PR",
|
|
||||||
"html_url": "https://github.com/test/repo/pull/1",
|
|
||||||
},
|
|
||||||
"repository": {"full_name": "test/repo"},
|
|
||||||
"sender": {"login": "prauthor"},
|
|
||||||
}
|
|
||||||
|
|
||||||
async def mock_json():
|
|
||||||
return payload
|
|
||||||
|
|
||||||
request.json = mock_json()
|
|
||||||
|
|
||||||
response = await adapter.webhook_callback(request)
|
|
||||||
assert response == {"status": "ok"}
|
|
||||||
|
|
||||||
# Verify event was queued
|
|
||||||
assert not event_queue.empty()
|
|
||||||
event = event_queue.get_nowait()
|
|
||||||
assert event.event_type == "pull_request"
|
|
||||||
assert "新 Pull Request" in event.message_str
|
|
||||||
assert "Test PR" in event.message_str
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_unsupported_event(self, adapter, event_queue):
|
|
||||||
"""Test unsupported GitHub event type"""
|
|
||||||
request = MagicMock()
|
|
||||||
request.headers.get.return_value = "push"
|
|
||||||
|
|
||||||
async def mock_json():
|
|
||||||
return {"action": "created"}
|
|
||||||
|
|
||||||
request.json = mock_json()
|
|
||||||
|
|
||||||
response = await adapter.webhook_callback(request)
|
|
||||||
assert response == {"status": "ok"}
|
|
||||||
|
|
||||||
# Verify no event was queued for unsupported events
|
|
||||||
assert event_queue.empty()
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_issue_closed_ignored(self, adapter, event_queue):
|
|
||||||
"""Test that issue closed action is ignored"""
|
|
||||||
request = MagicMock()
|
|
||||||
request.headers.get.return_value = "issues"
|
|
||||||
payload = {
|
|
||||||
"action": "closed", # Should be ignored
|
|
||||||
"issue": {"title": "Test Issue"},
|
|
||||||
"repository": {"full_name": "test/repo"},
|
|
||||||
"sender": {"login": "testuser"},
|
|
||||||
}
|
|
||||||
|
|
||||||
async def mock_json():
|
|
||||||
return payload
|
|
||||||
|
|
||||||
request.json = mock_json()
|
|
||||||
|
|
||||||
response = await adapter.webhook_callback(request)
|
|
||||||
assert response == {"status": "ok"}
|
|
||||||
|
|
||||||
# Verify no event was queued
|
|
||||||
assert event_queue.empty()
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_signature_verification(self, platform_settings, event_queue):
|
|
||||||
"""Test webhook signature verification"""
|
|
||||||
# Create adapter with webhook secret
|
|
||||||
config_with_secret = {
|
|
||||||
"type": "github_webhook",
|
|
||||||
"enable": True,
|
|
||||||
"id": "test_github_webhook",
|
|
||||||
"unified_webhook_mode": True,
|
|
||||||
"webhook_uuid": "test-uuid-123",
|
|
||||||
"webhook_secret": "test-secret",
|
|
||||||
}
|
|
||||||
adapter = GitHubWebhookPlatformAdapter(
|
|
||||||
config_with_secret, platform_settings, event_queue
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create a valid signature
|
|
||||||
body = b'{"action": "opened"}'
|
|
||||||
signature = hmac.new(b"test-secret", body, hashlib.sha256).hexdigest()
|
|
||||||
|
|
||||||
# Mock request with valid signature
|
|
||||||
request = MagicMock()
|
|
||||||
request.headers.get = lambda key, default="": {
|
|
||||||
"X-GitHub-Event": "ping",
|
|
||||||
"X-Hub-Signature-256": f"sha256={signature}",
|
|
||||||
}.get(key, default)
|
|
||||||
|
|
||||||
async def mock_get_data():
|
|
||||||
return body
|
|
||||||
|
|
||||||
request.get_data = mock_get_data
|
|
||||||
|
|
||||||
async def mock_json():
|
|
||||||
return {"action": "opened"}
|
|
||||||
|
|
||||||
request.json = mock_json()
|
|
||||||
|
|
||||||
response = await adapter.webhook_callback(request)
|
|
||||||
assert response == {"message": "pong"}
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_invalid_signature(self, platform_settings, event_queue):
|
|
||||||
"""Test webhook with invalid signature is rejected"""
|
|
||||||
# Create adapter with webhook secret
|
|
||||||
config_with_secret = {
|
|
||||||
"type": "github_webhook",
|
|
||||||
"enable": True,
|
|
||||||
"id": "test_github_webhook",
|
|
||||||
"unified_webhook_mode": True,
|
|
||||||
"webhook_uuid": "test-uuid-123",
|
|
||||||
"webhook_secret": "test-secret",
|
|
||||||
}
|
|
||||||
adapter = GitHubWebhookPlatformAdapter(
|
|
||||||
config_with_secret, platform_settings, event_queue
|
|
||||||
)
|
|
||||||
|
|
||||||
# Mock request with invalid signature
|
|
||||||
request = MagicMock()
|
|
||||||
request.headers.get = lambda key, default="": {
|
|
||||||
"X-GitHub-Event": "ping",
|
|
||||||
"X-Hub-Signature-256": "sha256=invalidsignature",
|
|
||||||
}.get(key, default)
|
|
||||||
|
|
||||||
async def mock_get_data():
|
|
||||||
return b'{"action": "opened"}'
|
|
||||||
|
|
||||||
request.get_data = mock_get_data
|
|
||||||
|
|
||||||
async def mock_json():
|
|
||||||
return {"action": "opened"}
|
|
||||||
|
|
||||||
request.json = mock_json()
|
|
||||||
|
|
||||||
response = await adapter.webhook_callback(request)
|
|
||||||
assert response == ({"error": "Invalid signature"}, 401)
|
|
||||||
@@ -0,0 +1,209 @@
|
|||||||
|
import asyncio
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
from quart import Quart
|
||||||
|
|
||||||
|
from astrbot.core import LogBroker
|
||||||
|
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||||
|
from astrbot.core.db.sqlite import SQLiteDatabase
|
||||||
|
from astrbot.core.knowledge_base.kb_helper import KBHelper
|
||||||
|
from astrbot.core.knowledge_base.models import KBDocument
|
||||||
|
from astrbot.dashboard.server import AstrBotDashboard
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture(scope="module")
|
||||||
|
async def core_lifecycle_td(tmp_path_factory):
|
||||||
|
"""Creates and initializes a core lifecycle instance with a temporary database."""
|
||||||
|
tmp_db_path = tmp_path_factory.mktemp("data") / "test_data_kb.db"
|
||||||
|
db = SQLiteDatabase(str(tmp_db_path))
|
||||||
|
log_broker = LogBroker()
|
||||||
|
core_lifecycle = AstrBotCoreLifecycle(log_broker, db)
|
||||||
|
await core_lifecycle.initialize()
|
||||||
|
|
||||||
|
# Mock kb_manager and kb_helper
|
||||||
|
kb_manager = MagicMock()
|
||||||
|
kb_helper = AsyncMock(spec=KBHelper)
|
||||||
|
|
||||||
|
# Configure get_kb to be an async mock that returns kb_helper
|
||||||
|
kb_manager.get_kb = AsyncMock(return_value=kb_helper)
|
||||||
|
|
||||||
|
# Mock upload_document return value
|
||||||
|
mock_doc = KBDocument(
|
||||||
|
doc_id="test_doc_id",
|
||||||
|
kb_id="test_kb_id",
|
||||||
|
doc_name="test_file.txt",
|
||||||
|
file_type="txt",
|
||||||
|
file_size=100,
|
||||||
|
file_path="",
|
||||||
|
chunk_count=2,
|
||||||
|
media_count=0,
|
||||||
|
)
|
||||||
|
kb_helper.upload_document.return_value = mock_doc
|
||||||
|
|
||||||
|
# kb_manager.get_kb.return_value = kb_helper # Removed this line as it's handled above
|
||||||
|
core_lifecycle.kb_manager = kb_manager
|
||||||
|
|
||||||
|
try:
|
||||||
|
yield core_lifecycle
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
_stop_res = core_lifecycle.stop()
|
||||||
|
if asyncio.iscoroutine(_stop_res):
|
||||||
|
await _stop_res
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def app(core_lifecycle_td: AstrBotCoreLifecycle):
|
||||||
|
"""Creates a Quart app instance for testing."""
|
||||||
|
shutdown_event = asyncio.Event()
|
||||||
|
server = AstrBotDashboard(core_lifecycle_td, core_lifecycle_td.db, shutdown_event)
|
||||||
|
return server.app
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture(scope="module")
|
||||||
|
async def authenticated_header(app: Quart, core_lifecycle_td: AstrBotCoreLifecycle):
|
||||||
|
"""Handles login and returns an authenticated header."""
|
||||||
|
test_client = app.test_client()
|
||||||
|
response = await test_client.post(
|
||||||
|
"/api/auth/login",
|
||||||
|
json={
|
||||||
|
"username": core_lifecycle_td.astrbot_config["dashboard"]["username"],
|
||||||
|
"password": core_lifecycle_td.astrbot_config["dashboard"]["password"],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
data = await response.get_json()
|
||||||
|
assert data["status"] == "ok"
|
||||||
|
token = data["data"]["token"]
|
||||||
|
return {"Authorization": f"Bearer {token}"}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_import_documents(
|
||||||
|
app: Quart, authenticated_header: dict, core_lifecycle_td: AstrBotCoreLifecycle
|
||||||
|
):
|
||||||
|
"""Tests the import documents functionality."""
|
||||||
|
test_client = app.test_client()
|
||||||
|
|
||||||
|
# Test data
|
||||||
|
import_data = {
|
||||||
|
"kb_id": "test_kb_id",
|
||||||
|
"documents": [
|
||||||
|
{"file_name": "test_file_1.txt", "chunks": ["chunk1", "chunk2"]},
|
||||||
|
{"file_name": "test_file_2.md", "chunks": ["chunk3", "chunk4", "chunk5"]},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
# Send request
|
||||||
|
response = await test_client.post(
|
||||||
|
"/api/kb/document/import", json=import_data, headers=authenticated_header
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify response
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = await response.get_json()
|
||||||
|
assert data["status"] == "ok"
|
||||||
|
assert "task_id" in data["data"]
|
||||||
|
assert data["data"]["doc_count"] == 2
|
||||||
|
|
||||||
|
task_id = data["data"]["task_id"]
|
||||||
|
|
||||||
|
# Wait for background task to complete (mocked)
|
||||||
|
# Since we mocked upload_document, it should be fast, but we might need to poll progress
|
||||||
|
for _ in range(10):
|
||||||
|
progress_response = await test_client.get(
|
||||||
|
f"/api/kb/document/upload/progress?task_id={task_id}",
|
||||||
|
headers=authenticated_header,
|
||||||
|
)
|
||||||
|
progress_data = await progress_response.get_json()
|
||||||
|
if progress_data["data"]["status"] == "completed":
|
||||||
|
break
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
|
assert progress_data["data"]["status"] == "completed"
|
||||||
|
result = progress_data["data"]["result"]
|
||||||
|
assert result["success_count"] == 2
|
||||||
|
assert result["failed_count"] == 0
|
||||||
|
|
||||||
|
# Verify kb_helper.upload_document was called correctly
|
||||||
|
kb_helper = await core_lifecycle_td.kb_manager.get_kb("test_kb_id")
|
||||||
|
assert kb_helper.upload_document.call_count == 2
|
||||||
|
|
||||||
|
# Check first call arguments
|
||||||
|
call_args_list = kb_helper.upload_document.call_args_list
|
||||||
|
|
||||||
|
# First document
|
||||||
|
args1, kwargs1 = call_args_list[0]
|
||||||
|
assert kwargs1["file_name"] == "test_file_1.txt"
|
||||||
|
assert kwargs1["pre_chunked_text"] == ["chunk1", "chunk2"]
|
||||||
|
|
||||||
|
# Second document
|
||||||
|
args2, kwargs2 = call_args_list[1]
|
||||||
|
assert kwargs2["file_name"] == "test_file_2.md"
|
||||||
|
assert kwargs2["pre_chunked_text"] == ["chunk3", "chunk4", "chunk5"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_import_documents_invalid_input(app: Quart, authenticated_header: dict):
|
||||||
|
"""Tests import documents with invalid input."""
|
||||||
|
test_client = app.test_client()
|
||||||
|
|
||||||
|
# Missing kb_id
|
||||||
|
response = await test_client.post(
|
||||||
|
"/api/kb/document/import", json={"documents": []}, headers=authenticated_header
|
||||||
|
)
|
||||||
|
data = await response.get_json()
|
||||||
|
assert data["status"] == "error"
|
||||||
|
assert "缺少参数 kb_id" in data["message"]
|
||||||
|
|
||||||
|
# Missing documents
|
||||||
|
response = await test_client.post(
|
||||||
|
"/api/kb/document/import",
|
||||||
|
json={"kb_id": "test_kb"},
|
||||||
|
headers=authenticated_header,
|
||||||
|
)
|
||||||
|
data = await response.get_json()
|
||||||
|
assert data["status"] == "error"
|
||||||
|
assert "缺少参数 documents" in data["message"]
|
||||||
|
|
||||||
|
# Invalid document format
|
||||||
|
response = await test_client.post(
|
||||||
|
"/api/kb/document/import",
|
||||||
|
json={
|
||||||
|
"kb_id": "test_kb",
|
||||||
|
"documents": [{"file_name": "test"}], # Missing chunks
|
||||||
|
},
|
||||||
|
headers=authenticated_header,
|
||||||
|
)
|
||||||
|
data = await response.get_json()
|
||||||
|
assert data["status"] == "error"
|
||||||
|
assert "文档格式错误" in data["message"]
|
||||||
|
|
||||||
|
# Invalid chunks type
|
||||||
|
response = await test_client.post(
|
||||||
|
"/api/kb/document/import",
|
||||||
|
json={
|
||||||
|
"kb_id": "test_kb",
|
||||||
|
"documents": [{"file_name": "test", "chunks": "not-a-list"}],
|
||||||
|
},
|
||||||
|
headers=authenticated_header,
|
||||||
|
)
|
||||||
|
data = await response.get_json()
|
||||||
|
assert data["status"] == "error"
|
||||||
|
assert "chunks 必须是列表" in data["message"]
|
||||||
|
|
||||||
|
# Invalid chunks content
|
||||||
|
response = await test_client.post(
|
||||||
|
"/api/kb/document/import",
|
||||||
|
json={
|
||||||
|
"kb_id": "test_kb",
|
||||||
|
"documents": [{"file_name": "test", "chunks": ["valid", ""]}],
|
||||||
|
},
|
||||||
|
headers=authenticated_header,
|
||||||
|
)
|
||||||
|
data = await response.get_json()
|
||||||
|
assert data["status"] == "error"
|
||||||
|
assert "chunks 必须是非空字符串列表" in data["message"]
|
||||||
Reference in New Issue
Block a user