From ceaa69da757b819ebee7e0f343998ca88e342c6b Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Sun, 26 Jan 2025 13:45:32 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=94=AF=E6=8C=81=E6=B6=88=E6=81=AF?= =?UTF-8?q?=E5=88=86=E6=AE=B5=E5=9B=9E=E5=A4=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/core/config/default.py | 33 ++++++++++- astrbot/core/message/message_event_result.py | 13 ----- astrbot/core/pipeline/respond/stage.py | 24 +++++++- .../core/pipeline/result_decorate/stage.py | 58 ++++++++++++++----- astrbot/core/platform/astr_message_event.py | 9 +++ .../aiocqhttp/aiocqhttp_message_event.py | 8 +-- .../qqofficial/qqofficial_message_event.py | 14 ++++- 7 files changed, 121 insertions(+), 38 deletions(-) diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 91a5b1a5d..2e5aaf264 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -24,7 +24,13 @@ DEFAULT_CONFIG = { "wl_ignore_admin_on_friend": True, "reply_with_mention": False, "reply_with_quote": False, - "path_mapping": [] + "path_mapping": [], + "segmented_reply": { + "enable": False, + "only_llm_result": True, + "interval": "1.5,3.5", + "regex": ".*?[。?!~…]+|.+$" + } }, "provider": [], "provider_settings": { @@ -182,6 +188,31 @@ CONFIG_METADATA_2 = { }, }, }, + "segmented_reply": { + "description": "分段回复", + "type": "object", + "items": { + "enable": { + "description": "启用分段回复", + "type": "bool", + }, + "only_llm_result": { + "description": "仅对 LLM 结果分段", + "type": "bool", + }, + "interval": { + "description": "随机间隔时间(秒)", + "type": "string", + "hint": "每一段回复的间隔时间,格式为 `最小时间,最大时间`。如 `0.75,2.5`", + }, + "regex": { + "description": "正则表达式", + "type": "string", + "obvious_hint": True, + "hint": "用于分隔一段消息。默认情况下会根据句号、问号等标点符号分隔。re.findall(r'', text)", + }, + }, + }, "reply_prefix": { "description": "回复前缀", "type": "string", diff --git a/astrbot/core/message/message_event_result.py b/astrbot/core/message/message_event_result.py index 9ecb0cd23..c9c13ec9c 100644 --- a/astrbot/core/message/message_event_result.py +++ b/astrbot/core/message/message_event_result.py @@ -13,12 +13,10 @@ class MessageChain(): Attributes: `chain` (list): 用于顺序存储各个组件。 `use_t2i_` (bool): 用于标记是否使用文本转图片服务。默认为 None,即跟随用户的设置。当设置为 True 时,将会使用文本转图片服务。 - `is_split_` (bool): 用于标记是否分条发送消息。默认为 False。启用后,将会依次发送 chain 中的每个 component。 ''' chain: List[BaseMessageComponent] = field(default_factory=list) use_t2i_: Optional[bool] = None # None 为跟随用户设置 - is_split_: Optional[bool] = False # 是否将消息分条发送。默认为 False。启用后,将会依次发送 chain 中的每个 component。 def message(self, message: str): '''添加一条文本消息到消息链 `chain` 中。 @@ -77,16 +75,6 @@ class MessageChain(): ''' self.use_t2i_ = use_t2i return self - - def is_split(self, is_split: bool): - '''设置是否分条发送消息。默认为 False。启用后,将会依次发送 chain 中的每个 component。 - - Note: - 具体的效果以各适配器实现为准。 - - ''' - self.is_split_ = is_split - return self class EventResultType(enum.Enum): '''用于描述事件处理的结果类型。 @@ -113,7 +101,6 @@ class MessageEventResult(MessageChain): Attributes: `chain` (list): 用于顺序存储各个组件。 `use_t2i_` (bool): 用于标记是否使用文本转图片服务。默认为 None,即跟随用户的设置。当设置为 True 时,将会使用文本转图片服务。 - `is_split_` (bool): 用于标记是否分条发送消息。默认为 False。启用后,将会依次发送 chain 中的每个 component。 `result_type` (EventResultType): 事件处理的结果类型。 ''' diff --git a/astrbot/core/pipeline/respond/stage.py b/astrbot/core/pipeline/respond/stage.py index 9b98c6030..95ec8b898 100644 --- a/astrbot/core/pipeline/respond/stage.py +++ b/astrbot/core/pipeline/respond/stage.py @@ -1,7 +1,10 @@ +import random +import asyncio from typing import Union, AsyncGenerator from ..stage import register_stage, Stage from ..context import PipelineContext from astrbot.core.platform.astr_message_event import AstrMessageEvent +from astrbot.core.message.message_event_result import MessageChain from astrbot.core import logger from astrbot.core.star.star_handler import star_handlers_registry, EventType @@ -9,6 +12,17 @@ from astrbot.core.star.star_handler import star_handlers_registry, EventType class RespondStage(Stage): async def initialize(self, ctx: PipelineContext): self.ctx = ctx + + # 分段回复 + self.enable_seg: bool = ctx.astrbot_config['platform_settings']['segmented_reply']['enable'] + interval_str: str = ctx.astrbot_config['platform_settings']['segmented_reply']['interval'] + interval_str_ls = interval_str.replace(" ", "").split(",") + try: + self.interval = [float(t) for t in interval_str_ls] + except BaseException as e: + logger.error(f'解析分段回复的间隔时间失败。{e}') + self.interval = [1.5, 3.5] + async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]: result = event.get_result() @@ -16,7 +30,15 @@ class RespondStage(Stage): return if len(result.chain) > 0: - await event.send(result) + await event._pre_send() + if self.enable_seg: + # 分段回复 + for comp in result.chain: + await event.send(MessageChain([comp])) + await asyncio.sleep(random.uniform(self.interval[0], self.interval[1])) + else: + await event.send(result) + await event._post_send() logger.info(f"AstrBot -> {event.get_sender_name()}/{event.get_sender_id()}: {event._outline_chain(result.chain)}") handlers = star_handlers_registry.get_handlers_by_event_type(EventType.OnAfterMessageSentEvent) diff --git a/astrbot/core/pipeline/result_decorate/stage.py b/astrbot/core/pipeline/result_decorate/stage.py index 46e7914ee..c2f856c32 100644 --- a/astrbot/core/pipeline/result_decorate/stage.py +++ b/astrbot/core/pipeline/result_decorate/stage.py @@ -1,4 +1,5 @@ import time +import re import traceback from typing import Union, AsyncGenerator from ..stage import register_stage @@ -19,6 +20,11 @@ class ResultDecorateStage: self.reply_with_quote = ctx.astrbot_config['platform_settings']['reply_with_quote'] self.use_tts = ctx.astrbot_config['provider_tts_settings']['enable'] self.t2i = ctx.astrbot_config['t2i'] + + # 分段回复 + self.enable_segmented_reply = ctx.astrbot_config['platform_settings']['segmented_reply']['enable'] + self.only_llm_result = ctx.astrbot_config['platform_settings']['segmented_reply']['only_llm_result'] + self.regex = ctx.astrbot_config['platform_settings']['segmented_reply']['regex'] async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]: result = event.get_result() @@ -33,26 +39,50 @@ class ResultDecorateStage: if len(result.chain) > 0: # 回复前缀 if self.reply_prefix: - result.chain.insert(0, Plain(self.reply_prefix)) + for comp in result.chain: + if isinstance(comp, Plain): + comp.text = self.reply_prefix + comp.text + break + + # 分段回复 + if self.enable_segmented_reply: + if (self.only_llm_result and result.is_llm_result()) or not self.only_llm_result: + new_chain = [] + for comp in result.chain: + if isinstance(comp, Plain): + split_response = re.findall(r".*?[。?!~…]+|.+$", comp.text) + if not split_response: + new_chain.append(comp) + continue + for seg in split_response: + new_chain.append(Plain(seg)) + else: + # 非 Plain 类型的消息段不分段 + new_chain.append(comp) + result.chain = new_chain # TTS if self.use_tts and result.is_llm_result(): tts_provider = self.ctx.plugin_manager.context.provider_manager.curr_tts_provider_inst - plain_str = "" + new_chain = [] for comp in result.chain: - if isinstance(comp, Plain): - plain_str += " " + comp.text + if isinstance(comp, Plain) and len(comp.text) > 1: + try: + logger.info("TTS 请求: " + plain_str) + audio_path = await tts_provider.get_audio(plain_str) + logger.info("TTS 结果: " + audio_path) + if audio_path: + new_chain.append(Record(file=audio_path, url=audio_path)) + else: + logger.error(f"由于 TTS 音频文件没找到,消息段转语音失败: {comp.text}") + new_chain.append(comp) + except BaseException: + traceback.print_exc() + logger.error("TTS 失败,使用文本发送。") + new_chain.append(comp) else: - break - if plain_str: - try: - audio_path = await tts_provider.get_audio(plain_str) - logger.info("TTS 结果: " + audio_path) - if audio_path: - result.chain = [Record(file=audio_path, url=audio_path)] - except BaseException: - traceback.print_exc() - logger.error("TTS 失败,使用文本发送。") + new_chain.append(comp) + result.chain = new_chain # 文本转图片 elif (result.use_t2i_ is None and self.t2i) or result.use_t2i_: diff --git a/astrbot/core/platform/astr_message_event.py b/astrbot/core/platform/astr_message_event.py index a743610a3..cf380332c 100644 --- a/astrbot/core/platform/astr_message_event.py +++ b/astrbot/core/platform/astr_message_event.py @@ -179,6 +179,15 @@ class AstrMessageEvent(abc.ABC): await Metric.upload(msg_event_tick = 1, adapter_name = self.platform_meta.name) self._has_send_oper = True + async def _pre_send(self): + '''调度器会在执行 send() 前调用该方法''' + pass + + async def _post_send(self): + '''调度器会在执行 send() 后调用该方法''' + pass + + def set_result(self, result: Union[MessageEventResult, str]): '''设置消息事件的结果。 diff --git a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py index aa26516c8..482f676cf 100644 --- a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py +++ b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py @@ -40,11 +40,5 @@ class AiocqhttpMessageEvent(AstrMessageEvent): ret = await AiocqhttpMessageEvent._parse_onebot_json(message) if os.environ.get('TEST_MODE', 'off') == 'on': return - - if message.is_split_: # 分条发送 - for m in ret: - await self.bot.send(self.message_obj.raw_message, [m]) - await asyncio.sleep(random.uniform(0.75, 2.5)) - else: - await self.bot.send(self.message_obj.raw_message, ret) + await self.bot.send(self.message_obj.raw_message, ret) await super().send(message) \ No newline at end of file diff --git a/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py b/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py index 50074d0f1..75db60ded 100644 --- a/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py +++ b/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py @@ -14,12 +14,20 @@ class QQOfficialMessageEvent(AstrMessageEvent): def __init__(self, message_str: str, message_obj: AstrBotMessage, platform_meta: PlatformMetadata, session_id: str, bot: Client): super().__init__(message_str, message_obj, platform_meta, session_id) self.bot = bot + self.send_buffer = None async def send(self, message: MessageChain): + if not self.send_buffer: + self.send_buffer = message + else: + self.send_buffer.chain.extend(message.chain) + + async def _post_send(self): + '''QQ 官方 API 仅支持回复一次''' source = self.message_obj.raw_message assert isinstance(source, (botpy.message.Message, botpy.message.GroupMessage, botpy.message.DirectMessage, botpy.message.C2CMessage)) - plain_text, image_base64, image_path = await QQOfficialMessageEvent._parse_to_qqofficial(message) + plain_text, image_base64, image_path = await QQOfficialMessageEvent._parse_to_qqofficial(self.send_buffer) payload = { 'content': plain_text, @@ -48,7 +56,9 @@ class QQOfficialMessageEvent(AstrMessageEvent): payload['file_image'] = image_path await self.bot.api.post_dms(guild_id=source.guild_id, **payload) - await super().send(message) + await super().send(self.send_buffer) + + self.send_buffer = None async def upload_group_and_c2c_image(self, image_base64: str, file_type: int, **kwargs) -> botpy.types.message.Media: payload = {