418 lines
14 KiB
Python
418 lines
14 KiB
Python
import asyncio
|
|
import os
|
|
import sys
|
|
from types import SimpleNamespace
|
|
|
|
import pytest
|
|
import pytest_asyncio
|
|
from quart import Quart
|
|
|
|
from astrbot.core import LogBroker
|
|
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
|
from astrbot.core.db.sqlite import SQLiteDatabase
|
|
from astrbot.core.star.star import star_registry
|
|
from astrbot.core.star.star_handler import star_handlers_registry
|
|
from astrbot.dashboard.server import AstrBotDashboard
|
|
|
|
|
|
@pytest_asyncio.fixture(scope="module")
|
|
async def core_lifecycle_td(tmp_path_factory):
|
|
"""Creates and initializes a core lifecycle instance with a temporary database."""
|
|
tmp_db_path = tmp_path_factory.mktemp("data") / "test_data_v3.db"
|
|
db = SQLiteDatabase(str(tmp_db_path))
|
|
log_broker = LogBroker()
|
|
core_lifecycle = AstrBotCoreLifecycle(log_broker, db)
|
|
await core_lifecycle.initialize()
|
|
try:
|
|
yield core_lifecycle
|
|
finally:
|
|
# 优先停止核心生命周期以释放资源(包括关闭 MCP 等后台任务)
|
|
try:
|
|
_stop_res = core_lifecycle.stop()
|
|
if asyncio.iscoroutine(_stop_res):
|
|
await _stop_res
|
|
except Exception:
|
|
# 停止过程中如有异常,不影响后续清理
|
|
pass
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
def app(core_lifecycle_td: AstrBotCoreLifecycle):
|
|
"""Creates a Quart app instance for testing."""
|
|
shutdown_event = asyncio.Event()
|
|
# The db instance is already part of the core_lifecycle_td
|
|
server = AstrBotDashboard(core_lifecycle_td, core_lifecycle_td.db, shutdown_event)
|
|
return server.app
|
|
|
|
|
|
@pytest_asyncio.fixture(scope="module")
|
|
async def authenticated_header(app: Quart, core_lifecycle_td: AstrBotCoreLifecycle):
|
|
"""Handles login and returns an authenticated header."""
|
|
test_client = app.test_client()
|
|
response = await test_client.post(
|
|
"/api/auth/login",
|
|
json={
|
|
"username": core_lifecycle_td.astrbot_config["dashboard"]["username"],
|
|
"password": core_lifecycle_td.astrbot_config["dashboard"]["password"],
|
|
},
|
|
)
|
|
data = await response.get_json()
|
|
assert data["status"] == "ok"
|
|
token = data["data"]["token"]
|
|
return {"Authorization": f"Bearer {token}"}
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_auth_login(app: Quart, core_lifecycle_td: AstrBotCoreLifecycle):
|
|
"""Tests the login functionality with both wrong and correct credentials."""
|
|
test_client = app.test_client()
|
|
response = await test_client.post(
|
|
"/api/auth/login",
|
|
json={"username": "wrong", "password": "password"},
|
|
)
|
|
data = await response.get_json()
|
|
assert data["status"] == "error"
|
|
|
|
response = await test_client.post(
|
|
"/api/auth/login",
|
|
json={
|
|
"username": core_lifecycle_td.astrbot_config["dashboard"]["username"],
|
|
"password": core_lifecycle_td.astrbot_config["dashboard"]["password"],
|
|
},
|
|
)
|
|
data = await response.get_json()
|
|
assert data["status"] == "ok" and "token" in data["data"]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_stat(app: Quart, authenticated_header: dict):
|
|
test_client = app.test_client()
|
|
response = await test_client.get("/api/stat/get")
|
|
assert response.status_code == 401
|
|
response = await test_client.get("/api/stat/get", headers=authenticated_header)
|
|
assert response.status_code == 200
|
|
data = await response.get_json()
|
|
assert data["status"] == "ok" and "platform" in data["data"]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_plugins(app: Quart, authenticated_header: dict):
|
|
test_client = app.test_client()
|
|
# 已经安装的插件
|
|
response = await test_client.get("/api/plugin/get", headers=authenticated_header)
|
|
assert response.status_code == 200
|
|
data = await response.get_json()
|
|
assert data["status"] == "ok"
|
|
|
|
# 插件市场
|
|
response = await test_client.get(
|
|
"/api/plugin/market_list",
|
|
headers=authenticated_header,
|
|
)
|
|
assert response.status_code == 200
|
|
data = await response.get_json()
|
|
assert data["status"] == "ok"
|
|
|
|
# 插件安装
|
|
response = await test_client.post(
|
|
"/api/plugin/install",
|
|
json={"url": "https://github.com/Soulter/astrbot_plugin_essential"},
|
|
headers=authenticated_header,
|
|
)
|
|
assert response.status_code == 200
|
|
data = await response.get_json()
|
|
assert data["status"] == "ok"
|
|
exists = False
|
|
for md in star_registry:
|
|
if md.name == "astrbot_plugin_essential":
|
|
exists = True
|
|
break
|
|
assert exists is True, "插件 astrbot_plugin_essential 未成功载入"
|
|
|
|
# 插件更新
|
|
response = await test_client.post(
|
|
"/api/plugin/update",
|
|
json={"name": "astrbot_plugin_essential"},
|
|
headers=authenticated_header,
|
|
)
|
|
assert response.status_code == 200
|
|
data = await response.get_json()
|
|
assert data["status"] == "ok"
|
|
|
|
# 插件卸载
|
|
response = await test_client.post(
|
|
"/api/plugin/uninstall",
|
|
json={"name": "astrbot_plugin_essential"},
|
|
headers=authenticated_header,
|
|
)
|
|
assert response.status_code == 200
|
|
data = await response.get_json()
|
|
assert data["status"] == "ok"
|
|
exists = False
|
|
for md in star_registry:
|
|
if md.name == "astrbot_plugin_essential":
|
|
exists = True
|
|
break
|
|
assert exists is False, "插件 astrbot_plugin_essential 未成功卸载"
|
|
exists = False
|
|
for md in star_handlers_registry:
|
|
if "astrbot_plugin_essential" in md.handler_module_path:
|
|
exists = True
|
|
break
|
|
assert exists is False, "插件 astrbot_plugin_essential 未成功卸载"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_commands_api(app: Quart, authenticated_header: dict):
|
|
"""Tests the command management API endpoints."""
|
|
test_client = app.test_client()
|
|
|
|
# GET /api/commands - list commands
|
|
response = await test_client.get("/api/commands", headers=authenticated_header)
|
|
assert response.status_code == 200
|
|
data = await response.get_json()
|
|
assert data["status"] == "ok"
|
|
assert "items" in data["data"]
|
|
assert "summary" in data["data"]
|
|
summary = data["data"]["summary"]
|
|
assert "total" in summary
|
|
assert "disabled" in summary
|
|
assert "conflicts" in summary
|
|
|
|
# GET /api/commands/conflicts - list conflicts
|
|
response = await test_client.get(
|
|
"/api/commands/conflicts", headers=authenticated_header
|
|
)
|
|
assert response.status_code == 200
|
|
data = await response.get_json()
|
|
assert data["status"] == "ok"
|
|
# conflicts is a list
|
|
assert isinstance(data["data"], list)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_check_update(app: Quart, authenticated_header: dict):
|
|
test_client = app.test_client()
|
|
response = await test_client.get("/api/update/check", headers=authenticated_header)
|
|
assert response.status_code == 200
|
|
data = await response.get_json()
|
|
assert data["status"] == "success"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_do_update(
|
|
app: Quart,
|
|
authenticated_header: dict,
|
|
core_lifecycle_td: AstrBotCoreLifecycle,
|
|
monkeypatch,
|
|
tmp_path_factory,
|
|
):
|
|
test_client = app.test_client()
|
|
|
|
# Use a temporary path for the mock update to avoid side effects
|
|
temp_release_dir = tmp_path_factory.mktemp("release")
|
|
release_path = temp_release_dir / "astrbot"
|
|
|
|
async def mock_update(*args, **kwargs):
|
|
"""Mocks the update process by creating a directory in the temp path."""
|
|
os.makedirs(release_path, exist_ok=True)
|
|
|
|
async def mock_download_dashboard(*args, **kwargs):
|
|
"""Mocks the dashboard download to prevent network access."""
|
|
return
|
|
|
|
async def mock_pip_install(*args, **kwargs):
|
|
"""Mocks pip install to prevent actual installation."""
|
|
return
|
|
|
|
monkeypatch.setattr(core_lifecycle_td.astrbot_updator, "update", mock_update)
|
|
monkeypatch.setattr(
|
|
"astrbot.dashboard.routes.update.download_dashboard",
|
|
mock_download_dashboard,
|
|
)
|
|
monkeypatch.setattr(
|
|
"astrbot.dashboard.routes.update.pip_installer.install",
|
|
mock_pip_install,
|
|
)
|
|
|
|
response = await test_client.post(
|
|
"/api/update/do",
|
|
headers=authenticated_header,
|
|
json={"version": "v3.4.0", "reboot": False},
|
|
)
|
|
assert response.status_code == 200
|
|
data = await response.get_json()
|
|
assert data["status"] == "ok"
|
|
assert os.path.exists(release_path)
|
|
|
|
|
|
class _FakeNeoSkills:
|
|
async def list_candidates(self, **kwargs):
|
|
_ = kwargs
|
|
return [
|
|
{
|
|
"id": "cand-1",
|
|
"skill_key": "neo.demo",
|
|
"status": "evaluated_pass",
|
|
"payload_ref": "pref-1",
|
|
}
|
|
]
|
|
|
|
async def list_releases(self, **kwargs):
|
|
_ = kwargs
|
|
return [
|
|
{
|
|
"id": "rel-1",
|
|
"skill_key": "neo.demo",
|
|
"candidate_id": "cand-1",
|
|
"stage": "stable",
|
|
"active": True,
|
|
}
|
|
]
|
|
|
|
async def get_payload(self, payload_ref: str):
|
|
return {
|
|
"payload_ref": payload_ref,
|
|
"payload": {"skill_markdown": "# Demo"},
|
|
}
|
|
|
|
async def evaluate_candidate(self, candidate_id: str, **kwargs):
|
|
return {"candidate_id": candidate_id, **kwargs}
|
|
|
|
async def promote_candidate(self, candidate_id: str, stage: str = "canary"):
|
|
return {
|
|
"id": "rel-2",
|
|
"skill_key": "neo.demo",
|
|
"candidate_id": candidate_id,
|
|
"stage": stage,
|
|
}
|
|
|
|
async def rollback_release(self, release_id: str):
|
|
return {"id": "rb-1", "rolled_back_release_id": release_id}
|
|
|
|
|
|
class _FakeNeoBayClient:
|
|
def __init__(self, endpoint_url: str, access_token: str):
|
|
self.endpoint_url = endpoint_url
|
|
self.access_token = access_token
|
|
self.skills = _FakeNeoSkills()
|
|
|
|
async def __aenter__(self):
|
|
return self
|
|
|
|
async def __aexit__(self, exc_type, exc, tb):
|
|
_ = exc_type, exc, tb
|
|
return False
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_neo_skills_routes(
|
|
app: Quart,
|
|
authenticated_header: dict,
|
|
core_lifecycle_td: AstrBotCoreLifecycle,
|
|
monkeypatch,
|
|
):
|
|
provider_settings = core_lifecycle_td.astrbot_config.setdefault(
|
|
"provider_settings", {}
|
|
)
|
|
sandbox = provider_settings.setdefault("sandbox", {})
|
|
sandbox["shipyard_neo_endpoint"] = "http://neo.test"
|
|
sandbox["shipyard_neo_access_token"] = "neo-token"
|
|
|
|
fake_shipyard_neo_module = SimpleNamespace(BayClient=_FakeNeoBayClient)
|
|
monkeypatch.setitem(sys.modules, "shipyard_neo", fake_shipyard_neo_module)
|
|
|
|
async def _fake_sync_release(self, client, **kwargs):
|
|
_ = self, client, kwargs
|
|
return SimpleNamespace(
|
|
skill_key="neo.demo",
|
|
local_skill_name="neo_demo",
|
|
release_id="rel-2",
|
|
candidate_id="cand-1",
|
|
payload_ref="pref-1",
|
|
map_path="data/skills/neo_skill_map.json",
|
|
synced_at="2026-01-01T00:00:00Z",
|
|
)
|
|
|
|
async def _fake_sync_skills_to_active_sandboxes():
|
|
return
|
|
|
|
monkeypatch.setattr(
|
|
"astrbot.dashboard.routes.skills.NeoSkillSyncManager.sync_release",
|
|
_fake_sync_release,
|
|
)
|
|
monkeypatch.setattr(
|
|
"astrbot.dashboard.routes.skills.sync_skills_to_active_sandboxes",
|
|
_fake_sync_skills_to_active_sandboxes,
|
|
)
|
|
|
|
test_client = app.test_client()
|
|
|
|
response = await test_client.get(
|
|
"/api/skills/neo/candidates", headers=authenticated_header
|
|
)
|
|
assert response.status_code == 200
|
|
data = await response.get_json()
|
|
assert data["status"] == "ok"
|
|
assert isinstance(data["data"], list)
|
|
assert data["data"][0]["id"] == "cand-1"
|
|
|
|
response = await test_client.get(
|
|
"/api/skills/neo/releases", headers=authenticated_header
|
|
)
|
|
assert response.status_code == 200
|
|
data = await response.get_json()
|
|
assert data["status"] == "ok"
|
|
assert isinstance(data["data"], list)
|
|
assert data["data"][0]["id"] == "rel-1"
|
|
|
|
response = await test_client.get(
|
|
"/api/skills/neo/payload?payload_ref=pref-1", headers=authenticated_header
|
|
)
|
|
assert response.status_code == 200
|
|
data = await response.get_json()
|
|
assert data["status"] == "ok"
|
|
assert data["data"]["payload_ref"] == "pref-1"
|
|
|
|
response = await test_client.post(
|
|
"/api/skills/neo/evaluate",
|
|
json={"candidate_id": "cand-1", "passed": True, "score": 0.95},
|
|
headers=authenticated_header,
|
|
)
|
|
assert response.status_code == 200
|
|
data = await response.get_json()
|
|
assert data["status"] == "ok"
|
|
assert data["data"]["candidate_id"] == "cand-1"
|
|
assert data["data"]["passed"] is True
|
|
|
|
response = await test_client.post(
|
|
"/api/skills/neo/promote",
|
|
json={"candidate_id": "cand-1", "stage": "stable"},
|
|
headers=authenticated_header,
|
|
)
|
|
assert response.status_code == 200
|
|
data = await response.get_json()
|
|
assert data["status"] == "ok"
|
|
assert data["data"]["release"]["id"] == "rel-2"
|
|
assert data["data"]["sync"]["local_skill_name"] == "neo_demo"
|
|
|
|
response = await test_client.post(
|
|
"/api/skills/neo/rollback",
|
|
json={"release_id": "rel-2"},
|
|
headers=authenticated_header,
|
|
)
|
|
assert response.status_code == 200
|
|
data = await response.get_json()
|
|
assert data["status"] == "ok"
|
|
assert data["data"]["rolled_back_release_id"] == "rel-2"
|
|
|
|
response = await test_client.post(
|
|
"/api/skills/neo/sync",
|
|
json={"release_id": "rel-2"},
|
|
headers=authenticated_header,
|
|
)
|
|
assert response.status_code == 200
|
|
data = await response.get_json()
|
|
assert data["status"] == "ok"
|
|
assert data["data"]["skill_key"] == "neo.demo"
|