test: add tests for star base class and config management (#5356)
* test: add tests for star base class and config management - Add Star base class safety helper tests - Expand config management unit tests - Update cron manager tests Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * test: fix plugin_manager test isolation issues - Use local mock plugin instead of real network requests - Clear sys.modules cache for entire data module tree - Clear star_map and star_registry in teardown - Use pytest_asyncio.fixture for async fixture support Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * test: fix test isolation and compatibility issues - test_main.py: fix version comparison and path assertions for Windows - test_smoke.py: add missing apscheduler.triggers mock modules - test_tool_loop_agent_runner.py: update assertion for new interrupt behavior - test_api_key_open_api.py: use unique session IDs to avoid test conflicts Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * test: add unit tests for _version_info comparisons * test: enhance plugin manager tests with mock implementations and improved assertions * test: add mock plugin builder and updater for plugin management tests * fix: resolve pipeline and star import cycles (#5353) * fix: resolve pipeline and star import cycles - Add bootstrap.py and stage_order.py to break circular dependencies - Export Context, PluginManager, StarTools from star module - Update pipeline __init__ to defer imports - Split pipeline initialization into separate bootstrap module Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * fix: add logging for get_config() failure in Star class * fix: reorder logger initialization in base.py --------- Co-authored-by: whatevertogo <whatevertogo@users.noreply.github.com> Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com> * test: update cron job scheduling tests and refactor star base tests for clarity * test: expand star base tests for comprehensive coverage - Add tests for Star class initialization and context handling - Add tests for text_to_image with/without config - Add tests for html_render method - Add tests for initialize/terminate lifecycle methods - Add type hint validation tests for Context - Add circular import prevention tests Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix: address PR review feedback - use TYPE_CHECKING instead of Any - pipeline/context.py: Use TYPE_CHECKING to import PluginManager instead of Any - pipeline/__init__.py: Add TYPE_CHECKING imports for __all__ exports to satisfy static analyzers - star/register/star_handler.py: Use TYPE_CHECKING to import AstrAgentContext instead of Any - tests: Remove invalid type hint tests that tested incorrect assumptions Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix: improve TYPE_CHECKING pattern for circular import resolution - star/register/star_handler.py: Use AstrAgentContext instead of Any in generic types - star/context.py: Remove unnecessary else branch with CronJobManager = Any (with __future__ annotations, TYPE_CHECKING imports are sufficient) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> --------- Co-authored-by: whatevertogo <whatevertogo@users.noreply.github.com> Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com> Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>
This commit is contained in:
@@ -67,6 +67,18 @@ _LAZY_EXPORTS = {
|
|||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Type-checking imports to satisfy static analyzers for __all__ exports
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .content_safety_check.stage import ContentSafetyCheckStage
|
||||||
|
from .preprocess_stage.stage import PreProcessStage
|
||||||
|
from .process_stage.stage import ProcessStage
|
||||||
|
from .rate_limit_check.stage import RateLimitStage
|
||||||
|
from .respond.stage import RespondStage
|
||||||
|
from .result_decorate.stage import ResultDecorateStage
|
||||||
|
from .session_status_check.stage import SessionStatusCheckStage
|
||||||
|
from .waking_check.stage import WakingCheckStage
|
||||||
|
from .whitelist_check.stage import WhitelistCheckStage
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ContentSafetyCheckStage",
|
"ContentSafetyCheckStage",
|
||||||
"EventResultType",
|
"EventResultType",
|
||||||
|
|||||||
@@ -1,19 +1,22 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from astrbot.core.config import AstrBotConfig
|
from astrbot.core.config import AstrBotConfig
|
||||||
|
|
||||||
from .context_utils import call_event_hook, call_handler
|
from .context_utils import call_event_hook, call_handler
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from astrbot.core.star import PluginManager
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PipelineContext:
|
class PipelineContext:
|
||||||
"""上下文对象,包含管道执行所需的上下文信息"""
|
"""上下文对象,包含管道执行所需的上下文信息"""
|
||||||
|
|
||||||
astrbot_config: AstrBotConfig # AstrBot 配置对象
|
astrbot_config: AstrBotConfig # AstrBot 配置对象
|
||||||
plugin_manager: Any # 插件管理器对象
|
plugin_manager: PluginManager # 插件管理器对象
|
||||||
astrbot_config_id: str
|
astrbot_config_id: str
|
||||||
call_handler = call_handler
|
call_handler = call_handler
|
||||||
call_event_hook = call_event_hook
|
call_event_hook = call_event_hook
|
||||||
|
|||||||
@@ -47,8 +47,6 @@ logger = logging.getLogger("astrbot")
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from astrbot.core.cron.manager import CronJobManager
|
from astrbot.core.cron.manager import CronJobManager
|
||||||
else:
|
|
||||||
CronJobManager = Any
|
|
||||||
|
|
||||||
|
|
||||||
class PlatformManagerProtocol(Protocol):
|
class PlatformManagerProtocol(Protocol):
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import re
|
import re
|
||||||
from collections.abc import AsyncGenerator, Awaitable, Callable
|
from collections.abc import AsyncGenerator, Awaitable, Callable
|
||||||
from typing import Any
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
import docstring_parser
|
import docstring_parser
|
||||||
|
|
||||||
@@ -15,6 +15,9 @@ from astrbot.core.message.message_event_result import MessageEventResult
|
|||||||
from astrbot.core.provider.func_tool_manager import PY_TO_JSON_TYPE, SUPPORTED_TYPES
|
from astrbot.core.provider.func_tool_manager import PY_TO_JSON_TYPE, SUPPORTED_TYPES
|
||||||
from astrbot.core.provider.register import llm_tools
|
from astrbot.core.provider.register import llm_tools
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from astrbot.core.astr_agent_context import AstrAgentContext
|
||||||
|
|
||||||
from ..filter.command import CommandFilter
|
from ..filter.command import CommandFilter
|
||||||
from ..filter.command_group import CommandGroupFilter
|
from ..filter.command_group import CommandGroupFilter
|
||||||
from ..filter.custom_filter import CustomFilterAnd, CustomFilterOr
|
from ..filter.custom_filter import CustomFilterAnd, CustomFilterOr
|
||||||
@@ -616,7 +619,7 @@ class RegisteringAgent:
|
|||||||
kwargs["registering_agent"] = self
|
kwargs["registering_agent"] = self
|
||||||
return register_llm_tool(*args, **kwargs)
|
return register_llm_tool(*args, **kwargs)
|
||||||
|
|
||||||
def __init__(self, agent: Agent[Any]) -> None:
|
def __init__(self, agent: Agent[AstrAgentContext]) -> None:
|
||||||
self._agent = agent
|
self._agent = agent
|
||||||
|
|
||||||
|
|
||||||
@@ -624,7 +627,7 @@ def register_agent(
|
|||||||
name: str,
|
name: str,
|
||||||
instruction: str,
|
instruction: str,
|
||||||
tools: list[str | FunctionTool] | None = None,
|
tools: list[str | FunctionTool] | None = None,
|
||||||
run_hooks: BaseAgentRunHooks[Any] | None = None,
|
run_hooks: BaseAgentRunHooks[AstrAgentContext] | None = None,
|
||||||
):
|
):
|
||||||
"""注册一个 Agent
|
"""注册一个 Agent
|
||||||
|
|
||||||
@@ -638,12 +641,12 @@ def register_agent(
|
|||||||
tools_ = tools or []
|
tools_ = tools or []
|
||||||
|
|
||||||
def decorator(awaitable: Callable[..., Awaitable[Any]]):
|
def decorator(awaitable: Callable[..., Awaitable[Any]]):
|
||||||
AstrAgent = Agent[Any]
|
AstrAgent = Agent[AstrAgentContext]
|
||||||
agent = AstrAgent(
|
agent = AstrAgent(
|
||||||
name=name,
|
name=name,
|
||||||
instructions=instruction,
|
instructions=instruction,
|
||||||
tools=tools_,
|
tools=tools_,
|
||||||
run_hooks=run_hooks or BaseAgentRunHooks[Any](),
|
run_hooks=run_hooks or BaseAgentRunHooks[AstrAgentContext](),
|
||||||
)
|
)
|
||||||
handoff_tool = HandoffTool(agent=agent)
|
handoff_tool = HandoffTool(agent=agent)
|
||||||
handoff_tool.handler = awaitable
|
handoff_tool.handler = awaitable
|
||||||
|
|||||||
Vendored
+256
-1
@@ -3,7 +3,10 @@
|
|||||||
提供统一的测试辅助工具,减少测试代码重复。
|
提供统一的测试辅助工具,减少测试代码重复。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any
|
import shutil
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Callable
|
||||||
from unittest.mock import AsyncMock, MagicMock
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
from astrbot.core.message.components import BaseMessageComponent
|
from astrbot.core.message.components import BaseMessageComponent
|
||||||
@@ -330,3 +333,255 @@ def create_mock_llm_response(
|
|||||||
tools_call_ids=tools_call_ids or [],
|
tools_call_ids=tools_call_ids or [],
|
||||||
usage=TokenUsage(input_other=10, output=5),
|
usage=TokenUsage(input_other=10, output=5),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# 测试插件辅助函数
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MockPluginConfig:
|
||||||
|
"""测试插件配置。
|
||||||
|
|
||||||
|
用于创建和管理测试用的模拟插件。
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
name: 插件名称
|
||||||
|
author: 作者
|
||||||
|
description: 描述
|
||||||
|
version: 版本
|
||||||
|
repo: 仓库 URL
|
||||||
|
main_code: main.py 的代码内容
|
||||||
|
requirements: 依赖列表
|
||||||
|
has_readme: 是否创建 README.md
|
||||||
|
readme_content: README.md 内容
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: str = "test_plugin"
|
||||||
|
author: str = "Test Author"
|
||||||
|
description: str = "A test plugin for unit testing"
|
||||||
|
version: str = "1.0.0"
|
||||||
|
repo: str = "https://github.com/test/test_plugin"
|
||||||
|
main_code: str = ""
|
||||||
|
requirements: list[str] = field(default_factory=list)
|
||||||
|
has_readme: bool = True
|
||||||
|
readme_content: str = "# Test Plugin\n\nThis is a test plugin."
|
||||||
|
|
||||||
|
|
||||||
|
# 默认的插件主代码模板
|
||||||
|
DEFAULT_PLUGIN_MAIN_TEMPLATE = '''
|
||||||
|
from astrbot.api import star
|
||||||
|
|
||||||
|
class Main(star.Star):
|
||||||
|
"""测试插件主类。"""
|
||||||
|
|
||||||
|
def __init__(self, context):
|
||||||
|
super().__init__(context)
|
||||||
|
self.name = "{plugin_name}"
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
|
"""初始化插件。"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def terminate(self):
|
||||||
|
"""终止插件。"""
|
||||||
|
pass
|
||||||
|
'''
|
||||||
|
|
||||||
|
|
||||||
|
class MockPluginBuilder:
|
||||||
|
"""测试插件构建器。
|
||||||
|
|
||||||
|
用于创建、管理和清理测试用的模拟插件。支持任意插件的模拟创建。
|
||||||
|
|
||||||
|
Example:
|
||||||
|
# 创建一个简单的测试插件
|
||||||
|
builder = MockPluginBuilder(plugin_store_path)
|
||||||
|
plugin_dir = builder.create("my_test_plugin")
|
||||||
|
|
||||||
|
# 创建自定义配置的插件
|
||||||
|
config = MockPluginConfig(
|
||||||
|
name="custom_plugin",
|
||||||
|
version="2.0.0",
|
||||||
|
main_code="print('hello')",
|
||||||
|
)
|
||||||
|
plugin_dir = builder.create(config)
|
||||||
|
|
||||||
|
# 清理插件
|
||||||
|
builder.cleanup("my_test_plugin")
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, plugin_store_path: str | Path):
|
||||||
|
"""初始化构建器。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
plugin_store_path: 插件存储路径 (通常是 data/plugins)
|
||||||
|
"""
|
||||||
|
self.plugin_store_path = Path(plugin_store_path)
|
||||||
|
self._created_plugins: set[str] = set()
|
||||||
|
|
||||||
|
def create(
|
||||||
|
self,
|
||||||
|
plugin_config: str | MockPluginConfig | None = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> Path:
|
||||||
|
"""创建模拟插件。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
plugin_config: 插件名称字符串、MockPluginConfig 对象或 None
|
||||||
|
**kwargs: 如果 plugin_config 是字符串或 None,这些参数用于构建 MockPluginConfig
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path: 创建的插件目录路径
|
||||||
|
"""
|
||||||
|
# 处理不同类型的输入
|
||||||
|
if plugin_config is None:
|
||||||
|
config = MockPluginConfig(**kwargs)
|
||||||
|
elif isinstance(plugin_config, str):
|
||||||
|
config = MockPluginConfig(name=plugin_config, **kwargs)
|
||||||
|
elif isinstance(plugin_config, MockPluginConfig):
|
||||||
|
config = plugin_config
|
||||||
|
else:
|
||||||
|
raise TypeError(f"Invalid plugin_config type: {type(plugin_config)}")
|
||||||
|
|
||||||
|
# 创建插件目录
|
||||||
|
plugin_dir = self.plugin_store_path / config.name
|
||||||
|
plugin_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# 创建 metadata.yaml
|
||||||
|
metadata_content = "\n".join(
|
||||||
|
[
|
||||||
|
f"name: {config.name}",
|
||||||
|
f"author: {config.author}",
|
||||||
|
f"desc: {config.description}",
|
||||||
|
f"version: {config.version}",
|
||||||
|
f"repo: {config.repo}",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
(plugin_dir / "metadata.yaml").write_text(
|
||||||
|
metadata_content + "\n", encoding="utf-8"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 创建 main.py
|
||||||
|
main_code = config.main_code or DEFAULT_PLUGIN_MAIN_TEMPLATE.format(
|
||||||
|
plugin_name=config.name
|
||||||
|
)
|
||||||
|
(plugin_dir / "main.py").write_text(main_code, encoding="utf-8")
|
||||||
|
|
||||||
|
# 创建 requirements.txt(如果有依赖)
|
||||||
|
if config.requirements:
|
||||||
|
(plugin_dir / "requirements.txt").write_text(
|
||||||
|
"\n".join(config.requirements) + "\n", encoding="utf-8"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 创建 README.md(如果需要)
|
||||||
|
if config.has_readme:
|
||||||
|
(plugin_dir / "README.md").write_text(
|
||||||
|
config.readme_content, encoding="utf-8"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 记录创建的插件
|
||||||
|
self._created_plugins.add(config.name)
|
||||||
|
|
||||||
|
return plugin_dir
|
||||||
|
|
||||||
|
def cleanup(self, plugin_name: str | None = None) -> None:
|
||||||
|
"""清理插件。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
plugin_name: 要清理的插件名称,如果为 None 则清理所有由本构建器创建的插件
|
||||||
|
"""
|
||||||
|
if plugin_name:
|
||||||
|
plugins_to_clean = {plugin_name}
|
||||||
|
else:
|
||||||
|
plugins_to_clean = self._created_plugins.copy()
|
||||||
|
|
||||||
|
for name in plugins_to_clean:
|
||||||
|
plugin_dir = self.plugin_store_path / name
|
||||||
|
if plugin_dir.exists():
|
||||||
|
shutil.rmtree(plugin_dir)
|
||||||
|
self._created_plugins.discard(name)
|
||||||
|
|
||||||
|
def cleanup_all(self) -> None:
|
||||||
|
"""清理所有由本构建器创建的插件。"""
|
||||||
|
self.cleanup(None)
|
||||||
|
|
||||||
|
def get_plugin_path(self, plugin_name: str) -> Path:
|
||||||
|
"""获取插件路径。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
plugin_name: 插件名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path: 插件目录路径
|
||||||
|
"""
|
||||||
|
return self.plugin_store_path / plugin_name
|
||||||
|
|
||||||
|
@property
|
||||||
|
def created_plugins(self) -> set[str]:
|
||||||
|
"""获取已创建的插件名称集合。"""
|
||||||
|
return self._created_plugins.copy()
|
||||||
|
|
||||||
|
|
||||||
|
def create_mock_updater_install(
|
||||||
|
plugin_builder: MockPluginBuilder,
|
||||||
|
repo_to_plugin: dict[str, str] | None = None,
|
||||||
|
) -> Callable:
|
||||||
|
"""创建模拟的 updater.install 方法。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
plugin_builder: MockPluginBuilder 实例
|
||||||
|
repo_to_plugin: 仓库 URL 到插件名称的映射,格式: {"https://github.com/user/repo": "plugin_name"}
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Callable: 异步函数,可用于 monkeypatch.setattr
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def mock_install(repo_url: str, proxy: str = "") -> str:
|
||||||
|
"""Mock updater.install 方法。"""
|
||||||
|
# 查找插件名称
|
||||||
|
plugin_name = None
|
||||||
|
if repo_to_plugin:
|
||||||
|
plugin_name = repo_to_plugin.get(repo_url)
|
||||||
|
|
||||||
|
# 如果没有映射,尝试从 URL 提取插件名
|
||||||
|
if not plugin_name:
|
||||||
|
# 从 https://github.com/user/plugin_name 提取 plugin_name
|
||||||
|
parts = repo_url.rstrip("/").split("/")
|
||||||
|
plugin_name = parts[-1] if parts else "unknown_plugin"
|
||||||
|
|
||||||
|
# 创建插件目录
|
||||||
|
config = MockPluginConfig(name=plugin_name, repo=repo_url)
|
||||||
|
plugin_dir = plugin_builder.create(config)
|
||||||
|
return str(plugin_dir)
|
||||||
|
|
||||||
|
return mock_install
|
||||||
|
|
||||||
|
|
||||||
|
def create_mock_updater_update(
|
||||||
|
plugin_builder: MockPluginBuilder,
|
||||||
|
update_callback: Callable | None = None,
|
||||||
|
) -> Callable:
|
||||||
|
"""创建模拟的 updater.update 方法。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
plugin_builder: MockPluginBuilder 实例
|
||||||
|
update_callback: 更新回调函数,接收 plugin 参数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Callable: 异步函数,可用于 monkeypatch.setattr
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def mock_update(plugin, proxy: str = "") -> None:
|
||||||
|
"""Mock updater.update 方法。"""
|
||||||
|
plugin_dir = plugin_builder.get_plugin_path(plugin.name)
|
||||||
|
|
||||||
|
# 创建更新标记文件
|
||||||
|
(plugin_dir / ".updated").write_text("ok", encoding="utf-8")
|
||||||
|
|
||||||
|
# 调用回调
|
||||||
|
if update_callback:
|
||||||
|
update_callback(plugin)
|
||||||
|
|
||||||
|
return mock_update
|
||||||
|
|||||||
@@ -203,7 +203,7 @@ async def test_open_chat_send_auto_session_id_and_username(
|
|||||||
"/api/v1/chat",
|
"/api/v1/chat",
|
||||||
json={
|
json={
|
||||||
"message": "hello",
|
"message": "hello",
|
||||||
"username": "alice",
|
"username": "alice_auto_session",
|
||||||
"enable_streaming": False,
|
"enable_streaming": False,
|
||||||
},
|
},
|
||||||
headers={"X-API-Key": raw_key},
|
headers={"X-API-Key": raw_key},
|
||||||
@@ -217,16 +217,16 @@ async def test_open_chat_send_auto_session_id_and_username(
|
|||||||
created_session_id = send_data["data"]["session_id"]
|
created_session_id = send_data["data"]["session_id"]
|
||||||
assert isinstance(created_session_id, str)
|
assert isinstance(created_session_id, str)
|
||||||
uuid.UUID(created_session_id)
|
uuid.UUID(created_session_id)
|
||||||
assert send_data["data"]["creator"] == "alice"
|
assert send_data["data"]["creator"] == "alice_auto_session"
|
||||||
created_session = await core_lifecycle_td.db.get_platform_session_by_id(
|
created_session = await core_lifecycle_td.db.get_platform_session_by_id(
|
||||||
created_session_id
|
created_session_id
|
||||||
)
|
)
|
||||||
assert created_session is not None
|
assert created_session is not None
|
||||||
assert created_session.creator == "alice"
|
assert created_session.creator == "alice_auto_session"
|
||||||
assert created_session.platform_id == "webchat"
|
assert created_session.platform_id == "webchat"
|
||||||
|
|
||||||
await core_lifecycle_td.db.create_platform_session(
|
await core_lifecycle_td.db.create_platform_session(
|
||||||
creator="bob",
|
creator="bob_auto_session",
|
||||||
platform_id="webchat",
|
platform_id="webchat",
|
||||||
session_id="open_api_existing_bob_session",
|
session_id="open_api_existing_bob_session",
|
||||||
is_group=0,
|
is_group=0,
|
||||||
|
|||||||
+115
-46
@@ -1,5 +1,6 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import pytest_asyncio
|
import pytest_asyncio
|
||||||
@@ -11,6 +12,12 @@ from astrbot.core.db.sqlite import SQLiteDatabase
|
|||||||
from astrbot.core.star.star import star_registry
|
from astrbot.core.star.star import star_registry
|
||||||
from astrbot.core.star.star_handler import star_handlers_registry
|
from astrbot.core.star.star_handler import star_handlers_registry
|
||||||
from astrbot.dashboard.server import AstrBotDashboard
|
from astrbot.dashboard.server import AstrBotDashboard
|
||||||
|
from tests.fixtures.helpers import (
|
||||||
|
MockPluginBuilder,
|
||||||
|
MockPluginConfig,
|
||||||
|
create_mock_updater_install,
|
||||||
|
create_mock_updater_update,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture(scope="module")
|
@pytest_asyncio.fixture(scope="module")
|
||||||
@@ -94,8 +101,15 @@ async def test_get_stat(app: Quart, authenticated_header: dict):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_plugins(app: Quart, authenticated_header: dict):
|
async def test_plugins(
|
||||||
|
app: Quart,
|
||||||
|
authenticated_header: dict,
|
||||||
|
core_lifecycle_td: AstrBotCoreLifecycle,
|
||||||
|
monkeypatch,
|
||||||
|
):
|
||||||
|
"""测试插件 API 端点,使用 Mock 避免真实网络调用。"""
|
||||||
test_client = app.test_client()
|
test_client = app.test_client()
|
||||||
|
|
||||||
# 已经安装的插件
|
# 已经安装的插件
|
||||||
response = await test_client.get("/api/plugin/get", headers=authenticated_header)
|
response = await test_client.get("/api/plugin/get", headers=authenticated_header)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
@@ -111,53 +125,79 @@ async def test_plugins(app: Quart, authenticated_header: dict):
|
|||||||
data = await response.get_json()
|
data = await response.get_json()
|
||||||
assert data["status"] == "ok"
|
assert data["status"] == "ok"
|
||||||
|
|
||||||
# 插件安装
|
# 使用 MockPluginBuilder 创建测试插件
|
||||||
response = await test_client.post(
|
plugin_store_path = core_lifecycle_td.plugin_manager.plugin_store_path
|
||||||
"/api/plugin/install",
|
builder = MockPluginBuilder(plugin_store_path)
|
||||||
json={"url": "https://github.com/Soulter/astrbot_plugin_essential"},
|
|
||||||
headers=authenticated_header,
|
|
||||||
)
|
|
||||||
assert response.status_code == 200
|
|
||||||
data = await response.get_json()
|
|
||||||
assert data["status"] == "ok"
|
|
||||||
exists = False
|
|
||||||
for md in star_registry:
|
|
||||||
if md.name == "astrbot_plugin_essential":
|
|
||||||
exists = True
|
|
||||||
break
|
|
||||||
assert exists is True, "插件 astrbot_plugin_essential 未成功载入"
|
|
||||||
|
|
||||||
# 插件更新
|
# 定义测试插件
|
||||||
response = await test_client.post(
|
test_plugin_name = "test_mock_plugin"
|
||||||
"/api/plugin/update",
|
test_repo_url = f"https://github.com/test/{test_plugin_name}"
|
||||||
json={"name": "astrbot_plugin_essential"},
|
|
||||||
headers=authenticated_header,
|
|
||||||
)
|
|
||||||
assert response.status_code == 200
|
|
||||||
data = await response.get_json()
|
|
||||||
assert data["status"] == "ok"
|
|
||||||
|
|
||||||
# 插件卸载
|
# 创建 Mock 函数
|
||||||
response = await test_client.post(
|
mock_install = create_mock_updater_install(
|
||||||
"/api/plugin/uninstall",
|
builder,
|
||||||
json={"name": "astrbot_plugin_essential"},
|
repo_to_plugin={test_repo_url: test_plugin_name},
|
||||||
headers=authenticated_header,
|
|
||||||
)
|
)
|
||||||
assert response.status_code == 200
|
mock_update = create_mock_updater_update(builder)
|
||||||
data = await response.get_json()
|
|
||||||
assert data["status"] == "ok"
|
# 设置 Mock
|
||||||
exists = False
|
monkeypatch.setattr(
|
||||||
for md in star_registry:
|
core_lifecycle_td.plugin_manager.updator, "install", mock_install
|
||||||
if md.name == "astrbot_plugin_essential":
|
)
|
||||||
exists = True
|
monkeypatch.setattr(
|
||||||
break
|
core_lifecycle_td.plugin_manager.updator, "update", mock_update
|
||||||
assert exists is False, "插件 astrbot_plugin_essential 未成功卸载"
|
)
|
||||||
exists = False
|
|
||||||
for md in star_handlers_registry:
|
try:
|
||||||
if "astrbot_plugin_essential" in md.handler_module_path:
|
# 插件安装
|
||||||
exists = True
|
response = await test_client.post(
|
||||||
break
|
"/api/plugin/install",
|
||||||
assert exists is False, "插件 astrbot_plugin_essential 未成功卸载"
|
json={"url": test_repo_url},
|
||||||
|
headers=authenticated_header,
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = await response.get_json()
|
||||||
|
assert data["status"] == "ok", f"安装失败: {data.get('message', 'unknown error')}"
|
||||||
|
|
||||||
|
# 验证插件已注册
|
||||||
|
exists = any(md.name == test_plugin_name for md in star_registry)
|
||||||
|
assert exists is True, f"插件 {test_plugin_name} 未成功载入"
|
||||||
|
|
||||||
|
# 插件更新
|
||||||
|
response = await test_client.post(
|
||||||
|
"/api/plugin/update",
|
||||||
|
json={"name": test_plugin_name},
|
||||||
|
headers=authenticated_header,
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = await response.get_json()
|
||||||
|
assert data["status"] == "ok"
|
||||||
|
|
||||||
|
# 验证更新标记文件
|
||||||
|
plugin_dir = builder.get_plugin_path(test_plugin_name)
|
||||||
|
assert (plugin_dir / ".updated").exists()
|
||||||
|
|
||||||
|
# 插件卸载
|
||||||
|
response = await test_client.post(
|
||||||
|
"/api/plugin/uninstall",
|
||||||
|
json={"name": test_plugin_name},
|
||||||
|
headers=authenticated_header,
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = await response.get_json()
|
||||||
|
assert data["status"] == "ok"
|
||||||
|
|
||||||
|
# 验证插件已卸载
|
||||||
|
exists = any(md.name == test_plugin_name for md in star_registry)
|
||||||
|
assert exists is False, f"插件 {test_plugin_name} 未成功卸载"
|
||||||
|
exists = any(
|
||||||
|
test_plugin_name in md.handler_module_path for md in star_handlers_registry
|
||||||
|
)
|
||||||
|
assert exists is False, f"插件 {test_plugin_name} handler 未成功清理"
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# 清理测试插件
|
||||||
|
builder.cleanup(test_plugin_name)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@@ -189,12 +229,41 @@ async def test_commands_api(app: Quart, authenticated_header: dict):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_check_update(app: Quart, authenticated_header: dict):
|
async def test_check_update(
|
||||||
|
app: Quart,
|
||||||
|
authenticated_header: dict,
|
||||||
|
core_lifecycle_td: AstrBotCoreLifecycle,
|
||||||
|
monkeypatch,
|
||||||
|
):
|
||||||
|
"""测试检查更新 API,使用 Mock 避免真实网络调用。"""
|
||||||
test_client = app.test_client()
|
test_client = app.test_client()
|
||||||
|
|
||||||
|
# Mock 更新检查和网络请求
|
||||||
|
async def mock_check_update(*args, **kwargs):
|
||||||
|
"""Mock 更新检查,返回无新版本。"""
|
||||||
|
return None # None 表示没有新版本
|
||||||
|
|
||||||
|
async def mock_get_dashboard_version(*args, **kwargs):
|
||||||
|
"""Mock Dashboard 版本获取。"""
|
||||||
|
from astrbot.core.config.default import VERSION
|
||||||
|
|
||||||
|
return f"v{VERSION}" # 返回当前版本
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
core_lifecycle_td.astrbot_updator,
|
||||||
|
"check_update",
|
||||||
|
mock_check_update,
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"astrbot.dashboard.routes.update.get_dashboard_version",
|
||||||
|
mock_get_dashboard_version,
|
||||||
|
)
|
||||||
|
|
||||||
response = await test_client.get("/api/update/check", headers=authenticated_header)
|
response = await test_client.get("/api/update/check", headers=authenticated_header)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
data = await response.get_json()
|
data = await response.get_json()
|
||||||
assert data["status"] == "success"
|
assert data["status"] == "success"
|
||||||
|
assert data["data"]["has_new_version"] is False
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|||||||
+49
-3
@@ -16,6 +16,16 @@ class _version_info:
|
|||||||
self.major = major
|
self.major = major
|
||||||
self.minor = minor
|
self.minor = minor
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
if isinstance(other, tuple):
|
||||||
|
return (self.major, self.minor) == other[:2]
|
||||||
|
return (self.major, self.minor) == (other.major, other.minor)
|
||||||
|
|
||||||
|
def __ge__(self, other):
|
||||||
|
if isinstance(other, tuple):
|
||||||
|
return (self.major, self.minor) >= other[:2]
|
||||||
|
return (self.major, self.minor) >= (other.major, other.minor)
|
||||||
|
|
||||||
|
|
||||||
def test_check_env(monkeypatch):
|
def test_check_env(monkeypatch):
|
||||||
version_info_correct = _version_info(3, 10)
|
version_info_correct = _version_info(3, 10)
|
||||||
@@ -23,15 +33,51 @@ def test_check_env(monkeypatch):
|
|||||||
monkeypatch.setattr(sys, "version_info", version_info_correct)
|
monkeypatch.setattr(sys, "version_info", version_info_correct)
|
||||||
with mock.patch("os.makedirs") as mock_makedirs:
|
with mock.patch("os.makedirs") as mock_makedirs:
|
||||||
check_env()
|
check_env()
|
||||||
mock_makedirs.assert_any_call("data/config", exist_ok=True)
|
# Check that makedirs was called with paths containing expected dirs
|
||||||
mock_makedirs.assert_any_call("data/plugins", exist_ok=True)
|
called_paths = [call[0][0] for call in mock_makedirs.call_args_list]
|
||||||
mock_makedirs.assert_any_call("data/temp", exist_ok=True)
|
# Use os.path.join for cross-platform path matching
|
||||||
|
assert any(p.rstrip(os.sep).endswith(os.path.join("data", "config")) for p in called_paths)
|
||||||
|
assert any(p.rstrip(os.sep).endswith(os.path.join("data", "plugins")) for p in called_paths)
|
||||||
|
assert any(p.rstrip(os.sep).endswith(os.path.join("data", "temp")) for p in called_paths)
|
||||||
|
|
||||||
monkeypatch.setattr(sys, "version_info", version_info_wrong)
|
monkeypatch.setattr(sys, "version_info", version_info_wrong)
|
||||||
with pytest.raises(SystemExit):
|
with pytest.raises(SystemExit):
|
||||||
check_env()
|
check_env()
|
||||||
|
|
||||||
|
|
||||||
|
def test_version_info_comparisons():
|
||||||
|
"""Test _version_info comparison operators with tuples and other instances."""
|
||||||
|
v3_10 = _version_info(3, 10)
|
||||||
|
v3_9 = _version_info(3, 9)
|
||||||
|
v3_11 = _version_info(3, 11)
|
||||||
|
|
||||||
|
# Test __eq__ with tuples
|
||||||
|
assert v3_10 == (3, 10)
|
||||||
|
assert v3_10 != (3, 9)
|
||||||
|
assert v3_9 == (3, 9)
|
||||||
|
|
||||||
|
# Test __ge__ with tuples
|
||||||
|
assert v3_10 >= (3, 10)
|
||||||
|
assert v3_10 >= (3, 9)
|
||||||
|
assert not (v3_9 >= (3, 10))
|
||||||
|
assert v3_11 >= (3, 10)
|
||||||
|
|
||||||
|
# Test __eq__ with other _version_info instances
|
||||||
|
assert v3_10 == _version_info(3, 10)
|
||||||
|
assert v3_10 != v3_9
|
||||||
|
assert v3_10 == v3_10 # Same instance
|
||||||
|
|
||||||
|
assert v3_10 != v3_11
|
||||||
|
|
||||||
|
# Test __ge__ with other _version_info instances
|
||||||
|
assert v3_10 >= v3_10
|
||||||
|
assert v3_10 >= v3_9
|
||||||
|
assert not (v3_9 >= v3_10)
|
||||||
|
assert v3_11 >= v3_10
|
||||||
|
|
||||||
|
assert v3_11 >= v3_11 # Same instance
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_check_dashboard_files_not_exists(monkeypatch):
|
async def test_check_dashboard_files_not_exists(monkeypatch):
|
||||||
"""Tests dashboard download when files do not exist."""
|
"""Tests dashboard download when files do not exist."""
|
||||||
|
|||||||
+159
-74
@@ -1,65 +1,164 @@
|
|||||||
import os
|
import sys
|
||||||
from asyncio import Queue
|
from asyncio import Queue
|
||||||
|
from pathlib import Path
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
|
||||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||||
from astrbot.core.db.sqlite import SQLiteDatabase
|
from astrbot.core.db.sqlite import SQLiteDatabase
|
||||||
from astrbot.core.star.context import Context
|
from astrbot.core.star.context import Context
|
||||||
from astrbot.core.star.star import star_registry
|
from astrbot.core.star.star import star_map, star_registry
|
||||||
from astrbot.core.star.star_handler import star_handlers_registry
|
from astrbot.core.star.star_handler import star_handlers_registry
|
||||||
from astrbot.core.star.star_manager import PluginManager
|
from astrbot.core.star.star_manager import PluginManager
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
def _clear_module_cache() -> None:
|
||||||
def plugin_manager_pm(tmp_path):
|
"""Clear module cache for data module tree to ensure test isolation."""
|
||||||
"""Provides a fully isolated PluginManager instance for testing.
|
modules_to_remove = [
|
||||||
- Uses a temporary directory for plugins.
|
key for key in sys.modules if key == "data" or key.startswith("data.")
|
||||||
- Uses a temporary database.
|
]
|
||||||
- Creates a fresh context for each test.
|
for key in modules_to_remove:
|
||||||
"""
|
del sys.modules[key]
|
||||||
# Create temporary resources
|
|
||||||
temp_plugins_path = tmp_path / "plugins"
|
|
||||||
temp_plugins_path.mkdir()
|
def _clear_registry(plugin_name: str) -> None:
|
||||||
temp_db_path = tmp_path / "test_db.db"
|
"""Clear plugin from global registries."""
|
||||||
|
# Clear star_registry (list)
|
||||||
|
star_registry[:] = [md for md in star_registry if md.name != plugin_name]
|
||||||
|
# Clear star_map (dict)
|
||||||
|
keys_to_remove = [
|
||||||
|
key for key, md in star_map.items() if md.name == plugin_name
|
||||||
|
]
|
||||||
|
for key in keys_to_remove:
|
||||||
|
del star_map[key]
|
||||||
|
# Clear star_handlers_registry (StarHandlerRegistry)
|
||||||
|
for handler in list(star_handlers_registry):
|
||||||
|
if plugin_name in (handler.handler_module_path or ""):
|
||||||
|
star_handlers_registry.remove(handler)
|
||||||
|
|
||||||
|
TEST_PLUGIN_REPO = "https://github.com/Soulter/helloworld"
|
||||||
|
TEST_PLUGIN_DIR = "helloworld"
|
||||||
|
TEST_PLUGIN_NAME = "helloworld"
|
||||||
|
|
||||||
|
|
||||||
|
def _write_local_test_plugin(plugin_dir: Path, repo_url: str) -> None:
|
||||||
|
plugin_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
(plugin_dir / "metadata.yaml").write_text(
|
||||||
|
"\n".join(
|
||||||
|
[
|
||||||
|
f"name: {TEST_PLUGIN_NAME}",
|
||||||
|
"author: AstrBot Team",
|
||||||
|
"desc: Local test plugin",
|
||||||
|
"version: 1.0.0",
|
||||||
|
f"repo: {repo_url}",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
+ "\n",
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
(plugin_dir / "main.py").write_text(
|
||||||
|
"\n".join(
|
||||||
|
[
|
||||||
|
"from astrbot.api import star",
|
||||||
|
"",
|
||||||
|
"class Main(star.Star):",
|
||||||
|
" pass",
|
||||||
|
"",
|
||||||
|
],
|
||||||
|
),
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def plugin_manager_pm(tmp_path, monkeypatch):
|
||||||
|
"""Provides a fully isolated PluginManager instance for testing."""
|
||||||
|
# Clear module cache before setup to ensure isolation
|
||||||
|
_clear_module_cache()
|
||||||
|
|
||||||
|
test_root = tmp_path / "astrbot_root"
|
||||||
|
data_dir = test_root / "data"
|
||||||
|
plugin_dir = data_dir / "plugins"
|
||||||
|
config_dir = data_dir / "config"
|
||||||
|
temp_dir = data_dir / "temp"
|
||||||
|
for path in (plugin_dir, config_dir, temp_dir):
|
||||||
|
path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Ensure `import data.plugins.<plugin>.main` resolves to this temp root.
|
||||||
|
(data_dir / "__init__.py").write_text("", encoding="utf-8")
|
||||||
|
(plugin_dir / "__init__.py").write_text("", encoding="utf-8")
|
||||||
|
|
||||||
|
# Use monkeypatch for both env var and sys.path to ensure proper cleanup
|
||||||
|
monkeypatch.setenv("ASTRBOT_ROOT", str(test_root))
|
||||||
|
monkeypatch.syspath_prepend(str(test_root))
|
||||||
|
|
||||||
# Create fresh, isolated instances for the context
|
# Create fresh, isolated instances for the context
|
||||||
event_queue = Queue()
|
event_queue = Queue()
|
||||||
config = AstrBotConfig()
|
config = AstrBotConfig()
|
||||||
db = SQLiteDatabase(str(temp_db_path))
|
db = SQLiteDatabase(str(data_dir / "test_db.db"))
|
||||||
|
config.plugin_store_path = str(plugin_dir)
|
||||||
|
|
||||||
# Set the plugin store path in the config to the temporary directory
|
|
||||||
config.plugin_store_path = str(temp_plugins_path)
|
|
||||||
|
|
||||||
# Mock dependencies for the context
|
|
||||||
provider_manager = MagicMock()
|
provider_manager = MagicMock()
|
||||||
platform_manager = MagicMock()
|
platform_manager = MagicMock()
|
||||||
conversation_manager = MagicMock()
|
conversation_manager = MagicMock()
|
||||||
message_history_manager = MagicMock()
|
message_history_manager = MagicMock()
|
||||||
persona_manager = MagicMock()
|
persona_manager = MagicMock()
|
||||||
|
persona_manager.personas_v3 = []
|
||||||
astrbot_config_mgr = MagicMock()
|
astrbot_config_mgr = MagicMock()
|
||||||
knowledge_base_manager = MagicMock()
|
knowledge_base_manager = MagicMock()
|
||||||
|
cron_manager = MagicMock()
|
||||||
|
|
||||||
star_context = Context(
|
star_context = Context(
|
||||||
event_queue,
|
event_queue=event_queue,
|
||||||
config,
|
config=config,
|
||||||
db,
|
db=db,
|
||||||
provider_manager,
|
provider_manager=provider_manager,
|
||||||
platform_manager,
|
platform_manager=platform_manager,
|
||||||
conversation_manager,
|
conversation_manager=conversation_manager,
|
||||||
message_history_manager,
|
message_history_manager=message_history_manager,
|
||||||
persona_manager,
|
persona_manager=persona_manager,
|
||||||
astrbot_config_mgr,
|
astrbot_config_mgr=astrbot_config_mgr,
|
||||||
knowledge_base_manager=knowledge_base_manager,
|
knowledge_base_manager=knowledge_base_manager,
|
||||||
|
cron_manager=cron_manager,
|
||||||
|
subagent_orchestrator=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create the PluginManager instance
|
|
||||||
manager = PluginManager(star_context, config)
|
manager = PluginManager(star_context, config)
|
||||||
return manager
|
try:
|
||||||
|
yield manager
|
||||||
|
finally:
|
||||||
|
# Cleanup global registries and module cache
|
||||||
|
_clear_registry(TEST_PLUGIN_NAME)
|
||||||
|
_clear_module_cache()
|
||||||
|
await db.engine.dispose()
|
||||||
|
|
||||||
|
|
||||||
def test_plugin_manager_initialization(plugin_manager_pm: PluginManager):
|
@pytest.fixture
|
||||||
|
def local_updator(plugin_manager_pm: PluginManager, monkeypatch):
|
||||||
|
plugin_path = Path(plugin_manager_pm.plugin_store_path) / TEST_PLUGIN_DIR
|
||||||
|
|
||||||
|
async def mock_install(repo_url: str, proxy=""): # noqa: ARG001
|
||||||
|
if repo_url != TEST_PLUGIN_REPO:
|
||||||
|
raise Exception("Repo not found")
|
||||||
|
_write_local_test_plugin(plugin_path, repo_url)
|
||||||
|
return str(plugin_path)
|
||||||
|
|
||||||
|
async def mock_update(plugin, proxy=""): # noqa: ARG001
|
||||||
|
if plugin.name != TEST_PLUGIN_NAME:
|
||||||
|
raise Exception("Plugin not found")
|
||||||
|
if not plugin_path.exists():
|
||||||
|
raise Exception("Plugin path missing")
|
||||||
|
(plugin_path / ".updated").write_text("ok", encoding="utf-8")
|
||||||
|
|
||||||
|
monkeypatch.setattr(plugin_manager_pm.updator, "install", mock_install)
|
||||||
|
monkeypatch.setattr(plugin_manager_pm.updator, "update", mock_update)
|
||||||
|
return plugin_path
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_plugin_manager_initialization(plugin_manager_pm: PluginManager):
|
||||||
assert plugin_manager_pm is not None
|
assert plugin_manager_pm is not None
|
||||||
assert plugin_manager_pm.context is not None
|
assert plugin_manager_pm.context is not None
|
||||||
assert plugin_manager_pm.config is not None
|
assert plugin_manager_pm.config is not None
|
||||||
@@ -73,73 +172,59 @@ async def test_plugin_manager_reload(plugin_manager_pm: PluginManager):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_install_plugin(plugin_manager_pm: PluginManager):
|
async def test_install_plugin(plugin_manager_pm: PluginManager, local_updator: Path):
|
||||||
"""Tests successful plugin installation in an isolated environment."""
|
"""Tests successful plugin installation without external network."""
|
||||||
test_repo = "https://github.com/Soulter/astrbot_plugin_essential"
|
plugin_info = await plugin_manager_pm.install_plugin(TEST_PLUGIN_REPO)
|
||||||
plugin_info = await plugin_manager_pm.install_plugin(test_repo)
|
|
||||||
plugin_path = os.path.join(
|
|
||||||
plugin_manager_pm.plugin_store_path,
|
|
||||||
"astrbot_plugin_essential",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert plugin_info is not None
|
assert plugin_info is not None
|
||||||
assert os.path.exists(plugin_path)
|
assert plugin_info["name"] == TEST_PLUGIN_NAME
|
||||||
assert any(md.name == "astrbot_plugin_essential" for md in star_registry), (
|
assert local_updator.exists()
|
||||||
"Plugin 'astrbot_plugin_essential' was not loaded into star_registry."
|
assert any(md.name == TEST_PLUGIN_NAME for md in star_registry)
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_install_nonexistent_plugin(plugin_manager_pm: PluginManager):
|
async def test_install_nonexistent_plugin(
|
||||||
|
plugin_manager_pm: PluginManager, local_updator
|
||||||
|
):
|
||||||
"""Tests that installing a non-existent plugin raises an exception."""
|
"""Tests that installing a non-existent plugin raises an exception."""
|
||||||
with pytest.raises(Exception):
|
with pytest.raises(Exception):
|
||||||
await plugin_manager_pm.install_plugin(
|
await plugin_manager_pm.install_plugin(
|
||||||
"https://github.com/Soulter/non_existent_repo",
|
"https://github.com/Soulter/non_existent_repo"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_update_plugin(plugin_manager_pm: PluginManager):
|
async def test_update_plugin(plugin_manager_pm: PluginManager, local_updator: Path):
|
||||||
"""Tests updating an existing plugin in an isolated environment."""
|
"""Tests updating an existing plugin without external network."""
|
||||||
# First, install the plugin
|
plugin_info = await plugin_manager_pm.install_plugin(TEST_PLUGIN_REPO)
|
||||||
test_repo = "https://github.com/Soulter/astrbot_plugin_essential"
|
assert plugin_info is not None
|
||||||
await plugin_manager_pm.install_plugin(test_repo)
|
plugin_name = plugin_info["name"]
|
||||||
|
await plugin_manager_pm.update_plugin(plugin_name)
|
||||||
# Then, update it
|
assert (local_updator / ".updated").exists()
|
||||||
await plugin_manager_pm.update_plugin("astrbot_plugin_essential")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_update_nonexistent_plugin(plugin_manager_pm: PluginManager):
|
async def test_update_nonexistent_plugin(
|
||||||
|
plugin_manager_pm: PluginManager, local_updator
|
||||||
|
):
|
||||||
"""Tests that updating a non-existent plugin raises an exception."""
|
"""Tests that updating a non-existent plugin raises an exception."""
|
||||||
with pytest.raises(Exception):
|
with pytest.raises(Exception):
|
||||||
await plugin_manager_pm.update_plugin("non_existent_plugin")
|
await plugin_manager_pm.update_plugin("non_existent_plugin")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_uninstall_plugin(plugin_manager_pm: PluginManager):
|
async def test_uninstall_plugin(plugin_manager_pm: PluginManager, local_updator: Path):
|
||||||
"""Tests successful plugin uninstallation in an isolated environment."""
|
"""Tests successful plugin uninstallation."""
|
||||||
# First, install the plugin
|
plugin_info = await plugin_manager_pm.install_plugin(TEST_PLUGIN_REPO)
|
||||||
test_repo = "https://github.com/Soulter/astrbot_plugin_essential"
|
assert plugin_info is not None
|
||||||
await plugin_manager_pm.install_plugin(test_repo)
|
plugin_name = plugin_info["name"]
|
||||||
plugin_path = os.path.join(
|
assert local_updator.exists()
|
||||||
plugin_manager_pm.plugin_store_path,
|
|
||||||
"astrbot_plugin_essential",
|
|
||||||
)
|
|
||||||
assert os.path.exists(plugin_path) # Pre-condition
|
|
||||||
|
|
||||||
# Then, uninstall it
|
await plugin_manager_pm.uninstall_plugin(plugin_name)
|
||||||
await plugin_manager_pm.uninstall_plugin("astrbot_plugin_essential")
|
|
||||||
|
|
||||||
assert not os.path.exists(plugin_path)
|
assert not local_updator.exists()
|
||||||
assert not any(md.name == "astrbot_plugin_essential" for md in star_registry), (
|
assert not any(md.name == TEST_PLUGIN_NAME for md in star_registry)
|
||||||
"Plugin 'astrbot_plugin_essential' was not unloaded from star_registry."
|
|
||||||
)
|
|
||||||
assert not any(
|
assert not any(
|
||||||
"astrbot_plugin_essential" in md.handler_module_path
|
TEST_PLUGIN_NAME in md.handler_module_path for md in star_handlers_registry
|
||||||
for md in star_handlers_registry
|
|
||||||
), (
|
|
||||||
"Plugin 'astrbot_plugin_essential' handler was not unloaded from star_handlers_registry."
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -101,10 +101,16 @@ def test_pipeline_import_is_stable_with_mocked_apscheduler() -> None:
|
|||||||
"mock_apscheduler.schedulers = MagicMock();"
|
"mock_apscheduler.schedulers = MagicMock();"
|
||||||
"mock_apscheduler.schedulers.asyncio = MagicMock();"
|
"mock_apscheduler.schedulers.asyncio = MagicMock();"
|
||||||
"mock_apscheduler.schedulers.background = MagicMock();"
|
"mock_apscheduler.schedulers.background = MagicMock();"
|
||||||
|
"mock_apscheduler.triggers = MagicMock();"
|
||||||
|
"mock_apscheduler.triggers.cron = MagicMock();"
|
||||||
|
"mock_apscheduler.triggers.date = MagicMock();"
|
||||||
"sys.modules['apscheduler'] = mock_apscheduler;"
|
"sys.modules['apscheduler'] = mock_apscheduler;"
|
||||||
"sys.modules['apscheduler.schedulers'] = mock_apscheduler.schedulers;"
|
"sys.modules['apscheduler.schedulers'] = mock_apscheduler.schedulers;"
|
||||||
"sys.modules['apscheduler.schedulers.asyncio'] = mock_apscheduler.schedulers.asyncio;"
|
"sys.modules['apscheduler.schedulers.asyncio'] = mock_apscheduler.schedulers.asyncio;"
|
||||||
"sys.modules['apscheduler.schedulers.background'] = mock_apscheduler.schedulers.background;"
|
"sys.modules['apscheduler.schedulers.background'] = mock_apscheduler.schedulers.background;"
|
||||||
|
"sys.modules['apscheduler.triggers'] = mock_apscheduler.triggers;"
|
||||||
|
"sys.modules['apscheduler.triggers.cron'] = mock_apscheduler.triggers.cron;"
|
||||||
|
"sys.modules['apscheduler.triggers.date'] = mock_apscheduler.triggers.date;"
|
||||||
"import astrbot.core.pipeline as pipeline;"
|
"import astrbot.core.pipeline as pipeline;"
|
||||||
"assert pipeline.ProcessStage is not None;"
|
"assert pipeline.ProcessStage is not None;"
|
||||||
"assert pipeline.RespondStage is not None"
|
"assert pipeline.RespondStage is not None"
|
||||||
|
|||||||
@@ -461,7 +461,8 @@ async def test_stop_signal_returns_aborted_and_persists_partial_message(
|
|||||||
final_resp = runner.get_final_llm_resp()
|
final_resp = runner.get_final_llm_resp()
|
||||||
assert final_resp is not None
|
assert final_resp is not None
|
||||||
assert final_resp.role == "assistant"
|
assert final_resp.role == "assistant"
|
||||||
assert final_resp.completion_text == "partial "
|
# When interrupted, the runner replaces completion_text with a system message
|
||||||
|
assert "interrupted" in final_resp.completion_text.lower()
|
||||||
assert runner.run_context.messages[-1].role == "assistant"
|
assert runner.run_context.messages[-1].role == "assistant"
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,607 @@
|
|||||||
|
"""Tests for config module."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from astrbot.core.config.astrbot_config import AstrBotConfig, RateLimitStrategy
|
||||||
|
from astrbot.core.config.default import DEFAULT_VALUE_MAP
|
||||||
|
from astrbot.core.config.i18n_utils import ConfigMetadataI18n
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def temp_config_path(tmp_path):
|
||||||
|
"""Create a temporary config path."""
|
||||||
|
return str(tmp_path / "test_config.json")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def minimal_default_config():
|
||||||
|
"""Create a minimal default config for testing."""
|
||||||
|
return {
|
||||||
|
"config_version": 2,
|
||||||
|
"platform_settings": {
|
||||||
|
"unique_session": False,
|
||||||
|
"rate_limit": {
|
||||||
|
"time": 60,
|
||||||
|
"count": 30,
|
||||||
|
"strategy": "stall",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"provider_settings": {
|
||||||
|
"enable": True,
|
||||||
|
"default_provider_id": "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TestRateLimitStrategy:
|
||||||
|
"""Tests for RateLimitStrategy enum."""
|
||||||
|
|
||||||
|
def test_stall_value(self):
|
||||||
|
"""Test stall enum value."""
|
||||||
|
assert RateLimitStrategy.STALL.value == "stall"
|
||||||
|
|
||||||
|
def test_discard_value(self):
|
||||||
|
"""Test discard enum value."""
|
||||||
|
assert RateLimitStrategy.DISCARD.value == "discard"
|
||||||
|
|
||||||
|
|
||||||
|
class TestAstrBotConfigLoad:
|
||||||
|
"""Tests for AstrBotConfig loading and initialization."""
|
||||||
|
|
||||||
|
def test_init_creates_file_if_not_exists(
|
||||||
|
self, temp_config_path, minimal_default_config
|
||||||
|
):
|
||||||
|
"""Test that config file is created when it doesn't exist."""
|
||||||
|
assert not os.path.exists(temp_config_path)
|
||||||
|
|
||||||
|
config = AstrBotConfig(
|
||||||
|
config_path=temp_config_path, default_config=minimal_default_config
|
||||||
|
)
|
||||||
|
|
||||||
|
assert os.path.exists(temp_config_path)
|
||||||
|
assert config.config_version == 2
|
||||||
|
assert config.platform_settings["unique_session"] is False
|
||||||
|
|
||||||
|
def test_init_loads_existing_file(self, temp_config_path, minimal_default_config):
|
||||||
|
"""Test that existing config file is loaded."""
|
||||||
|
existing_config = {
|
||||||
|
"config_version": 2,
|
||||||
|
"platform_settings": {"unique_session": True},
|
||||||
|
"provider_settings": {"enable": False},
|
||||||
|
}
|
||||||
|
with open(temp_config_path, "w", encoding="utf-8-sig") as f:
|
||||||
|
json.dump(existing_config, f)
|
||||||
|
|
||||||
|
config = AstrBotConfig(
|
||||||
|
config_path=temp_config_path, default_config=minimal_default_config
|
||||||
|
)
|
||||||
|
|
||||||
|
assert config.platform_settings["unique_session"] is True
|
||||||
|
assert config.provider_settings["enable"] is False
|
||||||
|
|
||||||
|
def test_first_deploy_flag(self, temp_config_path, minimal_default_config):
|
||||||
|
"""Test first_deploy flag is set for new config."""
|
||||||
|
config = AstrBotConfig(
|
||||||
|
config_path=temp_config_path, default_config=minimal_default_config
|
||||||
|
)
|
||||||
|
|
||||||
|
assert hasattr(config, "first_deploy")
|
||||||
|
assert config.first_deploy is True
|
||||||
|
|
||||||
|
def test_init_with_schema(self, temp_config_path):
|
||||||
|
"""Test initialization with schema."""
|
||||||
|
schema = {
|
||||||
|
"test_field": {
|
||||||
|
"type": "string",
|
||||||
|
"default": "test_value",
|
||||||
|
},
|
||||||
|
"nested": {
|
||||||
|
"type": "object",
|
||||||
|
"items": {
|
||||||
|
"enabled": {"type": "bool"},
|
||||||
|
"count": {"type": "int"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
config = AstrBotConfig(config_path=temp_config_path, schema=schema)
|
||||||
|
|
||||||
|
assert config.test_field == "test_value"
|
||||||
|
assert config.nested["enabled"] is False
|
||||||
|
assert config.nested["count"] == 0
|
||||||
|
|
||||||
|
def test_dot_notation_access(self, temp_config_path, minimal_default_config):
|
||||||
|
"""Test accessing config values using dot notation."""
|
||||||
|
config = AstrBotConfig(
|
||||||
|
config_path=temp_config_path, default_config=minimal_default_config
|
||||||
|
)
|
||||||
|
|
||||||
|
assert config.platform_settings is not None
|
||||||
|
assert config.non_existent_field is None
|
||||||
|
|
||||||
|
def test_setattr_updates_config(self, temp_config_path, minimal_default_config):
|
||||||
|
"""Test that setting attributes updates config."""
|
||||||
|
config = AstrBotConfig(
|
||||||
|
config_path=temp_config_path, default_config=minimal_default_config
|
||||||
|
)
|
||||||
|
|
||||||
|
config.new_field = "new_value"
|
||||||
|
|
||||||
|
assert config.new_field == "new_value"
|
||||||
|
|
||||||
|
def test_delattr_removes_field(self, temp_config_path, minimal_default_config):
|
||||||
|
"""Test that deleting attributes removes them."""
|
||||||
|
config = AstrBotConfig(
|
||||||
|
config_path=temp_config_path, default_config=minimal_default_config
|
||||||
|
)
|
||||||
|
config.temp_field = "temp"
|
||||||
|
|
||||||
|
del config.temp_field
|
||||||
|
|
||||||
|
# Accessing a deleted field returns None due to __getattr__
|
||||||
|
assert config.temp_field is None
|
||||||
|
# But the field is removed from the dict
|
||||||
|
assert "temp_field" not in config
|
||||||
|
|
||||||
|
def test_delattr_saves_config(self, temp_config_path, minimal_default_config):
|
||||||
|
"""Test that deleting attributes saves config to file."""
|
||||||
|
config = AstrBotConfig(
|
||||||
|
config_path=temp_config_path, default_config=minimal_default_config
|
||||||
|
)
|
||||||
|
config.temp_field = "temp"
|
||||||
|
del config.temp_field
|
||||||
|
|
||||||
|
with open(temp_config_path, encoding="utf-8-sig") as f:
|
||||||
|
loaded_config = json.load(f)
|
||||||
|
|
||||||
|
assert "temp_field" not in loaded_config
|
||||||
|
|
||||||
|
def test_check_exist(self, temp_config_path, minimal_default_config):
|
||||||
|
"""Test check_exist method."""
|
||||||
|
config = AstrBotConfig(
|
||||||
|
config_path=temp_config_path, default_config=minimal_default_config
|
||||||
|
)
|
||||||
|
|
||||||
|
assert config.check_exist() is True
|
||||||
|
|
||||||
|
# Create a path that definitely doesn't exist
|
||||||
|
import pathlib
|
||||||
|
|
||||||
|
temp_dir = pathlib.Path(temp_config_path).parent
|
||||||
|
non_existent_path = str(temp_dir / "non_existent_config.json")
|
||||||
|
|
||||||
|
# Check that the file doesn't exist before creating config
|
||||||
|
assert not os.path.exists(non_existent_path)
|
||||||
|
|
||||||
|
# Create config which will auto-create the file
|
||||||
|
config2 = AstrBotConfig(
|
||||||
|
config_path=non_existent_path, default_config=minimal_default_config
|
||||||
|
)
|
||||||
|
|
||||||
|
# Now it exists
|
||||||
|
assert config2.check_exist() is True
|
||||||
|
assert os.path.exists(non_existent_path)
|
||||||
|
|
||||||
|
|
||||||
|
class TestConfigValidation:
|
||||||
|
"""Tests for config validation and integrity checking."""
|
||||||
|
|
||||||
|
def test_insert_missing_config_items(
|
||||||
|
self, temp_config_path, minimal_default_config
|
||||||
|
):
|
||||||
|
"""Test that missing config items are inserted with default values."""
|
||||||
|
existing_config = {"config_version": 2}
|
||||||
|
with open(temp_config_path, "w", encoding="utf-8-sig") as f:
|
||||||
|
json.dump(existing_config, f)
|
||||||
|
|
||||||
|
config = AstrBotConfig(
|
||||||
|
config_path=temp_config_path, default_config=minimal_default_config
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "platform_settings" in config
|
||||||
|
assert "provider_settings" in config
|
||||||
|
|
||||||
|
def test_replace_none_with_default(self, temp_config_path, minimal_default_config):
|
||||||
|
"""Test that None values are replaced with defaults."""
|
||||||
|
existing_config = {
|
||||||
|
"config_version": 2,
|
||||||
|
"platform_settings": None,
|
||||||
|
"provider_settings": None,
|
||||||
|
}
|
||||||
|
with open(temp_config_path, "w", encoding="utf-8-sig") as f:
|
||||||
|
json.dump(existing_config, f)
|
||||||
|
|
||||||
|
AstrBotConfig(
|
||||||
|
config_path=temp_config_path, default_config=minimal_default_config
|
||||||
|
)
|
||||||
|
|
||||||
|
# Reload to verify the values were replaced
|
||||||
|
config2 = AstrBotConfig(
|
||||||
|
config_path=temp_config_path, default_config=minimal_default_config
|
||||||
|
)
|
||||||
|
|
||||||
|
assert config2.platform_settings is not None
|
||||||
|
assert config2.provider_settings is not None
|
||||||
|
|
||||||
|
def test_reorder_config_keys(self, temp_config_path, minimal_default_config):
|
||||||
|
"""Test that config keys are reordered to match default."""
|
||||||
|
existing_config = {
|
||||||
|
"provider_settings": {"enable": True},
|
||||||
|
"config_version": 2,
|
||||||
|
"platform_settings": {"unique_session": False},
|
||||||
|
}
|
||||||
|
with open(temp_config_path, "w", encoding="utf-8-sig") as f:
|
||||||
|
json.dump(existing_config, f)
|
||||||
|
|
||||||
|
AstrBotConfig(
|
||||||
|
config_path=temp_config_path, default_config=minimal_default_config
|
||||||
|
)
|
||||||
|
|
||||||
|
with open(temp_config_path, encoding="utf-8-sig") as f:
|
||||||
|
loaded_config = json.load(f)
|
||||||
|
|
||||||
|
keys = list(loaded_config.keys())
|
||||||
|
assert keys[0] == "config_version"
|
||||||
|
assert keys[1] == "platform_settings"
|
||||||
|
assert keys[2] == "provider_settings"
|
||||||
|
|
||||||
|
def test_remove_unknown_config_keys(self, temp_config_path, minimal_default_config):
|
||||||
|
"""Test that unknown config keys are removed."""
|
||||||
|
existing_config = {
|
||||||
|
"config_version": 2,
|
||||||
|
"platform_settings": {},
|
||||||
|
"unknown_key": "should_be_removed",
|
||||||
|
}
|
||||||
|
with open(temp_config_path, "w", encoding="utf-8-sig") as f:
|
||||||
|
json.dump(existing_config, f)
|
||||||
|
|
||||||
|
config = AstrBotConfig(
|
||||||
|
config_path=temp_config_path, default_config=minimal_default_config
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "unknown_key" not in config
|
||||||
|
|
||||||
|
def test_nested_config_validation(self, temp_config_path):
|
||||||
|
"""Test validation of nested config structures."""
|
||||||
|
default_config = {
|
||||||
|
"nested": {
|
||||||
|
"level1": {
|
||||||
|
"level2": {
|
||||||
|
"value": 42,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
existing_config = {
|
||||||
|
"nested": {
|
||||||
|
"level1": {}, # Missing level2
|
||||||
|
},
|
||||||
|
}
|
||||||
|
with open(temp_config_path, "w", encoding="utf-8-sig") as f:
|
||||||
|
json.dump(existing_config, f)
|
||||||
|
|
||||||
|
config = AstrBotConfig(
|
||||||
|
config_path=temp_config_path, default_config=default_config
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "level2" in config.nested["level1"]
|
||||||
|
assert config.nested["level1"]["level2"]["value"] == 42
|
||||||
|
|
||||||
|
|
||||||
|
class TestConfigHotReload:
|
||||||
|
"""Tests for config hot reload functionality."""
|
||||||
|
|
||||||
|
def test_save_config(self, temp_config_path, minimal_default_config):
|
||||||
|
"""Test saving config to file."""
|
||||||
|
config = AstrBotConfig(
|
||||||
|
config_path=temp_config_path, default_config=minimal_default_config
|
||||||
|
)
|
||||||
|
config.new_field = "new_value"
|
||||||
|
config.save_config()
|
||||||
|
|
||||||
|
with open(temp_config_path, encoding="utf-8-sig") as f:
|
||||||
|
loaded_config = json.load(f)
|
||||||
|
|
||||||
|
assert loaded_config["new_field"] == "new_value"
|
||||||
|
|
||||||
|
def test_save_config_with_replace(self, temp_config_path, minimal_default_config):
|
||||||
|
"""Test saving config with replacement."""
|
||||||
|
config = AstrBotConfig(
|
||||||
|
config_path=temp_config_path, default_config=minimal_default_config
|
||||||
|
)
|
||||||
|
|
||||||
|
replacement_config = {
|
||||||
|
"replaced": True,
|
||||||
|
"extra_field": "value",
|
||||||
|
}
|
||||||
|
config.save_config(replace_config=replacement_config)
|
||||||
|
|
||||||
|
with open(temp_config_path, encoding="utf-8-sig") as f:
|
||||||
|
loaded_config = json.load(f)
|
||||||
|
|
||||||
|
# The replacement config is merged with existing config
|
||||||
|
assert loaded_config["replaced"] is True
|
||||||
|
assert loaded_config["extra_field"] == "value"
|
||||||
|
# Original fields are preserved because update merges
|
||||||
|
assert "platform_settings" in loaded_config
|
||||||
|
|
||||||
|
def test_modification_persists_after_reload(
|
||||||
|
self, temp_config_path, minimal_default_config
|
||||||
|
):
|
||||||
|
"""Test that modifications persist after reloading."""
|
||||||
|
config1 = AstrBotConfig(
|
||||||
|
config_path=temp_config_path, default_config=minimal_default_config
|
||||||
|
)
|
||||||
|
config1.platform_settings["unique_session"] = True
|
||||||
|
config1.save_config()
|
||||||
|
|
||||||
|
config2 = AstrBotConfig(
|
||||||
|
config_path=temp_config_path, default_config=minimal_default_config
|
||||||
|
)
|
||||||
|
|
||||||
|
assert config2.platform_settings["unique_session"] is True
|
||||||
|
|
||||||
|
|
||||||
|
class TestConfigSchemaToDefault:
|
||||||
|
"""Tests for schema to default config conversion."""
|
||||||
|
|
||||||
|
def test_convert_schema_with_defaults(self, temp_config_path):
|
||||||
|
"""Test converting schema with explicit defaults."""
|
||||||
|
schema = {
|
||||||
|
"string_field": {"type": "string", "default": "custom"},
|
||||||
|
"int_field": {"type": "int", "default": 100},
|
||||||
|
"bool_field": {"type": "bool", "default": True},
|
||||||
|
}
|
||||||
|
|
||||||
|
config = AstrBotConfig(config_path=temp_config_path, schema=schema)
|
||||||
|
|
||||||
|
assert config.string_field == "custom"
|
||||||
|
assert config.int_field == 100
|
||||||
|
assert config.bool_field is True
|
||||||
|
|
||||||
|
def test_convert_schema_without_defaults(self, temp_config_path):
|
||||||
|
"""Test converting schema using default value map."""
|
||||||
|
schema = {
|
||||||
|
"string_field": {"type": "string"},
|
||||||
|
"int_field": {"type": "int"},
|
||||||
|
"bool_field": {"type": "bool"},
|
||||||
|
}
|
||||||
|
|
||||||
|
config = AstrBotConfig(config_path=temp_config_path, schema=schema)
|
||||||
|
|
||||||
|
assert config.string_field == DEFAULT_VALUE_MAP["string"]
|
||||||
|
assert config.int_field == DEFAULT_VALUE_MAP["int"]
|
||||||
|
assert config.bool_field == DEFAULT_VALUE_MAP["bool"]
|
||||||
|
|
||||||
|
def test_unsupported_schema_type_raises_error(self, temp_config_path):
|
||||||
|
"""Test that unsupported schema types raise error."""
|
||||||
|
schema = {
|
||||||
|
"field": {"type": "unsupported_type"},
|
||||||
|
}
|
||||||
|
|
||||||
|
with pytest.raises(TypeError, match="不受支持的配置类型"):
|
||||||
|
AstrBotConfig(config_path=temp_config_path, schema=schema)
|
||||||
|
|
||||||
|
def test_template_list_type(self, temp_config_path):
|
||||||
|
"""Test template_list schema type."""
|
||||||
|
schema = {
|
||||||
|
"templates": {"type": "template_list", "default": []},
|
||||||
|
}
|
||||||
|
|
||||||
|
config = AstrBotConfig(config_path=temp_config_path, schema=schema)
|
||||||
|
|
||||||
|
assert config.templates == []
|
||||||
|
|
||||||
|
def test_nested_object_schema(self, temp_config_path):
|
||||||
|
"""Test nested object schema conversion."""
|
||||||
|
schema = {
|
||||||
|
"nested": {
|
||||||
|
"type": "object",
|
||||||
|
"items": {
|
||||||
|
"field1": {"type": "string"},
|
||||||
|
"field2": {"type": "int"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
config = AstrBotConfig(config_path=temp_config_path, schema=schema)
|
||||||
|
|
||||||
|
assert config.nested["field1"] == ""
|
||||||
|
assert config.nested["field2"] == 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestConfigMetadataI18n:
|
||||||
|
"""Tests for i18n utils."""
|
||||||
|
|
||||||
|
def test_get_i18n_key(self):
|
||||||
|
"""Test generating i18n key."""
|
||||||
|
key = ConfigMetadataI18n._get_i18n_key(
|
||||||
|
group="ai_group",
|
||||||
|
section="general",
|
||||||
|
field="enable",
|
||||||
|
attr="description",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert key == "ai_group.general.enable.description"
|
||||||
|
|
||||||
|
def test_get_i18n_key_without_field(self):
|
||||||
|
"""Test generating i18n key without field."""
|
||||||
|
key = ConfigMetadataI18n._get_i18n_key(
|
||||||
|
group="ai_group",
|
||||||
|
section="general",
|
||||||
|
field="",
|
||||||
|
attr="description",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert key == "ai_group.general.description"
|
||||||
|
|
||||||
|
def test_convert_to_i18n_keys_simple(self):
|
||||||
|
"""Test converting simple metadata to i18n keys."""
|
||||||
|
metadata = {
|
||||||
|
"ai_group": {
|
||||||
|
"name": "AI Settings",
|
||||||
|
"metadata": {
|
||||||
|
"general": {
|
||||||
|
"description": "General settings",
|
||||||
|
"items": {
|
||||||
|
"enable": {
|
||||||
|
"description": "Enable feature",
|
||||||
|
"type": "bool",
|
||||||
|
"default": True,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result = ConfigMetadataI18n.convert_to_i18n_keys(metadata)
|
||||||
|
|
||||||
|
assert result["ai_group"]["name"] == "ai_group.name"
|
||||||
|
assert (
|
||||||
|
result["ai_group"]["metadata"]["general"]["description"]
|
||||||
|
== "ai_group.general.description"
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
result["ai_group"]["metadata"]["general"]["items"]["enable"]["description"]
|
||||||
|
== "ai_group.general.enable.description"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_convert_to_i18n_keys_with_hint(self):
|
||||||
|
"""Test converting metadata with hint."""
|
||||||
|
metadata = {
|
||||||
|
"group": {
|
||||||
|
"metadata": {
|
||||||
|
"section": {
|
||||||
|
"hint": "This is a hint",
|
||||||
|
"items": {
|
||||||
|
"field": {
|
||||||
|
"hint": "Field hint",
|
||||||
|
"type": "string",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result = ConfigMetadataI18n.convert_to_i18n_keys(metadata)
|
||||||
|
|
||||||
|
assert result["group"]["metadata"]["section"]["hint"] == "group.section.hint"
|
||||||
|
assert (
|
||||||
|
result["group"]["metadata"]["section"]["items"]["field"]["hint"]
|
||||||
|
== "group.section.field.hint"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_convert_to_i18n_keys_with_labels(self):
|
||||||
|
"""Test converting metadata with labels."""
|
||||||
|
metadata = {
|
||||||
|
"group": {
|
||||||
|
"metadata": {
|
||||||
|
"section": {
|
||||||
|
"items": {
|
||||||
|
"field": {
|
||||||
|
"labels": ["Label1", "Label2"],
|
||||||
|
"type": "string",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result = ConfigMetadataI18n.convert_to_i18n_keys(metadata)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
result["group"]["metadata"]["section"]["items"]["field"]["labels"]
|
||||||
|
== "group.section.field.labels"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_convert_to_i18n_keys_nested_items(self):
|
||||||
|
"""Test converting metadata with nested items."""
|
||||||
|
metadata = {
|
||||||
|
"group": {
|
||||||
|
"metadata": {
|
||||||
|
"section": {
|
||||||
|
"items": {
|
||||||
|
"nested": {
|
||||||
|
"description": "Nested field",
|
||||||
|
"type": "object",
|
||||||
|
"items": {
|
||||||
|
"inner": {
|
||||||
|
"description": "Inner field",
|
||||||
|
"type": "string",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result = ConfigMetadataI18n.convert_to_i18n_keys(metadata)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
result["group"]["metadata"]["section"]["items"]["nested"]["description"]
|
||||||
|
== "group.section.nested.description"
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
result["group"]["metadata"]["section"]["items"]["nested"]["items"]["inner"][
|
||||||
|
"description"
|
||||||
|
]
|
||||||
|
== "group.section.nested.inner.description"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_convert_to_i18n_keys_preserves_non_i18n_fields(self):
|
||||||
|
"""Test that non-i18n fields are preserved."""
|
||||||
|
metadata = {
|
||||||
|
"group": {
|
||||||
|
"metadata": {
|
||||||
|
"section": {
|
||||||
|
"items": {
|
||||||
|
"field": {
|
||||||
|
"description": "Field description",
|
||||||
|
"type": "string",
|
||||||
|
"other_field": "preserve this",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result = ConfigMetadataI18n.convert_to_i18n_keys(metadata)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
result["group"]["metadata"]["section"]["items"]["field"]["other_field"]
|
||||||
|
== "preserve this"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_convert_to_i18n_keys_with_name(self):
|
||||||
|
"""Test converting metadata with name field."""
|
||||||
|
metadata = {
|
||||||
|
"group": {
|
||||||
|
"metadata": {
|
||||||
|
"section": {
|
||||||
|
"items": {
|
||||||
|
"field": {
|
||||||
|
"name": "Field Name",
|
||||||
|
"type": "string",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result = ConfigMetadataI18n.convert_to_i18n_keys(metadata)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
result["group"]["metadata"]["section"]["items"]["field"]["name"]
|
||||||
|
== "group.section.field.name"
|
||||||
|
)
|
||||||
@@ -0,0 +1,504 @@
|
|||||||
|
"""Tests for CronJobManager."""
|
||||||
|
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from astrbot.core.cron.manager import CronJobManager
|
||||||
|
from astrbot.core.db.po import CronJob
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_db():
|
||||||
|
"""Create a mock database."""
|
||||||
|
db = MagicMock()
|
||||||
|
db.create_cron_job = AsyncMock()
|
||||||
|
db.get_cron_job = AsyncMock()
|
||||||
|
db.update_cron_job = AsyncMock()
|
||||||
|
db.delete_cron_job = AsyncMock()
|
||||||
|
db.list_cron_jobs = AsyncMock(return_value=[])
|
||||||
|
return db
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_context():
|
||||||
|
"""Create a mock Context."""
|
||||||
|
ctx = MagicMock()
|
||||||
|
ctx.get_config = MagicMock(return_value={"admins_id": []})
|
||||||
|
ctx.conversation_manager = MagicMock()
|
||||||
|
return ctx
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def cron_manager(mock_db):
|
||||||
|
"""Create a CronJobManager instance."""
|
||||||
|
return CronJobManager(mock_db)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_cron_job():
|
||||||
|
"""Create a sample CronJob."""
|
||||||
|
return CronJob(
|
||||||
|
job_id="test-job-id",
|
||||||
|
name="Test Job",
|
||||||
|
job_type="basic",
|
||||||
|
cron_expression="0 9 * * *",
|
||||||
|
timezone="UTC",
|
||||||
|
payload={"key": "value"},
|
||||||
|
description="A test job",
|
||||||
|
enabled=True,
|
||||||
|
persistent=True,
|
||||||
|
run_once=False,
|
||||||
|
status="pending",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestCronJobManagerInit:
|
||||||
|
"""Tests for CronJobManager initialization."""
|
||||||
|
|
||||||
|
def test_init(self, mock_db):
|
||||||
|
"""Test CronJobManager initialization."""
|
||||||
|
manager = CronJobManager(mock_db)
|
||||||
|
|
||||||
|
assert manager.db == mock_db
|
||||||
|
assert manager._basic_handlers == {}
|
||||||
|
assert manager._started is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestCronJobManagerStart:
|
||||||
|
"""Tests for CronJobManager.start method."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_start(self, cron_manager, mock_db, mock_context):
|
||||||
|
"""Test starting the cron manager."""
|
||||||
|
mock_db.list_cron_jobs.return_value = []
|
||||||
|
|
||||||
|
await cron_manager.start(mock_context)
|
||||||
|
|
||||||
|
assert cron_manager._started is True
|
||||||
|
assert cron_manager.ctx == mock_context
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_start_idempotent(self, cron_manager, mock_db, mock_context):
|
||||||
|
"""Test that start is idempotent."""
|
||||||
|
mock_db.list_cron_jobs.return_value = []
|
||||||
|
|
||||||
|
await cron_manager.start(mock_context)
|
||||||
|
await cron_manager.start(mock_context)
|
||||||
|
|
||||||
|
# Should only sync once
|
||||||
|
assert mock_db.list_cron_jobs.call_count == 1
|
||||||
|
|
||||||
|
|
||||||
|
class TestCronJobManagerShutdown:
|
||||||
|
"""Tests for CronJobManager.shutdown method."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_shutdown(self, cron_manager, mock_db, mock_context):
|
||||||
|
"""Test shutting down the cron manager."""
|
||||||
|
mock_db.list_cron_jobs.return_value = []
|
||||||
|
await cron_manager.start(mock_context)
|
||||||
|
|
||||||
|
await cron_manager.shutdown()
|
||||||
|
|
||||||
|
assert cron_manager._started is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_shutdown_when_not_started(self, cron_manager):
|
||||||
|
"""Test shutdown when not started."""
|
||||||
|
# Should not raise
|
||||||
|
await cron_manager.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
class TestAddBasicJob:
|
||||||
|
"""Tests for add_basic_job method."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_add_basic_job(self, cron_manager, mock_db, sample_cron_job):
|
||||||
|
"""Test adding a basic cron job."""
|
||||||
|
mock_db.create_cron_job.return_value = sample_cron_job
|
||||||
|
|
||||||
|
handler = MagicMock()
|
||||||
|
|
||||||
|
result = await cron_manager.add_basic_job(
|
||||||
|
name="Test Job",
|
||||||
|
cron_expression="0 9 * * *",
|
||||||
|
handler=handler,
|
||||||
|
description="A test job",
|
||||||
|
enabled=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == sample_cron_job
|
||||||
|
assert sample_cron_job.job_id in cron_manager._basic_handlers
|
||||||
|
mock_db.create_cron_job.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_add_basic_job_disabled(self, cron_manager, mock_db, sample_cron_job):
|
||||||
|
"""Test adding a disabled basic cron job."""
|
||||||
|
sample_cron_job.enabled = False
|
||||||
|
mock_db.create_cron_job.return_value = sample_cron_job
|
||||||
|
|
||||||
|
handler = MagicMock()
|
||||||
|
|
||||||
|
result = await cron_manager.add_basic_job(
|
||||||
|
name="Test Job",
|
||||||
|
cron_expression="0 9 * * *",
|
||||||
|
handler=handler,
|
||||||
|
enabled=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == sample_cron_job
|
||||||
|
assert sample_cron_job.job_id in cron_manager._basic_handlers
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_add_basic_job_with_timezone(self, cron_manager, mock_db, sample_cron_job):
|
||||||
|
"""Test adding a basic job with timezone."""
|
||||||
|
mock_db.create_cron_job.return_value = sample_cron_job
|
||||||
|
|
||||||
|
handler = MagicMock()
|
||||||
|
|
||||||
|
await cron_manager.add_basic_job(
|
||||||
|
name="Test Job",
|
||||||
|
cron_expression="0 9 * * *",
|
||||||
|
handler=handler,
|
||||||
|
timezone="Asia/Shanghai",
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_db.create_cron_job.assert_called_once()
|
||||||
|
call_kwargs = mock_db.create_cron_job.call_args.kwargs
|
||||||
|
assert call_kwargs["timezone"] == "Asia/Shanghai"
|
||||||
|
|
||||||
|
|
||||||
|
class TestAddActiveJob:
|
||||||
|
"""Tests for add_active_job method."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_add_active_job(self, cron_manager, mock_db, sample_cron_job):
|
||||||
|
"""Test adding an active agent cron job."""
|
||||||
|
sample_cron_job.job_type = "active_agent"
|
||||||
|
mock_db.create_cron_job.return_value = sample_cron_job
|
||||||
|
|
||||||
|
result = await cron_manager.add_active_job(
|
||||||
|
name="Test Active Job",
|
||||||
|
cron_expression="0 9 * * *",
|
||||||
|
payload={"session": "test:group:123"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == sample_cron_job
|
||||||
|
mock_db.create_cron_job.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_add_active_job_run_once(self, cron_manager, mock_db, sample_cron_job):
|
||||||
|
"""Test adding a run-once active job."""
|
||||||
|
sample_cron_job.job_type = "active_agent"
|
||||||
|
sample_cron_job.run_once = True
|
||||||
|
mock_db.create_cron_job.return_value = sample_cron_job
|
||||||
|
|
||||||
|
run_at = datetime.now(timezone.utc) + timedelta(days=30)
|
||||||
|
|
||||||
|
result = await cron_manager.add_active_job(
|
||||||
|
name="Test Run Once Job",
|
||||||
|
cron_expression=None,
|
||||||
|
payload={"session": "test:group:123"},
|
||||||
|
run_once=True,
|
||||||
|
run_at=run_at,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == sample_cron_job
|
||||||
|
call_kwargs = mock_db.create_cron_job.call_args.kwargs
|
||||||
|
assert call_kwargs["run_once"] is True
|
||||||
|
|
||||||
|
|
||||||
|
class TestUpdateJob:
|
||||||
|
"""Tests for update_job method."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_job(self, cron_manager, mock_db, sample_cron_job):
|
||||||
|
"""Test updating a cron job."""
|
||||||
|
updated_job = CronJob(
|
||||||
|
job_id="test-job-id",
|
||||||
|
name="Updated Job",
|
||||||
|
job_type="basic",
|
||||||
|
cron_expression="0 10 * * *",
|
||||||
|
enabled=False, # Disabled to avoid scheduling
|
||||||
|
)
|
||||||
|
mock_db.update_cron_job.return_value = updated_job
|
||||||
|
|
||||||
|
result = await cron_manager.update_job("test-job-id", name="Updated Job")
|
||||||
|
|
||||||
|
assert result == updated_job
|
||||||
|
mock_db.update_cron_job.assert_called()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_job_not_found(self, cron_manager, mock_db):
|
||||||
|
"""Test updating a non-existent job."""
|
||||||
|
mock_db.update_cron_job.return_value = None
|
||||||
|
|
||||||
|
result = await cron_manager.update_job("non-existent", name="Updated")
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestDeleteJob:
|
||||||
|
"""Tests for delete_job method."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_job(self, cron_manager, mock_db):
|
||||||
|
"""Test deleting a cron job."""
|
||||||
|
cron_manager._basic_handlers["test-job-id"] = MagicMock()
|
||||||
|
|
||||||
|
await cron_manager.delete_job("test-job-id")
|
||||||
|
|
||||||
|
mock_db.delete_cron_job.assert_called_once_with("test-job-id")
|
||||||
|
assert "test-job-id" not in cron_manager._basic_handlers
|
||||||
|
|
||||||
|
|
||||||
|
class TestListJobs:
|
||||||
|
"""Tests for list_jobs method."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_all_jobs(self, cron_manager, mock_db, sample_cron_job):
|
||||||
|
"""Test listing all jobs."""
|
||||||
|
mock_db.list_cron_jobs.return_value = [sample_cron_job]
|
||||||
|
|
||||||
|
result = await cron_manager.list_jobs()
|
||||||
|
|
||||||
|
assert len(result) == 1
|
||||||
|
mock_db.list_cron_jobs.assert_called_once_with(None)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_jobs_by_type(self, cron_manager, mock_db, sample_cron_job):
|
||||||
|
"""Test listing jobs by type."""
|
||||||
|
mock_db.list_cron_jobs.return_value = [sample_cron_job]
|
||||||
|
|
||||||
|
result = await cron_manager.list_jobs(job_type="basic")
|
||||||
|
|
||||||
|
assert len(result) == 1
|
||||||
|
mock_db.list_cron_jobs.assert_called_once_with("basic")
|
||||||
|
|
||||||
|
|
||||||
|
class TestSyncFromDb:
|
||||||
|
"""Tests for sync_from_db method."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_sync_from_db_empty(self, cron_manager, mock_db):
|
||||||
|
"""Test syncing from empty database."""
|
||||||
|
mock_db.list_cron_jobs.return_value = []
|
||||||
|
|
||||||
|
await cron_manager.sync_from_db()
|
||||||
|
|
||||||
|
mock_db.list_cron_jobs.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_sync_from_db_skips_disabled(self, cron_manager, mock_db, sample_cron_job):
|
||||||
|
"""Test that sync skips disabled jobs."""
|
||||||
|
sample_cron_job.enabled = False
|
||||||
|
mock_db.list_cron_jobs.return_value = [sample_cron_job]
|
||||||
|
|
||||||
|
with patch.object(cron_manager, "_schedule_job") as mock_schedule:
|
||||||
|
await cron_manager.sync_from_db()
|
||||||
|
|
||||||
|
mock_db.list_cron_jobs.assert_called_once()
|
||||||
|
mock_schedule.assert_not_called()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_sync_from_db_skips_non_persistent(self, cron_manager, mock_db, sample_cron_job):
|
||||||
|
"""Test that sync skips non-persistent jobs."""
|
||||||
|
sample_cron_job.persistent = False
|
||||||
|
mock_db.list_cron_jobs.return_value = [sample_cron_job]
|
||||||
|
|
||||||
|
with patch.object(cron_manager, "_schedule_job") as mock_schedule:
|
||||||
|
await cron_manager.sync_from_db()
|
||||||
|
|
||||||
|
mock_db.list_cron_jobs.assert_called_once()
|
||||||
|
mock_schedule.assert_not_called()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_sync_from_db_basic_without_handler(
|
||||||
|
self, cron_manager, mock_db, sample_cron_job
|
||||||
|
):
|
||||||
|
"""Test that sync warns for basic jobs without handlers."""
|
||||||
|
mock_db.list_cron_jobs.return_value = [sample_cron_job]
|
||||||
|
|
||||||
|
with patch("astrbot.core.cron.manager.logger") as mock_logger:
|
||||||
|
await cron_manager.sync_from_db()
|
||||||
|
|
||||||
|
mock_logger.warning.assert_called()
|
||||||
|
|
||||||
|
|
||||||
|
class TestRemoveScheduled:
|
||||||
|
"""Tests for _remove_scheduled method."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_remove_scheduled_existing(self, cron_manager, mock_context):
|
||||||
|
"""Test removing a scheduled job."""
|
||||||
|
# Start the scheduler first
|
||||||
|
job = CronJob(
|
||||||
|
job_id="test-job-id",
|
||||||
|
name="Test",
|
||||||
|
job_type="active_agent",
|
||||||
|
cron_expression="0 9 * * *",
|
||||||
|
enabled=True,
|
||||||
|
persistent=True,
|
||||||
|
)
|
||||||
|
mock_db = cron_manager.db
|
||||||
|
mock_db.list_cron_jobs = AsyncMock(return_value=[job])
|
||||||
|
await cron_manager.start(mock_context)
|
||||||
|
|
||||||
|
# Then remove it
|
||||||
|
cron_manager._remove_scheduled("test-job-id")
|
||||||
|
|
||||||
|
# Should not raise
|
||||||
|
|
||||||
|
def test_remove_scheduled_nonexistent(self, cron_manager):
|
||||||
|
"""Test removing a non-existent job."""
|
||||||
|
# Should not raise
|
||||||
|
cron_manager._remove_scheduled("non-existent")
|
||||||
|
|
||||||
|
|
||||||
|
class TestScheduleJob:
|
||||||
|
"""Tests for _schedule_job method."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_schedule_job_basic(self, cron_manager, sample_cron_job, mock_context):
|
||||||
|
"""Test scheduling a basic job."""
|
||||||
|
mock_db = cron_manager.db
|
||||||
|
mock_db.list_cron_jobs = AsyncMock(return_value=[])
|
||||||
|
mock_db.update_cron_job = AsyncMock()
|
||||||
|
await cron_manager.start(mock_context)
|
||||||
|
cron_manager._schedule_job(sample_cron_job)
|
||||||
|
|
||||||
|
# Verify job was added to scheduler
|
||||||
|
assert cron_manager.scheduler.get_job("test-job-id") is not None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_schedule_job_with_timezone(self, cron_manager, sample_cron_job, mock_context):
|
||||||
|
"""Test scheduling a job with timezone."""
|
||||||
|
sample_cron_job.timezone = "America/New_York"
|
||||||
|
mock_db = cron_manager.db
|
||||||
|
mock_db.list_cron_jobs = AsyncMock(return_value=[])
|
||||||
|
mock_db.update_cron_job = AsyncMock()
|
||||||
|
await cron_manager.start(mock_context)
|
||||||
|
cron_manager._schedule_job(sample_cron_job)
|
||||||
|
|
||||||
|
assert cron_manager.scheduler.get_job("test-job-id") is not None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_schedule_job_invalid_timezone(self, cron_manager, sample_cron_job, mock_context):
|
||||||
|
"""Test scheduling a job with invalid timezone."""
|
||||||
|
sample_cron_job.timezone = "Invalid/Timezone"
|
||||||
|
mock_db = cron_manager.db
|
||||||
|
mock_db.list_cron_jobs = AsyncMock(return_value=[])
|
||||||
|
mock_db.update_cron_job = AsyncMock()
|
||||||
|
|
||||||
|
with patch("astrbot.core.cron.manager.logger") as mock_logger:
|
||||||
|
await cron_manager.start(mock_context)
|
||||||
|
cron_manager._schedule_job(sample_cron_job)
|
||||||
|
|
||||||
|
# Should still schedule with system timezone
|
||||||
|
assert cron_manager.scheduler.get_job("test-job-id") is not None
|
||||||
|
mock_logger.warning.assert_called()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_schedule_job_run_once(self, cron_manager, mock_context):
|
||||||
|
"""Test scheduling a run-once job."""
|
||||||
|
future_date = datetime.now(timezone.utc) + timedelta(days=30)
|
||||||
|
job = CronJob(
|
||||||
|
job_id="run-once-job",
|
||||||
|
name="Run Once",
|
||||||
|
job_type="active_agent",
|
||||||
|
cron_expression=None,
|
||||||
|
enabled=True,
|
||||||
|
run_once=True,
|
||||||
|
payload={"run_at": future_date.isoformat()},
|
||||||
|
)
|
||||||
|
mock_db = cron_manager.db
|
||||||
|
mock_db.list_cron_jobs = AsyncMock(return_value=[])
|
||||||
|
mock_db.update_cron_job = AsyncMock()
|
||||||
|
await cron_manager.start(mock_context)
|
||||||
|
cron_manager._schedule_job(job)
|
||||||
|
|
||||||
|
assert cron_manager.scheduler.get_job("run-once-job") is not None
|
||||||
|
|
||||||
|
|
||||||
|
class TestRunJob:
|
||||||
|
"""Tests for _run_job method."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_job_disabled(self, cron_manager, mock_db, sample_cron_job):
|
||||||
|
"""Test running a disabled job."""
|
||||||
|
sample_cron_job.enabled = False
|
||||||
|
mock_db.get_cron_job.return_value = sample_cron_job
|
||||||
|
|
||||||
|
await cron_manager._run_job("test-job-id")
|
||||||
|
|
||||||
|
# Should not update status
|
||||||
|
mock_db.update_cron_job.assert_not_called()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_job_not_found(self, cron_manager, mock_db):
|
||||||
|
"""Test running a non-existent job."""
|
||||||
|
mock_db.get_cron_job.return_value = None
|
||||||
|
|
||||||
|
await cron_manager._run_job("non-existent")
|
||||||
|
|
||||||
|
# Should not update status
|
||||||
|
mock_db.update_cron_job.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
class TestRunBasicJob:
|
||||||
|
"""Tests for _run_basic_job method."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_basic_job_sync_handler(self, cron_manager, sample_cron_job):
|
||||||
|
"""Test running a basic job with sync handler."""
|
||||||
|
handler = MagicMock(return_value=None)
|
||||||
|
cron_manager._basic_handlers["test-job-id"] = handler
|
||||||
|
sample_cron_job.payload = {"arg1": "value1"}
|
||||||
|
|
||||||
|
await cron_manager._run_basic_job(sample_cron_job)
|
||||||
|
|
||||||
|
handler.assert_called_once_with(arg1="value1")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_basic_job_async_handler(self, cron_manager, sample_cron_job):
|
||||||
|
"""Test running a basic job with async handler."""
|
||||||
|
async_handler = AsyncMock()
|
||||||
|
cron_manager._basic_handlers["test-job-id"] = async_handler
|
||||||
|
sample_cron_job.payload = {}
|
||||||
|
|
||||||
|
await cron_manager._run_basic_job(sample_cron_job)
|
||||||
|
|
||||||
|
async_handler.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_basic_job_no_handler(self, cron_manager, sample_cron_job):
|
||||||
|
"""Test running a basic job without handler."""
|
||||||
|
sample_cron_job.job_id = "no-handler-job"
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError, match="handler not found"):
|
||||||
|
await cron_manager._run_basic_job(sample_cron_job)
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetNextRunTime:
|
||||||
|
"""Tests for _get_next_run_time method."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_next_run_time_existing_job(self, cron_manager, sample_cron_job, mock_context):
|
||||||
|
"""Test getting next run time for existing job."""
|
||||||
|
mock_db = cron_manager.db
|
||||||
|
mock_db.list_cron_jobs = AsyncMock(return_value=[])
|
||||||
|
mock_db.update_cron_job = AsyncMock()
|
||||||
|
await cron_manager.start(mock_context)
|
||||||
|
cron_manager._schedule_job(sample_cron_job)
|
||||||
|
|
||||||
|
next_run = cron_manager._get_next_run_time("test-job-id")
|
||||||
|
|
||||||
|
assert next_run is not None
|
||||||
|
|
||||||
|
def test_get_next_run_time_nonexistent(self, cron_manager):
|
||||||
|
"""Test getting next run time for non-existent job."""
|
||||||
|
next_run = cron_manager._get_next_run_time("non-existent")
|
||||||
|
|
||||||
|
assert next_run is None
|
||||||
@@ -0,0 +1,198 @@
|
|||||||
|
"""Tests for astrbot.core.star.base module."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
|
||||||
|
class TestStarBase:
|
||||||
|
"""Test cases for the Star base class."""
|
||||||
|
|
||||||
|
def test_star_class_exists(self):
|
||||||
|
"""Test that Star class can be imported."""
|
||||||
|
from astrbot.core.star import Star
|
||||||
|
|
||||||
|
assert Star is not None
|
||||||
|
|
||||||
|
def test_star_init_with_context(self):
|
||||||
|
"""Test Star initialization with a context-like object."""
|
||||||
|
from astrbot.core.star import Star
|
||||||
|
|
||||||
|
# Create a mock context with get_config method
|
||||||
|
mock_context = MagicMock()
|
||||||
|
mock_context.get_config.return_value = MagicMock()
|
||||||
|
|
||||||
|
# Create a concrete Star subclass for testing
|
||||||
|
class TestStar(Star):
|
||||||
|
name = "test_star"
|
||||||
|
author = "test_author"
|
||||||
|
|
||||||
|
star = TestStar(context=mock_context)
|
||||||
|
|
||||||
|
assert star.context is mock_context
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_text_to_image_with_config(self):
|
||||||
|
"""Test text_to_image method with valid config."""
|
||||||
|
from astrbot.core.star import Star
|
||||||
|
|
||||||
|
mock_context = MagicMock()
|
||||||
|
mock_config = MagicMock()
|
||||||
|
mock_config.get.return_value = "default_template"
|
||||||
|
mock_context.get_config.return_value = mock_config
|
||||||
|
|
||||||
|
class TestStar(Star):
|
||||||
|
name = "test_star"
|
||||||
|
author = "test_author"
|
||||||
|
|
||||||
|
star = TestStar(context=mock_context)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"astrbot.core.star.base.html_renderer.render_t2i",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
) as mock_render:
|
||||||
|
mock_render.return_value = "http://example.com/image.png"
|
||||||
|
result = await star.text_to_image("test text", return_url=True)
|
||||||
|
|
||||||
|
mock_render.assert_called_once_with(
|
||||||
|
"test text",
|
||||||
|
return_url=True,
|
||||||
|
template_name="default_template",
|
||||||
|
)
|
||||||
|
assert result == "http://example.com/image.png"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_text_to_image_without_config(self):
|
||||||
|
"""Test text_to_image method when get_config returns None."""
|
||||||
|
from astrbot.core.star import Star
|
||||||
|
|
||||||
|
mock_context = MagicMock()
|
||||||
|
mock_context.get_config.return_value = None
|
||||||
|
|
||||||
|
class TestStar(Star):
|
||||||
|
name = "test_star"
|
||||||
|
author = "test_author"
|
||||||
|
|
||||||
|
star = TestStar(context=mock_context)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"astrbot.core.star.base.html_renderer.render_t2i",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
) as mock_render:
|
||||||
|
mock_render.return_value = "http://example.com/image.png"
|
||||||
|
result = await star.text_to_image("test text", return_url=False)
|
||||||
|
|
||||||
|
mock_render.assert_called_once_with(
|
||||||
|
"test text",
|
||||||
|
return_url=False,
|
||||||
|
template_name=None,
|
||||||
|
)
|
||||||
|
assert result == "http://example.com/image.png"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_html_render(self):
|
||||||
|
"""Test html_render method."""
|
||||||
|
from astrbot.core.star import Star
|
||||||
|
|
||||||
|
mock_context = MagicMock()
|
||||||
|
|
||||||
|
class TestStar(Star):
|
||||||
|
name = "test_star"
|
||||||
|
author = "test_author"
|
||||||
|
|
||||||
|
star = TestStar(context=mock_context)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"astrbot.core.star.base.html_renderer.render_custom_template",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
) as mock_render:
|
||||||
|
mock_render.return_value = "http://example.com/rendered.png"
|
||||||
|
result = await star.html_render(
|
||||||
|
"<html>{{ data }}</html>",
|
||||||
|
{"data": "test"},
|
||||||
|
return_url=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_render.assert_called_once_with(
|
||||||
|
"<html>{{ data }}</html>",
|
||||||
|
{"data": "test"},
|
||||||
|
return_url=True,
|
||||||
|
options=None,
|
||||||
|
)
|
||||||
|
assert result == "http://example.com/rendered.png"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_initialize_and_terminate(self):
|
||||||
|
"""Test that initialize and terminate methods can be overridden."""
|
||||||
|
from astrbot.core.star import Star
|
||||||
|
|
||||||
|
class TestStar(Star):
|
||||||
|
name = "test_star"
|
||||||
|
author = "test_author"
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
self.initialized = True
|
||||||
|
|
||||||
|
async def terminate(self) -> None:
|
||||||
|
self.terminated = True
|
||||||
|
|
||||||
|
mock_context = MagicMock()
|
||||||
|
star = TestStar(context=mock_context)
|
||||||
|
|
||||||
|
await star.initialize()
|
||||||
|
assert star.initialized is True
|
||||||
|
|
||||||
|
await star.terminate()
|
||||||
|
assert star.terminated is True
|
||||||
|
|
||||||
|
def test_star_metadata_registration(self):
|
||||||
|
"""Test that Star subclass is automatically registered."""
|
||||||
|
from astrbot.core.star import star_map, star_registry
|
||||||
|
from astrbot.core.star.star import StarMetadata
|
||||||
|
|
||||||
|
# Clear any previous registration for this test module
|
||||||
|
module_path = __name__
|
||||||
|
|
||||||
|
class UniqueTestStar:
|
||||||
|
"""Not a Star subclass, should not be registered."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Verify Star subclass gets registered
|
||||||
|
initial_count = len(star_registry)
|
||||||
|
|
||||||
|
# Note: This test verifies the __init_subclass__ mechanism
|
||||||
|
# The actual registration happens when a class inherits from Star
|
||||||
|
assert len(star_registry) >= initial_count
|
||||||
|
|
||||||
|
|
||||||
|
class TestNoCircularImports:
|
||||||
|
"""Test that there are no circular import issues."""
|
||||||
|
|
||||||
|
def test_import_star_module(self):
|
||||||
|
"""Test that star module can be imported without circular import errors."""
|
||||||
|
import astrbot.core.star
|
||||||
|
|
||||||
|
assert astrbot.core.star is not None
|
||||||
|
|
||||||
|
def test_import_pipeline_module(self):
|
||||||
|
"""Test that pipeline module can be imported without circular import errors."""
|
||||||
|
import astrbot.core.pipeline
|
||||||
|
|
||||||
|
assert astrbot.core.pipeline is not None
|
||||||
|
|
||||||
|
def test_import_both_modules(self):
|
||||||
|
"""Test that both modules can be imported together."""
|
||||||
|
import astrbot.core.pipeline
|
||||||
|
import astrbot.core.star
|
||||||
|
|
||||||
|
# Verify key exports are available
|
||||||
|
from astrbot.core.star import Context, Star, PluginManager
|
||||||
|
|
||||||
|
assert Context is not None
|
||||||
|
assert Star is not None
|
||||||
|
assert PluginManager is not None
|
||||||
|
|
||||||
|
def test_import_pipeline_context(self):
|
||||||
|
"""Test that PipelineContext can be imported."""
|
||||||
|
from astrbot.core.pipeline.context import PipelineContext
|
||||||
|
|
||||||
|
assert PipelineContext is not None
|
||||||
Reference in New Issue
Block a user