fix: 修复 asyncio 事件循环相关问题 (#5774)
* fix: 修复 asyncio 事件循环相关的问题 1. components.py: 修复异常处理结构错误 - 将 except Exception 移到正确的内部 try 块 - 确保 _download_file() 异常能被正确捕获和记录 2. session_lock.py: 修复跨事件循环 Lock 绑定问题 - 添加 _access_lock_loop_id 追踪事件循环 - 当事件循环变化时重新创建 Lock Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix: 根据代码审查反馈修复问题 1. components.py: 移除 asyncio.set_event_loop() 调用 - 创建临时 event loop 时不再设置为全局 - 避免干扰其他 asyncio 使用 2. session_lock.py: 简化延迟初始化逻辑 - 移除 loop-ID 追踪和 _get_lock 方法 - 使用 setdefault 简化 session lock 创建 - 保留延迟初始化行为 3. wecomai_queue_mgr.py: 使用 time.monotonic() 替代 loop.time() - 同步方法不再依赖活动的 event loop - 避免在非异步上下文中抛出 RuntimeError Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix: 优化 asyncio 事件循环管理,使用安全的方式创建和关闭事件循环 * fix: 根据代码审查反馈改进异常处理和事件循环使用 - main.py: 显式处理 check_dashboard_files() 返回 None 的情况 - components.py: 使用 logger.exception 保留异常堆栈信息 - star_manager.py: 添加 Future 异常回调处理 __del__ 执行异常 - bay_manager.py: 缓存事件循环引用避免重复调用 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * refactor: 简化 SessionLockManager 使用 defaultdict 和 setdefault - 使用 defaultdict(asyncio.Lock) 简化锁的懒创建 - 使用 setdefault 简化 _get_loop_state 逻辑 - 减少 get + if 分支,提升可读性 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix: 降低 webui_dir 检查失败时的日志级别为 warning 改为警告而非退出,允许程序在无 WebUI 的情况下继续运行 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * refactor: 重构事件循环锁管理,简化锁状态管理逻辑 * 新增对 SessionLockManager 的多事件循环隔离测试 * fix: 修复测试中的变量声明和断言,确保事件循环管理器的正确性 * fix: 修复插件删除时异常处理逻辑,确保正确记录错误信息 * fix: 新增针对多个事件循环的 OneBot 实例的测试,确保锁对象在不同事件循环间不共享 --------- Co-authored-by: whatevertogo <whatevertogo@users.noreply.github.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -302,7 +302,7 @@ class DashscopeAgentRunner(BaseAgentRunner[TContext]):
|
||||
|
||||
while True:
|
||||
try:
|
||||
item_type, item_data = await asyncio.get_event_loop().run_in_executor(
|
||||
item_type, item_data = await asyncio.get_running_loop().run_in_executor(
|
||||
None, response_queue.get, True, 1
|
||||
)
|
||||
except queue.Empty:
|
||||
@@ -388,7 +388,7 @@ class DashscopeAgentRunner(BaseAgentRunner[TContext]):
|
||||
|
||||
# 发起请求
|
||||
partial = functools.partial(Application.call, **payload)
|
||||
response = await asyncio.get_event_loop().run_in_executor(None, partial)
|
||||
response = await asyncio.get_running_loop().run_in_executor(None, partial)
|
||||
|
||||
async for resp in self._handle_streaming_response(response, session_id):
|
||||
yield resp
|
||||
|
||||
@@ -121,11 +121,12 @@ class BayContainerManager:
|
||||
async def wait_healthy(self, timeout: int = HEALTH_TIMEOUT_S) -> None:
|
||||
"""Block until Bay's ``/health`` endpoint returns 200."""
|
||||
url = f"http://127.0.0.1:{self._host_port}/health"
|
||||
deadline = asyncio.get_event_loop().time() + timeout
|
||||
loop = asyncio.get_running_loop()
|
||||
deadline = loop.time() + timeout
|
||||
last_error: str = ""
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
while asyncio.get_event_loop().time() < deadline:
|
||||
while loop.time() < deadline:
|
||||
try:
|
||||
async with session.get(
|
||||
url, timeout=aiohttp.ClientTimeout(total=3)
|
||||
|
||||
@@ -699,21 +699,24 @@ class File(BaseMessageComponent):
|
||||
|
||||
if self.url:
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_running():
|
||||
logger.warning(
|
||||
"不可以在异步上下文中同步等待下载! "
|
||||
"这个警告通常发生于某些逻辑试图通过 <File>.file 获取文件消息段的文件内容。"
|
||||
"请使用 await get_file() 代替直接获取 <File>.file 字段",
|
||||
)
|
||||
return ""
|
||||
# 等待下载完成
|
||||
loop.run_until_complete(self._download_file())
|
||||
# 检查是否有正在运行的 event loop
|
||||
asyncio.get_running_loop()
|
||||
logger.warning(
|
||||
"不可以在异步上下文中同步等待下载! "
|
||||
"这个警告通常发生于某些逻辑试图通过 <File>.file 获取文件消息段的文件内容。"
|
||||
"请使用 await get_file() 代替直接获取 <File>.file 字段",
|
||||
)
|
||||
return ""
|
||||
except RuntimeError:
|
||||
# 没有运行中的 event loop,可以同步执行
|
||||
try:
|
||||
# 使用 asyncio.run 安全地创建和关闭事件循环
|
||||
asyncio.run(self._download_file())
|
||||
except Exception:
|
||||
logger.exception("文件下载失败")
|
||||
|
||||
if self.file_ and os.path.exists(self.file_):
|
||||
return os.path.abspath(self.file_)
|
||||
except Exception as e:
|
||||
logger.error(f"文件下载失败: {e}")
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
@@ -367,7 +367,7 @@ class DingtalkPlatformAdapter(Platform):
|
||||
|
||||
async def get_access_token(self) -> str:
|
||||
try:
|
||||
access_token = await asyncio.get_event_loop().run_in_executor(
|
||||
access_token = await asyncio.get_running_loop().run_in_executor(
|
||||
None,
|
||||
self.client_.get_access_token,
|
||||
)
|
||||
@@ -760,7 +760,7 @@ class DingtalkPlatformAdapter(Platform):
|
||||
return
|
||||
logger.error(f"钉钉机器人启动失败: {e}")
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
loop = asyncio.get_running_loop()
|
||||
await loop.run_in_executor(None, start_client, loop)
|
||||
|
||||
async def terminate(self) -> None:
|
||||
|
||||
@@ -80,7 +80,7 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
||||
|
||||
if isinstance(source, botpy.message.C2CMessage):
|
||||
# 真流式传输
|
||||
current_time = asyncio.get_event_loop().time()
|
||||
current_time = asyncio.get_running_loop().time()
|
||||
time_since_last_edit = current_time - last_edit_time
|
||||
|
||||
if time_since_last_edit >= throttle_interval:
|
||||
@@ -90,7 +90,7 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
||||
)
|
||||
stream_payload["index"] += 1
|
||||
stream_payload["id"] = ret["id"]
|
||||
last_edit_time = asyncio.get_event_loop().time()
|
||||
last_edit_time = asyncio.get_running_loop().time()
|
||||
|
||||
if isinstance(source, botpy.message.C2CMessage):
|
||||
# 结束流式对话,并且传输 buffer 中剩余的消息
|
||||
|
||||
@@ -55,7 +55,7 @@ class QQOfficialWebhook:
|
||||
max_async=1,
|
||||
connect=bot_connect,
|
||||
dispatch=self.client.ws_dispatch,
|
||||
loop=asyncio.get_event_loop(),
|
||||
loop=asyncio.get_running_loop(),
|
||||
api=self.api,
|
||||
)
|
||||
|
||||
|
||||
@@ -626,7 +626,7 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
||||
|
||||
# 发送初始 typing 状态
|
||||
await self._ensure_typing(user_name, message_thread_id)
|
||||
last_chat_action_time = asyncio.get_event_loop().time()
|
||||
last_chat_action_time = asyncio.get_running_loop().time()
|
||||
|
||||
def _append_text(t: str) -> None:
|
||||
nonlocal delta
|
||||
@@ -657,11 +657,11 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
||||
|
||||
# 编辑或发送消息
|
||||
if message_id and len(delta) <= self.MAX_MESSAGE_LENGTH:
|
||||
current_time = asyncio.get_event_loop().time()
|
||||
current_time = asyncio.get_running_loop().time()
|
||||
time_since_last_edit = current_time - last_edit_time
|
||||
|
||||
if time_since_last_edit >= throttle_interval:
|
||||
current_time = asyncio.get_event_loop().time()
|
||||
current_time = asyncio.get_running_loop().time()
|
||||
if current_time - last_chat_action_time >= chat_action_interval:
|
||||
await self._ensure_typing(user_name, message_thread_id)
|
||||
last_chat_action_time = current_time
|
||||
@@ -674,9 +674,9 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
||||
current_content = delta
|
||||
except Exception as e:
|
||||
logger.warning(f"编辑消息失败(streaming): {e!s}")
|
||||
last_edit_time = asyncio.get_event_loop().time()
|
||||
last_edit_time = asyncio.get_running_loop().time()
|
||||
else:
|
||||
current_time = asyncio.get_event_loop().time()
|
||||
current_time = asyncio.get_running_loop().time()
|
||||
if current_time - last_chat_action_time >= chat_action_interval:
|
||||
await self._ensure_typing(user_name, message_thread_id)
|
||||
last_chat_action_time = current_time
|
||||
@@ -688,7 +688,7 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
||||
except Exception as e:
|
||||
logger.warning(f"发送消息失败(streaming): {e!s}")
|
||||
message_id = msg.message_id
|
||||
last_edit_time = asyncio.get_event_loop().time()
|
||||
last_edit_time = asyncio.get_running_loop().time()
|
||||
|
||||
try:
|
||||
if delta and current_content != delta:
|
||||
|
||||
@@ -200,7 +200,7 @@ class WecomPlatformAdapter(Platform):
|
||||
return msg_list[-1]
|
||||
return None
|
||||
|
||||
msg_new = await asyncio.get_event_loop().run_in_executor(
|
||||
msg_new = await asyncio.get_running_loop().run_in_executor(
|
||||
None,
|
||||
get_latest_msg_item,
|
||||
)
|
||||
@@ -261,7 +261,7 @@ class WecomPlatformAdapter(Platform):
|
||||
|
||||
@override
|
||||
async def run(self) -> None:
|
||||
loop = asyncio.get_event_loop()
|
||||
loop = asyncio.get_running_loop()
|
||||
if self.kf_name:
|
||||
try:
|
||||
acc_list = (
|
||||
@@ -339,7 +339,7 @@ class WecomPlatformAdapter(Platform):
|
||||
abm.session_id = abm.sender.user_id
|
||||
abm.raw_message = msg
|
||||
elif isinstance(msg, VoiceMessage):
|
||||
resp: Response = await asyncio.get_event_loop().run_in_executor(
|
||||
resp: Response = await asyncio.get_running_loop().run_in_executor(
|
||||
None,
|
||||
self.client.media.download,
|
||||
msg.media_id,
|
||||
@@ -395,7 +395,7 @@ class WecomPlatformAdapter(Platform):
|
||||
abm.message_str = text
|
||||
elif msgtype == "image":
|
||||
media_id = msg.get("image", {}).get("media_id", "")
|
||||
resp: Response = await asyncio.get_event_loop().run_in_executor(
|
||||
resp: Response = await asyncio.get_running_loop().run_in_executor(
|
||||
None,
|
||||
self.client.media.download,
|
||||
media_id,
|
||||
@@ -407,7 +407,7 @@ class WecomPlatformAdapter(Platform):
|
||||
abm.message = [Image(file=path, url=path)]
|
||||
elif msgtype == "voice":
|
||||
media_id = msg.get("voice", {}).get("media_id", "")
|
||||
resp: Response = await asyncio.get_event_loop().run_in_executor(
|
||||
resp: Response = await asyncio.get_running_loop().run_in_executor(
|
||||
None,
|
||||
self.client.media.download,
|
||||
media_id,
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any
|
||||
|
||||
@@ -82,7 +83,7 @@ class WecomAIQueueMgr:
|
||||
del self.pending_responses[session_id]
|
||||
logger.debug(f"[WecomAI] 移除待处理响应: {session_id}")
|
||||
if mark_finished:
|
||||
self.completed_streams[session_id] = asyncio.get_event_loop().time()
|
||||
self.completed_streams[session_id] = time.monotonic()
|
||||
logger.debug(f"[WecomAI] 标记流已结束: {session_id}")
|
||||
|
||||
def remove_queue(self, session_id: str):
|
||||
@@ -135,7 +136,7 @@ class WecomAIQueueMgr:
|
||||
"""
|
||||
self.pending_responses[session_id] = {
|
||||
"callback_params": callback_params,
|
||||
"timestamp": asyncio.get_event_loop().time(),
|
||||
"timestamp": time.monotonic(),
|
||||
}
|
||||
logger.debug(f"[WecomAI] 设置待处理响应: {session_id}")
|
||||
|
||||
@@ -160,7 +161,7 @@ class WecomAIQueueMgr:
|
||||
finished_at = self.completed_streams.get(session_id)
|
||||
if finished_at is None:
|
||||
return False
|
||||
if asyncio.get_event_loop().time() - finished_at > max_age_seconds:
|
||||
if time.monotonic() - finished_at > max_age_seconds:
|
||||
self.completed_streams.pop(session_id, None)
|
||||
return False
|
||||
return True
|
||||
@@ -172,7 +173,7 @@ class WecomAIQueueMgr:
|
||||
max_age_seconds: 最大存活时间(秒)
|
||||
|
||||
"""
|
||||
current_time = asyncio.get_event_loop().time()
|
||||
current_time = time.monotonic()
|
||||
expired_sessions = []
|
||||
|
||||
for session_id, response_data in self.pending_responses.items():
|
||||
|
||||
@@ -369,7 +369,7 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
|
||||
if future:
|
||||
logger.debug(f"duplicate message id checked: {msg.id}")
|
||||
else:
|
||||
future = asyncio.get_event_loop().create_future()
|
||||
future = asyncio.get_running_loop().create_future()
|
||||
self.wexin_event_workers[msg_id] = future
|
||||
await self.convert_message(msg, future)
|
||||
# I love shield so much!
|
||||
@@ -461,7 +461,7 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
|
||||
elif msg.type == "voice":
|
||||
assert isinstance(msg, VoiceMessage)
|
||||
|
||||
resp: Response = await asyncio.get_event_loop().run_in_executor(
|
||||
resp: Response = await asyncio.get_running_loop().run_in_executor(
|
||||
None,
|
||||
self.client.media.download,
|
||||
msg.media_id,
|
||||
|
||||
@@ -87,7 +87,7 @@ class ProviderDashscopeTTSAPI(TTSProvider):
|
||||
model: str,
|
||||
text: str,
|
||||
) -> tuple[bytes | None, str]:
|
||||
loop = asyncio.get_event_loop()
|
||||
loop = asyncio.get_running_loop()
|
||||
response = await loop.run_in_executor(None, self._call_qwen_tts, model, text)
|
||||
audio_bytes = await self._extract_audio_from_response(response)
|
||||
if not audio_bytes:
|
||||
@@ -143,7 +143,7 @@ class ProviderDashscopeTTSAPI(TTSProvider):
|
||||
voice=self.voice,
|
||||
format=AudioFormat.WAV_24000HZ_MONO_16BIT,
|
||||
)
|
||||
loop = asyncio.get_event_loop()
|
||||
loop = asyncio.get_running_loop()
|
||||
audio_bytes = await loop.run_in_executor(
|
||||
None,
|
||||
synthesizer.call,
|
||||
|
||||
@@ -59,7 +59,7 @@ class GenieTTSProvider(TTSProvider):
|
||||
filename = f"genie_tts_{uuid.uuid4()}.wav"
|
||||
path = os.path.join(temp_dir, filename)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
def _generate(save_path: str) -> None:
|
||||
assert genie is not None
|
||||
@@ -85,7 +85,7 @@ class GenieTTSProvider(TTSProvider):
|
||||
text_queue: asyncio.Queue[str | None],
|
||||
audio_queue: "asyncio.Queue[bytes | tuple[str, bytes] | None]",
|
||||
) -> None:
|
||||
loop = asyncio.get_event_loop()
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
while True:
|
||||
text = await text_queue.get()
|
||||
|
||||
@@ -43,7 +43,7 @@ class ProviderSenseVoiceSTTSelfHost(STTProvider):
|
||||
logger.info("下载或者加载 SenseVoice 模型中,这可能需要一些时间 ...")
|
||||
|
||||
# 将模型加载放到线程池中执行
|
||||
self.model = await asyncio.get_event_loop().run_in_executor(
|
||||
self.model = await asyncio.get_running_loop().run_in_executor(
|
||||
None,
|
||||
lambda: SenseVoiceSmall(self.model_name, quantize=True, batch_size=16),
|
||||
)
|
||||
@@ -88,7 +88,7 @@ class ProviderSenseVoiceSTTSelfHost(STTProvider):
|
||||
audio_url = output_path
|
||||
|
||||
# 使用 run_in_executor 来调用模型进行识别
|
||||
loop = asyncio.get_event_loop()
|
||||
loop = asyncio.get_running_loop()
|
||||
res = await loop.run_in_executor(
|
||||
None, # 使用默认的线程池
|
||||
lambda: cast(SenseVoiceSmall, self.model)(
|
||||
|
||||
@@ -31,7 +31,7 @@ class ProviderOpenAIWhisperSelfHost(STTProvider):
|
||||
self.model = None
|
||||
|
||||
async def initialize(self) -> None:
|
||||
loop = asyncio.get_event_loop()
|
||||
loop = asyncio.get_running_loop()
|
||||
logger.info("下载或者加载 Whisper 模型中,这可能需要一些时间 ...")
|
||||
self.model = await loop.run_in_executor(
|
||||
None,
|
||||
@@ -50,7 +50,7 @@ class ProviderOpenAIWhisperSelfHost(STTProvider):
|
||||
return False
|
||||
|
||||
async def get_text(self, audio_url: str) -> str:
|
||||
loop = asyncio.get_event_loop()
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
is_tencent = False
|
||||
|
||||
|
||||
@@ -1374,10 +1374,23 @@ class PluginManager:
|
||||
return
|
||||
|
||||
if "__del__" in star_metadata.star_cls_type.__dict__:
|
||||
asyncio.get_event_loop().run_in_executor(
|
||||
loop = asyncio.get_running_loop()
|
||||
future = loop.run_in_executor(
|
||||
None,
|
||||
star_metadata.star_cls.__del__,
|
||||
)
|
||||
|
||||
def _log_del_exception(fut: asyncio.Future) -> None:
|
||||
if fut.cancelled():
|
||||
return
|
||||
if (exc := fut.exception()) is not None:
|
||||
logger.error(
|
||||
"插件 %s 在 __del__ 中抛出了异常:%r",
|
||||
star_metadata.name,
|
||||
exc,
|
||||
)
|
||||
|
||||
future.add_done_callback(_log_del_exception)
|
||||
elif "terminate" in star_metadata.star_cls_type.__dict__:
|
||||
await star_metadata.star_cls.terminate()
|
||||
|
||||
|
||||
@@ -1,9 +1,13 @@
|
||||
import asyncio
|
||||
import threading
|
||||
import weakref
|
||||
from collections import defaultdict
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
|
||||
class SessionLockManager:
|
||||
class _PerLoopSessionLockManager:
|
||||
"""Per-event-loop session lock manager; keeps original simple semantics."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._locks: dict[str, asyncio.Lock] = defaultdict(asyncio.Lock)
|
||||
self._lock_count: dict[str, int] = defaultdict(int)
|
||||
@@ -26,4 +30,26 @@ class SessionLockManager:
|
||||
self._lock_count.pop(session_id, None)
|
||||
|
||||
|
||||
class SessionLockManager:
|
||||
"""Thread-safe session lock manager with per-event-loop isolation."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._state_guard = threading.Lock()
|
||||
self._loop_managers: weakref.WeakKeyDictionary[
|
||||
asyncio.AbstractEventLoop, _PerLoopSessionLockManager
|
||||
] = weakref.WeakKeyDictionary()
|
||||
|
||||
def _get_loop_manager(self) -> _PerLoopSessionLockManager:
|
||||
"""Get the lock manager for the current event loop."""
|
||||
loop = asyncio.get_running_loop()
|
||||
with self._state_guard:
|
||||
return self._loop_managers.setdefault(loop, _PerLoopSessionLockManager())
|
||||
|
||||
@asynccontextmanager
|
||||
async def acquire_lock(self, session_id: str):
|
||||
manager = self._get_loop_manager()
|
||||
async with manager.acquire_lock(session_id):
|
||||
yield
|
||||
|
||||
|
||||
session_lock_manager = SessionLockManager()
|
||||
|
||||
@@ -101,6 +101,26 @@ async def check_dashboard_files(webui_dir: str | None = None):
|
||||
return data_dist_path
|
||||
|
||||
|
||||
async def main_async(webui_dir_arg: str | None) -> None:
|
||||
"""主异步入口"""
|
||||
# 检查仪表板文件
|
||||
webui_dir = await check_dashboard_files(webui_dir_arg)
|
||||
if webui_dir is None:
|
||||
logger.warning(
|
||||
"管理面板文件检查失败,WebUI 功能将不可用。"
|
||||
"请检查网络连接或手动指定 --webui-dir 参数。"
|
||||
)
|
||||
|
||||
db = db_helper
|
||||
|
||||
# 打印 logo
|
||||
logger.info(logo_tmpl)
|
||||
|
||||
core_lifecycle = InitialLoader(db, log_broker)
|
||||
core_lifecycle.webui_dir = webui_dir
|
||||
await core_lifecycle.start()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="AstrBot")
|
||||
parser.add_argument(
|
||||
@@ -117,14 +137,5 @@ if __name__ == "__main__":
|
||||
log_broker = LogBroker()
|
||||
LogManager.set_queue_handler(logger, log_broker)
|
||||
|
||||
# 检查仪表板文件
|
||||
webui_dir = asyncio.run(check_dashboard_files(args.webui_dir))
|
||||
|
||||
db = db_helper
|
||||
|
||||
# 打印 logo
|
||||
logger.info(logo_tmpl)
|
||||
|
||||
core_lifecycle = InitialLoader(db, log_broker)
|
||||
core_lifecycle.webui_dir = webui_dir
|
||||
asyncio.run(core_lifecycle.start())
|
||||
# 只使用一次 asyncio.run()
|
||||
asyncio.run(main_async(args.webui_dir))
|
||||
|
||||
@@ -0,0 +1,545 @@
|
||||
"""Tests for SessionLockManager with multi-event-loop isolation."""
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
import time
|
||||
import weakref
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
import pytest
|
||||
|
||||
from astrbot.core.utils.session_lock import SessionLockManager
|
||||
|
||||
|
||||
class TestSessionLockManagerBasic:
|
||||
"""Basic functionality tests."""
|
||||
|
||||
def test_init(self):
|
||||
"""Test manager initialization."""
|
||||
manager = SessionLockManager()
|
||||
assert manager._state_guard is not None
|
||||
assert manager._loop_managers is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_acquire_release_lock(self):
|
||||
"""Test basic lock acquire and release."""
|
||||
manager = SessionLockManager()
|
||||
session_id = "test-session"
|
||||
|
||||
async with manager.acquire_lock(session_id):
|
||||
# Lock acquired successfully
|
||||
pass
|
||||
|
||||
# Lock should be released and cleaned up
|
||||
state = manager._get_loop_manager()
|
||||
assert session_id not in state._locks
|
||||
assert session_id not in state._lock_count
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lock_is_reusable(self):
|
||||
"""Test that locks can be acquired multiple times."""
|
||||
manager = SessionLockManager()
|
||||
session_id = "test-session"
|
||||
|
||||
async with manager.acquire_lock(session_id):
|
||||
pass
|
||||
|
||||
async with manager.acquire_lock(session_id):
|
||||
pass
|
||||
|
||||
# Both acquisitions should succeed
|
||||
|
||||
|
||||
class TestCrossLoopIsolation:
|
||||
"""Tests for event loop isolation."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_different_loops_have_different_managers(self):
|
||||
"""Test that different event loops get different per-loop managers."""
|
||||
manager = SessionLockManager()
|
||||
|
||||
# Get manager for current loop
|
||||
manager1 = manager._get_loop_manager()
|
||||
|
||||
# Run in a different event loop
|
||||
def run_in_new_loop():
|
||||
new_loop = asyncio.new_event_loop()
|
||||
try:
|
||||
asyncio.set_event_loop(new_loop)
|
||||
|
||||
async def get_manager():
|
||||
return manager._get_loop_manager()
|
||||
|
||||
return new_loop.run_until_complete(get_manager())
|
||||
finally:
|
||||
new_loop.close()
|
||||
asyncio.set_event_loop(None)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=1) as executor:
|
||||
future = executor.submit(run_in_new_loop)
|
||||
manager2 = future.result()
|
||||
|
||||
# Should be different manager instances
|
||||
assert manager1 is not manager2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_locks_isolated_across_loops(self):
|
||||
"""Test that locks from different loops are isolated."""
|
||||
manager = SessionLockManager()
|
||||
session_id = "shared-session"
|
||||
results = []
|
||||
|
||||
async def acquire_in_loop(loop_id: int):
|
||||
"""Acquire lock in a new event loop."""
|
||||
async with manager.acquire_lock(session_id):
|
||||
results.append(f"loop-{loop_id}-acquired")
|
||||
await asyncio.sleep(0.05)
|
||||
results.append(f"loop-{loop_id}-released")
|
||||
|
||||
def run_in_thread(loop_id: int):
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(acquire_in_loop(loop_id))
|
||||
finally:
|
||||
loop.close()
|
||||
asyncio.set_event_loop(None)
|
||||
|
||||
# Run two loops concurrently - they should NOT block each other
|
||||
# because locks are isolated per-loop
|
||||
with ThreadPoolExecutor(max_workers=2) as executor:
|
||||
futures = [executor.submit(run_in_thread, i) for i in range(2)]
|
||||
for f in futures:
|
||||
f.result()
|
||||
|
||||
# Both loops should acquire immediately (no blocking between loops)
|
||||
# Order should show interleaved acquisitions, not sequential
|
||||
assert len(results) == 4
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_same_loop_blocks_on_same_session(self):
|
||||
"""Test that same loop blocks when acquiring same session lock."""
|
||||
manager = SessionLockManager()
|
||||
session_id = "test-session"
|
||||
execution_order = []
|
||||
|
||||
async def task1():
|
||||
async with manager.acquire_lock(session_id):
|
||||
execution_order.append("task1-start")
|
||||
await asyncio.sleep(0.1)
|
||||
execution_order.append("task1-end")
|
||||
|
||||
async def task2():
|
||||
await asyncio.sleep(0.01) # Let task1 start first
|
||||
async with manager.acquire_lock(session_id):
|
||||
execution_order.append("task2-start")
|
||||
execution_order.append("task2-end")
|
||||
|
||||
await asyncio.gather(task1(), task2())
|
||||
|
||||
# task2 should wait for task1 to finish
|
||||
assert execution_order.index("task1-start") < execution_order.index("task1-end")
|
||||
assert execution_order.index("task1-end") < execution_order.index("task2-start")
|
||||
|
||||
|
||||
class TestConcurrency:
|
||||
"""Tests for concurrent access."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_acquisitions_same_loop(self):
|
||||
"""Test concurrent lock acquisitions on the same loop."""
|
||||
manager = SessionLockManager()
|
||||
session_id = "concurrent-session"
|
||||
acquired_count = 0
|
||||
max_concurrent = 0
|
||||
lock = asyncio.Lock()
|
||||
|
||||
async def acquire_and_check():
|
||||
nonlocal acquired_count, max_concurrent
|
||||
async with manager.acquire_lock(session_id):
|
||||
async with lock:
|
||||
acquired_count += 1
|
||||
max_concurrent = max(max_concurrent, acquired_count)
|
||||
await asyncio.sleep(0.01)
|
||||
async with lock:
|
||||
acquired_count -= 1
|
||||
|
||||
# Run multiple concurrent tasks
|
||||
tasks = [acquire_and_check() for _ in range(5)]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
# Max concurrent should be 1 (lock serializes access)
|
||||
assert max_concurrent == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_thread_safety_of_loop_manager_creation(self):
|
||||
"""Test that _get_loop_manager is thread-safe."""
|
||||
manager = SessionLockManager()
|
||||
managers = []
|
||||
errors = []
|
||||
|
||||
def create_loop_and_get_manager():
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
|
||||
async def get_mgr():
|
||||
return manager._get_loop_manager()
|
||||
|
||||
mgr = loop.run_until_complete(get_mgr())
|
||||
managers.append(mgr)
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
finally:
|
||||
loop.close()
|
||||
asyncio.set_event_loop(None)
|
||||
|
||||
threads = [threading.Thread(target=create_loop_and_get_manager) for _ in range(10)]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
assert len(errors) == 0
|
||||
# All managers should be valid
|
||||
for m in managers:
|
||||
assert hasattr(m, "_locks")
|
||||
assert hasattr(m, "_access_lock")
|
||||
|
||||
|
||||
class TestEventLoopCleanup:
|
||||
"""Tests for event loop cleanup behavior."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_weakref_cleanup_on_loop_close(self):
|
||||
"""Test that per-loop managers are cleaned up when loop is closed."""
|
||||
manager = SessionLockManager()
|
||||
loop_ref: weakref.ref[asyncio.AbstractEventLoop] | None = None
|
||||
|
||||
def run_in_new_loop():
|
||||
nonlocal loop_ref
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
loop_ref = weakref.ref(loop)
|
||||
|
||||
async def use_lock():
|
||||
async with manager.acquire_lock("test-session"):
|
||||
pass
|
||||
return manager._get_loop_manager()
|
||||
|
||||
try:
|
||||
per_loop_mgr = loop.run_until_complete(use_lock())
|
||||
# Keep a weak ref to the per-loop manager
|
||||
return weakref.ref(per_loop_mgr)
|
||||
finally:
|
||||
loop.close()
|
||||
asyncio.set_event_loop(None)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=1) as executor:
|
||||
future = executor.submit(run_in_new_loop)
|
||||
per_loop_mgr_ref = future.result()
|
||||
|
||||
# Give time for weakref cleanup
|
||||
import gc
|
||||
|
||||
gc.collect()
|
||||
|
||||
# The per-loop manager should be cleaned up when the loop is closed
|
||||
# because WeakKeyDictionary removes entries when the key (loop) is gone
|
||||
per_loop_mgr = per_loop_mgr_ref()
|
||||
loop = loop_ref() if loop_ref is not None else None
|
||||
assert per_loop_mgr is None or loop is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_access_after_loop_close_in_new_loop_works(self):
|
||||
"""Test that accessing from a new loop after old loop closes works."""
|
||||
manager = SessionLockManager()
|
||||
|
||||
# Use lock in current loop
|
||||
async with manager.acquire_lock("session-1"):
|
||||
pass
|
||||
|
||||
# Simulate old loop being closed and new loop being created
|
||||
def run_in_new_loop():
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
|
||||
async def use_lock():
|
||||
# Should work without issues in new loop
|
||||
async with manager.acquire_lock("session-2"):
|
||||
return "success"
|
||||
|
||||
return loop.run_until_complete(use_lock())
|
||||
finally:
|
||||
loop.close()
|
||||
asyncio.set_event_loop(None)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=1) as executor:
|
||||
future = executor.submit(run_in_new_loop)
|
||||
result = future.result()
|
||||
|
||||
assert result == "success"
|
||||
|
||||
|
||||
class TestIssue5464:
|
||||
"""Tests for issue #5464: Multiple OneBot instances with different event loops.
|
||||
|
||||
Issue: Running multiple OneBot adapter instances causes
|
||||
"is bound to a different event loop" error.
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_event_loops_no_cross_loop_error(self):
|
||||
"""Test that multiple event loops don't cause cross-loop binding errors.
|
||||
|
||||
This simulates the scenario where multiple OneBot instances
|
||||
(each potentially running in different event loops) access the
|
||||
same SessionLockManager concurrently.
|
||||
"""
|
||||
from astrbot.core.utils.session_lock import session_lock_manager
|
||||
|
||||
errors: list[Exception] = []
|
||||
results: list[str] = []
|
||||
|
||||
def simulate_onebot_instance(instance_id: int, session_ids: list[str]):
|
||||
"""Simulate a OneBot instance running in its own event loop."""
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
|
||||
async def process_messages():
|
||||
for session_id in session_ids:
|
||||
try:
|
||||
async with session_lock_manager.acquire_lock(session_id):
|
||||
# Simulate message processing
|
||||
await asyncio.sleep(0.01)
|
||||
results.append(f"instance-{instance_id}-{session_id}")
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
|
||||
loop.run_until_complete(process_messages())
|
||||
finally:
|
||||
loop.close()
|
||||
asyncio.set_event_loop(None)
|
||||
|
||||
# Simulate 4 OneBot instances (as in the issue report)
|
||||
# Each handles multiple sessions concurrently
|
||||
threads = []
|
||||
for i in range(4):
|
||||
sessions = [f"session-{i}-1", f"session-{i}-2", f"session-{i}-3"]
|
||||
t = threading.Thread(target=simulate_onebot_instance, args=(i, sessions))
|
||||
threads.append(t)
|
||||
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
# Should have no errors (especially no "bound to a different event loop")
|
||||
assert len(errors) == 0, f"Errors occurred: {errors}"
|
||||
assert len(results) == 12 # 4 instances * 3 sessions each
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lock_object_not_shared_across_loops(self):
|
||||
"""Verify that asyncio.Lock objects are not shared across event loops.
|
||||
|
||||
The root cause of issue #5464 was that Lock objects created in one
|
||||
event loop were being used in another, causing the error.
|
||||
"""
|
||||
manager = SessionLockManager()
|
||||
session_id = "shared-session-id"
|
||||
lock_ids: set[int] = set()
|
||||
lock_id_lock = threading.Lock()
|
||||
|
||||
def get_lock_in_new_loop():
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
|
||||
async def acquire_and_capture():
|
||||
# Get the per-loop manager
|
||||
per_loop_mgr = manager._get_loop_manager()
|
||||
# Capture the lock object id before acquiring
|
||||
async with per_loop_mgr._access_lock:
|
||||
lock = per_loop_mgr._locks[session_id]
|
||||
with lock_id_lock:
|
||||
lock_ids.add(id(lock))
|
||||
async with manager.acquire_lock(session_id):
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
loop.run_until_complete(acquire_and_capture())
|
||||
finally:
|
||||
loop.close()
|
||||
asyncio.set_event_loop(None)
|
||||
|
||||
# Run multiple loops concurrently
|
||||
threads = [threading.Thread(target=get_lock_in_new_loop) for _ in range(5)]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
# Each loop should have its own Lock object
|
||||
# If locks were shared, we'd only have 1 lock_id
|
||||
assert len(lock_ids) == 5, "Each event loop should have its own Lock object"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_access_same_session_different_loops(self):
|
||||
"""Test that same session ID accessed from different loops doesn't block.
|
||||
|
||||
This verifies the fix: locks are isolated per event loop,
|
||||
so different loops can acquire the "same" session lock concurrently.
|
||||
"""
|
||||
from astrbot.core.utils.session_lock import session_lock_manager
|
||||
|
||||
session_id = "global-session"
|
||||
acquisition_times: list[float] = []
|
||||
time_lock = threading.Lock()
|
||||
|
||||
def acquire_lock_in_loop(loop_id: int):
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
|
||||
async def acquire():
|
||||
import time
|
||||
|
||||
start = time.time()
|
||||
async with session_lock_manager.acquire_lock(session_id):
|
||||
with time_lock:
|
||||
acquisition_times.append(start)
|
||||
await asyncio.sleep(0.1) # Hold the lock
|
||||
|
||||
loop.run_until_complete(acquire())
|
||||
finally:
|
||||
loop.close()
|
||||
asyncio.set_event_loop(None)
|
||||
|
||||
# Start 3 threads nearly simultaneously
|
||||
threads = [threading.Thread(target=acquire_lock_in_loop, args=(i,)) for i in range(3)]
|
||||
|
||||
start_time = time.time()
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
total_time = time.time() - start_time
|
||||
|
||||
# If locks were NOT isolated, we'd need ~0.3s (3 * 0.1s serial)
|
||||
# With isolation, all should complete in ~0.1s (parallel)
|
||||
# Allow some overhead, but should be much less than 0.3s
|
||||
assert total_time < 0.25, (
|
||||
f"Locks should be isolated per loop, but took {total_time:.2f}s"
|
||||
)
|
||||
|
||||
|
||||
class TestEdgeCases:
|
||||
"""Tests for edge cases."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_session_id(self):
|
||||
"""Test with empty session ID."""
|
||||
manager = SessionLockManager()
|
||||
|
||||
async with manager.acquire_lock(""):
|
||||
pass
|
||||
|
||||
# Should work without issues
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_special_characters_in_session_id(self):
|
||||
"""Test with special characters in session ID."""
|
||||
manager = SessionLockManager()
|
||||
session_id = "session-with-special-chars!@#$%^&*()"
|
||||
|
||||
async with manager.acquire_lock(session_id):
|
||||
pass
|
||||
|
||||
# Should work without issues
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_very_long_session_id(self):
|
||||
"""Test with very long session ID."""
|
||||
manager = SessionLockManager()
|
||||
session_id = "a" * 10000
|
||||
|
||||
async with manager.acquire_lock(session_id):
|
||||
pass
|
||||
|
||||
# Should work without issues
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lock_not_held_after_context_exit(self):
|
||||
"""Test that lock is released after context manager exit."""
|
||||
manager = SessionLockManager()
|
||||
session_id = "test-session"
|
||||
|
||||
async with manager.acquire_lock(session_id):
|
||||
state = manager._get_loop_manager()
|
||||
# Lock should exist and have count 1
|
||||
assert session_id in state._locks
|
||||
assert state._lock_count[session_id] == 1
|
||||
|
||||
# After exit, lock should be cleaned up
|
||||
state = manager._get_loop_manager()
|
||||
assert session_id not in state._locks
|
||||
assert session_id not in state._lock_count
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exception_during_lock(self):
|
||||
"""Test that lock is released even if exception occurs."""
|
||||
manager = SessionLockManager()
|
||||
session_id = "test-session"
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
async with manager.acquire_lock(session_id):
|
||||
raise ValueError("test error")
|
||||
|
||||
# Lock should still be released
|
||||
state = manager._get_loop_manager()
|
||||
assert session_id not in state._locks
|
||||
assert session_id not in state._lock_count
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_nested_lock_different_sessions(self):
|
||||
"""Test nested locks on different sessions."""
|
||||
manager = SessionLockManager()
|
||||
|
||||
async with manager.acquire_lock("session-1"):
|
||||
async with manager.acquire_lock("session-2"):
|
||||
state = manager._get_loop_manager()
|
||||
assert "session-1" in state._locks
|
||||
assert "session-2" in state._locks
|
||||
assert state._lock_count["session-1"] == 1
|
||||
assert state._lock_count["session-2"] == 1
|
||||
|
||||
state = manager._get_loop_manager()
|
||||
assert "session-1" not in state._locks
|
||||
assert "session-2" not in state._locks
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reentrant_lock_same_session(self):
|
||||
"""Test reentrant locking on same session (should block)."""
|
||||
manager = SessionLockManager()
|
||||
session_id = "test-session"
|
||||
order = []
|
||||
|
||||
async def outer():
|
||||
async with manager.acquire_lock(session_id):
|
||||
order.append("outer-acquired")
|
||||
await asyncio.sleep(0.1)
|
||||
order.append("outer-done")
|
||||
|
||||
async def inner():
|
||||
await asyncio.sleep(0.01) # Let outer acquire first
|
||||
order.append("inner-attempt")
|
||||
async with manager.acquire_lock(session_id):
|
||||
order.append("inner-acquired")
|
||||
order.append("inner-done")
|
||||
|
||||
await asyncio.gather(outer(), inner())
|
||||
|
||||
# Inner should wait for outer to complete
|
||||
assert order.index("outer-acquired") < order.index("outer-done")
|
||||
assert order.index("outer-done") < order.index("inner-acquired")
|
||||
Reference in New Issue
Block a user