From 0caff054f5b7feffe6981441611b5c0e8b7063b2 Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Sat, 8 Mar 2025 20:29:42 +0800 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20feat:=20=E4=BC=9A=E8=AF=9D=E6=8E=A7?= =?UTF-8?q?=E5=88=B6=E5=99=A8=E6=94=AF=E6=8C=81=E8=87=AA=E5=AE=9A=E4=B9=89?= =?UTF-8?q?=E4=BC=9A=E8=AF=9DID=E7=AE=97=E5=AD=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/core/utils/session_waiter.py | 48 ++++++++++++++++++++++++---- packages/session_controller/main.py | 14 ++++---- 2 files changed, 49 insertions(+), 13 deletions(-) diff --git a/astrbot/core/utils/session_waiter.py b/astrbot/core/utils/session_waiter.py index 616aa17b9..4a867bff4 100644 --- a/astrbot/core/utils/session_waiter.py +++ b/astrbot/core/utils/session_waiter.py @@ -2,6 +2,7 @@ 会话控制 """ +import abc import asyncio import time import functools @@ -11,6 +12,7 @@ from typing import Dict, Any, Callable, Awaitable, List from astrbot.core.platform import AstrMessageEvent USER_SESSIONS: Dict[str, "SessionWaiter"] = {} # 存储 SessionWaiter 实例 +FILTERS: List["SessionFilter"] = [] # 存储 SessionFilter 实例 class SessionController: @@ -84,9 +86,30 @@ class SessionController: return self.history_chains +class SessionFilter: + """如何界定一个会话""" + + @abc.abstractmethod + def filter(self, event: AstrMessageEvent) -> str: + """根据事件返回一个会话标识符""" + pass + + +class DefaultSessionFilter(SessionFilter): + def filter(self, event: AstrMessageEvent) -> str: + """默认实现,返回发送者的 ID 作为会话标识符""" + return event.get_sender_id() + + class SessionWaiter: - def __init__(self, session_id: str, record_history_chains: bool): + def __init__( + self, + session_filter: SessionFilter, + session_id: str, + record_history_chains: bool, + ): self.session_id = session_id + self.session_filter = session_filter self.handler: Callable[[str], Awaitable[Any]] | None = None # 处理函数 self.session_controller = SessionController() @@ -117,6 +140,10 @@ class SessionWaiter: def _cleanup(self, error: Exception = None): """清理会话""" USER_SESSIONS.pop(self.session_id, None) + try: + FILTERS.remove(self.session_filter) + except ValueError: + pass self.session_controller.stop(error) @classmethod @@ -149,12 +176,21 @@ def session_waiter(timeout: int = 30, record_history_chains: bool = False): def decorator(func: Callable[[str], Awaitable[Any]]): @functools.wraps(func) - async def wrapper(*args, **kwargs): - session_id = kwargs.get("session_id", None) - if not session_id: - raise ValueError("缺少 session_id 参数") + async def wrapper( + event: AstrMessageEvent, + session_filter: SessionFilter = None, + *args, + **kwargs, + ): + if not session_filter: + session_filter = DefaultSessionFilter() + if not isinstance(session_filter, SessionFilter): + raise ValueError("session_filter 必须是 SessionFilter") - waiter = SessionWaiter(session_id, record_history_chains) + session_id = session_filter.filter(event) + FILTERS.append(session_filter) + + waiter = SessionWaiter(session_filter, session_id, record_history_chains) return await waiter.register_wait(func, timeout) return wrapper diff --git a/packages/session_controller/main.py b/packages/session_controller/main.py index 0750e4c52..89a98dca0 100644 --- a/packages/session_controller/main.py +++ b/packages/session_controller/main.py @@ -6,6 +6,7 @@ from astrbot.api.star import Context, Star, register from astrbot.core.utils.session_waiter import ( SessionWaiter, USER_SESSIONS, + FILTERS, session_waiter, SessionController, ) @@ -33,10 +34,11 @@ class Waiter(Star): @filter.event_message_type(filter.EventMessageType.ALL, priority=maxsize) async def handle_session_control_agent(self, event: AstrMessageEvent): """会话控制代理""" - session_id = event.get_sender_id() - if session_id in USER_SESSIONS: - await SessionWaiter.trigger(session_id, event) - event.stop_event() + for session_filter in FILTERS: + session_id = session_filter.filter(event) + if session_id in USER_SESSIONS: + await SessionWaiter.trigger(session_id, event) + event.stop_event() @filter.event_message_type(filter.EventMessageType.ALL, priority=maxsize - 1) async def handle_empty_mention(self, event: AstrMessageEvent): @@ -70,9 +72,7 @@ class Waiter(Star): controller.stop() try: - await empty_mention_waiter( - event, session_id=event.get_sender_id() - ) + await empty_mention_waiter(event) except TimeoutError as _: yield event.plain_result("如果需要帮助,请再次 @ 我哦~") except Exception as e: