From aca5743ab61e523218fcd4b3b2168dccc54b3533 Mon Sep 17 00:00:00 2001 From: Dt8333 <25431943+Dt8333@users.noreply.github.com> Date: Sun, 9 Nov 2025 16:00:24 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E4=B8=BA=E9=83=A8=E5=88=86=E9=80=82?= =?UTF-8?q?=E9=85=8D=E5=99=A8=E6=B7=BB=E5=8A=A0=E7=BC=BA=E5=A4=B1=E7=9A=84?= =?UTF-8?q?=20send=5Fstreaming=20=E6=96=B9=E6=B3=95=20(#3545)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 为Wechatpadpro和discord添加缺失的方法。 --- .../sources/discord/discord_platform_event.py | 23 +++++++++++++------ .../wechatpadpro_message_event.py | 16 +++++++++++++ 2 files changed, 32 insertions(+), 7 deletions(-) diff --git a/astrbot/core/platform/sources/discord/discord_platform_event.py b/astrbot/core/platform/sources/discord/discord_platform_event.py index 0c1778eeb..82eb9f144 100644 --- a/astrbot/core/platform/sources/discord/discord_platform_event.py +++ b/astrbot/core/platform/sources/discord/discord_platform_event.py @@ -1,7 +1,7 @@ import asyncio import base64 import binascii -import sys +from collections.abc import AsyncGenerator from io import BytesIO from pathlib import Path @@ -21,11 +21,6 @@ from astrbot.api.platform import AstrBotMessage, At, PlatformMetadata from .client import DiscordBotClient from .components import DiscordEmbed, DiscordView -if sys.version_info >= (3, 12): - from typing import override -else: - from typing_extensions import override - # 自定义Discord视图组件(兼容旧版本) class DiscordViewComponent(BaseMessageComponent): @@ -49,7 +44,6 @@ class DiscordPlatformEvent(AstrMessageEvent): self.client = client self.interaction_followup_webhook = interaction_followup_webhook - @override async def send(self, message: MessageChain): """发送消息到Discord平台""" # 解析消息链为 Discord 所需的对象 @@ -98,6 +92,21 @@ class DiscordPlatformEvent(AstrMessageEvent): await super().send(message) + async def send_streaming( + self, generator: AsyncGenerator[MessageChain, None], use_fallback: bool = False + ): + buffer = None + async for chain in generator: + if not buffer: + buffer = chain + else: + buffer.chain.extend(chain.chain) + if not buffer: + return None + buffer.squash_plain() + await self.send(buffer) + return await super().send_streaming(generator, use_fallback) + async def _get_channel(self) -> discord.abc.Messageable | None: """获取当前事件对应的频道对象""" try: diff --git a/astrbot/core/platform/sources/wechatpadpro/wechatpadpro_message_event.py b/astrbot/core/platform/sources/wechatpadpro/wechatpadpro_message_event.py index e6d83d8ea..08ab27013 100644 --- a/astrbot/core/platform/sources/wechatpadpro/wechatpadpro_message_event.py +++ b/astrbot/core/platform/sources/wechatpadpro/wechatpadpro_message_event.py @@ -1,6 +1,7 @@ import asyncio import base64 import io +from collections.abc import AsyncGenerator from typing import TYPE_CHECKING import aiohttp @@ -50,6 +51,21 @@ class WeChatPadProMessageEvent(AstrMessageEvent): await self._send_voice(session, comp) await super().send(message) + async def send_streaming( + self, generator: AsyncGenerator[MessageChain, None], use_fallback: bool = False + ): + buffer = None + async for chain in generator: + if not buffer: + buffer = chain + else: + buffer.chain.extend(chain.chain) + if not buffer: + return None + buffer.squash_plain() + await self.send(buffer) + return await super().send_streaming(generator, use_fallback) + async def _send_image(self, session: aiohttp.ClientSession, comp: Image): b64 = await comp.convert_to_base64() raw = self._validate_base64(b64)