fix: unit tests (#2760)

* fix:修复了main和plugin_manager部分单元测试

* fix: 修复了dashboard部分测试

* remove: 删除暂无用的配置测试脚本

* perf:拆分插件增查删改为独立的单元测试

* refactor: 重构插件管理器测试,使用临时环境隔离测试实例

* test: 增加对仪表板文件检查的单元测试,涵盖不同情况

* style: format code

* remove: 删除未使用的导入语句

* delete: remove unused test file for pipeline

---------

Co-authored-by: Soulter <905617992@qq.com>
This commit is contained in:
RC-CHN
2025-09-27 14:43:04 +08:00
committed by GitHub
parent ccb380ce06
commit 19d7438499
4 changed files with 235 additions and 389 deletions
+79 -43
View File
@@ -1,5 +1,7 @@
import pytest
import pytest_asyncio
import os
import asyncio
from quart import Quart
from astrbot.dashboard.server import AstrBotDashboard
from astrbot.core.db.sqlite import SQLiteDatabase
@@ -9,36 +11,46 @@ from astrbot.core.star.star_handler import star_handlers_registry
from astrbot.core.star.star import star_registry
@pytest.fixture(scope="module")
def core_lifecycle_td():
db = SQLiteDatabase("data/data_v3.db")
@pytest_asyncio.fixture(scope="module")
async def core_lifecycle_td(tmp_path_factory):
"""Creates and initializes a core lifecycle instance with a temporary database."""
tmp_db_path = tmp_path_factory.mktemp("data") / "test_data_v3.db"
db = SQLiteDatabase(str(tmp_db_path))
log_broker = LogBroker()
core_lifecycle_td = AstrBotCoreLifecycle(log_broker, db)
return core_lifecycle_td
core_lifecycle = AstrBotCoreLifecycle(log_broker, db)
await core_lifecycle.initialize()
return core_lifecycle
@pytest.fixture(scope="module")
def app(core_lifecycle_td):
db = SQLiteDatabase("data/data_v3.db")
server = AstrBotDashboard(core_lifecycle_td, db)
def app(core_lifecycle_td: AstrBotCoreLifecycle):
"""Creates a Quart app instance for testing."""
shutdown_event = asyncio.Event()
# The db instance is already part of the core_lifecycle_td
server = AstrBotDashboard(core_lifecycle_td, core_lifecycle_td.db, shutdown_event)
return server.app
@pytest.fixture(scope="module")
def header():
return {}
@pytest_asyncio.fixture(scope="module")
async def authenticated_header(app: Quart, core_lifecycle_td: AstrBotCoreLifecycle):
"""Handles login and returns an authenticated header."""
test_client = app.test_client()
response = await test_client.post(
"/api/auth/login",
json={
"username": core_lifecycle_td.astrbot_config["dashboard"]["username"],
"password": core_lifecycle_td.astrbot_config["dashboard"]["password"],
},
)
data = await response.get_json()
assert data["status"] == "ok"
token = data["data"]["token"]
return {"Authorization": f"Bearer {token}"}
@pytest.mark.asyncio
async def test_init_core_lifecycle_td(core_lifecycle_td):
await core_lifecycle_td.initialize()
assert core_lifecycle_td is not None
@pytest.mark.asyncio
async def test_auth_login(
app: Quart, core_lifecycle_td: AstrBotCoreLifecycle, header: dict
):
async def test_auth_login(app: Quart, core_lifecycle_td: AstrBotCoreLifecycle):
"""Tests the login functionality with both wrong and correct credentials."""
test_client = app.test_client()
response = await test_client.post(
"/api/auth/login", json={"username": "wrong", "password": "password"}
@@ -55,31 +67,32 @@ async def test_auth_login(
)
data = await response.get_json()
assert data["status"] == "ok" and "token" in data["data"]
header["Authorization"] = f"Bearer {data['data']['token']}"
@pytest.mark.asyncio
async def test_get_stat(app: Quart, header: dict):
async def test_get_stat(app: Quart, authenticated_header: dict):
test_client = app.test_client()
response = await test_client.get("/api/stat/get")
assert response.status_code == 401
response = await test_client.get("/api/stat/get", headers=header)
response = await test_client.get("/api/stat/get", headers=authenticated_header)
assert response.status_code == 200
data = await response.get_json()
assert data["status"] == "ok" and "platform" in data["data"]
@pytest.mark.asyncio
async def test_plugins(app: Quart, header: dict):
async def test_plugins(app: Quart, authenticated_header: dict):
test_client = app.test_client()
# 已经安装的插件
response = await test_client.get("/api/plugin/get", headers=header)
response = await test_client.get("/api/plugin/get", headers=authenticated_header)
assert response.status_code == 200
data = await response.get_json()
assert data["status"] == "ok"
# 插件市场
response = await test_client.get("/api/plugin/market_list", headers=header)
response = await test_client.get(
"/api/plugin/market_list", headers=authenticated_header
)
assert response.status_code == 200
data = await response.get_json()
assert data["status"] == "ok"
@@ -88,7 +101,7 @@ async def test_plugins(app: Quart, header: dict):
response = await test_client.post(
"/api/plugin/install",
json={"url": "https://github.com/Soulter/astrbot_plugin_essential"},
headers=header,
headers=authenticated_header,
)
assert response.status_code == 200
data = await response.get_json()
@@ -102,7 +115,9 @@ async def test_plugins(app: Quart, header: dict):
# 插件更新
response = await test_client.post(
"/api/plugin/update", json={"name": "astrbot_plugin_essential"}, headers=header
"/api/plugin/update",
json={"name": "astrbot_plugin_essential"},
headers=authenticated_header,
)
assert response.status_code == 200
data = await response.get_json()
@@ -112,7 +127,7 @@ async def test_plugins(app: Quart, header: dict):
response = await test_client.post(
"/api/plugin/uninstall",
json={"name": "astrbot_plugin_essential"},
headers=header,
headers=authenticated_header,
)
assert response.status_code == 200
data = await response.get_json()
@@ -132,9 +147,9 @@ async def test_plugins(app: Quart, header: dict):
@pytest.mark.asyncio
async def test_check_update(app: Quart, header: dict):
async def test_check_update(app: Quart, authenticated_header: dict):
test_client = app.test_client()
response = await test_client.get("/api/update/check", headers=header)
response = await test_client.get("/api/update/check", headers=authenticated_header)
assert response.status_code == 200
data = await response.get_json()
assert data["status"] == "success"
@@ -142,24 +157,45 @@ async def test_check_update(app: Quart, header: dict):
@pytest.mark.asyncio
async def test_do_update(
app: Quart, header: dict, core_lifecycle_td: AstrBotCoreLifecycle
app: Quart,
authenticated_header: dict,
core_lifecycle_td: AstrBotCoreLifecycle,
monkeypatch,
tmp_path_factory,
):
global VERSION
test_client = app.test_client()
os.makedirs("data/astrbot_release", exist_ok=True)
core_lifecycle_td.astrbot_updator.MAIN_PATH = "data/astrbot_release"
VERSION = "114.514.1919810"
response = await test_client.post(
"/api/update/do", headers=header, json={"version": "latest"}
# Use a temporary path for the mock update to avoid side effects
temp_release_dir = tmp_path_factory.mktemp("release")
release_path = temp_release_dir / "astrbot"
async def mock_update(*args, **kwargs):
"""Mocks the update process by creating a directory in the temp path."""
os.makedirs(release_path, exist_ok=True)
return
async def mock_download_dashboard(*args, **kwargs):
"""Mocks the dashboard download to prevent network access."""
return
async def mock_pip_install(*args, **kwargs):
"""Mocks pip install to prevent actual installation."""
return
monkeypatch.setattr(core_lifecycle_td.astrbot_updator, "update", mock_update)
monkeypatch.setattr(
"astrbot.dashboard.routes.update.download_dashboard", mock_download_dashboard
)
monkeypatch.setattr(
"astrbot.dashboard.routes.update.pip_installer.install", mock_pip_install
)
assert response.status_code == 200
data = await response.get_json()
assert data["status"] == "error" # 已经是最新版本
response = await test_client.post(
"/api/update/do", headers=header, json={"version": "v3.4.0", "reboot": False}
"/api/update/do",
headers=authenticated_header,
json={"version": "v3.4.0", "reboot": False},
)
assert response.status_code == 200
data = await response.get_json()
assert data["status"] == "ok"
assert os.path.exists("data/astrbot_release/astrbot")
assert os.path.exists(release_path)
+51 -18
View File
@@ -1,5 +1,9 @@
import os
import sys
# 将项目根目录添加到 sys.path
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
import pytest
from unittest import mock
from main import check_env, check_dashboard_files
@@ -27,29 +31,58 @@ def test_check_env(monkeypatch):
@pytest.mark.asyncio
async def test_check_dashboard_files(monkeypatch):
async def test_check_dashboard_files_not_exists(monkeypatch):
"""Tests dashboard download when files do not exist."""
monkeypatch.setattr(os.path, "exists", lambda x: False)
async def mock_get(*args, **kwargs):
class MockResponse:
status = 200
with mock.patch("main.download_dashboard") as mock_download:
await check_dashboard_files()
mock_download.assert_called_once()
async def read(self):
return b"content"
return MockResponse()
@pytest.mark.asyncio
async def test_check_dashboard_files_exists_and_version_match(monkeypatch):
"""Tests that dashboard is not downloaded when it exists and version matches."""
# Mock os.path.exists to return True
monkeypatch.setattr(os.path, "exists", lambda x: True)
with mock.patch("aiohttp.ClientSession.get", new=mock_get):
with mock.patch("builtins.open", mock.mock_open()) as mock_file:
with mock.patch("zipfile.ZipFile.extractall") as mock_extractall:
# Mock get_dashboard_version to return the current version
with mock.patch("main.get_dashboard_version") as mock_get_version:
# We need to import VERSION from main's context
from main import VERSION
async def mock_aenter(_):
await check_dashboard_files()
mock_file.assert_called_once_with("data/dashboard.zip", "wb")
mock_extractall.assert_called_once()
mock_get_version.return_value = f"v{VERSION}"
async def mock_aexit(obj, exc_type, exc, tb):
return
with mock.patch("main.download_dashboard") as mock_download:
await check_dashboard_files()
# Assert that download_dashboard was NOT called
mock_download.assert_not_called()
mock_extractall.__aenter__ = mock_aenter
mock_extractall.__aexit__ = mock_aexit
@pytest.mark.asyncio
async def test_check_dashboard_files_exists_but_version_mismatch(monkeypatch):
"""Tests that a warning is logged when dashboard version mismatches."""
monkeypatch.setattr(os.path, "exists", lambda x: True)
with mock.patch("main.get_dashboard_version") as mock_get_version:
mock_get_version.return_value = "v0.0.1" # A different version
with mock.patch("main.logger.warning") as mock_logger_warning:
await check_dashboard_files()
mock_logger_warning.assert_called_once()
call_args, _ = mock_logger_warning.call_args
assert "不符" in call_args[0]
@pytest.mark.asyncio
async def test_check_dashboard_files_with_webui_dir_arg(monkeypatch):
"""Tests that providing a valid webui_dir skips all checks."""
valid_dir = "/tmp/my-custom-webui"
monkeypatch.setattr(os.path, "exists", lambda path: path == valid_dir)
with mock.patch("main.download_dashboard") as mock_download:
with mock.patch("main.get_dashboard_version") as mock_get_version:
result = await check_dashboard_files(webui_dir=valid_dir)
assert result == valid_dir
mock_download.assert_not_called()
mock_get_version.assert_not_called()
-285
View File
@@ -1,285 +0,0 @@
import pytest
import logging
import os
import asyncio
from astrbot.core.pipeline.scheduler import PipelineScheduler, PipelineContext
from astrbot.core.star import PluginManager
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.platform.astrbot_message import (
AstrBotMessage,
MessageMember,
MessageType,
)
from astrbot.core.message.message_event_result import MessageChain, ResultContentType
from astrbot.core.message.components import Plain, At
from astrbot.core.platform.platform_metadata import PlatformMetadata
from astrbot.core.platform.manager import PlatformManager
from astrbot.core.provider.manager import ProviderManager
from astrbot.core.db.sqlite import SQLiteDatabase
from astrbot.core.star.context import Context
from asyncio import Queue
SESSION_ID_IN_WHITELIST = "test_sid_wl"
SESSION_ID_NOT_IN_WHITELIST = "test_sid"
TEST_LLM_PROVIDER = {
"id": "zhipu_default",
"type": "openai_chat_completion",
"enable": True,
"key": [os.getenv("ZHIPU_API_KEY")],
"api_base": "https://open.bigmodel.cn/api/paas/v4/",
"model_config": {
"model": "glm-4-flash",
},
}
TEST_COMMANDS = [
["help", "已注册的 AstrBot 内置指令"],
["tool ls", "函数工具"],
["tool on websearch", "激活工具"],
["tool off websearch", "停用工具"],
["plugin", "已加载的插件"],
["t2i", "文本转图片模式"],
["sid", "此 ID 可用于设置会话白名单。"],
["op test_op", "授权成功。"],
["deop test_op", "取消授权成功。"],
["wl test_platform:FriendMessage:test_sid_wl2", "添加白名单成功。"],
["dwl test_platform:FriendMessage:test_sid_wl2", "删除白名单成功。"],
["provider", "当前载入的 LLM 提供商"],
["reset", "重置成功"],
# ["model", "查看、切换提供商模型列表"],
["history", "历史记录:"],
["key", "当前 Key"],
["persona", "[Persona]"],
]
class FakeAstrMessageEvent(AstrMessageEvent):
def __init__(self, abm: AstrBotMessage = None):
meta = PlatformMetadata("test_platform", "test")
super().__init__(
message_str=abm.message_str,
message_obj=abm,
platform_meta=meta,
session_id=abm.session_id,
)
async def send(self, message: MessageChain):
await super().send(message)
@staticmethod
def create_fake_event(
message_str: str,
session_id: str = "test_sid",
is_at: bool = False,
is_group: bool = False,
sender_id: str = "123456",
):
abm = AstrBotMessage()
abm.message_str = message_str
abm.group_id = "test"
abm.message = [Plain(message_str)]
if is_at:
abm.message.append(At(qq="bot"))
abm.self_id = "bot"
abm.sender = MessageMember(sender_id, "mika")
abm.timestamp = 1234567890
abm.message_id = "test"
abm.session_id = session_id
if is_group:
abm.type = MessageType.GROUP_MESSAGE
else:
abm.type = MessageType.FRIEND_MESSAGE
return FakeAstrMessageEvent(abm)
@pytest.fixture(scope="module")
def event_queue():
return Queue()
@pytest.fixture(scope="module")
def config():
cfg = AstrBotConfig()
cfg["platform_settings"]["id_whitelist"] = [
"test_platform:FriendMessage:test_sid_wl",
"test_platform:GroupMessage:test_sid_wl",
]
cfg["admins_id"] = ["123456"]
cfg["content_safety"]["internal_keywords"]["extra_keywords"] = ["^TEST_NEGATIVE"]
cfg["provider"] = [TEST_LLM_PROVIDER]
return cfg
@pytest.fixture(scope="module")
def db():
return SQLiteDatabase("data/data_v3.db")
@pytest.fixture(scope="module")
def platform_manager(event_queue, config):
return PlatformManager(config, event_queue)
@pytest.fixture(scope="module")
def provider_manager(config, db):
return ProviderManager(config, db)
@pytest.fixture(scope="module")
def star_context(event_queue, config, db, platform_manager, provider_manager):
star_context = Context(event_queue, config, db, provider_manager, platform_manager)
return star_context
@pytest.fixture(scope="module")
def plugin_manager(star_context, config):
plugin_manager = PluginManager(star_context, config)
# await plugin_manager.reload()
asyncio.run(plugin_manager.reload())
return plugin_manager
@pytest.fixture(scope="module")
def pipeline_context(config, plugin_manager):
return PipelineContext(config, plugin_manager)
@pytest.fixture(scope="module")
def pipeline_scheduler(pipeline_context):
return PipelineScheduler(pipeline_context)
@pytest.mark.asyncio
async def test_platform_initialization(platform_manager: PlatformManager):
await platform_manager.initialize()
@pytest.mark.asyncio
async def test_provider_initialization(provider_manager: ProviderManager):
await provider_manager.initialize()
@pytest.mark.asyncio
async def test_pipeline_scheduler_initialization(pipeline_scheduler: PipelineScheduler):
await pipeline_scheduler.initialize()
@pytest.mark.asyncio
async def test_pipeline_wakeup(pipeline_scheduler: PipelineScheduler, caplog):
"""测试唤醒"""
# 群聊无 @ 无指令
caplog.clear()
mock_event = FakeAstrMessageEvent.create_fake_event("test", is_group=True)
with caplog.at_level(logging.DEBUG):
await pipeline_scheduler.execute(mock_event)
assert any(
"执行阶段 WhitelistCheckStage" not in message for message in caplog.messages
)
# 群聊有 @ 无指令
mock_event = FakeAstrMessageEvent.create_fake_event(
"test", is_group=True, is_at=True
)
with caplog.at_level(logging.DEBUG):
await pipeline_scheduler.execute(mock_event)
assert any("执行阶段 WhitelistCheckStage" in message for message in caplog.messages)
# 群聊有指令
mock_event = FakeAstrMessageEvent.create_fake_event(
"/help", is_group=True, session_id=SESSION_ID_IN_WHITELIST
)
await pipeline_scheduler.execute(mock_event)
assert mock_event._has_send_oper is True
@pytest.mark.asyncio
async def test_pipeline_wl(
pipeline_scheduler: PipelineScheduler, config: AstrBotConfig, caplog
):
caplog.clear()
mock_event = FakeAstrMessageEvent.create_fake_event(
"test", SESSION_ID_IN_WHITELIST, sender_id="123"
)
with caplog.at_level(logging.INFO):
await pipeline_scheduler.execute(mock_event)
assert any(
"不在会话白名单中,已终止事件传播。" not in message
for message in caplog.messages
), "日志中未找到预期的消息"
mock_event = FakeAstrMessageEvent.create_fake_event("test", sender_id="123")
with caplog.at_level(logging.INFO):
await pipeline_scheduler.execute(mock_event)
assert any(
"不在会话白名单中,已终止事件传播。" in message for message in caplog.messages
), "日志中未找到预期的消息"
@pytest.mark.asyncio
async def test_pipeline_content_safety(pipeline_scheduler: PipelineScheduler, caplog):
# 测试默认屏蔽词
caplog.clear()
mock_event = FakeAstrMessageEvent.create_fake_event(
"色情", session_id=SESSION_ID_IN_WHITELIST
) # 测试需要。
with caplog.at_level(logging.INFO):
await pipeline_scheduler.execute(mock_event)
assert any("内容安全检查不通过" in message for message in caplog.messages), (
"日志中未找到预期的消息"
)
# 测试额外屏蔽词
mock_event = FakeAstrMessageEvent.create_fake_event(
"TEST_NEGATIVE", session_id=SESSION_ID_IN_WHITELIST
)
with caplog.at_level(logging.INFO):
await pipeline_scheduler.execute(mock_event)
assert any("内容安全检查不通过" in message for message in caplog.messages), (
"日志中未找到预期的消息"
)
mock_event = FakeAstrMessageEvent.create_fake_event(
"_TEST_NEGATIVE", session_id=SESSION_ID_IN_WHITELIST
)
with caplog.at_level(logging.INFO):
await pipeline_scheduler.execute(mock_event)
assert any("内容安全检查不通过" not in message for message in caplog.messages)
# TODO: 测试 百度AI 的内容安全检查
@pytest.mark.asyncio
async def test_pipeline_llm(pipeline_scheduler: PipelineScheduler, caplog):
caplog.clear()
mock_event = FakeAstrMessageEvent.create_fake_event(
"just reply me `OK`", session_id=SESSION_ID_IN_WHITELIST
)
with caplog.at_level(logging.DEBUG):
await pipeline_scheduler.execute(mock_event)
assert any("请求 LLM" in message for message in caplog.messages)
assert mock_event.get_result() is not None
assert mock_event.get_result().result_content_type == ResultContentType.LLM_RESULT
@pytest.mark.asyncio
async def test_pipeline_websearch(pipeline_scheduler: PipelineScheduler, caplog):
caplog.clear()
mock_event = FakeAstrMessageEvent.create_fake_event(
"help me search the latest OpenAI news", session_id=SESSION_ID_IN_WHITELIST
)
with caplog.at_level(logging.DEBUG):
await pipeline_scheduler.execute(mock_event)
assert any("请求 LLM" in message for message in caplog.messages)
assert any(
"web_searcher - search_from_search_engine" in message
for message in caplog.messages
)
@pytest.mark.asyncio
async def test_commands(pipeline_scheduler: PipelineScheduler, caplog):
for command in TEST_COMMANDS:
caplog.clear()
mock_event = FakeAstrMessageEvent.create_fake_event(
command[0], session_id=SESSION_ID_IN_WHITELIST
)
with caplog.at_level(logging.DEBUG):
await pipeline_scheduler.execute(mock_event)
# assert any("执行阶段 ProcessStage" in message for message in caplog.messages)
assert any(command[1] in message for message in caplog.messages)
+105 -43
View File
@@ -1,5 +1,6 @@
import pytest
import os
from unittest.mock import MagicMock
from astrbot.core.star.star_manager import PluginManager
from astrbot.core.star.star_handler import star_handlers_registry
from astrbot.core.star.star import star_registry
@@ -8,18 +9,51 @@ from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.db.sqlite import SQLiteDatabase
from asyncio import Queue
event_queue = Queue()
config = AstrBotConfig()
db = SQLiteDatabase("data/data_v3.db")
star_context = Context(event_queue, config, db)
@pytest.fixture
def plugin_manager_pm():
return PluginManager(star_context, config)
def plugin_manager_pm(tmp_path):
"""
Provides a fully isolated PluginManager instance for testing.
- Uses a temporary directory for plugins.
- Uses a temporary database.
- Creates a fresh context for each test.
"""
# Create temporary resources
temp_plugins_path = tmp_path / "plugins"
temp_plugins_path.mkdir()
temp_db_path = tmp_path / "test_db.db"
# Create fresh, isolated instances for the context
event_queue = Queue()
config = AstrBotConfig()
db = SQLiteDatabase(str(temp_db_path))
# Set the plugin store path in the config to the temporary directory
config.plugin_store_path = str(temp_plugins_path)
# Mock dependencies for the context
provider_manager = MagicMock()
platform_manager = MagicMock()
conversation_manager = MagicMock()
message_history_manager = MagicMock()
persona_manager = MagicMock()
astrbot_config_mgr = MagicMock()
star_context = Context(
event_queue,
config,
db,
provider_manager,
platform_manager,
conversation_manager,
message_history_manager,
persona_manager,
astrbot_config_mgr,
)
# Create the PluginManager instance
manager = PluginManager(star_context, config)
yield manager
def test_plugin_manager_initialization(plugin_manager_pm: PluginManager):
@@ -36,48 +70,76 @@ async def test_plugin_manager_reload(plugin_manager_pm: PluginManager):
@pytest.mark.asyncio
async def test_plugin_crud(plugin_manager_pm: PluginManager):
"""测试插件安装和重载"""
os.makedirs("data/plugins", exist_ok=True)
async def test_install_plugin(plugin_manager_pm: PluginManager):
"""Tests successful plugin installation in an isolated environment."""
test_repo = "https://github.com/Soulter/astrbot_plugin_essential"
plugin_path = await plugin_manager_pm.install_plugin(test_repo)
exists = False
for md in star_registry:
if md.name == "astrbot_plugin_essential":
exists = True
break
assert plugin_path is not None
plugin_info = await plugin_manager_pm.install_plugin(test_repo)
plugin_path = os.path.join(
plugin_manager_pm.plugin_store_path, "astrbot_plugin_essential"
)
assert plugin_info is not None
assert os.path.exists(plugin_path)
assert exists is True, "插件 astrbot_plugin_essential 未成功载入"
# shutil.rmtree(plugin_path)
assert any(md.name == "astrbot_plugin_essential" for md in star_registry), (
"Plugin 'astrbot_plugin_essential' was not loaded into star_registry."
)
# install plugin which is not exists
@pytest.mark.asyncio
async def test_install_nonexistent_plugin(plugin_manager_pm: PluginManager):
"""Tests that installing a non-existent plugin raises an exception."""
with pytest.raises(Exception):
plugin_path = await plugin_manager_pm.install_plugin(test_repo + "haha")
await plugin_manager_pm.install_plugin(
"https://github.com/Soulter/non_existent_repo"
)
# update
@pytest.mark.asyncio
async def test_update_plugin(plugin_manager_pm: PluginManager):
"""Tests updating an existing plugin in an isolated environment."""
# First, install the plugin
test_repo = "https://github.com/Soulter/astrbot_plugin_essential"
await plugin_manager_pm.install_plugin(test_repo)
# Then, update it
await plugin_manager_pm.update_plugin("astrbot_plugin_essential")
with pytest.raises(Exception):
await plugin_manager_pm.update_plugin("astrbot_plugin_essentialhaha")
# uninstall
@pytest.mark.asyncio
async def test_update_nonexistent_plugin(plugin_manager_pm: PluginManager):
"""Tests that updating a non-existent plugin raises an exception."""
with pytest.raises(Exception):
await plugin_manager_pm.update_plugin("non_existent_plugin")
@pytest.mark.asyncio
async def test_uninstall_plugin(plugin_manager_pm: PluginManager):
"""Tests successful plugin uninstallation in an isolated environment."""
# First, install the plugin
test_repo = "https://github.com/Soulter/astrbot_plugin_essential"
await plugin_manager_pm.install_plugin(test_repo)
plugin_path = os.path.join(
plugin_manager_pm.plugin_store_path, "astrbot_plugin_essential"
)
assert os.path.exists(plugin_path) # Pre-condition
# Then, uninstall it
await plugin_manager_pm.uninstall_plugin("astrbot_plugin_essential")
assert not os.path.exists(plugin_path)
exists = False
for md in star_registry:
if md.name == "astrbot_plugin_essential":
exists = True
break
assert exists is False, "插件 astrbot_plugin_essential 未成功卸载"
exists = False
for md in star_handlers_registry:
if "astrbot_plugin_essential" in md.handler_module_path:
exists = True
break
assert exists is False, "插件 astrbot_plugin_essential 未成功卸载"
assert not any(md.name == "astrbot_plugin_essential" for md in star_registry), (
"Plugin 'astrbot_plugin_essential' was not unloaded from star_registry."
)
assert not any(
"astrbot_plugin_essential" in md.handler_module_path
for md in star_handlers_registry
), (
"Plugin 'astrbot_plugin_essential' handler was not unloaded from star_handlers_registry."
)
@pytest.mark.asyncio
async def test_uninstall_nonexistent_plugin(plugin_manager_pm: PluginManager):
"""Tests that uninstalling a non-existent plugin raises an exception."""
with pytest.raises(Exception):
await plugin_manager_pm.uninstall_plugin("astrbot_plugin_essentialhaha")
# TODO: file installation
await plugin_manager_pm.uninstall_plugin("non_existent_plugin")