Compare commits

...

23 Commits

Author SHA1 Message Date
Soulter 8e7b44185d chore: bump version to 4.11.0 2026-01-05 18:05:12 +08:00
RC-CHN ef1c66a92e feat(webui): enable Range request support for backup downloads (#4329) 2026-01-05 17:27:03 +08:00
Soulter 241f1c26d3 feat: context compress (#4322)
* feat: context compressor

Co-authored-by: kawayiYokami <289104862@qq.com>

* Add comprehensive tests for ContextManager and ContextTruncator

- Implemented a full test suite for ContextManager covering initialization, message processing, token-based compression, and error handling.
- Added tests for ContextTruncator focusing on message fixing, truncation by turns, dropping oldest turns, and halving.
- Ensured that both test suites validate edge cases and maintain expected behavior with various message types, including system and tool messages.

* feat: add MockProvider for LLM compression tests

* chore: remove lock

* ruff fix

* fix

* perf

* feat: enhance context compression with token tracking and logging

* feat: update logging for context compression trigger

* feat: implement context compression logic with dynamic threshold and token tracking

* fix: reorder import statements for consistency

* feat: add token_usage tracking to conversations and update related processing logic

---------

Co-authored-by: kawayiYokami <289104862@qq.com>
2026-01-05 17:26:10 +08:00
Soulter 3615b7dde2 fix: token usage is always 0 in anthropic source (#4328) 2026-01-05 17:06:12 +08:00
RC-CHN 9bcf9bf2a0 fix(dashboard): complete i18n support for shared components (#4327)
* fix(dashboard): complete i18n support for shared components

- Replace hardcoded Chinese strings with i18n translations in:
  - PluginSetSelector.vue
  - ProviderSelector.vue
  - PersonaSelector.vue
  - KnowledgeBaseSelector.vue
  - T2ITemplateEditor.vue
  - AstrBotConfigV4.vue
  - ConfigItemRenderer.vue
  - ProxySelector.vue
  - ListConfigItem.vue

- Add missing translations to locale files:
  - core/shared.json: personaSelector, t2iTemplateEditor
  - core/common.json: autoDetect
  - features/settings.json: network.proxySelector

- Change prop defaults from hardcoded Chinese to empty strings,
  allowing components to use i18n fallback translations

* fix(i18n): 修正插件选择器标签的翻译格式,添加冒号

* fix(deployment): 添加持久化 machine-id PVC 和初始化容器,优化资源限制
2026-01-05 09:45:28 +08:00
Gao Jinzhe 7f5cc7cf1a feat: add on_waiting_llm_request event hook (#4319)
* 加入on_waiting_llm_request钩子

* ruff check
2026-01-04 16:11:12 +08:00
Oscar Shaw f26867c77d ci(stale): 增加 stale action 每次运行的操作限制 (#4256) 2026-01-04 11:20:03 +08:00
Soulter a14d588b44 docs: add Matrix adapter to community maintained section in multiple languages 2026-01-04 10:15:16 +08:00
Soulter e236402d92 chore: update platform adapter name for clarity 2026-01-04 10:12:25 +08:00
Soulter 454841de10 fix: database is locked error when invoking tts command (#4313)
* fix: database is locked error when invoking /tts command

fixes: #4311

* chore: rm pnpm lockfile

* perf: 减少操作数据库的次数
2026-01-03 19:12:39 +08:00
clown145 442b5403df feat(webui): supports force update plugins (#4293) 2026-01-03 15:30:50 +08:00
Soulter 9db7bf59b8 docs: add new community group contact 2026-01-03 00:48:55 +08:00
雪語 3622504021 fix: retry failed due to a mismatch in the msg.id data type of a WeChat Official Account (#4292)
问题描述:
- 控制台显示正常发送消息,但公众号未收到
- 处理时间 > 5秒的消息几乎总是失败(如 AI 图片生成)
- 短消息(<5秒)正常工作

根本原因:
msg.id 是整数类型,但字典 key 使用字符串类型,导致类型不匹配。
检查时整数无法匹配字符串 key,导致每次都创建新的 future,
微信重试时无法重用,最终导致响应失败。

修复内容:
将 msg.id 转换为字符串后再检查字典
  if str(msg.id) in self.wexin_event_workers:

影响范围:
- 修复了微信重试时无法正确重用 future 的问题
- AI 图片生成、长文本生成等耗时操作现在可以正常工作
- 仅影响微信公众号适配器,其他平台不受影响

Fixes #1679
2026-01-02 22:16:04 +08:00
Soulter fc42db40ce chore: bump version to 4.10.6 2026-01-02 12:14:59 +08:00
Soulter e413a002c1 perf: list view mode toggle with localStorage support in ExtensionPage (#4288)
closes: #4253
2026-01-02 11:59:41 +08:00
tjc66666666 6437d759a3 fix: reasoning content inject for openai api (#4284) 2026-01-02 01:09:28 +08:00
Soulter c758b2d888 feat: use shell globbing to match umop config router (#4270)
* feat: use shell globbing to match umop config router

* rf

* fix: use fnmatchcase for case-sensitive matching in UmopConfigRouter
2025-12-31 23:10:12 +08:00
Soulter 510290fe0e chore: bump version to 4.10.5 2025-12-31 17:58:28 +08:00
Soulter c61d62edb6 fix: handle null item-meta in ConfigItemRenderer (#4269)
fixes: #4268
2025-12-31 17:55:49 +08:00
Soulter 45bce6fe76 chore: bump version to 4.10.4 2025-12-31 12:50:37 +08:00
Soulter f156adddf8 feat: enhance configuration editor with template schema support and UI improvements (#4267)
- Added support for template schemas in the configuration editor, allowing users to define and manage additional parameters like temperature, top_p, and max_tokens.
- Improved UI components in ProviderModelsPanel and ObjectEditor for better user interaction, including new configuration buttons and enhanced input handling.
- Updated localization files to include new configuration options.
2025-12-31 12:19:29 +08:00
Soulter b5a4b80c36 perf: Add list item add button (#4259)
fixes: #4254
2025-12-30 15:27:17 +08:00
Soulter 792fb69d6d perf: allow zero chunk overlap in recursive chunker (#4258)
* Allow zero chunk overlap

* Validate recursive chunking bounds
2025-12-30 15:23:05 +08:00
78 changed files with 3068 additions and 323 deletions
+1
View File
@@ -26,6 +26,7 @@ jobs:
- uses: actions/stale@v10
with:
repo-token: ${{ secrets.GITHUB_TOKEN }}
operations-per-run: 200
# 只处理带 bug 标签的 Issue
any-of-labels: 'bug'
+2
View File
@@ -132,6 +132,7 @@ uv run main.py
**社区维护**
- [Matrix](https://github.com/stevessr/astrbot_plugin_matrix_adapter)
- [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter)
- [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat)
- [Bilibili 私信](https://github.com/Hina-Chat/astrbot_plugin_bilibili_adapter)
@@ -208,6 +209,7 @@ pre-commit install
- 5 群:822130018
- 6 群:753075035
- 7 群:743746109
- 8 群:1030353265
- 开发者群:975206796
### Telegram 群组
+1
View File
@@ -134,6 +134,7 @@ Or refer to the official documentation: [Deploy AstrBot from Source](https://ast
**Community Maintained**
- [Matrix](https://github.com/stevessr/astrbot_plugin_matrix_adapter)
- [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter)
- [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat)
- [Bilibili Direct Messages](https://github.com/Hina-Chat/astrbot_plugin_bilibili_adapter)
+1
View File
@@ -134,6 +134,7 @@ Ou consultez la documentation officielle : [Déployer AstrBot depuis les sources
**Maintenues par la communauté**
- [Matrix](https://github.com/stevessr/astrbot_plugin_matrix_adapter)
- [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter)
- [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat)
- [Messages directs Bilibili](https://github.com/Hina-Chat/astrbot_plugin_bilibili_adapter)
+1
View File
@@ -134,6 +134,7 @@ uv run main.py
**コミュニティメンテナンス**
- [Matrix](https://github.com/stevessr/astrbot_plugin_matrix_adapter)
- [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter)
- [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat)
- [Bilibili ダイレクトメッセージ](https://github.com/Hina-Chat/astrbot_plugin_bilibili_adapter)
+1
View File
@@ -134,6 +134,7 @@ uv run main.py
**Поддерживаемые сообществом**
- [Matrix](https://github.com/stevessr/astrbot_plugin_matrix_adapter)
- [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter)
- [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat)
- [Личные сообщения Bilibili](https://github.com/Hina-Chat/astrbot_plugin_bilibili_adapter)
+1
View File
@@ -134,6 +134,7 @@ uv run main.py
**社群維護**
- [Matrix](https://github.com/stevessr/astrbot_plugin_matrix_adapter)
- [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter)
- [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat)
- [Bilibili 私訊](https://github.com/Hina-Chat/astrbot_plugin_bilibili_adapter)
+4
View File
@@ -21,6 +21,9 @@ from astrbot.core.star.register import (
from astrbot.core.star.register import register_on_llm_request as on_llm_request
from astrbot.core.star.register import register_on_llm_response as on_llm_response
from astrbot.core.star.register import register_on_platform_loaded as on_platform_loaded
from astrbot.core.star.register import (
register_on_waiting_llm_request as on_waiting_llm_request,
)
from astrbot.core.star.register import register_permission_type as permission_type
from astrbot.core.star.register import (
register_platform_adapter_type as platform_adapter_type,
@@ -46,6 +49,7 @@ __all__ = [
"on_llm_request",
"on_llm_response",
"on_platform_loaded",
"on_waiting_llm_request",
"permission_type",
"platform_adapter_type",
"regex",
@@ -14,13 +14,13 @@ class TTSCommand:
async def tts(self, event: AstrMessageEvent):
"""开关文本转语音(会话级别)"""
umo = event.unified_msg_origin
ses_tts = SessionServiceManager.is_tts_enabled_for_session(umo)
ses_tts = await SessionServiceManager.is_tts_enabled_for_session(umo)
cfg = self.context.get_config(umo=umo)
tts_enable = cfg["provider_tts_settings"]["enable"]
# 切换状态
new_status = not ses_tts
SessionServiceManager.set_tts_status_for_session(umo, new_status)
await SessionServiceManager.set_tts_status_for_session(umo, new_status)
status_text = "已开启" if new_status else "已关闭"
+1 -1
View File
@@ -1 +1 @@
__version__ = "4.10.3"
__version__ = "4.11.0"
+243
View File
@@ -0,0 +1,243 @@
from typing import TYPE_CHECKING, Protocol, runtime_checkable
from ..message import Message
if TYPE_CHECKING:
from astrbot import logger
else:
try:
from astrbot import logger
except ImportError:
import logging
logger = logging.getLogger("astrbot")
if TYPE_CHECKING:
from astrbot.core.provider.provider import Provider
from ..context.truncator import ContextTruncator
@runtime_checkable
class ContextCompressor(Protocol):
"""
Protocol for context compressors.
Provides an interface for compressing message lists.
"""
def should_compress(
self, messages: list[Message], current_tokens: int, max_tokens: int
) -> bool:
"""Check if compression is needed.
Args:
messages: The message list to evaluate.
current_tokens: The current token count.
max_tokens: The maximum allowed tokens for the model.
Returns:
True if compression is needed, False otherwise.
"""
...
async def __call__(self, messages: list[Message]) -> list[Message]:
"""Compress the message list.
Args:
messages: The original message list.
Returns:
The compressed message list.
"""
...
class TruncateByTurnsCompressor:
"""Truncate by turns compressor implementation.
Truncates the message list by removing older turns.
"""
def __init__(self, truncate_turns: int = 1, compression_threshold: float = 0.82):
"""Initialize the truncate by turns compressor.
Args:
truncate_turns: The number of turns to remove when truncating (default: 1).
compression_threshold: The compression trigger threshold (default: 0.82).
"""
self.truncate_turns = truncate_turns
self.compression_threshold = compression_threshold
def should_compress(
self, messages: list[Message], current_tokens: int, max_tokens: int
) -> bool:
"""Check if compression is needed.
Args:
messages: The message list to evaluate.
current_tokens: The current token count.
max_tokens: The maximum allowed tokens.
Returns:
True if compression is needed, False otherwise.
"""
if max_tokens <= 0 or current_tokens <= 0:
return False
usage_rate = current_tokens / max_tokens
return usage_rate > self.compression_threshold
async def __call__(self, messages: list[Message]) -> list[Message]:
truncator = ContextTruncator()
truncated_messages = truncator.truncate_by_dropping_oldest_turns(
messages,
drop_turns=self.truncate_turns,
)
return truncated_messages
def split_history(
messages: list[Message], keep_recent: int
) -> tuple[list[Message], list[Message], list[Message]]:
"""Split the message list into system messages, messages to summarize, and recent messages.
Ensures that the split point is between complete user-assistant pairs to maintain conversation flow.
Args:
messages: The original message list.
keep_recent: The number of latest messages to keep.
Returns:
tuple: (system_messages, messages_to_summarize, recent_messages)
"""
# keep the system messages
first_non_system = 0
for i, msg in enumerate(messages):
if msg.role != "system":
first_non_system = i
break
system_messages = messages[:first_non_system]
non_system_messages = messages[first_non_system:]
if len(non_system_messages) <= keep_recent:
return system_messages, [], non_system_messages
# Find the split point, ensuring recent_messages starts with a user message
# This maintains complete conversation turns
split_index = len(non_system_messages) - keep_recent
# Search backward from split_index to find the first user message
# This ensures recent_messages starts with a user message (complete turn)
while split_index > 0 and non_system_messages[split_index].role != "user":
# TODO: +=1 or -=1 ? calculate by tokens
split_index -= 1
# If we couldn't find a user message, keep all messages as recent
if split_index == 0:
return system_messages, [], non_system_messages
messages_to_summarize = non_system_messages[:split_index]
recent_messages = non_system_messages[split_index:]
return system_messages, messages_to_summarize, recent_messages
class LLMSummaryCompressor:
"""LLM-based summary compressor.
Uses LLM to summarize the old conversation history, keeping the latest messages.
"""
def __init__(
self,
provider: "Provider",
keep_recent: int = 4,
instruction_text: str | None = None,
compression_threshold: float = 0.82,
):
"""Initialize the LLM summary compressor.
Args:
provider: The LLM provider instance.
keep_recent: The number of latest messages to keep (default: 4).
instruction_text: Custom instruction for summary generation.
compression_threshold: The compression trigger threshold (default: 0.82).
"""
self.provider = provider
self.keep_recent = keep_recent
self.compression_threshold = compression_threshold
self.instruction_text = instruction_text or (
"Based on our full conversation history, produce a concise summary of key takeaways and/or project progress.\n"
"1. Systematically cover all core topics discussed and the final conclusion/outcome for each; clearly highlight the latest primary focus.\n"
"2. If any tools were used, summarize tool usage (total call count) and extract the most valuable insights from tool outputs.\n"
"3. If there was an initial user goal, state it first and describe the current progress/status.\n"
"4. Write the summary in the user's language.\n"
)
def should_compress(
self, messages: list[Message], current_tokens: int, max_tokens: int
) -> bool:
"""Check if compression is needed.
Args:
messages: The message list to evaluate.
current_tokens: The current token count.
max_tokens: The maximum allowed tokens.
Returns:
True if compression is needed, False otherwise.
"""
if max_tokens <= 0 or current_tokens <= 0:
return False
usage_rate = current_tokens / max_tokens
return usage_rate > self.compression_threshold
async def __call__(self, messages: list[Message]) -> list[Message]:
"""Use LLM to generate a summary of the conversation history.
Process:
1. Divide messages: keep the system message and the latest N messages.
2. Send the old messages + the instruction message to the LLM.
3. Reconstruct the message list: [system message, summary message, latest messages].
"""
if len(messages) <= self.keep_recent + 1:
return messages
system_messages, messages_to_summarize, recent_messages = split_history(
messages, self.keep_recent
)
if not messages_to_summarize:
return messages
# build payload
instruction_message = Message(role="user", content=self.instruction_text)
llm_payload = messages_to_summarize + [instruction_message]
# generate summary
try:
response = await self.provider.text_chat(contexts=llm_payload)
summary_content = response.completion_text
except Exception as e:
logger.error(f"Failed to generate summary: {e}")
return messages
# build result
result = []
result.extend(system_messages)
result.append(
Message(
role="user",
content=f"Our previous history conversation summary: {summary_content}",
)
)
result.append(
Message(
role="assistant",
content="Acknowledged the summary of our previous conversation history.",
)
)
result.extend(recent_messages)
return result
+35
View File
@@ -0,0 +1,35 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING
from .compressor import ContextCompressor
from .token_counter import TokenCounter
if TYPE_CHECKING:
from astrbot.core.provider.provider import Provider
@dataclass
class ContextConfig:
"""Context configuration class."""
max_context_tokens: int = 0
"""Maximum number of context tokens. <= 0 means no limit."""
enforce_max_turns: int = -1 # -1 means no limit
"""Maximum number of conversation turns to keep. -1 means no limit. Executed before compression."""
truncate_turns: int = 1
"""Number of conversation turns to discard at once when truncation is triggered.
Two processes will use this value:
1. Enforce max turns truncation.
2. Truncation by turns compression strategy.
"""
llm_compress_instruction: str | None = None
"""Instruction prompt for LLM-based compression."""
llm_compress_keep_recent: int = 0
"""Number of recent messages to keep during LLM-based compression."""
llm_compress_provider: "Provider | None" = None
"""LLM provider used for compression tasks. If None, truncation strategy is used."""
custom_token_counter: TokenCounter | None = None
"""Custom token counting method. If None, the default method is used."""
custom_compressor: ContextCompressor | None = None
"""Custom context compression method. If None, the default method is used."""
+120
View File
@@ -0,0 +1,120 @@
from astrbot import logger
from ..message import Message
from .compressor import LLMSummaryCompressor, TruncateByTurnsCompressor
from .config import ContextConfig
from .token_counter import EstimateTokenCounter
from .truncator import ContextTruncator
class ContextManager:
"""Context compression manager."""
def __init__(
self,
config: ContextConfig,
):
"""Initialize the context manager.
There are two strategies to handle context limit reached:
1. Truncate by turns: remove older messages by turns.
2. LLM-based compression: use LLM to summarize old messages.
Args:
config: The context configuration.
"""
self.config = config
self.token_counter = config.custom_token_counter or EstimateTokenCounter()
self.truncator = ContextTruncator()
if config.custom_compressor:
self.compressor = config.custom_compressor
elif config.llm_compress_provider:
self.compressor = LLMSummaryCompressor(
provider=config.llm_compress_provider,
keep_recent=config.llm_compress_keep_recent,
instruction_text=config.llm_compress_instruction,
)
else:
self.compressor = TruncateByTurnsCompressor(
truncate_turns=config.truncate_turns
)
async def process(
self, messages: list[Message], trusted_token_usage: int = 0
) -> list[Message]:
"""Process the messages.
Args:
messages: The original message list.
Returns:
The processed message list.
"""
try:
result = messages
# 1. 基于轮次的截断 (Enforce max turns)
if self.config.enforce_max_turns != -1:
result = self.truncator.truncate_by_turns(
result,
keep_most_recent_turns=self.config.enforce_max_turns,
drop_turns=self.config.truncate_turns,
)
# 2. 基于 token 的压缩
if self.config.max_context_tokens > 0:
total_tokens = self.token_counter.count_tokens(
result, trusted_token_usage
)
if self.compressor.should_compress(
result, total_tokens, self.config.max_context_tokens
):
result = await self._run_compression(result, total_tokens)
return result
except Exception as e:
logger.error(f"Error during context processing: {e}", exc_info=True)
return messages
async def _run_compression(
self, messages: list[Message], prev_tokens: int
) -> list[Message]:
"""
Compress/truncate the messages.
Args:
messages: The original message list.
prev_tokens: The token count before compression.
Returns:
The compressed/truncated message list.
"""
logger.debug("Compress triggered, starting compression...")
messages = await self.compressor(messages)
# double check
tokens_after_summary = self.token_counter.count_tokens(messages)
# calculate compress rate
compress_rate = (tokens_after_summary / self.config.max_context_tokens) * 100
logger.info(
f"Compress completed."
f" {prev_tokens} -> {tokens_after_summary} tokens,"
f" compression rate: {compress_rate:.2f}%.",
)
# last check
if self.compressor.should_compress(
messages, tokens_after_summary, self.config.max_context_tokens
):
logger.info(
"Context still exceeds max tokens after compression, applying halving truncation..."
)
# still need compress, truncate by half
messages = self.truncator.truncate_by_halving(messages)
return messages
@@ -0,0 +1,64 @@
import json
from typing import Protocol, runtime_checkable
from ..message import Message, TextPart
@runtime_checkable
class TokenCounter(Protocol):
"""
Protocol for token counters.
Provides an interface for counting tokens in message lists.
"""
def count_tokens(
self, messages: list[Message], trusted_token_usage: int = 0
) -> int:
"""Count the total tokens in the message list.
Args:
messages: The message list.
trusted_token_usage: The total token usage that LLM API returned.
For some cases, this value is more accurate.
But some API does not return it, so the value defaults to 0.
Returns:
The total token count.
"""
...
class EstimateTokenCounter:
"""Estimate token counter implementation.
Provides a simple estimation of token count based on character types.
"""
def count_tokens(
self, messages: list[Message], trusted_token_usage: int = 0
) -> int:
if trusted_token_usage > 0:
return trusted_token_usage
total = 0
for msg in messages:
content = msg.content
if isinstance(content, str):
total += self._estimate_tokens(content)
elif isinstance(content, list):
# 处理多模态内容
for part in content:
if isinstance(part, TextPart):
total += self._estimate_tokens(part.text)
# 处理 Tool Calls
if msg.tool_calls:
for tc in msg.tool_calls:
tc_str = json.dumps(tc if isinstance(tc, dict) else tc.model_dump())
total += self._estimate_tokens(tc_str)
return total
def _estimate_tokens(self, text: str) -> int:
chinese_count = len([c for c in text if "\u4e00" <= c <= "\u9fff"])
other_count = len(text) - chinese_count
return int(chinese_count * 0.6 + other_count * 0.3)
+141
View File
@@ -0,0 +1,141 @@
from ..message import Message
class ContextTruncator:
"""Context truncator."""
def fix_messages(self, messages: list[Message]) -> list[Message]:
fixed_messages = []
for message in messages:
if message.role == "tool":
# tool block 前面必须要有 user 和 assistant block
if len(fixed_messages) < 2:
# 这种情况可能是上下文被截断导致的
# 我们直接将之前的上下文都清空
fixed_messages = []
else:
fixed_messages.append(message)
else:
fixed_messages.append(message)
return fixed_messages
def truncate_by_turns(
self,
messages: list[Message],
keep_most_recent_turns: int,
drop_turns: int = 1,
) -> list[Message]:
"""截断上下文列表,确保不超过最大长度。
一个 turn 包含一个 user 消息和一个 assistant 消息。
这个方法会保证截断后的上下文列表符合 OpenAI 的上下文格式。
Args:
messages: 上下文列表
keep_most_recent_turns: 保留最近的对话轮数
drop_turns: 一次性丢弃的对话轮数
Returns:
截断后的上下文列表
"""
if keep_most_recent_turns == -1:
return messages
first_non_system = 0
for i, msg in enumerate(messages):
if msg.role != "system":
first_non_system = i
break
system_messages = messages[:first_non_system]
non_system_messages = messages[first_non_system:]
if len(non_system_messages) // 2 <= keep_most_recent_turns:
return messages
num_to_keep = keep_most_recent_turns - drop_turns + 1
if num_to_keep <= 0:
truncated_contexts = []
else:
truncated_contexts = non_system_messages[-num_to_keep * 2 :]
# 找到第一个 role 为 user 的索引,确保上下文格式正确
index = next(
(i for i, item in enumerate(truncated_contexts) if item.role == "user"),
None,
)
if index is not None and index > 0:
truncated_contexts = truncated_contexts[index:]
result = system_messages + truncated_contexts
return self.fix_messages(result)
def truncate_by_dropping_oldest_turns(
self,
messages: list[Message],
drop_turns: int = 1,
) -> list[Message]:
"""丢弃最旧的 N 个对话轮次。"""
if drop_turns <= 0:
return messages
first_non_system = 0
for i, msg in enumerate(messages):
if msg.role != "system":
first_non_system = i
break
system_messages = messages[:first_non_system]
non_system_messages = messages[first_non_system:]
if len(non_system_messages) // 2 <= drop_turns:
truncated_non_system = []
else:
truncated_non_system = non_system_messages[drop_turns * 2 :]
index = next(
(i for i, item in enumerate(truncated_non_system) if item.role == "user"),
None,
)
if index is not None:
truncated_non_system = truncated_non_system[index:]
elif truncated_non_system:
truncated_non_system = []
result = system_messages + truncated_non_system
return self.fix_messages(result)
def truncate_by_halving(
self,
messages: list[Message],
) -> list[Message]:
"""对半砍策略,删除 50% 的消息"""
if len(messages) <= 2:
return messages
first_non_system = 0
for i, msg in enumerate(messages):
if msg.role != "system":
first_non_system = i
break
system_messages = messages[:first_non_system]
non_system_messages = messages[first_non_system:]
messages_to_delete = len(non_system_messages) // 2
if messages_to_delete == 0:
return messages
truncated_non_system = non_system_messages[messages_to_delete:]
index = next(
(i for i, item in enumerate(truncated_non_system) if item.role == "user"),
None,
)
if index is not None:
truncated_non_system = truncated_non_system[index:]
result = system_messages + truncated_non_system
return self.fix_messages(result)
@@ -25,6 +25,10 @@ from astrbot.core.provider.entities import (
)
from astrbot.core.provider.provider import Provider
from ..context.compressor import ContextCompressor
from ..context.config import ContextConfig
from ..context.manager import ContextManager
from ..context.token_counter import TokenCounter
from ..hooks import BaseAgentRunHooks
from ..message import AssistantMessageSegment, Message, ToolCallMessageSegment
from ..response import AgentResponseData, AgentStats
@@ -47,10 +51,47 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
run_context: ContextWrapper[TContext],
tool_executor: BaseFunctionToolExecutor[TContext],
agent_hooks: BaseAgentRunHooks[TContext],
streaming: bool = False,
# enforce max turns, will discard older turns when exceeded BEFORE compression
# -1 means no limit
enforce_max_turns: int = -1,
# llm compressor
llm_compress_instruction: str | None = None,
llm_compress_keep_recent: int = 0,
llm_compress_provider: Provider | None = None,
# truncate by turns compressor
truncate_turns: int = 1,
# customize
custom_token_counter: TokenCounter | None = None,
custom_compressor: ContextCompressor | None = None,
**kwargs: T.Any,
) -> None:
self.req = request
self.streaming = kwargs.get("streaming", False)
self.streaming = streaming
self.enforce_max_turns = enforce_max_turns
self.llm_compress_instruction = llm_compress_instruction
self.llm_compress_keep_recent = llm_compress_keep_recent
self.llm_compress_provider = llm_compress_provider
self.truncate_turns = truncate_turns
self.custom_token_counter = custom_token_counter
self.custom_compressor = custom_compressor
# we will do compress when:
# 1. before requesting LLM
# TODO: 2. after LLM output a tool call
self.context_config = ContextConfig(
# <=0 will never do compress
max_context_tokens=provider.provider_config.get("max_context_tokens", 0),
# enforce max turns before compression
enforce_max_turns=self.enforce_max_turns,
truncate_turns=self.truncate_turns,
llm_compress_instruction=self.llm_compress_instruction,
llm_compress_keep_recent=self.llm_compress_keep_recent,
llm_compress_provider=self.llm_compress_provider,
custom_token_counter=self.custom_token_counter,
custom_compressor=self.custom_compressor,
)
self.context_manager = ContextManager(self.context_config)
self.provider = provider
self.final_llm_resp = None
self._state = AgentState.IDLE
@@ -110,6 +151,12 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
self._transition_state(AgentState.RUNNING)
llm_resp_result = None
# do truncate and compress
token_usage = self.req.conversation.token_usage if self.req.conversation else 0
self.run_context.messages = await self.context_manager.process(
self.run_context.messages, trusted_token_usage=token_usage
)
async for llm_response in self._iter_llm_responses():
if llm_response.is_chunk:
# update ttft
+115 -29
View File
@@ -5,7 +5,7 @@ from typing import Any, TypedDict
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
VERSION = "4.10.3"
VERSION = "4.11.0"
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db")
WEBHOOK_SUPPORTED_PLATFORMS = [
@@ -83,6 +83,16 @@ DEFAULT_CONFIG = {
"default_personality": "default",
"persona_pool": ["*"],
"prompt_prefix": "{{prompt}}",
"context_limit_reached_strategy": "truncate_by_turns", # or llm_compress
"llm_compress_instruction": (
"Based on our full conversation history, produce a concise summary of key takeaways and/or project progress.\n"
"1. Systematically cover all core topics discussed and the final conclusion/outcome for each; clearly highlight the latest primary focus.\n"
"2. If any tools were used, summarize tool usage (total call count) and extract the most valuable insights from tool outputs.\n"
"3. If there was an initial user goal, state it first and describe the current progress/status.\n"
"4. Write the summary in the user's language.\n"
),
"llm_compress_keep_recent": 4,
"llm_compress_provider_id": "",
"max_context_length": -1,
"dequeue_context_length": 1,
"streaming_response": False,
@@ -179,6 +189,7 @@ class ChatProviderTemplate(TypedDict):
model: str
modalities: list
custom_extra_body: dict[str, Any]
max_context_tokens: int
CHAT_PROVIDER_TEMPLATE = {
@@ -187,6 +198,7 @@ CHAT_PROVIDER_TEMPLATE = {
"model": "",
"modalities": [],
"custom_extra_body": {},
"max_context_tokens": 0,
}
"""
@@ -227,7 +239,7 @@ CONFIG_METADATA_2 = {
"callback_server_host": "0.0.0.0",
"port": 6196,
},
"OneBot v11": {
"OneBot v11 (QQ 个人号等)": {
"id": "default",
"type": "aiocqhttp",
"enable": False,
@@ -235,16 +247,6 @@ CONFIG_METADATA_2 = {
"ws_reverse_port": 6199,
"ws_reverse_token": "",
},
"WeChatPadPro": {
"id": "wechatpadpro",
"type": "wechatpadpro",
"enable": False,
"admin_key": "stay33",
"host": "这里填写你的局域网IP或者公网服务器IP",
"port": 8059,
"wpp_active_message_poll": False,
"wpp_active_message_poll_interval": 3,
},
"微信公众平台": {
"id": "weixin_official_account",
"type": "weixin_official_account",
@@ -374,6 +376,16 @@ CONFIG_METADATA_2 = {
"satori_heartbeat_interval": 10,
"satori_reconnect_delay": 5,
},
"WeChatPadPro": {
"id": "wechatpadpro",
"type": "wechatpadpro",
"enable": False,
"admin_key": "stay33",
"host": "这里填写你的局域网IP或者公网服务器IP",
"port": 8059,
"wpp_active_message_poll": False,
"wpp_active_message_poll_interval": 3,
},
# "WebChat": {
# "id": "webchat",
# "type": "webchat",
@@ -1451,7 +1463,32 @@ CONFIG_METADATA_2 = {
"description": "自定义请求体参数",
"type": "dict",
"items": {},
"hint": "此处添加的键值对将被合并到发送给 API 的 extra_body 中。值可以是字符串、数字或布尔值",
"hint": "用于在请求时添加额外的参数,如 temperature、top_p、max_tokens 等",
"template_schema": {
"temperature": {
"name": "Temperature",
"description": "温度参数",
"hint": "控制输出的随机性,范围通常为 0-2。值越高越随机。",
"type": "float",
"default": 0.6,
"slider": {"min": 0, "max": 2, "step": 0.1},
},
"top_p": {
"name": "Top-p",
"description": "Top-p 采样",
"hint": "核采样参数,范围通常为 0-1。控制模型考虑的概率质量。",
"type": "float",
"default": 1.0,
"slider": {"min": 0, "max": 1, "step": 0.01},
},
"max_tokens": {
"name": "Max Tokens",
"description": "最大令牌数",
"hint": "生成的最大令牌数。",
"type": "int",
"default": 8192,
},
},
},
"provider": {
"type": "string",
@@ -2008,6 +2045,11 @@ CONFIG_METADATA_2 = {
"type": "string",
"hint": "模型名称,如 gpt-4o-mini, deepseek-chat。",
},
"max_context_tokens": {
"description": "模型上下文窗口大小",
"type": "int",
"hint": "模型最大上下文 Token 大小。如果为 0,则会自动从模型元数据填充(如有),也可手动修改。",
},
"dify_api_key": {
"description": "API Key",
"type": "string",
@@ -2515,6 +2557,66 @@ CONFIG_METADATA_3 = {
# "provider_settings.enable": True,
# },
# },
"truncate_and_compress": {
"description": "上下文管理策略",
"type": "object",
"items": {
"provider_settings.max_context_length": {
"description": "最多携带对话轮数",
"type": "int",
"hint": "超出这个数量时丢弃最旧的部分,一轮聊天记为 1 条,-1 为不限制",
"condition": {
"provider_settings.agent_runner_type": "local",
},
},
"provider_settings.dequeue_context_length": {
"description": "丢弃对话轮数",
"type": "int",
"hint": "超出最多携带对话轮数时, 一次丢弃的聊天轮数",
"condition": {
"provider_settings.agent_runner_type": "local",
},
},
"provider_settings.context_limit_reached_strategy": {
"description": "超出模型上下文窗口时的处理方式",
"type": "string",
"options": ["truncate_by_turns", "llm_compress"],
"labels": ["按对话轮数截断", "由 LLM 压缩上下文"],
"condition": {
"provider_settings.agent_runner_type": "local",
},
"hint": "",
},
"provider_settings.llm_compress_instruction": {
"description": "上下文压缩提示词",
"type": "text",
"hint": "如果为空则使用默认提示词。",
"condition": {
"provider_settings.context_limit_reached_strategy": "llm_compress",
"provider_settings.agent_runner_type": "local",
},
},
"provider_settings.llm_compress_keep_recent": {
"description": "压缩时保留最近对话轮数",
"type": "int",
"hint": "始终保留的最近 N 轮对话。",
"condition": {
"provider_settings.context_limit_reached_strategy": "llm_compress",
"provider_settings.agent_runner_type": "local",
},
},
"provider_settings.llm_compress_provider_id": {
"description": "用于上下文压缩的模型提供商 ID",
"type": "string",
"_special": "select_provider",
"hint": "留空时将降级为“按对话轮数截断”的策略。",
"condition": {
"provider_settings.context_limit_reached_strategy": "llm_compress",
"provider_settings.agent_runner_type": "local",
},
},
},
},
"others": {
"description": "其他配置",
"type": "object",
@@ -2579,22 +2681,6 @@ CONFIG_METADATA_3 = {
"provider_settings.streaming_response": True,
},
},
"provider_settings.max_context_length": {
"description": "最多携带对话轮数",
"type": "int",
"hint": "超出这个数量时丢弃最旧的部分,一轮聊天记为 1 条,-1 为不限制",
"condition": {
"provider_settings.agent_runner_type": "local",
},
},
"provider_settings.dequeue_context_length": {
"description": "丢弃对话轮数",
"type": "int",
"hint": "超出最多携带对话轮数时, 一次丢弃的聊天轮数",
"condition": {
"provider_settings.agent_runner_type": "local",
},
},
"provider_settings.wake_prefix": {
"description": "LLM 聊天额外唤醒前缀 ",
"type": "string",
+4
View File
@@ -69,6 +69,7 @@ class ConversationManager:
persona_id=conv_v2.persona_id,
created_at=created_at,
updated_at=updated_at,
token_usage=conv_v2.token_usage,
)
async def new_conversation(
@@ -256,6 +257,7 @@ class ConversationManager:
history: list[dict] | None = None,
title: str | None = None,
persona_id: str | None = None,
token_usage: int | None = None,
) -> None:
"""更新会话的对话.
@@ -263,6 +265,7 @@ class ConversationManager:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
history (List[Dict]): 对话历史记录, 是一个字典列表, 每个字典包含 role 和 content 字段
token_usage (int | None): token 使用量。None 表示不更新
"""
if not conversation_id:
@@ -274,6 +277,7 @@ class ConversationManager:
title=title,
persona_id=persona_id,
content=history,
token_usage=token_usage,
)
async def update_conversation_title(
+1
View File
@@ -90,6 +90,7 @@ class AstrBotCoreLifecycle:
# 初始化 UMOP 配置路由器
self.umop_config_router = UmopConfigRouter(sp=sp)
await self.umop_config_router.initialize()
# 初始化 AstrBot 配置管理器
self.astrbot_config_mgr = AstrBotConfigManager(
+1
View File
@@ -152,6 +152,7 @@ class BaseDatabase(abc.ABC):
title: str | None = None,
persona_id: str | None = None,
content: list[dict] | None = None,
token_usage: int | None = None,
) -> None:
"""Update a conversation's history."""
...
@@ -0,0 +1,61 @@
"""Migration script to add token_usage column to conversations table.
This migration adds the token_usage field to track token consumption for each conversation.
Changes:
- Adds token_usage column to conversations table (default: 0)
"""
from sqlalchemy import text
from astrbot.api import logger, sp
from astrbot.core.db import BaseDatabase
async def migrate_token_usage(db_helper: BaseDatabase):
"""Add token_usage column to conversations table.
This migration adds a new column to track token consumption in conversations.
"""
# 检查是否已经完成迁移
migration_done = await db_helper.get_preference(
"global", "global", "migration_done_token_usage_1"
)
if migration_done:
return
logger.info("开始执行数据库迁移(添加 conversations.token_usage 列)...")
# 这里只适配了 SQLite。因为截止至这一版本,AstrBot 仅支持 SQLite。
try:
async with db_helper.get_db() as session:
# 检查列是否已存在
result = await session.execute(text("PRAGMA table_info(conversations)"))
columns = result.fetchall()
column_names = [col[1] for col in columns]
if "token_usage" in column_names:
logger.info("token_usage 列已存在,跳过迁移")
await sp.put_async(
"global", "global", "migration_done_token_usage_1", True
)
return
# 添加 token_usage 列
await session.execute(
text(
"ALTER TABLE conversations ADD COLUMN token_usage INTEGER NOT NULL DEFAULT 0"
)
)
await session.commit()
logger.info("token_usage 列添加成功")
# 标记迁移完成
await sp.put_async("global", "global", "migration_done_token_usage_1", True)
logger.info("token_usage 迁移完成")
except Exception as e:
logger.error(f"迁移过程中发生错误: {e}", exc_info=True)
raise
+7
View File
@@ -54,6 +54,11 @@ class ConversationV2(SQLModel, table=True):
)
title: str | None = Field(default=None, max_length=255)
persona_id: str | None = Field(default=None)
token_usage: int = Field(default=0, nullable=False)
"""content is a list of OpenAI-formated messages in list[dict] format.
token_usage is the total token value of the messages.
when 0, will use estimated token counter.
"""
__table_args__ = (
UniqueConstraint(
@@ -313,6 +318,8 @@ class Conversation:
persona_id: str | None = ""
created_at: int = 0
updated_at: int = 0
token_usage: int = 0
"""对话的总 token 数量。AstrBot 会保留最近一次 LLM 请求返回的总 token 数,方便统计。token_usage 可能为 0,表示未知。"""
class Personality(TypedDict):
+5 -1
View File
@@ -241,7 +241,9 @@ class SQLiteDatabase(BaseDatabase):
session.add(new_conversation)
return new_conversation
async def update_conversation(self, cid, title=None, persona_id=None, content=None):
async def update_conversation(
self, cid, title=None, persona_id=None, content=None, token_usage=None
):
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
@@ -255,6 +257,8 @@ class SQLiteDatabase(BaseDatabase):
values["persona_id"] = persona_id
if content is not None:
values["content"] = content
if token_usage is not None:
values["token_usage"] = token_usage
if not values:
return None
query = query.values(**values)
@@ -149,8 +149,16 @@ class RecursiveCharacterChunker(BaseChunker):
分割后的文本块列表
"""
chunk_size = chunk_size or self.chunk_size
overlap = overlap or self.chunk_overlap
if chunk_size is None:
chunk_size = self.chunk_size
if overlap is None:
overlap = self.chunk_overlap
if chunk_size <= 0:
raise ValueError("chunk_size must be greater than 0")
if overlap < 0:
raise ValueError("chunk_overlap must be non-negative")
if overlap >= chunk_size:
raise ValueError("chunk_overlap must be less than chunk_size")
result = []
for i in range(0, len(text), chunk_size - overlap):
end = min(i + chunk_size, len(text))
@@ -38,7 +38,7 @@ class AgentRequestSubStage(Stage):
)
return
if not SessionServiceManager.should_process_llm_request(event):
if not await SessionServiceManager.should_process_llm_request(event):
logger.debug(
f"The session {event.unified_msg_origin} has disabled AI capability, skipping processing."
)
@@ -1,12 +1,12 @@
"""本地 Agent 模式的 LLM 调用 Stage"""
import asyncio
import copy
import json
from collections.abc import AsyncGenerator
from astrbot.core import logger
from astrbot.core.agent.message import Message
from astrbot.core.agent.response import AgentStats
from astrbot.core.agent.tool import ToolSet
from astrbot.core.astr_agent_context import AstrAgentContext
from astrbot.core.conversation_mgr import Conversation
@@ -24,6 +24,7 @@ from astrbot.core.provider.entities import (
)
from astrbot.core.star.star_handler import EventType, star_map
from astrbot.core.utils.file_extract import extract_file_moonshotai
from astrbot.core.utils.llm_metadata import LLM_METADATAS
from astrbot.core.utils.metrics import Metric
from astrbot.core.utils.session_lock import session_lock_manager
@@ -41,11 +42,6 @@ class InternalAgentSubStage(Stage):
self.ctx = ctx
conf = ctx.astrbot_config
settings = conf["provider_settings"]
self.max_context_length = settings["max_context_length"] # int
self.dequeue_context_length: int = min(
max(1, settings["dequeue_context_length"]),
self.max_context_length - 1,
)
self.streaming_response: bool = settings["streaming_response"]
self.unsupported_streaming_strategy: str = settings[
"unsupported_streaming_strategy"
@@ -65,6 +61,25 @@ class InternalAgentSubStage(Stage):
"moonshotai_api_key", ""
)
# 上下文管理相关
self.context_limit_reached_strategy: str = settings.get(
"context_limit_reached_strategy", "truncate_by_turns"
)
self.llm_compress_instruction: str = settings.get(
"llm_compress_instruction", ""
)
self.llm_compress_keep_recent: int = settings.get("llm_compress_keep_recent", 4)
self.llm_compress_provider_id: str = settings.get(
"llm_compress_provider_id", ""
)
self.max_context_length = settings["max_context_length"] # int
self.dequeue_context_length: int = min(
max(1, settings["dequeue_context_length"]),
self.max_context_length - 1,
)
if self.dequeue_context_length <= 0:
self.dequeue_context_length = 1
self.conv_manager = ctx.plugin_manager.context.conversation_manager
def _select_provider(self, event: AstrMessageEvent):
@@ -167,34 +182,6 @@ class InternalAgentSubStage(Stage):
},
)
def _truncate_contexts(
self,
contexts: list[dict],
) -> list[dict]:
"""截断上下文列表,确保不超过最大长度"""
if self.max_context_length == -1:
return contexts
if len(contexts) // 2 <= self.max_context_length:
return contexts
truncated_contexts = contexts[
-(self.max_context_length - self.dequeue_context_length + 1) * 2 :
]
# 找到第一个role 为 user 的索引,确保上下文格式正确
index = next(
(
i
for i, item in enumerate(truncated_contexts)
if item.get("role") == "user"
),
None,
)
if index is not None and index > 0:
truncated_contexts = truncated_contexts[index:]
return truncated_contexts
def _modalities_fix(
self,
provider: Provider,
@@ -296,6 +283,7 @@ class InternalAgentSubStage(Stage):
req: ProviderRequest,
llm_response: LLMResponse | None,
all_messages: list[Message],
runner_stats: AgentStats | None,
):
if (
not req
@@ -322,27 +310,37 @@ class InternalAgentSubStage(Stage):
continue
message_to_save.append(message.model_dump())
# get token usage from agent runner stats
token_usage = None
if runner_stats:
token_usage = runner_stats.token_usage.total
await self.conv_manager.update_conversation(
event.unified_msg_origin,
req.conversation.cid,
history=message_to_save,
token_usage=token_usage,
)
def _fix_messages(self, messages: list[dict]) -> list[dict]:
"""验证并且修复上下文"""
fixed_messages = []
for message in messages:
if message.get("role") == "tool":
# tool block 前面必须要有 user 和 assistant block
if len(fixed_messages) < 2:
# 这种情况可能是上下文被截断导致的
# 我们直接将之前的上下文都清空
fixed_messages = []
else:
fixed_messages.append(message)
else:
fixed_messages.append(message)
return fixed_messages
def _get_compress_provider(self) -> Provider | None:
if not self.llm_compress_provider_id:
return None
if self.context_limit_reached_strategy != "llm_compress":
return None
provider = self.ctx.plugin_manager.context.get_provider_by_id(
self.llm_compress_provider_id,
)
if provider is None:
logger.warning(
f"未找到指定的上下文压缩模型 {self.llm_compress_provider_id},将跳过压缩。",
)
return None
if not isinstance(provider, Provider):
logger.warning(
f"指定的上下文压缩模型 {self.llm_compress_provider_id} 不是对话模型,将跳过压缩。"
)
return None
return provider
async def process(
self, event: AstrMessageEvent, provider_wake_prefix: str
@@ -364,6 +362,10 @@ class InternalAgentSubStage(Stage):
streaming_response = bool(enable_streaming)
logger.debug("ready to request llm provider")
# 通知等待调用 LLM(在获取锁之前)
await call_event_hook(event, EventType.OnWaitingLLMRequestEvent)
async with session_lock_manager.acquire_lock(event.unified_msg_origin):
logger.debug("acquired session lock for llm request")
if event.get_extra("provider_request"):
@@ -422,9 +424,10 @@ class InternalAgentSubStage(Stage):
await self._apply_kb(event, req)
# truncate contexts to fit max length
if req.contexts:
req.contexts = self._truncate_contexts(req.contexts)
self._fix_messages(req.contexts)
# NOW moved to ContextManager inside ToolLoopAgentRunner
# if req.contexts:
# req.contexts = self._truncate_contexts(req.contexts)
# self._fix_messages(req.contexts)
# session_id
if not req.session_id:
@@ -440,8 +443,6 @@ class InternalAgentSubStage(Stage):
self.unsupported_streaming_strategy == "turn_off"
and not event.platform_meta.support_streaming_message
)
# 备份 req.contexts
backup_contexts = copy.deepcopy(req.contexts)
# run agent
agent_runner = AgentRunner()
@@ -452,6 +453,15 @@ class InternalAgentSubStage(Stage):
context=self.ctx.plugin_manager.context,
event=event,
)
# inject model context length limit
if provider.provider_config.get("max_context_tokens", 0) <= 0:
model = provider.get_model()
if model_info := LLM_METADATAS.get(model):
provider.provider_config["max_context_tokens"] = model_info[
"limit"
]["context"]
await agent_runner.reset(
provider=provider,
request=req,
@@ -462,6 +472,11 @@ class InternalAgentSubStage(Stage):
tool_executor=FunctionToolExecutor(),
agent_hooks=MAIN_AGENT_HOOKS,
streaming=streaming_response,
llm_compress_instruction=self.llm_compress_instruction,
llm_compress_keep_recent=self.llm_compress_keep_recent,
llm_compress_provider=self._get_compress_provider(),
truncate_turns=self.dequeue_context_length,
enforce_max_turns=self.max_context_length,
)
if streaming_response and not stream_to_general:
@@ -507,14 +522,12 @@ class InternalAgentSubStage(Stage):
):
yield
# 恢复备份的 contexts
req.contexts = backup_contexts
await self._save_to_history(
event,
req,
agent_runner.get_final_llm_resp(),
agent_runner.run_context.messages,
agent_runner.stats,
)
# 异步处理 WebChat 特殊情况
@@ -260,7 +260,7 @@ class ResultDecorateStage(Stage):
should_tts = (
bool(self.ctx.astrbot_config["provider_tts_settings"]["enable"])
and result.is_llm_result()
and SessionServiceManager.should_process_tts_request(event)
and await SessionServiceManager.should_process_tts_request(event)
and random.random() <= self.tts_trigger_probability
and tts_provider
)
@@ -21,7 +21,7 @@ class SessionStatusCheckStage(Stage):
event: AstrMessageEvent,
) -> None | AsyncGenerator[None, None]:
# 检查会话是否整体启用
if not SessionServiceManager.is_session_enabled(event.unified_msg_origin):
if not await SessionServiceManager.is_session_enabled(event.unified_msg_origin):
logger.debug(f"会话 {event.unified_msg_origin} 已被关闭,已终止事件传播。")
# workaround for #2309
+1 -1
View File
@@ -227,7 +227,7 @@ class WakingCheckStage(Stage):
event._extras.pop("parsed_params", None)
# 根据会话配置过滤插件处理器
activated_handlers = SessionPluginManager.filter_handlers_by_session(
activated_handlers = await SessionPluginManager.filter_handlers_by_session(
event,
activated_handlers,
)
@@ -191,7 +191,7 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
if self.active_send_mode:
await self.convert_message(msg, None)
else:
if msg.id in self.wexin_event_workers:
if str(msg.id) in self.wexin_event_workers:
future = self.wexin_event_workers[str(cast(str | int, msg.id))]
logger.debug(f"duplicate message id checked: {msg.id}")
else:
+5
View File
@@ -344,6 +344,11 @@ class LLMResponse:
self.raw_completion = raw_completion
self.is_chunk = is_chunk
if id is not None:
self.id = id
if usage is not None:
self.usage = usage
@property
def completion_text(self):
if self.result_chain:
+27 -12
View File
@@ -119,19 +119,34 @@ class ProviderManager:
TTSProvider,
):
self.curr_tts_provider_inst = prov
sp.put("curr_provider_tts", provider_id, scope="global", scope_id="global")
await sp.put_async(
key="curr_provider_tts",
value=provider_id,
scope="global",
scope_id="global",
)
elif provider_type == ProviderType.SPEECH_TO_TEXT and isinstance(
prov,
STTProvider,
):
self.curr_stt_provider_inst = prov
sp.put("curr_provider_stt", provider_id, scope="global", scope_id="global")
await sp.put_async(
key="curr_provider_stt",
value=provider_id,
scope="global",
scope_id="global",
)
elif provider_type == ProviderType.CHAT_COMPLETION and isinstance(
prov,
Provider,
):
self.curr_provider_inst = prov
sp.put("curr_provider", provider_id, scope="global", scope_id="global")
await sp.put_async(
key="curr_provider",
value=provider_id,
scope="global",
scope_id="global",
)
async def get_provider_by_id(self, provider_id: str) -> Providers | None:
"""根据提供商 ID 获取提供商实例"""
@@ -206,21 +221,21 @@ class ProviderManager:
logger.error(traceback.format_exc())
logger.error(e)
selected_provider_id = sp.get(
"curr_provider",
self.provider_settings.get("default_provider_id"),
selected_provider_id = await sp.get_async(
key="curr_provider",
default=self.provider_settings.get("default_provider_id"),
scope="global",
scope_id="global",
)
selected_stt_provider_id = sp.get(
"curr_provider_stt",
self.provider_stt_settings.get("provider_id"),
selected_stt_provider_id = await sp.get_async(
key="curr_provider_stt",
default=self.provider_stt_settings.get("provider_id"),
scope="global",
scope_id="global",
)
selected_tts_provider_id = sp.get(
"curr_provider_tts",
self.provider_tts_settings.get("provider_id"),
selected_tts_provider_id = await sp.get_async(
key="curr_provider_tts",
default=self.provider_tts_settings.get("provider_id"),
scope="global",
scope_id="global",
)
@@ -378,7 +378,8 @@ class ProviderOpenAIOfficial(Provider):
new_content.append(part)
message["content"] = new_content
# reasoning key is "reasoning_content"
message["reasoning_content"] = reasoning_content
if reasoning_content:
message["reasoning_content"] = reasoning_content
async def _handle_api_error(
self,
+14 -1
View File
@@ -149,9 +149,12 @@ class Context:
contexts: context messages for the LLM
max_steps: Maximum number of tool calls before stopping the loop
**kwargs: Additional keyword arguments. The kwargs will not be passed to the LLM directly for now, but can include:
stream: bool - whether to stream the LLM response
agent_hooks: BaseAgentRunHooks[AstrAgentContext] - hooks to run during agent execution
agent_context: AstrAgentContext - context to use for the agent
other kwargs will be DIRECTLY passed to the runner.reset() method
Returns:
The final LLMResponse after tool calls are completed.
@@ -194,6 +197,15 @@ class Context:
)
agent_runner = ToolLoopAgentRunner()
tool_executor = FunctionToolExecutor()
streaming = kwargs.get("stream", False)
other_kwargs = {
k: v
for k, v in kwargs.items()
if k not in ["stream", "agent_hooks", "agent_context"]
}
await agent_runner.reset(
provider=prov,
request=request,
@@ -203,7 +215,8 @@ class Context:
),
tool_executor=tool_executor,
agent_hooks=agent_hooks,
streaming=kwargs.get("stream", False),
streaming=streaming,
**other_kwargs,
)
async for _ in agent_runner.step_until_done(max_steps):
pass
+2
View File
@@ -12,6 +12,7 @@ from .star_handler import (
register_on_llm_request,
register_on_llm_response,
register_on_platform_loaded,
register_on_waiting_llm_request,
register_permission_type,
register_platform_adapter_type,
register_regex,
@@ -30,6 +31,7 @@ __all__ = [
"register_on_llm_request",
"register_on_llm_response",
"register_on_platform_loaded",
"register_on_waiting_llm_request",
"register_permission_type",
"register_platform_adapter_type",
"register_regex",
@@ -339,6 +339,30 @@ def register_on_platform_loaded(**kwargs):
return decorator
def register_on_waiting_llm_request(**kwargs):
"""当等待调用 LLM 时的通知事件(在获取锁之前)
此钩子在消息确定要调用 LLM 但还未开始排队等锁时触发
适合用于发送"正在思考中..."等用户反馈提示
Examples:
```py
@on_waiting_llm_request()
async def on_waiting_llm(self, event: AstrMessageEvent) -> None:
await event.send("🤔 正在思考中...")
```
"""
def decorator(awaitable):
_ = get_handler_or_create(
awaitable, EventType.OnWaitingLLMRequestEvent, **kwargs
)
return awaitable
return decorator
def register_on_llm_request(**kwargs):
"""当有 LLM 请求时的事件
+38 -26
View File
@@ -12,7 +12,7 @@ class SessionServiceManager:
# =============================================================================
@staticmethod
def is_llm_enabled_for_session(session_id: str) -> bool:
async def is_llm_enabled_for_session(session_id: str) -> bool:
"""检查LLM是否在指定会话中启用
Args:
@@ -23,11 +23,11 @@ class SessionServiceManager:
"""
# 获取会话服务配置
session_services = sp.get(
"session_service_config",
{},
session_services = await sp.get_async(
scope="umo",
scope_id=session_id,
key="session_service_config",
default={},
)
# 如果配置了该会话的LLM状态,返回该状态
@@ -39,7 +39,7 @@ class SessionServiceManager:
return True
@staticmethod
def set_llm_status_for_session(session_id: str, enabled: bool) -> None:
async def set_llm_status_for_session(session_id: str, enabled: bool) -> None:
"""设置LLM在指定会话中的启停状态
Args:
@@ -48,18 +48,24 @@ class SessionServiceManager:
"""
session_config = (
sp.get("session_service_config", {}, scope="umo", scope_id=session_id) or {}
await sp.get_async(
scope="umo",
scope_id=session_id,
key="session_service_config",
default={},
)
or {}
)
session_config["llm_enabled"] = enabled
sp.put(
"session_service_config",
session_config,
await sp.put_async(
scope="umo",
scope_id=session_id,
key="session_service_config",
value=session_config,
)
@staticmethod
def should_process_llm_request(event: AstrMessageEvent) -> bool:
async def should_process_llm_request(event: AstrMessageEvent) -> bool:
"""检查是否应该处理LLM请求
Args:
@@ -70,14 +76,14 @@ class SessionServiceManager:
"""
session_id = event.unified_msg_origin
return SessionServiceManager.is_llm_enabled_for_session(session_id)
return await SessionServiceManager.is_llm_enabled_for_session(session_id)
# =============================================================================
# TTS 相关方法
# =============================================================================
@staticmethod
def is_tts_enabled_for_session(session_id: str) -> bool:
async def is_tts_enabled_for_session(session_id: str) -> bool:
"""检查TTS是否在指定会话中启用
Args:
@@ -88,11 +94,11 @@ class SessionServiceManager:
"""
# 获取会话服务配置
session_services = sp.get(
"session_service_config",
{},
session_services = await sp.get_async(
scope="umo",
scope_id=session_id,
key="session_service_config",
default={},
)
# 如果配置了该会话的TTS状态,返回该状态
@@ -104,7 +110,7 @@ class SessionServiceManager:
return True
@staticmethod
def set_tts_status_for_session(session_id: str, enabled: bool) -> None:
async def set_tts_status_for_session(session_id: str, enabled: bool) -> None:
"""设置TTS在指定会话中的启停状态
Args:
@@ -113,14 +119,20 @@ class SessionServiceManager:
"""
session_config = (
sp.get("session_service_config", {}, scope="umo", scope_id=session_id) or {}
await sp.get_async(
scope="umo",
scope_id=session_id,
key="session_service_config",
default={},
)
or {}
)
session_config["tts_enabled"] = enabled
sp.put(
"session_service_config",
session_config,
await sp.put_async(
scope="umo",
scope_id=session_id,
key="session_service_config",
value=session_config,
)
logger.info(
@@ -128,7 +140,7 @@ class SessionServiceManager:
)
@staticmethod
def should_process_tts_request(event: AstrMessageEvent) -> bool:
async def should_process_tts_request(event: AstrMessageEvent) -> bool:
"""检查是否应该处理TTS请求
Args:
@@ -139,14 +151,14 @@ class SessionServiceManager:
"""
session_id = event.unified_msg_origin
return SessionServiceManager.is_tts_enabled_for_session(session_id)
return await SessionServiceManager.is_tts_enabled_for_session(session_id)
# =============================================================================
# 会话整体启停相关方法
# =============================================================================
@staticmethod
def is_session_enabled(session_id: str) -> bool:
async def is_session_enabled(session_id: str) -> bool:
"""检查会话是否整体启用
Args:
@@ -157,11 +169,11 @@ class SessionServiceManager:
"""
# 获取会话服务配置
session_services = sp.get(
"session_service_config",
{},
session_services = await sp.get_async(
scope="umo",
scope_id=session_id,
key="session_service_config",
default={},
)
# 如果配置了该会话的整体状态,返回该状态
+23 -11
View File
@@ -8,7 +8,10 @@ class SessionPluginManager:
"""管理会话级别的插件启停状态"""
@staticmethod
def is_plugin_enabled_for_session(session_id: str, plugin_name: str) -> bool:
async def is_plugin_enabled_for_session(
session_id: str,
plugin_name: str,
) -> bool:
"""检查插件是否在指定会话中启用
Args:
@@ -20,11 +23,11 @@ class SessionPluginManager:
"""
# 获取会话插件配置
session_plugin_config = sp.get(
"session_plugin_config",
{},
session_plugin_config = await sp.get_async(
scope="umo",
scope_id=session_id,
key="session_plugin_config",
default={},
)
session_config = session_plugin_config.get(session_id, {})
@@ -43,7 +46,10 @@ class SessionPluginManager:
return True
@staticmethod
def filter_handlers_by_session(event: AstrMessageEvent, handlers: list) -> list:
async def filter_handlers_by_session(
event: AstrMessageEvent,
handlers: list,
) -> list:
"""根据会话配置过滤处理器列表
Args:
@@ -59,6 +65,15 @@ class SessionPluginManager:
session_id = event.unified_msg_origin
filtered_handlers = []
session_plugin_config = await sp.get_async(
scope="umo",
scope_id=session_id,
key="session_plugin_config",
default={},
)
session_config = session_plugin_config.get(session_id, {})
disabled_plugins = session_config.get("disabled_plugins", [])
for handler in handlers:
# 获取处理器对应的插件
plugin = star_map.get(handler.handler_module_path)
@@ -76,14 +91,11 @@ class SessionPluginManager:
continue
# 检查插件是否在当前会话中启用
if SessionPluginManager.is_plugin_enabled_for_session(
session_id,
plugin.name,
):
filtered_handlers.append(handler)
else:
if plugin.name in disabled_plugins:
logger.debug(
f"插件 {plugin.name} 在会话 {session_id} 中被禁用,跳过处理器 {handler.handler_name}",
)
else:
filtered_handlers.append(handler)
return filtered_handlers
+1
View File
@@ -184,6 +184,7 @@ class EventType(enum.Enum):
OnPlatformLoadedEvent = enum.auto() # 平台加载完成
AdapterMessageEvent = enum.auto() # 收到适配器发来的消息
OnWaitingLLMRequestEvent = enum.auto() # 等待调用 LLM(在获取锁之前,仅通知)
OnLLMRequestEvent = enum.auto() # 收到 LLM 请求(可以是用户也可以是插件)
OnLLMResponseEvent = enum.auto() # LLM 响应后
OnDecoratingResultEvent = enum.auto() # 发送消息前
+9 -6
View File
@@ -1,3 +1,5 @@
import fnmatch
from astrbot.core.utils.shared_preferences import SharedPreferences
@@ -9,14 +11,15 @@ class UmopConfigRouter:
"""UMOP 到配置文件 ID 的映射"""
self.sp = sp
self._load_routing_table()
async def initialize(self):
await self._load_routing_table()
def _load_routing_table(self):
async def _load_routing_table(self):
"""加载路由表"""
# 从 SharedPreferences 中加载 umop_to_conf_id 映射
sp_data = self.sp.get(
"umop_config_routing",
{},
sp_data = await self.sp.get_async(
key="umop_config_routing",
default={},
scope="global",
scope_id="global",
)
@@ -30,7 +33,7 @@ class UmopConfigRouter:
if len(p1_ls) != 3 or len(p2_ls) != 3:
return False # 非法格式
return all(p == "" or p == "*" or p == t for p, t in zip(p1_ls, p2_ls))
return all(p == "" or fnmatch.fnmatchcase(t, p) for p, t in zip(p1_ls, p2_ls))
def get_conf_id_for_umop(self, umo: str) -> str | None:
"""根据 UMO 获取对应的配置文件 ID
+8
View File
@@ -3,6 +3,7 @@ import traceback
from astrbot.core import astrbot_config, logger
from astrbot.core.astrbot_config_mgr import AstrBotConfig, AstrBotConfigManager
from astrbot.core.db.migration.migra_45_to_46 import migrate_45_to_46
from astrbot.core.db.migration.migra_token_usage import migrate_token_usage
from astrbot.core.db.migration.migra_webchat_session import migrate_webchat_session
@@ -139,6 +140,13 @@ async def migra(
logger.error(f"Migration for webchat session failed: {e!s}")
logger.error(traceback.format_exc())
# migration for token_usage column
try:
await migrate_token_usage(db)
except Exception as e:
logger.error(f"Migration for token_usage column failed: {e!s}")
logger.error(traceback.format_exc())
# migra third party agent runner configs
_c = False
providers = astrbot_config["provider"]
+1
View File
@@ -993,6 +993,7 @@ class BackupRoute(Route):
file_path,
as_attachment=True,
attachment_filename=filename,
conditional=True, # 启用 Range 请求支持(断点续传)
)
except Exception as e:
logger.error(f"下载备份失败: {e}")
+25
View File
@@ -0,0 +1,25 @@
## What's Changed
### 修复
- 修复钉钉适配器中"回复消息 At 发送人"功能失效的问题
- 修复 Xinference STT 在部分情况下无法使用的问题
- 修复"会话隔离"功能在非默认配置下无法生效的问题
- 修复部分 LLM 中转商因 token 使用情况不符合 OpenAI 标准接口规范导致请求报错的问题
- 修复 Deepseek 模型开启思考模式后工具调用报错的问题
- 修复部分操作系统环境下 pip 安装依赖时出现 `UnicodeDecodeError` 错误的问题
### 优化
- 全面优化对思考型模型的支持(如 Anthropic Extended Thinking、Deepseek 思考模式),完整回传 thinking 内容,提升模型推理性能
- 优化 WebUI 记忆侧边栏中"更多功能"和"平台日志"模块的展开状态记忆
- 为 MiniMax TTS 新增 "auto" 音色情绪选项,支持模型根据文本内容自动选择情绪
- 优化备份功能,支持大文件分片下载
- 为 WebSocket 连接添加 max_size 参数,以处理更大的消息并防止接收来自 Satori 平台的大负载时连接断开
- 优化插件安装流程,通过文件安装插件时,若插件已加载则先终止再重新加载,避免重复加载
- 知识库支持将 overlap 参数设置为 0
### 新增
- 为 `dict` 类型的 Schema 新增 JSON value 和 template schema 功能。详见 [dict-类型的-schema](https://docs.astrbot.app/dev/star/guides/plugin-config.html#dict-%E7%B1%BB%E5%9E%8B%E7%9A%84-schema)。
- 新增 `template_list` 类型的 Schema,支持渲染指定 template 下的列表。详见 [template-list-类型的-schema](https://docs.astrbot.app/dev/star/guides/plugin-config.html#template-list-%E7%B1%BB%E5%9E%8B%E7%9A%84-schema)。
+5
View File
@@ -0,0 +1,5 @@
## What's Changed
hotfix of v4.10.4
fix: 部分配置项的输入框不显示,如飞书机器人配置的部分配置项。(#4268
+11
View File
@@ -0,0 +1,11 @@
## What's Changed
hotfix of v4.10.4
fix:
1. ‼️ 部分情况下使用 OpenAI 接口报错与 reasoning_content 有关的问题;
feat:
1. WebUI 已安装插件页支持记忆视图类型(列表/卡片),列表视图显示插件的人类友好名称和 logo。
+19
View File
@@ -0,0 +1,19 @@
## What's Changed
### 新增
- 支持上下文自动压缩功能。入口:配置文件 -> 上下文管理策略 -> 超出模型上下文窗口时的处理方式。详情请查看: [自动上下文压缩](https://docs.astrbot.app/use/context-compress.html) ([#4322](https://github.com/AstrBotDevs/AstrBot/issues/4322))
- 新增 `on_waiting_llm_request` 事件钩子 ([#4319](https://github.com/AstrBotDevs/AstrBot/issues/4319))
- WebUI 支持强制更新插件 ([#4293](https://github.com/AstrBotDevs/AstrBot/issues/4293))
- 社区已提供适用于 [Matrix](https://github.com/stevessr/astrbot_plugin_matrix_adapter) 平台的适配器插件
### 修复
- 修复微信公众号中由于 msg.id 数据类型不匹配导致的重试失败问题 ([#4292](https://github.com/AstrBotDevs/AstrBot/issues/4292))
- 修复调用 TTS 命令时出现的数据库锁定错误 ([#4313](https://github.com/AstrBotDevs/AstrBot/issues/4313))
- 修复 Anthropic 提供商中 token 用量始终为 0 的问题 ([#4328](https://github.com/AstrBotDevs/AstrBot/issues/4328))
### 优化
- 完善共享组件的国际化支持 ([#4327](https://github.com/AstrBotDevs/AstrBot/issues/4327))
- 优化下载大型备份文件时的稳定性,减少失败情况 ([#4329](https://github.com/AstrBotDevs/AstrBot/issues/4329))
@@ -82,7 +82,7 @@
{{ tm('availability.test') }}
<template #activator="{ props }">
<v-btn
icon="mdi-wrench"
icon="mdi-connection"
size="small"
variant="text"
:disabled="!entry.provider.enable"
@@ -93,6 +93,19 @@
</template>
</v-tooltip>
<v-tooltip location="top" max-width="300">
{{ tm('models.configure') }}
<template #activator="{ props }">
<v-btn
icon="mdi-cog"
size="small"
variant="text"
v-bind="props"
@click.stop="emit('open-provider-edit', entry.provider)"
></v-btn>
</template>
</v-tooltip>
<v-btn icon="mdi-delete" size="small" variant="text" color="error" @click.stop="emit('delete-provider', entry.provider)"></v-btn>
</div>
</template>
@@ -203,9 +203,8 @@ function hasVisibleItemsAfter(items, currentIndex) {
<v-col cols="12" sm="6" class="config-input">
<ConfigItemRenderer
v-if="metadata[metadataKey].items[key]"
v-model="iterable[key]"
:item-meta="metadata[metadataKey].items[key]"
:item-meta="metadata[metadataKey].items[key] || null"
:loading="loadingEmbeddingDim"
:show-fullscreen-btn="!!metadata[metadataKey].items[key]?.editor_mode"
@get-embedding-dim="getEmbeddingDimensions(iterable)"
@@ -219,7 +219,7 @@ function getSpecialSubtype(value) {
<ConfigItemRenderer
v-else
v-model="createSelectorModel(itemKey).value"
:item-meta="itemMeta"
:item-meta="itemMeta || null"
:show-fullscreen-btn="!!itemMeta?.editor_mode"
@open-fullscreen="openEditorDialog(itemKey, iterable, itemMeta?.editor_theme, itemMeta?.editor_language)"
/>
@@ -233,12 +233,12 @@ function getSpecialSubtype(value) {
<div v-if="createSelectorModel(itemKey).value && createSelectorModel(itemKey).value.length > 0"
class="selected-plugins-full-width">
<div class="plugins-header">
<small class="text-grey">已选择的插件</small>
<small class="text-grey">{{ t('core.shared.pluginSetSelector.selectedPluginsLabel') }}</small>
</div>
<div class="d-flex flex-wrap ga-2 mt-2">
<v-chip v-for="plugin in (createSelectorModel(itemKey).value || [])" :key="plugin" size="small" label
color="primary" variant="outlined">
{{ plugin === '*' ? '所有插件' : plugin }}
{{ plugin === '*' ? t('core.shared.pluginSetSelector.allPluginsLabel') : plugin }}
</v-chip>
</div>
</div>
@@ -20,13 +20,13 @@
</template>
<template v-else-if="itemMeta?._special === 'provider_pool'">
<ProviderSelector :model-value="modelValue" @update:model-value="emitUpdate" :provider-type="'chat_completion'"
button-text="选择提供商池..." />
:button-text="t('core.shared.providerSelector.selectProviderPool')" />
</template>
<template v-else-if="itemMeta?._special === 'select_persona'">
<PersonaSelector :model-value="modelValue" @update:model-value="emitUpdate" />
</template>
<template v-else-if="itemMeta?._special === 'persona_pool'">
<PersonaSelector :model-value="modelValue" @update:model-value="emitUpdate" button-text="选择人格池..." />
<PersonaSelector :model-value="modelValue" @update:model-value="emitUpdate" :button-text="t('core.shared.personaSelector.selectPersonaPool')" />
</template>
<template v-else-if="itemMeta?._special === 'select_knowledgebase'">
<KnowledgeBaseSelector :model-value="modelValue" @update:model-value="emitUpdate" />
@@ -56,7 +56,7 @@
:loading="loading"
class="ml-2"
>
自动检测
{{ t('core.common.autoDetect') }}
</v-btn>
</div>
</template>
@@ -144,7 +144,7 @@
color="primary"
density="compact"
hide-details
class="flex-grow-1"
style="flex: 1"
></v-slider>
<v-text-field
:model-value="modelValue"
@@ -154,7 +154,7 @@
class="config-field"
type="number"
hide-details
style="max-width: 140px;"
style="flex: 1"
></v-text-field>
</div>
@@ -188,6 +188,7 @@
<ObjectEditor
v-else-if="itemMeta?.type === 'dict'"
:model-value="modelValue"
:item-meta="itemMeta"
@update:model-value="emitUpdate"
class="config-field"
/>
@@ -222,7 +223,7 @@ const props = defineProps({
},
itemMeta: {
type: Object,
required: true
default: null
},
loading: {
type: Boolean,
@@ -324,4 +325,8 @@ function getSpecialSubtype(value) {
.gap-20 {
gap: 20px;
}
:deep(.v-field__input) {
font-size: 14px;
}
</style>
@@ -145,9 +145,11 @@ const viewReadme = () => {
}})</v-list-item-title>
</v-list-item>
<v-list-item @click="updateExtension" :disabled="!extension?.has_update">
<v-list-item @click="updateExtension">
<v-list-item-title>
{{ tm('card.actions.updateTo') }} {{ extension.online_version || extension.version }}
{{ extension.has_update
? tm('card.actions.updateTo') + ' ' + extension.online_version
: tm('card.actions.reinstall') }}
</v-list-item-title>
</v-list-item>
</template>
@@ -20,7 +20,7 @@
</div>
</div>
<v-btn size="small" color="primary" variant="tonal" @click="openDialog" style="flex-shrink: 0;">
{{ buttonText }}
{{ buttonText || tm('knowledgeBaseSelector.buttonText') }}
</v-btn>
</div>
@@ -105,7 +105,7 @@ const props = defineProps({
},
buttonText: {
type: String,
default: '选择知识库...'
default: ''
}
})
@@ -23,7 +23,7 @@
</div>
</div>
<v-btn size="small" color="primary" variant="tonal" @click="openDialog">
{{ preferSingleItem ? '添加更多' : (buttonText || t('core.common.list.modifyButton')) }}
{{ preferSingleItem ? t('core.common.list.addMore') : (buttonText || t('core.common.list.modifyButton')) }}
</v-btn>
</div>
@@ -48,6 +48,14 @@
:placeholder="t('core.common.list.inputPlaceholder')"
class="flex-grow-1">
</v-text-field>
<v-btn
@click="addItem"
variant="tonal"
color="primary"
size="small"
:disabled="!newItem.trim()">
{{ t('core.common.list.addButton') }}
</v-btn>
<v-btn
@click="showBatchImport = true"
variant="tonal"
@@ -167,11 +175,11 @@ const props = defineProps({
},
buttonText: {
type: String,
default: '修改'
default: ''
},
dialogTitle: {
type: String,
default: '修改列表项'
default: ''
},
maxDisplayItems: {
type: Number,
@@ -318,4 +326,4 @@ function cancelBatchImport() {
.v-chip {
margin: 2px;
}
</style>
</style>
+224 -16
View File
@@ -26,8 +26,9 @@
</v-card-title>
<v-card-text class="pa-4" style="max-height: 400px; overflow-y: auto;">
<div v-if="localKeyValuePairs.length > 0">
<div v-for="(pair, index) in localKeyValuePairs" :key="index" class="key-value-pair">
<!-- Regular key-value pairs (non-template) -->
<div v-if="nonTemplatePairs.length > 0">
<div v-for="(pair, index) in nonTemplatePairs" :key="index" class="key-value-pair">
<v-row no-gutters align="center" class="mb-2">
<v-col cols="4">
<v-text-field
@@ -48,15 +49,29 @@
hide-details
placeholder="字符串值"
></v-text-field>
<v-text-field
v-else-if="pair.type === 'number'"
v-model.number="pair.value"
type="number"
density="compact"
variant="outlined"
hide-details
placeholder="数值"
></v-text-field>
<div v-else-if="pair.type === 'number' || pair.type === 'float' || pair.type === 'int'" class="d-flex align-center gap-2 flex-grow-1">
<v-slider
v-if="pair.slider"
:model-value="Number(pair.value) || 0"
@update:model-value="pair.value = $event"
:min="pair.slider.min"
:max="pair.slider.max"
:step="pair.slider.step"
color="primary"
density="compact"
hide-details
class="flex-grow-1"
></v-slider>
<v-text-field
v-model.number="pair.value"
type="number"
density="compact"
variant="outlined"
hide-details
placeholder="数值"
:style="pair.slider ? 'max-width: 120px;' : ''"
></v-text-field>
</div>
<v-switch
v-else-if="pair.type === 'boolean'"
v-model="pair.value"
@@ -81,7 +96,7 @@
variant="text"
size="small"
color="error"
@click="removeKeyValuePair(index)"
@click="removeKeyValuePairByKey(pair.key)"
>
<v-icon>mdi-delete</v-icon>
</v-btn>
@@ -89,7 +104,79 @@
</v-row>
</div>
</div>
<div v-else class="text-center py-8">
<!-- Template schema fields -->
<div v-if="hasTemplateSchema" class="mt-4">
<v-divider class="mb-3"></v-divider>
<div class="text-caption text-grey mb-2">预设</div>
<div v-for="(template, templateKey) in templateSchema" :key="templateKey" class="template-field" :class="{ 'template-field-inactive': !isTemplateKeyAdded(templateKey) }">
<v-row no-gutters align="center" class="mb-2">
<v-col cols="4">
<div class="d-flex flex-column">
<span class="text-caption font-weight-medium">{{ template.name || template.description || templateKey }}</span>
<span v-if="template.hint" class="text-caption text-grey" style="font-size: 0.7rem;">{{ template.hint }}</span>
</div>
</v-col>
<v-col cols="7" class="pl-2 d-flex align-center justify-end">
<v-text-field
v-if="template.type === 'string'"
:model-value="getTemplateValue(templateKey)"
@update:model-value="updateTemplateValue(templateKey, $event)"
density="compact"
variant="outlined"
hide-details
placeholder="字符串值"
></v-text-field>
<div v-else-if="template.type === 'number' || template.type === 'float' || template.type === 'int'" class="d-flex align-center ga-4 flex-grow-1">
<v-slider
v-if="template.slider"
:model-value="Number(getTemplateValue(templateKey)) || 0"
@update:model-value="updateTemplateValue(templateKey, $event)"
:min="template.slider.min"
:max="template.slider.max"
:step="template.slider.step"
color="primary"
density="compact"
hide-details
class="flex-grow-1"
></v-slider>
<v-text-field
:model-value="getTemplateValue(templateKey)"
@update:model-value="updateTemplateValue(templateKey, $event)"
type="number"
density="compact"
variant="outlined"
hide-details
placeholder="数值"
:style="template.slider ? 'max-width: 120px;' : ''"
></v-text-field>
</div>
<v-switch
v-else-if="template.type === 'boolean' || template.type === 'bool'"
:model-value="getTemplateValue(templateKey)"
@update:model-value="updateTemplateValue(templateKey, $event)"
density="compact"
hide-details
color="primary"
></v-switch>
</v-col>
<v-col cols="1" class="pl-2">
<v-btn
v-if="isTemplateKeyAdded(templateKey)"
icon
variant="text"
size="small"
color="error"
@click="removeTemplateKey(templateKey)"
>
<v-icon>mdi-close</v-icon>
</v-btn>
</v-col>
</v-row>
</div>
</div>
<div v-if="localKeyValuePairs.length === 0 && !hasTemplateSchema" class="text-center py-8">
<v-icon size="64" color="grey-lighten-1">mdi-code-json</v-icon>
<p class="text-grey mt-4">暂无参数</p>
</div>
@@ -142,6 +229,10 @@ const props = defineProps({
type: Object,
required: true
},
itemMeta: {
type: Object,
default: null
},
buttonText: {
type: String,
default: '修改'
@@ -164,11 +255,25 @@ const originalKeyValuePairs = ref([])
const newKey = ref('')
const newValueType = ref('string')
// Template schema support
const templateSchema = computed(() => {
return props.itemMeta?.template_schema || {}
})
const hasTemplateSchema = computed(() => {
return Object.keys(templateSchema.value).length > 0
})
//
const displayKeys = computed(() => {
return Object.keys(props.modelValue).slice(0, props.maxDisplayItems)
})
//
const nonTemplatePairs = computed(() => {
return localKeyValuePairs.value.filter(pair => !templateSchema.value[pair.key])
})
// modelValue
watch(() => props.modelValue, (newValue) => {
// This watch is primarily for initialization or external changes
@@ -180,10 +285,24 @@ function initializeLocalKeyValuePairs() {
for (const [key, value] of Object.entries(props.modelValue)) {
let _type = (typeof value) === 'object' ? 'json':(typeof value)
let _value = _type === 'json'?JSON.stringify(value):value
// Check if this key has a template schema
const template = templateSchema.value[key]
if (template) {
// Use template type if available
_type = template.type || _type
// Use template default if value is missing
if (_value === undefined || _value === null) {
_value = template.default !== undefined ? template.default : _value
}
}
localKeyValuePairs.value.push({
key: key,
value: _value,
type: _type
type: _type,
slider: template?.slider,
template: template
})
}
}
@@ -239,8 +358,11 @@ function updateJSON(index, newValue) {
}
}
function removeKeyValuePair(index) {
localKeyValuePairs.value.splice(index, 1)
function removeKeyValuePairByKey(key) {
const index = localKeyValuePairs.value.findIndex(pair => pair.key === key)
if (index >= 0) {
localKeyValuePairs.value.splice(index, 1)
}
}
function updateKey(index, newKey) {
@@ -258,10 +380,83 @@ function updateKey(index, newKey) {
return
}
//
const template = templateSchema.value[newKey]
if (template) {
//
localKeyValuePairs.value[index].type = template.type || localKeyValuePairs.value[index].type
if (localKeyValuePairs.value[index].value === undefined || localKeyValuePairs.value[index].value === null || localKeyValuePairs.value[index].value === '') {
localKeyValuePairs.value[index].value = template.default !== undefined ? template.default : localKeyValuePairs.value[index].value
}
localKeyValuePairs.value[index].slider = template.slider
localKeyValuePairs.value[index].template = template
} else {
//
localKeyValuePairs.value[index].slider = undefined
localKeyValuePairs.value[index].template = undefined
}
//
localKeyValuePairs.value[index].key = newKey
}
function isTemplateKeyAdded(templateKey) {
return localKeyValuePairs.value.some(pair => pair.key === templateKey)
}
function getTemplateValue(templateKey) {
const pair = localKeyValuePairs.value.find(pair => pair.key === templateKey)
if (pair) {
return pair.value
}
const template = templateSchema.value[templateKey]
return template?.default !== undefined ? template.default : getDefaultValueForType(template?.type || 'string')
}
function updateTemplateValue(templateKey, newValue) {
const existingIndex = localKeyValuePairs.value.findIndex(pair => pair.key === templateKey)
const template = templateSchema.value[templateKey]
if (existingIndex >= 0) {
//
localKeyValuePairs.value[existingIndex].value = newValue
} else {
//
let valueType = template?.type || 'string'
localKeyValuePairs.value.push({
key: templateKey,
value: newValue,
type: valueType,
slider: template?.slider,
template: template
})
}
}
function removeTemplateKey(templateKey) {
const index = localKeyValuePairs.value.findIndex(pair => pair.key === templateKey)
if (index >= 0) {
localKeyValuePairs.value.splice(index, 1)
}
}
function getDefaultValueForType(type) {
switch (type) {
case 'int':
case 'float':
case 'number':
return 0
case 'bool':
case 'boolean':
return false
case 'json':
return "{}"
case 'string':
default:
return ""
}
}
function confirmDialog() {
const updatedValue = {}
for (const pair of localKeyValuePairs.value) {
@@ -269,12 +464,17 @@ function confirmDialog() {
let convertedValue = pair.value
//
switch (pair.type) {
case 'int':
convertedValue = parseInt(pair.value) || 0
break
case 'float':
case 'number':
// 0
convertedValue = Number(pair.value)
// 0
// if (isNaN(convertedValue)) convertedValue = 0;
break
case 'bool':
case 'boolean':
// v-switch
// JavaScript false, 0, "", null, undefined, NaN false
@@ -307,4 +507,12 @@ function cancelDialog() {
.key-value-pair {
width: 100%;
}
.template-field {
transition: opacity 0.2s;
}
.template-field-inactive {
opacity: 0.8;
}
</style>
@@ -1,13 +1,13 @@
<template>
<div class="d-flex align-center justify-space-between">
<span v-if="!modelValue" style="color: rgb(var(--v-theme-primaryText));">
未选择
{{ tm('personaSelector.notSelected') }}
</span>
<span v-else>
{{ modelValue === 'default' ? '默认人格' : modelValue }}
{{ modelValue === 'default' ? tm('personaSelector.defaultPersona') : modelValue }}
</span>
<v-btn size="small" color="primary" variant="tonal" @click="openDialog">
{{ buttonText }}
{{ buttonText || tm('personaSelector.buttonText') }}
</v-btn>
</div>
@@ -15,7 +15,7 @@
<v-dialog v-model="dialog" max-width="600px">
<v-card>
<v-card-title class="text-h3 py-4" style="font-weight: normal;">
选择人格
{{ tm('personaSelector.dialogTitle') }}
</v-card-title>
<v-card-text class="pa-2" style="max-height: 400px; overflow-y: auto;">
@@ -30,9 +30,9 @@
:active="selectedPersona === persona.persona_id"
rounded="md"
class="ma-1">
<v-list-item-title>{{ persona.persona_id === 'default' ? '默认人格' : persona.persona_id }}</v-list-item-title>
<v-list-item-title>{{ persona.persona_id === 'default' ? tm('personaSelector.defaultPersona') : persona.persona_id }}</v-list-item-title>
<v-list-item-subtitle>
{{ persona.system_prompt ? persona.system_prompt.substring(0, 50) + '...' : '无描述' }}
{{ persona.system_prompt ? persona.system_prompt.substring(0, 50) + '...' : tm('personaSelector.noDescription') }}
</v-list-item-subtitle>
<template v-slot:append>
@@ -43,21 +43,21 @@
<div v-else-if="!loading && personaList.length === 0" class="text-center py-8">
<v-icon size="64" color="grey-lighten-1">mdi-account-off</v-icon>
<p class="text-grey mt-4">暂无可用的人格</p>
<p class="text-grey mt-4">{{ tm('personaSelector.noPersonas') }}</p>
</div>
</v-card-text>
<v-card-actions class="pa-4">
<v-btn variant="text" color="primary" prepend-icon="mdi-plus" @click="openCreatePersona">
创建新人格
{{ tm('personaSelector.createPersona') }}
</v-btn>
<v-spacer></v-spacer>
<v-btn variant="text" @click="cancelSelection">取消</v-btn>
<v-btn
color="primary"
<v-btn variant="text" @click="cancelSelection">{{ t('core.common.cancel') }}</v-btn>
<v-btn
color="primary"
@click="confirmSelection"
:disabled="!selectedPersona">
确认选择
{{ t('core.common.confirm') }}
</v-btn>
</v-card-actions>
</v-card>
@@ -78,6 +78,7 @@
import { ref, watch } from 'vue'
import axios from 'axios'
import PersonaForm from './PersonaForm.vue'
import { useI18n, useModuleI18n } from '@/i18n/composables'
const props = defineProps({
modelValue: {
@@ -86,11 +87,13 @@ const props = defineProps({
},
buttonText: {
type: String,
default: '选择人格...'
default: ''
}
})
const emit = defineEmits(['update:modelValue'])
const { t } = useI18n()
const { tm } = useModuleI18n('core.shared')
const dialog = ref(false)
const personaList = ref([])
@@ -14,7 +14,7 @@
</span>
</div>
<v-btn size="small" color="primary" variant="tonal" @click="openDialog">
{{ buttonText }}
{{ buttonText || tm('pluginSetSelector.buttonText') }}
</v-btn>
</div>
</div>
@@ -113,7 +113,7 @@ const props = defineProps({
},
buttonText: {
type: String,
default: '选择插件集合...'
default: ''
},
maxDisplayItems: {
type: Number,
@@ -7,7 +7,7 @@
{{ modelValue }}
</span>
<v-btn size="small" color="primary" variant="tonal" @click="openDialog">
{{ buttonText }}
{{ buttonText || tm('providerSelector.buttonText') }}
</v-btn>
</div>
@@ -134,7 +134,7 @@ const props = defineProps({
},
buttonText: {
type: String,
default: '选择提供商...'
default: ''
}
})
@@ -1,13 +1,13 @@
<template>
<h5>GitHub 加速</h5>
<h5>{{ tm('network.proxySelector.title') }}</h5>
<v-radio-group class="mt-2" v-model="radioValue" hide-details="true">
<v-radio label="不使用 GitHub 加速" value="0"></v-radio>
<v-radio :label="tm('network.proxySelector.noProxy')" value="0"></v-radio>
<v-radio value="1">
<template v-slot:label>
<span>使用 GitHub 加速</span>
<span>{{ tm('network.proxySelector.useProxy') }}</span>
<v-btn v-if="radioValue === '1'" class="ml-2" @click="testAllProxies" size="x-small"
variant="tonal" :loading="loadingTestingConnection">
测试代理连通性
{{ tm('network.proxySelector.testConnection') }}
</v-btn>
</template>
</v-radio>
@@ -20,15 +20,15 @@
<div class="d-flex align-center">
<span class="mr-2">{{ proxy }}</span>
<div v-if="proxyStatus[idx]">
<v-chip
:color="proxyStatus[idx].available ? 'success' : 'error'"
size="x-small"
<v-chip
:color="proxyStatus[idx].available ? 'success' : 'error'"
size="x-small"
class="mr-1">
{{ proxyStatus[idx].available ? '可用' : '不可用' }}
{{ proxyStatus[idx].available ? tm('network.proxySelector.available') : tm('network.proxySelector.unavailable') }}
</v-chip>
<v-chip
v-if="proxyStatus[idx].available"
color="info"
<v-chip
v-if="proxyStatus[idx].available"
color="info"
size="x-small">
{{ proxyStatus[idx].latency }}ms
</v-chip>
@@ -36,10 +36,10 @@
</div>
</template>
</v-radio>
<v-radio color="primary" value="-1" label="自定义">
<v-radio color="primary" value="-1" :label="tm('network.proxySelector.custom')">
<template v-slot:label v-if="githubProxyRadioControl === '-1'">
<v-text-field density="compact" v-model="selectedGitHubProxy" variant="outlined"
style="width: 100vw;" placeholder="自定义" hide-details="true">
style="width: 100vw;" :placeholder="tm('network.proxySelector.custom')" hide-details="true">
</v-text-field>
</template>
</v-radio>
@@ -1,32 +1,32 @@
<template>
<v-dialog v-model="dialog" max-width="1400px" persistent scrollable>
<template v-slot:activator="{ props }">
<v-btn
<v-btn
v-bind="props"
variant="outlined"
color="primary"
variant="outlined"
color="primary"
size="small"
:loading="loading"
>
自定义 T2I 模板
{{ tm('t2iTemplateEditor.buttonText') }}
</v-btn>
</template>
<v-card>
<v-card-title class="d-flex align-center justify-space-between">
<span>自定义文转图 HTML 模板</span>
<span>{{ tm('t2iTemplateEditor.dialogTitle') }}</span>
<v-spacer></v-spacer>
<div class="d-flex align-center gap-2" style="width: 60%">
<v-text-field
v-if="isCreatingNew"
v-model="editingName"
label="输入新模板名称"
:label="tm('t2iTemplateEditor.newTemplateNameLabel')"
density="compact"
hide-details
variant="outlined"
class="flex-grow-1"
autofocus
:rules="[v => !!v || '名称不能为空']"
:rules="[v => !!v || tm('t2iTemplateEditor.nameRequired')]"
></v-text-field>
<v-select
v-else
@@ -34,7 +34,7 @@
:items="templates"
item-title="name"
item-value="name"
label="选择模板"
:label="tm('t2iTemplateEditor.selectTemplateLabel')"
density="compact"
hide-details
variant="outlined"
@@ -51,7 +51,7 @@
size="small"
class="ml-2"
>
已应用
{{ tm('t2iTemplateEditor.applied') }}
</v-chip>
<v-btn
v-else
@@ -62,7 +62,7 @@
@click.stop="setActiveTemplate(item.raw.name)"
:loading="applyLoading"
>
应用
{{ tm('t2iTemplateEditor.apply') }}
</v-btn>
</template>
</v-list-item>
@@ -83,7 +83,7 @@
<!-- 左侧编辑器 -->
<v-col cols="6" class="d-flex flex-column">
<v-toolbar density="compact" color="surface-variant">
<v-toolbar-title class="text-subtitle-2">模板编辑器</v-toolbar-title>
<v-toolbar-title class="text-subtitle-2">{{ tm('t2iTemplateEditor.templateEditor') }}</v-toolbar-title>
<v-spacer></v-spacer>
<div class="d-flex align-center pa-1" style="border: 1px solid rgba(0,0,0,0.1); border-radius: 8px;">
<v-btn
@@ -93,7 +93,7 @@
color="success"
>
<v-icon left>mdi-plus</v-icon>
新建
{{ tm('t2iTemplateEditor.new') }}
</v-btn>
<v-divider vertical class="mx-1"></v-divider>
<v-btn
@@ -103,7 +103,7 @@
:loading="resetLoading"
color="warning"
>
重置Base
{{ tm('t2iTemplateEditor.resetBase') }}
</v-btn>
<v-btn
variant="text"
@@ -112,7 +112,7 @@
color="error"
:disabled="isCreatingNew || selectedTemplate === 'base' || !selectedTemplate"
>
删除
{{ tm('t2iTemplateEditor.delete') }}
</v-btn>
<v-divider vertical class="mx-1"></v-divider>
<v-btn
@@ -123,7 +123,7 @@
color="primary"
:disabled="(isCreatingNew && !editingName) || (!isCreatingNew && !selectedTemplate)"
>
保存
{{ tm('t2iTemplateEditor.save') }}
</v-btn>
</div>
</v-toolbar>
@@ -141,15 +141,15 @@
<!-- 右侧预览 -->
<v-col cols="6" class="d-flex flex-column">
<v-toolbar density="compact" color="surface-variant">
<v-toolbar-title class="text-subtitle-2">实时预览(可能有差异)</v-toolbar-title>
<v-toolbar-title class="text-subtitle-2">{{ tm('t2iTemplateEditor.livePreview') }}</v-toolbar-title>
<v-spacer></v-spacer>
<v-btn
variant="text"
size="small"
<v-btn
variant="text"
size="small"
@click="refreshPreview"
:loading="previewLoading"
>
刷新预览
{{ tm('t2iTemplateEditor.refreshPreview') }}
</v-btn>
</v-toolbar>
<div class="flex-grow-1 preview-container">
@@ -168,7 +168,7 @@
<v-col>
<div class="text-caption text-grey">
<v-icon size="16" class="mr-1">mdi-information</v-icon>
支持 jinja2 语法可用变量<code> text | safe </code>要渲染的文本, <code> version </code>AstrBot 版本
{{ tm('t2iTemplateEditor.syntaxHint') }}
</div>
</v-col>
<v-col cols="auto">
@@ -176,7 +176,7 @@
variant="text"
@click="closeDialog"
>
取消
{{ t('core.common.cancel') }}
</v-btn>
<v-btn
color="primary"
@@ -184,7 +184,7 @@
:loading="saveLoading"
:disabled="isCreatingNew || !selectedTemplate"
>
保存应用当前编辑模板
{{ tm('t2iTemplateEditor.saveAndApply') }}
</v-btn>
</v-col>
</v-row>
@@ -194,14 +194,14 @@
<!-- 确认重置对话框 -->
<v-dialog v-model="resetDialog" max-width="400px">
<v-card>
<v-card-title>确认重置</v-card-title>
<v-card-title>{{ tm('t2iTemplateEditor.confirmReset') }}</v-card-title>
<v-card-text>
确定要将 'base' 模板恢复为默认内容吗当前编辑器中的任何未保存更改将丢失此操作无法撤销
{{ tm('t2iTemplateEditor.confirmResetMessage') }}
</v-card-text>
<v-card-actions>
<v-spacer></v-spacer>
<v-btn text @click="resetDialog = false">取消</v-btn>
<v-btn color="warning" @click="confirmReset" :loading="resetLoading">确认重置</v-btn>
<v-btn text @click="resetDialog = false">{{ t('core.common.cancel') }}</v-btn>
<v-btn color="warning" @click="confirmReset" :loading="resetLoading">{{ tm('t2iTemplateEditor.confirmResetButton') }}</v-btn>
</v-card-actions>
</v-card>
</v-dialog>
@@ -209,14 +209,14 @@
<!-- 删除确认对话框 -->
<v-dialog v-model="deleteDialog" max-width="400px">
<v-card>
<v-card-title>确认删除</v-card-title>
<v-card-title>{{ tm('t2iTemplateEditor.confirmDelete') }}</v-card-title>
<v-card-text>
确定要删除模板 '{{ selectedTemplate }}' 此操作无法撤销
{{ tm('t2iTemplateEditor.confirmDeleteMessage', { name: selectedTemplate }) }}
</v-card-text>
<v-card-actions>
<v-spacer></v-spacer>
<v-btn text @click="deleteDialog = false">取消</v-btn>
<v-btn color="error" @click="confirmDelete" :loading="saveLoading">确认删除</v-btn>
<v-btn text @click="deleteDialog = false">{{ t('core.common.cancel') }}</v-btn>
<v-btn color="error" @click="confirmDelete" :loading="saveLoading">{{ tm('t2iTemplateEditor.confirmDeleteButton') }}</v-btn>
</v-card-actions>
</v-card>
</v-dialog>
@@ -224,14 +224,14 @@
<!-- 保存并应用确认对话框 -->
<v-dialog v-model="applyAndCloseDialog" max-width="500px">
<v-card>
<v-card-title>确认操作</v-card-title>
<v-card-title>{{ tm('t2iTemplateEditor.confirmAction') }}</v-card-title>
<v-card-text>
确定要保存对 '{{ selectedTemplate }}' 的修改并将其设为新的活动模板吗
{{ tm('t2iTemplateEditor.confirmApplyMessage', { name: selectedTemplate }) }}
</v-card-text>
<v-card-actions>
<v-spacer></v-spacer>
<v-btn text @click="applyAndCloseDialog = false">取消</v-btn>
<v-btn color="primary" @click="confirmApplyAndClose" :loading="saveLoading">确认</v-btn>
<v-btn text @click="applyAndCloseDialog = false">{{ t('core.common.cancel') }}</v-btn>
<v-btn color="primary" @click="confirmApplyAndClose" :loading="saveLoading">{{ t('core.common.confirm') }}</v-btn>
</v-card-actions>
</v-card>
</v-dialog>
@@ -242,10 +242,11 @@
<script setup>
import { ref, computed, nextTick, watch } from 'vue'
import { VueMonacoEditor } from '@guolao/vue-monaco-editor'
import { useI18n } from '@/i18n/composables'
import { useI18n, useModuleI18n } from '@/i18n/composables'
import axios from 'axios'
const { t } = useI18n()
const { tm } = useModuleI18n('core.shared')
// --- ---
const dialog = ref(false)
@@ -510,7 +510,7 @@ export function useProviderSources(options: UseProviderSourcesOptions) {
const metadata = getModelMetadata(modelName)
let modalities: string[]
if (!metadata) {
modalities = ['text', 'image', 'tool_use']
} else {
@@ -523,13 +523,19 @@ export function useProviderSources(options: UseProviderSourcesOptions) {
}
}
let max_context_tokens = 0
if (metadata?.limit?.context && typeof metadata.limit.context === 'number') {
max_context_tokens = metadata.limit.context
}
const newProvider = {
id: newId,
enable: false,
provider_source_id: sourceId,
model: modelName,
modalities,
custom_extra_body: {}
custom_extra_body: {},
max_context_tokens: max_context_tokens
}
try {
@@ -35,6 +35,7 @@
"yes": "Yes",
"no": "No",
"imagePreview": "Image Preview",
"autoDetect": "Auto Detect",
"dialog": {
"confirmTitle": "Confirm Action",
"confirmMessage": "Are you sure you want to perform this action?",
@@ -74,6 +75,7 @@
"list": {
"addItemPlaceholder": "Add new item, press Enter to confirm",
"addButton": "Add",
"addMore": "Add More",
"batchImport": "Batch Import",
"batchImportTitle": "Batch Import",
"batchImportLabel": "One item per line",
@@ -28,7 +28,9 @@
"cancelSelection": "Cancel",
"noDescription": "No description",
"notActivated": "Not activated",
"note": "*System plugins and disabled plugins are not shown."
"note": "*System plugins and disabled plugins are not shown.",
"selectedPluginsLabel": "Selected Plugins:",
"allPluginsLabel": "All Plugins"
},
"providerSelector": {
"notSelected": "Not selected",
@@ -42,6 +44,45 @@
"clearSelectionSubtitle": "Clear current selection",
"unknownType": "Unknown type",
"createProvider": "Create Provider",
"manageProviders": "Provider Management"
"manageProviders": "Provider Management",
"selectProviderPool": "Select Provider Pool..."
},
"personaSelector": {
"notSelected": "Not selected",
"defaultPersona": "Default Persona",
"buttonText": "Select Persona...",
"dialogTitle": "Select Persona",
"noDescription": "No description",
"noPersonas": "No personas available",
"createPersona": "Create New Persona",
"cancelSelection": "Cancel",
"confirmSelection": "Confirm Selection",
"selectPersonaPool": "Select Persona Pool..."
},
"t2iTemplateEditor": {
"buttonText": "Customize T2I Template",
"dialogTitle": "Customize Text-to-Image HTML Template",
"newTemplateNameLabel": "Enter new template name",
"nameRequired": "Name is required",
"selectTemplateLabel": "Select Template",
"applied": "Applied",
"apply": "Apply",
"templateEditor": "Template Editor",
"new": "New",
"resetBase": "Reset Base",
"delete": "Delete",
"save": "Save",
"livePreview": "Live Preview (may differ)",
"refreshPreview": "Refresh Preview",
"syntaxHint": "Supports jinja2 syntax. Available variables: text | safe (text to render), version (AstrBot version)",
"saveAndApply": "Save and Apply Current Template",
"confirmReset": "Confirm Reset",
"confirmResetMessage": "Are you sure you want to reset the 'base' template to default content? Any unsaved changes in the editor will be lost. This action cannot be undone.",
"confirmResetButton": "Confirm Reset",
"confirmDelete": "Confirm Delete",
"confirmDeleteMessage": "Are you sure you want to delete template '{name}'? This action cannot be undone.",
"confirmDeleteButton": "Confirm Delete",
"confirmAction": "Confirm Action",
"confirmApplyMessage": "Are you sure you want to save changes to '{name}' and set it as the active template?"
}
}
@@ -11,7 +11,12 @@
},
"agent_runner_type": {
"description": "Runner",
"labels": ["Built-in Agent", "Dify", "Coze", "Alibaba Cloud Bailian Application"]
"labels": [
"Built-in Agent",
"Dify",
"Coze",
"Alibaba Cloud Bailian Application"
]
},
"coze_agent_runner_provider_id": {
"description": "Coze Agent Runner Provider ID"
@@ -128,6 +133,39 @@
}
}
},
"truncate_and_compress": {
"description": "Context Management Strategy",
"provider_settings": {
"max_context_length": {
"description": "Maximum Conversation Turns",
"hint": "Discards the oldest parts when this count is exceeded. One conversation round counts as 1, -1 means unlimited"
},
"dequeue_context_length": {
"description": "Dequeue Conversation Turns",
"hint": "Number of conversation turns to discard at once when maximum context length is exceeded"
},
"context_limit_reached_strategy": {
"description": "Handling When Model Context Window is Exceeded",
"labels": [
"Truncate by Turns",
"Compress by LLM"
],
"hint": "When 'Truncate by Turns' is selected, the oldest N conversation turns will be discarded based on the 'Dequeue Conversation Turns' setting above. When 'Compress by LLM' is selected, the specified model will be used for context compression."
},
"llm_compress_instruction": {
"description": "Context Compression Instruction",
"hint": "If empty, the default prompt will be used."
},
"llm_compress_keep_recent": {
"description": "Keep Recent Turns When Compressing",
"hint": "Always keep the most recent N turns of conversation when compressing context."
},
"llm_compress_provider_id": {
"description": "Model Provider ID for Context Compression",
"hint": "When left empty, will fall back to the 'Truncate by Turns' strategy."
}
}
},
"others": {
"description": "Other Settings",
"provider_settings": {
@@ -161,15 +199,10 @@
"unsupported_streaming_strategy": {
"description": "Platforms Without Streaming Support",
"hint": "Select the handling method for platforms that don't support streaming responses. Real-time segmented reply sends content immediately when the system detects segment points like punctuation during streaming reception",
"labels": ["Real-time Segmented Reply", "Disable Streaming Response"]
},
"max_context_length": {
"description": "Maximum Conversation Rounds",
"hint": "Discards the oldest parts when this count is exceeded. One conversation round counts as 1, -1 means unlimited"
},
"dequeue_context_length": {
"description": "Dequeue Conversation Rounds",
"hint": "Number of conversation rounds to discard at once when maximum context length is exceeded"
"labels": [
"Real-time Segmented Reply",
"Disable Streaming Response"
]
},
"wake_prefix": {
"description": "Additional LLM Chat Wake Prefix",
@@ -387,7 +420,10 @@
},
"split_mode": {
"description": "Split Mode",
"labels": ["Regex", "Words List"]
"labels": [
"Regex",
"Words List"
]
},
"regex": {
"description": "Segmentation Regular Expression"
@@ -488,4 +524,4 @@
}
}
}
}
}
@@ -145,6 +145,11 @@
"message": "This plugin has been flagged as containing security risks, including unsafe code or functionalities that may cause system malfunctions or data loss. Do you wish to proceed with the installation?",
"confirm": "Continue",
"cancel": "Cancel"
},
"forceUpdate": {
"title": "No New Version Detected",
"message": "No new version detected for this plugin. Do you want to force reinstall? This will pull the latest code from the remote repository.",
"confirm": "Force Update"
}
},
"messages": {
@@ -185,7 +190,8 @@
"reloadPlugin": "Reload Extension",
"togglePlugin": "Extension",
"viewHandlers": "View Handlers",
"updateTo": "Update to"
"updateTo": "Update to",
"reinstall": "Reinstall"
},
"status": {
"hasUpdate": "New version available",
@@ -207,4 +213,4 @@
"goToManage": "Go to Manage",
"later": "Later"
}
}
}
@@ -129,6 +129,7 @@
"manualDialogPreviewLabel": "Display ID (auto generated)",
"manualDialogPreviewHint": "Generated as sourceId/modelId",
"manualModelRequired": "Please enter a model ID",
"manualModelExists": "Model already exists"
"manualModelExists": "Model already exists",
"configure": "Configure"
}
}
@@ -5,6 +5,15 @@
"title": "GitHub Proxy Address",
"subtitle": "Set the GitHub proxy address used when downloading plugins or updating AstrBot. This is effective in mainland China's network environment. Can be customized, input takes effect in real time. All addresses do not guarantee stability. If errors occur when updating plugins/projects, please first check if the proxy address is working properly.",
"label": "Select GitHub Proxy Address"
},
"proxySelector": {
"title": "GitHub Proxy",
"noProxy": "Don't use GitHub Proxy",
"useProxy": "Use GitHub Proxy",
"testConnection": "Test Connection",
"available": "Available",
"unavailable": "Unavailable",
"custom": "Custom"
}
},
"system": {
@@ -35,6 +35,7 @@
"yes": "是",
"no": "否",
"imagePreview": "图片预览",
"autoDetect": "自动检测",
"dialog": {
"confirmTitle": "确认操作",
"confirmMessage": "你确定要执行此操作吗?",
@@ -74,6 +75,7 @@
"list": {
"addItemPlaceholder": "添加新项,按回车确认添加",
"addButton": "添加",
"addMore": "添加更多",
"batchImport": "批量导入",
"batchImportTitle": "批量导入",
"batchImportLabel": "每行一个项目",
@@ -94,4 +96,4 @@
"copy": "复制",
"noData": "暂无数据"
}
}
}
@@ -28,7 +28,9 @@
"cancelSelection": "取消",
"noDescription": "无描述",
"notActivated": "未激活",
"note": "*不显示系统插件和已经在插件页禁用的插件。"
"note": "*不显示系统插件和已经在插件页禁用的插件。",
"selectedPluginsLabel": "已选择的插件:",
"allPluginsLabel": "所有插件"
},
"providerSelector": {
"notSelected": "未选择",
@@ -42,6 +44,45 @@
"clearSelectionSubtitle": "清除当前选择",
"unknownType": "未知类型",
"createProvider": "创建提供商",
"manageProviders": "提供商管理"
"manageProviders": "提供商管理",
"selectProviderPool": "选择提供商池..."
},
"personaSelector": {
"notSelected": "未选择",
"defaultPersona": "默认人格",
"buttonText": "选择人格...",
"dialogTitle": "选择人格",
"noDescription": "无描述",
"noPersonas": "暂无可用的人格",
"createPersona": "创建新人格",
"cancelSelection": "取消",
"confirmSelection": "确认选择",
"selectPersonaPool": "选择人格池..."
},
"t2iTemplateEditor": {
"buttonText": "自定义 T2I 模板",
"dialogTitle": "自定义文转图 HTML 模板",
"newTemplateNameLabel": "输入新模板名称",
"nameRequired": "名称不能为空",
"selectTemplateLabel": "选择模板",
"applied": "已应用",
"apply": "应用",
"templateEditor": "模板编辑器",
"new": "新建",
"resetBase": "重置Base",
"delete": "删除",
"save": "保存",
"livePreview": "实时预览(可能有差异)",
"refreshPreview": "刷新预览",
"syntaxHint": "支持 jinja2 语法。可用变量:text | safe(要渲染的文本), versionAstrBot 版本)",
"saveAndApply": "保存应用当前编辑模板",
"confirmReset": "确认重置",
"confirmResetMessage": "确定要将 'base' 模板恢复为默认内容吗?当前编辑器中的任何未保存更改将丢失。此操作无法撤销。",
"confirmResetButton": "确认重置",
"confirmDelete": "确认删除",
"confirmDeleteMessage": "确定要删除模板 '{name}' 吗?此操作无法撤销。",
"confirmDeleteButton": "确认删除",
"confirmAction": "确认操作",
"confirmApplyMessage": "确定要保存对 '{name}' 的修改,并将其设为新的活动模板吗?"
}
}
@@ -133,6 +133,36 @@
}
}
},
"truncate_and_compress": {
"description": "上下文管理策略",
"provider_settings": {
"max_context_length": {
"description": "最多携带对话轮数",
"hint": "超出这个数量时丢弃最旧的部分,一轮聊天记为 1 条,-1 为不限制"
},
"dequeue_context_length": {
"description": "丢弃对话轮数",
"hint": "超出最多携带对话轮数时, 一次丢弃的聊天轮数"
},
"context_limit_reached_strategy": {
"description": "超出模型上下文窗口时的处理方式",
"labels": ["按对话轮数截断", "由 LLM 压缩上下文"],
"hint": "当按对话轮数截断时,会根据上面\"丢弃对话轮数\"的配置丢弃最旧的 N 轮对话。当由 LLM 压缩上下文时,会使用指定的模型进行上下文压缩。"
},
"llm_compress_instruction": {
"description": "上下文压缩提示词",
"hint": "如果为空则使用默认提示词。"
},
"llm_compress_keep_recent": {
"description": "压缩时保留最近对话轮数",
"hint": "始终保留的最近 N 轮对话。"
},
"llm_compress_provider_id": {
"description": "用于上下文压缩的模型提供商 ID",
"hint": "留空时将降级为\"按对话轮数截断\"的策略。"
}
}
},
"others": {
"description": "其他配置",
"provider_settings": {
@@ -171,14 +201,7 @@
"关闭流式回复"
]
},
"max_context_length": {
"description": "最多携带对话轮数",
"hint": "超出这个数量时丢弃最旧的部分,一轮聊天记为 1 条,-1 为不限制"
},
"dequeue_context_length": {
"description": "丢弃对话轮数",
"hint": "超出最多携带对话轮数时, 一次丢弃的聊天轮数"
},
"wake_prefix": {
"description": "LLM 聊天额外唤醒前缀",
"hint": "如果唤醒前缀为 /, 额外聊天唤醒前缀为 chat,则需要 /chat 才会触发 LLM 请求"
@@ -145,6 +145,11 @@
"message": "该插件可能包含不安全的代码或功能,可能导致系统异常或数据损失等。请确认是否继续安装?",
"confirm": "继续",
"cancel": "取消"
},
"forceUpdate": {
"title": "未检测到新版本",
"message": "当前插件未检测到新版本,是否强制重新安装?这将从远程仓库拉取最新代码。",
"confirm": "强制更新"
}
},
"messages": {
@@ -185,7 +190,8 @@
"reloadPlugin": "重载插件",
"togglePlugin": "插件",
"viewHandlers": "查看行为",
"updateTo": "更新到"
"updateTo": "更新到",
"reinstall": "重新安装"
},
"status": {
"hasUpdate": "有新版本可用",
@@ -130,6 +130,7 @@
"manualDialogPreviewLabel": "显示 ID(自动生成)",
"manualDialogPreviewHint": "生成规则:源ID/模型ID",
"manualModelRequired": "请输入模型 ID",
"manualModelExists": "该模型已存在"
"manualModelExists": "该模型已存在",
"configure": "配置"
}
}
@@ -5,6 +5,15 @@
"title": "GitHub 加速地址",
"subtitle": "设置下载插件或者更新 AstrBot 时所用的 GitHub 加速地址。这在中国大陆的网络环境有效。可以自定义,输入结果实时生效。所有地址均不保证稳定性,如果在更新插件/项目时出现报错,请首先检查加速地址是否能正常使用。",
"label": "选择 GitHub 加速地址"
},
"proxySelector": {
"title": "GitHub 加速",
"noProxy": "不使用 GitHub 加速",
"useProxy": "使用 GitHub 加速",
"testConnection": "测试代理连通性",
"available": "可用",
"unavailable": "不可用",
"custom": "自定义"
}
},
"system": {
+75 -32
View File
@@ -77,14 +77,26 @@ const readmeDialog = reactive({
repoUrl: null
});
//
const forceUpdateDialog = reactive({
show: false,
extensionName: ''
});
//
const isListView = ref(false);
// localStorage false
const getInitialListViewMode = () => {
if (typeof window !== 'undefined' && window.localStorage) {
return localStorage.getItem('pluginListViewMode') === 'true';
}
return false;
};
const isListView = ref(getInitialListViewMode());
const pluginSearch = ref("");
const loading_ = ref(false);
//
const currentPage = ref(1);
const itemsPerPage = ref(6); // 6 (2 x 3)
//
const dangerConfirmDialog = ref(false);
@@ -113,7 +125,6 @@ const uploadTab = ref('file');
const showPluginFullName = ref(false);
const marketSearch = ref("");
const debouncedMarketSearch = ref("");
const filterKeys = ['name', 'desc', 'author'];
const refreshingMarket = ref(false);
const sortBy = ref('default'); // default, stars, author, updated
const sortOrder = ref('desc'); // desc () or asc ()
@@ -162,18 +173,6 @@ const pluginHeaders = computed(() => [
]);
//
const pluginMarketHeaders = computed(() => [
{ title: tm('table.headers.name'), key: 'name', maxWidth: '200px' },
{ title: tm('table.headers.description'), key: 'desc', maxWidth: '250px' },
{ title: tm('table.headers.author'), key: 'author', maxWidth: '90px' },
{ title: tm('table.headers.stars'), key: 'stars', maxWidth: '80px' },
{ title: tm('table.headers.lastUpdate'), key: 'updated_at', maxWidth: '100px' },
{ title: tm('table.headers.tags'), key: 'tags', maxWidth: '100px' },
{ title: tm('table.headers.actions'), key: 'actions', sortable: false }
]);
//
const filteredExtensions = computed(() => {
const data = Array.isArray(extension_data?.data) ? extension_data.data : [];
@@ -197,9 +196,6 @@ const filteredPlugins = computed(() => {
});
});
const pinnedPlugins = computed(() => {
return pluginMarketData.value.filter(plugin => plugin?.pinned);
});
//
const filteredMarketPlugins = computed(() => {
@@ -385,7 +381,17 @@ const handleUninstallConfirm = (options) => {
}
};
const updateExtension = async (extension_name) => {
const updateExtension = async (extension_name, forceUpdate = false) => {
//
const ext = extension_data.data?.find(e => e.name === extension_name);
//
if (!ext?.has_update && !forceUpdate) {
forceUpdateDialog.extensionName = extension_name;
forceUpdateDialog.show = true;
return;
}
loadingDialog.title = tm('status.loading');
loadingDialog.show = true;
try {
@@ -417,6 +423,14 @@ const updateExtension = async (extension_name) => {
}
};
//
const confirmForceUpdate = () => {
const name = forceUpdateDialog.extensionName;
forceUpdateDialog.show = false;
forceUpdateDialog.extensionName = '';
updateExtension(name, true);
};
const updateAllExtensions = async () => {
if (updatingAll.value || updatableExtensions.value.length === 0) return;
updatingAll.value = true;
@@ -552,14 +566,6 @@ const viewReadme = (plugin) => {
readmeDialog.show = true;
};
const open = (link) => {
if (link) {
window.open(link, '_blank');
}
};
//
const handleInstallPlugin = async (plugin) => {
if (plugin.tags && plugin.tags.includes('danger')) {
@@ -918,6 +924,13 @@ watch(marketSearch, (newVal) => {
}, 300); // 300ms
});
// localStorage
watch(isListView, (newVal) => {
if (typeof window !== 'undefined' && window.localStorage) {
localStorage.setItem('pluginListViewMode', String(newVal));
}
});
</script>
@@ -1037,8 +1050,21 @@ watch(marketSearch, (newVal) => {
<template v-slot:item.name="{ item }">
<div class="d-flex align-center py-2">
<div v-if="item.logo" class="mr-3" style="flex-shrink: 0;">
<img :src="item.logo" :alt="item.name"
style="height: 40px; width: 40px; border-radius: 8px; object-fit: cover;" />
</div>
<div v-else class="mr-3" style="flex-shrink: 0;">
<img :src="defaultPluginIcon" :alt="item.name"
style="height: 40px; width: 40px; border-radius: 8px; object-fit: cover;" />
</div>
<div>
<div class="text-subtitle-1 font-weight-medium">{{ item.name }}</div>
<div class="text-subtitle-1 font-weight-medium">
{{ item.display_name && item.display_name.length ? item.display_name : item.name }}
</div>
<div v-if="item.display_name && item.display_name.length" class="text-caption text-medium-emphasis mt-1">
{{ item.name }}
</div>
<div v-if="item.reserved" class="d-flex align-center mt-1">
<v-chip color="primary" size="x-small" class="font-weight-medium">{{ tm('status.system')
}}</v-chip>
@@ -1048,7 +1074,7 @@ watch(marketSearch, (newVal) => {
</template>
<template v-slot:item.desc="{ item }">
<div class="text-body-2 text-medium-emphasis">{{ item.desc }}</div>
<div class="text-body-2 text-medium-emphasis mt-2 mb-2" style="display: -webkit-box; -webkit-line-clamp: 3; line-clamp: 3; -webkit-box-orient: vertical; overflow: hidden; text-overflow: ellipsis;">{{ item.desc }}</div>
</template>
<template v-slot:item.version="{ item }">
@@ -1084,7 +1110,7 @@ watch(marketSearch, (newVal) => {
<v-tooltip activator="parent" location="top">{{ tm('tooltips.disable') }}</v-tooltip>
</v-btn>
<v-btn icon size="small" color="info" @click="reloadPlugin(item.name)">
<v-btn icon size="small" @click="reloadPlugin(item.name)">
<v-icon>mdi-refresh</v-icon>
<v-tooltip activator="parent" location="top">{{ tm('tooltips.reload') }}</v-tooltip>
</v-btn>
@@ -1104,8 +1130,7 @@ watch(marketSearch, (newVal) => {
<v-tooltip activator="parent" location="top">{{ tm('tooltips.viewDocs') }}</v-tooltip>
</v-btn>
<v-btn icon size="small" color="warning" @click="updateExtension(item.name)"
:v-show="item.has_update">
<v-btn icon size="small" @click="updateExtension(item.name)">
<v-icon>mdi-update</v-icon>
<v-tooltip activator="parent" location="top">{{ tm('tooltips.update') }}</v-tooltip>
</v-btn>
@@ -1772,6 +1797,24 @@ watch(marketSearch, (newVal) => {
</v-card-actions>
</v-card>
</v-dialog>
<!-- 强制更新确认对话框 -->
<v-dialog v-model="forceUpdateDialog.show" max-width="420">
<v-card class="rounded-lg">
<v-card-title class="text-h6 d-flex align-center">
<v-icon color="info" class="mr-2">mdi-information-outline</v-icon>
{{ tm('dialogs.forceUpdate.title') }}
</v-card-title>
<v-card-text>
{{ tm('dialogs.forceUpdate.message') }}
</v-card-text>
<v-card-actions>
<v-spacer></v-spacer>
<v-btn variant="text" @click="forceUpdateDialog.show = false">{{ tm('buttons.cancel') }}</v-btn>
<v-btn color="primary" variant="flat" @click="confirmForceUpdate">{{ tm('dialogs.forceUpdate.confirm') }}</v-btn>
</v-card-actions>
</v-card>
</v-dialog>
</template>
<style scoped>
+17
View File
@@ -43,4 +43,21 @@ spec:
resources:
requests:
storage: 5Gi
# storageClassName: standard
---
# 持久化 machine-id,保持设备标识不变
apiVersion: v1
kind: PersistentVolumeClaim
metadata:
name: napcat-machine-id-pvc
namespace: astrbot-ns
labels:
app: astrbot-stack
spec:
accessModes:
- ReadWriteOnce
resources:
requests:
storage: 10Mi # 只需存储一个 32 字节的文件
# storageClassName: standard
+63 -1
View File
@@ -17,6 +17,32 @@ spec:
labels:
app: astrbot-stack
spec:
# 设置固定主机名,避免 Pod 重启后主机名变化触发风控
hostname: napcat-host
subdomain: astrbot-stack
# 优雅关闭时间,给 NapCat 足够时间保存状态
terminationGracePeriodSeconds: 60
# 初始化容器:首次生成随机 machine-id,后续复用
initContainers:
- name: init-machine-id
image: busybox:latest
command:
- /bin/sh
- -c
- |
# 仅在 machine-id 不存在时随机生成一个
if [ ! -f /machine-id-data/machine-id ]; then
# 使用 /dev/urandom 生成随机 UUID (32位十六进制)
cat /proc/sys/kernel/random/uuid | tr -d '-' > /machine-id-data/machine-id
echo "Machine ID generated: $(cat /machine-id-data/machine-id)"
else
echo "Machine ID exists: $(cat /machine-id-data/machine-id)"
fi
volumeMounts:
- name: machine-id-data
mountPath: /machine-id-data
containers:
- name: napcat
image: mlikiowa/napcat-docker:latest
@@ -28,9 +54,19 @@ spec:
value: "1000"
- name: MODE
value: "astrbot"
- name: TZ
value: "Asia/Shanghai"
ports:
- containerPort: 6099
name: napcat-web
# 资源限制:确保 Guaranteed QoS,减少被驱逐的可能
resources:
requests:
memory: "512Mi"
cpu: "250m"
limits:
memory: "1Gi"
cpu: "1000m"
volumeMounts:
- name: shared-data
mountPath: /AstrBot/data
@@ -38,6 +74,14 @@ spec:
mountPath: /app/napcat/config
- name: napcat-qq
mountPath: /app/.config/QQ
# 挂载持久化的 machine-id
- name: machine-id-data
mountPath: /etc/machine-id
subPath: machine-id
readOnly: true
- name: localtime
mountPath: /etc/localtime
readOnly: true
- name: astrbot
image: soulter/astrbot:latest
@@ -48,9 +92,19 @@ spec:
ports:
- containerPort: 6185
name: astrbot-web
resources:
requests:
memory: "256Mi"
cpu: "100m"
limits:
memory: "512Mi"
cpu: "500m"
volumeMounts:
- name: shared-data
mountPath: /AstrBot/data
- name: localtime
mountPath: /etc/localtime
readOnly: true
volumes:
- name: shared-data
@@ -61,4 +115,12 @@ spec:
claimName: napcat-config-pvc
- name: napcat-qq
persistentVolumeClaim:
claimName: napcat-qq-pvc
claimName: napcat-qq-pvc
# 持久化 machine-id(首次随机生成,后续复用)
- name: machine-id-data
persistentVolumeClaim:
claimName: napcat-machine-id-pvc
- name: localtime
hostPath:
path: /etc/localtime
type: File
+1 -1
View File
@@ -1,6 +1,6 @@
[project]
name = "AstrBot"
version = "4.10.3"
version = "4.11.0"
description = "Easy-to-use multi-platform LLM chatbot and development framework"
readme = "README.md"
requires-python = ">=3.10"
+774
View File
@@ -0,0 +1,774 @@
"""Comprehensive tests for ContextManager."""
import sys
from pathlib import Path
from typing import Literal
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
# Add parent directory to path to avoid circular import issues
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
from astrbot.core.agent.context.config import ContextConfig
from astrbot.core.agent.context.manager import ContextManager
from astrbot.core.agent.message import Message, TextPart
from astrbot.core.provider.entities import LLMResponse
class MockProvider:
"""模拟 Provider"""
def __init__(self):
self.provider_config = {
"id": "test_provider",
"model": "gpt-4",
"modalities": ["text", "image", "tool_use"],
}
async def text_chat(self, **kwargs):
"""模拟 LLM 调用,返回摘要"""
messages = kwargs.get("messages", [])
# 简单的摘要逻辑:返回消息数量统计
return LLMResponse(
role="assistant",
completion_text=f"历史对话包含 {len(messages) - 1} 条消息,主要讨论了技术话题。",
)
def get_model(self):
return "gpt-4"
def meta(self):
return MagicMock(id="test_provider", type="openai")
class TestContextManager:
"""Test suite for ContextManager."""
def create_message(
self, role: Literal["system", "user", "assistant", "tool"], content: str
) -> Message:
"""Helper to create a simple text message."""
return Message(role=role, content=content)
def create_messages(self, count: int) -> list[Message]:
"""Helper to create alternating user/assistant messages."""
messages = []
for i in range(count):
role = "user" if i % 2 == 0 else "assistant"
messages.append(self.create_message(role, f"Message {i}"))
return messages
# ==================== Basic Initialization Tests ====================
def test_init_with_minimal_config(self):
"""Test initialization with minimal configuration."""
config = ContextConfig()
manager = ContextManager(config)
assert manager.config == config
assert manager.token_counter is not None
assert manager.truncator is not None
assert manager.compressor is not None
def test_init_with_llm_compressor(self):
"""Test initialization with LLM-based compression."""
mock_provider = MockProvider()
config = ContextConfig(
llm_compress_provider=mock_provider, # type: ignore
llm_compress_keep_recent=5,
llm_compress_instruction="Summarize the conversation",
)
manager = ContextManager(config)
from astrbot.core.agent.context.compressor import LLMSummaryCompressor
assert isinstance(manager.compressor, LLMSummaryCompressor)
def test_init_with_truncate_compressor(self):
"""Test initialization with truncate-based compression (default)."""
config = ContextConfig(truncate_turns=3)
manager = ContextManager(config)
from astrbot.core.agent.context.compressor import TruncateByTurnsCompressor
assert isinstance(manager.compressor, TruncateByTurnsCompressor)
# ==================== Empty and Edge Cases ====================
@pytest.mark.asyncio
async def test_process_empty_messages(self):
"""Test processing an empty message list."""
config = ContextConfig()
manager = ContextManager(config)
result = await manager.process([])
assert result == []
@pytest.mark.asyncio
async def test_process_single_message(self):
"""Test processing a single message."""
config = ContextConfig()
manager = ContextManager(config)
messages = [self.create_message("user", "Hello")]
result = await manager.process(messages)
assert len(result) == 1
assert result[0].content == "Hello"
@pytest.mark.asyncio
async def test_process_with_no_limits(self):
"""Test processing when no limits are set (no truncation or compression)."""
config = ContextConfig(max_context_tokens=0, enforce_max_turns=-1)
manager = ContextManager(config)
messages = self.create_messages(20)
result = await manager.process(messages)
assert len(result) == 20
assert result == messages
# ==================== Enforce Max Turns Tests ====================
@pytest.mark.asyncio
async def test_enforce_max_turns_basic(self):
"""Test basic enforce_max_turns functionality."""
config = ContextConfig(enforce_max_turns=3, truncate_turns=1)
manager = ContextManager(config)
# Create 10 turns (20 messages)
messages = self.create_messages(20)
result = await manager.process(messages)
# Should keep only 3 most recent turns (6 messages)
assert len(result) <= 8 # May vary due to truncation logic
@pytest.mark.asyncio
async def test_enforce_max_turns_zero(self):
"""Test enforce_max_turns with value 0 (should keep nothing)."""
config = ContextConfig(enforce_max_turns=0, truncate_turns=1)
manager = ContextManager(config)
messages = self.create_messages(10)
result = await manager.process(messages)
# Should result in empty or minimal message list
assert len(result) <= 2
@pytest.mark.asyncio
async def test_enforce_max_turns_negative(self):
"""Test enforce_max_turns with -1 (no limit)."""
config = ContextConfig(enforce_max_turns=-1)
manager = ContextManager(config)
messages = self.create_messages(20)
result = await manager.process(messages)
assert len(result) == 20
@pytest.mark.asyncio
async def test_enforce_max_turns_with_system_messages(self):
"""Test enforce_max_turns preserves system messages."""
config = ContextConfig(enforce_max_turns=2, truncate_turns=1)
manager = ContextManager(config)
messages = [
self.create_message("system", "System instruction"),
*self.create_messages(10),
]
result = await manager.process(messages)
# System message should be preserved
system_msgs = [m for m in result if m.role == "system"]
assert len(system_msgs) >= 1
assert system_msgs[0].content == "System instruction"
# ==================== Token-based Compression Tests ====================
@pytest.mark.asyncio
async def test_token_compression_not_triggered_below_threshold(self):
"""Test that compression is not triggered below threshold."""
config = ContextConfig(max_context_tokens=1000)
manager = ContextManager(config)
# Create messages that total less than threshold
messages = [self.create_message("user", "Hi" * 50)] # ~100 tokens
with patch.object(
manager.compressor, "should_compress", return_value=False
) as mock_should_compress:
with patch.object(
manager.compressor, "__call__", new_callable=AsyncMock
) as mock_compress:
result = await manager.process(messages)
# should_compress should be called
mock_should_compress.assert_called_once()
# Compressor should not be called
mock_compress.assert_not_called()
assert result == messages
@pytest.mark.asyncio
async def test_token_compression_triggered_above_threshold(self):
"""Test that compression is triggered above threshold."""
config = ContextConfig(max_context_tokens=100, truncate_turns=1)
manager = ContextManager(config)
# Create messages that exceed threshold (0.82 * 100 = 82 tokens)
# 300 chars * 0.3 = 90 tokens > 82 threshold
long_text = "x" * 300 # ~90 tokens, above threshold
messages = [self.create_message("user", long_text)]
# Mock compressor to return smaller result
compressed = [self.create_message("user", "short")]
# Create a mock compressor
mock_compressor = AsyncMock()
mock_compressor.compression_threshold = 0.82
mock_compressor.return_value = compressed
# Mock should_compress to return True first time, False after
call_count = 0
def mock_should_compress(*args, **kwargs):
nonlocal call_count
call_count += 1
return call_count == 1
mock_compressor.should_compress = mock_should_compress
manager.compressor = mock_compressor
result = await manager.process(messages)
# Compressor should be called
mock_compressor.assert_called_once()
# Result should be the compressed version
assert len(result) <= len(messages)
@pytest.mark.asyncio
async def test_token_compression_with_zero_max_tokens(self):
"""Test that compression is skipped when max_context_tokens is 0."""
config = ContextConfig(max_context_tokens=0)
manager = ContextManager(config)
messages = [self.create_message("user", "x" * 10000)]
with patch.object(
manager.compressor, "__call__", new_callable=AsyncMock
) as mock_compress:
result = await manager.process(messages)
# Compressor should not be called when max_context_tokens is 0
mock_compress.assert_not_called()
assert result == messages
@pytest.mark.asyncio
async def test_token_compression_with_negative_max_tokens(self):
"""Test that compression is skipped when max_context_tokens is negative."""
config = ContextConfig(max_context_tokens=-100)
manager = ContextManager(config)
messages = [self.create_message("user", "x" * 10000)]
with patch.object(
manager.compressor, "__call__", new_callable=AsyncMock
) as mock_compress:
result = await manager.process(messages)
# Compressor should not be called
mock_compress.assert_not_called()
assert result == messages
@pytest.mark.asyncio
async def test_double_check_after_compression(self):
"""Test that halving is applied if still over threshold after compression."""
config = ContextConfig(max_context_tokens=100)
manager = ContextManager(config)
# Create messages that would still be over threshold after compression
long_messages = [self.create_message("user", "x" * 200) for _ in range(10)]
# Mock compressor to return messages still over threshold
async def mock_compress(msgs):
return msgs # Return same messages (still over limit)
# Mock should_compress to return True twice (before and after compression)
with patch.object(manager.compressor, "should_compress", return_value=True):
with patch.object(manager.compressor, "__call__", new=mock_compress):
with patch.object(
manager.truncator,
"truncate_by_halving",
return_value=long_messages[:5],
) as mock_halving:
_ = await manager.process(long_messages)
# Halving should be called
mock_halving.assert_called_once()
# ==================== Combined Truncation and Compression Tests ====================
@pytest.mark.asyncio
async def test_combined_enforce_turns_and_token_limit(self):
"""Test combining enforce_max_turns and token limit."""
config = ContextConfig(
enforce_max_turns=5, max_context_tokens=500, truncate_turns=1
)
manager = ContextManager(config)
# Create many messages
messages = self.create_messages(30)
result = await manager.process(messages)
# Should be truncated by both mechanisms
assert len(result) < 30
@pytest.mark.asyncio
async def test_sequential_processing_order(self):
"""Test that enforce_max_turns happens before token compression."""
config = ContextConfig(enforce_max_turns=5, max_context_tokens=1000)
manager = ContextManager(config)
messages = self.create_messages(20)
# Mock the truncator to track calls
with patch.object(
manager.truncator,
"truncate_by_turns",
wraps=manager.truncator.truncate_by_turns,
) as mock_truncate:
await manager.process(messages)
# Truncator should be called first
mock_truncate.assert_called_once()
# ==================== Error Handling Tests ====================
@pytest.mark.asyncio
async def test_error_handling_returns_original_messages(self):
"""Test that errors during processing return original messages."""
config = ContextConfig(max_context_tokens=100)
manager = ContextManager(config)
messages = self.create_messages(5)
# Make compressor raise an exception
with patch.object(
manager.compressor, "__call__", side_effect=Exception("Test error")
):
result = await manager.process(messages)
# Should return original messages despite error
assert result == messages
@pytest.mark.asyncio
async def test_error_handling_logs_exception(self):
"""Test that errors are logged."""
config = ContextConfig(max_context_tokens=100)
manager = ContextManager(config)
# Create messages that will trigger compression (> 82 tokens)
messages = [self.create_message("user", "x" * 300)] # ~90 tokens
# Replace compressor with one that raises an exception
mock_compressor = AsyncMock(side_effect=Exception("Test error"))
mock_compressor.compression_threshold = 0.82
mock_compressor.should_compress = MagicMock(return_value=True)
manager.compressor = mock_compressor
with patch("astrbot.core.agent.context.manager.logger") as mock_logger:
result = await manager.process(messages)
# Logger error method should be called
assert mock_logger.error.called
# Should return original messages on error
assert result == messages
# ==================== Multi-modal Content Tests ====================
@pytest.mark.asyncio
async def test_process_messages_with_textpart_content(self):
"""Test processing messages with TextPart content."""
config = ContextConfig()
manager = ContextManager(config)
messages = [
Message(role="user", content=[TextPart(text="Hello")]),
Message(role="assistant", content=[TextPart(text="Hi there")]),
]
result = await manager.process(messages)
assert len(result) == 2
assert result == messages
@pytest.mark.asyncio
async def test_token_counting_with_multimodal_content(self):
"""Test token counting works with multi-modal content."""
config = ContextConfig(max_context_tokens=50)
manager = ContextManager(config)
# Need enough tokens to exceed threshold: 50 * 0.82 = 41 tokens
# 150 chars * 0.3 = 45 tokens > 41
messages = [
Message(role="user", content=[TextPart(text="x" * 150)]),
]
# Should trigger compression due to token count
tokens = manager.token_counter.count_tokens(messages)
needs_compression = manager.compressor.should_compress(messages, tokens, 50)
assert tokens > 0 # Tokens should be counted
assert needs_compression # Should trigger compression
# ==================== Tool Calls Tests ====================
@pytest.mark.asyncio
async def test_process_messages_with_tool_calls(self):
"""Test processing messages with tool calls."""
config = ContextConfig()
manager = ContextManager(config)
messages = [
Message(
role="assistant",
content="Let me search for that",
tool_calls=[
{
"id": "call_1",
"type": "function",
"function": {"name": "search", "arguments": "{}"},
}
],
),
Message(role="tool", content="Search result", tool_call_id="call_1"),
]
result = await manager.process(messages)
assert len(result) == 2
# ==================== Compressor should_compress Tests ====================
@pytest.mark.asyncio
async def test_should_compress_empty_messages(self):
"""Test should_compress with empty messages."""
config = ContextConfig(max_context_tokens=100)
manager = ContextManager(config)
# Compressor's should_compress should handle empty gracefully
needs_compression = manager.compressor.should_compress([], 0, 100)
assert not needs_compression
@pytest.mark.asyncio
async def test_should_compress_below_threshold(self):
"""Test should_compress when below compression threshold."""
config = ContextConfig(max_context_tokens=1000)
manager = ContextManager(config)
messages = [self.create_message("user", "Hello")]
tokens = manager.token_counter.count_tokens(messages)
needs_compression = manager.compressor.should_compress(messages, tokens, 1000)
assert not needs_compression
@pytest.mark.asyncio
async def test_should_compress_above_threshold(self):
"""Test should_compress when above compression threshold."""
config = ContextConfig(max_context_tokens=100)
manager = ContextManager(config)
# Create message with many tokens
messages = [self.create_message("user", "这是测试" * 50)]
tokens = manager.token_counter.count_tokens(messages)
needs_compression = manager.compressor.should_compress(messages, tokens, 100)
# Should need compression if tokens > 82 (0.82 * 100)
assert needs_compression == (tokens > 82)
# ==================== Truncator Halving Tests ====================
def test_truncate_by_halving_basic(self):
"""Test truncate_by_halving removes middle 50%."""
config = ContextConfig()
manager = ContextManager(config)
messages = self.create_messages(10)
result = manager.truncator.truncate_by_halving(messages)
# Should keep roughly half
assert len(result) < len(messages)
def test_truncate_by_halving_empty_list(self):
"""Test truncate_by_halving with empty list."""
config = ContextConfig()
manager = ContextManager(config)
result = manager.truncator.truncate_by_halving([])
assert result == []
def test_truncate_by_halving_single_message(self):
"""Test truncate_by_halving with single message."""
config = ContextConfig()
manager = ContextManager(config)
messages = [self.create_message("user", "Hello")]
result = manager.truncator.truncate_by_halving(messages)
assert len(result) <= 1
# ==================== Complex Scenarios ====================
@pytest.mark.asyncio
async def test_multiple_compression_cycles(self):
"""Test that compression can be triggered multiple times in sequence."""
config = ContextConfig(max_context_tokens=50, truncate_turns=1)
manager = ContextManager(config)
# Process messages multiple times
messages = self.create_messages(10)
result1 = await manager.process(messages)
result2 = await manager.process(result1)
result3 = await manager.process(result2)
# Each cycle should maintain or reduce message count
assert len(result3) <= len(result2) <= len(result1)
@pytest.mark.asyncio
async def test_alternating_roles_preserved(self):
"""Test that user/assistant alternation is preserved after processing."""
config = ContextConfig(enforce_max_turns=3, truncate_turns=1)
manager = ContextManager(config)
messages = self.create_messages(20)
result = await manager.process(messages)
# Check that roles still alternate (excluding system messages)
non_system = [m for m in result if m.role != "system"]
if len(non_system) >= 2:
# Should start with user
assert non_system[0].role == "user"
@pytest.mark.asyncio
async def test_compression_threshold_default(self):
"""Test that compression threshold is used correctly."""
config = ContextConfig(max_context_tokens=100)
manager = ContextManager(config)
# Verify the default threshold is 0.82
assert manager.compressor.compression_threshold == 0.82
# Test threshold logic
messages = [self.create_message("user", "x" * 81)] # ~24 tokens
tokens = manager.token_counter.count_tokens(messages)
needs_compression = manager.compressor.should_compress(messages, tokens, 100)
# Should not compress if below threshold
assert needs_compression == (tokens > 82)
@pytest.mark.asyncio
async def test_large_batch_processing(self):
"""Test processing a large batch of messages."""
config = ContextConfig(
enforce_max_turns=10, max_context_tokens=1000, truncate_turns=2
)
manager = ContextManager(config)
# Create 100 messages (50 turns)
messages = self.create_messages(100)
result = await manager.process(messages)
# Should be significantly reduced
assert len(result) < 100
assert len(result) > 0
@pytest.mark.asyncio
async def test_config_persistence(self):
"""Test that config settings are respected throughout processing."""
config = ContextConfig(
max_context_tokens=500,
enforce_max_turns=5,
truncate_turns=2,
llm_compress_keep_recent=3,
)
manager = ContextManager(config)
# Verify config is stored
assert manager.config.max_context_tokens == 500
assert manager.config.enforce_max_turns == 5
assert manager.config.truncate_turns == 2
assert manager.config.llm_compress_keep_recent == 3
# ==================== Run Compression Tests ====================
@pytest.mark.asyncio
async def test_run_compression_calls_compressor(self):
"""Test _run_compression calls compressor."""
config = ContextConfig(max_context_tokens=100)
manager = ContextManager(config)
messages = self.create_messages(5)
compressed = self.create_messages(3)
# Create a mock compressor
mock_compressor = AsyncMock()
mock_compressor.compression_threshold = 0.82
mock_compressor.return_value = compressed
mock_compressor.should_compress = MagicMock(return_value=False)
manager.compressor = mock_compressor
result = await manager._run_compression(messages, prev_tokens=100)
# Compressor __call__ should be invoked
mock_compressor.assert_called_once_with(messages)
assert result == compressed
@pytest.mark.asyncio
async def test_run_compression_applies_compressor_through_process(self):
"""Test _run_compression calls compressor when needed through process()."""
config = ContextConfig(max_context_tokens=100, truncate_turns=1)
manager = ContextManager(config)
# Create messages that will trigger compression
messages = [self.create_message("user", "x" * 300)] # ~90 tokens > 82 threshold
compressed = [self.create_message("user", "short")] # Much smaller
# Create a mock compressor
mock_compressor = AsyncMock()
mock_compressor.compression_threshold = 0.82
mock_compressor.return_value = compressed
# Mock should_compress to return True first time, False after
call_count = 0
def mock_should_compress(*args, **kwargs):
nonlocal call_count
call_count += 1
return call_count == 1
mock_compressor.should_compress = mock_should_compress
manager.compressor = mock_compressor
result = await manager.process(messages)
# Compressor should have been called
mock_compressor.assert_called_once()
assert len(result) <= len(messages)
@pytest.mark.asyncio
async def test_llm_compression_with_mock_provider(self):
"""Test LLM compression using MockProvider."""
mock_provider = MockProvider()
config = ContextConfig(
llm_compress_provider=mock_provider, # type: ignore
llm_compress_keep_recent=3,
llm_compress_instruction="请总结对话内容",
max_context_tokens=100,
)
manager = ContextManager(config)
# Create messages that will trigger compression
messages = [
self.create_message("user", "x" * 100),
self.create_message("assistant", "y" * 100),
self.create_message("user", "z" * 100),
]
result = await manager.process(messages)
# Should have been compressed
assert len(result) <= len(messages)
# ==================== split_history Tests ====================
def test_split_history_ensures_user_start(self):
"""Test split_history ensures recent_messages starts with user message."""
from astrbot.core.agent.context.compressor import split_history
# Create alternating messages: user, assistant, user, assistant, user, assistant
messages = [
self.create_message("system", "System prompt"),
self.create_message("user", "msg1"),
self.create_message("assistant", "msg2"),
self.create_message("user", "msg3"),
self.create_message("assistant", "msg4"),
self.create_message("user", "msg5"),
self.create_message("assistant", "msg6"),
]
# Keep recent 3 messages - should adjust to start with user
system, to_summarize, recent = split_history(messages, keep_recent=3)
# recent_messages should start with user message
assert len(recent) > 0
assert recent[0].role == "user"
# messages_to_summarize should end with assistant (complete turn)
if len(to_summarize) > 0:
assert to_summarize[-1].role == "assistant"
def test_split_history_handles_assistant_at_split_point(self):
"""Test split_history when assistant message is at the intended split point."""
from astrbot.core.agent.context.compressor import split_history
messages = [
self.create_message("user", "msg1"),
self.create_message("assistant", "msg2"),
self.create_message("user", "msg3"),
self.create_message("assistant", "msg4"), # <- intended split here
self.create_message("user", "msg5"),
self.create_message("assistant", "msg6"),
]
# keep_recent=2 would normally split at index 4 (assistant msg4)
# Should move back to include from msg5 (user)
system, to_summarize, recent = split_history(messages, keep_recent=2)
# recent should start with user message
assert recent[0].role == "user"
assert recent[0].content == "msg5"
def test_split_history_all_assistant_messages(self):
"""Test split_history when there are consecutive assistant messages."""
from astrbot.core.agent.context.compressor import split_history
messages = [
self.create_message("user", "msg1"),
self.create_message("assistant", "msg2"),
self.create_message("assistant", "msg3"),
self.create_message("assistant", "msg4"),
]
system, to_summarize, recent = split_history(messages, keep_recent=2)
# Should find the user message and keep from there
if len(recent) > 0:
# Find first user message backwards
assert any(m.role == "user" for m in messages)
def test_split_history_with_system_messages(self):
"""Test split_history preserves system messages separately."""
from astrbot.core.agent.context.compressor import split_history
messages = [
self.create_message("system", "System 1"),
self.create_message("system", "System 2"),
self.create_message("user", "msg1"),
self.create_message("assistant", "msg2"),
self.create_message("user", "msg3"),
]
system, to_summarize, recent = split_history(messages, keep_recent=2)
# System messages should be separate
assert len(system) == 2
assert all(m.role == "system" for m in system)
# Recent should start with user
if len(recent) > 0:
assert recent[0].role == "user"
+423
View File
@@ -0,0 +1,423 @@
"""Tests for ContextTruncator."""
from astrbot.core.agent.context.truncator import ContextTruncator
from astrbot.core.agent.message import Message
class TestContextTruncator:
"""Test suite for ContextTruncator."""
def create_message(self, role: str, content: str = "test content") -> Message:
"""Helper to create a simple test message."""
return Message(role=role, content=content)
def create_messages(
self, count: int, include_system: bool = False
) -> list[Message]:
"""Helper to create alternating user/assistant messages.
Args:
count: Number of messages to create
include_system: Whether to include a system message at the start
Returns:
List of messages
"""
messages = []
if include_system:
messages.append(self.create_message("system", "System prompt"))
for i in range(count):
role = "user" if i % 2 == 0 else "assistant"
messages.append(self.create_message(role, f"Message {i}"))
return messages
# ==================== fix_messages Tests ====================
def test_fix_messages_empty_list(self):
"""Test fix_messages with an empty list."""
truncator = ContextTruncator()
result = truncator.fix_messages([])
assert result == []
def test_fix_messages_normal_messages(self):
"""Test fix_messages with normal user/assistant messages."""
truncator = ContextTruncator()
messages = [
self.create_message("user", "Hello"),
self.create_message("assistant", "Hi"),
self.create_message("user", "How are you?"),
]
result = truncator.fix_messages(messages)
assert len(result) == 3
assert result == messages
def test_fix_messages_tool_with_valid_context(self):
"""Test fix_messages with tool message after user+assistant."""
truncator = ContextTruncator()
messages = [
self.create_message("user", "Run tool"),
self.create_message("assistant", "Running..."),
self.create_message("tool", "Tool result"),
]
result = truncator.fix_messages(messages)
assert len(result) == 3
assert result == messages
def test_fix_messages_tool_without_context(self):
"""Test fix_messages with tool message without enough context."""
truncator = ContextTruncator()
messages = [
self.create_message("tool", "Tool result"),
]
result = truncator.fix_messages(messages)
# Tool message without context should be removed
assert len(result) == 0
def test_fix_messages_tool_with_only_one_message(self):
"""Test fix_messages with tool message after only one message."""
truncator = ContextTruncator()
messages = [
self.create_message("user", "Hello"),
self.create_message("tool", "Tool result"),
]
result = truncator.fix_messages(messages)
# Tool message without enough context should be removed
assert len(result) == 0
def test_fix_messages_multiple_tools(self):
"""Test fix_messages with multiple tool messages."""
truncator = ContextTruncator()
messages = [
self.create_message("user", "Run tool"),
self.create_message("assistant", "Running..."),
self.create_message("tool", "Tool 1 result"),
self.create_message("tool", "Tool 2 result"),
]
result = truncator.fix_messages(messages)
assert len(result) == 4
assert result == messages
def test_fix_messages_mixed_system_tool(self):
"""Test fix_messages with system message and tool messages."""
truncator = ContextTruncator()
messages = [
self.create_message("system", "System prompt"),
self.create_message("user", "Run tool"),
self.create_message("assistant", "Running..."),
self.create_message("tool", "Tool result"),
]
result = truncator.fix_messages(messages)
assert len(result) == 4
assert result == messages
# ==================== truncate_by_turns Tests ====================
def test_truncate_by_turns_no_limit(self):
"""Test truncate_by_turns with -1 (no limit)."""
truncator = ContextTruncator()
messages = self.create_messages(20)
result = truncator.truncate_by_turns(messages, keep_most_recent_turns=-1)
assert len(result) == 20
assert result == messages
def test_truncate_by_turns_basic(self):
"""Test basic truncate_by_turns functionality."""
truncator = ContextTruncator()
# Create 10 messages = 5 turns (user/assistant pairs)
messages = self.create_messages(10)
result = truncator.truncate_by_turns(
messages, keep_most_recent_turns=3, drop_turns=1
)
# Should keep 3 most recent turns (6 messages)
assert len(result) <= 8 # (3-1+1)*2 = 6, but may adjust for correct format
def test_truncate_by_turns_with_system_message(self):
"""Test truncate_by_turns preserves system messages."""
truncator = ContextTruncator()
messages = self.create_messages(10, include_system=True)
result = truncator.truncate_by_turns(
messages, keep_most_recent_turns=2, drop_turns=1
)
# System message should always be preserved
assert result[0].role == "system"
assert result[0].content == "System prompt"
def test_truncate_by_turns_zero_keep(self):
"""Test truncate_by_turns with keep_most_recent_turns=0."""
truncator = ContextTruncator()
messages = self.create_messages(10)
result = truncator.truncate_by_turns(
messages, keep_most_recent_turns=0, drop_turns=1
)
# Should result in empty or minimal list
assert len(result) == 0
def test_truncate_by_turns_below_threshold(self):
"""Test truncate_by_turns when messages are below threshold."""
truncator = ContextTruncator()
# Create 4 messages = 2 turns
messages = self.create_messages(4)
result = truncator.truncate_by_turns(
messages, keep_most_recent_turns=5, drop_turns=1
)
# No truncation should happen
assert len(result) == 4
assert result == messages
def test_truncate_by_turns_exact_threshold(self):
"""Test truncate_by_turns when messages exactly match threshold."""
truncator = ContextTruncator()
# Create 6 messages = 3 turns
messages = self.create_messages(6)
result = truncator.truncate_by_turns(
messages, keep_most_recent_turns=3, drop_turns=1
)
# No truncation should happen
assert len(result) == 6
assert result == messages
def test_truncate_by_turns_ensures_user_first(self):
"""Test that truncate_by_turns ensures user message comes first."""
truncator = ContextTruncator()
# Create scenario where truncation might start with assistant
messages = self.create_messages(20)
result = truncator.truncate_by_turns(
messages, keep_most_recent_turns=3, drop_turns=1
)
# First non-system message should be user
assert result[0].role == "user"
def test_truncate_by_turns_multiple_drop(self):
"""Test truncate_by_turns with multiple turns dropped at once."""
truncator = ContextTruncator()
messages = self.create_messages(20)
result = truncator.truncate_by_turns(
messages, keep_most_recent_turns=5, drop_turns=3
)
# Should drop 3 turns when limit exceeded
assert len(result) < len(messages)
# ==================== truncate_by_dropping_oldest_turns Tests ====================
def test_truncate_by_dropping_oldest_turns_zero(self):
"""Test truncate_by_dropping_oldest_turns with drop_turns=0."""
truncator = ContextTruncator()
messages = self.create_messages(10)
result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=0)
assert result == messages
def test_truncate_by_dropping_oldest_turns_negative(self):
"""Test truncate_by_dropping_oldest_turns with negative drop_turns."""
truncator = ContextTruncator()
messages = self.create_messages(10)
result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=-1)
assert result == messages
def test_truncate_by_dropping_oldest_turns_basic(self):
"""Test basic truncate_by_dropping_oldest_turns functionality."""
truncator = ContextTruncator()
# Create 10 messages = 5 turns
messages = self.create_messages(10)
result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=2)
# Should drop 2 oldest turns (4 messages)
assert len(result) == 6
# Should start with user message
assert result[0].role == "user"
def test_truncate_by_dropping_oldest_turns_with_system(self):
"""Test truncate_by_dropping_oldest_turns preserves system messages."""
truncator = ContextTruncator()
messages = self.create_messages(10, include_system=True)
result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=2)
# System message should be preserved
assert result[0].role == "system"
assert result[0].content == "System prompt"
def test_truncate_by_dropping_oldest_turns_drop_all(self):
"""Test truncate_by_dropping_oldest_turns dropping all turns."""
truncator = ContextTruncator()
# Create 4 messages = 2 turns
messages = self.create_messages(4)
result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=2)
# Should drop all turns
assert len(result) == 0
def test_truncate_by_dropping_oldest_turns_drop_more_than_available(self):
"""Test truncate_by_dropping_oldest_turns with drop_turns > available turns."""
truncator = ContextTruncator()
# Create 4 messages = 2 turns
messages = self.create_messages(4)
result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=5)
# Should result in empty list
assert len(result) == 0
def test_truncate_by_dropping_oldest_turns_ensures_user_first(self):
"""Test that result starts with user message after dropping."""
truncator = ContextTruncator()
messages = self.create_messages(20)
result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=3)
# First message should be user
if len(result) > 0:
assert result[0].role == "user"
# ==================== truncate_by_halving Tests ====================
def test_truncate_by_halving_empty(self):
"""Test truncate_by_halving with empty list."""
truncator = ContextTruncator()
result = truncator.truncate_by_halving([])
assert result == []
def test_truncate_by_halving_single_message(self):
"""Test truncate_by_halving with single message."""
truncator = ContextTruncator()
messages = [self.create_message("user", "Hello")]
result = truncator.truncate_by_halving(messages)
# Should not truncate if <= 2 messages
assert result == messages
def test_truncate_by_halving_two_messages(self):
"""Test truncate_by_halving with two messages."""
truncator = ContextTruncator()
messages = self.create_messages(2)
result = truncator.truncate_by_halving(messages)
# Should not truncate if <= 2 messages
assert result == messages
def test_truncate_by_halving_basic(self):
"""Test basic truncate_by_halving functionality."""
truncator = ContextTruncator()
# Create 20 messages
messages = self.create_messages(20)
result = truncator.truncate_by_halving(messages)
# Should delete 50% = 10 messages, keep 10
assert len(result) == 10
# First message should be user
assert result[0].role == "user"
def test_truncate_by_halving_with_system_message(self):
"""Test truncate_by_halving preserves system messages."""
truncator = ContextTruncator()
messages = self.create_messages(20, include_system=True)
result = truncator.truncate_by_halving(messages)
# System message should be preserved
assert result[0].role == "system"
assert result[0].content == "System prompt"
def test_truncate_by_halving_odd_count(self):
"""Test truncate_by_halving with odd number of messages."""
truncator = ContextTruncator()
messages = self.create_messages(11)
result = truncator.truncate_by_halving(messages)
# Should delete floor(11/2) = 5 messages, keep 6
# But after ensuring user first, may be 5
assert len(result) >= 5
assert result[0].role == "user"
def test_truncate_by_halving_ensures_user_first(self):
"""Test that result starts with user message."""
truncator = ContextTruncator()
# Create messages starting with user
messages = self.create_messages(30)
result = truncator.truncate_by_halving(messages)
# First message should be user
assert result[0].role == "user"
def test_truncate_by_halving_preserves_recent_messages(self):
"""Test that truncate_by_halving keeps the most recent 50%."""
truncator = ContextTruncator()
messages = [
self.create_message("user", "Message 0"),
self.create_message("assistant", "Message 1"),
self.create_message("user", "Message 2"),
self.create_message("assistant", "Message 3"),
]
result = truncator.truncate_by_halving(messages)
# Should keep last 2 messages
assert len(result) == 2
assert result[0].content == "Message 2"
assert result[1].content == "Message 3"
# ==================== Integration Tests ====================
def test_truncate_with_tool_messages(self):
"""Test truncation with tool messages."""
truncator = ContextTruncator()
messages = [
self.create_message("user", "Run tool"),
self.create_message("assistant", "Running..."),
self.create_message("tool", "Tool result"),
self.create_message("user", "Thanks"),
self.create_message("assistant", "Welcome"),
]
result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=1)
# First turn (user+assistant+tool) should be dropped
# Tool message should be cleaned up by fix_messages
assert len(result) <= 2
def test_chain_multiple_truncations(self):
"""Test chaining multiple truncation methods."""
truncator = ContextTruncator()
messages = self.create_messages(40, include_system=True)
# First: truncate by turns
result = truncator.truncate_by_turns(
messages, keep_most_recent_turns=10, drop_turns=2
)
# Then: halve
result = truncator.truncate_by_halving(result)
# Should have system message + truncated content
assert result[0].role == "system"
assert len(result) < len(messages)
def test_empty_after_system_message(self):
"""Test truncation when only system message exists."""
truncator = ContextTruncator()
messages = [self.create_message("system", "System prompt")]
result = truncator.truncate_by_turns(
messages, keep_most_recent_turns=5, drop_turns=1
)
# Should keep system message
assert len(result) == 1
assert result[0].role == "system"
def test_all_system_messages(self):
"""Test truncation with only system messages."""
truncator = ContextTruncator()
messages = [
self.create_message("system", "System 1"),
self.create_message("system", "System 2"),
]
result = truncator.truncate_by_turns(
messages, keep_most_recent_turns=0, drop_turns=1
)
# System messages should be preserved, but since there are no non-system
# messages and keep_most_recent_turns=0, result should be system messages only
assert len(result) >= 0 # May keep system messages or clear all
if len(result) > 0:
assert all(msg.role == "system" for msg in result)