Compare commits
27 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 82a96a8cce | |||
| 343b153263 | |||
| 3a41b19318 | |||
| af444ea6cc | |||
| cb84db532e | |||
| 99b82f48ec | |||
| 00471f904e | |||
| 5df15c60ff | |||
| 32e523b7da | |||
| 0de4fd9f0d | |||
| e23a7e2505 | |||
| 1ed4d9f484 | |||
| d842155770 | |||
| 7f5cc7cf1a | |||
| f26867c77d | |||
| a14d588b44 | |||
| e236402d92 | |||
| 454841de10 | |||
| 442b5403df | |||
| 9db7bf59b8 | |||
| 3622504021 | |||
| fc42db40ce | |||
| e413a002c1 | |||
| 6437d759a3 | |||
| c758b2d888 | |||
| 510290fe0e | |||
| c61d62edb6 |
@@ -26,6 +26,7 @@ jobs:
|
||||
- uses: actions/stale@v10
|
||||
with:
|
||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
operations-per-run: 200
|
||||
|
||||
# 只处理带 bug 标签的 Issue
|
||||
any-of-labels: 'bug'
|
||||
|
||||
@@ -132,6 +132,7 @@ uv run main.py
|
||||
|
||||
**社区维护**
|
||||
|
||||
- [Matrix](https://github.com/stevessr/astrbot_plugin_matrix_adapter)
|
||||
- [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter)
|
||||
- [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat)
|
||||
- [Bilibili 私信](https://github.com/Hina-Chat/astrbot_plugin_bilibili_adapter)
|
||||
@@ -208,6 +209,7 @@ pre-commit install
|
||||
- 5 群:822130018
|
||||
- 6 群:753075035
|
||||
- 7 群:743746109
|
||||
- 8 群:1030353265
|
||||
- 开发者群:975206796
|
||||
|
||||
### Telegram 群组
|
||||
|
||||
@@ -134,6 +134,7 @@ Or refer to the official documentation: [Deploy AstrBot from Source](https://ast
|
||||
|
||||
**Community Maintained**
|
||||
|
||||
- [Matrix](https://github.com/stevessr/astrbot_plugin_matrix_adapter)
|
||||
- [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter)
|
||||
- [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat)
|
||||
- [Bilibili Direct Messages](https://github.com/Hina-Chat/astrbot_plugin_bilibili_adapter)
|
||||
|
||||
@@ -134,6 +134,7 @@ Ou consultez la documentation officielle : [Déployer AstrBot depuis les sources
|
||||
|
||||
**Maintenues par la communauté**
|
||||
|
||||
- [Matrix](https://github.com/stevessr/astrbot_plugin_matrix_adapter)
|
||||
- [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter)
|
||||
- [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat)
|
||||
- [Messages directs Bilibili](https://github.com/Hina-Chat/astrbot_plugin_bilibili_adapter)
|
||||
|
||||
@@ -134,6 +134,7 @@ uv run main.py
|
||||
|
||||
**コミュニティメンテナンス**
|
||||
|
||||
- [Matrix](https://github.com/stevessr/astrbot_plugin_matrix_adapter)
|
||||
- [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter)
|
||||
- [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat)
|
||||
- [Bilibili ダイレクトメッセージ](https://github.com/Hina-Chat/astrbot_plugin_bilibili_adapter)
|
||||
|
||||
@@ -134,6 +134,7 @@ uv run main.py
|
||||
|
||||
**Поддерживаемые сообществом**
|
||||
|
||||
- [Matrix](https://github.com/stevessr/astrbot_plugin_matrix_adapter)
|
||||
- [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter)
|
||||
- [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat)
|
||||
- [Личные сообщения Bilibili](https://github.com/Hina-Chat/astrbot_plugin_bilibili_adapter)
|
||||
|
||||
@@ -134,6 +134,7 @@ uv run main.py
|
||||
|
||||
**社群維護**
|
||||
|
||||
- [Matrix](https://github.com/stevessr/astrbot_plugin_matrix_adapter)
|
||||
- [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter)
|
||||
- [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat)
|
||||
- [Bilibili 私訊](https://github.com/Hina-Chat/astrbot_plugin_bilibili_adapter)
|
||||
|
||||
@@ -21,6 +21,9 @@ from astrbot.core.star.register import (
|
||||
from astrbot.core.star.register import register_on_llm_request as on_llm_request
|
||||
from astrbot.core.star.register import register_on_llm_response as on_llm_response
|
||||
from astrbot.core.star.register import register_on_platform_loaded as on_platform_loaded
|
||||
from astrbot.core.star.register import (
|
||||
register_on_waiting_llm_request as on_waiting_llm_request,
|
||||
)
|
||||
from astrbot.core.star.register import register_permission_type as permission_type
|
||||
from astrbot.core.star.register import (
|
||||
register_platform_adapter_type as platform_adapter_type,
|
||||
@@ -46,6 +49,7 @@ __all__ = [
|
||||
"on_llm_request",
|
||||
"on_llm_response",
|
||||
"on_platform_loaded",
|
||||
"on_waiting_llm_request",
|
||||
"permission_type",
|
||||
"platform_adapter_type",
|
||||
"regex",
|
||||
|
||||
@@ -14,13 +14,13 @@ class TTSCommand:
|
||||
async def tts(self, event: AstrMessageEvent):
|
||||
"""开关文本转语音(会话级别)"""
|
||||
umo = event.unified_msg_origin
|
||||
ses_tts = SessionServiceManager.is_tts_enabled_for_session(umo)
|
||||
ses_tts = await SessionServiceManager.is_tts_enabled_for_session(umo)
|
||||
cfg = self.context.get_config(umo=umo)
|
||||
tts_enable = cfg["provider_tts_settings"]["enable"]
|
||||
|
||||
# 切换状态
|
||||
new_status = not ses_tts
|
||||
SessionServiceManager.set_tts_status_for_session(umo, new_status)
|
||||
await SessionServiceManager.set_tts_status_for_session(umo, new_status)
|
||||
|
||||
status_text = "已开启" if new_status else "已关闭"
|
||||
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = "4.10.4"
|
||||
__version__ = "4.11.0"
|
||||
|
||||
@@ -0,0 +1,243 @@
|
||||
from typing import TYPE_CHECKING, Protocol, runtime_checkable
|
||||
|
||||
from ..message import Message
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from astrbot import logger
|
||||
else:
|
||||
try:
|
||||
from astrbot import logger
|
||||
except ImportError:
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger("astrbot")
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from astrbot.core.provider.provider import Provider
|
||||
|
||||
from ..context.truncator import ContextTruncator
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class ContextCompressor(Protocol):
|
||||
"""
|
||||
Protocol for context compressors.
|
||||
Provides an interface for compressing message lists.
|
||||
"""
|
||||
|
||||
def should_compress(
|
||||
self, messages: list[Message], current_tokens: int, max_tokens: int
|
||||
) -> bool:
|
||||
"""Check if compression is needed.
|
||||
|
||||
Args:
|
||||
messages: The message list to evaluate.
|
||||
current_tokens: The current token count.
|
||||
max_tokens: The maximum allowed tokens for the model.
|
||||
|
||||
Returns:
|
||||
True if compression is needed, False otherwise.
|
||||
"""
|
||||
...
|
||||
|
||||
async def __call__(self, messages: list[Message]) -> list[Message]:
|
||||
"""Compress the message list.
|
||||
|
||||
Args:
|
||||
messages: The original message list.
|
||||
|
||||
Returns:
|
||||
The compressed message list.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class TruncateByTurnsCompressor:
|
||||
"""Truncate by turns compressor implementation.
|
||||
Truncates the message list by removing older turns.
|
||||
"""
|
||||
|
||||
def __init__(self, truncate_turns: int = 1, compression_threshold: float = 0.82):
|
||||
"""Initialize the truncate by turns compressor.
|
||||
|
||||
Args:
|
||||
truncate_turns: The number of turns to remove when truncating (default: 1).
|
||||
compression_threshold: The compression trigger threshold (default: 0.82).
|
||||
"""
|
||||
self.truncate_turns = truncate_turns
|
||||
self.compression_threshold = compression_threshold
|
||||
|
||||
def should_compress(
|
||||
self, messages: list[Message], current_tokens: int, max_tokens: int
|
||||
) -> bool:
|
||||
"""Check if compression is needed.
|
||||
|
||||
Args:
|
||||
messages: The message list to evaluate.
|
||||
current_tokens: The current token count.
|
||||
max_tokens: The maximum allowed tokens.
|
||||
|
||||
Returns:
|
||||
True if compression is needed, False otherwise.
|
||||
"""
|
||||
if max_tokens <= 0 or current_tokens <= 0:
|
||||
return False
|
||||
usage_rate = current_tokens / max_tokens
|
||||
return usage_rate > self.compression_threshold
|
||||
|
||||
async def __call__(self, messages: list[Message]) -> list[Message]:
|
||||
truncator = ContextTruncator()
|
||||
truncated_messages = truncator.truncate_by_dropping_oldest_turns(
|
||||
messages,
|
||||
drop_turns=self.truncate_turns,
|
||||
)
|
||||
return truncated_messages
|
||||
|
||||
|
||||
def split_history(
|
||||
messages: list[Message], keep_recent: int
|
||||
) -> tuple[list[Message], list[Message], list[Message]]:
|
||||
"""Split the message list into system messages, messages to summarize, and recent messages.
|
||||
|
||||
Ensures that the split point is between complete user-assistant pairs to maintain conversation flow.
|
||||
|
||||
Args:
|
||||
messages: The original message list.
|
||||
keep_recent: The number of latest messages to keep.
|
||||
|
||||
Returns:
|
||||
tuple: (system_messages, messages_to_summarize, recent_messages)
|
||||
"""
|
||||
# keep the system messages
|
||||
first_non_system = 0
|
||||
for i, msg in enumerate(messages):
|
||||
if msg.role != "system":
|
||||
first_non_system = i
|
||||
break
|
||||
|
||||
system_messages = messages[:first_non_system]
|
||||
non_system_messages = messages[first_non_system:]
|
||||
|
||||
if len(non_system_messages) <= keep_recent:
|
||||
return system_messages, [], non_system_messages
|
||||
|
||||
# Find the split point, ensuring recent_messages starts with a user message
|
||||
# This maintains complete conversation turns
|
||||
split_index = len(non_system_messages) - keep_recent
|
||||
|
||||
# Search backward from split_index to find the first user message
|
||||
# This ensures recent_messages starts with a user message (complete turn)
|
||||
while split_index > 0 and non_system_messages[split_index].role != "user":
|
||||
# TODO: +=1 or -=1 ? calculate by tokens
|
||||
split_index -= 1
|
||||
|
||||
# If we couldn't find a user message, keep all messages as recent
|
||||
if split_index == 0:
|
||||
return system_messages, [], non_system_messages
|
||||
|
||||
messages_to_summarize = non_system_messages[:split_index]
|
||||
recent_messages = non_system_messages[split_index:]
|
||||
|
||||
return system_messages, messages_to_summarize, recent_messages
|
||||
|
||||
|
||||
class LLMSummaryCompressor:
|
||||
"""LLM-based summary compressor.
|
||||
Uses LLM to summarize the old conversation history, keeping the latest messages.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
provider: "Provider",
|
||||
keep_recent: int = 4,
|
||||
instruction_text: str | None = None,
|
||||
compression_threshold: float = 0.82,
|
||||
):
|
||||
"""Initialize the LLM summary compressor.
|
||||
|
||||
Args:
|
||||
provider: The LLM provider instance.
|
||||
keep_recent: The number of latest messages to keep (default: 4).
|
||||
instruction_text: Custom instruction for summary generation.
|
||||
compression_threshold: The compression trigger threshold (default: 0.82).
|
||||
"""
|
||||
self.provider = provider
|
||||
self.keep_recent = keep_recent
|
||||
self.compression_threshold = compression_threshold
|
||||
|
||||
self.instruction_text = instruction_text or (
|
||||
"Based on our full conversation history, produce a concise summary of key takeaways and/or project progress.\n"
|
||||
"1. Systematically cover all core topics discussed and the final conclusion/outcome for each; clearly highlight the latest primary focus.\n"
|
||||
"2. If any tools were used, summarize tool usage (total call count) and extract the most valuable insights from tool outputs.\n"
|
||||
"3. If there was an initial user goal, state it first and describe the current progress/status.\n"
|
||||
"4. Write the summary in the user's language.\n"
|
||||
)
|
||||
|
||||
def should_compress(
|
||||
self, messages: list[Message], current_tokens: int, max_tokens: int
|
||||
) -> bool:
|
||||
"""Check if compression is needed.
|
||||
|
||||
Args:
|
||||
messages: The message list to evaluate.
|
||||
current_tokens: The current token count.
|
||||
max_tokens: The maximum allowed tokens.
|
||||
|
||||
Returns:
|
||||
True if compression is needed, False otherwise.
|
||||
"""
|
||||
if max_tokens <= 0 or current_tokens <= 0:
|
||||
return False
|
||||
usage_rate = current_tokens / max_tokens
|
||||
return usage_rate > self.compression_threshold
|
||||
|
||||
async def __call__(self, messages: list[Message]) -> list[Message]:
|
||||
"""Use LLM to generate a summary of the conversation history.
|
||||
|
||||
Process:
|
||||
1. Divide messages: keep the system message and the latest N messages.
|
||||
2. Send the old messages + the instruction message to the LLM.
|
||||
3. Reconstruct the message list: [system message, summary message, latest messages].
|
||||
"""
|
||||
if len(messages) <= self.keep_recent + 1:
|
||||
return messages
|
||||
|
||||
system_messages, messages_to_summarize, recent_messages = split_history(
|
||||
messages, self.keep_recent
|
||||
)
|
||||
|
||||
if not messages_to_summarize:
|
||||
return messages
|
||||
|
||||
# build payload
|
||||
instruction_message = Message(role="user", content=self.instruction_text)
|
||||
llm_payload = messages_to_summarize + [instruction_message]
|
||||
|
||||
# generate summary
|
||||
try:
|
||||
response = await self.provider.text_chat(contexts=llm_payload)
|
||||
summary_content = response.completion_text
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate summary: {e}")
|
||||
return messages
|
||||
|
||||
# build result
|
||||
result = []
|
||||
result.extend(system_messages)
|
||||
|
||||
result.append(
|
||||
Message(
|
||||
role="user",
|
||||
content=f"Our previous history conversation summary: {summary_content}",
|
||||
)
|
||||
)
|
||||
result.append(
|
||||
Message(
|
||||
role="assistant",
|
||||
content="Acknowledged the summary of our previous conversation history.",
|
||||
)
|
||||
)
|
||||
|
||||
result.extend(recent_messages)
|
||||
|
||||
return result
|
||||
@@ -0,0 +1,35 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .compressor import ContextCompressor
|
||||
from .token_counter import TokenCounter
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from astrbot.core.provider.provider import Provider
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContextConfig:
|
||||
"""Context configuration class."""
|
||||
|
||||
max_context_tokens: int = 0
|
||||
"""Maximum number of context tokens. <= 0 means no limit."""
|
||||
enforce_max_turns: int = -1 # -1 means no limit
|
||||
"""Maximum number of conversation turns to keep. -1 means no limit. Executed before compression."""
|
||||
truncate_turns: int = 1
|
||||
"""Number of conversation turns to discard at once when truncation is triggered.
|
||||
Two processes will use this value:
|
||||
|
||||
1. Enforce max turns truncation.
|
||||
2. Truncation by turns compression strategy.
|
||||
"""
|
||||
llm_compress_instruction: str | None = None
|
||||
"""Instruction prompt for LLM-based compression."""
|
||||
llm_compress_keep_recent: int = 0
|
||||
"""Number of recent messages to keep during LLM-based compression."""
|
||||
llm_compress_provider: "Provider | None" = None
|
||||
"""LLM provider used for compression tasks. If None, truncation strategy is used."""
|
||||
custom_token_counter: TokenCounter | None = None
|
||||
"""Custom token counting method. If None, the default method is used."""
|
||||
custom_compressor: ContextCompressor | None = None
|
||||
"""Custom context compression method. If None, the default method is used."""
|
||||
@@ -0,0 +1,120 @@
|
||||
from astrbot import logger
|
||||
|
||||
from ..message import Message
|
||||
from .compressor import LLMSummaryCompressor, TruncateByTurnsCompressor
|
||||
from .config import ContextConfig
|
||||
from .token_counter import EstimateTokenCounter
|
||||
from .truncator import ContextTruncator
|
||||
|
||||
|
||||
class ContextManager:
|
||||
"""Context compression manager."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: ContextConfig,
|
||||
):
|
||||
"""Initialize the context manager.
|
||||
|
||||
There are two strategies to handle context limit reached:
|
||||
1. Truncate by turns: remove older messages by turns.
|
||||
2. LLM-based compression: use LLM to summarize old messages.
|
||||
|
||||
Args:
|
||||
config: The context configuration.
|
||||
"""
|
||||
self.config = config
|
||||
|
||||
self.token_counter = config.custom_token_counter or EstimateTokenCounter()
|
||||
self.truncator = ContextTruncator()
|
||||
|
||||
if config.custom_compressor:
|
||||
self.compressor = config.custom_compressor
|
||||
elif config.llm_compress_provider:
|
||||
self.compressor = LLMSummaryCompressor(
|
||||
provider=config.llm_compress_provider,
|
||||
keep_recent=config.llm_compress_keep_recent,
|
||||
instruction_text=config.llm_compress_instruction,
|
||||
)
|
||||
else:
|
||||
self.compressor = TruncateByTurnsCompressor(
|
||||
truncate_turns=config.truncate_turns
|
||||
)
|
||||
|
||||
async def process(
|
||||
self, messages: list[Message], trusted_token_usage: int = 0
|
||||
) -> list[Message]:
|
||||
"""Process the messages.
|
||||
|
||||
Args:
|
||||
messages: The original message list.
|
||||
|
||||
Returns:
|
||||
The processed message list.
|
||||
"""
|
||||
try:
|
||||
result = messages
|
||||
|
||||
# 1. 基于轮次的截断 (Enforce max turns)
|
||||
if self.config.enforce_max_turns != -1:
|
||||
result = self.truncator.truncate_by_turns(
|
||||
result,
|
||||
keep_most_recent_turns=self.config.enforce_max_turns,
|
||||
drop_turns=self.config.truncate_turns,
|
||||
)
|
||||
|
||||
# 2. 基于 token 的压缩
|
||||
if self.config.max_context_tokens > 0:
|
||||
total_tokens = self.token_counter.count_tokens(
|
||||
result, trusted_token_usage
|
||||
)
|
||||
|
||||
if self.compressor.should_compress(
|
||||
result, total_tokens, self.config.max_context_tokens
|
||||
):
|
||||
result = await self._run_compression(result, total_tokens)
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Error during context processing: {e}", exc_info=True)
|
||||
return messages
|
||||
|
||||
async def _run_compression(
|
||||
self, messages: list[Message], prev_tokens: int
|
||||
) -> list[Message]:
|
||||
"""
|
||||
Compress/truncate the messages.
|
||||
|
||||
Args:
|
||||
messages: The original message list.
|
||||
prev_tokens: The token count before compression.
|
||||
|
||||
Returns:
|
||||
The compressed/truncated message list.
|
||||
"""
|
||||
logger.debug("Compress triggered, starting compression...")
|
||||
|
||||
messages = await self.compressor(messages)
|
||||
|
||||
# double check
|
||||
tokens_after_summary = self.token_counter.count_tokens(messages)
|
||||
|
||||
# calculate compress rate
|
||||
compress_rate = (tokens_after_summary / self.config.max_context_tokens) * 100
|
||||
logger.info(
|
||||
f"Compress completed."
|
||||
f" {prev_tokens} -> {tokens_after_summary} tokens,"
|
||||
f" compression rate: {compress_rate:.2f}%.",
|
||||
)
|
||||
|
||||
# last check
|
||||
if self.compressor.should_compress(
|
||||
messages, tokens_after_summary, self.config.max_context_tokens
|
||||
):
|
||||
logger.info(
|
||||
"Context still exceeds max tokens after compression, applying halving truncation..."
|
||||
)
|
||||
# still need compress, truncate by half
|
||||
messages = self.truncator.truncate_by_halving(messages)
|
||||
|
||||
return messages
|
||||
@@ -0,0 +1,64 @@
|
||||
import json
|
||||
from typing import Protocol, runtime_checkable
|
||||
|
||||
from ..message import Message, TextPart
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class TokenCounter(Protocol):
|
||||
"""
|
||||
Protocol for token counters.
|
||||
Provides an interface for counting tokens in message lists.
|
||||
"""
|
||||
|
||||
def count_tokens(
|
||||
self, messages: list[Message], trusted_token_usage: int = 0
|
||||
) -> int:
|
||||
"""Count the total tokens in the message list.
|
||||
|
||||
Args:
|
||||
messages: The message list.
|
||||
trusted_token_usage: The total token usage that LLM API returned.
|
||||
For some cases, this value is more accurate.
|
||||
But some API does not return it, so the value defaults to 0.
|
||||
|
||||
Returns:
|
||||
The total token count.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class EstimateTokenCounter:
|
||||
"""Estimate token counter implementation.
|
||||
Provides a simple estimation of token count based on character types.
|
||||
"""
|
||||
|
||||
def count_tokens(
|
||||
self, messages: list[Message], trusted_token_usage: int = 0
|
||||
) -> int:
|
||||
if trusted_token_usage > 0:
|
||||
return trusted_token_usage
|
||||
|
||||
total = 0
|
||||
for msg in messages:
|
||||
content = msg.content
|
||||
if isinstance(content, str):
|
||||
total += self._estimate_tokens(content)
|
||||
elif isinstance(content, list):
|
||||
# 处理多模态内容
|
||||
for part in content:
|
||||
if isinstance(part, TextPart):
|
||||
total += self._estimate_tokens(part.text)
|
||||
|
||||
# 处理 Tool Calls
|
||||
if msg.tool_calls:
|
||||
for tc in msg.tool_calls:
|
||||
tc_str = json.dumps(tc if isinstance(tc, dict) else tc.model_dump())
|
||||
total += self._estimate_tokens(tc_str)
|
||||
|
||||
return total
|
||||
|
||||
def _estimate_tokens(self, text: str) -> int:
|
||||
chinese_count = len([c for c in text if "\u4e00" <= c <= "\u9fff"])
|
||||
other_count = len(text) - chinese_count
|
||||
return int(chinese_count * 0.6 + other_count * 0.3)
|
||||
@@ -0,0 +1,141 @@
|
||||
from ..message import Message
|
||||
|
||||
|
||||
class ContextTruncator:
|
||||
"""Context truncator."""
|
||||
|
||||
def fix_messages(self, messages: list[Message]) -> list[Message]:
|
||||
fixed_messages = []
|
||||
for message in messages:
|
||||
if message.role == "tool":
|
||||
# tool block 前面必须要有 user 和 assistant block
|
||||
if len(fixed_messages) < 2:
|
||||
# 这种情况可能是上下文被截断导致的
|
||||
# 我们直接将之前的上下文都清空
|
||||
fixed_messages = []
|
||||
else:
|
||||
fixed_messages.append(message)
|
||||
else:
|
||||
fixed_messages.append(message)
|
||||
return fixed_messages
|
||||
|
||||
def truncate_by_turns(
|
||||
self,
|
||||
messages: list[Message],
|
||||
keep_most_recent_turns: int,
|
||||
drop_turns: int = 1,
|
||||
) -> list[Message]:
|
||||
"""截断上下文列表,确保不超过最大长度。
|
||||
一个 turn 包含一个 user 消息和一个 assistant 消息。
|
||||
这个方法会保证截断后的上下文列表符合 OpenAI 的上下文格式。
|
||||
|
||||
Args:
|
||||
messages: 上下文列表
|
||||
keep_most_recent_turns: 保留最近的对话轮数
|
||||
drop_turns: 一次性丢弃的对话轮数
|
||||
|
||||
Returns:
|
||||
截断后的上下文列表
|
||||
"""
|
||||
if keep_most_recent_turns == -1:
|
||||
return messages
|
||||
|
||||
first_non_system = 0
|
||||
for i, msg in enumerate(messages):
|
||||
if msg.role != "system":
|
||||
first_non_system = i
|
||||
break
|
||||
|
||||
system_messages = messages[:first_non_system]
|
||||
non_system_messages = messages[first_non_system:]
|
||||
|
||||
if len(non_system_messages) // 2 <= keep_most_recent_turns:
|
||||
return messages
|
||||
|
||||
num_to_keep = keep_most_recent_turns - drop_turns + 1
|
||||
if num_to_keep <= 0:
|
||||
truncated_contexts = []
|
||||
else:
|
||||
truncated_contexts = non_system_messages[-num_to_keep * 2 :]
|
||||
|
||||
# 找到第一个 role 为 user 的索引,确保上下文格式正确
|
||||
index = next(
|
||||
(i for i, item in enumerate(truncated_contexts) if item.role == "user"),
|
||||
None,
|
||||
)
|
||||
if index is not None and index > 0:
|
||||
truncated_contexts = truncated_contexts[index:]
|
||||
|
||||
result = system_messages + truncated_contexts
|
||||
|
||||
return self.fix_messages(result)
|
||||
|
||||
def truncate_by_dropping_oldest_turns(
|
||||
self,
|
||||
messages: list[Message],
|
||||
drop_turns: int = 1,
|
||||
) -> list[Message]:
|
||||
"""丢弃最旧的 N 个对话轮次。"""
|
||||
if drop_turns <= 0:
|
||||
return messages
|
||||
|
||||
first_non_system = 0
|
||||
for i, msg in enumerate(messages):
|
||||
if msg.role != "system":
|
||||
first_non_system = i
|
||||
break
|
||||
|
||||
system_messages = messages[:first_non_system]
|
||||
non_system_messages = messages[first_non_system:]
|
||||
|
||||
if len(non_system_messages) // 2 <= drop_turns:
|
||||
truncated_non_system = []
|
||||
else:
|
||||
truncated_non_system = non_system_messages[drop_turns * 2 :]
|
||||
|
||||
index = next(
|
||||
(i for i, item in enumerate(truncated_non_system) if item.role == "user"),
|
||||
None,
|
||||
)
|
||||
if index is not None:
|
||||
truncated_non_system = truncated_non_system[index:]
|
||||
elif truncated_non_system:
|
||||
truncated_non_system = []
|
||||
|
||||
result = system_messages + truncated_non_system
|
||||
|
||||
return self.fix_messages(result)
|
||||
|
||||
def truncate_by_halving(
|
||||
self,
|
||||
messages: list[Message],
|
||||
) -> list[Message]:
|
||||
"""对半砍策略,删除 50% 的消息"""
|
||||
if len(messages) <= 2:
|
||||
return messages
|
||||
|
||||
first_non_system = 0
|
||||
for i, msg in enumerate(messages):
|
||||
if msg.role != "system":
|
||||
first_non_system = i
|
||||
break
|
||||
|
||||
system_messages = messages[:first_non_system]
|
||||
non_system_messages = messages[first_non_system:]
|
||||
|
||||
messages_to_delete = len(non_system_messages) // 2
|
||||
if messages_to_delete == 0:
|
||||
return messages
|
||||
|
||||
truncated_non_system = non_system_messages[messages_to_delete:]
|
||||
|
||||
index = next(
|
||||
(i for i, item in enumerate(truncated_non_system) if item.role == "user"),
|
||||
None,
|
||||
)
|
||||
if index is not None:
|
||||
truncated_non_system = truncated_non_system[index:]
|
||||
|
||||
result = system_messages + truncated_non_system
|
||||
|
||||
return self.fix_messages(result)
|
||||
@@ -25,6 +25,10 @@ from astrbot.core.provider.entities import (
|
||||
)
|
||||
from astrbot.core.provider.provider import Provider
|
||||
|
||||
from ..context.compressor import ContextCompressor
|
||||
from ..context.config import ContextConfig
|
||||
from ..context.manager import ContextManager
|
||||
from ..context.token_counter import TokenCounter
|
||||
from ..hooks import BaseAgentRunHooks
|
||||
from ..message import AssistantMessageSegment, Message, ToolCallMessageSegment
|
||||
from ..response import AgentResponseData, AgentStats
|
||||
@@ -47,10 +51,47 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
run_context: ContextWrapper[TContext],
|
||||
tool_executor: BaseFunctionToolExecutor[TContext],
|
||||
agent_hooks: BaseAgentRunHooks[TContext],
|
||||
streaming: bool = False,
|
||||
# enforce max turns, will discard older turns when exceeded BEFORE compression
|
||||
# -1 means no limit
|
||||
enforce_max_turns: int = -1,
|
||||
# llm compressor
|
||||
llm_compress_instruction: str | None = None,
|
||||
llm_compress_keep_recent: int = 0,
|
||||
llm_compress_provider: Provider | None = None,
|
||||
# truncate by turns compressor
|
||||
truncate_turns: int = 1,
|
||||
# customize
|
||||
custom_token_counter: TokenCounter | None = None,
|
||||
custom_compressor: ContextCompressor | None = None,
|
||||
**kwargs: T.Any,
|
||||
) -> None:
|
||||
self.req = request
|
||||
self.streaming = kwargs.get("streaming", False)
|
||||
self.streaming = streaming
|
||||
self.enforce_max_turns = enforce_max_turns
|
||||
self.llm_compress_instruction = llm_compress_instruction
|
||||
self.llm_compress_keep_recent = llm_compress_keep_recent
|
||||
self.llm_compress_provider = llm_compress_provider
|
||||
self.truncate_turns = truncate_turns
|
||||
self.custom_token_counter = custom_token_counter
|
||||
self.custom_compressor = custom_compressor
|
||||
# we will do compress when:
|
||||
# 1. before requesting LLM
|
||||
# TODO: 2. after LLM output a tool call
|
||||
self.context_config = ContextConfig(
|
||||
# <=0 will never do compress
|
||||
max_context_tokens=provider.provider_config.get("max_context_tokens", 0),
|
||||
# enforce max turns before compression
|
||||
enforce_max_turns=self.enforce_max_turns,
|
||||
truncate_turns=self.truncate_turns,
|
||||
llm_compress_instruction=self.llm_compress_instruction,
|
||||
llm_compress_keep_recent=self.llm_compress_keep_recent,
|
||||
llm_compress_provider=self.llm_compress_provider,
|
||||
custom_token_counter=self.custom_token_counter,
|
||||
custom_compressor=self.custom_compressor,
|
||||
)
|
||||
self.context_manager = ContextManager(self.context_config)
|
||||
|
||||
self.provider = provider
|
||||
self.final_llm_resp = None
|
||||
self._state = AgentState.IDLE
|
||||
@@ -110,6 +151,12 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
self._transition_state(AgentState.RUNNING)
|
||||
llm_resp_result = None
|
||||
|
||||
# do truncate and compress
|
||||
token_usage = self.req.conversation.token_usage if self.req.conversation else 0
|
||||
self.run_context.messages = await self.context_manager.process(
|
||||
self.run_context.messages, trusted_token_usage=token_usage
|
||||
)
|
||||
|
||||
async for llm_response in self._iter_llm_responses():
|
||||
if llm_response.is_chunk:
|
||||
# update ttft
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Any, TypedDict
|
||||
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
VERSION = "4.10.4"
|
||||
VERSION = "4.11.0"
|
||||
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db")
|
||||
|
||||
WEBHOOK_SUPPORTED_PLATFORMS = [
|
||||
@@ -83,6 +83,16 @@ DEFAULT_CONFIG = {
|
||||
"default_personality": "default",
|
||||
"persona_pool": ["*"],
|
||||
"prompt_prefix": "{{prompt}}",
|
||||
"context_limit_reached_strategy": "truncate_by_turns", # or llm_compress
|
||||
"llm_compress_instruction": (
|
||||
"Based on our full conversation history, produce a concise summary of key takeaways and/or project progress.\n"
|
||||
"1. Systematically cover all core topics discussed and the final conclusion/outcome for each; clearly highlight the latest primary focus.\n"
|
||||
"2. If any tools were used, summarize tool usage (total call count) and extract the most valuable insights from tool outputs.\n"
|
||||
"3. If there was an initial user goal, state it first and describe the current progress/status.\n"
|
||||
"4. Write the summary in the user's language.\n"
|
||||
),
|
||||
"llm_compress_keep_recent": 4,
|
||||
"llm_compress_provider_id": "",
|
||||
"max_context_length": -1,
|
||||
"dequeue_context_length": 1,
|
||||
"streaming_response": False,
|
||||
@@ -179,6 +189,7 @@ class ChatProviderTemplate(TypedDict):
|
||||
model: str
|
||||
modalities: list
|
||||
custom_extra_body: dict[str, Any]
|
||||
max_context_tokens: int
|
||||
|
||||
|
||||
CHAT_PROVIDER_TEMPLATE = {
|
||||
@@ -187,6 +198,7 @@ CHAT_PROVIDER_TEMPLATE = {
|
||||
"model": "",
|
||||
"modalities": [],
|
||||
"custom_extra_body": {},
|
||||
"max_context_tokens": 0,
|
||||
}
|
||||
|
||||
"""
|
||||
@@ -227,7 +239,7 @@ CONFIG_METADATA_2 = {
|
||||
"callback_server_host": "0.0.0.0",
|
||||
"port": 6196,
|
||||
},
|
||||
"OneBot v11": {
|
||||
"OneBot v11 (QQ 个人号等)": {
|
||||
"id": "default",
|
||||
"type": "aiocqhttp",
|
||||
"enable": False,
|
||||
@@ -235,16 +247,6 @@ CONFIG_METADATA_2 = {
|
||||
"ws_reverse_port": 6199,
|
||||
"ws_reverse_token": "",
|
||||
},
|
||||
"WeChatPadPro": {
|
||||
"id": "wechatpadpro",
|
||||
"type": "wechatpadpro",
|
||||
"enable": False,
|
||||
"admin_key": "stay33",
|
||||
"host": "这里填写你的局域网IP或者公网服务器IP",
|
||||
"port": 8059,
|
||||
"wpp_active_message_poll": False,
|
||||
"wpp_active_message_poll_interval": 3,
|
||||
},
|
||||
"微信公众平台": {
|
||||
"id": "weixin_official_account",
|
||||
"type": "weixin_official_account",
|
||||
@@ -374,6 +376,16 @@ CONFIG_METADATA_2 = {
|
||||
"satori_heartbeat_interval": 10,
|
||||
"satori_reconnect_delay": 5,
|
||||
},
|
||||
"WeChatPadPro": {
|
||||
"id": "wechatpadpro",
|
||||
"type": "wechatpadpro",
|
||||
"enable": False,
|
||||
"admin_key": "stay33",
|
||||
"host": "这里填写你的局域网IP或者公网服务器IP",
|
||||
"port": 8059,
|
||||
"wpp_active_message_poll": False,
|
||||
"wpp_active_message_poll_interval": 3,
|
||||
},
|
||||
# "WebChat": {
|
||||
# "id": "webchat",
|
||||
# "type": "webchat",
|
||||
@@ -2033,6 +2045,11 @@ CONFIG_METADATA_2 = {
|
||||
"type": "string",
|
||||
"hint": "模型名称,如 gpt-4o-mini, deepseek-chat。",
|
||||
},
|
||||
"max_context_tokens": {
|
||||
"description": "模型上下文窗口大小",
|
||||
"type": "int",
|
||||
"hint": "模型最大上下文 Token 大小。如果为 0,则会自动从模型元数据填充(如有),也可手动修改。",
|
||||
},
|
||||
"dify_api_key": {
|
||||
"description": "API Key",
|
||||
"type": "string",
|
||||
@@ -2540,6 +2557,66 @@ CONFIG_METADATA_3 = {
|
||||
# "provider_settings.enable": True,
|
||||
# },
|
||||
# },
|
||||
"truncate_and_compress": {
|
||||
"description": "上下文管理策略",
|
||||
"type": "object",
|
||||
"items": {
|
||||
"provider_settings.max_context_length": {
|
||||
"description": "最多携带对话轮数",
|
||||
"type": "int",
|
||||
"hint": "超出这个数量时丢弃最旧的部分,一轮聊天记为 1 条,-1 为不限制",
|
||||
"condition": {
|
||||
"provider_settings.agent_runner_type": "local",
|
||||
},
|
||||
},
|
||||
"provider_settings.dequeue_context_length": {
|
||||
"description": "丢弃对话轮数",
|
||||
"type": "int",
|
||||
"hint": "超出最多携带对话轮数时, 一次丢弃的聊天轮数",
|
||||
"condition": {
|
||||
"provider_settings.agent_runner_type": "local",
|
||||
},
|
||||
},
|
||||
"provider_settings.context_limit_reached_strategy": {
|
||||
"description": "超出模型上下文窗口时的处理方式",
|
||||
"type": "string",
|
||||
"options": ["truncate_by_turns", "llm_compress"],
|
||||
"labels": ["按对话轮数截断", "由 LLM 压缩上下文"],
|
||||
"condition": {
|
||||
"provider_settings.agent_runner_type": "local",
|
||||
},
|
||||
"hint": "",
|
||||
},
|
||||
"provider_settings.llm_compress_instruction": {
|
||||
"description": "上下文压缩提示词",
|
||||
"type": "text",
|
||||
"hint": "如果为空则使用默认提示词。",
|
||||
"condition": {
|
||||
"provider_settings.context_limit_reached_strategy": "llm_compress",
|
||||
"provider_settings.agent_runner_type": "local",
|
||||
},
|
||||
},
|
||||
"provider_settings.llm_compress_keep_recent": {
|
||||
"description": "压缩时保留最近对话轮数",
|
||||
"type": "int",
|
||||
"hint": "始终保留的最近 N 轮对话。",
|
||||
"condition": {
|
||||
"provider_settings.context_limit_reached_strategy": "llm_compress",
|
||||
"provider_settings.agent_runner_type": "local",
|
||||
},
|
||||
},
|
||||
"provider_settings.llm_compress_provider_id": {
|
||||
"description": "用于上下文压缩的模型提供商 ID",
|
||||
"type": "string",
|
||||
"_special": "select_provider",
|
||||
"hint": "留空时将降级为“按对话轮数截断”的策略。",
|
||||
"condition": {
|
||||
"provider_settings.context_limit_reached_strategy": "llm_compress",
|
||||
"provider_settings.agent_runner_type": "local",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"others": {
|
||||
"description": "其他配置",
|
||||
"type": "object",
|
||||
@@ -2604,22 +2681,6 @@ CONFIG_METADATA_3 = {
|
||||
"provider_settings.streaming_response": True,
|
||||
},
|
||||
},
|
||||
"provider_settings.max_context_length": {
|
||||
"description": "最多携带对话轮数",
|
||||
"type": "int",
|
||||
"hint": "超出这个数量时丢弃最旧的部分,一轮聊天记为 1 条,-1 为不限制",
|
||||
"condition": {
|
||||
"provider_settings.agent_runner_type": "local",
|
||||
},
|
||||
},
|
||||
"provider_settings.dequeue_context_length": {
|
||||
"description": "丢弃对话轮数",
|
||||
"type": "int",
|
||||
"hint": "超出最多携带对话轮数时, 一次丢弃的聊天轮数",
|
||||
"condition": {
|
||||
"provider_settings.agent_runner_type": "local",
|
||||
},
|
||||
},
|
||||
"provider_settings.wake_prefix": {
|
||||
"description": "LLM 聊天额外唤醒前缀 ",
|
||||
"type": "string",
|
||||
|
||||
@@ -69,6 +69,7 @@ class ConversationManager:
|
||||
persona_id=conv_v2.persona_id,
|
||||
created_at=created_at,
|
||||
updated_at=updated_at,
|
||||
token_usage=conv_v2.token_usage,
|
||||
)
|
||||
|
||||
async def new_conversation(
|
||||
@@ -256,6 +257,7 @@ class ConversationManager:
|
||||
history: list[dict] | None = None,
|
||||
title: str | None = None,
|
||||
persona_id: str | None = None,
|
||||
token_usage: int | None = None,
|
||||
) -> None:
|
||||
"""更新会话的对话.
|
||||
|
||||
@@ -263,6 +265,7 @@ class ConversationManager:
|
||||
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
||||
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
|
||||
history (List[Dict]): 对话历史记录, 是一个字典列表, 每个字典包含 role 和 content 字段
|
||||
token_usage (int | None): token 使用量。None 表示不更新
|
||||
|
||||
"""
|
||||
if not conversation_id:
|
||||
@@ -274,6 +277,7 @@ class ConversationManager:
|
||||
title=title,
|
||||
persona_id=persona_id,
|
||||
content=history,
|
||||
token_usage=token_usage,
|
||||
)
|
||||
|
||||
async def update_conversation_title(
|
||||
|
||||
@@ -90,6 +90,7 @@ class AstrBotCoreLifecycle:
|
||||
|
||||
# 初始化 UMOP 配置路由器
|
||||
self.umop_config_router = UmopConfigRouter(sp=sp)
|
||||
await self.umop_config_router.initialize()
|
||||
|
||||
# 初始化 AstrBot 配置管理器
|
||||
self.astrbot_config_mgr = AstrBotConfigManager(
|
||||
|
||||
@@ -152,6 +152,7 @@ class BaseDatabase(abc.ABC):
|
||||
title: str | None = None,
|
||||
persona_id: str | None = None,
|
||||
content: list[dict] | None = None,
|
||||
token_usage: int | None = None,
|
||||
) -> None:
|
||||
"""Update a conversation's history."""
|
||||
...
|
||||
|
||||
@@ -0,0 +1,61 @@
|
||||
"""Migration script to add token_usage column to conversations table.
|
||||
|
||||
This migration adds the token_usage field to track token consumption for each conversation.
|
||||
|
||||
Changes:
|
||||
- Adds token_usage column to conversations table (default: 0)
|
||||
"""
|
||||
|
||||
from sqlalchemy import text
|
||||
|
||||
from astrbot.api import logger, sp
|
||||
from astrbot.core.db import BaseDatabase
|
||||
|
||||
|
||||
async def migrate_token_usage(db_helper: BaseDatabase):
|
||||
"""Add token_usage column to conversations table.
|
||||
|
||||
This migration adds a new column to track token consumption in conversations.
|
||||
"""
|
||||
# 检查是否已经完成迁移
|
||||
migration_done = await db_helper.get_preference(
|
||||
"global", "global", "migration_done_token_usage_1"
|
||||
)
|
||||
if migration_done:
|
||||
return
|
||||
|
||||
logger.info("开始执行数据库迁移(添加 conversations.token_usage 列)...")
|
||||
|
||||
# 这里只适配了 SQLite。因为截止至这一版本,AstrBot 仅支持 SQLite。
|
||||
|
||||
try:
|
||||
async with db_helper.get_db() as session:
|
||||
# 检查列是否已存在
|
||||
result = await session.execute(text("PRAGMA table_info(conversations)"))
|
||||
columns = result.fetchall()
|
||||
column_names = [col[1] for col in columns]
|
||||
|
||||
if "token_usage" in column_names:
|
||||
logger.info("token_usage 列已存在,跳过迁移")
|
||||
await sp.put_async(
|
||||
"global", "global", "migration_done_token_usage_1", True
|
||||
)
|
||||
return
|
||||
|
||||
# 添加 token_usage 列
|
||||
await session.execute(
|
||||
text(
|
||||
"ALTER TABLE conversations ADD COLUMN token_usage INTEGER NOT NULL DEFAULT 0"
|
||||
)
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
logger.info("token_usage 列添加成功")
|
||||
|
||||
# 标记迁移完成
|
||||
await sp.put_async("global", "global", "migration_done_token_usage_1", True)
|
||||
logger.info("token_usage 迁移完成")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"迁移过程中发生错误: {e}", exc_info=True)
|
||||
raise
|
||||
@@ -54,6 +54,11 @@ class ConversationV2(SQLModel, table=True):
|
||||
)
|
||||
title: str | None = Field(default=None, max_length=255)
|
||||
persona_id: str | None = Field(default=None)
|
||||
token_usage: int = Field(default=0, nullable=False)
|
||||
"""content is a list of OpenAI-formated messages in list[dict] format.
|
||||
token_usage is the total token value of the messages.
|
||||
when 0, will use estimated token counter.
|
||||
"""
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
@@ -313,6 +318,8 @@ class Conversation:
|
||||
persona_id: str | None = ""
|
||||
created_at: int = 0
|
||||
updated_at: int = 0
|
||||
token_usage: int = 0
|
||||
"""对话的总 token 数量。AstrBot 会保留最近一次 LLM 请求返回的总 token 数,方便统计。token_usage 可能为 0,表示未知。"""
|
||||
|
||||
|
||||
class Personality(TypedDict):
|
||||
|
||||
@@ -241,7 +241,9 @@ class SQLiteDatabase(BaseDatabase):
|
||||
session.add(new_conversation)
|
||||
return new_conversation
|
||||
|
||||
async def update_conversation(self, cid, title=None, persona_id=None, content=None):
|
||||
async def update_conversation(
|
||||
self, cid, title=None, persona_id=None, content=None, token_usage=None
|
||||
):
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
async with session.begin():
|
||||
@@ -255,6 +257,8 @@ class SQLiteDatabase(BaseDatabase):
|
||||
values["persona_id"] = persona_id
|
||||
if content is not None:
|
||||
values["content"] = content
|
||||
if token_usage is not None:
|
||||
values["token_usage"] = token_usage
|
||||
if not values:
|
||||
return None
|
||||
query = query.values(**values)
|
||||
|
||||
@@ -38,7 +38,7 @@ class AgentRequestSubStage(Stage):
|
||||
)
|
||||
return
|
||||
|
||||
if not SessionServiceManager.should_process_llm_request(event):
|
||||
if not await SessionServiceManager.should_process_llm_request(event):
|
||||
logger.debug(
|
||||
f"The session {event.unified_msg_origin} has disabled AI capability, skipping processing."
|
||||
)
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
"""本地 Agent 模式的 LLM 调用 Stage"""
|
||||
|
||||
import asyncio
|
||||
import copy
|
||||
import json
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.agent.message import Message
|
||||
from astrbot.core.agent.response import AgentStats
|
||||
from astrbot.core.agent.tool import ToolSet
|
||||
from astrbot.core.astr_agent_context import AstrAgentContext
|
||||
from astrbot.core.conversation_mgr import Conversation
|
||||
@@ -24,6 +24,7 @@ from astrbot.core.provider.entities import (
|
||||
)
|
||||
from astrbot.core.star.star_handler import EventType, star_map
|
||||
from astrbot.core.utils.file_extract import extract_file_moonshotai
|
||||
from astrbot.core.utils.llm_metadata import LLM_METADATAS
|
||||
from astrbot.core.utils.metrics import Metric
|
||||
from astrbot.core.utils.session_lock import session_lock_manager
|
||||
|
||||
@@ -41,11 +42,6 @@ class InternalAgentSubStage(Stage):
|
||||
self.ctx = ctx
|
||||
conf = ctx.astrbot_config
|
||||
settings = conf["provider_settings"]
|
||||
self.max_context_length = settings["max_context_length"] # int
|
||||
self.dequeue_context_length: int = min(
|
||||
max(1, settings["dequeue_context_length"]),
|
||||
self.max_context_length - 1,
|
||||
)
|
||||
self.streaming_response: bool = settings["streaming_response"]
|
||||
self.unsupported_streaming_strategy: str = settings[
|
||||
"unsupported_streaming_strategy"
|
||||
@@ -65,6 +61,25 @@ class InternalAgentSubStage(Stage):
|
||||
"moonshotai_api_key", ""
|
||||
)
|
||||
|
||||
# 上下文管理相关
|
||||
self.context_limit_reached_strategy: str = settings.get(
|
||||
"context_limit_reached_strategy", "truncate_by_turns"
|
||||
)
|
||||
self.llm_compress_instruction: str = settings.get(
|
||||
"llm_compress_instruction", ""
|
||||
)
|
||||
self.llm_compress_keep_recent: int = settings.get("llm_compress_keep_recent", 4)
|
||||
self.llm_compress_provider_id: str = settings.get(
|
||||
"llm_compress_provider_id", ""
|
||||
)
|
||||
self.max_context_length = settings["max_context_length"] # int
|
||||
self.dequeue_context_length: int = min(
|
||||
max(1, settings["dequeue_context_length"]),
|
||||
self.max_context_length - 1,
|
||||
)
|
||||
if self.dequeue_context_length <= 0:
|
||||
self.dequeue_context_length = 1
|
||||
|
||||
self.conv_manager = ctx.plugin_manager.context.conversation_manager
|
||||
|
||||
def _select_provider(self, event: AstrMessageEvent):
|
||||
@@ -167,34 +182,6 @@ class InternalAgentSubStage(Stage):
|
||||
},
|
||||
)
|
||||
|
||||
def _truncate_contexts(
|
||||
self,
|
||||
contexts: list[dict],
|
||||
) -> list[dict]:
|
||||
"""截断上下文列表,确保不超过最大长度"""
|
||||
if self.max_context_length == -1:
|
||||
return contexts
|
||||
|
||||
if len(contexts) // 2 <= self.max_context_length:
|
||||
return contexts
|
||||
|
||||
truncated_contexts = contexts[
|
||||
-(self.max_context_length - self.dequeue_context_length + 1) * 2 :
|
||||
]
|
||||
# 找到第一个role 为 user 的索引,确保上下文格式正确
|
||||
index = next(
|
||||
(
|
||||
i
|
||||
for i, item in enumerate(truncated_contexts)
|
||||
if item.get("role") == "user"
|
||||
),
|
||||
None,
|
||||
)
|
||||
if index is not None and index > 0:
|
||||
truncated_contexts = truncated_contexts[index:]
|
||||
|
||||
return truncated_contexts
|
||||
|
||||
def _modalities_fix(
|
||||
self,
|
||||
provider: Provider,
|
||||
@@ -296,6 +283,7 @@ class InternalAgentSubStage(Stage):
|
||||
req: ProviderRequest,
|
||||
llm_response: LLMResponse | None,
|
||||
all_messages: list[Message],
|
||||
runner_stats: AgentStats | None,
|
||||
):
|
||||
if (
|
||||
not req
|
||||
@@ -322,27 +310,37 @@ class InternalAgentSubStage(Stage):
|
||||
continue
|
||||
message_to_save.append(message.model_dump())
|
||||
|
||||
# get token usage from agent runner stats
|
||||
token_usage = None
|
||||
if runner_stats:
|
||||
token_usage = runner_stats.token_usage.total
|
||||
|
||||
await self.conv_manager.update_conversation(
|
||||
event.unified_msg_origin,
|
||||
req.conversation.cid,
|
||||
history=message_to_save,
|
||||
token_usage=token_usage,
|
||||
)
|
||||
|
||||
def _fix_messages(self, messages: list[dict]) -> list[dict]:
|
||||
"""验证并且修复上下文"""
|
||||
fixed_messages = []
|
||||
for message in messages:
|
||||
if message.get("role") == "tool":
|
||||
# tool block 前面必须要有 user 和 assistant block
|
||||
if len(fixed_messages) < 2:
|
||||
# 这种情况可能是上下文被截断导致的
|
||||
# 我们直接将之前的上下文都清空
|
||||
fixed_messages = []
|
||||
else:
|
||||
fixed_messages.append(message)
|
||||
else:
|
||||
fixed_messages.append(message)
|
||||
return fixed_messages
|
||||
def _get_compress_provider(self) -> Provider | None:
|
||||
if not self.llm_compress_provider_id:
|
||||
return None
|
||||
if self.context_limit_reached_strategy != "llm_compress":
|
||||
return None
|
||||
provider = self.ctx.plugin_manager.context.get_provider_by_id(
|
||||
self.llm_compress_provider_id,
|
||||
)
|
||||
if provider is None:
|
||||
logger.warning(
|
||||
f"未找到指定的上下文压缩模型 {self.llm_compress_provider_id},将跳过压缩。",
|
||||
)
|
||||
return None
|
||||
if not isinstance(provider, Provider):
|
||||
logger.warning(
|
||||
f"指定的上下文压缩模型 {self.llm_compress_provider_id} 不是对话模型,将跳过压缩。"
|
||||
)
|
||||
return None
|
||||
return provider
|
||||
|
||||
async def process(
|
||||
self, event: AstrMessageEvent, provider_wake_prefix: str
|
||||
@@ -364,6 +362,10 @@ class InternalAgentSubStage(Stage):
|
||||
streaming_response = bool(enable_streaming)
|
||||
|
||||
logger.debug("ready to request llm provider")
|
||||
|
||||
# 通知等待调用 LLM(在获取锁之前)
|
||||
await call_event_hook(event, EventType.OnWaitingLLMRequestEvent)
|
||||
|
||||
async with session_lock_manager.acquire_lock(event.unified_msg_origin):
|
||||
logger.debug("acquired session lock for llm request")
|
||||
if event.get_extra("provider_request"):
|
||||
@@ -422,9 +424,10 @@ class InternalAgentSubStage(Stage):
|
||||
await self._apply_kb(event, req)
|
||||
|
||||
# truncate contexts to fit max length
|
||||
if req.contexts:
|
||||
req.contexts = self._truncate_contexts(req.contexts)
|
||||
self._fix_messages(req.contexts)
|
||||
# NOW moved to ContextManager inside ToolLoopAgentRunner
|
||||
# if req.contexts:
|
||||
# req.contexts = self._truncate_contexts(req.contexts)
|
||||
# self._fix_messages(req.contexts)
|
||||
|
||||
# session_id
|
||||
if not req.session_id:
|
||||
@@ -440,8 +443,6 @@ class InternalAgentSubStage(Stage):
|
||||
self.unsupported_streaming_strategy == "turn_off"
|
||||
and not event.platform_meta.support_streaming_message
|
||||
)
|
||||
# 备份 req.contexts
|
||||
backup_contexts = copy.deepcopy(req.contexts)
|
||||
|
||||
# run agent
|
||||
agent_runner = AgentRunner()
|
||||
@@ -452,6 +453,15 @@ class InternalAgentSubStage(Stage):
|
||||
context=self.ctx.plugin_manager.context,
|
||||
event=event,
|
||||
)
|
||||
|
||||
# inject model context length limit
|
||||
if provider.provider_config.get("max_context_tokens", 0) <= 0:
|
||||
model = provider.get_model()
|
||||
if model_info := LLM_METADATAS.get(model):
|
||||
provider.provider_config["max_context_tokens"] = model_info[
|
||||
"limit"
|
||||
]["context"]
|
||||
|
||||
await agent_runner.reset(
|
||||
provider=provider,
|
||||
request=req,
|
||||
@@ -462,6 +472,11 @@ class InternalAgentSubStage(Stage):
|
||||
tool_executor=FunctionToolExecutor(),
|
||||
agent_hooks=MAIN_AGENT_HOOKS,
|
||||
streaming=streaming_response,
|
||||
llm_compress_instruction=self.llm_compress_instruction,
|
||||
llm_compress_keep_recent=self.llm_compress_keep_recent,
|
||||
llm_compress_provider=self._get_compress_provider(),
|
||||
truncate_turns=self.dequeue_context_length,
|
||||
enforce_max_turns=self.max_context_length,
|
||||
)
|
||||
|
||||
if streaming_response and not stream_to_general:
|
||||
@@ -507,14 +522,12 @@ class InternalAgentSubStage(Stage):
|
||||
):
|
||||
yield
|
||||
|
||||
# 恢复备份的 contexts
|
||||
req.contexts = backup_contexts
|
||||
|
||||
await self._save_to_history(
|
||||
event,
|
||||
req,
|
||||
agent_runner.get_final_llm_resp(),
|
||||
agent_runner.run_context.messages,
|
||||
agent_runner.stats,
|
||||
)
|
||||
|
||||
# 异步处理 WebChat 特殊情况
|
||||
|
||||
@@ -260,7 +260,7 @@ class ResultDecorateStage(Stage):
|
||||
should_tts = (
|
||||
bool(self.ctx.astrbot_config["provider_tts_settings"]["enable"])
|
||||
and result.is_llm_result()
|
||||
and SessionServiceManager.should_process_tts_request(event)
|
||||
and await SessionServiceManager.should_process_tts_request(event)
|
||||
and random.random() <= self.tts_trigger_probability
|
||||
and tts_provider
|
||||
)
|
||||
|
||||
@@ -21,7 +21,7 @@ class SessionStatusCheckStage(Stage):
|
||||
event: AstrMessageEvent,
|
||||
) -> None | AsyncGenerator[None, None]:
|
||||
# 检查会话是否整体启用
|
||||
if not SessionServiceManager.is_session_enabled(event.unified_msg_origin):
|
||||
if not await SessionServiceManager.is_session_enabled(event.unified_msg_origin):
|
||||
logger.debug(f"会话 {event.unified_msg_origin} 已被关闭,已终止事件传播。")
|
||||
|
||||
# workaround for #2309
|
||||
|
||||
@@ -227,7 +227,7 @@ class WakingCheckStage(Stage):
|
||||
event._extras.pop("parsed_params", None)
|
||||
|
||||
# 根据会话配置过滤插件处理器
|
||||
activated_handlers = SessionPluginManager.filter_handlers_by_session(
|
||||
activated_handlers = await SessionPluginManager.filter_handlers_by_session(
|
||||
event,
|
||||
activated_handlers,
|
||||
)
|
||||
|
||||
@@ -191,7 +191,7 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
|
||||
if self.active_send_mode:
|
||||
await self.convert_message(msg, None)
|
||||
else:
|
||||
if msg.id in self.wexin_event_workers:
|
||||
if str(msg.id) in self.wexin_event_workers:
|
||||
future = self.wexin_event_workers[str(cast(str | int, msg.id))]
|
||||
logger.debug(f"duplicate message id checked: {msg.id}")
|
||||
else:
|
||||
|
||||
@@ -119,19 +119,34 @@ class ProviderManager:
|
||||
TTSProvider,
|
||||
):
|
||||
self.curr_tts_provider_inst = prov
|
||||
sp.put("curr_provider_tts", provider_id, scope="global", scope_id="global")
|
||||
await sp.put_async(
|
||||
key="curr_provider_tts",
|
||||
value=provider_id,
|
||||
scope="global",
|
||||
scope_id="global",
|
||||
)
|
||||
elif provider_type == ProviderType.SPEECH_TO_TEXT and isinstance(
|
||||
prov,
|
||||
STTProvider,
|
||||
):
|
||||
self.curr_stt_provider_inst = prov
|
||||
sp.put("curr_provider_stt", provider_id, scope="global", scope_id="global")
|
||||
await sp.put_async(
|
||||
key="curr_provider_stt",
|
||||
value=provider_id,
|
||||
scope="global",
|
||||
scope_id="global",
|
||||
)
|
||||
elif provider_type == ProviderType.CHAT_COMPLETION and isinstance(
|
||||
prov,
|
||||
Provider,
|
||||
):
|
||||
self.curr_provider_inst = prov
|
||||
sp.put("curr_provider", provider_id, scope="global", scope_id="global")
|
||||
await sp.put_async(
|
||||
key="curr_provider",
|
||||
value=provider_id,
|
||||
scope="global",
|
||||
scope_id="global",
|
||||
)
|
||||
|
||||
async def get_provider_by_id(self, provider_id: str) -> Providers | None:
|
||||
"""根据提供商 ID 获取提供商实例"""
|
||||
@@ -206,21 +221,21 @@ class ProviderManager:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(e)
|
||||
|
||||
selected_provider_id = sp.get(
|
||||
"curr_provider",
|
||||
self.provider_settings.get("default_provider_id"),
|
||||
selected_provider_id = await sp.get_async(
|
||||
key="curr_provider",
|
||||
default=self.provider_settings.get("default_provider_id"),
|
||||
scope="global",
|
||||
scope_id="global",
|
||||
)
|
||||
selected_stt_provider_id = sp.get(
|
||||
"curr_provider_stt",
|
||||
self.provider_stt_settings.get("provider_id"),
|
||||
selected_stt_provider_id = await sp.get_async(
|
||||
key="curr_provider_stt",
|
||||
default=self.provider_stt_settings.get("provider_id"),
|
||||
scope="global",
|
||||
scope_id="global",
|
||||
)
|
||||
selected_tts_provider_id = sp.get(
|
||||
"curr_provider_tts",
|
||||
self.provider_tts_settings.get("provider_id"),
|
||||
selected_tts_provider_id = await sp.get_async(
|
||||
key="curr_provider_tts",
|
||||
default=self.provider_tts_settings.get("provider_id"),
|
||||
scope="global",
|
||||
scope_id="global",
|
||||
)
|
||||
|
||||
@@ -378,7 +378,8 @@ class ProviderOpenAIOfficial(Provider):
|
||||
new_content.append(part)
|
||||
message["content"] = new_content
|
||||
# reasoning key is "reasoning_content"
|
||||
message["reasoning_content"] = reasoning_content
|
||||
if reasoning_content:
|
||||
message["reasoning_content"] = reasoning_content
|
||||
|
||||
async def _handle_api_error(
|
||||
self,
|
||||
|
||||
@@ -149,9 +149,12 @@ class Context:
|
||||
contexts: context messages for the LLM
|
||||
max_steps: Maximum number of tool calls before stopping the loop
|
||||
**kwargs: Additional keyword arguments. The kwargs will not be passed to the LLM directly for now, but can include:
|
||||
stream: bool - whether to stream the LLM response
|
||||
agent_hooks: BaseAgentRunHooks[AstrAgentContext] - hooks to run during agent execution
|
||||
agent_context: AstrAgentContext - context to use for the agent
|
||||
|
||||
other kwargs will be DIRECTLY passed to the runner.reset() method
|
||||
|
||||
Returns:
|
||||
The final LLMResponse after tool calls are completed.
|
||||
|
||||
@@ -194,6 +197,15 @@ class Context:
|
||||
)
|
||||
agent_runner = ToolLoopAgentRunner()
|
||||
tool_executor = FunctionToolExecutor()
|
||||
|
||||
streaming = kwargs.get("stream", False)
|
||||
|
||||
other_kwargs = {
|
||||
k: v
|
||||
for k, v in kwargs.items()
|
||||
if k not in ["stream", "agent_hooks", "agent_context"]
|
||||
}
|
||||
|
||||
await agent_runner.reset(
|
||||
provider=prov,
|
||||
request=request,
|
||||
@@ -203,7 +215,8 @@ class Context:
|
||||
),
|
||||
tool_executor=tool_executor,
|
||||
agent_hooks=agent_hooks,
|
||||
streaming=kwargs.get("stream", False),
|
||||
streaming=streaming,
|
||||
**other_kwargs,
|
||||
)
|
||||
async for _ in agent_runner.step_until_done(max_steps):
|
||||
pass
|
||||
|
||||
@@ -12,6 +12,7 @@ from .star_handler import (
|
||||
register_on_llm_request,
|
||||
register_on_llm_response,
|
||||
register_on_platform_loaded,
|
||||
register_on_waiting_llm_request,
|
||||
register_permission_type,
|
||||
register_platform_adapter_type,
|
||||
register_regex,
|
||||
@@ -30,6 +31,7 @@ __all__ = [
|
||||
"register_on_llm_request",
|
||||
"register_on_llm_response",
|
||||
"register_on_platform_loaded",
|
||||
"register_on_waiting_llm_request",
|
||||
"register_permission_type",
|
||||
"register_platform_adapter_type",
|
||||
"register_regex",
|
||||
|
||||
@@ -339,6 +339,30 @@ def register_on_platform_loaded(**kwargs):
|
||||
return decorator
|
||||
|
||||
|
||||
def register_on_waiting_llm_request(**kwargs):
|
||||
"""当等待调用 LLM 时的通知事件(在获取锁之前)
|
||||
|
||||
此钩子在消息确定要调用 LLM 但还未开始排队等锁时触发,
|
||||
适合用于发送"正在思考中..."等用户反馈提示。
|
||||
|
||||
Examples:
|
||||
```py
|
||||
@on_waiting_llm_request()
|
||||
async def on_waiting_llm(self, event: AstrMessageEvent) -> None:
|
||||
await event.send("🤔 正在思考中...")
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
def decorator(awaitable):
|
||||
_ = get_handler_or_create(
|
||||
awaitable, EventType.OnWaitingLLMRequestEvent, **kwargs
|
||||
)
|
||||
return awaitable
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def register_on_llm_request(**kwargs):
|
||||
"""当有 LLM 请求时的事件
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ class SessionServiceManager:
|
||||
# =============================================================================
|
||||
|
||||
@staticmethod
|
||||
def is_llm_enabled_for_session(session_id: str) -> bool:
|
||||
async def is_llm_enabled_for_session(session_id: str) -> bool:
|
||||
"""检查LLM是否在指定会话中启用
|
||||
|
||||
Args:
|
||||
@@ -23,11 +23,11 @@ class SessionServiceManager:
|
||||
|
||||
"""
|
||||
# 获取会话服务配置
|
||||
session_services = sp.get(
|
||||
"session_service_config",
|
||||
{},
|
||||
session_services = await sp.get_async(
|
||||
scope="umo",
|
||||
scope_id=session_id,
|
||||
key="session_service_config",
|
||||
default={},
|
||||
)
|
||||
|
||||
# 如果配置了该会话的LLM状态,返回该状态
|
||||
@@ -39,7 +39,7 @@ class SessionServiceManager:
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def set_llm_status_for_session(session_id: str, enabled: bool) -> None:
|
||||
async def set_llm_status_for_session(session_id: str, enabled: bool) -> None:
|
||||
"""设置LLM在指定会话中的启停状态
|
||||
|
||||
Args:
|
||||
@@ -48,18 +48,24 @@ class SessionServiceManager:
|
||||
|
||||
"""
|
||||
session_config = (
|
||||
sp.get("session_service_config", {}, scope="umo", scope_id=session_id) or {}
|
||||
await sp.get_async(
|
||||
scope="umo",
|
||||
scope_id=session_id,
|
||||
key="session_service_config",
|
||||
default={},
|
||||
)
|
||||
or {}
|
||||
)
|
||||
session_config["llm_enabled"] = enabled
|
||||
sp.put(
|
||||
"session_service_config",
|
||||
session_config,
|
||||
await sp.put_async(
|
||||
scope="umo",
|
||||
scope_id=session_id,
|
||||
key="session_service_config",
|
||||
value=session_config,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def should_process_llm_request(event: AstrMessageEvent) -> bool:
|
||||
async def should_process_llm_request(event: AstrMessageEvent) -> bool:
|
||||
"""检查是否应该处理LLM请求
|
||||
|
||||
Args:
|
||||
@@ -70,14 +76,14 @@ class SessionServiceManager:
|
||||
|
||||
"""
|
||||
session_id = event.unified_msg_origin
|
||||
return SessionServiceManager.is_llm_enabled_for_session(session_id)
|
||||
return await SessionServiceManager.is_llm_enabled_for_session(session_id)
|
||||
|
||||
# =============================================================================
|
||||
# TTS 相关方法
|
||||
# =============================================================================
|
||||
|
||||
@staticmethod
|
||||
def is_tts_enabled_for_session(session_id: str) -> bool:
|
||||
async def is_tts_enabled_for_session(session_id: str) -> bool:
|
||||
"""检查TTS是否在指定会话中启用
|
||||
|
||||
Args:
|
||||
@@ -88,11 +94,11 @@ class SessionServiceManager:
|
||||
|
||||
"""
|
||||
# 获取会话服务配置
|
||||
session_services = sp.get(
|
||||
"session_service_config",
|
||||
{},
|
||||
session_services = await sp.get_async(
|
||||
scope="umo",
|
||||
scope_id=session_id,
|
||||
key="session_service_config",
|
||||
default={},
|
||||
)
|
||||
|
||||
# 如果配置了该会话的TTS状态,返回该状态
|
||||
@@ -104,7 +110,7 @@ class SessionServiceManager:
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def set_tts_status_for_session(session_id: str, enabled: bool) -> None:
|
||||
async def set_tts_status_for_session(session_id: str, enabled: bool) -> None:
|
||||
"""设置TTS在指定会话中的启停状态
|
||||
|
||||
Args:
|
||||
@@ -113,14 +119,20 @@ class SessionServiceManager:
|
||||
|
||||
"""
|
||||
session_config = (
|
||||
sp.get("session_service_config", {}, scope="umo", scope_id=session_id) or {}
|
||||
await sp.get_async(
|
||||
scope="umo",
|
||||
scope_id=session_id,
|
||||
key="session_service_config",
|
||||
default={},
|
||||
)
|
||||
or {}
|
||||
)
|
||||
session_config["tts_enabled"] = enabled
|
||||
sp.put(
|
||||
"session_service_config",
|
||||
session_config,
|
||||
await sp.put_async(
|
||||
scope="umo",
|
||||
scope_id=session_id,
|
||||
key="session_service_config",
|
||||
value=session_config,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
@@ -128,7 +140,7 @@ class SessionServiceManager:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def should_process_tts_request(event: AstrMessageEvent) -> bool:
|
||||
async def should_process_tts_request(event: AstrMessageEvent) -> bool:
|
||||
"""检查是否应该处理TTS请求
|
||||
|
||||
Args:
|
||||
@@ -139,14 +151,14 @@ class SessionServiceManager:
|
||||
|
||||
"""
|
||||
session_id = event.unified_msg_origin
|
||||
return SessionServiceManager.is_tts_enabled_for_session(session_id)
|
||||
return await SessionServiceManager.is_tts_enabled_for_session(session_id)
|
||||
|
||||
# =============================================================================
|
||||
# 会话整体启停相关方法
|
||||
# =============================================================================
|
||||
|
||||
@staticmethod
|
||||
def is_session_enabled(session_id: str) -> bool:
|
||||
async def is_session_enabled(session_id: str) -> bool:
|
||||
"""检查会话是否整体启用
|
||||
|
||||
Args:
|
||||
@@ -157,11 +169,11 @@ class SessionServiceManager:
|
||||
|
||||
"""
|
||||
# 获取会话服务配置
|
||||
session_services = sp.get(
|
||||
"session_service_config",
|
||||
{},
|
||||
session_services = await sp.get_async(
|
||||
scope="umo",
|
||||
scope_id=session_id,
|
||||
key="session_service_config",
|
||||
default={},
|
||||
)
|
||||
|
||||
# 如果配置了该会话的整体状态,返回该状态
|
||||
|
||||
@@ -8,7 +8,10 @@ class SessionPluginManager:
|
||||
"""管理会话级别的插件启停状态"""
|
||||
|
||||
@staticmethod
|
||||
def is_plugin_enabled_for_session(session_id: str, plugin_name: str) -> bool:
|
||||
async def is_plugin_enabled_for_session(
|
||||
session_id: str,
|
||||
plugin_name: str,
|
||||
) -> bool:
|
||||
"""检查插件是否在指定会话中启用
|
||||
|
||||
Args:
|
||||
@@ -20,11 +23,11 @@ class SessionPluginManager:
|
||||
|
||||
"""
|
||||
# 获取会话插件配置
|
||||
session_plugin_config = sp.get(
|
||||
"session_plugin_config",
|
||||
{},
|
||||
session_plugin_config = await sp.get_async(
|
||||
scope="umo",
|
||||
scope_id=session_id,
|
||||
key="session_plugin_config",
|
||||
default={},
|
||||
)
|
||||
session_config = session_plugin_config.get(session_id, {})
|
||||
|
||||
@@ -43,7 +46,10 @@ class SessionPluginManager:
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def filter_handlers_by_session(event: AstrMessageEvent, handlers: list) -> list:
|
||||
async def filter_handlers_by_session(
|
||||
event: AstrMessageEvent,
|
||||
handlers: list,
|
||||
) -> list:
|
||||
"""根据会话配置过滤处理器列表
|
||||
|
||||
Args:
|
||||
@@ -59,6 +65,15 @@ class SessionPluginManager:
|
||||
session_id = event.unified_msg_origin
|
||||
filtered_handlers = []
|
||||
|
||||
session_plugin_config = await sp.get_async(
|
||||
scope="umo",
|
||||
scope_id=session_id,
|
||||
key="session_plugin_config",
|
||||
default={},
|
||||
)
|
||||
session_config = session_plugin_config.get(session_id, {})
|
||||
disabled_plugins = session_config.get("disabled_plugins", [])
|
||||
|
||||
for handler in handlers:
|
||||
# 获取处理器对应的插件
|
||||
plugin = star_map.get(handler.handler_module_path)
|
||||
@@ -76,14 +91,11 @@ class SessionPluginManager:
|
||||
continue
|
||||
|
||||
# 检查插件是否在当前会话中启用
|
||||
if SessionPluginManager.is_plugin_enabled_for_session(
|
||||
session_id,
|
||||
plugin.name,
|
||||
):
|
||||
filtered_handlers.append(handler)
|
||||
else:
|
||||
if plugin.name in disabled_plugins:
|
||||
logger.debug(
|
||||
f"插件 {plugin.name} 在会话 {session_id} 中被禁用,跳过处理器 {handler.handler_name}",
|
||||
)
|
||||
else:
|
||||
filtered_handlers.append(handler)
|
||||
|
||||
return filtered_handlers
|
||||
|
||||
@@ -184,6 +184,7 @@ class EventType(enum.Enum):
|
||||
OnPlatformLoadedEvent = enum.auto() # 平台加载完成
|
||||
|
||||
AdapterMessageEvent = enum.auto() # 收到适配器发来的消息
|
||||
OnWaitingLLMRequestEvent = enum.auto() # 等待调用 LLM(在获取锁之前,仅通知)
|
||||
OnLLMRequestEvent = enum.auto() # 收到 LLM 请求(可以是用户也可以是插件)
|
||||
OnLLMResponseEvent = enum.auto() # LLM 响应后
|
||||
OnDecoratingResultEvent = enum.auto() # 发送消息前
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import fnmatch
|
||||
|
||||
from astrbot.core.utils.shared_preferences import SharedPreferences
|
||||
|
||||
|
||||
@@ -9,14 +11,15 @@ class UmopConfigRouter:
|
||||
"""UMOP 到配置文件 ID 的映射"""
|
||||
self.sp = sp
|
||||
|
||||
self._load_routing_table()
|
||||
async def initialize(self):
|
||||
await self._load_routing_table()
|
||||
|
||||
def _load_routing_table(self):
|
||||
async def _load_routing_table(self):
|
||||
"""加载路由表"""
|
||||
# 从 SharedPreferences 中加载 umop_to_conf_id 映射
|
||||
sp_data = self.sp.get(
|
||||
"umop_config_routing",
|
||||
{},
|
||||
sp_data = await self.sp.get_async(
|
||||
key="umop_config_routing",
|
||||
default={},
|
||||
scope="global",
|
||||
scope_id="global",
|
||||
)
|
||||
@@ -30,7 +33,7 @@ class UmopConfigRouter:
|
||||
if len(p1_ls) != 3 or len(p2_ls) != 3:
|
||||
return False # 非法格式
|
||||
|
||||
return all(p == "" or p == "*" or p == t for p, t in zip(p1_ls, p2_ls))
|
||||
return all(p == "" or fnmatch.fnmatchcase(t, p) for p, t in zip(p1_ls, p2_ls))
|
||||
|
||||
def get_conf_id_for_umop(self, umo: str) -> str | None:
|
||||
"""根据 UMO 获取对应的配置文件 ID
|
||||
|
||||
@@ -3,6 +3,7 @@ import traceback
|
||||
from astrbot.core import astrbot_config, logger
|
||||
from astrbot.core.astrbot_config_mgr import AstrBotConfig, AstrBotConfigManager
|
||||
from astrbot.core.db.migration.migra_45_to_46 import migrate_45_to_46
|
||||
from astrbot.core.db.migration.migra_token_usage import migrate_token_usage
|
||||
from astrbot.core.db.migration.migra_webchat_session import migrate_webchat_session
|
||||
|
||||
|
||||
@@ -139,6 +140,13 @@ async def migra(
|
||||
logger.error(f"Migration for webchat session failed: {e!s}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
# migration for token_usage column
|
||||
try:
|
||||
await migrate_token_usage(db)
|
||||
except Exception as e:
|
||||
logger.error(f"Migration for token_usage column failed: {e!s}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
# migra third party agent runner configs
|
||||
_c = False
|
||||
providers = astrbot_config["provider"]
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
## What's Changed
|
||||
|
||||
hotfix of v4.10.4
|
||||
|
||||
fix: 部分配置项的输入框不显示,如飞书机器人配置的部分配置项。(#4268)
|
||||
@@ -0,0 +1,11 @@
|
||||
## What's Changed
|
||||
|
||||
hotfix of v4.10.4
|
||||
|
||||
fix:
|
||||
|
||||
1. ‼️ 部分情况下使用 OpenAI 接口报错与 reasoning_content 有关的问题;
|
||||
|
||||
feat:
|
||||
|
||||
1. WebUI 已安装插件页支持记忆视图类型(列表/卡片),列表视图显示插件的人类友好名称和 logo。
|
||||
@@ -0,0 +1,19 @@
|
||||
## What's Changed
|
||||
|
||||
### 新增
|
||||
|
||||
- 支持上下文自动压缩功能。入口:配置文件 -> 上下文管理策略 -> 超出模型上下文窗口时的处理方式。详情请查看: [自动上下文压缩](https://docs.astrbot.app/use/context-compress.html) ([#4322](https://github.com/AstrBotDevs/AstrBot/issues/4322))
|
||||
- 新增 `on_waiting_llm_request` 事件钩子 ([#4319](https://github.com/AstrBotDevs/AstrBot/issues/4319))
|
||||
- WebUI 支持强制更新插件 ([#4293](https://github.com/AstrBotDevs/AstrBot/issues/4293))
|
||||
- 社区已提供适用于 [Matrix](https://github.com/stevessr/astrbot_plugin_matrix_adapter) 平台的适配器插件
|
||||
|
||||
### 修复
|
||||
|
||||
- 修复微信公众号中由于 msg.id 数据类型不匹配导致的重试失败问题 ([#4292](https://github.com/AstrBotDevs/AstrBot/issues/4292))
|
||||
- 修复调用 TTS 命令时出现的数据库锁定错误 ([#4313](https://github.com/AstrBotDevs/AstrBot/issues/4313))
|
||||
- 修复 Anthropic 提供商中 token 用量始终为 0 的问题 ([#4328](https://github.com/AstrBotDevs/AstrBot/issues/4328))
|
||||
|
||||
### 优化
|
||||
|
||||
- 完善共享组件的国际化支持 ([#4327](https://github.com/AstrBotDevs/AstrBot/issues/4327))
|
||||
- 优化下载大型备份文件时的稳定性,减少失败情况 ([#4329](https://github.com/AstrBotDevs/AstrBot/issues/4329))
|
||||
@@ -203,9 +203,8 @@ function hasVisibleItemsAfter(items, currentIndex) {
|
||||
|
||||
<v-col cols="12" sm="6" class="config-input">
|
||||
<ConfigItemRenderer
|
||||
v-if="metadata[metadataKey].items[key]"
|
||||
v-model="iterable[key]"
|
||||
:item-meta="metadata[metadataKey].items[key]"
|
||||
:item-meta="metadata[metadataKey].items[key] || null"
|
||||
:loading="loadingEmbeddingDim"
|
||||
:show-fullscreen-btn="!!metadata[metadataKey].items[key]?.editor_mode"
|
||||
@get-embedding-dim="getEmbeddingDimensions(iterable)"
|
||||
|
||||
@@ -219,7 +219,7 @@ function getSpecialSubtype(value) {
|
||||
<ConfigItemRenderer
|
||||
v-else
|
||||
v-model="createSelectorModel(itemKey).value"
|
||||
:item-meta="itemMeta"
|
||||
:item-meta="itemMeta || null"
|
||||
:show-fullscreen-btn="!!itemMeta?.editor_mode"
|
||||
@open-fullscreen="openEditorDialog(itemKey, iterable, itemMeta?.editor_theme, itemMeta?.editor_language)"
|
||||
/>
|
||||
|
||||
@@ -144,7 +144,7 @@
|
||||
color="primary"
|
||||
density="compact"
|
||||
hide-details
|
||||
class="flex-grow-1"
|
||||
style="flex: 1"
|
||||
></v-slider>
|
||||
<v-text-field
|
||||
:model-value="modelValue"
|
||||
@@ -154,7 +154,7 @@
|
||||
class="config-field"
|
||||
type="number"
|
||||
hide-details
|
||||
style="max-width: 140px;"
|
||||
style="flex: 1"
|
||||
></v-text-field>
|
||||
</div>
|
||||
|
||||
@@ -223,7 +223,7 @@ const props = defineProps({
|
||||
},
|
||||
itemMeta: {
|
||||
type: Object,
|
||||
required: true
|
||||
default: null
|
||||
},
|
||||
loading: {
|
||||
type: Boolean,
|
||||
@@ -325,4 +325,8 @@ function getSpecialSubtype(value) {
|
||||
.gap-20 {
|
||||
gap: 20px;
|
||||
}
|
||||
|
||||
:deep(.v-field__input) {
|
||||
font-size: 14px;
|
||||
}
|
||||
</style>
|
||||
|
||||
@@ -145,9 +145,11 @@ const viewReadme = () => {
|
||||
}})</v-list-item-title>
|
||||
</v-list-item>
|
||||
|
||||
<v-list-item @click="updateExtension" :disabled="!extension?.has_update">
|
||||
<v-list-item @click="updateExtension">
|
||||
<v-list-item-title>
|
||||
{{ tm('card.actions.updateTo') }} {{ extension.online_version || extension.version }}
|
||||
{{ extension.has_update
|
||||
? tm('card.actions.updateTo') + ' ' + extension.online_version
|
||||
: tm('card.actions.reinstall') }}
|
||||
</v-list-item-title>
|
||||
</v-list-item>
|
||||
</template>
|
||||
|
||||
@@ -510,7 +510,7 @@ export function useProviderSources(options: UseProviderSourcesOptions) {
|
||||
|
||||
const metadata = getModelMetadata(modelName)
|
||||
let modalities: string[]
|
||||
|
||||
|
||||
if (!metadata) {
|
||||
modalities = ['text', 'image', 'tool_use']
|
||||
} else {
|
||||
@@ -523,13 +523,19 @@ export function useProviderSources(options: UseProviderSourcesOptions) {
|
||||
}
|
||||
}
|
||||
|
||||
let max_context_tokens = 0
|
||||
if (metadata?.limit?.context && typeof metadata.limit.context === 'number') {
|
||||
max_context_tokens = metadata.limit.context
|
||||
}
|
||||
|
||||
const newProvider = {
|
||||
id: newId,
|
||||
enable: false,
|
||||
provider_source_id: sourceId,
|
||||
model: modelName,
|
||||
modalities,
|
||||
custom_extra_body: {}
|
||||
custom_extra_body: {},
|
||||
max_context_tokens: max_context_tokens
|
||||
}
|
||||
|
||||
try {
|
||||
|
||||
@@ -11,7 +11,12 @@
|
||||
},
|
||||
"agent_runner_type": {
|
||||
"description": "Runner",
|
||||
"labels": ["Built-in Agent", "Dify", "Coze", "Alibaba Cloud Bailian Application"]
|
||||
"labels": [
|
||||
"Built-in Agent",
|
||||
"Dify",
|
||||
"Coze",
|
||||
"Alibaba Cloud Bailian Application"
|
||||
]
|
||||
},
|
||||
"coze_agent_runner_provider_id": {
|
||||
"description": "Coze Agent Runner Provider ID"
|
||||
@@ -128,6 +133,39 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"truncate_and_compress": {
|
||||
"description": "Context Management Strategy",
|
||||
"provider_settings": {
|
||||
"max_context_length": {
|
||||
"description": "Maximum Conversation Turns",
|
||||
"hint": "Discards the oldest parts when this count is exceeded. One conversation round counts as 1, -1 means unlimited"
|
||||
},
|
||||
"dequeue_context_length": {
|
||||
"description": "Dequeue Conversation Turns",
|
||||
"hint": "Number of conversation turns to discard at once when maximum context length is exceeded"
|
||||
},
|
||||
"context_limit_reached_strategy": {
|
||||
"description": "Handling When Model Context Window is Exceeded",
|
||||
"labels": [
|
||||
"Truncate by Turns",
|
||||
"Compress by LLM"
|
||||
],
|
||||
"hint": "When 'Truncate by Turns' is selected, the oldest N conversation turns will be discarded based on the 'Dequeue Conversation Turns' setting above. When 'Compress by LLM' is selected, the specified model will be used for context compression."
|
||||
},
|
||||
"llm_compress_instruction": {
|
||||
"description": "Context Compression Instruction",
|
||||
"hint": "If empty, the default prompt will be used."
|
||||
},
|
||||
"llm_compress_keep_recent": {
|
||||
"description": "Keep Recent Turns When Compressing",
|
||||
"hint": "Always keep the most recent N turns of conversation when compressing context."
|
||||
},
|
||||
"llm_compress_provider_id": {
|
||||
"description": "Model Provider ID for Context Compression",
|
||||
"hint": "When left empty, will fall back to the 'Truncate by Turns' strategy."
|
||||
}
|
||||
}
|
||||
},
|
||||
"others": {
|
||||
"description": "Other Settings",
|
||||
"provider_settings": {
|
||||
@@ -161,15 +199,10 @@
|
||||
"unsupported_streaming_strategy": {
|
||||
"description": "Platforms Without Streaming Support",
|
||||
"hint": "Select the handling method for platforms that don't support streaming responses. Real-time segmented reply sends content immediately when the system detects segment points like punctuation during streaming reception",
|
||||
"labels": ["Real-time Segmented Reply", "Disable Streaming Response"]
|
||||
},
|
||||
"max_context_length": {
|
||||
"description": "Maximum Conversation Rounds",
|
||||
"hint": "Discards the oldest parts when this count is exceeded. One conversation round counts as 1, -1 means unlimited"
|
||||
},
|
||||
"dequeue_context_length": {
|
||||
"description": "Dequeue Conversation Rounds",
|
||||
"hint": "Number of conversation rounds to discard at once when maximum context length is exceeded"
|
||||
"labels": [
|
||||
"Real-time Segmented Reply",
|
||||
"Disable Streaming Response"
|
||||
]
|
||||
},
|
||||
"wake_prefix": {
|
||||
"description": "Additional LLM Chat Wake Prefix",
|
||||
@@ -387,7 +420,10 @@
|
||||
},
|
||||
"split_mode": {
|
||||
"description": "Split Mode",
|
||||
"labels": ["Regex", "Words List"]
|
||||
"labels": [
|
||||
"Regex",
|
||||
"Words List"
|
||||
]
|
||||
},
|
||||
"regex": {
|
||||
"description": "Segmentation Regular Expression"
|
||||
@@ -488,4 +524,4 @@
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -145,6 +145,11 @@
|
||||
"message": "This plugin has been flagged as containing security risks, including unsafe code or functionalities that may cause system malfunctions or data loss. Do you wish to proceed with the installation?",
|
||||
"confirm": "Continue",
|
||||
"cancel": "Cancel"
|
||||
},
|
||||
"forceUpdate": {
|
||||
"title": "No New Version Detected",
|
||||
"message": "No new version detected for this plugin. Do you want to force reinstall? This will pull the latest code from the remote repository.",
|
||||
"confirm": "Force Update"
|
||||
}
|
||||
},
|
||||
"messages": {
|
||||
@@ -185,7 +190,8 @@
|
||||
"reloadPlugin": "Reload Extension",
|
||||
"togglePlugin": "Extension",
|
||||
"viewHandlers": "View Handlers",
|
||||
"updateTo": "Update to"
|
||||
"updateTo": "Update to",
|
||||
"reinstall": "Reinstall"
|
||||
},
|
||||
"status": {
|
||||
"hasUpdate": "New version available",
|
||||
@@ -207,4 +213,4 @@
|
||||
"goToManage": "Go to Manage",
|
||||
"later": "Later"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -133,6 +133,36 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"truncate_and_compress": {
|
||||
"description": "上下文管理策略",
|
||||
"provider_settings": {
|
||||
"max_context_length": {
|
||||
"description": "最多携带对话轮数",
|
||||
"hint": "超出这个数量时丢弃最旧的部分,一轮聊天记为 1 条,-1 为不限制"
|
||||
},
|
||||
"dequeue_context_length": {
|
||||
"description": "丢弃对话轮数",
|
||||
"hint": "超出最多携带对话轮数时, 一次丢弃的聊天轮数"
|
||||
},
|
||||
"context_limit_reached_strategy": {
|
||||
"description": "超出模型上下文窗口时的处理方式",
|
||||
"labels": ["按对话轮数截断", "由 LLM 压缩上下文"],
|
||||
"hint": "当按对话轮数截断时,会根据上面\"丢弃对话轮数\"的配置丢弃最旧的 N 轮对话。当由 LLM 压缩上下文时,会使用指定的模型进行上下文压缩。"
|
||||
},
|
||||
"llm_compress_instruction": {
|
||||
"description": "上下文压缩提示词",
|
||||
"hint": "如果为空则使用默认提示词。"
|
||||
},
|
||||
"llm_compress_keep_recent": {
|
||||
"description": "压缩时保留最近对话轮数",
|
||||
"hint": "始终保留的最近 N 轮对话。"
|
||||
},
|
||||
"llm_compress_provider_id": {
|
||||
"description": "用于上下文压缩的模型提供商 ID",
|
||||
"hint": "留空时将降级为\"按对话轮数截断\"的策略。"
|
||||
}
|
||||
}
|
||||
},
|
||||
"others": {
|
||||
"description": "其他配置",
|
||||
"provider_settings": {
|
||||
@@ -171,14 +201,7 @@
|
||||
"关闭流式回复"
|
||||
]
|
||||
},
|
||||
"max_context_length": {
|
||||
"description": "最多携带对话轮数",
|
||||
"hint": "超出这个数量时丢弃最旧的部分,一轮聊天记为 1 条,-1 为不限制"
|
||||
},
|
||||
"dequeue_context_length": {
|
||||
"description": "丢弃对话轮数",
|
||||
"hint": "超出最多携带对话轮数时, 一次丢弃的聊天轮数"
|
||||
},
|
||||
|
||||
"wake_prefix": {
|
||||
"description": "LLM 聊天额外唤醒前缀",
|
||||
"hint": "如果唤醒前缀为 /, 额外聊天唤醒前缀为 chat,则需要 /chat 才会触发 LLM 请求"
|
||||
|
||||
@@ -145,6 +145,11 @@
|
||||
"message": "该插件可能包含不安全的代码或功能,可能导致系统异常或数据损失等。请确认是否继续安装?",
|
||||
"confirm": "继续",
|
||||
"cancel": "取消"
|
||||
},
|
||||
"forceUpdate": {
|
||||
"title": "未检测到新版本",
|
||||
"message": "当前插件未检测到新版本,是否强制重新安装?这将从远程仓库拉取最新代码。",
|
||||
"confirm": "强制更新"
|
||||
}
|
||||
},
|
||||
"messages": {
|
||||
@@ -185,7 +190,8 @@
|
||||
"reloadPlugin": "重载插件",
|
||||
"togglePlugin": "插件",
|
||||
"viewHandlers": "查看行为",
|
||||
"updateTo": "更新到"
|
||||
"updateTo": "更新到",
|
||||
"reinstall": "重新安装"
|
||||
},
|
||||
"status": {
|
||||
"hasUpdate": "有新版本可用",
|
||||
|
||||
@@ -77,14 +77,26 @@ const readmeDialog = reactive({
|
||||
repoUrl: null
|
||||
});
|
||||
|
||||
// 强制更新确认对话框
|
||||
const forceUpdateDialog = reactive({
|
||||
show: false,
|
||||
extensionName: ''
|
||||
});
|
||||
|
||||
// 新增变量支持列表视图
|
||||
const isListView = ref(false);
|
||||
// 从 localStorage 恢复显示模式,默认为 false(卡片视图)
|
||||
const getInitialListViewMode = () => {
|
||||
if (typeof window !== 'undefined' && window.localStorage) {
|
||||
return localStorage.getItem('pluginListViewMode') === 'true';
|
||||
}
|
||||
return false;
|
||||
};
|
||||
const isListView = ref(getInitialListViewMode());
|
||||
const pluginSearch = ref("");
|
||||
const loading_ = ref(false);
|
||||
|
||||
// 分页相关
|
||||
const currentPage = ref(1);
|
||||
const itemsPerPage = ref(6); // 每页显示6个卡片 (2行 x 3列,避免滚动)
|
||||
|
||||
// 危险插件确认对话框
|
||||
const dangerConfirmDialog = ref(false);
|
||||
@@ -113,7 +125,6 @@ const uploadTab = ref('file');
|
||||
const showPluginFullName = ref(false);
|
||||
const marketSearch = ref("");
|
||||
const debouncedMarketSearch = ref("");
|
||||
const filterKeys = ['name', 'desc', 'author'];
|
||||
const refreshingMarket = ref(false);
|
||||
const sortBy = ref('default'); // default, stars, author, updated
|
||||
const sortOrder = ref('desc'); // desc (降序) or asc (升序)
|
||||
@@ -162,18 +173,6 @@ const pluginHeaders = computed(() => [
|
||||
]);
|
||||
|
||||
|
||||
// 插件市场表头
|
||||
const pluginMarketHeaders = computed(() => [
|
||||
{ title: tm('table.headers.name'), key: 'name', maxWidth: '200px' },
|
||||
{ title: tm('table.headers.description'), key: 'desc', maxWidth: '250px' },
|
||||
{ title: tm('table.headers.author'), key: 'author', maxWidth: '90px' },
|
||||
{ title: tm('table.headers.stars'), key: 'stars', maxWidth: '80px' },
|
||||
{ title: tm('table.headers.lastUpdate'), key: 'updated_at', maxWidth: '100px' },
|
||||
{ title: tm('table.headers.tags'), key: 'tags', maxWidth: '100px' },
|
||||
{ title: tm('table.headers.actions'), key: 'actions', sortable: false }
|
||||
]);
|
||||
|
||||
|
||||
// 过滤要显示的插件
|
||||
const filteredExtensions = computed(() => {
|
||||
const data = Array.isArray(extension_data?.data) ? extension_data.data : [];
|
||||
@@ -197,9 +196,6 @@ const filteredPlugins = computed(() => {
|
||||
});
|
||||
});
|
||||
|
||||
const pinnedPlugins = computed(() => {
|
||||
return pluginMarketData.value.filter(plugin => plugin?.pinned);
|
||||
});
|
||||
|
||||
// 过滤后的插件市场数据(带搜索)
|
||||
const filteredMarketPlugins = computed(() => {
|
||||
@@ -385,7 +381,17 @@ const handleUninstallConfirm = (options) => {
|
||||
}
|
||||
};
|
||||
|
||||
const updateExtension = async (extension_name) => {
|
||||
const updateExtension = async (extension_name, forceUpdate = false) => {
|
||||
// 查找插件信息
|
||||
const ext = extension_data.data?.find(e => e.name === extension_name);
|
||||
|
||||
// 如果没有检测到更新且不是强制更新,则弹窗确认
|
||||
if (!ext?.has_update && !forceUpdate) {
|
||||
forceUpdateDialog.extensionName = extension_name;
|
||||
forceUpdateDialog.show = true;
|
||||
return;
|
||||
}
|
||||
|
||||
loadingDialog.title = tm('status.loading');
|
||||
loadingDialog.show = true;
|
||||
try {
|
||||
@@ -417,6 +423,14 @@ const updateExtension = async (extension_name) => {
|
||||
}
|
||||
};
|
||||
|
||||
// 确认强制更新
|
||||
const confirmForceUpdate = () => {
|
||||
const name = forceUpdateDialog.extensionName;
|
||||
forceUpdateDialog.show = false;
|
||||
forceUpdateDialog.extensionName = '';
|
||||
updateExtension(name, true);
|
||||
};
|
||||
|
||||
const updateAllExtensions = async () => {
|
||||
if (updatingAll.value || updatableExtensions.value.length === 0) return;
|
||||
updatingAll.value = true;
|
||||
@@ -552,14 +566,6 @@ const viewReadme = (plugin) => {
|
||||
readmeDialog.show = true;
|
||||
};
|
||||
|
||||
|
||||
|
||||
const open = (link) => {
|
||||
if (link) {
|
||||
window.open(link, '_blank');
|
||||
}
|
||||
};
|
||||
|
||||
// 为表格视图创建一个处理安装插件的函数
|
||||
const handleInstallPlugin = async (plugin) => {
|
||||
if (plugin.tags && plugin.tags.includes('danger')) {
|
||||
@@ -918,6 +924,13 @@ watch(marketSearch, (newVal) => {
|
||||
}, 300); // 300ms 防抖延迟
|
||||
});
|
||||
|
||||
// 监听显示模式变化并保存到 localStorage
|
||||
watch(isListView, (newVal) => {
|
||||
if (typeof window !== 'undefined' && window.localStorage) {
|
||||
localStorage.setItem('pluginListViewMode', String(newVal));
|
||||
}
|
||||
});
|
||||
|
||||
|
||||
</script>
|
||||
|
||||
@@ -1037,8 +1050,21 @@ watch(marketSearch, (newVal) => {
|
||||
|
||||
<template v-slot:item.name="{ item }">
|
||||
<div class="d-flex align-center py-2">
|
||||
<div v-if="item.logo" class="mr-3" style="flex-shrink: 0;">
|
||||
<img :src="item.logo" :alt="item.name"
|
||||
style="height: 40px; width: 40px; border-radius: 8px; object-fit: cover;" />
|
||||
</div>
|
||||
<div v-else class="mr-3" style="flex-shrink: 0;">
|
||||
<img :src="defaultPluginIcon" :alt="item.name"
|
||||
style="height: 40px; width: 40px; border-radius: 8px; object-fit: cover;" />
|
||||
</div>
|
||||
<div>
|
||||
<div class="text-subtitle-1 font-weight-medium">{{ item.name }}</div>
|
||||
<div class="text-subtitle-1 font-weight-medium">
|
||||
{{ item.display_name && item.display_name.length ? item.display_name : item.name }}
|
||||
</div>
|
||||
<div v-if="item.display_name && item.display_name.length" class="text-caption text-medium-emphasis mt-1">
|
||||
{{ item.name }}
|
||||
</div>
|
||||
<div v-if="item.reserved" class="d-flex align-center mt-1">
|
||||
<v-chip color="primary" size="x-small" class="font-weight-medium">{{ tm('status.system')
|
||||
}}</v-chip>
|
||||
@@ -1048,7 +1074,7 @@ watch(marketSearch, (newVal) => {
|
||||
</template>
|
||||
|
||||
<template v-slot:item.desc="{ item }">
|
||||
<div class="text-body-2 text-medium-emphasis">{{ item.desc }}</div>
|
||||
<div class="text-body-2 text-medium-emphasis mt-2 mb-2" style="display: -webkit-box; -webkit-line-clamp: 3; line-clamp: 3; -webkit-box-orient: vertical; overflow: hidden; text-overflow: ellipsis;">{{ item.desc }}</div>
|
||||
</template>
|
||||
|
||||
<template v-slot:item.version="{ item }">
|
||||
@@ -1084,7 +1110,7 @@ watch(marketSearch, (newVal) => {
|
||||
<v-tooltip activator="parent" location="top">{{ tm('tooltips.disable') }}</v-tooltip>
|
||||
</v-btn>
|
||||
|
||||
<v-btn icon size="small" color="info" @click="reloadPlugin(item.name)">
|
||||
<v-btn icon size="small" @click="reloadPlugin(item.name)">
|
||||
<v-icon>mdi-refresh</v-icon>
|
||||
<v-tooltip activator="parent" location="top">{{ tm('tooltips.reload') }}</v-tooltip>
|
||||
</v-btn>
|
||||
@@ -1104,8 +1130,7 @@ watch(marketSearch, (newVal) => {
|
||||
<v-tooltip activator="parent" location="top">{{ tm('tooltips.viewDocs') }}</v-tooltip>
|
||||
</v-btn>
|
||||
|
||||
<v-btn icon size="small" color="warning" @click="updateExtension(item.name)"
|
||||
:v-show="item.has_update">
|
||||
<v-btn icon size="small" @click="updateExtension(item.name)">
|
||||
<v-icon>mdi-update</v-icon>
|
||||
<v-tooltip activator="parent" location="top">{{ tm('tooltips.update') }}</v-tooltip>
|
||||
</v-btn>
|
||||
@@ -1772,6 +1797,24 @@ watch(marketSearch, (newVal) => {
|
||||
</v-card-actions>
|
||||
</v-card>
|
||||
</v-dialog>
|
||||
|
||||
<!-- 强制更新确认对话框 -->
|
||||
<v-dialog v-model="forceUpdateDialog.show" max-width="420">
|
||||
<v-card class="rounded-lg">
|
||||
<v-card-title class="text-h6 d-flex align-center">
|
||||
<v-icon color="info" class="mr-2">mdi-information-outline</v-icon>
|
||||
{{ tm('dialogs.forceUpdate.title') }}
|
||||
</v-card-title>
|
||||
<v-card-text>
|
||||
{{ tm('dialogs.forceUpdate.message') }}
|
||||
</v-card-text>
|
||||
<v-card-actions>
|
||||
<v-spacer></v-spacer>
|
||||
<v-btn variant="text" @click="forceUpdateDialog.show = false">{{ tm('buttons.cancel') }}</v-btn>
|
||||
<v-btn color="primary" variant="flat" @click="confirmForceUpdate">{{ tm('dialogs.forceUpdate.confirm') }}</v-btn>
|
||||
</v-card-actions>
|
||||
</v-card>
|
||||
</v-dialog>
|
||||
</template>
|
||||
|
||||
<style scoped>
|
||||
|
||||
+1
-1
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "AstrBot"
|
||||
version = "4.10.4"
|
||||
version = "4.11.0"
|
||||
description = "Easy-to-use multi-platform LLM chatbot and development framework"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
|
||||
@@ -0,0 +1,774 @@
|
||||
"""Comprehensive tests for ContextManager."""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
# Add parent directory to path to avoid circular import issues
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||
|
||||
from astrbot.core.agent.context.config import ContextConfig
|
||||
from astrbot.core.agent.context.manager import ContextManager
|
||||
from astrbot.core.agent.message import Message, TextPart
|
||||
from astrbot.core.provider.entities import LLMResponse
|
||||
|
||||
|
||||
class MockProvider:
|
||||
"""模拟 Provider"""
|
||||
|
||||
def __init__(self):
|
||||
self.provider_config = {
|
||||
"id": "test_provider",
|
||||
"model": "gpt-4",
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
}
|
||||
|
||||
async def text_chat(self, **kwargs):
|
||||
"""模拟 LLM 调用,返回摘要"""
|
||||
messages = kwargs.get("messages", [])
|
||||
# 简单的摘要逻辑:返回消息数量统计
|
||||
return LLMResponse(
|
||||
role="assistant",
|
||||
completion_text=f"历史对话包含 {len(messages) - 1} 条消息,主要讨论了技术话题。",
|
||||
)
|
||||
|
||||
def get_model(self):
|
||||
return "gpt-4"
|
||||
|
||||
def meta(self):
|
||||
return MagicMock(id="test_provider", type="openai")
|
||||
|
||||
|
||||
class TestContextManager:
|
||||
"""Test suite for ContextManager."""
|
||||
|
||||
def create_message(
|
||||
self, role: Literal["system", "user", "assistant", "tool"], content: str
|
||||
) -> Message:
|
||||
"""Helper to create a simple text message."""
|
||||
return Message(role=role, content=content)
|
||||
|
||||
def create_messages(self, count: int) -> list[Message]:
|
||||
"""Helper to create alternating user/assistant messages."""
|
||||
messages = []
|
||||
for i in range(count):
|
||||
role = "user" if i % 2 == 0 else "assistant"
|
||||
messages.append(self.create_message(role, f"Message {i}"))
|
||||
return messages
|
||||
|
||||
# ==================== Basic Initialization Tests ====================
|
||||
|
||||
def test_init_with_minimal_config(self):
|
||||
"""Test initialization with minimal configuration."""
|
||||
config = ContextConfig()
|
||||
manager = ContextManager(config)
|
||||
|
||||
assert manager.config == config
|
||||
assert manager.token_counter is not None
|
||||
assert manager.truncator is not None
|
||||
assert manager.compressor is not None
|
||||
|
||||
def test_init_with_llm_compressor(self):
|
||||
"""Test initialization with LLM-based compression."""
|
||||
mock_provider = MockProvider()
|
||||
config = ContextConfig(
|
||||
llm_compress_provider=mock_provider, # type: ignore
|
||||
llm_compress_keep_recent=5,
|
||||
llm_compress_instruction="Summarize the conversation",
|
||||
)
|
||||
manager = ContextManager(config)
|
||||
|
||||
from astrbot.core.agent.context.compressor import LLMSummaryCompressor
|
||||
|
||||
assert isinstance(manager.compressor, LLMSummaryCompressor)
|
||||
|
||||
def test_init_with_truncate_compressor(self):
|
||||
"""Test initialization with truncate-based compression (default)."""
|
||||
config = ContextConfig(truncate_turns=3)
|
||||
manager = ContextManager(config)
|
||||
|
||||
from astrbot.core.agent.context.compressor import TruncateByTurnsCompressor
|
||||
|
||||
assert isinstance(manager.compressor, TruncateByTurnsCompressor)
|
||||
|
||||
# ==================== Empty and Edge Cases ====================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_empty_messages(self):
|
||||
"""Test processing an empty message list."""
|
||||
config = ContextConfig()
|
||||
manager = ContextManager(config)
|
||||
|
||||
result = await manager.process([])
|
||||
|
||||
assert result == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_single_message(self):
|
||||
"""Test processing a single message."""
|
||||
config = ContextConfig()
|
||||
manager = ContextManager(config)
|
||||
|
||||
messages = [self.create_message("user", "Hello")]
|
||||
result = await manager.process(messages)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].content == "Hello"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_with_no_limits(self):
|
||||
"""Test processing when no limits are set (no truncation or compression)."""
|
||||
config = ContextConfig(max_context_tokens=0, enforce_max_turns=-1)
|
||||
manager = ContextManager(config)
|
||||
|
||||
messages = self.create_messages(20)
|
||||
result = await manager.process(messages)
|
||||
|
||||
assert len(result) == 20
|
||||
assert result == messages
|
||||
|
||||
# ==================== Enforce Max Turns Tests ====================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enforce_max_turns_basic(self):
|
||||
"""Test basic enforce_max_turns functionality."""
|
||||
config = ContextConfig(enforce_max_turns=3, truncate_turns=1)
|
||||
manager = ContextManager(config)
|
||||
|
||||
# Create 10 turns (20 messages)
|
||||
messages = self.create_messages(20)
|
||||
result = await manager.process(messages)
|
||||
|
||||
# Should keep only 3 most recent turns (6 messages)
|
||||
assert len(result) <= 8 # May vary due to truncation logic
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enforce_max_turns_zero(self):
|
||||
"""Test enforce_max_turns with value 0 (should keep nothing)."""
|
||||
config = ContextConfig(enforce_max_turns=0, truncate_turns=1)
|
||||
manager = ContextManager(config)
|
||||
|
||||
messages = self.create_messages(10)
|
||||
result = await manager.process(messages)
|
||||
|
||||
# Should result in empty or minimal message list
|
||||
assert len(result) <= 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enforce_max_turns_negative(self):
|
||||
"""Test enforce_max_turns with -1 (no limit)."""
|
||||
config = ContextConfig(enforce_max_turns=-1)
|
||||
manager = ContextManager(config)
|
||||
|
||||
messages = self.create_messages(20)
|
||||
result = await manager.process(messages)
|
||||
|
||||
assert len(result) == 20
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enforce_max_turns_with_system_messages(self):
|
||||
"""Test enforce_max_turns preserves system messages."""
|
||||
config = ContextConfig(enforce_max_turns=2, truncate_turns=1)
|
||||
manager = ContextManager(config)
|
||||
|
||||
messages = [
|
||||
self.create_message("system", "System instruction"),
|
||||
*self.create_messages(10),
|
||||
]
|
||||
result = await manager.process(messages)
|
||||
|
||||
# System message should be preserved
|
||||
system_msgs = [m for m in result if m.role == "system"]
|
||||
assert len(system_msgs) >= 1
|
||||
assert system_msgs[0].content == "System instruction"
|
||||
|
||||
# ==================== Token-based Compression Tests ====================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_compression_not_triggered_below_threshold(self):
|
||||
"""Test that compression is not triggered below threshold."""
|
||||
config = ContextConfig(max_context_tokens=1000)
|
||||
manager = ContextManager(config)
|
||||
|
||||
# Create messages that total less than threshold
|
||||
messages = [self.create_message("user", "Hi" * 50)] # ~100 tokens
|
||||
|
||||
with patch.object(
|
||||
manager.compressor, "should_compress", return_value=False
|
||||
) as mock_should_compress:
|
||||
with patch.object(
|
||||
manager.compressor, "__call__", new_callable=AsyncMock
|
||||
) as mock_compress:
|
||||
result = await manager.process(messages)
|
||||
|
||||
# should_compress should be called
|
||||
mock_should_compress.assert_called_once()
|
||||
# Compressor should not be called
|
||||
mock_compress.assert_not_called()
|
||||
assert result == messages
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_compression_triggered_above_threshold(self):
|
||||
"""Test that compression is triggered above threshold."""
|
||||
config = ContextConfig(max_context_tokens=100, truncate_turns=1)
|
||||
manager = ContextManager(config)
|
||||
|
||||
# Create messages that exceed threshold (0.82 * 100 = 82 tokens)
|
||||
# 300 chars * 0.3 = 90 tokens > 82 threshold
|
||||
long_text = "x" * 300 # ~90 tokens, above threshold
|
||||
messages = [self.create_message("user", long_text)]
|
||||
|
||||
# Mock compressor to return smaller result
|
||||
compressed = [self.create_message("user", "short")]
|
||||
|
||||
# Create a mock compressor
|
||||
mock_compressor = AsyncMock()
|
||||
mock_compressor.compression_threshold = 0.82
|
||||
mock_compressor.return_value = compressed
|
||||
|
||||
# Mock should_compress to return True first time, False after
|
||||
call_count = 0
|
||||
|
||||
def mock_should_compress(*args, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return call_count == 1
|
||||
|
||||
mock_compressor.should_compress = mock_should_compress
|
||||
manager.compressor = mock_compressor
|
||||
|
||||
result = await manager.process(messages)
|
||||
|
||||
# Compressor should be called
|
||||
mock_compressor.assert_called_once()
|
||||
# Result should be the compressed version
|
||||
assert len(result) <= len(messages)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_compression_with_zero_max_tokens(self):
|
||||
"""Test that compression is skipped when max_context_tokens is 0."""
|
||||
config = ContextConfig(max_context_tokens=0)
|
||||
manager = ContextManager(config)
|
||||
|
||||
messages = [self.create_message("user", "x" * 10000)]
|
||||
|
||||
with patch.object(
|
||||
manager.compressor, "__call__", new_callable=AsyncMock
|
||||
) as mock_compress:
|
||||
result = await manager.process(messages)
|
||||
|
||||
# Compressor should not be called when max_context_tokens is 0
|
||||
mock_compress.assert_not_called()
|
||||
assert result == messages
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_compression_with_negative_max_tokens(self):
|
||||
"""Test that compression is skipped when max_context_tokens is negative."""
|
||||
config = ContextConfig(max_context_tokens=-100)
|
||||
manager = ContextManager(config)
|
||||
|
||||
messages = [self.create_message("user", "x" * 10000)]
|
||||
|
||||
with patch.object(
|
||||
manager.compressor, "__call__", new_callable=AsyncMock
|
||||
) as mock_compress:
|
||||
result = await manager.process(messages)
|
||||
|
||||
# Compressor should not be called
|
||||
mock_compress.assert_not_called()
|
||||
assert result == messages
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_double_check_after_compression(self):
|
||||
"""Test that halving is applied if still over threshold after compression."""
|
||||
config = ContextConfig(max_context_tokens=100)
|
||||
manager = ContextManager(config)
|
||||
|
||||
# Create messages that would still be over threshold after compression
|
||||
long_messages = [self.create_message("user", "x" * 200) for _ in range(10)]
|
||||
|
||||
# Mock compressor to return messages still over threshold
|
||||
async def mock_compress(msgs):
|
||||
return msgs # Return same messages (still over limit)
|
||||
|
||||
# Mock should_compress to return True twice (before and after compression)
|
||||
with patch.object(manager.compressor, "should_compress", return_value=True):
|
||||
with patch.object(manager.compressor, "__call__", new=mock_compress):
|
||||
with patch.object(
|
||||
manager.truncator,
|
||||
"truncate_by_halving",
|
||||
return_value=long_messages[:5],
|
||||
) as mock_halving:
|
||||
_ = await manager.process(long_messages)
|
||||
|
||||
# Halving should be called
|
||||
mock_halving.assert_called_once()
|
||||
|
||||
# ==================== Combined Truncation and Compression Tests ====================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_combined_enforce_turns_and_token_limit(self):
|
||||
"""Test combining enforce_max_turns and token limit."""
|
||||
config = ContextConfig(
|
||||
enforce_max_turns=5, max_context_tokens=500, truncate_turns=1
|
||||
)
|
||||
manager = ContextManager(config)
|
||||
|
||||
# Create many messages
|
||||
messages = self.create_messages(30)
|
||||
|
||||
result = await manager.process(messages)
|
||||
|
||||
# Should be truncated by both mechanisms
|
||||
assert len(result) < 30
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sequential_processing_order(self):
|
||||
"""Test that enforce_max_turns happens before token compression."""
|
||||
config = ContextConfig(enforce_max_turns=5, max_context_tokens=1000)
|
||||
manager = ContextManager(config)
|
||||
|
||||
messages = self.create_messages(20)
|
||||
|
||||
# Mock the truncator to track calls
|
||||
with patch.object(
|
||||
manager.truncator,
|
||||
"truncate_by_turns",
|
||||
wraps=manager.truncator.truncate_by_turns,
|
||||
) as mock_truncate:
|
||||
await manager.process(messages)
|
||||
|
||||
# Truncator should be called first
|
||||
mock_truncate.assert_called_once()
|
||||
|
||||
# ==================== Error Handling Tests ====================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_handling_returns_original_messages(self):
|
||||
"""Test that errors during processing return original messages."""
|
||||
config = ContextConfig(max_context_tokens=100)
|
||||
manager = ContextManager(config)
|
||||
|
||||
messages = self.create_messages(5)
|
||||
|
||||
# Make compressor raise an exception
|
||||
with patch.object(
|
||||
manager.compressor, "__call__", side_effect=Exception("Test error")
|
||||
):
|
||||
result = await manager.process(messages)
|
||||
|
||||
# Should return original messages despite error
|
||||
assert result == messages
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_handling_logs_exception(self):
|
||||
"""Test that errors are logged."""
|
||||
config = ContextConfig(max_context_tokens=100)
|
||||
manager = ContextManager(config)
|
||||
|
||||
# Create messages that will trigger compression (> 82 tokens)
|
||||
messages = [self.create_message("user", "x" * 300)] # ~90 tokens
|
||||
|
||||
# Replace compressor with one that raises an exception
|
||||
mock_compressor = AsyncMock(side_effect=Exception("Test error"))
|
||||
mock_compressor.compression_threshold = 0.82
|
||||
mock_compressor.should_compress = MagicMock(return_value=True)
|
||||
manager.compressor = mock_compressor
|
||||
|
||||
with patch("astrbot.core.agent.context.manager.logger") as mock_logger:
|
||||
result = await manager.process(messages)
|
||||
|
||||
# Logger error method should be called
|
||||
assert mock_logger.error.called
|
||||
# Should return original messages on error
|
||||
assert result == messages
|
||||
|
||||
# ==================== Multi-modal Content Tests ====================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_messages_with_textpart_content(self):
|
||||
"""Test processing messages with TextPart content."""
|
||||
config = ContextConfig()
|
||||
manager = ContextManager(config)
|
||||
|
||||
messages = [
|
||||
Message(role="user", content=[TextPart(text="Hello")]),
|
||||
Message(role="assistant", content=[TextPart(text="Hi there")]),
|
||||
]
|
||||
|
||||
result = await manager.process(messages)
|
||||
|
||||
assert len(result) == 2
|
||||
assert result == messages
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_counting_with_multimodal_content(self):
|
||||
"""Test token counting works with multi-modal content."""
|
||||
config = ContextConfig(max_context_tokens=50)
|
||||
manager = ContextManager(config)
|
||||
|
||||
# Need enough tokens to exceed threshold: 50 * 0.82 = 41 tokens
|
||||
# 150 chars * 0.3 = 45 tokens > 41
|
||||
messages = [
|
||||
Message(role="user", content=[TextPart(text="x" * 150)]),
|
||||
]
|
||||
|
||||
# Should trigger compression due to token count
|
||||
tokens = manager.token_counter.count_tokens(messages)
|
||||
needs_compression = manager.compressor.should_compress(messages, tokens, 50)
|
||||
|
||||
assert tokens > 0 # Tokens should be counted
|
||||
assert needs_compression # Should trigger compression
|
||||
|
||||
# ==================== Tool Calls Tests ====================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_messages_with_tool_calls(self):
|
||||
"""Test processing messages with tool calls."""
|
||||
config = ContextConfig()
|
||||
manager = ContextManager(config)
|
||||
|
||||
messages = [
|
||||
Message(
|
||||
role="assistant",
|
||||
content="Let me search for that",
|
||||
tool_calls=[
|
||||
{
|
||||
"id": "call_1",
|
||||
"type": "function",
|
||||
"function": {"name": "search", "arguments": "{}"},
|
||||
}
|
||||
],
|
||||
),
|
||||
Message(role="tool", content="Search result", tool_call_id="call_1"),
|
||||
]
|
||||
|
||||
result = await manager.process(messages)
|
||||
|
||||
assert len(result) == 2
|
||||
|
||||
# ==================== Compressor should_compress Tests ====================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_compress_empty_messages(self):
|
||||
"""Test should_compress with empty messages."""
|
||||
config = ContextConfig(max_context_tokens=100)
|
||||
manager = ContextManager(config)
|
||||
|
||||
# Compressor's should_compress should handle empty gracefully
|
||||
needs_compression = manager.compressor.should_compress([], 0, 100)
|
||||
assert not needs_compression
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_compress_below_threshold(self):
|
||||
"""Test should_compress when below compression threshold."""
|
||||
config = ContextConfig(max_context_tokens=1000)
|
||||
manager = ContextManager(config)
|
||||
|
||||
messages = [self.create_message("user", "Hello")]
|
||||
tokens = manager.token_counter.count_tokens(messages)
|
||||
|
||||
needs_compression = manager.compressor.should_compress(messages, tokens, 1000)
|
||||
assert not needs_compression
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_compress_above_threshold(self):
|
||||
"""Test should_compress when above compression threshold."""
|
||||
config = ContextConfig(max_context_tokens=100)
|
||||
manager = ContextManager(config)
|
||||
|
||||
# Create message with many tokens
|
||||
messages = [self.create_message("user", "这是测试" * 50)]
|
||||
tokens = manager.token_counter.count_tokens(messages)
|
||||
|
||||
needs_compression = manager.compressor.should_compress(messages, tokens, 100)
|
||||
# Should need compression if tokens > 82 (0.82 * 100)
|
||||
assert needs_compression == (tokens > 82)
|
||||
|
||||
# ==================== Truncator Halving Tests ====================
|
||||
|
||||
def test_truncate_by_halving_basic(self):
|
||||
"""Test truncate_by_halving removes middle 50%."""
|
||||
config = ContextConfig()
|
||||
manager = ContextManager(config)
|
||||
|
||||
messages = self.create_messages(10)
|
||||
result = manager.truncator.truncate_by_halving(messages)
|
||||
|
||||
# Should keep roughly half
|
||||
assert len(result) < len(messages)
|
||||
|
||||
def test_truncate_by_halving_empty_list(self):
|
||||
"""Test truncate_by_halving with empty list."""
|
||||
config = ContextConfig()
|
||||
manager = ContextManager(config)
|
||||
|
||||
result = manager.truncator.truncate_by_halving([])
|
||||
|
||||
assert result == []
|
||||
|
||||
def test_truncate_by_halving_single_message(self):
|
||||
"""Test truncate_by_halving with single message."""
|
||||
config = ContextConfig()
|
||||
manager = ContextManager(config)
|
||||
|
||||
messages = [self.create_message("user", "Hello")]
|
||||
result = manager.truncator.truncate_by_halving(messages)
|
||||
|
||||
assert len(result) <= 1
|
||||
|
||||
# ==================== Complex Scenarios ====================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_compression_cycles(self):
|
||||
"""Test that compression can be triggered multiple times in sequence."""
|
||||
config = ContextConfig(max_context_tokens=50, truncate_turns=1)
|
||||
manager = ContextManager(config)
|
||||
|
||||
# Process messages multiple times
|
||||
messages = self.create_messages(10)
|
||||
|
||||
result1 = await manager.process(messages)
|
||||
result2 = await manager.process(result1)
|
||||
result3 = await manager.process(result2)
|
||||
|
||||
# Each cycle should maintain or reduce message count
|
||||
assert len(result3) <= len(result2) <= len(result1)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_alternating_roles_preserved(self):
|
||||
"""Test that user/assistant alternation is preserved after processing."""
|
||||
config = ContextConfig(enforce_max_turns=3, truncate_turns=1)
|
||||
manager = ContextManager(config)
|
||||
|
||||
messages = self.create_messages(20)
|
||||
result = await manager.process(messages)
|
||||
|
||||
# Check that roles still alternate (excluding system messages)
|
||||
non_system = [m for m in result if m.role != "system"]
|
||||
if len(non_system) >= 2:
|
||||
# Should start with user
|
||||
assert non_system[0].role == "user"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_compression_threshold_default(self):
|
||||
"""Test that compression threshold is used correctly."""
|
||||
config = ContextConfig(max_context_tokens=100)
|
||||
manager = ContextManager(config)
|
||||
|
||||
# Verify the default threshold is 0.82
|
||||
assert manager.compressor.compression_threshold == 0.82
|
||||
|
||||
# Test threshold logic
|
||||
messages = [self.create_message("user", "x" * 81)] # ~24 tokens
|
||||
tokens = manager.token_counter.count_tokens(messages)
|
||||
|
||||
needs_compression = manager.compressor.should_compress(messages, tokens, 100)
|
||||
# Should not compress if below threshold
|
||||
assert needs_compression == (tokens > 82)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_large_batch_processing(self):
|
||||
"""Test processing a large batch of messages."""
|
||||
config = ContextConfig(
|
||||
enforce_max_turns=10, max_context_tokens=1000, truncate_turns=2
|
||||
)
|
||||
manager = ContextManager(config)
|
||||
|
||||
# Create 100 messages (50 turns)
|
||||
messages = self.create_messages(100)
|
||||
|
||||
result = await manager.process(messages)
|
||||
|
||||
# Should be significantly reduced
|
||||
assert len(result) < 100
|
||||
assert len(result) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_config_persistence(self):
|
||||
"""Test that config settings are respected throughout processing."""
|
||||
config = ContextConfig(
|
||||
max_context_tokens=500,
|
||||
enforce_max_turns=5,
|
||||
truncate_turns=2,
|
||||
llm_compress_keep_recent=3,
|
||||
)
|
||||
manager = ContextManager(config)
|
||||
|
||||
# Verify config is stored
|
||||
assert manager.config.max_context_tokens == 500
|
||||
assert manager.config.enforce_max_turns == 5
|
||||
assert manager.config.truncate_turns == 2
|
||||
assert manager.config.llm_compress_keep_recent == 3
|
||||
|
||||
# ==================== Run Compression Tests ====================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_compression_calls_compressor(self):
|
||||
"""Test _run_compression calls compressor."""
|
||||
config = ContextConfig(max_context_tokens=100)
|
||||
manager = ContextManager(config)
|
||||
|
||||
messages = self.create_messages(5)
|
||||
compressed = self.create_messages(3)
|
||||
|
||||
# Create a mock compressor
|
||||
mock_compressor = AsyncMock()
|
||||
mock_compressor.compression_threshold = 0.82
|
||||
mock_compressor.return_value = compressed
|
||||
mock_compressor.should_compress = MagicMock(return_value=False)
|
||||
manager.compressor = mock_compressor
|
||||
|
||||
result = await manager._run_compression(messages, prev_tokens=100)
|
||||
|
||||
# Compressor __call__ should be invoked
|
||||
mock_compressor.assert_called_once_with(messages)
|
||||
assert result == compressed
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_compression_applies_compressor_through_process(self):
|
||||
"""Test _run_compression calls compressor when needed through process()."""
|
||||
config = ContextConfig(max_context_tokens=100, truncate_turns=1)
|
||||
manager = ContextManager(config)
|
||||
|
||||
# Create messages that will trigger compression
|
||||
messages = [self.create_message("user", "x" * 300)] # ~90 tokens > 82 threshold
|
||||
compressed = [self.create_message("user", "short")] # Much smaller
|
||||
|
||||
# Create a mock compressor
|
||||
mock_compressor = AsyncMock()
|
||||
mock_compressor.compression_threshold = 0.82
|
||||
mock_compressor.return_value = compressed
|
||||
|
||||
# Mock should_compress to return True first time, False after
|
||||
call_count = 0
|
||||
|
||||
def mock_should_compress(*args, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return call_count == 1
|
||||
|
||||
mock_compressor.should_compress = mock_should_compress
|
||||
manager.compressor = mock_compressor
|
||||
|
||||
result = await manager.process(messages)
|
||||
|
||||
# Compressor should have been called
|
||||
mock_compressor.assert_called_once()
|
||||
assert len(result) <= len(messages)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_compression_with_mock_provider(self):
|
||||
"""Test LLM compression using MockProvider."""
|
||||
mock_provider = MockProvider()
|
||||
config = ContextConfig(
|
||||
llm_compress_provider=mock_provider, # type: ignore
|
||||
llm_compress_keep_recent=3,
|
||||
llm_compress_instruction="请总结对话内容",
|
||||
max_context_tokens=100,
|
||||
)
|
||||
manager = ContextManager(config)
|
||||
|
||||
# Create messages that will trigger compression
|
||||
messages = [
|
||||
self.create_message("user", "x" * 100),
|
||||
self.create_message("assistant", "y" * 100),
|
||||
self.create_message("user", "z" * 100),
|
||||
]
|
||||
|
||||
result = await manager.process(messages)
|
||||
|
||||
# Should have been compressed
|
||||
assert len(result) <= len(messages)
|
||||
|
||||
# ==================== split_history Tests ====================
|
||||
|
||||
def test_split_history_ensures_user_start(self):
|
||||
"""Test split_history ensures recent_messages starts with user message."""
|
||||
from astrbot.core.agent.context.compressor import split_history
|
||||
|
||||
# Create alternating messages: user, assistant, user, assistant, user, assistant
|
||||
messages = [
|
||||
self.create_message("system", "System prompt"),
|
||||
self.create_message("user", "msg1"),
|
||||
self.create_message("assistant", "msg2"),
|
||||
self.create_message("user", "msg3"),
|
||||
self.create_message("assistant", "msg4"),
|
||||
self.create_message("user", "msg5"),
|
||||
self.create_message("assistant", "msg6"),
|
||||
]
|
||||
|
||||
# Keep recent 3 messages - should adjust to start with user
|
||||
system, to_summarize, recent = split_history(messages, keep_recent=3)
|
||||
|
||||
# recent_messages should start with user message
|
||||
assert len(recent) > 0
|
||||
assert recent[0].role == "user"
|
||||
|
||||
# messages_to_summarize should end with assistant (complete turn)
|
||||
if len(to_summarize) > 0:
|
||||
assert to_summarize[-1].role == "assistant"
|
||||
|
||||
def test_split_history_handles_assistant_at_split_point(self):
|
||||
"""Test split_history when assistant message is at the intended split point."""
|
||||
from astrbot.core.agent.context.compressor import split_history
|
||||
|
||||
messages = [
|
||||
self.create_message("user", "msg1"),
|
||||
self.create_message("assistant", "msg2"),
|
||||
self.create_message("user", "msg3"),
|
||||
self.create_message("assistant", "msg4"), # <- intended split here
|
||||
self.create_message("user", "msg5"),
|
||||
self.create_message("assistant", "msg6"),
|
||||
]
|
||||
|
||||
# keep_recent=2 would normally split at index 4 (assistant msg4)
|
||||
# Should move back to include from msg5 (user)
|
||||
system, to_summarize, recent = split_history(messages, keep_recent=2)
|
||||
|
||||
# recent should start with user message
|
||||
assert recent[0].role == "user"
|
||||
assert recent[0].content == "msg5"
|
||||
|
||||
def test_split_history_all_assistant_messages(self):
|
||||
"""Test split_history when there are consecutive assistant messages."""
|
||||
from astrbot.core.agent.context.compressor import split_history
|
||||
|
||||
messages = [
|
||||
self.create_message("user", "msg1"),
|
||||
self.create_message("assistant", "msg2"),
|
||||
self.create_message("assistant", "msg3"),
|
||||
self.create_message("assistant", "msg4"),
|
||||
]
|
||||
|
||||
system, to_summarize, recent = split_history(messages, keep_recent=2)
|
||||
|
||||
# Should find the user message and keep from there
|
||||
if len(recent) > 0:
|
||||
# Find first user message backwards
|
||||
assert any(m.role == "user" for m in messages)
|
||||
|
||||
def test_split_history_with_system_messages(self):
|
||||
"""Test split_history preserves system messages separately."""
|
||||
from astrbot.core.agent.context.compressor import split_history
|
||||
|
||||
messages = [
|
||||
self.create_message("system", "System 1"),
|
||||
self.create_message("system", "System 2"),
|
||||
self.create_message("user", "msg1"),
|
||||
self.create_message("assistant", "msg2"),
|
||||
self.create_message("user", "msg3"),
|
||||
]
|
||||
|
||||
system, to_summarize, recent = split_history(messages, keep_recent=2)
|
||||
|
||||
# System messages should be separate
|
||||
assert len(system) == 2
|
||||
assert all(m.role == "system" for m in system)
|
||||
|
||||
# Recent should start with user
|
||||
if len(recent) > 0:
|
||||
assert recent[0].role == "user"
|
||||
@@ -0,0 +1,423 @@
|
||||
"""Tests for ContextTruncator."""
|
||||
|
||||
from astrbot.core.agent.context.truncator import ContextTruncator
|
||||
from astrbot.core.agent.message import Message
|
||||
|
||||
|
||||
class TestContextTruncator:
|
||||
"""Test suite for ContextTruncator."""
|
||||
|
||||
def create_message(self, role: str, content: str = "test content") -> Message:
|
||||
"""Helper to create a simple test message."""
|
||||
return Message(role=role, content=content)
|
||||
|
||||
def create_messages(
|
||||
self, count: int, include_system: bool = False
|
||||
) -> list[Message]:
|
||||
"""Helper to create alternating user/assistant messages.
|
||||
|
||||
Args:
|
||||
count: Number of messages to create
|
||||
include_system: Whether to include a system message at the start
|
||||
|
||||
Returns:
|
||||
List of messages
|
||||
"""
|
||||
messages = []
|
||||
if include_system:
|
||||
messages.append(self.create_message("system", "System prompt"))
|
||||
|
||||
for i in range(count):
|
||||
role = "user" if i % 2 == 0 else "assistant"
|
||||
messages.append(self.create_message(role, f"Message {i}"))
|
||||
return messages
|
||||
|
||||
# ==================== fix_messages Tests ====================
|
||||
|
||||
def test_fix_messages_empty_list(self):
|
||||
"""Test fix_messages with an empty list."""
|
||||
truncator = ContextTruncator()
|
||||
result = truncator.fix_messages([])
|
||||
assert result == []
|
||||
|
||||
def test_fix_messages_normal_messages(self):
|
||||
"""Test fix_messages with normal user/assistant messages."""
|
||||
truncator = ContextTruncator()
|
||||
messages = [
|
||||
self.create_message("user", "Hello"),
|
||||
self.create_message("assistant", "Hi"),
|
||||
self.create_message("user", "How are you?"),
|
||||
]
|
||||
result = truncator.fix_messages(messages)
|
||||
assert len(result) == 3
|
||||
assert result == messages
|
||||
|
||||
def test_fix_messages_tool_with_valid_context(self):
|
||||
"""Test fix_messages with tool message after user+assistant."""
|
||||
truncator = ContextTruncator()
|
||||
messages = [
|
||||
self.create_message("user", "Run tool"),
|
||||
self.create_message("assistant", "Running..."),
|
||||
self.create_message("tool", "Tool result"),
|
||||
]
|
||||
result = truncator.fix_messages(messages)
|
||||
assert len(result) == 3
|
||||
assert result == messages
|
||||
|
||||
def test_fix_messages_tool_without_context(self):
|
||||
"""Test fix_messages with tool message without enough context."""
|
||||
truncator = ContextTruncator()
|
||||
messages = [
|
||||
self.create_message("tool", "Tool result"),
|
||||
]
|
||||
result = truncator.fix_messages(messages)
|
||||
# Tool message without context should be removed
|
||||
assert len(result) == 0
|
||||
|
||||
def test_fix_messages_tool_with_only_one_message(self):
|
||||
"""Test fix_messages with tool message after only one message."""
|
||||
truncator = ContextTruncator()
|
||||
messages = [
|
||||
self.create_message("user", "Hello"),
|
||||
self.create_message("tool", "Tool result"),
|
||||
]
|
||||
result = truncator.fix_messages(messages)
|
||||
# Tool message without enough context should be removed
|
||||
assert len(result) == 0
|
||||
|
||||
def test_fix_messages_multiple_tools(self):
|
||||
"""Test fix_messages with multiple tool messages."""
|
||||
truncator = ContextTruncator()
|
||||
messages = [
|
||||
self.create_message("user", "Run tool"),
|
||||
self.create_message("assistant", "Running..."),
|
||||
self.create_message("tool", "Tool 1 result"),
|
||||
self.create_message("tool", "Tool 2 result"),
|
||||
]
|
||||
result = truncator.fix_messages(messages)
|
||||
assert len(result) == 4
|
||||
assert result == messages
|
||||
|
||||
def test_fix_messages_mixed_system_tool(self):
|
||||
"""Test fix_messages with system message and tool messages."""
|
||||
truncator = ContextTruncator()
|
||||
messages = [
|
||||
self.create_message("system", "System prompt"),
|
||||
self.create_message("user", "Run tool"),
|
||||
self.create_message("assistant", "Running..."),
|
||||
self.create_message("tool", "Tool result"),
|
||||
]
|
||||
result = truncator.fix_messages(messages)
|
||||
assert len(result) == 4
|
||||
assert result == messages
|
||||
|
||||
# ==================== truncate_by_turns Tests ====================
|
||||
|
||||
def test_truncate_by_turns_no_limit(self):
|
||||
"""Test truncate_by_turns with -1 (no limit)."""
|
||||
truncator = ContextTruncator()
|
||||
messages = self.create_messages(20)
|
||||
result = truncator.truncate_by_turns(messages, keep_most_recent_turns=-1)
|
||||
assert len(result) == 20
|
||||
assert result == messages
|
||||
|
||||
def test_truncate_by_turns_basic(self):
|
||||
"""Test basic truncate_by_turns functionality."""
|
||||
truncator = ContextTruncator()
|
||||
# Create 10 messages = 5 turns (user/assistant pairs)
|
||||
messages = self.create_messages(10)
|
||||
result = truncator.truncate_by_turns(
|
||||
messages, keep_most_recent_turns=3, drop_turns=1
|
||||
)
|
||||
|
||||
# Should keep 3 most recent turns (6 messages)
|
||||
assert len(result) <= 8 # (3-1+1)*2 = 6, but may adjust for correct format
|
||||
|
||||
def test_truncate_by_turns_with_system_message(self):
|
||||
"""Test truncate_by_turns preserves system messages."""
|
||||
truncator = ContextTruncator()
|
||||
messages = self.create_messages(10, include_system=True)
|
||||
result = truncator.truncate_by_turns(
|
||||
messages, keep_most_recent_turns=2, drop_turns=1
|
||||
)
|
||||
|
||||
# System message should always be preserved
|
||||
assert result[0].role == "system"
|
||||
assert result[0].content == "System prompt"
|
||||
|
||||
def test_truncate_by_turns_zero_keep(self):
|
||||
"""Test truncate_by_turns with keep_most_recent_turns=0."""
|
||||
truncator = ContextTruncator()
|
||||
messages = self.create_messages(10)
|
||||
result = truncator.truncate_by_turns(
|
||||
messages, keep_most_recent_turns=0, drop_turns=1
|
||||
)
|
||||
|
||||
# Should result in empty or minimal list
|
||||
assert len(result) == 0
|
||||
|
||||
def test_truncate_by_turns_below_threshold(self):
|
||||
"""Test truncate_by_turns when messages are below threshold."""
|
||||
truncator = ContextTruncator()
|
||||
# Create 4 messages = 2 turns
|
||||
messages = self.create_messages(4)
|
||||
result = truncator.truncate_by_turns(
|
||||
messages, keep_most_recent_turns=5, drop_turns=1
|
||||
)
|
||||
|
||||
# No truncation should happen
|
||||
assert len(result) == 4
|
||||
assert result == messages
|
||||
|
||||
def test_truncate_by_turns_exact_threshold(self):
|
||||
"""Test truncate_by_turns when messages exactly match threshold."""
|
||||
truncator = ContextTruncator()
|
||||
# Create 6 messages = 3 turns
|
||||
messages = self.create_messages(6)
|
||||
result = truncator.truncate_by_turns(
|
||||
messages, keep_most_recent_turns=3, drop_turns=1
|
||||
)
|
||||
|
||||
# No truncation should happen
|
||||
assert len(result) == 6
|
||||
assert result == messages
|
||||
|
||||
def test_truncate_by_turns_ensures_user_first(self):
|
||||
"""Test that truncate_by_turns ensures user message comes first."""
|
||||
truncator = ContextTruncator()
|
||||
# Create scenario where truncation might start with assistant
|
||||
messages = self.create_messages(20)
|
||||
result = truncator.truncate_by_turns(
|
||||
messages, keep_most_recent_turns=3, drop_turns=1
|
||||
)
|
||||
|
||||
# First non-system message should be user
|
||||
assert result[0].role == "user"
|
||||
|
||||
def test_truncate_by_turns_multiple_drop(self):
|
||||
"""Test truncate_by_turns with multiple turns dropped at once."""
|
||||
truncator = ContextTruncator()
|
||||
messages = self.create_messages(20)
|
||||
result = truncator.truncate_by_turns(
|
||||
messages, keep_most_recent_turns=5, drop_turns=3
|
||||
)
|
||||
|
||||
# Should drop 3 turns when limit exceeded
|
||||
assert len(result) < len(messages)
|
||||
|
||||
# ==================== truncate_by_dropping_oldest_turns Tests ====================
|
||||
|
||||
def test_truncate_by_dropping_oldest_turns_zero(self):
|
||||
"""Test truncate_by_dropping_oldest_turns with drop_turns=0."""
|
||||
truncator = ContextTruncator()
|
||||
messages = self.create_messages(10)
|
||||
result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=0)
|
||||
assert result == messages
|
||||
|
||||
def test_truncate_by_dropping_oldest_turns_negative(self):
|
||||
"""Test truncate_by_dropping_oldest_turns with negative drop_turns."""
|
||||
truncator = ContextTruncator()
|
||||
messages = self.create_messages(10)
|
||||
result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=-1)
|
||||
assert result == messages
|
||||
|
||||
def test_truncate_by_dropping_oldest_turns_basic(self):
|
||||
"""Test basic truncate_by_dropping_oldest_turns functionality."""
|
||||
truncator = ContextTruncator()
|
||||
# Create 10 messages = 5 turns
|
||||
messages = self.create_messages(10)
|
||||
result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=2)
|
||||
|
||||
# Should drop 2 oldest turns (4 messages)
|
||||
assert len(result) == 6
|
||||
# Should start with user message
|
||||
assert result[0].role == "user"
|
||||
|
||||
def test_truncate_by_dropping_oldest_turns_with_system(self):
|
||||
"""Test truncate_by_dropping_oldest_turns preserves system messages."""
|
||||
truncator = ContextTruncator()
|
||||
messages = self.create_messages(10, include_system=True)
|
||||
result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=2)
|
||||
|
||||
# System message should be preserved
|
||||
assert result[0].role == "system"
|
||||
assert result[0].content == "System prompt"
|
||||
|
||||
def test_truncate_by_dropping_oldest_turns_drop_all(self):
|
||||
"""Test truncate_by_dropping_oldest_turns dropping all turns."""
|
||||
truncator = ContextTruncator()
|
||||
# Create 4 messages = 2 turns
|
||||
messages = self.create_messages(4)
|
||||
result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=2)
|
||||
|
||||
# Should drop all turns
|
||||
assert len(result) == 0
|
||||
|
||||
def test_truncate_by_dropping_oldest_turns_drop_more_than_available(self):
|
||||
"""Test truncate_by_dropping_oldest_turns with drop_turns > available turns."""
|
||||
truncator = ContextTruncator()
|
||||
# Create 4 messages = 2 turns
|
||||
messages = self.create_messages(4)
|
||||
result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=5)
|
||||
|
||||
# Should result in empty list
|
||||
assert len(result) == 0
|
||||
|
||||
def test_truncate_by_dropping_oldest_turns_ensures_user_first(self):
|
||||
"""Test that result starts with user message after dropping."""
|
||||
truncator = ContextTruncator()
|
||||
messages = self.create_messages(20)
|
||||
result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=3)
|
||||
|
||||
# First message should be user
|
||||
if len(result) > 0:
|
||||
assert result[0].role == "user"
|
||||
|
||||
# ==================== truncate_by_halving Tests ====================
|
||||
|
||||
def test_truncate_by_halving_empty(self):
|
||||
"""Test truncate_by_halving with empty list."""
|
||||
truncator = ContextTruncator()
|
||||
result = truncator.truncate_by_halving([])
|
||||
assert result == []
|
||||
|
||||
def test_truncate_by_halving_single_message(self):
|
||||
"""Test truncate_by_halving with single message."""
|
||||
truncator = ContextTruncator()
|
||||
messages = [self.create_message("user", "Hello")]
|
||||
result = truncator.truncate_by_halving(messages)
|
||||
# Should not truncate if <= 2 messages
|
||||
assert result == messages
|
||||
|
||||
def test_truncate_by_halving_two_messages(self):
|
||||
"""Test truncate_by_halving with two messages."""
|
||||
truncator = ContextTruncator()
|
||||
messages = self.create_messages(2)
|
||||
result = truncator.truncate_by_halving(messages)
|
||||
# Should not truncate if <= 2 messages
|
||||
assert result == messages
|
||||
|
||||
def test_truncate_by_halving_basic(self):
|
||||
"""Test basic truncate_by_halving functionality."""
|
||||
truncator = ContextTruncator()
|
||||
# Create 20 messages
|
||||
messages = self.create_messages(20)
|
||||
result = truncator.truncate_by_halving(messages)
|
||||
|
||||
# Should delete 50% = 10 messages, keep 10
|
||||
assert len(result) == 10
|
||||
# First message should be user
|
||||
assert result[0].role == "user"
|
||||
|
||||
def test_truncate_by_halving_with_system_message(self):
|
||||
"""Test truncate_by_halving preserves system messages."""
|
||||
truncator = ContextTruncator()
|
||||
messages = self.create_messages(20, include_system=True)
|
||||
result = truncator.truncate_by_halving(messages)
|
||||
|
||||
# System message should be preserved
|
||||
assert result[0].role == "system"
|
||||
assert result[0].content == "System prompt"
|
||||
|
||||
def test_truncate_by_halving_odd_count(self):
|
||||
"""Test truncate_by_halving with odd number of messages."""
|
||||
truncator = ContextTruncator()
|
||||
messages = self.create_messages(11)
|
||||
result = truncator.truncate_by_halving(messages)
|
||||
|
||||
# Should delete floor(11/2) = 5 messages, keep 6
|
||||
# But after ensuring user first, may be 5
|
||||
assert len(result) >= 5
|
||||
assert result[0].role == "user"
|
||||
|
||||
def test_truncate_by_halving_ensures_user_first(self):
|
||||
"""Test that result starts with user message."""
|
||||
truncator = ContextTruncator()
|
||||
# Create messages starting with user
|
||||
messages = self.create_messages(30)
|
||||
result = truncator.truncate_by_halving(messages)
|
||||
|
||||
# First message should be user
|
||||
assert result[0].role == "user"
|
||||
|
||||
def test_truncate_by_halving_preserves_recent_messages(self):
|
||||
"""Test that truncate_by_halving keeps the most recent 50%."""
|
||||
truncator = ContextTruncator()
|
||||
messages = [
|
||||
self.create_message("user", "Message 0"),
|
||||
self.create_message("assistant", "Message 1"),
|
||||
self.create_message("user", "Message 2"),
|
||||
self.create_message("assistant", "Message 3"),
|
||||
]
|
||||
result = truncator.truncate_by_halving(messages)
|
||||
|
||||
# Should keep last 2 messages
|
||||
assert len(result) == 2
|
||||
assert result[0].content == "Message 2"
|
||||
assert result[1].content == "Message 3"
|
||||
|
||||
# ==================== Integration Tests ====================
|
||||
|
||||
def test_truncate_with_tool_messages(self):
|
||||
"""Test truncation with tool messages."""
|
||||
truncator = ContextTruncator()
|
||||
messages = [
|
||||
self.create_message("user", "Run tool"),
|
||||
self.create_message("assistant", "Running..."),
|
||||
self.create_message("tool", "Tool result"),
|
||||
self.create_message("user", "Thanks"),
|
||||
self.create_message("assistant", "Welcome"),
|
||||
]
|
||||
|
||||
result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=1)
|
||||
|
||||
# First turn (user+assistant+tool) should be dropped
|
||||
# Tool message should be cleaned up by fix_messages
|
||||
assert len(result) <= 2
|
||||
|
||||
def test_chain_multiple_truncations(self):
|
||||
"""Test chaining multiple truncation methods."""
|
||||
truncator = ContextTruncator()
|
||||
messages = self.create_messages(40, include_system=True)
|
||||
|
||||
# First: truncate by turns
|
||||
result = truncator.truncate_by_turns(
|
||||
messages, keep_most_recent_turns=10, drop_turns=2
|
||||
)
|
||||
# Then: halve
|
||||
result = truncator.truncate_by_halving(result)
|
||||
|
||||
# Should have system message + truncated content
|
||||
assert result[0].role == "system"
|
||||
assert len(result) < len(messages)
|
||||
|
||||
def test_empty_after_system_message(self):
|
||||
"""Test truncation when only system message exists."""
|
||||
truncator = ContextTruncator()
|
||||
messages = [self.create_message("system", "System prompt")]
|
||||
|
||||
result = truncator.truncate_by_turns(
|
||||
messages, keep_most_recent_turns=5, drop_turns=1
|
||||
)
|
||||
|
||||
# Should keep system message
|
||||
assert len(result) == 1
|
||||
assert result[0].role == "system"
|
||||
|
||||
def test_all_system_messages(self):
|
||||
"""Test truncation with only system messages."""
|
||||
truncator = ContextTruncator()
|
||||
messages = [
|
||||
self.create_message("system", "System 1"),
|
||||
self.create_message("system", "System 2"),
|
||||
]
|
||||
|
||||
result = truncator.truncate_by_turns(
|
||||
messages, keep_most_recent_turns=0, drop_turns=1
|
||||
)
|
||||
|
||||
# System messages should be preserved, but since there are no non-system
|
||||
# messages and keep_most_recent_turns=0, result should be system messages only
|
||||
assert len(result) >= 0 # May keep system messages or clear all
|
||||
if len(result) > 0:
|
||||
assert all(msg.role == "system" for msg in result)
|
||||
Reference in New Issue
Block a user