Compare commits

..

5 Commits

Author SHA1 Message Date
copilot-swe-agent[bot] a2fe0ec5a1 Add webhook signature verification for security
Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>
2025-12-12 14:27:51 +00:00
copilot-swe-agent[bot] 6957ec713d Clean up unused imports in tests
Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>
2025-12-12 14:24:18 +00:00
copilot-swe-agent[bot] d97c8b5b2b Add tests for GitHub webhook platform adapter
Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>
2025-12-12 14:23:22 +00:00
copilot-swe-agent[bot] d07a1ad5c9 Add GitHub webhook platform adapter with event handlers
Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>
2025-12-12 14:20:33 +00:00
copilot-swe-agent[bot] d8e6dfbd6b Initial plan 2025-12-12 14:14:49 +00:00
14 changed files with 809 additions and 602 deletions
+1 -1
View File
@@ -1 +1 @@
__version__ = "4.9.0" __version__ = "4.8.0"
+1 -1
View File
@@ -4,7 +4,7 @@ import os
from astrbot.core.utils.astrbot_path import get_astrbot_data_path from astrbot.core.utils.astrbot_path import get_astrbot_data_path
VERSION = "4.9.0" VERSION = "4.8.0"
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db") DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db")
WEBHOOK_SUPPORTED_PLATFORMS = [ WEBHOOK_SUPPORTED_PLATFORMS = [
+1 -2
View File
@@ -24,7 +24,6 @@ import asyncio
import logging import logging
import os import os
import sys import sys
import time
from asyncio import Queue from asyncio import Queue
from collections import deque from collections import deque
@@ -149,7 +148,7 @@ class LogQueueHandler(logging.Handler):
self.log_broker.publish( self.log_broker.publish(
{ {
"level": record.levelname, "level": record.levelname,
"time": time.time(), "time": record.asctime,
"data": log_entry, "data": log_entry,
}, },
) )
+4
View File
@@ -112,6 +112,10 @@ class PlatformManager:
from .sources.satori.satori_adapter import ( from .sources.satori.satori_adapter import (
SatoriPlatformAdapter, # noqa: F401 SatoriPlatformAdapter, # noqa: F401
) )
case "github_webhook":
from .sources.github_webhook.github_webhook_adapter import (
GitHubWebhookPlatformAdapter, # noqa: F401
)
except (ImportError, ModuleNotFoundError) as e: except (ImportError, ModuleNotFoundError) as e:
logger.error( logger.error(
f"加载平台适配器 {platform_config['type']} 失败,原因:{e}。请检查依赖库是否安装。提示:可以在 管理面板->平台日志->安装Pip库 中安装依赖库。", f"加载平台适配器 {platform_config['type']} 失败,原因:{e}。请检查依赖库是否安装。提示:可以在 管理面板->平台日志->安装Pip库 中安装依赖库。",
@@ -0,0 +1,315 @@
import asyncio
import hashlib
import hmac
from typing import Any, cast
from astrbot import logger
from astrbot.api.event import MessageChain
from astrbot.api.message_components import Plain
from astrbot.api.platform import (
AstrBotMessage,
MessageMember,
MessageType,
Platform,
PlatformMetadata,
)
from astrbot.core.platform.astr_message_event import MessageSesion
from astrbot.core.platform.platform import PlatformStatus
from astrbot.core.utils.webhook_utils import log_webhook_info
from ...register import register_platform_adapter
from .github_webhook_event import GitHubWebhookMessageEvent
@register_platform_adapter(
"github_webhook",
"GitHub Webhook 适配器",
support_streaming_message=False,
)
class GitHubWebhookPlatformAdapter(Platform):
"""GitHub Webhook 平台适配器
支持的事件:
- issues (created)
- issue_comment (created)
- pull_request (opened)
"""
def __init__(
self,
platform_config: dict,
platform_settings: dict,
event_queue: asyncio.Queue,
) -> None:
super().__init__(platform_config, event_queue)
self.unified_webhook_mode = platform_config.get("unified_webhook_mode", True)
self.webhook_secret = platform_config.get("webhook_secret", "")
self.shutdown_event = asyncio.Event()
async def send_by_session(
self,
session: MessageSesion,
message_chain: MessageChain,
):
"""GitHub Webhook 是单向接收,不支持主动发送消息"""
logger.warning("GitHub Webhook 适配器不支持 send_by_session")
def meta(self) -> PlatformMetadata:
return PlatformMetadata(
name="github_webhook",
description="GitHub Webhook 适配器",
id=cast(str, self.config.get("id")),
)
async def run(self):
"""运行适配器"""
self.status = PlatformStatus.RUNNING
# 如果启用统一 webhook 模式
webhook_uuid = self.config.get("webhook_uuid")
if self.unified_webhook_mode and webhook_uuid:
log_webhook_info(f"{self.meta().id}(GitHub Webhook)", webhook_uuid)
# 保持运行状态,等待 shutdown
await self.shutdown_event.wait()
else:
logger.warning("GitHub Webhook 适配器需要启用统一 webhook 模式")
await self.shutdown_event.wait()
async def webhook_callback(self, request: Any) -> Any:
"""统一 Webhook 回调入口
处理 GitHub webhook 事件
Args:
request: Quart 请求对象
Returns:
响应数据
"""
try:
# 获取事件类型
event_type = request.headers.get("X-GitHub-Event", "")
# 获取请求数据
payload = await request.json
# 验证 webhook 签名(如果配置了 secret
if self.webhook_secret:
if not await self._verify_signature(request, payload):
logger.warning("GitHub webhook 签名验证失败")
return {"error": "Invalid signature"}, 401
logger.debug(f"收到 GitHub Webhook 事件: {event_type}")
# 处理不同类型的事件
if event_type == "issues":
await self._handle_issue_event(payload)
elif event_type == "issue_comment":
await self._handle_issue_comment_event(payload)
elif event_type == "pull_request":
await self._handle_pull_request_event(payload)
elif event_type == "ping":
# GitHub webhook 验证事件
return {"message": "pong"}
else:
logger.debug(f"忽略不支持的 GitHub 事件类型: {event_type}")
return {"status": "ok"}
except Exception as e:
logger.error(f"处理 GitHub webhook 回调时发生错误: {e}", exc_info=True)
return {"error": str(e)}, 500
async def _verify_signature(self, request: Any, payload: dict) -> bool:
"""验证 GitHub webhook 签名
Args:
request: Quart 请求对象
payload: 请求负载数据
Returns:
签名是否有效
"""
signature_header = request.headers.get("X-Hub-Signature-256", "")
if not signature_header:
# 如果没有签名头,检查是否有旧版本的签名
signature_header = request.headers.get("X-Hub-Signature", "")
if not signature_header:
return False
# 获取原始请求体
body = await request.get_data()
# 计算 HMAC
if signature_header.startswith("sha256="):
expected_signature = hmac.new(
self.webhook_secret.encode("utf-8"),
body,
hashlib.sha256,
).hexdigest()
received_signature = signature_header.replace("sha256=", "")
elif signature_header.startswith("sha1="):
expected_signature = hmac.new(
self.webhook_secret.encode("utf-8"),
body,
hashlib.sha1,
).hexdigest()
received_signature = signature_header.replace("sha1=", "")
else:
return False
# 使用 hmac.compare_digest 防止时序攻击
return hmac.compare_digest(expected_signature, received_signature)
async def _handle_issue_event(self, payload: dict):
"""处理 issue 事件"""
action = payload.get("action", "")
# 只处理创建事件
if action != "created" and action != "opened":
return
issue = payload.get("issue", {})
repo = payload.get("repository", {})
sender = payload.get("sender", {})
# 构造消息文本
message_text = (
f"📝 新 Issue 创建\n"
f"仓库: {repo.get('full_name', 'unknown')}\n"
f"标题: {issue.get('title', 'No title')}\n"
f"作者: {sender.get('login', 'unknown')}\n"
f"链接: {issue.get('html_url', '')}\n"
f"内容:\n{issue.get('body', 'No description')[:200]}"
)
# 创建 AstrBotMessage
abm = self._create_message(
message_text,
sender.get("login", "unknown"),
sender.get("login", "unknown"),
repo.get("full_name", "unknown"),
)
# 提交事件
self.commit_event(
GitHubWebhookMessageEvent(
message_text,
abm,
self.meta(),
repo.get("full_name", "unknown"),
"issues",
payload,
)
)
async def _handle_issue_comment_event(self, payload: dict):
"""处理 issue 评论事件"""
action = payload.get("action", "")
# 只处理创建事件
if action != "created":
return
issue = payload.get("issue", {})
comment = payload.get("comment", {})
repo = payload.get("repository", {})
sender = payload.get("sender", {})
# 构造消息文本
message_text = (
f"💬 新 Issue 评论\n"
f"仓库: {repo.get('full_name', 'unknown')}\n"
f"Issue: {issue.get('title', 'No title')}\n"
f"评论者: {sender.get('login', 'unknown')}\n"
f"链接: {comment.get('html_url', '')}\n"
f"内容:\n{comment.get('body', 'No comment')[:200]}"
)
# 创建 AstrBotMessage
abm = self._create_message(
message_text,
sender.get("login", "unknown"),
sender.get("login", "unknown"),
repo.get("full_name", "unknown"),
)
# 提交事件
self.commit_event(
GitHubWebhookMessageEvent(
message_text,
abm,
self.meta(),
repo.get("full_name", "unknown"),
"issue_comment",
payload,
)
)
async def _handle_pull_request_event(self, payload: dict):
"""处理 pull request 事件"""
action = payload.get("action", "")
# 只处理打开事件
if action != "opened":
return
pr = payload.get("pull_request", {})
repo = payload.get("repository", {})
sender = payload.get("sender", {})
# 构造消息文本
message_text = (
f"🔀 新 Pull Request\n"
f"仓库: {repo.get('full_name', 'unknown')}\n"
f"标题: {pr.get('title', 'No title')}\n"
f"作者: {sender.get('login', 'unknown')}\n"
f"链接: {pr.get('html_url', '')}\n"
f"内容:\n{pr.get('body', 'No description')[:200]}"
)
# 创建 AstrBotMessage
abm = self._create_message(
message_text,
sender.get("login", "unknown"),
sender.get("login", "unknown"),
repo.get("full_name", "unknown"),
)
# 提交事件
self.commit_event(
GitHubWebhookMessageEvent(
message_text,
abm,
self.meta(),
repo.get("full_name", "unknown"),
"pull_request",
payload,
)
)
def _create_message(
self,
message_text: str,
user_id: str,
nickname: str,
session_id: str,
) -> AstrBotMessage:
"""创建 AstrBotMessage 对象"""
abm = AstrBotMessage()
abm.type = MessageType.GROUP_MESSAGE
abm.self_id = self.client_self_id
abm.session_id = session_id
abm.message_id = ""
abm.sender = MessageMember(user_id=user_id, nickname=nickname)
abm.message = [Plain(message_text)]
abm.message_str = message_text
abm.raw_message = message_text
return abm
async def terminate(self):
"""终止适配器运行"""
self.shutdown_event.set()
logger.info("GitHub Webhook 适配器已经被优雅地关闭")
@@ -0,0 +1,22 @@
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
from ...astr_message_event import AstrMessageEvent
class GitHubWebhookMessageEvent(AstrMessageEvent):
"""GitHub Webhook 消息事件"""
def __init__(
self,
message_str: str,
message_obj: AstrBotMessage,
platform_meta: PlatformMetadata,
session_id: str,
event_type: str,
event_data: dict,
):
super().__init__(message_str, message_obj, platform_meta, session_id)
self.event_type = event_type
"""GitHub 事件类型: issues, issue_comment, pull_request"""
self.event_data = event_data
"""原始事件数据"""
+78 -253
View File
@@ -48,7 +48,6 @@ class KnowledgeBaseRoute(Route):
# 文档管理 # 文档管理
"/kb/document/list": ("GET", self.list_documents), "/kb/document/list": ("GET", self.list_documents),
"/kb/document/upload": ("POST", self.upload_document), "/kb/document/upload": ("POST", self.upload_document),
"/kb/document/import": ("POST", self.import_documents),
"/kb/document/upload/url": ("POST", self.upload_document_from_url), "/kb/document/upload/url": ("POST", self.upload_document_from_url),
"/kb/document/upload/progress": ("GET", self.get_upload_progress), "/kb/document/upload/progress": ("GET", self.get_upload_progress),
"/kb/document/get": ("GET", self.get_document), "/kb/document/get": ("GET", self.get_document),
@@ -67,65 +66,6 @@ class KnowledgeBaseRoute(Route):
def _get_kb_manager(self): def _get_kb_manager(self):
return self.core_lifecycle.kb_manager return self.core_lifecycle.kb_manager
def _init_task(self, task_id: str, status: str = "pending") -> None:
self.upload_tasks[task_id] = {
"status": status,
"result": None,
"error": None,
}
def _set_task_result(
self, task_id: str, status: str, result: any = None, error: str | None = None
) -> None:
self.upload_tasks[task_id] = {
"status": status,
"result": result,
"error": error,
}
if task_id in self.upload_progress:
self.upload_progress[task_id]["status"] = status
def _update_progress(
self,
task_id: str,
*,
status: str | None = None,
file_index: int | None = None,
file_name: str | None = None,
stage: str | None = None,
current: int | None = None,
total: int | None = None,
) -> None:
if task_id not in self.upload_progress:
return
p = self.upload_progress[task_id]
if status is not None:
p["status"] = status
if file_index is not None:
p["file_index"] = file_index
if file_name is not None:
p["file_name"] = file_name
if stage is not None:
p["stage"] = stage
if current is not None:
p["current"] = current
if total is not None:
p["total"] = total
def _make_progress_callback(self, task_id: str, file_idx: int, file_name: str):
async def _callback(stage: str, current: int, total: int):
self._update_progress(
task_id,
status="processing",
file_index=file_idx,
file_name=file_name,
stage=stage,
current=current,
total=total,
)
return _callback
async def _background_upload_task( async def _background_upload_task(
self, self,
task_id: str, task_id: str,
@@ -140,7 +80,11 @@ class KnowledgeBaseRoute(Route):
"""后台上传任务""" """后台上传任务"""
try: try:
# 初始化任务状态 # 初始化任务状态
self._init_task(task_id, status="processing") self.upload_tasks[task_id] = {
"status": "processing",
"result": None,
"error": None,
}
self.upload_progress[task_id] = { self.upload_progress[task_id] = {
"status": "processing", "status": "processing",
"file_index": 0, "file_index": 0,
@@ -156,20 +100,30 @@ class KnowledgeBaseRoute(Route):
for file_idx, file_info in enumerate(files_to_upload): for file_idx, file_info in enumerate(files_to_upload):
try: try:
# 更新整体进度 # 更新整体进度
self._update_progress( self.upload_progress[task_id].update(
task_id, {
status="processing", "status": "processing",
file_index=file_idx, "file_index": file_idx,
file_name=file_info["file_name"], "file_name": file_info["file_name"],
stage="parsing", "stage": "parsing",
current=0, "current": 0,
total=100, "total": 100,
},
) )
# 创建进度回调函数 # 创建进度回调函数
progress_callback = self._make_progress_callback( async def progress_callback(stage, current, total):
task_id, file_idx, file_info["file_name"] if task_id in self.upload_progress:
) self.upload_progress[task_id].update(
{
"status": "processing",
"file_index": file_idx,
"file_name": file_info["file_name"],
"stage": stage,
"current": current,
"total": total,
},
)
doc = await kb_helper.upload_document( doc = await kb_helper.upload_document(
file_name=file_info["file_name"], file_name=file_info["file_name"],
@@ -200,99 +154,23 @@ class KnowledgeBaseRoute(Route):
"failed_count": len(failed_docs), "failed_count": len(failed_docs),
} }
self._set_task_result(task_id, "completed", result=result) self.upload_tasks[task_id] = {
"status": "completed",
"result": result,
"error": None,
}
self.upload_progress[task_id]["status"] = "completed"
except Exception as e: except Exception as e:
logger.error(f"后台上传任务 {task_id} 失败: {e}") logger.error(f"后台上传任务 {task_id} 失败: {e}")
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
self._set_task_result(task_id, "failed", error=str(e)) self.upload_tasks[task_id] = {
"status": "failed",
async def _background_import_task( "result": None,
self, "error": str(e),
task_id: str,
kb_helper,
documents: list,
batch_size: int,
tasks_limit: int,
max_retries: int,
):
"""后台导入预切片文档任务"""
try:
# 初始化任务状态
self._init_task(task_id, status="processing")
self.upload_progress[task_id] = {
"status": "processing",
"file_index": 0,
"file_total": len(documents),
"stage": "waiting",
"current": 0,
"total": 100,
} }
if task_id in self.upload_progress:
uploaded_docs = [] self.upload_progress[task_id]["status"] = "failed"
failed_docs = []
for file_idx, doc_info in enumerate(documents):
file_name = doc_info.get("file_name", f"imported_doc_{file_idx}")
chunks = doc_info.get("chunks", [])
try:
# 更新整体进度
self._update_progress(
task_id,
status="processing",
file_index=file_idx,
file_name=file_name,
stage="importing",
current=0,
total=100,
)
# 创建进度回调函数
progress_callback = self._make_progress_callback(
task_id, file_idx, file_name
)
# 调用 upload_document,传入 pre_chunked_text
doc = await kb_helper.upload_document(
file_name=file_name,
file_content=None, # 预切片模式下不需要原始内容
file_type=doc_info.get("file_type")
or (
file_name.rsplit(".", 1)[-1].lower()
if "." in file_name
else "txt"
),
batch_size=batch_size,
tasks_limit=tasks_limit,
max_retries=max_retries,
progress_callback=progress_callback,
pre_chunked_text=chunks,
)
uploaded_docs.append(doc.model_dump())
except Exception as e:
logger.error(f"导入文档 {file_name} 失败: {e}")
failed_docs.append(
{"file_name": file_name, "error": str(e)},
)
# 更新任务完成状态
result = {
"task_id": task_id,
"uploaded": uploaded_docs,
"failed": failed_docs,
"total": len(documents),
"success_count": len(uploaded_docs),
"failed_count": len(failed_docs),
}
self._set_task_result(task_id, "completed", result=result)
except Exception as e:
logger.error(f"后台导入任务 {task_id} 失败: {e}")
logger.error(traceback.format_exc())
self._set_task_result(task_id, "failed", error=str(e))
async def list_kbs(self): async def list_kbs(self):
"""获取知识库列表 """获取知识库列表
@@ -736,7 +614,11 @@ class KnowledgeBaseRoute(Route):
task_id = str(uuid.uuid4()) task_id = str(uuid.uuid4())
# 初始化任务状态 # 初始化任务状态
self._init_task(task_id, status="pending") self.upload_tasks[task_id] = {
"status": "pending",
"result": None,
"error": None,
}
# 启动后台任务 # 启动后台任务
asyncio.create_task( asyncio.create_task(
@@ -771,93 +653,6 @@ class KnowledgeBaseRoute(Route):
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
return Response().error(f"上传文档失败: {e!s}").__dict__ return Response().error(f"上传文档失败: {e!s}").__dict__
def _validate_import_request(self, data: dict):
kb_id = data.get("kb_id")
if not kb_id:
raise ValueError("缺少参数 kb_id")
documents = data.get("documents")
if not documents or not isinstance(documents, list):
raise ValueError("缺少参数 documents 或格式错误")
for doc in documents:
if "file_name" not in doc or "chunks" not in doc:
raise ValueError("文档格式错误,必须包含 file_name 和 chunks")
if not isinstance(doc["chunks"], list):
raise ValueError("chunks 必须是列表")
if not all(
isinstance(chunk, str) and chunk.strip() for chunk in doc["chunks"]
):
raise ValueError("chunks 必须是非空字符串列表")
batch_size = data.get("batch_size", 32)
tasks_limit = data.get("tasks_limit", 3)
max_retries = data.get("max_retries", 3)
return kb_id, documents, batch_size, tasks_limit, max_retries
async def import_documents(self):
"""导入预切片文档
Body:
- kb_id: 知识库 ID (必填)
- documents: 文档列表 (必填)
- file_name: 文件名 (必填)
- chunks: 切片列表 (必填, list[str])
- file_type: 文件类型 (可选, 默认从文件名推断或为 txt)
- batch_size: 批处理大小 (可选, 默认32)
- tasks_limit: 并发任务限制 (可选, 默认3)
- max_retries: 最大重试次数 (可选, 默认3)
"""
try:
kb_manager = self._get_kb_manager()
data = await request.json
kb_id, documents, batch_size, tasks_limit, max_retries = (
self._validate_import_request(data)
)
# 获取知识库
kb_helper = await kb_manager.get_kb(kb_id)
if not kb_helper:
return Response().error("知识库不存在").__dict__
# 生成任务ID
task_id = str(uuid.uuid4())
# 初始化任务状态
self._init_task(task_id, status="pending")
# 启动后台任务
asyncio.create_task(
self._background_import_task(
task_id=task_id,
kb_helper=kb_helper,
documents=documents,
batch_size=batch_size,
tasks_limit=tasks_limit,
max_retries=max_retries,
),
)
return (
Response()
.ok(
{
"task_id": task_id,
"doc_count": len(documents),
"message": "import task created, processing in background",
},
)
.__dict__
)
except ValueError as e:
return Response().error(str(e)).__dict__
except Exception as e:
logger.error(f"导入文档失败: {e}")
logger.error(traceback.format_exc())
return Response().error(f"导入文档失败: {e!s}").__dict__
async def get_upload_progress(self): async def get_upload_progress(self):
"""获取上传进度和结果 """获取上传进度和结果
@@ -1165,7 +960,11 @@ class KnowledgeBaseRoute(Route):
task_id = str(uuid.uuid4()) task_id = str(uuid.uuid4())
# 初始化任务状态 # 初始化任务状态
self._init_task(task_id, status="pending") self.upload_tasks[task_id] = {
"status": "pending",
"result": None,
"error": None,
}
# 启动后台任务 # 启动后台任务
asyncio.create_task( asyncio.create_task(
@@ -1218,7 +1017,11 @@ class KnowledgeBaseRoute(Route):
"""后台上传URL任务""" """后台上传URL任务"""
try: try:
# 初始化任务状态 # 初始化任务状态
self._init_task(task_id, status="processing") self.upload_tasks[task_id] = {
"status": "processing",
"result": None,
"error": None,
}
self.upload_progress[task_id] = { self.upload_progress[task_id] = {
"status": "processing", "status": "processing",
"file_index": 0, "file_index": 0,
@@ -1230,7 +1033,18 @@ class KnowledgeBaseRoute(Route):
} }
# 创建进度回调函数 # 创建进度回调函数
progress_callback = self._make_progress_callback(task_id, 0, f"URL: {url}") async def progress_callback(stage, current, total):
if task_id in self.upload_progress:
self.upload_progress[task_id].update(
{
"status": "processing",
"file_index": 0,
"file_name": f"URL: {url}",
"stage": stage,
"current": current,
"total": total,
},
)
# 上传文档 # 上传文档
doc = await kb_helper.upload_from_url( doc = await kb_helper.upload_from_url(
@@ -1255,9 +1069,20 @@ class KnowledgeBaseRoute(Route):
"failed_count": 0, "failed_count": 0,
} }
self._set_task_result(task_id, "completed", result=result) self.upload_tasks[task_id] = {
"status": "completed",
"result": result,
"error": None,
}
self.upload_progress[task_id]["status"] = "completed"
except Exception as e: except Exception as e:
logger.error(f"后台上传URL任务 {task_id} 失败: {e}") logger.error(f"后台上传URL任务 {task_id} 失败: {e}")
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
self._set_task_result(task_id, "failed", error=str(e)) self.upload_tasks[task_id] = {
"status": "failed",
"result": None,
"error": str(e),
}
if task_id in self.upload_progress:
self.upload_progress[task_id]["status"] = "failed"
-19
View File
@@ -1,19 +0,0 @@
## What's Changed
### 新增
- 支持自定义插件源。
- 支持飞书(Lark)的 Webhook 模式(将事件推送至开发者服务器)。
- 支持 “禁用自带指令” 快捷配置项,启用后将禁用所有 AstrBot 自带指令。入口: WebUI -> 配置文件 -> 平台配置。
### 优化
- 从 WebUI 移除了开发版本渠道。
- 当试图测试"Agent Runner"时,提示前往配置文件页测试。
- WebUI 列表项支持批量粘贴、回车创建项目。
### 修复
- Gemini API 部分调用失败的问题。
- WebUI 插件安装加载 Dialog 关闭按钮在手机端下显示异常的问题。
- 部分情况下,WebUI 日志显示不全的问题。
Binary file not shown.

Before

Width:  |  Height:  |  Size: 12 KiB

@@ -1,7 +1,6 @@
<script setup> <script setup>
import { useCommonStore } from '@/stores/common'; import { useCommonStore } from '@/stores/common';
import { storeToRefs } from 'pinia'; import { storeToRefs } from 'pinia';
import axios from 'axios';
</script> </script>
<template> <template>
@@ -25,6 +24,8 @@ import axios from 'axios';
export default { export default {
name: 'ConsoleDisplayer', name: 'ConsoleDisplayer',
data() { data() {
const commonStore = useCommonStore();
const { log_cache } = storeToRefs(commonStore);
return { return {
autoScroll: true, // autoScroll: true, //
logColorAnsiMap: { logColorAnsiMap: {
@@ -37,6 +38,7 @@ export default {
'\u001b[32m': 'color: #00FF00;', // green '\u001b[32m': 'color: #00FF00;', // green
'default': 'color: #FFFFFF;' 'default': 'color: #FFFFFF;'
}, },
logCache: log_cache,
historyNum_: -1, historyNum_: -1,
logLevels: ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], logLevels: ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
selectedLevels: [0, 1, 2, 3, 4], // selectedLevels: [0, 1, 2, 3, 4], //
@@ -46,17 +48,7 @@ export default {
'WARNING': 'amber', 'WARNING': 'amber',
'ERROR': 'red', 'ERROR': 'red',
'CRITICAL': 'purple' 'CRITICAL': 'purple'
}, }
lastProcessedTime: 0, //
localLogCache: [], //
}
},
computed: {
commonStore() {
return useCommonStore();
},
logCache() {
return this.commonStore.log_cache;
} }
}, },
props: { props: {
@@ -71,39 +63,13 @@ export default {
}, },
watch: { watch: {
logCache: { logCache: {
handler(newVal) { handler(val) {
// timestamp const lastLog = val[this.logCache.length - 1];
if (newVal && newVal.length > 0) { if (lastLog && this.isLevelSelected(lastLog.level)) {
// DOM this.printLog(lastLog.data);
this.$nextTick(() => {
//
const newLogs = newVal.filter(log => log.time > this.lastProcessedTime);
if (newLogs.length > 0) {
this.localLogCache.push(...newLogs);
//
this.localLogCache.sort((a, b) => a.time - b.time);
// log_cache_max_len
if (this.localLogCache.length > this.commonStore.log_cache_max_len) {
this.localLogCache.splice(0, this.localLogCache.length - this.commonStore.log_cache_max_len);
}
//
newLogs.forEach(logItem => {
if (this.isLevelSelected(logItem.level)) {
this.printLog(logItem.data);
}
});
//
this.lastProcessedTime = Math.max(...newLogs.map(log => log.time));
}
});
} }
}, },
deep: true, deep: true
immediate: false
}, },
selectedLevels: { selectedLevels: {
handler() { handler() {
@@ -112,37 +78,14 @@ export default {
deep: true deep: true
} }
}, },
async mounted() { mounted() {
// if (this.logCache.length === 0) {
await this.fetchLogHistory(); this.delayInit()
} else {
// DOM this.init()
this.$nextTick(() => { }
if (this.localLogCache.length > 0) {
this.localLogCache.forEach(logItem => {
if (this.isLevelSelected(logItem.level)) {
this.printLog(logItem.data);
}
});
//
this.lastProcessedTime = Math.max(...this.localLogCache.map(log => log.time));
}
});
}, },
methods: { methods: {
async fetchLogHistory() {
try {
const res = await axios.get('/api/log-history');
if (res.data.data.logs && res.data.data.logs.length > 0) {
this.localLogCache = [...res.data.data.logs];
//
this.localLogCache.sort((a, b) => a.time - b.time);
}
} catch (err) {
console.error('Failed to fetch log history:', err);
}
},
getLevelColor(level) { getLevelColor(level) {
return this.levelColors[level] || 'grey'; return this.levelColors[level] || 'grey';
}, },
@@ -158,21 +101,40 @@ export default {
}, },
refreshDisplay() { refreshDisplay() {
//
const termElement = document.getElementById('term'); const termElement = document.getElementById('term');
if (termElement) { if (termElement) {
termElement.innerHTML = ''; termElement.innerHTML = '';
}
// //
if (this.localLogCache && this.localLogCache.length > 0) { this.init();
this.localLogCache.forEach(logItem => { },
if (this.isLevelSelected(logItem.level)) {
this.printLog(logItem.data); delayInit() {
} if (this.logCache.length === 0) {
}); setTimeout(() => {
} this.delayInit()
}, 500)
} else {
this.init()
} }
}, },
init() {
this.historyNum_ = parseInt(this.historyNum)
let i = 0
for (let log of this.logCache) {
if (this.isLevelSelected(log.level)) { //
if (this.historyNum_ != -1 && i >= this.logCache.length - this.historyNum_) {
this.printLog(log.data)
++i
} else if (this.historyNum_ == -1) {
this.printLog(log.data)
}
}
}
},
toggleAutoScroll() { toggleAutoScroll() {
this.autoScroll = !this.autoScroll; this.autoScroll = !this.autoScroll;
@@ -181,11 +143,6 @@ export default {
printLog(log) { printLog(log) {
// append span termblock // append span termblock
let ele = document.getElementById('term') let ele = document.getElementById('term')
if (!ele) {
console.warn('term element not found, skipping log print');
return;
}
let span = document.createElement('pre') let span = document.createElement('pre')
let style = this.logColorAnsiMap['default'] let style = this.logColorAnsiMap['default']
for (let key in this.logColorAnsiMap) { for (let key in this.logColorAnsiMap) {
+64 -30
View File
@@ -16,6 +16,21 @@ export const useCommonStore = defineStore({
}), }),
actions: { actions: {
async createEventSource() { async createEventSource() {
const fetchLogHistory = async () => {
try {
const res = await axios.get('/api/log-history');
if (res.data.data.logs) {
this.log_cache.push(...res.data.data.logs);
} else {
this.log_cache = [];
}
} catch (err) {
console.error('Failed to fetch log history:', err);
}
};
await fetchLogHistory();
if (this.eventSource) { if (this.eventSource) {
return return
} }
@@ -39,9 +54,25 @@ export const useCommonStore = defineStore({
const reader = response.body.getReader(); const reader = response.body.getReader();
const decoder = new TextDecoder(); const decoder = new TextDecoder();
let bufferedText = '';
let incompleteLine = ""; // 用于存储不完整的行
const handleIncompleteLine = (line) => {
incompleteLine += line;
// if can parse as JSON, return it
try {
const data_json = JSON.parse(incompleteLine);
incompleteLine = ""; // 清空不完整行
return data_json;
} catch (e) {
return null;
}
}
const processStream = ({ done, value }) => { const processStream = ({ done, value }) => {
// get bytes length
const bytesLength = value ? value.byteLength : 0;
console.log(`Received ${bytesLength} bytes from live log`);
if (done) { if (done) {
console.log('SSE stream closed'); console.log('SSE stream closed');
setTimeout(() => { setTimeout(() => {
@@ -51,41 +82,44 @@ export const useCommonStore = defineStore({
return; return;
} }
// Accumulate partial chunks; SSE data may split JSON across reads. const text = decoder.decode(value);
const text = decoder.decode(value, { stream: true }); const lines = text.split('\n\n');
bufferedText += text; lines.forEach(line => {
if (!line.trim()) {
// Split completed events; keep the trailing partial in buffer.
const segments = bufferedText.split('\n\n');
bufferedText = segments.pop() || '';
segments.forEach(segment => {
const line = segment.trim();
if (!line.startsWith('data: ')) {
return; return;
} }
if (line.startsWith('data:')) {
const logLine = line.replace('data: ', '').trim(); const data = line.substring(5).trim();
if (!logLine) { // {"type":"log","data":"[2021-08-01 00:00:00] INFO: Hello, world!"}
return; let data_json = {}
} try {
data_json = JSON.parse(data);
try { } catch (e) {
const logObject = JSON.parse(logLine); console.warn('Invalid JSON:', data);
// give a uuid if not exists // 尝试处理不完整的行
if (!logObject.uuid) { const parsedData = handleIncompleteLine(data);
logObject.uuid = crypto.randomUUID(); if (parsedData) {
data_json = parsedData;
} else {
return; // 如果无法解析,跳过当前行
}
} }
this.log_cache.push(logObject); if (data_json.type === 'log') {
// Limit log cache size this.log_cache.push(data_json);
if (this.log_cache.length > this.log_cache_max_len) { if (this.log_cache.length > this.log_cache_max_len) {
this.log_cache.splice(0, this.log_cache.length - this.log_cache_max_len); this.log_cache.shift();
}
}
} else {
const parsedData = handleIncompleteLine(line);
if (parsedData && parsedData.type === 'log') {
this.log_cache.push(parsedData);
if (this.log_cache.length > this.log_cache_max_len) {
this.log_cache.shift();
}
} }
} catch (err) {
console.warn('Failed to parse SSE log line, skipping:', err, logLine);
} }
}); });
return reader.read().then(processStream); return reader.read().then(processStream);
}; };
+1 -1
View File
@@ -1,6 +1,6 @@
[project] [project]
name = "AstrBot" name = "AstrBot"
version = "4.9.0" version = "4.8.0"
description = "Easy-to-use multi-platform LLM chatbot and development framework" description = "Easy-to-use multi-platform LLM chatbot and development framework"
readme = "README.md" readme = "README.md"
requires-python = ">=3.10" requires-python = ">=3.10"
+279
View File
@@ -0,0 +1,279 @@
"""Test GitHub webhook platform adapter"""
import asyncio
import hashlib
import hmac
from unittest.mock import MagicMock
import pytest
from astrbot.core.platform.sources.github_webhook.github_webhook_adapter import (
GitHubWebhookPlatformAdapter,
)
@pytest.fixture
def event_queue():
"""Create a test event queue"""
return asyncio.Queue()
@pytest.fixture
def platform_config():
"""Create test platform configuration"""
return {
"type": "github_webhook",
"enable": True,
"id": "test_github_webhook",
"unified_webhook_mode": True,
"webhook_uuid": "test-uuid-123",
"webhook_secret": "", # No secret by default for easier testing
}
@pytest.fixture
def platform_settings():
"""Create test platform settings"""
return {"unique_session": False}
@pytest.fixture
def adapter(platform_config, platform_settings, event_queue):
"""Create test adapter instance"""
return GitHubWebhookPlatformAdapter(platform_config, platform_settings, event_queue)
class TestGitHubWebhookAdapter:
"""Test cases for GitHub webhook adapter"""
def test_adapter_initialization(self, adapter):
"""Test adapter is initialized correctly"""
assert adapter.unified_webhook_mode is True
assert adapter.webhook_secret == ""
assert adapter.meta().name == "github_webhook"
assert adapter.meta().description == "GitHub Webhook 适配器"
@pytest.mark.asyncio
async def test_ping_event(self, adapter):
"""Test GitHub ping event"""
# Mock request
request = MagicMock()
request.headers.get.return_value = "ping"
async def mock_json():
return {}
request.json = mock_json()
response = await adapter.webhook_callback(request)
assert response == {"message": "pong"}
@pytest.mark.asyncio
async def test_issue_created_event(self, adapter, event_queue):
"""Test GitHub issue created event"""
# Mock request with issue created payload
request = MagicMock()
request.headers.get.return_value = "issues"
payload = {
"action": "opened",
"issue": {
"title": "Test Issue",
"body": "This is a test issue",
"html_url": "https://github.com/test/repo/issues/1",
},
"repository": {"full_name": "test/repo"},
"sender": {"login": "testuser"},
}
async def mock_json():
return payload
request.json = mock_json()
response = await adapter.webhook_callback(request)
assert response == {"status": "ok"}
# Verify event was queued
assert not event_queue.empty()
event = event_queue.get_nowait()
assert event.event_type == "issues"
assert "新 Issue 创建" in event.message_str
assert "Test Issue" in event.message_str
@pytest.mark.asyncio
async def test_issue_comment_event(self, adapter, event_queue):
"""Test GitHub issue comment event"""
request = MagicMock()
request.headers.get.return_value = "issue_comment"
payload = {
"action": "created",
"issue": {"title": "Test Issue"},
"comment": {
"body": "Test comment",
"html_url": "https://github.com/test/repo/issues/1#comment",
},
"repository": {"full_name": "test/repo"},
"sender": {"login": "commenter"},
}
async def mock_json():
return payload
request.json = mock_json()
response = await adapter.webhook_callback(request)
assert response == {"status": "ok"}
# Verify event was queued
assert not event_queue.empty()
event = event_queue.get_nowait()
assert event.event_type == "issue_comment"
assert "新 Issue 评论" in event.message_str
assert "Test comment" in event.message_str
@pytest.mark.asyncio
async def test_pull_request_event(self, adapter, event_queue):
"""Test GitHub pull request opened event"""
request = MagicMock()
request.headers.get.return_value = "pull_request"
payload = {
"action": "opened",
"pull_request": {
"title": "Test PR",
"body": "This is a test PR",
"html_url": "https://github.com/test/repo/pull/1",
},
"repository": {"full_name": "test/repo"},
"sender": {"login": "prauthor"},
}
async def mock_json():
return payload
request.json = mock_json()
response = await adapter.webhook_callback(request)
assert response == {"status": "ok"}
# Verify event was queued
assert not event_queue.empty()
event = event_queue.get_nowait()
assert event.event_type == "pull_request"
assert "新 Pull Request" in event.message_str
assert "Test PR" in event.message_str
@pytest.mark.asyncio
async def test_unsupported_event(self, adapter, event_queue):
"""Test unsupported GitHub event type"""
request = MagicMock()
request.headers.get.return_value = "push"
async def mock_json():
return {"action": "created"}
request.json = mock_json()
response = await adapter.webhook_callback(request)
assert response == {"status": "ok"}
# Verify no event was queued for unsupported events
assert event_queue.empty()
@pytest.mark.asyncio
async def test_issue_closed_ignored(self, adapter, event_queue):
"""Test that issue closed action is ignored"""
request = MagicMock()
request.headers.get.return_value = "issues"
payload = {
"action": "closed", # Should be ignored
"issue": {"title": "Test Issue"},
"repository": {"full_name": "test/repo"},
"sender": {"login": "testuser"},
}
async def mock_json():
return payload
request.json = mock_json()
response = await adapter.webhook_callback(request)
assert response == {"status": "ok"}
# Verify no event was queued
assert event_queue.empty()
@pytest.mark.asyncio
async def test_signature_verification(self, platform_settings, event_queue):
"""Test webhook signature verification"""
# Create adapter with webhook secret
config_with_secret = {
"type": "github_webhook",
"enable": True,
"id": "test_github_webhook",
"unified_webhook_mode": True,
"webhook_uuid": "test-uuid-123",
"webhook_secret": "test-secret",
}
adapter = GitHubWebhookPlatformAdapter(
config_with_secret, platform_settings, event_queue
)
# Create a valid signature
body = b'{"action": "opened"}'
signature = hmac.new(b"test-secret", body, hashlib.sha256).hexdigest()
# Mock request with valid signature
request = MagicMock()
request.headers.get = lambda key, default="": {
"X-GitHub-Event": "ping",
"X-Hub-Signature-256": f"sha256={signature}",
}.get(key, default)
async def mock_get_data():
return body
request.get_data = mock_get_data
async def mock_json():
return {"action": "opened"}
request.json = mock_json()
response = await adapter.webhook_callback(request)
assert response == {"message": "pong"}
@pytest.mark.asyncio
async def test_invalid_signature(self, platform_settings, event_queue):
"""Test webhook with invalid signature is rejected"""
# Create adapter with webhook secret
config_with_secret = {
"type": "github_webhook",
"enable": True,
"id": "test_github_webhook",
"unified_webhook_mode": True,
"webhook_uuid": "test-uuid-123",
"webhook_secret": "test-secret",
}
adapter = GitHubWebhookPlatformAdapter(
config_with_secret, platform_settings, event_queue
)
# Mock request with invalid signature
request = MagicMock()
request.headers.get = lambda key, default="": {
"X-GitHub-Event": "ping",
"X-Hub-Signature-256": "sha256=invalidsignature",
}.get(key, default)
async def mock_get_data():
return b'{"action": "opened"}'
request.get_data = mock_get_data
async def mock_json():
return {"action": "opened"}
request.json = mock_json()
response = await adapter.webhook_callback(request)
assert response == ({"error": "Invalid signature"}, 401)
-209
View File
@@ -1,209 +0,0 @@
import asyncio
from unittest.mock import AsyncMock, MagicMock
import pytest
import pytest_asyncio
from quart import Quart
from astrbot.core import LogBroker
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.db.sqlite import SQLiteDatabase
from astrbot.core.knowledge_base.kb_helper import KBHelper
from astrbot.core.knowledge_base.models import KBDocument
from astrbot.dashboard.server import AstrBotDashboard
@pytest_asyncio.fixture(scope="module")
async def core_lifecycle_td(tmp_path_factory):
"""Creates and initializes a core lifecycle instance with a temporary database."""
tmp_db_path = tmp_path_factory.mktemp("data") / "test_data_kb.db"
db = SQLiteDatabase(str(tmp_db_path))
log_broker = LogBroker()
core_lifecycle = AstrBotCoreLifecycle(log_broker, db)
await core_lifecycle.initialize()
# Mock kb_manager and kb_helper
kb_manager = MagicMock()
kb_helper = AsyncMock(spec=KBHelper)
# Configure get_kb to be an async mock that returns kb_helper
kb_manager.get_kb = AsyncMock(return_value=kb_helper)
# Mock upload_document return value
mock_doc = KBDocument(
doc_id="test_doc_id",
kb_id="test_kb_id",
doc_name="test_file.txt",
file_type="txt",
file_size=100,
file_path="",
chunk_count=2,
media_count=0,
)
kb_helper.upload_document.return_value = mock_doc
# kb_manager.get_kb.return_value = kb_helper # Removed this line as it's handled above
core_lifecycle.kb_manager = kb_manager
try:
yield core_lifecycle
finally:
try:
_stop_res = core_lifecycle.stop()
if asyncio.iscoroutine(_stop_res):
await _stop_res
except Exception:
pass
@pytest.fixture(scope="module")
def app(core_lifecycle_td: AstrBotCoreLifecycle):
"""Creates a Quart app instance for testing."""
shutdown_event = asyncio.Event()
server = AstrBotDashboard(core_lifecycle_td, core_lifecycle_td.db, shutdown_event)
return server.app
@pytest_asyncio.fixture(scope="module")
async def authenticated_header(app: Quart, core_lifecycle_td: AstrBotCoreLifecycle):
"""Handles login and returns an authenticated header."""
test_client = app.test_client()
response = await test_client.post(
"/api/auth/login",
json={
"username": core_lifecycle_td.astrbot_config["dashboard"]["username"],
"password": core_lifecycle_td.astrbot_config["dashboard"]["password"],
},
)
data = await response.get_json()
assert data["status"] == "ok"
token = data["data"]["token"]
return {"Authorization": f"Bearer {token}"}
@pytest.mark.asyncio
async def test_import_documents(
app: Quart, authenticated_header: dict, core_lifecycle_td: AstrBotCoreLifecycle
):
"""Tests the import documents functionality."""
test_client = app.test_client()
# Test data
import_data = {
"kb_id": "test_kb_id",
"documents": [
{"file_name": "test_file_1.txt", "chunks": ["chunk1", "chunk2"]},
{"file_name": "test_file_2.md", "chunks": ["chunk3", "chunk4", "chunk5"]},
],
}
# Send request
response = await test_client.post(
"/api/kb/document/import", json=import_data, headers=authenticated_header
)
# Verify response
assert response.status_code == 200
data = await response.get_json()
assert data["status"] == "ok"
assert "task_id" in data["data"]
assert data["data"]["doc_count"] == 2
task_id = data["data"]["task_id"]
# Wait for background task to complete (mocked)
# Since we mocked upload_document, it should be fast, but we might need to poll progress
for _ in range(10):
progress_response = await test_client.get(
f"/api/kb/document/upload/progress?task_id={task_id}",
headers=authenticated_header,
)
progress_data = await progress_response.get_json()
if progress_data["data"]["status"] == "completed":
break
await asyncio.sleep(0.1)
assert progress_data["data"]["status"] == "completed"
result = progress_data["data"]["result"]
assert result["success_count"] == 2
assert result["failed_count"] == 0
# Verify kb_helper.upload_document was called correctly
kb_helper = await core_lifecycle_td.kb_manager.get_kb("test_kb_id")
assert kb_helper.upload_document.call_count == 2
# Check first call arguments
call_args_list = kb_helper.upload_document.call_args_list
# First document
args1, kwargs1 = call_args_list[0]
assert kwargs1["file_name"] == "test_file_1.txt"
assert kwargs1["pre_chunked_text"] == ["chunk1", "chunk2"]
# Second document
args2, kwargs2 = call_args_list[1]
assert kwargs2["file_name"] == "test_file_2.md"
assert kwargs2["pre_chunked_text"] == ["chunk3", "chunk4", "chunk5"]
@pytest.mark.asyncio
async def test_import_documents_invalid_input(app: Quart, authenticated_header: dict):
"""Tests import documents with invalid input."""
test_client = app.test_client()
# Missing kb_id
response = await test_client.post(
"/api/kb/document/import", json={"documents": []}, headers=authenticated_header
)
data = await response.get_json()
assert data["status"] == "error"
assert "缺少参数 kb_id" in data["message"]
# Missing documents
response = await test_client.post(
"/api/kb/document/import",
json={"kb_id": "test_kb"},
headers=authenticated_header,
)
data = await response.get_json()
assert data["status"] == "error"
assert "缺少参数 documents" in data["message"]
# Invalid document format
response = await test_client.post(
"/api/kb/document/import",
json={
"kb_id": "test_kb",
"documents": [{"file_name": "test"}], # Missing chunks
},
headers=authenticated_header,
)
data = await response.get_json()
assert data["status"] == "error"
assert "文档格式错误" in data["message"]
# Invalid chunks type
response = await test_client.post(
"/api/kb/document/import",
json={
"kb_id": "test_kb",
"documents": [{"file_name": "test", "chunks": "not-a-list"}],
},
headers=authenticated_header,
)
data = await response.get_json()
assert data["status"] == "error"
assert "chunks 必须是列表" in data["message"]
# Invalid chunks content
response = await test_client.post(
"/api/kb/document/import",
json={
"kb_id": "test_kb",
"documents": [{"file_name": "test", "chunks": ["valid", ""]}],
},
headers=authenticated_header,
)
data = await response.get_json()
assert data["status"] == "error"
assert "chunks 必须是非空字符串列表" in data["message"]