f624971613
* chore(core.utils): 🚨 修正错误Lint
* chore(core.provider): 🚨 修复基类错误Lint
* chore(core.utils): 补全session_get()的重载
* chore(core.provider): 🚨 修正实现错误Lint
* chore(core.platform): 🚨 修正platform基类和webchat的错误Lint
* chore(core.platform): 修正错误实现Lint
* fix(core.provider): 修复循环调用和错误assert
* chore(core.platform): 修复部分实现Lint
* chore(core.provider): 补充Dify.text_chat_stream的参数类型
* chore(core.pipeline): 🚨 修复错误Lint
* fix(core.slack): 补充遗漏导入
* chore(core.utils): 修复错误的session_get声明
* chore(core.platform): 移除Lark adapter import中的wildcard
* chore(core.db): 修复声明和部分逻辑
* chore(core.db): 添加typings,使faiss参数能被正确识别。
* chore(core): 修复声明
* chore(core): 修改声明
* chore: 补充faiss声明
* chore(dashboard): 修改实现,减少报错
* chore(package): 修改部分声明与实现,减少报错
* chore(core): 添加Handler的overload,以去除部分assert同时通过类型检查
* chore(core.pipeline): 修改Pipeline Scheduler的execute,将判断属性改为判断类型,通过静态类型检查
* chore(core.config): 添加类型标注,通过类型检查
* chore(core.message): 为File._download_file添加检查,通过类型检查
* fix: 将断言改为条件判断以实现优雅关闭的容错性
* refactor: 移除 discord 客户端中的 assert,改用 if None 判断并抛出异常
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* fix: DiscordPlatformAdapter 对 self.client.user 为 None 做日志并返回,移除断言
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* fix: 增强 Lark 相关空值/异常检查并完善日志输出
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* refactor: 将断言替换为条件检查并加入日志与错误处理
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* chore: 移除LLM生成的无用注释
* refactor: 使用 File.get_file 替换下载逻辑并移除 assert,提供默认 filename
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* fix: Slack Socket 未初始化抛出运行时异常,图片 URL 判空改为非空判断
* refactor: 将 WeChatPadProAdapter 的断言改为空值判断并添加日志
* refactor: 使用 isinstance 替代断言实现类型判断,便于静态检查
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* fix: 去除cast,直接使用字段与字典访问,修正端口解析
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* refactor: 使用 match-case 重构 ProviderManager 加载并通过类型检查抛出 TypeError
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* fix: group_name_display 时若 group 对象为空则记录错误并返回
* fix: 将 _get_current_persona_id 的 assert 替换成 if guard 并返回 None
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* fix: 优化插件目录存在性检查及图片URL非空验证,更新JSON排序配置
* fix: 将 datetime_str 的 assert 替换为显式检查并抛出异常
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* refactor: 移除 cast,改为运行时检查并在找不到调度器时跳过
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* refactor: 移除 cast,改用 isinstance 检查 FaissVecDB 并警告
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* fix: 删除 typing.cast 导入,并在获取文件绝对路径前校验 file_
* refactor: 移除 typing.cast,简化内容安全检查调用
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* refactor: 将 PlatformMetadata.id 设为必填并在注册时传入 id,移除 cast
* refactor: 移除 cast,改用 HasInitialize 与 isinstance 进行初始化
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* fix: 为 ProviderManager.initialize 增加ID类型判断,避免 None 导致 get 失败
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* refactor: 为 OTTSProvider 与 AzureNativeProvider 引入 _client 与 client 属性改进上下文管理
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* fix: 为 Whisper 自托管源添加模型未初始化校验并直接调用 transcribe
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* refactor: 移除未使用的 cast 导入并简化 platform_name 赋值
* refactor: 引入 cast 并对 id 使用 cast(str, ...) 提升类型安全
* fix: 将 _id_to_sid 返回改为 str,空值返回空串;对 id 与 message_id 使用 cast
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* refactor: 重构 Discord 处理逻辑:强制 类型转换、优先斜杠指令并优化提及判断
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* fix: 统一对 id 获取执行 cast,并在微信消息解析失败时抛错
* Revert "fix: 去除cast,直接使用字段与字典访问,修正端口解析"
This reverts commit 1cbfdf9d1b.
* fix: 百炼 Rerank 会话关闭时返回空结果;初始化 request.prompt 避免空值拼接
* fix: 统一处理搜索结果链接为字符串,新增 _get_url 助手并适配 Bing/Sogo
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* refactor: 调整 call_handler 泛型、Discord 通道注解及 FishAudioTTS API 请求类型
* refactor: 使用 col(...) 替代列引用并对结果进行 CursorResult 强转
* chore: ruff format
---------
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
Co-authored-by: Soulter <905617992@qq.com>
205 lines
6.9 KiB
Python
205 lines
6.9 KiB
Python
"""会话控制"""
|
|
|
|
import abc
|
|
import asyncio
|
|
import copy
|
|
import functools
|
|
import time
|
|
from collections.abc import Awaitable, Callable
|
|
from typing import Any
|
|
|
|
import astrbot.core.message.components as Comp
|
|
from astrbot.core.platform import AstrMessageEvent
|
|
|
|
USER_SESSIONS: dict[str, "SessionWaiter"] = {} # 存储 SessionWaiter 实例
|
|
FILTERS: list["SessionFilter"] = [] # 存储 SessionFilter 实例
|
|
|
|
|
|
class SessionController:
|
|
"""控制一个 Session 是否已经结束"""
|
|
|
|
def __init__(self):
|
|
self.future = asyncio.Future()
|
|
self.current_event: asyncio.Event | None = None
|
|
"""当前正在等待的所用的异步事件"""
|
|
self.ts: float | None = None
|
|
"""上次保持(keep)开始时的时间"""
|
|
self.timeout: float | int | None = None
|
|
"""上次保持(keep)开始时的超时时间"""
|
|
|
|
self.history_chains: list[list[Comp.BaseMessageComponent]] = []
|
|
|
|
def stop(self, error: Exception | None = None):
|
|
"""立即结束这个会话"""
|
|
if not self.future.done():
|
|
if error:
|
|
self.future.set_exception(error)
|
|
else:
|
|
self.future.set_result(None)
|
|
|
|
def keep(self, timeout: float = 0, reset_timeout=False):
|
|
"""保持这个会话
|
|
|
|
Args:
|
|
timeout (float): 必填。会话超时时间。
|
|
当 reset_timeout 设置为 True 时, 代表重置超时时间, timeout 必须 > 0, 如果 <= 0 则立即结束会话。
|
|
当 reset_timeout 设置为 False 时, 代表继续维持原来的超时时间, 新 timeout = 原来剩余的timeout + timeout (可以 < 0)
|
|
|
|
"""
|
|
new_ts = time.time()
|
|
|
|
if reset_timeout:
|
|
if timeout <= 0:
|
|
self.stop()
|
|
return
|
|
else:
|
|
assert self.timeout is not None
|
|
assert self.ts is not None
|
|
left_timeout = self.timeout - (new_ts - self.ts)
|
|
timeout = left_timeout + timeout
|
|
if timeout <= 0:
|
|
self.stop()
|
|
return
|
|
|
|
if self.current_event and not self.current_event.is_set():
|
|
self.current_event.set() # 通知上一个 keep 结束
|
|
|
|
new_event = asyncio.Event()
|
|
self.ts = new_ts
|
|
self.current_event = new_event
|
|
self.timeout = timeout
|
|
|
|
asyncio.create_task(self._holding(new_event, timeout)) # 开始新的 keep
|
|
|
|
async def _holding(self, event: asyncio.Event, timeout: float):
|
|
"""等待事件结束或超时"""
|
|
try:
|
|
await asyncio.wait_for(event.wait(), timeout)
|
|
except asyncio.TimeoutError:
|
|
if not self.future.done():
|
|
self.future.set_exception(TimeoutError("等待超时"))
|
|
except asyncio.CancelledError:
|
|
pass # 避免报错
|
|
# finally:
|
|
|
|
def get_history_chains(self) -> list[list[Comp.BaseMessageComponent]]:
|
|
"""获取历史消息链"""
|
|
return self.history_chains
|
|
|
|
|
|
class SessionFilter:
|
|
"""如何界定一个会话"""
|
|
|
|
@abc.abstractmethod
|
|
def filter(self, event: AstrMessageEvent) -> str:
|
|
"""根据事件返回一个会话标识符"""
|
|
|
|
|
|
class DefaultSessionFilter(SessionFilter):
|
|
def filter(self, event: AstrMessageEvent) -> str:
|
|
"""默认实现,返回统一消息来源字符串作为会话标识符"""
|
|
return event.unified_msg_origin
|
|
|
|
|
|
class SessionWaiter:
|
|
def __init__(
|
|
self,
|
|
session_filter: SessionFilter,
|
|
session_id: str,
|
|
record_history_chains: bool,
|
|
):
|
|
self.session_id = session_id
|
|
self.session_filter = session_filter
|
|
self.handler: (
|
|
Callable[[SessionController, AstrMessageEvent], Awaitable[Any]] | None
|
|
) = None # 处理函数
|
|
|
|
self.session_controller = SessionController()
|
|
self.record_history_chains = record_history_chains
|
|
"""是否记录历史消息链"""
|
|
|
|
self._lock = asyncio.Lock()
|
|
"""需要保证一个 session 同时只有一个 trigger"""
|
|
|
|
async def register_wait(
|
|
self,
|
|
handler: Callable[[SessionController, AstrMessageEvent], Awaitable[Any]],
|
|
timeout: int = 30,
|
|
) -> Any:
|
|
"""等待外部输入并处理"""
|
|
self.handler = handler
|
|
USER_SESSIONS[self.session_id] = self
|
|
|
|
# 开始一个会话保持事件
|
|
self.session_controller.keep(timeout, reset_timeout=True)
|
|
|
|
try:
|
|
return await self.session_controller.future
|
|
except Exception as e:
|
|
self._cleanup(e)
|
|
raise e
|
|
finally:
|
|
self._cleanup()
|
|
|
|
def _cleanup(self, error: Exception | None = None):
|
|
"""清理会话"""
|
|
USER_SESSIONS.pop(self.session_id, None)
|
|
try:
|
|
FILTERS.remove(self.session_filter)
|
|
except ValueError:
|
|
pass
|
|
self.session_controller.stop(error)
|
|
|
|
@classmethod
|
|
async def trigger(cls, session_id: str, event: AstrMessageEvent):
|
|
"""外部输入触发会话处理"""
|
|
session = USER_SESSIONS.get(session_id)
|
|
if not session or session.session_controller.future.done():
|
|
return
|
|
|
|
async with session._lock:
|
|
if not session.session_controller.future.done():
|
|
if session.record_history_chains:
|
|
session.session_controller.history_chains.append(
|
|
[copy.deepcopy(comp) for comp in event.get_messages()],
|
|
)
|
|
try:
|
|
# TODO: 这里使用 create_task,跟踪 task,防止超时后这里 handler 仍然在执行
|
|
assert session.handler is not None
|
|
await session.handler(session.session_controller, event)
|
|
except Exception as e:
|
|
session.session_controller.stop(e)
|
|
|
|
|
|
def session_waiter(timeout: int = 30, record_history_chains: bool = False):
|
|
"""装饰器:自动将函数注册为 SessionWaiter 处理函数,并等待外部输入触发执行。
|
|
|
|
:param timeout: 超时时间(秒)
|
|
:param record_history_chain: 是否自动记录历史消息链。可以通过 controller.get_history_chains() 获取。深拷贝。
|
|
"""
|
|
|
|
def decorator(
|
|
func: Callable[[SessionController, AstrMessageEvent], Awaitable[Any]],
|
|
):
|
|
@functools.wraps(func)
|
|
async def wrapper(
|
|
event: AstrMessageEvent,
|
|
session_filter: SessionFilter | None = None,
|
|
*args,
|
|
**kwargs,
|
|
):
|
|
if not session_filter:
|
|
session_filter = DefaultSessionFilter()
|
|
if not isinstance(session_filter, SessionFilter):
|
|
raise ValueError("session_filter 必须是 SessionFilter")
|
|
|
|
session_id = session_filter.filter(event)
|
|
FILTERS.append(session_filter)
|
|
|
|
waiter = SessionWaiter(session_filter, session_id, record_history_chains)
|
|
return await waiter.register_wait(func, timeout)
|
|
|
|
return wrapper
|
|
|
|
return decorator
|