diff --git a/astrbot/core/pipeline/respond/stage.py b/astrbot/core/pipeline/respond/stage.py index ebbba7ed3..a674be313 100644 --- a/astrbot/core/pipeline/respond/stage.py +++ b/astrbot/core/pipeline/respond/stage.py @@ -1,17 +1,15 @@ import random import asyncio import math -import traceback import astrbot.core.message.components as Comp from typing import Union, AsyncGenerator from ..stage import register_stage, Stage -from ..context import PipelineContext +from ..context import PipelineContext, call_event_hook from astrbot.core.platform.astr_message_event import AstrMessageEvent from astrbot.core.message.message_event_result import MessageChain, ResultContentType from astrbot.core import logger -from astrbot.core.message.message_event_result import BaseMessageComponent -from astrbot.core.star.star_handler import star_handlers_registry, EventType -from astrbot.core.star.star import star_map +from astrbot.core.message.components import BaseMessageComponent, ComponentType +from astrbot.core.star.star_handler import EventType from astrbot.core.utils.path_util import path_Mapping from astrbot.core.utils.session_lock import session_lock_manager @@ -114,6 +112,43 @@ class RespondStage(Stage): # 如果所有组件都为空 return True + def is_seg_reply_required(self, event: AstrMessageEvent) -> bool: + """检查是否需要分段回复""" + if not self.enable_seg: + return False + + if self.only_llm_result and not event.get_result().is_llm_result(): + return False + + if event.get_platform_name() in [ + "qq_official", + "weixin_official_account", + "dingtalk", + ]: + return False + + return True + + def _extract_comp( + self, + raw_chain: list[BaseMessageComponent], + extract_types: set[ComponentType], + modify_raw_chain: bool = True, + ): + extracted = [] + if modify_raw_chain: + remaining = [] + for comp in raw_chain: + if comp.type in extract_types: + extracted.append(comp) + else: + remaining.append(comp) + raw_chain[:] = remaining + else: + extracted = [comp for comp in raw_chain if comp.type in extract_types] + + return extracted + async def process( self, event: AstrMessageEvent ) -> Union[None, AsyncGenerator[None, None]]: @@ -123,7 +158,14 @@ class RespondStage(Stage): if result.result_content_type == ResultContentType.STREAMING_FINISH: return + logger.info( + f"Prepare to send - {event.get_sender_name()}/{event.get_sender_id()}: {event._outline_chain(result.chain)}" + ) + if result.result_content_type == ResultContentType.STREAMING_RESULT: + if result.async_stream is None: + logger.warning("async_stream 为空,跳过发送。") + return # 流式结果直接交付平台适配器处理 use_fallback = self.config.get("provider_settings", {}).get( "streaming_segmented", False @@ -148,87 +190,71 @@ class RespondStage(Stage): except Exception as e: logger.warning(f"空内容检查异常: {e}") - record_comps = [c for c in result.chain if isinstance(c, Comp.Record)] - non_record_comps = [ - c for c in result.chain if not isinstance(c, Comp.Record) - ] - - if ( - self.enable_seg - and ( - (self.only_llm_result and result.is_llm_result()) - or not self.only_llm_result + # 发送消息链 + # Record 需要强制单独发送 + need_separately = {ComponentType.Record} + if self.is_seg_reply_required(event): + header_comps = self._extract_comp( + result.chain, + {ComponentType.Reply, ComponentType.At}, + modify_raw_chain=True, ) - and event.get_platform_name() - not in ["qq_official", "weixin_official_account", "dingtalk"] - ): - decorated_comps = [] - if self.reply_with_mention: - for comp in result.chain: - if isinstance(comp, Comp.At): - decorated_comps.append(comp) - result.chain.remove(comp) - break - if self.reply_with_quote: - for comp in result.chain: - if isinstance(comp, Comp.Reply): - decorated_comps.append(comp) - result.chain.remove(comp) - break - - # leverage lock to guarentee the order of message sending among different events + if not result.chain or len(result.chain) == 0: + # may fix #2670 + logger.warning( + f"实际消息链为空, 跳过发送阶段。header_chain: {header_comps}, actual_chain: {result.chain}" + ) + return async with session_lock_manager.acquire_lock(event.unified_msg_origin): - for rcomp in record_comps: - i = await self._calc_comp_interval(rcomp) - await asyncio.sleep(i) - try: - await event.send(MessageChain([rcomp])) - except Exception as e: - logger.error(f"发送消息失败: {e} chain: {result.chain}") - break - # 分段回复 - for comp in non_record_comps: + for comp in result.chain: i = await self._calc_comp_interval(comp) await asyncio.sleep(i) try: - await event.send(MessageChain([*decorated_comps, comp])) - decorated_comps = [] # 清空已发送的装饰组件 + if comp.type in need_separately: + await event.send(MessageChain([comp])) + else: + await event.send(MessageChain([*header_comps, comp])) + header_comps.clear() except Exception as e: - logger.error(f"发送消息失败: {e} chain: {result.chain}") - break + logger.error( + f"发送消息链失败: chain = {MessageChain([comp])}, error = {e}", + exc_info=True, + ) else: - for rcomp in record_comps: + if all( + comp.type in {ComponentType.Reply, ComponentType.At} + for comp in result.chain + ): + # may fix #2670 + logger.warning( + f"消息链全为 Reply 和 At 消息段, 跳过发送阶段。chain: {result.chain}" + ) + return + sep_comps = self._extract_comp( + result.chain, + need_separately, + modify_raw_chain=True, + ) + for comp in sep_comps: + chain = MessageChain([comp]) try: - await event.send(MessageChain([rcomp])) + await event.send(chain) except Exception as e: - logger.error(f"发送消息失败: {e} chain: {result.chain}") + logger.error( + f"发送消息链失败: chain = {chain}, error = {e}", + exc_info=True, + ) + chain = MessageChain(result.chain) + if result.chain and len(result.chain) > 0: + try: + await event.send(chain) + except Exception as e: + logger.error( + f"发送消息链失败: chain = {chain}, error = {e}", + exc_info=True, + ) - try: - await event.send(MessageChain(non_record_comps)) - except Exception as e: - logger.error(traceback.format_exc()) - logger.error(f"发送消息失败: {e} chain: {result.chain}") - - logger.info( - f"AstrBot -> {event.get_sender_name()}/{event.get_sender_id()}: {event._outline_chain(result.chain)}" - ) - - handlers = star_handlers_registry.get_handlers_by_event_type( - EventType.OnAfterMessageSentEvent, plugins_name=event.plugins_name - ) - for handler in handlers: - try: - logger.debug( - f"hook(on_after_message_sent) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}" - ) - await handler.handler(event) - except BaseException: - logger.error(traceback.format_exc()) - - if event.is_stopped(): - logger.info( - f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。" - ) - return + if await call_event_hook(event, EventType.OnAfterMessageSentEvent): + return event.clear_result()