1ed4d9f484
- Implemented a full test suite for ContextManager covering initialization, message processing, token-based compression, and error handling. - Added tests for ContextTruncator focusing on message fixing, truncation by turns, dropping oldest turns, and halving. - Ensured that both test suites validate edge cases and maintain expected behavior with various message types, including system and tool messages.
165 lines
5.2 KiB
Python
165 lines
5.2 KiB
Python
from typing import TYPE_CHECKING, Protocol, runtime_checkable
|
|
|
|
from ..message import Message
|
|
|
|
if TYPE_CHECKING:
|
|
from astrbot import logger
|
|
else:
|
|
try:
|
|
from astrbot import logger
|
|
except ImportError:
|
|
import logging
|
|
|
|
logger = logging.getLogger("astrbot")
|
|
|
|
if TYPE_CHECKING:
|
|
from astrbot.core.provider.provider import Provider
|
|
|
|
from ..context.truncator import ContextTruncator
|
|
|
|
|
|
@runtime_checkable
|
|
class ContextCompressor(Protocol):
|
|
"""
|
|
Protocol for context compressors.
|
|
Provides an interface for compressing message lists.
|
|
"""
|
|
|
|
async def __call__(self, messages: list[Message]) -> list[Message]:
|
|
"""Compress the message list.
|
|
|
|
Args:
|
|
messages: The original message list.
|
|
|
|
Returns:
|
|
The compressed message list.
|
|
"""
|
|
...
|
|
|
|
|
|
class TruncateByTurnsCompressor:
|
|
"""Truncate by turns compressor implementation.
|
|
Truncates the message list by removing older turns.
|
|
"""
|
|
|
|
def __init__(self, truncate_turns: int = 1):
|
|
"""Initialize the truncate by turns compressor.
|
|
|
|
Args:
|
|
truncate_turns: The number of turns to remove when truncating (default: 1).
|
|
"""
|
|
self.truncate_turns = truncate_turns
|
|
|
|
async def __call__(self, messages: list[Message]) -> list[Message]:
|
|
truncator = ContextTruncator()
|
|
truncated_messages = truncator.truncate_by_dropping_oldest_turns(
|
|
messages,
|
|
drop_turns=self.truncate_turns,
|
|
)
|
|
return truncated_messages
|
|
|
|
|
|
def split_history(
|
|
messages: list[Message], keep_recent: int
|
|
) -> tuple[list[Message], list[Message], list[Message]]:
|
|
"""Split the message list into system messages, messages to summarize, and recent messages.
|
|
|
|
Args:
|
|
messages: The original message list.
|
|
keep_recent: The number of latest messages to keep.
|
|
|
|
Returns:
|
|
tuple: (system_messages, messages_to_summarize, recent_messages)
|
|
"""
|
|
# keep the system messages
|
|
first_non_system = 0
|
|
for i, msg in enumerate(messages):
|
|
if msg.role != "system":
|
|
first_non_system = i
|
|
break
|
|
|
|
system_messages = messages[:first_non_system]
|
|
non_system_messages = messages[first_non_system:]
|
|
|
|
if len(non_system_messages) <= keep_recent:
|
|
return system_messages, [], non_system_messages
|
|
|
|
messages_to_summarize = non_system_messages[:-keep_recent]
|
|
recent_messages = non_system_messages[-keep_recent:]
|
|
|
|
return system_messages, messages_to_summarize, recent_messages
|
|
|
|
|
|
class LLMSummaryCompressor:
|
|
"""LLM-based summary compressor.
|
|
Uses LLM to summarize the old conversation history, keeping the latest messages.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
provider: "Provider",
|
|
keep_recent: int = 4,
|
|
instruction_text: str | None = None,
|
|
):
|
|
"""Initialize the LLM summary compressor.
|
|
|
|
Args:
|
|
provider: The LLM provider instance.
|
|
keep_recent: The number of latest messages to keep (default: 4).
|
|
"""
|
|
self.provider = provider
|
|
self.keep_recent = keep_recent
|
|
|
|
self.instruction_text = instruction_text or (
|
|
"Based on our full conversation history, produce a concise summary of key takeaways and/or project progress.\n"
|
|
"1. Systematically cover all core topics discussed and the final conclusion/outcome for each; clearly highlight the latest primary focus.\n"
|
|
"2. If any tools were used, summarize tool usage (total call count) and extract the most valuable insights from tool outputs.\n"
|
|
"3. If there was an initial user goal, state it first and describe the current progress/status.\n"
|
|
"4. Write the summary in the user's language.\n"
|
|
)
|
|
|
|
async def __call__(self, messages: list[Message]) -> list[Message]:
|
|
"""Use LLM to generate a summary of the conversation history.
|
|
|
|
Process:
|
|
1. Divide messages: keep the system message and the latest N messages.
|
|
2. Send the old messages + the instruction message to the LLM.
|
|
3. Reconstruct the message list: [system message, summary message, latest messages].
|
|
"""
|
|
if len(messages) <= self.keep_recent + 1:
|
|
return messages
|
|
|
|
system_messages, messages_to_summarize, recent_messages = split_history(
|
|
messages, self.keep_recent
|
|
)
|
|
|
|
if not messages_to_summarize:
|
|
return messages
|
|
|
|
# build payload
|
|
instruction_message = Message(role="user", content=self.instruction_text)
|
|
llm_payload = messages_to_summarize + [instruction_message]
|
|
|
|
# generate summary
|
|
try:
|
|
response = await self.provider.text_chat(contexts=llm_payload)
|
|
summary_content = response.completion_text
|
|
except Exception as e:
|
|
logger.error(f"Failed to generate summary: {e}")
|
|
return messages
|
|
|
|
# build result
|
|
result = []
|
|
result.extend(system_messages)
|
|
|
|
result.append(
|
|
Message(
|
|
role="system",
|
|
content=f"History conversation summary: {summary_content}",
|
|
),
|
|
)
|
|
|
|
result.extend(recent_messages)
|
|
|
|
return result
|