diff --git a/tests/agent/test_context_manager.py b/tests/agent/test_context_manager.py index d1577f64d..9e7b6b425 100644 --- a/tests/agent/test_context_manager.py +++ b/tests/agent/test_context_manager.py @@ -13,6 +13,33 @@ sys.path.insert(0, str(Path(__file__).parent.parent.parent)) from astrbot.core.agent.context.config import ContextConfig from astrbot.core.agent.context.manager import ContextManager from astrbot.core.agent.message import Message, TextPart +from astrbot.core.provider.entities import LLMResponse + + +class MockProvider: + """模拟 Provider""" + + def __init__(self): + self.provider_config = { + "id": "test_provider", + "model": "gpt-4", + "modalities": ["text", "image", "tool_use"], + } + + async def text_chat(self, **kwargs): + """模拟 LLM 调用,返回摘要""" + messages = kwargs.get("messages", []) + # 简单的摘要逻辑:返回消息数量统计 + return LLMResponse( + role="assistant", + completion_text=f"历史对话包含 {len(messages) - 1} 条消息,主要讨论了技术话题。", + ) + + def get_model(self): + return "gpt-4" + + def meta(self): + return MagicMock(id="test_provider", type="openai") class TestContextManager: @@ -46,9 +73,9 @@ class TestContextManager: def test_init_with_llm_compressor(self): """Test initialization with LLM-based compression.""" - mock_provider = MagicMock() + mock_provider = MockProvider() config = ContextConfig( - llm_compress_provider=mock_provider, + llm_compress_provider=mock_provider, # type: ignore llm_compress_keep_recent=5, llm_compress_instruction="Summarize the conversation", ) @@ -631,3 +658,27 @@ class TestContextManager: # Compressor should have been called mock_compress.assert_called_once() assert len(result) <= len(messages) + + @pytest.mark.asyncio + async def test_llm_compression_with_mock_provider(self): + """Test LLM compression using MockProvider.""" + mock_provider = MockProvider() + config = ContextConfig( + llm_compress_provider=mock_provider, # type: ignore + llm_compress_keep_recent=3, + llm_compress_instruction="请总结对话内容", + max_context_tokens=100, + ) + manager = ContextManager(config) + + # Create messages that will trigger compression + messages = [ + self.create_message("user", "x" * 100), + self.create_message("assistant", "y" * 100), + self.create_message("user", "z" * 100), + ] + + result = await manager.process(messages) + + # Should have been compressed + assert len(result) <= len(messages)