diff --git a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py index 9bb8b938f..24591d3e8 100644 --- a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py +++ b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py @@ -1,5 +1,7 @@ import asyncio -import typing +from typing import AsyncGenerator +import re + from astrbot.api.event import AstrMessageEvent, MessageChain from astrbot.api.platform import Group, MessageMember from astrbot.api.message_components import Plain, Image, Record, At, Node, Nodes @@ -82,17 +84,30 @@ class AiocqhttpMessageEvent(AstrMessageEvent): await super().send(message) - async def send_streaming(self, generator): - buffer = None + async def send_streaming(self, generator: AsyncGenerator): + buffer = "" + pattern = r"[^。?!~…]+[。?!~…]+" + async for chain in generator: - if not buffer: - buffer = chain - else: - buffer.chain.extend(chain.chain) - if not buffer: - return - buffer.squash_plain() - await self.send(buffer) + if isinstance(chain, MessageChain): + for comp in chain.chain: + if isinstance(comp, Plain): + buffer += comp.text + + if any(p in buffer for p in "。?!~…"): + while True: + match = re.search(pattern, buffer) + if not match: + break + matched_text = match.group() + await self.send(MessageChain([Plain(matched_text)])) + buffer = buffer[match.end() :] + await asyncio.sleep(0.5) # 限速 + else: + await self.send(MessageChain(chain=[comp])) + + if buffer.strip(): + await self.send(MessageChain([Plain(buffer)])) return await super().send_streaming(generator) async def get_group(self, group_id=None, **kwargs): diff --git a/astrbot/core/platform/sources/dingtalk/dingtalk_event.py b/astrbot/core/platform/sources/dingtalk/dingtalk_event.py index d850a759f..8292e5521 100644 --- a/astrbot/core/platform/sources/dingtalk/dingtalk_event.py +++ b/astrbot/core/platform/sources/dingtalk/dingtalk_event.py @@ -1,8 +1,12 @@ import asyncio +import re +from typing import AsyncGenerator + import dingtalk_stream import astrbot.api.message_components as Comp from astrbot.api.event import AstrMessageEvent, MessageChain from astrbot import logger +from astrbot.core.message.components import Plain class DingtalkMessageEvent(AstrMessageEvent): @@ -61,15 +65,27 @@ class DingtalkMessageEvent(AstrMessageEvent): await self.send_with_client(self.client, message) await super().send(message) - async def send_streaming(self, generator): - buffer = None + async def send_streaming(self, generator: AsyncGenerator): + buffer = "" + pattern = r"[^。?!~…]+[。?!~…]+" + async for chain in generator: - if not buffer: - buffer = chain - else: - buffer.chain.extend(chain.chain) - if not buffer: - return - buffer.squash_plain() - await self.send(buffer) + if isinstance(chain, MessageChain): + for comp in chain.chain: + if isinstance(comp, Plain): + buffer += comp.text + + if any(p in buffer for p in "。?!~…"): + while True: + match = re.search(pattern, buffer) + if not match: + break + matched_text = match.group() + await self.send(MessageChain([Plain(matched_text)])) + buffer = buffer[match.end() :] + await asyncio.sleep(0.5) # 限速 + else: + await self.send(MessageChain(chain=[comp])) + if buffer.strip(): + await self.send(MessageChain([Plain(buffer)])) return await super().send_streaming(generator) diff --git a/astrbot/core/platform/sources/gewechat/gewechat_event.py b/astrbot/core/platform/sources/gewechat/gewechat_event.py index 829a348c6..01e58880a 100644 --- a/astrbot/core/platform/sources/gewechat/gewechat_event.py +++ b/astrbot/core/platform/sources/gewechat/gewechat_event.py @@ -1,7 +1,10 @@ +import asyncio +import re import wave import uuid import traceback import os +from typing import AsyncGenerator from astrbot.core.utils.io import save_temp_img, download_file from astrbot.core.utils.tencent_record_helper import wav_to_tencent_silk @@ -217,15 +220,27 @@ class GewechatPlatformEvent(AstrMessageEvent): members=members, ) - async def send_streaming(self, generator): - buffer = None + async def send_streaming(self, generator: AsyncGenerator): + buffer = "" + pattern = r"[^。?!~…]+[。?!~…]+" + async for chain in generator: - if not buffer: - buffer = chain - else: - buffer.chain.extend(chain.chain) - if not buffer: - return - buffer.squash_plain() - await self.send(buffer) + if isinstance(chain, MessageChain): + for comp in chain.chain: + if isinstance(comp, Plain): + buffer += comp.text + + if any(p in buffer for p in "。?!~…"): + while True: + match = re.search(pattern, buffer) + if not match: + break + matched_text = match.group() + await self.send(MessageChain([Plain(matched_text)])) + buffer = buffer[match.end() :] + await asyncio.sleep(0.5) # 限速 + else: + await self.send(MessageChain(chain=[comp])) + if buffer.strip(): + await self.send(MessageChain([Plain(buffer)])) return await super().send_streaming(generator) diff --git a/astrbot/core/platform/sources/lark/lark_event.py b/astrbot/core/platform/sources/lark/lark_event.py index 544a7a5be..de0759710 100644 --- a/astrbot/core/platform/sources/lark/lark_event.py +++ b/astrbot/core/platform/sources/lark/lark_event.py @@ -1,7 +1,9 @@ +import asyncio import json +import re import uuid import lark_oapi as lark -from typing import List +from typing import List, AsyncGenerator from astrbot.api.event import AstrMessageEvent, MessageChain from astrbot.api.message_components import Plain, Image as AstrBotImage, At from astrbot.core.utils.io import download_image_by_url @@ -92,15 +94,27 @@ class LarkMessageEvent(AstrMessageEvent): await super().send(message) - async def send_streaming(self, generator): - buffer = None + async def send_streaming(self, generator: AsyncGenerator): + buffer = "" + pattern = r"[^。?!~…]+[。?!~…]+" + async for chain in generator: - if not buffer: - buffer = chain - else: - buffer.chain.extend(chain.chain) - if not buffer: - return - buffer.squash_plain() - await self.send(buffer) + if isinstance(chain, MessageChain): + for comp in chain.chain: + if isinstance(comp, Plain): + buffer += comp.text + + if any(p in buffer for p in "。?!~…"): + while True: + match = re.search(pattern, buffer) + if not match: + break + matched_text = match.group() + await self.send(MessageChain([Plain(matched_text)])) + buffer = buffer[match.end() :] + await asyncio.sleep(0.5) # 限速 + else: + await self.send(MessageChain(chain=[comp])) + if buffer.strip(): + await self.send(MessageChain([Plain(buffer)])) return await super().send_streaming(generator) diff --git a/astrbot/core/platform/sources/wecom/wecom_event.py b/astrbot/core/platform/sources/wecom/wecom_event.py index d8ee8b9a3..3c383024e 100644 --- a/astrbot/core/platform/sources/wecom/wecom_event.py +++ b/astrbot/core/platform/sources/wecom/wecom_event.py @@ -1,4 +1,8 @@ +import asyncio +import re import uuid +from typing import AsyncGenerator + from astrbot.api.event import AstrMessageEvent, MessageChain from astrbot.api.platform import AstrBotMessage, PlatformMetadata from astrbot.api.message_components import Plain, Image, Record @@ -85,15 +89,27 @@ class WecomPlatformEvent(AstrMessageEvent): await super().send(message) - async def send_streaming(self, generator): - buffer = None + async def send_streaming(self, generator: AsyncGenerator): + buffer = "" + pattern = r"[^。?!~…]+[。?!~…]+" + async for chain in generator: - if not buffer: - buffer = chain - else: - buffer.chain.extend(chain.chain) - if not buffer: - return - buffer.squash_plain() - await self.send(buffer) + if isinstance(chain, MessageChain): + for comp in chain.chain: + if isinstance(comp, Plain): + buffer += comp.text + + if any(p in buffer for p in "。?!~…"): + while True: + match = re.search(pattern, buffer) + if not match: + break + matched_text = match.group() + await self.send(MessageChain([Plain(matched_text)])) + buffer = buffer[match.end() :] + await asyncio.sleep(0.5) # 限速 + else: + await self.send(MessageChain(chain=[comp])) + if buffer.strip(): + await self.send(MessageChain([Plain(buffer)])) return await super().send_streaming(generator)