test: dashboard test
This commit is contained in:
@@ -44,7 +44,6 @@ class Stage(abc.ABC):
|
||||
try:
|
||||
ready_to_call = handler(event, **params)
|
||||
except TypeError as e:
|
||||
print(e)
|
||||
# 向下兼容
|
||||
ready_to_call = handler(event, ctx.plugin_manager.context, **params)
|
||||
|
||||
|
||||
@@ -58,7 +58,7 @@ class ProviderZhipu(ProviderOpenAIOfficial):
|
||||
payloads = {
|
||||
"messages": context_query,
|
||||
**model_cfgs
|
||||
}
|
||||
}
|
||||
llm_response = None
|
||||
try:
|
||||
llm_response = await self._query(payloads, func_tool)
|
||||
|
||||
@@ -140,6 +140,11 @@ class PluginManager:
|
||||
|
||||
def reload(self):
|
||||
'''扫描并加载所有的 Star'''
|
||||
for smd in star_registry:
|
||||
logger.debug(f"尝试终止插件 {smd.name} ...")
|
||||
if hasattr(smd.star_cls, "__del__"):
|
||||
smd.star_cls.__del__()
|
||||
|
||||
star_handlers_registry.clear()
|
||||
star_handlers_registry.star_handlers_map.clear()
|
||||
star_map.clear()
|
||||
@@ -272,6 +277,7 @@ class PluginManager:
|
||||
raise Exception("该插件是 AstrBot 保留插件,无法更新。")
|
||||
|
||||
await self.updator.update(plugin)
|
||||
self.reload()
|
||||
|
||||
def install_plugin_from_file(self, zip_file_path: str):
|
||||
desti_dir = os.path.join(self.plugin_store_path, os.path.basename(zip_file_path))
|
||||
|
||||
@@ -53,7 +53,6 @@ class PluginUpdator(RepoZipUpdator):
|
||||
|
||||
files = os.listdir(os.path.join(target_dir, update_dir))
|
||||
for f in files:
|
||||
logger.info(f"移动更新文件/目录: {f}")
|
||||
if os.path.isdir(os.path.join(target_dir, update_dir, f)):
|
||||
if os.path.exists(os.path.join(target_dir, f)):
|
||||
shutil.rmtree(os.path.join(target_dir, f), onerror=on_error)
|
||||
@@ -63,7 +62,7 @@ class PluginUpdator(RepoZipUpdator):
|
||||
shutil.move(os.path.join(target_dir, update_dir, f), target_dir)
|
||||
|
||||
try:
|
||||
logger.info(f"删除临时更新文件: {zip_path} 和 {os.path.join(target_dir, update_dir)}")
|
||||
logger.info(f"删除临时文件: {zip_path} 和 {os.path.join(target_dir, update_dir)}")
|
||||
shutil.rmtree(os.path.join(target_dir, update_dir), onerror=on_error)
|
||||
os.remove(zip_path)
|
||||
except BaseException:
|
||||
|
||||
@@ -111,7 +111,7 @@ class RepoZipUpdator():
|
||||
releases = await self.fetch_release_info(url=release_url)
|
||||
if not releases:
|
||||
# download from the default branch directly.
|
||||
logger.warning(f"未在仓库 {author}/{repo} 中找到任何发布版本,将从默认分支下载。")
|
||||
logger.warning(f"未在仓库 {author}/{repo} 中找到任何发布版本,正在从默认分支下载。")
|
||||
release_url = f"https://github.com/{author}/{repo}/archive/refs/heads/master.zip"
|
||||
else:
|
||||
release_url = releases[0]['zipball_url']
|
||||
|
||||
@@ -53,7 +53,6 @@ class PluginRoute(Route):
|
||||
try:
|
||||
logger.info(f"正在安装插件 {repo_url}")
|
||||
await self.plugin_manager.install_plugin(repo_url)
|
||||
self.core_lifecycle.restart()
|
||||
logger.info(f"安装插件 {repo_url} 成功。")
|
||||
return Response().ok(None, "安装成功。").__dict__
|
||||
except Exception as e:
|
||||
@@ -69,7 +68,6 @@ class PluginRoute(Route):
|
||||
await file.save(file_path)
|
||||
self.plugin_manager.install_plugin_from_file(file_path)
|
||||
logger.info(f"安装插件 {file.filename} 成功")
|
||||
self.core_lifecycle.restart()
|
||||
return Response().ok(None, "安装成功。").__dict__
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
@@ -93,9 +91,8 @@ class PluginRoute(Route):
|
||||
try:
|
||||
logger.info(f"正在更新插件 {plugin_name}")
|
||||
await self.plugin_manager.update_plugin(plugin_name)
|
||||
self.core_lifecycle.restart()
|
||||
logger.info(f"更新插件 {plugin_name} 成功,2秒后重启")
|
||||
return Response().ok(None, "更新成功,程序将在 2 秒内重启。").__dict__
|
||||
logger.info(f"更新插件 {plugin_name} 成功。")
|
||||
return Response().ok(None, "更新成功。").__dict__
|
||||
except Exception as e:
|
||||
logger.error(f"/api/extensions/update: {traceback.format_exc()}")
|
||||
return Response().error(str(e)).__dict__
|
||||
@@ -32,6 +32,7 @@ class UpdateRoute(Route):
|
||||
async def update_project(self):
|
||||
data = await request.json
|
||||
version = data.get('version', '')
|
||||
reboot = data.get('reboot', True)
|
||||
if version == "" or version == "latest":
|
||||
latest = True
|
||||
version = ''
|
||||
@@ -39,8 +40,11 @@ class UpdateRoute(Route):
|
||||
latest = False
|
||||
try:
|
||||
await self.astrbot_updator.update(latest=latest, version=version)
|
||||
threading.Thread(target=self.astrbot_updator._reboot, args=(2, )).start()
|
||||
return Response().ok(None, "更新成功,AstrBot 将在 2 秒内全量重启以应用新的代码。").__dict__
|
||||
if reboot:
|
||||
threading.Thread(target=self.astrbot_updator._reboot, args=(2, )).start()
|
||||
return Response().ok(None, "更新成功,AstrBot 将在 2 秒内全量重启以应用新的代码。").__dict__
|
||||
else:
|
||||
return Response().ok(None, "更新成功,AstrBot 将在下次启动时应用新的代码。").__dict__
|
||||
except Exception as e:
|
||||
logger.error(f"/api/update_project: {traceback.format_exc()}")
|
||||
return Response().error(e.__str__()).__dict__
|
||||
@@ -0,0 +1,149 @@
|
||||
import pytest
|
||||
import os
|
||||
from quart import Quart
|
||||
from astrbot.dashboard.server import AstrBotDashboard
|
||||
from astrbot.core.db.sqlite import SQLiteDatabase
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
from astrbot.core import LogBroker
|
||||
from astrbot.core.star.star_handler import star_handlers_registry
|
||||
from astrbot.core.star.star import star_registry
|
||||
from astrbot.core.updator import AstrBotUpdator
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def core_lifecycle():
|
||||
db = SQLiteDatabase("data/data_v3.db")
|
||||
log_broker = LogBroker()
|
||||
core_lifecycle = AstrBotCoreLifecycle(log_broker, db)
|
||||
return core_lifecycle
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def app(core_lifecycle):
|
||||
db = SQLiteDatabase("data/data_v3.db")
|
||||
server = AstrBotDashboard(core_lifecycle, db)
|
||||
return server.app
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def header():
|
||||
return {}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_init_core_lifecycle(core_lifecycle):
|
||||
await core_lifecycle.initialize()
|
||||
assert core_lifecycle is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auth_login(app: Quart, core_lifecycle: AstrBotCoreLifecycle, header: dict):
|
||||
test_client = app.test_client()
|
||||
response = await test_client.post('/api/auth/login', json={
|
||||
"username": "wrong",
|
||||
"password": "password"
|
||||
})
|
||||
data = await response.get_json()
|
||||
assert data['status'] == 'error'
|
||||
|
||||
response = await test_client.post('/api/auth/login', json={
|
||||
"username": core_lifecycle.astrbot_config['dashboard']['username'],
|
||||
"password": core_lifecycle.astrbot_config['dashboard']['password']
|
||||
})
|
||||
data = await response.get_json()
|
||||
assert data['status'] == 'ok' and 'token' in data['data']
|
||||
header['Authorization'] = f"Bearer {data['data']['token']}"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_stat(app: Quart, header: dict):
|
||||
test_client = app.test_client()
|
||||
response = await test_client.get('/api/stat/get')
|
||||
assert response.status_code == 401
|
||||
response = await test_client.get('/api/stat/get', headers=header)
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data['status'] == 'ok' and 'platform' in data['data']
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_plugins(app: Quart, header: dict):
|
||||
test_client = app.test_client()
|
||||
# 已经安装的插件
|
||||
response = await test_client.get('/api/plugin/get', headers=header)
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data['status'] == 'ok'
|
||||
|
||||
# 插件市场
|
||||
response = await test_client.get('/api/plugin/market_list', headers=header)
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data['status'] == 'ok'
|
||||
|
||||
# 插件安装
|
||||
response = await test_client.post('/api/plugin/install', json={
|
||||
"url": "https://github.com/Soulter/astrbot_plugin_essential"
|
||||
}, headers=header)
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data['status'] == 'ok'
|
||||
exists = False
|
||||
for md in star_registry:
|
||||
if md.name == "astrbot_plugin_essential":
|
||||
exists = True
|
||||
break
|
||||
assert exists is True, "插件 astrbot_plugin_essential 未成功载入"
|
||||
|
||||
# 插件更新
|
||||
response = await test_client.post('/api/plugin/update', json={
|
||||
"name": "astrbot_plugin_essential"
|
||||
}, headers=header)
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data['status'] == 'ok'
|
||||
|
||||
# 插件卸载
|
||||
response = await test_client.post('/api/plugin/uninstall', json={
|
||||
"name": "astrbot_plugin_essential"
|
||||
}, headers=header)
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data['status'] == 'ok'
|
||||
exists = False
|
||||
for md in star_registry:
|
||||
if md.name == "astrbot_plugin_essential":
|
||||
exists = True
|
||||
break
|
||||
assert exists is False, "插件 astrbot_plugin_essential 未成功卸载"
|
||||
exists = False
|
||||
for md in star_handlers_registry:
|
||||
if "astrbot_plugin_essential" in md.handler_module_path:
|
||||
exists = True
|
||||
break
|
||||
assert exists is False, "插件 astrbot_plugin_essential 未成功卸载"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_update(app: Quart, header: dict):
|
||||
test_client = app.test_client()
|
||||
response = await test_client.get('/api/update/check', headers=header)
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data['status'] == 'success'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_do_update(app: Quart, header: dict, core_lifecycle: AstrBotCoreLifecycle):
|
||||
global VERSION
|
||||
test_client = app.test_client()
|
||||
os.makedirs("data/astrbot_release", exist_ok=True)
|
||||
core_lifecycle.astrbot_updator.MAIN_PATH = "data/astrbot_release"
|
||||
VERSION = "114.514.1919810"
|
||||
response = await test_client.post('/api/update/do', headers=header, json={
|
||||
"version": "latest"
|
||||
})
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data['status'] == 'error' # 已经是最新版本
|
||||
|
||||
response = await test_client.post('/api/update/do', headers=header, json={
|
||||
"version": "v3.4.0",
|
||||
"reboot": False
|
||||
})
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data['status'] == 'ok'
|
||||
assert os.path.exists("data/astrbot_release/astrbot")
|
||||
+15
-8
@@ -1,4 +1,6 @@
|
||||
import pytest, logging, os
|
||||
import pytest
|
||||
import logging
|
||||
import os
|
||||
from astrbot.core.pipeline.scheduler import PipelineScheduler, PipelineContext
|
||||
from astrbot.core.star import PluginManager
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
@@ -11,7 +13,6 @@ from astrbot.core.platform.manager import PlatformManager
|
||||
from astrbot.core.provider.manager import ProviderManager
|
||||
from astrbot.core.db.sqlite import SQLiteDatabase
|
||||
from astrbot.core.star.context import Context
|
||||
from astrbot.core import logger
|
||||
from asyncio import Queue
|
||||
|
||||
SESSION_ID_IN_WHITELIST = "test_sid_wl"
|
||||
@@ -29,7 +30,7 @@ TEST_LLM_PROVIDER = {
|
||||
|
||||
TEST_COMMANDS = [
|
||||
["help", "已注册的 AstrBot 内置指令"],
|
||||
["tool ls", "查看、激活、停用当前注册的函数工具"],
|
||||
["tool ls", "函数工具"],
|
||||
["tool on websearch", "激活工具"],
|
||||
["tool off websearch", "停用工具"],
|
||||
["plugin", "已加载的插件"],
|
||||
@@ -145,6 +146,7 @@ async def test_pipeline_scheduler_initialization(pipeline_scheduler: PipelineSch
|
||||
async def test_pipeline_wakeup(pipeline_scheduler: PipelineScheduler, caplog):
|
||||
'''测试唤醒'''
|
||||
# 群聊无 @ 无指令
|
||||
caplog.clear()
|
||||
mock_event = FakeAstrMessageEvent.create_fake_event("test", is_group=True)
|
||||
with caplog.at_level(logging.DEBUG):
|
||||
await pipeline_scheduler.execute(mock_event)
|
||||
@@ -161,19 +163,21 @@ async def test_pipeline_wakeup(pipeline_scheduler: PipelineScheduler, caplog):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_wl(pipeline_scheduler: PipelineScheduler, config: AstrBotConfig, caplog):
|
||||
caplog.clear()
|
||||
mock_event = FakeAstrMessageEvent.create_fake_event("test", SESSION_ID_IN_WHITELIST, sender_id="123")
|
||||
with caplog.at_level(logging.INFO):
|
||||
await pipeline_scheduler.execute(mock_event)
|
||||
assert any("不在会话白名单中,已终止事件传播。" not in message for message in caplog.messages), "日志中未找到预期的消息"
|
||||
|
||||
mock_event = FakeAstrMessageEvent.create_fake_event("test", sender_id="123")
|
||||
with caplog.at_level(logging.INFO):
|
||||
await pipeline_scheduler.execute(mock_event)
|
||||
assert any("不在会话白名单中,已终止事件传播。" in message for message in caplog.messages), "日志中未找到预期的消息"
|
||||
|
||||
mock_event = FakeAstrMessageEvent.create_fake_event("test", SESSION_ID_IN_WHITELIST, sender_id="123")
|
||||
await pipeline_scheduler.execute(mock_event)
|
||||
assert any("不在会话白名单中,已终止事件传播。" not in message for message in caplog.messages), "日志中未找到预期的消息"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_content_safety(pipeline_scheduler: PipelineScheduler, caplog):
|
||||
# 测试默认屏蔽词
|
||||
caplog.clear()
|
||||
mock_event = FakeAstrMessageEvent.create_fake_event("色情", session_id=SESSION_ID_IN_WHITELIST) # 测试需要。
|
||||
with caplog.at_level(logging.INFO):
|
||||
await pipeline_scheduler.execute(mock_event)
|
||||
@@ -192,6 +196,7 @@ async def test_pipeline_content_safety(pipeline_scheduler: PipelineScheduler, ca
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_llm(pipeline_scheduler: PipelineScheduler, caplog):
|
||||
caplog.clear()
|
||||
mock_event = FakeAstrMessageEvent.create_fake_event("just reply me `OK`", session_id=SESSION_ID_IN_WHITELIST)
|
||||
with caplog.at_level(logging.DEBUG):
|
||||
await pipeline_scheduler.execute(mock_event)
|
||||
@@ -201,6 +206,7 @@ async def test_pipeline_llm(pipeline_scheduler: PipelineScheduler, caplog):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_websearch(pipeline_scheduler: PipelineScheduler, caplog):
|
||||
caplog.clear()
|
||||
mock_event = FakeAstrMessageEvent.create_fake_event("help me search the latest OpenAI news", session_id=SESSION_ID_IN_WHITELIST)
|
||||
with caplog.at_level(logging.DEBUG):
|
||||
await pipeline_scheduler.execute(mock_event)
|
||||
@@ -210,6 +216,7 @@ async def test_pipeline_websearch(pipeline_scheduler: PipelineScheduler, caplog)
|
||||
@pytest.mark.asyncio
|
||||
async def test_commands(pipeline_scheduler: PipelineScheduler, caplog):
|
||||
for command in TEST_COMMANDS:
|
||||
caplog.clear()
|
||||
mock_event = FakeAstrMessageEvent.create_fake_event(command[0], session_id=SESSION_ID_IN_WHITELIST)
|
||||
with caplog.at_level(logging.DEBUG):
|
||||
await pipeline_scheduler.execute(mock_event)
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import pytest
|
||||
import os
|
||||
import shutil
|
||||
from astrbot.core.star.star_manager import PluginManager
|
||||
from astrbot.core.star.star_handler import star_handlers_registry
|
||||
from astrbot.core.star.star import star_registry
|
||||
@@ -34,7 +33,8 @@ def test_plugin_manager_initialization(plugin_manager_pm: PluginManager):
|
||||
assert plugin_manager_pm.context is not None
|
||||
assert plugin_manager_pm.config is not None
|
||||
|
||||
def test_plugin_manager_reload(plugin_manager_pm: PluginManager):
|
||||
@pytest.mark.asyncio
|
||||
async def test_plugin_manager_reload(plugin_manager_pm: PluginManager):
|
||||
success, err_message = plugin_manager_pm.reload()
|
||||
assert success is True
|
||||
assert err_message is None
|
||||
@@ -86,8 +86,3 @@ async def test_plugin_crud(plugin_manager_pm: PluginManager):
|
||||
await plugin_manager_pm.uninstall_plugin("astrbot_plugin_essentialhaha")
|
||||
|
||||
# TODO: file installation
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user