diff --git a/astrbot/core/agent/runners/dashscope/dashscope_agent_runner.py b/astrbot/core/agent/runners/dashscope/dashscope_agent_runner.py index 1aaf6e3b9..8169a678c 100644 --- a/astrbot/core/agent/runners/dashscope/dashscope_agent_runner.py +++ b/astrbot/core/agent/runners/dashscope/dashscope_agent_runner.py @@ -302,7 +302,7 @@ class DashscopeAgentRunner(BaseAgentRunner[TContext]): while True: try: - item_type, item_data = await asyncio.get_event_loop().run_in_executor( + item_type, item_data = await asyncio.get_running_loop().run_in_executor( None, response_queue.get, True, 1 ) except queue.Empty: @@ -388,7 +388,7 @@ class DashscopeAgentRunner(BaseAgentRunner[TContext]): # 发起请求 partial = functools.partial(Application.call, **payload) - response = await asyncio.get_event_loop().run_in_executor(None, partial) + response = await asyncio.get_running_loop().run_in_executor(None, partial) async for resp in self._handle_streaming_response(response, session_id): yield resp diff --git a/astrbot/core/computer/booters/bay_manager.py b/astrbot/core/computer/booters/bay_manager.py index 24fa379e8..61ccc1b3a 100644 --- a/astrbot/core/computer/booters/bay_manager.py +++ b/astrbot/core/computer/booters/bay_manager.py @@ -121,11 +121,12 @@ class BayContainerManager: async def wait_healthy(self, timeout: int = HEALTH_TIMEOUT_S) -> None: """Block until Bay's ``/health`` endpoint returns 200.""" url = f"http://127.0.0.1:{self._host_port}/health" - deadline = asyncio.get_event_loop().time() + timeout + loop = asyncio.get_running_loop() + deadline = loop.time() + timeout last_error: str = "" async with aiohttp.ClientSession() as session: - while asyncio.get_event_loop().time() < deadline: + while loop.time() < deadline: try: async with session.get( url, timeout=aiohttp.ClientTimeout(total=3) diff --git a/astrbot/core/message/components.py b/astrbot/core/message/components.py index 6dbe78ae4..d9ea6aa26 100644 --- a/astrbot/core/message/components.py +++ b/astrbot/core/message/components.py @@ -699,21 +699,24 @@ class File(BaseMessageComponent): if self.url: try: - loop = asyncio.get_event_loop() - if loop.is_running(): - logger.warning( - "不可以在异步上下文中同步等待下载! " - "这个警告通常发生于某些逻辑试图通过 .file 获取文件消息段的文件内容。" - "请使用 await get_file() 代替直接获取 .file 字段", - ) - return "" - # 等待下载完成 - loop.run_until_complete(self._download_file()) + # 检查是否有正在运行的 event loop + asyncio.get_running_loop() + logger.warning( + "不可以在异步上下文中同步等待下载! " + "这个警告通常发生于某些逻辑试图通过 .file 获取文件消息段的文件内容。" + "请使用 await get_file() 代替直接获取 .file 字段", + ) + return "" + except RuntimeError: + # 没有运行中的 event loop,可以同步执行 + try: + # 使用 asyncio.run 安全地创建和关闭事件循环 + asyncio.run(self._download_file()) + except Exception: + logger.exception("文件下载失败") if self.file_ and os.path.exists(self.file_): return os.path.abspath(self.file_) - except Exception as e: - logger.error(f"文件下载失败: {e}") return "" diff --git a/astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py b/astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py index 7982a6593..37c3b09ab 100644 --- a/astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py +++ b/astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py @@ -367,7 +367,7 @@ class DingtalkPlatformAdapter(Platform): async def get_access_token(self) -> str: try: - access_token = await asyncio.get_event_loop().run_in_executor( + access_token = await asyncio.get_running_loop().run_in_executor( None, self.client_.get_access_token, ) @@ -760,7 +760,7 @@ class DingtalkPlatformAdapter(Platform): return logger.error(f"钉钉机器人启动失败: {e}") - loop = asyncio.get_event_loop() + loop = asyncio.get_running_loop() await loop.run_in_executor(None, start_client, loop) async def terminate(self) -> None: diff --git a/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py b/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py index 868ec8a65..2b417f45f 100644 --- a/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py +++ b/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py @@ -80,7 +80,7 @@ class QQOfficialMessageEvent(AstrMessageEvent): if isinstance(source, botpy.message.C2CMessage): # 真流式传输 - current_time = asyncio.get_event_loop().time() + current_time = asyncio.get_running_loop().time() time_since_last_edit = current_time - last_edit_time if time_since_last_edit >= throttle_interval: @@ -90,7 +90,7 @@ class QQOfficialMessageEvent(AstrMessageEvent): ) stream_payload["index"] += 1 stream_payload["id"] = ret["id"] - last_edit_time = asyncio.get_event_loop().time() + last_edit_time = asyncio.get_running_loop().time() if isinstance(source, botpy.message.C2CMessage): # 结束流式对话,并且传输 buffer 中剩余的消息 diff --git a/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py b/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py index 5f35471ee..bcd05faf1 100644 --- a/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py +++ b/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py @@ -55,7 +55,7 @@ class QQOfficialWebhook: max_async=1, connect=bot_connect, dispatch=self.client.ws_dispatch, - loop=asyncio.get_event_loop(), + loop=asyncio.get_running_loop(), api=self.api, ) diff --git a/astrbot/core/platform/sources/telegram/tg_event.py b/astrbot/core/platform/sources/telegram/tg_event.py index 96c7c5568..43e58960e 100644 --- a/astrbot/core/platform/sources/telegram/tg_event.py +++ b/astrbot/core/platform/sources/telegram/tg_event.py @@ -626,7 +626,7 @@ class TelegramPlatformEvent(AstrMessageEvent): # 发送初始 typing 状态 await self._ensure_typing(user_name, message_thread_id) - last_chat_action_time = asyncio.get_event_loop().time() + last_chat_action_time = asyncio.get_running_loop().time() def _append_text(t: str) -> None: nonlocal delta @@ -657,11 +657,11 @@ class TelegramPlatformEvent(AstrMessageEvent): # 编辑或发送消息 if message_id and len(delta) <= self.MAX_MESSAGE_LENGTH: - current_time = asyncio.get_event_loop().time() + current_time = asyncio.get_running_loop().time() time_since_last_edit = current_time - last_edit_time if time_since_last_edit >= throttle_interval: - current_time = asyncio.get_event_loop().time() + current_time = asyncio.get_running_loop().time() if current_time - last_chat_action_time >= chat_action_interval: await self._ensure_typing(user_name, message_thread_id) last_chat_action_time = current_time @@ -674,9 +674,9 @@ class TelegramPlatformEvent(AstrMessageEvent): current_content = delta except Exception as e: logger.warning(f"编辑消息失败(streaming): {e!s}") - last_edit_time = asyncio.get_event_loop().time() + last_edit_time = asyncio.get_running_loop().time() else: - current_time = asyncio.get_event_loop().time() + current_time = asyncio.get_running_loop().time() if current_time - last_chat_action_time >= chat_action_interval: await self._ensure_typing(user_name, message_thread_id) last_chat_action_time = current_time @@ -688,7 +688,7 @@ class TelegramPlatformEvent(AstrMessageEvent): except Exception as e: logger.warning(f"发送消息失败(streaming): {e!s}") message_id = msg.message_id - last_edit_time = asyncio.get_event_loop().time() + last_edit_time = asyncio.get_running_loop().time() try: if delta and current_content != delta: diff --git a/astrbot/core/platform/sources/wecom/wecom_adapter.py b/astrbot/core/platform/sources/wecom/wecom_adapter.py index 6647db89f..410b30eea 100644 --- a/astrbot/core/platform/sources/wecom/wecom_adapter.py +++ b/astrbot/core/platform/sources/wecom/wecom_adapter.py @@ -200,7 +200,7 @@ class WecomPlatformAdapter(Platform): return msg_list[-1] return None - msg_new = await asyncio.get_event_loop().run_in_executor( + msg_new = await asyncio.get_running_loop().run_in_executor( None, get_latest_msg_item, ) @@ -261,7 +261,7 @@ class WecomPlatformAdapter(Platform): @override async def run(self) -> None: - loop = asyncio.get_event_loop() + loop = asyncio.get_running_loop() if self.kf_name: try: acc_list = ( @@ -339,7 +339,7 @@ class WecomPlatformAdapter(Platform): abm.session_id = abm.sender.user_id abm.raw_message = msg elif isinstance(msg, VoiceMessage): - resp: Response = await asyncio.get_event_loop().run_in_executor( + resp: Response = await asyncio.get_running_loop().run_in_executor( None, self.client.media.download, msg.media_id, @@ -395,7 +395,7 @@ class WecomPlatformAdapter(Platform): abm.message_str = text elif msgtype == "image": media_id = msg.get("image", {}).get("media_id", "") - resp: Response = await asyncio.get_event_loop().run_in_executor( + resp: Response = await asyncio.get_running_loop().run_in_executor( None, self.client.media.download, media_id, @@ -407,7 +407,7 @@ class WecomPlatformAdapter(Platform): abm.message = [Image(file=path, url=path)] elif msgtype == "voice": media_id = msg.get("voice", {}).get("media_id", "") - resp: Response = await asyncio.get_event_loop().run_in_executor( + resp: Response = await asyncio.get_running_loop().run_in_executor( None, self.client.media.download, media_id, diff --git a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_queue_mgr.py b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_queue_mgr.py index 9b6e6b968..efa94b58e 100644 --- a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_queue_mgr.py +++ b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_queue_mgr.py @@ -4,6 +4,7 @@ """ import asyncio +import time from collections.abc import Awaitable, Callable from typing import Any @@ -82,7 +83,7 @@ class WecomAIQueueMgr: del self.pending_responses[session_id] logger.debug(f"[WecomAI] 移除待处理响应: {session_id}") if mark_finished: - self.completed_streams[session_id] = asyncio.get_event_loop().time() + self.completed_streams[session_id] = time.monotonic() logger.debug(f"[WecomAI] 标记流已结束: {session_id}") def remove_queue(self, session_id: str): @@ -135,7 +136,7 @@ class WecomAIQueueMgr: """ self.pending_responses[session_id] = { "callback_params": callback_params, - "timestamp": asyncio.get_event_loop().time(), + "timestamp": time.monotonic(), } logger.debug(f"[WecomAI] 设置待处理响应: {session_id}") @@ -160,7 +161,7 @@ class WecomAIQueueMgr: finished_at = self.completed_streams.get(session_id) if finished_at is None: return False - if asyncio.get_event_loop().time() - finished_at > max_age_seconds: + if time.monotonic() - finished_at > max_age_seconds: self.completed_streams.pop(session_id, None) return False return True @@ -172,7 +173,7 @@ class WecomAIQueueMgr: max_age_seconds: 最大存活时间(秒) """ - current_time = asyncio.get_event_loop().time() + current_time = time.monotonic() expired_sessions = [] for session_id, response_data in self.pending_responses.items(): diff --git a/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py b/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py index c01355974..bb7061ca1 100644 --- a/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py +++ b/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py @@ -369,7 +369,7 @@ class WeixinOfficialAccountPlatformAdapter(Platform): if future: logger.debug(f"duplicate message id checked: {msg.id}") else: - future = asyncio.get_event_loop().create_future() + future = asyncio.get_running_loop().create_future() self.wexin_event_workers[msg_id] = future await self.convert_message(msg, future) # I love shield so much! @@ -461,7 +461,7 @@ class WeixinOfficialAccountPlatformAdapter(Platform): elif msg.type == "voice": assert isinstance(msg, VoiceMessage) - resp: Response = await asyncio.get_event_loop().run_in_executor( + resp: Response = await asyncio.get_running_loop().run_in_executor( None, self.client.media.download, msg.media_id, diff --git a/astrbot/core/provider/sources/dashscope_tts.py b/astrbot/core/provider/sources/dashscope_tts.py index 9b6816859..15e763f3e 100644 --- a/astrbot/core/provider/sources/dashscope_tts.py +++ b/astrbot/core/provider/sources/dashscope_tts.py @@ -87,7 +87,7 @@ class ProviderDashscopeTTSAPI(TTSProvider): model: str, text: str, ) -> tuple[bytes | None, str]: - loop = asyncio.get_event_loop() + loop = asyncio.get_running_loop() response = await loop.run_in_executor(None, self._call_qwen_tts, model, text) audio_bytes = await self._extract_audio_from_response(response) if not audio_bytes: @@ -143,7 +143,7 @@ class ProviderDashscopeTTSAPI(TTSProvider): voice=self.voice, format=AudioFormat.WAV_24000HZ_MONO_16BIT, ) - loop = asyncio.get_event_loop() + loop = asyncio.get_running_loop() audio_bytes = await loop.run_in_executor( None, synthesizer.call, diff --git a/astrbot/core/provider/sources/genie_tts.py b/astrbot/core/provider/sources/genie_tts.py index 8f9b6d91d..b76bf6b46 100644 --- a/astrbot/core/provider/sources/genie_tts.py +++ b/astrbot/core/provider/sources/genie_tts.py @@ -59,7 +59,7 @@ class GenieTTSProvider(TTSProvider): filename = f"genie_tts_{uuid.uuid4()}.wav" path = os.path.join(temp_dir, filename) - loop = asyncio.get_event_loop() + loop = asyncio.get_running_loop() def _generate(save_path: str) -> None: assert genie is not None @@ -85,7 +85,7 @@ class GenieTTSProvider(TTSProvider): text_queue: asyncio.Queue[str | None], audio_queue: "asyncio.Queue[bytes | tuple[str, bytes] | None]", ) -> None: - loop = asyncio.get_event_loop() + loop = asyncio.get_running_loop() while True: text = await text_queue.get() diff --git a/astrbot/core/provider/sources/sensevoice_selfhosted_source.py b/astrbot/core/provider/sources/sensevoice_selfhosted_source.py index af6c0f631..d41ebaf62 100644 --- a/astrbot/core/provider/sources/sensevoice_selfhosted_source.py +++ b/astrbot/core/provider/sources/sensevoice_selfhosted_source.py @@ -43,7 +43,7 @@ class ProviderSenseVoiceSTTSelfHost(STTProvider): logger.info("下载或者加载 SenseVoice 模型中,这可能需要一些时间 ...") # 将模型加载放到线程池中执行 - self.model = await asyncio.get_event_loop().run_in_executor( + self.model = await asyncio.get_running_loop().run_in_executor( None, lambda: SenseVoiceSmall(self.model_name, quantize=True, batch_size=16), ) @@ -88,7 +88,7 @@ class ProviderSenseVoiceSTTSelfHost(STTProvider): audio_url = output_path # 使用 run_in_executor 来调用模型进行识别 - loop = asyncio.get_event_loop() + loop = asyncio.get_running_loop() res = await loop.run_in_executor( None, # 使用默认的线程池 lambda: cast(SenseVoiceSmall, self.model)( diff --git a/astrbot/core/provider/sources/whisper_selfhosted_source.py b/astrbot/core/provider/sources/whisper_selfhosted_source.py index 678deb948..519a64de6 100644 --- a/astrbot/core/provider/sources/whisper_selfhosted_source.py +++ b/astrbot/core/provider/sources/whisper_selfhosted_source.py @@ -31,7 +31,7 @@ class ProviderOpenAIWhisperSelfHost(STTProvider): self.model = None async def initialize(self) -> None: - loop = asyncio.get_event_loop() + loop = asyncio.get_running_loop() logger.info("下载或者加载 Whisper 模型中,这可能需要一些时间 ...") self.model = await loop.run_in_executor( None, @@ -50,7 +50,7 @@ class ProviderOpenAIWhisperSelfHost(STTProvider): return False async def get_text(self, audio_url: str) -> str: - loop = asyncio.get_event_loop() + loop = asyncio.get_running_loop() is_tencent = False diff --git a/astrbot/core/star/star_manager.py b/astrbot/core/star/star_manager.py index 68c58fdae..b812698f2 100644 --- a/astrbot/core/star/star_manager.py +++ b/astrbot/core/star/star_manager.py @@ -1374,10 +1374,23 @@ class PluginManager: return if "__del__" in star_metadata.star_cls_type.__dict__: - asyncio.get_event_loop().run_in_executor( + loop = asyncio.get_running_loop() + future = loop.run_in_executor( None, star_metadata.star_cls.__del__, ) + + def _log_del_exception(fut: asyncio.Future) -> None: + if fut.cancelled(): + return + if (exc := fut.exception()) is not None: + logger.error( + "插件 %s 在 __del__ 中抛出了异常:%r", + star_metadata.name, + exc, + ) + + future.add_done_callback(_log_del_exception) elif "terminate" in star_metadata.star_cls_type.__dict__: await star_metadata.star_cls.terminate() diff --git a/astrbot/core/utils/session_lock.py b/astrbot/core/utils/session_lock.py index 7810d6ce4..732a29b72 100644 --- a/astrbot/core/utils/session_lock.py +++ b/astrbot/core/utils/session_lock.py @@ -1,9 +1,13 @@ import asyncio +import threading +import weakref from collections import defaultdict from contextlib import asynccontextmanager -class SessionLockManager: +class _PerLoopSessionLockManager: + """Per-event-loop session lock manager; keeps original simple semantics.""" + def __init__(self) -> None: self._locks: dict[str, asyncio.Lock] = defaultdict(asyncio.Lock) self._lock_count: dict[str, int] = defaultdict(int) @@ -26,4 +30,26 @@ class SessionLockManager: self._lock_count.pop(session_id, None) +class SessionLockManager: + """Thread-safe session lock manager with per-event-loop isolation.""" + + def __init__(self) -> None: + self._state_guard = threading.Lock() + self._loop_managers: weakref.WeakKeyDictionary[ + asyncio.AbstractEventLoop, _PerLoopSessionLockManager + ] = weakref.WeakKeyDictionary() + + def _get_loop_manager(self) -> _PerLoopSessionLockManager: + """Get the lock manager for the current event loop.""" + loop = asyncio.get_running_loop() + with self._state_guard: + return self._loop_managers.setdefault(loop, _PerLoopSessionLockManager()) + + @asynccontextmanager + async def acquire_lock(self, session_id: str): + manager = self._get_loop_manager() + async with manager.acquire_lock(session_id): + yield + + session_lock_manager = SessionLockManager() diff --git a/main.py b/main.py index 36c46fca3..1cc900982 100644 --- a/main.py +++ b/main.py @@ -101,6 +101,26 @@ async def check_dashboard_files(webui_dir: str | None = None): return data_dist_path +async def main_async(webui_dir_arg: str | None) -> None: + """主异步入口""" + # 检查仪表板文件 + webui_dir = await check_dashboard_files(webui_dir_arg) + if webui_dir is None: + logger.warning( + "管理面板文件检查失败,WebUI 功能将不可用。" + "请检查网络连接或手动指定 --webui-dir 参数。" + ) + + db = db_helper + + # 打印 logo + logger.info(logo_tmpl) + + core_lifecycle = InitialLoader(db, log_broker) + core_lifecycle.webui_dir = webui_dir + await core_lifecycle.start() + + if __name__ == "__main__": parser = argparse.ArgumentParser(description="AstrBot") parser.add_argument( @@ -117,14 +137,5 @@ if __name__ == "__main__": log_broker = LogBroker() LogManager.set_queue_handler(logger, log_broker) - # 检查仪表板文件 - webui_dir = asyncio.run(check_dashboard_files(args.webui_dir)) - - db = db_helper - - # 打印 logo - logger.info(logo_tmpl) - - core_lifecycle = InitialLoader(db, log_broker) - core_lifecycle.webui_dir = webui_dir - asyncio.run(core_lifecycle.start()) + # 只使用一次 asyncio.run() + asyncio.run(main_async(args.webui_dir)) diff --git a/tests/unit/test_session_lock.py b/tests/unit/test_session_lock.py new file mode 100644 index 000000000..fea686b11 --- /dev/null +++ b/tests/unit/test_session_lock.py @@ -0,0 +1,545 @@ +"""Tests for SessionLockManager with multi-event-loop isolation.""" + +import asyncio +import threading +import time +import weakref +from concurrent.futures import ThreadPoolExecutor + +import pytest + +from astrbot.core.utils.session_lock import SessionLockManager + + +class TestSessionLockManagerBasic: + """Basic functionality tests.""" + + def test_init(self): + """Test manager initialization.""" + manager = SessionLockManager() + assert manager._state_guard is not None + assert manager._loop_managers is not None + + @pytest.mark.asyncio + async def test_acquire_release_lock(self): + """Test basic lock acquire and release.""" + manager = SessionLockManager() + session_id = "test-session" + + async with manager.acquire_lock(session_id): + # Lock acquired successfully + pass + + # Lock should be released and cleaned up + state = manager._get_loop_manager() + assert session_id not in state._locks + assert session_id not in state._lock_count + + @pytest.mark.asyncio + async def test_lock_is_reusable(self): + """Test that locks can be acquired multiple times.""" + manager = SessionLockManager() + session_id = "test-session" + + async with manager.acquire_lock(session_id): + pass + + async with manager.acquire_lock(session_id): + pass + + # Both acquisitions should succeed + + +class TestCrossLoopIsolation: + """Tests for event loop isolation.""" + + @pytest.mark.asyncio + async def test_different_loops_have_different_managers(self): + """Test that different event loops get different per-loop managers.""" + manager = SessionLockManager() + + # Get manager for current loop + manager1 = manager._get_loop_manager() + + # Run in a different event loop + def run_in_new_loop(): + new_loop = asyncio.new_event_loop() + try: + asyncio.set_event_loop(new_loop) + + async def get_manager(): + return manager._get_loop_manager() + + return new_loop.run_until_complete(get_manager()) + finally: + new_loop.close() + asyncio.set_event_loop(None) + + with ThreadPoolExecutor(max_workers=1) as executor: + future = executor.submit(run_in_new_loop) + manager2 = future.result() + + # Should be different manager instances + assert manager1 is not manager2 + + @pytest.mark.asyncio + async def test_locks_isolated_across_loops(self): + """Test that locks from different loops are isolated.""" + manager = SessionLockManager() + session_id = "shared-session" + results = [] + + async def acquire_in_loop(loop_id: int): + """Acquire lock in a new event loop.""" + async with manager.acquire_lock(session_id): + results.append(f"loop-{loop_id}-acquired") + await asyncio.sleep(0.05) + results.append(f"loop-{loop_id}-released") + + def run_in_thread(loop_id: int): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + loop.run_until_complete(acquire_in_loop(loop_id)) + finally: + loop.close() + asyncio.set_event_loop(None) + + # Run two loops concurrently - they should NOT block each other + # because locks are isolated per-loop + with ThreadPoolExecutor(max_workers=2) as executor: + futures = [executor.submit(run_in_thread, i) for i in range(2)] + for f in futures: + f.result() + + # Both loops should acquire immediately (no blocking between loops) + # Order should show interleaved acquisitions, not sequential + assert len(results) == 4 + + @pytest.mark.asyncio + async def test_same_loop_blocks_on_same_session(self): + """Test that same loop blocks when acquiring same session lock.""" + manager = SessionLockManager() + session_id = "test-session" + execution_order = [] + + async def task1(): + async with manager.acquire_lock(session_id): + execution_order.append("task1-start") + await asyncio.sleep(0.1) + execution_order.append("task1-end") + + async def task2(): + await asyncio.sleep(0.01) # Let task1 start first + async with manager.acquire_lock(session_id): + execution_order.append("task2-start") + execution_order.append("task2-end") + + await asyncio.gather(task1(), task2()) + + # task2 should wait for task1 to finish + assert execution_order.index("task1-start") < execution_order.index("task1-end") + assert execution_order.index("task1-end") < execution_order.index("task2-start") + + +class TestConcurrency: + """Tests for concurrent access.""" + + @pytest.mark.asyncio + async def test_concurrent_acquisitions_same_loop(self): + """Test concurrent lock acquisitions on the same loop.""" + manager = SessionLockManager() + session_id = "concurrent-session" + acquired_count = 0 + max_concurrent = 0 + lock = asyncio.Lock() + + async def acquire_and_check(): + nonlocal acquired_count, max_concurrent + async with manager.acquire_lock(session_id): + async with lock: + acquired_count += 1 + max_concurrent = max(max_concurrent, acquired_count) + await asyncio.sleep(0.01) + async with lock: + acquired_count -= 1 + + # Run multiple concurrent tasks + tasks = [acquire_and_check() for _ in range(5)] + await asyncio.gather(*tasks) + + # Max concurrent should be 1 (lock serializes access) + assert max_concurrent == 1 + + @pytest.mark.asyncio + async def test_thread_safety_of_loop_manager_creation(self): + """Test that _get_loop_manager is thread-safe.""" + manager = SessionLockManager() + managers = [] + errors = [] + + def create_loop_and_get_manager(): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + + async def get_mgr(): + return manager._get_loop_manager() + + mgr = loop.run_until_complete(get_mgr()) + managers.append(mgr) + except Exception as e: + errors.append(e) + finally: + loop.close() + asyncio.set_event_loop(None) + + threads = [threading.Thread(target=create_loop_and_get_manager) for _ in range(10)] + for t in threads: + t.start() + for t in threads: + t.join() + + assert len(errors) == 0 + # All managers should be valid + for m in managers: + assert hasattr(m, "_locks") + assert hasattr(m, "_access_lock") + + +class TestEventLoopCleanup: + """Tests for event loop cleanup behavior.""" + + @pytest.mark.asyncio + async def test_weakref_cleanup_on_loop_close(self): + """Test that per-loop managers are cleaned up when loop is closed.""" + manager = SessionLockManager() + loop_ref: weakref.ref[asyncio.AbstractEventLoop] | None = None + + def run_in_new_loop(): + nonlocal loop_ref + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop_ref = weakref.ref(loop) + + async def use_lock(): + async with manager.acquire_lock("test-session"): + pass + return manager._get_loop_manager() + + try: + per_loop_mgr = loop.run_until_complete(use_lock()) + # Keep a weak ref to the per-loop manager + return weakref.ref(per_loop_mgr) + finally: + loop.close() + asyncio.set_event_loop(None) + + with ThreadPoolExecutor(max_workers=1) as executor: + future = executor.submit(run_in_new_loop) + per_loop_mgr_ref = future.result() + + # Give time for weakref cleanup + import gc + + gc.collect() + + # The per-loop manager should be cleaned up when the loop is closed + # because WeakKeyDictionary removes entries when the key (loop) is gone + per_loop_mgr = per_loop_mgr_ref() + loop = loop_ref() if loop_ref is not None else None + assert per_loop_mgr is None or loop is None + + @pytest.mark.asyncio + async def test_access_after_loop_close_in_new_loop_works(self): + """Test that accessing from a new loop after old loop closes works.""" + manager = SessionLockManager() + + # Use lock in current loop + async with manager.acquire_lock("session-1"): + pass + + # Simulate old loop being closed and new loop being created + def run_in_new_loop(): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + + async def use_lock(): + # Should work without issues in new loop + async with manager.acquire_lock("session-2"): + return "success" + + return loop.run_until_complete(use_lock()) + finally: + loop.close() + asyncio.set_event_loop(None) + + with ThreadPoolExecutor(max_workers=1) as executor: + future = executor.submit(run_in_new_loop) + result = future.result() + + assert result == "success" + + +class TestIssue5464: + """Tests for issue #5464: Multiple OneBot instances with different event loops. + + Issue: Running multiple OneBot adapter instances causes + "is bound to a different event loop" error. + """ + + @pytest.mark.asyncio + async def test_multiple_event_loops_no_cross_loop_error(self): + """Test that multiple event loops don't cause cross-loop binding errors. + + This simulates the scenario where multiple OneBot instances + (each potentially running in different event loops) access the + same SessionLockManager concurrently. + """ + from astrbot.core.utils.session_lock import session_lock_manager + + errors: list[Exception] = [] + results: list[str] = [] + + def simulate_onebot_instance(instance_id: int, session_ids: list[str]): + """Simulate a OneBot instance running in its own event loop.""" + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + + async def process_messages(): + for session_id in session_ids: + try: + async with session_lock_manager.acquire_lock(session_id): + # Simulate message processing + await asyncio.sleep(0.01) + results.append(f"instance-{instance_id}-{session_id}") + except Exception as e: + errors.append(e) + + loop.run_until_complete(process_messages()) + finally: + loop.close() + asyncio.set_event_loop(None) + + # Simulate 4 OneBot instances (as in the issue report) + # Each handles multiple sessions concurrently + threads = [] + for i in range(4): + sessions = [f"session-{i}-1", f"session-{i}-2", f"session-{i}-3"] + t = threading.Thread(target=simulate_onebot_instance, args=(i, sessions)) + threads.append(t) + + for t in threads: + t.start() + for t in threads: + t.join() + + # Should have no errors (especially no "bound to a different event loop") + assert len(errors) == 0, f"Errors occurred: {errors}" + assert len(results) == 12 # 4 instances * 3 sessions each + + @pytest.mark.asyncio + async def test_lock_object_not_shared_across_loops(self): + """Verify that asyncio.Lock objects are not shared across event loops. + + The root cause of issue #5464 was that Lock objects created in one + event loop were being used in another, causing the error. + """ + manager = SessionLockManager() + session_id = "shared-session-id" + lock_ids: set[int] = set() + lock_id_lock = threading.Lock() + + def get_lock_in_new_loop(): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + + async def acquire_and_capture(): + # Get the per-loop manager + per_loop_mgr = manager._get_loop_manager() + # Capture the lock object id before acquiring + async with per_loop_mgr._access_lock: + lock = per_loop_mgr._locks[session_id] + with lock_id_lock: + lock_ids.add(id(lock)) + async with manager.acquire_lock(session_id): + await asyncio.sleep(0.01) + + loop.run_until_complete(acquire_and_capture()) + finally: + loop.close() + asyncio.set_event_loop(None) + + # Run multiple loops concurrently + threads = [threading.Thread(target=get_lock_in_new_loop) for _ in range(5)] + for t in threads: + t.start() + for t in threads: + t.join() + + # Each loop should have its own Lock object + # If locks were shared, we'd only have 1 lock_id + assert len(lock_ids) == 5, "Each event loop should have its own Lock object" + + @pytest.mark.asyncio + async def test_concurrent_access_same_session_different_loops(self): + """Test that same session ID accessed from different loops doesn't block. + + This verifies the fix: locks are isolated per event loop, + so different loops can acquire the "same" session lock concurrently. + """ + from astrbot.core.utils.session_lock import session_lock_manager + + session_id = "global-session" + acquisition_times: list[float] = [] + time_lock = threading.Lock() + + def acquire_lock_in_loop(loop_id: int): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + + async def acquire(): + import time + + start = time.time() + async with session_lock_manager.acquire_lock(session_id): + with time_lock: + acquisition_times.append(start) + await asyncio.sleep(0.1) # Hold the lock + + loop.run_until_complete(acquire()) + finally: + loop.close() + asyncio.set_event_loop(None) + + # Start 3 threads nearly simultaneously + threads = [threading.Thread(target=acquire_lock_in_loop, args=(i,)) for i in range(3)] + + start_time = time.time() + for t in threads: + t.start() + for t in threads: + t.join() + total_time = time.time() - start_time + + # If locks were NOT isolated, we'd need ~0.3s (3 * 0.1s serial) + # With isolation, all should complete in ~0.1s (parallel) + # Allow some overhead, but should be much less than 0.3s + assert total_time < 0.25, ( + f"Locks should be isolated per loop, but took {total_time:.2f}s" + ) + + +class TestEdgeCases: + """Tests for edge cases.""" + + @pytest.mark.asyncio + async def test_empty_session_id(self): + """Test with empty session ID.""" + manager = SessionLockManager() + + async with manager.acquire_lock(""): + pass + + # Should work without issues + + @pytest.mark.asyncio + async def test_special_characters_in_session_id(self): + """Test with special characters in session ID.""" + manager = SessionLockManager() + session_id = "session-with-special-chars!@#$%^&*()" + + async with manager.acquire_lock(session_id): + pass + + # Should work without issues + + @pytest.mark.asyncio + async def test_very_long_session_id(self): + """Test with very long session ID.""" + manager = SessionLockManager() + session_id = "a" * 10000 + + async with manager.acquire_lock(session_id): + pass + + # Should work without issues + + @pytest.mark.asyncio + async def test_lock_not_held_after_context_exit(self): + """Test that lock is released after context manager exit.""" + manager = SessionLockManager() + session_id = "test-session" + + async with manager.acquire_lock(session_id): + state = manager._get_loop_manager() + # Lock should exist and have count 1 + assert session_id in state._locks + assert state._lock_count[session_id] == 1 + + # After exit, lock should be cleaned up + state = manager._get_loop_manager() + assert session_id not in state._locks + assert session_id not in state._lock_count + + @pytest.mark.asyncio + async def test_exception_during_lock(self): + """Test that lock is released even if exception occurs.""" + manager = SessionLockManager() + session_id = "test-session" + + with pytest.raises(ValueError): + async with manager.acquire_lock(session_id): + raise ValueError("test error") + + # Lock should still be released + state = manager._get_loop_manager() + assert session_id not in state._locks + assert session_id not in state._lock_count + + @pytest.mark.asyncio + async def test_nested_lock_different_sessions(self): + """Test nested locks on different sessions.""" + manager = SessionLockManager() + + async with manager.acquire_lock("session-1"): + async with manager.acquire_lock("session-2"): + state = manager._get_loop_manager() + assert "session-1" in state._locks + assert "session-2" in state._locks + assert state._lock_count["session-1"] == 1 + assert state._lock_count["session-2"] == 1 + + state = manager._get_loop_manager() + assert "session-1" not in state._locks + assert "session-2" not in state._locks + + @pytest.mark.asyncio + async def test_reentrant_lock_same_session(self): + """Test reentrant locking on same session (should block).""" + manager = SessionLockManager() + session_id = "test-session" + order = [] + + async def outer(): + async with manager.acquire_lock(session_id): + order.append("outer-acquired") + await asyncio.sleep(0.1) + order.append("outer-done") + + async def inner(): + await asyncio.sleep(0.01) # Let outer acquire first + order.append("inner-attempt") + async with manager.acquire_lock(session_id): + order.append("inner-acquired") + order.append("inner-done") + + await asyncio.gather(outer(), inner()) + + # Inner should wait for outer to complete + assert order.index("outer-acquired") < order.index("outer-done") + assert order.index("outer-done") < order.index("inner-acquired")