feat: 支持消息分段回复

This commit is contained in:
Soulter
2025-01-26 13:45:32 +08:00
parent fa8e731576
commit ceaa69da75
7 changed files with 121 additions and 38 deletions
+32 -1
View File
@@ -24,7 +24,13 @@ DEFAULT_CONFIG = {
"wl_ignore_admin_on_friend": True,
"reply_with_mention": False,
"reply_with_quote": False,
"path_mapping": []
"path_mapping": [],
"segmented_reply": {
"enable": False,
"only_llm_result": True,
"interval": "1.5,3.5",
"regex": ".*?[。?!~…]+|.+$"
}
},
"provider": [],
"provider_settings": {
@@ -182,6 +188,31 @@ CONFIG_METADATA_2 = {
},
},
},
"segmented_reply": {
"description": "分段回复",
"type": "object",
"items": {
"enable": {
"description": "启用分段回复",
"type": "bool",
},
"only_llm_result": {
"description": "仅对 LLM 结果分段",
"type": "bool",
},
"interval": {
"description": "随机间隔时间(秒)",
"type": "string",
"hint": "每一段回复的间隔时间,格式为 `最小时间,最大时间`。如 `0.75,2.5`",
},
"regex": {
"description": "正则表达式",
"type": "string",
"obvious_hint": True,
"hint": "用于分隔一段消息。默认情况下会根据句号、问号等标点符号分隔。re.findall(r'<regex>', text)",
},
},
},
"reply_prefix": {
"description": "回复前缀",
"type": "string",
@@ -13,12 +13,10 @@ class MessageChain():
Attributes:
`chain` (list): 用于顺序存储各个组件。
`use_t2i_` (bool): 用于标记是否使用文本转图片服务。默认为 None,即跟随用户的设置。当设置为 True 时,将会使用文本转图片服务。
`is_split_` (bool): 用于标记是否分条发送消息。默认为 False。启用后,将会依次发送 chain 中的每个 component。
'''
chain: List[BaseMessageComponent] = field(default_factory=list)
use_t2i_: Optional[bool] = None # None 为跟随用户设置
is_split_: Optional[bool] = False # 是否将消息分条发送。默认为 False。启用后,将会依次发送 chain 中的每个 component。
def message(self, message: str):
'''添加一条文本消息到消息链 `chain` 中。
@@ -77,16 +75,6 @@ class MessageChain():
'''
self.use_t2i_ = use_t2i
return self
def is_split(self, is_split: bool):
'''设置是否分条发送消息。默认为 False。启用后,将会依次发送 chain 中的每个 component。
Note:
具体的效果以各适配器实现为准。
'''
self.is_split_ = is_split
return self
class EventResultType(enum.Enum):
'''用于描述事件处理的结果类型。
@@ -113,7 +101,6 @@ class MessageEventResult(MessageChain):
Attributes:
`chain` (list): 用于顺序存储各个组件。
`use_t2i_` (bool): 用于标记是否使用文本转图片服务。默认为 None,即跟随用户的设置。当设置为 True 时,将会使用文本转图片服务。
`is_split_` (bool): 用于标记是否分条发送消息。默认为 False。启用后,将会依次发送 chain 中的每个 component。
`result_type` (EventResultType): 事件处理的结果类型。
'''
+23 -1
View File
@@ -1,7 +1,10 @@
import random
import asyncio
from typing import Union, AsyncGenerator
from ..stage import register_stage, Stage
from ..context import PipelineContext
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core import logger
from astrbot.core.star.star_handler import star_handlers_registry, EventType
@@ -9,6 +12,17 @@ from astrbot.core.star.star_handler import star_handlers_registry, EventType
class RespondStage(Stage):
async def initialize(self, ctx: PipelineContext):
self.ctx = ctx
# 分段回复
self.enable_seg: bool = ctx.astrbot_config['platform_settings']['segmented_reply']['enable']
interval_str: str = ctx.astrbot_config['platform_settings']['segmented_reply']['interval']
interval_str_ls = interval_str.replace(" ", "").split(",")
try:
self.interval = [float(t) for t in interval_str_ls]
except BaseException as e:
logger.error(f'解析分段回复的间隔时间失败。{e}')
self.interval = [1.5, 3.5]
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
result = event.get_result()
@@ -16,7 +30,15 @@ class RespondStage(Stage):
return
if len(result.chain) > 0:
await event.send(result)
await event._pre_send()
if self.enable_seg:
# 分段回复
for comp in result.chain:
await event.send(MessageChain([comp]))
await asyncio.sleep(random.uniform(self.interval[0], self.interval[1]))
else:
await event.send(result)
await event._post_send()
logger.info(f"AstrBot -> {event.get_sender_name()}/{event.get_sender_id()}: {event._outline_chain(result.chain)}")
handlers = star_handlers_registry.get_handlers_by_event_type(EventType.OnAfterMessageSentEvent)
+44 -14
View File
@@ -1,4 +1,5 @@
import time
import re
import traceback
from typing import Union, AsyncGenerator
from ..stage import register_stage
@@ -19,6 +20,11 @@ class ResultDecorateStage:
self.reply_with_quote = ctx.astrbot_config['platform_settings']['reply_with_quote']
self.use_tts = ctx.astrbot_config['provider_tts_settings']['enable']
self.t2i = ctx.astrbot_config['t2i']
# 分段回复
self.enable_segmented_reply = ctx.astrbot_config['platform_settings']['segmented_reply']['enable']
self.only_llm_result = ctx.astrbot_config['platform_settings']['segmented_reply']['only_llm_result']
self.regex = ctx.astrbot_config['platform_settings']['segmented_reply']['regex']
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
result = event.get_result()
@@ -33,26 +39,50 @@ class ResultDecorateStage:
if len(result.chain) > 0:
# 回复前缀
if self.reply_prefix:
result.chain.insert(0, Plain(self.reply_prefix))
for comp in result.chain:
if isinstance(comp, Plain):
comp.text = self.reply_prefix + comp.text
break
# 分段回复
if self.enable_segmented_reply:
if (self.only_llm_result and result.is_llm_result()) or not self.only_llm_result:
new_chain = []
for comp in result.chain:
if isinstance(comp, Plain):
split_response = re.findall(r".*?[。?!~…]+|.+$", comp.text)
if not split_response:
new_chain.append(comp)
continue
for seg in split_response:
new_chain.append(Plain(seg))
else:
# 非 Plain 类型的消息段不分段
new_chain.append(comp)
result.chain = new_chain
# TTS
if self.use_tts and result.is_llm_result():
tts_provider = self.ctx.plugin_manager.context.provider_manager.curr_tts_provider_inst
plain_str = ""
new_chain = []
for comp in result.chain:
if isinstance(comp, Plain):
plain_str += " " + comp.text
if isinstance(comp, Plain) and len(comp.text) > 1:
try:
logger.info("TTS 请求: " + plain_str)
audio_path = await tts_provider.get_audio(plain_str)
logger.info("TTS 结果: " + audio_path)
if audio_path:
new_chain.append(Record(file=audio_path, url=audio_path))
else:
logger.error(f"由于 TTS 音频文件没找到,消息段转语音失败: {comp.text}")
new_chain.append(comp)
except BaseException:
traceback.print_exc()
logger.error("TTS 失败,使用文本发送。")
new_chain.append(comp)
else:
break
if plain_str:
try:
audio_path = await tts_provider.get_audio(plain_str)
logger.info("TTS 结果: " + audio_path)
if audio_path:
result.chain = [Record(file=audio_path, url=audio_path)]
except BaseException:
traceback.print_exc()
logger.error("TTS 失败,使用文本发送。")
new_chain.append(comp)
result.chain = new_chain
# 文本转图片
elif (result.use_t2i_ is None and self.t2i) or result.use_t2i_:
@@ -179,6 +179,15 @@ class AstrMessageEvent(abc.ABC):
await Metric.upload(msg_event_tick = 1, adapter_name = self.platform_meta.name)
self._has_send_oper = True
async def _pre_send(self):
'''调度器会在执行 send() 前调用该方法'''
pass
async def _post_send(self):
'''调度器会在执行 send() 后调用该方法'''
pass
def set_result(self, result: Union[MessageEventResult, str]):
'''设置消息事件的结果。
@@ -40,11 +40,5 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
ret = await AiocqhttpMessageEvent._parse_onebot_json(message)
if os.environ.get('TEST_MODE', 'off') == 'on':
return
if message.is_split_: # 分条发送
for m in ret:
await self.bot.send(self.message_obj.raw_message, [m])
await asyncio.sleep(random.uniform(0.75, 2.5))
else:
await self.bot.send(self.message_obj.raw_message, ret)
await self.bot.send(self.message_obj.raw_message, ret)
await super().send(message)
@@ -14,12 +14,20 @@ class QQOfficialMessageEvent(AstrMessageEvent):
def __init__(self, message_str: str, message_obj: AstrBotMessage, platform_meta: PlatformMetadata, session_id: str, bot: Client):
super().__init__(message_str, message_obj, platform_meta, session_id)
self.bot = bot
self.send_buffer = None
async def send(self, message: MessageChain):
if not self.send_buffer:
self.send_buffer = message
else:
self.send_buffer.chain.extend(message.chain)
async def _post_send(self):
'''QQ 官方 API 仅支持回复一次'''
source = self.message_obj.raw_message
assert isinstance(source, (botpy.message.Message, botpy.message.GroupMessage, botpy.message.DirectMessage, botpy.message.C2CMessage))
plain_text, image_base64, image_path = await QQOfficialMessageEvent._parse_to_qqofficial(message)
plain_text, image_base64, image_path = await QQOfficialMessageEvent._parse_to_qqofficial(self.send_buffer)
payload = {
'content': plain_text,
@@ -48,7 +56,9 @@ class QQOfficialMessageEvent(AstrMessageEvent):
payload['file_image'] = image_path
await self.bot.api.post_dms(guild_id=source.guild_id, **payload)
await super().send(message)
await super().send(self.send_buffer)
self.send_buffer = None
async def upload_group_and_c2c_image(self, image_base64: str, file_type: int, **kwargs) -> botpy.types.message.Media:
payload = {