feat: context compress (#4322)
* feat: context compressor Co-authored-by: kawayiYokami <289104862@qq.com> * Add comprehensive tests for ContextManager and ContextTruncator - Implemented a full test suite for ContextManager covering initialization, message processing, token-based compression, and error handling. - Added tests for ContextTruncator focusing on message fixing, truncation by turns, dropping oldest turns, and halving. - Ensured that both test suites validate edge cases and maintain expected behavior with various message types, including system and tool messages. * feat: add MockProvider for LLM compression tests * chore: remove lock * ruff fix * fix * perf * feat: enhance context compression with token tracking and logging * feat: update logging for context compression trigger * feat: implement context compression logic with dynamic threshold and token tracking * fix: reorder import statements for consistency * feat: add token_usage tracking to conversations and update related processing logic --------- Co-authored-by: kawayiYokami <289104862@qq.com>
This commit is contained in:
@@ -0,0 +1,243 @@
|
||||
from typing import TYPE_CHECKING, Protocol, runtime_checkable
|
||||
|
||||
from ..message import Message
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from astrbot import logger
|
||||
else:
|
||||
try:
|
||||
from astrbot import logger
|
||||
except ImportError:
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger("astrbot")
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from astrbot.core.provider.provider import Provider
|
||||
|
||||
from ..context.truncator import ContextTruncator
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class ContextCompressor(Protocol):
|
||||
"""
|
||||
Protocol for context compressors.
|
||||
Provides an interface for compressing message lists.
|
||||
"""
|
||||
|
||||
def should_compress(
|
||||
self, messages: list[Message], current_tokens: int, max_tokens: int
|
||||
) -> bool:
|
||||
"""Check if compression is needed.
|
||||
|
||||
Args:
|
||||
messages: The message list to evaluate.
|
||||
current_tokens: The current token count.
|
||||
max_tokens: The maximum allowed tokens for the model.
|
||||
|
||||
Returns:
|
||||
True if compression is needed, False otherwise.
|
||||
"""
|
||||
...
|
||||
|
||||
async def __call__(self, messages: list[Message]) -> list[Message]:
|
||||
"""Compress the message list.
|
||||
|
||||
Args:
|
||||
messages: The original message list.
|
||||
|
||||
Returns:
|
||||
The compressed message list.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class TruncateByTurnsCompressor:
|
||||
"""Truncate by turns compressor implementation.
|
||||
Truncates the message list by removing older turns.
|
||||
"""
|
||||
|
||||
def __init__(self, truncate_turns: int = 1, compression_threshold: float = 0.82):
|
||||
"""Initialize the truncate by turns compressor.
|
||||
|
||||
Args:
|
||||
truncate_turns: The number of turns to remove when truncating (default: 1).
|
||||
compression_threshold: The compression trigger threshold (default: 0.82).
|
||||
"""
|
||||
self.truncate_turns = truncate_turns
|
||||
self.compression_threshold = compression_threshold
|
||||
|
||||
def should_compress(
|
||||
self, messages: list[Message], current_tokens: int, max_tokens: int
|
||||
) -> bool:
|
||||
"""Check if compression is needed.
|
||||
|
||||
Args:
|
||||
messages: The message list to evaluate.
|
||||
current_tokens: The current token count.
|
||||
max_tokens: The maximum allowed tokens.
|
||||
|
||||
Returns:
|
||||
True if compression is needed, False otherwise.
|
||||
"""
|
||||
if max_tokens <= 0 or current_tokens <= 0:
|
||||
return False
|
||||
usage_rate = current_tokens / max_tokens
|
||||
return usage_rate > self.compression_threshold
|
||||
|
||||
async def __call__(self, messages: list[Message]) -> list[Message]:
|
||||
truncator = ContextTruncator()
|
||||
truncated_messages = truncator.truncate_by_dropping_oldest_turns(
|
||||
messages,
|
||||
drop_turns=self.truncate_turns,
|
||||
)
|
||||
return truncated_messages
|
||||
|
||||
|
||||
def split_history(
|
||||
messages: list[Message], keep_recent: int
|
||||
) -> tuple[list[Message], list[Message], list[Message]]:
|
||||
"""Split the message list into system messages, messages to summarize, and recent messages.
|
||||
|
||||
Ensures that the split point is between complete user-assistant pairs to maintain conversation flow.
|
||||
|
||||
Args:
|
||||
messages: The original message list.
|
||||
keep_recent: The number of latest messages to keep.
|
||||
|
||||
Returns:
|
||||
tuple: (system_messages, messages_to_summarize, recent_messages)
|
||||
"""
|
||||
# keep the system messages
|
||||
first_non_system = 0
|
||||
for i, msg in enumerate(messages):
|
||||
if msg.role != "system":
|
||||
first_non_system = i
|
||||
break
|
||||
|
||||
system_messages = messages[:first_non_system]
|
||||
non_system_messages = messages[first_non_system:]
|
||||
|
||||
if len(non_system_messages) <= keep_recent:
|
||||
return system_messages, [], non_system_messages
|
||||
|
||||
# Find the split point, ensuring recent_messages starts with a user message
|
||||
# This maintains complete conversation turns
|
||||
split_index = len(non_system_messages) - keep_recent
|
||||
|
||||
# Search backward from split_index to find the first user message
|
||||
# This ensures recent_messages starts with a user message (complete turn)
|
||||
while split_index > 0 and non_system_messages[split_index].role != "user":
|
||||
# TODO: +=1 or -=1 ? calculate by tokens
|
||||
split_index -= 1
|
||||
|
||||
# If we couldn't find a user message, keep all messages as recent
|
||||
if split_index == 0:
|
||||
return system_messages, [], non_system_messages
|
||||
|
||||
messages_to_summarize = non_system_messages[:split_index]
|
||||
recent_messages = non_system_messages[split_index:]
|
||||
|
||||
return system_messages, messages_to_summarize, recent_messages
|
||||
|
||||
|
||||
class LLMSummaryCompressor:
|
||||
"""LLM-based summary compressor.
|
||||
Uses LLM to summarize the old conversation history, keeping the latest messages.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
provider: "Provider",
|
||||
keep_recent: int = 4,
|
||||
instruction_text: str | None = None,
|
||||
compression_threshold: float = 0.82,
|
||||
):
|
||||
"""Initialize the LLM summary compressor.
|
||||
|
||||
Args:
|
||||
provider: The LLM provider instance.
|
||||
keep_recent: The number of latest messages to keep (default: 4).
|
||||
instruction_text: Custom instruction for summary generation.
|
||||
compression_threshold: The compression trigger threshold (default: 0.82).
|
||||
"""
|
||||
self.provider = provider
|
||||
self.keep_recent = keep_recent
|
||||
self.compression_threshold = compression_threshold
|
||||
|
||||
self.instruction_text = instruction_text or (
|
||||
"Based on our full conversation history, produce a concise summary of key takeaways and/or project progress.\n"
|
||||
"1. Systematically cover all core topics discussed and the final conclusion/outcome for each; clearly highlight the latest primary focus.\n"
|
||||
"2. If any tools were used, summarize tool usage (total call count) and extract the most valuable insights from tool outputs.\n"
|
||||
"3. If there was an initial user goal, state it first and describe the current progress/status.\n"
|
||||
"4. Write the summary in the user's language.\n"
|
||||
)
|
||||
|
||||
def should_compress(
|
||||
self, messages: list[Message], current_tokens: int, max_tokens: int
|
||||
) -> bool:
|
||||
"""Check if compression is needed.
|
||||
|
||||
Args:
|
||||
messages: The message list to evaluate.
|
||||
current_tokens: The current token count.
|
||||
max_tokens: The maximum allowed tokens.
|
||||
|
||||
Returns:
|
||||
True if compression is needed, False otherwise.
|
||||
"""
|
||||
if max_tokens <= 0 or current_tokens <= 0:
|
||||
return False
|
||||
usage_rate = current_tokens / max_tokens
|
||||
return usage_rate > self.compression_threshold
|
||||
|
||||
async def __call__(self, messages: list[Message]) -> list[Message]:
|
||||
"""Use LLM to generate a summary of the conversation history.
|
||||
|
||||
Process:
|
||||
1. Divide messages: keep the system message and the latest N messages.
|
||||
2. Send the old messages + the instruction message to the LLM.
|
||||
3. Reconstruct the message list: [system message, summary message, latest messages].
|
||||
"""
|
||||
if len(messages) <= self.keep_recent + 1:
|
||||
return messages
|
||||
|
||||
system_messages, messages_to_summarize, recent_messages = split_history(
|
||||
messages, self.keep_recent
|
||||
)
|
||||
|
||||
if not messages_to_summarize:
|
||||
return messages
|
||||
|
||||
# build payload
|
||||
instruction_message = Message(role="user", content=self.instruction_text)
|
||||
llm_payload = messages_to_summarize + [instruction_message]
|
||||
|
||||
# generate summary
|
||||
try:
|
||||
response = await self.provider.text_chat(contexts=llm_payload)
|
||||
summary_content = response.completion_text
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate summary: {e}")
|
||||
return messages
|
||||
|
||||
# build result
|
||||
result = []
|
||||
result.extend(system_messages)
|
||||
|
||||
result.append(
|
||||
Message(
|
||||
role="user",
|
||||
content=f"Our previous history conversation summary: {summary_content}",
|
||||
)
|
||||
)
|
||||
result.append(
|
||||
Message(
|
||||
role="assistant",
|
||||
content="Acknowledged the summary of our previous conversation history.",
|
||||
)
|
||||
)
|
||||
|
||||
result.extend(recent_messages)
|
||||
|
||||
return result
|
||||
@@ -0,0 +1,35 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .compressor import ContextCompressor
|
||||
from .token_counter import TokenCounter
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from astrbot.core.provider.provider import Provider
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContextConfig:
|
||||
"""Context configuration class."""
|
||||
|
||||
max_context_tokens: int = 0
|
||||
"""Maximum number of context tokens. <= 0 means no limit."""
|
||||
enforce_max_turns: int = -1 # -1 means no limit
|
||||
"""Maximum number of conversation turns to keep. -1 means no limit. Executed before compression."""
|
||||
truncate_turns: int = 1
|
||||
"""Number of conversation turns to discard at once when truncation is triggered.
|
||||
Two processes will use this value:
|
||||
|
||||
1. Enforce max turns truncation.
|
||||
2. Truncation by turns compression strategy.
|
||||
"""
|
||||
llm_compress_instruction: str | None = None
|
||||
"""Instruction prompt for LLM-based compression."""
|
||||
llm_compress_keep_recent: int = 0
|
||||
"""Number of recent messages to keep during LLM-based compression."""
|
||||
llm_compress_provider: "Provider | None" = None
|
||||
"""LLM provider used for compression tasks. If None, truncation strategy is used."""
|
||||
custom_token_counter: TokenCounter | None = None
|
||||
"""Custom token counting method. If None, the default method is used."""
|
||||
custom_compressor: ContextCompressor | None = None
|
||||
"""Custom context compression method. If None, the default method is used."""
|
||||
@@ -0,0 +1,120 @@
|
||||
from astrbot import logger
|
||||
|
||||
from ..message import Message
|
||||
from .compressor import LLMSummaryCompressor, TruncateByTurnsCompressor
|
||||
from .config import ContextConfig
|
||||
from .token_counter import EstimateTokenCounter
|
||||
from .truncator import ContextTruncator
|
||||
|
||||
|
||||
class ContextManager:
|
||||
"""Context compression manager."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: ContextConfig,
|
||||
):
|
||||
"""Initialize the context manager.
|
||||
|
||||
There are two strategies to handle context limit reached:
|
||||
1. Truncate by turns: remove older messages by turns.
|
||||
2. LLM-based compression: use LLM to summarize old messages.
|
||||
|
||||
Args:
|
||||
config: The context configuration.
|
||||
"""
|
||||
self.config = config
|
||||
|
||||
self.token_counter = config.custom_token_counter or EstimateTokenCounter()
|
||||
self.truncator = ContextTruncator()
|
||||
|
||||
if config.custom_compressor:
|
||||
self.compressor = config.custom_compressor
|
||||
elif config.llm_compress_provider:
|
||||
self.compressor = LLMSummaryCompressor(
|
||||
provider=config.llm_compress_provider,
|
||||
keep_recent=config.llm_compress_keep_recent,
|
||||
instruction_text=config.llm_compress_instruction,
|
||||
)
|
||||
else:
|
||||
self.compressor = TruncateByTurnsCompressor(
|
||||
truncate_turns=config.truncate_turns
|
||||
)
|
||||
|
||||
async def process(
|
||||
self, messages: list[Message], trusted_token_usage: int = 0
|
||||
) -> list[Message]:
|
||||
"""Process the messages.
|
||||
|
||||
Args:
|
||||
messages: The original message list.
|
||||
|
||||
Returns:
|
||||
The processed message list.
|
||||
"""
|
||||
try:
|
||||
result = messages
|
||||
|
||||
# 1. 基于轮次的截断 (Enforce max turns)
|
||||
if self.config.enforce_max_turns != -1:
|
||||
result = self.truncator.truncate_by_turns(
|
||||
result,
|
||||
keep_most_recent_turns=self.config.enforce_max_turns,
|
||||
drop_turns=self.config.truncate_turns,
|
||||
)
|
||||
|
||||
# 2. 基于 token 的压缩
|
||||
if self.config.max_context_tokens > 0:
|
||||
total_tokens = self.token_counter.count_tokens(
|
||||
result, trusted_token_usage
|
||||
)
|
||||
|
||||
if self.compressor.should_compress(
|
||||
result, total_tokens, self.config.max_context_tokens
|
||||
):
|
||||
result = await self._run_compression(result, total_tokens)
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Error during context processing: {e}", exc_info=True)
|
||||
return messages
|
||||
|
||||
async def _run_compression(
|
||||
self, messages: list[Message], prev_tokens: int
|
||||
) -> list[Message]:
|
||||
"""
|
||||
Compress/truncate the messages.
|
||||
|
||||
Args:
|
||||
messages: The original message list.
|
||||
prev_tokens: The token count before compression.
|
||||
|
||||
Returns:
|
||||
The compressed/truncated message list.
|
||||
"""
|
||||
logger.debug("Compress triggered, starting compression...")
|
||||
|
||||
messages = await self.compressor(messages)
|
||||
|
||||
# double check
|
||||
tokens_after_summary = self.token_counter.count_tokens(messages)
|
||||
|
||||
# calculate compress rate
|
||||
compress_rate = (tokens_after_summary / self.config.max_context_tokens) * 100
|
||||
logger.info(
|
||||
f"Compress completed."
|
||||
f" {prev_tokens} -> {tokens_after_summary} tokens,"
|
||||
f" compression rate: {compress_rate:.2f}%.",
|
||||
)
|
||||
|
||||
# last check
|
||||
if self.compressor.should_compress(
|
||||
messages, tokens_after_summary, self.config.max_context_tokens
|
||||
):
|
||||
logger.info(
|
||||
"Context still exceeds max tokens after compression, applying halving truncation..."
|
||||
)
|
||||
# still need compress, truncate by half
|
||||
messages = self.truncator.truncate_by_halving(messages)
|
||||
|
||||
return messages
|
||||
@@ -0,0 +1,64 @@
|
||||
import json
|
||||
from typing import Protocol, runtime_checkable
|
||||
|
||||
from ..message import Message, TextPart
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class TokenCounter(Protocol):
|
||||
"""
|
||||
Protocol for token counters.
|
||||
Provides an interface for counting tokens in message lists.
|
||||
"""
|
||||
|
||||
def count_tokens(
|
||||
self, messages: list[Message], trusted_token_usage: int = 0
|
||||
) -> int:
|
||||
"""Count the total tokens in the message list.
|
||||
|
||||
Args:
|
||||
messages: The message list.
|
||||
trusted_token_usage: The total token usage that LLM API returned.
|
||||
For some cases, this value is more accurate.
|
||||
But some API does not return it, so the value defaults to 0.
|
||||
|
||||
Returns:
|
||||
The total token count.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class EstimateTokenCounter:
|
||||
"""Estimate token counter implementation.
|
||||
Provides a simple estimation of token count based on character types.
|
||||
"""
|
||||
|
||||
def count_tokens(
|
||||
self, messages: list[Message], trusted_token_usage: int = 0
|
||||
) -> int:
|
||||
if trusted_token_usage > 0:
|
||||
return trusted_token_usage
|
||||
|
||||
total = 0
|
||||
for msg in messages:
|
||||
content = msg.content
|
||||
if isinstance(content, str):
|
||||
total += self._estimate_tokens(content)
|
||||
elif isinstance(content, list):
|
||||
# 处理多模态内容
|
||||
for part in content:
|
||||
if isinstance(part, TextPart):
|
||||
total += self._estimate_tokens(part.text)
|
||||
|
||||
# 处理 Tool Calls
|
||||
if msg.tool_calls:
|
||||
for tc in msg.tool_calls:
|
||||
tc_str = json.dumps(tc if isinstance(tc, dict) else tc.model_dump())
|
||||
total += self._estimate_tokens(tc_str)
|
||||
|
||||
return total
|
||||
|
||||
def _estimate_tokens(self, text: str) -> int:
|
||||
chinese_count = len([c for c in text if "\u4e00" <= c <= "\u9fff"])
|
||||
other_count = len(text) - chinese_count
|
||||
return int(chinese_count * 0.6 + other_count * 0.3)
|
||||
@@ -0,0 +1,141 @@
|
||||
from ..message import Message
|
||||
|
||||
|
||||
class ContextTruncator:
|
||||
"""Context truncator."""
|
||||
|
||||
def fix_messages(self, messages: list[Message]) -> list[Message]:
|
||||
fixed_messages = []
|
||||
for message in messages:
|
||||
if message.role == "tool":
|
||||
# tool block 前面必须要有 user 和 assistant block
|
||||
if len(fixed_messages) < 2:
|
||||
# 这种情况可能是上下文被截断导致的
|
||||
# 我们直接将之前的上下文都清空
|
||||
fixed_messages = []
|
||||
else:
|
||||
fixed_messages.append(message)
|
||||
else:
|
||||
fixed_messages.append(message)
|
||||
return fixed_messages
|
||||
|
||||
def truncate_by_turns(
|
||||
self,
|
||||
messages: list[Message],
|
||||
keep_most_recent_turns: int,
|
||||
drop_turns: int = 1,
|
||||
) -> list[Message]:
|
||||
"""截断上下文列表,确保不超过最大长度。
|
||||
一个 turn 包含一个 user 消息和一个 assistant 消息。
|
||||
这个方法会保证截断后的上下文列表符合 OpenAI 的上下文格式。
|
||||
|
||||
Args:
|
||||
messages: 上下文列表
|
||||
keep_most_recent_turns: 保留最近的对话轮数
|
||||
drop_turns: 一次性丢弃的对话轮数
|
||||
|
||||
Returns:
|
||||
截断后的上下文列表
|
||||
"""
|
||||
if keep_most_recent_turns == -1:
|
||||
return messages
|
||||
|
||||
first_non_system = 0
|
||||
for i, msg in enumerate(messages):
|
||||
if msg.role != "system":
|
||||
first_non_system = i
|
||||
break
|
||||
|
||||
system_messages = messages[:first_non_system]
|
||||
non_system_messages = messages[first_non_system:]
|
||||
|
||||
if len(non_system_messages) // 2 <= keep_most_recent_turns:
|
||||
return messages
|
||||
|
||||
num_to_keep = keep_most_recent_turns - drop_turns + 1
|
||||
if num_to_keep <= 0:
|
||||
truncated_contexts = []
|
||||
else:
|
||||
truncated_contexts = non_system_messages[-num_to_keep * 2 :]
|
||||
|
||||
# 找到第一个 role 为 user 的索引,确保上下文格式正确
|
||||
index = next(
|
||||
(i for i, item in enumerate(truncated_contexts) if item.role == "user"),
|
||||
None,
|
||||
)
|
||||
if index is not None and index > 0:
|
||||
truncated_contexts = truncated_contexts[index:]
|
||||
|
||||
result = system_messages + truncated_contexts
|
||||
|
||||
return self.fix_messages(result)
|
||||
|
||||
def truncate_by_dropping_oldest_turns(
|
||||
self,
|
||||
messages: list[Message],
|
||||
drop_turns: int = 1,
|
||||
) -> list[Message]:
|
||||
"""丢弃最旧的 N 个对话轮次。"""
|
||||
if drop_turns <= 0:
|
||||
return messages
|
||||
|
||||
first_non_system = 0
|
||||
for i, msg in enumerate(messages):
|
||||
if msg.role != "system":
|
||||
first_non_system = i
|
||||
break
|
||||
|
||||
system_messages = messages[:first_non_system]
|
||||
non_system_messages = messages[first_non_system:]
|
||||
|
||||
if len(non_system_messages) // 2 <= drop_turns:
|
||||
truncated_non_system = []
|
||||
else:
|
||||
truncated_non_system = non_system_messages[drop_turns * 2 :]
|
||||
|
||||
index = next(
|
||||
(i for i, item in enumerate(truncated_non_system) if item.role == "user"),
|
||||
None,
|
||||
)
|
||||
if index is not None:
|
||||
truncated_non_system = truncated_non_system[index:]
|
||||
elif truncated_non_system:
|
||||
truncated_non_system = []
|
||||
|
||||
result = system_messages + truncated_non_system
|
||||
|
||||
return self.fix_messages(result)
|
||||
|
||||
def truncate_by_halving(
|
||||
self,
|
||||
messages: list[Message],
|
||||
) -> list[Message]:
|
||||
"""对半砍策略,删除 50% 的消息"""
|
||||
if len(messages) <= 2:
|
||||
return messages
|
||||
|
||||
first_non_system = 0
|
||||
for i, msg in enumerate(messages):
|
||||
if msg.role != "system":
|
||||
first_non_system = i
|
||||
break
|
||||
|
||||
system_messages = messages[:first_non_system]
|
||||
non_system_messages = messages[first_non_system:]
|
||||
|
||||
messages_to_delete = len(non_system_messages) // 2
|
||||
if messages_to_delete == 0:
|
||||
return messages
|
||||
|
||||
truncated_non_system = non_system_messages[messages_to_delete:]
|
||||
|
||||
index = next(
|
||||
(i for i, item in enumerate(truncated_non_system) if item.role == "user"),
|
||||
None,
|
||||
)
|
||||
if index is not None:
|
||||
truncated_non_system = truncated_non_system[index:]
|
||||
|
||||
result = system_messages + truncated_non_system
|
||||
|
||||
return self.fix_messages(result)
|
||||
@@ -25,6 +25,10 @@ from astrbot.core.provider.entities import (
|
||||
)
|
||||
from astrbot.core.provider.provider import Provider
|
||||
|
||||
from ..context.compressor import ContextCompressor
|
||||
from ..context.config import ContextConfig
|
||||
from ..context.manager import ContextManager
|
||||
from ..context.token_counter import TokenCounter
|
||||
from ..hooks import BaseAgentRunHooks
|
||||
from ..message import AssistantMessageSegment, Message, ToolCallMessageSegment
|
||||
from ..response import AgentResponseData, AgentStats
|
||||
@@ -47,10 +51,47 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
run_context: ContextWrapper[TContext],
|
||||
tool_executor: BaseFunctionToolExecutor[TContext],
|
||||
agent_hooks: BaseAgentRunHooks[TContext],
|
||||
streaming: bool = False,
|
||||
# enforce max turns, will discard older turns when exceeded BEFORE compression
|
||||
# -1 means no limit
|
||||
enforce_max_turns: int = -1,
|
||||
# llm compressor
|
||||
llm_compress_instruction: str | None = None,
|
||||
llm_compress_keep_recent: int = 0,
|
||||
llm_compress_provider: Provider | None = None,
|
||||
# truncate by turns compressor
|
||||
truncate_turns: int = 1,
|
||||
# customize
|
||||
custom_token_counter: TokenCounter | None = None,
|
||||
custom_compressor: ContextCompressor | None = None,
|
||||
**kwargs: T.Any,
|
||||
) -> None:
|
||||
self.req = request
|
||||
self.streaming = kwargs.get("streaming", False)
|
||||
self.streaming = streaming
|
||||
self.enforce_max_turns = enforce_max_turns
|
||||
self.llm_compress_instruction = llm_compress_instruction
|
||||
self.llm_compress_keep_recent = llm_compress_keep_recent
|
||||
self.llm_compress_provider = llm_compress_provider
|
||||
self.truncate_turns = truncate_turns
|
||||
self.custom_token_counter = custom_token_counter
|
||||
self.custom_compressor = custom_compressor
|
||||
# we will do compress when:
|
||||
# 1. before requesting LLM
|
||||
# TODO: 2. after LLM output a tool call
|
||||
self.context_config = ContextConfig(
|
||||
# <=0 will never do compress
|
||||
max_context_tokens=provider.provider_config.get("max_context_tokens", 0),
|
||||
# enforce max turns before compression
|
||||
enforce_max_turns=self.enforce_max_turns,
|
||||
truncate_turns=self.truncate_turns,
|
||||
llm_compress_instruction=self.llm_compress_instruction,
|
||||
llm_compress_keep_recent=self.llm_compress_keep_recent,
|
||||
llm_compress_provider=self.llm_compress_provider,
|
||||
custom_token_counter=self.custom_token_counter,
|
||||
custom_compressor=self.custom_compressor,
|
||||
)
|
||||
self.context_manager = ContextManager(self.context_config)
|
||||
|
||||
self.provider = provider
|
||||
self.final_llm_resp = None
|
||||
self._state = AgentState.IDLE
|
||||
@@ -110,6 +151,12 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
self._transition_state(AgentState.RUNNING)
|
||||
llm_resp_result = None
|
||||
|
||||
# do truncate and compress
|
||||
token_usage = self.req.conversation.token_usage if self.req.conversation else 0
|
||||
self.run_context.messages = await self.context_manager.process(
|
||||
self.run_context.messages, trusted_token_usage=token_usage
|
||||
)
|
||||
|
||||
async for llm_response in self._iter_llm_responses():
|
||||
if llm_response.is_chunk:
|
||||
# update ttft
|
||||
|
||||
@@ -83,6 +83,16 @@ DEFAULT_CONFIG = {
|
||||
"default_personality": "default",
|
||||
"persona_pool": ["*"],
|
||||
"prompt_prefix": "{{prompt}}",
|
||||
"context_limit_reached_strategy": "truncate_by_turns", # or llm_compress
|
||||
"llm_compress_instruction": (
|
||||
"Based on our full conversation history, produce a concise summary of key takeaways and/or project progress.\n"
|
||||
"1. Systematically cover all core topics discussed and the final conclusion/outcome for each; clearly highlight the latest primary focus.\n"
|
||||
"2. If any tools were used, summarize tool usage (total call count) and extract the most valuable insights from tool outputs.\n"
|
||||
"3. If there was an initial user goal, state it first and describe the current progress/status.\n"
|
||||
"4. Write the summary in the user's language.\n"
|
||||
),
|
||||
"llm_compress_keep_recent": 4,
|
||||
"llm_compress_provider_id": "",
|
||||
"max_context_length": -1,
|
||||
"dequeue_context_length": 1,
|
||||
"streaming_response": False,
|
||||
@@ -179,6 +189,7 @@ class ChatProviderTemplate(TypedDict):
|
||||
model: str
|
||||
modalities: list
|
||||
custom_extra_body: dict[str, Any]
|
||||
max_context_tokens: int
|
||||
|
||||
|
||||
CHAT_PROVIDER_TEMPLATE = {
|
||||
@@ -187,6 +198,7 @@ CHAT_PROVIDER_TEMPLATE = {
|
||||
"model": "",
|
||||
"modalities": [],
|
||||
"custom_extra_body": {},
|
||||
"max_context_tokens": 0,
|
||||
}
|
||||
|
||||
"""
|
||||
@@ -2033,6 +2045,11 @@ CONFIG_METADATA_2 = {
|
||||
"type": "string",
|
||||
"hint": "模型名称,如 gpt-4o-mini, deepseek-chat。",
|
||||
},
|
||||
"max_context_tokens": {
|
||||
"description": "模型上下文窗口大小",
|
||||
"type": "int",
|
||||
"hint": "模型最大上下文 Token 大小。如果为 0,则会自动从模型元数据填充(如有),也可手动修改。",
|
||||
},
|
||||
"dify_api_key": {
|
||||
"description": "API Key",
|
||||
"type": "string",
|
||||
@@ -2540,6 +2557,66 @@ CONFIG_METADATA_3 = {
|
||||
# "provider_settings.enable": True,
|
||||
# },
|
||||
# },
|
||||
"truncate_and_compress": {
|
||||
"description": "上下文管理策略",
|
||||
"type": "object",
|
||||
"items": {
|
||||
"provider_settings.max_context_length": {
|
||||
"description": "最多携带对话轮数",
|
||||
"type": "int",
|
||||
"hint": "超出这个数量时丢弃最旧的部分,一轮聊天记为 1 条,-1 为不限制",
|
||||
"condition": {
|
||||
"provider_settings.agent_runner_type": "local",
|
||||
},
|
||||
},
|
||||
"provider_settings.dequeue_context_length": {
|
||||
"description": "丢弃对话轮数",
|
||||
"type": "int",
|
||||
"hint": "超出最多携带对话轮数时, 一次丢弃的聊天轮数",
|
||||
"condition": {
|
||||
"provider_settings.agent_runner_type": "local",
|
||||
},
|
||||
},
|
||||
"provider_settings.context_limit_reached_strategy": {
|
||||
"description": "超出模型上下文窗口时的处理方式",
|
||||
"type": "string",
|
||||
"options": ["truncate_by_turns", "llm_compress"],
|
||||
"labels": ["按对话轮数截断", "由 LLM 压缩上下文"],
|
||||
"condition": {
|
||||
"provider_settings.agent_runner_type": "local",
|
||||
},
|
||||
"hint": "",
|
||||
},
|
||||
"provider_settings.llm_compress_instruction": {
|
||||
"description": "上下文压缩提示词",
|
||||
"type": "text",
|
||||
"hint": "如果为空则使用默认提示词。",
|
||||
"condition": {
|
||||
"provider_settings.context_limit_reached_strategy": "llm_compress",
|
||||
"provider_settings.agent_runner_type": "local",
|
||||
},
|
||||
},
|
||||
"provider_settings.llm_compress_keep_recent": {
|
||||
"description": "压缩时保留最近对话轮数",
|
||||
"type": "int",
|
||||
"hint": "始终保留的最近 N 轮对话。",
|
||||
"condition": {
|
||||
"provider_settings.context_limit_reached_strategy": "llm_compress",
|
||||
"provider_settings.agent_runner_type": "local",
|
||||
},
|
||||
},
|
||||
"provider_settings.llm_compress_provider_id": {
|
||||
"description": "用于上下文压缩的模型提供商 ID",
|
||||
"type": "string",
|
||||
"_special": "select_provider",
|
||||
"hint": "留空时将降级为“按对话轮数截断”的策略。",
|
||||
"condition": {
|
||||
"provider_settings.context_limit_reached_strategy": "llm_compress",
|
||||
"provider_settings.agent_runner_type": "local",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"others": {
|
||||
"description": "其他配置",
|
||||
"type": "object",
|
||||
@@ -2604,22 +2681,6 @@ CONFIG_METADATA_3 = {
|
||||
"provider_settings.streaming_response": True,
|
||||
},
|
||||
},
|
||||
"provider_settings.max_context_length": {
|
||||
"description": "最多携带对话轮数",
|
||||
"type": "int",
|
||||
"hint": "超出这个数量时丢弃最旧的部分,一轮聊天记为 1 条,-1 为不限制",
|
||||
"condition": {
|
||||
"provider_settings.agent_runner_type": "local",
|
||||
},
|
||||
},
|
||||
"provider_settings.dequeue_context_length": {
|
||||
"description": "丢弃对话轮数",
|
||||
"type": "int",
|
||||
"hint": "超出最多携带对话轮数时, 一次丢弃的聊天轮数",
|
||||
"condition": {
|
||||
"provider_settings.agent_runner_type": "local",
|
||||
},
|
||||
},
|
||||
"provider_settings.wake_prefix": {
|
||||
"description": "LLM 聊天额外唤醒前缀 ",
|
||||
"type": "string",
|
||||
|
||||
@@ -69,6 +69,7 @@ class ConversationManager:
|
||||
persona_id=conv_v2.persona_id,
|
||||
created_at=created_at,
|
||||
updated_at=updated_at,
|
||||
token_usage=conv_v2.token_usage,
|
||||
)
|
||||
|
||||
async def new_conversation(
|
||||
@@ -256,6 +257,7 @@ class ConversationManager:
|
||||
history: list[dict] | None = None,
|
||||
title: str | None = None,
|
||||
persona_id: str | None = None,
|
||||
token_usage: int | None = None,
|
||||
) -> None:
|
||||
"""更新会话的对话.
|
||||
|
||||
@@ -263,6 +265,7 @@ class ConversationManager:
|
||||
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
||||
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
|
||||
history (List[Dict]): 对话历史记录, 是一个字典列表, 每个字典包含 role 和 content 字段
|
||||
token_usage (int | None): token 使用量。None 表示不更新
|
||||
|
||||
"""
|
||||
if not conversation_id:
|
||||
@@ -274,6 +277,7 @@ class ConversationManager:
|
||||
title=title,
|
||||
persona_id=persona_id,
|
||||
content=history,
|
||||
token_usage=token_usage,
|
||||
)
|
||||
|
||||
async def update_conversation_title(
|
||||
|
||||
@@ -152,6 +152,7 @@ class BaseDatabase(abc.ABC):
|
||||
title: str | None = None,
|
||||
persona_id: str | None = None,
|
||||
content: list[dict] | None = None,
|
||||
token_usage: int | None = None,
|
||||
) -> None:
|
||||
"""Update a conversation's history."""
|
||||
...
|
||||
|
||||
@@ -0,0 +1,61 @@
|
||||
"""Migration script to add token_usage column to conversations table.
|
||||
|
||||
This migration adds the token_usage field to track token consumption for each conversation.
|
||||
|
||||
Changes:
|
||||
- Adds token_usage column to conversations table (default: 0)
|
||||
"""
|
||||
|
||||
from sqlalchemy import text
|
||||
|
||||
from astrbot.api import logger, sp
|
||||
from astrbot.core.db import BaseDatabase
|
||||
|
||||
|
||||
async def migrate_token_usage(db_helper: BaseDatabase):
|
||||
"""Add token_usage column to conversations table.
|
||||
|
||||
This migration adds a new column to track token consumption in conversations.
|
||||
"""
|
||||
# 检查是否已经完成迁移
|
||||
migration_done = await db_helper.get_preference(
|
||||
"global", "global", "migration_done_token_usage_1"
|
||||
)
|
||||
if migration_done:
|
||||
return
|
||||
|
||||
logger.info("开始执行数据库迁移(添加 conversations.token_usage 列)...")
|
||||
|
||||
# 这里只适配了 SQLite。因为截止至这一版本,AstrBot 仅支持 SQLite。
|
||||
|
||||
try:
|
||||
async with db_helper.get_db() as session:
|
||||
# 检查列是否已存在
|
||||
result = await session.execute(text("PRAGMA table_info(conversations)"))
|
||||
columns = result.fetchall()
|
||||
column_names = [col[1] for col in columns]
|
||||
|
||||
if "token_usage" in column_names:
|
||||
logger.info("token_usage 列已存在,跳过迁移")
|
||||
await sp.put_async(
|
||||
"global", "global", "migration_done_token_usage_1", True
|
||||
)
|
||||
return
|
||||
|
||||
# 添加 token_usage 列
|
||||
await session.execute(
|
||||
text(
|
||||
"ALTER TABLE conversations ADD COLUMN token_usage INTEGER NOT NULL DEFAULT 0"
|
||||
)
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
logger.info("token_usage 列添加成功")
|
||||
|
||||
# 标记迁移完成
|
||||
await sp.put_async("global", "global", "migration_done_token_usage_1", True)
|
||||
logger.info("token_usage 迁移完成")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"迁移过程中发生错误: {e}", exc_info=True)
|
||||
raise
|
||||
@@ -54,6 +54,11 @@ class ConversationV2(SQLModel, table=True):
|
||||
)
|
||||
title: str | None = Field(default=None, max_length=255)
|
||||
persona_id: str | None = Field(default=None)
|
||||
token_usage: int = Field(default=0, nullable=False)
|
||||
"""content is a list of OpenAI-formated messages in list[dict] format.
|
||||
token_usage is the total token value of the messages.
|
||||
when 0, will use estimated token counter.
|
||||
"""
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
@@ -313,6 +318,8 @@ class Conversation:
|
||||
persona_id: str | None = ""
|
||||
created_at: int = 0
|
||||
updated_at: int = 0
|
||||
token_usage: int = 0
|
||||
"""对话的总 token 数量。AstrBot 会保留最近一次 LLM 请求返回的总 token 数,方便统计。token_usage 可能为 0,表示未知。"""
|
||||
|
||||
|
||||
class Personality(TypedDict):
|
||||
|
||||
@@ -241,7 +241,9 @@ class SQLiteDatabase(BaseDatabase):
|
||||
session.add(new_conversation)
|
||||
return new_conversation
|
||||
|
||||
async def update_conversation(self, cid, title=None, persona_id=None, content=None):
|
||||
async def update_conversation(
|
||||
self, cid, title=None, persona_id=None, content=None, token_usage=None
|
||||
):
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
async with session.begin():
|
||||
@@ -255,6 +257,8 @@ class SQLiteDatabase(BaseDatabase):
|
||||
values["persona_id"] = persona_id
|
||||
if content is not None:
|
||||
values["content"] = content
|
||||
if token_usage is not None:
|
||||
values["token_usage"] = token_usage
|
||||
if not values:
|
||||
return None
|
||||
query = query.values(**values)
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
"""本地 Agent 模式的 LLM 调用 Stage"""
|
||||
|
||||
import asyncio
|
||||
import copy
|
||||
import json
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.agent.message import Message
|
||||
from astrbot.core.agent.response import AgentStats
|
||||
from astrbot.core.agent.tool import ToolSet
|
||||
from astrbot.core.astr_agent_context import AstrAgentContext
|
||||
from astrbot.core.conversation_mgr import Conversation
|
||||
@@ -24,6 +24,7 @@ from astrbot.core.provider.entities import (
|
||||
)
|
||||
from astrbot.core.star.star_handler import EventType, star_map
|
||||
from astrbot.core.utils.file_extract import extract_file_moonshotai
|
||||
from astrbot.core.utils.llm_metadata import LLM_METADATAS
|
||||
from astrbot.core.utils.metrics import Metric
|
||||
from astrbot.core.utils.session_lock import session_lock_manager
|
||||
|
||||
@@ -41,11 +42,6 @@ class InternalAgentSubStage(Stage):
|
||||
self.ctx = ctx
|
||||
conf = ctx.astrbot_config
|
||||
settings = conf["provider_settings"]
|
||||
self.max_context_length = settings["max_context_length"] # int
|
||||
self.dequeue_context_length: int = min(
|
||||
max(1, settings["dequeue_context_length"]),
|
||||
self.max_context_length - 1,
|
||||
)
|
||||
self.streaming_response: bool = settings["streaming_response"]
|
||||
self.unsupported_streaming_strategy: str = settings[
|
||||
"unsupported_streaming_strategy"
|
||||
@@ -65,6 +61,25 @@ class InternalAgentSubStage(Stage):
|
||||
"moonshotai_api_key", ""
|
||||
)
|
||||
|
||||
# 上下文管理相关
|
||||
self.context_limit_reached_strategy: str = settings.get(
|
||||
"context_limit_reached_strategy", "truncate_by_turns"
|
||||
)
|
||||
self.llm_compress_instruction: str = settings.get(
|
||||
"llm_compress_instruction", ""
|
||||
)
|
||||
self.llm_compress_keep_recent: int = settings.get("llm_compress_keep_recent", 4)
|
||||
self.llm_compress_provider_id: str = settings.get(
|
||||
"llm_compress_provider_id", ""
|
||||
)
|
||||
self.max_context_length = settings["max_context_length"] # int
|
||||
self.dequeue_context_length: int = min(
|
||||
max(1, settings["dequeue_context_length"]),
|
||||
self.max_context_length - 1,
|
||||
)
|
||||
if self.dequeue_context_length <= 0:
|
||||
self.dequeue_context_length = 1
|
||||
|
||||
self.conv_manager = ctx.plugin_manager.context.conversation_manager
|
||||
|
||||
def _select_provider(self, event: AstrMessageEvent):
|
||||
@@ -167,34 +182,6 @@ class InternalAgentSubStage(Stage):
|
||||
},
|
||||
)
|
||||
|
||||
def _truncate_contexts(
|
||||
self,
|
||||
contexts: list[dict],
|
||||
) -> list[dict]:
|
||||
"""截断上下文列表,确保不超过最大长度"""
|
||||
if self.max_context_length == -1:
|
||||
return contexts
|
||||
|
||||
if len(contexts) // 2 <= self.max_context_length:
|
||||
return contexts
|
||||
|
||||
truncated_contexts = contexts[
|
||||
-(self.max_context_length - self.dequeue_context_length + 1) * 2 :
|
||||
]
|
||||
# 找到第一个role 为 user 的索引,确保上下文格式正确
|
||||
index = next(
|
||||
(
|
||||
i
|
||||
for i, item in enumerate(truncated_contexts)
|
||||
if item.get("role") == "user"
|
||||
),
|
||||
None,
|
||||
)
|
||||
if index is not None and index > 0:
|
||||
truncated_contexts = truncated_contexts[index:]
|
||||
|
||||
return truncated_contexts
|
||||
|
||||
def _modalities_fix(
|
||||
self,
|
||||
provider: Provider,
|
||||
@@ -296,6 +283,7 @@ class InternalAgentSubStage(Stage):
|
||||
req: ProviderRequest,
|
||||
llm_response: LLMResponse | None,
|
||||
all_messages: list[Message],
|
||||
runner_stats: AgentStats | None,
|
||||
):
|
||||
if (
|
||||
not req
|
||||
@@ -322,27 +310,37 @@ class InternalAgentSubStage(Stage):
|
||||
continue
|
||||
message_to_save.append(message.model_dump())
|
||||
|
||||
# get token usage from agent runner stats
|
||||
token_usage = None
|
||||
if runner_stats:
|
||||
token_usage = runner_stats.token_usage.total
|
||||
|
||||
await self.conv_manager.update_conversation(
|
||||
event.unified_msg_origin,
|
||||
req.conversation.cid,
|
||||
history=message_to_save,
|
||||
token_usage=token_usage,
|
||||
)
|
||||
|
||||
def _fix_messages(self, messages: list[dict]) -> list[dict]:
|
||||
"""验证并且修复上下文"""
|
||||
fixed_messages = []
|
||||
for message in messages:
|
||||
if message.get("role") == "tool":
|
||||
# tool block 前面必须要有 user 和 assistant block
|
||||
if len(fixed_messages) < 2:
|
||||
# 这种情况可能是上下文被截断导致的
|
||||
# 我们直接将之前的上下文都清空
|
||||
fixed_messages = []
|
||||
else:
|
||||
fixed_messages.append(message)
|
||||
else:
|
||||
fixed_messages.append(message)
|
||||
return fixed_messages
|
||||
def _get_compress_provider(self) -> Provider | None:
|
||||
if not self.llm_compress_provider_id:
|
||||
return None
|
||||
if self.context_limit_reached_strategy != "llm_compress":
|
||||
return None
|
||||
provider = self.ctx.plugin_manager.context.get_provider_by_id(
|
||||
self.llm_compress_provider_id,
|
||||
)
|
||||
if provider is None:
|
||||
logger.warning(
|
||||
f"未找到指定的上下文压缩模型 {self.llm_compress_provider_id},将跳过压缩。",
|
||||
)
|
||||
return None
|
||||
if not isinstance(provider, Provider):
|
||||
logger.warning(
|
||||
f"指定的上下文压缩模型 {self.llm_compress_provider_id} 不是对话模型,将跳过压缩。"
|
||||
)
|
||||
return None
|
||||
return provider
|
||||
|
||||
async def process(
|
||||
self, event: AstrMessageEvent, provider_wake_prefix: str
|
||||
@@ -426,9 +424,10 @@ class InternalAgentSubStage(Stage):
|
||||
await self._apply_kb(event, req)
|
||||
|
||||
# truncate contexts to fit max length
|
||||
if req.contexts:
|
||||
req.contexts = self._truncate_contexts(req.contexts)
|
||||
self._fix_messages(req.contexts)
|
||||
# NOW moved to ContextManager inside ToolLoopAgentRunner
|
||||
# if req.contexts:
|
||||
# req.contexts = self._truncate_contexts(req.contexts)
|
||||
# self._fix_messages(req.contexts)
|
||||
|
||||
# session_id
|
||||
if not req.session_id:
|
||||
@@ -444,8 +443,6 @@ class InternalAgentSubStage(Stage):
|
||||
self.unsupported_streaming_strategy == "turn_off"
|
||||
and not event.platform_meta.support_streaming_message
|
||||
)
|
||||
# 备份 req.contexts
|
||||
backup_contexts = copy.deepcopy(req.contexts)
|
||||
|
||||
# run agent
|
||||
agent_runner = AgentRunner()
|
||||
@@ -456,6 +453,15 @@ class InternalAgentSubStage(Stage):
|
||||
context=self.ctx.plugin_manager.context,
|
||||
event=event,
|
||||
)
|
||||
|
||||
# inject model context length limit
|
||||
if provider.provider_config.get("max_context_tokens", 0) <= 0:
|
||||
model = provider.get_model()
|
||||
if model_info := LLM_METADATAS.get(model):
|
||||
provider.provider_config["max_context_tokens"] = model_info[
|
||||
"limit"
|
||||
]["context"]
|
||||
|
||||
await agent_runner.reset(
|
||||
provider=provider,
|
||||
request=req,
|
||||
@@ -466,6 +472,11 @@ class InternalAgentSubStage(Stage):
|
||||
tool_executor=FunctionToolExecutor(),
|
||||
agent_hooks=MAIN_AGENT_HOOKS,
|
||||
streaming=streaming_response,
|
||||
llm_compress_instruction=self.llm_compress_instruction,
|
||||
llm_compress_keep_recent=self.llm_compress_keep_recent,
|
||||
llm_compress_provider=self._get_compress_provider(),
|
||||
truncate_turns=self.dequeue_context_length,
|
||||
enforce_max_turns=self.max_context_length,
|
||||
)
|
||||
|
||||
if streaming_response and not stream_to_general:
|
||||
@@ -511,14 +522,12 @@ class InternalAgentSubStage(Stage):
|
||||
):
|
||||
yield
|
||||
|
||||
# 恢复备份的 contexts
|
||||
req.contexts = backup_contexts
|
||||
|
||||
await self._save_to_history(
|
||||
event,
|
||||
req,
|
||||
agent_runner.get_final_llm_resp(),
|
||||
agent_runner.run_context.messages,
|
||||
agent_runner.stats,
|
||||
)
|
||||
|
||||
# 异步处理 WebChat 特殊情况
|
||||
|
||||
@@ -149,9 +149,12 @@ class Context:
|
||||
contexts: context messages for the LLM
|
||||
max_steps: Maximum number of tool calls before stopping the loop
|
||||
**kwargs: Additional keyword arguments. The kwargs will not be passed to the LLM directly for now, but can include:
|
||||
stream: bool - whether to stream the LLM response
|
||||
agent_hooks: BaseAgentRunHooks[AstrAgentContext] - hooks to run during agent execution
|
||||
agent_context: AstrAgentContext - context to use for the agent
|
||||
|
||||
other kwargs will be DIRECTLY passed to the runner.reset() method
|
||||
|
||||
Returns:
|
||||
The final LLMResponse after tool calls are completed.
|
||||
|
||||
@@ -194,6 +197,15 @@ class Context:
|
||||
)
|
||||
agent_runner = ToolLoopAgentRunner()
|
||||
tool_executor = FunctionToolExecutor()
|
||||
|
||||
streaming = kwargs.get("stream", False)
|
||||
|
||||
other_kwargs = {
|
||||
k: v
|
||||
for k, v in kwargs.items()
|
||||
if k not in ["stream", "agent_hooks", "agent_context"]
|
||||
}
|
||||
|
||||
await agent_runner.reset(
|
||||
provider=prov,
|
||||
request=request,
|
||||
@@ -203,7 +215,8 @@ class Context:
|
||||
),
|
||||
tool_executor=tool_executor,
|
||||
agent_hooks=agent_hooks,
|
||||
streaming=kwargs.get("stream", False),
|
||||
streaming=streaming,
|
||||
**other_kwargs,
|
||||
)
|
||||
async for _ in agent_runner.step_until_done(max_steps):
|
||||
pass
|
||||
|
||||
@@ -3,6 +3,7 @@ import traceback
|
||||
from astrbot.core import astrbot_config, logger
|
||||
from astrbot.core.astrbot_config_mgr import AstrBotConfig, AstrBotConfigManager
|
||||
from astrbot.core.db.migration.migra_45_to_46 import migrate_45_to_46
|
||||
from astrbot.core.db.migration.migra_token_usage import migrate_token_usage
|
||||
from astrbot.core.db.migration.migra_webchat_session import migrate_webchat_session
|
||||
|
||||
|
||||
@@ -139,6 +140,13 @@ async def migra(
|
||||
logger.error(f"Migration for webchat session failed: {e!s}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
# migration for token_usage column
|
||||
try:
|
||||
await migrate_token_usage(db)
|
||||
except Exception as e:
|
||||
logger.error(f"Migration for token_usage column failed: {e!s}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
# migra third party agent runner configs
|
||||
_c = False
|
||||
providers = astrbot_config["provider"]
|
||||
|
||||
@@ -144,7 +144,7 @@
|
||||
color="primary"
|
||||
density="compact"
|
||||
hide-details
|
||||
class="flex-grow-1"
|
||||
style="flex: 1"
|
||||
></v-slider>
|
||||
<v-text-field
|
||||
:model-value="modelValue"
|
||||
@@ -154,7 +154,7 @@
|
||||
class="config-field"
|
||||
type="number"
|
||||
hide-details
|
||||
style="max-width: 140px;"
|
||||
style="flex: 1"
|
||||
></v-text-field>
|
||||
</div>
|
||||
|
||||
@@ -325,4 +325,8 @@ function getSpecialSubtype(value) {
|
||||
.gap-20 {
|
||||
gap: 20px;
|
||||
}
|
||||
|
||||
:deep(.v-field__input) {
|
||||
font-size: 14px;
|
||||
}
|
||||
</style>
|
||||
|
||||
@@ -510,7 +510,7 @@ export function useProviderSources(options: UseProviderSourcesOptions) {
|
||||
|
||||
const metadata = getModelMetadata(modelName)
|
||||
let modalities: string[]
|
||||
|
||||
|
||||
if (!metadata) {
|
||||
modalities = ['text', 'image', 'tool_use']
|
||||
} else {
|
||||
@@ -523,13 +523,19 @@ export function useProviderSources(options: UseProviderSourcesOptions) {
|
||||
}
|
||||
}
|
||||
|
||||
let max_context_tokens = 0
|
||||
if (metadata?.limit?.context && typeof metadata.limit.context === 'number') {
|
||||
max_context_tokens = metadata.limit.context
|
||||
}
|
||||
|
||||
const newProvider = {
|
||||
id: newId,
|
||||
enable: false,
|
||||
provider_source_id: sourceId,
|
||||
model: modelName,
|
||||
modalities,
|
||||
custom_extra_body: {}
|
||||
custom_extra_body: {},
|
||||
max_context_tokens: max_context_tokens
|
||||
}
|
||||
|
||||
try {
|
||||
|
||||
@@ -11,7 +11,12 @@
|
||||
},
|
||||
"agent_runner_type": {
|
||||
"description": "Runner",
|
||||
"labels": ["Built-in Agent", "Dify", "Coze", "Alibaba Cloud Bailian Application"]
|
||||
"labels": [
|
||||
"Built-in Agent",
|
||||
"Dify",
|
||||
"Coze",
|
||||
"Alibaba Cloud Bailian Application"
|
||||
]
|
||||
},
|
||||
"coze_agent_runner_provider_id": {
|
||||
"description": "Coze Agent Runner Provider ID"
|
||||
@@ -128,6 +133,39 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"truncate_and_compress": {
|
||||
"description": "Context Management Strategy",
|
||||
"provider_settings": {
|
||||
"max_context_length": {
|
||||
"description": "Maximum Conversation Turns",
|
||||
"hint": "Discards the oldest parts when this count is exceeded. One conversation round counts as 1, -1 means unlimited"
|
||||
},
|
||||
"dequeue_context_length": {
|
||||
"description": "Dequeue Conversation Turns",
|
||||
"hint": "Number of conversation turns to discard at once when maximum context length is exceeded"
|
||||
},
|
||||
"context_limit_reached_strategy": {
|
||||
"description": "Handling When Model Context Window is Exceeded",
|
||||
"labels": [
|
||||
"Truncate by Turns",
|
||||
"Compress by LLM"
|
||||
],
|
||||
"hint": "When 'Truncate by Turns' is selected, the oldest N conversation turns will be discarded based on the 'Dequeue Conversation Turns' setting above. When 'Compress by LLM' is selected, the specified model will be used for context compression."
|
||||
},
|
||||
"llm_compress_instruction": {
|
||||
"description": "Context Compression Instruction",
|
||||
"hint": "If empty, the default prompt will be used."
|
||||
},
|
||||
"llm_compress_keep_recent": {
|
||||
"description": "Keep Recent Turns When Compressing",
|
||||
"hint": "Always keep the most recent N turns of conversation when compressing context."
|
||||
},
|
||||
"llm_compress_provider_id": {
|
||||
"description": "Model Provider ID for Context Compression",
|
||||
"hint": "When left empty, will fall back to the 'Truncate by Turns' strategy."
|
||||
}
|
||||
}
|
||||
},
|
||||
"others": {
|
||||
"description": "Other Settings",
|
||||
"provider_settings": {
|
||||
@@ -161,15 +199,10 @@
|
||||
"unsupported_streaming_strategy": {
|
||||
"description": "Platforms Without Streaming Support",
|
||||
"hint": "Select the handling method for platforms that don't support streaming responses. Real-time segmented reply sends content immediately when the system detects segment points like punctuation during streaming reception",
|
||||
"labels": ["Real-time Segmented Reply", "Disable Streaming Response"]
|
||||
},
|
||||
"max_context_length": {
|
||||
"description": "Maximum Conversation Rounds",
|
||||
"hint": "Discards the oldest parts when this count is exceeded. One conversation round counts as 1, -1 means unlimited"
|
||||
},
|
||||
"dequeue_context_length": {
|
||||
"description": "Dequeue Conversation Rounds",
|
||||
"hint": "Number of conversation rounds to discard at once when maximum context length is exceeded"
|
||||
"labels": [
|
||||
"Real-time Segmented Reply",
|
||||
"Disable Streaming Response"
|
||||
]
|
||||
},
|
||||
"wake_prefix": {
|
||||
"description": "Additional LLM Chat Wake Prefix",
|
||||
@@ -387,7 +420,10 @@
|
||||
},
|
||||
"split_mode": {
|
||||
"description": "Split Mode",
|
||||
"labels": ["Regex", "Words List"]
|
||||
"labels": [
|
||||
"Regex",
|
||||
"Words List"
|
||||
]
|
||||
},
|
||||
"regex": {
|
||||
"description": "Segmentation Regular Expression"
|
||||
@@ -488,4 +524,4 @@
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -133,6 +133,36 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"truncate_and_compress": {
|
||||
"description": "上下文管理策略",
|
||||
"provider_settings": {
|
||||
"max_context_length": {
|
||||
"description": "最多携带对话轮数",
|
||||
"hint": "超出这个数量时丢弃最旧的部分,一轮聊天记为 1 条,-1 为不限制"
|
||||
},
|
||||
"dequeue_context_length": {
|
||||
"description": "丢弃对话轮数",
|
||||
"hint": "超出最多携带对话轮数时, 一次丢弃的聊天轮数"
|
||||
},
|
||||
"context_limit_reached_strategy": {
|
||||
"description": "超出模型上下文窗口时的处理方式",
|
||||
"labels": ["按对话轮数截断", "由 LLM 压缩上下文"],
|
||||
"hint": "当按对话轮数截断时,会根据上面\"丢弃对话轮数\"的配置丢弃最旧的 N 轮对话。当由 LLM 压缩上下文时,会使用指定的模型进行上下文压缩。"
|
||||
},
|
||||
"llm_compress_instruction": {
|
||||
"description": "上下文压缩提示词",
|
||||
"hint": "如果为空则使用默认提示词。"
|
||||
},
|
||||
"llm_compress_keep_recent": {
|
||||
"description": "压缩时保留最近对话轮数",
|
||||
"hint": "始终保留的最近 N 轮对话。"
|
||||
},
|
||||
"llm_compress_provider_id": {
|
||||
"description": "用于上下文压缩的模型提供商 ID",
|
||||
"hint": "留空时将降级为\"按对话轮数截断\"的策略。"
|
||||
}
|
||||
}
|
||||
},
|
||||
"others": {
|
||||
"description": "其他配置",
|
||||
"provider_settings": {
|
||||
@@ -171,14 +201,7 @@
|
||||
"关闭流式回复"
|
||||
]
|
||||
},
|
||||
"max_context_length": {
|
||||
"description": "最多携带对话轮数",
|
||||
"hint": "超出这个数量时丢弃最旧的部分,一轮聊天记为 1 条,-1 为不限制"
|
||||
},
|
||||
"dequeue_context_length": {
|
||||
"description": "丢弃对话轮数",
|
||||
"hint": "超出最多携带对话轮数时, 一次丢弃的聊天轮数"
|
||||
},
|
||||
|
||||
"wake_prefix": {
|
||||
"description": "LLM 聊天额外唤醒前缀",
|
||||
"hint": "如果唤醒前缀为 /, 额外聊天唤醒前缀为 chat,则需要 /chat 才会触发 LLM 请求"
|
||||
|
||||
@@ -0,0 +1,774 @@
|
||||
"""Comprehensive tests for ContextManager."""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
# Add parent directory to path to avoid circular import issues
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||
|
||||
from astrbot.core.agent.context.config import ContextConfig
|
||||
from astrbot.core.agent.context.manager import ContextManager
|
||||
from astrbot.core.agent.message import Message, TextPart
|
||||
from astrbot.core.provider.entities import LLMResponse
|
||||
|
||||
|
||||
class MockProvider:
|
||||
"""模拟 Provider"""
|
||||
|
||||
def __init__(self):
|
||||
self.provider_config = {
|
||||
"id": "test_provider",
|
||||
"model": "gpt-4",
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
}
|
||||
|
||||
async def text_chat(self, **kwargs):
|
||||
"""模拟 LLM 调用,返回摘要"""
|
||||
messages = kwargs.get("messages", [])
|
||||
# 简单的摘要逻辑:返回消息数量统计
|
||||
return LLMResponse(
|
||||
role="assistant",
|
||||
completion_text=f"历史对话包含 {len(messages) - 1} 条消息,主要讨论了技术话题。",
|
||||
)
|
||||
|
||||
def get_model(self):
|
||||
return "gpt-4"
|
||||
|
||||
def meta(self):
|
||||
return MagicMock(id="test_provider", type="openai")
|
||||
|
||||
|
||||
class TestContextManager:
|
||||
"""Test suite for ContextManager."""
|
||||
|
||||
def create_message(
|
||||
self, role: Literal["system", "user", "assistant", "tool"], content: str
|
||||
) -> Message:
|
||||
"""Helper to create a simple text message."""
|
||||
return Message(role=role, content=content)
|
||||
|
||||
def create_messages(self, count: int) -> list[Message]:
|
||||
"""Helper to create alternating user/assistant messages."""
|
||||
messages = []
|
||||
for i in range(count):
|
||||
role = "user" if i % 2 == 0 else "assistant"
|
||||
messages.append(self.create_message(role, f"Message {i}"))
|
||||
return messages
|
||||
|
||||
# ==================== Basic Initialization Tests ====================
|
||||
|
||||
def test_init_with_minimal_config(self):
|
||||
"""Test initialization with minimal configuration."""
|
||||
config = ContextConfig()
|
||||
manager = ContextManager(config)
|
||||
|
||||
assert manager.config == config
|
||||
assert manager.token_counter is not None
|
||||
assert manager.truncator is not None
|
||||
assert manager.compressor is not None
|
||||
|
||||
def test_init_with_llm_compressor(self):
|
||||
"""Test initialization with LLM-based compression."""
|
||||
mock_provider = MockProvider()
|
||||
config = ContextConfig(
|
||||
llm_compress_provider=mock_provider, # type: ignore
|
||||
llm_compress_keep_recent=5,
|
||||
llm_compress_instruction="Summarize the conversation",
|
||||
)
|
||||
manager = ContextManager(config)
|
||||
|
||||
from astrbot.core.agent.context.compressor import LLMSummaryCompressor
|
||||
|
||||
assert isinstance(manager.compressor, LLMSummaryCompressor)
|
||||
|
||||
def test_init_with_truncate_compressor(self):
|
||||
"""Test initialization with truncate-based compression (default)."""
|
||||
config = ContextConfig(truncate_turns=3)
|
||||
manager = ContextManager(config)
|
||||
|
||||
from astrbot.core.agent.context.compressor import TruncateByTurnsCompressor
|
||||
|
||||
assert isinstance(manager.compressor, TruncateByTurnsCompressor)
|
||||
|
||||
# ==================== Empty and Edge Cases ====================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_empty_messages(self):
|
||||
"""Test processing an empty message list."""
|
||||
config = ContextConfig()
|
||||
manager = ContextManager(config)
|
||||
|
||||
result = await manager.process([])
|
||||
|
||||
assert result == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_single_message(self):
|
||||
"""Test processing a single message."""
|
||||
config = ContextConfig()
|
||||
manager = ContextManager(config)
|
||||
|
||||
messages = [self.create_message("user", "Hello")]
|
||||
result = await manager.process(messages)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].content == "Hello"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_with_no_limits(self):
|
||||
"""Test processing when no limits are set (no truncation or compression)."""
|
||||
config = ContextConfig(max_context_tokens=0, enforce_max_turns=-1)
|
||||
manager = ContextManager(config)
|
||||
|
||||
messages = self.create_messages(20)
|
||||
result = await manager.process(messages)
|
||||
|
||||
assert len(result) == 20
|
||||
assert result == messages
|
||||
|
||||
# ==================== Enforce Max Turns Tests ====================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enforce_max_turns_basic(self):
|
||||
"""Test basic enforce_max_turns functionality."""
|
||||
config = ContextConfig(enforce_max_turns=3, truncate_turns=1)
|
||||
manager = ContextManager(config)
|
||||
|
||||
# Create 10 turns (20 messages)
|
||||
messages = self.create_messages(20)
|
||||
result = await manager.process(messages)
|
||||
|
||||
# Should keep only 3 most recent turns (6 messages)
|
||||
assert len(result) <= 8 # May vary due to truncation logic
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enforce_max_turns_zero(self):
|
||||
"""Test enforce_max_turns with value 0 (should keep nothing)."""
|
||||
config = ContextConfig(enforce_max_turns=0, truncate_turns=1)
|
||||
manager = ContextManager(config)
|
||||
|
||||
messages = self.create_messages(10)
|
||||
result = await manager.process(messages)
|
||||
|
||||
# Should result in empty or minimal message list
|
||||
assert len(result) <= 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enforce_max_turns_negative(self):
|
||||
"""Test enforce_max_turns with -1 (no limit)."""
|
||||
config = ContextConfig(enforce_max_turns=-1)
|
||||
manager = ContextManager(config)
|
||||
|
||||
messages = self.create_messages(20)
|
||||
result = await manager.process(messages)
|
||||
|
||||
assert len(result) == 20
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enforce_max_turns_with_system_messages(self):
|
||||
"""Test enforce_max_turns preserves system messages."""
|
||||
config = ContextConfig(enforce_max_turns=2, truncate_turns=1)
|
||||
manager = ContextManager(config)
|
||||
|
||||
messages = [
|
||||
self.create_message("system", "System instruction"),
|
||||
*self.create_messages(10),
|
||||
]
|
||||
result = await manager.process(messages)
|
||||
|
||||
# System message should be preserved
|
||||
system_msgs = [m for m in result if m.role == "system"]
|
||||
assert len(system_msgs) >= 1
|
||||
assert system_msgs[0].content == "System instruction"
|
||||
|
||||
# ==================== Token-based Compression Tests ====================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_compression_not_triggered_below_threshold(self):
|
||||
"""Test that compression is not triggered below threshold."""
|
||||
config = ContextConfig(max_context_tokens=1000)
|
||||
manager = ContextManager(config)
|
||||
|
||||
# Create messages that total less than threshold
|
||||
messages = [self.create_message("user", "Hi" * 50)] # ~100 tokens
|
||||
|
||||
with patch.object(
|
||||
manager.compressor, "should_compress", return_value=False
|
||||
) as mock_should_compress:
|
||||
with patch.object(
|
||||
manager.compressor, "__call__", new_callable=AsyncMock
|
||||
) as mock_compress:
|
||||
result = await manager.process(messages)
|
||||
|
||||
# should_compress should be called
|
||||
mock_should_compress.assert_called_once()
|
||||
# Compressor should not be called
|
||||
mock_compress.assert_not_called()
|
||||
assert result == messages
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_compression_triggered_above_threshold(self):
|
||||
"""Test that compression is triggered above threshold."""
|
||||
config = ContextConfig(max_context_tokens=100, truncate_turns=1)
|
||||
manager = ContextManager(config)
|
||||
|
||||
# Create messages that exceed threshold (0.82 * 100 = 82 tokens)
|
||||
# 300 chars * 0.3 = 90 tokens > 82 threshold
|
||||
long_text = "x" * 300 # ~90 tokens, above threshold
|
||||
messages = [self.create_message("user", long_text)]
|
||||
|
||||
# Mock compressor to return smaller result
|
||||
compressed = [self.create_message("user", "short")]
|
||||
|
||||
# Create a mock compressor
|
||||
mock_compressor = AsyncMock()
|
||||
mock_compressor.compression_threshold = 0.82
|
||||
mock_compressor.return_value = compressed
|
||||
|
||||
# Mock should_compress to return True first time, False after
|
||||
call_count = 0
|
||||
|
||||
def mock_should_compress(*args, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return call_count == 1
|
||||
|
||||
mock_compressor.should_compress = mock_should_compress
|
||||
manager.compressor = mock_compressor
|
||||
|
||||
result = await manager.process(messages)
|
||||
|
||||
# Compressor should be called
|
||||
mock_compressor.assert_called_once()
|
||||
# Result should be the compressed version
|
||||
assert len(result) <= len(messages)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_compression_with_zero_max_tokens(self):
|
||||
"""Test that compression is skipped when max_context_tokens is 0."""
|
||||
config = ContextConfig(max_context_tokens=0)
|
||||
manager = ContextManager(config)
|
||||
|
||||
messages = [self.create_message("user", "x" * 10000)]
|
||||
|
||||
with patch.object(
|
||||
manager.compressor, "__call__", new_callable=AsyncMock
|
||||
) as mock_compress:
|
||||
result = await manager.process(messages)
|
||||
|
||||
# Compressor should not be called when max_context_tokens is 0
|
||||
mock_compress.assert_not_called()
|
||||
assert result == messages
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_compression_with_negative_max_tokens(self):
|
||||
"""Test that compression is skipped when max_context_tokens is negative."""
|
||||
config = ContextConfig(max_context_tokens=-100)
|
||||
manager = ContextManager(config)
|
||||
|
||||
messages = [self.create_message("user", "x" * 10000)]
|
||||
|
||||
with patch.object(
|
||||
manager.compressor, "__call__", new_callable=AsyncMock
|
||||
) as mock_compress:
|
||||
result = await manager.process(messages)
|
||||
|
||||
# Compressor should not be called
|
||||
mock_compress.assert_not_called()
|
||||
assert result == messages
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_double_check_after_compression(self):
|
||||
"""Test that halving is applied if still over threshold after compression."""
|
||||
config = ContextConfig(max_context_tokens=100)
|
||||
manager = ContextManager(config)
|
||||
|
||||
# Create messages that would still be over threshold after compression
|
||||
long_messages = [self.create_message("user", "x" * 200) for _ in range(10)]
|
||||
|
||||
# Mock compressor to return messages still over threshold
|
||||
async def mock_compress(msgs):
|
||||
return msgs # Return same messages (still over limit)
|
||||
|
||||
# Mock should_compress to return True twice (before and after compression)
|
||||
with patch.object(manager.compressor, "should_compress", return_value=True):
|
||||
with patch.object(manager.compressor, "__call__", new=mock_compress):
|
||||
with patch.object(
|
||||
manager.truncator,
|
||||
"truncate_by_halving",
|
||||
return_value=long_messages[:5],
|
||||
) as mock_halving:
|
||||
_ = await manager.process(long_messages)
|
||||
|
||||
# Halving should be called
|
||||
mock_halving.assert_called_once()
|
||||
|
||||
# ==================== Combined Truncation and Compression Tests ====================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_combined_enforce_turns_and_token_limit(self):
|
||||
"""Test combining enforce_max_turns and token limit."""
|
||||
config = ContextConfig(
|
||||
enforce_max_turns=5, max_context_tokens=500, truncate_turns=1
|
||||
)
|
||||
manager = ContextManager(config)
|
||||
|
||||
# Create many messages
|
||||
messages = self.create_messages(30)
|
||||
|
||||
result = await manager.process(messages)
|
||||
|
||||
# Should be truncated by both mechanisms
|
||||
assert len(result) < 30
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sequential_processing_order(self):
|
||||
"""Test that enforce_max_turns happens before token compression."""
|
||||
config = ContextConfig(enforce_max_turns=5, max_context_tokens=1000)
|
||||
manager = ContextManager(config)
|
||||
|
||||
messages = self.create_messages(20)
|
||||
|
||||
# Mock the truncator to track calls
|
||||
with patch.object(
|
||||
manager.truncator,
|
||||
"truncate_by_turns",
|
||||
wraps=manager.truncator.truncate_by_turns,
|
||||
) as mock_truncate:
|
||||
await manager.process(messages)
|
||||
|
||||
# Truncator should be called first
|
||||
mock_truncate.assert_called_once()
|
||||
|
||||
# ==================== Error Handling Tests ====================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_handling_returns_original_messages(self):
|
||||
"""Test that errors during processing return original messages."""
|
||||
config = ContextConfig(max_context_tokens=100)
|
||||
manager = ContextManager(config)
|
||||
|
||||
messages = self.create_messages(5)
|
||||
|
||||
# Make compressor raise an exception
|
||||
with patch.object(
|
||||
manager.compressor, "__call__", side_effect=Exception("Test error")
|
||||
):
|
||||
result = await manager.process(messages)
|
||||
|
||||
# Should return original messages despite error
|
||||
assert result == messages
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_handling_logs_exception(self):
|
||||
"""Test that errors are logged."""
|
||||
config = ContextConfig(max_context_tokens=100)
|
||||
manager = ContextManager(config)
|
||||
|
||||
# Create messages that will trigger compression (> 82 tokens)
|
||||
messages = [self.create_message("user", "x" * 300)] # ~90 tokens
|
||||
|
||||
# Replace compressor with one that raises an exception
|
||||
mock_compressor = AsyncMock(side_effect=Exception("Test error"))
|
||||
mock_compressor.compression_threshold = 0.82
|
||||
mock_compressor.should_compress = MagicMock(return_value=True)
|
||||
manager.compressor = mock_compressor
|
||||
|
||||
with patch("astrbot.core.agent.context.manager.logger") as mock_logger:
|
||||
result = await manager.process(messages)
|
||||
|
||||
# Logger error method should be called
|
||||
assert mock_logger.error.called
|
||||
# Should return original messages on error
|
||||
assert result == messages
|
||||
|
||||
# ==================== Multi-modal Content Tests ====================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_messages_with_textpart_content(self):
|
||||
"""Test processing messages with TextPart content."""
|
||||
config = ContextConfig()
|
||||
manager = ContextManager(config)
|
||||
|
||||
messages = [
|
||||
Message(role="user", content=[TextPart(text="Hello")]),
|
||||
Message(role="assistant", content=[TextPart(text="Hi there")]),
|
||||
]
|
||||
|
||||
result = await manager.process(messages)
|
||||
|
||||
assert len(result) == 2
|
||||
assert result == messages
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_counting_with_multimodal_content(self):
|
||||
"""Test token counting works with multi-modal content."""
|
||||
config = ContextConfig(max_context_tokens=50)
|
||||
manager = ContextManager(config)
|
||||
|
||||
# Need enough tokens to exceed threshold: 50 * 0.82 = 41 tokens
|
||||
# 150 chars * 0.3 = 45 tokens > 41
|
||||
messages = [
|
||||
Message(role="user", content=[TextPart(text="x" * 150)]),
|
||||
]
|
||||
|
||||
# Should trigger compression due to token count
|
||||
tokens = manager.token_counter.count_tokens(messages)
|
||||
needs_compression = manager.compressor.should_compress(messages, tokens, 50)
|
||||
|
||||
assert tokens > 0 # Tokens should be counted
|
||||
assert needs_compression # Should trigger compression
|
||||
|
||||
# ==================== Tool Calls Tests ====================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_messages_with_tool_calls(self):
|
||||
"""Test processing messages with tool calls."""
|
||||
config = ContextConfig()
|
||||
manager = ContextManager(config)
|
||||
|
||||
messages = [
|
||||
Message(
|
||||
role="assistant",
|
||||
content="Let me search for that",
|
||||
tool_calls=[
|
||||
{
|
||||
"id": "call_1",
|
||||
"type": "function",
|
||||
"function": {"name": "search", "arguments": "{}"},
|
||||
}
|
||||
],
|
||||
),
|
||||
Message(role="tool", content="Search result", tool_call_id="call_1"),
|
||||
]
|
||||
|
||||
result = await manager.process(messages)
|
||||
|
||||
assert len(result) == 2
|
||||
|
||||
# ==================== Compressor should_compress Tests ====================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_compress_empty_messages(self):
|
||||
"""Test should_compress with empty messages."""
|
||||
config = ContextConfig(max_context_tokens=100)
|
||||
manager = ContextManager(config)
|
||||
|
||||
# Compressor's should_compress should handle empty gracefully
|
||||
needs_compression = manager.compressor.should_compress([], 0, 100)
|
||||
assert not needs_compression
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_compress_below_threshold(self):
|
||||
"""Test should_compress when below compression threshold."""
|
||||
config = ContextConfig(max_context_tokens=1000)
|
||||
manager = ContextManager(config)
|
||||
|
||||
messages = [self.create_message("user", "Hello")]
|
||||
tokens = manager.token_counter.count_tokens(messages)
|
||||
|
||||
needs_compression = manager.compressor.should_compress(messages, tokens, 1000)
|
||||
assert not needs_compression
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_compress_above_threshold(self):
|
||||
"""Test should_compress when above compression threshold."""
|
||||
config = ContextConfig(max_context_tokens=100)
|
||||
manager = ContextManager(config)
|
||||
|
||||
# Create message with many tokens
|
||||
messages = [self.create_message("user", "这是测试" * 50)]
|
||||
tokens = manager.token_counter.count_tokens(messages)
|
||||
|
||||
needs_compression = manager.compressor.should_compress(messages, tokens, 100)
|
||||
# Should need compression if tokens > 82 (0.82 * 100)
|
||||
assert needs_compression == (tokens > 82)
|
||||
|
||||
# ==================== Truncator Halving Tests ====================
|
||||
|
||||
def test_truncate_by_halving_basic(self):
|
||||
"""Test truncate_by_halving removes middle 50%."""
|
||||
config = ContextConfig()
|
||||
manager = ContextManager(config)
|
||||
|
||||
messages = self.create_messages(10)
|
||||
result = manager.truncator.truncate_by_halving(messages)
|
||||
|
||||
# Should keep roughly half
|
||||
assert len(result) < len(messages)
|
||||
|
||||
def test_truncate_by_halving_empty_list(self):
|
||||
"""Test truncate_by_halving with empty list."""
|
||||
config = ContextConfig()
|
||||
manager = ContextManager(config)
|
||||
|
||||
result = manager.truncator.truncate_by_halving([])
|
||||
|
||||
assert result == []
|
||||
|
||||
def test_truncate_by_halving_single_message(self):
|
||||
"""Test truncate_by_halving with single message."""
|
||||
config = ContextConfig()
|
||||
manager = ContextManager(config)
|
||||
|
||||
messages = [self.create_message("user", "Hello")]
|
||||
result = manager.truncator.truncate_by_halving(messages)
|
||||
|
||||
assert len(result) <= 1
|
||||
|
||||
# ==================== Complex Scenarios ====================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_compression_cycles(self):
|
||||
"""Test that compression can be triggered multiple times in sequence."""
|
||||
config = ContextConfig(max_context_tokens=50, truncate_turns=1)
|
||||
manager = ContextManager(config)
|
||||
|
||||
# Process messages multiple times
|
||||
messages = self.create_messages(10)
|
||||
|
||||
result1 = await manager.process(messages)
|
||||
result2 = await manager.process(result1)
|
||||
result3 = await manager.process(result2)
|
||||
|
||||
# Each cycle should maintain or reduce message count
|
||||
assert len(result3) <= len(result2) <= len(result1)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_alternating_roles_preserved(self):
|
||||
"""Test that user/assistant alternation is preserved after processing."""
|
||||
config = ContextConfig(enforce_max_turns=3, truncate_turns=1)
|
||||
manager = ContextManager(config)
|
||||
|
||||
messages = self.create_messages(20)
|
||||
result = await manager.process(messages)
|
||||
|
||||
# Check that roles still alternate (excluding system messages)
|
||||
non_system = [m for m in result if m.role != "system"]
|
||||
if len(non_system) >= 2:
|
||||
# Should start with user
|
||||
assert non_system[0].role == "user"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_compression_threshold_default(self):
|
||||
"""Test that compression threshold is used correctly."""
|
||||
config = ContextConfig(max_context_tokens=100)
|
||||
manager = ContextManager(config)
|
||||
|
||||
# Verify the default threshold is 0.82
|
||||
assert manager.compressor.compression_threshold == 0.82
|
||||
|
||||
# Test threshold logic
|
||||
messages = [self.create_message("user", "x" * 81)] # ~24 tokens
|
||||
tokens = manager.token_counter.count_tokens(messages)
|
||||
|
||||
needs_compression = manager.compressor.should_compress(messages, tokens, 100)
|
||||
# Should not compress if below threshold
|
||||
assert needs_compression == (tokens > 82)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_large_batch_processing(self):
|
||||
"""Test processing a large batch of messages."""
|
||||
config = ContextConfig(
|
||||
enforce_max_turns=10, max_context_tokens=1000, truncate_turns=2
|
||||
)
|
||||
manager = ContextManager(config)
|
||||
|
||||
# Create 100 messages (50 turns)
|
||||
messages = self.create_messages(100)
|
||||
|
||||
result = await manager.process(messages)
|
||||
|
||||
# Should be significantly reduced
|
||||
assert len(result) < 100
|
||||
assert len(result) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_config_persistence(self):
|
||||
"""Test that config settings are respected throughout processing."""
|
||||
config = ContextConfig(
|
||||
max_context_tokens=500,
|
||||
enforce_max_turns=5,
|
||||
truncate_turns=2,
|
||||
llm_compress_keep_recent=3,
|
||||
)
|
||||
manager = ContextManager(config)
|
||||
|
||||
# Verify config is stored
|
||||
assert manager.config.max_context_tokens == 500
|
||||
assert manager.config.enforce_max_turns == 5
|
||||
assert manager.config.truncate_turns == 2
|
||||
assert manager.config.llm_compress_keep_recent == 3
|
||||
|
||||
# ==================== Run Compression Tests ====================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_compression_calls_compressor(self):
|
||||
"""Test _run_compression calls compressor."""
|
||||
config = ContextConfig(max_context_tokens=100)
|
||||
manager = ContextManager(config)
|
||||
|
||||
messages = self.create_messages(5)
|
||||
compressed = self.create_messages(3)
|
||||
|
||||
# Create a mock compressor
|
||||
mock_compressor = AsyncMock()
|
||||
mock_compressor.compression_threshold = 0.82
|
||||
mock_compressor.return_value = compressed
|
||||
mock_compressor.should_compress = MagicMock(return_value=False)
|
||||
manager.compressor = mock_compressor
|
||||
|
||||
result = await manager._run_compression(messages, prev_tokens=100)
|
||||
|
||||
# Compressor __call__ should be invoked
|
||||
mock_compressor.assert_called_once_with(messages)
|
||||
assert result == compressed
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_compression_applies_compressor_through_process(self):
|
||||
"""Test _run_compression calls compressor when needed through process()."""
|
||||
config = ContextConfig(max_context_tokens=100, truncate_turns=1)
|
||||
manager = ContextManager(config)
|
||||
|
||||
# Create messages that will trigger compression
|
||||
messages = [self.create_message("user", "x" * 300)] # ~90 tokens > 82 threshold
|
||||
compressed = [self.create_message("user", "short")] # Much smaller
|
||||
|
||||
# Create a mock compressor
|
||||
mock_compressor = AsyncMock()
|
||||
mock_compressor.compression_threshold = 0.82
|
||||
mock_compressor.return_value = compressed
|
||||
|
||||
# Mock should_compress to return True first time, False after
|
||||
call_count = 0
|
||||
|
||||
def mock_should_compress(*args, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return call_count == 1
|
||||
|
||||
mock_compressor.should_compress = mock_should_compress
|
||||
manager.compressor = mock_compressor
|
||||
|
||||
result = await manager.process(messages)
|
||||
|
||||
# Compressor should have been called
|
||||
mock_compressor.assert_called_once()
|
||||
assert len(result) <= len(messages)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_compression_with_mock_provider(self):
|
||||
"""Test LLM compression using MockProvider."""
|
||||
mock_provider = MockProvider()
|
||||
config = ContextConfig(
|
||||
llm_compress_provider=mock_provider, # type: ignore
|
||||
llm_compress_keep_recent=3,
|
||||
llm_compress_instruction="请总结对话内容",
|
||||
max_context_tokens=100,
|
||||
)
|
||||
manager = ContextManager(config)
|
||||
|
||||
# Create messages that will trigger compression
|
||||
messages = [
|
||||
self.create_message("user", "x" * 100),
|
||||
self.create_message("assistant", "y" * 100),
|
||||
self.create_message("user", "z" * 100),
|
||||
]
|
||||
|
||||
result = await manager.process(messages)
|
||||
|
||||
# Should have been compressed
|
||||
assert len(result) <= len(messages)
|
||||
|
||||
# ==================== split_history Tests ====================
|
||||
|
||||
def test_split_history_ensures_user_start(self):
|
||||
"""Test split_history ensures recent_messages starts with user message."""
|
||||
from astrbot.core.agent.context.compressor import split_history
|
||||
|
||||
# Create alternating messages: user, assistant, user, assistant, user, assistant
|
||||
messages = [
|
||||
self.create_message("system", "System prompt"),
|
||||
self.create_message("user", "msg1"),
|
||||
self.create_message("assistant", "msg2"),
|
||||
self.create_message("user", "msg3"),
|
||||
self.create_message("assistant", "msg4"),
|
||||
self.create_message("user", "msg5"),
|
||||
self.create_message("assistant", "msg6"),
|
||||
]
|
||||
|
||||
# Keep recent 3 messages - should adjust to start with user
|
||||
system, to_summarize, recent = split_history(messages, keep_recent=3)
|
||||
|
||||
# recent_messages should start with user message
|
||||
assert len(recent) > 0
|
||||
assert recent[0].role == "user"
|
||||
|
||||
# messages_to_summarize should end with assistant (complete turn)
|
||||
if len(to_summarize) > 0:
|
||||
assert to_summarize[-1].role == "assistant"
|
||||
|
||||
def test_split_history_handles_assistant_at_split_point(self):
|
||||
"""Test split_history when assistant message is at the intended split point."""
|
||||
from astrbot.core.agent.context.compressor import split_history
|
||||
|
||||
messages = [
|
||||
self.create_message("user", "msg1"),
|
||||
self.create_message("assistant", "msg2"),
|
||||
self.create_message("user", "msg3"),
|
||||
self.create_message("assistant", "msg4"), # <- intended split here
|
||||
self.create_message("user", "msg5"),
|
||||
self.create_message("assistant", "msg6"),
|
||||
]
|
||||
|
||||
# keep_recent=2 would normally split at index 4 (assistant msg4)
|
||||
# Should move back to include from msg5 (user)
|
||||
system, to_summarize, recent = split_history(messages, keep_recent=2)
|
||||
|
||||
# recent should start with user message
|
||||
assert recent[0].role == "user"
|
||||
assert recent[0].content == "msg5"
|
||||
|
||||
def test_split_history_all_assistant_messages(self):
|
||||
"""Test split_history when there are consecutive assistant messages."""
|
||||
from astrbot.core.agent.context.compressor import split_history
|
||||
|
||||
messages = [
|
||||
self.create_message("user", "msg1"),
|
||||
self.create_message("assistant", "msg2"),
|
||||
self.create_message("assistant", "msg3"),
|
||||
self.create_message("assistant", "msg4"),
|
||||
]
|
||||
|
||||
system, to_summarize, recent = split_history(messages, keep_recent=2)
|
||||
|
||||
# Should find the user message and keep from there
|
||||
if len(recent) > 0:
|
||||
# Find first user message backwards
|
||||
assert any(m.role == "user" for m in messages)
|
||||
|
||||
def test_split_history_with_system_messages(self):
|
||||
"""Test split_history preserves system messages separately."""
|
||||
from astrbot.core.agent.context.compressor import split_history
|
||||
|
||||
messages = [
|
||||
self.create_message("system", "System 1"),
|
||||
self.create_message("system", "System 2"),
|
||||
self.create_message("user", "msg1"),
|
||||
self.create_message("assistant", "msg2"),
|
||||
self.create_message("user", "msg3"),
|
||||
]
|
||||
|
||||
system, to_summarize, recent = split_history(messages, keep_recent=2)
|
||||
|
||||
# System messages should be separate
|
||||
assert len(system) == 2
|
||||
assert all(m.role == "system" for m in system)
|
||||
|
||||
# Recent should start with user
|
||||
if len(recent) > 0:
|
||||
assert recent[0].role == "user"
|
||||
@@ -0,0 +1,423 @@
|
||||
"""Tests for ContextTruncator."""
|
||||
|
||||
from astrbot.core.agent.context.truncator import ContextTruncator
|
||||
from astrbot.core.agent.message import Message
|
||||
|
||||
|
||||
class TestContextTruncator:
|
||||
"""Test suite for ContextTruncator."""
|
||||
|
||||
def create_message(self, role: str, content: str = "test content") -> Message:
|
||||
"""Helper to create a simple test message."""
|
||||
return Message(role=role, content=content)
|
||||
|
||||
def create_messages(
|
||||
self, count: int, include_system: bool = False
|
||||
) -> list[Message]:
|
||||
"""Helper to create alternating user/assistant messages.
|
||||
|
||||
Args:
|
||||
count: Number of messages to create
|
||||
include_system: Whether to include a system message at the start
|
||||
|
||||
Returns:
|
||||
List of messages
|
||||
"""
|
||||
messages = []
|
||||
if include_system:
|
||||
messages.append(self.create_message("system", "System prompt"))
|
||||
|
||||
for i in range(count):
|
||||
role = "user" if i % 2 == 0 else "assistant"
|
||||
messages.append(self.create_message(role, f"Message {i}"))
|
||||
return messages
|
||||
|
||||
# ==================== fix_messages Tests ====================
|
||||
|
||||
def test_fix_messages_empty_list(self):
|
||||
"""Test fix_messages with an empty list."""
|
||||
truncator = ContextTruncator()
|
||||
result = truncator.fix_messages([])
|
||||
assert result == []
|
||||
|
||||
def test_fix_messages_normal_messages(self):
|
||||
"""Test fix_messages with normal user/assistant messages."""
|
||||
truncator = ContextTruncator()
|
||||
messages = [
|
||||
self.create_message("user", "Hello"),
|
||||
self.create_message("assistant", "Hi"),
|
||||
self.create_message("user", "How are you?"),
|
||||
]
|
||||
result = truncator.fix_messages(messages)
|
||||
assert len(result) == 3
|
||||
assert result == messages
|
||||
|
||||
def test_fix_messages_tool_with_valid_context(self):
|
||||
"""Test fix_messages with tool message after user+assistant."""
|
||||
truncator = ContextTruncator()
|
||||
messages = [
|
||||
self.create_message("user", "Run tool"),
|
||||
self.create_message("assistant", "Running..."),
|
||||
self.create_message("tool", "Tool result"),
|
||||
]
|
||||
result = truncator.fix_messages(messages)
|
||||
assert len(result) == 3
|
||||
assert result == messages
|
||||
|
||||
def test_fix_messages_tool_without_context(self):
|
||||
"""Test fix_messages with tool message without enough context."""
|
||||
truncator = ContextTruncator()
|
||||
messages = [
|
||||
self.create_message("tool", "Tool result"),
|
||||
]
|
||||
result = truncator.fix_messages(messages)
|
||||
# Tool message without context should be removed
|
||||
assert len(result) == 0
|
||||
|
||||
def test_fix_messages_tool_with_only_one_message(self):
|
||||
"""Test fix_messages with tool message after only one message."""
|
||||
truncator = ContextTruncator()
|
||||
messages = [
|
||||
self.create_message("user", "Hello"),
|
||||
self.create_message("tool", "Tool result"),
|
||||
]
|
||||
result = truncator.fix_messages(messages)
|
||||
# Tool message without enough context should be removed
|
||||
assert len(result) == 0
|
||||
|
||||
def test_fix_messages_multiple_tools(self):
|
||||
"""Test fix_messages with multiple tool messages."""
|
||||
truncator = ContextTruncator()
|
||||
messages = [
|
||||
self.create_message("user", "Run tool"),
|
||||
self.create_message("assistant", "Running..."),
|
||||
self.create_message("tool", "Tool 1 result"),
|
||||
self.create_message("tool", "Tool 2 result"),
|
||||
]
|
||||
result = truncator.fix_messages(messages)
|
||||
assert len(result) == 4
|
||||
assert result == messages
|
||||
|
||||
def test_fix_messages_mixed_system_tool(self):
|
||||
"""Test fix_messages with system message and tool messages."""
|
||||
truncator = ContextTruncator()
|
||||
messages = [
|
||||
self.create_message("system", "System prompt"),
|
||||
self.create_message("user", "Run tool"),
|
||||
self.create_message("assistant", "Running..."),
|
||||
self.create_message("tool", "Tool result"),
|
||||
]
|
||||
result = truncator.fix_messages(messages)
|
||||
assert len(result) == 4
|
||||
assert result == messages
|
||||
|
||||
# ==================== truncate_by_turns Tests ====================
|
||||
|
||||
def test_truncate_by_turns_no_limit(self):
|
||||
"""Test truncate_by_turns with -1 (no limit)."""
|
||||
truncator = ContextTruncator()
|
||||
messages = self.create_messages(20)
|
||||
result = truncator.truncate_by_turns(messages, keep_most_recent_turns=-1)
|
||||
assert len(result) == 20
|
||||
assert result == messages
|
||||
|
||||
def test_truncate_by_turns_basic(self):
|
||||
"""Test basic truncate_by_turns functionality."""
|
||||
truncator = ContextTruncator()
|
||||
# Create 10 messages = 5 turns (user/assistant pairs)
|
||||
messages = self.create_messages(10)
|
||||
result = truncator.truncate_by_turns(
|
||||
messages, keep_most_recent_turns=3, drop_turns=1
|
||||
)
|
||||
|
||||
# Should keep 3 most recent turns (6 messages)
|
||||
assert len(result) <= 8 # (3-1+1)*2 = 6, but may adjust for correct format
|
||||
|
||||
def test_truncate_by_turns_with_system_message(self):
|
||||
"""Test truncate_by_turns preserves system messages."""
|
||||
truncator = ContextTruncator()
|
||||
messages = self.create_messages(10, include_system=True)
|
||||
result = truncator.truncate_by_turns(
|
||||
messages, keep_most_recent_turns=2, drop_turns=1
|
||||
)
|
||||
|
||||
# System message should always be preserved
|
||||
assert result[0].role == "system"
|
||||
assert result[0].content == "System prompt"
|
||||
|
||||
def test_truncate_by_turns_zero_keep(self):
|
||||
"""Test truncate_by_turns with keep_most_recent_turns=0."""
|
||||
truncator = ContextTruncator()
|
||||
messages = self.create_messages(10)
|
||||
result = truncator.truncate_by_turns(
|
||||
messages, keep_most_recent_turns=0, drop_turns=1
|
||||
)
|
||||
|
||||
# Should result in empty or minimal list
|
||||
assert len(result) == 0
|
||||
|
||||
def test_truncate_by_turns_below_threshold(self):
|
||||
"""Test truncate_by_turns when messages are below threshold."""
|
||||
truncator = ContextTruncator()
|
||||
# Create 4 messages = 2 turns
|
||||
messages = self.create_messages(4)
|
||||
result = truncator.truncate_by_turns(
|
||||
messages, keep_most_recent_turns=5, drop_turns=1
|
||||
)
|
||||
|
||||
# No truncation should happen
|
||||
assert len(result) == 4
|
||||
assert result == messages
|
||||
|
||||
def test_truncate_by_turns_exact_threshold(self):
|
||||
"""Test truncate_by_turns when messages exactly match threshold."""
|
||||
truncator = ContextTruncator()
|
||||
# Create 6 messages = 3 turns
|
||||
messages = self.create_messages(6)
|
||||
result = truncator.truncate_by_turns(
|
||||
messages, keep_most_recent_turns=3, drop_turns=1
|
||||
)
|
||||
|
||||
# No truncation should happen
|
||||
assert len(result) == 6
|
||||
assert result == messages
|
||||
|
||||
def test_truncate_by_turns_ensures_user_first(self):
|
||||
"""Test that truncate_by_turns ensures user message comes first."""
|
||||
truncator = ContextTruncator()
|
||||
# Create scenario where truncation might start with assistant
|
||||
messages = self.create_messages(20)
|
||||
result = truncator.truncate_by_turns(
|
||||
messages, keep_most_recent_turns=3, drop_turns=1
|
||||
)
|
||||
|
||||
# First non-system message should be user
|
||||
assert result[0].role == "user"
|
||||
|
||||
def test_truncate_by_turns_multiple_drop(self):
|
||||
"""Test truncate_by_turns with multiple turns dropped at once."""
|
||||
truncator = ContextTruncator()
|
||||
messages = self.create_messages(20)
|
||||
result = truncator.truncate_by_turns(
|
||||
messages, keep_most_recent_turns=5, drop_turns=3
|
||||
)
|
||||
|
||||
# Should drop 3 turns when limit exceeded
|
||||
assert len(result) < len(messages)
|
||||
|
||||
# ==================== truncate_by_dropping_oldest_turns Tests ====================
|
||||
|
||||
def test_truncate_by_dropping_oldest_turns_zero(self):
|
||||
"""Test truncate_by_dropping_oldest_turns with drop_turns=0."""
|
||||
truncator = ContextTruncator()
|
||||
messages = self.create_messages(10)
|
||||
result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=0)
|
||||
assert result == messages
|
||||
|
||||
def test_truncate_by_dropping_oldest_turns_negative(self):
|
||||
"""Test truncate_by_dropping_oldest_turns with negative drop_turns."""
|
||||
truncator = ContextTruncator()
|
||||
messages = self.create_messages(10)
|
||||
result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=-1)
|
||||
assert result == messages
|
||||
|
||||
def test_truncate_by_dropping_oldest_turns_basic(self):
|
||||
"""Test basic truncate_by_dropping_oldest_turns functionality."""
|
||||
truncator = ContextTruncator()
|
||||
# Create 10 messages = 5 turns
|
||||
messages = self.create_messages(10)
|
||||
result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=2)
|
||||
|
||||
# Should drop 2 oldest turns (4 messages)
|
||||
assert len(result) == 6
|
||||
# Should start with user message
|
||||
assert result[0].role == "user"
|
||||
|
||||
def test_truncate_by_dropping_oldest_turns_with_system(self):
|
||||
"""Test truncate_by_dropping_oldest_turns preserves system messages."""
|
||||
truncator = ContextTruncator()
|
||||
messages = self.create_messages(10, include_system=True)
|
||||
result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=2)
|
||||
|
||||
# System message should be preserved
|
||||
assert result[0].role == "system"
|
||||
assert result[0].content == "System prompt"
|
||||
|
||||
def test_truncate_by_dropping_oldest_turns_drop_all(self):
|
||||
"""Test truncate_by_dropping_oldest_turns dropping all turns."""
|
||||
truncator = ContextTruncator()
|
||||
# Create 4 messages = 2 turns
|
||||
messages = self.create_messages(4)
|
||||
result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=2)
|
||||
|
||||
# Should drop all turns
|
||||
assert len(result) == 0
|
||||
|
||||
def test_truncate_by_dropping_oldest_turns_drop_more_than_available(self):
|
||||
"""Test truncate_by_dropping_oldest_turns with drop_turns > available turns."""
|
||||
truncator = ContextTruncator()
|
||||
# Create 4 messages = 2 turns
|
||||
messages = self.create_messages(4)
|
||||
result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=5)
|
||||
|
||||
# Should result in empty list
|
||||
assert len(result) == 0
|
||||
|
||||
def test_truncate_by_dropping_oldest_turns_ensures_user_first(self):
|
||||
"""Test that result starts with user message after dropping."""
|
||||
truncator = ContextTruncator()
|
||||
messages = self.create_messages(20)
|
||||
result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=3)
|
||||
|
||||
# First message should be user
|
||||
if len(result) > 0:
|
||||
assert result[0].role == "user"
|
||||
|
||||
# ==================== truncate_by_halving Tests ====================
|
||||
|
||||
def test_truncate_by_halving_empty(self):
|
||||
"""Test truncate_by_halving with empty list."""
|
||||
truncator = ContextTruncator()
|
||||
result = truncator.truncate_by_halving([])
|
||||
assert result == []
|
||||
|
||||
def test_truncate_by_halving_single_message(self):
|
||||
"""Test truncate_by_halving with single message."""
|
||||
truncator = ContextTruncator()
|
||||
messages = [self.create_message("user", "Hello")]
|
||||
result = truncator.truncate_by_halving(messages)
|
||||
# Should not truncate if <= 2 messages
|
||||
assert result == messages
|
||||
|
||||
def test_truncate_by_halving_two_messages(self):
|
||||
"""Test truncate_by_halving with two messages."""
|
||||
truncator = ContextTruncator()
|
||||
messages = self.create_messages(2)
|
||||
result = truncator.truncate_by_halving(messages)
|
||||
# Should not truncate if <= 2 messages
|
||||
assert result == messages
|
||||
|
||||
def test_truncate_by_halving_basic(self):
|
||||
"""Test basic truncate_by_halving functionality."""
|
||||
truncator = ContextTruncator()
|
||||
# Create 20 messages
|
||||
messages = self.create_messages(20)
|
||||
result = truncator.truncate_by_halving(messages)
|
||||
|
||||
# Should delete 50% = 10 messages, keep 10
|
||||
assert len(result) == 10
|
||||
# First message should be user
|
||||
assert result[0].role == "user"
|
||||
|
||||
def test_truncate_by_halving_with_system_message(self):
|
||||
"""Test truncate_by_halving preserves system messages."""
|
||||
truncator = ContextTruncator()
|
||||
messages = self.create_messages(20, include_system=True)
|
||||
result = truncator.truncate_by_halving(messages)
|
||||
|
||||
# System message should be preserved
|
||||
assert result[0].role == "system"
|
||||
assert result[0].content == "System prompt"
|
||||
|
||||
def test_truncate_by_halving_odd_count(self):
|
||||
"""Test truncate_by_halving with odd number of messages."""
|
||||
truncator = ContextTruncator()
|
||||
messages = self.create_messages(11)
|
||||
result = truncator.truncate_by_halving(messages)
|
||||
|
||||
# Should delete floor(11/2) = 5 messages, keep 6
|
||||
# But after ensuring user first, may be 5
|
||||
assert len(result) >= 5
|
||||
assert result[0].role == "user"
|
||||
|
||||
def test_truncate_by_halving_ensures_user_first(self):
|
||||
"""Test that result starts with user message."""
|
||||
truncator = ContextTruncator()
|
||||
# Create messages starting with user
|
||||
messages = self.create_messages(30)
|
||||
result = truncator.truncate_by_halving(messages)
|
||||
|
||||
# First message should be user
|
||||
assert result[0].role == "user"
|
||||
|
||||
def test_truncate_by_halving_preserves_recent_messages(self):
|
||||
"""Test that truncate_by_halving keeps the most recent 50%."""
|
||||
truncator = ContextTruncator()
|
||||
messages = [
|
||||
self.create_message("user", "Message 0"),
|
||||
self.create_message("assistant", "Message 1"),
|
||||
self.create_message("user", "Message 2"),
|
||||
self.create_message("assistant", "Message 3"),
|
||||
]
|
||||
result = truncator.truncate_by_halving(messages)
|
||||
|
||||
# Should keep last 2 messages
|
||||
assert len(result) == 2
|
||||
assert result[0].content == "Message 2"
|
||||
assert result[1].content == "Message 3"
|
||||
|
||||
# ==================== Integration Tests ====================
|
||||
|
||||
def test_truncate_with_tool_messages(self):
|
||||
"""Test truncation with tool messages."""
|
||||
truncator = ContextTruncator()
|
||||
messages = [
|
||||
self.create_message("user", "Run tool"),
|
||||
self.create_message("assistant", "Running..."),
|
||||
self.create_message("tool", "Tool result"),
|
||||
self.create_message("user", "Thanks"),
|
||||
self.create_message("assistant", "Welcome"),
|
||||
]
|
||||
|
||||
result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=1)
|
||||
|
||||
# First turn (user+assistant+tool) should be dropped
|
||||
# Tool message should be cleaned up by fix_messages
|
||||
assert len(result) <= 2
|
||||
|
||||
def test_chain_multiple_truncations(self):
|
||||
"""Test chaining multiple truncation methods."""
|
||||
truncator = ContextTruncator()
|
||||
messages = self.create_messages(40, include_system=True)
|
||||
|
||||
# First: truncate by turns
|
||||
result = truncator.truncate_by_turns(
|
||||
messages, keep_most_recent_turns=10, drop_turns=2
|
||||
)
|
||||
# Then: halve
|
||||
result = truncator.truncate_by_halving(result)
|
||||
|
||||
# Should have system message + truncated content
|
||||
assert result[0].role == "system"
|
||||
assert len(result) < len(messages)
|
||||
|
||||
def test_empty_after_system_message(self):
|
||||
"""Test truncation when only system message exists."""
|
||||
truncator = ContextTruncator()
|
||||
messages = [self.create_message("system", "System prompt")]
|
||||
|
||||
result = truncator.truncate_by_turns(
|
||||
messages, keep_most_recent_turns=5, drop_turns=1
|
||||
)
|
||||
|
||||
# Should keep system message
|
||||
assert len(result) == 1
|
||||
assert result[0].role == "system"
|
||||
|
||||
def test_all_system_messages(self):
|
||||
"""Test truncation with only system messages."""
|
||||
truncator = ContextTruncator()
|
||||
messages = [
|
||||
self.create_message("system", "System 1"),
|
||||
self.create_message("system", "System 2"),
|
||||
]
|
||||
|
||||
result = truncator.truncate_by_turns(
|
||||
messages, keep_most_recent_turns=0, drop_turns=1
|
||||
)
|
||||
|
||||
# System messages should be preserved, but since there are no non-system
|
||||
# messages and keep_most_recent_turns=0, result should be system messages only
|
||||
assert len(result) >= 0 # May keep system messages or clear all
|
||||
if len(result) > 0:
|
||||
assert all(msg.role == "system" for msg in result)
|
||||
Reference in New Issue
Block a user