test: dashboard test

This commit is contained in:
Soulter
2024-12-26 22:59:10 +08:00
parent e6205e9aad
commit b72c69892e
10 changed files with 183 additions and 27 deletions
-1
View File
@@ -44,7 +44,6 @@ class Stage(abc.ABC):
try:
ready_to_call = handler(event, **params)
except TypeError as e:
print(e)
# 向下兼容
ready_to_call = handler(event, ctx.plugin_manager.context, **params)
@@ -58,7 +58,7 @@ class ProviderZhipu(ProviderOpenAIOfficial):
payloads = {
"messages": context_query,
**model_cfgs
}
}
llm_response = None
try:
llm_response = await self._query(payloads, func_tool)
+6
View File
@@ -140,6 +140,11 @@ class PluginManager:
def reload(self):
'''扫描并加载所有的 Star'''
for smd in star_registry:
logger.debug(f"尝试终止插件 {smd.name} ...")
if hasattr(smd.star_cls, "__del__"):
smd.star_cls.__del__()
star_handlers_registry.clear()
star_handlers_registry.star_handlers_map.clear()
star_map.clear()
@@ -272,6 +277,7 @@ class PluginManager:
raise Exception("该插件是 AstrBot 保留插件,无法更新。")
await self.updator.update(plugin)
self.reload()
def install_plugin_from_file(self, zip_file_path: str):
desti_dir = os.path.join(self.plugin_store_path, os.path.basename(zip_file_path))
+1 -2
View File
@@ -53,7 +53,6 @@ class PluginUpdator(RepoZipUpdator):
files = os.listdir(os.path.join(target_dir, update_dir))
for f in files:
logger.info(f"移动更新文件/目录: {f}")
if os.path.isdir(os.path.join(target_dir, update_dir, f)):
if os.path.exists(os.path.join(target_dir, f)):
shutil.rmtree(os.path.join(target_dir, f), onerror=on_error)
@@ -63,7 +62,7 @@ class PluginUpdator(RepoZipUpdator):
shutil.move(os.path.join(target_dir, update_dir, f), target_dir)
try:
logger.info(f"删除临时更新文件: {zip_path}{os.path.join(target_dir, update_dir)}")
logger.info(f"删除临时文件: {zip_path}{os.path.join(target_dir, update_dir)}")
shutil.rmtree(os.path.join(target_dir, update_dir), onerror=on_error)
os.remove(zip_path)
except BaseException:
+1 -1
View File
@@ -111,7 +111,7 @@ class RepoZipUpdator():
releases = await self.fetch_release_info(url=release_url)
if not releases:
# download from the default branch directly.
logger.warning(f"未在仓库 {author}/{repo} 中找到任何发布版本,从默认分支下载。")
logger.warning(f"未在仓库 {author}/{repo} 中找到任何发布版本,正在从默认分支下载。")
release_url = f"https://github.com/{author}/{repo}/archive/refs/heads/master.zip"
else:
release_url = releases[0]['zipball_url']
+2 -5
View File
@@ -53,7 +53,6 @@ class PluginRoute(Route):
try:
logger.info(f"正在安装插件 {repo_url}")
await self.plugin_manager.install_plugin(repo_url)
self.core_lifecycle.restart()
logger.info(f"安装插件 {repo_url} 成功。")
return Response().ok(None, "安装成功。").__dict__
except Exception as e:
@@ -69,7 +68,6 @@ class PluginRoute(Route):
await file.save(file_path)
self.plugin_manager.install_plugin_from_file(file_path)
logger.info(f"安装插件 {file.filename} 成功")
self.core_lifecycle.restart()
return Response().ok(None, "安装成功。").__dict__
except Exception as e:
logger.error(traceback.format_exc())
@@ -93,9 +91,8 @@ class PluginRoute(Route):
try:
logger.info(f"正在更新插件 {plugin_name}")
await self.plugin_manager.update_plugin(plugin_name)
self.core_lifecycle.restart()
logger.info(f"更新插件 {plugin_name} 成功,2秒后重启")
return Response().ok(None, "更新成功,程序将在 2 秒内重启。").__dict__
logger.info(f"更新插件 {plugin_name} 成功。")
return Response().ok(None, "更新成功。").__dict__
except Exception as e:
logger.error(f"/api/extensions/update: {traceback.format_exc()}")
return Response().error(str(e)).__dict__
+6 -2
View File
@@ -32,6 +32,7 @@ class UpdateRoute(Route):
async def update_project(self):
data = await request.json
version = data.get('version', '')
reboot = data.get('reboot', True)
if version == "" or version == "latest":
latest = True
version = ''
@@ -39,8 +40,11 @@ class UpdateRoute(Route):
latest = False
try:
await self.astrbot_updator.update(latest=latest, version=version)
threading.Thread(target=self.astrbot_updator._reboot, args=(2, )).start()
return Response().ok(None, "更新成功,AstrBot 将在 2 秒内全量重启以应用新的代码。").__dict__
if reboot:
threading.Thread(target=self.astrbot_updator._reboot, args=(2, )).start()
return Response().ok(None, "更新成功,AstrBot 将在 2 秒内全量重启以应用新的代码。").__dict__
else:
return Response().ok(None, "更新成功,AstrBot 将在下次启动时应用新的代码。").__dict__
except Exception as e:
logger.error(f"/api/update_project: {traceback.format_exc()}")
return Response().error(e.__str__()).__dict__
+149
View File
@@ -0,0 +1,149 @@
import pytest
import os
from quart import Quart
from astrbot.dashboard.server import AstrBotDashboard
from astrbot.core.db.sqlite import SQLiteDatabase
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core import LogBroker
from astrbot.core.star.star_handler import star_handlers_registry
from astrbot.core.star.star import star_registry
from astrbot.core.updator import AstrBotUpdator
@pytest.fixture(scope="module")
def core_lifecycle():
db = SQLiteDatabase("data/data_v3.db")
log_broker = LogBroker()
core_lifecycle = AstrBotCoreLifecycle(log_broker, db)
return core_lifecycle
@pytest.fixture(scope="module")
def app(core_lifecycle):
db = SQLiteDatabase("data/data_v3.db")
server = AstrBotDashboard(core_lifecycle, db)
return server.app
@pytest.fixture(scope="module")
def header():
return {}
@pytest.mark.asyncio
async def test_init_core_lifecycle(core_lifecycle):
await core_lifecycle.initialize()
assert core_lifecycle is not None
@pytest.mark.asyncio
async def test_auth_login(app: Quart, core_lifecycle: AstrBotCoreLifecycle, header: dict):
test_client = app.test_client()
response = await test_client.post('/api/auth/login', json={
"username": "wrong",
"password": "password"
})
data = await response.get_json()
assert data['status'] == 'error'
response = await test_client.post('/api/auth/login', json={
"username": core_lifecycle.astrbot_config['dashboard']['username'],
"password": core_lifecycle.astrbot_config['dashboard']['password']
})
data = await response.get_json()
assert data['status'] == 'ok' and 'token' in data['data']
header['Authorization'] = f"Bearer {data['data']['token']}"
@pytest.mark.asyncio
async def test_get_stat(app: Quart, header: dict):
test_client = app.test_client()
response = await test_client.get('/api/stat/get')
assert response.status_code == 401
response = await test_client.get('/api/stat/get', headers=header)
assert response.status_code == 200
data = await response.get_json()
assert data['status'] == 'ok' and 'platform' in data['data']
@pytest.mark.asyncio
async def test_plugins(app: Quart, header: dict):
test_client = app.test_client()
# 已经安装的插件
response = await test_client.get('/api/plugin/get', headers=header)
assert response.status_code == 200
data = await response.get_json()
assert data['status'] == 'ok'
# 插件市场
response = await test_client.get('/api/plugin/market_list', headers=header)
assert response.status_code == 200
data = await response.get_json()
assert data['status'] == 'ok'
# 插件安装
response = await test_client.post('/api/plugin/install', json={
"url": "https://github.com/Soulter/astrbot_plugin_essential"
}, headers=header)
assert response.status_code == 200
data = await response.get_json()
assert data['status'] == 'ok'
exists = False
for md in star_registry:
if md.name == "astrbot_plugin_essential":
exists = True
break
assert exists is True, "插件 astrbot_plugin_essential 未成功载入"
# 插件更新
response = await test_client.post('/api/plugin/update', json={
"name": "astrbot_plugin_essential"
}, headers=header)
assert response.status_code == 200
data = await response.get_json()
assert data['status'] == 'ok'
# 插件卸载
response = await test_client.post('/api/plugin/uninstall', json={
"name": "astrbot_plugin_essential"
}, headers=header)
assert response.status_code == 200
data = await response.get_json()
assert data['status'] == 'ok'
exists = False
for md in star_registry:
if md.name == "astrbot_plugin_essential":
exists = True
break
assert exists is False, "插件 astrbot_plugin_essential 未成功卸载"
exists = False
for md in star_handlers_registry:
if "astrbot_plugin_essential" in md.handler_module_path:
exists = True
break
assert exists is False, "插件 astrbot_plugin_essential 未成功卸载"
@pytest.mark.asyncio
async def test_check_update(app: Quart, header: dict):
test_client = app.test_client()
response = await test_client.get('/api/update/check', headers=header)
assert response.status_code == 200
data = await response.get_json()
assert data['status'] == 'success'
@pytest.mark.asyncio
async def test_do_update(app: Quart, header: dict, core_lifecycle: AstrBotCoreLifecycle):
global VERSION
test_client = app.test_client()
os.makedirs("data/astrbot_release", exist_ok=True)
core_lifecycle.astrbot_updator.MAIN_PATH = "data/astrbot_release"
VERSION = "114.514.1919810"
response = await test_client.post('/api/update/do', headers=header, json={
"version": "latest"
})
assert response.status_code == 200
data = await response.get_json()
assert data['status'] == 'error' # 已经是最新版本
response = await test_client.post('/api/update/do', headers=header, json={
"version": "v3.4.0",
"reboot": False
})
assert response.status_code == 200
data = await response.get_json()
assert data['status'] == 'ok'
assert os.path.exists("data/astrbot_release/astrbot")
+15 -8
View File
@@ -1,4 +1,6 @@
import pytest, logging, os
import pytest
import logging
import os
from astrbot.core.pipeline.scheduler import PipelineScheduler, PipelineContext
from astrbot.core.star import PluginManager
from astrbot.core.config.astrbot_config import AstrBotConfig
@@ -11,7 +13,6 @@ from astrbot.core.platform.manager import PlatformManager
from astrbot.core.provider.manager import ProviderManager
from astrbot.core.db.sqlite import SQLiteDatabase
from astrbot.core.star.context import Context
from astrbot.core import logger
from asyncio import Queue
SESSION_ID_IN_WHITELIST = "test_sid_wl"
@@ -29,7 +30,7 @@ TEST_LLM_PROVIDER = {
TEST_COMMANDS = [
["help", "已注册的 AstrBot 内置指令"],
["tool ls", "查看、激活、停用当前注册的函数工具"],
["tool ls", "函数工具"],
["tool on websearch", "激活工具"],
["tool off websearch", "停用工具"],
["plugin", "已加载的插件"],
@@ -145,6 +146,7 @@ async def test_pipeline_scheduler_initialization(pipeline_scheduler: PipelineSch
async def test_pipeline_wakeup(pipeline_scheduler: PipelineScheduler, caplog):
'''测试唤醒'''
# 群聊无 @ 无指令
caplog.clear()
mock_event = FakeAstrMessageEvent.create_fake_event("test", is_group=True)
with caplog.at_level(logging.DEBUG):
await pipeline_scheduler.execute(mock_event)
@@ -161,19 +163,21 @@ async def test_pipeline_wakeup(pipeline_scheduler: PipelineScheduler, caplog):
@pytest.mark.asyncio
async def test_pipeline_wl(pipeline_scheduler: PipelineScheduler, config: AstrBotConfig, caplog):
caplog.clear()
mock_event = FakeAstrMessageEvent.create_fake_event("test", SESSION_ID_IN_WHITELIST, sender_id="123")
with caplog.at_level(logging.INFO):
await pipeline_scheduler.execute(mock_event)
assert any("不在会话白名单中,已终止事件传播。" not in message for message in caplog.messages), "日志中未找到预期的消息"
mock_event = FakeAstrMessageEvent.create_fake_event("test", sender_id="123")
with caplog.at_level(logging.INFO):
await pipeline_scheduler.execute(mock_event)
assert any("不在会话白名单中,已终止事件传播。" in message for message in caplog.messages), "日志中未找到预期的消息"
mock_event = FakeAstrMessageEvent.create_fake_event("test", SESSION_ID_IN_WHITELIST, sender_id="123")
await pipeline_scheduler.execute(mock_event)
assert any("不在会话白名单中,已终止事件传播。" not in message for message in caplog.messages), "日志中未找到预期的消息"
@pytest.mark.asyncio
async def test_pipeline_content_safety(pipeline_scheduler: PipelineScheduler, caplog):
# 测试默认屏蔽词
caplog.clear()
mock_event = FakeAstrMessageEvent.create_fake_event("色情", session_id=SESSION_ID_IN_WHITELIST) # 测试需要。
with caplog.at_level(logging.INFO):
await pipeline_scheduler.execute(mock_event)
@@ -192,6 +196,7 @@ async def test_pipeline_content_safety(pipeline_scheduler: PipelineScheduler, ca
@pytest.mark.asyncio
async def test_pipeline_llm(pipeline_scheduler: PipelineScheduler, caplog):
caplog.clear()
mock_event = FakeAstrMessageEvent.create_fake_event("just reply me `OK`", session_id=SESSION_ID_IN_WHITELIST)
with caplog.at_level(logging.DEBUG):
await pipeline_scheduler.execute(mock_event)
@@ -201,6 +206,7 @@ async def test_pipeline_llm(pipeline_scheduler: PipelineScheduler, caplog):
@pytest.mark.asyncio
async def test_pipeline_websearch(pipeline_scheduler: PipelineScheduler, caplog):
caplog.clear()
mock_event = FakeAstrMessageEvent.create_fake_event("help me search the latest OpenAI news", session_id=SESSION_ID_IN_WHITELIST)
with caplog.at_level(logging.DEBUG):
await pipeline_scheduler.execute(mock_event)
@@ -210,6 +216,7 @@ async def test_pipeline_websearch(pipeline_scheduler: PipelineScheduler, caplog)
@pytest.mark.asyncio
async def test_commands(pipeline_scheduler: PipelineScheduler, caplog):
for command in TEST_COMMANDS:
caplog.clear()
mock_event = FakeAstrMessageEvent.create_fake_event(command[0], session_id=SESSION_ID_IN_WHITELIST)
with caplog.at_level(logging.DEBUG):
await pipeline_scheduler.execute(mock_event)
+2 -7
View File
@@ -1,6 +1,5 @@
import pytest
import os
import shutil
from astrbot.core.star.star_manager import PluginManager
from astrbot.core.star.star_handler import star_handlers_registry
from astrbot.core.star.star import star_registry
@@ -34,7 +33,8 @@ def test_plugin_manager_initialization(plugin_manager_pm: PluginManager):
assert plugin_manager_pm.context is not None
assert plugin_manager_pm.config is not None
def test_plugin_manager_reload(plugin_manager_pm: PluginManager):
@pytest.mark.asyncio
async def test_plugin_manager_reload(plugin_manager_pm: PluginManager):
success, err_message = plugin_manager_pm.reload()
assert success is True
assert err_message is None
@@ -86,8 +86,3 @@ async def test_plugin_crud(plugin_manager_pm: PluginManager):
await plugin_manager_pm.uninstall_plugin("astrbot_plugin_essentialhaha")
# TODO: file installation