feat: context compress (#4322)

* feat: context compressor

Co-authored-by: kawayiYokami <289104862@qq.com>

* Add comprehensive tests for ContextManager and ContextTruncator

- Implemented a full test suite for ContextManager covering initialization, message processing, token-based compression, and error handling.
- Added tests for ContextTruncator focusing on message fixing, truncation by turns, dropping oldest turns, and halving.
- Ensured that both test suites validate edge cases and maintain expected behavior with various message types, including system and tool messages.

* feat: add MockProvider for LLM compression tests

* chore: remove lock

* ruff fix

* fix

* perf

* feat: enhance context compression with token tracking and logging

* feat: update logging for context compression trigger

* feat: implement context compression logic with dynamic threshold and token tracking

* fix: reorder import statements for consistency

* feat: add token_usage tracking to conversations and update related processing logic

---------

Co-authored-by: kawayiYokami <289104862@qq.com>
This commit is contained in:
Soulter
2026-01-05 17:26:10 +08:00
committed by GitHub
parent 3615b7dde2
commit 241f1c26d3
21 changed files with 2184 additions and 100 deletions
+243
View File
@@ -0,0 +1,243 @@
from typing import TYPE_CHECKING, Protocol, runtime_checkable
from ..message import Message
if TYPE_CHECKING:
from astrbot import logger
else:
try:
from astrbot import logger
except ImportError:
import logging
logger = logging.getLogger("astrbot")
if TYPE_CHECKING:
from astrbot.core.provider.provider import Provider
from ..context.truncator import ContextTruncator
@runtime_checkable
class ContextCompressor(Protocol):
"""
Protocol for context compressors.
Provides an interface for compressing message lists.
"""
def should_compress(
self, messages: list[Message], current_tokens: int, max_tokens: int
) -> bool:
"""Check if compression is needed.
Args:
messages: The message list to evaluate.
current_tokens: The current token count.
max_tokens: The maximum allowed tokens for the model.
Returns:
True if compression is needed, False otherwise.
"""
...
async def __call__(self, messages: list[Message]) -> list[Message]:
"""Compress the message list.
Args:
messages: The original message list.
Returns:
The compressed message list.
"""
...
class TruncateByTurnsCompressor:
"""Truncate by turns compressor implementation.
Truncates the message list by removing older turns.
"""
def __init__(self, truncate_turns: int = 1, compression_threshold: float = 0.82):
"""Initialize the truncate by turns compressor.
Args:
truncate_turns: The number of turns to remove when truncating (default: 1).
compression_threshold: The compression trigger threshold (default: 0.82).
"""
self.truncate_turns = truncate_turns
self.compression_threshold = compression_threshold
def should_compress(
self, messages: list[Message], current_tokens: int, max_tokens: int
) -> bool:
"""Check if compression is needed.
Args:
messages: The message list to evaluate.
current_tokens: The current token count.
max_tokens: The maximum allowed tokens.
Returns:
True if compression is needed, False otherwise.
"""
if max_tokens <= 0 or current_tokens <= 0:
return False
usage_rate = current_tokens / max_tokens
return usage_rate > self.compression_threshold
async def __call__(self, messages: list[Message]) -> list[Message]:
truncator = ContextTruncator()
truncated_messages = truncator.truncate_by_dropping_oldest_turns(
messages,
drop_turns=self.truncate_turns,
)
return truncated_messages
def split_history(
messages: list[Message], keep_recent: int
) -> tuple[list[Message], list[Message], list[Message]]:
"""Split the message list into system messages, messages to summarize, and recent messages.
Ensures that the split point is between complete user-assistant pairs to maintain conversation flow.
Args:
messages: The original message list.
keep_recent: The number of latest messages to keep.
Returns:
tuple: (system_messages, messages_to_summarize, recent_messages)
"""
# keep the system messages
first_non_system = 0
for i, msg in enumerate(messages):
if msg.role != "system":
first_non_system = i
break
system_messages = messages[:first_non_system]
non_system_messages = messages[first_non_system:]
if len(non_system_messages) <= keep_recent:
return system_messages, [], non_system_messages
# Find the split point, ensuring recent_messages starts with a user message
# This maintains complete conversation turns
split_index = len(non_system_messages) - keep_recent
# Search backward from split_index to find the first user message
# This ensures recent_messages starts with a user message (complete turn)
while split_index > 0 and non_system_messages[split_index].role != "user":
# TODO: +=1 or -=1 ? calculate by tokens
split_index -= 1
# If we couldn't find a user message, keep all messages as recent
if split_index == 0:
return system_messages, [], non_system_messages
messages_to_summarize = non_system_messages[:split_index]
recent_messages = non_system_messages[split_index:]
return system_messages, messages_to_summarize, recent_messages
class LLMSummaryCompressor:
"""LLM-based summary compressor.
Uses LLM to summarize the old conversation history, keeping the latest messages.
"""
def __init__(
self,
provider: "Provider",
keep_recent: int = 4,
instruction_text: str | None = None,
compression_threshold: float = 0.82,
):
"""Initialize the LLM summary compressor.
Args:
provider: The LLM provider instance.
keep_recent: The number of latest messages to keep (default: 4).
instruction_text: Custom instruction for summary generation.
compression_threshold: The compression trigger threshold (default: 0.82).
"""
self.provider = provider
self.keep_recent = keep_recent
self.compression_threshold = compression_threshold
self.instruction_text = instruction_text or (
"Based on our full conversation history, produce a concise summary of key takeaways and/or project progress.\n"
"1. Systematically cover all core topics discussed and the final conclusion/outcome for each; clearly highlight the latest primary focus.\n"
"2. If any tools were used, summarize tool usage (total call count) and extract the most valuable insights from tool outputs.\n"
"3. If there was an initial user goal, state it first and describe the current progress/status.\n"
"4. Write the summary in the user's language.\n"
)
def should_compress(
self, messages: list[Message], current_tokens: int, max_tokens: int
) -> bool:
"""Check if compression is needed.
Args:
messages: The message list to evaluate.
current_tokens: The current token count.
max_tokens: The maximum allowed tokens.
Returns:
True if compression is needed, False otherwise.
"""
if max_tokens <= 0 or current_tokens <= 0:
return False
usage_rate = current_tokens / max_tokens
return usage_rate > self.compression_threshold
async def __call__(self, messages: list[Message]) -> list[Message]:
"""Use LLM to generate a summary of the conversation history.
Process:
1. Divide messages: keep the system message and the latest N messages.
2. Send the old messages + the instruction message to the LLM.
3. Reconstruct the message list: [system message, summary message, latest messages].
"""
if len(messages) <= self.keep_recent + 1:
return messages
system_messages, messages_to_summarize, recent_messages = split_history(
messages, self.keep_recent
)
if not messages_to_summarize:
return messages
# build payload
instruction_message = Message(role="user", content=self.instruction_text)
llm_payload = messages_to_summarize + [instruction_message]
# generate summary
try:
response = await self.provider.text_chat(contexts=llm_payload)
summary_content = response.completion_text
except Exception as e:
logger.error(f"Failed to generate summary: {e}")
return messages
# build result
result = []
result.extend(system_messages)
result.append(
Message(
role="user",
content=f"Our previous history conversation summary: {summary_content}",
)
)
result.append(
Message(
role="assistant",
content="Acknowledged the summary of our previous conversation history.",
)
)
result.extend(recent_messages)
return result
+35
View File
@@ -0,0 +1,35 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING
from .compressor import ContextCompressor
from .token_counter import TokenCounter
if TYPE_CHECKING:
from astrbot.core.provider.provider import Provider
@dataclass
class ContextConfig:
"""Context configuration class."""
max_context_tokens: int = 0
"""Maximum number of context tokens. <= 0 means no limit."""
enforce_max_turns: int = -1 # -1 means no limit
"""Maximum number of conversation turns to keep. -1 means no limit. Executed before compression."""
truncate_turns: int = 1
"""Number of conversation turns to discard at once when truncation is triggered.
Two processes will use this value:
1. Enforce max turns truncation.
2. Truncation by turns compression strategy.
"""
llm_compress_instruction: str | None = None
"""Instruction prompt for LLM-based compression."""
llm_compress_keep_recent: int = 0
"""Number of recent messages to keep during LLM-based compression."""
llm_compress_provider: "Provider | None" = None
"""LLM provider used for compression tasks. If None, truncation strategy is used."""
custom_token_counter: TokenCounter | None = None
"""Custom token counting method. If None, the default method is used."""
custom_compressor: ContextCompressor | None = None
"""Custom context compression method. If None, the default method is used."""
+120
View File
@@ -0,0 +1,120 @@
from astrbot import logger
from ..message import Message
from .compressor import LLMSummaryCompressor, TruncateByTurnsCompressor
from .config import ContextConfig
from .token_counter import EstimateTokenCounter
from .truncator import ContextTruncator
class ContextManager:
"""Context compression manager."""
def __init__(
self,
config: ContextConfig,
):
"""Initialize the context manager.
There are two strategies to handle context limit reached:
1. Truncate by turns: remove older messages by turns.
2. LLM-based compression: use LLM to summarize old messages.
Args:
config: The context configuration.
"""
self.config = config
self.token_counter = config.custom_token_counter or EstimateTokenCounter()
self.truncator = ContextTruncator()
if config.custom_compressor:
self.compressor = config.custom_compressor
elif config.llm_compress_provider:
self.compressor = LLMSummaryCompressor(
provider=config.llm_compress_provider,
keep_recent=config.llm_compress_keep_recent,
instruction_text=config.llm_compress_instruction,
)
else:
self.compressor = TruncateByTurnsCompressor(
truncate_turns=config.truncate_turns
)
async def process(
self, messages: list[Message], trusted_token_usage: int = 0
) -> list[Message]:
"""Process the messages.
Args:
messages: The original message list.
Returns:
The processed message list.
"""
try:
result = messages
# 1. 基于轮次的截断 (Enforce max turns)
if self.config.enforce_max_turns != -1:
result = self.truncator.truncate_by_turns(
result,
keep_most_recent_turns=self.config.enforce_max_turns,
drop_turns=self.config.truncate_turns,
)
# 2. 基于 token 的压缩
if self.config.max_context_tokens > 0:
total_tokens = self.token_counter.count_tokens(
result, trusted_token_usage
)
if self.compressor.should_compress(
result, total_tokens, self.config.max_context_tokens
):
result = await self._run_compression(result, total_tokens)
return result
except Exception as e:
logger.error(f"Error during context processing: {e}", exc_info=True)
return messages
async def _run_compression(
self, messages: list[Message], prev_tokens: int
) -> list[Message]:
"""
Compress/truncate the messages.
Args:
messages: The original message list.
prev_tokens: The token count before compression.
Returns:
The compressed/truncated message list.
"""
logger.debug("Compress triggered, starting compression...")
messages = await self.compressor(messages)
# double check
tokens_after_summary = self.token_counter.count_tokens(messages)
# calculate compress rate
compress_rate = (tokens_after_summary / self.config.max_context_tokens) * 100
logger.info(
f"Compress completed."
f" {prev_tokens} -> {tokens_after_summary} tokens,"
f" compression rate: {compress_rate:.2f}%.",
)
# last check
if self.compressor.should_compress(
messages, tokens_after_summary, self.config.max_context_tokens
):
logger.info(
"Context still exceeds max tokens after compression, applying halving truncation..."
)
# still need compress, truncate by half
messages = self.truncator.truncate_by_halving(messages)
return messages
@@ -0,0 +1,64 @@
import json
from typing import Protocol, runtime_checkable
from ..message import Message, TextPart
@runtime_checkable
class TokenCounter(Protocol):
"""
Protocol for token counters.
Provides an interface for counting tokens in message lists.
"""
def count_tokens(
self, messages: list[Message], trusted_token_usage: int = 0
) -> int:
"""Count the total tokens in the message list.
Args:
messages: The message list.
trusted_token_usage: The total token usage that LLM API returned.
For some cases, this value is more accurate.
But some API does not return it, so the value defaults to 0.
Returns:
The total token count.
"""
...
class EstimateTokenCounter:
"""Estimate token counter implementation.
Provides a simple estimation of token count based on character types.
"""
def count_tokens(
self, messages: list[Message], trusted_token_usage: int = 0
) -> int:
if trusted_token_usage > 0:
return trusted_token_usage
total = 0
for msg in messages:
content = msg.content
if isinstance(content, str):
total += self._estimate_tokens(content)
elif isinstance(content, list):
# 处理多模态内容
for part in content:
if isinstance(part, TextPart):
total += self._estimate_tokens(part.text)
# 处理 Tool Calls
if msg.tool_calls:
for tc in msg.tool_calls:
tc_str = json.dumps(tc if isinstance(tc, dict) else tc.model_dump())
total += self._estimate_tokens(tc_str)
return total
def _estimate_tokens(self, text: str) -> int:
chinese_count = len([c for c in text if "\u4e00" <= c <= "\u9fff"])
other_count = len(text) - chinese_count
return int(chinese_count * 0.6 + other_count * 0.3)
+141
View File
@@ -0,0 +1,141 @@
from ..message import Message
class ContextTruncator:
"""Context truncator."""
def fix_messages(self, messages: list[Message]) -> list[Message]:
fixed_messages = []
for message in messages:
if message.role == "tool":
# tool block 前面必须要有 user 和 assistant block
if len(fixed_messages) < 2:
# 这种情况可能是上下文被截断导致的
# 我们直接将之前的上下文都清空
fixed_messages = []
else:
fixed_messages.append(message)
else:
fixed_messages.append(message)
return fixed_messages
def truncate_by_turns(
self,
messages: list[Message],
keep_most_recent_turns: int,
drop_turns: int = 1,
) -> list[Message]:
"""截断上下文列表,确保不超过最大长度。
一个 turn 包含一个 user 消息和一个 assistant 消息。
这个方法会保证截断后的上下文列表符合 OpenAI 的上下文格式。
Args:
messages: 上下文列表
keep_most_recent_turns: 保留最近的对话轮数
drop_turns: 一次性丢弃的对话轮数
Returns:
截断后的上下文列表
"""
if keep_most_recent_turns == -1:
return messages
first_non_system = 0
for i, msg in enumerate(messages):
if msg.role != "system":
first_non_system = i
break
system_messages = messages[:first_non_system]
non_system_messages = messages[first_non_system:]
if len(non_system_messages) // 2 <= keep_most_recent_turns:
return messages
num_to_keep = keep_most_recent_turns - drop_turns + 1
if num_to_keep <= 0:
truncated_contexts = []
else:
truncated_contexts = non_system_messages[-num_to_keep * 2 :]
# 找到第一个 role 为 user 的索引,确保上下文格式正确
index = next(
(i for i, item in enumerate(truncated_contexts) if item.role == "user"),
None,
)
if index is not None and index > 0:
truncated_contexts = truncated_contexts[index:]
result = system_messages + truncated_contexts
return self.fix_messages(result)
def truncate_by_dropping_oldest_turns(
self,
messages: list[Message],
drop_turns: int = 1,
) -> list[Message]:
"""丢弃最旧的 N 个对话轮次。"""
if drop_turns <= 0:
return messages
first_non_system = 0
for i, msg in enumerate(messages):
if msg.role != "system":
first_non_system = i
break
system_messages = messages[:first_non_system]
non_system_messages = messages[first_non_system:]
if len(non_system_messages) // 2 <= drop_turns:
truncated_non_system = []
else:
truncated_non_system = non_system_messages[drop_turns * 2 :]
index = next(
(i for i, item in enumerate(truncated_non_system) if item.role == "user"),
None,
)
if index is not None:
truncated_non_system = truncated_non_system[index:]
elif truncated_non_system:
truncated_non_system = []
result = system_messages + truncated_non_system
return self.fix_messages(result)
def truncate_by_halving(
self,
messages: list[Message],
) -> list[Message]:
"""对半砍策略,删除 50% 的消息"""
if len(messages) <= 2:
return messages
first_non_system = 0
for i, msg in enumerate(messages):
if msg.role != "system":
first_non_system = i
break
system_messages = messages[:first_non_system]
non_system_messages = messages[first_non_system:]
messages_to_delete = len(non_system_messages) // 2
if messages_to_delete == 0:
return messages
truncated_non_system = non_system_messages[messages_to_delete:]
index = next(
(i for i, item in enumerate(truncated_non_system) if item.role == "user"),
None,
)
if index is not None:
truncated_non_system = truncated_non_system[index:]
result = system_messages + truncated_non_system
return self.fix_messages(result)
@@ -25,6 +25,10 @@ from astrbot.core.provider.entities import (
)
from astrbot.core.provider.provider import Provider
from ..context.compressor import ContextCompressor
from ..context.config import ContextConfig
from ..context.manager import ContextManager
from ..context.token_counter import TokenCounter
from ..hooks import BaseAgentRunHooks
from ..message import AssistantMessageSegment, Message, ToolCallMessageSegment
from ..response import AgentResponseData, AgentStats
@@ -47,10 +51,47 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
run_context: ContextWrapper[TContext],
tool_executor: BaseFunctionToolExecutor[TContext],
agent_hooks: BaseAgentRunHooks[TContext],
streaming: bool = False,
# enforce max turns, will discard older turns when exceeded BEFORE compression
# -1 means no limit
enforce_max_turns: int = -1,
# llm compressor
llm_compress_instruction: str | None = None,
llm_compress_keep_recent: int = 0,
llm_compress_provider: Provider | None = None,
# truncate by turns compressor
truncate_turns: int = 1,
# customize
custom_token_counter: TokenCounter | None = None,
custom_compressor: ContextCompressor | None = None,
**kwargs: T.Any,
) -> None:
self.req = request
self.streaming = kwargs.get("streaming", False)
self.streaming = streaming
self.enforce_max_turns = enforce_max_turns
self.llm_compress_instruction = llm_compress_instruction
self.llm_compress_keep_recent = llm_compress_keep_recent
self.llm_compress_provider = llm_compress_provider
self.truncate_turns = truncate_turns
self.custom_token_counter = custom_token_counter
self.custom_compressor = custom_compressor
# we will do compress when:
# 1. before requesting LLM
# TODO: 2. after LLM output a tool call
self.context_config = ContextConfig(
# <=0 will never do compress
max_context_tokens=provider.provider_config.get("max_context_tokens", 0),
# enforce max turns before compression
enforce_max_turns=self.enforce_max_turns,
truncate_turns=self.truncate_turns,
llm_compress_instruction=self.llm_compress_instruction,
llm_compress_keep_recent=self.llm_compress_keep_recent,
llm_compress_provider=self.llm_compress_provider,
custom_token_counter=self.custom_token_counter,
custom_compressor=self.custom_compressor,
)
self.context_manager = ContextManager(self.context_config)
self.provider = provider
self.final_llm_resp = None
self._state = AgentState.IDLE
@@ -110,6 +151,12 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
self._transition_state(AgentState.RUNNING)
llm_resp_result = None
# do truncate and compress
token_usage = self.req.conversation.token_usage if self.req.conversation else 0
self.run_context.messages = await self.context_manager.process(
self.run_context.messages, trusted_token_usage=token_usage
)
async for llm_response in self._iter_llm_responses():
if llm_response.is_chunk:
# update ttft
+77 -16
View File
@@ -83,6 +83,16 @@ DEFAULT_CONFIG = {
"default_personality": "default",
"persona_pool": ["*"],
"prompt_prefix": "{{prompt}}",
"context_limit_reached_strategy": "truncate_by_turns", # or llm_compress
"llm_compress_instruction": (
"Based on our full conversation history, produce a concise summary of key takeaways and/or project progress.\n"
"1. Systematically cover all core topics discussed and the final conclusion/outcome for each; clearly highlight the latest primary focus.\n"
"2. If any tools were used, summarize tool usage (total call count) and extract the most valuable insights from tool outputs.\n"
"3. If there was an initial user goal, state it first and describe the current progress/status.\n"
"4. Write the summary in the user's language.\n"
),
"llm_compress_keep_recent": 4,
"llm_compress_provider_id": "",
"max_context_length": -1,
"dequeue_context_length": 1,
"streaming_response": False,
@@ -179,6 +189,7 @@ class ChatProviderTemplate(TypedDict):
model: str
modalities: list
custom_extra_body: dict[str, Any]
max_context_tokens: int
CHAT_PROVIDER_TEMPLATE = {
@@ -187,6 +198,7 @@ CHAT_PROVIDER_TEMPLATE = {
"model": "",
"modalities": [],
"custom_extra_body": {},
"max_context_tokens": 0,
}
"""
@@ -2033,6 +2045,11 @@ CONFIG_METADATA_2 = {
"type": "string",
"hint": "模型名称,如 gpt-4o-mini, deepseek-chat。",
},
"max_context_tokens": {
"description": "模型上下文窗口大小",
"type": "int",
"hint": "模型最大上下文 Token 大小。如果为 0,则会自动从模型元数据填充(如有),也可手动修改。",
},
"dify_api_key": {
"description": "API Key",
"type": "string",
@@ -2540,6 +2557,66 @@ CONFIG_METADATA_3 = {
# "provider_settings.enable": True,
# },
# },
"truncate_and_compress": {
"description": "上下文管理策略",
"type": "object",
"items": {
"provider_settings.max_context_length": {
"description": "最多携带对话轮数",
"type": "int",
"hint": "超出这个数量时丢弃最旧的部分,一轮聊天记为 1 条,-1 为不限制",
"condition": {
"provider_settings.agent_runner_type": "local",
},
},
"provider_settings.dequeue_context_length": {
"description": "丢弃对话轮数",
"type": "int",
"hint": "超出最多携带对话轮数时, 一次丢弃的聊天轮数",
"condition": {
"provider_settings.agent_runner_type": "local",
},
},
"provider_settings.context_limit_reached_strategy": {
"description": "超出模型上下文窗口时的处理方式",
"type": "string",
"options": ["truncate_by_turns", "llm_compress"],
"labels": ["按对话轮数截断", "由 LLM 压缩上下文"],
"condition": {
"provider_settings.agent_runner_type": "local",
},
"hint": "",
},
"provider_settings.llm_compress_instruction": {
"description": "上下文压缩提示词",
"type": "text",
"hint": "如果为空则使用默认提示词。",
"condition": {
"provider_settings.context_limit_reached_strategy": "llm_compress",
"provider_settings.agent_runner_type": "local",
},
},
"provider_settings.llm_compress_keep_recent": {
"description": "压缩时保留最近对话轮数",
"type": "int",
"hint": "始终保留的最近 N 轮对话。",
"condition": {
"provider_settings.context_limit_reached_strategy": "llm_compress",
"provider_settings.agent_runner_type": "local",
},
},
"provider_settings.llm_compress_provider_id": {
"description": "用于上下文压缩的模型提供商 ID",
"type": "string",
"_special": "select_provider",
"hint": "留空时将降级为“按对话轮数截断”的策略。",
"condition": {
"provider_settings.context_limit_reached_strategy": "llm_compress",
"provider_settings.agent_runner_type": "local",
},
},
},
},
"others": {
"description": "其他配置",
"type": "object",
@@ -2604,22 +2681,6 @@ CONFIG_METADATA_3 = {
"provider_settings.streaming_response": True,
},
},
"provider_settings.max_context_length": {
"description": "最多携带对话轮数",
"type": "int",
"hint": "超出这个数量时丢弃最旧的部分,一轮聊天记为 1 条,-1 为不限制",
"condition": {
"provider_settings.agent_runner_type": "local",
},
},
"provider_settings.dequeue_context_length": {
"description": "丢弃对话轮数",
"type": "int",
"hint": "超出最多携带对话轮数时, 一次丢弃的聊天轮数",
"condition": {
"provider_settings.agent_runner_type": "local",
},
},
"provider_settings.wake_prefix": {
"description": "LLM 聊天额外唤醒前缀 ",
"type": "string",
+4
View File
@@ -69,6 +69,7 @@ class ConversationManager:
persona_id=conv_v2.persona_id,
created_at=created_at,
updated_at=updated_at,
token_usage=conv_v2.token_usage,
)
async def new_conversation(
@@ -256,6 +257,7 @@ class ConversationManager:
history: list[dict] | None = None,
title: str | None = None,
persona_id: str | None = None,
token_usage: int | None = None,
) -> None:
"""更新会话的对话.
@@ -263,6 +265,7 @@ class ConversationManager:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
history (List[Dict]): 对话历史记录, 是一个字典列表, 每个字典包含 role 和 content 字段
token_usage (int | None): token 使用量。None 表示不更新
"""
if not conversation_id:
@@ -274,6 +277,7 @@ class ConversationManager:
title=title,
persona_id=persona_id,
content=history,
token_usage=token_usage,
)
async def update_conversation_title(
+1
View File
@@ -152,6 +152,7 @@ class BaseDatabase(abc.ABC):
title: str | None = None,
persona_id: str | None = None,
content: list[dict] | None = None,
token_usage: int | None = None,
) -> None:
"""Update a conversation's history."""
...
@@ -0,0 +1,61 @@
"""Migration script to add token_usage column to conversations table.
This migration adds the token_usage field to track token consumption for each conversation.
Changes:
- Adds token_usage column to conversations table (default: 0)
"""
from sqlalchemy import text
from astrbot.api import logger, sp
from astrbot.core.db import BaseDatabase
async def migrate_token_usage(db_helper: BaseDatabase):
"""Add token_usage column to conversations table.
This migration adds a new column to track token consumption in conversations.
"""
# 检查是否已经完成迁移
migration_done = await db_helper.get_preference(
"global", "global", "migration_done_token_usage_1"
)
if migration_done:
return
logger.info("开始执行数据库迁移(添加 conversations.token_usage 列)...")
# 这里只适配了 SQLite。因为截止至这一版本,AstrBot 仅支持 SQLite。
try:
async with db_helper.get_db() as session:
# 检查列是否已存在
result = await session.execute(text("PRAGMA table_info(conversations)"))
columns = result.fetchall()
column_names = [col[1] for col in columns]
if "token_usage" in column_names:
logger.info("token_usage 列已存在,跳过迁移")
await sp.put_async(
"global", "global", "migration_done_token_usage_1", True
)
return
# 添加 token_usage 列
await session.execute(
text(
"ALTER TABLE conversations ADD COLUMN token_usage INTEGER NOT NULL DEFAULT 0"
)
)
await session.commit()
logger.info("token_usage 列添加成功")
# 标记迁移完成
await sp.put_async("global", "global", "migration_done_token_usage_1", True)
logger.info("token_usage 迁移完成")
except Exception as e:
logger.error(f"迁移过程中发生错误: {e}", exc_info=True)
raise
+7
View File
@@ -54,6 +54,11 @@ class ConversationV2(SQLModel, table=True):
)
title: str | None = Field(default=None, max_length=255)
persona_id: str | None = Field(default=None)
token_usage: int = Field(default=0, nullable=False)
"""content is a list of OpenAI-formated messages in list[dict] format.
token_usage is the total token value of the messages.
when 0, will use estimated token counter.
"""
__table_args__ = (
UniqueConstraint(
@@ -313,6 +318,8 @@ class Conversation:
persona_id: str | None = ""
created_at: int = 0
updated_at: int = 0
token_usage: int = 0
"""对话的总 token 数量。AstrBot 会保留最近一次 LLM 请求返回的总 token 数,方便统计。token_usage 可能为 0,表示未知。"""
class Personality(TypedDict):
+5 -1
View File
@@ -241,7 +241,9 @@ class SQLiteDatabase(BaseDatabase):
session.add(new_conversation)
return new_conversation
async def update_conversation(self, cid, title=None, persona_id=None, content=None):
async def update_conversation(
self, cid, title=None, persona_id=None, content=None, token_usage=None
):
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
@@ -255,6 +257,8 @@ class SQLiteDatabase(BaseDatabase):
values["persona_id"] = persona_id
if content is not None:
values["content"] = content
if token_usage is not None:
values["token_usage"] = token_usage
if not values:
return None
query = query.values(**values)
@@ -1,12 +1,12 @@
"""本地 Agent 模式的 LLM 调用 Stage"""
import asyncio
import copy
import json
from collections.abc import AsyncGenerator
from astrbot.core import logger
from astrbot.core.agent.message import Message
from astrbot.core.agent.response import AgentStats
from astrbot.core.agent.tool import ToolSet
from astrbot.core.astr_agent_context import AstrAgentContext
from astrbot.core.conversation_mgr import Conversation
@@ -24,6 +24,7 @@ from astrbot.core.provider.entities import (
)
from astrbot.core.star.star_handler import EventType, star_map
from astrbot.core.utils.file_extract import extract_file_moonshotai
from astrbot.core.utils.llm_metadata import LLM_METADATAS
from astrbot.core.utils.metrics import Metric
from astrbot.core.utils.session_lock import session_lock_manager
@@ -41,11 +42,6 @@ class InternalAgentSubStage(Stage):
self.ctx = ctx
conf = ctx.astrbot_config
settings = conf["provider_settings"]
self.max_context_length = settings["max_context_length"] # int
self.dequeue_context_length: int = min(
max(1, settings["dequeue_context_length"]),
self.max_context_length - 1,
)
self.streaming_response: bool = settings["streaming_response"]
self.unsupported_streaming_strategy: str = settings[
"unsupported_streaming_strategy"
@@ -65,6 +61,25 @@ class InternalAgentSubStage(Stage):
"moonshotai_api_key", ""
)
# 上下文管理相关
self.context_limit_reached_strategy: str = settings.get(
"context_limit_reached_strategy", "truncate_by_turns"
)
self.llm_compress_instruction: str = settings.get(
"llm_compress_instruction", ""
)
self.llm_compress_keep_recent: int = settings.get("llm_compress_keep_recent", 4)
self.llm_compress_provider_id: str = settings.get(
"llm_compress_provider_id", ""
)
self.max_context_length = settings["max_context_length"] # int
self.dequeue_context_length: int = min(
max(1, settings["dequeue_context_length"]),
self.max_context_length - 1,
)
if self.dequeue_context_length <= 0:
self.dequeue_context_length = 1
self.conv_manager = ctx.plugin_manager.context.conversation_manager
def _select_provider(self, event: AstrMessageEvent):
@@ -167,34 +182,6 @@ class InternalAgentSubStage(Stage):
},
)
def _truncate_contexts(
self,
contexts: list[dict],
) -> list[dict]:
"""截断上下文列表,确保不超过最大长度"""
if self.max_context_length == -1:
return contexts
if len(contexts) // 2 <= self.max_context_length:
return contexts
truncated_contexts = contexts[
-(self.max_context_length - self.dequeue_context_length + 1) * 2 :
]
# 找到第一个role 为 user 的索引,确保上下文格式正确
index = next(
(
i
for i, item in enumerate(truncated_contexts)
if item.get("role") == "user"
),
None,
)
if index is not None and index > 0:
truncated_contexts = truncated_contexts[index:]
return truncated_contexts
def _modalities_fix(
self,
provider: Provider,
@@ -296,6 +283,7 @@ class InternalAgentSubStage(Stage):
req: ProviderRequest,
llm_response: LLMResponse | None,
all_messages: list[Message],
runner_stats: AgentStats | None,
):
if (
not req
@@ -322,27 +310,37 @@ class InternalAgentSubStage(Stage):
continue
message_to_save.append(message.model_dump())
# get token usage from agent runner stats
token_usage = None
if runner_stats:
token_usage = runner_stats.token_usage.total
await self.conv_manager.update_conversation(
event.unified_msg_origin,
req.conversation.cid,
history=message_to_save,
token_usage=token_usage,
)
def _fix_messages(self, messages: list[dict]) -> list[dict]:
"""验证并且修复上下文"""
fixed_messages = []
for message in messages:
if message.get("role") == "tool":
# tool block 前面必须要有 user 和 assistant block
if len(fixed_messages) < 2:
# 这种情况可能是上下文被截断导致的
# 我们直接将之前的上下文都清空
fixed_messages = []
else:
fixed_messages.append(message)
else:
fixed_messages.append(message)
return fixed_messages
def _get_compress_provider(self) -> Provider | None:
if not self.llm_compress_provider_id:
return None
if self.context_limit_reached_strategy != "llm_compress":
return None
provider = self.ctx.plugin_manager.context.get_provider_by_id(
self.llm_compress_provider_id,
)
if provider is None:
logger.warning(
f"未找到指定的上下文压缩模型 {self.llm_compress_provider_id},将跳过压缩。",
)
return None
if not isinstance(provider, Provider):
logger.warning(
f"指定的上下文压缩模型 {self.llm_compress_provider_id} 不是对话模型,将跳过压缩。"
)
return None
return provider
async def process(
self, event: AstrMessageEvent, provider_wake_prefix: str
@@ -426,9 +424,10 @@ class InternalAgentSubStage(Stage):
await self._apply_kb(event, req)
# truncate contexts to fit max length
if req.contexts:
req.contexts = self._truncate_contexts(req.contexts)
self._fix_messages(req.contexts)
# NOW moved to ContextManager inside ToolLoopAgentRunner
# if req.contexts:
# req.contexts = self._truncate_contexts(req.contexts)
# self._fix_messages(req.contexts)
# session_id
if not req.session_id:
@@ -444,8 +443,6 @@ class InternalAgentSubStage(Stage):
self.unsupported_streaming_strategy == "turn_off"
and not event.platform_meta.support_streaming_message
)
# 备份 req.contexts
backup_contexts = copy.deepcopy(req.contexts)
# run agent
agent_runner = AgentRunner()
@@ -456,6 +453,15 @@ class InternalAgentSubStage(Stage):
context=self.ctx.plugin_manager.context,
event=event,
)
# inject model context length limit
if provider.provider_config.get("max_context_tokens", 0) <= 0:
model = provider.get_model()
if model_info := LLM_METADATAS.get(model):
provider.provider_config["max_context_tokens"] = model_info[
"limit"
]["context"]
await agent_runner.reset(
provider=provider,
request=req,
@@ -466,6 +472,11 @@ class InternalAgentSubStage(Stage):
tool_executor=FunctionToolExecutor(),
agent_hooks=MAIN_AGENT_HOOKS,
streaming=streaming_response,
llm_compress_instruction=self.llm_compress_instruction,
llm_compress_keep_recent=self.llm_compress_keep_recent,
llm_compress_provider=self._get_compress_provider(),
truncate_turns=self.dequeue_context_length,
enforce_max_turns=self.max_context_length,
)
if streaming_response and not stream_to_general:
@@ -511,14 +522,12 @@ class InternalAgentSubStage(Stage):
):
yield
# 恢复备份的 contexts
req.contexts = backup_contexts
await self._save_to_history(
event,
req,
agent_runner.get_final_llm_resp(),
agent_runner.run_context.messages,
agent_runner.stats,
)
# 异步处理 WebChat 特殊情况
+14 -1
View File
@@ -149,9 +149,12 @@ class Context:
contexts: context messages for the LLM
max_steps: Maximum number of tool calls before stopping the loop
**kwargs: Additional keyword arguments. The kwargs will not be passed to the LLM directly for now, but can include:
stream: bool - whether to stream the LLM response
agent_hooks: BaseAgentRunHooks[AstrAgentContext] - hooks to run during agent execution
agent_context: AstrAgentContext - context to use for the agent
other kwargs will be DIRECTLY passed to the runner.reset() method
Returns:
The final LLMResponse after tool calls are completed.
@@ -194,6 +197,15 @@ class Context:
)
agent_runner = ToolLoopAgentRunner()
tool_executor = FunctionToolExecutor()
streaming = kwargs.get("stream", False)
other_kwargs = {
k: v
for k, v in kwargs.items()
if k not in ["stream", "agent_hooks", "agent_context"]
}
await agent_runner.reset(
provider=prov,
request=request,
@@ -203,7 +215,8 @@ class Context:
),
tool_executor=tool_executor,
agent_hooks=agent_hooks,
streaming=kwargs.get("stream", False),
streaming=streaming,
**other_kwargs,
)
async for _ in agent_runner.step_until_done(max_steps):
pass
+8
View File
@@ -3,6 +3,7 @@ import traceback
from astrbot.core import astrbot_config, logger
from astrbot.core.astrbot_config_mgr import AstrBotConfig, AstrBotConfigManager
from astrbot.core.db.migration.migra_45_to_46 import migrate_45_to_46
from astrbot.core.db.migration.migra_token_usage import migrate_token_usage
from astrbot.core.db.migration.migra_webchat_session import migrate_webchat_session
@@ -139,6 +140,13 @@ async def migra(
logger.error(f"Migration for webchat session failed: {e!s}")
logger.error(traceback.format_exc())
# migration for token_usage column
try:
await migrate_token_usage(db)
except Exception as e:
logger.error(f"Migration for token_usage column failed: {e!s}")
logger.error(traceback.format_exc())
# migra third party agent runner configs
_c = False
providers = astrbot_config["provider"]
@@ -144,7 +144,7 @@
color="primary"
density="compact"
hide-details
class="flex-grow-1"
style="flex: 1"
></v-slider>
<v-text-field
:model-value="modelValue"
@@ -154,7 +154,7 @@
class="config-field"
type="number"
hide-details
style="max-width: 140px;"
style="flex: 1"
></v-text-field>
</div>
@@ -325,4 +325,8 @@ function getSpecialSubtype(value) {
.gap-20 {
gap: 20px;
}
:deep(.v-field__input) {
font-size: 14px;
}
</style>
@@ -510,7 +510,7 @@ export function useProviderSources(options: UseProviderSourcesOptions) {
const metadata = getModelMetadata(modelName)
let modalities: string[]
if (!metadata) {
modalities = ['text', 'image', 'tool_use']
} else {
@@ -523,13 +523,19 @@ export function useProviderSources(options: UseProviderSourcesOptions) {
}
}
let max_context_tokens = 0
if (metadata?.limit?.context && typeof metadata.limit.context === 'number') {
max_context_tokens = metadata.limit.context
}
const newProvider = {
id: newId,
enable: false,
provider_source_id: sourceId,
model: modelName,
modalities,
custom_extra_body: {}
custom_extra_body: {},
max_context_tokens: max_context_tokens
}
try {
@@ -11,7 +11,12 @@
},
"agent_runner_type": {
"description": "Runner",
"labels": ["Built-in Agent", "Dify", "Coze", "Alibaba Cloud Bailian Application"]
"labels": [
"Built-in Agent",
"Dify",
"Coze",
"Alibaba Cloud Bailian Application"
]
},
"coze_agent_runner_provider_id": {
"description": "Coze Agent Runner Provider ID"
@@ -128,6 +133,39 @@
}
}
},
"truncate_and_compress": {
"description": "Context Management Strategy",
"provider_settings": {
"max_context_length": {
"description": "Maximum Conversation Turns",
"hint": "Discards the oldest parts when this count is exceeded. One conversation round counts as 1, -1 means unlimited"
},
"dequeue_context_length": {
"description": "Dequeue Conversation Turns",
"hint": "Number of conversation turns to discard at once when maximum context length is exceeded"
},
"context_limit_reached_strategy": {
"description": "Handling When Model Context Window is Exceeded",
"labels": [
"Truncate by Turns",
"Compress by LLM"
],
"hint": "When 'Truncate by Turns' is selected, the oldest N conversation turns will be discarded based on the 'Dequeue Conversation Turns' setting above. When 'Compress by LLM' is selected, the specified model will be used for context compression."
},
"llm_compress_instruction": {
"description": "Context Compression Instruction",
"hint": "If empty, the default prompt will be used."
},
"llm_compress_keep_recent": {
"description": "Keep Recent Turns When Compressing",
"hint": "Always keep the most recent N turns of conversation when compressing context."
},
"llm_compress_provider_id": {
"description": "Model Provider ID for Context Compression",
"hint": "When left empty, will fall back to the 'Truncate by Turns' strategy."
}
}
},
"others": {
"description": "Other Settings",
"provider_settings": {
@@ -161,15 +199,10 @@
"unsupported_streaming_strategy": {
"description": "Platforms Without Streaming Support",
"hint": "Select the handling method for platforms that don't support streaming responses. Real-time segmented reply sends content immediately when the system detects segment points like punctuation during streaming reception",
"labels": ["Real-time Segmented Reply", "Disable Streaming Response"]
},
"max_context_length": {
"description": "Maximum Conversation Rounds",
"hint": "Discards the oldest parts when this count is exceeded. One conversation round counts as 1, -1 means unlimited"
},
"dequeue_context_length": {
"description": "Dequeue Conversation Rounds",
"hint": "Number of conversation rounds to discard at once when maximum context length is exceeded"
"labels": [
"Real-time Segmented Reply",
"Disable Streaming Response"
]
},
"wake_prefix": {
"description": "Additional LLM Chat Wake Prefix",
@@ -387,7 +420,10 @@
},
"split_mode": {
"description": "Split Mode",
"labels": ["Regex", "Words List"]
"labels": [
"Regex",
"Words List"
]
},
"regex": {
"description": "Segmentation Regular Expression"
@@ -488,4 +524,4 @@
}
}
}
}
}
@@ -133,6 +133,36 @@
}
}
},
"truncate_and_compress": {
"description": "上下文管理策略",
"provider_settings": {
"max_context_length": {
"description": "最多携带对话轮数",
"hint": "超出这个数量时丢弃最旧的部分,一轮聊天记为 1 条,-1 为不限制"
},
"dequeue_context_length": {
"description": "丢弃对话轮数",
"hint": "超出最多携带对话轮数时, 一次丢弃的聊天轮数"
},
"context_limit_reached_strategy": {
"description": "超出模型上下文窗口时的处理方式",
"labels": ["按对话轮数截断", "由 LLM 压缩上下文"],
"hint": "当按对话轮数截断时,会根据上面\"丢弃对话轮数\"的配置丢弃最旧的 N 轮对话。当由 LLM 压缩上下文时,会使用指定的模型进行上下文压缩。"
},
"llm_compress_instruction": {
"description": "上下文压缩提示词",
"hint": "如果为空则使用默认提示词。"
},
"llm_compress_keep_recent": {
"description": "压缩时保留最近对话轮数",
"hint": "始终保留的最近 N 轮对话。"
},
"llm_compress_provider_id": {
"description": "用于上下文压缩的模型提供商 ID",
"hint": "留空时将降级为\"按对话轮数截断\"的策略。"
}
}
},
"others": {
"description": "其他配置",
"provider_settings": {
@@ -171,14 +201,7 @@
"关闭流式回复"
]
},
"max_context_length": {
"description": "最多携带对话轮数",
"hint": "超出这个数量时丢弃最旧的部分,一轮聊天记为 1 条,-1 为不限制"
},
"dequeue_context_length": {
"description": "丢弃对话轮数",
"hint": "超出最多携带对话轮数时, 一次丢弃的聊天轮数"
},
"wake_prefix": {
"description": "LLM 聊天额外唤醒前缀",
"hint": "如果唤醒前缀为 /, 额外聊天唤醒前缀为 chat,则需要 /chat 才会触发 LLM 请求"
+774
View File
@@ -0,0 +1,774 @@
"""Comprehensive tests for ContextManager."""
import sys
from pathlib import Path
from typing import Literal
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
# Add parent directory to path to avoid circular import issues
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
from astrbot.core.agent.context.config import ContextConfig
from astrbot.core.agent.context.manager import ContextManager
from astrbot.core.agent.message import Message, TextPart
from astrbot.core.provider.entities import LLMResponse
class MockProvider:
"""模拟 Provider"""
def __init__(self):
self.provider_config = {
"id": "test_provider",
"model": "gpt-4",
"modalities": ["text", "image", "tool_use"],
}
async def text_chat(self, **kwargs):
"""模拟 LLM 调用,返回摘要"""
messages = kwargs.get("messages", [])
# 简单的摘要逻辑:返回消息数量统计
return LLMResponse(
role="assistant",
completion_text=f"历史对话包含 {len(messages) - 1} 条消息,主要讨论了技术话题。",
)
def get_model(self):
return "gpt-4"
def meta(self):
return MagicMock(id="test_provider", type="openai")
class TestContextManager:
"""Test suite for ContextManager."""
def create_message(
self, role: Literal["system", "user", "assistant", "tool"], content: str
) -> Message:
"""Helper to create a simple text message."""
return Message(role=role, content=content)
def create_messages(self, count: int) -> list[Message]:
"""Helper to create alternating user/assistant messages."""
messages = []
for i in range(count):
role = "user" if i % 2 == 0 else "assistant"
messages.append(self.create_message(role, f"Message {i}"))
return messages
# ==================== Basic Initialization Tests ====================
def test_init_with_minimal_config(self):
"""Test initialization with minimal configuration."""
config = ContextConfig()
manager = ContextManager(config)
assert manager.config == config
assert manager.token_counter is not None
assert manager.truncator is not None
assert manager.compressor is not None
def test_init_with_llm_compressor(self):
"""Test initialization with LLM-based compression."""
mock_provider = MockProvider()
config = ContextConfig(
llm_compress_provider=mock_provider, # type: ignore
llm_compress_keep_recent=5,
llm_compress_instruction="Summarize the conversation",
)
manager = ContextManager(config)
from astrbot.core.agent.context.compressor import LLMSummaryCompressor
assert isinstance(manager.compressor, LLMSummaryCompressor)
def test_init_with_truncate_compressor(self):
"""Test initialization with truncate-based compression (default)."""
config = ContextConfig(truncate_turns=3)
manager = ContextManager(config)
from astrbot.core.agent.context.compressor import TruncateByTurnsCompressor
assert isinstance(manager.compressor, TruncateByTurnsCompressor)
# ==================== Empty and Edge Cases ====================
@pytest.mark.asyncio
async def test_process_empty_messages(self):
"""Test processing an empty message list."""
config = ContextConfig()
manager = ContextManager(config)
result = await manager.process([])
assert result == []
@pytest.mark.asyncio
async def test_process_single_message(self):
"""Test processing a single message."""
config = ContextConfig()
manager = ContextManager(config)
messages = [self.create_message("user", "Hello")]
result = await manager.process(messages)
assert len(result) == 1
assert result[0].content == "Hello"
@pytest.mark.asyncio
async def test_process_with_no_limits(self):
"""Test processing when no limits are set (no truncation or compression)."""
config = ContextConfig(max_context_tokens=0, enforce_max_turns=-1)
manager = ContextManager(config)
messages = self.create_messages(20)
result = await manager.process(messages)
assert len(result) == 20
assert result == messages
# ==================== Enforce Max Turns Tests ====================
@pytest.mark.asyncio
async def test_enforce_max_turns_basic(self):
"""Test basic enforce_max_turns functionality."""
config = ContextConfig(enforce_max_turns=3, truncate_turns=1)
manager = ContextManager(config)
# Create 10 turns (20 messages)
messages = self.create_messages(20)
result = await manager.process(messages)
# Should keep only 3 most recent turns (6 messages)
assert len(result) <= 8 # May vary due to truncation logic
@pytest.mark.asyncio
async def test_enforce_max_turns_zero(self):
"""Test enforce_max_turns with value 0 (should keep nothing)."""
config = ContextConfig(enforce_max_turns=0, truncate_turns=1)
manager = ContextManager(config)
messages = self.create_messages(10)
result = await manager.process(messages)
# Should result in empty or minimal message list
assert len(result) <= 2
@pytest.mark.asyncio
async def test_enforce_max_turns_negative(self):
"""Test enforce_max_turns with -1 (no limit)."""
config = ContextConfig(enforce_max_turns=-1)
manager = ContextManager(config)
messages = self.create_messages(20)
result = await manager.process(messages)
assert len(result) == 20
@pytest.mark.asyncio
async def test_enforce_max_turns_with_system_messages(self):
"""Test enforce_max_turns preserves system messages."""
config = ContextConfig(enforce_max_turns=2, truncate_turns=1)
manager = ContextManager(config)
messages = [
self.create_message("system", "System instruction"),
*self.create_messages(10),
]
result = await manager.process(messages)
# System message should be preserved
system_msgs = [m for m in result if m.role == "system"]
assert len(system_msgs) >= 1
assert system_msgs[0].content == "System instruction"
# ==================== Token-based Compression Tests ====================
@pytest.mark.asyncio
async def test_token_compression_not_triggered_below_threshold(self):
"""Test that compression is not triggered below threshold."""
config = ContextConfig(max_context_tokens=1000)
manager = ContextManager(config)
# Create messages that total less than threshold
messages = [self.create_message("user", "Hi" * 50)] # ~100 tokens
with patch.object(
manager.compressor, "should_compress", return_value=False
) as mock_should_compress:
with patch.object(
manager.compressor, "__call__", new_callable=AsyncMock
) as mock_compress:
result = await manager.process(messages)
# should_compress should be called
mock_should_compress.assert_called_once()
# Compressor should not be called
mock_compress.assert_not_called()
assert result == messages
@pytest.mark.asyncio
async def test_token_compression_triggered_above_threshold(self):
"""Test that compression is triggered above threshold."""
config = ContextConfig(max_context_tokens=100, truncate_turns=1)
manager = ContextManager(config)
# Create messages that exceed threshold (0.82 * 100 = 82 tokens)
# 300 chars * 0.3 = 90 tokens > 82 threshold
long_text = "x" * 300 # ~90 tokens, above threshold
messages = [self.create_message("user", long_text)]
# Mock compressor to return smaller result
compressed = [self.create_message("user", "short")]
# Create a mock compressor
mock_compressor = AsyncMock()
mock_compressor.compression_threshold = 0.82
mock_compressor.return_value = compressed
# Mock should_compress to return True first time, False after
call_count = 0
def mock_should_compress(*args, **kwargs):
nonlocal call_count
call_count += 1
return call_count == 1
mock_compressor.should_compress = mock_should_compress
manager.compressor = mock_compressor
result = await manager.process(messages)
# Compressor should be called
mock_compressor.assert_called_once()
# Result should be the compressed version
assert len(result) <= len(messages)
@pytest.mark.asyncio
async def test_token_compression_with_zero_max_tokens(self):
"""Test that compression is skipped when max_context_tokens is 0."""
config = ContextConfig(max_context_tokens=0)
manager = ContextManager(config)
messages = [self.create_message("user", "x" * 10000)]
with patch.object(
manager.compressor, "__call__", new_callable=AsyncMock
) as mock_compress:
result = await manager.process(messages)
# Compressor should not be called when max_context_tokens is 0
mock_compress.assert_not_called()
assert result == messages
@pytest.mark.asyncio
async def test_token_compression_with_negative_max_tokens(self):
"""Test that compression is skipped when max_context_tokens is negative."""
config = ContextConfig(max_context_tokens=-100)
manager = ContextManager(config)
messages = [self.create_message("user", "x" * 10000)]
with patch.object(
manager.compressor, "__call__", new_callable=AsyncMock
) as mock_compress:
result = await manager.process(messages)
# Compressor should not be called
mock_compress.assert_not_called()
assert result == messages
@pytest.mark.asyncio
async def test_double_check_after_compression(self):
"""Test that halving is applied if still over threshold after compression."""
config = ContextConfig(max_context_tokens=100)
manager = ContextManager(config)
# Create messages that would still be over threshold after compression
long_messages = [self.create_message("user", "x" * 200) for _ in range(10)]
# Mock compressor to return messages still over threshold
async def mock_compress(msgs):
return msgs # Return same messages (still over limit)
# Mock should_compress to return True twice (before and after compression)
with patch.object(manager.compressor, "should_compress", return_value=True):
with patch.object(manager.compressor, "__call__", new=mock_compress):
with patch.object(
manager.truncator,
"truncate_by_halving",
return_value=long_messages[:5],
) as mock_halving:
_ = await manager.process(long_messages)
# Halving should be called
mock_halving.assert_called_once()
# ==================== Combined Truncation and Compression Tests ====================
@pytest.mark.asyncio
async def test_combined_enforce_turns_and_token_limit(self):
"""Test combining enforce_max_turns and token limit."""
config = ContextConfig(
enforce_max_turns=5, max_context_tokens=500, truncate_turns=1
)
manager = ContextManager(config)
# Create many messages
messages = self.create_messages(30)
result = await manager.process(messages)
# Should be truncated by both mechanisms
assert len(result) < 30
@pytest.mark.asyncio
async def test_sequential_processing_order(self):
"""Test that enforce_max_turns happens before token compression."""
config = ContextConfig(enforce_max_turns=5, max_context_tokens=1000)
manager = ContextManager(config)
messages = self.create_messages(20)
# Mock the truncator to track calls
with patch.object(
manager.truncator,
"truncate_by_turns",
wraps=manager.truncator.truncate_by_turns,
) as mock_truncate:
await manager.process(messages)
# Truncator should be called first
mock_truncate.assert_called_once()
# ==================== Error Handling Tests ====================
@pytest.mark.asyncio
async def test_error_handling_returns_original_messages(self):
"""Test that errors during processing return original messages."""
config = ContextConfig(max_context_tokens=100)
manager = ContextManager(config)
messages = self.create_messages(5)
# Make compressor raise an exception
with patch.object(
manager.compressor, "__call__", side_effect=Exception("Test error")
):
result = await manager.process(messages)
# Should return original messages despite error
assert result == messages
@pytest.mark.asyncio
async def test_error_handling_logs_exception(self):
"""Test that errors are logged."""
config = ContextConfig(max_context_tokens=100)
manager = ContextManager(config)
# Create messages that will trigger compression (> 82 tokens)
messages = [self.create_message("user", "x" * 300)] # ~90 tokens
# Replace compressor with one that raises an exception
mock_compressor = AsyncMock(side_effect=Exception("Test error"))
mock_compressor.compression_threshold = 0.82
mock_compressor.should_compress = MagicMock(return_value=True)
manager.compressor = mock_compressor
with patch("astrbot.core.agent.context.manager.logger") as mock_logger:
result = await manager.process(messages)
# Logger error method should be called
assert mock_logger.error.called
# Should return original messages on error
assert result == messages
# ==================== Multi-modal Content Tests ====================
@pytest.mark.asyncio
async def test_process_messages_with_textpart_content(self):
"""Test processing messages with TextPart content."""
config = ContextConfig()
manager = ContextManager(config)
messages = [
Message(role="user", content=[TextPart(text="Hello")]),
Message(role="assistant", content=[TextPart(text="Hi there")]),
]
result = await manager.process(messages)
assert len(result) == 2
assert result == messages
@pytest.mark.asyncio
async def test_token_counting_with_multimodal_content(self):
"""Test token counting works with multi-modal content."""
config = ContextConfig(max_context_tokens=50)
manager = ContextManager(config)
# Need enough tokens to exceed threshold: 50 * 0.82 = 41 tokens
# 150 chars * 0.3 = 45 tokens > 41
messages = [
Message(role="user", content=[TextPart(text="x" * 150)]),
]
# Should trigger compression due to token count
tokens = manager.token_counter.count_tokens(messages)
needs_compression = manager.compressor.should_compress(messages, tokens, 50)
assert tokens > 0 # Tokens should be counted
assert needs_compression # Should trigger compression
# ==================== Tool Calls Tests ====================
@pytest.mark.asyncio
async def test_process_messages_with_tool_calls(self):
"""Test processing messages with tool calls."""
config = ContextConfig()
manager = ContextManager(config)
messages = [
Message(
role="assistant",
content="Let me search for that",
tool_calls=[
{
"id": "call_1",
"type": "function",
"function": {"name": "search", "arguments": "{}"},
}
],
),
Message(role="tool", content="Search result", tool_call_id="call_1"),
]
result = await manager.process(messages)
assert len(result) == 2
# ==================== Compressor should_compress Tests ====================
@pytest.mark.asyncio
async def test_should_compress_empty_messages(self):
"""Test should_compress with empty messages."""
config = ContextConfig(max_context_tokens=100)
manager = ContextManager(config)
# Compressor's should_compress should handle empty gracefully
needs_compression = manager.compressor.should_compress([], 0, 100)
assert not needs_compression
@pytest.mark.asyncio
async def test_should_compress_below_threshold(self):
"""Test should_compress when below compression threshold."""
config = ContextConfig(max_context_tokens=1000)
manager = ContextManager(config)
messages = [self.create_message("user", "Hello")]
tokens = manager.token_counter.count_tokens(messages)
needs_compression = manager.compressor.should_compress(messages, tokens, 1000)
assert not needs_compression
@pytest.mark.asyncio
async def test_should_compress_above_threshold(self):
"""Test should_compress when above compression threshold."""
config = ContextConfig(max_context_tokens=100)
manager = ContextManager(config)
# Create message with many tokens
messages = [self.create_message("user", "这是测试" * 50)]
tokens = manager.token_counter.count_tokens(messages)
needs_compression = manager.compressor.should_compress(messages, tokens, 100)
# Should need compression if tokens > 82 (0.82 * 100)
assert needs_compression == (tokens > 82)
# ==================== Truncator Halving Tests ====================
def test_truncate_by_halving_basic(self):
"""Test truncate_by_halving removes middle 50%."""
config = ContextConfig()
manager = ContextManager(config)
messages = self.create_messages(10)
result = manager.truncator.truncate_by_halving(messages)
# Should keep roughly half
assert len(result) < len(messages)
def test_truncate_by_halving_empty_list(self):
"""Test truncate_by_halving with empty list."""
config = ContextConfig()
manager = ContextManager(config)
result = manager.truncator.truncate_by_halving([])
assert result == []
def test_truncate_by_halving_single_message(self):
"""Test truncate_by_halving with single message."""
config = ContextConfig()
manager = ContextManager(config)
messages = [self.create_message("user", "Hello")]
result = manager.truncator.truncate_by_halving(messages)
assert len(result) <= 1
# ==================== Complex Scenarios ====================
@pytest.mark.asyncio
async def test_multiple_compression_cycles(self):
"""Test that compression can be triggered multiple times in sequence."""
config = ContextConfig(max_context_tokens=50, truncate_turns=1)
manager = ContextManager(config)
# Process messages multiple times
messages = self.create_messages(10)
result1 = await manager.process(messages)
result2 = await manager.process(result1)
result3 = await manager.process(result2)
# Each cycle should maintain or reduce message count
assert len(result3) <= len(result2) <= len(result1)
@pytest.mark.asyncio
async def test_alternating_roles_preserved(self):
"""Test that user/assistant alternation is preserved after processing."""
config = ContextConfig(enforce_max_turns=3, truncate_turns=1)
manager = ContextManager(config)
messages = self.create_messages(20)
result = await manager.process(messages)
# Check that roles still alternate (excluding system messages)
non_system = [m for m in result if m.role != "system"]
if len(non_system) >= 2:
# Should start with user
assert non_system[0].role == "user"
@pytest.mark.asyncio
async def test_compression_threshold_default(self):
"""Test that compression threshold is used correctly."""
config = ContextConfig(max_context_tokens=100)
manager = ContextManager(config)
# Verify the default threshold is 0.82
assert manager.compressor.compression_threshold == 0.82
# Test threshold logic
messages = [self.create_message("user", "x" * 81)] # ~24 tokens
tokens = manager.token_counter.count_tokens(messages)
needs_compression = manager.compressor.should_compress(messages, tokens, 100)
# Should not compress if below threshold
assert needs_compression == (tokens > 82)
@pytest.mark.asyncio
async def test_large_batch_processing(self):
"""Test processing a large batch of messages."""
config = ContextConfig(
enforce_max_turns=10, max_context_tokens=1000, truncate_turns=2
)
manager = ContextManager(config)
# Create 100 messages (50 turns)
messages = self.create_messages(100)
result = await manager.process(messages)
# Should be significantly reduced
assert len(result) < 100
assert len(result) > 0
@pytest.mark.asyncio
async def test_config_persistence(self):
"""Test that config settings are respected throughout processing."""
config = ContextConfig(
max_context_tokens=500,
enforce_max_turns=5,
truncate_turns=2,
llm_compress_keep_recent=3,
)
manager = ContextManager(config)
# Verify config is stored
assert manager.config.max_context_tokens == 500
assert manager.config.enforce_max_turns == 5
assert manager.config.truncate_turns == 2
assert manager.config.llm_compress_keep_recent == 3
# ==================== Run Compression Tests ====================
@pytest.mark.asyncio
async def test_run_compression_calls_compressor(self):
"""Test _run_compression calls compressor."""
config = ContextConfig(max_context_tokens=100)
manager = ContextManager(config)
messages = self.create_messages(5)
compressed = self.create_messages(3)
# Create a mock compressor
mock_compressor = AsyncMock()
mock_compressor.compression_threshold = 0.82
mock_compressor.return_value = compressed
mock_compressor.should_compress = MagicMock(return_value=False)
manager.compressor = mock_compressor
result = await manager._run_compression(messages, prev_tokens=100)
# Compressor __call__ should be invoked
mock_compressor.assert_called_once_with(messages)
assert result == compressed
@pytest.mark.asyncio
async def test_run_compression_applies_compressor_through_process(self):
"""Test _run_compression calls compressor when needed through process()."""
config = ContextConfig(max_context_tokens=100, truncate_turns=1)
manager = ContextManager(config)
# Create messages that will trigger compression
messages = [self.create_message("user", "x" * 300)] # ~90 tokens > 82 threshold
compressed = [self.create_message("user", "short")] # Much smaller
# Create a mock compressor
mock_compressor = AsyncMock()
mock_compressor.compression_threshold = 0.82
mock_compressor.return_value = compressed
# Mock should_compress to return True first time, False after
call_count = 0
def mock_should_compress(*args, **kwargs):
nonlocal call_count
call_count += 1
return call_count == 1
mock_compressor.should_compress = mock_should_compress
manager.compressor = mock_compressor
result = await manager.process(messages)
# Compressor should have been called
mock_compressor.assert_called_once()
assert len(result) <= len(messages)
@pytest.mark.asyncio
async def test_llm_compression_with_mock_provider(self):
"""Test LLM compression using MockProvider."""
mock_provider = MockProvider()
config = ContextConfig(
llm_compress_provider=mock_provider, # type: ignore
llm_compress_keep_recent=3,
llm_compress_instruction="请总结对话内容",
max_context_tokens=100,
)
manager = ContextManager(config)
# Create messages that will trigger compression
messages = [
self.create_message("user", "x" * 100),
self.create_message("assistant", "y" * 100),
self.create_message("user", "z" * 100),
]
result = await manager.process(messages)
# Should have been compressed
assert len(result) <= len(messages)
# ==================== split_history Tests ====================
def test_split_history_ensures_user_start(self):
"""Test split_history ensures recent_messages starts with user message."""
from astrbot.core.agent.context.compressor import split_history
# Create alternating messages: user, assistant, user, assistant, user, assistant
messages = [
self.create_message("system", "System prompt"),
self.create_message("user", "msg1"),
self.create_message("assistant", "msg2"),
self.create_message("user", "msg3"),
self.create_message("assistant", "msg4"),
self.create_message("user", "msg5"),
self.create_message("assistant", "msg6"),
]
# Keep recent 3 messages - should adjust to start with user
system, to_summarize, recent = split_history(messages, keep_recent=3)
# recent_messages should start with user message
assert len(recent) > 0
assert recent[0].role == "user"
# messages_to_summarize should end with assistant (complete turn)
if len(to_summarize) > 0:
assert to_summarize[-1].role == "assistant"
def test_split_history_handles_assistant_at_split_point(self):
"""Test split_history when assistant message is at the intended split point."""
from astrbot.core.agent.context.compressor import split_history
messages = [
self.create_message("user", "msg1"),
self.create_message("assistant", "msg2"),
self.create_message("user", "msg3"),
self.create_message("assistant", "msg4"), # <- intended split here
self.create_message("user", "msg5"),
self.create_message("assistant", "msg6"),
]
# keep_recent=2 would normally split at index 4 (assistant msg4)
# Should move back to include from msg5 (user)
system, to_summarize, recent = split_history(messages, keep_recent=2)
# recent should start with user message
assert recent[0].role == "user"
assert recent[0].content == "msg5"
def test_split_history_all_assistant_messages(self):
"""Test split_history when there are consecutive assistant messages."""
from astrbot.core.agent.context.compressor import split_history
messages = [
self.create_message("user", "msg1"),
self.create_message("assistant", "msg2"),
self.create_message("assistant", "msg3"),
self.create_message("assistant", "msg4"),
]
system, to_summarize, recent = split_history(messages, keep_recent=2)
# Should find the user message and keep from there
if len(recent) > 0:
# Find first user message backwards
assert any(m.role == "user" for m in messages)
def test_split_history_with_system_messages(self):
"""Test split_history preserves system messages separately."""
from astrbot.core.agent.context.compressor import split_history
messages = [
self.create_message("system", "System 1"),
self.create_message("system", "System 2"),
self.create_message("user", "msg1"),
self.create_message("assistant", "msg2"),
self.create_message("user", "msg3"),
]
system, to_summarize, recent = split_history(messages, keep_recent=2)
# System messages should be separate
assert len(system) == 2
assert all(m.role == "system" for m in system)
# Recent should start with user
if len(recent) > 0:
assert recent[0].role == "user"
+423
View File
@@ -0,0 +1,423 @@
"""Tests for ContextTruncator."""
from astrbot.core.agent.context.truncator import ContextTruncator
from astrbot.core.agent.message import Message
class TestContextTruncator:
"""Test suite for ContextTruncator."""
def create_message(self, role: str, content: str = "test content") -> Message:
"""Helper to create a simple test message."""
return Message(role=role, content=content)
def create_messages(
self, count: int, include_system: bool = False
) -> list[Message]:
"""Helper to create alternating user/assistant messages.
Args:
count: Number of messages to create
include_system: Whether to include a system message at the start
Returns:
List of messages
"""
messages = []
if include_system:
messages.append(self.create_message("system", "System prompt"))
for i in range(count):
role = "user" if i % 2 == 0 else "assistant"
messages.append(self.create_message(role, f"Message {i}"))
return messages
# ==================== fix_messages Tests ====================
def test_fix_messages_empty_list(self):
"""Test fix_messages with an empty list."""
truncator = ContextTruncator()
result = truncator.fix_messages([])
assert result == []
def test_fix_messages_normal_messages(self):
"""Test fix_messages with normal user/assistant messages."""
truncator = ContextTruncator()
messages = [
self.create_message("user", "Hello"),
self.create_message("assistant", "Hi"),
self.create_message("user", "How are you?"),
]
result = truncator.fix_messages(messages)
assert len(result) == 3
assert result == messages
def test_fix_messages_tool_with_valid_context(self):
"""Test fix_messages with tool message after user+assistant."""
truncator = ContextTruncator()
messages = [
self.create_message("user", "Run tool"),
self.create_message("assistant", "Running..."),
self.create_message("tool", "Tool result"),
]
result = truncator.fix_messages(messages)
assert len(result) == 3
assert result == messages
def test_fix_messages_tool_without_context(self):
"""Test fix_messages with tool message without enough context."""
truncator = ContextTruncator()
messages = [
self.create_message("tool", "Tool result"),
]
result = truncator.fix_messages(messages)
# Tool message without context should be removed
assert len(result) == 0
def test_fix_messages_tool_with_only_one_message(self):
"""Test fix_messages with tool message after only one message."""
truncator = ContextTruncator()
messages = [
self.create_message("user", "Hello"),
self.create_message("tool", "Tool result"),
]
result = truncator.fix_messages(messages)
# Tool message without enough context should be removed
assert len(result) == 0
def test_fix_messages_multiple_tools(self):
"""Test fix_messages with multiple tool messages."""
truncator = ContextTruncator()
messages = [
self.create_message("user", "Run tool"),
self.create_message("assistant", "Running..."),
self.create_message("tool", "Tool 1 result"),
self.create_message("tool", "Tool 2 result"),
]
result = truncator.fix_messages(messages)
assert len(result) == 4
assert result == messages
def test_fix_messages_mixed_system_tool(self):
"""Test fix_messages with system message and tool messages."""
truncator = ContextTruncator()
messages = [
self.create_message("system", "System prompt"),
self.create_message("user", "Run tool"),
self.create_message("assistant", "Running..."),
self.create_message("tool", "Tool result"),
]
result = truncator.fix_messages(messages)
assert len(result) == 4
assert result == messages
# ==================== truncate_by_turns Tests ====================
def test_truncate_by_turns_no_limit(self):
"""Test truncate_by_turns with -1 (no limit)."""
truncator = ContextTruncator()
messages = self.create_messages(20)
result = truncator.truncate_by_turns(messages, keep_most_recent_turns=-1)
assert len(result) == 20
assert result == messages
def test_truncate_by_turns_basic(self):
"""Test basic truncate_by_turns functionality."""
truncator = ContextTruncator()
# Create 10 messages = 5 turns (user/assistant pairs)
messages = self.create_messages(10)
result = truncator.truncate_by_turns(
messages, keep_most_recent_turns=3, drop_turns=1
)
# Should keep 3 most recent turns (6 messages)
assert len(result) <= 8 # (3-1+1)*2 = 6, but may adjust for correct format
def test_truncate_by_turns_with_system_message(self):
"""Test truncate_by_turns preserves system messages."""
truncator = ContextTruncator()
messages = self.create_messages(10, include_system=True)
result = truncator.truncate_by_turns(
messages, keep_most_recent_turns=2, drop_turns=1
)
# System message should always be preserved
assert result[0].role == "system"
assert result[0].content == "System prompt"
def test_truncate_by_turns_zero_keep(self):
"""Test truncate_by_turns with keep_most_recent_turns=0."""
truncator = ContextTruncator()
messages = self.create_messages(10)
result = truncator.truncate_by_turns(
messages, keep_most_recent_turns=0, drop_turns=1
)
# Should result in empty or minimal list
assert len(result) == 0
def test_truncate_by_turns_below_threshold(self):
"""Test truncate_by_turns when messages are below threshold."""
truncator = ContextTruncator()
# Create 4 messages = 2 turns
messages = self.create_messages(4)
result = truncator.truncate_by_turns(
messages, keep_most_recent_turns=5, drop_turns=1
)
# No truncation should happen
assert len(result) == 4
assert result == messages
def test_truncate_by_turns_exact_threshold(self):
"""Test truncate_by_turns when messages exactly match threshold."""
truncator = ContextTruncator()
# Create 6 messages = 3 turns
messages = self.create_messages(6)
result = truncator.truncate_by_turns(
messages, keep_most_recent_turns=3, drop_turns=1
)
# No truncation should happen
assert len(result) == 6
assert result == messages
def test_truncate_by_turns_ensures_user_first(self):
"""Test that truncate_by_turns ensures user message comes first."""
truncator = ContextTruncator()
# Create scenario where truncation might start with assistant
messages = self.create_messages(20)
result = truncator.truncate_by_turns(
messages, keep_most_recent_turns=3, drop_turns=1
)
# First non-system message should be user
assert result[0].role == "user"
def test_truncate_by_turns_multiple_drop(self):
"""Test truncate_by_turns with multiple turns dropped at once."""
truncator = ContextTruncator()
messages = self.create_messages(20)
result = truncator.truncate_by_turns(
messages, keep_most_recent_turns=5, drop_turns=3
)
# Should drop 3 turns when limit exceeded
assert len(result) < len(messages)
# ==================== truncate_by_dropping_oldest_turns Tests ====================
def test_truncate_by_dropping_oldest_turns_zero(self):
"""Test truncate_by_dropping_oldest_turns with drop_turns=0."""
truncator = ContextTruncator()
messages = self.create_messages(10)
result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=0)
assert result == messages
def test_truncate_by_dropping_oldest_turns_negative(self):
"""Test truncate_by_dropping_oldest_turns with negative drop_turns."""
truncator = ContextTruncator()
messages = self.create_messages(10)
result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=-1)
assert result == messages
def test_truncate_by_dropping_oldest_turns_basic(self):
"""Test basic truncate_by_dropping_oldest_turns functionality."""
truncator = ContextTruncator()
# Create 10 messages = 5 turns
messages = self.create_messages(10)
result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=2)
# Should drop 2 oldest turns (4 messages)
assert len(result) == 6
# Should start with user message
assert result[0].role == "user"
def test_truncate_by_dropping_oldest_turns_with_system(self):
"""Test truncate_by_dropping_oldest_turns preserves system messages."""
truncator = ContextTruncator()
messages = self.create_messages(10, include_system=True)
result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=2)
# System message should be preserved
assert result[0].role == "system"
assert result[0].content == "System prompt"
def test_truncate_by_dropping_oldest_turns_drop_all(self):
"""Test truncate_by_dropping_oldest_turns dropping all turns."""
truncator = ContextTruncator()
# Create 4 messages = 2 turns
messages = self.create_messages(4)
result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=2)
# Should drop all turns
assert len(result) == 0
def test_truncate_by_dropping_oldest_turns_drop_more_than_available(self):
"""Test truncate_by_dropping_oldest_turns with drop_turns > available turns."""
truncator = ContextTruncator()
# Create 4 messages = 2 turns
messages = self.create_messages(4)
result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=5)
# Should result in empty list
assert len(result) == 0
def test_truncate_by_dropping_oldest_turns_ensures_user_first(self):
"""Test that result starts with user message after dropping."""
truncator = ContextTruncator()
messages = self.create_messages(20)
result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=3)
# First message should be user
if len(result) > 0:
assert result[0].role == "user"
# ==================== truncate_by_halving Tests ====================
def test_truncate_by_halving_empty(self):
"""Test truncate_by_halving with empty list."""
truncator = ContextTruncator()
result = truncator.truncate_by_halving([])
assert result == []
def test_truncate_by_halving_single_message(self):
"""Test truncate_by_halving with single message."""
truncator = ContextTruncator()
messages = [self.create_message("user", "Hello")]
result = truncator.truncate_by_halving(messages)
# Should not truncate if <= 2 messages
assert result == messages
def test_truncate_by_halving_two_messages(self):
"""Test truncate_by_halving with two messages."""
truncator = ContextTruncator()
messages = self.create_messages(2)
result = truncator.truncate_by_halving(messages)
# Should not truncate if <= 2 messages
assert result == messages
def test_truncate_by_halving_basic(self):
"""Test basic truncate_by_halving functionality."""
truncator = ContextTruncator()
# Create 20 messages
messages = self.create_messages(20)
result = truncator.truncate_by_halving(messages)
# Should delete 50% = 10 messages, keep 10
assert len(result) == 10
# First message should be user
assert result[0].role == "user"
def test_truncate_by_halving_with_system_message(self):
"""Test truncate_by_halving preserves system messages."""
truncator = ContextTruncator()
messages = self.create_messages(20, include_system=True)
result = truncator.truncate_by_halving(messages)
# System message should be preserved
assert result[0].role == "system"
assert result[0].content == "System prompt"
def test_truncate_by_halving_odd_count(self):
"""Test truncate_by_halving with odd number of messages."""
truncator = ContextTruncator()
messages = self.create_messages(11)
result = truncator.truncate_by_halving(messages)
# Should delete floor(11/2) = 5 messages, keep 6
# But after ensuring user first, may be 5
assert len(result) >= 5
assert result[0].role == "user"
def test_truncate_by_halving_ensures_user_first(self):
"""Test that result starts with user message."""
truncator = ContextTruncator()
# Create messages starting with user
messages = self.create_messages(30)
result = truncator.truncate_by_halving(messages)
# First message should be user
assert result[0].role == "user"
def test_truncate_by_halving_preserves_recent_messages(self):
"""Test that truncate_by_halving keeps the most recent 50%."""
truncator = ContextTruncator()
messages = [
self.create_message("user", "Message 0"),
self.create_message("assistant", "Message 1"),
self.create_message("user", "Message 2"),
self.create_message("assistant", "Message 3"),
]
result = truncator.truncate_by_halving(messages)
# Should keep last 2 messages
assert len(result) == 2
assert result[0].content == "Message 2"
assert result[1].content == "Message 3"
# ==================== Integration Tests ====================
def test_truncate_with_tool_messages(self):
"""Test truncation with tool messages."""
truncator = ContextTruncator()
messages = [
self.create_message("user", "Run tool"),
self.create_message("assistant", "Running..."),
self.create_message("tool", "Tool result"),
self.create_message("user", "Thanks"),
self.create_message("assistant", "Welcome"),
]
result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=1)
# First turn (user+assistant+tool) should be dropped
# Tool message should be cleaned up by fix_messages
assert len(result) <= 2
def test_chain_multiple_truncations(self):
"""Test chaining multiple truncation methods."""
truncator = ContextTruncator()
messages = self.create_messages(40, include_system=True)
# First: truncate by turns
result = truncator.truncate_by_turns(
messages, keep_most_recent_turns=10, drop_turns=2
)
# Then: halve
result = truncator.truncate_by_halving(result)
# Should have system message + truncated content
assert result[0].role == "system"
assert len(result) < len(messages)
def test_empty_after_system_message(self):
"""Test truncation when only system message exists."""
truncator = ContextTruncator()
messages = [self.create_message("system", "System prompt")]
result = truncator.truncate_by_turns(
messages, keep_most_recent_turns=5, drop_turns=1
)
# Should keep system message
assert len(result) == 1
assert result[0].role == "system"
def test_all_system_messages(self):
"""Test truncation with only system messages."""
truncator = ContextTruncator()
messages = [
self.create_message("system", "System 1"),
self.create_message("system", "System 2"),
]
result = truncator.truncate_by_turns(
messages, keep_most_recent_turns=0, drop_turns=1
)
# System messages should be preserved, but since there are no non-system
# messages and keep_most_recent_turns=0, result should be system messages only
assert len(result) >= 0 # May keep system messages or clear all
if len(result) > 0:
assert all(msg.role == "system" for msg in result)