Compare commits
5 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 5dd30f9a45 | |||
| a53a1ca49b | |||
| 3fd6c4c8a6 | |||
| 5808784f07 | |||
| 537849c1e7 |
@@ -1 +1 @@
|
|||||||
__version__ = "4.19.2"
|
__version__ = "4.19.3"
|
||||||
|
|||||||
@@ -302,7 +302,7 @@ class DashscopeAgentRunner(BaseAgentRunner[TContext]):
|
|||||||
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
item_type, item_data = await asyncio.get_event_loop().run_in_executor(
|
item_type, item_data = await asyncio.get_running_loop().run_in_executor(
|
||||||
None, response_queue.get, True, 1
|
None, response_queue.get, True, 1
|
||||||
)
|
)
|
||||||
except queue.Empty:
|
except queue.Empty:
|
||||||
@@ -388,7 +388,7 @@ class DashscopeAgentRunner(BaseAgentRunner[TContext]):
|
|||||||
|
|
||||||
# 发起请求
|
# 发起请求
|
||||||
partial = functools.partial(Application.call, **payload)
|
partial = functools.partial(Application.call, **payload)
|
||||||
response = await asyncio.get_event_loop().run_in_executor(None, partial)
|
response = await asyncio.get_running_loop().run_in_executor(None, partial)
|
||||||
|
|
||||||
async for resp in self._handle_streaming_response(response, session_id):
|
async for resp in self._handle_streaming_response(response, session_id):
|
||||||
yield resp
|
yield resp
|
||||||
|
|||||||
@@ -121,11 +121,12 @@ class BayContainerManager:
|
|||||||
async def wait_healthy(self, timeout: int = HEALTH_TIMEOUT_S) -> None:
|
async def wait_healthy(self, timeout: int = HEALTH_TIMEOUT_S) -> None:
|
||||||
"""Block until Bay's ``/health`` endpoint returns 200."""
|
"""Block until Bay's ``/health`` endpoint returns 200."""
|
||||||
url = f"http://127.0.0.1:{self._host_port}/health"
|
url = f"http://127.0.0.1:{self._host_port}/health"
|
||||||
deadline = asyncio.get_event_loop().time() + timeout
|
loop = asyncio.get_running_loop()
|
||||||
|
deadline = loop.time() + timeout
|
||||||
last_error: str = ""
|
last_error: str = ""
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
while asyncio.get_event_loop().time() < deadline:
|
while loop.time() < deadline:
|
||||||
try:
|
try:
|
||||||
async with session.get(
|
async with session.get(
|
||||||
url, timeout=aiohttp.ClientTimeout(total=3)
|
url, timeout=aiohttp.ClientTimeout(total=3)
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from typing import Any, TypedDict
|
|||||||
|
|
||||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||||
|
|
||||||
VERSION = "4.19.2"
|
VERSION = "4.19.3"
|
||||||
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db")
|
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db")
|
||||||
|
|
||||||
WEBHOOK_SUPPORTED_PLATFORMS = [
|
WEBHOOK_SUPPORTED_PLATFORMS = [
|
||||||
@@ -343,11 +343,16 @@ CONFIG_METADATA_2 = {
|
|||||||
"id": "wecom_ai_bot",
|
"id": "wecom_ai_bot",
|
||||||
"type": "wecom_ai_bot",
|
"type": "wecom_ai_bot",
|
||||||
"enable": True,
|
"enable": True,
|
||||||
|
"wecom_ai_bot_connection_mode": "webhook",
|
||||||
"wecomaibot_init_respond_text": "",
|
"wecomaibot_init_respond_text": "",
|
||||||
"wecomaibot_friend_message_welcome_text": "",
|
"wecomaibot_friend_message_welcome_text": "",
|
||||||
"wecom_ai_bot_name": "",
|
"wecom_ai_bot_name": "",
|
||||||
"msg_push_webhook_url": "",
|
"msg_push_webhook_url": "",
|
||||||
"only_use_webhook_url_to_send": False,
|
"only_use_webhook_url_to_send": False,
|
||||||
|
"long_connection_bot_id": "",
|
||||||
|
"long_connection_secret": "",
|
||||||
|
"long_connection_ws_url": "wss://openws.work.weixin.qq.com",
|
||||||
|
"long_connection_heartbeat_interval": 30,
|
||||||
"token": "",
|
"token": "",
|
||||||
"encoding_aes_key": "",
|
"encoding_aes_key": "",
|
||||||
"unified_webhook_mode": True,
|
"unified_webhook_mode": True,
|
||||||
@@ -732,6 +737,13 @@ CONFIG_METADATA_2 = {
|
|||||||
"type": "string",
|
"type": "string",
|
||||||
"hint": "请务必填写正确,否则无法使用一些指令。",
|
"hint": "请务必填写正确,否则无法使用一些指令。",
|
||||||
},
|
},
|
||||||
|
"wecom_ai_bot_connection_mode": {
|
||||||
|
"description": "企业微信智能机器人连接模式",
|
||||||
|
"type": "string",
|
||||||
|
"options": ["webhook", "long_connection"],
|
||||||
|
"labels": ["Webhook 回调", "长连接"],
|
||||||
|
"hint": "Webhook 回调模式需要配置 Token/EncodingAESKey。长连接模式需要配置 BotID/Secret。",
|
||||||
|
},
|
||||||
"wecomaibot_init_respond_text": {
|
"wecomaibot_init_respond_text": {
|
||||||
"description": "企业微信智能机器人初始响应文本",
|
"description": "企业微信智能机器人初始响应文本",
|
||||||
"type": "string",
|
"type": "string",
|
||||||
@@ -752,6 +764,38 @@ CONFIG_METADATA_2 = {
|
|||||||
"type": "bool",
|
"type": "bool",
|
||||||
"hint": "启用后,企业微信智能机器人的所有回复都改为通过消息推送 Webhook 发送。消息推送 Webhook 支持更多的消息类型(如图片、文件等)。",
|
"hint": "启用后,企业微信智能机器人的所有回复都改为通过消息推送 Webhook 发送。消息推送 Webhook 支持更多的消息类型(如图片、文件等)。",
|
||||||
},
|
},
|
||||||
|
"long_connection_bot_id": {
|
||||||
|
"description": "长连接 BotID",
|
||||||
|
"type": "string",
|
||||||
|
"hint": "企业微信智能机器人长连接模式凭证 BotID。",
|
||||||
|
"condition": {
|
||||||
|
"wecom_ai_bot_connection_mode": "long_connection",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"long_connection_secret": {
|
||||||
|
"description": "长连接 Secret",
|
||||||
|
"type": "string",
|
||||||
|
"hint": "企业微信智能机器人长连接模式凭证 Secret。",
|
||||||
|
"condition": {
|
||||||
|
"wecom_ai_bot_connection_mode": "long_connection",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"long_connection_ws_url": {
|
||||||
|
"description": "长连接 WebSocket 地址",
|
||||||
|
"type": "string",
|
||||||
|
"hint": "默认值为 wss://openws.work.weixin.qq.com,一般无需修改。",
|
||||||
|
"condition": {
|
||||||
|
"wecom_ai_bot_connection_mode": "long_connection",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"long_connection_heartbeat_interval": {
|
||||||
|
"description": "长连接心跳间隔",
|
||||||
|
"type": "int",
|
||||||
|
"hint": "长连接模式心跳间隔(秒),建议 30 秒。",
|
||||||
|
"condition": {
|
||||||
|
"wecom_ai_bot_connection_mode": "long_connection",
|
||||||
|
},
|
||||||
|
},
|
||||||
"lark_bot_name": {
|
"lark_bot_name": {
|
||||||
"description": "飞书机器人的名字",
|
"description": "飞书机器人的名字",
|
||||||
"type": "string",
|
"type": "string",
|
||||||
|
|||||||
@@ -699,21 +699,24 @@ class File(BaseMessageComponent):
|
|||||||
|
|
||||||
if self.url:
|
if self.url:
|
||||||
try:
|
try:
|
||||||
loop = asyncio.get_event_loop()
|
# 检查是否有正在运行的 event loop
|
||||||
if loop.is_running():
|
asyncio.get_running_loop()
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"不可以在异步上下文中同步等待下载! "
|
"不可以在异步上下文中同步等待下载! "
|
||||||
"这个警告通常发生于某些逻辑试图通过 <File>.file 获取文件消息段的文件内容。"
|
"这个警告通常发生于某些逻辑试图通过 <File>.file 获取文件消息段的文件内容。"
|
||||||
"请使用 await get_file() 代替直接获取 <File>.file 字段",
|
"请使用 await get_file() 代替直接获取 <File>.file 字段",
|
||||||
)
|
)
|
||||||
return ""
|
return ""
|
||||||
# 等待下载完成
|
except RuntimeError:
|
||||||
loop.run_until_complete(self._download_file())
|
# 没有运行中的 event loop,可以同步执行
|
||||||
|
try:
|
||||||
|
# 使用 asyncio.run 安全地创建和关闭事件循环
|
||||||
|
asyncio.run(self._download_file())
|
||||||
|
except Exception:
|
||||||
|
logger.exception("文件下载失败")
|
||||||
|
|
||||||
if self.file_ and os.path.exists(self.file_):
|
if self.file_ and os.path.exists(self.file_):
|
||||||
return os.path.abspath(self.file_)
|
return os.path.abspath(self.file_)
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"文件下载失败: {e}")
|
|
||||||
|
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|||||||
@@ -367,7 +367,7 @@ class DingtalkPlatformAdapter(Platform):
|
|||||||
|
|
||||||
async def get_access_token(self) -> str:
|
async def get_access_token(self) -> str:
|
||||||
try:
|
try:
|
||||||
access_token = await asyncio.get_event_loop().run_in_executor(
|
access_token = await asyncio.get_running_loop().run_in_executor(
|
||||||
None,
|
None,
|
||||||
self.client_.get_access_token,
|
self.client_.get_access_token,
|
||||||
)
|
)
|
||||||
@@ -760,7 +760,7 @@ class DingtalkPlatformAdapter(Platform):
|
|||||||
return
|
return
|
||||||
logger.error(f"钉钉机器人启动失败: {e}")
|
logger.error(f"钉钉机器人启动失败: {e}")
|
||||||
|
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_running_loop()
|
||||||
await loop.run_in_executor(None, start_client, loop)
|
await loop.run_in_executor(None, start_client, loop)
|
||||||
|
|
||||||
async def terminate(self) -> None:
|
async def terminate(self) -> None:
|
||||||
|
|||||||
@@ -80,7 +80,7 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
|||||||
|
|
||||||
if isinstance(source, botpy.message.C2CMessage):
|
if isinstance(source, botpy.message.C2CMessage):
|
||||||
# 真流式传输
|
# 真流式传输
|
||||||
current_time = asyncio.get_event_loop().time()
|
current_time = asyncio.get_running_loop().time()
|
||||||
time_since_last_edit = current_time - last_edit_time
|
time_since_last_edit = current_time - last_edit_time
|
||||||
|
|
||||||
if time_since_last_edit >= throttle_interval:
|
if time_since_last_edit >= throttle_interval:
|
||||||
@@ -90,7 +90,7 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
|||||||
)
|
)
|
||||||
stream_payload["index"] += 1
|
stream_payload["index"] += 1
|
||||||
stream_payload["id"] = ret["id"]
|
stream_payload["id"] = ret["id"]
|
||||||
last_edit_time = asyncio.get_event_loop().time()
|
last_edit_time = asyncio.get_running_loop().time()
|
||||||
|
|
||||||
if isinstance(source, botpy.message.C2CMessage):
|
if isinstance(source, botpy.message.C2CMessage):
|
||||||
# 结束流式对话,并且传输 buffer 中剩余的消息
|
# 结束流式对话,并且传输 buffer 中剩余的消息
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ class QQOfficialWebhook:
|
|||||||
max_async=1,
|
max_async=1,
|
||||||
connect=bot_connect,
|
connect=bot_connect,
|
||||||
dispatch=self.client.ws_dispatch,
|
dispatch=self.client.ws_dispatch,
|
||||||
loop=asyncio.get_event_loop(),
|
loop=asyncio.get_running_loop(),
|
||||||
api=self.api,
|
api=self.api,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -626,7 +626,7 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
|||||||
|
|
||||||
# 发送初始 typing 状态
|
# 发送初始 typing 状态
|
||||||
await self._ensure_typing(user_name, message_thread_id)
|
await self._ensure_typing(user_name, message_thread_id)
|
||||||
last_chat_action_time = asyncio.get_event_loop().time()
|
last_chat_action_time = asyncio.get_running_loop().time()
|
||||||
|
|
||||||
def _append_text(t: str) -> None:
|
def _append_text(t: str) -> None:
|
||||||
nonlocal delta
|
nonlocal delta
|
||||||
@@ -657,11 +657,11 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
|||||||
|
|
||||||
# 编辑或发送消息
|
# 编辑或发送消息
|
||||||
if message_id and len(delta) <= self.MAX_MESSAGE_LENGTH:
|
if message_id and len(delta) <= self.MAX_MESSAGE_LENGTH:
|
||||||
current_time = asyncio.get_event_loop().time()
|
current_time = asyncio.get_running_loop().time()
|
||||||
time_since_last_edit = current_time - last_edit_time
|
time_since_last_edit = current_time - last_edit_time
|
||||||
|
|
||||||
if time_since_last_edit >= throttle_interval:
|
if time_since_last_edit >= throttle_interval:
|
||||||
current_time = asyncio.get_event_loop().time()
|
current_time = asyncio.get_running_loop().time()
|
||||||
if current_time - last_chat_action_time >= chat_action_interval:
|
if current_time - last_chat_action_time >= chat_action_interval:
|
||||||
await self._ensure_typing(user_name, message_thread_id)
|
await self._ensure_typing(user_name, message_thread_id)
|
||||||
last_chat_action_time = current_time
|
last_chat_action_time = current_time
|
||||||
@@ -674,9 +674,9 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
|||||||
current_content = delta
|
current_content = delta
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"编辑消息失败(streaming): {e!s}")
|
logger.warning(f"编辑消息失败(streaming): {e!s}")
|
||||||
last_edit_time = asyncio.get_event_loop().time()
|
last_edit_time = asyncio.get_running_loop().time()
|
||||||
else:
|
else:
|
||||||
current_time = asyncio.get_event_loop().time()
|
current_time = asyncio.get_running_loop().time()
|
||||||
if current_time - last_chat_action_time >= chat_action_interval:
|
if current_time - last_chat_action_time >= chat_action_interval:
|
||||||
await self._ensure_typing(user_name, message_thread_id)
|
await self._ensure_typing(user_name, message_thread_id)
|
||||||
last_chat_action_time = current_time
|
last_chat_action_time = current_time
|
||||||
@@ -688,7 +688,7 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"发送消息失败(streaming): {e!s}")
|
logger.warning(f"发送消息失败(streaming): {e!s}")
|
||||||
message_id = msg.message_id
|
message_id = msg.message_id
|
||||||
last_edit_time = asyncio.get_event_loop().time()
|
last_edit_time = asyncio.get_running_loop().time()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if delta and current_content != delta:
|
if delta and current_content != delta:
|
||||||
|
|||||||
@@ -200,7 +200,7 @@ class WecomPlatformAdapter(Platform):
|
|||||||
return msg_list[-1]
|
return msg_list[-1]
|
||||||
return None
|
return None
|
||||||
|
|
||||||
msg_new = await asyncio.get_event_loop().run_in_executor(
|
msg_new = await asyncio.get_running_loop().run_in_executor(
|
||||||
None,
|
None,
|
||||||
get_latest_msg_item,
|
get_latest_msg_item,
|
||||||
)
|
)
|
||||||
@@ -261,7 +261,7 @@ class WecomPlatformAdapter(Platform):
|
|||||||
|
|
||||||
@override
|
@override
|
||||||
async def run(self) -> None:
|
async def run(self) -> None:
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_running_loop()
|
||||||
if self.kf_name:
|
if self.kf_name:
|
||||||
try:
|
try:
|
||||||
acc_list = (
|
acc_list = (
|
||||||
@@ -339,7 +339,7 @@ class WecomPlatformAdapter(Platform):
|
|||||||
abm.session_id = abm.sender.user_id
|
abm.session_id = abm.sender.user_id
|
||||||
abm.raw_message = msg
|
abm.raw_message = msg
|
||||||
elif isinstance(msg, VoiceMessage):
|
elif isinstance(msg, VoiceMessage):
|
||||||
resp: Response = await asyncio.get_event_loop().run_in_executor(
|
resp: Response = await asyncio.get_running_loop().run_in_executor(
|
||||||
None,
|
None,
|
||||||
self.client.media.download,
|
self.client.media.download,
|
||||||
msg.media_id,
|
msg.media_id,
|
||||||
@@ -395,7 +395,7 @@ class WecomPlatformAdapter(Platform):
|
|||||||
abm.message_str = text
|
abm.message_str = text
|
||||||
elif msgtype == "image":
|
elif msgtype == "image":
|
||||||
media_id = msg.get("image", {}).get("media_id", "")
|
media_id = msg.get("image", {}).get("media_id", "")
|
||||||
resp: Response = await asyncio.get_event_loop().run_in_executor(
|
resp: Response = await asyncio.get_running_loop().run_in_executor(
|
||||||
None,
|
None,
|
||||||
self.client.media.download,
|
self.client.media.download,
|
||||||
media_id,
|
media_id,
|
||||||
@@ -407,7 +407,7 @@ class WecomPlatformAdapter(Platform):
|
|||||||
abm.message = [Image(file=path, url=path)]
|
abm.message = [Image(file=path, url=path)]
|
||||||
elif msgtype == "voice":
|
elif msgtype == "voice":
|
||||||
media_id = msg.get("voice", {}).get("media_id", "")
|
media_id = msg.get("voice", {}).get("media_id", "")
|
||||||
resp: Response = await asyncio.get_event_loop().run_in_executor(
|
resp: Response = await asyncio.get_running_loop().run_in_executor(
|
||||||
None,
|
None,
|
||||||
self.client.media.download,
|
self.client.media.download,
|
||||||
media_id,
|
media_id,
|
||||||
|
|||||||
@@ -4,6 +4,7 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import time
|
||||||
from collections.abc import Awaitable, Callable
|
from collections.abc import Awaitable, Callable
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@@ -82,7 +83,7 @@ class WecomAIQueueMgr:
|
|||||||
del self.pending_responses[session_id]
|
del self.pending_responses[session_id]
|
||||||
logger.debug(f"[WecomAI] 移除待处理响应: {session_id}")
|
logger.debug(f"[WecomAI] 移除待处理响应: {session_id}")
|
||||||
if mark_finished:
|
if mark_finished:
|
||||||
self.completed_streams[session_id] = asyncio.get_event_loop().time()
|
self.completed_streams[session_id] = time.monotonic()
|
||||||
logger.debug(f"[WecomAI] 标记流已结束: {session_id}")
|
logger.debug(f"[WecomAI] 标记流已结束: {session_id}")
|
||||||
|
|
||||||
def remove_queue(self, session_id: str):
|
def remove_queue(self, session_id: str):
|
||||||
@@ -135,7 +136,7 @@ class WecomAIQueueMgr:
|
|||||||
"""
|
"""
|
||||||
self.pending_responses[session_id] = {
|
self.pending_responses[session_id] = {
|
||||||
"callback_params": callback_params,
|
"callback_params": callback_params,
|
||||||
"timestamp": asyncio.get_event_loop().time(),
|
"timestamp": time.monotonic(),
|
||||||
}
|
}
|
||||||
logger.debug(f"[WecomAI] 设置待处理响应: {session_id}")
|
logger.debug(f"[WecomAI] 设置待处理响应: {session_id}")
|
||||||
|
|
||||||
@@ -160,7 +161,7 @@ class WecomAIQueueMgr:
|
|||||||
finished_at = self.completed_streams.get(session_id)
|
finished_at = self.completed_streams.get(session_id)
|
||||||
if finished_at is None:
|
if finished_at is None:
|
||||||
return False
|
return False
|
||||||
if asyncio.get_event_loop().time() - finished_at > max_age_seconds:
|
if time.monotonic() - finished_at > max_age_seconds:
|
||||||
self.completed_streams.pop(session_id, None)
|
self.completed_streams.pop(session_id, None)
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
@@ -172,7 +173,7 @@ class WecomAIQueueMgr:
|
|||||||
max_age_seconds: 最大存活时间(秒)
|
max_age_seconds: 最大存活时间(秒)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
current_time = asyncio.get_event_loop().time()
|
current_time = time.monotonic()
|
||||||
expired_sessions = []
|
expired_sessions = []
|
||||||
|
|
||||||
for session_id, response_data in self.pending_responses.items():
|
for session_id, response_data in self.pending_responses.items():
|
||||||
|
|||||||
@@ -369,7 +369,7 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
|
|||||||
if future:
|
if future:
|
||||||
logger.debug(f"duplicate message id checked: {msg.id}")
|
logger.debug(f"duplicate message id checked: {msg.id}")
|
||||||
else:
|
else:
|
||||||
future = asyncio.get_event_loop().create_future()
|
future = asyncio.get_running_loop().create_future()
|
||||||
self.wexin_event_workers[msg_id] = future
|
self.wexin_event_workers[msg_id] = future
|
||||||
await self.convert_message(msg, future)
|
await self.convert_message(msg, future)
|
||||||
# I love shield so much!
|
# I love shield so much!
|
||||||
@@ -461,7 +461,7 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
|
|||||||
elif msg.type == "voice":
|
elif msg.type == "voice":
|
||||||
assert isinstance(msg, VoiceMessage)
|
assert isinstance(msg, VoiceMessage)
|
||||||
|
|
||||||
resp: Response = await asyncio.get_event_loop().run_in_executor(
|
resp: Response = await asyncio.get_running_loop().run_in_executor(
|
||||||
None,
|
None,
|
||||||
self.client.media.download,
|
self.client.media.download,
|
||||||
msg.media_id,
|
msg.media_id,
|
||||||
|
|||||||
@@ -276,9 +276,24 @@ class ProviderAnthropic(Provider):
|
|||||||
llm_response.id = completion.id
|
llm_response.id = completion.id
|
||||||
llm_response.usage = self._extract_usage(completion.usage)
|
llm_response.usage = self._extract_usage(completion.usage)
|
||||||
|
|
||||||
# TODO(Soulter): 处理 end_turn 情况
|
# Handle cases where completion only contains ThinkingBlock (e.g., MiniMax max_tokens)
|
||||||
|
# When stop_reason='max_tokens', the model may return only thinking content
|
||||||
|
# This is valid and should not raise an exception
|
||||||
if not llm_response.completion_text and not llm_response.tools_call_args:
|
if not llm_response.completion_text and not llm_response.tools_call_args:
|
||||||
raise Exception(f"Anthropic API 返回的 completion 无法解析:{completion}。")
|
# Guard clause: raise early if no valid content at all
|
||||||
|
if not llm_response.reasoning_content:
|
||||||
|
raise ValueError(
|
||||||
|
f"Anthropic API returned unparsable completion: "
|
||||||
|
f"no text, tool_use, or thinking content found. "
|
||||||
|
f"Completion: {completion}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# We have reasoning content (ThinkingBlock) - this is valid
|
||||||
|
stop_reason = getattr(completion, "stop_reason", "unknown")
|
||||||
|
logger.debug(
|
||||||
|
f"Completion contains only ThinkingBlock (stop_reason={stop_reason})"
|
||||||
|
)
|
||||||
|
llm_response.completion_text = "" # Ensure empty string, not None
|
||||||
|
|
||||||
return llm_response
|
return llm_response
|
||||||
|
|
||||||
|
|||||||
@@ -87,7 +87,7 @@ class ProviderDashscopeTTSAPI(TTSProvider):
|
|||||||
model: str,
|
model: str,
|
||||||
text: str,
|
text: str,
|
||||||
) -> tuple[bytes | None, str]:
|
) -> tuple[bytes | None, str]:
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_running_loop()
|
||||||
response = await loop.run_in_executor(None, self._call_qwen_tts, model, text)
|
response = await loop.run_in_executor(None, self._call_qwen_tts, model, text)
|
||||||
audio_bytes = await self._extract_audio_from_response(response)
|
audio_bytes = await self._extract_audio_from_response(response)
|
||||||
if not audio_bytes:
|
if not audio_bytes:
|
||||||
@@ -143,7 +143,7 @@ class ProviderDashscopeTTSAPI(TTSProvider):
|
|||||||
voice=self.voice,
|
voice=self.voice,
|
||||||
format=AudioFormat.WAV_24000HZ_MONO_16BIT,
|
format=AudioFormat.WAV_24000HZ_MONO_16BIT,
|
||||||
)
|
)
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_running_loop()
|
||||||
audio_bytes = await loop.run_in_executor(
|
audio_bytes = await loop.run_in_executor(
|
||||||
None,
|
None,
|
||||||
synthesizer.call,
|
synthesizer.call,
|
||||||
|
|||||||
@@ -59,7 +59,7 @@ class GenieTTSProvider(TTSProvider):
|
|||||||
filename = f"genie_tts_{uuid.uuid4()}.wav"
|
filename = f"genie_tts_{uuid.uuid4()}.wav"
|
||||||
path = os.path.join(temp_dir, filename)
|
path = os.path.join(temp_dir, filename)
|
||||||
|
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_running_loop()
|
||||||
|
|
||||||
def _generate(save_path: str) -> None:
|
def _generate(save_path: str) -> None:
|
||||||
assert genie is not None
|
assert genie is not None
|
||||||
@@ -85,7 +85,7 @@ class GenieTTSProvider(TTSProvider):
|
|||||||
text_queue: asyncio.Queue[str | None],
|
text_queue: asyncio.Queue[str | None],
|
||||||
audio_queue: "asyncio.Queue[bytes | tuple[str, bytes] | None]",
|
audio_queue: "asyncio.Queue[bytes | tuple[str, bytes] | None]",
|
||||||
) -> None:
|
) -> None:
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_running_loop()
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
text = await text_queue.get()
|
text = await text_queue.get()
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ class ProviderSenseVoiceSTTSelfHost(STTProvider):
|
|||||||
logger.info("下载或者加载 SenseVoice 模型中,这可能需要一些时间 ...")
|
logger.info("下载或者加载 SenseVoice 模型中,这可能需要一些时间 ...")
|
||||||
|
|
||||||
# 将模型加载放到线程池中执行
|
# 将模型加载放到线程池中执行
|
||||||
self.model = await asyncio.get_event_loop().run_in_executor(
|
self.model = await asyncio.get_running_loop().run_in_executor(
|
||||||
None,
|
None,
|
||||||
lambda: SenseVoiceSmall(self.model_name, quantize=True, batch_size=16),
|
lambda: SenseVoiceSmall(self.model_name, quantize=True, batch_size=16),
|
||||||
)
|
)
|
||||||
@@ -88,7 +88,7 @@ class ProviderSenseVoiceSTTSelfHost(STTProvider):
|
|||||||
audio_url = output_path
|
audio_url = output_path
|
||||||
|
|
||||||
# 使用 run_in_executor 来调用模型进行识别
|
# 使用 run_in_executor 来调用模型进行识别
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_running_loop()
|
||||||
res = await loop.run_in_executor(
|
res = await loop.run_in_executor(
|
||||||
None, # 使用默认的线程池
|
None, # 使用默认的线程池
|
||||||
lambda: cast(SenseVoiceSmall, self.model)(
|
lambda: cast(SenseVoiceSmall, self.model)(
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ class ProviderOpenAIWhisperSelfHost(STTProvider):
|
|||||||
self.model = None
|
self.model = None
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_running_loop()
|
||||||
logger.info("下载或者加载 Whisper 模型中,这可能需要一些时间 ...")
|
logger.info("下载或者加载 Whisper 模型中,这可能需要一些时间 ...")
|
||||||
self.model = await loop.run_in_executor(
|
self.model = await loop.run_in_executor(
|
||||||
None,
|
None,
|
||||||
@@ -50,7 +50,7 @@ class ProviderOpenAIWhisperSelfHost(STTProvider):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
async def get_text(self, audio_url: str) -> str:
|
async def get_text(self, audio_url: str) -> str:
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_running_loop()
|
||||||
|
|
||||||
is_tencent = False
|
is_tencent = False
|
||||||
|
|
||||||
|
|||||||
@@ -1374,10 +1374,23 @@ class PluginManager:
|
|||||||
return
|
return
|
||||||
|
|
||||||
if "__del__" in star_metadata.star_cls_type.__dict__:
|
if "__del__" in star_metadata.star_cls_type.__dict__:
|
||||||
asyncio.get_event_loop().run_in_executor(
|
loop = asyncio.get_running_loop()
|
||||||
|
future = loop.run_in_executor(
|
||||||
None,
|
None,
|
||||||
star_metadata.star_cls.__del__,
|
star_metadata.star_cls.__del__,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _log_del_exception(fut: asyncio.Future) -> None:
|
||||||
|
if fut.cancelled():
|
||||||
|
return
|
||||||
|
if (exc := fut.exception()) is not None:
|
||||||
|
logger.error(
|
||||||
|
"插件 %s 在 __del__ 中抛出了异常:%r",
|
||||||
|
star_metadata.name,
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
|
||||||
|
future.add_done_callback(_log_del_exception)
|
||||||
elif "terminate" in star_metadata.star_cls_type.__dict__:
|
elif "terminate" in star_metadata.star_cls_type.__dict__:
|
||||||
await star_metadata.star_cls.terminate()
|
await star_metadata.star_cls.terminate()
|
||||||
|
|
||||||
|
|||||||
@@ -1,9 +1,13 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
import threading
|
||||||
|
import weakref
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
|
|
||||||
class SessionLockManager:
|
class _PerLoopSessionLockManager:
|
||||||
|
"""Per-event-loop session lock manager; keeps original simple semantics."""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self._locks: dict[str, asyncio.Lock] = defaultdict(asyncio.Lock)
|
self._locks: dict[str, asyncio.Lock] = defaultdict(asyncio.Lock)
|
||||||
self._lock_count: dict[str, int] = defaultdict(int)
|
self._lock_count: dict[str, int] = defaultdict(int)
|
||||||
@@ -26,4 +30,26 @@ class SessionLockManager:
|
|||||||
self._lock_count.pop(session_id, None)
|
self._lock_count.pop(session_id, None)
|
||||||
|
|
||||||
|
|
||||||
|
class SessionLockManager:
|
||||||
|
"""Thread-safe session lock manager with per-event-loop isolation."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._state_guard = threading.Lock()
|
||||||
|
self._loop_managers: weakref.WeakKeyDictionary[
|
||||||
|
asyncio.AbstractEventLoop, _PerLoopSessionLockManager
|
||||||
|
] = weakref.WeakKeyDictionary()
|
||||||
|
|
||||||
|
def _get_loop_manager(self) -> _PerLoopSessionLockManager:
|
||||||
|
"""Get the lock manager for the current event loop."""
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
with self._state_guard:
|
||||||
|
return self._loop_managers.setdefault(loop, _PerLoopSessionLockManager())
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def acquire_lock(self, session_id: str):
|
||||||
|
manager = self._get_loop_manager()
|
||||||
|
async with manager.acquire_lock(session_id):
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
session_lock_manager = SessionLockManager()
|
session_lock_manager = SessionLockManager()
|
||||||
|
|||||||
@@ -12,6 +12,32 @@ from .route import Response, Route, RouteContext
|
|||||||
DEFAULT_MCP_CONFIG = {"mcpServers": {}}
|
DEFAULT_MCP_CONFIG = {"mcpServers": {}}
|
||||||
|
|
||||||
|
|
||||||
|
class EmptyMcpServersError(ValueError):
|
||||||
|
"""Raised when mcpServers is empty."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_mcp_server_config(mcp_servers_value: object) -> dict:
|
||||||
|
"""Extract server configuration from user-submitted mcpServers field.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: Invalid configuration
|
||||||
|
"""
|
||||||
|
if not isinstance(mcp_servers_value, dict):
|
||||||
|
raise ValueError("mcpServers must be a JSON object")
|
||||||
|
if not mcp_servers_value:
|
||||||
|
raise EmptyMcpServersError("mcpServers configuration cannot be empty")
|
||||||
|
key_0 = next(iter(mcp_servers_value))
|
||||||
|
extracted = mcp_servers_value[key_0]
|
||||||
|
if not isinstance(extracted, dict):
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid mcpServers format. Ensure each key in mcpServers is a server name, "
|
||||||
|
"and each value is an object containing fields like command/url."
|
||||||
|
)
|
||||||
|
return extracted
|
||||||
|
|
||||||
|
|
||||||
class ToolsRoute(Route):
|
class ToolsRoute(Route):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -33,13 +59,37 @@ class ToolsRoute(Route):
|
|||||||
self.register_routes()
|
self.register_routes()
|
||||||
self.tool_mgr = self.core_lifecycle.provider_manager.llm_tools
|
self.tool_mgr = self.core_lifecycle.provider_manager.llm_tools
|
||||||
|
|
||||||
|
def _rollback_mcp_server(self, name: str) -> bool:
|
||||||
|
try:
|
||||||
|
rollback_config = self.tool_mgr.load_mcp_config()
|
||||||
|
if name in rollback_config["mcpServers"]:
|
||||||
|
rollback_config["mcpServers"].pop(name)
|
||||||
|
return self.tool_mgr.save_mcp_config(rollback_config)
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
return False
|
||||||
|
|
||||||
async def get_mcp_servers(self):
|
async def get_mcp_servers(self):
|
||||||
try:
|
try:
|
||||||
config = self.tool_mgr.load_mcp_config()
|
config = self.tool_mgr.load_mcp_config()
|
||||||
servers = []
|
servers = []
|
||||||
|
mcp_servers = config.get("mcpServers", {})
|
||||||
|
|
||||||
|
if not isinstance(mcp_servers, dict):
|
||||||
|
logger.warning(
|
||||||
|
f"Invalid MCP server config type: {type(mcp_servers).__name__}. Expected object/dict; skipped all MCP servers."
|
||||||
|
)
|
||||||
|
mcp_servers = {}
|
||||||
|
|
||||||
# 获取所有服务器并添加它们的工具列表
|
# 获取所有服务器并添加它们的工具列表
|
||||||
for name, server_config in config["mcpServers"].items():
|
for name, server_config in mcp_servers.items():
|
||||||
|
if not isinstance(server_config, dict):
|
||||||
|
logger.warning(
|
||||||
|
f"Invalid config for MCP server '{name}' (type: {type(server_config).__name__}); skipped."
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
server_info = {
|
server_info = {
|
||||||
"name": name,
|
"name": name,
|
||||||
"active": server_config.get("active", True),
|
"active": server_config.get("active", True),
|
||||||
@@ -65,7 +115,7 @@ class ToolsRoute(Route):
|
|||||||
return Response().ok(servers).__dict__
|
return Response().ok(servers).__dict__
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
return Response().error(f"获取 MCP 服务器列表失败: {e!s}").__dict__
|
return Response().error(f"Failed to get MCP server list: {e!s}").__dict__
|
||||||
|
|
||||||
async def add_mcp_server(self):
|
async def add_mcp_server(self):
|
||||||
try:
|
try:
|
||||||
@@ -75,7 +125,7 @@ class ToolsRoute(Route):
|
|||||||
|
|
||||||
# 检查必填字段
|
# 检查必填字段
|
||||||
if not name:
|
if not name:
|
||||||
return Response().error("服务器名称不能为空").__dict__
|
return Response().error("Server name cannot be empty").__dict__
|
||||||
|
|
||||||
# 移除特殊字段并检查配置是否有效
|
# 移除特殊字段并检查配置是否有效
|
||||||
has_valid_config = False
|
has_valid_config = False
|
||||||
@@ -85,21 +135,33 @@ class ToolsRoute(Route):
|
|||||||
for key, value in server_data.items():
|
for key, value in server_data.items():
|
||||||
if key not in ["name", "active", "tools", "errlogs"]: # 排除特殊字段
|
if key not in ["name", "active", "tools", "errlogs"]: # 排除特殊字段
|
||||||
if key == "mcpServers":
|
if key == "mcpServers":
|
||||||
key_0 = list(server_data["mcpServers"].keys())[
|
try:
|
||||||
0
|
server_config = _extract_mcp_server_config(
|
||||||
] # 不考虑为空的情况
|
server_data["mcpServers"]
|
||||||
server_config = server_data["mcpServers"][key_0]
|
)
|
||||||
|
except ValueError as e:
|
||||||
|
return Response().error(f"{e!s}").__dict__
|
||||||
else:
|
else:
|
||||||
server_config[key] = value
|
server_config[key] = value
|
||||||
has_valid_config = True
|
has_valid_config = True
|
||||||
|
|
||||||
if not has_valid_config:
|
if not has_valid_config:
|
||||||
return Response().error("必须提供有效的服务器配置").__dict__
|
return (
|
||||||
|
Response()
|
||||||
|
.error("A valid server configuration is required")
|
||||||
|
.__dict__
|
||||||
|
)
|
||||||
|
|
||||||
config = self.tool_mgr.load_mcp_config()
|
config = self.tool_mgr.load_mcp_config()
|
||||||
|
|
||||||
if name in config["mcpServers"]:
|
if name in config["mcpServers"]:
|
||||||
return Response().error(f"服务器 {name} 已存在").__dict__
|
return Response().error(f"Server {name} already exists").__dict__
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self.tool_mgr.test_mcp_server_connection(server_config)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
return Response().error(f"MCP connection test failed: {e!s}").__dict__
|
||||||
|
|
||||||
config["mcpServers"][name] = server_config
|
config["mcpServers"][name] = server_config
|
||||||
|
|
||||||
@@ -111,17 +173,27 @@ class ToolsRoute(Route):
|
|||||||
timeout=30,
|
timeout=30,
|
||||||
)
|
)
|
||||||
except TimeoutError:
|
except TimeoutError:
|
||||||
return Response().error(f"启用 MCP 服务器 {name} 超时。").__dict__
|
rollback_ok = self._rollback_mcp_server(name)
|
||||||
|
err_msg = f"Timed out while enabling MCP server {name}."
|
||||||
|
if not rollback_ok:
|
||||||
|
err_msg += " Configuration rollback failed. Please check the config manually."
|
||||||
|
return Response().error(err_msg).__dict__
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
return (
|
rollback_ok = self._rollback_mcp_server(name)
|
||||||
Response().error(f"启用 MCP 服务器 {name} 失败: {e!s}").__dict__
|
err_msg = f"Failed to enable MCP server {name}: {e!s}"
|
||||||
)
|
if not rollback_ok:
|
||||||
return Response().ok(None, f"成功添加 MCP 服务器 {name}").__dict__
|
err_msg += " Configuration rollback failed. Please check the config manually."
|
||||||
return Response().error("保存配置失败").__dict__
|
return Response().error(err_msg).__dict__
|
||||||
|
return (
|
||||||
|
Response()
|
||||||
|
.ok(None, f"Successfully added MCP server {name}")
|
||||||
|
.__dict__
|
||||||
|
)
|
||||||
|
return Response().error("Failed to save configuration").__dict__
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
return Response().error(f"添加 MCP 服务器失败: {e!s}").__dict__
|
return Response().error(f"Failed to add MCP server: {e!s}").__dict__
|
||||||
|
|
||||||
async def update_mcp_server(self):
|
async def update_mcp_server(self):
|
||||||
try:
|
try:
|
||||||
@@ -131,23 +203,25 @@ class ToolsRoute(Route):
|
|||||||
old_name = server_data.get("oldName") or name
|
old_name = server_data.get("oldName") or name
|
||||||
|
|
||||||
if not name:
|
if not name:
|
||||||
return Response().error("服务器名称不能为空").__dict__
|
return Response().error("Server name cannot be empty").__dict__
|
||||||
|
|
||||||
config = self.tool_mgr.load_mcp_config()
|
config = self.tool_mgr.load_mcp_config()
|
||||||
|
|
||||||
if old_name not in config["mcpServers"]:
|
if old_name not in config["mcpServers"]:
|
||||||
return Response().error(f"服务器 {old_name} 不存在").__dict__
|
return Response().error(f"Server {old_name} does not exist").__dict__
|
||||||
|
|
||||||
is_rename = name != old_name
|
is_rename = name != old_name
|
||||||
|
|
||||||
if name in config["mcpServers"] and is_rename:
|
if name in config["mcpServers"] and is_rename:
|
||||||
return Response().error(f"服务器 {name} 已存在").__dict__
|
return Response().error(f"Server {name} already exists").__dict__
|
||||||
|
|
||||||
# 获取活动状态
|
# 获取活动状态
|
||||||
active = server_data.get(
|
old_config = config["mcpServers"][old_name]
|
||||||
"active",
|
if isinstance(old_config, dict):
|
||||||
config["mcpServers"][old_name].get("active", True),
|
old_active = old_config.get("active", True)
|
||||||
)
|
else:
|
||||||
|
old_active = True
|
||||||
|
active = server_data.get("active", old_active)
|
||||||
|
|
||||||
# 创建新的配置对象
|
# 创建新的配置对象
|
||||||
server_config = {"active": active}
|
server_config = {"active": active}
|
||||||
@@ -165,17 +239,19 @@ class ToolsRoute(Route):
|
|||||||
"oldName",
|
"oldName",
|
||||||
]: # 排除特殊字段
|
]: # 排除特殊字段
|
||||||
if key == "mcpServers":
|
if key == "mcpServers":
|
||||||
key_0 = list(server_data["mcpServers"].keys())[
|
try:
|
||||||
0
|
server_config = _extract_mcp_server_config(
|
||||||
] # 不考虑为空的情况
|
server_data["mcpServers"]
|
||||||
server_config = server_data["mcpServers"][key_0]
|
)
|
||||||
|
except ValueError as e:
|
||||||
|
return Response().error(f"{e!s}").__dict__
|
||||||
else:
|
else:
|
||||||
server_config[key] = value
|
server_config[key] = value
|
||||||
only_update_active = False
|
only_update_active = False
|
||||||
|
|
||||||
# 如果只更新活动状态,保留原始配置
|
# 如果只更新活动状态,保留原始配置
|
||||||
if only_update_active:
|
if only_update_active and isinstance(old_config, dict):
|
||||||
for key, value in config["mcpServers"][old_name].items():
|
for key, value in old_config.items():
|
||||||
if key != "active": # 除了active之外的所有字段都保留
|
if key != "active": # 除了active之外的所有字段都保留
|
||||||
server_config[key] = value
|
server_config[key] = value
|
||||||
|
|
||||||
@@ -200,7 +276,7 @@ class ToolsRoute(Route):
|
|||||||
return (
|
return (
|
||||||
Response()
|
Response()
|
||||||
.error(
|
.error(
|
||||||
f"启用前停用 MCP 服务器时 {old_name} 超时: {e!s}"
|
f"Timed out while disabling MCP server {old_name} before enabling: {e!s}"
|
||||||
)
|
)
|
||||||
.__dict__
|
.__dict__
|
||||||
)
|
)
|
||||||
@@ -209,7 +285,7 @@ class ToolsRoute(Route):
|
|||||||
return (
|
return (
|
||||||
Response()
|
Response()
|
||||||
.error(
|
.error(
|
||||||
f"启用前停用 MCP 服务器时 {old_name} 失败: {e!s}"
|
f"Failed to disable MCP server {old_name} before enabling: {e!s}"
|
||||||
)
|
)
|
||||||
.__dict__
|
.__dict__
|
||||||
)
|
)
|
||||||
@@ -221,13 +297,15 @@ class ToolsRoute(Route):
|
|||||||
)
|
)
|
||||||
except TimeoutError:
|
except TimeoutError:
|
||||||
return (
|
return (
|
||||||
Response().error(f"启用 MCP 服务器 {name} 超时。").__dict__
|
Response()
|
||||||
|
.error(f"Timed out while enabling MCP server {name}.")
|
||||||
|
.__dict__
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
return (
|
return (
|
||||||
Response()
|
Response()
|
||||||
.error(f"启用 MCP 服务器 {name} 失败: {e!s}")
|
.error(f"Failed to enable MCP server {name}: {e!s}")
|
||||||
.__dict__
|
.__dict__
|
||||||
)
|
)
|
||||||
# 如果要停用服务器
|
# 如果要停用服务器
|
||||||
@@ -237,22 +315,26 @@ class ToolsRoute(Route):
|
|||||||
except TimeoutError:
|
except TimeoutError:
|
||||||
return (
|
return (
|
||||||
Response()
|
Response()
|
||||||
.error(f"停用 MCP 服务器 {old_name} 超时。")
|
.error(f"Timed out while disabling MCP server {old_name}.")
|
||||||
.__dict__
|
.__dict__
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
return (
|
return (
|
||||||
Response()
|
Response()
|
||||||
.error(f"停用 MCP 服务器 {old_name} 失败: {e!s}")
|
.error(f"Failed to disable MCP server {old_name}: {e!s}")
|
||||||
.__dict__
|
.__dict__
|
||||||
)
|
)
|
||||||
|
|
||||||
return Response().ok(None, f"成功更新 MCP 服务器 {name}").__dict__
|
return (
|
||||||
return Response().error("保存配置失败").__dict__
|
Response()
|
||||||
|
.ok(None, f"Successfully updated MCP server {name}")
|
||||||
|
.__dict__
|
||||||
|
)
|
||||||
|
return Response().error("Failed to save configuration").__dict__
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
return Response().error(f"更新 MCP 服务器失败: {e!s}").__dict__
|
return Response().error(f"Failed to update MCP server: {e!s}").__dict__
|
||||||
|
|
||||||
async def delete_mcp_server(self):
|
async def delete_mcp_server(self):
|
||||||
try:
|
try:
|
||||||
@@ -260,12 +342,12 @@ class ToolsRoute(Route):
|
|||||||
name = server_data.get("name", "")
|
name = server_data.get("name", "")
|
||||||
|
|
||||||
if not name:
|
if not name:
|
||||||
return Response().error("服务器名称不能为空").__dict__
|
return Response().error("Server name cannot be empty").__dict__
|
||||||
|
|
||||||
config = self.tool_mgr.load_mcp_config()
|
config = self.tool_mgr.load_mcp_config()
|
||||||
|
|
||||||
if name not in config["mcpServers"]:
|
if name not in config["mcpServers"]:
|
||||||
return Response().error(f"服务器 {name} 不存在").__dict__
|
return Response().error(f"Server {name} does not exist").__dict__
|
||||||
|
|
||||||
del config["mcpServers"][name]
|
del config["mcpServers"][name]
|
||||||
|
|
||||||
@@ -275,51 +357,76 @@ class ToolsRoute(Route):
|
|||||||
await self.tool_mgr.disable_mcp_server(name, timeout=10)
|
await self.tool_mgr.disable_mcp_server(name, timeout=10)
|
||||||
except TimeoutError:
|
except TimeoutError:
|
||||||
return (
|
return (
|
||||||
Response().error(f"停用 MCP 服务器 {name} 超时。").__dict__
|
Response()
|
||||||
|
.error(f"Timed out while disabling MCP server {name}.")
|
||||||
|
.__dict__
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
return (
|
return (
|
||||||
Response()
|
Response()
|
||||||
.error(f"停用 MCP 服务器 {name} 失败: {e!s}")
|
.error(f"Failed to disable MCP server {name}: {e!s}")
|
||||||
.__dict__
|
.__dict__
|
||||||
)
|
)
|
||||||
return Response().ok(None, f"成功删除 MCP 服务器 {name}").__dict__
|
return (
|
||||||
return Response().error("保存配置失败").__dict__
|
Response()
|
||||||
|
.ok(None, f"Successfully deleted MCP server {name}")
|
||||||
|
.__dict__
|
||||||
|
)
|
||||||
|
return Response().error("Failed to save configuration").__dict__
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
return Response().error(f"删除 MCP 服务器失败: {e!s}").__dict__
|
return Response().error(f"Failed to delete MCP server: {e!s}").__dict__
|
||||||
|
|
||||||
async def test_mcp_connection(self):
|
async def test_mcp_connection(self):
|
||||||
"""测试 MCP 服务器连接"""
|
"""Test MCP server connection."""
|
||||||
try:
|
try:
|
||||||
server_data = await request.json
|
server_data = await request.json
|
||||||
config = server_data.get("mcp_server_config", None)
|
config = server_data.get("mcp_server_config", None)
|
||||||
|
|
||||||
if not isinstance(config, dict) or not config:
|
if not isinstance(config, dict) or not config:
|
||||||
return Response().error("无效的 MCP 服务器配置").__dict__
|
return Response().error("Invalid MCP server configuration").__dict__
|
||||||
|
|
||||||
if "mcpServers" in config:
|
if "mcpServers" in config:
|
||||||
keys = list(config["mcpServers"].keys())
|
mcp_servers = config["mcpServers"]
|
||||||
if not keys:
|
if isinstance(mcp_servers, dict) and len(mcp_servers) > 1:
|
||||||
return Response().error("MCP 服务器配置不能为空").__dict__
|
return (
|
||||||
if len(keys) > 1:
|
Response()
|
||||||
return Response().error("一次只能配置一个 MCP 服务器配置").__dict__
|
.error(
|
||||||
config = config["mcpServers"][keys[0]]
|
"Only one MCP server configuration can be tested at a time"
|
||||||
|
)
|
||||||
|
.__dict__
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
config = _extract_mcp_server_config(mcp_servers)
|
||||||
|
except EmptyMcpServersError:
|
||||||
|
return (
|
||||||
|
Response()
|
||||||
|
.error("MCP server configuration cannot be empty")
|
||||||
|
.__dict__
|
||||||
|
)
|
||||||
|
except ValueError as e:
|
||||||
|
return Response().error(f"{e!s}").__dict__
|
||||||
elif not config:
|
elif not config:
|
||||||
return Response().error("MCP 服务器配置不能为空").__dict__
|
return (
|
||||||
|
Response()
|
||||||
|
.error("MCP server configuration cannot be empty")
|
||||||
|
.__dict__
|
||||||
|
)
|
||||||
|
|
||||||
tools_name = await self.tool_mgr.test_mcp_server_connection(config)
|
tools_name = await self.tool_mgr.test_mcp_server_connection(config)
|
||||||
return (
|
return (
|
||||||
Response().ok(data=tools_name, message="🎉 MCP 服务器可用!").__dict__
|
Response()
|
||||||
|
.ok(data=tools_name, message="🎉 MCP server is available!")
|
||||||
|
.__dict__
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
return Response().error(f"测试 MCP 连接失败: {e!s}").__dict__
|
return Response().error(f"Failed to test MCP connection: {e!s}").__dict__
|
||||||
|
|
||||||
async def get_tool_list(self):
|
async def get_tool_list(self):
|
||||||
"""获取所有注册的工具列表"""
|
"""Get all registered tools."""
|
||||||
try:
|
try:
|
||||||
tools = self.tool_mgr.func_list
|
tools = self.tool_mgr.func_list
|
||||||
tools_dict = []
|
tools_dict = []
|
||||||
@@ -349,36 +456,44 @@ class ToolsRoute(Route):
|
|||||||
return Response().ok(data=tools_dict).__dict__
|
return Response().ok(data=tools_dict).__dict__
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
return Response().error(f"获取工具列表失败: {e!s}").__dict__
|
return Response().error(f"Failed to get tool list: {e!s}").__dict__
|
||||||
|
|
||||||
async def toggle_tool(self):
|
async def toggle_tool(self):
|
||||||
"""启用或停用指定的工具"""
|
"""Activate or deactivate a specified tool."""
|
||||||
try:
|
try:
|
||||||
data = await request.json
|
data = await request.json
|
||||||
tool_name = data.get("name")
|
tool_name = data.get("name")
|
||||||
action = data.get("activate") # True or False
|
action = data.get("activate") # True or False
|
||||||
|
|
||||||
if not tool_name or action is None:
|
if not tool_name or action is None:
|
||||||
return Response().error("缺少必要参数: name 或 action").__dict__
|
return (
|
||||||
|
Response()
|
||||||
|
.error("Missing required parameters: name or activate")
|
||||||
|
.__dict__
|
||||||
|
)
|
||||||
|
|
||||||
if action:
|
if action:
|
||||||
try:
|
try:
|
||||||
ok = self.tool_mgr.activate_llm_tool(tool_name, star_map=star_map)
|
ok = self.tool_mgr.activate_llm_tool(tool_name, star_map=star_map)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
return Response().error(f"启用工具失败: {e!s}").__dict__
|
return Response().error(f"Failed to activate tool: {e!s}").__dict__
|
||||||
else:
|
else:
|
||||||
ok = self.tool_mgr.deactivate_llm_tool(tool_name)
|
ok = self.tool_mgr.deactivate_llm_tool(tool_name)
|
||||||
|
|
||||||
if ok:
|
if ok:
|
||||||
return Response().ok(None, "操作成功。").__dict__
|
return Response().ok(None, "Operation successful.").__dict__
|
||||||
return Response().error(f"工具 {tool_name} 不存在或操作失败。").__dict__
|
return (
|
||||||
|
Response()
|
||||||
|
.error(f"Tool {tool_name} does not exist or the operation failed.")
|
||||||
|
.__dict__
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
return Response().error(f"操作工具失败: {e!s}").__dict__
|
return Response().error(f"Failed to operate tool: {e!s}").__dict__
|
||||||
|
|
||||||
async def sync_provider(self):
|
async def sync_provider(self):
|
||||||
"""同步 MCP 提供者配置"""
|
"""Sync MCP provider configuration."""
|
||||||
try:
|
try:
|
||||||
data = await request.json
|
data = await request.json
|
||||||
provider_name = data.get("name") # modelscope, or others
|
provider_name = data.get("name") # modelscope, or others
|
||||||
@@ -387,9 +502,11 @@ class ToolsRoute(Route):
|
|||||||
access_token = data.get("access_token", "")
|
access_token = data.get("access_token", "")
|
||||||
await self.tool_mgr.sync_modelscope_mcp_servers(access_token)
|
await self.tool_mgr.sync_modelscope_mcp_servers(access_token)
|
||||||
case _:
|
case _:
|
||||||
return Response().error(f"未知: {provider_name}").__dict__
|
return (
|
||||||
|
Response().error(f"Unknown provider: {provider_name}").__dict__
|
||||||
|
)
|
||||||
|
|
||||||
return Response().ok(message="同步成功").__dict__
|
return Response().ok(message="Sync completed").__dict__
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
return Response().error(f"同步失败: {e!s}").__dict__
|
return Response().error(f"Sync failed: {e!s}").__dict__
|
||||||
|
|||||||
@@ -0,0 +1,40 @@
|
|||||||
|
## What's Changed
|
||||||
|
|
||||||
|
### 新增
|
||||||
|
|
||||||
|
- 新增技能 ZIP 批量上传能力 ([#5804](https://github.com/AstrBotDevs/AstrBot/pull/5804))。
|
||||||
|
|
||||||
|
### 修复
|
||||||
|
|
||||||
|
- 修复 MCP Server 配置异常时可能导致崩溃的问题 ([#5666](https://github.com/AstrBotDevs/AstrBot/pull/5666), [#5673](https://github.com/AstrBotDevs/AstrBot/pull/5673))。
|
||||||
|
- 修复钉钉适配器文本消息被忽略、无法主动发送文件的问题 ([#5921](https://github.com/AstrBotDevs/AstrBot/pull/5921))。
|
||||||
|
- 修复钉钉适配器无法接收图片与文件的问题 ([#5920](https://github.com/AstrBotDevs/AstrBot/pull/5920))。
|
||||||
|
- fix(provider): handle MiniMax ThinkingBlock when max_tokens reached ([#5913](https://github.com/AstrBotDevs/AstrBot/pull/5913))。
|
||||||
|
- 修复 OpenRouter `api_base` 配置错误的问题 ([#5911](https://github.com/AstrBotDevs/AstrBot/pull/5911))。
|
||||||
|
- 修复插件市场中按展示名搜索已安装插件不生效的问题 ([#5806](https://github.com/AstrBotDevs/AstrBot/pull/5806), [#5811](https://github.com/AstrBotDevs/AstrBot/pull/5811))。
|
||||||
|
- 修复仅图片响应未应用 `reply_with_quote` 与 `reply_with_mention` 的问题 ([#5219](https://github.com/AstrBotDevs/AstrBot/pull/5219))。
|
||||||
|
- 修复 `RegexFilter` 使用 `re.match` 导致匹配范围不正确的问题 ([#5368](https://github.com/AstrBotDevs/AstrBot/pull/5368))。
|
||||||
|
- 修复桌面运行环境检测依赖 frozen Python 的问题 ([#5859](https://github.com/AstrBotDevs/AstrBot/pull/5859))。
|
||||||
|
- 修复通过“创建新配置”创建平台机器人后找不到 pipeline scheduler 的问题 ([#5776](https://github.com/AstrBotDevs/AstrBot/pull/5776))。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## What's Changed (EN)
|
||||||
|
|
||||||
|
### New Features
|
||||||
|
|
||||||
|
- Added batch upload support for multiple skill ZIP files ([#5804](https://github.com/AstrBotDevs/AstrBot/pull/5804)).
|
||||||
|
|
||||||
|
### Bug Fixes
|
||||||
|
|
||||||
|
- Fixed potential crash on malformed MCP server config ([#5666](https://github.com/AstrBotDevs/AstrBot/pull/5666), [#5673](https://github.com/AstrBotDevs/AstrBot/pull/5673)).
|
||||||
|
- Fixed DingTalk adapter issue where text messages were ignored and files could not be sent proactively ([#5921](https://github.com/AstrBotDevs/AstrBot/pull/5921)).
|
||||||
|
- Fixed DingTalk adapter issue where image and file messages could not be received ([#5920](https://github.com/AstrBotDevs/AstrBot/pull/5920)).
|
||||||
|
- Fixed incorrect OpenRouter `api_base` configuration ([#5911](https://github.com/AstrBotDevs/AstrBot/pull/5911)).
|
||||||
|
- Fixed searching installed plugins by display name in extensions ([#5806](https://github.com/AstrBotDevs/AstrBot/pull/5806), [#5811](https://github.com/AstrBotDevs/AstrBot/pull/5811)).
|
||||||
|
- Fixed image-only responses not applying `reply_with_quote` and `reply_with_mention` ([#5219](https://github.com/AstrBotDevs/AstrBot/pull/5219)).
|
||||||
|
- Fixed `RegexFilter` using `re.match` instead of `re.search` for expected matching behavior ([#5368](https://github.com/AstrBotDevs/AstrBot/pull/5368)).
|
||||||
|
- Fixed desktop runtime detection requiring frozen Python ([#5859](https://github.com/AstrBotDevs/AstrBot/pull/5859)).
|
||||||
|
- Fixed missing pipeline scheduler after creating a platform bot via "create new config" ([#5776](https://github.com/AstrBotDevs/AstrBot/pull/5776)).
|
||||||
|
- fix(provider): handle MiniMax ThinkingBlock when max_tokens reached ([#5913](https://github.com/AstrBotDevs/AstrBot/pull/5913))
|
||||||
|
|
||||||
@@ -300,6 +300,10 @@ export default {
|
|||||||
this.loadingGettingServers = true;
|
this.loadingGettingServers = true;
|
||||||
axios.get('/api/tools/mcp/servers')
|
axios.get('/api/tools/mcp/servers')
|
||||||
.then(response => {
|
.then(response => {
|
||||||
|
if (response.data.status === 'error') {
|
||||||
|
this.showError(response.data.message || this.tm('messages.getServersError', { error: 'Unknown error' }));
|
||||||
|
return;
|
||||||
|
}
|
||||||
this.mcpServers = response.data.data || [];
|
this.mcpServers = response.data.data || [];
|
||||||
this.mcpServers.forEach(server => {
|
this.mcpServers.forEach(server => {
|
||||||
if (!this.mcpServerUpdateLoaders[server.name]) {
|
if (!this.mcpServerUpdateLoaders[server.name]) {
|
||||||
@@ -372,6 +376,10 @@ export default {
|
|||||||
axios.post(endpoint, serverData)
|
axios.post(endpoint, serverData)
|
||||||
.then(response => {
|
.then(response => {
|
||||||
this.loading = false;
|
this.loading = false;
|
||||||
|
if (response.data.status === 'error') {
|
||||||
|
this.showError(response.data.message || this.tm('messages.saveError', { error: 'Unknown error' }));
|
||||||
|
return;
|
||||||
|
}
|
||||||
this.showMcpServerDialog = false;
|
this.showMcpServerDialog = false;
|
||||||
this.addServerDialogMessage = '';
|
this.addServerDialogMessage = '';
|
||||||
this.getServers();
|
this.getServers();
|
||||||
|
|||||||
@@ -101,6 +101,26 @@ async def check_dashboard_files(webui_dir: str | None = None):
|
|||||||
return data_dist_path
|
return data_dist_path
|
||||||
|
|
||||||
|
|
||||||
|
async def main_async(webui_dir_arg: str | None) -> None:
|
||||||
|
"""主异步入口"""
|
||||||
|
# 检查仪表板文件
|
||||||
|
webui_dir = await check_dashboard_files(webui_dir_arg)
|
||||||
|
if webui_dir is None:
|
||||||
|
logger.warning(
|
||||||
|
"管理面板文件检查失败,WebUI 功能将不可用。"
|
||||||
|
"请检查网络连接或手动指定 --webui-dir 参数。"
|
||||||
|
)
|
||||||
|
|
||||||
|
db = db_helper
|
||||||
|
|
||||||
|
# 打印 logo
|
||||||
|
logger.info(logo_tmpl)
|
||||||
|
|
||||||
|
core_lifecycle = InitialLoader(db, log_broker)
|
||||||
|
core_lifecycle.webui_dir = webui_dir
|
||||||
|
await core_lifecycle.start()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description="AstrBot")
|
parser = argparse.ArgumentParser(description="AstrBot")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@@ -117,14 +137,5 @@ if __name__ == "__main__":
|
|||||||
log_broker = LogBroker()
|
log_broker = LogBroker()
|
||||||
LogManager.set_queue_handler(logger, log_broker)
|
LogManager.set_queue_handler(logger, log_broker)
|
||||||
|
|
||||||
# 检查仪表板文件
|
# 只使用一次 asyncio.run()
|
||||||
webui_dir = asyncio.run(check_dashboard_files(args.webui_dir))
|
asyncio.run(main_async(args.webui_dir))
|
||||||
|
|
||||||
db = db_helper
|
|
||||||
|
|
||||||
# 打印 logo
|
|
||||||
logger.info(logo_tmpl)
|
|
||||||
|
|
||||||
core_lifecycle = InitialLoader(db, log_broker)
|
|
||||||
core_lifecycle.webui_dir = webui_dir
|
|
||||||
asyncio.run(core_lifecycle.start())
|
|
||||||
|
|||||||
+1
-1
@@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "AstrBot"
|
name = "AstrBot"
|
||||||
version = "4.19.2"
|
version = "4.19.3"
|
||||||
description = "Easy-to-use multi-platform LLM chatbot and development framework"
|
description = "Easy-to-use multi-platform LLM chatbot and development framework"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.12"
|
requires-python = ">=3.12"
|
||||||
|
|||||||
@@ -0,0 +1,545 @@
|
|||||||
|
"""Tests for SessionLockManager with multi-event-loop isolation."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
import weakref
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from astrbot.core.utils.session_lock import SessionLockManager
|
||||||
|
|
||||||
|
|
||||||
|
class TestSessionLockManagerBasic:
|
||||||
|
"""Basic functionality tests."""
|
||||||
|
|
||||||
|
def test_init(self):
|
||||||
|
"""Test manager initialization."""
|
||||||
|
manager = SessionLockManager()
|
||||||
|
assert manager._state_guard is not None
|
||||||
|
assert manager._loop_managers is not None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_acquire_release_lock(self):
|
||||||
|
"""Test basic lock acquire and release."""
|
||||||
|
manager = SessionLockManager()
|
||||||
|
session_id = "test-session"
|
||||||
|
|
||||||
|
async with manager.acquire_lock(session_id):
|
||||||
|
# Lock acquired successfully
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Lock should be released and cleaned up
|
||||||
|
state = manager._get_loop_manager()
|
||||||
|
assert session_id not in state._locks
|
||||||
|
assert session_id not in state._lock_count
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_lock_is_reusable(self):
|
||||||
|
"""Test that locks can be acquired multiple times."""
|
||||||
|
manager = SessionLockManager()
|
||||||
|
session_id = "test-session"
|
||||||
|
|
||||||
|
async with manager.acquire_lock(session_id):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async with manager.acquire_lock(session_id):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Both acquisitions should succeed
|
||||||
|
|
||||||
|
|
||||||
|
class TestCrossLoopIsolation:
|
||||||
|
"""Tests for event loop isolation."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_different_loops_have_different_managers(self):
|
||||||
|
"""Test that different event loops get different per-loop managers."""
|
||||||
|
manager = SessionLockManager()
|
||||||
|
|
||||||
|
# Get manager for current loop
|
||||||
|
manager1 = manager._get_loop_manager()
|
||||||
|
|
||||||
|
# Run in a different event loop
|
||||||
|
def run_in_new_loop():
|
||||||
|
new_loop = asyncio.new_event_loop()
|
||||||
|
try:
|
||||||
|
asyncio.set_event_loop(new_loop)
|
||||||
|
|
||||||
|
async def get_manager():
|
||||||
|
return manager._get_loop_manager()
|
||||||
|
|
||||||
|
return new_loop.run_until_complete(get_manager())
|
||||||
|
finally:
|
||||||
|
new_loop.close()
|
||||||
|
asyncio.set_event_loop(None)
|
||||||
|
|
||||||
|
with ThreadPoolExecutor(max_workers=1) as executor:
|
||||||
|
future = executor.submit(run_in_new_loop)
|
||||||
|
manager2 = future.result()
|
||||||
|
|
||||||
|
# Should be different manager instances
|
||||||
|
assert manager1 is not manager2
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_locks_isolated_across_loops(self):
|
||||||
|
"""Test that locks from different loops are isolated."""
|
||||||
|
manager = SessionLockManager()
|
||||||
|
session_id = "shared-session"
|
||||||
|
results = []
|
||||||
|
|
||||||
|
async def acquire_in_loop(loop_id: int):
|
||||||
|
"""Acquire lock in a new event loop."""
|
||||||
|
async with manager.acquire_lock(session_id):
|
||||||
|
results.append(f"loop-{loop_id}-acquired")
|
||||||
|
await asyncio.sleep(0.05)
|
||||||
|
results.append(f"loop-{loop_id}-released")
|
||||||
|
|
||||||
|
def run_in_thread(loop_id: int):
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
try:
|
||||||
|
loop.run_until_complete(acquire_in_loop(loop_id))
|
||||||
|
finally:
|
||||||
|
loop.close()
|
||||||
|
asyncio.set_event_loop(None)
|
||||||
|
|
||||||
|
# Run two loops concurrently - they should NOT block each other
|
||||||
|
# because locks are isolated per-loop
|
||||||
|
with ThreadPoolExecutor(max_workers=2) as executor:
|
||||||
|
futures = [executor.submit(run_in_thread, i) for i in range(2)]
|
||||||
|
for f in futures:
|
||||||
|
f.result()
|
||||||
|
|
||||||
|
# Both loops should acquire immediately (no blocking between loops)
|
||||||
|
# Order should show interleaved acquisitions, not sequential
|
||||||
|
assert len(results) == 4
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_same_loop_blocks_on_same_session(self):
|
||||||
|
"""Test that same loop blocks when acquiring same session lock."""
|
||||||
|
manager = SessionLockManager()
|
||||||
|
session_id = "test-session"
|
||||||
|
execution_order = []
|
||||||
|
|
||||||
|
async def task1():
|
||||||
|
async with manager.acquire_lock(session_id):
|
||||||
|
execution_order.append("task1-start")
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
execution_order.append("task1-end")
|
||||||
|
|
||||||
|
async def task2():
|
||||||
|
await asyncio.sleep(0.01) # Let task1 start first
|
||||||
|
async with manager.acquire_lock(session_id):
|
||||||
|
execution_order.append("task2-start")
|
||||||
|
execution_order.append("task2-end")
|
||||||
|
|
||||||
|
await asyncio.gather(task1(), task2())
|
||||||
|
|
||||||
|
# task2 should wait for task1 to finish
|
||||||
|
assert execution_order.index("task1-start") < execution_order.index("task1-end")
|
||||||
|
assert execution_order.index("task1-end") < execution_order.index("task2-start")
|
||||||
|
|
||||||
|
|
||||||
|
class TestConcurrency:
|
||||||
|
"""Tests for concurrent access."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_concurrent_acquisitions_same_loop(self):
|
||||||
|
"""Test concurrent lock acquisitions on the same loop."""
|
||||||
|
manager = SessionLockManager()
|
||||||
|
session_id = "concurrent-session"
|
||||||
|
acquired_count = 0
|
||||||
|
max_concurrent = 0
|
||||||
|
lock = asyncio.Lock()
|
||||||
|
|
||||||
|
async def acquire_and_check():
|
||||||
|
nonlocal acquired_count, max_concurrent
|
||||||
|
async with manager.acquire_lock(session_id):
|
||||||
|
async with lock:
|
||||||
|
acquired_count += 1
|
||||||
|
max_concurrent = max(max_concurrent, acquired_count)
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
async with lock:
|
||||||
|
acquired_count -= 1
|
||||||
|
|
||||||
|
# Run multiple concurrent tasks
|
||||||
|
tasks = [acquire_and_check() for _ in range(5)]
|
||||||
|
await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
# Max concurrent should be 1 (lock serializes access)
|
||||||
|
assert max_concurrent == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_thread_safety_of_loop_manager_creation(self):
|
||||||
|
"""Test that _get_loop_manager is thread-safe."""
|
||||||
|
manager = SessionLockManager()
|
||||||
|
managers = []
|
||||||
|
errors = []
|
||||||
|
|
||||||
|
def create_loop_and_get_manager():
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
try:
|
||||||
|
|
||||||
|
async def get_mgr():
|
||||||
|
return manager._get_loop_manager()
|
||||||
|
|
||||||
|
mgr = loop.run_until_complete(get_mgr())
|
||||||
|
managers.append(mgr)
|
||||||
|
except Exception as e:
|
||||||
|
errors.append(e)
|
||||||
|
finally:
|
||||||
|
loop.close()
|
||||||
|
asyncio.set_event_loop(None)
|
||||||
|
|
||||||
|
threads = [threading.Thread(target=create_loop_and_get_manager) for _ in range(10)]
|
||||||
|
for t in threads:
|
||||||
|
t.start()
|
||||||
|
for t in threads:
|
||||||
|
t.join()
|
||||||
|
|
||||||
|
assert len(errors) == 0
|
||||||
|
# All managers should be valid
|
||||||
|
for m in managers:
|
||||||
|
assert hasattr(m, "_locks")
|
||||||
|
assert hasattr(m, "_access_lock")
|
||||||
|
|
||||||
|
|
||||||
|
class TestEventLoopCleanup:
|
||||||
|
"""Tests for event loop cleanup behavior."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_weakref_cleanup_on_loop_close(self):
|
||||||
|
"""Test that per-loop managers are cleaned up when loop is closed."""
|
||||||
|
manager = SessionLockManager()
|
||||||
|
loop_ref: weakref.ref[asyncio.AbstractEventLoop] | None = None
|
||||||
|
|
||||||
|
def run_in_new_loop():
|
||||||
|
nonlocal loop_ref
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
loop_ref = weakref.ref(loop)
|
||||||
|
|
||||||
|
async def use_lock():
|
||||||
|
async with manager.acquire_lock("test-session"):
|
||||||
|
pass
|
||||||
|
return manager._get_loop_manager()
|
||||||
|
|
||||||
|
try:
|
||||||
|
per_loop_mgr = loop.run_until_complete(use_lock())
|
||||||
|
# Keep a weak ref to the per-loop manager
|
||||||
|
return weakref.ref(per_loop_mgr)
|
||||||
|
finally:
|
||||||
|
loop.close()
|
||||||
|
asyncio.set_event_loop(None)
|
||||||
|
|
||||||
|
with ThreadPoolExecutor(max_workers=1) as executor:
|
||||||
|
future = executor.submit(run_in_new_loop)
|
||||||
|
per_loop_mgr_ref = future.result()
|
||||||
|
|
||||||
|
# Give time for weakref cleanup
|
||||||
|
import gc
|
||||||
|
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
# The per-loop manager should be cleaned up when the loop is closed
|
||||||
|
# because WeakKeyDictionary removes entries when the key (loop) is gone
|
||||||
|
per_loop_mgr = per_loop_mgr_ref()
|
||||||
|
loop = loop_ref() if loop_ref is not None else None
|
||||||
|
assert per_loop_mgr is None or loop is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_access_after_loop_close_in_new_loop_works(self):
|
||||||
|
"""Test that accessing from a new loop after old loop closes works."""
|
||||||
|
manager = SessionLockManager()
|
||||||
|
|
||||||
|
# Use lock in current loop
|
||||||
|
async with manager.acquire_lock("session-1"):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Simulate old loop being closed and new loop being created
|
||||||
|
def run_in_new_loop():
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
try:
|
||||||
|
|
||||||
|
async def use_lock():
|
||||||
|
# Should work without issues in new loop
|
||||||
|
async with manager.acquire_lock("session-2"):
|
||||||
|
return "success"
|
||||||
|
|
||||||
|
return loop.run_until_complete(use_lock())
|
||||||
|
finally:
|
||||||
|
loop.close()
|
||||||
|
asyncio.set_event_loop(None)
|
||||||
|
|
||||||
|
with ThreadPoolExecutor(max_workers=1) as executor:
|
||||||
|
future = executor.submit(run_in_new_loop)
|
||||||
|
result = future.result()
|
||||||
|
|
||||||
|
assert result == "success"
|
||||||
|
|
||||||
|
|
||||||
|
class TestIssue5464:
|
||||||
|
"""Tests for issue #5464: Multiple OneBot instances with different event loops.
|
||||||
|
|
||||||
|
Issue: Running multiple OneBot adapter instances causes
|
||||||
|
"is bound to a different event loop" error.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_multiple_event_loops_no_cross_loop_error(self):
|
||||||
|
"""Test that multiple event loops don't cause cross-loop binding errors.
|
||||||
|
|
||||||
|
This simulates the scenario where multiple OneBot instances
|
||||||
|
(each potentially running in different event loops) access the
|
||||||
|
same SessionLockManager concurrently.
|
||||||
|
"""
|
||||||
|
from astrbot.core.utils.session_lock import session_lock_manager
|
||||||
|
|
||||||
|
errors: list[Exception] = []
|
||||||
|
results: list[str] = []
|
||||||
|
|
||||||
|
def simulate_onebot_instance(instance_id: int, session_ids: list[str]):
|
||||||
|
"""Simulate a OneBot instance running in its own event loop."""
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
try:
|
||||||
|
|
||||||
|
async def process_messages():
|
||||||
|
for session_id in session_ids:
|
||||||
|
try:
|
||||||
|
async with session_lock_manager.acquire_lock(session_id):
|
||||||
|
# Simulate message processing
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
results.append(f"instance-{instance_id}-{session_id}")
|
||||||
|
except Exception as e:
|
||||||
|
errors.append(e)
|
||||||
|
|
||||||
|
loop.run_until_complete(process_messages())
|
||||||
|
finally:
|
||||||
|
loop.close()
|
||||||
|
asyncio.set_event_loop(None)
|
||||||
|
|
||||||
|
# Simulate 4 OneBot instances (as in the issue report)
|
||||||
|
# Each handles multiple sessions concurrently
|
||||||
|
threads = []
|
||||||
|
for i in range(4):
|
||||||
|
sessions = [f"session-{i}-1", f"session-{i}-2", f"session-{i}-3"]
|
||||||
|
t = threading.Thread(target=simulate_onebot_instance, args=(i, sessions))
|
||||||
|
threads.append(t)
|
||||||
|
|
||||||
|
for t in threads:
|
||||||
|
t.start()
|
||||||
|
for t in threads:
|
||||||
|
t.join()
|
||||||
|
|
||||||
|
# Should have no errors (especially no "bound to a different event loop")
|
||||||
|
assert len(errors) == 0, f"Errors occurred: {errors}"
|
||||||
|
assert len(results) == 12 # 4 instances * 3 sessions each
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_lock_object_not_shared_across_loops(self):
|
||||||
|
"""Verify that asyncio.Lock objects are not shared across event loops.
|
||||||
|
|
||||||
|
The root cause of issue #5464 was that Lock objects created in one
|
||||||
|
event loop were being used in another, causing the error.
|
||||||
|
"""
|
||||||
|
manager = SessionLockManager()
|
||||||
|
session_id = "shared-session-id"
|
||||||
|
lock_ids: set[int] = set()
|
||||||
|
lock_id_lock = threading.Lock()
|
||||||
|
|
||||||
|
def get_lock_in_new_loop():
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
try:
|
||||||
|
|
||||||
|
async def acquire_and_capture():
|
||||||
|
# Get the per-loop manager
|
||||||
|
per_loop_mgr = manager._get_loop_manager()
|
||||||
|
# Capture the lock object id before acquiring
|
||||||
|
async with per_loop_mgr._access_lock:
|
||||||
|
lock = per_loop_mgr._locks[session_id]
|
||||||
|
with lock_id_lock:
|
||||||
|
lock_ids.add(id(lock))
|
||||||
|
async with manager.acquire_lock(session_id):
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
|
||||||
|
loop.run_until_complete(acquire_and_capture())
|
||||||
|
finally:
|
||||||
|
loop.close()
|
||||||
|
asyncio.set_event_loop(None)
|
||||||
|
|
||||||
|
# Run multiple loops concurrently
|
||||||
|
threads = [threading.Thread(target=get_lock_in_new_loop) for _ in range(5)]
|
||||||
|
for t in threads:
|
||||||
|
t.start()
|
||||||
|
for t in threads:
|
||||||
|
t.join()
|
||||||
|
|
||||||
|
# Each loop should have its own Lock object
|
||||||
|
# If locks were shared, we'd only have 1 lock_id
|
||||||
|
assert len(lock_ids) == 5, "Each event loop should have its own Lock object"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_concurrent_access_same_session_different_loops(self):
|
||||||
|
"""Test that same session ID accessed from different loops doesn't block.
|
||||||
|
|
||||||
|
This verifies the fix: locks are isolated per event loop,
|
||||||
|
so different loops can acquire the "same" session lock concurrently.
|
||||||
|
"""
|
||||||
|
from astrbot.core.utils.session_lock import session_lock_manager
|
||||||
|
|
||||||
|
session_id = "global-session"
|
||||||
|
acquisition_times: list[float] = []
|
||||||
|
time_lock = threading.Lock()
|
||||||
|
|
||||||
|
def acquire_lock_in_loop(loop_id: int):
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
try:
|
||||||
|
|
||||||
|
async def acquire():
|
||||||
|
import time
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
async with session_lock_manager.acquire_lock(session_id):
|
||||||
|
with time_lock:
|
||||||
|
acquisition_times.append(start)
|
||||||
|
await asyncio.sleep(0.1) # Hold the lock
|
||||||
|
|
||||||
|
loop.run_until_complete(acquire())
|
||||||
|
finally:
|
||||||
|
loop.close()
|
||||||
|
asyncio.set_event_loop(None)
|
||||||
|
|
||||||
|
# Start 3 threads nearly simultaneously
|
||||||
|
threads = [threading.Thread(target=acquire_lock_in_loop, args=(i,)) for i in range(3)]
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
for t in threads:
|
||||||
|
t.start()
|
||||||
|
for t in threads:
|
||||||
|
t.join()
|
||||||
|
total_time = time.time() - start_time
|
||||||
|
|
||||||
|
# If locks were NOT isolated, we'd need ~0.3s (3 * 0.1s serial)
|
||||||
|
# With isolation, all should complete in ~0.1s (parallel)
|
||||||
|
# Allow some overhead, but should be much less than 0.3s
|
||||||
|
assert total_time < 0.25, (
|
||||||
|
f"Locks should be isolated per loop, but took {total_time:.2f}s"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestEdgeCases:
|
||||||
|
"""Tests for edge cases."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_empty_session_id(self):
|
||||||
|
"""Test with empty session ID."""
|
||||||
|
manager = SessionLockManager()
|
||||||
|
|
||||||
|
async with manager.acquire_lock(""):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Should work without issues
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_special_characters_in_session_id(self):
|
||||||
|
"""Test with special characters in session ID."""
|
||||||
|
manager = SessionLockManager()
|
||||||
|
session_id = "session-with-special-chars!@#$%^&*()"
|
||||||
|
|
||||||
|
async with manager.acquire_lock(session_id):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Should work without issues
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_very_long_session_id(self):
|
||||||
|
"""Test with very long session ID."""
|
||||||
|
manager = SessionLockManager()
|
||||||
|
session_id = "a" * 10000
|
||||||
|
|
||||||
|
async with manager.acquire_lock(session_id):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Should work without issues
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_lock_not_held_after_context_exit(self):
|
||||||
|
"""Test that lock is released after context manager exit."""
|
||||||
|
manager = SessionLockManager()
|
||||||
|
session_id = "test-session"
|
||||||
|
|
||||||
|
async with manager.acquire_lock(session_id):
|
||||||
|
state = manager._get_loop_manager()
|
||||||
|
# Lock should exist and have count 1
|
||||||
|
assert session_id in state._locks
|
||||||
|
assert state._lock_count[session_id] == 1
|
||||||
|
|
||||||
|
# After exit, lock should be cleaned up
|
||||||
|
state = manager._get_loop_manager()
|
||||||
|
assert session_id not in state._locks
|
||||||
|
assert session_id not in state._lock_count
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_exception_during_lock(self):
|
||||||
|
"""Test that lock is released even if exception occurs."""
|
||||||
|
manager = SessionLockManager()
|
||||||
|
session_id = "test-session"
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
async with manager.acquire_lock(session_id):
|
||||||
|
raise ValueError("test error")
|
||||||
|
|
||||||
|
# Lock should still be released
|
||||||
|
state = manager._get_loop_manager()
|
||||||
|
assert session_id not in state._locks
|
||||||
|
assert session_id not in state._lock_count
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_nested_lock_different_sessions(self):
|
||||||
|
"""Test nested locks on different sessions."""
|
||||||
|
manager = SessionLockManager()
|
||||||
|
|
||||||
|
async with manager.acquire_lock("session-1"):
|
||||||
|
async with manager.acquire_lock("session-2"):
|
||||||
|
state = manager._get_loop_manager()
|
||||||
|
assert "session-1" in state._locks
|
||||||
|
assert "session-2" in state._locks
|
||||||
|
assert state._lock_count["session-1"] == 1
|
||||||
|
assert state._lock_count["session-2"] == 1
|
||||||
|
|
||||||
|
state = manager._get_loop_manager()
|
||||||
|
assert "session-1" not in state._locks
|
||||||
|
assert "session-2" not in state._locks
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_reentrant_lock_same_session(self):
|
||||||
|
"""Test reentrant locking on same session (should block)."""
|
||||||
|
manager = SessionLockManager()
|
||||||
|
session_id = "test-session"
|
||||||
|
order = []
|
||||||
|
|
||||||
|
async def outer():
|
||||||
|
async with manager.acquire_lock(session_id):
|
||||||
|
order.append("outer-acquired")
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
order.append("outer-done")
|
||||||
|
|
||||||
|
async def inner():
|
||||||
|
await asyncio.sleep(0.01) # Let outer acquire first
|
||||||
|
order.append("inner-attempt")
|
||||||
|
async with manager.acquire_lock(session_id):
|
||||||
|
order.append("inner-acquired")
|
||||||
|
order.append("inner-done")
|
||||||
|
|
||||||
|
await asyncio.gather(outer(), inner())
|
||||||
|
|
||||||
|
# Inner should wait for outer to complete
|
||||||
|
assert order.index("outer-acquired") < order.index("outer-done")
|
||||||
|
assert order.index("outer-done") < order.index("inner-acquired")
|
||||||
Reference in New Issue
Block a user