Compare commits

..

27 Commits

Author SHA1 Message Date
Soulter 82a96a8cce chore: bump version to 4.11.0 2026-01-05 18:03:35 +08:00
Soulter 343b153263 feat: add token_usage tracking to conversations and update related processing logic 2026-01-05 16:53:37 +08:00
Soulter 3a41b19318 fix: reorder import statements for consistency 2026-01-05 15:49:44 +08:00
Soulter af444ea6cc feat: implement context compression logic with dynamic threshold and token tracking 2026-01-05 14:12:13 +08:00
Soulter cb84db532e feat: update logging for context compression trigger 2026-01-05 11:35:33 +08:00
Soulter 99b82f48ec feat: enhance context compression with token tracking and logging 2026-01-05 11:34:18 +08:00
Soulter 00471f904e perf 2026-01-05 11:19:32 +08:00
Soulter 5df15c60ff fix 2026-01-05 11:01:18 +08:00
Soulter 32e523b7da ruff fix 2026-01-05 10:58:44 +08:00
Soulter 0de4fd9f0d chore: remove lock 2026-01-05 10:57:48 +08:00
Soulter e23a7e2505 feat: add MockProvider for LLM compression tests 2026-01-05 10:57:00 +08:00
Soulter 1ed4d9f484 Add comprehensive tests for ContextManager and ContextTruncator
- Implemented a full test suite for ContextManager covering initialization, message processing, token-based compression, and error handling.
- Added tests for ContextTruncator focusing on message fixing, truncation by turns, dropping oldest turns, and halving.
- Ensured that both test suites validate edge cases and maintain expected behavior with various message types, including system and tool messages.
2026-01-05 10:48:00 +08:00
Soulter d842155770 feat: context compressor
Co-authored-by: kawayiYokami <289104862@qq.com>
2026-01-05 00:28:54 +08:00
Gao Jinzhe 7f5cc7cf1a feat: add on_waiting_llm_request event hook (#4319)
* 加入on_waiting_llm_request钩子

* ruff check
2026-01-04 16:11:12 +08:00
Oscar Shaw f26867c77d ci(stale): 增加 stale action 每次运行的操作限制 (#4256) 2026-01-04 11:20:03 +08:00
Soulter a14d588b44 docs: add Matrix adapter to community maintained section in multiple languages 2026-01-04 10:15:16 +08:00
Soulter e236402d92 chore: update platform adapter name for clarity 2026-01-04 10:12:25 +08:00
Soulter 454841de10 fix: database is locked error when invoking tts command (#4313)
* fix: database is locked error when invoking /tts command

fixes: #4311

* chore: rm pnpm lockfile

* perf: 减少操作数据库的次数
2026-01-03 19:12:39 +08:00
clown145 442b5403df feat(webui): supports force update plugins (#4293) 2026-01-03 15:30:50 +08:00
Soulter 9db7bf59b8 docs: add new community group contact 2026-01-03 00:48:55 +08:00
雪語 3622504021 fix: retry failed due to a mismatch in the msg.id data type of a WeChat Official Account (#4292)
问题描述:
- 控制台显示正常发送消息,但公众号未收到
- 处理时间 > 5秒的消息几乎总是失败(如 AI 图片生成)
- 短消息(<5秒)正常工作

根本原因:
msg.id 是整数类型,但字典 key 使用字符串类型,导致类型不匹配。
检查时整数无法匹配字符串 key,导致每次都创建新的 future,
微信重试时无法重用,最终导致响应失败。

修复内容:
将 msg.id 转换为字符串后再检查字典
  if str(msg.id) in self.wexin_event_workers:

影响范围:
- 修复了微信重试时无法正确重用 future 的问题
- AI 图片生成、长文本生成等耗时操作现在可以正常工作
- 仅影响微信公众号适配器,其他平台不受影响

Fixes #1679
2026-01-02 22:16:04 +08:00
Soulter fc42db40ce chore: bump version to 4.10.6 2026-01-02 12:14:59 +08:00
Soulter e413a002c1 perf: list view mode toggle with localStorage support in ExtensionPage (#4288)
closes: #4253
2026-01-02 11:59:41 +08:00
tjc66666666 6437d759a3 fix: reasoning content inject for openai api (#4284) 2026-01-02 01:09:28 +08:00
Soulter c758b2d888 feat: use shell globbing to match umop config router (#4270)
* feat: use shell globbing to match umop config router

* rf

* fix: use fnmatchcase for case-sensitive matching in UmopConfigRouter
2025-12-31 23:10:12 +08:00
Soulter 510290fe0e chore: bump version to 4.10.5 2025-12-31 17:58:28 +08:00
Soulter c61d62edb6 fix: handle null item-meta in ConfigItemRenderer (#4269)
fixes: #4268
2025-12-31 17:55:49 +08:00
55 changed files with 2480 additions and 218 deletions
+1
View File
@@ -26,6 +26,7 @@ jobs:
- uses: actions/stale@v10
with:
repo-token: ${{ secrets.GITHUB_TOKEN }}
operations-per-run: 200
# 只处理带 bug 标签的 Issue
any-of-labels: 'bug'
+2
View File
@@ -132,6 +132,7 @@ uv run main.py
**社区维护**
- [Matrix](https://github.com/stevessr/astrbot_plugin_matrix_adapter)
- [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter)
- [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat)
- [Bilibili 私信](https://github.com/Hina-Chat/astrbot_plugin_bilibili_adapter)
@@ -208,6 +209,7 @@ pre-commit install
- 5 群:822130018
- 6 群:753075035
- 7 群:743746109
- 8 群:1030353265
- 开发者群:975206796
### Telegram 群组
+1
View File
@@ -134,6 +134,7 @@ Or refer to the official documentation: [Deploy AstrBot from Source](https://ast
**Community Maintained**
- [Matrix](https://github.com/stevessr/astrbot_plugin_matrix_adapter)
- [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter)
- [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat)
- [Bilibili Direct Messages](https://github.com/Hina-Chat/astrbot_plugin_bilibili_adapter)
+1
View File
@@ -134,6 +134,7 @@ Ou consultez la documentation officielle : [Déployer AstrBot depuis les sources
**Maintenues par la communauté**
- [Matrix](https://github.com/stevessr/astrbot_plugin_matrix_adapter)
- [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter)
- [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat)
- [Messages directs Bilibili](https://github.com/Hina-Chat/astrbot_plugin_bilibili_adapter)
+1
View File
@@ -134,6 +134,7 @@ uv run main.py
**コミュニティメンテナンス**
- [Matrix](https://github.com/stevessr/astrbot_plugin_matrix_adapter)
- [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter)
- [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat)
- [Bilibili ダイレクトメッセージ](https://github.com/Hina-Chat/astrbot_plugin_bilibili_adapter)
+1
View File
@@ -134,6 +134,7 @@ uv run main.py
**Поддерживаемые сообществом**
- [Matrix](https://github.com/stevessr/astrbot_plugin_matrix_adapter)
- [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter)
- [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat)
- [Личные сообщения Bilibili](https://github.com/Hina-Chat/astrbot_plugin_bilibili_adapter)
+1
View File
@@ -134,6 +134,7 @@ uv run main.py
**社群維護**
- [Matrix](https://github.com/stevessr/astrbot_plugin_matrix_adapter)
- [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter)
- [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat)
- [Bilibili 私訊](https://github.com/Hina-Chat/astrbot_plugin_bilibili_adapter)
+4
View File
@@ -21,6 +21,9 @@ from astrbot.core.star.register import (
from astrbot.core.star.register import register_on_llm_request as on_llm_request
from astrbot.core.star.register import register_on_llm_response as on_llm_response
from astrbot.core.star.register import register_on_platform_loaded as on_platform_loaded
from astrbot.core.star.register import (
register_on_waiting_llm_request as on_waiting_llm_request,
)
from astrbot.core.star.register import register_permission_type as permission_type
from astrbot.core.star.register import (
register_platform_adapter_type as platform_adapter_type,
@@ -46,6 +49,7 @@ __all__ = [
"on_llm_request",
"on_llm_response",
"on_platform_loaded",
"on_waiting_llm_request",
"permission_type",
"platform_adapter_type",
"regex",
@@ -14,13 +14,13 @@ class TTSCommand:
async def tts(self, event: AstrMessageEvent):
"""开关文本转语音(会话级别)"""
umo = event.unified_msg_origin
ses_tts = SessionServiceManager.is_tts_enabled_for_session(umo)
ses_tts = await SessionServiceManager.is_tts_enabled_for_session(umo)
cfg = self.context.get_config(umo=umo)
tts_enable = cfg["provider_tts_settings"]["enable"]
# 切换状态
new_status = not ses_tts
SessionServiceManager.set_tts_status_for_session(umo, new_status)
await SessionServiceManager.set_tts_status_for_session(umo, new_status)
status_text = "已开启" if new_status else "已关闭"
+1 -1
View File
@@ -1 +1 @@
__version__ = "4.10.4"
__version__ = "4.11.0"
+243
View File
@@ -0,0 +1,243 @@
from typing import TYPE_CHECKING, Protocol, runtime_checkable
from ..message import Message
if TYPE_CHECKING:
from astrbot import logger
else:
try:
from astrbot import logger
except ImportError:
import logging
logger = logging.getLogger("astrbot")
if TYPE_CHECKING:
from astrbot.core.provider.provider import Provider
from ..context.truncator import ContextTruncator
@runtime_checkable
class ContextCompressor(Protocol):
"""
Protocol for context compressors.
Provides an interface for compressing message lists.
"""
def should_compress(
self, messages: list[Message], current_tokens: int, max_tokens: int
) -> bool:
"""Check if compression is needed.
Args:
messages: The message list to evaluate.
current_tokens: The current token count.
max_tokens: The maximum allowed tokens for the model.
Returns:
True if compression is needed, False otherwise.
"""
...
async def __call__(self, messages: list[Message]) -> list[Message]:
"""Compress the message list.
Args:
messages: The original message list.
Returns:
The compressed message list.
"""
...
class TruncateByTurnsCompressor:
"""Truncate by turns compressor implementation.
Truncates the message list by removing older turns.
"""
def __init__(self, truncate_turns: int = 1, compression_threshold: float = 0.82):
"""Initialize the truncate by turns compressor.
Args:
truncate_turns: The number of turns to remove when truncating (default: 1).
compression_threshold: The compression trigger threshold (default: 0.82).
"""
self.truncate_turns = truncate_turns
self.compression_threshold = compression_threshold
def should_compress(
self, messages: list[Message], current_tokens: int, max_tokens: int
) -> bool:
"""Check if compression is needed.
Args:
messages: The message list to evaluate.
current_tokens: The current token count.
max_tokens: The maximum allowed tokens.
Returns:
True if compression is needed, False otherwise.
"""
if max_tokens <= 0 or current_tokens <= 0:
return False
usage_rate = current_tokens / max_tokens
return usage_rate > self.compression_threshold
async def __call__(self, messages: list[Message]) -> list[Message]:
truncator = ContextTruncator()
truncated_messages = truncator.truncate_by_dropping_oldest_turns(
messages,
drop_turns=self.truncate_turns,
)
return truncated_messages
def split_history(
messages: list[Message], keep_recent: int
) -> tuple[list[Message], list[Message], list[Message]]:
"""Split the message list into system messages, messages to summarize, and recent messages.
Ensures that the split point is between complete user-assistant pairs to maintain conversation flow.
Args:
messages: The original message list.
keep_recent: The number of latest messages to keep.
Returns:
tuple: (system_messages, messages_to_summarize, recent_messages)
"""
# keep the system messages
first_non_system = 0
for i, msg in enumerate(messages):
if msg.role != "system":
first_non_system = i
break
system_messages = messages[:first_non_system]
non_system_messages = messages[first_non_system:]
if len(non_system_messages) <= keep_recent:
return system_messages, [], non_system_messages
# Find the split point, ensuring recent_messages starts with a user message
# This maintains complete conversation turns
split_index = len(non_system_messages) - keep_recent
# Search backward from split_index to find the first user message
# This ensures recent_messages starts with a user message (complete turn)
while split_index > 0 and non_system_messages[split_index].role != "user":
# TODO: +=1 or -=1 ? calculate by tokens
split_index -= 1
# If we couldn't find a user message, keep all messages as recent
if split_index == 0:
return system_messages, [], non_system_messages
messages_to_summarize = non_system_messages[:split_index]
recent_messages = non_system_messages[split_index:]
return system_messages, messages_to_summarize, recent_messages
class LLMSummaryCompressor:
"""LLM-based summary compressor.
Uses LLM to summarize the old conversation history, keeping the latest messages.
"""
def __init__(
self,
provider: "Provider",
keep_recent: int = 4,
instruction_text: str | None = None,
compression_threshold: float = 0.82,
):
"""Initialize the LLM summary compressor.
Args:
provider: The LLM provider instance.
keep_recent: The number of latest messages to keep (default: 4).
instruction_text: Custom instruction for summary generation.
compression_threshold: The compression trigger threshold (default: 0.82).
"""
self.provider = provider
self.keep_recent = keep_recent
self.compression_threshold = compression_threshold
self.instruction_text = instruction_text or (
"Based on our full conversation history, produce a concise summary of key takeaways and/or project progress.\n"
"1. Systematically cover all core topics discussed and the final conclusion/outcome for each; clearly highlight the latest primary focus.\n"
"2. If any tools were used, summarize tool usage (total call count) and extract the most valuable insights from tool outputs.\n"
"3. If there was an initial user goal, state it first and describe the current progress/status.\n"
"4. Write the summary in the user's language.\n"
)
def should_compress(
self, messages: list[Message], current_tokens: int, max_tokens: int
) -> bool:
"""Check if compression is needed.
Args:
messages: The message list to evaluate.
current_tokens: The current token count.
max_tokens: The maximum allowed tokens.
Returns:
True if compression is needed, False otherwise.
"""
if max_tokens <= 0 or current_tokens <= 0:
return False
usage_rate = current_tokens / max_tokens
return usage_rate > self.compression_threshold
async def __call__(self, messages: list[Message]) -> list[Message]:
"""Use LLM to generate a summary of the conversation history.
Process:
1. Divide messages: keep the system message and the latest N messages.
2. Send the old messages + the instruction message to the LLM.
3. Reconstruct the message list: [system message, summary message, latest messages].
"""
if len(messages) <= self.keep_recent + 1:
return messages
system_messages, messages_to_summarize, recent_messages = split_history(
messages, self.keep_recent
)
if not messages_to_summarize:
return messages
# build payload
instruction_message = Message(role="user", content=self.instruction_text)
llm_payload = messages_to_summarize + [instruction_message]
# generate summary
try:
response = await self.provider.text_chat(contexts=llm_payload)
summary_content = response.completion_text
except Exception as e:
logger.error(f"Failed to generate summary: {e}")
return messages
# build result
result = []
result.extend(system_messages)
result.append(
Message(
role="user",
content=f"Our previous history conversation summary: {summary_content}",
)
)
result.append(
Message(
role="assistant",
content="Acknowledged the summary of our previous conversation history.",
)
)
result.extend(recent_messages)
return result
+35
View File
@@ -0,0 +1,35 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING
from .compressor import ContextCompressor
from .token_counter import TokenCounter
if TYPE_CHECKING:
from astrbot.core.provider.provider import Provider
@dataclass
class ContextConfig:
"""Context configuration class."""
max_context_tokens: int = 0
"""Maximum number of context tokens. <= 0 means no limit."""
enforce_max_turns: int = -1 # -1 means no limit
"""Maximum number of conversation turns to keep. -1 means no limit. Executed before compression."""
truncate_turns: int = 1
"""Number of conversation turns to discard at once when truncation is triggered.
Two processes will use this value:
1. Enforce max turns truncation.
2. Truncation by turns compression strategy.
"""
llm_compress_instruction: str | None = None
"""Instruction prompt for LLM-based compression."""
llm_compress_keep_recent: int = 0
"""Number of recent messages to keep during LLM-based compression."""
llm_compress_provider: "Provider | None" = None
"""LLM provider used for compression tasks. If None, truncation strategy is used."""
custom_token_counter: TokenCounter | None = None
"""Custom token counting method. If None, the default method is used."""
custom_compressor: ContextCompressor | None = None
"""Custom context compression method. If None, the default method is used."""
+120
View File
@@ -0,0 +1,120 @@
from astrbot import logger
from ..message import Message
from .compressor import LLMSummaryCompressor, TruncateByTurnsCompressor
from .config import ContextConfig
from .token_counter import EstimateTokenCounter
from .truncator import ContextTruncator
class ContextManager:
"""Context compression manager."""
def __init__(
self,
config: ContextConfig,
):
"""Initialize the context manager.
There are two strategies to handle context limit reached:
1. Truncate by turns: remove older messages by turns.
2. LLM-based compression: use LLM to summarize old messages.
Args:
config: The context configuration.
"""
self.config = config
self.token_counter = config.custom_token_counter or EstimateTokenCounter()
self.truncator = ContextTruncator()
if config.custom_compressor:
self.compressor = config.custom_compressor
elif config.llm_compress_provider:
self.compressor = LLMSummaryCompressor(
provider=config.llm_compress_provider,
keep_recent=config.llm_compress_keep_recent,
instruction_text=config.llm_compress_instruction,
)
else:
self.compressor = TruncateByTurnsCompressor(
truncate_turns=config.truncate_turns
)
async def process(
self, messages: list[Message], trusted_token_usage: int = 0
) -> list[Message]:
"""Process the messages.
Args:
messages: The original message list.
Returns:
The processed message list.
"""
try:
result = messages
# 1. 基于轮次的截断 (Enforce max turns)
if self.config.enforce_max_turns != -1:
result = self.truncator.truncate_by_turns(
result,
keep_most_recent_turns=self.config.enforce_max_turns,
drop_turns=self.config.truncate_turns,
)
# 2. 基于 token 的压缩
if self.config.max_context_tokens > 0:
total_tokens = self.token_counter.count_tokens(
result, trusted_token_usage
)
if self.compressor.should_compress(
result, total_tokens, self.config.max_context_tokens
):
result = await self._run_compression(result, total_tokens)
return result
except Exception as e:
logger.error(f"Error during context processing: {e}", exc_info=True)
return messages
async def _run_compression(
self, messages: list[Message], prev_tokens: int
) -> list[Message]:
"""
Compress/truncate the messages.
Args:
messages: The original message list.
prev_tokens: The token count before compression.
Returns:
The compressed/truncated message list.
"""
logger.debug("Compress triggered, starting compression...")
messages = await self.compressor(messages)
# double check
tokens_after_summary = self.token_counter.count_tokens(messages)
# calculate compress rate
compress_rate = (tokens_after_summary / self.config.max_context_tokens) * 100
logger.info(
f"Compress completed."
f" {prev_tokens} -> {tokens_after_summary} tokens,"
f" compression rate: {compress_rate:.2f}%.",
)
# last check
if self.compressor.should_compress(
messages, tokens_after_summary, self.config.max_context_tokens
):
logger.info(
"Context still exceeds max tokens after compression, applying halving truncation..."
)
# still need compress, truncate by half
messages = self.truncator.truncate_by_halving(messages)
return messages
@@ -0,0 +1,64 @@
import json
from typing import Protocol, runtime_checkable
from ..message import Message, TextPart
@runtime_checkable
class TokenCounter(Protocol):
"""
Protocol for token counters.
Provides an interface for counting tokens in message lists.
"""
def count_tokens(
self, messages: list[Message], trusted_token_usage: int = 0
) -> int:
"""Count the total tokens in the message list.
Args:
messages: The message list.
trusted_token_usage: The total token usage that LLM API returned.
For some cases, this value is more accurate.
But some API does not return it, so the value defaults to 0.
Returns:
The total token count.
"""
...
class EstimateTokenCounter:
"""Estimate token counter implementation.
Provides a simple estimation of token count based on character types.
"""
def count_tokens(
self, messages: list[Message], trusted_token_usage: int = 0
) -> int:
if trusted_token_usage > 0:
return trusted_token_usage
total = 0
for msg in messages:
content = msg.content
if isinstance(content, str):
total += self._estimate_tokens(content)
elif isinstance(content, list):
# 处理多模态内容
for part in content:
if isinstance(part, TextPart):
total += self._estimate_tokens(part.text)
# 处理 Tool Calls
if msg.tool_calls:
for tc in msg.tool_calls:
tc_str = json.dumps(tc if isinstance(tc, dict) else tc.model_dump())
total += self._estimate_tokens(tc_str)
return total
def _estimate_tokens(self, text: str) -> int:
chinese_count = len([c for c in text if "\u4e00" <= c <= "\u9fff"])
other_count = len(text) - chinese_count
return int(chinese_count * 0.6 + other_count * 0.3)
+141
View File
@@ -0,0 +1,141 @@
from ..message import Message
class ContextTruncator:
"""Context truncator."""
def fix_messages(self, messages: list[Message]) -> list[Message]:
fixed_messages = []
for message in messages:
if message.role == "tool":
# tool block 前面必须要有 user 和 assistant block
if len(fixed_messages) < 2:
# 这种情况可能是上下文被截断导致的
# 我们直接将之前的上下文都清空
fixed_messages = []
else:
fixed_messages.append(message)
else:
fixed_messages.append(message)
return fixed_messages
def truncate_by_turns(
self,
messages: list[Message],
keep_most_recent_turns: int,
drop_turns: int = 1,
) -> list[Message]:
"""截断上下文列表,确保不超过最大长度。
一个 turn 包含一个 user 消息和一个 assistant 消息。
这个方法会保证截断后的上下文列表符合 OpenAI 的上下文格式。
Args:
messages: 上下文列表
keep_most_recent_turns: 保留最近的对话轮数
drop_turns: 一次性丢弃的对话轮数
Returns:
截断后的上下文列表
"""
if keep_most_recent_turns == -1:
return messages
first_non_system = 0
for i, msg in enumerate(messages):
if msg.role != "system":
first_non_system = i
break
system_messages = messages[:first_non_system]
non_system_messages = messages[first_non_system:]
if len(non_system_messages) // 2 <= keep_most_recent_turns:
return messages
num_to_keep = keep_most_recent_turns - drop_turns + 1
if num_to_keep <= 0:
truncated_contexts = []
else:
truncated_contexts = non_system_messages[-num_to_keep * 2 :]
# 找到第一个 role 为 user 的索引,确保上下文格式正确
index = next(
(i for i, item in enumerate(truncated_contexts) if item.role == "user"),
None,
)
if index is not None and index > 0:
truncated_contexts = truncated_contexts[index:]
result = system_messages + truncated_contexts
return self.fix_messages(result)
def truncate_by_dropping_oldest_turns(
self,
messages: list[Message],
drop_turns: int = 1,
) -> list[Message]:
"""丢弃最旧的 N 个对话轮次。"""
if drop_turns <= 0:
return messages
first_non_system = 0
for i, msg in enumerate(messages):
if msg.role != "system":
first_non_system = i
break
system_messages = messages[:first_non_system]
non_system_messages = messages[first_non_system:]
if len(non_system_messages) // 2 <= drop_turns:
truncated_non_system = []
else:
truncated_non_system = non_system_messages[drop_turns * 2 :]
index = next(
(i for i, item in enumerate(truncated_non_system) if item.role == "user"),
None,
)
if index is not None:
truncated_non_system = truncated_non_system[index:]
elif truncated_non_system:
truncated_non_system = []
result = system_messages + truncated_non_system
return self.fix_messages(result)
def truncate_by_halving(
self,
messages: list[Message],
) -> list[Message]:
"""对半砍策略,删除 50% 的消息"""
if len(messages) <= 2:
return messages
first_non_system = 0
for i, msg in enumerate(messages):
if msg.role != "system":
first_non_system = i
break
system_messages = messages[:first_non_system]
non_system_messages = messages[first_non_system:]
messages_to_delete = len(non_system_messages) // 2
if messages_to_delete == 0:
return messages
truncated_non_system = non_system_messages[messages_to_delete:]
index = next(
(i for i, item in enumerate(truncated_non_system) if item.role == "user"),
None,
)
if index is not None:
truncated_non_system = truncated_non_system[index:]
result = system_messages + truncated_non_system
return self.fix_messages(result)
@@ -25,6 +25,10 @@ from astrbot.core.provider.entities import (
)
from astrbot.core.provider.provider import Provider
from ..context.compressor import ContextCompressor
from ..context.config import ContextConfig
from ..context.manager import ContextManager
from ..context.token_counter import TokenCounter
from ..hooks import BaseAgentRunHooks
from ..message import AssistantMessageSegment, Message, ToolCallMessageSegment
from ..response import AgentResponseData, AgentStats
@@ -47,10 +51,47 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
run_context: ContextWrapper[TContext],
tool_executor: BaseFunctionToolExecutor[TContext],
agent_hooks: BaseAgentRunHooks[TContext],
streaming: bool = False,
# enforce max turns, will discard older turns when exceeded BEFORE compression
# -1 means no limit
enforce_max_turns: int = -1,
# llm compressor
llm_compress_instruction: str | None = None,
llm_compress_keep_recent: int = 0,
llm_compress_provider: Provider | None = None,
# truncate by turns compressor
truncate_turns: int = 1,
# customize
custom_token_counter: TokenCounter | None = None,
custom_compressor: ContextCompressor | None = None,
**kwargs: T.Any,
) -> None:
self.req = request
self.streaming = kwargs.get("streaming", False)
self.streaming = streaming
self.enforce_max_turns = enforce_max_turns
self.llm_compress_instruction = llm_compress_instruction
self.llm_compress_keep_recent = llm_compress_keep_recent
self.llm_compress_provider = llm_compress_provider
self.truncate_turns = truncate_turns
self.custom_token_counter = custom_token_counter
self.custom_compressor = custom_compressor
# we will do compress when:
# 1. before requesting LLM
# TODO: 2. after LLM output a tool call
self.context_config = ContextConfig(
# <=0 will never do compress
max_context_tokens=provider.provider_config.get("max_context_tokens", 0),
# enforce max turns before compression
enforce_max_turns=self.enforce_max_turns,
truncate_turns=self.truncate_turns,
llm_compress_instruction=self.llm_compress_instruction,
llm_compress_keep_recent=self.llm_compress_keep_recent,
llm_compress_provider=self.llm_compress_provider,
custom_token_counter=self.custom_token_counter,
custom_compressor=self.custom_compressor,
)
self.context_manager = ContextManager(self.context_config)
self.provider = provider
self.final_llm_resp = None
self._state = AgentState.IDLE
@@ -110,6 +151,12 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
self._transition_state(AgentState.RUNNING)
llm_resp_result = None
# do truncate and compress
token_usage = self.req.conversation.token_usage if self.req.conversation else 0
self.run_context.messages = await self.context_manager.process(
self.run_context.messages, trusted_token_usage=token_usage
)
async for llm_response in self._iter_llm_responses():
if llm_response.is_chunk:
# update ttft
+89 -28
View File
@@ -5,7 +5,7 @@ from typing import Any, TypedDict
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
VERSION = "4.10.4"
VERSION = "4.11.0"
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db")
WEBHOOK_SUPPORTED_PLATFORMS = [
@@ -83,6 +83,16 @@ DEFAULT_CONFIG = {
"default_personality": "default",
"persona_pool": ["*"],
"prompt_prefix": "{{prompt}}",
"context_limit_reached_strategy": "truncate_by_turns", # or llm_compress
"llm_compress_instruction": (
"Based on our full conversation history, produce a concise summary of key takeaways and/or project progress.\n"
"1. Systematically cover all core topics discussed and the final conclusion/outcome for each; clearly highlight the latest primary focus.\n"
"2. If any tools were used, summarize tool usage (total call count) and extract the most valuable insights from tool outputs.\n"
"3. If there was an initial user goal, state it first and describe the current progress/status.\n"
"4. Write the summary in the user's language.\n"
),
"llm_compress_keep_recent": 4,
"llm_compress_provider_id": "",
"max_context_length": -1,
"dequeue_context_length": 1,
"streaming_response": False,
@@ -179,6 +189,7 @@ class ChatProviderTemplate(TypedDict):
model: str
modalities: list
custom_extra_body: dict[str, Any]
max_context_tokens: int
CHAT_PROVIDER_TEMPLATE = {
@@ -187,6 +198,7 @@ CHAT_PROVIDER_TEMPLATE = {
"model": "",
"modalities": [],
"custom_extra_body": {},
"max_context_tokens": 0,
}
"""
@@ -227,7 +239,7 @@ CONFIG_METADATA_2 = {
"callback_server_host": "0.0.0.0",
"port": 6196,
},
"OneBot v11": {
"OneBot v11 (QQ 个人号等)": {
"id": "default",
"type": "aiocqhttp",
"enable": False,
@@ -235,16 +247,6 @@ CONFIG_METADATA_2 = {
"ws_reverse_port": 6199,
"ws_reverse_token": "",
},
"WeChatPadPro": {
"id": "wechatpadpro",
"type": "wechatpadpro",
"enable": False,
"admin_key": "stay33",
"host": "这里填写你的局域网IP或者公网服务器IP",
"port": 8059,
"wpp_active_message_poll": False,
"wpp_active_message_poll_interval": 3,
},
"微信公众平台": {
"id": "weixin_official_account",
"type": "weixin_official_account",
@@ -374,6 +376,16 @@ CONFIG_METADATA_2 = {
"satori_heartbeat_interval": 10,
"satori_reconnect_delay": 5,
},
"WeChatPadPro": {
"id": "wechatpadpro",
"type": "wechatpadpro",
"enable": False,
"admin_key": "stay33",
"host": "这里填写你的局域网IP或者公网服务器IP",
"port": 8059,
"wpp_active_message_poll": False,
"wpp_active_message_poll_interval": 3,
},
# "WebChat": {
# "id": "webchat",
# "type": "webchat",
@@ -2033,6 +2045,11 @@ CONFIG_METADATA_2 = {
"type": "string",
"hint": "模型名称,如 gpt-4o-mini, deepseek-chat。",
},
"max_context_tokens": {
"description": "模型上下文窗口大小",
"type": "int",
"hint": "模型最大上下文 Token 大小。如果为 0,则会自动从模型元数据填充(如有),也可手动修改。",
},
"dify_api_key": {
"description": "API Key",
"type": "string",
@@ -2540,6 +2557,66 @@ CONFIG_METADATA_3 = {
# "provider_settings.enable": True,
# },
# },
"truncate_and_compress": {
"description": "上下文管理策略",
"type": "object",
"items": {
"provider_settings.max_context_length": {
"description": "最多携带对话轮数",
"type": "int",
"hint": "超出这个数量时丢弃最旧的部分,一轮聊天记为 1 条,-1 为不限制",
"condition": {
"provider_settings.agent_runner_type": "local",
},
},
"provider_settings.dequeue_context_length": {
"description": "丢弃对话轮数",
"type": "int",
"hint": "超出最多携带对话轮数时, 一次丢弃的聊天轮数",
"condition": {
"provider_settings.agent_runner_type": "local",
},
},
"provider_settings.context_limit_reached_strategy": {
"description": "超出模型上下文窗口时的处理方式",
"type": "string",
"options": ["truncate_by_turns", "llm_compress"],
"labels": ["按对话轮数截断", "由 LLM 压缩上下文"],
"condition": {
"provider_settings.agent_runner_type": "local",
},
"hint": "",
},
"provider_settings.llm_compress_instruction": {
"description": "上下文压缩提示词",
"type": "text",
"hint": "如果为空则使用默认提示词。",
"condition": {
"provider_settings.context_limit_reached_strategy": "llm_compress",
"provider_settings.agent_runner_type": "local",
},
},
"provider_settings.llm_compress_keep_recent": {
"description": "压缩时保留最近对话轮数",
"type": "int",
"hint": "始终保留的最近 N 轮对话。",
"condition": {
"provider_settings.context_limit_reached_strategy": "llm_compress",
"provider_settings.agent_runner_type": "local",
},
},
"provider_settings.llm_compress_provider_id": {
"description": "用于上下文压缩的模型提供商 ID",
"type": "string",
"_special": "select_provider",
"hint": "留空时将降级为“按对话轮数截断”的策略。",
"condition": {
"provider_settings.context_limit_reached_strategy": "llm_compress",
"provider_settings.agent_runner_type": "local",
},
},
},
},
"others": {
"description": "其他配置",
"type": "object",
@@ -2604,22 +2681,6 @@ CONFIG_METADATA_3 = {
"provider_settings.streaming_response": True,
},
},
"provider_settings.max_context_length": {
"description": "最多携带对话轮数",
"type": "int",
"hint": "超出这个数量时丢弃最旧的部分,一轮聊天记为 1 条,-1 为不限制",
"condition": {
"provider_settings.agent_runner_type": "local",
},
},
"provider_settings.dequeue_context_length": {
"description": "丢弃对话轮数",
"type": "int",
"hint": "超出最多携带对话轮数时, 一次丢弃的聊天轮数",
"condition": {
"provider_settings.agent_runner_type": "local",
},
},
"provider_settings.wake_prefix": {
"description": "LLM 聊天额外唤醒前缀 ",
"type": "string",
+4
View File
@@ -69,6 +69,7 @@ class ConversationManager:
persona_id=conv_v2.persona_id,
created_at=created_at,
updated_at=updated_at,
token_usage=conv_v2.token_usage,
)
async def new_conversation(
@@ -256,6 +257,7 @@ class ConversationManager:
history: list[dict] | None = None,
title: str | None = None,
persona_id: str | None = None,
token_usage: int | None = None,
) -> None:
"""更新会话的对话.
@@ -263,6 +265,7 @@ class ConversationManager:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
history (List[Dict]): 对话历史记录, 是一个字典列表, 每个字典包含 role 和 content 字段
token_usage (int | None): token 使用量。None 表示不更新
"""
if not conversation_id:
@@ -274,6 +277,7 @@ class ConversationManager:
title=title,
persona_id=persona_id,
content=history,
token_usage=token_usage,
)
async def update_conversation_title(
+1
View File
@@ -90,6 +90,7 @@ class AstrBotCoreLifecycle:
# 初始化 UMOP 配置路由器
self.umop_config_router = UmopConfigRouter(sp=sp)
await self.umop_config_router.initialize()
# 初始化 AstrBot 配置管理器
self.astrbot_config_mgr = AstrBotConfigManager(
+1
View File
@@ -152,6 +152,7 @@ class BaseDatabase(abc.ABC):
title: str | None = None,
persona_id: str | None = None,
content: list[dict] | None = None,
token_usage: int | None = None,
) -> None:
"""Update a conversation's history."""
...
@@ -0,0 +1,61 @@
"""Migration script to add token_usage column to conversations table.
This migration adds the token_usage field to track token consumption for each conversation.
Changes:
- Adds token_usage column to conversations table (default: 0)
"""
from sqlalchemy import text
from astrbot.api import logger, sp
from astrbot.core.db import BaseDatabase
async def migrate_token_usage(db_helper: BaseDatabase):
"""Add token_usage column to conversations table.
This migration adds a new column to track token consumption in conversations.
"""
# 检查是否已经完成迁移
migration_done = await db_helper.get_preference(
"global", "global", "migration_done_token_usage_1"
)
if migration_done:
return
logger.info("开始执行数据库迁移(添加 conversations.token_usage 列)...")
# 这里只适配了 SQLite。因为截止至这一版本,AstrBot 仅支持 SQLite。
try:
async with db_helper.get_db() as session:
# 检查列是否已存在
result = await session.execute(text("PRAGMA table_info(conversations)"))
columns = result.fetchall()
column_names = [col[1] for col in columns]
if "token_usage" in column_names:
logger.info("token_usage 列已存在,跳过迁移")
await sp.put_async(
"global", "global", "migration_done_token_usage_1", True
)
return
# 添加 token_usage 列
await session.execute(
text(
"ALTER TABLE conversations ADD COLUMN token_usage INTEGER NOT NULL DEFAULT 0"
)
)
await session.commit()
logger.info("token_usage 列添加成功")
# 标记迁移完成
await sp.put_async("global", "global", "migration_done_token_usage_1", True)
logger.info("token_usage 迁移完成")
except Exception as e:
logger.error(f"迁移过程中发生错误: {e}", exc_info=True)
raise
+7
View File
@@ -54,6 +54,11 @@ class ConversationV2(SQLModel, table=True):
)
title: str | None = Field(default=None, max_length=255)
persona_id: str | None = Field(default=None)
token_usage: int = Field(default=0, nullable=False)
"""content is a list of OpenAI-formated messages in list[dict] format.
token_usage is the total token value of the messages.
when 0, will use estimated token counter.
"""
__table_args__ = (
UniqueConstraint(
@@ -313,6 +318,8 @@ class Conversation:
persona_id: str | None = ""
created_at: int = 0
updated_at: int = 0
token_usage: int = 0
"""对话的总 token 数量。AstrBot 会保留最近一次 LLM 请求返回的总 token 数,方便统计。token_usage 可能为 0,表示未知。"""
class Personality(TypedDict):
+5 -1
View File
@@ -241,7 +241,9 @@ class SQLiteDatabase(BaseDatabase):
session.add(new_conversation)
return new_conversation
async def update_conversation(self, cid, title=None, persona_id=None, content=None):
async def update_conversation(
self, cid, title=None, persona_id=None, content=None, token_usage=None
):
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
@@ -255,6 +257,8 @@ class SQLiteDatabase(BaseDatabase):
values["persona_id"] = persona_id
if content is not None:
values["content"] = content
if token_usage is not None:
values["token_usage"] = token_usage
if not values:
return None
query = query.values(**values)
@@ -38,7 +38,7 @@ class AgentRequestSubStage(Stage):
)
return
if not SessionServiceManager.should_process_llm_request(event):
if not await SessionServiceManager.should_process_llm_request(event):
logger.debug(
f"The session {event.unified_msg_origin} has disabled AI capability, skipping processing."
)
@@ -1,12 +1,12 @@
"""本地 Agent 模式的 LLM 调用 Stage"""
import asyncio
import copy
import json
from collections.abc import AsyncGenerator
from astrbot.core import logger
from astrbot.core.agent.message import Message
from astrbot.core.agent.response import AgentStats
from astrbot.core.agent.tool import ToolSet
from astrbot.core.astr_agent_context import AstrAgentContext
from astrbot.core.conversation_mgr import Conversation
@@ -24,6 +24,7 @@ from astrbot.core.provider.entities import (
)
from astrbot.core.star.star_handler import EventType, star_map
from astrbot.core.utils.file_extract import extract_file_moonshotai
from astrbot.core.utils.llm_metadata import LLM_METADATAS
from astrbot.core.utils.metrics import Metric
from astrbot.core.utils.session_lock import session_lock_manager
@@ -41,11 +42,6 @@ class InternalAgentSubStage(Stage):
self.ctx = ctx
conf = ctx.astrbot_config
settings = conf["provider_settings"]
self.max_context_length = settings["max_context_length"] # int
self.dequeue_context_length: int = min(
max(1, settings["dequeue_context_length"]),
self.max_context_length - 1,
)
self.streaming_response: bool = settings["streaming_response"]
self.unsupported_streaming_strategy: str = settings[
"unsupported_streaming_strategy"
@@ -65,6 +61,25 @@ class InternalAgentSubStage(Stage):
"moonshotai_api_key", ""
)
# 上下文管理相关
self.context_limit_reached_strategy: str = settings.get(
"context_limit_reached_strategy", "truncate_by_turns"
)
self.llm_compress_instruction: str = settings.get(
"llm_compress_instruction", ""
)
self.llm_compress_keep_recent: int = settings.get("llm_compress_keep_recent", 4)
self.llm_compress_provider_id: str = settings.get(
"llm_compress_provider_id", ""
)
self.max_context_length = settings["max_context_length"] # int
self.dequeue_context_length: int = min(
max(1, settings["dequeue_context_length"]),
self.max_context_length - 1,
)
if self.dequeue_context_length <= 0:
self.dequeue_context_length = 1
self.conv_manager = ctx.plugin_manager.context.conversation_manager
def _select_provider(self, event: AstrMessageEvent):
@@ -167,34 +182,6 @@ class InternalAgentSubStage(Stage):
},
)
def _truncate_contexts(
self,
contexts: list[dict],
) -> list[dict]:
"""截断上下文列表,确保不超过最大长度"""
if self.max_context_length == -1:
return contexts
if len(contexts) // 2 <= self.max_context_length:
return contexts
truncated_contexts = contexts[
-(self.max_context_length - self.dequeue_context_length + 1) * 2 :
]
# 找到第一个role 为 user 的索引,确保上下文格式正确
index = next(
(
i
for i, item in enumerate(truncated_contexts)
if item.get("role") == "user"
),
None,
)
if index is not None and index > 0:
truncated_contexts = truncated_contexts[index:]
return truncated_contexts
def _modalities_fix(
self,
provider: Provider,
@@ -296,6 +283,7 @@ class InternalAgentSubStage(Stage):
req: ProviderRequest,
llm_response: LLMResponse | None,
all_messages: list[Message],
runner_stats: AgentStats | None,
):
if (
not req
@@ -322,27 +310,37 @@ class InternalAgentSubStage(Stage):
continue
message_to_save.append(message.model_dump())
# get token usage from agent runner stats
token_usage = None
if runner_stats:
token_usage = runner_stats.token_usage.total
await self.conv_manager.update_conversation(
event.unified_msg_origin,
req.conversation.cid,
history=message_to_save,
token_usage=token_usage,
)
def _fix_messages(self, messages: list[dict]) -> list[dict]:
"""验证并且修复上下文"""
fixed_messages = []
for message in messages:
if message.get("role") == "tool":
# tool block 前面必须要有 user 和 assistant block
if len(fixed_messages) < 2:
# 这种情况可能是上下文被截断导致的
# 我们直接将之前的上下文都清空
fixed_messages = []
else:
fixed_messages.append(message)
else:
fixed_messages.append(message)
return fixed_messages
def _get_compress_provider(self) -> Provider | None:
if not self.llm_compress_provider_id:
return None
if self.context_limit_reached_strategy != "llm_compress":
return None
provider = self.ctx.plugin_manager.context.get_provider_by_id(
self.llm_compress_provider_id,
)
if provider is None:
logger.warning(
f"未找到指定的上下文压缩模型 {self.llm_compress_provider_id},将跳过压缩。",
)
return None
if not isinstance(provider, Provider):
logger.warning(
f"指定的上下文压缩模型 {self.llm_compress_provider_id} 不是对话模型,将跳过压缩。"
)
return None
return provider
async def process(
self, event: AstrMessageEvent, provider_wake_prefix: str
@@ -364,6 +362,10 @@ class InternalAgentSubStage(Stage):
streaming_response = bool(enable_streaming)
logger.debug("ready to request llm provider")
# 通知等待调用 LLM(在获取锁之前)
await call_event_hook(event, EventType.OnWaitingLLMRequestEvent)
async with session_lock_manager.acquire_lock(event.unified_msg_origin):
logger.debug("acquired session lock for llm request")
if event.get_extra("provider_request"):
@@ -422,9 +424,10 @@ class InternalAgentSubStage(Stage):
await self._apply_kb(event, req)
# truncate contexts to fit max length
if req.contexts:
req.contexts = self._truncate_contexts(req.contexts)
self._fix_messages(req.contexts)
# NOW moved to ContextManager inside ToolLoopAgentRunner
# if req.contexts:
# req.contexts = self._truncate_contexts(req.contexts)
# self._fix_messages(req.contexts)
# session_id
if not req.session_id:
@@ -440,8 +443,6 @@ class InternalAgentSubStage(Stage):
self.unsupported_streaming_strategy == "turn_off"
and not event.platform_meta.support_streaming_message
)
# 备份 req.contexts
backup_contexts = copy.deepcopy(req.contexts)
# run agent
agent_runner = AgentRunner()
@@ -452,6 +453,15 @@ class InternalAgentSubStage(Stage):
context=self.ctx.plugin_manager.context,
event=event,
)
# inject model context length limit
if provider.provider_config.get("max_context_tokens", 0) <= 0:
model = provider.get_model()
if model_info := LLM_METADATAS.get(model):
provider.provider_config["max_context_tokens"] = model_info[
"limit"
]["context"]
await agent_runner.reset(
provider=provider,
request=req,
@@ -462,6 +472,11 @@ class InternalAgentSubStage(Stage):
tool_executor=FunctionToolExecutor(),
agent_hooks=MAIN_AGENT_HOOKS,
streaming=streaming_response,
llm_compress_instruction=self.llm_compress_instruction,
llm_compress_keep_recent=self.llm_compress_keep_recent,
llm_compress_provider=self._get_compress_provider(),
truncate_turns=self.dequeue_context_length,
enforce_max_turns=self.max_context_length,
)
if streaming_response and not stream_to_general:
@@ -507,14 +522,12 @@ class InternalAgentSubStage(Stage):
):
yield
# 恢复备份的 contexts
req.contexts = backup_contexts
await self._save_to_history(
event,
req,
agent_runner.get_final_llm_resp(),
agent_runner.run_context.messages,
agent_runner.stats,
)
# 异步处理 WebChat 特殊情况
@@ -260,7 +260,7 @@ class ResultDecorateStage(Stage):
should_tts = (
bool(self.ctx.astrbot_config["provider_tts_settings"]["enable"])
and result.is_llm_result()
and SessionServiceManager.should_process_tts_request(event)
and await SessionServiceManager.should_process_tts_request(event)
and random.random() <= self.tts_trigger_probability
and tts_provider
)
@@ -21,7 +21,7 @@ class SessionStatusCheckStage(Stage):
event: AstrMessageEvent,
) -> None | AsyncGenerator[None, None]:
# 检查会话是否整体启用
if not SessionServiceManager.is_session_enabled(event.unified_msg_origin):
if not await SessionServiceManager.is_session_enabled(event.unified_msg_origin):
logger.debug(f"会话 {event.unified_msg_origin} 已被关闭,已终止事件传播。")
# workaround for #2309
+1 -1
View File
@@ -227,7 +227,7 @@ class WakingCheckStage(Stage):
event._extras.pop("parsed_params", None)
# 根据会话配置过滤插件处理器
activated_handlers = SessionPluginManager.filter_handlers_by_session(
activated_handlers = await SessionPluginManager.filter_handlers_by_session(
event,
activated_handlers,
)
@@ -191,7 +191,7 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
if self.active_send_mode:
await self.convert_message(msg, None)
else:
if msg.id in self.wexin_event_workers:
if str(msg.id) in self.wexin_event_workers:
future = self.wexin_event_workers[str(cast(str | int, msg.id))]
logger.debug(f"duplicate message id checked: {msg.id}")
else:
+27 -12
View File
@@ -119,19 +119,34 @@ class ProviderManager:
TTSProvider,
):
self.curr_tts_provider_inst = prov
sp.put("curr_provider_tts", provider_id, scope="global", scope_id="global")
await sp.put_async(
key="curr_provider_tts",
value=provider_id,
scope="global",
scope_id="global",
)
elif provider_type == ProviderType.SPEECH_TO_TEXT and isinstance(
prov,
STTProvider,
):
self.curr_stt_provider_inst = prov
sp.put("curr_provider_stt", provider_id, scope="global", scope_id="global")
await sp.put_async(
key="curr_provider_stt",
value=provider_id,
scope="global",
scope_id="global",
)
elif provider_type == ProviderType.CHAT_COMPLETION and isinstance(
prov,
Provider,
):
self.curr_provider_inst = prov
sp.put("curr_provider", provider_id, scope="global", scope_id="global")
await sp.put_async(
key="curr_provider",
value=provider_id,
scope="global",
scope_id="global",
)
async def get_provider_by_id(self, provider_id: str) -> Providers | None:
"""根据提供商 ID 获取提供商实例"""
@@ -206,21 +221,21 @@ class ProviderManager:
logger.error(traceback.format_exc())
logger.error(e)
selected_provider_id = sp.get(
"curr_provider",
self.provider_settings.get("default_provider_id"),
selected_provider_id = await sp.get_async(
key="curr_provider",
default=self.provider_settings.get("default_provider_id"),
scope="global",
scope_id="global",
)
selected_stt_provider_id = sp.get(
"curr_provider_stt",
self.provider_stt_settings.get("provider_id"),
selected_stt_provider_id = await sp.get_async(
key="curr_provider_stt",
default=self.provider_stt_settings.get("provider_id"),
scope="global",
scope_id="global",
)
selected_tts_provider_id = sp.get(
"curr_provider_tts",
self.provider_tts_settings.get("provider_id"),
selected_tts_provider_id = await sp.get_async(
key="curr_provider_tts",
default=self.provider_tts_settings.get("provider_id"),
scope="global",
scope_id="global",
)
@@ -378,7 +378,8 @@ class ProviderOpenAIOfficial(Provider):
new_content.append(part)
message["content"] = new_content
# reasoning key is "reasoning_content"
message["reasoning_content"] = reasoning_content
if reasoning_content:
message["reasoning_content"] = reasoning_content
async def _handle_api_error(
self,
+14 -1
View File
@@ -149,9 +149,12 @@ class Context:
contexts: context messages for the LLM
max_steps: Maximum number of tool calls before stopping the loop
**kwargs: Additional keyword arguments. The kwargs will not be passed to the LLM directly for now, but can include:
stream: bool - whether to stream the LLM response
agent_hooks: BaseAgentRunHooks[AstrAgentContext] - hooks to run during agent execution
agent_context: AstrAgentContext - context to use for the agent
other kwargs will be DIRECTLY passed to the runner.reset() method
Returns:
The final LLMResponse after tool calls are completed.
@@ -194,6 +197,15 @@ class Context:
)
agent_runner = ToolLoopAgentRunner()
tool_executor = FunctionToolExecutor()
streaming = kwargs.get("stream", False)
other_kwargs = {
k: v
for k, v in kwargs.items()
if k not in ["stream", "agent_hooks", "agent_context"]
}
await agent_runner.reset(
provider=prov,
request=request,
@@ -203,7 +215,8 @@ class Context:
),
tool_executor=tool_executor,
agent_hooks=agent_hooks,
streaming=kwargs.get("stream", False),
streaming=streaming,
**other_kwargs,
)
async for _ in agent_runner.step_until_done(max_steps):
pass
+2
View File
@@ -12,6 +12,7 @@ from .star_handler import (
register_on_llm_request,
register_on_llm_response,
register_on_platform_loaded,
register_on_waiting_llm_request,
register_permission_type,
register_platform_adapter_type,
register_regex,
@@ -30,6 +31,7 @@ __all__ = [
"register_on_llm_request",
"register_on_llm_response",
"register_on_platform_loaded",
"register_on_waiting_llm_request",
"register_permission_type",
"register_platform_adapter_type",
"register_regex",
@@ -339,6 +339,30 @@ def register_on_platform_loaded(**kwargs):
return decorator
def register_on_waiting_llm_request(**kwargs):
"""当等待调用 LLM 时的通知事件(在获取锁之前)
此钩子在消息确定要调用 LLM 但还未开始排队等锁时触发
适合用于发送"正在思考中..."等用户反馈提示
Examples:
```py
@on_waiting_llm_request()
async def on_waiting_llm(self, event: AstrMessageEvent) -> None:
await event.send("🤔 正在思考中...")
```
"""
def decorator(awaitable):
_ = get_handler_or_create(
awaitable, EventType.OnWaitingLLMRequestEvent, **kwargs
)
return awaitable
return decorator
def register_on_llm_request(**kwargs):
"""当有 LLM 请求时的事件
+38 -26
View File
@@ -12,7 +12,7 @@ class SessionServiceManager:
# =============================================================================
@staticmethod
def is_llm_enabled_for_session(session_id: str) -> bool:
async def is_llm_enabled_for_session(session_id: str) -> bool:
"""检查LLM是否在指定会话中启用
Args:
@@ -23,11 +23,11 @@ class SessionServiceManager:
"""
# 获取会话服务配置
session_services = sp.get(
"session_service_config",
{},
session_services = await sp.get_async(
scope="umo",
scope_id=session_id,
key="session_service_config",
default={},
)
# 如果配置了该会话的LLM状态,返回该状态
@@ -39,7 +39,7 @@ class SessionServiceManager:
return True
@staticmethod
def set_llm_status_for_session(session_id: str, enabled: bool) -> None:
async def set_llm_status_for_session(session_id: str, enabled: bool) -> None:
"""设置LLM在指定会话中的启停状态
Args:
@@ -48,18 +48,24 @@ class SessionServiceManager:
"""
session_config = (
sp.get("session_service_config", {}, scope="umo", scope_id=session_id) or {}
await sp.get_async(
scope="umo",
scope_id=session_id,
key="session_service_config",
default={},
)
or {}
)
session_config["llm_enabled"] = enabled
sp.put(
"session_service_config",
session_config,
await sp.put_async(
scope="umo",
scope_id=session_id,
key="session_service_config",
value=session_config,
)
@staticmethod
def should_process_llm_request(event: AstrMessageEvent) -> bool:
async def should_process_llm_request(event: AstrMessageEvent) -> bool:
"""检查是否应该处理LLM请求
Args:
@@ -70,14 +76,14 @@ class SessionServiceManager:
"""
session_id = event.unified_msg_origin
return SessionServiceManager.is_llm_enabled_for_session(session_id)
return await SessionServiceManager.is_llm_enabled_for_session(session_id)
# =============================================================================
# TTS 相关方法
# =============================================================================
@staticmethod
def is_tts_enabled_for_session(session_id: str) -> bool:
async def is_tts_enabled_for_session(session_id: str) -> bool:
"""检查TTS是否在指定会话中启用
Args:
@@ -88,11 +94,11 @@ class SessionServiceManager:
"""
# 获取会话服务配置
session_services = sp.get(
"session_service_config",
{},
session_services = await sp.get_async(
scope="umo",
scope_id=session_id,
key="session_service_config",
default={},
)
# 如果配置了该会话的TTS状态,返回该状态
@@ -104,7 +110,7 @@ class SessionServiceManager:
return True
@staticmethod
def set_tts_status_for_session(session_id: str, enabled: bool) -> None:
async def set_tts_status_for_session(session_id: str, enabled: bool) -> None:
"""设置TTS在指定会话中的启停状态
Args:
@@ -113,14 +119,20 @@ class SessionServiceManager:
"""
session_config = (
sp.get("session_service_config", {}, scope="umo", scope_id=session_id) or {}
await sp.get_async(
scope="umo",
scope_id=session_id,
key="session_service_config",
default={},
)
or {}
)
session_config["tts_enabled"] = enabled
sp.put(
"session_service_config",
session_config,
await sp.put_async(
scope="umo",
scope_id=session_id,
key="session_service_config",
value=session_config,
)
logger.info(
@@ -128,7 +140,7 @@ class SessionServiceManager:
)
@staticmethod
def should_process_tts_request(event: AstrMessageEvent) -> bool:
async def should_process_tts_request(event: AstrMessageEvent) -> bool:
"""检查是否应该处理TTS请求
Args:
@@ -139,14 +151,14 @@ class SessionServiceManager:
"""
session_id = event.unified_msg_origin
return SessionServiceManager.is_tts_enabled_for_session(session_id)
return await SessionServiceManager.is_tts_enabled_for_session(session_id)
# =============================================================================
# 会话整体启停相关方法
# =============================================================================
@staticmethod
def is_session_enabled(session_id: str) -> bool:
async def is_session_enabled(session_id: str) -> bool:
"""检查会话是否整体启用
Args:
@@ -157,11 +169,11 @@ class SessionServiceManager:
"""
# 获取会话服务配置
session_services = sp.get(
"session_service_config",
{},
session_services = await sp.get_async(
scope="umo",
scope_id=session_id,
key="session_service_config",
default={},
)
# 如果配置了该会话的整体状态,返回该状态
+23 -11
View File
@@ -8,7 +8,10 @@ class SessionPluginManager:
"""管理会话级别的插件启停状态"""
@staticmethod
def is_plugin_enabled_for_session(session_id: str, plugin_name: str) -> bool:
async def is_plugin_enabled_for_session(
session_id: str,
plugin_name: str,
) -> bool:
"""检查插件是否在指定会话中启用
Args:
@@ -20,11 +23,11 @@ class SessionPluginManager:
"""
# 获取会话插件配置
session_plugin_config = sp.get(
"session_plugin_config",
{},
session_plugin_config = await sp.get_async(
scope="umo",
scope_id=session_id,
key="session_plugin_config",
default={},
)
session_config = session_plugin_config.get(session_id, {})
@@ -43,7 +46,10 @@ class SessionPluginManager:
return True
@staticmethod
def filter_handlers_by_session(event: AstrMessageEvent, handlers: list) -> list:
async def filter_handlers_by_session(
event: AstrMessageEvent,
handlers: list,
) -> list:
"""根据会话配置过滤处理器列表
Args:
@@ -59,6 +65,15 @@ class SessionPluginManager:
session_id = event.unified_msg_origin
filtered_handlers = []
session_plugin_config = await sp.get_async(
scope="umo",
scope_id=session_id,
key="session_plugin_config",
default={},
)
session_config = session_plugin_config.get(session_id, {})
disabled_plugins = session_config.get("disabled_plugins", [])
for handler in handlers:
# 获取处理器对应的插件
plugin = star_map.get(handler.handler_module_path)
@@ -76,14 +91,11 @@ class SessionPluginManager:
continue
# 检查插件是否在当前会话中启用
if SessionPluginManager.is_plugin_enabled_for_session(
session_id,
plugin.name,
):
filtered_handlers.append(handler)
else:
if plugin.name in disabled_plugins:
logger.debug(
f"插件 {plugin.name} 在会话 {session_id} 中被禁用,跳过处理器 {handler.handler_name}",
)
else:
filtered_handlers.append(handler)
return filtered_handlers
+1
View File
@@ -184,6 +184,7 @@ class EventType(enum.Enum):
OnPlatformLoadedEvent = enum.auto() # 平台加载完成
AdapterMessageEvent = enum.auto() # 收到适配器发来的消息
OnWaitingLLMRequestEvent = enum.auto() # 等待调用 LLM(在获取锁之前,仅通知)
OnLLMRequestEvent = enum.auto() # 收到 LLM 请求(可以是用户也可以是插件)
OnLLMResponseEvent = enum.auto() # LLM 响应后
OnDecoratingResultEvent = enum.auto() # 发送消息前
+9 -6
View File
@@ -1,3 +1,5 @@
import fnmatch
from astrbot.core.utils.shared_preferences import SharedPreferences
@@ -9,14 +11,15 @@ class UmopConfigRouter:
"""UMOP 到配置文件 ID 的映射"""
self.sp = sp
self._load_routing_table()
async def initialize(self):
await self._load_routing_table()
def _load_routing_table(self):
async def _load_routing_table(self):
"""加载路由表"""
# 从 SharedPreferences 中加载 umop_to_conf_id 映射
sp_data = self.sp.get(
"umop_config_routing",
{},
sp_data = await self.sp.get_async(
key="umop_config_routing",
default={},
scope="global",
scope_id="global",
)
@@ -30,7 +33,7 @@ class UmopConfigRouter:
if len(p1_ls) != 3 or len(p2_ls) != 3:
return False # 非法格式
return all(p == "" or p == "*" or p == t for p, t in zip(p1_ls, p2_ls))
return all(p == "" or fnmatch.fnmatchcase(t, p) for p, t in zip(p1_ls, p2_ls))
def get_conf_id_for_umop(self, umo: str) -> str | None:
"""根据 UMO 获取对应的配置文件 ID
+8
View File
@@ -3,6 +3,7 @@ import traceback
from astrbot.core import astrbot_config, logger
from astrbot.core.astrbot_config_mgr import AstrBotConfig, AstrBotConfigManager
from astrbot.core.db.migration.migra_45_to_46 import migrate_45_to_46
from astrbot.core.db.migration.migra_token_usage import migrate_token_usage
from astrbot.core.db.migration.migra_webchat_session import migrate_webchat_session
@@ -139,6 +140,13 @@ async def migra(
logger.error(f"Migration for webchat session failed: {e!s}")
logger.error(traceback.format_exc())
# migration for token_usage column
try:
await migrate_token_usage(db)
except Exception as e:
logger.error(f"Migration for token_usage column failed: {e!s}")
logger.error(traceback.format_exc())
# migra third party agent runner configs
_c = False
providers = astrbot_config["provider"]
+5
View File
@@ -0,0 +1,5 @@
## What's Changed
hotfix of v4.10.4
fix: 部分配置项的输入框不显示,如飞书机器人配置的部分配置项。(#4268
+11
View File
@@ -0,0 +1,11 @@
## What's Changed
hotfix of v4.10.4
fix:
1. ‼️ 部分情况下使用 OpenAI 接口报错与 reasoning_content 有关的问题;
feat:
1. WebUI 已安装插件页支持记忆视图类型(列表/卡片),列表视图显示插件的人类友好名称和 logo。
+19
View File
@@ -0,0 +1,19 @@
## What's Changed
### 新增
- 支持上下文自动压缩功能。入口:配置文件 -> 上下文管理策略 -> 超出模型上下文窗口时的处理方式。详情请查看: [自动上下文压缩](https://docs.astrbot.app/use/context-compress.html) ([#4322](https://github.com/AstrBotDevs/AstrBot/issues/4322))
- 新增 `on_waiting_llm_request` 事件钩子 ([#4319](https://github.com/AstrBotDevs/AstrBot/issues/4319))
- WebUI 支持强制更新插件 ([#4293](https://github.com/AstrBotDevs/AstrBot/issues/4293))
- 社区已提供适用于 [Matrix](https://github.com/stevessr/astrbot_plugin_matrix_adapter) 平台的适配器插件
### 修复
- 修复微信公众号中由于 msg.id 数据类型不匹配导致的重试失败问题 ([#4292](https://github.com/AstrBotDevs/AstrBot/issues/4292))
- 修复调用 TTS 命令时出现的数据库锁定错误 ([#4313](https://github.com/AstrBotDevs/AstrBot/issues/4313))
- 修复 Anthropic 提供商中 token 用量始终为 0 的问题 ([#4328](https://github.com/AstrBotDevs/AstrBot/issues/4328))
### 优化
- 完善共享组件的国际化支持 ([#4327](https://github.com/AstrBotDevs/AstrBot/issues/4327))
- 优化下载大型备份文件时的稳定性,减少失败情况 ([#4329](https://github.com/AstrBotDevs/AstrBot/issues/4329))
@@ -203,9 +203,8 @@ function hasVisibleItemsAfter(items, currentIndex) {
<v-col cols="12" sm="6" class="config-input">
<ConfigItemRenderer
v-if="metadata[metadataKey].items[key]"
v-model="iterable[key]"
:item-meta="metadata[metadataKey].items[key]"
:item-meta="metadata[metadataKey].items[key] || null"
:loading="loadingEmbeddingDim"
:show-fullscreen-btn="!!metadata[metadataKey].items[key]?.editor_mode"
@get-embedding-dim="getEmbeddingDimensions(iterable)"
@@ -219,7 +219,7 @@ function getSpecialSubtype(value) {
<ConfigItemRenderer
v-else
v-model="createSelectorModel(itemKey).value"
:item-meta="itemMeta"
:item-meta="itemMeta || null"
:show-fullscreen-btn="!!itemMeta?.editor_mode"
@open-fullscreen="openEditorDialog(itemKey, iterable, itemMeta?.editor_theme, itemMeta?.editor_language)"
/>
@@ -144,7 +144,7 @@
color="primary"
density="compact"
hide-details
class="flex-grow-1"
style="flex: 1"
></v-slider>
<v-text-field
:model-value="modelValue"
@@ -154,7 +154,7 @@
class="config-field"
type="number"
hide-details
style="max-width: 140px;"
style="flex: 1"
></v-text-field>
</div>
@@ -223,7 +223,7 @@ const props = defineProps({
},
itemMeta: {
type: Object,
required: true
default: null
},
loading: {
type: Boolean,
@@ -325,4 +325,8 @@ function getSpecialSubtype(value) {
.gap-20 {
gap: 20px;
}
:deep(.v-field__input) {
font-size: 14px;
}
</style>
@@ -145,9 +145,11 @@ const viewReadme = () => {
}})</v-list-item-title>
</v-list-item>
<v-list-item @click="updateExtension" :disabled="!extension?.has_update">
<v-list-item @click="updateExtension">
<v-list-item-title>
{{ tm('card.actions.updateTo') }} {{ extension.online_version || extension.version }}
{{ extension.has_update
? tm('card.actions.updateTo') + ' ' + extension.online_version
: tm('card.actions.reinstall') }}
</v-list-item-title>
</v-list-item>
</template>
@@ -510,7 +510,7 @@ export function useProviderSources(options: UseProviderSourcesOptions) {
const metadata = getModelMetadata(modelName)
let modalities: string[]
if (!metadata) {
modalities = ['text', 'image', 'tool_use']
} else {
@@ -523,13 +523,19 @@ export function useProviderSources(options: UseProviderSourcesOptions) {
}
}
let max_context_tokens = 0
if (metadata?.limit?.context && typeof metadata.limit.context === 'number') {
max_context_tokens = metadata.limit.context
}
const newProvider = {
id: newId,
enable: false,
provider_source_id: sourceId,
model: modelName,
modalities,
custom_extra_body: {}
custom_extra_body: {},
max_context_tokens: max_context_tokens
}
try {
@@ -11,7 +11,12 @@
},
"agent_runner_type": {
"description": "Runner",
"labels": ["Built-in Agent", "Dify", "Coze", "Alibaba Cloud Bailian Application"]
"labels": [
"Built-in Agent",
"Dify",
"Coze",
"Alibaba Cloud Bailian Application"
]
},
"coze_agent_runner_provider_id": {
"description": "Coze Agent Runner Provider ID"
@@ -128,6 +133,39 @@
}
}
},
"truncate_and_compress": {
"description": "Context Management Strategy",
"provider_settings": {
"max_context_length": {
"description": "Maximum Conversation Turns",
"hint": "Discards the oldest parts when this count is exceeded. One conversation round counts as 1, -1 means unlimited"
},
"dequeue_context_length": {
"description": "Dequeue Conversation Turns",
"hint": "Number of conversation turns to discard at once when maximum context length is exceeded"
},
"context_limit_reached_strategy": {
"description": "Handling When Model Context Window is Exceeded",
"labels": [
"Truncate by Turns",
"Compress by LLM"
],
"hint": "When 'Truncate by Turns' is selected, the oldest N conversation turns will be discarded based on the 'Dequeue Conversation Turns' setting above. When 'Compress by LLM' is selected, the specified model will be used for context compression."
},
"llm_compress_instruction": {
"description": "Context Compression Instruction",
"hint": "If empty, the default prompt will be used."
},
"llm_compress_keep_recent": {
"description": "Keep Recent Turns When Compressing",
"hint": "Always keep the most recent N turns of conversation when compressing context."
},
"llm_compress_provider_id": {
"description": "Model Provider ID for Context Compression",
"hint": "When left empty, will fall back to the 'Truncate by Turns' strategy."
}
}
},
"others": {
"description": "Other Settings",
"provider_settings": {
@@ -161,15 +199,10 @@
"unsupported_streaming_strategy": {
"description": "Platforms Without Streaming Support",
"hint": "Select the handling method for platforms that don't support streaming responses. Real-time segmented reply sends content immediately when the system detects segment points like punctuation during streaming reception",
"labels": ["Real-time Segmented Reply", "Disable Streaming Response"]
},
"max_context_length": {
"description": "Maximum Conversation Rounds",
"hint": "Discards the oldest parts when this count is exceeded. One conversation round counts as 1, -1 means unlimited"
},
"dequeue_context_length": {
"description": "Dequeue Conversation Rounds",
"hint": "Number of conversation rounds to discard at once when maximum context length is exceeded"
"labels": [
"Real-time Segmented Reply",
"Disable Streaming Response"
]
},
"wake_prefix": {
"description": "Additional LLM Chat Wake Prefix",
@@ -387,7 +420,10 @@
},
"split_mode": {
"description": "Split Mode",
"labels": ["Regex", "Words List"]
"labels": [
"Regex",
"Words List"
]
},
"regex": {
"description": "Segmentation Regular Expression"
@@ -488,4 +524,4 @@
}
}
}
}
}
@@ -145,6 +145,11 @@
"message": "This plugin has been flagged as containing security risks, including unsafe code or functionalities that may cause system malfunctions or data loss. Do you wish to proceed with the installation?",
"confirm": "Continue",
"cancel": "Cancel"
},
"forceUpdate": {
"title": "No New Version Detected",
"message": "No new version detected for this plugin. Do you want to force reinstall? This will pull the latest code from the remote repository.",
"confirm": "Force Update"
}
},
"messages": {
@@ -185,7 +190,8 @@
"reloadPlugin": "Reload Extension",
"togglePlugin": "Extension",
"viewHandlers": "View Handlers",
"updateTo": "Update to"
"updateTo": "Update to",
"reinstall": "Reinstall"
},
"status": {
"hasUpdate": "New version available",
@@ -207,4 +213,4 @@
"goToManage": "Go to Manage",
"later": "Later"
}
}
}
@@ -133,6 +133,36 @@
}
}
},
"truncate_and_compress": {
"description": "上下文管理策略",
"provider_settings": {
"max_context_length": {
"description": "最多携带对话轮数",
"hint": "超出这个数量时丢弃最旧的部分,一轮聊天记为 1 条,-1 为不限制"
},
"dequeue_context_length": {
"description": "丢弃对话轮数",
"hint": "超出最多携带对话轮数时, 一次丢弃的聊天轮数"
},
"context_limit_reached_strategy": {
"description": "超出模型上下文窗口时的处理方式",
"labels": ["按对话轮数截断", "由 LLM 压缩上下文"],
"hint": "当按对话轮数截断时,会根据上面\"丢弃对话轮数\"的配置丢弃最旧的 N 轮对话。当由 LLM 压缩上下文时,会使用指定的模型进行上下文压缩。"
},
"llm_compress_instruction": {
"description": "上下文压缩提示词",
"hint": "如果为空则使用默认提示词。"
},
"llm_compress_keep_recent": {
"description": "压缩时保留最近对话轮数",
"hint": "始终保留的最近 N 轮对话。"
},
"llm_compress_provider_id": {
"description": "用于上下文压缩的模型提供商 ID",
"hint": "留空时将降级为\"按对话轮数截断\"的策略。"
}
}
},
"others": {
"description": "其他配置",
"provider_settings": {
@@ -171,14 +201,7 @@
"关闭流式回复"
]
},
"max_context_length": {
"description": "最多携带对话轮数",
"hint": "超出这个数量时丢弃最旧的部分,一轮聊天记为 1 条,-1 为不限制"
},
"dequeue_context_length": {
"description": "丢弃对话轮数",
"hint": "超出最多携带对话轮数时, 一次丢弃的聊天轮数"
},
"wake_prefix": {
"description": "LLM 聊天额外唤醒前缀",
"hint": "如果唤醒前缀为 /, 额外聊天唤醒前缀为 chat,则需要 /chat 才会触发 LLM 请求"
@@ -145,6 +145,11 @@
"message": "该插件可能包含不安全的代码或功能,可能导致系统异常或数据损失等。请确认是否继续安装?",
"confirm": "继续",
"cancel": "取消"
},
"forceUpdate": {
"title": "未检测到新版本",
"message": "当前插件未检测到新版本,是否强制重新安装?这将从远程仓库拉取最新代码。",
"confirm": "强制更新"
}
},
"messages": {
@@ -185,7 +190,8 @@
"reloadPlugin": "重载插件",
"togglePlugin": "插件",
"viewHandlers": "查看行为",
"updateTo": "更新到"
"updateTo": "更新到",
"reinstall": "重新安装"
},
"status": {
"hasUpdate": "有新版本可用",
+75 -32
View File
@@ -77,14 +77,26 @@ const readmeDialog = reactive({
repoUrl: null
});
//
const forceUpdateDialog = reactive({
show: false,
extensionName: ''
});
//
const isListView = ref(false);
// localStorage false
const getInitialListViewMode = () => {
if (typeof window !== 'undefined' && window.localStorage) {
return localStorage.getItem('pluginListViewMode') === 'true';
}
return false;
};
const isListView = ref(getInitialListViewMode());
const pluginSearch = ref("");
const loading_ = ref(false);
//
const currentPage = ref(1);
const itemsPerPage = ref(6); // 6 (2 x 3)
//
const dangerConfirmDialog = ref(false);
@@ -113,7 +125,6 @@ const uploadTab = ref('file');
const showPluginFullName = ref(false);
const marketSearch = ref("");
const debouncedMarketSearch = ref("");
const filterKeys = ['name', 'desc', 'author'];
const refreshingMarket = ref(false);
const sortBy = ref('default'); // default, stars, author, updated
const sortOrder = ref('desc'); // desc () or asc ()
@@ -162,18 +173,6 @@ const pluginHeaders = computed(() => [
]);
//
const pluginMarketHeaders = computed(() => [
{ title: tm('table.headers.name'), key: 'name', maxWidth: '200px' },
{ title: tm('table.headers.description'), key: 'desc', maxWidth: '250px' },
{ title: tm('table.headers.author'), key: 'author', maxWidth: '90px' },
{ title: tm('table.headers.stars'), key: 'stars', maxWidth: '80px' },
{ title: tm('table.headers.lastUpdate'), key: 'updated_at', maxWidth: '100px' },
{ title: tm('table.headers.tags'), key: 'tags', maxWidth: '100px' },
{ title: tm('table.headers.actions'), key: 'actions', sortable: false }
]);
//
const filteredExtensions = computed(() => {
const data = Array.isArray(extension_data?.data) ? extension_data.data : [];
@@ -197,9 +196,6 @@ const filteredPlugins = computed(() => {
});
});
const pinnedPlugins = computed(() => {
return pluginMarketData.value.filter(plugin => plugin?.pinned);
});
//
const filteredMarketPlugins = computed(() => {
@@ -385,7 +381,17 @@ const handleUninstallConfirm = (options) => {
}
};
const updateExtension = async (extension_name) => {
const updateExtension = async (extension_name, forceUpdate = false) => {
//
const ext = extension_data.data?.find(e => e.name === extension_name);
//
if (!ext?.has_update && !forceUpdate) {
forceUpdateDialog.extensionName = extension_name;
forceUpdateDialog.show = true;
return;
}
loadingDialog.title = tm('status.loading');
loadingDialog.show = true;
try {
@@ -417,6 +423,14 @@ const updateExtension = async (extension_name) => {
}
};
//
const confirmForceUpdate = () => {
const name = forceUpdateDialog.extensionName;
forceUpdateDialog.show = false;
forceUpdateDialog.extensionName = '';
updateExtension(name, true);
};
const updateAllExtensions = async () => {
if (updatingAll.value || updatableExtensions.value.length === 0) return;
updatingAll.value = true;
@@ -552,14 +566,6 @@ const viewReadme = (plugin) => {
readmeDialog.show = true;
};
const open = (link) => {
if (link) {
window.open(link, '_blank');
}
};
//
const handleInstallPlugin = async (plugin) => {
if (plugin.tags && plugin.tags.includes('danger')) {
@@ -918,6 +924,13 @@ watch(marketSearch, (newVal) => {
}, 300); // 300ms
});
// localStorage
watch(isListView, (newVal) => {
if (typeof window !== 'undefined' && window.localStorage) {
localStorage.setItem('pluginListViewMode', String(newVal));
}
});
</script>
@@ -1037,8 +1050,21 @@ watch(marketSearch, (newVal) => {
<template v-slot:item.name="{ item }">
<div class="d-flex align-center py-2">
<div v-if="item.logo" class="mr-3" style="flex-shrink: 0;">
<img :src="item.logo" :alt="item.name"
style="height: 40px; width: 40px; border-radius: 8px; object-fit: cover;" />
</div>
<div v-else class="mr-3" style="flex-shrink: 0;">
<img :src="defaultPluginIcon" :alt="item.name"
style="height: 40px; width: 40px; border-radius: 8px; object-fit: cover;" />
</div>
<div>
<div class="text-subtitle-1 font-weight-medium">{{ item.name }}</div>
<div class="text-subtitle-1 font-weight-medium">
{{ item.display_name && item.display_name.length ? item.display_name : item.name }}
</div>
<div v-if="item.display_name && item.display_name.length" class="text-caption text-medium-emphasis mt-1">
{{ item.name }}
</div>
<div v-if="item.reserved" class="d-flex align-center mt-1">
<v-chip color="primary" size="x-small" class="font-weight-medium">{{ tm('status.system')
}}</v-chip>
@@ -1048,7 +1074,7 @@ watch(marketSearch, (newVal) => {
</template>
<template v-slot:item.desc="{ item }">
<div class="text-body-2 text-medium-emphasis">{{ item.desc }}</div>
<div class="text-body-2 text-medium-emphasis mt-2 mb-2" style="display: -webkit-box; -webkit-line-clamp: 3; line-clamp: 3; -webkit-box-orient: vertical; overflow: hidden; text-overflow: ellipsis;">{{ item.desc }}</div>
</template>
<template v-slot:item.version="{ item }">
@@ -1084,7 +1110,7 @@ watch(marketSearch, (newVal) => {
<v-tooltip activator="parent" location="top">{{ tm('tooltips.disable') }}</v-tooltip>
</v-btn>
<v-btn icon size="small" color="info" @click="reloadPlugin(item.name)">
<v-btn icon size="small" @click="reloadPlugin(item.name)">
<v-icon>mdi-refresh</v-icon>
<v-tooltip activator="parent" location="top">{{ tm('tooltips.reload') }}</v-tooltip>
</v-btn>
@@ -1104,8 +1130,7 @@ watch(marketSearch, (newVal) => {
<v-tooltip activator="parent" location="top">{{ tm('tooltips.viewDocs') }}</v-tooltip>
</v-btn>
<v-btn icon size="small" color="warning" @click="updateExtension(item.name)"
:v-show="item.has_update">
<v-btn icon size="small" @click="updateExtension(item.name)">
<v-icon>mdi-update</v-icon>
<v-tooltip activator="parent" location="top">{{ tm('tooltips.update') }}</v-tooltip>
</v-btn>
@@ -1772,6 +1797,24 @@ watch(marketSearch, (newVal) => {
</v-card-actions>
</v-card>
</v-dialog>
<!-- 强制更新确认对话框 -->
<v-dialog v-model="forceUpdateDialog.show" max-width="420">
<v-card class="rounded-lg">
<v-card-title class="text-h6 d-flex align-center">
<v-icon color="info" class="mr-2">mdi-information-outline</v-icon>
{{ tm('dialogs.forceUpdate.title') }}
</v-card-title>
<v-card-text>
{{ tm('dialogs.forceUpdate.message') }}
</v-card-text>
<v-card-actions>
<v-spacer></v-spacer>
<v-btn variant="text" @click="forceUpdateDialog.show = false">{{ tm('buttons.cancel') }}</v-btn>
<v-btn color="primary" variant="flat" @click="confirmForceUpdate">{{ tm('dialogs.forceUpdate.confirm') }}</v-btn>
</v-card-actions>
</v-card>
</v-dialog>
</template>
<style scoped>
+1 -1
View File
@@ -1,6 +1,6 @@
[project]
name = "AstrBot"
version = "4.10.4"
version = "4.11.0"
description = "Easy-to-use multi-platform LLM chatbot and development framework"
readme = "README.md"
requires-python = ">=3.10"
+774
View File
@@ -0,0 +1,774 @@
"""Comprehensive tests for ContextManager."""
import sys
from pathlib import Path
from typing import Literal
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
# Add parent directory to path to avoid circular import issues
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
from astrbot.core.agent.context.config import ContextConfig
from astrbot.core.agent.context.manager import ContextManager
from astrbot.core.agent.message import Message, TextPart
from astrbot.core.provider.entities import LLMResponse
class MockProvider:
"""模拟 Provider"""
def __init__(self):
self.provider_config = {
"id": "test_provider",
"model": "gpt-4",
"modalities": ["text", "image", "tool_use"],
}
async def text_chat(self, **kwargs):
"""模拟 LLM 调用,返回摘要"""
messages = kwargs.get("messages", [])
# 简单的摘要逻辑:返回消息数量统计
return LLMResponse(
role="assistant",
completion_text=f"历史对话包含 {len(messages) - 1} 条消息,主要讨论了技术话题。",
)
def get_model(self):
return "gpt-4"
def meta(self):
return MagicMock(id="test_provider", type="openai")
class TestContextManager:
"""Test suite for ContextManager."""
def create_message(
self, role: Literal["system", "user", "assistant", "tool"], content: str
) -> Message:
"""Helper to create a simple text message."""
return Message(role=role, content=content)
def create_messages(self, count: int) -> list[Message]:
"""Helper to create alternating user/assistant messages."""
messages = []
for i in range(count):
role = "user" if i % 2 == 0 else "assistant"
messages.append(self.create_message(role, f"Message {i}"))
return messages
# ==================== Basic Initialization Tests ====================
def test_init_with_minimal_config(self):
"""Test initialization with minimal configuration."""
config = ContextConfig()
manager = ContextManager(config)
assert manager.config == config
assert manager.token_counter is not None
assert manager.truncator is not None
assert manager.compressor is not None
def test_init_with_llm_compressor(self):
"""Test initialization with LLM-based compression."""
mock_provider = MockProvider()
config = ContextConfig(
llm_compress_provider=mock_provider, # type: ignore
llm_compress_keep_recent=5,
llm_compress_instruction="Summarize the conversation",
)
manager = ContextManager(config)
from astrbot.core.agent.context.compressor import LLMSummaryCompressor
assert isinstance(manager.compressor, LLMSummaryCompressor)
def test_init_with_truncate_compressor(self):
"""Test initialization with truncate-based compression (default)."""
config = ContextConfig(truncate_turns=3)
manager = ContextManager(config)
from astrbot.core.agent.context.compressor import TruncateByTurnsCompressor
assert isinstance(manager.compressor, TruncateByTurnsCompressor)
# ==================== Empty and Edge Cases ====================
@pytest.mark.asyncio
async def test_process_empty_messages(self):
"""Test processing an empty message list."""
config = ContextConfig()
manager = ContextManager(config)
result = await manager.process([])
assert result == []
@pytest.mark.asyncio
async def test_process_single_message(self):
"""Test processing a single message."""
config = ContextConfig()
manager = ContextManager(config)
messages = [self.create_message("user", "Hello")]
result = await manager.process(messages)
assert len(result) == 1
assert result[0].content == "Hello"
@pytest.mark.asyncio
async def test_process_with_no_limits(self):
"""Test processing when no limits are set (no truncation or compression)."""
config = ContextConfig(max_context_tokens=0, enforce_max_turns=-1)
manager = ContextManager(config)
messages = self.create_messages(20)
result = await manager.process(messages)
assert len(result) == 20
assert result == messages
# ==================== Enforce Max Turns Tests ====================
@pytest.mark.asyncio
async def test_enforce_max_turns_basic(self):
"""Test basic enforce_max_turns functionality."""
config = ContextConfig(enforce_max_turns=3, truncate_turns=1)
manager = ContextManager(config)
# Create 10 turns (20 messages)
messages = self.create_messages(20)
result = await manager.process(messages)
# Should keep only 3 most recent turns (6 messages)
assert len(result) <= 8 # May vary due to truncation logic
@pytest.mark.asyncio
async def test_enforce_max_turns_zero(self):
"""Test enforce_max_turns with value 0 (should keep nothing)."""
config = ContextConfig(enforce_max_turns=0, truncate_turns=1)
manager = ContextManager(config)
messages = self.create_messages(10)
result = await manager.process(messages)
# Should result in empty or minimal message list
assert len(result) <= 2
@pytest.mark.asyncio
async def test_enforce_max_turns_negative(self):
"""Test enforce_max_turns with -1 (no limit)."""
config = ContextConfig(enforce_max_turns=-1)
manager = ContextManager(config)
messages = self.create_messages(20)
result = await manager.process(messages)
assert len(result) == 20
@pytest.mark.asyncio
async def test_enforce_max_turns_with_system_messages(self):
"""Test enforce_max_turns preserves system messages."""
config = ContextConfig(enforce_max_turns=2, truncate_turns=1)
manager = ContextManager(config)
messages = [
self.create_message("system", "System instruction"),
*self.create_messages(10),
]
result = await manager.process(messages)
# System message should be preserved
system_msgs = [m for m in result if m.role == "system"]
assert len(system_msgs) >= 1
assert system_msgs[0].content == "System instruction"
# ==================== Token-based Compression Tests ====================
@pytest.mark.asyncio
async def test_token_compression_not_triggered_below_threshold(self):
"""Test that compression is not triggered below threshold."""
config = ContextConfig(max_context_tokens=1000)
manager = ContextManager(config)
# Create messages that total less than threshold
messages = [self.create_message("user", "Hi" * 50)] # ~100 tokens
with patch.object(
manager.compressor, "should_compress", return_value=False
) as mock_should_compress:
with patch.object(
manager.compressor, "__call__", new_callable=AsyncMock
) as mock_compress:
result = await manager.process(messages)
# should_compress should be called
mock_should_compress.assert_called_once()
# Compressor should not be called
mock_compress.assert_not_called()
assert result == messages
@pytest.mark.asyncio
async def test_token_compression_triggered_above_threshold(self):
"""Test that compression is triggered above threshold."""
config = ContextConfig(max_context_tokens=100, truncate_turns=1)
manager = ContextManager(config)
# Create messages that exceed threshold (0.82 * 100 = 82 tokens)
# 300 chars * 0.3 = 90 tokens > 82 threshold
long_text = "x" * 300 # ~90 tokens, above threshold
messages = [self.create_message("user", long_text)]
# Mock compressor to return smaller result
compressed = [self.create_message("user", "short")]
# Create a mock compressor
mock_compressor = AsyncMock()
mock_compressor.compression_threshold = 0.82
mock_compressor.return_value = compressed
# Mock should_compress to return True first time, False after
call_count = 0
def mock_should_compress(*args, **kwargs):
nonlocal call_count
call_count += 1
return call_count == 1
mock_compressor.should_compress = mock_should_compress
manager.compressor = mock_compressor
result = await manager.process(messages)
# Compressor should be called
mock_compressor.assert_called_once()
# Result should be the compressed version
assert len(result) <= len(messages)
@pytest.mark.asyncio
async def test_token_compression_with_zero_max_tokens(self):
"""Test that compression is skipped when max_context_tokens is 0."""
config = ContextConfig(max_context_tokens=0)
manager = ContextManager(config)
messages = [self.create_message("user", "x" * 10000)]
with patch.object(
manager.compressor, "__call__", new_callable=AsyncMock
) as mock_compress:
result = await manager.process(messages)
# Compressor should not be called when max_context_tokens is 0
mock_compress.assert_not_called()
assert result == messages
@pytest.mark.asyncio
async def test_token_compression_with_negative_max_tokens(self):
"""Test that compression is skipped when max_context_tokens is negative."""
config = ContextConfig(max_context_tokens=-100)
manager = ContextManager(config)
messages = [self.create_message("user", "x" * 10000)]
with patch.object(
manager.compressor, "__call__", new_callable=AsyncMock
) as mock_compress:
result = await manager.process(messages)
# Compressor should not be called
mock_compress.assert_not_called()
assert result == messages
@pytest.mark.asyncio
async def test_double_check_after_compression(self):
"""Test that halving is applied if still over threshold after compression."""
config = ContextConfig(max_context_tokens=100)
manager = ContextManager(config)
# Create messages that would still be over threshold after compression
long_messages = [self.create_message("user", "x" * 200) for _ in range(10)]
# Mock compressor to return messages still over threshold
async def mock_compress(msgs):
return msgs # Return same messages (still over limit)
# Mock should_compress to return True twice (before and after compression)
with patch.object(manager.compressor, "should_compress", return_value=True):
with patch.object(manager.compressor, "__call__", new=mock_compress):
with patch.object(
manager.truncator,
"truncate_by_halving",
return_value=long_messages[:5],
) as mock_halving:
_ = await manager.process(long_messages)
# Halving should be called
mock_halving.assert_called_once()
# ==================== Combined Truncation and Compression Tests ====================
@pytest.mark.asyncio
async def test_combined_enforce_turns_and_token_limit(self):
"""Test combining enforce_max_turns and token limit."""
config = ContextConfig(
enforce_max_turns=5, max_context_tokens=500, truncate_turns=1
)
manager = ContextManager(config)
# Create many messages
messages = self.create_messages(30)
result = await manager.process(messages)
# Should be truncated by both mechanisms
assert len(result) < 30
@pytest.mark.asyncio
async def test_sequential_processing_order(self):
"""Test that enforce_max_turns happens before token compression."""
config = ContextConfig(enforce_max_turns=5, max_context_tokens=1000)
manager = ContextManager(config)
messages = self.create_messages(20)
# Mock the truncator to track calls
with patch.object(
manager.truncator,
"truncate_by_turns",
wraps=manager.truncator.truncate_by_turns,
) as mock_truncate:
await manager.process(messages)
# Truncator should be called first
mock_truncate.assert_called_once()
# ==================== Error Handling Tests ====================
@pytest.mark.asyncio
async def test_error_handling_returns_original_messages(self):
"""Test that errors during processing return original messages."""
config = ContextConfig(max_context_tokens=100)
manager = ContextManager(config)
messages = self.create_messages(5)
# Make compressor raise an exception
with patch.object(
manager.compressor, "__call__", side_effect=Exception("Test error")
):
result = await manager.process(messages)
# Should return original messages despite error
assert result == messages
@pytest.mark.asyncio
async def test_error_handling_logs_exception(self):
"""Test that errors are logged."""
config = ContextConfig(max_context_tokens=100)
manager = ContextManager(config)
# Create messages that will trigger compression (> 82 tokens)
messages = [self.create_message("user", "x" * 300)] # ~90 tokens
# Replace compressor with one that raises an exception
mock_compressor = AsyncMock(side_effect=Exception("Test error"))
mock_compressor.compression_threshold = 0.82
mock_compressor.should_compress = MagicMock(return_value=True)
manager.compressor = mock_compressor
with patch("astrbot.core.agent.context.manager.logger") as mock_logger:
result = await manager.process(messages)
# Logger error method should be called
assert mock_logger.error.called
# Should return original messages on error
assert result == messages
# ==================== Multi-modal Content Tests ====================
@pytest.mark.asyncio
async def test_process_messages_with_textpart_content(self):
"""Test processing messages with TextPart content."""
config = ContextConfig()
manager = ContextManager(config)
messages = [
Message(role="user", content=[TextPart(text="Hello")]),
Message(role="assistant", content=[TextPart(text="Hi there")]),
]
result = await manager.process(messages)
assert len(result) == 2
assert result == messages
@pytest.mark.asyncio
async def test_token_counting_with_multimodal_content(self):
"""Test token counting works with multi-modal content."""
config = ContextConfig(max_context_tokens=50)
manager = ContextManager(config)
# Need enough tokens to exceed threshold: 50 * 0.82 = 41 tokens
# 150 chars * 0.3 = 45 tokens > 41
messages = [
Message(role="user", content=[TextPart(text="x" * 150)]),
]
# Should trigger compression due to token count
tokens = manager.token_counter.count_tokens(messages)
needs_compression = manager.compressor.should_compress(messages, tokens, 50)
assert tokens > 0 # Tokens should be counted
assert needs_compression # Should trigger compression
# ==================== Tool Calls Tests ====================
@pytest.mark.asyncio
async def test_process_messages_with_tool_calls(self):
"""Test processing messages with tool calls."""
config = ContextConfig()
manager = ContextManager(config)
messages = [
Message(
role="assistant",
content="Let me search for that",
tool_calls=[
{
"id": "call_1",
"type": "function",
"function": {"name": "search", "arguments": "{}"},
}
],
),
Message(role="tool", content="Search result", tool_call_id="call_1"),
]
result = await manager.process(messages)
assert len(result) == 2
# ==================== Compressor should_compress Tests ====================
@pytest.mark.asyncio
async def test_should_compress_empty_messages(self):
"""Test should_compress with empty messages."""
config = ContextConfig(max_context_tokens=100)
manager = ContextManager(config)
# Compressor's should_compress should handle empty gracefully
needs_compression = manager.compressor.should_compress([], 0, 100)
assert not needs_compression
@pytest.mark.asyncio
async def test_should_compress_below_threshold(self):
"""Test should_compress when below compression threshold."""
config = ContextConfig(max_context_tokens=1000)
manager = ContextManager(config)
messages = [self.create_message("user", "Hello")]
tokens = manager.token_counter.count_tokens(messages)
needs_compression = manager.compressor.should_compress(messages, tokens, 1000)
assert not needs_compression
@pytest.mark.asyncio
async def test_should_compress_above_threshold(self):
"""Test should_compress when above compression threshold."""
config = ContextConfig(max_context_tokens=100)
manager = ContextManager(config)
# Create message with many tokens
messages = [self.create_message("user", "这是测试" * 50)]
tokens = manager.token_counter.count_tokens(messages)
needs_compression = manager.compressor.should_compress(messages, tokens, 100)
# Should need compression if tokens > 82 (0.82 * 100)
assert needs_compression == (tokens > 82)
# ==================== Truncator Halving Tests ====================
def test_truncate_by_halving_basic(self):
"""Test truncate_by_halving removes middle 50%."""
config = ContextConfig()
manager = ContextManager(config)
messages = self.create_messages(10)
result = manager.truncator.truncate_by_halving(messages)
# Should keep roughly half
assert len(result) < len(messages)
def test_truncate_by_halving_empty_list(self):
"""Test truncate_by_halving with empty list."""
config = ContextConfig()
manager = ContextManager(config)
result = manager.truncator.truncate_by_halving([])
assert result == []
def test_truncate_by_halving_single_message(self):
"""Test truncate_by_halving with single message."""
config = ContextConfig()
manager = ContextManager(config)
messages = [self.create_message("user", "Hello")]
result = manager.truncator.truncate_by_halving(messages)
assert len(result) <= 1
# ==================== Complex Scenarios ====================
@pytest.mark.asyncio
async def test_multiple_compression_cycles(self):
"""Test that compression can be triggered multiple times in sequence."""
config = ContextConfig(max_context_tokens=50, truncate_turns=1)
manager = ContextManager(config)
# Process messages multiple times
messages = self.create_messages(10)
result1 = await manager.process(messages)
result2 = await manager.process(result1)
result3 = await manager.process(result2)
# Each cycle should maintain or reduce message count
assert len(result3) <= len(result2) <= len(result1)
@pytest.mark.asyncio
async def test_alternating_roles_preserved(self):
"""Test that user/assistant alternation is preserved after processing."""
config = ContextConfig(enforce_max_turns=3, truncate_turns=1)
manager = ContextManager(config)
messages = self.create_messages(20)
result = await manager.process(messages)
# Check that roles still alternate (excluding system messages)
non_system = [m for m in result if m.role != "system"]
if len(non_system) >= 2:
# Should start with user
assert non_system[0].role == "user"
@pytest.mark.asyncio
async def test_compression_threshold_default(self):
"""Test that compression threshold is used correctly."""
config = ContextConfig(max_context_tokens=100)
manager = ContextManager(config)
# Verify the default threshold is 0.82
assert manager.compressor.compression_threshold == 0.82
# Test threshold logic
messages = [self.create_message("user", "x" * 81)] # ~24 tokens
tokens = manager.token_counter.count_tokens(messages)
needs_compression = manager.compressor.should_compress(messages, tokens, 100)
# Should not compress if below threshold
assert needs_compression == (tokens > 82)
@pytest.mark.asyncio
async def test_large_batch_processing(self):
"""Test processing a large batch of messages."""
config = ContextConfig(
enforce_max_turns=10, max_context_tokens=1000, truncate_turns=2
)
manager = ContextManager(config)
# Create 100 messages (50 turns)
messages = self.create_messages(100)
result = await manager.process(messages)
# Should be significantly reduced
assert len(result) < 100
assert len(result) > 0
@pytest.mark.asyncio
async def test_config_persistence(self):
"""Test that config settings are respected throughout processing."""
config = ContextConfig(
max_context_tokens=500,
enforce_max_turns=5,
truncate_turns=2,
llm_compress_keep_recent=3,
)
manager = ContextManager(config)
# Verify config is stored
assert manager.config.max_context_tokens == 500
assert manager.config.enforce_max_turns == 5
assert manager.config.truncate_turns == 2
assert manager.config.llm_compress_keep_recent == 3
# ==================== Run Compression Tests ====================
@pytest.mark.asyncio
async def test_run_compression_calls_compressor(self):
"""Test _run_compression calls compressor."""
config = ContextConfig(max_context_tokens=100)
manager = ContextManager(config)
messages = self.create_messages(5)
compressed = self.create_messages(3)
# Create a mock compressor
mock_compressor = AsyncMock()
mock_compressor.compression_threshold = 0.82
mock_compressor.return_value = compressed
mock_compressor.should_compress = MagicMock(return_value=False)
manager.compressor = mock_compressor
result = await manager._run_compression(messages, prev_tokens=100)
# Compressor __call__ should be invoked
mock_compressor.assert_called_once_with(messages)
assert result == compressed
@pytest.mark.asyncio
async def test_run_compression_applies_compressor_through_process(self):
"""Test _run_compression calls compressor when needed through process()."""
config = ContextConfig(max_context_tokens=100, truncate_turns=1)
manager = ContextManager(config)
# Create messages that will trigger compression
messages = [self.create_message("user", "x" * 300)] # ~90 tokens > 82 threshold
compressed = [self.create_message("user", "short")] # Much smaller
# Create a mock compressor
mock_compressor = AsyncMock()
mock_compressor.compression_threshold = 0.82
mock_compressor.return_value = compressed
# Mock should_compress to return True first time, False after
call_count = 0
def mock_should_compress(*args, **kwargs):
nonlocal call_count
call_count += 1
return call_count == 1
mock_compressor.should_compress = mock_should_compress
manager.compressor = mock_compressor
result = await manager.process(messages)
# Compressor should have been called
mock_compressor.assert_called_once()
assert len(result) <= len(messages)
@pytest.mark.asyncio
async def test_llm_compression_with_mock_provider(self):
"""Test LLM compression using MockProvider."""
mock_provider = MockProvider()
config = ContextConfig(
llm_compress_provider=mock_provider, # type: ignore
llm_compress_keep_recent=3,
llm_compress_instruction="请总结对话内容",
max_context_tokens=100,
)
manager = ContextManager(config)
# Create messages that will trigger compression
messages = [
self.create_message("user", "x" * 100),
self.create_message("assistant", "y" * 100),
self.create_message("user", "z" * 100),
]
result = await manager.process(messages)
# Should have been compressed
assert len(result) <= len(messages)
# ==================== split_history Tests ====================
def test_split_history_ensures_user_start(self):
"""Test split_history ensures recent_messages starts with user message."""
from astrbot.core.agent.context.compressor import split_history
# Create alternating messages: user, assistant, user, assistant, user, assistant
messages = [
self.create_message("system", "System prompt"),
self.create_message("user", "msg1"),
self.create_message("assistant", "msg2"),
self.create_message("user", "msg3"),
self.create_message("assistant", "msg4"),
self.create_message("user", "msg5"),
self.create_message("assistant", "msg6"),
]
# Keep recent 3 messages - should adjust to start with user
system, to_summarize, recent = split_history(messages, keep_recent=3)
# recent_messages should start with user message
assert len(recent) > 0
assert recent[0].role == "user"
# messages_to_summarize should end with assistant (complete turn)
if len(to_summarize) > 0:
assert to_summarize[-1].role == "assistant"
def test_split_history_handles_assistant_at_split_point(self):
"""Test split_history when assistant message is at the intended split point."""
from astrbot.core.agent.context.compressor import split_history
messages = [
self.create_message("user", "msg1"),
self.create_message("assistant", "msg2"),
self.create_message("user", "msg3"),
self.create_message("assistant", "msg4"), # <- intended split here
self.create_message("user", "msg5"),
self.create_message("assistant", "msg6"),
]
# keep_recent=2 would normally split at index 4 (assistant msg4)
# Should move back to include from msg5 (user)
system, to_summarize, recent = split_history(messages, keep_recent=2)
# recent should start with user message
assert recent[0].role == "user"
assert recent[0].content == "msg5"
def test_split_history_all_assistant_messages(self):
"""Test split_history when there are consecutive assistant messages."""
from astrbot.core.agent.context.compressor import split_history
messages = [
self.create_message("user", "msg1"),
self.create_message("assistant", "msg2"),
self.create_message("assistant", "msg3"),
self.create_message("assistant", "msg4"),
]
system, to_summarize, recent = split_history(messages, keep_recent=2)
# Should find the user message and keep from there
if len(recent) > 0:
# Find first user message backwards
assert any(m.role == "user" for m in messages)
def test_split_history_with_system_messages(self):
"""Test split_history preserves system messages separately."""
from astrbot.core.agent.context.compressor import split_history
messages = [
self.create_message("system", "System 1"),
self.create_message("system", "System 2"),
self.create_message("user", "msg1"),
self.create_message("assistant", "msg2"),
self.create_message("user", "msg3"),
]
system, to_summarize, recent = split_history(messages, keep_recent=2)
# System messages should be separate
assert len(system) == 2
assert all(m.role == "system" for m in system)
# Recent should start with user
if len(recent) > 0:
assert recent[0].role == "user"
+423
View File
@@ -0,0 +1,423 @@
"""Tests for ContextTruncator."""
from astrbot.core.agent.context.truncator import ContextTruncator
from astrbot.core.agent.message import Message
class TestContextTruncator:
"""Test suite for ContextTruncator."""
def create_message(self, role: str, content: str = "test content") -> Message:
"""Helper to create a simple test message."""
return Message(role=role, content=content)
def create_messages(
self, count: int, include_system: bool = False
) -> list[Message]:
"""Helper to create alternating user/assistant messages.
Args:
count: Number of messages to create
include_system: Whether to include a system message at the start
Returns:
List of messages
"""
messages = []
if include_system:
messages.append(self.create_message("system", "System prompt"))
for i in range(count):
role = "user" if i % 2 == 0 else "assistant"
messages.append(self.create_message(role, f"Message {i}"))
return messages
# ==================== fix_messages Tests ====================
def test_fix_messages_empty_list(self):
"""Test fix_messages with an empty list."""
truncator = ContextTruncator()
result = truncator.fix_messages([])
assert result == []
def test_fix_messages_normal_messages(self):
"""Test fix_messages with normal user/assistant messages."""
truncator = ContextTruncator()
messages = [
self.create_message("user", "Hello"),
self.create_message("assistant", "Hi"),
self.create_message("user", "How are you?"),
]
result = truncator.fix_messages(messages)
assert len(result) == 3
assert result == messages
def test_fix_messages_tool_with_valid_context(self):
"""Test fix_messages with tool message after user+assistant."""
truncator = ContextTruncator()
messages = [
self.create_message("user", "Run tool"),
self.create_message("assistant", "Running..."),
self.create_message("tool", "Tool result"),
]
result = truncator.fix_messages(messages)
assert len(result) == 3
assert result == messages
def test_fix_messages_tool_without_context(self):
"""Test fix_messages with tool message without enough context."""
truncator = ContextTruncator()
messages = [
self.create_message("tool", "Tool result"),
]
result = truncator.fix_messages(messages)
# Tool message without context should be removed
assert len(result) == 0
def test_fix_messages_tool_with_only_one_message(self):
"""Test fix_messages with tool message after only one message."""
truncator = ContextTruncator()
messages = [
self.create_message("user", "Hello"),
self.create_message("tool", "Tool result"),
]
result = truncator.fix_messages(messages)
# Tool message without enough context should be removed
assert len(result) == 0
def test_fix_messages_multiple_tools(self):
"""Test fix_messages with multiple tool messages."""
truncator = ContextTruncator()
messages = [
self.create_message("user", "Run tool"),
self.create_message("assistant", "Running..."),
self.create_message("tool", "Tool 1 result"),
self.create_message("tool", "Tool 2 result"),
]
result = truncator.fix_messages(messages)
assert len(result) == 4
assert result == messages
def test_fix_messages_mixed_system_tool(self):
"""Test fix_messages with system message and tool messages."""
truncator = ContextTruncator()
messages = [
self.create_message("system", "System prompt"),
self.create_message("user", "Run tool"),
self.create_message("assistant", "Running..."),
self.create_message("tool", "Tool result"),
]
result = truncator.fix_messages(messages)
assert len(result) == 4
assert result == messages
# ==================== truncate_by_turns Tests ====================
def test_truncate_by_turns_no_limit(self):
"""Test truncate_by_turns with -1 (no limit)."""
truncator = ContextTruncator()
messages = self.create_messages(20)
result = truncator.truncate_by_turns(messages, keep_most_recent_turns=-1)
assert len(result) == 20
assert result == messages
def test_truncate_by_turns_basic(self):
"""Test basic truncate_by_turns functionality."""
truncator = ContextTruncator()
# Create 10 messages = 5 turns (user/assistant pairs)
messages = self.create_messages(10)
result = truncator.truncate_by_turns(
messages, keep_most_recent_turns=3, drop_turns=1
)
# Should keep 3 most recent turns (6 messages)
assert len(result) <= 8 # (3-1+1)*2 = 6, but may adjust for correct format
def test_truncate_by_turns_with_system_message(self):
"""Test truncate_by_turns preserves system messages."""
truncator = ContextTruncator()
messages = self.create_messages(10, include_system=True)
result = truncator.truncate_by_turns(
messages, keep_most_recent_turns=2, drop_turns=1
)
# System message should always be preserved
assert result[0].role == "system"
assert result[0].content == "System prompt"
def test_truncate_by_turns_zero_keep(self):
"""Test truncate_by_turns with keep_most_recent_turns=0."""
truncator = ContextTruncator()
messages = self.create_messages(10)
result = truncator.truncate_by_turns(
messages, keep_most_recent_turns=0, drop_turns=1
)
# Should result in empty or minimal list
assert len(result) == 0
def test_truncate_by_turns_below_threshold(self):
"""Test truncate_by_turns when messages are below threshold."""
truncator = ContextTruncator()
# Create 4 messages = 2 turns
messages = self.create_messages(4)
result = truncator.truncate_by_turns(
messages, keep_most_recent_turns=5, drop_turns=1
)
# No truncation should happen
assert len(result) == 4
assert result == messages
def test_truncate_by_turns_exact_threshold(self):
"""Test truncate_by_turns when messages exactly match threshold."""
truncator = ContextTruncator()
# Create 6 messages = 3 turns
messages = self.create_messages(6)
result = truncator.truncate_by_turns(
messages, keep_most_recent_turns=3, drop_turns=1
)
# No truncation should happen
assert len(result) == 6
assert result == messages
def test_truncate_by_turns_ensures_user_first(self):
"""Test that truncate_by_turns ensures user message comes first."""
truncator = ContextTruncator()
# Create scenario where truncation might start with assistant
messages = self.create_messages(20)
result = truncator.truncate_by_turns(
messages, keep_most_recent_turns=3, drop_turns=1
)
# First non-system message should be user
assert result[0].role == "user"
def test_truncate_by_turns_multiple_drop(self):
"""Test truncate_by_turns with multiple turns dropped at once."""
truncator = ContextTruncator()
messages = self.create_messages(20)
result = truncator.truncate_by_turns(
messages, keep_most_recent_turns=5, drop_turns=3
)
# Should drop 3 turns when limit exceeded
assert len(result) < len(messages)
# ==================== truncate_by_dropping_oldest_turns Tests ====================
def test_truncate_by_dropping_oldest_turns_zero(self):
"""Test truncate_by_dropping_oldest_turns with drop_turns=0."""
truncator = ContextTruncator()
messages = self.create_messages(10)
result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=0)
assert result == messages
def test_truncate_by_dropping_oldest_turns_negative(self):
"""Test truncate_by_dropping_oldest_turns with negative drop_turns."""
truncator = ContextTruncator()
messages = self.create_messages(10)
result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=-1)
assert result == messages
def test_truncate_by_dropping_oldest_turns_basic(self):
"""Test basic truncate_by_dropping_oldest_turns functionality."""
truncator = ContextTruncator()
# Create 10 messages = 5 turns
messages = self.create_messages(10)
result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=2)
# Should drop 2 oldest turns (4 messages)
assert len(result) == 6
# Should start with user message
assert result[0].role == "user"
def test_truncate_by_dropping_oldest_turns_with_system(self):
"""Test truncate_by_dropping_oldest_turns preserves system messages."""
truncator = ContextTruncator()
messages = self.create_messages(10, include_system=True)
result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=2)
# System message should be preserved
assert result[0].role == "system"
assert result[0].content == "System prompt"
def test_truncate_by_dropping_oldest_turns_drop_all(self):
"""Test truncate_by_dropping_oldest_turns dropping all turns."""
truncator = ContextTruncator()
# Create 4 messages = 2 turns
messages = self.create_messages(4)
result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=2)
# Should drop all turns
assert len(result) == 0
def test_truncate_by_dropping_oldest_turns_drop_more_than_available(self):
"""Test truncate_by_dropping_oldest_turns with drop_turns > available turns."""
truncator = ContextTruncator()
# Create 4 messages = 2 turns
messages = self.create_messages(4)
result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=5)
# Should result in empty list
assert len(result) == 0
def test_truncate_by_dropping_oldest_turns_ensures_user_first(self):
"""Test that result starts with user message after dropping."""
truncator = ContextTruncator()
messages = self.create_messages(20)
result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=3)
# First message should be user
if len(result) > 0:
assert result[0].role == "user"
# ==================== truncate_by_halving Tests ====================
def test_truncate_by_halving_empty(self):
"""Test truncate_by_halving with empty list."""
truncator = ContextTruncator()
result = truncator.truncate_by_halving([])
assert result == []
def test_truncate_by_halving_single_message(self):
"""Test truncate_by_halving with single message."""
truncator = ContextTruncator()
messages = [self.create_message("user", "Hello")]
result = truncator.truncate_by_halving(messages)
# Should not truncate if <= 2 messages
assert result == messages
def test_truncate_by_halving_two_messages(self):
"""Test truncate_by_halving with two messages."""
truncator = ContextTruncator()
messages = self.create_messages(2)
result = truncator.truncate_by_halving(messages)
# Should not truncate if <= 2 messages
assert result == messages
def test_truncate_by_halving_basic(self):
"""Test basic truncate_by_halving functionality."""
truncator = ContextTruncator()
# Create 20 messages
messages = self.create_messages(20)
result = truncator.truncate_by_halving(messages)
# Should delete 50% = 10 messages, keep 10
assert len(result) == 10
# First message should be user
assert result[0].role == "user"
def test_truncate_by_halving_with_system_message(self):
"""Test truncate_by_halving preserves system messages."""
truncator = ContextTruncator()
messages = self.create_messages(20, include_system=True)
result = truncator.truncate_by_halving(messages)
# System message should be preserved
assert result[0].role == "system"
assert result[0].content == "System prompt"
def test_truncate_by_halving_odd_count(self):
"""Test truncate_by_halving with odd number of messages."""
truncator = ContextTruncator()
messages = self.create_messages(11)
result = truncator.truncate_by_halving(messages)
# Should delete floor(11/2) = 5 messages, keep 6
# But after ensuring user first, may be 5
assert len(result) >= 5
assert result[0].role == "user"
def test_truncate_by_halving_ensures_user_first(self):
"""Test that result starts with user message."""
truncator = ContextTruncator()
# Create messages starting with user
messages = self.create_messages(30)
result = truncator.truncate_by_halving(messages)
# First message should be user
assert result[0].role == "user"
def test_truncate_by_halving_preserves_recent_messages(self):
"""Test that truncate_by_halving keeps the most recent 50%."""
truncator = ContextTruncator()
messages = [
self.create_message("user", "Message 0"),
self.create_message("assistant", "Message 1"),
self.create_message("user", "Message 2"),
self.create_message("assistant", "Message 3"),
]
result = truncator.truncate_by_halving(messages)
# Should keep last 2 messages
assert len(result) == 2
assert result[0].content == "Message 2"
assert result[1].content == "Message 3"
# ==================== Integration Tests ====================
def test_truncate_with_tool_messages(self):
"""Test truncation with tool messages."""
truncator = ContextTruncator()
messages = [
self.create_message("user", "Run tool"),
self.create_message("assistant", "Running..."),
self.create_message("tool", "Tool result"),
self.create_message("user", "Thanks"),
self.create_message("assistant", "Welcome"),
]
result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=1)
# First turn (user+assistant+tool) should be dropped
# Tool message should be cleaned up by fix_messages
assert len(result) <= 2
def test_chain_multiple_truncations(self):
"""Test chaining multiple truncation methods."""
truncator = ContextTruncator()
messages = self.create_messages(40, include_system=True)
# First: truncate by turns
result = truncator.truncate_by_turns(
messages, keep_most_recent_turns=10, drop_turns=2
)
# Then: halve
result = truncator.truncate_by_halving(result)
# Should have system message + truncated content
assert result[0].role == "system"
assert len(result) < len(messages)
def test_empty_after_system_message(self):
"""Test truncation when only system message exists."""
truncator = ContextTruncator()
messages = [self.create_message("system", "System prompt")]
result = truncator.truncate_by_turns(
messages, keep_most_recent_turns=5, drop_turns=1
)
# Should keep system message
assert len(result) == 1
assert result[0].role == "system"
def test_all_system_messages(self):
"""Test truncation with only system messages."""
truncator = ContextTruncator()
messages = [
self.create_message("system", "System 1"),
self.create_message("system", "System 2"),
]
result = truncator.truncate_by_turns(
messages, keep_most_recent_turns=0, drop_turns=1
)
# System messages should be preserved, but since there are no non-system
# messages and keep_most_recent_turns=0, result should be system messages only
assert len(result) >= 0 # May keep system messages or clear all
if len(result) > 0:
assert all(msg.role == "system" for msg in result)