diff --git a/astrbot/core/platform/sources/discord/discord_platform_event.py b/astrbot/core/platform/sources/discord/discord_platform_event.py index 0c1778eeb..82eb9f144 100644 --- a/astrbot/core/platform/sources/discord/discord_platform_event.py +++ b/astrbot/core/platform/sources/discord/discord_platform_event.py @@ -1,7 +1,7 @@ import asyncio import base64 import binascii -import sys +from collections.abc import AsyncGenerator from io import BytesIO from pathlib import Path @@ -21,11 +21,6 @@ from astrbot.api.platform import AstrBotMessage, At, PlatformMetadata from .client import DiscordBotClient from .components import DiscordEmbed, DiscordView -if sys.version_info >= (3, 12): - from typing import override -else: - from typing_extensions import override - # 自定义Discord视图组件(兼容旧版本) class DiscordViewComponent(BaseMessageComponent): @@ -49,7 +44,6 @@ class DiscordPlatformEvent(AstrMessageEvent): self.client = client self.interaction_followup_webhook = interaction_followup_webhook - @override async def send(self, message: MessageChain): """发送消息到Discord平台""" # 解析消息链为 Discord 所需的对象 @@ -98,6 +92,21 @@ class DiscordPlatformEvent(AstrMessageEvent): await super().send(message) + async def send_streaming( + self, generator: AsyncGenerator[MessageChain, None], use_fallback: bool = False + ): + buffer = None + async for chain in generator: + if not buffer: + buffer = chain + else: + buffer.chain.extend(chain.chain) + if not buffer: + return None + buffer.squash_plain() + await self.send(buffer) + return await super().send_streaming(generator, use_fallback) + async def _get_channel(self) -> discord.abc.Messageable | None: """获取当前事件对应的频道对象""" try: diff --git a/astrbot/core/platform/sources/wechatpadpro/wechatpadpro_message_event.py b/astrbot/core/platform/sources/wechatpadpro/wechatpadpro_message_event.py index e6d83d8ea..08ab27013 100644 --- a/astrbot/core/platform/sources/wechatpadpro/wechatpadpro_message_event.py +++ b/astrbot/core/platform/sources/wechatpadpro/wechatpadpro_message_event.py @@ -1,6 +1,7 @@ import asyncio import base64 import io +from collections.abc import AsyncGenerator from typing import TYPE_CHECKING import aiohttp @@ -50,6 +51,21 @@ class WeChatPadProMessageEvent(AstrMessageEvent): await self._send_voice(session, comp) await super().send(message) + async def send_streaming( + self, generator: AsyncGenerator[MessageChain, None], use_fallback: bool = False + ): + buffer = None + async for chain in generator: + if not buffer: + buffer = chain + else: + buffer.chain.extend(chain.chain) + if not buffer: + return None + buffer.squash_plain() + await self.send(buffer) + return await super().send_streaming(generator, use_fallback) + async def _send_image(self, session: aiohttp.ClientSession, comp: Image): b64 = await comp.convert_to_base64() raw = self._validate_base64(b64)