feat: context compressor
Co-authored-by: kawayiYokami <289104862@qq.com>
This commit is contained in:
@@ -0,0 +1,140 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from astrbot.api import logger
|
||||
|
||||
from ..message import Message
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from astrbot.core.provider.provider import Provider
|
||||
|
||||
from ..context.truncator import ContextTruncator
|
||||
|
||||
|
||||
class ContextCompressor(ABC):
|
||||
"""
|
||||
Abstract base class for context compressors.
|
||||
Provides an interface for compressing message lists.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def compress(self, messages: list[Message]) -> list[Message]:
|
||||
"""Compress the message list.
|
||||
|
||||
Args:
|
||||
messages: The original message list.
|
||||
|
||||
Returns:
|
||||
The compressed message list.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class DefaultCompressor(ContextCompressor):
|
||||
"""Default compressor implementation.
|
||||
Returns the original messages.
|
||||
"""
|
||||
|
||||
async def compress(self, messages: list[Message]) -> list[Message]:
|
||||
return messages
|
||||
|
||||
|
||||
class TruncateByTurnsCompressor(ContextCompressor):
|
||||
"""Truncate by turns compressor implementation.
|
||||
Truncates the message list by removing older turns.
|
||||
"""
|
||||
|
||||
def __init__(self, truncate_turns: int = 1):
|
||||
"""Initialize the truncate by turns compressor.
|
||||
|
||||
Args:
|
||||
truncate_turns: The number of turns to remove when truncating (default: 1).
|
||||
"""
|
||||
self.truncate_turns = truncate_turns
|
||||
|
||||
async def compress(self, messages: list[Message]) -> list[Message]:
|
||||
truncator = ContextTruncator()
|
||||
truncated_messages = truncator.truncate_by_turns(
|
||||
messages,
|
||||
keep_most_recent_turns=0,
|
||||
dequeue_turns=self.truncate_turns,
|
||||
)
|
||||
return truncated_messages
|
||||
|
||||
|
||||
class LLMSummaryCompressor(ContextCompressor):
|
||||
"""LLM-based summary compressor.
|
||||
Uses LLM to summarize the old conversation history, keeping the latest messages.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
provider: "Provider",
|
||||
keep_recent: int = 4,
|
||||
instruction_text: str | None = None,
|
||||
):
|
||||
"""Initialize the LLM summary compressor.
|
||||
|
||||
Args:
|
||||
provider: The LLM provider instance.
|
||||
keep_recent: The number of latest messages to keep (default: 4).
|
||||
"""
|
||||
self.provider = provider
|
||||
self.keep_recent = keep_recent
|
||||
|
||||
self.instruction_text = instruction_text or (
|
||||
"Based on our full conversation history, produce a concise summary of key takeaways and/or project progress.\n"
|
||||
"1. Systematically cover all core topics discussed and the final conclusion/outcome for each; clearly highlight the latest primary focus.\n"
|
||||
"2. If any tools were used, summarize tool usage (total call count) and extract the most valuable insights from tool outputs.\n"
|
||||
"3. If there was an initial user goal, state it first and describe the current progress/status.\n"
|
||||
"4. Write the summary in the user's language.\n"
|
||||
)
|
||||
|
||||
async def compress(self, messages: list[Message]) -> list[Message]:
|
||||
"""Use LLM to generate a summary of the conversation history.
|
||||
|
||||
Process:
|
||||
1. Divide messages: keep the system message and the latest N messages.
|
||||
2. Send the old messages + the instruction message to the LLM.
|
||||
3. Reconstruct the message list: [system message, summary message, latest messages].
|
||||
"""
|
||||
if len(messages) <= self.keep_recent + 1:
|
||||
return messages
|
||||
|
||||
# keep the system message
|
||||
system_msg = messages[0] if messages and messages[0].role == "system" else None
|
||||
start_idx = 1 if system_msg else 0
|
||||
|
||||
messages_to_summarize = messages[start_idx : -self.keep_recent]
|
||||
recent_messages = messages[-self.keep_recent :]
|
||||
|
||||
if not messages_to_summarize:
|
||||
return messages
|
||||
|
||||
# build payload
|
||||
instruction_message = Message(role="user", content=self.instruction_text)
|
||||
llm_payload = messages_to_summarize + [instruction_message]
|
||||
|
||||
# generate summary
|
||||
try:
|
||||
response = await self.provider.text_chat(contexts=llm_payload)
|
||||
summary_content = response.completion_text
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate summary: {e}")
|
||||
return messages
|
||||
|
||||
# build result
|
||||
result = []
|
||||
if system_msg:
|
||||
result.append(system_msg)
|
||||
|
||||
result.append(
|
||||
Message(
|
||||
role="system",
|
||||
content=f"History conversation summary: {summary_content}",
|
||||
),
|
||||
)
|
||||
|
||||
result.extend(recent_messages)
|
||||
|
||||
return result
|
||||
@@ -0,0 +1,141 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from astrbot import logger
|
||||
|
||||
from ..message import Message
|
||||
from .compressor import LLMSummaryCompressor, TruncateByTurnsCompressor
|
||||
from .token_counter import TokenCounter
|
||||
from .truncator import ContextTruncator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from astrbot.core.provider.provider import Provider
|
||||
|
||||
|
||||
class ContextManager:
|
||||
"""Context compression manager."""
|
||||
|
||||
COMPRESSION_THRESHOLD = 0.82
|
||||
"""compression trigger threshold"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_context_tokens: int = 0,
|
||||
truncate_turns: int = 1,
|
||||
llm_compress_instruction: str | None = None,
|
||||
llm_compress_keep_recent: int = 4,
|
||||
llm_compress_provider: "Provider | None" = None,
|
||||
):
|
||||
"""Initialize the context manager.
|
||||
|
||||
There are two strategies to handle context limit reached:
|
||||
1. Truncate by turns: remove older messages by turns.
|
||||
2. LLM-based compression: use LLM to summarize old messages.
|
||||
|
||||
Args:
|
||||
max_context_tokens: The maximum context tokens. <= 0 means no limit.
|
||||
truncate_turns: For turncate strategy. The number of turns to discard when truncating.
|
||||
llm_compress_instruction: The instruction text for LLM compression.
|
||||
llm_compress_keep_recent: The number of recent messages to keep during LLM compression.
|
||||
llm_compress_provider: The LLM provider for compression.
|
||||
"""
|
||||
self.max_context_tokens = max_context_tokens
|
||||
self.truncate_turns = truncate_turns
|
||||
|
||||
self.token_counter = TokenCounter()
|
||||
self.truncator = ContextTruncator()
|
||||
|
||||
if llm_compress_provider:
|
||||
self.compressor = LLMSummaryCompressor(
|
||||
provider=llm_compress_provider,
|
||||
keep_recent=llm_compress_keep_recent,
|
||||
instruction_text=llm_compress_instruction,
|
||||
)
|
||||
else:
|
||||
self.compressor = TruncateByTurnsCompressor(truncate_turns=truncate_turns)
|
||||
|
||||
async def process(self, messages: list[Message]) -> list[Message]:
|
||||
"""Process the messages.
|
||||
|
||||
Args:
|
||||
messages: The original message list.
|
||||
|
||||
Returns:
|
||||
The processed message list.
|
||||
"""
|
||||
if self.max_context_tokens <= 0:
|
||||
return messages
|
||||
|
||||
# check if the messages need to be compressed
|
||||
needs_compression, _ = await self._initial_token_check(messages)
|
||||
|
||||
# compress/truncate the messages if needed
|
||||
messages = await self._run_compression(messages, needs_compression)
|
||||
|
||||
return messages
|
||||
|
||||
async def _initial_token_check(
|
||||
self, messages: list[Message]
|
||||
) -> tuple[bool, int | None]:
|
||||
"""
|
||||
Check if the messages need to be compressed.
|
||||
|
||||
Args:
|
||||
messages: The original message list.
|
||||
|
||||
Returns:
|
||||
tuple: (whether to compress, initial token count)
|
||||
"""
|
||||
if not messages:
|
||||
return False, None
|
||||
if self.max_context_tokens <= 0:
|
||||
return False, None
|
||||
|
||||
total_tokens = self.token_counter.count_tokens(messages)
|
||||
|
||||
logger.debug(
|
||||
f"ContextManager: total tokens = {total_tokens}, max_context_tokens = {self.max_context_tokens}"
|
||||
)
|
||||
usage_rate = total_tokens / self.max_context_tokens
|
||||
|
||||
needs_compression = usage_rate > self.COMPRESSION_THRESHOLD
|
||||
return needs_compression, total_tokens if needs_compression else None
|
||||
|
||||
async def _run_compression(
|
||||
self, messages: list[Message], needs_compression: bool
|
||||
) -> list[Message]:
|
||||
"""
|
||||
Compress/truncate the messages if needed.
|
||||
|
||||
Args:
|
||||
messages: The original message list.
|
||||
needs_compression: Whether to compress.
|
||||
|
||||
Returns:
|
||||
The compressed/truncated message list.
|
||||
"""
|
||||
if not needs_compression:
|
||||
return messages
|
||||
if self.max_context_tokens <= 0:
|
||||
return messages
|
||||
|
||||
messages = await self.compressor.compress(messages)
|
||||
|
||||
# double check
|
||||
tokens_after_summary = self.token_counter.count_tokens(messages)
|
||||
if tokens_after_summary / self.max_context_tokens > self.COMPRESSION_THRESHOLD:
|
||||
# still over 82%, truncate by half
|
||||
messages = self._compress_by_halving(messages)
|
||||
|
||||
return messages
|
||||
|
||||
def _compress_by_halving(self, messages: list[Message]) -> list[Message]:
|
||||
"""
|
||||
对半砍策略:删除中间50%的消息
|
||||
|
||||
Args:
|
||||
messages: 原始消息列表
|
||||
|
||||
Returns:
|
||||
截断后的消息列表
|
||||
"""
|
||||
return self.truncator.truncate_by_halving(messages)
|
||||
@@ -0,0 +1,30 @@
|
||||
import json
|
||||
|
||||
from ..message import Message, TextPart
|
||||
|
||||
|
||||
class TokenCounter:
|
||||
def count_tokens(self, messages: list[Message]) -> int:
|
||||
total = 0
|
||||
for msg in messages:
|
||||
content = msg.content
|
||||
if isinstance(content, str):
|
||||
total += self._estimate_tokens(content)
|
||||
elif isinstance(content, list):
|
||||
# 处理多模态内容
|
||||
for part in content:
|
||||
if isinstance(part, TextPart):
|
||||
total += self._estimate_tokens(part.text)
|
||||
|
||||
# 处理 Tool Calls
|
||||
if msg.tool_calls:
|
||||
for tc in msg.tool_calls:
|
||||
tc_str = json.dumps(tc if isinstance(tc, dict) else tc.model_dump())
|
||||
total += self._estimate_tokens(tc_str)
|
||||
|
||||
return total
|
||||
|
||||
def _estimate_tokens(self, text: str) -> int:
|
||||
chinese_count = len([c for c in text if "\u4e00" <= c <= "\u9fff"])
|
||||
other_count = len(text) - chinese_count
|
||||
return int(chinese_count * 0.6 + other_count * 0.3)
|
||||
@@ -0,0 +1,94 @@
|
||||
from ..message import Message
|
||||
|
||||
|
||||
class ContextTruncator:
|
||||
"""Context truncator."""
|
||||
|
||||
def fix_messages(self, messages: list[Message]) -> list[Message]:
|
||||
fixed_messages = []
|
||||
for message in messages:
|
||||
if message.role == "tool":
|
||||
# tool block 前面必须要有 user 和 assistant block
|
||||
if len(fixed_messages) < 2:
|
||||
# 这种情况可能是上下文被截断导致的
|
||||
# 我们直接将之前的上下文都清空
|
||||
fixed_messages = []
|
||||
else:
|
||||
fixed_messages.append(message)
|
||||
else:
|
||||
fixed_messages.append(message)
|
||||
return fixed_messages
|
||||
|
||||
def truncate_by_turns(
|
||||
self,
|
||||
messages: list[Message],
|
||||
keep_most_recent_turns: int,
|
||||
dequeue_turns: int = 1,
|
||||
) -> list[Message]:
|
||||
"""截断上下文列表,确保不超过最大长度。
|
||||
一个 turn 包含一个 user 消息和一个 assistant 消息。
|
||||
这个方法会保证截断后的上下文列表符合 OpenAI 的上下文格式。
|
||||
|
||||
Args:
|
||||
messages: 上下文列表
|
||||
keep_most_recent_turns: 保留最近的对话轮数
|
||||
dequeue_turns: 一次性丢弃的对话轮数
|
||||
|
||||
Returns:
|
||||
截断后的上下文列表
|
||||
"""
|
||||
if keep_most_recent_turns == -1:
|
||||
return messages
|
||||
if len(messages) <= keep_most_recent_turns:
|
||||
return messages
|
||||
if len(messages) // 2 <= keep_most_recent_turns:
|
||||
return messages
|
||||
|
||||
system_message = None
|
||||
if messages[0].role == "system":
|
||||
system_message = messages[0]
|
||||
messages = messages[1:]
|
||||
|
||||
truncated_contexts = messages[
|
||||
-(keep_most_recent_turns - dequeue_turns + 1) * 2 :
|
||||
]
|
||||
# 找到第一个role 为 user 的索引,确保上下文格式正确
|
||||
index = next(
|
||||
(i for i, item in enumerate(truncated_contexts) if item.role == "user"),
|
||||
None,
|
||||
)
|
||||
if index is not None and index > 0:
|
||||
truncated_contexts = truncated_contexts[index:]
|
||||
|
||||
if system_message is not None:
|
||||
truncated_contexts = [system_message] + truncated_contexts
|
||||
|
||||
return self.fix_messages(truncated_contexts)
|
||||
|
||||
def truncate_by_halving(
|
||||
self,
|
||||
messages: list[Message],
|
||||
) -> list[Message]:
|
||||
"""对半砍策略,删除 50% 的消息"""
|
||||
if len(messages) <= 2:
|
||||
return messages
|
||||
|
||||
first_non_system = 0
|
||||
for i, msg in enumerate(messages):
|
||||
if msg.role != "system":
|
||||
first_non_system = i
|
||||
break
|
||||
|
||||
messages_to_delete = (len(messages) - first_non_system) // 2
|
||||
|
||||
result = messages[:first_non_system]
|
||||
result.extend(messages[first_non_system + messages_to_delete :])
|
||||
|
||||
index = next(
|
||||
(i for i, item in enumerate(result) if item.role == "user"),
|
||||
None,
|
||||
)
|
||||
if index is not None:
|
||||
result = result[index:]
|
||||
|
||||
return self.fix_messages(result)
|
||||
@@ -25,6 +25,8 @@ from astrbot.core.provider.entities import (
|
||||
)
|
||||
from astrbot.core.provider.provider import Provider
|
||||
|
||||
from ..context.manager import ContextManager
|
||||
from ..context.truncator import ContextTruncator
|
||||
from ..hooks import BaseAgentRunHooks
|
||||
from ..message import AssistantMessageSegment, Message, ToolCallMessageSegment
|
||||
from ..response import AgentResponseData, AgentStats
|
||||
@@ -51,6 +53,33 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
) -> None:
|
||||
self.req = request
|
||||
self.streaming = kwargs.get("streaming", False)
|
||||
|
||||
# enforce max turns, will discard older turns when exceeded BEFORE compression
|
||||
# -1 means no limit
|
||||
self.enforce_max_turns = kwargs.get("enforce_max_turns", -1)
|
||||
|
||||
# llm compressor
|
||||
self.llm_compress_instruction = kwargs.get("llm_compress_instruction", None)
|
||||
self.llm_compress_keep_recent = kwargs.get("llm_compress_keep_recent", 0)
|
||||
self.llm_compress_provider: Provider | None = kwargs.get(
|
||||
"llm_compress_provider", None
|
||||
)
|
||||
# truncate by turns compressor
|
||||
self.truncate_turns = kwargs.get("truncate_turns", 1)
|
||||
|
||||
# we will do compress when:
|
||||
# 1. before requesting LLM
|
||||
# TODO: 2. after LLM output a tool call
|
||||
self.context_manager = ContextManager(
|
||||
# <=0 will never trigger context compression
|
||||
max_context_tokens=provider.provider_config.get("max_context_tokens", 0),
|
||||
truncate_turns=self.truncate_turns,
|
||||
llm_compress_instruction=self.llm_compress_instruction,
|
||||
llm_compress_keep_recent=self.llm_compress_keep_recent,
|
||||
llm_compress_provider=self.llm_compress_provider,
|
||||
)
|
||||
self.context_truncator = ContextTruncator()
|
||||
|
||||
self.provider = provider
|
||||
self.final_llm_resp = None
|
||||
self._state = AgentState.IDLE
|
||||
@@ -92,6 +121,12 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
else:
|
||||
yield await self.provider.text_chat(**payload)
|
||||
|
||||
async def do_context_compress(self):
|
||||
"""检查并执行上下文压缩。"""
|
||||
original_messages = self.run_context.messages
|
||||
compressed_messages = await self.context_manager.process(original_messages)
|
||||
self.run_context.messages = compressed_messages
|
||||
|
||||
@override
|
||||
async def step(self):
|
||||
"""Process a single step of the agent.
|
||||
@@ -110,6 +145,24 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
self._transition_state(AgentState.RUNNING)
|
||||
llm_resp_result = None
|
||||
|
||||
# do truncate
|
||||
if self.enforce_max_turns != -1:
|
||||
try:
|
||||
truncated_messages = self.context_truncator.truncate_by_turns(
|
||||
self.run_context.messages,
|
||||
keep_most_recent_turns=self.enforce_max_turns,
|
||||
dequeue_turns=self.truncate_turns,
|
||||
)
|
||||
self.run_context.messages = truncated_messages
|
||||
except Exception as e:
|
||||
logger.error(f"Error during context truncation: {e}", exc_info=True)
|
||||
|
||||
# check compress
|
||||
try:
|
||||
await self.do_context_compress()
|
||||
except Exception as e:
|
||||
logger.error(f"Error during context compression: {e}", exc_info=True)
|
||||
|
||||
async for llm_response in self._iter_llm_responses():
|
||||
if llm_response.is_chunk:
|
||||
# update ttft
|
||||
|
||||
@@ -83,6 +83,16 @@ DEFAULT_CONFIG = {
|
||||
"default_personality": "default",
|
||||
"persona_pool": ["*"],
|
||||
"prompt_prefix": "{{prompt}}",
|
||||
"context_limit_reached_strategy": "truncate_by_turns", # or llm_compress
|
||||
"llm_compress_instruction": (
|
||||
"Based on our full conversation history, produce a concise summary of key takeaways and/or project progress.\n"
|
||||
"1. Systematically cover all core topics discussed and the final conclusion/outcome for each; clearly highlight the latest primary focus.\n"
|
||||
"2. If any tools were used, summarize tool usage (total call count) and extract the most valuable insights from tool outputs.\n"
|
||||
"3. If there was an initial user goal, state it first and describe the current progress/status.\n"
|
||||
"4. Write the summary in the user's language.\n"
|
||||
),
|
||||
"llm_compress_keep_recent": 4,
|
||||
"llm_compress_provider_id": "",
|
||||
"max_context_length": -1,
|
||||
"dequeue_context_length": 1,
|
||||
"streaming_response": False,
|
||||
@@ -179,6 +189,7 @@ class ChatProviderTemplate(TypedDict):
|
||||
model: str
|
||||
modalities: list
|
||||
custom_extra_body: dict[str, Any]
|
||||
max_context_tokens: int
|
||||
|
||||
|
||||
CHAT_PROVIDER_TEMPLATE = {
|
||||
@@ -187,6 +198,7 @@ CHAT_PROVIDER_TEMPLATE = {
|
||||
"model": "",
|
||||
"modalities": [],
|
||||
"custom_extra_body": {},
|
||||
"max_context_tokens": 0,
|
||||
}
|
||||
|
||||
"""
|
||||
@@ -2033,6 +2045,11 @@ CONFIG_METADATA_2 = {
|
||||
"type": "string",
|
||||
"hint": "模型名称,如 gpt-4o-mini, deepseek-chat。",
|
||||
},
|
||||
"max_context_tokens": {
|
||||
"description": "模型上下文窗口大小",
|
||||
"type": "int",
|
||||
"hint": "模型最大上下文 Token 大小。如果为 0,则会自动从模型元数据填充(如有),也可手动修改。",
|
||||
},
|
||||
"dify_api_key": {
|
||||
"description": "API Key",
|
||||
"type": "string",
|
||||
@@ -2540,6 +2557,66 @@ CONFIG_METADATA_3 = {
|
||||
# "provider_settings.enable": True,
|
||||
# },
|
||||
# },
|
||||
"truncate_and_compress": {
|
||||
"description": "上下文管理策略",
|
||||
"type": "object",
|
||||
"items": {
|
||||
"provider_settings.max_context_length": {
|
||||
"description": "最多携带对话轮数",
|
||||
"type": "int",
|
||||
"hint": "超出这个数量时丢弃最旧的部分,一轮聊天记为 1 条,-1 为不限制",
|
||||
"condition": {
|
||||
"provider_settings.agent_runner_type": "local",
|
||||
},
|
||||
},
|
||||
"provider_settings.dequeue_context_length": {
|
||||
"description": "丢弃对话轮数",
|
||||
"type": "int",
|
||||
"hint": "超出最多携带对话轮数时, 一次丢弃的聊天轮数",
|
||||
"condition": {
|
||||
"provider_settings.agent_runner_type": "local",
|
||||
},
|
||||
},
|
||||
"provider_settings.context_limit_reached_strategy": {
|
||||
"description": "超出模型上下文窗口时的处理方式",
|
||||
"type": "string",
|
||||
"options": ["truncate_by_turns", "llm_compress"],
|
||||
"labels": ["按对话轮数截断", "由 LLM 压缩上下文"],
|
||||
"condition": {
|
||||
"provider_settings.agent_runner_type": "local",
|
||||
},
|
||||
"hint": "",
|
||||
},
|
||||
"provider_settings.llm_compress_instruction": {
|
||||
"description": "上下文压缩提示词",
|
||||
"type": "text",
|
||||
"hint": "如果为空则使用默认提示词。",
|
||||
"condition": {
|
||||
"provider_settings.context_limit_reached_strategy": "llm_compress",
|
||||
"provider_settings.agent_runner_type": "local",
|
||||
},
|
||||
},
|
||||
"provider_settings.llm_compress_keep_recent": {
|
||||
"description": "压缩时保留最近对话轮数",
|
||||
"type": "int",
|
||||
"hint": "始终保留的最近 N 轮对话。",
|
||||
"condition": {
|
||||
"provider_settings.context_limit_reached_strategy": "llm_compress",
|
||||
"provider_settings.agent_runner_type": "local",
|
||||
},
|
||||
},
|
||||
"provider_settings.llm_compress_provider_id": {
|
||||
"description": "用于上下文压缩的模型提供商 ID",
|
||||
"type": "string",
|
||||
"_special": "select_provider",
|
||||
"hint": "留空时将降级为“按对话轮数截断”的策略。",
|
||||
"condition": {
|
||||
"provider_settings.context_limit_reached_strategy": "llm_compress",
|
||||
"provider_settings.agent_runner_type": "local",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"others": {
|
||||
"description": "其他配置",
|
||||
"type": "object",
|
||||
@@ -2604,22 +2681,6 @@ CONFIG_METADATA_3 = {
|
||||
"provider_settings.streaming_response": True,
|
||||
},
|
||||
},
|
||||
"provider_settings.max_context_length": {
|
||||
"description": "最多携带对话轮数",
|
||||
"type": "int",
|
||||
"hint": "超出这个数量时丢弃最旧的部分,一轮聊天记为 1 条,-1 为不限制",
|
||||
"condition": {
|
||||
"provider_settings.agent_runner_type": "local",
|
||||
},
|
||||
},
|
||||
"provider_settings.dequeue_context_length": {
|
||||
"description": "丢弃对话轮数",
|
||||
"type": "int",
|
||||
"hint": "超出最多携带对话轮数时, 一次丢弃的聊天轮数",
|
||||
"condition": {
|
||||
"provider_settings.agent_runner_type": "local",
|
||||
},
|
||||
},
|
||||
"provider_settings.wake_prefix": {
|
||||
"description": "LLM 聊天额外唤醒前缀 ",
|
||||
"type": "string",
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
"""本地 Agent 模式的 LLM 调用 Stage"""
|
||||
|
||||
import asyncio
|
||||
import copy
|
||||
import json
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
@@ -24,6 +23,7 @@ from astrbot.core.provider.entities import (
|
||||
)
|
||||
from astrbot.core.star.star_handler import EventType, star_map
|
||||
from astrbot.core.utils.file_extract import extract_file_moonshotai
|
||||
from astrbot.core.utils.llm_metadata import LLM_METADATAS
|
||||
from astrbot.core.utils.metrics import Metric
|
||||
from astrbot.core.utils.session_lock import session_lock_manager
|
||||
|
||||
@@ -41,11 +41,6 @@ class InternalAgentSubStage(Stage):
|
||||
self.ctx = ctx
|
||||
conf = ctx.astrbot_config
|
||||
settings = conf["provider_settings"]
|
||||
self.max_context_length = settings["max_context_length"] # int
|
||||
self.dequeue_context_length: int = min(
|
||||
max(1, settings["dequeue_context_length"]),
|
||||
self.max_context_length - 1,
|
||||
)
|
||||
self.streaming_response: bool = settings["streaming_response"]
|
||||
self.unsupported_streaming_strategy: str = settings[
|
||||
"unsupported_streaming_strategy"
|
||||
@@ -65,6 +60,23 @@ class InternalAgentSubStage(Stage):
|
||||
"moonshotai_api_key", ""
|
||||
)
|
||||
|
||||
# 上下文管理相关
|
||||
self.context_limit_reached_strategy: str = settings.get(
|
||||
"context_limit_reached_strategy", "truncate_by_turns"
|
||||
)
|
||||
self.llm_compress_instruction: str = settings.get(
|
||||
"llm_compress_instruction", ""
|
||||
)
|
||||
self.llm_compress_keep_recent: int = settings.get("llm_compress_keep_recent", 4)
|
||||
self.llm_compress_provider_id: str = settings.get(
|
||||
"llm_compress_provider_id", ""
|
||||
)
|
||||
self.max_context_length = settings["max_context_length"] # int
|
||||
self.dequeue_context_length: int = min(
|
||||
max(1, settings["dequeue_context_length"]),
|
||||
self.max_context_length - 1,
|
||||
)
|
||||
|
||||
self.conv_manager = ctx.plugin_manager.context.conversation_manager
|
||||
|
||||
def _select_provider(self, event: AstrMessageEvent):
|
||||
@@ -167,34 +179,6 @@ class InternalAgentSubStage(Stage):
|
||||
},
|
||||
)
|
||||
|
||||
def _truncate_contexts(
|
||||
self,
|
||||
contexts: list[dict],
|
||||
) -> list[dict]:
|
||||
"""截断上下文列表,确保不超过最大长度"""
|
||||
if self.max_context_length == -1:
|
||||
return contexts
|
||||
|
||||
if len(contexts) // 2 <= self.max_context_length:
|
||||
return contexts
|
||||
|
||||
truncated_contexts = contexts[
|
||||
-(self.max_context_length - self.dequeue_context_length + 1) * 2 :
|
||||
]
|
||||
# 找到第一个role 为 user 的索引,确保上下文格式正确
|
||||
index = next(
|
||||
(
|
||||
i
|
||||
for i, item in enumerate(truncated_contexts)
|
||||
if item.get("role") == "user"
|
||||
),
|
||||
None,
|
||||
)
|
||||
if index is not None and index > 0:
|
||||
truncated_contexts = truncated_contexts[index:]
|
||||
|
||||
return truncated_contexts
|
||||
|
||||
def _modalities_fix(
|
||||
self,
|
||||
provider: Provider,
|
||||
@@ -328,21 +312,25 @@ class InternalAgentSubStage(Stage):
|
||||
history=message_to_save,
|
||||
)
|
||||
|
||||
def _fix_messages(self, messages: list[dict]) -> list[dict]:
|
||||
"""验证并且修复上下文"""
|
||||
fixed_messages = []
|
||||
for message in messages:
|
||||
if message.get("role") == "tool":
|
||||
# tool block 前面必须要有 user 和 assistant block
|
||||
if len(fixed_messages) < 2:
|
||||
# 这种情况可能是上下文被截断导致的
|
||||
# 我们直接将之前的上下文都清空
|
||||
fixed_messages = []
|
||||
else:
|
||||
fixed_messages.append(message)
|
||||
else:
|
||||
fixed_messages.append(message)
|
||||
return fixed_messages
|
||||
def _get_compress_provider(self) -> Provider | None:
|
||||
if not self.llm_compress_provider_id:
|
||||
return None
|
||||
if self.context_limit_reached_strategy != "llm_compress":
|
||||
return None
|
||||
provider = self.ctx.plugin_manager.context.get_provider_by_id(
|
||||
self.llm_compress_provider_id,
|
||||
)
|
||||
if provider is None:
|
||||
logger.warning(
|
||||
f"未找到指定的上下文压缩模型 {self.llm_compress_provider_id},将跳过压缩。",
|
||||
)
|
||||
return None
|
||||
if not isinstance(provider, Provider):
|
||||
logger.warning(
|
||||
f"指定的上下文压缩模型 {self.llm_compress_provider_id} 不是对话模型,将跳过压缩。"
|
||||
)
|
||||
return None
|
||||
return provider
|
||||
|
||||
async def process(
|
||||
self, event: AstrMessageEvent, provider_wake_prefix: str
|
||||
@@ -426,9 +414,10 @@ class InternalAgentSubStage(Stage):
|
||||
await self._apply_kb(event, req)
|
||||
|
||||
# truncate contexts to fit max length
|
||||
if req.contexts:
|
||||
req.contexts = self._truncate_contexts(req.contexts)
|
||||
self._fix_messages(req.contexts)
|
||||
# NOW moved to ContextManager inside ToolLoopAgentRunner
|
||||
# if req.contexts:
|
||||
# req.contexts = self._truncate_contexts(req.contexts)
|
||||
# self._fix_messages(req.contexts)
|
||||
|
||||
# session_id
|
||||
if not req.session_id:
|
||||
@@ -444,8 +433,6 @@ class InternalAgentSubStage(Stage):
|
||||
self.unsupported_streaming_strategy == "turn_off"
|
||||
and not event.platform_meta.support_streaming_message
|
||||
)
|
||||
# 备份 req.contexts
|
||||
backup_contexts = copy.deepcopy(req.contexts)
|
||||
|
||||
# run agent
|
||||
agent_runner = AgentRunner()
|
||||
@@ -456,6 +443,15 @@ class InternalAgentSubStage(Stage):
|
||||
context=self.ctx.plugin_manager.context,
|
||||
event=event,
|
||||
)
|
||||
|
||||
# inject model context length limit
|
||||
if provider.provider_config.get("max_context_tokens", 0) <= 0:
|
||||
model = provider.get_model()
|
||||
if model_info := LLM_METADATAS.get(model):
|
||||
provider.provider_config["max_context_tokens"] = model_info[
|
||||
"limit"
|
||||
]["context"]
|
||||
|
||||
await agent_runner.reset(
|
||||
provider=provider,
|
||||
request=req,
|
||||
@@ -466,6 +462,11 @@ class InternalAgentSubStage(Stage):
|
||||
tool_executor=FunctionToolExecutor(),
|
||||
agent_hooks=MAIN_AGENT_HOOKS,
|
||||
streaming=streaming_response,
|
||||
llm_compress_instruction=self.llm_compress_instruction,
|
||||
llm_compress_keep_recent=self.llm_compress_keep_recent,
|
||||
llm_compress_provider=self._get_compress_provider(),
|
||||
truncate_turns=self.dequeue_context_length,
|
||||
enforce_max_turns=self.max_context_length,
|
||||
)
|
||||
|
||||
if streaming_response and not stream_to_general:
|
||||
@@ -511,9 +512,6 @@ class InternalAgentSubStage(Stage):
|
||||
):
|
||||
yield
|
||||
|
||||
# 恢复备份的 contexts
|
||||
req.contexts = backup_contexts
|
||||
|
||||
await self._save_to_history(
|
||||
event,
|
||||
req,
|
||||
|
||||
@@ -144,7 +144,7 @@
|
||||
color="primary"
|
||||
density="compact"
|
||||
hide-details
|
||||
class="flex-grow-1"
|
||||
style="flex: 1"
|
||||
></v-slider>
|
||||
<v-text-field
|
||||
:model-value="modelValue"
|
||||
@@ -154,7 +154,7 @@
|
||||
class="config-field"
|
||||
type="number"
|
||||
hide-details
|
||||
style="max-width: 140px;"
|
||||
style="flex: 1"
|
||||
></v-text-field>
|
||||
</div>
|
||||
|
||||
@@ -325,4 +325,8 @@ function getSpecialSubtype(value) {
|
||||
.gap-20 {
|
||||
gap: 20px;
|
||||
}
|
||||
|
||||
:deep(.v-field__input) {
|
||||
font-size: 14px;
|
||||
}
|
||||
</style>
|
||||
|
||||
@@ -510,7 +510,7 @@ export function useProviderSources(options: UseProviderSourcesOptions) {
|
||||
|
||||
const metadata = getModelMetadata(modelName)
|
||||
let modalities: string[]
|
||||
|
||||
|
||||
if (!metadata) {
|
||||
modalities = ['text', 'image', 'tool_use']
|
||||
} else {
|
||||
@@ -523,13 +523,19 @@ export function useProviderSources(options: UseProviderSourcesOptions) {
|
||||
}
|
||||
}
|
||||
|
||||
let max_context_tokens = 0
|
||||
if (metadata?.limit?.context && typeof metadata.limit.context === 'number') {
|
||||
max_context_tokens = metadata.limit.context
|
||||
}
|
||||
|
||||
const newProvider = {
|
||||
id: newId,
|
||||
enable: false,
|
||||
provider_source_id: sourceId,
|
||||
model: modelName,
|
||||
modalities,
|
||||
custom_extra_body: {}
|
||||
custom_extra_body: {},
|
||||
max_context_tokens: max_context_tokens
|
||||
}
|
||||
|
||||
try {
|
||||
|
||||
@@ -11,7 +11,12 @@
|
||||
},
|
||||
"agent_runner_type": {
|
||||
"description": "Runner",
|
||||
"labels": ["Built-in Agent", "Dify", "Coze", "Alibaba Cloud Bailian Application"]
|
||||
"labels": [
|
||||
"Built-in Agent",
|
||||
"Dify",
|
||||
"Coze",
|
||||
"Alibaba Cloud Bailian Application"
|
||||
]
|
||||
},
|
||||
"coze_agent_runner_provider_id": {
|
||||
"description": "Coze Agent Runner Provider ID"
|
||||
@@ -128,6 +133,39 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"truncate_and_compress": {
|
||||
"description": "Context Management Strategy",
|
||||
"provider_settings": {
|
||||
"max_context_length": {
|
||||
"description": "Maximum Conversation Turns",
|
||||
"hint": "Discards the oldest parts when this count is exceeded. One conversation round counts as 1, -1 means unlimited"
|
||||
},
|
||||
"dequeue_context_length": {
|
||||
"description": "Dequeue Conversation Turns",
|
||||
"hint": "Number of conversation turns to discard at once when maximum context length is exceeded"
|
||||
},
|
||||
"context_limit_reached_strategy": {
|
||||
"description": "Handling When Model Context Window is Exceeded",
|
||||
"labels": [
|
||||
"Truncate by Turns",
|
||||
"Compress by LLM"
|
||||
],
|
||||
"hint": "When 'Truncate by Turns' is selected, the oldest N conversation turns will be discarded based on the 'Dequeue Conversation Turns' setting above. When 'Compress by LLM' is selected, the specified model will be used for context compression."
|
||||
},
|
||||
"llm_compress_instruction": {
|
||||
"description": "Context Compression Instruction",
|
||||
"hint": "If empty, the default prompt will be used."
|
||||
},
|
||||
"llm_compress_keep_recent": {
|
||||
"description": "Keep Recent Turns When Compressing",
|
||||
"hint": "Always keep the most recent N turns of conversation when compressing context."
|
||||
},
|
||||
"llm_compress_provider_id": {
|
||||
"description": "Model Provider ID for Context Compression",
|
||||
"hint": "When left empty, will fall back to the 'Truncate by Turns' strategy."
|
||||
}
|
||||
}
|
||||
},
|
||||
"others": {
|
||||
"description": "Other Settings",
|
||||
"provider_settings": {
|
||||
@@ -161,15 +199,10 @@
|
||||
"unsupported_streaming_strategy": {
|
||||
"description": "Platforms Without Streaming Support",
|
||||
"hint": "Select the handling method for platforms that don't support streaming responses. Real-time segmented reply sends content immediately when the system detects segment points like punctuation during streaming reception",
|
||||
"labels": ["Real-time Segmented Reply", "Disable Streaming Response"]
|
||||
},
|
||||
"max_context_length": {
|
||||
"description": "Maximum Conversation Rounds",
|
||||
"hint": "Discards the oldest parts when this count is exceeded. One conversation round counts as 1, -1 means unlimited"
|
||||
},
|
||||
"dequeue_context_length": {
|
||||
"description": "Dequeue Conversation Rounds",
|
||||
"hint": "Number of conversation rounds to discard at once when maximum context length is exceeded"
|
||||
"labels": [
|
||||
"Real-time Segmented Reply",
|
||||
"Disable Streaming Response"
|
||||
]
|
||||
},
|
||||
"wake_prefix": {
|
||||
"description": "Additional LLM Chat Wake Prefix",
|
||||
@@ -387,7 +420,10 @@
|
||||
},
|
||||
"split_mode": {
|
||||
"description": "Split Mode",
|
||||
"labels": ["Regex", "Words List"]
|
||||
"labels": [
|
||||
"Regex",
|
||||
"Words List"
|
||||
]
|
||||
},
|
||||
"regex": {
|
||||
"description": "Segmentation Regular Expression"
|
||||
@@ -488,4 +524,4 @@
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -133,6 +133,36 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"truncate_and_compress": {
|
||||
"description": "上下文管理策略",
|
||||
"provider_settings": {
|
||||
"max_context_length": {
|
||||
"description": "最多携带对话轮数",
|
||||
"hint": "超出这个数量时丢弃最旧的部分,一轮聊天记为 1 条,-1 为不限制"
|
||||
},
|
||||
"dequeue_context_length": {
|
||||
"description": "丢弃对话轮数",
|
||||
"hint": "超出最多携带对话轮数时, 一次丢弃的聊天轮数"
|
||||
},
|
||||
"context_limit_reached_strategy": {
|
||||
"description": "超出模型上下文窗口时的处理方式",
|
||||
"labels": ["按对话轮数截断", "由 LLM 压缩上下文"],
|
||||
"hint": "当按对话轮数截断时,会根据上面\"丢弃对话轮数\"的配置丢弃最旧的 N 轮对话。当由 LLM 压缩上下文时,会使用指定的模型进行上下文压缩。"
|
||||
},
|
||||
"llm_compress_instruction": {
|
||||
"description": "上下文压缩提示词",
|
||||
"hint": "如果为空则使用默认提示词。"
|
||||
},
|
||||
"llm_compress_keep_recent": {
|
||||
"description": "压缩时保留最近对话轮数",
|
||||
"hint": "始终保留的最近 N 轮对话。"
|
||||
},
|
||||
"llm_compress_provider_id": {
|
||||
"description": "用于上下文压缩的模型提供商 ID",
|
||||
"hint": "留空时将降级为\"按对话轮数截断\"的策略。"
|
||||
}
|
||||
}
|
||||
},
|
||||
"others": {
|
||||
"description": "其他配置",
|
||||
"provider_settings": {
|
||||
@@ -171,14 +201,7 @@
|
||||
"关闭流式回复"
|
||||
]
|
||||
},
|
||||
"max_context_length": {
|
||||
"description": "最多携带对话轮数",
|
||||
"hint": "超出这个数量时丢弃最旧的部分,一轮聊天记为 1 条,-1 为不限制"
|
||||
},
|
||||
"dequeue_context_length": {
|
||||
"description": "丢弃对话轮数",
|
||||
"hint": "超出最多携带对话轮数时, 一次丢弃的聊天轮数"
|
||||
},
|
||||
|
||||
"wake_prefix": {
|
||||
"description": "LLM 聊天额外唤醒前缀",
|
||||
"hint": "如果唤醒前缀为 /, 额外聊天唤醒前缀为 chat,则需要 /chat 才会触发 LLM 请求"
|
||||
|
||||
Reference in New Issue
Block a user