chore: update tests
This commit is contained in:
@@ -21,16 +21,10 @@ jobs:
|
||||
python -m pip install --upgrade pip
|
||||
pip install -r requirements.txt
|
||||
pip install pytest pytest-cov pytest-asyncio
|
||||
mkdir data
|
||||
mkdir data/plugins
|
||||
mkdir data/config
|
||||
mkdir temp
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
export LLM_MODEL=${{ secrets.LLM_MODEL }}
|
||||
export OPENAI_API_BASE=${{ secrets.OPENAI_API_BASE }}
|
||||
export OPENAI_API_KEY=${{ secrets.OPENAI_API_KEY }}
|
||||
export ZHIPU_API_KEY=${{ secrets.OPENAI_API_KEY }}
|
||||
PYTHONPATH=./ pytest --cov=. tests/ -v
|
||||
|
||||
- name: Upload results to Codecov
|
||||
|
||||
@@ -51,6 +51,9 @@ class CommandFilter(HandlerFilter, ParameterValidationMixin):
|
||||
ls = re.split(r"\s+", message_str)
|
||||
if self.command_name != ls[0]:
|
||||
return False
|
||||
if len(self.handler_params) == 0 and len(ls) > 1:
|
||||
# 一定程度避免 LLM 聊天时误判为指令
|
||||
return False
|
||||
# params_str = message_str[len(self.command_name):].strip()
|
||||
ls = ls[1:]
|
||||
# 去除空字符串
|
||||
|
||||
@@ -289,7 +289,7 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
|
||||
- 重置 LLM 会话(保留人格): /reset p
|
||||
|
||||
【当前人格】: {str(self.context.get_using_provider().curr_personality['prompt'])}
|
||||
"""))
|
||||
""").use_t2i(False))
|
||||
elif l[1] == "list":
|
||||
msg = "人格列表:\n"
|
||||
for key in personalities.keys():
|
||||
|
||||
@@ -0,0 +1,48 @@
|
||||
import os
|
||||
import sys
|
||||
import pytest
|
||||
from unittest import mock
|
||||
from main import check_env, check_dashboard_files
|
||||
|
||||
class _version_info():
|
||||
def __init__(self, major, minor):
|
||||
self.major = major
|
||||
self.minor = minor
|
||||
|
||||
def test_check_env(monkeypatch):
|
||||
version_info_correct = _version_info(3, 10)
|
||||
version_info_wrong = _version_info(3, 9)
|
||||
monkeypatch.setattr(sys, 'version_info', version_info_correct)
|
||||
with mock.patch('os.makedirs') as mock_makedirs:
|
||||
check_env()
|
||||
mock_makedirs.assert_any_call("data/config", exist_ok=True)
|
||||
mock_makedirs.assert_any_call("data/plugins", exist_ok=True)
|
||||
mock_makedirs.assert_any_call("data/temp", exist_ok=True)
|
||||
|
||||
monkeypatch.setattr(sys, 'version_info', version_info_wrong)
|
||||
with pytest.raises(SystemExit):
|
||||
check_env()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_dashboard_files(monkeypatch):
|
||||
monkeypatch.setattr(os.path, 'exists', lambda x: False)
|
||||
async def mock_get(*args, **kwargs):
|
||||
class MockResponse:
|
||||
status = 200
|
||||
async def read(self):
|
||||
return b'content'
|
||||
return MockResponse()
|
||||
|
||||
with mock.patch('aiohttp.ClientSession.get', new=mock_get):
|
||||
with mock.patch('builtins.open', mock.mock_open()) as mock_file:
|
||||
with mock.patch('zipfile.ZipFile.extractall') as mock_extractall:
|
||||
async def mock_aenter(_):
|
||||
await check_dashboard_files()
|
||||
mock_file.assert_called_once_with("data/dashboard.zip", "wb")
|
||||
mock_extractall.assert_called_once()
|
||||
|
||||
async def mock_aexit(obj, exc_type, exc, tb):
|
||||
return
|
||||
|
||||
mock_extractall.__aenter__ = mock_aenter
|
||||
mock_extractall.__aexit__ = mock_aexit
|
||||
@@ -0,0 +1,217 @@
|
||||
import pytest, logging, os
|
||||
from astrbot.core.pipeline.scheduler import PipelineScheduler, PipelineContext
|
||||
from astrbot.core.star import PluginManager
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.platform.astrbot_message import AstrBotMessage, MessageMember, MessageType
|
||||
from astrbot.core.message.message_event_result import MessageChain, ResultContentType
|
||||
from astrbot.core.message.components import Plain, At
|
||||
from astrbot.core.platform.platform_metadata import PlatformMetadata
|
||||
from astrbot.core.platform.manager import PlatformManager
|
||||
from astrbot.core.provider.manager import ProviderManager
|
||||
from astrbot.core.db.sqlite import SQLiteDatabase
|
||||
from astrbot.core.star.context import Context
|
||||
from astrbot.core import logger
|
||||
from asyncio import Queue
|
||||
|
||||
SESSION_ID_IN_WHITELIST = "test_sid_wl"
|
||||
SESSION_ID_NOT_IN_WHITELIST = "test_sid"
|
||||
TEST_LLM_PROVIDER = {
|
||||
"id": "zhipu_default",
|
||||
"type": "openai_chat_completion",
|
||||
"enable": True,
|
||||
"key": [os.getenv("ZHIPU_API_KEY")],
|
||||
"api_base": "https://open.bigmodel.cn/api/paas/v4/",
|
||||
"model_config": {
|
||||
"model": "glm-4-flash",
|
||||
},
|
||||
}
|
||||
|
||||
TEST_COMMANDS = [
|
||||
["help", "已注册的 AstrBot 内置指令"],
|
||||
["tool ls", "查看、激活、停用当前注册的函数工具"],
|
||||
["tool on websearch", "激活工具"],
|
||||
["tool off websearch", "停用工具"],
|
||||
["plugin", "已加载的插件"],
|
||||
["t2i", "文本转图片模式"],
|
||||
["sid", "此 ID 可用于设置会话白名单。"],
|
||||
["op test_op", "授权成功。"],
|
||||
["deop test_op", "取消授权成功。"],
|
||||
["wl test_platform:FriendMessage:test_sid_wl2", "添加白名单成功。"],
|
||||
["dwl test_platform:FriendMessage:test_sid_wl2", "删除白名单成功。"],
|
||||
["provider", "当前载入的 LLM 提供商"],
|
||||
["reset", "重置成功"],
|
||||
# ["model", "查看、切换提供商模型列表"],
|
||||
["history", "历史记录:"],
|
||||
["key", "当前 Key"],
|
||||
["persona", "[Persona]"]
|
||||
]
|
||||
|
||||
class FakeAstrMessageEvent(AstrMessageEvent):
|
||||
def __init__(self, abm: AstrBotMessage = None):
|
||||
meta = PlatformMetadata("test_platform", "test")
|
||||
super().__init__(
|
||||
message_str=abm.message_str,
|
||||
message_obj=abm,
|
||||
platform_meta=meta,
|
||||
session_id=abm.session_id
|
||||
)
|
||||
|
||||
async def send(self, message: MessageChain):
|
||||
await super().send(message)
|
||||
|
||||
@staticmethod
|
||||
def create_fake_event(
|
||||
message_str: str,
|
||||
session_id: str = "test_sid",
|
||||
is_at: bool = False,
|
||||
is_group: bool = False,
|
||||
sender_id: str = "123456"
|
||||
):
|
||||
abm = AstrBotMessage()
|
||||
abm.message_str = message_str
|
||||
abm.group_id = "test"
|
||||
abm.message = [Plain(message_str)]
|
||||
if is_at:
|
||||
abm.message.append(At(qq="bot"))
|
||||
abm.self_id = "bot"
|
||||
abm.sender = MessageMember(sender_id, "mika")
|
||||
abm.timestamp = 1234567890
|
||||
abm.message_id = "test"
|
||||
abm.session_id = session_id
|
||||
if is_group:
|
||||
abm.type = MessageType.GROUP_MESSAGE
|
||||
else:
|
||||
abm.type = MessageType.FRIEND_MESSAGE
|
||||
return FakeAstrMessageEvent(abm)
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def event_queue():
|
||||
return Queue()
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def config():
|
||||
cfg = AstrBotConfig()
|
||||
cfg['platform_settings']['id_whitelist'] = ["test_platform:FriendMessage:test_sid_wl", "test_platform:GroupMessage:test_sid_wl"]
|
||||
cfg['admins_id'] = ["123456"]
|
||||
cfg['content_safety']['internal_keywords']['extra_keywords'] = ["^TEST_NEGATIVE"]
|
||||
cfg['provider'] = [TEST_LLM_PROVIDER]
|
||||
return cfg
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def db():
|
||||
return SQLiteDatabase("data/data_v3.db")
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def platform_manager(event_queue, config):
|
||||
return PlatformManager(config, event_queue)
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def provider_manager(config, db):
|
||||
return ProviderManager(config, db)
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def star_context(event_queue, config, db, platform_manager, provider_manager):
|
||||
star_context = Context(event_queue, config, db, provider_manager, platform_manager)
|
||||
return star_context
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def plugin_manager(star_context, config):
|
||||
plugin_manager = PluginManager(star_context, config)
|
||||
plugin_manager.reload()
|
||||
return plugin_manager
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def pipeline_context(config, plugin_manager):
|
||||
return PipelineContext(config, plugin_manager)
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def pipeline_scheduler(pipeline_context):
|
||||
return PipelineScheduler(pipeline_context)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_platform_initialization(platform_manager: PlatformManager):
|
||||
await platform_manager.initialize()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_provider_initialization(provider_manager: ProviderManager):
|
||||
await provider_manager.initialize()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_scheduler_initialization(pipeline_scheduler: PipelineScheduler):
|
||||
await pipeline_scheduler.initialize()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_wakeup(pipeline_scheduler: PipelineScheduler, caplog):
|
||||
'''测试唤醒'''
|
||||
# 群聊无 @ 无指令
|
||||
mock_event = FakeAstrMessageEvent.create_fake_event("test", is_group=True)
|
||||
with caplog.at_level(logging.DEBUG):
|
||||
await pipeline_scheduler.execute(mock_event)
|
||||
assert any("执行阶段 WhitelistCheckStage" not in message for message in caplog.messages)
|
||||
# 群聊有 @ 无指令
|
||||
mock_event = FakeAstrMessageEvent.create_fake_event("test", is_group=True, is_at=True)
|
||||
with caplog.at_level(logging.DEBUG):
|
||||
await pipeline_scheduler.execute(mock_event)
|
||||
assert any("执行阶段 WhitelistCheckStage" in message for message in caplog.messages)
|
||||
# 群聊有指令
|
||||
mock_event = FakeAstrMessageEvent.create_fake_event("/help", is_group=True, session_id=SESSION_ID_IN_WHITELIST)
|
||||
await pipeline_scheduler.execute(mock_event)
|
||||
assert mock_event._has_send_oper is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_wl(pipeline_scheduler: PipelineScheduler, config: AstrBotConfig, caplog):
|
||||
mock_event = FakeAstrMessageEvent.create_fake_event("test", sender_id="123")
|
||||
with caplog.at_level(logging.INFO):
|
||||
await pipeline_scheduler.execute(mock_event)
|
||||
assert any("不在会话白名单中,已终止事件传播。" in message for message in caplog.messages), "日志中未找到预期的消息"
|
||||
|
||||
mock_event = FakeAstrMessageEvent.create_fake_event("test", SESSION_ID_IN_WHITELIST, sender_id="123")
|
||||
await pipeline_scheduler.execute(mock_event)
|
||||
assert any("不在会话白名单中,已终止事件传播。" not in message for message in caplog.messages), "日志中未找到预期的消息"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_content_safety(pipeline_scheduler: PipelineScheduler, caplog):
|
||||
# 测试默认屏蔽词
|
||||
mock_event = FakeAstrMessageEvent.create_fake_event("色情", session_id=SESSION_ID_IN_WHITELIST) # 测试需要。
|
||||
with caplog.at_level(logging.INFO):
|
||||
await pipeline_scheduler.execute(mock_event)
|
||||
assert any("内容安全检查不通过" in message for message in caplog.messages), "日志中未找到预期的消息"
|
||||
# 测试额外屏蔽词
|
||||
mock_event = FakeAstrMessageEvent.create_fake_event("TEST_NEGATIVE", session_id=SESSION_ID_IN_WHITELIST)
|
||||
with caplog.at_level(logging.INFO):
|
||||
await pipeline_scheduler.execute(mock_event)
|
||||
assert any("内容安全检查不通过" in message for message in caplog.messages), "日志中未找到预期的消息"
|
||||
mock_event = FakeAstrMessageEvent.create_fake_event("_TEST_NEGATIVE", session_id=SESSION_ID_IN_WHITELIST)
|
||||
with caplog.at_level(logging.INFO):
|
||||
await pipeline_scheduler.execute(mock_event)
|
||||
assert any("内容安全检查不通过" not in message for message in caplog.messages)
|
||||
# TODO: 测试 百度AI 的内容安全检查
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_llm(pipeline_scheduler: PipelineScheduler, caplog):
|
||||
mock_event = FakeAstrMessageEvent.create_fake_event("just reply me `OK`", session_id=SESSION_ID_IN_WHITELIST)
|
||||
with caplog.at_level(logging.DEBUG):
|
||||
await pipeline_scheduler.execute(mock_event)
|
||||
assert any("请求 LLM" in message for message in caplog.messages)
|
||||
assert mock_event.get_result() is not None
|
||||
assert mock_event.get_result().result_content_type == ResultContentType.LLM_RESULT
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_websearch(pipeline_scheduler: PipelineScheduler, caplog):
|
||||
mock_event = FakeAstrMessageEvent.create_fake_event("help me search the latest OpenAI news", session_id=SESSION_ID_IN_WHITELIST)
|
||||
with caplog.at_level(logging.DEBUG):
|
||||
await pipeline_scheduler.execute(mock_event)
|
||||
assert any("请求 LLM" in message for message in caplog.messages)
|
||||
assert any("web_searcher - search_from_search_engine" in message for message in caplog.messages)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_commands(pipeline_scheduler: PipelineScheduler, caplog):
|
||||
for command in TEST_COMMANDS:
|
||||
mock_event = FakeAstrMessageEvent.create_fake_event(command[0], session_id=SESSION_ID_IN_WHITELIST)
|
||||
with caplog.at_level(logging.DEBUG):
|
||||
await pipeline_scheduler.execute(mock_event)
|
||||
assert any("执行阶段 ProcessStage" in message for message in caplog.messages)
|
||||
assert any(command[1] in message for message in caplog.messages)
|
||||
@@ -0,0 +1,93 @@
|
||||
import pytest
|
||||
import os
|
||||
import shutil
|
||||
from astrbot.core.star.star_manager import PluginManager
|
||||
from astrbot.core.star.star_handler import star_handlers_registry
|
||||
from astrbot.core.star.star import star_registry
|
||||
from astrbot.core.star.context import Context
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core.db.sqlite import SQLiteDatabase
|
||||
from asyncio import Queue
|
||||
|
||||
@pytest.fixture
|
||||
def event_queue():
|
||||
return Queue()
|
||||
|
||||
@pytest.fixture
|
||||
def config():
|
||||
return AstrBotConfig()
|
||||
|
||||
@pytest.fixture
|
||||
def db():
|
||||
return SQLiteDatabase("data/data_v3.db")
|
||||
|
||||
@pytest.fixture
|
||||
def star_context(event_queue, config, db):
|
||||
return Context(event_queue, config, db)
|
||||
|
||||
@pytest.fixture
|
||||
def plugin_manager_pm(star_context, config):
|
||||
return PluginManager(star_context, config)
|
||||
|
||||
def test_plugin_manager_initialization(plugin_manager_pm: PluginManager):
|
||||
assert plugin_manager_pm is not None
|
||||
assert plugin_manager_pm.context is not None
|
||||
assert plugin_manager_pm.config is not None
|
||||
|
||||
def test_plugin_manager_reload(plugin_manager_pm: PluginManager):
|
||||
success, err_message = plugin_manager_pm.reload()
|
||||
assert success is True
|
||||
assert err_message is None
|
||||
assert len(star_handlers_registry) > 0 # package
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_plugin_crud(plugin_manager_pm: PluginManager):
|
||||
'''测试插件安装和重载'''
|
||||
os.makedirs("data/plugins", exist_ok=True)
|
||||
test_repo = "https://github.com/Soulter/astrbot_plugin_essential"
|
||||
plugin_path = await plugin_manager_pm.install_plugin(test_repo)
|
||||
exists = False
|
||||
for md in star_registry:
|
||||
if md.name == "astrbot_plugin_essential":
|
||||
exists = True
|
||||
break
|
||||
assert plugin_path is not None
|
||||
assert os.path.exists(plugin_path)
|
||||
assert exists is True, "插件 astrbot_plugin_essential 未成功载入"
|
||||
# shutil.rmtree(plugin_path)
|
||||
|
||||
# install plugin which is not exists
|
||||
with pytest.raises(Exception):
|
||||
plugin_path = await plugin_manager_pm.install_plugin(test_repo + "haha")
|
||||
|
||||
# update
|
||||
await plugin_manager_pm.update_plugin("astrbot_plugin_essential")
|
||||
|
||||
with pytest.raises(Exception):
|
||||
await plugin_manager_pm.update_plugin("astrbot_plugin_essentialhaha")
|
||||
|
||||
# uninstall
|
||||
await plugin_manager_pm.uninstall_plugin("astrbot_plugin_essential")
|
||||
assert not os.path.exists(plugin_path)
|
||||
exists = False
|
||||
for md in star_registry:
|
||||
if md.name == "astrbot_plugin_essential":
|
||||
exists = True
|
||||
break
|
||||
assert exists is False, "插件 astrbot_plugin_essential 未成功卸载"
|
||||
exists = False
|
||||
for md in star_handlers_registry:
|
||||
if "astrbot_plugin_essential" in md.handler_module_path:
|
||||
exists = True
|
||||
break
|
||||
assert exists is False, "插件 astrbot_plugin_essential 未成功卸载"
|
||||
|
||||
with pytest.raises(Exception):
|
||||
await plugin_manager_pm.uninstall_plugin("astrbot_plugin_essentialhaha")
|
||||
|
||||
# TODO: file installation
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user