Compare commits

..

2 Commits

Author SHA1 Message Date
copilot-swe-agent[bot] 2a7c8b44bf feat: switch Monaco editor from CDN to local deployment
Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>
2026-03-01 15:15:03 +00:00
copilot-swe-agent[bot] b8e83b772d Initial plan 2026-03-01 15:09:13 +00:00
96 changed files with 722 additions and 9359 deletions
+1 -1
View File
@@ -36,7 +36,7 @@ jobs:
zip -r dist.zip dist
- name: Archive production artifacts
uses: actions/upload-artifact@v7
uses: actions/upload-artifact@v6
with:
name: dist-without-markdown
path: |
+2 -2
View File
@@ -71,7 +71,7 @@ jobs:
zip -r "AstrBot-${{ steps.tag.outputs.tag }}-dashboard.zip" dist
- name: Upload dashboard artifact
uses: actions/upload-artifact@v7
uses: actions/upload-artifact@v6
with:
name: Dashboard-${{ steps.tag.outputs.tag }}
if-no-files-found: error
@@ -132,7 +132,7 @@ jobs:
echo "tag=$tag" >> "$GITHUB_OUTPUT"
- name: Download dashboard artifact
uses: actions/download-artifact@v8
uses: actions/download-artifact@v7
with:
name: Dashboard-${{ steps.tag.outputs.tag }}
path: release-assets
-7
View File
@@ -36,9 +36,6 @@ dashboard/dist/
package-lock.json
yarn.lock
# Bundled dashboard dist (generated by hatch_build.py during pip wheel build)
astrbot/dashboard/dist/
# Operating System
**/.DS_Store
.DS_Store
@@ -57,7 +54,3 @@ IFLOW.md
# genie_tts data
CharacterModels/
GenieData/
.agent/
.codex/
.opencode/
.kilocode/
-52
View File
@@ -46,32 +46,6 @@ ruff check .
如果您使用 VSCode,可以安装 `Ruff` 插件。
##### PR 功能完整性验证(推荐)
如果您希望在本地做一套接近 CI 的完整验证,可使用:
```bash
make pr-test-neo
```
该命令会执行:
- `uv sync --group dev`
- `ruff format --check .``ruff check .`
- Neo 相关关键测试
- `main.py` 启动 smoke test(检测 `http://localhost:6185`
需要全量验证时可使用:
```bash
make pr-test-full
```
如果只想快速重复执行(跳过依赖同步和 dashboard 构建):
```bash
make pr-test-full-fast
```
## Contributing Guide
@@ -114,29 +88,3 @@ We use Ruff as our code formatter and static analysis tool. Before submitting yo
ruff format .
ruff check .
```
##### PR completeness checks (recommended)
To run a local validation flow close to CI, use:
```bash
make pr-test-neo
```
This command runs:
- `uv sync --group dev`
- `ruff format --check .` and `ruff check .`
- Neo-related critical tests
- a startup smoke test against `http://localhost:6185`
For full validation, use:
```bash
make pr-test-full
```
For faster repeated runs (skip dependency sync and dashboard build), use:
```bash
make pr-test-full-fast
```
+1 -10
View File
@@ -1,4 +1,4 @@
.PHONY: worktree worktree-add worktree-rm pr-test-neo pr-test-full pr-test-full-fast
.PHONY: worktree worktree-add worktree-rm
WORKTREE_DIR ?= ../astrbot_worktree
BRANCH ?= $(word 2,$(MAKECMDGOALS))
@@ -27,15 +27,6 @@ endif
echo "Worktree $(WORKTREE_DIR)/$(BRANCH) not found."; \
fi
pr-test-neo:
./scripts/pr_test_env.sh --profile neo
pr-test-full:
./scripts/pr_test_env.sh --profile full
pr-test-full-fast:
./scripts/pr_test_env.sh --profile full --skip-sync --no-dashboard
# Swallow extra args (branch/base) so make doesn't treat them as targets
%:
@true
-6
View File
@@ -184,12 +184,6 @@ Connect AstrBot to your favorite chat platform.
| Minimax TTS | Text-to-Speech Services |
| Volcano Engine TTS | Text-to-Speech Services |
## ❤️ Sponsors
<p align="center">
<img alt="sponsors" src="https://sponsors.astrbot.app/?v=1">
</p>
## ❤️ Contributing
Issues and Pull Requests are always welcome! Feel free to submit your changes to this project :)
+10 -1
View File
@@ -8,7 +8,7 @@ from bs4 import BeautifulSoup
from readability import Document
from astrbot.api import AstrBotConfig, llm_tool, logger, sp, star
from astrbot.api.event import AstrMessageEvent, filter
from astrbot.api.event import AstrMessageEvent, MessageEventResult, filter
from astrbot.api.provider import ProviderRequest
from astrbot.core.provider.func_tool_manager import FunctionToolManager
@@ -196,6 +196,15 @@ class Main(star.Star):
)
return results
@filter.command("websearch")
async def websearch(self, event: AstrMessageEvent, oper: str | None = None) -> None:
"""网页搜索指令(已废弃)"""
event.set_result(
MessageEventResult().message(
"此指令已经被废弃,请在 WebUI 中开启或关闭网页搜索功能。",
),
)
@llm_tool(name="web_search")
async def search_from_search_engine(
self,
+7 -7
View File
@@ -1,4 +1,4 @@
"""AstrBot CLI entry point"""
"""AstrBot CLI入口"""
import sys
@@ -29,23 +29,23 @@ def cli() -> None:
@click.command()
@click.argument("command_name", required=False, type=str)
def help(command_name: str | None) -> None:
"""Display help information for commands
"""显示命令的帮助信息
If COMMAND_NAME is provided, display detailed help for that command.
Otherwise, display general help information.
如果提供了 COMMAND_NAME,则显示该命令的详细帮助信息。
否则,显示通用帮助信息。
"""
ctx = click.get_current_context()
if command_name:
# Find the specified command
# 查找指定命令
command = cli.get_command(ctx, command_name)
if command:
# Display help for the specific command
# 显示特定命令的帮助信息
click.echo(command.get_help(ctx))
else:
click.echo(f"Unknown command: {command_name}")
sys.exit(1)
else:
# Display general help information
# 显示通用帮助信息
click.echo(cli.get_help(ctx))
+43 -47
View File
@@ -10,61 +10,57 @@ from ..utils import check_astrbot_root, get_astrbot_root
def _validate_log_level(value: str) -> str:
"""Validate log level"""
"""验证日志级别"""
value = value.upper()
if value not in ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]:
raise click.ClickException(
"Log level must be one of DEBUG/INFO/WARNING/ERROR/CRITICAL",
"日志级别必须是 DEBUG/INFO/WARNING/ERROR/CRITICAL 之一",
)
return value
def _validate_dashboard_port(value: str) -> int:
"""Validate Dashboard port"""
"""验证 Dashboard 端口"""
try:
port = int(value)
if port < 1 or port > 65535:
raise click.ClickException("Port must be in range 1-65535")
raise click.ClickException("端口必须在 1-65535 范围内")
return port
except ValueError:
raise click.ClickException("Port must be a number")
raise click.ClickException("端口必须是数字")
def _validate_dashboard_username(value: str) -> str:
"""Validate Dashboard username"""
"""验证 Dashboard 用户名"""
if not value:
raise click.ClickException("Username cannot be empty")
raise click.ClickException("用户名不能为空")
return value
def _validate_dashboard_password(value: str) -> str:
"""Validate Dashboard password"""
"""验证 Dashboard 密码"""
if not value:
raise click.ClickException("Password cannot be empty")
raise click.ClickException("密码不能为空")
return hashlib.md5(value.encode()).hexdigest()
def _validate_timezone(value: str) -> str:
"""Validate timezone"""
"""验证时区"""
try:
zoneinfo.ZoneInfo(value)
except Exception:
raise click.ClickException(
f"Invalid timezone: {value}. Please use a valid IANA timezone name"
)
raise click.ClickException(f"无效的时区: {value},请使用有效的IANA时区名称")
return value
def _validate_callback_api_base(value: str) -> str:
"""Validate callback API base URL"""
"""验证回调接口基址"""
if not value.startswith("http://") and not value.startswith("https://"):
raise click.ClickException(
"Callback API base must start with http:// or https://"
)
raise click.ClickException("回调接口基址必须以 http:// 或 https:// 开头")
return value
# Configuration items settable via CLI, mapping config keys to validator functions
# 可通过CLI设置的配置项,配置键到验证器函数的映射
CONFIG_VALIDATORS: dict[str, Callable[[str], Any]] = {
"timezone": _validate_timezone,
"log_level": _validate_log_level,
@@ -76,11 +72,11 @@ CONFIG_VALIDATORS: dict[str, Callable[[str], Any]] = {
def _load_config() -> dict[str, Any]:
"""Load or initialize config file"""
"""加载或初始化配置文件"""
root = get_astrbot_root()
if not check_astrbot_root(root):
raise click.ClickException(
f"{root} is not a valid AstrBot root directory. Use 'astrbot init' to initialize",
f"{root}不是有效的 AstrBot 根目录,如需初始化请使用 astrbot init",
)
config_path = root / "data" / "cmd_config.json"
@@ -95,11 +91,11 @@ def _load_config() -> dict[str, Any]:
try:
return json.loads(config_path.read_text(encoding="utf-8-sig"))
except json.JSONDecodeError as e:
raise click.ClickException(f"Failed to parse config file: {e!s}")
raise click.ClickException(f"配置文件解析失败: {e!s}")
def _save_config(config: dict[str, Any]) -> None:
"""Save config file"""
"""保存配置文件"""
config_path = get_astrbot_root() / "data" / "cmd_config.json"
config_path.write_text(
@@ -109,21 +105,21 @@ def _save_config(config: dict[str, Any]) -> None:
def _set_nested_item(obj: dict[str, Any], path: str, value: Any) -> None:
"""Set a value in a nested dictionary"""
"""设置嵌套字典中的值"""
parts = path.split(".")
for part in parts[:-1]:
if part not in obj:
obj[part] = {}
elif not isinstance(obj[part], dict):
raise click.ClickException(
f"Config path conflict: {'.'.join(parts[: parts.index(part) + 1])} is not a dict",
f"配置路径冲突: {'.'.join(parts[: parts.index(part) + 1])} 不是字典",
)
obj = obj[part]
obj[parts[-1]] = value
def _get_nested_item(obj: dict[str, Any], path: str) -> Any:
"""Get a value from a nested dictionary"""
"""获取嵌套字典中的值"""
parts = path.split(".")
for part in parts:
obj = obj[part]
@@ -132,21 +128,21 @@ def _get_nested_item(obj: dict[str, Any], path: str) -> Any:
@click.group(name="conf")
def conf() -> None:
"""Configuration management commands
"""配置管理命令
Supported config keys:
支持的配置项:
- timezone: Timezone setting (e.g. Asia/Shanghai)
- timezone: 时区设置 (例如: Asia/Shanghai)
- log_level: Log level (DEBUG/INFO/WARNING/ERROR/CRITICAL)
- log_level: 日志级别 (DEBUG/INFO/WARNING/ERROR/CRITICAL)
- dashboard.port: Dashboard port
- dashboard.port: Dashboard 端口
- dashboard.username: Dashboard username
- dashboard.username: Dashboard 用户名
- dashboard.password: Dashboard password
- dashboard.password: Dashboard 密码
- callback_api_base: Callback API base URL
- callback_api_base: 回调接口基址
"""
@@ -154,9 +150,9 @@ def conf() -> None:
@click.argument("key")
@click.argument("value")
def set_config(key: str, value: str) -> None:
"""Set the value of a config item"""
"""设置配置项的值"""
if key not in CONFIG_VALIDATORS:
raise click.ClickException(f"Unsupported config key: {key}")
raise click.ClickException(f"不支持的配置项: {key}")
config = _load_config()
@@ -166,29 +162,29 @@ def set_config(key: str, value: str) -> None:
_set_nested_item(config, key, validated_value)
_save_config(config)
click.echo(f"Config updated: {key}")
click.echo(f"配置已更新: {key}")
if key == "dashboard.password":
click.echo(" Old value: ********")
click.echo(" New value: ********")
click.echo(" 原值: ********")
click.echo(" 新值: ********")
else:
click.echo(f" Old value: {old_value}")
click.echo(f" New value: {validated_value}")
click.echo(f" 原值: {old_value}")
click.echo(f" 新值: {validated_value}")
except KeyError:
raise click.ClickException(f"Unknown config key: {key}")
raise click.ClickException(f"未知的配置项: {key}")
except Exception as e:
raise click.UsageError(f"Failed to set config: {e!s}")
raise click.UsageError(f"设置配置失败: {e!s}")
@conf.command(name="get")
@click.argument("key", required=False)
def get_config(key: str | None = None) -> None:
"""Get the value of a config item. If no key is provided, show all configurable items"""
"""获取配置项的值,不提供key则显示所有可配置项"""
config = _load_config()
if key:
if key not in CONFIG_VALIDATORS:
raise click.ClickException(f"Unsupported config key: {key}")
raise click.ClickException(f"不支持的配置项: {key}")
try:
value = _get_nested_item(config, key)
@@ -196,11 +192,11 @@ def get_config(key: str | None = None) -> None:
value = "********"
click.echo(f"{key}: {value}")
except KeyError:
raise click.ClickException(f"Unknown config key: {key}")
raise click.ClickException(f"未知的配置项: {key}")
except Exception as e:
raise click.UsageError(f"Failed to get config: {e!s}")
raise click.UsageError(f"获取配置失败: {e!s}")
else:
click.echo("Current config:")
click.echo("当前配置:")
for key in CONFIG_VALIDATORS:
try:
value = (
+9 -8
View File
@@ -8,12 +8,16 @@ from ..utils import check_dashboard, get_astrbot_root
async def initialize_astrbot(astrbot_root: Path) -> None:
"""Execute AstrBot initialization logic"""
"""执行 AstrBot 初始化逻辑"""
dot_astrbot = astrbot_root / ".astrbot"
if not dot_astrbot.exists():
click.echo(f"Current Directory: {astrbot_root}")
click.echo(
"如果你确认这是 Astrbot root directory, 你需要在当前目录下创建一个 .astrbot 文件标记该目录为 AstrBot 的数据目录。",
)
if click.confirm(
f"Install AstrBot to this directory? {astrbot_root}",
f"请检查当前目录是否正确,确认正确请回车: {astrbot_root}",
default=True,
abort=True,
):
@@ -36,7 +40,7 @@ async def initialize_astrbot(astrbot_root: Path) -> None:
@click.command()
def init() -> None:
"""Initialize AstrBot"""
"""初始化 AstrBot"""
click.echo("Initializing AstrBot...")
astrbot_root = get_astrbot_root()
lock_file = astrbot_root / "astrbot.lock"
@@ -45,11 +49,8 @@ def init() -> None:
try:
with lock.acquire():
asyncio.run(initialize_astrbot(astrbot_root))
click.echo("Done! You can now run 'astrbot run' to start AstrBot")
except Timeout:
raise click.ClickException(
"Cannot acquire lock file. Please check if another instance is running"
)
raise click.ClickException("无法获取锁文件,请检查是否有其他实例正在运行")
except Exception as e:
raise click.ClickException(f"Initialization failed: {e!s}")
raise click.ClickException(f"初始化失败: {e!s}")
+46 -54
View File
@@ -16,14 +16,14 @@ from ..utils import (
@click.group()
def plug() -> None:
"""Plugin management"""
"""插件管理"""
def _get_data_path() -> Path:
base = get_astrbot_root()
if not check_astrbot_root(base):
raise click.ClickException(
f"{base} is not a valid AstrBot root directory. Use 'astrbot init' to initialize",
f"{base}不是有效的 AstrBot 根目录,如需初始化请使用 astrbot init",
)
return (base / "data").resolve()
@@ -32,9 +32,7 @@ def display_plugins(plugins, title=None, color=None) -> None:
if title:
click.echo(click.style(title, fg=color, bold=True))
click.echo(
f"{'Name':<20} {'Version':<10} {'Status':<10} {'Author':<15} {'Description':<30}"
)
click.echo(f"{'名称':<20} {'版本':<10} {'状态':<10} {'作者':<15} {'描述':<30}")
click.echo("-" * 85)
for p in plugins:
@@ -48,30 +46,30 @@ def display_plugins(plugins, title=None, color=None) -> None:
@plug.command()
@click.argument("name")
def new(name: str) -> None:
"""Create a new plugin"""
"""创建新插件"""
base_path = _get_data_path()
plug_path = base_path / "plugins" / name
if plug_path.exists():
raise click.ClickException(f"Plugin {name} already exists")
raise click.ClickException(f"插件 {name} 已存在")
author = click.prompt("Enter plugin author", type=str)
desc = click.prompt("Enter plugin description", type=str)
version = click.prompt("Enter plugin version", type=str)
author = click.prompt("请输入插件作者", type=str)
desc = click.prompt("请输入插件描述", type=str)
version = click.prompt("请输入插件版本", type=str)
if not re.match(r"^\d+\.\d+(\.\d+)?$", version.lower().lstrip("v")):
raise click.ClickException("Version must be in x.y or x.y.z format")
repo = click.prompt("Enter plugin repository URL:", type=str)
raise click.ClickException("版本号必须为 x.y x.y.z 格式")
repo = click.prompt("请输入插件仓库:", type=str)
if not repo.startswith("http"):
raise click.ClickException("Repository URL must start with http")
raise click.ClickException("仓库地址必须以 http 开头")
click.echo("Downloading plugin template...")
click.echo("下载插件模板...")
get_git_repo(
"https://github.com/Soulter/helloworld",
plug_path,
)
click.echo("Rewriting plugin metadata...")
# Rewrite metadata.yaml
click.echo("重写插件信息...")
# 重写 metadata.yaml
with open(plug_path / "metadata.yaml", "w", encoding="utf-8") as f:
f.write(
f"name: {name}\n"
@@ -81,13 +79,11 @@ def new(name: str) -> None:
f"repo: {repo}\n",
)
# Rewrite README.md
# 重写 README.md
with open(plug_path / "README.md", "w", encoding="utf-8") as f:
f.write(
f"# {name}\n\n{desc}\n\n# Support\n\n[Documentation](https://astrbot.app)\n"
)
f.write(f"# {name}\n\n{desc}\n\n# 支持\n\n[帮助文档](https://astrbot.app)\n")
# Rewrite main.py
# 重写 main.py
with open(plug_path / "main.py", encoding="utf-8") as f:
content = f.read()
@@ -99,54 +95,54 @@ def new(name: str) -> None:
with open(plug_path / "main.py", "w", encoding="utf-8") as f:
f.write(new_content)
click.echo(f"Plugin {name} created successfully")
click.echo(f"插件 {name} 创建成功")
@plug.command()
@click.option("--all", "-a", is_flag=True, help="List uninstalled plugins")
@click.option("--all", "-a", is_flag=True, help="列出未安装的插件")
def list(all: bool) -> None:
"""List plugins"""
"""列出插件"""
base_path = _get_data_path()
plugins = build_plug_list(base_path / "plugins")
# Unpublished plugins
# 未发布的插件
not_published_plugins = [
p for p in plugins if p["status"] == PluginStatus.NOT_PUBLISHED
]
if not_published_plugins:
display_plugins(not_published_plugins, "Unpublished Plugins", "red")
display_plugins(not_published_plugins, "未发布的插件", "red")
# Plugins needing update
# 需要更新的插件
need_update_plugins = [
p for p in plugins if p["status"] == PluginStatus.NEED_UPDATE
]
if need_update_plugins:
display_plugins(need_update_plugins, "Plugins Needing Update", "yellow")
display_plugins(need_update_plugins, "需要更新的插件", "yellow")
# Installed plugins
# 已安装的插件
installed_plugins = [p for p in plugins if p["status"] == PluginStatus.INSTALLED]
if installed_plugins:
display_plugins(installed_plugins, "Installed Plugins", "green")
display_plugins(installed_plugins, "已安装的插件", "green")
# Uninstalled plugins
# 未安装的插件
not_installed_plugins = [
p for p in plugins if p["status"] == PluginStatus.NOT_INSTALLED
]
if not_installed_plugins and all:
display_plugins(not_installed_plugins, "Uninstalled Plugins", "blue")
display_plugins(not_installed_plugins, "未安装的插件", "blue")
if (
not any([not_published_plugins, need_update_plugins, installed_plugins])
and not all
):
click.echo("No plugins installed")
click.echo("未安装任何插件")
@plug.command()
@click.argument("name")
@click.option("--proxy", help="Proxy server address")
@click.option("--proxy", help="代理服务器地址")
def install(name: str, proxy: str | None) -> None:
"""Install a plugin"""
"""安装插件"""
base_path = _get_data_path()
plug_path = base_path / "plugins"
plugins = build_plug_list(base_path / "plugins")
@@ -161,7 +157,7 @@ def install(name: str, proxy: str | None) -> None:
)
if not plugin:
raise click.ClickException(f"Plugin {name} not found or already installed")
raise click.ClickException(f"未找到可安装的插件 {name},可能是不存在或已安装")
manage_plugin(plugin, plug_path, is_update=False, proxy=proxy)
@@ -169,32 +165,30 @@ def install(name: str, proxy: str | None) -> None:
@plug.command()
@click.argument("name")
def remove(name: str) -> None:
"""Uninstall a plugin"""
"""卸载插件"""
base_path = _get_data_path()
plugins = build_plug_list(base_path / "plugins")
plugin = next((p for p in plugins if p["name"] == name), None)
if not plugin or not plugin.get("local_path"):
raise click.ClickException(f"Plugin {name} does not exist or is not installed")
raise click.ClickException(f"插件 {name} 不存在或未安装")
plugin_path = plugin["local_path"]
click.confirm(
f"Are you sure you want to uninstall plugin {name}?", default=False, abort=True
)
click.confirm(f"确定要卸载插件 {name} 吗?", default=False, abort=True)
try:
shutil.rmtree(plugin_path)
click.echo(f"Plugin {name} has been uninstalled")
click.echo(f"插件 {name} 已卸载")
except Exception as e:
raise click.ClickException(f"Failed to uninstall plugin {name}: {e}")
raise click.ClickException(f"卸载插件 {name} 失败: {e}")
@plug.command()
@click.argument("name", required=False)
@click.option("--proxy", help="GitHub proxy address")
@click.option("--proxy", help="Github代理地址")
def update(name: str, proxy: str | None) -> None:
"""Update plugins"""
"""更新插件"""
base_path = _get_data_path()
plug_path = base_path / "plugins"
plugins = build_plug_list(base_path / "plugins")
@@ -210,9 +204,7 @@ def update(name: str, proxy: str | None) -> None:
)
if not plugin:
raise click.ClickException(
f"Plugin {name} does not need updating or cannot be updated"
)
raise click.ClickException(f"插件 {name} 不需要更新或无法更新")
manage_plugin(plugin, plug_path, is_update=True, proxy=proxy)
else:
@@ -221,20 +213,20 @@ def update(name: str, proxy: str | None) -> None:
]
if not need_update_plugins:
click.echo("No plugins need updating")
click.echo("没有需要更新的插件")
return
click.echo(f"Found {len(need_update_plugins)} plugin(s) needing update")
click.echo(f"发现 {len(need_update_plugins)} 个插件需要更新")
for plugin in need_update_plugins:
plugin_name = plugin["name"]
click.echo(f"Updating plugin {plugin_name}...")
click.echo(f"正在更新插件 {plugin_name}...")
manage_plugin(plugin, plug_path, is_update=True, proxy=proxy)
@plug.command()
@click.argument("query")
def search(query: str) -> None:
"""Search for plugins"""
"""搜索插件"""
base_path = _get_data_path()
plugins = build_plug_list(base_path / "plugins")
@@ -247,7 +239,7 @@ def search(query: str) -> None:
]
if not matched_plugins:
click.echo(f"No plugins matching '{query}' found")
click.echo(f"未找到匹配 '{query}' 的插件")
return
display_plugins(matched_plugins, f"Search results: '{query}'", "cyan")
display_plugins(matched_plugins, f"搜索结果: '{query}'", "cyan")
+9 -11
View File
@@ -11,7 +11,7 @@ from ..utils import check_astrbot_root, check_dashboard, get_astrbot_root
async def run_astrbot(astrbot_root: Path) -> None:
"""Run AstrBot"""
"""运行 AstrBot"""
from astrbot.core import LogBroker, LogManager, db_helper, logger
from astrbot.core.initial_loader import InitialLoader
@@ -26,18 +26,18 @@ async def run_astrbot(astrbot_root: Path) -> None:
await core_lifecycle.start()
@click.option("--reload", "-r", is_flag=True, help="Auto-reload plugins")
@click.option("--port", "-p", help="AstrBot Dashboard port", required=False, type=str)
@click.option("--reload", "-r", is_flag=True, help="插件自动重载")
@click.option("--port", "-p", help="Astrbot Dashboard端口", required=False, type=str)
@click.command()
def run(reload: bool, port: str) -> None:
"""Run AstrBot"""
"""运行 AstrBot"""
try:
os.environ["ASTRBOT_CLI"] = "1"
astrbot_root = get_astrbot_root()
if not check_astrbot_root(astrbot_root):
raise click.ClickException(
f"{astrbot_root} is not a valid AstrBot root directory. Use 'astrbot init' to initialize",
f"{astrbot_root}不是有效的 AstrBot 根目录,如需初始化请使用 astrbot init",
)
os.environ["ASTRBOT_ROOT"] = str(astrbot_root)
@@ -47,7 +47,7 @@ def run(reload: bool, port: str) -> None:
os.environ["DASHBOARD_PORT"] = port
if reload:
click.echo("Plugin auto-reload enabled")
click.echo("启用插件自动重载")
os.environ["ASTRBOT_RELOAD"] = "1"
lock_file = astrbot_root / "astrbot.lock"
@@ -55,10 +55,8 @@ def run(reload: bool, port: str) -> None:
with lock.acquire():
asyncio.run(run_astrbot(astrbot_root))
except KeyboardInterrupt:
click.echo("AstrBot has been shut down.")
click.echo("AstrBot 已关闭...")
except Timeout:
raise click.ClickException(
"Cannot acquire lock file. Please check if another instance is running"
)
raise click.ClickException("无法获取锁文件,请检查是否有其他实例正在运行")
except Exception as e:
raise click.ClickException(f"Runtime error: {e}\n{traceback.format_exc()}")
raise click.ClickException(f"运行时出现错误: {e}\n{traceback.format_exc()}")
+13 -21
View File
@@ -2,12 +2,9 @@ from pathlib import Path
import click
# Static assets bundled inside the installed wheel (built by hatch_build.py).
_BUNDLED_DIST = Path(__file__).parent.parent.parent / "dashboard" / "dist"
def check_astrbot_root(path: str | Path) -> bool:
"""Check if the path is an AstrBot root directory"""
"""检查路径是否为 AstrBot 根目录"""
if not isinstance(path, Path):
path = Path(path)
if not path.exists() or not path.is_dir():
@@ -18,48 +15,43 @@ def check_astrbot_root(path: str | Path) -> bool:
def get_astrbot_root() -> Path:
"""Get the AstrBot root directory path"""
"""获取Astrbot根目录路径"""
return Path.cwd()
async def check_dashboard(astrbot_root: Path) -> None:
"""Check if the dashboard is installed"""
"""检查是否安装了dashboard"""
from astrbot.core.config.default import VERSION
from astrbot.core.utils.io import download_dashboard, get_dashboard_version
from .version_comparator import VersionComparator
# If the wheel ships bundled dashboard assets, no network download is needed.
if _BUNDLED_DIST.exists():
click.echo("Dashboard is bundled with the package skipping download.")
return
try:
dashboard_version = await get_dashboard_version()
match dashboard_version:
case None:
click.echo("Dashboard is not installed")
click.echo("未安装管理面板")
if click.confirm(
"Install dashboard?",
"是否安装管理面板?",
default=True,
abort=True,
):
click.echo("Installing dashboard...")
click.echo("正在安装管理面板...")
await download_dashboard(
path="data/dashboard.zip",
extract_path=str(astrbot_root),
version=f"v{VERSION}",
latest=False,
)
click.echo("Dashboard installed successfully")
click.echo("管理面板安装完成")
case str():
if VersionComparator.compare_version(VERSION, dashboard_version) <= 0:
click.echo("Dashboard is already up to date")
click.echo("管理面板已是最新版本")
return
try:
version = dashboard_version.split("v")[1]
click.echo(f"Dashboard version: {version}")
click.echo(f"管理面板版本: {version}")
await download_dashboard(
path="data/dashboard.zip",
extract_path=str(astrbot_root),
@@ -67,10 +59,10 @@ async def check_dashboard(astrbot_root: Path) -> None:
latest=False,
)
except Exception as e:
click.echo(f"Failed to download dashboard: {e}")
click.echo(f"下载管理面板失败: {e}")
return
except FileNotFoundError:
click.echo("Initializing dashboard directory...")
click.echo("初始化管理面板目录...")
try:
await download_dashboard(
path=str(astrbot_root / "dashboard.zip"),
@@ -78,7 +70,7 @@ async def check_dashboard(astrbot_root: Path) -> None:
version=f"v{VERSION}",
latest=False,
)
click.echo("Dashboard initialized successfully")
click.echo("管理面板初始化完成")
except Exception as e:
click.echo(f"Failed to download dashboard: {e}")
click.echo(f"下载管理面板失败: {e}")
return
+43 -47
View File
@@ -13,22 +13,22 @@ from .version_comparator import VersionComparator
class PluginStatus(str, Enum):
INSTALLED = "installed"
NEED_UPDATE = "needs-update"
NOT_INSTALLED = "not-installed"
NOT_PUBLISHED = "unpublished"
INSTALLED = "已安装"
NEED_UPDATE = "需更新"
NOT_INSTALLED = "未安装"
NOT_PUBLISHED = "未发布"
def get_git_repo(url: str, target_path: Path, proxy: str | None = None) -> None:
"""Download code from a Git repository and extract to the specified path"""
"""从 Git 仓库下载代码并解压到指定路径"""
temp_dir = Path(tempfile.mkdtemp())
try:
# Parse repository info
# 解析仓库信息
repo_namespace = url.split("/")[-2:]
author = repo_namespace[0]
repo = repo_namespace[1]
# Try to get the latest release
# 尝试获取最新的 release
release_url = f"https://api.github.com/repos/{author}/{repo}/releases"
try:
with httpx.Client(
@@ -40,21 +40,21 @@ def get_git_repo(url: str, target_path: Path, proxy: str | None = None) -> None:
releases = resp.json()
if releases:
# Use the latest release
# 使用最新的 release
download_url = releases[0]["zipball_url"]
else:
# No release found, use default branch
click.echo(f"Downloading {author}/{repo} from default branch")
# 没有 release,使用默认分支
click.echo(f"正在从默认分支下载 {author}/{repo}")
download_url = f"https://github.com/{author}/{repo}/archive/refs/heads/master.zip"
except Exception as e:
click.echo(f"Failed to get release info: {e}. Using provided URL directly")
click.echo(f"获取 release 信息失败: {e},将直接使用提供的 URL")
download_url = url
# Apply proxy
# 应用代理
if proxy:
download_url = f"{proxy}/{download_url}"
# Download and extract
# 下载并解压
with httpx.Client(
proxy=proxy if proxy else None,
follow_redirects=True,
@@ -65,7 +65,7 @@ def get_git_repo(url: str, target_path: Path, proxy: str | None = None) -> None:
and "archive/refs/heads/master.zip" in download_url
):
alt_url = download_url.replace("master.zip", "main.zip")
click.echo("Branch 'master' not found, trying 'main' branch")
click.echo("master 分支不存在,尝试下载 main 分支")
resp = client.get(alt_url)
resp.raise_for_status()
else:
@@ -84,13 +84,13 @@ def get_git_repo(url: str, target_path: Path, proxy: str | None = None) -> None:
def load_yaml_metadata(plugin_dir: Path) -> dict:
"""Load plugin metadata from metadata.yaml file
""" metadata.yaml 文件加载插件元数据
Args:
plugin_dir: Plugin directory path
plugin_dir: 插件目录路径
Returns:
dict: Dictionary containing metadata, or empty dict if loading fails
dict: 包含元数据的字典,如果读取失败则返回空字典
"""
yaml_path = plugin_dir / "metadata.yaml"
@@ -98,33 +98,33 @@ def load_yaml_metadata(plugin_dir: Path) -> dict:
try:
return yaml.safe_load(yaml_path.read_text(encoding="utf-8")) or {}
except Exception as e:
click.echo(f"Failed to read {yaml_path}: {e}", err=True)
click.echo(f"读取 {yaml_path} 失败: {e}", err=True)
return {}
def build_plug_list(plugins_dir: Path) -> list:
"""Build plugin list containing local and online plugin information
"""构建插件列表,包含本地和在线插件信息
Args:
plugins_dir (Path): Plugin directory path
plugins_dir (Path): 插件目录路径
Returns:
list: List of dicts containing plugin information
list: 包含插件信息的字典列表
"""
# Get local plugin info
# 获取本地插件信息
result = []
if plugins_dir.exists():
for plugin_name in [d.name for d in plugins_dir.glob("*") if d.is_dir()]:
plugin_dir = plugins_dir / plugin_name
# Load metadata from metadata.yaml
# metadata.yaml 加载元数据
metadata = load_yaml_metadata(plugin_dir)
if "desc" not in metadata and "description" in metadata:
metadata["desc"] = metadata["description"]
# If metadata loaded successfully, add to result list
# 如果成功加载元数据,添加到结果列表
if metadata and all(
k in metadata for k in ["name", "desc", "version", "author", "repo"]
):
@@ -140,7 +140,7 @@ def build_plug_list(plugins_dir: Path) -> list:
},
)
# Get online plugin list
# 获取在线插件列表
online_plugins = []
try:
with httpx.Client() as client:
@@ -160,13 +160,13 @@ def build_plug_list(plugins_dir: Path) -> list:
},
)
except Exception as e:
click.echo(f"Failed to get online plugin list: {e}", err=True)
click.echo(f"获取在线插件列表失败: {e}", err=True)
# Compare with online plugins and update status
# 与在线插件比对,更新状态
online_plugin_names = {plugin["name"] for plugin in online_plugins}
for local_plugin in result:
if local_plugin["name"] in online_plugin_names:
# Find the corresponding online plugin
# 查找对应的在线插件
online_plugin = next(
p for p in online_plugins if p["name"] == local_plugin["name"]
)
@@ -179,10 +179,10 @@ def build_plug_list(plugins_dir: Path) -> list:
):
local_plugin["status"] = PluginStatus.NEED_UPDATE
else:
# Local plugin is not published online
# 本地插件未在线上发布
local_plugin["status"] = PluginStatus.NOT_PUBLISHED
# Add uninstalled online plugins
# 添加未安装的在线插件
for online_plugin in online_plugins:
if not any(plugin["name"] == online_plugin["name"] for plugin in result):
result.append(online_plugin)
@@ -196,19 +196,19 @@ def manage_plugin(
is_update: bool = False,
proxy: str | None = None,
) -> None:
"""Install or update a plugin
"""安装或更新插件
Args:
plugin (dict): Plugin info dict
plugins_dir (Path): Plugins directory
is_update (bool, optional): Whether this is an update operation. Defaults to False
proxy (str, optional): Proxy server address
plugin (dict): 插件信息字典
plugins_dir (Path): 插件目录
is_update (bool, optional): 是否为更新操作. 默认为 False
proxy (str, optional): 代理服务器地址
"""
plugin_name = plugin["name"]
repo_url = plugin["repo"]
# If updating and local path exists, use it directly
# 如果是更新且有本地路径,直接使用本地路径
if is_update and plugin.get("local_path"):
target_path = Path(plugin["local_path"])
else:
@@ -216,13 +216,11 @@ def manage_plugin(
backup_path = Path(f"{target_path}_backup") if is_update else None
# Check if plugin exists
# 检查插件是否存在
if is_update and not target_path.exists():
raise click.ClickException(
f"Plugin {plugin_name} is not installed and cannot be updated"
)
raise click.ClickException(f"插件 {plugin_name} 未安装,无法更新")
# Backup existing plugin
# 备份现有插件
if is_update and backup_path is not None and backup_path.exists():
shutil.rmtree(backup_path)
if is_update and backup_path is not None:
@@ -230,21 +228,19 @@ def manage_plugin(
try:
click.echo(
f"{'Updating' if is_update else 'Downloading'} plugin {plugin_name} from {repo_url}...",
f"正在从 {repo_url} {'更新' if is_update else '下载'}插件 {plugin_name}...",
)
get_git_repo(repo_url, target_path, proxy)
# Update succeeded, delete backup
# 更新成功,删除备份
if is_update and backup_path is not None and backup_path.exists():
shutil.rmtree(backup_path)
click.echo(
f"Plugin {plugin_name} {'updated' if is_update else 'installed'} successfully"
)
click.echo(f"插件 {plugin_name} {'更新' if is_update else '安装'}成功")
except Exception as e:
if target_path.exists():
shutil.rmtree(target_path, ignore_errors=True)
if is_update and backup_path is not None and backup_path.exists():
shutil.move(backup_path, target_path)
raise click.ClickException(
f"Error {'updating' if is_update else 'installing'} plugin {plugin_name}: {e}",
f"{'更新' if is_update else '安装'}插件 {plugin_name} 时出错: {e}",
)
+11 -11
View File
@@ -1,4 +1,4 @@
"""Copied from astrbot.core.utils.version_comparator"""
"""拷贝自 astrbot.core.utils.version_comparator"""
import re
@@ -6,11 +6,11 @@ import re
class VersionComparator:
@staticmethod
def compare_version(v1: str, v2: str) -> int:
"""Compare version numbers according to Semver semantics. Supports version numbers with more than 3 digits and handles pre-release tags.
"""根据 Semver 语义版本规范来比较版本号的大小。支持不仅局限于 3 个数字的版本号,并处理预发布标签。
Reference: https://semver.org/
参考: https://semver.org/lang/zh-CN/
Returns 1 if v1 > v2, -1 if v1 < v2, 0 if v1 == v2.
返回 1 表示 v1 > v2,返回 -1 表示 v1 < v2,返回 0 表示 v1 = v2
"""
v1 = v1.lower().replace("v", "")
v2 = v2.lower().replace("v", "")
@@ -24,7 +24,7 @@ class VersionComparator:
return [], None
major_minor_patch = match.group(1).split(".")
prerelease = match.group(2)
# buildmetadata = match.group(3) # Build metadata is ignored in comparison
# buildmetadata = match.group(3) # 构建元数据在比较时忽略
parts = [int(x) for x in major_minor_patch]
prerelease = VersionComparator._split_prerelease(prerelease)
return parts, prerelease
@@ -32,7 +32,7 @@ class VersionComparator:
v1_parts, v1_prerelease = split_version(v1)
v2_parts, v2_prerelease = split_version(v2)
# Compare numeric parts
# 比较数字部分
length = max(len(v1_parts), len(v2_parts))
v1_parts.extend([0] * (length - len(v1_parts)))
v2_parts.extend([0] * (length - len(v2_parts)))
@@ -43,11 +43,11 @@ class VersionComparator:
if v1_parts[i] < v2_parts[i]:
return -1
# Compare pre-release tags
# 比较预发布标签
if v1_prerelease is None and v2_prerelease is not None:
return 1 # Version without pre-release tag is higher than one with it
return 1 # 没有预发布标签的版本高于有预发布标签的版本
if v1_prerelease is not None and v2_prerelease is None:
return -1 # Version with pre-release tag is lower than one without it
return -1 # 有预发布标签的版本低于没有预发布标签的版本
if v1_prerelease is not None and v2_prerelease is not None:
len_pre = max(len(v1_prerelease), len(v2_prerelease))
for i in range(len_pre):
@@ -72,9 +72,9 @@ class VersionComparator:
return 1
if p1 < p2:
return -1
return 0 # Pre-release tags are identical
return 0 # 预发布标签完全相同
return 0 # Both numeric parts and pre-release tags are equal
return 0 # 数字部分和预发布标签都相同
@staticmethod
def _split_prerelease(prerelease):
+1 -1
View File
@@ -14,7 +14,7 @@ from .utils.astrbot_path import get_astrbot_data_path
# 初始化数据存储文件夹
os.makedirs(get_astrbot_data_path(), exist_ok=True)
DEMO_MODE = os.getenv("DEMO_MODE", "False").strip().lower() in ("true", "1", "t")
DEMO_MODE = os.getenv("DEMO_MODE", False)
astrbot_config = AstrBotConfig()
t2i_base_url = astrbot_config.get("t2i_endpoint", "https://t2i.soulter.top/text2img")
+3 -5
View File
@@ -291,9 +291,6 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
except Exception:
continue
prov_settings: dict = ctx.get_config(umo=umo).get("provider_settings", {})
agent_max_step = int(prov_settings.get("max_agent_step", 30))
stream = prov_settings.get("streaming_response", False)
llm_resp = await ctx.tool_loop_agent(
event=event,
chat_provider_id=prov_id,
@@ -302,8 +299,9 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
system_prompt=tool.agent.instructions,
tools=toolset,
contexts=contexts,
max_steps=agent_max_step,
stream=stream,
max_steps=30,
run_hooks=tool.agent.run_hooks,
stream=ctx.get_config().get("provider_settings", {}).get("stream", False),
)
yield mcp.types.CallToolResult(
content=[mcp.types.TextContent(type="text", text=llm_resp.completion_text)]
+2 -72
View File
@@ -20,32 +20,18 @@ from astrbot.core.astr_agent_hooks import MAIN_AGENT_HOOKS
from astrbot.core.astr_agent_run_util import AgentRunner
from astrbot.core.astr_agent_tool_exec import FunctionToolExecutor
from astrbot.core.astr_main_agent_resources import (
ANNOTATE_EXECUTION_TOOL,
BROWSER_BATCH_EXEC_TOOL,
BROWSER_EXEC_TOOL,
CHATUI_SPECIAL_DEFAULT_PERSONA_PROMPT,
CREATE_SKILL_CANDIDATE_TOOL,
CREATE_SKILL_PAYLOAD_TOOL,
EVALUATE_SKILL_CANDIDATE_TOOL,
EXECUTE_SHELL_TOOL,
FILE_DOWNLOAD_TOOL,
FILE_UPLOAD_TOOL,
GET_EXECUTION_HISTORY_TOOL,
GET_SKILL_PAYLOAD_TOOL,
KNOWLEDGE_BASE_QUERY_TOOL,
LIST_SKILL_CANDIDATES_TOOL,
LIST_SKILL_RELEASES_TOOL,
LIVE_MODE_SYSTEM_PROMPT,
LLM_SAFETY_MODE_SYSTEM_PROMPT,
LOCAL_EXECUTE_SHELL_TOOL,
LOCAL_PYTHON_TOOL,
PROMOTE_SKILL_CANDIDATE_TOOL,
PYTHON_TOOL,
ROLLBACK_SKILL_RELEASE_TOOL,
RUN_BROWSER_SKILL_TOOL,
SANDBOX_MODE_PROMPT,
SEND_MESSAGE_TO_USER_TOOL,
SYNC_SKILL_RELEASE_TOOL,
TOOL_CALL_PROMPT,
TOOL_CALL_PROMPT_SKILLS_LIKE_MODE,
retrieve_knowledge_base,
@@ -846,10 +832,7 @@ def _apply_sandbox_tools(
) -> None:
if req.func_tool is None:
req.func_tool = ToolSet()
if req.system_prompt is None:
req.system_prompt = ""
booter = config.sandbox_cfg.get("booter", "shipyard_neo")
if booter == "shipyard":
if config.sandbox_cfg.get("booter") == "shipyard":
ep = config.sandbox_cfg.get("shipyard_endpoint", "")
at = config.sandbox_cfg.get("shipyard_access_token", "")
if not ep or not at:
@@ -857,64 +840,11 @@ def _apply_sandbox_tools(
return
os.environ["SHIPYARD_ENDPOINT"] = ep
os.environ["SHIPYARD_ACCESS_TOKEN"] = at
req.func_tool.add_tool(EXECUTE_SHELL_TOOL)
req.func_tool.add_tool(PYTHON_TOOL)
req.func_tool.add_tool(FILE_UPLOAD_TOOL)
req.func_tool.add_tool(FILE_DOWNLOAD_TOOL)
if booter == "shipyard_neo":
# Neo-specific path rule: filesystem tools operate relative to sandbox
# workspace root. Do not prepend "/workspace".
req.system_prompt += (
"\n[Shipyard Neo File Path Rule]\n"
"When using sandbox filesystem tools (upload/download/read/write/list/delete), "
"always pass paths relative to the sandbox workspace root. "
"Example: use `baidu_homepage.png` instead of `/workspace/baidu_homepage.png`.\n"
)
req.system_prompt += (
"\n[Neo Skill Lifecycle Workflow]\n"
"When user asks to create/update a reusable skill in Neo mode, use lifecycle tools instead of directly writing local skill folders.\n"
"Preferred sequence:\n"
"1) Use `astrbot_create_skill_payload` to store canonical payload content and get `payload_ref`.\n"
"2) Use `astrbot_create_skill_candidate` with `skill_key` + `source_execution_ids` (and optional `payload_ref`) to create a candidate.\n"
"3) Use `astrbot_promote_skill_candidate` to release: `stage=canary` for trial; `stage=stable` for production.\n"
"For stable release, set `sync_to_local=true` to sync `payload.skill_markdown` into local `SKILL.md`.\n"
"Do not treat ad-hoc generated files as reusable Neo skills unless they are captured via payload/candidate/release.\n"
"To update an existing skill, create a new payload/candidate and promote a new release version; avoid patching old local folders directly.\n"
)
# Determine sandbox capabilities from an already-booted session.
# If no session exists yet (first request), capabilities is None
# and we register all tools conservatively.
from astrbot.core.computer.computer_client import session_booter
sandbox_capabilities: list[str] | None = None
existing_booter = session_booter.get(session_id)
if existing_booter is not None:
sandbox_capabilities = getattr(existing_booter, "capabilities", None)
# Browser tools: only register if profile supports browser
# (or if capabilities are unknown because sandbox hasn't booted yet)
if sandbox_capabilities is None or "browser" in sandbox_capabilities:
req.func_tool.add_tool(BROWSER_EXEC_TOOL)
req.func_tool.add_tool(BROWSER_BATCH_EXEC_TOOL)
req.func_tool.add_tool(RUN_BROWSER_SKILL_TOOL)
# Neo-specific tools (always available for shipyard_neo)
req.func_tool.add_tool(GET_EXECUTION_HISTORY_TOOL)
req.func_tool.add_tool(ANNOTATE_EXECUTION_TOOL)
req.func_tool.add_tool(CREATE_SKILL_PAYLOAD_TOOL)
req.func_tool.add_tool(GET_SKILL_PAYLOAD_TOOL)
req.func_tool.add_tool(CREATE_SKILL_CANDIDATE_TOOL)
req.func_tool.add_tool(LIST_SKILL_CANDIDATES_TOOL)
req.func_tool.add_tool(EVALUATE_SKILL_CANDIDATE_TOOL)
req.func_tool.add_tool(PROMOTE_SKILL_CANDIDATE_TOOL)
req.func_tool.add_tool(LIST_SKILL_RELEASES_TOOL)
req.func_tool.add_tool(ROLLBACK_SKILL_RELEASE_TOOL)
req.func_tool.add_tool(SYNC_SKILL_RELEASE_TOOL)
req.system_prompt = f"{req.system_prompt or ''}\n{SANDBOX_MODE_PROMPT}\n"
req.system_prompt = f"{req.system_prompt}\n{SANDBOX_MODE_PROMPT}\n"
def _proactive_cron_job_tools(req: ProviderRequest) -> None:
-28
View File
@@ -13,25 +13,11 @@ from astrbot.core.agent.tool import FunctionTool, ToolExecResult
from astrbot.core.astr_agent_context import AstrAgentContext
from astrbot.core.computer.computer_client import get_booter
from astrbot.core.computer.tools import (
AnnotateExecutionTool,
BrowserBatchExecTool,
BrowserExecTool,
CreateSkillCandidateTool,
CreateSkillPayloadTool,
EvaluateSkillCandidateTool,
ExecuteShellTool,
FileDownloadTool,
FileUploadTool,
GetExecutionHistoryTool,
GetSkillPayloadTool,
ListSkillCandidatesTool,
ListSkillReleasesTool,
LocalPythonTool,
PromoteSkillCandidateTool,
PythonTool,
RollbackSkillReleaseTool,
RunBrowserSkillTool,
SyncSkillReleaseTool,
)
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.platform.message_session import MessageSession
@@ -463,20 +449,6 @@ PYTHON_TOOL = PythonTool()
LOCAL_PYTHON_TOOL = LocalPythonTool()
FILE_UPLOAD_TOOL = FileUploadTool()
FILE_DOWNLOAD_TOOL = FileDownloadTool()
BROWSER_EXEC_TOOL = BrowserExecTool()
BROWSER_BATCH_EXEC_TOOL = BrowserBatchExecTool()
RUN_BROWSER_SKILL_TOOL = RunBrowserSkillTool()
GET_EXECUTION_HISTORY_TOOL = GetExecutionHistoryTool()
ANNOTATE_EXECUTION_TOOL = AnnotateExecutionTool()
CREATE_SKILL_PAYLOAD_TOOL = CreateSkillPayloadTool()
GET_SKILL_PAYLOAD_TOOL = GetSkillPayloadTool()
CREATE_SKILL_CANDIDATE_TOOL = CreateSkillCandidateTool()
LIST_SKILL_CANDIDATES_TOOL = ListSkillCandidatesTool()
EVALUATE_SKILL_CANDIDATE_TOOL = EvaluateSkillCandidateTool()
PROMOTE_SKILL_CANDIDATE_TOOL = PromoteSkillCandidateTool()
LIST_SKILL_RELEASES_TOOL = ListSkillReleasesTool()
ROLLBACK_SKILL_RELEASE_TOOL = RollbackSkillReleaseTool()
SYNC_SKILL_RELEASE_TOOL = SyncSkillReleaseTool()
# we prevent astrbot from connecting to known malicious hosts
# these hosts are base64 encoded
+1 -19
View File
@@ -1,9 +1,4 @@
from ..olayer import (
BrowserComponent,
FileSystemComponent,
PythonComponent,
ShellComponent,
)
from ..olayer import FileSystemComponent, PythonComponent, ShellComponent
class ComputerBooter:
@@ -16,19 +11,6 @@ class ComputerBooter:
@property
def shell(self) -> ShellComponent: ...
@property
def capabilities(self) -> tuple[str, ...] | None:
"""Sandbox capabilities (e.g. ('python', 'shell', 'filesystem', 'browser')).
Returns None if the booter doesn't support capability introspection
(backward-compatible default). Subclasses override after boot.
"""
return None
@property
def browser(self) -> BrowserComponent | None:
return None
async def boot(self, session_id: str) -> None: ...
async def shutdown(self) -> None: ...
@@ -1,258 +0,0 @@
"""Manage Bay container lifecycle for zero-config Shipyard Neo integration.
When no Bay endpoint is configured, AstrBot can automatically start a Bay
container using the Docker socket (like BoxliteBooter does for Ship
containers).
"""
from __future__ import annotations
import asyncio
import io
import json
import tarfile
from typing import Any
import aiodocker
import aiohttp
from astrbot.api import logger
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
BAY_IMAGE = "ghcr.io/astrbotdevs/shipyard-neo-bay:latest"
BAY_CONTAINER_NAME = "astrbot-bay"
BAY_LABEL = "astrbot.bay.managed"
BAY_PORT = 8114
HEALTH_TIMEOUT_S = 60
HEALTH_POLL_INTERVAL_S = 2
class BayContainerManager:
"""Start / reuse / stop a Bay container via Docker Engine API."""
def __init__(
self,
image: str = BAY_IMAGE,
host_port: int = BAY_PORT,
) -> None:
self._image = image
self._host_port = host_port
self._docker: aiodocker.Docker | None = None
self._container: Any = None
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
async def ensure_running(self) -> str:
"""Make sure a Bay container is running. Returns the endpoint URL.
If a container labelled ``astrbot.bay.managed`` already exists
and is running, it will be reused. Otherwise a new container is
created from *self._image*.
"""
try:
self._docker = aiodocker.Docker()
except Exception as exc:
raise RuntimeError(
"Failed to connect to Docker daemon. "
"Ensure Docker is installed and running, or configure "
"an explicit Bay endpoint instead of auto-start mode."
) from exc
# 1. Look for an existing managed container
existing = await self._find_managed_container()
if existing is not None:
state = existing["State"]
if state.get("Running"):
cid = existing["Id"][:12]
logger.info("[BayManager] Reusing existing Bay container: %s", cid)
self._container = await self._docker.containers.get(existing["Id"])
return f"http://127.0.0.1:{self._host_port}"
else:
# Container exists but stopped — restart it
logger.info("[BayManager] Restarting stopped Bay container")
container = await self._docker.containers.get(existing["Id"])
await container.start()
self._container = container
return f"http://127.0.0.1:{self._host_port}"
# 2. Pull image if needed
await self._pull_image_if_needed()
# 3. Create and start container
logger.info(
"[BayManager] Starting Bay container: image=%s, port=%d",
self._image,
self._host_port,
)
config = {
"Image": self._image,
"Labels": {BAY_LABEL: "true"},
"Env": [
"BAY_SERVER__HOST=0.0.0.0",
f"BAY_SERVER__PORT={BAY_PORT}",
"BAY_DATA_DIR=/app/data",
# allow_anonymous=false → auto-provisions API key
"BAY_SECURITY__ALLOW_ANONYMOUS=false",
],
"HostConfig": {
"PortBindings": {
f"{BAY_PORT}/tcp": [{"HostPort": str(self._host_port)}],
},
"Binds": [
# Bay needs Docker socket to create sandbox containers
"/var/run/docker.sock:/var/run/docker.sock",
],
"RestartPolicy": {"Name": "unless-stopped"},
},
}
self._container = await self._docker.containers.create_or_replace(
BAY_CONTAINER_NAME, config
)
await self._container.start()
logger.info("[BayManager] Bay container started: %s", BAY_CONTAINER_NAME)
return f"http://127.0.0.1:{self._host_port}"
async def wait_healthy(self, timeout: int = HEALTH_TIMEOUT_S) -> None:
"""Block until Bay's ``/health`` endpoint returns 200."""
url = f"http://127.0.0.1:{self._host_port}/health"
deadline = asyncio.get_event_loop().time() + timeout
last_error: str = ""
async with aiohttp.ClientSession() as session:
while asyncio.get_event_loop().time() < deadline:
try:
async with session.get(
url, timeout=aiohttp.ClientTimeout(total=3)
) as resp:
if resp.status == 200:
logger.info("[BayManager] Bay is healthy")
return
last_error = f"HTTP {resp.status}"
except Exception as exc:
last_error = str(exc)
await asyncio.sleep(HEALTH_POLL_INTERVAL_S)
raise TimeoutError(
f"Bay did not become healthy within {timeout}s (last error: {last_error})"
)
async def read_credentials(self) -> str:
"""Read auto-provisioned API key from Bay container.
Bay writes ``credentials.json`` to its data directory when
``allow_anonymous=false`` and no explicit API key is set.
"""
if self._container is None:
return ""
try:
# Read credentials.json from container filesystem
tar_stream = await self._container.get_archive("/app/data/credentials.json")
# get_archive returns (tar_data, stat)
tar_data = tar_stream
if isinstance(tar_data, dict):
raw = tar_data.get("data", b"")
elif isinstance(tar_data, tuple):
# (stream, stat_info)
raw = b""
stream = tar_data[0]
if hasattr(stream, "read"):
raw = await stream.read()
elif isinstance(stream, bytes):
raw = stream
else:
# It might be a chunked response
chunks = []
async for chunk in stream:
chunks.append(chunk)
raw = b"".join(chunks)
else:
raw = tar_data if isinstance(tar_data, bytes) else b""
if not raw:
logger.debug("[BayManager] Empty tar response from container")
return ""
tario = io.BytesIO(raw)
with tarfile.open(fileobj=tario) as tar:
for member in tar.getmembers():
f = tar.extractfile(member)
if f:
creds = json.loads(f.read().decode("utf-8"))
api_key = creds.get("api_key", "")
if api_key:
masked = (
f"{api_key[:8]}..."
if len(api_key) >= 10
else "redacted"
)
logger.info(
"[BayManager] Auto-discovered Bay API key: %s",
masked,
)
return api_key
except Exception as exc:
logger.debug(
"[BayManager] Failed to read credentials from container: %s", exc
)
return ""
async def close_client(self) -> None:
"""Close the Docker client without stopping the container.
The Bay container stays running for reuse by future sessions.
"""
if self._docker is not None:
await self._docker.close()
self._docker = None
async def stop(self) -> None:
"""Stop and remove the managed Bay container."""
if self._container is not None:
try:
await self._container.stop()
await self._container.delete(force=True)
logger.info("[BayManager] Bay container stopped and removed")
except Exception as exc:
logger.debug("[BayManager] Error stopping Bay container: %s", exc)
finally:
self._container = None
await self.close_client()
# ------------------------------------------------------------------
# Private helpers
# ------------------------------------------------------------------
async def _find_managed_container(self) -> dict | None:
"""Find an existing container with our management label."""
assert self._docker is not None
containers = await self._docker.containers.list(
all=True,
filters=json.dumps({"label": [f"{BAY_LABEL}=true"]}),
)
if containers:
# Inspect first match to get full state
return await containers[0].show()
return None
async def _pull_image_if_needed(self) -> None:
"""Pull the Bay image if it doesn't exist locally."""
assert self._docker is not None
try:
await self._docker.images.inspect(self._image)
logger.debug("[BayManager] Image %s already exists", self._image)
except aiodocker.exceptions.DockerError:
logger.info("[BayManager] Pulling image %s ...", self._image)
# Pull with progress logging
await self._docker.images.pull(self._image)
logger.info("[BayManager] Image %s pulled successfully", self._image)
-4
View File
@@ -64,10 +64,6 @@ class MockShipyardSandboxClient:
async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.post(url, data=data) as response:
if response.status == 200:
logger.info(
"[Computer] File uploaded to Boxlite sandbox: %s",
remote_path,
)
return {
"success": True,
"message": "File uploaded successfully",
+3 -20
View File
@@ -31,7 +31,7 @@ class ShipyardBooter(ComputerBooter):
self._ship = ship
async def shutdown(self) -> None:
logger.info("[Computer] Shipyard booter shutdown.")
pass
@property
def fs(self) -> FileSystemComponent:
@@ -47,19 +47,11 @@ class ShipyardBooter(ComputerBooter):
async def upload_file(self, path: str, file_name: str) -> dict:
"""Upload file to sandbox"""
result = await self._ship.upload_file(path, file_name)
logger.info("[Computer] File uploaded to Shipyard sandbox: %s", file_name)
return result
return await self._ship.upload_file(path, file_name)
async def download_file(self, remote_path: str, local_path: str):
"""Download file from sandbox."""
result = await self._ship.download_file(remote_path, local_path)
logger.info(
"[Computer] File downloaded from Shipyard sandbox: %s -> %s",
remote_path,
local_path,
)
return result
return await self._ship.download_file(remote_path, local_path)
async def available(self) -> bool:
"""Check if the sandbox is available."""
@@ -67,17 +59,8 @@ class ShipyardBooter(ComputerBooter):
ship_id = self._ship.id
data = await self._sandbox_client.get_ship(ship_id)
if not data:
logger.info(
"[Computer] Shipyard sandbox health check: id=%s, healthy=False (no data)",
ship_id,
)
return False
health = bool(data.get("status", 0) == 1)
logger.info(
"[Computer] Shipyard sandbox health check: id=%s, healthy=%s",
ship_id,
health,
)
return health
except Exception as e:
logger.error(f"Error checking Shipyard sandbox availability: {e}")
@@ -1,513 +0,0 @@
from __future__ import annotations
import os
import shlex
from typing import Any, cast
from astrbot.api import logger
from ..olayer import (
BrowserComponent,
FileSystemComponent,
PythonComponent,
ShellComponent,
)
from .base import ComputerBooter
def _maybe_model_dump(value: Any) -> dict[str, Any]:
if isinstance(value, dict):
return value
if hasattr(value, "model_dump"):
dumped = value.model_dump()
if isinstance(dumped, dict):
return dumped
return {}
class NeoPythonComponent(PythonComponent):
def __init__(self, sandbox: Any) -> None:
self._sandbox = sandbox
async def exec(
self,
code: str,
kernel_id: str | None = None,
timeout: int = 30,
silent: bool = False,
) -> dict[str, Any]:
_ = kernel_id # Bay runtime does not expose kernel_id in current SDK.
result = await self._sandbox.python.exec(code, timeout=timeout)
payload = _maybe_model_dump(result)
output_text = payload.get("output", "") or ""
error_text = payload.get("error", "") or ""
data = payload.get("data") if isinstance(payload.get("data"), dict) else {}
rich_output = data.get("output") if isinstance(data.get("output"), dict) else {}
if not isinstance(rich_output.get("images"), list):
rich_output["images"] = []
if "text" not in rich_output:
rich_output["text"] = output_text
if silent:
rich_output["text"] = ""
return {
"success": bool(payload.get("success", error_text == "")),
"data": {
"output": rich_output,
"error": error_text,
},
"execution_id": payload.get("execution_id"),
"execution_time_ms": payload.get("execution_time_ms"),
"code": payload.get("code"),
"output": output_text,
"error": error_text,
}
class NeoShellComponent(ShellComponent):
def __init__(self, sandbox: Any) -> None:
self._sandbox = sandbox
async def exec(
self,
command: str,
cwd: str | None = None,
env: dict[str, str] | None = None,
timeout: int | None = 30,
shell: bool = True,
background: bool = False,
) -> dict[str, Any]:
if not shell:
return {
"stdout": "",
"stderr": "error: only shell mode is supported in shipyard_neo booter.",
"exit_code": 2,
"success": False,
}
run_command = command
if env:
env_prefix = " ".join(
f"{k}={shlex.quote(str(v))}" for k, v in sorted(env.items())
)
run_command = f"{env_prefix} {run_command}"
if background:
run_command = f"nohup sh -lc {shlex.quote(run_command)} >/tmp/astrbot_bg.log 2>&1 & echo $!"
result = await self._sandbox.shell.exec(
run_command,
timeout=timeout or 30,
cwd=cwd,
)
payload = _maybe_model_dump(result)
stdout = payload.get("output", "") or ""
stderr = payload.get("error", "") or ""
exit_code = payload.get("exit_code")
if background:
pid: int | None = None
try:
pid = int(stdout.strip().splitlines()[-1])
except Exception:
pid = None
return {
"pid": pid,
"stdout": stdout,
"stderr": stderr,
"exit_code": exit_code,
"success": bool(payload.get("success", not stderr)),
"execution_id": payload.get("execution_id"),
"execution_time_ms": payload.get("execution_time_ms"),
"command": payload.get("command"),
}
return {
"stdout": stdout,
"stderr": stderr,
"exit_code": exit_code,
"success": bool(payload.get("success", not stderr)),
"execution_id": payload.get("execution_id"),
"execution_time_ms": payload.get("execution_time_ms"),
"command": payload.get("command"),
}
class NeoFileSystemComponent(FileSystemComponent):
def __init__(self, sandbox: Any) -> None:
self._sandbox = sandbox
async def create_file(
self,
path: str,
content: str = "",
mode: int = 0o644,
) -> dict[str, Any]:
_ = mode
await self._sandbox.filesystem.write_file(path, content)
return {"success": True, "path": path}
async def read_file(self, path: str, encoding: str = "utf-8") -> dict[str, Any]:
_ = encoding
content = await self._sandbox.filesystem.read_file(path)
return {"success": True, "path": path, "content": content}
async def write_file(
self,
path: str,
content: str,
mode: str = "w",
encoding: str = "utf-8",
) -> dict[str, Any]:
_ = mode
_ = encoding
await self._sandbox.filesystem.write_file(path, content)
return {"success": True, "path": path}
async def delete_file(self, path: str) -> dict[str, Any]:
await self._sandbox.filesystem.delete(path)
return {"success": True, "path": path}
async def list_dir(
self,
path: str = ".",
show_hidden: bool = False,
) -> dict[str, Any]:
entries = await self._sandbox.filesystem.list_dir(path)
data = []
for entry in entries:
item = _maybe_model_dump(entry)
if not show_hidden and str(item.get("name", "")).startswith("."):
continue
data.append(item)
return {"success": True, "path": path, "entries": data}
class NeoBrowserComponent(BrowserComponent):
def __init__(self, sandbox: Any) -> None:
self._sandbox = sandbox
async def exec(
self,
cmd: str,
timeout: int = 30,
description: str | None = None,
tags: str | None = None,
learn: bool = False,
include_trace: bool = False,
) -> dict[str, Any]:
result = await self._sandbox.browser.exec(
cmd,
timeout=timeout,
description=description,
tags=tags,
learn=learn,
include_trace=include_trace,
)
return _maybe_model_dump(result)
async def exec_batch(
self,
commands: list[str],
timeout: int = 60,
stop_on_error: bool = True,
description: str | None = None,
tags: str | None = None,
learn: bool = False,
include_trace: bool = False,
) -> dict[str, Any]:
result = await self._sandbox.browser.exec_batch(
commands,
timeout=timeout,
stop_on_error=stop_on_error,
description=description,
tags=tags,
learn=learn,
include_trace=include_trace,
)
return _maybe_model_dump(result)
async def run_skill(
self,
skill_key: str,
timeout: int = 60,
stop_on_error: bool = True,
include_trace: bool = False,
description: str | None = None,
tags: str | None = None,
) -> dict[str, Any]:
result = await self._sandbox.browser.run_skill(
skill_key=skill_key,
timeout=timeout,
stop_on_error=stop_on_error,
include_trace=include_trace,
description=description,
tags=tags,
)
return _maybe_model_dump(result)
class ShipyardNeoBooter(ComputerBooter):
"""Booter backed by Shipyard Neo (Bay).
If *endpoint_url* is empty or set to ``"__auto__"``, Bay will be
started automatically as a Docker container (like Boxlite does for
Ship containers).
"""
AUTO_SENTINEL = "__auto__"
DEFAULT_PROFILE = "python-default"
def __init__(
self,
endpoint_url: str,
access_token: str,
profile: str = DEFAULT_PROFILE,
ttl: int = 3600,
) -> None:
self._endpoint_url = endpoint_url
self._access_token = access_token
self._profile = profile
self._ttl = ttl
self._client: Any = None
self._sandbox: Any = None
self._bay_manager: Any = None # BayContainerManager when auto-started
self._fs: FileSystemComponent | None = None
self._python: PythonComponent | None = None
self._shell: ShellComponent | None = None
self._browser: BrowserComponent | None = None
@property
def bay_client(self) -> Any:
return self._client
@property
def sandbox(self) -> Any:
return self._sandbox
@property
def capabilities(self) -> tuple[str, ...] | None:
"""Sandbox capabilities from the Bay profile.
Returns an immutable tuple after :meth:`boot`; ``None`` before boot.
"""
if self._sandbox is None:
return None
caps = getattr(self._sandbox, "capabilities", None)
return tuple(caps) if caps is not None else None
@property
def is_auto_mode(self) -> bool:
"""True when Bay should be auto-started."""
ep = (self._endpoint_url or "").strip()
return not ep or ep == self.AUTO_SENTINEL
async def boot(self, session_id: str) -> None:
_ = session_id
# --- Auto-start Bay if needed ---
if self.is_auto_mode:
from .bay_manager import BayContainerManager
# Clean up previous manager if re-booting
if self._bay_manager is not None:
await self._bay_manager.close_client()
logger.info("[Computer] Neo auto-start mode: launching Bay container")
self._bay_manager = BayContainerManager()
self._endpoint_url = await self._bay_manager.ensure_running()
await self._bay_manager.wait_healthy()
# Read auto-provisioned credentials
if not self._access_token:
self._access_token = await self._bay_manager.read_credentials()
logger.info("[Computer] Bay auto-started at %s", self._endpoint_url)
if not self._endpoint_url or not self._access_token:
if self._bay_manager is not None:
raise ValueError(
"Bay container started but credentials could not be read. "
"Ensure Bay generated credentials.json, or set access_token manually."
)
raise ValueError(
"Shipyard Neo sandbox configuration is incomplete. "
"Set endpoint (default http://127.0.0.1:8114) and access token, "
"or ensure Bay's credentials.json is accessible for auto-discovery."
)
from shipyard_neo import BayClient
self._client = BayClient(
endpoint_url=self._endpoint_url,
access_token=self._access_token,
)
await self._client.__aenter__()
# Resolve profile: user-specified > smart selection > default
resolved_profile = await self._resolve_profile(self._client)
self._sandbox = await self._client.create_sandbox(
profile=resolved_profile,
ttl=self._ttl,
)
self._fs = NeoFileSystemComponent(self._sandbox)
self._python = NeoPythonComponent(self._sandbox)
self._shell = NeoShellComponent(self._sandbox)
caps = self.capabilities or ()
self._browser = (
NeoBrowserComponent(self._sandbox) if "browser" in caps else None
)
logger.info(
"Got Shipyard Neo sandbox: %s (profile=%s, capabilities=%s, auto=%s)",
self._sandbox.id,
resolved_profile,
list(caps),
bool(self._bay_manager),
)
async def _resolve_profile(self, client: Any) -> str:
"""Pick the best profile for this session.
Resolution order:
1. User-specified profile (non-empty, non-default) use as-is.
2. Query ``GET /v1/profiles`` and pick the profile with the most
capabilities, preferring profiles that include ``"browser"``.
3. Fall back to :attr:`DEFAULT_PROFILE`.
Auth errors (401/403) are re-raised immediately they indicate a
misconfigured token, and silently falling back would just delay the
real failure to ``create_sandbox``.
"""
# User explicitly set a profile → honour it
if self._profile and self._profile != self.DEFAULT_PROFILE:
logger.info("[Computer] Using user-specified profile: %s", self._profile)
return self._profile
# Query Bay for available profiles
from shipyard_neo.errors import ForbiddenError, UnauthorizedError
try:
profile_list = await client.list_profiles()
profiles = profile_list.items
except (UnauthorizedError, ForbiddenError):
raise # auth errors must not be silenced
except Exception as exc:
logger.warning(
"[Computer] Failed to query Bay profiles, falling back to %s: %s",
self.DEFAULT_PROFILE,
exc,
)
return self.DEFAULT_PROFILE
if not profiles:
return self.DEFAULT_PROFILE
def _score(p: Any) -> tuple[int, int]:
"""(has_browser, capability_count) — higher is better."""
caps = getattr(p, "capabilities", []) or []
return (1 if "browser" in caps else 0, len(caps))
best = max(profiles, key=_score)
chosen = getattr(best, "id", self.DEFAULT_PROFILE)
if chosen != self.DEFAULT_PROFILE:
caps = getattr(best, "capabilities", [])
logger.info(
"[Computer] Auto-selected profile %s (capabilities=%s)",
chosen,
caps,
)
return chosen
async def shutdown(self) -> None:
if self._client is not None:
sandbox_id = getattr(self._sandbox, "id", "unknown")
logger.info(
"[Computer] Shutting down Shipyard Neo sandbox: id=%s", sandbox_id
)
await self._client.__aexit__(None, None, None)
self._client = None
self._sandbox = None
logger.info("[Computer] Shipyard Neo sandbox shut down: id=%s", sandbox_id)
# NOTE: We intentionally do NOT stop the Bay container here.
# It stays running for reuse by future sessions. The user can
# stop it manually or via ``BayContainerManager.stop()``.
if self._bay_manager is not None:
await self._bay_manager.close_client()
@property
def fs(self) -> FileSystemComponent:
if self._fs is None:
raise RuntimeError("ShipyardNeoBooter is not initialized.")
return self._fs
@property
def python(self) -> PythonComponent:
if self._python is None:
raise RuntimeError("ShipyardNeoBooter is not initialized.")
return self._python
@property
def shell(self) -> ShellComponent:
if self._shell is None:
raise RuntimeError("ShipyardNeoBooter is not initialized.")
return self._shell
@property
def browser(self) -> BrowserComponent:
if self._browser is None:
raise RuntimeError("ShipyardNeoBooter is not initialized.")
return self._browser
async def upload_file(self, path: str, file_name: str) -> dict:
if self._sandbox is None:
raise RuntimeError("ShipyardNeoBooter is not initialized.")
with open(path, "rb") as f:
content = f.read()
remote_path = file_name.lstrip("/")
await self._sandbox.filesystem.upload(remote_path, content)
logger.info("[Computer] File uploaded to Neo sandbox: %s", remote_path)
return {
"success": True,
"message": "File uploaded successfully",
"file_path": remote_path,
}
async def download_file(self, remote_path: str, local_path: str) -> None:
if self._sandbox is None:
raise RuntimeError("ShipyardNeoBooter is not initialized.")
content = await self._sandbox.filesystem.download(remote_path.lstrip("/"))
local_dir = os.path.dirname(local_path)
if local_dir:
os.makedirs(local_dir, exist_ok=True)
with open(local_path, "wb") as f:
f.write(cast(bytes, content))
logger.info(
"[Computer] File downloaded from Neo sandbox: %s -> %s",
remote_path,
local_path,
)
async def available(self) -> bool:
if self._sandbox is None:
return False
try:
await self._sandbox.refresh()
status = getattr(self._sandbox.status, "value", str(self._sandbox.status))
healthy = status not in {"failed", "expired"}
logger.info(
"[Computer] Neo sandbox health check: id=%s, status=%s, healthy=%s",
getattr(self._sandbox, "id", "unknown"),
status,
healthy,
)
return healthy
except Exception as e:
logger.error(f"Error checking Shipyard Neo sandbox availability: {e}")
return False
+31 -433
View File
@@ -1,11 +1,10 @@
import json
import os
import shutil
import uuid
from pathlib import Path
from astrbot.api import logger
from astrbot.core.skills.skill_manager import SANDBOX_SKILLS_ROOT, SkillManager
from astrbot.core.skills.skill_manager import SANDBOX_SKILLS_ROOT
from astrbot.core.star.context import Context
from astrbot.core.utils.astrbot_path import (
get_astrbot_skills_path,
@@ -17,401 +16,45 @@ from .booters.local import LocalBooter
session_booter: dict[str, ComputerBooter] = {}
local_booter: ComputerBooter | None = None
_MANAGED_SKILLS_FILE = ".astrbot_managed_skills.json"
def _list_local_skill_dirs(skills_root: Path) -> list[Path]:
skills: list[Path] = []
for entry in sorted(skills_root.iterdir()):
if not entry.is_dir():
continue
skill_md = entry / "SKILL.md"
if skill_md.exists():
skills.append(entry)
return skills
def _discover_bay_credentials(endpoint: str) -> str:
"""Try to auto-discover Bay API key from credentials.json.
Search order:
1. BAY_DATA_DIR env var
2. Mono-repo relative path: ../pkgs/bay/ (dev layout)
3. Current working directory
Returns:
API key string, or empty string if not found.
"""
candidates: list[Path] = []
# 1. BAY_DATA_DIR env var
bay_data_dir = os.environ.get("BAY_DATA_DIR")
if bay_data_dir:
candidates.append(Path(bay_data_dir) / "credentials.json")
# 2. Mono-repo layout: AstrBot/../pkgs/bay/credentials.json
astrbot_root = Path(__file__).resolve().parents[3] # astrbot/core/computer/ → root
candidates.append(astrbot_root.parent / "pkgs" / "bay" / "credentials.json")
# 3. Current working directory
candidates.append(Path.cwd() / "credentials.json")
for cred_path in candidates:
if not cred_path.is_file():
continue
try:
data = json.loads(cred_path.read_text())
api_key = data.get("api_key", "")
if api_key:
# Optionally verify endpoint matches
cred_endpoint = data.get("endpoint", "")
if (
cred_endpoint
and endpoint
and cred_endpoint.rstrip("/") != endpoint.rstrip("/")
):
logger.warning(
"[Computer] credentials.json endpoint mismatch: "
"file=%s, configured=%s — using key anyway",
cred_endpoint,
endpoint,
)
masked_key = f"{api_key[:4]}..." if len(api_key) >= 6 else "redacted"
logger.info(
"[Computer] Auto-discovered Bay API key from %s (prefix=%s)",
cred_path,
masked_key,
)
return api_key
except (json.JSONDecodeError, OSError) as exc:
logger.debug("[Computer] Failed to read %s: %s", cred_path, exc)
logger.debug("[Computer] No Bay credentials.json found in search paths")
return ""
def _build_python_exec_command(script: str) -> str:
return (
"if command -v python3 >/dev/null 2>&1; then PYBIN=python3; "
"elif command -v python >/dev/null 2>&1; then PYBIN=python; "
"else echo 'python not found in sandbox' >&2; exit 127; fi; "
"$PYBIN - <<'PY'\n"
f"{script}\n"
"PY"
)
def _build_apply_sync_command() -> str:
"""Build shell command for sync stage only.
This stage mutates sandbox files (managed skill replacement) but does not scan
metadata. Keeping it separate allows callers to preserve old behavior while
reusing the apply step independently.
"""
script = f"""
import json
import shutil
import zipfile
from pathlib import Path
root = Path({SANDBOX_SKILLS_ROOT!r})
zip_path = root / "skills.zip"
tmp_extract = Path(f"{{root}}_tmp_extract")
managed_file = root / {_MANAGED_SKILLS_FILE!r}
def remove_tree(path: Path) -> None:
if not path.exists():
return
if path.is_dir():
shutil.rmtree(path, ignore_errors=True)
else:
path.unlink(missing_ok=True)
def load_managed_skills() -> list[str]:
if not managed_file.exists():
return []
try:
payload = json.loads(managed_file.read_text(encoding="utf-8"))
except Exception:
return []
if not isinstance(payload, dict):
return []
items = payload.get("managed_skills", [])
if not isinstance(items, list):
return []
result: list[str] = []
for item in items:
if isinstance(item, str) and item.strip():
result.append(item.strip())
return result
root.mkdir(parents=True, exist_ok=True)
for managed_name in load_managed_skills():
remove_tree(root / managed_name)
current_managed: list[str] = []
if zip_path.exists():
remove_tree(tmp_extract)
tmp_extract.mkdir(parents=True, exist_ok=True)
with zipfile.ZipFile(zip_path) as zf:
zf.extractall(tmp_extract)
for entry in sorted(tmp_extract.iterdir()):
if not entry.is_dir():
continue
target = root / entry.name
remove_tree(target)
shutil.copytree(entry, target)
current_managed.append(entry.name)
remove_tree(tmp_extract)
remove_tree(zip_path)
managed_file.write_text(
json.dumps({{"managed_skills": current_managed}}, ensure_ascii=False, indent=2),
encoding="utf-8",
)
print(json.dumps({{"managed_skills": current_managed}}, ensure_ascii=False))
""".strip()
return _build_python_exec_command(script)
def _build_scan_command() -> str:
"""Build shell command for scan stage only.
This stage is read-oriented: it scans SKILL.md metadata and returns the
historical payload shape consumed by cache update logic.
The scan resolves the absolute path of the skills root at runtime so
that the LLM can reliably ``cat`` skill files regardless of cwd.
Only the ``description`` field is extracted from frontmatter.
"""
script = f"""
import json
from pathlib import Path
root = Path({SANDBOX_SKILLS_ROOT!r})
managed_file = root / {_MANAGED_SKILLS_FILE!r}
# Resolve absolute path at runtime so prompts always have a reliable path
root_abs = str(root.resolve())
# NOTE: This parser mirrors skill_manager._parse_frontmatter_description.
# Keep the two implementations in sync when changing parsing logic.
def parse_description(text: str) -> str:
if not text.startswith("---"):
return ""
lines = text.splitlines()
if not lines or lines[0].strip() != "---":
return ""
end_idx = None
for i in range(1, len(lines)):
if lines[i].strip() == "---":
end_idx = i
break
if end_idx is None:
return ""
for line in lines[1:end_idx]:
if ":" not in line:
continue
key, value = line.split(":", 1)
if key.strip().lower() == "description":
return value.strip().strip('"').strip("'")
return ""
def load_managed_skills() -> list[str]:
if not managed_file.exists():
return []
try:
payload = json.loads(managed_file.read_text(encoding="utf-8"))
except Exception:
return []
if not isinstance(payload, dict):
return []
items = payload.get("managed_skills", [])
if not isinstance(items, list):
return []
result: list[str] = []
for item in items:
if isinstance(item, str) and item.strip():
result.append(item.strip())
return result
def collect_skills() -> list[dict[str, str]]:
skills: list[dict[str, str]] = []
if not root.exists():
return skills
for skill_dir in sorted(root.iterdir()):
if not skill_dir.is_dir():
continue
skill_md = skill_dir / "SKILL.md"
if not skill_md.is_file():
continue
description = ""
try:
text = skill_md.read_text(encoding="utf-8")
description = parse_description(text)
except Exception:
description = ""
skills.append(
{{
"name": skill_dir.name,
"description": description,
"path": f"{{root_abs}}/{{skill_dir.name}}/SKILL.md",
}}
)
return skills
print(
json.dumps(
{{
"managed_skills": load_managed_skills(),
"skills": collect_skills(),
}},
ensure_ascii=False,
)
)
""".strip()
return _build_python_exec_command(script)
def _build_sync_and_scan_command() -> str:
"""Legacy combined command kept for backward compatibility.
New code paths should prefer apply + scan split helpers.
"""
return f"{_build_apply_sync_command()}\n{_build_scan_command()}"
def _shell_exec_succeeded(result: dict) -> bool:
if "success" in result:
return bool(result.get("success"))
exit_code = result.get("exit_code")
return exit_code in (0, None)
def _format_exec_error_detail(result: dict) -> str:
"""Format shell execution details for better observability.
Keep the message compact while still surfacing exit code and stderr/stdout.
"""
exit_code = result.get("exit_code")
stderr = str(result.get("stderr", "") or "").strip()
stdout = str(result.get("stdout", "") or "").strip()
stderr_text = stderr[:500]
stdout_text = stdout[:300]
return f"exit_code={exit_code}, stderr={stderr_text!r}, stdout_tail={stdout_text!r}"
def _decode_sync_payload(stdout: str) -> dict | None:
text = stdout.strip()
if not text:
return None
candidates = [text]
candidates.extend([line.strip() for line in text.splitlines() if line.strip()])
for candidate in reversed(candidates):
try:
payload = json.loads(candidate)
except Exception:
continue
if isinstance(payload, dict):
return payload
return None
def _update_sandbox_skills_cache(payload: dict | None) -> None:
if not isinstance(payload, dict):
return
skills = payload.get("skills", [])
if not isinstance(skills, list):
return
SkillManager().set_sandbox_skills_cache(skills)
async def _apply_skills_to_sandbox(booter: ComputerBooter) -> None:
"""Apply local skill bundle to sandbox filesystem only.
This function is intentionally limited to file mutation. Metadata scanning is
executed in a separate phase to keep failure domains clear.
"""
logger.info("[Computer] Skill sync phase=apply start")
apply_result = await booter.shell.exec(_build_apply_sync_command())
if not _shell_exec_succeeded(apply_result):
detail = _format_exec_error_detail(apply_result)
logger.error("[Computer] Skill sync phase=apply failed: %s", detail)
raise RuntimeError(f"Failed to apply sandbox skill sync strategy: {detail}")
logger.info("[Computer] Skill sync phase=apply done")
async def _scan_sandbox_skills(booter: ComputerBooter) -> dict | None:
"""Scan sandbox skills and return normalized payload for cache update."""
logger.info("[Computer] Skill sync phase=scan start")
scan_result = await booter.shell.exec(_build_scan_command())
if not _shell_exec_succeeded(scan_result):
detail = _format_exec_error_detail(scan_result)
logger.error("[Computer] Skill sync phase=scan failed: %s", detail)
raise RuntimeError(f"Failed to scan sandbox skills after sync: {detail}")
payload = _decode_sync_payload(str(scan_result.get("stdout", "") or ""))
if payload is None:
logger.warning("[Computer] Skill sync phase=scan returned empty payload")
else:
logger.info("[Computer] Skill sync phase=scan done")
return payload
async def _sync_skills_to_sandbox(booter: ComputerBooter) -> None:
"""Sync local skills to sandbox and refresh cache.
Backward-compatible orchestrator: keep historical behavior while internally
splitting into `apply` and `scan` phases.
"""
skills_root = Path(get_astrbot_skills_path())
if not skills_root.is_dir():
skills_root = get_astrbot_skills_path()
if not os.path.isdir(skills_root):
return
if not any(Path(skills_root).iterdir()):
return
local_skill_dirs = _list_local_skill_dirs(skills_root)
temp_dir = Path(get_astrbot_temp_path())
temp_dir.mkdir(parents=True, exist_ok=True)
zip_base = temp_dir / "skills_bundle"
zip_path = zip_base.with_suffix(".zip")
temp_dir = get_astrbot_temp_path()
os.makedirs(temp_dir, exist_ok=True)
zip_base = os.path.join(temp_dir, "skills_bundle")
zip_path = f"{zip_base}.zip"
try:
if local_skill_dirs:
if zip_path.exists():
zip_path.unlink()
shutil.make_archive(str(zip_base), "zip", str(skills_root))
remote_zip = Path(SANDBOX_SKILLS_ROOT) / "skills.zip"
logger.info("Uploading skills bundle to sandbox...")
await booter.shell.exec(f"mkdir -p {SANDBOX_SKILLS_ROOT}")
upload_result = await booter.upload_file(str(zip_path), str(remote_zip))
if not upload_result.get("success", False):
raise RuntimeError("Failed to upload skills bundle to sandbox.")
else:
logger.info(
"No local skills found. Keeping sandbox built-ins and refreshing metadata."
)
await booter.shell.exec(f"rm -f {SANDBOX_SKILLS_ROOT}/skills.zip")
# Keep backward-compatible behavior while splitting lifecycle into two
# observable phases: apply (filesystem mutation) + scan (metadata read).
await _apply_skills_to_sandbox(booter)
payload = await _scan_sandbox_skills(booter)
_update_sandbox_skills_cache(payload)
managed = payload.get("managed_skills", []) if isinstance(payload, dict) else []
logger.info(
"[Computer] Sandbox skill sync complete: managed=%d",
len(managed),
if os.path.exists(zip_path):
os.remove(zip_path)
shutil.make_archive(zip_base, "zip", skills_root)
remote_zip = Path(SANDBOX_SKILLS_ROOT) / "skills.zip"
logger.info("Uploading skills bundle to sandbox...")
await booter.shell.exec(f"mkdir -p {SANDBOX_SKILLS_ROOT}")
upload_result = await booter.upload_file(zip_path, str(remote_zip))
if not upload_result.get("success", False):
raise RuntimeError("Failed to upload skills bundle to sandbox.")
# Use -n flag to never overwrite existing files, fallback to Python if unzip unavailable
await booter.shell.exec(
f"unzip -n {remote_zip} -d {SANDBOX_SKILLS_ROOT} || "
f"python3 -c \"import zipfile, os, pathlib; z=zipfile.ZipFile('{remote_zip}'); "
f"[z.extract(m, '{SANDBOX_SKILLS_ROOT}') for m in z.namelist() "
f"if not os.path.exists(os.path.join('{SANDBOX_SKILLS_ROOT}', m))]\" || "
f"python -c \"import zipfile, os, pathlib; z=zipfile.ZipFile('{remote_zip}'); "
f"[z.extract(m, '{SANDBOX_SKILLS_ROOT}') for m in z.namelist() "
f"if not os.path.exists(os.path.join('{SANDBOX_SKILLS_ROOT}', m))]\"; "
f"rm -f {remote_zip}"
)
finally:
if zip_path.exists():
if os.path.exists(zip_path):
try:
zip_path.unlink()
os.remove(zip_path)
except Exception:
logger.warning(f"Failed to remove temp skills zip: {zip_path}")
@@ -423,7 +66,7 @@ async def get_booter(
config = context.get_config(umo=session_id)
sandbox_cfg = config.get("provider_settings", {}).get("sandbox", {})
booter_type = sandbox_cfg.get("booter", "shipyard_neo")
booter_type = sandbox_cfg.get("booter", "shipyard")
if session_id in session_booter:
booter = session_booter[session_id]
@@ -432,9 +75,6 @@ async def get_booter(
session_booter.pop(session_id, None)
if session_id not in session_booter:
uuid_str = uuid.uuid5(uuid.NAMESPACE_DNS, session_id).hex
logger.info(
f"[Computer] Initializing booter: type={booter_type}, session={session_id}"
)
if booter_type == "shipyard":
from .booters.shipyard import ShipyardBooter
@@ -446,27 +86,6 @@ async def get_booter(
client = ShipyardBooter(
endpoint_url=ep, access_token=token, ttl=ttl, session_num=max_sessions
)
elif booter_type == "shipyard_neo":
from .booters.shipyard_neo import ShipyardNeoBooter
ep = sandbox_cfg.get("shipyard_neo_endpoint", "")
token = sandbox_cfg.get("shipyard_neo_access_token", "")
ttl = sandbox_cfg.get("shipyard_neo_ttl", 3600)
profile = sandbox_cfg.get("shipyard_neo_profile", "python-default")
# Auto-discover token from Bay's credentials.json if not configured
if not token:
token = _discover_bay_credentials(ep)
logger.info(
f"[Computer] Shipyard Neo config: endpoint={ep}, profile={profile}, ttl={ttl}"
)
client = ShipyardNeoBooter(
endpoint_url=ep,
access_token=token,
profile=profile,
ttl=ttl,
)
elif booter_type == "boxlite":
from .booters.boxlite import BoxliteBooter
@@ -476,9 +95,6 @@ async def get_booter(
try:
await client.boot(uuid_str)
logger.info(
f"[Computer] Sandbox booted successfully: type={booter_type}, session={session_id}"
)
await _sync_skills_to_sandbox(client)
except Exception as e:
logger.error(f"Error booting sandbox for session {session_id}: {e}")
@@ -488,24 +104,6 @@ async def get_booter(
return session_booter[session_id]
async def sync_skills_to_active_sandboxes() -> None:
"""Best-effort skills synchronization for all active sandbox sessions."""
logger.info(
"[Computer] Syncing skills to %d active sandbox(es)", len(session_booter)
)
for session_id, booter in list(session_booter.items()):
try:
if not await booter.available():
continue
await _sync_skills_to_sandbox(booter)
except Exception as e:
logger.warning(
"Failed to sync skills to sandbox for session %s: %s",
session_id,
e,
)
def get_local_booter() -> ComputerBooter:
global local_booter
if local_booter is None:
+1 -7
View File
@@ -1,11 +1,5 @@
from .browser import BrowserComponent
from .filesystem import FileSystemComponent
from .python import PythonComponent
from .shell import ShellComponent
__all__ = [
"PythonComponent",
"ShellComponent",
"FileSystemComponent",
"BrowserComponent",
]
__all__ = ["PythonComponent", "ShellComponent", "FileSystemComponent"]
-46
View File
@@ -1,46 +0,0 @@
"""
Browser automation component
"""
from typing import Any, Protocol
class BrowserComponent(Protocol):
"""Browser operations component"""
async def exec(
self,
cmd: str,
timeout: int = 30,
description: str | None = None,
tags: str | None = None,
learn: bool = False,
include_trace: bool = False,
) -> dict[str, Any]:
"""Execute a browser automation command"""
...
async def exec_batch(
self,
commands: list[str],
timeout: int = 60,
stop_on_error: bool = True,
description: str | None = None,
tags: str | None = None,
learn: bool = False,
include_trace: bool = False,
) -> dict[str, Any]:
"""Execute a browser automation command batch"""
...
async def run_skill(
self,
skill_key: str,
timeout: int = 60,
stop_on_error: bool = True,
include_trace: bool = False,
description: str | None = None,
tags: str | None = None,
) -> dict[str, Any]:
"""Run a browser skill by skill key"""
...
-28
View File
@@ -1,36 +1,8 @@
from .browser import BrowserBatchExecTool, BrowserExecTool, RunBrowserSkillTool
from .fs import FileDownloadTool, FileUploadTool
from .neo_skills import (
AnnotateExecutionTool,
CreateSkillCandidateTool,
CreateSkillPayloadTool,
EvaluateSkillCandidateTool,
GetExecutionHistoryTool,
GetSkillPayloadTool,
ListSkillCandidatesTool,
ListSkillReleasesTool,
PromoteSkillCandidateTool,
RollbackSkillReleaseTool,
SyncSkillReleaseTool,
)
from .python import LocalPythonTool, PythonTool
from .shell import ExecuteShellTool
__all__ = [
"BrowserExecTool",
"BrowserBatchExecTool",
"RunBrowserSkillTool",
"GetExecutionHistoryTool",
"AnnotateExecutionTool",
"CreateSkillPayloadTool",
"GetSkillPayloadTool",
"CreateSkillCandidateTool",
"ListSkillCandidatesTool",
"EvaluateSkillCandidateTool",
"PromoteSkillCandidateTool",
"ListSkillReleasesTool",
"RollbackSkillReleaseTool",
"SyncSkillReleaseTool",
"FileUploadTool",
"PythonTool",
"LocalPythonTool",
-204
View File
@@ -1,204 +0,0 @@
import json
from dataclasses import dataclass, field
from typing import Any
from astrbot.api import FunctionTool
from astrbot.core.agent.run_context import ContextWrapper
from astrbot.core.agent.tool import ToolExecResult
from astrbot.core.astr_agent_context import AstrAgentContext
from ..computer_client import get_booter
def _to_json(data: Any) -> str:
return json.dumps(data, ensure_ascii=False, default=str)
def _ensure_admin(context: ContextWrapper[AstrAgentContext]) -> str | None:
if context.context.event.role != "admin":
return (
"error: Permission denied. Browser and skill lifecycle tools are only allowed "
"for admin users."
)
return None
async def _get_browser_component(context: ContextWrapper[AstrAgentContext]) -> Any:
booter = await get_booter(
context.context.context,
context.context.event.unified_msg_origin,
)
browser = getattr(booter, "browser", None)
if browser is None:
raise RuntimeError(
"Current sandbox booter does not support browser capability. "
"Please switch to shipyard_neo."
)
return browser
@dataclass
class BrowserExecTool(FunctionTool):
name: str = "astrbot_execute_browser"
description: str = "Execute one browser automation command in the sandbox."
parameters: dict = field(
default_factory=lambda: {
"type": "object",
"properties": {
"cmd": {"type": "string", "description": "Browser command to execute."},
"timeout": {"type": "integer", "default": 30},
"description": {
"type": "string",
"description": "Optional execution description.",
},
"tags": {"type": "string", "description": "Optional tags."},
"learn": {
"type": "boolean",
"description": "Whether to mark execution as learn evidence.",
"default": False,
},
"include_trace": {
"type": "boolean",
"description": "Whether to include trace_ref in response.",
"default": False,
},
},
"required": ["cmd"],
}
)
async def call(
self,
context: ContextWrapper[AstrAgentContext],
cmd: str,
timeout: int = 30,
description: str | None = None,
tags: str | None = None,
learn: bool = False,
include_trace: bool = False,
) -> ToolExecResult:
if err := _ensure_admin(context):
return err
try:
browser = await _get_browser_component(context)
result = await browser.exec(
cmd=cmd,
timeout=timeout,
description=description,
tags=tags,
learn=learn,
include_trace=include_trace,
)
return _to_json(result)
except Exception as e:
return f"Error executing browser command: {str(e)}"
@dataclass
class BrowserBatchExecTool(FunctionTool):
name: str = "astrbot_execute_browser_batch"
description: str = "Execute a browser command batch in the sandbox."
parameters: dict = field(
default_factory=lambda: {
"type": "object",
"properties": {
"commands": {
"type": "array",
"items": {"type": "string"},
"description": "Ordered browser commands.",
},
"timeout": {"type": "integer", "default": 60},
"stop_on_error": {"type": "boolean", "default": True},
"description": {
"type": "string",
"description": "Optional execution description.",
},
"tags": {"type": "string", "description": "Optional tags."},
"learn": {
"type": "boolean",
"description": "Whether to mark execution as learn evidence.",
"default": False,
},
"include_trace": {
"type": "boolean",
"description": "Whether to include trace_ref in response.",
"default": False,
},
},
"required": ["commands"],
}
)
async def call(
self,
context: ContextWrapper[AstrAgentContext],
commands: list[str],
timeout: int = 60,
stop_on_error: bool = True,
description: str | None = None,
tags: str | None = None,
learn: bool = False,
include_trace: bool = False,
) -> ToolExecResult:
if err := _ensure_admin(context):
return err
try:
browser = await _get_browser_component(context)
result = await browser.exec_batch(
commands=commands,
timeout=timeout,
stop_on_error=stop_on_error,
description=description,
tags=tags,
learn=learn,
include_trace=include_trace,
)
return _to_json(result)
except Exception as e:
return f"Error executing browser batch command: {str(e)}"
@dataclass
class RunBrowserSkillTool(FunctionTool):
name: str = "astrbot_run_browser_skill"
description: str = "Run a released browser skill in the sandbox by skill_key."
parameters: dict = field(
default_factory=lambda: {
"type": "object",
"properties": {
"skill_key": {"type": "string"},
"timeout": {"type": "integer", "default": 60},
"stop_on_error": {"type": "boolean", "default": True},
"include_trace": {"type": "boolean", "default": False},
"description": {"type": "string"},
"tags": {"type": "string"},
},
"required": ["skill_key"],
}
)
async def call(
self,
context: ContextWrapper[AstrAgentContext],
skill_key: str,
timeout: int = 60,
stop_on_error: bool = True,
include_trace: bool = False,
description: str | None = None,
tags: str | None = None,
) -> ToolExecResult:
if err := _ensure_admin(context):
return err
try:
browser = await _get_browser_component(context)
result = await browser.run_skill(
skill_key=skill_key,
timeout=timeout,
stop_on_error=stop_on_error,
include_trace=include_trace,
description=description,
tags=tags,
)
return _to_json(result)
except Exception as e:
return f"Error running browser skill: {str(e)}"
-542
View File
@@ -1,542 +0,0 @@
import json
from collections.abc import Awaitable, Callable
from dataclasses import dataclass, field
from typing import Any
from astrbot.api import FunctionTool
from astrbot.core.agent.run_context import ContextWrapper
from astrbot.core.agent.tool import ToolExecResult
from astrbot.core.astr_agent_context import AstrAgentContext
from astrbot.core.skills.neo_skill_sync import NeoSkillSyncManager
from ..computer_client import get_booter
def _to_jsonable(model_like: Any) -> Any:
if isinstance(model_like, dict):
return model_like
if isinstance(model_like, list):
return [_to_jsonable(i) for i in model_like]
if hasattr(model_like, "model_dump"):
return _to_jsonable(model_like.model_dump())
return model_like
def _to_json_text(data: Any) -> str:
return json.dumps(_to_jsonable(data), ensure_ascii=False, default=str)
def _ensure_admin(context: ContextWrapper[AstrAgentContext]) -> str | None:
if context.context.event.role != "admin":
return "error: Permission denied. Skill lifecycle tools are only allowed for admin users."
return None
async def _get_neo_context(
context: ContextWrapper[AstrAgentContext],
) -> tuple[Any, Any]:
booter = await get_booter(
context.context.context,
context.context.event.unified_msg_origin,
)
client = getattr(booter, "bay_client", None)
sandbox = getattr(booter, "sandbox", None)
if client is None or sandbox is None:
raise RuntimeError(
"Current sandbox booter does not support Neo skill lifecycle APIs. "
"Please switch to shipyard_neo."
)
return client, sandbox
@dataclass
class NeoSkillToolBase(FunctionTool):
error_prefix: str = "Error"
async def _run(
self,
context: ContextWrapper[AstrAgentContext],
neo_call: Callable[[Any, Any], Awaitable[Any]],
error_action: str,
) -> ToolExecResult:
if err := _ensure_admin(context):
return err
try:
client, sandbox = await _get_neo_context(context)
result = await neo_call(client, sandbox)
return _to_json_text(result)
except Exception as e:
return f"{self.error_prefix} {error_action}: {str(e)}"
@dataclass
class GetExecutionHistoryTool(NeoSkillToolBase):
name: str = "astrbot_get_execution_history"
description: str = "Get execution history from current sandbox."
parameters: dict = field(
default_factory=lambda: {
"type": "object",
"properties": {
"exec_type": {"type": "string"},
"success_only": {"type": "boolean", "default": False},
"limit": {"type": "integer", "default": 100},
"offset": {"type": "integer", "default": 0},
"tags": {"type": "string"},
"has_notes": {"type": "boolean", "default": False},
"has_description": {"type": "boolean", "default": False},
},
"required": [],
}
)
async def call(
self,
context: ContextWrapper[AstrAgentContext],
exec_type: str | None = None,
success_only: bool = False,
limit: int = 100,
offset: int = 0,
tags: str | None = None,
has_notes: bool = False,
has_description: bool = False,
) -> ToolExecResult:
return await self._run(
context,
lambda _client, sandbox: sandbox.get_execution_history(
exec_type=exec_type,
success_only=success_only,
limit=limit,
offset=offset,
tags=tags,
has_notes=has_notes,
has_description=has_description,
),
error_action="getting execution history",
)
@dataclass
class AnnotateExecutionTool(NeoSkillToolBase):
name: str = "astrbot_annotate_execution"
description: str = "Annotate one execution history record."
parameters: dict = field(
default_factory=lambda: {
"type": "object",
"properties": {
"execution_id": {"type": "string"},
"description": {"type": "string"},
"tags": {"type": "string"},
"notes": {"type": "string"},
},
"required": ["execution_id"],
}
)
async def call(
self,
context: ContextWrapper[AstrAgentContext],
execution_id: str,
description: str | None = None,
tags: str | None = None,
notes: str | None = None,
) -> ToolExecResult:
return await self._run(
context,
lambda _client, sandbox: sandbox.annotate_execution(
execution_id=execution_id,
description=description,
tags=tags,
notes=notes,
),
error_action="annotating execution",
)
@dataclass
class CreateSkillPayloadTool(NeoSkillToolBase):
name: str = "astrbot_create_skill_payload"
description: str = (
"Step 1/3 for Neo skill authoring: create immutable payload content and return payload_ref. "
"Use this to store skill_markdown and structured metadata; do NOT write local skill folders directly."
)
parameters: dict = field(
default_factory=lambda: {
"type": "object",
"properties": {
"payload": {
"anyOf": [{"type": "object"}, {"type": "array"}],
"description": (
"Skill payload JSON. Typical schema: {skill_markdown, inputs, outputs, meta}. "
"This only stores content and returns payload_ref; it does not create a candidate or release."
),
},
"kind": {
"type": "string",
"description": "Payload kind.",
"default": "astrbot_skill_v1",
},
},
"required": ["payload"],
}
)
async def call(
self,
context: ContextWrapper[AstrAgentContext],
payload: dict[str, Any] | list[Any],
kind: str = "astrbot_skill_v1",
) -> ToolExecResult:
return await self._run(
context,
lambda client, _sandbox: client.skills.create_payload(
payload=payload,
kind=kind,
),
error_action="creating skill payload",
)
@dataclass
class GetSkillPayloadTool(NeoSkillToolBase):
name: str = "astrbot_get_skill_payload"
description: str = "Get one skill payload by payload_ref."
parameters: dict = field(
default_factory=lambda: {
"type": "object",
"properties": {
"payload_ref": {"type": "string"},
},
"required": ["payload_ref"],
}
)
async def call(
self,
context: ContextWrapper[AstrAgentContext],
payload_ref: str,
) -> ToolExecResult:
return await self._run(
context,
lambda client, _sandbox: client.skills.get_payload(payload_ref),
error_action="getting skill payload",
)
@dataclass
class CreateSkillCandidateTool(NeoSkillToolBase):
name: str = "astrbot_create_skill_candidate"
description: str = (
"Step 2/3 for Neo skill authoring: create a candidate by binding execution evidence "
"(source_execution_ids) with skill identity (skill_key) and optional payload_ref."
)
parameters: dict = field(
default_factory=lambda: {
"type": "object",
"properties": {
"skill_key": {
"type": "string",
"description": "Stable logical identifier, e.g. image-collage-9grid.",
},
"source_execution_ids": {
"type": "array",
"items": {"type": "string"},
"description": "Execution evidence IDs captured from sandbox history.",
},
"scenario_key": {
"type": "string",
"description": "Optional scenario namespace for grouping candidates.",
},
"payload_ref": {
"type": "string",
"description": "Optional payload reference created by astrbot_create_skill_payload.",
},
},
"required": ["skill_key", "source_execution_ids"],
}
)
async def call(
self,
context: ContextWrapper[AstrAgentContext],
skill_key: str,
source_execution_ids: list[str],
scenario_key: str | None = None,
payload_ref: str | None = None,
) -> ToolExecResult:
return await self._run(
context,
lambda client, _sandbox: client.skills.create_candidate(
skill_key=skill_key,
source_execution_ids=source_execution_ids,
scenario_key=scenario_key,
payload_ref=payload_ref,
),
error_action="creating skill candidate",
)
@dataclass
class ListSkillCandidatesTool(NeoSkillToolBase):
name: str = "astrbot_list_skill_candidates"
description: str = "List skill candidates."
parameters: dict = field(
default_factory=lambda: {
"type": "object",
"properties": {
"status": {"type": "string"},
"skill_key": {"type": "string"},
"limit": {"type": "integer", "default": 100},
"offset": {"type": "integer", "default": 0},
},
"required": [],
}
)
async def call(
self,
context: ContextWrapper[AstrAgentContext],
status: str | None = None,
skill_key: str | None = None,
limit: int = 100,
offset: int = 0,
) -> ToolExecResult:
return await self._run(
context,
lambda client, _sandbox: client.skills.list_candidates(
status=status,
skill_key=skill_key,
limit=limit,
offset=offset,
),
error_action="listing skill candidates",
)
@dataclass
class EvaluateSkillCandidateTool(NeoSkillToolBase):
name: str = "astrbot_evaluate_skill_candidate"
description: str = "Evaluate a skill candidate."
parameters: dict = field(
default_factory=lambda: {
"type": "object",
"properties": {
"candidate_id": {"type": "string"},
"passed": {"type": "boolean"},
"score": {"type": "number"},
"benchmark_id": {"type": "string"},
"report": {"type": "string"},
},
"required": ["candidate_id", "passed"],
}
)
async def call(
self,
context: ContextWrapper[AstrAgentContext],
candidate_id: str,
passed: bool,
score: float | None = None,
benchmark_id: str | None = None,
report: str | None = None,
) -> ToolExecResult:
return await self._run(
context,
lambda client, _sandbox: client.skills.evaluate_candidate(
candidate_id,
passed=passed,
score=score,
benchmark_id=benchmark_id,
report=report,
),
error_action="evaluating skill candidate",
)
@dataclass
class PromoteSkillCandidateTool(NeoSkillToolBase):
name: str = "astrbot_promote_skill_candidate"
description: str = (
"Step 3/3 for Neo skill authoring: promote candidate to canary/stable release. "
"If stage=stable and sync_to_local=true, payload.skill_markdown is synced to local SKILL.md automatically."
)
parameters: dict = field(
default_factory=lambda: {
"type": "object",
"properties": {
"candidate_id": {"type": "string"},
"stage": {
"type": "string",
"description": "Release stage: canary/stable",
"default": "canary",
},
"sync_to_local": {
"type": "boolean",
"description": (
"Only used with stage=stable. true means sync payload.skill_markdown to local SKILL.md; "
"false means release remains Neo-side only."
),
"default": True,
},
},
"required": ["candidate_id"],
}
)
async def call(
self,
context: ContextWrapper[AstrAgentContext],
candidate_id: str,
stage: str = "canary",
sync_to_local: bool = True,
) -> ToolExecResult:
if err := _ensure_admin(context):
return err
if stage not in {"canary", "stable"}:
return "Error promoting skill candidate: stage must be canary or stable."
try:
client, _sandbox = await _get_neo_context(context)
sync_mgr = NeoSkillSyncManager()
result = await sync_mgr.promote_with_optional_sync(
client,
candidate_id=candidate_id,
stage=stage,
sync_to_local=sync_to_local,
)
if result.get("sync_error"):
rollback_json = result.get("rollback")
if rollback_json:
return (
"Error promoting skill candidate: stable release synced failed; "
f"auto rollback succeeded. sync_error={result['sync_error']}; "
f"rollback={_to_json_text(rollback_json)}"
)
return _to_json_text(
{
"release": result.get("release"),
"sync": result.get("sync"),
"rollback": result.get("rollback"),
}
)
except Exception as e:
return f"Error promoting skill candidate: {str(e)}"
@dataclass
class ListSkillReleasesTool(NeoSkillToolBase):
name: str = "astrbot_list_skill_releases"
description: str = "List skill releases."
parameters: dict = field(
default_factory=lambda: {
"type": "object",
"properties": {
"skill_key": {"type": "string"},
"active_only": {"type": "boolean", "default": False},
"stage": {"type": "string"},
"limit": {"type": "integer", "default": 100},
"offset": {"type": "integer", "default": 0},
},
"required": [],
}
)
async def call(
self,
context: ContextWrapper[AstrAgentContext],
skill_key: str | None = None,
active_only: bool = False,
stage: str | None = None,
limit: int = 100,
offset: int = 0,
) -> ToolExecResult:
return await self._run(
context,
lambda client, _sandbox: client.skills.list_releases(
skill_key=skill_key,
active_only=active_only,
stage=stage,
limit=limit,
offset=offset,
),
error_action="listing skill releases",
)
@dataclass
class RollbackSkillReleaseTool(NeoSkillToolBase):
name: str = "astrbot_rollback_skill_release"
description: str = "Rollback one skill release."
parameters: dict = field(
default_factory=lambda: {
"type": "object",
"properties": {
"release_id": {"type": "string"},
},
"required": ["release_id"],
}
)
async def call(
self,
context: ContextWrapper[AstrAgentContext],
release_id: str,
) -> ToolExecResult:
return await self._run(
context,
lambda client, _sandbox: client.skills.rollback_release(release_id),
error_action="rolling back skill release",
)
@dataclass
class SyncSkillReleaseTool(NeoSkillToolBase):
name: str = "astrbot_sync_skill_release"
description: str = (
"Sync stable Neo release payload to local SKILL.md and update mapping metadata."
)
parameters: dict = field(
default_factory=lambda: {
"type": "object",
"properties": {
"release_id": {"type": "string"},
"skill_key": {"type": "string"},
"require_stable": {"type": "boolean", "default": True},
},
"required": [],
}
)
async def call(
self,
context: ContextWrapper[AstrAgentContext],
release_id: str | None = None,
skill_key: str | None = None,
require_stable: bool = True,
) -> ToolExecResult:
return await self._run(
context,
lambda client, _sandbox: _sync_release_to_dict(
client,
release_id=release_id,
skill_key=skill_key,
require_stable=require_stable,
),
error_action="syncing skill release",
)
async def _sync_release_to_dict(
client: Any,
*,
release_id: str | None,
skill_key: str | None,
require_stable: bool,
) -> dict[str, str]:
sync_mgr = NeoSkillSyncManager()
result = await sync_mgr.sync_release(
client,
release_id=release_id,
skill_key=skill_key,
require_stable=require_stable,
)
return sync_mgr.sync_result_to_dict(result)
+2 -8
View File
@@ -1,4 +1,3 @@
import platform
from dataclasses import dataclass, field
import mcp
@@ -11,8 +10,6 @@ from astrbot.core.computer.computer_client import get_booter, get_local_booter
from astrbot.core.computer.tools.permissions import check_admin_permission
from astrbot.core.message.message_event_result import MessageChain
_OS_NAME = platform.system()
param_schema = {
"type": "object",
"properties": {
@@ -64,7 +61,7 @@ async def handle_result(result: dict, event: AstrMessageEvent) -> ToolExecResult
@dataclass
class PythonTool(FunctionTool):
name: str = "astrbot_execute_ipython"
description: str = f"Run codes in an IPython shell. Current OS: {_OS_NAME}."
description: str = "Run codes in an IPython shell."
parameters: dict = field(default_factory=lambda: param_schema)
async def call(
@@ -86,10 +83,7 @@ class PythonTool(FunctionTool):
@dataclass
class LocalPythonTool(FunctionTool):
name: str = "astrbot_execute_python"
description: str = (
f"Execute codes in a Python environment. Current OS: {_OS_NAME}. "
"Use system-compatible commands."
)
description: str = "Execute codes in a Python environment."
parameters: dict = field(default_factory=lambda: param_schema)
+45 -157
View File
@@ -132,15 +132,11 @@ DEFAULT_CONFIG = {
"computer_use_runtime": "none",
"computer_use_require_admin": True,
"sandbox": {
"booter": "shipyard_neo",
"booter": "shipyard",
"shipyard_endpoint": "",
"shipyard_access_token": "",
"shipyard_ttl": 3600,
"shipyard_max_sessions": 10,
"shipyard_neo_endpoint": "",
"shipyard_neo_access_token": "",
"shipyard_neo_profile": "python-default",
"shipyard_neo_ttl": 3600,
},
},
# SubAgent orchestrator mode:
@@ -395,6 +391,7 @@ CONFIG_METADATA_2 = {
"discord_token": "",
"discord_proxy": "",
"discord_command_register": True,
"discord_guild_id_for_debug": "",
"discord_activity_name": "",
},
"Misskey": {
@@ -449,20 +446,6 @@ CONFIG_METADATA_2 = {
"satori_heartbeat_interval": 10,
"satori_reconnect_delay": 5,
},
"kook": {
"id": "kook",
"type": "kook",
"enable": False,
"kook_bot_token": "",
"kook_bot_nickname": "",
"kook_reconnect_delay": 1,
"kook_max_reconnect_delay": 60,
"kook_max_retry_delay": 60,
"kook_heartbeat_interval": 30,
"kook_heartbeat_timeout": 6,
"kook_max_heartbeat_failures": 3,
"kook_max_consecutive_failures": 5,
},
# "WebChat": {
# "id": "webchat",
# "type": "webchat",
@@ -768,8 +751,7 @@ CONFIG_METADATA_2 = {
"hint": "可选的代理地址:http://ip:port",
},
"discord_command_register": {
"description": "注册 Discord 指令",
"hint": "启用后,自动将插件指令注册为 Discord 斜杠指令",
"description": "是否自动将插件指令注册 Discord 斜杠指令",
"type": "bool",
},
"discord_activity_name": {
@@ -804,51 +786,6 @@ CONFIG_METADATA_2 = {
"type": "string",
"hint": "统一 Webhook 模式下的唯一标识符,创建平台时自动生成。",
},
"kook_bot_token": {
"description": "机器人 Token",
"type": "string",
"hint": "必填项。从 KOOK 开发者平台获取的机器人 Token。",
},
"kook_bot_nickname": {
"description": "Bot Nickname",
"type": "string",
"hint": "可选项。若发送者昵称与此值一致,将忽略该消息以避免广播风暴。",
},
"kook_reconnect_delay": {
"description": "重连延迟",
"type": "int",
"hint": "重连延迟时间(秒),使用指数退避策略。",
},
"kook_max_reconnect_delay": {
"description": "最大重连延迟",
"type": "int",
"hint": "重连延迟的最大值(秒)。",
},
"kook_max_retry_delay": {
"description": "最大重试延迟",
"type": "int",
"hint": "重试的最大延迟时间(秒)。",
},
"kook_heartbeat_interval": {
"description": "心跳间隔",
"type": "int",
"hint": "心跳检测间隔时间(秒)。",
},
"kook_heartbeat_timeout": {
"description": "心跳超时时间",
"type": "int",
"hint": "心跳检测超时时间(秒)。",
},
"kook_max_heartbeat_failures": {
"description": "最大心跳失败次数",
"type": "int",
"hint": "允许的最大心跳失败次数,超过后断开连接。",
},
"kook_max_consecutive_failures": {
"description": "最大连续失败次数",
"type": "int",
"hint": "允许的最大连续失败次数,超过后停止重试。",
},
},
},
"platform_settings": {
@@ -2934,48 +2871,12 @@ CONFIG_METADATA_3 = {
"provider_settings.sandbox.booter": {
"description": "沙箱环境驱动器",
"type": "string",
"options": ["shipyard_neo", "shipyard"],
"labels": ["Shipyard Neo", "Shipyard"],
"options": ["shipyard"],
"labels": ["Shipyard"],
"condition": {
"provider_settings.computer_use_runtime": "sandbox",
},
},
"provider_settings.sandbox.shipyard_neo_endpoint": {
"description": "Shipyard Neo API Endpoint",
"type": "string",
"hint": "Shipyard Neo(Bay) 服务的 API 地址,默认 http://127.0.0.1:8114。",
"condition": {
"provider_settings.computer_use_runtime": "sandbox",
"provider_settings.sandbox.booter": "shipyard_neo",
},
},
"provider_settings.sandbox.shipyard_neo_access_token": {
"description": "Shipyard Neo Access Token",
"type": "string",
"hint": "Bay 的 API Keysk-bay-...)。留空时自动从 credentials.json 发现。",
"condition": {
"provider_settings.computer_use_runtime": "sandbox",
"provider_settings.sandbox.booter": "shipyard_neo",
},
},
"provider_settings.sandbox.shipyard_neo_profile": {
"description": "Shipyard Neo Profile",
"type": "string",
"hint": "Shipyard Neo 沙箱 profile,如 python-default。",
"condition": {
"provider_settings.computer_use_runtime": "sandbox",
"provider_settings.sandbox.booter": "shipyard_neo",
},
},
"provider_settings.sandbox.shipyard_neo_ttl": {
"description": "Shipyard Neo Sandbox TTL",
"type": "int",
"hint": "Shipyard Neo 沙箱生存时间(秒)。",
"condition": {
"provider_settings.computer_use_runtime": "sandbox",
"provider_settings.sandbox.booter": "shipyard_neo",
},
},
"provider_settings.sandbox.shipyard_endpoint": {
"description": "Shipyard API Endpoint",
"type": "string",
@@ -3211,6 +3112,46 @@ CONFIG_METADATA_3 = {
"provider_settings.agent_runner_type": "local",
},
},
"provider_settings.max_quoted_fallback_images": {
"description": "引用图片回退解析上限",
"type": "int",
"hint": "引用/转发消息回退解析图片时的最大注入数量,超出会截断。",
"condition": {
"provider_settings.agent_runner_type": "local",
},
},
"provider_settings.quoted_message_parser.max_component_chain_depth": {
"description": "引用解析组件链深度",
"type": "int",
"hint": "解析 Reply 组件链时允许的最大递归深度。",
"condition": {
"provider_settings.agent_runner_type": "local",
},
},
"provider_settings.quoted_message_parser.max_forward_node_depth": {
"description": "引用解析转发节点深度",
"type": "int",
"hint": "解析合并转发节点时允许的最大递归深度。",
"condition": {
"provider_settings.agent_runner_type": "local",
},
},
"provider_settings.quoted_message_parser.max_forward_fetch": {
"description": "引用解析转发拉取上限",
"type": "int",
"hint": "递归拉取 get_forward_msg 的最大次数。",
"condition": {
"provider_settings.agent_runner_type": "local",
},
},
"provider_settings.quoted_message_parser.warn_on_action_failure": {
"description": "引用解析 action 失败告警",
"type": "bool",
"hint": "开启后,get_msg/get_forward_msg 全部尝试失败时输出 warning 日志。",
"condition": {
"provider_settings.agent_runner_type": "local",
},
},
"provider_settings.max_agent_step": {
"description": "工具调用轮数上限",
"type": "int",
@@ -3254,46 +3195,6 @@ CONFIG_METADATA_3 = {
"type": "bool",
"hint": "/provider 命令列出模型时是否并发检测连通性。开启后会主动调用模型测试连通性,可能产生额外 token 消耗。",
},
"provider_settings.max_quoted_fallback_images": {
"description": "引用图片回退解析上限",
"type": "int",
"hint": "引用/转发消息回退解析图片时的最大注入数量,超出会截断。",
"condition": {
"provider_settings.agent_runner_type": "local",
},
},
"provider_settings.quoted_message_parser.max_component_chain_depth": {
"description": "引用解析组件链深度",
"type": "int",
"hint": "解析 Reply 组件链时允许的最大递归深度。",
"condition": {
"provider_settings.agent_runner_type": "local",
},
},
"provider_settings.quoted_message_parser.max_forward_node_depth": {
"description": "引用解析转发节点深度",
"type": "int",
"hint": "解析合并转发节点时允许的最大递归深度。",
"condition": {
"provider_settings.agent_runner_type": "local",
},
},
"provider_settings.quoted_message_parser.max_forward_fetch": {
"description": "引用解析转发拉取上限",
"type": "int",
"hint": "递归拉取 get_forward_msg 的最大次数。",
"condition": {
"provider_settings.agent_runner_type": "local",
},
},
"provider_settings.quoted_message_parser.warn_on_action_failure": {
"description": "引用解析 action 失败告警",
"type": "bool",
"hint": "开启后,get_msg/get_forward_msg 全部尝试失败时输出 warning 日志。",
"condition": {
"provider_settings.agent_runner_type": "local",
},
},
},
"condition": {
"provider_settings.enable": True,
@@ -3505,19 +3406,6 @@ CONFIG_METADATA_3 = {
"platform_specific.telegram.pre_ack_emoji.enable": True,
},
},
"platform_specific.discord.pre_ack_emoji.enable": {
"description": "[Discord] 启用预回应表情",
"type": "bool",
},
"platform_specific.discord.pre_ack_emoji.emojis": {
"description": "表情列表(Unicode 或自定义表情名)",
"type": "list",
"items": {"type": "string"},
"hint": "填写 Unicode 表情符号,例如:👍、🤔、⏳",
"condition": {
"platform_specific.discord.pre_ack_emoji.enable": True,
},
},
},
},
},
-4
View File
@@ -175,10 +175,6 @@ class LogManager:
_trace_sink_id: int | None = None
_NOISY_LOGGER_LEVELS: dict[str, int] = {
"aiosqlite": logging.WARNING,
"filelock": logging.WARNING,
"asyncio": logging.WARNING,
"tzlocal": logging.WARNING,
"apscheduler": logging.WARNING,
}
@classmethod
@@ -27,7 +27,7 @@ class PreProcessStage(Stage):
) -> None | AsyncGenerator[None, None]:
"""在处理事件之前的预处理"""
# 平台特异配置:platform_specific.<platform>.pre_ack_emoji
supported = {"telegram", "lark", "discord"}
supported = {"telegram", "lark"}
platform = event.get_platform_name()
cfg = (
self.config.get("platform_specific", {})
-4
View File
@@ -180,10 +180,6 @@ class PlatformManager:
from .sources.line.line_adapter import (
LinePlatformAdapter, # noqa: F401
)
case "kook":
from .sources.kook.kook_adapter import (
KookPlatformAdapter, # noqa: F401
)
except (ImportError, ModuleNotFoundError) as e:
logger.error(
f"加载平台适配器 {platform_config['type']} 失败,原因:{e}。请检查依赖库是否安装。提示:可以在 管理面板->平台日志->安装Pip库 中安装依赖库。",
@@ -1,371 +0,0 @@
import asyncio
import json
import re
from astrbot import logger
from astrbot.api.event import MessageChain
from astrbot.api.message_components import At, AtAll, Image, Plain
from astrbot.api.platform import (
AstrBotMessage,
MessageMember,
MessageType,
Platform,
PlatformMetadata,
register_platform_adapter,
)
from astrbot.core.platform.astr_message_event import MessageSesion
from .kook_client import KookClient
from .kook_config import KookConfig
from .kook_event import KookEvent
@register_platform_adapter(
"kook",
"KOOK 适配器",
)
class KookPlatformAdapter(Platform):
def __init__(
self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue
) -> None:
super().__init__(platform_config, event_queue)
self.kook_config = KookConfig.from_dict(platform_config)
logger.debug(f"[KOOK] 配置: {self.kook_config.pretty_jsons()}")
self.settings = platform_settings
self.client = KookClient(self.kook_config, self._on_received)
self._reconnect_task = None
self.running = False
self._main_task = None
async def send_by_session(
self, session: MessageSesion, message_chain: MessageChain
):
inner_message = AstrBotMessage()
inner_message.session_id = session.session_id
inner_message.type = session.message_type
message_event = KookEvent(
message_str=message_chain.get_plain_text(),
message_obj=inner_message,
platform_meta=self.meta(),
session_id=session.session_id,
client=self.client,
)
await message_event.send(message_chain)
def meta(self) -> PlatformMetadata:
return PlatformMetadata(
name="kook", description="KOOK 适配器", id=self.kook_config.id
)
def _should_ignore_event_by_bot_nickname(self, payload: dict) -> bool:
bot_nickname = self.kook_config.bot_nickname.strip()
if not bot_nickname:
return False
author = payload.get("extra", {}).get("author", {})
if not isinstance(author, dict):
return False
author_nickname = author.get("nickname") or author.get("username") or ""
if not isinstance(author_nickname, str):
author_nickname = str(author_nickname)
return author_nickname.strip().casefold() == bot_nickname.casefold()
async def _on_received(self, data: dict):
logger.debug(f"KOOK 收到数据: {data}")
if "d" in data and data["s"] == 0:
payload = data["d"]
event_type = payload.get("type")
# 支持type=9(文本)和type=10(卡片)
if event_type in (9, 10):
if self._should_ignore_event_by_bot_nickname(payload):
return
try:
abm = await self.convert_message(payload)
await self.handle_msg(abm)
except Exception as e:
logger.error(f"[KOOK] 消息处理异常: {e}")
async def run(self):
"""主运行循环"""
self.running = True
logger.info("[KOOK] 启动KOOK适配器")
# 启动主循环
self._main_task = asyncio.create_task(self._main_loop())
try:
await self._main_task
except asyncio.CancelledError:
logger.info("[KOOK] 适配器被取消")
except Exception as e:
logger.error(f"[KOOK] 适配器运行异常: {e}")
finally:
self.running = False
await self._cleanup()
async def _main_loop(self):
"""主循环,处理连接和重连"""
consecutive_failures = 0
max_consecutive_failures = self.kook_config.max_consecutive_failures
max_retry_delay = self.kook_config.max_retry_delay
while self.running:
try:
logger.info("[KOOK] 尝试连接KOOK服务器...")
# 尝试连接
success = await self.client.connect()
if success:
logger.info("[KOOK] 连接成功,开始监听消息")
consecutive_failures = 0 # 重置失败计数
# 等待连接结束(可能是正常关闭或异常)
while self.client.running and self.running:
try:
# 等待 client 内部触发 _stop_event,或者超时 1 秒后重试
# 使用 wait_for 配合 timeout 是为了防止极端情况下 self.running 变化没被察觉
await asyncio.wait_for(
self.client.wait_until_closed(), timeout=1.0
)
except asyncio.TimeoutError:
# 正常超时,继续下一轮 while 检查
continue
if self.running:
logger.warning("[KOOK] 连接断开,准备重连")
else:
consecutive_failures += 1
logger.error(
f"[KOOK] 连接失败,连续失败次数: {consecutive_failures}"
)
if consecutive_failures >= max_consecutive_failures:
logger.error("[KOOK] 连续失败次数过多,停止重连")
break
# 等待一段时间后重试
wait_time = min(
2**consecutive_failures, max_retry_delay
) # 指数退避
logger.info(f"[KOOK] 等待 {wait_time} 秒后重试...")
await asyncio.sleep(wait_time)
except Exception as e:
consecutive_failures += 1
logger.error(f"[KOOK] 主循环异常: {e}")
if consecutive_failures >= max_consecutive_failures:
logger.error("[KOOK] 连续异常次数过多,停止重连")
break
await asyncio.sleep(5)
async def _cleanup(self):
"""清理资源"""
logger.info("[KOOK] 开始清理资源")
if self.client:
try:
await self.client.close()
except Exception as e:
logger.error(f"[KOOK] 关闭客户端异常: {e}")
if self._main_task and not self._main_task.done():
self._main_task.cancel()
try:
await self._main_task
except asyncio.CancelledError:
pass
logger.info("[KOOK] 资源清理完成")
def _parse_kmarkdown_text_message(
self, data: dict, self_id: str
) -> tuple[list, str]:
kmarkdown = data.get("extra", {}).get("kmarkdown", {})
content = data.get("content") or ""
raw_content = kmarkdown.get("raw_content") or content
if not isinstance(content, str):
content = str(content)
if not isinstance(raw_content, str):
raw_content = str(raw_content)
mention_name_map: dict[str, str] = {}
mention_part = kmarkdown.get("mention_part", [])
if isinstance(mention_part, list):
for item in mention_part:
if not isinstance(item, dict):
continue
mention_id = item.get("id")
if mention_id is None:
continue
mention_name_map[str(mention_id)] = str(item.get("username", ""))
components = []
cursor = 0
for match in re.finditer(r"\(met\)([^()]+)\(met\)", content):
if match.start() > cursor:
plain_text = content[cursor : match.start()]
if plain_text:
components.append(Plain(text=plain_text))
mention_target = match.group(1).strip()
if mention_target == "all":
components.append(AtAll())
elif mention_target:
components.append(
At(
qq=mention_target,
name=mention_name_map.get(mention_target, ""),
)
)
cursor = match.end()
if cursor < len(content):
tail_text = content[cursor:]
if tail_text:
components.append(Plain(text=tail_text))
message_str = raw_content
if components:
for comp in components:
if isinstance(comp, Plain):
if not comp.text.strip():
continue
break
if isinstance(comp, At):
if str(comp.qq) == str(self_id):
message_str = re.sub(
r"^@[^\s]+(\s*-\s*[^\s]+)?\s*",
"",
message_str,
count=1,
).strip()
break
if not components:
if message_str:
components = [Plain(text=message_str)]
else:
components = []
return components, message_str
def _parse_card_message(self, data: dict) -> tuple[list, str]:
content = data.get("content", "[]")
if not isinstance(content, str):
content = str(content)
card_list = json.loads(content)
text_parts: list[str] = []
images: list[str] = []
for card in card_list:
if not isinstance(card, dict):
continue
for module in card.get("modules", []):
if not isinstance(module, dict):
continue
module_type = module.get("type")
if module_type == "section":
section_text = module.get("text", {}).get("content", "")
if section_text:
text_parts.append(str(section_text))
continue
if module_type != "container":
continue
for element in module.get("elements", []):
if not isinstance(element, dict):
continue
if element.get("type") != "image":
continue
image_src = element.get("src")
if not isinstance(image_src, str):
logger.warning(
f'[KOOK] 处理卡片中的图片时发生错误,图片url "{image_src}" 应该为str类型, 而不是 "{type(image_src)}" '
)
continue
if not image_src.startswith(("http://", "https://")):
logger.warning(f"[KOOK] 屏蔽非http图片url: {image_src}")
continue
images.append(image_src)
text = "".join(text_parts)
message = []
if text:
message.append(Plain(text=text))
for img_url in images:
message.append(Image(file=img_url))
return message, text
async def convert_message(self, data: dict) -> AstrBotMessage:
abm = AstrBotMessage()
abm.raw_message = data
abm.self_id = self.client.bot_id
channel_type = data.get("channel_type")
author_id = data.get("author_id", "unknown")
# channel_type定义: https://developer.kookapp.cn/doc/event/event-introduction
match channel_type:
case "GROUP":
session_id = data.get("target_id") or "unknown"
abm.type = MessageType.GROUP_MESSAGE
abm.group_id = session_id
abm.session_id = session_id
case "PERSON":
abm.type = MessageType.FRIEND_MESSAGE
abm.group_id = ""
abm.session_id = data.get("author_id", "unknown")
case "BROADCAST":
session_id = data.get("target_id") or "unknown"
abm.type = MessageType.OTHER_MESSAGE
abm.group_id = session_id
abm.session_id = session_id
case _:
raise ValueError(f"不支持的频道类型: {channel_type}")
abm.sender = MessageMember(
user_id=author_id,
nickname=data.get("extra", {}).get("author", {}).get("username", ""),
)
abm.message_id = data.get("msg_id", "unknown")
# 普通文本消息
if data.get("type") == 9:
message, message_str = self._parse_kmarkdown_text_message(
data, str(abm.self_id)
)
abm.message = message
abm.message_str = message_str
# 卡片消息
elif data.get("type") == 10:
try:
abm.message, abm.message_str = self._parse_card_message(data)
except Exception as exp:
logger.error(f"[KOOK] 卡片消息解析失败: {exp}")
abm.message_str = "[卡片消息解析失败]"
abm.message = [Plain(text="[卡片消息解析失败]")]
else:
logger.warning(f'[KOOK] 不支持的kook消息类型: "{data.get("type")}"')
abm.message_str = "[不支持的消息类型]"
abm.message = [Plain(text="[不支持的消息类型]")]
return abm
async def handle_msg(self, message: AstrBotMessage):
message_event = KookEvent(
message_str=message.message_str,
message_obj=message,
platform_meta=self.meta(),
session_id=message.session_id,
client=self.client,
)
self.commit_event(message_event)
@@ -1,437 +0,0 @@
import asyncio
import base64
import json
import os
import random
import time
import zlib
from pathlib import Path
import aiofiles
import aiohttp
import websockets
from astrbot import logger
from astrbot.core.platform.message_type import MessageType
from .kook_config import KookConfig
from .kook_types import KookApiPaths, KookMessageType
class KookClient:
def __init__(self, config: KookConfig, event_callback):
# 数据字段
self.config = config
self._bot_id = ""
self._bot_name = ""
# 资源字段
self._http_client = aiohttp.ClientSession(
headers={
"Authorization": f"Bot {self.config.token}",
}
)
self.event_callback = event_callback # 回调函数,用于处理接收到的事件
self.ws = None
self.heartbeat_task = None
self._stop_event = asyncio.Event() # 用于通知连接结束
# 状态/计算字段
self.running = False
self.session_id = None
self.last_sn = 0 # 记录最后处理的消息序号
self.last_heartbeat_time = 0
self.heartbeat_failed_count = 0
@property
def bot_id(self):
return self._bot_id
@property
def bot_name(self):
return self._bot_name
async def get_bot_info(self) -> str:
"""获取机器人账号ID"""
url = KookApiPaths.USER_ME
try:
async with self._http_client.get(url) as resp:
if resp.status != 200:
logger.error(f"[KOOK] 获取机器人账号ID失败,状态码: {resp.status}")
return ""
data = await resp.json()
if data.get("code") != 0:
logger.error(f"[KOOK] 获取机器人账号ID失败: {data}")
return ""
bot_id: str = data["data"]["id"]
self._bot_id = bot_id
logger.info(f"[KOOK] 获取机器人账号ID成功: {bot_id}")
bot_name: str = data["data"]["nickname"] or data["data"]["username"]
self._bot_name = bot_name
logger.info(f"[KOOK] 获取机器人名称成功: {self._bot_name}")
return bot_id
except Exception as e:
logger.error(f"[KOOK] 获取机器人账号ID异常: {e}")
return ""
async def get_gateway_url(self, resume=False, sn=0, session_id=None):
"""获取网关连接地址"""
url = KookApiPaths.GATEWAY_INDEX
# 构建连接参数
params = {}
if resume:
params["resume"] = 1
params["sn"] = sn
if session_id:
params["session_id"] = session_id
try:
async with self._http_client.get(url, params=params) as resp:
if resp.status != 200:
logger.error(f"[KOOK] 获取gateway失败,状态码: {resp.status}")
return None
data = await resp.json()
if data.get("code") != 0:
logger.error(f"[KOOK] 获取gateway失败: {data}")
return None
gateway_url: str = data["data"]["url"]
logger.info(f"[KOOK] 获取gateway成功: {gateway_url.split('?')[0]}")
return gateway_url
except Exception as e:
logger.error(f"[KOOK] 获取gateway异常: {e}")
return None
async def connect(self, resume=False):
"""连接WebSocket"""
if self.ws:
try:
await self.ws.close()
except Exception:
pass
self.ws = None
self._stop_event.clear()
try:
# 获取gateway地址
gateway_url = await self.get_gateway_url(
resume=resume, sn=self.last_sn, session_id=self.session_id
)
await self.get_bot_info()
if not gateway_url:
return False
# 连接WebSocket
self.ws = await websockets.connect(gateway_url)
self.running = True
logger.info("[KOOK] WebSocket 连接成功")
# 启动心跳任务
if self.heartbeat_task:
self.heartbeat_task.cancel()
self.heartbeat_task = asyncio.create_task(self._heartbeat_loop())
# 开始监听消息
await self.listen()
return True
except Exception as e:
logger.error(f"[KOOK] WebSocket 连接失败: {e}")
if self.ws:
try:
await self.ws.close()
except Exception:
pass
self.ws = None
return False
async def listen(self):
"""监听WebSocket消息"""
try:
while self.running:
try:
msg = await asyncio.wait_for(self.ws.recv(), timeout=10) # type: ignore
if isinstance(msg, bytes):
try:
msg = zlib.decompress(msg)
except Exception as e:
logger.error(f"[KOOK] 解压消息失败: {e}")
continue
msg = msg.decode("utf-8")
data = json.loads(msg)
# 处理不同类型的信令
await self._handle_signal(data)
except asyncio.TimeoutError:
# 超时检查,继续循环
continue
except websockets.exceptions.ConnectionClosed:
logger.warning("[KOOK] WebSocket连接已关闭")
break
except Exception as e:
logger.error(f"[KOOK] 消息处理异常: {e}")
break
except Exception as e:
logger.error(f"[KOOK] WebSocket 监听异常: {e}")
finally:
self.running = False
self._stop_event.set()
async def _handle_signal(self, data):
"""处理不同类型的信令"""
signal_type = data.get("s")
if signal_type == 0: # 事件消息
# 更新消息序号
if "sn" in data:
self.last_sn = data["sn"]
await self.event_callback(data)
elif signal_type == 1: # HELLO握手
await self._handle_hello(data)
elif signal_type == 3: # PONG心跳响应
await self._handle_pong(data)
elif signal_type == 5: # RECONNECT重连指令
await self._handle_reconnect(data)
elif signal_type == 6: # RESUME ACK
await self._handle_resume_ack(data)
else:
logger.debug(f"[KOOK] 未处理的信令类型: {signal_type}")
async def _handle_hello(self, data):
"""处理HELLO握手"""
hello_data = data.get("d", {})
code = hello_data.get("code", 0)
if code == 0:
self.session_id = hello_data.get("session_id")
logger.info(f"[KOOK] 握手成功,session_id: {self.session_id}")
# TODO 重置重连延迟
# self.reconnect_delay = 1
else:
logger.error(f"[KOOK] 握手失败,错误码: {code}")
if code == 40103: # token过期
logger.error("[KOOK] Token已过期,需要重新获取")
self.running = False
async def _handle_pong(self, data):
"""处理PONG心跳响应"""
self.last_heartbeat_time = time.time()
self.heartbeat_failed_count = 0
async def _handle_reconnect(self, data):
"""处理重连指令"""
logger.warning("[KOOK] 收到重连指令")
# 清空本地状态
self.last_sn = 0
self.session_id = None
self.running = False
async def _handle_resume_ack(self, data):
"""处理RESUME确认"""
resume_data = data.get("d", {})
self.session_id = resume_data.get("session_id")
logger.info(f"[KOOK] Resume成功,session_id: {self.session_id}")
async def _heartbeat_loop(self):
"""心跳循环"""
while self.running:
try:
# 随机化心跳间隔 (±5秒)
interval = max(
1, self.config.heartbeat_interval + random.randint(-5, 5)
)
await asyncio.sleep(interval)
if not self.running:
break
# 发送心跳
await self._send_ping()
# 等待PONG响应
await asyncio.sleep(self.config.heartbeat_timeout)
# 检查是否收到PONG响应
if (
time.time() - self.last_heartbeat_time
> self.config.heartbeat_timeout
):
self.heartbeat_failed_count += 1
logger.warning(
f"[KOOK] 心跳超时,失败次数: {self.heartbeat_failed_count}"
)
if (
self.heartbeat_failed_count
>= self.config.max_heartbeat_failures
):
logger.error("[KOOK] 心跳失败次数过多,准备重连")
self.running = False
break
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"[KOOK] 心跳异常: {e}")
self.heartbeat_failed_count += 1
async def _send_ping(self):
"""发送心跳PING"""
try:
ping_data = {"s": 2, "sn": self.last_sn}
await self.ws.send(json.dumps(ping_data)) # type: ignore
except Exception as e:
logger.error(f"[KOOK] 发送心跳失败: {e}")
async def send_text(
self,
target_id: str,
content: str,
astrbot_message_type: MessageType,
kook_message_type: KookMessageType,
reply_message_id: str | int = "",
):
"""发送文本消息
消息发送接口文档参见: https://developer.kookapp.cn/doc/http/message#%E5%8F%91%E9%80%81%E9%A2%91%E9%81%93%E8%81%8A%E5%A4%A9%E6%B6%88%E6%81%AF
KMarkdown格式参见: https://developer.kookapp.cn/doc/kmarkdown-desc
"""
url = KookApiPaths.CHANNEL_MESSAGE_CREATE
if astrbot_message_type == MessageType.FRIEND_MESSAGE:
url = KookApiPaths.DIRECT_MESSAGE_CREATE
payload = {
"target_id": target_id,
"content": content,
"type": kook_message_type,
}
if reply_message_id:
payload["quote"] = reply_message_id
payload["reply_msg_id"] = reply_message_id
try:
async with self._http_client.post(url, json=payload) as resp:
if resp.status == 200:
result = await resp.json()
if result.get("code") != 0:
raise RuntimeError(
f'发送kook消息类型 "{kook_message_type.name}" 失败: {result}'
)
# else:
# logger.info("[KOOK] 发送消息成功")
else:
raise RuntimeError(
f'发送kook消息类型 "{kook_message_type.name}" HTTP错误: {resp.status} , 响应内容 : {await resp.text()}'
)
except RuntimeError:
raise
except Exception as e:
logger.error(
f'[KOOK] 发送kook消息类型 "{kook_message_type.name}" 异常: {e}'
)
async def upload_asset(self, file_url: str | None) -> str:
"""上传文件到kook,获得远端资源url
接口定义参见: https://developer.kookapp.cn/doc/http/asset
"""
if not file_url:
return ""
bytes_data: bytes | None = None
filename = "unknown"
if file_url.startswith(("http://", "https://")):
filename = file_url.split("/")[-1]
return file_url
if file_url.startswith("base64:///"):
# b64decode的时候得开头留一个'/'的, 不然会报错
b64_str = file_url.removeprefix("base64://")
bytes_data = base64.b64decode(b64_str)
elif file_url.startswith("file://") or os.path.exists(file_url):
file_url = file_url.removeprefix("file:///")
file_url = file_url.removeprefix("file://")
try:
target_path = Path(file_url).resolve()
except Exception as exp:
logger.error(f'[KOOK] 获取文件 "{file_url}" 绝对路径失败: "{exp}"')
raise FileNotFoundError(
f'获取文件 "{file_url}" 绝对路径失败: "{exp}"'
) from exp
if not target_path.is_file():
raise FileNotFoundError(f"文件不存在: {target_path.name}")
filename = target_path.name
async with aiofiles.open(target_path, "rb") as f:
bytes_data = await f.read()
else:
raise ValueError(f'[KOOK] 不支持的文件资源类型: "{file_url}"')
data = aiohttp.FormData()
data.add_field("file", bytes_data, filename=filename)
url = KookApiPaths.ASSET_CREATE
try:
async with self._http_client.post(url, data=data) as resp:
if resp.status == 200:
result: dict = await resp.json()
logger.debug(f"[KOOK] 上传文件响应: {result}")
if result.get("code") == 0:
logger.info("[KOOK] 上传文件到kook服务器成功")
remote_url = result["data"]["url"]
logger.debug(f"[KOOK] 文件远端URL: {remote_url}")
return remote_url
else:
raise RuntimeError(f"上传文件到kook服务器失败: {result}")
else:
raise RuntimeError(
f"上传文件到kook服务器 HTTP错误: {resp.status} , {await resp.text()}"
)
except RuntimeError:
raise
except Exception as e:
raise RuntimeError(f"上传文件到kook服务器异常: {e}") from e
async def wait_until_closed(self):
"""提供给外部调用的等待方法"""
await self._stop_event.wait()
async def close(self):
"""关闭连接"""
self.running = False
self._stop_event.set()
if self.heartbeat_task:
self.heartbeat_task.cancel()
try:
await self.heartbeat_task
except asyncio.CancelledError:
pass
if self.ws:
try:
await self.ws.close()
except Exception as e:
logger.error(f"[KOOK] 关闭WebSocket异常: {e}")
if self._http_client:
await self._http_client.close()
logger.info("[KOOK] 连接已关闭")
@@ -1,133 +0,0 @@
import json
from dataclasses import asdict, dataclass
from typing import Any
@dataclass
class KookConfig:
"""KOOK 适配器配置类"""
# 基础配置
token: str
bot_nickname: str = ""
enable: bool = False
id: str = "kook"
# 重连配置
reconnect_delay: int = 1
"""重连延迟基数(秒),指数退避"""
max_reconnect_delay: int = 60
"""最大重连延迟(秒)"""
max_retry_delay: int = 60
"""最大重试延迟(秒)"""
# 心跳配置
heartbeat_interval: int = 30
"""心跳间隔(秒)"""
heartbeat_timeout: int = 6
"""心跳超时时间(秒)"""
max_heartbeat_failures: int = 3
"""最大心跳失败次数"""
# 失败处理
max_consecutive_failures: int = 5
"""最大连续失败次数"""
@classmethod
def from_dict(cls, config_dict: dict) -> "KookConfig":
"""从字典创建配置对象"""
return cls(
# 适配器id 应该是不能改的
# id=config_dict.get("id", "kook"),
enable=config_dict.get("enable", False),
token=config_dict.get("kook_bot_token", ""),
bot_nickname=config_dict.get("kook_bot_nickname", ""),
reconnect_delay=config_dict.get(
"kook_reconnect_delay",
KookConfig.reconnect_delay,
),
max_reconnect_delay=config_dict.get(
"kook_max_reconnect_delay",
KookConfig.max_reconnect_delay,
),
max_retry_delay=config_dict.get(
"kook_max_retry_delay",
KookConfig.max_retry_delay,
),
heartbeat_interval=config_dict.get(
"kook_heartbeat_interval",
KookConfig.heartbeat_interval,
),
heartbeat_timeout=config_dict.get(
"kook_heartbeat_timeout",
KookConfig.heartbeat_timeout,
),
max_heartbeat_failures=config_dict.get(
"kook_max_heartbeat_failures",
KookConfig.max_heartbeat_failures,
),
max_consecutive_failures=config_dict.get(
"kook_max_consecutive_failures",
KookConfig.max_consecutive_failures,
),
)
def to_dict(self) -> dict[str, Any]:
return asdict(self)
def pretty_jsons(self, indent=2) -> str:
dict_config = self.to_dict()
dict_config["token"] = "*" * len(self.token) if self.token else "MISSING"
return json.dumps(dict_config, indent=indent, ensure_ascii=False)
# TODO 没用上的config配置,未来有空会实现这些配置描述的功能?
# # 连接配置
# CONNECTION_CONFIG = {
# # 心跳配置
# "heartbeat_interval": 30, # 心跳间隔(秒)
# "heartbeat_timeout": 6, # 心跳超时时间(秒)
# "max_heartbeat_failures": 3, # 最大心跳失败次数
# # 重连配置
# "initial_reconnect_delay": 1, # 初始重连延迟(秒)
# "max_reconnect_delay": 60, # 最大重连延迟(秒)
# "max_consecutive_failures": 5, # 最大连续失败次数
# # WebSocket配置
# "websocket_timeout": 10, # WebSocket接收超时(秒)
# "connection_timeout": 30, # 连接超时(秒)
# # 消息处理配置
# "enable_compression": True, # 是否启用消息压缩
# "max_message_size": 1024 * 1024, # 最大消息大小(字节)
# }
# # 日志配置
# LOGGING_CONFIG = {
# "level": "INFO", # 日志级别:DEBUG, INFO, WARNING, ERROR
# "format": "[KOOK] %(message)s",
# "enable_heartbeat_logs": False, # 是否启用心跳日志
# "enable_message_logs": False, # 是否启用消息日志
# }
# # 错误处理配置
# ERROR_HANDLING_CONFIG = {
# "retry_on_network_error": True, # 网络错误时是否重试
# "retry_on_token_expired": True, # Token过期时是否重试
# "max_retry_attempts": 3, # 最大重试次数
# "retry_delay_base": 2, # 重试延迟基数(秒)
# }
# # 性能配置
# PERFORMANCE_CONFIG = {
# "enable_message_buffering": True, # 是否启用消息缓冲
# "buffer_size": 100, # 缓冲区大小
# "enable_connection_pooling": True, # 是否启用连接池
# "max_concurrent_requests": 10, # 最大并发请求数
# }
# # 安全配置
# SECURITY_CONFIG = {
# "verify_ssl": True, # 是否验证SSL证书
# "enable_rate_limiting": True, # 是否启用速率限制
# "rate_limit_requests": 100, # 速率限制请求数
# "rate_limit_window": 60, # 速率限制窗口(秒)
# }
@@ -1,209 +0,0 @@
import asyncio
import json
from collections.abc import Coroutine
from pathlib import Path
from typing import Any
from astrbot import logger
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
from astrbot.core.message.components import (
At,
AtAll,
BaseMessageComponent,
File,
Image,
Json,
Plain,
Record,
Reply,
Video,
)
from astrbot.core.platform import MessageType
from .kook_client import KookClient
from .kook_types import (
FileModule,
KookCardMessage,
KookCardMessageContainer,
KookMessageType,
OrderMessage,
)
class KookEvent(AstrMessageEvent):
def __init__(
self,
message_str: str,
message_obj: AstrBotMessage,
platform_meta: PlatformMetadata,
session_id: str,
client: KookClient,
):
super().__init__(message_str, message_obj, platform_meta, session_id)
self.client = client
self.channel_id = message_obj.group_id or message_obj.session_id
self.astrbot_message_type: MessageType = message_obj.type
self._file_message_counter = 0
def _wrap_message(
self, index: int, message_component: BaseMessageComponent
) -> Coroutine[Any, Any, OrderMessage]:
async def wrap_upload(
index: int, message_type: KookMessageType, upload_coro
) -> OrderMessage:
url = await upload_coro
return OrderMessage(index=index, text=url, type=message_type)
async def handle_plain(
index: int,
text: str | None,
reply_id: str | int = "",
type: KookMessageType = KookMessageType.KMARKDOWN,
):
if not text:
text = ""
return OrderMessage(
index=index,
text=text,
type=type,
reply_id=reply_id,
)
match message_component:
case Image():
self._file_message_counter += 1
return wrap_upload(
index,
KookMessageType.IMAGE,
self.client.upload_asset(message_component.file),
)
case Video():
self._file_message_counter += 1
return wrap_upload(
index,
KookMessageType.VIDEO,
self.client.upload_asset(message_component.file),
)
case File():
async def handle_file(index: int, f_item: File):
f_data = await f_item.get_file()
url = await self.client.upload_asset(f_data)
return OrderMessage(
index=index, text=url, type=KookMessageType.FILE
)
self._file_message_counter += 1
return handle_file(index, message_component)
case Record():
async def handle_audio(index: int, f_item: Record):
file_path = await f_item.convert_to_file_path()
url = await self.client.upload_asset(file_path)
title = f_item.text or Path(file_path).name
return OrderMessage(
index=index,
text=KookCardMessageContainer(
[
KookCardMessage(
modules=[
FileModule(
type="audio",
title=title,
src=url,
)
]
)
]
).to_json(),
type=KookMessageType.CARD,
)
return handle_audio(index, message_component)
case Plain():
return handle_plain(index, message_component.text)
case At():
return handle_plain(index, f"(met){message_component.qq}(met)")
case AtAll():
return handle_plain(index, "(met)all(met)")
case Reply():
return handle_plain(index, "", reply_id=message_component.id)
case Json():
json_data = message_component.data
# kook卡片json外层得是一个列表
if isinstance(json_data, dict):
json_data = [json_data]
return handle_plain(
index,
# 考虑到kook可能会更改消息结构,为了能让插件开发者
# 自行根据kook文档描述填卡片json内容,故不做模型校验
# KookCardMessage().model_validate(message_component.data).to_json(),
text=json.dumps(json_data),
type=KookMessageType.CARD,
)
case _:
raise NotImplementedError(
f'kook适配器尚未实现对 "{message_component.type}" 消息类型的支持'
)
async def send(self, message: MessageChain):
file_upload_tasks: list[Coroutine[Any, Any, OrderMessage]] = []
for index, item in enumerate(message.chain):
file_upload_tasks.append(self._wrap_message(index, item))
if self._file_message_counter > 0:
logger.debug("[Kook] 正在向kook服务器上传文件")
tasks_result = await asyncio.gather(*file_upload_tasks, return_exceptions=True)
order_messages: list[OrderMessage] = []
for index, result in enumerate(tasks_result):
if isinstance(result, BaseException):
logger.error(f"[Kook] {result}")
# 构造一个虚假的 OrderMessage,让用户知道这里本来有张图但坏了
# 这样后面的 for 循环就能把它当成普通文本发出去
err_node = OrderMessage(
index=index,
text=str(result),
type=KookMessageType.TEXT,
)
order_messages.append(err_node)
else:
order_messages.append(result)
order_messages.sort(key=lambda x: x.index)
reply_id: str | int = ""
errors: list[Exception] = []
for item in order_messages:
if item.reply_id:
reply_id = item.reply_id
if not item.text:
logger.debug(f'[Kook] 跳过空消息,类型为"{item.type}"')
continue
try:
await self.client.send_text(
self.channel_id,
item.text,
self.astrbot_message_type,
item.type,
reply_id,
)
except RuntimeError as exp:
await self.client.send_text(
self.channel_id,
str(exp),
self.astrbot_message_type,
KookMessageType.TEXT,
reply_id,
)
errors.append(exp)
if errors:
err_msg = "\n".join([str(err) for err in errors])
logger.error(f"[kook] {err_msg}")
await super().send(message)
@@ -1,241 +0,0 @@
import json
from dataclasses import field
from enum import IntEnum
from typing import Literal
from pydantic import BaseModel, ConfigDict
from pydantic.dataclasses import dataclass
class KookApiPaths:
"""Kook Api 路径"""
BASE_URL = "https://www.kookapp.cn"
API_VERSION_PATH = "/api/v3"
# 初始化相关
USER_ME = f"{BASE_URL}{API_VERSION_PATH}/user/me"
GATEWAY_INDEX = f"{BASE_URL}{API_VERSION_PATH}/gateway/index"
# 消息相关
ASSET_CREATE = f"{BASE_URL}{API_VERSION_PATH}/asset/create"
## 频道消息
CHANNEL_MESSAGE_CREATE = f"{BASE_URL}{API_VERSION_PATH}/message/create"
## 私聊消息
DIRECT_MESSAGE_CREATE = f"{BASE_URL}{API_VERSION_PATH}/direct-message/create"
# 定义参见kook事件结构文档: https://developer.kookapp.cn/doc/event/event-introduction
class KookMessageType(IntEnum):
TEXT = 1
IMAGE = 2
VIDEO = 3
FILE = 4
AUDIO = 8
KMARKDOWN = 9
CARD = 10
SYSTEM = 255
ThemeType = Literal[
"primary", "success", "danger", "warning", "info", "secondary", "none", "invisible"
]
"""主题,可选的值为:primary, success, danger, warning, info, secondary, none.默认为 primary,为 none 时不显示侧边框。"""
SizeType = Literal["xs", "sm", "md", "lg"]
"""大小,可选值为:xs, sm, md, lg, 一般默认为 lg"""
SectionMode = Literal["left", "right"]
CountdownMode = Literal["day", "hour", "second"]
class KookCardColor(str):
"""16 进制色值"""
class KookCardModelBase:
"""卡片模块基类"""
type: str
@dataclass
class PlainTextElement(KookCardModelBase):
content: str
type: str = "plain-text"
emoji: bool = True
@dataclass
class KmarkdownElement(KookCardModelBase):
content: str
type: str = "kmarkdown"
@dataclass
class ImageElement(KookCardModelBase):
src: str
type: str = "image"
alt: str = ""
size: SizeType = "lg"
circle: bool = False
fallbackUrl: str | None = None
@dataclass
class ButtonElement(KookCardModelBase):
text: str
type: str = "button"
theme: ThemeType = "primary"
value: str = ""
"""当为 link 时,会跳转到 value 代表的链接;
当为 return-val 系统会通过系统消息将消息 id,点击用户 id value 发回给发送者发送者可以根据自己的需求进行处理,消息事件参见button 点击事件私聊和频道内均可使用按钮点击事件"""
click: Literal["", "link", "return-val"] = ""
"""click 代表用户点击的事件,默认为"",代表无任何事件。"""
AnyElement = PlainTextElement | KmarkdownElement | ImageElement | ButtonElement | str
@dataclass
class ParagraphStructure(KookCardModelBase):
fields: list[PlainTextElement | KmarkdownElement]
type: str = "paragraph"
cols: int = 1
"""范围是 1-3 , 移动端忽略此参数"""
@dataclass
class HeaderModule(KookCardModelBase):
text: PlainTextElement
type: str = "header"
@dataclass
class SectionModule(KookCardModelBase):
text: PlainTextElement | KmarkdownElement | ParagraphStructure
type: str = "section"
mode: SectionMode = "left"
accessory: ImageElement | ButtonElement | None = None
@dataclass
class ImageGroupModule(KookCardModelBase):
"""1 到多张图片的组合"""
elements: list[ImageElement]
type: str = "image-group"
@dataclass
class ContainerModule(KookCardModelBase):
"""1 到多张图片的组合,与图片组模块(ImageGroupModule)不同,图片并不会裁切为正方形。多张图片会纵向排列。"""
elements: list[ImageElement]
type: str = "container"
@dataclass
class ActionGroupModule(KookCardModelBase):
elements: list[ButtonElement]
type: str = "action-group"
@dataclass
class ContextModule(KookCardModelBase):
elements: list[PlainTextElement | KmarkdownElement | ImageElement]
"""最多包含10个元素"""
type: str = "context"
@dataclass
class DividerModule(KookCardModelBase):
type: str = "divider"
@dataclass
class FileModule(KookCardModelBase):
src: str
title: str = ""
type: Literal["file", "audio", "video"] = "file"
cover: str | None = None
"""cover 仅音频有效, 是音频的封面图"""
@dataclass
class CountdownModule(KookCardModelBase):
"""startTime 和 endTime 为毫秒时间戳,startTime 和 endTime 不能小于服务器当前时间戳。"""
endTime: int
"""毫秒时间戳"""
type: str = "countdown"
startTime: int | None = None
"""毫秒时间戳, 仅当mode为second才有这个字段"""
mode: CountdownMode = "day"
"""mode 主要是倒计时的样式"""
@dataclass
class InviteModule(KookCardModelBase):
code: str
"""邀请链接或者邀请码"""
type: str = "invite"
# 所有模块的联合类型
AnyModule = (
HeaderModule
| SectionModule
| ImageGroupModule
| ContainerModule
| ActionGroupModule
| ContextModule
| DividerModule
| FileModule
| CountdownModule
| InviteModule
)
class KookCardMessage(BaseModel):
"""卡片定义文档详见 : https://developer.kookapp.cn/doc/cardmessage
此类型不能直接to_json后发送,因为kook要求卡片容器json顶层必须是**列表**
若要发送卡片消息请使用KookCardMessageContainer
"""
model_config = ConfigDict(arbitrary_types_allowed=True)
type: str = "card"
theme: ThemeType | None = None
size: SizeType | None = None
color: KookCardColor | None = None
modules: list[AnyModule] = field(default_factory=list)
"""单个 card 模块数量不限制,但是一条消息中所有卡片的模块数量之和最多是 50"""
def add_module(self, module: AnyModule):
self.modules.append(module)
def to_dict(self, exclude_none: bool = True):
"""exclude_none:去掉值为 None 字段,保留结构"""
return self.model_dump(exclude_none=exclude_none)
def to_json(self, indent: int | None = None, ensure_ascii: bool = True):
return json.dumps(self.to_dict(), indent=indent, ensure_ascii=ensure_ascii)
class KookCardMessageContainer(list[KookCardMessage]):
"""卡片消息容器(列表),此类型可以直接to_json后发送出去"""
def append(self, object: KookCardMessage) -> None:
return super().append(object)
def to_json(self, indent: int | None = None, ensure_ascii: bool = True) -> str:
return json.dumps(
[i.to_dict() for i in self], indent=indent, ensure_ascii=ensure_ascii
)
@dataclass
class OrderMessage:
index: int
text: str
type: KookMessageType
reply_id: str | int = ""
@@ -104,7 +104,7 @@ class LineMessageEvent(AstrMessageEvent):
@staticmethod
async def _resolve_image_url(segment: Image) -> str:
candidate = (segment.url or segment.file or "").strip()
if candidate.startswith("https://"):
if candidate.startswith("http://") or candidate.startswith("https://"):
return candidate
try:
return await segment.register_to_file_service()
@@ -115,7 +115,7 @@ class LineMessageEvent(AstrMessageEvent):
@staticmethod
async def _resolve_record_url(segment: Record) -> str:
candidate = (segment.url or segment.file or "").strip()
if candidate.startswith("https://"):
if candidate.startswith("http://") or candidate.startswith("https://"):
return candidate
try:
return await segment.register_to_file_service()
@@ -137,7 +137,7 @@ class LineMessageEvent(AstrMessageEvent):
@staticmethod
async def _resolve_video_url(segment: Video) -> str:
candidate = (segment.file or "").strip()
if candidate.startswith("https://"):
if candidate.startswith("http://") or candidate.startswith("https://"):
return candidate
try:
return await segment.register_to_file_service()
@@ -148,7 +148,9 @@ class LineMessageEvent(AstrMessageEvent):
@staticmethod
async def _resolve_video_preview_url(segment: Video) -> str:
cover_candidate = (segment.cover or "").strip()
if cover_candidate.startswith("https://"):
if cover_candidate.startswith("http://") or cover_candidate.startswith(
"https://"
):
return cover_candidate
if cover_candidate:
@@ -189,7 +191,7 @@ class LineMessageEvent(AstrMessageEvent):
@staticmethod
async def _resolve_file_url(segment: File) -> str:
if segment.url and segment.url.startswith("https://"):
if segment.url and segment.url.startswith(("http://", "https://")):
return segment.url
try:
return await segment.register_to_file_service()
+114 -446
View File
@@ -4,11 +4,7 @@ import asyncio
import copy
import json
import os
import threading
import urllib.parse
from collections.abc import AsyncGenerator, Awaitable, Callable, Mapping
from dataclasses import dataclass
from types import MappingProxyType
from collections.abc import AsyncGenerator, Awaitable, Callable
from typing import Any
import aiohttp
@@ -21,103 +17,6 @@ from astrbot.core.utils.astrbot_path import get_astrbot_data_path
DEFAULT_MCP_CONFIG = {"mcpServers": {}}
DEFAULT_MCP_INIT_TIMEOUT_SECONDS = 20.0
DEFAULT_ENABLE_MCP_TIMEOUT_SECONDS = 30.0
MCP_INIT_TIMEOUT_ENV = "ASTRBOT_MCP_INIT_TIMEOUT"
ENABLE_MCP_TIMEOUT_ENV = "ASTRBOT_MCP_ENABLE_TIMEOUT"
MAX_MCP_TIMEOUT_SECONDS = 300.0
class MCPInitError(Exception):
"""Base exception for MCP initialization failures."""
class MCPInitTimeoutError(asyncio.TimeoutError, MCPInitError):
"""Raised when MCP client initialization exceeds the configured timeout."""
class MCPAllServicesFailedError(MCPInitError):
"""Raised when all configured MCP services fail to initialize."""
class MCPShutdownTimeoutError(asyncio.TimeoutError):
"""Raised when MCP shutdown exceeds the configured timeout."""
def __init__(self, names: list[str], timeout: float) -> None:
self.names = names
self.timeout = timeout
message = f"MCP 服务关闭超时({timeout:g} 秒):{', '.join(names)}"
super().__init__(message)
@dataclass
class MCPInitSummary:
total: int
success: int
failed: list[str]
@dataclass
class _MCPServerRuntime:
name: str
client: MCPClient
shutdown_event: asyncio.Event
lifecycle_task: asyncio.Task[None]
class _MCPClientDictView(Mapping[str, MCPClient]):
"""Read-only view of MCP clients derived from runtime state."""
def __init__(self, runtime: dict[str, _MCPServerRuntime]) -> None:
self._runtime = runtime
def __getitem__(self, key: str) -> MCPClient:
return self._runtime[key].client
def __iter__(self):
return iter(self._runtime)
def __len__(self) -> int:
return len(self._runtime)
def _resolve_timeout(
timeout: float | int | str | None = None,
*,
env_name: str = MCP_INIT_TIMEOUT_ENV,
default: float = DEFAULT_MCP_INIT_TIMEOUT_SECONDS,
) -> float:
"""Resolve timeout with precedence: explicit argument > env value > default."""
source = f"环境变量 {env_name}"
if timeout is None:
timeout = os.getenv(env_name, str(default))
else:
source = "显式参数 timeout"
try:
timeout_value = float(timeout)
except (TypeError, ValueError):
logger.warning(
f"超时配置({source}={timeout!r} 无效,使用默认值 {default:g} 秒。"
)
return default
if timeout_value <= 0:
logger.warning(
f"超时配置({source}={timeout_value:g} 必须大于 0,使用默认值 {default:g} 秒。"
)
return default
if timeout_value > MAX_MCP_TIMEOUT_SECONDS:
logger.warning(
f"超时配置({source}={timeout_value:g} 过大,已限制为最大值 "
f"{MAX_MCP_TIMEOUT_SECONDS:g} 秒,以避免长时间等待。"
)
return MAX_MCP_TIMEOUT_SECONDS
return timeout_value
SUPPORTED_TYPES = [
"string",
"number",
@@ -207,49 +106,9 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
class FunctionToolManager:
def __init__(self) -> None:
self.func_list: list[FuncTool] = []
self._mcp_server_runtime: dict[str, _MCPServerRuntime] = {}
"""MCP 服务运行时状态(唯一事实来源)"""
self._mcp_server_runtime_view = MappingProxyType(self._mcp_server_runtime)
self._mcp_client_dict_view = _MCPClientDictView(self._mcp_server_runtime)
self._timeout_mismatch_warned = False
self._timeout_warn_lock = threading.Lock()
self._runtime_lock = asyncio.Lock()
self._mcp_starting: set[str] = set()
self._init_timeout_default = _resolve_timeout(
timeout=None,
env_name=MCP_INIT_TIMEOUT_ENV,
default=DEFAULT_MCP_INIT_TIMEOUT_SECONDS,
)
self._enable_timeout_default = _resolve_timeout(
timeout=None,
env_name=ENABLE_MCP_TIMEOUT_ENV,
default=DEFAULT_ENABLE_MCP_TIMEOUT_SECONDS,
)
self._warn_on_timeout_mismatch(
self._init_timeout_default,
self._enable_timeout_default,
)
@property
def mcp_client_dict(self) -> Mapping[str, MCPClient]:
"""Read-only compatibility view for external callers that still read mcp_client_dict.
Note: Mutating this mapping is unsupported and will raise TypeError.
"""
return self._mcp_client_dict_view
@property
def mcp_server_runtime_view(self) -> Mapping[str, _MCPServerRuntime]:
"""Read-only view of MCP runtime metadata for external callers."""
return self._mcp_server_runtime_view
@property
def mcp_server_runtime(self) -> Mapping[str, _MCPServerRuntime]:
"""Backward-compatible read-only view (deprecated). Do not mutate.
Note: Mutations are not supported and will raise TypeError.
"""
return self._mcp_server_runtime_view
self.mcp_client_dict: dict[str, MCPClient] = {}
"""MCP 服务列表"""
self.mcp_client_event: dict[str, asyncio.Event] = {}
def empty(self) -> bool:
return len(self.func_list) == 0
@@ -320,34 +179,7 @@ class FunctionToolManager:
tool_set = ToolSet(self.func_list.copy())
return tool_set
@staticmethod
def _log_safe_mcp_debug_config(cfg: dict) -> None:
# 仅记录脱敏后的摘要,避免泄露 command/args/url 中的敏感信息
if "command" in cfg:
cmd = cfg["command"]
executable = str(cmd[0] if isinstance(cmd, (list, tuple)) and cmd else cmd)
args_val = cfg.get("args", [])
args_count = (
len(args_val)
if isinstance(args_val, (list, tuple))
else (0 if args_val is None else 1)
)
logger.debug(f" 命令可执行文件: {executable}, 参数数量: {args_count}")
return
if "url" in cfg:
parsed = urllib.parse.urlparse(str(cfg["url"]))
host = parsed.hostname or ""
scheme = parsed.scheme or "unknown"
try:
port = f":{parsed.port}" if parsed.port else ""
except ValueError:
port = ""
logger.debug(f" 主机: {scheme}://{host}{port}")
async def init_mcp_clients(
self, raise_on_all_failed: bool = False
) -> MCPInitSummary:
async def init_mcp_clients(self) -> None:
"""从项目根目录读取 mcp_server.json 文件,初始化 MCP 服务列表。文件格式如下:
```
{
@@ -365,10 +197,6 @@ class FunctionToolManager:
...
}
```
Timeout behavior:
- 初始化超时使用环境变量 ASTRBOT_MCP_INIT_TIMEOUT 或默认值
- 动态启用超时使用 ASTRBOT_MCP_ENABLE_TIMEOUT独立于初始化超时
"""
data_dir = get_astrbot_data_path()
@@ -378,211 +206,56 @@ class FunctionToolManager:
with open(mcp_json_file, "w", encoding="utf-8") as f:
json.dump(DEFAULT_MCP_CONFIG, f, ensure_ascii=False, indent=4)
logger.info(f"未找到 MCP 服务配置文件,已创建默认配置文件 {mcp_json_file}")
return MCPInitSummary(total=0, success=0, failed=[])
return
with open(mcp_json_file, encoding="utf-8") as f:
mcp_server_json_obj: dict[str, dict] = json.load(f)["mcpServers"]
mcp_server_json_obj: dict[str, dict] = json.load(
open(mcp_json_file, encoding="utf-8"),
)["mcpServers"]
init_timeout = self._init_timeout_default
timeout_display = f"{init_timeout:g}"
active_configs: list[tuple[str, dict, asyncio.Event]] = []
for name, cfg in mcp_server_json_obj.items():
for name in mcp_server_json_obj:
cfg = mcp_server_json_obj[name]
if cfg.get("active", True):
shutdown_event = asyncio.Event()
active_configs.append((name, cfg, shutdown_event))
event = asyncio.Event()
asyncio.create_task(
self._init_mcp_client_task_wrapper(name, cfg, event),
)
self.mcp_client_event[name] = event
if not active_configs:
return MCPInitSummary(total=0, success=0, failed=[])
logger.info(f"等待 {len(active_configs)} 个 MCP 服务初始化...")
init_tasks = [
asyncio.create_task(
self._start_mcp_server(
name=name,
cfg=cfg,
shutdown_event=shutdown_event,
timeout=init_timeout,
),
name=f"mcp-init:{name}",
)
for (name, cfg, shutdown_event) in active_configs
]
results = await asyncio.gather(*init_tasks, return_exceptions=True)
success_count = 0
failed_services: list[str] = []
for (name, cfg, _), result in zip(active_configs, results, strict=False):
if isinstance(result, Exception):
if isinstance(result, MCPInitTimeoutError):
logger.error(f"MCP 服务 {name} 初始化超时({timeout_display}秒)")
else:
logger.error(f"MCP 服务 {name} 初始化失败: {result}")
self._log_safe_mcp_debug_config(cfg)
failed_services.append(name)
async with self._runtime_lock:
self._mcp_server_runtime.pop(name, None)
continue
success_count += 1
if failed_services:
logger.warning(
f"以下 MCP 服务初始化失败: {', '.join(failed_services)}"
f"请检查配置文件 mcp_server.json 和服务器可用性。"
)
summary = MCPInitSummary(
total=len(active_configs), success=success_count, failed=failed_services
)
logger.info(f"MCP 服务初始化完成: {summary.success}/{summary.total} 成功")
if summary.total > 0 and summary.success == 0:
msg = "全部 MCP 服务初始化失败,请检查 mcp_server.json 配置和服务器可用性。"
if raise_on_all_failed:
raise MCPAllServicesFailedError(msg)
logger.error(msg)
return summary
async def _start_mcp_server(
async def _init_mcp_client_task_wrapper(
self,
name: str,
cfg: dict,
*,
shutdown_event: asyncio.Event | None = None,
timeout: float,
event: asyncio.Event,
ready_future: asyncio.Future | None = None,
) -> None:
"""Initialize MCP server with timeout and register task/event together.
This method is idempotent. If the server is already running, the existing
runtime is kept and the new config is ignored.
"""
async with self._runtime_lock:
if name in self._mcp_server_runtime or name in self._mcp_starting:
logger.warning(
f"MCP 服务 {name} 已在运行,忽略本次启用请求(timeout={timeout:g})。"
)
self._log_safe_mcp_debug_config(cfg)
return
self._mcp_starting.add(name)
if shutdown_event is None:
shutdown_event = asyncio.Event()
mcp_client: MCPClient | None = None
"""初始化 MCP 客户端的包装函数,用于捕获异常"""
try:
mcp_client = await asyncio.wait_for(
self._init_mcp_client(name, cfg),
timeout=timeout,
)
except asyncio.TimeoutError as exc:
raise MCPInitTimeoutError(
f"MCP 服务 {name} 初始化超时({timeout:g} 秒)"
) from exc
except Exception:
await self._init_mcp_client(name, cfg)
tools = await self.mcp_client_dict[name].list_tools_and_save()
if ready_future and not ready_future.done():
# tell the caller we are ready
ready_future.set_result(tools)
await event.wait()
logger.info(f"收到 MCP 客户端 {name} 终止信号")
except Exception as e:
logger.error(f"初始化 MCP 客户端 {name} 失败", exc_info=True)
raise
if ready_future and not ready_future.done():
ready_future.set_exception(e)
finally:
if mcp_client is None:
async with self._runtime_lock:
self._mcp_starting.discard(name)
# 无论如何都能清理
await self._terminate_mcp_client(name)
async def lifecycle() -> None:
try:
await shutdown_event.wait()
logger.info(f"收到 MCP 客户端 {name} 终止信号")
except asyncio.CancelledError:
logger.debug(f"MCP 客户端 {name} 任务被取消")
raise
finally:
await self._terminate_mcp_client(name)
lifecycle_task = asyncio.create_task(lifecycle(), name=f"mcp-client:{name}")
async with self._runtime_lock:
self._mcp_server_runtime[name] = _MCPServerRuntime(
name=name,
client=mcp_client,
shutdown_event=shutdown_event,
lifecycle_task=lifecycle_task,
)
self._mcp_starting.discard(name)
async def _shutdown_runtimes(
self,
runtimes: list[_MCPServerRuntime],
timeout: float,
*,
strict: bool = True,
) -> list[str]:
"""Shutdown runtimes and wait for lifecycle tasks to complete."""
lifecycle_tasks = [
runtime.lifecycle_task
for runtime in runtimes
if not runtime.lifecycle_task.done()
]
if not lifecycle_tasks:
return []
for runtime in runtimes:
runtime.shutdown_event.set()
try:
results = await asyncio.wait_for(
asyncio.gather(*lifecycle_tasks, return_exceptions=True),
timeout=timeout,
)
except asyncio.TimeoutError:
pending_names = [
runtime.name
for runtime in runtimes
if not runtime.lifecycle_task.done()
]
for task in lifecycle_tasks:
if not task.done():
task.cancel()
await asyncio.gather(*lifecycle_tasks, return_exceptions=True)
if strict:
raise MCPShutdownTimeoutError(pending_names, timeout)
logger.warning(
"MCP 服务关闭超时(%s 秒),以下服务未完全关闭:%s",
f"{timeout:g}",
", ".join(pending_names),
)
return pending_names
else:
for result in results:
if isinstance(result, asyncio.CancelledError):
logger.debug("MCP lifecycle task was cancelled during shutdown.")
elif isinstance(result, Exception):
logger.error(
"MCP lifecycle task failed during shutdown.",
exc_info=(type(result), result, result.__traceback__),
)
return []
async def _cleanup_mcp_client_safely(
self, mcp_client: MCPClient, name: str
) -> None:
"""安全清理单个 MCP 客户端,避免清理异常中断主流程。"""
try:
await mcp_client.cleanup()
except Exception as cleanup_exc: # noqa: BLE001 - only log here
logger.error(f"清理 MCP 客户端资源 {name} 失败: {cleanup_exc}")
async def _init_mcp_client(self, name: str, config: dict) -> MCPClient:
async def _init_mcp_client(self, name: str, config: dict) -> None:
"""初始化单个MCP客户端"""
# 先清理之前的客户端,如果存在
if name in self.mcp_client_dict:
await self._terminate_mcp_client(name)
mcp_client = MCPClient()
mcp_client.name = name
try:
await mcp_client.connect_to_server(config, name)
tools_res = await mcp_client.list_tools_and_save()
except asyncio.CancelledError:
await self._cleanup_mcp_client_safely(mcp_client, name)
raise
except Exception:
await self._cleanup_mcp_client_safely(mcp_client, name)
raise
self.mcp_client_dict[name] = mcp_client
await mcp_client.connect_to_server(config, name)
tools_res = await mcp_client.list_tools_and_save()
logger.debug(f"MCP server {name} list tools response: {tools_res}")
tool_names = [tool.name for tool in tools_res.tools]
@@ -603,36 +276,26 @@ class FunctionToolManager:
self.func_list.append(func_tool)
logger.info(f"已连接 MCP 服务 {name}, Tools: {tool_names}")
return mcp_client
async def _terminate_mcp_client(self, name: str) -> None:
"""关闭并清理MCP客户端"""
async with self._runtime_lock:
runtime = self._mcp_server_runtime.get(name)
if runtime:
client = runtime.client
# 关闭MCP连接
await self._cleanup_mcp_client_safely(client, name)
# 移除关联的FuncTool
self.func_list = [
f
for f in self.func_list
if not (isinstance(f, MCPTool) and f.mcp_server_name == name)
]
async with self._runtime_lock:
self._mcp_server_runtime.pop(name, None)
self._mcp_starting.discard(name)
logger.info(f"已关闭 MCP 服务 {name}")
return
# Runtime missing but stale tools may still exist after failed flows.
self.func_list = [
f
for f in self.func_list
if not (isinstance(f, MCPTool) and f.mcp_server_name == name)
]
async with self._runtime_lock:
self._mcp_starting.discard(name)
if name in self.mcp_client_dict:
client = self.mcp_client_dict[name]
try:
# 关闭MCP连接
await client.cleanup()
except Exception as e:
logger.error(f"清空 MCP 客户端资源 {name}: {e}")
finally:
# Remove client from dict after cleanup attempt (successful or not)
self.mcp_client_dict.pop(name, None)
# 移除关联的FuncTool
self.func_list = [
f
for f in self.func_list
if not (isinstance(f, MCPTool) and f.mcp_server_name == name)
]
logger.info(f"已关闭 MCP 服务 {name}")
@staticmethod
async def test_mcp_server_connection(config: dict) -> list[str]:
@@ -656,36 +319,42 @@ class FunctionToolManager:
self,
name: str,
config: dict,
shutdown_event: asyncio.Event | None = None,
timeout: float | int | str | None = None,
event: asyncio.Event | None = None,
ready_future: asyncio.Future | None = None,
timeout: int = 30,
) -> None:
"""Enable a new MCP server and initialize it.
"""Enable_mcp_server a new MCP server to the manager and initialize it.
Args:
name: The name of the MCP server.
config: Configuration for the MCP server.
shutdown_event: Event to signal when the MCP client should shut down.
timeout: Timeout in seconds for initialization.
Uses ASTRBOT_MCP_ENABLE_TIMEOUT by default (separate from init timeout).
name (str): The name of the MCP server.
config (dict): Configuration for the MCP server.
event (asyncio.Event): Event to signal when the MCP client is ready.
ready_future (asyncio.Future): Future to signal when the MCP client is ready.
timeout (int): Timeout for the initialization.
Raises:
MCPInitTimeoutError: If initialization does not complete within timeout.
TimeoutError: If the initialization does not complete within the specified timeout.
Exception: If there is an error during initialization.
"""
if timeout is None:
timeout_value = self._enable_timeout_default
else:
timeout_value = _resolve_timeout(
timeout=timeout,
env_name=ENABLE_MCP_TIMEOUT_ENV,
default=self._enable_timeout_default,
)
await self._start_mcp_server(
name=name,
cfg=config,
shutdown_event=shutdown_event,
timeout=timeout_value,
if not event:
event = asyncio.Event()
if not ready_future:
ready_future = asyncio.Future()
if name in self.mcp_client_dict:
return
asyncio.create_task(
self._init_mcp_client_task_wrapper(name, config, event, ready_future),
)
try:
await asyncio.wait_for(ready_future, timeout=timeout)
finally:
self.mcp_client_event[name] = event
if ready_future.done() and ready_future.exception():
exc = ready_future.exception()
if exc is not None:
raise exc
async def disable_mcp_server(
self,
@@ -698,40 +367,39 @@ class FunctionToolManager:
name (str): The name of the MCP server to disable. If None, ALL MCP servers will be disabled.
timeout (int): Timeout.
Raises:
MCPShutdownTimeoutError: If shutdown does not complete within timeout.
Only raised when disabling a specific server (name is not None).
"""
if name:
async with self._runtime_lock:
runtime = self._mcp_server_runtime.get(name)
if runtime is None:
if name not in self.mcp_client_event:
return
await self._shutdown_runtimes([runtime], timeout, strict=True)
client = self.mcp_client_dict.get(name)
self.mcp_client_event[name].set()
if not client:
return
client_running_event = client.running_event
try:
await asyncio.wait_for(client_running_event.wait(), timeout=timeout)
finally:
self.mcp_client_event.pop(name, None)
self.func_list = [
f
for f in self.func_list
if not (isinstance(f, MCPTool) and f.mcp_server_name == name)
]
else:
async with self._runtime_lock:
runtimes = list(self._mcp_server_runtime.values())
await self._shutdown_runtimes(runtimes, timeout, strict=False)
def _warn_on_timeout_mismatch(
self,
init_timeout: float,
enable_timeout: float,
) -> None:
if init_timeout == enable_timeout:
return
with self._timeout_warn_lock:
if self._timeout_mismatch_warned:
return
logger.info(
"检测到 MCP 初始化超时与动态启用超时配置不同:"
"初始化使用 %s 秒,动态启用使用 %s 秒。如需一致,请设置相同值。",
f"{init_timeout:g}",
f"{enable_timeout:g}",
)
self._timeout_mismatch_warned = True
running_events = [
client.running_event.wait() for client in self.mcp_client_dict.values()
]
for key, event in self.mcp_client_event.items():
event.set()
# waiting for all clients to finish
try:
await asyncio.wait_for(asyncio.gather(*running_events), timeout=timeout)
finally:
self.mcp_client_event.clear()
self.mcp_client_dict.clear()
self.func_list = [
f for f in self.func_list if not isinstance(f, MCPTool)
]
def get_func_desc_openai_style(self, omit_empty_parameter_field=False) -> list:
"""获得 OpenAI API 风格的**已经激活**的工具描述"""
+2 -19
View File
@@ -330,25 +330,8 @@ class ProviderManager:
if not self.curr_tts_provider_inst and self.tts_provider_insts:
self.curr_tts_provider_inst = self.tts_provider_insts[0]
# 初始化 MCP Client 连接(等待完成以确保工具可用)
strict_mcp_init = os.getenv("ASTRBOT_MCP_INIT_STRICT", "").strip().lower() in {
"1",
"true",
"yes",
"on",
}
mcp_init_summary = await self.llm_tools.init_mcp_clients(
raise_on_all_failed=strict_mcp_init
)
if (
mcp_init_summary.total > 0
and mcp_init_summary.success == 0
and not strict_mcp_init
):
logger.warning(
"MCP 服务全部初始化失败,系统将继续启动(可设置 "
"ASTRBOT_MCP_INIT_STRICT=1 以在此场景下中止启动)。"
)
# 初始化 MCP Client 连接
asyncio.create_task(self.llm_tools.init_mcp_clients(), name="init_mcp_clients")
def dynamic_import_provider(self, type: str) -> None:
"""动态导入提供商适配器模块
-372
View File
@@ -1,372 +0,0 @@
from __future__ import annotations
import hashlib
import json
import os
import re
from dataclasses import dataclass
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
from astrbot.core.computer.computer_client import sync_skills_to_active_sandboxes
from astrbot.core.skills.skill_manager import SkillManager
from astrbot.core.utils.astrbot_path import get_astrbot_skills_path
_MAP_VERSION = 1
_MAP_FILE_NAME = "neo_skill_map.json"
_SKILL_NAME_RE = re.compile(r"[^a-zA-Z0-9._-]+")
def _now_iso() -> str:
return datetime.now(timezone.utc).isoformat()
def _to_jsonable(model_like: Any) -> dict[str, Any]:
if isinstance(model_like, dict):
return model_like
if hasattr(model_like, "model_dump"):
dumped = model_like.model_dump()
if isinstance(dumped, dict):
return dumped
return {}
def _parse_frontmatter(text: str) -> tuple[dict[str, str], str]:
if not text.startswith("---"):
return {}, text
lines = text.splitlines()
if not lines or lines[0].strip() != "---":
return {}, text
end_idx = None
for i in range(1, len(lines)):
if lines[i].strip() == "---":
end_idx = i
break
if end_idx is None:
return {}, text
data: dict[str, str] = {}
for line in lines[1:end_idx]:
if ":" not in line:
continue
key, value = line.split(":", 1)
key = key.strip().lower()
value = value.strip().strip('"').strip("'")
if key in {"name", "description"} and value:
data[key] = value
body = "\n".join(lines[end_idx + 1 :]).lstrip("\n")
return data, body
def _derive_description(markdown_body: str) -> str:
lines = markdown_body.splitlines()
heading_idx = None
for i, line in enumerate(lines):
normalized = line.strip().lower()
if normalized in {"## 描述", "## description"}:
heading_idx = i
break
if heading_idx is not None:
for line in lines[heading_idx + 1 :]:
text = line.strip()
if not text:
continue
if text.startswith("#"):
break
return text
for line in lines:
text = line.strip()
if not text or text.startswith("#"):
continue
return text
return ""
def _ensure_skill_frontmatter(markdown: str, *, skill_name: str, skill_key: str) -> str:
frontmatter, body = _parse_frontmatter(markdown)
name = frontmatter.get("name") or skill_name
name = " ".join(str(name).split())
description = frontmatter.get("description") or _derive_description(body)
if not description:
description = f"Synced skill for `{skill_key}`."
description = " ".join(description.split())
header = f"---\nname: {name}\ndescription: {description}\n---\n\n"
body = body.strip("\n")
return f"{header}{body}\n"
@dataclass
class NeoSkillSyncResult:
skill_key: str
local_skill_name: str
release_id: str
candidate_id: str
payload_ref: str
map_path: str
synced_at: str
class NeoSkillSyncManager:
@staticmethod
def sync_result_to_dict(result: NeoSkillSyncResult) -> dict[str, str]:
return {
"skill_key": result.skill_key,
"local_skill_name": result.local_skill_name,
"release_id": result.release_id,
"candidate_id": result.candidate_id,
"payload_ref": result.payload_ref,
"map_path": result.map_path,
"synced_at": result.synced_at,
}
def __init__(
self,
*,
skills_root: str | None = None,
map_path: str | None = None,
) -> None:
self.skills_root = skills_root or get_astrbot_skills_path()
self.map_path = map_path or str(Path(self.skills_root) / _MAP_FILE_NAME)
os.makedirs(self.skills_root, exist_ok=True)
def _load_map(self) -> dict[str, Any]:
if not os.path.exists(self.map_path):
return {"version": _MAP_VERSION, "items": {}}
try:
with open(self.map_path, encoding="utf-8") as f:
data = json.load(f)
if not isinstance(data, dict):
return {"version": _MAP_VERSION, "items": {}}
items = data.get("items", {})
if not isinstance(items, dict):
items = {}
return {"version": int(data.get("version", _MAP_VERSION)), "items": items}
except Exception:
return {"version": _MAP_VERSION, "items": {}}
def _save_map(self, data: dict[str, Any]) -> None:
os.makedirs(os.path.dirname(self.map_path), exist_ok=True)
with open(self.map_path, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=2)
@staticmethod
def normalize_skill_name(skill_key: str) -> str:
normalized = _SKILL_NAME_RE.sub("-", skill_key.strip().lower())
normalized = normalized.strip("._-")
if not normalized:
normalized = "skill"
return f"neo_{normalized}"
def _resolve_local_skill_name(self, skill_key: str, mapping: dict[str, Any]) -> str:
items = mapping.get("items", {})
if not isinstance(items, dict):
items = {}
existing = items.get(skill_key)
if isinstance(existing, dict):
local_name = existing.get("local_skill_name")
if isinstance(local_name, str) and local_name:
return local_name
base = self.normalize_skill_name(skill_key)
used_names = {
str(v.get("local_skill_name"))
for v in items.values()
if isinstance(v, dict) and v.get("local_skill_name")
}
if base not in used_names:
return base
suffix = hashlib.sha1(skill_key.encode("utf-8")).hexdigest()[:8]
return f"{base}-{suffix}"
async def _find_release(self, client: Any, *, release_id: str) -> dict[str, Any]:
offset = 0
while True:
page = await client.skills.list_releases(limit=100, offset=offset)
page_json = _to_jsonable(page)
items = page_json.get("items", [])
if not isinstance(items, list):
items = []
for item in items:
if isinstance(item, dict) and item.get("id") == release_id:
return item
total = int(page_json.get("total", 0) or 0)
offset += len(items)
if offset >= total or not items:
break
raise ValueError(f"Release not found: {release_id}")
async def _find_active_stable_release(
self,
client: Any,
*,
skill_key: str,
) -> dict[str, Any]:
page = await client.skills.list_releases(
skill_key=skill_key,
active_only=True,
stage="stable",
limit=1,
offset=0,
)
page_json = _to_jsonable(page)
items = page_json.get("items", [])
if not isinstance(items, list) or not items:
raise ValueError(
f"No active stable release found for skill_key: {skill_key}"
)
if not isinstance(items[0], dict):
raise ValueError("Unexpected release payload format.")
return items[0]
async def sync_release(
self,
client: Any,
*,
release_id: str | None = None,
skill_key: str | None = None,
require_stable: bool = True,
) -> NeoSkillSyncResult:
if release_id:
release = await self._find_release(client, release_id=release_id)
elif skill_key:
release = await self._find_active_stable_release(
client, skill_key=skill_key
)
else:
raise ValueError("release_id or skill_key is required for sync.")
release_id_val = str(release.get("id") or "")
release_stage_raw = release.get("stage")
release_stage_value = getattr(release_stage_raw, "value", release_stage_raw)
release_stage = str(release_stage_value or "").strip().lower()
skill_key_val = str(release.get("skill_key") or "")
candidate_id = str(release.get("candidate_id") or "")
if not release_id_val or not skill_key_val or not candidate_id:
raise ValueError("Release payload is incomplete.")
if require_stable and release_stage != "stable":
raise ValueError(
"Only stable releases can be synced to local SKILL.md "
f"(got: {release_stage_raw})."
)
candidate = await client.skills.get_candidate(candidate_id)
candidate_json = _to_jsonable(candidate)
payload_ref = candidate_json.get("payload_ref")
if not isinstance(payload_ref, str) or not payload_ref:
raise ValueError("Candidate payload_ref is missing.")
payload_resp = await client.skills.get_payload(payload_ref)
payload_json = _to_jsonable(payload_resp)
payload = payload_json.get("payload")
if not isinstance(payload, dict):
raise ValueError("Skill payload must be a JSON object.")
skill_markdown = payload.get("skill_markdown")
if not isinstance(skill_markdown, str) or not skill_markdown.strip():
raise ValueError(
"payload.skill_markdown is required for stable sync to local skill."
)
mapping = self._load_map()
local_skill_name = self._resolve_local_skill_name(skill_key_val, mapping)
skill_dir = Path(self.skills_root) / local_skill_name
skill_dir.mkdir(parents=True, exist_ok=True)
normalized_markdown = _ensure_skill_frontmatter(
skill_markdown,
skill_name=local_skill_name,
skill_key=skill_key_val,
)
skill_md_path = skill_dir / "SKILL.md"
skill_md_path.write_text(normalized_markdown, encoding="utf-8")
items = mapping.setdefault("items", {})
items[skill_key_val] = {
"local_skill_name": local_skill_name,
"latest_release_id": release_id_val,
"latest_candidate_id": candidate_id,
"latest_payload_ref": payload_ref,
"updated_at": _now_iso(),
}
mapping["version"] = _MAP_VERSION
self._save_map(mapping)
# Ensure local skill is visible to AstrBot skill manager.
SkillManager().set_skill_active(local_skill_name, True)
# Best-effort synchronization to active sandboxes.
await sync_skills_to_active_sandboxes()
return NeoSkillSyncResult(
skill_key=skill_key_val,
local_skill_name=local_skill_name,
release_id=release_id_val,
candidate_id=candidate_id,
payload_ref=payload_ref,
map_path=self.map_path,
synced_at=_now_iso(),
)
async def promote_with_optional_sync(
self,
client: Any,
*,
candidate_id: str,
stage: str,
sync_to_local: bool,
) -> dict[str, Any]:
release = await client.skills.promote_candidate(candidate_id, stage=stage)
release_json = _to_jsonable(release)
sync_json: dict[str, Any] | None = None
rollback_json: dict[str, Any] | None = None
sync_error: str | None = None
if stage == "stable" and sync_to_local:
try:
sync_result = await self.sync_release(
client,
release_id=str(release_json.get("id", "")),
require_stable=True,
)
sync_json = self.sync_result_to_dict(sync_result)
except Exception as err:
sync_error = str(err)
try:
rollback = await client.skills.rollback_release(
str(release_json.get("id", ""))
)
rollback_json = _to_jsonable(rollback)
except Exception as rollback_err:
rollback_msg = str(rollback_err)
if "no previous release exists" in rollback_msg.lower():
rollback_json = {
"skipped": True,
"reason": rollback_msg,
}
else:
raise RuntimeError(
"stable release synced failed and auto rollback also failed; "
f"sync_error={sync_error}; rollback_error={rollback_err}"
) from rollback_err
return {
"release": release_json,
"sync": sync_json,
"rollback": rollback_json,
"sync_error": sync_error,
}
+37 -250
View File
@@ -7,7 +7,6 @@ import shutil
import tempfile
import zipfile
from dataclasses import dataclass
from datetime import datetime, timezone
from pathlib import Path, PurePosixPath
from astrbot.core.utils.astrbot_path import (
@@ -17,11 +16,9 @@ from astrbot.core.utils.astrbot_path import (
)
SKILLS_CONFIG_FILENAME = "skills.json"
SANDBOX_SKILLS_CACHE_FILENAME = "sandbox_skills_cache.json"
DEFAULT_SKILLS_CONFIG: dict[str, dict] = {"skills": {}}
# SANDBOX_SKILLS_ROOT = "/home/shared/skills"
SANDBOX_SKILLS_ROOT = "skills"
SANDBOX_WORKSPACE_ROOT = "/workspace"
_SANDBOX_SKILLS_CACHE_VERSION = 1
_SKILL_NAME_RE = re.compile(r"^[A-Za-z0-9._-]+$")
@@ -32,23 +29,9 @@ class SkillInfo:
description: str
path: str
active: bool
source_type: str = "local_only"
source_label: str = "local"
local_exists: bool = True
sandbox_exists: bool = False
def _parse_frontmatter_description(text: str) -> str:
"""Extract the ``description`` value from YAML frontmatter.
Expects the standard SKILL.md format used by OpenAI Codex CLI and
Anthropic Claude Skills::
---
name: my-skill
description: What this skill does and when to use it.
---
"""
if not text.startswith("---"):
return ""
lines = text.splitlines()
@@ -70,74 +53,45 @@ def _parse_frontmatter_description(text: str) -> str:
return ""
# Regex for sanitizing paths used in prompt examples — only allow
# safe path characters to prevent prompt injection via crafted skill paths.
_SAFE_PATH_RE = re.compile(r"[^A-Za-z0-9_./ -]")
def build_skills_prompt(skills: list[SkillInfo]) -> str:
"""Build the skills section of the system prompt.
Generates a markdown-formatted skill inventory for the LLM. Only
``name`` and ``description`` are shown upfront; the LLM must read
the full ``SKILL.md`` before execution (progressive disclosure).
"""
skills_lines: list[str] = []
example_path = ""
skills_lines = []
for skill in skills:
description = skill.description or "No description"
skills_lines.append(
f"- **{skill.name}**: {description}\n File: `{skill.path}`"
)
if not example_path:
example_path = skill.path
skills_lines.append(f"- {skill.name}: {description} (file: {skill.path})")
skills_block = "\n".join(skills_lines)
# Sanitize example_path — it may originate from sandbox cache (untrusted)
example_path = _SAFE_PATH_RE.sub("", example_path) if example_path else ""
example_path = example_path or "<skills_root>/<skill_name>/SKILL.md"
# Based on openai/codex
return (
"## Skills\n\n"
"You have specialized skills — reusable instruction bundles stored "
"in `SKILL.md` files. Each skill has a **name** and a **description** "
"that tells you what it does and when to use it.\n\n"
"### Available skills\n\n"
f"{skills_block}\n\n"
"### Skill rules\n\n"
"1. **Discovery** — The list above is the complete skill inventory "
"for this session. Full instructions are in the referenced "
"`SKILL.md` file.\n"
"2. **When to trigger** — Use a skill if the user names it "
"explicitly, or if the task clearly matches the skill's description. "
"*Never silently skip a matching skill* — either use it or briefly "
"explain why you chose not to.\n"
"3. **Mandatory grounding** — Before executing any skill you MUST "
"first read its `SKILL.md` by running a shell command with the "
f"**absolute path** shown above (e.g. `cat {example_path}`). "
"Never rely on memory or assumptions about a skill's content.\n"
"4. **Progressive disclosure** — Load only what is directly "
"referenced from `SKILL.md`:\n"
" - If `scripts/` exist, prefer running or patching them over "
"rewriting code from scratch.\n"
" - If `assets/` or templates exist, reuse them.\n"
" - Do NOT bulk-load every file in the skill directory.\n"
"5. **Coordination** — When multiple skills apply, pick the minimal "
"set needed. Announce which skill(s) you are using and why "
"(one short line). Prefer `astrbot_*` tools when running skill "
"scripts.\n"
"6. **Context hygiene** — Avoid deep reference chasing; open only "
"files that are directly linked from `SKILL.md`.\n"
"7. **Failure handling** — If a skill cannot be applied, state the "
"issue clearly and continue with the best alternative.\n"
"## Skills\n"
"You have many useful skills that can help you accomplish various tasks.\n"
"A skill is a set of local instructions stored in a `SKILL.md` file.\n"
"### Available skills\n"
f"{skills_block}\n"
"### Skill Rules\n"
"\n"
"- Discovery: The list above shows all skills available in this session. Full instructions live in the referenced `SKILL.md`.\n"
"- Trigger rules: Use a skill if the user names it or the task matches its description. Do not carry skills across turns unless re-mentioned\n"
"### How to use a skill (progressive disclosure):\n"
" 0) Mandatory grounding: Before using any skill, you MUST inspect its `SKILL.md` using shell tools"
" (e.g., `cat`, `head`, `sed`, `awk`, `grep`). Do not rely on assumptions or memory.\n"
" 1) Load only directly referenced files, DO NOT bulk-load everything.\n"
" 2) If `scripts/` exist, prefer running or patching them instead of retyping large blocks of code.\n"
" 3) If `assets/` or templates exist, reuse them rather than recreating everything from scratch.\n"
"- Coordination:\n"
" - If multiple skills apply, choose the minimal set that covers the request and state the order in which you will use them.\n"
" - Announce which skill(s) you are using and why (one short line). If you skip an obvious skill, explain why.\n"
" - Prefer to use `astrbot_*` tools to perform skills that need to run scripts.\n"
"- Context hygiene:\n"
" - Avoid deep reference chasing: unless blocked, open only files that are directly linked from `SKILL.md`.\n"
"- Failure handling: If a skill cannot be applied, state the issue and continue with the best alternative.\n"
"### Example\n"
"When you decided to use a skill, use shell tool to read its `SKILL.md`, e.g., `head -40 skills/code_formatter/SKILL.md`, and you can increase or decrease the number of lines as needed.\n"
)
class SkillManager:
def __init__(self, skills_root: str | None = None) -> None:
self.skills_root = skills_root or get_astrbot_skills_path()
data_path = Path(get_astrbot_data_path())
self.config_path = str(data_path / SKILLS_CONFIG_FILENAME)
self.sandbox_skills_cache_path = str(data_path / SANDBOX_SKILLS_CACHE_FILENAME)
self.config_path = os.path.join(get_astrbot_data_path(), SKILLS_CONFIG_FILENAME)
os.makedirs(self.skills_root, exist_ok=True)
def _load_config(self) -> dict:
@@ -154,66 +108,6 @@ class SkillManager:
with open(self.config_path, "w", encoding="utf-8") as f:
json.dump(config, f, ensure_ascii=False, indent=4)
def _load_sandbox_skills_cache(self) -> dict:
if not os.path.exists(self.sandbox_skills_cache_path):
return {"version": _SANDBOX_SKILLS_CACHE_VERSION, "skills": []}
try:
with open(self.sandbox_skills_cache_path, encoding="utf-8") as f:
data = json.load(f)
if not isinstance(data, dict):
return {"version": _SANDBOX_SKILLS_CACHE_VERSION, "skills": []}
skills = data.get("skills", [])
if not isinstance(skills, list):
skills = []
return {
"version": int(data.get("version", _SANDBOX_SKILLS_CACHE_VERSION)),
"skills": skills,
"updated_at": data.get("updated_at"),
}
except Exception:
return {"version": _SANDBOX_SKILLS_CACHE_VERSION, "skills": []}
def _save_sandbox_skills_cache(self, cache: dict) -> None:
cache["version"] = _SANDBOX_SKILLS_CACHE_VERSION
cache["updated_at"] = datetime.now(timezone.utc).isoformat()
with open(self.sandbox_skills_cache_path, "w", encoding="utf-8") as f:
json.dump(cache, f, ensure_ascii=False, indent=2)
def set_sandbox_skills_cache(self, skills: list[dict]) -> None:
"""Persist sandbox skill metadata discovered from runtime side."""
deduped: dict[str, dict[str, str]] = {}
for item in skills:
if not isinstance(item, dict):
continue
name = str(item.get("name", "")).strip()
if not name or not _SKILL_NAME_RE.match(name):
continue
description = str(item.get("description", "") or "")
path = str(item.get("path", "") or "")
if not path:
path = f"{SANDBOX_WORKSPACE_ROOT}/{SANDBOX_SKILLS_ROOT}/{name}/SKILL.md"
deduped[name] = {
"name": name,
"description": description,
"path": path.replace("\\", "/"),
}
cache = {
"version": _SANDBOX_SKILLS_CACHE_VERSION,
"skills": [deduped[name] for name in sorted(deduped)],
}
self._save_sandbox_skills_cache(cache)
def get_sandbox_skills_cache_status(self) -> dict[str, object]:
cache = self._load_sandbox_skills_cache()
skills = cache.get("skills", [])
count = len(skills) if isinstance(skills, list) else 0
return {
"exists": os.path.exists(self.sandbox_skills_cache_path),
"ready": count > 0,
"count": count,
"updated_at": cache.get("updated_at"),
}
def list_skills(
self,
*,
@@ -230,21 +124,7 @@ class SkillManager:
config = self._load_config()
skill_configs = config.get("skills", {})
modified = False
skills_by_name: dict[str, SkillInfo] = {}
sandbox_cached_paths: dict[str, str] = {}
sandbox_cached_descriptions: dict[str, str] = {}
cache_for_paths = self._load_sandbox_skills_cache()
for item in cache_for_paths.get("skills", []):
if not isinstance(item, dict):
continue
name = str(item.get("name", "") or "").strip()
path = str(item.get("path", "") or "").strip().replace("\\", "/")
if not name or not _SKILL_NAME_RE.match(name):
continue
sandbox_cached_descriptions[name] = str(item.get("description", "") or "")
if path:
sandbox_cached_paths[name] = path
skills: list[SkillInfo] = []
for entry in sorted(Path(self.skills_root).iterdir()):
if not entry.is_dir():
@@ -265,129 +145,36 @@ class SkillManager:
description = _parse_frontmatter_description(content)
except Exception:
description = ""
sandbox_exists = (
runtime == "sandbox" and skill_name in sandbox_cached_descriptions
)
source_type = "both" if sandbox_exists else "local_only"
source_label = "synced" if sandbox_exists else "local"
if runtime == "sandbox" and show_sandbox_path:
path_str = sandbox_cached_paths.get(skill_name) or (
f"{SANDBOX_WORKSPACE_ROOT}/{SANDBOX_SKILLS_ROOT}/{skill_name}/SKILL.md"
)
path_str = f"{SANDBOX_SKILLS_ROOT}/{skill_name}/SKILL.md"
else:
path_str = str(skill_md)
path_str = path_str.replace("\\", "/")
skills_by_name[skill_name] = SkillInfo(
name=skill_name,
description=description,
path=path_str,
active=active,
source_type=source_type,
source_label=source_label,
local_exists=True,
sandbox_exists=sandbox_exists,
)
if runtime == "sandbox":
cache = self._load_sandbox_skills_cache()
for item in cache.get("skills", []):
if not isinstance(item, dict):
continue
skill_name = str(item.get("name", "")).strip()
if (
not skill_name
or skill_name in skills_by_name
or not _SKILL_NAME_RE.match(skill_name)
):
continue
active = skill_configs.get(skill_name, {}).get("active", True)
if skill_name not in skill_configs:
skill_configs[skill_name] = {"active": active}
modified = True
if active_only and not active:
continue
description = sandbox_cached_descriptions.get(skill_name, "")
if show_sandbox_path:
path_str = f"{SANDBOX_WORKSPACE_ROOT}/{SANDBOX_SKILLS_ROOT}/{skill_name}/SKILL.md"
else:
path_str = sandbox_cached_paths.get(skill_name, "")
if not path_str:
path_str = f"{SANDBOX_WORKSPACE_ROOT}/{SANDBOX_SKILLS_ROOT}/{skill_name}/SKILL.md"
skills_by_name[skill_name] = SkillInfo(
skills.append(
SkillInfo(
name=skill_name,
description=description,
path=path_str.replace("\\", "/"),
path=path_str,
active=active,
source_type="sandbox_only",
source_label="sandbox_preset",
local_exists=False,
sandbox_exists=True,
)
)
if modified:
config["skills"] = skill_configs
self._save_config(config)
return [skills_by_name[name] for name in sorted(skills_by_name)]
def is_sandbox_only_skill(self, name: str) -> bool:
skill_dir = Path(self.skills_root) / name
skill_md_exists = (skill_dir / "SKILL.md").exists()
if skill_md_exists:
return False
cache = self._load_sandbox_skills_cache()
skills = cache.get("skills", [])
if not isinstance(skills, list):
return False
for item in skills:
if not isinstance(item, dict):
continue
if str(item.get("name", "")).strip() == name:
return True
return False
return skills
def set_skill_active(self, name: str, active: bool) -> None:
if self.is_sandbox_only_skill(name):
raise PermissionError(
"Sandbox preset skill cannot be enabled/disabled from local skill management."
)
config = self._load_config()
config.setdefault("skills", {})
config["skills"][name] = {"active": bool(active)}
self._save_config(config)
def _remove_skill_from_sandbox_cache(self, name: str) -> None:
cache = self._load_sandbox_skills_cache()
skills = cache.get("skills", [])
if not isinstance(skills, list):
return
filtered = [
item
for item in skills
if not (
isinstance(item, dict) and str(item.get("name", "")).strip() == name
)
]
if len(filtered) != len(skills):
cache["skills"] = filtered
self._save_sandbox_skills_cache(cache)
def delete_skill(self, name: str) -> None:
if self.is_sandbox_only_skill(name):
raise PermissionError(
"Sandbox preset skill cannot be deleted from local skill management."
)
skill_dir = Path(self.skills_root) / name
if skill_dir.exists():
shutil.rmtree(skill_dir)
# Ensure UI consistency even when there is no active sandbox session
# to refresh cache from runtime side.
self._remove_skill_from_sandbox_cache(name)
config = self._load_config()
if name in config.get("skills", {}):
config["skills"].pop(name, None)
@@ -409,7 +196,7 @@ class SkillManager:
top_dirs = {
PurePosixPath(name).parts[0] for name in file_names if name.strip()
}
print(top_dirs)
if len(top_dirs) != 1:
raise ValueError("Zip archive must contain a single top-level folder.")
skill_name = next(iter(top_dirs))
+1 -3
View File
@@ -149,9 +149,7 @@ class AstrBotUpdator(RepoZipUpdator):
file_url = None
if os.environ.get("ASTRBOT_CLI") or os.environ.get("ASTRBOT_LAUNCHER"):
raise Exception(
"Error: You are running AstrBot via CLI, please use `pip` or `uv tool upgrade` to update AstrBot."
) # 避免版本管理混乱
raise Exception("不支持更新此方式启动的AstrBot") # 避免版本管理混乱
if latest:
latest_version = update_data[0]["tag_name"]
+1 -7
View File
@@ -14,7 +14,7 @@ import certifi
import psutil
from PIL import Image
from .astrbot_path import get_astrbot_data_path, get_astrbot_path, get_astrbot_temp_path
from .astrbot_path import get_astrbot_data_path, get_astrbot_temp_path
logger = logging.getLogger("astrbot")
@@ -219,13 +219,7 @@ def get_local_ip_addresses():
async def get_dashboard_version():
# First check user data directory (manually updated / downloaded dashboard).
dist_dir = os.path.join(get_astrbot_data_path(), "dist")
if not os.path.exists(dist_dir):
# Fall back to the dist bundled inside the installed wheel.
_bundled = Path(get_astrbot_path()) / "astrbot" / "dashboard" / "dist"
if _bundled.exists():
dist_dir = str(_bundled)
if os.path.exists(dist_dir):
version_file = os.path.join(dist_dir, "assets", "version")
if os.path.exists(version_file):
-103
View File
@@ -206,110 +206,12 @@ def validate_config(data, schema: dict, is_core: bool) -> tuple[list[str], dict]
return errors, data
def _log_computer_config_changes(old_config: dict, new_config: dict) -> None:
"""Compare and log Computer/sandbox configuration changes."""
old_ps = old_config.get("provider_settings", {})
new_ps = new_config.get("provider_settings", {})
# Check computer_use_runtime
old_runtime = old_ps.get("computer_use_runtime", "none")
new_runtime = new_ps.get("computer_use_runtime", "none")
if old_runtime != new_runtime:
logger.info(
"[Computer] Config changed: computer_use_runtime %s -> %s",
old_runtime,
new_runtime,
)
# Check sandbox sub-keys
old_sandbox = old_ps.get("sandbox", {})
new_sandbox = new_ps.get("sandbox", {})
all_keys = set(old_sandbox.keys()) | set(new_sandbox.keys())
for key in sorted(all_keys):
old_val = old_sandbox.get(key)
new_val = new_sandbox.get(key)
if old_val != new_val:
# Mask tokens/secrets in log output
if "token" in key or "secret" in key:
old_display = "***" if old_val else "(empty)"
new_display = "***" if new_val else "(empty)"
else:
old_display = old_val
new_display = new_val
logger.info(
"[Computer] Config changed: sandbox.%s %s -> %s",
key,
old_display,
new_display,
)
async def _validate_neo_connectivity(
post_config: dict,
) -> str | None:
"""Check if Bay is reachable when Shipyard Neo sandbox is configured.
Returns a warning message string if Bay isn't reachable, or None if
everything looks fine (or Neo isn't configured).
"""
ps = post_config.get("provider_settings", {})
runtime = ps.get("computer_use_runtime", "none")
sandbox = ps.get("sandbox", {})
booter = sandbox.get("booter", "")
# Only check when sandbox mode + shipyard_neo is selected
if runtime != "sandbox" or booter != "shipyard_neo":
return None
endpoint = sandbox.get("shipyard_neo_endpoint", "").rstrip("/")
if not endpoint:
return "⚠️ Shipyard Neo endpoint 未设置"
access_token = sandbox.get("shipyard_neo_access_token", "")
if not access_token:
# Try auto-discovery
from astrbot.core.computer.computer_client import _discover_bay_credentials
access_token = _discover_bay_credentials(endpoint)
if not access_token:
return (
"⚠️ 未找到 Bay API Key。请填写访问令牌,"
"或确保 Bay 的 credentials.json 可被自动发现。"
)
# Connectivity check
import aiohttp
health_url = f"{endpoint}/health"
try:
async with aiohttp.ClientSession() as session:
async with session.get(
health_url,
timeout=aiohttp.ClientTimeout(total=5),
) as resp:
if resp.status != 200:
return (
f"⚠️ Bay 健康检查失败 (HTTP {resp.status})"
f"请确认 Bay 正在运行: {endpoint}"
)
except Exception:
return f"⚠️ 无法连接 Bay ({endpoint}),请确认 Bay 已启动。"
return None
def save_config(
post_config: dict, config: AstrBotConfig, is_core: bool = False
) -> None:
"""验证并保存配置"""
errors = None
logger.info(f"Saving config, is_core={is_core}")
# Snapshot old Computer config for change detection
if is_core:
_log_computer_config_changes(dict(config), post_config)
try:
if is_core:
errors, post_config = validate_config(
@@ -1026,11 +928,6 @@ class ConfigRoute(Route):
await self._save_astrbot_configs(config, conf_id)
await self.core_lifecycle.reload_pipeline_scheduler(conf_id)
# Non-blocking Bay connectivity check
warning = await _validate_neo_connectivity(config)
if warning:
return Response().ok(None, f"保存成功。{warning}").__dict__
return Response().ok(None, "保存成功~").__dict__
except Exception as e:
logger.error(traceback.format_exc())
+2 -407
View File
@@ -1,48 +1,15 @@
import os
import re
import shutil
import traceback
from collections.abc import Awaitable, Callable
from pathlib import Path
from typing import Any
from quart import request, send_file
from quart import request
from astrbot.core import DEMO_MODE, logger
from astrbot.core.computer.computer_client import (
_discover_bay_credentials,
sync_skills_to_active_sandboxes,
)
from astrbot.core.skills.neo_skill_sync import NeoSkillSyncManager
from astrbot.core.skills.skill_manager import SkillManager
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from .route import Response, Route, RouteContext
def _to_jsonable(value: Any) -> Any:
if isinstance(value, dict):
return {k: _to_jsonable(v) for k, v in value.items()}
if isinstance(value, list):
return [_to_jsonable(v) for v in value]
if hasattr(value, "model_dump"):
return _to_jsonable(value.model_dump())
return value
def _to_bool(value: Any, default: bool = False) -> bool:
if value is None:
return default
if isinstance(value, bool):
return value
if isinstance(value, str):
return value.strip().lower() in {"1", "true", "yes", "y", "on"}
return bool(value)
_SKILL_NAME_RE = re.compile(r"^[A-Za-z0-9._-]+$")
class SkillsRoute(Route):
def __init__(self, context: RouteContext, core_lifecycle) -> None:
super().__init__(context)
@@ -50,81 +17,18 @@ class SkillsRoute(Route):
self.routes = {
"/skills": ("GET", self.get_skills),
"/skills/upload": ("POST", self.upload_skill),
"/skills/download": ("GET", self.download_skill),
"/skills/update": ("POST", self.update_skill),
"/skills/delete": ("POST", self.delete_skill),
"/skills/neo/candidates": ("GET", self.get_neo_candidates),
"/skills/neo/releases": ("GET", self.get_neo_releases),
"/skills/neo/payload": ("GET", self.get_neo_payload),
"/skills/neo/evaluate": ("POST", self.evaluate_neo_candidate),
"/skills/neo/promote": ("POST", self.promote_neo_candidate),
"/skills/neo/rollback": ("POST", self.rollback_neo_release),
"/skills/neo/sync": ("POST", self.sync_neo_release),
"/skills/neo/delete-candidate": ("POST", self.delete_neo_candidate),
"/skills/neo/delete-release": ("POST", self.delete_neo_release),
}
self.register_routes()
def _get_neo_client_config(self) -> tuple[str, str]:
provider_settings = self.core_lifecycle.astrbot_config.get(
"provider_settings",
{},
)
sandbox = provider_settings.get("sandbox", {})
endpoint = sandbox.get("shipyard_neo_endpoint", "")
access_token = sandbox.get("shipyard_neo_access_token", "")
# Auto-discover token from Bay's credentials.json if not configured
if not access_token and endpoint:
access_token = _discover_bay_credentials(endpoint)
if not endpoint or not access_token:
raise ValueError(
"Shipyard Neo endpoint or access token not configured. "
"Set them in Dashboard or ensure Bay's credentials.json is accessible."
)
return endpoint, access_token
async def _delete_neo_release(
self, client: Any, release_id: str, reason: str | None
):
return await client.skills.delete_release(release_id, reason=reason)
async def _delete_neo_candidate(
self, client: Any, candidate_id: str, reason: str | None
):
return await client.skills.delete_candidate(candidate_id, reason=reason)
async def _with_neo_client(
self,
operation: Callable[[Any], Awaitable[dict]],
) -> dict:
try:
endpoint, access_token = self._get_neo_client_config()
from shipyard_neo import BayClient
async with BayClient(
endpoint_url=endpoint,
access_token=access_token,
) as client:
return await operation(client)
except ValueError as e:
# Config not ready — expected when Neo isn't set up yet
logger.debug("[Neo] %s", e)
return Response().error(str(e)).__dict__
except Exception as e:
logger.error(traceback.format_exc())
return Response().error(str(e)).__dict__
async def get_skills(self):
try:
provider_settings = self.core_lifecycle.astrbot_config.get(
"provider_settings", {}
)
runtime = provider_settings.get("computer_use_runtime", "local")
skill_mgr = SkillManager()
skills = skill_mgr.list_skills(
skills = SkillManager().list_skills(
active_only=False, runtime=runtime, show_sandbox_path=False
)
return (
@@ -132,8 +36,6 @@ class SkillsRoute(Route):
.ok(
{
"skills": [skill.__dict__ for skill in skills],
"runtime": runtime,
"sandbox_cache": skill_mgr.get_sandbox_skills_cache_status(),
}
)
.__dict__
@@ -168,11 +70,6 @@ class SkillsRoute(Route):
skill_mgr = SkillManager()
skill_name = skill_mgr.install_skill_from_zip(temp_path, overwrite=True)
try:
await sync_skills_to_active_sandboxes()
except Exception:
logger.warning("Failed to sync uploaded skills to active sandboxes.")
return (
Response()
.ok({"name": skill_name}, "Skill uploaded successfully.")
@@ -188,53 +85,6 @@ class SkillsRoute(Route):
except Exception:
logger.warning(f"Failed to remove temp skill file: {temp_path}")
async def download_skill(self):
try:
name = str(request.args.get("name") or "").strip()
if not name:
return Response().error("Missing skill name").__dict__
if not _SKILL_NAME_RE.match(name):
return Response().error("Invalid skill name").__dict__
skill_mgr = SkillManager()
if skill_mgr.is_sandbox_only_skill(name):
return (
Response()
.error(
"Sandbox preset skill cannot be downloaded from local skill files."
)
.__dict__
)
skill_dir = Path(skill_mgr.skills_root) / name
skill_md = skill_dir / "SKILL.md"
if not skill_dir.is_dir() or not skill_md.exists():
return Response().error("Local skill not found").__dict__
export_dir = Path(get_astrbot_temp_path()) / "skill_exports"
export_dir.mkdir(parents=True, exist_ok=True)
zip_base = export_dir / name
zip_path = zip_base.with_suffix(".zip")
if zip_path.exists():
zip_path.unlink()
shutil.make_archive(
str(zip_base),
"zip",
root_dir=str(skill_mgr.skills_root),
base_dir=name,
)
return await send_file(
str(zip_path),
as_attachment=True,
attachment_filename=f"{name}.zip",
conditional=True,
)
except Exception as e:
logger.error(traceback.format_exc())
return Response().error(str(e)).__dict__
async def update_skill(self):
if DEMO_MODE:
return (
@@ -267,262 +117,7 @@ class SkillsRoute(Route):
if not name:
return Response().error("Missing skill name").__dict__
SkillManager().delete_skill(name)
try:
await sync_skills_to_active_sandboxes()
except Exception:
logger.warning("Failed to sync deleted skills to active sandboxes.")
return Response().ok({"name": name}).__dict__
except Exception as e:
logger.error(traceback.format_exc())
return Response().error(str(e)).__dict__
async def get_neo_candidates(self):
logger.info("[Neo] GET /skills/neo/candidates requested.")
status = request.args.get("status")
skill_key = request.args.get("skill_key")
limit = int(request.args.get("limit", 100))
offset = int(request.args.get("offset", 0))
async def _do(client):
candidates = await client.skills.list_candidates(
status=status,
skill_key=skill_key,
limit=limit,
offset=offset,
)
result = _to_jsonable(candidates)
total = result.get("total", "?") if isinstance(result, dict) else "?"
logger.info(f"[Neo] Candidates fetched: total={total}")
return Response().ok(result).__dict__
return await self._with_neo_client(_do)
async def get_neo_releases(self):
logger.info("[Neo] GET /skills/neo/releases requested.")
skill_key = request.args.get("skill_key")
stage = request.args.get("stage")
active_only = _to_bool(request.args.get("active_only"), False)
limit = int(request.args.get("limit", 100))
offset = int(request.args.get("offset", 0))
async def _do(client):
releases = await client.skills.list_releases(
skill_key=skill_key,
active_only=active_only,
stage=stage,
limit=limit,
offset=offset,
)
result = _to_jsonable(releases)
total = result.get("total", "?") if isinstance(result, dict) else "?"
logger.info(f"[Neo] Releases fetched: total={total}")
return Response().ok(result).__dict__
return await self._with_neo_client(_do)
async def get_neo_payload(self):
logger.info("[Neo] GET /skills/neo/payload requested.")
payload_ref = request.args.get("payload_ref", "")
if not payload_ref:
return Response().error("Missing payload_ref").__dict__
async def _do(client):
payload = await client.skills.get_payload(payload_ref)
logger.info(f"[Neo] Payload fetched: ref={payload_ref}")
return Response().ok(_to_jsonable(payload)).__dict__
return await self._with_neo_client(_do)
async def evaluate_neo_candidate(self):
if DEMO_MODE:
return (
Response()
.error("You are not permitted to do this operation in demo mode")
.__dict__
)
logger.info("[Neo] POST /skills/neo/evaluate requested.")
data = await request.get_json()
candidate_id = data.get("candidate_id")
passed_value = data.get("passed")
if not candidate_id or passed_value is None:
return Response().error("Missing candidate_id or passed").__dict__
passed = _to_bool(passed_value, False)
async def _do(client):
result = await client.skills.evaluate_candidate(
candidate_id,
passed=passed,
score=data.get("score"),
benchmark_id=data.get("benchmark_id"),
report=data.get("report"),
)
logger.info(
f"[Neo] Candidate evaluated: id={candidate_id}, passed={passed}"
)
return Response().ok(_to_jsonable(result)).__dict__
return await self._with_neo_client(_do)
async def promote_neo_candidate(self):
if DEMO_MODE:
return (
Response()
.error("You are not permitted to do this operation in demo mode")
.__dict__
)
logger.info("[Neo] POST /skills/neo/promote requested.")
data = await request.get_json()
candidate_id = data.get("candidate_id")
stage = data.get("stage", "canary")
sync_to_local = _to_bool(data.get("sync_to_local"), True)
if not candidate_id:
return Response().error("Missing candidate_id").__dict__
if stage not in {"canary", "stable"}:
return Response().error("Invalid stage, must be canary/stable").__dict__
async def _do(client):
sync_mgr = NeoSkillSyncManager()
result = await sync_mgr.promote_with_optional_sync(
client,
candidate_id=candidate_id,
stage=stage,
sync_to_local=sync_to_local,
)
release_json = result.get("release")
logger.info(f"[Neo] Candidate promoted: id={candidate_id}, stage={stage}")
sync_json = result.get("sync")
did_sync_to_local = bool(sync_json)
if did_sync_to_local:
logger.info(
f"[Neo] Stable release synced to local: skill={sync_json.get('local_skill_name', '')}"
)
if result.get("sync_error"):
resp = Response().error(
"Stable promote synced failed and has been rolled back. "
f"sync_error={result['sync_error']}"
)
resp.data = {
"release": release_json,
"rollback": result.get("rollback"),
}
return resp.__dict__
# Try to push latest local skills to all active sandboxes.
if not did_sync_to_local:
try:
await sync_skills_to_active_sandboxes()
except Exception:
logger.warning("Failed to sync skills to active sandboxes.")
return Response().ok({"release": release_json, "sync": sync_json}).__dict__
return await self._with_neo_client(_do)
async def rollback_neo_release(self):
if DEMO_MODE:
return (
Response()
.error("You are not permitted to do this operation in demo mode")
.__dict__
)
logger.info("[Neo] POST /skills/neo/rollback requested.")
data = await request.get_json()
release_id = data.get("release_id")
if not release_id:
return Response().error("Missing release_id").__dict__
async def _do(client):
result = await client.skills.rollback_release(release_id)
logger.info(f"[Neo] Release rolled back: id={release_id}")
return Response().ok(_to_jsonable(result)).__dict__
return await self._with_neo_client(_do)
async def sync_neo_release(self):
if DEMO_MODE:
return (
Response()
.error("You are not permitted to do this operation in demo mode")
.__dict__
)
logger.info("[Neo] POST /skills/neo/sync requested.")
data = await request.get_json()
release_id = data.get("release_id")
skill_key = data.get("skill_key")
require_stable = _to_bool(data.get("require_stable"), True)
if not release_id and not skill_key:
return Response().error("Missing release_id or skill_key").__dict__
async def _do(client):
sync_mgr = NeoSkillSyncManager()
result = await sync_mgr.sync_release(
client,
release_id=release_id,
skill_key=skill_key,
require_stable=require_stable,
)
logger.info(
f"[Neo] Release synced to local: skill={result.local_skill_name}, "
f"release_id={result.release_id}"
)
return (
Response()
.ok(
{
"skill_key": result.skill_key,
"local_skill_name": result.local_skill_name,
"release_id": result.release_id,
"candidate_id": result.candidate_id,
"payload_ref": result.payload_ref,
"map_path": result.map_path,
"synced_at": result.synced_at,
}
)
.__dict__
)
return await self._with_neo_client(_do)
async def delete_neo_candidate(self):
if DEMO_MODE:
return (
Response()
.error("You are not permitted to do this operation in demo mode")
.__dict__
)
logger.info("[Neo] POST /skills/neo/delete-candidate requested.")
data = await request.get_json()
candidate_id = data.get("candidate_id")
reason = data.get("reason")
if not candidate_id:
return Response().error("Missing candidate_id").__dict__
async def _do(client):
result = await self._delete_neo_candidate(client, candidate_id, reason)
logger.info(f"[Neo] Candidate deleted: id={candidate_id}")
return Response().ok(_to_jsonable(result)).__dict__
return await self._with_neo_client(_do)
async def delete_neo_release(self):
if DEMO_MODE:
return (
Response()
.error("You are not permitted to do this operation in demo mode")
.__dict__
)
logger.info("[Neo] POST /skills/neo/delete-release requested.")
data = await request.get_json()
release_id = data.get("release_id")
reason = data.get("reason")
if not release_id:
return Response().error("Missing release_id").__dict__
async def _do(client):
result = await self._delete_neo_release(client, release_id, reason)
logger.info(f"[Neo] Release deleted: id={release_id}")
return Response().ok(_to_jsonable(result)).__dict__
return await self._with_neo_client(_do)
+7 -5
View File
@@ -51,9 +51,11 @@ class ToolsRoute(Route):
server_info[key] = value
# 如果MCP客户端已初始化,从客户端获取工具名称
for name_key, runtime in self.tool_mgr.mcp_server_runtime_view.items():
for (
name_key,
mcp_client,
) in self.tool_mgr.mcp_client_dict.items():
if name_key == name:
mcp_client = runtime.client
server_info["tools"] = [tool.name for tool in mcp_client.tools]
server_info["errlogs"] = mcp_client.server_errlogs
break
@@ -190,7 +192,7 @@ class ToolsRoute(Route):
# 处理MCP客户端状态变化
if active:
if (
old_name in self.tool_mgr.mcp_server_runtime_view
old_name in self.tool_mgr.mcp_client_dict
or not only_update_active
or is_rename
):
@@ -231,7 +233,7 @@ class ToolsRoute(Route):
.__dict__
)
# 如果要停用服务器
elif old_name in self.tool_mgr.mcp_server_runtime_view:
elif old_name in self.tool_mgr.mcp_client_dict:
try:
await self.tool_mgr.disable_mcp_server(old_name, timeout=10)
except TimeoutError:
@@ -270,7 +272,7 @@ class ToolsRoute(Route):
del config["mcpServers"][name]
if self.tool_mgr.save_mcp_config(config):
if name in self.tool_mgr.mcp_server_runtime_view:
if name in self.tool_mgr.mcp_client_dict:
try:
await self.tool_mgr.disable_mcp_server(name, timeout=10)
except TimeoutError:
+4 -16
View File
@@ -33,9 +33,6 @@ from .routes.session_management import SessionManagementRoute
from .routes.subagent import SubAgentRoute
from .routes.t2i import T2iRoute
# Static assets shipped inside the wheel (built during `hatch build`).
_BUNDLED_DIST = Path(__file__).parent / "dist"
class _AddrWithPort(Protocol):
port: int
@@ -69,22 +66,13 @@ class AstrBotDashboard:
self.config = core_lifecycle.astrbot_config
self.db = db
# Path priority:
# 1. Explicit webui_dir argument
# 2. data/dist/ (user-installed / manually updated dashboard)
# 3. astrbot/dashboard/dist/ (bundled with the wheel)
# 参数指定webui目录
if webui_dir and os.path.exists(webui_dir):
self.data_path = os.path.abspath(webui_dir)
else:
user_dist = os.path.join(get_astrbot_data_path(), "dist")
if os.path.exists(user_dist):
self.data_path = os.path.abspath(user_dist)
elif _BUNDLED_DIST.exists():
self.data_path = str(_BUNDLED_DIST)
logger.info("Using bundled dashboard dist: %s", self.data_path)
else:
# Fall back to expected user path (will fail gracefully later)
self.data_path = os.path.abspath(user_dist)
self.data_path = os.path.abspath(
os.path.join(get_astrbot_data_path(), "dist"),
)
self.app = Quart("dashboard", static_folder=self.data_path, static_url_path="/")
APP = self.app # noqa
+1 -1
View File
@@ -3,7 +3,7 @@
<head>
<meta charset="UTF-8" />
<link rel="icon" href="/favicon.svg" />
<meta name="viewport" content="width=device-width, initial-scale=1.0, maximum-scale=1.0, user-scalable=no" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<meta name="keywords" content="AstrBot Soulter" />
<meta name="description" content="AstrBot Dashboard" />
<meta name="robots" content="noindex, nofollow" />
+1
View File
@@ -65,6 +65,7 @@
"sass-loader": "13.3.2",
"typescript": "5.1.6",
"vite": "4.4.9",
"vite-plugin-monaco-editor": "1.1.0",
"vue-cli-plugin-vuetify": "2.5.8",
"vue-tsc": "1.8.8",
"vuetify-loader": "^2.0.0-alpha.9"
+12
View File
@@ -159,6 +159,9 @@ importers:
vite:
specifier: 4.4.9
version: 4.4.9(@types/node@20.19.32)(sass@1.66.1)(terser@5.46.0)
vite-plugin-monaco-editor:
specifier: 1.1.0
version: 1.1.0(monaco-editor@0.52.2)
vue-cli-plugin-vuetify:
specifier: 2.5.8
version: 2.5.8(sass-loader@13.3.2(sass@1.66.1)(webpack@5.105.0))(vue@3.3.4)(vuetify-loader@2.0.0-alpha.9(@vue/compiler-sfc@3.3.4)(vue@3.3.4)(vuetify@3.7.11)(webpack@5.105.0))(webpack@5.105.0)
@@ -2568,6 +2571,11 @@ packages:
vfile@6.0.3:
resolution: {integrity: sha512-KzIbH/9tXat2u30jf+smMwFCsno4wHVdNmzFyL+T/L3UGqqk6JKfVqOFOZEpZSHADH1k40ab6NUIXZq422ov3Q==}
vite-plugin-monaco-editor@1.1.0:
resolution: {integrity: sha512-IvtUqZotrRoVqwT0PBBDIZPNraya3BxN/bfcNfnxZ5rkJiGcNtO5eAOWWSgT7zullIAEqQwxMU83yL9J5k7gww==}
peerDependencies:
monaco-editor: '>=0.33.0'
vite-plugin-vuetify@1.0.2:
resolution: {integrity: sha512-MubIcKD33O8wtgQXlbEXE7ccTEpHZ8nPpe77y9Wy3my2MWw/PgehP9VqTp92BLqr0R1dSL970Lynvisx3UxBFw==}
engines: {node: '>=12'}
@@ -5297,6 +5305,10 @@ snapshots:
'@types/unist': 3.0.3
vfile-message: 4.0.3
vite-plugin-monaco-editor@1.1.0(monaco-editor@0.52.2):
dependencies:
monaco-editor: 0.52.2
vite-plugin-vuetify@1.0.2(vite@4.4.9(@types/node@20.19.32)(sass@1.66.1)(terser@5.46.0))(vue@3.3.4)(vuetify@3.7.11):
dependencies:
'@vuetify/loader-shared': 1.7.1(vue@3.3.4)(vuetify@3.7.11)
+8 -12
View File
@@ -37,7 +37,14 @@
<!-- 正常聊天界面 -->
<template v-else>
<div class="conversation-header fade-in" v-if="isMobile">
<!-- 手机端菜单按钮 -->
<v-btn icon class="mobile-menu-btn" @click="toggleMobileSidebar" variant="text">
<v-icon>mdi-menu</v-icon>
</v-btn>
</div>
<!-- 面包屑导航 -->
<div v-if="currentSessionProject && messages && messages.length > 0" class="breadcrumb-container">
<div class="breadcrumb-content">
<span class="breadcrumb-emoji">{{ currentSessionProject.emoji || '📁' }}</span>
@@ -234,7 +241,6 @@ const route = useRoute();
const { t } = useI18n();
const { tm } = useModuleI18n('features/chat');
const theme = useTheme();
const customizer = useCustomizerStore();
// UI
const isMobile = ref(false);
@@ -336,28 +342,19 @@ function checkMobile() {
isMobile.value = window.innerWidth <= 768;
if (!isMobile.value) {
mobileMenuOpen.value = false;
customizer.SET_CHAT_SIDEBAR(false);
}
}
function toggleMobileSidebar() {
mobileMenuOpen.value = !mobileMenuOpen.value;
customizer.SET_CHAT_SIDEBAR(mobileMenuOpen.value);
}
function closeMobileSidebar() {
mobileMenuOpen.value = false;
customizer.SET_CHAT_SIDEBAR(false);
}
// nav header sidebar toggle
watch(() => customizer.chatSidebarOpen, (val) => {
if (isMobile.value) {
mobileMenuOpen.value = val;
}
});
function toggleTheme() {
const customizer = useCustomizerStore();
const newTheme = customizer.uiTheme === 'PurpleTheme' ? 'PurpleThemeDark' : 'PurpleTheme';
customizer.SET_UI_THEME(newTheme);
theme.global.name.value = newTheme;
@@ -725,7 +722,6 @@ onBeforeUnmount(() => {
height: 100%;
max-height: 100%;
overflow: hidden;
overscroll-behavior: none;
}
.chat-page-container {
+30 -39
View File
@@ -32,12 +32,11 @@
</div>
</transition>
<textarea ref="inputField" v-model="localPrompt" @keydown="handleKeyDown" :disabled="disabled"
placeholder="Ask AstrBot..." class="chat-textarea"
autocomplete="off" autocorrect="off" autocapitalize="sentences" spellcheck="false"
style="width: 100%; resize: none; outline: none; border: 1px solid var(--v-theme-border); border-radius: 12px; padding: 16px 20px; min-height: 40px; max-height: 200px; overflow-y: auto; font-family: inherit; font-size: 16px; background-color: var(--v-theme-surface);"></textarea>
placeholder="Ask AstrBot..."
style="width: 100%; resize: none; outline: none; border: 1px solid var(--v-theme-border); border-radius: 12px; padding: 12px 16px; min-height: 40px; font-family: inherit; font-size: 16px; background-color: var(--v-theme-surface);"></textarea>
<div style="display: flex; justify-content: space-between; align-items: center; padding: 6px 14px;">
<div
style="display: flex; justify-content: flex-start; margin-top: 4px; align-items: center; gap: 8px; min-width: 0; flex: 1; overflow: hidden;">
style="display: flex; justify-content: flex-start; margin-top: 4px; align-items: center; gap: 8px;">
<!-- Settings Menu -->
<StyledMenu offset="8" location="top start" :close-on-content-click="false">
<template v-slot:activator="{ props: activatorProps }">
@@ -73,9 +72,9 @@
<!-- Provider/Model Selector Menu -->
<ProviderModelMenu v-if="showProviderSelector" ref="providerModelMenuRef" />
</div>
<div style="display: flex; justify-content: flex-end; margin-top: 8px; align-items: center; flex-shrink: 0;">
<div style="display: flex; justify-content: flex-end; margin-top: 8px; align-items: center;">
<input type="file" ref="imageInputRef" @change="handleFileSelect" style="display: none" multiple />
<v-progress-circular v-if="disabled && !mobile" indeterminate size="16" class="mr-1" width="1.5" />
<v-progress-circular v-if="disabled" indeterminate size="16" class="mr-1" width="1.5" />
<!-- <v-btn @click="$emit('openLiveMode')"
icon
variant="text"
@@ -88,21 +87,36 @@
</v-tooltip>
</v-btn> -->
<v-btn @click="handleRecordClick" icon variant="text" :color="isRecording ? 'error' : 'deep-purple'"
class="record-btn">
class="record-btn" size="small">
<v-icon :icon="isRecording ? 'mdi-stop-circle' : 'mdi-microphone'" variant="text"
plain></v-icon>
<v-tooltip activator="parent" location="top">
{{ isRecording ? tm('voice.speaking') : tm('voice.startRecording') }}
</v-tooltip>
</v-btn>
<v-btn icon v-if="isRunning" @click="$emit('stop')" variant="tonal" color="deep-purple" class="send-btn">
<v-btn
icon
v-if="isRunning"
@click="$emit('stop')"
variant="text"
class="send-btn"
size="small"
>
<v-icon icon="mdi-stop" variant="text" plain></v-icon>
<v-tooltip activator="parent" location="top">
{{ tm('input.stopGenerating') }}
</v-tooltip>
</v-btn>
<v-btn v-else @click="$emit('send')" icon="mdi-send" variant="tonal" color="deep-purple"
:disabled="!canSend" class="send-btn" />
<v-btn
v-else
@click="$emit('send')"
icon="mdi-send"
variant="text"
color="deep-purple"
:disabled="!canSend"
class="send-btn"
size="small"
/>
</div>
</div>
</div>
@@ -138,8 +152,7 @@
</template>
<script setup lang="ts">
import { ref, computed, watch, nextTick, onMounted, onBeforeUnmount } from 'vue';
import { useDisplay } from 'vuetify';
import { ref, computed, onMounted, onBeforeUnmount } from 'vue';
import { useModuleI18n } from '@/i18n/composables';
import { useCustomizerStore } from '@/stores/customizer';
import ConfigSelector from './ConfigSelector.vue';
@@ -238,34 +251,21 @@ function handleReplyAfterLeave() {
isReplyClosing.value = false;
}
const { mobile } = useDisplay();
// Auto-resize textarea
function autoResize() {
const el = inputField.value;
if (!el) return;
el.style.height = 'auto';
el.style.height = Math.min(el.scrollHeight, 200) + 'px';
}
watch(localPrompt, () => {
nextTick(autoResize);
});
function handleKeyDown(e: KeyboardEvent) {
// Enter
// Shift+Enter Ctrl+Enter / Cmd+Enter
if (e.keyCode === 13 && (e.shiftKey || e.ctrlKey || e.metaKey)) {
// Enter
if (e.keyCode === 13 && !e.shiftKey) {
e.preventDefault();
// /astr_live_dev
if (localPrompt.value.trim() === '/astr_live_dev') {
emit('openLiveMode');
localPrompt.value = '';
return;
}
if (canSend.value) {
emit('send');
}
return;
}
// Ctrl+B
@@ -588,20 +588,11 @@ defineExpose({
@media (max-width: 768px) {
.input-area {
padding: 0 !important;
padding-bottom: 10px !important;
}
.input-container {
width: 100% !important;
max-width: 100% !important;
}
.input-area textarea,
.chat-textarea {
min-height: 32px !important;
max-height: 160px !important;
font-size: 16px !important;
padding: 16px 16px 12px 16px !important;
}
}
</style>
@@ -37,7 +37,7 @@
@deleteProject="$emit('deleteProject', $event)"
/>
<div style="overflow-y: auto; flex-grow: 1; overscroll-behavior-y: contain;"
<div style="overflow-y: auto; flex-grow: 1;"
v-if="!sidebarCollapsed || isMobile">
<v-card v-if="sessions.length > 0" flat style="background-color: transparent;">
<v-list density="compact" nav class="conversation-list"
@@ -326,13 +326,6 @@ function handleTransportModeChange(mode: string | null) {
transition: all 0.2s ease;
}
@media (max-width: 768px) {
.conversation-actions {
opacity: 1 !important;
visibility: visible !important;
}
}
.edit-title-btn,
.delete-conversation-btn {
opacity: 0.7;
@@ -965,7 +965,6 @@ export default {
height: 100%;
max-height: 100%;
overflow-y: auto;
overscroll-behavior-y: contain;
padding: 16px;
display: flex;
flex-direction: column;
@@ -1,295 +1,60 @@
<template>
<div class="skills-page">
<v-container fluid class="pa-0" elevation="0">
<v-row class="d-flex justify-space-between align-center px-4 py-3 pb-4">
<v-row class="d-flex justify-space-between align-center px-4 py-3 pb-8">
<div>
<v-btn
v-if="mode === 'local'"
color="success"
prepend-icon="mdi-upload"
class="me-2"
variant="tonal"
@click="uploadDialog = true"
>
{{ tm("skills.upload") }}
<v-btn color="success" prepend-icon="mdi-upload" class="me-2" variant="tonal" @click="uploadDialog = true">
{{ tm('skills.upload') }}
</v-btn>
<v-btn color="primary" prepend-icon="mdi-refresh" variant="tonal" @click="refreshCurrentMode">
{{ tm("skills.refresh") }}
<v-btn color="primary" prepend-icon="mdi-refresh" variant="tonal" @click="fetchSkills">
{{ tm('skills.refresh') }}
</v-btn>
</div>
<v-btn-toggle v-model="mode" mandatory divided density="comfortable">
<v-btn value="local">{{ tm("skills.modeLocal") }}</v-btn>
<v-btn value="neo" :disabled="!neoEnabled">{{ tm("skills.modeNeo") }}</v-btn>
</v-btn-toggle>
</v-row>
<div v-if="mode === 'local'" class="px-2 pb-2 d-flex flex-column ga-2">
<small style="color: grey;">{{ tm("skills.runtimeHint") }}</small>
<v-alert
v-if="runtime === 'sandbox' && !sandboxCache.ready"
type="info"
variant="tonal"
density="comfortable"
border="start"
>
{{ tm("skills.sandboxDiscoveryPending") }}
</v-alert>
<div class="px-2 pb-2">
<small style="color: grey;">{{ tm('skills.runtimeHint') }}</small>
</div>
<div v-if="mode === 'neo' && !neoEnabled" class="px-3 pb-3">
<v-alert type="warning" variant="tonal" density="comfortable" border="start">
{{ neoUnavailableMessage }}
</v-alert>
<v-progress-linear v-if="loading" indeterminate color="primary"></v-progress-linear>
<div v-else-if="skills.length === 0" class="text-center pa-8">
<v-icon size="64" color="grey-lighten-1">mdi-folder-open</v-icon>
<p class="text-grey mt-4">{{ tm('skills.empty') }}</p>
<small class="text-grey">{{ tm('skills.emptyHint') }}</small>
</div>
<template v-if="mode === 'local'">
<v-progress-linear v-if="loading" indeterminate color="primary"></v-progress-linear>
<div v-else-if="skills.length === 0" class="text-center pa-8">
<v-icon size="64" color="grey-lighten-1">mdi-folder-open</v-icon>
<p class="text-grey mt-4">{{ tm("skills.empty") }}</p>
<small class="text-grey">{{ tm("skills.emptyHint") }}</small>
</div>
<v-row v-else align="stretch">
<v-col
v-for="skill in skills"
:key="skill.name"
cols="12"
md="6"
lg="4"
xl="3"
class="d-flex"
>
<item-card
:item="skill"
title-field="name"
enabled-field="active"
:loading="itemLoading[skill.name] || false"
:show-edit-button="false"
:disable-toggle="isSandboxPresetSkill(skill)"
:disable-delete="isSandboxPresetSkill(skill)"
@toggle-enabled="toggleSkill"
@delete="confirmDelete"
>
<template #item-details="{ item }">
<div class="d-flex align-center mb-2 ga-2 flex-wrap">
<v-chip
size="x-small"
variant="tonal"
:color="sourceTypeColor(item.source_type)"
>
{{ sourceTypeLabel(item.source_type) }}
</v-chip>
<div class="text-caption text-medium-emphasis skill-description">
<v-icon size="small" class="me-1">mdi-text</v-icon>
{{ item.description || tm("skills.noDescription") }}
</div>
</div>
<div class="text-caption text-medium-emphasis skill-path">
<v-icon size="small" class="me-1">mdi-file-document</v-icon>
{{ tm("skills.path") }}: {{ item.path }}
</div>
</template>
<template #actions="{ item }">
<v-btn
variant="tonal"
color="primary"
size="small"
rounded="xl"
:disabled="itemLoading[item.name] || false || isSandboxPresetSkill(item)"
@click="downloadSkill(item)"
>
{{ tm("skills.download") }}
</v-btn>
</template>
</item-card>
</v-col>
</v-row>
</template>
<template v-else-if="mode === 'neo' && neoEnabled">
<v-card class="mx-3 mb-4 pa-4 neo-filter-card" variant="outlined">
<div class="d-flex flex-wrap justify-space-between align-center ga-2 mb-3">
<div>
<div class="text-subtitle-1 font-weight-bold">Neo Skills</div>
<div class="text-caption text-medium-emphasis">{{ tm("skills.neoFilterHint") }}</div>
</div>
<v-btn color="primary" prepend-icon="mdi-refresh" variant="flat" @click="fetchNeoData">
{{ tm("skills.refresh") }}
</v-btn>
</div>
<v-row class="ga-md-0 ga-2">
<v-col cols="12" md="4">
<v-text-field
v-model="neoFilters.skill_key"
:label="tm('skills.neoSkillKey')"
prepend-inner-icon="mdi-key-outline"
density="comfortable"
hide-details
variant="outlined"
/>
</v-col>
<v-col cols="12" md="4">
<v-select
v-model="neoFilters.status"
:label="tm('skills.neoStatus')"
:items="candidateStatusItems"
item-title="title"
item-value="value"
prepend-inner-icon="mdi-progress-check"
density="comfortable"
hide-details
variant="outlined"
/>
</v-col>
<v-col cols="12" md="4">
<v-select
v-model="neoFilters.stage"
:label="tm('skills.neoStage')"
:items="releaseStageItems"
item-title="title"
item-value="value"
prepend-inner-icon="mdi-layers-outline"
density="comfortable"
hide-details
variant="outlined"
/>
</v-col>
</v-row>
</v-card>
<v-progress-linear v-if="neoLoading" indeterminate color="primary"></v-progress-linear>
<div class="mx-3 mb-3 d-flex flex-wrap ga-2">
<v-chip size="small" color="primary" variant="tonal">Candidates: {{ neoCandidates.length }}</v-chip>
<v-chip size="small" color="indigo" variant="tonal">Releases: {{ neoReleases.length }}</v-chip>
<v-chip size="small" color="success" variant="tonal">Active: {{ activeReleaseCount }}</v-chip>
</div>
<v-card class="mx-3 mb-4 neo-table-card" variant="outlined">
<v-card-title class="text-subtitle-1 font-weight-bold">{{ tm("skills.neoCandidates") }}</v-card-title>
<v-data-table
:headers="candidateHeaders"
:items="neoCandidates"
density="compact"
:items-per-page="10"
class="neo-data-table"
>
<template #item.latest_score="{ item }">
{{ item.latest_score ?? "-" }}
</template>
<template #item.actions="{ item }">
<div class="d-flex ga-1 flex-wrap">
<v-btn size="x-small" color="success" variant="tonal" @click="evaluateCandidate(item, true)">
{{ tm("skills.neoPass") }}
</v-btn>
<v-btn size="x-small" color="warning" variant="tonal" @click="evaluateCandidate(item, false)">
{{ tm("skills.neoReject") }}
</v-btn>
<v-btn
size="x-small"
color="primary"
variant="tonal"
:loading="isCandidatePromoteLoading(item.id, 'canary')"
:disabled="isCandidatePromoting(item.id)"
@click="promoteCandidate(item, 'canary')"
>
Canary
</v-btn>
<v-btn
size="x-small"
color="primary"
variant="tonal"
:loading="isCandidatePromoteLoading(item.id, 'stable')"
:disabled="isCandidatePromoting(item.id)"
@click="promoteCandidate(item, 'stable')"
>
Stable
</v-btn>
<v-btn
size="x-small"
variant="tonal"
@click="viewPayload(item.payload_ref)"
:disabled="!item.payload_ref"
>
Payload
</v-btn>
<v-btn
size="x-small"
color="error"
variant="tonal"
@click="deleteCandidate(item)"
>
{{ tm("skills.neoDelete") }}
</v-btn>
<v-row v-else>
<v-col v-for="skill in skills" :key="skill.name" cols="12" md="6" lg="4" xl="3">
<item-card :item="skill" title-field="name" enabled-field="active" :loading="itemLoading[skill.name] || false"
:show-edit-button="false" @toggle-enabled="toggleSkill" @delete="confirmDelete">
<template v-slot:item-details="{ item }">
<div class="text-caption text-medium-emphasis mb-2 skill-description">
<v-icon size="small" class="me-1">mdi-text</v-icon>
{{ item.description || tm('skills.noDescription') }}
</div>
<div class="text-caption text-medium-emphasis">
<v-icon size="small" class="me-1">mdi-file-document</v-icon>
{{ tm('skills.path') }}: {{ item.path }}
</div>
</template>
</v-data-table>
</v-card>
<v-card class="mx-3 mb-4 neo-table-card" variant="outlined">
<v-card-title class="text-subtitle-1 font-weight-bold">{{ tm("skills.neoReleases") }}</v-card-title>
<v-data-table
:headers="releaseHeaders"
:items="neoReleases"
density="compact"
:items-per-page="10"
class="neo-data-table"
>
<template #item.is_active="{ item }">
<v-chip size="small" :color="item.is_active ? 'success' : 'default'" variant="tonal">
{{ item.is_active ? "active" : "inactive" }}
</v-chip>
</template>
<template #item.actions="{ item }">
<div class="d-flex ga-1 flex-wrap">
<v-btn
size="x-small"
color="warning"
variant="tonal"
@click="handleReleaseLifecycleAction(item)"
>
{{ item.is_active ? tm("skills.neoDeactivate") : tm("skills.neoRollback") }}
</v-btn>
<v-btn size="x-small" color="primary" variant="tonal" @click="syncRelease(item)">
{{ tm("skills.neoSync") }}
</v-btn>
<v-btn
size="x-small"
color="error"
variant="tonal"
@click="deleteRelease(item)"
>
{{ tm("skills.neoDelete") }}
</v-btn>
</div>
</template>
</v-data-table>
</v-card>
</template>
</item-card>
</v-col>
</v-row>
</v-container>
<v-dialog v-model="uploadDialog" max-width="520px">
<v-card>
<v-card-title class="text-h3 pa-4 pb-0 pl-6">{{ tm("skills.uploadDialogTitle") }}</v-card-title>
<v-card-title class="text-h3 pa-4 pb-0 pl-6">{{ tm('skills.uploadDialogTitle') }}</v-card-title>
<v-card-text>
<small class="text-grey">{{ tm("skills.uploadHint") }}</small>
<v-file-input
v-model="uploadFile"
accept=".zip"
:label="tm('skills.selectFile')"
prepend-icon="mdi-folder-zip-outline"
variant="outlined"
class="mt-4"
:multiple="false"
/>
<small class="text-grey">{{ tm('skills.uploadHint') }}</small>
<v-file-input v-model="uploadFile" accept=".zip" :label="tm('skills.selectFile')"
prepend-icon="mdi-folder-zip-outline" variant="outlined" class="mt-4" :multiple="false" />
</v-card-text>
<v-card-actions class="d-flex justify-end">
<v-btn variant="text" @click="uploadDialog = false">{{ tm("skills.cancel") }}</v-btn>
<v-btn variant="text" @click="uploadDialog = false">{{ tm('skills.cancel') }}</v-btn>
<v-btn color="primary" :loading="uploading" :disabled="!uploadFile" @click="uploadSkill">
{{ tm("skills.confirmUpload") }}
{{ tm('skills.confirmUpload') }}
</v-btn>
</v-card-actions>
</v-card>
@@ -297,30 +62,18 @@
<v-dialog v-model="deleteDialog" max-width="400px">
<v-card>
<v-card-title>{{ tm("skills.deleteTitle") }}</v-card-title>
<v-card-text>{{ tm("skills.deleteMessage") }}</v-card-text>
<v-card-title>{{ tm('skills.deleteTitle') }}</v-card-title>
<v-card-text>{{ tm('skills.deleteMessage') }}</v-card-text>
<v-card-actions class="d-flex justify-end">
<v-btn variant="text" @click="deleteDialog = false">{{ tm("skills.cancel") }}</v-btn>
<v-btn variant="text" @click="deleteDialog = false">{{ tm('skills.cancel') }}</v-btn>
<v-btn color="error" :loading="deleting" @click="deleteSkill">
{{ t("core.common.itemCard.delete") }}
{{ t('core.common.itemCard.delete') }}
</v-btn>
</v-card-actions>
</v-card>
</v-dialog>
<v-dialog v-model="payloadDialog.show" max-width="820px">
<v-card>
<v-card-title>{{ tm("skills.neoPayloadTitle") }}</v-card-title>
<v-card-text>
<pre class="payload-preview">{{ payloadDialog.content }}</pre>
</v-card-text>
<v-card-actions class="d-flex justify-end">
<v-btn variant="text" @click="payloadDialog.show = false">{{ tm("skills.cancel") }}</v-btn>
</v-card-actions>
</v-card>
</v-dialog>
<v-snackbar v-model="snackbar.show" :timeout="3500" :color="snackbar.color" elevation="24">
<v-snackbar v-model="snackbar.show" :timeout="3000" :color="snackbar.color" elevation="24">
{{ snackbar.message }}
</v-snackbar>
</div>
@@ -328,7 +81,7 @@
<script>
import axios from "axios";
import { computed, onMounted, reactive, ref, watch } from "vue";
import { ref, reactive, onMounted } from "vue";
import ItemCard from "@/components/shared/ItemCard.vue";
import { useI18n, useModuleI18n } from "@/i18n/composables";
@@ -339,11 +92,8 @@ export default {
const { t } = useI18n();
const { tm } = useModuleI18n("features/extension");
const mode = ref("local");
const skills = ref([]);
const loading = ref(false);
const runtime = ref("local");
const sandboxCache = reactive({ ready: false, count: 0, updated_at: null });
const uploading = ref(false);
const uploadDialog = ref(false);
const uploadFile = ref(null);
@@ -353,109 +103,23 @@ export default {
const skillToDelete = ref(null);
const snackbar = reactive({ show: false, message: "", color: "success" });
const neoLoading = ref(false);
const neoCandidates = ref([]);
const neoReleases = ref([]);
const neoFilters = reactive({
skill_key: "",
status: "",
stage: "",
});
const candidatePromoteLoading = reactive({});
const payloadDialog = reactive({
show: false,
content: "",
});
const neoEnabled = ref(false);
const neoUnavailableMessage = ref("");
const candidateStatusItems = computed(() => [
{ title: tm("skills.neoAll"), value: "" },
{ title: "draft", value: "draft" },
{ title: "evaluating", value: "evaluating" },
{ title: "promoted", value: "promoted" },
{ title: "promoted_canary", value: "promoted_canary" },
{ title: "promoted_stable", value: "promoted_stable" },
{ title: "rejected", value: "rejected" },
{ title: "rolled_back", value: "rolled_back" },
]);
const releaseStageItems = computed(() => [
{ title: tm("skills.neoAll"), value: "" },
{ title: "canary", value: "canary" },
{ title: "stable", value: "stable" },
]);
const activeReleaseCount = computed(() => neoReleases.value.filter((item) => item?.is_active).length);
const candidateHeaders = computed(() => [
{ title: "ID", key: "id", width: "180px" },
{ title: "skill_key", key: "skill_key" },
{ title: "status", key: "status", width: "130px" },
{ title: "score", key: "latest_score", width: "90px" },
{ title: tm("skills.actions"), key: "actions", sortable: false, width: "420px" },
]);
const releaseHeaders = computed(() => [
{ title: "ID", key: "id", width: "180px" },
{ title: "skill_key", key: "skill_key" },
{ title: "stage", key: "stage", width: "100px" },
{ title: "version", key: "version", width: "90px" },
{ title: "active", key: "is_active", width: "110px" },
{ title: tm("skills.actions"), key: "actions", sortable: false, width: "220px" },
]);
const showMessage = (message, color = "success") => {
snackbar.message = message;
snackbar.color = color;
snackbar.show = true;
};
const normalizeSkillsPayload = (res) => {
const payload = res?.data?.data || [];
if (Array.isArray(payload)) {
runtime.value = "local";
sandboxCache.ready = false;
sandboxCache.count = 0;
sandboxCache.updated_at = null;
return payload;
}
runtime.value = payload.runtime || "local";
const cache = payload.sandbox_cache || {};
sandboxCache.ready = !!cache.ready;
sandboxCache.count = Number(cache.count || 0);
sandboxCache.updated_at = cache.updated_at || null;
return payload.skills || [];
};
const sourceTypeLabel = (sourceType) => {
if (sourceType === "sandbox_only") return tm("skills.sourceSandboxOnly");
if (sourceType === "both") return tm("skills.sourceBoth");
return tm("skills.sourceLocalOnly");
};
const sourceTypeColor = (sourceType) => {
if (sourceType === "sandbox_only") return "indigo";
if (sourceType === "both") return "success";
return "primary";
};
const isSandboxPresetSkill = (skill) => skill?.source_type === "sandbox_only";
const normalizeNeoItemsPayload = (res) => {
const payload = res?.data?.data || [];
if (Array.isArray(payload)) return payload;
if (Array.isArray(payload.items)) return payload.items;
return [];
};
const fetchSkills = async () => {
loading.value = true;
try {
const res = await axios.get("/api/skills");
skills.value = normalizeSkillsPayload(res);
} catch (_err) {
const payload = res.data?.data || [];
if (Array.isArray(payload)) {
skills.value = payload;
} else {
skills.value = payload.skills || [];
}
} catch (err) {
showMessage(tm("skills.loadFailed"), "error");
} finally {
loading.value = false;
@@ -477,7 +141,9 @@ export default {
uploading.value = true;
try {
const formData = new FormData();
const file = Array.isArray(uploadFile.value) ? uploadFile.value[0] : uploadFile.value;
const file = Array.isArray(uploadFile.value)
? uploadFile.value[0]
: uploadFile.value;
if (!file) {
uploading.value = false;
return;
@@ -486,12 +152,17 @@ export default {
const res = await axios.post("/api/skills/upload", formData, {
headers: { "Content-Type": "multipart/form-data" },
});
handleApiResponse(res, tm("skills.uploadSuccess"), tm("skills.uploadFailed"), async () => {
uploadDialog.value = false;
uploadFile.value = null;
await fetchSkills();
});
} catch (_err) {
handleApiResponse(
res,
tm("skills.uploadSuccess"),
tm("skills.uploadFailed"),
async () => {
uploadDialog.value = false;
uploadFile.value = null;
await fetchSkills();
}
);
} catch (err) {
showMessage(tm("skills.uploadFailed"), "error");
} finally {
uploading.value = false;
@@ -499,10 +170,6 @@ export default {
};
const toggleSkill = async (skill) => {
if (isSandboxPresetSkill(skill)) {
showMessage(tm("skills.sandboxPresetReadonly"), "warning");
return;
}
const nextActive = !skill.active;
itemLoading[skill.name] = true;
try {
@@ -510,10 +177,15 @@ export default {
name: skill.name,
active: nextActive,
});
handleApiResponse(res, tm("skills.updateSuccess"), tm("skills.updateFailed"), () => {
skill.active = nextActive;
});
} catch (_err) {
handleApiResponse(
res,
tm("skills.updateSuccess"),
tm("skills.updateFailed"),
() => {
skill.active = nextActive;
}
);
} catch (err) {
showMessage(tm("skills.updateFailed"), "error");
} finally {
itemLoading[skill.name] = false;
@@ -521,10 +193,6 @@ export default {
};
const confirmDelete = (skill) => {
if (isSandboxPresetSkill(skill)) {
showMessage(tm("skills.sandboxPresetReadonly"), "warning");
return;
}
skillToDelete.value = skill;
deleteDialog.value = true;
};
@@ -536,288 +204,29 @@ export default {
const res = await axios.post("/api/skills/delete", {
name: skillToDelete.value.name,
});
handleApiResponse(res, tm("skills.deleteSuccess"), tm("skills.deleteFailed"), async () => {
deleteDialog.value = false;
await fetchSkills();
});
} catch (_err) {
handleApiResponse(
res,
tm("skills.deleteSuccess"),
tm("skills.deleteFailed"),
async () => {
deleteDialog.value = false;
await fetchSkills();
}
);
} catch (err) {
showMessage(tm("skills.deleteFailed"), "error");
} finally {
deleting.value = false;
}
};
const downloadSkill = async (skill) => {
if (isSandboxPresetSkill(skill)) {
showMessage(tm("skills.sandboxPresetReadonly"), "warning");
return;
}
itemLoading[skill.name] = true;
try {
const res = await axios.get("/api/skills/download", {
params: { name: skill.name },
responseType: "blob",
});
const blob = new Blob([res.data], { type: "application/zip" });
const url = window.URL.createObjectURL(blob);
const link = document.createElement("a");
link.href = url;
link.download = `${skill.name}.zip`;
document.body.appendChild(link);
link.click();
document.body.removeChild(link);
window.URL.revokeObjectURL(url);
showMessage(tm("skills.downloadSuccess"), "success");
} catch (_err) {
showMessage(tm("skills.downloadFailed"), "error");
} finally {
itemLoading[skill.name] = false;
}
};
const fetchNeoCandidates = async () => {
const params = {
skill_key: neoFilters.skill_key || undefined,
status: neoFilters.status || undefined,
};
const res = await axios.get("/api/skills/neo/candidates", { params });
neoCandidates.value = normalizeNeoItemsPayload(res);
};
const fetchNeoReleases = async () => {
const params = {
skill_key: neoFilters.skill_key || undefined,
stage: neoFilters.stage || undefined,
};
const res = await axios.get("/api/skills/neo/releases", { params });
neoReleases.value = normalizeNeoItemsPayload(res).map((item) => {
if (!item || typeof item !== "object") {
return item;
}
return {
...item,
is_active: item.is_active ?? item.active ?? false,
};
});
};
const loadNeoAvailability = async () => {
try {
const res = await axios.get("/api/config/get");
const config = res?.data?.data?.config || {};
const providerSettings = config?.provider_settings || {};
const runtime = providerSettings?.computer_use_runtime || "local";
const booter = providerSettings?.sandbox?.booter || "";
neoEnabled.value = runtime === "sandbox" && booter === "shipyard_neo";
} catch (_err) {
neoEnabled.value = false;
}
neoUnavailableMessage.value = tm("skills.neoRuntimeRequired");
if (!neoEnabled.value && mode.value === "neo") {
mode.value = "local";
}
};
const fetchNeoData = async () => {
neoLoading.value = true;
try {
await Promise.all([fetchNeoCandidates(), fetchNeoReleases()]);
} catch (_err) {
showMessage(tm("skills.neoLoadFailed"), "error");
} finally {
neoLoading.value = false;
}
};
const evaluateCandidate = async (candidate, passed) => {
try {
const res = await axios.post("/api/skills/neo/evaluate", {
candidate_id: candidate.id,
passed,
score: passed ? 1.0 : 0.0,
report: passed ? "approved_from_webui" : "rejected_from_webui",
});
handleApiResponse(res, tm("skills.neoEvaluateSuccess"), tm("skills.neoEvaluateFailed"), async () => {
await fetchNeoCandidates();
});
} catch (_err) {
showMessage(tm("skills.neoEvaluateFailed"), "error");
}
};
const candidatePromoteLoadingKey = (candidateId, stage) => `${candidateId}:${stage}`;
const isCandidatePromoteLoading = (candidateId, stage) =>
!!candidatePromoteLoading[candidatePromoteLoadingKey(candidateId, stage)];
const isCandidatePromoting = (candidateId) =>
isCandidatePromoteLoading(candidateId, "canary") || isCandidatePromoteLoading(candidateId, "stable");
const promoteCandidate = async (candidate, stage) => {
const candidateId = candidate?.id;
if (!candidateId) return;
const loadingKey = candidatePromoteLoadingKey(candidateId, stage);
if (candidatePromoteLoading[loadingKey]) return;
candidatePromoteLoading[loadingKey] = true;
try {
const res = await axios.post("/api/skills/neo/promote", {
candidate_id: candidateId,
stage,
sync_to_local: true,
});
const ok = res?.data?.status === "ok";
if (!ok) {
showMessage(res?.data?.message || tm("skills.neoPromoteFailed"), "error");
} else {
showMessage(tm("skills.neoPromoteSuccess"), "success");
}
await fetchNeoData();
if (stage === "stable") {
await fetchSkills();
}
} catch (_err) {
showMessage(tm("skills.neoPromoteFailed"), "error");
} finally {
candidatePromoteLoading[loadingKey] = false;
}
};
const rollbackRelease = async (release) => {
try {
const res = await axios.post("/api/skills/neo/rollback", {
release_id: release.id,
});
handleApiResponse(res, tm("skills.neoRollbackSuccess"), tm("skills.neoRollbackFailed"), async () => {
await fetchNeoData();
});
} catch (_err) {
showMessage(tm("skills.neoRollbackFailed"), "error");
}
};
const deactivateRelease = async (release) => {
try {
const res = await axios.post("/api/skills/neo/rollback", {
release_id: release.id,
});
handleApiResponse(
res,
tm("skills.neoDeactivateSuccess"),
tm("skills.neoDeactivateFailed"),
async () => {
await fetchNeoData();
},
);
} catch (_err) {
showMessage(tm("skills.neoDeactivateFailed"), "error");
}
};
const handleReleaseLifecycleAction = async (release) => {
if (release?.is_active) {
await deactivateRelease(release);
return;
}
await rollbackRelease(release);
};
const syncRelease = async (release) => {
try {
const res = await axios.post("/api/skills/neo/sync", {
release_id: release.id,
});
handleApiResponse(res, tm("skills.neoSyncSuccess"), tm("skills.neoSyncFailed"), async () => {
await fetchSkills();
});
} catch (_err) {
showMessage(tm("skills.neoSyncFailed"), "error");
}
};
const viewPayload = async (payloadRef) => {
if (!payloadRef) return;
try {
const res = await axios.get("/api/skills/neo/payload", {
params: { payload_ref: payloadRef },
});
if (res?.data?.status !== "ok") {
showMessage(res?.data?.message || tm("skills.neoPayloadFailed"), "error");
return;
}
const payload = res?.data?.data || {};
payloadDialog.content = JSON.stringify(payload, null, 2);
payloadDialog.show = true;
} catch (_err) {
showMessage(tm("skills.neoPayloadFailed"), "error");
}
};
const deleteCandidate = async (candidate) => {
try {
const res = await axios.post("/api/skills/neo/delete-candidate", {
candidate_id: candidate.id,
reason: "deleted_from_webui",
});
handleApiResponse(res, tm("skills.neoDeleteSuccess"), tm("skills.neoDeleteFailed"), async () => {
await fetchNeoData();
});
} catch (_err) {
showMessage(tm("skills.neoDeleteFailed"), "error");
}
};
const deleteRelease = async (release) => {
try {
const res = await axios.post("/api/skills/neo/delete-release", {
release_id: release.id,
reason: "deleted_from_webui",
});
handleApiResponse(res, tm("skills.neoDeleteSuccess"), tm("skills.neoDeleteFailed"), async () => {
await fetchNeoData();
});
} catch (_err) {
showMessage(tm("skills.neoDeleteFailed"), "error");
}
};
const refreshCurrentMode = async () => {
if (mode.value === "neo") {
await loadNeoAvailability();
if (neoEnabled.value) {
await fetchNeoData();
} else {
showMessage(tm("skills.neoRuntimeRequired"), "warning");
}
} else {
await fetchSkills();
}
};
watch(mode, async (nextMode) => {
if (nextMode === "neo") {
await loadNeoAvailability();
if (neoEnabled.value) {
await fetchNeoData();
}
} else {
await fetchSkills();
}
});
onMounted(async () => {
await Promise.all([fetchSkills(), loadNeoAvailability()]);
if (neoEnabled.value) {
await fetchNeoData();
}
});
onMounted(fetchSkills);
return {
t,
tm,
mode,
skills,
loading,
runtime,
sandboxCache,
uploadDialog,
uploadFile,
uploading,
@@ -825,39 +234,11 @@ export default {
deleteDialog,
deleting,
snackbar,
neoEnabled,
neoUnavailableMessage,
neoLoading,
neoCandidates,
neoReleases,
neoFilters,
candidateStatusItems,
releaseStageItems,
activeReleaseCount,
candidateHeaders,
releaseHeaders,
payloadDialog,
refreshCurrentMode,
fetchNeoData,
fetchSkills,
uploadSkill,
downloadSkill,
toggleSkill,
confirmDelete,
deleteSkill,
evaluateCandidate,
promoteCandidate,
isCandidatePromoteLoading,
isCandidatePromoting,
rollbackRelease,
deactivateRelease,
handleReleaseLifecycleAction,
syncRelease,
viewPayload,
deleteCandidate,
deleteRelease,
sourceTypeLabel,
sourceTypeColor,
isSandboxPresetSkill,
};
},
};
@@ -869,42 +250,5 @@ export default {
-webkit-line-clamp: 1;
-webkit-box-orient: vertical;
overflow: hidden;
min-height: 20px;
}
.skill-path {
display: -webkit-box;
-webkit-line-clamp: 2;
-webkit-box-orient: vertical;
overflow: hidden;
min-height: 40px;
word-break: break-all;
}
.payload-preview {
max-height: 480px;
overflow: auto;
background: #111;
color: #ececec;
padding: 12px;
border-radius: 8px;
font-size: 12px;
}
.neo-filter-card {
border-radius: 14px;
border-color: rgba(var(--v-theme-primary), 0.25);
background: linear-gradient(180deg, rgba(var(--v-theme-primary), 0.03), rgba(var(--v-theme-surface), 1));
}
.neo-table-card {
border-radius: 14px;
}
.neo-data-table :deep(.v-data-table-header__content) {
font-weight: 700;
}
.neo-data-table :deep(tbody tr:hover) {
background: rgba(var(--v-theme-primary), 0.04);
}
</style>
+2 -11
View File
@@ -10,7 +10,7 @@
density="compact"
:model-value="getItemEnabled()"
:loading="loading"
:disabled="loading || disableToggle"
:disabled="loading"
v-bind="props"
@update:model-value="toggleEnabled"
></v-switch>
@@ -29,7 +29,7 @@
color="error"
size="small"
rounded="xl"
:disabled="loading || disableDelete"
:disabled="loading"
@click="$emit('delete', item)"
>
{{ t('core.common.itemCard.delete') }}
@@ -108,14 +108,6 @@ export default {
showEditButton: {
type: Boolean,
default: true
},
disableToggle: {
type: Boolean,
default: false
},
disableDelete: {
type: Boolean,
default: false
}
},
emits: ['toggle-enabled', 'delete', 'edit', 'copy'],
@@ -140,7 +132,6 @@ export default {
transition: all 0.3s ease;
overflow: hidden;
min-height: 220px;
height: 100%;
display: flex;
flex-direction: column;
justify-content: space-between;
@@ -161,22 +161,6 @@
"booter": {
"description": "Sandbox Environment Driver"
},
"shipyard_neo_endpoint": {
"description": "Shipyard Neo API Endpoint",
"hint": "Bay API address, default http://127.0.0.1:8114."
},
"shipyard_neo_access_token": {
"description": "Shipyard Neo Access Token",
"hint": "Bay API Key (sk-bay-...). Leave empty for auto-discovery from credentials.json."
},
"shipyard_neo_profile": {
"description": "Shipyard Neo Profile",
"hint": "Sandbox profile for Shipyard Neo, e.g. python-default."
},
"shipyard_neo_ttl": {
"description": "Shipyard Neo Sandbox TTL",
"hint": "Sandbox time-to-live in seconds."
},
"shipyard_endpoint": {
"description": "Shipyard API Endpoint",
"hint": "API access address for Shipyard service."
@@ -370,8 +354,7 @@
"hint": "Optional Discord activity name. Leave empty to disable."
},
"discord_command_register": {
"description": "Register Discord slash commands",
"hint": "When enabled, AstrBot will automatically register plugin commands as Discord slash commands"
"description": "Auto-register plugin commands as Discord slash commands"
},
"discord_proxy": {
"description": "Discord Proxy URL",
@@ -584,51 +567,6 @@
"only_use_webhook_url_to_send": {
"description": "Send Replies via Webhook Only",
"hint": "When enabled, all WeCom AI Bot replies are sent through msg_push_webhook_url. The message push webhook supports more message types (such as images, files, etc.). If you do not need the typing effect, it is strongly recommended to use this option. "
},
"kook_bot_token": {
"description": "Bot Token",
"type": "string",
"hint": "Required. The Bot Token obtained from the KOOK Developer Platform."
},
"kook_bot_nickname": {
"description": "Bot Nickname",
"type": "string",
"hint": "Optional. If the sender nickname matches this value, the message will be ignored to prevent broadcast storms."
},
"kook_reconnect_delay": {
"description": "Reconnect Delay",
"type": "int",
"hint": "Delay time for reconnection (seconds), using an exponential backoff strategy."
},
"kook_max_reconnect_delay": {
"description": "Max Reconnect Delay",
"type": "int",
"hint": "The maximum value for reconnection delay (seconds)."
},
"kook_max_retry_delay": {
"description": "Max Retry Delay",
"type": "int",
"hint": "The maximum delay time for retries (seconds)."
},
"kook_heartbeat_interval": {
"description": "Heartbeat Interval",
"type": "int",
"hint": "The interval time for heartbeat detection (seconds)."
},
"kook_heartbeat_timeout": {
"description": "Heartbeat Timeout",
"type": "int",
"hint": "The timeout duration for heartbeat detection (seconds)."
},
"kook_max_heartbeat_failures": {
"description": "Max Heartbeat Failures",
"type": "int",
"hint": "Maximum allowed heartbeat failures; the connection will be dropped if exceeded."
},
"kook_max_consecutive_failures": {
"description": "Max Consecutive Failures",
"type": "int",
"hint": "Maximum allowed consecutive failures; retries will stop if exceeded."
}
},
"general": {
@@ -783,17 +721,6 @@
"hint": "Telegram only supports a fixed reaction set, reference: [https://gist.github.com/Soulter/3f22c8e5f9c7e152e967e8bc28c97fc9](https://gist.github.com/Soulter/3f22c8e5f9c7e152e967e8bc28c97fc9)"
}
}
},
"discord": {
"pre_ack_emoji": {
"enable": {
"description": "[Discord] Enable Pre-acknowledgment Emoji"
},
"emojis": {
"description": "Emoji List (Unicode or Custom Emoji Name)",
"hint": "Enter Unicode emoji symbols, e.g., 👍, 🤔, ⏳"
}
}
}
}
}
@@ -216,9 +216,6 @@
"enterUrl": "Enter extension repository URL"
},
"skills": {
"modeLocal": "Local Skills",
"modeNeo": "Neo Skills",
"actions": "Actions",
"upload": "Upload Skills",
"refresh": "Refresh",
"empty": "No Skills found",
@@ -232,9 +229,6 @@
"path": "Path",
"uploadSuccess": "Upload succeeded",
"uploadFailed": "Upload failed",
"download": "Download",
"downloadSuccess": "Download succeeded",
"downloadFailed": "Download failed",
"loadFailed": "Failed to load Skills",
"updateSuccess": "Updated successfully",
"updateFailed": "Update failed",
@@ -242,42 +236,8 @@
"deleteMessage": "Are you sure you want to delete this Skill?",
"deleteSuccess": "Deleted successfully",
"deleteFailed": "Delete failed",
"neoSkillKey": "Filter by skill_key",
"neoStatus": "Candidate Status",
"neoStage": "Release Stage",
"neoFilterHint": "Filter candidates and release records",
"neoAll": "All",
"neoCandidates": "Neo Candidates",
"neoReleases": "Neo Releases",
"neoLoadFailed": "Failed to load Neo skills data",
"neoPass": "Pass",
"neoReject": "Reject",
"neoEvaluateSuccess": "Evaluation updated",
"neoEvaluateFailed": "Failed to update evaluation",
"neoPromoteSuccess": "Promoted successfully",
"neoPromoteFailed": "Failed to promote",
"neoRollback": "Rollback",
"neoRollbackSuccess": "Rollback succeeded",
"neoRollbackFailed": "Rollback failed",
"neoDeactivate": "Deactivate",
"neoDeactivateSuccess": "Deactivated successfully",
"neoDeactivateFailed": "Failed to deactivate",
"neoSync": "Sync",
"neoSyncSuccess": "Sync succeeded",
"neoSyncFailed": "Sync failed",
"neoDelete": "Delete",
"neoDeleteSuccess": "Deleted successfully",
"neoDeleteFailed": "Failed to delete",
"neoPayloadTitle": "Neo Payload",
"neoPayloadFailed": "Failed to load payload",
"runtimeNoneWarning": "Computer Use runtime is set to None; Skills may not run correctly because no runtime is enabled.",
"runtimeHint": "Set the Computer Use runtime to Local or Sandbox in settings so AstrBot can use your Skills.",
"neoRuntimeRequired": "Neo Skills are available only when runtime is sandbox and sandbox booter is shipyard_neo.",
"sourceLocalOnly": "Local Skill",
"sourceSandboxOnly": "Sandbox Preset Skill",
"sourceBoth": "Local + Sandbox",
"sandboxDiscoveryPending": "Sandbox preset skills have not been discovered yet. Start at least one sandbox session to populate this list.",
"sandboxPresetReadonly": "Sandbox preset skills are read-only here. You cannot delete or enable/disable them from Local Skills."
"runtimeHint": "Set the Computer Use runtime to Local or Sandbox in settings so AstrBot can use your Skills."
},
"card": {
"actions": {
@@ -164,22 +164,6 @@
"booter": {
"description": "沙箱环境驱动器"
},
"shipyard_neo_endpoint": {
"description": "Shipyard Neo API Endpoint",
"hint": "Shipyard Neo(Bay) 服务的 API 地址,默认 http://127.0.0.1:8114。"
},
"shipyard_neo_access_token": {
"description": "Shipyard Neo 访问令牌",
"hint": "Bay 的 API Keysk-bay-...)。留空时自动从 credentials.json 发现。"
},
"shipyard_neo_profile": {
"description": "Shipyard Neo Profile",
"hint": "Shipyard Neo 沙箱 profile,例如 python-default。"
},
"shipyard_neo_ttl": {
"description": "Shipyard Neo Sandbox 存活时间(秒)",
"hint": "Shipyard Neo 沙箱的生存时间(秒)。"
},
"shipyard_endpoint": {
"description": "Shipyard API Endpoint",
"hint": "Shipyard 服务的 API 访问地址。"
@@ -373,8 +357,7 @@
"hint": "可选的 Discord 活动名称。留空则不设置活动。"
},
"discord_command_register": {
"description": "注册 Discord 指令",
"hint": "启用后,自动将插件指令注册为 Discord 斜杠指令"
"description": "是否自动将插件指令注册 Discord 斜杠指令"
},
"discord_proxy": {
"description": "Discord 代理地址",
@@ -587,51 +570,6 @@
"only_use_webhook_url_to_send": {
"description": "仅使用 Webhook 发送消息",
"hint": "启用后,企业微信智能机器人的所有回复都改为通过消息推送 Webhook 发送。消息推送 Webhook 支持更多的消息类型(如图片、文件等)。如果不需要打字机效果,强烈建议使用此选项。"
},
"kook_bot_token": {
"description": "机器人 Token",
"type": "string",
"hint": "必填项。从 KOOK 开发者平台获取的机器人 Token"
},
"kook_bot_nickname": {
"description": "Bot Nickname",
"type": "string",
"hint": "可选项。若发送者昵称与此值一致,将忽略该消息。"
},
"kook_reconnect_delay": {
"description": "重连延迟",
"type": "int",
"hint": "重连延迟时间(秒),使用指数退避策略"
},
"kook_max_reconnect_delay": {
"description": "最大重连延迟",
"type": "int",
"hint": "重连延迟的最大值(秒)"
},
"kook_max_retry_delay": {
"description": "最大重试延迟",
"type": "int",
"hint": "重试的最大延迟时间(秒)"
},
"kook_heartbeat_interval": {
"description": "心跳间隔",
"type": "int",
"hint": "心跳检测间隔时间(秒)"
},
"kook_heartbeat_timeout": {
"description": "心跳超时时间",
"type": "int",
"hint": "心跳检测超时时间(秒)"
},
"kook_max_heartbeat_failures": {
"description": "最大心跳失败次数",
"type": "int",
"hint": "允许的最大心跳失败次数,超过后断开连接"
},
"kook_max_consecutive_failures": {
"description": "最大连续失败次数",
"type": "int",
"hint": "允许的最大连续失败次数,超过后停止重试"
}
},
"general": {
@@ -786,17 +724,6 @@
"hint": "Telegram 仅支持固定反应集合,参考:[https://gist.github.com/Soulter/3f22c8e5f9c7e152e967e8bc28c97fc9](https://gist.github.com/Soulter/3f22c8e5f9c7e152e967e8bc28c97fc9)"
}
}
},
"discord": {
"pre_ack_emoji": {
"enable": {
"description": "[Discord] 启用预回应表情"
},
"emojis": {
"description": "表情列表(Unicode 或自定义表情名)",
"hint": "填写 Unicode 表情符号,例如:👍、🤔、⏳"
}
}
}
}
}
@@ -216,9 +216,6 @@
"enterUrl": "输入插件仓库链接"
},
"skills": {
"modeLocal": "本地 Skills",
"modeNeo": "Neo Skills",
"actions": "操作",
"upload": "上传 Skills",
"refresh": "刷新",
"empty": "暂无 Skills",
@@ -232,9 +229,6 @@
"path": "路径",
"uploadSuccess": "上传成功",
"uploadFailed": "上传失败",
"download": "下载",
"downloadSuccess": "下载成功",
"downloadFailed": "下载失败",
"loadFailed": "加载 Skills 失败",
"updateSuccess": "更新成功",
"updateFailed": "更新失败",
@@ -242,42 +236,8 @@
"deleteMessage": "确定要删除该 Skill 吗?",
"deleteSuccess": "删除成功",
"deleteFailed": "删除失败",
"neoSkillKey": "skill_key 过滤",
"neoStatus": "候选状态",
"neoStage": "发布阶段",
"neoFilterHint": "筛选候选与发布记录",
"neoAll": "全部",
"neoCandidates": "Neo Candidates",
"neoReleases": "Neo Releases",
"neoLoadFailed": "加载 Neo Skills 数据失败",
"neoPass": "通过",
"neoReject": "拒绝",
"neoEvaluateSuccess": "评测更新成功",
"neoEvaluateFailed": "评测更新失败",
"neoPromoteSuccess": "发布成功",
"neoPromoteFailed": "发布失败",
"neoRollback": "回滚",
"neoRollbackSuccess": "回滚成功",
"neoRollbackFailed": "回滚失败",
"neoDeactivate": "失活",
"neoDeactivateSuccess": "失活成功",
"neoDeactivateFailed": "失活失败",
"neoSync": "同步",
"neoSyncSuccess": "同步成功",
"neoSyncFailed": "同步失败",
"neoDelete": "删除",
"neoDeleteSuccess": "删除成功",
"neoDeleteFailed": "删除失败",
"neoPayloadTitle": "Neo Payload 详情",
"neoPayloadFailed": "读取 Payload 失败",
"runtimeNoneWarning": "Computer Use 运行环境为无,Skills 可能无法正确被 Agent 运行,因为没有启用运行环境。",
"runtimeHint": "需要在配置的 “使用电脑能力” 中将运行环境设置为 “local” 或 “sandbox” 才能让 AstrBot 正常使用你提供的 Skills。",
"neoRuntimeRequired": "Neo Skills 仅在运行环境为 sandbox 且沙箱驱动为 shipyard_neo 时可用。",
"sourceLocalOnly": "本地 Skill",
"sourceSandboxOnly": "Sandbox 预置 Skill",
"sourceBoth": "本地 + Sandbox",
"sandboxDiscoveryPending": "尚未发现 Sandbox 预置 Skill。请至少启动一次 Sandbox 会话后再查看。",
"sandboxPresetReadonly": "Sandbox 预置 Skill 在此处为只读,无法在本地 Skills 页面删除或启用/禁用。"
"runtimeHint": "需要在配置的 “使用电脑能力” 中将运行环境设置为 “local” 或 “sandbox” 才能让 AstrBot 正常使用你提供的 Skills。"
},
"card": {
"actions": {
@@ -468,12 +468,6 @@ onMounted(async () => {
<v-icon>mdi-menu</v-icon>
</v-btn>
<!-- 移动端 chat sidebar 展开按钮 - 仅在 chat 模式下的小屏幕显示 -->
<v-btn v-if="customizer.viewMode === 'chat'" class="hidden-lg-and-up ms-1" icon rounded="sm" variant="flat"
@click.stop="customizer.TOGGLE_CHAT_SIDEBAR()">
<v-icon>mdi-menu</v-icon>
</v-btn>
<div class="logo-container" :class="{ 'mobile-logo': $vuetify.display.xs, 'chat-mode-logo': customizer.viewMode === 'chat' }" @click="handleLogoClick">
<span class="logo-text Outfit">Astr<span class="logo-text bot-text-wrapper">Bot
<img v-if="isChristmas" src="@/assets/images/xmas-hat.png" alt="Christmas hat" class="xmas-hat" />
@@ -494,13 +488,13 @@ onMounted(async () => {
</small>
</div>
<!-- Bot/Chat 模式切换按钮 - 手机端隐藏移入 ... 菜单 -->
<!-- Bot/Chat 模式切换按钮 -->
<v-btn-toggle
v-model="viewMode"
mandatory
variant="outlined"
density="compact"
class="mr-4 hidden-xs"
class="mr-4"
color="primary"
>
<v-btn value="bot" size="small">
@@ -530,30 +524,6 @@ onMounted(async () => {
</v-btn>
</template>
<!-- Bot/Chat 模式切换 - 仅在手机端显示 -->
<template v-if="$vuetify.display.xs">
<div class="mobile-mode-toggle-wrapper">
<v-btn-toggle
v-model="viewMode"
mandatory
variant="outlined"
density="compact"
color="primary"
class="mobile-mode-toggle"
>
<v-btn value="bot" size="small">
<v-icon start>mdi-robot</v-icon>
Bot
</v-btn>
<v-btn value="chat" size="small">
<v-icon start>mdi-chat</v-icon>
Chat
</v-btn>
</v-btn-toggle>
</div>
<v-divider class="my-1" />
</template>
<!-- 语言切换 -->
<v-list-item
v-for="lang in languages"
@@ -918,10 +888,6 @@ onMounted(async () => {
margin-left: 22px;
}
.mobile-logo.chat-mode-logo {
margin-left: 4px;
}
.logo-text {
font-size: 24px;
font-weight: 1000;
@@ -960,20 +926,6 @@ onMounted(async () => {
margin-right: 8px;
}
.mobile-mode-toggle-wrapper {
display: flex;
justify-content: center;
padding: 8px 12px 4px;
}
.mobile-mode-toggle {
width: 100%;
}
.mobile-mode-toggle .v-btn {
flex: 1;
}
/* 移动端对话框标题样式 */
.mobile-card-title {
display: flex;
@@ -38,7 +38,7 @@ const isItemActive = computed(() => {
</template>
<!-- children -->
<template v-for="(child, index) in item.children" :key="child.title || child.to || `child-${index}`">
<template v-for="(child, index) in item.children" :key="index">
<NavItem :item="child" :level="(level || 0) + 1" />
</template>
</v-list-group>
@@ -10,60 +10,26 @@ import ChangelogDialog from '@/components/shared/ChangelogDialog.vue';
const { t, locale } = useI18n();
const customizer = useCustomizerStore();
function collectGroupValues(items, values = new Set()) {
items.forEach((item) => {
if (item?.children && item.title) {
values.add(item.title);
collectGroupValues(item.children, values);
}
});
return values;
}
function sanitizeOpenedItems(items, menuItems) {
if (!Array.isArray(items)) {
return [];
}
const groupValues = collectGroupValues(menuItems);
return items.filter((item) => typeof item === 'string' && groupValues.has(item));
}
function getInitialOpenedItems(menuItems) {
try {
const stored = JSON.parse(localStorage.getItem('sidebar_openedItems') || '[]');
return sanitizeOpenedItems(stored, menuItems);
} catch {
return [];
}
}
const sidebarMenu = shallowRef(applySidebarCustomization(sidebarItems));
const sidebarMenu = shallowRef(sidebarItems);
//
const openedItems = ref(getInitialOpenedItems(sidebarMenu.value));
watch(openedItems, (val) => {
localStorage.setItem('sidebar_openedItems', JSON.stringify(sanitizeOpenedItems(val, sidebarMenu.value)));
}, { deep: true });
function refreshSidebarMenu() {
sidebarMenu.value = applySidebarCustomization(sidebarItems);
openedItems.value = sanitizeOpenedItems(openedItems.value, sidebarMenu.value);
}
const openedItems = ref(JSON.parse(localStorage.getItem('sidebar_openedItems') || '[]'));
watch(openedItems, (val) => localStorage.setItem('sidebar_openedItems', JSON.stringify(val)), { deep: true });
// Apply customization on mount and listen for storage changes
const handleStorageChange = (e) => {
if (e.key === 'astrbot_sidebar_customization') {
refreshSidebarMenu();
sidebarMenu.value = applySidebarCustomization(sidebarItems);
}
};
const handleCustomEvent = () => {
refreshSidebarMenu();
sidebarMenu.value = applySidebarCustomization(sidebarItems);
};
onMounted(() => {
sidebarMenu.value = applySidebarCustomization(sidebarItems);
window.addEventListener('storage', handleStorageChange);
window.addEventListener('sidebar-customization-changed', handleCustomEvent);
});
@@ -289,7 +255,7 @@ function openChangelogDialog() {
>
<div class="sidebar-container">
<v-list class="pa-4 listitem flex-grow-1" v-model:opened="openedItems" :open-strategy="'multiple'">
<template v-for="(item, i) in sidebarMenu" :key="item.title || item.to || `sidebar-item-${i}`">
<template v-for="(item, i) in sidebarMenu" :key="i">
<NavItem :item="item" class="leftPadding" />
</template>
</v-list>
+4 -7
View File
@@ -9,9 +9,12 @@ import '@/scss/style.scss';
import VueApexCharts from 'vue3-apexcharts';
import print from 'vue3-print-nb';
import { loader } from '@guolao/vue-monaco-editor'
import { loader } from '@guolao/vue-monaco-editor';
import * as monaco from 'monaco-editor';
import axios from 'axios';
loader.config({ monaco });
// 初始化新的i18n系统,等待完成后再挂载应用
setupI18n().then(() => {
console.log('🌍 新i18n系统初始化完成');
@@ -108,9 +111,3 @@ window.fetch = (input: RequestInfo | URL, init?: RequestInit) => {
}
return _origFetch(input, { ...init, headers });
};
loader.config({
paths: {
vs: 'https://cdn.jsdelivr.net/npm/monaco-editor@0.54.0/min/vs',
},
})
-4
View File
@@ -15,7 +15,3 @@
@import './components/VScrollbar';
@import './pages/dashboards';
html, body {
overscroll-behavior-y: none;
}
+2 -9
View File
@@ -10,8 +10,7 @@ export const useCustomizerStore = defineStore({
fontTheme: "Poppins",
uiTheme: config.uiTheme,
inputBg: config.inputBg,
viewMode: (localStorage.getItem('viewMode') as 'bot' | 'chat') || 'bot', // 'bot' 或 'chat'
chatSidebarOpen: false // chat mode mobile sidebar state
viewMode: (localStorage.getItem('viewMode') as 'bot' | 'chat') || 'bot' // 'bot' 或 'chat'
}),
getters: {},
@@ -31,13 +30,7 @@ export const useCustomizerStore = defineStore({
},
SET_VIEW_MODE(payload: 'bot' | 'chat') {
this.viewMode = payload;
localStorage.setItem('viewMode', payload);
},
TOGGLE_CHAT_SIDEBAR() {
this.chatSidebarOpen = !this.chatSidebarOpen;
},
SET_CHAT_SIDEBAR(payload: boolean) {
this.chatSidebarOpen = payload;
localStorage.setItem("viewMode", payload);
},
}
});
+16 -16
View File
@@ -46,22 +46,22 @@ export function getPlatformIcon(name) {
*/
export function getTutorialLink(platformType) {
const tutorialMap = {
"qq_official_webhook": "https://docs.astrbot.app/platform/qqofficial/webhook.html",
"qq_official": "https://docs.astrbot.app/platform/qqofficial/websockets.html",
"aiocqhttp": "https://docs.astrbot.app/platform/aiocqhttp/napcat.html",
"wecom": "https://docs.astrbot.app/platform/wecom.html",
"wecom_ai_bot": "https://docs.astrbot.app/platform/wecom_ai_bot.html",
"lark": "https://docs.astrbot.app/platform/lark.html",
"telegram": "https://docs.astrbot.app/platform/telegram.html",
"dingtalk": "https://docs.astrbot.app/platform/dingtalk.html",
"weixin_official_account": "https://docs.astrbot.app/platform/weixin-official-account.html",
"discord": "https://docs.astrbot.app/platform/discord.html",
"slack": "https://docs.astrbot.app/platform/slack.html",
"kook": "https://docs.astrbot.app/platform/kook.html",
"vocechat": "https://docs.astrbot.app/platform/vocechat.html",
"satori": "https://docs.astrbot.app/platform/satori/llonebot.html",
"misskey": "https://docs.astrbot.app/platform/misskey.html",
"line": "https://docs.astrbot.app/platform/line.html",
"qq_official_webhook": "https://docs.astrbot.app/deploy/platform/qqofficial/webhook.html",
"qq_official": "https://docs.astrbot.app/deploy/platform/qqofficial/websockets.html",
"aiocqhttp": "https://docs.astrbot.app/deploy/platform/aiocqhttp/napcat.html",
"wecom": "https://docs.astrbot.app/deploy/platform/wecom.html",
"wecom_ai_bot": "https://docs.astrbot.app/deploy/platform/wecom_ai_bot.html",
"lark": "https://docs.astrbot.app/deploy/platform/lark.html",
"telegram": "https://docs.astrbot.app/deploy/platform/telegram.html",
"dingtalk": "https://docs.astrbot.app/deploy/platform/dingtalk.html",
"weixin_official_account": "https://docs.astrbot.app/deploy/platform/weixin-official-account.html",
"discord": "https://docs.astrbot.app/deploy/platform/discord.html",
"slack": "https://docs.astrbot.app/deploy/platform/slack.html",
"kook": "https://docs.astrbot.app/deploy/platform/kook.html",
"vocechat": "https://docs.astrbot.app/deploy/platform/vocechat.html",
"satori": "https://docs.astrbot.app/deploy/platform/satori/llonebot.html",
"misskey": "https://docs.astrbot.app/deploy/platform/misskey.html",
"line": "https://docs.astrbot.app/deploy/platform/line.html",
}
return tutorialMap[platformType] || "https://docs.astrbot.app";
}
+5 -60
View File
@@ -52,21 +52,6 @@ export function clearSidebarCustomization() {
export function resolveSidebarItems(defaultItems, customization, options = {}) {
const { cloneItems = false, assembleMoreGroup = false } = options;
const normalizeKeys = (keys = []) => {
const list = Array.isArray(keys) ? keys : [];
const deduped = [];
const seen = new Set();
list.forEach((key) => {
if (typeof key !== 'string') return;
if (seen.has(key)) return;
seen.add(key);
deduped.push(key);
});
return deduped;
};
const all = new Map();
const defaultMain = [];
const defaultMore = [];
@@ -85,23 +70,9 @@ export function resolveSidebarItems(defaultItems, customization, options = {}) {
});
const hasCustomization = Boolean(customization);
let mainKeys = hasCustomization ? normalizeKeys(customization.mainItems || []) : [...defaultMain];
let moreKeys = hasCustomization ? normalizeKeys(customization.moreItems || []) : [...defaultMore];
if (hasCustomization) {
mainKeys = mainKeys.filter(title => all.has(title));
moreKeys = moreKeys.filter(title => all.has(title));
}
if (hasCustomization) {
// 如果同一项同时出现在主区与更多区,主区优先。
const mainSet = new Set(mainKeys);
moreKeys = moreKeys.filter(title => !mainSet.has(title));
}
const used = hasCustomization
? new Set([...mainKeys, ...moreKeys])
: new Set(defaultMain.concat(defaultMore));
const mainKeys = hasCustomization ? customization.mainItems || [] : defaultMain;
const moreKeys = hasCustomization ? customization.moreItems || [] : defaultMore;
const used = hasCustomization ? new Set([...mainKeys, ...moreKeys]) : new Set(defaultMain.concat(defaultMore));
const mainItems = mainKeys
.map(title => all.get(title))
@@ -148,13 +119,7 @@ export function resolveSidebarItems(defaultItems, customization, options = {}) {
}
}
return {
mainItems,
moreItems,
merged,
normalizedMainKeys: [...mainKeys],
normalizedMoreKeys: [...moreKeys]
};
return { mainItems, moreItems, merged };
}
/**
@@ -164,29 +129,9 @@ export function resolveSidebarItems(defaultItems, customization, options = {}) {
*/
export function applySidebarCustomization(defaultItems) {
const customization = getSidebarCustomization();
const {
merged,
normalizedMainKeys,
normalizedMoreKeys
} = resolveSidebarItems(defaultItems, customization, {
const { merged } = resolveSidebarItems(defaultItems, customization, {
cloneItems: true,
assembleMoreGroup: true
});
if (customization) {
const rawMainKeys = Array.isArray(customization.mainItems) ? customization.mainItems : [];
const rawMoreKeys = Array.isArray(customization.moreItems) ? customization.moreItems : [];
const hasChanged =
JSON.stringify(rawMainKeys) !== JSON.stringify(normalizedMainKeys) ||
JSON.stringify(rawMoreKeys) !== JSON.stringify(normalizedMoreKeys);
if (hasChanged) {
setSidebarCustomization({
mainItems: normalizedMainKeys,
moreItems: normalizedMoreKeys
});
}
}
return merged || defaultItems;
}
+3 -1
View File
@@ -2,6 +2,7 @@ import { fileURLToPath, URL } from 'url';
import { defineConfig } from 'vite';
import vue from '@vitejs/plugin-vue';
import vuetify from 'vite-plugin-vuetify';
import monacoEditorPlugin from 'vite-plugin-monaco-editor';
// https://vitejs.dev/config/
export default defineConfig({
@@ -15,7 +16,8 @@ export default defineConfig({
}),
vuetify({
autoImport: true
})
}),
monacoEditorPlugin({})
],
resolve: {
alias: {
-20
View File
@@ -1,20 +0,0 @@
schema: spec-driven
# Project context (optional)
# This is shown to AI when creating artifacts.
# Add your tech stack, conventions, style guides, domain knowledge, etc.
# Example:
# context: |
# Tech stack: TypeScript, React, Node.js
# We use conventional commits
# Domain: e-commerce platform
# Per-artifact rules (optional)
# Add custom rules for specific artifacts.
# Example:
# rules:
# proposal:
# - Keep proposals under 500 words
# - Always include a "Non-goals" section
# tasks:
# - Break tasks into chunks of max 2 hours
-12
View File
@@ -61,7 +61,6 @@ dependencies = [
"xinference-client",
"tenacity>=9.1.2",
"shipyard-python-sdk>=0.2.4",
"shipyard-neo-sdk @ git+https://github.com/AstrBotDevs/shipyard-neo.git#subdirectory=shipyard-neo-sdk",
"python-socks>=2.8.0",
"packaging>=24.2",
]
@@ -111,17 +110,6 @@ reportMissingImports = false
include = ["astrbot"]
exclude = ["dashboard", "node_modules", "dist", "data", "tests"]
[tool.hatch.metadata]
allow-direct-references = true
# Include bundled dashboard dist even though it is not tracked by VCS.
[tool.hatch.build.targets.wheel]
artifacts = ["astrbot/dashboard/dist/**"]
# Custom build hook: builds the Vue dashboard and copies dist into the package.
[tool.hatch.build.hooks.custom]
path = "scripts/hatch_build.py"
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
-1
View File
@@ -54,5 +54,4 @@ markitdown-no-magika[docx,xls,xlsx]>=0.1.2
xinference-client
tenacity>=9.1.2
shipyard-python-sdk>=0.2.4
shipyard-neo-sdk @ git+https://github.com/AstrBotDevs/shipyard-neo.git#subdirectory=shipyard-neo-sdk
packaging>=24.2
-63
View File
@@ -1,63 +0,0 @@
"""
Custom Hatchling build hook.
During `hatch build` (or `pip wheel`), this hook:
1. Runs `npm run build` inside the `dashboard/` directory.
2. Copies the resulting `dashboard/dist/` tree into
`astrbot/dashboard/dist/` so the static assets are shipped
inside the Python wheel.
"""
import shutil
import subprocess
import sys
from pathlib import Path
from hatchling.builders.hooks.plugin.interface import BuildHookInterface
class CustomBuildHook(BuildHookInterface):
PLUGIN_NAME = "custom"
def initialize(self, version: str, build_data: dict) -> None:
root = Path(self.root)
dashboard_src = root / "dashboard"
dist_src = dashboard_src / "dist"
dist_target = root / "astrbot" / "dashboard" / "dist"
if not dashboard_src.exists():
print(
"[hatch_build] 'dashboard/' directory not found skipping dashboard build.",
file=sys.stderr,
)
return
# ── Install Node dependencies if node_modules is absent ─────────────
if not (dashboard_src / "node_modules").exists():
print("[hatch_build] Installing dashboard Node dependencies...")
subprocess.run(
["npm", "install"],
cwd=dashboard_src,
check=True,
)
# ── Build the Vue/Vite dashboard ──────────────────────────────────────
print("[hatch_build] Building Vue dashboard (npm run build)...")
subprocess.run(
["npm", "run", "build"],
cwd=dashboard_src,
check=True,
)
if not dist_src.exists():
print(
"[hatch_build] dashboard/dist not found after build skipping copy.",
file=sys.stderr,
)
return
# ── Copy into the Python package tree ────────────────────────────────
if dist_target.exists():
shutil.rmtree(dist_target)
shutil.copytree(dist_src, dist_target)
print(f"[hatch_build] Dashboard dist copied → {dist_target.relative_to(root)}")
-171
View File
@@ -1,171 +0,0 @@
#!/usr/bin/env bash
set -euo pipefail
ROOT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")/.." && pwd)"
cd "$ROOT_DIR"
PROFILE="neo"
RUN_SYNC=true
RUN_LINT=true
RUN_SMOKE=true
RUN_DASHBOARD=false
usage() {
cat <<'EOF'
Usage:
scripts/pr_test_env.sh [options]
Options:
--profile <neo|full> Test profile. Default: neo
--with-dashboard Build dashboard before finishing checks
--no-dashboard Disable dashboard build (even for full profile)
--skip-sync Skip `uv sync`
--skip-lint Skip `ruff format --check` and `ruff check`
--skip-smoke Skip startup smoke test
-h, --help Show this help message
Environment:
PYTEST_ARGS Extra args appended to pytest command
EOF
}
while (($# > 0)); do
case "$1" in
--profile)
PROFILE="${2:-}"
if [[ "$PROFILE" != "neo" && "$PROFILE" != "full" ]]; then
echo "Unsupported profile: $PROFILE" >&2
exit 1
fi
shift 2
;;
--with-dashboard)
RUN_DASHBOARD=true
shift
;;
--skip-sync)
RUN_SYNC=false
shift
;;
--skip-lint)
RUN_LINT=false
shift
;;
--skip-smoke)
RUN_SMOKE=false
shift
;;
--no-dashboard)
RUN_DASHBOARD=false
shift
;;
-h | --help)
usage
exit 0
;;
*)
echo "Unknown option: $1" >&2
usage
exit 1
;;
esac
done
if [[ "$PROFILE" == "full" && "$RUN_DASHBOARD" == false ]]; then
RUN_DASHBOARD=true
fi
echo "==> Profile: $PROFILE"
echo "==> Sync dependencies: $RUN_SYNC"
echo "==> Run lint: $RUN_LINT"
echo "==> Run smoke test: $RUN_SMOKE"
echo "==> Build dashboard: $RUN_DASHBOARD"
if [[ "$RUN_SYNC" == true ]]; then
echo "==> Syncing dependencies with uv"
uv sync --group dev
fi
echo "==> Preparing test directories"
mkdir -p data/plugins data/config data/temp data/skills
export TESTING="${TESTING:-true}"
export ZHIPU_API_KEY="${ZHIPU_API_KEY:-test-api-key}"
if [[ "$RUN_LINT" == true ]]; then
echo "==> Running Ruff format check"
uv run ruff format --check .
echo "==> Running Ruff lint check"
uv run ruff check .
fi
echo "==> Running pytest"
if [[ "$PROFILE" == "neo" ]]; then
NEO_TESTS=(
"tests/test_neo_skill_sync.py"
"tests/test_neo_skill_tools.py"
"tests/test_computer_skill_sync.py"
"tests/test_skill_manager_sandbox_cache.py"
"tests/test_dashboard.py::test_neo_skills_routes"
)
uv run pytest -q "${NEO_TESTS[@]}" ${PYTEST_ARGS:-}
else
uv run pytest --cov=. -v -o log_cli=true -o log_level=DEBUG ${PYTEST_ARGS:-}
fi
run_smoke_test() {
if ! command -v curl >/dev/null 2>&1; then
echo "curl is required for smoke test." >&2
return 1
fi
local smoke_port="6185"
local smoke_log
smoke_log="$(mktemp -t astrbot-smoke.XXXXXX.log)"
echo "==> Starting smoke test on http://localhost:${smoke_port}"
uv run main.py >"$smoke_log" 2>&1 &
local app_pid=$!
for _ in $(seq 1 60); do
if curl -sf "http://localhost:${smoke_port}" >/dev/null 2>&1; then
echo "==> Smoke test passed"
kill "$app_pid" 2>/dev/null || true
wait "$app_pid" 2>/dev/null || true
rm -f "$smoke_log"
return 0
fi
if ! kill -0 "$app_pid" 2>/dev/null; then
echo "AstrBot process exited before becoming healthy." >&2
tail -n 60 "$smoke_log" || true
rm -f "$smoke_log"
return 1
fi
sleep 1
done
echo "Smoke test failed: health endpoint did not become ready in time." >&2
tail -n 60 "$smoke_log" || true
kill "$app_pid" 2>/dev/null || true
wait "$app_pid" 2>/dev/null || true
rm -f "$smoke_log"
return 1
}
if [[ "$RUN_SMOKE" == true ]]; then
run_smoke_test
fi
if [[ "$RUN_DASHBOARD" == true ]]; then
if ! command -v pnpm >/dev/null 2>&1; then
echo "pnpm is required for dashboard build. Install it with: npm install -g pnpm" >&2
exit 1
fi
echo "==> Building dashboard"
pnpm --dir dashboard install --frozen-lockfile
pnpm --dir dashboard run build
fi
echo "==> PR checks completed successfully"
-298
View File
@@ -1,298 +0,0 @@
#!/usr/bin/env bash
# ──────────────────────────────────────────────────────────────
# start-with-neo.sh — 一键启动 Shipyard Neo Bay + AstrBot
#
# Usage:
# bash scripts/start-with-neo.sh # 默认 Bay :8114
# BAY_PORT=9000 bash scripts/start-with-neo.sh # 自定义端口
# ──────────────────────────────────────────────────────────────
set -euo pipefail
# ── 路径 ──────────────────────────────────────────────────────
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
ASTRBOT_DIR="$(cd "$SCRIPT_DIR/.." && pwd)"
# shipyard-neo mono-repo root is one level above AstrBot
NEO_ROOT="$(cd "$ASTRBOT_DIR/.." && pwd)"
BAY_DIR="$NEO_ROOT/pkgs/bay"
BAY_PORT="${BAY_PORT:-8114}"
BAY_HOST="0.0.0.0"
BAY_PID=""
BAY_API_KEY="" # Populated after Bay starts from credentials.json
# ── 颜色 ──────────────────────────────────────────────────────
RED='\033[0;31m'
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
CYAN='\033[0;36m'
NC='\033[0m' # No Color
log() { echo -e "${CYAN}[neo]${NC} $*"; }
ok() { echo -e "${GREEN}[neo]${NC} $*"; }
warn() { echo -e "${YELLOW}[neo]${NC} $*"; }
err() { echo -e "${RED}[neo]${NC} $*" >&2; }
# ── 清理函数 ──────────────────────────────────────────────────
cleanup() {
log "Shutting down..."
if [[ -n "$BAY_PID" ]] && kill -0 "$BAY_PID" 2>/dev/null; then
log "Stopping Bay (PID $BAY_PID)..."
kill "$BAY_PID" 2>/dev/null || true
wait "$BAY_PID" 2>/dev/null || true
fi
ok "All services stopped."
}
trap cleanup EXIT INT TERM
# ── 检查前置条件 ──────────────────────────────────────────────
check_prerequisites() {
log "Checking prerequisites..."
if [[ ! -d "$BAY_DIR" ]]; then
err "Bay directory not found: $BAY_DIR"
err "Expected shipyard-neo mono-repo at: $NEO_ROOT"
exit 1
fi
if ! command -v uv &>/dev/null; then
err "'uv' is not installed. Please install it first."
exit 1
fi
# Check Docker access (try without sudo first, then with sudo)
if docker info &>/dev/null 2>&1; then
ok "Docker is accessible."
elif sudo docker info &>/dev/null 2>&1; then
warn "Docker requires sudo. Bay may need socket permissions."
warn "If Bay fails to connect to Docker, run: sudo chmod 666 /var/run/docker.sock"
else
err "Docker is not accessible. Please install Docker or fix permissions."
exit 1
fi
# Check Bay venv
if [[ ! -d "$BAY_DIR/.venv" ]]; then
log "Bay venv not found. Running 'uv sync' in $BAY_DIR ..."
(cd "$BAY_DIR" && uv sync)
fi
ok "Prerequisites OK."
}
# ── 生成 Bay config.yaml(如不存在)────────────────────────────
ensure_bay_config() {
local config_file="$BAY_DIR/config.yaml"
if [[ -f "$config_file" ]]; then
ok "Bay config.yaml already exists."
return
fi
log "Generating Bay config.yaml for local development..."
cat > "$config_file" << 'BAYCONFIG'
# Bay Local Development Config (auto-generated by start-with-neo.sh)
# For full reference see config.yaml.example
server:
host: "0.0.0.0"
port: 8114
database:
url: "sqlite+aiosqlite:///./bay.db"
echo: false
driver:
type: docker
image_pull_policy: if_not_present
docker:
socket: "unix:///var/run/docker.sock"
connect_mode: host_port
host_address: "127.0.0.1"
publish_ports: true
host_port: null
network: null
cargo:
root_path: "/var/lib/bay/cargos"
default_size_limit_mb: 1024
mount_path: "/workspace"
# Security: auto-provision mode
# Bay generates sk-bay-* key on first boot → credentials.json
security:
allow_anonymous: false
profiles:
- id: python-default
description: "Standard Python sandbox"
image: "ghcr.io/astrbotdevs/shipyard-neo-ship:latest"
runtime_type: ship
runtime_port: 8123
resources:
cpus: 1.0
memory: "1g"
capabilities:
- filesystem
- shell
- python
idle_timeout: 1800
env: {}
gc:
enabled: true
run_on_startup: true
interval_seconds: 300
idle_session:
enabled: true
expired_sandbox:
enabled: true
orphan_cargo:
enabled: true
orphan_container:
enabled: false
BAYCONFIG
ok "Bay config.yaml created at $config_file"
}
# ── 拉取 Ship 镜像 ───────────────────────────────────────────
ensure_ship_image() {
local image="ghcr.io/astrbotdevs/shipyard-neo-ship:latest"
log "Checking Ship image: $image ..."
if docker image inspect "$image" &>/dev/null 2>&1 || \
sudo docker image inspect "$image" &>/dev/null 2>&1; then
ok "Ship image is available locally."
else
log "Pulling Ship image (this may take a while)..."
if docker pull "$image" 2>/dev/null || sudo docker pull "$image" 2>/dev/null; then
ok "Ship image pulled successfully."
else
warn "Failed to pull Ship image. Bay will try to pull it on first sandbox creation."
fi
fi
}
# ── 启动 Bay ──────────────────────────────────────────────────
start_bay() {
log "Starting Bay on :$BAY_PORT ..."
(cd "$BAY_DIR" && BAY_DATA_DIR="$BAY_DIR" uv run uvicorn app.main:app \
--host "$BAY_HOST" \
--port "$BAY_PORT" \
--reload \
2>&1 | sed "s/^/ ${CYAN}[bay]${NC} /") &
BAY_PID=$!
log "Bay started (PID $BAY_PID), waiting for health check..."
# Wait for Bay to become healthy
local max_wait=30
local waited=0
while [[ $waited -lt $max_wait ]]; do
if curl -sf "http://127.0.0.1:$BAY_PORT/health" &>/dev/null; then
ok "Bay is healthy at http://127.0.0.1:$BAY_PORT"
return
fi
# Check if process is still alive
if ! kill -0 "$BAY_PID" 2>/dev/null; then
err "Bay process died unexpectedly. Check the output above."
exit 1
fi
sleep 1
waited=$((waited + 1))
done
err "Bay did not become healthy within ${max_wait}s."
err "It may still be starting — check http://127.0.0.1:$BAY_PORT/health"
}
# ── 读取 Bay 自动生成的凭证 ───────────────────────────────────
read_bay_credentials() {
local cred_file="$BAY_DIR/credentials.json"
# Wait briefly for credentials.json to appear (Bay writes it during startup)
local max_wait=5
local waited=0
while [[ $waited -lt $max_wait ]]; do
if [[ -f "$cred_file" ]]; then
break
fi
sleep 1
waited=$((waited + 1))
done
if [[ -f "$cred_file" ]]; then
# Extract api_key using python (always available) — no jq dependency
BAY_API_KEY=$(python3 -c "
import json, sys
try:
d = json.load(open('$cred_file'))
print(d.get('api_key', ''))
except Exception:
print('')
" 2>/dev/null || echo "")
if [[ -n "$BAY_API_KEY" ]]; then
ok "Auto-provisioned API key loaded from credentials.json"
else
warn "credentials.json found but api_key is empty"
fi
else
warn "credentials.json not found — Bay may be using an existing key or anonymous mode"
warn "Check Bay logs above for the API key, or look at: $cred_file"
fi
}
# ── 打印 AstrBot 配置提示 ────────────────────────────────────
print_astrbot_config_hint() {
echo ""
echo -e "${GREEN}════════════════════════════════════════════════════════════${NC}"
echo -e "${GREEN} Shipyard Neo Bay is running at http://127.0.0.1:$BAY_PORT ${NC}"
echo -e "${GREEN}════════════════════════════════════════════════════════════${NC}"
echo ""
if [[ -n "$BAY_API_KEY" ]]; then
echo -e " ${CYAN}Bay API Key (auto-generated):${NC}"
echo -e " ${YELLOW}$BAY_API_KEY${NC}"
echo ""
fi
echo -e " ${CYAN}AstrBot Dashboard 配置指引:${NC}"
echo -e " 1. AI 配置 → Agent Computer Use"
echo -e " • Computer Use Runtime → ${YELLOW}沙箱${NC}"
echo -e " • 沙箱环境驱动器 → ${YELLOW}Shipyard Neo${NC}"
echo -e " • Shipyard Neo API Endpoint → ${YELLOW}http://127.0.0.1:$BAY_PORT${NC}"
if [[ -n "$BAY_API_KEY" ]]; then
echo -e " • Shipyard Neo Access Token → ${YELLOW}$BAY_API_KEY${NC}"
else
echo -e " • Shipyard Neo Access Token → ${YELLOW}(查看 Bay 日志获取 key${NC}"
fi
echo -e " • Shipyard Neo Profile → ${YELLOW}python-default${NC}"
echo ""
}
# ── 启动 AstrBot ──────────────────────────────────────────────
start_astrbot() {
log "Starting AstrBot..."
cd "$ASTRBOT_DIR"
uv run main.py
}
# ── 主流程 ────────────────────────────────────────────────────
main() {
echo ""
echo -e "${CYAN}╔══════════════════════════════════════════╗${NC}"
echo -e "${CYAN}║ Shipyard Neo + AstrBot Quick Start ║${NC}"
echo -e "${CYAN}╚══════════════════════════════════════════╝${NC}"
echo ""
check_prerequisites
ensure_bay_config
ensure_ship_image
start_bay
read_bay_credentials
print_astrbot_config_hint
start_astrbot
}
main "$@"
+49
View File
@@ -52,6 +52,18 @@ class TestContextTruncator:
assert len(result) == 3
assert result == messages
def test_fix_messages_tool_with_valid_context(self):
"""Test fix_messages with tool message after user+assistant."""
truncator = ContextTruncator()
messages = [
self.create_message("user", "Run tool"),
self.create_message("assistant", "Running..."),
self.create_message("tool", "Tool result"),
]
result = truncator.fix_messages(messages)
assert len(result) == 3
assert result == messages
def test_fix_messages_tool_without_context(self):
"""Test fix_messages with tool message without enough context."""
truncator = ContextTruncator()
@@ -62,6 +74,43 @@ class TestContextTruncator:
# Tool message without context should be removed
assert len(result) == 0
def test_fix_messages_tool_with_only_one_message(self):
"""Test fix_messages with tool message after only one message."""
truncator = ContextTruncator()
messages = [
self.create_message("user", "Hello"),
self.create_message("tool", "Tool result"),
]
result = truncator.fix_messages(messages)
# Tool message without enough context should be removed
assert len(result) == 0
def test_fix_messages_multiple_tools(self):
"""Test fix_messages with multiple tool messages."""
truncator = ContextTruncator()
messages = [
self.create_message("user", "Run tool"),
self.create_message("assistant", "Running..."),
self.create_message("tool", "Tool 1 result"),
self.create_message("tool", "Tool 2 result"),
]
result = truncator.fix_messages(messages)
assert len(result) == 4
assert result == messages
def test_fix_messages_mixed_system_tool(self):
"""Test fix_messages with system message and tool messages."""
truncator = ContextTruncator()
messages = [
self.create_message("system", "System prompt"),
self.create_message("user", "Run tool"),
self.create_message("assistant", "Running..."),
self.create_message("tool", "Tool result"),
]
result = truncator.fix_messages(messages)
assert len(result) == 4
assert result == messages
# ==================== truncate_by_turns Tests ====================
def test_truncate_by_turns_no_limit(self):
-325
View File
@@ -1,325 +0,0 @@
"""Tests for _discover_bay_credentials() auto-discovery and _log_computer_config_changes()."""
from __future__ import annotations
import json
import logging
from pathlib import Path
from unittest.mock import patch
import pytest
from astrbot.core.computer.computer_client import _discover_bay_credentials
from astrbot.dashboard.routes.config import _log_computer_config_changes
# ═══════════════════════════════════════════════════════════════
# _discover_bay_credentials
# ═══════════════════════════════════════════════════════════════
class TestDiscoverBayCredentials:
"""Test Bay API key auto-discovery from credentials.json."""
def _write_creds(
self,
path: Path,
api_key: str = "sk-bay-abc123",
endpoint: str = "http://127.0.0.1:8114",
) -> None:
"""Helper: write a credentials.json file."""
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(
json.dumps(
{
"api_key": api_key,
"endpoint": endpoint,
"generated_at": "2026-02-17T00:00:00+00:00",
}
)
)
def test_discover_from_bay_data_dir_env(
self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch
) -> None:
"""BAY_DATA_DIR env var takes highest priority."""
data_dir = tmp_path / "bay_data"
cred_file = data_dir / "credentials.json"
self._write_creds(cred_file, api_key="sk-bay-from-env-dir")
monkeypatch.setenv("BAY_DATA_DIR", str(data_dir))
result = _discover_bay_credentials("http://127.0.0.1:8114")
assert result == "sk-bay-from-env-dir"
def test_discover_from_cwd(
self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Falls back to current working directory."""
cred_file = tmp_path / "credentials.json"
self._write_creds(cred_file, api_key="sk-bay-from-cwd")
monkeypatch.chdir(tmp_path)
monkeypatch.delenv("BAY_DATA_DIR", raising=False)
result = _discover_bay_credentials("http://127.0.0.1:8114")
assert result == "sk-bay-from-cwd"
def test_returns_empty_when_no_credentials_found(
self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Returns empty string when no credentials.json exists anywhere."""
monkeypatch.chdir(tmp_path)
monkeypatch.delenv("BAY_DATA_DIR", raising=False)
result = _discover_bay_credentials("http://127.0.0.1:8114")
assert result == ""
def test_skips_empty_api_key(
self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Skips credentials.json when api_key is empty."""
cred_file = tmp_path / "credentials.json"
self._write_creds(cred_file, api_key="")
monkeypatch.chdir(tmp_path)
monkeypatch.delenv("BAY_DATA_DIR", raising=False)
result = _discover_bay_credentials("http://127.0.0.1:8114")
assert result == ""
def test_skips_malformed_json(
self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Handles malformed JSON gracefully."""
cred_file = tmp_path / "credentials.json"
cred_file.parent.mkdir(parents=True, exist_ok=True)
cred_file.write_text("not valid json {{{")
monkeypatch.chdir(tmp_path)
monkeypatch.delenv("BAY_DATA_DIR", raising=False)
result = _discover_bay_credentials("http://127.0.0.1:8114")
assert result == ""
@patch("astrbot.core.computer.computer_client.logger")
def test_endpoint_mismatch_still_returns_key(
self, mock_logger, tmp_path: Path, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Returns key even if endpoint doesn't match, but logs a warning."""
data_dir = tmp_path / "bay_data"
cred_file = data_dir / "credentials.json"
self._write_creds(
cred_file, api_key="sk-bay-mismatch", endpoint="http://other-host:9000"
)
monkeypatch.setenv("BAY_DATA_DIR", str(data_dir))
result = _discover_bay_credentials("http://127.0.0.1:8114")
assert result == "sk-bay-mismatch"
mock_logger.warning.assert_called_once()
warning_msg = mock_logger.warning.call_args[0][0]
assert "endpoint mismatch" in warning_msg
def test_endpoint_match_no_warning(
self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch
) -> None:
"""No warning when endpoints match."""
data_dir = tmp_path / "bay_data"
cred_file = data_dir / "credentials.json"
self._write_creds(
cred_file, api_key="sk-bay-match", endpoint="http://127.0.0.1:8114"
)
monkeypatch.setenv("BAY_DATA_DIR", str(data_dir))
with patch("astrbot.core.computer.computer_client.logger") as mock_logger:
result = _discover_bay_credentials("http://127.0.0.1:8114")
assert result == "sk-bay-match"
mock_logger.warning.assert_not_called()
def test_bay_data_dir_priority_over_cwd(
self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch
) -> None:
"""BAY_DATA_DIR takes priority over cwd."""
env_dir = tmp_path / "env_dir"
cwd_dir = tmp_path / "cwd_dir"
self._write_creds(env_dir / "credentials.json", api_key="sk-bay-env-wins")
self._write_creds(cwd_dir / "credentials.json", api_key="sk-bay-cwd-loses")
monkeypatch.setenv("BAY_DATA_DIR", str(env_dir))
monkeypatch.chdir(cwd_dir)
result = _discover_bay_credentials("http://127.0.0.1:8114")
assert result == "sk-bay-env-wins"
def test_trailing_slash_normalization(
self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Trailing slashes on endpoints are normalized before comparison."""
data_dir = tmp_path / "bay_data"
cred_file = data_dir / "credentials.json"
self._write_creds(
cred_file, api_key="sk-bay-slash", endpoint="http://127.0.0.1:8114/"
)
monkeypatch.setenv("BAY_DATA_DIR", str(data_dir))
with patch("astrbot.core.computer.computer_client.logger") as mock_logger:
result = _discover_bay_credentials("http://127.0.0.1:8114")
assert result == "sk-bay-slash"
mock_logger.warning.assert_not_called()
# ═══════════════════════════════════════════════════════════════
# _log_computer_config_changes
# ═══════════════════════════════════════════════════════════════
class TestLogComputerConfigChanges:
"""Test config change detection and logging."""
@patch("astrbot.dashboard.routes.config.logger")
def test_logs_runtime_change(self, mock_logger) -> None:
"""Detects computer_use_runtime change."""
old = {"provider_settings": {"computer_use_runtime": "none"}}
new = {"provider_settings": {"computer_use_runtime": "sandbox"}}
_log_computer_config_changes(old, new)
mock_logger.info.assert_called()
call_args = [str(c) for c in mock_logger.info.call_args_list]
assert any("computer_use_runtime" in c and "none" in c and "sandbox" in c for c in call_args)
@patch("astrbot.dashboard.routes.config.logger")
def test_no_log_when_runtime_unchanged(self, mock_logger) -> None:
"""No log when runtime stays the same."""
old = {"provider_settings": {"computer_use_runtime": "sandbox"}}
new = {"provider_settings": {"computer_use_runtime": "sandbox"}}
_log_computer_config_changes(old, new)
mock_logger.info.assert_not_called()
@patch("astrbot.dashboard.routes.config.logger")
def test_logs_sandbox_key_change(self, mock_logger) -> None:
"""Detects sandbox sub-key change."""
old = {"provider_settings": {"sandbox": {"booter": "shipyard"}}}
new = {"provider_settings": {"sandbox": {"booter": "shipyard_neo"}}}
_log_computer_config_changes(old, new)
mock_logger.info.assert_called()
# logger.info("[Computer] Config changed: sandbox.%s %s -> %s", key, old, new)
found = False
for call in mock_logger.info.call_args_list:
args = call[0] # positional args: (fmt, key, old_val, new_val)
if len(args) >= 4 and args[1] == "booter":
assert args[2] == "shipyard"
assert args[3] == "shipyard_neo"
found = True
break
assert found, f"Expected booter change in log calls: {mock_logger.info.call_args_list}"
@patch("astrbot.dashboard.routes.config.logger")
def test_masks_token_values(self, mock_logger) -> None:
"""Token/secret values are masked in log output."""
old = {"provider_settings": {"sandbox": {"shipyard_neo_access_token": ""}}}
new = {
"provider_settings": {
"sandbox": {"shipyard_neo_access_token": "sk-bay-secret123"}
}
}
_log_computer_config_changes(old, new)
mock_logger.info.assert_called()
call_args_str = str(mock_logger.info.call_args_list)
assert "***" in call_args_str
assert "sk-bay-secret123" not in call_args_str
@patch("astrbot.dashboard.routes.config.logger")
def test_masks_empty_token_as_empty_label(self, mock_logger) -> None:
"""Empty token values show as '(empty)' not '***'."""
old = {
"provider_settings": {
"sandbox": {"shipyard_neo_access_token": "old-key"}
}
}
new = {"provider_settings": {"sandbox": {"shipyard_neo_access_token": ""}}}
_log_computer_config_changes(old, new)
mock_logger.info.assert_called()
call_args_str = str(mock_logger.info.call_args_list)
assert "(empty)" in call_args_str
@patch("astrbot.dashboard.routes.config.logger")
def test_no_log_when_nothing_changed(self, mock_logger) -> None:
"""No logs at all when config is identical."""
cfg = {
"provider_settings": {
"computer_use_runtime": "sandbox",
"sandbox": {
"booter": "shipyard_neo",
"shipyard_neo_endpoint": "http://127.0.0.1:8114",
},
}
}
_log_computer_config_changes(cfg, cfg)
mock_logger.info.assert_not_called()
@patch("astrbot.dashboard.routes.config.logger")
def test_handles_missing_provider_settings(self, mock_logger) -> None:
"""Gracefully handles configs without provider_settings."""
_log_computer_config_changes(
{}, {"provider_settings": {"computer_use_runtime": "sandbox"}}
)
mock_logger.info.assert_called()
call_args_str = str(mock_logger.info.call_args_list)
assert "computer_use_runtime" in call_args_str
@patch("astrbot.dashboard.routes.config.logger")
def test_detects_new_sandbox_key(self, mock_logger) -> None:
"""Detects a newly added sandbox key."""
old = {"provider_settings": {"sandbox": {}}}
new = {
"provider_settings": {
"sandbox": {"shipyard_neo_endpoint": "http://127.0.0.1:8114"}
}
}
_log_computer_config_changes(old, new)
mock_logger.info.assert_called()
call_args_str = str(mock_logger.info.call_args_list)
assert "shipyard_neo_endpoint" in call_args_str
@patch("astrbot.dashboard.routes.config.logger")
def test_detects_removed_sandbox_key(self, mock_logger) -> None:
"""Detects a removed sandbox key."""
old = {
"provider_settings": {
"sandbox": {"shipyard_neo_endpoint": "http://127.0.0.1:8114"}
}
}
new = {"provider_settings": {"sandbox": {}}}
_log_computer_config_changes(old, new)
mock_logger.info.assert_called()
call_args_str = str(mock_logger.info.call_args_list)
assert "shipyard_neo_endpoint" in call_args_str
@patch("astrbot.dashboard.routes.config.logger")
def test_secret_key_masked(self, mock_logger) -> None:
"""Any key containing 'secret' is also masked."""
old = {"provider_settings": {"sandbox": {"my_secret_key": ""}}}
new = {
"provider_settings": {"sandbox": {"my_secret_key": "very-secret-value"}}
}
_log_computer_config_changes(old, new)
mock_logger.info.assert_called()
call_args_str = str(mock_logger.info.call_args_list)
assert "***" in call_args_str
assert "very-secret-value" not in call_args_str
-123
View File
@@ -1,123 +0,0 @@
from __future__ import annotations
import asyncio
from pathlib import Path
from astrbot.core.computer import computer_client
class _FakeShell:
def __init__(self, sync_payload_json: str):
self.sync_payload_json = sync_payload_json
self.commands: list[str] = []
async def exec(self, command: str, **kwargs):
_ = kwargs
self.commands.append(command)
if "PYBIN" in command and "managed_skills" in command:
return {
"success": True,
"stdout": self.sync_payload_json,
"stderr": "",
"exit_code": 0,
}
return {"success": True, "stdout": "", "stderr": "", "exit_code": 0}
class _FakeBooter:
def __init__(self, sync_payload_json: str):
self.shell = _FakeShell(sync_payload_json)
self.uploads: list[tuple[str, str]] = []
async def upload_file(self, path: str, file_name: str) -> dict:
self.uploads.append((path, file_name))
return {"success": True}
def test_sync_skills_keeps_builtin_skills_when_local_is_empty(monkeypatch, tmp_path: Path):
skills_root = tmp_path / "skills"
temp_root = tmp_path / "temp"
skills_root.mkdir(parents=True, exist_ok=True)
temp_root.mkdir(parents=True, exist_ok=True)
captured = {"skills": None}
def _fake_set_cache(self, skills):
captured["skills"] = skills
monkeypatch.setattr(
"astrbot.core.computer.computer_client.get_astrbot_skills_path",
lambda: str(skills_root),
)
monkeypatch.setattr(
"astrbot.core.computer.computer_client.get_astrbot_temp_path",
lambda: str(temp_root),
)
monkeypatch.setattr(
"astrbot.core.computer.computer_client.SkillManager.set_sandbox_skills_cache",
_fake_set_cache,
)
booter = _FakeBooter(
'{"skills":[{"name":"python-sandbox","description":"ship","path":"skills/python-sandbox/SKILL.md"}]}'
)
asyncio.run(computer_client._sync_skills_to_sandbox(booter))
assert booter.uploads == []
assert any(cmd == "rm -f skills/skills.zip" for cmd in booter.shell.commands)
assert captured["skills"] == [
{
"name": "python-sandbox",
"description": "ship",
"path": "skills/python-sandbox/SKILL.md",
}
]
def test_sync_skills_uses_managed_strategy_instead_of_wiping_all(
monkeypatch,
tmp_path: Path,
):
skills_root = tmp_path / "skills"
temp_root = tmp_path / "temp"
skill_dir = skills_root / "custom-agent-skill"
skill_dir.mkdir(parents=True, exist_ok=True)
skill_dir.joinpath("SKILL.md").write_text("# demo", encoding="utf-8")
temp_root.mkdir(parents=True, exist_ok=True)
captured = {"skills": None}
def _fake_set_cache(self, skills):
captured["skills"] = skills
monkeypatch.setattr(
"astrbot.core.computer.computer_client.get_astrbot_skills_path",
lambda: str(skills_root),
)
monkeypatch.setattr(
"astrbot.core.computer.computer_client.get_astrbot_temp_path",
lambda: str(temp_root),
)
monkeypatch.setattr(
"astrbot.core.computer.computer_client.SkillManager.set_sandbox_skills_cache",
_fake_set_cache,
)
booter = _FakeBooter(
'{"skills":[{"name":"custom-agent-skill","description":"","path":"skills/custom-agent-skill/SKILL.md"}]}'
)
asyncio.run(computer_client._sync_skills_to_sandbox(booter))
assert len(booter.uploads) == 1
assert booter.uploads[0][1] == "skills/skills.zip"
assert not any(
"find skills -mindepth 1 -delete" in cmd for cmd in booter.shell.commands
)
assert captured["skills"] == [
{
"name": "custom-agent-skill",
"description": "",
"path": "skills/custom-agent-skill/SKILL.md",
}
]
+1 -183
View File
@@ -1,7 +1,6 @@
import asyncio
import os
import sys
from types import SimpleNamespace
from pathlib import Path
import pytest
import pytest_asyncio
@@ -312,184 +311,3 @@ async def test_do_update(
data = await response.get_json()
assert data["status"] == "ok"
assert os.path.exists(release_path)
class _FakeNeoSkills:
async def list_candidates(self, **kwargs):
_ = kwargs
return [
{
"id": "cand-1",
"skill_key": "neo.demo",
"status": "evaluated_pass",
"payload_ref": "pref-1",
}
]
async def list_releases(self, **kwargs):
_ = kwargs
return [
{
"id": "rel-1",
"skill_key": "neo.demo",
"candidate_id": "cand-1",
"stage": "stable",
"active": True,
}
]
async def get_payload(self, payload_ref: str):
return {
"payload_ref": payload_ref,
"payload": {"skill_markdown": "# Demo"},
}
async def evaluate_candidate(self, candidate_id: str, **kwargs):
return {"candidate_id": candidate_id, **kwargs}
async def promote_candidate(self, candidate_id: str, stage: str = "canary"):
return {
"id": "rel-2",
"skill_key": "neo.demo",
"candidate_id": candidate_id,
"stage": stage,
}
async def rollback_release(self, release_id: str):
return {"id": "rb-1", "rolled_back_release_id": release_id}
class _FakeNeoBayClient:
def __init__(self, endpoint_url: str, access_token: str):
self.endpoint_url = endpoint_url
self.access_token = access_token
self.skills = _FakeNeoSkills()
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
_ = exc_type, exc, tb
return False
@pytest.mark.asyncio
async def test_neo_skills_routes(
app: Quart,
authenticated_header: dict,
core_lifecycle_td: AstrBotCoreLifecycle,
monkeypatch,
):
provider_settings = core_lifecycle_td.astrbot_config.setdefault(
"provider_settings", {}
)
sandbox = provider_settings.setdefault("sandbox", {})
sandbox["shipyard_neo_endpoint"] = "http://neo.test"
sandbox["shipyard_neo_access_token"] = "neo-token"
fake_shipyard_neo_module = SimpleNamespace(BayClient=_FakeNeoBayClient)
monkeypatch.setitem(sys.modules, "shipyard_neo", fake_shipyard_neo_module)
async def _fake_sync_release(self, client, **kwargs):
_ = self, client, kwargs
return SimpleNamespace(
skill_key="neo.demo",
local_skill_name="neo_demo",
release_id="rel-2",
candidate_id="cand-1",
payload_ref="pref-1",
map_path="data/skills/neo_skill_map.json",
synced_at="2026-01-01T00:00:00Z",
)
async def _fake_sync_skills_to_active_sandboxes():
return
monkeypatch.setattr(
"astrbot.dashboard.routes.skills.NeoSkillSyncManager.sync_release",
_fake_sync_release,
)
monkeypatch.setattr(
"astrbot.dashboard.routes.skills.sync_skills_to_active_sandboxes",
_fake_sync_skills_to_active_sandboxes,
)
test_client = app.test_client()
response = await test_client.get(
"/api/skills/neo/candidates", headers=authenticated_header
)
assert response.status_code == 200
data = await response.get_json()
assert data["status"] == "ok"
assert isinstance(data["data"], list)
assert data["data"][0]["id"] == "cand-1"
response = await test_client.get(
"/api/skills/neo/releases", headers=authenticated_header
)
assert response.status_code == 200
data = await response.get_json()
assert data["status"] == "ok"
assert isinstance(data["data"], list)
assert data["data"][0]["id"] == "rel-1"
response = await test_client.get(
"/api/skills/neo/payload?payload_ref=pref-1", headers=authenticated_header
)
assert response.status_code == 200
data = await response.get_json()
assert data["status"] == "ok"
assert data["data"]["payload_ref"] == "pref-1"
response = await test_client.post(
"/api/skills/neo/evaluate",
json={"candidate_id": "cand-1", "passed": True, "score": 0.95},
headers=authenticated_header,
)
assert response.status_code == 200
data = await response.get_json()
assert data["status"] == "ok"
assert data["data"]["candidate_id"] == "cand-1"
assert data["data"]["passed"] is True
response = await test_client.post(
"/api/skills/neo/evaluate",
json={"candidate_id": "cand-1", "passed": "false", "score": 0.0},
headers=authenticated_header,
)
assert response.status_code == 200
data = await response.get_json()
assert data["status"] == "ok"
assert data["data"]["passed"] is False
response = await test_client.post(
"/api/skills/neo/promote",
json={"candidate_id": "cand-1", "stage": "stable"},
headers=authenticated_header,
)
assert response.status_code == 200
data = await response.get_json()
assert data["status"] == "ok"
assert data["data"]["release"]["id"] == "rel-2"
assert data["data"]["sync"]["local_skill_name"] == "neo_demo"
response = await test_client.post(
"/api/skills/neo/rollback",
json={"release_id": "rel-2"},
headers=authenticated_header,
)
assert response.status_code == 200
data = await response.get_json()
assert data["status"] == "ok"
assert data["data"]["rolled_back_release_id"] == "rel-2"
response = await test_client.post(
"/api/skills/neo/sync",
json={"release_id": "rel-2"},
headers=authenticated_header,
)
assert response.status_code == 200
data = await response.get_json()
assert data["status"] == "ok"
assert data["data"]["skill_key"] == "neo.demo"
-1
View File
@@ -1 +0,0 @@
!data
-100
View File
@@ -1,100 +0,0 @@
{
"type": "card",
"theme": "info",
"size": "lg",
"modules": [
{
"text": {
"content": "test1",
"type": "plain-text",
"emoji": true
},
"type": "header"
},
{
"text": {
"content": "test2",
"type": "kmarkdown"
},
"type": "section",
"mode": "left"
},
{
"type": "divider"
},
{
"text": {
"fields": [
{
"content": "test3",
"type": "kmarkdown"
},
{
"content": "**test4**",
"type": "kmarkdown"
}
],
"type": "paragraph",
"cols": 2
},
"type": "section",
"mode": "left"
},
{
"elements": [
{
"src": "https://img.kookapp.cn/attachments/2023-01/05/63b645851ff19.svg",
"type": "image",
"alt": "",
"size": "lg",
"circle": false
}
],
"type": "image-group"
},
{
"src": "https://img.kookapp.cn/attachments/2023-01/05/63b645851ff19.svg",
"title": "test5",
"type": "file"
},
{
"endTime": 1772343427360,
"type": "countdown",
"startTime": 1772343378259,
"mode": "second"
},
{
"elements": [
{
"text": "点我测试回调",
"type": "button",
"theme": "primary",
"value": "btn_clicked",
"click": "return-val"
},
{
"text": "访问官网",
"type": "button",
"theme": "danger",
"value": "https://www.kookapp.cn",
"click": "link"
}
],
"type": "action-group"
},
{
"elements": [
{
"content": "test6",
"type": "plain-text",
"emoji": true
}
],
"type": "context"
},
{
"code": "test7",
"type": "invite"
}
]
}
-4
View File
@@ -1,4 +0,0 @@
from pathlib import Path
TEST_DATA_DIR = Path(__file__).parent / "data"
-223
View File
@@ -1,223 +0,0 @@
from unittest.mock import AsyncMock, MagicMock
import pytest
from astrbot.api.platform import AstrBotMessage, MessageType, PlatformMetadata, Unknown
from astrbot.api.event import MessageChain
from astrbot.core.message.components import (
File,
Image,
Plain,
Video,
At,
AtAll,
BaseMessageComponent,
Json,
Record,
Reply,
)
from astrbot.core.platform.sources.kook.kook_event import KookEvent
from astrbot.core.platform.sources.kook.kook_types import KookMessageType, OrderMessage
async def mock_kook_client(upload_asset_return: str, send_text_return: str):
# 1. Mock 掉整个 KookClient 类
client = MagicMock()
client.upload_asset = AsyncMock(return_value=upload_asset_return)
client.send_text = AsyncMock(return_value=send_text_return)
return client
def mock_file_message(input: str):
message = MagicMock(spec=File)
message.get_file = AsyncMock(return_value=input)
return message
def mock_record_message(input: str):
message = MagicMock(spec=Record)
message.text = input
message.convert_to_file_path = AsyncMock(return_value=input)
return message
def mock_astrbot_message():
message = AstrBotMessage()
message.type = MessageType.OTHER_MESSAGE
message.group_id = "test"
message.session_id = "test"
message.message_id = "test"
return message
@pytest.mark.asyncio
@pytest.mark.parametrize(
"input_message,upload_asset_return, expected_output, expected_error",
[
(
Image("test image"),
"test image",
OrderMessage(
1,
text="test image",
type=KookMessageType.IMAGE,
),
None,
),
(
Video("test video"),
"test video",
OrderMessage(
1,
text="test video",
type=KookMessageType.VIDEO,
),
None,
),
(
mock_file_message("test file"),
"test file",
OrderMessage(
1,
text="test file",
type=KookMessageType.FILE,
),
None,
),
(
mock_record_message("./tests/file.wav"),
"./tests/file.wav",
OrderMessage(
1,
text='[{"type": "card", "modules": [{"src": "./tests/file.wav", "title": "./tests/file.wav", "type": "audio"}]}]',
type=KookMessageType.CARD,
),
None,
),
(
Plain("test plain"),
"test plain",
OrderMessage(
1,
text="test plain",
type=KookMessageType.KMARKDOWN,
),
None,
),
(
At(qq="test at"),
"test at",
OrderMessage(
1,
text="(met)test at(met)",
type=KookMessageType.KMARKDOWN,
),
None,
),
(
AtAll(qq="all"),
"test atAll",
OrderMessage(
1,
text="(met)all(met)",
type=KookMessageType.KMARKDOWN,
),
None,
),
(
Reply(id="test reply"),
"test reply",
OrderMessage(
1,
text="",
type=KookMessageType.KMARKDOWN,
reply_id="test reply",
),
None,
),
(
Json(data={"test": "json"}),
"test json",
OrderMessage(
1,
text='[{"test": "json"}]',
type=KookMessageType.CARD,
),
None,
),
(
Unknown(text="test unknown"),
"test unknown",
None,
NotImplementedError,
),
],
)
async def test_kook_event_warp_message(
input_message: BaseMessageComponent,
upload_asset_return: str,
expected_output: OrderMessage,
expected_error: type[Exception] | None,
):
client = await mock_kook_client(
upload_asset_return,
"",
)
event = KookEvent(
"",
mock_astrbot_message(),
PlatformMetadata(
name="test",
id="test",
description="test",
),
"",
client,
)
if expected_error:
with pytest.raises(expected_error):
await event._wrap_message(1, input_message)
return
result = await event._wrap_message(1, input_message)
assert result == expected_output
# @pytest.mark.asyncio
# @pytest.mark.parametrize(
# "message_chain,send_text_expected_output,expected_error",
# [
# (
# MessageChain(
# chain=[
# Image(file="test image"),
# Plain(text="test plain"),
# ],
# ),
# ""
# ),
# ],
# )
# async def test_kook_event_send():
# client = await mock_kook_client(
# "",
# "",
# )
# event = KookEvent(
# "",
# mock_astrbot_message(),
# PlatformMetadata(
# name="test",
# id="test",
# description="test",
# ),
# "",
# client,
# )
# await event.send(message=mock_astrbot_message())
-107
View File
@@ -1,107 +0,0 @@
import json
from pathlib import Path
import pytest
from astrbot.core.platform.sources.kook.kook_types import (
ActionGroupModule,
ButtonElement,
ContextModule,
CountdownModule,
DividerModule,
FileModule,
HeaderModule,
ImageElement,
ImageGroupModule,
InviteModule,
KmarkdownElement,
KookCardMessage,
ParagraphStructure,
PlainTextElement,
SectionModule,
KookCardMessageContainer,
)
from tests.test_kook.shared import TEST_DATA_DIR
def test_kook_card_message_container_append():
container = KookCardMessageContainer()
container.append(KookCardMessage())
assert len(container) == 1
@pytest.mark.parametrize(
"input, expect_container_length",
[
([KookCardMessage()], 1),
([KookCardMessage()] * 2, 2),
],
)
def test_kook_card_message_container_to_json(
input: list[KookCardMessage], expect_container_length: int
):
container = KookCardMessageContainer(input)
json_output = container.to_json()
output = json.loads(json_output)
assert isinstance(output, list)
assert len(output) == expect_container_length
def test_all_kook_card_type():
expect_json_data = Path(TEST_DATA_DIR / "kook_card_data.json").read_text(
encoding="utf-8"
)
json_output = KookCardMessage(
theme="info",
size="lg",
modules=[
HeaderModule(text=PlainTextElement(content="test1")),
SectionModule(text=KmarkdownElement(content="test2")),
DividerModule(),
SectionModule(
text=ParagraphStructure(
cols=2,
fields=[
KmarkdownElement(content="test3"),
KmarkdownElement(content="**test4**"),
],
)
),
ImageGroupModule(
elements=[
ImageElement(
src="https://img.kookapp.cn/attachments/2023-01/05/63b645851ff19.svg"
)
]
),
FileModule(
src="https://img.kookapp.cn/attachments/2023-01/05/63b645851ff19.svg",
title="test5",
type="file",
),
CountdownModule(
endTime=1772343427360,
startTime=1772343378259,
mode="second",
),
ActionGroupModule(
elements=[
ButtonElement(
value="btn_clicked",
text="点我测试回调",
click="return-val",
theme="primary",
),
ButtonElement(
value="https://www.kookapp.cn",
text="访问官网",
click="link",
theme="danger",
),
]
),
ContextModule(elements=[PlainTextElement(content="test6")]),
InviteModule(code="test7"),
],
).to_json(indent=4, ensure_ascii=False)
assert json_output == expect_json_data
+6 -21
View File
@@ -26,21 +26,6 @@ class _version_info:
return (self.major, self.minor) >= other[:2]
return (self.major, self.minor) >= (other.major, other.minor)
def __le__(self, other):
if isinstance(other, tuple):
return (self.major, self.minor) <= other[:2]
return (self.major, self.minor) <= (other.major, other.minor)
def __gt__(self, other):
if isinstance(other, tuple):
return (self.major, self.minor) > other[:2]
return (self.major, self.minor) > (other.major, other.minor)
def __lt__(self, other):
if isinstance(other, tuple):
return (self.major, self.minor) < other[:2]
return (self.major, self.minor) < (other.major, other.minor)
def test_check_env(monkeypatch):
version_info_correct = _version_info(3, 10)
@@ -48,12 +33,12 @@ def test_check_env(monkeypatch):
monkeypatch.setattr(sys, "version_info", version_info_correct)
with mock.patch("os.makedirs") as mock_makedirs:
check_env()
# check_env uses get_astrbot_*_path() which returns absolute paths,
# so just verify makedirs was called the expected number of times
assert mock_makedirs.call_count >= 4
# Verify all calls used exist_ok=True
for call_args in mock_makedirs.call_args_list:
assert call_args[1].get("exist_ok") is True
# Check that makedirs was called with paths containing expected dirs
called_paths = [call[0][0] for call in mock_makedirs.call_args_list]
# Use os.path.join for cross-platform path matching
assert any(p.rstrip(os.sep).endswith(os.path.join("data", "config")) for p in called_paths)
assert any(p.rstrip(os.sep).endswith(os.path.join("data", "plugins")) for p in called_paths)
assert any(p.rstrip(os.sep).endswith(os.path.join("data", "temp")) for p in called_paths)
monkeypatch.setattr(sys, "version_info", version_info_wrong)
with pytest.raises(SystemExit):
-130
View File
@@ -1,130 +0,0 @@
from __future__ import annotations
import asyncio
from pathlib import Path
import pytest
from astrbot.core.skills.neo_skill_sync import NeoSkillSyncManager
class _FakeSkills:
async def list_releases(self, **kwargs):
_ = kwargs
return {
"items": [
{
"id": "sr-1",
"skill_key": "etl/loader@v1",
"candidate_id": "sc-1",
"stage": "stable",
}
],
"total": 1,
}
async def get_candidate(self, candidate_id: str):
assert candidate_id == "sc-1"
return {
"id": "sc-1",
"payload_ref": "blob:blob-1",
}
async def get_payload(self, payload_ref: str):
assert payload_ref == "blob:blob-1"
return {
"payload_ref": payload_ref,
"kind": "astrbot_skill_v1",
"payload": {
"skill_markdown": "---\ndescription: test\n---\n# title\ncontent",
},
}
class _FakeClient:
def __init__(self):
self.skills = _FakeSkills()
def test_sync_release_writes_skill_and_map(monkeypatch, tmp_path: Path):
calls = {"active": [], "sandbox_sync": 0}
def _fake_set_skill_active(self, name, active):
calls["active"].append((name, active))
async def _fake_sync_sandboxes():
calls["sandbox_sync"] += 1
monkeypatch.setattr(
"astrbot.core.skills.neo_skill_sync.SkillManager.set_skill_active",
_fake_set_skill_active,
)
monkeypatch.setattr(
"astrbot.core.skills.neo_skill_sync.sync_skills_to_active_sandboxes",
_fake_sync_sandboxes,
)
skills_root = tmp_path / "skills"
map_path = skills_root / "neo_skill_map.json"
mgr = NeoSkillSyncManager(skills_root=str(skills_root), map_path=str(map_path))
result = asyncio.run(
mgr.sync_release(_FakeClient(), release_id="sr-1", require_stable=True)
)
assert result.skill_key == "etl/loader@v1"
assert result.release_id == "sr-1"
assert result.local_skill_name.startswith("neo_")
assert calls["active"] == [(result.local_skill_name, True)]
assert calls["sandbox_sync"] == 1
skill_md = skills_root / result.local_skill_name / "SKILL.md"
assert skill_md.exists()
assert "description: test" in skill_md.read_text(encoding="utf-8")
assert map_path.exists()
map_text = map_path.read_text(encoding="utf-8")
assert "etl/loader@v1" in map_text
assert result.local_skill_name in map_text
def test_sync_release_rejects_non_stable(monkeypatch, tmp_path: Path):
class _CanarySkills(_FakeSkills):
async def list_releases(self, **kwargs):
_ = kwargs
return {
"items": [
{
"id": "sr-1",
"skill_key": "etl",
"candidate_id": "sc-1",
"stage": "canary",
}
],
"total": 1,
}
class _CanaryClient:
def __init__(self):
self.skills = _CanarySkills()
async def _fake_sync_sandboxes():
return
monkeypatch.setattr(
"astrbot.core.skills.neo_skill_sync.sync_skills_to_active_sandboxes",
_fake_sync_sandboxes,
)
monkeypatch.setattr(
"astrbot.core.skills.neo_skill_sync.SkillManager.set_skill_active",
lambda self, name, active: None,
)
mgr = NeoSkillSyncManager(
skills_root=str(tmp_path / "skills"),
map_path=str(tmp_path / "skills" / "neo_skill_map.json"),
)
with pytest.raises(ValueError, match="Only stable releases"):
asyncio.run(
mgr.sync_release(_CanaryClient(), release_id="sr-1", require_stable=True)
)
-73
View File
@@ -1,73 +0,0 @@
from __future__ import annotations
import asyncio
from types import SimpleNamespace
from astrbot.core.agent.run_context import ContextWrapper
from astrbot.core.computer.tools.neo_skills import PromoteSkillCandidateTool
class _FakeSkills:
def __init__(self):
self.rollback_called_with = None
async def promote_candidate(self, candidate_id: str, stage: str = "canary"):
assert candidate_id == "cand-1"
assert stage == "stable"
return {
"id": "sr-1",
"skill_key": "k1",
"candidate_id": candidate_id,
"stage": stage,
}
async def rollback_release(self, release_id: str):
self.rollback_called_with = release_id
return {"id": "rb-1", "rollback_of": release_id}
class _FakeClient:
def __init__(self):
self.skills = _FakeSkills()
class _FakeBooter:
def __init__(self):
self.bay_client = _FakeClient()
self.sandbox = object()
def test_promote_stable_sync_failure_auto_rolls_back(monkeypatch):
async def _fake_get_booter(_ctx, _session_id):
return _FakeBooter()
async def _fake_sync_release(self, client, **kwargs):
_ = self, client, kwargs
raise ValueError("sync failed")
monkeypatch.setattr(
"astrbot.core.computer.tools.neo_skills.get_booter",
_fake_get_booter,
)
monkeypatch.setattr(
"astrbot.core.computer.tools.neo_skills.NeoSkillSyncManager.sync_release",
_fake_sync_release,
)
event = SimpleNamespace(role="admin", unified_msg_origin="session-1")
astr_ctx = SimpleNamespace(context=SimpleNamespace(), event=event)
run_ctx = ContextWrapper(context=astr_ctx)
tool = PromoteSkillCandidateTool()
result = asyncio.run(
tool.call(
run_ctx,
candidate_id="cand-1",
stage="stable",
sync_to_local=True,
)
)
assert isinstance(result, str)
assert "auto rollback succeeded" in result
assert "sync failed" in result
-287
View File
@@ -1,287 +0,0 @@
"""Tests for profile-aware sandbox selection and conditional tool registration."""
from __future__ import annotations
from types import SimpleNamespace
from unittest.mock import patch
import pytest
# ═══════════════════════════════════════════════════════════════
# ShipyardNeoBooter.capabilities
# ═══════════════════════════════════════════════════════════════
class TestShipyardNeoBooterCapabilities:
"""Test capabilities property on ShipyardNeoBooter."""
def _make_booter(self, sandbox_caps: list[str] | None = None):
from astrbot.core.computer.booters.shipyard_neo import ShipyardNeoBooter
booter = ShipyardNeoBooter(
endpoint_url="http://localhost:8114",
access_token="sk-bay-test",
)
if sandbox_caps is not None:
booter._sandbox = SimpleNamespace(capabilities=sandbox_caps)
return booter
def test_none_before_boot(self):
booter = self._make_booter()
assert booter.capabilities is None
def test_returns_tuple_after_boot(self):
booter = self._make_booter(["python", "shell", "filesystem"])
assert booter.capabilities == ("python", "shell", "filesystem")
assert isinstance(booter.capabilities, tuple)
def test_includes_browser_when_present(self):
booter = self._make_booter(["python", "shell", "filesystem", "browser"])
assert "browser" in booter.capabilities
def test_no_browser_when_absent(self):
booter = self._make_booter(["python", "shell", "filesystem"])
assert "browser" not in booter.capabilities
def test_returns_immutable(self):
"""Verify capabilities returns an immutable tuple."""
booter = self._make_booter(["python"])
caps = booter.capabilities
assert isinstance(caps, tuple)
with pytest.raises(AttributeError):
caps.append("mutated") # type: ignore[attr-defined]
# ═══════════════════════════════════════════════════════════════
# _apply_sandbox_tools — conditional browser tool registration
# ═══════════════════════════════════════════════════════════════
def _make_config(booter_type: str = "shipyard_neo"):
return SimpleNamespace(
sandbox_cfg={"booter": booter_type},
)
def _make_req():
return SimpleNamespace(func_tool=None, system_prompt="")
def _import_apply_sandbox_tools():
"""Import _apply_sandbox_tools, skipping if circular-import fails."""
try:
from astrbot.core.astr_main_agent import _apply_sandbox_tools
return _apply_sandbox_tools
except ImportError:
pytest.skip("Cannot import _apply_sandbox_tools (circular import in test env)")
class TestApplySandboxToolsConditional:
"""Verify browser tools are conditionally registered."""
def _tool_names(self, req) -> set[str]:
"""Extract tool names from a request's func_tool."""
if req.func_tool is None:
return set()
return {t.name for t in req.func_tool.tools}
def test_no_session_registers_all(self):
"""First request (no booted session) → all tools including browser."""
fn = _import_apply_sandbox_tools()
config = _make_config("shipyard_neo")
req = _make_req()
with patch(
"astrbot.core.computer.computer_client.session_booter", {}
):
fn(config, req, "session-1")
names = self._tool_names(req)
assert "astrbot_execute_browser" in names
assert "astrbot_execute_browser_batch" in names
assert "astrbot_run_browser_skill" in names
def test_with_browser_capability(self):
"""Booted session with browser capability → browser tools registered."""
fn = _import_apply_sandbox_tools()
config = _make_config("shipyard_neo")
req = _make_req()
fake_booter = SimpleNamespace(
capabilities=["python", "shell", "filesystem", "browser"]
)
with patch(
"astrbot.core.computer.computer_client.session_booter",
{"session-1": fake_booter},
):
fn(config, req, "session-1")
names = self._tool_names(req)
assert "astrbot_execute_browser" in names
def test_without_browser_capability(self):
"""Booted session WITHOUT browser capability → browser tools NOT registered."""
fn = _import_apply_sandbox_tools()
config = _make_config("shipyard_neo")
req = _make_req()
fake_booter = SimpleNamespace(
capabilities=["python", "shell", "filesystem"]
)
with patch(
"astrbot.core.computer.computer_client.session_booter",
{"session-1": fake_booter},
):
fn(config, req, "session-1")
names = self._tool_names(req)
assert "astrbot_execute_browser" not in names
assert "astrbot_execute_browser_batch" not in names
assert "astrbot_run_browser_skill" not in names
# Skill tools should still be registered
assert "astrbot_get_execution_history" in names
def test_skill_tools_always_registered(self):
"""Skill lifecycle tools are registered regardless of capabilities."""
fn = _import_apply_sandbox_tools()
config = _make_config("shipyard_neo")
req = _make_req()
fake_booter = SimpleNamespace(capabilities=["python"])
with patch(
"astrbot.core.computer.computer_client.session_booter",
{"session-1": fake_booter},
):
fn(config, req, "session-1")
names = self._tool_names(req)
assert "astrbot_create_skill_candidate" in names
assert "astrbot_promote_skill_candidate" in names
# ═══════════════════════════════════════════════════════════════
# _resolve_profile
# ═══════════════════════════════════════════════════════════════
class TestResolveProfile:
"""Test smart profile selection logic."""
def _make_booter(self, profile: str = "python-default"):
from astrbot.core.computer.booters.shipyard_neo import ShipyardNeoBooter
return ShipyardNeoBooter(
endpoint_url="http://localhost:8114",
access_token="sk-bay-test",
profile=profile,
)
@pytest.mark.asyncio
async def test_user_specified_profile_honoured(self):
"""User explicitly sets a non-default profile → use it directly."""
booter = self._make_booter(profile="browser-python")
client = SimpleNamespace() # list_profiles should NOT be called
result = await booter._resolve_profile(client)
assert result == "browser-python"
@pytest.mark.asyncio
async def test_selects_browser_profile(self):
"""When multiple profiles available, prefer one with browser."""
async def _mock_list_profiles():
return SimpleNamespace(
items=[
SimpleNamespace(
id="python-default",
capabilities=["python", "shell", "filesystem"],
),
SimpleNamespace(
id="browser-python",
capabilities=["python", "shell", "filesystem", "browser"],
),
]
)
booter = self._make_booter()
client = SimpleNamespace(list_profiles=_mock_list_profiles)
result = await booter._resolve_profile(client)
assert result == "browser-python"
@pytest.mark.asyncio
async def test_falls_back_to_default_on_api_error(self):
"""API error → graceful fallback to python-default."""
async def _failing_list_profiles():
raise ConnectionError("Bay unreachable")
booter = self._make_booter()
client = SimpleNamespace(list_profiles=_failing_list_profiles)
result = await booter._resolve_profile(client)
assert result == "python-default"
@pytest.mark.asyncio
async def test_falls_back_on_empty_profiles(self):
"""Empty profile list → python-default."""
async def _empty_list_profiles():
return SimpleNamespace(items=[])
booter = self._make_booter()
client = SimpleNamespace(list_profiles=_empty_list_profiles)
result = await booter._resolve_profile(client)
assert result == "python-default"
@pytest.mark.asyncio
async def test_single_profile_selected(self):
"""Only one profile available → use it."""
async def _single_profile():
return SimpleNamespace(
items=[
SimpleNamespace(
id="python-data",
capabilities=["python", "shell", "filesystem"],
),
]
)
booter = self._make_booter()
client = SimpleNamespace(list_profiles=_single_profile)
result = await booter._resolve_profile(client)
assert result == "python-data"
@pytest.mark.asyncio
async def test_auth_error_not_silenced(self):
"""UnauthorizedError must propagate, not be downgraded to fallback."""
from shipyard_neo.errors import UnauthorizedError
async def _unauthorized_list_profiles():
raise UnauthorizedError("bad token")
booter = self._make_booter()
client = SimpleNamespace(list_profiles=_unauthorized_list_profiles)
with pytest.raises(UnauthorizedError):
await booter._resolve_profile(client)
# ═══════════════════════════════════════════════════════════════
# ComputerBooter base class
# ═══════════════════════════════════════════════════════════════
class TestBaseComputerBooter:
"""Verify base class defaults."""
def test_capabilities_default_none(self):
from astrbot.core.computer.booters.base import ComputerBooter
booter = ComputerBooter()
assert booter.capabilities is None
def test_browser_default_none(self):
from astrbot.core.computer.booters.base import ComputerBooter
booter = ComputerBooter()
assert booter.browser is None
-104
View File
@@ -1,104 +0,0 @@
from __future__ import annotations
from pathlib import Path
from astrbot.core.skills.skill_manager import SkillManager
def _write_skill(root: Path, name: str, description: str) -> None:
skill_dir = root / name
skill_dir.mkdir(parents=True, exist_ok=True)
skill_dir.joinpath("SKILL.md").write_text(
f"---\ndescription: {description}\n---\n# {name}\n",
encoding="utf-8",
)
def test_list_skills_merges_local_and_sandbox_cache(monkeypatch, tmp_path: Path):
data_dir = tmp_path / "data"
temp_dir = tmp_path / "temp"
skills_root = tmp_path / "skills"
data_dir.mkdir(parents=True, exist_ok=True)
temp_dir.mkdir(parents=True, exist_ok=True)
skills_root.mkdir(parents=True, exist_ok=True)
monkeypatch.setattr(
"astrbot.core.skills.skill_manager.get_astrbot_data_path",
lambda: str(data_dir),
)
monkeypatch.setattr(
"astrbot.core.skills.skill_manager.get_astrbot_temp_path",
lambda: str(temp_dir),
)
mgr = SkillManager(skills_root=str(skills_root))
_write_skill(skills_root, "custom-local", "local description")
mgr.set_sandbox_skills_cache(
[
{
"name": "python-sandbox",
"description": "ship built-in",
"path": "/app/skills/python-sandbox/SKILL.md",
},
{
"name": "custom-local",
"description": "should be ignored by local override",
"path": "skills/custom-local/SKILL.md",
},
]
)
skills = mgr.list_skills(runtime="sandbox")
by_name = {item.name: item for item in skills}
assert sorted(by_name) == ["custom-local", "python-sandbox"]
assert by_name["custom-local"].description == "local description"
assert by_name["custom-local"].path == "skills/custom-local/SKILL.md"
assert by_name["python-sandbox"].description == "ship built-in"
assert by_name["python-sandbox"].path == "skills/python-sandbox/SKILL.md"
def test_sandbox_cached_skill_respects_active_and_display_path(
monkeypatch,
tmp_path: Path,
):
data_dir = tmp_path / "data"
temp_dir = tmp_path / "temp"
skills_root = tmp_path / "skills"
data_dir.mkdir(parents=True, exist_ok=True)
temp_dir.mkdir(parents=True, exist_ok=True)
skills_root.mkdir(parents=True, exist_ok=True)
monkeypatch.setattr(
"astrbot.core.skills.skill_manager.get_astrbot_data_path",
lambda: str(data_dir),
)
monkeypatch.setattr(
"astrbot.core.skills.skill_manager.get_astrbot_temp_path",
lambda: str(temp_dir),
)
mgr = SkillManager(skills_root=str(skills_root))
mgr.set_sandbox_skills_cache(
[
{
"name": "browser-automation",
"description": "gull built-in",
"path": "/app/skills/browser-automation/SKILL.md",
}
]
)
all_skills = mgr.list_skills(
runtime="sandbox",
active_only=False,
show_sandbox_path=False,
)
assert len(all_skills) == 1
assert all_skills[0].path == "/app/skills/browser-automation/SKILL.md"
mgr.set_skill_active("browser-automation", False)
active_skills = mgr.list_skills(runtime="sandbox", active_only=True)
assert active_skills == []
-203
View File
@@ -1,203 +0,0 @@
"""Tests for skill metadata: frontmatter parsing, prompt generation, absolute paths."""
from __future__ import annotations
from pathlib import Path
from astrbot.core.skills.skill_manager import (
SkillInfo,
SkillManager,
_parse_frontmatter_description,
build_skills_prompt,
)
# ---------- _parse_frontmatter_description tests ----------
def test_parse_frontmatter_description():
text = (
"---\n"
"name: screenshot-capture\n"
"description: Captures full-page screenshots of web pages. "
"Use when user asks to screenshot, take a picture of a page, "
"截图, or needs a visual snapshot of any URL.\n"
"---\n"
"# Screenshot Skill\n"
)
desc = _parse_frontmatter_description(text)
assert "Captures full-page screenshots" in desc
assert "截图" in desc
def test_parse_frontmatter_description_only():
text = "---\ndescription: legacy skill\n---\n# Title\n"
assert _parse_frontmatter_description(text) == "legacy skill"
def test_parse_frontmatter_empty():
assert _parse_frontmatter_description("no frontmatter") == ""
assert _parse_frontmatter_description("") == ""
def test_parse_frontmatter_missing_end_delimiter():
text = "---\ndescription: broken\n"
assert _parse_frontmatter_description(text) == ""
def test_parse_frontmatter_quoted_description():
text = '---\ndescription: "quoted value"\n---\n'
assert _parse_frontmatter_description(text) == "quoted value"
# ---------- build_skills_prompt tests ----------
def test_build_skills_prompt_basic_format():
skills = [
SkillInfo(
name="screenshot",
description="Take screenshots of web pages",
path="/abs/skills/screenshot/SKILL.md",
active=True,
)
]
prompt = build_skills_prompt(skills)
assert "**screenshot**" in prompt
assert "Take screenshots of web pages" in prompt
assert "`/abs/skills/screenshot/SKILL.md`" in prompt
def test_build_skills_prompt_absolute_path_in_example():
"""The mandatory grounding example should show the absolute path."""
skills = [
SkillInfo(
name="foo",
description="do foo",
path="/home/pan/AstrBot/skills/foo/SKILL.md",
active=True,
),
]
prompt = build_skills_prompt(skills)
assert "cat /home/pan/AstrBot/skills/foo/SKILL.md" in prompt
def test_build_skills_prompt_progressive_disclosure_rules():
"""The prompt should contain the key progressive disclosure rules."""
skills = [
SkillInfo(
name="test",
description="test skill",
path="/skills/test/SKILL.md",
active=True,
)
]
prompt = build_skills_prompt(skills)
# Numbered rules
assert "1." in prompt # Discovery
assert "2." in prompt # When to trigger
assert "3." in prompt # Mandatory grounding
assert "4." in prompt # Progressive disclosure
# Key concepts
assert "Mandatory grounding" in prompt
assert "Progressive disclosure" in prompt
assert "SKILL.md" in prompt
def test_build_skills_prompt_no_custom_fields():
"""Prompt should NOT contain triggers/capabilities/output labels."""
skills = [
SkillInfo(
name="test",
description="test skill",
path="/skills/test/SKILL.md",
active=True,
)
]
prompt = build_skills_prompt(skills)
assert "Triggers:" not in prompt
assert "Capabilities:" not in prompt
assert "Output:" not in prompt
# ---------- list_skills with description ----------
def test_list_skills_parses_description_from_local(monkeypatch, tmp_path: Path):
data_dir = tmp_path / "data"
temp_dir = tmp_path / "temp"
skills_root = tmp_path / "skills"
data_dir.mkdir(parents=True, exist_ok=True)
temp_dir.mkdir(parents=True, exist_ok=True)
skills_root.mkdir(parents=True, exist_ok=True)
monkeypatch.setattr(
"astrbot.core.skills.skill_manager.get_astrbot_data_path",
lambda: str(data_dir),
)
monkeypatch.setattr(
"astrbot.core.skills.skill_manager.get_astrbot_temp_path",
lambda: str(temp_dir),
)
skill_dir = skills_root / "screencap"
skill_dir.mkdir()
skill_dir.joinpath("SKILL.md").write_text(
"---\n"
"name: screencap\n"
"description: Capture screenshots of web pages. "
"Use when user asks to screenshot, 截图, or capture a page.\n"
"---\n"
"# Screenshot\n",
encoding="utf-8",
)
mgr = SkillManager(skills_root=str(skills_root))
skills = mgr.list_skills()
assert len(skills) == 1
s = skills[0]
assert "Capture screenshots" in s.description
assert "截图" in s.description
# SkillInfo should NOT have triggers/capabilities/output attributes
assert not hasattr(s, "triggers")
assert not hasattr(s, "capabilities")
assert not hasattr(s, "output")
def test_list_skills_description_from_sandbox_cache(
monkeypatch, tmp_path: Path
):
data_dir = tmp_path / "data"
temp_dir = tmp_path / "temp"
skills_root = tmp_path / "skills"
data_dir.mkdir(parents=True, exist_ok=True)
temp_dir.mkdir(parents=True, exist_ok=True)
skills_root.mkdir(parents=True, exist_ok=True)
monkeypatch.setattr(
"astrbot.core.skills.skill_manager.get_astrbot_data_path",
lambda: str(data_dir),
)
monkeypatch.setattr(
"astrbot.core.skills.skill_manager.get_astrbot_temp_path",
lambda: str(temp_dir),
)
mgr = SkillManager(skills_root=str(skills_root))
mgr.set_sandbox_skills_cache(
[
{
"name": "web-scrape",
"description": "Scrape web pages and extract structured data. "
"Use when user needs to extract content from URLs.",
"path": "/home/pan/AstrBot/skills/web-scrape/SKILL.md",
}
]
)
skills = mgr.list_skills(runtime="sandbox", show_sandbox_path=False)
assert len(skills) == 1
s = skills[0]
assert "Scrape web pages" in s.description
# Path should be the absolute path from cache
assert "/home/pan/AstrBot/skills/web-scrape/SKILL.md" in s.path
+24
View File
@@ -516,6 +516,30 @@ class TestEnsurePersonaAndSkills:
assert "Persona Instructions" not in req.system_prompt
@pytest.mark.asyncio
async def test_ensure_skills(self, mock_event, mock_context):
"""Test applying skills to request."""
module = ama
mock_skill = MagicMock()
mock_skill.name = "test_skill"
mock_skill.to_prompt.return_value = "Skill description"
mock_context.persona_manager.personas_v3 = []
mock_context.persona_manager.resolve_selected_persona = AsyncMock(
return_value=(None, None, None, False)
)
with patch("astrbot.core.astr_main_agent.SkillManager") as mock_skill_mgr_cls:
mock_skill_mgr = MagicMock()
mock_skill_mgr.list_skills.return_value = [mock_skill]
mock_skill_mgr_cls.return_value = mock_skill_mgr
req = ProviderRequest()
req.conversation = MagicMock(persona_id=None)
await module._ensure_persona_and_skills(req, {}, mock_context, mock_event)
assert "test_skill" in req.system_prompt
@pytest.mark.asyncio
async def test_ensure_tools_from_persona(self, mock_event, mock_context):
"""Test applying tools from persona."""
-17
View File
@@ -1,17 +0,0 @@
import platform
from astrbot.core.computer.tools.python import PythonTool, LocalPythonTool
def test_python_tool_description_contains_os():
"""测试 PythonTool 的描述中是否包含当前操作系统信息"""
tool = PythonTool()
current_os = platform.system()
assert current_os in tool.description
assert "IPython" in tool.description
def test_local_python_tool_description_contains_os():
"""测试 LocalPythonTool 的描述中是否包含当前操作系统信息和兼容性提示"""
tool = LocalPythonTool()
current_os = platform.system()
assert current_os in tool.description
assert "Python environment" in tool.description
assert "system-compatible" in tool.description