Compare commits

...

5 Commits

Author SHA1 Message Date
Soulter 8e7b44185d chore: bump version to 4.11.0 2026-01-05 18:05:12 +08:00
RC-CHN ef1c66a92e feat(webui): enable Range request support for backup downloads (#4329) 2026-01-05 17:27:03 +08:00
Soulter 241f1c26d3 feat: context compress (#4322)
* feat: context compressor

Co-authored-by: kawayiYokami <289104862@qq.com>

* Add comprehensive tests for ContextManager and ContextTruncator

- Implemented a full test suite for ContextManager covering initialization, message processing, token-based compression, and error handling.
- Added tests for ContextTruncator focusing on message fixing, truncation by turns, dropping oldest turns, and halving.
- Ensured that both test suites validate edge cases and maintain expected behavior with various message types, including system and tool messages.

* feat: add MockProvider for LLM compression tests

* chore: remove lock

* ruff fix

* fix

* perf

* feat: enhance context compression with token tracking and logging

* feat: update logging for context compression trigger

* feat: implement context compression logic with dynamic threshold and token tracking

* fix: reorder import statements for consistency

* feat: add token_usage tracking to conversations and update related processing logic

---------

Co-authored-by: kawayiYokami <289104862@qq.com>
2026-01-05 17:26:10 +08:00
Soulter 3615b7dde2 fix: token usage is always 0 in anthropic source (#4328) 2026-01-05 17:06:12 +08:00
RC-CHN 9bcf9bf2a0 fix(dashboard): complete i18n support for shared components (#4327)
* fix(dashboard): complete i18n support for shared components

- Replace hardcoded Chinese strings with i18n translations in:
  - PluginSetSelector.vue
  - ProviderSelector.vue
  - PersonaSelector.vue
  - KnowledgeBaseSelector.vue
  - T2ITemplateEditor.vue
  - AstrBotConfigV4.vue
  - ConfigItemRenderer.vue
  - ProxySelector.vue
  - ListConfigItem.vue

- Add missing translations to locale files:
  - core/shared.json: personaSelector, t2iTemplateEditor
  - core/common.json: autoDetect
  - features/settings.json: network.proxySelector

- Change prop defaults from hardcoded Chinese to empty strings,
  allowing components to use i18n fallback translations

* fix(i18n): 修正插件选择器标签的翻译格式,添加冒号

* fix(deployment): 添加持久化 machine-id PVC 和初始化容器,优化资源限制
2026-01-05 09:45:28 +08:00
42 changed files with 2477 additions and 183 deletions
+1 -1
View File
@@ -1 +1 @@
__version__ = "4.10.6"
__version__ = "4.11.0"
+243
View File
@@ -0,0 +1,243 @@
from typing import TYPE_CHECKING, Protocol, runtime_checkable
from ..message import Message
if TYPE_CHECKING:
from astrbot import logger
else:
try:
from astrbot import logger
except ImportError:
import logging
logger = logging.getLogger("astrbot")
if TYPE_CHECKING:
from astrbot.core.provider.provider import Provider
from ..context.truncator import ContextTruncator
@runtime_checkable
class ContextCompressor(Protocol):
"""
Protocol for context compressors.
Provides an interface for compressing message lists.
"""
def should_compress(
self, messages: list[Message], current_tokens: int, max_tokens: int
) -> bool:
"""Check if compression is needed.
Args:
messages: The message list to evaluate.
current_tokens: The current token count.
max_tokens: The maximum allowed tokens for the model.
Returns:
True if compression is needed, False otherwise.
"""
...
async def __call__(self, messages: list[Message]) -> list[Message]:
"""Compress the message list.
Args:
messages: The original message list.
Returns:
The compressed message list.
"""
...
class TruncateByTurnsCompressor:
"""Truncate by turns compressor implementation.
Truncates the message list by removing older turns.
"""
def __init__(self, truncate_turns: int = 1, compression_threshold: float = 0.82):
"""Initialize the truncate by turns compressor.
Args:
truncate_turns: The number of turns to remove when truncating (default: 1).
compression_threshold: The compression trigger threshold (default: 0.82).
"""
self.truncate_turns = truncate_turns
self.compression_threshold = compression_threshold
def should_compress(
self, messages: list[Message], current_tokens: int, max_tokens: int
) -> bool:
"""Check if compression is needed.
Args:
messages: The message list to evaluate.
current_tokens: The current token count.
max_tokens: The maximum allowed tokens.
Returns:
True if compression is needed, False otherwise.
"""
if max_tokens <= 0 or current_tokens <= 0:
return False
usage_rate = current_tokens / max_tokens
return usage_rate > self.compression_threshold
async def __call__(self, messages: list[Message]) -> list[Message]:
truncator = ContextTruncator()
truncated_messages = truncator.truncate_by_dropping_oldest_turns(
messages,
drop_turns=self.truncate_turns,
)
return truncated_messages
def split_history(
messages: list[Message], keep_recent: int
) -> tuple[list[Message], list[Message], list[Message]]:
"""Split the message list into system messages, messages to summarize, and recent messages.
Ensures that the split point is between complete user-assistant pairs to maintain conversation flow.
Args:
messages: The original message list.
keep_recent: The number of latest messages to keep.
Returns:
tuple: (system_messages, messages_to_summarize, recent_messages)
"""
# keep the system messages
first_non_system = 0
for i, msg in enumerate(messages):
if msg.role != "system":
first_non_system = i
break
system_messages = messages[:first_non_system]
non_system_messages = messages[first_non_system:]
if len(non_system_messages) <= keep_recent:
return system_messages, [], non_system_messages
# Find the split point, ensuring recent_messages starts with a user message
# This maintains complete conversation turns
split_index = len(non_system_messages) - keep_recent
# Search backward from split_index to find the first user message
# This ensures recent_messages starts with a user message (complete turn)
while split_index > 0 and non_system_messages[split_index].role != "user":
# TODO: +=1 or -=1 ? calculate by tokens
split_index -= 1
# If we couldn't find a user message, keep all messages as recent
if split_index == 0:
return system_messages, [], non_system_messages
messages_to_summarize = non_system_messages[:split_index]
recent_messages = non_system_messages[split_index:]
return system_messages, messages_to_summarize, recent_messages
class LLMSummaryCompressor:
"""LLM-based summary compressor.
Uses LLM to summarize the old conversation history, keeping the latest messages.
"""
def __init__(
self,
provider: "Provider",
keep_recent: int = 4,
instruction_text: str | None = None,
compression_threshold: float = 0.82,
):
"""Initialize the LLM summary compressor.
Args:
provider: The LLM provider instance.
keep_recent: The number of latest messages to keep (default: 4).
instruction_text: Custom instruction for summary generation.
compression_threshold: The compression trigger threshold (default: 0.82).
"""
self.provider = provider
self.keep_recent = keep_recent
self.compression_threshold = compression_threshold
self.instruction_text = instruction_text or (
"Based on our full conversation history, produce a concise summary of key takeaways and/or project progress.\n"
"1. Systematically cover all core topics discussed and the final conclusion/outcome for each; clearly highlight the latest primary focus.\n"
"2. If any tools were used, summarize tool usage (total call count) and extract the most valuable insights from tool outputs.\n"
"3. If there was an initial user goal, state it first and describe the current progress/status.\n"
"4. Write the summary in the user's language.\n"
)
def should_compress(
self, messages: list[Message], current_tokens: int, max_tokens: int
) -> bool:
"""Check if compression is needed.
Args:
messages: The message list to evaluate.
current_tokens: The current token count.
max_tokens: The maximum allowed tokens.
Returns:
True if compression is needed, False otherwise.
"""
if max_tokens <= 0 or current_tokens <= 0:
return False
usage_rate = current_tokens / max_tokens
return usage_rate > self.compression_threshold
async def __call__(self, messages: list[Message]) -> list[Message]:
"""Use LLM to generate a summary of the conversation history.
Process:
1. Divide messages: keep the system message and the latest N messages.
2. Send the old messages + the instruction message to the LLM.
3. Reconstruct the message list: [system message, summary message, latest messages].
"""
if len(messages) <= self.keep_recent + 1:
return messages
system_messages, messages_to_summarize, recent_messages = split_history(
messages, self.keep_recent
)
if not messages_to_summarize:
return messages
# build payload
instruction_message = Message(role="user", content=self.instruction_text)
llm_payload = messages_to_summarize + [instruction_message]
# generate summary
try:
response = await self.provider.text_chat(contexts=llm_payload)
summary_content = response.completion_text
except Exception as e:
logger.error(f"Failed to generate summary: {e}")
return messages
# build result
result = []
result.extend(system_messages)
result.append(
Message(
role="user",
content=f"Our previous history conversation summary: {summary_content}",
)
)
result.append(
Message(
role="assistant",
content="Acknowledged the summary of our previous conversation history.",
)
)
result.extend(recent_messages)
return result
+35
View File
@@ -0,0 +1,35 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING
from .compressor import ContextCompressor
from .token_counter import TokenCounter
if TYPE_CHECKING:
from astrbot.core.provider.provider import Provider
@dataclass
class ContextConfig:
"""Context configuration class."""
max_context_tokens: int = 0
"""Maximum number of context tokens. <= 0 means no limit."""
enforce_max_turns: int = -1 # -1 means no limit
"""Maximum number of conversation turns to keep. -1 means no limit. Executed before compression."""
truncate_turns: int = 1
"""Number of conversation turns to discard at once when truncation is triggered.
Two processes will use this value:
1. Enforce max turns truncation.
2. Truncation by turns compression strategy.
"""
llm_compress_instruction: str | None = None
"""Instruction prompt for LLM-based compression."""
llm_compress_keep_recent: int = 0
"""Number of recent messages to keep during LLM-based compression."""
llm_compress_provider: "Provider | None" = None
"""LLM provider used for compression tasks. If None, truncation strategy is used."""
custom_token_counter: TokenCounter | None = None
"""Custom token counting method. If None, the default method is used."""
custom_compressor: ContextCompressor | None = None
"""Custom context compression method. If None, the default method is used."""
+120
View File
@@ -0,0 +1,120 @@
from astrbot import logger
from ..message import Message
from .compressor import LLMSummaryCompressor, TruncateByTurnsCompressor
from .config import ContextConfig
from .token_counter import EstimateTokenCounter
from .truncator import ContextTruncator
class ContextManager:
"""Context compression manager."""
def __init__(
self,
config: ContextConfig,
):
"""Initialize the context manager.
There are two strategies to handle context limit reached:
1. Truncate by turns: remove older messages by turns.
2. LLM-based compression: use LLM to summarize old messages.
Args:
config: The context configuration.
"""
self.config = config
self.token_counter = config.custom_token_counter or EstimateTokenCounter()
self.truncator = ContextTruncator()
if config.custom_compressor:
self.compressor = config.custom_compressor
elif config.llm_compress_provider:
self.compressor = LLMSummaryCompressor(
provider=config.llm_compress_provider,
keep_recent=config.llm_compress_keep_recent,
instruction_text=config.llm_compress_instruction,
)
else:
self.compressor = TruncateByTurnsCompressor(
truncate_turns=config.truncate_turns
)
async def process(
self, messages: list[Message], trusted_token_usage: int = 0
) -> list[Message]:
"""Process the messages.
Args:
messages: The original message list.
Returns:
The processed message list.
"""
try:
result = messages
# 1. 基于轮次的截断 (Enforce max turns)
if self.config.enforce_max_turns != -1:
result = self.truncator.truncate_by_turns(
result,
keep_most_recent_turns=self.config.enforce_max_turns,
drop_turns=self.config.truncate_turns,
)
# 2. 基于 token 的压缩
if self.config.max_context_tokens > 0:
total_tokens = self.token_counter.count_tokens(
result, trusted_token_usage
)
if self.compressor.should_compress(
result, total_tokens, self.config.max_context_tokens
):
result = await self._run_compression(result, total_tokens)
return result
except Exception as e:
logger.error(f"Error during context processing: {e}", exc_info=True)
return messages
async def _run_compression(
self, messages: list[Message], prev_tokens: int
) -> list[Message]:
"""
Compress/truncate the messages.
Args:
messages: The original message list.
prev_tokens: The token count before compression.
Returns:
The compressed/truncated message list.
"""
logger.debug("Compress triggered, starting compression...")
messages = await self.compressor(messages)
# double check
tokens_after_summary = self.token_counter.count_tokens(messages)
# calculate compress rate
compress_rate = (tokens_after_summary / self.config.max_context_tokens) * 100
logger.info(
f"Compress completed."
f" {prev_tokens} -> {tokens_after_summary} tokens,"
f" compression rate: {compress_rate:.2f}%.",
)
# last check
if self.compressor.should_compress(
messages, tokens_after_summary, self.config.max_context_tokens
):
logger.info(
"Context still exceeds max tokens after compression, applying halving truncation..."
)
# still need compress, truncate by half
messages = self.truncator.truncate_by_halving(messages)
return messages
@@ -0,0 +1,64 @@
import json
from typing import Protocol, runtime_checkable
from ..message import Message, TextPart
@runtime_checkable
class TokenCounter(Protocol):
"""
Protocol for token counters.
Provides an interface for counting tokens in message lists.
"""
def count_tokens(
self, messages: list[Message], trusted_token_usage: int = 0
) -> int:
"""Count the total tokens in the message list.
Args:
messages: The message list.
trusted_token_usage: The total token usage that LLM API returned.
For some cases, this value is more accurate.
But some API does not return it, so the value defaults to 0.
Returns:
The total token count.
"""
...
class EstimateTokenCounter:
"""Estimate token counter implementation.
Provides a simple estimation of token count based on character types.
"""
def count_tokens(
self, messages: list[Message], trusted_token_usage: int = 0
) -> int:
if trusted_token_usage > 0:
return trusted_token_usage
total = 0
for msg in messages:
content = msg.content
if isinstance(content, str):
total += self._estimate_tokens(content)
elif isinstance(content, list):
# 处理多模态内容
for part in content:
if isinstance(part, TextPart):
total += self._estimate_tokens(part.text)
# 处理 Tool Calls
if msg.tool_calls:
for tc in msg.tool_calls:
tc_str = json.dumps(tc if isinstance(tc, dict) else tc.model_dump())
total += self._estimate_tokens(tc_str)
return total
def _estimate_tokens(self, text: str) -> int:
chinese_count = len([c for c in text if "\u4e00" <= c <= "\u9fff"])
other_count = len(text) - chinese_count
return int(chinese_count * 0.6 + other_count * 0.3)
+141
View File
@@ -0,0 +1,141 @@
from ..message import Message
class ContextTruncator:
"""Context truncator."""
def fix_messages(self, messages: list[Message]) -> list[Message]:
fixed_messages = []
for message in messages:
if message.role == "tool":
# tool block 前面必须要有 user 和 assistant block
if len(fixed_messages) < 2:
# 这种情况可能是上下文被截断导致的
# 我们直接将之前的上下文都清空
fixed_messages = []
else:
fixed_messages.append(message)
else:
fixed_messages.append(message)
return fixed_messages
def truncate_by_turns(
self,
messages: list[Message],
keep_most_recent_turns: int,
drop_turns: int = 1,
) -> list[Message]:
"""截断上下文列表,确保不超过最大长度。
一个 turn 包含一个 user 消息和一个 assistant 消息。
这个方法会保证截断后的上下文列表符合 OpenAI 的上下文格式。
Args:
messages: 上下文列表
keep_most_recent_turns: 保留最近的对话轮数
drop_turns: 一次性丢弃的对话轮数
Returns:
截断后的上下文列表
"""
if keep_most_recent_turns == -1:
return messages
first_non_system = 0
for i, msg in enumerate(messages):
if msg.role != "system":
first_non_system = i
break
system_messages = messages[:first_non_system]
non_system_messages = messages[first_non_system:]
if len(non_system_messages) // 2 <= keep_most_recent_turns:
return messages
num_to_keep = keep_most_recent_turns - drop_turns + 1
if num_to_keep <= 0:
truncated_contexts = []
else:
truncated_contexts = non_system_messages[-num_to_keep * 2 :]
# 找到第一个 role 为 user 的索引,确保上下文格式正确
index = next(
(i for i, item in enumerate(truncated_contexts) if item.role == "user"),
None,
)
if index is not None and index > 0:
truncated_contexts = truncated_contexts[index:]
result = system_messages + truncated_contexts
return self.fix_messages(result)
def truncate_by_dropping_oldest_turns(
self,
messages: list[Message],
drop_turns: int = 1,
) -> list[Message]:
"""丢弃最旧的 N 个对话轮次。"""
if drop_turns <= 0:
return messages
first_non_system = 0
for i, msg in enumerate(messages):
if msg.role != "system":
first_non_system = i
break
system_messages = messages[:first_non_system]
non_system_messages = messages[first_non_system:]
if len(non_system_messages) // 2 <= drop_turns:
truncated_non_system = []
else:
truncated_non_system = non_system_messages[drop_turns * 2 :]
index = next(
(i for i, item in enumerate(truncated_non_system) if item.role == "user"),
None,
)
if index is not None:
truncated_non_system = truncated_non_system[index:]
elif truncated_non_system:
truncated_non_system = []
result = system_messages + truncated_non_system
return self.fix_messages(result)
def truncate_by_halving(
self,
messages: list[Message],
) -> list[Message]:
"""对半砍策略,删除 50% 的消息"""
if len(messages) <= 2:
return messages
first_non_system = 0
for i, msg in enumerate(messages):
if msg.role != "system":
first_non_system = i
break
system_messages = messages[:first_non_system]
non_system_messages = messages[first_non_system:]
messages_to_delete = len(non_system_messages) // 2
if messages_to_delete == 0:
return messages
truncated_non_system = non_system_messages[messages_to_delete:]
index = next(
(i for i, item in enumerate(truncated_non_system) if item.role == "user"),
None,
)
if index is not None:
truncated_non_system = truncated_non_system[index:]
result = system_messages + truncated_non_system
return self.fix_messages(result)
@@ -25,6 +25,10 @@ from astrbot.core.provider.entities import (
)
from astrbot.core.provider.provider import Provider
from ..context.compressor import ContextCompressor
from ..context.config import ContextConfig
from ..context.manager import ContextManager
from ..context.token_counter import TokenCounter
from ..hooks import BaseAgentRunHooks
from ..message import AssistantMessageSegment, Message, ToolCallMessageSegment
from ..response import AgentResponseData, AgentStats
@@ -47,10 +51,47 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
run_context: ContextWrapper[TContext],
tool_executor: BaseFunctionToolExecutor[TContext],
agent_hooks: BaseAgentRunHooks[TContext],
streaming: bool = False,
# enforce max turns, will discard older turns when exceeded BEFORE compression
# -1 means no limit
enforce_max_turns: int = -1,
# llm compressor
llm_compress_instruction: str | None = None,
llm_compress_keep_recent: int = 0,
llm_compress_provider: Provider | None = None,
# truncate by turns compressor
truncate_turns: int = 1,
# customize
custom_token_counter: TokenCounter | None = None,
custom_compressor: ContextCompressor | None = None,
**kwargs: T.Any,
) -> None:
self.req = request
self.streaming = kwargs.get("streaming", False)
self.streaming = streaming
self.enforce_max_turns = enforce_max_turns
self.llm_compress_instruction = llm_compress_instruction
self.llm_compress_keep_recent = llm_compress_keep_recent
self.llm_compress_provider = llm_compress_provider
self.truncate_turns = truncate_turns
self.custom_token_counter = custom_token_counter
self.custom_compressor = custom_compressor
# we will do compress when:
# 1. before requesting LLM
# TODO: 2. after LLM output a tool call
self.context_config = ContextConfig(
# <=0 will never do compress
max_context_tokens=provider.provider_config.get("max_context_tokens", 0),
# enforce max turns before compression
enforce_max_turns=self.enforce_max_turns,
truncate_turns=self.truncate_turns,
llm_compress_instruction=self.llm_compress_instruction,
llm_compress_keep_recent=self.llm_compress_keep_recent,
llm_compress_provider=self.llm_compress_provider,
custom_token_counter=self.custom_token_counter,
custom_compressor=self.custom_compressor,
)
self.context_manager = ContextManager(self.context_config)
self.provider = provider
self.final_llm_resp = None
self._state = AgentState.IDLE
@@ -110,6 +151,12 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
self._transition_state(AgentState.RUNNING)
llm_resp_result = None
# do truncate and compress
token_usage = self.req.conversation.token_usage if self.req.conversation else 0
self.run_context.messages = await self.context_manager.process(
self.run_context.messages, trusted_token_usage=token_usage
)
async for llm_response in self._iter_llm_responses():
if llm_response.is_chunk:
# update ttft
+78 -17
View File
@@ -5,7 +5,7 @@ from typing import Any, TypedDict
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
VERSION = "4.10.6"
VERSION = "4.11.0"
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db")
WEBHOOK_SUPPORTED_PLATFORMS = [
@@ -83,6 +83,16 @@ DEFAULT_CONFIG = {
"default_personality": "default",
"persona_pool": ["*"],
"prompt_prefix": "{{prompt}}",
"context_limit_reached_strategy": "truncate_by_turns", # or llm_compress
"llm_compress_instruction": (
"Based on our full conversation history, produce a concise summary of key takeaways and/or project progress.\n"
"1. Systematically cover all core topics discussed and the final conclusion/outcome for each; clearly highlight the latest primary focus.\n"
"2. If any tools were used, summarize tool usage (total call count) and extract the most valuable insights from tool outputs.\n"
"3. If there was an initial user goal, state it first and describe the current progress/status.\n"
"4. Write the summary in the user's language.\n"
),
"llm_compress_keep_recent": 4,
"llm_compress_provider_id": "",
"max_context_length": -1,
"dequeue_context_length": 1,
"streaming_response": False,
@@ -179,6 +189,7 @@ class ChatProviderTemplate(TypedDict):
model: str
modalities: list
custom_extra_body: dict[str, Any]
max_context_tokens: int
CHAT_PROVIDER_TEMPLATE = {
@@ -187,6 +198,7 @@ CHAT_PROVIDER_TEMPLATE = {
"model": "",
"modalities": [],
"custom_extra_body": {},
"max_context_tokens": 0,
}
"""
@@ -2033,6 +2045,11 @@ CONFIG_METADATA_2 = {
"type": "string",
"hint": "模型名称,如 gpt-4o-mini, deepseek-chat。",
},
"max_context_tokens": {
"description": "模型上下文窗口大小",
"type": "int",
"hint": "模型最大上下文 Token 大小。如果为 0,则会自动从模型元数据填充(如有),也可手动修改。",
},
"dify_api_key": {
"description": "API Key",
"type": "string",
@@ -2540,6 +2557,66 @@ CONFIG_METADATA_3 = {
# "provider_settings.enable": True,
# },
# },
"truncate_and_compress": {
"description": "上下文管理策略",
"type": "object",
"items": {
"provider_settings.max_context_length": {
"description": "最多携带对话轮数",
"type": "int",
"hint": "超出这个数量时丢弃最旧的部分,一轮聊天记为 1 条,-1 为不限制",
"condition": {
"provider_settings.agent_runner_type": "local",
},
},
"provider_settings.dequeue_context_length": {
"description": "丢弃对话轮数",
"type": "int",
"hint": "超出最多携带对话轮数时, 一次丢弃的聊天轮数",
"condition": {
"provider_settings.agent_runner_type": "local",
},
},
"provider_settings.context_limit_reached_strategy": {
"description": "超出模型上下文窗口时的处理方式",
"type": "string",
"options": ["truncate_by_turns", "llm_compress"],
"labels": ["按对话轮数截断", "由 LLM 压缩上下文"],
"condition": {
"provider_settings.agent_runner_type": "local",
},
"hint": "",
},
"provider_settings.llm_compress_instruction": {
"description": "上下文压缩提示词",
"type": "text",
"hint": "如果为空则使用默认提示词。",
"condition": {
"provider_settings.context_limit_reached_strategy": "llm_compress",
"provider_settings.agent_runner_type": "local",
},
},
"provider_settings.llm_compress_keep_recent": {
"description": "压缩时保留最近对话轮数",
"type": "int",
"hint": "始终保留的最近 N 轮对话。",
"condition": {
"provider_settings.context_limit_reached_strategy": "llm_compress",
"provider_settings.agent_runner_type": "local",
},
},
"provider_settings.llm_compress_provider_id": {
"description": "用于上下文压缩的模型提供商 ID",
"type": "string",
"_special": "select_provider",
"hint": "留空时将降级为“按对话轮数截断”的策略。",
"condition": {
"provider_settings.context_limit_reached_strategy": "llm_compress",
"provider_settings.agent_runner_type": "local",
},
},
},
},
"others": {
"description": "其他配置",
"type": "object",
@@ -2604,22 +2681,6 @@ CONFIG_METADATA_3 = {
"provider_settings.streaming_response": True,
},
},
"provider_settings.max_context_length": {
"description": "最多携带对话轮数",
"type": "int",
"hint": "超出这个数量时丢弃最旧的部分,一轮聊天记为 1 条,-1 为不限制",
"condition": {
"provider_settings.agent_runner_type": "local",
},
},
"provider_settings.dequeue_context_length": {
"description": "丢弃对话轮数",
"type": "int",
"hint": "超出最多携带对话轮数时, 一次丢弃的聊天轮数",
"condition": {
"provider_settings.agent_runner_type": "local",
},
},
"provider_settings.wake_prefix": {
"description": "LLM 聊天额外唤醒前缀 ",
"type": "string",
+4
View File
@@ -69,6 +69,7 @@ class ConversationManager:
persona_id=conv_v2.persona_id,
created_at=created_at,
updated_at=updated_at,
token_usage=conv_v2.token_usage,
)
async def new_conversation(
@@ -256,6 +257,7 @@ class ConversationManager:
history: list[dict] | None = None,
title: str | None = None,
persona_id: str | None = None,
token_usage: int | None = None,
) -> None:
"""更新会话的对话.
@@ -263,6 +265,7 @@ class ConversationManager:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
history (List[Dict]): 对话历史记录, 是一个字典列表, 每个字典包含 role 和 content 字段
token_usage (int | None): token 使用量。None 表示不更新
"""
if not conversation_id:
@@ -274,6 +277,7 @@ class ConversationManager:
title=title,
persona_id=persona_id,
content=history,
token_usage=token_usage,
)
async def update_conversation_title(
+1
View File
@@ -152,6 +152,7 @@ class BaseDatabase(abc.ABC):
title: str | None = None,
persona_id: str | None = None,
content: list[dict] | None = None,
token_usage: int | None = None,
) -> None:
"""Update a conversation's history."""
...
@@ -0,0 +1,61 @@
"""Migration script to add token_usage column to conversations table.
This migration adds the token_usage field to track token consumption for each conversation.
Changes:
- Adds token_usage column to conversations table (default: 0)
"""
from sqlalchemy import text
from astrbot.api import logger, sp
from astrbot.core.db import BaseDatabase
async def migrate_token_usage(db_helper: BaseDatabase):
"""Add token_usage column to conversations table.
This migration adds a new column to track token consumption in conversations.
"""
# 检查是否已经完成迁移
migration_done = await db_helper.get_preference(
"global", "global", "migration_done_token_usage_1"
)
if migration_done:
return
logger.info("开始执行数据库迁移(添加 conversations.token_usage 列)...")
# 这里只适配了 SQLite。因为截止至这一版本,AstrBot 仅支持 SQLite。
try:
async with db_helper.get_db() as session:
# 检查列是否已存在
result = await session.execute(text("PRAGMA table_info(conversations)"))
columns = result.fetchall()
column_names = [col[1] for col in columns]
if "token_usage" in column_names:
logger.info("token_usage 列已存在,跳过迁移")
await sp.put_async(
"global", "global", "migration_done_token_usage_1", True
)
return
# 添加 token_usage 列
await session.execute(
text(
"ALTER TABLE conversations ADD COLUMN token_usage INTEGER NOT NULL DEFAULT 0"
)
)
await session.commit()
logger.info("token_usage 列添加成功")
# 标记迁移完成
await sp.put_async("global", "global", "migration_done_token_usage_1", True)
logger.info("token_usage 迁移完成")
except Exception as e:
logger.error(f"迁移过程中发生错误: {e}", exc_info=True)
raise
+7
View File
@@ -54,6 +54,11 @@ class ConversationV2(SQLModel, table=True):
)
title: str | None = Field(default=None, max_length=255)
persona_id: str | None = Field(default=None)
token_usage: int = Field(default=0, nullable=False)
"""content is a list of OpenAI-formated messages in list[dict] format.
token_usage is the total token value of the messages.
when 0, will use estimated token counter.
"""
__table_args__ = (
UniqueConstraint(
@@ -313,6 +318,8 @@ class Conversation:
persona_id: str | None = ""
created_at: int = 0
updated_at: int = 0
token_usage: int = 0
"""对话的总 token 数量。AstrBot 会保留最近一次 LLM 请求返回的总 token 数,方便统计。token_usage 可能为 0,表示未知。"""
class Personality(TypedDict):
+5 -1
View File
@@ -241,7 +241,9 @@ class SQLiteDatabase(BaseDatabase):
session.add(new_conversation)
return new_conversation
async def update_conversation(self, cid, title=None, persona_id=None, content=None):
async def update_conversation(
self, cid, title=None, persona_id=None, content=None, token_usage=None
):
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
@@ -255,6 +257,8 @@ class SQLiteDatabase(BaseDatabase):
values["persona_id"] = persona_id
if content is not None:
values["content"] = content
if token_usage is not None:
values["token_usage"] = token_usage
if not values:
return None
query = query.values(**values)
@@ -1,12 +1,12 @@
"""本地 Agent 模式的 LLM 调用 Stage"""
import asyncio
import copy
import json
from collections.abc import AsyncGenerator
from astrbot.core import logger
from astrbot.core.agent.message import Message
from astrbot.core.agent.response import AgentStats
from astrbot.core.agent.tool import ToolSet
from astrbot.core.astr_agent_context import AstrAgentContext
from astrbot.core.conversation_mgr import Conversation
@@ -24,6 +24,7 @@ from astrbot.core.provider.entities import (
)
from astrbot.core.star.star_handler import EventType, star_map
from astrbot.core.utils.file_extract import extract_file_moonshotai
from astrbot.core.utils.llm_metadata import LLM_METADATAS
from astrbot.core.utils.metrics import Metric
from astrbot.core.utils.session_lock import session_lock_manager
@@ -41,11 +42,6 @@ class InternalAgentSubStage(Stage):
self.ctx = ctx
conf = ctx.astrbot_config
settings = conf["provider_settings"]
self.max_context_length = settings["max_context_length"] # int
self.dequeue_context_length: int = min(
max(1, settings["dequeue_context_length"]),
self.max_context_length - 1,
)
self.streaming_response: bool = settings["streaming_response"]
self.unsupported_streaming_strategy: str = settings[
"unsupported_streaming_strategy"
@@ -65,6 +61,25 @@ class InternalAgentSubStage(Stage):
"moonshotai_api_key", ""
)
# 上下文管理相关
self.context_limit_reached_strategy: str = settings.get(
"context_limit_reached_strategy", "truncate_by_turns"
)
self.llm_compress_instruction: str = settings.get(
"llm_compress_instruction", ""
)
self.llm_compress_keep_recent: int = settings.get("llm_compress_keep_recent", 4)
self.llm_compress_provider_id: str = settings.get(
"llm_compress_provider_id", ""
)
self.max_context_length = settings["max_context_length"] # int
self.dequeue_context_length: int = min(
max(1, settings["dequeue_context_length"]),
self.max_context_length - 1,
)
if self.dequeue_context_length <= 0:
self.dequeue_context_length = 1
self.conv_manager = ctx.plugin_manager.context.conversation_manager
def _select_provider(self, event: AstrMessageEvent):
@@ -167,34 +182,6 @@ class InternalAgentSubStage(Stage):
},
)
def _truncate_contexts(
self,
contexts: list[dict],
) -> list[dict]:
"""截断上下文列表,确保不超过最大长度"""
if self.max_context_length == -1:
return contexts
if len(contexts) // 2 <= self.max_context_length:
return contexts
truncated_contexts = contexts[
-(self.max_context_length - self.dequeue_context_length + 1) * 2 :
]
# 找到第一个role 为 user 的索引,确保上下文格式正确
index = next(
(
i
for i, item in enumerate(truncated_contexts)
if item.get("role") == "user"
),
None,
)
if index is not None and index > 0:
truncated_contexts = truncated_contexts[index:]
return truncated_contexts
def _modalities_fix(
self,
provider: Provider,
@@ -296,6 +283,7 @@ class InternalAgentSubStage(Stage):
req: ProviderRequest,
llm_response: LLMResponse | None,
all_messages: list[Message],
runner_stats: AgentStats | None,
):
if (
not req
@@ -322,27 +310,37 @@ class InternalAgentSubStage(Stage):
continue
message_to_save.append(message.model_dump())
# get token usage from agent runner stats
token_usage = None
if runner_stats:
token_usage = runner_stats.token_usage.total
await self.conv_manager.update_conversation(
event.unified_msg_origin,
req.conversation.cid,
history=message_to_save,
token_usage=token_usage,
)
def _fix_messages(self, messages: list[dict]) -> list[dict]:
"""验证并且修复上下文"""
fixed_messages = []
for message in messages:
if message.get("role") == "tool":
# tool block 前面必须要有 user 和 assistant block
if len(fixed_messages) < 2:
# 这种情况可能是上下文被截断导致的
# 我们直接将之前的上下文都清空
fixed_messages = []
else:
fixed_messages.append(message)
else:
fixed_messages.append(message)
return fixed_messages
def _get_compress_provider(self) -> Provider | None:
if not self.llm_compress_provider_id:
return None
if self.context_limit_reached_strategy != "llm_compress":
return None
provider = self.ctx.plugin_manager.context.get_provider_by_id(
self.llm_compress_provider_id,
)
if provider is None:
logger.warning(
f"未找到指定的上下文压缩模型 {self.llm_compress_provider_id},将跳过压缩。",
)
return None
if not isinstance(provider, Provider):
logger.warning(
f"指定的上下文压缩模型 {self.llm_compress_provider_id} 不是对话模型,将跳过压缩。"
)
return None
return provider
async def process(
self, event: AstrMessageEvent, provider_wake_prefix: str
@@ -426,9 +424,10 @@ class InternalAgentSubStage(Stage):
await self._apply_kb(event, req)
# truncate contexts to fit max length
if req.contexts:
req.contexts = self._truncate_contexts(req.contexts)
self._fix_messages(req.contexts)
# NOW moved to ContextManager inside ToolLoopAgentRunner
# if req.contexts:
# req.contexts = self._truncate_contexts(req.contexts)
# self._fix_messages(req.contexts)
# session_id
if not req.session_id:
@@ -444,8 +443,6 @@ class InternalAgentSubStage(Stage):
self.unsupported_streaming_strategy == "turn_off"
and not event.platform_meta.support_streaming_message
)
# 备份 req.contexts
backup_contexts = copy.deepcopy(req.contexts)
# run agent
agent_runner = AgentRunner()
@@ -456,6 +453,15 @@ class InternalAgentSubStage(Stage):
context=self.ctx.plugin_manager.context,
event=event,
)
# inject model context length limit
if provider.provider_config.get("max_context_tokens", 0) <= 0:
model = provider.get_model()
if model_info := LLM_METADATAS.get(model):
provider.provider_config["max_context_tokens"] = model_info[
"limit"
]["context"]
await agent_runner.reset(
provider=provider,
request=req,
@@ -466,6 +472,11 @@ class InternalAgentSubStage(Stage):
tool_executor=FunctionToolExecutor(),
agent_hooks=MAIN_AGENT_HOOKS,
streaming=streaming_response,
llm_compress_instruction=self.llm_compress_instruction,
llm_compress_keep_recent=self.llm_compress_keep_recent,
llm_compress_provider=self._get_compress_provider(),
truncate_turns=self.dequeue_context_length,
enforce_max_turns=self.max_context_length,
)
if streaming_response and not stream_to_general:
@@ -511,14 +522,12 @@ class InternalAgentSubStage(Stage):
):
yield
# 恢复备份的 contexts
req.contexts = backup_contexts
await self._save_to_history(
event,
req,
agent_runner.get_final_llm_resp(),
agent_runner.run_context.messages,
agent_runner.stats,
)
# 异步处理 WebChat 特殊情况
+5
View File
@@ -344,6 +344,11 @@ class LLMResponse:
self.raw_completion = raw_completion
self.is_chunk = is_chunk
if id is not None:
self.id = id
if usage is not None:
self.usage = usage
@property
def completion_text(self):
if self.result_chain:
+14 -1
View File
@@ -149,9 +149,12 @@ class Context:
contexts: context messages for the LLM
max_steps: Maximum number of tool calls before stopping the loop
**kwargs: Additional keyword arguments. The kwargs will not be passed to the LLM directly for now, but can include:
stream: bool - whether to stream the LLM response
agent_hooks: BaseAgentRunHooks[AstrAgentContext] - hooks to run during agent execution
agent_context: AstrAgentContext - context to use for the agent
other kwargs will be DIRECTLY passed to the runner.reset() method
Returns:
The final LLMResponse after tool calls are completed.
@@ -194,6 +197,15 @@ class Context:
)
agent_runner = ToolLoopAgentRunner()
tool_executor = FunctionToolExecutor()
streaming = kwargs.get("stream", False)
other_kwargs = {
k: v
for k, v in kwargs.items()
if k not in ["stream", "agent_hooks", "agent_context"]
}
await agent_runner.reset(
provider=prov,
request=request,
@@ -203,7 +215,8 @@ class Context:
),
tool_executor=tool_executor,
agent_hooks=agent_hooks,
streaming=kwargs.get("stream", False),
streaming=streaming,
**other_kwargs,
)
async for _ in agent_runner.step_until_done(max_steps):
pass
+8
View File
@@ -3,6 +3,7 @@ import traceback
from astrbot.core import astrbot_config, logger
from astrbot.core.astrbot_config_mgr import AstrBotConfig, AstrBotConfigManager
from astrbot.core.db.migration.migra_45_to_46 import migrate_45_to_46
from astrbot.core.db.migration.migra_token_usage import migrate_token_usage
from astrbot.core.db.migration.migra_webchat_session import migrate_webchat_session
@@ -139,6 +140,13 @@ async def migra(
logger.error(f"Migration for webchat session failed: {e!s}")
logger.error(traceback.format_exc())
# migration for token_usage column
try:
await migrate_token_usage(db)
except Exception as e:
logger.error(f"Migration for token_usage column failed: {e!s}")
logger.error(traceback.format_exc())
# migra third party agent runner configs
_c = False
providers = astrbot_config["provider"]
+1
View File
@@ -993,6 +993,7 @@ class BackupRoute(Route):
file_path,
as_attachment=True,
attachment_filename=filename,
conditional=True, # 启用 Range 请求支持(断点续传)
)
except Exception as e:
logger.error(f"下载备份失败: {e}")
+19
View File
@@ -0,0 +1,19 @@
## What's Changed
### 新增
- 支持上下文自动压缩功能。入口:配置文件 -> 上下文管理策略 -> 超出模型上下文窗口时的处理方式。详情请查看: [自动上下文压缩](https://docs.astrbot.app/use/context-compress.html) ([#4322](https://github.com/AstrBotDevs/AstrBot/issues/4322))
- 新增 `on_waiting_llm_request` 事件钩子 ([#4319](https://github.com/AstrBotDevs/AstrBot/issues/4319))
- WebUI 支持强制更新插件 ([#4293](https://github.com/AstrBotDevs/AstrBot/issues/4293))
- 社区已提供适用于 [Matrix](https://github.com/stevessr/astrbot_plugin_matrix_adapter) 平台的适配器插件
### 修复
- 修复微信公众号中由于 msg.id 数据类型不匹配导致的重试失败问题 ([#4292](https://github.com/AstrBotDevs/AstrBot/issues/4292))
- 修复调用 TTS 命令时出现的数据库锁定错误 ([#4313](https://github.com/AstrBotDevs/AstrBot/issues/4313))
- 修复 Anthropic 提供商中 token 用量始终为 0 的问题 ([#4328](https://github.com/AstrBotDevs/AstrBot/issues/4328))
### 优化
- 完善共享组件的国际化支持 ([#4327](https://github.com/AstrBotDevs/AstrBot/issues/4327))
- 优化下载大型备份文件时的稳定性,减少失败情况 ([#4329](https://github.com/AstrBotDevs/AstrBot/issues/4329))
@@ -233,12 +233,12 @@ function getSpecialSubtype(value) {
<div v-if="createSelectorModel(itemKey).value && createSelectorModel(itemKey).value.length > 0"
class="selected-plugins-full-width">
<div class="plugins-header">
<small class="text-grey">已选择的插件</small>
<small class="text-grey">{{ t('core.shared.pluginSetSelector.selectedPluginsLabel') }}</small>
</div>
<div class="d-flex flex-wrap ga-2 mt-2">
<v-chip v-for="plugin in (createSelectorModel(itemKey).value || [])" :key="plugin" size="small" label
color="primary" variant="outlined">
{{ plugin === '*' ? '所有插件' : plugin }}
{{ plugin === '*' ? t('core.shared.pluginSetSelector.allPluginsLabel') : plugin }}
</v-chip>
</div>
</div>
@@ -20,13 +20,13 @@
</template>
<template v-else-if="itemMeta?._special === 'provider_pool'">
<ProviderSelector :model-value="modelValue" @update:model-value="emitUpdate" :provider-type="'chat_completion'"
button-text="选择提供商池..." />
:button-text="t('core.shared.providerSelector.selectProviderPool')" />
</template>
<template v-else-if="itemMeta?._special === 'select_persona'">
<PersonaSelector :model-value="modelValue" @update:model-value="emitUpdate" />
</template>
<template v-else-if="itemMeta?._special === 'persona_pool'">
<PersonaSelector :model-value="modelValue" @update:model-value="emitUpdate" button-text="选择人格池..." />
<PersonaSelector :model-value="modelValue" @update:model-value="emitUpdate" :button-text="t('core.shared.personaSelector.selectPersonaPool')" />
</template>
<template v-else-if="itemMeta?._special === 'select_knowledgebase'">
<KnowledgeBaseSelector :model-value="modelValue" @update:model-value="emitUpdate" />
@@ -56,7 +56,7 @@
:loading="loading"
class="ml-2"
>
自动检测
{{ t('core.common.autoDetect') }}
</v-btn>
</div>
</template>
@@ -144,7 +144,7 @@
color="primary"
density="compact"
hide-details
class="flex-grow-1"
style="flex: 1"
></v-slider>
<v-text-field
:model-value="modelValue"
@@ -154,7 +154,7 @@
class="config-field"
type="number"
hide-details
style="max-width: 140px;"
style="flex: 1"
></v-text-field>
</div>
@@ -325,4 +325,8 @@ function getSpecialSubtype(value) {
.gap-20 {
gap: 20px;
}
:deep(.v-field__input) {
font-size: 14px;
}
</style>
@@ -20,7 +20,7 @@
</div>
</div>
<v-btn size="small" color="primary" variant="tonal" @click="openDialog" style="flex-shrink: 0;">
{{ buttonText }}
{{ buttonText || tm('knowledgeBaseSelector.buttonText') }}
</v-btn>
</div>
@@ -105,7 +105,7 @@ const props = defineProps({
},
buttonText: {
type: String,
default: '选择知识库...'
default: ''
}
})
@@ -175,11 +175,11 @@ const props = defineProps({
},
buttonText: {
type: String,
default: '修改'
default: ''
},
dialogTitle: {
type: String,
default: '修改列表项'
default: ''
},
maxDisplayItems: {
type: Number,
@@ -1,13 +1,13 @@
<template>
<div class="d-flex align-center justify-space-between">
<span v-if="!modelValue" style="color: rgb(var(--v-theme-primaryText));">
未选择
{{ tm('personaSelector.notSelected') }}
</span>
<span v-else>
{{ modelValue === 'default' ? '默认人格' : modelValue }}
{{ modelValue === 'default' ? tm('personaSelector.defaultPersona') : modelValue }}
</span>
<v-btn size="small" color="primary" variant="tonal" @click="openDialog">
{{ buttonText }}
{{ buttonText || tm('personaSelector.buttonText') }}
</v-btn>
</div>
@@ -15,7 +15,7 @@
<v-dialog v-model="dialog" max-width="600px">
<v-card>
<v-card-title class="text-h3 py-4" style="font-weight: normal;">
选择人格
{{ tm('personaSelector.dialogTitle') }}
</v-card-title>
<v-card-text class="pa-2" style="max-height: 400px; overflow-y: auto;">
@@ -30,9 +30,9 @@
:active="selectedPersona === persona.persona_id"
rounded="md"
class="ma-1">
<v-list-item-title>{{ persona.persona_id === 'default' ? '默认人格' : persona.persona_id }}</v-list-item-title>
<v-list-item-title>{{ persona.persona_id === 'default' ? tm('personaSelector.defaultPersona') : persona.persona_id }}</v-list-item-title>
<v-list-item-subtitle>
{{ persona.system_prompt ? persona.system_prompt.substring(0, 50) + '...' : '无描述' }}
{{ persona.system_prompt ? persona.system_prompt.substring(0, 50) + '...' : tm('personaSelector.noDescription') }}
</v-list-item-subtitle>
<template v-slot:append>
@@ -43,21 +43,21 @@
<div v-else-if="!loading && personaList.length === 0" class="text-center py-8">
<v-icon size="64" color="grey-lighten-1">mdi-account-off</v-icon>
<p class="text-grey mt-4">暂无可用的人格</p>
<p class="text-grey mt-4">{{ tm('personaSelector.noPersonas') }}</p>
</div>
</v-card-text>
<v-card-actions class="pa-4">
<v-btn variant="text" color="primary" prepend-icon="mdi-plus" @click="openCreatePersona">
创建新人格
{{ tm('personaSelector.createPersona') }}
</v-btn>
<v-spacer></v-spacer>
<v-btn variant="text" @click="cancelSelection">取消</v-btn>
<v-btn
color="primary"
<v-btn variant="text" @click="cancelSelection">{{ t('core.common.cancel') }}</v-btn>
<v-btn
color="primary"
@click="confirmSelection"
:disabled="!selectedPersona">
确认选择
{{ t('core.common.confirm') }}
</v-btn>
</v-card-actions>
</v-card>
@@ -78,6 +78,7 @@
import { ref, watch } from 'vue'
import axios from 'axios'
import PersonaForm from './PersonaForm.vue'
import { useI18n, useModuleI18n } from '@/i18n/composables'
const props = defineProps({
modelValue: {
@@ -86,11 +87,13 @@ const props = defineProps({
},
buttonText: {
type: String,
default: '选择人格...'
default: ''
}
})
const emit = defineEmits(['update:modelValue'])
const { t } = useI18n()
const { tm } = useModuleI18n('core.shared')
const dialog = ref(false)
const personaList = ref([])
@@ -14,7 +14,7 @@
</span>
</div>
<v-btn size="small" color="primary" variant="tonal" @click="openDialog">
{{ buttonText }}
{{ buttonText || tm('pluginSetSelector.buttonText') }}
</v-btn>
</div>
</div>
@@ -113,7 +113,7 @@ const props = defineProps({
},
buttonText: {
type: String,
default: '选择插件集合...'
default: ''
},
maxDisplayItems: {
type: Number,
@@ -7,7 +7,7 @@
{{ modelValue }}
</span>
<v-btn size="small" color="primary" variant="tonal" @click="openDialog">
{{ buttonText }}
{{ buttonText || tm('providerSelector.buttonText') }}
</v-btn>
</div>
@@ -134,7 +134,7 @@ const props = defineProps({
},
buttonText: {
type: String,
default: '选择提供商...'
default: ''
}
})
@@ -1,13 +1,13 @@
<template>
<h5>GitHub 加速</h5>
<h5>{{ tm('network.proxySelector.title') }}</h5>
<v-radio-group class="mt-2" v-model="radioValue" hide-details="true">
<v-radio label="不使用 GitHub 加速" value="0"></v-radio>
<v-radio :label="tm('network.proxySelector.noProxy')" value="0"></v-radio>
<v-radio value="1">
<template v-slot:label>
<span>使用 GitHub 加速</span>
<span>{{ tm('network.proxySelector.useProxy') }}</span>
<v-btn v-if="radioValue === '1'" class="ml-2" @click="testAllProxies" size="x-small"
variant="tonal" :loading="loadingTestingConnection">
测试代理连通性
{{ tm('network.proxySelector.testConnection') }}
</v-btn>
</template>
</v-radio>
@@ -20,15 +20,15 @@
<div class="d-flex align-center">
<span class="mr-2">{{ proxy }}</span>
<div v-if="proxyStatus[idx]">
<v-chip
:color="proxyStatus[idx].available ? 'success' : 'error'"
size="x-small"
<v-chip
:color="proxyStatus[idx].available ? 'success' : 'error'"
size="x-small"
class="mr-1">
{{ proxyStatus[idx].available ? '可用' : '不可用' }}
{{ proxyStatus[idx].available ? tm('network.proxySelector.available') : tm('network.proxySelector.unavailable') }}
</v-chip>
<v-chip
v-if="proxyStatus[idx].available"
color="info"
<v-chip
v-if="proxyStatus[idx].available"
color="info"
size="x-small">
{{ proxyStatus[idx].latency }}ms
</v-chip>
@@ -36,10 +36,10 @@
</div>
</template>
</v-radio>
<v-radio color="primary" value="-1" label="自定义">
<v-radio color="primary" value="-1" :label="tm('network.proxySelector.custom')">
<template v-slot:label v-if="githubProxyRadioControl === '-1'">
<v-text-field density="compact" v-model="selectedGitHubProxy" variant="outlined"
style="width: 100vw;" placeholder="自定义" hide-details="true">
style="width: 100vw;" :placeholder="tm('network.proxySelector.custom')" hide-details="true">
</v-text-field>
</template>
</v-radio>
@@ -1,32 +1,32 @@
<template>
<v-dialog v-model="dialog" max-width="1400px" persistent scrollable>
<template v-slot:activator="{ props }">
<v-btn
<v-btn
v-bind="props"
variant="outlined"
color="primary"
variant="outlined"
color="primary"
size="small"
:loading="loading"
>
自定义 T2I 模板
{{ tm('t2iTemplateEditor.buttonText') }}
</v-btn>
</template>
<v-card>
<v-card-title class="d-flex align-center justify-space-between">
<span>自定义文转图 HTML 模板</span>
<span>{{ tm('t2iTemplateEditor.dialogTitle') }}</span>
<v-spacer></v-spacer>
<div class="d-flex align-center gap-2" style="width: 60%">
<v-text-field
v-if="isCreatingNew"
v-model="editingName"
label="输入新模板名称"
:label="tm('t2iTemplateEditor.newTemplateNameLabel')"
density="compact"
hide-details
variant="outlined"
class="flex-grow-1"
autofocus
:rules="[v => !!v || '名称不能为空']"
:rules="[v => !!v || tm('t2iTemplateEditor.nameRequired')]"
></v-text-field>
<v-select
v-else
@@ -34,7 +34,7 @@
:items="templates"
item-title="name"
item-value="name"
label="选择模板"
:label="tm('t2iTemplateEditor.selectTemplateLabel')"
density="compact"
hide-details
variant="outlined"
@@ -51,7 +51,7 @@
size="small"
class="ml-2"
>
已应用
{{ tm('t2iTemplateEditor.applied') }}
</v-chip>
<v-btn
v-else
@@ -62,7 +62,7 @@
@click.stop="setActiveTemplate(item.raw.name)"
:loading="applyLoading"
>
应用
{{ tm('t2iTemplateEditor.apply') }}
</v-btn>
</template>
</v-list-item>
@@ -83,7 +83,7 @@
<!-- 左侧编辑器 -->
<v-col cols="6" class="d-flex flex-column">
<v-toolbar density="compact" color="surface-variant">
<v-toolbar-title class="text-subtitle-2">模板编辑器</v-toolbar-title>
<v-toolbar-title class="text-subtitle-2">{{ tm('t2iTemplateEditor.templateEditor') }}</v-toolbar-title>
<v-spacer></v-spacer>
<div class="d-flex align-center pa-1" style="border: 1px solid rgba(0,0,0,0.1); border-radius: 8px;">
<v-btn
@@ -93,7 +93,7 @@
color="success"
>
<v-icon left>mdi-plus</v-icon>
新建
{{ tm('t2iTemplateEditor.new') }}
</v-btn>
<v-divider vertical class="mx-1"></v-divider>
<v-btn
@@ -103,7 +103,7 @@
:loading="resetLoading"
color="warning"
>
重置Base
{{ tm('t2iTemplateEditor.resetBase') }}
</v-btn>
<v-btn
variant="text"
@@ -112,7 +112,7 @@
color="error"
:disabled="isCreatingNew || selectedTemplate === 'base' || !selectedTemplate"
>
删除
{{ tm('t2iTemplateEditor.delete') }}
</v-btn>
<v-divider vertical class="mx-1"></v-divider>
<v-btn
@@ -123,7 +123,7 @@
color="primary"
:disabled="(isCreatingNew && !editingName) || (!isCreatingNew && !selectedTemplate)"
>
保存
{{ tm('t2iTemplateEditor.save') }}
</v-btn>
</div>
</v-toolbar>
@@ -141,15 +141,15 @@
<!-- 右侧预览 -->
<v-col cols="6" class="d-flex flex-column">
<v-toolbar density="compact" color="surface-variant">
<v-toolbar-title class="text-subtitle-2">实时预览(可能有差异)</v-toolbar-title>
<v-toolbar-title class="text-subtitle-2">{{ tm('t2iTemplateEditor.livePreview') }}</v-toolbar-title>
<v-spacer></v-spacer>
<v-btn
variant="text"
size="small"
<v-btn
variant="text"
size="small"
@click="refreshPreview"
:loading="previewLoading"
>
刷新预览
{{ tm('t2iTemplateEditor.refreshPreview') }}
</v-btn>
</v-toolbar>
<div class="flex-grow-1 preview-container">
@@ -168,7 +168,7 @@
<v-col>
<div class="text-caption text-grey">
<v-icon size="16" class="mr-1">mdi-information</v-icon>
支持 jinja2 语法可用变量<code> text | safe </code>要渲染的文本, <code> version </code>AstrBot 版本
{{ tm('t2iTemplateEditor.syntaxHint') }}
</div>
</v-col>
<v-col cols="auto">
@@ -176,7 +176,7 @@
variant="text"
@click="closeDialog"
>
取消
{{ t('core.common.cancel') }}
</v-btn>
<v-btn
color="primary"
@@ -184,7 +184,7 @@
:loading="saveLoading"
:disabled="isCreatingNew || !selectedTemplate"
>
保存应用当前编辑模板
{{ tm('t2iTemplateEditor.saveAndApply') }}
</v-btn>
</v-col>
</v-row>
@@ -194,14 +194,14 @@
<!-- 确认重置对话框 -->
<v-dialog v-model="resetDialog" max-width="400px">
<v-card>
<v-card-title>确认重置</v-card-title>
<v-card-title>{{ tm('t2iTemplateEditor.confirmReset') }}</v-card-title>
<v-card-text>
确定要将 'base' 模板恢复为默认内容吗当前编辑器中的任何未保存更改将丢失此操作无法撤销
{{ tm('t2iTemplateEditor.confirmResetMessage') }}
</v-card-text>
<v-card-actions>
<v-spacer></v-spacer>
<v-btn text @click="resetDialog = false">取消</v-btn>
<v-btn color="warning" @click="confirmReset" :loading="resetLoading">确认重置</v-btn>
<v-btn text @click="resetDialog = false">{{ t('core.common.cancel') }}</v-btn>
<v-btn color="warning" @click="confirmReset" :loading="resetLoading">{{ tm('t2iTemplateEditor.confirmResetButton') }}</v-btn>
</v-card-actions>
</v-card>
</v-dialog>
@@ -209,14 +209,14 @@
<!-- 删除确认对话框 -->
<v-dialog v-model="deleteDialog" max-width="400px">
<v-card>
<v-card-title>确认删除</v-card-title>
<v-card-title>{{ tm('t2iTemplateEditor.confirmDelete') }}</v-card-title>
<v-card-text>
确定要删除模板 '{{ selectedTemplate }}' 此操作无法撤销
{{ tm('t2iTemplateEditor.confirmDeleteMessage', { name: selectedTemplate }) }}
</v-card-text>
<v-card-actions>
<v-spacer></v-spacer>
<v-btn text @click="deleteDialog = false">取消</v-btn>
<v-btn color="error" @click="confirmDelete" :loading="saveLoading">确认删除</v-btn>
<v-btn text @click="deleteDialog = false">{{ t('core.common.cancel') }}</v-btn>
<v-btn color="error" @click="confirmDelete" :loading="saveLoading">{{ tm('t2iTemplateEditor.confirmDeleteButton') }}</v-btn>
</v-card-actions>
</v-card>
</v-dialog>
@@ -224,14 +224,14 @@
<!-- 保存并应用确认对话框 -->
<v-dialog v-model="applyAndCloseDialog" max-width="500px">
<v-card>
<v-card-title>确认操作</v-card-title>
<v-card-title>{{ tm('t2iTemplateEditor.confirmAction') }}</v-card-title>
<v-card-text>
确定要保存对 '{{ selectedTemplate }}' 的修改并将其设为新的活动模板吗
{{ tm('t2iTemplateEditor.confirmApplyMessage', { name: selectedTemplate }) }}
</v-card-text>
<v-card-actions>
<v-spacer></v-spacer>
<v-btn text @click="applyAndCloseDialog = false">取消</v-btn>
<v-btn color="primary" @click="confirmApplyAndClose" :loading="saveLoading">确认</v-btn>
<v-btn text @click="applyAndCloseDialog = false">{{ t('core.common.cancel') }}</v-btn>
<v-btn color="primary" @click="confirmApplyAndClose" :loading="saveLoading">{{ t('core.common.confirm') }}</v-btn>
</v-card-actions>
</v-card>
</v-dialog>
@@ -242,10 +242,11 @@
<script setup>
import { ref, computed, nextTick, watch } from 'vue'
import { VueMonacoEditor } from '@guolao/vue-monaco-editor'
import { useI18n } from '@/i18n/composables'
import { useI18n, useModuleI18n } from '@/i18n/composables'
import axios from 'axios'
const { t } = useI18n()
const { tm } = useModuleI18n('core.shared')
// --- ---
const dialog = ref(false)
@@ -510,7 +510,7 @@ export function useProviderSources(options: UseProviderSourcesOptions) {
const metadata = getModelMetadata(modelName)
let modalities: string[]
if (!metadata) {
modalities = ['text', 'image', 'tool_use']
} else {
@@ -523,13 +523,19 @@ export function useProviderSources(options: UseProviderSourcesOptions) {
}
}
let max_context_tokens = 0
if (metadata?.limit?.context && typeof metadata.limit.context === 'number') {
max_context_tokens = metadata.limit.context
}
const newProvider = {
id: newId,
enable: false,
provider_source_id: sourceId,
model: modelName,
modalities,
custom_extra_body: {}
custom_extra_body: {},
max_context_tokens: max_context_tokens
}
try {
@@ -35,6 +35,7 @@
"yes": "Yes",
"no": "No",
"imagePreview": "Image Preview",
"autoDetect": "Auto Detect",
"dialog": {
"confirmTitle": "Confirm Action",
"confirmMessage": "Are you sure you want to perform this action?",
@@ -28,7 +28,9 @@
"cancelSelection": "Cancel",
"noDescription": "No description",
"notActivated": "Not activated",
"note": "*System plugins and disabled plugins are not shown."
"note": "*System plugins and disabled plugins are not shown.",
"selectedPluginsLabel": "Selected Plugins:",
"allPluginsLabel": "All Plugins"
},
"providerSelector": {
"notSelected": "Not selected",
@@ -42,6 +44,45 @@
"clearSelectionSubtitle": "Clear current selection",
"unknownType": "Unknown type",
"createProvider": "Create Provider",
"manageProviders": "Provider Management"
"manageProviders": "Provider Management",
"selectProviderPool": "Select Provider Pool..."
},
"personaSelector": {
"notSelected": "Not selected",
"defaultPersona": "Default Persona",
"buttonText": "Select Persona...",
"dialogTitle": "Select Persona",
"noDescription": "No description",
"noPersonas": "No personas available",
"createPersona": "Create New Persona",
"cancelSelection": "Cancel",
"confirmSelection": "Confirm Selection",
"selectPersonaPool": "Select Persona Pool..."
},
"t2iTemplateEditor": {
"buttonText": "Customize T2I Template",
"dialogTitle": "Customize Text-to-Image HTML Template",
"newTemplateNameLabel": "Enter new template name",
"nameRequired": "Name is required",
"selectTemplateLabel": "Select Template",
"applied": "Applied",
"apply": "Apply",
"templateEditor": "Template Editor",
"new": "New",
"resetBase": "Reset Base",
"delete": "Delete",
"save": "Save",
"livePreview": "Live Preview (may differ)",
"refreshPreview": "Refresh Preview",
"syntaxHint": "Supports jinja2 syntax. Available variables: text | safe (text to render), version (AstrBot version)",
"saveAndApply": "Save and Apply Current Template",
"confirmReset": "Confirm Reset",
"confirmResetMessage": "Are you sure you want to reset the 'base' template to default content? Any unsaved changes in the editor will be lost. This action cannot be undone.",
"confirmResetButton": "Confirm Reset",
"confirmDelete": "Confirm Delete",
"confirmDeleteMessage": "Are you sure you want to delete template '{name}'? This action cannot be undone.",
"confirmDeleteButton": "Confirm Delete",
"confirmAction": "Confirm Action",
"confirmApplyMessage": "Are you sure you want to save changes to '{name}' and set it as the active template?"
}
}
@@ -11,7 +11,12 @@
},
"agent_runner_type": {
"description": "Runner",
"labels": ["Built-in Agent", "Dify", "Coze", "Alibaba Cloud Bailian Application"]
"labels": [
"Built-in Agent",
"Dify",
"Coze",
"Alibaba Cloud Bailian Application"
]
},
"coze_agent_runner_provider_id": {
"description": "Coze Agent Runner Provider ID"
@@ -128,6 +133,39 @@
}
}
},
"truncate_and_compress": {
"description": "Context Management Strategy",
"provider_settings": {
"max_context_length": {
"description": "Maximum Conversation Turns",
"hint": "Discards the oldest parts when this count is exceeded. One conversation round counts as 1, -1 means unlimited"
},
"dequeue_context_length": {
"description": "Dequeue Conversation Turns",
"hint": "Number of conversation turns to discard at once when maximum context length is exceeded"
},
"context_limit_reached_strategy": {
"description": "Handling When Model Context Window is Exceeded",
"labels": [
"Truncate by Turns",
"Compress by LLM"
],
"hint": "When 'Truncate by Turns' is selected, the oldest N conversation turns will be discarded based on the 'Dequeue Conversation Turns' setting above. When 'Compress by LLM' is selected, the specified model will be used for context compression."
},
"llm_compress_instruction": {
"description": "Context Compression Instruction",
"hint": "If empty, the default prompt will be used."
},
"llm_compress_keep_recent": {
"description": "Keep Recent Turns When Compressing",
"hint": "Always keep the most recent N turns of conversation when compressing context."
},
"llm_compress_provider_id": {
"description": "Model Provider ID for Context Compression",
"hint": "When left empty, will fall back to the 'Truncate by Turns' strategy."
}
}
},
"others": {
"description": "Other Settings",
"provider_settings": {
@@ -161,15 +199,10 @@
"unsupported_streaming_strategy": {
"description": "Platforms Without Streaming Support",
"hint": "Select the handling method for platforms that don't support streaming responses. Real-time segmented reply sends content immediately when the system detects segment points like punctuation during streaming reception",
"labels": ["Real-time Segmented Reply", "Disable Streaming Response"]
},
"max_context_length": {
"description": "Maximum Conversation Rounds",
"hint": "Discards the oldest parts when this count is exceeded. One conversation round counts as 1, -1 means unlimited"
},
"dequeue_context_length": {
"description": "Dequeue Conversation Rounds",
"hint": "Number of conversation rounds to discard at once when maximum context length is exceeded"
"labels": [
"Real-time Segmented Reply",
"Disable Streaming Response"
]
},
"wake_prefix": {
"description": "Additional LLM Chat Wake Prefix",
@@ -387,7 +420,10 @@
},
"split_mode": {
"description": "Split Mode",
"labels": ["Regex", "Words List"]
"labels": [
"Regex",
"Words List"
]
},
"regex": {
"description": "Segmentation Regular Expression"
@@ -488,4 +524,4 @@
}
}
}
}
}
@@ -5,6 +5,15 @@
"title": "GitHub Proxy Address",
"subtitle": "Set the GitHub proxy address used when downloading plugins or updating AstrBot. This is effective in mainland China's network environment. Can be customized, input takes effect in real time. All addresses do not guarantee stability. If errors occur when updating plugins/projects, please first check if the proxy address is working properly.",
"label": "Select GitHub Proxy Address"
},
"proxySelector": {
"title": "GitHub Proxy",
"noProxy": "Don't use GitHub Proxy",
"useProxy": "Use GitHub Proxy",
"testConnection": "Test Connection",
"available": "Available",
"unavailable": "Unavailable",
"custom": "Custom"
}
},
"system": {
@@ -35,6 +35,7 @@
"yes": "是",
"no": "否",
"imagePreview": "图片预览",
"autoDetect": "自动检测",
"dialog": {
"confirmTitle": "确认操作",
"confirmMessage": "你确定要执行此操作吗?",
@@ -28,7 +28,9 @@
"cancelSelection": "取消",
"noDescription": "无描述",
"notActivated": "未激活",
"note": "*不显示系统插件和已经在插件页禁用的插件。"
"note": "*不显示系统插件和已经在插件页禁用的插件。",
"selectedPluginsLabel": "已选择的插件:",
"allPluginsLabel": "所有插件"
},
"providerSelector": {
"notSelected": "未选择",
@@ -42,6 +44,45 @@
"clearSelectionSubtitle": "清除当前选择",
"unknownType": "未知类型",
"createProvider": "创建提供商",
"manageProviders": "提供商管理"
"manageProviders": "提供商管理",
"selectProviderPool": "选择提供商池..."
},
"personaSelector": {
"notSelected": "未选择",
"defaultPersona": "默认人格",
"buttonText": "选择人格...",
"dialogTitle": "选择人格",
"noDescription": "无描述",
"noPersonas": "暂无可用的人格",
"createPersona": "创建新人格",
"cancelSelection": "取消",
"confirmSelection": "确认选择",
"selectPersonaPool": "选择人格池..."
},
"t2iTemplateEditor": {
"buttonText": "自定义 T2I 模板",
"dialogTitle": "自定义文转图 HTML 模板",
"newTemplateNameLabel": "输入新模板名称",
"nameRequired": "名称不能为空",
"selectTemplateLabel": "选择模板",
"applied": "已应用",
"apply": "应用",
"templateEditor": "模板编辑器",
"new": "新建",
"resetBase": "重置Base",
"delete": "删除",
"save": "保存",
"livePreview": "实时预览(可能有差异)",
"refreshPreview": "刷新预览",
"syntaxHint": "支持 jinja2 语法。可用变量:text | safe(要渲染的文本), versionAstrBot 版本)",
"saveAndApply": "保存应用当前编辑模板",
"confirmReset": "确认重置",
"confirmResetMessage": "确定要将 'base' 模板恢复为默认内容吗?当前编辑器中的任何未保存更改将丢失。此操作无法撤销。",
"confirmResetButton": "确认重置",
"confirmDelete": "确认删除",
"confirmDeleteMessage": "确定要删除模板 '{name}' 吗?此操作无法撤销。",
"confirmDeleteButton": "确认删除",
"confirmAction": "确认操作",
"confirmApplyMessage": "确定要保存对 '{name}' 的修改,并将其设为新的活动模板吗?"
}
}
@@ -133,6 +133,36 @@
}
}
},
"truncate_and_compress": {
"description": "上下文管理策略",
"provider_settings": {
"max_context_length": {
"description": "最多携带对话轮数",
"hint": "超出这个数量时丢弃最旧的部分,一轮聊天记为 1 条,-1 为不限制"
},
"dequeue_context_length": {
"description": "丢弃对话轮数",
"hint": "超出最多携带对话轮数时, 一次丢弃的聊天轮数"
},
"context_limit_reached_strategy": {
"description": "超出模型上下文窗口时的处理方式",
"labels": ["按对话轮数截断", "由 LLM 压缩上下文"],
"hint": "当按对话轮数截断时,会根据上面\"丢弃对话轮数\"的配置丢弃最旧的 N 轮对话。当由 LLM 压缩上下文时,会使用指定的模型进行上下文压缩。"
},
"llm_compress_instruction": {
"description": "上下文压缩提示词",
"hint": "如果为空则使用默认提示词。"
},
"llm_compress_keep_recent": {
"description": "压缩时保留最近对话轮数",
"hint": "始终保留的最近 N 轮对话。"
},
"llm_compress_provider_id": {
"description": "用于上下文压缩的模型提供商 ID",
"hint": "留空时将降级为\"按对话轮数截断\"的策略。"
}
}
},
"others": {
"description": "其他配置",
"provider_settings": {
@@ -171,14 +201,7 @@
"关闭流式回复"
]
},
"max_context_length": {
"description": "最多携带对话轮数",
"hint": "超出这个数量时丢弃最旧的部分,一轮聊天记为 1 条,-1 为不限制"
},
"dequeue_context_length": {
"description": "丢弃对话轮数",
"hint": "超出最多携带对话轮数时, 一次丢弃的聊天轮数"
},
"wake_prefix": {
"description": "LLM 聊天额外唤醒前缀",
"hint": "如果唤醒前缀为 /, 额外聊天唤醒前缀为 chat,则需要 /chat 才会触发 LLM 请求"
@@ -5,6 +5,15 @@
"title": "GitHub 加速地址",
"subtitle": "设置下载插件或者更新 AstrBot 时所用的 GitHub 加速地址。这在中国大陆的网络环境有效。可以自定义,输入结果实时生效。所有地址均不保证稳定性,如果在更新插件/项目时出现报错,请首先检查加速地址是否能正常使用。",
"label": "选择 GitHub 加速地址"
},
"proxySelector": {
"title": "GitHub 加速",
"noProxy": "不使用 GitHub 加速",
"useProxy": "使用 GitHub 加速",
"testConnection": "测试代理连通性",
"available": "可用",
"unavailable": "不可用",
"custom": "自定义"
}
},
"system": {
+17
View File
@@ -43,4 +43,21 @@ spec:
resources:
requests:
storage: 5Gi
# storageClassName: standard
---
# 持久化 machine-id,保持设备标识不变
apiVersion: v1
kind: PersistentVolumeClaim
metadata:
name: napcat-machine-id-pvc
namespace: astrbot-ns
labels:
app: astrbot-stack
spec:
accessModes:
- ReadWriteOnce
resources:
requests:
storage: 10Mi # 只需存储一个 32 字节的文件
# storageClassName: standard
+63 -1
View File
@@ -17,6 +17,32 @@ spec:
labels:
app: astrbot-stack
spec:
# 设置固定主机名,避免 Pod 重启后主机名变化触发风控
hostname: napcat-host
subdomain: astrbot-stack
# 优雅关闭时间,给 NapCat 足够时间保存状态
terminationGracePeriodSeconds: 60
# 初始化容器:首次生成随机 machine-id,后续复用
initContainers:
- name: init-machine-id
image: busybox:latest
command:
- /bin/sh
- -c
- |
# 仅在 machine-id 不存在时随机生成一个
if [ ! -f /machine-id-data/machine-id ]; then
# 使用 /dev/urandom 生成随机 UUID (32位十六进制)
cat /proc/sys/kernel/random/uuid | tr -d '-' > /machine-id-data/machine-id
echo "Machine ID generated: $(cat /machine-id-data/machine-id)"
else
echo "Machine ID exists: $(cat /machine-id-data/machine-id)"
fi
volumeMounts:
- name: machine-id-data
mountPath: /machine-id-data
containers:
- name: napcat
image: mlikiowa/napcat-docker:latest
@@ -28,9 +54,19 @@ spec:
value: "1000"
- name: MODE
value: "astrbot"
- name: TZ
value: "Asia/Shanghai"
ports:
- containerPort: 6099
name: napcat-web
# 资源限制:确保 Guaranteed QoS,减少被驱逐的可能
resources:
requests:
memory: "512Mi"
cpu: "250m"
limits:
memory: "1Gi"
cpu: "1000m"
volumeMounts:
- name: shared-data
mountPath: /AstrBot/data
@@ -38,6 +74,14 @@ spec:
mountPath: /app/napcat/config
- name: napcat-qq
mountPath: /app/.config/QQ
# 挂载持久化的 machine-id
- name: machine-id-data
mountPath: /etc/machine-id
subPath: machine-id
readOnly: true
- name: localtime
mountPath: /etc/localtime
readOnly: true
- name: astrbot
image: soulter/astrbot:latest
@@ -48,9 +92,19 @@ spec:
ports:
- containerPort: 6185
name: astrbot-web
resources:
requests:
memory: "256Mi"
cpu: "100m"
limits:
memory: "512Mi"
cpu: "500m"
volumeMounts:
- name: shared-data
mountPath: /AstrBot/data
- name: localtime
mountPath: /etc/localtime
readOnly: true
volumes:
- name: shared-data
@@ -61,4 +115,12 @@ spec:
claimName: napcat-config-pvc
- name: napcat-qq
persistentVolumeClaim:
claimName: napcat-qq-pvc
claimName: napcat-qq-pvc
# 持久化 machine-id(首次随机生成,后续复用)
- name: machine-id-data
persistentVolumeClaim:
claimName: napcat-machine-id-pvc
- name: localtime
hostPath:
path: /etc/localtime
type: File
+1 -1
View File
@@ -1,6 +1,6 @@
[project]
name = "AstrBot"
version = "4.10.6"
version = "4.11.0"
description = "Easy-to-use multi-platform LLM chatbot and development framework"
readme = "README.md"
requires-python = ">=3.10"
+774
View File
@@ -0,0 +1,774 @@
"""Comprehensive tests for ContextManager."""
import sys
from pathlib import Path
from typing import Literal
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
# Add parent directory to path to avoid circular import issues
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
from astrbot.core.agent.context.config import ContextConfig
from astrbot.core.agent.context.manager import ContextManager
from astrbot.core.agent.message import Message, TextPart
from astrbot.core.provider.entities import LLMResponse
class MockProvider:
"""模拟 Provider"""
def __init__(self):
self.provider_config = {
"id": "test_provider",
"model": "gpt-4",
"modalities": ["text", "image", "tool_use"],
}
async def text_chat(self, **kwargs):
"""模拟 LLM 调用,返回摘要"""
messages = kwargs.get("messages", [])
# 简单的摘要逻辑:返回消息数量统计
return LLMResponse(
role="assistant",
completion_text=f"历史对话包含 {len(messages) - 1} 条消息,主要讨论了技术话题。",
)
def get_model(self):
return "gpt-4"
def meta(self):
return MagicMock(id="test_provider", type="openai")
class TestContextManager:
"""Test suite for ContextManager."""
def create_message(
self, role: Literal["system", "user", "assistant", "tool"], content: str
) -> Message:
"""Helper to create a simple text message."""
return Message(role=role, content=content)
def create_messages(self, count: int) -> list[Message]:
"""Helper to create alternating user/assistant messages."""
messages = []
for i in range(count):
role = "user" if i % 2 == 0 else "assistant"
messages.append(self.create_message(role, f"Message {i}"))
return messages
# ==================== Basic Initialization Tests ====================
def test_init_with_minimal_config(self):
"""Test initialization with minimal configuration."""
config = ContextConfig()
manager = ContextManager(config)
assert manager.config == config
assert manager.token_counter is not None
assert manager.truncator is not None
assert manager.compressor is not None
def test_init_with_llm_compressor(self):
"""Test initialization with LLM-based compression."""
mock_provider = MockProvider()
config = ContextConfig(
llm_compress_provider=mock_provider, # type: ignore
llm_compress_keep_recent=5,
llm_compress_instruction="Summarize the conversation",
)
manager = ContextManager(config)
from astrbot.core.agent.context.compressor import LLMSummaryCompressor
assert isinstance(manager.compressor, LLMSummaryCompressor)
def test_init_with_truncate_compressor(self):
"""Test initialization with truncate-based compression (default)."""
config = ContextConfig(truncate_turns=3)
manager = ContextManager(config)
from astrbot.core.agent.context.compressor import TruncateByTurnsCompressor
assert isinstance(manager.compressor, TruncateByTurnsCompressor)
# ==================== Empty and Edge Cases ====================
@pytest.mark.asyncio
async def test_process_empty_messages(self):
"""Test processing an empty message list."""
config = ContextConfig()
manager = ContextManager(config)
result = await manager.process([])
assert result == []
@pytest.mark.asyncio
async def test_process_single_message(self):
"""Test processing a single message."""
config = ContextConfig()
manager = ContextManager(config)
messages = [self.create_message("user", "Hello")]
result = await manager.process(messages)
assert len(result) == 1
assert result[0].content == "Hello"
@pytest.mark.asyncio
async def test_process_with_no_limits(self):
"""Test processing when no limits are set (no truncation or compression)."""
config = ContextConfig(max_context_tokens=0, enforce_max_turns=-1)
manager = ContextManager(config)
messages = self.create_messages(20)
result = await manager.process(messages)
assert len(result) == 20
assert result == messages
# ==================== Enforce Max Turns Tests ====================
@pytest.mark.asyncio
async def test_enforce_max_turns_basic(self):
"""Test basic enforce_max_turns functionality."""
config = ContextConfig(enforce_max_turns=3, truncate_turns=1)
manager = ContextManager(config)
# Create 10 turns (20 messages)
messages = self.create_messages(20)
result = await manager.process(messages)
# Should keep only 3 most recent turns (6 messages)
assert len(result) <= 8 # May vary due to truncation logic
@pytest.mark.asyncio
async def test_enforce_max_turns_zero(self):
"""Test enforce_max_turns with value 0 (should keep nothing)."""
config = ContextConfig(enforce_max_turns=0, truncate_turns=1)
manager = ContextManager(config)
messages = self.create_messages(10)
result = await manager.process(messages)
# Should result in empty or minimal message list
assert len(result) <= 2
@pytest.mark.asyncio
async def test_enforce_max_turns_negative(self):
"""Test enforce_max_turns with -1 (no limit)."""
config = ContextConfig(enforce_max_turns=-1)
manager = ContextManager(config)
messages = self.create_messages(20)
result = await manager.process(messages)
assert len(result) == 20
@pytest.mark.asyncio
async def test_enforce_max_turns_with_system_messages(self):
"""Test enforce_max_turns preserves system messages."""
config = ContextConfig(enforce_max_turns=2, truncate_turns=1)
manager = ContextManager(config)
messages = [
self.create_message("system", "System instruction"),
*self.create_messages(10),
]
result = await manager.process(messages)
# System message should be preserved
system_msgs = [m for m in result if m.role == "system"]
assert len(system_msgs) >= 1
assert system_msgs[0].content == "System instruction"
# ==================== Token-based Compression Tests ====================
@pytest.mark.asyncio
async def test_token_compression_not_triggered_below_threshold(self):
"""Test that compression is not triggered below threshold."""
config = ContextConfig(max_context_tokens=1000)
manager = ContextManager(config)
# Create messages that total less than threshold
messages = [self.create_message("user", "Hi" * 50)] # ~100 tokens
with patch.object(
manager.compressor, "should_compress", return_value=False
) as mock_should_compress:
with patch.object(
manager.compressor, "__call__", new_callable=AsyncMock
) as mock_compress:
result = await manager.process(messages)
# should_compress should be called
mock_should_compress.assert_called_once()
# Compressor should not be called
mock_compress.assert_not_called()
assert result == messages
@pytest.mark.asyncio
async def test_token_compression_triggered_above_threshold(self):
"""Test that compression is triggered above threshold."""
config = ContextConfig(max_context_tokens=100, truncate_turns=1)
manager = ContextManager(config)
# Create messages that exceed threshold (0.82 * 100 = 82 tokens)
# 300 chars * 0.3 = 90 tokens > 82 threshold
long_text = "x" * 300 # ~90 tokens, above threshold
messages = [self.create_message("user", long_text)]
# Mock compressor to return smaller result
compressed = [self.create_message("user", "short")]
# Create a mock compressor
mock_compressor = AsyncMock()
mock_compressor.compression_threshold = 0.82
mock_compressor.return_value = compressed
# Mock should_compress to return True first time, False after
call_count = 0
def mock_should_compress(*args, **kwargs):
nonlocal call_count
call_count += 1
return call_count == 1
mock_compressor.should_compress = mock_should_compress
manager.compressor = mock_compressor
result = await manager.process(messages)
# Compressor should be called
mock_compressor.assert_called_once()
# Result should be the compressed version
assert len(result) <= len(messages)
@pytest.mark.asyncio
async def test_token_compression_with_zero_max_tokens(self):
"""Test that compression is skipped when max_context_tokens is 0."""
config = ContextConfig(max_context_tokens=0)
manager = ContextManager(config)
messages = [self.create_message("user", "x" * 10000)]
with patch.object(
manager.compressor, "__call__", new_callable=AsyncMock
) as mock_compress:
result = await manager.process(messages)
# Compressor should not be called when max_context_tokens is 0
mock_compress.assert_not_called()
assert result == messages
@pytest.mark.asyncio
async def test_token_compression_with_negative_max_tokens(self):
"""Test that compression is skipped when max_context_tokens is negative."""
config = ContextConfig(max_context_tokens=-100)
manager = ContextManager(config)
messages = [self.create_message("user", "x" * 10000)]
with patch.object(
manager.compressor, "__call__", new_callable=AsyncMock
) as mock_compress:
result = await manager.process(messages)
# Compressor should not be called
mock_compress.assert_not_called()
assert result == messages
@pytest.mark.asyncio
async def test_double_check_after_compression(self):
"""Test that halving is applied if still over threshold after compression."""
config = ContextConfig(max_context_tokens=100)
manager = ContextManager(config)
# Create messages that would still be over threshold after compression
long_messages = [self.create_message("user", "x" * 200) for _ in range(10)]
# Mock compressor to return messages still over threshold
async def mock_compress(msgs):
return msgs # Return same messages (still over limit)
# Mock should_compress to return True twice (before and after compression)
with patch.object(manager.compressor, "should_compress", return_value=True):
with patch.object(manager.compressor, "__call__", new=mock_compress):
with patch.object(
manager.truncator,
"truncate_by_halving",
return_value=long_messages[:5],
) as mock_halving:
_ = await manager.process(long_messages)
# Halving should be called
mock_halving.assert_called_once()
# ==================== Combined Truncation and Compression Tests ====================
@pytest.mark.asyncio
async def test_combined_enforce_turns_and_token_limit(self):
"""Test combining enforce_max_turns and token limit."""
config = ContextConfig(
enforce_max_turns=5, max_context_tokens=500, truncate_turns=1
)
manager = ContextManager(config)
# Create many messages
messages = self.create_messages(30)
result = await manager.process(messages)
# Should be truncated by both mechanisms
assert len(result) < 30
@pytest.mark.asyncio
async def test_sequential_processing_order(self):
"""Test that enforce_max_turns happens before token compression."""
config = ContextConfig(enforce_max_turns=5, max_context_tokens=1000)
manager = ContextManager(config)
messages = self.create_messages(20)
# Mock the truncator to track calls
with patch.object(
manager.truncator,
"truncate_by_turns",
wraps=manager.truncator.truncate_by_turns,
) as mock_truncate:
await manager.process(messages)
# Truncator should be called first
mock_truncate.assert_called_once()
# ==================== Error Handling Tests ====================
@pytest.mark.asyncio
async def test_error_handling_returns_original_messages(self):
"""Test that errors during processing return original messages."""
config = ContextConfig(max_context_tokens=100)
manager = ContextManager(config)
messages = self.create_messages(5)
# Make compressor raise an exception
with patch.object(
manager.compressor, "__call__", side_effect=Exception("Test error")
):
result = await manager.process(messages)
# Should return original messages despite error
assert result == messages
@pytest.mark.asyncio
async def test_error_handling_logs_exception(self):
"""Test that errors are logged."""
config = ContextConfig(max_context_tokens=100)
manager = ContextManager(config)
# Create messages that will trigger compression (> 82 tokens)
messages = [self.create_message("user", "x" * 300)] # ~90 tokens
# Replace compressor with one that raises an exception
mock_compressor = AsyncMock(side_effect=Exception("Test error"))
mock_compressor.compression_threshold = 0.82
mock_compressor.should_compress = MagicMock(return_value=True)
manager.compressor = mock_compressor
with patch("astrbot.core.agent.context.manager.logger") as mock_logger:
result = await manager.process(messages)
# Logger error method should be called
assert mock_logger.error.called
# Should return original messages on error
assert result == messages
# ==================== Multi-modal Content Tests ====================
@pytest.mark.asyncio
async def test_process_messages_with_textpart_content(self):
"""Test processing messages with TextPart content."""
config = ContextConfig()
manager = ContextManager(config)
messages = [
Message(role="user", content=[TextPart(text="Hello")]),
Message(role="assistant", content=[TextPart(text="Hi there")]),
]
result = await manager.process(messages)
assert len(result) == 2
assert result == messages
@pytest.mark.asyncio
async def test_token_counting_with_multimodal_content(self):
"""Test token counting works with multi-modal content."""
config = ContextConfig(max_context_tokens=50)
manager = ContextManager(config)
# Need enough tokens to exceed threshold: 50 * 0.82 = 41 tokens
# 150 chars * 0.3 = 45 tokens > 41
messages = [
Message(role="user", content=[TextPart(text="x" * 150)]),
]
# Should trigger compression due to token count
tokens = manager.token_counter.count_tokens(messages)
needs_compression = manager.compressor.should_compress(messages, tokens, 50)
assert tokens > 0 # Tokens should be counted
assert needs_compression # Should trigger compression
# ==================== Tool Calls Tests ====================
@pytest.mark.asyncio
async def test_process_messages_with_tool_calls(self):
"""Test processing messages with tool calls."""
config = ContextConfig()
manager = ContextManager(config)
messages = [
Message(
role="assistant",
content="Let me search for that",
tool_calls=[
{
"id": "call_1",
"type": "function",
"function": {"name": "search", "arguments": "{}"},
}
],
),
Message(role="tool", content="Search result", tool_call_id="call_1"),
]
result = await manager.process(messages)
assert len(result) == 2
# ==================== Compressor should_compress Tests ====================
@pytest.mark.asyncio
async def test_should_compress_empty_messages(self):
"""Test should_compress with empty messages."""
config = ContextConfig(max_context_tokens=100)
manager = ContextManager(config)
# Compressor's should_compress should handle empty gracefully
needs_compression = manager.compressor.should_compress([], 0, 100)
assert not needs_compression
@pytest.mark.asyncio
async def test_should_compress_below_threshold(self):
"""Test should_compress when below compression threshold."""
config = ContextConfig(max_context_tokens=1000)
manager = ContextManager(config)
messages = [self.create_message("user", "Hello")]
tokens = manager.token_counter.count_tokens(messages)
needs_compression = manager.compressor.should_compress(messages, tokens, 1000)
assert not needs_compression
@pytest.mark.asyncio
async def test_should_compress_above_threshold(self):
"""Test should_compress when above compression threshold."""
config = ContextConfig(max_context_tokens=100)
manager = ContextManager(config)
# Create message with many tokens
messages = [self.create_message("user", "这是测试" * 50)]
tokens = manager.token_counter.count_tokens(messages)
needs_compression = manager.compressor.should_compress(messages, tokens, 100)
# Should need compression if tokens > 82 (0.82 * 100)
assert needs_compression == (tokens > 82)
# ==================== Truncator Halving Tests ====================
def test_truncate_by_halving_basic(self):
"""Test truncate_by_halving removes middle 50%."""
config = ContextConfig()
manager = ContextManager(config)
messages = self.create_messages(10)
result = manager.truncator.truncate_by_halving(messages)
# Should keep roughly half
assert len(result) < len(messages)
def test_truncate_by_halving_empty_list(self):
"""Test truncate_by_halving with empty list."""
config = ContextConfig()
manager = ContextManager(config)
result = manager.truncator.truncate_by_halving([])
assert result == []
def test_truncate_by_halving_single_message(self):
"""Test truncate_by_halving with single message."""
config = ContextConfig()
manager = ContextManager(config)
messages = [self.create_message("user", "Hello")]
result = manager.truncator.truncate_by_halving(messages)
assert len(result) <= 1
# ==================== Complex Scenarios ====================
@pytest.mark.asyncio
async def test_multiple_compression_cycles(self):
"""Test that compression can be triggered multiple times in sequence."""
config = ContextConfig(max_context_tokens=50, truncate_turns=1)
manager = ContextManager(config)
# Process messages multiple times
messages = self.create_messages(10)
result1 = await manager.process(messages)
result2 = await manager.process(result1)
result3 = await manager.process(result2)
# Each cycle should maintain or reduce message count
assert len(result3) <= len(result2) <= len(result1)
@pytest.mark.asyncio
async def test_alternating_roles_preserved(self):
"""Test that user/assistant alternation is preserved after processing."""
config = ContextConfig(enforce_max_turns=3, truncate_turns=1)
manager = ContextManager(config)
messages = self.create_messages(20)
result = await manager.process(messages)
# Check that roles still alternate (excluding system messages)
non_system = [m for m in result if m.role != "system"]
if len(non_system) >= 2:
# Should start with user
assert non_system[0].role == "user"
@pytest.mark.asyncio
async def test_compression_threshold_default(self):
"""Test that compression threshold is used correctly."""
config = ContextConfig(max_context_tokens=100)
manager = ContextManager(config)
# Verify the default threshold is 0.82
assert manager.compressor.compression_threshold == 0.82
# Test threshold logic
messages = [self.create_message("user", "x" * 81)] # ~24 tokens
tokens = manager.token_counter.count_tokens(messages)
needs_compression = manager.compressor.should_compress(messages, tokens, 100)
# Should not compress if below threshold
assert needs_compression == (tokens > 82)
@pytest.mark.asyncio
async def test_large_batch_processing(self):
"""Test processing a large batch of messages."""
config = ContextConfig(
enforce_max_turns=10, max_context_tokens=1000, truncate_turns=2
)
manager = ContextManager(config)
# Create 100 messages (50 turns)
messages = self.create_messages(100)
result = await manager.process(messages)
# Should be significantly reduced
assert len(result) < 100
assert len(result) > 0
@pytest.mark.asyncio
async def test_config_persistence(self):
"""Test that config settings are respected throughout processing."""
config = ContextConfig(
max_context_tokens=500,
enforce_max_turns=5,
truncate_turns=2,
llm_compress_keep_recent=3,
)
manager = ContextManager(config)
# Verify config is stored
assert manager.config.max_context_tokens == 500
assert manager.config.enforce_max_turns == 5
assert manager.config.truncate_turns == 2
assert manager.config.llm_compress_keep_recent == 3
# ==================== Run Compression Tests ====================
@pytest.mark.asyncio
async def test_run_compression_calls_compressor(self):
"""Test _run_compression calls compressor."""
config = ContextConfig(max_context_tokens=100)
manager = ContextManager(config)
messages = self.create_messages(5)
compressed = self.create_messages(3)
# Create a mock compressor
mock_compressor = AsyncMock()
mock_compressor.compression_threshold = 0.82
mock_compressor.return_value = compressed
mock_compressor.should_compress = MagicMock(return_value=False)
manager.compressor = mock_compressor
result = await manager._run_compression(messages, prev_tokens=100)
# Compressor __call__ should be invoked
mock_compressor.assert_called_once_with(messages)
assert result == compressed
@pytest.mark.asyncio
async def test_run_compression_applies_compressor_through_process(self):
"""Test _run_compression calls compressor when needed through process()."""
config = ContextConfig(max_context_tokens=100, truncate_turns=1)
manager = ContextManager(config)
# Create messages that will trigger compression
messages = [self.create_message("user", "x" * 300)] # ~90 tokens > 82 threshold
compressed = [self.create_message("user", "short")] # Much smaller
# Create a mock compressor
mock_compressor = AsyncMock()
mock_compressor.compression_threshold = 0.82
mock_compressor.return_value = compressed
# Mock should_compress to return True first time, False after
call_count = 0
def mock_should_compress(*args, **kwargs):
nonlocal call_count
call_count += 1
return call_count == 1
mock_compressor.should_compress = mock_should_compress
manager.compressor = mock_compressor
result = await manager.process(messages)
# Compressor should have been called
mock_compressor.assert_called_once()
assert len(result) <= len(messages)
@pytest.mark.asyncio
async def test_llm_compression_with_mock_provider(self):
"""Test LLM compression using MockProvider."""
mock_provider = MockProvider()
config = ContextConfig(
llm_compress_provider=mock_provider, # type: ignore
llm_compress_keep_recent=3,
llm_compress_instruction="请总结对话内容",
max_context_tokens=100,
)
manager = ContextManager(config)
# Create messages that will trigger compression
messages = [
self.create_message("user", "x" * 100),
self.create_message("assistant", "y" * 100),
self.create_message("user", "z" * 100),
]
result = await manager.process(messages)
# Should have been compressed
assert len(result) <= len(messages)
# ==================== split_history Tests ====================
def test_split_history_ensures_user_start(self):
"""Test split_history ensures recent_messages starts with user message."""
from astrbot.core.agent.context.compressor import split_history
# Create alternating messages: user, assistant, user, assistant, user, assistant
messages = [
self.create_message("system", "System prompt"),
self.create_message("user", "msg1"),
self.create_message("assistant", "msg2"),
self.create_message("user", "msg3"),
self.create_message("assistant", "msg4"),
self.create_message("user", "msg5"),
self.create_message("assistant", "msg6"),
]
# Keep recent 3 messages - should adjust to start with user
system, to_summarize, recent = split_history(messages, keep_recent=3)
# recent_messages should start with user message
assert len(recent) > 0
assert recent[0].role == "user"
# messages_to_summarize should end with assistant (complete turn)
if len(to_summarize) > 0:
assert to_summarize[-1].role == "assistant"
def test_split_history_handles_assistant_at_split_point(self):
"""Test split_history when assistant message is at the intended split point."""
from astrbot.core.agent.context.compressor import split_history
messages = [
self.create_message("user", "msg1"),
self.create_message("assistant", "msg2"),
self.create_message("user", "msg3"),
self.create_message("assistant", "msg4"), # <- intended split here
self.create_message("user", "msg5"),
self.create_message("assistant", "msg6"),
]
# keep_recent=2 would normally split at index 4 (assistant msg4)
# Should move back to include from msg5 (user)
system, to_summarize, recent = split_history(messages, keep_recent=2)
# recent should start with user message
assert recent[0].role == "user"
assert recent[0].content == "msg5"
def test_split_history_all_assistant_messages(self):
"""Test split_history when there are consecutive assistant messages."""
from astrbot.core.agent.context.compressor import split_history
messages = [
self.create_message("user", "msg1"),
self.create_message("assistant", "msg2"),
self.create_message("assistant", "msg3"),
self.create_message("assistant", "msg4"),
]
system, to_summarize, recent = split_history(messages, keep_recent=2)
# Should find the user message and keep from there
if len(recent) > 0:
# Find first user message backwards
assert any(m.role == "user" for m in messages)
def test_split_history_with_system_messages(self):
"""Test split_history preserves system messages separately."""
from astrbot.core.agent.context.compressor import split_history
messages = [
self.create_message("system", "System 1"),
self.create_message("system", "System 2"),
self.create_message("user", "msg1"),
self.create_message("assistant", "msg2"),
self.create_message("user", "msg3"),
]
system, to_summarize, recent = split_history(messages, keep_recent=2)
# System messages should be separate
assert len(system) == 2
assert all(m.role == "system" for m in system)
# Recent should start with user
if len(recent) > 0:
assert recent[0].role == "user"
+423
View File
@@ -0,0 +1,423 @@
"""Tests for ContextTruncator."""
from astrbot.core.agent.context.truncator import ContextTruncator
from astrbot.core.agent.message import Message
class TestContextTruncator:
"""Test suite for ContextTruncator."""
def create_message(self, role: str, content: str = "test content") -> Message:
"""Helper to create a simple test message."""
return Message(role=role, content=content)
def create_messages(
self, count: int, include_system: bool = False
) -> list[Message]:
"""Helper to create alternating user/assistant messages.
Args:
count: Number of messages to create
include_system: Whether to include a system message at the start
Returns:
List of messages
"""
messages = []
if include_system:
messages.append(self.create_message("system", "System prompt"))
for i in range(count):
role = "user" if i % 2 == 0 else "assistant"
messages.append(self.create_message(role, f"Message {i}"))
return messages
# ==================== fix_messages Tests ====================
def test_fix_messages_empty_list(self):
"""Test fix_messages with an empty list."""
truncator = ContextTruncator()
result = truncator.fix_messages([])
assert result == []
def test_fix_messages_normal_messages(self):
"""Test fix_messages with normal user/assistant messages."""
truncator = ContextTruncator()
messages = [
self.create_message("user", "Hello"),
self.create_message("assistant", "Hi"),
self.create_message("user", "How are you?"),
]
result = truncator.fix_messages(messages)
assert len(result) == 3
assert result == messages
def test_fix_messages_tool_with_valid_context(self):
"""Test fix_messages with tool message after user+assistant."""
truncator = ContextTruncator()
messages = [
self.create_message("user", "Run tool"),
self.create_message("assistant", "Running..."),
self.create_message("tool", "Tool result"),
]
result = truncator.fix_messages(messages)
assert len(result) == 3
assert result == messages
def test_fix_messages_tool_without_context(self):
"""Test fix_messages with tool message without enough context."""
truncator = ContextTruncator()
messages = [
self.create_message("tool", "Tool result"),
]
result = truncator.fix_messages(messages)
# Tool message without context should be removed
assert len(result) == 0
def test_fix_messages_tool_with_only_one_message(self):
"""Test fix_messages with tool message after only one message."""
truncator = ContextTruncator()
messages = [
self.create_message("user", "Hello"),
self.create_message("tool", "Tool result"),
]
result = truncator.fix_messages(messages)
# Tool message without enough context should be removed
assert len(result) == 0
def test_fix_messages_multiple_tools(self):
"""Test fix_messages with multiple tool messages."""
truncator = ContextTruncator()
messages = [
self.create_message("user", "Run tool"),
self.create_message("assistant", "Running..."),
self.create_message("tool", "Tool 1 result"),
self.create_message("tool", "Tool 2 result"),
]
result = truncator.fix_messages(messages)
assert len(result) == 4
assert result == messages
def test_fix_messages_mixed_system_tool(self):
"""Test fix_messages with system message and tool messages."""
truncator = ContextTruncator()
messages = [
self.create_message("system", "System prompt"),
self.create_message("user", "Run tool"),
self.create_message("assistant", "Running..."),
self.create_message("tool", "Tool result"),
]
result = truncator.fix_messages(messages)
assert len(result) == 4
assert result == messages
# ==================== truncate_by_turns Tests ====================
def test_truncate_by_turns_no_limit(self):
"""Test truncate_by_turns with -1 (no limit)."""
truncator = ContextTruncator()
messages = self.create_messages(20)
result = truncator.truncate_by_turns(messages, keep_most_recent_turns=-1)
assert len(result) == 20
assert result == messages
def test_truncate_by_turns_basic(self):
"""Test basic truncate_by_turns functionality."""
truncator = ContextTruncator()
# Create 10 messages = 5 turns (user/assistant pairs)
messages = self.create_messages(10)
result = truncator.truncate_by_turns(
messages, keep_most_recent_turns=3, drop_turns=1
)
# Should keep 3 most recent turns (6 messages)
assert len(result) <= 8 # (3-1+1)*2 = 6, but may adjust for correct format
def test_truncate_by_turns_with_system_message(self):
"""Test truncate_by_turns preserves system messages."""
truncator = ContextTruncator()
messages = self.create_messages(10, include_system=True)
result = truncator.truncate_by_turns(
messages, keep_most_recent_turns=2, drop_turns=1
)
# System message should always be preserved
assert result[0].role == "system"
assert result[0].content == "System prompt"
def test_truncate_by_turns_zero_keep(self):
"""Test truncate_by_turns with keep_most_recent_turns=0."""
truncator = ContextTruncator()
messages = self.create_messages(10)
result = truncator.truncate_by_turns(
messages, keep_most_recent_turns=0, drop_turns=1
)
# Should result in empty or minimal list
assert len(result) == 0
def test_truncate_by_turns_below_threshold(self):
"""Test truncate_by_turns when messages are below threshold."""
truncator = ContextTruncator()
# Create 4 messages = 2 turns
messages = self.create_messages(4)
result = truncator.truncate_by_turns(
messages, keep_most_recent_turns=5, drop_turns=1
)
# No truncation should happen
assert len(result) == 4
assert result == messages
def test_truncate_by_turns_exact_threshold(self):
"""Test truncate_by_turns when messages exactly match threshold."""
truncator = ContextTruncator()
# Create 6 messages = 3 turns
messages = self.create_messages(6)
result = truncator.truncate_by_turns(
messages, keep_most_recent_turns=3, drop_turns=1
)
# No truncation should happen
assert len(result) == 6
assert result == messages
def test_truncate_by_turns_ensures_user_first(self):
"""Test that truncate_by_turns ensures user message comes first."""
truncator = ContextTruncator()
# Create scenario where truncation might start with assistant
messages = self.create_messages(20)
result = truncator.truncate_by_turns(
messages, keep_most_recent_turns=3, drop_turns=1
)
# First non-system message should be user
assert result[0].role == "user"
def test_truncate_by_turns_multiple_drop(self):
"""Test truncate_by_turns with multiple turns dropped at once."""
truncator = ContextTruncator()
messages = self.create_messages(20)
result = truncator.truncate_by_turns(
messages, keep_most_recent_turns=5, drop_turns=3
)
# Should drop 3 turns when limit exceeded
assert len(result) < len(messages)
# ==================== truncate_by_dropping_oldest_turns Tests ====================
def test_truncate_by_dropping_oldest_turns_zero(self):
"""Test truncate_by_dropping_oldest_turns with drop_turns=0."""
truncator = ContextTruncator()
messages = self.create_messages(10)
result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=0)
assert result == messages
def test_truncate_by_dropping_oldest_turns_negative(self):
"""Test truncate_by_dropping_oldest_turns with negative drop_turns."""
truncator = ContextTruncator()
messages = self.create_messages(10)
result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=-1)
assert result == messages
def test_truncate_by_dropping_oldest_turns_basic(self):
"""Test basic truncate_by_dropping_oldest_turns functionality."""
truncator = ContextTruncator()
# Create 10 messages = 5 turns
messages = self.create_messages(10)
result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=2)
# Should drop 2 oldest turns (4 messages)
assert len(result) == 6
# Should start with user message
assert result[0].role == "user"
def test_truncate_by_dropping_oldest_turns_with_system(self):
"""Test truncate_by_dropping_oldest_turns preserves system messages."""
truncator = ContextTruncator()
messages = self.create_messages(10, include_system=True)
result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=2)
# System message should be preserved
assert result[0].role == "system"
assert result[0].content == "System prompt"
def test_truncate_by_dropping_oldest_turns_drop_all(self):
"""Test truncate_by_dropping_oldest_turns dropping all turns."""
truncator = ContextTruncator()
# Create 4 messages = 2 turns
messages = self.create_messages(4)
result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=2)
# Should drop all turns
assert len(result) == 0
def test_truncate_by_dropping_oldest_turns_drop_more_than_available(self):
"""Test truncate_by_dropping_oldest_turns with drop_turns > available turns."""
truncator = ContextTruncator()
# Create 4 messages = 2 turns
messages = self.create_messages(4)
result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=5)
# Should result in empty list
assert len(result) == 0
def test_truncate_by_dropping_oldest_turns_ensures_user_first(self):
"""Test that result starts with user message after dropping."""
truncator = ContextTruncator()
messages = self.create_messages(20)
result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=3)
# First message should be user
if len(result) > 0:
assert result[0].role == "user"
# ==================== truncate_by_halving Tests ====================
def test_truncate_by_halving_empty(self):
"""Test truncate_by_halving with empty list."""
truncator = ContextTruncator()
result = truncator.truncate_by_halving([])
assert result == []
def test_truncate_by_halving_single_message(self):
"""Test truncate_by_halving with single message."""
truncator = ContextTruncator()
messages = [self.create_message("user", "Hello")]
result = truncator.truncate_by_halving(messages)
# Should not truncate if <= 2 messages
assert result == messages
def test_truncate_by_halving_two_messages(self):
"""Test truncate_by_halving with two messages."""
truncator = ContextTruncator()
messages = self.create_messages(2)
result = truncator.truncate_by_halving(messages)
# Should not truncate if <= 2 messages
assert result == messages
def test_truncate_by_halving_basic(self):
"""Test basic truncate_by_halving functionality."""
truncator = ContextTruncator()
# Create 20 messages
messages = self.create_messages(20)
result = truncator.truncate_by_halving(messages)
# Should delete 50% = 10 messages, keep 10
assert len(result) == 10
# First message should be user
assert result[0].role == "user"
def test_truncate_by_halving_with_system_message(self):
"""Test truncate_by_halving preserves system messages."""
truncator = ContextTruncator()
messages = self.create_messages(20, include_system=True)
result = truncator.truncate_by_halving(messages)
# System message should be preserved
assert result[0].role == "system"
assert result[0].content == "System prompt"
def test_truncate_by_halving_odd_count(self):
"""Test truncate_by_halving with odd number of messages."""
truncator = ContextTruncator()
messages = self.create_messages(11)
result = truncator.truncate_by_halving(messages)
# Should delete floor(11/2) = 5 messages, keep 6
# But after ensuring user first, may be 5
assert len(result) >= 5
assert result[0].role == "user"
def test_truncate_by_halving_ensures_user_first(self):
"""Test that result starts with user message."""
truncator = ContextTruncator()
# Create messages starting with user
messages = self.create_messages(30)
result = truncator.truncate_by_halving(messages)
# First message should be user
assert result[0].role == "user"
def test_truncate_by_halving_preserves_recent_messages(self):
"""Test that truncate_by_halving keeps the most recent 50%."""
truncator = ContextTruncator()
messages = [
self.create_message("user", "Message 0"),
self.create_message("assistant", "Message 1"),
self.create_message("user", "Message 2"),
self.create_message("assistant", "Message 3"),
]
result = truncator.truncate_by_halving(messages)
# Should keep last 2 messages
assert len(result) == 2
assert result[0].content == "Message 2"
assert result[1].content == "Message 3"
# ==================== Integration Tests ====================
def test_truncate_with_tool_messages(self):
"""Test truncation with tool messages."""
truncator = ContextTruncator()
messages = [
self.create_message("user", "Run tool"),
self.create_message("assistant", "Running..."),
self.create_message("tool", "Tool result"),
self.create_message("user", "Thanks"),
self.create_message("assistant", "Welcome"),
]
result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=1)
# First turn (user+assistant+tool) should be dropped
# Tool message should be cleaned up by fix_messages
assert len(result) <= 2
def test_chain_multiple_truncations(self):
"""Test chaining multiple truncation methods."""
truncator = ContextTruncator()
messages = self.create_messages(40, include_system=True)
# First: truncate by turns
result = truncator.truncate_by_turns(
messages, keep_most_recent_turns=10, drop_turns=2
)
# Then: halve
result = truncator.truncate_by_halving(result)
# Should have system message + truncated content
assert result[0].role == "system"
assert len(result) < len(messages)
def test_empty_after_system_message(self):
"""Test truncation when only system message exists."""
truncator = ContextTruncator()
messages = [self.create_message("system", "System prompt")]
result = truncator.truncate_by_turns(
messages, keep_most_recent_turns=5, drop_turns=1
)
# Should keep system message
assert len(result) == 1
assert result[0].role == "system"
def test_all_system_messages(self):
"""Test truncation with only system messages."""
truncator = ContextTruncator()
messages = [
self.create_message("system", "System 1"),
self.create_message("system", "System 2"),
]
result = truncator.truncate_by_turns(
messages, keep_most_recent_turns=0, drop_turns=1
)
# System messages should be preserved, but since there are no non-system
# messages and keep_most_recent_turns=0, result should be system messages only
assert len(result) >= 0 # May keep system messages or clear all
if len(result) > 0:
assert all(msg.role == "system" for msg in result)