Add comprehensive tests for ContextManager and ContextTruncator
- Implemented a full test suite for ContextManager covering initialization, message processing, token-based compression, and error handling. - Added tests for ContextTruncator focusing on message fixing, truncation by turns, dropping oldest turns, and halving. - Ensured that both test suites validate edge cases and maintain expected behavior with various message types, including system and tool messages.
This commit is contained in:
@@ -1,24 +1,31 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from astrbot.api import logger
|
||||
from typing import TYPE_CHECKING, Protocol, runtime_checkable
|
||||
|
||||
from ..message import Message
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from astrbot import logger
|
||||
else:
|
||||
try:
|
||||
from astrbot import logger
|
||||
except ImportError:
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger("astrbot")
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from astrbot.core.provider.provider import Provider
|
||||
|
||||
from ..context.truncator import ContextTruncator
|
||||
|
||||
|
||||
class ContextCompressor(ABC):
|
||||
@runtime_checkable
|
||||
class ContextCompressor(Protocol):
|
||||
"""
|
||||
Abstract base class for context compressors.
|
||||
Protocol for context compressors.
|
||||
Provides an interface for compressing message lists.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def compress(self, messages: list[Message]) -> list[Message]:
|
||||
async def __call__(self, messages: list[Message]) -> list[Message]:
|
||||
"""Compress the message list.
|
||||
|
||||
Args:
|
||||
@@ -27,19 +34,10 @@ class ContextCompressor(ABC):
|
||||
Returns:
|
||||
The compressed message list.
|
||||
"""
|
||||
pass
|
||||
...
|
||||
|
||||
|
||||
class DefaultCompressor(ContextCompressor):
|
||||
"""Default compressor implementation.
|
||||
Returns the original messages.
|
||||
"""
|
||||
|
||||
async def compress(self, messages: list[Message]) -> list[Message]:
|
||||
return messages
|
||||
|
||||
|
||||
class TruncateByTurnsCompressor(ContextCompressor):
|
||||
class TruncateByTurnsCompressor:
|
||||
"""Truncate by turns compressor implementation.
|
||||
Truncates the message list by removing older turns.
|
||||
"""
|
||||
@@ -52,17 +50,47 @@ class TruncateByTurnsCompressor(ContextCompressor):
|
||||
"""
|
||||
self.truncate_turns = truncate_turns
|
||||
|
||||
async def compress(self, messages: list[Message]) -> list[Message]:
|
||||
async def __call__(self, messages: list[Message]) -> list[Message]:
|
||||
truncator = ContextTruncator()
|
||||
truncated_messages = truncator.truncate_by_turns(
|
||||
truncated_messages = truncator.truncate_by_dropping_oldest_turns(
|
||||
messages,
|
||||
keep_most_recent_turns=0,
|
||||
dequeue_turns=self.truncate_turns,
|
||||
drop_turns=self.truncate_turns,
|
||||
)
|
||||
return truncated_messages
|
||||
|
||||
|
||||
class LLMSummaryCompressor(ContextCompressor):
|
||||
def split_history(
|
||||
messages: list[Message], keep_recent: int
|
||||
) -> tuple[list[Message], list[Message], list[Message]]:
|
||||
"""Split the message list into system messages, messages to summarize, and recent messages.
|
||||
|
||||
Args:
|
||||
messages: The original message list.
|
||||
keep_recent: The number of latest messages to keep.
|
||||
|
||||
Returns:
|
||||
tuple: (system_messages, messages_to_summarize, recent_messages)
|
||||
"""
|
||||
# keep the system messages
|
||||
first_non_system = 0
|
||||
for i, msg in enumerate(messages):
|
||||
if msg.role != "system":
|
||||
first_non_system = i
|
||||
break
|
||||
|
||||
system_messages = messages[:first_non_system]
|
||||
non_system_messages = messages[first_non_system:]
|
||||
|
||||
if len(non_system_messages) <= keep_recent:
|
||||
return system_messages, [], non_system_messages
|
||||
|
||||
messages_to_summarize = non_system_messages[:-keep_recent]
|
||||
recent_messages = non_system_messages[-keep_recent:]
|
||||
|
||||
return system_messages, messages_to_summarize, recent_messages
|
||||
|
||||
|
||||
class LLMSummaryCompressor:
|
||||
"""LLM-based summary compressor.
|
||||
Uses LLM to summarize the old conversation history, keeping the latest messages.
|
||||
"""
|
||||
@@ -90,7 +118,7 @@ class LLMSummaryCompressor(ContextCompressor):
|
||||
"4. Write the summary in the user's language.\n"
|
||||
)
|
||||
|
||||
async def compress(self, messages: list[Message]) -> list[Message]:
|
||||
async def __call__(self, messages: list[Message]) -> list[Message]:
|
||||
"""Use LLM to generate a summary of the conversation history.
|
||||
|
||||
Process:
|
||||
@@ -101,12 +129,9 @@ class LLMSummaryCompressor(ContextCompressor):
|
||||
if len(messages) <= self.keep_recent + 1:
|
||||
return messages
|
||||
|
||||
# keep the system message
|
||||
system_msg = messages[0] if messages and messages[0].role == "system" else None
|
||||
start_idx = 1 if system_msg else 0
|
||||
|
||||
messages_to_summarize = messages[start_idx : -self.keep_recent]
|
||||
recent_messages = messages[-self.keep_recent :]
|
||||
system_messages, messages_to_summarize, recent_messages = split_history(
|
||||
messages, self.keep_recent
|
||||
)
|
||||
|
||||
if not messages_to_summarize:
|
||||
return messages
|
||||
@@ -125,8 +150,7 @@ class LLMSummaryCompressor(ContextCompressor):
|
||||
|
||||
# build result
|
||||
result = []
|
||||
if system_msg:
|
||||
result.append(system_msg)
|
||||
result.extend(system_messages)
|
||||
|
||||
result.append(
|
||||
Message(
|
||||
|
||||
@@ -0,0 +1,28 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from astrbot.core.provider.provider import Provider
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContextConfig:
|
||||
"""Context configuration class."""
|
||||
|
||||
max_context_tokens: int = 0
|
||||
"""Maximum number of context tokens. <= 0 means no limit."""
|
||||
enforce_max_turns: int = -1 # -1 means no limit
|
||||
"""Maximum number of conversation turns to keep. -1 means no limit. Executed before compression."""
|
||||
truncate_turns: int = 1
|
||||
"""Number of conversation turns to discard at once when truncation is triggered.
|
||||
Two processes will use this value:
|
||||
|
||||
1. Enforce max turns truncation.
|
||||
2. Truncation by turns compression strategy.
|
||||
"""
|
||||
llm_compress_instruction: str | None = None
|
||||
"""Instruction prompt for LLM-based compression."""
|
||||
llm_compress_keep_recent: int = 0
|
||||
"""Number of recent messages to keep during LLM-based compression."""
|
||||
llm_compress_provider: "Provider | None" = None
|
||||
"""LLM provider used for compression tasks. If None, truncation strategy is used."""
|
||||
@@ -1,14 +1,10 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from astrbot import logger
|
||||
|
||||
from ..message import Message
|
||||
from .compressor import LLMSummaryCompressor, TruncateByTurnsCompressor
|
||||
from .token_counter import TokenCounter
|
||||
from .truncator import ContextTruncator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from astrbot.core.provider.provider import Provider
|
||||
from .config import ContextConfig
|
||||
|
||||
|
||||
class ContextManager:
|
||||
@@ -19,11 +15,7 @@ class ContextManager:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_context_tokens: int = 0,
|
||||
truncate_turns: int = 1,
|
||||
llm_compress_instruction: str | None = None,
|
||||
llm_compress_keep_recent: int = 4,
|
||||
llm_compress_provider: "Provider | None" = None,
|
||||
config: ContextConfig,
|
||||
):
|
||||
"""Initialize the context manager.
|
||||
|
||||
@@ -32,26 +24,23 @@ class ContextManager:
|
||||
2. LLM-based compression: use LLM to summarize old messages.
|
||||
|
||||
Args:
|
||||
max_context_tokens: The maximum context tokens. <= 0 means no limit.
|
||||
truncate_turns: For turncate strategy. The number of turns to discard when truncating.
|
||||
llm_compress_instruction: The instruction text for LLM compression.
|
||||
llm_compress_keep_recent: The number of recent messages to keep during LLM compression.
|
||||
llm_compress_provider: The LLM provider for compression.
|
||||
config: The context configuration.
|
||||
"""
|
||||
self.max_context_tokens = max_context_tokens
|
||||
self.truncate_turns = truncate_turns
|
||||
self.config = config
|
||||
|
||||
self.token_counter = TokenCounter()
|
||||
self.truncator = ContextTruncator()
|
||||
|
||||
if llm_compress_provider:
|
||||
if config.llm_compress_provider:
|
||||
self.compressor = LLMSummaryCompressor(
|
||||
provider=llm_compress_provider,
|
||||
keep_recent=llm_compress_keep_recent,
|
||||
instruction_text=llm_compress_instruction,
|
||||
provider=config.llm_compress_provider,
|
||||
keep_recent=config.llm_compress_keep_recent,
|
||||
instruction_text=config.llm_compress_instruction,
|
||||
)
|
||||
else:
|
||||
self.compressor = TruncateByTurnsCompressor(truncate_turns=truncate_turns)
|
||||
self.compressor = TruncateByTurnsCompressor(
|
||||
truncate_turns=config.truncate_turns
|
||||
)
|
||||
|
||||
async def process(self, messages: list[Message]) -> list[Message]:
|
||||
"""Process the messages.
|
||||
@@ -62,17 +51,30 @@ class ContextManager:
|
||||
Returns:
|
||||
The processed message list.
|
||||
"""
|
||||
if self.max_context_tokens <= 0:
|
||||
try:
|
||||
result = messages
|
||||
|
||||
# 1. 基于轮次的截断 (Enforce max turns)
|
||||
if self.config.enforce_max_turns != -1:
|
||||
result = self.truncator.truncate_by_turns(
|
||||
result,
|
||||
keep_most_recent_turns=self.config.enforce_max_turns,
|
||||
drop_turns=self.config.truncate_turns,
|
||||
)
|
||||
|
||||
# 2. 基于 token 的压缩
|
||||
if self.config.max_context_tokens > 0:
|
||||
# check if the messages need to be compressed
|
||||
needs_compression, _ = await self._initial_token_check(result)
|
||||
|
||||
# compress/truncate the messages if needed
|
||||
result = await self._run_compression(result, needs_compression)
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Error during context processing: {e}", exc_info=True)
|
||||
return messages
|
||||
|
||||
# check if the messages need to be compressed
|
||||
needs_compression, _ = await self._initial_token_check(messages)
|
||||
|
||||
# compress/truncate the messages if needed
|
||||
messages = await self._run_compression(messages, needs_compression)
|
||||
|
||||
return messages
|
||||
|
||||
async def _initial_token_check(
|
||||
self, messages: list[Message]
|
||||
) -> tuple[bool, int | None]:
|
||||
@@ -87,15 +89,15 @@ class ContextManager:
|
||||
"""
|
||||
if not messages:
|
||||
return False, None
|
||||
if self.max_context_tokens <= 0:
|
||||
if self.config.max_context_tokens <= 0:
|
||||
return False, None
|
||||
|
||||
total_tokens = self.token_counter.count_tokens(messages)
|
||||
|
||||
logger.debug(
|
||||
f"ContextManager: total tokens = {total_tokens}, max_context_tokens = {self.max_context_tokens}"
|
||||
f"ContextManager: total tokens = {total_tokens}, max_context_tokens = {self.config.max_context_tokens}"
|
||||
)
|
||||
usage_rate = total_tokens / self.max_context_tokens
|
||||
usage_rate = total_tokens / self.config.max_context_tokens
|
||||
|
||||
needs_compression = usage_rate > self.COMPRESSION_THRESHOLD
|
||||
return needs_compression, total_tokens if needs_compression else None
|
||||
@@ -115,14 +117,17 @@ class ContextManager:
|
||||
"""
|
||||
if not needs_compression:
|
||||
return messages
|
||||
if self.max_context_tokens <= 0:
|
||||
if self.config.max_context_tokens <= 0:
|
||||
return messages
|
||||
|
||||
messages = await self.compressor.compress(messages)
|
||||
messages = await self.compressor(messages)
|
||||
|
||||
# double check
|
||||
tokens_after_summary = self.token_counter.count_tokens(messages)
|
||||
if tokens_after_summary / self.max_context_tokens > self.COMPRESSION_THRESHOLD:
|
||||
if (
|
||||
tokens_after_summary / self.config.max_context_tokens
|
||||
> self.COMPRESSION_THRESHOLD
|
||||
):
|
||||
# still over 82%, truncate by half
|
||||
messages = self._compress_by_halving(messages)
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ class ContextTruncator:
|
||||
self,
|
||||
messages: list[Message],
|
||||
keep_most_recent_turns: int,
|
||||
dequeue_turns: int = 1,
|
||||
drop_turns: int = 1,
|
||||
) -> list[Message]:
|
||||
"""截断上下文列表,确保不超过最大长度。
|
||||
一个 turn 包含一个 user 消息和一个 assistant 消息。
|
||||
@@ -32,27 +32,33 @@ class ContextTruncator:
|
||||
Args:
|
||||
messages: 上下文列表
|
||||
keep_most_recent_turns: 保留最近的对话轮数
|
||||
dequeue_turns: 一次性丢弃的对话轮数
|
||||
drop_turns: 一次性丢弃的对话轮数
|
||||
|
||||
Returns:
|
||||
截断后的上下文列表
|
||||
"""
|
||||
if keep_most_recent_turns == -1:
|
||||
return messages
|
||||
if len(messages) <= keep_most_recent_turns:
|
||||
return messages
|
||||
if len(messages) // 2 <= keep_most_recent_turns:
|
||||
|
||||
first_non_system = 0
|
||||
for i, msg in enumerate(messages):
|
||||
if msg.role != "system":
|
||||
first_non_system = i
|
||||
break
|
||||
|
||||
system_messages = messages[:first_non_system]
|
||||
non_system_messages = messages[first_non_system:]
|
||||
|
||||
if len(non_system_messages) // 2 <= keep_most_recent_turns:
|
||||
return messages
|
||||
|
||||
system_message = None
|
||||
if messages[0].role == "system":
|
||||
system_message = messages[0]
|
||||
messages = messages[1:]
|
||||
num_to_keep = keep_most_recent_turns - drop_turns + 1
|
||||
if num_to_keep <= 0:
|
||||
truncated_contexts = []
|
||||
else:
|
||||
truncated_contexts = non_system_messages[-num_to_keep * 2 :]
|
||||
|
||||
truncated_contexts = messages[
|
||||
-(keep_most_recent_turns - dequeue_turns + 1) * 2 :
|
||||
]
|
||||
# 找到第一个role 为 user 的索引,确保上下文格式正确
|
||||
# 找到第一个 role 为 user 的索引,确保上下文格式正确
|
||||
index = next(
|
||||
(i for i, item in enumerate(truncated_contexts) if item.role == "user"),
|
||||
None,
|
||||
@@ -60,10 +66,45 @@ class ContextTruncator:
|
||||
if index is not None and index > 0:
|
||||
truncated_contexts = truncated_contexts[index:]
|
||||
|
||||
if system_message is not None:
|
||||
truncated_contexts = [system_message] + truncated_contexts
|
||||
result = system_messages + truncated_contexts
|
||||
|
||||
return self.fix_messages(truncated_contexts)
|
||||
return self.fix_messages(result)
|
||||
|
||||
def truncate_by_dropping_oldest_turns(
|
||||
self,
|
||||
messages: list[Message],
|
||||
drop_turns: int = 1,
|
||||
) -> list[Message]:
|
||||
"""丢弃最旧的 N 个对话轮次。"""
|
||||
if drop_turns <= 0:
|
||||
return messages
|
||||
|
||||
first_non_system = 0
|
||||
for i, msg in enumerate(messages):
|
||||
if msg.role != "system":
|
||||
first_non_system = i
|
||||
break
|
||||
|
||||
system_messages = messages[:first_non_system]
|
||||
non_system_messages = messages[first_non_system:]
|
||||
|
||||
if len(non_system_messages) // 2 <= drop_turns:
|
||||
truncated_non_system = []
|
||||
else:
|
||||
truncated_non_system = non_system_messages[drop_turns * 2 :]
|
||||
|
||||
index = next(
|
||||
(i for i, item in enumerate(truncated_non_system) if item.role == "user"),
|
||||
None,
|
||||
)
|
||||
if index is not None:
|
||||
truncated_non_system = truncated_non_system[index:]
|
||||
elif truncated_non_system:
|
||||
truncated_non_system = []
|
||||
|
||||
result = system_messages + truncated_non_system
|
||||
|
||||
return self.fix_messages(result)
|
||||
|
||||
def truncate_by_halving(
|
||||
self,
|
||||
@@ -79,16 +120,22 @@ class ContextTruncator:
|
||||
first_non_system = i
|
||||
break
|
||||
|
||||
messages_to_delete = (len(messages) - first_non_system) // 2
|
||||
system_messages = messages[:first_non_system]
|
||||
non_system_messages = messages[first_non_system:]
|
||||
|
||||
result = messages[:first_non_system]
|
||||
result.extend(messages[first_non_system + messages_to_delete :])
|
||||
messages_to_delete = len(non_system_messages) // 2
|
||||
if messages_to_delete == 0:
|
||||
return messages
|
||||
|
||||
truncated_non_system = non_system_messages[messages_to_delete:]
|
||||
|
||||
index = next(
|
||||
(i for i, item in enumerate(result) if item.role == "user"),
|
||||
(i for i, item in enumerate(truncated_non_system) if item.role == "user"),
|
||||
None,
|
||||
)
|
||||
if index is not None:
|
||||
result = result[index:]
|
||||
truncated_non_system = truncated_non_system[index:]
|
||||
|
||||
result = system_messages + truncated_non_system
|
||||
|
||||
return self.fix_messages(result)
|
||||
|
||||
@@ -26,7 +26,7 @@ from astrbot.core.provider.entities import (
|
||||
from astrbot.core.provider.provider import Provider
|
||||
|
||||
from ..context.manager import ContextManager
|
||||
from ..context.truncator import ContextTruncator
|
||||
from ..context.config import ContextConfig
|
||||
from ..hooks import BaseAgentRunHooks
|
||||
from ..message import AssistantMessageSegment, Message, ToolCallMessageSegment
|
||||
from ..response import AgentResponseData, AgentStats
|
||||
@@ -70,15 +70,17 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
# we will do compress when:
|
||||
# 1. before requesting LLM
|
||||
# TODO: 2. after LLM output a tool call
|
||||
self.context_manager = ContextManager(
|
||||
# <=0 will never trigger context compression
|
||||
self.context_config = ContextConfig(
|
||||
# <=0 will never do compress
|
||||
max_context_tokens=provider.provider_config.get("max_context_tokens", 0),
|
||||
# enforce max turns before compression
|
||||
enforce_max_turns=self.enforce_max_turns,
|
||||
truncate_turns=self.truncate_turns,
|
||||
llm_compress_instruction=self.llm_compress_instruction,
|
||||
llm_compress_keep_recent=self.llm_compress_keep_recent,
|
||||
llm_compress_provider=self.llm_compress_provider,
|
||||
)
|
||||
self.context_truncator = ContextTruncator()
|
||||
self.context_manager = ContextManager(self.context_config)
|
||||
|
||||
self.provider = provider
|
||||
self.final_llm_resp = None
|
||||
@@ -121,12 +123,6 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
else:
|
||||
yield await self.provider.text_chat(**payload)
|
||||
|
||||
async def do_context_compress(self):
|
||||
"""检查并执行上下文压缩。"""
|
||||
original_messages = self.run_context.messages
|
||||
compressed_messages = await self.context_manager.process(original_messages)
|
||||
self.run_context.messages = compressed_messages
|
||||
|
||||
@override
|
||||
async def step(self):
|
||||
"""Process a single step of the agent.
|
||||
@@ -145,23 +141,10 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
self._transition_state(AgentState.RUNNING)
|
||||
llm_resp_result = None
|
||||
|
||||
# do truncate
|
||||
if self.enforce_max_turns != -1:
|
||||
try:
|
||||
truncated_messages = self.context_truncator.truncate_by_turns(
|
||||
self.run_context.messages,
|
||||
keep_most_recent_turns=self.enforce_max_turns,
|
||||
dequeue_turns=self.truncate_turns,
|
||||
)
|
||||
self.run_context.messages = truncated_messages
|
||||
except Exception as e:
|
||||
logger.error(f"Error during context truncation: {e}", exc_info=True)
|
||||
|
||||
# check compress
|
||||
try:
|
||||
await self.do_context_compress()
|
||||
except Exception as e:
|
||||
logger.error(f"Error during context compression: {e}", exc_info=True)
|
||||
# do truncate and compress
|
||||
self.run_context.messages = await self.context_manager.process(
|
||||
self.run_context.messages
|
||||
)
|
||||
|
||||
async for llm_response in self._iter_llm_responses():
|
||||
if llm_response.is_chunk:
|
||||
|
||||
Generated
+5525
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,633 @@
|
||||
"""Comprehensive tests for ContextManager."""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
# Add parent directory to path to avoid circular import issues
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||
|
||||
from astrbot.core.agent.context.config import ContextConfig
|
||||
from astrbot.core.agent.context.manager import ContextManager
|
||||
from astrbot.core.agent.message import Message, TextPart
|
||||
|
||||
|
||||
class TestContextManager:
|
||||
"""Test suite for ContextManager."""
|
||||
|
||||
def create_message(
|
||||
self, role: Literal["system", "user", "assistant", "tool"], content: str
|
||||
) -> Message:
|
||||
"""Helper to create a simple text message."""
|
||||
return Message(role=role, content=content)
|
||||
|
||||
def create_messages(self, count: int) -> list[Message]:
|
||||
"""Helper to create alternating user/assistant messages."""
|
||||
messages = []
|
||||
for i in range(count):
|
||||
role = "user" if i % 2 == 0 else "assistant"
|
||||
messages.append(self.create_message(role, f"Message {i}"))
|
||||
return messages
|
||||
|
||||
# ==================== Basic Initialization Tests ====================
|
||||
|
||||
def test_init_with_minimal_config(self):
|
||||
"""Test initialization with minimal configuration."""
|
||||
config = ContextConfig()
|
||||
manager = ContextManager(config)
|
||||
|
||||
assert manager.config == config
|
||||
assert manager.token_counter is not None
|
||||
assert manager.truncator is not None
|
||||
assert manager.compressor is not None
|
||||
|
||||
def test_init_with_llm_compressor(self):
|
||||
"""Test initialization with LLM-based compression."""
|
||||
mock_provider = MagicMock()
|
||||
config = ContextConfig(
|
||||
llm_compress_provider=mock_provider,
|
||||
llm_compress_keep_recent=5,
|
||||
llm_compress_instruction="Summarize the conversation",
|
||||
)
|
||||
manager = ContextManager(config)
|
||||
|
||||
from astrbot.core.agent.context.compressor import LLMSummaryCompressor
|
||||
|
||||
assert isinstance(manager.compressor, LLMSummaryCompressor)
|
||||
|
||||
def test_init_with_truncate_compressor(self):
|
||||
"""Test initialization with truncate-based compression (default)."""
|
||||
config = ContextConfig(truncate_turns=3)
|
||||
manager = ContextManager(config)
|
||||
|
||||
from astrbot.core.agent.context.compressor import TruncateByTurnsCompressor
|
||||
|
||||
assert isinstance(manager.compressor, TruncateByTurnsCompressor)
|
||||
|
||||
# ==================== Empty and Edge Cases ====================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_empty_messages(self):
|
||||
"""Test processing an empty message list."""
|
||||
config = ContextConfig()
|
||||
manager = ContextManager(config)
|
||||
|
||||
result = await manager.process([])
|
||||
|
||||
assert result == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_single_message(self):
|
||||
"""Test processing a single message."""
|
||||
config = ContextConfig()
|
||||
manager = ContextManager(config)
|
||||
|
||||
messages = [self.create_message("user", "Hello")]
|
||||
result = await manager.process(messages)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].content == "Hello"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_with_no_limits(self):
|
||||
"""Test processing when no limits are set (no truncation or compression)."""
|
||||
config = ContextConfig(max_context_tokens=0, enforce_max_turns=-1)
|
||||
manager = ContextManager(config)
|
||||
|
||||
messages = self.create_messages(20)
|
||||
result = await manager.process(messages)
|
||||
|
||||
assert len(result) == 20
|
||||
assert result == messages
|
||||
|
||||
# ==================== Enforce Max Turns Tests ====================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enforce_max_turns_basic(self):
|
||||
"""Test basic enforce_max_turns functionality."""
|
||||
config = ContextConfig(enforce_max_turns=3, truncate_turns=1)
|
||||
manager = ContextManager(config)
|
||||
|
||||
# Create 10 turns (20 messages)
|
||||
messages = self.create_messages(20)
|
||||
result = await manager.process(messages)
|
||||
|
||||
# Should keep only 3 most recent turns (6 messages)
|
||||
assert len(result) <= 8 # May vary due to truncation logic
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enforce_max_turns_zero(self):
|
||||
"""Test enforce_max_turns with value 0 (should keep nothing)."""
|
||||
config = ContextConfig(enforce_max_turns=0, truncate_turns=1)
|
||||
manager = ContextManager(config)
|
||||
|
||||
messages = self.create_messages(10)
|
||||
result = await manager.process(messages)
|
||||
|
||||
# Should result in empty or minimal message list
|
||||
assert len(result) <= 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enforce_max_turns_negative(self):
|
||||
"""Test enforce_max_turns with -1 (no limit)."""
|
||||
config = ContextConfig(enforce_max_turns=-1)
|
||||
manager = ContextManager(config)
|
||||
|
||||
messages = self.create_messages(20)
|
||||
result = await manager.process(messages)
|
||||
|
||||
assert len(result) == 20
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enforce_max_turns_with_system_messages(self):
|
||||
"""Test enforce_max_turns preserves system messages."""
|
||||
config = ContextConfig(enforce_max_turns=2, truncate_turns=1)
|
||||
manager = ContextManager(config)
|
||||
|
||||
messages = [
|
||||
self.create_message("system", "System instruction"),
|
||||
*self.create_messages(10),
|
||||
]
|
||||
result = await manager.process(messages)
|
||||
|
||||
# System message should be preserved
|
||||
system_msgs = [m for m in result if m.role == "system"]
|
||||
assert len(system_msgs) >= 1
|
||||
assert system_msgs[0].content == "System instruction"
|
||||
|
||||
# ==================== Token-based Compression Tests ====================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_compression_not_triggered_below_threshold(self):
|
||||
"""Test that compression is not triggered below threshold."""
|
||||
config = ContextConfig(max_context_tokens=1000)
|
||||
manager = ContextManager(config)
|
||||
|
||||
# Create messages that total less than threshold
|
||||
messages = [self.create_message("user", "Hi" * 50)] # ~100 tokens
|
||||
|
||||
with patch.object(
|
||||
manager.compressor, "__call__", new_callable=AsyncMock
|
||||
) as mock_compress:
|
||||
result = await manager.process(messages)
|
||||
|
||||
# Compressor should not be called
|
||||
mock_compress.assert_not_called()
|
||||
assert result == messages
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_compression_triggered_above_threshold(self):
|
||||
"""Test that compression is triggered above threshold."""
|
||||
config = ContextConfig(max_context_tokens=100, truncate_turns=1)
|
||||
manager = ContextManager(config)
|
||||
|
||||
# Create messages that exceed threshold (0.82 * 100 = 82 tokens)
|
||||
# 300 chars * 0.3 = 90 tokens > 82 threshold
|
||||
long_text = "x" * 300 # ~90 tokens, above threshold
|
||||
messages = [self.create_message("user", long_text)]
|
||||
|
||||
# Mock compressor to return smaller result
|
||||
compressed = [self.create_message("user", "short")]
|
||||
|
||||
# Create a mock compressor that we can track
|
||||
mock_compress = AsyncMock(return_value=compressed)
|
||||
manager.compressor = mock_compress
|
||||
|
||||
result = await manager.process(messages)
|
||||
|
||||
# Compressor should be called
|
||||
mock_compress.assert_called_once()
|
||||
# Result should be the compressed version
|
||||
assert len(result) == len(compressed)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_compression_with_zero_max_tokens(self):
|
||||
"""Test that compression is skipped when max_context_tokens is 0."""
|
||||
config = ContextConfig(max_context_tokens=0)
|
||||
manager = ContextManager(config)
|
||||
|
||||
messages = [self.create_message("user", "x" * 10000)]
|
||||
|
||||
with patch.object(
|
||||
manager.compressor, "__call__", new_callable=AsyncMock
|
||||
) as mock_compress:
|
||||
result = await manager.process(messages)
|
||||
|
||||
# Compressor should not be called
|
||||
mock_compress.assert_not_called()
|
||||
assert result == messages
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_compression_with_negative_max_tokens(self):
|
||||
"""Test that compression is skipped when max_context_tokens is negative."""
|
||||
config = ContextConfig(max_context_tokens=-100)
|
||||
manager = ContextManager(config)
|
||||
|
||||
messages = [self.create_message("user", "x" * 10000)]
|
||||
|
||||
with patch.object(
|
||||
manager.compressor, "__call__", new_callable=AsyncMock
|
||||
) as mock_compress:
|
||||
result = await manager.process(messages)
|
||||
|
||||
# Compressor should not be called
|
||||
mock_compress.assert_not_called()
|
||||
assert result == messages
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_double_check_after_compression(self):
|
||||
"""Test that halving is applied if still over threshold after compression."""
|
||||
config = ContextConfig(max_context_tokens=100)
|
||||
manager = ContextManager(config)
|
||||
|
||||
# Create messages that would still be over threshold after compression
|
||||
long_messages = [self.create_message("user", "x" * 200) for _ in range(10)]
|
||||
|
||||
# Mock compressor to return messages still over threshold
|
||||
async def mock_compress(msgs):
|
||||
return msgs # Return same messages (still over limit)
|
||||
|
||||
with patch.object(manager.compressor, "__call__", new=mock_compress):
|
||||
with patch.object(
|
||||
manager, "_compress_by_halving", return_value=long_messages[:5]
|
||||
) as mock_halving:
|
||||
_ = await manager.process(long_messages)
|
||||
|
||||
# Halving should be called
|
||||
mock_halving.assert_called_once()
|
||||
|
||||
# ==================== Combined Truncation and Compression Tests ====================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_combined_enforce_turns_and_token_limit(self):
|
||||
"""Test combining enforce_max_turns and token limit."""
|
||||
config = ContextConfig(
|
||||
enforce_max_turns=5, max_context_tokens=500, truncate_turns=1
|
||||
)
|
||||
manager = ContextManager(config)
|
||||
|
||||
# Create many messages
|
||||
messages = self.create_messages(30)
|
||||
|
||||
result = await manager.process(messages)
|
||||
|
||||
# Should be truncated by both mechanisms
|
||||
assert len(result) < 30
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sequential_processing_order(self):
|
||||
"""Test that enforce_max_turns happens before token compression."""
|
||||
config = ContextConfig(enforce_max_turns=5, max_context_tokens=1000)
|
||||
manager = ContextManager(config)
|
||||
|
||||
messages = self.create_messages(20)
|
||||
|
||||
# Mock the truncator to track calls
|
||||
with patch.object(
|
||||
manager.truncator,
|
||||
"truncate_by_turns",
|
||||
wraps=manager.truncator.truncate_by_turns,
|
||||
) as mock_truncate:
|
||||
await manager.process(messages)
|
||||
|
||||
# Truncator should be called first
|
||||
mock_truncate.assert_called_once()
|
||||
|
||||
# ==================== Error Handling Tests ====================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_handling_returns_original_messages(self):
|
||||
"""Test that errors during processing return original messages."""
|
||||
config = ContextConfig(max_context_tokens=100)
|
||||
manager = ContextManager(config)
|
||||
|
||||
messages = self.create_messages(5)
|
||||
|
||||
# Make compressor raise an exception
|
||||
with patch.object(
|
||||
manager.compressor, "__call__", side_effect=Exception("Test error")
|
||||
):
|
||||
result = await manager.process(messages)
|
||||
|
||||
# Should return original messages despite error
|
||||
assert result == messages
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_handling_logs_exception(self):
|
||||
"""Test that errors are logged."""
|
||||
config = ContextConfig(max_context_tokens=100)
|
||||
manager = ContextManager(config)
|
||||
|
||||
# Create messages that will trigger compression (> 82 tokens)
|
||||
messages = [self.create_message("user", "x" * 300)] # ~90 tokens
|
||||
|
||||
# Replace compressor with one that raises an exception
|
||||
manager.compressor = AsyncMock(side_effect=Exception("Test error"))
|
||||
|
||||
with patch("astrbot.core.agent.context.manager.logger") as mock_logger:
|
||||
result = await manager.process(messages)
|
||||
|
||||
# Logger error method should be called
|
||||
assert mock_logger.error.called
|
||||
# Should return original messages on error
|
||||
assert result == messages
|
||||
|
||||
# ==================== Multi-modal Content Tests ====================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_messages_with_textpart_content(self):
|
||||
"""Test processing messages with TextPart content."""
|
||||
config = ContextConfig()
|
||||
manager = ContextManager(config)
|
||||
|
||||
messages = [
|
||||
Message(role="user", content=[TextPart(text="Hello")]),
|
||||
Message(role="assistant", content=[TextPart(text="Hi there")]),
|
||||
]
|
||||
|
||||
result = await manager.process(messages)
|
||||
|
||||
assert len(result) == 2
|
||||
assert result == messages
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_counting_with_multimodal_content(self):
|
||||
"""Test token counting works with multi-modal content."""
|
||||
config = ContextConfig(max_context_tokens=50)
|
||||
manager = ContextManager(config)
|
||||
|
||||
# Need enough tokens to exceed threshold: 50 * 0.82 = 41 tokens
|
||||
# 150 chars * 0.3 = 45 tokens > 41
|
||||
messages = [
|
||||
Message(role="user", content=[TextPart(text="x" * 150)]),
|
||||
]
|
||||
|
||||
# Should trigger compression due to token count
|
||||
needs_compression, tokens = await manager._initial_token_check(messages)
|
||||
|
||||
assert tokens is not None # Tokens should be counted
|
||||
assert needs_compression # Should trigger compression
|
||||
|
||||
# ==================== Tool Calls Tests ====================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_messages_with_tool_calls(self):
|
||||
"""Test processing messages with tool calls."""
|
||||
config = ContextConfig()
|
||||
manager = ContextManager(config)
|
||||
|
||||
messages = [
|
||||
Message(
|
||||
role="assistant",
|
||||
content="Let me search for that",
|
||||
tool_calls=[
|
||||
{
|
||||
"id": "call_1",
|
||||
"type": "function",
|
||||
"function": {"name": "search", "arguments": "{}"},
|
||||
}
|
||||
],
|
||||
),
|
||||
Message(role="tool", content="Search result", tool_call_id="call_1"),
|
||||
]
|
||||
|
||||
result = await manager.process(messages)
|
||||
|
||||
assert len(result) == 2
|
||||
|
||||
# ==================== Initial Token Check Tests ====================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initial_token_check_empty_messages(self):
|
||||
"""Test _initial_token_check with empty messages."""
|
||||
config = ContextConfig(max_context_tokens=100)
|
||||
manager = ContextManager(config)
|
||||
|
||||
needs_compression, tokens = await manager._initial_token_check([])
|
||||
|
||||
assert not needs_compression
|
||||
assert tokens is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initial_token_check_no_limit(self):
|
||||
"""Test _initial_token_check when max_context_tokens is 0."""
|
||||
config = ContextConfig(max_context_tokens=0)
|
||||
manager = ContextManager(config)
|
||||
|
||||
messages = [self.create_message("user", "x" * 1000)]
|
||||
needs_compression, tokens = await manager._initial_token_check(messages)
|
||||
|
||||
assert not needs_compression
|
||||
assert tokens is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initial_token_check_below_threshold(self):
|
||||
"""Test _initial_token_check when below compression threshold."""
|
||||
config = ContextConfig(max_context_tokens=1000)
|
||||
manager = ContextManager(config)
|
||||
|
||||
messages = [self.create_message("user", "Hello")]
|
||||
needs_compression, tokens = await manager._initial_token_check(messages)
|
||||
|
||||
assert not needs_compression
|
||||
assert tokens is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initial_token_check_above_threshold(self):
|
||||
"""Test _initial_token_check when above compression threshold."""
|
||||
config = ContextConfig(max_context_tokens=100)
|
||||
manager = ContextManager(config)
|
||||
|
||||
# Create message with ~90 tokens (above 0.82 * 100 = 82)
|
||||
messages = [self.create_message("user", "这是测试" * 50)]
|
||||
needs_compression, tokens = await manager._initial_token_check(messages)
|
||||
|
||||
assert needs_compression
|
||||
assert tokens is not None
|
||||
assert tokens > 82
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initial_token_check_exactly_at_threshold(self):
|
||||
"""Test _initial_token_check when just above threshold."""
|
||||
config = ContextConfig(max_context_tokens=100)
|
||||
manager = ContextManager(config)
|
||||
|
||||
# Create message with >82 tokens (0.82 * 100)
|
||||
# 300 chars * 0.3 = 90 tokens > 82 (threshold)
|
||||
messages = [self.create_message("user", "x" * 300)] # ~90 tokens
|
||||
needs_compression, tokens = await manager._initial_token_check(messages)
|
||||
|
||||
# Above threshold should trigger compression
|
||||
assert tokens is not None
|
||||
assert needs_compression
|
||||
|
||||
# ==================== Compression by Halving Tests ====================
|
||||
|
||||
def test_compress_by_halving_basic(self):
|
||||
"""Test _compress_by_halving removes middle 50%."""
|
||||
config = ContextConfig()
|
||||
manager = ContextManager(config)
|
||||
|
||||
messages = self.create_messages(10)
|
||||
result = manager._compress_by_halving(messages)
|
||||
|
||||
# Should keep roughly half
|
||||
assert len(result) < len(messages)
|
||||
|
||||
def test_compress_by_halving_empty_list(self):
|
||||
"""Test _compress_by_halving with empty list."""
|
||||
config = ContextConfig()
|
||||
manager = ContextManager(config)
|
||||
|
||||
result = manager._compress_by_halving([])
|
||||
|
||||
assert result == []
|
||||
|
||||
def test_compress_by_halving_single_message(self):
|
||||
"""Test _compress_by_halving with single message."""
|
||||
config = ContextConfig()
|
||||
manager = ContextManager(config)
|
||||
|
||||
messages = [self.create_message("user", "Hello")]
|
||||
result = manager._compress_by_halving(messages)
|
||||
|
||||
assert len(result) <= 1
|
||||
|
||||
# ==================== Complex Scenarios ====================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_compression_cycles(self):
|
||||
"""Test that compression can be triggered multiple times in sequence."""
|
||||
config = ContextConfig(max_context_tokens=50, truncate_turns=1)
|
||||
manager = ContextManager(config)
|
||||
|
||||
# Process messages multiple times
|
||||
messages = self.create_messages(10)
|
||||
|
||||
result1 = await manager.process(messages)
|
||||
result2 = await manager.process(result1)
|
||||
result3 = await manager.process(result2)
|
||||
|
||||
# Each cycle should maintain or reduce message count
|
||||
assert len(result3) <= len(result2) <= len(result1)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_alternating_roles_preserved(self):
|
||||
"""Test that user/assistant alternation is preserved after processing."""
|
||||
config = ContextConfig(enforce_max_turns=3, truncate_turns=1)
|
||||
manager = ContextManager(config)
|
||||
|
||||
messages = self.create_messages(20)
|
||||
result = await manager.process(messages)
|
||||
|
||||
# Check that roles still alternate (excluding system messages)
|
||||
non_system = [m for m in result if m.role != "system"]
|
||||
if len(non_system) >= 2:
|
||||
# Should start with user
|
||||
assert non_system[0].role == "user"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_compression_threshold_constant(self):
|
||||
"""Test that COMPRESSION_THRESHOLD is used correctly."""
|
||||
config = ContextConfig(max_context_tokens=100)
|
||||
manager = ContextManager(config)
|
||||
|
||||
# Verify the threshold is 0.82
|
||||
assert manager.COMPRESSION_THRESHOLD == 0.82
|
||||
|
||||
# Create messages just below threshold
|
||||
messages = [self.create_message("user", "x" * 81)] # ~24 tokens
|
||||
|
||||
needs_compression, _ = await manager._initial_token_check(messages)
|
||||
assert not needs_compression
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_large_batch_processing(self):
|
||||
"""Test processing a large batch of messages."""
|
||||
config = ContextConfig(
|
||||
enforce_max_turns=10, max_context_tokens=1000, truncate_turns=2
|
||||
)
|
||||
manager = ContextManager(config)
|
||||
|
||||
# Create 100 messages (50 turns)
|
||||
messages = self.create_messages(100)
|
||||
|
||||
result = await manager.process(messages)
|
||||
|
||||
# Should be significantly reduced
|
||||
assert len(result) < 100
|
||||
assert len(result) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_config_persistence(self):
|
||||
"""Test that config settings are respected throughout processing."""
|
||||
config = ContextConfig(
|
||||
max_context_tokens=500,
|
||||
enforce_max_turns=5,
|
||||
truncate_turns=2,
|
||||
llm_compress_keep_recent=3,
|
||||
)
|
||||
manager = ContextManager(config)
|
||||
|
||||
# Verify config is stored
|
||||
assert manager.config.max_context_tokens == 500
|
||||
assert manager.config.enforce_max_turns == 5
|
||||
assert manager.config.truncate_turns == 2
|
||||
assert manager.config.llm_compress_keep_recent == 3
|
||||
|
||||
# ==================== Run Compression Tests ====================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_compression_skip_when_not_needed(self):
|
||||
"""Test _run_compression skips when needs_compression is False."""
|
||||
config = ContextConfig(max_context_tokens=100)
|
||||
manager = ContextManager(config)
|
||||
|
||||
messages = self.create_messages(5)
|
||||
|
||||
with patch.object(
|
||||
manager.compressor, "__call__", new_callable=AsyncMock
|
||||
) as mock_compress:
|
||||
result = await manager._run_compression(messages, needs_compression=False)
|
||||
|
||||
mock_compress.assert_not_called()
|
||||
assert result == messages
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_compression_skip_when_zero_limit(self):
|
||||
"""Test _run_compression skips when max_context_tokens is 0."""
|
||||
config = ContextConfig(max_context_tokens=0)
|
||||
manager = ContextManager(config)
|
||||
|
||||
messages = self.create_messages(5)
|
||||
|
||||
with patch.object(
|
||||
manager.compressor, "__call__", new_callable=AsyncMock
|
||||
) as mock_compress:
|
||||
result = await manager._run_compression(messages, needs_compression=True)
|
||||
|
||||
mock_compress.assert_not_called()
|
||||
assert result == messages
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_compression_applies_compressor(self):
|
||||
"""Test _run_compression calls compressor when needed through process()."""
|
||||
config = ContextConfig(max_context_tokens=100, truncate_turns=1)
|
||||
manager = ContextManager(config)
|
||||
|
||||
# Create messages that will trigger compression
|
||||
messages = [self.create_message("user", "x" * 300)] # ~90 tokens > 82 threshold
|
||||
compressed = [self.create_message("user", "short")] # Much smaller
|
||||
|
||||
# Replace compressor with mock
|
||||
mock_compress = AsyncMock(return_value=compressed)
|
||||
manager.compressor = mock_compress
|
||||
|
||||
result = await manager.process(messages)
|
||||
|
||||
# Compressor should have been called
|
||||
mock_compress.assert_called_once()
|
||||
assert len(result) <= len(messages)
|
||||
@@ -0,0 +1,423 @@
|
||||
"""Tests for ContextTruncator."""
|
||||
|
||||
from astrbot.core.agent.context.truncator import ContextTruncator
|
||||
from astrbot.core.agent.message import Message
|
||||
|
||||
|
||||
class TestContextTruncator:
|
||||
"""Test suite for ContextTruncator."""
|
||||
|
||||
def create_message(self, role: str, content: str = "test content") -> Message:
|
||||
"""Helper to create a simple test message."""
|
||||
return Message(role=role, content=content)
|
||||
|
||||
def create_messages(
|
||||
self, count: int, include_system: bool = False
|
||||
) -> list[Message]:
|
||||
"""Helper to create alternating user/assistant messages.
|
||||
|
||||
Args:
|
||||
count: Number of messages to create
|
||||
include_system: Whether to include a system message at the start
|
||||
|
||||
Returns:
|
||||
List of messages
|
||||
"""
|
||||
messages = []
|
||||
if include_system:
|
||||
messages.append(self.create_message("system", "System prompt"))
|
||||
|
||||
for i in range(count):
|
||||
role = "user" if i % 2 == 0 else "assistant"
|
||||
messages.append(self.create_message(role, f"Message {i}"))
|
||||
return messages
|
||||
|
||||
# ==================== fix_messages Tests ====================
|
||||
|
||||
def test_fix_messages_empty_list(self):
|
||||
"""Test fix_messages with an empty list."""
|
||||
truncator = ContextTruncator()
|
||||
result = truncator.fix_messages([])
|
||||
assert result == []
|
||||
|
||||
def test_fix_messages_normal_messages(self):
|
||||
"""Test fix_messages with normal user/assistant messages."""
|
||||
truncator = ContextTruncator()
|
||||
messages = [
|
||||
self.create_message("user", "Hello"),
|
||||
self.create_message("assistant", "Hi"),
|
||||
self.create_message("user", "How are you?"),
|
||||
]
|
||||
result = truncator.fix_messages(messages)
|
||||
assert len(result) == 3
|
||||
assert result == messages
|
||||
|
||||
def test_fix_messages_tool_with_valid_context(self):
|
||||
"""Test fix_messages with tool message after user+assistant."""
|
||||
truncator = ContextTruncator()
|
||||
messages = [
|
||||
self.create_message("user", "Run tool"),
|
||||
self.create_message("assistant", "Running..."),
|
||||
self.create_message("tool", "Tool result"),
|
||||
]
|
||||
result = truncator.fix_messages(messages)
|
||||
assert len(result) == 3
|
||||
assert result == messages
|
||||
|
||||
def test_fix_messages_tool_without_context(self):
|
||||
"""Test fix_messages with tool message without enough context."""
|
||||
truncator = ContextTruncator()
|
||||
messages = [
|
||||
self.create_message("tool", "Tool result"),
|
||||
]
|
||||
result = truncator.fix_messages(messages)
|
||||
# Tool message without context should be removed
|
||||
assert len(result) == 0
|
||||
|
||||
def test_fix_messages_tool_with_only_one_message(self):
|
||||
"""Test fix_messages with tool message after only one message."""
|
||||
truncator = ContextTruncator()
|
||||
messages = [
|
||||
self.create_message("user", "Hello"),
|
||||
self.create_message("tool", "Tool result"),
|
||||
]
|
||||
result = truncator.fix_messages(messages)
|
||||
# Tool message without enough context should be removed
|
||||
assert len(result) == 0
|
||||
|
||||
def test_fix_messages_multiple_tools(self):
|
||||
"""Test fix_messages with multiple tool messages."""
|
||||
truncator = ContextTruncator()
|
||||
messages = [
|
||||
self.create_message("user", "Run tool"),
|
||||
self.create_message("assistant", "Running..."),
|
||||
self.create_message("tool", "Tool 1 result"),
|
||||
self.create_message("tool", "Tool 2 result"),
|
||||
]
|
||||
result = truncator.fix_messages(messages)
|
||||
assert len(result) == 4
|
||||
assert result == messages
|
||||
|
||||
def test_fix_messages_mixed_system_tool(self):
|
||||
"""Test fix_messages with system message and tool messages."""
|
||||
truncator = ContextTruncator()
|
||||
messages = [
|
||||
self.create_message("system", "System prompt"),
|
||||
self.create_message("user", "Run tool"),
|
||||
self.create_message("assistant", "Running..."),
|
||||
self.create_message("tool", "Tool result"),
|
||||
]
|
||||
result = truncator.fix_messages(messages)
|
||||
assert len(result) == 4
|
||||
assert result == messages
|
||||
|
||||
# ==================== truncate_by_turns Tests ====================
|
||||
|
||||
def test_truncate_by_turns_no_limit(self):
|
||||
"""Test truncate_by_turns with -1 (no limit)."""
|
||||
truncator = ContextTruncator()
|
||||
messages = self.create_messages(20)
|
||||
result = truncator.truncate_by_turns(messages, keep_most_recent_turns=-1)
|
||||
assert len(result) == 20
|
||||
assert result == messages
|
||||
|
||||
def test_truncate_by_turns_basic(self):
|
||||
"""Test basic truncate_by_turns functionality."""
|
||||
truncator = ContextTruncator()
|
||||
# Create 10 messages = 5 turns (user/assistant pairs)
|
||||
messages = self.create_messages(10)
|
||||
result = truncator.truncate_by_turns(
|
||||
messages, keep_most_recent_turns=3, drop_turns=1
|
||||
)
|
||||
|
||||
# Should keep 3 most recent turns (6 messages)
|
||||
assert len(result) <= 8 # (3-1+1)*2 = 6, but may adjust for correct format
|
||||
|
||||
def test_truncate_by_turns_with_system_message(self):
|
||||
"""Test truncate_by_turns preserves system messages."""
|
||||
truncator = ContextTruncator()
|
||||
messages = self.create_messages(10, include_system=True)
|
||||
result = truncator.truncate_by_turns(
|
||||
messages, keep_most_recent_turns=2, drop_turns=1
|
||||
)
|
||||
|
||||
# System message should always be preserved
|
||||
assert result[0].role == "system"
|
||||
assert result[0].content == "System prompt"
|
||||
|
||||
def test_truncate_by_turns_zero_keep(self):
|
||||
"""Test truncate_by_turns with keep_most_recent_turns=0."""
|
||||
truncator = ContextTruncator()
|
||||
messages = self.create_messages(10)
|
||||
result = truncator.truncate_by_turns(
|
||||
messages, keep_most_recent_turns=0, drop_turns=1
|
||||
)
|
||||
|
||||
# Should result in empty or minimal list
|
||||
assert len(result) == 0
|
||||
|
||||
def test_truncate_by_turns_below_threshold(self):
|
||||
"""Test truncate_by_turns when messages are below threshold."""
|
||||
truncator = ContextTruncator()
|
||||
# Create 4 messages = 2 turns
|
||||
messages = self.create_messages(4)
|
||||
result = truncator.truncate_by_turns(
|
||||
messages, keep_most_recent_turns=5, drop_turns=1
|
||||
)
|
||||
|
||||
# No truncation should happen
|
||||
assert len(result) == 4
|
||||
assert result == messages
|
||||
|
||||
def test_truncate_by_turns_exact_threshold(self):
|
||||
"""Test truncate_by_turns when messages exactly match threshold."""
|
||||
truncator = ContextTruncator()
|
||||
# Create 6 messages = 3 turns
|
||||
messages = self.create_messages(6)
|
||||
result = truncator.truncate_by_turns(
|
||||
messages, keep_most_recent_turns=3, drop_turns=1
|
||||
)
|
||||
|
||||
# No truncation should happen
|
||||
assert len(result) == 6
|
||||
assert result == messages
|
||||
|
||||
def test_truncate_by_turns_ensures_user_first(self):
|
||||
"""Test that truncate_by_turns ensures user message comes first."""
|
||||
truncator = ContextTruncator()
|
||||
# Create scenario where truncation might start with assistant
|
||||
messages = self.create_messages(20)
|
||||
result = truncator.truncate_by_turns(
|
||||
messages, keep_most_recent_turns=3, drop_turns=1
|
||||
)
|
||||
|
||||
# First non-system message should be user
|
||||
assert result[0].role == "user"
|
||||
|
||||
def test_truncate_by_turns_multiple_drop(self):
|
||||
"""Test truncate_by_turns with multiple turns dropped at once."""
|
||||
truncator = ContextTruncator()
|
||||
messages = self.create_messages(20)
|
||||
result = truncator.truncate_by_turns(
|
||||
messages, keep_most_recent_turns=5, drop_turns=3
|
||||
)
|
||||
|
||||
# Should drop 3 turns when limit exceeded
|
||||
assert len(result) < len(messages)
|
||||
|
||||
# ==================== truncate_by_dropping_oldest_turns Tests ====================
|
||||
|
||||
def test_truncate_by_dropping_oldest_turns_zero(self):
|
||||
"""Test truncate_by_dropping_oldest_turns with drop_turns=0."""
|
||||
truncator = ContextTruncator()
|
||||
messages = self.create_messages(10)
|
||||
result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=0)
|
||||
assert result == messages
|
||||
|
||||
def test_truncate_by_dropping_oldest_turns_negative(self):
|
||||
"""Test truncate_by_dropping_oldest_turns with negative drop_turns."""
|
||||
truncator = ContextTruncator()
|
||||
messages = self.create_messages(10)
|
||||
result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=-1)
|
||||
assert result == messages
|
||||
|
||||
def test_truncate_by_dropping_oldest_turns_basic(self):
|
||||
"""Test basic truncate_by_dropping_oldest_turns functionality."""
|
||||
truncator = ContextTruncator()
|
||||
# Create 10 messages = 5 turns
|
||||
messages = self.create_messages(10)
|
||||
result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=2)
|
||||
|
||||
# Should drop 2 oldest turns (4 messages)
|
||||
assert len(result) == 6
|
||||
# Should start with user message
|
||||
assert result[0].role == "user"
|
||||
|
||||
def test_truncate_by_dropping_oldest_turns_with_system(self):
|
||||
"""Test truncate_by_dropping_oldest_turns preserves system messages."""
|
||||
truncator = ContextTruncator()
|
||||
messages = self.create_messages(10, include_system=True)
|
||||
result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=2)
|
||||
|
||||
# System message should be preserved
|
||||
assert result[0].role == "system"
|
||||
assert result[0].content == "System prompt"
|
||||
|
||||
def test_truncate_by_dropping_oldest_turns_drop_all(self):
|
||||
"""Test truncate_by_dropping_oldest_turns dropping all turns."""
|
||||
truncator = ContextTruncator()
|
||||
# Create 4 messages = 2 turns
|
||||
messages = self.create_messages(4)
|
||||
result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=2)
|
||||
|
||||
# Should drop all turns
|
||||
assert len(result) == 0
|
||||
|
||||
def test_truncate_by_dropping_oldest_turns_drop_more_than_available(self):
|
||||
"""Test truncate_by_dropping_oldest_turns with drop_turns > available turns."""
|
||||
truncator = ContextTruncator()
|
||||
# Create 4 messages = 2 turns
|
||||
messages = self.create_messages(4)
|
||||
result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=5)
|
||||
|
||||
# Should result in empty list
|
||||
assert len(result) == 0
|
||||
|
||||
def test_truncate_by_dropping_oldest_turns_ensures_user_first(self):
|
||||
"""Test that result starts with user message after dropping."""
|
||||
truncator = ContextTruncator()
|
||||
messages = self.create_messages(20)
|
||||
result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=3)
|
||||
|
||||
# First message should be user
|
||||
if len(result) > 0:
|
||||
assert result[0].role == "user"
|
||||
|
||||
# ==================== truncate_by_halving Tests ====================
|
||||
|
||||
def test_truncate_by_halving_empty(self):
|
||||
"""Test truncate_by_halving with empty list."""
|
||||
truncator = ContextTruncator()
|
||||
result = truncator.truncate_by_halving([])
|
||||
assert result == []
|
||||
|
||||
def test_truncate_by_halving_single_message(self):
|
||||
"""Test truncate_by_halving with single message."""
|
||||
truncator = ContextTruncator()
|
||||
messages = [self.create_message("user", "Hello")]
|
||||
result = truncator.truncate_by_halving(messages)
|
||||
# Should not truncate if <= 2 messages
|
||||
assert result == messages
|
||||
|
||||
def test_truncate_by_halving_two_messages(self):
|
||||
"""Test truncate_by_halving with two messages."""
|
||||
truncator = ContextTruncator()
|
||||
messages = self.create_messages(2)
|
||||
result = truncator.truncate_by_halving(messages)
|
||||
# Should not truncate if <= 2 messages
|
||||
assert result == messages
|
||||
|
||||
def test_truncate_by_halving_basic(self):
|
||||
"""Test basic truncate_by_halving functionality."""
|
||||
truncator = ContextTruncator()
|
||||
# Create 20 messages
|
||||
messages = self.create_messages(20)
|
||||
result = truncator.truncate_by_halving(messages)
|
||||
|
||||
# Should delete 50% = 10 messages, keep 10
|
||||
assert len(result) == 10
|
||||
# First message should be user
|
||||
assert result[0].role == "user"
|
||||
|
||||
def test_truncate_by_halving_with_system_message(self):
|
||||
"""Test truncate_by_halving preserves system messages."""
|
||||
truncator = ContextTruncator()
|
||||
messages = self.create_messages(20, include_system=True)
|
||||
result = truncator.truncate_by_halving(messages)
|
||||
|
||||
# System message should be preserved
|
||||
assert result[0].role == "system"
|
||||
assert result[0].content == "System prompt"
|
||||
|
||||
def test_truncate_by_halving_odd_count(self):
|
||||
"""Test truncate_by_halving with odd number of messages."""
|
||||
truncator = ContextTruncator()
|
||||
messages = self.create_messages(11)
|
||||
result = truncator.truncate_by_halving(messages)
|
||||
|
||||
# Should delete floor(11/2) = 5 messages, keep 6
|
||||
# But after ensuring user first, may be 5
|
||||
assert len(result) >= 5
|
||||
assert result[0].role == "user"
|
||||
|
||||
def test_truncate_by_halving_ensures_user_first(self):
|
||||
"""Test that result starts with user message."""
|
||||
truncator = ContextTruncator()
|
||||
# Create messages starting with user
|
||||
messages = self.create_messages(30)
|
||||
result = truncator.truncate_by_halving(messages)
|
||||
|
||||
# First message should be user
|
||||
assert result[0].role == "user"
|
||||
|
||||
def test_truncate_by_halving_preserves_recent_messages(self):
|
||||
"""Test that truncate_by_halving keeps the most recent 50%."""
|
||||
truncator = ContextTruncator()
|
||||
messages = [
|
||||
self.create_message("user", "Message 0"),
|
||||
self.create_message("assistant", "Message 1"),
|
||||
self.create_message("user", "Message 2"),
|
||||
self.create_message("assistant", "Message 3"),
|
||||
]
|
||||
result = truncator.truncate_by_halving(messages)
|
||||
|
||||
# Should keep last 2 messages
|
||||
assert len(result) == 2
|
||||
assert result[0].content == "Message 2"
|
||||
assert result[1].content == "Message 3"
|
||||
|
||||
# ==================== Integration Tests ====================
|
||||
|
||||
def test_truncate_with_tool_messages(self):
|
||||
"""Test truncation with tool messages."""
|
||||
truncator = ContextTruncator()
|
||||
messages = [
|
||||
self.create_message("user", "Run tool"),
|
||||
self.create_message("assistant", "Running..."),
|
||||
self.create_message("tool", "Tool result"),
|
||||
self.create_message("user", "Thanks"),
|
||||
self.create_message("assistant", "Welcome"),
|
||||
]
|
||||
|
||||
result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=1)
|
||||
|
||||
# First turn (user+assistant+tool) should be dropped
|
||||
# Tool message should be cleaned up by fix_messages
|
||||
assert len(result) <= 2
|
||||
|
||||
def test_chain_multiple_truncations(self):
|
||||
"""Test chaining multiple truncation methods."""
|
||||
truncator = ContextTruncator()
|
||||
messages = self.create_messages(40, include_system=True)
|
||||
|
||||
# First: truncate by turns
|
||||
result = truncator.truncate_by_turns(
|
||||
messages, keep_most_recent_turns=10, drop_turns=2
|
||||
)
|
||||
# Then: halve
|
||||
result = truncator.truncate_by_halving(result)
|
||||
|
||||
# Should have system message + truncated content
|
||||
assert result[0].role == "system"
|
||||
assert len(result) < len(messages)
|
||||
|
||||
def test_empty_after_system_message(self):
|
||||
"""Test truncation when only system message exists."""
|
||||
truncator = ContextTruncator()
|
||||
messages = [self.create_message("system", "System prompt")]
|
||||
|
||||
result = truncator.truncate_by_turns(
|
||||
messages, keep_most_recent_turns=5, drop_turns=1
|
||||
)
|
||||
|
||||
# Should keep system message
|
||||
assert len(result) == 1
|
||||
assert result[0].role == "system"
|
||||
|
||||
def test_all_system_messages(self):
|
||||
"""Test truncation with only system messages."""
|
||||
truncator = ContextTruncator()
|
||||
messages = [
|
||||
self.create_message("system", "System 1"),
|
||||
self.create_message("system", "System 2"),
|
||||
]
|
||||
|
||||
result = truncator.truncate_by_turns(
|
||||
messages, keep_most_recent_turns=0, drop_turns=1
|
||||
)
|
||||
|
||||
# System messages should be preserved, but since there are no non-system
|
||||
# messages and keep_most_recent_turns=0, result should be system messages only
|
||||
assert len(result) >= 0 # May keep system messages or clear all
|
||||
if len(result) > 0:
|
||||
assert all(msg.role == "system" for msg in result)
|
||||
Reference in New Issue
Block a user