375 lines
14 KiB
Python
375 lines
14 KiB
Python
"""Tests for ContextTruncator."""
|
|
|
|
from astrbot.core.agent.context.truncator import ContextTruncator
|
|
from astrbot.core.agent.message import Message
|
|
|
|
|
|
class TestContextTruncator:
|
|
"""Test suite for ContextTruncator."""
|
|
|
|
def create_message(self, role: str, content: str = "test content") -> Message:
|
|
"""Helper to create a simple test message."""
|
|
return Message(role=role, content=content)
|
|
|
|
def create_messages(
|
|
self, count: int, include_system: bool = False
|
|
) -> list[Message]:
|
|
"""Helper to create alternating user/assistant messages.
|
|
|
|
Args:
|
|
count: Number of messages to create
|
|
include_system: Whether to include a system message at the start
|
|
|
|
Returns:
|
|
List of messages
|
|
"""
|
|
messages = []
|
|
if include_system:
|
|
messages.append(self.create_message("system", "System prompt"))
|
|
|
|
for i in range(count):
|
|
role = "user" if i % 2 == 0 else "assistant"
|
|
messages.append(self.create_message(role, f"Message {i}"))
|
|
return messages
|
|
|
|
# ==================== fix_messages Tests ====================
|
|
|
|
def test_fix_messages_empty_list(self):
|
|
"""Test fix_messages with an empty list."""
|
|
truncator = ContextTruncator()
|
|
result = truncator.fix_messages([])
|
|
assert result == []
|
|
|
|
def test_fix_messages_normal_messages(self):
|
|
"""Test fix_messages with normal user/assistant messages."""
|
|
truncator = ContextTruncator()
|
|
messages = [
|
|
self.create_message("user", "Hello"),
|
|
self.create_message("assistant", "Hi"),
|
|
self.create_message("user", "How are you?"),
|
|
]
|
|
result = truncator.fix_messages(messages)
|
|
assert len(result) == 3
|
|
assert result == messages
|
|
|
|
def test_fix_messages_tool_without_context(self):
|
|
"""Test fix_messages with tool message without enough context."""
|
|
truncator = ContextTruncator()
|
|
messages = [
|
|
self.create_message("tool", "Tool result"),
|
|
]
|
|
result = truncator.fix_messages(messages)
|
|
# Tool message without context should be removed
|
|
assert len(result) == 0
|
|
|
|
# ==================== truncate_by_turns Tests ====================
|
|
|
|
def test_truncate_by_turns_no_limit(self):
|
|
"""Test truncate_by_turns with -1 (no limit)."""
|
|
truncator = ContextTruncator()
|
|
messages = self.create_messages(20)
|
|
result = truncator.truncate_by_turns(messages, keep_most_recent_turns=-1)
|
|
assert len(result) == 20
|
|
assert result == messages
|
|
|
|
def test_truncate_by_turns_basic(self):
|
|
"""Test basic truncate_by_turns functionality."""
|
|
truncator = ContextTruncator()
|
|
# Create 10 messages = 5 turns (user/assistant pairs)
|
|
messages = self.create_messages(10)
|
|
result = truncator.truncate_by_turns(
|
|
messages, keep_most_recent_turns=3, drop_turns=1
|
|
)
|
|
|
|
# Should keep 3 most recent turns (6 messages)
|
|
assert len(result) <= 8 # (3-1+1)*2 = 6, but may adjust for correct format
|
|
|
|
def test_truncate_by_turns_with_system_message(self):
|
|
"""Test truncate_by_turns preserves system messages."""
|
|
truncator = ContextTruncator()
|
|
messages = self.create_messages(10, include_system=True)
|
|
result = truncator.truncate_by_turns(
|
|
messages, keep_most_recent_turns=2, drop_turns=1
|
|
)
|
|
|
|
# System message should always be preserved
|
|
assert result[0].role == "system"
|
|
assert result[0].content == "System prompt"
|
|
|
|
def test_truncate_by_turns_zero_keep(self):
|
|
"""Test truncate_by_turns with keep_most_recent_turns=0."""
|
|
truncator = ContextTruncator()
|
|
messages = self.create_messages(10)
|
|
result = truncator.truncate_by_turns(
|
|
messages, keep_most_recent_turns=0, drop_turns=1
|
|
)
|
|
|
|
# Should result in empty or minimal list
|
|
assert len(result) == 0
|
|
|
|
def test_truncate_by_turns_below_threshold(self):
|
|
"""Test truncate_by_turns when messages are below threshold."""
|
|
truncator = ContextTruncator()
|
|
# Create 4 messages = 2 turns
|
|
messages = self.create_messages(4)
|
|
result = truncator.truncate_by_turns(
|
|
messages, keep_most_recent_turns=5, drop_turns=1
|
|
)
|
|
|
|
# No truncation should happen
|
|
assert len(result) == 4
|
|
assert result == messages
|
|
|
|
def test_truncate_by_turns_exact_threshold(self):
|
|
"""Test truncate_by_turns when messages exactly match threshold."""
|
|
truncator = ContextTruncator()
|
|
# Create 6 messages = 3 turns
|
|
messages = self.create_messages(6)
|
|
result = truncator.truncate_by_turns(
|
|
messages, keep_most_recent_turns=3, drop_turns=1
|
|
)
|
|
|
|
# No truncation should happen
|
|
assert len(result) == 6
|
|
assert result == messages
|
|
|
|
def test_truncate_by_turns_ensures_user_first(self):
|
|
"""Test that truncate_by_turns ensures user message comes first."""
|
|
truncator = ContextTruncator()
|
|
# Create scenario where truncation might start with assistant
|
|
messages = self.create_messages(20)
|
|
result = truncator.truncate_by_turns(
|
|
messages, keep_most_recent_turns=3, drop_turns=1
|
|
)
|
|
|
|
# First non-system message should be user
|
|
assert result[0].role == "user"
|
|
|
|
def test_truncate_by_turns_multiple_drop(self):
|
|
"""Test truncate_by_turns with multiple turns dropped at once."""
|
|
truncator = ContextTruncator()
|
|
messages = self.create_messages(20)
|
|
result = truncator.truncate_by_turns(
|
|
messages, keep_most_recent_turns=5, drop_turns=3
|
|
)
|
|
|
|
# Should drop 3 turns when limit exceeded
|
|
assert len(result) < len(messages)
|
|
|
|
# ==================== truncate_by_dropping_oldest_turns Tests ====================
|
|
|
|
def test_truncate_by_dropping_oldest_turns_zero(self):
|
|
"""Test truncate_by_dropping_oldest_turns with drop_turns=0."""
|
|
truncator = ContextTruncator()
|
|
messages = self.create_messages(10)
|
|
result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=0)
|
|
assert result == messages
|
|
|
|
def test_truncate_by_dropping_oldest_turns_negative(self):
|
|
"""Test truncate_by_dropping_oldest_turns with negative drop_turns."""
|
|
truncator = ContextTruncator()
|
|
messages = self.create_messages(10)
|
|
result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=-1)
|
|
assert result == messages
|
|
|
|
def test_truncate_by_dropping_oldest_turns_basic(self):
|
|
"""Test basic truncate_by_dropping_oldest_turns functionality."""
|
|
truncator = ContextTruncator()
|
|
# Create 10 messages = 5 turns
|
|
messages = self.create_messages(10)
|
|
result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=2)
|
|
|
|
# Should drop 2 oldest turns (4 messages)
|
|
assert len(result) == 6
|
|
# Should start with user message
|
|
assert result[0].role == "user"
|
|
|
|
def test_truncate_by_dropping_oldest_turns_with_system(self):
|
|
"""Test truncate_by_dropping_oldest_turns preserves system messages."""
|
|
truncator = ContextTruncator()
|
|
messages = self.create_messages(10, include_system=True)
|
|
result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=2)
|
|
|
|
# System message should be preserved
|
|
assert result[0].role == "system"
|
|
assert result[0].content == "System prompt"
|
|
|
|
def test_truncate_by_dropping_oldest_turns_drop_all(self):
|
|
"""Test truncate_by_dropping_oldest_turns dropping all turns."""
|
|
truncator = ContextTruncator()
|
|
# Create 4 messages = 2 turns
|
|
messages = self.create_messages(4)
|
|
result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=2)
|
|
|
|
# Should drop all turns
|
|
assert len(result) == 0
|
|
|
|
def test_truncate_by_dropping_oldest_turns_drop_more_than_available(self):
|
|
"""Test truncate_by_dropping_oldest_turns with drop_turns > available turns."""
|
|
truncator = ContextTruncator()
|
|
# Create 4 messages = 2 turns
|
|
messages = self.create_messages(4)
|
|
result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=5)
|
|
|
|
# Should result in empty list
|
|
assert len(result) == 0
|
|
|
|
def test_truncate_by_dropping_oldest_turns_ensures_user_first(self):
|
|
"""Test that result starts with user message after dropping."""
|
|
truncator = ContextTruncator()
|
|
messages = self.create_messages(20)
|
|
result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=3)
|
|
|
|
# First message should be user
|
|
if len(result) > 0:
|
|
assert result[0].role == "user"
|
|
|
|
# ==================== truncate_by_halving Tests ====================
|
|
|
|
def test_truncate_by_halving_empty(self):
|
|
"""Test truncate_by_halving with empty list."""
|
|
truncator = ContextTruncator()
|
|
result = truncator.truncate_by_halving([])
|
|
assert result == []
|
|
|
|
def test_truncate_by_halving_single_message(self):
|
|
"""Test truncate_by_halving with single message."""
|
|
truncator = ContextTruncator()
|
|
messages = [self.create_message("user", "Hello")]
|
|
result = truncator.truncate_by_halving(messages)
|
|
# Should not truncate if <= 2 messages
|
|
assert result == messages
|
|
|
|
def test_truncate_by_halving_two_messages(self):
|
|
"""Test truncate_by_halving with two messages."""
|
|
truncator = ContextTruncator()
|
|
messages = self.create_messages(2)
|
|
result = truncator.truncate_by_halving(messages)
|
|
# Should not truncate if <= 2 messages
|
|
assert result == messages
|
|
|
|
def test_truncate_by_halving_basic(self):
|
|
"""Test basic truncate_by_halving functionality."""
|
|
truncator = ContextTruncator()
|
|
# Create 20 messages
|
|
messages = self.create_messages(20)
|
|
result = truncator.truncate_by_halving(messages)
|
|
|
|
# Should delete 50% = 10 messages, keep 10
|
|
assert len(result) == 10
|
|
# First message should be user
|
|
assert result[0].role == "user"
|
|
|
|
def test_truncate_by_halving_with_system_message(self):
|
|
"""Test truncate_by_halving preserves system messages."""
|
|
truncator = ContextTruncator()
|
|
messages = self.create_messages(20, include_system=True)
|
|
result = truncator.truncate_by_halving(messages)
|
|
|
|
# System message should be preserved
|
|
assert result[0].role == "system"
|
|
assert result[0].content == "System prompt"
|
|
|
|
def test_truncate_by_halving_odd_count(self):
|
|
"""Test truncate_by_halving with odd number of messages."""
|
|
truncator = ContextTruncator()
|
|
messages = self.create_messages(11)
|
|
result = truncator.truncate_by_halving(messages)
|
|
|
|
# Should delete floor(11/2) = 5 messages, keep 6
|
|
# But after ensuring user first, may be 5
|
|
assert len(result) >= 5
|
|
assert result[0].role == "user"
|
|
|
|
def test_truncate_by_halving_ensures_user_first(self):
|
|
"""Test that result starts with user message."""
|
|
truncator = ContextTruncator()
|
|
# Create messages starting with user
|
|
messages = self.create_messages(30)
|
|
result = truncator.truncate_by_halving(messages)
|
|
|
|
# First message should be user
|
|
assert result[0].role == "user"
|
|
|
|
def test_truncate_by_halving_preserves_recent_messages(self):
|
|
"""Test that truncate_by_halving keeps the most recent 50%."""
|
|
truncator = ContextTruncator()
|
|
messages = [
|
|
self.create_message("user", "Message 0"),
|
|
self.create_message("assistant", "Message 1"),
|
|
self.create_message("user", "Message 2"),
|
|
self.create_message("assistant", "Message 3"),
|
|
]
|
|
result = truncator.truncate_by_halving(messages)
|
|
|
|
# Should keep last 2 messages
|
|
assert len(result) == 2
|
|
assert result[0].content == "Message 2"
|
|
assert result[1].content == "Message 3"
|
|
|
|
# ==================== Integration Tests ====================
|
|
|
|
def test_truncate_with_tool_messages(self):
|
|
"""Test truncation with tool messages."""
|
|
truncator = ContextTruncator()
|
|
messages = [
|
|
self.create_message("user", "Run tool"),
|
|
self.create_message("assistant", "Running..."),
|
|
self.create_message("tool", "Tool result"),
|
|
self.create_message("user", "Thanks"),
|
|
self.create_message("assistant", "Welcome"),
|
|
]
|
|
|
|
result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=1)
|
|
|
|
# First turn (user+assistant+tool) should be dropped
|
|
# Tool message should be cleaned up by fix_messages
|
|
assert len(result) <= 2
|
|
|
|
def test_chain_multiple_truncations(self):
|
|
"""Test chaining multiple truncation methods."""
|
|
truncator = ContextTruncator()
|
|
messages = self.create_messages(40, include_system=True)
|
|
|
|
# First: truncate by turns
|
|
result = truncator.truncate_by_turns(
|
|
messages, keep_most_recent_turns=10, drop_turns=2
|
|
)
|
|
# Then: halve
|
|
result = truncator.truncate_by_halving(result)
|
|
|
|
# Should have system message + truncated content
|
|
assert result[0].role == "system"
|
|
assert len(result) < len(messages)
|
|
|
|
def test_empty_after_system_message(self):
|
|
"""Test truncation when only system message exists."""
|
|
truncator = ContextTruncator()
|
|
messages = [self.create_message("system", "System prompt")]
|
|
|
|
result = truncator.truncate_by_turns(
|
|
messages, keep_most_recent_turns=5, drop_turns=1
|
|
)
|
|
|
|
# Should keep system message
|
|
assert len(result) == 1
|
|
assert result[0].role == "system"
|
|
|
|
def test_all_system_messages(self):
|
|
"""Test truncation with only system messages."""
|
|
truncator = ContextTruncator()
|
|
messages = [
|
|
self.create_message("system", "System 1"),
|
|
self.create_message("system", "System 2"),
|
|
]
|
|
|
|
result = truncator.truncate_by_turns(
|
|
messages, keep_most_recent_turns=0, drop_turns=1
|
|
)
|
|
|
|
# System messages should be preserved, but since there are no non-system
|
|
# messages and keep_most_recent_turns=0, result should be system messages only
|
|
assert len(result) >= 0 # May keep system messages or clear all
|
|
if len(result) > 0:
|
|
assert all(msg.role == "system" for msg in result)
|