feat: implement context compression logic with dynamic threshold and token tracking

This commit is contained in:
Soulter
2026-01-05 14:12:13 +08:00
parent cb84db532e
commit af444ea6cc
7 changed files with 273 additions and 193 deletions
+67 -4
View File
@@ -25,6 +25,21 @@ class ContextCompressor(Protocol):
Provides an interface for compressing message lists.
"""
def should_compress(
self, messages: list[Message], current_tokens: int, max_tokens: int
) -> bool:
"""Check if compression is needed.
Args:
messages: The message list to evaluate.
current_tokens: The current token count.
max_tokens: The maximum allowed tokens for the model.
Returns:
True if compression is needed, False otherwise.
"""
...
async def __call__(self, messages: list[Message]) -> list[Message]:
"""Compress the message list.
@@ -42,13 +57,33 @@ class TruncateByTurnsCompressor:
Truncates the message list by removing older turns.
"""
def __init__(self, truncate_turns: int = 1):
def __init__(self, truncate_turns: int = 1, compression_threshold: float = 0.82):
"""Initialize the truncate by turns compressor.
Args:
truncate_turns: The number of turns to remove when truncating (default: 1).
compression_threshold: The compression trigger threshold (default: 0.82).
"""
self.truncate_turns = truncate_turns
self.compression_threshold = compression_threshold
def should_compress(
self, messages: list[Message], current_tokens: int, max_tokens: int
) -> bool:
"""Check if compression is needed.
Args:
messages: The message list to evaluate.
current_tokens: The current token count.
max_tokens: The maximum allowed tokens.
Returns:
True if compression is needed, False otherwise.
"""
if max_tokens <= 0 or current_tokens <= 0:
return False
usage_rate = current_tokens / max_tokens
return usage_rate > self.compression_threshold
async def __call__(self, messages: list[Message]) -> list[Message]:
truncator = ContextTruncator()
@@ -116,15 +151,19 @@ class LLMSummaryCompressor:
provider: "Provider",
keep_recent: int = 4,
instruction_text: str | None = None,
compression_threshold: float = 0.82,
):
"""Initialize the LLM summary compressor.
Args:
provider: The LLM provider instance.
keep_recent: The number of latest messages to keep (default: 4).
instruction_text: Custom instruction for summary generation.
compression_threshold: The compression trigger threshold (default: 0.82).
"""
self.provider = provider
self.keep_recent = keep_recent
self.compression_threshold = compression_threshold
self.instruction_text = instruction_text or (
"Based on our full conversation history, produce a concise summary of key takeaways and/or project progress.\n"
@@ -134,6 +173,24 @@ class LLMSummaryCompressor:
"4. Write the summary in the user's language.\n"
)
def should_compress(
self, messages: list[Message], current_tokens: int, max_tokens: int
) -> bool:
"""Check if compression is needed.
Args:
messages: The message list to evaluate.
current_tokens: The current token count.
max_tokens: The maximum allowed tokens.
Returns:
True if compression is needed, False otherwise.
"""
if max_tokens <= 0 or current_tokens <= 0:
return False
usage_rate = current_tokens / max_tokens
return usage_rate > self.compression_threshold
async def __call__(self, messages: list[Message]) -> list[Message]:
"""Use LLM to generate a summary of the conversation history.
@@ -170,9 +227,15 @@ class LLMSummaryCompressor:
result.append(
Message(
role="system",
content=f"History conversation summary: {summary_content}",
),
role="user",
content=f"Our previous history conversation summary: {summary_content}",
)
)
result.append(
Message(
role="assistant",
content="Acknowledged the summary of our previous conversation history.",
)
)
result.extend(recent_messages)
+7
View File
@@ -1,6 +1,9 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING
from .compressor import ContextCompressor
from .token_counter import TokenCounter
if TYPE_CHECKING:
from astrbot.core.provider.provider import Provider
@@ -26,3 +29,7 @@ class ContextConfig:
"""Number of recent messages to keep during LLM-based compression."""
llm_compress_provider: "Provider | None" = None
"""LLM provider used for compression tasks. If None, truncation strategy is used."""
custom_token_counter: TokenCounter | None = None
"""Custom token counting method. If None, the default method is used."""
custom_compressor: ContextCompressor | None = None
"""Custom context compression method. If None, the default method is used."""
+22 -60
View File
@@ -3,16 +3,13 @@ from astrbot import logger
from ..message import Message
from .compressor import LLMSummaryCompressor, TruncateByTurnsCompressor
from .config import ContextConfig
from .token_counter import TokenCounter
from .token_counter import EstimateTokenCounter
from .truncator import ContextTruncator
class ContextManager:
"""Context compression manager."""
COMPRESSION_THRESHOLD = 0.82
"""compression trigger threshold"""
def __init__(
self,
config: ContextConfig,
@@ -28,10 +25,12 @@ class ContextManager:
"""
self.config = config
self.token_counter = TokenCounter()
self.token_counter = config.custom_token_counter or EstimateTokenCounter()
self.truncator = ContextTruncator()
if config.llm_compress_provider:
if config.custom_compressor:
self.compressor = config.custom_compressor
elif config.llm_compress_provider:
self.compressor = LLMSummaryCompressor(
provider=config.llm_compress_provider,
keep_recent=config.llm_compress_keep_recent,
@@ -64,60 +63,32 @@ class ContextManager:
# 2. 基于 token 的压缩
if self.config.max_context_tokens > 0:
# check if the messages need to be compressed
needs_compression, tokens = await self._initial_token_check(result)
total_tokens = self.token_counter.count_tokens(result)
# compress/truncate the messages if needed
result = await self._run_compression(result, tokens, needs_compression)
if self.compressor.should_compress(
result, total_tokens, self.config.max_context_tokens
):
result = await self._run_compression(result, total_tokens)
return result
except Exception as e:
logger.error(f"Error during context processing: {e}", exc_info=True)
return messages
async def _initial_token_check(self, messages: list[Message]) -> tuple[bool, int]:
"""
Check if the messages need to be compressed.
Args:
messages: The original message list.
Returns:
tuple: (whether to compress, initial token count)
"""
if not messages:
return False, 0
if self.config.max_context_tokens <= 0:
return False, 0
total_tokens = self.token_counter.count_tokens(messages)
usage_rate = total_tokens / self.config.max_context_tokens
needs_compression = usage_rate > self.COMPRESSION_THRESHOLD
return needs_compression, total_tokens if needs_compression else 0
async def _run_compression(
self, messages: list[Message], prev_tokens: int, needs_compression: bool
self, messages: list[Message], prev_tokens: int
) -> list[Message]:
"""
Compress/truncate the messages if needed.
Compress/truncate the messages.
Args:
messages: The original message list.
needs_compression: Whether to compress.
prev_tokens: The token count before compression.
Returns:
The compressed/truncated message list.
"""
if not needs_compression:
return messages
if self.config.max_context_tokens <= 0:
return messages
logger.debug(
f"Reached high water mark {self.COMPRESSION_THRESHOLD}, starting compression..."
)
logger.debug("Compress triggered, starting compression...")
messages = await self.compressor(messages)
@@ -132,23 +103,14 @@ class ContextManager:
f" compression rate: {compress_rate:.2f}%.",
)
if (
tokens_after_summary / self.config.max_context_tokens
> self.COMPRESSION_THRESHOLD
# last check
if self.compressor.should_compress(
messages, tokens_after_summary, self.config.max_context_tokens
):
# still over 82%, truncate by half
messages = self._compress_by_halving(messages)
logger.info(
"Context still exceeds max tokens after compression, applying halving truncation..."
)
# still need compress, truncate by half
messages = self.truncator.truncate_by_halving(messages)
return messages
def _compress_by_halving(self, messages: list[Message]) -> list[Message]:
"""
对半砍策略:删除中间50%的消息
Args:
messages: 原始消息列表
Returns:
截断后的消息列表
"""
return self.truncator.truncate_by_halving(messages)
+25 -1
View File
@@ -1,9 +1,33 @@
import json
from typing import Protocol, runtime_checkable
from ..message import Message, TextPart
class TokenCounter:
@runtime_checkable
class TokenCounter(Protocol):
"""
Protocol for token counters.
Provides an interface for counting tokens in message lists.
"""
def count_tokens(self, messages: list[Message]) -> int:
"""Count the total tokens in the message list.
Args:
messages: The message list.
Returns:
The total token count.
"""
...
class EstimateTokenCounter:
"""Estimate token counter implementation.
Provides a simple estimation of token count based on character types.
"""
def count_tokens(self, messages: list[Message]) -> int:
total = 0
for msg in messages:
@@ -25,6 +25,8 @@ from astrbot.core.provider.entities import (
)
from astrbot.core.provider.provider import Provider
from ..context.compressor import ContextCompressor
from ..context.token_counter import TokenCounter
from ..context.config import ContextConfig
from ..context.manager import ContextManager
from ..hooks import BaseAgentRunHooks
@@ -49,24 +51,30 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
run_context: ContextWrapper[TContext],
tool_executor: BaseFunctionToolExecutor[TContext],
agent_hooks: BaseAgentRunHooks[TContext],
streaming: bool = False,
# enforce max turns, will discard older turns when exceeded BEFORE compression
# -1 means no limit
enforce_max_turns: int = -1,
# llm compressor
llm_compress_instruction: str | None = None,
llm_compress_keep_recent: int = 0,
llm_compress_provider: Provider | None = None,
# truncate by turns compressor
truncate_turns: int = 1,
# customize
custom_token_counter: TokenCounter | None = None,
custom_compressor: ContextCompressor | None = None,
**kwargs: T.Any,
) -> None:
self.req = request
self.streaming = kwargs.get("streaming", False)
# enforce max turns, will discard older turns when exceeded BEFORE compression
# -1 means no limit
self.enforce_max_turns = kwargs.get("enforce_max_turns", -1)
# llm compressor
self.llm_compress_instruction = kwargs.get("llm_compress_instruction", None)
self.llm_compress_keep_recent = kwargs.get("llm_compress_keep_recent", 0)
self.llm_compress_provider: Provider | None = kwargs.get(
"llm_compress_provider", None
)
# truncate by turns compressor
self.truncate_turns = kwargs.get("truncate_turns", 1)
self.streaming = streaming
self.enforce_max_turns = enforce_max_turns
self.llm_compress_instruction = llm_compress_instruction
self.llm_compress_keep_recent = llm_compress_keep_recent
self.llm_compress_provider = llm_compress_provider
self.truncate_turns = truncate_turns
self.custom_token_counter = custom_token_counter
self.custom_compressor = custom_compressor
# we will do compress when:
# 1. before requesting LLM
# TODO: 2. after LLM output a tool call
@@ -79,6 +87,8 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
llm_compress_instruction=self.llm_compress_instruction,
llm_compress_keep_recent=self.llm_compress_keep_recent,
llm_compress_provider=self.llm_compress_provider,
custom_token_counter=self.custom_token_counter,
custom_compressor=self.custom_compressor,
)
self.context_manager = ContextManager(self.context_config)
+14 -1
View File
@@ -149,9 +149,12 @@ class Context:
contexts: context messages for the LLM
max_steps: Maximum number of tool calls before stopping the loop
**kwargs: Additional keyword arguments. The kwargs will not be passed to the LLM directly for now, but can include:
stream: bool - whether to stream the LLM response
agent_hooks: BaseAgentRunHooks[AstrAgentContext] - hooks to run during agent execution
agent_context: AstrAgentContext - context to use for the agent
other kwargs will be DIRECTLY passed to the runner.reset() method
Returns:
The final LLMResponse after tool calls are completed.
@@ -194,6 +197,15 @@ class Context:
)
agent_runner = ToolLoopAgentRunner()
tool_executor = FunctionToolExecutor()
streaming = kwargs.get("stream", False)
other_kwargs = {
k: v
for k, v in kwargs.items()
if k not in ["stream", "agent_hooks", "agent_context"]
}
await agent_runner.reset(
provider=prov,
request=request,
@@ -203,7 +215,8 @@ class Context:
),
tool_executor=tool_executor,
agent_hooks=agent_hooks,
streaming=kwargs.get("stream", False),
streaming=streaming,
**other_kwargs,
)
async for _ in agent_runner.step_until_done(max_steps):
pass
+113 -112
View File
@@ -197,13 +197,18 @@ class TestContextManager:
messages = [self.create_message("user", "Hi" * 50)] # ~100 tokens
with patch.object(
manager.compressor, "__call__", new_callable=AsyncMock
) as mock_compress:
result = await manager.process(messages)
manager.compressor, "should_compress", return_value=False
) as mock_should_compress:
with patch.object(
manager.compressor, "__call__", new_callable=AsyncMock
) as mock_compress:
result = await manager.process(messages)
# Compressor should not be called
mock_compress.assert_not_called()
assert result == messages
# should_compress should be called
mock_should_compress.assert_called_once()
# Compressor should not be called
mock_compress.assert_not_called()
assert result == messages
@pytest.mark.asyncio
async def test_token_compression_triggered_above_threshold(self):
@@ -219,16 +224,28 @@ class TestContextManager:
# Mock compressor to return smaller result
compressed = [self.create_message("user", "short")]
# Create a mock compressor that we can track
mock_compress = AsyncMock(return_value=compressed)
manager.compressor = mock_compress
# Create a mock compressor
mock_compressor = AsyncMock()
mock_compressor.compression_threshold = 0.82
mock_compressor.return_value = compressed
# Mock should_compress to return True first time, False after
call_count = 0
def mock_should_compress(*args, **kwargs):
nonlocal call_count
call_count += 1
return call_count == 1
mock_compressor.should_compress = mock_should_compress
manager.compressor = mock_compressor
result = await manager.process(messages)
# Compressor should be called
mock_compress.assert_called_once()
mock_compressor.assert_called_once()
# Result should be the compressed version
assert len(result) == len(compressed)
assert len(result) <= len(messages)
@pytest.mark.asyncio
async def test_token_compression_with_zero_max_tokens(self):
@@ -243,7 +260,7 @@ class TestContextManager:
) as mock_compress:
result = await manager.process(messages)
# Compressor should not be called
# Compressor should not be called when max_context_tokens is 0
mock_compress.assert_not_called()
assert result == messages
@@ -277,14 +294,18 @@ class TestContextManager:
async def mock_compress(msgs):
return msgs # Return same messages (still over limit)
with patch.object(manager.compressor, "__call__", new=mock_compress):
with patch.object(
manager, "_compress_by_halving", return_value=long_messages[:5]
) as mock_halving:
_ = await manager.process(long_messages)
# Mock should_compress to return True twice (before and after compression)
with patch.object(manager.compressor, "should_compress", return_value=True):
with patch.object(manager.compressor, "__call__", new=mock_compress):
with patch.object(
manager.truncator,
"truncate_by_halving",
return_value=long_messages[:5],
) as mock_halving:
_ = await manager.process(long_messages)
# Halving should be called
mock_halving.assert_called_once()
# Halving should be called
mock_halving.assert_called_once()
# ==================== Combined Truncation and Compression Tests ====================
@@ -352,7 +373,10 @@ class TestContextManager:
messages = [self.create_message("user", "x" * 300)] # ~90 tokens
# Replace compressor with one that raises an exception
manager.compressor = AsyncMock(side_effect=Exception("Test error"))
mock_compressor = AsyncMock(side_effect=Exception("Test error"))
mock_compressor.compression_threshold = 0.82
mock_compressor.should_compress = MagicMock(return_value=True)
manager.compressor = mock_compressor
with patch("astrbot.core.agent.context.manager.logger") as mock_logger:
result = await manager.process(messages)
@@ -393,9 +417,10 @@ class TestContextManager:
]
# Should trigger compression due to token count
needs_compression, tokens = await manager._initial_token_check(messages)
tokens = manager.token_counter.count_tokens(messages)
needs_compression = manager.compressor.should_compress(messages, tokens, 50)
assert tokens is not None # Tokens should be counted
assert tokens > 0 # Tokens should be counted
assert needs_compression # Should trigger compression
# ==================== Tool Calls Tests ====================
@@ -425,101 +450,73 @@ class TestContextManager:
assert len(result) == 2
# ==================== Initial Token Check Tests ====================
# ==================== Compressor should_compress Tests ====================
@pytest.mark.asyncio
async def test_initial_token_check_empty_messages(self):
"""Test _initial_token_check with empty messages."""
async def test_should_compress_empty_messages(self):
"""Test should_compress with empty messages."""
config = ContextConfig(max_context_tokens=100)
manager = ContextManager(config)
needs_compression, tokens = await manager._initial_token_check([])
# Compressor's should_compress should handle empty gracefully
needs_compression = manager.compressor.should_compress([], 0, 100)
assert not needs_compression
assert tokens is None
@pytest.mark.asyncio
async def test_initial_token_check_no_limit(self):
"""Test _initial_token_check when max_context_tokens is 0."""
config = ContextConfig(max_context_tokens=0)
manager = ContextManager(config)
messages = [self.create_message("user", "x" * 1000)]
needs_compression, tokens = await manager._initial_token_check(messages)
assert not needs_compression
assert tokens is None
@pytest.mark.asyncio
async def test_initial_token_check_below_threshold(self):
"""Test _initial_token_check when below compression threshold."""
async def test_should_compress_below_threshold(self):
"""Test should_compress when below compression threshold."""
config = ContextConfig(max_context_tokens=1000)
manager = ContextManager(config)
messages = [self.create_message("user", "Hello")]
needs_compression, tokens = await manager._initial_token_check(messages)
tokens = manager.token_counter.count_tokens(messages)
needs_compression = manager.compressor.should_compress(messages, tokens, 1000)
assert not needs_compression
assert tokens is None
@pytest.mark.asyncio
async def test_initial_token_check_above_threshold(self):
"""Test _initial_token_check when above compression threshold."""
async def test_should_compress_above_threshold(self):
"""Test should_compress when above compression threshold."""
config = ContextConfig(max_context_tokens=100)
manager = ContextManager(config)
# Create message with ~90 tokens (above 0.82 * 100 = 82)
# Create message with many tokens
messages = [self.create_message("user", "这是测试" * 50)]
needs_compression, tokens = await manager._initial_token_check(messages)
tokens = manager.token_counter.count_tokens(messages)
assert needs_compression
assert tokens is not None
assert tokens > 82
needs_compression = manager.compressor.should_compress(messages, tokens, 100)
# Should need compression if tokens > 82 (0.82 * 100)
assert needs_compression == (tokens > 82)
@pytest.mark.asyncio
async def test_initial_token_check_exactly_at_threshold(self):
"""Test _initial_token_check when just above threshold."""
config = ContextConfig(max_context_tokens=100)
manager = ContextManager(config)
# ==================== Truncator Halving Tests ====================
# Create message with >82 tokens (0.82 * 100)
# 300 chars * 0.3 = 90 tokens > 82 (threshold)
messages = [self.create_message("user", "x" * 300)] # ~90 tokens
needs_compression, tokens = await manager._initial_token_check(messages)
# Above threshold should trigger compression
assert tokens is not None
assert needs_compression
# ==================== Compression by Halving Tests ====================
def test_compress_by_halving_basic(self):
"""Test _compress_by_halving removes middle 50%."""
def test_truncate_by_halving_basic(self):
"""Test truncate_by_halving removes middle 50%."""
config = ContextConfig()
manager = ContextManager(config)
messages = self.create_messages(10)
result = manager._compress_by_halving(messages)
result = manager.truncator.truncate_by_halving(messages)
# Should keep roughly half
assert len(result) < len(messages)
def test_compress_by_halving_empty_list(self):
"""Test _compress_by_halving with empty list."""
def test_truncate_by_halving_empty_list(self):
"""Test truncate_by_halving with empty list."""
config = ContextConfig()
manager = ContextManager(config)
result = manager._compress_by_halving([])
result = manager.truncator.truncate_by_halving([])
assert result == []
def test_compress_by_halving_single_message(self):
"""Test _compress_by_halving with single message."""
def test_truncate_by_halving_single_message(self):
"""Test truncate_by_halving with single message."""
config = ContextConfig()
manager = ContextManager(config)
messages = [self.create_message("user", "Hello")]
result = manager._compress_by_halving(messages)
result = manager.truncator.truncate_by_halving(messages)
assert len(result) <= 1
@@ -557,19 +554,21 @@ class TestContextManager:
assert non_system[0].role == "user"
@pytest.mark.asyncio
async def test_compression_threshold_constant(self):
"""Test that COMPRESSION_THRESHOLD is used correctly."""
async def test_compression_threshold_default(self):
"""Test that compression threshold is used correctly."""
config = ContextConfig(max_context_tokens=100)
manager = ContextManager(config)
# Verify the threshold is 0.82
assert manager.COMPRESSION_THRESHOLD == 0.82
# Verify the default threshold is 0.82
assert manager.compressor.compression_threshold == 0.82
# Create messages just below threshold
# Test threshold logic
messages = [self.create_message("user", "x" * 81)] # ~24 tokens
tokens = manager.token_counter.count_tokens(messages)
needs_compression, _ = await manager._initial_token_check(messages)
assert not needs_compression
needs_compression = manager.compressor.should_compress(messages, tokens, 100)
# Should not compress if below threshold
assert needs_compression == (tokens > 82)
@pytest.mark.asyncio
async def test_large_batch_processing(self):
@@ -608,39 +607,29 @@ class TestContextManager:
# ==================== Run Compression Tests ====================
@pytest.mark.asyncio
async def test_run_compression_skip_when_not_needed(self):
"""Test _run_compression skips when needs_compression is False."""
async def test_run_compression_calls_compressor(self):
"""Test _run_compression calls compressor."""
config = ContextConfig(max_context_tokens=100)
manager = ContextManager(config)
messages = self.create_messages(5)
compressed = self.create_messages(3)
with patch.object(
manager.compressor, "__call__", new_callable=AsyncMock
) as mock_compress:
result = await manager._run_compression(messages, needs_compression=False)
# Create a mock compressor
mock_compressor = AsyncMock()
mock_compressor.compression_threshold = 0.82
mock_compressor.return_value = compressed
mock_compressor.should_compress = MagicMock(return_value=False)
manager.compressor = mock_compressor
mock_compress.assert_not_called()
assert result == messages
result = await manager._run_compression(messages, prev_tokens=100)
# Compressor __call__ should be invoked
mock_compressor.assert_called_once_with(messages)
assert result == compressed
@pytest.mark.asyncio
async def test_run_compression_skip_when_zero_limit(self):
"""Test _run_compression skips when max_context_tokens is 0."""
config = ContextConfig(max_context_tokens=0)
manager = ContextManager(config)
messages = self.create_messages(5)
with patch.object(
manager.compressor, "__call__", new_callable=AsyncMock
) as mock_compress:
result = await manager._run_compression(messages, needs_compression=True)
mock_compress.assert_not_called()
assert result == messages
@pytest.mark.asyncio
async def test_run_compression_applies_compressor(self):
async def test_run_compression_applies_compressor_through_process(self):
"""Test _run_compression calls compressor when needed through process()."""
config = ContextConfig(max_context_tokens=100, truncate_turns=1)
manager = ContextManager(config)
@@ -649,14 +638,26 @@ class TestContextManager:
messages = [self.create_message("user", "x" * 300)] # ~90 tokens > 82 threshold
compressed = [self.create_message("user", "short")] # Much smaller
# Replace compressor with mock
mock_compress = AsyncMock(return_value=compressed)
manager.compressor = mock_compress
# Create a mock compressor
mock_compressor = AsyncMock()
mock_compressor.compression_threshold = 0.82
mock_compressor.return_value = compressed
# Mock should_compress to return True first time, False after
call_count = 0
def mock_should_compress(*args, **kwargs):
nonlocal call_count
call_count += 1
return call_count == 1
mock_compressor.should_compress = mock_should_compress
manager.compressor = mock_compressor
result = await manager.process(messages)
# Compressor should have been called
mock_compress.assert_called_once()
mock_compressor.assert_called_once()
assert len(result) <= len(messages)
@pytest.mark.asyncio