feat: 支持消息分段回复
This commit is contained in:
@@ -24,7 +24,13 @@ DEFAULT_CONFIG = {
|
||||
"wl_ignore_admin_on_friend": True,
|
||||
"reply_with_mention": False,
|
||||
"reply_with_quote": False,
|
||||
"path_mapping": []
|
||||
"path_mapping": [],
|
||||
"segmented_reply": {
|
||||
"enable": False,
|
||||
"only_llm_result": True,
|
||||
"interval": "1.5,3.5",
|
||||
"regex": ".*?[。?!~…]+|.+$"
|
||||
}
|
||||
},
|
||||
"provider": [],
|
||||
"provider_settings": {
|
||||
@@ -182,6 +188,31 @@ CONFIG_METADATA_2 = {
|
||||
},
|
||||
},
|
||||
},
|
||||
"segmented_reply": {
|
||||
"description": "分段回复",
|
||||
"type": "object",
|
||||
"items": {
|
||||
"enable": {
|
||||
"description": "启用分段回复",
|
||||
"type": "bool",
|
||||
},
|
||||
"only_llm_result": {
|
||||
"description": "仅对 LLM 结果分段",
|
||||
"type": "bool",
|
||||
},
|
||||
"interval": {
|
||||
"description": "随机间隔时间(秒)",
|
||||
"type": "string",
|
||||
"hint": "每一段回复的间隔时间,格式为 `最小时间,最大时间`。如 `0.75,2.5`",
|
||||
},
|
||||
"regex": {
|
||||
"description": "正则表达式",
|
||||
"type": "string",
|
||||
"obvious_hint": True,
|
||||
"hint": "用于分隔一段消息。默认情况下会根据句号、问号等标点符号分隔。re.findall(r'<regex>', text)",
|
||||
},
|
||||
},
|
||||
},
|
||||
"reply_prefix": {
|
||||
"description": "回复前缀",
|
||||
"type": "string",
|
||||
|
||||
@@ -13,12 +13,10 @@ class MessageChain():
|
||||
Attributes:
|
||||
`chain` (list): 用于顺序存储各个组件。
|
||||
`use_t2i_` (bool): 用于标记是否使用文本转图片服务。默认为 None,即跟随用户的设置。当设置为 True 时,将会使用文本转图片服务。
|
||||
`is_split_` (bool): 用于标记是否分条发送消息。默认为 False。启用后,将会依次发送 chain 中的每个 component。
|
||||
'''
|
||||
|
||||
chain: List[BaseMessageComponent] = field(default_factory=list)
|
||||
use_t2i_: Optional[bool] = None # None 为跟随用户设置
|
||||
is_split_: Optional[bool] = False # 是否将消息分条发送。默认为 False。启用后,将会依次发送 chain 中的每个 component。
|
||||
|
||||
def message(self, message: str):
|
||||
'''添加一条文本消息到消息链 `chain` 中。
|
||||
@@ -77,16 +75,6 @@ class MessageChain():
|
||||
'''
|
||||
self.use_t2i_ = use_t2i
|
||||
return self
|
||||
|
||||
def is_split(self, is_split: bool):
|
||||
'''设置是否分条发送消息。默认为 False。启用后,将会依次发送 chain 中的每个 component。
|
||||
|
||||
Note:
|
||||
具体的效果以各适配器实现为准。
|
||||
|
||||
'''
|
||||
self.is_split_ = is_split
|
||||
return self
|
||||
|
||||
class EventResultType(enum.Enum):
|
||||
'''用于描述事件处理的结果类型。
|
||||
@@ -113,7 +101,6 @@ class MessageEventResult(MessageChain):
|
||||
Attributes:
|
||||
`chain` (list): 用于顺序存储各个组件。
|
||||
`use_t2i_` (bool): 用于标记是否使用文本转图片服务。默认为 None,即跟随用户的设置。当设置为 True 时,将会使用文本转图片服务。
|
||||
`is_split_` (bool): 用于标记是否分条发送消息。默认为 False。启用后,将会依次发送 chain 中的每个 component。
|
||||
`result_type` (EventResultType): 事件处理的结果类型。
|
||||
'''
|
||||
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
import random
|
||||
import asyncio
|
||||
from typing import Union, AsyncGenerator
|
||||
from ..stage import register_stage, Stage
|
||||
from ..context import PipelineContext
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
||||
|
||||
@@ -9,6 +12,17 @@ from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
||||
class RespondStage(Stage):
|
||||
async def initialize(self, ctx: PipelineContext):
|
||||
self.ctx = ctx
|
||||
|
||||
# 分段回复
|
||||
self.enable_seg: bool = ctx.astrbot_config['platform_settings']['segmented_reply']['enable']
|
||||
interval_str: str = ctx.astrbot_config['platform_settings']['segmented_reply']['interval']
|
||||
interval_str_ls = interval_str.replace(" ", "").split(",")
|
||||
try:
|
||||
self.interval = [float(t) for t in interval_str_ls]
|
||||
except BaseException as e:
|
||||
logger.error(f'解析分段回复的间隔时间失败。{e}')
|
||||
self.interval = [1.5, 3.5]
|
||||
|
||||
|
||||
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
|
||||
result = event.get_result()
|
||||
@@ -16,7 +30,15 @@ class RespondStage(Stage):
|
||||
return
|
||||
|
||||
if len(result.chain) > 0:
|
||||
await event.send(result)
|
||||
await event._pre_send()
|
||||
if self.enable_seg:
|
||||
# 分段回复
|
||||
for comp in result.chain:
|
||||
await event.send(MessageChain([comp]))
|
||||
await asyncio.sleep(random.uniform(self.interval[0], self.interval[1]))
|
||||
else:
|
||||
await event.send(result)
|
||||
await event._post_send()
|
||||
logger.info(f"AstrBot -> {event.get_sender_name()}/{event.get_sender_id()}: {event._outline_chain(result.chain)}")
|
||||
|
||||
handlers = star_handlers_registry.get_handlers_by_event_type(EventType.OnAfterMessageSentEvent)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import time
|
||||
import re
|
||||
import traceback
|
||||
from typing import Union, AsyncGenerator
|
||||
from ..stage import register_stage
|
||||
@@ -19,6 +20,11 @@ class ResultDecorateStage:
|
||||
self.reply_with_quote = ctx.astrbot_config['platform_settings']['reply_with_quote']
|
||||
self.use_tts = ctx.astrbot_config['provider_tts_settings']['enable']
|
||||
self.t2i = ctx.astrbot_config['t2i']
|
||||
|
||||
# 分段回复
|
||||
self.enable_segmented_reply = ctx.astrbot_config['platform_settings']['segmented_reply']['enable']
|
||||
self.only_llm_result = ctx.astrbot_config['platform_settings']['segmented_reply']['only_llm_result']
|
||||
self.regex = ctx.astrbot_config['platform_settings']['segmented_reply']['regex']
|
||||
|
||||
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
|
||||
result = event.get_result()
|
||||
@@ -33,26 +39,50 @@ class ResultDecorateStage:
|
||||
if len(result.chain) > 0:
|
||||
# 回复前缀
|
||||
if self.reply_prefix:
|
||||
result.chain.insert(0, Plain(self.reply_prefix))
|
||||
for comp in result.chain:
|
||||
if isinstance(comp, Plain):
|
||||
comp.text = self.reply_prefix + comp.text
|
||||
break
|
||||
|
||||
# 分段回复
|
||||
if self.enable_segmented_reply:
|
||||
if (self.only_llm_result and result.is_llm_result()) or not self.only_llm_result:
|
||||
new_chain = []
|
||||
for comp in result.chain:
|
||||
if isinstance(comp, Plain):
|
||||
split_response = re.findall(r".*?[。?!~…]+|.+$", comp.text)
|
||||
if not split_response:
|
||||
new_chain.append(comp)
|
||||
continue
|
||||
for seg in split_response:
|
||||
new_chain.append(Plain(seg))
|
||||
else:
|
||||
# 非 Plain 类型的消息段不分段
|
||||
new_chain.append(comp)
|
||||
result.chain = new_chain
|
||||
|
||||
# TTS
|
||||
if self.use_tts and result.is_llm_result():
|
||||
tts_provider = self.ctx.plugin_manager.context.provider_manager.curr_tts_provider_inst
|
||||
plain_str = ""
|
||||
new_chain = []
|
||||
for comp in result.chain:
|
||||
if isinstance(comp, Plain):
|
||||
plain_str += " " + comp.text
|
||||
if isinstance(comp, Plain) and len(comp.text) > 1:
|
||||
try:
|
||||
logger.info("TTS 请求: " + plain_str)
|
||||
audio_path = await tts_provider.get_audio(plain_str)
|
||||
logger.info("TTS 结果: " + audio_path)
|
||||
if audio_path:
|
||||
new_chain.append(Record(file=audio_path, url=audio_path))
|
||||
else:
|
||||
logger.error(f"由于 TTS 音频文件没找到,消息段转语音失败: {comp.text}")
|
||||
new_chain.append(comp)
|
||||
except BaseException:
|
||||
traceback.print_exc()
|
||||
logger.error("TTS 失败,使用文本发送。")
|
||||
new_chain.append(comp)
|
||||
else:
|
||||
break
|
||||
if plain_str:
|
||||
try:
|
||||
audio_path = await tts_provider.get_audio(plain_str)
|
||||
logger.info("TTS 结果: " + audio_path)
|
||||
if audio_path:
|
||||
result.chain = [Record(file=audio_path, url=audio_path)]
|
||||
except BaseException:
|
||||
traceback.print_exc()
|
||||
logger.error("TTS 失败,使用文本发送。")
|
||||
new_chain.append(comp)
|
||||
result.chain = new_chain
|
||||
|
||||
# 文本转图片
|
||||
elif (result.use_t2i_ is None and self.t2i) or result.use_t2i_:
|
||||
|
||||
@@ -179,6 +179,15 @@ class AstrMessageEvent(abc.ABC):
|
||||
await Metric.upload(msg_event_tick = 1, adapter_name = self.platform_meta.name)
|
||||
self._has_send_oper = True
|
||||
|
||||
async def _pre_send(self):
|
||||
'''调度器会在执行 send() 前调用该方法'''
|
||||
pass
|
||||
|
||||
async def _post_send(self):
|
||||
'''调度器会在执行 send() 后调用该方法'''
|
||||
pass
|
||||
|
||||
|
||||
def set_result(self, result: Union[MessageEventResult, str]):
|
||||
'''设置消息事件的结果。
|
||||
|
||||
|
||||
@@ -40,11 +40,5 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
|
||||
ret = await AiocqhttpMessageEvent._parse_onebot_json(message)
|
||||
if os.environ.get('TEST_MODE', 'off') == 'on':
|
||||
return
|
||||
|
||||
if message.is_split_: # 分条发送
|
||||
for m in ret:
|
||||
await self.bot.send(self.message_obj.raw_message, [m])
|
||||
await asyncio.sleep(random.uniform(0.75, 2.5))
|
||||
else:
|
||||
await self.bot.send(self.message_obj.raw_message, ret)
|
||||
await self.bot.send(self.message_obj.raw_message, ret)
|
||||
await super().send(message)
|
||||
@@ -14,12 +14,20 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
||||
def __init__(self, message_str: str, message_obj: AstrBotMessage, platform_meta: PlatformMetadata, session_id: str, bot: Client):
|
||||
super().__init__(message_str, message_obj, platform_meta, session_id)
|
||||
self.bot = bot
|
||||
self.send_buffer = None
|
||||
|
||||
async def send(self, message: MessageChain):
|
||||
if not self.send_buffer:
|
||||
self.send_buffer = message
|
||||
else:
|
||||
self.send_buffer.chain.extend(message.chain)
|
||||
|
||||
async def _post_send(self):
|
||||
'''QQ 官方 API 仅支持回复一次'''
|
||||
source = self.message_obj.raw_message
|
||||
assert isinstance(source, (botpy.message.Message, botpy.message.GroupMessage, botpy.message.DirectMessage, botpy.message.C2CMessage))
|
||||
|
||||
plain_text, image_base64, image_path = await QQOfficialMessageEvent._parse_to_qqofficial(message)
|
||||
plain_text, image_base64, image_path = await QQOfficialMessageEvent._parse_to_qqofficial(self.send_buffer)
|
||||
|
||||
payload = {
|
||||
'content': plain_text,
|
||||
@@ -48,7 +56,9 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
||||
payload['file_image'] = image_path
|
||||
await self.bot.api.post_dms(guild_id=source.guild_id, **payload)
|
||||
|
||||
await super().send(message)
|
||||
await super().send(self.send_buffer)
|
||||
|
||||
self.send_buffer = None
|
||||
|
||||
async def upload_group_and_c2c_image(self, image_base64: str, file_type: int, **kwargs) -> botpy.types.message.Media:
|
||||
payload = {
|
||||
|
||||
Reference in New Issue
Block a user