fix: 修复 asyncio 事件循环相关问题 (#5774)

* fix: 修复 asyncio 事件循环相关的问题

1. components.py: 修复异常处理结构错误
   - 将 except Exception 移到正确的内部 try 块
   - 确保 _download_file() 异常能被正确捕获和记录

2. session_lock.py: 修复跨事件循环 Lock 绑定问题
   - 添加 _access_lock_loop_id 追踪事件循环
   - 当事件循环变化时重新创建 Lock

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* fix: 根据代码审查反馈修复问题

1. components.py: 移除 asyncio.set_event_loop() 调用
   - 创建临时 event loop 时不再设置为全局
   - 避免干扰其他 asyncio 使用

2. session_lock.py: 简化延迟初始化逻辑
   - 移除 loop-ID 追踪和 _get_lock 方法
   - 使用 setdefault 简化 session lock 创建
   - 保留延迟初始化行为

3. wecomai_queue_mgr.py: 使用 time.monotonic() 替代 loop.time()
   - 同步方法不再依赖活动的 event loop
   - 避免在非异步上下文中抛出 RuntimeError

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* fix: 优化 asyncio 事件循环管理,使用安全的方式创建和关闭事件循环

* fix: 根据代码审查反馈改进异常处理和事件循环使用

- main.py: 显式处理 check_dashboard_files() 返回 None 的情况
- components.py: 使用 logger.exception 保留异常堆栈信息
- star_manager.py: 添加 Future 异常回调处理 __del__ 执行异常
- bay_manager.py: 缓存事件循环引用避免重复调用

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* refactor: 简化 SessionLockManager 使用 defaultdict 和 setdefault

- 使用 defaultdict(asyncio.Lock) 简化锁的懒创建
- 使用 setdefault 简化 _get_loop_state 逻辑
- 减少 get + if 分支,提升可读性

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* fix: 降低 webui_dir 检查失败时的日志级别为 warning

改为警告而非退出,允许程序在无 WebUI 的情况下继续运行

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* refactor: 重构事件循环锁管理,简化锁状态管理逻辑

* 新增对 SessionLockManager 的多事件循环隔离测试

* fix: 修复测试中的变量声明和断言,确保事件循环管理器的正确性

* fix: 修复插件删除时异常处理逻辑,确保正确记录错误信息

* fix: 新增针对多个事件循环的 OneBot 实例的测试,确保锁对象在不同事件循环间不共享

---------

Co-authored-by: whatevertogo <whatevertogo@users.noreply.github.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
whatevertogo
2026-03-09 00:00:13 +08:00
committed by GitHub
parent 5808784f07
commit 3fd6c4c8a6
18 changed files with 659 additions and 59 deletions
@@ -302,7 +302,7 @@ class DashscopeAgentRunner(BaseAgentRunner[TContext]):
while True:
try:
item_type, item_data = await asyncio.get_event_loop().run_in_executor(
item_type, item_data = await asyncio.get_running_loop().run_in_executor(
None, response_queue.get, True, 1
)
except queue.Empty:
@@ -388,7 +388,7 @@ class DashscopeAgentRunner(BaseAgentRunner[TContext]):
# 发起请求
partial = functools.partial(Application.call, **payload)
response = await asyncio.get_event_loop().run_in_executor(None, partial)
response = await asyncio.get_running_loop().run_in_executor(None, partial)
async for resp in self._handle_streaming_response(response, session_id):
yield resp
+3 -2
View File
@@ -121,11 +121,12 @@ class BayContainerManager:
async def wait_healthy(self, timeout: int = HEALTH_TIMEOUT_S) -> None:
"""Block until Bay's ``/health`` endpoint returns 200."""
url = f"http://127.0.0.1:{self._host_port}/health"
deadline = asyncio.get_event_loop().time() + timeout
loop = asyncio.get_running_loop()
deadline = loop.time() + timeout
last_error: str = ""
async with aiohttp.ClientSession() as session:
while asyncio.get_event_loop().time() < deadline:
while loop.time() < deadline:
try:
async with session.get(
url, timeout=aiohttp.ClientTimeout(total=3)
+15 -12
View File
@@ -699,21 +699,24 @@ class File(BaseMessageComponent):
if self.url:
try:
loop = asyncio.get_event_loop()
if loop.is_running():
logger.warning(
"不可以在异步上下文中同步等待下载! "
"这个警告通常发生于某些逻辑试图通过 <File>.file 获取文件消息段的文件内容。"
"请使用 await get_file() 代替直接获取 <File>.file 字段",
)
return ""
# 等待下载完成
loop.run_until_complete(self._download_file())
# 检查是否有正在运行的 event loop
asyncio.get_running_loop()
logger.warning(
"不可以在异步上下文中同步等待下载! "
"这个警告通常发生于某些逻辑试图通过 <File>.file 获取文件消息段的文件内容。"
"请使用 await get_file() 代替直接获取 <File>.file 字段",
)
return ""
except RuntimeError:
# 没有运行中的 event loop,可以同步执行
try:
# 使用 asyncio.run 安全地创建和关闭事件循环
asyncio.run(self._download_file())
except Exception:
logger.exception("文件下载失败")
if self.file_ and os.path.exists(self.file_):
return os.path.abspath(self.file_)
except Exception as e:
logger.error(f"文件下载失败: {e}")
return ""
@@ -367,7 +367,7 @@ class DingtalkPlatformAdapter(Platform):
async def get_access_token(self) -> str:
try:
access_token = await asyncio.get_event_loop().run_in_executor(
access_token = await asyncio.get_running_loop().run_in_executor(
None,
self.client_.get_access_token,
)
@@ -760,7 +760,7 @@ class DingtalkPlatformAdapter(Platform):
return
logger.error(f"钉钉机器人启动失败: {e}")
loop = asyncio.get_event_loop()
loop = asyncio.get_running_loop()
await loop.run_in_executor(None, start_client, loop)
async def terminate(self) -> None:
@@ -80,7 +80,7 @@ class QQOfficialMessageEvent(AstrMessageEvent):
if isinstance(source, botpy.message.C2CMessage):
# 真流式传输
current_time = asyncio.get_event_loop().time()
current_time = asyncio.get_running_loop().time()
time_since_last_edit = current_time - last_edit_time
if time_since_last_edit >= throttle_interval:
@@ -90,7 +90,7 @@ class QQOfficialMessageEvent(AstrMessageEvent):
)
stream_payload["index"] += 1
stream_payload["id"] = ret["id"]
last_edit_time = asyncio.get_event_loop().time()
last_edit_time = asyncio.get_running_loop().time()
if isinstance(source, botpy.message.C2CMessage):
# 结束流式对话,并且传输 buffer 中剩余的消息
@@ -55,7 +55,7 @@ class QQOfficialWebhook:
max_async=1,
connect=bot_connect,
dispatch=self.client.ws_dispatch,
loop=asyncio.get_event_loop(),
loop=asyncio.get_running_loop(),
api=self.api,
)
@@ -626,7 +626,7 @@ class TelegramPlatformEvent(AstrMessageEvent):
# 发送初始 typing 状态
await self._ensure_typing(user_name, message_thread_id)
last_chat_action_time = asyncio.get_event_loop().time()
last_chat_action_time = asyncio.get_running_loop().time()
def _append_text(t: str) -> None:
nonlocal delta
@@ -657,11 +657,11 @@ class TelegramPlatformEvent(AstrMessageEvent):
# 编辑或发送消息
if message_id and len(delta) <= self.MAX_MESSAGE_LENGTH:
current_time = asyncio.get_event_loop().time()
current_time = asyncio.get_running_loop().time()
time_since_last_edit = current_time - last_edit_time
if time_since_last_edit >= throttle_interval:
current_time = asyncio.get_event_loop().time()
current_time = asyncio.get_running_loop().time()
if current_time - last_chat_action_time >= chat_action_interval:
await self._ensure_typing(user_name, message_thread_id)
last_chat_action_time = current_time
@@ -674,9 +674,9 @@ class TelegramPlatformEvent(AstrMessageEvent):
current_content = delta
except Exception as e:
logger.warning(f"编辑消息失败(streaming): {e!s}")
last_edit_time = asyncio.get_event_loop().time()
last_edit_time = asyncio.get_running_loop().time()
else:
current_time = asyncio.get_event_loop().time()
current_time = asyncio.get_running_loop().time()
if current_time - last_chat_action_time >= chat_action_interval:
await self._ensure_typing(user_name, message_thread_id)
last_chat_action_time = current_time
@@ -688,7 +688,7 @@ class TelegramPlatformEvent(AstrMessageEvent):
except Exception as e:
logger.warning(f"发送消息失败(streaming): {e!s}")
message_id = msg.message_id
last_edit_time = asyncio.get_event_loop().time()
last_edit_time = asyncio.get_running_loop().time()
try:
if delta and current_content != delta:
@@ -200,7 +200,7 @@ class WecomPlatformAdapter(Platform):
return msg_list[-1]
return None
msg_new = await asyncio.get_event_loop().run_in_executor(
msg_new = await asyncio.get_running_loop().run_in_executor(
None,
get_latest_msg_item,
)
@@ -261,7 +261,7 @@ class WecomPlatformAdapter(Platform):
@override
async def run(self) -> None:
loop = asyncio.get_event_loop()
loop = asyncio.get_running_loop()
if self.kf_name:
try:
acc_list = (
@@ -339,7 +339,7 @@ class WecomPlatformAdapter(Platform):
abm.session_id = abm.sender.user_id
abm.raw_message = msg
elif isinstance(msg, VoiceMessage):
resp: Response = await asyncio.get_event_loop().run_in_executor(
resp: Response = await asyncio.get_running_loop().run_in_executor(
None,
self.client.media.download,
msg.media_id,
@@ -395,7 +395,7 @@ class WecomPlatformAdapter(Platform):
abm.message_str = text
elif msgtype == "image":
media_id = msg.get("image", {}).get("media_id", "")
resp: Response = await asyncio.get_event_loop().run_in_executor(
resp: Response = await asyncio.get_running_loop().run_in_executor(
None,
self.client.media.download,
media_id,
@@ -407,7 +407,7 @@ class WecomPlatformAdapter(Platform):
abm.message = [Image(file=path, url=path)]
elif msgtype == "voice":
media_id = msg.get("voice", {}).get("media_id", "")
resp: Response = await asyncio.get_event_loop().run_in_executor(
resp: Response = await asyncio.get_running_loop().run_in_executor(
None,
self.client.media.download,
media_id,
@@ -4,6 +4,7 @@
"""
import asyncio
import time
from collections.abc import Awaitable, Callable
from typing import Any
@@ -82,7 +83,7 @@ class WecomAIQueueMgr:
del self.pending_responses[session_id]
logger.debug(f"[WecomAI] 移除待处理响应: {session_id}")
if mark_finished:
self.completed_streams[session_id] = asyncio.get_event_loop().time()
self.completed_streams[session_id] = time.monotonic()
logger.debug(f"[WecomAI] 标记流已结束: {session_id}")
def remove_queue(self, session_id: str):
@@ -135,7 +136,7 @@ class WecomAIQueueMgr:
"""
self.pending_responses[session_id] = {
"callback_params": callback_params,
"timestamp": asyncio.get_event_loop().time(),
"timestamp": time.monotonic(),
}
logger.debug(f"[WecomAI] 设置待处理响应: {session_id}")
@@ -160,7 +161,7 @@ class WecomAIQueueMgr:
finished_at = self.completed_streams.get(session_id)
if finished_at is None:
return False
if asyncio.get_event_loop().time() - finished_at > max_age_seconds:
if time.monotonic() - finished_at > max_age_seconds:
self.completed_streams.pop(session_id, None)
return False
return True
@@ -172,7 +173,7 @@ class WecomAIQueueMgr:
max_age_seconds: 最大存活时间
"""
current_time = asyncio.get_event_loop().time()
current_time = time.monotonic()
expired_sessions = []
for session_id, response_data in self.pending_responses.items():
@@ -369,7 +369,7 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
if future:
logger.debug(f"duplicate message id checked: {msg.id}")
else:
future = asyncio.get_event_loop().create_future()
future = asyncio.get_running_loop().create_future()
self.wexin_event_workers[msg_id] = future
await self.convert_message(msg, future)
# I love shield so much!
@@ -461,7 +461,7 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
elif msg.type == "voice":
assert isinstance(msg, VoiceMessage)
resp: Response = await asyncio.get_event_loop().run_in_executor(
resp: Response = await asyncio.get_running_loop().run_in_executor(
None,
self.client.media.download,
msg.media_id,
@@ -87,7 +87,7 @@ class ProviderDashscopeTTSAPI(TTSProvider):
model: str,
text: str,
) -> tuple[bytes | None, str]:
loop = asyncio.get_event_loop()
loop = asyncio.get_running_loop()
response = await loop.run_in_executor(None, self._call_qwen_tts, model, text)
audio_bytes = await self._extract_audio_from_response(response)
if not audio_bytes:
@@ -143,7 +143,7 @@ class ProviderDashscopeTTSAPI(TTSProvider):
voice=self.voice,
format=AudioFormat.WAV_24000HZ_MONO_16BIT,
)
loop = asyncio.get_event_loop()
loop = asyncio.get_running_loop()
audio_bytes = await loop.run_in_executor(
None,
synthesizer.call,
+2 -2
View File
@@ -59,7 +59,7 @@ class GenieTTSProvider(TTSProvider):
filename = f"genie_tts_{uuid.uuid4()}.wav"
path = os.path.join(temp_dir, filename)
loop = asyncio.get_event_loop()
loop = asyncio.get_running_loop()
def _generate(save_path: str) -> None:
assert genie is not None
@@ -85,7 +85,7 @@ class GenieTTSProvider(TTSProvider):
text_queue: asyncio.Queue[str | None],
audio_queue: "asyncio.Queue[bytes | tuple[str, bytes] | None]",
) -> None:
loop = asyncio.get_event_loop()
loop = asyncio.get_running_loop()
while True:
text = await text_queue.get()
@@ -43,7 +43,7 @@ class ProviderSenseVoiceSTTSelfHost(STTProvider):
logger.info("下载或者加载 SenseVoice 模型中,这可能需要一些时间 ...")
# 将模型加载放到线程池中执行
self.model = await asyncio.get_event_loop().run_in_executor(
self.model = await asyncio.get_running_loop().run_in_executor(
None,
lambda: SenseVoiceSmall(self.model_name, quantize=True, batch_size=16),
)
@@ -88,7 +88,7 @@ class ProviderSenseVoiceSTTSelfHost(STTProvider):
audio_url = output_path
# 使用 run_in_executor 来调用模型进行识别
loop = asyncio.get_event_loop()
loop = asyncio.get_running_loop()
res = await loop.run_in_executor(
None, # 使用默认的线程池
lambda: cast(SenseVoiceSmall, self.model)(
@@ -31,7 +31,7 @@ class ProviderOpenAIWhisperSelfHost(STTProvider):
self.model = None
async def initialize(self) -> None:
loop = asyncio.get_event_loop()
loop = asyncio.get_running_loop()
logger.info("下载或者加载 Whisper 模型中,这可能需要一些时间 ...")
self.model = await loop.run_in_executor(
None,
@@ -50,7 +50,7 @@ class ProviderOpenAIWhisperSelfHost(STTProvider):
return False
async def get_text(self, audio_url: str) -> str:
loop = asyncio.get_event_loop()
loop = asyncio.get_running_loop()
is_tencent = False
+14 -1
View File
@@ -1374,10 +1374,23 @@ class PluginManager:
return
if "__del__" in star_metadata.star_cls_type.__dict__:
asyncio.get_event_loop().run_in_executor(
loop = asyncio.get_running_loop()
future = loop.run_in_executor(
None,
star_metadata.star_cls.__del__,
)
def _log_del_exception(fut: asyncio.Future) -> None:
if fut.cancelled():
return
if (exc := fut.exception()) is not None:
logger.error(
"插件 %s 在 __del__ 中抛出了异常:%r",
star_metadata.name,
exc,
)
future.add_done_callback(_log_del_exception)
elif "terminate" in star_metadata.star_cls_type.__dict__:
await star_metadata.star_cls.terminate()
+27 -1
View File
@@ -1,9 +1,13 @@
import asyncio
import threading
import weakref
from collections import defaultdict
from contextlib import asynccontextmanager
class SessionLockManager:
class _PerLoopSessionLockManager:
"""Per-event-loop session lock manager; keeps original simple semantics."""
def __init__(self) -> None:
self._locks: dict[str, asyncio.Lock] = defaultdict(asyncio.Lock)
self._lock_count: dict[str, int] = defaultdict(int)
@@ -26,4 +30,26 @@ class SessionLockManager:
self._lock_count.pop(session_id, None)
class SessionLockManager:
"""Thread-safe session lock manager with per-event-loop isolation."""
def __init__(self) -> None:
self._state_guard = threading.Lock()
self._loop_managers: weakref.WeakKeyDictionary[
asyncio.AbstractEventLoop, _PerLoopSessionLockManager
] = weakref.WeakKeyDictionary()
def _get_loop_manager(self) -> _PerLoopSessionLockManager:
"""Get the lock manager for the current event loop."""
loop = asyncio.get_running_loop()
with self._state_guard:
return self._loop_managers.setdefault(loop, _PerLoopSessionLockManager())
@asynccontextmanager
async def acquire_lock(self, session_id: str):
manager = self._get_loop_manager()
async with manager.acquire_lock(session_id):
yield
session_lock_manager = SessionLockManager()
+22 -11
View File
@@ -101,6 +101,26 @@ async def check_dashboard_files(webui_dir: str | None = None):
return data_dist_path
async def main_async(webui_dir_arg: str | None) -> None:
"""主异步入口"""
# 检查仪表板文件
webui_dir = await check_dashboard_files(webui_dir_arg)
if webui_dir is None:
logger.warning(
"管理面板文件检查失败,WebUI 功能将不可用。"
"请检查网络连接或手动指定 --webui-dir 参数。"
)
db = db_helper
# 打印 logo
logger.info(logo_tmpl)
core_lifecycle = InitialLoader(db, log_broker)
core_lifecycle.webui_dir = webui_dir
await core_lifecycle.start()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="AstrBot")
parser.add_argument(
@@ -117,14 +137,5 @@ if __name__ == "__main__":
log_broker = LogBroker()
LogManager.set_queue_handler(logger, log_broker)
# 检查仪表板文件
webui_dir = asyncio.run(check_dashboard_files(args.webui_dir))
db = db_helper
# 打印 logo
logger.info(logo_tmpl)
core_lifecycle = InitialLoader(db, log_broker)
core_lifecycle.webui_dir = webui_dir
asyncio.run(core_lifecycle.start())
# 只使用一次 asyncio.run()
asyncio.run(main_async(args.webui_dir))
+545
View File
@@ -0,0 +1,545 @@
"""Tests for SessionLockManager with multi-event-loop isolation."""
import asyncio
import threading
import time
import weakref
from concurrent.futures import ThreadPoolExecutor
import pytest
from astrbot.core.utils.session_lock import SessionLockManager
class TestSessionLockManagerBasic:
"""Basic functionality tests."""
def test_init(self):
"""Test manager initialization."""
manager = SessionLockManager()
assert manager._state_guard is not None
assert manager._loop_managers is not None
@pytest.mark.asyncio
async def test_acquire_release_lock(self):
"""Test basic lock acquire and release."""
manager = SessionLockManager()
session_id = "test-session"
async with manager.acquire_lock(session_id):
# Lock acquired successfully
pass
# Lock should be released and cleaned up
state = manager._get_loop_manager()
assert session_id not in state._locks
assert session_id not in state._lock_count
@pytest.mark.asyncio
async def test_lock_is_reusable(self):
"""Test that locks can be acquired multiple times."""
manager = SessionLockManager()
session_id = "test-session"
async with manager.acquire_lock(session_id):
pass
async with manager.acquire_lock(session_id):
pass
# Both acquisitions should succeed
class TestCrossLoopIsolation:
"""Tests for event loop isolation."""
@pytest.mark.asyncio
async def test_different_loops_have_different_managers(self):
"""Test that different event loops get different per-loop managers."""
manager = SessionLockManager()
# Get manager for current loop
manager1 = manager._get_loop_manager()
# Run in a different event loop
def run_in_new_loop():
new_loop = asyncio.new_event_loop()
try:
asyncio.set_event_loop(new_loop)
async def get_manager():
return manager._get_loop_manager()
return new_loop.run_until_complete(get_manager())
finally:
new_loop.close()
asyncio.set_event_loop(None)
with ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(run_in_new_loop)
manager2 = future.result()
# Should be different manager instances
assert manager1 is not manager2
@pytest.mark.asyncio
async def test_locks_isolated_across_loops(self):
"""Test that locks from different loops are isolated."""
manager = SessionLockManager()
session_id = "shared-session"
results = []
async def acquire_in_loop(loop_id: int):
"""Acquire lock in a new event loop."""
async with manager.acquire_lock(session_id):
results.append(f"loop-{loop_id}-acquired")
await asyncio.sleep(0.05)
results.append(f"loop-{loop_id}-released")
def run_in_thread(loop_id: int):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(acquire_in_loop(loop_id))
finally:
loop.close()
asyncio.set_event_loop(None)
# Run two loops concurrently - they should NOT block each other
# because locks are isolated per-loop
with ThreadPoolExecutor(max_workers=2) as executor:
futures = [executor.submit(run_in_thread, i) for i in range(2)]
for f in futures:
f.result()
# Both loops should acquire immediately (no blocking between loops)
# Order should show interleaved acquisitions, not sequential
assert len(results) == 4
@pytest.mark.asyncio
async def test_same_loop_blocks_on_same_session(self):
"""Test that same loop blocks when acquiring same session lock."""
manager = SessionLockManager()
session_id = "test-session"
execution_order = []
async def task1():
async with manager.acquire_lock(session_id):
execution_order.append("task1-start")
await asyncio.sleep(0.1)
execution_order.append("task1-end")
async def task2():
await asyncio.sleep(0.01) # Let task1 start first
async with manager.acquire_lock(session_id):
execution_order.append("task2-start")
execution_order.append("task2-end")
await asyncio.gather(task1(), task2())
# task2 should wait for task1 to finish
assert execution_order.index("task1-start") < execution_order.index("task1-end")
assert execution_order.index("task1-end") < execution_order.index("task2-start")
class TestConcurrency:
"""Tests for concurrent access."""
@pytest.mark.asyncio
async def test_concurrent_acquisitions_same_loop(self):
"""Test concurrent lock acquisitions on the same loop."""
manager = SessionLockManager()
session_id = "concurrent-session"
acquired_count = 0
max_concurrent = 0
lock = asyncio.Lock()
async def acquire_and_check():
nonlocal acquired_count, max_concurrent
async with manager.acquire_lock(session_id):
async with lock:
acquired_count += 1
max_concurrent = max(max_concurrent, acquired_count)
await asyncio.sleep(0.01)
async with lock:
acquired_count -= 1
# Run multiple concurrent tasks
tasks = [acquire_and_check() for _ in range(5)]
await asyncio.gather(*tasks)
# Max concurrent should be 1 (lock serializes access)
assert max_concurrent == 1
@pytest.mark.asyncio
async def test_thread_safety_of_loop_manager_creation(self):
"""Test that _get_loop_manager is thread-safe."""
manager = SessionLockManager()
managers = []
errors = []
def create_loop_and_get_manager():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
async def get_mgr():
return manager._get_loop_manager()
mgr = loop.run_until_complete(get_mgr())
managers.append(mgr)
except Exception as e:
errors.append(e)
finally:
loop.close()
asyncio.set_event_loop(None)
threads = [threading.Thread(target=create_loop_and_get_manager) for _ in range(10)]
for t in threads:
t.start()
for t in threads:
t.join()
assert len(errors) == 0
# All managers should be valid
for m in managers:
assert hasattr(m, "_locks")
assert hasattr(m, "_access_lock")
class TestEventLoopCleanup:
"""Tests for event loop cleanup behavior."""
@pytest.mark.asyncio
async def test_weakref_cleanup_on_loop_close(self):
"""Test that per-loop managers are cleaned up when loop is closed."""
manager = SessionLockManager()
loop_ref: weakref.ref[asyncio.AbstractEventLoop] | None = None
def run_in_new_loop():
nonlocal loop_ref
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop_ref = weakref.ref(loop)
async def use_lock():
async with manager.acquire_lock("test-session"):
pass
return manager._get_loop_manager()
try:
per_loop_mgr = loop.run_until_complete(use_lock())
# Keep a weak ref to the per-loop manager
return weakref.ref(per_loop_mgr)
finally:
loop.close()
asyncio.set_event_loop(None)
with ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(run_in_new_loop)
per_loop_mgr_ref = future.result()
# Give time for weakref cleanup
import gc
gc.collect()
# The per-loop manager should be cleaned up when the loop is closed
# because WeakKeyDictionary removes entries when the key (loop) is gone
per_loop_mgr = per_loop_mgr_ref()
loop = loop_ref() if loop_ref is not None else None
assert per_loop_mgr is None or loop is None
@pytest.mark.asyncio
async def test_access_after_loop_close_in_new_loop_works(self):
"""Test that accessing from a new loop after old loop closes works."""
manager = SessionLockManager()
# Use lock in current loop
async with manager.acquire_lock("session-1"):
pass
# Simulate old loop being closed and new loop being created
def run_in_new_loop():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
async def use_lock():
# Should work without issues in new loop
async with manager.acquire_lock("session-2"):
return "success"
return loop.run_until_complete(use_lock())
finally:
loop.close()
asyncio.set_event_loop(None)
with ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(run_in_new_loop)
result = future.result()
assert result == "success"
class TestIssue5464:
"""Tests for issue #5464: Multiple OneBot instances with different event loops.
Issue: Running multiple OneBot adapter instances causes
"is bound to a different event loop" error.
"""
@pytest.mark.asyncio
async def test_multiple_event_loops_no_cross_loop_error(self):
"""Test that multiple event loops don't cause cross-loop binding errors.
This simulates the scenario where multiple OneBot instances
(each potentially running in different event loops) access the
same SessionLockManager concurrently.
"""
from astrbot.core.utils.session_lock import session_lock_manager
errors: list[Exception] = []
results: list[str] = []
def simulate_onebot_instance(instance_id: int, session_ids: list[str]):
"""Simulate a OneBot instance running in its own event loop."""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
async def process_messages():
for session_id in session_ids:
try:
async with session_lock_manager.acquire_lock(session_id):
# Simulate message processing
await asyncio.sleep(0.01)
results.append(f"instance-{instance_id}-{session_id}")
except Exception as e:
errors.append(e)
loop.run_until_complete(process_messages())
finally:
loop.close()
asyncio.set_event_loop(None)
# Simulate 4 OneBot instances (as in the issue report)
# Each handles multiple sessions concurrently
threads = []
for i in range(4):
sessions = [f"session-{i}-1", f"session-{i}-2", f"session-{i}-3"]
t = threading.Thread(target=simulate_onebot_instance, args=(i, sessions))
threads.append(t)
for t in threads:
t.start()
for t in threads:
t.join()
# Should have no errors (especially no "bound to a different event loop")
assert len(errors) == 0, f"Errors occurred: {errors}"
assert len(results) == 12 # 4 instances * 3 sessions each
@pytest.mark.asyncio
async def test_lock_object_not_shared_across_loops(self):
"""Verify that asyncio.Lock objects are not shared across event loops.
The root cause of issue #5464 was that Lock objects created in one
event loop were being used in another, causing the error.
"""
manager = SessionLockManager()
session_id = "shared-session-id"
lock_ids: set[int] = set()
lock_id_lock = threading.Lock()
def get_lock_in_new_loop():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
async def acquire_and_capture():
# Get the per-loop manager
per_loop_mgr = manager._get_loop_manager()
# Capture the lock object id before acquiring
async with per_loop_mgr._access_lock:
lock = per_loop_mgr._locks[session_id]
with lock_id_lock:
lock_ids.add(id(lock))
async with manager.acquire_lock(session_id):
await asyncio.sleep(0.01)
loop.run_until_complete(acquire_and_capture())
finally:
loop.close()
asyncio.set_event_loop(None)
# Run multiple loops concurrently
threads = [threading.Thread(target=get_lock_in_new_loop) for _ in range(5)]
for t in threads:
t.start()
for t in threads:
t.join()
# Each loop should have its own Lock object
# If locks were shared, we'd only have 1 lock_id
assert len(lock_ids) == 5, "Each event loop should have its own Lock object"
@pytest.mark.asyncio
async def test_concurrent_access_same_session_different_loops(self):
"""Test that same session ID accessed from different loops doesn't block.
This verifies the fix: locks are isolated per event loop,
so different loops can acquire the "same" session lock concurrently.
"""
from astrbot.core.utils.session_lock import session_lock_manager
session_id = "global-session"
acquisition_times: list[float] = []
time_lock = threading.Lock()
def acquire_lock_in_loop(loop_id: int):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
async def acquire():
import time
start = time.time()
async with session_lock_manager.acquire_lock(session_id):
with time_lock:
acquisition_times.append(start)
await asyncio.sleep(0.1) # Hold the lock
loop.run_until_complete(acquire())
finally:
loop.close()
asyncio.set_event_loop(None)
# Start 3 threads nearly simultaneously
threads = [threading.Thread(target=acquire_lock_in_loop, args=(i,)) for i in range(3)]
start_time = time.time()
for t in threads:
t.start()
for t in threads:
t.join()
total_time = time.time() - start_time
# If locks were NOT isolated, we'd need ~0.3s (3 * 0.1s serial)
# With isolation, all should complete in ~0.1s (parallel)
# Allow some overhead, but should be much less than 0.3s
assert total_time < 0.25, (
f"Locks should be isolated per loop, but took {total_time:.2f}s"
)
class TestEdgeCases:
"""Tests for edge cases."""
@pytest.mark.asyncio
async def test_empty_session_id(self):
"""Test with empty session ID."""
manager = SessionLockManager()
async with manager.acquire_lock(""):
pass
# Should work without issues
@pytest.mark.asyncio
async def test_special_characters_in_session_id(self):
"""Test with special characters in session ID."""
manager = SessionLockManager()
session_id = "session-with-special-chars!@#$%^&*()"
async with manager.acquire_lock(session_id):
pass
# Should work without issues
@pytest.mark.asyncio
async def test_very_long_session_id(self):
"""Test with very long session ID."""
manager = SessionLockManager()
session_id = "a" * 10000
async with manager.acquire_lock(session_id):
pass
# Should work without issues
@pytest.mark.asyncio
async def test_lock_not_held_after_context_exit(self):
"""Test that lock is released after context manager exit."""
manager = SessionLockManager()
session_id = "test-session"
async with manager.acquire_lock(session_id):
state = manager._get_loop_manager()
# Lock should exist and have count 1
assert session_id in state._locks
assert state._lock_count[session_id] == 1
# After exit, lock should be cleaned up
state = manager._get_loop_manager()
assert session_id not in state._locks
assert session_id not in state._lock_count
@pytest.mark.asyncio
async def test_exception_during_lock(self):
"""Test that lock is released even if exception occurs."""
manager = SessionLockManager()
session_id = "test-session"
with pytest.raises(ValueError):
async with manager.acquire_lock(session_id):
raise ValueError("test error")
# Lock should still be released
state = manager._get_loop_manager()
assert session_id not in state._locks
assert session_id not in state._lock_count
@pytest.mark.asyncio
async def test_nested_lock_different_sessions(self):
"""Test nested locks on different sessions."""
manager = SessionLockManager()
async with manager.acquire_lock("session-1"):
async with manager.acquire_lock("session-2"):
state = manager._get_loop_manager()
assert "session-1" in state._locks
assert "session-2" in state._locks
assert state._lock_count["session-1"] == 1
assert state._lock_count["session-2"] == 1
state = manager._get_loop_manager()
assert "session-1" not in state._locks
assert "session-2" not in state._locks
@pytest.mark.asyncio
async def test_reentrant_lock_same_session(self):
"""Test reentrant locking on same session (should block)."""
manager = SessionLockManager()
session_id = "test-session"
order = []
async def outer():
async with manager.acquire_lock(session_id):
order.append("outer-acquired")
await asyncio.sleep(0.1)
order.append("outer-done")
async def inner():
await asyncio.sleep(0.01) # Let outer acquire first
order.append("inner-attempt")
async with manager.acquire_lock(session_id):
order.append("inner-acquired")
order.append("inner-done")
await asyncio.gather(outer(), inner())
# Inner should wait for outer to complete
assert order.index("outer-acquired") < order.index("outer-done")
assert order.index("outer-done") < order.index("inner-acquired")