Files
AstrBot/astrbot/core/utils/session_waiter.py
T
Dt8333 f624971613 chore: fix bunches of type checking errors (#3213)
* chore(core.utils): 🚨 修正错误Lint

* chore(core.provider): 🚨 修复基类错误Lint

* chore(core.utils): 补全session_get()的重载

* chore(core.provider): 🚨 修正实现错误Lint

* chore(core.platform): 🚨 修正platform基类和webchat的错误Lint

* chore(core.platform): 修正错误实现Lint

* fix(core.provider): 修复循环调用和错误assert

* chore(core.platform): 修复部分实现Lint

* chore(core.provider): 补充Dify.text_chat_stream的参数类型

* chore(core.pipeline): 🚨 修复错误Lint

* fix(core.slack): 补充遗漏导入

* chore(core.utils): 修复错误的session_get声明

* chore(core.platform): 移除Lark adapter import中的wildcard

* chore(core.db): 修复声明和部分逻辑

* chore(core.db): 添加typings,使faiss参数能被正确识别。

* chore(core): 修复声明

* chore(core): 修改声明

* chore: 补充faiss声明

* chore(dashboard): 修改实现,减少报错

* chore(package): 修改部分声明与实现,减少报错

* chore(core): 添加Handler的overload,以去除部分assert同时通过类型检查

* chore(core.pipeline): 修改Pipeline Scheduler的execute,将判断属性改为判断类型,通过静态类型检查

* chore(core.config): 添加类型标注,通过类型检查

* chore(core.message): 为File._download_file添加检查,通过类型检查

* fix: 将断言改为条件判断以实现优雅关闭的容错性

* refactor: 移除 discord 客户端中的 assert,改用 if None 判断并抛出异常

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: DiscordPlatformAdapter 对 self.client.user 为 None 做日志并返回,移除断言

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: 增强 Lark 相关空值/异常检查并完善日志输出

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 将断言替换为条件检查并加入日志与错误处理

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* chore: 移除LLM生成的无用注释

* refactor: 使用 File.get_file 替换下载逻辑并移除 assert,提供默认 filename

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: Slack Socket 未初始化抛出运行时异常,图片 URL 判空改为非空判断

* refactor: 将 WeChatPadProAdapter 的断言改为空值判断并添加日志

* refactor: 使用 isinstance 替代断言实现类型判断,便于静态检查

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: 去除cast,直接使用字段与字典访问,修正端口解析

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 使用 match-case 重构 ProviderManager 加载并通过类型检查抛出 TypeError

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: group_name_display 时若 group 对象为空则记录错误并返回

* fix: 将 _get_current_persona_id 的 assert 替换成 if guard 并返回 None

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: 优化插件目录存在性检查及图片URL非空验证,更新JSON排序配置

* fix: 将 datetime_str 的 assert 替换为显式检查并抛出异常

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 移除 cast,改为运行时检查并在找不到调度器时跳过

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 移除 cast,改用 isinstance 检查 FaissVecDB 并警告

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: 删除 typing.cast 导入,并在获取文件绝对路径前校验 file_

* refactor: 移除 typing.cast,简化内容安全检查调用

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 将 PlatformMetadata.id 设为必填并在注册时传入 id,移除 cast

* refactor: 移除 cast,改用 HasInitialize 与 isinstance 进行初始化

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: 为 ProviderManager.initialize 增加ID类型判断,避免 None 导致 get 失败

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 为 OTTSProvider 与 AzureNativeProvider 引入 _client 与 client 属性改进上下文管理

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: 为 Whisper 自托管源添加模型未初始化校验并直接调用 transcribe

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 移除未使用的 cast 导入并简化 platform_name 赋值

* refactor: 引入 cast 并对 id 使用 cast(str, ...) 提升类型安全

* fix: 将 _id_to_sid 返回改为 str,空值返回空串;对 id 与 message_id 使用 cast

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 重构 Discord 处理逻辑:强制 类型转换、优先斜杠指令并优化提及判断

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: 统一对 id 获取执行 cast,并在微信消息解析失败时抛错

* Revert "fix: 去除cast,直接使用字段与字典访问,修正端口解析"

This reverts commit 1cbfdf9d1b.

* fix: 百炼 Rerank 会话关闭时返回空结果;初始化 request.prompt 避免空值拼接

* fix: 统一处理搜索结果链接为字符串,新增 _get_url 助手并适配 Bing/Sogo

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 调整 call_handler 泛型、Discord 通道注解及 FishAudioTTS API 请求类型

* refactor: 使用 col(...) 替代列引用并对结果进行 CursorResult 强转

* chore: ruff format

---------

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
Co-authored-by: Soulter <905617992@qq.com>
2025-12-09 14:13:47 +08:00

205 lines
6.9 KiB
Python

"""会话控制"""
import abc
import asyncio
import copy
import functools
import time
from collections.abc import Awaitable, Callable
from typing import Any
import astrbot.core.message.components as Comp
from astrbot.core.platform import AstrMessageEvent
USER_SESSIONS: dict[str, "SessionWaiter"] = {} # 存储 SessionWaiter 实例
FILTERS: list["SessionFilter"] = [] # 存储 SessionFilter 实例
class SessionController:
"""控制一个 Session 是否已经结束"""
def __init__(self):
self.future = asyncio.Future()
self.current_event: asyncio.Event | None = None
"""当前正在等待的所用的异步事件"""
self.ts: float | None = None
"""上次保持(keep)开始时的时间"""
self.timeout: float | int | None = None
"""上次保持(keep)开始时的超时时间"""
self.history_chains: list[list[Comp.BaseMessageComponent]] = []
def stop(self, error: Exception | None = None):
"""立即结束这个会话"""
if not self.future.done():
if error:
self.future.set_exception(error)
else:
self.future.set_result(None)
def keep(self, timeout: float = 0, reset_timeout=False):
"""保持这个会话
Args:
timeout (float): 必填。会话超时时间。
当 reset_timeout 设置为 True 时, 代表重置超时时间, timeout 必须 > 0, 如果 <= 0 则立即结束会话。
当 reset_timeout 设置为 False 时, 代表继续维持原来的超时时间, 新 timeout = 原来剩余的timeout + timeout (可以 < 0)
"""
new_ts = time.time()
if reset_timeout:
if timeout <= 0:
self.stop()
return
else:
assert self.timeout is not None
assert self.ts is not None
left_timeout = self.timeout - (new_ts - self.ts)
timeout = left_timeout + timeout
if timeout <= 0:
self.stop()
return
if self.current_event and not self.current_event.is_set():
self.current_event.set() # 通知上一个 keep 结束
new_event = asyncio.Event()
self.ts = new_ts
self.current_event = new_event
self.timeout = timeout
asyncio.create_task(self._holding(new_event, timeout)) # 开始新的 keep
async def _holding(self, event: asyncio.Event, timeout: float):
"""等待事件结束或超时"""
try:
await asyncio.wait_for(event.wait(), timeout)
except asyncio.TimeoutError:
if not self.future.done():
self.future.set_exception(TimeoutError("等待超时"))
except asyncio.CancelledError:
pass # 避免报错
# finally:
def get_history_chains(self) -> list[list[Comp.BaseMessageComponent]]:
"""获取历史消息链"""
return self.history_chains
class SessionFilter:
"""如何界定一个会话"""
@abc.abstractmethod
def filter(self, event: AstrMessageEvent) -> str:
"""根据事件返回一个会话标识符"""
class DefaultSessionFilter(SessionFilter):
def filter(self, event: AstrMessageEvent) -> str:
"""默认实现,返回统一消息来源字符串作为会话标识符"""
return event.unified_msg_origin
class SessionWaiter:
def __init__(
self,
session_filter: SessionFilter,
session_id: str,
record_history_chains: bool,
):
self.session_id = session_id
self.session_filter = session_filter
self.handler: (
Callable[[SessionController, AstrMessageEvent], Awaitable[Any]] | None
) = None # 处理函数
self.session_controller = SessionController()
self.record_history_chains = record_history_chains
"""是否记录历史消息链"""
self._lock = asyncio.Lock()
"""需要保证一个 session 同时只有一个 trigger"""
async def register_wait(
self,
handler: Callable[[SessionController, AstrMessageEvent], Awaitable[Any]],
timeout: int = 30,
) -> Any:
"""等待外部输入并处理"""
self.handler = handler
USER_SESSIONS[self.session_id] = self
# 开始一个会话保持事件
self.session_controller.keep(timeout, reset_timeout=True)
try:
return await self.session_controller.future
except Exception as e:
self._cleanup(e)
raise e
finally:
self._cleanup()
def _cleanup(self, error: Exception | None = None):
"""清理会话"""
USER_SESSIONS.pop(self.session_id, None)
try:
FILTERS.remove(self.session_filter)
except ValueError:
pass
self.session_controller.stop(error)
@classmethod
async def trigger(cls, session_id: str, event: AstrMessageEvent):
"""外部输入触发会话处理"""
session = USER_SESSIONS.get(session_id)
if not session or session.session_controller.future.done():
return
async with session._lock:
if not session.session_controller.future.done():
if session.record_history_chains:
session.session_controller.history_chains.append(
[copy.deepcopy(comp) for comp in event.get_messages()],
)
try:
# TODO: 这里使用 create_task,跟踪 task,防止超时后这里 handler 仍然在执行
assert session.handler is not None
await session.handler(session.session_controller, event)
except Exception as e:
session.session_controller.stop(e)
def session_waiter(timeout: int = 30, record_history_chains: bool = False):
"""装饰器:自动将函数注册为 SessionWaiter 处理函数,并等待外部输入触发执行。
:param timeout: 超时时间(秒)
:param record_history_chain: 是否自动记录历史消息链。可以通过 controller.get_history_chains() 获取。深拷贝。
"""
def decorator(
func: Callable[[SessionController, AstrMessageEvent], Awaitable[Any]],
):
@functools.wraps(func)
async def wrapper(
event: AstrMessageEvent,
session_filter: SessionFilter | None = None,
*args,
**kwargs,
):
if not session_filter:
session_filter = DefaultSessionFilter()
if not isinstance(session_filter, SessionFilter):
raise ValueError("session_filter 必须是 SessionFilter")
session_id = session_filter.filter(event)
FILTERS.append(session_filter)
waiter = SessionWaiter(session_filter, session_id, record_history_chains)
return await waiter.register_wait(func, timeout)
return wrapper
return decorator