diff --git a/astrbot/api/util/__init__.py b/astrbot/api/util/__init__.py new file mode 100644 index 000000000..a66206e05 --- /dev/null +++ b/astrbot/api/util/__init__.py @@ -0,0 +1,7 @@ +from astrbot.core.utils.session_waiter import ( + SessionWaiter, + SessionController, + session_waiter, +) + +__all__ = ["SessionWaiter", "SessionController", "session_waiter"] diff --git a/astrbot/core/utils/session_waiter.py b/astrbot/core/utils/session_waiter.py new file mode 100644 index 000000000..7b0c7855f --- /dev/null +++ b/astrbot/core/utils/session_waiter.py @@ -0,0 +1,163 @@ +""" +会话控制 +""" + +import asyncio +import time +import functools +import copy +import astrbot.core.message.components as Comp +from typing import Dict, Any, Callable, Awaitable, List +from astrbot.core.platform import AstrMessageEvent + +USER_SESSIONS: Dict[str, "SessionWaiter"] = {} # 存储 SessionWaiter 实例 + + +class SessionController: + """ + 控制一个 Session 是否已经结束 + """ + + def __init__(self): + self.future = asyncio.Future() + self.current_event: asyncio.Event = None + """当前正在等待的所用的异步事件""" + self.ts: float = None + """上次保持(keep)开始时的时间""" + self.timeout: float | int = None + """上次保持(keep)开始时的超时时间""" + + self.history_chains: List[List[Comp.BaseMessageComponent]] = [] + + def stop(self, error: Exception = None): + """立即结束这个会话""" + if not self.future.done(): + if error: + self.future.set_exception(error) + else: + self.future.set_result(None) + + def keep(self, timeout: float | int = 0, reset_timeout=False): + """保持这个会话 + + Args: + timeout (float): 必填。会话超时时间。 + 当 reset_timeout 设置为 True 时, 代表重置超时时间, timeout 必须 > 0, 如果 <= 0 则立即结束会话。 + 当 reset_timeout 设置为 False 时, 代表继续维持原来的超时时间, 新 timeout = 原来剩余的timeout + timeout (可以 < 0) + """ + new_ts = time.time() + + if reset_timeout: + if timeout <= 0: + self.stop() + return + else: + left_timeout = self.timeout - (new_ts - self.ts) + timeout = left_timeout + timeout + if timeout <= 0: + self.stop() + return + + if self.current_event and not self.current_event.is_set(): + self.current_event.set() # 通知上一个 keep 结束 + + new_event = asyncio.Event() + self.ts = new_ts + self.current_event = new_event + self.timeout = timeout + + asyncio.create_task(self._holding(new_event, timeout)) # 开始新的 keep + + async def _holding(self, event: asyncio.Event, timeout: int): + """等待事件结束或超时""" + try: + await asyncio.wait_for(event.wait(), timeout) + except asyncio.TimeoutError: + if not self.future.done(): + self.future.set_exception(TimeoutError("等待超时")) + except asyncio.CancelledError: + pass # 避免报错 + # finally: + + def get_history_chains(self) -> List[List[Comp.BaseMessageComponent]]: + """获取历史消息链""" + return self.history_chains + + +class SessionWaiter: + def __init__(self, session_id: str, record_history_chains: bool): + self.session_id = session_id + self.handler: Callable[[str], Awaitable[Any]] | None = None # 处理函数 + + self.session_controller = SessionController() + self.record_history_chains = record_history_chains + """是否记录历史消息链""" + + self._lock = asyncio.Lock() + """需要保证一个 session 同时只有一个 trigger""" + + async def register_wait( + self, handler: Callable[[str], Awaitable[Any]], timeout: int = 30 + ) -> Any: + """等待外部输入并处理""" + self.handler = handler + USER_SESSIONS[self.session_id] = self + + # 开始一个会话保持事件 + self.session_controller.keep(timeout, reset_timeout=True) + + try: + return await self.session_controller.future + except Exception as e: + self._cleanup(e) + raise e + finally: + self._cleanup() + + def _cleanup(self, error: Exception = None): + """清理会话""" + USER_SESSIONS.pop(self.session_id, None) + self.session_controller.stop(error) + + @classmethod + async def trigger(cls, session_id: str, event: AstrMessageEvent): + """外部输入触发会话处理""" + session = USER_SESSIONS.get(session_id, None) + if not session or session.session_controller.future.done(): + return + + async with session._lock: + if not session.session_controller.future.done(): + if session.record_history_chains: + session.session_controller.history_chains.append( + [copy.deepcopy(comp) for comp in event.get_messages()] + ) + try: + # TODO: 这里使用 create_task,跟踪 task,防止超时后这里 handler 仍然在执行 + await session.handler(session.session_controller, event) + except Exception as e: + session.session_controller.stop(e) + + +def session_waiter(session_id_param: str, timeout: int = 30, record_history_chains: bool = False): + """ + 装饰器:自动将函数注册为 SessionWaiter 处理函数,并等待外部输入触发执行。 + + :param session_id_param: 用于从参数中获取 session_id 的参数名称 + :param timeout: 超时时间(秒) + :param record_history_chain: 是否自动记录历史消息链。可以通过 controller.get_history_chains() 获取。深拷贝。 + """ + + def decorator(func: Callable[[str], Awaitable[Any]]): + @functools.wraps(func) + async def wrapper(*args, **kwargs): + session_id = kwargs.get(session_id_param) + if not session_id: + raise ValueError(f"缺少 session_id 参数 '{session_id_param}'") + + waiter = SessionWaiter(session_id, record_history_chains) + return await waiter.register_wait(func, timeout) + + return wrapper + + return decorator diff --git a/packages/session_controller/main.py b/packages/session_controller/main.py new file mode 100644 index 000000000..280371042 --- /dev/null +++ b/packages/session_controller/main.py @@ -0,0 +1,25 @@ + +from astrbot.api.event import AstrMessageEvent, filter +from astrbot.api.star import Context, Star, register +from astrbot.core.utils.session_waiter import SessionWaiter, USER_SESSIONS +from sys import maxsize + +@register( + "session_controller", + "Cvandia & Soulter", + "为插件支持会话控制", + "v1.0.1", + "https://astrbot.app", +) +class Waiter(Star): + """会话控制""" + + def __init__(self, context: Context): + super().__init__(context) + + @filter.event_message_type(filter.EventMessageType.ALL, priority=maxsize) + async def handle_session_control_agent(self, event: AstrMessageEvent): + session_id = event.unified_msg_origin + if session_id in USER_SESSIONS: + await SessionWaiter.trigger(session_id, event) + event.stop_event()