Compare commits
5 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 8e7b44185d | |||
| ef1c66a92e | |||
| 241f1c26d3 | |||
| 3615b7dde2 | |||
| 9bcf9bf2a0 |
@@ -1 +1 @@
|
||||
__version__ = "4.10.6"
|
||||
__version__ = "4.11.0"
|
||||
|
||||
@@ -0,0 +1,243 @@
|
||||
from typing import TYPE_CHECKING, Protocol, runtime_checkable
|
||||
|
||||
from ..message import Message
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from astrbot import logger
|
||||
else:
|
||||
try:
|
||||
from astrbot import logger
|
||||
except ImportError:
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger("astrbot")
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from astrbot.core.provider.provider import Provider
|
||||
|
||||
from ..context.truncator import ContextTruncator
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class ContextCompressor(Protocol):
|
||||
"""
|
||||
Protocol for context compressors.
|
||||
Provides an interface for compressing message lists.
|
||||
"""
|
||||
|
||||
def should_compress(
|
||||
self, messages: list[Message], current_tokens: int, max_tokens: int
|
||||
) -> bool:
|
||||
"""Check if compression is needed.
|
||||
|
||||
Args:
|
||||
messages: The message list to evaluate.
|
||||
current_tokens: The current token count.
|
||||
max_tokens: The maximum allowed tokens for the model.
|
||||
|
||||
Returns:
|
||||
True if compression is needed, False otherwise.
|
||||
"""
|
||||
...
|
||||
|
||||
async def __call__(self, messages: list[Message]) -> list[Message]:
|
||||
"""Compress the message list.
|
||||
|
||||
Args:
|
||||
messages: The original message list.
|
||||
|
||||
Returns:
|
||||
The compressed message list.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class TruncateByTurnsCompressor:
|
||||
"""Truncate by turns compressor implementation.
|
||||
Truncates the message list by removing older turns.
|
||||
"""
|
||||
|
||||
def __init__(self, truncate_turns: int = 1, compression_threshold: float = 0.82):
|
||||
"""Initialize the truncate by turns compressor.
|
||||
|
||||
Args:
|
||||
truncate_turns: The number of turns to remove when truncating (default: 1).
|
||||
compression_threshold: The compression trigger threshold (default: 0.82).
|
||||
"""
|
||||
self.truncate_turns = truncate_turns
|
||||
self.compression_threshold = compression_threshold
|
||||
|
||||
def should_compress(
|
||||
self, messages: list[Message], current_tokens: int, max_tokens: int
|
||||
) -> bool:
|
||||
"""Check if compression is needed.
|
||||
|
||||
Args:
|
||||
messages: The message list to evaluate.
|
||||
current_tokens: The current token count.
|
||||
max_tokens: The maximum allowed tokens.
|
||||
|
||||
Returns:
|
||||
True if compression is needed, False otherwise.
|
||||
"""
|
||||
if max_tokens <= 0 or current_tokens <= 0:
|
||||
return False
|
||||
usage_rate = current_tokens / max_tokens
|
||||
return usage_rate > self.compression_threshold
|
||||
|
||||
async def __call__(self, messages: list[Message]) -> list[Message]:
|
||||
truncator = ContextTruncator()
|
||||
truncated_messages = truncator.truncate_by_dropping_oldest_turns(
|
||||
messages,
|
||||
drop_turns=self.truncate_turns,
|
||||
)
|
||||
return truncated_messages
|
||||
|
||||
|
||||
def split_history(
|
||||
messages: list[Message], keep_recent: int
|
||||
) -> tuple[list[Message], list[Message], list[Message]]:
|
||||
"""Split the message list into system messages, messages to summarize, and recent messages.
|
||||
|
||||
Ensures that the split point is between complete user-assistant pairs to maintain conversation flow.
|
||||
|
||||
Args:
|
||||
messages: The original message list.
|
||||
keep_recent: The number of latest messages to keep.
|
||||
|
||||
Returns:
|
||||
tuple: (system_messages, messages_to_summarize, recent_messages)
|
||||
"""
|
||||
# keep the system messages
|
||||
first_non_system = 0
|
||||
for i, msg in enumerate(messages):
|
||||
if msg.role != "system":
|
||||
first_non_system = i
|
||||
break
|
||||
|
||||
system_messages = messages[:first_non_system]
|
||||
non_system_messages = messages[first_non_system:]
|
||||
|
||||
if len(non_system_messages) <= keep_recent:
|
||||
return system_messages, [], non_system_messages
|
||||
|
||||
# Find the split point, ensuring recent_messages starts with a user message
|
||||
# This maintains complete conversation turns
|
||||
split_index = len(non_system_messages) - keep_recent
|
||||
|
||||
# Search backward from split_index to find the first user message
|
||||
# This ensures recent_messages starts with a user message (complete turn)
|
||||
while split_index > 0 and non_system_messages[split_index].role != "user":
|
||||
# TODO: +=1 or -=1 ? calculate by tokens
|
||||
split_index -= 1
|
||||
|
||||
# If we couldn't find a user message, keep all messages as recent
|
||||
if split_index == 0:
|
||||
return system_messages, [], non_system_messages
|
||||
|
||||
messages_to_summarize = non_system_messages[:split_index]
|
||||
recent_messages = non_system_messages[split_index:]
|
||||
|
||||
return system_messages, messages_to_summarize, recent_messages
|
||||
|
||||
|
||||
class LLMSummaryCompressor:
|
||||
"""LLM-based summary compressor.
|
||||
Uses LLM to summarize the old conversation history, keeping the latest messages.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
provider: "Provider",
|
||||
keep_recent: int = 4,
|
||||
instruction_text: str | None = None,
|
||||
compression_threshold: float = 0.82,
|
||||
):
|
||||
"""Initialize the LLM summary compressor.
|
||||
|
||||
Args:
|
||||
provider: The LLM provider instance.
|
||||
keep_recent: The number of latest messages to keep (default: 4).
|
||||
instruction_text: Custom instruction for summary generation.
|
||||
compression_threshold: The compression trigger threshold (default: 0.82).
|
||||
"""
|
||||
self.provider = provider
|
||||
self.keep_recent = keep_recent
|
||||
self.compression_threshold = compression_threshold
|
||||
|
||||
self.instruction_text = instruction_text or (
|
||||
"Based on our full conversation history, produce a concise summary of key takeaways and/or project progress.\n"
|
||||
"1. Systematically cover all core topics discussed and the final conclusion/outcome for each; clearly highlight the latest primary focus.\n"
|
||||
"2. If any tools were used, summarize tool usage (total call count) and extract the most valuable insights from tool outputs.\n"
|
||||
"3. If there was an initial user goal, state it first and describe the current progress/status.\n"
|
||||
"4. Write the summary in the user's language.\n"
|
||||
)
|
||||
|
||||
def should_compress(
|
||||
self, messages: list[Message], current_tokens: int, max_tokens: int
|
||||
) -> bool:
|
||||
"""Check if compression is needed.
|
||||
|
||||
Args:
|
||||
messages: The message list to evaluate.
|
||||
current_tokens: The current token count.
|
||||
max_tokens: The maximum allowed tokens.
|
||||
|
||||
Returns:
|
||||
True if compression is needed, False otherwise.
|
||||
"""
|
||||
if max_tokens <= 0 or current_tokens <= 0:
|
||||
return False
|
||||
usage_rate = current_tokens / max_tokens
|
||||
return usage_rate > self.compression_threshold
|
||||
|
||||
async def __call__(self, messages: list[Message]) -> list[Message]:
|
||||
"""Use LLM to generate a summary of the conversation history.
|
||||
|
||||
Process:
|
||||
1. Divide messages: keep the system message and the latest N messages.
|
||||
2. Send the old messages + the instruction message to the LLM.
|
||||
3. Reconstruct the message list: [system message, summary message, latest messages].
|
||||
"""
|
||||
if len(messages) <= self.keep_recent + 1:
|
||||
return messages
|
||||
|
||||
system_messages, messages_to_summarize, recent_messages = split_history(
|
||||
messages, self.keep_recent
|
||||
)
|
||||
|
||||
if not messages_to_summarize:
|
||||
return messages
|
||||
|
||||
# build payload
|
||||
instruction_message = Message(role="user", content=self.instruction_text)
|
||||
llm_payload = messages_to_summarize + [instruction_message]
|
||||
|
||||
# generate summary
|
||||
try:
|
||||
response = await self.provider.text_chat(contexts=llm_payload)
|
||||
summary_content = response.completion_text
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate summary: {e}")
|
||||
return messages
|
||||
|
||||
# build result
|
||||
result = []
|
||||
result.extend(system_messages)
|
||||
|
||||
result.append(
|
||||
Message(
|
||||
role="user",
|
||||
content=f"Our previous history conversation summary: {summary_content}",
|
||||
)
|
||||
)
|
||||
result.append(
|
||||
Message(
|
||||
role="assistant",
|
||||
content="Acknowledged the summary of our previous conversation history.",
|
||||
)
|
||||
)
|
||||
|
||||
result.extend(recent_messages)
|
||||
|
||||
return result
|
||||
@@ -0,0 +1,35 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .compressor import ContextCompressor
|
||||
from .token_counter import TokenCounter
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from astrbot.core.provider.provider import Provider
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContextConfig:
|
||||
"""Context configuration class."""
|
||||
|
||||
max_context_tokens: int = 0
|
||||
"""Maximum number of context tokens. <= 0 means no limit."""
|
||||
enforce_max_turns: int = -1 # -1 means no limit
|
||||
"""Maximum number of conversation turns to keep. -1 means no limit. Executed before compression."""
|
||||
truncate_turns: int = 1
|
||||
"""Number of conversation turns to discard at once when truncation is triggered.
|
||||
Two processes will use this value:
|
||||
|
||||
1. Enforce max turns truncation.
|
||||
2. Truncation by turns compression strategy.
|
||||
"""
|
||||
llm_compress_instruction: str | None = None
|
||||
"""Instruction prompt for LLM-based compression."""
|
||||
llm_compress_keep_recent: int = 0
|
||||
"""Number of recent messages to keep during LLM-based compression."""
|
||||
llm_compress_provider: "Provider | None" = None
|
||||
"""LLM provider used for compression tasks. If None, truncation strategy is used."""
|
||||
custom_token_counter: TokenCounter | None = None
|
||||
"""Custom token counting method. If None, the default method is used."""
|
||||
custom_compressor: ContextCompressor | None = None
|
||||
"""Custom context compression method. If None, the default method is used."""
|
||||
@@ -0,0 +1,120 @@
|
||||
from astrbot import logger
|
||||
|
||||
from ..message import Message
|
||||
from .compressor import LLMSummaryCompressor, TruncateByTurnsCompressor
|
||||
from .config import ContextConfig
|
||||
from .token_counter import EstimateTokenCounter
|
||||
from .truncator import ContextTruncator
|
||||
|
||||
|
||||
class ContextManager:
|
||||
"""Context compression manager."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: ContextConfig,
|
||||
):
|
||||
"""Initialize the context manager.
|
||||
|
||||
There are two strategies to handle context limit reached:
|
||||
1. Truncate by turns: remove older messages by turns.
|
||||
2. LLM-based compression: use LLM to summarize old messages.
|
||||
|
||||
Args:
|
||||
config: The context configuration.
|
||||
"""
|
||||
self.config = config
|
||||
|
||||
self.token_counter = config.custom_token_counter or EstimateTokenCounter()
|
||||
self.truncator = ContextTruncator()
|
||||
|
||||
if config.custom_compressor:
|
||||
self.compressor = config.custom_compressor
|
||||
elif config.llm_compress_provider:
|
||||
self.compressor = LLMSummaryCompressor(
|
||||
provider=config.llm_compress_provider,
|
||||
keep_recent=config.llm_compress_keep_recent,
|
||||
instruction_text=config.llm_compress_instruction,
|
||||
)
|
||||
else:
|
||||
self.compressor = TruncateByTurnsCompressor(
|
||||
truncate_turns=config.truncate_turns
|
||||
)
|
||||
|
||||
async def process(
|
||||
self, messages: list[Message], trusted_token_usage: int = 0
|
||||
) -> list[Message]:
|
||||
"""Process the messages.
|
||||
|
||||
Args:
|
||||
messages: The original message list.
|
||||
|
||||
Returns:
|
||||
The processed message list.
|
||||
"""
|
||||
try:
|
||||
result = messages
|
||||
|
||||
# 1. 基于轮次的截断 (Enforce max turns)
|
||||
if self.config.enforce_max_turns != -1:
|
||||
result = self.truncator.truncate_by_turns(
|
||||
result,
|
||||
keep_most_recent_turns=self.config.enforce_max_turns,
|
||||
drop_turns=self.config.truncate_turns,
|
||||
)
|
||||
|
||||
# 2. 基于 token 的压缩
|
||||
if self.config.max_context_tokens > 0:
|
||||
total_tokens = self.token_counter.count_tokens(
|
||||
result, trusted_token_usage
|
||||
)
|
||||
|
||||
if self.compressor.should_compress(
|
||||
result, total_tokens, self.config.max_context_tokens
|
||||
):
|
||||
result = await self._run_compression(result, total_tokens)
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Error during context processing: {e}", exc_info=True)
|
||||
return messages
|
||||
|
||||
async def _run_compression(
|
||||
self, messages: list[Message], prev_tokens: int
|
||||
) -> list[Message]:
|
||||
"""
|
||||
Compress/truncate the messages.
|
||||
|
||||
Args:
|
||||
messages: The original message list.
|
||||
prev_tokens: The token count before compression.
|
||||
|
||||
Returns:
|
||||
The compressed/truncated message list.
|
||||
"""
|
||||
logger.debug("Compress triggered, starting compression...")
|
||||
|
||||
messages = await self.compressor(messages)
|
||||
|
||||
# double check
|
||||
tokens_after_summary = self.token_counter.count_tokens(messages)
|
||||
|
||||
# calculate compress rate
|
||||
compress_rate = (tokens_after_summary / self.config.max_context_tokens) * 100
|
||||
logger.info(
|
||||
f"Compress completed."
|
||||
f" {prev_tokens} -> {tokens_after_summary} tokens,"
|
||||
f" compression rate: {compress_rate:.2f}%.",
|
||||
)
|
||||
|
||||
# last check
|
||||
if self.compressor.should_compress(
|
||||
messages, tokens_after_summary, self.config.max_context_tokens
|
||||
):
|
||||
logger.info(
|
||||
"Context still exceeds max tokens after compression, applying halving truncation..."
|
||||
)
|
||||
# still need compress, truncate by half
|
||||
messages = self.truncator.truncate_by_halving(messages)
|
||||
|
||||
return messages
|
||||
@@ -0,0 +1,64 @@
|
||||
import json
|
||||
from typing import Protocol, runtime_checkable
|
||||
|
||||
from ..message import Message, TextPart
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class TokenCounter(Protocol):
|
||||
"""
|
||||
Protocol for token counters.
|
||||
Provides an interface for counting tokens in message lists.
|
||||
"""
|
||||
|
||||
def count_tokens(
|
||||
self, messages: list[Message], trusted_token_usage: int = 0
|
||||
) -> int:
|
||||
"""Count the total tokens in the message list.
|
||||
|
||||
Args:
|
||||
messages: The message list.
|
||||
trusted_token_usage: The total token usage that LLM API returned.
|
||||
For some cases, this value is more accurate.
|
||||
But some API does not return it, so the value defaults to 0.
|
||||
|
||||
Returns:
|
||||
The total token count.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class EstimateTokenCounter:
|
||||
"""Estimate token counter implementation.
|
||||
Provides a simple estimation of token count based on character types.
|
||||
"""
|
||||
|
||||
def count_tokens(
|
||||
self, messages: list[Message], trusted_token_usage: int = 0
|
||||
) -> int:
|
||||
if trusted_token_usage > 0:
|
||||
return trusted_token_usage
|
||||
|
||||
total = 0
|
||||
for msg in messages:
|
||||
content = msg.content
|
||||
if isinstance(content, str):
|
||||
total += self._estimate_tokens(content)
|
||||
elif isinstance(content, list):
|
||||
# 处理多模态内容
|
||||
for part in content:
|
||||
if isinstance(part, TextPart):
|
||||
total += self._estimate_tokens(part.text)
|
||||
|
||||
# 处理 Tool Calls
|
||||
if msg.tool_calls:
|
||||
for tc in msg.tool_calls:
|
||||
tc_str = json.dumps(tc if isinstance(tc, dict) else tc.model_dump())
|
||||
total += self._estimate_tokens(tc_str)
|
||||
|
||||
return total
|
||||
|
||||
def _estimate_tokens(self, text: str) -> int:
|
||||
chinese_count = len([c for c in text if "\u4e00" <= c <= "\u9fff"])
|
||||
other_count = len(text) - chinese_count
|
||||
return int(chinese_count * 0.6 + other_count * 0.3)
|
||||
@@ -0,0 +1,141 @@
|
||||
from ..message import Message
|
||||
|
||||
|
||||
class ContextTruncator:
|
||||
"""Context truncator."""
|
||||
|
||||
def fix_messages(self, messages: list[Message]) -> list[Message]:
|
||||
fixed_messages = []
|
||||
for message in messages:
|
||||
if message.role == "tool":
|
||||
# tool block 前面必须要有 user 和 assistant block
|
||||
if len(fixed_messages) < 2:
|
||||
# 这种情况可能是上下文被截断导致的
|
||||
# 我们直接将之前的上下文都清空
|
||||
fixed_messages = []
|
||||
else:
|
||||
fixed_messages.append(message)
|
||||
else:
|
||||
fixed_messages.append(message)
|
||||
return fixed_messages
|
||||
|
||||
def truncate_by_turns(
|
||||
self,
|
||||
messages: list[Message],
|
||||
keep_most_recent_turns: int,
|
||||
drop_turns: int = 1,
|
||||
) -> list[Message]:
|
||||
"""截断上下文列表,确保不超过最大长度。
|
||||
一个 turn 包含一个 user 消息和一个 assistant 消息。
|
||||
这个方法会保证截断后的上下文列表符合 OpenAI 的上下文格式。
|
||||
|
||||
Args:
|
||||
messages: 上下文列表
|
||||
keep_most_recent_turns: 保留最近的对话轮数
|
||||
drop_turns: 一次性丢弃的对话轮数
|
||||
|
||||
Returns:
|
||||
截断后的上下文列表
|
||||
"""
|
||||
if keep_most_recent_turns == -1:
|
||||
return messages
|
||||
|
||||
first_non_system = 0
|
||||
for i, msg in enumerate(messages):
|
||||
if msg.role != "system":
|
||||
first_non_system = i
|
||||
break
|
||||
|
||||
system_messages = messages[:first_non_system]
|
||||
non_system_messages = messages[first_non_system:]
|
||||
|
||||
if len(non_system_messages) // 2 <= keep_most_recent_turns:
|
||||
return messages
|
||||
|
||||
num_to_keep = keep_most_recent_turns - drop_turns + 1
|
||||
if num_to_keep <= 0:
|
||||
truncated_contexts = []
|
||||
else:
|
||||
truncated_contexts = non_system_messages[-num_to_keep * 2 :]
|
||||
|
||||
# 找到第一个 role 为 user 的索引,确保上下文格式正确
|
||||
index = next(
|
||||
(i for i, item in enumerate(truncated_contexts) if item.role == "user"),
|
||||
None,
|
||||
)
|
||||
if index is not None and index > 0:
|
||||
truncated_contexts = truncated_contexts[index:]
|
||||
|
||||
result = system_messages + truncated_contexts
|
||||
|
||||
return self.fix_messages(result)
|
||||
|
||||
def truncate_by_dropping_oldest_turns(
|
||||
self,
|
||||
messages: list[Message],
|
||||
drop_turns: int = 1,
|
||||
) -> list[Message]:
|
||||
"""丢弃最旧的 N 个对话轮次。"""
|
||||
if drop_turns <= 0:
|
||||
return messages
|
||||
|
||||
first_non_system = 0
|
||||
for i, msg in enumerate(messages):
|
||||
if msg.role != "system":
|
||||
first_non_system = i
|
||||
break
|
||||
|
||||
system_messages = messages[:first_non_system]
|
||||
non_system_messages = messages[first_non_system:]
|
||||
|
||||
if len(non_system_messages) // 2 <= drop_turns:
|
||||
truncated_non_system = []
|
||||
else:
|
||||
truncated_non_system = non_system_messages[drop_turns * 2 :]
|
||||
|
||||
index = next(
|
||||
(i for i, item in enumerate(truncated_non_system) if item.role == "user"),
|
||||
None,
|
||||
)
|
||||
if index is not None:
|
||||
truncated_non_system = truncated_non_system[index:]
|
||||
elif truncated_non_system:
|
||||
truncated_non_system = []
|
||||
|
||||
result = system_messages + truncated_non_system
|
||||
|
||||
return self.fix_messages(result)
|
||||
|
||||
def truncate_by_halving(
|
||||
self,
|
||||
messages: list[Message],
|
||||
) -> list[Message]:
|
||||
"""对半砍策略,删除 50% 的消息"""
|
||||
if len(messages) <= 2:
|
||||
return messages
|
||||
|
||||
first_non_system = 0
|
||||
for i, msg in enumerate(messages):
|
||||
if msg.role != "system":
|
||||
first_non_system = i
|
||||
break
|
||||
|
||||
system_messages = messages[:first_non_system]
|
||||
non_system_messages = messages[first_non_system:]
|
||||
|
||||
messages_to_delete = len(non_system_messages) // 2
|
||||
if messages_to_delete == 0:
|
||||
return messages
|
||||
|
||||
truncated_non_system = non_system_messages[messages_to_delete:]
|
||||
|
||||
index = next(
|
||||
(i for i, item in enumerate(truncated_non_system) if item.role == "user"),
|
||||
None,
|
||||
)
|
||||
if index is not None:
|
||||
truncated_non_system = truncated_non_system[index:]
|
||||
|
||||
result = system_messages + truncated_non_system
|
||||
|
||||
return self.fix_messages(result)
|
||||
@@ -25,6 +25,10 @@ from astrbot.core.provider.entities import (
|
||||
)
|
||||
from astrbot.core.provider.provider import Provider
|
||||
|
||||
from ..context.compressor import ContextCompressor
|
||||
from ..context.config import ContextConfig
|
||||
from ..context.manager import ContextManager
|
||||
from ..context.token_counter import TokenCounter
|
||||
from ..hooks import BaseAgentRunHooks
|
||||
from ..message import AssistantMessageSegment, Message, ToolCallMessageSegment
|
||||
from ..response import AgentResponseData, AgentStats
|
||||
@@ -47,10 +51,47 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
run_context: ContextWrapper[TContext],
|
||||
tool_executor: BaseFunctionToolExecutor[TContext],
|
||||
agent_hooks: BaseAgentRunHooks[TContext],
|
||||
streaming: bool = False,
|
||||
# enforce max turns, will discard older turns when exceeded BEFORE compression
|
||||
# -1 means no limit
|
||||
enforce_max_turns: int = -1,
|
||||
# llm compressor
|
||||
llm_compress_instruction: str | None = None,
|
||||
llm_compress_keep_recent: int = 0,
|
||||
llm_compress_provider: Provider | None = None,
|
||||
# truncate by turns compressor
|
||||
truncate_turns: int = 1,
|
||||
# customize
|
||||
custom_token_counter: TokenCounter | None = None,
|
||||
custom_compressor: ContextCompressor | None = None,
|
||||
**kwargs: T.Any,
|
||||
) -> None:
|
||||
self.req = request
|
||||
self.streaming = kwargs.get("streaming", False)
|
||||
self.streaming = streaming
|
||||
self.enforce_max_turns = enforce_max_turns
|
||||
self.llm_compress_instruction = llm_compress_instruction
|
||||
self.llm_compress_keep_recent = llm_compress_keep_recent
|
||||
self.llm_compress_provider = llm_compress_provider
|
||||
self.truncate_turns = truncate_turns
|
||||
self.custom_token_counter = custom_token_counter
|
||||
self.custom_compressor = custom_compressor
|
||||
# we will do compress when:
|
||||
# 1. before requesting LLM
|
||||
# TODO: 2. after LLM output a tool call
|
||||
self.context_config = ContextConfig(
|
||||
# <=0 will never do compress
|
||||
max_context_tokens=provider.provider_config.get("max_context_tokens", 0),
|
||||
# enforce max turns before compression
|
||||
enforce_max_turns=self.enforce_max_turns,
|
||||
truncate_turns=self.truncate_turns,
|
||||
llm_compress_instruction=self.llm_compress_instruction,
|
||||
llm_compress_keep_recent=self.llm_compress_keep_recent,
|
||||
llm_compress_provider=self.llm_compress_provider,
|
||||
custom_token_counter=self.custom_token_counter,
|
||||
custom_compressor=self.custom_compressor,
|
||||
)
|
||||
self.context_manager = ContextManager(self.context_config)
|
||||
|
||||
self.provider = provider
|
||||
self.final_llm_resp = None
|
||||
self._state = AgentState.IDLE
|
||||
@@ -110,6 +151,12 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
self._transition_state(AgentState.RUNNING)
|
||||
llm_resp_result = None
|
||||
|
||||
# do truncate and compress
|
||||
token_usage = self.req.conversation.token_usage if self.req.conversation else 0
|
||||
self.run_context.messages = await self.context_manager.process(
|
||||
self.run_context.messages, trusted_token_usage=token_usage
|
||||
)
|
||||
|
||||
async for llm_response in self._iter_llm_responses():
|
||||
if llm_response.is_chunk:
|
||||
# update ttft
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Any, TypedDict
|
||||
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
VERSION = "4.10.6"
|
||||
VERSION = "4.11.0"
|
||||
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db")
|
||||
|
||||
WEBHOOK_SUPPORTED_PLATFORMS = [
|
||||
@@ -83,6 +83,16 @@ DEFAULT_CONFIG = {
|
||||
"default_personality": "default",
|
||||
"persona_pool": ["*"],
|
||||
"prompt_prefix": "{{prompt}}",
|
||||
"context_limit_reached_strategy": "truncate_by_turns", # or llm_compress
|
||||
"llm_compress_instruction": (
|
||||
"Based on our full conversation history, produce a concise summary of key takeaways and/or project progress.\n"
|
||||
"1. Systematically cover all core topics discussed and the final conclusion/outcome for each; clearly highlight the latest primary focus.\n"
|
||||
"2. If any tools were used, summarize tool usage (total call count) and extract the most valuable insights from tool outputs.\n"
|
||||
"3. If there was an initial user goal, state it first and describe the current progress/status.\n"
|
||||
"4. Write the summary in the user's language.\n"
|
||||
),
|
||||
"llm_compress_keep_recent": 4,
|
||||
"llm_compress_provider_id": "",
|
||||
"max_context_length": -1,
|
||||
"dequeue_context_length": 1,
|
||||
"streaming_response": False,
|
||||
@@ -179,6 +189,7 @@ class ChatProviderTemplate(TypedDict):
|
||||
model: str
|
||||
modalities: list
|
||||
custom_extra_body: dict[str, Any]
|
||||
max_context_tokens: int
|
||||
|
||||
|
||||
CHAT_PROVIDER_TEMPLATE = {
|
||||
@@ -187,6 +198,7 @@ CHAT_PROVIDER_TEMPLATE = {
|
||||
"model": "",
|
||||
"modalities": [],
|
||||
"custom_extra_body": {},
|
||||
"max_context_tokens": 0,
|
||||
}
|
||||
|
||||
"""
|
||||
@@ -2033,6 +2045,11 @@ CONFIG_METADATA_2 = {
|
||||
"type": "string",
|
||||
"hint": "模型名称,如 gpt-4o-mini, deepseek-chat。",
|
||||
},
|
||||
"max_context_tokens": {
|
||||
"description": "模型上下文窗口大小",
|
||||
"type": "int",
|
||||
"hint": "模型最大上下文 Token 大小。如果为 0,则会自动从模型元数据填充(如有),也可手动修改。",
|
||||
},
|
||||
"dify_api_key": {
|
||||
"description": "API Key",
|
||||
"type": "string",
|
||||
@@ -2540,6 +2557,66 @@ CONFIG_METADATA_3 = {
|
||||
# "provider_settings.enable": True,
|
||||
# },
|
||||
# },
|
||||
"truncate_and_compress": {
|
||||
"description": "上下文管理策略",
|
||||
"type": "object",
|
||||
"items": {
|
||||
"provider_settings.max_context_length": {
|
||||
"description": "最多携带对话轮数",
|
||||
"type": "int",
|
||||
"hint": "超出这个数量时丢弃最旧的部分,一轮聊天记为 1 条,-1 为不限制",
|
||||
"condition": {
|
||||
"provider_settings.agent_runner_type": "local",
|
||||
},
|
||||
},
|
||||
"provider_settings.dequeue_context_length": {
|
||||
"description": "丢弃对话轮数",
|
||||
"type": "int",
|
||||
"hint": "超出最多携带对话轮数时, 一次丢弃的聊天轮数",
|
||||
"condition": {
|
||||
"provider_settings.agent_runner_type": "local",
|
||||
},
|
||||
},
|
||||
"provider_settings.context_limit_reached_strategy": {
|
||||
"description": "超出模型上下文窗口时的处理方式",
|
||||
"type": "string",
|
||||
"options": ["truncate_by_turns", "llm_compress"],
|
||||
"labels": ["按对话轮数截断", "由 LLM 压缩上下文"],
|
||||
"condition": {
|
||||
"provider_settings.agent_runner_type": "local",
|
||||
},
|
||||
"hint": "",
|
||||
},
|
||||
"provider_settings.llm_compress_instruction": {
|
||||
"description": "上下文压缩提示词",
|
||||
"type": "text",
|
||||
"hint": "如果为空则使用默认提示词。",
|
||||
"condition": {
|
||||
"provider_settings.context_limit_reached_strategy": "llm_compress",
|
||||
"provider_settings.agent_runner_type": "local",
|
||||
},
|
||||
},
|
||||
"provider_settings.llm_compress_keep_recent": {
|
||||
"description": "压缩时保留最近对话轮数",
|
||||
"type": "int",
|
||||
"hint": "始终保留的最近 N 轮对话。",
|
||||
"condition": {
|
||||
"provider_settings.context_limit_reached_strategy": "llm_compress",
|
||||
"provider_settings.agent_runner_type": "local",
|
||||
},
|
||||
},
|
||||
"provider_settings.llm_compress_provider_id": {
|
||||
"description": "用于上下文压缩的模型提供商 ID",
|
||||
"type": "string",
|
||||
"_special": "select_provider",
|
||||
"hint": "留空时将降级为“按对话轮数截断”的策略。",
|
||||
"condition": {
|
||||
"provider_settings.context_limit_reached_strategy": "llm_compress",
|
||||
"provider_settings.agent_runner_type": "local",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"others": {
|
||||
"description": "其他配置",
|
||||
"type": "object",
|
||||
@@ -2604,22 +2681,6 @@ CONFIG_METADATA_3 = {
|
||||
"provider_settings.streaming_response": True,
|
||||
},
|
||||
},
|
||||
"provider_settings.max_context_length": {
|
||||
"description": "最多携带对话轮数",
|
||||
"type": "int",
|
||||
"hint": "超出这个数量时丢弃最旧的部分,一轮聊天记为 1 条,-1 为不限制",
|
||||
"condition": {
|
||||
"provider_settings.agent_runner_type": "local",
|
||||
},
|
||||
},
|
||||
"provider_settings.dequeue_context_length": {
|
||||
"description": "丢弃对话轮数",
|
||||
"type": "int",
|
||||
"hint": "超出最多携带对话轮数时, 一次丢弃的聊天轮数",
|
||||
"condition": {
|
||||
"provider_settings.agent_runner_type": "local",
|
||||
},
|
||||
},
|
||||
"provider_settings.wake_prefix": {
|
||||
"description": "LLM 聊天额外唤醒前缀 ",
|
||||
"type": "string",
|
||||
|
||||
@@ -69,6 +69,7 @@ class ConversationManager:
|
||||
persona_id=conv_v2.persona_id,
|
||||
created_at=created_at,
|
||||
updated_at=updated_at,
|
||||
token_usage=conv_v2.token_usage,
|
||||
)
|
||||
|
||||
async def new_conversation(
|
||||
@@ -256,6 +257,7 @@ class ConversationManager:
|
||||
history: list[dict] | None = None,
|
||||
title: str | None = None,
|
||||
persona_id: str | None = None,
|
||||
token_usage: int | None = None,
|
||||
) -> None:
|
||||
"""更新会话的对话.
|
||||
|
||||
@@ -263,6 +265,7 @@ class ConversationManager:
|
||||
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
||||
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
|
||||
history (List[Dict]): 对话历史记录, 是一个字典列表, 每个字典包含 role 和 content 字段
|
||||
token_usage (int | None): token 使用量。None 表示不更新
|
||||
|
||||
"""
|
||||
if not conversation_id:
|
||||
@@ -274,6 +277,7 @@ class ConversationManager:
|
||||
title=title,
|
||||
persona_id=persona_id,
|
||||
content=history,
|
||||
token_usage=token_usage,
|
||||
)
|
||||
|
||||
async def update_conversation_title(
|
||||
|
||||
@@ -152,6 +152,7 @@ class BaseDatabase(abc.ABC):
|
||||
title: str | None = None,
|
||||
persona_id: str | None = None,
|
||||
content: list[dict] | None = None,
|
||||
token_usage: int | None = None,
|
||||
) -> None:
|
||||
"""Update a conversation's history."""
|
||||
...
|
||||
|
||||
@@ -0,0 +1,61 @@
|
||||
"""Migration script to add token_usage column to conversations table.
|
||||
|
||||
This migration adds the token_usage field to track token consumption for each conversation.
|
||||
|
||||
Changes:
|
||||
- Adds token_usage column to conversations table (default: 0)
|
||||
"""
|
||||
|
||||
from sqlalchemy import text
|
||||
|
||||
from astrbot.api import logger, sp
|
||||
from astrbot.core.db import BaseDatabase
|
||||
|
||||
|
||||
async def migrate_token_usage(db_helper: BaseDatabase):
|
||||
"""Add token_usage column to conversations table.
|
||||
|
||||
This migration adds a new column to track token consumption in conversations.
|
||||
"""
|
||||
# 检查是否已经完成迁移
|
||||
migration_done = await db_helper.get_preference(
|
||||
"global", "global", "migration_done_token_usage_1"
|
||||
)
|
||||
if migration_done:
|
||||
return
|
||||
|
||||
logger.info("开始执行数据库迁移(添加 conversations.token_usage 列)...")
|
||||
|
||||
# 这里只适配了 SQLite。因为截止至这一版本,AstrBot 仅支持 SQLite。
|
||||
|
||||
try:
|
||||
async with db_helper.get_db() as session:
|
||||
# 检查列是否已存在
|
||||
result = await session.execute(text("PRAGMA table_info(conversations)"))
|
||||
columns = result.fetchall()
|
||||
column_names = [col[1] for col in columns]
|
||||
|
||||
if "token_usage" in column_names:
|
||||
logger.info("token_usage 列已存在,跳过迁移")
|
||||
await sp.put_async(
|
||||
"global", "global", "migration_done_token_usage_1", True
|
||||
)
|
||||
return
|
||||
|
||||
# 添加 token_usage 列
|
||||
await session.execute(
|
||||
text(
|
||||
"ALTER TABLE conversations ADD COLUMN token_usage INTEGER NOT NULL DEFAULT 0"
|
||||
)
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
logger.info("token_usage 列添加成功")
|
||||
|
||||
# 标记迁移完成
|
||||
await sp.put_async("global", "global", "migration_done_token_usage_1", True)
|
||||
logger.info("token_usage 迁移完成")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"迁移过程中发生错误: {e}", exc_info=True)
|
||||
raise
|
||||
@@ -54,6 +54,11 @@ class ConversationV2(SQLModel, table=True):
|
||||
)
|
||||
title: str | None = Field(default=None, max_length=255)
|
||||
persona_id: str | None = Field(default=None)
|
||||
token_usage: int = Field(default=0, nullable=False)
|
||||
"""content is a list of OpenAI-formated messages in list[dict] format.
|
||||
token_usage is the total token value of the messages.
|
||||
when 0, will use estimated token counter.
|
||||
"""
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
@@ -313,6 +318,8 @@ class Conversation:
|
||||
persona_id: str | None = ""
|
||||
created_at: int = 0
|
||||
updated_at: int = 0
|
||||
token_usage: int = 0
|
||||
"""对话的总 token 数量。AstrBot 会保留最近一次 LLM 请求返回的总 token 数,方便统计。token_usage 可能为 0,表示未知。"""
|
||||
|
||||
|
||||
class Personality(TypedDict):
|
||||
|
||||
@@ -241,7 +241,9 @@ class SQLiteDatabase(BaseDatabase):
|
||||
session.add(new_conversation)
|
||||
return new_conversation
|
||||
|
||||
async def update_conversation(self, cid, title=None, persona_id=None, content=None):
|
||||
async def update_conversation(
|
||||
self, cid, title=None, persona_id=None, content=None, token_usage=None
|
||||
):
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
async with session.begin():
|
||||
@@ -255,6 +257,8 @@ class SQLiteDatabase(BaseDatabase):
|
||||
values["persona_id"] = persona_id
|
||||
if content is not None:
|
||||
values["content"] = content
|
||||
if token_usage is not None:
|
||||
values["token_usage"] = token_usage
|
||||
if not values:
|
||||
return None
|
||||
query = query.values(**values)
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
"""本地 Agent 模式的 LLM 调用 Stage"""
|
||||
|
||||
import asyncio
|
||||
import copy
|
||||
import json
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.agent.message import Message
|
||||
from astrbot.core.agent.response import AgentStats
|
||||
from astrbot.core.agent.tool import ToolSet
|
||||
from astrbot.core.astr_agent_context import AstrAgentContext
|
||||
from astrbot.core.conversation_mgr import Conversation
|
||||
@@ -24,6 +24,7 @@ from astrbot.core.provider.entities import (
|
||||
)
|
||||
from astrbot.core.star.star_handler import EventType, star_map
|
||||
from astrbot.core.utils.file_extract import extract_file_moonshotai
|
||||
from astrbot.core.utils.llm_metadata import LLM_METADATAS
|
||||
from astrbot.core.utils.metrics import Metric
|
||||
from astrbot.core.utils.session_lock import session_lock_manager
|
||||
|
||||
@@ -41,11 +42,6 @@ class InternalAgentSubStage(Stage):
|
||||
self.ctx = ctx
|
||||
conf = ctx.astrbot_config
|
||||
settings = conf["provider_settings"]
|
||||
self.max_context_length = settings["max_context_length"] # int
|
||||
self.dequeue_context_length: int = min(
|
||||
max(1, settings["dequeue_context_length"]),
|
||||
self.max_context_length - 1,
|
||||
)
|
||||
self.streaming_response: bool = settings["streaming_response"]
|
||||
self.unsupported_streaming_strategy: str = settings[
|
||||
"unsupported_streaming_strategy"
|
||||
@@ -65,6 +61,25 @@ class InternalAgentSubStage(Stage):
|
||||
"moonshotai_api_key", ""
|
||||
)
|
||||
|
||||
# 上下文管理相关
|
||||
self.context_limit_reached_strategy: str = settings.get(
|
||||
"context_limit_reached_strategy", "truncate_by_turns"
|
||||
)
|
||||
self.llm_compress_instruction: str = settings.get(
|
||||
"llm_compress_instruction", ""
|
||||
)
|
||||
self.llm_compress_keep_recent: int = settings.get("llm_compress_keep_recent", 4)
|
||||
self.llm_compress_provider_id: str = settings.get(
|
||||
"llm_compress_provider_id", ""
|
||||
)
|
||||
self.max_context_length = settings["max_context_length"] # int
|
||||
self.dequeue_context_length: int = min(
|
||||
max(1, settings["dequeue_context_length"]),
|
||||
self.max_context_length - 1,
|
||||
)
|
||||
if self.dequeue_context_length <= 0:
|
||||
self.dequeue_context_length = 1
|
||||
|
||||
self.conv_manager = ctx.plugin_manager.context.conversation_manager
|
||||
|
||||
def _select_provider(self, event: AstrMessageEvent):
|
||||
@@ -167,34 +182,6 @@ class InternalAgentSubStage(Stage):
|
||||
},
|
||||
)
|
||||
|
||||
def _truncate_contexts(
|
||||
self,
|
||||
contexts: list[dict],
|
||||
) -> list[dict]:
|
||||
"""截断上下文列表,确保不超过最大长度"""
|
||||
if self.max_context_length == -1:
|
||||
return contexts
|
||||
|
||||
if len(contexts) // 2 <= self.max_context_length:
|
||||
return contexts
|
||||
|
||||
truncated_contexts = contexts[
|
||||
-(self.max_context_length - self.dequeue_context_length + 1) * 2 :
|
||||
]
|
||||
# 找到第一个role 为 user 的索引,确保上下文格式正确
|
||||
index = next(
|
||||
(
|
||||
i
|
||||
for i, item in enumerate(truncated_contexts)
|
||||
if item.get("role") == "user"
|
||||
),
|
||||
None,
|
||||
)
|
||||
if index is not None and index > 0:
|
||||
truncated_contexts = truncated_contexts[index:]
|
||||
|
||||
return truncated_contexts
|
||||
|
||||
def _modalities_fix(
|
||||
self,
|
||||
provider: Provider,
|
||||
@@ -296,6 +283,7 @@ class InternalAgentSubStage(Stage):
|
||||
req: ProviderRequest,
|
||||
llm_response: LLMResponse | None,
|
||||
all_messages: list[Message],
|
||||
runner_stats: AgentStats | None,
|
||||
):
|
||||
if (
|
||||
not req
|
||||
@@ -322,27 +310,37 @@ class InternalAgentSubStage(Stage):
|
||||
continue
|
||||
message_to_save.append(message.model_dump())
|
||||
|
||||
# get token usage from agent runner stats
|
||||
token_usage = None
|
||||
if runner_stats:
|
||||
token_usage = runner_stats.token_usage.total
|
||||
|
||||
await self.conv_manager.update_conversation(
|
||||
event.unified_msg_origin,
|
||||
req.conversation.cid,
|
||||
history=message_to_save,
|
||||
token_usage=token_usage,
|
||||
)
|
||||
|
||||
def _fix_messages(self, messages: list[dict]) -> list[dict]:
|
||||
"""验证并且修复上下文"""
|
||||
fixed_messages = []
|
||||
for message in messages:
|
||||
if message.get("role") == "tool":
|
||||
# tool block 前面必须要有 user 和 assistant block
|
||||
if len(fixed_messages) < 2:
|
||||
# 这种情况可能是上下文被截断导致的
|
||||
# 我们直接将之前的上下文都清空
|
||||
fixed_messages = []
|
||||
else:
|
||||
fixed_messages.append(message)
|
||||
else:
|
||||
fixed_messages.append(message)
|
||||
return fixed_messages
|
||||
def _get_compress_provider(self) -> Provider | None:
|
||||
if not self.llm_compress_provider_id:
|
||||
return None
|
||||
if self.context_limit_reached_strategy != "llm_compress":
|
||||
return None
|
||||
provider = self.ctx.plugin_manager.context.get_provider_by_id(
|
||||
self.llm_compress_provider_id,
|
||||
)
|
||||
if provider is None:
|
||||
logger.warning(
|
||||
f"未找到指定的上下文压缩模型 {self.llm_compress_provider_id},将跳过压缩。",
|
||||
)
|
||||
return None
|
||||
if not isinstance(provider, Provider):
|
||||
logger.warning(
|
||||
f"指定的上下文压缩模型 {self.llm_compress_provider_id} 不是对话模型,将跳过压缩。"
|
||||
)
|
||||
return None
|
||||
return provider
|
||||
|
||||
async def process(
|
||||
self, event: AstrMessageEvent, provider_wake_prefix: str
|
||||
@@ -426,9 +424,10 @@ class InternalAgentSubStage(Stage):
|
||||
await self._apply_kb(event, req)
|
||||
|
||||
# truncate contexts to fit max length
|
||||
if req.contexts:
|
||||
req.contexts = self._truncate_contexts(req.contexts)
|
||||
self._fix_messages(req.contexts)
|
||||
# NOW moved to ContextManager inside ToolLoopAgentRunner
|
||||
# if req.contexts:
|
||||
# req.contexts = self._truncate_contexts(req.contexts)
|
||||
# self._fix_messages(req.contexts)
|
||||
|
||||
# session_id
|
||||
if not req.session_id:
|
||||
@@ -444,8 +443,6 @@ class InternalAgentSubStage(Stage):
|
||||
self.unsupported_streaming_strategy == "turn_off"
|
||||
and not event.platform_meta.support_streaming_message
|
||||
)
|
||||
# 备份 req.contexts
|
||||
backup_contexts = copy.deepcopy(req.contexts)
|
||||
|
||||
# run agent
|
||||
agent_runner = AgentRunner()
|
||||
@@ -456,6 +453,15 @@ class InternalAgentSubStage(Stage):
|
||||
context=self.ctx.plugin_manager.context,
|
||||
event=event,
|
||||
)
|
||||
|
||||
# inject model context length limit
|
||||
if provider.provider_config.get("max_context_tokens", 0) <= 0:
|
||||
model = provider.get_model()
|
||||
if model_info := LLM_METADATAS.get(model):
|
||||
provider.provider_config["max_context_tokens"] = model_info[
|
||||
"limit"
|
||||
]["context"]
|
||||
|
||||
await agent_runner.reset(
|
||||
provider=provider,
|
||||
request=req,
|
||||
@@ -466,6 +472,11 @@ class InternalAgentSubStage(Stage):
|
||||
tool_executor=FunctionToolExecutor(),
|
||||
agent_hooks=MAIN_AGENT_HOOKS,
|
||||
streaming=streaming_response,
|
||||
llm_compress_instruction=self.llm_compress_instruction,
|
||||
llm_compress_keep_recent=self.llm_compress_keep_recent,
|
||||
llm_compress_provider=self._get_compress_provider(),
|
||||
truncate_turns=self.dequeue_context_length,
|
||||
enforce_max_turns=self.max_context_length,
|
||||
)
|
||||
|
||||
if streaming_response and not stream_to_general:
|
||||
@@ -511,14 +522,12 @@ class InternalAgentSubStage(Stage):
|
||||
):
|
||||
yield
|
||||
|
||||
# 恢复备份的 contexts
|
||||
req.contexts = backup_contexts
|
||||
|
||||
await self._save_to_history(
|
||||
event,
|
||||
req,
|
||||
agent_runner.get_final_llm_resp(),
|
||||
agent_runner.run_context.messages,
|
||||
agent_runner.stats,
|
||||
)
|
||||
|
||||
# 异步处理 WebChat 特殊情况
|
||||
|
||||
@@ -344,6 +344,11 @@ class LLMResponse:
|
||||
self.raw_completion = raw_completion
|
||||
self.is_chunk = is_chunk
|
||||
|
||||
if id is not None:
|
||||
self.id = id
|
||||
if usage is not None:
|
||||
self.usage = usage
|
||||
|
||||
@property
|
||||
def completion_text(self):
|
||||
if self.result_chain:
|
||||
|
||||
@@ -149,9 +149,12 @@ class Context:
|
||||
contexts: context messages for the LLM
|
||||
max_steps: Maximum number of tool calls before stopping the loop
|
||||
**kwargs: Additional keyword arguments. The kwargs will not be passed to the LLM directly for now, but can include:
|
||||
stream: bool - whether to stream the LLM response
|
||||
agent_hooks: BaseAgentRunHooks[AstrAgentContext] - hooks to run during agent execution
|
||||
agent_context: AstrAgentContext - context to use for the agent
|
||||
|
||||
other kwargs will be DIRECTLY passed to the runner.reset() method
|
||||
|
||||
Returns:
|
||||
The final LLMResponse after tool calls are completed.
|
||||
|
||||
@@ -194,6 +197,15 @@ class Context:
|
||||
)
|
||||
agent_runner = ToolLoopAgentRunner()
|
||||
tool_executor = FunctionToolExecutor()
|
||||
|
||||
streaming = kwargs.get("stream", False)
|
||||
|
||||
other_kwargs = {
|
||||
k: v
|
||||
for k, v in kwargs.items()
|
||||
if k not in ["stream", "agent_hooks", "agent_context"]
|
||||
}
|
||||
|
||||
await agent_runner.reset(
|
||||
provider=prov,
|
||||
request=request,
|
||||
@@ -203,7 +215,8 @@ class Context:
|
||||
),
|
||||
tool_executor=tool_executor,
|
||||
agent_hooks=agent_hooks,
|
||||
streaming=kwargs.get("stream", False),
|
||||
streaming=streaming,
|
||||
**other_kwargs,
|
||||
)
|
||||
async for _ in agent_runner.step_until_done(max_steps):
|
||||
pass
|
||||
|
||||
@@ -3,6 +3,7 @@ import traceback
|
||||
from astrbot.core import astrbot_config, logger
|
||||
from astrbot.core.astrbot_config_mgr import AstrBotConfig, AstrBotConfigManager
|
||||
from astrbot.core.db.migration.migra_45_to_46 import migrate_45_to_46
|
||||
from astrbot.core.db.migration.migra_token_usage import migrate_token_usage
|
||||
from astrbot.core.db.migration.migra_webchat_session import migrate_webchat_session
|
||||
|
||||
|
||||
@@ -139,6 +140,13 @@ async def migra(
|
||||
logger.error(f"Migration for webchat session failed: {e!s}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
# migration for token_usage column
|
||||
try:
|
||||
await migrate_token_usage(db)
|
||||
except Exception as e:
|
||||
logger.error(f"Migration for token_usage column failed: {e!s}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
# migra third party agent runner configs
|
||||
_c = False
|
||||
providers = astrbot_config["provider"]
|
||||
|
||||
@@ -993,6 +993,7 @@ class BackupRoute(Route):
|
||||
file_path,
|
||||
as_attachment=True,
|
||||
attachment_filename=filename,
|
||||
conditional=True, # 启用 Range 请求支持(断点续传)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"下载备份失败: {e}")
|
||||
|
||||
@@ -0,0 +1,19 @@
|
||||
## What's Changed
|
||||
|
||||
### 新增
|
||||
|
||||
- 支持上下文自动压缩功能。入口:配置文件 -> 上下文管理策略 -> 超出模型上下文窗口时的处理方式。详情请查看: [自动上下文压缩](https://docs.astrbot.app/use/context-compress.html) ([#4322](https://github.com/AstrBotDevs/AstrBot/issues/4322))
|
||||
- 新增 `on_waiting_llm_request` 事件钩子 ([#4319](https://github.com/AstrBotDevs/AstrBot/issues/4319))
|
||||
- WebUI 支持强制更新插件 ([#4293](https://github.com/AstrBotDevs/AstrBot/issues/4293))
|
||||
- 社区已提供适用于 [Matrix](https://github.com/stevessr/astrbot_plugin_matrix_adapter) 平台的适配器插件
|
||||
|
||||
### 修复
|
||||
|
||||
- 修复微信公众号中由于 msg.id 数据类型不匹配导致的重试失败问题 ([#4292](https://github.com/AstrBotDevs/AstrBot/issues/4292))
|
||||
- 修复调用 TTS 命令时出现的数据库锁定错误 ([#4313](https://github.com/AstrBotDevs/AstrBot/issues/4313))
|
||||
- 修复 Anthropic 提供商中 token 用量始终为 0 的问题 ([#4328](https://github.com/AstrBotDevs/AstrBot/issues/4328))
|
||||
|
||||
### 优化
|
||||
|
||||
- 完善共享组件的国际化支持 ([#4327](https://github.com/AstrBotDevs/AstrBot/issues/4327))
|
||||
- 优化下载大型备份文件时的稳定性,减少失败情况 ([#4329](https://github.com/AstrBotDevs/AstrBot/issues/4329))
|
||||
@@ -233,12 +233,12 @@ function getSpecialSubtype(value) {
|
||||
<div v-if="createSelectorModel(itemKey).value && createSelectorModel(itemKey).value.length > 0"
|
||||
class="selected-plugins-full-width">
|
||||
<div class="plugins-header">
|
||||
<small class="text-grey">已选择的插件:</small>
|
||||
<small class="text-grey">{{ t('core.shared.pluginSetSelector.selectedPluginsLabel') }}</small>
|
||||
</div>
|
||||
<div class="d-flex flex-wrap ga-2 mt-2">
|
||||
<v-chip v-for="plugin in (createSelectorModel(itemKey).value || [])" :key="plugin" size="small" label
|
||||
color="primary" variant="outlined">
|
||||
{{ plugin === '*' ? '所有插件' : plugin }}
|
||||
{{ plugin === '*' ? t('core.shared.pluginSetSelector.allPluginsLabel') : plugin }}
|
||||
</v-chip>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -20,13 +20,13 @@
|
||||
</template>
|
||||
<template v-else-if="itemMeta?._special === 'provider_pool'">
|
||||
<ProviderSelector :model-value="modelValue" @update:model-value="emitUpdate" :provider-type="'chat_completion'"
|
||||
button-text="选择提供商池..." />
|
||||
:button-text="t('core.shared.providerSelector.selectProviderPool')" />
|
||||
</template>
|
||||
<template v-else-if="itemMeta?._special === 'select_persona'">
|
||||
<PersonaSelector :model-value="modelValue" @update:model-value="emitUpdate" />
|
||||
</template>
|
||||
<template v-else-if="itemMeta?._special === 'persona_pool'">
|
||||
<PersonaSelector :model-value="modelValue" @update:model-value="emitUpdate" button-text="选择人格池..." />
|
||||
<PersonaSelector :model-value="modelValue" @update:model-value="emitUpdate" :button-text="t('core.shared.personaSelector.selectPersonaPool')" />
|
||||
</template>
|
||||
<template v-else-if="itemMeta?._special === 'select_knowledgebase'">
|
||||
<KnowledgeBaseSelector :model-value="modelValue" @update:model-value="emitUpdate" />
|
||||
@@ -56,7 +56,7 @@
|
||||
:loading="loading"
|
||||
class="ml-2"
|
||||
>
|
||||
自动检测
|
||||
{{ t('core.common.autoDetect') }}
|
||||
</v-btn>
|
||||
</div>
|
||||
</template>
|
||||
@@ -144,7 +144,7 @@
|
||||
color="primary"
|
||||
density="compact"
|
||||
hide-details
|
||||
class="flex-grow-1"
|
||||
style="flex: 1"
|
||||
></v-slider>
|
||||
<v-text-field
|
||||
:model-value="modelValue"
|
||||
@@ -154,7 +154,7 @@
|
||||
class="config-field"
|
||||
type="number"
|
||||
hide-details
|
||||
style="max-width: 140px;"
|
||||
style="flex: 1"
|
||||
></v-text-field>
|
||||
</div>
|
||||
|
||||
@@ -325,4 +325,8 @@ function getSpecialSubtype(value) {
|
||||
.gap-20 {
|
||||
gap: 20px;
|
||||
}
|
||||
|
||||
:deep(.v-field__input) {
|
||||
font-size: 14px;
|
||||
}
|
||||
</style>
|
||||
|
||||
@@ -20,7 +20,7 @@
|
||||
</div>
|
||||
</div>
|
||||
<v-btn size="small" color="primary" variant="tonal" @click="openDialog" style="flex-shrink: 0;">
|
||||
{{ buttonText }}
|
||||
{{ buttonText || tm('knowledgeBaseSelector.buttonText') }}
|
||||
</v-btn>
|
||||
</div>
|
||||
|
||||
@@ -105,7 +105,7 @@ const props = defineProps({
|
||||
},
|
||||
buttonText: {
|
||||
type: String,
|
||||
default: '选择知识库...'
|
||||
default: ''
|
||||
}
|
||||
})
|
||||
|
||||
|
||||
@@ -175,11 +175,11 @@ const props = defineProps({
|
||||
},
|
||||
buttonText: {
|
||||
type: String,
|
||||
default: '修改'
|
||||
default: ''
|
||||
},
|
||||
dialogTitle: {
|
||||
type: String,
|
||||
default: '修改列表项'
|
||||
default: ''
|
||||
},
|
||||
maxDisplayItems: {
|
||||
type: Number,
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
<template>
|
||||
<div class="d-flex align-center justify-space-between">
|
||||
<span v-if="!modelValue" style="color: rgb(var(--v-theme-primaryText));">
|
||||
未选择
|
||||
{{ tm('personaSelector.notSelected') }}
|
||||
</span>
|
||||
<span v-else>
|
||||
{{ modelValue === 'default' ? '默认人格' : modelValue }}
|
||||
{{ modelValue === 'default' ? tm('personaSelector.defaultPersona') : modelValue }}
|
||||
</span>
|
||||
<v-btn size="small" color="primary" variant="tonal" @click="openDialog">
|
||||
{{ buttonText }}
|
||||
{{ buttonText || tm('personaSelector.buttonText') }}
|
||||
</v-btn>
|
||||
</div>
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
<v-dialog v-model="dialog" max-width="600px">
|
||||
<v-card>
|
||||
<v-card-title class="text-h3 py-4" style="font-weight: normal;">
|
||||
选择人格
|
||||
{{ tm('personaSelector.dialogTitle') }}
|
||||
</v-card-title>
|
||||
|
||||
<v-card-text class="pa-2" style="max-height: 400px; overflow-y: auto;">
|
||||
@@ -30,9 +30,9 @@
|
||||
:active="selectedPersona === persona.persona_id"
|
||||
rounded="md"
|
||||
class="ma-1">
|
||||
<v-list-item-title>{{ persona.persona_id === 'default' ? '默认人格' : persona.persona_id }}</v-list-item-title>
|
||||
<v-list-item-title>{{ persona.persona_id === 'default' ? tm('personaSelector.defaultPersona') : persona.persona_id }}</v-list-item-title>
|
||||
<v-list-item-subtitle>
|
||||
{{ persona.system_prompt ? persona.system_prompt.substring(0, 50) + '...' : '无描述' }}
|
||||
{{ persona.system_prompt ? persona.system_prompt.substring(0, 50) + '...' : tm('personaSelector.noDescription') }}
|
||||
</v-list-item-subtitle>
|
||||
|
||||
<template v-slot:append>
|
||||
@@ -43,21 +43,21 @@
|
||||
|
||||
<div v-else-if="!loading && personaList.length === 0" class="text-center py-8">
|
||||
<v-icon size="64" color="grey-lighten-1">mdi-account-off</v-icon>
|
||||
<p class="text-grey mt-4">暂无可用的人格</p>
|
||||
<p class="text-grey mt-4">{{ tm('personaSelector.noPersonas') }}</p>
|
||||
</div>
|
||||
</v-card-text>
|
||||
|
||||
<v-card-actions class="pa-4">
|
||||
<v-btn variant="text" color="primary" prepend-icon="mdi-plus" @click="openCreatePersona">
|
||||
创建新人格
|
||||
{{ tm('personaSelector.createPersona') }}
|
||||
</v-btn>
|
||||
<v-spacer></v-spacer>
|
||||
<v-btn variant="text" @click="cancelSelection">取消</v-btn>
|
||||
<v-btn
|
||||
color="primary"
|
||||
<v-btn variant="text" @click="cancelSelection">{{ t('core.common.cancel') }}</v-btn>
|
||||
<v-btn
|
||||
color="primary"
|
||||
@click="confirmSelection"
|
||||
:disabled="!selectedPersona">
|
||||
确认选择
|
||||
{{ t('core.common.confirm') }}
|
||||
</v-btn>
|
||||
</v-card-actions>
|
||||
</v-card>
|
||||
@@ -78,6 +78,7 @@
|
||||
import { ref, watch } from 'vue'
|
||||
import axios from 'axios'
|
||||
import PersonaForm from './PersonaForm.vue'
|
||||
import { useI18n, useModuleI18n } from '@/i18n/composables'
|
||||
|
||||
const props = defineProps({
|
||||
modelValue: {
|
||||
@@ -86,11 +87,13 @@ const props = defineProps({
|
||||
},
|
||||
buttonText: {
|
||||
type: String,
|
||||
default: '选择人格...'
|
||||
default: ''
|
||||
}
|
||||
})
|
||||
|
||||
const emit = defineEmits(['update:modelValue'])
|
||||
const { t } = useI18n()
|
||||
const { tm } = useModuleI18n('core.shared')
|
||||
|
||||
const dialog = ref(false)
|
||||
const personaList = ref([])
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
</span>
|
||||
</div>
|
||||
<v-btn size="small" color="primary" variant="tonal" @click="openDialog">
|
||||
{{ buttonText }}
|
||||
{{ buttonText || tm('pluginSetSelector.buttonText') }}
|
||||
</v-btn>
|
||||
</div>
|
||||
</div>
|
||||
@@ -113,7 +113,7 @@ const props = defineProps({
|
||||
},
|
||||
buttonText: {
|
||||
type: String,
|
||||
default: '选择插件集合...'
|
||||
default: ''
|
||||
},
|
||||
maxDisplayItems: {
|
||||
type: Number,
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
{{ modelValue }}
|
||||
</span>
|
||||
<v-btn size="small" color="primary" variant="tonal" @click="openDialog">
|
||||
{{ buttonText }}
|
||||
{{ buttonText || tm('providerSelector.buttonText') }}
|
||||
</v-btn>
|
||||
</div>
|
||||
|
||||
@@ -134,7 +134,7 @@ const props = defineProps({
|
||||
},
|
||||
buttonText: {
|
||||
type: String,
|
||||
default: '选择提供商...'
|
||||
default: ''
|
||||
}
|
||||
})
|
||||
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
<template>
|
||||
<h5>GitHub 加速</h5>
|
||||
<h5>{{ tm('network.proxySelector.title') }}</h5>
|
||||
<v-radio-group class="mt-2" v-model="radioValue" hide-details="true">
|
||||
<v-radio label="不使用 GitHub 加速" value="0"></v-radio>
|
||||
<v-radio :label="tm('network.proxySelector.noProxy')" value="0"></v-radio>
|
||||
<v-radio value="1">
|
||||
<template v-slot:label>
|
||||
<span>使用 GitHub 加速</span>
|
||||
<span>{{ tm('network.proxySelector.useProxy') }}</span>
|
||||
<v-btn v-if="radioValue === '1'" class="ml-2" @click="testAllProxies" size="x-small"
|
||||
variant="tonal" :loading="loadingTestingConnection">
|
||||
测试代理连通性
|
||||
{{ tm('network.proxySelector.testConnection') }}
|
||||
</v-btn>
|
||||
</template>
|
||||
</v-radio>
|
||||
@@ -20,15 +20,15 @@
|
||||
<div class="d-flex align-center">
|
||||
<span class="mr-2">{{ proxy }}</span>
|
||||
<div v-if="proxyStatus[idx]">
|
||||
<v-chip
|
||||
:color="proxyStatus[idx].available ? 'success' : 'error'"
|
||||
size="x-small"
|
||||
<v-chip
|
||||
:color="proxyStatus[idx].available ? 'success' : 'error'"
|
||||
size="x-small"
|
||||
class="mr-1">
|
||||
{{ proxyStatus[idx].available ? '可用' : '不可用' }}
|
||||
{{ proxyStatus[idx].available ? tm('network.proxySelector.available') : tm('network.proxySelector.unavailable') }}
|
||||
</v-chip>
|
||||
<v-chip
|
||||
v-if="proxyStatus[idx].available"
|
||||
color="info"
|
||||
<v-chip
|
||||
v-if="proxyStatus[idx].available"
|
||||
color="info"
|
||||
size="x-small">
|
||||
{{ proxyStatus[idx].latency }}ms
|
||||
</v-chip>
|
||||
@@ -36,10 +36,10 @@
|
||||
</div>
|
||||
</template>
|
||||
</v-radio>
|
||||
<v-radio color="primary" value="-1" label="自定义">
|
||||
<v-radio color="primary" value="-1" :label="tm('network.proxySelector.custom')">
|
||||
<template v-slot:label v-if="githubProxyRadioControl === '-1'">
|
||||
<v-text-field density="compact" v-model="selectedGitHubProxy" variant="outlined"
|
||||
style="width: 100vw;" placeholder="自定义" hide-details="true">
|
||||
style="width: 100vw;" :placeholder="tm('network.proxySelector.custom')" hide-details="true">
|
||||
</v-text-field>
|
||||
</template>
|
||||
</v-radio>
|
||||
|
||||
@@ -1,32 +1,32 @@
|
||||
<template>
|
||||
<v-dialog v-model="dialog" max-width="1400px" persistent scrollable>
|
||||
<template v-slot:activator="{ props }">
|
||||
<v-btn
|
||||
<v-btn
|
||||
v-bind="props"
|
||||
variant="outlined"
|
||||
color="primary"
|
||||
variant="outlined"
|
||||
color="primary"
|
||||
size="small"
|
||||
:loading="loading"
|
||||
>
|
||||
自定义 T2I 模板
|
||||
{{ tm('t2iTemplateEditor.buttonText') }}
|
||||
</v-btn>
|
||||
</template>
|
||||
|
||||
<v-card>
|
||||
<v-card-title class="d-flex align-center justify-space-between">
|
||||
<span>自定义文转图 HTML 模板</span>
|
||||
<span>{{ tm('t2iTemplateEditor.dialogTitle') }}</span>
|
||||
<v-spacer></v-spacer>
|
||||
<div class="d-flex align-center gap-2" style="width: 60%">
|
||||
<v-text-field
|
||||
v-if="isCreatingNew"
|
||||
v-model="editingName"
|
||||
label="输入新模板名称"
|
||||
:label="tm('t2iTemplateEditor.newTemplateNameLabel')"
|
||||
density="compact"
|
||||
hide-details
|
||||
variant="outlined"
|
||||
class="flex-grow-1"
|
||||
autofocus
|
||||
:rules="[v => !!v || '名称不能为空']"
|
||||
:rules="[v => !!v || tm('t2iTemplateEditor.nameRequired')]"
|
||||
></v-text-field>
|
||||
<v-select
|
||||
v-else
|
||||
@@ -34,7 +34,7 @@
|
||||
:items="templates"
|
||||
item-title="name"
|
||||
item-value="name"
|
||||
label="选择模板"
|
||||
:label="tm('t2iTemplateEditor.selectTemplateLabel')"
|
||||
density="compact"
|
||||
hide-details
|
||||
variant="outlined"
|
||||
@@ -51,7 +51,7 @@
|
||||
size="small"
|
||||
class="ml-2"
|
||||
>
|
||||
已应用
|
||||
{{ tm('t2iTemplateEditor.applied') }}
|
||||
</v-chip>
|
||||
<v-btn
|
||||
v-else
|
||||
@@ -62,7 +62,7 @@
|
||||
@click.stop="setActiveTemplate(item.raw.name)"
|
||||
:loading="applyLoading"
|
||||
>
|
||||
应用
|
||||
{{ tm('t2iTemplateEditor.apply') }}
|
||||
</v-btn>
|
||||
</template>
|
||||
</v-list-item>
|
||||
@@ -83,7 +83,7 @@
|
||||
<!-- 左侧编辑器 -->
|
||||
<v-col cols="6" class="d-flex flex-column">
|
||||
<v-toolbar density="compact" color="surface-variant">
|
||||
<v-toolbar-title class="text-subtitle-2">模板编辑器</v-toolbar-title>
|
||||
<v-toolbar-title class="text-subtitle-2">{{ tm('t2iTemplateEditor.templateEditor') }}</v-toolbar-title>
|
||||
<v-spacer></v-spacer>
|
||||
<div class="d-flex align-center pa-1" style="border: 1px solid rgba(0,0,0,0.1); border-radius: 8px;">
|
||||
<v-btn
|
||||
@@ -93,7 +93,7 @@
|
||||
color="success"
|
||||
>
|
||||
<v-icon left>mdi-plus</v-icon>
|
||||
新建
|
||||
{{ tm('t2iTemplateEditor.new') }}
|
||||
</v-btn>
|
||||
<v-divider vertical class="mx-1"></v-divider>
|
||||
<v-btn
|
||||
@@ -103,7 +103,7 @@
|
||||
:loading="resetLoading"
|
||||
color="warning"
|
||||
>
|
||||
重置Base
|
||||
{{ tm('t2iTemplateEditor.resetBase') }}
|
||||
</v-btn>
|
||||
<v-btn
|
||||
variant="text"
|
||||
@@ -112,7 +112,7 @@
|
||||
color="error"
|
||||
:disabled="isCreatingNew || selectedTemplate === 'base' || !selectedTemplate"
|
||||
>
|
||||
删除
|
||||
{{ tm('t2iTemplateEditor.delete') }}
|
||||
</v-btn>
|
||||
<v-divider vertical class="mx-1"></v-divider>
|
||||
<v-btn
|
||||
@@ -123,7 +123,7 @@
|
||||
color="primary"
|
||||
:disabled="(isCreatingNew && !editingName) || (!isCreatingNew && !selectedTemplate)"
|
||||
>
|
||||
保存
|
||||
{{ tm('t2iTemplateEditor.save') }}
|
||||
</v-btn>
|
||||
</div>
|
||||
</v-toolbar>
|
||||
@@ -141,15 +141,15 @@
|
||||
<!-- 右侧预览 -->
|
||||
<v-col cols="6" class="d-flex flex-column">
|
||||
<v-toolbar density="compact" color="surface-variant">
|
||||
<v-toolbar-title class="text-subtitle-2">实时预览(可能有差异)</v-toolbar-title>
|
||||
<v-toolbar-title class="text-subtitle-2">{{ tm('t2iTemplateEditor.livePreview') }}</v-toolbar-title>
|
||||
<v-spacer></v-spacer>
|
||||
<v-btn
|
||||
variant="text"
|
||||
size="small"
|
||||
<v-btn
|
||||
variant="text"
|
||||
size="small"
|
||||
@click="refreshPreview"
|
||||
:loading="previewLoading"
|
||||
>
|
||||
刷新预览
|
||||
{{ tm('t2iTemplateEditor.refreshPreview') }}
|
||||
</v-btn>
|
||||
</v-toolbar>
|
||||
<div class="flex-grow-1 preview-container">
|
||||
@@ -168,7 +168,7 @@
|
||||
<v-col>
|
||||
<div class="text-caption text-grey">
|
||||
<v-icon size="16" class="mr-1">mdi-information</v-icon>
|
||||
支持 jinja2 语法。可用变量:<code> text | safe </code>(要渲染的文本), <code> version </code>(AstrBot 版本)
|
||||
{{ tm('t2iTemplateEditor.syntaxHint') }}
|
||||
</div>
|
||||
</v-col>
|
||||
<v-col cols="auto">
|
||||
@@ -176,7 +176,7 @@
|
||||
variant="text"
|
||||
@click="closeDialog"
|
||||
>
|
||||
取消
|
||||
{{ t('core.common.cancel') }}
|
||||
</v-btn>
|
||||
<v-btn
|
||||
color="primary"
|
||||
@@ -184,7 +184,7 @@
|
||||
:loading="saveLoading"
|
||||
:disabled="isCreatingNew || !selectedTemplate"
|
||||
>
|
||||
保存应用当前编辑模板
|
||||
{{ tm('t2iTemplateEditor.saveAndApply') }}
|
||||
</v-btn>
|
||||
</v-col>
|
||||
</v-row>
|
||||
@@ -194,14 +194,14 @@
|
||||
<!-- 确认重置对话框 -->
|
||||
<v-dialog v-model="resetDialog" max-width="400px">
|
||||
<v-card>
|
||||
<v-card-title>确认重置</v-card-title>
|
||||
<v-card-title>{{ tm('t2iTemplateEditor.confirmReset') }}</v-card-title>
|
||||
<v-card-text>
|
||||
确定要将 'base' 模板恢复为默认内容吗?当前编辑器中的任何未保存更改将丢失。此操作无法撤销。
|
||||
{{ tm('t2iTemplateEditor.confirmResetMessage') }}
|
||||
</v-card-text>
|
||||
<v-card-actions>
|
||||
<v-spacer></v-spacer>
|
||||
<v-btn text @click="resetDialog = false">取消</v-btn>
|
||||
<v-btn color="warning" @click="confirmReset" :loading="resetLoading">确认重置</v-btn>
|
||||
<v-btn text @click="resetDialog = false">{{ t('core.common.cancel') }}</v-btn>
|
||||
<v-btn color="warning" @click="confirmReset" :loading="resetLoading">{{ tm('t2iTemplateEditor.confirmResetButton') }}</v-btn>
|
||||
</v-card-actions>
|
||||
</v-card>
|
||||
</v-dialog>
|
||||
@@ -209,14 +209,14 @@
|
||||
<!-- 删除确认对话框 -->
|
||||
<v-dialog v-model="deleteDialog" max-width="400px">
|
||||
<v-card>
|
||||
<v-card-title>确认删除</v-card-title>
|
||||
<v-card-title>{{ tm('t2iTemplateEditor.confirmDelete') }}</v-card-title>
|
||||
<v-card-text>
|
||||
确定要删除模板 '{{ selectedTemplate }}' 吗?此操作无法撤销。
|
||||
{{ tm('t2iTemplateEditor.confirmDeleteMessage', { name: selectedTemplate }) }}
|
||||
</v-card-text>
|
||||
<v-card-actions>
|
||||
<v-spacer></v-spacer>
|
||||
<v-btn text @click="deleteDialog = false">取消</v-btn>
|
||||
<v-btn color="error" @click="confirmDelete" :loading="saveLoading">确认删除</v-btn>
|
||||
<v-btn text @click="deleteDialog = false">{{ t('core.common.cancel') }}</v-btn>
|
||||
<v-btn color="error" @click="confirmDelete" :loading="saveLoading">{{ tm('t2iTemplateEditor.confirmDeleteButton') }}</v-btn>
|
||||
</v-card-actions>
|
||||
</v-card>
|
||||
</v-dialog>
|
||||
@@ -224,14 +224,14 @@
|
||||
<!-- 保存并应用确认对话框 -->
|
||||
<v-dialog v-model="applyAndCloseDialog" max-width="500px">
|
||||
<v-card>
|
||||
<v-card-title>确认操作</v-card-title>
|
||||
<v-card-title>{{ tm('t2iTemplateEditor.confirmAction') }}</v-card-title>
|
||||
<v-card-text>
|
||||
确定要保存对 '{{ selectedTemplate }}' 的修改,并将其设为新的活动模板吗?
|
||||
{{ tm('t2iTemplateEditor.confirmApplyMessage', { name: selectedTemplate }) }}
|
||||
</v-card-text>
|
||||
<v-card-actions>
|
||||
<v-spacer></v-spacer>
|
||||
<v-btn text @click="applyAndCloseDialog = false">取消</v-btn>
|
||||
<v-btn color="primary" @click="confirmApplyAndClose" :loading="saveLoading">确认</v-btn>
|
||||
<v-btn text @click="applyAndCloseDialog = false">{{ t('core.common.cancel') }}</v-btn>
|
||||
<v-btn color="primary" @click="confirmApplyAndClose" :loading="saveLoading">{{ t('core.common.confirm') }}</v-btn>
|
||||
</v-card-actions>
|
||||
</v-card>
|
||||
</v-dialog>
|
||||
@@ -242,10 +242,11 @@
|
||||
<script setup>
|
||||
import { ref, computed, nextTick, watch } from 'vue'
|
||||
import { VueMonacoEditor } from '@guolao/vue-monaco-editor'
|
||||
import { useI18n } from '@/i18n/composables'
|
||||
import { useI18n, useModuleI18n } from '@/i18n/composables'
|
||||
import axios from 'axios'
|
||||
|
||||
const { t } = useI18n()
|
||||
const { tm } = useModuleI18n('core.shared')
|
||||
|
||||
// --- 响应式数据 ---
|
||||
const dialog = ref(false)
|
||||
|
||||
@@ -510,7 +510,7 @@ export function useProviderSources(options: UseProviderSourcesOptions) {
|
||||
|
||||
const metadata = getModelMetadata(modelName)
|
||||
let modalities: string[]
|
||||
|
||||
|
||||
if (!metadata) {
|
||||
modalities = ['text', 'image', 'tool_use']
|
||||
} else {
|
||||
@@ -523,13 +523,19 @@ export function useProviderSources(options: UseProviderSourcesOptions) {
|
||||
}
|
||||
}
|
||||
|
||||
let max_context_tokens = 0
|
||||
if (metadata?.limit?.context && typeof metadata.limit.context === 'number') {
|
||||
max_context_tokens = metadata.limit.context
|
||||
}
|
||||
|
||||
const newProvider = {
|
||||
id: newId,
|
||||
enable: false,
|
||||
provider_source_id: sourceId,
|
||||
model: modelName,
|
||||
modalities,
|
||||
custom_extra_body: {}
|
||||
custom_extra_body: {},
|
||||
max_context_tokens: max_context_tokens
|
||||
}
|
||||
|
||||
try {
|
||||
|
||||
@@ -35,6 +35,7 @@
|
||||
"yes": "Yes",
|
||||
"no": "No",
|
||||
"imagePreview": "Image Preview",
|
||||
"autoDetect": "Auto Detect",
|
||||
"dialog": {
|
||||
"confirmTitle": "Confirm Action",
|
||||
"confirmMessage": "Are you sure you want to perform this action?",
|
||||
|
||||
@@ -28,7 +28,9 @@
|
||||
"cancelSelection": "Cancel",
|
||||
"noDescription": "No description",
|
||||
"notActivated": "Not activated",
|
||||
"note": "*System plugins and disabled plugins are not shown."
|
||||
"note": "*System plugins and disabled plugins are not shown.",
|
||||
"selectedPluginsLabel": "Selected Plugins:",
|
||||
"allPluginsLabel": "All Plugins"
|
||||
},
|
||||
"providerSelector": {
|
||||
"notSelected": "Not selected",
|
||||
@@ -42,6 +44,45 @@
|
||||
"clearSelectionSubtitle": "Clear current selection",
|
||||
"unknownType": "Unknown type",
|
||||
"createProvider": "Create Provider",
|
||||
"manageProviders": "Provider Management"
|
||||
"manageProviders": "Provider Management",
|
||||
"selectProviderPool": "Select Provider Pool..."
|
||||
},
|
||||
"personaSelector": {
|
||||
"notSelected": "Not selected",
|
||||
"defaultPersona": "Default Persona",
|
||||
"buttonText": "Select Persona...",
|
||||
"dialogTitle": "Select Persona",
|
||||
"noDescription": "No description",
|
||||
"noPersonas": "No personas available",
|
||||
"createPersona": "Create New Persona",
|
||||
"cancelSelection": "Cancel",
|
||||
"confirmSelection": "Confirm Selection",
|
||||
"selectPersonaPool": "Select Persona Pool..."
|
||||
},
|
||||
"t2iTemplateEditor": {
|
||||
"buttonText": "Customize T2I Template",
|
||||
"dialogTitle": "Customize Text-to-Image HTML Template",
|
||||
"newTemplateNameLabel": "Enter new template name",
|
||||
"nameRequired": "Name is required",
|
||||
"selectTemplateLabel": "Select Template",
|
||||
"applied": "Applied",
|
||||
"apply": "Apply",
|
||||
"templateEditor": "Template Editor",
|
||||
"new": "New",
|
||||
"resetBase": "Reset Base",
|
||||
"delete": "Delete",
|
||||
"save": "Save",
|
||||
"livePreview": "Live Preview (may differ)",
|
||||
"refreshPreview": "Refresh Preview",
|
||||
"syntaxHint": "Supports jinja2 syntax. Available variables: text | safe (text to render), version (AstrBot version)",
|
||||
"saveAndApply": "Save and Apply Current Template",
|
||||
"confirmReset": "Confirm Reset",
|
||||
"confirmResetMessage": "Are you sure you want to reset the 'base' template to default content? Any unsaved changes in the editor will be lost. This action cannot be undone.",
|
||||
"confirmResetButton": "Confirm Reset",
|
||||
"confirmDelete": "Confirm Delete",
|
||||
"confirmDeleteMessage": "Are you sure you want to delete template '{name}'? This action cannot be undone.",
|
||||
"confirmDeleteButton": "Confirm Delete",
|
||||
"confirmAction": "Confirm Action",
|
||||
"confirmApplyMessage": "Are you sure you want to save changes to '{name}' and set it as the active template?"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,7 +11,12 @@
|
||||
},
|
||||
"agent_runner_type": {
|
||||
"description": "Runner",
|
||||
"labels": ["Built-in Agent", "Dify", "Coze", "Alibaba Cloud Bailian Application"]
|
||||
"labels": [
|
||||
"Built-in Agent",
|
||||
"Dify",
|
||||
"Coze",
|
||||
"Alibaba Cloud Bailian Application"
|
||||
]
|
||||
},
|
||||
"coze_agent_runner_provider_id": {
|
||||
"description": "Coze Agent Runner Provider ID"
|
||||
@@ -128,6 +133,39 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"truncate_and_compress": {
|
||||
"description": "Context Management Strategy",
|
||||
"provider_settings": {
|
||||
"max_context_length": {
|
||||
"description": "Maximum Conversation Turns",
|
||||
"hint": "Discards the oldest parts when this count is exceeded. One conversation round counts as 1, -1 means unlimited"
|
||||
},
|
||||
"dequeue_context_length": {
|
||||
"description": "Dequeue Conversation Turns",
|
||||
"hint": "Number of conversation turns to discard at once when maximum context length is exceeded"
|
||||
},
|
||||
"context_limit_reached_strategy": {
|
||||
"description": "Handling When Model Context Window is Exceeded",
|
||||
"labels": [
|
||||
"Truncate by Turns",
|
||||
"Compress by LLM"
|
||||
],
|
||||
"hint": "When 'Truncate by Turns' is selected, the oldest N conversation turns will be discarded based on the 'Dequeue Conversation Turns' setting above. When 'Compress by LLM' is selected, the specified model will be used for context compression."
|
||||
},
|
||||
"llm_compress_instruction": {
|
||||
"description": "Context Compression Instruction",
|
||||
"hint": "If empty, the default prompt will be used."
|
||||
},
|
||||
"llm_compress_keep_recent": {
|
||||
"description": "Keep Recent Turns When Compressing",
|
||||
"hint": "Always keep the most recent N turns of conversation when compressing context."
|
||||
},
|
||||
"llm_compress_provider_id": {
|
||||
"description": "Model Provider ID for Context Compression",
|
||||
"hint": "When left empty, will fall back to the 'Truncate by Turns' strategy."
|
||||
}
|
||||
}
|
||||
},
|
||||
"others": {
|
||||
"description": "Other Settings",
|
||||
"provider_settings": {
|
||||
@@ -161,15 +199,10 @@
|
||||
"unsupported_streaming_strategy": {
|
||||
"description": "Platforms Without Streaming Support",
|
||||
"hint": "Select the handling method for platforms that don't support streaming responses. Real-time segmented reply sends content immediately when the system detects segment points like punctuation during streaming reception",
|
||||
"labels": ["Real-time Segmented Reply", "Disable Streaming Response"]
|
||||
},
|
||||
"max_context_length": {
|
||||
"description": "Maximum Conversation Rounds",
|
||||
"hint": "Discards the oldest parts when this count is exceeded. One conversation round counts as 1, -1 means unlimited"
|
||||
},
|
||||
"dequeue_context_length": {
|
||||
"description": "Dequeue Conversation Rounds",
|
||||
"hint": "Number of conversation rounds to discard at once when maximum context length is exceeded"
|
||||
"labels": [
|
||||
"Real-time Segmented Reply",
|
||||
"Disable Streaming Response"
|
||||
]
|
||||
},
|
||||
"wake_prefix": {
|
||||
"description": "Additional LLM Chat Wake Prefix",
|
||||
@@ -387,7 +420,10 @@
|
||||
},
|
||||
"split_mode": {
|
||||
"description": "Split Mode",
|
||||
"labels": ["Regex", "Words List"]
|
||||
"labels": [
|
||||
"Regex",
|
||||
"Words List"
|
||||
]
|
||||
},
|
||||
"regex": {
|
||||
"description": "Segmentation Regular Expression"
|
||||
@@ -488,4 +524,4 @@
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -5,6 +5,15 @@
|
||||
"title": "GitHub Proxy Address",
|
||||
"subtitle": "Set the GitHub proxy address used when downloading plugins or updating AstrBot. This is effective in mainland China's network environment. Can be customized, input takes effect in real time. All addresses do not guarantee stability. If errors occur when updating plugins/projects, please first check if the proxy address is working properly.",
|
||||
"label": "Select GitHub Proxy Address"
|
||||
},
|
||||
"proxySelector": {
|
||||
"title": "GitHub Proxy",
|
||||
"noProxy": "Don't use GitHub Proxy",
|
||||
"useProxy": "Use GitHub Proxy",
|
||||
"testConnection": "Test Connection",
|
||||
"available": "Available",
|
||||
"unavailable": "Unavailable",
|
||||
"custom": "Custom"
|
||||
}
|
||||
},
|
||||
"system": {
|
||||
|
||||
@@ -35,6 +35,7 @@
|
||||
"yes": "是",
|
||||
"no": "否",
|
||||
"imagePreview": "图片预览",
|
||||
"autoDetect": "自动检测",
|
||||
"dialog": {
|
||||
"confirmTitle": "确认操作",
|
||||
"confirmMessage": "你确定要执行此操作吗?",
|
||||
|
||||
@@ -28,7 +28,9 @@
|
||||
"cancelSelection": "取消",
|
||||
"noDescription": "无描述",
|
||||
"notActivated": "未激活",
|
||||
"note": "*不显示系统插件和已经在插件页禁用的插件。"
|
||||
"note": "*不显示系统插件和已经在插件页禁用的插件。",
|
||||
"selectedPluginsLabel": "已选择的插件:",
|
||||
"allPluginsLabel": "所有插件"
|
||||
},
|
||||
"providerSelector": {
|
||||
"notSelected": "未选择",
|
||||
@@ -42,6 +44,45 @@
|
||||
"clearSelectionSubtitle": "清除当前选择",
|
||||
"unknownType": "未知类型",
|
||||
"createProvider": "创建提供商",
|
||||
"manageProviders": "提供商管理"
|
||||
"manageProviders": "提供商管理",
|
||||
"selectProviderPool": "选择提供商池..."
|
||||
},
|
||||
"personaSelector": {
|
||||
"notSelected": "未选择",
|
||||
"defaultPersona": "默认人格",
|
||||
"buttonText": "选择人格...",
|
||||
"dialogTitle": "选择人格",
|
||||
"noDescription": "无描述",
|
||||
"noPersonas": "暂无可用的人格",
|
||||
"createPersona": "创建新人格",
|
||||
"cancelSelection": "取消",
|
||||
"confirmSelection": "确认选择",
|
||||
"selectPersonaPool": "选择人格池..."
|
||||
},
|
||||
"t2iTemplateEditor": {
|
||||
"buttonText": "自定义 T2I 模板",
|
||||
"dialogTitle": "自定义文转图 HTML 模板",
|
||||
"newTemplateNameLabel": "输入新模板名称",
|
||||
"nameRequired": "名称不能为空",
|
||||
"selectTemplateLabel": "选择模板",
|
||||
"applied": "已应用",
|
||||
"apply": "应用",
|
||||
"templateEditor": "模板编辑器",
|
||||
"new": "新建",
|
||||
"resetBase": "重置Base",
|
||||
"delete": "删除",
|
||||
"save": "保存",
|
||||
"livePreview": "实时预览(可能有差异)",
|
||||
"refreshPreview": "刷新预览",
|
||||
"syntaxHint": "支持 jinja2 语法。可用变量:text | safe(要渲染的文本), version(AstrBot 版本)",
|
||||
"saveAndApply": "保存应用当前编辑模板",
|
||||
"confirmReset": "确认重置",
|
||||
"confirmResetMessage": "确定要将 'base' 模板恢复为默认内容吗?当前编辑器中的任何未保存更改将丢失。此操作无法撤销。",
|
||||
"confirmResetButton": "确认重置",
|
||||
"confirmDelete": "确认删除",
|
||||
"confirmDeleteMessage": "确定要删除模板 '{name}' 吗?此操作无法撤销。",
|
||||
"confirmDeleteButton": "确认删除",
|
||||
"confirmAction": "确认操作",
|
||||
"confirmApplyMessage": "确定要保存对 '{name}' 的修改,并将其设为新的活动模板吗?"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -133,6 +133,36 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"truncate_and_compress": {
|
||||
"description": "上下文管理策略",
|
||||
"provider_settings": {
|
||||
"max_context_length": {
|
||||
"description": "最多携带对话轮数",
|
||||
"hint": "超出这个数量时丢弃最旧的部分,一轮聊天记为 1 条,-1 为不限制"
|
||||
},
|
||||
"dequeue_context_length": {
|
||||
"description": "丢弃对话轮数",
|
||||
"hint": "超出最多携带对话轮数时, 一次丢弃的聊天轮数"
|
||||
},
|
||||
"context_limit_reached_strategy": {
|
||||
"description": "超出模型上下文窗口时的处理方式",
|
||||
"labels": ["按对话轮数截断", "由 LLM 压缩上下文"],
|
||||
"hint": "当按对话轮数截断时,会根据上面\"丢弃对话轮数\"的配置丢弃最旧的 N 轮对话。当由 LLM 压缩上下文时,会使用指定的模型进行上下文压缩。"
|
||||
},
|
||||
"llm_compress_instruction": {
|
||||
"description": "上下文压缩提示词",
|
||||
"hint": "如果为空则使用默认提示词。"
|
||||
},
|
||||
"llm_compress_keep_recent": {
|
||||
"description": "压缩时保留最近对话轮数",
|
||||
"hint": "始终保留的最近 N 轮对话。"
|
||||
},
|
||||
"llm_compress_provider_id": {
|
||||
"description": "用于上下文压缩的模型提供商 ID",
|
||||
"hint": "留空时将降级为\"按对话轮数截断\"的策略。"
|
||||
}
|
||||
}
|
||||
},
|
||||
"others": {
|
||||
"description": "其他配置",
|
||||
"provider_settings": {
|
||||
@@ -171,14 +201,7 @@
|
||||
"关闭流式回复"
|
||||
]
|
||||
},
|
||||
"max_context_length": {
|
||||
"description": "最多携带对话轮数",
|
||||
"hint": "超出这个数量时丢弃最旧的部分,一轮聊天记为 1 条,-1 为不限制"
|
||||
},
|
||||
"dequeue_context_length": {
|
||||
"description": "丢弃对话轮数",
|
||||
"hint": "超出最多携带对话轮数时, 一次丢弃的聊天轮数"
|
||||
},
|
||||
|
||||
"wake_prefix": {
|
||||
"description": "LLM 聊天额外唤醒前缀",
|
||||
"hint": "如果唤醒前缀为 /, 额外聊天唤醒前缀为 chat,则需要 /chat 才会触发 LLM 请求"
|
||||
|
||||
@@ -5,6 +5,15 @@
|
||||
"title": "GitHub 加速地址",
|
||||
"subtitle": "设置下载插件或者更新 AstrBot 时所用的 GitHub 加速地址。这在中国大陆的网络环境有效。可以自定义,输入结果实时生效。所有地址均不保证稳定性,如果在更新插件/项目时出现报错,请首先检查加速地址是否能正常使用。",
|
||||
"label": "选择 GitHub 加速地址"
|
||||
},
|
||||
"proxySelector": {
|
||||
"title": "GitHub 加速",
|
||||
"noProxy": "不使用 GitHub 加速",
|
||||
"useProxy": "使用 GitHub 加速",
|
||||
"testConnection": "测试代理连通性",
|
||||
"available": "可用",
|
||||
"unavailable": "不可用",
|
||||
"custom": "自定义"
|
||||
}
|
||||
},
|
||||
"system": {
|
||||
|
||||
@@ -43,4 +43,21 @@ spec:
|
||||
resources:
|
||||
requests:
|
||||
storage: 5Gi
|
||||
# storageClassName: standard
|
||||
|
||||
---
|
||||
# 持久化 machine-id,保持设备标识不变
|
||||
apiVersion: v1
|
||||
kind: PersistentVolumeClaim
|
||||
metadata:
|
||||
name: napcat-machine-id-pvc
|
||||
namespace: astrbot-ns
|
||||
labels:
|
||||
app: astrbot-stack
|
||||
spec:
|
||||
accessModes:
|
||||
- ReadWriteOnce
|
||||
resources:
|
||||
requests:
|
||||
storage: 10Mi # 只需存储一个 32 字节的文件
|
||||
# storageClassName: standard
|
||||
@@ -17,6 +17,32 @@ spec:
|
||||
labels:
|
||||
app: astrbot-stack
|
||||
spec:
|
||||
# 设置固定主机名,避免 Pod 重启后主机名变化触发风控
|
||||
hostname: napcat-host
|
||||
subdomain: astrbot-stack
|
||||
# 优雅关闭时间,给 NapCat 足够时间保存状态
|
||||
terminationGracePeriodSeconds: 60
|
||||
|
||||
# 初始化容器:首次生成随机 machine-id,后续复用
|
||||
initContainers:
|
||||
- name: init-machine-id
|
||||
image: busybox:latest
|
||||
command:
|
||||
- /bin/sh
|
||||
- -c
|
||||
- |
|
||||
# 仅在 machine-id 不存在时随机生成一个
|
||||
if [ ! -f /machine-id-data/machine-id ]; then
|
||||
# 使用 /dev/urandom 生成随机 UUID (32位十六进制)
|
||||
cat /proc/sys/kernel/random/uuid | tr -d '-' > /machine-id-data/machine-id
|
||||
echo "Machine ID generated: $(cat /machine-id-data/machine-id)"
|
||||
else
|
||||
echo "Machine ID exists: $(cat /machine-id-data/machine-id)"
|
||||
fi
|
||||
volumeMounts:
|
||||
- name: machine-id-data
|
||||
mountPath: /machine-id-data
|
||||
|
||||
containers:
|
||||
- name: napcat
|
||||
image: mlikiowa/napcat-docker:latest
|
||||
@@ -28,9 +54,19 @@ spec:
|
||||
value: "1000"
|
||||
- name: MODE
|
||||
value: "astrbot"
|
||||
- name: TZ
|
||||
value: "Asia/Shanghai"
|
||||
ports:
|
||||
- containerPort: 6099
|
||||
name: napcat-web
|
||||
# 资源限制:确保 Guaranteed QoS,减少被驱逐的可能
|
||||
resources:
|
||||
requests:
|
||||
memory: "512Mi"
|
||||
cpu: "250m"
|
||||
limits:
|
||||
memory: "1Gi"
|
||||
cpu: "1000m"
|
||||
volumeMounts:
|
||||
- name: shared-data
|
||||
mountPath: /AstrBot/data
|
||||
@@ -38,6 +74,14 @@ spec:
|
||||
mountPath: /app/napcat/config
|
||||
- name: napcat-qq
|
||||
mountPath: /app/.config/QQ
|
||||
# 挂载持久化的 machine-id
|
||||
- name: machine-id-data
|
||||
mountPath: /etc/machine-id
|
||||
subPath: machine-id
|
||||
readOnly: true
|
||||
- name: localtime
|
||||
mountPath: /etc/localtime
|
||||
readOnly: true
|
||||
|
||||
- name: astrbot
|
||||
image: soulter/astrbot:latest
|
||||
@@ -48,9 +92,19 @@ spec:
|
||||
ports:
|
||||
- containerPort: 6185
|
||||
name: astrbot-web
|
||||
resources:
|
||||
requests:
|
||||
memory: "256Mi"
|
||||
cpu: "100m"
|
||||
limits:
|
||||
memory: "512Mi"
|
||||
cpu: "500m"
|
||||
volumeMounts:
|
||||
- name: shared-data
|
||||
mountPath: /AstrBot/data
|
||||
- name: localtime
|
||||
mountPath: /etc/localtime
|
||||
readOnly: true
|
||||
|
||||
volumes:
|
||||
- name: shared-data
|
||||
@@ -61,4 +115,12 @@ spec:
|
||||
claimName: napcat-config-pvc
|
||||
- name: napcat-qq
|
||||
persistentVolumeClaim:
|
||||
claimName: napcat-qq-pvc
|
||||
claimName: napcat-qq-pvc
|
||||
# 持久化 machine-id(首次随机生成,后续复用)
|
||||
- name: machine-id-data
|
||||
persistentVolumeClaim:
|
||||
claimName: napcat-machine-id-pvc
|
||||
- name: localtime
|
||||
hostPath:
|
||||
path: /etc/localtime
|
||||
type: File
|
||||
+1
-1
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "AstrBot"
|
||||
version = "4.10.6"
|
||||
version = "4.11.0"
|
||||
description = "Easy-to-use multi-platform LLM chatbot and development framework"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
|
||||
@@ -0,0 +1,774 @@
|
||||
"""Comprehensive tests for ContextManager."""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
# Add parent directory to path to avoid circular import issues
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||
|
||||
from astrbot.core.agent.context.config import ContextConfig
|
||||
from astrbot.core.agent.context.manager import ContextManager
|
||||
from astrbot.core.agent.message import Message, TextPart
|
||||
from astrbot.core.provider.entities import LLMResponse
|
||||
|
||||
|
||||
class MockProvider:
|
||||
"""模拟 Provider"""
|
||||
|
||||
def __init__(self):
|
||||
self.provider_config = {
|
||||
"id": "test_provider",
|
||||
"model": "gpt-4",
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
}
|
||||
|
||||
async def text_chat(self, **kwargs):
|
||||
"""模拟 LLM 调用,返回摘要"""
|
||||
messages = kwargs.get("messages", [])
|
||||
# 简单的摘要逻辑:返回消息数量统计
|
||||
return LLMResponse(
|
||||
role="assistant",
|
||||
completion_text=f"历史对话包含 {len(messages) - 1} 条消息,主要讨论了技术话题。",
|
||||
)
|
||||
|
||||
def get_model(self):
|
||||
return "gpt-4"
|
||||
|
||||
def meta(self):
|
||||
return MagicMock(id="test_provider", type="openai")
|
||||
|
||||
|
||||
class TestContextManager:
|
||||
"""Test suite for ContextManager."""
|
||||
|
||||
def create_message(
|
||||
self, role: Literal["system", "user", "assistant", "tool"], content: str
|
||||
) -> Message:
|
||||
"""Helper to create a simple text message."""
|
||||
return Message(role=role, content=content)
|
||||
|
||||
def create_messages(self, count: int) -> list[Message]:
|
||||
"""Helper to create alternating user/assistant messages."""
|
||||
messages = []
|
||||
for i in range(count):
|
||||
role = "user" if i % 2 == 0 else "assistant"
|
||||
messages.append(self.create_message(role, f"Message {i}"))
|
||||
return messages
|
||||
|
||||
# ==================== Basic Initialization Tests ====================
|
||||
|
||||
def test_init_with_minimal_config(self):
|
||||
"""Test initialization with minimal configuration."""
|
||||
config = ContextConfig()
|
||||
manager = ContextManager(config)
|
||||
|
||||
assert manager.config == config
|
||||
assert manager.token_counter is not None
|
||||
assert manager.truncator is not None
|
||||
assert manager.compressor is not None
|
||||
|
||||
def test_init_with_llm_compressor(self):
|
||||
"""Test initialization with LLM-based compression."""
|
||||
mock_provider = MockProvider()
|
||||
config = ContextConfig(
|
||||
llm_compress_provider=mock_provider, # type: ignore
|
||||
llm_compress_keep_recent=5,
|
||||
llm_compress_instruction="Summarize the conversation",
|
||||
)
|
||||
manager = ContextManager(config)
|
||||
|
||||
from astrbot.core.agent.context.compressor import LLMSummaryCompressor
|
||||
|
||||
assert isinstance(manager.compressor, LLMSummaryCompressor)
|
||||
|
||||
def test_init_with_truncate_compressor(self):
|
||||
"""Test initialization with truncate-based compression (default)."""
|
||||
config = ContextConfig(truncate_turns=3)
|
||||
manager = ContextManager(config)
|
||||
|
||||
from astrbot.core.agent.context.compressor import TruncateByTurnsCompressor
|
||||
|
||||
assert isinstance(manager.compressor, TruncateByTurnsCompressor)
|
||||
|
||||
# ==================== Empty and Edge Cases ====================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_empty_messages(self):
|
||||
"""Test processing an empty message list."""
|
||||
config = ContextConfig()
|
||||
manager = ContextManager(config)
|
||||
|
||||
result = await manager.process([])
|
||||
|
||||
assert result == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_single_message(self):
|
||||
"""Test processing a single message."""
|
||||
config = ContextConfig()
|
||||
manager = ContextManager(config)
|
||||
|
||||
messages = [self.create_message("user", "Hello")]
|
||||
result = await manager.process(messages)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].content == "Hello"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_with_no_limits(self):
|
||||
"""Test processing when no limits are set (no truncation or compression)."""
|
||||
config = ContextConfig(max_context_tokens=0, enforce_max_turns=-1)
|
||||
manager = ContextManager(config)
|
||||
|
||||
messages = self.create_messages(20)
|
||||
result = await manager.process(messages)
|
||||
|
||||
assert len(result) == 20
|
||||
assert result == messages
|
||||
|
||||
# ==================== Enforce Max Turns Tests ====================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enforce_max_turns_basic(self):
|
||||
"""Test basic enforce_max_turns functionality."""
|
||||
config = ContextConfig(enforce_max_turns=3, truncate_turns=1)
|
||||
manager = ContextManager(config)
|
||||
|
||||
# Create 10 turns (20 messages)
|
||||
messages = self.create_messages(20)
|
||||
result = await manager.process(messages)
|
||||
|
||||
# Should keep only 3 most recent turns (6 messages)
|
||||
assert len(result) <= 8 # May vary due to truncation logic
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enforce_max_turns_zero(self):
|
||||
"""Test enforce_max_turns with value 0 (should keep nothing)."""
|
||||
config = ContextConfig(enforce_max_turns=0, truncate_turns=1)
|
||||
manager = ContextManager(config)
|
||||
|
||||
messages = self.create_messages(10)
|
||||
result = await manager.process(messages)
|
||||
|
||||
# Should result in empty or minimal message list
|
||||
assert len(result) <= 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enforce_max_turns_negative(self):
|
||||
"""Test enforce_max_turns with -1 (no limit)."""
|
||||
config = ContextConfig(enforce_max_turns=-1)
|
||||
manager = ContextManager(config)
|
||||
|
||||
messages = self.create_messages(20)
|
||||
result = await manager.process(messages)
|
||||
|
||||
assert len(result) == 20
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enforce_max_turns_with_system_messages(self):
|
||||
"""Test enforce_max_turns preserves system messages."""
|
||||
config = ContextConfig(enforce_max_turns=2, truncate_turns=1)
|
||||
manager = ContextManager(config)
|
||||
|
||||
messages = [
|
||||
self.create_message("system", "System instruction"),
|
||||
*self.create_messages(10),
|
||||
]
|
||||
result = await manager.process(messages)
|
||||
|
||||
# System message should be preserved
|
||||
system_msgs = [m for m in result if m.role == "system"]
|
||||
assert len(system_msgs) >= 1
|
||||
assert system_msgs[0].content == "System instruction"
|
||||
|
||||
# ==================== Token-based Compression Tests ====================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_compression_not_triggered_below_threshold(self):
|
||||
"""Test that compression is not triggered below threshold."""
|
||||
config = ContextConfig(max_context_tokens=1000)
|
||||
manager = ContextManager(config)
|
||||
|
||||
# Create messages that total less than threshold
|
||||
messages = [self.create_message("user", "Hi" * 50)] # ~100 tokens
|
||||
|
||||
with patch.object(
|
||||
manager.compressor, "should_compress", return_value=False
|
||||
) as mock_should_compress:
|
||||
with patch.object(
|
||||
manager.compressor, "__call__", new_callable=AsyncMock
|
||||
) as mock_compress:
|
||||
result = await manager.process(messages)
|
||||
|
||||
# should_compress should be called
|
||||
mock_should_compress.assert_called_once()
|
||||
# Compressor should not be called
|
||||
mock_compress.assert_not_called()
|
||||
assert result == messages
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_compression_triggered_above_threshold(self):
|
||||
"""Test that compression is triggered above threshold."""
|
||||
config = ContextConfig(max_context_tokens=100, truncate_turns=1)
|
||||
manager = ContextManager(config)
|
||||
|
||||
# Create messages that exceed threshold (0.82 * 100 = 82 tokens)
|
||||
# 300 chars * 0.3 = 90 tokens > 82 threshold
|
||||
long_text = "x" * 300 # ~90 tokens, above threshold
|
||||
messages = [self.create_message("user", long_text)]
|
||||
|
||||
# Mock compressor to return smaller result
|
||||
compressed = [self.create_message("user", "short")]
|
||||
|
||||
# Create a mock compressor
|
||||
mock_compressor = AsyncMock()
|
||||
mock_compressor.compression_threshold = 0.82
|
||||
mock_compressor.return_value = compressed
|
||||
|
||||
# Mock should_compress to return True first time, False after
|
||||
call_count = 0
|
||||
|
||||
def mock_should_compress(*args, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return call_count == 1
|
||||
|
||||
mock_compressor.should_compress = mock_should_compress
|
||||
manager.compressor = mock_compressor
|
||||
|
||||
result = await manager.process(messages)
|
||||
|
||||
# Compressor should be called
|
||||
mock_compressor.assert_called_once()
|
||||
# Result should be the compressed version
|
||||
assert len(result) <= len(messages)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_compression_with_zero_max_tokens(self):
|
||||
"""Test that compression is skipped when max_context_tokens is 0."""
|
||||
config = ContextConfig(max_context_tokens=0)
|
||||
manager = ContextManager(config)
|
||||
|
||||
messages = [self.create_message("user", "x" * 10000)]
|
||||
|
||||
with patch.object(
|
||||
manager.compressor, "__call__", new_callable=AsyncMock
|
||||
) as mock_compress:
|
||||
result = await manager.process(messages)
|
||||
|
||||
# Compressor should not be called when max_context_tokens is 0
|
||||
mock_compress.assert_not_called()
|
||||
assert result == messages
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_compression_with_negative_max_tokens(self):
|
||||
"""Test that compression is skipped when max_context_tokens is negative."""
|
||||
config = ContextConfig(max_context_tokens=-100)
|
||||
manager = ContextManager(config)
|
||||
|
||||
messages = [self.create_message("user", "x" * 10000)]
|
||||
|
||||
with patch.object(
|
||||
manager.compressor, "__call__", new_callable=AsyncMock
|
||||
) as mock_compress:
|
||||
result = await manager.process(messages)
|
||||
|
||||
# Compressor should not be called
|
||||
mock_compress.assert_not_called()
|
||||
assert result == messages
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_double_check_after_compression(self):
|
||||
"""Test that halving is applied if still over threshold after compression."""
|
||||
config = ContextConfig(max_context_tokens=100)
|
||||
manager = ContextManager(config)
|
||||
|
||||
# Create messages that would still be over threshold after compression
|
||||
long_messages = [self.create_message("user", "x" * 200) for _ in range(10)]
|
||||
|
||||
# Mock compressor to return messages still over threshold
|
||||
async def mock_compress(msgs):
|
||||
return msgs # Return same messages (still over limit)
|
||||
|
||||
# Mock should_compress to return True twice (before and after compression)
|
||||
with patch.object(manager.compressor, "should_compress", return_value=True):
|
||||
with patch.object(manager.compressor, "__call__", new=mock_compress):
|
||||
with patch.object(
|
||||
manager.truncator,
|
||||
"truncate_by_halving",
|
||||
return_value=long_messages[:5],
|
||||
) as mock_halving:
|
||||
_ = await manager.process(long_messages)
|
||||
|
||||
# Halving should be called
|
||||
mock_halving.assert_called_once()
|
||||
|
||||
# ==================== Combined Truncation and Compression Tests ====================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_combined_enforce_turns_and_token_limit(self):
|
||||
"""Test combining enforce_max_turns and token limit."""
|
||||
config = ContextConfig(
|
||||
enforce_max_turns=5, max_context_tokens=500, truncate_turns=1
|
||||
)
|
||||
manager = ContextManager(config)
|
||||
|
||||
# Create many messages
|
||||
messages = self.create_messages(30)
|
||||
|
||||
result = await manager.process(messages)
|
||||
|
||||
# Should be truncated by both mechanisms
|
||||
assert len(result) < 30
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sequential_processing_order(self):
|
||||
"""Test that enforce_max_turns happens before token compression."""
|
||||
config = ContextConfig(enforce_max_turns=5, max_context_tokens=1000)
|
||||
manager = ContextManager(config)
|
||||
|
||||
messages = self.create_messages(20)
|
||||
|
||||
# Mock the truncator to track calls
|
||||
with patch.object(
|
||||
manager.truncator,
|
||||
"truncate_by_turns",
|
||||
wraps=manager.truncator.truncate_by_turns,
|
||||
) as mock_truncate:
|
||||
await manager.process(messages)
|
||||
|
||||
# Truncator should be called first
|
||||
mock_truncate.assert_called_once()
|
||||
|
||||
# ==================== Error Handling Tests ====================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_handling_returns_original_messages(self):
|
||||
"""Test that errors during processing return original messages."""
|
||||
config = ContextConfig(max_context_tokens=100)
|
||||
manager = ContextManager(config)
|
||||
|
||||
messages = self.create_messages(5)
|
||||
|
||||
# Make compressor raise an exception
|
||||
with patch.object(
|
||||
manager.compressor, "__call__", side_effect=Exception("Test error")
|
||||
):
|
||||
result = await manager.process(messages)
|
||||
|
||||
# Should return original messages despite error
|
||||
assert result == messages
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_handling_logs_exception(self):
|
||||
"""Test that errors are logged."""
|
||||
config = ContextConfig(max_context_tokens=100)
|
||||
manager = ContextManager(config)
|
||||
|
||||
# Create messages that will trigger compression (> 82 tokens)
|
||||
messages = [self.create_message("user", "x" * 300)] # ~90 tokens
|
||||
|
||||
# Replace compressor with one that raises an exception
|
||||
mock_compressor = AsyncMock(side_effect=Exception("Test error"))
|
||||
mock_compressor.compression_threshold = 0.82
|
||||
mock_compressor.should_compress = MagicMock(return_value=True)
|
||||
manager.compressor = mock_compressor
|
||||
|
||||
with patch("astrbot.core.agent.context.manager.logger") as mock_logger:
|
||||
result = await manager.process(messages)
|
||||
|
||||
# Logger error method should be called
|
||||
assert mock_logger.error.called
|
||||
# Should return original messages on error
|
||||
assert result == messages
|
||||
|
||||
# ==================== Multi-modal Content Tests ====================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_messages_with_textpart_content(self):
|
||||
"""Test processing messages with TextPart content."""
|
||||
config = ContextConfig()
|
||||
manager = ContextManager(config)
|
||||
|
||||
messages = [
|
||||
Message(role="user", content=[TextPart(text="Hello")]),
|
||||
Message(role="assistant", content=[TextPart(text="Hi there")]),
|
||||
]
|
||||
|
||||
result = await manager.process(messages)
|
||||
|
||||
assert len(result) == 2
|
||||
assert result == messages
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_counting_with_multimodal_content(self):
|
||||
"""Test token counting works with multi-modal content."""
|
||||
config = ContextConfig(max_context_tokens=50)
|
||||
manager = ContextManager(config)
|
||||
|
||||
# Need enough tokens to exceed threshold: 50 * 0.82 = 41 tokens
|
||||
# 150 chars * 0.3 = 45 tokens > 41
|
||||
messages = [
|
||||
Message(role="user", content=[TextPart(text="x" * 150)]),
|
||||
]
|
||||
|
||||
# Should trigger compression due to token count
|
||||
tokens = manager.token_counter.count_tokens(messages)
|
||||
needs_compression = manager.compressor.should_compress(messages, tokens, 50)
|
||||
|
||||
assert tokens > 0 # Tokens should be counted
|
||||
assert needs_compression # Should trigger compression
|
||||
|
||||
# ==================== Tool Calls Tests ====================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_messages_with_tool_calls(self):
|
||||
"""Test processing messages with tool calls."""
|
||||
config = ContextConfig()
|
||||
manager = ContextManager(config)
|
||||
|
||||
messages = [
|
||||
Message(
|
||||
role="assistant",
|
||||
content="Let me search for that",
|
||||
tool_calls=[
|
||||
{
|
||||
"id": "call_1",
|
||||
"type": "function",
|
||||
"function": {"name": "search", "arguments": "{}"},
|
||||
}
|
||||
],
|
||||
),
|
||||
Message(role="tool", content="Search result", tool_call_id="call_1"),
|
||||
]
|
||||
|
||||
result = await manager.process(messages)
|
||||
|
||||
assert len(result) == 2
|
||||
|
||||
# ==================== Compressor should_compress Tests ====================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_compress_empty_messages(self):
|
||||
"""Test should_compress with empty messages."""
|
||||
config = ContextConfig(max_context_tokens=100)
|
||||
manager = ContextManager(config)
|
||||
|
||||
# Compressor's should_compress should handle empty gracefully
|
||||
needs_compression = manager.compressor.should_compress([], 0, 100)
|
||||
assert not needs_compression
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_compress_below_threshold(self):
|
||||
"""Test should_compress when below compression threshold."""
|
||||
config = ContextConfig(max_context_tokens=1000)
|
||||
manager = ContextManager(config)
|
||||
|
||||
messages = [self.create_message("user", "Hello")]
|
||||
tokens = manager.token_counter.count_tokens(messages)
|
||||
|
||||
needs_compression = manager.compressor.should_compress(messages, tokens, 1000)
|
||||
assert not needs_compression
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_compress_above_threshold(self):
|
||||
"""Test should_compress when above compression threshold."""
|
||||
config = ContextConfig(max_context_tokens=100)
|
||||
manager = ContextManager(config)
|
||||
|
||||
# Create message with many tokens
|
||||
messages = [self.create_message("user", "这是测试" * 50)]
|
||||
tokens = manager.token_counter.count_tokens(messages)
|
||||
|
||||
needs_compression = manager.compressor.should_compress(messages, tokens, 100)
|
||||
# Should need compression if tokens > 82 (0.82 * 100)
|
||||
assert needs_compression == (tokens > 82)
|
||||
|
||||
# ==================== Truncator Halving Tests ====================
|
||||
|
||||
def test_truncate_by_halving_basic(self):
|
||||
"""Test truncate_by_halving removes middle 50%."""
|
||||
config = ContextConfig()
|
||||
manager = ContextManager(config)
|
||||
|
||||
messages = self.create_messages(10)
|
||||
result = manager.truncator.truncate_by_halving(messages)
|
||||
|
||||
# Should keep roughly half
|
||||
assert len(result) < len(messages)
|
||||
|
||||
def test_truncate_by_halving_empty_list(self):
|
||||
"""Test truncate_by_halving with empty list."""
|
||||
config = ContextConfig()
|
||||
manager = ContextManager(config)
|
||||
|
||||
result = manager.truncator.truncate_by_halving([])
|
||||
|
||||
assert result == []
|
||||
|
||||
def test_truncate_by_halving_single_message(self):
|
||||
"""Test truncate_by_halving with single message."""
|
||||
config = ContextConfig()
|
||||
manager = ContextManager(config)
|
||||
|
||||
messages = [self.create_message("user", "Hello")]
|
||||
result = manager.truncator.truncate_by_halving(messages)
|
||||
|
||||
assert len(result) <= 1
|
||||
|
||||
# ==================== Complex Scenarios ====================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_compression_cycles(self):
|
||||
"""Test that compression can be triggered multiple times in sequence."""
|
||||
config = ContextConfig(max_context_tokens=50, truncate_turns=1)
|
||||
manager = ContextManager(config)
|
||||
|
||||
# Process messages multiple times
|
||||
messages = self.create_messages(10)
|
||||
|
||||
result1 = await manager.process(messages)
|
||||
result2 = await manager.process(result1)
|
||||
result3 = await manager.process(result2)
|
||||
|
||||
# Each cycle should maintain or reduce message count
|
||||
assert len(result3) <= len(result2) <= len(result1)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_alternating_roles_preserved(self):
|
||||
"""Test that user/assistant alternation is preserved after processing."""
|
||||
config = ContextConfig(enforce_max_turns=3, truncate_turns=1)
|
||||
manager = ContextManager(config)
|
||||
|
||||
messages = self.create_messages(20)
|
||||
result = await manager.process(messages)
|
||||
|
||||
# Check that roles still alternate (excluding system messages)
|
||||
non_system = [m for m in result if m.role != "system"]
|
||||
if len(non_system) >= 2:
|
||||
# Should start with user
|
||||
assert non_system[0].role == "user"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_compression_threshold_default(self):
|
||||
"""Test that compression threshold is used correctly."""
|
||||
config = ContextConfig(max_context_tokens=100)
|
||||
manager = ContextManager(config)
|
||||
|
||||
# Verify the default threshold is 0.82
|
||||
assert manager.compressor.compression_threshold == 0.82
|
||||
|
||||
# Test threshold logic
|
||||
messages = [self.create_message("user", "x" * 81)] # ~24 tokens
|
||||
tokens = manager.token_counter.count_tokens(messages)
|
||||
|
||||
needs_compression = manager.compressor.should_compress(messages, tokens, 100)
|
||||
# Should not compress if below threshold
|
||||
assert needs_compression == (tokens > 82)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_large_batch_processing(self):
|
||||
"""Test processing a large batch of messages."""
|
||||
config = ContextConfig(
|
||||
enforce_max_turns=10, max_context_tokens=1000, truncate_turns=2
|
||||
)
|
||||
manager = ContextManager(config)
|
||||
|
||||
# Create 100 messages (50 turns)
|
||||
messages = self.create_messages(100)
|
||||
|
||||
result = await manager.process(messages)
|
||||
|
||||
# Should be significantly reduced
|
||||
assert len(result) < 100
|
||||
assert len(result) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_config_persistence(self):
|
||||
"""Test that config settings are respected throughout processing."""
|
||||
config = ContextConfig(
|
||||
max_context_tokens=500,
|
||||
enforce_max_turns=5,
|
||||
truncate_turns=2,
|
||||
llm_compress_keep_recent=3,
|
||||
)
|
||||
manager = ContextManager(config)
|
||||
|
||||
# Verify config is stored
|
||||
assert manager.config.max_context_tokens == 500
|
||||
assert manager.config.enforce_max_turns == 5
|
||||
assert manager.config.truncate_turns == 2
|
||||
assert manager.config.llm_compress_keep_recent == 3
|
||||
|
||||
# ==================== Run Compression Tests ====================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_compression_calls_compressor(self):
|
||||
"""Test _run_compression calls compressor."""
|
||||
config = ContextConfig(max_context_tokens=100)
|
||||
manager = ContextManager(config)
|
||||
|
||||
messages = self.create_messages(5)
|
||||
compressed = self.create_messages(3)
|
||||
|
||||
# Create a mock compressor
|
||||
mock_compressor = AsyncMock()
|
||||
mock_compressor.compression_threshold = 0.82
|
||||
mock_compressor.return_value = compressed
|
||||
mock_compressor.should_compress = MagicMock(return_value=False)
|
||||
manager.compressor = mock_compressor
|
||||
|
||||
result = await manager._run_compression(messages, prev_tokens=100)
|
||||
|
||||
# Compressor __call__ should be invoked
|
||||
mock_compressor.assert_called_once_with(messages)
|
||||
assert result == compressed
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_compression_applies_compressor_through_process(self):
|
||||
"""Test _run_compression calls compressor when needed through process()."""
|
||||
config = ContextConfig(max_context_tokens=100, truncate_turns=1)
|
||||
manager = ContextManager(config)
|
||||
|
||||
# Create messages that will trigger compression
|
||||
messages = [self.create_message("user", "x" * 300)] # ~90 tokens > 82 threshold
|
||||
compressed = [self.create_message("user", "short")] # Much smaller
|
||||
|
||||
# Create a mock compressor
|
||||
mock_compressor = AsyncMock()
|
||||
mock_compressor.compression_threshold = 0.82
|
||||
mock_compressor.return_value = compressed
|
||||
|
||||
# Mock should_compress to return True first time, False after
|
||||
call_count = 0
|
||||
|
||||
def mock_should_compress(*args, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return call_count == 1
|
||||
|
||||
mock_compressor.should_compress = mock_should_compress
|
||||
manager.compressor = mock_compressor
|
||||
|
||||
result = await manager.process(messages)
|
||||
|
||||
# Compressor should have been called
|
||||
mock_compressor.assert_called_once()
|
||||
assert len(result) <= len(messages)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_compression_with_mock_provider(self):
|
||||
"""Test LLM compression using MockProvider."""
|
||||
mock_provider = MockProvider()
|
||||
config = ContextConfig(
|
||||
llm_compress_provider=mock_provider, # type: ignore
|
||||
llm_compress_keep_recent=3,
|
||||
llm_compress_instruction="请总结对话内容",
|
||||
max_context_tokens=100,
|
||||
)
|
||||
manager = ContextManager(config)
|
||||
|
||||
# Create messages that will trigger compression
|
||||
messages = [
|
||||
self.create_message("user", "x" * 100),
|
||||
self.create_message("assistant", "y" * 100),
|
||||
self.create_message("user", "z" * 100),
|
||||
]
|
||||
|
||||
result = await manager.process(messages)
|
||||
|
||||
# Should have been compressed
|
||||
assert len(result) <= len(messages)
|
||||
|
||||
# ==================== split_history Tests ====================
|
||||
|
||||
def test_split_history_ensures_user_start(self):
|
||||
"""Test split_history ensures recent_messages starts with user message."""
|
||||
from astrbot.core.agent.context.compressor import split_history
|
||||
|
||||
# Create alternating messages: user, assistant, user, assistant, user, assistant
|
||||
messages = [
|
||||
self.create_message("system", "System prompt"),
|
||||
self.create_message("user", "msg1"),
|
||||
self.create_message("assistant", "msg2"),
|
||||
self.create_message("user", "msg3"),
|
||||
self.create_message("assistant", "msg4"),
|
||||
self.create_message("user", "msg5"),
|
||||
self.create_message("assistant", "msg6"),
|
||||
]
|
||||
|
||||
# Keep recent 3 messages - should adjust to start with user
|
||||
system, to_summarize, recent = split_history(messages, keep_recent=3)
|
||||
|
||||
# recent_messages should start with user message
|
||||
assert len(recent) > 0
|
||||
assert recent[0].role == "user"
|
||||
|
||||
# messages_to_summarize should end with assistant (complete turn)
|
||||
if len(to_summarize) > 0:
|
||||
assert to_summarize[-1].role == "assistant"
|
||||
|
||||
def test_split_history_handles_assistant_at_split_point(self):
|
||||
"""Test split_history when assistant message is at the intended split point."""
|
||||
from astrbot.core.agent.context.compressor import split_history
|
||||
|
||||
messages = [
|
||||
self.create_message("user", "msg1"),
|
||||
self.create_message("assistant", "msg2"),
|
||||
self.create_message("user", "msg3"),
|
||||
self.create_message("assistant", "msg4"), # <- intended split here
|
||||
self.create_message("user", "msg5"),
|
||||
self.create_message("assistant", "msg6"),
|
||||
]
|
||||
|
||||
# keep_recent=2 would normally split at index 4 (assistant msg4)
|
||||
# Should move back to include from msg5 (user)
|
||||
system, to_summarize, recent = split_history(messages, keep_recent=2)
|
||||
|
||||
# recent should start with user message
|
||||
assert recent[0].role == "user"
|
||||
assert recent[0].content == "msg5"
|
||||
|
||||
def test_split_history_all_assistant_messages(self):
|
||||
"""Test split_history when there are consecutive assistant messages."""
|
||||
from astrbot.core.agent.context.compressor import split_history
|
||||
|
||||
messages = [
|
||||
self.create_message("user", "msg1"),
|
||||
self.create_message("assistant", "msg2"),
|
||||
self.create_message("assistant", "msg3"),
|
||||
self.create_message("assistant", "msg4"),
|
||||
]
|
||||
|
||||
system, to_summarize, recent = split_history(messages, keep_recent=2)
|
||||
|
||||
# Should find the user message and keep from there
|
||||
if len(recent) > 0:
|
||||
# Find first user message backwards
|
||||
assert any(m.role == "user" for m in messages)
|
||||
|
||||
def test_split_history_with_system_messages(self):
|
||||
"""Test split_history preserves system messages separately."""
|
||||
from astrbot.core.agent.context.compressor import split_history
|
||||
|
||||
messages = [
|
||||
self.create_message("system", "System 1"),
|
||||
self.create_message("system", "System 2"),
|
||||
self.create_message("user", "msg1"),
|
||||
self.create_message("assistant", "msg2"),
|
||||
self.create_message("user", "msg3"),
|
||||
]
|
||||
|
||||
system, to_summarize, recent = split_history(messages, keep_recent=2)
|
||||
|
||||
# System messages should be separate
|
||||
assert len(system) == 2
|
||||
assert all(m.role == "system" for m in system)
|
||||
|
||||
# Recent should start with user
|
||||
if len(recent) > 0:
|
||||
assert recent[0].role == "user"
|
||||
@@ -0,0 +1,423 @@
|
||||
"""Tests for ContextTruncator."""
|
||||
|
||||
from astrbot.core.agent.context.truncator import ContextTruncator
|
||||
from astrbot.core.agent.message import Message
|
||||
|
||||
|
||||
class TestContextTruncator:
|
||||
"""Test suite for ContextTruncator."""
|
||||
|
||||
def create_message(self, role: str, content: str = "test content") -> Message:
|
||||
"""Helper to create a simple test message."""
|
||||
return Message(role=role, content=content)
|
||||
|
||||
def create_messages(
|
||||
self, count: int, include_system: bool = False
|
||||
) -> list[Message]:
|
||||
"""Helper to create alternating user/assistant messages.
|
||||
|
||||
Args:
|
||||
count: Number of messages to create
|
||||
include_system: Whether to include a system message at the start
|
||||
|
||||
Returns:
|
||||
List of messages
|
||||
"""
|
||||
messages = []
|
||||
if include_system:
|
||||
messages.append(self.create_message("system", "System prompt"))
|
||||
|
||||
for i in range(count):
|
||||
role = "user" if i % 2 == 0 else "assistant"
|
||||
messages.append(self.create_message(role, f"Message {i}"))
|
||||
return messages
|
||||
|
||||
# ==================== fix_messages Tests ====================
|
||||
|
||||
def test_fix_messages_empty_list(self):
|
||||
"""Test fix_messages with an empty list."""
|
||||
truncator = ContextTruncator()
|
||||
result = truncator.fix_messages([])
|
||||
assert result == []
|
||||
|
||||
def test_fix_messages_normal_messages(self):
|
||||
"""Test fix_messages with normal user/assistant messages."""
|
||||
truncator = ContextTruncator()
|
||||
messages = [
|
||||
self.create_message("user", "Hello"),
|
||||
self.create_message("assistant", "Hi"),
|
||||
self.create_message("user", "How are you?"),
|
||||
]
|
||||
result = truncator.fix_messages(messages)
|
||||
assert len(result) == 3
|
||||
assert result == messages
|
||||
|
||||
def test_fix_messages_tool_with_valid_context(self):
|
||||
"""Test fix_messages with tool message after user+assistant."""
|
||||
truncator = ContextTruncator()
|
||||
messages = [
|
||||
self.create_message("user", "Run tool"),
|
||||
self.create_message("assistant", "Running..."),
|
||||
self.create_message("tool", "Tool result"),
|
||||
]
|
||||
result = truncator.fix_messages(messages)
|
||||
assert len(result) == 3
|
||||
assert result == messages
|
||||
|
||||
def test_fix_messages_tool_without_context(self):
|
||||
"""Test fix_messages with tool message without enough context."""
|
||||
truncator = ContextTruncator()
|
||||
messages = [
|
||||
self.create_message("tool", "Tool result"),
|
||||
]
|
||||
result = truncator.fix_messages(messages)
|
||||
# Tool message without context should be removed
|
||||
assert len(result) == 0
|
||||
|
||||
def test_fix_messages_tool_with_only_one_message(self):
|
||||
"""Test fix_messages with tool message after only one message."""
|
||||
truncator = ContextTruncator()
|
||||
messages = [
|
||||
self.create_message("user", "Hello"),
|
||||
self.create_message("tool", "Tool result"),
|
||||
]
|
||||
result = truncator.fix_messages(messages)
|
||||
# Tool message without enough context should be removed
|
||||
assert len(result) == 0
|
||||
|
||||
def test_fix_messages_multiple_tools(self):
|
||||
"""Test fix_messages with multiple tool messages."""
|
||||
truncator = ContextTruncator()
|
||||
messages = [
|
||||
self.create_message("user", "Run tool"),
|
||||
self.create_message("assistant", "Running..."),
|
||||
self.create_message("tool", "Tool 1 result"),
|
||||
self.create_message("tool", "Tool 2 result"),
|
||||
]
|
||||
result = truncator.fix_messages(messages)
|
||||
assert len(result) == 4
|
||||
assert result == messages
|
||||
|
||||
def test_fix_messages_mixed_system_tool(self):
|
||||
"""Test fix_messages with system message and tool messages."""
|
||||
truncator = ContextTruncator()
|
||||
messages = [
|
||||
self.create_message("system", "System prompt"),
|
||||
self.create_message("user", "Run tool"),
|
||||
self.create_message("assistant", "Running..."),
|
||||
self.create_message("tool", "Tool result"),
|
||||
]
|
||||
result = truncator.fix_messages(messages)
|
||||
assert len(result) == 4
|
||||
assert result == messages
|
||||
|
||||
# ==================== truncate_by_turns Tests ====================
|
||||
|
||||
def test_truncate_by_turns_no_limit(self):
|
||||
"""Test truncate_by_turns with -1 (no limit)."""
|
||||
truncator = ContextTruncator()
|
||||
messages = self.create_messages(20)
|
||||
result = truncator.truncate_by_turns(messages, keep_most_recent_turns=-1)
|
||||
assert len(result) == 20
|
||||
assert result == messages
|
||||
|
||||
def test_truncate_by_turns_basic(self):
|
||||
"""Test basic truncate_by_turns functionality."""
|
||||
truncator = ContextTruncator()
|
||||
# Create 10 messages = 5 turns (user/assistant pairs)
|
||||
messages = self.create_messages(10)
|
||||
result = truncator.truncate_by_turns(
|
||||
messages, keep_most_recent_turns=3, drop_turns=1
|
||||
)
|
||||
|
||||
# Should keep 3 most recent turns (6 messages)
|
||||
assert len(result) <= 8 # (3-1+1)*2 = 6, but may adjust for correct format
|
||||
|
||||
def test_truncate_by_turns_with_system_message(self):
|
||||
"""Test truncate_by_turns preserves system messages."""
|
||||
truncator = ContextTruncator()
|
||||
messages = self.create_messages(10, include_system=True)
|
||||
result = truncator.truncate_by_turns(
|
||||
messages, keep_most_recent_turns=2, drop_turns=1
|
||||
)
|
||||
|
||||
# System message should always be preserved
|
||||
assert result[0].role == "system"
|
||||
assert result[0].content == "System prompt"
|
||||
|
||||
def test_truncate_by_turns_zero_keep(self):
|
||||
"""Test truncate_by_turns with keep_most_recent_turns=0."""
|
||||
truncator = ContextTruncator()
|
||||
messages = self.create_messages(10)
|
||||
result = truncator.truncate_by_turns(
|
||||
messages, keep_most_recent_turns=0, drop_turns=1
|
||||
)
|
||||
|
||||
# Should result in empty or minimal list
|
||||
assert len(result) == 0
|
||||
|
||||
def test_truncate_by_turns_below_threshold(self):
|
||||
"""Test truncate_by_turns when messages are below threshold."""
|
||||
truncator = ContextTruncator()
|
||||
# Create 4 messages = 2 turns
|
||||
messages = self.create_messages(4)
|
||||
result = truncator.truncate_by_turns(
|
||||
messages, keep_most_recent_turns=5, drop_turns=1
|
||||
)
|
||||
|
||||
# No truncation should happen
|
||||
assert len(result) == 4
|
||||
assert result == messages
|
||||
|
||||
def test_truncate_by_turns_exact_threshold(self):
|
||||
"""Test truncate_by_turns when messages exactly match threshold."""
|
||||
truncator = ContextTruncator()
|
||||
# Create 6 messages = 3 turns
|
||||
messages = self.create_messages(6)
|
||||
result = truncator.truncate_by_turns(
|
||||
messages, keep_most_recent_turns=3, drop_turns=1
|
||||
)
|
||||
|
||||
# No truncation should happen
|
||||
assert len(result) == 6
|
||||
assert result == messages
|
||||
|
||||
def test_truncate_by_turns_ensures_user_first(self):
|
||||
"""Test that truncate_by_turns ensures user message comes first."""
|
||||
truncator = ContextTruncator()
|
||||
# Create scenario where truncation might start with assistant
|
||||
messages = self.create_messages(20)
|
||||
result = truncator.truncate_by_turns(
|
||||
messages, keep_most_recent_turns=3, drop_turns=1
|
||||
)
|
||||
|
||||
# First non-system message should be user
|
||||
assert result[0].role == "user"
|
||||
|
||||
def test_truncate_by_turns_multiple_drop(self):
|
||||
"""Test truncate_by_turns with multiple turns dropped at once."""
|
||||
truncator = ContextTruncator()
|
||||
messages = self.create_messages(20)
|
||||
result = truncator.truncate_by_turns(
|
||||
messages, keep_most_recent_turns=5, drop_turns=3
|
||||
)
|
||||
|
||||
# Should drop 3 turns when limit exceeded
|
||||
assert len(result) < len(messages)
|
||||
|
||||
# ==================== truncate_by_dropping_oldest_turns Tests ====================
|
||||
|
||||
def test_truncate_by_dropping_oldest_turns_zero(self):
|
||||
"""Test truncate_by_dropping_oldest_turns with drop_turns=0."""
|
||||
truncator = ContextTruncator()
|
||||
messages = self.create_messages(10)
|
||||
result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=0)
|
||||
assert result == messages
|
||||
|
||||
def test_truncate_by_dropping_oldest_turns_negative(self):
|
||||
"""Test truncate_by_dropping_oldest_turns with negative drop_turns."""
|
||||
truncator = ContextTruncator()
|
||||
messages = self.create_messages(10)
|
||||
result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=-1)
|
||||
assert result == messages
|
||||
|
||||
def test_truncate_by_dropping_oldest_turns_basic(self):
|
||||
"""Test basic truncate_by_dropping_oldest_turns functionality."""
|
||||
truncator = ContextTruncator()
|
||||
# Create 10 messages = 5 turns
|
||||
messages = self.create_messages(10)
|
||||
result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=2)
|
||||
|
||||
# Should drop 2 oldest turns (4 messages)
|
||||
assert len(result) == 6
|
||||
# Should start with user message
|
||||
assert result[0].role == "user"
|
||||
|
||||
def test_truncate_by_dropping_oldest_turns_with_system(self):
|
||||
"""Test truncate_by_dropping_oldest_turns preserves system messages."""
|
||||
truncator = ContextTruncator()
|
||||
messages = self.create_messages(10, include_system=True)
|
||||
result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=2)
|
||||
|
||||
# System message should be preserved
|
||||
assert result[0].role == "system"
|
||||
assert result[0].content == "System prompt"
|
||||
|
||||
def test_truncate_by_dropping_oldest_turns_drop_all(self):
|
||||
"""Test truncate_by_dropping_oldest_turns dropping all turns."""
|
||||
truncator = ContextTruncator()
|
||||
# Create 4 messages = 2 turns
|
||||
messages = self.create_messages(4)
|
||||
result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=2)
|
||||
|
||||
# Should drop all turns
|
||||
assert len(result) == 0
|
||||
|
||||
def test_truncate_by_dropping_oldest_turns_drop_more_than_available(self):
|
||||
"""Test truncate_by_dropping_oldest_turns with drop_turns > available turns."""
|
||||
truncator = ContextTruncator()
|
||||
# Create 4 messages = 2 turns
|
||||
messages = self.create_messages(4)
|
||||
result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=5)
|
||||
|
||||
# Should result in empty list
|
||||
assert len(result) == 0
|
||||
|
||||
def test_truncate_by_dropping_oldest_turns_ensures_user_first(self):
|
||||
"""Test that result starts with user message after dropping."""
|
||||
truncator = ContextTruncator()
|
||||
messages = self.create_messages(20)
|
||||
result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=3)
|
||||
|
||||
# First message should be user
|
||||
if len(result) > 0:
|
||||
assert result[0].role == "user"
|
||||
|
||||
# ==================== truncate_by_halving Tests ====================
|
||||
|
||||
def test_truncate_by_halving_empty(self):
|
||||
"""Test truncate_by_halving with empty list."""
|
||||
truncator = ContextTruncator()
|
||||
result = truncator.truncate_by_halving([])
|
||||
assert result == []
|
||||
|
||||
def test_truncate_by_halving_single_message(self):
|
||||
"""Test truncate_by_halving with single message."""
|
||||
truncator = ContextTruncator()
|
||||
messages = [self.create_message("user", "Hello")]
|
||||
result = truncator.truncate_by_halving(messages)
|
||||
# Should not truncate if <= 2 messages
|
||||
assert result == messages
|
||||
|
||||
def test_truncate_by_halving_two_messages(self):
|
||||
"""Test truncate_by_halving with two messages."""
|
||||
truncator = ContextTruncator()
|
||||
messages = self.create_messages(2)
|
||||
result = truncator.truncate_by_halving(messages)
|
||||
# Should not truncate if <= 2 messages
|
||||
assert result == messages
|
||||
|
||||
def test_truncate_by_halving_basic(self):
|
||||
"""Test basic truncate_by_halving functionality."""
|
||||
truncator = ContextTruncator()
|
||||
# Create 20 messages
|
||||
messages = self.create_messages(20)
|
||||
result = truncator.truncate_by_halving(messages)
|
||||
|
||||
# Should delete 50% = 10 messages, keep 10
|
||||
assert len(result) == 10
|
||||
# First message should be user
|
||||
assert result[0].role == "user"
|
||||
|
||||
def test_truncate_by_halving_with_system_message(self):
|
||||
"""Test truncate_by_halving preserves system messages."""
|
||||
truncator = ContextTruncator()
|
||||
messages = self.create_messages(20, include_system=True)
|
||||
result = truncator.truncate_by_halving(messages)
|
||||
|
||||
# System message should be preserved
|
||||
assert result[0].role == "system"
|
||||
assert result[0].content == "System prompt"
|
||||
|
||||
def test_truncate_by_halving_odd_count(self):
|
||||
"""Test truncate_by_halving with odd number of messages."""
|
||||
truncator = ContextTruncator()
|
||||
messages = self.create_messages(11)
|
||||
result = truncator.truncate_by_halving(messages)
|
||||
|
||||
# Should delete floor(11/2) = 5 messages, keep 6
|
||||
# But after ensuring user first, may be 5
|
||||
assert len(result) >= 5
|
||||
assert result[0].role == "user"
|
||||
|
||||
def test_truncate_by_halving_ensures_user_first(self):
|
||||
"""Test that result starts with user message."""
|
||||
truncator = ContextTruncator()
|
||||
# Create messages starting with user
|
||||
messages = self.create_messages(30)
|
||||
result = truncator.truncate_by_halving(messages)
|
||||
|
||||
# First message should be user
|
||||
assert result[0].role == "user"
|
||||
|
||||
def test_truncate_by_halving_preserves_recent_messages(self):
|
||||
"""Test that truncate_by_halving keeps the most recent 50%."""
|
||||
truncator = ContextTruncator()
|
||||
messages = [
|
||||
self.create_message("user", "Message 0"),
|
||||
self.create_message("assistant", "Message 1"),
|
||||
self.create_message("user", "Message 2"),
|
||||
self.create_message("assistant", "Message 3"),
|
||||
]
|
||||
result = truncator.truncate_by_halving(messages)
|
||||
|
||||
# Should keep last 2 messages
|
||||
assert len(result) == 2
|
||||
assert result[0].content == "Message 2"
|
||||
assert result[1].content == "Message 3"
|
||||
|
||||
# ==================== Integration Tests ====================
|
||||
|
||||
def test_truncate_with_tool_messages(self):
|
||||
"""Test truncation with tool messages."""
|
||||
truncator = ContextTruncator()
|
||||
messages = [
|
||||
self.create_message("user", "Run tool"),
|
||||
self.create_message("assistant", "Running..."),
|
||||
self.create_message("tool", "Tool result"),
|
||||
self.create_message("user", "Thanks"),
|
||||
self.create_message("assistant", "Welcome"),
|
||||
]
|
||||
|
||||
result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=1)
|
||||
|
||||
# First turn (user+assistant+tool) should be dropped
|
||||
# Tool message should be cleaned up by fix_messages
|
||||
assert len(result) <= 2
|
||||
|
||||
def test_chain_multiple_truncations(self):
|
||||
"""Test chaining multiple truncation methods."""
|
||||
truncator = ContextTruncator()
|
||||
messages = self.create_messages(40, include_system=True)
|
||||
|
||||
# First: truncate by turns
|
||||
result = truncator.truncate_by_turns(
|
||||
messages, keep_most_recent_turns=10, drop_turns=2
|
||||
)
|
||||
# Then: halve
|
||||
result = truncator.truncate_by_halving(result)
|
||||
|
||||
# Should have system message + truncated content
|
||||
assert result[0].role == "system"
|
||||
assert len(result) < len(messages)
|
||||
|
||||
def test_empty_after_system_message(self):
|
||||
"""Test truncation when only system message exists."""
|
||||
truncator = ContextTruncator()
|
||||
messages = [self.create_message("system", "System prompt")]
|
||||
|
||||
result = truncator.truncate_by_turns(
|
||||
messages, keep_most_recent_turns=5, drop_turns=1
|
||||
)
|
||||
|
||||
# Should keep system message
|
||||
assert len(result) == 1
|
||||
assert result[0].role == "system"
|
||||
|
||||
def test_all_system_messages(self):
|
||||
"""Test truncation with only system messages."""
|
||||
truncator = ContextTruncator()
|
||||
messages = [
|
||||
self.create_message("system", "System 1"),
|
||||
self.create_message("system", "System 2"),
|
||||
]
|
||||
|
||||
result = truncator.truncate_by_turns(
|
||||
messages, keep_most_recent_turns=0, drop_turns=1
|
||||
)
|
||||
|
||||
# System messages should be preserved, but since there are no non-system
|
||||
# messages and keep_most_recent_turns=0, result should be system messages only
|
||||
assert len(result) >= 0 # May keep system messages or clear all
|
||||
if len(result) > 0:
|
||||
assert all(msg.role == "system" for msg in result)
|
||||
Reference in New Issue
Block a user