fix: 修复 asyncio 事件循环相关问题 (#5774)
* fix: 修复 asyncio 事件循环相关的问题 1. components.py: 修复异常处理结构错误 - 将 except Exception 移到正确的内部 try 块 - 确保 _download_file() 异常能被正确捕获和记录 2. session_lock.py: 修复跨事件循环 Lock 绑定问题 - 添加 _access_lock_loop_id 追踪事件循环 - 当事件循环变化时重新创建 Lock Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix: 根据代码审查反馈修复问题 1. components.py: 移除 asyncio.set_event_loop() 调用 - 创建临时 event loop 时不再设置为全局 - 避免干扰其他 asyncio 使用 2. session_lock.py: 简化延迟初始化逻辑 - 移除 loop-ID 追踪和 _get_lock 方法 - 使用 setdefault 简化 session lock 创建 - 保留延迟初始化行为 3. wecomai_queue_mgr.py: 使用 time.monotonic() 替代 loop.time() - 同步方法不再依赖活动的 event loop - 避免在非异步上下文中抛出 RuntimeError Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix: 优化 asyncio 事件循环管理,使用安全的方式创建和关闭事件循环 * fix: 根据代码审查反馈改进异常处理和事件循环使用 - main.py: 显式处理 check_dashboard_files() 返回 None 的情况 - components.py: 使用 logger.exception 保留异常堆栈信息 - star_manager.py: 添加 Future 异常回调处理 __del__ 执行异常 - bay_manager.py: 缓存事件循环引用避免重复调用 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * refactor: 简化 SessionLockManager 使用 defaultdict 和 setdefault - 使用 defaultdict(asyncio.Lock) 简化锁的懒创建 - 使用 setdefault 简化 _get_loop_state 逻辑 - 减少 get + if 分支,提升可读性 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix: 降低 webui_dir 检查失败时的日志级别为 warning 改为警告而非退出,允许程序在无 WebUI 的情况下继续运行 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * refactor: 重构事件循环锁管理,简化锁状态管理逻辑 * 新增对 SessionLockManager 的多事件循环隔离测试 * fix: 修复测试中的变量声明和断言,确保事件循环管理器的正确性 * fix: 修复插件删除时异常处理逻辑,确保正确记录错误信息 * fix: 新增针对多个事件循环的 OneBot 实例的测试,确保锁对象在不同事件循环间不共享 --------- Co-authored-by: whatevertogo <whatevertogo@users.noreply.github.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -302,7 +302,7 @@ class DashscopeAgentRunner(BaseAgentRunner[TContext]):
|
|||||||
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
item_type, item_data = await asyncio.get_event_loop().run_in_executor(
|
item_type, item_data = await asyncio.get_running_loop().run_in_executor(
|
||||||
None, response_queue.get, True, 1
|
None, response_queue.get, True, 1
|
||||||
)
|
)
|
||||||
except queue.Empty:
|
except queue.Empty:
|
||||||
@@ -388,7 +388,7 @@ class DashscopeAgentRunner(BaseAgentRunner[TContext]):
|
|||||||
|
|
||||||
# 发起请求
|
# 发起请求
|
||||||
partial = functools.partial(Application.call, **payload)
|
partial = functools.partial(Application.call, **payload)
|
||||||
response = await asyncio.get_event_loop().run_in_executor(None, partial)
|
response = await asyncio.get_running_loop().run_in_executor(None, partial)
|
||||||
|
|
||||||
async for resp in self._handle_streaming_response(response, session_id):
|
async for resp in self._handle_streaming_response(response, session_id):
|
||||||
yield resp
|
yield resp
|
||||||
|
|||||||
@@ -121,11 +121,12 @@ class BayContainerManager:
|
|||||||
async def wait_healthy(self, timeout: int = HEALTH_TIMEOUT_S) -> None:
|
async def wait_healthy(self, timeout: int = HEALTH_TIMEOUT_S) -> None:
|
||||||
"""Block until Bay's ``/health`` endpoint returns 200."""
|
"""Block until Bay's ``/health`` endpoint returns 200."""
|
||||||
url = f"http://127.0.0.1:{self._host_port}/health"
|
url = f"http://127.0.0.1:{self._host_port}/health"
|
||||||
deadline = asyncio.get_event_loop().time() + timeout
|
loop = asyncio.get_running_loop()
|
||||||
|
deadline = loop.time() + timeout
|
||||||
last_error: str = ""
|
last_error: str = ""
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
while asyncio.get_event_loop().time() < deadline:
|
while loop.time() < deadline:
|
||||||
try:
|
try:
|
||||||
async with session.get(
|
async with session.get(
|
||||||
url, timeout=aiohttp.ClientTimeout(total=3)
|
url, timeout=aiohttp.ClientTimeout(total=3)
|
||||||
|
|||||||
@@ -699,21 +699,24 @@ class File(BaseMessageComponent):
|
|||||||
|
|
||||||
if self.url:
|
if self.url:
|
||||||
try:
|
try:
|
||||||
loop = asyncio.get_event_loop()
|
# 检查是否有正在运行的 event loop
|
||||||
if loop.is_running():
|
asyncio.get_running_loop()
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"不可以在异步上下文中同步等待下载! "
|
"不可以在异步上下文中同步等待下载! "
|
||||||
"这个警告通常发生于某些逻辑试图通过 <File>.file 获取文件消息段的文件内容。"
|
"这个警告通常发生于某些逻辑试图通过 <File>.file 获取文件消息段的文件内容。"
|
||||||
"请使用 await get_file() 代替直接获取 <File>.file 字段",
|
"请使用 await get_file() 代替直接获取 <File>.file 字段",
|
||||||
)
|
)
|
||||||
return ""
|
return ""
|
||||||
# 等待下载完成
|
except RuntimeError:
|
||||||
loop.run_until_complete(self._download_file())
|
# 没有运行中的 event loop,可以同步执行
|
||||||
|
try:
|
||||||
|
# 使用 asyncio.run 安全地创建和关闭事件循环
|
||||||
|
asyncio.run(self._download_file())
|
||||||
|
except Exception:
|
||||||
|
logger.exception("文件下载失败")
|
||||||
|
|
||||||
if self.file_ and os.path.exists(self.file_):
|
if self.file_ and os.path.exists(self.file_):
|
||||||
return os.path.abspath(self.file_)
|
return os.path.abspath(self.file_)
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"文件下载失败: {e}")
|
|
||||||
|
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|||||||
@@ -367,7 +367,7 @@ class DingtalkPlatformAdapter(Platform):
|
|||||||
|
|
||||||
async def get_access_token(self) -> str:
|
async def get_access_token(self) -> str:
|
||||||
try:
|
try:
|
||||||
access_token = await asyncio.get_event_loop().run_in_executor(
|
access_token = await asyncio.get_running_loop().run_in_executor(
|
||||||
None,
|
None,
|
||||||
self.client_.get_access_token,
|
self.client_.get_access_token,
|
||||||
)
|
)
|
||||||
@@ -760,7 +760,7 @@ class DingtalkPlatformAdapter(Platform):
|
|||||||
return
|
return
|
||||||
logger.error(f"钉钉机器人启动失败: {e}")
|
logger.error(f"钉钉机器人启动失败: {e}")
|
||||||
|
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_running_loop()
|
||||||
await loop.run_in_executor(None, start_client, loop)
|
await loop.run_in_executor(None, start_client, loop)
|
||||||
|
|
||||||
async def terminate(self) -> None:
|
async def terminate(self) -> None:
|
||||||
|
|||||||
@@ -80,7 +80,7 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
|||||||
|
|
||||||
if isinstance(source, botpy.message.C2CMessage):
|
if isinstance(source, botpy.message.C2CMessage):
|
||||||
# 真流式传输
|
# 真流式传输
|
||||||
current_time = asyncio.get_event_loop().time()
|
current_time = asyncio.get_running_loop().time()
|
||||||
time_since_last_edit = current_time - last_edit_time
|
time_since_last_edit = current_time - last_edit_time
|
||||||
|
|
||||||
if time_since_last_edit >= throttle_interval:
|
if time_since_last_edit >= throttle_interval:
|
||||||
@@ -90,7 +90,7 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
|||||||
)
|
)
|
||||||
stream_payload["index"] += 1
|
stream_payload["index"] += 1
|
||||||
stream_payload["id"] = ret["id"]
|
stream_payload["id"] = ret["id"]
|
||||||
last_edit_time = asyncio.get_event_loop().time()
|
last_edit_time = asyncio.get_running_loop().time()
|
||||||
|
|
||||||
if isinstance(source, botpy.message.C2CMessage):
|
if isinstance(source, botpy.message.C2CMessage):
|
||||||
# 结束流式对话,并且传输 buffer 中剩余的消息
|
# 结束流式对话,并且传输 buffer 中剩余的消息
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ class QQOfficialWebhook:
|
|||||||
max_async=1,
|
max_async=1,
|
||||||
connect=bot_connect,
|
connect=bot_connect,
|
||||||
dispatch=self.client.ws_dispatch,
|
dispatch=self.client.ws_dispatch,
|
||||||
loop=asyncio.get_event_loop(),
|
loop=asyncio.get_running_loop(),
|
||||||
api=self.api,
|
api=self.api,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -626,7 +626,7 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
|||||||
|
|
||||||
# 发送初始 typing 状态
|
# 发送初始 typing 状态
|
||||||
await self._ensure_typing(user_name, message_thread_id)
|
await self._ensure_typing(user_name, message_thread_id)
|
||||||
last_chat_action_time = asyncio.get_event_loop().time()
|
last_chat_action_time = asyncio.get_running_loop().time()
|
||||||
|
|
||||||
def _append_text(t: str) -> None:
|
def _append_text(t: str) -> None:
|
||||||
nonlocal delta
|
nonlocal delta
|
||||||
@@ -657,11 +657,11 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
|||||||
|
|
||||||
# 编辑或发送消息
|
# 编辑或发送消息
|
||||||
if message_id and len(delta) <= self.MAX_MESSAGE_LENGTH:
|
if message_id and len(delta) <= self.MAX_MESSAGE_LENGTH:
|
||||||
current_time = asyncio.get_event_loop().time()
|
current_time = asyncio.get_running_loop().time()
|
||||||
time_since_last_edit = current_time - last_edit_time
|
time_since_last_edit = current_time - last_edit_time
|
||||||
|
|
||||||
if time_since_last_edit >= throttle_interval:
|
if time_since_last_edit >= throttle_interval:
|
||||||
current_time = asyncio.get_event_loop().time()
|
current_time = asyncio.get_running_loop().time()
|
||||||
if current_time - last_chat_action_time >= chat_action_interval:
|
if current_time - last_chat_action_time >= chat_action_interval:
|
||||||
await self._ensure_typing(user_name, message_thread_id)
|
await self._ensure_typing(user_name, message_thread_id)
|
||||||
last_chat_action_time = current_time
|
last_chat_action_time = current_time
|
||||||
@@ -674,9 +674,9 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
|||||||
current_content = delta
|
current_content = delta
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"编辑消息失败(streaming): {e!s}")
|
logger.warning(f"编辑消息失败(streaming): {e!s}")
|
||||||
last_edit_time = asyncio.get_event_loop().time()
|
last_edit_time = asyncio.get_running_loop().time()
|
||||||
else:
|
else:
|
||||||
current_time = asyncio.get_event_loop().time()
|
current_time = asyncio.get_running_loop().time()
|
||||||
if current_time - last_chat_action_time >= chat_action_interval:
|
if current_time - last_chat_action_time >= chat_action_interval:
|
||||||
await self._ensure_typing(user_name, message_thread_id)
|
await self._ensure_typing(user_name, message_thread_id)
|
||||||
last_chat_action_time = current_time
|
last_chat_action_time = current_time
|
||||||
@@ -688,7 +688,7 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"发送消息失败(streaming): {e!s}")
|
logger.warning(f"发送消息失败(streaming): {e!s}")
|
||||||
message_id = msg.message_id
|
message_id = msg.message_id
|
||||||
last_edit_time = asyncio.get_event_loop().time()
|
last_edit_time = asyncio.get_running_loop().time()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if delta and current_content != delta:
|
if delta and current_content != delta:
|
||||||
|
|||||||
@@ -200,7 +200,7 @@ class WecomPlatformAdapter(Platform):
|
|||||||
return msg_list[-1]
|
return msg_list[-1]
|
||||||
return None
|
return None
|
||||||
|
|
||||||
msg_new = await asyncio.get_event_loop().run_in_executor(
|
msg_new = await asyncio.get_running_loop().run_in_executor(
|
||||||
None,
|
None,
|
||||||
get_latest_msg_item,
|
get_latest_msg_item,
|
||||||
)
|
)
|
||||||
@@ -261,7 +261,7 @@ class WecomPlatformAdapter(Platform):
|
|||||||
|
|
||||||
@override
|
@override
|
||||||
async def run(self) -> None:
|
async def run(self) -> None:
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_running_loop()
|
||||||
if self.kf_name:
|
if self.kf_name:
|
||||||
try:
|
try:
|
||||||
acc_list = (
|
acc_list = (
|
||||||
@@ -339,7 +339,7 @@ class WecomPlatformAdapter(Platform):
|
|||||||
abm.session_id = abm.sender.user_id
|
abm.session_id = abm.sender.user_id
|
||||||
abm.raw_message = msg
|
abm.raw_message = msg
|
||||||
elif isinstance(msg, VoiceMessage):
|
elif isinstance(msg, VoiceMessage):
|
||||||
resp: Response = await asyncio.get_event_loop().run_in_executor(
|
resp: Response = await asyncio.get_running_loop().run_in_executor(
|
||||||
None,
|
None,
|
||||||
self.client.media.download,
|
self.client.media.download,
|
||||||
msg.media_id,
|
msg.media_id,
|
||||||
@@ -395,7 +395,7 @@ class WecomPlatformAdapter(Platform):
|
|||||||
abm.message_str = text
|
abm.message_str = text
|
||||||
elif msgtype == "image":
|
elif msgtype == "image":
|
||||||
media_id = msg.get("image", {}).get("media_id", "")
|
media_id = msg.get("image", {}).get("media_id", "")
|
||||||
resp: Response = await asyncio.get_event_loop().run_in_executor(
|
resp: Response = await asyncio.get_running_loop().run_in_executor(
|
||||||
None,
|
None,
|
||||||
self.client.media.download,
|
self.client.media.download,
|
||||||
media_id,
|
media_id,
|
||||||
@@ -407,7 +407,7 @@ class WecomPlatformAdapter(Platform):
|
|||||||
abm.message = [Image(file=path, url=path)]
|
abm.message = [Image(file=path, url=path)]
|
||||||
elif msgtype == "voice":
|
elif msgtype == "voice":
|
||||||
media_id = msg.get("voice", {}).get("media_id", "")
|
media_id = msg.get("voice", {}).get("media_id", "")
|
||||||
resp: Response = await asyncio.get_event_loop().run_in_executor(
|
resp: Response = await asyncio.get_running_loop().run_in_executor(
|
||||||
None,
|
None,
|
||||||
self.client.media.download,
|
self.client.media.download,
|
||||||
media_id,
|
media_id,
|
||||||
|
|||||||
@@ -4,6 +4,7 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import time
|
||||||
from collections.abc import Awaitable, Callable
|
from collections.abc import Awaitable, Callable
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@@ -82,7 +83,7 @@ class WecomAIQueueMgr:
|
|||||||
del self.pending_responses[session_id]
|
del self.pending_responses[session_id]
|
||||||
logger.debug(f"[WecomAI] 移除待处理响应: {session_id}")
|
logger.debug(f"[WecomAI] 移除待处理响应: {session_id}")
|
||||||
if mark_finished:
|
if mark_finished:
|
||||||
self.completed_streams[session_id] = asyncio.get_event_loop().time()
|
self.completed_streams[session_id] = time.monotonic()
|
||||||
logger.debug(f"[WecomAI] 标记流已结束: {session_id}")
|
logger.debug(f"[WecomAI] 标记流已结束: {session_id}")
|
||||||
|
|
||||||
def remove_queue(self, session_id: str):
|
def remove_queue(self, session_id: str):
|
||||||
@@ -135,7 +136,7 @@ class WecomAIQueueMgr:
|
|||||||
"""
|
"""
|
||||||
self.pending_responses[session_id] = {
|
self.pending_responses[session_id] = {
|
||||||
"callback_params": callback_params,
|
"callback_params": callback_params,
|
||||||
"timestamp": asyncio.get_event_loop().time(),
|
"timestamp": time.monotonic(),
|
||||||
}
|
}
|
||||||
logger.debug(f"[WecomAI] 设置待处理响应: {session_id}")
|
logger.debug(f"[WecomAI] 设置待处理响应: {session_id}")
|
||||||
|
|
||||||
@@ -160,7 +161,7 @@ class WecomAIQueueMgr:
|
|||||||
finished_at = self.completed_streams.get(session_id)
|
finished_at = self.completed_streams.get(session_id)
|
||||||
if finished_at is None:
|
if finished_at is None:
|
||||||
return False
|
return False
|
||||||
if asyncio.get_event_loop().time() - finished_at > max_age_seconds:
|
if time.monotonic() - finished_at > max_age_seconds:
|
||||||
self.completed_streams.pop(session_id, None)
|
self.completed_streams.pop(session_id, None)
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
@@ -172,7 +173,7 @@ class WecomAIQueueMgr:
|
|||||||
max_age_seconds: 最大存活时间(秒)
|
max_age_seconds: 最大存活时间(秒)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
current_time = asyncio.get_event_loop().time()
|
current_time = time.monotonic()
|
||||||
expired_sessions = []
|
expired_sessions = []
|
||||||
|
|
||||||
for session_id, response_data in self.pending_responses.items():
|
for session_id, response_data in self.pending_responses.items():
|
||||||
|
|||||||
@@ -369,7 +369,7 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
|
|||||||
if future:
|
if future:
|
||||||
logger.debug(f"duplicate message id checked: {msg.id}")
|
logger.debug(f"duplicate message id checked: {msg.id}")
|
||||||
else:
|
else:
|
||||||
future = asyncio.get_event_loop().create_future()
|
future = asyncio.get_running_loop().create_future()
|
||||||
self.wexin_event_workers[msg_id] = future
|
self.wexin_event_workers[msg_id] = future
|
||||||
await self.convert_message(msg, future)
|
await self.convert_message(msg, future)
|
||||||
# I love shield so much!
|
# I love shield so much!
|
||||||
@@ -461,7 +461,7 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
|
|||||||
elif msg.type == "voice":
|
elif msg.type == "voice":
|
||||||
assert isinstance(msg, VoiceMessage)
|
assert isinstance(msg, VoiceMessage)
|
||||||
|
|
||||||
resp: Response = await asyncio.get_event_loop().run_in_executor(
|
resp: Response = await asyncio.get_running_loop().run_in_executor(
|
||||||
None,
|
None,
|
||||||
self.client.media.download,
|
self.client.media.download,
|
||||||
msg.media_id,
|
msg.media_id,
|
||||||
|
|||||||
@@ -87,7 +87,7 @@ class ProviderDashscopeTTSAPI(TTSProvider):
|
|||||||
model: str,
|
model: str,
|
||||||
text: str,
|
text: str,
|
||||||
) -> tuple[bytes | None, str]:
|
) -> tuple[bytes | None, str]:
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_running_loop()
|
||||||
response = await loop.run_in_executor(None, self._call_qwen_tts, model, text)
|
response = await loop.run_in_executor(None, self._call_qwen_tts, model, text)
|
||||||
audio_bytes = await self._extract_audio_from_response(response)
|
audio_bytes = await self._extract_audio_from_response(response)
|
||||||
if not audio_bytes:
|
if not audio_bytes:
|
||||||
@@ -143,7 +143,7 @@ class ProviderDashscopeTTSAPI(TTSProvider):
|
|||||||
voice=self.voice,
|
voice=self.voice,
|
||||||
format=AudioFormat.WAV_24000HZ_MONO_16BIT,
|
format=AudioFormat.WAV_24000HZ_MONO_16BIT,
|
||||||
)
|
)
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_running_loop()
|
||||||
audio_bytes = await loop.run_in_executor(
|
audio_bytes = await loop.run_in_executor(
|
||||||
None,
|
None,
|
||||||
synthesizer.call,
|
synthesizer.call,
|
||||||
|
|||||||
@@ -59,7 +59,7 @@ class GenieTTSProvider(TTSProvider):
|
|||||||
filename = f"genie_tts_{uuid.uuid4()}.wav"
|
filename = f"genie_tts_{uuid.uuid4()}.wav"
|
||||||
path = os.path.join(temp_dir, filename)
|
path = os.path.join(temp_dir, filename)
|
||||||
|
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_running_loop()
|
||||||
|
|
||||||
def _generate(save_path: str) -> None:
|
def _generate(save_path: str) -> None:
|
||||||
assert genie is not None
|
assert genie is not None
|
||||||
@@ -85,7 +85,7 @@ class GenieTTSProvider(TTSProvider):
|
|||||||
text_queue: asyncio.Queue[str | None],
|
text_queue: asyncio.Queue[str | None],
|
||||||
audio_queue: "asyncio.Queue[bytes | tuple[str, bytes] | None]",
|
audio_queue: "asyncio.Queue[bytes | tuple[str, bytes] | None]",
|
||||||
) -> None:
|
) -> None:
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_running_loop()
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
text = await text_queue.get()
|
text = await text_queue.get()
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ class ProviderSenseVoiceSTTSelfHost(STTProvider):
|
|||||||
logger.info("下载或者加载 SenseVoice 模型中,这可能需要一些时间 ...")
|
logger.info("下载或者加载 SenseVoice 模型中,这可能需要一些时间 ...")
|
||||||
|
|
||||||
# 将模型加载放到线程池中执行
|
# 将模型加载放到线程池中执行
|
||||||
self.model = await asyncio.get_event_loop().run_in_executor(
|
self.model = await asyncio.get_running_loop().run_in_executor(
|
||||||
None,
|
None,
|
||||||
lambda: SenseVoiceSmall(self.model_name, quantize=True, batch_size=16),
|
lambda: SenseVoiceSmall(self.model_name, quantize=True, batch_size=16),
|
||||||
)
|
)
|
||||||
@@ -88,7 +88,7 @@ class ProviderSenseVoiceSTTSelfHost(STTProvider):
|
|||||||
audio_url = output_path
|
audio_url = output_path
|
||||||
|
|
||||||
# 使用 run_in_executor 来调用模型进行识别
|
# 使用 run_in_executor 来调用模型进行识别
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_running_loop()
|
||||||
res = await loop.run_in_executor(
|
res = await loop.run_in_executor(
|
||||||
None, # 使用默认的线程池
|
None, # 使用默认的线程池
|
||||||
lambda: cast(SenseVoiceSmall, self.model)(
|
lambda: cast(SenseVoiceSmall, self.model)(
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ class ProviderOpenAIWhisperSelfHost(STTProvider):
|
|||||||
self.model = None
|
self.model = None
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_running_loop()
|
||||||
logger.info("下载或者加载 Whisper 模型中,这可能需要一些时间 ...")
|
logger.info("下载或者加载 Whisper 模型中,这可能需要一些时间 ...")
|
||||||
self.model = await loop.run_in_executor(
|
self.model = await loop.run_in_executor(
|
||||||
None,
|
None,
|
||||||
@@ -50,7 +50,7 @@ class ProviderOpenAIWhisperSelfHost(STTProvider):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
async def get_text(self, audio_url: str) -> str:
|
async def get_text(self, audio_url: str) -> str:
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_running_loop()
|
||||||
|
|
||||||
is_tencent = False
|
is_tencent = False
|
||||||
|
|
||||||
|
|||||||
@@ -1374,10 +1374,23 @@ class PluginManager:
|
|||||||
return
|
return
|
||||||
|
|
||||||
if "__del__" in star_metadata.star_cls_type.__dict__:
|
if "__del__" in star_metadata.star_cls_type.__dict__:
|
||||||
asyncio.get_event_loop().run_in_executor(
|
loop = asyncio.get_running_loop()
|
||||||
|
future = loop.run_in_executor(
|
||||||
None,
|
None,
|
||||||
star_metadata.star_cls.__del__,
|
star_metadata.star_cls.__del__,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _log_del_exception(fut: asyncio.Future) -> None:
|
||||||
|
if fut.cancelled():
|
||||||
|
return
|
||||||
|
if (exc := fut.exception()) is not None:
|
||||||
|
logger.error(
|
||||||
|
"插件 %s 在 __del__ 中抛出了异常:%r",
|
||||||
|
star_metadata.name,
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
|
||||||
|
future.add_done_callback(_log_del_exception)
|
||||||
elif "terminate" in star_metadata.star_cls_type.__dict__:
|
elif "terminate" in star_metadata.star_cls_type.__dict__:
|
||||||
await star_metadata.star_cls.terminate()
|
await star_metadata.star_cls.terminate()
|
||||||
|
|
||||||
|
|||||||
@@ -1,9 +1,13 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
import threading
|
||||||
|
import weakref
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
|
|
||||||
class SessionLockManager:
|
class _PerLoopSessionLockManager:
|
||||||
|
"""Per-event-loop session lock manager; keeps original simple semantics."""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self._locks: dict[str, asyncio.Lock] = defaultdict(asyncio.Lock)
|
self._locks: dict[str, asyncio.Lock] = defaultdict(asyncio.Lock)
|
||||||
self._lock_count: dict[str, int] = defaultdict(int)
|
self._lock_count: dict[str, int] = defaultdict(int)
|
||||||
@@ -26,4 +30,26 @@ class SessionLockManager:
|
|||||||
self._lock_count.pop(session_id, None)
|
self._lock_count.pop(session_id, None)
|
||||||
|
|
||||||
|
|
||||||
|
class SessionLockManager:
|
||||||
|
"""Thread-safe session lock manager with per-event-loop isolation."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._state_guard = threading.Lock()
|
||||||
|
self._loop_managers: weakref.WeakKeyDictionary[
|
||||||
|
asyncio.AbstractEventLoop, _PerLoopSessionLockManager
|
||||||
|
] = weakref.WeakKeyDictionary()
|
||||||
|
|
||||||
|
def _get_loop_manager(self) -> _PerLoopSessionLockManager:
|
||||||
|
"""Get the lock manager for the current event loop."""
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
with self._state_guard:
|
||||||
|
return self._loop_managers.setdefault(loop, _PerLoopSessionLockManager())
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def acquire_lock(self, session_id: str):
|
||||||
|
manager = self._get_loop_manager()
|
||||||
|
async with manager.acquire_lock(session_id):
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
session_lock_manager = SessionLockManager()
|
session_lock_manager = SessionLockManager()
|
||||||
|
|||||||
@@ -101,6 +101,26 @@ async def check_dashboard_files(webui_dir: str | None = None):
|
|||||||
return data_dist_path
|
return data_dist_path
|
||||||
|
|
||||||
|
|
||||||
|
async def main_async(webui_dir_arg: str | None) -> None:
|
||||||
|
"""主异步入口"""
|
||||||
|
# 检查仪表板文件
|
||||||
|
webui_dir = await check_dashboard_files(webui_dir_arg)
|
||||||
|
if webui_dir is None:
|
||||||
|
logger.warning(
|
||||||
|
"管理面板文件检查失败,WebUI 功能将不可用。"
|
||||||
|
"请检查网络连接或手动指定 --webui-dir 参数。"
|
||||||
|
)
|
||||||
|
|
||||||
|
db = db_helper
|
||||||
|
|
||||||
|
# 打印 logo
|
||||||
|
logger.info(logo_tmpl)
|
||||||
|
|
||||||
|
core_lifecycle = InitialLoader(db, log_broker)
|
||||||
|
core_lifecycle.webui_dir = webui_dir
|
||||||
|
await core_lifecycle.start()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description="AstrBot")
|
parser = argparse.ArgumentParser(description="AstrBot")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@@ -117,14 +137,5 @@ if __name__ == "__main__":
|
|||||||
log_broker = LogBroker()
|
log_broker = LogBroker()
|
||||||
LogManager.set_queue_handler(logger, log_broker)
|
LogManager.set_queue_handler(logger, log_broker)
|
||||||
|
|
||||||
# 检查仪表板文件
|
# 只使用一次 asyncio.run()
|
||||||
webui_dir = asyncio.run(check_dashboard_files(args.webui_dir))
|
asyncio.run(main_async(args.webui_dir))
|
||||||
|
|
||||||
db = db_helper
|
|
||||||
|
|
||||||
# 打印 logo
|
|
||||||
logger.info(logo_tmpl)
|
|
||||||
|
|
||||||
core_lifecycle = InitialLoader(db, log_broker)
|
|
||||||
core_lifecycle.webui_dir = webui_dir
|
|
||||||
asyncio.run(core_lifecycle.start())
|
|
||||||
|
|||||||
@@ -0,0 +1,545 @@
|
|||||||
|
"""Tests for SessionLockManager with multi-event-loop isolation."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
import weakref
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from astrbot.core.utils.session_lock import SessionLockManager
|
||||||
|
|
||||||
|
|
||||||
|
class TestSessionLockManagerBasic:
|
||||||
|
"""Basic functionality tests."""
|
||||||
|
|
||||||
|
def test_init(self):
|
||||||
|
"""Test manager initialization."""
|
||||||
|
manager = SessionLockManager()
|
||||||
|
assert manager._state_guard is not None
|
||||||
|
assert manager._loop_managers is not None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_acquire_release_lock(self):
|
||||||
|
"""Test basic lock acquire and release."""
|
||||||
|
manager = SessionLockManager()
|
||||||
|
session_id = "test-session"
|
||||||
|
|
||||||
|
async with manager.acquire_lock(session_id):
|
||||||
|
# Lock acquired successfully
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Lock should be released and cleaned up
|
||||||
|
state = manager._get_loop_manager()
|
||||||
|
assert session_id not in state._locks
|
||||||
|
assert session_id not in state._lock_count
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_lock_is_reusable(self):
|
||||||
|
"""Test that locks can be acquired multiple times."""
|
||||||
|
manager = SessionLockManager()
|
||||||
|
session_id = "test-session"
|
||||||
|
|
||||||
|
async with manager.acquire_lock(session_id):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async with manager.acquire_lock(session_id):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Both acquisitions should succeed
|
||||||
|
|
||||||
|
|
||||||
|
class TestCrossLoopIsolation:
|
||||||
|
"""Tests for event loop isolation."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_different_loops_have_different_managers(self):
|
||||||
|
"""Test that different event loops get different per-loop managers."""
|
||||||
|
manager = SessionLockManager()
|
||||||
|
|
||||||
|
# Get manager for current loop
|
||||||
|
manager1 = manager._get_loop_manager()
|
||||||
|
|
||||||
|
# Run in a different event loop
|
||||||
|
def run_in_new_loop():
|
||||||
|
new_loop = asyncio.new_event_loop()
|
||||||
|
try:
|
||||||
|
asyncio.set_event_loop(new_loop)
|
||||||
|
|
||||||
|
async def get_manager():
|
||||||
|
return manager._get_loop_manager()
|
||||||
|
|
||||||
|
return new_loop.run_until_complete(get_manager())
|
||||||
|
finally:
|
||||||
|
new_loop.close()
|
||||||
|
asyncio.set_event_loop(None)
|
||||||
|
|
||||||
|
with ThreadPoolExecutor(max_workers=1) as executor:
|
||||||
|
future = executor.submit(run_in_new_loop)
|
||||||
|
manager2 = future.result()
|
||||||
|
|
||||||
|
# Should be different manager instances
|
||||||
|
assert manager1 is not manager2
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_locks_isolated_across_loops(self):
|
||||||
|
"""Test that locks from different loops are isolated."""
|
||||||
|
manager = SessionLockManager()
|
||||||
|
session_id = "shared-session"
|
||||||
|
results = []
|
||||||
|
|
||||||
|
async def acquire_in_loop(loop_id: int):
|
||||||
|
"""Acquire lock in a new event loop."""
|
||||||
|
async with manager.acquire_lock(session_id):
|
||||||
|
results.append(f"loop-{loop_id}-acquired")
|
||||||
|
await asyncio.sleep(0.05)
|
||||||
|
results.append(f"loop-{loop_id}-released")
|
||||||
|
|
||||||
|
def run_in_thread(loop_id: int):
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
try:
|
||||||
|
loop.run_until_complete(acquire_in_loop(loop_id))
|
||||||
|
finally:
|
||||||
|
loop.close()
|
||||||
|
asyncio.set_event_loop(None)
|
||||||
|
|
||||||
|
# Run two loops concurrently - they should NOT block each other
|
||||||
|
# because locks are isolated per-loop
|
||||||
|
with ThreadPoolExecutor(max_workers=2) as executor:
|
||||||
|
futures = [executor.submit(run_in_thread, i) for i in range(2)]
|
||||||
|
for f in futures:
|
||||||
|
f.result()
|
||||||
|
|
||||||
|
# Both loops should acquire immediately (no blocking between loops)
|
||||||
|
# Order should show interleaved acquisitions, not sequential
|
||||||
|
assert len(results) == 4
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_same_loop_blocks_on_same_session(self):
|
||||||
|
"""Test that same loop blocks when acquiring same session lock."""
|
||||||
|
manager = SessionLockManager()
|
||||||
|
session_id = "test-session"
|
||||||
|
execution_order = []
|
||||||
|
|
||||||
|
async def task1():
|
||||||
|
async with manager.acquire_lock(session_id):
|
||||||
|
execution_order.append("task1-start")
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
execution_order.append("task1-end")
|
||||||
|
|
||||||
|
async def task2():
|
||||||
|
await asyncio.sleep(0.01) # Let task1 start first
|
||||||
|
async with manager.acquire_lock(session_id):
|
||||||
|
execution_order.append("task2-start")
|
||||||
|
execution_order.append("task2-end")
|
||||||
|
|
||||||
|
await asyncio.gather(task1(), task2())
|
||||||
|
|
||||||
|
# task2 should wait for task1 to finish
|
||||||
|
assert execution_order.index("task1-start") < execution_order.index("task1-end")
|
||||||
|
assert execution_order.index("task1-end") < execution_order.index("task2-start")
|
||||||
|
|
||||||
|
|
||||||
|
class TestConcurrency:
|
||||||
|
"""Tests for concurrent access."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_concurrent_acquisitions_same_loop(self):
|
||||||
|
"""Test concurrent lock acquisitions on the same loop."""
|
||||||
|
manager = SessionLockManager()
|
||||||
|
session_id = "concurrent-session"
|
||||||
|
acquired_count = 0
|
||||||
|
max_concurrent = 0
|
||||||
|
lock = asyncio.Lock()
|
||||||
|
|
||||||
|
async def acquire_and_check():
|
||||||
|
nonlocal acquired_count, max_concurrent
|
||||||
|
async with manager.acquire_lock(session_id):
|
||||||
|
async with lock:
|
||||||
|
acquired_count += 1
|
||||||
|
max_concurrent = max(max_concurrent, acquired_count)
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
async with lock:
|
||||||
|
acquired_count -= 1
|
||||||
|
|
||||||
|
# Run multiple concurrent tasks
|
||||||
|
tasks = [acquire_and_check() for _ in range(5)]
|
||||||
|
await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
# Max concurrent should be 1 (lock serializes access)
|
||||||
|
assert max_concurrent == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_thread_safety_of_loop_manager_creation(self):
|
||||||
|
"""Test that _get_loop_manager is thread-safe."""
|
||||||
|
manager = SessionLockManager()
|
||||||
|
managers = []
|
||||||
|
errors = []
|
||||||
|
|
||||||
|
def create_loop_and_get_manager():
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
try:
|
||||||
|
|
||||||
|
async def get_mgr():
|
||||||
|
return manager._get_loop_manager()
|
||||||
|
|
||||||
|
mgr = loop.run_until_complete(get_mgr())
|
||||||
|
managers.append(mgr)
|
||||||
|
except Exception as e:
|
||||||
|
errors.append(e)
|
||||||
|
finally:
|
||||||
|
loop.close()
|
||||||
|
asyncio.set_event_loop(None)
|
||||||
|
|
||||||
|
threads = [threading.Thread(target=create_loop_and_get_manager) for _ in range(10)]
|
||||||
|
for t in threads:
|
||||||
|
t.start()
|
||||||
|
for t in threads:
|
||||||
|
t.join()
|
||||||
|
|
||||||
|
assert len(errors) == 0
|
||||||
|
# All managers should be valid
|
||||||
|
for m in managers:
|
||||||
|
assert hasattr(m, "_locks")
|
||||||
|
assert hasattr(m, "_access_lock")
|
||||||
|
|
||||||
|
|
||||||
|
class TestEventLoopCleanup:
|
||||||
|
"""Tests for event loop cleanup behavior."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_weakref_cleanup_on_loop_close(self):
|
||||||
|
"""Test that per-loop managers are cleaned up when loop is closed."""
|
||||||
|
manager = SessionLockManager()
|
||||||
|
loop_ref: weakref.ref[asyncio.AbstractEventLoop] | None = None
|
||||||
|
|
||||||
|
def run_in_new_loop():
|
||||||
|
nonlocal loop_ref
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
loop_ref = weakref.ref(loop)
|
||||||
|
|
||||||
|
async def use_lock():
|
||||||
|
async with manager.acquire_lock("test-session"):
|
||||||
|
pass
|
||||||
|
return manager._get_loop_manager()
|
||||||
|
|
||||||
|
try:
|
||||||
|
per_loop_mgr = loop.run_until_complete(use_lock())
|
||||||
|
# Keep a weak ref to the per-loop manager
|
||||||
|
return weakref.ref(per_loop_mgr)
|
||||||
|
finally:
|
||||||
|
loop.close()
|
||||||
|
asyncio.set_event_loop(None)
|
||||||
|
|
||||||
|
with ThreadPoolExecutor(max_workers=1) as executor:
|
||||||
|
future = executor.submit(run_in_new_loop)
|
||||||
|
per_loop_mgr_ref = future.result()
|
||||||
|
|
||||||
|
# Give time for weakref cleanup
|
||||||
|
import gc
|
||||||
|
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
# The per-loop manager should be cleaned up when the loop is closed
|
||||||
|
# because WeakKeyDictionary removes entries when the key (loop) is gone
|
||||||
|
per_loop_mgr = per_loop_mgr_ref()
|
||||||
|
loop = loop_ref() if loop_ref is not None else None
|
||||||
|
assert per_loop_mgr is None or loop is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_access_after_loop_close_in_new_loop_works(self):
|
||||||
|
"""Test that accessing from a new loop after old loop closes works."""
|
||||||
|
manager = SessionLockManager()
|
||||||
|
|
||||||
|
# Use lock in current loop
|
||||||
|
async with manager.acquire_lock("session-1"):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Simulate old loop being closed and new loop being created
|
||||||
|
def run_in_new_loop():
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
try:
|
||||||
|
|
||||||
|
async def use_lock():
|
||||||
|
# Should work without issues in new loop
|
||||||
|
async with manager.acquire_lock("session-2"):
|
||||||
|
return "success"
|
||||||
|
|
||||||
|
return loop.run_until_complete(use_lock())
|
||||||
|
finally:
|
||||||
|
loop.close()
|
||||||
|
asyncio.set_event_loop(None)
|
||||||
|
|
||||||
|
with ThreadPoolExecutor(max_workers=1) as executor:
|
||||||
|
future = executor.submit(run_in_new_loop)
|
||||||
|
result = future.result()
|
||||||
|
|
||||||
|
assert result == "success"
|
||||||
|
|
||||||
|
|
||||||
|
class TestIssue5464:
|
||||||
|
"""Tests for issue #5464: Multiple OneBot instances with different event loops.
|
||||||
|
|
||||||
|
Issue: Running multiple OneBot adapter instances causes
|
||||||
|
"is bound to a different event loop" error.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_multiple_event_loops_no_cross_loop_error(self):
|
||||||
|
"""Test that multiple event loops don't cause cross-loop binding errors.
|
||||||
|
|
||||||
|
This simulates the scenario where multiple OneBot instances
|
||||||
|
(each potentially running in different event loops) access the
|
||||||
|
same SessionLockManager concurrently.
|
||||||
|
"""
|
||||||
|
from astrbot.core.utils.session_lock import session_lock_manager
|
||||||
|
|
||||||
|
errors: list[Exception] = []
|
||||||
|
results: list[str] = []
|
||||||
|
|
||||||
|
def simulate_onebot_instance(instance_id: int, session_ids: list[str]):
|
||||||
|
"""Simulate a OneBot instance running in its own event loop."""
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
try:
|
||||||
|
|
||||||
|
async def process_messages():
|
||||||
|
for session_id in session_ids:
|
||||||
|
try:
|
||||||
|
async with session_lock_manager.acquire_lock(session_id):
|
||||||
|
# Simulate message processing
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
results.append(f"instance-{instance_id}-{session_id}")
|
||||||
|
except Exception as e:
|
||||||
|
errors.append(e)
|
||||||
|
|
||||||
|
loop.run_until_complete(process_messages())
|
||||||
|
finally:
|
||||||
|
loop.close()
|
||||||
|
asyncio.set_event_loop(None)
|
||||||
|
|
||||||
|
# Simulate 4 OneBot instances (as in the issue report)
|
||||||
|
# Each handles multiple sessions concurrently
|
||||||
|
threads = []
|
||||||
|
for i in range(4):
|
||||||
|
sessions = [f"session-{i}-1", f"session-{i}-2", f"session-{i}-3"]
|
||||||
|
t = threading.Thread(target=simulate_onebot_instance, args=(i, sessions))
|
||||||
|
threads.append(t)
|
||||||
|
|
||||||
|
for t in threads:
|
||||||
|
t.start()
|
||||||
|
for t in threads:
|
||||||
|
t.join()
|
||||||
|
|
||||||
|
# Should have no errors (especially no "bound to a different event loop")
|
||||||
|
assert len(errors) == 0, f"Errors occurred: {errors}"
|
||||||
|
assert len(results) == 12 # 4 instances * 3 sessions each
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_lock_object_not_shared_across_loops(self):
|
||||||
|
"""Verify that asyncio.Lock objects are not shared across event loops.
|
||||||
|
|
||||||
|
The root cause of issue #5464 was that Lock objects created in one
|
||||||
|
event loop were being used in another, causing the error.
|
||||||
|
"""
|
||||||
|
manager = SessionLockManager()
|
||||||
|
session_id = "shared-session-id"
|
||||||
|
lock_ids: set[int] = set()
|
||||||
|
lock_id_lock = threading.Lock()
|
||||||
|
|
||||||
|
def get_lock_in_new_loop():
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
try:
|
||||||
|
|
||||||
|
async def acquire_and_capture():
|
||||||
|
# Get the per-loop manager
|
||||||
|
per_loop_mgr = manager._get_loop_manager()
|
||||||
|
# Capture the lock object id before acquiring
|
||||||
|
async with per_loop_mgr._access_lock:
|
||||||
|
lock = per_loop_mgr._locks[session_id]
|
||||||
|
with lock_id_lock:
|
||||||
|
lock_ids.add(id(lock))
|
||||||
|
async with manager.acquire_lock(session_id):
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
|
||||||
|
loop.run_until_complete(acquire_and_capture())
|
||||||
|
finally:
|
||||||
|
loop.close()
|
||||||
|
asyncio.set_event_loop(None)
|
||||||
|
|
||||||
|
# Run multiple loops concurrently
|
||||||
|
threads = [threading.Thread(target=get_lock_in_new_loop) for _ in range(5)]
|
||||||
|
for t in threads:
|
||||||
|
t.start()
|
||||||
|
for t in threads:
|
||||||
|
t.join()
|
||||||
|
|
||||||
|
# Each loop should have its own Lock object
|
||||||
|
# If locks were shared, we'd only have 1 lock_id
|
||||||
|
assert len(lock_ids) == 5, "Each event loop should have its own Lock object"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_concurrent_access_same_session_different_loops(self):
|
||||||
|
"""Test that same session ID accessed from different loops doesn't block.
|
||||||
|
|
||||||
|
This verifies the fix: locks are isolated per event loop,
|
||||||
|
so different loops can acquire the "same" session lock concurrently.
|
||||||
|
"""
|
||||||
|
from astrbot.core.utils.session_lock import session_lock_manager
|
||||||
|
|
||||||
|
session_id = "global-session"
|
||||||
|
acquisition_times: list[float] = []
|
||||||
|
time_lock = threading.Lock()
|
||||||
|
|
||||||
|
def acquire_lock_in_loop(loop_id: int):
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
try:
|
||||||
|
|
||||||
|
async def acquire():
|
||||||
|
import time
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
async with session_lock_manager.acquire_lock(session_id):
|
||||||
|
with time_lock:
|
||||||
|
acquisition_times.append(start)
|
||||||
|
await asyncio.sleep(0.1) # Hold the lock
|
||||||
|
|
||||||
|
loop.run_until_complete(acquire())
|
||||||
|
finally:
|
||||||
|
loop.close()
|
||||||
|
asyncio.set_event_loop(None)
|
||||||
|
|
||||||
|
# Start 3 threads nearly simultaneously
|
||||||
|
threads = [threading.Thread(target=acquire_lock_in_loop, args=(i,)) for i in range(3)]
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
for t in threads:
|
||||||
|
t.start()
|
||||||
|
for t in threads:
|
||||||
|
t.join()
|
||||||
|
total_time = time.time() - start_time
|
||||||
|
|
||||||
|
# If locks were NOT isolated, we'd need ~0.3s (3 * 0.1s serial)
|
||||||
|
# With isolation, all should complete in ~0.1s (parallel)
|
||||||
|
# Allow some overhead, but should be much less than 0.3s
|
||||||
|
assert total_time < 0.25, (
|
||||||
|
f"Locks should be isolated per loop, but took {total_time:.2f}s"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestEdgeCases:
|
||||||
|
"""Tests for edge cases."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_empty_session_id(self):
|
||||||
|
"""Test with empty session ID."""
|
||||||
|
manager = SessionLockManager()
|
||||||
|
|
||||||
|
async with manager.acquire_lock(""):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Should work without issues
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_special_characters_in_session_id(self):
|
||||||
|
"""Test with special characters in session ID."""
|
||||||
|
manager = SessionLockManager()
|
||||||
|
session_id = "session-with-special-chars!@#$%^&*()"
|
||||||
|
|
||||||
|
async with manager.acquire_lock(session_id):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Should work without issues
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_very_long_session_id(self):
|
||||||
|
"""Test with very long session ID."""
|
||||||
|
manager = SessionLockManager()
|
||||||
|
session_id = "a" * 10000
|
||||||
|
|
||||||
|
async with manager.acquire_lock(session_id):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Should work without issues
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_lock_not_held_after_context_exit(self):
|
||||||
|
"""Test that lock is released after context manager exit."""
|
||||||
|
manager = SessionLockManager()
|
||||||
|
session_id = "test-session"
|
||||||
|
|
||||||
|
async with manager.acquire_lock(session_id):
|
||||||
|
state = manager._get_loop_manager()
|
||||||
|
# Lock should exist and have count 1
|
||||||
|
assert session_id in state._locks
|
||||||
|
assert state._lock_count[session_id] == 1
|
||||||
|
|
||||||
|
# After exit, lock should be cleaned up
|
||||||
|
state = manager._get_loop_manager()
|
||||||
|
assert session_id not in state._locks
|
||||||
|
assert session_id not in state._lock_count
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_exception_during_lock(self):
|
||||||
|
"""Test that lock is released even if exception occurs."""
|
||||||
|
manager = SessionLockManager()
|
||||||
|
session_id = "test-session"
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
async with manager.acquire_lock(session_id):
|
||||||
|
raise ValueError("test error")
|
||||||
|
|
||||||
|
# Lock should still be released
|
||||||
|
state = manager._get_loop_manager()
|
||||||
|
assert session_id not in state._locks
|
||||||
|
assert session_id not in state._lock_count
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_nested_lock_different_sessions(self):
|
||||||
|
"""Test nested locks on different sessions."""
|
||||||
|
manager = SessionLockManager()
|
||||||
|
|
||||||
|
async with manager.acquire_lock("session-1"):
|
||||||
|
async with manager.acquire_lock("session-2"):
|
||||||
|
state = manager._get_loop_manager()
|
||||||
|
assert "session-1" in state._locks
|
||||||
|
assert "session-2" in state._locks
|
||||||
|
assert state._lock_count["session-1"] == 1
|
||||||
|
assert state._lock_count["session-2"] == 1
|
||||||
|
|
||||||
|
state = manager._get_loop_manager()
|
||||||
|
assert "session-1" not in state._locks
|
||||||
|
assert "session-2" not in state._locks
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_reentrant_lock_same_session(self):
|
||||||
|
"""Test reentrant locking on same session (should block)."""
|
||||||
|
manager = SessionLockManager()
|
||||||
|
session_id = "test-session"
|
||||||
|
order = []
|
||||||
|
|
||||||
|
async def outer():
|
||||||
|
async with manager.acquire_lock(session_id):
|
||||||
|
order.append("outer-acquired")
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
order.append("outer-done")
|
||||||
|
|
||||||
|
async def inner():
|
||||||
|
await asyncio.sleep(0.01) # Let outer acquire first
|
||||||
|
order.append("inner-attempt")
|
||||||
|
async with manager.acquire_lock(session_id):
|
||||||
|
order.append("inner-acquired")
|
||||||
|
order.append("inner-done")
|
||||||
|
|
||||||
|
await asyncio.gather(outer(), inner())
|
||||||
|
|
||||||
|
# Inner should wait for outer to complete
|
||||||
|
assert order.index("outer-acquired") < order.index("outer-done")
|
||||||
|
assert order.index("outer-done") < order.index("inner-acquired")
|
||||||
Reference in New Issue
Block a user