test: add comprehensive tests for message event handling (#5355)

* test: add comprehensive tests for message event handling

- Add AstrMessageEvent unit tests (688 lines)
- Add AstrBotMessage unit tests
- Enhance smoke tests with message event scenarios

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* fix: improve message type handling and add defensive tests

---------

Co-authored-by: whatevertogo <whatevertogo@users.noreply.github.com>
Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
whatevertogo
2026-02-23 23:36:39 +08:00
committed by GitHub
parent 7b731ebda8
commit 5af5ad9e36
4 changed files with 1197 additions and 11 deletions
+33 -11
View File
@@ -52,9 +52,19 @@ class AstrMessageEvent(abc.ABC):
self.is_at_or_wake_command = False
"""是否是 At 机器人或者带有唤醒词或者是私聊(插件注册的事件监听器会让 is_wake 设为 True, 但是不会让这个属性置为 True)"""
self._extras: dict[str, Any] = {}
message_type = getattr(message_obj, "type", None)
if not isinstance(message_type, MessageType):
try:
message_type = MessageType(str(message_type))
except (ValueError, TypeError, AttributeError):
logger.warning(
f"Failed to convert message type {message_obj.type!r} to MessageType. "
f"Falling back to FRIEND_MESSAGE."
)
message_type = MessageType.FRIEND_MESSAGE
self.session = MessageSession(
platform_name=platform_meta.id,
message_type=message_obj.type,
message_type=message_type,
session_id=session_id,
)
# self.unified_msg_origin = str(self.session)
@@ -159,15 +169,18 @@ class AstrMessageEvent(abc.ABC):
除了文本消息外,其他消息类型会被转换为对应的占位符。如图片消息会被转换为 [图片]。
"""
return self._outline_chain(self.message_obj.message)
return self._outline_chain(getattr(self.message_obj, "message", None))
def get_messages(self) -> list[BaseMessageComponent]:
"""获取消息链。"""
return self.message_obj.message
return getattr(self.message_obj, "message", [])
def get_message_type(self) -> MessageType:
"""获取消息类型。"""
return self.message_obj.type
message_type = getattr(self.message_obj, "type", None)
if isinstance(message_type, MessageType):
return message_type
return self.session.message_type
def get_session_id(self) -> str:
"""获取会话id。"""
@@ -175,21 +188,30 @@ class AstrMessageEvent(abc.ABC):
def get_group_id(self) -> str:
"""获取群组id。如果不是群组消息,返回空字符串。"""
return self.message_obj.group_id
return getattr(self.message_obj, "group_id", "")
def get_self_id(self) -> str:
"""获取机器人自身的id。"""
return self.message_obj.self_id
return getattr(self.message_obj, "self_id", "")
def get_sender_id(self) -> str:
"""获取消息发送者的id。"""
return self.message_obj.sender.user_id
sender = getattr(self.message_obj, "sender", None)
if sender and isinstance(getattr(sender, "user_id", None), str):
return sender.user_id
return ""
def get_sender_name(self) -> str:
"""获取消息发送者的名称。(可能会返回空字符串)"""
if isinstance(self.message_obj.sender.nickname, str):
return self.message_obj.sender.nickname
return ""
sender = getattr(self.message_obj, "sender", None)
if not sender:
return ""
nickname = getattr(sender, "nickname", None)
if nickname is None:
return ""
if isinstance(nickname, str):
return nickname
return str(nickname)
def set_extra(self, key, value) -> None:
"""设置额外的信息。"""
@@ -208,7 +230,7 @@ class AstrMessageEvent(abc.ABC):
def is_private_chat(self) -> bool:
"""是否是私聊。"""
return self.message_obj.type.value == (MessageType.FRIEND_MESSAGE).value
return self.get_message_type() == MessageType.FRIEND_MESSAGE
def is_wake_up(self) -> bool:
"""是否是唤醒机器人的事件。"""
+115
View File
@@ -0,0 +1,115 @@
"""Smoke tests for critical startup and import paths."""
from __future__ import annotations
import subprocess
import sys
from pathlib import Path
from astrbot.core.pipeline.bootstrap import ensure_builtin_stages_registered
from astrbot.core.pipeline.process_stage.method.agent_sub_stages.internal import (
InternalAgentSubStage,
)
from astrbot.core.pipeline.process_stage.method.agent_sub_stages.third_party import (
ThirdPartyAgentSubStage,
)
from astrbot.core.pipeline.stage import Stage, registered_stages
from astrbot.core.pipeline.stage_order import STAGES_ORDER
REPO_ROOT = Path(__file__).resolve().parents[1]
def _run_code_in_fresh_interpreter(code: str, failure_message: str) -> None:
proc = subprocess.run(
[sys.executable, "-c", code],
cwd=REPO_ROOT,
capture_output=True,
text=True,
check=False,
)
assert proc.returncode == 0, (
f"{failure_message}\nstdout:\n{proc.stdout}\nstderr:\n{proc.stderr}\n"
)
def test_smoke_critical_imports_in_fresh_interpreter() -> None:
code = (
"import importlib;"
"mods=["
"'astrbot.core.core_lifecycle',"
"'astrbot.core.astr_main_agent',"
"'astrbot.core.pipeline.scheduler',"
"'astrbot.core.pipeline.process_stage.method.agent_sub_stages.internal',"
"'astrbot.core.pipeline.process_stage.method.agent_sub_stages.third_party'"
"];"
"[importlib.import_module(m) for m in mods]"
)
_run_code_in_fresh_interpreter(code, "Smoke import check failed.")
def test_smoke_pipeline_stage_registration_matches_order() -> None:
ensure_builtin_stages_registered()
stage_names = {cls.__name__ for cls in registered_stages}
assert set(STAGES_ORDER).issubset(stage_names)
assert len(stage_names) == len(registered_stages)
def test_smoke_agent_sub_stages_are_stage_subclasses() -> None:
assert issubclass(InternalAgentSubStage, Stage)
assert issubclass(ThirdPartyAgentSubStage, Stage)
def test_pipeline_package_exports_remain_compatible() -> None:
import astrbot.core.pipeline as pipeline
assert pipeline.ProcessStage is not None
assert pipeline.RespondStage is not None
assert isinstance(pipeline.STAGES_ORDER, list)
assert "ProcessStage" in pipeline.STAGES_ORDER
def test_builtin_stage_bootstrap_is_idempotent() -> None:
ensure_builtin_stages_registered()
before_count = len(registered_stages)
stage_names = {cls.__name__ for cls in registered_stages}
expected_stage_names = {
"WakingCheckStage",
"WhitelistCheckStage",
"SessionStatusCheckStage",
"RateLimitStage",
"ContentSafetyCheckStage",
"PreProcessStage",
"ProcessStage",
"ResultDecorateStage",
"RespondStage",
}
assert expected_stage_names.issubset(stage_names)
ensure_builtin_stages_registered()
assert len(registered_stages) == before_count
def test_pipeline_import_is_stable_with_mocked_apscheduler() -> None:
"""Regression: importing pipeline should not require cron/apscheduler modules."""
code = (
"import sys;"
"from unittest.mock import MagicMock;"
"mock_apscheduler = MagicMock();"
"mock_apscheduler.schedulers = MagicMock();"
"mock_apscheduler.schedulers.asyncio = MagicMock();"
"mock_apscheduler.schedulers.background = MagicMock();"
"sys.modules['apscheduler'] = mock_apscheduler;"
"sys.modules['apscheduler.schedulers'] = mock_apscheduler.schedulers;"
"sys.modules['apscheduler.schedulers.asyncio'] = mock_apscheduler.schedulers.asyncio;"
"sys.modules['apscheduler.schedulers.background'] = mock_apscheduler.schedulers.background;"
"import astrbot.core.pipeline as pipeline;"
"assert pipeline.ProcessStage is not None;"
"assert pipeline.RespondStage is not None"
)
_run_code_in_fresh_interpreter(
code,
"Pipeline import should not depend on real apscheduler package.",
)
+781
View File
@@ -0,0 +1,781 @@
"""Tests for AstrMessageEvent class."""
import re
from unittest.mock import AsyncMock, patch
import pytest
from astrbot.core.message.components import (
At,
AtAll,
Face,
Forward,
Image,
Plain,
Reply,
)
from astrbot.core.message.message_event_result import MessageEventResult
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.platform.astrbot_message import AstrBotMessage, MessageMember
from astrbot.core.platform.message_type import MessageType
from astrbot.core.platform.platform_metadata import PlatformMetadata
class ConcreteAstrMessageEvent(AstrMessageEvent):
"""Concrete implementation of AstrMessageEvent for testing purposes."""
async def send(self, message):
"""Send message implementation."""
await super().send(message)
@pytest.fixture
def platform_meta():
"""Create platform metadata for testing."""
return PlatformMetadata(
name="test_platform",
description="Test platform",
id="test_platform_id",
)
@pytest.fixture
def message_member():
"""Create a message member for testing."""
return MessageMember(user_id="user123", nickname="TestUser")
@pytest.fixture
def astrbot_message(message_member):
"""Create an AstrBotMessage for testing."""
message = AstrBotMessage()
message.type = MessageType.FRIEND_MESSAGE
message.self_id = "bot123"
message.session_id = "session123"
message.message_id = "msg123"
message.sender = message_member
message.message = [Plain(text="Hello world")]
message.message_str = "Hello world"
message.raw_message = None
return message
@pytest.fixture
def astr_message_event(platform_meta, astrbot_message):
"""Create an AstrMessageEvent instance for testing."""
return ConcreteAstrMessageEvent(
message_str="Hello world",
message_obj=astrbot_message,
platform_meta=platform_meta,
session_id="session123",
)
class TestAstrMessageEventInit:
"""Tests for AstrMessageEvent initialization."""
def test_init_basic(self, astr_message_event):
"""Test basic AstrMessageEvent initialization."""
assert astr_message_event.message_str == "Hello world"
assert astr_message_event.role == "member"
assert astr_message_event.is_wake is False
assert astr_message_event.is_at_or_wake_command is False
assert astr_message_event._extras == {}
assert astr_message_event._result is None
assert astr_message_event.call_llm is False
def test_init_session(self, astr_message_event):
"""Test session initialization."""
assert astr_message_event.session_id == "session123"
assert astr_message_event.session.platform_name == "test_platform_id"
def test_init_platform_reference(self, astr_message_event, platform_meta):
"""Test platform reference initialization."""
assert astr_message_event.platform_meta == platform_meta
assert astr_message_event.platform == platform_meta # back compatibility
def test_init_created_at(self, astr_message_event):
"""Test created_at timestamp is set."""
assert astr_message_event.created_at is not None
assert isinstance(astr_message_event.created_at, float)
def test_init_trace(self, astr_message_event):
"""Test trace/span initialization."""
assert astr_message_event.trace is not None
assert astr_message_event.span is not None
assert astr_message_event.trace == astr_message_event.span
class TestUnifiedMsgOrigin:
"""Tests for unified_msg_origin property."""
def test_unified_msg_origin_getter(self, astr_message_event):
"""Test unified_msg_origin getter."""
expected = "test_platform_id:FriendMessage:session123"
assert astr_message_event.unified_msg_origin == expected
def test_unified_msg_origin_setter(self, astr_message_event):
"""Test unified_msg_origin setter."""
astr_message_event.unified_msg_origin = "new_platform:GroupMessage:new_session"
assert astr_message_event.session.platform_name == "new_platform"
assert astr_message_event.session.session_id == "new_session"
class TestSessionId:
"""Tests for session_id property."""
def test_session_id_getter(self, astr_message_event):
"""Test session_id getter."""
assert astr_message_event.session_id == "session123"
def test_session_id_setter(self, astr_message_event):
"""Test session_id setter."""
astr_message_event.session_id = "new_session_id"
assert astr_message_event.session_id == "new_session_id"
class TestGetPlatformInfo:
"""Tests for platform info methods."""
def test_get_platform_name(self, astr_message_event):
"""Test get_platform_name method."""
assert astr_message_event.get_platform_name() == "test_platform"
def test_get_platform_id(self, astr_message_event):
"""Test get_platform_id method."""
assert astr_message_event.get_platform_id() == "test_platform_id"
class TestGetMessageInfo:
"""Tests for message info methods."""
def test_get_message_str(self, astr_message_event):
"""Test get_message_str method."""
assert astr_message_event.get_message_str() == "Hello world"
def test_get_message_str_none(self, platform_meta, astrbot_message):
"""Test get_message_str keeps None when source message_str is None."""
astrbot_message.message_str = None
event = ConcreteAstrMessageEvent(
message_str=None,
message_obj=astrbot_message,
platform_meta=platform_meta,
session_id="session123",
)
assert event.get_message_str() is None
def test_get_messages(self, astr_message_event):
"""Test get_messages method."""
messages = astr_message_event.get_messages()
assert len(messages) == 1
assert isinstance(messages[0], Plain)
assert messages[0].text == "Hello world"
def test_get_message_type(self, astr_message_event):
"""Test get_message_type method."""
assert astr_message_event.get_message_type() == MessageType.FRIEND_MESSAGE
def test_get_session_id(self, astr_message_event):
"""Test get_session_id method."""
assert astr_message_event.get_session_id() == "session123"
def test_get_group_id_empty_for_private(self, astr_message_event):
"""Test get_group_id returns empty for private messages."""
assert astr_message_event.get_group_id() == ""
def test_get_self_id(self, astr_message_event):
"""Test get_self_id method."""
assert astr_message_event.get_self_id() == "bot123"
def test_get_sender_id(self, astr_message_event):
"""Test get_sender_id method."""
assert astr_message_event.get_sender_id() == "user123"
def test_get_sender_name(self, astr_message_event):
"""Test get_sender_name method."""
assert astr_message_event.get_sender_name() == "TestUser"
def test_get_sender_name_empty_when_none(self, platform_meta, astrbot_message):
"""Test get_sender_name returns empty string when nickname is None."""
astrbot_message.sender = MessageMember(user_id="user123", nickname=None)
event = ConcreteAstrMessageEvent(
message_str="test",
message_obj=astrbot_message,
platform_meta=platform_meta,
session_id="session123",
)
assert event.get_sender_name() == ""
def test_get_sender_name_coerces_non_string(self, platform_meta, astrbot_message):
"""Test get_sender_name stringifies non-string nickname values."""
astrbot_message.sender = MessageMember(user_id="user123", nickname=None)
astrbot_message.sender.nickname = 12345
event = ConcreteAstrMessageEvent(
message_str="test",
message_obj=astrbot_message,
platform_meta=platform_meta,
session_id="session123",
)
assert event.get_sender_name() == "12345"
class TestGetMessageOutline:
"""Tests for get_message_outline method."""
def test_outline_plain_text(self, astr_message_event):
"""Test outline with plain text message."""
outline = astr_message_event.get_message_outline()
assert "Hello world" in outline
def test_outline_with_image(self, platform_meta, astrbot_message):
"""Test outline with image component."""
astrbot_message.message = [
Plain(text="Look at this"),
Image(file="http://example.com/img.jpg"),
]
event = ConcreteAstrMessageEvent(
message_str="Look at this",
message_obj=astrbot_message,
platform_meta=platform_meta,
session_id="session123",
)
outline = event.get_message_outline()
assert "Look at this" in outline
assert "[图片]" in outline
def test_outline_with_at(self, platform_meta, astrbot_message):
"""Test outline with At component."""
astrbot_message.message = [At(qq="12345"), Plain(text=" hello")]
event = ConcreteAstrMessageEvent(
message_str=" hello",
message_obj=astrbot_message,
platform_meta=platform_meta,
session_id="session123",
)
outline = event.get_message_outline()
assert "[At:12345]" in outline
def test_outline_with_at_all(self, platform_meta, astrbot_message):
"""Test outline with AtAll component."""
astrbot_message.message = [AtAll()]
event = ConcreteAstrMessageEvent(
message_str="",
message_obj=astrbot_message,
platform_meta=platform_meta,
session_id="session123",
)
outline = event.get_message_outline()
# AtAll format is "[At:all]" in the actual implementation
assert "[At:" in outline and "all" in outline.lower()
def test_outline_with_face(self, platform_meta, astrbot_message):
"""Test outline with Face component."""
astrbot_message.message = [Face(id="123")]
event = ConcreteAstrMessageEvent(
message_str="",
message_obj=astrbot_message,
platform_meta=platform_meta,
session_id="session123",
)
outline = event.get_message_outline()
assert "[表情:123]" in outline
def test_outline_with_forward(self, platform_meta, astrbot_message):
"""Test outline with Forward component."""
# Forward requires an id parameter
astrbot_message.message = [Forward(id="test_forward_id")]
event = ConcreteAstrMessageEvent(
message_str="",
message_obj=astrbot_message,
platform_meta=platform_meta,
session_id="session123",
)
outline = event.get_message_outline()
assert "[转发消息]" in outline
def test_outline_with_reply(self, platform_meta, astrbot_message):
"""Test outline with Reply component."""
# Reply requires an id parameter
reply = Reply(id="test_reply_id")
reply.message_str = "Original message"
reply.sender_nickname = "Sender"
astrbot_message.message = [reply, Plain(text=" reply")]
event = ConcreteAstrMessageEvent(
message_str=" reply",
message_obj=astrbot_message,
platform_meta=platform_meta,
session_id="session123",
)
outline = event.get_message_outline()
assert "[引用消息(Sender: Original message)]" in outline
def test_outline_with_reply_no_message(self, platform_meta, astrbot_message):
"""Test outline with Reply component without message_str."""
# Reply requires an id parameter
reply = Reply(id="test_reply_id")
reply.message_str = None
astrbot_message.message = [reply]
event = ConcreteAstrMessageEvent(
message_str="",
message_obj=astrbot_message,
platform_meta=platform_meta,
session_id="session123",
)
outline = event.get_message_outline()
assert "[引用消息]" in outline
def test_outline_empty_chain(self, platform_meta, astrbot_message):
"""Test outline with empty message chain."""
astrbot_message.message = []
event = ConcreteAstrMessageEvent(
message_str="",
message_obj=astrbot_message,
platform_meta=platform_meta,
session_id="session123",
)
outline = event.get_message_outline()
assert outline == ""
def test_outline_very_long_plain_text(self, platform_meta, astrbot_message):
"""Test outline generation for very long plain text content."""
long_text = "A" * 20000
astrbot_message.message = [Plain(text=long_text)]
event = ConcreteAstrMessageEvent(
message_str=long_text,
message_obj=astrbot_message,
platform_meta=platform_meta,
session_id="session123",
)
outline = event.get_message_outline()
assert outline.startswith("A")
assert len(outline) >= 20000
class TestExtras:
"""Tests for extra information methods."""
def test_set_extra(self, astr_message_event):
"""Test set_extra method."""
astr_message_event.set_extra("key1", "value1")
assert astr_message_event._extras["key1"] == "value1"
def test_get_extra_with_key(self, astr_message_event):
"""Test get_extra with specific key."""
astr_message_event.set_extra("key1", "value1")
assert astr_message_event.get_extra("key1") == "value1"
def test_get_extra_with_default(self, astr_message_event):
"""Test get_extra with default value."""
result = astr_message_event.get_extra("nonexistent", "default_value")
assert result == "default_value"
def test_get_extra_all(self, astr_message_event):
"""Test get_extra without key returns all extras."""
astr_message_event.set_extra("key1", "value1")
astr_message_event.set_extra("key2", "value2")
all_extras = astr_message_event.get_extra()
assert all_extras == {"key1": "value1", "key2": "value2"}
def test_clear_extra(self, astr_message_event):
"""Test clear_extra method."""
astr_message_event.set_extra("key1", "value1")
astr_message_event.clear_extra()
assert astr_message_event._extras == {}
class TestSetResult:
"""Tests for set_result method."""
def test_set_result_with_message_event_result(self, astr_message_event):
"""Test set_result with MessageEventResult object."""
result = MessageEventResult().message("Test message")
astr_message_event.set_result(result)
assert astr_message_event._result == result
def test_set_result_with_string(self, astr_message_event):
"""Test set_result with string creates MessageEventResult."""
astr_message_event.set_result("Test message")
assert astr_message_event._result is not None
assert len(astr_message_event._result.chain) == 1
assert isinstance(astr_message_event._result.chain[0], Plain)
def test_set_result_with_empty_chain(self, astr_message_event):
"""Test set_result handles empty chain correctly."""
result = MessageEventResult()
# chain is already an empty list by default
astr_message_event.set_result(result)
assert astr_message_event._result.chain == []
class TestStopContinueEvent:
"""Tests for stop_event and continue_event methods."""
def test_stop_event_creates_result_if_none(self, astr_message_event):
"""Test stop_event creates result if none exists."""
astr_message_event.stop_event()
assert astr_message_event._result is not None
assert astr_message_event.is_stopped() is True
def test_stop_event_with_existing_result(self, astr_message_event):
"""Test stop_event with existing result."""
astr_message_event.set_result(MessageEventResult().message("Test"))
astr_message_event.stop_event()
assert astr_message_event.is_stopped() is True
def test_continue_event_creates_result_if_none(self, astr_message_event):
"""Test continue_event creates result if none exists."""
astr_message_event.continue_event()
assert astr_message_event._result is not None
assert astr_message_event.is_stopped() is False
def test_continue_event_with_existing_result(self, astr_message_event):
"""Test continue_event with existing result."""
astr_message_event.set_result(MessageEventResult().message("Test"))
astr_message_event.stop_event()
astr_message_event.continue_event()
assert astr_message_event.is_stopped() is False
def test_is_stopped_default_false(self, astr_message_event):
"""Test is_stopped returns False by default."""
assert astr_message_event.is_stopped() is False
class TestIsPrivateChat:
"""Tests for is_private_chat method."""
def test_is_private_chat_true(self, astr_message_event):
"""Test is_private_chat returns True for friend message."""
assert astr_message_event.is_private_chat() is True
def test_is_private_chat_false(self, platform_meta, astrbot_message):
"""Test is_private_chat returns False for group message."""
astrbot_message.type = MessageType.GROUP_MESSAGE
event = ConcreteAstrMessageEvent(
message_str="test",
message_obj=astrbot_message,
platform_meta=platform_meta,
session_id="session123",
)
assert event.is_private_chat() is False
class TestIsWakeUp:
"""Tests for is_wake_up method."""
def test_is_wake_up_default_false(self, astr_message_event):
"""Test is_wake_up returns False by default."""
assert astr_message_event.is_wake_up() is False
def test_is_wake_up_when_set(self, astr_message_event):
"""Test is_wake_up returns True when is_wake is set."""
astr_message_event.is_wake = True
assert astr_message_event.is_wake_up() is True
class TestIsAdmin:
"""Tests for is_admin method."""
def test_is_admin_default_false(self, astr_message_event):
"""Test is_admin returns False by default."""
assert astr_message_event.is_admin() is False
def test_is_admin_when_admin(self, astr_message_event):
"""Test is_admin returns True when role is admin."""
astr_message_event.role = "admin"
assert astr_message_event.is_admin() is True
class TestProcessBuffer:
"""Tests for process_buffer method."""
@pytest.mark.asyncio
async def test_process_buffer_splits_by_pattern(self, astr_message_event):
"""Test process_buffer splits buffer by pattern."""
buffer = "Line 1\nLine 2\nLine 3\nRemaining"
pattern = re.compile(r".*\n")
with patch.object(
astr_message_event, "send", new_callable=AsyncMock
) as mock_send:
result = await astr_message_event.process_buffer(buffer, pattern)
# Should have sent 3 lines and remaining should be "Remaining"
assert mock_send.call_count == 3
assert result == "Remaining"
@pytest.mark.asyncio
async def test_process_buffer_no_match(self, astr_message_event):
"""Test process_buffer returns original when no match."""
buffer = "No newlines here"
pattern = re.compile(r"\n")
result = await astr_message_event.process_buffer(buffer, pattern)
assert result == "No newlines here"
class TestResultHelpers:
"""Tests for result helper methods."""
def test_make_result(self, astr_message_event):
"""Test make_result creates empty MessageEventResult."""
result = astr_message_event.make_result()
assert isinstance(result, MessageEventResult)
def test_plain_result(self, astr_message_event):
"""Test plain_result creates result with text."""
result = astr_message_event.plain_result("Hello")
assert isinstance(result, MessageEventResult)
assert len(result.chain) == 1
assert isinstance(result.chain[0], Plain)
assert result.chain[0].text == "Hello"
def test_image_result_url(self, astr_message_event):
"""Test image_result with URL."""
result = astr_message_event.image_result("http://example.com/image.jpg")
assert isinstance(result, MessageEventResult)
assert len(result.chain) == 1
assert isinstance(result.chain[0], Image)
def test_image_result_path(self, astr_message_event):
"""Test image_result with file path."""
result = astr_message_event.image_result("/path/to/image.jpg")
assert isinstance(result, MessageEventResult)
assert len(result.chain) == 1
assert isinstance(result.chain[0], Image)
class TestGetResult:
"""Tests for get_result and clear_result methods."""
def test_get_result_returns_none_by_default(self, astr_message_event):
"""Test get_result returns None by default."""
assert astr_message_event.get_result() is None
def test_get_result_returns_set_result(self, astr_message_event):
"""Test get_result returns set result."""
result = MessageEventResult().message("Test")
astr_message_event.set_result(result)
assert astr_message_event.get_result() == result
def test_clear_result(self, astr_message_event):
"""Test clear_result clears the result."""
astr_message_event.set_result(MessageEventResult().message("Test"))
astr_message_event.clear_result()
assert astr_message_event.get_result() is None
class TestShouldCallLlm:
"""Tests for should_call_llm method."""
def test_should_call_llm_default(self, astr_message_event):
"""Test call_llm default is False."""
assert astr_message_event.call_llm is False
def test_should_call_llm_when_set(self, astr_message_event):
"""Test should_call_llm sets call_llm."""
astr_message_event.should_call_llm(True)
assert astr_message_event.call_llm is True
class TestRequestLlm:
"""Tests for request_llm method."""
def test_request_llm_basic(self, astr_message_event):
"""Test request_llm creates ProviderRequest."""
request = astr_message_event.request_llm(prompt="Hello")
assert request.prompt == "Hello"
assert request.session_id == ""
assert request.image_urls == []
assert request.contexts == []
def test_request_llm_with_all_params(self, astr_message_event):
"""Test request_llm with all parameters."""
request = astr_message_event.request_llm(
prompt="Hello",
session_id="session123",
image_urls=["http://example.com/img.jpg"],
contexts=[{"role": "user", "content": "Hi"}],
system_prompt="You are helpful",
)
assert request.prompt == "Hello"
assert request.session_id == "session123"
assert request.image_urls == ["http://example.com/img.jpg"]
assert request.contexts == [{"role": "user", "content": "Hi"}]
assert request.system_prompt == "You are helpful"
class TestSendStreaming:
"""Tests for send_streaming method."""
@pytest.mark.asyncio
async def test_send_streaming_sets_has_send_oper(self, astr_message_event):
"""Test send_streaming sets _has_send_oper flag."""
assert astr_message_event._has_send_oper is False
async def generator():
yield MessageEventResult().message("Test")
with patch(
"astrbot.core.platform.astr_message_event.Metric.upload",
new_callable=AsyncMock,
):
await astr_message_event.send_streaming(generator())
assert astr_message_event._has_send_oper is True
class TestSendTyping:
"""Tests for send_typing method."""
@pytest.mark.asyncio
async def test_send_typing_default_empty(self, astr_message_event):
"""Test send_typing default implementation is empty."""
# Should not raise any exception
await astr_message_event.send_typing()
class TestReact:
"""Tests for react method."""
@pytest.mark.asyncio
async def test_react_sends_emoji(self, astr_message_event):
"""Test react sends emoji as message."""
with patch.object(
astr_message_event, "send", new_callable=AsyncMock
) as mock_send:
await astr_message_event.react("👍")
mock_send.assert_called_once()
call_arg = mock_send.call_args[0][0]
# MessageChain is a dataclass with chain attribute
assert len(call_arg.chain) == 1
assert isinstance(call_arg.chain[0], Plain)
assert call_arg.chain[0].text == "👍"
class TestGetGroup:
"""Tests for get_group method."""
@pytest.mark.asyncio
async def test_get_group_returns_none_for_private(self, astr_message_event):
"""Test get_group returns None for private chat."""
result = await astr_message_event.get_group()
assert result is None
@pytest.mark.asyncio
async def test_get_group_with_group_id_param(self, astr_message_event):
"""Test get_group with group_id parameter."""
# Default implementation returns None
result = await astr_message_event.get_group(group_id="group123")
assert result is None
class TestMessageTypeHandling:
"""Tests for message type handling edge cases."""
def test_message_type_from_valid_string(self, platform_meta):
"""Valid MessageType string should be converted correctly."""
message = AstrBotMessage()
message.type = "FRIEND_MESSAGE"
message.message = []
event = ConcreteAstrMessageEvent(
message_str="test",
message_obj=message,
platform_meta=platform_meta,
session_id="session123",
)
assert event.session.message_type == MessageType.FRIEND_MESSAGE
assert event.get_message_type() == MessageType.FRIEND_MESSAGE
def test_message_type_from_invalid_string_defaults_to_friend(self, platform_meta):
"""Invalid message type should default to FRIEND_MESSAGE."""
message = AstrBotMessage()
message.type = "InvalidMessageType"
message.message = []
event = ConcreteAstrMessageEvent(
message_str="test",
message_obj=message,
platform_meta=platform_meta,
session_id="session123",
)
assert event.session.message_type == MessageType.FRIEND_MESSAGE
assert event.get_message_type() == MessageType.FRIEND_MESSAGE
def test_message_type_from_none_defaults_to_friend(self, platform_meta):
"""None message type should default to FRIEND_MESSAGE."""
message = AstrBotMessage()
message.type = None
message.message = []
event = ConcreteAstrMessageEvent(
message_str="test",
message_obj=message,
platform_meta=platform_meta,
session_id="session123",
)
assert event.session.message_type == MessageType.FRIEND_MESSAGE
assert event.get_message_type() == MessageType.FRIEND_MESSAGE
def test_message_type_from_integer_defaults_to_friend(self, platform_meta):
"""Integer message type should default to FRIEND_MESSAGE."""
message = AstrBotMessage()
message.type = 123
message.message = []
event = ConcreteAstrMessageEvent(
message_str="test",
message_obj=message,
platform_meta=platform_meta,
session_id="session123",
)
assert event.session.message_type == MessageType.FRIEND_MESSAGE
assert event.get_message_type() == MessageType.FRIEND_MESSAGE
class TestDefensiveGetattr:
"""Tests for defensive getattr behavior in AstrMessageEvent."""
def test_get_messages_without_message_attr(self, astr_message_event):
"""get_messages should handle message_obj without 'message' attribute."""
astr_message_event.message_obj = type("DummyMessage", (), {})()
messages = astr_message_event.get_messages()
assert isinstance(messages, list)
def test_get_message_type_without_type_attr(self, astr_message_event):
"""get_message_type should handle message_obj without 'type' attribute."""
astr_message_event.message_obj = type("DummyMessage", (), {})()
message_type = astr_message_event.get_message_type()
assert isinstance(message_type, MessageType)
def test_get_sender_fields_without_sender_attr(self, astr_message_event):
"""get_sender_id and get_sender_name should handle missing 'sender'."""
astr_message_event.message_obj = type("DummyMessage", (), {})()
sender_id = astr_message_event.get_sender_id()
sender_name = astr_message_event.get_sender_name()
assert isinstance(sender_id, str)
assert isinstance(sender_name, str)
def test_get_message_type_with_non_enum_type(self, astr_message_event):
"""get_message_type should handle message_obj.type that is not a MessageType."""
class DummyMessage:
def __init__(self):
self.type = "not_an_enum"
self.message = []
astr_message_event.message_obj = DummyMessage()
message_type = astr_message_event.get_message_type()
assert isinstance(message_type, MessageType)
+268
View File
@@ -0,0 +1,268 @@
"""Tests for AstrBotMessage and MessageMember classes."""
import time
from unittest.mock import patch
from astrbot.core.message.components import Image, Plain
from astrbot.core.platform.astrbot_message import AstrBotMessage, Group, MessageMember
from astrbot.core.platform.message_type import MessageType
class TestMessageMember:
"""Tests for MessageMember dataclass."""
def test_message_member_creation_basic(self):
"""Test creating a MessageMember with required fields."""
member = MessageMember(user_id="user123")
assert member.user_id == "user123"
assert member.nickname is None
def test_message_member_creation_with_nickname(self):
"""Test creating a MessageMember with nickname."""
member = MessageMember(user_id="user123", nickname="TestUser")
assert member.user_id == "user123"
assert member.nickname == "TestUser"
def test_message_member_str_with_nickname(self):
"""Test __str__ method with nickname."""
member = MessageMember(user_id="user123", nickname="TestUser")
result = str(member)
assert "User ID: user123" in result
assert "Nickname: TestUser" in result
def test_message_member_str_without_nickname(self):
"""Test __str__ method without nickname."""
member = MessageMember(user_id="user123")
result = str(member)
assert "User ID: user123" in result
assert "Nickname: N/A" in result
class TestGroup:
"""Tests for Group dataclass."""
def test_group_creation_basic(self):
"""Test creating a Group with required fields."""
group = Group(group_id="group123")
assert group.group_id == "group123"
assert group.group_name is None
assert group.group_avatar is None
assert group.group_owner is None
assert group.group_admins is None
assert group.members is None
def test_group_creation_with_all_fields(self):
"""Test creating a Group with all fields."""
members = [MessageMember(user_id="user1"), MessageMember(user_id="user2")]
group = Group(
group_id="group123",
group_name="Test Group",
group_avatar="http://example.com/avatar.jpg",
group_owner="owner123",
group_admins=["admin1", "admin2"],
members=members,
)
assert group.group_id == "group123"
assert group.group_name == "Test Group"
assert group.group_avatar == "http://example.com/avatar.jpg"
assert group.group_owner == "owner123"
assert group.group_admins == ["admin1", "admin2"]
assert group.members == members
def test_group_str_with_all_fields(self):
"""Test __str__ method with all fields."""
members = [MessageMember(user_id="user1", nickname="User One")]
group = Group(
group_id="group123",
group_name="Test Group",
group_avatar="http://example.com/avatar.jpg",
group_owner="owner123",
group_admins=["admin1"],
members=members,
)
result = str(group)
assert "Group ID: group123" in result
assert "Name: Test Group" in result
assert "Avatar: http://example.com/avatar.jpg" in result
assert "Owner ID: owner123" in result
assert "Admin IDs: ['admin1']" in result
assert "Members Len: 1" in result
def test_group_str_with_minimal_fields(self):
"""Test __str__ method with minimal fields."""
group = Group(group_id="group123")
result = str(group)
assert "Group ID: group123" in result
assert "Name: N/A" in result
assert "Avatar: N/A" in result
assert "Owner ID: N/A" in result
assert "Admin IDs: N/A" in result
assert "Members Len: 0" in result
assert "First Member: N/A" in result
class TestAstrBotMessage:
"""Tests for AstrBotMessage class."""
def test_astrbot_message_creation(self):
"""Test creating an AstrBotMessage."""
message = AstrBotMessage()
assert message.group is None
assert message.timestamp is not None
assert isinstance(message.timestamp, int)
def test_astrbot_message_timestamp(self):
"""Test timestamp is set on creation."""
with patch.object(time, "time", return_value=1234567890):
message = AstrBotMessage()
assert message.timestamp == 1234567890
def test_astrbot_message_all_attributes(self):
"""Test setting all attributes on AstrBotMessage."""
message = AstrBotMessage()
message.type = MessageType.FRIEND_MESSAGE
message.self_id = "bot123"
message.session_id = "session123"
message.message_id = "msg123"
message.sender = MessageMember(user_id="user123", nickname="TestUser")
message.message = [Plain(text="Hello")]
message.message_str = "Hello"
message.raw_message = {"raw": "data"}
assert message.type == MessageType.FRIEND_MESSAGE
assert message.self_id == "bot123"
assert message.session_id == "session123"
assert message.message_id == "msg123"
assert message.sender.user_id == "user123"
assert len(message.message) == 1
assert message.message_str == "Hello"
assert message.raw_message == {"raw": "data"}
def test_astrbot_message_str(self):
"""Test __str__ method."""
message = AstrBotMessage()
message.type = MessageType.FRIEND_MESSAGE
message.self_id = "bot123"
result = str(message)
assert "'type'" in result
assert "'self_id'" in result
class TestAstrBotMessageGroupId:
"""Tests for AstrBotMessage group_id property."""
def test_group_id_returns_empty_when_no_group(self):
"""Test group_id returns empty string when group is None."""
message = AstrBotMessage()
assert message.group_id == ""
def test_group_id_returns_group_id_when_group_exists(self):
"""Test group_id returns the group's id when group exists."""
message = AstrBotMessage()
message.group = Group(group_id="group123")
assert message.group_id == "group123"
def test_group_id_setter_creates_new_group(self):
"""Test group_id setter creates a new group if none exists."""
message = AstrBotMessage()
message.group_id = "new_group123"
assert message.group is not None
assert message.group.group_id == "new_group123"
def test_group_id_setter_updates_existing_group(self):
"""Test group_id setter updates existing group's id."""
message = AstrBotMessage()
message.group = Group(group_id="old_group")
message.group_id = "new_group"
assert message.group.group_id == "new_group"
def test_group_id_setter_with_none_removes_group(self):
"""Test group_id setter with None removes the group."""
message = AstrBotMessage()
message.group = Group(group_id="group123")
message.group_id = None
assert message.group is None
def test_group_id_setter_with_empty_string_removes_group(self):
"""Test group_id setter with empty string removes the group."""
message = AstrBotMessage()
message.group = Group(group_id="group123")
message.group_id = ""
assert message.group is None
class TestAstrBotMessageTypes:
"""Tests for AstrBotMessage with different message types."""
def test_friend_message_type(self):
"""Test AstrBotMessage with FRIEND_MESSAGE type."""
message = AstrBotMessage()
message.type = MessageType.FRIEND_MESSAGE
assert message.type == MessageType.FRIEND_MESSAGE
assert message.type.value == "FriendMessage"
def test_group_message_type(self):
"""Test AstrBotMessage with GROUP_MESSAGE type."""
message = AstrBotMessage()
message.type = MessageType.GROUP_MESSAGE
assert message.type == MessageType.GROUP_MESSAGE
assert message.type.value == "GroupMessage"
def test_other_message_type(self):
"""Test AstrBotMessage with OTHER_MESSAGE type."""
message = AstrBotMessage()
message.type = MessageType.OTHER_MESSAGE
assert message.type == MessageType.OTHER_MESSAGE
assert message.type.value == "OtherMessage"
class TestAstrBotMessageChain:
"""Tests for AstrBotMessage message chain."""
def test_message_chain_with_plain_text(self):
"""Test message chain with plain text."""
message = AstrBotMessage()
message.message = [Plain(text="Hello world")]
assert len(message.message) == 1
assert isinstance(message.message[0], Plain)
assert message.message[0].text == "Hello world"
def test_message_chain_with_multiple_components(self):
"""Test message chain with multiple components."""
message = AstrBotMessage()
message.message = [
Plain(text="Hello "),
Plain(text="world"),
Image(file="http://example.com/img.jpg"),
]
assert len(message.message) == 3
assert isinstance(message.message[0], Plain)
assert isinstance(message.message[1], Plain)
assert isinstance(message.message[2], Image)
def test_message_chain_empty(self):
"""Test empty message chain."""
message = AstrBotMessage()
message.message = []
assert len(message.message) == 0