test: add tests for star base class and config management (#5356)

* test: add tests for star base class and config management

- Add Star base class safety helper tests
- Expand config management unit tests
- Update cron manager tests

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* test: fix plugin_manager test isolation issues

- Use local mock plugin instead of real network requests
- Clear sys.modules cache for entire data module tree
- Clear star_map and star_registry in teardown
- Use pytest_asyncio.fixture for async fixture support

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* test: fix test isolation and compatibility issues

- test_main.py: fix version comparison and path assertions for Windows
- test_smoke.py: add missing apscheduler.triggers mock modules
- test_tool_loop_agent_runner.py: update assertion for new interrupt behavior
- test_api_key_open_api.py: use unique session IDs to avoid test conflicts

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* test: add unit tests for _version_info comparisons

* test: enhance plugin manager tests with mock implementations and improved assertions

* test: add mock plugin builder and updater for plugin management tests

* fix: resolve pipeline and star import cycles (#5353)

* fix: resolve pipeline and star import cycles

- Add bootstrap.py and stage_order.py to break circular dependencies
- Export Context, PluginManager, StarTools from star module
- Update pipeline __init__ to defer imports
- Split pipeline initialization into separate bootstrap module

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* fix: add logging for get_config() failure in Star class

* fix: reorder logger initialization in base.py

---------

Co-authored-by: whatevertogo <whatevertogo@users.noreply.github.com>
Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>

* test: update cron job scheduling tests and refactor star base tests for clarity

* test: expand star base tests for comprehensive coverage

- Add tests for Star class initialization and context handling
- Add tests for text_to_image with/without config
- Add tests for html_render method
- Add tests for initialize/terminate lifecycle methods
- Add type hint validation tests for Context
- Add circular import prevention tests

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* fix: address PR review feedback - use TYPE_CHECKING instead of Any

- pipeline/context.py: Use TYPE_CHECKING to import PluginManager instead of Any
- pipeline/__init__.py: Add TYPE_CHECKING imports for __all__ exports to satisfy static analyzers
- star/register/star_handler.py: Use TYPE_CHECKING to import AstrAgentContext instead of Any
- tests: Remove invalid type hint tests that tested incorrect assumptions

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* fix: improve TYPE_CHECKING pattern for circular import resolution

- star/register/star_handler.py: Use AstrAgentContext instead of Any in generic types
- star/context.py: Remove unnecessary else branch with CronJobManager = Any
  (with __future__ annotations, TYPE_CHECKING imports are sufficient)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

---------

Co-authored-by: whatevertogo <whatevertogo@users.noreply.github.com>
Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>
This commit is contained in:
whatevertogo
2026-03-01 00:06:04 +08:00
committed by GitHub
parent 76e0d6d71a
commit 2a6863cf70
14 changed files with 1925 additions and 138 deletions
+12
View File
@@ -67,6 +67,18 @@ _LAZY_EXPORTS = {
), ),
} }
# Type-checking imports to satisfy static analyzers for __all__ exports
if TYPE_CHECKING:
from .content_safety_check.stage import ContentSafetyCheckStage
from .preprocess_stage.stage import PreProcessStage
from .process_stage.stage import ProcessStage
from .rate_limit_check.stage import RateLimitStage
from .respond.stage import RespondStage
from .result_decorate.stage import ResultDecorateStage
from .session_status_check.stage import SessionStatusCheckStage
from .waking_check.stage import WakingCheckStage
from .whitelist_check.stage import WhitelistCheckStage
__all__ = [ __all__ = [
"ContentSafetyCheckStage", "ContentSafetyCheckStage",
"EventResultType", "EventResultType",
+5 -2
View File
@@ -1,19 +1,22 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any from typing import TYPE_CHECKING
from astrbot.core.config import AstrBotConfig from astrbot.core.config import AstrBotConfig
from .context_utils import call_event_hook, call_handler from .context_utils import call_event_hook, call_handler
if TYPE_CHECKING:
from astrbot.core.star import PluginManager
@dataclass @dataclass
class PipelineContext: class PipelineContext:
"""上下文对象,包含管道执行所需的上下文信息""" """上下文对象,包含管道执行所需的上下文信息"""
astrbot_config: AstrBotConfig # AstrBot 配置对象 astrbot_config: AstrBotConfig # AstrBot 配置对象
plugin_manager: Any # 插件管理器对象 plugin_manager: PluginManager # 插件管理器对象
astrbot_config_id: str astrbot_config_id: str
call_handler = call_handler call_handler = call_handler
call_event_hook = call_event_hook call_event_hook = call_event_hook
-2
View File
@@ -47,8 +47,6 @@ logger = logging.getLogger("astrbot")
if TYPE_CHECKING: if TYPE_CHECKING:
from astrbot.core.cron.manager import CronJobManager from astrbot.core.cron.manager import CronJobManager
else:
CronJobManager = Any
class PlatformManagerProtocol(Protocol): class PlatformManagerProtocol(Protocol):
+8 -5
View File
@@ -2,7 +2,7 @@ from __future__ import annotations
import re import re
from collections.abc import AsyncGenerator, Awaitable, Callable from collections.abc import AsyncGenerator, Awaitable, Callable
from typing import Any from typing import TYPE_CHECKING, Any
import docstring_parser import docstring_parser
@@ -15,6 +15,9 @@ from astrbot.core.message.message_event_result import MessageEventResult
from astrbot.core.provider.func_tool_manager import PY_TO_JSON_TYPE, SUPPORTED_TYPES from astrbot.core.provider.func_tool_manager import PY_TO_JSON_TYPE, SUPPORTED_TYPES
from astrbot.core.provider.register import llm_tools from astrbot.core.provider.register import llm_tools
if TYPE_CHECKING:
from astrbot.core.astr_agent_context import AstrAgentContext
from ..filter.command import CommandFilter from ..filter.command import CommandFilter
from ..filter.command_group import CommandGroupFilter from ..filter.command_group import CommandGroupFilter
from ..filter.custom_filter import CustomFilterAnd, CustomFilterOr from ..filter.custom_filter import CustomFilterAnd, CustomFilterOr
@@ -616,7 +619,7 @@ class RegisteringAgent:
kwargs["registering_agent"] = self kwargs["registering_agent"] = self
return register_llm_tool(*args, **kwargs) return register_llm_tool(*args, **kwargs)
def __init__(self, agent: Agent[Any]) -> None: def __init__(self, agent: Agent[AstrAgentContext]) -> None:
self._agent = agent self._agent = agent
@@ -624,7 +627,7 @@ def register_agent(
name: str, name: str,
instruction: str, instruction: str,
tools: list[str | FunctionTool] | None = None, tools: list[str | FunctionTool] | None = None,
run_hooks: BaseAgentRunHooks[Any] | None = None, run_hooks: BaseAgentRunHooks[AstrAgentContext] | None = None,
): ):
"""注册一个 Agent """注册一个 Agent
@@ -638,12 +641,12 @@ def register_agent(
tools_ = tools or [] tools_ = tools or []
def decorator(awaitable: Callable[..., Awaitable[Any]]): def decorator(awaitable: Callable[..., Awaitable[Any]]):
AstrAgent = Agent[Any] AstrAgent = Agent[AstrAgentContext]
agent = AstrAgent( agent = AstrAgent(
name=name, name=name,
instructions=instruction, instructions=instruction,
tools=tools_, tools=tools_,
run_hooks=run_hooks or BaseAgentRunHooks[Any](), run_hooks=run_hooks or BaseAgentRunHooks[AstrAgentContext](),
) )
handoff_tool = HandoffTool(agent=agent) handoff_tool = HandoffTool(agent=agent)
handoff_tool.handler = awaitable handoff_tool.handler = awaitable
+256 -1
View File
@@ -3,7 +3,10 @@
提供统一的测试辅助工具,减少测试代码重复。 提供统一的测试辅助工具,减少测试代码重复。
""" """
from typing import Any import shutil
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Callable
from unittest.mock import AsyncMock, MagicMock from unittest.mock import AsyncMock, MagicMock
from astrbot.core.message.components import BaseMessageComponent from astrbot.core.message.components import BaseMessageComponent
@@ -330,3 +333,255 @@ def create_mock_llm_response(
tools_call_ids=tools_call_ids or [], tools_call_ids=tools_call_ids or [],
usage=TokenUsage(input_other=10, output=5), usage=TokenUsage(input_other=10, output=5),
) )
# ============================================================
# 测试插件辅助函数
# ============================================================
@dataclass
class MockPluginConfig:
"""测试插件配置。
用于创建和管理测试用的模拟插件。
Attributes:
name: 插件名称
author: 作者
description: 描述
version: 版本
repo: 仓库 URL
main_code: main.py 的代码内容
requirements: 依赖列表
has_readme: 是否创建 README.md
readme_content: README.md 内容
"""
name: str = "test_plugin"
author: str = "Test Author"
description: str = "A test plugin for unit testing"
version: str = "1.0.0"
repo: str = "https://github.com/test/test_plugin"
main_code: str = ""
requirements: list[str] = field(default_factory=list)
has_readme: bool = True
readme_content: str = "# Test Plugin\n\nThis is a test plugin."
# 默认的插件主代码模板
DEFAULT_PLUGIN_MAIN_TEMPLATE = '''
from astrbot.api import star
class Main(star.Star):
"""测试插件主类。"""
def __init__(self, context):
super().__init__(context)
self.name = "{plugin_name}"
async def initialize(self):
"""初始化插件。"""
pass
async def terminate(self):
"""终止插件。"""
pass
'''
class MockPluginBuilder:
"""测试插件构建器。
用于创建、管理和清理测试用的模拟插件。支持任意插件的模拟创建。
Example:
# 创建一个简单的测试插件
builder = MockPluginBuilder(plugin_store_path)
plugin_dir = builder.create("my_test_plugin")
# 创建自定义配置的插件
config = MockPluginConfig(
name="custom_plugin",
version="2.0.0",
main_code="print('hello')",
)
plugin_dir = builder.create(config)
# 清理插件
builder.cleanup("my_test_plugin")
"""
def __init__(self, plugin_store_path: str | Path):
"""初始化构建器。
Args:
plugin_store_path: 插件存储路径 (通常是 data/plugins)
"""
self.plugin_store_path = Path(plugin_store_path)
self._created_plugins: set[str] = set()
def create(
self,
plugin_config: str | MockPluginConfig | None = None,
**kwargs,
) -> Path:
"""创建模拟插件。
Args:
plugin_config: 插件名称字符串、MockPluginConfig 对象或 None
**kwargs: 如果 plugin_config 是字符串或 None,这些参数用于构建 MockPluginConfig
Returns:
Path: 创建的插件目录路径
"""
# 处理不同类型的输入
if plugin_config is None:
config = MockPluginConfig(**kwargs)
elif isinstance(plugin_config, str):
config = MockPluginConfig(name=plugin_config, **kwargs)
elif isinstance(plugin_config, MockPluginConfig):
config = plugin_config
else:
raise TypeError(f"Invalid plugin_config type: {type(plugin_config)}")
# 创建插件目录
plugin_dir = self.plugin_store_path / config.name
plugin_dir.mkdir(parents=True, exist_ok=True)
# 创建 metadata.yaml
metadata_content = "\n".join(
[
f"name: {config.name}",
f"author: {config.author}",
f"desc: {config.description}",
f"version: {config.version}",
f"repo: {config.repo}",
]
)
(plugin_dir / "metadata.yaml").write_text(
metadata_content + "\n", encoding="utf-8"
)
# 创建 main.py
main_code = config.main_code or DEFAULT_PLUGIN_MAIN_TEMPLATE.format(
plugin_name=config.name
)
(plugin_dir / "main.py").write_text(main_code, encoding="utf-8")
# 创建 requirements.txt(如果有依赖)
if config.requirements:
(plugin_dir / "requirements.txt").write_text(
"\n".join(config.requirements) + "\n", encoding="utf-8"
)
# 创建 README.md(如果需要)
if config.has_readme:
(plugin_dir / "README.md").write_text(
config.readme_content, encoding="utf-8"
)
# 记录创建的插件
self._created_plugins.add(config.name)
return plugin_dir
def cleanup(self, plugin_name: str | None = None) -> None:
"""清理插件。
Args:
plugin_name: 要清理的插件名称,如果为 None 则清理所有由本构建器创建的插件
"""
if plugin_name:
plugins_to_clean = {plugin_name}
else:
plugins_to_clean = self._created_plugins.copy()
for name in plugins_to_clean:
plugin_dir = self.plugin_store_path / name
if plugin_dir.exists():
shutil.rmtree(plugin_dir)
self._created_plugins.discard(name)
def cleanup_all(self) -> None:
"""清理所有由本构建器创建的插件。"""
self.cleanup(None)
def get_plugin_path(self, plugin_name: str) -> Path:
"""获取插件路径。
Args:
plugin_name: 插件名称
Returns:
Path: 插件目录路径
"""
return self.plugin_store_path / plugin_name
@property
def created_plugins(self) -> set[str]:
"""获取已创建的插件名称集合。"""
return self._created_plugins.copy()
def create_mock_updater_install(
plugin_builder: MockPluginBuilder,
repo_to_plugin: dict[str, str] | None = None,
) -> Callable:
"""创建模拟的 updater.install 方法。
Args:
plugin_builder: MockPluginBuilder 实例
repo_to_plugin: 仓库 URL 到插件名称的映射,格式: {"https://github.com/user/repo": "plugin_name"}
Returns:
Callable: 异步函数,可用于 monkeypatch.setattr
"""
async def mock_install(repo_url: str, proxy: str = "") -> str:
"""Mock updater.install 方法。"""
# 查找插件名称
plugin_name = None
if repo_to_plugin:
plugin_name = repo_to_plugin.get(repo_url)
# 如果没有映射,尝试从 URL 提取插件名
if not plugin_name:
# 从 https://github.com/user/plugin_name 提取 plugin_name
parts = repo_url.rstrip("/").split("/")
plugin_name = parts[-1] if parts else "unknown_plugin"
# 创建插件目录
config = MockPluginConfig(name=plugin_name, repo=repo_url)
plugin_dir = plugin_builder.create(config)
return str(plugin_dir)
return mock_install
def create_mock_updater_update(
plugin_builder: MockPluginBuilder,
update_callback: Callable | None = None,
) -> Callable:
"""创建模拟的 updater.update 方法。
Args:
plugin_builder: MockPluginBuilder 实例
update_callback: 更新回调函数,接收 plugin 参数
Returns:
Callable: 异步函数,可用于 monkeypatch.setattr
"""
async def mock_update(plugin, proxy: str = "") -> None:
"""Mock updater.update 方法。"""
plugin_dir = plugin_builder.get_plugin_path(plugin.name)
# 创建更新标记文件
(plugin_dir / ".updated").write_text("ok", encoding="utf-8")
# 调用回调
if update_callback:
update_callback(plugin)
return mock_update
+4 -4
View File
@@ -203,7 +203,7 @@ async def test_open_chat_send_auto_session_id_and_username(
"/api/v1/chat", "/api/v1/chat",
json={ json={
"message": "hello", "message": "hello",
"username": "alice", "username": "alice_auto_session",
"enable_streaming": False, "enable_streaming": False,
}, },
headers={"X-API-Key": raw_key}, headers={"X-API-Key": raw_key},
@@ -217,16 +217,16 @@ async def test_open_chat_send_auto_session_id_and_username(
created_session_id = send_data["data"]["session_id"] created_session_id = send_data["data"]["session_id"]
assert isinstance(created_session_id, str) assert isinstance(created_session_id, str)
uuid.UUID(created_session_id) uuid.UUID(created_session_id)
assert send_data["data"]["creator"] == "alice" assert send_data["data"]["creator"] == "alice_auto_session"
created_session = await core_lifecycle_td.db.get_platform_session_by_id( created_session = await core_lifecycle_td.db.get_platform_session_by_id(
created_session_id created_session_id
) )
assert created_session is not None assert created_session is not None
assert created_session.creator == "alice" assert created_session.creator == "alice_auto_session"
assert created_session.platform_id == "webchat" assert created_session.platform_id == "webchat"
await core_lifecycle_td.db.create_platform_session( await core_lifecycle_td.db.create_platform_session(
creator="bob", creator="bob_auto_session",
platform_id="webchat", platform_id="webchat",
session_id="open_api_existing_bob_session", session_id="open_api_existing_bob_session",
is_group=0, is_group=0,
+115 -46
View File
@@ -1,5 +1,6 @@
import asyncio import asyncio
import os import os
from pathlib import Path
import pytest import pytest
import pytest_asyncio import pytest_asyncio
@@ -11,6 +12,12 @@ from astrbot.core.db.sqlite import SQLiteDatabase
from astrbot.core.star.star import star_registry from astrbot.core.star.star import star_registry
from astrbot.core.star.star_handler import star_handlers_registry from astrbot.core.star.star_handler import star_handlers_registry
from astrbot.dashboard.server import AstrBotDashboard from astrbot.dashboard.server import AstrBotDashboard
from tests.fixtures.helpers import (
MockPluginBuilder,
MockPluginConfig,
create_mock_updater_install,
create_mock_updater_update,
)
@pytest_asyncio.fixture(scope="module") @pytest_asyncio.fixture(scope="module")
@@ -94,8 +101,15 @@ async def test_get_stat(app: Quart, authenticated_header: dict):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_plugins(app: Quart, authenticated_header: dict): async def test_plugins(
app: Quart,
authenticated_header: dict,
core_lifecycle_td: AstrBotCoreLifecycle,
monkeypatch,
):
"""测试插件 API 端点,使用 Mock 避免真实网络调用。"""
test_client = app.test_client() test_client = app.test_client()
# 已经安装的插件 # 已经安装的插件
response = await test_client.get("/api/plugin/get", headers=authenticated_header) response = await test_client.get("/api/plugin/get", headers=authenticated_header)
assert response.status_code == 200 assert response.status_code == 200
@@ -111,53 +125,79 @@ async def test_plugins(app: Quart, authenticated_header: dict):
data = await response.get_json() data = await response.get_json()
assert data["status"] == "ok" assert data["status"] == "ok"
# 插件安装 # 使用 MockPluginBuilder 创建测试插件
response = await test_client.post( plugin_store_path = core_lifecycle_td.plugin_manager.plugin_store_path
"/api/plugin/install", builder = MockPluginBuilder(plugin_store_path)
json={"url": "https://github.com/Soulter/astrbot_plugin_essential"},
headers=authenticated_header,
)
assert response.status_code == 200
data = await response.get_json()
assert data["status"] == "ok"
exists = False
for md in star_registry:
if md.name == "astrbot_plugin_essential":
exists = True
break
assert exists is True, "插件 astrbot_plugin_essential 未成功载入"
# 插件更新 # 定义测试插件
response = await test_client.post( test_plugin_name = "test_mock_plugin"
"/api/plugin/update", test_repo_url = f"https://github.com/test/{test_plugin_name}"
json={"name": "astrbot_plugin_essential"},
headers=authenticated_header,
)
assert response.status_code == 200
data = await response.get_json()
assert data["status"] == "ok"
# 插件卸载 # 创建 Mock 函数
response = await test_client.post( mock_install = create_mock_updater_install(
"/api/plugin/uninstall", builder,
json={"name": "astrbot_plugin_essential"}, repo_to_plugin={test_repo_url: test_plugin_name},
headers=authenticated_header,
) )
assert response.status_code == 200 mock_update = create_mock_updater_update(builder)
data = await response.get_json()
assert data["status"] == "ok" # 设置 Mock
exists = False monkeypatch.setattr(
for md in star_registry: core_lifecycle_td.plugin_manager.updator, "install", mock_install
if md.name == "astrbot_plugin_essential": )
exists = True monkeypatch.setattr(
break core_lifecycle_td.plugin_manager.updator, "update", mock_update
assert exists is False, "插件 astrbot_plugin_essential 未成功卸载" )
exists = False
for md in star_handlers_registry: try:
if "astrbot_plugin_essential" in md.handler_module_path: # 插件安装
exists = True response = await test_client.post(
break "/api/plugin/install",
assert exists is False, "插件 astrbot_plugin_essential 未成功卸载" json={"url": test_repo_url},
headers=authenticated_header,
)
assert response.status_code == 200
data = await response.get_json()
assert data["status"] == "ok", f"安装失败: {data.get('message', 'unknown error')}"
# 验证插件已注册
exists = any(md.name == test_plugin_name for md in star_registry)
assert exists is True, f"插件 {test_plugin_name} 未成功载入"
# 插件更新
response = await test_client.post(
"/api/plugin/update",
json={"name": test_plugin_name},
headers=authenticated_header,
)
assert response.status_code == 200
data = await response.get_json()
assert data["status"] == "ok"
# 验证更新标记文件
plugin_dir = builder.get_plugin_path(test_plugin_name)
assert (plugin_dir / ".updated").exists()
# 插件卸载
response = await test_client.post(
"/api/plugin/uninstall",
json={"name": test_plugin_name},
headers=authenticated_header,
)
assert response.status_code == 200
data = await response.get_json()
assert data["status"] == "ok"
# 验证插件已卸载
exists = any(md.name == test_plugin_name for md in star_registry)
assert exists is False, f"插件 {test_plugin_name} 未成功卸载"
exists = any(
test_plugin_name in md.handler_module_path for md in star_handlers_registry
)
assert exists is False, f"插件 {test_plugin_name} handler 未成功清理"
finally:
# 清理测试插件
builder.cleanup(test_plugin_name)
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -189,12 +229,41 @@ async def test_commands_api(app: Quart, authenticated_header: dict):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_check_update(app: Quart, authenticated_header: dict): async def test_check_update(
app: Quart,
authenticated_header: dict,
core_lifecycle_td: AstrBotCoreLifecycle,
monkeypatch,
):
"""测试检查更新 API,使用 Mock 避免真实网络调用。"""
test_client = app.test_client() test_client = app.test_client()
# Mock 更新检查和网络请求
async def mock_check_update(*args, **kwargs):
"""Mock 更新检查,返回无新版本。"""
return None # None 表示没有新版本
async def mock_get_dashboard_version(*args, **kwargs):
"""Mock Dashboard 版本获取。"""
from astrbot.core.config.default import VERSION
return f"v{VERSION}" # 返回当前版本
monkeypatch.setattr(
core_lifecycle_td.astrbot_updator,
"check_update",
mock_check_update,
)
monkeypatch.setattr(
"astrbot.dashboard.routes.update.get_dashboard_version",
mock_get_dashboard_version,
)
response = await test_client.get("/api/update/check", headers=authenticated_header) response = await test_client.get("/api/update/check", headers=authenticated_header)
assert response.status_code == 200 assert response.status_code == 200
data = await response.get_json() data = await response.get_json()
assert data["status"] == "success" assert data["status"] == "success"
assert data["data"]["has_new_version"] is False
@pytest.mark.asyncio @pytest.mark.asyncio
+49 -3
View File
@@ -16,6 +16,16 @@ class _version_info:
self.major = major self.major = major
self.minor = minor self.minor = minor
def __eq__(self, other):
if isinstance(other, tuple):
return (self.major, self.minor) == other[:2]
return (self.major, self.minor) == (other.major, other.minor)
def __ge__(self, other):
if isinstance(other, tuple):
return (self.major, self.minor) >= other[:2]
return (self.major, self.minor) >= (other.major, other.minor)
def test_check_env(monkeypatch): def test_check_env(monkeypatch):
version_info_correct = _version_info(3, 10) version_info_correct = _version_info(3, 10)
@@ -23,15 +33,51 @@ def test_check_env(monkeypatch):
monkeypatch.setattr(sys, "version_info", version_info_correct) monkeypatch.setattr(sys, "version_info", version_info_correct)
with mock.patch("os.makedirs") as mock_makedirs: with mock.patch("os.makedirs") as mock_makedirs:
check_env() check_env()
mock_makedirs.assert_any_call("data/config", exist_ok=True) # Check that makedirs was called with paths containing expected dirs
mock_makedirs.assert_any_call("data/plugins", exist_ok=True) called_paths = [call[0][0] for call in mock_makedirs.call_args_list]
mock_makedirs.assert_any_call("data/temp", exist_ok=True) # Use os.path.join for cross-platform path matching
assert any(p.rstrip(os.sep).endswith(os.path.join("data", "config")) for p in called_paths)
assert any(p.rstrip(os.sep).endswith(os.path.join("data", "plugins")) for p in called_paths)
assert any(p.rstrip(os.sep).endswith(os.path.join("data", "temp")) for p in called_paths)
monkeypatch.setattr(sys, "version_info", version_info_wrong) monkeypatch.setattr(sys, "version_info", version_info_wrong)
with pytest.raises(SystemExit): with pytest.raises(SystemExit):
check_env() check_env()
def test_version_info_comparisons():
"""Test _version_info comparison operators with tuples and other instances."""
v3_10 = _version_info(3, 10)
v3_9 = _version_info(3, 9)
v3_11 = _version_info(3, 11)
# Test __eq__ with tuples
assert v3_10 == (3, 10)
assert v3_10 != (3, 9)
assert v3_9 == (3, 9)
# Test __ge__ with tuples
assert v3_10 >= (3, 10)
assert v3_10 >= (3, 9)
assert not (v3_9 >= (3, 10))
assert v3_11 >= (3, 10)
# Test __eq__ with other _version_info instances
assert v3_10 == _version_info(3, 10)
assert v3_10 != v3_9
assert v3_10 == v3_10 # Same instance
assert v3_10 != v3_11
# Test __ge__ with other _version_info instances
assert v3_10 >= v3_10
assert v3_10 >= v3_9
assert not (v3_9 >= v3_10)
assert v3_11 >= v3_10
assert v3_11 >= v3_11 # Same instance
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_check_dashboard_files_not_exists(monkeypatch): async def test_check_dashboard_files_not_exists(monkeypatch):
"""Tests dashboard download when files do not exist.""" """Tests dashboard download when files do not exist."""
+159 -74
View File
@@ -1,65 +1,164 @@
import os import sys
from asyncio import Queue from asyncio import Queue
from pathlib import Path
from unittest.mock import MagicMock from unittest.mock import MagicMock
import pytest import pytest
import pytest_asyncio
from astrbot.core.config.astrbot_config import AstrBotConfig from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.db.sqlite import SQLiteDatabase from astrbot.core.db.sqlite import SQLiteDatabase
from astrbot.core.star.context import Context from astrbot.core.star.context import Context
from astrbot.core.star.star import star_registry from astrbot.core.star.star import star_map, star_registry
from astrbot.core.star.star_handler import star_handlers_registry from astrbot.core.star.star_handler import star_handlers_registry
from astrbot.core.star.star_manager import PluginManager from astrbot.core.star.star_manager import PluginManager
@pytest.fixture def _clear_module_cache() -> None:
def plugin_manager_pm(tmp_path): """Clear module cache for data module tree to ensure test isolation."""
"""Provides a fully isolated PluginManager instance for testing. modules_to_remove = [
- Uses a temporary directory for plugins. key for key in sys.modules if key == "data" or key.startswith("data.")
- Uses a temporary database. ]
- Creates a fresh context for each test. for key in modules_to_remove:
""" del sys.modules[key]
# Create temporary resources
temp_plugins_path = tmp_path / "plugins"
temp_plugins_path.mkdir() def _clear_registry(plugin_name: str) -> None:
temp_db_path = tmp_path / "test_db.db" """Clear plugin from global registries."""
# Clear star_registry (list)
star_registry[:] = [md for md in star_registry if md.name != plugin_name]
# Clear star_map (dict)
keys_to_remove = [
key for key, md in star_map.items() if md.name == plugin_name
]
for key in keys_to_remove:
del star_map[key]
# Clear star_handlers_registry (StarHandlerRegistry)
for handler in list(star_handlers_registry):
if plugin_name in (handler.handler_module_path or ""):
star_handlers_registry.remove(handler)
TEST_PLUGIN_REPO = "https://github.com/Soulter/helloworld"
TEST_PLUGIN_DIR = "helloworld"
TEST_PLUGIN_NAME = "helloworld"
def _write_local_test_plugin(plugin_dir: Path, repo_url: str) -> None:
plugin_dir.mkdir(parents=True, exist_ok=True)
(plugin_dir / "metadata.yaml").write_text(
"\n".join(
[
f"name: {TEST_PLUGIN_NAME}",
"author: AstrBot Team",
"desc: Local test plugin",
"version: 1.0.0",
f"repo: {repo_url}",
],
)
+ "\n",
encoding="utf-8",
)
(plugin_dir / "main.py").write_text(
"\n".join(
[
"from astrbot.api import star",
"",
"class Main(star.Star):",
" pass",
"",
],
),
encoding="utf-8",
)
@pytest_asyncio.fixture
async def plugin_manager_pm(tmp_path, monkeypatch):
"""Provides a fully isolated PluginManager instance for testing."""
# Clear module cache before setup to ensure isolation
_clear_module_cache()
test_root = tmp_path / "astrbot_root"
data_dir = test_root / "data"
plugin_dir = data_dir / "plugins"
config_dir = data_dir / "config"
temp_dir = data_dir / "temp"
for path in (plugin_dir, config_dir, temp_dir):
path.mkdir(parents=True, exist_ok=True)
# Ensure `import data.plugins.<plugin>.main` resolves to this temp root.
(data_dir / "__init__.py").write_text("", encoding="utf-8")
(plugin_dir / "__init__.py").write_text("", encoding="utf-8")
# Use monkeypatch for both env var and sys.path to ensure proper cleanup
monkeypatch.setenv("ASTRBOT_ROOT", str(test_root))
monkeypatch.syspath_prepend(str(test_root))
# Create fresh, isolated instances for the context # Create fresh, isolated instances for the context
event_queue = Queue() event_queue = Queue()
config = AstrBotConfig() config = AstrBotConfig()
db = SQLiteDatabase(str(temp_db_path)) db = SQLiteDatabase(str(data_dir / "test_db.db"))
config.plugin_store_path = str(plugin_dir)
# Set the plugin store path in the config to the temporary directory
config.plugin_store_path = str(temp_plugins_path)
# Mock dependencies for the context
provider_manager = MagicMock() provider_manager = MagicMock()
platform_manager = MagicMock() platform_manager = MagicMock()
conversation_manager = MagicMock() conversation_manager = MagicMock()
message_history_manager = MagicMock() message_history_manager = MagicMock()
persona_manager = MagicMock() persona_manager = MagicMock()
persona_manager.personas_v3 = []
astrbot_config_mgr = MagicMock() astrbot_config_mgr = MagicMock()
knowledge_base_manager = MagicMock() knowledge_base_manager = MagicMock()
cron_manager = MagicMock()
star_context = Context( star_context = Context(
event_queue, event_queue=event_queue,
config, config=config,
db, db=db,
provider_manager, provider_manager=provider_manager,
platform_manager, platform_manager=platform_manager,
conversation_manager, conversation_manager=conversation_manager,
message_history_manager, message_history_manager=message_history_manager,
persona_manager, persona_manager=persona_manager,
astrbot_config_mgr, astrbot_config_mgr=astrbot_config_mgr,
knowledge_base_manager=knowledge_base_manager, knowledge_base_manager=knowledge_base_manager,
cron_manager=cron_manager,
subagent_orchestrator=None,
) )
# Create the PluginManager instance
manager = PluginManager(star_context, config) manager = PluginManager(star_context, config)
return manager try:
yield manager
finally:
# Cleanup global registries and module cache
_clear_registry(TEST_PLUGIN_NAME)
_clear_module_cache()
await db.engine.dispose()
def test_plugin_manager_initialization(plugin_manager_pm: PluginManager): @pytest.fixture
def local_updator(plugin_manager_pm: PluginManager, monkeypatch):
plugin_path = Path(plugin_manager_pm.plugin_store_path) / TEST_PLUGIN_DIR
async def mock_install(repo_url: str, proxy=""): # noqa: ARG001
if repo_url != TEST_PLUGIN_REPO:
raise Exception("Repo not found")
_write_local_test_plugin(plugin_path, repo_url)
return str(plugin_path)
async def mock_update(plugin, proxy=""): # noqa: ARG001
if plugin.name != TEST_PLUGIN_NAME:
raise Exception("Plugin not found")
if not plugin_path.exists():
raise Exception("Plugin path missing")
(plugin_path / ".updated").write_text("ok", encoding="utf-8")
monkeypatch.setattr(plugin_manager_pm.updator, "install", mock_install)
monkeypatch.setattr(plugin_manager_pm.updator, "update", mock_update)
return plugin_path
@pytest.mark.asyncio
async def test_plugin_manager_initialization(plugin_manager_pm: PluginManager):
assert plugin_manager_pm is not None assert plugin_manager_pm is not None
assert plugin_manager_pm.context is not None assert plugin_manager_pm.context is not None
assert plugin_manager_pm.config is not None assert plugin_manager_pm.config is not None
@@ -73,73 +172,59 @@ async def test_plugin_manager_reload(plugin_manager_pm: PluginManager):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_install_plugin(plugin_manager_pm: PluginManager): async def test_install_plugin(plugin_manager_pm: PluginManager, local_updator: Path):
"""Tests successful plugin installation in an isolated environment.""" """Tests successful plugin installation without external network."""
test_repo = "https://github.com/Soulter/astrbot_plugin_essential" plugin_info = await plugin_manager_pm.install_plugin(TEST_PLUGIN_REPO)
plugin_info = await plugin_manager_pm.install_plugin(test_repo)
plugin_path = os.path.join(
plugin_manager_pm.plugin_store_path,
"astrbot_plugin_essential",
)
assert plugin_info is not None assert plugin_info is not None
assert os.path.exists(plugin_path) assert plugin_info["name"] == TEST_PLUGIN_NAME
assert any(md.name == "astrbot_plugin_essential" for md in star_registry), ( assert local_updator.exists()
"Plugin 'astrbot_plugin_essential' was not loaded into star_registry." assert any(md.name == TEST_PLUGIN_NAME for md in star_registry)
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_install_nonexistent_plugin(plugin_manager_pm: PluginManager): async def test_install_nonexistent_plugin(
plugin_manager_pm: PluginManager, local_updator
):
"""Tests that installing a non-existent plugin raises an exception.""" """Tests that installing a non-existent plugin raises an exception."""
with pytest.raises(Exception): with pytest.raises(Exception):
await plugin_manager_pm.install_plugin( await plugin_manager_pm.install_plugin(
"https://github.com/Soulter/non_existent_repo", "https://github.com/Soulter/non_existent_repo"
) )
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_update_plugin(plugin_manager_pm: PluginManager): async def test_update_plugin(plugin_manager_pm: PluginManager, local_updator: Path):
"""Tests updating an existing plugin in an isolated environment.""" """Tests updating an existing plugin without external network."""
# First, install the plugin plugin_info = await plugin_manager_pm.install_plugin(TEST_PLUGIN_REPO)
test_repo = "https://github.com/Soulter/astrbot_plugin_essential" assert plugin_info is not None
await plugin_manager_pm.install_plugin(test_repo) plugin_name = plugin_info["name"]
await plugin_manager_pm.update_plugin(plugin_name)
# Then, update it assert (local_updator / ".updated").exists()
await plugin_manager_pm.update_plugin("astrbot_plugin_essential")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_update_nonexistent_plugin(plugin_manager_pm: PluginManager): async def test_update_nonexistent_plugin(
plugin_manager_pm: PluginManager, local_updator
):
"""Tests that updating a non-existent plugin raises an exception.""" """Tests that updating a non-existent plugin raises an exception."""
with pytest.raises(Exception): with pytest.raises(Exception):
await plugin_manager_pm.update_plugin("non_existent_plugin") await plugin_manager_pm.update_plugin("non_existent_plugin")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_uninstall_plugin(plugin_manager_pm: PluginManager): async def test_uninstall_plugin(plugin_manager_pm: PluginManager, local_updator: Path):
"""Tests successful plugin uninstallation in an isolated environment.""" """Tests successful plugin uninstallation."""
# First, install the plugin plugin_info = await plugin_manager_pm.install_plugin(TEST_PLUGIN_REPO)
test_repo = "https://github.com/Soulter/astrbot_plugin_essential" assert plugin_info is not None
await plugin_manager_pm.install_plugin(test_repo) plugin_name = plugin_info["name"]
plugin_path = os.path.join( assert local_updator.exists()
plugin_manager_pm.plugin_store_path,
"astrbot_plugin_essential",
)
assert os.path.exists(plugin_path) # Pre-condition
# Then, uninstall it await plugin_manager_pm.uninstall_plugin(plugin_name)
await plugin_manager_pm.uninstall_plugin("astrbot_plugin_essential")
assert not os.path.exists(plugin_path) assert not local_updator.exists()
assert not any(md.name == "astrbot_plugin_essential" for md in star_registry), ( assert not any(md.name == TEST_PLUGIN_NAME for md in star_registry)
"Plugin 'astrbot_plugin_essential' was not unloaded from star_registry."
)
assert not any( assert not any(
"astrbot_plugin_essential" in md.handler_module_path TEST_PLUGIN_NAME in md.handler_module_path for md in star_handlers_registry
for md in star_handlers_registry
), (
"Plugin 'astrbot_plugin_essential' handler was not unloaded from star_handlers_registry."
) )
+6
View File
@@ -101,10 +101,16 @@ def test_pipeline_import_is_stable_with_mocked_apscheduler() -> None:
"mock_apscheduler.schedulers = MagicMock();" "mock_apscheduler.schedulers = MagicMock();"
"mock_apscheduler.schedulers.asyncio = MagicMock();" "mock_apscheduler.schedulers.asyncio = MagicMock();"
"mock_apscheduler.schedulers.background = MagicMock();" "mock_apscheduler.schedulers.background = MagicMock();"
"mock_apscheduler.triggers = MagicMock();"
"mock_apscheduler.triggers.cron = MagicMock();"
"mock_apscheduler.triggers.date = MagicMock();"
"sys.modules['apscheduler'] = mock_apscheduler;" "sys.modules['apscheduler'] = mock_apscheduler;"
"sys.modules['apscheduler.schedulers'] = mock_apscheduler.schedulers;" "sys.modules['apscheduler.schedulers'] = mock_apscheduler.schedulers;"
"sys.modules['apscheduler.schedulers.asyncio'] = mock_apscheduler.schedulers.asyncio;" "sys.modules['apscheduler.schedulers.asyncio'] = mock_apscheduler.schedulers.asyncio;"
"sys.modules['apscheduler.schedulers.background'] = mock_apscheduler.schedulers.background;" "sys.modules['apscheduler.schedulers.background'] = mock_apscheduler.schedulers.background;"
"sys.modules['apscheduler.triggers'] = mock_apscheduler.triggers;"
"sys.modules['apscheduler.triggers.cron'] = mock_apscheduler.triggers.cron;"
"sys.modules['apscheduler.triggers.date'] = mock_apscheduler.triggers.date;"
"import astrbot.core.pipeline as pipeline;" "import astrbot.core.pipeline as pipeline;"
"assert pipeline.ProcessStage is not None;" "assert pipeline.ProcessStage is not None;"
"assert pipeline.RespondStage is not None" "assert pipeline.RespondStage is not None"
+2 -1
View File
@@ -461,7 +461,8 @@ async def test_stop_signal_returns_aborted_and_persists_partial_message(
final_resp = runner.get_final_llm_resp() final_resp = runner.get_final_llm_resp()
assert final_resp is not None assert final_resp is not None
assert final_resp.role == "assistant" assert final_resp.role == "assistant"
assert final_resp.completion_text == "partial " # When interrupted, the runner replaces completion_text with a system message
assert "interrupted" in final_resp.completion_text.lower()
assert runner.run_context.messages[-1].role == "assistant" assert runner.run_context.messages[-1].role == "assistant"
+607
View File
@@ -0,0 +1,607 @@
"""Tests for config module."""
import json
import os
import pytest
from astrbot.core.config.astrbot_config import AstrBotConfig, RateLimitStrategy
from astrbot.core.config.default import DEFAULT_VALUE_MAP
from astrbot.core.config.i18n_utils import ConfigMetadataI18n
@pytest.fixture
def temp_config_path(tmp_path):
"""Create a temporary config path."""
return str(tmp_path / "test_config.json")
@pytest.fixture
def minimal_default_config():
"""Create a minimal default config for testing."""
return {
"config_version": 2,
"platform_settings": {
"unique_session": False,
"rate_limit": {
"time": 60,
"count": 30,
"strategy": "stall",
},
},
"provider_settings": {
"enable": True,
"default_provider_id": "",
},
}
class TestRateLimitStrategy:
"""Tests for RateLimitStrategy enum."""
def test_stall_value(self):
"""Test stall enum value."""
assert RateLimitStrategy.STALL.value == "stall"
def test_discard_value(self):
"""Test discard enum value."""
assert RateLimitStrategy.DISCARD.value == "discard"
class TestAstrBotConfigLoad:
"""Tests for AstrBotConfig loading and initialization."""
def test_init_creates_file_if_not_exists(
self, temp_config_path, minimal_default_config
):
"""Test that config file is created when it doesn't exist."""
assert not os.path.exists(temp_config_path)
config = AstrBotConfig(
config_path=temp_config_path, default_config=minimal_default_config
)
assert os.path.exists(temp_config_path)
assert config.config_version == 2
assert config.platform_settings["unique_session"] is False
def test_init_loads_existing_file(self, temp_config_path, minimal_default_config):
"""Test that existing config file is loaded."""
existing_config = {
"config_version": 2,
"platform_settings": {"unique_session": True},
"provider_settings": {"enable": False},
}
with open(temp_config_path, "w", encoding="utf-8-sig") as f:
json.dump(existing_config, f)
config = AstrBotConfig(
config_path=temp_config_path, default_config=minimal_default_config
)
assert config.platform_settings["unique_session"] is True
assert config.provider_settings["enable"] is False
def test_first_deploy_flag(self, temp_config_path, minimal_default_config):
"""Test first_deploy flag is set for new config."""
config = AstrBotConfig(
config_path=temp_config_path, default_config=minimal_default_config
)
assert hasattr(config, "first_deploy")
assert config.first_deploy is True
def test_init_with_schema(self, temp_config_path):
"""Test initialization with schema."""
schema = {
"test_field": {
"type": "string",
"default": "test_value",
},
"nested": {
"type": "object",
"items": {
"enabled": {"type": "bool"},
"count": {"type": "int"},
},
},
}
config = AstrBotConfig(config_path=temp_config_path, schema=schema)
assert config.test_field == "test_value"
assert config.nested["enabled"] is False
assert config.nested["count"] == 0
def test_dot_notation_access(self, temp_config_path, minimal_default_config):
"""Test accessing config values using dot notation."""
config = AstrBotConfig(
config_path=temp_config_path, default_config=minimal_default_config
)
assert config.platform_settings is not None
assert config.non_existent_field is None
def test_setattr_updates_config(self, temp_config_path, minimal_default_config):
"""Test that setting attributes updates config."""
config = AstrBotConfig(
config_path=temp_config_path, default_config=minimal_default_config
)
config.new_field = "new_value"
assert config.new_field == "new_value"
def test_delattr_removes_field(self, temp_config_path, minimal_default_config):
"""Test that deleting attributes removes them."""
config = AstrBotConfig(
config_path=temp_config_path, default_config=minimal_default_config
)
config.temp_field = "temp"
del config.temp_field
# Accessing a deleted field returns None due to __getattr__
assert config.temp_field is None
# But the field is removed from the dict
assert "temp_field" not in config
def test_delattr_saves_config(self, temp_config_path, minimal_default_config):
"""Test that deleting attributes saves config to file."""
config = AstrBotConfig(
config_path=temp_config_path, default_config=minimal_default_config
)
config.temp_field = "temp"
del config.temp_field
with open(temp_config_path, encoding="utf-8-sig") as f:
loaded_config = json.load(f)
assert "temp_field" not in loaded_config
def test_check_exist(self, temp_config_path, minimal_default_config):
"""Test check_exist method."""
config = AstrBotConfig(
config_path=temp_config_path, default_config=minimal_default_config
)
assert config.check_exist() is True
# Create a path that definitely doesn't exist
import pathlib
temp_dir = pathlib.Path(temp_config_path).parent
non_existent_path = str(temp_dir / "non_existent_config.json")
# Check that the file doesn't exist before creating config
assert not os.path.exists(non_existent_path)
# Create config which will auto-create the file
config2 = AstrBotConfig(
config_path=non_existent_path, default_config=minimal_default_config
)
# Now it exists
assert config2.check_exist() is True
assert os.path.exists(non_existent_path)
class TestConfigValidation:
"""Tests for config validation and integrity checking."""
def test_insert_missing_config_items(
self, temp_config_path, minimal_default_config
):
"""Test that missing config items are inserted with default values."""
existing_config = {"config_version": 2}
with open(temp_config_path, "w", encoding="utf-8-sig") as f:
json.dump(existing_config, f)
config = AstrBotConfig(
config_path=temp_config_path, default_config=minimal_default_config
)
assert "platform_settings" in config
assert "provider_settings" in config
def test_replace_none_with_default(self, temp_config_path, minimal_default_config):
"""Test that None values are replaced with defaults."""
existing_config = {
"config_version": 2,
"platform_settings": None,
"provider_settings": None,
}
with open(temp_config_path, "w", encoding="utf-8-sig") as f:
json.dump(existing_config, f)
AstrBotConfig(
config_path=temp_config_path, default_config=minimal_default_config
)
# Reload to verify the values were replaced
config2 = AstrBotConfig(
config_path=temp_config_path, default_config=minimal_default_config
)
assert config2.platform_settings is not None
assert config2.provider_settings is not None
def test_reorder_config_keys(self, temp_config_path, minimal_default_config):
"""Test that config keys are reordered to match default."""
existing_config = {
"provider_settings": {"enable": True},
"config_version": 2,
"platform_settings": {"unique_session": False},
}
with open(temp_config_path, "w", encoding="utf-8-sig") as f:
json.dump(existing_config, f)
AstrBotConfig(
config_path=temp_config_path, default_config=minimal_default_config
)
with open(temp_config_path, encoding="utf-8-sig") as f:
loaded_config = json.load(f)
keys = list(loaded_config.keys())
assert keys[0] == "config_version"
assert keys[1] == "platform_settings"
assert keys[2] == "provider_settings"
def test_remove_unknown_config_keys(self, temp_config_path, minimal_default_config):
"""Test that unknown config keys are removed."""
existing_config = {
"config_version": 2,
"platform_settings": {},
"unknown_key": "should_be_removed",
}
with open(temp_config_path, "w", encoding="utf-8-sig") as f:
json.dump(existing_config, f)
config = AstrBotConfig(
config_path=temp_config_path, default_config=minimal_default_config
)
assert "unknown_key" not in config
def test_nested_config_validation(self, temp_config_path):
"""Test validation of nested config structures."""
default_config = {
"nested": {
"level1": {
"level2": {
"value": 42,
},
},
},
}
existing_config = {
"nested": {
"level1": {}, # Missing level2
},
}
with open(temp_config_path, "w", encoding="utf-8-sig") as f:
json.dump(existing_config, f)
config = AstrBotConfig(
config_path=temp_config_path, default_config=default_config
)
assert "level2" in config.nested["level1"]
assert config.nested["level1"]["level2"]["value"] == 42
class TestConfigHotReload:
"""Tests for config hot reload functionality."""
def test_save_config(self, temp_config_path, minimal_default_config):
"""Test saving config to file."""
config = AstrBotConfig(
config_path=temp_config_path, default_config=minimal_default_config
)
config.new_field = "new_value"
config.save_config()
with open(temp_config_path, encoding="utf-8-sig") as f:
loaded_config = json.load(f)
assert loaded_config["new_field"] == "new_value"
def test_save_config_with_replace(self, temp_config_path, minimal_default_config):
"""Test saving config with replacement."""
config = AstrBotConfig(
config_path=temp_config_path, default_config=minimal_default_config
)
replacement_config = {
"replaced": True,
"extra_field": "value",
}
config.save_config(replace_config=replacement_config)
with open(temp_config_path, encoding="utf-8-sig") as f:
loaded_config = json.load(f)
# The replacement config is merged with existing config
assert loaded_config["replaced"] is True
assert loaded_config["extra_field"] == "value"
# Original fields are preserved because update merges
assert "platform_settings" in loaded_config
def test_modification_persists_after_reload(
self, temp_config_path, minimal_default_config
):
"""Test that modifications persist after reloading."""
config1 = AstrBotConfig(
config_path=temp_config_path, default_config=minimal_default_config
)
config1.platform_settings["unique_session"] = True
config1.save_config()
config2 = AstrBotConfig(
config_path=temp_config_path, default_config=minimal_default_config
)
assert config2.platform_settings["unique_session"] is True
class TestConfigSchemaToDefault:
"""Tests for schema to default config conversion."""
def test_convert_schema_with_defaults(self, temp_config_path):
"""Test converting schema with explicit defaults."""
schema = {
"string_field": {"type": "string", "default": "custom"},
"int_field": {"type": "int", "default": 100},
"bool_field": {"type": "bool", "default": True},
}
config = AstrBotConfig(config_path=temp_config_path, schema=schema)
assert config.string_field == "custom"
assert config.int_field == 100
assert config.bool_field is True
def test_convert_schema_without_defaults(self, temp_config_path):
"""Test converting schema using default value map."""
schema = {
"string_field": {"type": "string"},
"int_field": {"type": "int"},
"bool_field": {"type": "bool"},
}
config = AstrBotConfig(config_path=temp_config_path, schema=schema)
assert config.string_field == DEFAULT_VALUE_MAP["string"]
assert config.int_field == DEFAULT_VALUE_MAP["int"]
assert config.bool_field == DEFAULT_VALUE_MAP["bool"]
def test_unsupported_schema_type_raises_error(self, temp_config_path):
"""Test that unsupported schema types raise error."""
schema = {
"field": {"type": "unsupported_type"},
}
with pytest.raises(TypeError, match="不受支持的配置类型"):
AstrBotConfig(config_path=temp_config_path, schema=schema)
def test_template_list_type(self, temp_config_path):
"""Test template_list schema type."""
schema = {
"templates": {"type": "template_list", "default": []},
}
config = AstrBotConfig(config_path=temp_config_path, schema=schema)
assert config.templates == []
def test_nested_object_schema(self, temp_config_path):
"""Test nested object schema conversion."""
schema = {
"nested": {
"type": "object",
"items": {
"field1": {"type": "string"},
"field2": {"type": "int"},
},
},
}
config = AstrBotConfig(config_path=temp_config_path, schema=schema)
assert config.nested["field1"] == ""
assert config.nested["field2"] == 0
class TestConfigMetadataI18n:
"""Tests for i18n utils."""
def test_get_i18n_key(self):
"""Test generating i18n key."""
key = ConfigMetadataI18n._get_i18n_key(
group="ai_group",
section="general",
field="enable",
attr="description",
)
assert key == "ai_group.general.enable.description"
def test_get_i18n_key_without_field(self):
"""Test generating i18n key without field."""
key = ConfigMetadataI18n._get_i18n_key(
group="ai_group",
section="general",
field="",
attr="description",
)
assert key == "ai_group.general.description"
def test_convert_to_i18n_keys_simple(self):
"""Test converting simple metadata to i18n keys."""
metadata = {
"ai_group": {
"name": "AI Settings",
"metadata": {
"general": {
"description": "General settings",
"items": {
"enable": {
"description": "Enable feature",
"type": "bool",
"default": True,
},
},
},
},
},
}
result = ConfigMetadataI18n.convert_to_i18n_keys(metadata)
assert result["ai_group"]["name"] == "ai_group.name"
assert (
result["ai_group"]["metadata"]["general"]["description"]
== "ai_group.general.description"
)
assert (
result["ai_group"]["metadata"]["general"]["items"]["enable"]["description"]
== "ai_group.general.enable.description"
)
def test_convert_to_i18n_keys_with_hint(self):
"""Test converting metadata with hint."""
metadata = {
"group": {
"metadata": {
"section": {
"hint": "This is a hint",
"items": {
"field": {
"hint": "Field hint",
"type": "string",
},
},
},
},
},
}
result = ConfigMetadataI18n.convert_to_i18n_keys(metadata)
assert result["group"]["metadata"]["section"]["hint"] == "group.section.hint"
assert (
result["group"]["metadata"]["section"]["items"]["field"]["hint"]
== "group.section.field.hint"
)
def test_convert_to_i18n_keys_with_labels(self):
"""Test converting metadata with labels."""
metadata = {
"group": {
"metadata": {
"section": {
"items": {
"field": {
"labels": ["Label1", "Label2"],
"type": "string",
},
},
},
},
},
}
result = ConfigMetadataI18n.convert_to_i18n_keys(metadata)
assert (
result["group"]["metadata"]["section"]["items"]["field"]["labels"]
== "group.section.field.labels"
)
def test_convert_to_i18n_keys_nested_items(self):
"""Test converting metadata with nested items."""
metadata = {
"group": {
"metadata": {
"section": {
"items": {
"nested": {
"description": "Nested field",
"type": "object",
"items": {
"inner": {
"description": "Inner field",
"type": "string",
},
},
},
},
},
},
},
}
result = ConfigMetadataI18n.convert_to_i18n_keys(metadata)
assert (
result["group"]["metadata"]["section"]["items"]["nested"]["description"]
== "group.section.nested.description"
)
assert (
result["group"]["metadata"]["section"]["items"]["nested"]["items"]["inner"][
"description"
]
== "group.section.nested.inner.description"
)
def test_convert_to_i18n_keys_preserves_non_i18n_fields(self):
"""Test that non-i18n fields are preserved."""
metadata = {
"group": {
"metadata": {
"section": {
"items": {
"field": {
"description": "Field description",
"type": "string",
"other_field": "preserve this",
},
},
},
},
},
}
result = ConfigMetadataI18n.convert_to_i18n_keys(metadata)
assert (
result["group"]["metadata"]["section"]["items"]["field"]["other_field"]
== "preserve this"
)
def test_convert_to_i18n_keys_with_name(self):
"""Test converting metadata with name field."""
metadata = {
"group": {
"metadata": {
"section": {
"items": {
"field": {
"name": "Field Name",
"type": "string",
},
},
},
},
},
}
result = ConfigMetadataI18n.convert_to_i18n_keys(metadata)
assert (
result["group"]["metadata"]["section"]["items"]["field"]["name"]
== "group.section.field.name"
)
+504
View File
@@ -0,0 +1,504 @@
"""Tests for CronJobManager."""
from datetime import datetime, timedelta, timezone
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from astrbot.core.cron.manager import CronJobManager
from astrbot.core.db.po import CronJob
@pytest.fixture
def mock_db():
"""Create a mock database."""
db = MagicMock()
db.create_cron_job = AsyncMock()
db.get_cron_job = AsyncMock()
db.update_cron_job = AsyncMock()
db.delete_cron_job = AsyncMock()
db.list_cron_jobs = AsyncMock(return_value=[])
return db
@pytest.fixture
def mock_context():
"""Create a mock Context."""
ctx = MagicMock()
ctx.get_config = MagicMock(return_value={"admins_id": []})
ctx.conversation_manager = MagicMock()
return ctx
@pytest.fixture
def cron_manager(mock_db):
"""Create a CronJobManager instance."""
return CronJobManager(mock_db)
@pytest.fixture
def sample_cron_job():
"""Create a sample CronJob."""
return CronJob(
job_id="test-job-id",
name="Test Job",
job_type="basic",
cron_expression="0 9 * * *",
timezone="UTC",
payload={"key": "value"},
description="A test job",
enabled=True,
persistent=True,
run_once=False,
status="pending",
)
class TestCronJobManagerInit:
"""Tests for CronJobManager initialization."""
def test_init(self, mock_db):
"""Test CronJobManager initialization."""
manager = CronJobManager(mock_db)
assert manager.db == mock_db
assert manager._basic_handlers == {}
assert manager._started is False
class TestCronJobManagerStart:
"""Tests for CronJobManager.start method."""
@pytest.mark.asyncio
async def test_start(self, cron_manager, mock_db, mock_context):
"""Test starting the cron manager."""
mock_db.list_cron_jobs.return_value = []
await cron_manager.start(mock_context)
assert cron_manager._started is True
assert cron_manager.ctx == mock_context
@pytest.mark.asyncio
async def test_start_idempotent(self, cron_manager, mock_db, mock_context):
"""Test that start is idempotent."""
mock_db.list_cron_jobs.return_value = []
await cron_manager.start(mock_context)
await cron_manager.start(mock_context)
# Should only sync once
assert mock_db.list_cron_jobs.call_count == 1
class TestCronJobManagerShutdown:
"""Tests for CronJobManager.shutdown method."""
@pytest.mark.asyncio
async def test_shutdown(self, cron_manager, mock_db, mock_context):
"""Test shutting down the cron manager."""
mock_db.list_cron_jobs.return_value = []
await cron_manager.start(mock_context)
await cron_manager.shutdown()
assert cron_manager._started is False
@pytest.mark.asyncio
async def test_shutdown_when_not_started(self, cron_manager):
"""Test shutdown when not started."""
# Should not raise
await cron_manager.shutdown()
class TestAddBasicJob:
"""Tests for add_basic_job method."""
@pytest.mark.asyncio
async def test_add_basic_job(self, cron_manager, mock_db, sample_cron_job):
"""Test adding a basic cron job."""
mock_db.create_cron_job.return_value = sample_cron_job
handler = MagicMock()
result = await cron_manager.add_basic_job(
name="Test Job",
cron_expression="0 9 * * *",
handler=handler,
description="A test job",
enabled=True,
)
assert result == sample_cron_job
assert sample_cron_job.job_id in cron_manager._basic_handlers
mock_db.create_cron_job.assert_called_once()
@pytest.mark.asyncio
async def test_add_basic_job_disabled(self, cron_manager, mock_db, sample_cron_job):
"""Test adding a disabled basic cron job."""
sample_cron_job.enabled = False
mock_db.create_cron_job.return_value = sample_cron_job
handler = MagicMock()
result = await cron_manager.add_basic_job(
name="Test Job",
cron_expression="0 9 * * *",
handler=handler,
enabled=False,
)
assert result == sample_cron_job
assert sample_cron_job.job_id in cron_manager._basic_handlers
@pytest.mark.asyncio
async def test_add_basic_job_with_timezone(self, cron_manager, mock_db, sample_cron_job):
"""Test adding a basic job with timezone."""
mock_db.create_cron_job.return_value = sample_cron_job
handler = MagicMock()
await cron_manager.add_basic_job(
name="Test Job",
cron_expression="0 9 * * *",
handler=handler,
timezone="Asia/Shanghai",
)
mock_db.create_cron_job.assert_called_once()
call_kwargs = mock_db.create_cron_job.call_args.kwargs
assert call_kwargs["timezone"] == "Asia/Shanghai"
class TestAddActiveJob:
"""Tests for add_active_job method."""
@pytest.mark.asyncio
async def test_add_active_job(self, cron_manager, mock_db, sample_cron_job):
"""Test adding an active agent cron job."""
sample_cron_job.job_type = "active_agent"
mock_db.create_cron_job.return_value = sample_cron_job
result = await cron_manager.add_active_job(
name="Test Active Job",
cron_expression="0 9 * * *",
payload={"session": "test:group:123"},
)
assert result == sample_cron_job
mock_db.create_cron_job.assert_called_once()
@pytest.mark.asyncio
async def test_add_active_job_run_once(self, cron_manager, mock_db, sample_cron_job):
"""Test adding a run-once active job."""
sample_cron_job.job_type = "active_agent"
sample_cron_job.run_once = True
mock_db.create_cron_job.return_value = sample_cron_job
run_at = datetime.now(timezone.utc) + timedelta(days=30)
result = await cron_manager.add_active_job(
name="Test Run Once Job",
cron_expression=None,
payload={"session": "test:group:123"},
run_once=True,
run_at=run_at,
)
assert result == sample_cron_job
call_kwargs = mock_db.create_cron_job.call_args.kwargs
assert call_kwargs["run_once"] is True
class TestUpdateJob:
"""Tests for update_job method."""
@pytest.mark.asyncio
async def test_update_job(self, cron_manager, mock_db, sample_cron_job):
"""Test updating a cron job."""
updated_job = CronJob(
job_id="test-job-id",
name="Updated Job",
job_type="basic",
cron_expression="0 10 * * *",
enabled=False, # Disabled to avoid scheduling
)
mock_db.update_cron_job.return_value = updated_job
result = await cron_manager.update_job("test-job-id", name="Updated Job")
assert result == updated_job
mock_db.update_cron_job.assert_called()
@pytest.mark.asyncio
async def test_update_job_not_found(self, cron_manager, mock_db):
"""Test updating a non-existent job."""
mock_db.update_cron_job.return_value = None
result = await cron_manager.update_job("non-existent", name="Updated")
assert result is None
class TestDeleteJob:
"""Tests for delete_job method."""
@pytest.mark.asyncio
async def test_delete_job(self, cron_manager, mock_db):
"""Test deleting a cron job."""
cron_manager._basic_handlers["test-job-id"] = MagicMock()
await cron_manager.delete_job("test-job-id")
mock_db.delete_cron_job.assert_called_once_with("test-job-id")
assert "test-job-id" not in cron_manager._basic_handlers
class TestListJobs:
"""Tests for list_jobs method."""
@pytest.mark.asyncio
async def test_list_all_jobs(self, cron_manager, mock_db, sample_cron_job):
"""Test listing all jobs."""
mock_db.list_cron_jobs.return_value = [sample_cron_job]
result = await cron_manager.list_jobs()
assert len(result) == 1
mock_db.list_cron_jobs.assert_called_once_with(None)
@pytest.mark.asyncio
async def test_list_jobs_by_type(self, cron_manager, mock_db, sample_cron_job):
"""Test listing jobs by type."""
mock_db.list_cron_jobs.return_value = [sample_cron_job]
result = await cron_manager.list_jobs(job_type="basic")
assert len(result) == 1
mock_db.list_cron_jobs.assert_called_once_with("basic")
class TestSyncFromDb:
"""Tests for sync_from_db method."""
@pytest.mark.asyncio
async def test_sync_from_db_empty(self, cron_manager, mock_db):
"""Test syncing from empty database."""
mock_db.list_cron_jobs.return_value = []
await cron_manager.sync_from_db()
mock_db.list_cron_jobs.assert_called_once()
@pytest.mark.asyncio
async def test_sync_from_db_skips_disabled(self, cron_manager, mock_db, sample_cron_job):
"""Test that sync skips disabled jobs."""
sample_cron_job.enabled = False
mock_db.list_cron_jobs.return_value = [sample_cron_job]
with patch.object(cron_manager, "_schedule_job") as mock_schedule:
await cron_manager.sync_from_db()
mock_db.list_cron_jobs.assert_called_once()
mock_schedule.assert_not_called()
@pytest.mark.asyncio
async def test_sync_from_db_skips_non_persistent(self, cron_manager, mock_db, sample_cron_job):
"""Test that sync skips non-persistent jobs."""
sample_cron_job.persistent = False
mock_db.list_cron_jobs.return_value = [sample_cron_job]
with patch.object(cron_manager, "_schedule_job") as mock_schedule:
await cron_manager.sync_from_db()
mock_db.list_cron_jobs.assert_called_once()
mock_schedule.assert_not_called()
@pytest.mark.asyncio
async def test_sync_from_db_basic_without_handler(
self, cron_manager, mock_db, sample_cron_job
):
"""Test that sync warns for basic jobs without handlers."""
mock_db.list_cron_jobs.return_value = [sample_cron_job]
with patch("astrbot.core.cron.manager.logger") as mock_logger:
await cron_manager.sync_from_db()
mock_logger.warning.assert_called()
class TestRemoveScheduled:
"""Tests for _remove_scheduled method."""
@pytest.mark.asyncio
async def test_remove_scheduled_existing(self, cron_manager, mock_context):
"""Test removing a scheduled job."""
# Start the scheduler first
job = CronJob(
job_id="test-job-id",
name="Test",
job_type="active_agent",
cron_expression="0 9 * * *",
enabled=True,
persistent=True,
)
mock_db = cron_manager.db
mock_db.list_cron_jobs = AsyncMock(return_value=[job])
await cron_manager.start(mock_context)
# Then remove it
cron_manager._remove_scheduled("test-job-id")
# Should not raise
def test_remove_scheduled_nonexistent(self, cron_manager):
"""Test removing a non-existent job."""
# Should not raise
cron_manager._remove_scheduled("non-existent")
class TestScheduleJob:
"""Tests for _schedule_job method."""
@pytest.mark.asyncio
async def test_schedule_job_basic(self, cron_manager, sample_cron_job, mock_context):
"""Test scheduling a basic job."""
mock_db = cron_manager.db
mock_db.list_cron_jobs = AsyncMock(return_value=[])
mock_db.update_cron_job = AsyncMock()
await cron_manager.start(mock_context)
cron_manager._schedule_job(sample_cron_job)
# Verify job was added to scheduler
assert cron_manager.scheduler.get_job("test-job-id") is not None
@pytest.mark.asyncio
async def test_schedule_job_with_timezone(self, cron_manager, sample_cron_job, mock_context):
"""Test scheduling a job with timezone."""
sample_cron_job.timezone = "America/New_York"
mock_db = cron_manager.db
mock_db.list_cron_jobs = AsyncMock(return_value=[])
mock_db.update_cron_job = AsyncMock()
await cron_manager.start(mock_context)
cron_manager._schedule_job(sample_cron_job)
assert cron_manager.scheduler.get_job("test-job-id") is not None
@pytest.mark.asyncio
async def test_schedule_job_invalid_timezone(self, cron_manager, sample_cron_job, mock_context):
"""Test scheduling a job with invalid timezone."""
sample_cron_job.timezone = "Invalid/Timezone"
mock_db = cron_manager.db
mock_db.list_cron_jobs = AsyncMock(return_value=[])
mock_db.update_cron_job = AsyncMock()
with patch("astrbot.core.cron.manager.logger") as mock_logger:
await cron_manager.start(mock_context)
cron_manager._schedule_job(sample_cron_job)
# Should still schedule with system timezone
assert cron_manager.scheduler.get_job("test-job-id") is not None
mock_logger.warning.assert_called()
@pytest.mark.asyncio
async def test_schedule_job_run_once(self, cron_manager, mock_context):
"""Test scheduling a run-once job."""
future_date = datetime.now(timezone.utc) + timedelta(days=30)
job = CronJob(
job_id="run-once-job",
name="Run Once",
job_type="active_agent",
cron_expression=None,
enabled=True,
run_once=True,
payload={"run_at": future_date.isoformat()},
)
mock_db = cron_manager.db
mock_db.list_cron_jobs = AsyncMock(return_value=[])
mock_db.update_cron_job = AsyncMock()
await cron_manager.start(mock_context)
cron_manager._schedule_job(job)
assert cron_manager.scheduler.get_job("run-once-job") is not None
class TestRunJob:
"""Tests for _run_job method."""
@pytest.mark.asyncio
async def test_run_job_disabled(self, cron_manager, mock_db, sample_cron_job):
"""Test running a disabled job."""
sample_cron_job.enabled = False
mock_db.get_cron_job.return_value = sample_cron_job
await cron_manager._run_job("test-job-id")
# Should not update status
mock_db.update_cron_job.assert_not_called()
@pytest.mark.asyncio
async def test_run_job_not_found(self, cron_manager, mock_db):
"""Test running a non-existent job."""
mock_db.get_cron_job.return_value = None
await cron_manager._run_job("non-existent")
# Should not update status
mock_db.update_cron_job.assert_not_called()
class TestRunBasicJob:
"""Tests for _run_basic_job method."""
@pytest.mark.asyncio
async def test_run_basic_job_sync_handler(self, cron_manager, sample_cron_job):
"""Test running a basic job with sync handler."""
handler = MagicMock(return_value=None)
cron_manager._basic_handlers["test-job-id"] = handler
sample_cron_job.payload = {"arg1": "value1"}
await cron_manager._run_basic_job(sample_cron_job)
handler.assert_called_once_with(arg1="value1")
@pytest.mark.asyncio
async def test_run_basic_job_async_handler(self, cron_manager, sample_cron_job):
"""Test running a basic job with async handler."""
async_handler = AsyncMock()
cron_manager._basic_handlers["test-job-id"] = async_handler
sample_cron_job.payload = {}
await cron_manager._run_basic_job(sample_cron_job)
async_handler.assert_called_once()
@pytest.mark.asyncio
async def test_run_basic_job_no_handler(self, cron_manager, sample_cron_job):
"""Test running a basic job without handler."""
sample_cron_job.job_id = "no-handler-job"
with pytest.raises(RuntimeError, match="handler not found"):
await cron_manager._run_basic_job(sample_cron_job)
class TestGetNextRunTime:
"""Tests for _get_next_run_time method."""
@pytest.mark.asyncio
async def test_get_next_run_time_existing_job(self, cron_manager, sample_cron_job, mock_context):
"""Test getting next run time for existing job."""
mock_db = cron_manager.db
mock_db.list_cron_jobs = AsyncMock(return_value=[])
mock_db.update_cron_job = AsyncMock()
await cron_manager.start(mock_context)
cron_manager._schedule_job(sample_cron_job)
next_run = cron_manager._get_next_run_time("test-job-id")
assert next_run is not None
def test_get_next_run_time_nonexistent(self, cron_manager):
"""Test getting next run time for non-existent job."""
next_run = cron_manager._get_next_run_time("non-existent")
assert next_run is None
+198
View File
@@ -0,0 +1,198 @@
"""Tests for astrbot.core.star.base module."""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
class TestStarBase:
"""Test cases for the Star base class."""
def test_star_class_exists(self):
"""Test that Star class can be imported."""
from astrbot.core.star import Star
assert Star is not None
def test_star_init_with_context(self):
"""Test Star initialization with a context-like object."""
from astrbot.core.star import Star
# Create a mock context with get_config method
mock_context = MagicMock()
mock_context.get_config.return_value = MagicMock()
# Create a concrete Star subclass for testing
class TestStar(Star):
name = "test_star"
author = "test_author"
star = TestStar(context=mock_context)
assert star.context is mock_context
@pytest.mark.asyncio
async def test_text_to_image_with_config(self):
"""Test text_to_image method with valid config."""
from astrbot.core.star import Star
mock_context = MagicMock()
mock_config = MagicMock()
mock_config.get.return_value = "default_template"
mock_context.get_config.return_value = mock_config
class TestStar(Star):
name = "test_star"
author = "test_author"
star = TestStar(context=mock_context)
with patch(
"astrbot.core.star.base.html_renderer.render_t2i",
new_callable=AsyncMock,
) as mock_render:
mock_render.return_value = "http://example.com/image.png"
result = await star.text_to_image("test text", return_url=True)
mock_render.assert_called_once_with(
"test text",
return_url=True,
template_name="default_template",
)
assert result == "http://example.com/image.png"
@pytest.mark.asyncio
async def test_text_to_image_without_config(self):
"""Test text_to_image method when get_config returns None."""
from astrbot.core.star import Star
mock_context = MagicMock()
mock_context.get_config.return_value = None
class TestStar(Star):
name = "test_star"
author = "test_author"
star = TestStar(context=mock_context)
with patch(
"astrbot.core.star.base.html_renderer.render_t2i",
new_callable=AsyncMock,
) as mock_render:
mock_render.return_value = "http://example.com/image.png"
result = await star.text_to_image("test text", return_url=False)
mock_render.assert_called_once_with(
"test text",
return_url=False,
template_name=None,
)
assert result == "http://example.com/image.png"
@pytest.mark.asyncio
async def test_html_render(self):
"""Test html_render method."""
from astrbot.core.star import Star
mock_context = MagicMock()
class TestStar(Star):
name = "test_star"
author = "test_author"
star = TestStar(context=mock_context)
with patch(
"astrbot.core.star.base.html_renderer.render_custom_template",
new_callable=AsyncMock,
) as mock_render:
mock_render.return_value = "http://example.com/rendered.png"
result = await star.html_render(
"<html>{{ data }}</html>",
{"data": "test"},
return_url=True,
)
mock_render.assert_called_once_with(
"<html>{{ data }}</html>",
{"data": "test"},
return_url=True,
options=None,
)
assert result == "http://example.com/rendered.png"
@pytest.mark.asyncio
async def test_initialize_and_terminate(self):
"""Test that initialize and terminate methods can be overridden."""
from astrbot.core.star import Star
class TestStar(Star):
name = "test_star"
author = "test_author"
async def initialize(self) -> None:
self.initialized = True
async def terminate(self) -> None:
self.terminated = True
mock_context = MagicMock()
star = TestStar(context=mock_context)
await star.initialize()
assert star.initialized is True
await star.terminate()
assert star.terminated is True
def test_star_metadata_registration(self):
"""Test that Star subclass is automatically registered."""
from astrbot.core.star import star_map, star_registry
from astrbot.core.star.star import StarMetadata
# Clear any previous registration for this test module
module_path = __name__
class UniqueTestStar:
"""Not a Star subclass, should not be registered."""
pass
# Verify Star subclass gets registered
initial_count = len(star_registry)
# Note: This test verifies the __init_subclass__ mechanism
# The actual registration happens when a class inherits from Star
assert len(star_registry) >= initial_count
class TestNoCircularImports:
"""Test that there are no circular import issues."""
def test_import_star_module(self):
"""Test that star module can be imported without circular import errors."""
import astrbot.core.star
assert astrbot.core.star is not None
def test_import_pipeline_module(self):
"""Test that pipeline module can be imported without circular import errors."""
import astrbot.core.pipeline
assert astrbot.core.pipeline is not None
def test_import_both_modules(self):
"""Test that both modules can be imported together."""
import astrbot.core.pipeline
import astrbot.core.star
# Verify key exports are available
from astrbot.core.star import Context, Star, PluginManager
assert Context is not None
assert Star is not None
assert PluginManager is not None
def test_import_pipeline_context(self):
"""Test that PipelineContext can be imported."""
from astrbot.core.pipeline.context import PipelineContext
assert PipelineContext is not None