Compare commits

..

5 Commits

Author SHA1 Message Date
copilot-swe-agent[bot] a2fe0ec5a1 Add webhook signature verification for security
Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>
2025-12-12 14:27:51 +00:00
copilot-swe-agent[bot] 6957ec713d Clean up unused imports in tests
Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>
2025-12-12 14:24:18 +00:00
copilot-swe-agent[bot] d97c8b5b2b Add tests for GitHub webhook platform adapter
Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>
2025-12-12 14:23:22 +00:00
copilot-swe-agent[bot] d07a1ad5c9 Add GitHub webhook platform adapter with event handlers
Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>
2025-12-12 14:20:33 +00:00
copilot-swe-agent[bot] d8e6dfbd6b Initial plan 2025-12-12 14:14:49 +00:00
35 changed files with 844 additions and 969 deletions
-6
View File
@@ -243,10 +243,4 @@ pre-commit install
</details> </details>
<div align="center">
_私は、高性能ですから!_ _私は、高性能ですから!_
<img src="https://files.astrbot.app/watashiwa-koseino-desukara.gif" width="100"/>
</div
+1 -1
View File
@@ -1 +1 @@
__version__ = "4.9.2" __version__ = "4.8.0"
+1 -14
View File
@@ -4,7 +4,7 @@ import os
from astrbot.core.utils.astrbot_path import get_astrbot_data_path from astrbot.core.utils.astrbot_path import get_astrbot_data_path
VERSION = "4.9.2" VERSION = "4.8.0"
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db") DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db")
WEBHOOK_SUPPORTED_PLATFORMS = [ WEBHOOK_SUPPORTED_PLATFORMS = [
@@ -108,7 +108,6 @@ DEFAULT_CONFIG = {
"provider_id": "", "provider_id": "",
"dual_output": False, "dual_output": False,
"use_file_service": False, "use_file_service": False,
"trigger_probability": 1.0,
}, },
"provider_ltm_settings": { "provider_ltm_settings": {
"group_icl_enable": False, "group_icl_enable": False,
@@ -2210,9 +2209,6 @@ CONFIG_METADATA_2 = {
"use_file_service": { "use_file_service": {
"type": "bool", "type": "bool",
}, },
"trigger_probability": {
"type": "float",
},
}, },
}, },
"provider_ltm_settings": { "provider_ltm_settings": {
@@ -2423,14 +2419,6 @@ CONFIG_METADATA_3 = {
"provider_tts_settings.enable": True, "provider_tts_settings.enable": True,
}, },
}, },
"provider_tts_settings.trigger_probability": {
"description": "TTS 触发概率",
"type": "float",
"slider": {"min": 0, "max": 1, "step": 0.05},
"condition": {
"provider_tts_settings.enable": True,
},
},
"provider_settings.image_caption_prompt": { "provider_settings.image_caption_prompt": {
"description": "图片转述提示词", "description": "图片转述提示词",
"type": "text", "type": "text",
@@ -2998,7 +2986,6 @@ CONFIG_METADATA_3 = {
"description": "回复概率", "description": "回复概率",
"type": "float", "type": "float",
"hint": "0.0-1.0 之间的数值", "hint": "0.0-1.0 之间的数值",
"slider": {"min": 0, "max": 1, "step": 0.05},
"condition": { "condition": {
"provider_ltm_settings.active_reply.enable": True, "provider_ltm_settings.active_reply.enable": True,
}, },
-1
View File
@@ -79,7 +79,6 @@ class ConfigMetadataI18n:
"_special", "_special",
"invisible", "invisible",
"options", "options",
"slider",
]: ]:
if attr in field_data: if attr in field_data:
field_result[attr] = field_data[attr] field_result[attr] = field_data[attr]
+1 -2
View File
@@ -24,7 +24,6 @@ import asyncio
import logging import logging
import os import os
import sys import sys
import time
from asyncio import Queue from asyncio import Queue
from collections import deque from collections import deque
@@ -149,7 +148,7 @@ class LogQueueHandler(logging.Handler):
self.log_broker.publish( self.log_broker.publish(
{ {
"level": record.levelname, "level": record.levelname,
"time": time.time(), "time": record.asctime,
"data": log_entry, "data": log_entry,
}, },
) )
-4
View File
@@ -158,11 +158,7 @@ class RespondStage(Stage):
result = event.get_result() result = event.get_result()
if result is None: if result is None:
return return
if event.get_extra("_streaming_finished", False):
# prevent some plugin make result content type to LLM_RESULT after streaming finished, lead to send again
return
if result.result_content_type == ResultContentType.STREAMING_FINISH: if result.result_content_type == ResultContentType.STREAMING_FINISH:
event.set_extra("_streaming_finished", True)
return return
logger.info( logger.info(
+1 -21
View File
@@ -1,4 +1,3 @@
import random
import re import re
import time import time
import traceback import traceback
@@ -43,18 +42,6 @@ class ResultDecorateStage(Stage):
"forward_threshold" "forward_threshold"
] ]
trigger_probability = ctx.astrbot_config["provider_tts_settings"].get(
"trigger_probability",
1,
)
try:
self.tts_trigger_probability = max(
0.0,
min(float(trigger_probability), 1.0),
)
except (TypeError, ValueError):
self.tts_trigger_probability = 1.0
# 分段回复 # 分段回复
self.words_count_threshold = int( self.words_count_threshold = int(
ctx.astrbot_config["platform_settings"]["segmented_reply"][ ctx.astrbot_config["platform_settings"]["segmented_reply"][
@@ -259,14 +246,7 @@ class ResultDecorateStage(Stage):
and result.is_llm_result() and result.is_llm_result()
and SessionServiceManager.should_process_tts_request(event) and SessionServiceManager.should_process_tts_request(event)
): ):
should_tts = self.tts_trigger_probability >= 1.0 or ( if not tts_provider:
self.tts_trigger_probability > 0.0
and random.random() <= self.tts_trigger_probability
)
if not should_tts:
logger.debug("跳过 TTS:触发概率未命中。")
elif not tts_provider:
logger.warning( logger.warning(
f"会话 {event.unified_msg_origin} 未配置文本转语音模型。", f"会话 {event.unified_msg_origin} 未配置文本转语音模型。",
) )
+4
View File
@@ -112,6 +112,10 @@ class PlatformManager:
from .sources.satori.satori_adapter import ( from .sources.satori.satori_adapter import (
SatoriPlatformAdapter, # noqa: F401 SatoriPlatformAdapter, # noqa: F401
) )
case "github_webhook":
from .sources.github_webhook.github_webhook_adapter import (
GitHubWebhookPlatformAdapter, # noqa: F401
)
except (ImportError, ModuleNotFoundError) as e: except (ImportError, ModuleNotFoundError) as e:
logger.error( logger.error(
f"加载平台适配器 {platform_config['type']} 失败,原因:{e}。请检查依赖库是否安装。提示:可以在 管理面板->平台日志->安装Pip库 中安装依赖库。", f"加载平台适配器 {platform_config['type']} 失败,原因:{e}。请检查依赖库是否安装。提示:可以在 管理面板->平台日志->安装Pip库 中安装依赖库。",
@@ -0,0 +1,315 @@
import asyncio
import hashlib
import hmac
from typing import Any, cast
from astrbot import logger
from astrbot.api.event import MessageChain
from astrbot.api.message_components import Plain
from astrbot.api.platform import (
AstrBotMessage,
MessageMember,
MessageType,
Platform,
PlatformMetadata,
)
from astrbot.core.platform.astr_message_event import MessageSesion
from astrbot.core.platform.platform import PlatformStatus
from astrbot.core.utils.webhook_utils import log_webhook_info
from ...register import register_platform_adapter
from .github_webhook_event import GitHubWebhookMessageEvent
@register_platform_adapter(
"github_webhook",
"GitHub Webhook 适配器",
support_streaming_message=False,
)
class GitHubWebhookPlatformAdapter(Platform):
"""GitHub Webhook 平台适配器
支持的事件:
- issues (created)
- issue_comment (created)
- pull_request (opened)
"""
def __init__(
self,
platform_config: dict,
platform_settings: dict,
event_queue: asyncio.Queue,
) -> None:
super().__init__(platform_config, event_queue)
self.unified_webhook_mode = platform_config.get("unified_webhook_mode", True)
self.webhook_secret = platform_config.get("webhook_secret", "")
self.shutdown_event = asyncio.Event()
async def send_by_session(
self,
session: MessageSesion,
message_chain: MessageChain,
):
"""GitHub Webhook 是单向接收,不支持主动发送消息"""
logger.warning("GitHub Webhook 适配器不支持 send_by_session")
def meta(self) -> PlatformMetadata:
return PlatformMetadata(
name="github_webhook",
description="GitHub Webhook 适配器",
id=cast(str, self.config.get("id")),
)
async def run(self):
"""运行适配器"""
self.status = PlatformStatus.RUNNING
# 如果启用统一 webhook 模式
webhook_uuid = self.config.get("webhook_uuid")
if self.unified_webhook_mode and webhook_uuid:
log_webhook_info(f"{self.meta().id}(GitHub Webhook)", webhook_uuid)
# 保持运行状态,等待 shutdown
await self.shutdown_event.wait()
else:
logger.warning("GitHub Webhook 适配器需要启用统一 webhook 模式")
await self.shutdown_event.wait()
async def webhook_callback(self, request: Any) -> Any:
"""统一 Webhook 回调入口
处理 GitHub webhook 事件
Args:
request: Quart 请求对象
Returns:
响应数据
"""
try:
# 获取事件类型
event_type = request.headers.get("X-GitHub-Event", "")
# 获取请求数据
payload = await request.json
# 验证 webhook 签名(如果配置了 secret
if self.webhook_secret:
if not await self._verify_signature(request, payload):
logger.warning("GitHub webhook 签名验证失败")
return {"error": "Invalid signature"}, 401
logger.debug(f"收到 GitHub Webhook 事件: {event_type}")
# 处理不同类型的事件
if event_type == "issues":
await self._handle_issue_event(payload)
elif event_type == "issue_comment":
await self._handle_issue_comment_event(payload)
elif event_type == "pull_request":
await self._handle_pull_request_event(payload)
elif event_type == "ping":
# GitHub webhook 验证事件
return {"message": "pong"}
else:
logger.debug(f"忽略不支持的 GitHub 事件类型: {event_type}")
return {"status": "ok"}
except Exception as e:
logger.error(f"处理 GitHub webhook 回调时发生错误: {e}", exc_info=True)
return {"error": str(e)}, 500
async def _verify_signature(self, request: Any, payload: dict) -> bool:
"""验证 GitHub webhook 签名
Args:
request: Quart 请求对象
payload: 请求负载数据
Returns:
签名是否有效
"""
signature_header = request.headers.get("X-Hub-Signature-256", "")
if not signature_header:
# 如果没有签名头,检查是否有旧版本的签名
signature_header = request.headers.get("X-Hub-Signature", "")
if not signature_header:
return False
# 获取原始请求体
body = await request.get_data()
# 计算 HMAC
if signature_header.startswith("sha256="):
expected_signature = hmac.new(
self.webhook_secret.encode("utf-8"),
body,
hashlib.sha256,
).hexdigest()
received_signature = signature_header.replace("sha256=", "")
elif signature_header.startswith("sha1="):
expected_signature = hmac.new(
self.webhook_secret.encode("utf-8"),
body,
hashlib.sha1,
).hexdigest()
received_signature = signature_header.replace("sha1=", "")
else:
return False
# 使用 hmac.compare_digest 防止时序攻击
return hmac.compare_digest(expected_signature, received_signature)
async def _handle_issue_event(self, payload: dict):
"""处理 issue 事件"""
action = payload.get("action", "")
# 只处理创建事件
if action != "created" and action != "opened":
return
issue = payload.get("issue", {})
repo = payload.get("repository", {})
sender = payload.get("sender", {})
# 构造消息文本
message_text = (
f"📝 新 Issue 创建\n"
f"仓库: {repo.get('full_name', 'unknown')}\n"
f"标题: {issue.get('title', 'No title')}\n"
f"作者: {sender.get('login', 'unknown')}\n"
f"链接: {issue.get('html_url', '')}\n"
f"内容:\n{issue.get('body', 'No description')[:200]}"
)
# 创建 AstrBotMessage
abm = self._create_message(
message_text,
sender.get("login", "unknown"),
sender.get("login", "unknown"),
repo.get("full_name", "unknown"),
)
# 提交事件
self.commit_event(
GitHubWebhookMessageEvent(
message_text,
abm,
self.meta(),
repo.get("full_name", "unknown"),
"issues",
payload,
)
)
async def _handle_issue_comment_event(self, payload: dict):
"""处理 issue 评论事件"""
action = payload.get("action", "")
# 只处理创建事件
if action != "created":
return
issue = payload.get("issue", {})
comment = payload.get("comment", {})
repo = payload.get("repository", {})
sender = payload.get("sender", {})
# 构造消息文本
message_text = (
f"💬 新 Issue 评论\n"
f"仓库: {repo.get('full_name', 'unknown')}\n"
f"Issue: {issue.get('title', 'No title')}\n"
f"评论者: {sender.get('login', 'unknown')}\n"
f"链接: {comment.get('html_url', '')}\n"
f"内容:\n{comment.get('body', 'No comment')[:200]}"
)
# 创建 AstrBotMessage
abm = self._create_message(
message_text,
sender.get("login", "unknown"),
sender.get("login", "unknown"),
repo.get("full_name", "unknown"),
)
# 提交事件
self.commit_event(
GitHubWebhookMessageEvent(
message_text,
abm,
self.meta(),
repo.get("full_name", "unknown"),
"issue_comment",
payload,
)
)
async def _handle_pull_request_event(self, payload: dict):
"""处理 pull request 事件"""
action = payload.get("action", "")
# 只处理打开事件
if action != "opened":
return
pr = payload.get("pull_request", {})
repo = payload.get("repository", {})
sender = payload.get("sender", {})
# 构造消息文本
message_text = (
f"🔀 新 Pull Request\n"
f"仓库: {repo.get('full_name', 'unknown')}\n"
f"标题: {pr.get('title', 'No title')}\n"
f"作者: {sender.get('login', 'unknown')}\n"
f"链接: {pr.get('html_url', '')}\n"
f"内容:\n{pr.get('body', 'No description')[:200]}"
)
# 创建 AstrBotMessage
abm = self._create_message(
message_text,
sender.get("login", "unknown"),
sender.get("login", "unknown"),
repo.get("full_name", "unknown"),
)
# 提交事件
self.commit_event(
GitHubWebhookMessageEvent(
message_text,
abm,
self.meta(),
repo.get("full_name", "unknown"),
"pull_request",
payload,
)
)
def _create_message(
self,
message_text: str,
user_id: str,
nickname: str,
session_id: str,
) -> AstrBotMessage:
"""创建 AstrBotMessage 对象"""
abm = AstrBotMessage()
abm.type = MessageType.GROUP_MESSAGE
abm.self_id = self.client_self_id
abm.session_id = session_id
abm.message_id = ""
abm.sender = MessageMember(user_id=user_id, nickname=nickname)
abm.message = [Plain(message_text)]
abm.message_str = message_text
abm.raw_message = message_text
return abm
async def terminate(self):
"""终止适配器运行"""
self.shutdown_event.set()
logger.info("GitHub Webhook 适配器已经被优雅地关闭")
@@ -0,0 +1,22 @@
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
from ...astr_message_event import AstrMessageEvent
class GitHubWebhookMessageEvent(AstrMessageEvent):
"""GitHub Webhook 消息事件"""
def __init__(
self,
message_str: str,
message_obj: AstrBotMessage,
platform_meta: PlatformMetadata,
session_id: str,
event_type: str,
event_data: dict,
):
super().__init__(message_str, message_obj, platform_meta, session_id)
self.event_type = event_type
"""GitHub 事件类型: issues, issue_comment, pull_request"""
self.event_data = event_data
"""原始事件数据"""
@@ -81,12 +81,7 @@ class LarkPlatformAdapter(Platform):
) )
self.lark_api = ( self.lark_api = (
lark.Client.builder() lark.Client.builder().app_id(self.appid).app_secret(self.appsecret).build()
.app_id(self.appid)
.app_secret(self.appsecret)
.log_level(lark.LogLevel.ERROR)
.domain(self.domain)
.build()
) )
self.webhook_server = None self.webhook_server = None
+1 -5
View File
@@ -2,19 +2,15 @@ from astrbot.core import html_renderer
from astrbot.core.provider import Provider from astrbot.core.provider import Provider
from astrbot.core.star.star_tools import StarTools from astrbot.core.star.star_tools import StarTools
from astrbot.core.utils.command_parser import CommandParserMixin from astrbot.core.utils.command_parser import CommandParserMixin
from astrbot.core.utils.plugin_kv_store import PluginKVStoreMixin
from .context import Context from .context import Context
from .star import StarMetadata, star_map, star_registry from .star import StarMetadata, star_map, star_registry
from .star_manager import PluginManager from .star_manager import PluginManager
class Star(CommandParserMixin, PluginKVStoreMixin): class Star(CommandParserMixin):
"""所有插件(Star)的父类,所有插件都应该继承于这个类""" """所有插件(Star)的父类,所有插件都应该继承于这个类"""
author: str
name: str
def __init__(self, context: Context, config: dict | None = None): def __init__(self, context: Context, config: dict | None = None):
StarTools.initialize(context) StarTools.initialize(context)
self.context = context self.context = context
-12
View File
@@ -467,18 +467,6 @@ class PluginManager:
metadata.star_cls = metadata.star_cls_type( metadata.star_cls = metadata.star_cls_type(
context=self.context, context=self.context,
) )
p_name = (metadata.name or "unknown").lower().replace("/", "_")
p_author = (
(metadata.author or "unknown").lower().replace("/", "_")
)
setattr(metadata.star_cls, "name", p_name)
setattr(metadata.star_cls, "author", p_author)
setattr(
metadata.star_cls,
"plugin_id",
f"{p_author}/{p_name}",
)
else: else:
logger.info(f"插件 {metadata.name} 已被禁用。") logger.info(f"插件 {metadata.name} 已被禁用。")
-28
View File
@@ -1,28 +0,0 @@
from typing import TypeVar
from astrbot.core import sp
SUPPORTED_VALUE_TYPES = int | float | str | bytes | bool | dict | list | None
_VT = TypeVar("_VT")
class PluginKVStoreMixin:
"""为插件提供键值存储功能的 Mixin 类"""
plugin_id: str
async def put_kv_data(
self,
key: str,
value: SUPPORTED_VALUE_TYPES,
) -> None:
"""为指定插件存储一个键值对"""
await sp.put_async("plugin", self.plugin_id, key, value)
async def get_kv_data(self, key: str, default: _VT) -> _VT | None:
"""获取指定插件存储的键值对"""
return await sp.get_async("plugin", self.plugin_id, key, default)
async def delete_kv_data(self, key: str) -> None:
"""删除指定插件存储的键值对"""
await sp.remove_async("plugin", self.plugin_id, key)
+1 -91
View File
@@ -1,9 +1,7 @@
import json import json
import traceback import traceback
from datetime import datetime
from io import BytesIO
from quart import request, send_file from quart import request
from astrbot.core import logger from astrbot.core import logger
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
@@ -32,7 +30,6 @@ class ConversationRoute(Route):
"POST", "POST",
self.update_history, self.update_history,
), ),
"/conversation/export": ("POST", self.export_conversations),
} }
self.db_helper = db_helper self.db_helper = db_helper
self.conv_mgr = core_lifecycle.conversation_manager self.conv_mgr = core_lifecycle.conversation_manager
@@ -286,90 +283,3 @@ class ConversationRoute(Route):
except Exception as e: except Exception as e:
logger.error(f"更新对话历史失败: {e!s}\n{traceback.format_exc()}") logger.error(f"更新对话历史失败: {e!s}\n{traceback.format_exc()}")
return Response().error(f"更新对话历史失败: {e!s}").__dict__ return Response().error(f"更新对话历史失败: {e!s}").__dict__
async def export_conversations(self):
"""批量导出对话为 JSONL 格式"""
try:
data = await request.get_json()
conversations_to_export = data.get("conversations", [])
if not conversations_to_export:
return Response().error("导出列表不能为空").__dict__
# 收集所有对话的内容
jsonl_lines = []
exported_count = 0
failed_items = []
for conv_info in conversations_to_export:
user_id = conv_info.get("user_id")
cid = conv_info.get("cid")
if not user_id or not cid:
failed_items.append(
f"user_id:{user_id}, cid:{cid} - 缺少必要参数",
)
continue
try:
conversation = await self.conv_mgr.get_conversation(
unified_msg_origin=user_id,
conversation_id=cid,
)
if not conversation:
failed_items.append(
f"user_id:{user_id}, cid:{cid} - 对话不存在"
)
continue
# 解析对话内容 (history is always a JSON string from _convert_conv_from_v2_to_v1)
content = json.loads(conversation.history)
# 创建导出记录
export_record = {
"cid": cid,
"user_id": user_id,
"platform_id": conversation.platform_id,
"title": conversation.title,
"persona_id": conversation.persona_id,
"created_at": conversation.created_at,
"updated_at": conversation.updated_at,
"content": content,
}
# 将记录转换为 JSON 字符串并添加到 JSONL
jsonl_lines.append(json.dumps(export_record, ensure_ascii=False))
exported_count += 1
except Exception as e:
failed_items.append(f"user_id:{user_id}, cid:{cid} - {e!s}")
logger.error(
f"导出对话失败: user_id={user_id}, cid={cid}, error={e!s}"
)
if exported_count == 0:
return Response().error("没有成功导出任何对话").__dict__
# 创建 JSONL 内容
jsonl_content = "\n".join(jsonl_lines)
# 创建一个内存文件对象
file_obj = BytesIO(jsonl_content.encode("utf-8"))
file_obj.seek(0)
# 生成文件名
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"astrbot_conversations_export_{timestamp}.jsonl"
# 返回文件流
return await send_file(
file_obj,
mimetype="application/jsonl",
as_attachment=True,
attachment_filename=filename,
)
except Exception as e:
logger.error(f"批量导出对话失败: {e!s}\n{traceback.format_exc()}")
return Response().error(f"批量导出对话失败: {e!s}").__dict__
+78 -253
View File
@@ -48,7 +48,6 @@ class KnowledgeBaseRoute(Route):
# 文档管理 # 文档管理
"/kb/document/list": ("GET", self.list_documents), "/kb/document/list": ("GET", self.list_documents),
"/kb/document/upload": ("POST", self.upload_document), "/kb/document/upload": ("POST", self.upload_document),
"/kb/document/import": ("POST", self.import_documents),
"/kb/document/upload/url": ("POST", self.upload_document_from_url), "/kb/document/upload/url": ("POST", self.upload_document_from_url),
"/kb/document/upload/progress": ("GET", self.get_upload_progress), "/kb/document/upload/progress": ("GET", self.get_upload_progress),
"/kb/document/get": ("GET", self.get_document), "/kb/document/get": ("GET", self.get_document),
@@ -67,65 +66,6 @@ class KnowledgeBaseRoute(Route):
def _get_kb_manager(self): def _get_kb_manager(self):
return self.core_lifecycle.kb_manager return self.core_lifecycle.kb_manager
def _init_task(self, task_id: str, status: str = "pending") -> None:
self.upload_tasks[task_id] = {
"status": status,
"result": None,
"error": None,
}
def _set_task_result(
self, task_id: str, status: str, result: any = None, error: str | None = None
) -> None:
self.upload_tasks[task_id] = {
"status": status,
"result": result,
"error": error,
}
if task_id in self.upload_progress:
self.upload_progress[task_id]["status"] = status
def _update_progress(
self,
task_id: str,
*,
status: str | None = None,
file_index: int | None = None,
file_name: str | None = None,
stage: str | None = None,
current: int | None = None,
total: int | None = None,
) -> None:
if task_id not in self.upload_progress:
return
p = self.upload_progress[task_id]
if status is not None:
p["status"] = status
if file_index is not None:
p["file_index"] = file_index
if file_name is not None:
p["file_name"] = file_name
if stage is not None:
p["stage"] = stage
if current is not None:
p["current"] = current
if total is not None:
p["total"] = total
def _make_progress_callback(self, task_id: str, file_idx: int, file_name: str):
async def _callback(stage: str, current: int, total: int):
self._update_progress(
task_id,
status="processing",
file_index=file_idx,
file_name=file_name,
stage=stage,
current=current,
total=total,
)
return _callback
async def _background_upload_task( async def _background_upload_task(
self, self,
task_id: str, task_id: str,
@@ -140,7 +80,11 @@ class KnowledgeBaseRoute(Route):
"""后台上传任务""" """后台上传任务"""
try: try:
# 初始化任务状态 # 初始化任务状态
self._init_task(task_id, status="processing") self.upload_tasks[task_id] = {
"status": "processing",
"result": None,
"error": None,
}
self.upload_progress[task_id] = { self.upload_progress[task_id] = {
"status": "processing", "status": "processing",
"file_index": 0, "file_index": 0,
@@ -156,20 +100,30 @@ class KnowledgeBaseRoute(Route):
for file_idx, file_info in enumerate(files_to_upload): for file_idx, file_info in enumerate(files_to_upload):
try: try:
# 更新整体进度 # 更新整体进度
self._update_progress( self.upload_progress[task_id].update(
task_id, {
status="processing", "status": "processing",
file_index=file_idx, "file_index": file_idx,
file_name=file_info["file_name"], "file_name": file_info["file_name"],
stage="parsing", "stage": "parsing",
current=0, "current": 0,
total=100, "total": 100,
},
) )
# 创建进度回调函数 # 创建进度回调函数
progress_callback = self._make_progress_callback( async def progress_callback(stage, current, total):
task_id, file_idx, file_info["file_name"] if task_id in self.upload_progress:
) self.upload_progress[task_id].update(
{
"status": "processing",
"file_index": file_idx,
"file_name": file_info["file_name"],
"stage": stage,
"current": current,
"total": total,
},
)
doc = await kb_helper.upload_document( doc = await kb_helper.upload_document(
file_name=file_info["file_name"], file_name=file_info["file_name"],
@@ -200,99 +154,23 @@ class KnowledgeBaseRoute(Route):
"failed_count": len(failed_docs), "failed_count": len(failed_docs),
} }
self._set_task_result(task_id, "completed", result=result) self.upload_tasks[task_id] = {
"status": "completed",
"result": result,
"error": None,
}
self.upload_progress[task_id]["status"] = "completed"
except Exception as e: except Exception as e:
logger.error(f"后台上传任务 {task_id} 失败: {e}") logger.error(f"后台上传任务 {task_id} 失败: {e}")
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
self._set_task_result(task_id, "failed", error=str(e)) self.upload_tasks[task_id] = {
"status": "failed",
async def _background_import_task( "result": None,
self, "error": str(e),
task_id: str,
kb_helper,
documents: list,
batch_size: int,
tasks_limit: int,
max_retries: int,
):
"""后台导入预切片文档任务"""
try:
# 初始化任务状态
self._init_task(task_id, status="processing")
self.upload_progress[task_id] = {
"status": "processing",
"file_index": 0,
"file_total": len(documents),
"stage": "waiting",
"current": 0,
"total": 100,
} }
if task_id in self.upload_progress:
uploaded_docs = [] self.upload_progress[task_id]["status"] = "failed"
failed_docs = []
for file_idx, doc_info in enumerate(documents):
file_name = doc_info.get("file_name", f"imported_doc_{file_idx}")
chunks = doc_info.get("chunks", [])
try:
# 更新整体进度
self._update_progress(
task_id,
status="processing",
file_index=file_idx,
file_name=file_name,
stage="importing",
current=0,
total=100,
)
# 创建进度回调函数
progress_callback = self._make_progress_callback(
task_id, file_idx, file_name
)
# 调用 upload_document,传入 pre_chunked_text
doc = await kb_helper.upload_document(
file_name=file_name,
file_content=None, # 预切片模式下不需要原始内容
file_type=doc_info.get("file_type")
or (
file_name.rsplit(".", 1)[-1].lower()
if "." in file_name
else "txt"
),
batch_size=batch_size,
tasks_limit=tasks_limit,
max_retries=max_retries,
progress_callback=progress_callback,
pre_chunked_text=chunks,
)
uploaded_docs.append(doc.model_dump())
except Exception as e:
logger.error(f"导入文档 {file_name} 失败: {e}")
failed_docs.append(
{"file_name": file_name, "error": str(e)},
)
# 更新任务完成状态
result = {
"task_id": task_id,
"uploaded": uploaded_docs,
"failed": failed_docs,
"total": len(documents),
"success_count": len(uploaded_docs),
"failed_count": len(failed_docs),
}
self._set_task_result(task_id, "completed", result=result)
except Exception as e:
logger.error(f"后台导入任务 {task_id} 失败: {e}")
logger.error(traceback.format_exc())
self._set_task_result(task_id, "failed", error=str(e))
async def list_kbs(self): async def list_kbs(self):
"""获取知识库列表 """获取知识库列表
@@ -736,7 +614,11 @@ class KnowledgeBaseRoute(Route):
task_id = str(uuid.uuid4()) task_id = str(uuid.uuid4())
# 初始化任务状态 # 初始化任务状态
self._init_task(task_id, status="pending") self.upload_tasks[task_id] = {
"status": "pending",
"result": None,
"error": None,
}
# 启动后台任务 # 启动后台任务
asyncio.create_task( asyncio.create_task(
@@ -771,93 +653,6 @@ class KnowledgeBaseRoute(Route):
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
return Response().error(f"上传文档失败: {e!s}").__dict__ return Response().error(f"上传文档失败: {e!s}").__dict__
def _validate_import_request(self, data: dict):
kb_id = data.get("kb_id")
if not kb_id:
raise ValueError("缺少参数 kb_id")
documents = data.get("documents")
if not documents or not isinstance(documents, list):
raise ValueError("缺少参数 documents 或格式错误")
for doc in documents:
if "file_name" not in doc or "chunks" not in doc:
raise ValueError("文档格式错误,必须包含 file_name 和 chunks")
if not isinstance(doc["chunks"], list):
raise ValueError("chunks 必须是列表")
if not all(
isinstance(chunk, str) and chunk.strip() for chunk in doc["chunks"]
):
raise ValueError("chunks 必须是非空字符串列表")
batch_size = data.get("batch_size", 32)
tasks_limit = data.get("tasks_limit", 3)
max_retries = data.get("max_retries", 3)
return kb_id, documents, batch_size, tasks_limit, max_retries
async def import_documents(self):
"""导入预切片文档
Body:
- kb_id: 知识库 ID (必填)
- documents: 文档列表 (必填)
- file_name: 文件名 (必填)
- chunks: 切片列表 (必填, list[str])
- file_type: 文件类型 (可选, 默认从文件名推断或为 txt)
- batch_size: 批处理大小 (可选, 默认32)
- tasks_limit: 并发任务限制 (可选, 默认3)
- max_retries: 最大重试次数 (可选, 默认3)
"""
try:
kb_manager = self._get_kb_manager()
data = await request.json
kb_id, documents, batch_size, tasks_limit, max_retries = (
self._validate_import_request(data)
)
# 获取知识库
kb_helper = await kb_manager.get_kb(kb_id)
if not kb_helper:
return Response().error("知识库不存在").__dict__
# 生成任务ID
task_id = str(uuid.uuid4())
# 初始化任务状态
self._init_task(task_id, status="pending")
# 启动后台任务
asyncio.create_task(
self._background_import_task(
task_id=task_id,
kb_helper=kb_helper,
documents=documents,
batch_size=batch_size,
tasks_limit=tasks_limit,
max_retries=max_retries,
),
)
return (
Response()
.ok(
{
"task_id": task_id,
"doc_count": len(documents),
"message": "import task created, processing in background",
},
)
.__dict__
)
except ValueError as e:
return Response().error(str(e)).__dict__
except Exception as e:
logger.error(f"导入文档失败: {e}")
logger.error(traceback.format_exc())
return Response().error(f"导入文档失败: {e!s}").__dict__
async def get_upload_progress(self): async def get_upload_progress(self):
"""获取上传进度和结果 """获取上传进度和结果
@@ -1165,7 +960,11 @@ class KnowledgeBaseRoute(Route):
task_id = str(uuid.uuid4()) task_id = str(uuid.uuid4())
# 初始化任务状态 # 初始化任务状态
self._init_task(task_id, status="pending") self.upload_tasks[task_id] = {
"status": "pending",
"result": None,
"error": None,
}
# 启动后台任务 # 启动后台任务
asyncio.create_task( asyncio.create_task(
@@ -1218,7 +1017,11 @@ class KnowledgeBaseRoute(Route):
"""后台上传URL任务""" """后台上传URL任务"""
try: try:
# 初始化任务状态 # 初始化任务状态
self._init_task(task_id, status="processing") self.upload_tasks[task_id] = {
"status": "processing",
"result": None,
"error": None,
}
self.upload_progress[task_id] = { self.upload_progress[task_id] = {
"status": "processing", "status": "processing",
"file_index": 0, "file_index": 0,
@@ -1230,7 +1033,18 @@ class KnowledgeBaseRoute(Route):
} }
# 创建进度回调函数 # 创建进度回调函数
progress_callback = self._make_progress_callback(task_id, 0, f"URL: {url}") async def progress_callback(stage, current, total):
if task_id in self.upload_progress:
self.upload_progress[task_id].update(
{
"status": "processing",
"file_index": 0,
"file_name": f"URL: {url}",
"stage": stage,
"current": current,
"total": total,
},
)
# 上传文档 # 上传文档
doc = await kb_helper.upload_from_url( doc = await kb_helper.upload_from_url(
@@ -1255,9 +1069,20 @@ class KnowledgeBaseRoute(Route):
"failed_count": 0, "failed_count": 0,
} }
self._set_task_result(task_id, "completed", result=result) self.upload_tasks[task_id] = {
"status": "completed",
"result": result,
"error": None,
}
self.upload_progress[task_id]["status"] = "completed"
except Exception as e: except Exception as e:
logger.error(f"后台上传URL任务 {task_id} 失败: {e}") logger.error(f"后台上传URL任务 {task_id} 失败: {e}")
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
self._set_task_result(task_id, "failed", error=str(e)) self.upload_tasks[task_id] = {
"status": "failed",
"result": None,
"error": str(e),
}
if task_id in self.upload_progress:
self.upload_progress[task_id]["status"] = "failed"
+1 -5
View File
@@ -124,11 +124,7 @@ class PluginRoute(Route):
session.get(url) as response, session.get(url) as response,
): ):
if response.status == 200: if response.status == 200:
try: remote_data = await response.json()
remote_data = await response.json()
except aiohttp.ContentTypeError:
remote_text = await response.text()
remote_data = json.loads(remote_text)
# 检查远程数据是否为空 # 检查远程数据是否为空
if not remote_data or ( if not remote_data or (
-19
View File
@@ -1,19 +0,0 @@
## What's Changed
### 新增
- 支持自定义插件源。
- 支持飞书(Lark)的 Webhook 模式(将事件推送至开发者服务器)。
- 支持 “禁用自带指令” 快捷配置项,启用后将禁用所有 AstrBot 自带指令。入口: WebUI -> 配置文件 -> 平台配置。
### 优化
- 从 WebUI 移除了开发版本渠道。
- 当试图测试"Agent Runner"时,提示前往配置文件页测试。
- WebUI 列表项支持批量粘贴、回车创建项目。
### 修复
- Gemini API 部分调用失败的问题。
- WebUI 插件安装加载 Dialog 关闭按钮在手机端下显示异常的问题。
- 部分情况下,WebUI 日志显示不全的问题。
-3
View File
@@ -1,3 +0,0 @@
## What's Changed
-
-17
View File
@@ -1,17 +0,0 @@
## What's Changed
### 修复
- 企业自部署飞书(自定义 domain)可以接收消息但无法发送消息的问题。
- 安装插件 Dialog 的深色样式问题。
### 优化
- 避免某些插件在流式响应结束后重d复发送消息的问题。
### 新增
- 支持在对话管理批量导出对话轨迹数据为 `jsonl` 格式文件。入口:WebUI -> 对话管理 -> 批量选中 -> 导出。
- 支持对 TTS(文本转语音)设置概率触发。
- (插件开发)支持在 schema 中对 float 和 int 类型设置 `slider` 滑块控件。例如 `slider: {min: 0, max: 1, step: 0.1}`
- (插件开发)支持 key-value 存储功能。例如使用 `await self.put_kv_data("key", value)`, `await self.get_kv_data("key", default_value)``await self.delete_kv_data("key")`
Binary file not shown.

Before

Width:  |  Height:  |  Size: 12 KiB

@@ -304,32 +304,16 @@ function hasVisibleItemsAfter(items, currentIndex) {
hide-details hide-details
></v-text-field> ></v-text-field>
<!-- Numeric input with optional slider --> <!-- Numeric input -->
<div <v-text-field
v-else-if="(metadata[metadataKey].items[key]?.type === 'int' || metadata[metadataKey].items[key]?.type === 'float') && !metadata[metadataKey]?.invisible" v-else-if="(metadata[metadataKey].items[key]?.type === 'int' || metadata[metadataKey].items[key]?.type === 'float') && !metadata[metadataKey]?.invisible"
class="d-flex align-center gap-3" v-model="iterable[key]"
> density="compact"
<v-slider variant="outlined"
v-if="metadata[metadataKey].items[key]?.slider" class="config-field"
v-model.number="iterable[key]" type="number"
:min="metadata[metadataKey].items[key]?.slider?.min ?? 0" hide-details
:max="metadata[metadataKey].items[key]?.slider?.max ?? 100" ></v-text-field>
:step="metadata[metadataKey].items[key]?.slider?.step ?? 1"
color="primary"
density="compact"
hide-details
class="flex-grow-1"
></v-slider>
<v-text-field
v-model.number="iterable[key]"
density="compact"
variant="outlined"
class="config-field"
type="number"
hide-details
style="max-width: 140px;"
></v-text-field>
</div>
<!-- Text area --> <!-- Text area -->
<v-textarea <v-textarea
@@ -429,32 +413,16 @@ function hasVisibleItemsAfter(items, currentIndex) {
hide-details hide-details
></v-text-field> ></v-text-field>
<!-- Numeric input with optional slider --> <!-- Numeric input -->
<div <v-text-field
v-else-if="(metadata[metadataKey]?.type === 'int' || metadata[metadataKey]?.type === 'float') && !metadata[metadataKey]?.invisible" v-else-if="(metadata[metadataKey]?.type === 'int' || metadata[metadataKey]?.type === 'float') && !metadata[metadataKey]?.invisible"
class="d-flex align-center gap-3" v-model="iterable[metadataKey]"
> density="compact"
<v-slider variant="outlined"
v-if="metadata[metadataKey]?.slider" class="config-field"
v-model.number="iterable[metadataKey]" type="number"
:min="metadata[metadataKey]?.slider?.min ?? 0" hide-details
:max="metadata[metadataKey]?.slider?.max ?? 100" ></v-text-field>
:step="metadata[metadataKey]?.slider?.step ?? 1"
color="primary"
density="compact"
hide-details
class="flex-grow-1"
></v-slider>
<v-text-field
v-model.number="iterable[metadataKey]"
density="compact"
variant="outlined"
class="config-field"
type="number"
hide-details
style="max-width: 140px;"
></v-text-field>
</div>
<!-- Text area --> <!-- Text area -->
<v-textarea <v-textarea
@@ -245,29 +245,10 @@ function getSpecialSubtype(value) {
<v-text-field v-else-if="itemMeta?.type === 'string'" v-model="createSelectorModel(itemKey).value" <v-text-field v-else-if="itemMeta?.type === 'string'" v-model="createSelectorModel(itemKey).value"
density="compact" variant="outlined" class="config-field" hide-details></v-text-field> density="compact" variant="outlined" class="config-field" hide-details></v-text-field>
<!-- Numeric input with optional slider for JSON selector --> <!-- Numeric input for JSON selector -->
<div v-else-if="itemMeta?.type === 'int' || itemMeta?.type === 'float'" class="d-flex align-center gap-3"> <v-text-field v-else-if="itemMeta?.type === 'int' || itemMeta?.type === 'float'"
<v-slider v-model="createSelectorModel(itemKey).value" density="compact" variant="outlined" class="config-field"
v-if="itemMeta?.slider" type="number" hide-details></v-text-field>
v-model.number="createSelectorModel(itemKey).value"
:min="itemMeta?.slider?.min ?? 0"
:max="itemMeta?.slider?.max ?? 100"
:step="itemMeta?.slider?.step ?? 1"
color="primary"
density="compact"
hide-details
style="flex: 3"
></v-slider>
<v-text-field
v-model.number="createSelectorModel(itemKey).value"
density="compact"
variant="outlined"
class="config-field"
style="flex: 2"
type="number"
hide-details
></v-text-field>
</div>
<!-- Text area for JSON selector --> <!-- Text area for JSON selector -->
<v-textarea v-else-if="itemMeta?.type === 'text'" v-model="createSelectorModel(itemKey).value" <v-textarea v-else-if="itemMeta?.type === 'text'" v-model="createSelectorModel(itemKey).value"
@@ -1,7 +1,6 @@
<script setup> <script setup>
import { useCommonStore } from '@/stores/common'; import { useCommonStore } from '@/stores/common';
import { storeToRefs } from 'pinia'; import { storeToRefs } from 'pinia';
import axios from 'axios';
</script> </script>
<template> <template>
@@ -25,6 +24,8 @@ import axios from 'axios';
export default { export default {
name: 'ConsoleDisplayer', name: 'ConsoleDisplayer',
data() { data() {
const commonStore = useCommonStore();
const { log_cache } = storeToRefs(commonStore);
return { return {
autoScroll: true, // autoScroll: true, //
logColorAnsiMap: { logColorAnsiMap: {
@@ -37,6 +38,7 @@ export default {
'\u001b[32m': 'color: #00FF00;', // green '\u001b[32m': 'color: #00FF00;', // green
'default': 'color: #FFFFFF;' 'default': 'color: #FFFFFF;'
}, },
logCache: log_cache,
historyNum_: -1, historyNum_: -1,
logLevels: ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], logLevels: ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
selectedLevels: [0, 1, 2, 3, 4], // selectedLevels: [0, 1, 2, 3, 4], //
@@ -46,17 +48,7 @@ export default {
'WARNING': 'amber', 'WARNING': 'amber',
'ERROR': 'red', 'ERROR': 'red',
'CRITICAL': 'purple' 'CRITICAL': 'purple'
}, }
lastProcessedTime: 0, //
localLogCache: [], //
}
},
computed: {
commonStore() {
return useCommonStore();
},
logCache() {
return this.commonStore.log_cache;
} }
}, },
props: { props: {
@@ -71,39 +63,13 @@ export default {
}, },
watch: { watch: {
logCache: { logCache: {
handler(newVal) { handler(val) {
// timestamp const lastLog = val[this.logCache.length - 1];
if (newVal && newVal.length > 0) { if (lastLog && this.isLevelSelected(lastLog.level)) {
// DOM this.printLog(lastLog.data);
this.$nextTick(() => {
//
const newLogs = newVal.filter(log => log.time > this.lastProcessedTime);
if (newLogs.length > 0) {
this.localLogCache.push(...newLogs);
//
this.localLogCache.sort((a, b) => a.time - b.time);
// log_cache_max_len
if (this.localLogCache.length > this.commonStore.log_cache_max_len) {
this.localLogCache.splice(0, this.localLogCache.length - this.commonStore.log_cache_max_len);
}
//
newLogs.forEach(logItem => {
if (this.isLevelSelected(logItem.level)) {
this.printLog(logItem.data);
}
});
//
this.lastProcessedTime = Math.max(...newLogs.map(log => log.time));
}
});
} }
}, },
deep: true, deep: true
immediate: false
}, },
selectedLevels: { selectedLevels: {
handler() { handler() {
@@ -112,37 +78,14 @@ export default {
deep: true deep: true
} }
}, },
async mounted() { mounted() {
// if (this.logCache.length === 0) {
await this.fetchLogHistory(); this.delayInit()
} else {
// DOM this.init()
this.$nextTick(() => { }
if (this.localLogCache.length > 0) {
this.localLogCache.forEach(logItem => {
if (this.isLevelSelected(logItem.level)) {
this.printLog(logItem.data);
}
});
//
this.lastProcessedTime = Math.max(...this.localLogCache.map(log => log.time));
}
});
}, },
methods: { methods: {
async fetchLogHistory() {
try {
const res = await axios.get('/api/log-history');
if (res.data.data.logs && res.data.data.logs.length > 0) {
this.localLogCache = [...res.data.data.logs];
//
this.localLogCache.sort((a, b) => a.time - b.time);
}
} catch (err) {
console.error('Failed to fetch log history:', err);
}
},
getLevelColor(level) { getLevelColor(level) {
return this.levelColors[level] || 'grey'; return this.levelColors[level] || 'grey';
}, },
@@ -158,21 +101,40 @@ export default {
}, },
refreshDisplay() { refreshDisplay() {
//
const termElement = document.getElementById('term'); const termElement = document.getElementById('term');
if (termElement) { if (termElement) {
termElement.innerHTML = ''; termElement.innerHTML = '';
}
//
if (this.localLogCache && this.localLogCache.length > 0) { //
this.localLogCache.forEach(logItem => { this.init();
if (this.isLevelSelected(logItem.level)) { },
this.printLog(logItem.data);
} delayInit() {
}); if (this.logCache.length === 0) {
} setTimeout(() => {
this.delayInit()
}, 500)
} else {
this.init()
} }
}, },
init() {
this.historyNum_ = parseInt(this.historyNum)
let i = 0
for (let log of this.logCache) {
if (this.isLevelSelected(log.level)) { //
if (this.historyNum_ != -1 && i >= this.logCache.length - this.historyNum_) {
this.printLog(log.data)
++i
} else if (this.historyNum_ == -1) {
this.printLog(log.data)
}
}
}
},
toggleAutoScroll() { toggleAutoScroll() {
this.autoScroll = !this.autoScroll; this.autoScroll = !this.autoScroll;
@@ -181,11 +143,6 @@ export default {
printLog(log) { printLog(log) {
// append span termblock // append span termblock
let ele = document.getElementById('term') let ele = document.getElementById('term')
if (!ele) {
console.warn('term element not found, skipping log print');
return;
}
let span = document.createElement('pre') let span = document.createElement('pre')
let style = this.logColorAnsiMap['default'] let style = this.logColorAnsiMap['default']
for (let key in this.logColorAnsiMap) { for (let key in this.logColorAnsiMap) {
@@ -115,7 +115,7 @@ const _show = computed({
</script> </script>
<template> <template>
<v-dialog v-model="_show" width="800"> <v-dialog v-model="_show" width="800" persistent>
<v-card> <v-card>
<v-card-title class="d-flex justify-space-between align-center"> <v-card-title class="d-flex justify-space-between align-center">
<span class="text-h5">{{ t('core.common.readme.title') }}</span> <span class="text-h5">{{ t('core.common.readme.title') }}</span>
@@ -57,9 +57,6 @@
}, },
"provider_id": { "provider_id": {
"description": "Default Text-to-Speech Model" "description": "Default Text-to-Speech Model"
},
"trigger_probability": {
"description": "TTS Trigger Probability"
} }
} }
}, },
@@ -13,8 +13,7 @@
"refresh": "Refresh" "refresh": "Refresh"
}, },
"batch": { "batch": {
"deleteSelected": "Delete Selected ({count})", "deleteSelected": "Delete Selected ({count})"
"exportSelected": "Export Selected ({count})"
}, },
"pagination": { "pagination": {
"itemsPerPage": "Items per page", "itemsPerPage": "Items per page",
@@ -77,8 +76,7 @@
"message": "Are you sure you want to delete the selected {count} conversations? This action cannot be undone, please proceed with caution!", "message": "Are you sure you want to delete the selected {count} conversations? This action cannot be undone, please proceed with caution!",
"andMore": "and {count} more", "andMore": "and {count} more",
"cancel": "Cancel", "cancel": "Cancel",
"confirm": "Batch Delete", "confirm": "Batch Delete"
"warning": "Warning: This action cannot be undone!"
} }
}, },
"messages": { "messages": {
@@ -94,9 +92,6 @@
"noItemSelected": "Please select conversations to delete first", "noItemSelected": "Please select conversations to delete first",
"batchDeleteSuccess": "Successfully deleted {count} conversations", "batchDeleteSuccess": "Successfully deleted {count} conversations",
"batchDeleteError": "Batch delete failed", "batchDeleteError": "Batch delete failed",
"batchDeletePartial": "Delete completed: {deleted} successful, {failed} failed", "batchDeletePartial": "Delete completed: {deleted} successful, {failed} failed"
"exportSuccess": "Export successful",
"exportError": "Export failed",
"noItemSelectedForExport": "Please select conversations to export first"
} }
} }
@@ -62,9 +62,6 @@
}, },
"provider_id": { "provider_id": {
"description": "默认文本转语音模型" "description": "默认文本转语音模型"
},
"trigger_probability": {
"description": "TTS 触发概率"
} }
} }
}, },
@@ -13,8 +13,7 @@
"refresh": "刷新" "refresh": "刷新"
}, },
"batch": { "batch": {
"deleteSelected": "删除选中 ({count})", "deleteSelected": "删除选中 ({count})"
"exportSelected": "导出选中 ({count})"
}, },
"pagination": { "pagination": {
"itemsPerPage": "每页", "itemsPerPage": "每页",
@@ -77,8 +76,7 @@
"message": "确定要删除选中的 {count} 个对话吗?此操作不可恢复,请谨慎操作!", "message": "确定要删除选中的 {count} 个对话吗?此操作不可恢复,请谨慎操作!",
"andMore": "等 {count} 个", "andMore": "等 {count} 个",
"cancel": "取消", "cancel": "取消",
"confirm": "批量删除", "confirm": "批量删除"
"warning": "警告:此操作不可撤销!"
} }
}, },
"messages": { "messages": {
@@ -94,9 +92,6 @@
"noItemSelected": "请先选择要删除的对话", "noItemSelected": "请先选择要删除的对话",
"batchDeleteSuccess": "成功删除 {count} 个对话", "batchDeleteSuccess": "成功删除 {count} 个对话",
"batchDeleteError": "批量删除失败", "batchDeleteError": "批量删除失败",
"batchDeletePartial": "删除完成:成功 {deleted} 个,失败 {failed} 个", "batchDeletePartial": "删除完成:成功 {deleted} 个,失败 {failed} 个"
"exportSuccess": "导出成功",
"exportError": "导出失败",
"noItemSelectedForExport": "请先选择要导出的对话"
} }
} }
+64 -30
View File
@@ -16,6 +16,21 @@ export const useCommonStore = defineStore({
}), }),
actions: { actions: {
async createEventSource() { async createEventSource() {
const fetchLogHistory = async () => {
try {
const res = await axios.get('/api/log-history');
if (res.data.data.logs) {
this.log_cache.push(...res.data.data.logs);
} else {
this.log_cache = [];
}
} catch (err) {
console.error('Failed to fetch log history:', err);
}
};
await fetchLogHistory();
if (this.eventSource) { if (this.eventSource) {
return return
} }
@@ -39,9 +54,25 @@ export const useCommonStore = defineStore({
const reader = response.body.getReader(); const reader = response.body.getReader();
const decoder = new TextDecoder(); const decoder = new TextDecoder();
let bufferedText = '';
let incompleteLine = ""; // 用于存储不完整的行
const handleIncompleteLine = (line) => {
incompleteLine += line;
// if can parse as JSON, return it
try {
const data_json = JSON.parse(incompleteLine);
incompleteLine = ""; // 清空不完整行
return data_json;
} catch (e) {
return null;
}
}
const processStream = ({ done, value }) => { const processStream = ({ done, value }) => {
// get bytes length
const bytesLength = value ? value.byteLength : 0;
console.log(`Received ${bytesLength} bytes from live log`);
if (done) { if (done) {
console.log('SSE stream closed'); console.log('SSE stream closed');
setTimeout(() => { setTimeout(() => {
@@ -51,41 +82,44 @@ export const useCommonStore = defineStore({
return; return;
} }
// Accumulate partial chunks; SSE data may split JSON across reads. const text = decoder.decode(value);
const text = decoder.decode(value, { stream: true }); const lines = text.split('\n\n');
bufferedText += text; lines.forEach(line => {
if (!line.trim()) {
// Split completed events; keep the trailing partial in buffer.
const segments = bufferedText.split('\n\n');
bufferedText = segments.pop() || '';
segments.forEach(segment => {
const line = segment.trim();
if (!line.startsWith('data: ')) {
return; return;
} }
if (line.startsWith('data:')) {
const logLine = line.replace('data: ', '').trim(); const data = line.substring(5).trim();
if (!logLine) { // {"type":"log","data":"[2021-08-01 00:00:00] INFO: Hello, world!"}
return; let data_json = {}
} try {
data_json = JSON.parse(data);
try { } catch (e) {
const logObject = JSON.parse(logLine); console.warn('Invalid JSON:', data);
// give a uuid if not exists // 尝试处理不完整的行
if (!logObject.uuid) { const parsedData = handleIncompleteLine(data);
logObject.uuid = crypto.randomUUID(); if (parsedData) {
data_json = parsedData;
} else {
return; // 如果无法解析,跳过当前行
}
} }
this.log_cache.push(logObject); if (data_json.type === 'log') {
// Limit log cache size this.log_cache.push(data_json);
if (this.log_cache.length > this.log_cache_max_len) { if (this.log_cache.length > this.log_cache_max_len) {
this.log_cache.splice(0, this.log_cache.length - this.log_cache_max_len); this.log_cache.shift();
}
}
} else {
const parsedData = handleIncompleteLine(line);
if (parsedData && parsedData.type === 'log') {
this.log_cache.push(parsedData);
if (this.log_cache.length > this.log_cache_max_len) {
this.log_cache.shift();
}
} }
} catch (err) {
console.warn('Failed to parse SSE log line, skipping:', err, logLine);
} }
}); });
return reader.read().then(processStream); return reader.read().then(processStream);
}; };
-58
View File
@@ -40,17 +40,6 @@
:loading="loading" size="small" class="mr-2"> :loading="loading" size="small" class="mr-2">
{{ tm('history.refresh') }} {{ tm('history.refresh') }}
</v-btn> </v-btn>
<v-btn
v-if="selectedItems.length > 0"
color="success"
prepend-icon="mdi-download"
variant="tonal"
@click="exportConversations"
:disabled="loading"
size="small"
class="mr-2">
{{ tm('batch.exportSelected', { count: selectedItems.length }) }}
</v-btn>
<v-btn <v-btn
v-if="selectedItems.length > 0" v-if="selectedItems.length > 0"
color="error" color="error"
@@ -921,53 +910,6 @@ export default {
} }
}, },
//
async exportConversations() {
if (this.selectedItems.length === 0) {
this.showErrorMessage(this.tm('messages.noItemSelectedForExport'));
return;
}
this.loading = true;
try {
//
const conversations = this.selectedItems.map(item => ({
user_id: item.user_id,
cid: item.cid
}));
const response = await axios.post('/api/conversation/export', {
conversations: conversations
}, {
responseType: 'blob' // axios blob
});
//
const url = window.URL.createObjectURL(response.data);
const link = document.createElement('a');
link.href = url;
// 使
const timestamp = new Date().toISOString().replace(/[:.]/g, '-').slice(0, -5);
const filename = `conversations_export_${timestamp}.jsonl`;
link.setAttribute('download', filename);
document.body.appendChild(link);
link.click();
//
link.remove();
window.URL.revokeObjectURL(url);
this.showSuccessMessage(this.tm('messages.exportSuccess'));
} catch (error) {
console.error(this.tm('messages.exportError'), error);
this.showErrorMessage(error.response?.data?.message || error.message || this.tm('messages.exportError'));
} finally {
this.loading = false;
}
},
// //
formatTimestamp(timestamp) { formatTimestamp(timestamp) {
if (!timestamp) return this.tm('status.unknown'); if (!timestamp) return this.tm('status.unknown');
+1 -1
View File
@@ -1568,7 +1568,7 @@ watch(marketSearch, (newVal) => {
<!-- 上传插件对话框 --> <!-- 上传插件对话框 -->
<v-dialog v-model="dialog" width="500"> <v-dialog v-model="dialog" width="500">
<div class="v-card v-card--density-default rounded-lg v-card--variant-elevated"> <div class="v-card v-theme--PurpleThemeDark v-card--density-default rounded-lg v-card--variant-elevated">
<div class="v-card__loader"> <div class="v-card__loader">
<v-progress-linear :indeterminate="loading_" color="primary" height="2" :active="loading_"></v-progress-linear> <v-progress-linear :indeterminate="loading_" color="primary" height="2" :active="loading_"></v-progress-linear>
</div> </div>
+1 -1
View File
@@ -1,6 +1,6 @@
[project] [project]
name = "AstrBot" name = "AstrBot"
version = "4.9.2" version = "4.8.0"
description = "Easy-to-use multi-platform LLM chatbot and development framework" description = "Easy-to-use multi-platform LLM chatbot and development framework"
readme = "README.md" readme = "README.md"
requires-python = ">=3.10" requires-python = ">=3.10"
+279
View File
@@ -0,0 +1,279 @@
"""Test GitHub webhook platform adapter"""
import asyncio
import hashlib
import hmac
from unittest.mock import MagicMock
import pytest
from astrbot.core.platform.sources.github_webhook.github_webhook_adapter import (
GitHubWebhookPlatformAdapter,
)
@pytest.fixture
def event_queue():
"""Create a test event queue"""
return asyncio.Queue()
@pytest.fixture
def platform_config():
"""Create test platform configuration"""
return {
"type": "github_webhook",
"enable": True,
"id": "test_github_webhook",
"unified_webhook_mode": True,
"webhook_uuid": "test-uuid-123",
"webhook_secret": "", # No secret by default for easier testing
}
@pytest.fixture
def platform_settings():
"""Create test platform settings"""
return {"unique_session": False}
@pytest.fixture
def adapter(platform_config, platform_settings, event_queue):
"""Create test adapter instance"""
return GitHubWebhookPlatformAdapter(platform_config, platform_settings, event_queue)
class TestGitHubWebhookAdapter:
"""Test cases for GitHub webhook adapter"""
def test_adapter_initialization(self, adapter):
"""Test adapter is initialized correctly"""
assert adapter.unified_webhook_mode is True
assert adapter.webhook_secret == ""
assert adapter.meta().name == "github_webhook"
assert adapter.meta().description == "GitHub Webhook 适配器"
@pytest.mark.asyncio
async def test_ping_event(self, adapter):
"""Test GitHub ping event"""
# Mock request
request = MagicMock()
request.headers.get.return_value = "ping"
async def mock_json():
return {}
request.json = mock_json()
response = await adapter.webhook_callback(request)
assert response == {"message": "pong"}
@pytest.mark.asyncio
async def test_issue_created_event(self, adapter, event_queue):
"""Test GitHub issue created event"""
# Mock request with issue created payload
request = MagicMock()
request.headers.get.return_value = "issues"
payload = {
"action": "opened",
"issue": {
"title": "Test Issue",
"body": "This is a test issue",
"html_url": "https://github.com/test/repo/issues/1",
},
"repository": {"full_name": "test/repo"},
"sender": {"login": "testuser"},
}
async def mock_json():
return payload
request.json = mock_json()
response = await adapter.webhook_callback(request)
assert response == {"status": "ok"}
# Verify event was queued
assert not event_queue.empty()
event = event_queue.get_nowait()
assert event.event_type == "issues"
assert "新 Issue 创建" in event.message_str
assert "Test Issue" in event.message_str
@pytest.mark.asyncio
async def test_issue_comment_event(self, adapter, event_queue):
"""Test GitHub issue comment event"""
request = MagicMock()
request.headers.get.return_value = "issue_comment"
payload = {
"action": "created",
"issue": {"title": "Test Issue"},
"comment": {
"body": "Test comment",
"html_url": "https://github.com/test/repo/issues/1#comment",
},
"repository": {"full_name": "test/repo"},
"sender": {"login": "commenter"},
}
async def mock_json():
return payload
request.json = mock_json()
response = await adapter.webhook_callback(request)
assert response == {"status": "ok"}
# Verify event was queued
assert not event_queue.empty()
event = event_queue.get_nowait()
assert event.event_type == "issue_comment"
assert "新 Issue 评论" in event.message_str
assert "Test comment" in event.message_str
@pytest.mark.asyncio
async def test_pull_request_event(self, adapter, event_queue):
"""Test GitHub pull request opened event"""
request = MagicMock()
request.headers.get.return_value = "pull_request"
payload = {
"action": "opened",
"pull_request": {
"title": "Test PR",
"body": "This is a test PR",
"html_url": "https://github.com/test/repo/pull/1",
},
"repository": {"full_name": "test/repo"},
"sender": {"login": "prauthor"},
}
async def mock_json():
return payload
request.json = mock_json()
response = await adapter.webhook_callback(request)
assert response == {"status": "ok"}
# Verify event was queued
assert not event_queue.empty()
event = event_queue.get_nowait()
assert event.event_type == "pull_request"
assert "新 Pull Request" in event.message_str
assert "Test PR" in event.message_str
@pytest.mark.asyncio
async def test_unsupported_event(self, adapter, event_queue):
"""Test unsupported GitHub event type"""
request = MagicMock()
request.headers.get.return_value = "push"
async def mock_json():
return {"action": "created"}
request.json = mock_json()
response = await adapter.webhook_callback(request)
assert response == {"status": "ok"}
# Verify no event was queued for unsupported events
assert event_queue.empty()
@pytest.mark.asyncio
async def test_issue_closed_ignored(self, adapter, event_queue):
"""Test that issue closed action is ignored"""
request = MagicMock()
request.headers.get.return_value = "issues"
payload = {
"action": "closed", # Should be ignored
"issue": {"title": "Test Issue"},
"repository": {"full_name": "test/repo"},
"sender": {"login": "testuser"},
}
async def mock_json():
return payload
request.json = mock_json()
response = await adapter.webhook_callback(request)
assert response == {"status": "ok"}
# Verify no event was queued
assert event_queue.empty()
@pytest.mark.asyncio
async def test_signature_verification(self, platform_settings, event_queue):
"""Test webhook signature verification"""
# Create adapter with webhook secret
config_with_secret = {
"type": "github_webhook",
"enable": True,
"id": "test_github_webhook",
"unified_webhook_mode": True,
"webhook_uuid": "test-uuid-123",
"webhook_secret": "test-secret",
}
adapter = GitHubWebhookPlatformAdapter(
config_with_secret, platform_settings, event_queue
)
# Create a valid signature
body = b'{"action": "opened"}'
signature = hmac.new(b"test-secret", body, hashlib.sha256).hexdigest()
# Mock request with valid signature
request = MagicMock()
request.headers.get = lambda key, default="": {
"X-GitHub-Event": "ping",
"X-Hub-Signature-256": f"sha256={signature}",
}.get(key, default)
async def mock_get_data():
return body
request.get_data = mock_get_data
async def mock_json():
return {"action": "opened"}
request.json = mock_json()
response = await adapter.webhook_callback(request)
assert response == {"message": "pong"}
@pytest.mark.asyncio
async def test_invalid_signature(self, platform_settings, event_queue):
"""Test webhook with invalid signature is rejected"""
# Create adapter with webhook secret
config_with_secret = {
"type": "github_webhook",
"enable": True,
"id": "test_github_webhook",
"unified_webhook_mode": True,
"webhook_uuid": "test-uuid-123",
"webhook_secret": "test-secret",
}
adapter = GitHubWebhookPlatformAdapter(
config_with_secret, platform_settings, event_queue
)
# Mock request with invalid signature
request = MagicMock()
request.headers.get = lambda key, default="": {
"X-GitHub-Event": "ping",
"X-Hub-Signature-256": "sha256=invalidsignature",
}.get(key, default)
async def mock_get_data():
return b'{"action": "opened"}'
request.get_data = mock_get_data
async def mock_json():
return {"action": "opened"}
request.json = mock_json()
response = await adapter.webhook_callback(request)
assert response == ({"error": "Invalid signature"}, 401)
-209
View File
@@ -1,209 +0,0 @@
import asyncio
from unittest.mock import AsyncMock, MagicMock
import pytest
import pytest_asyncio
from quart import Quart
from astrbot.core import LogBroker
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.db.sqlite import SQLiteDatabase
from astrbot.core.knowledge_base.kb_helper import KBHelper
from astrbot.core.knowledge_base.models import KBDocument
from astrbot.dashboard.server import AstrBotDashboard
@pytest_asyncio.fixture(scope="module")
async def core_lifecycle_td(tmp_path_factory):
"""Creates and initializes a core lifecycle instance with a temporary database."""
tmp_db_path = tmp_path_factory.mktemp("data") / "test_data_kb.db"
db = SQLiteDatabase(str(tmp_db_path))
log_broker = LogBroker()
core_lifecycle = AstrBotCoreLifecycle(log_broker, db)
await core_lifecycle.initialize()
# Mock kb_manager and kb_helper
kb_manager = MagicMock()
kb_helper = AsyncMock(spec=KBHelper)
# Configure get_kb to be an async mock that returns kb_helper
kb_manager.get_kb = AsyncMock(return_value=kb_helper)
# Mock upload_document return value
mock_doc = KBDocument(
doc_id="test_doc_id",
kb_id="test_kb_id",
doc_name="test_file.txt",
file_type="txt",
file_size=100,
file_path="",
chunk_count=2,
media_count=0,
)
kb_helper.upload_document.return_value = mock_doc
# kb_manager.get_kb.return_value = kb_helper # Removed this line as it's handled above
core_lifecycle.kb_manager = kb_manager
try:
yield core_lifecycle
finally:
try:
_stop_res = core_lifecycle.stop()
if asyncio.iscoroutine(_stop_res):
await _stop_res
except Exception:
pass
@pytest.fixture(scope="module")
def app(core_lifecycle_td: AstrBotCoreLifecycle):
"""Creates a Quart app instance for testing."""
shutdown_event = asyncio.Event()
server = AstrBotDashboard(core_lifecycle_td, core_lifecycle_td.db, shutdown_event)
return server.app
@pytest_asyncio.fixture(scope="module")
async def authenticated_header(app: Quart, core_lifecycle_td: AstrBotCoreLifecycle):
"""Handles login and returns an authenticated header."""
test_client = app.test_client()
response = await test_client.post(
"/api/auth/login",
json={
"username": core_lifecycle_td.astrbot_config["dashboard"]["username"],
"password": core_lifecycle_td.astrbot_config["dashboard"]["password"],
},
)
data = await response.get_json()
assert data["status"] == "ok"
token = data["data"]["token"]
return {"Authorization": f"Bearer {token}"}
@pytest.mark.asyncio
async def test_import_documents(
app: Quart, authenticated_header: dict, core_lifecycle_td: AstrBotCoreLifecycle
):
"""Tests the import documents functionality."""
test_client = app.test_client()
# Test data
import_data = {
"kb_id": "test_kb_id",
"documents": [
{"file_name": "test_file_1.txt", "chunks": ["chunk1", "chunk2"]},
{"file_name": "test_file_2.md", "chunks": ["chunk3", "chunk4", "chunk5"]},
],
}
# Send request
response = await test_client.post(
"/api/kb/document/import", json=import_data, headers=authenticated_header
)
# Verify response
assert response.status_code == 200
data = await response.get_json()
assert data["status"] == "ok"
assert "task_id" in data["data"]
assert data["data"]["doc_count"] == 2
task_id = data["data"]["task_id"]
# Wait for background task to complete (mocked)
# Since we mocked upload_document, it should be fast, but we might need to poll progress
for _ in range(10):
progress_response = await test_client.get(
f"/api/kb/document/upload/progress?task_id={task_id}",
headers=authenticated_header,
)
progress_data = await progress_response.get_json()
if progress_data["data"]["status"] == "completed":
break
await asyncio.sleep(0.1)
assert progress_data["data"]["status"] == "completed"
result = progress_data["data"]["result"]
assert result["success_count"] == 2
assert result["failed_count"] == 0
# Verify kb_helper.upload_document was called correctly
kb_helper = await core_lifecycle_td.kb_manager.get_kb("test_kb_id")
assert kb_helper.upload_document.call_count == 2
# Check first call arguments
call_args_list = kb_helper.upload_document.call_args_list
# First document
args1, kwargs1 = call_args_list[0]
assert kwargs1["file_name"] == "test_file_1.txt"
assert kwargs1["pre_chunked_text"] == ["chunk1", "chunk2"]
# Second document
args2, kwargs2 = call_args_list[1]
assert kwargs2["file_name"] == "test_file_2.md"
assert kwargs2["pre_chunked_text"] == ["chunk3", "chunk4", "chunk5"]
@pytest.mark.asyncio
async def test_import_documents_invalid_input(app: Quart, authenticated_header: dict):
"""Tests import documents with invalid input."""
test_client = app.test_client()
# Missing kb_id
response = await test_client.post(
"/api/kb/document/import", json={"documents": []}, headers=authenticated_header
)
data = await response.get_json()
assert data["status"] == "error"
assert "缺少参数 kb_id" in data["message"]
# Missing documents
response = await test_client.post(
"/api/kb/document/import",
json={"kb_id": "test_kb"},
headers=authenticated_header,
)
data = await response.get_json()
assert data["status"] == "error"
assert "缺少参数 documents" in data["message"]
# Invalid document format
response = await test_client.post(
"/api/kb/document/import",
json={
"kb_id": "test_kb",
"documents": [{"file_name": "test"}], # Missing chunks
},
headers=authenticated_header,
)
data = await response.get_json()
assert data["status"] == "error"
assert "文档格式错误" in data["message"]
# Invalid chunks type
response = await test_client.post(
"/api/kb/document/import",
json={
"kb_id": "test_kb",
"documents": [{"file_name": "test", "chunks": "not-a-list"}],
},
headers=authenticated_header,
)
data = await response.get_json()
assert data["status"] == "error"
assert "chunks 必须是列表" in data["message"]
# Invalid chunks content
response = await test_client.post(
"/api/kb/document/import",
json={
"kb_id": "test_kb",
"documents": [{"file_name": "test", "chunks": ["valid", ""]}],
},
headers=authenticated_header,
)
data = await response.get_json()
assert data["status"] == "error"
assert "chunks 必须是非空字符串列表" in data["message"]