feat: add MockProvider for LLM compression tests

This commit is contained in:
Soulter
2026-01-05 10:57:00 +08:00
parent 1ed4d9f484
commit e23a7e2505
+53 -2
View File
@@ -13,6 +13,33 @@ sys.path.insert(0, str(Path(__file__).parent.parent.parent))
from astrbot.core.agent.context.config import ContextConfig
from astrbot.core.agent.context.manager import ContextManager
from astrbot.core.agent.message import Message, TextPart
from astrbot.core.provider.entities import LLMResponse
class MockProvider:
"""模拟 Provider"""
def __init__(self):
self.provider_config = {
"id": "test_provider",
"model": "gpt-4",
"modalities": ["text", "image", "tool_use"],
}
async def text_chat(self, **kwargs):
"""模拟 LLM 调用,返回摘要"""
messages = kwargs.get("messages", [])
# 简单的摘要逻辑:返回消息数量统计
return LLMResponse(
role="assistant",
completion_text=f"历史对话包含 {len(messages) - 1} 条消息,主要讨论了技术话题。",
)
def get_model(self):
return "gpt-4"
def meta(self):
return MagicMock(id="test_provider", type="openai")
class TestContextManager:
@@ -46,9 +73,9 @@ class TestContextManager:
def test_init_with_llm_compressor(self):
"""Test initialization with LLM-based compression."""
mock_provider = MagicMock()
mock_provider = MockProvider()
config = ContextConfig(
llm_compress_provider=mock_provider,
llm_compress_provider=mock_provider, # type: ignore
llm_compress_keep_recent=5,
llm_compress_instruction="Summarize the conversation",
)
@@ -631,3 +658,27 @@ class TestContextManager:
# Compressor should have been called
mock_compress.assert_called_once()
assert len(result) <= len(messages)
@pytest.mark.asyncio
async def test_llm_compression_with_mock_provider(self):
"""Test LLM compression using MockProvider."""
mock_provider = MockProvider()
config = ContextConfig(
llm_compress_provider=mock_provider, # type: ignore
llm_compress_keep_recent=3,
llm_compress_instruction="请总结对话内容",
max_context_tokens=100,
)
manager = ContextManager(config)
# Create messages that will trigger compression
messages = [
self.create_message("user", "x" * 100),
self.create_message("assistant", "y" * 100),
self.create_message("user", "z" * 100),
]
result = await manager.process(messages)
# Should have been compressed
assert len(result) <= len(messages)