Compare commits
18 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 349ca05e26 | |||
| 7f3c0fdeb2 | |||
| 8e431e2076 | |||
| 89c11fd683 | |||
| 7cfe2aca99 | |||
| 3a938d2a13 | |||
| 812834bc9f | |||
| 51ff4f6e46 | |||
| 7ac169c5e8 | |||
| 61648ebe3e | |||
| 0610f0db0a | |||
| 8c935981bb | |||
| 3f3b4e4924 | |||
| af581e7f21 | |||
| 9e371ee10b | |||
| 7cf77adbc8 | |||
| 31673ee521 | |||
| ff22030dde |
@@ -83,6 +83,15 @@ astrbot
|
||||
|
||||
> Requires [uv](https://docs.astral.sh/uv/) to be installed.
|
||||
|
||||
> [!NOTE]
|
||||
> For macOS user: due to macOS security checks, the first run of the `astrbot` command may take longer (about 10-20s).
|
||||
|
||||
Update `astrbot`:
|
||||
|
||||
```bash
|
||||
uv tool upgrade astrbot
|
||||
```
|
||||
|
||||
### Docker Deployment
|
||||
|
||||
For users familiar with containers and looking for a more stable, production-ready deployment method, we recommend deploying AstrBot with Docker / Docker Compose.
|
||||
@@ -216,12 +225,15 @@ pre-commit install
|
||||
|
||||
### QQ Groups
|
||||
|
||||
- Group 9: 1076659624 (New)
|
||||
- Group 10: 1078079676 (New)
|
||||
- Group 1: 322154837
|
||||
- Group 3: 630166526
|
||||
- Group 5: 822130018
|
||||
- Group 6: 753075035
|
||||
- Group 7: 743746109
|
||||
- Group 8: 1030353265
|
||||
|
||||
- Developer Group: 975206796
|
||||
|
||||
### Discord Server
|
||||
|
||||
@@ -83,6 +83,15 @@ astrbot
|
||||
|
||||
> [uv](https://docs.astral.sh/uv/) doit être installé.
|
||||
|
||||
> [!NOTE]
|
||||
> Pour les utilisateurs macOS : en raison des vérifications de sécurité de macOS, la première exécution de la commande `astrbot` peut prendre plus de temps (environ 10-20s).
|
||||
|
||||
Mettre à jour `astrbot` :
|
||||
|
||||
```bash
|
||||
uv tool upgrade astrbot
|
||||
```
|
||||
|
||||
### Déploiement Docker
|
||||
|
||||
Pour les utilisateurs familiers avec les conteneurs et qui souhaitent une méthode plus stable et adaptée à la production, nous recommandons de déployer AstrBot avec Docker / Docker Compose.
|
||||
|
||||
@@ -83,6 +83,15 @@ astrbot
|
||||
|
||||
> [uv](https://docs.astral.sh/uv/) のインストールが必要です。
|
||||
|
||||
> [!NOTE]
|
||||
> macOS ユーザーの場合:macOS のセキュリティチェックにより、`astrbot` コマンドの初回実行に時間がかかる場合があります(約 10〜20 秒)。
|
||||
|
||||
`astrbot` の更新:
|
||||
|
||||
```bash
|
||||
uv tool upgrade astrbot
|
||||
```
|
||||
|
||||
### Docker デプロイ
|
||||
|
||||
コンテナ運用に慣れており、より安定した本番向けのデプロイ方法を求めるユーザーには、Docker / Docker Compose での AstrBot デプロイをおすすめします。
|
||||
|
||||
@@ -83,6 +83,15 @@ astrbot
|
||||
|
||||
> Требуется установленный [uv](https://docs.astral.sh/uv/).
|
||||
|
||||
> [!NOTE]
|
||||
> Для пользователей macOS: из-за проверок безопасности macOS первый запуск команды `astrbot` может занять больше времени (около 10-20 секунд).
|
||||
|
||||
Обновить `astrbot`:
|
||||
|
||||
```bash
|
||||
uv tool upgrade astrbot
|
||||
```
|
||||
|
||||
### Развёртывание Docker
|
||||
|
||||
Для пользователей, знакомых с контейнерами и которым нужен более стабильный и подходящий для production способ, мы рекомендуем разворачивать AstrBot через Docker / Docker Compose.
|
||||
|
||||
@@ -83,6 +83,15 @@ astrbot
|
||||
|
||||
> 需要安裝 [uv](https://docs.astral.sh/uv/)。
|
||||
|
||||
> [!NOTE]
|
||||
> 對於 macOS 使用者:由於 macOS 安全性檢查,首次執行 `astrbot` 指令可能需要較長時間(約 10-20 秒)。
|
||||
|
||||
更新 `astrbot`:
|
||||
|
||||
```bash
|
||||
uv tool upgrade astrbot
|
||||
```
|
||||
|
||||
### Docker 部署
|
||||
|
||||
對於熟悉容器、希望獲得更穩定且更適合正式環境部署方式的使用者,我們推薦使用 Docker / Docker Compose 部署 AstrBot。
|
||||
@@ -208,10 +217,14 @@ pre-commit install
|
||||
|
||||
### QQ 群組
|
||||
|
||||
- 9 群: 1076659624 (新)
|
||||
- 10 群: 1078079676 (新)
|
||||
- 1 群:322154837
|
||||
- 3 群:630166526
|
||||
- 5 群:822130018
|
||||
- 6 群:753075035
|
||||
- 7 群:743746109
|
||||
- 8 群:1030353265
|
||||
- 開發者群:975206796
|
||||
|
||||
### Discord 群組
|
||||
|
||||
@@ -83,6 +83,15 @@ astrbot
|
||||
|
||||
> 需要安装 [uv](https://docs.astral.sh/uv/)。
|
||||
|
||||
> [!NOTE]
|
||||
> 对于 macOS 用户:由于 macOS 安全检查,首次运行 `astrbot` 命令可能需要较长时间(约 10-20 秒)。
|
||||
|
||||
更新 `astrbot`:
|
||||
|
||||
```bash
|
||||
uv tool upgrade astrbot
|
||||
```
|
||||
|
||||
### Docker 部署
|
||||
|
||||
对于熟悉容器、希望获得更稳定且更适合生产环境部署方式的用户,我们推荐使用 Docker / Docker Compose 部署 AstrBot。
|
||||
@@ -209,6 +218,8 @@ pre-commit install
|
||||
|
||||
### QQ 群组
|
||||
|
||||
- 9 群: 1076659624 (新)
|
||||
- 10 群: 1078079676 (新)
|
||||
- 1 群:322154837
|
||||
- 3 群:630166526
|
||||
- 5 群:822130018
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = "4.19.1"
|
||||
__version__ = "4.19.2"
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Any, TypedDict
|
||||
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
VERSION = "4.19.1"
|
||||
VERSION = "4.19.2"
|
||||
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db")
|
||||
|
||||
WEBHOOK_SUPPORTED_PLATFORMS = [
|
||||
@@ -1123,7 +1123,7 @@ CONFIG_METADATA_2 = {
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"timeout": 120,
|
||||
"api_base": "https://openrouter.ai/v1",
|
||||
"api_base": "https://openrouter.ai/api/v1",
|
||||
"proxy": "",
|
||||
"custom_headers": {},
|
||||
},
|
||||
|
||||
@@ -97,11 +97,7 @@ class AstrBotCoreLifecycle:
|
||||
except Exception as e:
|
||||
logger.error(f"Subagent orchestrator init failed: {e}", exc_info=True)
|
||||
|
||||
async def initialize(
|
||||
self,
|
||||
*,
|
||||
mcp_init_timeout: float | int | str | None = None,
|
||||
) -> None:
|
||||
async def initialize(self) -> None:
|
||||
"""初始化 AstrBot 核心生命周期管理类.
|
||||
|
||||
负责初始化各个组件, 包括 ProviderManager、PlatformManager、ConversationManager、PluginManager、PipelineScheduler、EventBus、AstrBotUpdator等。
|
||||
@@ -205,7 +201,7 @@ class AstrBotCoreLifecycle:
|
||||
await self.plugin_manager.reload()
|
||||
|
||||
# 根据配置实例化各个 Provider
|
||||
await self.provider_manager.initialize(init_timeout=mcp_init_timeout)
|
||||
await self.provider_manager.initialize()
|
||||
|
||||
await self.kb_manager.initialize()
|
||||
|
||||
|
||||
@@ -539,13 +539,36 @@ class Reply(BaseMessageComponent):
|
||||
|
||||
|
||||
class Poke(BaseMessageComponent):
|
||||
type: str = ComponentType.Poke
|
||||
id: int | None = 0
|
||||
qq: int | None = 0
|
||||
type: ComponentType = ComponentType.Poke
|
||||
_type: str | int = "126"
|
||||
id: int | str | None = 0
|
||||
qq: int | str | None = 0 # deprecated: legacy field, kept for compatibility
|
||||
|
||||
def __init__(self, type: str, **_) -> None:
|
||||
type = f"Poke:{type}"
|
||||
super().__init__(type=type, **_)
|
||||
def __init__(self, poke_type: str | int | None = None, **_) -> None:
|
||||
# Backward compatible with old signature: Poke(type="poke", ...)
|
||||
legacy_type = _.pop("type", None)
|
||||
if poke_type is None:
|
||||
poke_type = legacy_type
|
||||
if poke_type in (None, "", "poke", "Poke"):
|
||||
poke_type = "126"
|
||||
super().__init__(_type=str(poke_type), **_)
|
||||
|
||||
def target_id(self) -> str | None:
|
||||
"""Return normalized target id, compatible with old `qq` field."""
|
||||
for value in (self.id, self.qq):
|
||||
if value is None:
|
||||
continue
|
||||
text = str(value).strip()
|
||||
if text and text != "0":
|
||||
return text
|
||||
return None
|
||||
|
||||
def toDict(self):
|
||||
target_id = self.target_id()
|
||||
data = {"type": str(self._type or "126")}
|
||||
if target_id:
|
||||
data["id"] = target_id
|
||||
return {"type": "poke", "data": data}
|
||||
|
||||
|
||||
class Forward(BaseMessageComponent):
|
||||
|
||||
@@ -28,7 +28,7 @@ class RespondStage(Stage):
|
||||
Comp.At: lambda comp: bool(comp.qq) or bool(comp.name), # @
|
||||
Comp.Image: lambda comp: bool(comp.file), # 图片
|
||||
Comp.Reply: lambda comp: bool(comp.id) and comp.sender_id is not None, # 回复
|
||||
Comp.Poke: lambda comp: comp.id != 0 and comp.qq != 0, # 戳一戳
|
||||
Comp.Poke: lambda comp: comp.target_id() is not None, # 戳一戳
|
||||
Comp.Node: lambda comp: bool(comp.content), # 转发节点
|
||||
Comp.Nodes: lambda comp: bool(comp.nodes), # 多个转发节点
|
||||
Comp.File: lambda comp: bool(comp.file_ or comp.url),
|
||||
|
||||
@@ -5,7 +5,7 @@ import traceback
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from astrbot.core import file_token_service, html_renderer, logger
|
||||
from astrbot.core.message.components import At, File, Image, Node, Plain, Record, Reply
|
||||
from astrbot.core.message.components import At, Image, Node, Plain, Record, Reply
|
||||
from astrbot.core.message.message_event_result import ResultContentType
|
||||
from astrbot.core.pipeline.content_safety_check.stage import ContentSafetyCheckStage
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
@@ -383,8 +383,11 @@ class ResultDecorateStage(Stage):
|
||||
)
|
||||
result.chain = [node]
|
||||
|
||||
has_plain = any(isinstance(item, Plain) for item in result.chain)
|
||||
if has_plain:
|
||||
# at 回复 / 引用回复仅适用于纯文本或图文消息
|
||||
can_decorate = all(
|
||||
isinstance(item, (Plain, Image)) for item in result.chain
|
||||
)
|
||||
if can_decorate:
|
||||
# at 回复
|
||||
if (
|
||||
self.reply_with_mention
|
||||
@@ -399,5 +402,4 @@ class ResultDecorateStage(Stage):
|
||||
|
||||
# 引用回复
|
||||
if self.reply_with_quote:
|
||||
if not any(isinstance(item, File) for item in result.chain):
|
||||
result.chain.insert(0, Reply(id=event.message_obj.message_id))
|
||||
result.chain.insert(0, Reply(id=event.message_obj.message_id))
|
||||
|
||||
@@ -191,7 +191,7 @@ class AiocqhttpAdapter(Platform):
|
||||
|
||||
if "sub_type" in event:
|
||||
if event["sub_type"] == "poke" and "target_id" in event:
|
||||
abm.message.append(Poke(qq=str(event["target_id"]), type="poke"))
|
||||
abm.message.append(Poke(id=str(event["target_id"])))
|
||||
|
||||
return abm
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ from dingtalk_stream import AckMessage
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot.api.event import MessageChain
|
||||
from astrbot.api.message_components import At, Image, Plain, Record, Video
|
||||
from astrbot.api.message_components import At, File, Image, Plain, Record, Video
|
||||
from astrbot.api.platform import (
|
||||
AstrBotMessage,
|
||||
MessageMember,
|
||||
@@ -178,29 +178,110 @@ class DingtalkPlatformAdapter(Platform):
|
||||
abm.session_id = abm.sender.user_id
|
||||
|
||||
message_type: str = cast(str, message.message_type)
|
||||
robot_code = cast(str, message.robot_code or "")
|
||||
raw_content = cast(dict, message.extensions.get("content") or {})
|
||||
if not isinstance(raw_content, dict):
|
||||
raw_content = {}
|
||||
match message_type:
|
||||
case "text":
|
||||
abm.message_str = message.text.content.strip()
|
||||
abm.message.append(Plain(abm.message_str))
|
||||
case "picture":
|
||||
if not robot_code:
|
||||
logger.error("钉钉图片消息解析失败: 回调中缺少 robotCode")
|
||||
await self._remember_sender_binding(message, abm)
|
||||
return abm
|
||||
image_content = cast(
|
||||
dingtalk_stream.ImageContent | None,
|
||||
message.image_content,
|
||||
)
|
||||
download_code = cast(
|
||||
str, (image_content.download_code if image_content else "") or ""
|
||||
)
|
||||
if not download_code:
|
||||
logger.warning("钉钉图片消息缺少 downloadCode,已跳过")
|
||||
else:
|
||||
f_path = await self.download_ding_file(
|
||||
download_code,
|
||||
robot_code,
|
||||
"jpg",
|
||||
)
|
||||
if f_path:
|
||||
abm.message.append(Image.fromFileSystem(f_path))
|
||||
else:
|
||||
logger.warning("钉钉图片消息下载失败,无法解析为图片")
|
||||
case "richText":
|
||||
rtc: dingtalk_stream.RichTextContent = cast(
|
||||
dingtalk_stream.RichTextContent, message.rich_text_content
|
||||
)
|
||||
contents: list[dict] = cast(list[dict], rtc.rich_text_list)
|
||||
plain_parts: list[str] = []
|
||||
for content in contents:
|
||||
plains = ""
|
||||
if "text" in content:
|
||||
plains += content["text"]
|
||||
abm.message.append(Plain(plains))
|
||||
plain_text = cast(str, content.get("text") or "")
|
||||
if plain_text:
|
||||
plain_parts.append(plain_text)
|
||||
abm.message.append(Plain(plain_text))
|
||||
elif "type" in content and content["type"] == "picture":
|
||||
download_code = cast(str, content.get("downloadCode") or "")
|
||||
if not download_code:
|
||||
logger.warning(
|
||||
"钉钉富文本图片消息缺少 downloadCode,已跳过"
|
||||
)
|
||||
continue
|
||||
if not robot_code:
|
||||
logger.error(
|
||||
"钉钉富文本图片消息解析失败: 回调中缺少 robotCode"
|
||||
)
|
||||
continue
|
||||
f_path = await self.download_ding_file(
|
||||
content["downloadCode"],
|
||||
cast(str, message.robot_code),
|
||||
download_code,
|
||||
robot_code,
|
||||
"jpg",
|
||||
)
|
||||
abm.message.append(Image.fromFileSystem(f_path))
|
||||
case "audio":
|
||||
pass
|
||||
if f_path:
|
||||
abm.message.append(Image.fromFileSystem(f_path))
|
||||
abm.message_str = "".join(plain_parts).strip()
|
||||
case "audio" | "voice":
|
||||
download_code = cast(str, raw_content.get("downloadCode") or "")
|
||||
if not download_code:
|
||||
logger.warning("钉钉语音消息缺少 downloadCode,已跳过")
|
||||
elif not robot_code:
|
||||
logger.error("钉钉语音消息解析失败: 回调中缺少 robotCode")
|
||||
else:
|
||||
voice_ext = cast(str, raw_content.get("fileExtension") or "")
|
||||
if not voice_ext:
|
||||
voice_ext = "amr"
|
||||
voice_ext = voice_ext.lstrip(".")
|
||||
f_path = await self.download_ding_file(
|
||||
download_code,
|
||||
robot_code,
|
||||
voice_ext,
|
||||
)
|
||||
if f_path:
|
||||
abm.message.append(Record.fromFileSystem(f_path))
|
||||
case "file":
|
||||
download_code = cast(str, raw_content.get("downloadCode") or "")
|
||||
if not download_code:
|
||||
logger.warning("钉钉文件消息缺少 downloadCode,已跳过")
|
||||
elif not robot_code:
|
||||
logger.error("钉钉文件消息解析失败: 回调中缺少 robotCode")
|
||||
else:
|
||||
file_name = cast(str, raw_content.get("fileName") or "")
|
||||
file_ext = Path(file_name).suffix.lstrip(".") if file_name else ""
|
||||
if not file_ext:
|
||||
file_ext = cast(str, raw_content.get("fileExtension") or "")
|
||||
if not file_ext:
|
||||
file_ext = "file"
|
||||
f_path = await self.download_ding_file(
|
||||
download_code,
|
||||
robot_code,
|
||||
file_ext,
|
||||
)
|
||||
if f_path:
|
||||
if not file_name:
|
||||
file_name = Path(f_path).name
|
||||
abm.message.append(File(name=file_name, file=f_path))
|
||||
|
||||
await self._remember_sender_binding(message, abm)
|
||||
return abm # 别忘了返回转换后的消息对象
|
||||
@@ -270,7 +351,17 @@ class DingtalkPlatformAdapter(Platform):
|
||||
)
|
||||
return ""
|
||||
resp_data = await resp.json()
|
||||
download_url = resp_data["data"]["downloadUrl"]
|
||||
download_url = cast(
|
||||
str,
|
||||
(
|
||||
resp_data.get("downloadUrl")
|
||||
or resp_data.get("data", {}).get("downloadUrl")
|
||||
or ""
|
||||
),
|
||||
)
|
||||
if not download_url:
|
||||
logger.error(f"下载钉钉文件失败: 未找到 downloadUrl, 响应: {resp_data}")
|
||||
return ""
|
||||
await download_file(download_url, str(f_path))
|
||||
return str(f_path)
|
||||
|
||||
@@ -541,6 +632,28 @@ class DingtalkPlatformAdapter(Platform):
|
||||
self._safe_remove_file(cover_path)
|
||||
if converted_video:
|
||||
self._safe_remove_file(video_path)
|
||||
elif isinstance(segment, File):
|
||||
try:
|
||||
file_path = await segment.get_file()
|
||||
if not file_path:
|
||||
logger.warning("钉钉文件发送失败: 无法解析文件路径")
|
||||
continue
|
||||
media_id = await self.upload_media(file_path, "file")
|
||||
if not media_id:
|
||||
continue
|
||||
file_name = segment.name or Path(file_path).name
|
||||
file_type = Path(file_name).suffix.lstrip(".")
|
||||
await send_message(
|
||||
msg_key="sampleFile",
|
||||
msg_param={
|
||||
"mediaId": media_id,
|
||||
"fileName": file_name,
|
||||
"fileType": file_type,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"钉钉文件发送失败: {e}")
|
||||
continue
|
||||
|
||||
async def send_message_chain_to_group(
|
||||
self,
|
||||
|
||||
@@ -346,10 +346,7 @@ class FunctionToolManager:
|
||||
logger.debug(f" 主机: {scheme}://{host}{port}")
|
||||
|
||||
async def init_mcp_clients(
|
||||
self,
|
||||
raise_on_all_failed: bool = False,
|
||||
*,
|
||||
init_timeout: float | int | str | None = None,
|
||||
self, raise_on_all_failed: bool = False
|
||||
) -> MCPInitSummary:
|
||||
"""从项目根目录读取 mcp_server.json 文件,初始化 MCP 服务列表。文件格式如下:
|
||||
```
|
||||
@@ -370,7 +367,6 @@ class FunctionToolManager:
|
||||
```
|
||||
|
||||
Timeout behavior:
|
||||
- 显式 `init_timeout` 参数优先(用于测试或调用方覆盖)。
|
||||
- 初始化超时使用环境变量 ASTRBOT_MCP_INIT_TIMEOUT 或默认值。
|
||||
- 动态启用超时使用 ASTRBOT_MCP_ENABLE_TIMEOUT(独立于初始化超时)。
|
||||
"""
|
||||
@@ -387,12 +383,8 @@ class FunctionToolManager:
|
||||
with open(mcp_json_file, encoding="utf-8") as f:
|
||||
mcp_server_json_obj: dict[str, dict] = json.load(f)["mcpServers"]
|
||||
|
||||
init_timeout_value = _resolve_timeout(
|
||||
timeout=init_timeout,
|
||||
env_name=MCP_INIT_TIMEOUT_ENV,
|
||||
default=self._init_timeout_default,
|
||||
)
|
||||
timeout_display = f"{init_timeout_value:g}"
|
||||
init_timeout = self._init_timeout_default
|
||||
timeout_display = f"{init_timeout:g}"
|
||||
|
||||
active_configs: list[tuple[str, dict, asyncio.Event]] = []
|
||||
for name, cfg in mcp_server_json_obj.items():
|
||||
@@ -411,7 +403,7 @@ class FunctionToolManager:
|
||||
name=name,
|
||||
cfg=cfg,
|
||||
shutdown_event=shutdown_event,
|
||||
timeout_seconds=init_timeout_value,
|
||||
timeout=init_timeout,
|
||||
),
|
||||
name=f"mcp-init:{name}",
|
||||
)
|
||||
|
||||
@@ -269,11 +269,7 @@ class ProviderManager:
|
||||
|
||||
return provider
|
||||
|
||||
async def initialize(
|
||||
self,
|
||||
*,
|
||||
init_timeout: float | int | str | None = None,
|
||||
) -> None:
|
||||
async def initialize(self) -> None:
|
||||
# 逐个初始化提供商
|
||||
for provider_config in self.providers_config:
|
||||
try:
|
||||
@@ -342,8 +338,7 @@ class ProviderManager:
|
||||
"on",
|
||||
}
|
||||
mcp_init_summary = await self.llm_tools.init_mcp_clients(
|
||||
raise_on_all_failed=strict_mcp_init,
|
||||
init_timeout=init_timeout,
|
||||
raise_on_all_failed=strict_mcp_init
|
||||
)
|
||||
if (
|
||||
mcp_init_summary.total > 0
|
||||
|
||||
@@ -26,6 +26,13 @@ _SANDBOX_SKILLS_CACHE_VERSION = 1
|
||||
_SKILL_NAME_RE = re.compile(r"^[A-Za-z0-9._-]+$")
|
||||
|
||||
|
||||
def _is_ignored_zip_entry(name: str) -> bool:
|
||||
parts = PurePosixPath(name).parts
|
||||
if not parts:
|
||||
return True
|
||||
return parts[0] == "__MACOSX"
|
||||
|
||||
|
||||
@dataclass
|
||||
class SkillInfo:
|
||||
name: str
|
||||
@@ -401,7 +408,11 @@ class SkillManager:
|
||||
raise ValueError("Uploaded file is not a valid zip archive.")
|
||||
|
||||
with zipfile.ZipFile(zip_path) as zf:
|
||||
names = [name.replace("\\", "/") for name in zf.namelist()]
|
||||
names = [
|
||||
name
|
||||
for name in (entry.replace("\\", "/") for entry in zf.namelist())
|
||||
if name and not _is_ignored_zip_entry(name)
|
||||
]
|
||||
file_names = [name for name in names if name and not name.endswith("/")]
|
||||
if not file_names:
|
||||
raise ValueError("Zip archive is empty.")
|
||||
@@ -436,7 +447,11 @@ class SkillManager:
|
||||
raise ValueError("SKILL.md not found in the skill folder.")
|
||||
|
||||
with tempfile.TemporaryDirectory(dir=get_astrbot_temp_path()) as tmp_dir:
|
||||
zf.extractall(tmp_dir)
|
||||
for member in zf.infolist():
|
||||
member_name = member.filename.replace("\\", "/")
|
||||
if not member_name or _is_ignored_zip_entry(member_name):
|
||||
continue
|
||||
zf.extract(member, tmp_dir)
|
||||
src_dir = Path(tmp_dir) / skill_name
|
||||
if not src_dir.exists():
|
||||
raise ValueError("Skill folder not found after extraction.")
|
||||
|
||||
@@ -15,4 +15,4 @@ class RegexFilter(HandlerFilter):
|
||||
self.regex = re.compile(regex)
|
||||
|
||||
def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool:
|
||||
return bool(self.regex.match(event.get_message_str().strip()))
|
||||
return bool(self.regex.search(event.get_message_str().strip()))
|
||||
|
||||
@@ -7,4 +7,4 @@ def is_frozen_runtime() -> bool:
|
||||
|
||||
|
||||
def is_packaged_desktop_runtime() -> bool:
|
||||
return is_frozen_runtime() and os.environ.get("ASTRBOT_DESKTOP_CLIENT") == "1"
|
||||
return os.environ.get("ASTRBOT_DESKTOP_CLIENT") == "1"
|
||||
|
||||
@@ -610,6 +610,7 @@ class ConfigRoute(Route):
|
||||
|
||||
try:
|
||||
conf_id = self.acm.create_conf(name=name, config=config)
|
||||
await self.core_lifecycle.reload_pipeline_scheduler(conf_id)
|
||||
return Response().ok(message="创建成功", data={"conf_id": conf_id}).__dict__
|
||||
except ValueError as e:
|
||||
return Response().error(str(e)).__dict__
|
||||
@@ -649,6 +650,7 @@ class ConfigRoute(Route):
|
||||
try:
|
||||
success = self.acm.delete_conf(conf_id)
|
||||
if success:
|
||||
self.core_lifecycle.pipeline_scheduler_mapping.pop(conf_id, None)
|
||||
return Response().ok(message="删除成功").__dict__
|
||||
return Response().error("删除失败").__dict__
|
||||
except ValueError as e:
|
||||
|
||||
@@ -2,6 +2,7 @@ import os
|
||||
import re
|
||||
import shutil
|
||||
import traceback
|
||||
import uuid
|
||||
from collections.abc import Awaitable, Callable
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
@@ -50,6 +51,7 @@ class SkillsRoute(Route):
|
||||
self.routes = {
|
||||
"/skills": ("GET", self.get_skills),
|
||||
"/skills/upload": ("POST", self.upload_skill),
|
||||
"/skills/batch-upload": ("POST", self.batch_upload_skills),
|
||||
"/skills/download": ("GET", self.download_skill),
|
||||
"/skills/update": ("POST", self.update_skill),
|
||||
"/skills/delete": ("POST", self.delete_skill),
|
||||
@@ -188,6 +190,114 @@ class SkillsRoute(Route):
|
||||
except Exception:
|
||||
logger.warning(f"Failed to remove temp skill file: {temp_path}")
|
||||
|
||||
async def batch_upload_skills(self):
|
||||
"""批量上传多个 skill ZIP 文件"""
|
||||
if DEMO_MODE:
|
||||
return (
|
||||
Response()
|
||||
.error("You are not permitted to do this operation in demo mode")
|
||||
.__dict__
|
||||
)
|
||||
|
||||
try:
|
||||
files = await request.files
|
||||
file_list = files.getlist("files")
|
||||
|
||||
if not file_list:
|
||||
return Response().error("No files provided").__dict__
|
||||
|
||||
succeeded = []
|
||||
failed = []
|
||||
skill_mgr = SkillManager()
|
||||
temp_dir = get_astrbot_temp_path()
|
||||
os.makedirs(temp_dir, exist_ok=True)
|
||||
|
||||
for file in file_list:
|
||||
filename = os.path.basename(file.filename or "unknown.zip")
|
||||
temp_path = None
|
||||
|
||||
try:
|
||||
if not filename.lower().endswith(".zip"):
|
||||
failed.append(
|
||||
{
|
||||
"filename": filename,
|
||||
"error": "Only .zip files are supported",
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
temp_path = os.path.join(
|
||||
temp_dir, f"batch_{uuid.uuid4().hex}_{filename}"
|
||||
)
|
||||
await file.save(temp_path)
|
||||
|
||||
skill_name = skill_mgr.install_skill_from_zip(
|
||||
temp_path, overwrite=True
|
||||
)
|
||||
succeeded.append({"filename": filename, "name": skill_name})
|
||||
|
||||
except Exception as e:
|
||||
failed.append({"filename": filename, "error": str(e)})
|
||||
finally:
|
||||
if temp_path and os.path.exists(temp_path):
|
||||
try:
|
||||
os.remove(temp_path)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if succeeded:
|
||||
try:
|
||||
await sync_skills_to_active_sandboxes()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to sync uploaded skills to active sandboxes."
|
||||
)
|
||||
|
||||
total = len(file_list)
|
||||
success_count = len(succeeded)
|
||||
|
||||
if success_count == total:
|
||||
message = f"All {total} skill(s) uploaded successfully."
|
||||
return (
|
||||
Response()
|
||||
.ok(
|
||||
{
|
||||
"total": total,
|
||||
"succeeded": succeeded,
|
||||
"failed": failed,
|
||||
},
|
||||
message,
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
if success_count == 0:
|
||||
message = f"Upload failed for all {total} file(s)."
|
||||
resp = Response().error(message)
|
||||
resp.data = {
|
||||
"total": total,
|
||||
"succeeded": succeeded,
|
||||
"failed": failed,
|
||||
}
|
||||
return resp.__dict__
|
||||
|
||||
message = f"Partial success: {success_count}/{total} skill(s) uploaded."
|
||||
return (
|
||||
Response()
|
||||
.ok(
|
||||
{
|
||||
"total": total,
|
||||
"succeeded": succeeded,
|
||||
"failed": failed,
|
||||
},
|
||||
message,
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(str(e)).__dict__
|
||||
|
||||
async def download_skill(self):
|
||||
try:
|
||||
name = str(request.args.get("name") or "").strip()
|
||||
|
||||
@@ -3,10 +3,10 @@
|
||||
### 新增
|
||||
|
||||
- 集成 KOOK 平台适配器 ([#5658](https://github.com/AstrBotDevs/AstrBot/pull/5658))。
|
||||
- 集成 DeerFlow Agent Runner 并优化流式处理 ([#5581](https://github.com/AstrBotDevs/AstrBot/pull/5581))。
|
||||
- 新增 Discord pre-react Emoji 支持 ([#5609](https://github.com/AstrBotDevs/AstrBot/pull/5609))。
|
||||
- 新增 Telegram 支持 `sendMessageDraft` 流式实时输出 API ([#5726](https://github.com/AstrBotDevs/AstrBot/issues/5726))
|
||||
- 支持在 Agent 运行时进行消息跟进能力,跟进的消息实时注入给 Agent ([#5484](https://github.com/AstrBotDevs/AstrBot/pull/5484))。
|
||||
- 集成 DeerFlow Agent Runner 并优化流式处理 ([#5581](https://github.com/AstrBotDevs/AstrBot/pull/5581))。
|
||||
- 新增 shell, ipython tool 中包含操作系统信息,提高 windows 下 tool call 成功率 ([#5677](https://github.com/AstrBotDevs/AstrBot/pull/5677))。
|
||||
- Sandbox 支持 Shipyard-neo - 支持 Skills 自迭代 ([#5028](https://github.com/AstrBotDevs/AstrBot/pull/5028))。
|
||||
- 新增 ChatUI WebSocket 传输模式选择,OpenAPI Chat API 支持 WebSocket 连接 ([#5410](https://github.com/AstrBotDevs/AstrBot/pull/5410))。
|
||||
File diff suppressed because it is too large
Load Diff
@@ -224,10 +224,43 @@
|
||||
"empty": "No Skills found",
|
||||
"emptyHint": "Upload a Skills zip to get started",
|
||||
"uploadDialogTitle": "Upload Skills",
|
||||
"uploadHint": "Upload a zip file that contains skill_name/ and a SKILL.md inside.",
|
||||
"uploadHint": "Upload multiple zip skill packages or drag them in. The system validates the structure automatically and shows a result for each file.",
|
||||
"structureRequirement": "The most common failure is an invalid archive structure. Each zip must contain exactly one top-level folder such as `skillname/`, and that folder must include `SKILL.md`.",
|
||||
"abilityMultiple": "Upload multiple zip files at once",
|
||||
"abilityValidate": "Validate `SKILL.md` automatically",
|
||||
"abilitySkip": "Automatically skip duplicate files.",
|
||||
"selectFile": "Select file",
|
||||
"confirmUpload": "Upload",
|
||||
"selectFiles": "Select files (multiple allowed)",
|
||||
"dropzoneTitle": "Drag multiple zip files here",
|
||||
"dropzoneAction": "or click to pick multiple files from a folder",
|
||||
"dropzoneHint": "Batch upload is supported and the structure will be validated automatically",
|
||||
"fileListTitle": "Files in queue",
|
||||
"fileListEmpty": "Selected files will appear here with validation feedback and upload status",
|
||||
"uploading": "Uploading...",
|
||||
"batchResultTitle": "Batch Upload Results",
|
||||
"batchResultSummary": "{success} of {total} files uploaded successfully",
|
||||
"batchSuccessList": "Successfully uploaded",
|
||||
"batchFailedList": "Failed to upload",
|
||||
"confirm": "OK",
|
||||
"confirmUpload": "Start Upload",
|
||||
"cancel": "Cancel",
|
||||
"statusWaiting": "Waiting",
|
||||
"statusUploading": "Uploading",
|
||||
"statusSuccess": "Uploaded",
|
||||
"statusError": "Failed",
|
||||
"statusSkipped": "Skipped",
|
||||
"summaryTotal": "{count} file(s)",
|
||||
"summaryReady": "Pending {count}",
|
||||
"summarySuccess": "Success {count}",
|
||||
"summaryFailed": "Failed {count}",
|
||||
"summarySkipped": "Skipped {count}",
|
||||
"validationReady": "Ready to upload. The archive structure will be checked during upload.",
|
||||
"validationZipOnly": "Only zip skill packages are supported",
|
||||
"validationDuplicate": "A file with the same name is already in the queue and has been skipped",
|
||||
"validationUploading": "Validating and uploading...",
|
||||
"validationUploadFailed": "Upload failed. Please try again.",
|
||||
"validationUploadedAs": "Installed as {name}",
|
||||
"validationNoResult": "No validation result was returned. Check the platform logs.",
|
||||
"noDescription": "No description",
|
||||
"path": "Path",
|
||||
"uploadSuccess": "Upload succeeded",
|
||||
|
||||
@@ -224,10 +224,43 @@
|
||||
"empty": "暂无 Skills",
|
||||
"emptyHint": "请上传 Skills 压缩包",
|
||||
"uploadDialogTitle": "上传 Skills",
|
||||
"uploadHint": "请上传 zip 压缩包,解压后为 skill_name/ 目录,且包含 SKILL.md",
|
||||
"uploadHint": "支持批量上传 zip 技能包,也支持拖拽批量上传 zip 技能包。系统会自动校验目录结构,并给出逐个文件的结果。",
|
||||
"structureRequirement": "常见失败原因是压缩包结构不正确。每个 zip 必须只包含一个顶层目录,例如 `skillname/`,且该目录下必须存在 `SKILL.md`。",
|
||||
"abilityMultiple": "支持一次上传多个zip文件",
|
||||
"abilityValidate": "自动校验 `SKILL.md`",
|
||||
"abilitySkip": "自动跳过重复文件",
|
||||
"selectFile": "选择文件",
|
||||
"confirmUpload": "上传",
|
||||
"selectFiles": "选择文件(可多选)",
|
||||
"dropzoneTitle": "拖拽多个 zip 文件到这里",
|
||||
"dropzoneAction": "或者点击之后在文件夹中选择多个文件",
|
||||
"dropzoneHint": "支持批量上传,系统会自动校验目录结构",
|
||||
"fileListTitle": "待处理文件",
|
||||
"fileListEmpty": "选择文件后会在这里显示校验结果与上传状态",
|
||||
"uploading": "正在上传...",
|
||||
"batchResultTitle": "批量上传结果",
|
||||
"batchResultSummary": "共 {total} 个文件,成功 {success} 个",
|
||||
"batchSuccessList": "上传成功",
|
||||
"batchFailedList": "上传失败",
|
||||
"confirm": "确定",
|
||||
"confirmUpload": "开始上传",
|
||||
"cancel": "取消",
|
||||
"statusWaiting": "待上传",
|
||||
"statusUploading": "上传中",
|
||||
"statusSuccess": "已上传",
|
||||
"statusError": "校验失败",
|
||||
"statusSkipped": "已跳过",
|
||||
"summaryTotal": "共 {count} 个文件",
|
||||
"summaryReady": "待处理 {count}",
|
||||
"summarySuccess": "成功 {count}",
|
||||
"summaryFailed": "失败 {count}",
|
||||
"summarySkipped": "跳过 {count}",
|
||||
"validationReady": "等待上传,上传时会自动校验目录结构",
|
||||
"validationZipOnly": "仅支持 zip 技能包",
|
||||
"validationDuplicate": "同名文件已在列表中,已跳过",
|
||||
"validationUploading": "正在校验并上传...",
|
||||
"validationUploadFailed": "上传失败,请重试",
|
||||
"validationUploadedAs": "已安装为 {name}",
|
||||
"validationNoResult": "未收到校验结果,请检查平台日志",
|
||||
"noDescription": "无描述",
|
||||
"path": "路径",
|
||||
"uploadSuccess": "上传成功",
|
||||
|
||||
@@ -0,0 +1,102 @@
|
||||
import { pinyin } from "pinyin-pro";
|
||||
|
||||
const HAN_IDEOGRAPH_RE = /\p{Unified_Ideograph}/u;
|
||||
|
||||
export const normalizeStr = (s) => (s ?? "").toString().toLowerCase().trim();
|
||||
|
||||
const normalizeLooseFromNormalized = (normalized) =>
|
||||
normalized.replace(/[\s_-]+/g, "").replace(/[()()【】\[\]{}·•]+/g, "");
|
||||
|
||||
export const normalizeLoose = (s) =>
|
||||
normalizeLooseFromNormalized(normalizeStr(s));
|
||||
|
||||
const memoizeStringFn = (fn) => {
|
||||
const cache = new Map();
|
||||
|
||||
return (raw) => {
|
||||
const key = (raw ?? "").toString();
|
||||
if (cache.has(key)) {
|
||||
return cache.get(key);
|
||||
}
|
||||
|
||||
const value = fn(key);
|
||||
cache.set(key, value);
|
||||
return value;
|
||||
};
|
||||
};
|
||||
|
||||
const getNormalizedText = memoizeStringFn(normalizeStr);
|
||||
|
||||
const getLooseText = memoizeStringFn((text) =>
|
||||
normalizeLooseFromNormalized(getNormalizedText(text)),
|
||||
);
|
||||
|
||||
export const toPinyinText = memoizeStringFn((text) =>
|
||||
pinyin(text, { toneType: "none" })
|
||||
.toLowerCase()
|
||||
.replace(/\s+/g, ""),
|
||||
);
|
||||
|
||||
export const toInitials = memoizeStringFn((text) =>
|
||||
pinyin(text, { pattern: "first", toneType: "none" })
|
||||
.toLowerCase()
|
||||
.replace(/\s+/g, ""),
|
||||
);
|
||||
|
||||
export const buildSearchQuery = (raw) => {
|
||||
const norm = getNormalizedText(raw);
|
||||
if (!norm) return null;
|
||||
return {
|
||||
norm,
|
||||
loose: getLooseText(raw),
|
||||
};
|
||||
};
|
||||
|
||||
export const matchesText = (value, query) => {
|
||||
if (value == null || !query?.norm) return false;
|
||||
const text = String(value);
|
||||
|
||||
const normalizedValue = getNormalizedText(text);
|
||||
const looseValue = query.loose ? getLooseText(text) : null;
|
||||
|
||||
if (normalizedValue.includes(query.norm)) return true;
|
||||
if (query.loose && looseValue?.includes(query.loose)) return true;
|
||||
|
||||
if (!HAN_IDEOGRAPH_RE.test(text)) return false;
|
||||
|
||||
const pinyinValue = toPinyinText(text);
|
||||
if (pinyinValue.includes(query.norm)) return true;
|
||||
|
||||
const initialsValue = toInitials(text);
|
||||
if (initialsValue.includes(query.norm)) return true;
|
||||
|
||||
return false;
|
||||
};
|
||||
|
||||
export const getPluginSearchFields = (plugin) => {
|
||||
const supportPlatforms = Array.isArray(plugin?.support_platforms)
|
||||
? plugin.support_platforms.join(" ")
|
||||
: "";
|
||||
const tags = Array.isArray(plugin?.tags) ? plugin.tags.join(" ") : "";
|
||||
|
||||
return [
|
||||
plugin?.name,
|
||||
plugin?.trimmedName,
|
||||
plugin?.display_name,
|
||||
plugin?.desc,
|
||||
plugin?.author,
|
||||
plugin?.repo,
|
||||
plugin?.version,
|
||||
plugin?.astrbot_version,
|
||||
supportPlatforms,
|
||||
tags,
|
||||
];
|
||||
};
|
||||
|
||||
export const matchesPluginSearch = (plugin, query) => {
|
||||
if (!query) return true;
|
||||
|
||||
return getPluginSearchFields(plugin).some((candidate) =>
|
||||
matchesText(candidate, query),
|
||||
);
|
||||
};
|
||||
@@ -84,7 +84,6 @@ const {
|
||||
normalizeStr,
|
||||
toPinyinText,
|
||||
toInitials,
|
||||
marketCustomFilter,
|
||||
plugin_handler_info_headers,
|
||||
pluginHeaders,
|
||||
filteredExtensions,
|
||||
|
||||
@@ -81,7 +81,6 @@ const {
|
||||
normalizeStr,
|
||||
toPinyinText,
|
||||
toInitials,
|
||||
marketCustomFilter,
|
||||
plugin_handler_info_headers,
|
||||
pluginHeaders,
|
||||
filteredExtensions,
|
||||
|
||||
@@ -82,7 +82,6 @@ const {
|
||||
normalizeStr,
|
||||
toPinyinText,
|
||||
toInitials,
|
||||
marketCustomFilter,
|
||||
plugin_handler_info_headers,
|
||||
pluginHeaders,
|
||||
filteredExtensions,
|
||||
|
||||
@@ -1,9 +1,15 @@
|
||||
import axios from "axios";
|
||||
import { pinyin } from "pinyin-pro";
|
||||
import { useCommonStore } from "@/stores/common";
|
||||
import { useI18n, useModuleI18n } from "@/i18n/composables";
|
||||
import { getPlatformDisplayName } from "@/utils/platformUtils";
|
||||
import { resolveErrorMessage } from "@/utils/errorUtils";
|
||||
import {
|
||||
buildSearchQuery,
|
||||
matchesPluginSearch,
|
||||
normalizeStr,
|
||||
toInitials,
|
||||
toPinyinText,
|
||||
} from "@/utils/pluginSearch";
|
||||
import { ref, computed, onMounted, onUnmounted, reactive, watch } from "vue";
|
||||
import { useRoute, useRouter } from "vue-router";
|
||||
import { useDisplay } from "vuetify";
|
||||
@@ -240,37 +246,6 @@ export const useExtensionPage = () => {
|
||||
});
|
||||
|
||||
// 插件市场拼音搜索
|
||||
const normalizeStr = (s) => (s ?? "").toString().toLowerCase().trim();
|
||||
const toPinyinText = (s) =>
|
||||
pinyin(s ?? "", { toneType: "none" })
|
||||
.toLowerCase()
|
||||
.replace(/\s+/g, "");
|
||||
const toInitials = (s) =>
|
||||
pinyin(s ?? "", { pattern: "first", toneType: "none" })
|
||||
.toLowerCase()
|
||||
.replace(/\s+/g, "");
|
||||
const marketCustomFilter = (value, query, item) => {
|
||||
const q = normalizeStr(query);
|
||||
if (!q) return true;
|
||||
|
||||
const candidates = new Set();
|
||||
if (value != null) candidates.add(String(value));
|
||||
if (item?.name) candidates.add(String(item.name));
|
||||
if (item?.trimmedName) candidates.add(String(item.trimmedName));
|
||||
if (item?.display_name) candidates.add(String(item.display_name));
|
||||
if (item?.desc) candidates.add(String(item.desc));
|
||||
if (item?.author) candidates.add(String(item.author));
|
||||
|
||||
for (const v of candidates) {
|
||||
const nv = normalizeStr(v);
|
||||
if (nv.includes(q)) return true;
|
||||
const pv = toPinyinText(v);
|
||||
if (pv.includes(q)) return true;
|
||||
const iv = toInitials(v);
|
||||
if (iv.includes(q)) return true;
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
const plugin_handler_info_headers = computed(() => [
|
||||
{ title: tm("table.headers.eventType"), key: "event_type_h" },
|
||||
@@ -347,47 +322,24 @@ export const useExtensionPage = () => {
|
||||
// 通过搜索过滤插件
|
||||
const filteredPlugins = computed(() => {
|
||||
const plugins = filteredExtensions.value;
|
||||
let filtered = plugins;
|
||||
|
||||
if (pluginSearch.value) {
|
||||
const search = pluginSearch.value.toLowerCase();
|
||||
filtered = plugins.filter((plugin) => {
|
||||
const pluginName = (plugin.name ?? "").toLowerCase();
|
||||
const pluginDesc = (plugin.desc ?? "").toLowerCase();
|
||||
const pluginAuthor = (plugin.author ?? "").toLowerCase();
|
||||
const supportPlatforms = Array.isArray(plugin.support_platforms)
|
||||
? plugin.support_platforms.join(" ").toLowerCase()
|
||||
: "";
|
||||
const astrbotVersion = (plugin.astrbot_version ?? "").toLowerCase();
|
||||
|
||||
return (
|
||||
pluginName.includes(search) ||
|
||||
pluginDesc.includes(search) ||
|
||||
pluginAuthor.includes(search) ||
|
||||
supportPlatforms.includes(search) ||
|
||||
astrbotVersion.includes(search)
|
||||
);
|
||||
});
|
||||
}
|
||||
const query = buildSearchQuery(pluginSearch.value);
|
||||
const filtered = query
|
||||
? plugins.filter((plugin) => matchesPluginSearch(plugin, query))
|
||||
: plugins;
|
||||
|
||||
return sortPluginsByName([...filtered]);
|
||||
});
|
||||
|
||||
// 过滤后的插件市场数据(带搜索)
|
||||
const filteredMarketPlugins = computed(() => {
|
||||
if (!debouncedMarketSearch.value) {
|
||||
const query = buildSearchQuery(debouncedMarketSearch.value);
|
||||
if (!query) {
|
||||
return pluginMarketData.value;
|
||||
}
|
||||
|
||||
const search = debouncedMarketSearch.value.toLowerCase();
|
||||
return pluginMarketData.value.filter((plugin) => {
|
||||
// 使用自定义过滤器
|
||||
return (
|
||||
marketCustomFilter(plugin.name, search, plugin) ||
|
||||
marketCustomFilter(plugin.desc, search, plugin) ||
|
||||
marketCustomFilter(plugin.author, search, plugin)
|
||||
);
|
||||
});
|
||||
|
||||
return pluginMarketData.value.filter((plugin) =>
|
||||
matchesPluginSearch(plugin, query),
|
||||
);
|
||||
});
|
||||
|
||||
// 所有插件列表,推荐插件排在前面
|
||||
@@ -1563,7 +1515,6 @@ export const useExtensionPage = () => {
|
||||
normalizeStr,
|
||||
toPinyinText,
|
||||
toInitials,
|
||||
marketCustomFilter,
|
||||
plugin_handler_info_headers,
|
||||
pluginHeaders,
|
||||
filteredExtensions,
|
||||
|
||||
@@ -5,6 +5,10 @@ import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import runtime_bootstrap
|
||||
|
||||
runtime_bootstrap.initialize_runtime_bootstrap()
|
||||
|
||||
from astrbot.core import LogBroker, LogManager, db_helper, logger # noqa: E402
|
||||
from astrbot.core.config.default import VERSION # noqa: E402
|
||||
from astrbot.core.initial_loader import InitialLoader # noqa: E402
|
||||
|
||||
+3
-3
@@ -1,9 +1,9 @@
|
||||
[project]
|
||||
name = "AstrBot"
|
||||
version = "4.19.1"
|
||||
version = "4.19.2"
|
||||
description = "Easy-to-use multi-platform LLM chatbot and development framework"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.11"
|
||||
requires-python = ">=3.12"
|
||||
|
||||
keywords = ["Astrbot", "Astrbot Module", "Astrbot Plugin"]
|
||||
|
||||
@@ -61,7 +61,7 @@ dependencies = [
|
||||
"xinference-client",
|
||||
"tenacity>=9.1.2",
|
||||
"shipyard-python-sdk>=0.2.4",
|
||||
"shipyard-neo-sdk @ git+https://github.com/AstrBotDevs/shipyard-neo.git#subdirectory=shipyard-neo-sdk",
|
||||
"shipyard-neo-sdk>=0.2.0",
|
||||
"python-socks>=2.8.0",
|
||||
"packaging>=24.2",
|
||||
]
|
||||
|
||||
+1
-1
@@ -54,5 +54,5 @@ markitdown-no-magika[docx,xls,xlsx]>=0.1.2
|
||||
xinference-client
|
||||
tenacity>=9.1.2
|
||||
shipyard-python-sdk>=0.2.4
|
||||
shipyard-neo-sdk @ git+https://github.com/AstrBotDevs/shipyard-neo.git#subdirectory=shipyard-neo-sdk
|
||||
shipyard-neo-sdk>=0.2.0
|
||||
packaging>=24.2
|
||||
|
||||
@@ -0,0 +1,50 @@
|
||||
import logging
|
||||
import ssl
|
||||
from typing import Any
|
||||
|
||||
import aiohttp.connector as aiohttp_connector
|
||||
|
||||
from astrbot.utils.http_ssl_common import build_ssl_context_with_certifi
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _try_patch_aiohttp_ssl_context(
|
||||
ssl_context: ssl.SSLContext,
|
||||
log_obj: Any | None = None,
|
||||
) -> bool:
|
||||
log = log_obj or logger
|
||||
attr_name = "_SSL_CONTEXT_VERIFIED"
|
||||
|
||||
if not hasattr(aiohttp_connector, attr_name):
|
||||
log.warning(
|
||||
"aiohttp connector does not expose _SSL_CONTEXT_VERIFIED; skipped patch.",
|
||||
)
|
||||
return False
|
||||
|
||||
current_value = getattr(aiohttp_connector, attr_name, None)
|
||||
if current_value is not None and not isinstance(current_value, ssl.SSLContext):
|
||||
log.warning(
|
||||
"aiohttp connector exposes _SSL_CONTEXT_VERIFIED with unexpected type; skipped patch.",
|
||||
)
|
||||
return False
|
||||
|
||||
setattr(aiohttp_connector, attr_name, ssl_context)
|
||||
log.info("Configured aiohttp verified SSL context with system+certifi trust chain.")
|
||||
return True
|
||||
|
||||
|
||||
def configure_runtime_ca_bundle(log_obj: Any | None = None) -> bool:
|
||||
log = log_obj or logger
|
||||
|
||||
try:
|
||||
log.info("Bootstrapping runtime CA bundle.")
|
||||
ssl_context = build_ssl_context_with_certifi(log_obj=log)
|
||||
return _try_patch_aiohttp_ssl_context(ssl_context, log_obj=log)
|
||||
except Exception as exc:
|
||||
log.error("Failed to configure runtime CA bundle for aiohttp: %r", exc)
|
||||
return False
|
||||
|
||||
|
||||
def initialize_runtime_bootstrap(log_obj: Any | None = None) -> bool:
|
||||
return configure_runtime_ca_bundle(log_obj=log_obj)
|
||||
Vendored
-20
@@ -7,7 +7,6 @@ import shutil
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable
|
||||
from urllib.parse import urlparse
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from astrbot.core.message.components import BaseMessageComponent
|
||||
@@ -25,25 +24,6 @@ class NoopAwaitable:
|
||||
return None
|
||||
|
||||
|
||||
def get_bound_tcp_port(site: Any) -> int:
|
||||
"""Resolve the bound aiohttp TCP site port for tests.
|
||||
|
||||
We prefer the public ``site.name`` first. Some aiohttp test setups with
|
||||
ephemeral ports may not expose a usable port there, so we fall back to
|
||||
``site._server.sockets`` as a test-only compatibility path.
|
||||
"""
|
||||
parsed = urlparse(getattr(site, "name", ""))
|
||||
if parsed.port is not None and parsed.port > 0:
|
||||
return parsed.port
|
||||
|
||||
server = getattr(site, "_server", None)
|
||||
sockets = getattr(server, "sockets", None) if server else None
|
||||
if sockets:
|
||||
return sockets[0].getsockname()[1]
|
||||
|
||||
raise RuntimeError("Unable to resolve bound TCP port from aiohttp site")
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 平台配置工厂
|
||||
# ============================================================
|
||||
|
||||
@@ -1,268 +0,0 @@
|
||||
"""Performance benchmark tests for core AstrBot execution paths.
|
||||
|
||||
Run with:
|
||||
uv run pytest tests/performance/test_benchmarks.py -q -s
|
||||
|
||||
Optional output:
|
||||
ASTRBOT_BENCHMARK_OUTPUT=/tmp/astrbot_benchmark.json
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import time
|
||||
import zipfile
|
||||
from dataclasses import asdict, dataclass
|
||||
from pathlib import Path
|
||||
from typing import Awaitable, Callable
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from aiohttp import web
|
||||
|
||||
from astrbot.core.backup.exporter import AstrBotExporter
|
||||
from astrbot.core.message.components import File, Image, Record
|
||||
from astrbot.core.utils.io import download_file, file_to_base64
|
||||
from tests.fixtures.helpers import get_bound_tcp_port
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class BenchmarkResult:
|
||||
name: str
|
||||
iterations: int
|
||||
warmup: int
|
||||
min_ms: float
|
||||
max_ms: float
|
||||
mean_ms: float
|
||||
p50_ms: float
|
||||
p95_ms: float
|
||||
ops_per_sec: float
|
||||
|
||||
|
||||
def _percentile(values: list[float], q: float) -> float:
|
||||
if not values:
|
||||
return 0.0
|
||||
sorted_values = sorted(values)
|
||||
if len(sorted_values) == 1:
|
||||
return sorted_values[0]
|
||||
rank = (len(sorted_values) - 1) * q
|
||||
lower = math.floor(rank)
|
||||
upper = math.ceil(rank)
|
||||
if lower == upper:
|
||||
return sorted_values[lower]
|
||||
weight = rank - lower
|
||||
return sorted_values[lower] * (1 - weight) + sorted_values[upper] * weight
|
||||
|
||||
|
||||
async def run_async_benchmark(
|
||||
name: str,
|
||||
func: Callable[[], Awaitable[None]],
|
||||
*,
|
||||
iterations: int,
|
||||
warmup: int = 5,
|
||||
) -> BenchmarkResult:
|
||||
for _ in range(warmup):
|
||||
await func()
|
||||
|
||||
samples_ms: list[float] = []
|
||||
for _ in range(iterations):
|
||||
start_ns = time.perf_counter_ns()
|
||||
await func()
|
||||
elapsed_ms = (time.perf_counter_ns() - start_ns) / 1_000_000
|
||||
samples_ms.append(elapsed_ms)
|
||||
|
||||
mean_ms = sum(samples_ms) / len(samples_ms)
|
||||
return BenchmarkResult(
|
||||
name=name,
|
||||
iterations=iterations,
|
||||
warmup=warmup,
|
||||
min_ms=min(samples_ms),
|
||||
max_ms=max(samples_ms),
|
||||
mean_ms=mean_ms,
|
||||
p50_ms=_percentile(samples_ms, 0.50),
|
||||
p95_ms=_percentile(samples_ms, 0.95),
|
||||
ops_per_sec=1000 / mean_ms if mean_ms > 0 else 0.0,
|
||||
)
|
||||
|
||||
|
||||
def _print_report(results: list[BenchmarkResult]) -> None:
|
||||
print("\nAstrBot Benchmark Report")
|
||||
print("-" * 84)
|
||||
print(
|
||||
f"{'case':35} {'iters':>7} {'mean(ms)':>10} {'p50(ms)':>10} "
|
||||
f"{'p95(ms)':>10} {'ops/s':>10}"
|
||||
)
|
||||
print("-" * 84)
|
||||
for result in results:
|
||||
print(
|
||||
f"{result.name:35} {result.iterations:7d} "
|
||||
f"{result.mean_ms:10.4f} {result.p50_ms:10.4f} "
|
||||
f"{result.p95_ms:10.4f} {result.ops_per_sec:10.1f}"
|
||||
)
|
||||
|
||||
|
||||
def _scaled_iterations(value: int) -> int:
|
||||
scale = int(os.environ.get("ASTRBOT_BENCHMARK_SCALE", "1"))
|
||||
return max(1, value * scale)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.slow
|
||||
async def test_core_performance_benchmarks(tmp_path: Path) -> None:
|
||||
"""Measure representative performance paths across core modules."""
|
||||
data = os.urandom(256 * 1024)
|
||||
|
||||
payload_path = tmp_path / "payload.bin"
|
||||
payload_path.write_bytes(data)
|
||||
|
||||
image = Image.fromFileSystem(str(payload_path))
|
||||
record = Record.fromFileSystem(str(payload_path))
|
||||
file_component = File(name="payload.bin", file=str(payload_path))
|
||||
exists_path = tmp_path / "exists_target.txt"
|
||||
exists_path.write_text("ok", encoding="utf-8")
|
||||
|
||||
attachments_dir = tmp_path / "attachments"
|
||||
attachments_dir.mkdir()
|
||||
attachments: list[dict[str, str]] = []
|
||||
attachments_with_missing: list[dict[str, str]] = []
|
||||
for i in range(64):
|
||||
file_path = attachments_dir / f"attachment_{i}.bin"
|
||||
file_path.write_bytes(data[:2048])
|
||||
attachments.append({"attachment_id": f"att_{i}", "path": str(file_path)})
|
||||
if i % 4 == 0:
|
||||
missing_path = attachments_dir / f"missing_{i}.bin"
|
||||
attachments_with_missing.append(
|
||||
{"attachment_id": f"att_missing_{i}", "path": str(missing_path)}
|
||||
)
|
||||
attachments_with_missing.append(
|
||||
{"attachment_id": f"att_existing_{i}", "path": str(file_path)}
|
||||
)
|
||||
|
||||
exporter = AstrBotExporter(main_db=MagicMock())
|
||||
zip_path = tmp_path / "attachments_bench.zip"
|
||||
micro_batch = 32
|
||||
download_target = tmp_path / "download_target.bin"
|
||||
download_payload = os.urandom(512 * 1024)
|
||||
|
||||
async def handle_download(_request):
|
||||
return web.Response(body=download_payload)
|
||||
|
||||
app = web.Application()
|
||||
app.router.add_get("/download.bin", handle_download)
|
||||
runner = web.AppRunner(app, access_log=None)
|
||||
await runner.setup()
|
||||
site = web.TCPSite(runner, "127.0.0.1", 0)
|
||||
await site.start()
|
||||
port = get_bound_tcp_port(site)
|
||||
download_url = f"http://127.0.0.1:{port}/download.bin"
|
||||
|
||||
async def bench_file_to_base64() -> None:
|
||||
await file_to_base64(str(payload_path))
|
||||
|
||||
async def bench_image_convert_to_base64() -> None:
|
||||
await image.convert_to_base64()
|
||||
|
||||
async def bench_record_convert_to_base64() -> None:
|
||||
await record.convert_to_base64()
|
||||
|
||||
async def bench_image_convert_to_file_path() -> None:
|
||||
for _ in range(micro_batch):
|
||||
await image.convert_to_file_path()
|
||||
|
||||
async def bench_file_component_get_file() -> None:
|
||||
await file_component.get_file()
|
||||
|
||||
async def bench_to_thread_exists() -> None:
|
||||
await asyncio.to_thread(exists_path.exists)
|
||||
|
||||
async def bench_export_attachments_existing() -> None:
|
||||
if zip_path.exists():
|
||||
zip_path.unlink()
|
||||
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
|
||||
await exporter._export_attachments(zf, attachments)
|
||||
zip_path.unlink(missing_ok=True)
|
||||
|
||||
async def bench_export_attachments_with_missing() -> None:
|
||||
if zip_path.exists():
|
||||
zip_path.unlink()
|
||||
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
|
||||
await exporter._export_attachments(zf, attachments_with_missing)
|
||||
zip_path.unlink(missing_ok=True)
|
||||
|
||||
async def bench_download_file_local_http() -> None:
|
||||
await download_file(download_url, str(download_target))
|
||||
download_target.unlink(missing_ok=True)
|
||||
|
||||
try:
|
||||
results = [
|
||||
await run_async_benchmark(
|
||||
"utils.io.file_to_base64(256KB)",
|
||||
bench_file_to_base64,
|
||||
iterations=_scaled_iterations(120),
|
||||
),
|
||||
await run_async_benchmark(
|
||||
"components.Image.convert_to_base64",
|
||||
bench_image_convert_to_base64,
|
||||
iterations=_scaled_iterations(120),
|
||||
),
|
||||
await run_async_benchmark(
|
||||
"components.Record.convert_to_base64",
|
||||
bench_record_convert_to_base64,
|
||||
iterations=_scaled_iterations(120),
|
||||
),
|
||||
await run_async_benchmark(
|
||||
f"components.Image.convert_to_file_path(x{micro_batch})",
|
||||
bench_image_convert_to_file_path,
|
||||
iterations=_scaled_iterations(140),
|
||||
),
|
||||
await run_async_benchmark(
|
||||
"components.File.get_file(local)",
|
||||
bench_file_component_get_file,
|
||||
iterations=_scaled_iterations(140),
|
||||
),
|
||||
await run_async_benchmark(
|
||||
"asyncio.to_thread(Path.exists)",
|
||||
bench_to_thread_exists,
|
||||
iterations=_scaled_iterations(240),
|
||||
),
|
||||
await run_async_benchmark(
|
||||
"backup.exporter._export_attachments(existing)",
|
||||
bench_export_attachments_existing,
|
||||
iterations=_scaled_iterations(20),
|
||||
warmup=2,
|
||||
),
|
||||
await run_async_benchmark(
|
||||
"backup.exporter._export_attachments(mixed)",
|
||||
bench_export_attachments_with_missing,
|
||||
iterations=_scaled_iterations(20),
|
||||
warmup=2,
|
||||
),
|
||||
await run_async_benchmark(
|
||||
"utils.io.download_file(local_http_512KB)",
|
||||
bench_download_file_local_http,
|
||||
iterations=_scaled_iterations(12),
|
||||
warmup=2,
|
||||
),
|
||||
]
|
||||
finally:
|
||||
await runner.cleanup()
|
||||
|
||||
_print_report(results)
|
||||
|
||||
output_path = os.environ.get("ASTRBOT_BENCHMARK_OUTPUT")
|
||||
if output_path:
|
||||
Path(output_path).write_text(
|
||||
json.dumps([asdict(result) for result in results], indent=2),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
# Keep assertions broad: benchmarks are for measurement, not strict gating.
|
||||
assert len(results) == 9
|
||||
for result in results:
|
||||
assert result.iterations > 0
|
||||
assert result.mean_ms > 0
|
||||
assert result.max_ms >= result.min_ms
|
||||
assert result.p95_ms >= result.p50_ms
|
||||
@@ -172,15 +172,6 @@ class TestAstrBotExporter:
|
||||
assert "test.json" in exporter._checksums
|
||||
assert exporter._checksums["test.json"].startswith("sha256:")
|
||||
|
||||
def test_read_text_if_exists(self, tmp_path):
|
||||
"""测试 _read_text_if_exists 行为。"""
|
||||
exporter = AstrBotExporter(main_db=MagicMock())
|
||||
file_path = tmp_path / "config.json"
|
||||
file_path.write_text('{"k":"v"}', encoding="utf-8")
|
||||
|
||||
assert exporter._read_text_if_exists(str(file_path)) == '{"k":"v"}'
|
||||
assert exporter._read_text_if_exists(str(tmp_path / "missing.json")) is None
|
||||
|
||||
def test_generate_manifest(self, mock_main_db, mock_kb_manager):
|
||||
"""测试生成清单"""
|
||||
exporter = AstrBotExporter(
|
||||
@@ -249,95 +240,6 @@ class TestAstrBotExporter:
|
||||
assert "databases/main_db.json" in namelist
|
||||
assert "config/cmd_config.json" in namelist
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_export_attachments_exports_existing_and_skips_missing(
|
||||
self, mock_main_db, tmp_path
|
||||
):
|
||||
"""测试附件导出:存在文件写入 ZIP,不存在文件跳过。"""
|
||||
exporter = AstrBotExporter(main_db=mock_main_db, kb_manager=None)
|
||||
|
||||
existing_file = tmp_path / "exists.txt"
|
||||
existing_file.write_text("hello", encoding="utf-8")
|
||||
missing_file = tmp_path / "missing.txt"
|
||||
zip_path = tmp_path / "attachments.zip"
|
||||
|
||||
attachments = [
|
||||
{"attachment_id": "att_ok", "path": str(existing_file)},
|
||||
{"attachment_id": "att_missing", "path": str(missing_file)},
|
||||
]
|
||||
|
||||
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
|
||||
await exporter._export_attachments(zf, attachments)
|
||||
|
||||
with zipfile.ZipFile(zip_path, "r") as zf:
|
||||
namelist = zf.namelist()
|
||||
|
||||
assert "files/attachments/att_ok.txt" in namelist
|
||||
assert "files/attachments/att_missing.txt" not in namelist
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_export_attachments_skips_empty_attachment_id(
|
||||
self, mock_main_db, tmp_path
|
||||
):
|
||||
"""测试附件导出:attachment_id 为空时跳过,避免覆盖冲突。"""
|
||||
exporter = AstrBotExporter(main_db=mock_main_db, kb_manager=None)
|
||||
|
||||
file_a = tmp_path / "a.txt"
|
||||
file_b = tmp_path / "b.txt"
|
||||
file_a.write_text("a", encoding="utf-8")
|
||||
file_b.write_text("b", encoding="utf-8")
|
||||
zip_path = tmp_path / "attachments_empty_id.zip"
|
||||
|
||||
attachments = [
|
||||
{"attachment_id": "", "path": str(file_a)},
|
||||
{"path": str(file_b)},
|
||||
{"attachment_id": "att_ok", "path": str(file_a)},
|
||||
]
|
||||
|
||||
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
|
||||
await exporter._export_attachments(zf, attachments)
|
||||
|
||||
with zipfile.ZipFile(zip_path, "r") as zf:
|
||||
namelist = zf.namelist()
|
||||
|
||||
assert "files/attachments/att_ok.txt" in namelist
|
||||
assert "files/attachments/.txt" not in namelist
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_export_attachments_keeps_best_effort_on_unexpected_write_error(
|
||||
self, mock_main_db, tmp_path
|
||||
):
|
||||
"""测试附件导出:单个非 OSError 写入异常不会中断后续附件导出。"""
|
||||
exporter = AstrBotExporter(main_db=mock_main_db, kb_manager=None)
|
||||
|
||||
file_a = tmp_path / "a.txt"
|
||||
file_b = tmp_path / "b.txt"
|
||||
file_a.write_text("a", encoding="utf-8")
|
||||
file_b.write_text("b", encoding="utf-8")
|
||||
zip_path = tmp_path / "attachments_best_effort.zip"
|
||||
|
||||
attachments = [
|
||||
{"attachment_id": "att_boom", "path": str(file_a)},
|
||||
{"attachment_id": "att_ok", "path": str(file_b)},
|
||||
]
|
||||
|
||||
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
|
||||
original_write = zf.write
|
||||
|
||||
def flaky_write(filename, arcname=None, *args, **kwargs):
|
||||
if arcname == "files/attachments/att_boom.txt":
|
||||
raise RuntimeError("boom")
|
||||
return original_write(filename, arcname, *args, **kwargs)
|
||||
|
||||
with patch.object(zf, "write", side_effect=flaky_write):
|
||||
await exporter._export_attachments(zf, attachments)
|
||||
|
||||
with zipfile.ZipFile(zip_path, "r") as zf:
|
||||
namelist = zf.namelist()
|
||||
|
||||
assert "files/attachments/att_boom.txt" not in namelist
|
||||
assert "files/attachments/att_ok.txt" in namelist
|
||||
|
||||
|
||||
class TestAstrBotImporter:
|
||||
"""AstrBotImporter 类测试"""
|
||||
|
||||
+227
-5
@@ -1,11 +1,14 @@
|
||||
import asyncio
|
||||
import io
|
||||
import os
|
||||
import sys
|
||||
import zipfile
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from quart import Quart
|
||||
from werkzeug.datastructures import FileStorage
|
||||
|
||||
from astrbot.core import LogBroker
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
@@ -15,7 +18,6 @@ from astrbot.core.star.star_handler import star_handlers_registry
|
||||
from astrbot.dashboard.server import AstrBotDashboard
|
||||
from tests.fixtures.helpers import (
|
||||
MockPluginBuilder,
|
||||
MockPluginConfig,
|
||||
create_mock_updater_install,
|
||||
create_mock_updater_update,
|
||||
)
|
||||
@@ -145,9 +147,7 @@ async def test_plugins(
|
||||
monkeypatch.setattr(
|
||||
core_lifecycle_td.plugin_manager.updator, "install", mock_install
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
core_lifecycle_td.plugin_manager.updator, "update", mock_update
|
||||
)
|
||||
monkeypatch.setattr(core_lifecycle_td.plugin_manager.updator, "update", mock_update)
|
||||
|
||||
try:
|
||||
# 插件安装
|
||||
@@ -158,7 +158,9 @@ async def test_plugins(
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data["status"] == "ok", f"安装失败: {data.get('message', 'unknown error')}"
|
||||
assert data["status"] == "ok", (
|
||||
f"安装失败: {data.get('message', 'unknown error')}"
|
||||
)
|
||||
|
||||
# 验证插件已注册
|
||||
exists = any(md.name == test_plugin_name for md in star_registry)
|
||||
@@ -493,3 +495,223 @@ async def test_neo_skills_routes(
|
||||
data = await response.get_json()
|
||||
assert data["status"] == "ok"
|
||||
assert data["data"]["skill_key"] == "neo.demo"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_upload_skills_returns_error_when_all_files_invalid(
|
||||
app: Quart,
|
||||
authenticated_header: dict,
|
||||
):
|
||||
test_client = app.test_client()
|
||||
|
||||
response = await test_client.post(
|
||||
"/api/skills/batch-upload",
|
||||
headers=authenticated_header,
|
||||
files={
|
||||
"files": FileStorage(
|
||||
stream=io.BytesIO(b"not-a-zip"),
|
||||
filename="invalid.txt",
|
||||
content_type="text/plain",
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data["status"] == "error"
|
||||
assert data["message"] == "Upload failed for all 1 file(s)."
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_upload_skills_accepts_zip_files(
|
||||
app: Quart,
|
||||
authenticated_header: dict,
|
||||
monkeypatch,
|
||||
):
|
||||
async def _fake_sync_skills_to_active_sandboxes():
|
||||
return
|
||||
|
||||
def _fake_install_skill_from_zip(
|
||||
self,
|
||||
zip_path: str,
|
||||
*,
|
||||
overwrite: bool = True,
|
||||
):
|
||||
_ = self, overwrite
|
||||
assert zip_path.endswith(".zip")
|
||||
return "demo_skill"
|
||||
|
||||
monkeypatch.setattr(
|
||||
"astrbot.dashboard.routes.skills.sync_skills_to_active_sandboxes",
|
||||
_fake_sync_skills_to_active_sandboxes,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"astrbot.dashboard.routes.skills.SkillManager.install_skill_from_zip",
|
||||
_fake_install_skill_from_zip,
|
||||
)
|
||||
|
||||
test_client = app.test_client()
|
||||
|
||||
response = await test_client.post(
|
||||
"/api/skills/batch-upload",
|
||||
headers=authenticated_header,
|
||||
files={
|
||||
"files": FileStorage(
|
||||
stream=io.BytesIO(b"fake-zip"),
|
||||
filename="demo_skill.zip",
|
||||
content_type="application/zip",
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data["status"] == "ok"
|
||||
assert data["message"] == "All 1 skill(s) uploaded successfully."
|
||||
assert data["data"]["total"] == 1
|
||||
assert data["data"]["succeeded"] == [
|
||||
{"filename": "demo_skill.zip", "name": "demo_skill"}
|
||||
]
|
||||
assert data["data"]["failed"] == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_upload_skills_accepts_valid_skill_archive(
|
||||
app: Quart,
|
||||
authenticated_header: dict,
|
||||
monkeypatch,
|
||||
tmp_path,
|
||||
):
|
||||
data_dir = tmp_path / "data"
|
||||
skills_dir = tmp_path / "skills"
|
||||
temp_dir = tmp_path / "temp"
|
||||
data_dir.mkdir()
|
||||
skills_dir.mkdir()
|
||||
temp_dir.mkdir()
|
||||
|
||||
async def _fake_sync_skills_to_active_sandboxes():
|
||||
return
|
||||
|
||||
monkeypatch.setattr(
|
||||
"astrbot.dashboard.routes.skills.sync_skills_to_active_sandboxes",
|
||||
_fake_sync_skills_to_active_sandboxes,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"astrbot.core.skills.skill_manager.get_astrbot_data_path",
|
||||
lambda: str(data_dir),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"astrbot.core.skills.skill_manager.get_astrbot_skills_path",
|
||||
lambda: str(skills_dir),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"astrbot.core.skills.skill_manager.get_astrbot_temp_path",
|
||||
lambda: str(temp_dir),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"astrbot.dashboard.routes.skills.get_astrbot_temp_path",
|
||||
lambda: str(temp_dir),
|
||||
)
|
||||
|
||||
archive = io.BytesIO()
|
||||
with zipfile.ZipFile(archive, "w", zipfile.ZIP_DEFLATED) as zf:
|
||||
zf.writestr(
|
||||
"demo_skill/SKILL.md",
|
||||
"---\nname: demo-skill\ndescription: Demo skill\n---\n",
|
||||
)
|
||||
zf.writestr("demo_skill/notes.txt", "hello")
|
||||
zf.writestr("__MACOSX/demo_skill/._SKILL.md", "")
|
||||
zf.writestr("__MACOSX/._demo_skill", "")
|
||||
archive.seek(0)
|
||||
|
||||
test_client = app.test_client()
|
||||
|
||||
response = await test_client.post(
|
||||
"/api/skills/batch-upload",
|
||||
headers=authenticated_header,
|
||||
files={
|
||||
"files": FileStorage(
|
||||
stream=archive,
|
||||
filename="demo_skill.zip",
|
||||
content_type="application/zip",
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data["status"] == "ok"
|
||||
assert data["data"]["succeeded"] == [
|
||||
{"filename": "demo_skill.zip", "name": "demo_skill"}
|
||||
]
|
||||
assert data["data"]["failed"] == []
|
||||
assert (skills_dir / "demo_skill" / "SKILL.md").exists()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_upload_skills_partial_success(
|
||||
app: Quart,
|
||||
authenticated_header: dict,
|
||||
monkeypatch,
|
||||
):
|
||||
async def _fake_sync_skills_to_active_sandboxes():
|
||||
return
|
||||
|
||||
def _fake_install_skill_from_zip(
|
||||
self,
|
||||
zip_path: str,
|
||||
*,
|
||||
overwrite: bool = True,
|
||||
):
|
||||
_ = self, overwrite
|
||||
if "ok_skill" in zip_path:
|
||||
return "ok_skill"
|
||||
raise RuntimeError("install failed")
|
||||
|
||||
monkeypatch.setattr(
|
||||
"astrbot.dashboard.routes.skills.sync_skills_to_active_sandboxes",
|
||||
_fake_sync_skills_to_active_sandboxes,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"astrbot.dashboard.routes.skills.SkillManager.install_skill_from_zip",
|
||||
_fake_install_skill_from_zip,
|
||||
)
|
||||
|
||||
test_client = app.test_client()
|
||||
|
||||
boundary = "----AstrBotBatchBoundary"
|
||||
body = (
|
||||
(
|
||||
f"--{boundary}\r\n"
|
||||
'Content-Disposition: form-data; name="files"; filename="ok_skill.zip"\r\n'
|
||||
"Content-Type: application/zip\r\n\r\n"
|
||||
).encode()
|
||||
+ b"fake-zip-1\r\n"
|
||||
+ (
|
||||
f"--{boundary}\r\n"
|
||||
'Content-Disposition: form-data; name="files"; filename="bad_skill.zip"\r\n'
|
||||
"Content-Type: application/zip\r\n\r\n"
|
||||
).encode()
|
||||
+ b"fake-zip-2\r\n"
|
||||
+ f"--{boundary}--\r\n".encode()
|
||||
)
|
||||
headers = dict(authenticated_header)
|
||||
headers["Content-Type"] = f"multipart/form-data; boundary={boundary}"
|
||||
|
||||
response = await test_client.post(
|
||||
"/api/skills/batch-upload",
|
||||
headers=headers,
|
||||
data=body,
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data["status"] == "ok"
|
||||
assert data["message"] == "Partial success: 1/2 skill(s) uploaded."
|
||||
assert data["data"]["total"] == 2
|
||||
assert data["data"]["succeeded"] == [
|
||||
{"filename": "ok_skill.zip", "name": "ok_skill"}
|
||||
]
|
||||
assert data["data"]["failed"] == [
|
||||
{"filename": "bad_skill.zip", "error": "install failed"}
|
||||
]
|
||||
|
||||
@@ -0,0 +1,41 @@
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from astrbot.core.utils.pip_installer import PipInstaller
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_install_targets_site_packages_for_desktop_client(monkeypatch, tmp_path):
|
||||
monkeypatch.setenv("ASTRBOT_DESKTOP_CLIENT", "1")
|
||||
monkeypatch.delattr("sys.frozen", raising=False)
|
||||
|
||||
site_packages_path = tmp_path / "site-packages"
|
||||
run_pip = AsyncMock(return_value=0)
|
||||
prepend_sys_path_calls = []
|
||||
ensure_preferred_calls = []
|
||||
|
||||
monkeypatch.setattr(PipInstaller, "_run_pip_in_process", run_pip)
|
||||
monkeypatch.setattr(
|
||||
"astrbot.core.utils.pip_installer.get_astrbot_site_packages_path",
|
||||
lambda: str(site_packages_path),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"astrbot.core.utils.pip_installer._prepend_sys_path",
|
||||
lambda path: prepend_sys_path_calls.append(path),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"astrbot.core.utils.pip_installer._ensure_plugin_dependencies_preferred",
|
||||
lambda path, requirements: ensure_preferred_calls.append((path, requirements)),
|
||||
)
|
||||
|
||||
installer = PipInstaller("")
|
||||
await installer.install(package_name="demo-package")
|
||||
|
||||
run_pip.assert_awaited_once()
|
||||
recorded_args = run_pip.await_args_list[0].args[0]
|
||||
|
||||
assert "--target" in recorded_args
|
||||
assert str(site_packages_path) in recorded_args
|
||||
assert prepend_sys_path_calls == [str(site_packages_path), str(site_packages_path)]
|
||||
assert ensure_preferred_calls == [(str(site_packages_path), {"demo-package"})]
|
||||
@@ -0,0 +1,26 @@
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_root
|
||||
from astrbot.core.utils.runtime_env import is_packaged_desktop_runtime
|
||||
|
||||
|
||||
def test_desktop_client_env_marks_desktop_runtime_without_frozen(monkeypatch):
|
||||
monkeypatch.setenv("ASTRBOT_DESKTOP_CLIENT", "1")
|
||||
monkeypatch.delattr("sys.frozen", raising=False)
|
||||
|
||||
assert is_packaged_desktop_runtime() is True
|
||||
|
||||
|
||||
def test_desktop_client_uses_home_root_without_explicit_astrbot_root(monkeypatch):
|
||||
monkeypatch.setenv("ASTRBOT_DESKTOP_CLIENT", "1")
|
||||
monkeypatch.delenv("ASTRBOT_ROOT", raising=False)
|
||||
monkeypatch.delattr("sys.frozen", raising=False)
|
||||
|
||||
assert get_astrbot_root().endswith(".astrbot")
|
||||
|
||||
|
||||
def test_explicit_astrbot_root_overrides_desktop_default(monkeypatch, tmp_path):
|
||||
explicit_root = tmp_path / "astrbot-root"
|
||||
monkeypatch.setenv("ASTRBOT_DESKTOP_CLIENT", "1")
|
||||
monkeypatch.setenv("ASTRBOT_ROOT", str(explicit_root))
|
||||
monkeypatch.delattr("sys.frozen", raising=False)
|
||||
|
||||
assert get_astrbot_root() == str(explicit_root.resolve())
|
||||
@@ -2,6 +2,8 @@ from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from astrbot.core.skills.skill_manager import SkillManager
|
||||
|
||||
|
||||
@@ -56,7 +58,7 @@ def test_list_skills_merges_local_and_sandbox_cache(monkeypatch, tmp_path: Path)
|
||||
assert by_name["custom-local"].description == "local description"
|
||||
assert by_name["custom-local"].path == "skills/custom-local/SKILL.md"
|
||||
assert by_name["python-sandbox"].description == "ship built-in"
|
||||
assert by_name["python-sandbox"].path == "skills/python-sandbox/SKILL.md"
|
||||
assert by_name["python-sandbox"].path == "/workspace/skills/python-sandbox/SKILL.md"
|
||||
|
||||
|
||||
def test_sandbox_cached_skill_respects_active_and_display_path(
|
||||
@@ -98,7 +100,58 @@ def test_sandbox_cached_skill_respects_active_and_display_path(
|
||||
assert len(all_skills) == 1
|
||||
assert all_skills[0].path == "/app/skills/browser-automation/SKILL.md"
|
||||
|
||||
mgr.set_skill_active("browser-automation", False)
|
||||
active_skills = mgr.list_skills(runtime="sandbox", active_only=True)
|
||||
assert active_skills == []
|
||||
with pytest.raises(PermissionError):
|
||||
mgr.set_skill_active("browser-automation", False)
|
||||
|
||||
active_skills = mgr.list_skills(runtime="sandbox", active_only=True)
|
||||
assert len(active_skills) == 1
|
||||
assert active_skills[0].name == "browser-automation"
|
||||
|
||||
|
||||
def test_sandbox_and_local_path_resolution_with_show_sandbox_path_false(
|
||||
monkeypatch,
|
||||
tmp_path: Path,
|
||||
):
|
||||
data_dir = tmp_path / "data"
|
||||
temp_dir = tmp_path / "temp"
|
||||
skills_root = tmp_path / "skills"
|
||||
data_dir.mkdir(parents=True, exist_ok=True)
|
||||
temp_dir.mkdir(parents=True, exist_ok=True)
|
||||
skills_root.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"astrbot.core.skills.skill_manager.get_astrbot_data_path",
|
||||
lambda: str(data_dir),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"astrbot.core.skills.skill_manager.get_astrbot_temp_path",
|
||||
lambda: str(temp_dir),
|
||||
)
|
||||
|
||||
mgr = SkillManager(skills_root=str(skills_root))
|
||||
_write_skill(skills_root, "custom-local", "local description")
|
||||
mgr.set_sandbox_skills_cache(
|
||||
[
|
||||
{
|
||||
"name": "custom-local",
|
||||
"description": "cached description should be overridden",
|
||||
"path": "/app/skills/custom-local/SKILL.md",
|
||||
},
|
||||
{
|
||||
"name": "python-sandbox",
|
||||
"description": "ship built-in",
|
||||
"path": "/app/skills/python-sandbox/SKILL.md",
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
skills = mgr.list_skills(runtime="sandbox", show_sandbox_path=False)
|
||||
by_name = {item.name: item for item in skills}
|
||||
|
||||
assert sorted(by_name) == ["custom-local", "python-sandbox"]
|
||||
assert by_name["custom-local"].description == "local description"
|
||||
local_skill_path = Path(by_name["custom-local"].path)
|
||||
assert local_skill_path.is_relative_to(skills_root)
|
||||
assert local_skill_path == skills_root / "custom-local" / "SKILL.md"
|
||||
assert by_name["python-sandbox"].path == "/app/skills/python-sandbox/SKILL.md"
|
||||
|
||||
|
||||
@@ -0,0 +1,59 @@
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
import astrbot.core.message.components as Comp
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from astrbot.core.pipeline.respond.stage import RespondStage
|
||||
from astrbot.core.platform.sources.aiocqhttp.aiocqhttp_message_event import (
|
||||
AiocqhttpMessageEvent,
|
||||
)
|
||||
|
||||
|
||||
def test_poke_to_dict_matches_onebot_v11_segment_format():
|
||||
poke = Comp.Poke(type="126", id=2003)
|
||||
assert poke.toDict() == {
|
||||
"type": "poke",
|
||||
"data": {"type": "126", "id": "2003"},
|
||||
}
|
||||
|
||||
|
||||
def test_poke_to_dict_keeps_legacy_qq_compatible():
|
||||
poke = Comp.Poke(type="poke", qq=2916963017)
|
||||
assert poke.toDict() == {
|
||||
"type": "poke",
|
||||
"data": {"type": "126", "id": "2916963017"},
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_respond_stage_treats_poke_with_target_as_non_empty():
|
||||
stage = RespondStage()
|
||||
chain = [Comp.Poke(type="126", id=2003)]
|
||||
assert await stage._is_empty_message_chain(chain) is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aiocqhttp_parse_json_outputs_standard_poke_data():
|
||||
chain = MessageChain([Comp.Poke(type="126", id=2003)])
|
||||
data = await AiocqhttpMessageEvent._parse_onebot_json(chain)
|
||||
assert data == [{"type": "poke", "data": {"type": "126", "id": "2003"}}]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aiocqhttp_send_message_dispatches_onebot_v11_poke_payload():
|
||||
bot = AsyncMock()
|
||||
chain = MessageChain([Comp.Poke(type="126", id=2003)])
|
||||
|
||||
await AiocqhttpMessageEvent.send_message(
|
||||
bot=bot,
|
||||
message_chain=chain,
|
||||
event=None,
|
||||
is_group=True,
|
||||
session_id="123456",
|
||||
)
|
||||
|
||||
bot.send_group_msg.assert_awaited_once_with(
|
||||
group_id=123456,
|
||||
message=[{"type": "poke", "data": {"type": "126", "id": "2003"}}],
|
||||
)
|
||||
@@ -373,7 +373,7 @@ class TestAstrBotCoreLifecycleInitialize:
|
||||
new_callable=AsyncMock,
|
||||
),
|
||||
):
|
||||
await lifecycle.initialize(mcp_init_timeout=3.5)
|
||||
await lifecycle.initialize()
|
||||
|
||||
# Verify database initialized
|
||||
mock_db.initialize.assert_awaited_once()
|
||||
@@ -388,7 +388,7 @@ class TestAstrBotCoreLifecycleInitialize:
|
||||
mock_persona_mgr.initialize.assert_awaited_once()
|
||||
|
||||
# Verify provider manager initialized
|
||||
mock_provider_manager.initialize.assert_awaited_once_with(init_timeout=3.5)
|
||||
mock_provider_manager.initialize.assert_awaited_once()
|
||||
|
||||
# Verify platform manager initialized
|
||||
mock_platform_manager.initialize.assert_awaited_once()
|
||||
|
||||
@@ -1,43 +0,0 @@
|
||||
import pytest
|
||||
|
||||
from tests.fixtures.helpers import get_bound_tcp_port
|
||||
|
||||
|
||||
class _DummySiteNoAttrs:
|
||||
pass
|
||||
|
||||
|
||||
class _DummySocket:
|
||||
def __init__(self, port: int) -> None:
|
||||
self._port = port
|
||||
|
||||
def getsockname(self):
|
||||
return ("127.0.0.1", self._port)
|
||||
|
||||
|
||||
class _DummyServer:
|
||||
def __init__(self, port: int) -> None:
|
||||
self.sockets = [_DummySocket(port)]
|
||||
|
||||
|
||||
class _DummySiteWithName:
|
||||
def __init__(self, port: int) -> None:
|
||||
self.name = f"http://localhost:{port}"
|
||||
|
||||
|
||||
class _DummySiteWithServer:
|
||||
def __init__(self, port: int) -> None:
|
||||
self._server = _DummyServer(port)
|
||||
|
||||
|
||||
def test_get_bound_tcp_port_raises_on_unresolvable_site():
|
||||
with pytest.raises(RuntimeError, match="Unable to resolve bound TCP port"):
|
||||
get_bound_tcp_port(_DummySiteNoAttrs())
|
||||
|
||||
|
||||
def test_get_bound_tcp_port_uses_name_port_when_available():
|
||||
assert get_bound_tcp_port(_DummySiteWithName(8081)) == 8081
|
||||
|
||||
|
||||
def test_get_bound_tcp_port_falls_back_to_server_sockets():
|
||||
assert get_bound_tcp_port(_DummySiteWithServer(9092)) == 9092
|
||||
@@ -1,98 +0,0 @@
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from astrbot.core.provider import func_tool_manager
|
||||
from astrbot.core.provider.func_tool_manager import FunctionToolManager
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mcp_init_harness(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
tmp_path,
|
||||
):
|
||||
manager = FunctionToolManager()
|
||||
data_dir = tmp_path / "data"
|
||||
data_dir.mkdir()
|
||||
|
||||
(data_dir / "mcp_server.json").write_text(
|
||||
json.dumps({"mcpServers": {"demo": {"active": True}}}),
|
||||
encoding="utf-8",
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
func_tool_manager,
|
||||
"get_astrbot_data_path",
|
||||
lambda: data_dir,
|
||||
)
|
||||
|
||||
called = {}
|
||||
|
||||
async def fake_start_mcp_server(*, name, cfg, shutdown_event, timeout_seconds):
|
||||
called[name] = {
|
||||
"cfg": cfg,
|
||||
"shutdown_event_type": type(shutdown_event).__name__,
|
||||
"timeout_seconds": timeout_seconds,
|
||||
}
|
||||
|
||||
monkeypatch.setattr(manager, "_start_mcp_server", fake_start_mcp_server)
|
||||
return manager, called
|
||||
|
||||
|
||||
def assert_demo_init_result(summary, called, *, timeout_seconds: float) -> None:
|
||||
assert summary.total == 1
|
||||
assert summary.success == 1
|
||||
assert summary.failed == []
|
||||
assert called["demo"]["cfg"] == {"active": True}
|
||||
assert called["demo"]["shutdown_event_type"] == "Event"
|
||||
assert called["demo"]["timeout_seconds"] == timeout_seconds
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_init_mcp_clients_passes_timeout_seconds_keyword(mcp_init_harness):
|
||||
manager, called = mcp_init_harness
|
||||
|
||||
summary = await manager.init_mcp_clients()
|
||||
|
||||
assert_demo_init_result(
|
||||
summary,
|
||||
called,
|
||||
timeout_seconds=manager._init_timeout_default,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_init_mcp_clients_passes_overridden_init_timeout(
|
||||
mcp_init_harness,
|
||||
):
|
||||
manager, called = mcp_init_harness
|
||||
|
||||
summary = await manager.init_mcp_clients(init_timeout=3.5)
|
||||
|
||||
assert_demo_init_result(summary, called, timeout_seconds=3.5)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_init_mcp_clients_reads_env_timeout_when_not_overridden(
|
||||
mcp_init_harness,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
):
|
||||
manager, called = mcp_init_harness
|
||||
manager._init_timeout_default = 20.0 # ensure env override is observable
|
||||
monkeypatch.setenv("ASTRBOT_MCP_INIT_TIMEOUT", "3.5")
|
||||
|
||||
summary = await manager.init_mcp_clients()
|
||||
|
||||
assert_demo_init_result(summary, called, timeout_seconds=3.5)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_init_mcp_clients_prefers_explicit_timeout_over_env(
|
||||
mcp_init_harness,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
):
|
||||
manager, called = mcp_init_harness
|
||||
monkeypatch.setenv("ASTRBOT_MCP_INIT_TIMEOUT", "7.0")
|
||||
|
||||
summary = await manager.init_mcp_clients(init_timeout=3.5)
|
||||
|
||||
assert_demo_init_result(summary, called, timeout_seconds=3.5)
|
||||
@@ -1,71 +0,0 @@
|
||||
import pytest
|
||||
from aiohttp import web
|
||||
|
||||
from astrbot.core.utils import io as io_module
|
||||
from astrbot.core.utils.io import _stream_to_file, download_file
|
||||
from tests.fixtures.helpers import get_bound_tcp_port
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_file_downloads_content(tmp_path):
|
||||
payload = b"astrbot-download-payload" * 256
|
||||
|
||||
async def handle(_request):
|
||||
return web.Response(body=payload)
|
||||
|
||||
app = web.Application()
|
||||
app.router.add_get("/file.bin", handle)
|
||||
runner = web.AppRunner(app, access_log=None)
|
||||
await runner.setup()
|
||||
site = web.TCPSite(runner, "127.0.0.1", 0)
|
||||
await site.start()
|
||||
|
||||
try:
|
||||
port = get_bound_tcp_port(site)
|
||||
url = f"http://127.0.0.1:{port}/file.bin"
|
||||
|
||||
out = tmp_path / "downloaded.bin"
|
||||
await download_file(url, str(out))
|
||||
|
||||
assert out.read_bytes() == payload
|
||||
finally:
|
||||
await runner.cleanup()
|
||||
|
||||
|
||||
class _DummyStream:
|
||||
def __init__(self, chunks: list[bytes]) -> None:
|
||||
self._chunks = chunks
|
||||
|
||||
async def read(self, _size: int) -> bytes:
|
||||
if not self._chunks:
|
||||
return b""
|
||||
return self._chunks.pop(0)
|
||||
|
||||
|
||||
class _RecordingFile:
|
||||
def __init__(self) -> None:
|
||||
self.writes: list[bytes] = []
|
||||
|
||||
def write(self, data: bytes) -> int:
|
||||
self.writes.append(data)
|
||||
return len(data)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_to_file_batches_multiple_chunks_per_write(monkeypatch):
|
||||
monkeypatch.setattr(io_module, "_DOWNLOAD_READ_CHUNK_SIZE", 4)
|
||||
monkeypatch.setattr(io_module, "_DOWNLOAD_FLUSH_THRESHOLD", 10)
|
||||
|
||||
stream = _DummyStream([b"aaaa", b"bbbb", b"cccc"])
|
||||
file_obj = _RecordingFile()
|
||||
|
||||
await _stream_to_file(
|
||||
stream,
|
||||
file_obj,
|
||||
total_size=12,
|
||||
start_time=0.0,
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
assert len(file_obj.writes) == 1
|
||||
assert file_obj.writes[0] == b"aaaabbbbcccc"
|
||||
@@ -1,178 +0,0 @@
|
||||
import base64
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from aiohttp import web
|
||||
|
||||
from astrbot.core.message import components as components_module
|
||||
from astrbot.core.message.components import File, Image, Record
|
||||
from tests.fixtures.helpers import get_bound_tcp_port
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_image_convert_to_file_path_returns_absolute_path(tmp_path):
|
||||
file_path = tmp_path / "img.bin"
|
||||
file_path.write_bytes(b"img")
|
||||
|
||||
image = Image(file=str(file_path))
|
||||
resolved = await image.convert_to_file_path()
|
||||
|
||||
assert resolved == os.path.abspath(str(file_path))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_convert_to_file_path_returns_absolute_path(tmp_path):
|
||||
file_path = tmp_path / "record.bin"
|
||||
file_path.write_bytes(b"record")
|
||||
|
||||
record = Record(file=str(file_path))
|
||||
resolved = await record.convert_to_file_path()
|
||||
|
||||
assert resolved == os.path.abspath(str(file_path))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_file_component_get_file_returns_absolute_path(tmp_path):
|
||||
file_path = tmp_path / "file.bin"
|
||||
file_path.write_bytes(b"file")
|
||||
|
||||
file_component = File(name="file.bin", file=str(file_path))
|
||||
resolved = await file_component.get_file()
|
||||
|
||||
assert resolved == os.path.abspath(str(file_path))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_image_convert_to_base64_raises_on_missing_file(tmp_path):
|
||||
image = Image(file=str(tmp_path / "missing.bin"))
|
||||
|
||||
with pytest.raises(Exception, match="not a valid file"):
|
||||
await image.convert_to_base64()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_convert_to_base64_raises_on_missing_file(tmp_path):
|
||||
record = Record(file=str(tmp_path / "missing.bin"))
|
||||
|
||||
with pytest.raises(Exception, match="not a valid file"):
|
||||
await record.convert_to_base64()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_image_convert_to_base64_reads_existing_local_file(tmp_path):
|
||||
raw = b"image-bytes"
|
||||
file_path = tmp_path / "exists_image.bin"
|
||||
file_path.write_bytes(raw)
|
||||
|
||||
image = Image(file=str(file_path))
|
||||
encoded = await image.convert_to_base64()
|
||||
|
||||
assert base64.b64decode(encoded) == raw
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_convert_to_base64_reads_existing_local_file(tmp_path):
|
||||
raw = b"record-bytes"
|
||||
file_path = tmp_path / "exists_record.bin"
|
||||
file_path.write_bytes(raw)
|
||||
|
||||
record = Record(file=str(file_path))
|
||||
encoded = await record.convert_to_base64()
|
||||
|
||||
assert base64.b64decode(encoded) == raw
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_image_convert_to_base64_maps_permission_error(monkeypatch):
|
||||
async def _raise_permission_error(_path: str) -> str:
|
||||
raise PermissionError("permission denied")
|
||||
|
||||
monkeypatch.setattr(components_module, "file_to_base64", _raise_permission_error)
|
||||
|
||||
image = Image(file="/tmp/forbidden-image")
|
||||
with pytest.raises(Exception, match="not a valid file"):
|
||||
await image.convert_to_base64()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_convert_to_base64_maps_permission_error(monkeypatch):
|
||||
async def _raise_permission_error(_path: str) -> str:
|
||||
raise PermissionError("permission denied")
|
||||
|
||||
monkeypatch.setattr(components_module, "file_to_base64", _raise_permission_error)
|
||||
|
||||
record = Record(file="/tmp/forbidden-record")
|
||||
with pytest.raises(Exception, match="not a valid file"):
|
||||
await record.convert_to_base64()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_image_convert_to_file_path_from_base64_creates_absolute_file():
|
||||
payload = b"image-base64-payload"
|
||||
image = Image(file=f"base64://{base64.b64encode(payload).decode()}")
|
||||
|
||||
resolved = await image.convert_to_file_path()
|
||||
resolved_path = Path(resolved)
|
||||
|
||||
assert resolved_path.is_absolute()
|
||||
assert resolved_path.exists()
|
||||
assert resolved_path.read_bytes() == payload
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_convert_to_file_path_from_base64_creates_absolute_file():
|
||||
payload = b"record-base64-payload"
|
||||
record = Record(file=f"base64://{base64.b64encode(payload).decode()}")
|
||||
|
||||
resolved = await record.convert_to_file_path()
|
||||
resolved_path = Path(resolved)
|
||||
|
||||
assert resolved_path.is_absolute()
|
||||
assert resolved_path.exists()
|
||||
assert resolved_path.read_bytes() == payload
|
||||
|
||||
|
||||
async def _serve_payload(payload: bytes, route: str):
|
||||
async def handle(_request):
|
||||
return web.Response(body=payload)
|
||||
|
||||
app = web.Application()
|
||||
app.router.add_get(route, handle)
|
||||
runner = web.AppRunner(app, access_log=None)
|
||||
await runner.setup()
|
||||
site = web.TCPSite(runner, "127.0.0.1", 0)
|
||||
await site.start()
|
||||
return runner, get_bound_tcp_port(site)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_image_convert_to_file_path_from_http_creates_absolute_file():
|
||||
payload = b"image-http-payload"
|
||||
runner, port = await _serve_payload(payload, "/img.bin")
|
||||
try:
|
||||
image = Image(file=f"http://127.0.0.1:{port}/img.bin")
|
||||
resolved = await image.convert_to_file_path()
|
||||
resolved_path = Path(resolved)
|
||||
|
||||
assert resolved_path.is_absolute()
|
||||
assert resolved_path.exists()
|
||||
assert resolved_path.read_bytes() == payload
|
||||
finally:
|
||||
await runner.cleanup()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_convert_to_file_path_from_http_creates_absolute_file():
|
||||
payload = b"record-http-payload"
|
||||
runner, port = await _serve_payload(payload, "/record.bin")
|
||||
try:
|
||||
record = Record(file=f"http://127.0.0.1:{port}/record.bin")
|
||||
resolved = await record.convert_to_file_path()
|
||||
resolved_path = Path(resolved)
|
||||
|
||||
assert resolved_path.is_absolute()
|
||||
assert resolved_path.exists()
|
||||
assert resolved_path.read_bytes() == payload
|
||||
finally:
|
||||
await runner.cleanup()
|
||||
Reference in New Issue
Block a user