Add comprehensive tests for ContextManager and ContextTruncator

- Implemented a full test suite for ContextManager covering initialization, message processing, token-based compression, and error handling.
- Added tests for ContextTruncator focusing on message fixing, truncation by turns, dropping oldest turns, and halving.
- Ensured that both test suites validate edge cases and maintain expected behavior with various message types, including system and tool messages.
This commit is contained in:
Soulter
2026-01-05 10:48:00 +08:00
parent d842155770
commit 1ed4d9f484
8 changed files with 6786 additions and 118 deletions
+57 -33
View File
@@ -1,24 +1,31 @@
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING
from astrbot.api import logger
from typing import TYPE_CHECKING, Protocol, runtime_checkable
from ..message import Message
if TYPE_CHECKING:
from astrbot import logger
else:
try:
from astrbot import logger
except ImportError:
import logging
logger = logging.getLogger("astrbot")
if TYPE_CHECKING:
from astrbot.core.provider.provider import Provider
from ..context.truncator import ContextTruncator
class ContextCompressor(ABC):
@runtime_checkable
class ContextCompressor(Protocol):
"""
Abstract base class for context compressors.
Protocol for context compressors.
Provides an interface for compressing message lists.
"""
@abstractmethod
async def compress(self, messages: list[Message]) -> list[Message]:
async def __call__(self, messages: list[Message]) -> list[Message]:
"""Compress the message list.
Args:
@@ -27,19 +34,10 @@ class ContextCompressor(ABC):
Returns:
The compressed message list.
"""
pass
...
class DefaultCompressor(ContextCompressor):
"""Default compressor implementation.
Returns the original messages.
"""
async def compress(self, messages: list[Message]) -> list[Message]:
return messages
class TruncateByTurnsCompressor(ContextCompressor):
class TruncateByTurnsCompressor:
"""Truncate by turns compressor implementation.
Truncates the message list by removing older turns.
"""
@@ -52,17 +50,47 @@ class TruncateByTurnsCompressor(ContextCompressor):
"""
self.truncate_turns = truncate_turns
async def compress(self, messages: list[Message]) -> list[Message]:
async def __call__(self, messages: list[Message]) -> list[Message]:
truncator = ContextTruncator()
truncated_messages = truncator.truncate_by_turns(
truncated_messages = truncator.truncate_by_dropping_oldest_turns(
messages,
keep_most_recent_turns=0,
dequeue_turns=self.truncate_turns,
drop_turns=self.truncate_turns,
)
return truncated_messages
class LLMSummaryCompressor(ContextCompressor):
def split_history(
messages: list[Message], keep_recent: int
) -> tuple[list[Message], list[Message], list[Message]]:
"""Split the message list into system messages, messages to summarize, and recent messages.
Args:
messages: The original message list.
keep_recent: The number of latest messages to keep.
Returns:
tuple: (system_messages, messages_to_summarize, recent_messages)
"""
# keep the system messages
first_non_system = 0
for i, msg in enumerate(messages):
if msg.role != "system":
first_non_system = i
break
system_messages = messages[:first_non_system]
non_system_messages = messages[first_non_system:]
if len(non_system_messages) <= keep_recent:
return system_messages, [], non_system_messages
messages_to_summarize = non_system_messages[:-keep_recent]
recent_messages = non_system_messages[-keep_recent:]
return system_messages, messages_to_summarize, recent_messages
class LLMSummaryCompressor:
"""LLM-based summary compressor.
Uses LLM to summarize the old conversation history, keeping the latest messages.
"""
@@ -90,7 +118,7 @@ class LLMSummaryCompressor(ContextCompressor):
"4. Write the summary in the user's language.\n"
)
async def compress(self, messages: list[Message]) -> list[Message]:
async def __call__(self, messages: list[Message]) -> list[Message]:
"""Use LLM to generate a summary of the conversation history.
Process:
@@ -101,12 +129,9 @@ class LLMSummaryCompressor(ContextCompressor):
if len(messages) <= self.keep_recent + 1:
return messages
# keep the system message
system_msg = messages[0] if messages and messages[0].role == "system" else None
start_idx = 1 if system_msg else 0
messages_to_summarize = messages[start_idx : -self.keep_recent]
recent_messages = messages[-self.keep_recent :]
system_messages, messages_to_summarize, recent_messages = split_history(
messages, self.keep_recent
)
if not messages_to_summarize:
return messages
@@ -125,8 +150,7 @@ class LLMSummaryCompressor(ContextCompressor):
# build result
result = []
if system_msg:
result.append(system_msg)
result.extend(system_messages)
result.append(
Message(
+28
View File
@@ -0,0 +1,28 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from astrbot.core.provider.provider import Provider
@dataclass
class ContextConfig:
"""Context configuration class."""
max_context_tokens: int = 0
"""Maximum number of context tokens. <= 0 means no limit."""
enforce_max_turns: int = -1 # -1 means no limit
"""Maximum number of conversation turns to keep. -1 means no limit. Executed before compression."""
truncate_turns: int = 1
"""Number of conversation turns to discard at once when truncation is triggered.
Two processes will use this value:
1. Enforce max turns truncation.
2. Truncation by turns compression strategy.
"""
llm_compress_instruction: str | None = None
"""Instruction prompt for LLM-based compression."""
llm_compress_keep_recent: int = 0
"""Number of recent messages to keep during LLM-based compression."""
llm_compress_provider: "Provider | None" = None
"""LLM provider used for compression tasks. If None, truncation strategy is used."""
+42 -37
View File
@@ -1,14 +1,10 @@
from typing import TYPE_CHECKING
from astrbot import logger
from ..message import Message
from .compressor import LLMSummaryCompressor, TruncateByTurnsCompressor
from .token_counter import TokenCounter
from .truncator import ContextTruncator
if TYPE_CHECKING:
from astrbot.core.provider.provider import Provider
from .config import ContextConfig
class ContextManager:
@@ -19,11 +15,7 @@ class ContextManager:
def __init__(
self,
max_context_tokens: int = 0,
truncate_turns: int = 1,
llm_compress_instruction: str | None = None,
llm_compress_keep_recent: int = 4,
llm_compress_provider: "Provider | None" = None,
config: ContextConfig,
):
"""Initialize the context manager.
@@ -32,26 +24,23 @@ class ContextManager:
2. LLM-based compression: use LLM to summarize old messages.
Args:
max_context_tokens: The maximum context tokens. <= 0 means no limit.
truncate_turns: For turncate strategy. The number of turns to discard when truncating.
llm_compress_instruction: The instruction text for LLM compression.
llm_compress_keep_recent: The number of recent messages to keep during LLM compression.
llm_compress_provider: The LLM provider for compression.
config: The context configuration.
"""
self.max_context_tokens = max_context_tokens
self.truncate_turns = truncate_turns
self.config = config
self.token_counter = TokenCounter()
self.truncator = ContextTruncator()
if llm_compress_provider:
if config.llm_compress_provider:
self.compressor = LLMSummaryCompressor(
provider=llm_compress_provider,
keep_recent=llm_compress_keep_recent,
instruction_text=llm_compress_instruction,
provider=config.llm_compress_provider,
keep_recent=config.llm_compress_keep_recent,
instruction_text=config.llm_compress_instruction,
)
else:
self.compressor = TruncateByTurnsCompressor(truncate_turns=truncate_turns)
self.compressor = TruncateByTurnsCompressor(
truncate_turns=config.truncate_turns
)
async def process(self, messages: list[Message]) -> list[Message]:
"""Process the messages.
@@ -62,17 +51,30 @@ class ContextManager:
Returns:
The processed message list.
"""
if self.max_context_tokens <= 0:
try:
result = messages
# 1. 基于轮次的截断 (Enforce max turns)
if self.config.enforce_max_turns != -1:
result = self.truncator.truncate_by_turns(
result,
keep_most_recent_turns=self.config.enforce_max_turns,
drop_turns=self.config.truncate_turns,
)
# 2. 基于 token 的压缩
if self.config.max_context_tokens > 0:
# check if the messages need to be compressed
needs_compression, _ = await self._initial_token_check(result)
# compress/truncate the messages if needed
result = await self._run_compression(result, needs_compression)
return result
except Exception as e:
logger.error(f"Error during context processing: {e}", exc_info=True)
return messages
# check if the messages need to be compressed
needs_compression, _ = await self._initial_token_check(messages)
# compress/truncate the messages if needed
messages = await self._run_compression(messages, needs_compression)
return messages
async def _initial_token_check(
self, messages: list[Message]
) -> tuple[bool, int | None]:
@@ -87,15 +89,15 @@ class ContextManager:
"""
if not messages:
return False, None
if self.max_context_tokens <= 0:
if self.config.max_context_tokens <= 0:
return False, None
total_tokens = self.token_counter.count_tokens(messages)
logger.debug(
f"ContextManager: total tokens = {total_tokens}, max_context_tokens = {self.max_context_tokens}"
f"ContextManager: total tokens = {total_tokens}, max_context_tokens = {self.config.max_context_tokens}"
)
usage_rate = total_tokens / self.max_context_tokens
usage_rate = total_tokens / self.config.max_context_tokens
needs_compression = usage_rate > self.COMPRESSION_THRESHOLD
return needs_compression, total_tokens if needs_compression else None
@@ -115,14 +117,17 @@ class ContextManager:
"""
if not needs_compression:
return messages
if self.max_context_tokens <= 0:
if self.config.max_context_tokens <= 0:
return messages
messages = await self.compressor.compress(messages)
messages = await self.compressor(messages)
# double check
tokens_after_summary = self.token_counter.count_tokens(messages)
if tokens_after_summary / self.max_context_tokens > self.COMPRESSION_THRESHOLD:
if (
tokens_after_summary / self.config.max_context_tokens
> self.COMPRESSION_THRESHOLD
):
# still over 82%, truncate by half
messages = self._compress_by_halving(messages)
+68 -21
View File
@@ -23,7 +23,7 @@ class ContextTruncator:
self,
messages: list[Message],
keep_most_recent_turns: int,
dequeue_turns: int = 1,
drop_turns: int = 1,
) -> list[Message]:
"""截断上下文列表,确保不超过最大长度。
一个 turn 包含一个 user 消息和一个 assistant 消息。
@@ -32,27 +32,33 @@ class ContextTruncator:
Args:
messages: 上下文列表
keep_most_recent_turns: 保留最近的对话轮数
dequeue_turns: 一次性丢弃的对话轮数
drop_turns: 一次性丢弃的对话轮数
Returns:
截断后的上下文列表
"""
if keep_most_recent_turns == -1:
return messages
if len(messages) <= keep_most_recent_turns:
return messages
if len(messages) // 2 <= keep_most_recent_turns:
first_non_system = 0
for i, msg in enumerate(messages):
if msg.role != "system":
first_non_system = i
break
system_messages = messages[:first_non_system]
non_system_messages = messages[first_non_system:]
if len(non_system_messages) // 2 <= keep_most_recent_turns:
return messages
system_message = None
if messages[0].role == "system":
system_message = messages[0]
messages = messages[1:]
num_to_keep = keep_most_recent_turns - drop_turns + 1
if num_to_keep <= 0:
truncated_contexts = []
else:
truncated_contexts = non_system_messages[-num_to_keep * 2 :]
truncated_contexts = messages[
-(keep_most_recent_turns - dequeue_turns + 1) * 2 :
]
# 找到第一个role 为 user 的索引,确保上下文格式正确
# 找到第一个 role 为 user 的索引,确保上下文格式正确
index = next(
(i for i, item in enumerate(truncated_contexts) if item.role == "user"),
None,
@@ -60,10 +66,45 @@ class ContextTruncator:
if index is not None and index > 0:
truncated_contexts = truncated_contexts[index:]
if system_message is not None:
truncated_contexts = [system_message] + truncated_contexts
result = system_messages + truncated_contexts
return self.fix_messages(truncated_contexts)
return self.fix_messages(result)
def truncate_by_dropping_oldest_turns(
self,
messages: list[Message],
drop_turns: int = 1,
) -> list[Message]:
"""丢弃最旧的 N 个对话轮次。"""
if drop_turns <= 0:
return messages
first_non_system = 0
for i, msg in enumerate(messages):
if msg.role != "system":
first_non_system = i
break
system_messages = messages[:first_non_system]
non_system_messages = messages[first_non_system:]
if len(non_system_messages) // 2 <= drop_turns:
truncated_non_system = []
else:
truncated_non_system = non_system_messages[drop_turns * 2 :]
index = next(
(i for i, item in enumerate(truncated_non_system) if item.role == "user"),
None,
)
if index is not None:
truncated_non_system = truncated_non_system[index:]
elif truncated_non_system:
truncated_non_system = []
result = system_messages + truncated_non_system
return self.fix_messages(result)
def truncate_by_halving(
self,
@@ -79,16 +120,22 @@ class ContextTruncator:
first_non_system = i
break
messages_to_delete = (len(messages) - first_non_system) // 2
system_messages = messages[:first_non_system]
non_system_messages = messages[first_non_system:]
result = messages[:first_non_system]
result.extend(messages[first_non_system + messages_to_delete :])
messages_to_delete = len(non_system_messages) // 2
if messages_to_delete == 0:
return messages
truncated_non_system = non_system_messages[messages_to_delete:]
index = next(
(i for i, item in enumerate(result) if item.role == "user"),
(i for i, item in enumerate(truncated_non_system) if item.role == "user"),
None,
)
if index is not None:
result = result[index:]
truncated_non_system = truncated_non_system[index:]
result = system_messages + truncated_non_system
return self.fix_messages(result)
@@ -26,7 +26,7 @@ from astrbot.core.provider.entities import (
from astrbot.core.provider.provider import Provider
from ..context.manager import ContextManager
from ..context.truncator import ContextTruncator
from ..context.config import ContextConfig
from ..hooks import BaseAgentRunHooks
from ..message import AssistantMessageSegment, Message, ToolCallMessageSegment
from ..response import AgentResponseData, AgentStats
@@ -70,15 +70,17 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
# we will do compress when:
# 1. before requesting LLM
# TODO: 2. after LLM output a tool call
self.context_manager = ContextManager(
# <=0 will never trigger context compression
self.context_config = ContextConfig(
# <=0 will never do compress
max_context_tokens=provider.provider_config.get("max_context_tokens", 0),
# enforce max turns before compression
enforce_max_turns=self.enforce_max_turns,
truncate_turns=self.truncate_turns,
llm_compress_instruction=self.llm_compress_instruction,
llm_compress_keep_recent=self.llm_compress_keep_recent,
llm_compress_provider=self.llm_compress_provider,
)
self.context_truncator = ContextTruncator()
self.context_manager = ContextManager(self.context_config)
self.provider = provider
self.final_llm_resp = None
@@ -121,12 +123,6 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
else:
yield await self.provider.text_chat(**payload)
async def do_context_compress(self):
"""检查并执行上下文压缩。"""
original_messages = self.run_context.messages
compressed_messages = await self.context_manager.process(original_messages)
self.run_context.messages = compressed_messages
@override
async def step(self):
"""Process a single step of the agent.
@@ -145,23 +141,10 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
self._transition_state(AgentState.RUNNING)
llm_resp_result = None
# do truncate
if self.enforce_max_turns != -1:
try:
truncated_messages = self.context_truncator.truncate_by_turns(
self.run_context.messages,
keep_most_recent_turns=self.enforce_max_turns,
dequeue_turns=self.truncate_turns,
)
self.run_context.messages = truncated_messages
except Exception as e:
logger.error(f"Error during context truncation: {e}", exc_info=True)
# check compress
try:
await self.do_context_compress()
except Exception as e:
logger.error(f"Error during context compression: {e}", exc_info=True)
# do truncate and compress
self.run_context.messages = await self.context_manager.process(
self.run_context.messages
)
async for llm_response in self._iter_llm_responses():
if llm_response.is_chunk:
+5525
View File
File diff suppressed because it is too large Load Diff
+633
View File
@@ -0,0 +1,633 @@
"""Comprehensive tests for ContextManager."""
import sys
from pathlib import Path
from typing import Literal
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
# Add parent directory to path to avoid circular import issues
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
from astrbot.core.agent.context.config import ContextConfig
from astrbot.core.agent.context.manager import ContextManager
from astrbot.core.agent.message import Message, TextPart
class TestContextManager:
"""Test suite for ContextManager."""
def create_message(
self, role: Literal["system", "user", "assistant", "tool"], content: str
) -> Message:
"""Helper to create a simple text message."""
return Message(role=role, content=content)
def create_messages(self, count: int) -> list[Message]:
"""Helper to create alternating user/assistant messages."""
messages = []
for i in range(count):
role = "user" if i % 2 == 0 else "assistant"
messages.append(self.create_message(role, f"Message {i}"))
return messages
# ==================== Basic Initialization Tests ====================
def test_init_with_minimal_config(self):
"""Test initialization with minimal configuration."""
config = ContextConfig()
manager = ContextManager(config)
assert manager.config == config
assert manager.token_counter is not None
assert manager.truncator is not None
assert manager.compressor is not None
def test_init_with_llm_compressor(self):
"""Test initialization with LLM-based compression."""
mock_provider = MagicMock()
config = ContextConfig(
llm_compress_provider=mock_provider,
llm_compress_keep_recent=5,
llm_compress_instruction="Summarize the conversation",
)
manager = ContextManager(config)
from astrbot.core.agent.context.compressor import LLMSummaryCompressor
assert isinstance(manager.compressor, LLMSummaryCompressor)
def test_init_with_truncate_compressor(self):
"""Test initialization with truncate-based compression (default)."""
config = ContextConfig(truncate_turns=3)
manager = ContextManager(config)
from astrbot.core.agent.context.compressor import TruncateByTurnsCompressor
assert isinstance(manager.compressor, TruncateByTurnsCompressor)
# ==================== Empty and Edge Cases ====================
@pytest.mark.asyncio
async def test_process_empty_messages(self):
"""Test processing an empty message list."""
config = ContextConfig()
manager = ContextManager(config)
result = await manager.process([])
assert result == []
@pytest.mark.asyncio
async def test_process_single_message(self):
"""Test processing a single message."""
config = ContextConfig()
manager = ContextManager(config)
messages = [self.create_message("user", "Hello")]
result = await manager.process(messages)
assert len(result) == 1
assert result[0].content == "Hello"
@pytest.mark.asyncio
async def test_process_with_no_limits(self):
"""Test processing when no limits are set (no truncation or compression)."""
config = ContextConfig(max_context_tokens=0, enforce_max_turns=-1)
manager = ContextManager(config)
messages = self.create_messages(20)
result = await manager.process(messages)
assert len(result) == 20
assert result == messages
# ==================== Enforce Max Turns Tests ====================
@pytest.mark.asyncio
async def test_enforce_max_turns_basic(self):
"""Test basic enforce_max_turns functionality."""
config = ContextConfig(enforce_max_turns=3, truncate_turns=1)
manager = ContextManager(config)
# Create 10 turns (20 messages)
messages = self.create_messages(20)
result = await manager.process(messages)
# Should keep only 3 most recent turns (6 messages)
assert len(result) <= 8 # May vary due to truncation logic
@pytest.mark.asyncio
async def test_enforce_max_turns_zero(self):
"""Test enforce_max_turns with value 0 (should keep nothing)."""
config = ContextConfig(enforce_max_turns=0, truncate_turns=1)
manager = ContextManager(config)
messages = self.create_messages(10)
result = await manager.process(messages)
# Should result in empty or minimal message list
assert len(result) <= 2
@pytest.mark.asyncio
async def test_enforce_max_turns_negative(self):
"""Test enforce_max_turns with -1 (no limit)."""
config = ContextConfig(enforce_max_turns=-1)
manager = ContextManager(config)
messages = self.create_messages(20)
result = await manager.process(messages)
assert len(result) == 20
@pytest.mark.asyncio
async def test_enforce_max_turns_with_system_messages(self):
"""Test enforce_max_turns preserves system messages."""
config = ContextConfig(enforce_max_turns=2, truncate_turns=1)
manager = ContextManager(config)
messages = [
self.create_message("system", "System instruction"),
*self.create_messages(10),
]
result = await manager.process(messages)
# System message should be preserved
system_msgs = [m for m in result if m.role == "system"]
assert len(system_msgs) >= 1
assert system_msgs[0].content == "System instruction"
# ==================== Token-based Compression Tests ====================
@pytest.mark.asyncio
async def test_token_compression_not_triggered_below_threshold(self):
"""Test that compression is not triggered below threshold."""
config = ContextConfig(max_context_tokens=1000)
manager = ContextManager(config)
# Create messages that total less than threshold
messages = [self.create_message("user", "Hi" * 50)] # ~100 tokens
with patch.object(
manager.compressor, "__call__", new_callable=AsyncMock
) as mock_compress:
result = await manager.process(messages)
# Compressor should not be called
mock_compress.assert_not_called()
assert result == messages
@pytest.mark.asyncio
async def test_token_compression_triggered_above_threshold(self):
"""Test that compression is triggered above threshold."""
config = ContextConfig(max_context_tokens=100, truncate_turns=1)
manager = ContextManager(config)
# Create messages that exceed threshold (0.82 * 100 = 82 tokens)
# 300 chars * 0.3 = 90 tokens > 82 threshold
long_text = "x" * 300 # ~90 tokens, above threshold
messages = [self.create_message("user", long_text)]
# Mock compressor to return smaller result
compressed = [self.create_message("user", "short")]
# Create a mock compressor that we can track
mock_compress = AsyncMock(return_value=compressed)
manager.compressor = mock_compress
result = await manager.process(messages)
# Compressor should be called
mock_compress.assert_called_once()
# Result should be the compressed version
assert len(result) == len(compressed)
@pytest.mark.asyncio
async def test_token_compression_with_zero_max_tokens(self):
"""Test that compression is skipped when max_context_tokens is 0."""
config = ContextConfig(max_context_tokens=0)
manager = ContextManager(config)
messages = [self.create_message("user", "x" * 10000)]
with patch.object(
manager.compressor, "__call__", new_callable=AsyncMock
) as mock_compress:
result = await manager.process(messages)
# Compressor should not be called
mock_compress.assert_not_called()
assert result == messages
@pytest.mark.asyncio
async def test_token_compression_with_negative_max_tokens(self):
"""Test that compression is skipped when max_context_tokens is negative."""
config = ContextConfig(max_context_tokens=-100)
manager = ContextManager(config)
messages = [self.create_message("user", "x" * 10000)]
with patch.object(
manager.compressor, "__call__", new_callable=AsyncMock
) as mock_compress:
result = await manager.process(messages)
# Compressor should not be called
mock_compress.assert_not_called()
assert result == messages
@pytest.mark.asyncio
async def test_double_check_after_compression(self):
"""Test that halving is applied if still over threshold after compression."""
config = ContextConfig(max_context_tokens=100)
manager = ContextManager(config)
# Create messages that would still be over threshold after compression
long_messages = [self.create_message("user", "x" * 200) for _ in range(10)]
# Mock compressor to return messages still over threshold
async def mock_compress(msgs):
return msgs # Return same messages (still over limit)
with patch.object(manager.compressor, "__call__", new=mock_compress):
with patch.object(
manager, "_compress_by_halving", return_value=long_messages[:5]
) as mock_halving:
_ = await manager.process(long_messages)
# Halving should be called
mock_halving.assert_called_once()
# ==================== Combined Truncation and Compression Tests ====================
@pytest.mark.asyncio
async def test_combined_enforce_turns_and_token_limit(self):
"""Test combining enforce_max_turns and token limit."""
config = ContextConfig(
enforce_max_turns=5, max_context_tokens=500, truncate_turns=1
)
manager = ContextManager(config)
# Create many messages
messages = self.create_messages(30)
result = await manager.process(messages)
# Should be truncated by both mechanisms
assert len(result) < 30
@pytest.mark.asyncio
async def test_sequential_processing_order(self):
"""Test that enforce_max_turns happens before token compression."""
config = ContextConfig(enforce_max_turns=5, max_context_tokens=1000)
manager = ContextManager(config)
messages = self.create_messages(20)
# Mock the truncator to track calls
with patch.object(
manager.truncator,
"truncate_by_turns",
wraps=manager.truncator.truncate_by_turns,
) as mock_truncate:
await manager.process(messages)
# Truncator should be called first
mock_truncate.assert_called_once()
# ==================== Error Handling Tests ====================
@pytest.mark.asyncio
async def test_error_handling_returns_original_messages(self):
"""Test that errors during processing return original messages."""
config = ContextConfig(max_context_tokens=100)
manager = ContextManager(config)
messages = self.create_messages(5)
# Make compressor raise an exception
with patch.object(
manager.compressor, "__call__", side_effect=Exception("Test error")
):
result = await manager.process(messages)
# Should return original messages despite error
assert result == messages
@pytest.mark.asyncio
async def test_error_handling_logs_exception(self):
"""Test that errors are logged."""
config = ContextConfig(max_context_tokens=100)
manager = ContextManager(config)
# Create messages that will trigger compression (> 82 tokens)
messages = [self.create_message("user", "x" * 300)] # ~90 tokens
# Replace compressor with one that raises an exception
manager.compressor = AsyncMock(side_effect=Exception("Test error"))
with patch("astrbot.core.agent.context.manager.logger") as mock_logger:
result = await manager.process(messages)
# Logger error method should be called
assert mock_logger.error.called
# Should return original messages on error
assert result == messages
# ==================== Multi-modal Content Tests ====================
@pytest.mark.asyncio
async def test_process_messages_with_textpart_content(self):
"""Test processing messages with TextPart content."""
config = ContextConfig()
manager = ContextManager(config)
messages = [
Message(role="user", content=[TextPart(text="Hello")]),
Message(role="assistant", content=[TextPart(text="Hi there")]),
]
result = await manager.process(messages)
assert len(result) == 2
assert result == messages
@pytest.mark.asyncio
async def test_token_counting_with_multimodal_content(self):
"""Test token counting works with multi-modal content."""
config = ContextConfig(max_context_tokens=50)
manager = ContextManager(config)
# Need enough tokens to exceed threshold: 50 * 0.82 = 41 tokens
# 150 chars * 0.3 = 45 tokens > 41
messages = [
Message(role="user", content=[TextPart(text="x" * 150)]),
]
# Should trigger compression due to token count
needs_compression, tokens = await manager._initial_token_check(messages)
assert tokens is not None # Tokens should be counted
assert needs_compression # Should trigger compression
# ==================== Tool Calls Tests ====================
@pytest.mark.asyncio
async def test_process_messages_with_tool_calls(self):
"""Test processing messages with tool calls."""
config = ContextConfig()
manager = ContextManager(config)
messages = [
Message(
role="assistant",
content="Let me search for that",
tool_calls=[
{
"id": "call_1",
"type": "function",
"function": {"name": "search", "arguments": "{}"},
}
],
),
Message(role="tool", content="Search result", tool_call_id="call_1"),
]
result = await manager.process(messages)
assert len(result) == 2
# ==================== Initial Token Check Tests ====================
@pytest.mark.asyncio
async def test_initial_token_check_empty_messages(self):
"""Test _initial_token_check with empty messages."""
config = ContextConfig(max_context_tokens=100)
manager = ContextManager(config)
needs_compression, tokens = await manager._initial_token_check([])
assert not needs_compression
assert tokens is None
@pytest.mark.asyncio
async def test_initial_token_check_no_limit(self):
"""Test _initial_token_check when max_context_tokens is 0."""
config = ContextConfig(max_context_tokens=0)
manager = ContextManager(config)
messages = [self.create_message("user", "x" * 1000)]
needs_compression, tokens = await manager._initial_token_check(messages)
assert not needs_compression
assert tokens is None
@pytest.mark.asyncio
async def test_initial_token_check_below_threshold(self):
"""Test _initial_token_check when below compression threshold."""
config = ContextConfig(max_context_tokens=1000)
manager = ContextManager(config)
messages = [self.create_message("user", "Hello")]
needs_compression, tokens = await manager._initial_token_check(messages)
assert not needs_compression
assert tokens is None
@pytest.mark.asyncio
async def test_initial_token_check_above_threshold(self):
"""Test _initial_token_check when above compression threshold."""
config = ContextConfig(max_context_tokens=100)
manager = ContextManager(config)
# Create message with ~90 tokens (above 0.82 * 100 = 82)
messages = [self.create_message("user", "这是测试" * 50)]
needs_compression, tokens = await manager._initial_token_check(messages)
assert needs_compression
assert tokens is not None
assert tokens > 82
@pytest.mark.asyncio
async def test_initial_token_check_exactly_at_threshold(self):
"""Test _initial_token_check when just above threshold."""
config = ContextConfig(max_context_tokens=100)
manager = ContextManager(config)
# Create message with >82 tokens (0.82 * 100)
# 300 chars * 0.3 = 90 tokens > 82 (threshold)
messages = [self.create_message("user", "x" * 300)] # ~90 tokens
needs_compression, tokens = await manager._initial_token_check(messages)
# Above threshold should trigger compression
assert tokens is not None
assert needs_compression
# ==================== Compression by Halving Tests ====================
def test_compress_by_halving_basic(self):
"""Test _compress_by_halving removes middle 50%."""
config = ContextConfig()
manager = ContextManager(config)
messages = self.create_messages(10)
result = manager._compress_by_halving(messages)
# Should keep roughly half
assert len(result) < len(messages)
def test_compress_by_halving_empty_list(self):
"""Test _compress_by_halving with empty list."""
config = ContextConfig()
manager = ContextManager(config)
result = manager._compress_by_halving([])
assert result == []
def test_compress_by_halving_single_message(self):
"""Test _compress_by_halving with single message."""
config = ContextConfig()
manager = ContextManager(config)
messages = [self.create_message("user", "Hello")]
result = manager._compress_by_halving(messages)
assert len(result) <= 1
# ==================== Complex Scenarios ====================
@pytest.mark.asyncio
async def test_multiple_compression_cycles(self):
"""Test that compression can be triggered multiple times in sequence."""
config = ContextConfig(max_context_tokens=50, truncate_turns=1)
manager = ContextManager(config)
# Process messages multiple times
messages = self.create_messages(10)
result1 = await manager.process(messages)
result2 = await manager.process(result1)
result3 = await manager.process(result2)
# Each cycle should maintain or reduce message count
assert len(result3) <= len(result2) <= len(result1)
@pytest.mark.asyncio
async def test_alternating_roles_preserved(self):
"""Test that user/assistant alternation is preserved after processing."""
config = ContextConfig(enforce_max_turns=3, truncate_turns=1)
manager = ContextManager(config)
messages = self.create_messages(20)
result = await manager.process(messages)
# Check that roles still alternate (excluding system messages)
non_system = [m for m in result if m.role != "system"]
if len(non_system) >= 2:
# Should start with user
assert non_system[0].role == "user"
@pytest.mark.asyncio
async def test_compression_threshold_constant(self):
"""Test that COMPRESSION_THRESHOLD is used correctly."""
config = ContextConfig(max_context_tokens=100)
manager = ContextManager(config)
# Verify the threshold is 0.82
assert manager.COMPRESSION_THRESHOLD == 0.82
# Create messages just below threshold
messages = [self.create_message("user", "x" * 81)] # ~24 tokens
needs_compression, _ = await manager._initial_token_check(messages)
assert not needs_compression
@pytest.mark.asyncio
async def test_large_batch_processing(self):
"""Test processing a large batch of messages."""
config = ContextConfig(
enforce_max_turns=10, max_context_tokens=1000, truncate_turns=2
)
manager = ContextManager(config)
# Create 100 messages (50 turns)
messages = self.create_messages(100)
result = await manager.process(messages)
# Should be significantly reduced
assert len(result) < 100
assert len(result) > 0
@pytest.mark.asyncio
async def test_config_persistence(self):
"""Test that config settings are respected throughout processing."""
config = ContextConfig(
max_context_tokens=500,
enforce_max_turns=5,
truncate_turns=2,
llm_compress_keep_recent=3,
)
manager = ContextManager(config)
# Verify config is stored
assert manager.config.max_context_tokens == 500
assert manager.config.enforce_max_turns == 5
assert manager.config.truncate_turns == 2
assert manager.config.llm_compress_keep_recent == 3
# ==================== Run Compression Tests ====================
@pytest.mark.asyncio
async def test_run_compression_skip_when_not_needed(self):
"""Test _run_compression skips when needs_compression is False."""
config = ContextConfig(max_context_tokens=100)
manager = ContextManager(config)
messages = self.create_messages(5)
with patch.object(
manager.compressor, "__call__", new_callable=AsyncMock
) as mock_compress:
result = await manager._run_compression(messages, needs_compression=False)
mock_compress.assert_not_called()
assert result == messages
@pytest.mark.asyncio
async def test_run_compression_skip_when_zero_limit(self):
"""Test _run_compression skips when max_context_tokens is 0."""
config = ContextConfig(max_context_tokens=0)
manager = ContextManager(config)
messages = self.create_messages(5)
with patch.object(
manager.compressor, "__call__", new_callable=AsyncMock
) as mock_compress:
result = await manager._run_compression(messages, needs_compression=True)
mock_compress.assert_not_called()
assert result == messages
@pytest.mark.asyncio
async def test_run_compression_applies_compressor(self):
"""Test _run_compression calls compressor when needed through process()."""
config = ContextConfig(max_context_tokens=100, truncate_turns=1)
manager = ContextManager(config)
# Create messages that will trigger compression
messages = [self.create_message("user", "x" * 300)] # ~90 tokens > 82 threshold
compressed = [self.create_message("user", "short")] # Much smaller
# Replace compressor with mock
mock_compress = AsyncMock(return_value=compressed)
manager.compressor = mock_compress
result = await manager.process(messages)
# Compressor should have been called
mock_compress.assert_called_once()
assert len(result) <= len(messages)
+423
View File
@@ -0,0 +1,423 @@
"""Tests for ContextTruncator."""
from astrbot.core.agent.context.truncator import ContextTruncator
from astrbot.core.agent.message import Message
class TestContextTruncator:
"""Test suite for ContextTruncator."""
def create_message(self, role: str, content: str = "test content") -> Message:
"""Helper to create a simple test message."""
return Message(role=role, content=content)
def create_messages(
self, count: int, include_system: bool = False
) -> list[Message]:
"""Helper to create alternating user/assistant messages.
Args:
count: Number of messages to create
include_system: Whether to include a system message at the start
Returns:
List of messages
"""
messages = []
if include_system:
messages.append(self.create_message("system", "System prompt"))
for i in range(count):
role = "user" if i % 2 == 0 else "assistant"
messages.append(self.create_message(role, f"Message {i}"))
return messages
# ==================== fix_messages Tests ====================
def test_fix_messages_empty_list(self):
"""Test fix_messages with an empty list."""
truncator = ContextTruncator()
result = truncator.fix_messages([])
assert result == []
def test_fix_messages_normal_messages(self):
"""Test fix_messages with normal user/assistant messages."""
truncator = ContextTruncator()
messages = [
self.create_message("user", "Hello"),
self.create_message("assistant", "Hi"),
self.create_message("user", "How are you?"),
]
result = truncator.fix_messages(messages)
assert len(result) == 3
assert result == messages
def test_fix_messages_tool_with_valid_context(self):
"""Test fix_messages with tool message after user+assistant."""
truncator = ContextTruncator()
messages = [
self.create_message("user", "Run tool"),
self.create_message("assistant", "Running..."),
self.create_message("tool", "Tool result"),
]
result = truncator.fix_messages(messages)
assert len(result) == 3
assert result == messages
def test_fix_messages_tool_without_context(self):
"""Test fix_messages with tool message without enough context."""
truncator = ContextTruncator()
messages = [
self.create_message("tool", "Tool result"),
]
result = truncator.fix_messages(messages)
# Tool message without context should be removed
assert len(result) == 0
def test_fix_messages_tool_with_only_one_message(self):
"""Test fix_messages with tool message after only one message."""
truncator = ContextTruncator()
messages = [
self.create_message("user", "Hello"),
self.create_message("tool", "Tool result"),
]
result = truncator.fix_messages(messages)
# Tool message without enough context should be removed
assert len(result) == 0
def test_fix_messages_multiple_tools(self):
"""Test fix_messages with multiple tool messages."""
truncator = ContextTruncator()
messages = [
self.create_message("user", "Run tool"),
self.create_message("assistant", "Running..."),
self.create_message("tool", "Tool 1 result"),
self.create_message("tool", "Tool 2 result"),
]
result = truncator.fix_messages(messages)
assert len(result) == 4
assert result == messages
def test_fix_messages_mixed_system_tool(self):
"""Test fix_messages with system message and tool messages."""
truncator = ContextTruncator()
messages = [
self.create_message("system", "System prompt"),
self.create_message("user", "Run tool"),
self.create_message("assistant", "Running..."),
self.create_message("tool", "Tool result"),
]
result = truncator.fix_messages(messages)
assert len(result) == 4
assert result == messages
# ==================== truncate_by_turns Tests ====================
def test_truncate_by_turns_no_limit(self):
"""Test truncate_by_turns with -1 (no limit)."""
truncator = ContextTruncator()
messages = self.create_messages(20)
result = truncator.truncate_by_turns(messages, keep_most_recent_turns=-1)
assert len(result) == 20
assert result == messages
def test_truncate_by_turns_basic(self):
"""Test basic truncate_by_turns functionality."""
truncator = ContextTruncator()
# Create 10 messages = 5 turns (user/assistant pairs)
messages = self.create_messages(10)
result = truncator.truncate_by_turns(
messages, keep_most_recent_turns=3, drop_turns=1
)
# Should keep 3 most recent turns (6 messages)
assert len(result) <= 8 # (3-1+1)*2 = 6, but may adjust for correct format
def test_truncate_by_turns_with_system_message(self):
"""Test truncate_by_turns preserves system messages."""
truncator = ContextTruncator()
messages = self.create_messages(10, include_system=True)
result = truncator.truncate_by_turns(
messages, keep_most_recent_turns=2, drop_turns=1
)
# System message should always be preserved
assert result[0].role == "system"
assert result[0].content == "System prompt"
def test_truncate_by_turns_zero_keep(self):
"""Test truncate_by_turns with keep_most_recent_turns=0."""
truncator = ContextTruncator()
messages = self.create_messages(10)
result = truncator.truncate_by_turns(
messages, keep_most_recent_turns=0, drop_turns=1
)
# Should result in empty or minimal list
assert len(result) == 0
def test_truncate_by_turns_below_threshold(self):
"""Test truncate_by_turns when messages are below threshold."""
truncator = ContextTruncator()
# Create 4 messages = 2 turns
messages = self.create_messages(4)
result = truncator.truncate_by_turns(
messages, keep_most_recent_turns=5, drop_turns=1
)
# No truncation should happen
assert len(result) == 4
assert result == messages
def test_truncate_by_turns_exact_threshold(self):
"""Test truncate_by_turns when messages exactly match threshold."""
truncator = ContextTruncator()
# Create 6 messages = 3 turns
messages = self.create_messages(6)
result = truncator.truncate_by_turns(
messages, keep_most_recent_turns=3, drop_turns=1
)
# No truncation should happen
assert len(result) == 6
assert result == messages
def test_truncate_by_turns_ensures_user_first(self):
"""Test that truncate_by_turns ensures user message comes first."""
truncator = ContextTruncator()
# Create scenario where truncation might start with assistant
messages = self.create_messages(20)
result = truncator.truncate_by_turns(
messages, keep_most_recent_turns=3, drop_turns=1
)
# First non-system message should be user
assert result[0].role == "user"
def test_truncate_by_turns_multiple_drop(self):
"""Test truncate_by_turns with multiple turns dropped at once."""
truncator = ContextTruncator()
messages = self.create_messages(20)
result = truncator.truncate_by_turns(
messages, keep_most_recent_turns=5, drop_turns=3
)
# Should drop 3 turns when limit exceeded
assert len(result) < len(messages)
# ==================== truncate_by_dropping_oldest_turns Tests ====================
def test_truncate_by_dropping_oldest_turns_zero(self):
"""Test truncate_by_dropping_oldest_turns with drop_turns=0."""
truncator = ContextTruncator()
messages = self.create_messages(10)
result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=0)
assert result == messages
def test_truncate_by_dropping_oldest_turns_negative(self):
"""Test truncate_by_dropping_oldest_turns with negative drop_turns."""
truncator = ContextTruncator()
messages = self.create_messages(10)
result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=-1)
assert result == messages
def test_truncate_by_dropping_oldest_turns_basic(self):
"""Test basic truncate_by_dropping_oldest_turns functionality."""
truncator = ContextTruncator()
# Create 10 messages = 5 turns
messages = self.create_messages(10)
result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=2)
# Should drop 2 oldest turns (4 messages)
assert len(result) == 6
# Should start with user message
assert result[0].role == "user"
def test_truncate_by_dropping_oldest_turns_with_system(self):
"""Test truncate_by_dropping_oldest_turns preserves system messages."""
truncator = ContextTruncator()
messages = self.create_messages(10, include_system=True)
result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=2)
# System message should be preserved
assert result[0].role == "system"
assert result[0].content == "System prompt"
def test_truncate_by_dropping_oldest_turns_drop_all(self):
"""Test truncate_by_dropping_oldest_turns dropping all turns."""
truncator = ContextTruncator()
# Create 4 messages = 2 turns
messages = self.create_messages(4)
result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=2)
# Should drop all turns
assert len(result) == 0
def test_truncate_by_dropping_oldest_turns_drop_more_than_available(self):
"""Test truncate_by_dropping_oldest_turns with drop_turns > available turns."""
truncator = ContextTruncator()
# Create 4 messages = 2 turns
messages = self.create_messages(4)
result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=5)
# Should result in empty list
assert len(result) == 0
def test_truncate_by_dropping_oldest_turns_ensures_user_first(self):
"""Test that result starts with user message after dropping."""
truncator = ContextTruncator()
messages = self.create_messages(20)
result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=3)
# First message should be user
if len(result) > 0:
assert result[0].role == "user"
# ==================== truncate_by_halving Tests ====================
def test_truncate_by_halving_empty(self):
"""Test truncate_by_halving with empty list."""
truncator = ContextTruncator()
result = truncator.truncate_by_halving([])
assert result == []
def test_truncate_by_halving_single_message(self):
"""Test truncate_by_halving with single message."""
truncator = ContextTruncator()
messages = [self.create_message("user", "Hello")]
result = truncator.truncate_by_halving(messages)
# Should not truncate if <= 2 messages
assert result == messages
def test_truncate_by_halving_two_messages(self):
"""Test truncate_by_halving with two messages."""
truncator = ContextTruncator()
messages = self.create_messages(2)
result = truncator.truncate_by_halving(messages)
# Should not truncate if <= 2 messages
assert result == messages
def test_truncate_by_halving_basic(self):
"""Test basic truncate_by_halving functionality."""
truncator = ContextTruncator()
# Create 20 messages
messages = self.create_messages(20)
result = truncator.truncate_by_halving(messages)
# Should delete 50% = 10 messages, keep 10
assert len(result) == 10
# First message should be user
assert result[0].role == "user"
def test_truncate_by_halving_with_system_message(self):
"""Test truncate_by_halving preserves system messages."""
truncator = ContextTruncator()
messages = self.create_messages(20, include_system=True)
result = truncator.truncate_by_halving(messages)
# System message should be preserved
assert result[0].role == "system"
assert result[0].content == "System prompt"
def test_truncate_by_halving_odd_count(self):
"""Test truncate_by_halving with odd number of messages."""
truncator = ContextTruncator()
messages = self.create_messages(11)
result = truncator.truncate_by_halving(messages)
# Should delete floor(11/2) = 5 messages, keep 6
# But after ensuring user first, may be 5
assert len(result) >= 5
assert result[0].role == "user"
def test_truncate_by_halving_ensures_user_first(self):
"""Test that result starts with user message."""
truncator = ContextTruncator()
# Create messages starting with user
messages = self.create_messages(30)
result = truncator.truncate_by_halving(messages)
# First message should be user
assert result[0].role == "user"
def test_truncate_by_halving_preserves_recent_messages(self):
"""Test that truncate_by_halving keeps the most recent 50%."""
truncator = ContextTruncator()
messages = [
self.create_message("user", "Message 0"),
self.create_message("assistant", "Message 1"),
self.create_message("user", "Message 2"),
self.create_message("assistant", "Message 3"),
]
result = truncator.truncate_by_halving(messages)
# Should keep last 2 messages
assert len(result) == 2
assert result[0].content == "Message 2"
assert result[1].content == "Message 3"
# ==================== Integration Tests ====================
def test_truncate_with_tool_messages(self):
"""Test truncation with tool messages."""
truncator = ContextTruncator()
messages = [
self.create_message("user", "Run tool"),
self.create_message("assistant", "Running..."),
self.create_message("tool", "Tool result"),
self.create_message("user", "Thanks"),
self.create_message("assistant", "Welcome"),
]
result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=1)
# First turn (user+assistant+tool) should be dropped
# Tool message should be cleaned up by fix_messages
assert len(result) <= 2
def test_chain_multiple_truncations(self):
"""Test chaining multiple truncation methods."""
truncator = ContextTruncator()
messages = self.create_messages(40, include_system=True)
# First: truncate by turns
result = truncator.truncate_by_turns(
messages, keep_most_recent_turns=10, drop_turns=2
)
# Then: halve
result = truncator.truncate_by_halving(result)
# Should have system message + truncated content
assert result[0].role == "system"
assert len(result) < len(messages)
def test_empty_after_system_message(self):
"""Test truncation when only system message exists."""
truncator = ContextTruncator()
messages = [self.create_message("system", "System prompt")]
result = truncator.truncate_by_turns(
messages, keep_most_recent_turns=5, drop_turns=1
)
# Should keep system message
assert len(result) == 1
assert result[0].role == "system"
def test_all_system_messages(self):
"""Test truncation with only system messages."""
truncator = ContextTruncator()
messages = [
self.create_message("system", "System 1"),
self.create_message("system", "System 2"),
]
result = truncator.truncate_by_turns(
messages, keep_most_recent_turns=0, drop_turns=1
)
# System messages should be preserved, but since there are no non-system
# messages and keep_most_recent_turns=0, result should be system messages only
assert len(result) >= 0 # May keep system messages or clear all
if len(result) > 0:
assert all(msg.role == "system" for msg in result)