Compare commits

..

12 Commits

Author SHA1 Message Date
Soulter 447b4542d1 chore: bump version to 4.19.1 2026-03-05 01:38:54 +08:00
Soulter ead10b5643 refactor: remove runtime_bootstrap module and related initialization 2026-03-05 01:38:27 +08:00
Soulter 6beca2144c revert: #5729
This reverts commit a9c16febf4.
2026-03-05 01:34:07 +08:00
Soulter 2d27bfb6d0 revert: #5744
This reverts commit 3d1c3946f6.
2026-03-05 01:29:36 +08:00
エイカク 3d1c3946f6 feat(ci): add nightly prerelease release flow and updater support (#5744)
* feat: add nightly prerelease release flow and updater support

* feat(ci): auto-generate nightly release notes from latest stable tag

* fix(ci): correct nightly release notes heredoc YAML indentation

* fix(ci): align nightly notes heredoc terminator

* fix(ci): remove heredoc body indentation in nightly notes script

* fix: align nightly release metadata and prerelease rules

* fix: harden nightly release flow and updater release resolution

* fix: improve nightly branch resolution and updater logging

* fix: simplify updater target resolution and nightly release assets

* fix: avoid inputs lookup on non-dispatch release events

* fix: split nightly release fetch and simplify updater flow

* refactor: simplify updater target resolvers and nightly error checks

* fix: type release fetch errors and streamline updater resolution

* refactor: simplify updater target branching and release artifacts

* refactor: simplify release fetching and harden nightly git diagnostics

* fix: validate release payload shape before parsing

* refactor: harden prerelease handling and nightly constants

* refactor: derive archive urls and enrich fetch errors

* refactor: simplify update target resolution flow

* refactor: linearize update target resolution

* refactor: validate update target inputs and sync nightly tag source

* refactor: simplify updater mode resolution and prerelease tests

* refactor: simplify update target resolution flow

* fix: avoid package import when resolving nightly tag

* refactor: simplify updater resolution and centralize release constants

* fix: harden nightly release notes generation in workflow

* refactor: streamline update target resolution and errors

* refactor: simplify updater target resolution and nightly handling

* refactor: simplify updater errors and package release scripts

* refactor: centralize release api constants and loader

* fix(ci): resolve dispatch fallback tag from stable releases
2026-03-05 01:23:49 +09:00
Soulter cd434c5fed chore: bump version to 4.19.0 2026-03-04 23:39:52 +08:00
camera-2018 9683abeb19 feat(telegram): supports sendMessageDraft API (#5726)
* feat(telegram): 使用 sendMessageDraft API 实现私聊流式输出

- 新增 _send_message_draft 方法封装 Telegram Bot API sendMessageDraft
- 私聊流式输出使用 sendMessageDraft 推送草稿动画,群聊保留 edit_message_text 回退
- 使用独立异步发送循环 (_draft_sender_loop) 按固定间隔推送最新缓冲区内容,
  完全解耦 token 到达速度与 API 网络延迟
- 流式结束后发送真实消息保留最终内容(draft 是临时的)
- 使用模块级递增 draft_id 替代随机生成,确保 Telegram 端动画连续性

* fix(telegram): convert draft text to Markdown before sending message draft

* chore(telegram): telegram 适配器重构

- 提取公共方法
- 有新 token 到达时触发流式
- 生成结束后清除draft内容
- 默认draft发送md格式

* style(telegram): ruff format

* style(telegram): ruff check

---------

Co-authored-by: Soulter <905617992@qq.com>
2026-03-04 23:11:57 +08:00
エイカク ab96537308 [codex] fix mcp init timeout keyword mismatch (#5743)
* fix: use timeout_seconds for mcp init startup

* fix: support overridden mcp init timeout in startup

* fix: resolve mcp init timeout from env when unset

* fix: pass mcp init timeout through lifecycle chain
2026-03-04 21:20:07 +09:00
Soulter 78fa58714c fix: require node.js env when uv sync 2026-03-04 18:25:19 +08:00
エイカク 9afe5757be feat: optimize async io performance and benchmark coverage (#5737)
* docs: align deployment sections across multilingual readmes

* docs: normalize deployment punctuation and AUR guidance

* docs: fix french and russian deployment wording

* perf: optimize async io hot paths and extend benchmarks

* fix: address async io review feedback

* fix: address follow-up async io review comments

* fix: align base64 io error handling in message components

* fix: harden attachment export ids and tune io chunking

* fix: preserve best-effort attachment export and batch writes

* test: expand path conversion and helper coverage
2026-03-04 16:26:34 +09:00
エイカク bbc8c62d43 docs: align deployment sections across multilingual readmes (#5734)
* docs: align deployment sections across multilingual readmes

* docs: normalize deployment punctuation and AUR guidance

* docs: fix french and russian deployment wording
2026-03-04 14:41:36 +09:00
エイカク a9c16febf4 fix: 工程化收敛并移除 ASYNC230/ASYNC240 忽略 (#5729)
* test(skills): align sandbox cache tests with readonly behavior

* ci(release): enforce core quality gate before publish

* ci: enforce locked dependency installs in workflows

* security: remove curl-pipe-shell installs

* chore: align project python baseline to 3.12

* ci(dashboard): add explicit typecheck gate

* chore(pre-commit): align ruff hook version with project

* ci(codeql): add javascript-typescript analysis

* chore(ruff): defer py312 migration lint rules

* fix: resolve ruff violations without new ignores

* fix: resolve ASYNC230 and ASYNC240 without ignores

* fix(auth): replace utcnow with timezone-aware UTC now

* fix: avoid blocking file read in file_to_base64
2026-03-04 13:51:00 +09:00
66 changed files with 1141 additions and 2966 deletions
-12
View File
@@ -83,15 +83,6 @@ astrbot
> Requires [uv](https://docs.astral.sh/uv/) to be installed.
> [!NOTE]
> For macOS user: due to macOS security checks, the first run of the `astrbot` command may take longer (about 10-20s).
Update `astrbot`:
```bash
uv tool upgrade astrbot
```
### Docker Deployment
For users familiar with containers and looking for a more stable, production-ready deployment method, we recommend deploying AstrBot with Docker / Docker Compose.
@@ -225,15 +216,12 @@ pre-commit install
### QQ Groups
- Group 9: 1076659624 (New)
- Group 10: 1078079676 (New)
- Group 1: 322154837
- Group 3: 630166526
- Group 5: 822130018
- Group 6: 753075035
- Group 7: 743746109
- Group 8: 1030353265
- Developer Group: 975206796
### Discord Server
-9
View File
@@ -83,15 +83,6 @@ astrbot
> [uv](https://docs.astral.sh/uv/) doit être installé.
> [!NOTE]
> Pour les utilisateurs macOS : en raison des vérifications de sécurité de macOS, la première exécution de la commande `astrbot` peut prendre plus de temps (environ 10-20s).
Mettre à jour `astrbot` :
```bash
uv tool upgrade astrbot
```
### Déploiement Docker
Pour les utilisateurs familiers avec les conteneurs et qui souhaitent une méthode plus stable et adaptée à la production, nous recommandons de déployer AstrBot avec Docker / Docker Compose.
-9
View File
@@ -83,15 +83,6 @@ astrbot
> [uv](https://docs.astral.sh/uv/) のインストールが必要です。
> [!NOTE]
> macOS ユーザーの場合:macOS のセキュリティチェックにより、`astrbot` コマンドの初回実行に時間がかかる場合があります(約 10〜20 秒)。
`astrbot` の更新:
```bash
uv tool upgrade astrbot
```
### Docker デプロイ
コンテナ運用に慣れており、より安定した本番向けのデプロイ方法を求めるユーザーには、Docker / Docker Compose での AstrBot デプロイをおすすめします。
-9
View File
@@ -83,15 +83,6 @@ astrbot
> Требуется установленный [uv](https://docs.astral.sh/uv/).
> [!NOTE]
> Для пользователей macOS: из-за проверок безопасности macOS первый запуск команды `astrbot` может занять больше времени (около 10-20 секунд).
Обновить `astrbot`:
```bash
uv tool upgrade astrbot
```
### Развёртывание Docker
Для пользователей, знакомых с контейнерами и которым нужен более стабильный и подходящий для production способ, мы рекомендуем разворачивать AstrBot через Docker / Docker Compose.
-13
View File
@@ -83,15 +83,6 @@ astrbot
> 需要安裝 [uv](https://docs.astral.sh/uv/)。
> [!NOTE]
> 對於 macOS 使用者:由於 macOS 安全性檢查,首次執行 `astrbot` 指令可能需要較長時間(約 10-20 秒)。
更新 `astrbot`
```bash
uv tool upgrade astrbot
```
### Docker 部署
對於熟悉容器、希望獲得更穩定且更適合正式環境部署方式的使用者,我們推薦使用 Docker / Docker Compose 部署 AstrBot。
@@ -217,14 +208,10 @@ pre-commit install
### QQ 群組
- 9 群: 1076659624 (新)
- 10 群: 1078079676 (新)
- 1 群:322154837
- 3 群:630166526
- 5 群:822130018
- 6 群:753075035
- 7 群:743746109
- 8 群:1030353265
- 開發者群:975206796
### Discord 群組
-11
View File
@@ -83,15 +83,6 @@ astrbot
> 需要安装 [uv](https://docs.astral.sh/uv/)。
> [!NOTE]
> 对于 macOS 用户:由于 macOS 安全检查,首次运行 `astrbot` 命令可能需要较长时间(约 10-20 秒)。
更新 `astrbot`
```bash
uv tool upgrade astrbot
```
### Docker 部署
对于熟悉容器、希望获得更稳定且更适合生产环境部署方式的用户,我们推荐使用 Docker / Docker Compose 部署 AstrBot。
@@ -218,8 +209,6 @@ pre-commit install
### QQ 群组
- 9 群: 1076659624 (新)
- 10 群: 1078079676 (新)
- 1 群:322154837
- 3 群:630166526
- 5 群:822130018
+1 -1
View File
@@ -1 +1 @@
__version__ = "4.19.3"
__version__ = "4.19.1"
@@ -302,7 +302,7 @@ class DashscopeAgentRunner(BaseAgentRunner[TContext]):
while True:
try:
item_type, item_data = await asyncio.get_running_loop().run_in_executor(
item_type, item_data = await asyncio.get_event_loop().run_in_executor(
None, response_queue.get, True, 1
)
except queue.Empty:
@@ -388,7 +388,7 @@ class DashscopeAgentRunner(BaseAgentRunner[TContext]):
# 发起请求
partial = functools.partial(Application.call, **payload)
response = await asyncio.get_running_loop().run_in_executor(None, partial)
response = await asyncio.get_event_loop().run_in_executor(None, partial)
async for resp in self._handle_streaming_response(response, session_id):
yield resp
+2 -3
View File
@@ -121,12 +121,11 @@ class BayContainerManager:
async def wait_healthy(self, timeout: int = HEALTH_TIMEOUT_S) -> None:
"""Block until Bay's ``/health`` endpoint returns 200."""
url = f"http://127.0.0.1:{self._host_port}/health"
loop = asyncio.get_running_loop()
deadline = loop.time() + timeout
deadline = asyncio.get_event_loop().time() + timeout
last_error: str = ""
async with aiohttp.ClientSession() as session:
while loop.time() < deadline:
while asyncio.get_event_loop().time() < deadline:
try:
async with session.get(
url, timeout=aiohttp.ClientTimeout(total=3)
+2 -46
View File
@@ -5,7 +5,7 @@ from typing import Any, TypedDict
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
VERSION = "4.19.3"
VERSION = "4.19.1"
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db")
WEBHOOK_SUPPORTED_PLATFORMS = [
@@ -343,16 +343,11 @@ CONFIG_METADATA_2 = {
"id": "wecom_ai_bot",
"type": "wecom_ai_bot",
"enable": True,
"wecom_ai_bot_connection_mode": "webhook",
"wecomaibot_init_respond_text": "",
"wecomaibot_friend_message_welcome_text": "",
"wecom_ai_bot_name": "",
"msg_push_webhook_url": "",
"only_use_webhook_url_to_send": False,
"long_connection_bot_id": "",
"long_connection_secret": "",
"long_connection_ws_url": "wss://openws.work.weixin.qq.com",
"long_connection_heartbeat_interval": 30,
"token": "",
"encoding_aes_key": "",
"unified_webhook_mode": True,
@@ -737,13 +732,6 @@ CONFIG_METADATA_2 = {
"type": "string",
"hint": "请务必填写正确,否则无法使用一些指令。",
},
"wecom_ai_bot_connection_mode": {
"description": "企业微信智能机器人连接模式",
"type": "string",
"options": ["webhook", "long_connection"],
"labels": ["Webhook 回调", "长连接"],
"hint": "Webhook 回调模式需要配置 Token/EncodingAESKey。长连接模式需要配置 BotID/Secret。",
},
"wecomaibot_init_respond_text": {
"description": "企业微信智能机器人初始响应文本",
"type": "string",
@@ -764,38 +752,6 @@ CONFIG_METADATA_2 = {
"type": "bool",
"hint": "启用后,企业微信智能机器人的所有回复都改为通过消息推送 Webhook 发送。消息推送 Webhook 支持更多的消息类型(如图片、文件等)。",
},
"long_connection_bot_id": {
"description": "长连接 BotID",
"type": "string",
"hint": "企业微信智能机器人长连接模式凭证 BotID。",
"condition": {
"wecom_ai_bot_connection_mode": "long_connection",
},
},
"long_connection_secret": {
"description": "长连接 Secret",
"type": "string",
"hint": "企业微信智能机器人长连接模式凭证 Secret。",
"condition": {
"wecom_ai_bot_connection_mode": "long_connection",
},
},
"long_connection_ws_url": {
"description": "长连接 WebSocket 地址",
"type": "string",
"hint": "默认值为 wss://openws.work.weixin.qq.com,一般无需修改。",
"condition": {
"wecom_ai_bot_connection_mode": "long_connection",
},
},
"long_connection_heartbeat_interval": {
"description": "长连接心跳间隔",
"type": "int",
"hint": "长连接模式心跳间隔(秒),建议 30 秒。",
"condition": {
"wecom_ai_bot_connection_mode": "long_connection",
},
},
"lark_bot_name": {
"description": "飞书机器人的名字",
"type": "string",
@@ -1167,7 +1123,7 @@ CONFIG_METADATA_2 = {
"enable": True,
"key": [],
"timeout": 120,
"api_base": "https://openrouter.ai/api/v1",
"api_base": "https://openrouter.ai/v1",
"proxy": "",
"custom_headers": {},
},
+6 -2
View File
@@ -97,7 +97,11 @@ class AstrBotCoreLifecycle:
except Exception as e:
logger.error(f"Subagent orchestrator init failed: {e}", exc_info=True)
async def initialize(self) -> None:
async def initialize(
self,
*,
mcp_init_timeout: float | int | str | None = None,
) -> None:
"""初始化 AstrBot 核心生命周期管理类.
负责初始化各个组件, 包括 ProviderManager、PlatformManager、ConversationManager、PluginManager、PipelineScheduler、EventBus、AstrBotUpdator等。
@@ -201,7 +205,7 @@ class AstrBotCoreLifecycle:
await self.plugin_manager.reload()
# 根据配置实例化各个 Provider
await self.provider_manager.initialize()
await self.provider_manager.initialize(init_timeout=mcp_init_timeout)
await self.kb_manager.initialize()
+18 -44
View File
@@ -539,36 +539,13 @@ class Reply(BaseMessageComponent):
class Poke(BaseMessageComponent):
type: ComponentType = ComponentType.Poke
_type: str | int = "126"
id: int | str | None = 0
qq: int | str | None = 0 # deprecated: legacy field, kept for compatibility
type: str = ComponentType.Poke
id: int | None = 0
qq: int | None = 0
def __init__(self, poke_type: str | int | None = None, **_) -> None:
# Backward compatible with old signature: Poke(type="poke", ...)
legacy_type = _.pop("type", None)
if poke_type is None:
poke_type = legacy_type
if poke_type in (None, "", "poke", "Poke"):
poke_type = "126"
super().__init__(_type=str(poke_type), **_)
def target_id(self) -> str | None:
"""Return normalized target id, compatible with old `qq` field."""
for value in (self.id, self.qq):
if value is None:
continue
text = str(value).strip()
if text and text != "0":
return text
return None
def toDict(self):
target_id = self.target_id()
data = {"type": str(self._type or "126")}
if target_id:
data["id"] = target_id
return {"type": "poke", "data": data}
def __init__(self, type: str, **_) -> None:
type = f"Poke:{type}"
super().__init__(type=type, **_)
class Forward(BaseMessageComponent):
@@ -699,24 +676,21 @@ class File(BaseMessageComponent):
if self.url:
try:
# 检查是否有正在运行的 event loop
asyncio.get_running_loop()
logger.warning(
"不可以在异步上下文中同步等待下载! "
"这个警告通常发生于某些逻辑试图通过 <File>.file 获取文件消息段的文件内容。"
"请使用 await get_file() 代替直接获取 <File>.file 字段",
)
return ""
except RuntimeError:
# 没有运行中的 event loop,可以同步执行
try:
# 使用 asyncio.run 安全地创建和关闭事件循环
asyncio.run(self._download_file())
except Exception:
logger.exception("文件下载失败")
loop = asyncio.get_event_loop()
if loop.is_running():
logger.warning(
"不可以在异步上下文中同步等待下载! "
"这个警告通常发生于某些逻辑试图通过 <File>.file 获取文件消息段的文件内容。"
"请使用 await get_file() 代替直接获取 <File>.file 字段",
)
return ""
# 等待下载完成
loop.run_until_complete(self._download_file())
if self.file_ and os.path.exists(self.file_):
return os.path.abspath(self.file_)
except Exception as e:
logger.error(f"文件下载失败: {e}")
return ""
+1 -1
View File
@@ -28,7 +28,7 @@ class RespondStage(Stage):
Comp.At: lambda comp: bool(comp.qq) or bool(comp.name), # @
Comp.Image: lambda comp: bool(comp.file), # 图片
Comp.Reply: lambda comp: bool(comp.id) and comp.sender_id is not None, # 回复
Comp.Poke: lambda comp: comp.target_id() is not None, # 戳一戳
Comp.Poke: lambda comp: comp.id != 0 and comp.qq != 0, # 戳一戳
Comp.Node: lambda comp: bool(comp.content), # 转发节点
Comp.Nodes: lambda comp: bool(comp.nodes), # 多个转发节点
Comp.File: lambda comp: bool(comp.file_ or comp.url),
@@ -5,7 +5,7 @@ import traceback
from collections.abc import AsyncGenerator
from astrbot.core import file_token_service, html_renderer, logger
from astrbot.core.message.components import At, Image, Node, Plain, Record, Reply
from astrbot.core.message.components import At, File, Image, Node, Plain, Record, Reply
from astrbot.core.message.message_event_result import ResultContentType
from astrbot.core.pipeline.content_safety_check.stage import ContentSafetyCheckStage
from astrbot.core.platform.astr_message_event import AstrMessageEvent
@@ -383,11 +383,8 @@ class ResultDecorateStage(Stage):
)
result.chain = [node]
# at 回复 / 引用回复仅适用于纯文本或图文消息
can_decorate = all(
isinstance(item, (Plain, Image)) for item in result.chain
)
if can_decorate:
has_plain = any(isinstance(item, Plain) for item in result.chain)
if has_plain:
# at 回复
if (
self.reply_with_mention
@@ -402,4 +399,5 @@ class ResultDecorateStage(Stage):
# 引用回复
if self.reply_with_quote:
result.chain.insert(0, Reply(id=event.message_obj.message_id))
if not any(isinstance(item, File) for item in result.chain):
result.chain.insert(0, Reply(id=event.message_obj.message_id))
@@ -191,7 +191,7 @@ class AiocqhttpAdapter(Platform):
if "sub_type" in event:
if event["sub_type"] == "poke" and "target_id" in event:
abm.message.append(Poke(id=str(event["target_id"])))
abm.message.append(Poke(qq=str(event["target_id"]), type="poke"))
return abm
@@ -11,7 +11,7 @@ from dingtalk_stream import AckMessage
from astrbot import logger
from astrbot.api.event import MessageChain
from astrbot.api.message_components import At, File, Image, Plain, Record, Video
from astrbot.api.message_components import At, Image, Plain, Record, Video
from astrbot.api.platform import (
AstrBotMessage,
MessageMember,
@@ -178,110 +178,29 @@ class DingtalkPlatformAdapter(Platform):
abm.session_id = abm.sender.user_id
message_type: str = cast(str, message.message_type)
robot_code = cast(str, message.robot_code or "")
raw_content = cast(dict, message.extensions.get("content") or {})
if not isinstance(raw_content, dict):
raw_content = {}
match message_type:
case "text":
abm.message_str = message.text.content.strip()
abm.message.append(Plain(abm.message_str))
case "picture":
if not robot_code:
logger.error("钉钉图片消息解析失败: 回调中缺少 robotCode")
await self._remember_sender_binding(message, abm)
return abm
image_content = cast(
dingtalk_stream.ImageContent | None,
message.image_content,
)
download_code = cast(
str, (image_content.download_code if image_content else "") or ""
)
if not download_code:
logger.warning("钉钉图片消息缺少 downloadCode,已跳过")
else:
f_path = await self.download_ding_file(
download_code,
robot_code,
"jpg",
)
if f_path:
abm.message.append(Image.fromFileSystem(f_path))
else:
logger.warning("钉钉图片消息下载失败,无法解析为图片")
case "richText":
rtc: dingtalk_stream.RichTextContent = cast(
dingtalk_stream.RichTextContent, message.rich_text_content
)
contents: list[dict] = cast(list[dict], rtc.rich_text_list)
plain_parts: list[str] = []
for content in contents:
plains = ""
if "text" in content:
plain_text = cast(str, content.get("text") or "")
if plain_text:
plain_parts.append(plain_text)
abm.message.append(Plain(plain_text))
plains += content["text"]
abm.message.append(Plain(plains))
elif "type" in content and content["type"] == "picture":
download_code = cast(str, content.get("downloadCode") or "")
if not download_code:
logger.warning(
"钉钉富文本图片消息缺少 downloadCode,已跳过"
)
continue
if not robot_code:
logger.error(
"钉钉富文本图片消息解析失败: 回调中缺少 robotCode"
)
continue
f_path = await self.download_ding_file(
download_code,
robot_code,
content["downloadCode"],
cast(str, message.robot_code),
"jpg",
)
if f_path:
abm.message.append(Image.fromFileSystem(f_path))
abm.message_str = "".join(plain_parts).strip()
case "audio" | "voice":
download_code = cast(str, raw_content.get("downloadCode") or "")
if not download_code:
logger.warning("钉钉语音消息缺少 downloadCode,已跳过")
elif not robot_code:
logger.error("钉钉语音消息解析失败: 回调中缺少 robotCode")
else:
voice_ext = cast(str, raw_content.get("fileExtension") or "")
if not voice_ext:
voice_ext = "amr"
voice_ext = voice_ext.lstrip(".")
f_path = await self.download_ding_file(
download_code,
robot_code,
voice_ext,
)
if f_path:
abm.message.append(Record.fromFileSystem(f_path))
case "file":
download_code = cast(str, raw_content.get("downloadCode") or "")
if not download_code:
logger.warning("钉钉文件消息缺少 downloadCode,已跳过")
elif not robot_code:
logger.error("钉钉文件消息解析失败: 回调中缺少 robotCode")
else:
file_name = cast(str, raw_content.get("fileName") or "")
file_ext = Path(file_name).suffix.lstrip(".") if file_name else ""
if not file_ext:
file_ext = cast(str, raw_content.get("fileExtension") or "")
if not file_ext:
file_ext = "file"
f_path = await self.download_ding_file(
download_code,
robot_code,
file_ext,
)
if f_path:
if not file_name:
file_name = Path(f_path).name
abm.message.append(File(name=file_name, file=f_path))
abm.message.append(Image.fromFileSystem(f_path))
case "audio":
pass
await self._remember_sender_binding(message, abm)
return abm # 别忘了返回转换后的消息对象
@@ -351,23 +270,13 @@ class DingtalkPlatformAdapter(Platform):
)
return ""
resp_data = await resp.json()
download_url = cast(
str,
(
resp_data.get("downloadUrl")
or resp_data.get("data", {}).get("downloadUrl")
or ""
),
)
if not download_url:
logger.error(f"下载钉钉文件失败: 未找到 downloadUrl, 响应: {resp_data}")
return ""
download_url = resp_data["data"]["downloadUrl"]
await download_file(download_url, str(f_path))
return str(f_path)
async def get_access_token(self) -> str:
try:
access_token = await asyncio.get_running_loop().run_in_executor(
access_token = await asyncio.get_event_loop().run_in_executor(
None,
self.client_.get_access_token,
)
@@ -632,28 +541,6 @@ class DingtalkPlatformAdapter(Platform):
self._safe_remove_file(cover_path)
if converted_video:
self._safe_remove_file(video_path)
elif isinstance(segment, File):
try:
file_path = await segment.get_file()
if not file_path:
logger.warning("钉钉文件发送失败: 无法解析文件路径")
continue
media_id = await self.upload_media(file_path, "file")
if not media_id:
continue
file_name = segment.name or Path(file_path).name
file_type = Path(file_name).suffix.lstrip(".")
await send_message(
msg_key="sampleFile",
msg_param={
"mediaId": media_id,
"fileName": file_name,
"fileType": file_type,
},
)
except Exception as e:
logger.warning(f"钉钉文件发送失败: {e}")
continue
async def send_message_chain_to_group(
self,
@@ -760,7 +647,7 @@ class DingtalkPlatformAdapter(Platform):
return
logger.error(f"钉钉机器人启动失败: {e}")
loop = asyncio.get_running_loop()
loop = asyncio.get_event_loop()
await loop.run_in_executor(None, start_client, loop)
async def terminate(self) -> None:
@@ -80,7 +80,7 @@ class QQOfficialMessageEvent(AstrMessageEvent):
if isinstance(source, botpy.message.C2CMessage):
# 真流式传输
current_time = asyncio.get_running_loop().time()
current_time = asyncio.get_event_loop().time()
time_since_last_edit = current_time - last_edit_time
if time_since_last_edit >= throttle_interval:
@@ -90,7 +90,7 @@ class QQOfficialMessageEvent(AstrMessageEvent):
)
stream_payload["index"] += 1
stream_payload["id"] = ret["id"]
last_edit_time = asyncio.get_running_loop().time()
last_edit_time = asyncio.get_event_loop().time()
if isinstance(source, botpy.message.C2CMessage):
# 结束流式对话,并且传输 buffer 中剩余的消息
@@ -55,7 +55,7 @@ class QQOfficialWebhook:
max_async=1,
connect=bot_connect,
dispatch=self.client.ws_dispatch,
loop=asyncio.get_running_loop(),
loop=asyncio.get_event_loop(),
api=self.api,
)
@@ -626,7 +626,7 @@ class TelegramPlatformEvent(AstrMessageEvent):
# 发送初始 typing 状态
await self._ensure_typing(user_name, message_thread_id)
last_chat_action_time = asyncio.get_running_loop().time()
last_chat_action_time = asyncio.get_event_loop().time()
def _append_text(t: str) -> None:
nonlocal delta
@@ -657,11 +657,11 @@ class TelegramPlatformEvent(AstrMessageEvent):
# 编辑或发送消息
if message_id and len(delta) <= self.MAX_MESSAGE_LENGTH:
current_time = asyncio.get_running_loop().time()
current_time = asyncio.get_event_loop().time()
time_since_last_edit = current_time - last_edit_time
if time_since_last_edit >= throttle_interval:
current_time = asyncio.get_running_loop().time()
current_time = asyncio.get_event_loop().time()
if current_time - last_chat_action_time >= chat_action_interval:
await self._ensure_typing(user_name, message_thread_id)
last_chat_action_time = current_time
@@ -674,9 +674,9 @@ class TelegramPlatformEvent(AstrMessageEvent):
current_content = delta
except Exception as e:
logger.warning(f"编辑消息失败(streaming): {e!s}")
last_edit_time = asyncio.get_running_loop().time()
last_edit_time = asyncio.get_event_loop().time()
else:
current_time = asyncio.get_running_loop().time()
current_time = asyncio.get_event_loop().time()
if current_time - last_chat_action_time >= chat_action_interval:
await self._ensure_typing(user_name, message_thread_id)
last_chat_action_time = current_time
@@ -688,7 +688,7 @@ class TelegramPlatformEvent(AstrMessageEvent):
except Exception as e:
logger.warning(f"发送消息失败(streaming): {e!s}")
message_id = msg.message_id
last_edit_time = asyncio.get_running_loop().time()
last_edit_time = asyncio.get_event_loop().time()
try:
if delta and current_content != delta:
@@ -200,7 +200,7 @@ class WecomPlatformAdapter(Platform):
return msg_list[-1]
return None
msg_new = await asyncio.get_running_loop().run_in_executor(
msg_new = await asyncio.get_event_loop().run_in_executor(
None,
get_latest_msg_item,
)
@@ -261,7 +261,7 @@ class WecomPlatformAdapter(Platform):
@override
async def run(self) -> None:
loop = asyncio.get_running_loop()
loop = asyncio.get_event_loop()
if self.kf_name:
try:
acc_list = (
@@ -339,7 +339,7 @@ class WecomPlatformAdapter(Platform):
abm.session_id = abm.sender.user_id
abm.raw_message = msg
elif isinstance(msg, VoiceMessage):
resp: Response = await asyncio.get_running_loop().run_in_executor(
resp: Response = await asyncio.get_event_loop().run_in_executor(
None,
self.client.media.download,
msg.media_id,
@@ -395,7 +395,7 @@ class WecomPlatformAdapter(Platform):
abm.message_str = text
elif msgtype == "image":
media_id = msg.get("image", {}).get("media_id", "")
resp: Response = await asyncio.get_running_loop().run_in_executor(
resp: Response = await asyncio.get_event_loop().run_in_executor(
None,
self.client.media.download,
media_id,
@@ -407,7 +407,7 @@ class WecomPlatformAdapter(Platform):
abm.message = [Image(file=path, url=path)]
elif msgtype == "voice":
media_id = msg.get("voice", {}).get("media_id", "")
resp: Response = await asyncio.get_running_loop().run_in_executor(
resp: Response = await asyncio.get_event_loop().run_in_executor(
None,
self.client.media.download,
media_id,
@@ -4,7 +4,6 @@
"""
import asyncio
import time
from collections.abc import Awaitable, Callable
from typing import Any
@@ -83,7 +82,7 @@ class WecomAIQueueMgr:
del self.pending_responses[session_id]
logger.debug(f"[WecomAI] 移除待处理响应: {session_id}")
if mark_finished:
self.completed_streams[session_id] = time.monotonic()
self.completed_streams[session_id] = asyncio.get_event_loop().time()
logger.debug(f"[WecomAI] 标记流已结束: {session_id}")
def remove_queue(self, session_id: str):
@@ -136,7 +135,7 @@ class WecomAIQueueMgr:
"""
self.pending_responses[session_id] = {
"callback_params": callback_params,
"timestamp": time.monotonic(),
"timestamp": asyncio.get_event_loop().time(),
}
logger.debug(f"[WecomAI] 设置待处理响应: {session_id}")
@@ -161,7 +160,7 @@ class WecomAIQueueMgr:
finished_at = self.completed_streams.get(session_id)
if finished_at is None:
return False
if time.monotonic() - finished_at > max_age_seconds:
if asyncio.get_event_loop().time() - finished_at > max_age_seconds:
self.completed_streams.pop(session_id, None)
return False
return True
@@ -173,7 +172,7 @@ class WecomAIQueueMgr:
max_age_seconds: 最大存活时间
"""
current_time = time.monotonic()
current_time = asyncio.get_event_loop().time()
expired_sessions = []
for session_id, response_data in self.pending_responses.items():
@@ -369,7 +369,7 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
if future:
logger.debug(f"duplicate message id checked: {msg.id}")
else:
future = asyncio.get_running_loop().create_future()
future = asyncio.get_event_loop().create_future()
self.wexin_event_workers[msg_id] = future
await self.convert_message(msg, future)
# I love shield so much!
@@ -461,7 +461,7 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
elif msg.type == "voice":
assert isinstance(msg, VoiceMessage)
resp: Response = await asyncio.get_running_loop().run_in_executor(
resp: Response = await asyncio.get_event_loop().run_in_executor(
None,
self.client.media.download,
msg.media_id,
+12 -4
View File
@@ -346,7 +346,10 @@ class FunctionToolManager:
logger.debug(f" 主机: {scheme}://{host}{port}")
async def init_mcp_clients(
self, raise_on_all_failed: bool = False
self,
raise_on_all_failed: bool = False,
*,
init_timeout: float | int | str | None = None,
) -> MCPInitSummary:
"""从项目根目录读取 mcp_server.json 文件,初始化 MCP 服务列表。文件格式如下:
```
@@ -367,6 +370,7 @@ class FunctionToolManager:
```
Timeout behavior:
- 显式 `init_timeout` 参数优先用于测试或调用方覆盖
- 初始化超时使用环境变量 ASTRBOT_MCP_INIT_TIMEOUT 或默认值
- 动态启用超时使用 ASTRBOT_MCP_ENABLE_TIMEOUT独立于初始化超时
"""
@@ -383,8 +387,12 @@ class FunctionToolManager:
with open(mcp_json_file, encoding="utf-8") as f:
mcp_server_json_obj: dict[str, dict] = json.load(f)["mcpServers"]
init_timeout = self._init_timeout_default
timeout_display = f"{init_timeout:g}"
init_timeout_value = _resolve_timeout(
timeout=init_timeout,
env_name=MCP_INIT_TIMEOUT_ENV,
default=self._init_timeout_default,
)
timeout_display = f"{init_timeout_value:g}"
active_configs: list[tuple[str, dict, asyncio.Event]] = []
for name, cfg in mcp_server_json_obj.items():
@@ -403,7 +411,7 @@ class FunctionToolManager:
name=name,
cfg=cfg,
shutdown_event=shutdown_event,
timeout=init_timeout,
timeout_seconds=init_timeout_value,
),
name=f"mcp-init:{name}",
)
+7 -2
View File
@@ -269,7 +269,11 @@ class ProviderManager:
return provider
async def initialize(self) -> None:
async def initialize(
self,
*,
init_timeout: float | int | str | None = None,
) -> None:
# 逐个初始化提供商
for provider_config in self.providers_config:
try:
@@ -338,7 +342,8 @@ class ProviderManager:
"on",
}
mcp_init_summary = await self.llm_tools.init_mcp_clients(
raise_on_all_failed=strict_mcp_init
raise_on_all_failed=strict_mcp_init,
init_timeout=init_timeout,
)
if (
mcp_init_summary.total > 0
@@ -276,24 +276,9 @@ class ProviderAnthropic(Provider):
llm_response.id = completion.id
llm_response.usage = self._extract_usage(completion.usage)
# Handle cases where completion only contains ThinkingBlock (e.g., MiniMax max_tokens)
# When stop_reason='max_tokens', the model may return only thinking content
# This is valid and should not raise an exception
# TODO(Soulter): 处理 end_turn 情况
if not llm_response.completion_text and not llm_response.tools_call_args:
# Guard clause: raise early if no valid content at all
if not llm_response.reasoning_content:
raise ValueError(
f"Anthropic API returned unparsable completion: "
f"no text, tool_use, or thinking content found. "
f"Completion: {completion}"
)
# We have reasoning content (ThinkingBlock) - this is valid
stop_reason = getattr(completion, "stop_reason", "unknown")
logger.debug(
f"Completion contains only ThinkingBlock (stop_reason={stop_reason})"
)
llm_response.completion_text = "" # Ensure empty string, not None
raise Exception(f"Anthropic API 返回的 completion 无法解析:{completion}")
return llm_response
@@ -87,7 +87,7 @@ class ProviderDashscopeTTSAPI(TTSProvider):
model: str,
text: str,
) -> tuple[bytes | None, str]:
loop = asyncio.get_running_loop()
loop = asyncio.get_event_loop()
response = await loop.run_in_executor(None, self._call_qwen_tts, model, text)
audio_bytes = await self._extract_audio_from_response(response)
if not audio_bytes:
@@ -143,7 +143,7 @@ class ProviderDashscopeTTSAPI(TTSProvider):
voice=self.voice,
format=AudioFormat.WAV_24000HZ_MONO_16BIT,
)
loop = asyncio.get_running_loop()
loop = asyncio.get_event_loop()
audio_bytes = await loop.run_in_executor(
None,
synthesizer.call,
+2 -2
View File
@@ -59,7 +59,7 @@ class GenieTTSProvider(TTSProvider):
filename = f"genie_tts_{uuid.uuid4()}.wav"
path = os.path.join(temp_dir, filename)
loop = asyncio.get_running_loop()
loop = asyncio.get_event_loop()
def _generate(save_path: str) -> None:
assert genie is not None
@@ -85,7 +85,7 @@ class GenieTTSProvider(TTSProvider):
text_queue: asyncio.Queue[str | None],
audio_queue: "asyncio.Queue[bytes | tuple[str, bytes] | None]",
) -> None:
loop = asyncio.get_running_loop()
loop = asyncio.get_event_loop()
while True:
text = await text_queue.get()
@@ -43,7 +43,7 @@ class ProviderSenseVoiceSTTSelfHost(STTProvider):
logger.info("下载或者加载 SenseVoice 模型中,这可能需要一些时间 ...")
# 将模型加载放到线程池中执行
self.model = await asyncio.get_running_loop().run_in_executor(
self.model = await asyncio.get_event_loop().run_in_executor(
None,
lambda: SenseVoiceSmall(self.model_name, quantize=True, batch_size=16),
)
@@ -88,7 +88,7 @@ class ProviderSenseVoiceSTTSelfHost(STTProvider):
audio_url = output_path
# 使用 run_in_executor 来调用模型进行识别
loop = asyncio.get_running_loop()
loop = asyncio.get_event_loop()
res = await loop.run_in_executor(
None, # 使用默认的线程池
lambda: cast(SenseVoiceSmall, self.model)(
@@ -31,7 +31,7 @@ class ProviderOpenAIWhisperSelfHost(STTProvider):
self.model = None
async def initialize(self) -> None:
loop = asyncio.get_running_loop()
loop = asyncio.get_event_loop()
logger.info("下载或者加载 Whisper 模型中,这可能需要一些时间 ...")
self.model = await loop.run_in_executor(
None,
@@ -50,7 +50,7 @@ class ProviderOpenAIWhisperSelfHost(STTProvider):
return False
async def get_text(self, audio_url: str) -> str:
loop = asyncio.get_running_loop()
loop = asyncio.get_event_loop()
is_tencent = False
+2 -17
View File
@@ -26,13 +26,6 @@ _SANDBOX_SKILLS_CACHE_VERSION = 1
_SKILL_NAME_RE = re.compile(r"^[A-Za-z0-9._-]+$")
def _is_ignored_zip_entry(name: str) -> bool:
parts = PurePosixPath(name).parts
if not parts:
return True
return parts[0] == "__MACOSX"
@dataclass
class SkillInfo:
name: str
@@ -408,11 +401,7 @@ class SkillManager:
raise ValueError("Uploaded file is not a valid zip archive.")
with zipfile.ZipFile(zip_path) as zf:
names = [
name
for name in (entry.replace("\\", "/") for entry in zf.namelist())
if name and not _is_ignored_zip_entry(name)
]
names = [name.replace("\\", "/") for name in zf.namelist()]
file_names = [name for name in names if name and not name.endswith("/")]
if not file_names:
raise ValueError("Zip archive is empty.")
@@ -447,11 +436,7 @@ class SkillManager:
raise ValueError("SKILL.md not found in the skill folder.")
with tempfile.TemporaryDirectory(dir=get_astrbot_temp_path()) as tmp_dir:
for member in zf.infolist():
member_name = member.filename.replace("\\", "/")
if not member_name or _is_ignored_zip_entry(member_name):
continue
zf.extract(member, tmp_dir)
zf.extractall(tmp_dir)
src_dir = Path(tmp_dir) / skill_name
if not src_dir.exists():
raise ValueError("Skill folder not found after extraction.")
+1 -1
View File
@@ -15,4 +15,4 @@ class RegexFilter(HandlerFilter):
self.regex = re.compile(regex)
def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool:
return bool(self.regex.search(event.get_message_str().strip()))
return bool(self.regex.match(event.get_message_str().strip()))
+1 -14
View File
@@ -1374,23 +1374,10 @@ class PluginManager:
return
if "__del__" in star_metadata.star_cls_type.__dict__:
loop = asyncio.get_running_loop()
future = loop.run_in_executor(
asyncio.get_event_loop().run_in_executor(
None,
star_metadata.star_cls.__del__,
)
def _log_del_exception(fut: asyncio.Future) -> None:
if fut.cancelled():
return
if (exc := fut.exception()) is not None:
logger.error(
"插件 %s 在 __del__ 中抛出了异常:%r",
star_metadata.name,
exc,
)
future.add_done_callback(_log_del_exception)
elif "terminate" in star_metadata.star_cls_type.__dict__:
await star_metadata.star_cls.terminate()
+1 -1
View File
@@ -7,4 +7,4 @@ def is_frozen_runtime() -> bool:
def is_packaged_desktop_runtime() -> bool:
return os.environ.get("ASTRBOT_DESKTOP_CLIENT") == "1"
return is_frozen_runtime() and os.environ.get("ASTRBOT_DESKTOP_CLIENT") == "1"
+1 -27
View File
@@ -1,13 +1,9 @@
import asyncio
import threading
import weakref
from collections import defaultdict
from contextlib import asynccontextmanager
class _PerLoopSessionLockManager:
"""Per-event-loop session lock manager; keeps original simple semantics."""
class SessionLockManager:
def __init__(self) -> None:
self._locks: dict[str, asyncio.Lock] = defaultdict(asyncio.Lock)
self._lock_count: dict[str, int] = defaultdict(int)
@@ -30,26 +26,4 @@ class _PerLoopSessionLockManager:
self._lock_count.pop(session_id, None)
class SessionLockManager:
"""Thread-safe session lock manager with per-event-loop isolation."""
def __init__(self) -> None:
self._state_guard = threading.Lock()
self._loop_managers: weakref.WeakKeyDictionary[
asyncio.AbstractEventLoop, _PerLoopSessionLockManager
] = weakref.WeakKeyDictionary()
def _get_loop_manager(self) -> _PerLoopSessionLockManager:
"""Get the lock manager for the current event loop."""
loop = asyncio.get_running_loop()
with self._state_guard:
return self._loop_managers.setdefault(loop, _PerLoopSessionLockManager())
@asynccontextmanager
async def acquire_lock(self, session_id: str):
manager = self._get_loop_manager()
async with manager.acquire_lock(session_id):
yield
session_lock_manager = SessionLockManager()
-2
View File
@@ -610,7 +610,6 @@ class ConfigRoute(Route):
try:
conf_id = self.acm.create_conf(name=name, config=config)
await self.core_lifecycle.reload_pipeline_scheduler(conf_id)
return Response().ok(message="创建成功", data={"conf_id": conf_id}).__dict__
except ValueError as e:
return Response().error(str(e)).__dict__
@@ -650,7 +649,6 @@ class ConfigRoute(Route):
try:
success = self.acm.delete_conf(conf_id)
if success:
self.core_lifecycle.pipeline_scheduler_mapping.pop(conf_id, None)
return Response().ok(message="删除成功").__dict__
return Response().error("删除失败").__dict__
except ValueError as e:
-110
View File
@@ -2,7 +2,6 @@ import os
import re
import shutil
import traceback
import uuid
from collections.abc import Awaitable, Callable
from pathlib import Path
from typing import Any
@@ -51,7 +50,6 @@ class SkillsRoute(Route):
self.routes = {
"/skills": ("GET", self.get_skills),
"/skills/upload": ("POST", self.upload_skill),
"/skills/batch-upload": ("POST", self.batch_upload_skills),
"/skills/download": ("GET", self.download_skill),
"/skills/update": ("POST", self.update_skill),
"/skills/delete": ("POST", self.delete_skill),
@@ -190,114 +188,6 @@ class SkillsRoute(Route):
except Exception:
logger.warning(f"Failed to remove temp skill file: {temp_path}")
async def batch_upload_skills(self):
"""批量上传多个 skill ZIP 文件"""
if DEMO_MODE:
return (
Response()
.error("You are not permitted to do this operation in demo mode")
.__dict__
)
try:
files = await request.files
file_list = files.getlist("files")
if not file_list:
return Response().error("No files provided").__dict__
succeeded = []
failed = []
skill_mgr = SkillManager()
temp_dir = get_astrbot_temp_path()
os.makedirs(temp_dir, exist_ok=True)
for file in file_list:
filename = os.path.basename(file.filename or "unknown.zip")
temp_path = None
try:
if not filename.lower().endswith(".zip"):
failed.append(
{
"filename": filename,
"error": "Only .zip files are supported",
}
)
continue
temp_path = os.path.join(
temp_dir, f"batch_{uuid.uuid4().hex}_{filename}"
)
await file.save(temp_path)
skill_name = skill_mgr.install_skill_from_zip(
temp_path, overwrite=True
)
succeeded.append({"filename": filename, "name": skill_name})
except Exception as e:
failed.append({"filename": filename, "error": str(e)})
finally:
if temp_path and os.path.exists(temp_path):
try:
os.remove(temp_path)
except Exception:
pass
if succeeded:
try:
await sync_skills_to_active_sandboxes()
except Exception:
logger.warning(
"Failed to sync uploaded skills to active sandboxes."
)
total = len(file_list)
success_count = len(succeeded)
if success_count == total:
message = f"All {total} skill(s) uploaded successfully."
return (
Response()
.ok(
{
"total": total,
"succeeded": succeeded,
"failed": failed,
},
message,
)
.__dict__
)
if success_count == 0:
message = f"Upload failed for all {total} file(s)."
resp = Response().error(message)
resp.data = {
"total": total,
"succeeded": succeeded,
"failed": failed,
}
return resp.__dict__
message = f"Partial success: {success_count}/{total} skill(s) uploaded."
return (
Response()
.ok(
{
"total": total,
"succeeded": succeeded,
"failed": failed,
},
message,
)
.__dict__
)
except Exception as e:
logger.error(traceback.format_exc())
return Response().error(str(e)).__dict__
async def download_skill(self):
try:
name = str(request.args.get("name") or "").strip()
+68 -185
View File
@@ -12,32 +12,6 @@ from .route import Response, Route, RouteContext
DEFAULT_MCP_CONFIG = {"mcpServers": {}}
class EmptyMcpServersError(ValueError):
"""Raised when mcpServers is empty."""
pass
def _extract_mcp_server_config(mcp_servers_value: object) -> dict:
"""Extract server configuration from user-submitted mcpServers field.
Raises:
ValueError: Invalid configuration
"""
if not isinstance(mcp_servers_value, dict):
raise ValueError("mcpServers must be a JSON object")
if not mcp_servers_value:
raise EmptyMcpServersError("mcpServers configuration cannot be empty")
key_0 = next(iter(mcp_servers_value))
extracted = mcp_servers_value[key_0]
if not isinstance(extracted, dict):
raise ValueError(
"Invalid mcpServers format. Ensure each key in mcpServers is a server name, "
"and each value is an object containing fields like command/url."
)
return extracted
class ToolsRoute(Route):
def __init__(
self,
@@ -59,37 +33,13 @@ class ToolsRoute(Route):
self.register_routes()
self.tool_mgr = self.core_lifecycle.provider_manager.llm_tools
def _rollback_mcp_server(self, name: str) -> bool:
try:
rollback_config = self.tool_mgr.load_mcp_config()
if name in rollback_config["mcpServers"]:
rollback_config["mcpServers"].pop(name)
return self.tool_mgr.save_mcp_config(rollback_config)
return True
except Exception:
logger.error(traceback.format_exc())
return False
async def get_mcp_servers(self):
try:
config = self.tool_mgr.load_mcp_config()
servers = []
mcp_servers = config.get("mcpServers", {})
if not isinstance(mcp_servers, dict):
logger.warning(
f"Invalid MCP server config type: {type(mcp_servers).__name__}. Expected object/dict; skipped all MCP servers."
)
mcp_servers = {}
# 获取所有服务器并添加它们的工具列表
for name, server_config in mcp_servers.items():
if not isinstance(server_config, dict):
logger.warning(
f"Invalid config for MCP server '{name}' (type: {type(server_config).__name__}); skipped."
)
continue
for name, server_config in config["mcpServers"].items():
server_info = {
"name": name,
"active": server_config.get("active", True),
@@ -115,7 +65,7 @@ class ToolsRoute(Route):
return Response().ok(servers).__dict__
except Exception as e:
logger.error(traceback.format_exc())
return Response().error(f"Failed to get MCP server list: {e!s}").__dict__
return Response().error(f"获取 MCP 服务器列表失败: {e!s}").__dict__
async def add_mcp_server(self):
try:
@@ -125,7 +75,7 @@ class ToolsRoute(Route):
# 检查必填字段
if not name:
return Response().error("Server name cannot be empty").__dict__
return Response().error("服务器名称不能为空").__dict__
# 移除特殊字段并检查配置是否有效
has_valid_config = False
@@ -135,33 +85,21 @@ class ToolsRoute(Route):
for key, value in server_data.items():
if key not in ["name", "active", "tools", "errlogs"]: # 排除特殊字段
if key == "mcpServers":
try:
server_config = _extract_mcp_server_config(
server_data["mcpServers"]
)
except ValueError as e:
return Response().error(f"{e!s}").__dict__
key_0 = list(server_data["mcpServers"].keys())[
0
] # 不考虑为空的情况
server_config = server_data["mcpServers"][key_0]
else:
server_config[key] = value
has_valid_config = True
if not has_valid_config:
return (
Response()
.error("A valid server configuration is required")
.__dict__
)
return Response().error("必须提供有效的服务器配置").__dict__
config = self.tool_mgr.load_mcp_config()
if name in config["mcpServers"]:
return Response().error(f"Server {name} already exists").__dict__
try:
await self.tool_mgr.test_mcp_server_connection(server_config)
except Exception as e:
logger.error(traceback.format_exc())
return Response().error(f"MCP connection test failed: {e!s}").__dict__
return Response().error(f"服务器 {name} 已存在").__dict__
config["mcpServers"][name] = server_config
@@ -173,27 +111,17 @@ class ToolsRoute(Route):
timeout=30,
)
except TimeoutError:
rollback_ok = self._rollback_mcp_server(name)
err_msg = f"Timed out while enabling MCP server {name}."
if not rollback_ok:
err_msg += " Configuration rollback failed. Please check the config manually."
return Response().error(err_msg).__dict__
return Response().error(f"启用 MCP 服务器 {name} 超时。").__dict__
except Exception as e:
logger.error(traceback.format_exc())
rollback_ok = self._rollback_mcp_server(name)
err_msg = f"Failed to enable MCP server {name}: {e!s}"
if not rollback_ok:
err_msg += " Configuration rollback failed. Please check the config manually."
return Response().error(err_msg).__dict__
return (
Response()
.ok(None, f"Successfully added MCP server {name}")
.__dict__
)
return Response().error("Failed to save configuration").__dict__
return (
Response().error(f"启用 MCP 服务器 {name} 失败: {e!s}").__dict__
)
return Response().ok(None, f"成功添加 MCP 服务器 {name}").__dict__
return Response().error("保存配置失败").__dict__
except Exception as e:
logger.error(traceback.format_exc())
return Response().error(f"Failed to add MCP server: {e!s}").__dict__
return Response().error(f"添加 MCP 服务器失败: {e!s}").__dict__
async def update_mcp_server(self):
try:
@@ -203,25 +131,23 @@ class ToolsRoute(Route):
old_name = server_data.get("oldName") or name
if not name:
return Response().error("Server name cannot be empty").__dict__
return Response().error("服务器名称不能为空").__dict__
config = self.tool_mgr.load_mcp_config()
if old_name not in config["mcpServers"]:
return Response().error(f"Server {old_name} does not exist").__dict__
return Response().error(f"服务器 {old_name} 不存在").__dict__
is_rename = name != old_name
if name in config["mcpServers"] and is_rename:
return Response().error(f"Server {name} already exists").__dict__
return Response().error(f"服务器 {name} 已存在").__dict__
# 获取活动状态
old_config = config["mcpServers"][old_name]
if isinstance(old_config, dict):
old_active = old_config.get("active", True)
else:
old_active = True
active = server_data.get("active", old_active)
active = server_data.get(
"active",
config["mcpServers"][old_name].get("active", True),
)
# 创建新的配置对象
server_config = {"active": active}
@@ -239,19 +165,17 @@ class ToolsRoute(Route):
"oldName",
]: # 排除特殊字段
if key == "mcpServers":
try:
server_config = _extract_mcp_server_config(
server_data["mcpServers"]
)
except ValueError as e:
return Response().error(f"{e!s}").__dict__
key_0 = list(server_data["mcpServers"].keys())[
0
] # 不考虑为空的情况
server_config = server_data["mcpServers"][key_0]
else:
server_config[key] = value
only_update_active = False
# 如果只更新活动状态,保留原始配置
if only_update_active and isinstance(old_config, dict):
for key, value in old_config.items():
if only_update_active:
for key, value in config["mcpServers"][old_name].items():
if key != "active": # 除了active之外的所有字段都保留
server_config[key] = value
@@ -276,7 +200,7 @@ class ToolsRoute(Route):
return (
Response()
.error(
f"Timed out while disabling MCP server {old_name} before enabling: {e!s}"
f"启用前停用 MCP 服务器时 {old_name} 超时: {e!s}"
)
.__dict__
)
@@ -285,7 +209,7 @@ class ToolsRoute(Route):
return (
Response()
.error(
f"Failed to disable MCP server {old_name} before enabling: {e!s}"
f"启用前停用 MCP 服务器时 {old_name} 失败: {e!s}"
)
.__dict__
)
@@ -297,15 +221,13 @@ class ToolsRoute(Route):
)
except TimeoutError:
return (
Response()
.error(f"Timed out while enabling MCP server {name}.")
.__dict__
Response().error(f"启用 MCP 服务器 {name} 超时。").__dict__
)
except Exception as e:
logger.error(traceback.format_exc())
return (
Response()
.error(f"Failed to enable MCP server {name}: {e!s}")
.error(f"启用 MCP 服务器 {name} 失败: {e!s}")
.__dict__
)
# 如果要停用服务器
@@ -315,26 +237,22 @@ class ToolsRoute(Route):
except TimeoutError:
return (
Response()
.error(f"Timed out while disabling MCP server {old_name}.")
.error(f"停用 MCP 服务器 {old_name} 超时。")
.__dict__
)
except Exception as e:
logger.error(traceback.format_exc())
return (
Response()
.error(f"Failed to disable MCP server {old_name}: {e!s}")
.error(f"停用 MCP 服务器 {old_name} 失败: {e!s}")
.__dict__
)
return (
Response()
.ok(None, f"Successfully updated MCP server {name}")
.__dict__
)
return Response().error("Failed to save configuration").__dict__
return Response().ok(None, f"成功更新 MCP 服务器 {name}").__dict__
return Response().error("保存配置失败").__dict__
except Exception as e:
logger.error(traceback.format_exc())
return Response().error(f"Failed to update MCP server: {e!s}").__dict__
return Response().error(f"更新 MCP 服务器失败: {e!s}").__dict__
async def delete_mcp_server(self):
try:
@@ -342,12 +260,12 @@ class ToolsRoute(Route):
name = server_data.get("name", "")
if not name:
return Response().error("Server name cannot be empty").__dict__
return Response().error("服务器名称不能为空").__dict__
config = self.tool_mgr.load_mcp_config()
if name not in config["mcpServers"]:
return Response().error(f"Server {name} does not exist").__dict__
return Response().error(f"服务器 {name} 不存在").__dict__
del config["mcpServers"][name]
@@ -357,76 +275,51 @@ class ToolsRoute(Route):
await self.tool_mgr.disable_mcp_server(name, timeout=10)
except TimeoutError:
return (
Response()
.error(f"Timed out while disabling MCP server {name}.")
.__dict__
Response().error(f"停用 MCP 服务器 {name} 超时。").__dict__
)
except Exception as e:
logger.error(traceback.format_exc())
return (
Response()
.error(f"Failed to disable MCP server {name}: {e!s}")
.error(f"停用 MCP 服务器 {name} 失败: {e!s}")
.__dict__
)
return (
Response()
.ok(None, f"Successfully deleted MCP server {name}")
.__dict__
)
return Response().error("Failed to save configuration").__dict__
return Response().ok(None, f"成功删除 MCP 服务器 {name}").__dict__
return Response().error("保存配置失败").__dict__
except Exception as e:
logger.error(traceback.format_exc())
return Response().error(f"Failed to delete MCP server: {e!s}").__dict__
return Response().error(f"删除 MCP 服务器失败: {e!s}").__dict__
async def test_mcp_connection(self):
"""Test MCP server connection."""
"""测试 MCP 服务器连接"""
try:
server_data = await request.json
config = server_data.get("mcp_server_config", None)
if not isinstance(config, dict) or not config:
return Response().error("Invalid MCP server configuration").__dict__
return Response().error("无效的 MCP 服务器配置").__dict__
if "mcpServers" in config:
mcp_servers = config["mcpServers"]
if isinstance(mcp_servers, dict) and len(mcp_servers) > 1:
return (
Response()
.error(
"Only one MCP server configuration can be tested at a time"
)
.__dict__
)
try:
config = _extract_mcp_server_config(mcp_servers)
except EmptyMcpServersError:
return (
Response()
.error("MCP server configuration cannot be empty")
.__dict__
)
except ValueError as e:
return Response().error(f"{e!s}").__dict__
keys = list(config["mcpServers"].keys())
if not keys:
return Response().error("MCP 服务器配置不能为空").__dict__
if len(keys) > 1:
return Response().error("一次只能配置一个 MCP 服务器配置").__dict__
config = config["mcpServers"][keys[0]]
elif not config:
return (
Response()
.error("MCP server configuration cannot be empty")
.__dict__
)
return Response().error("MCP 服务器配置不能为空").__dict__
tools_name = await self.tool_mgr.test_mcp_server_connection(config)
return (
Response()
.ok(data=tools_name, message="🎉 MCP server is available!")
.__dict__
Response().ok(data=tools_name, message="🎉 MCP 服务器可用!").__dict__
)
except Exception as e:
logger.error(traceback.format_exc())
return Response().error(f"Failed to test MCP connection: {e!s}").__dict__
return Response().error(f"测试 MCP 连接失败: {e!s}").__dict__
async def get_tool_list(self):
"""Get all registered tools."""
"""获取所有注册的工具列表"""
try:
tools = self.tool_mgr.func_list
tools_dict = []
@@ -456,44 +349,36 @@ class ToolsRoute(Route):
return Response().ok(data=tools_dict).__dict__
except Exception as e:
logger.error(traceback.format_exc())
return Response().error(f"Failed to get tool list: {e!s}").__dict__
return Response().error(f"获取工具列表失败: {e!s}").__dict__
async def toggle_tool(self):
"""Activate or deactivate a specified tool."""
"""启用或停用指定的工具"""
try:
data = await request.json
tool_name = data.get("name")
action = data.get("activate") # True or False
if not tool_name or action is None:
return (
Response()
.error("Missing required parameters: name or activate")
.__dict__
)
return Response().error("缺少必要参数: name 或 action").__dict__
if action:
try:
ok = self.tool_mgr.activate_llm_tool(tool_name, star_map=star_map)
except ValueError as e:
return Response().error(f"Failed to activate tool: {e!s}").__dict__
return Response().error(f"启用工具失败: {e!s}").__dict__
else:
ok = self.tool_mgr.deactivate_llm_tool(tool_name)
if ok:
return Response().ok(None, "Operation successful.").__dict__
return (
Response()
.error(f"Tool {tool_name} does not exist or the operation failed.")
.__dict__
)
return Response().ok(None, "操作成功。").__dict__
return Response().error(f"工具 {tool_name} 不存在或操作失败。").__dict__
except Exception as e:
logger.error(traceback.format_exc())
return Response().error(f"Failed to operate tool: {e!s}").__dict__
return Response().error(f"操作工具失败: {e!s}").__dict__
async def sync_provider(self):
"""Sync MCP provider configuration."""
"""同步 MCP 提供者配置"""
try:
data = await request.json
provider_name = data.get("name") # modelscope, or others
@@ -502,11 +387,9 @@ class ToolsRoute(Route):
access_token = data.get("access_token", "")
await self.tool_mgr.sync_modelscope_mcp_servers(access_token)
case _:
return (
Response().error(f"Unknown provider: {provider_name}").__dict__
)
return Response().error(f"未知: {provider_name}").__dict__
return Response().ok(message="Sync completed").__dict__
return Response().ok(message="同步成功").__dict__
except Exception as e:
logger.error(traceback.format_exc())
return Response().error(f"Sync failed: {e!s}").__dict__
return Response().error(f"同步失败: {e!s}").__dict__
@@ -3,10 +3,10 @@
### 新增
- 集成 KOOK 平台适配器 ([#5658](https://github.com/AstrBotDevs/AstrBot/pull/5658))。
- 集成 DeerFlow Agent Runner 并优化流式处理 ([#5581](https://github.com/AstrBotDevs/AstrBot/pull/5581))。
- 新增 Discord pre-react Emoji 支持 ([#5609](https://github.com/AstrBotDevs/AstrBot/pull/5609))。
- 新增 Telegram 支持 `sendMessageDraft` 流式实时输出 API ([#5726](https://github.com/AstrBotDevs/AstrBot/issues/5726))
- 支持在 Agent 运行时进行消息跟进能力,跟进的消息实时注入给 Agent ([#5484](https://github.com/AstrBotDevs/AstrBot/pull/5484))。
- 集成 DeerFlow Agent Runner 并优化流式处理 ([#5581](https://github.com/AstrBotDevs/AstrBot/pull/5581))。
- 新增 shell, ipython tool 中包含操作系统信息,提高 windows 下 tool call 成功率 ([#5677](https://github.com/AstrBotDevs/AstrBot/pull/5677))。
- Sandbox 支持 Shipyard-neo - 支持 Skills 自迭代 ([#5028](https://github.com/AstrBotDevs/AstrBot/pull/5028))。
- 新增 ChatUI WebSocket 传输模式选择,OpenAPI Chat API 支持 WebSocket 连接 ([#5410](https://github.com/AstrBotDevs/AstrBot/pull/5410))。
-40
View File
@@ -1,40 +0,0 @@
## What's Changed
### 新增
- 新增技能 ZIP 批量上传能力 ([#5804](https://github.com/AstrBotDevs/AstrBot/pull/5804))。
### 修复
- 修复 MCP Server 配置异常时可能导致崩溃的问题 ([#5666](https://github.com/AstrBotDevs/AstrBot/pull/5666), [#5673](https://github.com/AstrBotDevs/AstrBot/pull/5673))。
- 修复钉钉适配器文本消息被忽略、无法主动发送文件的问题 ([#5921](https://github.com/AstrBotDevs/AstrBot/pull/5921))。
- 修复钉钉适配器无法接收图片与文件的问题 ([#5920](https://github.com/AstrBotDevs/AstrBot/pull/5920))。
- fix(provider): handle MiniMax ThinkingBlock when max_tokens reached ([#5913](https://github.com/AstrBotDevs/AstrBot/pull/5913))。
- 修复 OpenRouter `api_base` 配置错误的问题 ([#5911](https://github.com/AstrBotDevs/AstrBot/pull/5911))。
- 修复插件市场中按展示名搜索已安装插件不生效的问题 ([#5806](https://github.com/AstrBotDevs/AstrBot/pull/5806), [#5811](https://github.com/AstrBotDevs/AstrBot/pull/5811))。
- 修复仅图片响应未应用 `reply_with_quote``reply_with_mention` 的问题 ([#5219](https://github.com/AstrBotDevs/AstrBot/pull/5219))。
- 修复 `RegexFilter` 使用 `re.match` 导致匹配范围不正确的问题 ([#5368](https://github.com/AstrBotDevs/AstrBot/pull/5368))。
- 修复桌面运行环境检测依赖 frozen Python 的问题 ([#5859](https://github.com/AstrBotDevs/AstrBot/pull/5859))。
- 修复通过“创建新配置”创建平台机器人后找不到 pipeline scheduler 的问题 ([#5776](https://github.com/AstrBotDevs/AstrBot/pull/5776))。
---
## What's Changed (EN)
### New Features
- Added batch upload support for multiple skill ZIP files ([#5804](https://github.com/AstrBotDevs/AstrBot/pull/5804)).
### Bug Fixes
- Fixed potential crash on malformed MCP server config ([#5666](https://github.com/AstrBotDevs/AstrBot/pull/5666), [#5673](https://github.com/AstrBotDevs/AstrBot/pull/5673)).
- Fixed DingTalk adapter issue where text messages were ignored and files could not be sent proactively ([#5921](https://github.com/AstrBotDevs/AstrBot/pull/5921)).
- Fixed DingTalk adapter issue where image and file messages could not be received ([#5920](https://github.com/AstrBotDevs/AstrBot/pull/5920)).
- Fixed incorrect OpenRouter `api_base` configuration ([#5911](https://github.com/AstrBotDevs/AstrBot/pull/5911)).
- Fixed searching installed plugins by display name in extensions ([#5806](https://github.com/AstrBotDevs/AstrBot/pull/5806), [#5811](https://github.com/AstrBotDevs/AstrBot/pull/5811)).
- Fixed image-only responses not applying `reply_with_quote` and `reply_with_mention` ([#5219](https://github.com/AstrBotDevs/AstrBot/pull/5219)).
- Fixed `RegexFilter` using `re.match` instead of `re.search` for expected matching behavior ([#5368](https://github.com/AstrBotDevs/AstrBot/pull/5368)).
- Fixed desktop runtime detection requiring frozen Python ([#5859](https://github.com/AstrBotDevs/AstrBot/pull/5859)).
- Fixed missing pipeline scheduler after creating a platform bot via "create new config" ([#5776](https://github.com/AstrBotDevs/AstrBot/pull/5776)).
- fix(provider): handle MiniMax ThinkingBlock when max_tokens reached ([#5913](https://github.com/AstrBotDevs/AstrBot/pull/5913))
@@ -300,10 +300,6 @@ export default {
this.loadingGettingServers = true;
axios.get('/api/tools/mcp/servers')
.then(response => {
if (response.data.status === 'error') {
this.showError(response.data.message || this.tm('messages.getServersError', { error: 'Unknown error' }));
return;
}
this.mcpServers = response.data.data || [];
this.mcpServers.forEach(server => {
if (!this.mcpServerUpdateLoaders[server.name]) {
@@ -376,10 +372,6 @@ export default {
axios.post(endpoint, serverData)
.then(response => {
this.loading = false;
if (response.data.status === 'error') {
this.showError(response.data.message || this.tm('messages.saveError', { error: 'Unknown error' }));
return;
}
this.showMcpServerDialog = false;
this.addServerDialogMessage = '';
this.getServers();
File diff suppressed because it is too large Load Diff
@@ -224,43 +224,10 @@
"empty": "No Skills found",
"emptyHint": "Upload a Skills zip to get started",
"uploadDialogTitle": "Upload Skills",
"uploadHint": "Upload multiple zip skill packages or drag them in. The system validates the structure automatically and shows a result for each file.",
"structureRequirement": "The most common failure is an invalid archive structure. Each zip must contain exactly one top-level folder such as `skillname/`, and that folder must include `SKILL.md`.",
"abilityMultiple": "Upload multiple zip files at once",
"abilityValidate": "Validate `SKILL.md` automatically",
"abilitySkip": "Automatically skip duplicate files.",
"uploadHint": "Upload a zip file that contains skill_name/ and a SKILL.md inside.",
"selectFile": "Select file",
"selectFiles": "Select files (multiple allowed)",
"dropzoneTitle": "Drag multiple zip files here",
"dropzoneAction": "or click to pick multiple files from a folder",
"dropzoneHint": "Batch upload is supported and the structure will be validated automatically",
"fileListTitle": "Files in queue",
"fileListEmpty": "Selected files will appear here with validation feedback and upload status",
"uploading": "Uploading...",
"batchResultTitle": "Batch Upload Results",
"batchResultSummary": "{success} of {total} files uploaded successfully",
"batchSuccessList": "Successfully uploaded",
"batchFailedList": "Failed to upload",
"confirm": "OK",
"confirmUpload": "Start Upload",
"confirmUpload": "Upload",
"cancel": "Cancel",
"statusWaiting": "Waiting",
"statusUploading": "Uploading",
"statusSuccess": "Uploaded",
"statusError": "Failed",
"statusSkipped": "Skipped",
"summaryTotal": "{count} file(s)",
"summaryReady": "Pending {count}",
"summarySuccess": "Success {count}",
"summaryFailed": "Failed {count}",
"summarySkipped": "Skipped {count}",
"validationReady": "Ready to upload. The archive structure will be checked during upload.",
"validationZipOnly": "Only zip skill packages are supported",
"validationDuplicate": "A file with the same name is already in the queue and has been skipped",
"validationUploading": "Validating and uploading...",
"validationUploadFailed": "Upload failed. Please try again.",
"validationUploadedAs": "Installed as {name}",
"validationNoResult": "No validation result was returned. Check the platform logs.",
"noDescription": "No description",
"path": "Path",
"uploadSuccess": "Upload succeeded",
@@ -224,43 +224,10 @@
"empty": "暂无 Skills",
"emptyHint": "请上传 Skills 压缩包",
"uploadDialogTitle": "上传 Skills",
"uploadHint": "支持批量上传 zip 技能包,也支持拖拽批量上传 zip 技能包。系统会自动校验目录结构,并给出逐个文件的结果。",
"structureRequirement": "常见失败原因是压缩包结构不正确。每个 zip 必须只包含一个顶层目录,例如 `skillname/`,且该目录下必须存在 `SKILL.md`。",
"abilityMultiple": "支持一次上传多个zip文件",
"abilityValidate": "自动校验 `SKILL.md`",
"abilitySkip": "自动跳过重复文件",
"uploadHint": "上传 zip 压缩包,解压后为 skill_name/ 目录,且包含 SKILL.md",
"selectFile": "选择文件",
"selectFiles": "选择文件(可多选)",
"dropzoneTitle": "拖拽多个 zip 文件到这里",
"dropzoneAction": "或者点击之后在文件夹中选择多个文件",
"dropzoneHint": "支持批量上传,系统会自动校验目录结构",
"fileListTitle": "待处理文件",
"fileListEmpty": "选择文件后会在这里显示校验结果与上传状态",
"uploading": "正在上传...",
"batchResultTitle": "批量上传结果",
"batchResultSummary": "共 {total} 个文件,成功 {success} 个",
"batchSuccessList": "上传成功",
"batchFailedList": "上传失败",
"confirm": "确定",
"confirmUpload": "开始上传",
"confirmUpload": "上传",
"cancel": "取消",
"statusWaiting": "待上传",
"statusUploading": "上传中",
"statusSuccess": "已上传",
"statusError": "校验失败",
"statusSkipped": "已跳过",
"summaryTotal": "共 {count} 个文件",
"summaryReady": "待处理 {count}",
"summarySuccess": "成功 {count}",
"summaryFailed": "失败 {count}",
"summarySkipped": "跳过 {count}",
"validationReady": "等待上传,上传时会自动校验目录结构",
"validationZipOnly": "仅支持 zip 技能包",
"validationDuplicate": "同名文件已在列表中,已跳过",
"validationUploading": "正在校验并上传...",
"validationUploadFailed": "上传失败,请重试",
"validationUploadedAs": "已安装为 {name}",
"validationNoResult": "未收到校验结果,请检查平台日志",
"noDescription": "无描述",
"path": "路径",
"uploadSuccess": "上传成功",
-102
View File
@@ -1,102 +0,0 @@
import { pinyin } from "pinyin-pro";
const HAN_IDEOGRAPH_RE = /\p{Unified_Ideograph}/u;
export const normalizeStr = (s) => (s ?? "").toString().toLowerCase().trim();
const normalizeLooseFromNormalized = (normalized) =>
normalized.replace(/[\s_-]+/g, "").replace(/[()()【】\[\]{}·•]+/g, "");
export const normalizeLoose = (s) =>
normalizeLooseFromNormalized(normalizeStr(s));
const memoizeStringFn = (fn) => {
const cache = new Map();
return (raw) => {
const key = (raw ?? "").toString();
if (cache.has(key)) {
return cache.get(key);
}
const value = fn(key);
cache.set(key, value);
return value;
};
};
const getNormalizedText = memoizeStringFn(normalizeStr);
const getLooseText = memoizeStringFn((text) =>
normalizeLooseFromNormalized(getNormalizedText(text)),
);
export const toPinyinText = memoizeStringFn((text) =>
pinyin(text, { toneType: "none" })
.toLowerCase()
.replace(/\s+/g, ""),
);
export const toInitials = memoizeStringFn((text) =>
pinyin(text, { pattern: "first", toneType: "none" })
.toLowerCase()
.replace(/\s+/g, ""),
);
export const buildSearchQuery = (raw) => {
const norm = getNormalizedText(raw);
if (!norm) return null;
return {
norm,
loose: getLooseText(raw),
};
};
export const matchesText = (value, query) => {
if (value == null || !query?.norm) return false;
const text = String(value);
const normalizedValue = getNormalizedText(text);
const looseValue = query.loose ? getLooseText(text) : null;
if (normalizedValue.includes(query.norm)) return true;
if (query.loose && looseValue?.includes(query.loose)) return true;
if (!HAN_IDEOGRAPH_RE.test(text)) return false;
const pinyinValue = toPinyinText(text);
if (pinyinValue.includes(query.norm)) return true;
const initialsValue = toInitials(text);
if (initialsValue.includes(query.norm)) return true;
return false;
};
export const getPluginSearchFields = (plugin) => {
const supportPlatforms = Array.isArray(plugin?.support_platforms)
? plugin.support_platforms.join(" ")
: "";
const tags = Array.isArray(plugin?.tags) ? plugin.tags.join(" ") : "";
return [
plugin?.name,
plugin?.trimmedName,
plugin?.display_name,
plugin?.desc,
plugin?.author,
plugin?.repo,
plugin?.version,
plugin?.astrbot_version,
supportPlatforms,
tags,
];
};
export const matchesPluginSearch = (plugin, query) => {
if (!query) return true;
return getPluginSearchFields(plugin).some((candidate) =>
matchesText(candidate, query),
);
};
+1
View File
@@ -84,6 +84,7 @@ const {
normalizeStr,
toPinyinText,
toInitials,
marketCustomFilter,
plugin_handler_info_headers,
pluginHeaders,
filteredExtensions,
@@ -81,6 +81,7 @@ const {
normalizeStr,
toPinyinText,
toInitials,
marketCustomFilter,
plugin_handler_info_headers,
pluginHeaders,
filteredExtensions,
@@ -82,6 +82,7 @@ const {
normalizeStr,
toPinyinText,
toInitials,
marketCustomFilter,
plugin_handler_info_headers,
pluginHeaders,
filteredExtensions,
@@ -1,15 +1,9 @@
import axios from "axios";
import { pinyin } from "pinyin-pro";
import { useCommonStore } from "@/stores/common";
import { useI18n, useModuleI18n } from "@/i18n/composables";
import { getPlatformDisplayName } from "@/utils/platformUtils";
import { resolveErrorMessage } from "@/utils/errorUtils";
import {
buildSearchQuery,
matchesPluginSearch,
normalizeStr,
toInitials,
toPinyinText,
} from "@/utils/pluginSearch";
import { ref, computed, onMounted, onUnmounted, reactive, watch } from "vue";
import { useRoute, useRouter } from "vue-router";
import { useDisplay } from "vuetify";
@@ -246,6 +240,37 @@ export const useExtensionPage = () => {
});
// 插件市场拼音搜索
const normalizeStr = (s) => (s ?? "").toString().toLowerCase().trim();
const toPinyinText = (s) =>
pinyin(s ?? "", { toneType: "none" })
.toLowerCase()
.replace(/\s+/g, "");
const toInitials = (s) =>
pinyin(s ?? "", { pattern: "first", toneType: "none" })
.toLowerCase()
.replace(/\s+/g, "");
const marketCustomFilter = (value, query, item) => {
const q = normalizeStr(query);
if (!q) return true;
const candidates = new Set();
if (value != null) candidates.add(String(value));
if (item?.name) candidates.add(String(item.name));
if (item?.trimmedName) candidates.add(String(item.trimmedName));
if (item?.display_name) candidates.add(String(item.display_name));
if (item?.desc) candidates.add(String(item.desc));
if (item?.author) candidates.add(String(item.author));
for (const v of candidates) {
const nv = normalizeStr(v);
if (nv.includes(q)) return true;
const pv = toPinyinText(v);
if (pv.includes(q)) return true;
const iv = toInitials(v);
if (iv.includes(q)) return true;
}
return false;
};
const plugin_handler_info_headers = computed(() => [
{ title: tm("table.headers.eventType"), key: "event_type_h" },
@@ -322,24 +347,47 @@ export const useExtensionPage = () => {
// 通过搜索过滤插件
const filteredPlugins = computed(() => {
const plugins = filteredExtensions.value;
const query = buildSearchQuery(pluginSearch.value);
const filtered = query
? plugins.filter((plugin) => matchesPluginSearch(plugin, query))
: plugins;
let filtered = plugins;
if (pluginSearch.value) {
const search = pluginSearch.value.toLowerCase();
filtered = plugins.filter((plugin) => {
const pluginName = (plugin.name ?? "").toLowerCase();
const pluginDesc = (plugin.desc ?? "").toLowerCase();
const pluginAuthor = (plugin.author ?? "").toLowerCase();
const supportPlatforms = Array.isArray(plugin.support_platforms)
? plugin.support_platforms.join(" ").toLowerCase()
: "";
const astrbotVersion = (plugin.astrbot_version ?? "").toLowerCase();
return (
pluginName.includes(search) ||
pluginDesc.includes(search) ||
pluginAuthor.includes(search) ||
supportPlatforms.includes(search) ||
astrbotVersion.includes(search)
);
});
}
return sortPluginsByName([...filtered]);
});
// 过滤后的插件市场数据(带搜索)
const filteredMarketPlugins = computed(() => {
const query = buildSearchQuery(debouncedMarketSearch.value);
if (!query) {
if (!debouncedMarketSearch.value) {
return pluginMarketData.value;
}
return pluginMarketData.value.filter((plugin) =>
matchesPluginSearch(plugin, query),
);
const search = debouncedMarketSearch.value.toLowerCase();
return pluginMarketData.value.filter((plugin) => {
// 使用自定义过滤器
return (
marketCustomFilter(plugin.name, search, plugin) ||
marketCustomFilter(plugin.desc, search, plugin) ||
marketCustomFilter(plugin.author, search, plugin)
);
});
});
// 所有插件列表,推荐插件排在前面
@@ -1515,6 +1563,7 @@ export const useExtensionPage = () => {
normalizeStr,
toPinyinText,
toInitials,
marketCustomFilter,
plugin_handler_info_headers,
pluginHeaders,
filteredExtensions,
+11 -26
View File
@@ -5,10 +5,6 @@ import os
import sys
from pathlib import Path
import runtime_bootstrap
runtime_bootstrap.initialize_runtime_bootstrap()
from astrbot.core import LogBroker, LogManager, db_helper, logger # noqa: E402
from astrbot.core.config.default import VERSION # noqa: E402
from astrbot.core.initial_loader import InitialLoader # noqa: E402
@@ -101,26 +97,6 @@ async def check_dashboard_files(webui_dir: str | None = None):
return data_dist_path
async def main_async(webui_dir_arg: str | None) -> None:
"""主异步入口"""
# 检查仪表板文件
webui_dir = await check_dashboard_files(webui_dir_arg)
if webui_dir is None:
logger.warning(
"管理面板文件检查失败,WebUI 功能将不可用。"
"请检查网络连接或手动指定 --webui-dir 参数。"
)
db = db_helper
# 打印 logo
logger.info(logo_tmpl)
core_lifecycle = InitialLoader(db, log_broker)
core_lifecycle.webui_dir = webui_dir
await core_lifecycle.start()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="AstrBot")
parser.add_argument(
@@ -137,5 +113,14 @@ if __name__ == "__main__":
log_broker = LogBroker()
LogManager.set_queue_handler(logger, log_broker)
# 只使用一次 asyncio.run()
asyncio.run(main_async(args.webui_dir))
# 检查仪表板文件
webui_dir = asyncio.run(check_dashboard_files(args.webui_dir))
db = db_helper
# 打印 logo
logger.info(logo_tmpl)
core_lifecycle = InitialLoader(db, log_broker)
core_lifecycle.webui_dir = webui_dir
asyncio.run(core_lifecycle.start())
+3 -3
View File
@@ -1,9 +1,9 @@
[project]
name = "AstrBot"
version = "4.19.3"
version = "4.19.1"
description = "Easy-to-use multi-platform LLM chatbot and development framework"
readme = "README.md"
requires-python = ">=3.12"
requires-python = ">=3.11"
keywords = ["Astrbot", "Astrbot Module", "Astrbot Plugin"]
@@ -61,7 +61,7 @@ dependencies = [
"xinference-client",
"tenacity>=9.1.2",
"shipyard-python-sdk>=0.2.4",
"shipyard-neo-sdk>=0.2.0",
"shipyard-neo-sdk @ git+https://github.com/AstrBotDevs/shipyard-neo.git#subdirectory=shipyard-neo-sdk",
"python-socks>=2.8.0",
"packaging>=24.2",
]
+1 -1
View File
@@ -54,5 +54,5 @@ markitdown-no-magika[docx,xls,xlsx]>=0.1.2
xinference-client
tenacity>=9.1.2
shipyard-python-sdk>=0.2.4
shipyard-neo-sdk>=0.2.0
shipyard-neo-sdk @ git+https://github.com/AstrBotDevs/shipyard-neo.git#subdirectory=shipyard-neo-sdk
packaging>=24.2
-50
View File
@@ -1,50 +0,0 @@
import logging
import ssl
from typing import Any
import aiohttp.connector as aiohttp_connector
from astrbot.utils.http_ssl_common import build_ssl_context_with_certifi
logger = logging.getLogger(__name__)
def _try_patch_aiohttp_ssl_context(
ssl_context: ssl.SSLContext,
log_obj: Any | None = None,
) -> bool:
log = log_obj or logger
attr_name = "_SSL_CONTEXT_VERIFIED"
if not hasattr(aiohttp_connector, attr_name):
log.warning(
"aiohttp connector does not expose _SSL_CONTEXT_VERIFIED; skipped patch.",
)
return False
current_value = getattr(aiohttp_connector, attr_name, None)
if current_value is not None and not isinstance(current_value, ssl.SSLContext):
log.warning(
"aiohttp connector exposes _SSL_CONTEXT_VERIFIED with unexpected type; skipped patch.",
)
return False
setattr(aiohttp_connector, attr_name, ssl_context)
log.info("Configured aiohttp verified SSL context with system+certifi trust chain.")
return True
def configure_runtime_ca_bundle(log_obj: Any | None = None) -> bool:
log = log_obj or logger
try:
log.info("Bootstrapping runtime CA bundle.")
ssl_context = build_ssl_context_with_certifi(log_obj=log)
return _try_patch_aiohttp_ssl_context(ssl_context, log_obj=log)
except Exception as exc:
log.error("Failed to configure runtime CA bundle for aiohttp: %r", exc)
return False
def initialize_runtime_bootstrap(log_obj: Any | None = None) -> bool:
return configure_runtime_ca_bundle(log_obj=log_obj)
+20
View File
@@ -7,6 +7,7 @@ import shutil
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Callable
from urllib.parse import urlparse
from unittest.mock import AsyncMock, MagicMock
from astrbot.core.message.components import BaseMessageComponent
@@ -24,6 +25,25 @@ class NoopAwaitable:
return None
def get_bound_tcp_port(site: Any) -> int:
"""Resolve the bound aiohttp TCP site port for tests.
We prefer the public ``site.name`` first. Some aiohttp test setups with
ephemeral ports may not expose a usable port there, so we fall back to
``site._server.sockets`` as a test-only compatibility path.
"""
parsed = urlparse(getattr(site, "name", ""))
if parsed.port is not None and parsed.port > 0:
return parsed.port
server = getattr(site, "_server", None)
sockets = getattr(server, "sockets", None) if server else None
if sockets:
return sockets[0].getsockname()[1]
raise RuntimeError("Unable to resolve bound TCP port from aiohttp site")
# ============================================================
# 平台配置工厂
# ============================================================
+268
View File
@@ -0,0 +1,268 @@
"""Performance benchmark tests for core AstrBot execution paths.
Run with:
uv run pytest tests/performance/test_benchmarks.py -q -s
Optional output:
ASTRBOT_BENCHMARK_OUTPUT=/tmp/astrbot_benchmark.json
"""
from __future__ import annotations
import asyncio
import json
import math
import os
import time
import zipfile
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Awaitable, Callable
from unittest.mock import MagicMock
import pytest
from aiohttp import web
from astrbot.core.backup.exporter import AstrBotExporter
from astrbot.core.message.components import File, Image, Record
from astrbot.core.utils.io import download_file, file_to_base64
from tests.fixtures.helpers import get_bound_tcp_port
@dataclass(slots=True)
class BenchmarkResult:
name: str
iterations: int
warmup: int
min_ms: float
max_ms: float
mean_ms: float
p50_ms: float
p95_ms: float
ops_per_sec: float
def _percentile(values: list[float], q: float) -> float:
if not values:
return 0.0
sorted_values = sorted(values)
if len(sorted_values) == 1:
return sorted_values[0]
rank = (len(sorted_values) - 1) * q
lower = math.floor(rank)
upper = math.ceil(rank)
if lower == upper:
return sorted_values[lower]
weight = rank - lower
return sorted_values[lower] * (1 - weight) + sorted_values[upper] * weight
async def run_async_benchmark(
name: str,
func: Callable[[], Awaitable[None]],
*,
iterations: int,
warmup: int = 5,
) -> BenchmarkResult:
for _ in range(warmup):
await func()
samples_ms: list[float] = []
for _ in range(iterations):
start_ns = time.perf_counter_ns()
await func()
elapsed_ms = (time.perf_counter_ns() - start_ns) / 1_000_000
samples_ms.append(elapsed_ms)
mean_ms = sum(samples_ms) / len(samples_ms)
return BenchmarkResult(
name=name,
iterations=iterations,
warmup=warmup,
min_ms=min(samples_ms),
max_ms=max(samples_ms),
mean_ms=mean_ms,
p50_ms=_percentile(samples_ms, 0.50),
p95_ms=_percentile(samples_ms, 0.95),
ops_per_sec=1000 / mean_ms if mean_ms > 0 else 0.0,
)
def _print_report(results: list[BenchmarkResult]) -> None:
print("\nAstrBot Benchmark Report")
print("-" * 84)
print(
f"{'case':35} {'iters':>7} {'mean(ms)':>10} {'p50(ms)':>10} "
f"{'p95(ms)':>10} {'ops/s':>10}"
)
print("-" * 84)
for result in results:
print(
f"{result.name:35} {result.iterations:7d} "
f"{result.mean_ms:10.4f} {result.p50_ms:10.4f} "
f"{result.p95_ms:10.4f} {result.ops_per_sec:10.1f}"
)
def _scaled_iterations(value: int) -> int:
scale = int(os.environ.get("ASTRBOT_BENCHMARK_SCALE", "1"))
return max(1, value * scale)
@pytest.mark.asyncio
@pytest.mark.slow
async def test_core_performance_benchmarks(tmp_path: Path) -> None:
"""Measure representative performance paths across core modules."""
data = os.urandom(256 * 1024)
payload_path = tmp_path / "payload.bin"
payload_path.write_bytes(data)
image = Image.fromFileSystem(str(payload_path))
record = Record.fromFileSystem(str(payload_path))
file_component = File(name="payload.bin", file=str(payload_path))
exists_path = tmp_path / "exists_target.txt"
exists_path.write_text("ok", encoding="utf-8")
attachments_dir = tmp_path / "attachments"
attachments_dir.mkdir()
attachments: list[dict[str, str]] = []
attachments_with_missing: list[dict[str, str]] = []
for i in range(64):
file_path = attachments_dir / f"attachment_{i}.bin"
file_path.write_bytes(data[:2048])
attachments.append({"attachment_id": f"att_{i}", "path": str(file_path)})
if i % 4 == 0:
missing_path = attachments_dir / f"missing_{i}.bin"
attachments_with_missing.append(
{"attachment_id": f"att_missing_{i}", "path": str(missing_path)}
)
attachments_with_missing.append(
{"attachment_id": f"att_existing_{i}", "path": str(file_path)}
)
exporter = AstrBotExporter(main_db=MagicMock())
zip_path = tmp_path / "attachments_bench.zip"
micro_batch = 32
download_target = tmp_path / "download_target.bin"
download_payload = os.urandom(512 * 1024)
async def handle_download(_request):
return web.Response(body=download_payload)
app = web.Application()
app.router.add_get("/download.bin", handle_download)
runner = web.AppRunner(app, access_log=None)
await runner.setup()
site = web.TCPSite(runner, "127.0.0.1", 0)
await site.start()
port = get_bound_tcp_port(site)
download_url = f"http://127.0.0.1:{port}/download.bin"
async def bench_file_to_base64() -> None:
await file_to_base64(str(payload_path))
async def bench_image_convert_to_base64() -> None:
await image.convert_to_base64()
async def bench_record_convert_to_base64() -> None:
await record.convert_to_base64()
async def bench_image_convert_to_file_path() -> None:
for _ in range(micro_batch):
await image.convert_to_file_path()
async def bench_file_component_get_file() -> None:
await file_component.get_file()
async def bench_to_thread_exists() -> None:
await asyncio.to_thread(exists_path.exists)
async def bench_export_attachments_existing() -> None:
if zip_path.exists():
zip_path.unlink()
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
await exporter._export_attachments(zf, attachments)
zip_path.unlink(missing_ok=True)
async def bench_export_attachments_with_missing() -> None:
if zip_path.exists():
zip_path.unlink()
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
await exporter._export_attachments(zf, attachments_with_missing)
zip_path.unlink(missing_ok=True)
async def bench_download_file_local_http() -> None:
await download_file(download_url, str(download_target))
download_target.unlink(missing_ok=True)
try:
results = [
await run_async_benchmark(
"utils.io.file_to_base64(256KB)",
bench_file_to_base64,
iterations=_scaled_iterations(120),
),
await run_async_benchmark(
"components.Image.convert_to_base64",
bench_image_convert_to_base64,
iterations=_scaled_iterations(120),
),
await run_async_benchmark(
"components.Record.convert_to_base64",
bench_record_convert_to_base64,
iterations=_scaled_iterations(120),
),
await run_async_benchmark(
f"components.Image.convert_to_file_path(x{micro_batch})",
bench_image_convert_to_file_path,
iterations=_scaled_iterations(140),
),
await run_async_benchmark(
"components.File.get_file(local)",
bench_file_component_get_file,
iterations=_scaled_iterations(140),
),
await run_async_benchmark(
"asyncio.to_thread(Path.exists)",
bench_to_thread_exists,
iterations=_scaled_iterations(240),
),
await run_async_benchmark(
"backup.exporter._export_attachments(existing)",
bench_export_attachments_existing,
iterations=_scaled_iterations(20),
warmup=2,
),
await run_async_benchmark(
"backup.exporter._export_attachments(mixed)",
bench_export_attachments_with_missing,
iterations=_scaled_iterations(20),
warmup=2,
),
await run_async_benchmark(
"utils.io.download_file(local_http_512KB)",
bench_download_file_local_http,
iterations=_scaled_iterations(12),
warmup=2,
),
]
finally:
await runner.cleanup()
_print_report(results)
output_path = os.environ.get("ASTRBOT_BENCHMARK_OUTPUT")
if output_path:
Path(output_path).write_text(
json.dumps([asdict(result) for result in results], indent=2),
encoding="utf-8",
)
# Keep assertions broad: benchmarks are for measurement, not strict gating.
assert len(results) == 9
for result in results:
assert result.iterations > 0
assert result.mean_ms > 0
assert result.max_ms >= result.min_ms
assert result.p95_ms >= result.p50_ms
+98
View File
@@ -172,6 +172,15 @@ class TestAstrBotExporter:
assert "test.json" in exporter._checksums
assert exporter._checksums["test.json"].startswith("sha256:")
def test_read_text_if_exists(self, tmp_path):
"""测试 _read_text_if_exists 行为。"""
exporter = AstrBotExporter(main_db=MagicMock())
file_path = tmp_path / "config.json"
file_path.write_text('{"k":"v"}', encoding="utf-8")
assert exporter._read_text_if_exists(str(file_path)) == '{"k":"v"}'
assert exporter._read_text_if_exists(str(tmp_path / "missing.json")) is None
def test_generate_manifest(self, mock_main_db, mock_kb_manager):
"""测试生成清单"""
exporter = AstrBotExporter(
@@ -240,6 +249,95 @@ class TestAstrBotExporter:
assert "databases/main_db.json" in namelist
assert "config/cmd_config.json" in namelist
@pytest.mark.asyncio
async def test_export_attachments_exports_existing_and_skips_missing(
self, mock_main_db, tmp_path
):
"""测试附件导出:存在文件写入 ZIP,不存在文件跳过。"""
exporter = AstrBotExporter(main_db=mock_main_db, kb_manager=None)
existing_file = tmp_path / "exists.txt"
existing_file.write_text("hello", encoding="utf-8")
missing_file = tmp_path / "missing.txt"
zip_path = tmp_path / "attachments.zip"
attachments = [
{"attachment_id": "att_ok", "path": str(existing_file)},
{"attachment_id": "att_missing", "path": str(missing_file)},
]
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
await exporter._export_attachments(zf, attachments)
with zipfile.ZipFile(zip_path, "r") as zf:
namelist = zf.namelist()
assert "files/attachments/att_ok.txt" in namelist
assert "files/attachments/att_missing.txt" not in namelist
@pytest.mark.asyncio
async def test_export_attachments_skips_empty_attachment_id(
self, mock_main_db, tmp_path
):
"""测试附件导出:attachment_id 为空时跳过,避免覆盖冲突。"""
exporter = AstrBotExporter(main_db=mock_main_db, kb_manager=None)
file_a = tmp_path / "a.txt"
file_b = tmp_path / "b.txt"
file_a.write_text("a", encoding="utf-8")
file_b.write_text("b", encoding="utf-8")
zip_path = tmp_path / "attachments_empty_id.zip"
attachments = [
{"attachment_id": "", "path": str(file_a)},
{"path": str(file_b)},
{"attachment_id": "att_ok", "path": str(file_a)},
]
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
await exporter._export_attachments(zf, attachments)
with zipfile.ZipFile(zip_path, "r") as zf:
namelist = zf.namelist()
assert "files/attachments/att_ok.txt" in namelist
assert "files/attachments/.txt" not in namelist
@pytest.mark.asyncio
async def test_export_attachments_keeps_best_effort_on_unexpected_write_error(
self, mock_main_db, tmp_path
):
"""测试附件导出:单个非 OSError 写入异常不会中断后续附件导出。"""
exporter = AstrBotExporter(main_db=mock_main_db, kb_manager=None)
file_a = tmp_path / "a.txt"
file_b = tmp_path / "b.txt"
file_a.write_text("a", encoding="utf-8")
file_b.write_text("b", encoding="utf-8")
zip_path = tmp_path / "attachments_best_effort.zip"
attachments = [
{"attachment_id": "att_boom", "path": str(file_a)},
{"attachment_id": "att_ok", "path": str(file_b)},
]
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
original_write = zf.write
def flaky_write(filename, arcname=None, *args, **kwargs):
if arcname == "files/attachments/att_boom.txt":
raise RuntimeError("boom")
return original_write(filename, arcname, *args, **kwargs)
with patch.object(zf, "write", side_effect=flaky_write):
await exporter._export_attachments(zf, attachments)
with zipfile.ZipFile(zip_path, "r") as zf:
namelist = zf.namelist()
assert "files/attachments/att_boom.txt" not in namelist
assert "files/attachments/att_ok.txt" in namelist
class TestAstrBotImporter:
"""AstrBotImporter 类测试"""
+5 -227
View File
@@ -1,14 +1,11 @@
import asyncio
import io
import os
import sys
import zipfile
from types import SimpleNamespace
import pytest
import pytest_asyncio
from quart import Quart
from werkzeug.datastructures import FileStorage
from astrbot.core import LogBroker
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
@@ -18,6 +15,7 @@ from astrbot.core.star.star_handler import star_handlers_registry
from astrbot.dashboard.server import AstrBotDashboard
from tests.fixtures.helpers import (
MockPluginBuilder,
MockPluginConfig,
create_mock_updater_install,
create_mock_updater_update,
)
@@ -147,7 +145,9 @@ async def test_plugins(
monkeypatch.setattr(
core_lifecycle_td.plugin_manager.updator, "install", mock_install
)
monkeypatch.setattr(core_lifecycle_td.plugin_manager.updator, "update", mock_update)
monkeypatch.setattr(
core_lifecycle_td.plugin_manager.updator, "update", mock_update
)
try:
# 插件安装
@@ -158,9 +158,7 @@ async def test_plugins(
)
assert response.status_code == 200
data = await response.get_json()
assert data["status"] == "ok", (
f"安装失败: {data.get('message', 'unknown error')}"
)
assert data["status"] == "ok", f"安装失败: {data.get('message', 'unknown error')}"
# 验证插件已注册
exists = any(md.name == test_plugin_name for md in star_registry)
@@ -495,223 +493,3 @@ async def test_neo_skills_routes(
data = await response.get_json()
assert data["status"] == "ok"
assert data["data"]["skill_key"] == "neo.demo"
@pytest.mark.asyncio
async def test_batch_upload_skills_returns_error_when_all_files_invalid(
app: Quart,
authenticated_header: dict,
):
test_client = app.test_client()
response = await test_client.post(
"/api/skills/batch-upload",
headers=authenticated_header,
files={
"files": FileStorage(
stream=io.BytesIO(b"not-a-zip"),
filename="invalid.txt",
content_type="text/plain",
),
},
)
assert response.status_code == 200
data = await response.get_json()
assert data["status"] == "error"
assert data["message"] == "Upload failed for all 1 file(s)."
@pytest.mark.asyncio
async def test_batch_upload_skills_accepts_zip_files(
app: Quart,
authenticated_header: dict,
monkeypatch,
):
async def _fake_sync_skills_to_active_sandboxes():
return
def _fake_install_skill_from_zip(
self,
zip_path: str,
*,
overwrite: bool = True,
):
_ = self, overwrite
assert zip_path.endswith(".zip")
return "demo_skill"
monkeypatch.setattr(
"astrbot.dashboard.routes.skills.sync_skills_to_active_sandboxes",
_fake_sync_skills_to_active_sandboxes,
)
monkeypatch.setattr(
"astrbot.dashboard.routes.skills.SkillManager.install_skill_from_zip",
_fake_install_skill_from_zip,
)
test_client = app.test_client()
response = await test_client.post(
"/api/skills/batch-upload",
headers=authenticated_header,
files={
"files": FileStorage(
stream=io.BytesIO(b"fake-zip"),
filename="demo_skill.zip",
content_type="application/zip",
),
},
)
assert response.status_code == 200
data = await response.get_json()
assert data["status"] == "ok"
assert data["message"] == "All 1 skill(s) uploaded successfully."
assert data["data"]["total"] == 1
assert data["data"]["succeeded"] == [
{"filename": "demo_skill.zip", "name": "demo_skill"}
]
assert data["data"]["failed"] == []
@pytest.mark.asyncio
async def test_batch_upload_skills_accepts_valid_skill_archive(
app: Quart,
authenticated_header: dict,
monkeypatch,
tmp_path,
):
data_dir = tmp_path / "data"
skills_dir = tmp_path / "skills"
temp_dir = tmp_path / "temp"
data_dir.mkdir()
skills_dir.mkdir()
temp_dir.mkdir()
async def _fake_sync_skills_to_active_sandboxes():
return
monkeypatch.setattr(
"astrbot.dashboard.routes.skills.sync_skills_to_active_sandboxes",
_fake_sync_skills_to_active_sandboxes,
)
monkeypatch.setattr(
"astrbot.core.skills.skill_manager.get_astrbot_data_path",
lambda: str(data_dir),
)
monkeypatch.setattr(
"astrbot.core.skills.skill_manager.get_astrbot_skills_path",
lambda: str(skills_dir),
)
monkeypatch.setattr(
"astrbot.core.skills.skill_manager.get_astrbot_temp_path",
lambda: str(temp_dir),
)
monkeypatch.setattr(
"astrbot.dashboard.routes.skills.get_astrbot_temp_path",
lambda: str(temp_dir),
)
archive = io.BytesIO()
with zipfile.ZipFile(archive, "w", zipfile.ZIP_DEFLATED) as zf:
zf.writestr(
"demo_skill/SKILL.md",
"---\nname: demo-skill\ndescription: Demo skill\n---\n",
)
zf.writestr("demo_skill/notes.txt", "hello")
zf.writestr("__MACOSX/demo_skill/._SKILL.md", "")
zf.writestr("__MACOSX/._demo_skill", "")
archive.seek(0)
test_client = app.test_client()
response = await test_client.post(
"/api/skills/batch-upload",
headers=authenticated_header,
files={
"files": FileStorage(
stream=archive,
filename="demo_skill.zip",
content_type="application/zip",
),
},
)
assert response.status_code == 200
data = await response.get_json()
assert data["status"] == "ok"
assert data["data"]["succeeded"] == [
{"filename": "demo_skill.zip", "name": "demo_skill"}
]
assert data["data"]["failed"] == []
assert (skills_dir / "demo_skill" / "SKILL.md").exists()
@pytest.mark.asyncio
async def test_batch_upload_skills_partial_success(
app: Quart,
authenticated_header: dict,
monkeypatch,
):
async def _fake_sync_skills_to_active_sandboxes():
return
def _fake_install_skill_from_zip(
self,
zip_path: str,
*,
overwrite: bool = True,
):
_ = self, overwrite
if "ok_skill" in zip_path:
return "ok_skill"
raise RuntimeError("install failed")
monkeypatch.setattr(
"astrbot.dashboard.routes.skills.sync_skills_to_active_sandboxes",
_fake_sync_skills_to_active_sandboxes,
)
monkeypatch.setattr(
"astrbot.dashboard.routes.skills.SkillManager.install_skill_from_zip",
_fake_install_skill_from_zip,
)
test_client = app.test_client()
boundary = "----AstrBotBatchBoundary"
body = (
(
f"--{boundary}\r\n"
'Content-Disposition: form-data; name="files"; filename="ok_skill.zip"\r\n'
"Content-Type: application/zip\r\n\r\n"
).encode()
+ b"fake-zip-1\r\n"
+ (
f"--{boundary}\r\n"
'Content-Disposition: form-data; name="files"; filename="bad_skill.zip"\r\n'
"Content-Type: application/zip\r\n\r\n"
).encode()
+ b"fake-zip-2\r\n"
+ f"--{boundary}--\r\n".encode()
)
headers = dict(authenticated_header)
headers["Content-Type"] = f"multipart/form-data; boundary={boundary}"
response = await test_client.post(
"/api/skills/batch-upload",
headers=headers,
data=body,
)
assert response.status_code == 200
data = await response.get_json()
assert data["status"] == "ok"
assert data["message"] == "Partial success: 1/2 skill(s) uploaded."
assert data["data"]["total"] == 2
assert data["data"]["succeeded"] == [
{"filename": "ok_skill.zip", "name": "ok_skill"}
]
assert data["data"]["failed"] == [
{"filename": "bad_skill.zip", "error": "install failed"}
]
-41
View File
@@ -1,41 +0,0 @@
from unittest.mock import AsyncMock
import pytest
from astrbot.core.utils.pip_installer import PipInstaller
@pytest.mark.asyncio
async def test_install_targets_site_packages_for_desktop_client(monkeypatch, tmp_path):
monkeypatch.setenv("ASTRBOT_DESKTOP_CLIENT", "1")
monkeypatch.delattr("sys.frozen", raising=False)
site_packages_path = tmp_path / "site-packages"
run_pip = AsyncMock(return_value=0)
prepend_sys_path_calls = []
ensure_preferred_calls = []
monkeypatch.setattr(PipInstaller, "_run_pip_in_process", run_pip)
monkeypatch.setattr(
"astrbot.core.utils.pip_installer.get_astrbot_site_packages_path",
lambda: str(site_packages_path),
)
monkeypatch.setattr(
"astrbot.core.utils.pip_installer._prepend_sys_path",
lambda path: prepend_sys_path_calls.append(path),
)
monkeypatch.setattr(
"astrbot.core.utils.pip_installer._ensure_plugin_dependencies_preferred",
lambda path, requirements: ensure_preferred_calls.append((path, requirements)),
)
installer = PipInstaller("")
await installer.install(package_name="demo-package")
run_pip.assert_awaited_once()
recorded_args = run_pip.await_args_list[0].args[0]
assert "--target" in recorded_args
assert str(site_packages_path) in recorded_args
assert prepend_sys_path_calls == [str(site_packages_path), str(site_packages_path)]
assert ensure_preferred_calls == [(str(site_packages_path), {"demo-package"})]
-26
View File
@@ -1,26 +0,0 @@
from astrbot.core.utils.astrbot_path import get_astrbot_root
from astrbot.core.utils.runtime_env import is_packaged_desktop_runtime
def test_desktop_client_env_marks_desktop_runtime_without_frozen(monkeypatch):
monkeypatch.setenv("ASTRBOT_DESKTOP_CLIENT", "1")
monkeypatch.delattr("sys.frozen", raising=False)
assert is_packaged_desktop_runtime() is True
def test_desktop_client_uses_home_root_without_explicit_astrbot_root(monkeypatch):
monkeypatch.setenv("ASTRBOT_DESKTOP_CLIENT", "1")
monkeypatch.delenv("ASTRBOT_ROOT", raising=False)
monkeypatch.delattr("sys.frozen", raising=False)
assert get_astrbot_root().endswith(".astrbot")
def test_explicit_astrbot_root_overrides_desktop_default(monkeypatch, tmp_path):
explicit_root = tmp_path / "astrbot-root"
monkeypatch.setenv("ASTRBOT_DESKTOP_CLIENT", "1")
monkeypatch.setenv("ASTRBOT_ROOT", str(explicit_root))
monkeypatch.delattr("sys.frozen", raising=False)
assert get_astrbot_root() == str(explicit_root.resolve())
+3 -56
View File
@@ -2,8 +2,6 @@ from __future__ import annotations
from pathlib import Path
import pytest
from astrbot.core.skills.skill_manager import SkillManager
@@ -58,7 +56,7 @@ def test_list_skills_merges_local_and_sandbox_cache(monkeypatch, tmp_path: Path)
assert by_name["custom-local"].description == "local description"
assert by_name["custom-local"].path == "skills/custom-local/SKILL.md"
assert by_name["python-sandbox"].description == "ship built-in"
assert by_name["python-sandbox"].path == "/workspace/skills/python-sandbox/SKILL.md"
assert by_name["python-sandbox"].path == "skills/python-sandbox/SKILL.md"
def test_sandbox_cached_skill_respects_active_and_display_path(
@@ -100,58 +98,7 @@ def test_sandbox_cached_skill_respects_active_and_display_path(
assert len(all_skills) == 1
assert all_skills[0].path == "/app/skills/browser-automation/SKILL.md"
with pytest.raises(PermissionError):
mgr.set_skill_active("browser-automation", False)
mgr.set_skill_active("browser-automation", False)
active_skills = mgr.list_skills(runtime="sandbox", active_only=True)
assert len(active_skills) == 1
assert active_skills[0].name == "browser-automation"
def test_sandbox_and_local_path_resolution_with_show_sandbox_path_false(
monkeypatch,
tmp_path: Path,
):
data_dir = tmp_path / "data"
temp_dir = tmp_path / "temp"
skills_root = tmp_path / "skills"
data_dir.mkdir(parents=True, exist_ok=True)
temp_dir.mkdir(parents=True, exist_ok=True)
skills_root.mkdir(parents=True, exist_ok=True)
monkeypatch.setattr(
"astrbot.core.skills.skill_manager.get_astrbot_data_path",
lambda: str(data_dir),
)
monkeypatch.setattr(
"astrbot.core.skills.skill_manager.get_astrbot_temp_path",
lambda: str(temp_dir),
)
mgr = SkillManager(skills_root=str(skills_root))
_write_skill(skills_root, "custom-local", "local description")
mgr.set_sandbox_skills_cache(
[
{
"name": "custom-local",
"description": "cached description should be overridden",
"path": "/app/skills/custom-local/SKILL.md",
},
{
"name": "python-sandbox",
"description": "ship built-in",
"path": "/app/skills/python-sandbox/SKILL.md",
},
]
)
skills = mgr.list_skills(runtime="sandbox", show_sandbox_path=False)
by_name = {item.name: item for item in skills}
assert sorted(by_name) == ["custom-local", "python-sandbox"]
assert by_name["custom-local"].description == "local description"
local_skill_path = Path(by_name["custom-local"].path)
assert local_skill_path.is_relative_to(skills_root)
assert local_skill_path == skills_root / "custom-local" / "SKILL.md"
assert by_name["python-sandbox"].path == "/app/skills/python-sandbox/SKILL.md"
assert active_skills == []
-59
View File
@@ -1,59 +0,0 @@
from unittest.mock import AsyncMock
import pytest
import astrbot.core.message.components as Comp
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.pipeline.respond.stage import RespondStage
from astrbot.core.platform.sources.aiocqhttp.aiocqhttp_message_event import (
AiocqhttpMessageEvent,
)
def test_poke_to_dict_matches_onebot_v11_segment_format():
poke = Comp.Poke(type="126", id=2003)
assert poke.toDict() == {
"type": "poke",
"data": {"type": "126", "id": "2003"},
}
def test_poke_to_dict_keeps_legacy_qq_compatible():
poke = Comp.Poke(type="poke", qq=2916963017)
assert poke.toDict() == {
"type": "poke",
"data": {"type": "126", "id": "2916963017"},
}
@pytest.mark.asyncio
async def test_respond_stage_treats_poke_with_target_as_non_empty():
stage = RespondStage()
chain = [Comp.Poke(type="126", id=2003)]
assert await stage._is_empty_message_chain(chain) is False
@pytest.mark.asyncio
async def test_aiocqhttp_parse_json_outputs_standard_poke_data():
chain = MessageChain([Comp.Poke(type="126", id=2003)])
data = await AiocqhttpMessageEvent._parse_onebot_json(chain)
assert data == [{"type": "poke", "data": {"type": "126", "id": "2003"}}]
@pytest.mark.asyncio
async def test_aiocqhttp_send_message_dispatches_onebot_v11_poke_payload():
bot = AsyncMock()
chain = MessageChain([Comp.Poke(type="126", id=2003)])
await AiocqhttpMessageEvent.send_message(
bot=bot,
message_chain=chain,
event=None,
is_group=True,
session_id="123456",
)
bot.send_group_msg.assert_awaited_once_with(
group_id=123456,
message=[{"type": "poke", "data": {"type": "126", "id": "2003"}}],
)
+2 -2
View File
@@ -373,7 +373,7 @@ class TestAstrBotCoreLifecycleInitialize:
new_callable=AsyncMock,
),
):
await lifecycle.initialize()
await lifecycle.initialize(mcp_init_timeout=3.5)
# Verify database initialized
mock_db.initialize.assert_awaited_once()
@@ -388,7 +388,7 @@ class TestAstrBotCoreLifecycleInitialize:
mock_persona_mgr.initialize.assert_awaited_once()
# Verify provider manager initialized
mock_provider_manager.initialize.assert_awaited_once()
mock_provider_manager.initialize.assert_awaited_once_with(init_timeout=3.5)
# Verify platform manager initialized
mock_platform_manager.initialize.assert_awaited_once()
+43
View File
@@ -0,0 +1,43 @@
import pytest
from tests.fixtures.helpers import get_bound_tcp_port
class _DummySiteNoAttrs:
pass
class _DummySocket:
def __init__(self, port: int) -> None:
self._port = port
def getsockname(self):
return ("127.0.0.1", self._port)
class _DummyServer:
def __init__(self, port: int) -> None:
self.sockets = [_DummySocket(port)]
class _DummySiteWithName:
def __init__(self, port: int) -> None:
self.name = f"http://localhost:{port}"
class _DummySiteWithServer:
def __init__(self, port: int) -> None:
self._server = _DummyServer(port)
def test_get_bound_tcp_port_raises_on_unresolvable_site():
with pytest.raises(RuntimeError, match="Unable to resolve bound TCP port"):
get_bound_tcp_port(_DummySiteNoAttrs())
def test_get_bound_tcp_port_uses_name_port_when_available():
assert get_bound_tcp_port(_DummySiteWithName(8081)) == 8081
def test_get_bound_tcp_port_falls_back_to_server_sockets():
assert get_bound_tcp_port(_DummySiteWithServer(9092)) == 9092
+98
View File
@@ -0,0 +1,98 @@
import json
import pytest
from astrbot.core.provider import func_tool_manager
from astrbot.core.provider.func_tool_manager import FunctionToolManager
@pytest.fixture
def mcp_init_harness(
monkeypatch: pytest.MonkeyPatch,
tmp_path,
):
manager = FunctionToolManager()
data_dir = tmp_path / "data"
data_dir.mkdir()
(data_dir / "mcp_server.json").write_text(
json.dumps({"mcpServers": {"demo": {"active": True}}}),
encoding="utf-8",
)
monkeypatch.setattr(
func_tool_manager,
"get_astrbot_data_path",
lambda: data_dir,
)
called = {}
async def fake_start_mcp_server(*, name, cfg, shutdown_event, timeout_seconds):
called[name] = {
"cfg": cfg,
"shutdown_event_type": type(shutdown_event).__name__,
"timeout_seconds": timeout_seconds,
}
monkeypatch.setattr(manager, "_start_mcp_server", fake_start_mcp_server)
return manager, called
def assert_demo_init_result(summary, called, *, timeout_seconds: float) -> None:
assert summary.total == 1
assert summary.success == 1
assert summary.failed == []
assert called["demo"]["cfg"] == {"active": True}
assert called["demo"]["shutdown_event_type"] == "Event"
assert called["demo"]["timeout_seconds"] == timeout_seconds
@pytest.mark.asyncio
async def test_init_mcp_clients_passes_timeout_seconds_keyword(mcp_init_harness):
manager, called = mcp_init_harness
summary = await manager.init_mcp_clients()
assert_demo_init_result(
summary,
called,
timeout_seconds=manager._init_timeout_default,
)
@pytest.mark.asyncio
async def test_init_mcp_clients_passes_overridden_init_timeout(
mcp_init_harness,
):
manager, called = mcp_init_harness
summary = await manager.init_mcp_clients(init_timeout=3.5)
assert_demo_init_result(summary, called, timeout_seconds=3.5)
@pytest.mark.asyncio
async def test_init_mcp_clients_reads_env_timeout_when_not_overridden(
mcp_init_harness,
monkeypatch: pytest.MonkeyPatch,
):
manager, called = mcp_init_harness
manager._init_timeout_default = 20.0 # ensure env override is observable
monkeypatch.setenv("ASTRBOT_MCP_INIT_TIMEOUT", "3.5")
summary = await manager.init_mcp_clients()
assert_demo_init_result(summary, called, timeout_seconds=3.5)
@pytest.mark.asyncio
async def test_init_mcp_clients_prefers_explicit_timeout_over_env(
mcp_init_harness,
monkeypatch: pytest.MonkeyPatch,
):
manager, called = mcp_init_harness
monkeypatch.setenv("ASTRBOT_MCP_INIT_TIMEOUT", "7.0")
summary = await manager.init_mcp_clients(init_timeout=3.5)
assert_demo_init_result(summary, called, timeout_seconds=3.5)
+71
View File
@@ -0,0 +1,71 @@
import pytest
from aiohttp import web
from astrbot.core.utils import io as io_module
from astrbot.core.utils.io import _stream_to_file, download_file
from tests.fixtures.helpers import get_bound_tcp_port
@pytest.mark.asyncio
async def test_download_file_downloads_content(tmp_path):
payload = b"astrbot-download-payload" * 256
async def handle(_request):
return web.Response(body=payload)
app = web.Application()
app.router.add_get("/file.bin", handle)
runner = web.AppRunner(app, access_log=None)
await runner.setup()
site = web.TCPSite(runner, "127.0.0.1", 0)
await site.start()
try:
port = get_bound_tcp_port(site)
url = f"http://127.0.0.1:{port}/file.bin"
out = tmp_path / "downloaded.bin"
await download_file(url, str(out))
assert out.read_bytes() == payload
finally:
await runner.cleanup()
class _DummyStream:
def __init__(self, chunks: list[bytes]) -> None:
self._chunks = chunks
async def read(self, _size: int) -> bytes:
if not self._chunks:
return b""
return self._chunks.pop(0)
class _RecordingFile:
def __init__(self) -> None:
self.writes: list[bytes] = []
def write(self, data: bytes) -> int:
self.writes.append(data)
return len(data)
@pytest.mark.asyncio
async def test_stream_to_file_batches_multiple_chunks_per_write(monkeypatch):
monkeypatch.setattr(io_module, "_DOWNLOAD_READ_CHUNK_SIZE", 4)
monkeypatch.setattr(io_module, "_DOWNLOAD_FLUSH_THRESHOLD", 10)
stream = _DummyStream([b"aaaa", b"bbbb", b"cccc"])
file_obj = _RecordingFile()
await _stream_to_file(
stream,
file_obj,
total_size=12,
start_time=0.0,
show_progress=False,
)
assert len(file_obj.writes) == 1
assert file_obj.writes[0] == b"aaaabbbbcccc"
+178
View File
@@ -0,0 +1,178 @@
import base64
import os
from pathlib import Path
import pytest
from aiohttp import web
from astrbot.core.message import components as components_module
from astrbot.core.message.components import File, Image, Record
from tests.fixtures.helpers import get_bound_tcp_port
@pytest.mark.asyncio
async def test_image_convert_to_file_path_returns_absolute_path(tmp_path):
file_path = tmp_path / "img.bin"
file_path.write_bytes(b"img")
image = Image(file=str(file_path))
resolved = await image.convert_to_file_path()
assert resolved == os.path.abspath(str(file_path))
@pytest.mark.asyncio
async def test_record_convert_to_file_path_returns_absolute_path(tmp_path):
file_path = tmp_path / "record.bin"
file_path.write_bytes(b"record")
record = Record(file=str(file_path))
resolved = await record.convert_to_file_path()
assert resolved == os.path.abspath(str(file_path))
@pytest.mark.asyncio
async def test_file_component_get_file_returns_absolute_path(tmp_path):
file_path = tmp_path / "file.bin"
file_path.write_bytes(b"file")
file_component = File(name="file.bin", file=str(file_path))
resolved = await file_component.get_file()
assert resolved == os.path.abspath(str(file_path))
@pytest.mark.asyncio
async def test_image_convert_to_base64_raises_on_missing_file(tmp_path):
image = Image(file=str(tmp_path / "missing.bin"))
with pytest.raises(Exception, match="not a valid file"):
await image.convert_to_base64()
@pytest.mark.asyncio
async def test_record_convert_to_base64_raises_on_missing_file(tmp_path):
record = Record(file=str(tmp_path / "missing.bin"))
with pytest.raises(Exception, match="not a valid file"):
await record.convert_to_base64()
@pytest.mark.asyncio
async def test_image_convert_to_base64_reads_existing_local_file(tmp_path):
raw = b"image-bytes"
file_path = tmp_path / "exists_image.bin"
file_path.write_bytes(raw)
image = Image(file=str(file_path))
encoded = await image.convert_to_base64()
assert base64.b64decode(encoded) == raw
@pytest.mark.asyncio
async def test_record_convert_to_base64_reads_existing_local_file(tmp_path):
raw = b"record-bytes"
file_path = tmp_path / "exists_record.bin"
file_path.write_bytes(raw)
record = Record(file=str(file_path))
encoded = await record.convert_to_base64()
assert base64.b64decode(encoded) == raw
@pytest.mark.asyncio
async def test_image_convert_to_base64_maps_permission_error(monkeypatch):
async def _raise_permission_error(_path: str) -> str:
raise PermissionError("permission denied")
monkeypatch.setattr(components_module, "file_to_base64", _raise_permission_error)
image = Image(file="/tmp/forbidden-image")
with pytest.raises(Exception, match="not a valid file"):
await image.convert_to_base64()
@pytest.mark.asyncio
async def test_record_convert_to_base64_maps_permission_error(monkeypatch):
async def _raise_permission_error(_path: str) -> str:
raise PermissionError("permission denied")
monkeypatch.setattr(components_module, "file_to_base64", _raise_permission_error)
record = Record(file="/tmp/forbidden-record")
with pytest.raises(Exception, match="not a valid file"):
await record.convert_to_base64()
@pytest.mark.asyncio
async def test_image_convert_to_file_path_from_base64_creates_absolute_file():
payload = b"image-base64-payload"
image = Image(file=f"base64://{base64.b64encode(payload).decode()}")
resolved = await image.convert_to_file_path()
resolved_path = Path(resolved)
assert resolved_path.is_absolute()
assert resolved_path.exists()
assert resolved_path.read_bytes() == payload
@pytest.mark.asyncio
async def test_record_convert_to_file_path_from_base64_creates_absolute_file():
payload = b"record-base64-payload"
record = Record(file=f"base64://{base64.b64encode(payload).decode()}")
resolved = await record.convert_to_file_path()
resolved_path = Path(resolved)
assert resolved_path.is_absolute()
assert resolved_path.exists()
assert resolved_path.read_bytes() == payload
async def _serve_payload(payload: bytes, route: str):
async def handle(_request):
return web.Response(body=payload)
app = web.Application()
app.router.add_get(route, handle)
runner = web.AppRunner(app, access_log=None)
await runner.setup()
site = web.TCPSite(runner, "127.0.0.1", 0)
await site.start()
return runner, get_bound_tcp_port(site)
@pytest.mark.asyncio
async def test_image_convert_to_file_path_from_http_creates_absolute_file():
payload = b"image-http-payload"
runner, port = await _serve_payload(payload, "/img.bin")
try:
image = Image(file=f"http://127.0.0.1:{port}/img.bin")
resolved = await image.convert_to_file_path()
resolved_path = Path(resolved)
assert resolved_path.is_absolute()
assert resolved_path.exists()
assert resolved_path.read_bytes() == payload
finally:
await runner.cleanup()
@pytest.mark.asyncio
async def test_record_convert_to_file_path_from_http_creates_absolute_file():
payload = b"record-http-payload"
runner, port = await _serve_payload(payload, "/record.bin")
try:
record = Record(file=f"http://127.0.0.1:{port}/record.bin")
resolved = await record.convert_to_file_path()
resolved_path = Path(resolved)
assert resolved_path.is_absolute()
assert resolved_path.exists()
assert resolved_path.read_bytes() == payload
finally:
await runner.cleanup()
-545
View File
@@ -1,545 +0,0 @@
"""Tests for SessionLockManager with multi-event-loop isolation."""
import asyncio
import threading
import time
import weakref
from concurrent.futures import ThreadPoolExecutor
import pytest
from astrbot.core.utils.session_lock import SessionLockManager
class TestSessionLockManagerBasic:
"""Basic functionality tests."""
def test_init(self):
"""Test manager initialization."""
manager = SessionLockManager()
assert manager._state_guard is not None
assert manager._loop_managers is not None
@pytest.mark.asyncio
async def test_acquire_release_lock(self):
"""Test basic lock acquire and release."""
manager = SessionLockManager()
session_id = "test-session"
async with manager.acquire_lock(session_id):
# Lock acquired successfully
pass
# Lock should be released and cleaned up
state = manager._get_loop_manager()
assert session_id not in state._locks
assert session_id not in state._lock_count
@pytest.mark.asyncio
async def test_lock_is_reusable(self):
"""Test that locks can be acquired multiple times."""
manager = SessionLockManager()
session_id = "test-session"
async with manager.acquire_lock(session_id):
pass
async with manager.acquire_lock(session_id):
pass
# Both acquisitions should succeed
class TestCrossLoopIsolation:
"""Tests for event loop isolation."""
@pytest.mark.asyncio
async def test_different_loops_have_different_managers(self):
"""Test that different event loops get different per-loop managers."""
manager = SessionLockManager()
# Get manager for current loop
manager1 = manager._get_loop_manager()
# Run in a different event loop
def run_in_new_loop():
new_loop = asyncio.new_event_loop()
try:
asyncio.set_event_loop(new_loop)
async def get_manager():
return manager._get_loop_manager()
return new_loop.run_until_complete(get_manager())
finally:
new_loop.close()
asyncio.set_event_loop(None)
with ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(run_in_new_loop)
manager2 = future.result()
# Should be different manager instances
assert manager1 is not manager2
@pytest.mark.asyncio
async def test_locks_isolated_across_loops(self):
"""Test that locks from different loops are isolated."""
manager = SessionLockManager()
session_id = "shared-session"
results = []
async def acquire_in_loop(loop_id: int):
"""Acquire lock in a new event loop."""
async with manager.acquire_lock(session_id):
results.append(f"loop-{loop_id}-acquired")
await asyncio.sleep(0.05)
results.append(f"loop-{loop_id}-released")
def run_in_thread(loop_id: int):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(acquire_in_loop(loop_id))
finally:
loop.close()
asyncio.set_event_loop(None)
# Run two loops concurrently - they should NOT block each other
# because locks are isolated per-loop
with ThreadPoolExecutor(max_workers=2) as executor:
futures = [executor.submit(run_in_thread, i) for i in range(2)]
for f in futures:
f.result()
# Both loops should acquire immediately (no blocking between loops)
# Order should show interleaved acquisitions, not sequential
assert len(results) == 4
@pytest.mark.asyncio
async def test_same_loop_blocks_on_same_session(self):
"""Test that same loop blocks when acquiring same session lock."""
manager = SessionLockManager()
session_id = "test-session"
execution_order = []
async def task1():
async with manager.acquire_lock(session_id):
execution_order.append("task1-start")
await asyncio.sleep(0.1)
execution_order.append("task1-end")
async def task2():
await asyncio.sleep(0.01) # Let task1 start first
async with manager.acquire_lock(session_id):
execution_order.append("task2-start")
execution_order.append("task2-end")
await asyncio.gather(task1(), task2())
# task2 should wait for task1 to finish
assert execution_order.index("task1-start") < execution_order.index("task1-end")
assert execution_order.index("task1-end") < execution_order.index("task2-start")
class TestConcurrency:
"""Tests for concurrent access."""
@pytest.mark.asyncio
async def test_concurrent_acquisitions_same_loop(self):
"""Test concurrent lock acquisitions on the same loop."""
manager = SessionLockManager()
session_id = "concurrent-session"
acquired_count = 0
max_concurrent = 0
lock = asyncio.Lock()
async def acquire_and_check():
nonlocal acquired_count, max_concurrent
async with manager.acquire_lock(session_id):
async with lock:
acquired_count += 1
max_concurrent = max(max_concurrent, acquired_count)
await asyncio.sleep(0.01)
async with lock:
acquired_count -= 1
# Run multiple concurrent tasks
tasks = [acquire_and_check() for _ in range(5)]
await asyncio.gather(*tasks)
# Max concurrent should be 1 (lock serializes access)
assert max_concurrent == 1
@pytest.mark.asyncio
async def test_thread_safety_of_loop_manager_creation(self):
"""Test that _get_loop_manager is thread-safe."""
manager = SessionLockManager()
managers = []
errors = []
def create_loop_and_get_manager():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
async def get_mgr():
return manager._get_loop_manager()
mgr = loop.run_until_complete(get_mgr())
managers.append(mgr)
except Exception as e:
errors.append(e)
finally:
loop.close()
asyncio.set_event_loop(None)
threads = [threading.Thread(target=create_loop_and_get_manager) for _ in range(10)]
for t in threads:
t.start()
for t in threads:
t.join()
assert len(errors) == 0
# All managers should be valid
for m in managers:
assert hasattr(m, "_locks")
assert hasattr(m, "_access_lock")
class TestEventLoopCleanup:
"""Tests for event loop cleanup behavior."""
@pytest.mark.asyncio
async def test_weakref_cleanup_on_loop_close(self):
"""Test that per-loop managers are cleaned up when loop is closed."""
manager = SessionLockManager()
loop_ref: weakref.ref[asyncio.AbstractEventLoop] | None = None
def run_in_new_loop():
nonlocal loop_ref
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop_ref = weakref.ref(loop)
async def use_lock():
async with manager.acquire_lock("test-session"):
pass
return manager._get_loop_manager()
try:
per_loop_mgr = loop.run_until_complete(use_lock())
# Keep a weak ref to the per-loop manager
return weakref.ref(per_loop_mgr)
finally:
loop.close()
asyncio.set_event_loop(None)
with ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(run_in_new_loop)
per_loop_mgr_ref = future.result()
# Give time for weakref cleanup
import gc
gc.collect()
# The per-loop manager should be cleaned up when the loop is closed
# because WeakKeyDictionary removes entries when the key (loop) is gone
per_loop_mgr = per_loop_mgr_ref()
loop = loop_ref() if loop_ref is not None else None
assert per_loop_mgr is None or loop is None
@pytest.mark.asyncio
async def test_access_after_loop_close_in_new_loop_works(self):
"""Test that accessing from a new loop after old loop closes works."""
manager = SessionLockManager()
# Use lock in current loop
async with manager.acquire_lock("session-1"):
pass
# Simulate old loop being closed and new loop being created
def run_in_new_loop():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
async def use_lock():
# Should work without issues in new loop
async with manager.acquire_lock("session-2"):
return "success"
return loop.run_until_complete(use_lock())
finally:
loop.close()
asyncio.set_event_loop(None)
with ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(run_in_new_loop)
result = future.result()
assert result == "success"
class TestIssue5464:
"""Tests for issue #5464: Multiple OneBot instances with different event loops.
Issue: Running multiple OneBot adapter instances causes
"is bound to a different event loop" error.
"""
@pytest.mark.asyncio
async def test_multiple_event_loops_no_cross_loop_error(self):
"""Test that multiple event loops don't cause cross-loop binding errors.
This simulates the scenario where multiple OneBot instances
(each potentially running in different event loops) access the
same SessionLockManager concurrently.
"""
from astrbot.core.utils.session_lock import session_lock_manager
errors: list[Exception] = []
results: list[str] = []
def simulate_onebot_instance(instance_id: int, session_ids: list[str]):
"""Simulate a OneBot instance running in its own event loop."""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
async def process_messages():
for session_id in session_ids:
try:
async with session_lock_manager.acquire_lock(session_id):
# Simulate message processing
await asyncio.sleep(0.01)
results.append(f"instance-{instance_id}-{session_id}")
except Exception as e:
errors.append(e)
loop.run_until_complete(process_messages())
finally:
loop.close()
asyncio.set_event_loop(None)
# Simulate 4 OneBot instances (as in the issue report)
# Each handles multiple sessions concurrently
threads = []
for i in range(4):
sessions = [f"session-{i}-1", f"session-{i}-2", f"session-{i}-3"]
t = threading.Thread(target=simulate_onebot_instance, args=(i, sessions))
threads.append(t)
for t in threads:
t.start()
for t in threads:
t.join()
# Should have no errors (especially no "bound to a different event loop")
assert len(errors) == 0, f"Errors occurred: {errors}"
assert len(results) == 12 # 4 instances * 3 sessions each
@pytest.mark.asyncio
async def test_lock_object_not_shared_across_loops(self):
"""Verify that asyncio.Lock objects are not shared across event loops.
The root cause of issue #5464 was that Lock objects created in one
event loop were being used in another, causing the error.
"""
manager = SessionLockManager()
session_id = "shared-session-id"
lock_ids: set[int] = set()
lock_id_lock = threading.Lock()
def get_lock_in_new_loop():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
async def acquire_and_capture():
# Get the per-loop manager
per_loop_mgr = manager._get_loop_manager()
# Capture the lock object id before acquiring
async with per_loop_mgr._access_lock:
lock = per_loop_mgr._locks[session_id]
with lock_id_lock:
lock_ids.add(id(lock))
async with manager.acquire_lock(session_id):
await asyncio.sleep(0.01)
loop.run_until_complete(acquire_and_capture())
finally:
loop.close()
asyncio.set_event_loop(None)
# Run multiple loops concurrently
threads = [threading.Thread(target=get_lock_in_new_loop) for _ in range(5)]
for t in threads:
t.start()
for t in threads:
t.join()
# Each loop should have its own Lock object
# If locks were shared, we'd only have 1 lock_id
assert len(lock_ids) == 5, "Each event loop should have its own Lock object"
@pytest.mark.asyncio
async def test_concurrent_access_same_session_different_loops(self):
"""Test that same session ID accessed from different loops doesn't block.
This verifies the fix: locks are isolated per event loop,
so different loops can acquire the "same" session lock concurrently.
"""
from astrbot.core.utils.session_lock import session_lock_manager
session_id = "global-session"
acquisition_times: list[float] = []
time_lock = threading.Lock()
def acquire_lock_in_loop(loop_id: int):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
async def acquire():
import time
start = time.time()
async with session_lock_manager.acquire_lock(session_id):
with time_lock:
acquisition_times.append(start)
await asyncio.sleep(0.1) # Hold the lock
loop.run_until_complete(acquire())
finally:
loop.close()
asyncio.set_event_loop(None)
# Start 3 threads nearly simultaneously
threads = [threading.Thread(target=acquire_lock_in_loop, args=(i,)) for i in range(3)]
start_time = time.time()
for t in threads:
t.start()
for t in threads:
t.join()
total_time = time.time() - start_time
# If locks were NOT isolated, we'd need ~0.3s (3 * 0.1s serial)
# With isolation, all should complete in ~0.1s (parallel)
# Allow some overhead, but should be much less than 0.3s
assert total_time < 0.25, (
f"Locks should be isolated per loop, but took {total_time:.2f}s"
)
class TestEdgeCases:
"""Tests for edge cases."""
@pytest.mark.asyncio
async def test_empty_session_id(self):
"""Test with empty session ID."""
manager = SessionLockManager()
async with manager.acquire_lock(""):
pass
# Should work without issues
@pytest.mark.asyncio
async def test_special_characters_in_session_id(self):
"""Test with special characters in session ID."""
manager = SessionLockManager()
session_id = "session-with-special-chars!@#$%^&*()"
async with manager.acquire_lock(session_id):
pass
# Should work without issues
@pytest.mark.asyncio
async def test_very_long_session_id(self):
"""Test with very long session ID."""
manager = SessionLockManager()
session_id = "a" * 10000
async with manager.acquire_lock(session_id):
pass
# Should work without issues
@pytest.mark.asyncio
async def test_lock_not_held_after_context_exit(self):
"""Test that lock is released after context manager exit."""
manager = SessionLockManager()
session_id = "test-session"
async with manager.acquire_lock(session_id):
state = manager._get_loop_manager()
# Lock should exist and have count 1
assert session_id in state._locks
assert state._lock_count[session_id] == 1
# After exit, lock should be cleaned up
state = manager._get_loop_manager()
assert session_id not in state._locks
assert session_id not in state._lock_count
@pytest.mark.asyncio
async def test_exception_during_lock(self):
"""Test that lock is released even if exception occurs."""
manager = SessionLockManager()
session_id = "test-session"
with pytest.raises(ValueError):
async with manager.acquire_lock(session_id):
raise ValueError("test error")
# Lock should still be released
state = manager._get_loop_manager()
assert session_id not in state._locks
assert session_id not in state._lock_count
@pytest.mark.asyncio
async def test_nested_lock_different_sessions(self):
"""Test nested locks on different sessions."""
manager = SessionLockManager()
async with manager.acquire_lock("session-1"):
async with manager.acquire_lock("session-2"):
state = manager._get_loop_manager()
assert "session-1" in state._locks
assert "session-2" in state._locks
assert state._lock_count["session-1"] == 1
assert state._lock_count["session-2"] == 1
state = manager._get_loop_manager()
assert "session-1" not in state._locks
assert "session-2" not in state._locks
@pytest.mark.asyncio
async def test_reentrant_lock_same_session(self):
"""Test reentrant locking on same session (should block)."""
manager = SessionLockManager()
session_id = "test-session"
order = []
async def outer():
async with manager.acquire_lock(session_id):
order.append("outer-acquired")
await asyncio.sleep(0.1)
order.append("outer-done")
async def inner():
await asyncio.sleep(0.01) # Let outer acquire first
order.append("inner-attempt")
async with manager.acquire_lock(session_id):
order.append("inner-acquired")
order.append("inner-done")
await asyncio.gather(outer(), inner())
# Inner should wait for outer to complete
assert order.index("outer-acquired") < order.index("outer-done")
assert order.index("outer-done") < order.index("inner-acquired")