✨ feat: 会话控制器支持自定义会话ID算子
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
会话控制
|
||||
"""
|
||||
|
||||
import abc
|
||||
import asyncio
|
||||
import time
|
||||
import functools
|
||||
@@ -11,6 +12,7 @@ from typing import Dict, Any, Callable, Awaitable, List
|
||||
from astrbot.core.platform import AstrMessageEvent
|
||||
|
||||
USER_SESSIONS: Dict[str, "SessionWaiter"] = {} # 存储 SessionWaiter 实例
|
||||
FILTERS: List["SessionFilter"] = [] # 存储 SessionFilter 实例
|
||||
|
||||
|
||||
class SessionController:
|
||||
@@ -84,9 +86,30 @@ class SessionController:
|
||||
return self.history_chains
|
||||
|
||||
|
||||
class SessionFilter:
|
||||
"""如何界定一个会话"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def filter(self, event: AstrMessageEvent) -> str:
|
||||
"""根据事件返回一个会话标识符"""
|
||||
pass
|
||||
|
||||
|
||||
class DefaultSessionFilter(SessionFilter):
|
||||
def filter(self, event: AstrMessageEvent) -> str:
|
||||
"""默认实现,返回发送者的 ID 作为会话标识符"""
|
||||
return event.get_sender_id()
|
||||
|
||||
|
||||
class SessionWaiter:
|
||||
def __init__(self, session_id: str, record_history_chains: bool):
|
||||
def __init__(
|
||||
self,
|
||||
session_filter: SessionFilter,
|
||||
session_id: str,
|
||||
record_history_chains: bool,
|
||||
):
|
||||
self.session_id = session_id
|
||||
self.session_filter = session_filter
|
||||
self.handler: Callable[[str], Awaitable[Any]] | None = None # 处理函数
|
||||
|
||||
self.session_controller = SessionController()
|
||||
@@ -117,6 +140,10 @@ class SessionWaiter:
|
||||
def _cleanup(self, error: Exception = None):
|
||||
"""清理会话"""
|
||||
USER_SESSIONS.pop(self.session_id, None)
|
||||
try:
|
||||
FILTERS.remove(self.session_filter)
|
||||
except ValueError:
|
||||
pass
|
||||
self.session_controller.stop(error)
|
||||
|
||||
@classmethod
|
||||
@@ -149,12 +176,21 @@ def session_waiter(timeout: int = 30, record_history_chains: bool = False):
|
||||
|
||||
def decorator(func: Callable[[str], Awaitable[Any]]):
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
session_id = kwargs.get("session_id", None)
|
||||
if not session_id:
|
||||
raise ValueError("缺少 session_id 参数")
|
||||
async def wrapper(
|
||||
event: AstrMessageEvent,
|
||||
session_filter: SessionFilter = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
if not session_filter:
|
||||
session_filter = DefaultSessionFilter()
|
||||
if not isinstance(session_filter, SessionFilter):
|
||||
raise ValueError("session_filter 必须是 SessionFilter")
|
||||
|
||||
waiter = SessionWaiter(session_id, record_history_chains)
|
||||
session_id = session_filter.filter(event)
|
||||
FILTERS.append(session_filter)
|
||||
|
||||
waiter = SessionWaiter(session_filter, session_id, record_history_chains)
|
||||
return await waiter.register_wait(func, timeout)
|
||||
|
||||
return wrapper
|
||||
|
||||
@@ -6,6 +6,7 @@ from astrbot.api.star import Context, Star, register
|
||||
from astrbot.core.utils.session_waiter import (
|
||||
SessionWaiter,
|
||||
USER_SESSIONS,
|
||||
FILTERS,
|
||||
session_waiter,
|
||||
SessionController,
|
||||
)
|
||||
@@ -33,10 +34,11 @@ class Waiter(Star):
|
||||
@filter.event_message_type(filter.EventMessageType.ALL, priority=maxsize)
|
||||
async def handle_session_control_agent(self, event: AstrMessageEvent):
|
||||
"""会话控制代理"""
|
||||
session_id = event.get_sender_id()
|
||||
if session_id in USER_SESSIONS:
|
||||
await SessionWaiter.trigger(session_id, event)
|
||||
event.stop_event()
|
||||
for session_filter in FILTERS:
|
||||
session_id = session_filter.filter(event)
|
||||
if session_id in USER_SESSIONS:
|
||||
await SessionWaiter.trigger(session_id, event)
|
||||
event.stop_event()
|
||||
|
||||
@filter.event_message_type(filter.EventMessageType.ALL, priority=maxsize - 1)
|
||||
async def handle_empty_mention(self, event: AstrMessageEvent):
|
||||
@@ -70,9 +72,7 @@ class Waiter(Star):
|
||||
controller.stop()
|
||||
|
||||
try:
|
||||
await empty_mention_waiter(
|
||||
event, session_id=event.get_sender_id()
|
||||
)
|
||||
await empty_mention_waiter(event)
|
||||
except TimeoutError as _:
|
||||
yield event.plain_result("如果需要帮助,请再次 @ 我哦~")
|
||||
except Exception as e:
|
||||
|
||||
Reference in New Issue
Block a user