Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 2a7c8b44bf | |||
| b8e83b772d |
@@ -36,7 +36,7 @@ jobs:
|
||||
zip -r dist.zip dist
|
||||
|
||||
- name: Archive production artifacts
|
||||
uses: actions/upload-artifact@v7
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: dist-without-markdown
|
||||
path: |
|
||||
|
||||
@@ -71,7 +71,7 @@ jobs:
|
||||
zip -r "AstrBot-${{ steps.tag.outputs.tag }}-dashboard.zip" dist
|
||||
|
||||
- name: Upload dashboard artifact
|
||||
uses: actions/upload-artifact@v7
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: Dashboard-${{ steps.tag.outputs.tag }}
|
||||
if-no-files-found: error
|
||||
@@ -132,7 +132,7 @@ jobs:
|
||||
echo "tag=$tag" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Download dashboard artifact
|
||||
uses: actions/download-artifact@v8
|
||||
uses: actions/download-artifact@v7
|
||||
with:
|
||||
name: Dashboard-${{ steps.tag.outputs.tag }}
|
||||
path: release-assets
|
||||
|
||||
@@ -36,9 +36,6 @@ dashboard/dist/
|
||||
package-lock.json
|
||||
yarn.lock
|
||||
|
||||
# Bundled dashboard dist (generated by hatch_build.py during pip wheel build)
|
||||
astrbot/dashboard/dist/
|
||||
|
||||
# Operating System
|
||||
**/.DS_Store
|
||||
.DS_Store
|
||||
@@ -57,7 +54,3 @@ IFLOW.md
|
||||
# genie_tts data
|
||||
CharacterModels/
|
||||
GenieData/
|
||||
.agent/
|
||||
.codex/
|
||||
.opencode/
|
||||
.kilocode/
|
||||
|
||||
@@ -46,32 +46,6 @@ ruff check .
|
||||
|
||||
如果您使用 VSCode,可以安装 `Ruff` 插件。
|
||||
|
||||
##### PR 功能完整性验证(推荐)
|
||||
|
||||
如果您希望在本地做一套接近 CI 的完整验证,可使用:
|
||||
|
||||
```bash
|
||||
make pr-test-neo
|
||||
```
|
||||
|
||||
该命令会执行:
|
||||
- `uv sync --group dev`
|
||||
- `ruff format --check .` 与 `ruff check .`
|
||||
- Neo 相关关键测试
|
||||
- `main.py` 启动 smoke test(检测 `http://localhost:6185`)
|
||||
|
||||
需要全量验证时可使用:
|
||||
|
||||
```bash
|
||||
make pr-test-full
|
||||
```
|
||||
|
||||
如果只想快速重复执行(跳过依赖同步和 dashboard 构建):
|
||||
|
||||
```bash
|
||||
make pr-test-full-fast
|
||||
```
|
||||
|
||||
|
||||
## Contributing Guide
|
||||
|
||||
@@ -114,29 +88,3 @@ We use Ruff as our code formatter and static analysis tool. Before submitting yo
|
||||
ruff format .
|
||||
ruff check .
|
||||
```
|
||||
|
||||
##### PR completeness checks (recommended)
|
||||
|
||||
To run a local validation flow close to CI, use:
|
||||
|
||||
```bash
|
||||
make pr-test-neo
|
||||
```
|
||||
|
||||
This command runs:
|
||||
- `uv sync --group dev`
|
||||
- `ruff format --check .` and `ruff check .`
|
||||
- Neo-related critical tests
|
||||
- a startup smoke test against `http://localhost:6185`
|
||||
|
||||
For full validation, use:
|
||||
|
||||
```bash
|
||||
make pr-test-full
|
||||
```
|
||||
|
||||
For faster repeated runs (skip dependency sync and dashboard build), use:
|
||||
|
||||
```bash
|
||||
make pr-test-full-fast
|
||||
```
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
.PHONY: worktree worktree-add worktree-rm pr-test-neo pr-test-full pr-test-full-fast
|
||||
.PHONY: worktree worktree-add worktree-rm
|
||||
|
||||
WORKTREE_DIR ?= ../astrbot_worktree
|
||||
BRANCH ?= $(word 2,$(MAKECMDGOALS))
|
||||
@@ -27,15 +27,6 @@ endif
|
||||
echo "Worktree $(WORKTREE_DIR)/$(BRANCH) not found."; \
|
||||
fi
|
||||
|
||||
pr-test-neo:
|
||||
./scripts/pr_test_env.sh --profile neo
|
||||
|
||||
pr-test-full:
|
||||
./scripts/pr_test_env.sh --profile full
|
||||
|
||||
pr-test-full-fast:
|
||||
./scripts/pr_test_env.sh --profile full --skip-sync --no-dashboard
|
||||
|
||||
# Swallow extra args (branch/base) so make doesn't treat them as targets
|
||||
%:
|
||||
@true
|
||||
|
||||
@@ -184,12 +184,6 @@ Connect AstrBot to your favorite chat platform.
|
||||
| Minimax TTS | Text-to-Speech Services |
|
||||
| Volcano Engine TTS | Text-to-Speech Services |
|
||||
|
||||
## ❤️ Sponsors
|
||||
|
||||
<p align="center">
|
||||
<img alt="sponsors" src="https://sponsors.astrbot.app/?v=1">
|
||||
</p>
|
||||
|
||||
## ❤️ Contributing
|
||||
|
||||
Issues and Pull Requests are always welcome! Feel free to submit your changes to this project :)
|
||||
|
||||
@@ -8,7 +8,7 @@ from bs4 import BeautifulSoup
|
||||
from readability import Document
|
||||
|
||||
from astrbot.api import AstrBotConfig, llm_tool, logger, sp, star
|
||||
from astrbot.api.event import AstrMessageEvent, filter
|
||||
from astrbot.api.event import AstrMessageEvent, MessageEventResult, filter
|
||||
from astrbot.api.provider import ProviderRequest
|
||||
from astrbot.core.provider.func_tool_manager import FunctionToolManager
|
||||
|
||||
@@ -196,6 +196,15 @@ class Main(star.Star):
|
||||
)
|
||||
return results
|
||||
|
||||
@filter.command("websearch")
|
||||
async def websearch(self, event: AstrMessageEvent, oper: str | None = None) -> None:
|
||||
"""网页搜索指令(已废弃)"""
|
||||
event.set_result(
|
||||
MessageEventResult().message(
|
||||
"此指令已经被废弃,请在 WebUI 中开启或关闭网页搜索功能。",
|
||||
),
|
||||
)
|
||||
|
||||
@llm_tool(name="web_search")
|
||||
async def search_from_search_engine(
|
||||
self,
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""AstrBot CLI entry point"""
|
||||
"""AstrBot CLI入口"""
|
||||
|
||||
import sys
|
||||
|
||||
@@ -29,23 +29,23 @@ def cli() -> None:
|
||||
@click.command()
|
||||
@click.argument("command_name", required=False, type=str)
|
||||
def help(command_name: str | None) -> None:
|
||||
"""Display help information for commands
|
||||
"""显示命令的帮助信息
|
||||
|
||||
If COMMAND_NAME is provided, display detailed help for that command.
|
||||
Otherwise, display general help information.
|
||||
如果提供了 COMMAND_NAME,则显示该命令的详细帮助信息。
|
||||
否则,显示通用帮助信息。
|
||||
"""
|
||||
ctx = click.get_current_context()
|
||||
if command_name:
|
||||
# Find the specified command
|
||||
# 查找指定命令
|
||||
command = cli.get_command(ctx, command_name)
|
||||
if command:
|
||||
# Display help for the specific command
|
||||
# 显示特定命令的帮助信息
|
||||
click.echo(command.get_help(ctx))
|
||||
else:
|
||||
click.echo(f"Unknown command: {command_name}")
|
||||
sys.exit(1)
|
||||
else:
|
||||
# Display general help information
|
||||
# 显示通用帮助信息
|
||||
click.echo(cli.get_help(ctx))
|
||||
|
||||
|
||||
|
||||
@@ -10,61 +10,57 @@ from ..utils import check_astrbot_root, get_astrbot_root
|
||||
|
||||
|
||||
def _validate_log_level(value: str) -> str:
|
||||
"""Validate log level"""
|
||||
"""验证日志级别"""
|
||||
value = value.upper()
|
||||
if value not in ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]:
|
||||
raise click.ClickException(
|
||||
"Log level must be one of DEBUG/INFO/WARNING/ERROR/CRITICAL",
|
||||
"日志级别必须是 DEBUG/INFO/WARNING/ERROR/CRITICAL 之一",
|
||||
)
|
||||
return value
|
||||
|
||||
|
||||
def _validate_dashboard_port(value: str) -> int:
|
||||
"""Validate Dashboard port"""
|
||||
"""验证 Dashboard 端口"""
|
||||
try:
|
||||
port = int(value)
|
||||
if port < 1 or port > 65535:
|
||||
raise click.ClickException("Port must be in range 1-65535")
|
||||
raise click.ClickException("端口必须在 1-65535 范围内")
|
||||
return port
|
||||
except ValueError:
|
||||
raise click.ClickException("Port must be a number")
|
||||
raise click.ClickException("端口必须是数字")
|
||||
|
||||
|
||||
def _validate_dashboard_username(value: str) -> str:
|
||||
"""Validate Dashboard username"""
|
||||
"""验证 Dashboard 用户名"""
|
||||
if not value:
|
||||
raise click.ClickException("Username cannot be empty")
|
||||
raise click.ClickException("用户名不能为空")
|
||||
return value
|
||||
|
||||
|
||||
def _validate_dashboard_password(value: str) -> str:
|
||||
"""Validate Dashboard password"""
|
||||
"""验证 Dashboard 密码"""
|
||||
if not value:
|
||||
raise click.ClickException("Password cannot be empty")
|
||||
raise click.ClickException("密码不能为空")
|
||||
return hashlib.md5(value.encode()).hexdigest()
|
||||
|
||||
|
||||
def _validate_timezone(value: str) -> str:
|
||||
"""Validate timezone"""
|
||||
"""验证时区"""
|
||||
try:
|
||||
zoneinfo.ZoneInfo(value)
|
||||
except Exception:
|
||||
raise click.ClickException(
|
||||
f"Invalid timezone: {value}. Please use a valid IANA timezone name"
|
||||
)
|
||||
raise click.ClickException(f"无效的时区: {value},请使用有效的IANA时区名称")
|
||||
return value
|
||||
|
||||
|
||||
def _validate_callback_api_base(value: str) -> str:
|
||||
"""Validate callback API base URL"""
|
||||
"""验证回调接口基址"""
|
||||
if not value.startswith("http://") and not value.startswith("https://"):
|
||||
raise click.ClickException(
|
||||
"Callback API base must start with http:// or https://"
|
||||
)
|
||||
raise click.ClickException("回调接口基址必须以 http:// 或 https:// 开头")
|
||||
return value
|
||||
|
||||
|
||||
# Configuration items settable via CLI, mapping config keys to validator functions
|
||||
# 可通过CLI设置的配置项,配置键到验证器函数的映射
|
||||
CONFIG_VALIDATORS: dict[str, Callable[[str], Any]] = {
|
||||
"timezone": _validate_timezone,
|
||||
"log_level": _validate_log_level,
|
||||
@@ -76,11 +72,11 @@ CONFIG_VALIDATORS: dict[str, Callable[[str], Any]] = {
|
||||
|
||||
|
||||
def _load_config() -> dict[str, Any]:
|
||||
"""Load or initialize config file"""
|
||||
"""加载或初始化配置文件"""
|
||||
root = get_astrbot_root()
|
||||
if not check_astrbot_root(root):
|
||||
raise click.ClickException(
|
||||
f"{root} is not a valid AstrBot root directory. Use 'astrbot init' to initialize",
|
||||
f"{root}不是有效的 AstrBot 根目录,如需初始化请使用 astrbot init",
|
||||
)
|
||||
|
||||
config_path = root / "data" / "cmd_config.json"
|
||||
@@ -95,11 +91,11 @@ def _load_config() -> dict[str, Any]:
|
||||
try:
|
||||
return json.loads(config_path.read_text(encoding="utf-8-sig"))
|
||||
except json.JSONDecodeError as e:
|
||||
raise click.ClickException(f"Failed to parse config file: {e!s}")
|
||||
raise click.ClickException(f"配置文件解析失败: {e!s}")
|
||||
|
||||
|
||||
def _save_config(config: dict[str, Any]) -> None:
|
||||
"""Save config file"""
|
||||
"""保存配置文件"""
|
||||
config_path = get_astrbot_root() / "data" / "cmd_config.json"
|
||||
|
||||
config_path.write_text(
|
||||
@@ -109,21 +105,21 @@ def _save_config(config: dict[str, Any]) -> None:
|
||||
|
||||
|
||||
def _set_nested_item(obj: dict[str, Any], path: str, value: Any) -> None:
|
||||
"""Set a value in a nested dictionary"""
|
||||
"""设置嵌套字典中的值"""
|
||||
parts = path.split(".")
|
||||
for part in parts[:-1]:
|
||||
if part not in obj:
|
||||
obj[part] = {}
|
||||
elif not isinstance(obj[part], dict):
|
||||
raise click.ClickException(
|
||||
f"Config path conflict: {'.'.join(parts[: parts.index(part) + 1])} is not a dict",
|
||||
f"配置路径冲突: {'.'.join(parts[: parts.index(part) + 1])} 不是字典",
|
||||
)
|
||||
obj = obj[part]
|
||||
obj[parts[-1]] = value
|
||||
|
||||
|
||||
def _get_nested_item(obj: dict[str, Any], path: str) -> Any:
|
||||
"""Get a value from a nested dictionary"""
|
||||
"""获取嵌套字典中的值"""
|
||||
parts = path.split(".")
|
||||
for part in parts:
|
||||
obj = obj[part]
|
||||
@@ -132,21 +128,21 @@ def _get_nested_item(obj: dict[str, Any], path: str) -> Any:
|
||||
|
||||
@click.group(name="conf")
|
||||
def conf() -> None:
|
||||
"""Configuration management commands
|
||||
"""配置管理命令
|
||||
|
||||
Supported config keys:
|
||||
支持的配置项:
|
||||
|
||||
- timezone: Timezone setting (e.g. Asia/Shanghai)
|
||||
- timezone: 时区设置 (例如: Asia/Shanghai)
|
||||
|
||||
- log_level: Log level (DEBUG/INFO/WARNING/ERROR/CRITICAL)
|
||||
- log_level: 日志级别 (DEBUG/INFO/WARNING/ERROR/CRITICAL)
|
||||
|
||||
- dashboard.port: Dashboard port
|
||||
- dashboard.port: Dashboard 端口
|
||||
|
||||
- dashboard.username: Dashboard username
|
||||
- dashboard.username: Dashboard 用户名
|
||||
|
||||
- dashboard.password: Dashboard password
|
||||
- dashboard.password: Dashboard 密码
|
||||
|
||||
- callback_api_base: Callback API base URL
|
||||
- callback_api_base: 回调接口基址
|
||||
"""
|
||||
|
||||
|
||||
@@ -154,9 +150,9 @@ def conf() -> None:
|
||||
@click.argument("key")
|
||||
@click.argument("value")
|
||||
def set_config(key: str, value: str) -> None:
|
||||
"""Set the value of a config item"""
|
||||
"""设置配置项的值"""
|
||||
if key not in CONFIG_VALIDATORS:
|
||||
raise click.ClickException(f"Unsupported config key: {key}")
|
||||
raise click.ClickException(f"不支持的配置项: {key}")
|
||||
|
||||
config = _load_config()
|
||||
|
||||
@@ -166,29 +162,29 @@ def set_config(key: str, value: str) -> None:
|
||||
_set_nested_item(config, key, validated_value)
|
||||
_save_config(config)
|
||||
|
||||
click.echo(f"Config updated: {key}")
|
||||
click.echo(f"配置已更新: {key}")
|
||||
if key == "dashboard.password":
|
||||
click.echo(" Old value: ********")
|
||||
click.echo(" New value: ********")
|
||||
click.echo(" 原值: ********")
|
||||
click.echo(" 新值: ********")
|
||||
else:
|
||||
click.echo(f" Old value: {old_value}")
|
||||
click.echo(f" New value: {validated_value}")
|
||||
click.echo(f" 原值: {old_value}")
|
||||
click.echo(f" 新值: {validated_value}")
|
||||
|
||||
except KeyError:
|
||||
raise click.ClickException(f"Unknown config key: {key}")
|
||||
raise click.ClickException(f"未知的配置项: {key}")
|
||||
except Exception as e:
|
||||
raise click.UsageError(f"Failed to set config: {e!s}")
|
||||
raise click.UsageError(f"设置配置失败: {e!s}")
|
||||
|
||||
|
||||
@conf.command(name="get")
|
||||
@click.argument("key", required=False)
|
||||
def get_config(key: str | None = None) -> None:
|
||||
"""Get the value of a config item. If no key is provided, show all configurable items"""
|
||||
"""获取配置项的值,不提供key则显示所有可配置项"""
|
||||
config = _load_config()
|
||||
|
||||
if key:
|
||||
if key not in CONFIG_VALIDATORS:
|
||||
raise click.ClickException(f"Unsupported config key: {key}")
|
||||
raise click.ClickException(f"不支持的配置项: {key}")
|
||||
|
||||
try:
|
||||
value = _get_nested_item(config, key)
|
||||
@@ -196,11 +192,11 @@ def get_config(key: str | None = None) -> None:
|
||||
value = "********"
|
||||
click.echo(f"{key}: {value}")
|
||||
except KeyError:
|
||||
raise click.ClickException(f"Unknown config key: {key}")
|
||||
raise click.ClickException(f"未知的配置项: {key}")
|
||||
except Exception as e:
|
||||
raise click.UsageError(f"Failed to get config: {e!s}")
|
||||
raise click.UsageError(f"获取配置失败: {e!s}")
|
||||
else:
|
||||
click.echo("Current config:")
|
||||
click.echo("当前配置:")
|
||||
for key in CONFIG_VALIDATORS:
|
||||
try:
|
||||
value = (
|
||||
|
||||
@@ -8,12 +8,16 @@ from ..utils import check_dashboard, get_astrbot_root
|
||||
|
||||
|
||||
async def initialize_astrbot(astrbot_root: Path) -> None:
|
||||
"""Execute AstrBot initialization logic"""
|
||||
"""执行 AstrBot 初始化逻辑"""
|
||||
dot_astrbot = astrbot_root / ".astrbot"
|
||||
|
||||
if not dot_astrbot.exists():
|
||||
click.echo(f"Current Directory: {astrbot_root}")
|
||||
click.echo(
|
||||
"如果你确认这是 Astrbot root directory, 你需要在当前目录下创建一个 .astrbot 文件标记该目录为 AstrBot 的数据目录。",
|
||||
)
|
||||
if click.confirm(
|
||||
f"Install AstrBot to this directory? {astrbot_root}",
|
||||
f"请检查当前目录是否正确,确认正确请回车: {astrbot_root}",
|
||||
default=True,
|
||||
abort=True,
|
||||
):
|
||||
@@ -36,7 +40,7 @@ async def initialize_astrbot(astrbot_root: Path) -> None:
|
||||
|
||||
@click.command()
|
||||
def init() -> None:
|
||||
"""Initialize AstrBot"""
|
||||
"""初始化 AstrBot"""
|
||||
click.echo("Initializing AstrBot...")
|
||||
astrbot_root = get_astrbot_root()
|
||||
lock_file = astrbot_root / "astrbot.lock"
|
||||
@@ -45,11 +49,8 @@ def init() -> None:
|
||||
try:
|
||||
with lock.acquire():
|
||||
asyncio.run(initialize_astrbot(astrbot_root))
|
||||
click.echo("Done! You can now run 'astrbot run' to start AstrBot")
|
||||
except Timeout:
|
||||
raise click.ClickException(
|
||||
"Cannot acquire lock file. Please check if another instance is running"
|
||||
)
|
||||
raise click.ClickException("无法获取锁文件,请检查是否有其他实例正在运行")
|
||||
|
||||
except Exception as e:
|
||||
raise click.ClickException(f"Initialization failed: {e!s}")
|
||||
raise click.ClickException(f"初始化失败: {e!s}")
|
||||
|
||||
@@ -16,14 +16,14 @@ from ..utils import (
|
||||
|
||||
@click.group()
|
||||
def plug() -> None:
|
||||
"""Plugin management"""
|
||||
"""插件管理"""
|
||||
|
||||
|
||||
def _get_data_path() -> Path:
|
||||
base = get_astrbot_root()
|
||||
if not check_astrbot_root(base):
|
||||
raise click.ClickException(
|
||||
f"{base} is not a valid AstrBot root directory. Use 'astrbot init' to initialize",
|
||||
f"{base}不是有效的 AstrBot 根目录,如需初始化请使用 astrbot init",
|
||||
)
|
||||
return (base / "data").resolve()
|
||||
|
||||
@@ -32,9 +32,7 @@ def display_plugins(plugins, title=None, color=None) -> None:
|
||||
if title:
|
||||
click.echo(click.style(title, fg=color, bold=True))
|
||||
|
||||
click.echo(
|
||||
f"{'Name':<20} {'Version':<10} {'Status':<10} {'Author':<15} {'Description':<30}"
|
||||
)
|
||||
click.echo(f"{'名称':<20} {'版本':<10} {'状态':<10} {'作者':<15} {'描述':<30}")
|
||||
click.echo("-" * 85)
|
||||
|
||||
for p in plugins:
|
||||
@@ -48,30 +46,30 @@ def display_plugins(plugins, title=None, color=None) -> None:
|
||||
@plug.command()
|
||||
@click.argument("name")
|
||||
def new(name: str) -> None:
|
||||
"""Create a new plugin"""
|
||||
"""创建新插件"""
|
||||
base_path = _get_data_path()
|
||||
plug_path = base_path / "plugins" / name
|
||||
|
||||
if plug_path.exists():
|
||||
raise click.ClickException(f"Plugin {name} already exists")
|
||||
raise click.ClickException(f"插件 {name} 已存在")
|
||||
|
||||
author = click.prompt("Enter plugin author", type=str)
|
||||
desc = click.prompt("Enter plugin description", type=str)
|
||||
version = click.prompt("Enter plugin version", type=str)
|
||||
author = click.prompt("请输入插件作者", type=str)
|
||||
desc = click.prompt("请输入插件描述", type=str)
|
||||
version = click.prompt("请输入插件版本", type=str)
|
||||
if not re.match(r"^\d+\.\d+(\.\d+)?$", version.lower().lstrip("v")):
|
||||
raise click.ClickException("Version must be in x.y or x.y.z format")
|
||||
repo = click.prompt("Enter plugin repository URL:", type=str)
|
||||
raise click.ClickException("版本号必须为 x.y 或 x.y.z 格式")
|
||||
repo = click.prompt("请输入插件仓库:", type=str)
|
||||
if not repo.startswith("http"):
|
||||
raise click.ClickException("Repository URL must start with http")
|
||||
raise click.ClickException("仓库地址必须以 http 开头")
|
||||
|
||||
click.echo("Downloading plugin template...")
|
||||
click.echo("下载插件模板...")
|
||||
get_git_repo(
|
||||
"https://github.com/Soulter/helloworld",
|
||||
plug_path,
|
||||
)
|
||||
|
||||
click.echo("Rewriting plugin metadata...")
|
||||
# Rewrite metadata.yaml
|
||||
click.echo("重写插件信息...")
|
||||
# 重写 metadata.yaml
|
||||
with open(plug_path / "metadata.yaml", "w", encoding="utf-8") as f:
|
||||
f.write(
|
||||
f"name: {name}\n"
|
||||
@@ -81,13 +79,11 @@ def new(name: str) -> None:
|
||||
f"repo: {repo}\n",
|
||||
)
|
||||
|
||||
# Rewrite README.md
|
||||
# 重写 README.md
|
||||
with open(plug_path / "README.md", "w", encoding="utf-8") as f:
|
||||
f.write(
|
||||
f"# {name}\n\n{desc}\n\n# Support\n\n[Documentation](https://astrbot.app)\n"
|
||||
)
|
||||
f.write(f"# {name}\n\n{desc}\n\n# 支持\n\n[帮助文档](https://astrbot.app)\n")
|
||||
|
||||
# Rewrite main.py
|
||||
# 重写 main.py
|
||||
with open(plug_path / "main.py", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
|
||||
@@ -99,54 +95,54 @@ def new(name: str) -> None:
|
||||
with open(plug_path / "main.py", "w", encoding="utf-8") as f:
|
||||
f.write(new_content)
|
||||
|
||||
click.echo(f"Plugin {name} created successfully")
|
||||
click.echo(f"插件 {name} 创建成功")
|
||||
|
||||
|
||||
@plug.command()
|
||||
@click.option("--all", "-a", is_flag=True, help="List uninstalled plugins")
|
||||
@click.option("--all", "-a", is_flag=True, help="列出未安装的插件")
|
||||
def list(all: bool) -> None:
|
||||
"""List plugins"""
|
||||
"""列出插件"""
|
||||
base_path = _get_data_path()
|
||||
plugins = build_plug_list(base_path / "plugins")
|
||||
|
||||
# Unpublished plugins
|
||||
# 未发布的插件
|
||||
not_published_plugins = [
|
||||
p for p in plugins if p["status"] == PluginStatus.NOT_PUBLISHED
|
||||
]
|
||||
if not_published_plugins:
|
||||
display_plugins(not_published_plugins, "Unpublished Plugins", "red")
|
||||
display_plugins(not_published_plugins, "未发布的插件", "red")
|
||||
|
||||
# Plugins needing update
|
||||
# 需要更新的插件
|
||||
need_update_plugins = [
|
||||
p for p in plugins if p["status"] == PluginStatus.NEED_UPDATE
|
||||
]
|
||||
if need_update_plugins:
|
||||
display_plugins(need_update_plugins, "Plugins Needing Update", "yellow")
|
||||
display_plugins(need_update_plugins, "需要更新的插件", "yellow")
|
||||
|
||||
# Installed plugins
|
||||
# 已安装的插件
|
||||
installed_plugins = [p for p in plugins if p["status"] == PluginStatus.INSTALLED]
|
||||
if installed_plugins:
|
||||
display_plugins(installed_plugins, "Installed Plugins", "green")
|
||||
display_plugins(installed_plugins, "已安装的插件", "green")
|
||||
|
||||
# Uninstalled plugins
|
||||
# 未安装的插件
|
||||
not_installed_plugins = [
|
||||
p for p in plugins if p["status"] == PluginStatus.NOT_INSTALLED
|
||||
]
|
||||
if not_installed_plugins and all:
|
||||
display_plugins(not_installed_plugins, "Uninstalled Plugins", "blue")
|
||||
display_plugins(not_installed_plugins, "未安装的插件", "blue")
|
||||
|
||||
if (
|
||||
not any([not_published_plugins, need_update_plugins, installed_plugins])
|
||||
and not all
|
||||
):
|
||||
click.echo("No plugins installed")
|
||||
click.echo("未安装任何插件")
|
||||
|
||||
|
||||
@plug.command()
|
||||
@click.argument("name")
|
||||
@click.option("--proxy", help="Proxy server address")
|
||||
@click.option("--proxy", help="代理服务器地址")
|
||||
def install(name: str, proxy: str | None) -> None:
|
||||
"""Install a plugin"""
|
||||
"""安装插件"""
|
||||
base_path = _get_data_path()
|
||||
plug_path = base_path / "plugins"
|
||||
plugins = build_plug_list(base_path / "plugins")
|
||||
@@ -161,7 +157,7 @@ def install(name: str, proxy: str | None) -> None:
|
||||
)
|
||||
|
||||
if not plugin:
|
||||
raise click.ClickException(f"Plugin {name} not found or already installed")
|
||||
raise click.ClickException(f"未找到可安装的插件 {name},可能是不存在或已安装")
|
||||
|
||||
manage_plugin(plugin, plug_path, is_update=False, proxy=proxy)
|
||||
|
||||
@@ -169,32 +165,30 @@ def install(name: str, proxy: str | None) -> None:
|
||||
@plug.command()
|
||||
@click.argument("name")
|
||||
def remove(name: str) -> None:
|
||||
"""Uninstall a plugin"""
|
||||
"""卸载插件"""
|
||||
base_path = _get_data_path()
|
||||
plugins = build_plug_list(base_path / "plugins")
|
||||
plugin = next((p for p in plugins if p["name"] == name), None)
|
||||
|
||||
if not plugin or not plugin.get("local_path"):
|
||||
raise click.ClickException(f"Plugin {name} does not exist or is not installed")
|
||||
raise click.ClickException(f"插件 {name} 不存在或未安装")
|
||||
|
||||
plugin_path = plugin["local_path"]
|
||||
|
||||
click.confirm(
|
||||
f"Are you sure you want to uninstall plugin {name}?", default=False, abort=True
|
||||
)
|
||||
click.confirm(f"确定要卸载插件 {name} 吗?", default=False, abort=True)
|
||||
|
||||
try:
|
||||
shutil.rmtree(plugin_path)
|
||||
click.echo(f"Plugin {name} has been uninstalled")
|
||||
click.echo(f"插件 {name} 已卸载")
|
||||
except Exception as e:
|
||||
raise click.ClickException(f"Failed to uninstall plugin {name}: {e}")
|
||||
raise click.ClickException(f"卸载插件 {name} 失败: {e}")
|
||||
|
||||
|
||||
@plug.command()
|
||||
@click.argument("name", required=False)
|
||||
@click.option("--proxy", help="GitHub proxy address")
|
||||
@click.option("--proxy", help="Github代理地址")
|
||||
def update(name: str, proxy: str | None) -> None:
|
||||
"""Update plugins"""
|
||||
"""更新插件"""
|
||||
base_path = _get_data_path()
|
||||
plug_path = base_path / "plugins"
|
||||
plugins = build_plug_list(base_path / "plugins")
|
||||
@@ -210,9 +204,7 @@ def update(name: str, proxy: str | None) -> None:
|
||||
)
|
||||
|
||||
if not plugin:
|
||||
raise click.ClickException(
|
||||
f"Plugin {name} does not need updating or cannot be updated"
|
||||
)
|
||||
raise click.ClickException(f"插件 {name} 不需要更新或无法更新")
|
||||
|
||||
manage_plugin(plugin, plug_path, is_update=True, proxy=proxy)
|
||||
else:
|
||||
@@ -221,20 +213,20 @@ def update(name: str, proxy: str | None) -> None:
|
||||
]
|
||||
|
||||
if not need_update_plugins:
|
||||
click.echo("No plugins need updating")
|
||||
click.echo("没有需要更新的插件")
|
||||
return
|
||||
|
||||
click.echo(f"Found {len(need_update_plugins)} plugin(s) needing update")
|
||||
click.echo(f"发现 {len(need_update_plugins)} 个插件需要更新")
|
||||
for plugin in need_update_plugins:
|
||||
plugin_name = plugin["name"]
|
||||
click.echo(f"Updating plugin {plugin_name}...")
|
||||
click.echo(f"正在更新插件 {plugin_name}...")
|
||||
manage_plugin(plugin, plug_path, is_update=True, proxy=proxy)
|
||||
|
||||
|
||||
@plug.command()
|
||||
@click.argument("query")
|
||||
def search(query: str) -> None:
|
||||
"""Search for plugins"""
|
||||
"""搜索插件"""
|
||||
base_path = _get_data_path()
|
||||
plugins = build_plug_list(base_path / "plugins")
|
||||
|
||||
@@ -247,7 +239,7 @@ def search(query: str) -> None:
|
||||
]
|
||||
|
||||
if not matched_plugins:
|
||||
click.echo(f"No plugins matching '{query}' found")
|
||||
click.echo(f"未找到匹配 '{query}' 的插件")
|
||||
return
|
||||
|
||||
display_plugins(matched_plugins, f"Search results: '{query}'", "cyan")
|
||||
display_plugins(matched_plugins, f"搜索结果: '{query}'", "cyan")
|
||||
|
||||
@@ -11,7 +11,7 @@ from ..utils import check_astrbot_root, check_dashboard, get_astrbot_root
|
||||
|
||||
|
||||
async def run_astrbot(astrbot_root: Path) -> None:
|
||||
"""Run AstrBot"""
|
||||
"""运行 AstrBot"""
|
||||
from astrbot.core import LogBroker, LogManager, db_helper, logger
|
||||
from astrbot.core.initial_loader import InitialLoader
|
||||
|
||||
@@ -26,18 +26,18 @@ async def run_astrbot(astrbot_root: Path) -> None:
|
||||
await core_lifecycle.start()
|
||||
|
||||
|
||||
@click.option("--reload", "-r", is_flag=True, help="Auto-reload plugins")
|
||||
@click.option("--port", "-p", help="AstrBot Dashboard port", required=False, type=str)
|
||||
@click.option("--reload", "-r", is_flag=True, help="插件自动重载")
|
||||
@click.option("--port", "-p", help="Astrbot Dashboard端口", required=False, type=str)
|
||||
@click.command()
|
||||
def run(reload: bool, port: str) -> None:
|
||||
"""Run AstrBot"""
|
||||
"""运行 AstrBot"""
|
||||
try:
|
||||
os.environ["ASTRBOT_CLI"] = "1"
|
||||
astrbot_root = get_astrbot_root()
|
||||
|
||||
if not check_astrbot_root(astrbot_root):
|
||||
raise click.ClickException(
|
||||
f"{astrbot_root} is not a valid AstrBot root directory. Use 'astrbot init' to initialize",
|
||||
f"{astrbot_root}不是有效的 AstrBot 根目录,如需初始化请使用 astrbot init",
|
||||
)
|
||||
|
||||
os.environ["ASTRBOT_ROOT"] = str(astrbot_root)
|
||||
@@ -47,7 +47,7 @@ def run(reload: bool, port: str) -> None:
|
||||
os.environ["DASHBOARD_PORT"] = port
|
||||
|
||||
if reload:
|
||||
click.echo("Plugin auto-reload enabled")
|
||||
click.echo("启用插件自动重载")
|
||||
os.environ["ASTRBOT_RELOAD"] = "1"
|
||||
|
||||
lock_file = astrbot_root / "astrbot.lock"
|
||||
@@ -55,10 +55,8 @@ def run(reload: bool, port: str) -> None:
|
||||
with lock.acquire():
|
||||
asyncio.run(run_astrbot(astrbot_root))
|
||||
except KeyboardInterrupt:
|
||||
click.echo("AstrBot has been shut down.")
|
||||
click.echo("AstrBot 已关闭...")
|
||||
except Timeout:
|
||||
raise click.ClickException(
|
||||
"Cannot acquire lock file. Please check if another instance is running"
|
||||
)
|
||||
raise click.ClickException("无法获取锁文件,请检查是否有其他实例正在运行")
|
||||
except Exception as e:
|
||||
raise click.ClickException(f"Runtime error: {e}\n{traceback.format_exc()}")
|
||||
raise click.ClickException(f"运行时出现错误: {e}\n{traceback.format_exc()}")
|
||||
|
||||
+13
-21
@@ -2,12 +2,9 @@ from pathlib import Path
|
||||
|
||||
import click
|
||||
|
||||
# Static assets bundled inside the installed wheel (built by hatch_build.py).
|
||||
_BUNDLED_DIST = Path(__file__).parent.parent.parent / "dashboard" / "dist"
|
||||
|
||||
|
||||
def check_astrbot_root(path: str | Path) -> bool:
|
||||
"""Check if the path is an AstrBot root directory"""
|
||||
"""检查路径是否为 AstrBot 根目录"""
|
||||
if not isinstance(path, Path):
|
||||
path = Path(path)
|
||||
if not path.exists() or not path.is_dir():
|
||||
@@ -18,48 +15,43 @@ def check_astrbot_root(path: str | Path) -> bool:
|
||||
|
||||
|
||||
def get_astrbot_root() -> Path:
|
||||
"""Get the AstrBot root directory path"""
|
||||
"""获取Astrbot根目录路径"""
|
||||
return Path.cwd()
|
||||
|
||||
|
||||
async def check_dashboard(astrbot_root: Path) -> None:
|
||||
"""Check if the dashboard is installed"""
|
||||
"""检查是否安装了dashboard"""
|
||||
from astrbot.core.config.default import VERSION
|
||||
from astrbot.core.utils.io import download_dashboard, get_dashboard_version
|
||||
|
||||
from .version_comparator import VersionComparator
|
||||
|
||||
# If the wheel ships bundled dashboard assets, no network download is needed.
|
||||
if _BUNDLED_DIST.exists():
|
||||
click.echo("Dashboard is bundled with the package – skipping download.")
|
||||
return
|
||||
|
||||
try:
|
||||
dashboard_version = await get_dashboard_version()
|
||||
match dashboard_version:
|
||||
case None:
|
||||
click.echo("Dashboard is not installed")
|
||||
click.echo("未安装管理面板")
|
||||
if click.confirm(
|
||||
"Install dashboard?",
|
||||
"是否安装管理面板?",
|
||||
default=True,
|
||||
abort=True,
|
||||
):
|
||||
click.echo("Installing dashboard...")
|
||||
click.echo("正在安装管理面板...")
|
||||
await download_dashboard(
|
||||
path="data/dashboard.zip",
|
||||
extract_path=str(astrbot_root),
|
||||
version=f"v{VERSION}",
|
||||
latest=False,
|
||||
)
|
||||
click.echo("Dashboard installed successfully")
|
||||
click.echo("管理面板安装完成")
|
||||
|
||||
case str():
|
||||
if VersionComparator.compare_version(VERSION, dashboard_version) <= 0:
|
||||
click.echo("Dashboard is already up to date")
|
||||
click.echo("管理面板已是最新版本")
|
||||
return
|
||||
try:
|
||||
version = dashboard_version.split("v")[1]
|
||||
click.echo(f"Dashboard version: {version}")
|
||||
click.echo(f"管理面板版本: {version}")
|
||||
await download_dashboard(
|
||||
path="data/dashboard.zip",
|
||||
extract_path=str(astrbot_root),
|
||||
@@ -67,10 +59,10 @@ async def check_dashboard(astrbot_root: Path) -> None:
|
||||
latest=False,
|
||||
)
|
||||
except Exception as e:
|
||||
click.echo(f"Failed to download dashboard: {e}")
|
||||
click.echo(f"下载管理面板失败: {e}")
|
||||
return
|
||||
except FileNotFoundError:
|
||||
click.echo("Initializing dashboard directory...")
|
||||
click.echo("初始化管理面板目录...")
|
||||
try:
|
||||
await download_dashboard(
|
||||
path=str(astrbot_root / "dashboard.zip"),
|
||||
@@ -78,7 +70,7 @@ async def check_dashboard(astrbot_root: Path) -> None:
|
||||
version=f"v{VERSION}",
|
||||
latest=False,
|
||||
)
|
||||
click.echo("Dashboard initialized successfully")
|
||||
click.echo("管理面板初始化完成")
|
||||
except Exception as e:
|
||||
click.echo(f"Failed to download dashboard: {e}")
|
||||
click.echo(f"下载管理面板失败: {e}")
|
||||
return
|
||||
|
||||
+43
-47
@@ -13,22 +13,22 @@ from .version_comparator import VersionComparator
|
||||
|
||||
|
||||
class PluginStatus(str, Enum):
|
||||
INSTALLED = "installed"
|
||||
NEED_UPDATE = "needs-update"
|
||||
NOT_INSTALLED = "not-installed"
|
||||
NOT_PUBLISHED = "unpublished"
|
||||
INSTALLED = "已安装"
|
||||
NEED_UPDATE = "需更新"
|
||||
NOT_INSTALLED = "未安装"
|
||||
NOT_PUBLISHED = "未发布"
|
||||
|
||||
|
||||
def get_git_repo(url: str, target_path: Path, proxy: str | None = None) -> None:
|
||||
"""Download code from a Git repository and extract to the specified path"""
|
||||
"""从 Git 仓库下载代码并解压到指定路径"""
|
||||
temp_dir = Path(tempfile.mkdtemp())
|
||||
try:
|
||||
# Parse repository info
|
||||
# 解析仓库信息
|
||||
repo_namespace = url.split("/")[-2:]
|
||||
author = repo_namespace[0]
|
||||
repo = repo_namespace[1]
|
||||
|
||||
# Try to get the latest release
|
||||
# 尝试获取最新的 release
|
||||
release_url = f"https://api.github.com/repos/{author}/{repo}/releases"
|
||||
try:
|
||||
with httpx.Client(
|
||||
@@ -40,21 +40,21 @@ def get_git_repo(url: str, target_path: Path, proxy: str | None = None) -> None:
|
||||
releases = resp.json()
|
||||
|
||||
if releases:
|
||||
# Use the latest release
|
||||
# 使用最新的 release
|
||||
download_url = releases[0]["zipball_url"]
|
||||
else:
|
||||
# No release found, use default branch
|
||||
click.echo(f"Downloading {author}/{repo} from default branch")
|
||||
# 没有 release,使用默认分支
|
||||
click.echo(f"正在从默认分支下载 {author}/{repo}")
|
||||
download_url = f"https://github.com/{author}/{repo}/archive/refs/heads/master.zip"
|
||||
except Exception as e:
|
||||
click.echo(f"Failed to get release info: {e}. Using provided URL directly")
|
||||
click.echo(f"获取 release 信息失败: {e},将直接使用提供的 URL")
|
||||
download_url = url
|
||||
|
||||
# Apply proxy
|
||||
# 应用代理
|
||||
if proxy:
|
||||
download_url = f"{proxy}/{download_url}"
|
||||
|
||||
# Download and extract
|
||||
# 下载并解压
|
||||
with httpx.Client(
|
||||
proxy=proxy if proxy else None,
|
||||
follow_redirects=True,
|
||||
@@ -65,7 +65,7 @@ def get_git_repo(url: str, target_path: Path, proxy: str | None = None) -> None:
|
||||
and "archive/refs/heads/master.zip" in download_url
|
||||
):
|
||||
alt_url = download_url.replace("master.zip", "main.zip")
|
||||
click.echo("Branch 'master' not found, trying 'main' branch")
|
||||
click.echo("master 分支不存在,尝试下载 main 分支")
|
||||
resp = client.get(alt_url)
|
||||
resp.raise_for_status()
|
||||
else:
|
||||
@@ -84,13 +84,13 @@ def get_git_repo(url: str, target_path: Path, proxy: str | None = None) -> None:
|
||||
|
||||
|
||||
def load_yaml_metadata(plugin_dir: Path) -> dict:
|
||||
"""Load plugin metadata from metadata.yaml file
|
||||
"""从 metadata.yaml 文件加载插件元数据
|
||||
|
||||
Args:
|
||||
plugin_dir: Plugin directory path
|
||||
plugin_dir: 插件目录路径
|
||||
|
||||
Returns:
|
||||
dict: Dictionary containing metadata, or empty dict if loading fails
|
||||
dict: 包含元数据的字典,如果读取失败则返回空字典
|
||||
|
||||
"""
|
||||
yaml_path = plugin_dir / "metadata.yaml"
|
||||
@@ -98,33 +98,33 @@ def load_yaml_metadata(plugin_dir: Path) -> dict:
|
||||
try:
|
||||
return yaml.safe_load(yaml_path.read_text(encoding="utf-8")) or {}
|
||||
except Exception as e:
|
||||
click.echo(f"Failed to read {yaml_path}: {e}", err=True)
|
||||
click.echo(f"读取 {yaml_path} 失败: {e}", err=True)
|
||||
return {}
|
||||
|
||||
|
||||
def build_plug_list(plugins_dir: Path) -> list:
|
||||
"""Build plugin list containing local and online plugin information
|
||||
"""构建插件列表,包含本地和在线插件信息
|
||||
|
||||
Args:
|
||||
plugins_dir (Path): Plugin directory path
|
||||
plugins_dir (Path): 插件目录路径
|
||||
|
||||
Returns:
|
||||
list: List of dicts containing plugin information
|
||||
list: 包含插件信息的字典列表
|
||||
|
||||
"""
|
||||
# Get local plugin info
|
||||
# 获取本地插件信息
|
||||
result = []
|
||||
if plugins_dir.exists():
|
||||
for plugin_name in [d.name for d in plugins_dir.glob("*") if d.is_dir()]:
|
||||
plugin_dir = plugins_dir / plugin_name
|
||||
|
||||
# Load metadata from metadata.yaml
|
||||
# 从 metadata.yaml 加载元数据
|
||||
metadata = load_yaml_metadata(plugin_dir)
|
||||
|
||||
if "desc" not in metadata and "description" in metadata:
|
||||
metadata["desc"] = metadata["description"]
|
||||
|
||||
# If metadata loaded successfully, add to result list
|
||||
# 如果成功加载元数据,添加到结果列表
|
||||
if metadata and all(
|
||||
k in metadata for k in ["name", "desc", "version", "author", "repo"]
|
||||
):
|
||||
@@ -140,7 +140,7 @@ def build_plug_list(plugins_dir: Path) -> list:
|
||||
},
|
||||
)
|
||||
|
||||
# Get online plugin list
|
||||
# 获取在线插件列表
|
||||
online_plugins = []
|
||||
try:
|
||||
with httpx.Client() as client:
|
||||
@@ -160,13 +160,13 @@ def build_plug_list(plugins_dir: Path) -> list:
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
click.echo(f"Failed to get online plugin list: {e}", err=True)
|
||||
click.echo(f"获取在线插件列表失败: {e}", err=True)
|
||||
|
||||
# Compare with online plugins and update status
|
||||
# 与在线插件比对,更新状态
|
||||
online_plugin_names = {plugin["name"] for plugin in online_plugins}
|
||||
for local_plugin in result:
|
||||
if local_plugin["name"] in online_plugin_names:
|
||||
# Find the corresponding online plugin
|
||||
# 查找对应的在线插件
|
||||
online_plugin = next(
|
||||
p for p in online_plugins if p["name"] == local_plugin["name"]
|
||||
)
|
||||
@@ -179,10 +179,10 @@ def build_plug_list(plugins_dir: Path) -> list:
|
||||
):
|
||||
local_plugin["status"] = PluginStatus.NEED_UPDATE
|
||||
else:
|
||||
# Local plugin is not published online
|
||||
# 本地插件未在线上发布
|
||||
local_plugin["status"] = PluginStatus.NOT_PUBLISHED
|
||||
|
||||
# Add uninstalled online plugins
|
||||
# 添加未安装的在线插件
|
||||
for online_plugin in online_plugins:
|
||||
if not any(plugin["name"] == online_plugin["name"] for plugin in result):
|
||||
result.append(online_plugin)
|
||||
@@ -196,19 +196,19 @@ def manage_plugin(
|
||||
is_update: bool = False,
|
||||
proxy: str | None = None,
|
||||
) -> None:
|
||||
"""Install or update a plugin
|
||||
"""安装或更新插件
|
||||
|
||||
Args:
|
||||
plugin (dict): Plugin info dict
|
||||
plugins_dir (Path): Plugins directory
|
||||
is_update (bool, optional): Whether this is an update operation. Defaults to False
|
||||
proxy (str, optional): Proxy server address
|
||||
plugin (dict): 插件信息字典
|
||||
plugins_dir (Path): 插件目录
|
||||
is_update (bool, optional): 是否为更新操作. 默认为 False
|
||||
proxy (str, optional): 代理服务器地址
|
||||
|
||||
"""
|
||||
plugin_name = plugin["name"]
|
||||
repo_url = plugin["repo"]
|
||||
|
||||
# If updating and local path exists, use it directly
|
||||
# 如果是更新且有本地路径,直接使用本地路径
|
||||
if is_update and plugin.get("local_path"):
|
||||
target_path = Path(plugin["local_path"])
|
||||
else:
|
||||
@@ -216,13 +216,11 @@ def manage_plugin(
|
||||
|
||||
backup_path = Path(f"{target_path}_backup") if is_update else None
|
||||
|
||||
# Check if plugin exists
|
||||
# 检查插件是否存在
|
||||
if is_update and not target_path.exists():
|
||||
raise click.ClickException(
|
||||
f"Plugin {plugin_name} is not installed and cannot be updated"
|
||||
)
|
||||
raise click.ClickException(f"插件 {plugin_name} 未安装,无法更新")
|
||||
|
||||
# Backup existing plugin
|
||||
# 备份现有插件
|
||||
if is_update and backup_path is not None and backup_path.exists():
|
||||
shutil.rmtree(backup_path)
|
||||
if is_update and backup_path is not None:
|
||||
@@ -230,21 +228,19 @@ def manage_plugin(
|
||||
|
||||
try:
|
||||
click.echo(
|
||||
f"{'Updating' if is_update else 'Downloading'} plugin {plugin_name} from {repo_url}...",
|
||||
f"正在从 {repo_url} {'更新' if is_update else '下载'}插件 {plugin_name}...",
|
||||
)
|
||||
get_git_repo(repo_url, target_path, proxy)
|
||||
|
||||
# Update succeeded, delete backup
|
||||
# 更新成功,删除备份
|
||||
if is_update and backup_path is not None and backup_path.exists():
|
||||
shutil.rmtree(backup_path)
|
||||
click.echo(
|
||||
f"Plugin {plugin_name} {'updated' if is_update else 'installed'} successfully"
|
||||
)
|
||||
click.echo(f"插件 {plugin_name} {'更新' if is_update else '安装'}成功")
|
||||
except Exception as e:
|
||||
if target_path.exists():
|
||||
shutil.rmtree(target_path, ignore_errors=True)
|
||||
if is_update and backup_path is not None and backup_path.exists():
|
||||
shutil.move(backup_path, target_path)
|
||||
raise click.ClickException(
|
||||
f"Error {'updating' if is_update else 'installing'} plugin {plugin_name}: {e}",
|
||||
f"{'更新' if is_update else '安装'}插件 {plugin_name} 时出错: {e}",
|
||||
)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Copied from astrbot.core.utils.version_comparator"""
|
||||
"""拷贝自 astrbot.core.utils.version_comparator"""
|
||||
|
||||
import re
|
||||
|
||||
@@ -6,11 +6,11 @@ import re
|
||||
class VersionComparator:
|
||||
@staticmethod
|
||||
def compare_version(v1: str, v2: str) -> int:
|
||||
"""Compare version numbers according to Semver semantics. Supports version numbers with more than 3 digits and handles pre-release tags.
|
||||
"""根据 Semver 语义版本规范来比较版本号的大小。支持不仅局限于 3 个数字的版本号,并处理预发布标签。
|
||||
|
||||
Reference: https://semver.org/
|
||||
参考: https://semver.org/lang/zh-CN/
|
||||
|
||||
Returns 1 if v1 > v2, -1 if v1 < v2, 0 if v1 == v2.
|
||||
返回 1 表示 v1 > v2,返回 -1 表示 v1 < v2,返回 0 表示 v1 = v2。
|
||||
"""
|
||||
v1 = v1.lower().replace("v", "")
|
||||
v2 = v2.lower().replace("v", "")
|
||||
@@ -24,7 +24,7 @@ class VersionComparator:
|
||||
return [], None
|
||||
major_minor_patch = match.group(1).split(".")
|
||||
prerelease = match.group(2)
|
||||
# buildmetadata = match.group(3) # Build metadata is ignored in comparison
|
||||
# buildmetadata = match.group(3) # 构建元数据在比较时忽略
|
||||
parts = [int(x) for x in major_minor_patch]
|
||||
prerelease = VersionComparator._split_prerelease(prerelease)
|
||||
return parts, prerelease
|
||||
@@ -32,7 +32,7 @@ class VersionComparator:
|
||||
v1_parts, v1_prerelease = split_version(v1)
|
||||
v2_parts, v2_prerelease = split_version(v2)
|
||||
|
||||
# Compare numeric parts
|
||||
# 比较数字部分
|
||||
length = max(len(v1_parts), len(v2_parts))
|
||||
v1_parts.extend([0] * (length - len(v1_parts)))
|
||||
v2_parts.extend([0] * (length - len(v2_parts)))
|
||||
@@ -43,11 +43,11 @@ class VersionComparator:
|
||||
if v1_parts[i] < v2_parts[i]:
|
||||
return -1
|
||||
|
||||
# Compare pre-release tags
|
||||
# 比较预发布标签
|
||||
if v1_prerelease is None and v2_prerelease is not None:
|
||||
return 1 # Version without pre-release tag is higher than one with it
|
||||
return 1 # 没有预发布标签的版本高于有预发布标签的版本
|
||||
if v1_prerelease is not None and v2_prerelease is None:
|
||||
return -1 # Version with pre-release tag is lower than one without it
|
||||
return -1 # 有预发布标签的版本低于没有预发布标签的版本
|
||||
if v1_prerelease is not None and v2_prerelease is not None:
|
||||
len_pre = max(len(v1_prerelease), len(v2_prerelease))
|
||||
for i in range(len_pre):
|
||||
@@ -72,9 +72,9 @@ class VersionComparator:
|
||||
return 1
|
||||
if p1 < p2:
|
||||
return -1
|
||||
return 0 # Pre-release tags are identical
|
||||
return 0 # 预发布标签完全相同
|
||||
|
||||
return 0 # Both numeric parts and pre-release tags are equal
|
||||
return 0 # 数字部分和预发布标签都相同
|
||||
|
||||
@staticmethod
|
||||
def _split_prerelease(prerelease):
|
||||
|
||||
@@ -14,7 +14,7 @@ from .utils.astrbot_path import get_astrbot_data_path
|
||||
# 初始化数据存储文件夹
|
||||
os.makedirs(get_astrbot_data_path(), exist_ok=True)
|
||||
|
||||
DEMO_MODE = os.getenv("DEMO_MODE", "False").strip().lower() in ("true", "1", "t")
|
||||
DEMO_MODE = os.getenv("DEMO_MODE", False)
|
||||
|
||||
astrbot_config = AstrBotConfig()
|
||||
t2i_base_url = astrbot_config.get("t2i_endpoint", "https://t2i.soulter.top/text2img")
|
||||
|
||||
@@ -291,9 +291,6 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
prov_settings: dict = ctx.get_config(umo=umo).get("provider_settings", {})
|
||||
agent_max_step = int(prov_settings.get("max_agent_step", 30))
|
||||
stream = prov_settings.get("streaming_response", False)
|
||||
llm_resp = await ctx.tool_loop_agent(
|
||||
event=event,
|
||||
chat_provider_id=prov_id,
|
||||
@@ -302,8 +299,9 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
|
||||
system_prompt=tool.agent.instructions,
|
||||
tools=toolset,
|
||||
contexts=contexts,
|
||||
max_steps=agent_max_step,
|
||||
stream=stream,
|
||||
max_steps=30,
|
||||
run_hooks=tool.agent.run_hooks,
|
||||
stream=ctx.get_config().get("provider_settings", {}).get("stream", False),
|
||||
)
|
||||
yield mcp.types.CallToolResult(
|
||||
content=[mcp.types.TextContent(type="text", text=llm_resp.completion_text)]
|
||||
|
||||
@@ -20,32 +20,18 @@ from astrbot.core.astr_agent_hooks import MAIN_AGENT_HOOKS
|
||||
from astrbot.core.astr_agent_run_util import AgentRunner
|
||||
from astrbot.core.astr_agent_tool_exec import FunctionToolExecutor
|
||||
from astrbot.core.astr_main_agent_resources import (
|
||||
ANNOTATE_EXECUTION_TOOL,
|
||||
BROWSER_BATCH_EXEC_TOOL,
|
||||
BROWSER_EXEC_TOOL,
|
||||
CHATUI_SPECIAL_DEFAULT_PERSONA_PROMPT,
|
||||
CREATE_SKILL_CANDIDATE_TOOL,
|
||||
CREATE_SKILL_PAYLOAD_TOOL,
|
||||
EVALUATE_SKILL_CANDIDATE_TOOL,
|
||||
EXECUTE_SHELL_TOOL,
|
||||
FILE_DOWNLOAD_TOOL,
|
||||
FILE_UPLOAD_TOOL,
|
||||
GET_EXECUTION_HISTORY_TOOL,
|
||||
GET_SKILL_PAYLOAD_TOOL,
|
||||
KNOWLEDGE_BASE_QUERY_TOOL,
|
||||
LIST_SKILL_CANDIDATES_TOOL,
|
||||
LIST_SKILL_RELEASES_TOOL,
|
||||
LIVE_MODE_SYSTEM_PROMPT,
|
||||
LLM_SAFETY_MODE_SYSTEM_PROMPT,
|
||||
LOCAL_EXECUTE_SHELL_TOOL,
|
||||
LOCAL_PYTHON_TOOL,
|
||||
PROMOTE_SKILL_CANDIDATE_TOOL,
|
||||
PYTHON_TOOL,
|
||||
ROLLBACK_SKILL_RELEASE_TOOL,
|
||||
RUN_BROWSER_SKILL_TOOL,
|
||||
SANDBOX_MODE_PROMPT,
|
||||
SEND_MESSAGE_TO_USER_TOOL,
|
||||
SYNC_SKILL_RELEASE_TOOL,
|
||||
TOOL_CALL_PROMPT,
|
||||
TOOL_CALL_PROMPT_SKILLS_LIKE_MODE,
|
||||
retrieve_knowledge_base,
|
||||
@@ -846,10 +832,7 @@ def _apply_sandbox_tools(
|
||||
) -> None:
|
||||
if req.func_tool is None:
|
||||
req.func_tool = ToolSet()
|
||||
if req.system_prompt is None:
|
||||
req.system_prompt = ""
|
||||
booter = config.sandbox_cfg.get("booter", "shipyard_neo")
|
||||
if booter == "shipyard":
|
||||
if config.sandbox_cfg.get("booter") == "shipyard":
|
||||
ep = config.sandbox_cfg.get("shipyard_endpoint", "")
|
||||
at = config.sandbox_cfg.get("shipyard_access_token", "")
|
||||
if not ep or not at:
|
||||
@@ -857,64 +840,11 @@ def _apply_sandbox_tools(
|
||||
return
|
||||
os.environ["SHIPYARD_ENDPOINT"] = ep
|
||||
os.environ["SHIPYARD_ACCESS_TOKEN"] = at
|
||||
|
||||
req.func_tool.add_tool(EXECUTE_SHELL_TOOL)
|
||||
req.func_tool.add_tool(PYTHON_TOOL)
|
||||
req.func_tool.add_tool(FILE_UPLOAD_TOOL)
|
||||
req.func_tool.add_tool(FILE_DOWNLOAD_TOOL)
|
||||
if booter == "shipyard_neo":
|
||||
# Neo-specific path rule: filesystem tools operate relative to sandbox
|
||||
# workspace root. Do not prepend "/workspace".
|
||||
req.system_prompt += (
|
||||
"\n[Shipyard Neo File Path Rule]\n"
|
||||
"When using sandbox filesystem tools (upload/download/read/write/list/delete), "
|
||||
"always pass paths relative to the sandbox workspace root. "
|
||||
"Example: use `baidu_homepage.png` instead of `/workspace/baidu_homepage.png`.\n"
|
||||
)
|
||||
|
||||
req.system_prompt += (
|
||||
"\n[Neo Skill Lifecycle Workflow]\n"
|
||||
"When user asks to create/update a reusable skill in Neo mode, use lifecycle tools instead of directly writing local skill folders.\n"
|
||||
"Preferred sequence:\n"
|
||||
"1) Use `astrbot_create_skill_payload` to store canonical payload content and get `payload_ref`.\n"
|
||||
"2) Use `astrbot_create_skill_candidate` with `skill_key` + `source_execution_ids` (and optional `payload_ref`) to create a candidate.\n"
|
||||
"3) Use `astrbot_promote_skill_candidate` to release: `stage=canary` for trial; `stage=stable` for production.\n"
|
||||
"For stable release, set `sync_to_local=true` to sync `payload.skill_markdown` into local `SKILL.md`.\n"
|
||||
"Do not treat ad-hoc generated files as reusable Neo skills unless they are captured via payload/candidate/release.\n"
|
||||
"To update an existing skill, create a new payload/candidate and promote a new release version; avoid patching old local folders directly.\n"
|
||||
)
|
||||
|
||||
# Determine sandbox capabilities from an already-booted session.
|
||||
# If no session exists yet (first request), capabilities is None
|
||||
# and we register all tools conservatively.
|
||||
from astrbot.core.computer.computer_client import session_booter
|
||||
|
||||
sandbox_capabilities: list[str] | None = None
|
||||
existing_booter = session_booter.get(session_id)
|
||||
if existing_booter is not None:
|
||||
sandbox_capabilities = getattr(existing_booter, "capabilities", None)
|
||||
|
||||
# Browser tools: only register if profile supports browser
|
||||
# (or if capabilities are unknown because sandbox hasn't booted yet)
|
||||
if sandbox_capabilities is None or "browser" in sandbox_capabilities:
|
||||
req.func_tool.add_tool(BROWSER_EXEC_TOOL)
|
||||
req.func_tool.add_tool(BROWSER_BATCH_EXEC_TOOL)
|
||||
req.func_tool.add_tool(RUN_BROWSER_SKILL_TOOL)
|
||||
|
||||
# Neo-specific tools (always available for shipyard_neo)
|
||||
req.func_tool.add_tool(GET_EXECUTION_HISTORY_TOOL)
|
||||
req.func_tool.add_tool(ANNOTATE_EXECUTION_TOOL)
|
||||
req.func_tool.add_tool(CREATE_SKILL_PAYLOAD_TOOL)
|
||||
req.func_tool.add_tool(GET_SKILL_PAYLOAD_TOOL)
|
||||
req.func_tool.add_tool(CREATE_SKILL_CANDIDATE_TOOL)
|
||||
req.func_tool.add_tool(LIST_SKILL_CANDIDATES_TOOL)
|
||||
req.func_tool.add_tool(EVALUATE_SKILL_CANDIDATE_TOOL)
|
||||
req.func_tool.add_tool(PROMOTE_SKILL_CANDIDATE_TOOL)
|
||||
req.func_tool.add_tool(LIST_SKILL_RELEASES_TOOL)
|
||||
req.func_tool.add_tool(ROLLBACK_SKILL_RELEASE_TOOL)
|
||||
req.func_tool.add_tool(SYNC_SKILL_RELEASE_TOOL)
|
||||
|
||||
req.system_prompt = f"{req.system_prompt or ''}\n{SANDBOX_MODE_PROMPT}\n"
|
||||
req.system_prompt = f"{req.system_prompt}\n{SANDBOX_MODE_PROMPT}\n"
|
||||
|
||||
|
||||
def _proactive_cron_job_tools(req: ProviderRequest) -> None:
|
||||
|
||||
@@ -13,25 +13,11 @@ from astrbot.core.agent.tool import FunctionTool, ToolExecResult
|
||||
from astrbot.core.astr_agent_context import AstrAgentContext
|
||||
from astrbot.core.computer.computer_client import get_booter
|
||||
from astrbot.core.computer.tools import (
|
||||
AnnotateExecutionTool,
|
||||
BrowserBatchExecTool,
|
||||
BrowserExecTool,
|
||||
CreateSkillCandidateTool,
|
||||
CreateSkillPayloadTool,
|
||||
EvaluateSkillCandidateTool,
|
||||
ExecuteShellTool,
|
||||
FileDownloadTool,
|
||||
FileUploadTool,
|
||||
GetExecutionHistoryTool,
|
||||
GetSkillPayloadTool,
|
||||
ListSkillCandidatesTool,
|
||||
ListSkillReleasesTool,
|
||||
LocalPythonTool,
|
||||
PromoteSkillCandidateTool,
|
||||
PythonTool,
|
||||
RollbackSkillReleaseTool,
|
||||
RunBrowserSkillTool,
|
||||
SyncSkillReleaseTool,
|
||||
)
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from astrbot.core.platform.message_session import MessageSession
|
||||
@@ -463,20 +449,6 @@ PYTHON_TOOL = PythonTool()
|
||||
LOCAL_PYTHON_TOOL = LocalPythonTool()
|
||||
FILE_UPLOAD_TOOL = FileUploadTool()
|
||||
FILE_DOWNLOAD_TOOL = FileDownloadTool()
|
||||
BROWSER_EXEC_TOOL = BrowserExecTool()
|
||||
BROWSER_BATCH_EXEC_TOOL = BrowserBatchExecTool()
|
||||
RUN_BROWSER_SKILL_TOOL = RunBrowserSkillTool()
|
||||
GET_EXECUTION_HISTORY_TOOL = GetExecutionHistoryTool()
|
||||
ANNOTATE_EXECUTION_TOOL = AnnotateExecutionTool()
|
||||
CREATE_SKILL_PAYLOAD_TOOL = CreateSkillPayloadTool()
|
||||
GET_SKILL_PAYLOAD_TOOL = GetSkillPayloadTool()
|
||||
CREATE_SKILL_CANDIDATE_TOOL = CreateSkillCandidateTool()
|
||||
LIST_SKILL_CANDIDATES_TOOL = ListSkillCandidatesTool()
|
||||
EVALUATE_SKILL_CANDIDATE_TOOL = EvaluateSkillCandidateTool()
|
||||
PROMOTE_SKILL_CANDIDATE_TOOL = PromoteSkillCandidateTool()
|
||||
LIST_SKILL_RELEASES_TOOL = ListSkillReleasesTool()
|
||||
ROLLBACK_SKILL_RELEASE_TOOL = RollbackSkillReleaseTool()
|
||||
SYNC_SKILL_RELEASE_TOOL = SyncSkillReleaseTool()
|
||||
|
||||
# we prevent astrbot from connecting to known malicious hosts
|
||||
# these hosts are base64 encoded
|
||||
|
||||
@@ -1,9 +1,4 @@
|
||||
from ..olayer import (
|
||||
BrowserComponent,
|
||||
FileSystemComponent,
|
||||
PythonComponent,
|
||||
ShellComponent,
|
||||
)
|
||||
from ..olayer import FileSystemComponent, PythonComponent, ShellComponent
|
||||
|
||||
|
||||
class ComputerBooter:
|
||||
@@ -16,19 +11,6 @@ class ComputerBooter:
|
||||
@property
|
||||
def shell(self) -> ShellComponent: ...
|
||||
|
||||
@property
|
||||
def capabilities(self) -> tuple[str, ...] | None:
|
||||
"""Sandbox capabilities (e.g. ('python', 'shell', 'filesystem', 'browser')).
|
||||
|
||||
Returns None if the booter doesn't support capability introspection
|
||||
(backward-compatible default). Subclasses override after boot.
|
||||
"""
|
||||
return None
|
||||
|
||||
@property
|
||||
def browser(self) -> BrowserComponent | None:
|
||||
return None
|
||||
|
||||
async def boot(self, session_id: str) -> None: ...
|
||||
|
||||
async def shutdown(self) -> None: ...
|
||||
|
||||
@@ -1,258 +0,0 @@
|
||||
"""Manage Bay container lifecycle for zero-config Shipyard Neo integration.
|
||||
|
||||
When no Bay endpoint is configured, AstrBot can automatically start a Bay
|
||||
container using the Docker socket (like BoxliteBooter does for Ship
|
||||
containers).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import io
|
||||
import json
|
||||
import tarfile
|
||||
from typing import Any
|
||||
|
||||
import aiodocker
|
||||
import aiohttp
|
||||
|
||||
from astrbot.api import logger
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Constants
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
BAY_IMAGE = "ghcr.io/astrbotdevs/shipyard-neo-bay:latest"
|
||||
BAY_CONTAINER_NAME = "astrbot-bay"
|
||||
BAY_LABEL = "astrbot.bay.managed"
|
||||
BAY_PORT = 8114
|
||||
HEALTH_TIMEOUT_S = 60
|
||||
HEALTH_POLL_INTERVAL_S = 2
|
||||
|
||||
|
||||
class BayContainerManager:
|
||||
"""Start / reuse / stop a Bay container via Docker Engine API."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image: str = BAY_IMAGE,
|
||||
host_port: int = BAY_PORT,
|
||||
) -> None:
|
||||
self._image = image
|
||||
self._host_port = host_port
|
||||
self._docker: aiodocker.Docker | None = None
|
||||
self._container: Any = None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public API
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def ensure_running(self) -> str:
|
||||
"""Make sure a Bay container is running. Returns the endpoint URL.
|
||||
|
||||
If a container labelled ``astrbot.bay.managed`` already exists
|
||||
and is running, it will be reused. Otherwise a new container is
|
||||
created from *self._image*.
|
||||
"""
|
||||
try:
|
||||
self._docker = aiodocker.Docker()
|
||||
except Exception as exc:
|
||||
raise RuntimeError(
|
||||
"Failed to connect to Docker daemon. "
|
||||
"Ensure Docker is installed and running, or configure "
|
||||
"an explicit Bay endpoint instead of auto-start mode."
|
||||
) from exc
|
||||
|
||||
# 1. Look for an existing managed container
|
||||
existing = await self._find_managed_container()
|
||||
if existing is not None:
|
||||
state = existing["State"]
|
||||
if state.get("Running"):
|
||||
cid = existing["Id"][:12]
|
||||
logger.info("[BayManager] Reusing existing Bay container: %s", cid)
|
||||
self._container = await self._docker.containers.get(existing["Id"])
|
||||
return f"http://127.0.0.1:{self._host_port}"
|
||||
else:
|
||||
# Container exists but stopped — restart it
|
||||
logger.info("[BayManager] Restarting stopped Bay container")
|
||||
container = await self._docker.containers.get(existing["Id"])
|
||||
await container.start()
|
||||
self._container = container
|
||||
return f"http://127.0.0.1:{self._host_port}"
|
||||
|
||||
# 2. Pull image if needed
|
||||
await self._pull_image_if_needed()
|
||||
|
||||
# 3. Create and start container
|
||||
logger.info(
|
||||
"[BayManager] Starting Bay container: image=%s, port=%d",
|
||||
self._image,
|
||||
self._host_port,
|
||||
)
|
||||
config = {
|
||||
"Image": self._image,
|
||||
"Labels": {BAY_LABEL: "true"},
|
||||
"Env": [
|
||||
"BAY_SERVER__HOST=0.0.0.0",
|
||||
f"BAY_SERVER__PORT={BAY_PORT}",
|
||||
"BAY_DATA_DIR=/app/data",
|
||||
# allow_anonymous=false → auto-provisions API key
|
||||
"BAY_SECURITY__ALLOW_ANONYMOUS=false",
|
||||
],
|
||||
"HostConfig": {
|
||||
"PortBindings": {
|
||||
f"{BAY_PORT}/tcp": [{"HostPort": str(self._host_port)}],
|
||||
},
|
||||
"Binds": [
|
||||
# Bay needs Docker socket to create sandbox containers
|
||||
"/var/run/docker.sock:/var/run/docker.sock",
|
||||
],
|
||||
"RestartPolicy": {"Name": "unless-stopped"},
|
||||
},
|
||||
}
|
||||
self._container = await self._docker.containers.create_or_replace(
|
||||
BAY_CONTAINER_NAME, config
|
||||
)
|
||||
await self._container.start()
|
||||
logger.info("[BayManager] Bay container started: %s", BAY_CONTAINER_NAME)
|
||||
|
||||
return f"http://127.0.0.1:{self._host_port}"
|
||||
|
||||
async def wait_healthy(self, timeout: int = HEALTH_TIMEOUT_S) -> None:
|
||||
"""Block until Bay's ``/health`` endpoint returns 200."""
|
||||
url = f"http://127.0.0.1:{self._host_port}/health"
|
||||
deadline = asyncio.get_event_loop().time() + timeout
|
||||
last_error: str = ""
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
while asyncio.get_event_loop().time() < deadline:
|
||||
try:
|
||||
async with session.get(
|
||||
url, timeout=aiohttp.ClientTimeout(total=3)
|
||||
) as resp:
|
||||
if resp.status == 200:
|
||||
logger.info("[BayManager] Bay is healthy")
|
||||
return
|
||||
last_error = f"HTTP {resp.status}"
|
||||
except Exception as exc:
|
||||
last_error = str(exc)
|
||||
|
||||
await asyncio.sleep(HEALTH_POLL_INTERVAL_S)
|
||||
|
||||
raise TimeoutError(
|
||||
f"Bay did not become healthy within {timeout}s (last error: {last_error})"
|
||||
)
|
||||
|
||||
async def read_credentials(self) -> str:
|
||||
"""Read auto-provisioned API key from Bay container.
|
||||
|
||||
Bay writes ``credentials.json`` to its data directory when
|
||||
``allow_anonymous=false`` and no explicit API key is set.
|
||||
"""
|
||||
if self._container is None:
|
||||
return ""
|
||||
|
||||
try:
|
||||
# Read credentials.json from container filesystem
|
||||
tar_stream = await self._container.get_archive("/app/data/credentials.json")
|
||||
# get_archive returns (tar_data, stat)
|
||||
tar_data = tar_stream
|
||||
|
||||
if isinstance(tar_data, dict):
|
||||
raw = tar_data.get("data", b"")
|
||||
elif isinstance(tar_data, tuple):
|
||||
# (stream, stat_info)
|
||||
raw = b""
|
||||
stream = tar_data[0]
|
||||
if hasattr(stream, "read"):
|
||||
raw = await stream.read()
|
||||
elif isinstance(stream, bytes):
|
||||
raw = stream
|
||||
else:
|
||||
# It might be a chunked response
|
||||
chunks = []
|
||||
async for chunk in stream:
|
||||
chunks.append(chunk)
|
||||
raw = b"".join(chunks)
|
||||
else:
|
||||
raw = tar_data if isinstance(tar_data, bytes) else b""
|
||||
|
||||
if not raw:
|
||||
logger.debug("[BayManager] Empty tar response from container")
|
||||
return ""
|
||||
|
||||
tario = io.BytesIO(raw)
|
||||
with tarfile.open(fileobj=tario) as tar:
|
||||
for member in tar.getmembers():
|
||||
f = tar.extractfile(member)
|
||||
if f:
|
||||
creds = json.loads(f.read().decode("utf-8"))
|
||||
api_key = creds.get("api_key", "")
|
||||
if api_key:
|
||||
masked = (
|
||||
f"{api_key[:8]}..."
|
||||
if len(api_key) >= 10
|
||||
else "redacted"
|
||||
)
|
||||
logger.info(
|
||||
"[BayManager] Auto-discovered Bay API key: %s",
|
||||
masked,
|
||||
)
|
||||
return api_key
|
||||
except Exception as exc:
|
||||
logger.debug(
|
||||
"[BayManager] Failed to read credentials from container: %s", exc
|
||||
)
|
||||
|
||||
return ""
|
||||
|
||||
async def close_client(self) -> None:
|
||||
"""Close the Docker client without stopping the container.
|
||||
|
||||
The Bay container stays running for reuse by future sessions.
|
||||
"""
|
||||
if self._docker is not None:
|
||||
await self._docker.close()
|
||||
self._docker = None
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop and remove the managed Bay container."""
|
||||
if self._container is not None:
|
||||
try:
|
||||
await self._container.stop()
|
||||
await self._container.delete(force=True)
|
||||
logger.info("[BayManager] Bay container stopped and removed")
|
||||
except Exception as exc:
|
||||
logger.debug("[BayManager] Error stopping Bay container: %s", exc)
|
||||
finally:
|
||||
self._container = None
|
||||
|
||||
await self.close_client()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Private helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _find_managed_container(self) -> dict | None:
|
||||
"""Find an existing container with our management label."""
|
||||
assert self._docker is not None
|
||||
containers = await self._docker.containers.list(
|
||||
all=True,
|
||||
filters=json.dumps({"label": [f"{BAY_LABEL}=true"]}),
|
||||
)
|
||||
if containers:
|
||||
# Inspect first match to get full state
|
||||
return await containers[0].show()
|
||||
return None
|
||||
|
||||
async def _pull_image_if_needed(self) -> None:
|
||||
"""Pull the Bay image if it doesn't exist locally."""
|
||||
assert self._docker is not None
|
||||
try:
|
||||
await self._docker.images.inspect(self._image)
|
||||
logger.debug("[BayManager] Image %s already exists", self._image)
|
||||
except aiodocker.exceptions.DockerError:
|
||||
logger.info("[BayManager] Pulling image %s ...", self._image)
|
||||
# Pull with progress logging
|
||||
await self._docker.images.pull(self._image)
|
||||
logger.info("[BayManager] Image %s pulled successfully", self._image)
|
||||
@@ -64,10 +64,6 @@ class MockShipyardSandboxClient:
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
async with session.post(url, data=data) as response:
|
||||
if response.status == 200:
|
||||
logger.info(
|
||||
"[Computer] File uploaded to Boxlite sandbox: %s",
|
||||
remote_path,
|
||||
)
|
||||
return {
|
||||
"success": True,
|
||||
"message": "File uploaded successfully",
|
||||
|
||||
@@ -31,7 +31,7 @@ class ShipyardBooter(ComputerBooter):
|
||||
self._ship = ship
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
logger.info("[Computer] Shipyard booter shutdown.")
|
||||
pass
|
||||
|
||||
@property
|
||||
def fs(self) -> FileSystemComponent:
|
||||
@@ -47,19 +47,11 @@ class ShipyardBooter(ComputerBooter):
|
||||
|
||||
async def upload_file(self, path: str, file_name: str) -> dict:
|
||||
"""Upload file to sandbox"""
|
||||
result = await self._ship.upload_file(path, file_name)
|
||||
logger.info("[Computer] File uploaded to Shipyard sandbox: %s", file_name)
|
||||
return result
|
||||
return await self._ship.upload_file(path, file_name)
|
||||
|
||||
async def download_file(self, remote_path: str, local_path: str):
|
||||
"""Download file from sandbox."""
|
||||
result = await self._ship.download_file(remote_path, local_path)
|
||||
logger.info(
|
||||
"[Computer] File downloaded from Shipyard sandbox: %s -> %s",
|
||||
remote_path,
|
||||
local_path,
|
||||
)
|
||||
return result
|
||||
return await self._ship.download_file(remote_path, local_path)
|
||||
|
||||
async def available(self) -> bool:
|
||||
"""Check if the sandbox is available."""
|
||||
@@ -67,17 +59,8 @@ class ShipyardBooter(ComputerBooter):
|
||||
ship_id = self._ship.id
|
||||
data = await self._sandbox_client.get_ship(ship_id)
|
||||
if not data:
|
||||
logger.info(
|
||||
"[Computer] Shipyard sandbox health check: id=%s, healthy=False (no data)",
|
||||
ship_id,
|
||||
)
|
||||
return False
|
||||
health = bool(data.get("status", 0) == 1)
|
||||
logger.info(
|
||||
"[Computer] Shipyard sandbox health check: id=%s, healthy=%s",
|
||||
ship_id,
|
||||
health,
|
||||
)
|
||||
return health
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking Shipyard sandbox availability: {e}")
|
||||
|
||||
@@ -1,513 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import shlex
|
||||
from typing import Any, cast
|
||||
|
||||
from astrbot.api import logger
|
||||
|
||||
from ..olayer import (
|
||||
BrowserComponent,
|
||||
FileSystemComponent,
|
||||
PythonComponent,
|
||||
ShellComponent,
|
||||
)
|
||||
from .base import ComputerBooter
|
||||
|
||||
|
||||
def _maybe_model_dump(value: Any) -> dict[str, Any]:
|
||||
if isinstance(value, dict):
|
||||
return value
|
||||
if hasattr(value, "model_dump"):
|
||||
dumped = value.model_dump()
|
||||
if isinstance(dumped, dict):
|
||||
return dumped
|
||||
return {}
|
||||
|
||||
|
||||
class NeoPythonComponent(PythonComponent):
|
||||
def __init__(self, sandbox: Any) -> None:
|
||||
self._sandbox = sandbox
|
||||
|
||||
async def exec(
|
||||
self,
|
||||
code: str,
|
||||
kernel_id: str | None = None,
|
||||
timeout: int = 30,
|
||||
silent: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
_ = kernel_id # Bay runtime does not expose kernel_id in current SDK.
|
||||
result = await self._sandbox.python.exec(code, timeout=timeout)
|
||||
payload = _maybe_model_dump(result)
|
||||
|
||||
output_text = payload.get("output", "") or ""
|
||||
error_text = payload.get("error", "") or ""
|
||||
data = payload.get("data") if isinstance(payload.get("data"), dict) else {}
|
||||
rich_output = data.get("output") if isinstance(data.get("output"), dict) else {}
|
||||
if not isinstance(rich_output.get("images"), list):
|
||||
rich_output["images"] = []
|
||||
if "text" not in rich_output:
|
||||
rich_output["text"] = output_text
|
||||
|
||||
if silent:
|
||||
rich_output["text"] = ""
|
||||
|
||||
return {
|
||||
"success": bool(payload.get("success", error_text == "")),
|
||||
"data": {
|
||||
"output": rich_output,
|
||||
"error": error_text,
|
||||
},
|
||||
"execution_id": payload.get("execution_id"),
|
||||
"execution_time_ms": payload.get("execution_time_ms"),
|
||||
"code": payload.get("code"),
|
||||
"output": output_text,
|
||||
"error": error_text,
|
||||
}
|
||||
|
||||
|
||||
class NeoShellComponent(ShellComponent):
|
||||
def __init__(self, sandbox: Any) -> None:
|
||||
self._sandbox = sandbox
|
||||
|
||||
async def exec(
|
||||
self,
|
||||
command: str,
|
||||
cwd: str | None = None,
|
||||
env: dict[str, str] | None = None,
|
||||
timeout: int | None = 30,
|
||||
shell: bool = True,
|
||||
background: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
if not shell:
|
||||
return {
|
||||
"stdout": "",
|
||||
"stderr": "error: only shell mode is supported in shipyard_neo booter.",
|
||||
"exit_code": 2,
|
||||
"success": False,
|
||||
}
|
||||
|
||||
run_command = command
|
||||
if env:
|
||||
env_prefix = " ".join(
|
||||
f"{k}={shlex.quote(str(v))}" for k, v in sorted(env.items())
|
||||
)
|
||||
run_command = f"{env_prefix} {run_command}"
|
||||
|
||||
if background:
|
||||
run_command = f"nohup sh -lc {shlex.quote(run_command)} >/tmp/astrbot_bg.log 2>&1 & echo $!"
|
||||
|
||||
result = await self._sandbox.shell.exec(
|
||||
run_command,
|
||||
timeout=timeout or 30,
|
||||
cwd=cwd,
|
||||
)
|
||||
payload = _maybe_model_dump(result)
|
||||
|
||||
stdout = payload.get("output", "") or ""
|
||||
stderr = payload.get("error", "") or ""
|
||||
exit_code = payload.get("exit_code")
|
||||
if background:
|
||||
pid: int | None = None
|
||||
try:
|
||||
pid = int(stdout.strip().splitlines()[-1])
|
||||
except Exception:
|
||||
pid = None
|
||||
return {
|
||||
"pid": pid,
|
||||
"stdout": stdout,
|
||||
"stderr": stderr,
|
||||
"exit_code": exit_code,
|
||||
"success": bool(payload.get("success", not stderr)),
|
||||
"execution_id": payload.get("execution_id"),
|
||||
"execution_time_ms": payload.get("execution_time_ms"),
|
||||
"command": payload.get("command"),
|
||||
}
|
||||
|
||||
return {
|
||||
"stdout": stdout,
|
||||
"stderr": stderr,
|
||||
"exit_code": exit_code,
|
||||
"success": bool(payload.get("success", not stderr)),
|
||||
"execution_id": payload.get("execution_id"),
|
||||
"execution_time_ms": payload.get("execution_time_ms"),
|
||||
"command": payload.get("command"),
|
||||
}
|
||||
|
||||
|
||||
class NeoFileSystemComponent(FileSystemComponent):
|
||||
def __init__(self, sandbox: Any) -> None:
|
||||
self._sandbox = sandbox
|
||||
|
||||
async def create_file(
|
||||
self,
|
||||
path: str,
|
||||
content: str = "",
|
||||
mode: int = 0o644,
|
||||
) -> dict[str, Any]:
|
||||
_ = mode
|
||||
await self._sandbox.filesystem.write_file(path, content)
|
||||
return {"success": True, "path": path}
|
||||
|
||||
async def read_file(self, path: str, encoding: str = "utf-8") -> dict[str, Any]:
|
||||
_ = encoding
|
||||
content = await self._sandbox.filesystem.read_file(path)
|
||||
return {"success": True, "path": path, "content": content}
|
||||
|
||||
async def write_file(
|
||||
self,
|
||||
path: str,
|
||||
content: str,
|
||||
mode: str = "w",
|
||||
encoding: str = "utf-8",
|
||||
) -> dict[str, Any]:
|
||||
_ = mode
|
||||
_ = encoding
|
||||
await self._sandbox.filesystem.write_file(path, content)
|
||||
return {"success": True, "path": path}
|
||||
|
||||
async def delete_file(self, path: str) -> dict[str, Any]:
|
||||
await self._sandbox.filesystem.delete(path)
|
||||
return {"success": True, "path": path}
|
||||
|
||||
async def list_dir(
|
||||
self,
|
||||
path: str = ".",
|
||||
show_hidden: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
entries = await self._sandbox.filesystem.list_dir(path)
|
||||
data = []
|
||||
for entry in entries:
|
||||
item = _maybe_model_dump(entry)
|
||||
if not show_hidden and str(item.get("name", "")).startswith("."):
|
||||
continue
|
||||
data.append(item)
|
||||
return {"success": True, "path": path, "entries": data}
|
||||
|
||||
|
||||
class NeoBrowserComponent(BrowserComponent):
|
||||
def __init__(self, sandbox: Any) -> None:
|
||||
self._sandbox = sandbox
|
||||
|
||||
async def exec(
|
||||
self,
|
||||
cmd: str,
|
||||
timeout: int = 30,
|
||||
description: str | None = None,
|
||||
tags: str | None = None,
|
||||
learn: bool = False,
|
||||
include_trace: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
result = await self._sandbox.browser.exec(
|
||||
cmd,
|
||||
timeout=timeout,
|
||||
description=description,
|
||||
tags=tags,
|
||||
learn=learn,
|
||||
include_trace=include_trace,
|
||||
)
|
||||
return _maybe_model_dump(result)
|
||||
|
||||
async def exec_batch(
|
||||
self,
|
||||
commands: list[str],
|
||||
timeout: int = 60,
|
||||
stop_on_error: bool = True,
|
||||
description: str | None = None,
|
||||
tags: str | None = None,
|
||||
learn: bool = False,
|
||||
include_trace: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
result = await self._sandbox.browser.exec_batch(
|
||||
commands,
|
||||
timeout=timeout,
|
||||
stop_on_error=stop_on_error,
|
||||
description=description,
|
||||
tags=tags,
|
||||
learn=learn,
|
||||
include_trace=include_trace,
|
||||
)
|
||||
return _maybe_model_dump(result)
|
||||
|
||||
async def run_skill(
|
||||
self,
|
||||
skill_key: str,
|
||||
timeout: int = 60,
|
||||
stop_on_error: bool = True,
|
||||
include_trace: bool = False,
|
||||
description: str | None = None,
|
||||
tags: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
result = await self._sandbox.browser.run_skill(
|
||||
skill_key=skill_key,
|
||||
timeout=timeout,
|
||||
stop_on_error=stop_on_error,
|
||||
include_trace=include_trace,
|
||||
description=description,
|
||||
tags=tags,
|
||||
)
|
||||
return _maybe_model_dump(result)
|
||||
|
||||
|
||||
class ShipyardNeoBooter(ComputerBooter):
|
||||
"""Booter backed by Shipyard Neo (Bay).
|
||||
|
||||
If *endpoint_url* is empty or set to ``"__auto__"``, Bay will be
|
||||
started automatically as a Docker container (like Boxlite does for
|
||||
Ship containers).
|
||||
"""
|
||||
|
||||
AUTO_SENTINEL = "__auto__"
|
||||
DEFAULT_PROFILE = "python-default"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
endpoint_url: str,
|
||||
access_token: str,
|
||||
profile: str = DEFAULT_PROFILE,
|
||||
ttl: int = 3600,
|
||||
) -> None:
|
||||
self._endpoint_url = endpoint_url
|
||||
self._access_token = access_token
|
||||
self._profile = profile
|
||||
self._ttl = ttl
|
||||
self._client: Any = None
|
||||
self._sandbox: Any = None
|
||||
self._bay_manager: Any = None # BayContainerManager when auto-started
|
||||
self._fs: FileSystemComponent | None = None
|
||||
self._python: PythonComponent | None = None
|
||||
self._shell: ShellComponent | None = None
|
||||
self._browser: BrowserComponent | None = None
|
||||
|
||||
@property
|
||||
def bay_client(self) -> Any:
|
||||
return self._client
|
||||
|
||||
@property
|
||||
def sandbox(self) -> Any:
|
||||
return self._sandbox
|
||||
|
||||
@property
|
||||
def capabilities(self) -> tuple[str, ...] | None:
|
||||
"""Sandbox capabilities from the Bay profile.
|
||||
|
||||
Returns an immutable tuple after :meth:`boot`; ``None`` before boot.
|
||||
"""
|
||||
if self._sandbox is None:
|
||||
return None
|
||||
caps = getattr(self._sandbox, "capabilities", None)
|
||||
return tuple(caps) if caps is not None else None
|
||||
|
||||
@property
|
||||
def is_auto_mode(self) -> bool:
|
||||
"""True when Bay should be auto-started."""
|
||||
ep = (self._endpoint_url or "").strip()
|
||||
return not ep or ep == self.AUTO_SENTINEL
|
||||
|
||||
async def boot(self, session_id: str) -> None:
|
||||
_ = session_id
|
||||
|
||||
# --- Auto-start Bay if needed ---
|
||||
if self.is_auto_mode:
|
||||
from .bay_manager import BayContainerManager
|
||||
|
||||
# Clean up previous manager if re-booting
|
||||
if self._bay_manager is not None:
|
||||
await self._bay_manager.close_client()
|
||||
|
||||
logger.info("[Computer] Neo auto-start mode: launching Bay container")
|
||||
self._bay_manager = BayContainerManager()
|
||||
self._endpoint_url = await self._bay_manager.ensure_running()
|
||||
await self._bay_manager.wait_healthy()
|
||||
# Read auto-provisioned credentials
|
||||
if not self._access_token:
|
||||
self._access_token = await self._bay_manager.read_credentials()
|
||||
logger.info("[Computer] Bay auto-started at %s", self._endpoint_url)
|
||||
|
||||
if not self._endpoint_url or not self._access_token:
|
||||
if self._bay_manager is not None:
|
||||
raise ValueError(
|
||||
"Bay container started but credentials could not be read. "
|
||||
"Ensure Bay generated credentials.json, or set access_token manually."
|
||||
)
|
||||
raise ValueError(
|
||||
"Shipyard Neo sandbox configuration is incomplete. "
|
||||
"Set endpoint (default http://127.0.0.1:8114) and access token, "
|
||||
"or ensure Bay's credentials.json is accessible for auto-discovery."
|
||||
)
|
||||
|
||||
from shipyard_neo import BayClient
|
||||
|
||||
self._client = BayClient(
|
||||
endpoint_url=self._endpoint_url,
|
||||
access_token=self._access_token,
|
||||
)
|
||||
await self._client.__aenter__()
|
||||
|
||||
# Resolve profile: user-specified > smart selection > default
|
||||
resolved_profile = await self._resolve_profile(self._client)
|
||||
|
||||
self._sandbox = await self._client.create_sandbox(
|
||||
profile=resolved_profile,
|
||||
ttl=self._ttl,
|
||||
)
|
||||
|
||||
self._fs = NeoFileSystemComponent(self._sandbox)
|
||||
self._python = NeoPythonComponent(self._sandbox)
|
||||
self._shell = NeoShellComponent(self._sandbox)
|
||||
|
||||
caps = self.capabilities or ()
|
||||
self._browser = (
|
||||
NeoBrowserComponent(self._sandbox) if "browser" in caps else None
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Got Shipyard Neo sandbox: %s (profile=%s, capabilities=%s, auto=%s)",
|
||||
self._sandbox.id,
|
||||
resolved_profile,
|
||||
list(caps),
|
||||
bool(self._bay_manager),
|
||||
)
|
||||
|
||||
async def _resolve_profile(self, client: Any) -> str:
|
||||
"""Pick the best profile for this session.
|
||||
|
||||
Resolution order:
|
||||
1. User-specified profile (non-empty, non-default) → use as-is.
|
||||
2. Query ``GET /v1/profiles`` and pick the profile with the most
|
||||
capabilities, preferring profiles that include ``"browser"``.
|
||||
3. Fall back to :attr:`DEFAULT_PROFILE`.
|
||||
|
||||
Auth errors (401/403) are re-raised immediately — they indicate a
|
||||
misconfigured token, and silently falling back would just delay the
|
||||
real failure to ``create_sandbox``.
|
||||
"""
|
||||
# User explicitly set a profile → honour it
|
||||
if self._profile and self._profile != self.DEFAULT_PROFILE:
|
||||
logger.info("[Computer] Using user-specified profile: %s", self._profile)
|
||||
return self._profile
|
||||
|
||||
# Query Bay for available profiles
|
||||
from shipyard_neo.errors import ForbiddenError, UnauthorizedError
|
||||
|
||||
try:
|
||||
profile_list = await client.list_profiles()
|
||||
profiles = profile_list.items
|
||||
except (UnauthorizedError, ForbiddenError):
|
||||
raise # auth errors must not be silenced
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"[Computer] Failed to query Bay profiles, falling back to %s: %s",
|
||||
self.DEFAULT_PROFILE,
|
||||
exc,
|
||||
)
|
||||
return self.DEFAULT_PROFILE
|
||||
|
||||
if not profiles:
|
||||
return self.DEFAULT_PROFILE
|
||||
|
||||
def _score(p: Any) -> tuple[int, int]:
|
||||
"""(has_browser, capability_count) — higher is better."""
|
||||
caps = getattr(p, "capabilities", []) or []
|
||||
return (1 if "browser" in caps else 0, len(caps))
|
||||
|
||||
best = max(profiles, key=_score)
|
||||
chosen = getattr(best, "id", self.DEFAULT_PROFILE)
|
||||
|
||||
if chosen != self.DEFAULT_PROFILE:
|
||||
caps = getattr(best, "capabilities", [])
|
||||
logger.info(
|
||||
"[Computer] Auto-selected profile %s (capabilities=%s)",
|
||||
chosen,
|
||||
caps,
|
||||
)
|
||||
|
||||
return chosen
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
if self._client is not None:
|
||||
sandbox_id = getattr(self._sandbox, "id", "unknown")
|
||||
logger.info(
|
||||
"[Computer] Shutting down Shipyard Neo sandbox: id=%s", sandbox_id
|
||||
)
|
||||
await self._client.__aexit__(None, None, None)
|
||||
self._client = None
|
||||
self._sandbox = None
|
||||
logger.info("[Computer] Shipyard Neo sandbox shut down: id=%s", sandbox_id)
|
||||
|
||||
# NOTE: We intentionally do NOT stop the Bay container here.
|
||||
# It stays running for reuse by future sessions. The user can
|
||||
# stop it manually or via ``BayContainerManager.stop()``.
|
||||
if self._bay_manager is not None:
|
||||
await self._bay_manager.close_client()
|
||||
|
||||
@property
|
||||
def fs(self) -> FileSystemComponent:
|
||||
if self._fs is None:
|
||||
raise RuntimeError("ShipyardNeoBooter is not initialized.")
|
||||
return self._fs
|
||||
|
||||
@property
|
||||
def python(self) -> PythonComponent:
|
||||
if self._python is None:
|
||||
raise RuntimeError("ShipyardNeoBooter is not initialized.")
|
||||
return self._python
|
||||
|
||||
@property
|
||||
def shell(self) -> ShellComponent:
|
||||
if self._shell is None:
|
||||
raise RuntimeError("ShipyardNeoBooter is not initialized.")
|
||||
return self._shell
|
||||
|
||||
@property
|
||||
def browser(self) -> BrowserComponent:
|
||||
if self._browser is None:
|
||||
raise RuntimeError("ShipyardNeoBooter is not initialized.")
|
||||
return self._browser
|
||||
|
||||
async def upload_file(self, path: str, file_name: str) -> dict:
|
||||
if self._sandbox is None:
|
||||
raise RuntimeError("ShipyardNeoBooter is not initialized.")
|
||||
with open(path, "rb") as f:
|
||||
content = f.read()
|
||||
remote_path = file_name.lstrip("/")
|
||||
await self._sandbox.filesystem.upload(remote_path, content)
|
||||
logger.info("[Computer] File uploaded to Neo sandbox: %s", remote_path)
|
||||
return {
|
||||
"success": True,
|
||||
"message": "File uploaded successfully",
|
||||
"file_path": remote_path,
|
||||
}
|
||||
|
||||
async def download_file(self, remote_path: str, local_path: str) -> None:
|
||||
if self._sandbox is None:
|
||||
raise RuntimeError("ShipyardNeoBooter is not initialized.")
|
||||
content = await self._sandbox.filesystem.download(remote_path.lstrip("/"))
|
||||
local_dir = os.path.dirname(local_path)
|
||||
if local_dir:
|
||||
os.makedirs(local_dir, exist_ok=True)
|
||||
with open(local_path, "wb") as f:
|
||||
f.write(cast(bytes, content))
|
||||
logger.info(
|
||||
"[Computer] File downloaded from Neo sandbox: %s -> %s",
|
||||
remote_path,
|
||||
local_path,
|
||||
)
|
||||
|
||||
async def available(self) -> bool:
|
||||
if self._sandbox is None:
|
||||
return False
|
||||
try:
|
||||
await self._sandbox.refresh()
|
||||
status = getattr(self._sandbox.status, "value", str(self._sandbox.status))
|
||||
healthy = status not in {"failed", "expired"}
|
||||
logger.info(
|
||||
"[Computer] Neo sandbox health check: id=%s, status=%s, healthy=%s",
|
||||
getattr(self._sandbox, "id", "unknown"),
|
||||
status,
|
||||
healthy,
|
||||
)
|
||||
return healthy
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking Shipyard Neo sandbox availability: {e}")
|
||||
return False
|
||||
@@ -1,11 +1,10 @@
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
from astrbot.api import logger
|
||||
from astrbot.core.skills.skill_manager import SANDBOX_SKILLS_ROOT, SkillManager
|
||||
from astrbot.core.skills.skill_manager import SANDBOX_SKILLS_ROOT
|
||||
from astrbot.core.star.context import Context
|
||||
from astrbot.core.utils.astrbot_path import (
|
||||
get_astrbot_skills_path,
|
||||
@@ -17,401 +16,45 @@ from .booters.local import LocalBooter
|
||||
|
||||
session_booter: dict[str, ComputerBooter] = {}
|
||||
local_booter: ComputerBooter | None = None
|
||||
_MANAGED_SKILLS_FILE = ".astrbot_managed_skills.json"
|
||||
|
||||
|
||||
def _list_local_skill_dirs(skills_root: Path) -> list[Path]:
|
||||
skills: list[Path] = []
|
||||
for entry in sorted(skills_root.iterdir()):
|
||||
if not entry.is_dir():
|
||||
continue
|
||||
skill_md = entry / "SKILL.md"
|
||||
if skill_md.exists():
|
||||
skills.append(entry)
|
||||
return skills
|
||||
|
||||
|
||||
def _discover_bay_credentials(endpoint: str) -> str:
|
||||
"""Try to auto-discover Bay API key from credentials.json.
|
||||
|
||||
Search order:
|
||||
1. BAY_DATA_DIR env var
|
||||
2. Mono-repo relative path: ../pkgs/bay/ (dev layout)
|
||||
3. Current working directory
|
||||
|
||||
Returns:
|
||||
API key string, or empty string if not found.
|
||||
"""
|
||||
candidates: list[Path] = []
|
||||
|
||||
# 1. BAY_DATA_DIR env var
|
||||
bay_data_dir = os.environ.get("BAY_DATA_DIR")
|
||||
if bay_data_dir:
|
||||
candidates.append(Path(bay_data_dir) / "credentials.json")
|
||||
|
||||
# 2. Mono-repo layout: AstrBot/../pkgs/bay/credentials.json
|
||||
astrbot_root = Path(__file__).resolve().parents[3] # astrbot/core/computer/ → root
|
||||
candidates.append(astrbot_root.parent / "pkgs" / "bay" / "credentials.json")
|
||||
|
||||
# 3. Current working directory
|
||||
candidates.append(Path.cwd() / "credentials.json")
|
||||
|
||||
for cred_path in candidates:
|
||||
if not cred_path.is_file():
|
||||
continue
|
||||
try:
|
||||
data = json.loads(cred_path.read_text())
|
||||
api_key = data.get("api_key", "")
|
||||
if api_key:
|
||||
# Optionally verify endpoint matches
|
||||
cred_endpoint = data.get("endpoint", "")
|
||||
if (
|
||||
cred_endpoint
|
||||
and endpoint
|
||||
and cred_endpoint.rstrip("/") != endpoint.rstrip("/")
|
||||
):
|
||||
logger.warning(
|
||||
"[Computer] credentials.json endpoint mismatch: "
|
||||
"file=%s, configured=%s — using key anyway",
|
||||
cred_endpoint,
|
||||
endpoint,
|
||||
)
|
||||
masked_key = f"{api_key[:4]}..." if len(api_key) >= 6 else "redacted"
|
||||
logger.info(
|
||||
"[Computer] Auto-discovered Bay API key from %s (prefix=%s)",
|
||||
cred_path,
|
||||
masked_key,
|
||||
)
|
||||
return api_key
|
||||
except (json.JSONDecodeError, OSError) as exc:
|
||||
logger.debug("[Computer] Failed to read %s: %s", cred_path, exc)
|
||||
|
||||
logger.debug("[Computer] No Bay credentials.json found in search paths")
|
||||
return ""
|
||||
|
||||
|
||||
def _build_python_exec_command(script: str) -> str:
|
||||
return (
|
||||
"if command -v python3 >/dev/null 2>&1; then PYBIN=python3; "
|
||||
"elif command -v python >/dev/null 2>&1; then PYBIN=python; "
|
||||
"else echo 'python not found in sandbox' >&2; exit 127; fi; "
|
||||
"$PYBIN - <<'PY'\n"
|
||||
f"{script}\n"
|
||||
"PY"
|
||||
)
|
||||
|
||||
|
||||
def _build_apply_sync_command() -> str:
|
||||
"""Build shell command for sync stage only.
|
||||
|
||||
This stage mutates sandbox files (managed skill replacement) but does not scan
|
||||
metadata. Keeping it separate allows callers to preserve old behavior while
|
||||
reusing the apply step independently.
|
||||
"""
|
||||
script = f"""
|
||||
import json
|
||||
import shutil
|
||||
import zipfile
|
||||
from pathlib import Path
|
||||
|
||||
root = Path({SANDBOX_SKILLS_ROOT!r})
|
||||
zip_path = root / "skills.zip"
|
||||
tmp_extract = Path(f"{{root}}_tmp_extract")
|
||||
managed_file = root / {_MANAGED_SKILLS_FILE!r}
|
||||
|
||||
|
||||
def remove_tree(path: Path) -> None:
|
||||
if not path.exists():
|
||||
return
|
||||
if path.is_dir():
|
||||
shutil.rmtree(path, ignore_errors=True)
|
||||
else:
|
||||
path.unlink(missing_ok=True)
|
||||
|
||||
|
||||
def load_managed_skills() -> list[str]:
|
||||
if not managed_file.exists():
|
||||
return []
|
||||
try:
|
||||
payload = json.loads(managed_file.read_text(encoding="utf-8"))
|
||||
except Exception:
|
||||
return []
|
||||
if not isinstance(payload, dict):
|
||||
return []
|
||||
items = payload.get("managed_skills", [])
|
||||
if not isinstance(items, list):
|
||||
return []
|
||||
result: list[str] = []
|
||||
for item in items:
|
||||
if isinstance(item, str) and item.strip():
|
||||
result.append(item.strip())
|
||||
return result
|
||||
|
||||
|
||||
root.mkdir(parents=True, exist_ok=True)
|
||||
for managed_name in load_managed_skills():
|
||||
remove_tree(root / managed_name)
|
||||
|
||||
current_managed: list[str] = []
|
||||
if zip_path.exists():
|
||||
remove_tree(tmp_extract)
|
||||
tmp_extract.mkdir(parents=True, exist_ok=True)
|
||||
with zipfile.ZipFile(zip_path) as zf:
|
||||
zf.extractall(tmp_extract)
|
||||
for entry in sorted(tmp_extract.iterdir()):
|
||||
if not entry.is_dir():
|
||||
continue
|
||||
target = root / entry.name
|
||||
remove_tree(target)
|
||||
shutil.copytree(entry, target)
|
||||
current_managed.append(entry.name)
|
||||
|
||||
remove_tree(tmp_extract)
|
||||
remove_tree(zip_path)
|
||||
managed_file.write_text(
|
||||
json.dumps({{"managed_skills": current_managed}}, ensure_ascii=False, indent=2),
|
||||
encoding="utf-8",
|
||||
)
|
||||
print(json.dumps({{"managed_skills": current_managed}}, ensure_ascii=False))
|
||||
""".strip()
|
||||
return _build_python_exec_command(script)
|
||||
|
||||
|
||||
def _build_scan_command() -> str:
|
||||
"""Build shell command for scan stage only.
|
||||
|
||||
This stage is read-oriented: it scans SKILL.md metadata and returns the
|
||||
historical payload shape consumed by cache update logic.
|
||||
|
||||
The scan resolves the absolute path of the skills root at runtime so
|
||||
that the LLM can reliably ``cat`` skill files regardless of cwd.
|
||||
Only the ``description`` field is extracted from frontmatter.
|
||||
"""
|
||||
script = f"""
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
root = Path({SANDBOX_SKILLS_ROOT!r})
|
||||
managed_file = root / {_MANAGED_SKILLS_FILE!r}
|
||||
|
||||
# Resolve absolute path at runtime so prompts always have a reliable path
|
||||
root_abs = str(root.resolve())
|
||||
|
||||
|
||||
# NOTE: This parser mirrors skill_manager._parse_frontmatter_description.
|
||||
# Keep the two implementations in sync when changing parsing logic.
|
||||
def parse_description(text: str) -> str:
|
||||
if not text.startswith("---"):
|
||||
return ""
|
||||
lines = text.splitlines()
|
||||
if not lines or lines[0].strip() != "---":
|
||||
return ""
|
||||
end_idx = None
|
||||
for i in range(1, len(lines)):
|
||||
if lines[i].strip() == "---":
|
||||
end_idx = i
|
||||
break
|
||||
if end_idx is None:
|
||||
return ""
|
||||
for line in lines[1:end_idx]:
|
||||
if ":" not in line:
|
||||
continue
|
||||
key, value = line.split(":", 1)
|
||||
if key.strip().lower() == "description":
|
||||
return value.strip().strip('"').strip("'")
|
||||
return ""
|
||||
|
||||
|
||||
def load_managed_skills() -> list[str]:
|
||||
if not managed_file.exists():
|
||||
return []
|
||||
try:
|
||||
payload = json.loads(managed_file.read_text(encoding="utf-8"))
|
||||
except Exception:
|
||||
return []
|
||||
if not isinstance(payload, dict):
|
||||
return []
|
||||
items = payload.get("managed_skills", [])
|
||||
if not isinstance(items, list):
|
||||
return []
|
||||
result: list[str] = []
|
||||
for item in items:
|
||||
if isinstance(item, str) and item.strip():
|
||||
result.append(item.strip())
|
||||
return result
|
||||
|
||||
|
||||
def collect_skills() -> list[dict[str, str]]:
|
||||
skills: list[dict[str, str]] = []
|
||||
if not root.exists():
|
||||
return skills
|
||||
for skill_dir in sorted(root.iterdir()):
|
||||
if not skill_dir.is_dir():
|
||||
continue
|
||||
skill_md = skill_dir / "SKILL.md"
|
||||
if not skill_md.is_file():
|
||||
continue
|
||||
description = ""
|
||||
try:
|
||||
text = skill_md.read_text(encoding="utf-8")
|
||||
description = parse_description(text)
|
||||
except Exception:
|
||||
description = ""
|
||||
skills.append(
|
||||
{{
|
||||
"name": skill_dir.name,
|
||||
"description": description,
|
||||
"path": f"{{root_abs}}/{{skill_dir.name}}/SKILL.md",
|
||||
}}
|
||||
)
|
||||
return skills
|
||||
|
||||
|
||||
print(
|
||||
json.dumps(
|
||||
{{
|
||||
"managed_skills": load_managed_skills(),
|
||||
"skills": collect_skills(),
|
||||
}},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
)
|
||||
""".strip()
|
||||
return _build_python_exec_command(script)
|
||||
|
||||
|
||||
def _build_sync_and_scan_command() -> str:
|
||||
"""Legacy combined command kept for backward compatibility.
|
||||
|
||||
New code paths should prefer apply + scan split helpers.
|
||||
"""
|
||||
return f"{_build_apply_sync_command()}\n{_build_scan_command()}"
|
||||
|
||||
|
||||
def _shell_exec_succeeded(result: dict) -> bool:
|
||||
if "success" in result:
|
||||
return bool(result.get("success"))
|
||||
exit_code = result.get("exit_code")
|
||||
return exit_code in (0, None)
|
||||
|
||||
|
||||
def _format_exec_error_detail(result: dict) -> str:
|
||||
"""Format shell execution details for better observability.
|
||||
|
||||
Keep the message compact while still surfacing exit code and stderr/stdout.
|
||||
"""
|
||||
exit_code = result.get("exit_code")
|
||||
stderr = str(result.get("stderr", "") or "").strip()
|
||||
stdout = str(result.get("stdout", "") or "").strip()
|
||||
stderr_text = stderr[:500]
|
||||
stdout_text = stdout[:300]
|
||||
return f"exit_code={exit_code}, stderr={stderr_text!r}, stdout_tail={stdout_text!r}"
|
||||
|
||||
|
||||
def _decode_sync_payload(stdout: str) -> dict | None:
|
||||
text = stdout.strip()
|
||||
if not text:
|
||||
return None
|
||||
candidates = [text]
|
||||
candidates.extend([line.strip() for line in text.splitlines() if line.strip()])
|
||||
for candidate in reversed(candidates):
|
||||
try:
|
||||
payload = json.loads(candidate)
|
||||
except Exception:
|
||||
continue
|
||||
if isinstance(payload, dict):
|
||||
return payload
|
||||
return None
|
||||
|
||||
|
||||
def _update_sandbox_skills_cache(payload: dict | None) -> None:
|
||||
if not isinstance(payload, dict):
|
||||
return
|
||||
skills = payload.get("skills", [])
|
||||
if not isinstance(skills, list):
|
||||
return
|
||||
SkillManager().set_sandbox_skills_cache(skills)
|
||||
|
||||
|
||||
async def _apply_skills_to_sandbox(booter: ComputerBooter) -> None:
|
||||
"""Apply local skill bundle to sandbox filesystem only.
|
||||
|
||||
This function is intentionally limited to file mutation. Metadata scanning is
|
||||
executed in a separate phase to keep failure domains clear.
|
||||
"""
|
||||
logger.info("[Computer] Skill sync phase=apply start")
|
||||
apply_result = await booter.shell.exec(_build_apply_sync_command())
|
||||
if not _shell_exec_succeeded(apply_result):
|
||||
detail = _format_exec_error_detail(apply_result)
|
||||
logger.error("[Computer] Skill sync phase=apply failed: %s", detail)
|
||||
raise RuntimeError(f"Failed to apply sandbox skill sync strategy: {detail}")
|
||||
logger.info("[Computer] Skill sync phase=apply done")
|
||||
|
||||
|
||||
async def _scan_sandbox_skills(booter: ComputerBooter) -> dict | None:
|
||||
"""Scan sandbox skills and return normalized payload for cache update."""
|
||||
logger.info("[Computer] Skill sync phase=scan start")
|
||||
scan_result = await booter.shell.exec(_build_scan_command())
|
||||
if not _shell_exec_succeeded(scan_result):
|
||||
detail = _format_exec_error_detail(scan_result)
|
||||
logger.error("[Computer] Skill sync phase=scan failed: %s", detail)
|
||||
raise RuntimeError(f"Failed to scan sandbox skills after sync: {detail}")
|
||||
|
||||
payload = _decode_sync_payload(str(scan_result.get("stdout", "") or ""))
|
||||
if payload is None:
|
||||
logger.warning("[Computer] Skill sync phase=scan returned empty payload")
|
||||
else:
|
||||
logger.info("[Computer] Skill sync phase=scan done")
|
||||
return payload
|
||||
|
||||
|
||||
async def _sync_skills_to_sandbox(booter: ComputerBooter) -> None:
|
||||
"""Sync local skills to sandbox and refresh cache.
|
||||
|
||||
Backward-compatible orchestrator: keep historical behavior while internally
|
||||
splitting into `apply` and `scan` phases.
|
||||
"""
|
||||
skills_root = Path(get_astrbot_skills_path())
|
||||
if not skills_root.is_dir():
|
||||
skills_root = get_astrbot_skills_path()
|
||||
if not os.path.isdir(skills_root):
|
||||
return
|
||||
if not any(Path(skills_root).iterdir()):
|
||||
return
|
||||
local_skill_dirs = _list_local_skill_dirs(skills_root)
|
||||
|
||||
temp_dir = Path(get_astrbot_temp_path())
|
||||
temp_dir.mkdir(parents=True, exist_ok=True)
|
||||
zip_base = temp_dir / "skills_bundle"
|
||||
zip_path = zip_base.with_suffix(".zip")
|
||||
temp_dir = get_astrbot_temp_path()
|
||||
os.makedirs(temp_dir, exist_ok=True)
|
||||
zip_base = os.path.join(temp_dir, "skills_bundle")
|
||||
zip_path = f"{zip_base}.zip"
|
||||
|
||||
try:
|
||||
if local_skill_dirs:
|
||||
if zip_path.exists():
|
||||
zip_path.unlink()
|
||||
shutil.make_archive(str(zip_base), "zip", str(skills_root))
|
||||
remote_zip = Path(SANDBOX_SKILLS_ROOT) / "skills.zip"
|
||||
logger.info("Uploading skills bundle to sandbox...")
|
||||
await booter.shell.exec(f"mkdir -p {SANDBOX_SKILLS_ROOT}")
|
||||
upload_result = await booter.upload_file(str(zip_path), str(remote_zip))
|
||||
if not upload_result.get("success", False):
|
||||
raise RuntimeError("Failed to upload skills bundle to sandbox.")
|
||||
else:
|
||||
logger.info(
|
||||
"No local skills found. Keeping sandbox built-ins and refreshing metadata."
|
||||
)
|
||||
await booter.shell.exec(f"rm -f {SANDBOX_SKILLS_ROOT}/skills.zip")
|
||||
|
||||
# Keep backward-compatible behavior while splitting lifecycle into two
|
||||
# observable phases: apply (filesystem mutation) + scan (metadata read).
|
||||
await _apply_skills_to_sandbox(booter)
|
||||
payload = await _scan_sandbox_skills(booter)
|
||||
_update_sandbox_skills_cache(payload)
|
||||
managed = payload.get("managed_skills", []) if isinstance(payload, dict) else []
|
||||
logger.info(
|
||||
"[Computer] Sandbox skill sync complete: managed=%d",
|
||||
len(managed),
|
||||
if os.path.exists(zip_path):
|
||||
os.remove(zip_path)
|
||||
shutil.make_archive(zip_base, "zip", skills_root)
|
||||
remote_zip = Path(SANDBOX_SKILLS_ROOT) / "skills.zip"
|
||||
logger.info("Uploading skills bundle to sandbox...")
|
||||
await booter.shell.exec(f"mkdir -p {SANDBOX_SKILLS_ROOT}")
|
||||
upload_result = await booter.upload_file(zip_path, str(remote_zip))
|
||||
if not upload_result.get("success", False):
|
||||
raise RuntimeError("Failed to upload skills bundle to sandbox.")
|
||||
# Use -n flag to never overwrite existing files, fallback to Python if unzip unavailable
|
||||
await booter.shell.exec(
|
||||
f"unzip -n {remote_zip} -d {SANDBOX_SKILLS_ROOT} || "
|
||||
f"python3 -c \"import zipfile, os, pathlib; z=zipfile.ZipFile('{remote_zip}'); "
|
||||
f"[z.extract(m, '{SANDBOX_SKILLS_ROOT}') for m in z.namelist() "
|
||||
f"if not os.path.exists(os.path.join('{SANDBOX_SKILLS_ROOT}', m))]\" || "
|
||||
f"python -c \"import zipfile, os, pathlib; z=zipfile.ZipFile('{remote_zip}'); "
|
||||
f"[z.extract(m, '{SANDBOX_SKILLS_ROOT}') for m in z.namelist() "
|
||||
f"if not os.path.exists(os.path.join('{SANDBOX_SKILLS_ROOT}', m))]\"; "
|
||||
f"rm -f {remote_zip}"
|
||||
)
|
||||
finally:
|
||||
if zip_path.exists():
|
||||
if os.path.exists(zip_path):
|
||||
try:
|
||||
zip_path.unlink()
|
||||
os.remove(zip_path)
|
||||
except Exception:
|
||||
logger.warning(f"Failed to remove temp skills zip: {zip_path}")
|
||||
|
||||
@@ -423,7 +66,7 @@ async def get_booter(
|
||||
config = context.get_config(umo=session_id)
|
||||
|
||||
sandbox_cfg = config.get("provider_settings", {}).get("sandbox", {})
|
||||
booter_type = sandbox_cfg.get("booter", "shipyard_neo")
|
||||
booter_type = sandbox_cfg.get("booter", "shipyard")
|
||||
|
||||
if session_id in session_booter:
|
||||
booter = session_booter[session_id]
|
||||
@@ -432,9 +75,6 @@ async def get_booter(
|
||||
session_booter.pop(session_id, None)
|
||||
if session_id not in session_booter:
|
||||
uuid_str = uuid.uuid5(uuid.NAMESPACE_DNS, session_id).hex
|
||||
logger.info(
|
||||
f"[Computer] Initializing booter: type={booter_type}, session={session_id}"
|
||||
)
|
||||
if booter_type == "shipyard":
|
||||
from .booters.shipyard import ShipyardBooter
|
||||
|
||||
@@ -446,27 +86,6 @@ async def get_booter(
|
||||
client = ShipyardBooter(
|
||||
endpoint_url=ep, access_token=token, ttl=ttl, session_num=max_sessions
|
||||
)
|
||||
elif booter_type == "shipyard_neo":
|
||||
from .booters.shipyard_neo import ShipyardNeoBooter
|
||||
|
||||
ep = sandbox_cfg.get("shipyard_neo_endpoint", "")
|
||||
token = sandbox_cfg.get("shipyard_neo_access_token", "")
|
||||
ttl = sandbox_cfg.get("shipyard_neo_ttl", 3600)
|
||||
profile = sandbox_cfg.get("shipyard_neo_profile", "python-default")
|
||||
|
||||
# Auto-discover token from Bay's credentials.json if not configured
|
||||
if not token:
|
||||
token = _discover_bay_credentials(ep)
|
||||
|
||||
logger.info(
|
||||
f"[Computer] Shipyard Neo config: endpoint={ep}, profile={profile}, ttl={ttl}"
|
||||
)
|
||||
client = ShipyardNeoBooter(
|
||||
endpoint_url=ep,
|
||||
access_token=token,
|
||||
profile=profile,
|
||||
ttl=ttl,
|
||||
)
|
||||
elif booter_type == "boxlite":
|
||||
from .booters.boxlite import BoxliteBooter
|
||||
|
||||
@@ -476,9 +95,6 @@ async def get_booter(
|
||||
|
||||
try:
|
||||
await client.boot(uuid_str)
|
||||
logger.info(
|
||||
f"[Computer] Sandbox booted successfully: type={booter_type}, session={session_id}"
|
||||
)
|
||||
await _sync_skills_to_sandbox(client)
|
||||
except Exception as e:
|
||||
logger.error(f"Error booting sandbox for session {session_id}: {e}")
|
||||
@@ -488,24 +104,6 @@ async def get_booter(
|
||||
return session_booter[session_id]
|
||||
|
||||
|
||||
async def sync_skills_to_active_sandboxes() -> None:
|
||||
"""Best-effort skills synchronization for all active sandbox sessions."""
|
||||
logger.info(
|
||||
"[Computer] Syncing skills to %d active sandbox(es)", len(session_booter)
|
||||
)
|
||||
for session_id, booter in list(session_booter.items()):
|
||||
try:
|
||||
if not await booter.available():
|
||||
continue
|
||||
await _sync_skills_to_sandbox(booter)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to sync skills to sandbox for session %s: %s",
|
||||
session_id,
|
||||
e,
|
||||
)
|
||||
|
||||
|
||||
def get_local_booter() -> ComputerBooter:
|
||||
global local_booter
|
||||
if local_booter is None:
|
||||
|
||||
@@ -1,11 +1,5 @@
|
||||
from .browser import BrowserComponent
|
||||
from .filesystem import FileSystemComponent
|
||||
from .python import PythonComponent
|
||||
from .shell import ShellComponent
|
||||
|
||||
__all__ = [
|
||||
"PythonComponent",
|
||||
"ShellComponent",
|
||||
"FileSystemComponent",
|
||||
"BrowserComponent",
|
||||
]
|
||||
__all__ = ["PythonComponent", "ShellComponent", "FileSystemComponent"]
|
||||
|
||||
@@ -1,46 +0,0 @@
|
||||
"""
|
||||
Browser automation component
|
||||
"""
|
||||
|
||||
from typing import Any, Protocol
|
||||
|
||||
|
||||
class BrowserComponent(Protocol):
|
||||
"""Browser operations component"""
|
||||
|
||||
async def exec(
|
||||
self,
|
||||
cmd: str,
|
||||
timeout: int = 30,
|
||||
description: str | None = None,
|
||||
tags: str | None = None,
|
||||
learn: bool = False,
|
||||
include_trace: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
"""Execute a browser automation command"""
|
||||
...
|
||||
|
||||
async def exec_batch(
|
||||
self,
|
||||
commands: list[str],
|
||||
timeout: int = 60,
|
||||
stop_on_error: bool = True,
|
||||
description: str | None = None,
|
||||
tags: str | None = None,
|
||||
learn: bool = False,
|
||||
include_trace: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
"""Execute a browser automation command batch"""
|
||||
...
|
||||
|
||||
async def run_skill(
|
||||
self,
|
||||
skill_key: str,
|
||||
timeout: int = 60,
|
||||
stop_on_error: bool = True,
|
||||
include_trace: bool = False,
|
||||
description: str | None = None,
|
||||
tags: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Run a browser skill by skill key"""
|
||||
...
|
||||
@@ -1,36 +1,8 @@
|
||||
from .browser import BrowserBatchExecTool, BrowserExecTool, RunBrowserSkillTool
|
||||
from .fs import FileDownloadTool, FileUploadTool
|
||||
from .neo_skills import (
|
||||
AnnotateExecutionTool,
|
||||
CreateSkillCandidateTool,
|
||||
CreateSkillPayloadTool,
|
||||
EvaluateSkillCandidateTool,
|
||||
GetExecutionHistoryTool,
|
||||
GetSkillPayloadTool,
|
||||
ListSkillCandidatesTool,
|
||||
ListSkillReleasesTool,
|
||||
PromoteSkillCandidateTool,
|
||||
RollbackSkillReleaseTool,
|
||||
SyncSkillReleaseTool,
|
||||
)
|
||||
from .python import LocalPythonTool, PythonTool
|
||||
from .shell import ExecuteShellTool
|
||||
|
||||
__all__ = [
|
||||
"BrowserExecTool",
|
||||
"BrowserBatchExecTool",
|
||||
"RunBrowserSkillTool",
|
||||
"GetExecutionHistoryTool",
|
||||
"AnnotateExecutionTool",
|
||||
"CreateSkillPayloadTool",
|
||||
"GetSkillPayloadTool",
|
||||
"CreateSkillCandidateTool",
|
||||
"ListSkillCandidatesTool",
|
||||
"EvaluateSkillCandidateTool",
|
||||
"PromoteSkillCandidateTool",
|
||||
"ListSkillReleasesTool",
|
||||
"RollbackSkillReleaseTool",
|
||||
"SyncSkillReleaseTool",
|
||||
"FileUploadTool",
|
||||
"PythonTool",
|
||||
"LocalPythonTool",
|
||||
|
||||
@@ -1,204 +0,0 @@
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from astrbot.api import FunctionTool
|
||||
from astrbot.core.agent.run_context import ContextWrapper
|
||||
from astrbot.core.agent.tool import ToolExecResult
|
||||
from astrbot.core.astr_agent_context import AstrAgentContext
|
||||
|
||||
from ..computer_client import get_booter
|
||||
|
||||
|
||||
def _to_json(data: Any) -> str:
|
||||
return json.dumps(data, ensure_ascii=False, default=str)
|
||||
|
||||
|
||||
def _ensure_admin(context: ContextWrapper[AstrAgentContext]) -> str | None:
|
||||
if context.context.event.role != "admin":
|
||||
return (
|
||||
"error: Permission denied. Browser and skill lifecycle tools are only allowed "
|
||||
"for admin users."
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
async def _get_browser_component(context: ContextWrapper[AstrAgentContext]) -> Any:
|
||||
booter = await get_booter(
|
||||
context.context.context,
|
||||
context.context.event.unified_msg_origin,
|
||||
)
|
||||
browser = getattr(booter, "browser", None)
|
||||
if browser is None:
|
||||
raise RuntimeError(
|
||||
"Current sandbox booter does not support browser capability. "
|
||||
"Please switch to shipyard_neo."
|
||||
)
|
||||
return browser
|
||||
|
||||
|
||||
@dataclass
|
||||
class BrowserExecTool(FunctionTool):
|
||||
name: str = "astrbot_execute_browser"
|
||||
description: str = "Execute one browser automation command in the sandbox."
|
||||
parameters: dict = field(
|
||||
default_factory=lambda: {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"cmd": {"type": "string", "description": "Browser command to execute."},
|
||||
"timeout": {"type": "integer", "default": 30},
|
||||
"description": {
|
||||
"type": "string",
|
||||
"description": "Optional execution description.",
|
||||
},
|
||||
"tags": {"type": "string", "description": "Optional tags."},
|
||||
"learn": {
|
||||
"type": "boolean",
|
||||
"description": "Whether to mark execution as learn evidence.",
|
||||
"default": False,
|
||||
},
|
||||
"include_trace": {
|
||||
"type": "boolean",
|
||||
"description": "Whether to include trace_ref in response.",
|
||||
"default": False,
|
||||
},
|
||||
},
|
||||
"required": ["cmd"],
|
||||
}
|
||||
)
|
||||
|
||||
async def call(
|
||||
self,
|
||||
context: ContextWrapper[AstrAgentContext],
|
||||
cmd: str,
|
||||
timeout: int = 30,
|
||||
description: str | None = None,
|
||||
tags: str | None = None,
|
||||
learn: bool = False,
|
||||
include_trace: bool = False,
|
||||
) -> ToolExecResult:
|
||||
if err := _ensure_admin(context):
|
||||
return err
|
||||
try:
|
||||
browser = await _get_browser_component(context)
|
||||
result = await browser.exec(
|
||||
cmd=cmd,
|
||||
timeout=timeout,
|
||||
description=description,
|
||||
tags=tags,
|
||||
learn=learn,
|
||||
include_trace=include_trace,
|
||||
)
|
||||
return _to_json(result)
|
||||
except Exception as e:
|
||||
return f"Error executing browser command: {str(e)}"
|
||||
|
||||
|
||||
@dataclass
|
||||
class BrowserBatchExecTool(FunctionTool):
|
||||
name: str = "astrbot_execute_browser_batch"
|
||||
description: str = "Execute a browser command batch in the sandbox."
|
||||
parameters: dict = field(
|
||||
default_factory=lambda: {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"commands": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Ordered browser commands.",
|
||||
},
|
||||
"timeout": {"type": "integer", "default": 60},
|
||||
"stop_on_error": {"type": "boolean", "default": True},
|
||||
"description": {
|
||||
"type": "string",
|
||||
"description": "Optional execution description.",
|
||||
},
|
||||
"tags": {"type": "string", "description": "Optional tags."},
|
||||
"learn": {
|
||||
"type": "boolean",
|
||||
"description": "Whether to mark execution as learn evidence.",
|
||||
"default": False,
|
||||
},
|
||||
"include_trace": {
|
||||
"type": "boolean",
|
||||
"description": "Whether to include trace_ref in response.",
|
||||
"default": False,
|
||||
},
|
||||
},
|
||||
"required": ["commands"],
|
||||
}
|
||||
)
|
||||
|
||||
async def call(
|
||||
self,
|
||||
context: ContextWrapper[AstrAgentContext],
|
||||
commands: list[str],
|
||||
timeout: int = 60,
|
||||
stop_on_error: bool = True,
|
||||
description: str | None = None,
|
||||
tags: str | None = None,
|
||||
learn: bool = False,
|
||||
include_trace: bool = False,
|
||||
) -> ToolExecResult:
|
||||
if err := _ensure_admin(context):
|
||||
return err
|
||||
try:
|
||||
browser = await _get_browser_component(context)
|
||||
result = await browser.exec_batch(
|
||||
commands=commands,
|
||||
timeout=timeout,
|
||||
stop_on_error=stop_on_error,
|
||||
description=description,
|
||||
tags=tags,
|
||||
learn=learn,
|
||||
include_trace=include_trace,
|
||||
)
|
||||
return _to_json(result)
|
||||
except Exception as e:
|
||||
return f"Error executing browser batch command: {str(e)}"
|
||||
|
||||
|
||||
@dataclass
|
||||
class RunBrowserSkillTool(FunctionTool):
|
||||
name: str = "astrbot_run_browser_skill"
|
||||
description: str = "Run a released browser skill in the sandbox by skill_key."
|
||||
parameters: dict = field(
|
||||
default_factory=lambda: {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"skill_key": {"type": "string"},
|
||||
"timeout": {"type": "integer", "default": 60},
|
||||
"stop_on_error": {"type": "boolean", "default": True},
|
||||
"include_trace": {"type": "boolean", "default": False},
|
||||
"description": {"type": "string"},
|
||||
"tags": {"type": "string"},
|
||||
},
|
||||
"required": ["skill_key"],
|
||||
}
|
||||
)
|
||||
|
||||
async def call(
|
||||
self,
|
||||
context: ContextWrapper[AstrAgentContext],
|
||||
skill_key: str,
|
||||
timeout: int = 60,
|
||||
stop_on_error: bool = True,
|
||||
include_trace: bool = False,
|
||||
description: str | None = None,
|
||||
tags: str | None = None,
|
||||
) -> ToolExecResult:
|
||||
if err := _ensure_admin(context):
|
||||
return err
|
||||
try:
|
||||
browser = await _get_browser_component(context)
|
||||
result = await browser.run_skill(
|
||||
skill_key=skill_key,
|
||||
timeout=timeout,
|
||||
stop_on_error=stop_on_error,
|
||||
include_trace=include_trace,
|
||||
description=description,
|
||||
tags=tags,
|
||||
)
|
||||
return _to_json(result)
|
||||
except Exception as e:
|
||||
return f"Error running browser skill: {str(e)}"
|
||||
@@ -1,542 +0,0 @@
|
||||
import json
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from astrbot.api import FunctionTool
|
||||
from astrbot.core.agent.run_context import ContextWrapper
|
||||
from astrbot.core.agent.tool import ToolExecResult
|
||||
from astrbot.core.astr_agent_context import AstrAgentContext
|
||||
from astrbot.core.skills.neo_skill_sync import NeoSkillSyncManager
|
||||
|
||||
from ..computer_client import get_booter
|
||||
|
||||
|
||||
def _to_jsonable(model_like: Any) -> Any:
|
||||
if isinstance(model_like, dict):
|
||||
return model_like
|
||||
if isinstance(model_like, list):
|
||||
return [_to_jsonable(i) for i in model_like]
|
||||
if hasattr(model_like, "model_dump"):
|
||||
return _to_jsonable(model_like.model_dump())
|
||||
return model_like
|
||||
|
||||
|
||||
def _to_json_text(data: Any) -> str:
|
||||
return json.dumps(_to_jsonable(data), ensure_ascii=False, default=str)
|
||||
|
||||
|
||||
def _ensure_admin(context: ContextWrapper[AstrAgentContext]) -> str | None:
|
||||
if context.context.event.role != "admin":
|
||||
return "error: Permission denied. Skill lifecycle tools are only allowed for admin users."
|
||||
return None
|
||||
|
||||
|
||||
async def _get_neo_context(
|
||||
context: ContextWrapper[AstrAgentContext],
|
||||
) -> tuple[Any, Any]:
|
||||
booter = await get_booter(
|
||||
context.context.context,
|
||||
context.context.event.unified_msg_origin,
|
||||
)
|
||||
client = getattr(booter, "bay_client", None)
|
||||
sandbox = getattr(booter, "sandbox", None)
|
||||
if client is None or sandbox is None:
|
||||
raise RuntimeError(
|
||||
"Current sandbox booter does not support Neo skill lifecycle APIs. "
|
||||
"Please switch to shipyard_neo."
|
||||
)
|
||||
return client, sandbox
|
||||
|
||||
|
||||
@dataclass
|
||||
class NeoSkillToolBase(FunctionTool):
|
||||
error_prefix: str = "Error"
|
||||
|
||||
async def _run(
|
||||
self,
|
||||
context: ContextWrapper[AstrAgentContext],
|
||||
neo_call: Callable[[Any, Any], Awaitable[Any]],
|
||||
error_action: str,
|
||||
) -> ToolExecResult:
|
||||
if err := _ensure_admin(context):
|
||||
return err
|
||||
try:
|
||||
client, sandbox = await _get_neo_context(context)
|
||||
result = await neo_call(client, sandbox)
|
||||
return _to_json_text(result)
|
||||
except Exception as e:
|
||||
return f"{self.error_prefix} {error_action}: {str(e)}"
|
||||
|
||||
|
||||
@dataclass
|
||||
class GetExecutionHistoryTool(NeoSkillToolBase):
|
||||
name: str = "astrbot_get_execution_history"
|
||||
description: str = "Get execution history from current sandbox."
|
||||
parameters: dict = field(
|
||||
default_factory=lambda: {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"exec_type": {"type": "string"},
|
||||
"success_only": {"type": "boolean", "default": False},
|
||||
"limit": {"type": "integer", "default": 100},
|
||||
"offset": {"type": "integer", "default": 0},
|
||||
"tags": {"type": "string"},
|
||||
"has_notes": {"type": "boolean", "default": False},
|
||||
"has_description": {"type": "boolean", "default": False},
|
||||
},
|
||||
"required": [],
|
||||
}
|
||||
)
|
||||
|
||||
async def call(
|
||||
self,
|
||||
context: ContextWrapper[AstrAgentContext],
|
||||
exec_type: str | None = None,
|
||||
success_only: bool = False,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
tags: str | None = None,
|
||||
has_notes: bool = False,
|
||||
has_description: bool = False,
|
||||
) -> ToolExecResult:
|
||||
return await self._run(
|
||||
context,
|
||||
lambda _client, sandbox: sandbox.get_execution_history(
|
||||
exec_type=exec_type,
|
||||
success_only=success_only,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
tags=tags,
|
||||
has_notes=has_notes,
|
||||
has_description=has_description,
|
||||
),
|
||||
error_action="getting execution history",
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AnnotateExecutionTool(NeoSkillToolBase):
|
||||
name: str = "astrbot_annotate_execution"
|
||||
description: str = "Annotate one execution history record."
|
||||
parameters: dict = field(
|
||||
default_factory=lambda: {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"execution_id": {"type": "string"},
|
||||
"description": {"type": "string"},
|
||||
"tags": {"type": "string"},
|
||||
"notes": {"type": "string"},
|
||||
},
|
||||
"required": ["execution_id"],
|
||||
}
|
||||
)
|
||||
|
||||
async def call(
|
||||
self,
|
||||
context: ContextWrapper[AstrAgentContext],
|
||||
execution_id: str,
|
||||
description: str | None = None,
|
||||
tags: str | None = None,
|
||||
notes: str | None = None,
|
||||
) -> ToolExecResult:
|
||||
return await self._run(
|
||||
context,
|
||||
lambda _client, sandbox: sandbox.annotate_execution(
|
||||
execution_id=execution_id,
|
||||
description=description,
|
||||
tags=tags,
|
||||
notes=notes,
|
||||
),
|
||||
error_action="annotating execution",
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CreateSkillPayloadTool(NeoSkillToolBase):
|
||||
name: str = "astrbot_create_skill_payload"
|
||||
description: str = (
|
||||
"Step 1/3 for Neo skill authoring: create immutable payload content and return payload_ref. "
|
||||
"Use this to store skill_markdown and structured metadata; do NOT write local skill folders directly."
|
||||
)
|
||||
parameters: dict = field(
|
||||
default_factory=lambda: {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"payload": {
|
||||
"anyOf": [{"type": "object"}, {"type": "array"}],
|
||||
"description": (
|
||||
"Skill payload JSON. Typical schema: {skill_markdown, inputs, outputs, meta}. "
|
||||
"This only stores content and returns payload_ref; it does not create a candidate or release."
|
||||
),
|
||||
},
|
||||
"kind": {
|
||||
"type": "string",
|
||||
"description": "Payload kind.",
|
||||
"default": "astrbot_skill_v1",
|
||||
},
|
||||
},
|
||||
"required": ["payload"],
|
||||
}
|
||||
)
|
||||
|
||||
async def call(
|
||||
self,
|
||||
context: ContextWrapper[AstrAgentContext],
|
||||
payload: dict[str, Any] | list[Any],
|
||||
kind: str = "astrbot_skill_v1",
|
||||
) -> ToolExecResult:
|
||||
return await self._run(
|
||||
context,
|
||||
lambda client, _sandbox: client.skills.create_payload(
|
||||
payload=payload,
|
||||
kind=kind,
|
||||
),
|
||||
error_action="creating skill payload",
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GetSkillPayloadTool(NeoSkillToolBase):
|
||||
name: str = "astrbot_get_skill_payload"
|
||||
description: str = "Get one skill payload by payload_ref."
|
||||
parameters: dict = field(
|
||||
default_factory=lambda: {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"payload_ref": {"type": "string"},
|
||||
},
|
||||
"required": ["payload_ref"],
|
||||
}
|
||||
)
|
||||
|
||||
async def call(
|
||||
self,
|
||||
context: ContextWrapper[AstrAgentContext],
|
||||
payload_ref: str,
|
||||
) -> ToolExecResult:
|
||||
return await self._run(
|
||||
context,
|
||||
lambda client, _sandbox: client.skills.get_payload(payload_ref),
|
||||
error_action="getting skill payload",
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CreateSkillCandidateTool(NeoSkillToolBase):
|
||||
name: str = "astrbot_create_skill_candidate"
|
||||
description: str = (
|
||||
"Step 2/3 for Neo skill authoring: create a candidate by binding execution evidence "
|
||||
"(source_execution_ids) with skill identity (skill_key) and optional payload_ref."
|
||||
)
|
||||
parameters: dict = field(
|
||||
default_factory=lambda: {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"skill_key": {
|
||||
"type": "string",
|
||||
"description": "Stable logical identifier, e.g. image-collage-9grid.",
|
||||
},
|
||||
"source_execution_ids": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Execution evidence IDs captured from sandbox history.",
|
||||
},
|
||||
"scenario_key": {
|
||||
"type": "string",
|
||||
"description": "Optional scenario namespace for grouping candidates.",
|
||||
},
|
||||
"payload_ref": {
|
||||
"type": "string",
|
||||
"description": "Optional payload reference created by astrbot_create_skill_payload.",
|
||||
},
|
||||
},
|
||||
"required": ["skill_key", "source_execution_ids"],
|
||||
}
|
||||
)
|
||||
|
||||
async def call(
|
||||
self,
|
||||
context: ContextWrapper[AstrAgentContext],
|
||||
skill_key: str,
|
||||
source_execution_ids: list[str],
|
||||
scenario_key: str | None = None,
|
||||
payload_ref: str | None = None,
|
||||
) -> ToolExecResult:
|
||||
return await self._run(
|
||||
context,
|
||||
lambda client, _sandbox: client.skills.create_candidate(
|
||||
skill_key=skill_key,
|
||||
source_execution_ids=source_execution_ids,
|
||||
scenario_key=scenario_key,
|
||||
payload_ref=payload_ref,
|
||||
),
|
||||
error_action="creating skill candidate",
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ListSkillCandidatesTool(NeoSkillToolBase):
|
||||
name: str = "astrbot_list_skill_candidates"
|
||||
description: str = "List skill candidates."
|
||||
parameters: dict = field(
|
||||
default_factory=lambda: {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"status": {"type": "string"},
|
||||
"skill_key": {"type": "string"},
|
||||
"limit": {"type": "integer", "default": 100},
|
||||
"offset": {"type": "integer", "default": 0},
|
||||
},
|
||||
"required": [],
|
||||
}
|
||||
)
|
||||
|
||||
async def call(
|
||||
self,
|
||||
context: ContextWrapper[AstrAgentContext],
|
||||
status: str | None = None,
|
||||
skill_key: str | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
) -> ToolExecResult:
|
||||
return await self._run(
|
||||
context,
|
||||
lambda client, _sandbox: client.skills.list_candidates(
|
||||
status=status,
|
||||
skill_key=skill_key,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
),
|
||||
error_action="listing skill candidates",
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EvaluateSkillCandidateTool(NeoSkillToolBase):
|
||||
name: str = "astrbot_evaluate_skill_candidate"
|
||||
description: str = "Evaluate a skill candidate."
|
||||
parameters: dict = field(
|
||||
default_factory=lambda: {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"candidate_id": {"type": "string"},
|
||||
"passed": {"type": "boolean"},
|
||||
"score": {"type": "number"},
|
||||
"benchmark_id": {"type": "string"},
|
||||
"report": {"type": "string"},
|
||||
},
|
||||
"required": ["candidate_id", "passed"],
|
||||
}
|
||||
)
|
||||
|
||||
async def call(
|
||||
self,
|
||||
context: ContextWrapper[AstrAgentContext],
|
||||
candidate_id: str,
|
||||
passed: bool,
|
||||
score: float | None = None,
|
||||
benchmark_id: str | None = None,
|
||||
report: str | None = None,
|
||||
) -> ToolExecResult:
|
||||
return await self._run(
|
||||
context,
|
||||
lambda client, _sandbox: client.skills.evaluate_candidate(
|
||||
candidate_id,
|
||||
passed=passed,
|
||||
score=score,
|
||||
benchmark_id=benchmark_id,
|
||||
report=report,
|
||||
),
|
||||
error_action="evaluating skill candidate",
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PromoteSkillCandidateTool(NeoSkillToolBase):
|
||||
name: str = "astrbot_promote_skill_candidate"
|
||||
description: str = (
|
||||
"Step 3/3 for Neo skill authoring: promote candidate to canary/stable release. "
|
||||
"If stage=stable and sync_to_local=true, payload.skill_markdown is synced to local SKILL.md automatically."
|
||||
)
|
||||
parameters: dict = field(
|
||||
default_factory=lambda: {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"candidate_id": {"type": "string"},
|
||||
"stage": {
|
||||
"type": "string",
|
||||
"description": "Release stage: canary/stable",
|
||||
"default": "canary",
|
||||
},
|
||||
"sync_to_local": {
|
||||
"type": "boolean",
|
||||
"description": (
|
||||
"Only used with stage=stable. true means sync payload.skill_markdown to local SKILL.md; "
|
||||
"false means release remains Neo-side only."
|
||||
),
|
||||
"default": True,
|
||||
},
|
||||
},
|
||||
"required": ["candidate_id"],
|
||||
}
|
||||
)
|
||||
|
||||
async def call(
|
||||
self,
|
||||
context: ContextWrapper[AstrAgentContext],
|
||||
candidate_id: str,
|
||||
stage: str = "canary",
|
||||
sync_to_local: bool = True,
|
||||
) -> ToolExecResult:
|
||||
if err := _ensure_admin(context):
|
||||
return err
|
||||
if stage not in {"canary", "stable"}:
|
||||
return "Error promoting skill candidate: stage must be canary or stable."
|
||||
|
||||
try:
|
||||
client, _sandbox = await _get_neo_context(context)
|
||||
sync_mgr = NeoSkillSyncManager()
|
||||
result = await sync_mgr.promote_with_optional_sync(
|
||||
client,
|
||||
candidate_id=candidate_id,
|
||||
stage=stage,
|
||||
sync_to_local=sync_to_local,
|
||||
)
|
||||
if result.get("sync_error"):
|
||||
rollback_json = result.get("rollback")
|
||||
if rollback_json:
|
||||
return (
|
||||
"Error promoting skill candidate: stable release synced failed; "
|
||||
f"auto rollback succeeded. sync_error={result['sync_error']}; "
|
||||
f"rollback={_to_json_text(rollback_json)}"
|
||||
)
|
||||
return _to_json_text(
|
||||
{
|
||||
"release": result.get("release"),
|
||||
"sync": result.get("sync"),
|
||||
"rollback": result.get("rollback"),
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
return f"Error promoting skill candidate: {str(e)}"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ListSkillReleasesTool(NeoSkillToolBase):
|
||||
name: str = "astrbot_list_skill_releases"
|
||||
description: str = "List skill releases."
|
||||
parameters: dict = field(
|
||||
default_factory=lambda: {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"skill_key": {"type": "string"},
|
||||
"active_only": {"type": "boolean", "default": False},
|
||||
"stage": {"type": "string"},
|
||||
"limit": {"type": "integer", "default": 100},
|
||||
"offset": {"type": "integer", "default": 0},
|
||||
},
|
||||
"required": [],
|
||||
}
|
||||
)
|
||||
|
||||
async def call(
|
||||
self,
|
||||
context: ContextWrapper[AstrAgentContext],
|
||||
skill_key: str | None = None,
|
||||
active_only: bool = False,
|
||||
stage: str | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
) -> ToolExecResult:
|
||||
return await self._run(
|
||||
context,
|
||||
lambda client, _sandbox: client.skills.list_releases(
|
||||
skill_key=skill_key,
|
||||
active_only=active_only,
|
||||
stage=stage,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
),
|
||||
error_action="listing skill releases",
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RollbackSkillReleaseTool(NeoSkillToolBase):
|
||||
name: str = "astrbot_rollback_skill_release"
|
||||
description: str = "Rollback one skill release."
|
||||
parameters: dict = field(
|
||||
default_factory=lambda: {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"release_id": {"type": "string"},
|
||||
},
|
||||
"required": ["release_id"],
|
||||
}
|
||||
)
|
||||
|
||||
async def call(
|
||||
self,
|
||||
context: ContextWrapper[AstrAgentContext],
|
||||
release_id: str,
|
||||
) -> ToolExecResult:
|
||||
return await self._run(
|
||||
context,
|
||||
lambda client, _sandbox: client.skills.rollback_release(release_id),
|
||||
error_action="rolling back skill release",
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SyncSkillReleaseTool(NeoSkillToolBase):
|
||||
name: str = "astrbot_sync_skill_release"
|
||||
description: str = (
|
||||
"Sync stable Neo release payload to local SKILL.md and update mapping metadata."
|
||||
)
|
||||
parameters: dict = field(
|
||||
default_factory=lambda: {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"release_id": {"type": "string"},
|
||||
"skill_key": {"type": "string"},
|
||||
"require_stable": {"type": "boolean", "default": True},
|
||||
},
|
||||
"required": [],
|
||||
}
|
||||
)
|
||||
|
||||
async def call(
|
||||
self,
|
||||
context: ContextWrapper[AstrAgentContext],
|
||||
release_id: str | None = None,
|
||||
skill_key: str | None = None,
|
||||
require_stable: bool = True,
|
||||
) -> ToolExecResult:
|
||||
return await self._run(
|
||||
context,
|
||||
lambda client, _sandbox: _sync_release_to_dict(
|
||||
client,
|
||||
release_id=release_id,
|
||||
skill_key=skill_key,
|
||||
require_stable=require_stable,
|
||||
),
|
||||
error_action="syncing skill release",
|
||||
)
|
||||
|
||||
|
||||
async def _sync_release_to_dict(
|
||||
client: Any,
|
||||
*,
|
||||
release_id: str | None,
|
||||
skill_key: str | None,
|
||||
require_stable: bool,
|
||||
) -> dict[str, str]:
|
||||
sync_mgr = NeoSkillSyncManager()
|
||||
result = await sync_mgr.sync_release(
|
||||
client,
|
||||
release_id=release_id,
|
||||
skill_key=skill_key,
|
||||
require_stable=require_stable,
|
||||
)
|
||||
return sync_mgr.sync_result_to_dict(result)
|
||||
@@ -1,4 +1,3 @@
|
||||
import platform
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import mcp
|
||||
@@ -11,8 +10,6 @@ from astrbot.core.computer.computer_client import get_booter, get_local_booter
|
||||
from astrbot.core.computer.tools.permissions import check_admin_permission
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
|
||||
_OS_NAME = platform.system()
|
||||
|
||||
param_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -64,7 +61,7 @@ async def handle_result(result: dict, event: AstrMessageEvent) -> ToolExecResult
|
||||
@dataclass
|
||||
class PythonTool(FunctionTool):
|
||||
name: str = "astrbot_execute_ipython"
|
||||
description: str = f"Run codes in an IPython shell. Current OS: {_OS_NAME}."
|
||||
description: str = "Run codes in an IPython shell."
|
||||
parameters: dict = field(default_factory=lambda: param_schema)
|
||||
|
||||
async def call(
|
||||
@@ -86,10 +83,7 @@ class PythonTool(FunctionTool):
|
||||
@dataclass
|
||||
class LocalPythonTool(FunctionTool):
|
||||
name: str = "astrbot_execute_python"
|
||||
description: str = (
|
||||
f"Execute codes in a Python environment. Current OS: {_OS_NAME}. "
|
||||
"Use system-compatible commands."
|
||||
)
|
||||
description: str = "Execute codes in a Python environment."
|
||||
|
||||
parameters: dict = field(default_factory=lambda: param_schema)
|
||||
|
||||
|
||||
+45
-157
@@ -132,15 +132,11 @@ DEFAULT_CONFIG = {
|
||||
"computer_use_runtime": "none",
|
||||
"computer_use_require_admin": True,
|
||||
"sandbox": {
|
||||
"booter": "shipyard_neo",
|
||||
"booter": "shipyard",
|
||||
"shipyard_endpoint": "",
|
||||
"shipyard_access_token": "",
|
||||
"shipyard_ttl": 3600,
|
||||
"shipyard_max_sessions": 10,
|
||||
"shipyard_neo_endpoint": "",
|
||||
"shipyard_neo_access_token": "",
|
||||
"shipyard_neo_profile": "python-default",
|
||||
"shipyard_neo_ttl": 3600,
|
||||
},
|
||||
},
|
||||
# SubAgent orchestrator mode:
|
||||
@@ -395,6 +391,7 @@ CONFIG_METADATA_2 = {
|
||||
"discord_token": "",
|
||||
"discord_proxy": "",
|
||||
"discord_command_register": True,
|
||||
"discord_guild_id_for_debug": "",
|
||||
"discord_activity_name": "",
|
||||
},
|
||||
"Misskey": {
|
||||
@@ -449,20 +446,6 @@ CONFIG_METADATA_2 = {
|
||||
"satori_heartbeat_interval": 10,
|
||||
"satori_reconnect_delay": 5,
|
||||
},
|
||||
"kook": {
|
||||
"id": "kook",
|
||||
"type": "kook",
|
||||
"enable": False,
|
||||
"kook_bot_token": "",
|
||||
"kook_bot_nickname": "",
|
||||
"kook_reconnect_delay": 1,
|
||||
"kook_max_reconnect_delay": 60,
|
||||
"kook_max_retry_delay": 60,
|
||||
"kook_heartbeat_interval": 30,
|
||||
"kook_heartbeat_timeout": 6,
|
||||
"kook_max_heartbeat_failures": 3,
|
||||
"kook_max_consecutive_failures": 5,
|
||||
},
|
||||
# "WebChat": {
|
||||
# "id": "webchat",
|
||||
# "type": "webchat",
|
||||
@@ -768,8 +751,7 @@ CONFIG_METADATA_2 = {
|
||||
"hint": "可选的代理地址:http://ip:port",
|
||||
},
|
||||
"discord_command_register": {
|
||||
"description": "注册 Discord 指令",
|
||||
"hint": "启用后,自动将插件指令注册为 Discord 斜杠指令",
|
||||
"description": "是否自动将插件指令注册为 Discord 斜杠指令",
|
||||
"type": "bool",
|
||||
},
|
||||
"discord_activity_name": {
|
||||
@@ -804,51 +786,6 @@ CONFIG_METADATA_2 = {
|
||||
"type": "string",
|
||||
"hint": "统一 Webhook 模式下的唯一标识符,创建平台时自动生成。",
|
||||
},
|
||||
"kook_bot_token": {
|
||||
"description": "机器人 Token",
|
||||
"type": "string",
|
||||
"hint": "必填项。从 KOOK 开发者平台获取的机器人 Token。",
|
||||
},
|
||||
"kook_bot_nickname": {
|
||||
"description": "Bot Nickname",
|
||||
"type": "string",
|
||||
"hint": "可选项。若发送者昵称与此值一致,将忽略该消息以避免广播风暴。",
|
||||
},
|
||||
"kook_reconnect_delay": {
|
||||
"description": "重连延迟",
|
||||
"type": "int",
|
||||
"hint": "重连延迟时间(秒),使用指数退避策略。",
|
||||
},
|
||||
"kook_max_reconnect_delay": {
|
||||
"description": "最大重连延迟",
|
||||
"type": "int",
|
||||
"hint": "重连延迟的最大值(秒)。",
|
||||
},
|
||||
"kook_max_retry_delay": {
|
||||
"description": "最大重试延迟",
|
||||
"type": "int",
|
||||
"hint": "重试的最大延迟时间(秒)。",
|
||||
},
|
||||
"kook_heartbeat_interval": {
|
||||
"description": "心跳间隔",
|
||||
"type": "int",
|
||||
"hint": "心跳检测间隔时间(秒)。",
|
||||
},
|
||||
"kook_heartbeat_timeout": {
|
||||
"description": "心跳超时时间",
|
||||
"type": "int",
|
||||
"hint": "心跳检测超时时间(秒)。",
|
||||
},
|
||||
"kook_max_heartbeat_failures": {
|
||||
"description": "最大心跳失败次数",
|
||||
"type": "int",
|
||||
"hint": "允许的最大心跳失败次数,超过后断开连接。",
|
||||
},
|
||||
"kook_max_consecutive_failures": {
|
||||
"description": "最大连续失败次数",
|
||||
"type": "int",
|
||||
"hint": "允许的最大连续失败次数,超过后停止重试。",
|
||||
},
|
||||
},
|
||||
},
|
||||
"platform_settings": {
|
||||
@@ -2934,48 +2871,12 @@ CONFIG_METADATA_3 = {
|
||||
"provider_settings.sandbox.booter": {
|
||||
"description": "沙箱环境驱动器",
|
||||
"type": "string",
|
||||
"options": ["shipyard_neo", "shipyard"],
|
||||
"labels": ["Shipyard Neo", "Shipyard"],
|
||||
"options": ["shipyard"],
|
||||
"labels": ["Shipyard"],
|
||||
"condition": {
|
||||
"provider_settings.computer_use_runtime": "sandbox",
|
||||
},
|
||||
},
|
||||
"provider_settings.sandbox.shipyard_neo_endpoint": {
|
||||
"description": "Shipyard Neo API Endpoint",
|
||||
"type": "string",
|
||||
"hint": "Shipyard Neo(Bay) 服务的 API 地址,默认 http://127.0.0.1:8114。",
|
||||
"condition": {
|
||||
"provider_settings.computer_use_runtime": "sandbox",
|
||||
"provider_settings.sandbox.booter": "shipyard_neo",
|
||||
},
|
||||
},
|
||||
"provider_settings.sandbox.shipyard_neo_access_token": {
|
||||
"description": "Shipyard Neo Access Token",
|
||||
"type": "string",
|
||||
"hint": "Bay 的 API Key(sk-bay-...)。留空时自动从 credentials.json 发现。",
|
||||
"condition": {
|
||||
"provider_settings.computer_use_runtime": "sandbox",
|
||||
"provider_settings.sandbox.booter": "shipyard_neo",
|
||||
},
|
||||
},
|
||||
"provider_settings.sandbox.shipyard_neo_profile": {
|
||||
"description": "Shipyard Neo Profile",
|
||||
"type": "string",
|
||||
"hint": "Shipyard Neo 沙箱 profile,如 python-default。",
|
||||
"condition": {
|
||||
"provider_settings.computer_use_runtime": "sandbox",
|
||||
"provider_settings.sandbox.booter": "shipyard_neo",
|
||||
},
|
||||
},
|
||||
"provider_settings.sandbox.shipyard_neo_ttl": {
|
||||
"description": "Shipyard Neo Sandbox TTL",
|
||||
"type": "int",
|
||||
"hint": "Shipyard Neo 沙箱生存时间(秒)。",
|
||||
"condition": {
|
||||
"provider_settings.computer_use_runtime": "sandbox",
|
||||
"provider_settings.sandbox.booter": "shipyard_neo",
|
||||
},
|
||||
},
|
||||
"provider_settings.sandbox.shipyard_endpoint": {
|
||||
"description": "Shipyard API Endpoint",
|
||||
"type": "string",
|
||||
@@ -3211,6 +3112,46 @@ CONFIG_METADATA_3 = {
|
||||
"provider_settings.agent_runner_type": "local",
|
||||
},
|
||||
},
|
||||
"provider_settings.max_quoted_fallback_images": {
|
||||
"description": "引用图片回退解析上限",
|
||||
"type": "int",
|
||||
"hint": "引用/转发消息回退解析图片时的最大注入数量,超出会截断。",
|
||||
"condition": {
|
||||
"provider_settings.agent_runner_type": "local",
|
||||
},
|
||||
},
|
||||
"provider_settings.quoted_message_parser.max_component_chain_depth": {
|
||||
"description": "引用解析组件链深度",
|
||||
"type": "int",
|
||||
"hint": "解析 Reply 组件链时允许的最大递归深度。",
|
||||
"condition": {
|
||||
"provider_settings.agent_runner_type": "local",
|
||||
},
|
||||
},
|
||||
"provider_settings.quoted_message_parser.max_forward_node_depth": {
|
||||
"description": "引用解析转发节点深度",
|
||||
"type": "int",
|
||||
"hint": "解析合并转发节点时允许的最大递归深度。",
|
||||
"condition": {
|
||||
"provider_settings.agent_runner_type": "local",
|
||||
},
|
||||
},
|
||||
"provider_settings.quoted_message_parser.max_forward_fetch": {
|
||||
"description": "引用解析转发拉取上限",
|
||||
"type": "int",
|
||||
"hint": "递归拉取 get_forward_msg 的最大次数。",
|
||||
"condition": {
|
||||
"provider_settings.agent_runner_type": "local",
|
||||
},
|
||||
},
|
||||
"provider_settings.quoted_message_parser.warn_on_action_failure": {
|
||||
"description": "引用解析 action 失败告警",
|
||||
"type": "bool",
|
||||
"hint": "开启后,get_msg/get_forward_msg 全部尝试失败时输出 warning 日志。",
|
||||
"condition": {
|
||||
"provider_settings.agent_runner_type": "local",
|
||||
},
|
||||
},
|
||||
"provider_settings.max_agent_step": {
|
||||
"description": "工具调用轮数上限",
|
||||
"type": "int",
|
||||
@@ -3254,46 +3195,6 @@ CONFIG_METADATA_3 = {
|
||||
"type": "bool",
|
||||
"hint": "/provider 命令列出模型时是否并发检测连通性。开启后会主动调用模型测试连通性,可能产生额外 token 消耗。",
|
||||
},
|
||||
"provider_settings.max_quoted_fallback_images": {
|
||||
"description": "引用图片回退解析上限",
|
||||
"type": "int",
|
||||
"hint": "引用/转发消息回退解析图片时的最大注入数量,超出会截断。",
|
||||
"condition": {
|
||||
"provider_settings.agent_runner_type": "local",
|
||||
},
|
||||
},
|
||||
"provider_settings.quoted_message_parser.max_component_chain_depth": {
|
||||
"description": "引用解析组件链深度",
|
||||
"type": "int",
|
||||
"hint": "解析 Reply 组件链时允许的最大递归深度。",
|
||||
"condition": {
|
||||
"provider_settings.agent_runner_type": "local",
|
||||
},
|
||||
},
|
||||
"provider_settings.quoted_message_parser.max_forward_node_depth": {
|
||||
"description": "引用解析转发节点深度",
|
||||
"type": "int",
|
||||
"hint": "解析合并转发节点时允许的最大递归深度。",
|
||||
"condition": {
|
||||
"provider_settings.agent_runner_type": "local",
|
||||
},
|
||||
},
|
||||
"provider_settings.quoted_message_parser.max_forward_fetch": {
|
||||
"description": "引用解析转发拉取上限",
|
||||
"type": "int",
|
||||
"hint": "递归拉取 get_forward_msg 的最大次数。",
|
||||
"condition": {
|
||||
"provider_settings.agent_runner_type": "local",
|
||||
},
|
||||
},
|
||||
"provider_settings.quoted_message_parser.warn_on_action_failure": {
|
||||
"description": "引用解析 action 失败告警",
|
||||
"type": "bool",
|
||||
"hint": "开启后,get_msg/get_forward_msg 全部尝试失败时输出 warning 日志。",
|
||||
"condition": {
|
||||
"provider_settings.agent_runner_type": "local",
|
||||
},
|
||||
},
|
||||
},
|
||||
"condition": {
|
||||
"provider_settings.enable": True,
|
||||
@@ -3505,19 +3406,6 @@ CONFIG_METADATA_3 = {
|
||||
"platform_specific.telegram.pre_ack_emoji.enable": True,
|
||||
},
|
||||
},
|
||||
"platform_specific.discord.pre_ack_emoji.enable": {
|
||||
"description": "[Discord] 启用预回应表情",
|
||||
"type": "bool",
|
||||
},
|
||||
"platform_specific.discord.pre_ack_emoji.emojis": {
|
||||
"description": "表情列表(Unicode 或自定义表情名)",
|
||||
"type": "list",
|
||||
"items": {"type": "string"},
|
||||
"hint": "填写 Unicode 表情符号,例如:👍、🤔、⏳",
|
||||
"condition": {
|
||||
"platform_specific.discord.pre_ack_emoji.enable": True,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
@@ -175,10 +175,6 @@ class LogManager:
|
||||
_trace_sink_id: int | None = None
|
||||
_NOISY_LOGGER_LEVELS: dict[str, int] = {
|
||||
"aiosqlite": logging.WARNING,
|
||||
"filelock": logging.WARNING,
|
||||
"asyncio": logging.WARNING,
|
||||
"tzlocal": logging.WARNING,
|
||||
"apscheduler": logging.WARNING,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -27,7 +27,7 @@ class PreProcessStage(Stage):
|
||||
) -> None | AsyncGenerator[None, None]:
|
||||
"""在处理事件之前的预处理"""
|
||||
# 平台特异配置:platform_specific.<platform>.pre_ack_emoji
|
||||
supported = {"telegram", "lark", "discord"}
|
||||
supported = {"telegram", "lark"}
|
||||
platform = event.get_platform_name()
|
||||
cfg = (
|
||||
self.config.get("platform_specific", {})
|
||||
|
||||
@@ -180,10 +180,6 @@ class PlatformManager:
|
||||
from .sources.line.line_adapter import (
|
||||
LinePlatformAdapter, # noqa: F401
|
||||
)
|
||||
case "kook":
|
||||
from .sources.kook.kook_adapter import (
|
||||
KookPlatformAdapter, # noqa: F401
|
||||
)
|
||||
except (ImportError, ModuleNotFoundError) as e:
|
||||
logger.error(
|
||||
f"加载平台适配器 {platform_config['type']} 失败,原因:{e}。请检查依赖库是否安装。提示:可以在 管理面板->平台日志->安装Pip库 中安装依赖库。",
|
||||
|
||||
@@ -1,371 +0,0 @@
|
||||
import asyncio
|
||||
import json
|
||||
import re
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot.api.event import MessageChain
|
||||
from astrbot.api.message_components import At, AtAll, Image, Plain
|
||||
from astrbot.api.platform import (
|
||||
AstrBotMessage,
|
||||
MessageMember,
|
||||
MessageType,
|
||||
Platform,
|
||||
PlatformMetadata,
|
||||
register_platform_adapter,
|
||||
)
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
|
||||
from .kook_client import KookClient
|
||||
from .kook_config import KookConfig
|
||||
from .kook_event import KookEvent
|
||||
|
||||
|
||||
@register_platform_adapter(
|
||||
"kook",
|
||||
"KOOK 适配器",
|
||||
)
|
||||
class KookPlatformAdapter(Platform):
|
||||
def __init__(
|
||||
self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue
|
||||
) -> None:
|
||||
super().__init__(platform_config, event_queue)
|
||||
self.kook_config = KookConfig.from_dict(platform_config)
|
||||
logger.debug(f"[KOOK] 配置: {self.kook_config.pretty_jsons()}")
|
||||
self.settings = platform_settings
|
||||
self.client = KookClient(self.kook_config, self._on_received)
|
||||
self._reconnect_task = None
|
||||
self.running = False
|
||||
self._main_task = None
|
||||
|
||||
async def send_by_session(
|
||||
self, session: MessageSesion, message_chain: MessageChain
|
||||
):
|
||||
inner_message = AstrBotMessage()
|
||||
inner_message.session_id = session.session_id
|
||||
inner_message.type = session.message_type
|
||||
message_event = KookEvent(
|
||||
message_str=message_chain.get_plain_text(),
|
||||
message_obj=inner_message,
|
||||
platform_meta=self.meta(),
|
||||
session_id=session.session_id,
|
||||
client=self.client,
|
||||
)
|
||||
await message_event.send(message_chain)
|
||||
|
||||
def meta(self) -> PlatformMetadata:
|
||||
return PlatformMetadata(
|
||||
name="kook", description="KOOK 适配器", id=self.kook_config.id
|
||||
)
|
||||
|
||||
def _should_ignore_event_by_bot_nickname(self, payload: dict) -> bool:
|
||||
bot_nickname = self.kook_config.bot_nickname.strip()
|
||||
if not bot_nickname:
|
||||
return False
|
||||
|
||||
author = payload.get("extra", {}).get("author", {})
|
||||
if not isinstance(author, dict):
|
||||
return False
|
||||
|
||||
author_nickname = author.get("nickname") or author.get("username") or ""
|
||||
if not isinstance(author_nickname, str):
|
||||
author_nickname = str(author_nickname)
|
||||
|
||||
return author_nickname.strip().casefold() == bot_nickname.casefold()
|
||||
|
||||
async def _on_received(self, data: dict):
|
||||
logger.debug(f"KOOK 收到数据: {data}")
|
||||
if "d" in data and data["s"] == 0:
|
||||
payload = data["d"]
|
||||
event_type = payload.get("type")
|
||||
# 支持type=9(文本)和type=10(卡片)
|
||||
if event_type in (9, 10):
|
||||
if self._should_ignore_event_by_bot_nickname(payload):
|
||||
return
|
||||
try:
|
||||
abm = await self.convert_message(payload)
|
||||
await self.handle_msg(abm)
|
||||
except Exception as e:
|
||||
logger.error(f"[KOOK] 消息处理异常: {e}")
|
||||
|
||||
async def run(self):
|
||||
"""主运行循环"""
|
||||
self.running = True
|
||||
logger.info("[KOOK] 启动KOOK适配器")
|
||||
|
||||
# 启动主循环
|
||||
self._main_task = asyncio.create_task(self._main_loop())
|
||||
|
||||
try:
|
||||
await self._main_task
|
||||
except asyncio.CancelledError:
|
||||
logger.info("[KOOK] 适配器被取消")
|
||||
except Exception as e:
|
||||
logger.error(f"[KOOK] 适配器运行异常: {e}")
|
||||
finally:
|
||||
self.running = False
|
||||
await self._cleanup()
|
||||
|
||||
async def _main_loop(self):
|
||||
"""主循环,处理连接和重连"""
|
||||
consecutive_failures = 0
|
||||
max_consecutive_failures = self.kook_config.max_consecutive_failures
|
||||
max_retry_delay = self.kook_config.max_retry_delay
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
logger.info("[KOOK] 尝试连接KOOK服务器...")
|
||||
|
||||
# 尝试连接
|
||||
success = await self.client.connect()
|
||||
|
||||
if success:
|
||||
logger.info("[KOOK] 连接成功,开始监听消息")
|
||||
consecutive_failures = 0 # 重置失败计数
|
||||
|
||||
# 等待连接结束(可能是正常关闭或异常)
|
||||
while self.client.running and self.running:
|
||||
try:
|
||||
# 等待 client 内部触发 _stop_event,或者超时 1 秒后重试
|
||||
# 使用 wait_for 配合 timeout 是为了防止极端情况下 self.running 变化没被察觉
|
||||
await asyncio.wait_for(
|
||||
self.client.wait_until_closed(), timeout=1.0
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
# 正常超时,继续下一轮 while 检查
|
||||
continue
|
||||
|
||||
if self.running:
|
||||
logger.warning("[KOOK] 连接断开,准备重连")
|
||||
|
||||
else:
|
||||
consecutive_failures += 1
|
||||
logger.error(
|
||||
f"[KOOK] 连接失败,连续失败次数: {consecutive_failures}"
|
||||
)
|
||||
|
||||
if consecutive_failures >= max_consecutive_failures:
|
||||
logger.error("[KOOK] 连续失败次数过多,停止重连")
|
||||
break
|
||||
|
||||
# 等待一段时间后重试
|
||||
wait_time = min(
|
||||
2**consecutive_failures, max_retry_delay
|
||||
) # 指数退避
|
||||
logger.info(f"[KOOK] 等待 {wait_time} 秒后重试...")
|
||||
await asyncio.sleep(wait_time)
|
||||
|
||||
except Exception as e:
|
||||
consecutive_failures += 1
|
||||
logger.error(f"[KOOK] 主循环异常: {e}")
|
||||
|
||||
if consecutive_failures >= max_consecutive_failures:
|
||||
logger.error("[KOOK] 连续异常次数过多,停止重连")
|
||||
break
|
||||
|
||||
await asyncio.sleep(5)
|
||||
|
||||
async def _cleanup(self):
|
||||
"""清理资源"""
|
||||
logger.info("[KOOK] 开始清理资源")
|
||||
|
||||
if self.client:
|
||||
try:
|
||||
await self.client.close()
|
||||
except Exception as e:
|
||||
logger.error(f"[KOOK] 关闭客户端异常: {e}")
|
||||
|
||||
if self._main_task and not self._main_task.done():
|
||||
self._main_task.cancel()
|
||||
try:
|
||||
await self._main_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
logger.info("[KOOK] 资源清理完成")
|
||||
|
||||
def _parse_kmarkdown_text_message(
|
||||
self, data: dict, self_id: str
|
||||
) -> tuple[list, str]:
|
||||
kmarkdown = data.get("extra", {}).get("kmarkdown", {})
|
||||
content = data.get("content") or ""
|
||||
raw_content = kmarkdown.get("raw_content") or content
|
||||
if not isinstance(content, str):
|
||||
content = str(content)
|
||||
if not isinstance(raw_content, str):
|
||||
raw_content = str(raw_content)
|
||||
|
||||
mention_name_map: dict[str, str] = {}
|
||||
mention_part = kmarkdown.get("mention_part", [])
|
||||
if isinstance(mention_part, list):
|
||||
for item in mention_part:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
mention_id = item.get("id")
|
||||
if mention_id is None:
|
||||
continue
|
||||
mention_name_map[str(mention_id)] = str(item.get("username", ""))
|
||||
|
||||
components = []
|
||||
cursor = 0
|
||||
for match in re.finditer(r"\(met\)([^()]+)\(met\)", content):
|
||||
if match.start() > cursor:
|
||||
plain_text = content[cursor : match.start()]
|
||||
if plain_text:
|
||||
components.append(Plain(text=plain_text))
|
||||
|
||||
mention_target = match.group(1).strip()
|
||||
if mention_target == "all":
|
||||
components.append(AtAll())
|
||||
elif mention_target:
|
||||
components.append(
|
||||
At(
|
||||
qq=mention_target,
|
||||
name=mention_name_map.get(mention_target, ""),
|
||||
)
|
||||
)
|
||||
cursor = match.end()
|
||||
|
||||
if cursor < len(content):
|
||||
tail_text = content[cursor:]
|
||||
if tail_text:
|
||||
components.append(Plain(text=tail_text))
|
||||
|
||||
message_str = raw_content
|
||||
if components:
|
||||
for comp in components:
|
||||
if isinstance(comp, Plain):
|
||||
if not comp.text.strip():
|
||||
continue
|
||||
break
|
||||
if isinstance(comp, At):
|
||||
if str(comp.qq) == str(self_id):
|
||||
message_str = re.sub(
|
||||
r"^@[^\s]+(\s*-\s*[^\s]+)?\s*",
|
||||
"",
|
||||
message_str,
|
||||
count=1,
|
||||
).strip()
|
||||
break
|
||||
if not components:
|
||||
if message_str:
|
||||
components = [Plain(text=message_str)]
|
||||
else:
|
||||
components = []
|
||||
|
||||
return components, message_str
|
||||
|
||||
def _parse_card_message(self, data: dict) -> tuple[list, str]:
|
||||
content = data.get("content", "[]")
|
||||
if not isinstance(content, str):
|
||||
content = str(content)
|
||||
card_list = json.loads(content)
|
||||
|
||||
text_parts: list[str] = []
|
||||
images: list[str] = []
|
||||
|
||||
for card in card_list:
|
||||
if not isinstance(card, dict):
|
||||
continue
|
||||
for module in card.get("modules", []):
|
||||
if not isinstance(module, dict):
|
||||
continue
|
||||
|
||||
module_type = module.get("type")
|
||||
if module_type == "section":
|
||||
section_text = module.get("text", {}).get("content", "")
|
||||
if section_text:
|
||||
text_parts.append(str(section_text))
|
||||
continue
|
||||
|
||||
if module_type != "container":
|
||||
continue
|
||||
|
||||
for element in module.get("elements", []):
|
||||
if not isinstance(element, dict):
|
||||
continue
|
||||
if element.get("type") != "image":
|
||||
continue
|
||||
|
||||
image_src = element.get("src")
|
||||
if not isinstance(image_src, str):
|
||||
logger.warning(
|
||||
f'[KOOK] 处理卡片中的图片时发生错误,图片url "{image_src}" 应该为str类型, 而不是 "{type(image_src)}" '
|
||||
)
|
||||
continue
|
||||
if not image_src.startswith(("http://", "https://")):
|
||||
logger.warning(f"[KOOK] 屏蔽非http图片url: {image_src}")
|
||||
continue
|
||||
images.append(image_src)
|
||||
|
||||
text = "".join(text_parts)
|
||||
message = []
|
||||
if text:
|
||||
message.append(Plain(text=text))
|
||||
for img_url in images:
|
||||
message.append(Image(file=img_url))
|
||||
return message, text
|
||||
|
||||
async def convert_message(self, data: dict) -> AstrBotMessage:
|
||||
abm = AstrBotMessage()
|
||||
abm.raw_message = data
|
||||
abm.self_id = self.client.bot_id
|
||||
|
||||
channel_type = data.get("channel_type")
|
||||
author_id = data.get("author_id", "unknown")
|
||||
# channel_type定义: https://developer.kookapp.cn/doc/event/event-introduction
|
||||
match channel_type:
|
||||
case "GROUP":
|
||||
session_id = data.get("target_id") or "unknown"
|
||||
abm.type = MessageType.GROUP_MESSAGE
|
||||
abm.group_id = session_id
|
||||
abm.session_id = session_id
|
||||
case "PERSON":
|
||||
abm.type = MessageType.FRIEND_MESSAGE
|
||||
abm.group_id = ""
|
||||
abm.session_id = data.get("author_id", "unknown")
|
||||
case "BROADCAST":
|
||||
session_id = data.get("target_id") or "unknown"
|
||||
abm.type = MessageType.OTHER_MESSAGE
|
||||
abm.group_id = session_id
|
||||
abm.session_id = session_id
|
||||
case _:
|
||||
raise ValueError(f"不支持的频道类型: {channel_type}")
|
||||
|
||||
abm.sender = MessageMember(
|
||||
user_id=author_id,
|
||||
nickname=data.get("extra", {}).get("author", {}).get("username", ""),
|
||||
)
|
||||
|
||||
abm.message_id = data.get("msg_id", "unknown")
|
||||
|
||||
# 普通文本消息
|
||||
if data.get("type") == 9:
|
||||
message, message_str = self._parse_kmarkdown_text_message(
|
||||
data, str(abm.self_id)
|
||||
)
|
||||
abm.message = message
|
||||
abm.message_str = message_str
|
||||
# 卡片消息
|
||||
elif data.get("type") == 10:
|
||||
try:
|
||||
abm.message, abm.message_str = self._parse_card_message(data)
|
||||
except Exception as exp:
|
||||
logger.error(f"[KOOK] 卡片消息解析失败: {exp}")
|
||||
abm.message_str = "[卡片消息解析失败]"
|
||||
abm.message = [Plain(text="[卡片消息解析失败]")]
|
||||
else:
|
||||
logger.warning(f'[KOOK] 不支持的kook消息类型: "{data.get("type")}"')
|
||||
abm.message_str = "[不支持的消息类型]"
|
||||
abm.message = [Plain(text="[不支持的消息类型]")]
|
||||
|
||||
return abm
|
||||
|
||||
async def handle_msg(self, message: AstrBotMessage):
|
||||
message_event = KookEvent(
|
||||
message_str=message.message_str,
|
||||
message_obj=message,
|
||||
platform_meta=self.meta(),
|
||||
session_id=message.session_id,
|
||||
client=self.client,
|
||||
)
|
||||
self.commit_event(message_event)
|
||||
@@ -1,437 +0,0 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
import zlib
|
||||
from pathlib import Path
|
||||
|
||||
import aiofiles
|
||||
import aiohttp
|
||||
import websockets
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot.core.platform.message_type import MessageType
|
||||
|
||||
from .kook_config import KookConfig
|
||||
from .kook_types import KookApiPaths, KookMessageType
|
||||
|
||||
|
||||
class KookClient:
|
||||
def __init__(self, config: KookConfig, event_callback):
|
||||
# 数据字段
|
||||
self.config = config
|
||||
self._bot_id = ""
|
||||
self._bot_name = ""
|
||||
|
||||
# 资源字段
|
||||
self._http_client = aiohttp.ClientSession(
|
||||
headers={
|
||||
"Authorization": f"Bot {self.config.token}",
|
||||
}
|
||||
)
|
||||
self.event_callback = event_callback # 回调函数,用于处理接收到的事件
|
||||
self.ws = None
|
||||
self.heartbeat_task = None
|
||||
self._stop_event = asyncio.Event() # 用于通知连接结束
|
||||
|
||||
# 状态/计算字段
|
||||
self.running = False
|
||||
self.session_id = None
|
||||
self.last_sn = 0 # 记录最后处理的消息序号
|
||||
self.last_heartbeat_time = 0
|
||||
self.heartbeat_failed_count = 0
|
||||
|
||||
@property
|
||||
def bot_id(self):
|
||||
return self._bot_id
|
||||
|
||||
@property
|
||||
def bot_name(self):
|
||||
return self._bot_name
|
||||
|
||||
async def get_bot_info(self) -> str:
|
||||
"""获取机器人账号ID"""
|
||||
url = KookApiPaths.USER_ME
|
||||
|
||||
try:
|
||||
async with self._http_client.get(url) as resp:
|
||||
if resp.status != 200:
|
||||
logger.error(f"[KOOK] 获取机器人账号ID失败,状态码: {resp.status}")
|
||||
return ""
|
||||
|
||||
data = await resp.json()
|
||||
if data.get("code") != 0:
|
||||
logger.error(f"[KOOK] 获取机器人账号ID失败: {data}")
|
||||
return ""
|
||||
|
||||
bot_id: str = data["data"]["id"]
|
||||
self._bot_id = bot_id
|
||||
logger.info(f"[KOOK] 获取机器人账号ID成功: {bot_id}")
|
||||
bot_name: str = data["data"]["nickname"] or data["data"]["username"]
|
||||
self._bot_name = bot_name
|
||||
logger.info(f"[KOOK] 获取机器人名称成功: {self._bot_name}")
|
||||
|
||||
return bot_id
|
||||
except Exception as e:
|
||||
logger.error(f"[KOOK] 获取机器人账号ID异常: {e}")
|
||||
return ""
|
||||
|
||||
async def get_gateway_url(self, resume=False, sn=0, session_id=None):
|
||||
"""获取网关连接地址"""
|
||||
url = KookApiPaths.GATEWAY_INDEX
|
||||
|
||||
# 构建连接参数
|
||||
params = {}
|
||||
if resume:
|
||||
params["resume"] = 1
|
||||
params["sn"] = sn
|
||||
if session_id:
|
||||
params["session_id"] = session_id
|
||||
|
||||
try:
|
||||
async with self._http_client.get(url, params=params) as resp:
|
||||
if resp.status != 200:
|
||||
logger.error(f"[KOOK] 获取gateway失败,状态码: {resp.status}")
|
||||
return None
|
||||
|
||||
data = await resp.json()
|
||||
if data.get("code") != 0:
|
||||
logger.error(f"[KOOK] 获取gateway失败: {data}")
|
||||
return None
|
||||
|
||||
gateway_url: str = data["data"]["url"]
|
||||
logger.info(f"[KOOK] 获取gateway成功: {gateway_url.split('?')[0]}")
|
||||
return gateway_url
|
||||
except Exception as e:
|
||||
logger.error(f"[KOOK] 获取gateway异常: {e}")
|
||||
return None
|
||||
|
||||
async def connect(self, resume=False):
|
||||
"""连接WebSocket"""
|
||||
if self.ws:
|
||||
try:
|
||||
await self.ws.close()
|
||||
except Exception:
|
||||
pass
|
||||
self.ws = None
|
||||
self._stop_event.clear()
|
||||
try:
|
||||
# 获取gateway地址
|
||||
gateway_url = await self.get_gateway_url(
|
||||
resume=resume, sn=self.last_sn, session_id=self.session_id
|
||||
)
|
||||
await self.get_bot_info()
|
||||
|
||||
if not gateway_url:
|
||||
return False
|
||||
|
||||
# 连接WebSocket
|
||||
self.ws = await websockets.connect(gateway_url)
|
||||
self.running = True
|
||||
logger.info("[KOOK] WebSocket 连接成功")
|
||||
|
||||
# 启动心跳任务
|
||||
if self.heartbeat_task:
|
||||
self.heartbeat_task.cancel()
|
||||
self.heartbeat_task = asyncio.create_task(self._heartbeat_loop())
|
||||
|
||||
# 开始监听消息
|
||||
await self.listen()
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[KOOK] WebSocket 连接失败: {e}")
|
||||
if self.ws:
|
||||
try:
|
||||
await self.ws.close()
|
||||
except Exception:
|
||||
pass
|
||||
self.ws = None
|
||||
return False
|
||||
|
||||
async def listen(self):
|
||||
"""监听WebSocket消息"""
|
||||
try:
|
||||
while self.running:
|
||||
try:
|
||||
msg = await asyncio.wait_for(self.ws.recv(), timeout=10) # type: ignore
|
||||
|
||||
if isinstance(msg, bytes):
|
||||
try:
|
||||
msg = zlib.decompress(msg)
|
||||
except Exception as e:
|
||||
logger.error(f"[KOOK] 解压消息失败: {e}")
|
||||
continue
|
||||
msg = msg.decode("utf-8")
|
||||
|
||||
data = json.loads(msg)
|
||||
|
||||
# 处理不同类型的信令
|
||||
await self._handle_signal(data)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
# 超时检查,继续循环
|
||||
continue
|
||||
except websockets.exceptions.ConnectionClosed:
|
||||
logger.warning("[KOOK] WebSocket连接已关闭")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"[KOOK] 消息处理异常: {e}")
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[KOOK] WebSocket 监听异常: {e}")
|
||||
finally:
|
||||
self.running = False
|
||||
self._stop_event.set()
|
||||
|
||||
async def _handle_signal(self, data):
|
||||
"""处理不同类型的信令"""
|
||||
signal_type = data.get("s")
|
||||
|
||||
if signal_type == 0: # 事件消息
|
||||
# 更新消息序号
|
||||
if "sn" in data:
|
||||
self.last_sn = data["sn"]
|
||||
await self.event_callback(data)
|
||||
|
||||
elif signal_type == 1: # HELLO握手
|
||||
await self._handle_hello(data)
|
||||
|
||||
elif signal_type == 3: # PONG心跳响应
|
||||
await self._handle_pong(data)
|
||||
|
||||
elif signal_type == 5: # RECONNECT重连指令
|
||||
await self._handle_reconnect(data)
|
||||
|
||||
elif signal_type == 6: # RESUME ACK
|
||||
await self._handle_resume_ack(data)
|
||||
|
||||
else:
|
||||
logger.debug(f"[KOOK] 未处理的信令类型: {signal_type}")
|
||||
|
||||
async def _handle_hello(self, data):
|
||||
"""处理HELLO握手"""
|
||||
hello_data = data.get("d", {})
|
||||
code = hello_data.get("code", 0)
|
||||
|
||||
if code == 0:
|
||||
self.session_id = hello_data.get("session_id")
|
||||
logger.info(f"[KOOK] 握手成功,session_id: {self.session_id}")
|
||||
# TODO 重置重连延迟
|
||||
# self.reconnect_delay = 1
|
||||
else:
|
||||
logger.error(f"[KOOK] 握手失败,错误码: {code}")
|
||||
if code == 40103: # token过期
|
||||
logger.error("[KOOK] Token已过期,需要重新获取")
|
||||
self.running = False
|
||||
|
||||
async def _handle_pong(self, data):
|
||||
"""处理PONG心跳响应"""
|
||||
self.last_heartbeat_time = time.time()
|
||||
self.heartbeat_failed_count = 0
|
||||
|
||||
async def _handle_reconnect(self, data):
|
||||
"""处理重连指令"""
|
||||
logger.warning("[KOOK] 收到重连指令")
|
||||
# 清空本地状态
|
||||
self.last_sn = 0
|
||||
self.session_id = None
|
||||
self.running = False
|
||||
|
||||
async def _handle_resume_ack(self, data):
|
||||
"""处理RESUME确认"""
|
||||
resume_data = data.get("d", {})
|
||||
self.session_id = resume_data.get("session_id")
|
||||
logger.info(f"[KOOK] Resume成功,session_id: {self.session_id}")
|
||||
|
||||
async def _heartbeat_loop(self):
|
||||
"""心跳循环"""
|
||||
while self.running:
|
||||
try:
|
||||
# 随机化心跳间隔 (±5秒)
|
||||
interval = max(
|
||||
1, self.config.heartbeat_interval + random.randint(-5, 5)
|
||||
)
|
||||
await asyncio.sleep(interval)
|
||||
|
||||
if not self.running:
|
||||
break
|
||||
|
||||
# 发送心跳
|
||||
await self._send_ping()
|
||||
|
||||
# 等待PONG响应
|
||||
await asyncio.sleep(self.config.heartbeat_timeout)
|
||||
|
||||
# 检查是否收到PONG响应
|
||||
if (
|
||||
time.time() - self.last_heartbeat_time
|
||||
> self.config.heartbeat_timeout
|
||||
):
|
||||
self.heartbeat_failed_count += 1
|
||||
logger.warning(
|
||||
f"[KOOK] 心跳超时,失败次数: {self.heartbeat_failed_count}"
|
||||
)
|
||||
|
||||
if (
|
||||
self.heartbeat_failed_count
|
||||
>= self.config.max_heartbeat_failures
|
||||
):
|
||||
logger.error("[KOOK] 心跳失败次数过多,准备重连")
|
||||
self.running = False
|
||||
break
|
||||
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"[KOOK] 心跳异常: {e}")
|
||||
self.heartbeat_failed_count += 1
|
||||
|
||||
async def _send_ping(self):
|
||||
"""发送心跳PING"""
|
||||
try:
|
||||
ping_data = {"s": 2, "sn": self.last_sn}
|
||||
await self.ws.send(json.dumps(ping_data)) # type: ignore
|
||||
except Exception as e:
|
||||
logger.error(f"[KOOK] 发送心跳失败: {e}")
|
||||
|
||||
async def send_text(
|
||||
self,
|
||||
target_id: str,
|
||||
content: str,
|
||||
astrbot_message_type: MessageType,
|
||||
kook_message_type: KookMessageType,
|
||||
reply_message_id: str | int = "",
|
||||
):
|
||||
"""发送文本消息
|
||||
消息发送接口文档参见: https://developer.kookapp.cn/doc/http/message#%E5%8F%91%E9%80%81%E9%A2%91%E9%81%93%E8%81%8A%E5%A4%A9%E6%B6%88%E6%81%AF
|
||||
KMarkdown格式参见: https://developer.kookapp.cn/doc/kmarkdown-desc
|
||||
"""
|
||||
url = KookApiPaths.CHANNEL_MESSAGE_CREATE
|
||||
if astrbot_message_type == MessageType.FRIEND_MESSAGE:
|
||||
url = KookApiPaths.DIRECT_MESSAGE_CREATE
|
||||
|
||||
payload = {
|
||||
"target_id": target_id,
|
||||
"content": content,
|
||||
"type": kook_message_type,
|
||||
}
|
||||
if reply_message_id:
|
||||
payload["quote"] = reply_message_id
|
||||
payload["reply_msg_id"] = reply_message_id
|
||||
|
||||
try:
|
||||
async with self._http_client.post(url, json=payload) as resp:
|
||||
if resp.status == 200:
|
||||
result = await resp.json()
|
||||
if result.get("code") != 0:
|
||||
raise RuntimeError(
|
||||
f'发送kook消息类型 "{kook_message_type.name}" 失败: {result}'
|
||||
)
|
||||
# else:
|
||||
# logger.info("[KOOK] 发送消息成功")
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f'发送kook消息类型 "{kook_message_type.name}" HTTP错误: {resp.status} , 响应内容 : {await resp.text()}'
|
||||
)
|
||||
except RuntimeError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f'[KOOK] 发送kook消息类型 "{kook_message_type.name}" 异常: {e}'
|
||||
)
|
||||
|
||||
async def upload_asset(self, file_url: str | None) -> str:
|
||||
"""上传文件到kook,获得远端资源url
|
||||
接口定义参见: https://developer.kookapp.cn/doc/http/asset
|
||||
"""
|
||||
if not file_url:
|
||||
return ""
|
||||
|
||||
bytes_data: bytes | None = None
|
||||
filename = "unknown"
|
||||
if file_url.startswith(("http://", "https://")):
|
||||
filename = file_url.split("/")[-1]
|
||||
return file_url
|
||||
|
||||
if file_url.startswith("base64:///"):
|
||||
# b64decode的时候得开头留一个'/'的, 不然会报错
|
||||
b64_str = file_url.removeprefix("base64://")
|
||||
bytes_data = base64.b64decode(b64_str)
|
||||
|
||||
elif file_url.startswith("file://") or os.path.exists(file_url):
|
||||
file_url = file_url.removeprefix("file:///")
|
||||
file_url = file_url.removeprefix("file://")
|
||||
|
||||
try:
|
||||
target_path = Path(file_url).resolve()
|
||||
except Exception as exp:
|
||||
logger.error(f'[KOOK] 获取文件 "{file_url}" 绝对路径失败: "{exp}"')
|
||||
raise FileNotFoundError(
|
||||
f'获取文件 "{file_url}" 绝对路径失败: "{exp}"'
|
||||
) from exp
|
||||
|
||||
if not target_path.is_file():
|
||||
raise FileNotFoundError(f"文件不存在: {target_path.name}")
|
||||
|
||||
filename = target_path.name
|
||||
async with aiofiles.open(target_path, "rb") as f:
|
||||
bytes_data = await f.read()
|
||||
|
||||
else:
|
||||
raise ValueError(f'[KOOK] 不支持的文件资源类型: "{file_url}"')
|
||||
|
||||
data = aiohttp.FormData()
|
||||
data.add_field("file", bytes_data, filename=filename)
|
||||
|
||||
url = KookApiPaths.ASSET_CREATE
|
||||
try:
|
||||
async with self._http_client.post(url, data=data) as resp:
|
||||
if resp.status == 200:
|
||||
result: dict = await resp.json()
|
||||
logger.debug(f"[KOOK] 上传文件响应: {result}")
|
||||
if result.get("code") == 0:
|
||||
logger.info("[KOOK] 上传文件到kook服务器成功")
|
||||
remote_url = result["data"]["url"]
|
||||
logger.debug(f"[KOOK] 文件远端URL: {remote_url}")
|
||||
return remote_url
|
||||
else:
|
||||
raise RuntimeError(f"上传文件到kook服务器失败: {result}")
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"上传文件到kook服务器 HTTP错误: {resp.status} , {await resp.text()}"
|
||||
)
|
||||
except RuntimeError:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"上传文件到kook服务器异常: {e}") from e
|
||||
|
||||
async def wait_until_closed(self):
|
||||
"""提供给外部调用的等待方法"""
|
||||
await self._stop_event.wait()
|
||||
|
||||
async def close(self):
|
||||
"""关闭连接"""
|
||||
self.running = False
|
||||
self._stop_event.set()
|
||||
|
||||
if self.heartbeat_task:
|
||||
self.heartbeat_task.cancel()
|
||||
try:
|
||||
await self.heartbeat_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
if self.ws:
|
||||
try:
|
||||
await self.ws.close()
|
||||
except Exception as e:
|
||||
logger.error(f"[KOOK] 关闭WebSocket异常: {e}")
|
||||
|
||||
if self._http_client:
|
||||
await self._http_client.close()
|
||||
|
||||
logger.info("[KOOK] 连接已关闭")
|
||||
@@ -1,133 +0,0 @@
|
||||
import json
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class KookConfig:
|
||||
"""KOOK 适配器配置类"""
|
||||
|
||||
# 基础配置
|
||||
token: str
|
||||
bot_nickname: str = ""
|
||||
enable: bool = False
|
||||
id: str = "kook"
|
||||
|
||||
# 重连配置
|
||||
reconnect_delay: int = 1
|
||||
"""重连延迟基数(秒),指数退避"""
|
||||
max_reconnect_delay: int = 60
|
||||
"""最大重连延迟(秒)"""
|
||||
max_retry_delay: int = 60
|
||||
"""最大重试延迟(秒)"""
|
||||
|
||||
# 心跳配置
|
||||
heartbeat_interval: int = 30
|
||||
"""心跳间隔(秒)"""
|
||||
heartbeat_timeout: int = 6
|
||||
"""心跳超时时间(秒)"""
|
||||
max_heartbeat_failures: int = 3
|
||||
"""最大心跳失败次数"""
|
||||
|
||||
# 失败处理
|
||||
max_consecutive_failures: int = 5
|
||||
"""最大连续失败次数"""
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, config_dict: dict) -> "KookConfig":
|
||||
"""从字典创建配置对象"""
|
||||
return cls(
|
||||
# 适配器id 应该是不能改的
|
||||
# id=config_dict.get("id", "kook"),
|
||||
enable=config_dict.get("enable", False),
|
||||
token=config_dict.get("kook_bot_token", ""),
|
||||
bot_nickname=config_dict.get("kook_bot_nickname", ""),
|
||||
reconnect_delay=config_dict.get(
|
||||
"kook_reconnect_delay",
|
||||
KookConfig.reconnect_delay,
|
||||
),
|
||||
max_reconnect_delay=config_dict.get(
|
||||
"kook_max_reconnect_delay",
|
||||
KookConfig.max_reconnect_delay,
|
||||
),
|
||||
max_retry_delay=config_dict.get(
|
||||
"kook_max_retry_delay",
|
||||
KookConfig.max_retry_delay,
|
||||
),
|
||||
heartbeat_interval=config_dict.get(
|
||||
"kook_heartbeat_interval",
|
||||
KookConfig.heartbeat_interval,
|
||||
),
|
||||
heartbeat_timeout=config_dict.get(
|
||||
"kook_heartbeat_timeout",
|
||||
KookConfig.heartbeat_timeout,
|
||||
),
|
||||
max_heartbeat_failures=config_dict.get(
|
||||
"kook_max_heartbeat_failures",
|
||||
KookConfig.max_heartbeat_failures,
|
||||
),
|
||||
max_consecutive_failures=config_dict.get(
|
||||
"kook_max_consecutive_failures",
|
||||
KookConfig.max_consecutive_failures,
|
||||
),
|
||||
)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return asdict(self)
|
||||
|
||||
def pretty_jsons(self, indent=2) -> str:
|
||||
dict_config = self.to_dict()
|
||||
dict_config["token"] = "*" * len(self.token) if self.token else "MISSING"
|
||||
return json.dumps(dict_config, indent=indent, ensure_ascii=False)
|
||||
|
||||
|
||||
# TODO 没用上的config配置,未来有空会实现这些配置描述的功能?
|
||||
# # 连接配置
|
||||
# CONNECTION_CONFIG = {
|
||||
# # 心跳配置
|
||||
# "heartbeat_interval": 30, # 心跳间隔(秒)
|
||||
# "heartbeat_timeout": 6, # 心跳超时时间(秒)
|
||||
# "max_heartbeat_failures": 3, # 最大心跳失败次数
|
||||
# # 重连配置
|
||||
# "initial_reconnect_delay": 1, # 初始重连延迟(秒)
|
||||
# "max_reconnect_delay": 60, # 最大重连延迟(秒)
|
||||
# "max_consecutive_failures": 5, # 最大连续失败次数
|
||||
# # WebSocket配置
|
||||
# "websocket_timeout": 10, # WebSocket接收超时(秒)
|
||||
# "connection_timeout": 30, # 连接超时(秒)
|
||||
# # 消息处理配置
|
||||
# "enable_compression": True, # 是否启用消息压缩
|
||||
# "max_message_size": 1024 * 1024, # 最大消息大小(字节)
|
||||
# }
|
||||
|
||||
# # 日志配置
|
||||
# LOGGING_CONFIG = {
|
||||
# "level": "INFO", # 日志级别:DEBUG, INFO, WARNING, ERROR
|
||||
# "format": "[KOOK] %(message)s",
|
||||
# "enable_heartbeat_logs": False, # 是否启用心跳日志
|
||||
# "enable_message_logs": False, # 是否启用消息日志
|
||||
# }
|
||||
|
||||
# # 错误处理配置
|
||||
# ERROR_HANDLING_CONFIG = {
|
||||
# "retry_on_network_error": True, # 网络错误时是否重试
|
||||
# "retry_on_token_expired": True, # Token过期时是否重试
|
||||
# "max_retry_attempts": 3, # 最大重试次数
|
||||
# "retry_delay_base": 2, # 重试延迟基数(秒)
|
||||
# }
|
||||
|
||||
# # 性能配置
|
||||
# PERFORMANCE_CONFIG = {
|
||||
# "enable_message_buffering": True, # 是否启用消息缓冲
|
||||
# "buffer_size": 100, # 缓冲区大小
|
||||
# "enable_connection_pooling": True, # 是否启用连接池
|
||||
# "max_concurrent_requests": 10, # 最大并发请求数
|
||||
# }
|
||||
|
||||
# # 安全配置
|
||||
# SECURITY_CONFIG = {
|
||||
# "verify_ssl": True, # 是否验证SSL证书
|
||||
# "enable_rate_limiting": True, # 是否启用速率限制
|
||||
# "rate_limit_requests": 100, # 速率限制请求数
|
||||
# "rate_limit_window": 60, # 速率限制窗口(秒)
|
||||
# }
|
||||
@@ -1,209 +0,0 @@
|
||||
import asyncio
|
||||
import json
|
||||
from collections.abc import Coroutine
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
|
||||
from astrbot.core.message.components import (
|
||||
At,
|
||||
AtAll,
|
||||
BaseMessageComponent,
|
||||
File,
|
||||
Image,
|
||||
Json,
|
||||
Plain,
|
||||
Record,
|
||||
Reply,
|
||||
Video,
|
||||
)
|
||||
from astrbot.core.platform import MessageType
|
||||
|
||||
from .kook_client import KookClient
|
||||
from .kook_types import (
|
||||
FileModule,
|
||||
KookCardMessage,
|
||||
KookCardMessageContainer,
|
||||
KookMessageType,
|
||||
OrderMessage,
|
||||
)
|
||||
|
||||
|
||||
class KookEvent(AstrMessageEvent):
|
||||
def __init__(
|
||||
self,
|
||||
message_str: str,
|
||||
message_obj: AstrBotMessage,
|
||||
platform_meta: PlatformMetadata,
|
||||
session_id: str,
|
||||
client: KookClient,
|
||||
):
|
||||
super().__init__(message_str, message_obj, platform_meta, session_id)
|
||||
self.client = client
|
||||
self.channel_id = message_obj.group_id or message_obj.session_id
|
||||
self.astrbot_message_type: MessageType = message_obj.type
|
||||
self._file_message_counter = 0
|
||||
|
||||
def _wrap_message(
|
||||
self, index: int, message_component: BaseMessageComponent
|
||||
) -> Coroutine[Any, Any, OrderMessage]:
|
||||
async def wrap_upload(
|
||||
index: int, message_type: KookMessageType, upload_coro
|
||||
) -> OrderMessage:
|
||||
url = await upload_coro
|
||||
return OrderMessage(index=index, text=url, type=message_type)
|
||||
|
||||
async def handle_plain(
|
||||
index: int,
|
||||
text: str | None,
|
||||
reply_id: str | int = "",
|
||||
type: KookMessageType = KookMessageType.KMARKDOWN,
|
||||
):
|
||||
if not text:
|
||||
text = ""
|
||||
return OrderMessage(
|
||||
index=index,
|
||||
text=text,
|
||||
type=type,
|
||||
reply_id=reply_id,
|
||||
)
|
||||
|
||||
match message_component:
|
||||
case Image():
|
||||
self._file_message_counter += 1
|
||||
return wrap_upload(
|
||||
index,
|
||||
KookMessageType.IMAGE,
|
||||
self.client.upload_asset(message_component.file),
|
||||
)
|
||||
|
||||
case Video():
|
||||
self._file_message_counter += 1
|
||||
return wrap_upload(
|
||||
index,
|
||||
KookMessageType.VIDEO,
|
||||
self.client.upload_asset(message_component.file),
|
||||
)
|
||||
case File():
|
||||
|
||||
async def handle_file(index: int, f_item: File):
|
||||
f_data = await f_item.get_file()
|
||||
url = await self.client.upload_asset(f_data)
|
||||
return OrderMessage(
|
||||
index=index, text=url, type=KookMessageType.FILE
|
||||
)
|
||||
|
||||
self._file_message_counter += 1
|
||||
return handle_file(index, message_component)
|
||||
|
||||
case Record():
|
||||
|
||||
async def handle_audio(index: int, f_item: Record):
|
||||
file_path = await f_item.convert_to_file_path()
|
||||
url = await self.client.upload_asset(file_path)
|
||||
title = f_item.text or Path(file_path).name
|
||||
return OrderMessage(
|
||||
index=index,
|
||||
text=KookCardMessageContainer(
|
||||
[
|
||||
KookCardMessage(
|
||||
modules=[
|
||||
FileModule(
|
||||
type="audio",
|
||||
title=title,
|
||||
src=url,
|
||||
)
|
||||
]
|
||||
)
|
||||
]
|
||||
).to_json(),
|
||||
type=KookMessageType.CARD,
|
||||
)
|
||||
|
||||
return handle_audio(index, message_component)
|
||||
case Plain():
|
||||
return handle_plain(index, message_component.text)
|
||||
case At():
|
||||
return handle_plain(index, f"(met){message_component.qq}(met)")
|
||||
case AtAll():
|
||||
return handle_plain(index, "(met)all(met)")
|
||||
case Reply():
|
||||
return handle_plain(index, "", reply_id=message_component.id)
|
||||
case Json():
|
||||
json_data = message_component.data
|
||||
# kook卡片json外层得是一个列表
|
||||
if isinstance(json_data, dict):
|
||||
json_data = [json_data]
|
||||
return handle_plain(
|
||||
index,
|
||||
# 考虑到kook可能会更改消息结构,为了能让插件开发者
|
||||
# 自行根据kook文档描述填卡片json内容,故不做模型校验
|
||||
# KookCardMessage().model_validate(message_component.data).to_json(),
|
||||
text=json.dumps(json_data),
|
||||
type=KookMessageType.CARD,
|
||||
)
|
||||
case _:
|
||||
raise NotImplementedError(
|
||||
f'kook适配器尚未实现对 "{message_component.type}" 消息类型的支持'
|
||||
)
|
||||
|
||||
async def send(self, message: MessageChain):
|
||||
file_upload_tasks: list[Coroutine[Any, Any, OrderMessage]] = []
|
||||
for index, item in enumerate(message.chain):
|
||||
file_upload_tasks.append(self._wrap_message(index, item))
|
||||
|
||||
if self._file_message_counter > 0:
|
||||
logger.debug("[Kook] 正在向kook服务器上传文件")
|
||||
|
||||
tasks_result = await asyncio.gather(*file_upload_tasks, return_exceptions=True)
|
||||
order_messages: list[OrderMessage] = []
|
||||
|
||||
for index, result in enumerate(tasks_result):
|
||||
if isinstance(result, BaseException):
|
||||
logger.error(f"[Kook] {result}")
|
||||
# 构造一个虚假的 OrderMessage,让用户知道这里本来有张图但坏了
|
||||
# 这样后面的 for 循环就能把它当成普通文本发出去
|
||||
err_node = OrderMessage(
|
||||
index=index,
|
||||
text=str(result),
|
||||
type=KookMessageType.TEXT,
|
||||
)
|
||||
order_messages.append(err_node)
|
||||
else:
|
||||
order_messages.append(result)
|
||||
|
||||
order_messages.sort(key=lambda x: x.index)
|
||||
|
||||
reply_id: str | int = ""
|
||||
errors: list[Exception] = []
|
||||
for item in order_messages:
|
||||
if item.reply_id:
|
||||
reply_id = item.reply_id
|
||||
if not item.text:
|
||||
logger.debug(f'[Kook] 跳过空消息,类型为"{item.type}"')
|
||||
continue
|
||||
try:
|
||||
await self.client.send_text(
|
||||
self.channel_id,
|
||||
item.text,
|
||||
self.astrbot_message_type,
|
||||
item.type,
|
||||
reply_id,
|
||||
)
|
||||
except RuntimeError as exp:
|
||||
await self.client.send_text(
|
||||
self.channel_id,
|
||||
str(exp),
|
||||
self.astrbot_message_type,
|
||||
KookMessageType.TEXT,
|
||||
reply_id,
|
||||
)
|
||||
errors.append(exp)
|
||||
|
||||
if errors:
|
||||
err_msg = "\n".join([str(err) for err in errors])
|
||||
logger.error(f"[kook] {err_msg}")
|
||||
|
||||
await super().send(message)
|
||||
@@ -1,241 +0,0 @@
|
||||
import json
|
||||
from dataclasses import field
|
||||
from enum import IntEnum
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
|
||||
class KookApiPaths:
|
||||
"""Kook Api 路径"""
|
||||
|
||||
BASE_URL = "https://www.kookapp.cn"
|
||||
API_VERSION_PATH = "/api/v3"
|
||||
|
||||
# 初始化相关
|
||||
USER_ME = f"{BASE_URL}{API_VERSION_PATH}/user/me"
|
||||
GATEWAY_INDEX = f"{BASE_URL}{API_VERSION_PATH}/gateway/index"
|
||||
|
||||
# 消息相关
|
||||
ASSET_CREATE = f"{BASE_URL}{API_VERSION_PATH}/asset/create"
|
||||
## 频道消息
|
||||
CHANNEL_MESSAGE_CREATE = f"{BASE_URL}{API_VERSION_PATH}/message/create"
|
||||
## 私聊消息
|
||||
DIRECT_MESSAGE_CREATE = f"{BASE_URL}{API_VERSION_PATH}/direct-message/create"
|
||||
|
||||
|
||||
# 定义参见kook事件结构文档: https://developer.kookapp.cn/doc/event/event-introduction
|
||||
class KookMessageType(IntEnum):
|
||||
TEXT = 1
|
||||
IMAGE = 2
|
||||
VIDEO = 3
|
||||
FILE = 4
|
||||
AUDIO = 8
|
||||
KMARKDOWN = 9
|
||||
CARD = 10
|
||||
SYSTEM = 255
|
||||
|
||||
|
||||
ThemeType = Literal[
|
||||
"primary", "success", "danger", "warning", "info", "secondary", "none", "invisible"
|
||||
]
|
||||
"""主题,可选的值为:primary, success, danger, warning, info, secondary, none.默认为 primary,为 none 时不显示侧边框。"""
|
||||
SizeType = Literal["xs", "sm", "md", "lg"]
|
||||
"""大小,可选值为:xs, sm, md, lg, 一般默认为 lg"""
|
||||
|
||||
SectionMode = Literal["left", "right"]
|
||||
CountdownMode = Literal["day", "hour", "second"]
|
||||
|
||||
|
||||
class KookCardColor(str):
|
||||
"""16 进制色值"""
|
||||
|
||||
|
||||
class KookCardModelBase:
|
||||
"""卡片模块基类"""
|
||||
|
||||
type: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlainTextElement(KookCardModelBase):
|
||||
content: str
|
||||
type: str = "plain-text"
|
||||
emoji: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class KmarkdownElement(KookCardModelBase):
|
||||
content: str
|
||||
type: str = "kmarkdown"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImageElement(KookCardModelBase):
|
||||
src: str
|
||||
type: str = "image"
|
||||
alt: str = ""
|
||||
size: SizeType = "lg"
|
||||
circle: bool = False
|
||||
fallbackUrl: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ButtonElement(KookCardModelBase):
|
||||
text: str
|
||||
type: str = "button"
|
||||
theme: ThemeType = "primary"
|
||||
value: str = ""
|
||||
"""当为 link 时,会跳转到 value 代表的链接;
|
||||
当为 return-val 时,系统会通过系统消息将消息 id,点击用户 id 和 value 发回给发送者,发送者可以根据自己的需求进行处理,消息事件参见button 点击事件。私聊和频道内均可使用按钮点击事件。"""
|
||||
click: Literal["", "link", "return-val"] = ""
|
||||
"""click 代表用户点击的事件,默认为"",代表无任何事件。"""
|
||||
|
||||
|
||||
AnyElement = PlainTextElement | KmarkdownElement | ImageElement | ButtonElement | str
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParagraphStructure(KookCardModelBase):
|
||||
fields: list[PlainTextElement | KmarkdownElement]
|
||||
type: str = "paragraph"
|
||||
cols: int = 1
|
||||
"""范围是 1-3 , 移动端忽略此参数"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class HeaderModule(KookCardModelBase):
|
||||
text: PlainTextElement
|
||||
type: str = "header"
|
||||
|
||||
|
||||
@dataclass
|
||||
class SectionModule(KookCardModelBase):
|
||||
text: PlainTextElement | KmarkdownElement | ParagraphStructure
|
||||
type: str = "section"
|
||||
mode: SectionMode = "left"
|
||||
accessory: ImageElement | ButtonElement | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImageGroupModule(KookCardModelBase):
|
||||
"""1 到多张图片的组合"""
|
||||
|
||||
elements: list[ImageElement]
|
||||
type: str = "image-group"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContainerModule(KookCardModelBase):
|
||||
"""1 到多张图片的组合,与图片组模块(ImageGroupModule)不同,图片并不会裁切为正方形。多张图片会纵向排列。"""
|
||||
|
||||
elements: list[ImageElement]
|
||||
type: str = "container"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ActionGroupModule(KookCardModelBase):
|
||||
elements: list[ButtonElement]
|
||||
type: str = "action-group"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContextModule(KookCardModelBase):
|
||||
elements: list[PlainTextElement | KmarkdownElement | ImageElement]
|
||||
"""最多包含10个元素"""
|
||||
type: str = "context"
|
||||
|
||||
|
||||
@dataclass
|
||||
class DividerModule(KookCardModelBase):
|
||||
type: str = "divider"
|
||||
|
||||
|
||||
@dataclass
|
||||
class FileModule(KookCardModelBase):
|
||||
src: str
|
||||
title: str = ""
|
||||
type: Literal["file", "audio", "video"] = "file"
|
||||
cover: str | None = None
|
||||
"""cover 仅音频有效, 是音频的封面图"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class CountdownModule(KookCardModelBase):
|
||||
"""startTime 和 endTime 为毫秒时间戳,startTime 和 endTime 不能小于服务器当前时间戳。"""
|
||||
|
||||
endTime: int
|
||||
"""毫秒时间戳"""
|
||||
type: str = "countdown"
|
||||
startTime: int | None = None
|
||||
"""毫秒时间戳, 仅当mode为second才有这个字段"""
|
||||
mode: CountdownMode = "day"
|
||||
"""mode 主要是倒计时的样式"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class InviteModule(KookCardModelBase):
|
||||
code: str
|
||||
"""邀请链接或者邀请码"""
|
||||
type: str = "invite"
|
||||
|
||||
|
||||
# 所有模块的联合类型
|
||||
AnyModule = (
|
||||
HeaderModule
|
||||
| SectionModule
|
||||
| ImageGroupModule
|
||||
| ContainerModule
|
||||
| ActionGroupModule
|
||||
| ContextModule
|
||||
| DividerModule
|
||||
| FileModule
|
||||
| CountdownModule
|
||||
| InviteModule
|
||||
)
|
||||
|
||||
|
||||
class KookCardMessage(BaseModel):
|
||||
"""卡片定义文档详见 : https://developer.kookapp.cn/doc/cardmessage
|
||||
此类型不能直接to_json后发送,因为kook要求卡片容器json顶层必须是**列表**
|
||||
若要发送卡片消息,请使用KookCardMessageContainer
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
type: str = "card"
|
||||
theme: ThemeType | None = None
|
||||
size: SizeType | None = None
|
||||
color: KookCardColor | None = None
|
||||
modules: list[AnyModule] = field(default_factory=list)
|
||||
"""单个 card 模块数量不限制,但是一条消息中所有卡片的模块数量之和最多是 50"""
|
||||
|
||||
def add_module(self, module: AnyModule):
|
||||
self.modules.append(module)
|
||||
|
||||
def to_dict(self, exclude_none: bool = True):
|
||||
"""exclude_none:去掉值为 None 字段,保留结构"""
|
||||
return self.model_dump(exclude_none=exclude_none)
|
||||
|
||||
def to_json(self, indent: int | None = None, ensure_ascii: bool = True):
|
||||
return json.dumps(self.to_dict(), indent=indent, ensure_ascii=ensure_ascii)
|
||||
|
||||
|
||||
class KookCardMessageContainer(list[KookCardMessage]):
|
||||
"""卡片消息容器(列表),此类型可以直接to_json后发送出去"""
|
||||
|
||||
def append(self, object: KookCardMessage) -> None:
|
||||
return super().append(object)
|
||||
|
||||
def to_json(self, indent: int | None = None, ensure_ascii: bool = True) -> str:
|
||||
return json.dumps(
|
||||
[i.to_dict() for i in self], indent=indent, ensure_ascii=ensure_ascii
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class OrderMessage:
|
||||
index: int
|
||||
text: str
|
||||
type: KookMessageType
|
||||
reply_id: str | int = ""
|
||||
@@ -104,7 +104,7 @@ class LineMessageEvent(AstrMessageEvent):
|
||||
@staticmethod
|
||||
async def _resolve_image_url(segment: Image) -> str:
|
||||
candidate = (segment.url or segment.file or "").strip()
|
||||
if candidate.startswith("https://"):
|
||||
if candidate.startswith("http://") or candidate.startswith("https://"):
|
||||
return candidate
|
||||
try:
|
||||
return await segment.register_to_file_service()
|
||||
@@ -115,7 +115,7 @@ class LineMessageEvent(AstrMessageEvent):
|
||||
@staticmethod
|
||||
async def _resolve_record_url(segment: Record) -> str:
|
||||
candidate = (segment.url or segment.file or "").strip()
|
||||
if candidate.startswith("https://"):
|
||||
if candidate.startswith("http://") or candidate.startswith("https://"):
|
||||
return candidate
|
||||
try:
|
||||
return await segment.register_to_file_service()
|
||||
@@ -137,7 +137,7 @@ class LineMessageEvent(AstrMessageEvent):
|
||||
@staticmethod
|
||||
async def _resolve_video_url(segment: Video) -> str:
|
||||
candidate = (segment.file or "").strip()
|
||||
if candidate.startswith("https://"):
|
||||
if candidate.startswith("http://") or candidate.startswith("https://"):
|
||||
return candidate
|
||||
try:
|
||||
return await segment.register_to_file_service()
|
||||
@@ -148,7 +148,9 @@ class LineMessageEvent(AstrMessageEvent):
|
||||
@staticmethod
|
||||
async def _resolve_video_preview_url(segment: Video) -> str:
|
||||
cover_candidate = (segment.cover or "").strip()
|
||||
if cover_candidate.startswith("https://"):
|
||||
if cover_candidate.startswith("http://") or cover_candidate.startswith(
|
||||
"https://"
|
||||
):
|
||||
return cover_candidate
|
||||
|
||||
if cover_candidate:
|
||||
@@ -189,7 +191,7 @@ class LineMessageEvent(AstrMessageEvent):
|
||||
|
||||
@staticmethod
|
||||
async def _resolve_file_url(segment: File) -> str:
|
||||
if segment.url and segment.url.startswith("https://"):
|
||||
if segment.url and segment.url.startswith(("http://", "https://")):
|
||||
return segment.url
|
||||
try:
|
||||
return await segment.register_to_file_service()
|
||||
|
||||
@@ -4,11 +4,7 @@ import asyncio
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
import threading
|
||||
import urllib.parse
|
||||
from collections.abc import AsyncGenerator, Awaitable, Callable, Mapping
|
||||
from dataclasses import dataclass
|
||||
from types import MappingProxyType
|
||||
from collections.abc import AsyncGenerator, Awaitable, Callable
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
@@ -21,103 +17,6 @@ from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
DEFAULT_MCP_CONFIG = {"mcpServers": {}}
|
||||
|
||||
DEFAULT_MCP_INIT_TIMEOUT_SECONDS = 20.0
|
||||
DEFAULT_ENABLE_MCP_TIMEOUT_SECONDS = 30.0
|
||||
MCP_INIT_TIMEOUT_ENV = "ASTRBOT_MCP_INIT_TIMEOUT"
|
||||
ENABLE_MCP_TIMEOUT_ENV = "ASTRBOT_MCP_ENABLE_TIMEOUT"
|
||||
MAX_MCP_TIMEOUT_SECONDS = 300.0
|
||||
|
||||
|
||||
class MCPInitError(Exception):
|
||||
"""Base exception for MCP initialization failures."""
|
||||
|
||||
|
||||
class MCPInitTimeoutError(asyncio.TimeoutError, MCPInitError):
|
||||
"""Raised when MCP client initialization exceeds the configured timeout."""
|
||||
|
||||
|
||||
class MCPAllServicesFailedError(MCPInitError):
|
||||
"""Raised when all configured MCP services fail to initialize."""
|
||||
|
||||
|
||||
class MCPShutdownTimeoutError(asyncio.TimeoutError):
|
||||
"""Raised when MCP shutdown exceeds the configured timeout."""
|
||||
|
||||
def __init__(self, names: list[str], timeout: float) -> None:
|
||||
self.names = names
|
||||
self.timeout = timeout
|
||||
message = f"MCP 服务关闭超时({timeout:g} 秒):{', '.join(names)}"
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MCPInitSummary:
|
||||
total: int
|
||||
success: int
|
||||
failed: list[str]
|
||||
|
||||
|
||||
@dataclass
|
||||
class _MCPServerRuntime:
|
||||
name: str
|
||||
client: MCPClient
|
||||
shutdown_event: asyncio.Event
|
||||
lifecycle_task: asyncio.Task[None]
|
||||
|
||||
|
||||
class _MCPClientDictView(Mapping[str, MCPClient]):
|
||||
"""Read-only view of MCP clients derived from runtime state."""
|
||||
|
||||
def __init__(self, runtime: dict[str, _MCPServerRuntime]) -> None:
|
||||
self._runtime = runtime
|
||||
|
||||
def __getitem__(self, key: str) -> MCPClient:
|
||||
return self._runtime[key].client
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self._runtime)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._runtime)
|
||||
|
||||
|
||||
def _resolve_timeout(
|
||||
timeout: float | int | str | None = None,
|
||||
*,
|
||||
env_name: str = MCP_INIT_TIMEOUT_ENV,
|
||||
default: float = DEFAULT_MCP_INIT_TIMEOUT_SECONDS,
|
||||
) -> float:
|
||||
"""Resolve timeout with precedence: explicit argument > env value > default."""
|
||||
source = f"环境变量 {env_name}"
|
||||
if timeout is None:
|
||||
timeout = os.getenv(env_name, str(default))
|
||||
else:
|
||||
source = "显式参数 timeout"
|
||||
|
||||
try:
|
||||
timeout_value = float(timeout)
|
||||
except (TypeError, ValueError):
|
||||
logger.warning(
|
||||
f"超时配置({source})={timeout!r} 无效,使用默认值 {default:g} 秒。"
|
||||
)
|
||||
return default
|
||||
|
||||
if timeout_value <= 0:
|
||||
logger.warning(
|
||||
f"超时配置({source})={timeout_value:g} 必须大于 0,使用默认值 {default:g} 秒。"
|
||||
)
|
||||
return default
|
||||
|
||||
if timeout_value > MAX_MCP_TIMEOUT_SECONDS:
|
||||
logger.warning(
|
||||
f"超时配置({source})={timeout_value:g} 过大,已限制为最大值 "
|
||||
f"{MAX_MCP_TIMEOUT_SECONDS:g} 秒,以避免长时间等待。"
|
||||
)
|
||||
return MAX_MCP_TIMEOUT_SECONDS
|
||||
|
||||
return timeout_value
|
||||
|
||||
|
||||
SUPPORTED_TYPES = [
|
||||
"string",
|
||||
"number",
|
||||
@@ -207,49 +106,9 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
|
||||
class FunctionToolManager:
|
||||
def __init__(self) -> None:
|
||||
self.func_list: list[FuncTool] = []
|
||||
self._mcp_server_runtime: dict[str, _MCPServerRuntime] = {}
|
||||
"""MCP 服务运行时状态(唯一事实来源)"""
|
||||
self._mcp_server_runtime_view = MappingProxyType(self._mcp_server_runtime)
|
||||
self._mcp_client_dict_view = _MCPClientDictView(self._mcp_server_runtime)
|
||||
self._timeout_mismatch_warned = False
|
||||
self._timeout_warn_lock = threading.Lock()
|
||||
self._runtime_lock = asyncio.Lock()
|
||||
self._mcp_starting: set[str] = set()
|
||||
self._init_timeout_default = _resolve_timeout(
|
||||
timeout=None,
|
||||
env_name=MCP_INIT_TIMEOUT_ENV,
|
||||
default=DEFAULT_MCP_INIT_TIMEOUT_SECONDS,
|
||||
)
|
||||
self._enable_timeout_default = _resolve_timeout(
|
||||
timeout=None,
|
||||
env_name=ENABLE_MCP_TIMEOUT_ENV,
|
||||
default=DEFAULT_ENABLE_MCP_TIMEOUT_SECONDS,
|
||||
)
|
||||
self._warn_on_timeout_mismatch(
|
||||
self._init_timeout_default,
|
||||
self._enable_timeout_default,
|
||||
)
|
||||
|
||||
@property
|
||||
def mcp_client_dict(self) -> Mapping[str, MCPClient]:
|
||||
"""Read-only compatibility view for external callers that still read mcp_client_dict.
|
||||
|
||||
Note: Mutating this mapping is unsupported and will raise TypeError.
|
||||
"""
|
||||
return self._mcp_client_dict_view
|
||||
|
||||
@property
|
||||
def mcp_server_runtime_view(self) -> Mapping[str, _MCPServerRuntime]:
|
||||
"""Read-only view of MCP runtime metadata for external callers."""
|
||||
return self._mcp_server_runtime_view
|
||||
|
||||
@property
|
||||
def mcp_server_runtime(self) -> Mapping[str, _MCPServerRuntime]:
|
||||
"""Backward-compatible read-only view (deprecated). Do not mutate.
|
||||
|
||||
Note: Mutations are not supported and will raise TypeError.
|
||||
"""
|
||||
return self._mcp_server_runtime_view
|
||||
self.mcp_client_dict: dict[str, MCPClient] = {}
|
||||
"""MCP 服务列表"""
|
||||
self.mcp_client_event: dict[str, asyncio.Event] = {}
|
||||
|
||||
def empty(self) -> bool:
|
||||
return len(self.func_list) == 0
|
||||
@@ -320,34 +179,7 @@ class FunctionToolManager:
|
||||
tool_set = ToolSet(self.func_list.copy())
|
||||
return tool_set
|
||||
|
||||
@staticmethod
|
||||
def _log_safe_mcp_debug_config(cfg: dict) -> None:
|
||||
# 仅记录脱敏后的摘要,避免泄露 command/args/url 中的敏感信息
|
||||
if "command" in cfg:
|
||||
cmd = cfg["command"]
|
||||
executable = str(cmd[0] if isinstance(cmd, (list, tuple)) and cmd else cmd)
|
||||
args_val = cfg.get("args", [])
|
||||
args_count = (
|
||||
len(args_val)
|
||||
if isinstance(args_val, (list, tuple))
|
||||
else (0 if args_val is None else 1)
|
||||
)
|
||||
logger.debug(f" 命令可执行文件: {executable}, 参数数量: {args_count}")
|
||||
return
|
||||
|
||||
if "url" in cfg:
|
||||
parsed = urllib.parse.urlparse(str(cfg["url"]))
|
||||
host = parsed.hostname or ""
|
||||
scheme = parsed.scheme or "unknown"
|
||||
try:
|
||||
port = f":{parsed.port}" if parsed.port else ""
|
||||
except ValueError:
|
||||
port = ""
|
||||
logger.debug(f" 主机: {scheme}://{host}{port}")
|
||||
|
||||
async def init_mcp_clients(
|
||||
self, raise_on_all_failed: bool = False
|
||||
) -> MCPInitSummary:
|
||||
async def init_mcp_clients(self) -> None:
|
||||
"""从项目根目录读取 mcp_server.json 文件,初始化 MCP 服务列表。文件格式如下:
|
||||
```
|
||||
{
|
||||
@@ -365,10 +197,6 @@ class FunctionToolManager:
|
||||
...
|
||||
}
|
||||
```
|
||||
|
||||
Timeout behavior:
|
||||
- 初始化超时使用环境变量 ASTRBOT_MCP_INIT_TIMEOUT 或默认值。
|
||||
- 动态启用超时使用 ASTRBOT_MCP_ENABLE_TIMEOUT(独立于初始化超时)。
|
||||
"""
|
||||
data_dir = get_astrbot_data_path()
|
||||
|
||||
@@ -378,211 +206,56 @@ class FunctionToolManager:
|
||||
with open(mcp_json_file, "w", encoding="utf-8") as f:
|
||||
json.dump(DEFAULT_MCP_CONFIG, f, ensure_ascii=False, indent=4)
|
||||
logger.info(f"未找到 MCP 服务配置文件,已创建默认配置文件 {mcp_json_file}")
|
||||
return MCPInitSummary(total=0, success=0, failed=[])
|
||||
return
|
||||
|
||||
with open(mcp_json_file, encoding="utf-8") as f:
|
||||
mcp_server_json_obj: dict[str, dict] = json.load(f)["mcpServers"]
|
||||
mcp_server_json_obj: dict[str, dict] = json.load(
|
||||
open(mcp_json_file, encoding="utf-8"),
|
||||
)["mcpServers"]
|
||||
|
||||
init_timeout = self._init_timeout_default
|
||||
timeout_display = f"{init_timeout:g}"
|
||||
|
||||
active_configs: list[tuple[str, dict, asyncio.Event]] = []
|
||||
for name, cfg in mcp_server_json_obj.items():
|
||||
for name in mcp_server_json_obj:
|
||||
cfg = mcp_server_json_obj[name]
|
||||
if cfg.get("active", True):
|
||||
shutdown_event = asyncio.Event()
|
||||
active_configs.append((name, cfg, shutdown_event))
|
||||
event = asyncio.Event()
|
||||
asyncio.create_task(
|
||||
self._init_mcp_client_task_wrapper(name, cfg, event),
|
||||
)
|
||||
self.mcp_client_event[name] = event
|
||||
|
||||
if not active_configs:
|
||||
return MCPInitSummary(total=0, success=0, failed=[])
|
||||
|
||||
logger.info(f"等待 {len(active_configs)} 个 MCP 服务初始化...")
|
||||
|
||||
init_tasks = [
|
||||
asyncio.create_task(
|
||||
self._start_mcp_server(
|
||||
name=name,
|
||||
cfg=cfg,
|
||||
shutdown_event=shutdown_event,
|
||||
timeout=init_timeout,
|
||||
),
|
||||
name=f"mcp-init:{name}",
|
||||
)
|
||||
for (name, cfg, shutdown_event) in active_configs
|
||||
]
|
||||
results = await asyncio.gather(*init_tasks, return_exceptions=True)
|
||||
|
||||
success_count = 0
|
||||
failed_services: list[str] = []
|
||||
|
||||
for (name, cfg, _), result in zip(active_configs, results, strict=False):
|
||||
if isinstance(result, Exception):
|
||||
if isinstance(result, MCPInitTimeoutError):
|
||||
logger.error(f"MCP 服务 {name} 初始化超时({timeout_display}秒)")
|
||||
else:
|
||||
logger.error(f"MCP 服务 {name} 初始化失败: {result}")
|
||||
self._log_safe_mcp_debug_config(cfg)
|
||||
failed_services.append(name)
|
||||
async with self._runtime_lock:
|
||||
self._mcp_server_runtime.pop(name, None)
|
||||
continue
|
||||
|
||||
success_count += 1
|
||||
|
||||
if failed_services:
|
||||
logger.warning(
|
||||
f"以下 MCP 服务初始化失败: {', '.join(failed_services)}。"
|
||||
f"请检查配置文件 mcp_server.json 和服务器可用性。"
|
||||
)
|
||||
|
||||
summary = MCPInitSummary(
|
||||
total=len(active_configs), success=success_count, failed=failed_services
|
||||
)
|
||||
logger.info(f"MCP 服务初始化完成: {summary.success}/{summary.total} 成功")
|
||||
if summary.total > 0 and summary.success == 0:
|
||||
msg = "全部 MCP 服务初始化失败,请检查 mcp_server.json 配置和服务器可用性。"
|
||||
if raise_on_all_failed:
|
||||
raise MCPAllServicesFailedError(msg)
|
||||
logger.error(msg)
|
||||
return summary
|
||||
|
||||
async def _start_mcp_server(
|
||||
async def _init_mcp_client_task_wrapper(
|
||||
self,
|
||||
name: str,
|
||||
cfg: dict,
|
||||
*,
|
||||
shutdown_event: asyncio.Event | None = None,
|
||||
timeout: float,
|
||||
event: asyncio.Event,
|
||||
ready_future: asyncio.Future | None = None,
|
||||
) -> None:
|
||||
"""Initialize MCP server with timeout and register task/event together.
|
||||
|
||||
This method is idempotent. If the server is already running, the existing
|
||||
runtime is kept and the new config is ignored.
|
||||
"""
|
||||
async with self._runtime_lock:
|
||||
if name in self._mcp_server_runtime or name in self._mcp_starting:
|
||||
logger.warning(
|
||||
f"MCP 服务 {name} 已在运行,忽略本次启用请求(timeout={timeout:g})。"
|
||||
)
|
||||
self._log_safe_mcp_debug_config(cfg)
|
||||
return
|
||||
self._mcp_starting.add(name)
|
||||
|
||||
if shutdown_event is None:
|
||||
shutdown_event = asyncio.Event()
|
||||
|
||||
mcp_client: MCPClient | None = None
|
||||
"""初始化 MCP 客户端的包装函数,用于捕获异常"""
|
||||
try:
|
||||
mcp_client = await asyncio.wait_for(
|
||||
self._init_mcp_client(name, cfg),
|
||||
timeout=timeout,
|
||||
)
|
||||
except asyncio.TimeoutError as exc:
|
||||
raise MCPInitTimeoutError(
|
||||
f"MCP 服务 {name} 初始化超时({timeout:g} 秒)"
|
||||
) from exc
|
||||
except Exception:
|
||||
await self._init_mcp_client(name, cfg)
|
||||
tools = await self.mcp_client_dict[name].list_tools_and_save()
|
||||
if ready_future and not ready_future.done():
|
||||
# tell the caller we are ready
|
||||
ready_future.set_result(tools)
|
||||
await event.wait()
|
||||
logger.info(f"收到 MCP 客户端 {name} 终止信号")
|
||||
except Exception as e:
|
||||
logger.error(f"初始化 MCP 客户端 {name} 失败", exc_info=True)
|
||||
raise
|
||||
if ready_future and not ready_future.done():
|
||||
ready_future.set_exception(e)
|
||||
finally:
|
||||
if mcp_client is None:
|
||||
async with self._runtime_lock:
|
||||
self._mcp_starting.discard(name)
|
||||
# 无论如何都能清理
|
||||
await self._terminate_mcp_client(name)
|
||||
|
||||
async def lifecycle() -> None:
|
||||
try:
|
||||
await shutdown_event.wait()
|
||||
logger.info(f"收到 MCP 客户端 {name} 终止信号")
|
||||
except asyncio.CancelledError:
|
||||
logger.debug(f"MCP 客户端 {name} 任务被取消")
|
||||
raise
|
||||
finally:
|
||||
await self._terminate_mcp_client(name)
|
||||
|
||||
lifecycle_task = asyncio.create_task(lifecycle(), name=f"mcp-client:{name}")
|
||||
async with self._runtime_lock:
|
||||
self._mcp_server_runtime[name] = _MCPServerRuntime(
|
||||
name=name,
|
||||
client=mcp_client,
|
||||
shutdown_event=shutdown_event,
|
||||
lifecycle_task=lifecycle_task,
|
||||
)
|
||||
self._mcp_starting.discard(name)
|
||||
|
||||
async def _shutdown_runtimes(
|
||||
self,
|
||||
runtimes: list[_MCPServerRuntime],
|
||||
timeout: float,
|
||||
*,
|
||||
strict: bool = True,
|
||||
) -> list[str]:
|
||||
"""Shutdown runtimes and wait for lifecycle tasks to complete."""
|
||||
lifecycle_tasks = [
|
||||
runtime.lifecycle_task
|
||||
for runtime in runtimes
|
||||
if not runtime.lifecycle_task.done()
|
||||
]
|
||||
if not lifecycle_tasks:
|
||||
return []
|
||||
|
||||
for runtime in runtimes:
|
||||
runtime.shutdown_event.set()
|
||||
|
||||
try:
|
||||
results = await asyncio.wait_for(
|
||||
asyncio.gather(*lifecycle_tasks, return_exceptions=True),
|
||||
timeout=timeout,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
pending_names = [
|
||||
runtime.name
|
||||
for runtime in runtimes
|
||||
if not runtime.lifecycle_task.done()
|
||||
]
|
||||
for task in lifecycle_tasks:
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
await asyncio.gather(*lifecycle_tasks, return_exceptions=True)
|
||||
if strict:
|
||||
raise MCPShutdownTimeoutError(pending_names, timeout)
|
||||
logger.warning(
|
||||
"MCP 服务关闭超时(%s 秒),以下服务未完全关闭:%s",
|
||||
f"{timeout:g}",
|
||||
", ".join(pending_names),
|
||||
)
|
||||
return pending_names
|
||||
else:
|
||||
for result in results:
|
||||
if isinstance(result, asyncio.CancelledError):
|
||||
logger.debug("MCP lifecycle task was cancelled during shutdown.")
|
||||
elif isinstance(result, Exception):
|
||||
logger.error(
|
||||
"MCP lifecycle task failed during shutdown.",
|
||||
exc_info=(type(result), result, result.__traceback__),
|
||||
)
|
||||
return []
|
||||
|
||||
async def _cleanup_mcp_client_safely(
|
||||
self, mcp_client: MCPClient, name: str
|
||||
) -> None:
|
||||
"""安全清理单个 MCP 客户端,避免清理异常中断主流程。"""
|
||||
try:
|
||||
await mcp_client.cleanup()
|
||||
except Exception as cleanup_exc: # noqa: BLE001 - only log here
|
||||
logger.error(f"清理 MCP 客户端资源 {name} 失败: {cleanup_exc}")
|
||||
|
||||
async def _init_mcp_client(self, name: str, config: dict) -> MCPClient:
|
||||
async def _init_mcp_client(self, name: str, config: dict) -> None:
|
||||
"""初始化单个MCP客户端"""
|
||||
# 先清理之前的客户端,如果存在
|
||||
if name in self.mcp_client_dict:
|
||||
await self._terminate_mcp_client(name)
|
||||
|
||||
mcp_client = MCPClient()
|
||||
mcp_client.name = name
|
||||
try:
|
||||
await mcp_client.connect_to_server(config, name)
|
||||
tools_res = await mcp_client.list_tools_and_save()
|
||||
except asyncio.CancelledError:
|
||||
await self._cleanup_mcp_client_safely(mcp_client, name)
|
||||
raise
|
||||
except Exception:
|
||||
await self._cleanup_mcp_client_safely(mcp_client, name)
|
||||
raise
|
||||
self.mcp_client_dict[name] = mcp_client
|
||||
await mcp_client.connect_to_server(config, name)
|
||||
tools_res = await mcp_client.list_tools_and_save()
|
||||
logger.debug(f"MCP server {name} list tools response: {tools_res}")
|
||||
tool_names = [tool.name for tool in tools_res.tools]
|
||||
|
||||
@@ -603,36 +276,26 @@ class FunctionToolManager:
|
||||
self.func_list.append(func_tool)
|
||||
|
||||
logger.info(f"已连接 MCP 服务 {name}, Tools: {tool_names}")
|
||||
return mcp_client
|
||||
|
||||
async def _terminate_mcp_client(self, name: str) -> None:
|
||||
"""关闭并清理MCP客户端"""
|
||||
async with self._runtime_lock:
|
||||
runtime = self._mcp_server_runtime.get(name)
|
||||
if runtime:
|
||||
client = runtime.client
|
||||
# 关闭MCP连接
|
||||
await self._cleanup_mcp_client_safely(client, name)
|
||||
# 移除关联的FuncTool
|
||||
self.func_list = [
|
||||
f
|
||||
for f in self.func_list
|
||||
if not (isinstance(f, MCPTool) and f.mcp_server_name == name)
|
||||
]
|
||||
async with self._runtime_lock:
|
||||
self._mcp_server_runtime.pop(name, None)
|
||||
self._mcp_starting.discard(name)
|
||||
logger.info(f"已关闭 MCP 服务 {name}")
|
||||
return
|
||||
|
||||
# Runtime missing but stale tools may still exist after failed flows.
|
||||
self.func_list = [
|
||||
f
|
||||
for f in self.func_list
|
||||
if not (isinstance(f, MCPTool) and f.mcp_server_name == name)
|
||||
]
|
||||
async with self._runtime_lock:
|
||||
self._mcp_starting.discard(name)
|
||||
if name in self.mcp_client_dict:
|
||||
client = self.mcp_client_dict[name]
|
||||
try:
|
||||
# 关闭MCP连接
|
||||
await client.cleanup()
|
||||
except Exception as e:
|
||||
logger.error(f"清空 MCP 客户端资源 {name}: {e}。")
|
||||
finally:
|
||||
# Remove client from dict after cleanup attempt (successful or not)
|
||||
self.mcp_client_dict.pop(name, None)
|
||||
# 移除关联的FuncTool
|
||||
self.func_list = [
|
||||
f
|
||||
for f in self.func_list
|
||||
if not (isinstance(f, MCPTool) and f.mcp_server_name == name)
|
||||
]
|
||||
logger.info(f"已关闭 MCP 服务 {name}")
|
||||
|
||||
@staticmethod
|
||||
async def test_mcp_server_connection(config: dict) -> list[str]:
|
||||
@@ -656,36 +319,42 @@ class FunctionToolManager:
|
||||
self,
|
||||
name: str,
|
||||
config: dict,
|
||||
shutdown_event: asyncio.Event | None = None,
|
||||
timeout: float | int | str | None = None,
|
||||
event: asyncio.Event | None = None,
|
||||
ready_future: asyncio.Future | None = None,
|
||||
timeout: int = 30,
|
||||
) -> None:
|
||||
"""Enable a new MCP server and initialize it.
|
||||
"""Enable_mcp_server a new MCP server to the manager and initialize it.
|
||||
|
||||
Args:
|
||||
name: The name of the MCP server.
|
||||
config: Configuration for the MCP server.
|
||||
shutdown_event: Event to signal when the MCP client should shut down.
|
||||
timeout: Timeout in seconds for initialization.
|
||||
Uses ASTRBOT_MCP_ENABLE_TIMEOUT by default (separate from init timeout).
|
||||
name (str): The name of the MCP server.
|
||||
config (dict): Configuration for the MCP server.
|
||||
event (asyncio.Event): Event to signal when the MCP client is ready.
|
||||
ready_future (asyncio.Future): Future to signal when the MCP client is ready.
|
||||
timeout (int): Timeout for the initialization.
|
||||
|
||||
Raises:
|
||||
MCPInitTimeoutError: If initialization does not complete within timeout.
|
||||
TimeoutError: If the initialization does not complete within the specified timeout.
|
||||
Exception: If there is an error during initialization.
|
||||
|
||||
"""
|
||||
if timeout is None:
|
||||
timeout_value = self._enable_timeout_default
|
||||
else:
|
||||
timeout_value = _resolve_timeout(
|
||||
timeout=timeout,
|
||||
env_name=ENABLE_MCP_TIMEOUT_ENV,
|
||||
default=self._enable_timeout_default,
|
||||
)
|
||||
await self._start_mcp_server(
|
||||
name=name,
|
||||
cfg=config,
|
||||
shutdown_event=shutdown_event,
|
||||
timeout=timeout_value,
|
||||
if not event:
|
||||
event = asyncio.Event()
|
||||
if not ready_future:
|
||||
ready_future = asyncio.Future()
|
||||
if name in self.mcp_client_dict:
|
||||
return
|
||||
asyncio.create_task(
|
||||
self._init_mcp_client_task_wrapper(name, config, event, ready_future),
|
||||
)
|
||||
try:
|
||||
await asyncio.wait_for(ready_future, timeout=timeout)
|
||||
finally:
|
||||
self.mcp_client_event[name] = event
|
||||
|
||||
if ready_future.done() and ready_future.exception():
|
||||
exc = ready_future.exception()
|
||||
if exc is not None:
|
||||
raise exc
|
||||
|
||||
async def disable_mcp_server(
|
||||
self,
|
||||
@@ -698,40 +367,39 @@ class FunctionToolManager:
|
||||
name (str): The name of the MCP server to disable. If None, ALL MCP servers will be disabled.
|
||||
timeout (int): Timeout.
|
||||
|
||||
Raises:
|
||||
MCPShutdownTimeoutError: If shutdown does not complete within timeout.
|
||||
Only raised when disabling a specific server (name is not None).
|
||||
|
||||
"""
|
||||
if name:
|
||||
async with self._runtime_lock:
|
||||
runtime = self._mcp_server_runtime.get(name)
|
||||
if runtime is None:
|
||||
if name not in self.mcp_client_event:
|
||||
return
|
||||
|
||||
await self._shutdown_runtimes([runtime], timeout, strict=True)
|
||||
client = self.mcp_client_dict.get(name)
|
||||
self.mcp_client_event[name].set()
|
||||
if not client:
|
||||
return
|
||||
client_running_event = client.running_event
|
||||
try:
|
||||
await asyncio.wait_for(client_running_event.wait(), timeout=timeout)
|
||||
finally:
|
||||
self.mcp_client_event.pop(name, None)
|
||||
self.func_list = [
|
||||
f
|
||||
for f in self.func_list
|
||||
if not (isinstance(f, MCPTool) and f.mcp_server_name == name)
|
||||
]
|
||||
else:
|
||||
async with self._runtime_lock:
|
||||
runtimes = list(self._mcp_server_runtime.values())
|
||||
await self._shutdown_runtimes(runtimes, timeout, strict=False)
|
||||
|
||||
def _warn_on_timeout_mismatch(
|
||||
self,
|
||||
init_timeout: float,
|
||||
enable_timeout: float,
|
||||
) -> None:
|
||||
if init_timeout == enable_timeout:
|
||||
return
|
||||
with self._timeout_warn_lock:
|
||||
if self._timeout_mismatch_warned:
|
||||
return
|
||||
logger.info(
|
||||
"检测到 MCP 初始化超时与动态启用超时配置不同:"
|
||||
"初始化使用 %s 秒,动态启用使用 %s 秒。如需一致,请设置相同值。",
|
||||
f"{init_timeout:g}",
|
||||
f"{enable_timeout:g}",
|
||||
)
|
||||
self._timeout_mismatch_warned = True
|
||||
running_events = [
|
||||
client.running_event.wait() for client in self.mcp_client_dict.values()
|
||||
]
|
||||
for key, event in self.mcp_client_event.items():
|
||||
event.set()
|
||||
# waiting for all clients to finish
|
||||
try:
|
||||
await asyncio.wait_for(asyncio.gather(*running_events), timeout=timeout)
|
||||
finally:
|
||||
self.mcp_client_event.clear()
|
||||
self.mcp_client_dict.clear()
|
||||
self.func_list = [
|
||||
f for f in self.func_list if not isinstance(f, MCPTool)
|
||||
]
|
||||
|
||||
def get_func_desc_openai_style(self, omit_empty_parameter_field=False) -> list:
|
||||
"""获得 OpenAI API 风格的**已经激活**的工具描述"""
|
||||
|
||||
@@ -330,25 +330,8 @@ class ProviderManager:
|
||||
if not self.curr_tts_provider_inst and self.tts_provider_insts:
|
||||
self.curr_tts_provider_inst = self.tts_provider_insts[0]
|
||||
|
||||
# 初始化 MCP Client 连接(等待完成以确保工具可用)
|
||||
strict_mcp_init = os.getenv("ASTRBOT_MCP_INIT_STRICT", "").strip().lower() in {
|
||||
"1",
|
||||
"true",
|
||||
"yes",
|
||||
"on",
|
||||
}
|
||||
mcp_init_summary = await self.llm_tools.init_mcp_clients(
|
||||
raise_on_all_failed=strict_mcp_init
|
||||
)
|
||||
if (
|
||||
mcp_init_summary.total > 0
|
||||
and mcp_init_summary.success == 0
|
||||
and not strict_mcp_init
|
||||
):
|
||||
logger.warning(
|
||||
"MCP 服务全部初始化失败,系统将继续启动(可设置 "
|
||||
"ASTRBOT_MCP_INIT_STRICT=1 以在此场景下中止启动)。"
|
||||
)
|
||||
# 初始化 MCP Client 连接
|
||||
asyncio.create_task(self.llm_tools.init_mcp_clients(), name="init_mcp_clients")
|
||||
|
||||
def dynamic_import_provider(self, type: str) -> None:
|
||||
"""动态导入提供商适配器模块
|
||||
|
||||
@@ -1,372 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from astrbot.core.computer.computer_client import sync_skills_to_active_sandboxes
|
||||
from astrbot.core.skills.skill_manager import SkillManager
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_skills_path
|
||||
|
||||
_MAP_VERSION = 1
|
||||
_MAP_FILE_NAME = "neo_skill_map.json"
|
||||
_SKILL_NAME_RE = re.compile(r"[^a-zA-Z0-9._-]+")
|
||||
|
||||
|
||||
def _now_iso() -> str:
|
||||
return datetime.now(timezone.utc).isoformat()
|
||||
|
||||
|
||||
def _to_jsonable(model_like: Any) -> dict[str, Any]:
|
||||
if isinstance(model_like, dict):
|
||||
return model_like
|
||||
if hasattr(model_like, "model_dump"):
|
||||
dumped = model_like.model_dump()
|
||||
if isinstance(dumped, dict):
|
||||
return dumped
|
||||
return {}
|
||||
|
||||
|
||||
def _parse_frontmatter(text: str) -> tuple[dict[str, str], str]:
|
||||
if not text.startswith("---"):
|
||||
return {}, text
|
||||
|
||||
lines = text.splitlines()
|
||||
if not lines or lines[0].strip() != "---":
|
||||
return {}, text
|
||||
|
||||
end_idx = None
|
||||
for i in range(1, len(lines)):
|
||||
if lines[i].strip() == "---":
|
||||
end_idx = i
|
||||
break
|
||||
|
||||
if end_idx is None:
|
||||
return {}, text
|
||||
|
||||
data: dict[str, str] = {}
|
||||
for line in lines[1:end_idx]:
|
||||
if ":" not in line:
|
||||
continue
|
||||
key, value = line.split(":", 1)
|
||||
key = key.strip().lower()
|
||||
value = value.strip().strip('"').strip("'")
|
||||
if key in {"name", "description"} and value:
|
||||
data[key] = value
|
||||
|
||||
body = "\n".join(lines[end_idx + 1 :]).lstrip("\n")
|
||||
return data, body
|
||||
|
||||
|
||||
def _derive_description(markdown_body: str) -> str:
|
||||
lines = markdown_body.splitlines()
|
||||
|
||||
heading_idx = None
|
||||
for i, line in enumerate(lines):
|
||||
normalized = line.strip().lower()
|
||||
if normalized in {"## 描述", "## description"}:
|
||||
heading_idx = i
|
||||
break
|
||||
|
||||
if heading_idx is not None:
|
||||
for line in lines[heading_idx + 1 :]:
|
||||
text = line.strip()
|
||||
if not text:
|
||||
continue
|
||||
if text.startswith("#"):
|
||||
break
|
||||
return text
|
||||
|
||||
for line in lines:
|
||||
text = line.strip()
|
||||
if not text or text.startswith("#"):
|
||||
continue
|
||||
return text
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
def _ensure_skill_frontmatter(markdown: str, *, skill_name: str, skill_key: str) -> str:
|
||||
frontmatter, body = _parse_frontmatter(markdown)
|
||||
|
||||
name = frontmatter.get("name") or skill_name
|
||||
name = " ".join(str(name).split())
|
||||
description = frontmatter.get("description") or _derive_description(body)
|
||||
if not description:
|
||||
description = f"Synced skill for `{skill_key}`."
|
||||
|
||||
description = " ".join(description.split())
|
||||
|
||||
header = f"---\nname: {name}\ndescription: {description}\n---\n\n"
|
||||
body = body.strip("\n")
|
||||
return f"{header}{body}\n"
|
||||
|
||||
|
||||
@dataclass
|
||||
class NeoSkillSyncResult:
|
||||
skill_key: str
|
||||
local_skill_name: str
|
||||
release_id: str
|
||||
candidate_id: str
|
||||
payload_ref: str
|
||||
map_path: str
|
||||
synced_at: str
|
||||
|
||||
|
||||
class NeoSkillSyncManager:
|
||||
@staticmethod
|
||||
def sync_result_to_dict(result: NeoSkillSyncResult) -> dict[str, str]:
|
||||
return {
|
||||
"skill_key": result.skill_key,
|
||||
"local_skill_name": result.local_skill_name,
|
||||
"release_id": result.release_id,
|
||||
"candidate_id": result.candidate_id,
|
||||
"payload_ref": result.payload_ref,
|
||||
"map_path": result.map_path,
|
||||
"synced_at": result.synced_at,
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
skills_root: str | None = None,
|
||||
map_path: str | None = None,
|
||||
) -> None:
|
||||
self.skills_root = skills_root or get_astrbot_skills_path()
|
||||
self.map_path = map_path or str(Path(self.skills_root) / _MAP_FILE_NAME)
|
||||
os.makedirs(self.skills_root, exist_ok=True)
|
||||
|
||||
def _load_map(self) -> dict[str, Any]:
|
||||
if not os.path.exists(self.map_path):
|
||||
return {"version": _MAP_VERSION, "items": {}}
|
||||
try:
|
||||
with open(self.map_path, encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
if not isinstance(data, dict):
|
||||
return {"version": _MAP_VERSION, "items": {}}
|
||||
items = data.get("items", {})
|
||||
if not isinstance(items, dict):
|
||||
items = {}
|
||||
return {"version": int(data.get("version", _MAP_VERSION)), "items": items}
|
||||
except Exception:
|
||||
return {"version": _MAP_VERSION, "items": {}}
|
||||
|
||||
def _save_map(self, data: dict[str, Any]) -> None:
|
||||
os.makedirs(os.path.dirname(self.map_path), exist_ok=True)
|
||||
with open(self.map_path, "w", encoding="utf-8") as f:
|
||||
json.dump(data, f, ensure_ascii=False, indent=2)
|
||||
|
||||
@staticmethod
|
||||
def normalize_skill_name(skill_key: str) -> str:
|
||||
normalized = _SKILL_NAME_RE.sub("-", skill_key.strip().lower())
|
||||
normalized = normalized.strip("._-")
|
||||
if not normalized:
|
||||
normalized = "skill"
|
||||
return f"neo_{normalized}"
|
||||
|
||||
def _resolve_local_skill_name(self, skill_key: str, mapping: dict[str, Any]) -> str:
|
||||
items = mapping.get("items", {})
|
||||
if not isinstance(items, dict):
|
||||
items = {}
|
||||
existing = items.get(skill_key)
|
||||
if isinstance(existing, dict):
|
||||
local_name = existing.get("local_skill_name")
|
||||
if isinstance(local_name, str) and local_name:
|
||||
return local_name
|
||||
|
||||
base = self.normalize_skill_name(skill_key)
|
||||
used_names = {
|
||||
str(v.get("local_skill_name"))
|
||||
for v in items.values()
|
||||
if isinstance(v, dict) and v.get("local_skill_name")
|
||||
}
|
||||
if base not in used_names:
|
||||
return base
|
||||
suffix = hashlib.sha1(skill_key.encode("utf-8")).hexdigest()[:8]
|
||||
return f"{base}-{suffix}"
|
||||
|
||||
async def _find_release(self, client: Any, *, release_id: str) -> dict[str, Any]:
|
||||
offset = 0
|
||||
while True:
|
||||
page = await client.skills.list_releases(limit=100, offset=offset)
|
||||
page_json = _to_jsonable(page)
|
||||
items = page_json.get("items", [])
|
||||
if not isinstance(items, list):
|
||||
items = []
|
||||
for item in items:
|
||||
if isinstance(item, dict) and item.get("id") == release_id:
|
||||
return item
|
||||
total = int(page_json.get("total", 0) or 0)
|
||||
offset += len(items)
|
||||
if offset >= total or not items:
|
||||
break
|
||||
raise ValueError(f"Release not found: {release_id}")
|
||||
|
||||
async def _find_active_stable_release(
|
||||
self,
|
||||
client: Any,
|
||||
*,
|
||||
skill_key: str,
|
||||
) -> dict[str, Any]:
|
||||
page = await client.skills.list_releases(
|
||||
skill_key=skill_key,
|
||||
active_only=True,
|
||||
stage="stable",
|
||||
limit=1,
|
||||
offset=0,
|
||||
)
|
||||
page_json = _to_jsonable(page)
|
||||
items = page_json.get("items", [])
|
||||
if not isinstance(items, list) or not items:
|
||||
raise ValueError(
|
||||
f"No active stable release found for skill_key: {skill_key}"
|
||||
)
|
||||
if not isinstance(items[0], dict):
|
||||
raise ValueError("Unexpected release payload format.")
|
||||
return items[0]
|
||||
|
||||
async def sync_release(
|
||||
self,
|
||||
client: Any,
|
||||
*,
|
||||
release_id: str | None = None,
|
||||
skill_key: str | None = None,
|
||||
require_stable: bool = True,
|
||||
) -> NeoSkillSyncResult:
|
||||
if release_id:
|
||||
release = await self._find_release(client, release_id=release_id)
|
||||
elif skill_key:
|
||||
release = await self._find_active_stable_release(
|
||||
client, skill_key=skill_key
|
||||
)
|
||||
else:
|
||||
raise ValueError("release_id or skill_key is required for sync.")
|
||||
|
||||
release_id_val = str(release.get("id") or "")
|
||||
release_stage_raw = release.get("stage")
|
||||
release_stage_value = getattr(release_stage_raw, "value", release_stage_raw)
|
||||
release_stage = str(release_stage_value or "").strip().lower()
|
||||
skill_key_val = str(release.get("skill_key") or "")
|
||||
candidate_id = str(release.get("candidate_id") or "")
|
||||
|
||||
if not release_id_val or not skill_key_val or not candidate_id:
|
||||
raise ValueError("Release payload is incomplete.")
|
||||
if require_stable and release_stage != "stable":
|
||||
raise ValueError(
|
||||
"Only stable releases can be synced to local SKILL.md "
|
||||
f"(got: {release_stage_raw})."
|
||||
)
|
||||
|
||||
candidate = await client.skills.get_candidate(candidate_id)
|
||||
candidate_json = _to_jsonable(candidate)
|
||||
payload_ref = candidate_json.get("payload_ref")
|
||||
if not isinstance(payload_ref, str) or not payload_ref:
|
||||
raise ValueError("Candidate payload_ref is missing.")
|
||||
|
||||
payload_resp = await client.skills.get_payload(payload_ref)
|
||||
payload_json = _to_jsonable(payload_resp)
|
||||
payload = payload_json.get("payload")
|
||||
if not isinstance(payload, dict):
|
||||
raise ValueError("Skill payload must be a JSON object.")
|
||||
|
||||
skill_markdown = payload.get("skill_markdown")
|
||||
if not isinstance(skill_markdown, str) or not skill_markdown.strip():
|
||||
raise ValueError(
|
||||
"payload.skill_markdown is required for stable sync to local skill."
|
||||
)
|
||||
|
||||
mapping = self._load_map()
|
||||
local_skill_name = self._resolve_local_skill_name(skill_key_val, mapping)
|
||||
skill_dir = Path(self.skills_root) / local_skill_name
|
||||
skill_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
normalized_markdown = _ensure_skill_frontmatter(
|
||||
skill_markdown,
|
||||
skill_name=local_skill_name,
|
||||
skill_key=skill_key_val,
|
||||
)
|
||||
|
||||
skill_md_path = skill_dir / "SKILL.md"
|
||||
skill_md_path.write_text(normalized_markdown, encoding="utf-8")
|
||||
|
||||
items = mapping.setdefault("items", {})
|
||||
items[skill_key_val] = {
|
||||
"local_skill_name": local_skill_name,
|
||||
"latest_release_id": release_id_val,
|
||||
"latest_candidate_id": candidate_id,
|
||||
"latest_payload_ref": payload_ref,
|
||||
"updated_at": _now_iso(),
|
||||
}
|
||||
mapping["version"] = _MAP_VERSION
|
||||
self._save_map(mapping)
|
||||
|
||||
# Ensure local skill is visible to AstrBot skill manager.
|
||||
SkillManager().set_skill_active(local_skill_name, True)
|
||||
|
||||
# Best-effort synchronization to active sandboxes.
|
||||
await sync_skills_to_active_sandboxes()
|
||||
|
||||
return NeoSkillSyncResult(
|
||||
skill_key=skill_key_val,
|
||||
local_skill_name=local_skill_name,
|
||||
release_id=release_id_val,
|
||||
candidate_id=candidate_id,
|
||||
payload_ref=payload_ref,
|
||||
map_path=self.map_path,
|
||||
synced_at=_now_iso(),
|
||||
)
|
||||
|
||||
async def promote_with_optional_sync(
|
||||
self,
|
||||
client: Any,
|
||||
*,
|
||||
candidate_id: str,
|
||||
stage: str,
|
||||
sync_to_local: bool,
|
||||
) -> dict[str, Any]:
|
||||
release = await client.skills.promote_candidate(candidate_id, stage=stage)
|
||||
release_json = _to_jsonable(release)
|
||||
|
||||
sync_json: dict[str, Any] | None = None
|
||||
rollback_json: dict[str, Any] | None = None
|
||||
sync_error: str | None = None
|
||||
|
||||
if stage == "stable" and sync_to_local:
|
||||
try:
|
||||
sync_result = await self.sync_release(
|
||||
client,
|
||||
release_id=str(release_json.get("id", "")),
|
||||
require_stable=True,
|
||||
)
|
||||
sync_json = self.sync_result_to_dict(sync_result)
|
||||
except Exception as err:
|
||||
sync_error = str(err)
|
||||
try:
|
||||
rollback = await client.skills.rollback_release(
|
||||
str(release_json.get("id", ""))
|
||||
)
|
||||
rollback_json = _to_jsonable(rollback)
|
||||
except Exception as rollback_err:
|
||||
rollback_msg = str(rollback_err)
|
||||
if "no previous release exists" in rollback_msg.lower():
|
||||
rollback_json = {
|
||||
"skipped": True,
|
||||
"reason": rollback_msg,
|
||||
}
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"stable release synced failed and auto rollback also failed; "
|
||||
f"sync_error={sync_error}; rollback_error={rollback_err}"
|
||||
) from rollback_err
|
||||
|
||||
return {
|
||||
"release": release_json,
|
||||
"sync": sync_json,
|
||||
"rollback": rollback_json,
|
||||
"sync_error": sync_error,
|
||||
}
|
||||
@@ -7,7 +7,6 @@ import shutil
|
||||
import tempfile
|
||||
import zipfile
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path, PurePosixPath
|
||||
|
||||
from astrbot.core.utils.astrbot_path import (
|
||||
@@ -17,11 +16,9 @@ from astrbot.core.utils.astrbot_path import (
|
||||
)
|
||||
|
||||
SKILLS_CONFIG_FILENAME = "skills.json"
|
||||
SANDBOX_SKILLS_CACHE_FILENAME = "sandbox_skills_cache.json"
|
||||
DEFAULT_SKILLS_CONFIG: dict[str, dict] = {"skills": {}}
|
||||
# SANDBOX_SKILLS_ROOT = "/home/shared/skills"
|
||||
SANDBOX_SKILLS_ROOT = "skills"
|
||||
SANDBOX_WORKSPACE_ROOT = "/workspace"
|
||||
_SANDBOX_SKILLS_CACHE_VERSION = 1
|
||||
|
||||
_SKILL_NAME_RE = re.compile(r"^[A-Za-z0-9._-]+$")
|
||||
|
||||
@@ -32,23 +29,9 @@ class SkillInfo:
|
||||
description: str
|
||||
path: str
|
||||
active: bool
|
||||
source_type: str = "local_only"
|
||||
source_label: str = "local"
|
||||
local_exists: bool = True
|
||||
sandbox_exists: bool = False
|
||||
|
||||
|
||||
def _parse_frontmatter_description(text: str) -> str:
|
||||
"""Extract the ``description`` value from YAML frontmatter.
|
||||
|
||||
Expects the standard SKILL.md format used by OpenAI Codex CLI and
|
||||
Anthropic Claude Skills::
|
||||
|
||||
---
|
||||
name: my-skill
|
||||
description: What this skill does and when to use it.
|
||||
---
|
||||
"""
|
||||
if not text.startswith("---"):
|
||||
return ""
|
||||
lines = text.splitlines()
|
||||
@@ -70,74 +53,45 @@ def _parse_frontmatter_description(text: str) -> str:
|
||||
return ""
|
||||
|
||||
|
||||
# Regex for sanitizing paths used in prompt examples — only allow
|
||||
# safe path characters to prevent prompt injection via crafted skill paths.
|
||||
_SAFE_PATH_RE = re.compile(r"[^A-Za-z0-9_./ -]")
|
||||
|
||||
|
||||
def build_skills_prompt(skills: list[SkillInfo]) -> str:
|
||||
"""Build the skills section of the system prompt.
|
||||
|
||||
Generates a markdown-formatted skill inventory for the LLM. Only
|
||||
``name`` and ``description`` are shown upfront; the LLM must read
|
||||
the full ``SKILL.md`` before execution (progressive disclosure).
|
||||
"""
|
||||
skills_lines: list[str] = []
|
||||
example_path = ""
|
||||
skills_lines = []
|
||||
for skill in skills:
|
||||
description = skill.description or "No description"
|
||||
skills_lines.append(
|
||||
f"- **{skill.name}**: {description}\n File: `{skill.path}`"
|
||||
)
|
||||
if not example_path:
|
||||
example_path = skill.path
|
||||
skills_lines.append(f"- {skill.name}: {description} (file: {skill.path})")
|
||||
skills_block = "\n".join(skills_lines)
|
||||
# Sanitize example_path — it may originate from sandbox cache (untrusted)
|
||||
example_path = _SAFE_PATH_RE.sub("", example_path) if example_path else ""
|
||||
example_path = example_path or "<skills_root>/<skill_name>/SKILL.md"
|
||||
|
||||
# Based on openai/codex
|
||||
return (
|
||||
"## Skills\n\n"
|
||||
"You have specialized skills — reusable instruction bundles stored "
|
||||
"in `SKILL.md` files. Each skill has a **name** and a **description** "
|
||||
"that tells you what it does and when to use it.\n\n"
|
||||
"### Available skills\n\n"
|
||||
f"{skills_block}\n\n"
|
||||
"### Skill rules\n\n"
|
||||
"1. **Discovery** — The list above is the complete skill inventory "
|
||||
"for this session. Full instructions are in the referenced "
|
||||
"`SKILL.md` file.\n"
|
||||
"2. **When to trigger** — Use a skill if the user names it "
|
||||
"explicitly, or if the task clearly matches the skill's description. "
|
||||
"*Never silently skip a matching skill* — either use it or briefly "
|
||||
"explain why you chose not to.\n"
|
||||
"3. **Mandatory grounding** — Before executing any skill you MUST "
|
||||
"first read its `SKILL.md` by running a shell command with the "
|
||||
f"**absolute path** shown above (e.g. `cat {example_path}`). "
|
||||
"Never rely on memory or assumptions about a skill's content.\n"
|
||||
"4. **Progressive disclosure** — Load only what is directly "
|
||||
"referenced from `SKILL.md`:\n"
|
||||
" - If `scripts/` exist, prefer running or patching them over "
|
||||
"rewriting code from scratch.\n"
|
||||
" - If `assets/` or templates exist, reuse them.\n"
|
||||
" - Do NOT bulk-load every file in the skill directory.\n"
|
||||
"5. **Coordination** — When multiple skills apply, pick the minimal "
|
||||
"set needed. Announce which skill(s) you are using and why "
|
||||
"(one short line). Prefer `astrbot_*` tools when running skill "
|
||||
"scripts.\n"
|
||||
"6. **Context hygiene** — Avoid deep reference chasing; open only "
|
||||
"files that are directly linked from `SKILL.md`.\n"
|
||||
"7. **Failure handling** — If a skill cannot be applied, state the "
|
||||
"issue clearly and continue with the best alternative.\n"
|
||||
"## Skills\n"
|
||||
"You have many useful skills that can help you accomplish various tasks.\n"
|
||||
"A skill is a set of local instructions stored in a `SKILL.md` file.\n"
|
||||
"### Available skills\n"
|
||||
f"{skills_block}\n"
|
||||
"### Skill Rules\n"
|
||||
"\n"
|
||||
"- Discovery: The list above shows all skills available in this session. Full instructions live in the referenced `SKILL.md`.\n"
|
||||
"- Trigger rules: Use a skill if the user names it or the task matches its description. Do not carry skills across turns unless re-mentioned\n"
|
||||
"### How to use a skill (progressive disclosure):\n"
|
||||
" 0) Mandatory grounding: Before using any skill, you MUST inspect its `SKILL.md` using shell tools"
|
||||
" (e.g., `cat`, `head`, `sed`, `awk`, `grep`). Do not rely on assumptions or memory.\n"
|
||||
" 1) Load only directly referenced files, DO NOT bulk-load everything.\n"
|
||||
" 2) If `scripts/` exist, prefer running or patching them instead of retyping large blocks of code.\n"
|
||||
" 3) If `assets/` or templates exist, reuse them rather than recreating everything from scratch.\n"
|
||||
"- Coordination:\n"
|
||||
" - If multiple skills apply, choose the minimal set that covers the request and state the order in which you will use them.\n"
|
||||
" - Announce which skill(s) you are using and why (one short line). If you skip an obvious skill, explain why.\n"
|
||||
" - Prefer to use `astrbot_*` tools to perform skills that need to run scripts.\n"
|
||||
"- Context hygiene:\n"
|
||||
" - Avoid deep reference chasing: unless blocked, open only files that are directly linked from `SKILL.md`.\n"
|
||||
"- Failure handling: If a skill cannot be applied, state the issue and continue with the best alternative.\n"
|
||||
"### Example\n"
|
||||
"When you decided to use a skill, use shell tool to read its `SKILL.md`, e.g., `head -40 skills/code_formatter/SKILL.md`, and you can increase or decrease the number of lines as needed.\n"
|
||||
)
|
||||
|
||||
|
||||
class SkillManager:
|
||||
def __init__(self, skills_root: str | None = None) -> None:
|
||||
self.skills_root = skills_root or get_astrbot_skills_path()
|
||||
data_path = Path(get_astrbot_data_path())
|
||||
self.config_path = str(data_path / SKILLS_CONFIG_FILENAME)
|
||||
self.sandbox_skills_cache_path = str(data_path / SANDBOX_SKILLS_CACHE_FILENAME)
|
||||
self.config_path = os.path.join(get_astrbot_data_path(), SKILLS_CONFIG_FILENAME)
|
||||
os.makedirs(self.skills_root, exist_ok=True)
|
||||
|
||||
def _load_config(self) -> dict:
|
||||
@@ -154,66 +108,6 @@ class SkillManager:
|
||||
with open(self.config_path, "w", encoding="utf-8") as f:
|
||||
json.dump(config, f, ensure_ascii=False, indent=4)
|
||||
|
||||
def _load_sandbox_skills_cache(self) -> dict:
|
||||
if not os.path.exists(self.sandbox_skills_cache_path):
|
||||
return {"version": _SANDBOX_SKILLS_CACHE_VERSION, "skills": []}
|
||||
try:
|
||||
with open(self.sandbox_skills_cache_path, encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
if not isinstance(data, dict):
|
||||
return {"version": _SANDBOX_SKILLS_CACHE_VERSION, "skills": []}
|
||||
skills = data.get("skills", [])
|
||||
if not isinstance(skills, list):
|
||||
skills = []
|
||||
return {
|
||||
"version": int(data.get("version", _SANDBOX_SKILLS_CACHE_VERSION)),
|
||||
"skills": skills,
|
||||
"updated_at": data.get("updated_at"),
|
||||
}
|
||||
except Exception:
|
||||
return {"version": _SANDBOX_SKILLS_CACHE_VERSION, "skills": []}
|
||||
|
||||
def _save_sandbox_skills_cache(self, cache: dict) -> None:
|
||||
cache["version"] = _SANDBOX_SKILLS_CACHE_VERSION
|
||||
cache["updated_at"] = datetime.now(timezone.utc).isoformat()
|
||||
with open(self.sandbox_skills_cache_path, "w", encoding="utf-8") as f:
|
||||
json.dump(cache, f, ensure_ascii=False, indent=2)
|
||||
|
||||
def set_sandbox_skills_cache(self, skills: list[dict]) -> None:
|
||||
"""Persist sandbox skill metadata discovered from runtime side."""
|
||||
deduped: dict[str, dict[str, str]] = {}
|
||||
for item in skills:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
name = str(item.get("name", "")).strip()
|
||||
if not name or not _SKILL_NAME_RE.match(name):
|
||||
continue
|
||||
description = str(item.get("description", "") or "")
|
||||
path = str(item.get("path", "") or "")
|
||||
if not path:
|
||||
path = f"{SANDBOX_WORKSPACE_ROOT}/{SANDBOX_SKILLS_ROOT}/{name}/SKILL.md"
|
||||
deduped[name] = {
|
||||
"name": name,
|
||||
"description": description,
|
||||
"path": path.replace("\\", "/"),
|
||||
}
|
||||
cache = {
|
||||
"version": _SANDBOX_SKILLS_CACHE_VERSION,
|
||||
"skills": [deduped[name] for name in sorted(deduped)],
|
||||
}
|
||||
self._save_sandbox_skills_cache(cache)
|
||||
|
||||
def get_sandbox_skills_cache_status(self) -> dict[str, object]:
|
||||
cache = self._load_sandbox_skills_cache()
|
||||
skills = cache.get("skills", [])
|
||||
count = len(skills) if isinstance(skills, list) else 0
|
||||
return {
|
||||
"exists": os.path.exists(self.sandbox_skills_cache_path),
|
||||
"ready": count > 0,
|
||||
"count": count,
|
||||
"updated_at": cache.get("updated_at"),
|
||||
}
|
||||
|
||||
def list_skills(
|
||||
self,
|
||||
*,
|
||||
@@ -230,21 +124,7 @@ class SkillManager:
|
||||
config = self._load_config()
|
||||
skill_configs = config.get("skills", {})
|
||||
modified = False
|
||||
skills_by_name: dict[str, SkillInfo] = {}
|
||||
|
||||
sandbox_cached_paths: dict[str, str] = {}
|
||||
sandbox_cached_descriptions: dict[str, str] = {}
|
||||
cache_for_paths = self._load_sandbox_skills_cache()
|
||||
for item in cache_for_paths.get("skills", []):
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
name = str(item.get("name", "") or "").strip()
|
||||
path = str(item.get("path", "") or "").strip().replace("\\", "/")
|
||||
if not name or not _SKILL_NAME_RE.match(name):
|
||||
continue
|
||||
sandbox_cached_descriptions[name] = str(item.get("description", "") or "")
|
||||
if path:
|
||||
sandbox_cached_paths[name] = path
|
||||
skills: list[SkillInfo] = []
|
||||
|
||||
for entry in sorted(Path(self.skills_root).iterdir()):
|
||||
if not entry.is_dir():
|
||||
@@ -265,129 +145,36 @@ class SkillManager:
|
||||
description = _parse_frontmatter_description(content)
|
||||
except Exception:
|
||||
description = ""
|
||||
sandbox_exists = (
|
||||
runtime == "sandbox" and skill_name in sandbox_cached_descriptions
|
||||
)
|
||||
source_type = "both" if sandbox_exists else "local_only"
|
||||
source_label = "synced" if sandbox_exists else "local"
|
||||
if runtime == "sandbox" and show_sandbox_path:
|
||||
path_str = sandbox_cached_paths.get(skill_name) or (
|
||||
f"{SANDBOX_WORKSPACE_ROOT}/{SANDBOX_SKILLS_ROOT}/{skill_name}/SKILL.md"
|
||||
)
|
||||
path_str = f"{SANDBOX_SKILLS_ROOT}/{skill_name}/SKILL.md"
|
||||
else:
|
||||
path_str = str(skill_md)
|
||||
path_str = path_str.replace("\\", "/")
|
||||
skills_by_name[skill_name] = SkillInfo(
|
||||
name=skill_name,
|
||||
description=description,
|
||||
path=path_str,
|
||||
active=active,
|
||||
source_type=source_type,
|
||||
source_label=source_label,
|
||||
local_exists=True,
|
||||
sandbox_exists=sandbox_exists,
|
||||
)
|
||||
|
||||
if runtime == "sandbox":
|
||||
cache = self._load_sandbox_skills_cache()
|
||||
for item in cache.get("skills", []):
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
skill_name = str(item.get("name", "")).strip()
|
||||
if (
|
||||
not skill_name
|
||||
or skill_name in skills_by_name
|
||||
or not _SKILL_NAME_RE.match(skill_name)
|
||||
):
|
||||
continue
|
||||
active = skill_configs.get(skill_name, {}).get("active", True)
|
||||
if skill_name not in skill_configs:
|
||||
skill_configs[skill_name] = {"active": active}
|
||||
modified = True
|
||||
if active_only and not active:
|
||||
continue
|
||||
description = sandbox_cached_descriptions.get(skill_name, "")
|
||||
if show_sandbox_path:
|
||||
path_str = f"{SANDBOX_WORKSPACE_ROOT}/{SANDBOX_SKILLS_ROOT}/{skill_name}/SKILL.md"
|
||||
else:
|
||||
path_str = sandbox_cached_paths.get(skill_name, "")
|
||||
if not path_str:
|
||||
path_str = f"{SANDBOX_WORKSPACE_ROOT}/{SANDBOX_SKILLS_ROOT}/{skill_name}/SKILL.md"
|
||||
skills_by_name[skill_name] = SkillInfo(
|
||||
skills.append(
|
||||
SkillInfo(
|
||||
name=skill_name,
|
||||
description=description,
|
||||
path=path_str.replace("\\", "/"),
|
||||
path=path_str,
|
||||
active=active,
|
||||
source_type="sandbox_only",
|
||||
source_label="sandbox_preset",
|
||||
local_exists=False,
|
||||
sandbox_exists=True,
|
||||
)
|
||||
)
|
||||
|
||||
if modified:
|
||||
config["skills"] = skill_configs
|
||||
self._save_config(config)
|
||||
|
||||
return [skills_by_name[name] for name in sorted(skills_by_name)]
|
||||
|
||||
def is_sandbox_only_skill(self, name: str) -> bool:
|
||||
skill_dir = Path(self.skills_root) / name
|
||||
skill_md_exists = (skill_dir / "SKILL.md").exists()
|
||||
if skill_md_exists:
|
||||
return False
|
||||
cache = self._load_sandbox_skills_cache()
|
||||
skills = cache.get("skills", [])
|
||||
if not isinstance(skills, list):
|
||||
return False
|
||||
for item in skills:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
if str(item.get("name", "")).strip() == name:
|
||||
return True
|
||||
return False
|
||||
return skills
|
||||
|
||||
def set_skill_active(self, name: str, active: bool) -> None:
|
||||
if self.is_sandbox_only_skill(name):
|
||||
raise PermissionError(
|
||||
"Sandbox preset skill cannot be enabled/disabled from local skill management."
|
||||
)
|
||||
config = self._load_config()
|
||||
config.setdefault("skills", {})
|
||||
config["skills"][name] = {"active": bool(active)}
|
||||
self._save_config(config)
|
||||
|
||||
def _remove_skill_from_sandbox_cache(self, name: str) -> None:
|
||||
cache = self._load_sandbox_skills_cache()
|
||||
skills = cache.get("skills", [])
|
||||
if not isinstance(skills, list):
|
||||
return
|
||||
|
||||
filtered = [
|
||||
item
|
||||
for item in skills
|
||||
if not (
|
||||
isinstance(item, dict) and str(item.get("name", "")).strip() == name
|
||||
)
|
||||
]
|
||||
|
||||
if len(filtered) != len(skills):
|
||||
cache["skills"] = filtered
|
||||
self._save_sandbox_skills_cache(cache)
|
||||
|
||||
def delete_skill(self, name: str) -> None:
|
||||
if self.is_sandbox_only_skill(name):
|
||||
raise PermissionError(
|
||||
"Sandbox preset skill cannot be deleted from local skill management."
|
||||
)
|
||||
|
||||
skill_dir = Path(self.skills_root) / name
|
||||
if skill_dir.exists():
|
||||
shutil.rmtree(skill_dir)
|
||||
|
||||
# Ensure UI consistency even when there is no active sandbox session
|
||||
# to refresh cache from runtime side.
|
||||
self._remove_skill_from_sandbox_cache(name)
|
||||
|
||||
config = self._load_config()
|
||||
if name in config.get("skills", {}):
|
||||
config["skills"].pop(name, None)
|
||||
@@ -409,7 +196,7 @@ class SkillManager:
|
||||
top_dirs = {
|
||||
PurePosixPath(name).parts[0] for name in file_names if name.strip()
|
||||
}
|
||||
|
||||
print(top_dirs)
|
||||
if len(top_dirs) != 1:
|
||||
raise ValueError("Zip archive must contain a single top-level folder.")
|
||||
skill_name = next(iter(top_dirs))
|
||||
|
||||
@@ -149,9 +149,7 @@ class AstrBotUpdator(RepoZipUpdator):
|
||||
file_url = None
|
||||
|
||||
if os.environ.get("ASTRBOT_CLI") or os.environ.get("ASTRBOT_LAUNCHER"):
|
||||
raise Exception(
|
||||
"Error: You are running AstrBot via CLI, please use `pip` or `uv tool upgrade` to update AstrBot."
|
||||
) # 避免版本管理混乱
|
||||
raise Exception("不支持更新此方式启动的AstrBot") # 避免版本管理混乱
|
||||
|
||||
if latest:
|
||||
latest_version = update_data[0]["tag_name"]
|
||||
|
||||
@@ -14,7 +14,7 @@ import certifi
|
||||
import psutil
|
||||
from PIL import Image
|
||||
|
||||
from .astrbot_path import get_astrbot_data_path, get_astrbot_path, get_astrbot_temp_path
|
||||
from .astrbot_path import get_astrbot_data_path, get_astrbot_temp_path
|
||||
|
||||
logger = logging.getLogger("astrbot")
|
||||
|
||||
@@ -219,13 +219,7 @@ def get_local_ip_addresses():
|
||||
|
||||
|
||||
async def get_dashboard_version():
|
||||
# First check user data directory (manually updated / downloaded dashboard).
|
||||
dist_dir = os.path.join(get_astrbot_data_path(), "dist")
|
||||
if not os.path.exists(dist_dir):
|
||||
# Fall back to the dist bundled inside the installed wheel.
|
||||
_bundled = Path(get_astrbot_path()) / "astrbot" / "dashboard" / "dist"
|
||||
if _bundled.exists():
|
||||
dist_dir = str(_bundled)
|
||||
if os.path.exists(dist_dir):
|
||||
version_file = os.path.join(dist_dir, "assets", "version")
|
||||
if os.path.exists(version_file):
|
||||
|
||||
@@ -206,110 +206,12 @@ def validate_config(data, schema: dict, is_core: bool) -> tuple[list[str], dict]
|
||||
return errors, data
|
||||
|
||||
|
||||
def _log_computer_config_changes(old_config: dict, new_config: dict) -> None:
|
||||
"""Compare and log Computer/sandbox configuration changes."""
|
||||
old_ps = old_config.get("provider_settings", {})
|
||||
new_ps = new_config.get("provider_settings", {})
|
||||
|
||||
# Check computer_use_runtime
|
||||
old_runtime = old_ps.get("computer_use_runtime", "none")
|
||||
new_runtime = new_ps.get("computer_use_runtime", "none")
|
||||
if old_runtime != new_runtime:
|
||||
logger.info(
|
||||
"[Computer] Config changed: computer_use_runtime %s -> %s",
|
||||
old_runtime,
|
||||
new_runtime,
|
||||
)
|
||||
|
||||
# Check sandbox sub-keys
|
||||
old_sandbox = old_ps.get("sandbox", {})
|
||||
new_sandbox = new_ps.get("sandbox", {})
|
||||
all_keys = set(old_sandbox.keys()) | set(new_sandbox.keys())
|
||||
for key in sorted(all_keys):
|
||||
old_val = old_sandbox.get(key)
|
||||
new_val = new_sandbox.get(key)
|
||||
if old_val != new_val:
|
||||
# Mask tokens/secrets in log output
|
||||
if "token" in key or "secret" in key:
|
||||
old_display = "***" if old_val else "(empty)"
|
||||
new_display = "***" if new_val else "(empty)"
|
||||
else:
|
||||
old_display = old_val
|
||||
new_display = new_val
|
||||
logger.info(
|
||||
"[Computer] Config changed: sandbox.%s %s -> %s",
|
||||
key,
|
||||
old_display,
|
||||
new_display,
|
||||
)
|
||||
|
||||
|
||||
async def _validate_neo_connectivity(
|
||||
post_config: dict,
|
||||
) -> str | None:
|
||||
"""Check if Bay is reachable when Shipyard Neo sandbox is configured.
|
||||
|
||||
Returns a warning message string if Bay isn't reachable, or None if
|
||||
everything looks fine (or Neo isn't configured).
|
||||
"""
|
||||
ps = post_config.get("provider_settings", {})
|
||||
runtime = ps.get("computer_use_runtime", "none")
|
||||
sandbox = ps.get("sandbox", {})
|
||||
booter = sandbox.get("booter", "")
|
||||
|
||||
# Only check when sandbox mode + shipyard_neo is selected
|
||||
if runtime != "sandbox" or booter != "shipyard_neo":
|
||||
return None
|
||||
|
||||
endpoint = sandbox.get("shipyard_neo_endpoint", "").rstrip("/")
|
||||
if not endpoint:
|
||||
return "⚠️ Shipyard Neo endpoint 未设置"
|
||||
|
||||
access_token = sandbox.get("shipyard_neo_access_token", "")
|
||||
if not access_token:
|
||||
# Try auto-discovery
|
||||
from astrbot.core.computer.computer_client import _discover_bay_credentials
|
||||
|
||||
access_token = _discover_bay_credentials(endpoint)
|
||||
|
||||
if not access_token:
|
||||
return (
|
||||
"⚠️ 未找到 Bay API Key。请填写访问令牌,"
|
||||
"或确保 Bay 的 credentials.json 可被自动发现。"
|
||||
)
|
||||
|
||||
# Connectivity check
|
||||
import aiohttp
|
||||
|
||||
health_url = f"{endpoint}/health"
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
health_url,
|
||||
timeout=aiohttp.ClientTimeout(total=5),
|
||||
) as resp:
|
||||
if resp.status != 200:
|
||||
return (
|
||||
f"⚠️ Bay 健康检查失败 (HTTP {resp.status}),"
|
||||
f"请确认 Bay 正在运行: {endpoint}"
|
||||
)
|
||||
except Exception:
|
||||
return f"⚠️ 无法连接 Bay ({endpoint}),请确认 Bay 已启动。"
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def save_config(
|
||||
post_config: dict, config: AstrBotConfig, is_core: bool = False
|
||||
) -> None:
|
||||
"""验证并保存配置"""
|
||||
errors = None
|
||||
logger.info(f"Saving config, is_core={is_core}")
|
||||
|
||||
# Snapshot old Computer config for change detection
|
||||
if is_core:
|
||||
_log_computer_config_changes(dict(config), post_config)
|
||||
|
||||
try:
|
||||
if is_core:
|
||||
errors, post_config = validate_config(
|
||||
@@ -1026,11 +928,6 @@ class ConfigRoute(Route):
|
||||
|
||||
await self._save_astrbot_configs(config, conf_id)
|
||||
await self.core_lifecycle.reload_pipeline_scheduler(conf_id)
|
||||
|
||||
# Non-blocking Bay connectivity check
|
||||
warning = await _validate_neo_connectivity(config)
|
||||
if warning:
|
||||
return Response().ok(None, f"保存成功。{warning}").__dict__
|
||||
return Response().ok(None, "保存成功~").__dict__
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
@@ -1,48 +1,15 @@
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import traceback
|
||||
from collections.abc import Awaitable, Callable
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from quart import request, send_file
|
||||
from quart import request
|
||||
|
||||
from astrbot.core import DEMO_MODE, logger
|
||||
from astrbot.core.computer.computer_client import (
|
||||
_discover_bay_credentials,
|
||||
sync_skills_to_active_sandboxes,
|
||||
)
|
||||
from astrbot.core.skills.neo_skill_sync import NeoSkillSyncManager
|
||||
from astrbot.core.skills.skill_manager import SkillManager
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
|
||||
|
||||
from .route import Response, Route, RouteContext
|
||||
|
||||
|
||||
def _to_jsonable(value: Any) -> Any:
|
||||
if isinstance(value, dict):
|
||||
return {k: _to_jsonable(v) for k, v in value.items()}
|
||||
if isinstance(value, list):
|
||||
return [_to_jsonable(v) for v in value]
|
||||
if hasattr(value, "model_dump"):
|
||||
return _to_jsonable(value.model_dump())
|
||||
return value
|
||||
|
||||
|
||||
def _to_bool(value: Any, default: bool = False) -> bool:
|
||||
if value is None:
|
||||
return default
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
if isinstance(value, str):
|
||||
return value.strip().lower() in {"1", "true", "yes", "y", "on"}
|
||||
return bool(value)
|
||||
|
||||
|
||||
_SKILL_NAME_RE = re.compile(r"^[A-Za-z0-9._-]+$")
|
||||
|
||||
|
||||
class SkillsRoute(Route):
|
||||
def __init__(self, context: RouteContext, core_lifecycle) -> None:
|
||||
super().__init__(context)
|
||||
@@ -50,81 +17,18 @@ class SkillsRoute(Route):
|
||||
self.routes = {
|
||||
"/skills": ("GET", self.get_skills),
|
||||
"/skills/upload": ("POST", self.upload_skill),
|
||||
"/skills/download": ("GET", self.download_skill),
|
||||
"/skills/update": ("POST", self.update_skill),
|
||||
"/skills/delete": ("POST", self.delete_skill),
|
||||
"/skills/neo/candidates": ("GET", self.get_neo_candidates),
|
||||
"/skills/neo/releases": ("GET", self.get_neo_releases),
|
||||
"/skills/neo/payload": ("GET", self.get_neo_payload),
|
||||
"/skills/neo/evaluate": ("POST", self.evaluate_neo_candidate),
|
||||
"/skills/neo/promote": ("POST", self.promote_neo_candidate),
|
||||
"/skills/neo/rollback": ("POST", self.rollback_neo_release),
|
||||
"/skills/neo/sync": ("POST", self.sync_neo_release),
|
||||
"/skills/neo/delete-candidate": ("POST", self.delete_neo_candidate),
|
||||
"/skills/neo/delete-release": ("POST", self.delete_neo_release),
|
||||
}
|
||||
self.register_routes()
|
||||
|
||||
def _get_neo_client_config(self) -> tuple[str, str]:
|
||||
provider_settings = self.core_lifecycle.astrbot_config.get(
|
||||
"provider_settings",
|
||||
{},
|
||||
)
|
||||
sandbox = provider_settings.get("sandbox", {})
|
||||
endpoint = sandbox.get("shipyard_neo_endpoint", "")
|
||||
access_token = sandbox.get("shipyard_neo_access_token", "")
|
||||
|
||||
# Auto-discover token from Bay's credentials.json if not configured
|
||||
if not access_token and endpoint:
|
||||
access_token = _discover_bay_credentials(endpoint)
|
||||
|
||||
if not endpoint or not access_token:
|
||||
raise ValueError(
|
||||
"Shipyard Neo endpoint or access token not configured. "
|
||||
"Set them in Dashboard or ensure Bay's credentials.json is accessible."
|
||||
)
|
||||
return endpoint, access_token
|
||||
|
||||
async def _delete_neo_release(
|
||||
self, client: Any, release_id: str, reason: str | None
|
||||
):
|
||||
return await client.skills.delete_release(release_id, reason=reason)
|
||||
|
||||
async def _delete_neo_candidate(
|
||||
self, client: Any, candidate_id: str, reason: str | None
|
||||
):
|
||||
return await client.skills.delete_candidate(candidate_id, reason=reason)
|
||||
|
||||
async def _with_neo_client(
|
||||
self,
|
||||
operation: Callable[[Any], Awaitable[dict]],
|
||||
) -> dict:
|
||||
try:
|
||||
endpoint, access_token = self._get_neo_client_config()
|
||||
|
||||
from shipyard_neo import BayClient
|
||||
|
||||
async with BayClient(
|
||||
endpoint_url=endpoint,
|
||||
access_token=access_token,
|
||||
) as client:
|
||||
return await operation(client)
|
||||
except ValueError as e:
|
||||
# Config not ready — expected when Neo isn't set up yet
|
||||
logger.debug("[Neo] %s", e)
|
||||
return Response().error(str(e)).__dict__
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(str(e)).__dict__
|
||||
|
||||
async def get_skills(self):
|
||||
try:
|
||||
provider_settings = self.core_lifecycle.astrbot_config.get(
|
||||
"provider_settings", {}
|
||||
)
|
||||
runtime = provider_settings.get("computer_use_runtime", "local")
|
||||
skill_mgr = SkillManager()
|
||||
skills = skill_mgr.list_skills(
|
||||
skills = SkillManager().list_skills(
|
||||
active_only=False, runtime=runtime, show_sandbox_path=False
|
||||
)
|
||||
return (
|
||||
@@ -132,8 +36,6 @@ class SkillsRoute(Route):
|
||||
.ok(
|
||||
{
|
||||
"skills": [skill.__dict__ for skill in skills],
|
||||
"runtime": runtime,
|
||||
"sandbox_cache": skill_mgr.get_sandbox_skills_cache_status(),
|
||||
}
|
||||
)
|
||||
.__dict__
|
||||
@@ -168,11 +70,6 @@ class SkillsRoute(Route):
|
||||
skill_mgr = SkillManager()
|
||||
skill_name = skill_mgr.install_skill_from_zip(temp_path, overwrite=True)
|
||||
|
||||
try:
|
||||
await sync_skills_to_active_sandboxes()
|
||||
except Exception:
|
||||
logger.warning("Failed to sync uploaded skills to active sandboxes.")
|
||||
|
||||
return (
|
||||
Response()
|
||||
.ok({"name": skill_name}, "Skill uploaded successfully.")
|
||||
@@ -188,53 +85,6 @@ class SkillsRoute(Route):
|
||||
except Exception:
|
||||
logger.warning(f"Failed to remove temp skill file: {temp_path}")
|
||||
|
||||
async def download_skill(self):
|
||||
try:
|
||||
name = str(request.args.get("name") or "").strip()
|
||||
if not name:
|
||||
return Response().error("Missing skill name").__dict__
|
||||
if not _SKILL_NAME_RE.match(name):
|
||||
return Response().error("Invalid skill name").__dict__
|
||||
|
||||
skill_mgr = SkillManager()
|
||||
if skill_mgr.is_sandbox_only_skill(name):
|
||||
return (
|
||||
Response()
|
||||
.error(
|
||||
"Sandbox preset skill cannot be downloaded from local skill files."
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
|
||||
skill_dir = Path(skill_mgr.skills_root) / name
|
||||
skill_md = skill_dir / "SKILL.md"
|
||||
if not skill_dir.is_dir() or not skill_md.exists():
|
||||
return Response().error("Local skill not found").__dict__
|
||||
|
||||
export_dir = Path(get_astrbot_temp_path()) / "skill_exports"
|
||||
export_dir.mkdir(parents=True, exist_ok=True)
|
||||
zip_base = export_dir / name
|
||||
zip_path = zip_base.with_suffix(".zip")
|
||||
if zip_path.exists():
|
||||
zip_path.unlink()
|
||||
|
||||
shutil.make_archive(
|
||||
str(zip_base),
|
||||
"zip",
|
||||
root_dir=str(skill_mgr.skills_root),
|
||||
base_dir=name,
|
||||
)
|
||||
|
||||
return await send_file(
|
||||
str(zip_path),
|
||||
as_attachment=True,
|
||||
attachment_filename=f"{name}.zip",
|
||||
conditional=True,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(str(e)).__dict__
|
||||
|
||||
async def update_skill(self):
|
||||
if DEMO_MODE:
|
||||
return (
|
||||
@@ -267,262 +117,7 @@ class SkillsRoute(Route):
|
||||
if not name:
|
||||
return Response().error("Missing skill name").__dict__
|
||||
SkillManager().delete_skill(name)
|
||||
try:
|
||||
await sync_skills_to_active_sandboxes()
|
||||
except Exception:
|
||||
logger.warning("Failed to sync deleted skills to active sandboxes.")
|
||||
return Response().ok({"name": name}).__dict__
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(str(e)).__dict__
|
||||
|
||||
async def get_neo_candidates(self):
|
||||
logger.info("[Neo] GET /skills/neo/candidates requested.")
|
||||
status = request.args.get("status")
|
||||
skill_key = request.args.get("skill_key")
|
||||
limit = int(request.args.get("limit", 100))
|
||||
offset = int(request.args.get("offset", 0))
|
||||
|
||||
async def _do(client):
|
||||
candidates = await client.skills.list_candidates(
|
||||
status=status,
|
||||
skill_key=skill_key,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
result = _to_jsonable(candidates)
|
||||
total = result.get("total", "?") if isinstance(result, dict) else "?"
|
||||
logger.info(f"[Neo] Candidates fetched: total={total}")
|
||||
return Response().ok(result).__dict__
|
||||
|
||||
return await self._with_neo_client(_do)
|
||||
|
||||
async def get_neo_releases(self):
|
||||
logger.info("[Neo] GET /skills/neo/releases requested.")
|
||||
skill_key = request.args.get("skill_key")
|
||||
stage = request.args.get("stage")
|
||||
active_only = _to_bool(request.args.get("active_only"), False)
|
||||
limit = int(request.args.get("limit", 100))
|
||||
offset = int(request.args.get("offset", 0))
|
||||
|
||||
async def _do(client):
|
||||
releases = await client.skills.list_releases(
|
||||
skill_key=skill_key,
|
||||
active_only=active_only,
|
||||
stage=stage,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
result = _to_jsonable(releases)
|
||||
total = result.get("total", "?") if isinstance(result, dict) else "?"
|
||||
logger.info(f"[Neo] Releases fetched: total={total}")
|
||||
return Response().ok(result).__dict__
|
||||
|
||||
return await self._with_neo_client(_do)
|
||||
|
||||
async def get_neo_payload(self):
|
||||
logger.info("[Neo] GET /skills/neo/payload requested.")
|
||||
payload_ref = request.args.get("payload_ref", "")
|
||||
if not payload_ref:
|
||||
return Response().error("Missing payload_ref").__dict__
|
||||
|
||||
async def _do(client):
|
||||
payload = await client.skills.get_payload(payload_ref)
|
||||
logger.info(f"[Neo] Payload fetched: ref={payload_ref}")
|
||||
return Response().ok(_to_jsonable(payload)).__dict__
|
||||
|
||||
return await self._with_neo_client(_do)
|
||||
|
||||
async def evaluate_neo_candidate(self):
|
||||
if DEMO_MODE:
|
||||
return (
|
||||
Response()
|
||||
.error("You are not permitted to do this operation in demo mode")
|
||||
.__dict__
|
||||
)
|
||||
logger.info("[Neo] POST /skills/neo/evaluate requested.")
|
||||
data = await request.get_json()
|
||||
candidate_id = data.get("candidate_id")
|
||||
passed_value = data.get("passed")
|
||||
if not candidate_id or passed_value is None:
|
||||
return Response().error("Missing candidate_id or passed").__dict__
|
||||
passed = _to_bool(passed_value, False)
|
||||
|
||||
async def _do(client):
|
||||
result = await client.skills.evaluate_candidate(
|
||||
candidate_id,
|
||||
passed=passed,
|
||||
score=data.get("score"),
|
||||
benchmark_id=data.get("benchmark_id"),
|
||||
report=data.get("report"),
|
||||
)
|
||||
logger.info(
|
||||
f"[Neo] Candidate evaluated: id={candidate_id}, passed={passed}"
|
||||
)
|
||||
return Response().ok(_to_jsonable(result)).__dict__
|
||||
|
||||
return await self._with_neo_client(_do)
|
||||
|
||||
async def promote_neo_candidate(self):
|
||||
if DEMO_MODE:
|
||||
return (
|
||||
Response()
|
||||
.error("You are not permitted to do this operation in demo mode")
|
||||
.__dict__
|
||||
)
|
||||
logger.info("[Neo] POST /skills/neo/promote requested.")
|
||||
data = await request.get_json()
|
||||
candidate_id = data.get("candidate_id")
|
||||
stage = data.get("stage", "canary")
|
||||
sync_to_local = _to_bool(data.get("sync_to_local"), True)
|
||||
if not candidate_id:
|
||||
return Response().error("Missing candidate_id").__dict__
|
||||
if stage not in {"canary", "stable"}:
|
||||
return Response().error("Invalid stage, must be canary/stable").__dict__
|
||||
|
||||
async def _do(client):
|
||||
sync_mgr = NeoSkillSyncManager()
|
||||
result = await sync_mgr.promote_with_optional_sync(
|
||||
client,
|
||||
candidate_id=candidate_id,
|
||||
stage=stage,
|
||||
sync_to_local=sync_to_local,
|
||||
)
|
||||
release_json = result.get("release")
|
||||
logger.info(f"[Neo] Candidate promoted: id={candidate_id}, stage={stage}")
|
||||
|
||||
sync_json = result.get("sync")
|
||||
did_sync_to_local = bool(sync_json)
|
||||
if did_sync_to_local:
|
||||
logger.info(
|
||||
f"[Neo] Stable release synced to local: skill={sync_json.get('local_skill_name', '')}"
|
||||
)
|
||||
|
||||
if result.get("sync_error"):
|
||||
resp = Response().error(
|
||||
"Stable promote synced failed and has been rolled back. "
|
||||
f"sync_error={result['sync_error']}"
|
||||
)
|
||||
resp.data = {
|
||||
"release": release_json,
|
||||
"rollback": result.get("rollback"),
|
||||
}
|
||||
return resp.__dict__
|
||||
|
||||
# Try to push latest local skills to all active sandboxes.
|
||||
if not did_sync_to_local:
|
||||
try:
|
||||
await sync_skills_to_active_sandboxes()
|
||||
except Exception:
|
||||
logger.warning("Failed to sync skills to active sandboxes.")
|
||||
|
||||
return Response().ok({"release": release_json, "sync": sync_json}).__dict__
|
||||
|
||||
return await self._with_neo_client(_do)
|
||||
|
||||
async def rollback_neo_release(self):
|
||||
if DEMO_MODE:
|
||||
return (
|
||||
Response()
|
||||
.error("You are not permitted to do this operation in demo mode")
|
||||
.__dict__
|
||||
)
|
||||
logger.info("[Neo] POST /skills/neo/rollback requested.")
|
||||
data = await request.get_json()
|
||||
release_id = data.get("release_id")
|
||||
if not release_id:
|
||||
return Response().error("Missing release_id").__dict__
|
||||
|
||||
async def _do(client):
|
||||
result = await client.skills.rollback_release(release_id)
|
||||
logger.info(f"[Neo] Release rolled back: id={release_id}")
|
||||
return Response().ok(_to_jsonable(result)).__dict__
|
||||
|
||||
return await self._with_neo_client(_do)
|
||||
|
||||
async def sync_neo_release(self):
|
||||
if DEMO_MODE:
|
||||
return (
|
||||
Response()
|
||||
.error("You are not permitted to do this operation in demo mode")
|
||||
.__dict__
|
||||
)
|
||||
logger.info("[Neo] POST /skills/neo/sync requested.")
|
||||
data = await request.get_json()
|
||||
release_id = data.get("release_id")
|
||||
skill_key = data.get("skill_key")
|
||||
require_stable = _to_bool(data.get("require_stable"), True)
|
||||
if not release_id and not skill_key:
|
||||
return Response().error("Missing release_id or skill_key").__dict__
|
||||
|
||||
async def _do(client):
|
||||
sync_mgr = NeoSkillSyncManager()
|
||||
result = await sync_mgr.sync_release(
|
||||
client,
|
||||
release_id=release_id,
|
||||
skill_key=skill_key,
|
||||
require_stable=require_stable,
|
||||
)
|
||||
logger.info(
|
||||
f"[Neo] Release synced to local: skill={result.local_skill_name}, "
|
||||
f"release_id={result.release_id}"
|
||||
)
|
||||
return (
|
||||
Response()
|
||||
.ok(
|
||||
{
|
||||
"skill_key": result.skill_key,
|
||||
"local_skill_name": result.local_skill_name,
|
||||
"release_id": result.release_id,
|
||||
"candidate_id": result.candidate_id,
|
||||
"payload_ref": result.payload_ref,
|
||||
"map_path": result.map_path,
|
||||
"synced_at": result.synced_at,
|
||||
}
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
|
||||
return await self._with_neo_client(_do)
|
||||
|
||||
async def delete_neo_candidate(self):
|
||||
if DEMO_MODE:
|
||||
return (
|
||||
Response()
|
||||
.error("You are not permitted to do this operation in demo mode")
|
||||
.__dict__
|
||||
)
|
||||
logger.info("[Neo] POST /skills/neo/delete-candidate requested.")
|
||||
data = await request.get_json()
|
||||
candidate_id = data.get("candidate_id")
|
||||
reason = data.get("reason")
|
||||
if not candidate_id:
|
||||
return Response().error("Missing candidate_id").__dict__
|
||||
|
||||
async def _do(client):
|
||||
result = await self._delete_neo_candidate(client, candidate_id, reason)
|
||||
logger.info(f"[Neo] Candidate deleted: id={candidate_id}")
|
||||
return Response().ok(_to_jsonable(result)).__dict__
|
||||
|
||||
return await self._with_neo_client(_do)
|
||||
|
||||
async def delete_neo_release(self):
|
||||
if DEMO_MODE:
|
||||
return (
|
||||
Response()
|
||||
.error("You are not permitted to do this operation in demo mode")
|
||||
.__dict__
|
||||
)
|
||||
logger.info("[Neo] POST /skills/neo/delete-release requested.")
|
||||
data = await request.get_json()
|
||||
release_id = data.get("release_id")
|
||||
reason = data.get("reason")
|
||||
if not release_id:
|
||||
return Response().error("Missing release_id").__dict__
|
||||
|
||||
async def _do(client):
|
||||
result = await self._delete_neo_release(client, release_id, reason)
|
||||
logger.info(f"[Neo] Release deleted: id={release_id}")
|
||||
return Response().ok(_to_jsonable(result)).__dict__
|
||||
|
||||
return await self._with_neo_client(_do)
|
||||
|
||||
@@ -51,9 +51,11 @@ class ToolsRoute(Route):
|
||||
server_info[key] = value
|
||||
|
||||
# 如果MCP客户端已初始化,从客户端获取工具名称
|
||||
for name_key, runtime in self.tool_mgr.mcp_server_runtime_view.items():
|
||||
for (
|
||||
name_key,
|
||||
mcp_client,
|
||||
) in self.tool_mgr.mcp_client_dict.items():
|
||||
if name_key == name:
|
||||
mcp_client = runtime.client
|
||||
server_info["tools"] = [tool.name for tool in mcp_client.tools]
|
||||
server_info["errlogs"] = mcp_client.server_errlogs
|
||||
break
|
||||
@@ -190,7 +192,7 @@ class ToolsRoute(Route):
|
||||
# 处理MCP客户端状态变化
|
||||
if active:
|
||||
if (
|
||||
old_name in self.tool_mgr.mcp_server_runtime_view
|
||||
old_name in self.tool_mgr.mcp_client_dict
|
||||
or not only_update_active
|
||||
or is_rename
|
||||
):
|
||||
@@ -231,7 +233,7 @@ class ToolsRoute(Route):
|
||||
.__dict__
|
||||
)
|
||||
# 如果要停用服务器
|
||||
elif old_name in self.tool_mgr.mcp_server_runtime_view:
|
||||
elif old_name in self.tool_mgr.mcp_client_dict:
|
||||
try:
|
||||
await self.tool_mgr.disable_mcp_server(old_name, timeout=10)
|
||||
except TimeoutError:
|
||||
@@ -270,7 +272,7 @@ class ToolsRoute(Route):
|
||||
del config["mcpServers"][name]
|
||||
|
||||
if self.tool_mgr.save_mcp_config(config):
|
||||
if name in self.tool_mgr.mcp_server_runtime_view:
|
||||
if name in self.tool_mgr.mcp_client_dict:
|
||||
try:
|
||||
await self.tool_mgr.disable_mcp_server(name, timeout=10)
|
||||
except TimeoutError:
|
||||
|
||||
@@ -33,9 +33,6 @@ from .routes.session_management import SessionManagementRoute
|
||||
from .routes.subagent import SubAgentRoute
|
||||
from .routes.t2i import T2iRoute
|
||||
|
||||
# Static assets shipped inside the wheel (built during `hatch build`).
|
||||
_BUNDLED_DIST = Path(__file__).parent / "dist"
|
||||
|
||||
|
||||
class _AddrWithPort(Protocol):
|
||||
port: int
|
||||
@@ -69,22 +66,13 @@ class AstrBotDashboard:
|
||||
self.config = core_lifecycle.astrbot_config
|
||||
self.db = db
|
||||
|
||||
# Path priority:
|
||||
# 1. Explicit webui_dir argument
|
||||
# 2. data/dist/ (user-installed / manually updated dashboard)
|
||||
# 3. astrbot/dashboard/dist/ (bundled with the wheel)
|
||||
# 参数指定webui目录
|
||||
if webui_dir and os.path.exists(webui_dir):
|
||||
self.data_path = os.path.abspath(webui_dir)
|
||||
else:
|
||||
user_dist = os.path.join(get_astrbot_data_path(), "dist")
|
||||
if os.path.exists(user_dist):
|
||||
self.data_path = os.path.abspath(user_dist)
|
||||
elif _BUNDLED_DIST.exists():
|
||||
self.data_path = str(_BUNDLED_DIST)
|
||||
logger.info("Using bundled dashboard dist: %s", self.data_path)
|
||||
else:
|
||||
# Fall back to expected user path (will fail gracefully later)
|
||||
self.data_path = os.path.abspath(user_dist)
|
||||
self.data_path = os.path.abspath(
|
||||
os.path.join(get_astrbot_data_path(), "dist"),
|
||||
)
|
||||
|
||||
self.app = Quart("dashboard", static_folder=self.data_path, static_url_path="/")
|
||||
APP = self.app # noqa
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
<head>
|
||||
<meta charset="UTF-8" />
|
||||
<link rel="icon" href="/favicon.svg" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0, maximum-scale=1.0, user-scalable=no" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<meta name="keywords" content="AstrBot Soulter" />
|
||||
<meta name="description" content="AstrBot Dashboard" />
|
||||
<meta name="robots" content="noindex, nofollow" />
|
||||
|
||||
@@ -65,6 +65,7 @@
|
||||
"sass-loader": "13.3.2",
|
||||
"typescript": "5.1.6",
|
||||
"vite": "4.4.9",
|
||||
"vite-plugin-monaco-editor": "1.1.0",
|
||||
"vue-cli-plugin-vuetify": "2.5.8",
|
||||
"vue-tsc": "1.8.8",
|
||||
"vuetify-loader": "^2.0.0-alpha.9"
|
||||
|
||||
Generated
+12
@@ -159,6 +159,9 @@ importers:
|
||||
vite:
|
||||
specifier: 4.4.9
|
||||
version: 4.4.9(@types/node@20.19.32)(sass@1.66.1)(terser@5.46.0)
|
||||
vite-plugin-monaco-editor:
|
||||
specifier: 1.1.0
|
||||
version: 1.1.0(monaco-editor@0.52.2)
|
||||
vue-cli-plugin-vuetify:
|
||||
specifier: 2.5.8
|
||||
version: 2.5.8(sass-loader@13.3.2(sass@1.66.1)(webpack@5.105.0))(vue@3.3.4)(vuetify-loader@2.0.0-alpha.9(@vue/compiler-sfc@3.3.4)(vue@3.3.4)(vuetify@3.7.11)(webpack@5.105.0))(webpack@5.105.0)
|
||||
@@ -2568,6 +2571,11 @@ packages:
|
||||
vfile@6.0.3:
|
||||
resolution: {integrity: sha512-KzIbH/9tXat2u30jf+smMwFCsno4wHVdNmzFyL+T/L3UGqqk6JKfVqOFOZEpZSHADH1k40ab6NUIXZq422ov3Q==}
|
||||
|
||||
vite-plugin-monaco-editor@1.1.0:
|
||||
resolution: {integrity: sha512-IvtUqZotrRoVqwT0PBBDIZPNraya3BxN/bfcNfnxZ5rkJiGcNtO5eAOWWSgT7zullIAEqQwxMU83yL9J5k7gww==}
|
||||
peerDependencies:
|
||||
monaco-editor: '>=0.33.0'
|
||||
|
||||
vite-plugin-vuetify@1.0.2:
|
||||
resolution: {integrity: sha512-MubIcKD33O8wtgQXlbEXE7ccTEpHZ8nPpe77y9Wy3my2MWw/PgehP9VqTp92BLqr0R1dSL970Lynvisx3UxBFw==}
|
||||
engines: {node: '>=12'}
|
||||
@@ -5297,6 +5305,10 @@ snapshots:
|
||||
'@types/unist': 3.0.3
|
||||
vfile-message: 4.0.3
|
||||
|
||||
vite-plugin-monaco-editor@1.1.0(monaco-editor@0.52.2):
|
||||
dependencies:
|
||||
monaco-editor: 0.52.2
|
||||
|
||||
vite-plugin-vuetify@1.0.2(vite@4.4.9(@types/node@20.19.32)(sass@1.66.1)(terser@5.46.0))(vue@3.3.4)(vuetify@3.7.11):
|
||||
dependencies:
|
||||
'@vuetify/loader-shared': 1.7.1(vue@3.3.4)(vuetify@3.7.11)
|
||||
|
||||
@@ -37,7 +37,14 @@
|
||||
|
||||
<!-- 正常聊天界面 -->
|
||||
<template v-else>
|
||||
<div class="conversation-header fade-in" v-if="isMobile">
|
||||
<!-- 手机端菜单按钮 -->
|
||||
<v-btn icon class="mobile-menu-btn" @click="toggleMobileSidebar" variant="text">
|
||||
<v-icon>mdi-menu</v-icon>
|
||||
</v-btn>
|
||||
</div>
|
||||
|
||||
<!-- 面包屑导航 -->
|
||||
<div v-if="currentSessionProject && messages && messages.length > 0" class="breadcrumb-container">
|
||||
<div class="breadcrumb-content">
|
||||
<span class="breadcrumb-emoji">{{ currentSessionProject.emoji || '📁' }}</span>
|
||||
@@ -234,7 +241,6 @@ const route = useRoute();
|
||||
const { t } = useI18n();
|
||||
const { tm } = useModuleI18n('features/chat');
|
||||
const theme = useTheme();
|
||||
const customizer = useCustomizerStore();
|
||||
|
||||
// UI 状态
|
||||
const isMobile = ref(false);
|
||||
@@ -336,28 +342,19 @@ function checkMobile() {
|
||||
isMobile.value = window.innerWidth <= 768;
|
||||
if (!isMobile.value) {
|
||||
mobileMenuOpen.value = false;
|
||||
customizer.SET_CHAT_SIDEBAR(false);
|
||||
}
|
||||
}
|
||||
|
||||
function toggleMobileSidebar() {
|
||||
mobileMenuOpen.value = !mobileMenuOpen.value;
|
||||
customizer.SET_CHAT_SIDEBAR(mobileMenuOpen.value);
|
||||
}
|
||||
|
||||
function closeMobileSidebar() {
|
||||
mobileMenuOpen.value = false;
|
||||
customizer.SET_CHAT_SIDEBAR(false);
|
||||
}
|
||||
|
||||
// 同步 nav header 中的 sidebar toggle
|
||||
watch(() => customizer.chatSidebarOpen, (val) => {
|
||||
if (isMobile.value) {
|
||||
mobileMenuOpen.value = val;
|
||||
}
|
||||
});
|
||||
|
||||
function toggleTheme() {
|
||||
const customizer = useCustomizerStore();
|
||||
const newTheme = customizer.uiTheme === 'PurpleTheme' ? 'PurpleThemeDark' : 'PurpleTheme';
|
||||
customizer.SET_UI_THEME(newTheme);
|
||||
theme.global.name.value = newTheme;
|
||||
@@ -725,7 +722,6 @@ onBeforeUnmount(() => {
|
||||
height: 100%;
|
||||
max-height: 100%;
|
||||
overflow: hidden;
|
||||
overscroll-behavior: none;
|
||||
}
|
||||
|
||||
.chat-page-container {
|
||||
|
||||
@@ -32,12 +32,11 @@
|
||||
</div>
|
||||
</transition>
|
||||
<textarea ref="inputField" v-model="localPrompt" @keydown="handleKeyDown" :disabled="disabled"
|
||||
placeholder="Ask AstrBot..." class="chat-textarea"
|
||||
autocomplete="off" autocorrect="off" autocapitalize="sentences" spellcheck="false"
|
||||
style="width: 100%; resize: none; outline: none; border: 1px solid var(--v-theme-border); border-radius: 12px; padding: 16px 20px; min-height: 40px; max-height: 200px; overflow-y: auto; font-family: inherit; font-size: 16px; background-color: var(--v-theme-surface);"></textarea>
|
||||
placeholder="Ask AstrBot..."
|
||||
style="width: 100%; resize: none; outline: none; border: 1px solid var(--v-theme-border); border-radius: 12px; padding: 12px 16px; min-height: 40px; font-family: inherit; font-size: 16px; background-color: var(--v-theme-surface);"></textarea>
|
||||
<div style="display: flex; justify-content: space-between; align-items: center; padding: 6px 14px;">
|
||||
<div
|
||||
style="display: flex; justify-content: flex-start; margin-top: 4px; align-items: center; gap: 8px; min-width: 0; flex: 1; overflow: hidden;">
|
||||
style="display: flex; justify-content: flex-start; margin-top: 4px; align-items: center; gap: 8px;">
|
||||
<!-- Settings Menu -->
|
||||
<StyledMenu offset="8" location="top start" :close-on-content-click="false">
|
||||
<template v-slot:activator="{ props: activatorProps }">
|
||||
@@ -73,9 +72,9 @@
|
||||
<!-- Provider/Model Selector Menu -->
|
||||
<ProviderModelMenu v-if="showProviderSelector" ref="providerModelMenuRef" />
|
||||
</div>
|
||||
<div style="display: flex; justify-content: flex-end; margin-top: 8px; align-items: center; flex-shrink: 0;">
|
||||
<div style="display: flex; justify-content: flex-end; margin-top: 8px; align-items: center;">
|
||||
<input type="file" ref="imageInputRef" @change="handleFileSelect" style="display: none" multiple />
|
||||
<v-progress-circular v-if="disabled && !mobile" indeterminate size="16" class="mr-1" width="1.5" />
|
||||
<v-progress-circular v-if="disabled" indeterminate size="16" class="mr-1" width="1.5" />
|
||||
<!-- <v-btn @click="$emit('openLiveMode')"
|
||||
icon
|
||||
variant="text"
|
||||
@@ -88,21 +87,36 @@
|
||||
</v-tooltip>
|
||||
</v-btn> -->
|
||||
<v-btn @click="handleRecordClick" icon variant="text" :color="isRecording ? 'error' : 'deep-purple'"
|
||||
class="record-btn">
|
||||
class="record-btn" size="small">
|
||||
<v-icon :icon="isRecording ? 'mdi-stop-circle' : 'mdi-microphone'" variant="text"
|
||||
plain></v-icon>
|
||||
<v-tooltip activator="parent" location="top">
|
||||
{{ isRecording ? tm('voice.speaking') : tm('voice.startRecording') }}
|
||||
</v-tooltip>
|
||||
</v-btn>
|
||||
<v-btn icon v-if="isRunning" @click="$emit('stop')" variant="tonal" color="deep-purple" class="send-btn">
|
||||
<v-btn
|
||||
icon
|
||||
v-if="isRunning"
|
||||
@click="$emit('stop')"
|
||||
variant="text"
|
||||
class="send-btn"
|
||||
size="small"
|
||||
>
|
||||
<v-icon icon="mdi-stop" variant="text" plain></v-icon>
|
||||
<v-tooltip activator="parent" location="top">
|
||||
{{ tm('input.stopGenerating') }}
|
||||
</v-tooltip>
|
||||
</v-btn>
|
||||
<v-btn v-else @click="$emit('send')" icon="mdi-send" variant="tonal" color="deep-purple"
|
||||
:disabled="!canSend" class="send-btn" />
|
||||
<v-btn
|
||||
v-else
|
||||
@click="$emit('send')"
|
||||
icon="mdi-send"
|
||||
variant="text"
|
||||
color="deep-purple"
|
||||
:disabled="!canSend"
|
||||
class="send-btn"
|
||||
size="small"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
@@ -138,8 +152,7 @@
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, computed, watch, nextTick, onMounted, onBeforeUnmount } from 'vue';
|
||||
import { useDisplay } from 'vuetify';
|
||||
import { ref, computed, onMounted, onBeforeUnmount } from 'vue';
|
||||
import { useModuleI18n } from '@/i18n/composables';
|
||||
import { useCustomizerStore } from '@/stores/customizer';
|
||||
import ConfigSelector from './ConfigSelector.vue';
|
||||
@@ -238,34 +251,21 @@ function handleReplyAfterLeave() {
|
||||
isReplyClosing.value = false;
|
||||
}
|
||||
|
||||
const { mobile } = useDisplay();
|
||||
|
||||
// Auto-resize textarea
|
||||
function autoResize() {
|
||||
const el = inputField.value;
|
||||
if (!el) return;
|
||||
el.style.height = 'auto';
|
||||
el.style.height = Math.min(el.scrollHeight, 200) + 'px';
|
||||
}
|
||||
|
||||
watch(localPrompt, () => {
|
||||
nextTick(autoResize);
|
||||
});
|
||||
|
||||
function handleKeyDown(e: KeyboardEvent) {
|
||||
// Enter 插入换行(桌面和手机端均如此,发送通过右下角发送按鈕)
|
||||
// Shift+Enter 发送(Ctrl+Enter / Cmd+Enter 也保留)
|
||||
if (e.keyCode === 13 && (e.shiftKey || e.ctrlKey || e.metaKey)) {
|
||||
// Enter 发送消息或触发命令
|
||||
if (e.keyCode === 13 && !e.shiftKey) {
|
||||
e.preventDefault();
|
||||
|
||||
// 检查是否是 /astr_live_dev 命令
|
||||
if (localPrompt.value.trim() === '/astr_live_dev') {
|
||||
emit('openLiveMode');
|
||||
localPrompt.value = '';
|
||||
return;
|
||||
}
|
||||
|
||||
if (canSend.value) {
|
||||
emit('send');
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// Ctrl+B 录音
|
||||
@@ -588,20 +588,11 @@ defineExpose({
|
||||
@media (max-width: 768px) {
|
||||
.input-area {
|
||||
padding: 0 !important;
|
||||
padding-bottom: 10px !important;
|
||||
}
|
||||
|
||||
.input-container {
|
||||
width: 100% !important;
|
||||
max-width: 100% !important;
|
||||
}
|
||||
|
||||
.input-area textarea,
|
||||
.chat-textarea {
|
||||
min-height: 32px !important;
|
||||
max-height: 160px !important;
|
||||
font-size: 16px !important;
|
||||
padding: 16px 16px 12px 16px !important;
|
||||
}
|
||||
}
|
||||
</style>
|
||||
|
||||
@@ -37,7 +37,7 @@
|
||||
@deleteProject="$emit('deleteProject', $event)"
|
||||
/>
|
||||
|
||||
<div style="overflow-y: auto; flex-grow: 1; overscroll-behavior-y: contain;"
|
||||
<div style="overflow-y: auto; flex-grow: 1;"
|
||||
v-if="!sidebarCollapsed || isMobile">
|
||||
<v-card v-if="sessions.length > 0" flat style="background-color: transparent;">
|
||||
<v-list density="compact" nav class="conversation-list"
|
||||
@@ -326,13 +326,6 @@ function handleTransportModeChange(mode: string | null) {
|
||||
transition: all 0.2s ease;
|
||||
}
|
||||
|
||||
@media (max-width: 768px) {
|
||||
.conversation-actions {
|
||||
opacity: 1 !important;
|
||||
visibility: visible !important;
|
||||
}
|
||||
}
|
||||
|
||||
.edit-title-btn,
|
||||
.delete-conversation-btn {
|
||||
opacity: 0.7;
|
||||
|
||||
@@ -965,7 +965,6 @@ export default {
|
||||
height: 100%;
|
||||
max-height: 100%;
|
||||
overflow-y: auto;
|
||||
overscroll-behavior-y: contain;
|
||||
padding: 16px;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
|
||||
@@ -1,295 +1,60 @@
|
||||
<template>
|
||||
<div class="skills-page">
|
||||
<v-container fluid class="pa-0" elevation="0">
|
||||
<v-row class="d-flex justify-space-between align-center px-4 py-3 pb-4">
|
||||
<v-row class="d-flex justify-space-between align-center px-4 py-3 pb-8">
|
||||
<div>
|
||||
<v-btn
|
||||
v-if="mode === 'local'"
|
||||
color="success"
|
||||
prepend-icon="mdi-upload"
|
||||
class="me-2"
|
||||
variant="tonal"
|
||||
@click="uploadDialog = true"
|
||||
>
|
||||
{{ tm("skills.upload") }}
|
||||
<v-btn color="success" prepend-icon="mdi-upload" class="me-2" variant="tonal" @click="uploadDialog = true">
|
||||
{{ tm('skills.upload') }}
|
||||
</v-btn>
|
||||
<v-btn color="primary" prepend-icon="mdi-refresh" variant="tonal" @click="refreshCurrentMode">
|
||||
{{ tm("skills.refresh") }}
|
||||
<v-btn color="primary" prepend-icon="mdi-refresh" variant="tonal" @click="fetchSkills">
|
||||
{{ tm('skills.refresh') }}
|
||||
</v-btn>
|
||||
</div>
|
||||
<v-btn-toggle v-model="mode" mandatory divided density="comfortable">
|
||||
<v-btn value="local">{{ tm("skills.modeLocal") }}</v-btn>
|
||||
<v-btn value="neo" :disabled="!neoEnabled">{{ tm("skills.modeNeo") }}</v-btn>
|
||||
</v-btn-toggle>
|
||||
</v-row>
|
||||
|
||||
<div v-if="mode === 'local'" class="px-2 pb-2 d-flex flex-column ga-2">
|
||||
<small style="color: grey;">{{ tm("skills.runtimeHint") }}</small>
|
||||
<v-alert
|
||||
v-if="runtime === 'sandbox' && !sandboxCache.ready"
|
||||
type="info"
|
||||
variant="tonal"
|
||||
density="comfortable"
|
||||
border="start"
|
||||
>
|
||||
{{ tm("skills.sandboxDiscoveryPending") }}
|
||||
</v-alert>
|
||||
<div class="px-2 pb-2">
|
||||
<small style="color: grey;">{{ tm('skills.runtimeHint') }}</small>
|
||||
</div>
|
||||
|
||||
<div v-if="mode === 'neo' && !neoEnabled" class="px-3 pb-3">
|
||||
<v-alert type="warning" variant="tonal" density="comfortable" border="start">
|
||||
{{ neoUnavailableMessage }}
|
||||
</v-alert>
|
||||
<v-progress-linear v-if="loading" indeterminate color="primary"></v-progress-linear>
|
||||
|
||||
<div v-else-if="skills.length === 0" class="text-center pa-8">
|
||||
<v-icon size="64" color="grey-lighten-1">mdi-folder-open</v-icon>
|
||||
<p class="text-grey mt-4">{{ tm('skills.empty') }}</p>
|
||||
<small class="text-grey">{{ tm('skills.emptyHint') }}</small>
|
||||
</div>
|
||||
|
||||
<template v-if="mode === 'local'">
|
||||
<v-progress-linear v-if="loading" indeterminate color="primary"></v-progress-linear>
|
||||
|
||||
<div v-else-if="skills.length === 0" class="text-center pa-8">
|
||||
<v-icon size="64" color="grey-lighten-1">mdi-folder-open</v-icon>
|
||||
<p class="text-grey mt-4">{{ tm("skills.empty") }}</p>
|
||||
<small class="text-grey">{{ tm("skills.emptyHint") }}</small>
|
||||
</div>
|
||||
|
||||
<v-row v-else align="stretch">
|
||||
<v-col
|
||||
v-for="skill in skills"
|
||||
:key="skill.name"
|
||||
cols="12"
|
||||
md="6"
|
||||
lg="4"
|
||||
xl="3"
|
||||
class="d-flex"
|
||||
>
|
||||
<item-card
|
||||
:item="skill"
|
||||
title-field="name"
|
||||
enabled-field="active"
|
||||
:loading="itemLoading[skill.name] || false"
|
||||
:show-edit-button="false"
|
||||
:disable-toggle="isSandboxPresetSkill(skill)"
|
||||
:disable-delete="isSandboxPresetSkill(skill)"
|
||||
@toggle-enabled="toggleSkill"
|
||||
@delete="confirmDelete"
|
||||
>
|
||||
<template #item-details="{ item }">
|
||||
<div class="d-flex align-center mb-2 ga-2 flex-wrap">
|
||||
<v-chip
|
||||
size="x-small"
|
||||
variant="tonal"
|
||||
:color="sourceTypeColor(item.source_type)"
|
||||
>
|
||||
{{ sourceTypeLabel(item.source_type) }}
|
||||
</v-chip>
|
||||
<div class="text-caption text-medium-emphasis skill-description">
|
||||
<v-icon size="small" class="me-1">mdi-text</v-icon>
|
||||
{{ item.description || tm("skills.noDescription") }}
|
||||
</div>
|
||||
</div>
|
||||
<div class="text-caption text-medium-emphasis skill-path">
|
||||
<v-icon size="small" class="me-1">mdi-file-document</v-icon>
|
||||
{{ tm("skills.path") }}: {{ item.path }}
|
||||
</div>
|
||||
</template>
|
||||
<template #actions="{ item }">
|
||||
<v-btn
|
||||
variant="tonal"
|
||||
color="primary"
|
||||
size="small"
|
||||
rounded="xl"
|
||||
:disabled="itemLoading[item.name] || false || isSandboxPresetSkill(item)"
|
||||
@click="downloadSkill(item)"
|
||||
>
|
||||
{{ tm("skills.download") }}
|
||||
</v-btn>
|
||||
</template>
|
||||
</item-card>
|
||||
</v-col>
|
||||
</v-row>
|
||||
</template>
|
||||
|
||||
<template v-else-if="mode === 'neo' && neoEnabled">
|
||||
<v-card class="mx-3 mb-4 pa-4 neo-filter-card" variant="outlined">
|
||||
<div class="d-flex flex-wrap justify-space-between align-center ga-2 mb-3">
|
||||
<div>
|
||||
<div class="text-subtitle-1 font-weight-bold">Neo Skills</div>
|
||||
<div class="text-caption text-medium-emphasis">{{ tm("skills.neoFilterHint") }}</div>
|
||||
</div>
|
||||
<v-btn color="primary" prepend-icon="mdi-refresh" variant="flat" @click="fetchNeoData">
|
||||
{{ tm("skills.refresh") }}
|
||||
</v-btn>
|
||||
</div>
|
||||
|
||||
<v-row class="ga-md-0 ga-2">
|
||||
<v-col cols="12" md="4">
|
||||
<v-text-field
|
||||
v-model="neoFilters.skill_key"
|
||||
:label="tm('skills.neoSkillKey')"
|
||||
prepend-inner-icon="mdi-key-outline"
|
||||
density="comfortable"
|
||||
hide-details
|
||||
variant="outlined"
|
||||
/>
|
||||
</v-col>
|
||||
<v-col cols="12" md="4">
|
||||
<v-select
|
||||
v-model="neoFilters.status"
|
||||
:label="tm('skills.neoStatus')"
|
||||
:items="candidateStatusItems"
|
||||
item-title="title"
|
||||
item-value="value"
|
||||
prepend-inner-icon="mdi-progress-check"
|
||||
density="comfortable"
|
||||
hide-details
|
||||
variant="outlined"
|
||||
/>
|
||||
</v-col>
|
||||
<v-col cols="12" md="4">
|
||||
<v-select
|
||||
v-model="neoFilters.stage"
|
||||
:label="tm('skills.neoStage')"
|
||||
:items="releaseStageItems"
|
||||
item-title="title"
|
||||
item-value="value"
|
||||
prepend-inner-icon="mdi-layers-outline"
|
||||
density="comfortable"
|
||||
hide-details
|
||||
variant="outlined"
|
||||
/>
|
||||
</v-col>
|
||||
</v-row>
|
||||
</v-card>
|
||||
|
||||
<v-progress-linear v-if="neoLoading" indeterminate color="primary"></v-progress-linear>
|
||||
|
||||
<div class="mx-3 mb-3 d-flex flex-wrap ga-2">
|
||||
<v-chip size="small" color="primary" variant="tonal">Candidates: {{ neoCandidates.length }}</v-chip>
|
||||
<v-chip size="small" color="indigo" variant="tonal">Releases: {{ neoReleases.length }}</v-chip>
|
||||
<v-chip size="small" color="success" variant="tonal">Active: {{ activeReleaseCount }}</v-chip>
|
||||
</div>
|
||||
|
||||
<v-card class="mx-3 mb-4 neo-table-card" variant="outlined">
|
||||
<v-card-title class="text-subtitle-1 font-weight-bold">{{ tm("skills.neoCandidates") }}</v-card-title>
|
||||
<v-data-table
|
||||
:headers="candidateHeaders"
|
||||
:items="neoCandidates"
|
||||
density="compact"
|
||||
:items-per-page="10"
|
||||
class="neo-data-table"
|
||||
>
|
||||
<template #item.latest_score="{ item }">
|
||||
{{ item.latest_score ?? "-" }}
|
||||
</template>
|
||||
<template #item.actions="{ item }">
|
||||
<div class="d-flex ga-1 flex-wrap">
|
||||
<v-btn size="x-small" color="success" variant="tonal" @click="evaluateCandidate(item, true)">
|
||||
{{ tm("skills.neoPass") }}
|
||||
</v-btn>
|
||||
<v-btn size="x-small" color="warning" variant="tonal" @click="evaluateCandidate(item, false)">
|
||||
{{ tm("skills.neoReject") }}
|
||||
</v-btn>
|
||||
<v-btn
|
||||
size="x-small"
|
||||
color="primary"
|
||||
variant="tonal"
|
||||
:loading="isCandidatePromoteLoading(item.id, 'canary')"
|
||||
:disabled="isCandidatePromoting(item.id)"
|
||||
@click="promoteCandidate(item, 'canary')"
|
||||
>
|
||||
Canary
|
||||
</v-btn>
|
||||
<v-btn
|
||||
size="x-small"
|
||||
color="primary"
|
||||
variant="tonal"
|
||||
:loading="isCandidatePromoteLoading(item.id, 'stable')"
|
||||
:disabled="isCandidatePromoting(item.id)"
|
||||
@click="promoteCandidate(item, 'stable')"
|
||||
>
|
||||
Stable
|
||||
</v-btn>
|
||||
<v-btn
|
||||
size="x-small"
|
||||
variant="tonal"
|
||||
@click="viewPayload(item.payload_ref)"
|
||||
:disabled="!item.payload_ref"
|
||||
>
|
||||
Payload
|
||||
</v-btn>
|
||||
<v-btn
|
||||
size="x-small"
|
||||
color="error"
|
||||
variant="tonal"
|
||||
@click="deleteCandidate(item)"
|
||||
>
|
||||
{{ tm("skills.neoDelete") }}
|
||||
</v-btn>
|
||||
<v-row v-else>
|
||||
<v-col v-for="skill in skills" :key="skill.name" cols="12" md="6" lg="4" xl="3">
|
||||
<item-card :item="skill" title-field="name" enabled-field="active" :loading="itemLoading[skill.name] || false"
|
||||
:show-edit-button="false" @toggle-enabled="toggleSkill" @delete="confirmDelete">
|
||||
<template v-slot:item-details="{ item }">
|
||||
<div class="text-caption text-medium-emphasis mb-2 skill-description">
|
||||
<v-icon size="small" class="me-1">mdi-text</v-icon>
|
||||
{{ item.description || tm('skills.noDescription') }}
|
||||
</div>
|
||||
<div class="text-caption text-medium-emphasis">
|
||||
<v-icon size="small" class="me-1">mdi-file-document</v-icon>
|
||||
{{ tm('skills.path') }}: {{ item.path }}
|
||||
</div>
|
||||
</template>
|
||||
</v-data-table>
|
||||
</v-card>
|
||||
|
||||
<v-card class="mx-3 mb-4 neo-table-card" variant="outlined">
|
||||
<v-card-title class="text-subtitle-1 font-weight-bold">{{ tm("skills.neoReleases") }}</v-card-title>
|
||||
<v-data-table
|
||||
:headers="releaseHeaders"
|
||||
:items="neoReleases"
|
||||
density="compact"
|
||||
:items-per-page="10"
|
||||
class="neo-data-table"
|
||||
>
|
||||
<template #item.is_active="{ item }">
|
||||
<v-chip size="small" :color="item.is_active ? 'success' : 'default'" variant="tonal">
|
||||
{{ item.is_active ? "active" : "inactive" }}
|
||||
</v-chip>
|
||||
</template>
|
||||
<template #item.actions="{ item }">
|
||||
<div class="d-flex ga-1 flex-wrap">
|
||||
<v-btn
|
||||
size="x-small"
|
||||
color="warning"
|
||||
variant="tonal"
|
||||
@click="handleReleaseLifecycleAction(item)"
|
||||
>
|
||||
{{ item.is_active ? tm("skills.neoDeactivate") : tm("skills.neoRollback") }}
|
||||
</v-btn>
|
||||
<v-btn size="x-small" color="primary" variant="tonal" @click="syncRelease(item)">
|
||||
{{ tm("skills.neoSync") }}
|
||||
</v-btn>
|
||||
<v-btn
|
||||
size="x-small"
|
||||
color="error"
|
||||
variant="tonal"
|
||||
@click="deleteRelease(item)"
|
||||
>
|
||||
{{ tm("skills.neoDelete") }}
|
||||
</v-btn>
|
||||
</div>
|
||||
</template>
|
||||
</v-data-table>
|
||||
</v-card>
|
||||
</template>
|
||||
</item-card>
|
||||
</v-col>
|
||||
</v-row>
|
||||
</v-container>
|
||||
|
||||
<v-dialog v-model="uploadDialog" max-width="520px">
|
||||
<v-card>
|
||||
<v-card-title class="text-h3 pa-4 pb-0 pl-6">{{ tm("skills.uploadDialogTitle") }}</v-card-title>
|
||||
<v-card-title class="text-h3 pa-4 pb-0 pl-6">{{ tm('skills.uploadDialogTitle') }}</v-card-title>
|
||||
<v-card-text>
|
||||
<small class="text-grey">{{ tm("skills.uploadHint") }}</small>
|
||||
<v-file-input
|
||||
v-model="uploadFile"
|
||||
accept=".zip"
|
||||
:label="tm('skills.selectFile')"
|
||||
prepend-icon="mdi-folder-zip-outline"
|
||||
variant="outlined"
|
||||
class="mt-4"
|
||||
:multiple="false"
|
||||
/>
|
||||
<small class="text-grey">{{ tm('skills.uploadHint') }}</small>
|
||||
<v-file-input v-model="uploadFile" accept=".zip" :label="tm('skills.selectFile')"
|
||||
prepend-icon="mdi-folder-zip-outline" variant="outlined" class="mt-4" :multiple="false" />
|
||||
</v-card-text>
|
||||
<v-card-actions class="d-flex justify-end">
|
||||
<v-btn variant="text" @click="uploadDialog = false">{{ tm("skills.cancel") }}</v-btn>
|
||||
<v-btn variant="text" @click="uploadDialog = false">{{ tm('skills.cancel') }}</v-btn>
|
||||
<v-btn color="primary" :loading="uploading" :disabled="!uploadFile" @click="uploadSkill">
|
||||
{{ tm("skills.confirmUpload") }}
|
||||
{{ tm('skills.confirmUpload') }}
|
||||
</v-btn>
|
||||
</v-card-actions>
|
||||
</v-card>
|
||||
@@ -297,30 +62,18 @@
|
||||
|
||||
<v-dialog v-model="deleteDialog" max-width="400px">
|
||||
<v-card>
|
||||
<v-card-title>{{ tm("skills.deleteTitle") }}</v-card-title>
|
||||
<v-card-text>{{ tm("skills.deleteMessage") }}</v-card-text>
|
||||
<v-card-title>{{ tm('skills.deleteTitle') }}</v-card-title>
|
||||
<v-card-text>{{ tm('skills.deleteMessage') }}</v-card-text>
|
||||
<v-card-actions class="d-flex justify-end">
|
||||
<v-btn variant="text" @click="deleteDialog = false">{{ tm("skills.cancel") }}</v-btn>
|
||||
<v-btn variant="text" @click="deleteDialog = false">{{ tm('skills.cancel') }}</v-btn>
|
||||
<v-btn color="error" :loading="deleting" @click="deleteSkill">
|
||||
{{ t("core.common.itemCard.delete") }}
|
||||
{{ t('core.common.itemCard.delete') }}
|
||||
</v-btn>
|
||||
</v-card-actions>
|
||||
</v-card>
|
||||
</v-dialog>
|
||||
|
||||
<v-dialog v-model="payloadDialog.show" max-width="820px">
|
||||
<v-card>
|
||||
<v-card-title>{{ tm("skills.neoPayloadTitle") }}</v-card-title>
|
||||
<v-card-text>
|
||||
<pre class="payload-preview">{{ payloadDialog.content }}</pre>
|
||||
</v-card-text>
|
||||
<v-card-actions class="d-flex justify-end">
|
||||
<v-btn variant="text" @click="payloadDialog.show = false">{{ tm("skills.cancel") }}</v-btn>
|
||||
</v-card-actions>
|
||||
</v-card>
|
||||
</v-dialog>
|
||||
|
||||
<v-snackbar v-model="snackbar.show" :timeout="3500" :color="snackbar.color" elevation="24">
|
||||
<v-snackbar v-model="snackbar.show" :timeout="3000" :color="snackbar.color" elevation="24">
|
||||
{{ snackbar.message }}
|
||||
</v-snackbar>
|
||||
</div>
|
||||
@@ -328,7 +81,7 @@
|
||||
|
||||
<script>
|
||||
import axios from "axios";
|
||||
import { computed, onMounted, reactive, ref, watch } from "vue";
|
||||
import { ref, reactive, onMounted } from "vue";
|
||||
import ItemCard from "@/components/shared/ItemCard.vue";
|
||||
import { useI18n, useModuleI18n } from "@/i18n/composables";
|
||||
|
||||
@@ -339,11 +92,8 @@ export default {
|
||||
const { t } = useI18n();
|
||||
const { tm } = useModuleI18n("features/extension");
|
||||
|
||||
const mode = ref("local");
|
||||
const skills = ref([]);
|
||||
const loading = ref(false);
|
||||
const runtime = ref("local");
|
||||
const sandboxCache = reactive({ ready: false, count: 0, updated_at: null });
|
||||
const uploading = ref(false);
|
||||
const uploadDialog = ref(false);
|
||||
const uploadFile = ref(null);
|
||||
@@ -353,109 +103,23 @@ export default {
|
||||
const skillToDelete = ref(null);
|
||||
const snackbar = reactive({ show: false, message: "", color: "success" });
|
||||
|
||||
const neoLoading = ref(false);
|
||||
const neoCandidates = ref([]);
|
||||
const neoReleases = ref([]);
|
||||
const neoFilters = reactive({
|
||||
skill_key: "",
|
||||
status: "",
|
||||
stage: "",
|
||||
});
|
||||
const candidatePromoteLoading = reactive({});
|
||||
const payloadDialog = reactive({
|
||||
show: false,
|
||||
content: "",
|
||||
});
|
||||
|
||||
const neoEnabled = ref(false);
|
||||
const neoUnavailableMessage = ref("");
|
||||
|
||||
const candidateStatusItems = computed(() => [
|
||||
{ title: tm("skills.neoAll"), value: "" },
|
||||
{ title: "draft", value: "draft" },
|
||||
{ title: "evaluating", value: "evaluating" },
|
||||
{ title: "promoted", value: "promoted" },
|
||||
{ title: "promoted_canary", value: "promoted_canary" },
|
||||
{ title: "promoted_stable", value: "promoted_stable" },
|
||||
{ title: "rejected", value: "rejected" },
|
||||
{ title: "rolled_back", value: "rolled_back" },
|
||||
]);
|
||||
|
||||
const releaseStageItems = computed(() => [
|
||||
{ title: tm("skills.neoAll"), value: "" },
|
||||
{ title: "canary", value: "canary" },
|
||||
{ title: "stable", value: "stable" },
|
||||
]);
|
||||
|
||||
const activeReleaseCount = computed(() => neoReleases.value.filter((item) => item?.is_active).length);
|
||||
|
||||
const candidateHeaders = computed(() => [
|
||||
{ title: "ID", key: "id", width: "180px" },
|
||||
{ title: "skill_key", key: "skill_key" },
|
||||
{ title: "status", key: "status", width: "130px" },
|
||||
{ title: "score", key: "latest_score", width: "90px" },
|
||||
{ title: tm("skills.actions"), key: "actions", sortable: false, width: "420px" },
|
||||
]);
|
||||
|
||||
const releaseHeaders = computed(() => [
|
||||
{ title: "ID", key: "id", width: "180px" },
|
||||
{ title: "skill_key", key: "skill_key" },
|
||||
{ title: "stage", key: "stage", width: "100px" },
|
||||
{ title: "version", key: "version", width: "90px" },
|
||||
{ title: "active", key: "is_active", width: "110px" },
|
||||
{ title: tm("skills.actions"), key: "actions", sortable: false, width: "220px" },
|
||||
]);
|
||||
|
||||
const showMessage = (message, color = "success") => {
|
||||
snackbar.message = message;
|
||||
snackbar.color = color;
|
||||
snackbar.show = true;
|
||||
};
|
||||
|
||||
const normalizeSkillsPayload = (res) => {
|
||||
const payload = res?.data?.data || [];
|
||||
if (Array.isArray(payload)) {
|
||||
runtime.value = "local";
|
||||
sandboxCache.ready = false;
|
||||
sandboxCache.count = 0;
|
||||
sandboxCache.updated_at = null;
|
||||
return payload;
|
||||
}
|
||||
runtime.value = payload.runtime || "local";
|
||||
const cache = payload.sandbox_cache || {};
|
||||
sandboxCache.ready = !!cache.ready;
|
||||
sandboxCache.count = Number(cache.count || 0);
|
||||
sandboxCache.updated_at = cache.updated_at || null;
|
||||
return payload.skills || [];
|
||||
};
|
||||
|
||||
const sourceTypeLabel = (sourceType) => {
|
||||
if (sourceType === "sandbox_only") return tm("skills.sourceSandboxOnly");
|
||||
if (sourceType === "both") return tm("skills.sourceBoth");
|
||||
return tm("skills.sourceLocalOnly");
|
||||
};
|
||||
|
||||
const sourceTypeColor = (sourceType) => {
|
||||
if (sourceType === "sandbox_only") return "indigo";
|
||||
if (sourceType === "both") return "success";
|
||||
return "primary";
|
||||
};
|
||||
|
||||
const isSandboxPresetSkill = (skill) => skill?.source_type === "sandbox_only";
|
||||
|
||||
const normalizeNeoItemsPayload = (res) => {
|
||||
const payload = res?.data?.data || [];
|
||||
if (Array.isArray(payload)) return payload;
|
||||
if (Array.isArray(payload.items)) return payload.items;
|
||||
return [];
|
||||
};
|
||||
|
||||
const fetchSkills = async () => {
|
||||
loading.value = true;
|
||||
try {
|
||||
const res = await axios.get("/api/skills");
|
||||
skills.value = normalizeSkillsPayload(res);
|
||||
} catch (_err) {
|
||||
const payload = res.data?.data || [];
|
||||
if (Array.isArray(payload)) {
|
||||
skills.value = payload;
|
||||
} else {
|
||||
skills.value = payload.skills || [];
|
||||
}
|
||||
} catch (err) {
|
||||
showMessage(tm("skills.loadFailed"), "error");
|
||||
} finally {
|
||||
loading.value = false;
|
||||
@@ -477,7 +141,9 @@ export default {
|
||||
uploading.value = true;
|
||||
try {
|
||||
const formData = new FormData();
|
||||
const file = Array.isArray(uploadFile.value) ? uploadFile.value[0] : uploadFile.value;
|
||||
const file = Array.isArray(uploadFile.value)
|
||||
? uploadFile.value[0]
|
||||
: uploadFile.value;
|
||||
if (!file) {
|
||||
uploading.value = false;
|
||||
return;
|
||||
@@ -486,12 +152,17 @@ export default {
|
||||
const res = await axios.post("/api/skills/upload", formData, {
|
||||
headers: { "Content-Type": "multipart/form-data" },
|
||||
});
|
||||
handleApiResponse(res, tm("skills.uploadSuccess"), tm("skills.uploadFailed"), async () => {
|
||||
uploadDialog.value = false;
|
||||
uploadFile.value = null;
|
||||
await fetchSkills();
|
||||
});
|
||||
} catch (_err) {
|
||||
handleApiResponse(
|
||||
res,
|
||||
tm("skills.uploadSuccess"),
|
||||
tm("skills.uploadFailed"),
|
||||
async () => {
|
||||
uploadDialog.value = false;
|
||||
uploadFile.value = null;
|
||||
await fetchSkills();
|
||||
}
|
||||
);
|
||||
} catch (err) {
|
||||
showMessage(tm("skills.uploadFailed"), "error");
|
||||
} finally {
|
||||
uploading.value = false;
|
||||
@@ -499,10 +170,6 @@ export default {
|
||||
};
|
||||
|
||||
const toggleSkill = async (skill) => {
|
||||
if (isSandboxPresetSkill(skill)) {
|
||||
showMessage(tm("skills.sandboxPresetReadonly"), "warning");
|
||||
return;
|
||||
}
|
||||
const nextActive = !skill.active;
|
||||
itemLoading[skill.name] = true;
|
||||
try {
|
||||
@@ -510,10 +177,15 @@ export default {
|
||||
name: skill.name,
|
||||
active: nextActive,
|
||||
});
|
||||
handleApiResponse(res, tm("skills.updateSuccess"), tm("skills.updateFailed"), () => {
|
||||
skill.active = nextActive;
|
||||
});
|
||||
} catch (_err) {
|
||||
handleApiResponse(
|
||||
res,
|
||||
tm("skills.updateSuccess"),
|
||||
tm("skills.updateFailed"),
|
||||
() => {
|
||||
skill.active = nextActive;
|
||||
}
|
||||
);
|
||||
} catch (err) {
|
||||
showMessage(tm("skills.updateFailed"), "error");
|
||||
} finally {
|
||||
itemLoading[skill.name] = false;
|
||||
@@ -521,10 +193,6 @@ export default {
|
||||
};
|
||||
|
||||
const confirmDelete = (skill) => {
|
||||
if (isSandboxPresetSkill(skill)) {
|
||||
showMessage(tm("skills.sandboxPresetReadonly"), "warning");
|
||||
return;
|
||||
}
|
||||
skillToDelete.value = skill;
|
||||
deleteDialog.value = true;
|
||||
};
|
||||
@@ -536,288 +204,29 @@ export default {
|
||||
const res = await axios.post("/api/skills/delete", {
|
||||
name: skillToDelete.value.name,
|
||||
});
|
||||
handleApiResponse(res, tm("skills.deleteSuccess"), tm("skills.deleteFailed"), async () => {
|
||||
deleteDialog.value = false;
|
||||
await fetchSkills();
|
||||
});
|
||||
} catch (_err) {
|
||||
handleApiResponse(
|
||||
res,
|
||||
tm("skills.deleteSuccess"),
|
||||
tm("skills.deleteFailed"),
|
||||
async () => {
|
||||
deleteDialog.value = false;
|
||||
await fetchSkills();
|
||||
}
|
||||
);
|
||||
} catch (err) {
|
||||
showMessage(tm("skills.deleteFailed"), "error");
|
||||
} finally {
|
||||
deleting.value = false;
|
||||
}
|
||||
};
|
||||
|
||||
const downloadSkill = async (skill) => {
|
||||
if (isSandboxPresetSkill(skill)) {
|
||||
showMessage(tm("skills.sandboxPresetReadonly"), "warning");
|
||||
return;
|
||||
}
|
||||
itemLoading[skill.name] = true;
|
||||
try {
|
||||
const res = await axios.get("/api/skills/download", {
|
||||
params: { name: skill.name },
|
||||
responseType: "blob",
|
||||
});
|
||||
const blob = new Blob([res.data], { type: "application/zip" });
|
||||
const url = window.URL.createObjectURL(blob);
|
||||
const link = document.createElement("a");
|
||||
link.href = url;
|
||||
link.download = `${skill.name}.zip`;
|
||||
document.body.appendChild(link);
|
||||
link.click();
|
||||
document.body.removeChild(link);
|
||||
window.URL.revokeObjectURL(url);
|
||||
showMessage(tm("skills.downloadSuccess"), "success");
|
||||
} catch (_err) {
|
||||
showMessage(tm("skills.downloadFailed"), "error");
|
||||
} finally {
|
||||
itemLoading[skill.name] = false;
|
||||
}
|
||||
};
|
||||
|
||||
const fetchNeoCandidates = async () => {
|
||||
const params = {
|
||||
skill_key: neoFilters.skill_key || undefined,
|
||||
status: neoFilters.status || undefined,
|
||||
};
|
||||
const res = await axios.get("/api/skills/neo/candidates", { params });
|
||||
neoCandidates.value = normalizeNeoItemsPayload(res);
|
||||
};
|
||||
|
||||
const fetchNeoReleases = async () => {
|
||||
const params = {
|
||||
skill_key: neoFilters.skill_key || undefined,
|
||||
stage: neoFilters.stage || undefined,
|
||||
};
|
||||
const res = await axios.get("/api/skills/neo/releases", { params });
|
||||
neoReleases.value = normalizeNeoItemsPayload(res).map((item) => {
|
||||
if (!item || typeof item !== "object") {
|
||||
return item;
|
||||
}
|
||||
return {
|
||||
...item,
|
||||
is_active: item.is_active ?? item.active ?? false,
|
||||
};
|
||||
});
|
||||
};
|
||||
|
||||
const loadNeoAvailability = async () => {
|
||||
try {
|
||||
const res = await axios.get("/api/config/get");
|
||||
const config = res?.data?.data?.config || {};
|
||||
const providerSettings = config?.provider_settings || {};
|
||||
const runtime = providerSettings?.computer_use_runtime || "local";
|
||||
const booter = providerSettings?.sandbox?.booter || "";
|
||||
neoEnabled.value = runtime === "sandbox" && booter === "shipyard_neo";
|
||||
} catch (_err) {
|
||||
neoEnabled.value = false;
|
||||
}
|
||||
|
||||
neoUnavailableMessage.value = tm("skills.neoRuntimeRequired");
|
||||
if (!neoEnabled.value && mode.value === "neo") {
|
||||
mode.value = "local";
|
||||
}
|
||||
};
|
||||
|
||||
const fetchNeoData = async () => {
|
||||
neoLoading.value = true;
|
||||
try {
|
||||
await Promise.all([fetchNeoCandidates(), fetchNeoReleases()]);
|
||||
} catch (_err) {
|
||||
showMessage(tm("skills.neoLoadFailed"), "error");
|
||||
} finally {
|
||||
neoLoading.value = false;
|
||||
}
|
||||
};
|
||||
|
||||
const evaluateCandidate = async (candidate, passed) => {
|
||||
try {
|
||||
const res = await axios.post("/api/skills/neo/evaluate", {
|
||||
candidate_id: candidate.id,
|
||||
passed,
|
||||
score: passed ? 1.0 : 0.0,
|
||||
report: passed ? "approved_from_webui" : "rejected_from_webui",
|
||||
});
|
||||
handleApiResponse(res, tm("skills.neoEvaluateSuccess"), tm("skills.neoEvaluateFailed"), async () => {
|
||||
await fetchNeoCandidates();
|
||||
});
|
||||
} catch (_err) {
|
||||
showMessage(tm("skills.neoEvaluateFailed"), "error");
|
||||
}
|
||||
};
|
||||
|
||||
const candidatePromoteLoadingKey = (candidateId, stage) => `${candidateId}:${stage}`;
|
||||
const isCandidatePromoteLoading = (candidateId, stage) =>
|
||||
!!candidatePromoteLoading[candidatePromoteLoadingKey(candidateId, stage)];
|
||||
const isCandidatePromoting = (candidateId) =>
|
||||
isCandidatePromoteLoading(candidateId, "canary") || isCandidatePromoteLoading(candidateId, "stable");
|
||||
|
||||
const promoteCandidate = async (candidate, stage) => {
|
||||
const candidateId = candidate?.id;
|
||||
if (!candidateId) return;
|
||||
const loadingKey = candidatePromoteLoadingKey(candidateId, stage);
|
||||
if (candidatePromoteLoading[loadingKey]) return;
|
||||
candidatePromoteLoading[loadingKey] = true;
|
||||
try {
|
||||
const res = await axios.post("/api/skills/neo/promote", {
|
||||
candidate_id: candidateId,
|
||||
stage,
|
||||
sync_to_local: true,
|
||||
});
|
||||
const ok = res?.data?.status === "ok";
|
||||
if (!ok) {
|
||||
showMessage(res?.data?.message || tm("skills.neoPromoteFailed"), "error");
|
||||
} else {
|
||||
showMessage(tm("skills.neoPromoteSuccess"), "success");
|
||||
}
|
||||
await fetchNeoData();
|
||||
if (stage === "stable") {
|
||||
await fetchSkills();
|
||||
}
|
||||
} catch (_err) {
|
||||
showMessage(tm("skills.neoPromoteFailed"), "error");
|
||||
} finally {
|
||||
candidatePromoteLoading[loadingKey] = false;
|
||||
}
|
||||
};
|
||||
|
||||
const rollbackRelease = async (release) => {
|
||||
try {
|
||||
const res = await axios.post("/api/skills/neo/rollback", {
|
||||
release_id: release.id,
|
||||
});
|
||||
handleApiResponse(res, tm("skills.neoRollbackSuccess"), tm("skills.neoRollbackFailed"), async () => {
|
||||
await fetchNeoData();
|
||||
});
|
||||
} catch (_err) {
|
||||
showMessage(tm("skills.neoRollbackFailed"), "error");
|
||||
}
|
||||
};
|
||||
|
||||
const deactivateRelease = async (release) => {
|
||||
try {
|
||||
const res = await axios.post("/api/skills/neo/rollback", {
|
||||
release_id: release.id,
|
||||
});
|
||||
handleApiResponse(
|
||||
res,
|
||||
tm("skills.neoDeactivateSuccess"),
|
||||
tm("skills.neoDeactivateFailed"),
|
||||
async () => {
|
||||
await fetchNeoData();
|
||||
},
|
||||
);
|
||||
} catch (_err) {
|
||||
showMessage(tm("skills.neoDeactivateFailed"), "error");
|
||||
}
|
||||
};
|
||||
|
||||
const handleReleaseLifecycleAction = async (release) => {
|
||||
if (release?.is_active) {
|
||||
await deactivateRelease(release);
|
||||
return;
|
||||
}
|
||||
await rollbackRelease(release);
|
||||
};
|
||||
|
||||
const syncRelease = async (release) => {
|
||||
try {
|
||||
const res = await axios.post("/api/skills/neo/sync", {
|
||||
release_id: release.id,
|
||||
});
|
||||
handleApiResponse(res, tm("skills.neoSyncSuccess"), tm("skills.neoSyncFailed"), async () => {
|
||||
await fetchSkills();
|
||||
});
|
||||
} catch (_err) {
|
||||
showMessage(tm("skills.neoSyncFailed"), "error");
|
||||
}
|
||||
};
|
||||
|
||||
const viewPayload = async (payloadRef) => {
|
||||
if (!payloadRef) return;
|
||||
try {
|
||||
const res = await axios.get("/api/skills/neo/payload", {
|
||||
params: { payload_ref: payloadRef },
|
||||
});
|
||||
if (res?.data?.status !== "ok") {
|
||||
showMessage(res?.data?.message || tm("skills.neoPayloadFailed"), "error");
|
||||
return;
|
||||
}
|
||||
const payload = res?.data?.data || {};
|
||||
payloadDialog.content = JSON.stringify(payload, null, 2);
|
||||
payloadDialog.show = true;
|
||||
} catch (_err) {
|
||||
showMessage(tm("skills.neoPayloadFailed"), "error");
|
||||
}
|
||||
};
|
||||
|
||||
const deleteCandidate = async (candidate) => {
|
||||
try {
|
||||
const res = await axios.post("/api/skills/neo/delete-candidate", {
|
||||
candidate_id: candidate.id,
|
||||
reason: "deleted_from_webui",
|
||||
});
|
||||
handleApiResponse(res, tm("skills.neoDeleteSuccess"), tm("skills.neoDeleteFailed"), async () => {
|
||||
await fetchNeoData();
|
||||
});
|
||||
} catch (_err) {
|
||||
showMessage(tm("skills.neoDeleteFailed"), "error");
|
||||
}
|
||||
};
|
||||
|
||||
const deleteRelease = async (release) => {
|
||||
try {
|
||||
const res = await axios.post("/api/skills/neo/delete-release", {
|
||||
release_id: release.id,
|
||||
reason: "deleted_from_webui",
|
||||
});
|
||||
handleApiResponse(res, tm("skills.neoDeleteSuccess"), tm("skills.neoDeleteFailed"), async () => {
|
||||
await fetchNeoData();
|
||||
});
|
||||
} catch (_err) {
|
||||
showMessage(tm("skills.neoDeleteFailed"), "error");
|
||||
}
|
||||
};
|
||||
|
||||
const refreshCurrentMode = async () => {
|
||||
if (mode.value === "neo") {
|
||||
await loadNeoAvailability();
|
||||
if (neoEnabled.value) {
|
||||
await fetchNeoData();
|
||||
} else {
|
||||
showMessage(tm("skills.neoRuntimeRequired"), "warning");
|
||||
}
|
||||
} else {
|
||||
await fetchSkills();
|
||||
}
|
||||
};
|
||||
|
||||
watch(mode, async (nextMode) => {
|
||||
if (nextMode === "neo") {
|
||||
await loadNeoAvailability();
|
||||
if (neoEnabled.value) {
|
||||
await fetchNeoData();
|
||||
}
|
||||
} else {
|
||||
await fetchSkills();
|
||||
}
|
||||
});
|
||||
|
||||
onMounted(async () => {
|
||||
await Promise.all([fetchSkills(), loadNeoAvailability()]);
|
||||
if (neoEnabled.value) {
|
||||
await fetchNeoData();
|
||||
}
|
||||
});
|
||||
onMounted(fetchSkills);
|
||||
|
||||
return {
|
||||
t,
|
||||
tm,
|
||||
mode,
|
||||
skills,
|
||||
loading,
|
||||
runtime,
|
||||
sandboxCache,
|
||||
uploadDialog,
|
||||
uploadFile,
|
||||
uploading,
|
||||
@@ -825,39 +234,11 @@ export default {
|
||||
deleteDialog,
|
||||
deleting,
|
||||
snackbar,
|
||||
neoEnabled,
|
||||
neoUnavailableMessage,
|
||||
neoLoading,
|
||||
neoCandidates,
|
||||
neoReleases,
|
||||
neoFilters,
|
||||
candidateStatusItems,
|
||||
releaseStageItems,
|
||||
activeReleaseCount,
|
||||
candidateHeaders,
|
||||
releaseHeaders,
|
||||
payloadDialog,
|
||||
refreshCurrentMode,
|
||||
fetchNeoData,
|
||||
fetchSkills,
|
||||
uploadSkill,
|
||||
downloadSkill,
|
||||
toggleSkill,
|
||||
confirmDelete,
|
||||
deleteSkill,
|
||||
evaluateCandidate,
|
||||
promoteCandidate,
|
||||
isCandidatePromoteLoading,
|
||||
isCandidatePromoting,
|
||||
rollbackRelease,
|
||||
deactivateRelease,
|
||||
handleReleaseLifecycleAction,
|
||||
syncRelease,
|
||||
viewPayload,
|
||||
deleteCandidate,
|
||||
deleteRelease,
|
||||
sourceTypeLabel,
|
||||
sourceTypeColor,
|
||||
isSandboxPresetSkill,
|
||||
};
|
||||
},
|
||||
};
|
||||
@@ -869,42 +250,5 @@ export default {
|
||||
-webkit-line-clamp: 1;
|
||||
-webkit-box-orient: vertical;
|
||||
overflow: hidden;
|
||||
min-height: 20px;
|
||||
}
|
||||
|
||||
.skill-path {
|
||||
display: -webkit-box;
|
||||
-webkit-line-clamp: 2;
|
||||
-webkit-box-orient: vertical;
|
||||
overflow: hidden;
|
||||
min-height: 40px;
|
||||
word-break: break-all;
|
||||
}
|
||||
|
||||
.payload-preview {
|
||||
max-height: 480px;
|
||||
overflow: auto;
|
||||
background: #111;
|
||||
color: #ececec;
|
||||
padding: 12px;
|
||||
border-radius: 8px;
|
||||
font-size: 12px;
|
||||
}
|
||||
.neo-filter-card {
|
||||
border-radius: 14px;
|
||||
border-color: rgba(var(--v-theme-primary), 0.25);
|
||||
background: linear-gradient(180deg, rgba(var(--v-theme-primary), 0.03), rgba(var(--v-theme-surface), 1));
|
||||
}
|
||||
|
||||
.neo-table-card {
|
||||
border-radius: 14px;
|
||||
}
|
||||
|
||||
.neo-data-table :deep(.v-data-table-header__content) {
|
||||
font-weight: 700;
|
||||
}
|
||||
|
||||
.neo-data-table :deep(tbody tr:hover) {
|
||||
background: rgba(var(--v-theme-primary), 0.04);
|
||||
}
|
||||
</style>
|
||||
|
||||
@@ -10,7 +10,7 @@
|
||||
density="compact"
|
||||
:model-value="getItemEnabled()"
|
||||
:loading="loading"
|
||||
:disabled="loading || disableToggle"
|
||||
:disabled="loading"
|
||||
v-bind="props"
|
||||
@update:model-value="toggleEnabled"
|
||||
></v-switch>
|
||||
@@ -29,7 +29,7 @@
|
||||
color="error"
|
||||
size="small"
|
||||
rounded="xl"
|
||||
:disabled="loading || disableDelete"
|
||||
:disabled="loading"
|
||||
@click="$emit('delete', item)"
|
||||
>
|
||||
{{ t('core.common.itemCard.delete') }}
|
||||
@@ -108,14 +108,6 @@ export default {
|
||||
showEditButton: {
|
||||
type: Boolean,
|
||||
default: true
|
||||
},
|
||||
disableToggle: {
|
||||
type: Boolean,
|
||||
default: false
|
||||
},
|
||||
disableDelete: {
|
||||
type: Boolean,
|
||||
default: false
|
||||
}
|
||||
},
|
||||
emits: ['toggle-enabled', 'delete', 'edit', 'copy'],
|
||||
@@ -140,7 +132,6 @@ export default {
|
||||
transition: all 0.3s ease;
|
||||
overflow: hidden;
|
||||
min-height: 220px;
|
||||
height: 100%;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
justify-content: space-between;
|
||||
|
||||
@@ -161,22 +161,6 @@
|
||||
"booter": {
|
||||
"description": "Sandbox Environment Driver"
|
||||
},
|
||||
"shipyard_neo_endpoint": {
|
||||
"description": "Shipyard Neo API Endpoint",
|
||||
"hint": "Bay API address, default http://127.0.0.1:8114."
|
||||
},
|
||||
"shipyard_neo_access_token": {
|
||||
"description": "Shipyard Neo Access Token",
|
||||
"hint": "Bay API Key (sk-bay-...). Leave empty for auto-discovery from credentials.json."
|
||||
},
|
||||
"shipyard_neo_profile": {
|
||||
"description": "Shipyard Neo Profile",
|
||||
"hint": "Sandbox profile for Shipyard Neo, e.g. python-default."
|
||||
},
|
||||
"shipyard_neo_ttl": {
|
||||
"description": "Shipyard Neo Sandbox TTL",
|
||||
"hint": "Sandbox time-to-live in seconds."
|
||||
},
|
||||
"shipyard_endpoint": {
|
||||
"description": "Shipyard API Endpoint",
|
||||
"hint": "API access address for Shipyard service."
|
||||
@@ -370,8 +354,7 @@
|
||||
"hint": "Optional Discord activity name. Leave empty to disable."
|
||||
},
|
||||
"discord_command_register": {
|
||||
"description": "Register Discord slash commands",
|
||||
"hint": "When enabled, AstrBot will automatically register plugin commands as Discord slash commands"
|
||||
"description": "Auto-register plugin commands as Discord slash commands"
|
||||
},
|
||||
"discord_proxy": {
|
||||
"description": "Discord Proxy URL",
|
||||
@@ -584,51 +567,6 @@
|
||||
"only_use_webhook_url_to_send": {
|
||||
"description": "Send Replies via Webhook Only",
|
||||
"hint": "When enabled, all WeCom AI Bot replies are sent through msg_push_webhook_url. The message push webhook supports more message types (such as images, files, etc.). If you do not need the typing effect, it is strongly recommended to use this option. "
|
||||
},
|
||||
"kook_bot_token": {
|
||||
"description": "Bot Token",
|
||||
"type": "string",
|
||||
"hint": "Required. The Bot Token obtained from the KOOK Developer Platform."
|
||||
},
|
||||
"kook_bot_nickname": {
|
||||
"description": "Bot Nickname",
|
||||
"type": "string",
|
||||
"hint": "Optional. If the sender nickname matches this value, the message will be ignored to prevent broadcast storms."
|
||||
},
|
||||
"kook_reconnect_delay": {
|
||||
"description": "Reconnect Delay",
|
||||
"type": "int",
|
||||
"hint": "Delay time for reconnection (seconds), using an exponential backoff strategy."
|
||||
},
|
||||
"kook_max_reconnect_delay": {
|
||||
"description": "Max Reconnect Delay",
|
||||
"type": "int",
|
||||
"hint": "The maximum value for reconnection delay (seconds)."
|
||||
},
|
||||
"kook_max_retry_delay": {
|
||||
"description": "Max Retry Delay",
|
||||
"type": "int",
|
||||
"hint": "The maximum delay time for retries (seconds)."
|
||||
},
|
||||
"kook_heartbeat_interval": {
|
||||
"description": "Heartbeat Interval",
|
||||
"type": "int",
|
||||
"hint": "The interval time for heartbeat detection (seconds)."
|
||||
},
|
||||
"kook_heartbeat_timeout": {
|
||||
"description": "Heartbeat Timeout",
|
||||
"type": "int",
|
||||
"hint": "The timeout duration for heartbeat detection (seconds)."
|
||||
},
|
||||
"kook_max_heartbeat_failures": {
|
||||
"description": "Max Heartbeat Failures",
|
||||
"type": "int",
|
||||
"hint": "Maximum allowed heartbeat failures; the connection will be dropped if exceeded."
|
||||
},
|
||||
"kook_max_consecutive_failures": {
|
||||
"description": "Max Consecutive Failures",
|
||||
"type": "int",
|
||||
"hint": "Maximum allowed consecutive failures; retries will stop if exceeded."
|
||||
}
|
||||
},
|
||||
"general": {
|
||||
@@ -783,17 +721,6 @@
|
||||
"hint": "Telegram only supports a fixed reaction set, reference: [https://gist.github.com/Soulter/3f22c8e5f9c7e152e967e8bc28c97fc9](https://gist.github.com/Soulter/3f22c8e5f9c7e152e967e8bc28c97fc9)"
|
||||
}
|
||||
}
|
||||
},
|
||||
"discord": {
|
||||
"pre_ack_emoji": {
|
||||
"enable": {
|
||||
"description": "[Discord] Enable Pre-acknowledgment Emoji"
|
||||
},
|
||||
"emojis": {
|
||||
"description": "Emoji List (Unicode or Custom Emoji Name)",
|
||||
"hint": "Enter Unicode emoji symbols, e.g., 👍, 🤔, ⏳"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -216,9 +216,6 @@
|
||||
"enterUrl": "Enter extension repository URL"
|
||||
},
|
||||
"skills": {
|
||||
"modeLocal": "Local Skills",
|
||||
"modeNeo": "Neo Skills",
|
||||
"actions": "Actions",
|
||||
"upload": "Upload Skills",
|
||||
"refresh": "Refresh",
|
||||
"empty": "No Skills found",
|
||||
@@ -232,9 +229,6 @@
|
||||
"path": "Path",
|
||||
"uploadSuccess": "Upload succeeded",
|
||||
"uploadFailed": "Upload failed",
|
||||
"download": "Download",
|
||||
"downloadSuccess": "Download succeeded",
|
||||
"downloadFailed": "Download failed",
|
||||
"loadFailed": "Failed to load Skills",
|
||||
"updateSuccess": "Updated successfully",
|
||||
"updateFailed": "Update failed",
|
||||
@@ -242,42 +236,8 @@
|
||||
"deleteMessage": "Are you sure you want to delete this Skill?",
|
||||
"deleteSuccess": "Deleted successfully",
|
||||
"deleteFailed": "Delete failed",
|
||||
"neoSkillKey": "Filter by skill_key",
|
||||
"neoStatus": "Candidate Status",
|
||||
"neoStage": "Release Stage",
|
||||
"neoFilterHint": "Filter candidates and release records",
|
||||
"neoAll": "All",
|
||||
"neoCandidates": "Neo Candidates",
|
||||
"neoReleases": "Neo Releases",
|
||||
"neoLoadFailed": "Failed to load Neo skills data",
|
||||
"neoPass": "Pass",
|
||||
"neoReject": "Reject",
|
||||
"neoEvaluateSuccess": "Evaluation updated",
|
||||
"neoEvaluateFailed": "Failed to update evaluation",
|
||||
"neoPromoteSuccess": "Promoted successfully",
|
||||
"neoPromoteFailed": "Failed to promote",
|
||||
"neoRollback": "Rollback",
|
||||
"neoRollbackSuccess": "Rollback succeeded",
|
||||
"neoRollbackFailed": "Rollback failed",
|
||||
"neoDeactivate": "Deactivate",
|
||||
"neoDeactivateSuccess": "Deactivated successfully",
|
||||
"neoDeactivateFailed": "Failed to deactivate",
|
||||
"neoSync": "Sync",
|
||||
"neoSyncSuccess": "Sync succeeded",
|
||||
"neoSyncFailed": "Sync failed",
|
||||
"neoDelete": "Delete",
|
||||
"neoDeleteSuccess": "Deleted successfully",
|
||||
"neoDeleteFailed": "Failed to delete",
|
||||
"neoPayloadTitle": "Neo Payload",
|
||||
"neoPayloadFailed": "Failed to load payload",
|
||||
"runtimeNoneWarning": "Computer Use runtime is set to None; Skills may not run correctly because no runtime is enabled.",
|
||||
"runtimeHint": "Set the Computer Use runtime to Local or Sandbox in settings so AstrBot can use your Skills.",
|
||||
"neoRuntimeRequired": "Neo Skills are available only when runtime is sandbox and sandbox booter is shipyard_neo.",
|
||||
"sourceLocalOnly": "Local Skill",
|
||||
"sourceSandboxOnly": "Sandbox Preset Skill",
|
||||
"sourceBoth": "Local + Sandbox",
|
||||
"sandboxDiscoveryPending": "Sandbox preset skills have not been discovered yet. Start at least one sandbox session to populate this list.",
|
||||
"sandboxPresetReadonly": "Sandbox preset skills are read-only here. You cannot delete or enable/disable them from Local Skills."
|
||||
"runtimeHint": "Set the Computer Use runtime to Local or Sandbox in settings so AstrBot can use your Skills."
|
||||
},
|
||||
"card": {
|
||||
"actions": {
|
||||
|
||||
@@ -164,22 +164,6 @@
|
||||
"booter": {
|
||||
"description": "沙箱环境驱动器"
|
||||
},
|
||||
"shipyard_neo_endpoint": {
|
||||
"description": "Shipyard Neo API Endpoint",
|
||||
"hint": "Shipyard Neo(Bay) 服务的 API 地址,默认 http://127.0.0.1:8114。"
|
||||
},
|
||||
"shipyard_neo_access_token": {
|
||||
"description": "Shipyard Neo 访问令牌",
|
||||
"hint": "Bay 的 API Key(sk-bay-...)。留空时自动从 credentials.json 发现。"
|
||||
},
|
||||
"shipyard_neo_profile": {
|
||||
"description": "Shipyard Neo Profile",
|
||||
"hint": "Shipyard Neo 沙箱 profile,例如 python-default。"
|
||||
},
|
||||
"shipyard_neo_ttl": {
|
||||
"description": "Shipyard Neo Sandbox 存活时间(秒)",
|
||||
"hint": "Shipyard Neo 沙箱的生存时间(秒)。"
|
||||
},
|
||||
"shipyard_endpoint": {
|
||||
"description": "Shipyard API Endpoint",
|
||||
"hint": "Shipyard 服务的 API 访问地址。"
|
||||
@@ -373,8 +357,7 @@
|
||||
"hint": "可选的 Discord 活动名称。留空则不设置活动。"
|
||||
},
|
||||
"discord_command_register": {
|
||||
"description": "注册 Discord 指令",
|
||||
"hint": "启用后,自动将插件指令注册为 Discord 斜杠指令"
|
||||
"description": "是否自动将插件指令注册为 Discord 斜杠指令"
|
||||
},
|
||||
"discord_proxy": {
|
||||
"description": "Discord 代理地址",
|
||||
@@ -587,51 +570,6 @@
|
||||
"only_use_webhook_url_to_send": {
|
||||
"description": "仅使用 Webhook 发送消息",
|
||||
"hint": "启用后,企业微信智能机器人的所有回复都改为通过消息推送 Webhook 发送。消息推送 Webhook 支持更多的消息类型(如图片、文件等)。如果不需要打字机效果,强烈建议使用此选项。"
|
||||
},
|
||||
"kook_bot_token": {
|
||||
"description": "机器人 Token",
|
||||
"type": "string",
|
||||
"hint": "必填项。从 KOOK 开发者平台获取的机器人 Token"
|
||||
},
|
||||
"kook_bot_nickname": {
|
||||
"description": "Bot Nickname",
|
||||
"type": "string",
|
||||
"hint": "可选项。若发送者昵称与此值一致,将忽略该消息。"
|
||||
},
|
||||
"kook_reconnect_delay": {
|
||||
"description": "重连延迟",
|
||||
"type": "int",
|
||||
"hint": "重连延迟时间(秒),使用指数退避策略"
|
||||
},
|
||||
"kook_max_reconnect_delay": {
|
||||
"description": "最大重连延迟",
|
||||
"type": "int",
|
||||
"hint": "重连延迟的最大值(秒)"
|
||||
},
|
||||
"kook_max_retry_delay": {
|
||||
"description": "最大重试延迟",
|
||||
"type": "int",
|
||||
"hint": "重试的最大延迟时间(秒)"
|
||||
},
|
||||
"kook_heartbeat_interval": {
|
||||
"description": "心跳间隔",
|
||||
"type": "int",
|
||||
"hint": "心跳检测间隔时间(秒)"
|
||||
},
|
||||
"kook_heartbeat_timeout": {
|
||||
"description": "心跳超时时间",
|
||||
"type": "int",
|
||||
"hint": "心跳检测超时时间(秒)"
|
||||
},
|
||||
"kook_max_heartbeat_failures": {
|
||||
"description": "最大心跳失败次数",
|
||||
"type": "int",
|
||||
"hint": "允许的最大心跳失败次数,超过后断开连接"
|
||||
},
|
||||
"kook_max_consecutive_failures": {
|
||||
"description": "最大连续失败次数",
|
||||
"type": "int",
|
||||
"hint": "允许的最大连续失败次数,超过后停止重试"
|
||||
}
|
||||
},
|
||||
"general": {
|
||||
@@ -786,17 +724,6 @@
|
||||
"hint": "Telegram 仅支持固定反应集合,参考:[https://gist.github.com/Soulter/3f22c8e5f9c7e152e967e8bc28c97fc9](https://gist.github.com/Soulter/3f22c8e5f9c7e152e967e8bc28c97fc9)"
|
||||
}
|
||||
}
|
||||
},
|
||||
"discord": {
|
||||
"pre_ack_emoji": {
|
||||
"enable": {
|
||||
"description": "[Discord] 启用预回应表情"
|
||||
},
|
||||
"emojis": {
|
||||
"description": "表情列表(Unicode 或自定义表情名)",
|
||||
"hint": "填写 Unicode 表情符号,例如:👍、🤔、⏳"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -216,9 +216,6 @@
|
||||
"enterUrl": "输入插件仓库链接"
|
||||
},
|
||||
"skills": {
|
||||
"modeLocal": "本地 Skills",
|
||||
"modeNeo": "Neo Skills",
|
||||
"actions": "操作",
|
||||
"upload": "上传 Skills",
|
||||
"refresh": "刷新",
|
||||
"empty": "暂无 Skills",
|
||||
@@ -232,9 +229,6 @@
|
||||
"path": "路径",
|
||||
"uploadSuccess": "上传成功",
|
||||
"uploadFailed": "上传失败",
|
||||
"download": "下载",
|
||||
"downloadSuccess": "下载成功",
|
||||
"downloadFailed": "下载失败",
|
||||
"loadFailed": "加载 Skills 失败",
|
||||
"updateSuccess": "更新成功",
|
||||
"updateFailed": "更新失败",
|
||||
@@ -242,42 +236,8 @@
|
||||
"deleteMessage": "确定要删除该 Skill 吗?",
|
||||
"deleteSuccess": "删除成功",
|
||||
"deleteFailed": "删除失败",
|
||||
"neoSkillKey": "skill_key 过滤",
|
||||
"neoStatus": "候选状态",
|
||||
"neoStage": "发布阶段",
|
||||
"neoFilterHint": "筛选候选与发布记录",
|
||||
"neoAll": "全部",
|
||||
"neoCandidates": "Neo Candidates",
|
||||
"neoReleases": "Neo Releases",
|
||||
"neoLoadFailed": "加载 Neo Skills 数据失败",
|
||||
"neoPass": "通过",
|
||||
"neoReject": "拒绝",
|
||||
"neoEvaluateSuccess": "评测更新成功",
|
||||
"neoEvaluateFailed": "评测更新失败",
|
||||
"neoPromoteSuccess": "发布成功",
|
||||
"neoPromoteFailed": "发布失败",
|
||||
"neoRollback": "回滚",
|
||||
"neoRollbackSuccess": "回滚成功",
|
||||
"neoRollbackFailed": "回滚失败",
|
||||
"neoDeactivate": "失活",
|
||||
"neoDeactivateSuccess": "失活成功",
|
||||
"neoDeactivateFailed": "失活失败",
|
||||
"neoSync": "同步",
|
||||
"neoSyncSuccess": "同步成功",
|
||||
"neoSyncFailed": "同步失败",
|
||||
"neoDelete": "删除",
|
||||
"neoDeleteSuccess": "删除成功",
|
||||
"neoDeleteFailed": "删除失败",
|
||||
"neoPayloadTitle": "Neo Payload 详情",
|
||||
"neoPayloadFailed": "读取 Payload 失败",
|
||||
"runtimeNoneWarning": "Computer Use 运行环境为无,Skills 可能无法正确被 Agent 运行,因为没有启用运行环境。",
|
||||
"runtimeHint": "需要在配置的 “使用电脑能力” 中将运行环境设置为 “local” 或 “sandbox” 才能让 AstrBot 正常使用你提供的 Skills。",
|
||||
"neoRuntimeRequired": "Neo Skills 仅在运行环境为 sandbox 且沙箱驱动为 shipyard_neo 时可用。",
|
||||
"sourceLocalOnly": "本地 Skill",
|
||||
"sourceSandboxOnly": "Sandbox 预置 Skill",
|
||||
"sourceBoth": "本地 + Sandbox",
|
||||
"sandboxDiscoveryPending": "尚未发现 Sandbox 预置 Skill。请至少启动一次 Sandbox 会话后再查看。",
|
||||
"sandboxPresetReadonly": "Sandbox 预置 Skill 在此处为只读,无法在本地 Skills 页面删除或启用/禁用。"
|
||||
"runtimeHint": "需要在配置的 “使用电脑能力” 中将运行环境设置为 “local” 或 “sandbox” 才能让 AstrBot 正常使用你提供的 Skills。"
|
||||
},
|
||||
"card": {
|
||||
"actions": {
|
||||
|
||||
@@ -468,12 +468,6 @@ onMounted(async () => {
|
||||
<v-icon>mdi-menu</v-icon>
|
||||
</v-btn>
|
||||
|
||||
<!-- 移动端 chat sidebar 展开按钮 - 仅在 chat 模式下的小屏幕显示 -->
|
||||
<v-btn v-if="customizer.viewMode === 'chat'" class="hidden-lg-and-up ms-1" icon rounded="sm" variant="flat"
|
||||
@click.stop="customizer.TOGGLE_CHAT_SIDEBAR()">
|
||||
<v-icon>mdi-menu</v-icon>
|
||||
</v-btn>
|
||||
|
||||
<div class="logo-container" :class="{ 'mobile-logo': $vuetify.display.xs, 'chat-mode-logo': customizer.viewMode === 'chat' }" @click="handleLogoClick">
|
||||
<span class="logo-text Outfit">Astr<span class="logo-text bot-text-wrapper">Bot
|
||||
<img v-if="isChristmas" src="@/assets/images/xmas-hat.png" alt="Christmas hat" class="xmas-hat" />
|
||||
@@ -494,13 +488,13 @@ onMounted(async () => {
|
||||
</small>
|
||||
</div>
|
||||
|
||||
<!-- Bot/Chat 模式切换按钮 - 手机端隐藏,移入 ... 菜单 -->
|
||||
<!-- Bot/Chat 模式切换按钮 -->
|
||||
<v-btn-toggle
|
||||
v-model="viewMode"
|
||||
mandatory
|
||||
variant="outlined"
|
||||
density="compact"
|
||||
class="mr-4 hidden-xs"
|
||||
class="mr-4"
|
||||
color="primary"
|
||||
>
|
||||
<v-btn value="bot" size="small">
|
||||
@@ -530,30 +524,6 @@ onMounted(async () => {
|
||||
</v-btn>
|
||||
</template>
|
||||
|
||||
<!-- Bot/Chat 模式切换 - 仅在手机端显示 -->
|
||||
<template v-if="$vuetify.display.xs">
|
||||
<div class="mobile-mode-toggle-wrapper">
|
||||
<v-btn-toggle
|
||||
v-model="viewMode"
|
||||
mandatory
|
||||
variant="outlined"
|
||||
density="compact"
|
||||
color="primary"
|
||||
class="mobile-mode-toggle"
|
||||
>
|
||||
<v-btn value="bot" size="small">
|
||||
<v-icon start>mdi-robot</v-icon>
|
||||
Bot
|
||||
</v-btn>
|
||||
<v-btn value="chat" size="small">
|
||||
<v-icon start>mdi-chat</v-icon>
|
||||
Chat
|
||||
</v-btn>
|
||||
</v-btn-toggle>
|
||||
</div>
|
||||
<v-divider class="my-1" />
|
||||
</template>
|
||||
|
||||
<!-- 语言切换 -->
|
||||
<v-list-item
|
||||
v-for="lang in languages"
|
||||
@@ -918,10 +888,6 @@ onMounted(async () => {
|
||||
margin-left: 22px;
|
||||
}
|
||||
|
||||
.mobile-logo.chat-mode-logo {
|
||||
margin-left: 4px;
|
||||
}
|
||||
|
||||
.logo-text {
|
||||
font-size: 24px;
|
||||
font-weight: 1000;
|
||||
@@ -960,20 +926,6 @@ onMounted(async () => {
|
||||
margin-right: 8px;
|
||||
}
|
||||
|
||||
.mobile-mode-toggle-wrapper {
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
padding: 8px 12px 4px;
|
||||
}
|
||||
|
||||
.mobile-mode-toggle {
|
||||
width: 100%;
|
||||
}
|
||||
|
||||
.mobile-mode-toggle .v-btn {
|
||||
flex: 1;
|
||||
}
|
||||
|
||||
/* 移动端对话框标题样式 */
|
||||
.mobile-card-title {
|
||||
display: flex;
|
||||
|
||||
@@ -38,7 +38,7 @@ const isItemActive = computed(() => {
|
||||
</template>
|
||||
|
||||
<!-- children -->
|
||||
<template v-for="(child, index) in item.children" :key="child.title || child.to || `child-${index}`">
|
||||
<template v-for="(child, index) in item.children" :key="index">
|
||||
<NavItem :item="child" :level="(level || 0) + 1" />
|
||||
</template>
|
||||
</v-list-group>
|
||||
|
||||
@@ -10,60 +10,26 @@ import ChangelogDialog from '@/components/shared/ChangelogDialog.vue';
|
||||
const { t, locale } = useI18n();
|
||||
|
||||
const customizer = useCustomizerStore();
|
||||
|
||||
function collectGroupValues(items, values = new Set()) {
|
||||
items.forEach((item) => {
|
||||
if (item?.children && item.title) {
|
||||
values.add(item.title);
|
||||
collectGroupValues(item.children, values);
|
||||
}
|
||||
});
|
||||
return values;
|
||||
}
|
||||
|
||||
function sanitizeOpenedItems(items, menuItems) {
|
||||
if (!Array.isArray(items)) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const groupValues = collectGroupValues(menuItems);
|
||||
return items.filter((item) => typeof item === 'string' && groupValues.has(item));
|
||||
}
|
||||
|
||||
function getInitialOpenedItems(menuItems) {
|
||||
try {
|
||||
const stored = JSON.parse(localStorage.getItem('sidebar_openedItems') || '[]');
|
||||
return sanitizeOpenedItems(stored, menuItems);
|
||||
} catch {
|
||||
return [];
|
||||
}
|
||||
}
|
||||
|
||||
const sidebarMenu = shallowRef(applySidebarCustomization(sidebarItems));
|
||||
const sidebarMenu = shallowRef(sidebarItems);
|
||||
|
||||
// 侧边栏分组展开状态持久化
|
||||
const openedItems = ref(getInitialOpenedItems(sidebarMenu.value));
|
||||
watch(openedItems, (val) => {
|
||||
localStorage.setItem('sidebar_openedItems', JSON.stringify(sanitizeOpenedItems(val, sidebarMenu.value)));
|
||||
}, { deep: true });
|
||||
|
||||
function refreshSidebarMenu() {
|
||||
sidebarMenu.value = applySidebarCustomization(sidebarItems);
|
||||
openedItems.value = sanitizeOpenedItems(openedItems.value, sidebarMenu.value);
|
||||
}
|
||||
const openedItems = ref(JSON.parse(localStorage.getItem('sidebar_openedItems') || '[]'));
|
||||
watch(openedItems, (val) => localStorage.setItem('sidebar_openedItems', JSON.stringify(val)), { deep: true });
|
||||
|
||||
// Apply customization on mount and listen for storage changes
|
||||
const handleStorageChange = (e) => {
|
||||
if (e.key === 'astrbot_sidebar_customization') {
|
||||
refreshSidebarMenu();
|
||||
sidebarMenu.value = applySidebarCustomization(sidebarItems);
|
||||
}
|
||||
};
|
||||
|
||||
const handleCustomEvent = () => {
|
||||
refreshSidebarMenu();
|
||||
sidebarMenu.value = applySidebarCustomization(sidebarItems);
|
||||
};
|
||||
|
||||
onMounted(() => {
|
||||
sidebarMenu.value = applySidebarCustomization(sidebarItems);
|
||||
|
||||
window.addEventListener('storage', handleStorageChange);
|
||||
window.addEventListener('sidebar-customization-changed', handleCustomEvent);
|
||||
});
|
||||
@@ -289,7 +255,7 @@ function openChangelogDialog() {
|
||||
>
|
||||
<div class="sidebar-container">
|
||||
<v-list class="pa-4 listitem flex-grow-1" v-model:opened="openedItems" :open-strategy="'multiple'">
|
||||
<template v-for="(item, i) in sidebarMenu" :key="item.title || item.to || `sidebar-item-${i}`">
|
||||
<template v-for="(item, i) in sidebarMenu" :key="i">
|
||||
<NavItem :item="item" class="leftPadding" />
|
||||
</template>
|
||||
</v-list>
|
||||
|
||||
@@ -9,9 +9,12 @@ import '@/scss/style.scss';
|
||||
import VueApexCharts from 'vue3-apexcharts';
|
||||
|
||||
import print from 'vue3-print-nb';
|
||||
import { loader } from '@guolao/vue-monaco-editor'
|
||||
import { loader } from '@guolao/vue-monaco-editor';
|
||||
import * as monaco from 'monaco-editor';
|
||||
import axios from 'axios';
|
||||
|
||||
loader.config({ monaco });
|
||||
|
||||
// 初始化新的i18n系统,等待完成后再挂载应用
|
||||
setupI18n().then(() => {
|
||||
console.log('🌍 新i18n系统初始化完成');
|
||||
@@ -108,9 +111,3 @@ window.fetch = (input: RequestInfo | URL, init?: RequestInit) => {
|
||||
}
|
||||
return _origFetch(input, { ...init, headers });
|
||||
};
|
||||
|
||||
loader.config({
|
||||
paths: {
|
||||
vs: 'https://cdn.jsdelivr.net/npm/monaco-editor@0.54.0/min/vs',
|
||||
},
|
||||
})
|
||||
|
||||
@@ -15,7 +15,3 @@
|
||||
@import './components/VScrollbar';
|
||||
|
||||
@import './pages/dashboards';
|
||||
|
||||
html, body {
|
||||
overscroll-behavior-y: none;
|
||||
}
|
||||
|
||||
@@ -10,8 +10,7 @@ export const useCustomizerStore = defineStore({
|
||||
fontTheme: "Poppins",
|
||||
uiTheme: config.uiTheme,
|
||||
inputBg: config.inputBg,
|
||||
viewMode: (localStorage.getItem('viewMode') as 'bot' | 'chat') || 'bot', // 'bot' 或 'chat'
|
||||
chatSidebarOpen: false // chat mode mobile sidebar state
|
||||
viewMode: (localStorage.getItem('viewMode') as 'bot' | 'chat') || 'bot' // 'bot' 或 'chat'
|
||||
}),
|
||||
|
||||
getters: {},
|
||||
@@ -31,13 +30,7 @@ export const useCustomizerStore = defineStore({
|
||||
},
|
||||
SET_VIEW_MODE(payload: 'bot' | 'chat') {
|
||||
this.viewMode = payload;
|
||||
localStorage.setItem('viewMode', payload);
|
||||
},
|
||||
TOGGLE_CHAT_SIDEBAR() {
|
||||
this.chatSidebarOpen = !this.chatSidebarOpen;
|
||||
},
|
||||
SET_CHAT_SIDEBAR(payload: boolean) {
|
||||
this.chatSidebarOpen = payload;
|
||||
localStorage.setItem("viewMode", payload);
|
||||
},
|
||||
}
|
||||
});
|
||||
|
||||
@@ -46,22 +46,22 @@ export function getPlatformIcon(name) {
|
||||
*/
|
||||
export function getTutorialLink(platformType) {
|
||||
const tutorialMap = {
|
||||
"qq_official_webhook": "https://docs.astrbot.app/platform/qqofficial/webhook.html",
|
||||
"qq_official": "https://docs.astrbot.app/platform/qqofficial/websockets.html",
|
||||
"aiocqhttp": "https://docs.astrbot.app/platform/aiocqhttp/napcat.html",
|
||||
"wecom": "https://docs.astrbot.app/platform/wecom.html",
|
||||
"wecom_ai_bot": "https://docs.astrbot.app/platform/wecom_ai_bot.html",
|
||||
"lark": "https://docs.astrbot.app/platform/lark.html",
|
||||
"telegram": "https://docs.astrbot.app/platform/telegram.html",
|
||||
"dingtalk": "https://docs.astrbot.app/platform/dingtalk.html",
|
||||
"weixin_official_account": "https://docs.astrbot.app/platform/weixin-official-account.html",
|
||||
"discord": "https://docs.astrbot.app/platform/discord.html",
|
||||
"slack": "https://docs.astrbot.app/platform/slack.html",
|
||||
"kook": "https://docs.astrbot.app/platform/kook.html",
|
||||
"vocechat": "https://docs.astrbot.app/platform/vocechat.html",
|
||||
"satori": "https://docs.astrbot.app/platform/satori/llonebot.html",
|
||||
"misskey": "https://docs.astrbot.app/platform/misskey.html",
|
||||
"line": "https://docs.astrbot.app/platform/line.html",
|
||||
"qq_official_webhook": "https://docs.astrbot.app/deploy/platform/qqofficial/webhook.html",
|
||||
"qq_official": "https://docs.astrbot.app/deploy/platform/qqofficial/websockets.html",
|
||||
"aiocqhttp": "https://docs.astrbot.app/deploy/platform/aiocqhttp/napcat.html",
|
||||
"wecom": "https://docs.astrbot.app/deploy/platform/wecom.html",
|
||||
"wecom_ai_bot": "https://docs.astrbot.app/deploy/platform/wecom_ai_bot.html",
|
||||
"lark": "https://docs.astrbot.app/deploy/platform/lark.html",
|
||||
"telegram": "https://docs.astrbot.app/deploy/platform/telegram.html",
|
||||
"dingtalk": "https://docs.astrbot.app/deploy/platform/dingtalk.html",
|
||||
"weixin_official_account": "https://docs.astrbot.app/deploy/platform/weixin-official-account.html",
|
||||
"discord": "https://docs.astrbot.app/deploy/platform/discord.html",
|
||||
"slack": "https://docs.astrbot.app/deploy/platform/slack.html",
|
||||
"kook": "https://docs.astrbot.app/deploy/platform/kook.html",
|
||||
"vocechat": "https://docs.astrbot.app/deploy/platform/vocechat.html",
|
||||
"satori": "https://docs.astrbot.app/deploy/platform/satori/llonebot.html",
|
||||
"misskey": "https://docs.astrbot.app/deploy/platform/misskey.html",
|
||||
"line": "https://docs.astrbot.app/deploy/platform/line.html",
|
||||
}
|
||||
return tutorialMap[platformType] || "https://docs.astrbot.app";
|
||||
}
|
||||
|
||||
@@ -52,21 +52,6 @@ export function clearSidebarCustomization() {
|
||||
export function resolveSidebarItems(defaultItems, customization, options = {}) {
|
||||
const { cloneItems = false, assembleMoreGroup = false } = options;
|
||||
|
||||
const normalizeKeys = (keys = []) => {
|
||||
const list = Array.isArray(keys) ? keys : [];
|
||||
const deduped = [];
|
||||
const seen = new Set();
|
||||
|
||||
list.forEach((key) => {
|
||||
if (typeof key !== 'string') return;
|
||||
if (seen.has(key)) return;
|
||||
seen.add(key);
|
||||
deduped.push(key);
|
||||
});
|
||||
|
||||
return deduped;
|
||||
};
|
||||
|
||||
const all = new Map();
|
||||
const defaultMain = [];
|
||||
const defaultMore = [];
|
||||
@@ -85,23 +70,9 @@ export function resolveSidebarItems(defaultItems, customization, options = {}) {
|
||||
});
|
||||
|
||||
const hasCustomization = Boolean(customization);
|
||||
let mainKeys = hasCustomization ? normalizeKeys(customization.mainItems || []) : [...defaultMain];
|
||||
let moreKeys = hasCustomization ? normalizeKeys(customization.moreItems || []) : [...defaultMore];
|
||||
|
||||
if (hasCustomization) {
|
||||
mainKeys = mainKeys.filter(title => all.has(title));
|
||||
moreKeys = moreKeys.filter(title => all.has(title));
|
||||
}
|
||||
|
||||
if (hasCustomization) {
|
||||
// 如果同一项同时出现在主区与更多区,主区优先。
|
||||
const mainSet = new Set(mainKeys);
|
||||
moreKeys = moreKeys.filter(title => !mainSet.has(title));
|
||||
}
|
||||
|
||||
const used = hasCustomization
|
||||
? new Set([...mainKeys, ...moreKeys])
|
||||
: new Set(defaultMain.concat(defaultMore));
|
||||
const mainKeys = hasCustomization ? customization.mainItems || [] : defaultMain;
|
||||
const moreKeys = hasCustomization ? customization.moreItems || [] : defaultMore;
|
||||
const used = hasCustomization ? new Set([...mainKeys, ...moreKeys]) : new Set(defaultMain.concat(defaultMore));
|
||||
|
||||
const mainItems = mainKeys
|
||||
.map(title => all.get(title))
|
||||
@@ -148,13 +119,7 @@ export function resolveSidebarItems(defaultItems, customization, options = {}) {
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
mainItems,
|
||||
moreItems,
|
||||
merged,
|
||||
normalizedMainKeys: [...mainKeys],
|
||||
normalizedMoreKeys: [...moreKeys]
|
||||
};
|
||||
return { mainItems, moreItems, merged };
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -164,29 +129,9 @@ export function resolveSidebarItems(defaultItems, customization, options = {}) {
|
||||
*/
|
||||
export function applySidebarCustomization(defaultItems) {
|
||||
const customization = getSidebarCustomization();
|
||||
const {
|
||||
merged,
|
||||
normalizedMainKeys,
|
||||
normalizedMoreKeys
|
||||
} = resolveSidebarItems(defaultItems, customization, {
|
||||
const { merged } = resolveSidebarItems(defaultItems, customization, {
|
||||
cloneItems: true,
|
||||
assembleMoreGroup: true
|
||||
});
|
||||
|
||||
if (customization) {
|
||||
const rawMainKeys = Array.isArray(customization.mainItems) ? customization.mainItems : [];
|
||||
const rawMoreKeys = Array.isArray(customization.moreItems) ? customization.moreItems : [];
|
||||
const hasChanged =
|
||||
JSON.stringify(rawMainKeys) !== JSON.stringify(normalizedMainKeys) ||
|
||||
JSON.stringify(rawMoreKeys) !== JSON.stringify(normalizedMoreKeys);
|
||||
|
||||
if (hasChanged) {
|
||||
setSidebarCustomization({
|
||||
mainItems: normalizedMainKeys,
|
||||
moreItems: normalizedMoreKeys
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
return merged || defaultItems;
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ import { fileURLToPath, URL } from 'url';
|
||||
import { defineConfig } from 'vite';
|
||||
import vue from '@vitejs/plugin-vue';
|
||||
import vuetify from 'vite-plugin-vuetify';
|
||||
import monacoEditorPlugin from 'vite-plugin-monaco-editor';
|
||||
|
||||
// https://vitejs.dev/config/
|
||||
export default defineConfig({
|
||||
@@ -15,7 +16,8 @@ export default defineConfig({
|
||||
}),
|
||||
vuetify({
|
||||
autoImport: true
|
||||
})
|
||||
}),
|
||||
monacoEditorPlugin({})
|
||||
],
|
||||
resolve: {
|
||||
alias: {
|
||||
|
||||
@@ -1,20 +0,0 @@
|
||||
schema: spec-driven
|
||||
|
||||
# Project context (optional)
|
||||
# This is shown to AI when creating artifacts.
|
||||
# Add your tech stack, conventions, style guides, domain knowledge, etc.
|
||||
# Example:
|
||||
# context: |
|
||||
# Tech stack: TypeScript, React, Node.js
|
||||
# We use conventional commits
|
||||
# Domain: e-commerce platform
|
||||
|
||||
# Per-artifact rules (optional)
|
||||
# Add custom rules for specific artifacts.
|
||||
# Example:
|
||||
# rules:
|
||||
# proposal:
|
||||
# - Keep proposals under 500 words
|
||||
# - Always include a "Non-goals" section
|
||||
# tasks:
|
||||
# - Break tasks into chunks of max 2 hours
|
||||
@@ -61,7 +61,6 @@ dependencies = [
|
||||
"xinference-client",
|
||||
"tenacity>=9.1.2",
|
||||
"shipyard-python-sdk>=0.2.4",
|
||||
"shipyard-neo-sdk @ git+https://github.com/AstrBotDevs/shipyard-neo.git#subdirectory=shipyard-neo-sdk",
|
||||
"python-socks>=2.8.0",
|
||||
"packaging>=24.2",
|
||||
]
|
||||
@@ -111,17 +110,6 @@ reportMissingImports = false
|
||||
include = ["astrbot"]
|
||||
exclude = ["dashboard", "node_modules", "dist", "data", "tests"]
|
||||
|
||||
[tool.hatch.metadata]
|
||||
allow-direct-references = true
|
||||
|
||||
# Include bundled dashboard dist even though it is not tracked by VCS.
|
||||
[tool.hatch.build.targets.wheel]
|
||||
artifacts = ["astrbot/dashboard/dist/**"]
|
||||
|
||||
# Custom build hook: builds the Vue dashboard and copies dist into the package.
|
||||
[tool.hatch.build.hooks.custom]
|
||||
path = "scripts/hatch_build.py"
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
@@ -54,5 +54,4 @@ markitdown-no-magika[docx,xls,xlsx]>=0.1.2
|
||||
xinference-client
|
||||
tenacity>=9.1.2
|
||||
shipyard-python-sdk>=0.2.4
|
||||
shipyard-neo-sdk @ git+https://github.com/AstrBotDevs/shipyard-neo.git#subdirectory=shipyard-neo-sdk
|
||||
packaging>=24.2
|
||||
|
||||
@@ -1,63 +0,0 @@
|
||||
"""
|
||||
Custom Hatchling build hook.
|
||||
|
||||
During `hatch build` (or `pip wheel`), this hook:
|
||||
1. Runs `npm run build` inside the `dashboard/` directory.
|
||||
2. Copies the resulting `dashboard/dist/` tree into
|
||||
`astrbot/dashboard/dist/` so the static assets are shipped
|
||||
inside the Python wheel.
|
||||
"""
|
||||
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from hatchling.builders.hooks.plugin.interface import BuildHookInterface
|
||||
|
||||
|
||||
class CustomBuildHook(BuildHookInterface):
|
||||
PLUGIN_NAME = "custom"
|
||||
|
||||
def initialize(self, version: str, build_data: dict) -> None:
|
||||
root = Path(self.root)
|
||||
dashboard_src = root / "dashboard"
|
||||
dist_src = dashboard_src / "dist"
|
||||
dist_target = root / "astrbot" / "dashboard" / "dist"
|
||||
|
||||
if not dashboard_src.exists():
|
||||
print(
|
||||
"[hatch_build] 'dashboard/' directory not found – skipping dashboard build.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return
|
||||
|
||||
# ── Install Node dependencies if node_modules is absent ─────────────
|
||||
if not (dashboard_src / "node_modules").exists():
|
||||
print("[hatch_build] Installing dashboard Node dependencies...")
|
||||
subprocess.run(
|
||||
["npm", "install"],
|
||||
cwd=dashboard_src,
|
||||
check=True,
|
||||
)
|
||||
|
||||
# ── Build the Vue/Vite dashboard ──────────────────────────────────────
|
||||
print("[hatch_build] Building Vue dashboard (npm run build)...")
|
||||
subprocess.run(
|
||||
["npm", "run", "build"],
|
||||
cwd=dashboard_src,
|
||||
check=True,
|
||||
)
|
||||
|
||||
if not dist_src.exists():
|
||||
print(
|
||||
"[hatch_build] dashboard/dist not found after build – skipping copy.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return
|
||||
|
||||
# ── Copy into the Python package tree ────────────────────────────────
|
||||
if dist_target.exists():
|
||||
shutil.rmtree(dist_target)
|
||||
shutil.copytree(dist_src, dist_target)
|
||||
print(f"[hatch_build] Dashboard dist copied → {dist_target.relative_to(root)}")
|
||||
@@ -1,171 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
ROOT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")/.." && pwd)"
|
||||
cd "$ROOT_DIR"
|
||||
|
||||
PROFILE="neo"
|
||||
RUN_SYNC=true
|
||||
RUN_LINT=true
|
||||
RUN_SMOKE=true
|
||||
RUN_DASHBOARD=false
|
||||
|
||||
usage() {
|
||||
cat <<'EOF'
|
||||
Usage:
|
||||
scripts/pr_test_env.sh [options]
|
||||
|
||||
Options:
|
||||
--profile <neo|full> Test profile. Default: neo
|
||||
--with-dashboard Build dashboard before finishing checks
|
||||
--no-dashboard Disable dashboard build (even for full profile)
|
||||
--skip-sync Skip `uv sync`
|
||||
--skip-lint Skip `ruff format --check` and `ruff check`
|
||||
--skip-smoke Skip startup smoke test
|
||||
-h, --help Show this help message
|
||||
|
||||
Environment:
|
||||
PYTEST_ARGS Extra args appended to pytest command
|
||||
EOF
|
||||
}
|
||||
|
||||
while (($# > 0)); do
|
||||
case "$1" in
|
||||
--profile)
|
||||
PROFILE="${2:-}"
|
||||
if [[ "$PROFILE" != "neo" && "$PROFILE" != "full" ]]; then
|
||||
echo "Unsupported profile: $PROFILE" >&2
|
||||
exit 1
|
||||
fi
|
||||
shift 2
|
||||
;;
|
||||
--with-dashboard)
|
||||
RUN_DASHBOARD=true
|
||||
shift
|
||||
;;
|
||||
--skip-sync)
|
||||
RUN_SYNC=false
|
||||
shift
|
||||
;;
|
||||
--skip-lint)
|
||||
RUN_LINT=false
|
||||
shift
|
||||
;;
|
||||
--skip-smoke)
|
||||
RUN_SMOKE=false
|
||||
shift
|
||||
;;
|
||||
--no-dashboard)
|
||||
RUN_DASHBOARD=false
|
||||
shift
|
||||
;;
|
||||
-h | --help)
|
||||
usage
|
||||
exit 0
|
||||
;;
|
||||
*)
|
||||
echo "Unknown option: $1" >&2
|
||||
usage
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
if [[ "$PROFILE" == "full" && "$RUN_DASHBOARD" == false ]]; then
|
||||
RUN_DASHBOARD=true
|
||||
fi
|
||||
|
||||
echo "==> Profile: $PROFILE"
|
||||
echo "==> Sync dependencies: $RUN_SYNC"
|
||||
echo "==> Run lint: $RUN_LINT"
|
||||
echo "==> Run smoke test: $RUN_SMOKE"
|
||||
echo "==> Build dashboard: $RUN_DASHBOARD"
|
||||
|
||||
if [[ "$RUN_SYNC" == true ]]; then
|
||||
echo "==> Syncing dependencies with uv"
|
||||
uv sync --group dev
|
||||
fi
|
||||
|
||||
echo "==> Preparing test directories"
|
||||
mkdir -p data/plugins data/config data/temp data/skills
|
||||
export TESTING="${TESTING:-true}"
|
||||
export ZHIPU_API_KEY="${ZHIPU_API_KEY:-test-api-key}"
|
||||
|
||||
if [[ "$RUN_LINT" == true ]]; then
|
||||
echo "==> Running Ruff format check"
|
||||
uv run ruff format --check .
|
||||
echo "==> Running Ruff lint check"
|
||||
uv run ruff check .
|
||||
fi
|
||||
|
||||
echo "==> Running pytest"
|
||||
if [[ "$PROFILE" == "neo" ]]; then
|
||||
NEO_TESTS=(
|
||||
"tests/test_neo_skill_sync.py"
|
||||
"tests/test_neo_skill_tools.py"
|
||||
"tests/test_computer_skill_sync.py"
|
||||
"tests/test_skill_manager_sandbox_cache.py"
|
||||
"tests/test_dashboard.py::test_neo_skills_routes"
|
||||
)
|
||||
uv run pytest -q "${NEO_TESTS[@]}" ${PYTEST_ARGS:-}
|
||||
else
|
||||
uv run pytest --cov=. -v -o log_cli=true -o log_level=DEBUG ${PYTEST_ARGS:-}
|
||||
fi
|
||||
|
||||
run_smoke_test() {
|
||||
if ! command -v curl >/dev/null 2>&1; then
|
||||
echo "curl is required for smoke test." >&2
|
||||
return 1
|
||||
fi
|
||||
|
||||
local smoke_port="6185"
|
||||
local smoke_log
|
||||
smoke_log="$(mktemp -t astrbot-smoke.XXXXXX.log)"
|
||||
|
||||
echo "==> Starting smoke test on http://localhost:${smoke_port}"
|
||||
uv run main.py >"$smoke_log" 2>&1 &
|
||||
local app_pid=$!
|
||||
|
||||
for _ in $(seq 1 60); do
|
||||
if curl -sf "http://localhost:${smoke_port}" >/dev/null 2>&1; then
|
||||
echo "==> Smoke test passed"
|
||||
kill "$app_pid" 2>/dev/null || true
|
||||
wait "$app_pid" 2>/dev/null || true
|
||||
rm -f "$smoke_log"
|
||||
return 0
|
||||
fi
|
||||
|
||||
if ! kill -0 "$app_pid" 2>/dev/null; then
|
||||
echo "AstrBot process exited before becoming healthy." >&2
|
||||
tail -n 60 "$smoke_log" || true
|
||||
rm -f "$smoke_log"
|
||||
return 1
|
||||
fi
|
||||
|
||||
sleep 1
|
||||
done
|
||||
|
||||
echo "Smoke test failed: health endpoint did not become ready in time." >&2
|
||||
tail -n 60 "$smoke_log" || true
|
||||
kill "$app_pid" 2>/dev/null || true
|
||||
wait "$app_pid" 2>/dev/null || true
|
||||
rm -f "$smoke_log"
|
||||
return 1
|
||||
}
|
||||
|
||||
if [[ "$RUN_SMOKE" == true ]]; then
|
||||
run_smoke_test
|
||||
fi
|
||||
|
||||
if [[ "$RUN_DASHBOARD" == true ]]; then
|
||||
if ! command -v pnpm >/dev/null 2>&1; then
|
||||
echo "pnpm is required for dashboard build. Install it with: npm install -g pnpm" >&2
|
||||
exit 1
|
||||
fi
|
||||
echo "==> Building dashboard"
|
||||
pnpm --dir dashboard install --frozen-lockfile
|
||||
pnpm --dir dashboard run build
|
||||
fi
|
||||
|
||||
echo "==> PR checks completed successfully"
|
||||
@@ -1,298 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
# start-with-neo.sh — 一键启动 Shipyard Neo Bay + AstrBot
|
||||
#
|
||||
# Usage:
|
||||
# bash scripts/start-with-neo.sh # 默认 Bay :8114
|
||||
# BAY_PORT=9000 bash scripts/start-with-neo.sh # 自定义端口
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
set -euo pipefail
|
||||
|
||||
# ── 路径 ──────────────────────────────────────────────────────
|
||||
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
|
||||
ASTRBOT_DIR="$(cd "$SCRIPT_DIR/.." && pwd)"
|
||||
# shipyard-neo mono-repo root is one level above AstrBot
|
||||
NEO_ROOT="$(cd "$ASTRBOT_DIR/.." && pwd)"
|
||||
BAY_DIR="$NEO_ROOT/pkgs/bay"
|
||||
|
||||
BAY_PORT="${BAY_PORT:-8114}"
|
||||
BAY_HOST="0.0.0.0"
|
||||
BAY_PID=""
|
||||
BAY_API_KEY="" # Populated after Bay starts from credentials.json
|
||||
|
||||
# ── 颜色 ──────────────────────────────────────────────────────
|
||||
RED='\033[0;31m'
|
||||
GREEN='\033[0;32m'
|
||||
YELLOW='\033[1;33m'
|
||||
CYAN='\033[0;36m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
log() { echo -e "${CYAN}[neo]${NC} $*"; }
|
||||
ok() { echo -e "${GREEN}[neo]${NC} $*"; }
|
||||
warn() { echo -e "${YELLOW}[neo]${NC} $*"; }
|
||||
err() { echo -e "${RED}[neo]${NC} $*" >&2; }
|
||||
|
||||
# ── 清理函数 ──────────────────────────────────────────────────
|
||||
cleanup() {
|
||||
log "Shutting down..."
|
||||
if [[ -n "$BAY_PID" ]] && kill -0 "$BAY_PID" 2>/dev/null; then
|
||||
log "Stopping Bay (PID $BAY_PID)..."
|
||||
kill "$BAY_PID" 2>/dev/null || true
|
||||
wait "$BAY_PID" 2>/dev/null || true
|
||||
fi
|
||||
ok "All services stopped."
|
||||
}
|
||||
trap cleanup EXIT INT TERM
|
||||
|
||||
# ── 检查前置条件 ──────────────────────────────────────────────
|
||||
check_prerequisites() {
|
||||
log "Checking prerequisites..."
|
||||
|
||||
if [[ ! -d "$BAY_DIR" ]]; then
|
||||
err "Bay directory not found: $BAY_DIR"
|
||||
err "Expected shipyard-neo mono-repo at: $NEO_ROOT"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if ! command -v uv &>/dev/null; then
|
||||
err "'uv' is not installed. Please install it first."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Check Docker access (try without sudo first, then with sudo)
|
||||
if docker info &>/dev/null 2>&1; then
|
||||
ok "Docker is accessible."
|
||||
elif sudo docker info &>/dev/null 2>&1; then
|
||||
warn "Docker requires sudo. Bay may need socket permissions."
|
||||
warn "If Bay fails to connect to Docker, run: sudo chmod 666 /var/run/docker.sock"
|
||||
else
|
||||
err "Docker is not accessible. Please install Docker or fix permissions."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Check Bay venv
|
||||
if [[ ! -d "$BAY_DIR/.venv" ]]; then
|
||||
log "Bay venv not found. Running 'uv sync' in $BAY_DIR ..."
|
||||
(cd "$BAY_DIR" && uv sync)
|
||||
fi
|
||||
|
||||
ok "Prerequisites OK."
|
||||
}
|
||||
|
||||
# ── 生成 Bay config.yaml(如不存在)────────────────────────────
|
||||
ensure_bay_config() {
|
||||
local config_file="$BAY_DIR/config.yaml"
|
||||
|
||||
if [[ -f "$config_file" ]]; then
|
||||
ok "Bay config.yaml already exists."
|
||||
return
|
||||
fi
|
||||
|
||||
log "Generating Bay config.yaml for local development..."
|
||||
|
||||
cat > "$config_file" << 'BAYCONFIG'
|
||||
# Bay Local Development Config (auto-generated by start-with-neo.sh)
|
||||
# For full reference see config.yaml.example
|
||||
|
||||
server:
|
||||
host: "0.0.0.0"
|
||||
port: 8114
|
||||
|
||||
database:
|
||||
url: "sqlite+aiosqlite:///./bay.db"
|
||||
echo: false
|
||||
|
||||
driver:
|
||||
type: docker
|
||||
image_pull_policy: if_not_present
|
||||
docker:
|
||||
socket: "unix:///var/run/docker.sock"
|
||||
connect_mode: host_port
|
||||
host_address: "127.0.0.1"
|
||||
publish_ports: true
|
||||
host_port: null
|
||||
network: null
|
||||
|
||||
cargo:
|
||||
root_path: "/var/lib/bay/cargos"
|
||||
default_size_limit_mb: 1024
|
||||
mount_path: "/workspace"
|
||||
|
||||
# Security: auto-provision mode
|
||||
# Bay generates sk-bay-* key on first boot → credentials.json
|
||||
security:
|
||||
allow_anonymous: false
|
||||
|
||||
profiles:
|
||||
- id: python-default
|
||||
description: "Standard Python sandbox"
|
||||
image: "ghcr.io/astrbotdevs/shipyard-neo-ship:latest"
|
||||
runtime_type: ship
|
||||
runtime_port: 8123
|
||||
resources:
|
||||
cpus: 1.0
|
||||
memory: "1g"
|
||||
capabilities:
|
||||
- filesystem
|
||||
- shell
|
||||
- python
|
||||
idle_timeout: 1800
|
||||
env: {}
|
||||
|
||||
gc:
|
||||
enabled: true
|
||||
run_on_startup: true
|
||||
interval_seconds: 300
|
||||
idle_session:
|
||||
enabled: true
|
||||
expired_sandbox:
|
||||
enabled: true
|
||||
orphan_cargo:
|
||||
enabled: true
|
||||
orphan_container:
|
||||
enabled: false
|
||||
BAYCONFIG
|
||||
|
||||
ok "Bay config.yaml created at $config_file"
|
||||
}
|
||||
|
||||
# ── 拉取 Ship 镜像 ───────────────────────────────────────────
|
||||
ensure_ship_image() {
|
||||
local image="ghcr.io/astrbotdevs/shipyard-neo-ship:latest"
|
||||
log "Checking Ship image: $image ..."
|
||||
|
||||
if docker image inspect "$image" &>/dev/null 2>&1 || \
|
||||
sudo docker image inspect "$image" &>/dev/null 2>&1; then
|
||||
ok "Ship image is available locally."
|
||||
else
|
||||
log "Pulling Ship image (this may take a while)..."
|
||||
if docker pull "$image" 2>/dev/null || sudo docker pull "$image" 2>/dev/null; then
|
||||
ok "Ship image pulled successfully."
|
||||
else
|
||||
warn "Failed to pull Ship image. Bay will try to pull it on first sandbox creation."
|
||||
fi
|
||||
fi
|
||||
}
|
||||
|
||||
# ── 启动 Bay ──────────────────────────────────────────────────
|
||||
start_bay() {
|
||||
log "Starting Bay on :$BAY_PORT ..."
|
||||
|
||||
(cd "$BAY_DIR" && BAY_DATA_DIR="$BAY_DIR" uv run uvicorn app.main:app \
|
||||
--host "$BAY_HOST" \
|
||||
--port "$BAY_PORT" \
|
||||
--reload \
|
||||
2>&1 | sed "s/^/ ${CYAN}[bay]${NC} /") &
|
||||
BAY_PID=$!
|
||||
|
||||
log "Bay started (PID $BAY_PID), waiting for health check..."
|
||||
|
||||
# Wait for Bay to become healthy
|
||||
local max_wait=30
|
||||
local waited=0
|
||||
while [[ $waited -lt $max_wait ]]; do
|
||||
if curl -sf "http://127.0.0.1:$BAY_PORT/health" &>/dev/null; then
|
||||
ok "Bay is healthy at http://127.0.0.1:$BAY_PORT"
|
||||
return
|
||||
fi
|
||||
# Check if process is still alive
|
||||
if ! kill -0 "$BAY_PID" 2>/dev/null; then
|
||||
err "Bay process died unexpectedly. Check the output above."
|
||||
exit 1
|
||||
fi
|
||||
sleep 1
|
||||
waited=$((waited + 1))
|
||||
done
|
||||
|
||||
err "Bay did not become healthy within ${max_wait}s."
|
||||
err "It may still be starting — check http://127.0.0.1:$BAY_PORT/health"
|
||||
}
|
||||
|
||||
# ── 读取 Bay 自动生成的凭证 ───────────────────────────────────
|
||||
read_bay_credentials() {
|
||||
local cred_file="$BAY_DIR/credentials.json"
|
||||
|
||||
# Wait briefly for credentials.json to appear (Bay writes it during startup)
|
||||
local max_wait=5
|
||||
local waited=0
|
||||
while [[ $waited -lt $max_wait ]]; do
|
||||
if [[ -f "$cred_file" ]]; then
|
||||
break
|
||||
fi
|
||||
sleep 1
|
||||
waited=$((waited + 1))
|
||||
done
|
||||
|
||||
if [[ -f "$cred_file" ]]; then
|
||||
# Extract api_key using python (always available) — no jq dependency
|
||||
BAY_API_KEY=$(python3 -c "
|
||||
import json, sys
|
||||
try:
|
||||
d = json.load(open('$cred_file'))
|
||||
print(d.get('api_key', ''))
|
||||
except Exception:
|
||||
print('')
|
||||
" 2>/dev/null || echo "")
|
||||
|
||||
if [[ -n "$BAY_API_KEY" ]]; then
|
||||
ok "Auto-provisioned API key loaded from credentials.json"
|
||||
else
|
||||
warn "credentials.json found but api_key is empty"
|
||||
fi
|
||||
else
|
||||
warn "credentials.json not found — Bay may be using an existing key or anonymous mode"
|
||||
warn "Check Bay logs above for the API key, or look at: $cred_file"
|
||||
fi
|
||||
}
|
||||
|
||||
# ── 打印 AstrBot 配置提示 ────────────────────────────────────
|
||||
print_astrbot_config_hint() {
|
||||
echo ""
|
||||
echo -e "${GREEN}════════════════════════════════════════════════════════════${NC}"
|
||||
echo -e "${GREEN} Shipyard Neo Bay is running at http://127.0.0.1:$BAY_PORT ${NC}"
|
||||
echo -e "${GREEN}════════════════════════════════════════════════════════════${NC}"
|
||||
echo ""
|
||||
if [[ -n "$BAY_API_KEY" ]]; then
|
||||
echo -e " ${CYAN}Bay API Key (auto-generated):${NC}"
|
||||
echo -e " ${YELLOW}$BAY_API_KEY${NC}"
|
||||
echo ""
|
||||
fi
|
||||
echo -e " ${CYAN}AstrBot Dashboard 配置指引:${NC}"
|
||||
echo -e " 1. AI 配置 → Agent Computer Use"
|
||||
echo -e " • Computer Use Runtime → ${YELLOW}沙箱${NC}"
|
||||
echo -e " • 沙箱环境驱动器 → ${YELLOW}Shipyard Neo${NC}"
|
||||
echo -e " • Shipyard Neo API Endpoint → ${YELLOW}http://127.0.0.1:$BAY_PORT${NC}"
|
||||
if [[ -n "$BAY_API_KEY" ]]; then
|
||||
echo -e " • Shipyard Neo Access Token → ${YELLOW}$BAY_API_KEY${NC}"
|
||||
else
|
||||
echo -e " • Shipyard Neo Access Token → ${YELLOW}(查看 Bay 日志获取 key)${NC}"
|
||||
fi
|
||||
echo -e " • Shipyard Neo Profile → ${YELLOW}python-default${NC}"
|
||||
echo ""
|
||||
}
|
||||
|
||||
# ── 启动 AstrBot ──────────────────────────────────────────────
|
||||
start_astrbot() {
|
||||
log "Starting AstrBot..."
|
||||
cd "$ASTRBOT_DIR"
|
||||
uv run main.py
|
||||
}
|
||||
|
||||
# ── 主流程 ────────────────────────────────────────────────────
|
||||
main() {
|
||||
echo ""
|
||||
echo -e "${CYAN}╔══════════════════════════════════════════╗${NC}"
|
||||
echo -e "${CYAN}║ Shipyard Neo + AstrBot Quick Start ║${NC}"
|
||||
echo -e "${CYAN}╚══════════════════════════════════════════╝${NC}"
|
||||
echo ""
|
||||
|
||||
check_prerequisites
|
||||
ensure_bay_config
|
||||
ensure_ship_image
|
||||
start_bay
|
||||
read_bay_credentials
|
||||
print_astrbot_config_hint
|
||||
start_astrbot
|
||||
}
|
||||
|
||||
main "$@"
|
||||
@@ -52,6 +52,18 @@ class TestContextTruncator:
|
||||
assert len(result) == 3
|
||||
assert result == messages
|
||||
|
||||
def test_fix_messages_tool_with_valid_context(self):
|
||||
"""Test fix_messages with tool message after user+assistant."""
|
||||
truncator = ContextTruncator()
|
||||
messages = [
|
||||
self.create_message("user", "Run tool"),
|
||||
self.create_message("assistant", "Running..."),
|
||||
self.create_message("tool", "Tool result"),
|
||||
]
|
||||
result = truncator.fix_messages(messages)
|
||||
assert len(result) == 3
|
||||
assert result == messages
|
||||
|
||||
def test_fix_messages_tool_without_context(self):
|
||||
"""Test fix_messages with tool message without enough context."""
|
||||
truncator = ContextTruncator()
|
||||
@@ -62,6 +74,43 @@ class TestContextTruncator:
|
||||
# Tool message without context should be removed
|
||||
assert len(result) == 0
|
||||
|
||||
def test_fix_messages_tool_with_only_one_message(self):
|
||||
"""Test fix_messages with tool message after only one message."""
|
||||
truncator = ContextTruncator()
|
||||
messages = [
|
||||
self.create_message("user", "Hello"),
|
||||
self.create_message("tool", "Tool result"),
|
||||
]
|
||||
result = truncator.fix_messages(messages)
|
||||
# Tool message without enough context should be removed
|
||||
assert len(result) == 0
|
||||
|
||||
def test_fix_messages_multiple_tools(self):
|
||||
"""Test fix_messages with multiple tool messages."""
|
||||
truncator = ContextTruncator()
|
||||
messages = [
|
||||
self.create_message("user", "Run tool"),
|
||||
self.create_message("assistant", "Running..."),
|
||||
self.create_message("tool", "Tool 1 result"),
|
||||
self.create_message("tool", "Tool 2 result"),
|
||||
]
|
||||
result = truncator.fix_messages(messages)
|
||||
assert len(result) == 4
|
||||
assert result == messages
|
||||
|
||||
def test_fix_messages_mixed_system_tool(self):
|
||||
"""Test fix_messages with system message and tool messages."""
|
||||
truncator = ContextTruncator()
|
||||
messages = [
|
||||
self.create_message("system", "System prompt"),
|
||||
self.create_message("user", "Run tool"),
|
||||
self.create_message("assistant", "Running..."),
|
||||
self.create_message("tool", "Tool result"),
|
||||
]
|
||||
result = truncator.fix_messages(messages)
|
||||
assert len(result) == 4
|
||||
assert result == messages
|
||||
|
||||
# ==================== truncate_by_turns Tests ====================
|
||||
|
||||
def test_truncate_by_turns_no_limit(self):
|
||||
|
||||
@@ -1,325 +0,0 @@
|
||||
"""Tests for _discover_bay_credentials() auto-discovery and _log_computer_config_changes()."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from astrbot.core.computer.computer_client import _discover_bay_credentials
|
||||
from astrbot.dashboard.routes.config import _log_computer_config_changes
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════
|
||||
# _discover_bay_credentials
|
||||
# ═══════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
class TestDiscoverBayCredentials:
|
||||
"""Test Bay API key auto-discovery from credentials.json."""
|
||||
|
||||
def _write_creds(
|
||||
self,
|
||||
path: Path,
|
||||
api_key: str = "sk-bay-abc123",
|
||||
endpoint: str = "http://127.0.0.1:8114",
|
||||
) -> None:
|
||||
"""Helper: write a credentials.json file."""
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
path.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"api_key": api_key,
|
||||
"endpoint": endpoint,
|
||||
"generated_at": "2026-02-17T00:00:00+00:00",
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
def test_discover_from_bay_data_dir_env(
|
||||
self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""BAY_DATA_DIR env var takes highest priority."""
|
||||
data_dir = tmp_path / "bay_data"
|
||||
cred_file = data_dir / "credentials.json"
|
||||
self._write_creds(cred_file, api_key="sk-bay-from-env-dir")
|
||||
monkeypatch.setenv("BAY_DATA_DIR", str(data_dir))
|
||||
|
||||
result = _discover_bay_credentials("http://127.0.0.1:8114")
|
||||
assert result == "sk-bay-from-env-dir"
|
||||
|
||||
def test_discover_from_cwd(
|
||||
self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Falls back to current working directory."""
|
||||
cred_file = tmp_path / "credentials.json"
|
||||
self._write_creds(cred_file, api_key="sk-bay-from-cwd")
|
||||
monkeypatch.chdir(tmp_path)
|
||||
monkeypatch.delenv("BAY_DATA_DIR", raising=False)
|
||||
|
||||
result = _discover_bay_credentials("http://127.0.0.1:8114")
|
||||
assert result == "sk-bay-from-cwd"
|
||||
|
||||
def test_returns_empty_when_no_credentials_found(
|
||||
self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Returns empty string when no credentials.json exists anywhere."""
|
||||
monkeypatch.chdir(tmp_path)
|
||||
monkeypatch.delenv("BAY_DATA_DIR", raising=False)
|
||||
|
||||
result = _discover_bay_credentials("http://127.0.0.1:8114")
|
||||
assert result == ""
|
||||
|
||||
def test_skips_empty_api_key(
|
||||
self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Skips credentials.json when api_key is empty."""
|
||||
cred_file = tmp_path / "credentials.json"
|
||||
self._write_creds(cred_file, api_key="")
|
||||
monkeypatch.chdir(tmp_path)
|
||||
monkeypatch.delenv("BAY_DATA_DIR", raising=False)
|
||||
|
||||
result = _discover_bay_credentials("http://127.0.0.1:8114")
|
||||
assert result == ""
|
||||
|
||||
def test_skips_malformed_json(
|
||||
self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Handles malformed JSON gracefully."""
|
||||
cred_file = tmp_path / "credentials.json"
|
||||
cred_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
cred_file.write_text("not valid json {{{")
|
||||
monkeypatch.chdir(tmp_path)
|
||||
monkeypatch.delenv("BAY_DATA_DIR", raising=False)
|
||||
|
||||
result = _discover_bay_credentials("http://127.0.0.1:8114")
|
||||
assert result == ""
|
||||
|
||||
@patch("astrbot.core.computer.computer_client.logger")
|
||||
def test_endpoint_mismatch_still_returns_key(
|
||||
self, mock_logger, tmp_path: Path, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Returns key even if endpoint doesn't match, but logs a warning."""
|
||||
data_dir = tmp_path / "bay_data"
|
||||
cred_file = data_dir / "credentials.json"
|
||||
self._write_creds(
|
||||
cred_file, api_key="sk-bay-mismatch", endpoint="http://other-host:9000"
|
||||
)
|
||||
monkeypatch.setenv("BAY_DATA_DIR", str(data_dir))
|
||||
|
||||
result = _discover_bay_credentials("http://127.0.0.1:8114")
|
||||
|
||||
assert result == "sk-bay-mismatch"
|
||||
mock_logger.warning.assert_called_once()
|
||||
warning_msg = mock_logger.warning.call_args[0][0]
|
||||
assert "endpoint mismatch" in warning_msg
|
||||
|
||||
def test_endpoint_match_no_warning(
|
||||
self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""No warning when endpoints match."""
|
||||
data_dir = tmp_path / "bay_data"
|
||||
cred_file = data_dir / "credentials.json"
|
||||
self._write_creds(
|
||||
cred_file, api_key="sk-bay-match", endpoint="http://127.0.0.1:8114"
|
||||
)
|
||||
monkeypatch.setenv("BAY_DATA_DIR", str(data_dir))
|
||||
|
||||
with patch("astrbot.core.computer.computer_client.logger") as mock_logger:
|
||||
result = _discover_bay_credentials("http://127.0.0.1:8114")
|
||||
|
||||
assert result == "sk-bay-match"
|
||||
mock_logger.warning.assert_not_called()
|
||||
|
||||
def test_bay_data_dir_priority_over_cwd(
|
||||
self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""BAY_DATA_DIR takes priority over cwd."""
|
||||
env_dir = tmp_path / "env_dir"
|
||||
cwd_dir = tmp_path / "cwd_dir"
|
||||
self._write_creds(env_dir / "credentials.json", api_key="sk-bay-env-wins")
|
||||
self._write_creds(cwd_dir / "credentials.json", api_key="sk-bay-cwd-loses")
|
||||
monkeypatch.setenv("BAY_DATA_DIR", str(env_dir))
|
||||
monkeypatch.chdir(cwd_dir)
|
||||
|
||||
result = _discover_bay_credentials("http://127.0.0.1:8114")
|
||||
assert result == "sk-bay-env-wins"
|
||||
|
||||
def test_trailing_slash_normalization(
|
||||
self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Trailing slashes on endpoints are normalized before comparison."""
|
||||
data_dir = tmp_path / "bay_data"
|
||||
cred_file = data_dir / "credentials.json"
|
||||
self._write_creds(
|
||||
cred_file, api_key="sk-bay-slash", endpoint="http://127.0.0.1:8114/"
|
||||
)
|
||||
monkeypatch.setenv("BAY_DATA_DIR", str(data_dir))
|
||||
|
||||
with patch("astrbot.core.computer.computer_client.logger") as mock_logger:
|
||||
result = _discover_bay_credentials("http://127.0.0.1:8114")
|
||||
|
||||
assert result == "sk-bay-slash"
|
||||
mock_logger.warning.assert_not_called()
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════
|
||||
# _log_computer_config_changes
|
||||
# ═══════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
class TestLogComputerConfigChanges:
|
||||
"""Test config change detection and logging."""
|
||||
|
||||
@patch("astrbot.dashboard.routes.config.logger")
|
||||
def test_logs_runtime_change(self, mock_logger) -> None:
|
||||
"""Detects computer_use_runtime change."""
|
||||
old = {"provider_settings": {"computer_use_runtime": "none"}}
|
||||
new = {"provider_settings": {"computer_use_runtime": "sandbox"}}
|
||||
|
||||
_log_computer_config_changes(old, new)
|
||||
|
||||
mock_logger.info.assert_called()
|
||||
call_args = [str(c) for c in mock_logger.info.call_args_list]
|
||||
assert any("computer_use_runtime" in c and "none" in c and "sandbox" in c for c in call_args)
|
||||
|
||||
@patch("astrbot.dashboard.routes.config.logger")
|
||||
def test_no_log_when_runtime_unchanged(self, mock_logger) -> None:
|
||||
"""No log when runtime stays the same."""
|
||||
old = {"provider_settings": {"computer_use_runtime": "sandbox"}}
|
||||
new = {"provider_settings": {"computer_use_runtime": "sandbox"}}
|
||||
|
||||
_log_computer_config_changes(old, new)
|
||||
|
||||
mock_logger.info.assert_not_called()
|
||||
|
||||
@patch("astrbot.dashboard.routes.config.logger")
|
||||
def test_logs_sandbox_key_change(self, mock_logger) -> None:
|
||||
"""Detects sandbox sub-key change."""
|
||||
old = {"provider_settings": {"sandbox": {"booter": "shipyard"}}}
|
||||
new = {"provider_settings": {"sandbox": {"booter": "shipyard_neo"}}}
|
||||
|
||||
_log_computer_config_changes(old, new)
|
||||
|
||||
mock_logger.info.assert_called()
|
||||
# logger.info("[Computer] Config changed: sandbox.%s %s -> %s", key, old, new)
|
||||
found = False
|
||||
for call in mock_logger.info.call_args_list:
|
||||
args = call[0] # positional args: (fmt, key, old_val, new_val)
|
||||
if len(args) >= 4 and args[1] == "booter":
|
||||
assert args[2] == "shipyard"
|
||||
assert args[3] == "shipyard_neo"
|
||||
found = True
|
||||
break
|
||||
assert found, f"Expected booter change in log calls: {mock_logger.info.call_args_list}"
|
||||
|
||||
@patch("astrbot.dashboard.routes.config.logger")
|
||||
def test_masks_token_values(self, mock_logger) -> None:
|
||||
"""Token/secret values are masked in log output."""
|
||||
old = {"provider_settings": {"sandbox": {"shipyard_neo_access_token": ""}}}
|
||||
new = {
|
||||
"provider_settings": {
|
||||
"sandbox": {"shipyard_neo_access_token": "sk-bay-secret123"}
|
||||
}
|
||||
}
|
||||
|
||||
_log_computer_config_changes(old, new)
|
||||
|
||||
mock_logger.info.assert_called()
|
||||
call_args_str = str(mock_logger.info.call_args_list)
|
||||
assert "***" in call_args_str
|
||||
assert "sk-bay-secret123" not in call_args_str
|
||||
|
||||
@patch("astrbot.dashboard.routes.config.logger")
|
||||
def test_masks_empty_token_as_empty_label(self, mock_logger) -> None:
|
||||
"""Empty token values show as '(empty)' not '***'."""
|
||||
old = {
|
||||
"provider_settings": {
|
||||
"sandbox": {"shipyard_neo_access_token": "old-key"}
|
||||
}
|
||||
}
|
||||
new = {"provider_settings": {"sandbox": {"shipyard_neo_access_token": ""}}}
|
||||
|
||||
_log_computer_config_changes(old, new)
|
||||
|
||||
mock_logger.info.assert_called()
|
||||
call_args_str = str(mock_logger.info.call_args_list)
|
||||
assert "(empty)" in call_args_str
|
||||
|
||||
@patch("astrbot.dashboard.routes.config.logger")
|
||||
def test_no_log_when_nothing_changed(self, mock_logger) -> None:
|
||||
"""No logs at all when config is identical."""
|
||||
cfg = {
|
||||
"provider_settings": {
|
||||
"computer_use_runtime": "sandbox",
|
||||
"sandbox": {
|
||||
"booter": "shipyard_neo",
|
||||
"shipyard_neo_endpoint": "http://127.0.0.1:8114",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
_log_computer_config_changes(cfg, cfg)
|
||||
|
||||
mock_logger.info.assert_not_called()
|
||||
|
||||
@patch("astrbot.dashboard.routes.config.logger")
|
||||
def test_handles_missing_provider_settings(self, mock_logger) -> None:
|
||||
"""Gracefully handles configs without provider_settings."""
|
||||
_log_computer_config_changes(
|
||||
{}, {"provider_settings": {"computer_use_runtime": "sandbox"}}
|
||||
)
|
||||
|
||||
mock_logger.info.assert_called()
|
||||
call_args_str = str(mock_logger.info.call_args_list)
|
||||
assert "computer_use_runtime" in call_args_str
|
||||
|
||||
@patch("astrbot.dashboard.routes.config.logger")
|
||||
def test_detects_new_sandbox_key(self, mock_logger) -> None:
|
||||
"""Detects a newly added sandbox key."""
|
||||
old = {"provider_settings": {"sandbox": {}}}
|
||||
new = {
|
||||
"provider_settings": {
|
||||
"sandbox": {"shipyard_neo_endpoint": "http://127.0.0.1:8114"}
|
||||
}
|
||||
}
|
||||
|
||||
_log_computer_config_changes(old, new)
|
||||
|
||||
mock_logger.info.assert_called()
|
||||
call_args_str = str(mock_logger.info.call_args_list)
|
||||
assert "shipyard_neo_endpoint" in call_args_str
|
||||
|
||||
@patch("astrbot.dashboard.routes.config.logger")
|
||||
def test_detects_removed_sandbox_key(self, mock_logger) -> None:
|
||||
"""Detects a removed sandbox key."""
|
||||
old = {
|
||||
"provider_settings": {
|
||||
"sandbox": {"shipyard_neo_endpoint": "http://127.0.0.1:8114"}
|
||||
}
|
||||
}
|
||||
new = {"provider_settings": {"sandbox": {}}}
|
||||
|
||||
_log_computer_config_changes(old, new)
|
||||
|
||||
mock_logger.info.assert_called()
|
||||
call_args_str = str(mock_logger.info.call_args_list)
|
||||
assert "shipyard_neo_endpoint" in call_args_str
|
||||
|
||||
@patch("astrbot.dashboard.routes.config.logger")
|
||||
def test_secret_key_masked(self, mock_logger) -> None:
|
||||
"""Any key containing 'secret' is also masked."""
|
||||
old = {"provider_settings": {"sandbox": {"my_secret_key": ""}}}
|
||||
new = {
|
||||
"provider_settings": {"sandbox": {"my_secret_key": "very-secret-value"}}
|
||||
}
|
||||
|
||||
_log_computer_config_changes(old, new)
|
||||
|
||||
mock_logger.info.assert_called()
|
||||
call_args_str = str(mock_logger.info.call_args_list)
|
||||
assert "***" in call_args_str
|
||||
assert "very-secret-value" not in call_args_str
|
||||
@@ -1,123 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
|
||||
from astrbot.core.computer import computer_client
|
||||
|
||||
|
||||
class _FakeShell:
|
||||
def __init__(self, sync_payload_json: str):
|
||||
self.sync_payload_json = sync_payload_json
|
||||
self.commands: list[str] = []
|
||||
|
||||
async def exec(self, command: str, **kwargs):
|
||||
_ = kwargs
|
||||
self.commands.append(command)
|
||||
if "PYBIN" in command and "managed_skills" in command:
|
||||
return {
|
||||
"success": True,
|
||||
"stdout": self.sync_payload_json,
|
||||
"stderr": "",
|
||||
"exit_code": 0,
|
||||
}
|
||||
return {"success": True, "stdout": "", "stderr": "", "exit_code": 0}
|
||||
|
||||
|
||||
class _FakeBooter:
|
||||
def __init__(self, sync_payload_json: str):
|
||||
self.shell = _FakeShell(sync_payload_json)
|
||||
self.uploads: list[tuple[str, str]] = []
|
||||
|
||||
async def upload_file(self, path: str, file_name: str) -> dict:
|
||||
self.uploads.append((path, file_name))
|
||||
return {"success": True}
|
||||
|
||||
|
||||
def test_sync_skills_keeps_builtin_skills_when_local_is_empty(monkeypatch, tmp_path: Path):
|
||||
skills_root = tmp_path / "skills"
|
||||
temp_root = tmp_path / "temp"
|
||||
skills_root.mkdir(parents=True, exist_ok=True)
|
||||
temp_root.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
captured = {"skills": None}
|
||||
|
||||
def _fake_set_cache(self, skills):
|
||||
captured["skills"] = skills
|
||||
|
||||
monkeypatch.setattr(
|
||||
"astrbot.core.computer.computer_client.get_astrbot_skills_path",
|
||||
lambda: str(skills_root),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"astrbot.core.computer.computer_client.get_astrbot_temp_path",
|
||||
lambda: str(temp_root),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"astrbot.core.computer.computer_client.SkillManager.set_sandbox_skills_cache",
|
||||
_fake_set_cache,
|
||||
)
|
||||
|
||||
booter = _FakeBooter(
|
||||
'{"skills":[{"name":"python-sandbox","description":"ship","path":"skills/python-sandbox/SKILL.md"}]}'
|
||||
)
|
||||
asyncio.run(computer_client._sync_skills_to_sandbox(booter))
|
||||
|
||||
assert booter.uploads == []
|
||||
assert any(cmd == "rm -f skills/skills.zip" for cmd in booter.shell.commands)
|
||||
assert captured["skills"] == [
|
||||
{
|
||||
"name": "python-sandbox",
|
||||
"description": "ship",
|
||||
"path": "skills/python-sandbox/SKILL.md",
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def test_sync_skills_uses_managed_strategy_instead_of_wiping_all(
|
||||
monkeypatch,
|
||||
tmp_path: Path,
|
||||
):
|
||||
skills_root = tmp_path / "skills"
|
||||
temp_root = tmp_path / "temp"
|
||||
skill_dir = skills_root / "custom-agent-skill"
|
||||
skill_dir.mkdir(parents=True, exist_ok=True)
|
||||
skill_dir.joinpath("SKILL.md").write_text("# demo", encoding="utf-8")
|
||||
temp_root.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
captured = {"skills": None}
|
||||
|
||||
def _fake_set_cache(self, skills):
|
||||
captured["skills"] = skills
|
||||
|
||||
monkeypatch.setattr(
|
||||
"astrbot.core.computer.computer_client.get_astrbot_skills_path",
|
||||
lambda: str(skills_root),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"astrbot.core.computer.computer_client.get_astrbot_temp_path",
|
||||
lambda: str(temp_root),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"astrbot.core.computer.computer_client.SkillManager.set_sandbox_skills_cache",
|
||||
_fake_set_cache,
|
||||
)
|
||||
|
||||
booter = _FakeBooter(
|
||||
'{"skills":[{"name":"custom-agent-skill","description":"","path":"skills/custom-agent-skill/SKILL.md"}]}'
|
||||
)
|
||||
asyncio.run(computer_client._sync_skills_to_sandbox(booter))
|
||||
|
||||
assert len(booter.uploads) == 1
|
||||
assert booter.uploads[0][1] == "skills/skills.zip"
|
||||
assert not any(
|
||||
"find skills -mindepth 1 -delete" in cmd for cmd in booter.shell.commands
|
||||
)
|
||||
assert captured["skills"] == [
|
||||
{
|
||||
"name": "custom-agent-skill",
|
||||
"description": "",
|
||||
"path": "skills/custom-agent-skill/SKILL.md",
|
||||
}
|
||||
]
|
||||
|
||||
+1
-183
@@ -1,7 +1,6 @@
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
from types import SimpleNamespace
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
@@ -312,184 +311,3 @@ async def test_do_update(
|
||||
data = await response.get_json()
|
||||
assert data["status"] == "ok"
|
||||
assert os.path.exists(release_path)
|
||||
|
||||
|
||||
class _FakeNeoSkills:
|
||||
async def list_candidates(self, **kwargs):
|
||||
_ = kwargs
|
||||
return [
|
||||
{
|
||||
"id": "cand-1",
|
||||
"skill_key": "neo.demo",
|
||||
"status": "evaluated_pass",
|
||||
"payload_ref": "pref-1",
|
||||
}
|
||||
]
|
||||
|
||||
async def list_releases(self, **kwargs):
|
||||
_ = kwargs
|
||||
return [
|
||||
{
|
||||
"id": "rel-1",
|
||||
"skill_key": "neo.demo",
|
||||
"candidate_id": "cand-1",
|
||||
"stage": "stable",
|
||||
"active": True,
|
||||
}
|
||||
]
|
||||
|
||||
async def get_payload(self, payload_ref: str):
|
||||
return {
|
||||
"payload_ref": payload_ref,
|
||||
"payload": {"skill_markdown": "# Demo"},
|
||||
}
|
||||
|
||||
async def evaluate_candidate(self, candidate_id: str, **kwargs):
|
||||
return {"candidate_id": candidate_id, **kwargs}
|
||||
|
||||
async def promote_candidate(self, candidate_id: str, stage: str = "canary"):
|
||||
return {
|
||||
"id": "rel-2",
|
||||
"skill_key": "neo.demo",
|
||||
"candidate_id": candidate_id,
|
||||
"stage": stage,
|
||||
}
|
||||
|
||||
async def rollback_release(self, release_id: str):
|
||||
return {"id": "rb-1", "rolled_back_release_id": release_id}
|
||||
|
||||
|
||||
class _FakeNeoBayClient:
|
||||
def __init__(self, endpoint_url: str, access_token: str):
|
||||
self.endpoint_url = endpoint_url
|
||||
self.access_token = access_token
|
||||
self.skills = _FakeNeoSkills()
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
_ = exc_type, exc, tb
|
||||
return False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_neo_skills_routes(
|
||||
app: Quart,
|
||||
authenticated_header: dict,
|
||||
core_lifecycle_td: AstrBotCoreLifecycle,
|
||||
monkeypatch,
|
||||
):
|
||||
provider_settings = core_lifecycle_td.astrbot_config.setdefault(
|
||||
"provider_settings", {}
|
||||
)
|
||||
sandbox = provider_settings.setdefault("sandbox", {})
|
||||
sandbox["shipyard_neo_endpoint"] = "http://neo.test"
|
||||
sandbox["shipyard_neo_access_token"] = "neo-token"
|
||||
|
||||
fake_shipyard_neo_module = SimpleNamespace(BayClient=_FakeNeoBayClient)
|
||||
monkeypatch.setitem(sys.modules, "shipyard_neo", fake_shipyard_neo_module)
|
||||
|
||||
async def _fake_sync_release(self, client, **kwargs):
|
||||
_ = self, client, kwargs
|
||||
return SimpleNamespace(
|
||||
skill_key="neo.demo",
|
||||
local_skill_name="neo_demo",
|
||||
release_id="rel-2",
|
||||
candidate_id="cand-1",
|
||||
payload_ref="pref-1",
|
||||
map_path="data/skills/neo_skill_map.json",
|
||||
synced_at="2026-01-01T00:00:00Z",
|
||||
)
|
||||
|
||||
async def _fake_sync_skills_to_active_sandboxes():
|
||||
return
|
||||
|
||||
monkeypatch.setattr(
|
||||
"astrbot.dashboard.routes.skills.NeoSkillSyncManager.sync_release",
|
||||
_fake_sync_release,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"astrbot.dashboard.routes.skills.sync_skills_to_active_sandboxes",
|
||||
_fake_sync_skills_to_active_sandboxes,
|
||||
)
|
||||
|
||||
test_client = app.test_client()
|
||||
|
||||
response = await test_client.get(
|
||||
"/api/skills/neo/candidates", headers=authenticated_header
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data["status"] == "ok"
|
||||
assert isinstance(data["data"], list)
|
||||
assert data["data"][0]["id"] == "cand-1"
|
||||
|
||||
response = await test_client.get(
|
||||
"/api/skills/neo/releases", headers=authenticated_header
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data["status"] == "ok"
|
||||
assert isinstance(data["data"], list)
|
||||
assert data["data"][0]["id"] == "rel-1"
|
||||
|
||||
response = await test_client.get(
|
||||
"/api/skills/neo/payload?payload_ref=pref-1", headers=authenticated_header
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data["status"] == "ok"
|
||||
assert data["data"]["payload_ref"] == "pref-1"
|
||||
|
||||
response = await test_client.post(
|
||||
"/api/skills/neo/evaluate",
|
||||
json={"candidate_id": "cand-1", "passed": True, "score": 0.95},
|
||||
headers=authenticated_header,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data["status"] == "ok"
|
||||
assert data["data"]["candidate_id"] == "cand-1"
|
||||
assert data["data"]["passed"] is True
|
||||
|
||||
response = await test_client.post(
|
||||
"/api/skills/neo/evaluate",
|
||||
json={"candidate_id": "cand-1", "passed": "false", "score": 0.0},
|
||||
headers=authenticated_header,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data["status"] == "ok"
|
||||
assert data["data"]["passed"] is False
|
||||
|
||||
response = await test_client.post(
|
||||
"/api/skills/neo/promote",
|
||||
json={"candidate_id": "cand-1", "stage": "stable"},
|
||||
headers=authenticated_header,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data["status"] == "ok"
|
||||
assert data["data"]["release"]["id"] == "rel-2"
|
||||
assert data["data"]["sync"]["local_skill_name"] == "neo_demo"
|
||||
|
||||
response = await test_client.post(
|
||||
"/api/skills/neo/rollback",
|
||||
json={"release_id": "rel-2"},
|
||||
headers=authenticated_header,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data["status"] == "ok"
|
||||
assert data["data"]["rolled_back_release_id"] == "rel-2"
|
||||
|
||||
response = await test_client.post(
|
||||
"/api/skills/neo/sync",
|
||||
json={"release_id": "rel-2"},
|
||||
headers=authenticated_header,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data["status"] == "ok"
|
||||
assert data["data"]["skill_key"] == "neo.demo"
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
!data
|
||||
@@ -1,100 +0,0 @@
|
||||
{
|
||||
"type": "card",
|
||||
"theme": "info",
|
||||
"size": "lg",
|
||||
"modules": [
|
||||
{
|
||||
"text": {
|
||||
"content": "test1",
|
||||
"type": "plain-text",
|
||||
"emoji": true
|
||||
},
|
||||
"type": "header"
|
||||
},
|
||||
{
|
||||
"text": {
|
||||
"content": "test2",
|
||||
"type": "kmarkdown"
|
||||
},
|
||||
"type": "section",
|
||||
"mode": "left"
|
||||
},
|
||||
{
|
||||
"type": "divider"
|
||||
},
|
||||
{
|
||||
"text": {
|
||||
"fields": [
|
||||
{
|
||||
"content": "test3",
|
||||
"type": "kmarkdown"
|
||||
},
|
||||
{
|
||||
"content": "**test4**",
|
||||
"type": "kmarkdown"
|
||||
}
|
||||
],
|
||||
"type": "paragraph",
|
||||
"cols": 2
|
||||
},
|
||||
"type": "section",
|
||||
"mode": "left"
|
||||
},
|
||||
{
|
||||
"elements": [
|
||||
{
|
||||
"src": "https://img.kookapp.cn/attachments/2023-01/05/63b645851ff19.svg",
|
||||
"type": "image",
|
||||
"alt": "",
|
||||
"size": "lg",
|
||||
"circle": false
|
||||
}
|
||||
],
|
||||
"type": "image-group"
|
||||
},
|
||||
{
|
||||
"src": "https://img.kookapp.cn/attachments/2023-01/05/63b645851ff19.svg",
|
||||
"title": "test5",
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"endTime": 1772343427360,
|
||||
"type": "countdown",
|
||||
"startTime": 1772343378259,
|
||||
"mode": "second"
|
||||
},
|
||||
{
|
||||
"elements": [
|
||||
{
|
||||
"text": "点我测试回调",
|
||||
"type": "button",
|
||||
"theme": "primary",
|
||||
"value": "btn_clicked",
|
||||
"click": "return-val"
|
||||
},
|
||||
{
|
||||
"text": "访问官网",
|
||||
"type": "button",
|
||||
"theme": "danger",
|
||||
"value": "https://www.kookapp.cn",
|
||||
"click": "link"
|
||||
}
|
||||
],
|
||||
"type": "action-group"
|
||||
},
|
||||
{
|
||||
"elements": [
|
||||
{
|
||||
"content": "test6",
|
||||
"type": "plain-text",
|
||||
"emoji": true
|
||||
}
|
||||
],
|
||||
"type": "context"
|
||||
},
|
||||
{
|
||||
"code": "test7",
|
||||
"type": "invite"
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -1,4 +0,0 @@
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
TEST_DATA_DIR = Path(__file__).parent / "data"
|
||||
@@ -1,223 +0,0 @@
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
from astrbot.api.platform import AstrBotMessage, MessageType, PlatformMetadata, Unknown
|
||||
from astrbot.api.event import MessageChain
|
||||
from astrbot.core.message.components import (
|
||||
File,
|
||||
Image,
|
||||
Plain,
|
||||
Video,
|
||||
At,
|
||||
AtAll,
|
||||
BaseMessageComponent,
|
||||
Json,
|
||||
Record,
|
||||
Reply,
|
||||
)
|
||||
|
||||
|
||||
from astrbot.core.platform.sources.kook.kook_event import KookEvent
|
||||
from astrbot.core.platform.sources.kook.kook_types import KookMessageType, OrderMessage
|
||||
|
||||
|
||||
async def mock_kook_client(upload_asset_return: str, send_text_return: str):
|
||||
# 1. Mock 掉整个 KookClient 类
|
||||
client = MagicMock()
|
||||
|
||||
client.upload_asset = AsyncMock(return_value=upload_asset_return)
|
||||
client.send_text = AsyncMock(return_value=send_text_return)
|
||||
return client
|
||||
|
||||
|
||||
def mock_file_message(input: str):
|
||||
message = MagicMock(spec=File)
|
||||
message.get_file = AsyncMock(return_value=input)
|
||||
return message
|
||||
|
||||
|
||||
def mock_record_message(input: str):
|
||||
message = MagicMock(spec=Record)
|
||||
message.text = input
|
||||
message.convert_to_file_path = AsyncMock(return_value=input)
|
||||
return message
|
||||
|
||||
|
||||
def mock_astrbot_message():
|
||||
message = AstrBotMessage()
|
||||
message.type = MessageType.OTHER_MESSAGE
|
||||
message.group_id = "test"
|
||||
message.session_id = "test"
|
||||
message.message_id = "test"
|
||||
return message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"input_message,upload_asset_return, expected_output, expected_error",
|
||||
[
|
||||
(
|
||||
Image("test image"),
|
||||
"test image",
|
||||
OrderMessage(
|
||||
1,
|
||||
text="test image",
|
||||
type=KookMessageType.IMAGE,
|
||||
),
|
||||
None,
|
||||
),
|
||||
(
|
||||
Video("test video"),
|
||||
"test video",
|
||||
OrderMessage(
|
||||
1,
|
||||
text="test video",
|
||||
type=KookMessageType.VIDEO,
|
||||
),
|
||||
None,
|
||||
),
|
||||
(
|
||||
mock_file_message("test file"),
|
||||
"test file",
|
||||
OrderMessage(
|
||||
1,
|
||||
text="test file",
|
||||
type=KookMessageType.FILE,
|
||||
),
|
||||
None,
|
||||
),
|
||||
(
|
||||
mock_record_message("./tests/file.wav"),
|
||||
"./tests/file.wav",
|
||||
OrderMessage(
|
||||
1,
|
||||
text='[{"type": "card", "modules": [{"src": "./tests/file.wav", "title": "./tests/file.wav", "type": "audio"}]}]',
|
||||
type=KookMessageType.CARD,
|
||||
),
|
||||
None,
|
||||
),
|
||||
(
|
||||
Plain("test plain"),
|
||||
"test plain",
|
||||
OrderMessage(
|
||||
1,
|
||||
text="test plain",
|
||||
type=KookMessageType.KMARKDOWN,
|
||||
),
|
||||
None,
|
||||
),
|
||||
(
|
||||
At(qq="test at"),
|
||||
"test at",
|
||||
OrderMessage(
|
||||
1,
|
||||
text="(met)test at(met)",
|
||||
type=KookMessageType.KMARKDOWN,
|
||||
),
|
||||
None,
|
||||
),
|
||||
(
|
||||
AtAll(qq="all"),
|
||||
"test atAll",
|
||||
OrderMessage(
|
||||
1,
|
||||
text="(met)all(met)",
|
||||
type=KookMessageType.KMARKDOWN,
|
||||
),
|
||||
None,
|
||||
),
|
||||
(
|
||||
Reply(id="test reply"),
|
||||
"test reply",
|
||||
OrderMessage(
|
||||
1,
|
||||
text="",
|
||||
type=KookMessageType.KMARKDOWN,
|
||||
reply_id="test reply",
|
||||
),
|
||||
None,
|
||||
),
|
||||
(
|
||||
Json(data={"test": "json"}),
|
||||
"test json",
|
||||
OrderMessage(
|
||||
1,
|
||||
text='[{"test": "json"}]',
|
||||
type=KookMessageType.CARD,
|
||||
),
|
||||
None,
|
||||
),
|
||||
(
|
||||
Unknown(text="test unknown"),
|
||||
"test unknown",
|
||||
None,
|
||||
NotImplementedError,
|
||||
),
|
||||
],
|
||||
)
|
||||
async def test_kook_event_warp_message(
|
||||
input_message: BaseMessageComponent,
|
||||
upload_asset_return: str,
|
||||
expected_output: OrderMessage,
|
||||
expected_error: type[Exception] | None,
|
||||
):
|
||||
client = await mock_kook_client(
|
||||
upload_asset_return,
|
||||
"",
|
||||
)
|
||||
|
||||
event = KookEvent(
|
||||
"",
|
||||
mock_astrbot_message(),
|
||||
PlatformMetadata(
|
||||
name="test",
|
||||
id="test",
|
||||
description="test",
|
||||
),
|
||||
"",
|
||||
client,
|
||||
)
|
||||
|
||||
if expected_error:
|
||||
with pytest.raises(expected_error):
|
||||
await event._wrap_message(1, input_message)
|
||||
return
|
||||
|
||||
result = await event._wrap_message(1, input_message)
|
||||
assert result == expected_output
|
||||
|
||||
|
||||
# @pytest.mark.asyncio
|
||||
# @pytest.mark.parametrize(
|
||||
# "message_chain,send_text_expected_output,expected_error",
|
||||
# [
|
||||
# (
|
||||
# MessageChain(
|
||||
# chain=[
|
||||
# Image(file="test image"),
|
||||
# Plain(text="test plain"),
|
||||
# ],
|
||||
# ),
|
||||
# ""
|
||||
# ),
|
||||
# ],
|
||||
# )
|
||||
# async def test_kook_event_send():
|
||||
# client = await mock_kook_client(
|
||||
# "",
|
||||
# "",
|
||||
# )
|
||||
|
||||
# event = KookEvent(
|
||||
# "",
|
||||
# mock_astrbot_message(),
|
||||
# PlatformMetadata(
|
||||
# name="test",
|
||||
# id="test",
|
||||
# description="test",
|
||||
# ),
|
||||
# "",
|
||||
# client,
|
||||
# )
|
||||
|
||||
# await event.send(message=mock_astrbot_message())
|
||||
@@ -1,107 +0,0 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from astrbot.core.platform.sources.kook.kook_types import (
|
||||
ActionGroupModule,
|
||||
ButtonElement,
|
||||
ContextModule,
|
||||
CountdownModule,
|
||||
DividerModule,
|
||||
FileModule,
|
||||
HeaderModule,
|
||||
ImageElement,
|
||||
ImageGroupModule,
|
||||
InviteModule,
|
||||
KmarkdownElement,
|
||||
KookCardMessage,
|
||||
ParagraphStructure,
|
||||
PlainTextElement,
|
||||
SectionModule,
|
||||
KookCardMessageContainer,
|
||||
)
|
||||
from tests.test_kook.shared import TEST_DATA_DIR
|
||||
|
||||
|
||||
def test_kook_card_message_container_append():
|
||||
container = KookCardMessageContainer()
|
||||
container.append(KookCardMessage())
|
||||
assert len(container) == 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input, expect_container_length",
|
||||
[
|
||||
([KookCardMessage()], 1),
|
||||
([KookCardMessage()] * 2, 2),
|
||||
],
|
||||
)
|
||||
def test_kook_card_message_container_to_json(
|
||||
input: list[KookCardMessage], expect_container_length: int
|
||||
):
|
||||
container = KookCardMessageContainer(input)
|
||||
json_output = container.to_json()
|
||||
output = json.loads(json_output)
|
||||
assert isinstance(output, list)
|
||||
assert len(output) == expect_container_length
|
||||
|
||||
|
||||
def test_all_kook_card_type():
|
||||
expect_json_data = Path(TEST_DATA_DIR / "kook_card_data.json").read_text(
|
||||
encoding="utf-8"
|
||||
)
|
||||
json_output = KookCardMessage(
|
||||
theme="info",
|
||||
size="lg",
|
||||
modules=[
|
||||
HeaderModule(text=PlainTextElement(content="test1")),
|
||||
SectionModule(text=KmarkdownElement(content="test2")),
|
||||
DividerModule(),
|
||||
SectionModule(
|
||||
text=ParagraphStructure(
|
||||
cols=2,
|
||||
fields=[
|
||||
KmarkdownElement(content="test3"),
|
||||
KmarkdownElement(content="**test4**"),
|
||||
],
|
||||
)
|
||||
),
|
||||
ImageGroupModule(
|
||||
elements=[
|
||||
ImageElement(
|
||||
src="https://img.kookapp.cn/attachments/2023-01/05/63b645851ff19.svg"
|
||||
)
|
||||
]
|
||||
),
|
||||
FileModule(
|
||||
src="https://img.kookapp.cn/attachments/2023-01/05/63b645851ff19.svg",
|
||||
title="test5",
|
||||
type="file",
|
||||
),
|
||||
CountdownModule(
|
||||
endTime=1772343427360,
|
||||
startTime=1772343378259,
|
||||
mode="second",
|
||||
),
|
||||
ActionGroupModule(
|
||||
elements=[
|
||||
ButtonElement(
|
||||
value="btn_clicked",
|
||||
text="点我测试回调",
|
||||
click="return-val",
|
||||
theme="primary",
|
||||
),
|
||||
ButtonElement(
|
||||
value="https://www.kookapp.cn",
|
||||
text="访问官网",
|
||||
click="link",
|
||||
theme="danger",
|
||||
),
|
||||
]
|
||||
),
|
||||
ContextModule(elements=[PlainTextElement(content="test6")]),
|
||||
InviteModule(code="test7"),
|
||||
],
|
||||
).to_json(indent=4, ensure_ascii=False)
|
||||
assert json_output == expect_json_data
|
||||
+6
-21
@@ -26,21 +26,6 @@ class _version_info:
|
||||
return (self.major, self.minor) >= other[:2]
|
||||
return (self.major, self.minor) >= (other.major, other.minor)
|
||||
|
||||
def __le__(self, other):
|
||||
if isinstance(other, tuple):
|
||||
return (self.major, self.minor) <= other[:2]
|
||||
return (self.major, self.minor) <= (other.major, other.minor)
|
||||
|
||||
def __gt__(self, other):
|
||||
if isinstance(other, tuple):
|
||||
return (self.major, self.minor) > other[:2]
|
||||
return (self.major, self.minor) > (other.major, other.minor)
|
||||
|
||||
def __lt__(self, other):
|
||||
if isinstance(other, tuple):
|
||||
return (self.major, self.minor) < other[:2]
|
||||
return (self.major, self.minor) < (other.major, other.minor)
|
||||
|
||||
|
||||
def test_check_env(monkeypatch):
|
||||
version_info_correct = _version_info(3, 10)
|
||||
@@ -48,12 +33,12 @@ def test_check_env(monkeypatch):
|
||||
monkeypatch.setattr(sys, "version_info", version_info_correct)
|
||||
with mock.patch("os.makedirs") as mock_makedirs:
|
||||
check_env()
|
||||
# check_env uses get_astrbot_*_path() which returns absolute paths,
|
||||
# so just verify makedirs was called the expected number of times
|
||||
assert mock_makedirs.call_count >= 4
|
||||
# Verify all calls used exist_ok=True
|
||||
for call_args in mock_makedirs.call_args_list:
|
||||
assert call_args[1].get("exist_ok") is True
|
||||
# Check that makedirs was called with paths containing expected dirs
|
||||
called_paths = [call[0][0] for call in mock_makedirs.call_args_list]
|
||||
# Use os.path.join for cross-platform path matching
|
||||
assert any(p.rstrip(os.sep).endswith(os.path.join("data", "config")) for p in called_paths)
|
||||
assert any(p.rstrip(os.sep).endswith(os.path.join("data", "plugins")) for p in called_paths)
|
||||
assert any(p.rstrip(os.sep).endswith(os.path.join("data", "temp")) for p in called_paths)
|
||||
|
||||
monkeypatch.setattr(sys, "version_info", version_info_wrong)
|
||||
with pytest.raises(SystemExit):
|
||||
|
||||
@@ -1,130 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from astrbot.core.skills.neo_skill_sync import NeoSkillSyncManager
|
||||
|
||||
|
||||
class _FakeSkills:
|
||||
async def list_releases(self, **kwargs):
|
||||
_ = kwargs
|
||||
return {
|
||||
"items": [
|
||||
{
|
||||
"id": "sr-1",
|
||||
"skill_key": "etl/loader@v1",
|
||||
"candidate_id": "sc-1",
|
||||
"stage": "stable",
|
||||
}
|
||||
],
|
||||
"total": 1,
|
||||
}
|
||||
|
||||
async def get_candidate(self, candidate_id: str):
|
||||
assert candidate_id == "sc-1"
|
||||
return {
|
||||
"id": "sc-1",
|
||||
"payload_ref": "blob:blob-1",
|
||||
}
|
||||
|
||||
async def get_payload(self, payload_ref: str):
|
||||
assert payload_ref == "blob:blob-1"
|
||||
return {
|
||||
"payload_ref": payload_ref,
|
||||
"kind": "astrbot_skill_v1",
|
||||
"payload": {
|
||||
"skill_markdown": "---\ndescription: test\n---\n# title\ncontent",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class _FakeClient:
|
||||
def __init__(self):
|
||||
self.skills = _FakeSkills()
|
||||
|
||||
|
||||
def test_sync_release_writes_skill_and_map(monkeypatch, tmp_path: Path):
|
||||
calls = {"active": [], "sandbox_sync": 0}
|
||||
|
||||
def _fake_set_skill_active(self, name, active):
|
||||
calls["active"].append((name, active))
|
||||
|
||||
async def _fake_sync_sandboxes():
|
||||
calls["sandbox_sync"] += 1
|
||||
|
||||
monkeypatch.setattr(
|
||||
"astrbot.core.skills.neo_skill_sync.SkillManager.set_skill_active",
|
||||
_fake_set_skill_active,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"astrbot.core.skills.neo_skill_sync.sync_skills_to_active_sandboxes",
|
||||
_fake_sync_sandboxes,
|
||||
)
|
||||
|
||||
skills_root = tmp_path / "skills"
|
||||
map_path = skills_root / "neo_skill_map.json"
|
||||
mgr = NeoSkillSyncManager(skills_root=str(skills_root), map_path=str(map_path))
|
||||
|
||||
result = asyncio.run(
|
||||
mgr.sync_release(_FakeClient(), release_id="sr-1", require_stable=True)
|
||||
)
|
||||
|
||||
assert result.skill_key == "etl/loader@v1"
|
||||
assert result.release_id == "sr-1"
|
||||
assert result.local_skill_name.startswith("neo_")
|
||||
assert calls["active"] == [(result.local_skill_name, True)]
|
||||
assert calls["sandbox_sync"] == 1
|
||||
|
||||
skill_md = skills_root / result.local_skill_name / "SKILL.md"
|
||||
assert skill_md.exists()
|
||||
assert "description: test" in skill_md.read_text(encoding="utf-8")
|
||||
|
||||
assert map_path.exists()
|
||||
map_text = map_path.read_text(encoding="utf-8")
|
||||
assert "etl/loader@v1" in map_text
|
||||
assert result.local_skill_name in map_text
|
||||
|
||||
|
||||
def test_sync_release_rejects_non_stable(monkeypatch, tmp_path: Path):
|
||||
class _CanarySkills(_FakeSkills):
|
||||
async def list_releases(self, **kwargs):
|
||||
_ = kwargs
|
||||
return {
|
||||
"items": [
|
||||
{
|
||||
"id": "sr-1",
|
||||
"skill_key": "etl",
|
||||
"candidate_id": "sc-1",
|
||||
"stage": "canary",
|
||||
}
|
||||
],
|
||||
"total": 1,
|
||||
}
|
||||
|
||||
class _CanaryClient:
|
||||
def __init__(self):
|
||||
self.skills = _CanarySkills()
|
||||
|
||||
async def _fake_sync_sandboxes():
|
||||
return
|
||||
|
||||
monkeypatch.setattr(
|
||||
"astrbot.core.skills.neo_skill_sync.sync_skills_to_active_sandboxes",
|
||||
_fake_sync_sandboxes,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"astrbot.core.skills.neo_skill_sync.SkillManager.set_skill_active",
|
||||
lambda self, name, active: None,
|
||||
)
|
||||
|
||||
mgr = NeoSkillSyncManager(
|
||||
skills_root=str(tmp_path / "skills"),
|
||||
map_path=str(tmp_path / "skills" / "neo_skill_map.json"),
|
||||
)
|
||||
with pytest.raises(ValueError, match="Only stable releases"):
|
||||
asyncio.run(
|
||||
mgr.sync_release(_CanaryClient(), release_id="sr-1", require_stable=True)
|
||||
)
|
||||
@@ -1,73 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from types import SimpleNamespace
|
||||
|
||||
from astrbot.core.agent.run_context import ContextWrapper
|
||||
from astrbot.core.computer.tools.neo_skills import PromoteSkillCandidateTool
|
||||
|
||||
|
||||
class _FakeSkills:
|
||||
def __init__(self):
|
||||
self.rollback_called_with = None
|
||||
|
||||
async def promote_candidate(self, candidate_id: str, stage: str = "canary"):
|
||||
assert candidate_id == "cand-1"
|
||||
assert stage == "stable"
|
||||
return {
|
||||
"id": "sr-1",
|
||||
"skill_key": "k1",
|
||||
"candidate_id": candidate_id,
|
||||
"stage": stage,
|
||||
}
|
||||
|
||||
async def rollback_release(self, release_id: str):
|
||||
self.rollback_called_with = release_id
|
||||
return {"id": "rb-1", "rollback_of": release_id}
|
||||
|
||||
|
||||
class _FakeClient:
|
||||
def __init__(self):
|
||||
self.skills = _FakeSkills()
|
||||
|
||||
|
||||
class _FakeBooter:
|
||||
def __init__(self):
|
||||
self.bay_client = _FakeClient()
|
||||
self.sandbox = object()
|
||||
|
||||
|
||||
def test_promote_stable_sync_failure_auto_rolls_back(monkeypatch):
|
||||
async def _fake_get_booter(_ctx, _session_id):
|
||||
return _FakeBooter()
|
||||
|
||||
async def _fake_sync_release(self, client, **kwargs):
|
||||
_ = self, client, kwargs
|
||||
raise ValueError("sync failed")
|
||||
|
||||
monkeypatch.setattr(
|
||||
"astrbot.core.computer.tools.neo_skills.get_booter",
|
||||
_fake_get_booter,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"astrbot.core.computer.tools.neo_skills.NeoSkillSyncManager.sync_release",
|
||||
_fake_sync_release,
|
||||
)
|
||||
|
||||
event = SimpleNamespace(role="admin", unified_msg_origin="session-1")
|
||||
astr_ctx = SimpleNamespace(context=SimpleNamespace(), event=event)
|
||||
run_ctx = ContextWrapper(context=astr_ctx)
|
||||
|
||||
tool = PromoteSkillCandidateTool()
|
||||
result = asyncio.run(
|
||||
tool.call(
|
||||
run_ctx,
|
||||
candidate_id="cand-1",
|
||||
stage="stable",
|
||||
sync_to_local=True,
|
||||
)
|
||||
)
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert "auto rollback succeeded" in result
|
||||
assert "sync failed" in result
|
||||
@@ -1,287 +0,0 @@
|
||||
"""Tests for profile-aware sandbox selection and conditional tool registration."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════
|
||||
# ShipyardNeoBooter.capabilities
|
||||
# ═══════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
class TestShipyardNeoBooterCapabilities:
|
||||
"""Test capabilities property on ShipyardNeoBooter."""
|
||||
|
||||
def _make_booter(self, sandbox_caps: list[str] | None = None):
|
||||
from astrbot.core.computer.booters.shipyard_neo import ShipyardNeoBooter
|
||||
|
||||
booter = ShipyardNeoBooter(
|
||||
endpoint_url="http://localhost:8114",
|
||||
access_token="sk-bay-test",
|
||||
)
|
||||
if sandbox_caps is not None:
|
||||
booter._sandbox = SimpleNamespace(capabilities=sandbox_caps)
|
||||
return booter
|
||||
|
||||
def test_none_before_boot(self):
|
||||
booter = self._make_booter()
|
||||
assert booter.capabilities is None
|
||||
|
||||
def test_returns_tuple_after_boot(self):
|
||||
booter = self._make_booter(["python", "shell", "filesystem"])
|
||||
assert booter.capabilities == ("python", "shell", "filesystem")
|
||||
assert isinstance(booter.capabilities, tuple)
|
||||
|
||||
def test_includes_browser_when_present(self):
|
||||
booter = self._make_booter(["python", "shell", "filesystem", "browser"])
|
||||
assert "browser" in booter.capabilities
|
||||
|
||||
def test_no_browser_when_absent(self):
|
||||
booter = self._make_booter(["python", "shell", "filesystem"])
|
||||
assert "browser" not in booter.capabilities
|
||||
|
||||
def test_returns_immutable(self):
|
||||
"""Verify capabilities returns an immutable tuple."""
|
||||
booter = self._make_booter(["python"])
|
||||
caps = booter.capabilities
|
||||
assert isinstance(caps, tuple)
|
||||
with pytest.raises(AttributeError):
|
||||
caps.append("mutated") # type: ignore[attr-defined]
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════
|
||||
# _apply_sandbox_tools — conditional browser tool registration
|
||||
# ═══════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
def _make_config(booter_type: str = "shipyard_neo"):
|
||||
return SimpleNamespace(
|
||||
sandbox_cfg={"booter": booter_type},
|
||||
)
|
||||
|
||||
|
||||
def _make_req():
|
||||
return SimpleNamespace(func_tool=None, system_prompt="")
|
||||
|
||||
|
||||
def _import_apply_sandbox_tools():
|
||||
"""Import _apply_sandbox_tools, skipping if circular-import fails."""
|
||||
try:
|
||||
from astrbot.core.astr_main_agent import _apply_sandbox_tools
|
||||
|
||||
return _apply_sandbox_tools
|
||||
except ImportError:
|
||||
pytest.skip("Cannot import _apply_sandbox_tools (circular import in test env)")
|
||||
|
||||
|
||||
class TestApplySandboxToolsConditional:
|
||||
"""Verify browser tools are conditionally registered."""
|
||||
|
||||
def _tool_names(self, req) -> set[str]:
|
||||
"""Extract tool names from a request's func_tool."""
|
||||
if req.func_tool is None:
|
||||
return set()
|
||||
return {t.name for t in req.func_tool.tools}
|
||||
|
||||
def test_no_session_registers_all(self):
|
||||
"""First request (no booted session) → all tools including browser."""
|
||||
fn = _import_apply_sandbox_tools()
|
||||
config = _make_config("shipyard_neo")
|
||||
req = _make_req()
|
||||
|
||||
with patch(
|
||||
"astrbot.core.computer.computer_client.session_booter", {}
|
||||
):
|
||||
fn(config, req, "session-1")
|
||||
|
||||
names = self._tool_names(req)
|
||||
assert "astrbot_execute_browser" in names
|
||||
assert "astrbot_execute_browser_batch" in names
|
||||
assert "astrbot_run_browser_skill" in names
|
||||
|
||||
def test_with_browser_capability(self):
|
||||
"""Booted session with browser capability → browser tools registered."""
|
||||
fn = _import_apply_sandbox_tools()
|
||||
config = _make_config("shipyard_neo")
|
||||
req = _make_req()
|
||||
fake_booter = SimpleNamespace(
|
||||
capabilities=["python", "shell", "filesystem", "browser"]
|
||||
)
|
||||
|
||||
with patch(
|
||||
"astrbot.core.computer.computer_client.session_booter",
|
||||
{"session-1": fake_booter},
|
||||
):
|
||||
fn(config, req, "session-1")
|
||||
|
||||
names = self._tool_names(req)
|
||||
assert "astrbot_execute_browser" in names
|
||||
|
||||
def test_without_browser_capability(self):
|
||||
"""Booted session WITHOUT browser capability → browser tools NOT registered."""
|
||||
fn = _import_apply_sandbox_tools()
|
||||
config = _make_config("shipyard_neo")
|
||||
req = _make_req()
|
||||
fake_booter = SimpleNamespace(
|
||||
capabilities=["python", "shell", "filesystem"]
|
||||
)
|
||||
|
||||
with patch(
|
||||
"astrbot.core.computer.computer_client.session_booter",
|
||||
{"session-1": fake_booter},
|
||||
):
|
||||
fn(config, req, "session-1")
|
||||
|
||||
names = self._tool_names(req)
|
||||
assert "astrbot_execute_browser" not in names
|
||||
assert "astrbot_execute_browser_batch" not in names
|
||||
assert "astrbot_run_browser_skill" not in names
|
||||
# Skill tools should still be registered
|
||||
assert "astrbot_get_execution_history" in names
|
||||
|
||||
def test_skill_tools_always_registered(self):
|
||||
"""Skill lifecycle tools are registered regardless of capabilities."""
|
||||
fn = _import_apply_sandbox_tools()
|
||||
config = _make_config("shipyard_neo")
|
||||
req = _make_req()
|
||||
fake_booter = SimpleNamespace(capabilities=["python"])
|
||||
|
||||
with patch(
|
||||
"astrbot.core.computer.computer_client.session_booter",
|
||||
{"session-1": fake_booter},
|
||||
):
|
||||
fn(config, req, "session-1")
|
||||
|
||||
names = self._tool_names(req)
|
||||
assert "astrbot_create_skill_candidate" in names
|
||||
assert "astrbot_promote_skill_candidate" in names
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════
|
||||
# _resolve_profile
|
||||
# ═══════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
class TestResolveProfile:
|
||||
"""Test smart profile selection logic."""
|
||||
|
||||
def _make_booter(self, profile: str = "python-default"):
|
||||
from astrbot.core.computer.booters.shipyard_neo import ShipyardNeoBooter
|
||||
|
||||
return ShipyardNeoBooter(
|
||||
endpoint_url="http://localhost:8114",
|
||||
access_token="sk-bay-test",
|
||||
profile=profile,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_specified_profile_honoured(self):
|
||||
"""User explicitly sets a non-default profile → use it directly."""
|
||||
booter = self._make_booter(profile="browser-python")
|
||||
client = SimpleNamespace() # list_profiles should NOT be called
|
||||
result = await booter._resolve_profile(client)
|
||||
assert result == "browser-python"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_selects_browser_profile(self):
|
||||
"""When multiple profiles available, prefer one with browser."""
|
||||
|
||||
async def _mock_list_profiles():
|
||||
return SimpleNamespace(
|
||||
items=[
|
||||
SimpleNamespace(
|
||||
id="python-default",
|
||||
capabilities=["python", "shell", "filesystem"],
|
||||
),
|
||||
SimpleNamespace(
|
||||
id="browser-python",
|
||||
capabilities=["python", "shell", "filesystem", "browser"],
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
booter = self._make_booter()
|
||||
client = SimpleNamespace(list_profiles=_mock_list_profiles)
|
||||
result = await booter._resolve_profile(client)
|
||||
assert result == "browser-python"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_falls_back_to_default_on_api_error(self):
|
||||
"""API error → graceful fallback to python-default."""
|
||||
|
||||
async def _failing_list_profiles():
|
||||
raise ConnectionError("Bay unreachable")
|
||||
|
||||
booter = self._make_booter()
|
||||
client = SimpleNamespace(list_profiles=_failing_list_profiles)
|
||||
result = await booter._resolve_profile(client)
|
||||
assert result == "python-default"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_falls_back_on_empty_profiles(self):
|
||||
"""Empty profile list → python-default."""
|
||||
|
||||
async def _empty_list_profiles():
|
||||
return SimpleNamespace(items=[])
|
||||
|
||||
booter = self._make_booter()
|
||||
client = SimpleNamespace(list_profiles=_empty_list_profiles)
|
||||
result = await booter._resolve_profile(client)
|
||||
assert result == "python-default"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_profile_selected(self):
|
||||
"""Only one profile available → use it."""
|
||||
|
||||
async def _single_profile():
|
||||
return SimpleNamespace(
|
||||
items=[
|
||||
SimpleNamespace(
|
||||
id="python-data",
|
||||
capabilities=["python", "shell", "filesystem"],
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
booter = self._make_booter()
|
||||
client = SimpleNamespace(list_profiles=_single_profile)
|
||||
result = await booter._resolve_profile(client)
|
||||
assert result == "python-data"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auth_error_not_silenced(self):
|
||||
"""UnauthorizedError must propagate, not be downgraded to fallback."""
|
||||
from shipyard_neo.errors import UnauthorizedError
|
||||
|
||||
async def _unauthorized_list_profiles():
|
||||
raise UnauthorizedError("bad token")
|
||||
|
||||
booter = self._make_booter()
|
||||
client = SimpleNamespace(list_profiles=_unauthorized_list_profiles)
|
||||
with pytest.raises(UnauthorizedError):
|
||||
await booter._resolve_profile(client)
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════
|
||||
# ComputerBooter base class
|
||||
# ═══════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
class TestBaseComputerBooter:
|
||||
"""Verify base class defaults."""
|
||||
|
||||
def test_capabilities_default_none(self):
|
||||
from astrbot.core.computer.booters.base import ComputerBooter
|
||||
|
||||
booter = ComputerBooter()
|
||||
assert booter.capabilities is None
|
||||
|
||||
def test_browser_default_none(self):
|
||||
from astrbot.core.computer.booters.base import ComputerBooter
|
||||
|
||||
booter = ComputerBooter()
|
||||
assert booter.browser is None
|
||||
@@ -1,104 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from astrbot.core.skills.skill_manager import SkillManager
|
||||
|
||||
|
||||
def _write_skill(root: Path, name: str, description: str) -> None:
|
||||
skill_dir = root / name
|
||||
skill_dir.mkdir(parents=True, exist_ok=True)
|
||||
skill_dir.joinpath("SKILL.md").write_text(
|
||||
f"---\ndescription: {description}\n---\n# {name}\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
|
||||
def test_list_skills_merges_local_and_sandbox_cache(monkeypatch, tmp_path: Path):
|
||||
data_dir = tmp_path / "data"
|
||||
temp_dir = tmp_path / "temp"
|
||||
skills_root = tmp_path / "skills"
|
||||
data_dir.mkdir(parents=True, exist_ok=True)
|
||||
temp_dir.mkdir(parents=True, exist_ok=True)
|
||||
skills_root.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"astrbot.core.skills.skill_manager.get_astrbot_data_path",
|
||||
lambda: str(data_dir),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"astrbot.core.skills.skill_manager.get_astrbot_temp_path",
|
||||
lambda: str(temp_dir),
|
||||
)
|
||||
|
||||
mgr = SkillManager(skills_root=str(skills_root))
|
||||
_write_skill(skills_root, "custom-local", "local description")
|
||||
|
||||
mgr.set_sandbox_skills_cache(
|
||||
[
|
||||
{
|
||||
"name": "python-sandbox",
|
||||
"description": "ship built-in",
|
||||
"path": "/app/skills/python-sandbox/SKILL.md",
|
||||
},
|
||||
{
|
||||
"name": "custom-local",
|
||||
"description": "should be ignored by local override",
|
||||
"path": "skills/custom-local/SKILL.md",
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
skills = mgr.list_skills(runtime="sandbox")
|
||||
by_name = {item.name: item for item in skills}
|
||||
|
||||
assert sorted(by_name) == ["custom-local", "python-sandbox"]
|
||||
assert by_name["custom-local"].description == "local description"
|
||||
assert by_name["custom-local"].path == "skills/custom-local/SKILL.md"
|
||||
assert by_name["python-sandbox"].description == "ship built-in"
|
||||
assert by_name["python-sandbox"].path == "skills/python-sandbox/SKILL.md"
|
||||
|
||||
|
||||
def test_sandbox_cached_skill_respects_active_and_display_path(
|
||||
monkeypatch,
|
||||
tmp_path: Path,
|
||||
):
|
||||
data_dir = tmp_path / "data"
|
||||
temp_dir = tmp_path / "temp"
|
||||
skills_root = tmp_path / "skills"
|
||||
data_dir.mkdir(parents=True, exist_ok=True)
|
||||
temp_dir.mkdir(parents=True, exist_ok=True)
|
||||
skills_root.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"astrbot.core.skills.skill_manager.get_astrbot_data_path",
|
||||
lambda: str(data_dir),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"astrbot.core.skills.skill_manager.get_astrbot_temp_path",
|
||||
lambda: str(temp_dir),
|
||||
)
|
||||
|
||||
mgr = SkillManager(skills_root=str(skills_root))
|
||||
mgr.set_sandbox_skills_cache(
|
||||
[
|
||||
{
|
||||
"name": "browser-automation",
|
||||
"description": "gull built-in",
|
||||
"path": "/app/skills/browser-automation/SKILL.md",
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
all_skills = mgr.list_skills(
|
||||
runtime="sandbox",
|
||||
active_only=False,
|
||||
show_sandbox_path=False,
|
||||
)
|
||||
assert len(all_skills) == 1
|
||||
assert all_skills[0].path == "/app/skills/browser-automation/SKILL.md"
|
||||
|
||||
mgr.set_skill_active("browser-automation", False)
|
||||
active_skills = mgr.list_skills(runtime="sandbox", active_only=True)
|
||||
assert active_skills == []
|
||||
|
||||
@@ -1,203 +0,0 @@
|
||||
"""Tests for skill metadata: frontmatter parsing, prompt generation, absolute paths."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from astrbot.core.skills.skill_manager import (
|
||||
SkillInfo,
|
||||
SkillManager,
|
||||
_parse_frontmatter_description,
|
||||
build_skills_prompt,
|
||||
)
|
||||
|
||||
|
||||
# ---------- _parse_frontmatter_description tests ----------
|
||||
|
||||
|
||||
def test_parse_frontmatter_description():
|
||||
text = (
|
||||
"---\n"
|
||||
"name: screenshot-capture\n"
|
||||
"description: Captures full-page screenshots of web pages. "
|
||||
"Use when user asks to screenshot, take a picture of a page, "
|
||||
"截图, or needs a visual snapshot of any URL.\n"
|
||||
"---\n"
|
||||
"# Screenshot Skill\n"
|
||||
)
|
||||
desc = _parse_frontmatter_description(text)
|
||||
assert "Captures full-page screenshots" in desc
|
||||
assert "截图" in desc
|
||||
|
||||
|
||||
def test_parse_frontmatter_description_only():
|
||||
text = "---\ndescription: legacy skill\n---\n# Title\n"
|
||||
assert _parse_frontmatter_description(text) == "legacy skill"
|
||||
|
||||
|
||||
def test_parse_frontmatter_empty():
|
||||
assert _parse_frontmatter_description("no frontmatter") == ""
|
||||
assert _parse_frontmatter_description("") == ""
|
||||
|
||||
|
||||
def test_parse_frontmatter_missing_end_delimiter():
|
||||
text = "---\ndescription: broken\n"
|
||||
assert _parse_frontmatter_description(text) == ""
|
||||
|
||||
|
||||
def test_parse_frontmatter_quoted_description():
|
||||
text = '---\ndescription: "quoted value"\n---\n'
|
||||
assert _parse_frontmatter_description(text) == "quoted value"
|
||||
|
||||
|
||||
# ---------- build_skills_prompt tests ----------
|
||||
|
||||
|
||||
def test_build_skills_prompt_basic_format():
|
||||
skills = [
|
||||
SkillInfo(
|
||||
name="screenshot",
|
||||
description="Take screenshots of web pages",
|
||||
path="/abs/skills/screenshot/SKILL.md",
|
||||
active=True,
|
||||
)
|
||||
]
|
||||
prompt = build_skills_prompt(skills)
|
||||
assert "**screenshot**" in prompt
|
||||
assert "Take screenshots of web pages" in prompt
|
||||
assert "`/abs/skills/screenshot/SKILL.md`" in prompt
|
||||
|
||||
|
||||
def test_build_skills_prompt_absolute_path_in_example():
|
||||
"""The mandatory grounding example should show the absolute path."""
|
||||
skills = [
|
||||
SkillInfo(
|
||||
name="foo",
|
||||
description="do foo",
|
||||
path="/home/pan/AstrBot/skills/foo/SKILL.md",
|
||||
active=True,
|
||||
),
|
||||
]
|
||||
prompt = build_skills_prompt(skills)
|
||||
assert "cat /home/pan/AstrBot/skills/foo/SKILL.md" in prompt
|
||||
|
||||
|
||||
def test_build_skills_prompt_progressive_disclosure_rules():
|
||||
"""The prompt should contain the key progressive disclosure rules."""
|
||||
skills = [
|
||||
SkillInfo(
|
||||
name="test",
|
||||
description="test skill",
|
||||
path="/skills/test/SKILL.md",
|
||||
active=True,
|
||||
)
|
||||
]
|
||||
prompt = build_skills_prompt(skills)
|
||||
# Numbered rules
|
||||
assert "1." in prompt # Discovery
|
||||
assert "2." in prompt # When to trigger
|
||||
assert "3." in prompt # Mandatory grounding
|
||||
assert "4." in prompt # Progressive disclosure
|
||||
# Key concepts
|
||||
assert "Mandatory grounding" in prompt
|
||||
assert "Progressive disclosure" in prompt
|
||||
assert "SKILL.md" in prompt
|
||||
|
||||
|
||||
def test_build_skills_prompt_no_custom_fields():
|
||||
"""Prompt should NOT contain triggers/capabilities/output labels."""
|
||||
skills = [
|
||||
SkillInfo(
|
||||
name="test",
|
||||
description="test skill",
|
||||
path="/skills/test/SKILL.md",
|
||||
active=True,
|
||||
)
|
||||
]
|
||||
prompt = build_skills_prompt(skills)
|
||||
assert "Triggers:" not in prompt
|
||||
assert "Capabilities:" not in prompt
|
||||
assert "Output:" not in prompt
|
||||
|
||||
|
||||
# ---------- list_skills with description ----------
|
||||
|
||||
|
||||
def test_list_skills_parses_description_from_local(monkeypatch, tmp_path: Path):
|
||||
data_dir = tmp_path / "data"
|
||||
temp_dir = tmp_path / "temp"
|
||||
skills_root = tmp_path / "skills"
|
||||
data_dir.mkdir(parents=True, exist_ok=True)
|
||||
temp_dir.mkdir(parents=True, exist_ok=True)
|
||||
skills_root.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"astrbot.core.skills.skill_manager.get_astrbot_data_path",
|
||||
lambda: str(data_dir),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"astrbot.core.skills.skill_manager.get_astrbot_temp_path",
|
||||
lambda: str(temp_dir),
|
||||
)
|
||||
|
||||
skill_dir = skills_root / "screencap"
|
||||
skill_dir.mkdir()
|
||||
skill_dir.joinpath("SKILL.md").write_text(
|
||||
"---\n"
|
||||
"name: screencap\n"
|
||||
"description: Capture screenshots of web pages. "
|
||||
"Use when user asks to screenshot, 截图, or capture a page.\n"
|
||||
"---\n"
|
||||
"# Screenshot\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
mgr = SkillManager(skills_root=str(skills_root))
|
||||
skills = mgr.list_skills()
|
||||
assert len(skills) == 1
|
||||
s = skills[0]
|
||||
assert "Capture screenshots" in s.description
|
||||
assert "截图" in s.description
|
||||
# SkillInfo should NOT have triggers/capabilities/output attributes
|
||||
assert not hasattr(s, "triggers")
|
||||
assert not hasattr(s, "capabilities")
|
||||
assert not hasattr(s, "output")
|
||||
|
||||
|
||||
def test_list_skills_description_from_sandbox_cache(
|
||||
monkeypatch, tmp_path: Path
|
||||
):
|
||||
data_dir = tmp_path / "data"
|
||||
temp_dir = tmp_path / "temp"
|
||||
skills_root = tmp_path / "skills"
|
||||
data_dir.mkdir(parents=True, exist_ok=True)
|
||||
temp_dir.mkdir(parents=True, exist_ok=True)
|
||||
skills_root.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"astrbot.core.skills.skill_manager.get_astrbot_data_path",
|
||||
lambda: str(data_dir),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"astrbot.core.skills.skill_manager.get_astrbot_temp_path",
|
||||
lambda: str(temp_dir),
|
||||
)
|
||||
|
||||
mgr = SkillManager(skills_root=str(skills_root))
|
||||
mgr.set_sandbox_skills_cache(
|
||||
[
|
||||
{
|
||||
"name": "web-scrape",
|
||||
"description": "Scrape web pages and extract structured data. "
|
||||
"Use when user needs to extract content from URLs.",
|
||||
"path": "/home/pan/AstrBot/skills/web-scrape/SKILL.md",
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
skills = mgr.list_skills(runtime="sandbox", show_sandbox_path=False)
|
||||
assert len(skills) == 1
|
||||
s = skills[0]
|
||||
assert "Scrape web pages" in s.description
|
||||
# Path should be the absolute path from cache
|
||||
assert "/home/pan/AstrBot/skills/web-scrape/SKILL.md" in s.path
|
||||
@@ -516,6 +516,30 @@ class TestEnsurePersonaAndSkills:
|
||||
|
||||
assert "Persona Instructions" not in req.system_prompt
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ensure_skills(self, mock_event, mock_context):
|
||||
"""Test applying skills to request."""
|
||||
module = ama
|
||||
mock_skill = MagicMock()
|
||||
mock_skill.name = "test_skill"
|
||||
mock_skill.to_prompt.return_value = "Skill description"
|
||||
mock_context.persona_manager.personas_v3 = []
|
||||
mock_context.persona_manager.resolve_selected_persona = AsyncMock(
|
||||
return_value=(None, None, None, False)
|
||||
)
|
||||
|
||||
with patch("astrbot.core.astr_main_agent.SkillManager") as mock_skill_mgr_cls:
|
||||
mock_skill_mgr = MagicMock()
|
||||
mock_skill_mgr.list_skills.return_value = [mock_skill]
|
||||
mock_skill_mgr_cls.return_value = mock_skill_mgr
|
||||
|
||||
req = ProviderRequest()
|
||||
req.conversation = MagicMock(persona_id=None)
|
||||
|
||||
await module._ensure_persona_and_skills(req, {}, mock_context, mock_event)
|
||||
|
||||
assert "test_skill" in req.system_prompt
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ensure_tools_from_persona(self, mock_event, mock_context):
|
||||
"""Test applying tools from persona."""
|
||||
|
||||
@@ -1,17 +0,0 @@
|
||||
import platform
|
||||
from astrbot.core.computer.tools.python import PythonTool, LocalPythonTool
|
||||
|
||||
def test_python_tool_description_contains_os():
|
||||
"""测试 PythonTool 的描述中是否包含当前操作系统信息"""
|
||||
tool = PythonTool()
|
||||
current_os = platform.system()
|
||||
assert current_os in tool.description
|
||||
assert "IPython" in tool.description
|
||||
|
||||
def test_local_python_tool_description_contains_os():
|
||||
"""测试 LocalPythonTool 的描述中是否包含当前操作系统信息和兼容性提示"""
|
||||
tool = LocalPythonTool()
|
||||
current_os = platform.system()
|
||||
assert current_os in tool.description
|
||||
assert "Python environment" in tool.description
|
||||
assert "system-compatible" in tool.description
|
||||
Reference in New Issue
Block a user