diff --git a/astrbot/cli/__init__.py b/astrbot/cli/__init__.py index ea674c5c5..454a2edae 100644 --- a/astrbot/cli/__init__.py +++ b/astrbot/cli/__init__.py @@ -1 +1 @@ -__version__ = "4.8.0" +__version__ = "4.9.0" diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index e8778bfc6..7f42aa88b 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -4,7 +4,7 @@ import os from astrbot.core.utils.astrbot_path import get_astrbot_data_path -VERSION = "4.8.0" +VERSION = "4.9.0" DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db") WEBHOOK_SUPPORTED_PLATFORMS = [ @@ -13,6 +13,7 @@ WEBHOOK_SUPPORTED_PLATFORMS = [ "wecom", "wecom_ai_bot", "slack", + "lark", ] # 默认配置 @@ -277,6 +278,10 @@ CONFIG_METADATA_2 = { "app_id": "", "app_secret": "", "domain": "https://open.feishu.cn", + "lark_connection_mode": "socket", # webhook, socket + "webhook_uuid": "", + "lark_encrypt_key": "", + "lark_verification_token": "", }, "钉钉(DingTalk)": { "id": "dingtalk", @@ -370,6 +375,28 @@ CONFIG_METADATA_2 = { # "type": "string", # "options": ["fullscreen", "embedded"], # }, + "lark_connection_mode": { + "description": "订阅方式", + "type": "string", + "options": ["socket", "webhook"], + "labels": ["长连接模式", "推送至服务器模式"], + }, + "lark_encrypt_key": { + "description": "Encrypt Key", + "type": "string", + "hint": "用于解密飞书回调数据的加密密钥", + "condition": { + "lark_connection_mode": "webhook", + }, + }, + "lark_verification_token": { + "description": "Verification Token", + "type": "string", + "hint": "用于验证飞书回调请求的令牌", + "condition": { + "lark_connection_mode": "webhook", + }, + }, "is_sandbox": { "description": "沙箱模式", "type": "bool", diff --git a/astrbot/core/log.py b/astrbot/core/log.py index 376f5ffd6..806ebcebb 100644 --- a/astrbot/core/log.py +++ b/astrbot/core/log.py @@ -24,6 +24,7 @@ import asyncio import logging import os import sys +import time from asyncio import Queue from collections import deque @@ -148,7 +149,7 @@ class LogQueueHandler(logging.Handler): self.log_broker.publish( { "level": record.levelname, - "time": record.asctime, + "time": time.time(), "data": log_entry, }, ) diff --git a/astrbot/core/platform/manager.py b/astrbot/core/platform/manager.py index b941c8cbc..f4313f642 100644 --- a/astrbot/core/platform/manager.py +++ b/astrbot/core/platform/manager.py @@ -5,6 +5,7 @@ from asyncio import Queue from astrbot.core import logger from astrbot.core.config.astrbot_config import AstrBotConfig from astrbot.core.star.star_handler import EventType, star_handlers_registry, star_map +from astrbot.core.utils.webhook_utils import ensure_platform_webhook_config from .platform import Platform, PlatformStatus from .register import platform_cls_map @@ -18,6 +19,7 @@ class PlatformManager: self._inst_map: dict[str, dict] = {} + self.astrbot_config = config self.platforms_config = config["platform"] self.settings = config["platform_settings"] """NOTE: 这里是 default 的配置文件,以保证最大的兼容性; @@ -29,6 +31,8 @@ class PlatformManager: """初始化所有平台适配器""" for platform in self.platforms_config: try: + if ensure_platform_webhook_config(platform): + self.astrbot_config.save_config() await self.load_platform(platform) except Exception as e: logger.error(f"初始化 {platform} 平台适配器失败: {e}") diff --git a/astrbot/core/platform/platform.py b/astrbot/core/platform/platform.py index c139b8bd7..c2e55fb63 100644 --- a/astrbot/core/platform/platform.py +++ b/astrbot/core/platform/platform.py @@ -80,6 +80,13 @@ class Platform(abc.ABC): if self._status == PlatformStatus.ERROR: self._status = PlatformStatus.RUNNING + def unified_webhook(self) -> bool: + """是否正在使用统一 Webhook 模式""" + return bool( + self.config.get("unified_webhook_mode", False) + and self.config.get("webhook_uuid") + ) + def get_stats(self) -> dict: """获取平台统计信息""" meta = self.meta() @@ -97,6 +104,7 @@ class Platform(abc.ABC): } if self.last_error else None, + "unified_webhook": self.unified_webhook(), } @abc.abstractmethod diff --git a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py index b3c2229ab..52dd21d56 100644 --- a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py +++ b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py @@ -421,7 +421,7 @@ class AiocqhttpAdapter(Platform): async def shutdown_trigger_placeholder(self): await self.shutdown_event.wait() - logger.info("aiocqhttp 适配器已被优雅地关闭") + logger.info("aiocqhttp 适配器已被关闭") def meta(self) -> PlatformMetadata: return self.metadata diff --git a/astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py b/astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py index 8905698a5..6f9e25df4 100644 --- a/astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py +++ b/astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py @@ -245,7 +245,7 @@ class DingtalkPlatformAdapter(Platform): task.result() except Exception as e: if "Graceful shutdown" in str(e): - logger.info("钉钉适配器已被优雅地关闭") + logger.info("钉钉适配器已被关闭") return logger.error(f"钉钉机器人启动失败: {e}") diff --git a/astrbot/core/platform/sources/lark/lark_adapter.py b/astrbot/core/platform/sources/lark/lark_adapter.py index 473be096f..08df1f359 100644 --- a/astrbot/core/platform/sources/lark/lark_adapter.py +++ b/astrbot/core/platform/sources/lark/lark_adapter.py @@ -2,8 +2,9 @@ import asyncio import base64 import json import re +import time import uuid -from typing import cast +from typing import Any, cast import lark_oapi as lark from lark_oapi.api.im.v1 import ( @@ -11,6 +12,7 @@ from lark_oapi.api.im.v1 import ( CreateMessageRequestBody, GetMessageResourceRequest, ) +from lark_oapi.api.im.v1.processor import P2ImMessageReceiveV1Processor import astrbot.api.message_components as Comp from astrbot import logger @@ -23,9 +25,11 @@ from astrbot.api.platform import ( PlatformMetadata, ) from astrbot.core.platform.astr_message_event import MessageSesion +from astrbot.core.utils.webhook_utils import log_webhook_info from ...register import register_platform_adapter from .lark_event import LarkMessageEvent +from .server import LarkWebhookServer @register_platform_adapter( @@ -47,9 +51,13 @@ class LarkPlatformAdapter(Platform): self.domain = platform_config.get("domain", lark.FEISHU_DOMAIN) self.bot_name = platform_config.get("lark_bot_name", "astrbot") + # socket or webhook + self.connection_mode = platform_config.get("lark_connection_mode", "socket") + if not self.bot_name: logger.warning("未设置飞书机器人名称,@ 机器人可能得不到回复。") + # 初始化 WebSocket 长连接相关配置 async def on_msg_event_recv(event: lark.im.v1.P2ImMessageReceiveV1): await self.convert_msg(event) @@ -62,6 +70,8 @@ class LarkPlatformAdapter(Platform): .build() ) + self.do_v2_msg_event = do_v2_msg_event + self.client = lark.ws.Client( app_id=self.appid, app_secret=self.appsecret, @@ -71,9 +81,47 @@ class LarkPlatformAdapter(Platform): ) self.lark_api = ( - lark.Client.builder().app_id(self.appid).app_secret(self.appsecret).build() + lark.Client.builder() + .app_id(self.appid) + .app_secret(self.appsecret) + .log_level(lark.LogLevel.ERROR) + .domain(self.domain) + .build() ) + self.webhook_server = None + if self.connection_mode == "webhook": + self.webhook_server = LarkWebhookServer(platform_config, event_queue) + self.webhook_server.set_callback(self.handle_webhook_event) + + self.event_id_timestamps: dict[str, float] = {} + + def _clean_expired_events(self): + """清理超过 30 分钟的事件记录""" + current_time = time.time() + expired_keys = [ + event_id + for event_id, timestamp in self.event_id_timestamps.items() + if current_time - timestamp > 1800 + ] + for event_id in expired_keys: + del self.event_id_timestamps[event_id] + + def _is_duplicate_event(self, event_id: str) -> bool: + """检查事件是否重复 + + Args: + event_id: 事件ID + + Returns: + True 表示重复事件,False 表示新事件 + """ + self._clean_expired_events() + if event_id in self.event_id_timestamps: + return True + self.event_id_timestamps[event_id] = time.time() + return False + async def send_by_session( self, session: MessageSesion, @@ -137,7 +185,11 @@ class LarkPlatformAdapter(Platform): return abm = AstrBotMessage() - abm.timestamp = cast(int, message.create_time) // 1000 + + if message.create_time: + abm.timestamp = int(message.create_time) // 1000 + else: + abm.timestamp = int(time.time()) abm.message = [] abm.type = ( MessageType.GROUP_MESSAGE @@ -290,13 +342,61 @@ class LarkPlatformAdapter(Platform): self._event_queue.put_nowait(event) + async def handle_webhook_event(self, event_data: dict): + """处理 Webhook 事件 + + Args: + event_data: Webhook 事件数据 + """ + try: + header = event_data.get("header", {}) + event_id = header.get("event_id", "") + if event_id and self._is_duplicate_event(event_id): + logger.debug(f"[Lark Webhook] 跳过重复事件: {event_id}") + return + event_type = header.get("event_type", "") + if event_type == "im.message.receive_v1": + processor = P2ImMessageReceiveV1Processor(self.do_v2_msg_event) + data = (processor.type())(event_data) + processor.do(data) + else: + logger.debug(f"[Lark Webhook] 未处理的事件类型: {event_type}") + except Exception as e: + logger.error(f"[Lark Webhook] 处理事件失败: {e}", exc_info=True) + async def run(self): - # self.client.start() - await self.client._connect() + if self.connection_mode == "webhook": + # Webhook 模式 + if self.webhook_server is None: + logger.error("[Lark] Webhook 模式已启用,但 webhook_server 未初始化") + return + + webhook_uuid = self.config.get("webhook_uuid") + if webhook_uuid: + log_webhook_info(f"{self.meta().id}(飞书 Webhook)", webhook_uuid) + else: + logger.warning("[Lark] Webhook 模式已启用,但未配置 webhook_uuid") + else: + # 长连接模式 + await self.client._connect() + + async def webhook_callback(self, request: Any) -> Any: + """统一 Webhook 回调入口""" + if not self.webhook_server: + return {"error": "Webhook server not initialized"}, 500 + + return await self.webhook_server.handle_callback(request) async def terminate(self): - await self.client._disconnect() - logger.info("飞书(Lark) 适配器已被优雅地关闭") + if self.connection_mode == "socket": + await self.client._disconnect() + logger.info("飞书(Lark) 适配器已关闭") def get_client(self) -> lark.ws.Client: return self.client + + def unified_webhook(self) -> bool: + return bool( + self.config.get("lark_connection_mode", "") == "webhook" + and self.config.get("webhook_uuid") + ) diff --git a/astrbot/core/platform/sources/lark/server.py b/astrbot/core/platform/sources/lark/server.py new file mode 100644 index 000000000..3921eb8be --- /dev/null +++ b/astrbot/core/platform/sources/lark/server.py @@ -0,0 +1,206 @@ +"""飞书(Lark) Webhook 服务器实现 + +实现飞书事件订阅的 Webhook 模式,支持: +1. 请求 URL 验证 (challenge 验证) +2. 事件加密/解密 (AES-256-CBC) +3. 签名校验 (SHA256) +4. 事件接收和处理 +""" + +import asyncio +import base64 +import hashlib +import json +from collections.abc import Awaitable, Callable + +from Crypto.Cipher import AES + +from astrbot.api import logger + + +class AESCipher: + """AES 加密/解密工具类""" + + def __init__(self, key: str): + self.bs = AES.block_size + self.key = hashlib.sha256(self.str_to_bytes(key)).digest() + + @staticmethod + def str_to_bytes(data): + u_type = type(b"".decode("utf8")) + if isinstance(data, u_type): + return data.encode("utf8") + return data + + @staticmethod + def _unpad(s): + return s[: -ord(s[len(s) - 1 :])] + + def decrypt(self, enc): + iv = enc[: AES.block_size] + cipher = AES.new(self.key, AES.MODE_CBC, iv) + return self._unpad(cipher.decrypt(enc[AES.block_size :])) + + def decrypt_string(self, enc): + enc = base64.b64decode(enc) + return self.decrypt(enc).decode("utf8") + + +class LarkWebhookServer: + """飞书 Webhook 服务器 + + 仅支持统一 Webhook 模式 + """ + + def __init__(self, config: dict, event_queue: asyncio.Queue): + """初始化 Webhook 服务器 + + Args: + config: 飞书配置 + event_queue: 事件队列 + """ + self.app_id = config["app_id"] + self.app_secret = config["app_secret"] + self.encrypt_key = config.get("lark_encrypt_key", "") + self.verification_token = config.get("lark_verification_token", "") + + self.event_queue = event_queue + self.callback: Callable[[dict], Awaitable[None]] | None = None + + # 初始化加密工具 + self.cipher = None + if self.encrypt_key: + self.cipher = AESCipher(self.encrypt_key) + + def verify_signature( + self, + timestamp: str, + nonce: str, + encrypt_key: str, + body: bytes, + signature: str, + ) -> bool: + """验证签名 + + Args: + timestamp: 请求时间戳 + nonce: 随机数 + encrypt_key: 加密密钥 + body: 请求体 + signature: 签名 + + Returns: + 签名是否有效 + """ + # 拼接字符串: timestamp + nonce + encrypt_key + body + bytes_b1 = (timestamp + nonce + encrypt_key).encode("utf-8") + bytes_b = bytes_b1 + body + h = hashlib.sha256(bytes_b) + calculated_signature = h.hexdigest() + return calculated_signature == signature + + def decrypt_event(self, encrypted_data: str) -> dict: + """解密事件数据 + + Args: + encrypted_data: 加密的事件数据 + + Returns: + 解密后的事件字典 + """ + if not self.cipher: + raise ValueError("未配置 encrypt_key,无法解密事件") + + decrypted_str = self.cipher.decrypt_string(encrypted_data) + return json.loads(decrypted_str) + + async def handle_challenge(self, event_data: dict) -> dict: + """处理 challenge 验证请求 + + Args: + event_data: 事件数据 + + Returns: + 包含 challenge 的响应 + """ + challenge = event_data.get("challenge", "") + logger.info(f"[Lark Webhook] 收到 challenge 验证请求: {challenge}") + + return {"challenge": challenge} + + async def handle_callback(self, request) -> tuple[dict, int] | dict: + """处理 webhook 回调,可被统一 webhook 入口复用 + + Args: + request: Quart 请求对象 + + Returns: + 响应数据 + """ + # 获取原始请求体 + body = await request.get_data() + + try: + event_data = await request.json + except Exception as e: + logger.error(f"[Lark Webhook] 解析请求体失败: {e}") + return {"error": "Invalid JSON"}, 400 + + if not event_data: + logger.error("[Lark Webhook] 请求体为空") + return {"error": "Empty request body"}, 400 + + # 如果配置了 encrypt_key,进行签名验证 + if self.encrypt_key: + timestamp = request.headers.get("X-Lark-Request-Timestamp", "") + nonce = request.headers.get("X-Lark-Request-Nonce", "") + signature = request.headers.get("X-Lark-Signature", "") + + if timestamp and nonce and signature: + if not self.verify_signature( + timestamp, nonce, self.encrypt_key, body, signature + ): + logger.error("[Lark Webhook] 签名验证失败") + return {"error": "Invalid signature"}, 401 + + # 检查是否是加密事件 + if "encrypt" in event_data: + try: + event_data = self.decrypt_event(event_data["encrypt"]) + logger.debug(f"[Lark Webhook] 解密后的事件: {event_data}") + except Exception as e: + logger.error(f"[Lark Webhook] 解密事件失败: {e}") + return {"error": "Decryption failed"}, 400 + + # 验证 token + if self.verification_token: + header = event_data.get("header", {}) + if header: + token = header.get("token", "") + else: + token = event_data.get("token", "") + if token != self.verification_token: + logger.error("[Lark Webhook] Verification Token 不匹配。") + return {"error": "Invalid verification token"}, 401 + + # 处理 URL 验证 (challenge) + if event_data.get("type") == "url_verification": + return await self.handle_challenge(event_data) + + # 调用回调函数处理事件 + if self.callback: + try: + await self.callback(event_data) + except Exception as e: + logger.error(f"[Lark Webhook] 处理事件回调失败: {e}", exc_info=True) + return {"error": "Event processing failed"}, 500 + + return {} + + def set_callback(self, callback: Callable[[dict], Awaitable[None]]): + """设置事件回调函数 + + Args: + callback: 处理事件的异步函数 + """ + self.callback = callback diff --git a/astrbot/core/platform/sources/slack/slack_adapter.py b/astrbot/core/platform/sources/slack/slack_adapter.py index 4621f8494..ed838b0a9 100644 --- a/astrbot/core/platform/sources/slack/slack_adapter.py +++ b/astrbot/core/platform/sources/slack/slack_adapter.py @@ -409,7 +409,7 @@ class SlackAdapter(Platform): await self.socket_client.stop() if self.webhook_client: await self.webhook_client.stop() - logger.info("Slack 适配器已被优雅地关闭") + logger.info("Slack 适配器已被关闭") def meta(self) -> PlatformMetadata: return self.metadata @@ -427,3 +427,10 @@ class SlackAdapter(Platform): def get_client(self): return self.web_client + + def unified_webhook(self) -> bool: + return bool( + self.config.get("unified_webhook_mode", False) + and self.config.get("slack_connection_mode", "") == "webhook" + and self.config.get("webhook_uuid") + ) diff --git a/astrbot/core/platform/sources/telegram/tg_adapter.py b/astrbot/core/platform/sources/telegram/tg_adapter.py index bca45ea8d..218d13bdc 100644 --- a/astrbot/core/platform/sources/telegram/tg_adapter.py +++ b/astrbot/core/platform/sources/telegram/tg_adapter.py @@ -424,6 +424,6 @@ class TelegramPlatformAdapter(Platform): if self.application.updater is not None: await self.application.updater.stop() - logger.info("Telegram 适配器已被优雅地关闭") + logger.info("Telegram 适配器已被关闭") except Exception as e: logger.error(f"Telegram 适配器关闭时出错: {e}") diff --git a/astrbot/core/platform/sources/wecom/wecom_adapter.py b/astrbot/core/platform/sources/wecom/wecom_adapter.py index 8f3d091a4..44ed75117 100644 --- a/astrbot/core/platform/sources/wecom/wecom_adapter.py +++ b/astrbot/core/platform/sources/wecom/wecom_adapter.py @@ -422,4 +422,4 @@ class WecomPlatformAdapter(Platform): await self.server.server.shutdown() except Exception as _: pass - logger.info("企业微信 适配器已被优雅地关闭") + logger.info("企业微信 适配器已被关闭") diff --git a/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py b/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py index d0304a48e..d12285d68 100644 --- a/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py +++ b/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py @@ -349,4 +349,4 @@ class WeixinOfficialAccountPlatformAdapter(Platform): await self.server.server.shutdown() except Exception as _: pass - logger.info("微信公众平台 适配器已被优雅地关闭") + logger.info("微信公众平台 适配器已被关闭") diff --git a/astrbot/core/utils/webhook_utils.py b/astrbot/core/utils/webhook_utils.py index c56d00b37..0e1c3f9cd 100644 --- a/astrbot/core/utils/webhook_utils.py +++ b/astrbot/core/utils/webhook_utils.py @@ -1,4 +1,7 @@ +import uuid + from astrbot.core import astrbot_config, logger +from astrbot.core.config.default import WEBHOOK_SUPPORTED_PLATFORMS def _get_callback_api_base() -> str: @@ -45,3 +48,19 @@ def log_webhook_info(platform_name: str, webhook_uuid: str): "====================\n" ) logger.info(display_log) + + +def ensure_platform_webhook_config(platform_cfg: dict) -> bool: + """为支持统一 webhook 的平台自动生成 webhook_uuid + + Args: + platform_cfg (dict): 平台配置字典 + + Returns: + bool: 如果生成了 webhook_uuid 则返回 True,否则返回 False + """ + pt = platform_cfg.get("type", "") + if pt in WEBHOOK_SUPPORTED_PLATFORMS and not platform_cfg.get("webhook_uuid"): + platform_cfg["webhook_uuid"] = uuid.uuid4().hex[:16] + return True + return False diff --git a/astrbot/dashboard/routes/config.py b/astrbot/dashboard/routes/config.py index e8f17cc99..0edbe8377 100644 --- a/astrbot/dashboard/routes/config.py +++ b/astrbot/dashboard/routes/config.py @@ -2,7 +2,6 @@ import asyncio import inspect import os import traceback -import uuid from typing import Any from quart import request @@ -15,7 +14,6 @@ from astrbot.core.config.default import ( CONFIG_METADATA_3_SYSTEM, DEFAULT_CONFIG, DEFAULT_VALUE_MAP, - WEBHOOK_SUPPORTED_PLATFORMS, ) from astrbot.core.config.i18n_utils import ConfigMetadataI18n from astrbot.core.core_lifecycle import AstrBotCoreLifecycle @@ -23,6 +21,7 @@ from astrbot.core.platform.register import platform_cls_map, platform_registry from astrbot.core.provider import Provider from astrbot.core.provider.register import provider_registry from astrbot.core.star.star import star_registry +from astrbot.core.utils.webhook_utils import ensure_platform_webhook_config from .route import Response, Route, RouteContext @@ -559,13 +558,8 @@ class ConfigRoute(Route): async def post_new_platform(self): new_platform_config = await request.json - # 如果是支持统一 webhook 模式的平台,且启用了统一 webhook 模式,自动生成 webhook_uuid - platform_type = new_platform_config.get("type", "") - if platform_type in WEBHOOK_SUPPORTED_PLATFORMS: - if new_platform_config.get("unified_webhook_mode", False): - # 如果没有 webhook_uuid 或为空,自动生成 - if not new_platform_config.get("webhook_uuid"): - new_platform_config["webhook_uuid"] = uuid.uuid4().hex[:16] + # 如果是支持统一 webhook 模式的平台,生成 webhook_uuid + ensure_platform_webhook_config(new_platform_config) self.config["platform"].append(new_platform_config) try: @@ -597,12 +591,7 @@ class ConfigRoute(Route): return Response().error("参数错误").__dict__ # 如果是支持统一 webhook 模式的平台,且启用了统一 webhook 模式,确保有 webhook_uuid - platform_type = new_config.get("type", "") - if platform_type in WEBHOOK_SUPPORTED_PLATFORMS: - if new_config.get("unified_webhook_mode", False): - # 如果没有 webhook_uuid 或为空,自动生成 - if not new_config.get("webhook_uuid"): - new_config["webhook_uuid"] = uuid.uuid4().hex + ensure_platform_webhook_config(new_config) for i, platform in enumerate(self.config["platform"]): if platform["id"] == platform_id: diff --git a/astrbot/dashboard/routes/conversation.py b/astrbot/dashboard/routes/conversation.py index d19fdf793..513d3603f 100644 --- a/astrbot/dashboard/routes/conversation.py +++ b/astrbot/dashboard/routes/conversation.py @@ -1,7 +1,9 @@ import json import traceback +from datetime import datetime +from io import BytesIO -from quart import request +from quart import request, send_file from astrbot.core import logger from astrbot.core.core_lifecycle import AstrBotCoreLifecycle @@ -30,6 +32,7 @@ class ConversationRoute(Route): "POST", self.update_history, ), + "/conversation/export": ("POST", self.export_conversations), } self.db_helper = db_helper self.conv_mgr = core_lifecycle.conversation_manager @@ -283,3 +286,90 @@ class ConversationRoute(Route): except Exception as e: logger.error(f"更新对话历史失败: {e!s}\n{traceback.format_exc()}") return Response().error(f"更新对话历史失败: {e!s}").__dict__ + + async def export_conversations(self): + """批量导出对话为 JSONL 格式""" + try: + data = await request.get_json() + conversations_to_export = data.get("conversations", []) + + if not conversations_to_export: + return Response().error("导出列表不能为空").__dict__ + + # 收集所有对话的内容 + jsonl_lines = [] + exported_count = 0 + failed_items = [] + + for conv_info in conversations_to_export: + user_id = conv_info.get("user_id") + cid = conv_info.get("cid") + + if not user_id or not cid: + failed_items.append( + f"user_id:{user_id}, cid:{cid} - 缺少必要参数", + ) + continue + + try: + conversation = await self.conv_mgr.get_conversation( + unified_msg_origin=user_id, + conversation_id=cid, + ) + + if not conversation: + failed_items.append( + f"user_id:{user_id}, cid:{cid} - 对话不存在" + ) + continue + + # 解析对话内容 (history is always a JSON string from _convert_conv_from_v2_to_v1) + content = json.loads(conversation.history) + + # 创建导出记录 + export_record = { + "cid": cid, + "user_id": user_id, + "platform_id": conversation.platform_id, + "title": conversation.title, + "persona_id": conversation.persona_id, + "created_at": conversation.created_at, + "updated_at": conversation.updated_at, + "content": content, + } + + # 将记录转换为 JSON 字符串并添加到 JSONL + jsonl_lines.append(json.dumps(export_record, ensure_ascii=False)) + exported_count += 1 + + except Exception as e: + failed_items.append(f"user_id:{user_id}, cid:{cid} - {e!s}") + logger.error( + f"导出对话失败: user_id={user_id}, cid={cid}, error={e!s}" + ) + + if exported_count == 0: + return Response().error("没有成功导出任何对话").__dict__ + + # 创建 JSONL 内容 + jsonl_content = "\n".join(jsonl_lines) + + # 创建一个内存文件对象 + file_obj = BytesIO(jsonl_content.encode("utf-8")) + file_obj.seek(0) + + # 生成文件名 + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + filename = f"astrbot_conversations_export_{timestamp}.jsonl" + + # 返回文件流 + return await send_file( + file_obj, + mimetype="application/jsonl", + as_attachment=True, + attachment_filename=filename, + ) + + except Exception as e: + logger.error(f"批量导出对话失败: {e!s}\n{traceback.format_exc()}") + return Response().error(f"批量导出对话失败: {e!s}").__dict__ diff --git a/astrbot/dashboard/routes/knowledge_base.py b/astrbot/dashboard/routes/knowledge_base.py index d7db42c40..537a81f0b 100644 --- a/astrbot/dashboard/routes/knowledge_base.py +++ b/astrbot/dashboard/routes/knowledge_base.py @@ -48,6 +48,7 @@ class KnowledgeBaseRoute(Route): # 文档管理 "/kb/document/list": ("GET", self.list_documents), "/kb/document/upload": ("POST", self.upload_document), + "/kb/document/import": ("POST", self.import_documents), "/kb/document/upload/url": ("POST", self.upload_document_from_url), "/kb/document/upload/progress": ("GET", self.get_upload_progress), "/kb/document/get": ("GET", self.get_document), @@ -66,6 +67,65 @@ class KnowledgeBaseRoute(Route): def _get_kb_manager(self): return self.core_lifecycle.kb_manager + def _init_task(self, task_id: str, status: str = "pending") -> None: + self.upload_tasks[task_id] = { + "status": status, + "result": None, + "error": None, + } + + def _set_task_result( + self, task_id: str, status: str, result: any = None, error: str | None = None + ) -> None: + self.upload_tasks[task_id] = { + "status": status, + "result": result, + "error": error, + } + if task_id in self.upload_progress: + self.upload_progress[task_id]["status"] = status + + def _update_progress( + self, + task_id: str, + *, + status: str | None = None, + file_index: int | None = None, + file_name: str | None = None, + stage: str | None = None, + current: int | None = None, + total: int | None = None, + ) -> None: + if task_id not in self.upload_progress: + return + p = self.upload_progress[task_id] + if status is not None: + p["status"] = status + if file_index is not None: + p["file_index"] = file_index + if file_name is not None: + p["file_name"] = file_name + if stage is not None: + p["stage"] = stage + if current is not None: + p["current"] = current + if total is not None: + p["total"] = total + + def _make_progress_callback(self, task_id: str, file_idx: int, file_name: str): + async def _callback(stage: str, current: int, total: int): + self._update_progress( + task_id, + status="processing", + file_index=file_idx, + file_name=file_name, + stage=stage, + current=current, + total=total, + ) + + return _callback + async def _background_upload_task( self, task_id: str, @@ -80,11 +140,7 @@ class KnowledgeBaseRoute(Route): """后台上传任务""" try: # 初始化任务状态 - self.upload_tasks[task_id] = { - "status": "processing", - "result": None, - "error": None, - } + self._init_task(task_id, status="processing") self.upload_progress[task_id] = { "status": "processing", "file_index": 0, @@ -100,30 +156,20 @@ class KnowledgeBaseRoute(Route): for file_idx, file_info in enumerate(files_to_upload): try: # 更新整体进度 - self.upload_progress[task_id].update( - { - "status": "processing", - "file_index": file_idx, - "file_name": file_info["file_name"], - "stage": "parsing", - "current": 0, - "total": 100, - }, + self._update_progress( + task_id, + status="processing", + file_index=file_idx, + file_name=file_info["file_name"], + stage="parsing", + current=0, + total=100, ) # 创建进度回调函数 - async def progress_callback(stage, current, total): - if task_id in self.upload_progress: - self.upload_progress[task_id].update( - { - "status": "processing", - "file_index": file_idx, - "file_name": file_info["file_name"], - "stage": stage, - "current": current, - "total": total, - }, - ) + progress_callback = self._make_progress_callback( + task_id, file_idx, file_info["file_name"] + ) doc = await kb_helper.upload_document( file_name=file_info["file_name"], @@ -154,23 +200,99 @@ class KnowledgeBaseRoute(Route): "failed_count": len(failed_docs), } - self.upload_tasks[task_id] = { - "status": "completed", - "result": result, - "error": None, - } - self.upload_progress[task_id]["status"] = "completed" + self._set_task_result(task_id, "completed", result=result) except Exception as e: logger.error(f"后台上传任务 {task_id} 失败: {e}") logger.error(traceback.format_exc()) - self.upload_tasks[task_id] = { - "status": "failed", - "result": None, - "error": str(e), + self._set_task_result(task_id, "failed", error=str(e)) + + async def _background_import_task( + self, + task_id: str, + kb_helper, + documents: list, + batch_size: int, + tasks_limit: int, + max_retries: int, + ): + """后台导入预切片文档任务""" + try: + # 初始化任务状态 + self._init_task(task_id, status="processing") + self.upload_progress[task_id] = { + "status": "processing", + "file_index": 0, + "file_total": len(documents), + "stage": "waiting", + "current": 0, + "total": 100, } - if task_id in self.upload_progress: - self.upload_progress[task_id]["status"] = "failed" + + uploaded_docs = [] + failed_docs = [] + + for file_idx, doc_info in enumerate(documents): + file_name = doc_info.get("file_name", f"imported_doc_{file_idx}") + chunks = doc_info.get("chunks", []) + + try: + # 更新整体进度 + self._update_progress( + task_id, + status="processing", + file_index=file_idx, + file_name=file_name, + stage="importing", + current=0, + total=100, + ) + + # 创建进度回调函数 + progress_callback = self._make_progress_callback( + task_id, file_idx, file_name + ) + + # 调用 upload_document,传入 pre_chunked_text + doc = await kb_helper.upload_document( + file_name=file_name, + file_content=None, # 预切片模式下不需要原始内容 + file_type=doc_info.get("file_type") + or ( + file_name.rsplit(".", 1)[-1].lower() + if "." in file_name + else "txt" + ), + batch_size=batch_size, + tasks_limit=tasks_limit, + max_retries=max_retries, + progress_callback=progress_callback, + pre_chunked_text=chunks, + ) + + uploaded_docs.append(doc.model_dump()) + except Exception as e: + logger.error(f"导入文档 {file_name} 失败: {e}") + failed_docs.append( + {"file_name": file_name, "error": str(e)}, + ) + + # 更新任务完成状态 + result = { + "task_id": task_id, + "uploaded": uploaded_docs, + "failed": failed_docs, + "total": len(documents), + "success_count": len(uploaded_docs), + "failed_count": len(failed_docs), + } + + self._set_task_result(task_id, "completed", result=result) + + except Exception as e: + logger.error(f"后台导入任务 {task_id} 失败: {e}") + logger.error(traceback.format_exc()) + self._set_task_result(task_id, "failed", error=str(e)) async def list_kbs(self): """获取知识库列表 @@ -614,11 +736,7 @@ class KnowledgeBaseRoute(Route): task_id = str(uuid.uuid4()) # 初始化任务状态 - self.upload_tasks[task_id] = { - "status": "pending", - "result": None, - "error": None, - } + self._init_task(task_id, status="pending") # 启动后台任务 asyncio.create_task( @@ -653,6 +771,93 @@ class KnowledgeBaseRoute(Route): logger.error(traceback.format_exc()) return Response().error(f"上传文档失败: {e!s}").__dict__ + def _validate_import_request(self, data: dict): + kb_id = data.get("kb_id") + if not kb_id: + raise ValueError("缺少参数 kb_id") + + documents = data.get("documents") + if not documents or not isinstance(documents, list): + raise ValueError("缺少参数 documents 或格式错误") + + for doc in documents: + if "file_name" not in doc or "chunks" not in doc: + raise ValueError("文档格式错误,必须包含 file_name 和 chunks") + if not isinstance(doc["chunks"], list): + raise ValueError("chunks 必须是列表") + if not all( + isinstance(chunk, str) and chunk.strip() for chunk in doc["chunks"] + ): + raise ValueError("chunks 必须是非空字符串列表") + + batch_size = data.get("batch_size", 32) + tasks_limit = data.get("tasks_limit", 3) + max_retries = data.get("max_retries", 3) + return kb_id, documents, batch_size, tasks_limit, max_retries + + async def import_documents(self): + """导入预切片文档 + + Body: + - kb_id: 知识库 ID (必填) + - documents: 文档列表 (必填) + - file_name: 文件名 (必填) + - chunks: 切片列表 (必填, list[str]) + - file_type: 文件类型 (可选, 默认从文件名推断或为 txt) + - batch_size: 批处理大小 (可选, 默认32) + - tasks_limit: 并发任务限制 (可选, 默认3) + - max_retries: 最大重试次数 (可选, 默认3) + """ + try: + kb_manager = self._get_kb_manager() + data = await request.json + + kb_id, documents, batch_size, tasks_limit, max_retries = ( + self._validate_import_request(data) + ) + + # 获取知识库 + kb_helper = await kb_manager.get_kb(kb_id) + if not kb_helper: + return Response().error("知识库不存在").__dict__ + + # 生成任务ID + task_id = str(uuid.uuid4()) + + # 初始化任务状态 + self._init_task(task_id, status="pending") + + # 启动后台任务 + asyncio.create_task( + self._background_import_task( + task_id=task_id, + kb_helper=kb_helper, + documents=documents, + batch_size=batch_size, + tasks_limit=tasks_limit, + max_retries=max_retries, + ), + ) + + return ( + Response() + .ok( + { + "task_id": task_id, + "doc_count": len(documents), + "message": "import task created, processing in background", + }, + ) + .__dict__ + ) + + except ValueError as e: + return Response().error(str(e)).__dict__ + except Exception as e: + logger.error(f"导入文档失败: {e}") + logger.error(traceback.format_exc()) + return Response().error(f"导入文档失败: {e!s}").__dict__ + async def get_upload_progress(self): """获取上传进度和结果 @@ -960,11 +1165,7 @@ class KnowledgeBaseRoute(Route): task_id = str(uuid.uuid4()) # 初始化任务状态 - self.upload_tasks[task_id] = { - "status": "pending", - "result": None, - "error": None, - } + self._init_task(task_id, status="pending") # 启动后台任务 asyncio.create_task( @@ -1017,11 +1218,7 @@ class KnowledgeBaseRoute(Route): """后台上传URL任务""" try: # 初始化任务状态 - self.upload_tasks[task_id] = { - "status": "processing", - "result": None, - "error": None, - } + self._init_task(task_id, status="processing") self.upload_progress[task_id] = { "status": "processing", "file_index": 0, @@ -1033,18 +1230,7 @@ class KnowledgeBaseRoute(Route): } # 创建进度回调函数 - async def progress_callback(stage, current, total): - if task_id in self.upload_progress: - self.upload_progress[task_id].update( - { - "status": "processing", - "file_index": 0, - "file_name": f"URL: {url}", - "stage": stage, - "current": current, - "total": total, - }, - ) + progress_callback = self._make_progress_callback(task_id, 0, f"URL: {url}") # 上传文档 doc = await kb_helper.upload_from_url( @@ -1069,20 +1255,9 @@ class KnowledgeBaseRoute(Route): "failed_count": 0, } - self.upload_tasks[task_id] = { - "status": "completed", - "result": result, - "error": None, - } - self.upload_progress[task_id]["status"] = "completed" + self._set_task_result(task_id, "completed", result=result) except Exception as e: logger.error(f"后台上传URL任务 {task_id} 失败: {e}") logger.error(traceback.format_exc()) - self.upload_tasks[task_id] = { - "status": "failed", - "result": None, - "error": str(e), - } - if task_id in self.upload_progress: - self.upload_progress[task_id]["status"] = "failed" + self._set_task_result(task_id, "failed", error=str(e)) diff --git a/astrbot/dashboard/routes/platform.py b/astrbot/dashboard/routes/platform.py index 5b709a628..4d8fdddfe 100644 --- a/astrbot/dashboard/routes/platform.py +++ b/astrbot/dashboard/routes/platform.py @@ -82,7 +82,7 @@ class PlatformRoute(Route): """ for platform in self.platform_manager.platform_insts: if platform.config.get("webhook_uuid") == webhook_uuid: - if platform.config.get("unified_webhook_mode", False): + if platform.unified_webhook(): return platform return None diff --git a/astrbot/dashboard/routes/plugin.py b/astrbot/dashboard/routes/plugin.py index c249b07b7..fd808c6c9 100644 --- a/astrbot/dashboard/routes/plugin.py +++ b/astrbot/dashboard/routes/plugin.py @@ -124,7 +124,11 @@ class PluginRoute(Route): session.get(url) as response, ): if response.status == 200: - remote_data = await response.json() + try: + remote_data = await response.json() + except aiohttp.ContentTypeError: + remote_text = await response.text() + remote_data = json.loads(remote_text) # 检查远程数据是否为空 if not remote_data or ( diff --git a/changelogs/v4.9.0.md b/changelogs/v4.9.0.md new file mode 100644 index 000000000..aeccdb006 --- /dev/null +++ b/changelogs/v4.9.0.md @@ -0,0 +1,19 @@ +## What's Changed + +### 新增 + +- 支持自定义插件源。 +- 支持飞书(Lark)的 Webhook 模式(将事件推送至开发者服务器)。 +- 支持 “禁用自带指令” 快捷配置项,启用后将禁用所有 AstrBot 自带指令。入口: WebUI -> 配置文件 -> 平台配置。 + +### 优化 + +- 从 WebUI 移除了开发版本渠道。 +- 当试图测试"Agent Runner"时,提示前往配置文件页测试。 +- WebUI 列表项支持批量粘贴、回车创建项目。 + +### 修复 + +- Gemini API 部分调用失败的问题。 +- WebUI 插件安装加载 Dialog 关闭按钮在手机端下显示异常的问题。 +- 部分情况下,WebUI 日志显示不全的问题。 \ No newline at end of file diff --git a/dashboard/src/assets/images/loading-seio.webp b/dashboard/src/assets/images/loading-seio.webp new file mode 100644 index 000000000..62e159f98 Binary files /dev/null and b/dashboard/src/assets/images/loading-seio.webp differ diff --git a/dashboard/src/components/shared/ConsoleDisplayer.vue b/dashboard/src/components/shared/ConsoleDisplayer.vue index d3fa591e4..7d6759dfd 100644 --- a/dashboard/src/components/shared/ConsoleDisplayer.vue +++ b/dashboard/src/components/shared/ConsoleDisplayer.vue @@ -1,6 +1,7 @@