feat: 会话控制器支持自定义会话ID算子

This commit is contained in:
Soulter
2025-03-08 20:29:42 +08:00
parent 4aa91ad599
commit 0caff054f5
2 changed files with 49 additions and 13 deletions
+42 -6
View File
@@ -2,6 +2,7 @@
会话控制
"""
import abc
import asyncio
import time
import functools
@@ -11,6 +12,7 @@ from typing import Dict, Any, Callable, Awaitable, List
from astrbot.core.platform import AstrMessageEvent
USER_SESSIONS: Dict[str, "SessionWaiter"] = {} # 存储 SessionWaiter 实例
FILTERS: List["SessionFilter"] = [] # 存储 SessionFilter 实例
class SessionController:
@@ -84,9 +86,30 @@ class SessionController:
return self.history_chains
class SessionFilter:
"""如何界定一个会话"""
@abc.abstractmethod
def filter(self, event: AstrMessageEvent) -> str:
"""根据事件返回一个会话标识符"""
pass
class DefaultSessionFilter(SessionFilter):
def filter(self, event: AstrMessageEvent) -> str:
"""默认实现,返回发送者的 ID 作为会话标识符"""
return event.get_sender_id()
class SessionWaiter:
def __init__(self, session_id: str, record_history_chains: bool):
def __init__(
self,
session_filter: SessionFilter,
session_id: str,
record_history_chains: bool,
):
self.session_id = session_id
self.session_filter = session_filter
self.handler: Callable[[str], Awaitable[Any]] | None = None # 处理函数
self.session_controller = SessionController()
@@ -117,6 +140,10 @@ class SessionWaiter:
def _cleanup(self, error: Exception = None):
"""清理会话"""
USER_SESSIONS.pop(self.session_id, None)
try:
FILTERS.remove(self.session_filter)
except ValueError:
pass
self.session_controller.stop(error)
@classmethod
@@ -149,12 +176,21 @@ def session_waiter(timeout: int = 30, record_history_chains: bool = False):
def decorator(func: Callable[[str], Awaitable[Any]]):
@functools.wraps(func)
async def wrapper(*args, **kwargs):
session_id = kwargs.get("session_id", None)
if not session_id:
raise ValueError("缺少 session_id 参数")
async def wrapper(
event: AstrMessageEvent,
session_filter: SessionFilter = None,
*args,
**kwargs,
):
if not session_filter:
session_filter = DefaultSessionFilter()
if not isinstance(session_filter, SessionFilter):
raise ValueError("session_filter 必须是 SessionFilter")
waiter = SessionWaiter(session_id, record_history_chains)
session_id = session_filter.filter(event)
FILTERS.append(session_filter)
waiter = SessionWaiter(session_filter, session_id, record_history_chains)
return await waiter.register_wait(func, timeout)
return wrapper
+7 -7
View File
@@ -6,6 +6,7 @@ from astrbot.api.star import Context, Star, register
from astrbot.core.utils.session_waiter import (
SessionWaiter,
USER_SESSIONS,
FILTERS,
session_waiter,
SessionController,
)
@@ -33,10 +34,11 @@ class Waiter(Star):
@filter.event_message_type(filter.EventMessageType.ALL, priority=maxsize)
async def handle_session_control_agent(self, event: AstrMessageEvent):
"""会话控制代理"""
session_id = event.get_sender_id()
if session_id in USER_SESSIONS:
await SessionWaiter.trigger(session_id, event)
event.stop_event()
for session_filter in FILTERS:
session_id = session_filter.filter(event)
if session_id in USER_SESSIONS:
await SessionWaiter.trigger(session_id, event)
event.stop_event()
@filter.event_message_type(filter.EventMessageType.ALL, priority=maxsize - 1)
async def handle_empty_mention(self, event: AstrMessageEvent):
@@ -70,9 +72,7 @@ class Waiter(Star):
controller.stop()
try:
await empty_mention_waiter(
event, session_id=event.get_sender_id()
)
await empty_mention_waiter(event)
except TimeoutError as _:
yield event.plain_result("如果需要帮助,请再次 @ 我哦~")
except Exception as e: