fix: unit tests (#2760)
* fix:修复了main和plugin_manager部分单元测试 * fix: 修复了dashboard部分测试 * remove: 删除暂无用的配置测试脚本 * perf:拆分插件增查删改为独立的单元测试 * refactor: 重构插件管理器测试,使用临时环境隔离测试实例 * test: 增加对仪表板文件检查的单元测试,涵盖不同情况 * style: format code * remove: 删除未使用的导入语句 * delete: remove unused test file for pipeline --------- Co-authored-by: Soulter <905617992@qq.com>
This commit is contained in:
+79
-43
@@ -1,5 +1,7 @@
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
import os
|
||||
import asyncio
|
||||
from quart import Quart
|
||||
from astrbot.dashboard.server import AstrBotDashboard
|
||||
from astrbot.core.db.sqlite import SQLiteDatabase
|
||||
@@ -9,36 +11,46 @@ from astrbot.core.star.star_handler import star_handlers_registry
|
||||
from astrbot.core.star.star import star_registry
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def core_lifecycle_td():
|
||||
db = SQLiteDatabase("data/data_v3.db")
|
||||
@pytest_asyncio.fixture(scope="module")
|
||||
async def core_lifecycle_td(tmp_path_factory):
|
||||
"""Creates and initializes a core lifecycle instance with a temporary database."""
|
||||
tmp_db_path = tmp_path_factory.mktemp("data") / "test_data_v3.db"
|
||||
db = SQLiteDatabase(str(tmp_db_path))
|
||||
log_broker = LogBroker()
|
||||
core_lifecycle_td = AstrBotCoreLifecycle(log_broker, db)
|
||||
return core_lifecycle_td
|
||||
core_lifecycle = AstrBotCoreLifecycle(log_broker, db)
|
||||
await core_lifecycle.initialize()
|
||||
return core_lifecycle
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def app(core_lifecycle_td):
|
||||
db = SQLiteDatabase("data/data_v3.db")
|
||||
server = AstrBotDashboard(core_lifecycle_td, db)
|
||||
def app(core_lifecycle_td: AstrBotCoreLifecycle):
|
||||
"""Creates a Quart app instance for testing."""
|
||||
shutdown_event = asyncio.Event()
|
||||
# The db instance is already part of the core_lifecycle_td
|
||||
server = AstrBotDashboard(core_lifecycle_td, core_lifecycle_td.db, shutdown_event)
|
||||
return server.app
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def header():
|
||||
return {}
|
||||
@pytest_asyncio.fixture(scope="module")
|
||||
async def authenticated_header(app: Quart, core_lifecycle_td: AstrBotCoreLifecycle):
|
||||
"""Handles login and returns an authenticated header."""
|
||||
test_client = app.test_client()
|
||||
response = await test_client.post(
|
||||
"/api/auth/login",
|
||||
json={
|
||||
"username": core_lifecycle_td.astrbot_config["dashboard"]["username"],
|
||||
"password": core_lifecycle_td.astrbot_config["dashboard"]["password"],
|
||||
},
|
||||
)
|
||||
data = await response.get_json()
|
||||
assert data["status"] == "ok"
|
||||
token = data["data"]["token"]
|
||||
return {"Authorization": f"Bearer {token}"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_init_core_lifecycle_td(core_lifecycle_td):
|
||||
await core_lifecycle_td.initialize()
|
||||
assert core_lifecycle_td is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auth_login(
|
||||
app: Quart, core_lifecycle_td: AstrBotCoreLifecycle, header: dict
|
||||
):
|
||||
async def test_auth_login(app: Quart, core_lifecycle_td: AstrBotCoreLifecycle):
|
||||
"""Tests the login functionality with both wrong and correct credentials."""
|
||||
test_client = app.test_client()
|
||||
response = await test_client.post(
|
||||
"/api/auth/login", json={"username": "wrong", "password": "password"}
|
||||
@@ -55,31 +67,32 @@ async def test_auth_login(
|
||||
)
|
||||
data = await response.get_json()
|
||||
assert data["status"] == "ok" and "token" in data["data"]
|
||||
header["Authorization"] = f"Bearer {data['data']['token']}"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_stat(app: Quart, header: dict):
|
||||
async def test_get_stat(app: Quart, authenticated_header: dict):
|
||||
test_client = app.test_client()
|
||||
response = await test_client.get("/api/stat/get")
|
||||
assert response.status_code == 401
|
||||
response = await test_client.get("/api/stat/get", headers=header)
|
||||
response = await test_client.get("/api/stat/get", headers=authenticated_header)
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data["status"] == "ok" and "platform" in data["data"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_plugins(app: Quart, header: dict):
|
||||
async def test_plugins(app: Quart, authenticated_header: dict):
|
||||
test_client = app.test_client()
|
||||
# 已经安装的插件
|
||||
response = await test_client.get("/api/plugin/get", headers=header)
|
||||
response = await test_client.get("/api/plugin/get", headers=authenticated_header)
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data["status"] == "ok"
|
||||
|
||||
# 插件市场
|
||||
response = await test_client.get("/api/plugin/market_list", headers=header)
|
||||
response = await test_client.get(
|
||||
"/api/plugin/market_list", headers=authenticated_header
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data["status"] == "ok"
|
||||
@@ -88,7 +101,7 @@ async def test_plugins(app: Quart, header: dict):
|
||||
response = await test_client.post(
|
||||
"/api/plugin/install",
|
||||
json={"url": "https://github.com/Soulter/astrbot_plugin_essential"},
|
||||
headers=header,
|
||||
headers=authenticated_header,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
@@ -102,7 +115,9 @@ async def test_plugins(app: Quart, header: dict):
|
||||
|
||||
# 插件更新
|
||||
response = await test_client.post(
|
||||
"/api/plugin/update", json={"name": "astrbot_plugin_essential"}, headers=header
|
||||
"/api/plugin/update",
|
||||
json={"name": "astrbot_plugin_essential"},
|
||||
headers=authenticated_header,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
@@ -112,7 +127,7 @@ async def test_plugins(app: Quart, header: dict):
|
||||
response = await test_client.post(
|
||||
"/api/plugin/uninstall",
|
||||
json={"name": "astrbot_plugin_essential"},
|
||||
headers=header,
|
||||
headers=authenticated_header,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
@@ -132,9 +147,9 @@ async def test_plugins(app: Quart, header: dict):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_update(app: Quart, header: dict):
|
||||
async def test_check_update(app: Quart, authenticated_header: dict):
|
||||
test_client = app.test_client()
|
||||
response = await test_client.get("/api/update/check", headers=header)
|
||||
response = await test_client.get("/api/update/check", headers=authenticated_header)
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data["status"] == "success"
|
||||
@@ -142,24 +157,45 @@ async def test_check_update(app: Quart, header: dict):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_do_update(
|
||||
app: Quart, header: dict, core_lifecycle_td: AstrBotCoreLifecycle
|
||||
app: Quart,
|
||||
authenticated_header: dict,
|
||||
core_lifecycle_td: AstrBotCoreLifecycle,
|
||||
monkeypatch,
|
||||
tmp_path_factory,
|
||||
):
|
||||
global VERSION
|
||||
test_client = app.test_client()
|
||||
os.makedirs("data/astrbot_release", exist_ok=True)
|
||||
core_lifecycle_td.astrbot_updator.MAIN_PATH = "data/astrbot_release"
|
||||
VERSION = "114.514.1919810"
|
||||
response = await test_client.post(
|
||||
"/api/update/do", headers=header, json={"version": "latest"}
|
||||
|
||||
# Use a temporary path for the mock update to avoid side effects
|
||||
temp_release_dir = tmp_path_factory.mktemp("release")
|
||||
release_path = temp_release_dir / "astrbot"
|
||||
|
||||
async def mock_update(*args, **kwargs):
|
||||
"""Mocks the update process by creating a directory in the temp path."""
|
||||
os.makedirs(release_path, exist_ok=True)
|
||||
return
|
||||
|
||||
async def mock_download_dashboard(*args, **kwargs):
|
||||
"""Mocks the dashboard download to prevent network access."""
|
||||
return
|
||||
|
||||
async def mock_pip_install(*args, **kwargs):
|
||||
"""Mocks pip install to prevent actual installation."""
|
||||
return
|
||||
|
||||
monkeypatch.setattr(core_lifecycle_td.astrbot_updator, "update", mock_update)
|
||||
monkeypatch.setattr(
|
||||
"astrbot.dashboard.routes.update.download_dashboard", mock_download_dashboard
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"astrbot.dashboard.routes.update.pip_installer.install", mock_pip_install
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data["status"] == "error" # 已经是最新版本
|
||||
|
||||
response = await test_client.post(
|
||||
"/api/update/do", headers=header, json={"version": "v3.4.0", "reboot": False}
|
||||
"/api/update/do",
|
||||
headers=authenticated_header,
|
||||
json={"version": "v3.4.0", "reboot": False},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data["status"] == "ok"
|
||||
assert os.path.exists("data/astrbot_release/astrbot")
|
||||
assert os.path.exists(release_path)
|
||||
|
||||
+51
-18
@@ -1,5 +1,9 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
# 将项目根目录添加到 sys.path
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
|
||||
import pytest
|
||||
from unittest import mock
|
||||
from main import check_env, check_dashboard_files
|
||||
@@ -27,29 +31,58 @@ def test_check_env(monkeypatch):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_dashboard_files(monkeypatch):
|
||||
async def test_check_dashboard_files_not_exists(monkeypatch):
|
||||
"""Tests dashboard download when files do not exist."""
|
||||
monkeypatch.setattr(os.path, "exists", lambda x: False)
|
||||
|
||||
async def mock_get(*args, **kwargs):
|
||||
class MockResponse:
|
||||
status = 200
|
||||
with mock.patch("main.download_dashboard") as mock_download:
|
||||
await check_dashboard_files()
|
||||
mock_download.assert_called_once()
|
||||
|
||||
async def read(self):
|
||||
return b"content"
|
||||
|
||||
return MockResponse()
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_dashboard_files_exists_and_version_match(monkeypatch):
|
||||
"""Tests that dashboard is not downloaded when it exists and version matches."""
|
||||
# Mock os.path.exists to return True
|
||||
monkeypatch.setattr(os.path, "exists", lambda x: True)
|
||||
|
||||
with mock.patch("aiohttp.ClientSession.get", new=mock_get):
|
||||
with mock.patch("builtins.open", mock.mock_open()) as mock_file:
|
||||
with mock.patch("zipfile.ZipFile.extractall") as mock_extractall:
|
||||
# Mock get_dashboard_version to return the current version
|
||||
with mock.patch("main.get_dashboard_version") as mock_get_version:
|
||||
# We need to import VERSION from main's context
|
||||
from main import VERSION
|
||||
|
||||
async def mock_aenter(_):
|
||||
await check_dashboard_files()
|
||||
mock_file.assert_called_once_with("data/dashboard.zip", "wb")
|
||||
mock_extractall.assert_called_once()
|
||||
mock_get_version.return_value = f"v{VERSION}"
|
||||
|
||||
async def mock_aexit(obj, exc_type, exc, tb):
|
||||
return
|
||||
with mock.patch("main.download_dashboard") as mock_download:
|
||||
await check_dashboard_files()
|
||||
# Assert that download_dashboard was NOT called
|
||||
mock_download.assert_not_called()
|
||||
|
||||
mock_extractall.__aenter__ = mock_aenter
|
||||
mock_extractall.__aexit__ = mock_aexit
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_dashboard_files_exists_but_version_mismatch(monkeypatch):
|
||||
"""Tests that a warning is logged when dashboard version mismatches."""
|
||||
monkeypatch.setattr(os.path, "exists", lambda x: True)
|
||||
|
||||
with mock.patch("main.get_dashboard_version") as mock_get_version:
|
||||
mock_get_version.return_value = "v0.0.1" # A different version
|
||||
|
||||
with mock.patch("main.logger.warning") as mock_logger_warning:
|
||||
await check_dashboard_files()
|
||||
mock_logger_warning.assert_called_once()
|
||||
call_args, _ = mock_logger_warning.call_args
|
||||
assert "不符" in call_args[0]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_dashboard_files_with_webui_dir_arg(monkeypatch):
|
||||
"""Tests that providing a valid webui_dir skips all checks."""
|
||||
valid_dir = "/tmp/my-custom-webui"
|
||||
monkeypatch.setattr(os.path, "exists", lambda path: path == valid_dir)
|
||||
|
||||
with mock.patch("main.download_dashboard") as mock_download:
|
||||
with mock.patch("main.get_dashboard_version") as mock_get_version:
|
||||
result = await check_dashboard_files(webui_dir=valid_dir)
|
||||
assert result == valid_dir
|
||||
mock_download.assert_not_called()
|
||||
mock_get_version.assert_not_called()
|
||||
|
||||
@@ -1,285 +0,0 @@
|
||||
import pytest
|
||||
import logging
|
||||
import os
|
||||
import asyncio
|
||||
from astrbot.core.pipeline.scheduler import PipelineScheduler, PipelineContext
|
||||
from astrbot.core.star import PluginManager
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.platform.astrbot_message import (
|
||||
AstrBotMessage,
|
||||
MessageMember,
|
||||
MessageType,
|
||||
)
|
||||
from astrbot.core.message.message_event_result import MessageChain, ResultContentType
|
||||
from astrbot.core.message.components import Plain, At
|
||||
from astrbot.core.platform.platform_metadata import PlatformMetadata
|
||||
from astrbot.core.platform.manager import PlatformManager
|
||||
from astrbot.core.provider.manager import ProviderManager
|
||||
from astrbot.core.db.sqlite import SQLiteDatabase
|
||||
from astrbot.core.star.context import Context
|
||||
from asyncio import Queue
|
||||
|
||||
SESSION_ID_IN_WHITELIST = "test_sid_wl"
|
||||
SESSION_ID_NOT_IN_WHITELIST = "test_sid"
|
||||
TEST_LLM_PROVIDER = {
|
||||
"id": "zhipu_default",
|
||||
"type": "openai_chat_completion",
|
||||
"enable": True,
|
||||
"key": [os.getenv("ZHIPU_API_KEY")],
|
||||
"api_base": "https://open.bigmodel.cn/api/paas/v4/",
|
||||
"model_config": {
|
||||
"model": "glm-4-flash",
|
||||
},
|
||||
}
|
||||
|
||||
TEST_COMMANDS = [
|
||||
["help", "已注册的 AstrBot 内置指令"],
|
||||
["tool ls", "函数工具"],
|
||||
["tool on websearch", "激活工具"],
|
||||
["tool off websearch", "停用工具"],
|
||||
["plugin", "已加载的插件"],
|
||||
["t2i", "文本转图片模式"],
|
||||
["sid", "此 ID 可用于设置会话白名单。"],
|
||||
["op test_op", "授权成功。"],
|
||||
["deop test_op", "取消授权成功。"],
|
||||
["wl test_platform:FriendMessage:test_sid_wl2", "添加白名单成功。"],
|
||||
["dwl test_platform:FriendMessage:test_sid_wl2", "删除白名单成功。"],
|
||||
["provider", "当前载入的 LLM 提供商"],
|
||||
["reset", "重置成功"],
|
||||
# ["model", "查看、切换提供商模型列表"],
|
||||
["history", "历史记录:"],
|
||||
["key", "当前 Key"],
|
||||
["persona", "[Persona]"],
|
||||
]
|
||||
|
||||
|
||||
class FakeAstrMessageEvent(AstrMessageEvent):
|
||||
def __init__(self, abm: AstrBotMessage = None):
|
||||
meta = PlatformMetadata("test_platform", "test")
|
||||
super().__init__(
|
||||
message_str=abm.message_str,
|
||||
message_obj=abm,
|
||||
platform_meta=meta,
|
||||
session_id=abm.session_id,
|
||||
)
|
||||
|
||||
async def send(self, message: MessageChain):
|
||||
await super().send(message)
|
||||
|
||||
@staticmethod
|
||||
def create_fake_event(
|
||||
message_str: str,
|
||||
session_id: str = "test_sid",
|
||||
is_at: bool = False,
|
||||
is_group: bool = False,
|
||||
sender_id: str = "123456",
|
||||
):
|
||||
abm = AstrBotMessage()
|
||||
abm.message_str = message_str
|
||||
abm.group_id = "test"
|
||||
abm.message = [Plain(message_str)]
|
||||
if is_at:
|
||||
abm.message.append(At(qq="bot"))
|
||||
abm.self_id = "bot"
|
||||
abm.sender = MessageMember(sender_id, "mika")
|
||||
abm.timestamp = 1234567890
|
||||
abm.message_id = "test"
|
||||
abm.session_id = session_id
|
||||
if is_group:
|
||||
abm.type = MessageType.GROUP_MESSAGE
|
||||
else:
|
||||
abm.type = MessageType.FRIEND_MESSAGE
|
||||
return FakeAstrMessageEvent(abm)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def event_queue():
|
||||
return Queue()
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def config():
|
||||
cfg = AstrBotConfig()
|
||||
cfg["platform_settings"]["id_whitelist"] = [
|
||||
"test_platform:FriendMessage:test_sid_wl",
|
||||
"test_platform:GroupMessage:test_sid_wl",
|
||||
]
|
||||
cfg["admins_id"] = ["123456"]
|
||||
cfg["content_safety"]["internal_keywords"]["extra_keywords"] = ["^TEST_NEGATIVE"]
|
||||
cfg["provider"] = [TEST_LLM_PROVIDER]
|
||||
return cfg
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def db():
|
||||
return SQLiteDatabase("data/data_v3.db")
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def platform_manager(event_queue, config):
|
||||
return PlatformManager(config, event_queue)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def provider_manager(config, db):
|
||||
return ProviderManager(config, db)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def star_context(event_queue, config, db, platform_manager, provider_manager):
|
||||
star_context = Context(event_queue, config, db, provider_manager, platform_manager)
|
||||
return star_context
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def plugin_manager(star_context, config):
|
||||
plugin_manager = PluginManager(star_context, config)
|
||||
# await plugin_manager.reload()
|
||||
asyncio.run(plugin_manager.reload())
|
||||
return plugin_manager
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def pipeline_context(config, plugin_manager):
|
||||
return PipelineContext(config, plugin_manager)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def pipeline_scheduler(pipeline_context):
|
||||
return PipelineScheduler(pipeline_context)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_platform_initialization(platform_manager: PlatformManager):
|
||||
await platform_manager.initialize()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_provider_initialization(provider_manager: ProviderManager):
|
||||
await provider_manager.initialize()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_scheduler_initialization(pipeline_scheduler: PipelineScheduler):
|
||||
await pipeline_scheduler.initialize()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_wakeup(pipeline_scheduler: PipelineScheduler, caplog):
|
||||
"""测试唤醒"""
|
||||
# 群聊无 @ 无指令
|
||||
caplog.clear()
|
||||
mock_event = FakeAstrMessageEvent.create_fake_event("test", is_group=True)
|
||||
with caplog.at_level(logging.DEBUG):
|
||||
await pipeline_scheduler.execute(mock_event)
|
||||
assert any(
|
||||
"执行阶段 WhitelistCheckStage" not in message for message in caplog.messages
|
||||
)
|
||||
# 群聊有 @ 无指令
|
||||
mock_event = FakeAstrMessageEvent.create_fake_event(
|
||||
"test", is_group=True, is_at=True
|
||||
)
|
||||
with caplog.at_level(logging.DEBUG):
|
||||
await pipeline_scheduler.execute(mock_event)
|
||||
assert any("执行阶段 WhitelistCheckStage" in message for message in caplog.messages)
|
||||
# 群聊有指令
|
||||
mock_event = FakeAstrMessageEvent.create_fake_event(
|
||||
"/help", is_group=True, session_id=SESSION_ID_IN_WHITELIST
|
||||
)
|
||||
await pipeline_scheduler.execute(mock_event)
|
||||
assert mock_event._has_send_oper is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_wl(
|
||||
pipeline_scheduler: PipelineScheduler, config: AstrBotConfig, caplog
|
||||
):
|
||||
caplog.clear()
|
||||
mock_event = FakeAstrMessageEvent.create_fake_event(
|
||||
"test", SESSION_ID_IN_WHITELIST, sender_id="123"
|
||||
)
|
||||
with caplog.at_level(logging.INFO):
|
||||
await pipeline_scheduler.execute(mock_event)
|
||||
assert any(
|
||||
"不在会话白名单中,已终止事件传播。" not in message
|
||||
for message in caplog.messages
|
||||
), "日志中未找到预期的消息"
|
||||
|
||||
mock_event = FakeAstrMessageEvent.create_fake_event("test", sender_id="123")
|
||||
with caplog.at_level(logging.INFO):
|
||||
await pipeline_scheduler.execute(mock_event)
|
||||
assert any(
|
||||
"不在会话白名单中,已终止事件传播。" in message for message in caplog.messages
|
||||
), "日志中未找到预期的消息"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_content_safety(pipeline_scheduler: PipelineScheduler, caplog):
|
||||
# 测试默认屏蔽词
|
||||
caplog.clear()
|
||||
mock_event = FakeAstrMessageEvent.create_fake_event(
|
||||
"色情", session_id=SESSION_ID_IN_WHITELIST
|
||||
) # 测试需要。
|
||||
with caplog.at_level(logging.INFO):
|
||||
await pipeline_scheduler.execute(mock_event)
|
||||
assert any("内容安全检查不通过" in message for message in caplog.messages), (
|
||||
"日志中未找到预期的消息"
|
||||
)
|
||||
# 测试额外屏蔽词
|
||||
mock_event = FakeAstrMessageEvent.create_fake_event(
|
||||
"TEST_NEGATIVE", session_id=SESSION_ID_IN_WHITELIST
|
||||
)
|
||||
with caplog.at_level(logging.INFO):
|
||||
await pipeline_scheduler.execute(mock_event)
|
||||
assert any("内容安全检查不通过" in message for message in caplog.messages), (
|
||||
"日志中未找到预期的消息"
|
||||
)
|
||||
mock_event = FakeAstrMessageEvent.create_fake_event(
|
||||
"_TEST_NEGATIVE", session_id=SESSION_ID_IN_WHITELIST
|
||||
)
|
||||
with caplog.at_level(logging.INFO):
|
||||
await pipeline_scheduler.execute(mock_event)
|
||||
assert any("内容安全检查不通过" not in message for message in caplog.messages)
|
||||
# TODO: 测试 百度AI 的内容安全检查
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_llm(pipeline_scheduler: PipelineScheduler, caplog):
|
||||
caplog.clear()
|
||||
mock_event = FakeAstrMessageEvent.create_fake_event(
|
||||
"just reply me `OK`", session_id=SESSION_ID_IN_WHITELIST
|
||||
)
|
||||
with caplog.at_level(logging.DEBUG):
|
||||
await pipeline_scheduler.execute(mock_event)
|
||||
assert any("请求 LLM" in message for message in caplog.messages)
|
||||
assert mock_event.get_result() is not None
|
||||
assert mock_event.get_result().result_content_type == ResultContentType.LLM_RESULT
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_websearch(pipeline_scheduler: PipelineScheduler, caplog):
|
||||
caplog.clear()
|
||||
mock_event = FakeAstrMessageEvent.create_fake_event(
|
||||
"help me search the latest OpenAI news", session_id=SESSION_ID_IN_WHITELIST
|
||||
)
|
||||
with caplog.at_level(logging.DEBUG):
|
||||
await pipeline_scheduler.execute(mock_event)
|
||||
assert any("请求 LLM" in message for message in caplog.messages)
|
||||
assert any(
|
||||
"web_searcher - search_from_search_engine" in message
|
||||
for message in caplog.messages
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_commands(pipeline_scheduler: PipelineScheduler, caplog):
|
||||
for command in TEST_COMMANDS:
|
||||
caplog.clear()
|
||||
mock_event = FakeAstrMessageEvent.create_fake_event(
|
||||
command[0], session_id=SESSION_ID_IN_WHITELIST
|
||||
)
|
||||
with caplog.at_level(logging.DEBUG):
|
||||
await pipeline_scheduler.execute(mock_event)
|
||||
# assert any("执行阶段 ProcessStage" in message for message in caplog.messages)
|
||||
assert any(command[1] in message for message in caplog.messages)
|
||||
+105
-43
@@ -1,5 +1,6 @@
|
||||
import pytest
|
||||
import os
|
||||
from unittest.mock import MagicMock
|
||||
from astrbot.core.star.star_manager import PluginManager
|
||||
from astrbot.core.star.star_handler import star_handlers_registry
|
||||
from astrbot.core.star.star import star_registry
|
||||
@@ -8,18 +9,51 @@ from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core.db.sqlite import SQLiteDatabase
|
||||
from asyncio import Queue
|
||||
|
||||
event_queue = Queue()
|
||||
|
||||
config = AstrBotConfig()
|
||||
|
||||
db = SQLiteDatabase("data/data_v3.db")
|
||||
|
||||
star_context = Context(event_queue, config, db)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def plugin_manager_pm():
|
||||
return PluginManager(star_context, config)
|
||||
def plugin_manager_pm(tmp_path):
|
||||
"""
|
||||
Provides a fully isolated PluginManager instance for testing.
|
||||
- Uses a temporary directory for plugins.
|
||||
- Uses a temporary database.
|
||||
- Creates a fresh context for each test.
|
||||
"""
|
||||
# Create temporary resources
|
||||
temp_plugins_path = tmp_path / "plugins"
|
||||
temp_plugins_path.mkdir()
|
||||
temp_db_path = tmp_path / "test_db.db"
|
||||
|
||||
# Create fresh, isolated instances for the context
|
||||
event_queue = Queue()
|
||||
config = AstrBotConfig()
|
||||
db = SQLiteDatabase(str(temp_db_path))
|
||||
|
||||
# Set the plugin store path in the config to the temporary directory
|
||||
config.plugin_store_path = str(temp_plugins_path)
|
||||
|
||||
# Mock dependencies for the context
|
||||
provider_manager = MagicMock()
|
||||
platform_manager = MagicMock()
|
||||
conversation_manager = MagicMock()
|
||||
message_history_manager = MagicMock()
|
||||
persona_manager = MagicMock()
|
||||
astrbot_config_mgr = MagicMock()
|
||||
|
||||
star_context = Context(
|
||||
event_queue,
|
||||
config,
|
||||
db,
|
||||
provider_manager,
|
||||
platform_manager,
|
||||
conversation_manager,
|
||||
message_history_manager,
|
||||
persona_manager,
|
||||
astrbot_config_mgr,
|
||||
)
|
||||
|
||||
# Create the PluginManager instance
|
||||
manager = PluginManager(star_context, config)
|
||||
yield manager
|
||||
|
||||
|
||||
def test_plugin_manager_initialization(plugin_manager_pm: PluginManager):
|
||||
@@ -36,48 +70,76 @@ async def test_plugin_manager_reload(plugin_manager_pm: PluginManager):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_plugin_crud(plugin_manager_pm: PluginManager):
|
||||
"""测试插件安装和重载"""
|
||||
os.makedirs("data/plugins", exist_ok=True)
|
||||
async def test_install_plugin(plugin_manager_pm: PluginManager):
|
||||
"""Tests successful plugin installation in an isolated environment."""
|
||||
test_repo = "https://github.com/Soulter/astrbot_plugin_essential"
|
||||
plugin_path = await plugin_manager_pm.install_plugin(test_repo)
|
||||
exists = False
|
||||
for md in star_registry:
|
||||
if md.name == "astrbot_plugin_essential":
|
||||
exists = True
|
||||
break
|
||||
assert plugin_path is not None
|
||||
plugin_info = await plugin_manager_pm.install_plugin(test_repo)
|
||||
plugin_path = os.path.join(
|
||||
plugin_manager_pm.plugin_store_path, "astrbot_plugin_essential"
|
||||
)
|
||||
|
||||
assert plugin_info is not None
|
||||
assert os.path.exists(plugin_path)
|
||||
assert exists is True, "插件 astrbot_plugin_essential 未成功载入"
|
||||
# shutil.rmtree(plugin_path)
|
||||
assert any(md.name == "astrbot_plugin_essential" for md in star_registry), (
|
||||
"Plugin 'astrbot_plugin_essential' was not loaded into star_registry."
|
||||
)
|
||||
|
||||
# install plugin which is not exists
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_install_nonexistent_plugin(plugin_manager_pm: PluginManager):
|
||||
"""Tests that installing a non-existent plugin raises an exception."""
|
||||
with pytest.raises(Exception):
|
||||
plugin_path = await plugin_manager_pm.install_plugin(test_repo + "haha")
|
||||
await plugin_manager_pm.install_plugin(
|
||||
"https://github.com/Soulter/non_existent_repo"
|
||||
)
|
||||
|
||||
# update
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_plugin(plugin_manager_pm: PluginManager):
|
||||
"""Tests updating an existing plugin in an isolated environment."""
|
||||
# First, install the plugin
|
||||
test_repo = "https://github.com/Soulter/astrbot_plugin_essential"
|
||||
await plugin_manager_pm.install_plugin(test_repo)
|
||||
|
||||
# Then, update it
|
||||
await plugin_manager_pm.update_plugin("astrbot_plugin_essential")
|
||||
|
||||
with pytest.raises(Exception):
|
||||
await plugin_manager_pm.update_plugin("astrbot_plugin_essentialhaha")
|
||||
|
||||
# uninstall
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_nonexistent_plugin(plugin_manager_pm: PluginManager):
|
||||
"""Tests that updating a non-existent plugin raises an exception."""
|
||||
with pytest.raises(Exception):
|
||||
await plugin_manager_pm.update_plugin("non_existent_plugin")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_uninstall_plugin(plugin_manager_pm: PluginManager):
|
||||
"""Tests successful plugin uninstallation in an isolated environment."""
|
||||
# First, install the plugin
|
||||
test_repo = "https://github.com/Soulter/astrbot_plugin_essential"
|
||||
await plugin_manager_pm.install_plugin(test_repo)
|
||||
plugin_path = os.path.join(
|
||||
plugin_manager_pm.plugin_store_path, "astrbot_plugin_essential"
|
||||
)
|
||||
assert os.path.exists(plugin_path) # Pre-condition
|
||||
|
||||
# Then, uninstall it
|
||||
await plugin_manager_pm.uninstall_plugin("astrbot_plugin_essential")
|
||||
|
||||
assert not os.path.exists(plugin_path)
|
||||
exists = False
|
||||
for md in star_registry:
|
||||
if md.name == "astrbot_plugin_essential":
|
||||
exists = True
|
||||
break
|
||||
assert exists is False, "插件 astrbot_plugin_essential 未成功卸载"
|
||||
exists = False
|
||||
for md in star_handlers_registry:
|
||||
if "astrbot_plugin_essential" in md.handler_module_path:
|
||||
exists = True
|
||||
break
|
||||
assert exists is False, "插件 astrbot_plugin_essential 未成功卸载"
|
||||
assert not any(md.name == "astrbot_plugin_essential" for md in star_registry), (
|
||||
"Plugin 'astrbot_plugin_essential' was not unloaded from star_registry."
|
||||
)
|
||||
assert not any(
|
||||
"astrbot_plugin_essential" in md.handler_module_path
|
||||
for md in star_handlers_registry
|
||||
), (
|
||||
"Plugin 'astrbot_plugin_essential' handler was not unloaded from star_handlers_registry."
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_uninstall_nonexistent_plugin(plugin_manager_pm: PluginManager):
|
||||
"""Tests that uninstalling a non-existent plugin raises an exception."""
|
||||
with pytest.raises(Exception):
|
||||
await plugin_manager_pm.uninstall_plugin("astrbot_plugin_essentialhaha")
|
||||
|
||||
# TODO: file installation
|
||||
await plugin_manager_pm.uninstall_plugin("non_existent_plugin")
|
||||
|
||||
Reference in New Issue
Block a user