diff --git a/astrbot/builtin_stars/builtin_commands/commands/conversation.py b/astrbot/builtin_stars/builtin_commands/commands/conversation.py index eb8cfdefa..f6d5db914 100644 --- a/astrbot/builtin_stars/builtin_commands/commands/conversation.py +++ b/astrbot/builtin_stars/builtin_commands/commands/conversation.py @@ -4,6 +4,7 @@ from astrbot.api import sp, star from astrbot.api.event import AstrMessageEvent, MessageEventResult from astrbot.core.platform.astr_message_event import MessageSession from astrbot.core.platform.message_type import MessageType +from astrbot.core.utils.active_event_registry import active_event_registry from .utils.rst_scene import RstScene @@ -62,6 +63,7 @@ class ConversationCommands: agent_runner_type = cfg["provider_settings"]["agent_runner_type"] if agent_runner_type in THIRD_PARTY_AGENT_RUNNER_KEY: + active_event_registry.stop_all(umo, exclude=message) await sp.remove_async( scope="umo", scope_id=umo, @@ -86,6 +88,8 @@ class ConversationCommands: ) return + active_event_registry.stop_all(umo, exclude=message) + await self.context.conversation_manager.update_conversation( umo, cid, @@ -221,6 +225,7 @@ class ConversationCommands: cfg = self.context.get_config(umo=message.unified_msg_origin) agent_runner_type = cfg["provider_settings"]["agent_runner_type"] if agent_runner_type in THIRD_PARTY_AGENT_RUNNER_KEY: + active_event_registry.stop_all(message.unified_msg_origin, exclude=message) await sp.remove_async( scope="umo", scope_id=message.unified_msg_origin, @@ -229,6 +234,7 @@ class ConversationCommands: message.set_result(MessageEventResult().message("已创建新对话。")) return + active_event_registry.stop_all(message.unified_msg_origin, exclude=message) cpersona = await self._get_current_persona_id(message.unified_msg_origin) cid = await self.context.conversation_manager.new_conversation( message.unified_msg_origin, @@ -321,7 +327,8 @@ class ConversationCommands: async def del_conv(self, message: AstrMessageEvent) -> None: """删除当前对话""" - cfg = self.context.get_config(umo=message.unified_msg_origin) + umo = message.unified_msg_origin + cfg = self.context.get_config(umo=umo) is_unique_session = cfg["platform_settings"]["unique_session"] if message.get_group_id() and not is_unique_session and message.role != "admin": # 群聊,没开独立会话,发送人不是管理员 @@ -334,18 +341,17 @@ class ConversationCommands: agent_runner_type = cfg["provider_settings"]["agent_runner_type"] if agent_runner_type in THIRD_PARTY_AGENT_RUNNER_KEY: + active_event_registry.stop_all(umo, exclude=message) await sp.remove_async( scope="umo", - scope_id=message.unified_msg_origin, + scope_id=umo, key=THIRD_PARTY_AGENT_RUNNER_KEY[agent_runner_type], ) message.set_result(MessageEventResult().message("重置对话成功。")) return session_curr_cid = ( - await self.context.conversation_manager.get_curr_conversation_id( - message.unified_msg_origin, - ) + await self.context.conversation_manager.get_curr_conversation_id(umo) ) if not session_curr_cid: @@ -356,8 +362,10 @@ class ConversationCommands: ) return + active_event_registry.stop_all(umo, exclude=message) + await self.context.conversation_manager.delete_conversation( - message.unified_msg_origin, + umo, session_curr_cid, ) diff --git a/astrbot/core/pipeline/scheduler.py b/astrbot/core/pipeline/scheduler.py index 71c98778f..c4a65077a 100644 --- a/astrbot/core/pipeline/scheduler.py +++ b/astrbot/core/pipeline/scheduler.py @@ -6,6 +6,7 @@ from astrbot.core.platform.sources.webchat.webchat_event import WebChatMessageEv from astrbot.core.platform.sources.wecom_ai_bot.wecomai_event import ( WecomAIBotMessageEvent, ) +from astrbot.core.utils.active_event_registry import active_event_registry from . import STAGES_ORDER from .context import PipelineContext @@ -79,10 +80,14 @@ class PipelineScheduler: event (AstrMessageEvent): 事件对象 """ - await self._process_stages(event) + active_event_registry.register(event) + try: + await self._process_stages(event) - # 如果没有发送操作, 则发送一个空消息, 以便于后续的处理 - if isinstance(event, WebChatMessageEvent | WecomAIBotMessageEvent): - await event.send(None) + # 如果没有发送操作, 则发送一个空消息, 以便于后续的处理 + if isinstance(event, WebChatMessageEvent | WecomAIBotMessageEvent): + await event.send(None) - logger.debug("pipeline 执行完毕。") + logger.debug("pipeline 执行完毕。") + finally: + active_event_registry.unregister(event) diff --git a/astrbot/core/utils/active_event_registry.py b/astrbot/core/utils/active_event_registry.py new file mode 100644 index 000000000..254859933 --- /dev/null +++ b/astrbot/core/utils/active_event_registry.py @@ -0,0 +1,50 @@ +from __future__ import annotations + +from collections import defaultdict +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from astrbot.core.platform import AstrMessageEvent + + +class ActiveEventRegistry: + """维护 unified_msg_origin 到活跃事件的映射。 + + 用于在 reset 等场景下终止该会话正在处理的事件。 + """ + + def __init__(self) -> None: + self._events: dict[str, set[AstrMessageEvent]] = defaultdict(set) + + def register(self, event: AstrMessageEvent) -> None: + self._events[event.unified_msg_origin].add(event) + + def unregister(self, event: AstrMessageEvent) -> None: + umo = event.unified_msg_origin + self._events[umo].discard(event) + if not self._events[umo]: + del self._events[umo] + + def stop_all( + self, + umo: str, + exclude: AstrMessageEvent | None = None, + ) -> int: + """终止指定 UMO 的所有活跃事件。 + + Args: + umo: 统一消息来源标识符。 + exclude: 需要排除的事件(通常是发起 reset 的事件本身)。 + + Returns: + 被终止的事件数量。 + """ + count = 0 + for event in list(self._events.get(umo, [])): + if event is not exclude: + event.stop_event() + count += 1 + return count + + +active_event_registry = ActiveEventRegistry()