feat: implement context compression logic with dynamic threshold and token tracking
This commit is contained in:
@@ -25,6 +25,21 @@ class ContextCompressor(Protocol):
|
||||
Provides an interface for compressing message lists.
|
||||
"""
|
||||
|
||||
def should_compress(
|
||||
self, messages: list[Message], current_tokens: int, max_tokens: int
|
||||
) -> bool:
|
||||
"""Check if compression is needed.
|
||||
|
||||
Args:
|
||||
messages: The message list to evaluate.
|
||||
current_tokens: The current token count.
|
||||
max_tokens: The maximum allowed tokens for the model.
|
||||
|
||||
Returns:
|
||||
True if compression is needed, False otherwise.
|
||||
"""
|
||||
...
|
||||
|
||||
async def __call__(self, messages: list[Message]) -> list[Message]:
|
||||
"""Compress the message list.
|
||||
|
||||
@@ -42,13 +57,33 @@ class TruncateByTurnsCompressor:
|
||||
Truncates the message list by removing older turns.
|
||||
"""
|
||||
|
||||
def __init__(self, truncate_turns: int = 1):
|
||||
def __init__(self, truncate_turns: int = 1, compression_threshold: float = 0.82):
|
||||
"""Initialize the truncate by turns compressor.
|
||||
|
||||
Args:
|
||||
truncate_turns: The number of turns to remove when truncating (default: 1).
|
||||
compression_threshold: The compression trigger threshold (default: 0.82).
|
||||
"""
|
||||
self.truncate_turns = truncate_turns
|
||||
self.compression_threshold = compression_threshold
|
||||
|
||||
def should_compress(
|
||||
self, messages: list[Message], current_tokens: int, max_tokens: int
|
||||
) -> bool:
|
||||
"""Check if compression is needed.
|
||||
|
||||
Args:
|
||||
messages: The message list to evaluate.
|
||||
current_tokens: The current token count.
|
||||
max_tokens: The maximum allowed tokens.
|
||||
|
||||
Returns:
|
||||
True if compression is needed, False otherwise.
|
||||
"""
|
||||
if max_tokens <= 0 or current_tokens <= 0:
|
||||
return False
|
||||
usage_rate = current_tokens / max_tokens
|
||||
return usage_rate > self.compression_threshold
|
||||
|
||||
async def __call__(self, messages: list[Message]) -> list[Message]:
|
||||
truncator = ContextTruncator()
|
||||
@@ -116,15 +151,19 @@ class LLMSummaryCompressor:
|
||||
provider: "Provider",
|
||||
keep_recent: int = 4,
|
||||
instruction_text: str | None = None,
|
||||
compression_threshold: float = 0.82,
|
||||
):
|
||||
"""Initialize the LLM summary compressor.
|
||||
|
||||
Args:
|
||||
provider: The LLM provider instance.
|
||||
keep_recent: The number of latest messages to keep (default: 4).
|
||||
instruction_text: Custom instruction for summary generation.
|
||||
compression_threshold: The compression trigger threshold (default: 0.82).
|
||||
"""
|
||||
self.provider = provider
|
||||
self.keep_recent = keep_recent
|
||||
self.compression_threshold = compression_threshold
|
||||
|
||||
self.instruction_text = instruction_text or (
|
||||
"Based on our full conversation history, produce a concise summary of key takeaways and/or project progress.\n"
|
||||
@@ -134,6 +173,24 @@ class LLMSummaryCompressor:
|
||||
"4. Write the summary in the user's language.\n"
|
||||
)
|
||||
|
||||
def should_compress(
|
||||
self, messages: list[Message], current_tokens: int, max_tokens: int
|
||||
) -> bool:
|
||||
"""Check if compression is needed.
|
||||
|
||||
Args:
|
||||
messages: The message list to evaluate.
|
||||
current_tokens: The current token count.
|
||||
max_tokens: The maximum allowed tokens.
|
||||
|
||||
Returns:
|
||||
True if compression is needed, False otherwise.
|
||||
"""
|
||||
if max_tokens <= 0 or current_tokens <= 0:
|
||||
return False
|
||||
usage_rate = current_tokens / max_tokens
|
||||
return usage_rate > self.compression_threshold
|
||||
|
||||
async def __call__(self, messages: list[Message]) -> list[Message]:
|
||||
"""Use LLM to generate a summary of the conversation history.
|
||||
|
||||
@@ -170,9 +227,15 @@ class LLMSummaryCompressor:
|
||||
|
||||
result.append(
|
||||
Message(
|
||||
role="system",
|
||||
content=f"History conversation summary: {summary_content}",
|
||||
),
|
||||
role="user",
|
||||
content=f"Our previous history conversation summary: {summary_content}",
|
||||
)
|
||||
)
|
||||
result.append(
|
||||
Message(
|
||||
role="assistant",
|
||||
content="Acknowledged the summary of our previous conversation history.",
|
||||
)
|
||||
)
|
||||
|
||||
result.extend(recent_messages)
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .compressor import ContextCompressor
|
||||
from .token_counter import TokenCounter
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from astrbot.core.provider.provider import Provider
|
||||
|
||||
@@ -26,3 +29,7 @@ class ContextConfig:
|
||||
"""Number of recent messages to keep during LLM-based compression."""
|
||||
llm_compress_provider: "Provider | None" = None
|
||||
"""LLM provider used for compression tasks. If None, truncation strategy is used."""
|
||||
custom_token_counter: TokenCounter | None = None
|
||||
"""Custom token counting method. If None, the default method is used."""
|
||||
custom_compressor: ContextCompressor | None = None
|
||||
"""Custom context compression method. If None, the default method is used."""
|
||||
|
||||
@@ -3,16 +3,13 @@ from astrbot import logger
|
||||
from ..message import Message
|
||||
from .compressor import LLMSummaryCompressor, TruncateByTurnsCompressor
|
||||
from .config import ContextConfig
|
||||
from .token_counter import TokenCounter
|
||||
from .token_counter import EstimateTokenCounter
|
||||
from .truncator import ContextTruncator
|
||||
|
||||
|
||||
class ContextManager:
|
||||
"""Context compression manager."""
|
||||
|
||||
COMPRESSION_THRESHOLD = 0.82
|
||||
"""compression trigger threshold"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: ContextConfig,
|
||||
@@ -28,10 +25,12 @@ class ContextManager:
|
||||
"""
|
||||
self.config = config
|
||||
|
||||
self.token_counter = TokenCounter()
|
||||
self.token_counter = config.custom_token_counter or EstimateTokenCounter()
|
||||
self.truncator = ContextTruncator()
|
||||
|
||||
if config.llm_compress_provider:
|
||||
if config.custom_compressor:
|
||||
self.compressor = config.custom_compressor
|
||||
elif config.llm_compress_provider:
|
||||
self.compressor = LLMSummaryCompressor(
|
||||
provider=config.llm_compress_provider,
|
||||
keep_recent=config.llm_compress_keep_recent,
|
||||
@@ -64,60 +63,32 @@ class ContextManager:
|
||||
|
||||
# 2. 基于 token 的压缩
|
||||
if self.config.max_context_tokens > 0:
|
||||
# check if the messages need to be compressed
|
||||
needs_compression, tokens = await self._initial_token_check(result)
|
||||
total_tokens = self.token_counter.count_tokens(result)
|
||||
|
||||
# compress/truncate the messages if needed
|
||||
result = await self._run_compression(result, tokens, needs_compression)
|
||||
if self.compressor.should_compress(
|
||||
result, total_tokens, self.config.max_context_tokens
|
||||
):
|
||||
result = await self._run_compression(result, total_tokens)
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Error during context processing: {e}", exc_info=True)
|
||||
return messages
|
||||
|
||||
async def _initial_token_check(self, messages: list[Message]) -> tuple[bool, int]:
|
||||
"""
|
||||
Check if the messages need to be compressed.
|
||||
|
||||
Args:
|
||||
messages: The original message list.
|
||||
|
||||
Returns:
|
||||
tuple: (whether to compress, initial token count)
|
||||
"""
|
||||
if not messages:
|
||||
return False, 0
|
||||
if self.config.max_context_tokens <= 0:
|
||||
return False, 0
|
||||
|
||||
total_tokens = self.token_counter.count_tokens(messages)
|
||||
|
||||
usage_rate = total_tokens / self.config.max_context_tokens
|
||||
|
||||
needs_compression = usage_rate > self.COMPRESSION_THRESHOLD
|
||||
return needs_compression, total_tokens if needs_compression else 0
|
||||
|
||||
async def _run_compression(
|
||||
self, messages: list[Message], prev_tokens: int, needs_compression: bool
|
||||
self, messages: list[Message], prev_tokens: int
|
||||
) -> list[Message]:
|
||||
"""
|
||||
Compress/truncate the messages if needed.
|
||||
Compress/truncate the messages.
|
||||
|
||||
Args:
|
||||
messages: The original message list.
|
||||
needs_compression: Whether to compress.
|
||||
prev_tokens: The token count before compression.
|
||||
|
||||
Returns:
|
||||
The compressed/truncated message list.
|
||||
"""
|
||||
if not needs_compression:
|
||||
return messages
|
||||
if self.config.max_context_tokens <= 0:
|
||||
return messages
|
||||
|
||||
logger.debug(
|
||||
f"Reached high water mark {self.COMPRESSION_THRESHOLD}, starting compression..."
|
||||
)
|
||||
logger.debug("Compress triggered, starting compression...")
|
||||
|
||||
messages = await self.compressor(messages)
|
||||
|
||||
@@ -132,23 +103,14 @@ class ContextManager:
|
||||
f" compression rate: {compress_rate:.2f}%.",
|
||||
)
|
||||
|
||||
if (
|
||||
tokens_after_summary / self.config.max_context_tokens
|
||||
> self.COMPRESSION_THRESHOLD
|
||||
# last check
|
||||
if self.compressor.should_compress(
|
||||
messages, tokens_after_summary, self.config.max_context_tokens
|
||||
):
|
||||
# still over 82%, truncate by half
|
||||
messages = self._compress_by_halving(messages)
|
||||
logger.info(
|
||||
"Context still exceeds max tokens after compression, applying halving truncation..."
|
||||
)
|
||||
# still need compress, truncate by half
|
||||
messages = self.truncator.truncate_by_halving(messages)
|
||||
|
||||
return messages
|
||||
|
||||
def _compress_by_halving(self, messages: list[Message]) -> list[Message]:
|
||||
"""
|
||||
对半砍策略:删除中间50%的消息
|
||||
|
||||
Args:
|
||||
messages: 原始消息列表
|
||||
|
||||
Returns:
|
||||
截断后的消息列表
|
||||
"""
|
||||
return self.truncator.truncate_by_halving(messages)
|
||||
|
||||
@@ -1,9 +1,33 @@
|
||||
import json
|
||||
from typing import Protocol, runtime_checkable
|
||||
|
||||
from ..message import Message, TextPart
|
||||
|
||||
|
||||
class TokenCounter:
|
||||
@runtime_checkable
|
||||
class TokenCounter(Protocol):
|
||||
"""
|
||||
Protocol for token counters.
|
||||
Provides an interface for counting tokens in message lists.
|
||||
"""
|
||||
|
||||
def count_tokens(self, messages: list[Message]) -> int:
|
||||
"""Count the total tokens in the message list.
|
||||
|
||||
Args:
|
||||
messages: The message list.
|
||||
|
||||
Returns:
|
||||
The total token count.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class EstimateTokenCounter:
|
||||
"""Estimate token counter implementation.
|
||||
Provides a simple estimation of token count based on character types.
|
||||
"""
|
||||
|
||||
def count_tokens(self, messages: list[Message]) -> int:
|
||||
total = 0
|
||||
for msg in messages:
|
||||
|
||||
@@ -25,6 +25,8 @@ from astrbot.core.provider.entities import (
|
||||
)
|
||||
from astrbot.core.provider.provider import Provider
|
||||
|
||||
from ..context.compressor import ContextCompressor
|
||||
from ..context.token_counter import TokenCounter
|
||||
from ..context.config import ContextConfig
|
||||
from ..context.manager import ContextManager
|
||||
from ..hooks import BaseAgentRunHooks
|
||||
@@ -49,24 +51,30 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
run_context: ContextWrapper[TContext],
|
||||
tool_executor: BaseFunctionToolExecutor[TContext],
|
||||
agent_hooks: BaseAgentRunHooks[TContext],
|
||||
streaming: bool = False,
|
||||
# enforce max turns, will discard older turns when exceeded BEFORE compression
|
||||
# -1 means no limit
|
||||
enforce_max_turns: int = -1,
|
||||
# llm compressor
|
||||
llm_compress_instruction: str | None = None,
|
||||
llm_compress_keep_recent: int = 0,
|
||||
llm_compress_provider: Provider | None = None,
|
||||
# truncate by turns compressor
|
||||
truncate_turns: int = 1,
|
||||
# customize
|
||||
custom_token_counter: TokenCounter | None = None,
|
||||
custom_compressor: ContextCompressor | None = None,
|
||||
**kwargs: T.Any,
|
||||
) -> None:
|
||||
self.req = request
|
||||
self.streaming = kwargs.get("streaming", False)
|
||||
|
||||
# enforce max turns, will discard older turns when exceeded BEFORE compression
|
||||
# -1 means no limit
|
||||
self.enforce_max_turns = kwargs.get("enforce_max_turns", -1)
|
||||
|
||||
# llm compressor
|
||||
self.llm_compress_instruction = kwargs.get("llm_compress_instruction", None)
|
||||
self.llm_compress_keep_recent = kwargs.get("llm_compress_keep_recent", 0)
|
||||
self.llm_compress_provider: Provider | None = kwargs.get(
|
||||
"llm_compress_provider", None
|
||||
)
|
||||
# truncate by turns compressor
|
||||
self.truncate_turns = kwargs.get("truncate_turns", 1)
|
||||
|
||||
self.streaming = streaming
|
||||
self.enforce_max_turns = enforce_max_turns
|
||||
self.llm_compress_instruction = llm_compress_instruction
|
||||
self.llm_compress_keep_recent = llm_compress_keep_recent
|
||||
self.llm_compress_provider = llm_compress_provider
|
||||
self.truncate_turns = truncate_turns
|
||||
self.custom_token_counter = custom_token_counter
|
||||
self.custom_compressor = custom_compressor
|
||||
# we will do compress when:
|
||||
# 1. before requesting LLM
|
||||
# TODO: 2. after LLM output a tool call
|
||||
@@ -79,6 +87,8 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
llm_compress_instruction=self.llm_compress_instruction,
|
||||
llm_compress_keep_recent=self.llm_compress_keep_recent,
|
||||
llm_compress_provider=self.llm_compress_provider,
|
||||
custom_token_counter=self.custom_token_counter,
|
||||
custom_compressor=self.custom_compressor,
|
||||
)
|
||||
self.context_manager = ContextManager(self.context_config)
|
||||
|
||||
|
||||
@@ -149,9 +149,12 @@ class Context:
|
||||
contexts: context messages for the LLM
|
||||
max_steps: Maximum number of tool calls before stopping the loop
|
||||
**kwargs: Additional keyword arguments. The kwargs will not be passed to the LLM directly for now, but can include:
|
||||
stream: bool - whether to stream the LLM response
|
||||
agent_hooks: BaseAgentRunHooks[AstrAgentContext] - hooks to run during agent execution
|
||||
agent_context: AstrAgentContext - context to use for the agent
|
||||
|
||||
other kwargs will be DIRECTLY passed to the runner.reset() method
|
||||
|
||||
Returns:
|
||||
The final LLMResponse after tool calls are completed.
|
||||
|
||||
@@ -194,6 +197,15 @@ class Context:
|
||||
)
|
||||
agent_runner = ToolLoopAgentRunner()
|
||||
tool_executor = FunctionToolExecutor()
|
||||
|
||||
streaming = kwargs.get("stream", False)
|
||||
|
||||
other_kwargs = {
|
||||
k: v
|
||||
for k, v in kwargs.items()
|
||||
if k not in ["stream", "agent_hooks", "agent_context"]
|
||||
}
|
||||
|
||||
await agent_runner.reset(
|
||||
provider=prov,
|
||||
request=request,
|
||||
@@ -203,7 +215,8 @@ class Context:
|
||||
),
|
||||
tool_executor=tool_executor,
|
||||
agent_hooks=agent_hooks,
|
||||
streaming=kwargs.get("stream", False),
|
||||
streaming=streaming,
|
||||
**other_kwargs,
|
||||
)
|
||||
async for _ in agent_runner.step_until_done(max_steps):
|
||||
pass
|
||||
|
||||
+113
-112
@@ -197,13 +197,18 @@ class TestContextManager:
|
||||
messages = [self.create_message("user", "Hi" * 50)] # ~100 tokens
|
||||
|
||||
with patch.object(
|
||||
manager.compressor, "__call__", new_callable=AsyncMock
|
||||
) as mock_compress:
|
||||
result = await manager.process(messages)
|
||||
manager.compressor, "should_compress", return_value=False
|
||||
) as mock_should_compress:
|
||||
with patch.object(
|
||||
manager.compressor, "__call__", new_callable=AsyncMock
|
||||
) as mock_compress:
|
||||
result = await manager.process(messages)
|
||||
|
||||
# Compressor should not be called
|
||||
mock_compress.assert_not_called()
|
||||
assert result == messages
|
||||
# should_compress should be called
|
||||
mock_should_compress.assert_called_once()
|
||||
# Compressor should not be called
|
||||
mock_compress.assert_not_called()
|
||||
assert result == messages
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_compression_triggered_above_threshold(self):
|
||||
@@ -219,16 +224,28 @@ class TestContextManager:
|
||||
# Mock compressor to return smaller result
|
||||
compressed = [self.create_message("user", "short")]
|
||||
|
||||
# Create a mock compressor that we can track
|
||||
mock_compress = AsyncMock(return_value=compressed)
|
||||
manager.compressor = mock_compress
|
||||
# Create a mock compressor
|
||||
mock_compressor = AsyncMock()
|
||||
mock_compressor.compression_threshold = 0.82
|
||||
mock_compressor.return_value = compressed
|
||||
|
||||
# Mock should_compress to return True first time, False after
|
||||
call_count = 0
|
||||
|
||||
def mock_should_compress(*args, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return call_count == 1
|
||||
|
||||
mock_compressor.should_compress = mock_should_compress
|
||||
manager.compressor = mock_compressor
|
||||
|
||||
result = await manager.process(messages)
|
||||
|
||||
# Compressor should be called
|
||||
mock_compress.assert_called_once()
|
||||
mock_compressor.assert_called_once()
|
||||
# Result should be the compressed version
|
||||
assert len(result) == len(compressed)
|
||||
assert len(result) <= len(messages)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_compression_with_zero_max_tokens(self):
|
||||
@@ -243,7 +260,7 @@ class TestContextManager:
|
||||
) as mock_compress:
|
||||
result = await manager.process(messages)
|
||||
|
||||
# Compressor should not be called
|
||||
# Compressor should not be called when max_context_tokens is 0
|
||||
mock_compress.assert_not_called()
|
||||
assert result == messages
|
||||
|
||||
@@ -277,14 +294,18 @@ class TestContextManager:
|
||||
async def mock_compress(msgs):
|
||||
return msgs # Return same messages (still over limit)
|
||||
|
||||
with patch.object(manager.compressor, "__call__", new=mock_compress):
|
||||
with patch.object(
|
||||
manager, "_compress_by_halving", return_value=long_messages[:5]
|
||||
) as mock_halving:
|
||||
_ = await manager.process(long_messages)
|
||||
# Mock should_compress to return True twice (before and after compression)
|
||||
with patch.object(manager.compressor, "should_compress", return_value=True):
|
||||
with patch.object(manager.compressor, "__call__", new=mock_compress):
|
||||
with patch.object(
|
||||
manager.truncator,
|
||||
"truncate_by_halving",
|
||||
return_value=long_messages[:5],
|
||||
) as mock_halving:
|
||||
_ = await manager.process(long_messages)
|
||||
|
||||
# Halving should be called
|
||||
mock_halving.assert_called_once()
|
||||
# Halving should be called
|
||||
mock_halving.assert_called_once()
|
||||
|
||||
# ==================== Combined Truncation and Compression Tests ====================
|
||||
|
||||
@@ -352,7 +373,10 @@ class TestContextManager:
|
||||
messages = [self.create_message("user", "x" * 300)] # ~90 tokens
|
||||
|
||||
# Replace compressor with one that raises an exception
|
||||
manager.compressor = AsyncMock(side_effect=Exception("Test error"))
|
||||
mock_compressor = AsyncMock(side_effect=Exception("Test error"))
|
||||
mock_compressor.compression_threshold = 0.82
|
||||
mock_compressor.should_compress = MagicMock(return_value=True)
|
||||
manager.compressor = mock_compressor
|
||||
|
||||
with patch("astrbot.core.agent.context.manager.logger") as mock_logger:
|
||||
result = await manager.process(messages)
|
||||
@@ -393,9 +417,10 @@ class TestContextManager:
|
||||
]
|
||||
|
||||
# Should trigger compression due to token count
|
||||
needs_compression, tokens = await manager._initial_token_check(messages)
|
||||
tokens = manager.token_counter.count_tokens(messages)
|
||||
needs_compression = manager.compressor.should_compress(messages, tokens, 50)
|
||||
|
||||
assert tokens is not None # Tokens should be counted
|
||||
assert tokens > 0 # Tokens should be counted
|
||||
assert needs_compression # Should trigger compression
|
||||
|
||||
# ==================== Tool Calls Tests ====================
|
||||
@@ -425,101 +450,73 @@ class TestContextManager:
|
||||
|
||||
assert len(result) == 2
|
||||
|
||||
# ==================== Initial Token Check Tests ====================
|
||||
# ==================== Compressor should_compress Tests ====================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initial_token_check_empty_messages(self):
|
||||
"""Test _initial_token_check with empty messages."""
|
||||
async def test_should_compress_empty_messages(self):
|
||||
"""Test should_compress with empty messages."""
|
||||
config = ContextConfig(max_context_tokens=100)
|
||||
manager = ContextManager(config)
|
||||
|
||||
needs_compression, tokens = await manager._initial_token_check([])
|
||||
|
||||
# Compressor's should_compress should handle empty gracefully
|
||||
needs_compression = manager.compressor.should_compress([], 0, 100)
|
||||
assert not needs_compression
|
||||
assert tokens is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initial_token_check_no_limit(self):
|
||||
"""Test _initial_token_check when max_context_tokens is 0."""
|
||||
config = ContextConfig(max_context_tokens=0)
|
||||
manager = ContextManager(config)
|
||||
|
||||
messages = [self.create_message("user", "x" * 1000)]
|
||||
needs_compression, tokens = await manager._initial_token_check(messages)
|
||||
|
||||
assert not needs_compression
|
||||
assert tokens is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initial_token_check_below_threshold(self):
|
||||
"""Test _initial_token_check when below compression threshold."""
|
||||
async def test_should_compress_below_threshold(self):
|
||||
"""Test should_compress when below compression threshold."""
|
||||
config = ContextConfig(max_context_tokens=1000)
|
||||
manager = ContextManager(config)
|
||||
|
||||
messages = [self.create_message("user", "Hello")]
|
||||
needs_compression, tokens = await manager._initial_token_check(messages)
|
||||
tokens = manager.token_counter.count_tokens(messages)
|
||||
|
||||
needs_compression = manager.compressor.should_compress(messages, tokens, 1000)
|
||||
assert not needs_compression
|
||||
assert tokens is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initial_token_check_above_threshold(self):
|
||||
"""Test _initial_token_check when above compression threshold."""
|
||||
async def test_should_compress_above_threshold(self):
|
||||
"""Test should_compress when above compression threshold."""
|
||||
config = ContextConfig(max_context_tokens=100)
|
||||
manager = ContextManager(config)
|
||||
|
||||
# Create message with ~90 tokens (above 0.82 * 100 = 82)
|
||||
# Create message with many tokens
|
||||
messages = [self.create_message("user", "这是测试" * 50)]
|
||||
needs_compression, tokens = await manager._initial_token_check(messages)
|
||||
tokens = manager.token_counter.count_tokens(messages)
|
||||
|
||||
assert needs_compression
|
||||
assert tokens is not None
|
||||
assert tokens > 82
|
||||
needs_compression = manager.compressor.should_compress(messages, tokens, 100)
|
||||
# Should need compression if tokens > 82 (0.82 * 100)
|
||||
assert needs_compression == (tokens > 82)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initial_token_check_exactly_at_threshold(self):
|
||||
"""Test _initial_token_check when just above threshold."""
|
||||
config = ContextConfig(max_context_tokens=100)
|
||||
manager = ContextManager(config)
|
||||
# ==================== Truncator Halving Tests ====================
|
||||
|
||||
# Create message with >82 tokens (0.82 * 100)
|
||||
# 300 chars * 0.3 = 90 tokens > 82 (threshold)
|
||||
messages = [self.create_message("user", "x" * 300)] # ~90 tokens
|
||||
needs_compression, tokens = await manager._initial_token_check(messages)
|
||||
|
||||
# Above threshold should trigger compression
|
||||
assert tokens is not None
|
||||
assert needs_compression
|
||||
|
||||
# ==================== Compression by Halving Tests ====================
|
||||
|
||||
def test_compress_by_halving_basic(self):
|
||||
"""Test _compress_by_halving removes middle 50%."""
|
||||
def test_truncate_by_halving_basic(self):
|
||||
"""Test truncate_by_halving removes middle 50%."""
|
||||
config = ContextConfig()
|
||||
manager = ContextManager(config)
|
||||
|
||||
messages = self.create_messages(10)
|
||||
result = manager._compress_by_halving(messages)
|
||||
result = manager.truncator.truncate_by_halving(messages)
|
||||
|
||||
# Should keep roughly half
|
||||
assert len(result) < len(messages)
|
||||
|
||||
def test_compress_by_halving_empty_list(self):
|
||||
"""Test _compress_by_halving with empty list."""
|
||||
def test_truncate_by_halving_empty_list(self):
|
||||
"""Test truncate_by_halving with empty list."""
|
||||
config = ContextConfig()
|
||||
manager = ContextManager(config)
|
||||
|
||||
result = manager._compress_by_halving([])
|
||||
result = manager.truncator.truncate_by_halving([])
|
||||
|
||||
assert result == []
|
||||
|
||||
def test_compress_by_halving_single_message(self):
|
||||
"""Test _compress_by_halving with single message."""
|
||||
def test_truncate_by_halving_single_message(self):
|
||||
"""Test truncate_by_halving with single message."""
|
||||
config = ContextConfig()
|
||||
manager = ContextManager(config)
|
||||
|
||||
messages = [self.create_message("user", "Hello")]
|
||||
result = manager._compress_by_halving(messages)
|
||||
result = manager.truncator.truncate_by_halving(messages)
|
||||
|
||||
assert len(result) <= 1
|
||||
|
||||
@@ -557,19 +554,21 @@ class TestContextManager:
|
||||
assert non_system[0].role == "user"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_compression_threshold_constant(self):
|
||||
"""Test that COMPRESSION_THRESHOLD is used correctly."""
|
||||
async def test_compression_threshold_default(self):
|
||||
"""Test that compression threshold is used correctly."""
|
||||
config = ContextConfig(max_context_tokens=100)
|
||||
manager = ContextManager(config)
|
||||
|
||||
# Verify the threshold is 0.82
|
||||
assert manager.COMPRESSION_THRESHOLD == 0.82
|
||||
# Verify the default threshold is 0.82
|
||||
assert manager.compressor.compression_threshold == 0.82
|
||||
|
||||
# Create messages just below threshold
|
||||
# Test threshold logic
|
||||
messages = [self.create_message("user", "x" * 81)] # ~24 tokens
|
||||
tokens = manager.token_counter.count_tokens(messages)
|
||||
|
||||
needs_compression, _ = await manager._initial_token_check(messages)
|
||||
assert not needs_compression
|
||||
needs_compression = manager.compressor.should_compress(messages, tokens, 100)
|
||||
# Should not compress if below threshold
|
||||
assert needs_compression == (tokens > 82)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_large_batch_processing(self):
|
||||
@@ -608,39 +607,29 @@ class TestContextManager:
|
||||
# ==================== Run Compression Tests ====================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_compression_skip_when_not_needed(self):
|
||||
"""Test _run_compression skips when needs_compression is False."""
|
||||
async def test_run_compression_calls_compressor(self):
|
||||
"""Test _run_compression calls compressor."""
|
||||
config = ContextConfig(max_context_tokens=100)
|
||||
manager = ContextManager(config)
|
||||
|
||||
messages = self.create_messages(5)
|
||||
compressed = self.create_messages(3)
|
||||
|
||||
with patch.object(
|
||||
manager.compressor, "__call__", new_callable=AsyncMock
|
||||
) as mock_compress:
|
||||
result = await manager._run_compression(messages, needs_compression=False)
|
||||
# Create a mock compressor
|
||||
mock_compressor = AsyncMock()
|
||||
mock_compressor.compression_threshold = 0.82
|
||||
mock_compressor.return_value = compressed
|
||||
mock_compressor.should_compress = MagicMock(return_value=False)
|
||||
manager.compressor = mock_compressor
|
||||
|
||||
mock_compress.assert_not_called()
|
||||
assert result == messages
|
||||
result = await manager._run_compression(messages, prev_tokens=100)
|
||||
|
||||
# Compressor __call__ should be invoked
|
||||
mock_compressor.assert_called_once_with(messages)
|
||||
assert result == compressed
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_compression_skip_when_zero_limit(self):
|
||||
"""Test _run_compression skips when max_context_tokens is 0."""
|
||||
config = ContextConfig(max_context_tokens=0)
|
||||
manager = ContextManager(config)
|
||||
|
||||
messages = self.create_messages(5)
|
||||
|
||||
with patch.object(
|
||||
manager.compressor, "__call__", new_callable=AsyncMock
|
||||
) as mock_compress:
|
||||
result = await manager._run_compression(messages, needs_compression=True)
|
||||
|
||||
mock_compress.assert_not_called()
|
||||
assert result == messages
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_compression_applies_compressor(self):
|
||||
async def test_run_compression_applies_compressor_through_process(self):
|
||||
"""Test _run_compression calls compressor when needed through process()."""
|
||||
config = ContextConfig(max_context_tokens=100, truncate_turns=1)
|
||||
manager = ContextManager(config)
|
||||
@@ -649,14 +638,26 @@ class TestContextManager:
|
||||
messages = [self.create_message("user", "x" * 300)] # ~90 tokens > 82 threshold
|
||||
compressed = [self.create_message("user", "short")] # Much smaller
|
||||
|
||||
# Replace compressor with mock
|
||||
mock_compress = AsyncMock(return_value=compressed)
|
||||
manager.compressor = mock_compress
|
||||
# Create a mock compressor
|
||||
mock_compressor = AsyncMock()
|
||||
mock_compressor.compression_threshold = 0.82
|
||||
mock_compressor.return_value = compressed
|
||||
|
||||
# Mock should_compress to return True first time, False after
|
||||
call_count = 0
|
||||
|
||||
def mock_should_compress(*args, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return call_count == 1
|
||||
|
||||
mock_compressor.should_compress = mock_should_compress
|
||||
manager.compressor = mock_compressor
|
||||
|
||||
result = await manager.process(messages)
|
||||
|
||||
# Compressor should have been called
|
||||
mock_compress.assert_called_once()
|
||||
mock_compressor.assert_called_once()
|
||||
assert len(result) <= len(messages)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
Reference in New Issue
Block a user