1514 lines
55 KiB
Python
1514 lines
55 KiB
Python
"""Tests for astr_main_agent module."""
|
|
|
|
import os
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
from astrbot.core import astr_main_agent as ama
|
|
from astrbot.core.agent.mcp_client import MCPTool
|
|
from astrbot.core.agent.tool import FunctionTool, ToolSet
|
|
from astrbot.core.conversation_mgr import Conversation
|
|
from astrbot.core.message.components import File, Image, Plain, Reply
|
|
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
|
from astrbot.core.platform.platform_metadata import PlatformMetadata
|
|
from astrbot.core.provider import Provider
|
|
from astrbot.core.provider.entities import ProviderRequest
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_provider():
|
|
"""Create a mock provider."""
|
|
provider = MagicMock(spec=Provider)
|
|
provider.provider_config = {
|
|
"id": "test-provider",
|
|
"modalities": ["image", "tool_use"],
|
|
}
|
|
provider.get_model.return_value = "gpt-4"
|
|
return provider
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_context():
|
|
"""Create a mock Context."""
|
|
ctx = MagicMock()
|
|
ctx.get_config.return_value = {}
|
|
ctx.conversation_manager = MagicMock()
|
|
ctx.persona_manager = MagicMock()
|
|
ctx.persona_manager.personas_v3 = []
|
|
ctx.persona_manager.resolve_selected_persona = AsyncMock(
|
|
return_value=(None, None, None, False)
|
|
)
|
|
ctx.get_llm_tool_manager.return_value = MagicMock()
|
|
ctx.subagent_orchestrator = None
|
|
return ctx
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_event():
|
|
"""Create a mock AstrMessageEvent."""
|
|
platform_meta = PlatformMetadata(
|
|
id="test_platform",
|
|
name="test_platform",
|
|
description="Test platform",
|
|
)
|
|
message_obj = MagicMock()
|
|
message_obj.message = [Plain(text="Hello")]
|
|
message_obj.sender = MagicMock(user_id="user123", nickname="TestUser")
|
|
message_obj.group_id = None
|
|
message_obj.group = None
|
|
|
|
event = MagicMock(spec=AstrMessageEvent)
|
|
event.message_str = "Hello"
|
|
event.message_obj = message_obj
|
|
event.platform_meta = platform_meta
|
|
event.session_id = "session123"
|
|
event.unified_msg_origin = "test_platform:private:session123"
|
|
event.get_extra.return_value = None
|
|
event.get_platform_name.return_value = "test_platform"
|
|
event.get_platform_id.return_value = "test_platform"
|
|
event.get_group_id.return_value = None
|
|
event.get_sender_name.return_value = "TestUser"
|
|
event.trace = MagicMock()
|
|
event.plugins_name = None
|
|
return event
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_conversation():
|
|
"""Create a mock conversation."""
|
|
conv = MagicMock(spec=Conversation)
|
|
conv.cid = "conv-id"
|
|
conv.persona_id = None
|
|
conv.history = "[]"
|
|
return conv
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_config():
|
|
"""Create a sample MainAgentBuildConfig."""
|
|
module = ama
|
|
return module.MainAgentBuildConfig(
|
|
tool_call_timeout=60,
|
|
streaming_response=True,
|
|
file_extract_enabled=True,
|
|
file_extract_prov="moonshotai",
|
|
file_extract_msh_api_key="test-api-key",
|
|
)
|
|
|
|
|
|
def _new_mock_conversation(cid: str = "conv-id") -> MagicMock:
|
|
conv = MagicMock(spec=Conversation)
|
|
conv.cid = cid
|
|
conv.persona_id = None
|
|
conv.history = "[]"
|
|
return conv
|
|
|
|
|
|
def _setup_conversation_for_build(conv_mgr, cid: str = "conv-id") -> MagicMock:
|
|
conv_mgr.get_curr_conversation_id = AsyncMock(return_value=None)
|
|
conv_mgr.new_conversation = AsyncMock(return_value=cid)
|
|
conversation = _new_mock_conversation(cid=cid)
|
|
conv_mgr.get_conversation = AsyncMock(return_value=conversation)
|
|
return conversation
|
|
|
|
|
|
class TestMainAgentBuildConfig:
|
|
"""Tests for MainAgentBuildConfig dataclass."""
|
|
|
|
def test_config_initialization(self):
|
|
"""Test MainAgentBuildConfig initialization with defaults."""
|
|
module = ama
|
|
config = module.MainAgentBuildConfig(tool_call_timeout=60)
|
|
assert config.tool_call_timeout == 60
|
|
assert config.tool_schema_mode == "full"
|
|
assert config.provider_wake_prefix == ""
|
|
assert config.streaming_response is True
|
|
assert config.sanitize_context_by_modalities is False
|
|
assert config.kb_agentic_mode is False
|
|
assert config.file_extract_enabled is False
|
|
assert config.llm_safety_mode is True
|
|
|
|
def test_config_with_custom_values(self):
|
|
"""Test MainAgentBuildConfig with custom values."""
|
|
module = ama
|
|
config = module.MainAgentBuildConfig(
|
|
tool_call_timeout=120,
|
|
tool_schema_mode="skills-like",
|
|
provider_wake_prefix="/",
|
|
streaming_response=False,
|
|
kb_agentic_mode=True,
|
|
file_extract_enabled=True,
|
|
computer_use_runtime="sandbox",
|
|
add_cron_tools=False,
|
|
)
|
|
assert config.tool_call_timeout == 120
|
|
assert config.tool_schema_mode == "skills-like"
|
|
assert config.provider_wake_prefix == "/"
|
|
assert config.streaming_response is False
|
|
assert config.kb_agentic_mode is True
|
|
assert config.file_extract_enabled is True
|
|
assert config.computer_use_runtime == "sandbox"
|
|
assert config.add_cron_tools is False
|
|
|
|
|
|
class TestSelectProvider:
|
|
"""Tests for _select_provider function."""
|
|
|
|
def test_select_provider_by_id(self, mock_event, mock_context, mock_provider):
|
|
"""Test selecting provider by ID from event extra."""
|
|
module = ama
|
|
mock_event.get_extra.side_effect = lambda k: (
|
|
"test-provider" if k == "selected_provider" else None
|
|
)
|
|
mock_context.get_provider_by_id.return_value = mock_provider
|
|
|
|
result = module._select_provider(mock_event, mock_context)
|
|
|
|
assert result == mock_provider
|
|
mock_context.get_provider_by_id.assert_called_once_with("test-provider")
|
|
|
|
def test_select_provider_not_found(self, mock_event, mock_context):
|
|
"""Test selecting provider when ID is not found."""
|
|
module = ama
|
|
mock_event.get_extra.side_effect = lambda k: (
|
|
"non-existent" if k == "selected_provider" else None
|
|
)
|
|
mock_context.get_provider_by_id.return_value = None
|
|
|
|
result = module._select_provider(mock_event, mock_context)
|
|
|
|
assert result is None
|
|
|
|
def test_select_provider_invalid_type(self, mock_event, mock_context):
|
|
"""Test selecting provider when result is not a Provider instance."""
|
|
module = ama
|
|
mock_event.get_extra.side_effect = lambda k: (
|
|
"invalid" if k == "selected_provider" else None
|
|
)
|
|
mock_context.get_provider_by_id.return_value = "not a provider"
|
|
|
|
result = module._select_provider(mock_event, mock_context)
|
|
|
|
assert result is None
|
|
|
|
def test_select_provider_fallback(self, mock_event, mock_context, mock_provider):
|
|
"""Test provider selection fallback to using provider."""
|
|
module = ama
|
|
mock_event.get_extra.return_value = None
|
|
mock_context.get_using_provider.return_value = mock_provider
|
|
|
|
result = module._select_provider(mock_event, mock_context)
|
|
|
|
assert result == mock_provider
|
|
mock_context.get_using_provider.assert_called_once_with(
|
|
umo=mock_event.unified_msg_origin
|
|
)
|
|
|
|
def test_select_provider_fallback_error(self, mock_event, mock_context):
|
|
"""Test provider selection when fallback raises ValueError."""
|
|
module = ama
|
|
mock_event.get_extra.return_value = None
|
|
mock_context.get_using_provider.side_effect = ValueError("Test error")
|
|
|
|
result = module._select_provider(mock_event, mock_context)
|
|
|
|
assert result is None
|
|
|
|
|
|
class TestGetSessionConv:
|
|
"""Tests for _get_session_conv function."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_session_conv_existing(
|
|
self, mock_event, mock_context, mock_conversation
|
|
):
|
|
"""Test getting existing conversation."""
|
|
module = ama
|
|
conv_mgr = mock_context.conversation_manager
|
|
conv_mgr.get_curr_conversation_id = AsyncMock(return_value="existing-conv-id")
|
|
conv_mgr.get_conversation = AsyncMock(return_value=mock_conversation)
|
|
|
|
result = await module._get_session_conv(mock_event, mock_context)
|
|
|
|
assert result == mock_conversation
|
|
conv_mgr.get_curr_conversation_id.assert_called_once_with(
|
|
mock_event.unified_msg_origin
|
|
)
|
|
conv_mgr.get_conversation.assert_called_once_with(
|
|
mock_event.unified_msg_origin, "existing-conv-id"
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_session_conv_create_new(self, mock_event, mock_context):
|
|
"""Test creating new conversation when none exists."""
|
|
module = ama
|
|
conv_mgr = mock_context.conversation_manager
|
|
conv_mgr.get_curr_conversation_id = AsyncMock(return_value=None)
|
|
conv_mgr.new_conversation = AsyncMock(return_value="new-conv-id")
|
|
mock_conversation = MagicMock(spec=Conversation)
|
|
mock_conversation.cid = "new-conv-id"
|
|
mock_conversation.persona_id = None
|
|
mock_conversation.history = "[]"
|
|
conv_mgr.get_conversation = AsyncMock(return_value=mock_conversation)
|
|
|
|
result = await module._get_session_conv(mock_event, mock_context)
|
|
|
|
assert result == mock_conversation
|
|
conv_mgr.new_conversation.assert_called_once_with(
|
|
mock_event.unified_msg_origin, mock_event.get_platform_id()
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_session_conv_retry(self, mock_event, mock_context):
|
|
"""Test retrying conversation creation after failure."""
|
|
module = ama
|
|
conv_mgr = mock_context.conversation_manager
|
|
conv_mgr.get_curr_conversation_id = AsyncMock(return_value="conv-id")
|
|
conv_mgr.get_conversation = AsyncMock(return_value=None)
|
|
conv_mgr.new_conversation = AsyncMock(return_value="retry-conv-id")
|
|
mock_conversation = MagicMock(spec=Conversation)
|
|
mock_conversation.cid = "retry-conv-id"
|
|
mock_conversation.persona_id = None
|
|
mock_conversation.history = "[]"
|
|
conv_mgr.get_conversation.side_effect = [None, mock_conversation]
|
|
|
|
result = await module._get_session_conv(mock_event, mock_context)
|
|
|
|
assert result == mock_conversation
|
|
assert conv_mgr.new_conversation.call_count == 1
|
|
assert conv_mgr.get_conversation.call_count == 2
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_session_conv_failure(self, mock_event, mock_context):
|
|
"""Test RuntimeError when conversation creation fails."""
|
|
module = ama
|
|
conv_mgr = mock_context.conversation_manager
|
|
conv_mgr.get_curr_conversation_id = AsyncMock(return_value=None)
|
|
conv_mgr.new_conversation = AsyncMock(return_value="new-conv-id")
|
|
conv_mgr.get_conversation = AsyncMock(return_value=None)
|
|
|
|
with pytest.raises(RuntimeError, match="无法创建新的对话。"):
|
|
await module._get_session_conv(mock_event, mock_context)
|
|
|
|
|
|
class TestApplyKb:
|
|
"""Tests for _apply_kb function."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_apply_kb_without_agentic_mode(self, mock_event, mock_context):
|
|
"""Test applying knowledge base in non-agentic mode."""
|
|
module = ama
|
|
req = ProviderRequest(prompt="test question", system_prompt="System prompt")
|
|
config = module.MainAgentBuildConfig(
|
|
tool_call_timeout=60, kb_agentic_mode=False
|
|
)
|
|
|
|
with patch(
|
|
"astrbot.core.astr_main_agent.retrieve_knowledge_base",
|
|
AsyncMock(return_value="KB result"),
|
|
):
|
|
await module._apply_kb(mock_event, req, mock_context, config)
|
|
|
|
assert "[Related Knowledge Base Results]:" in req.system_prompt
|
|
assert "KB result" in req.system_prompt
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_apply_kb_with_agentic_mode(self, mock_event, mock_context):
|
|
"""Test applying knowledge base in agentic mode."""
|
|
module = ama
|
|
req = ProviderRequest(prompt="test question")
|
|
config = module.MainAgentBuildConfig(tool_call_timeout=60, kb_agentic_mode=True)
|
|
|
|
await module._apply_kb(mock_event, req, mock_context, config)
|
|
|
|
assert req.func_tool is not None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_apply_kb_no_prompt(self, mock_event, mock_context):
|
|
"""Test applying knowledge base when prompt is None."""
|
|
module = ama
|
|
req = ProviderRequest(prompt=None, system_prompt="System")
|
|
config = module.MainAgentBuildConfig(
|
|
tool_call_timeout=60, kb_agentic_mode=False
|
|
)
|
|
|
|
await module._apply_kb(mock_event, req, mock_context, config)
|
|
|
|
assert req.system_prompt == "System"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_apply_kb_no_result(self, mock_event, mock_context):
|
|
"""Test applying knowledge base when no result is returned."""
|
|
module = ama
|
|
req = ProviderRequest(prompt="test", system_prompt="System")
|
|
config = module.MainAgentBuildConfig(
|
|
tool_call_timeout=60, kb_agentic_mode=False
|
|
)
|
|
|
|
with patch(
|
|
"astrbot.core.astr_main_agent.retrieve_knowledge_base",
|
|
AsyncMock(return_value=None),
|
|
):
|
|
await module._apply_kb(mock_event, req, mock_context, config)
|
|
|
|
assert req.system_prompt == "System"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_apply_kb_with_existing_tools(self, mock_event, mock_context):
|
|
"""Test applying knowledge base with existing toolset."""
|
|
module = ama
|
|
existing_tools = ToolSet()
|
|
req = ProviderRequest(prompt="test", func_tool=existing_tools)
|
|
config = module.MainAgentBuildConfig(tool_call_timeout=60, kb_agentic_mode=True)
|
|
|
|
await module._apply_kb(mock_event, req, mock_context, config)
|
|
|
|
assert req.func_tool is not None
|
|
|
|
|
|
class TestApplyFileExtract:
|
|
"""Tests for _apply_file_extract function."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_file_extract_basic(self, mock_event, sample_config):
|
|
"""Test basic file extraction."""
|
|
module = ama
|
|
mock_file = MagicMock(spec=File)
|
|
mock_file.name = "test.pdf"
|
|
mock_file.get_file = AsyncMock(return_value="/path/to/test.pdf")
|
|
mock_event.message_obj.message = [mock_file]
|
|
|
|
req = ProviderRequest(prompt="Summarize")
|
|
|
|
with patch(
|
|
"astrbot.core.astr_main_agent.extract_file_moonshotai"
|
|
) as mock_extract:
|
|
mock_extract.return_value = "File content"
|
|
|
|
await module._apply_file_extract(mock_event, req, sample_config)
|
|
|
|
assert len(req.contexts) == 1
|
|
assert "File Extract Results" in req.contexts[0]["content"]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_file_extract_no_files(self, mock_event, sample_config):
|
|
"""Test file extraction when no files present."""
|
|
module = ama
|
|
mock_event.message_obj.message = [Plain(text="Hello")]
|
|
req = ProviderRequest(prompt="Hello")
|
|
|
|
await module._apply_file_extract(mock_event, req, sample_config)
|
|
|
|
assert len(req.contexts) == 0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_file_extract_in_reply(self, mock_event, sample_config):
|
|
"""Test file extraction from reply chain."""
|
|
module = ama
|
|
mock_file = MagicMock(spec=File)
|
|
mock_file.name = "reply.pdf"
|
|
mock_file.get_file = AsyncMock(return_value="/path/to/reply.pdf")
|
|
mock_reply = MagicMock(spec=Reply)
|
|
mock_reply.chain = [mock_file]
|
|
mock_event.message_obj.message = [mock_reply]
|
|
|
|
req = ProviderRequest(prompt="Summarize")
|
|
|
|
with patch(
|
|
"astrbot.core.astr_main_agent.extract_file_moonshotai"
|
|
) as mock_extract:
|
|
mock_extract.return_value = "Reply content"
|
|
|
|
await module._apply_file_extract(mock_event, req, sample_config)
|
|
|
|
assert len(req.contexts) == 1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_file_extract_no_prompt(self, mock_event, sample_config):
|
|
"""Test file extraction when prompt is empty."""
|
|
module = ama
|
|
mock_file = MagicMock(spec=File)
|
|
mock_file.name = "test.pdf"
|
|
mock_file.get_file = AsyncMock(return_value="/path/to/test.pdf")
|
|
mock_event.message_obj.message = [mock_file]
|
|
|
|
req = ProviderRequest(prompt=None)
|
|
|
|
with patch(
|
|
"astrbot.core.astr_main_agent.extract_file_moonshotai"
|
|
) as mock_extract:
|
|
mock_extract.return_value = "Content"
|
|
|
|
await module._apply_file_extract(mock_event, req, sample_config)
|
|
|
|
assert req.prompt == "总结一下文件里面讲了什么?"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_file_extract_no_api_key(self, mock_event):
|
|
"""Test file extraction when no API key is configured."""
|
|
module = ama
|
|
config = module.MainAgentBuildConfig(
|
|
tool_call_timeout=60,
|
|
file_extract_enabled=True,
|
|
file_extract_msh_api_key="",
|
|
)
|
|
mock_file = MagicMock(spec=File)
|
|
mock_file.name = "test.pdf"
|
|
mock_file.get_file = AsyncMock(return_value="/path/to/test.pdf")
|
|
mock_event.message_obj.message = [mock_file]
|
|
|
|
req = ProviderRequest(prompt="Summarize")
|
|
|
|
await module._apply_file_extract(mock_event, req, config)
|
|
|
|
assert len(req.contexts) == 0
|
|
|
|
|
|
class TestEnsurePersonaAndSkills:
|
|
"""Tests for _ensure_persona_and_skills function."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_ensure_persona_from_session(self, mock_event, mock_context):
|
|
"""Test applying persona from session service config."""
|
|
module = ama
|
|
persona = {"name": "test-persona", "prompt": "You are helpful."}
|
|
mock_context.persona_manager.personas_v3 = [persona]
|
|
mock_context.persona_manager.resolve_selected_persona = AsyncMock(
|
|
return_value=("test-persona", persona, "test-persona", False)
|
|
)
|
|
mock_event.trace = MagicMock(record=MagicMock())
|
|
req = ProviderRequest()
|
|
req.conversation = MagicMock(persona_id=None)
|
|
|
|
await module._ensure_persona_and_skills(req, {}, mock_context, mock_event)
|
|
|
|
assert "You are helpful." in req.system_prompt
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_ensure_persona_from_conversation(self, mock_event, mock_context):
|
|
"""Test applying persona from conversation setting."""
|
|
module = ama
|
|
persona = {"name": "conv-persona", "prompt": "Custom persona."}
|
|
mock_context.persona_manager.personas_v3 = [persona]
|
|
mock_context.persona_manager.resolve_selected_persona = AsyncMock(
|
|
return_value=("conv-persona", persona, None, False)
|
|
)
|
|
req = ProviderRequest()
|
|
req.conversation = MagicMock(persona_id="conv-persona")
|
|
|
|
await module._ensure_persona_and_skills(req, {}, mock_context, mock_event)
|
|
|
|
assert "Custom persona." in req.system_prompt
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_ensure_persona_none_explicit(self, mock_event, mock_context):
|
|
"""Test that [%None] persona is explicitly set to no persona."""
|
|
module = ama
|
|
mock_context.persona_manager.personas_v3 = []
|
|
mock_context.persona_manager.resolve_selected_persona = AsyncMock(
|
|
return_value=("[%None]", None, None, False)
|
|
)
|
|
req = ProviderRequest()
|
|
req.conversation = MagicMock(persona_id="[%None]")
|
|
|
|
await module._ensure_persona_and_skills(req, {}, mock_context, mock_event)
|
|
|
|
assert "Persona Instructions" not in req.system_prompt
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_ensure_tools_from_persona(self, mock_event, mock_context):
|
|
"""Test applying tools from persona."""
|
|
module = ama
|
|
mock_tool = MagicMock()
|
|
mock_tool.name = "test_tool"
|
|
mock_tool.active = True
|
|
persona = {"name": "persona", "prompt": "Test", "tools": ["test_tool"]}
|
|
mock_context.persona_manager.personas_v3 = [persona]
|
|
mock_context.persona_manager.resolve_selected_persona = AsyncMock(
|
|
return_value=("persona", persona, None, False)
|
|
)
|
|
tmgr = mock_context.get_llm_tool_manager.return_value
|
|
tmgr.get_func.return_value = mock_tool
|
|
|
|
req = ProviderRequest()
|
|
req.conversation = MagicMock(persona_id="persona")
|
|
|
|
await module._ensure_persona_and_skills(req, {}, mock_context, mock_event)
|
|
|
|
assert req.func_tool is not None
|
|
|
|
|
|
class TestDecorateLlmRequest:
|
|
"""Tests for _decorate_llm_request function."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_decorate_llm_request_basic(
|
|
self, mock_event, mock_context, sample_config
|
|
):
|
|
"""Test basic LLM request decoration."""
|
|
module = ama
|
|
req = ProviderRequest(prompt="Hello", system_prompt="System")
|
|
|
|
await module._decorate_llm_request(mock_event, req, mock_context, sample_config)
|
|
|
|
assert req.prompt == "Hello"
|
|
assert req.system_prompt == "System"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_decorate_llm_request_with_prefix(self, mock_event, mock_context):
|
|
"""Test LLM request decoration with prompt prefix."""
|
|
module = ama
|
|
req = ProviderRequest(prompt="Hello")
|
|
config = module.MainAgentBuildConfig(
|
|
tool_call_timeout=60, provider_settings={"prompt_prefix": "AI: "}
|
|
)
|
|
|
|
with patch.object(mock_context, "get_config") as mock_get_config:
|
|
mock_get_config.return_value = {}
|
|
|
|
await module._decorate_llm_request(mock_event, req, mock_context, config)
|
|
|
|
assert req.prompt == "AI: Hello"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_decorate_llm_request_prefix_with_placeholder(
|
|
self, mock_event, mock_context
|
|
):
|
|
"""Test prompt prefix with {{prompt}} placeholder."""
|
|
module = ama
|
|
req = ProviderRequest(prompt="Hello")
|
|
config = module.MainAgentBuildConfig(
|
|
tool_call_timeout=60,
|
|
provider_settings={"prompt_prefix": "AI {{prompt}} - Please respond:"},
|
|
)
|
|
|
|
with patch.object(mock_context, "get_config") as mock_get_config:
|
|
mock_get_config.return_value = {}
|
|
|
|
await module._decorate_llm_request(mock_event, req, mock_context, config)
|
|
|
|
assert req.prompt == "AI Hello - Please respond:"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_decorate_llm_request_no_conversation(self, mock_event, mock_context):
|
|
"""Test decoration when no conversation exists."""
|
|
module = ama
|
|
req = ProviderRequest(prompt="Hello")
|
|
req.conversation = None
|
|
config = module.MainAgentBuildConfig(tool_call_timeout=60)
|
|
|
|
with patch.object(mock_context, "get_config") as mock_get_config:
|
|
mock_get_config.return_value = {}
|
|
|
|
await module._decorate_llm_request(mock_event, req, mock_context, config)
|
|
|
|
assert req.prompt == "Hello"
|
|
|
|
|
|
class TestModalitiesFix:
|
|
"""Tests for _modalities_fix function."""
|
|
|
|
def test_modalities_fix_image_not_supported(self, mock_provider):
|
|
"""Test modality fix when image is not supported."""
|
|
module = ama
|
|
mock_provider.provider_config = {"modalities": ["text"]}
|
|
req = ProviderRequest(prompt="Hello", image_urls=["/path/to/image.jpg"])
|
|
|
|
module._modalities_fix(mock_provider, req)
|
|
|
|
assert "[图片]" in req.prompt
|
|
assert req.image_urls == []
|
|
|
|
def test_modalities_fix_tool_not_supported(self, mock_provider):
|
|
"""Test modality fix when tool is not supported."""
|
|
module = ama
|
|
mock_provider.provider_config = {"modalities": ["text", "image"]}
|
|
req = ProviderRequest(prompt="Hello")
|
|
req.func_tool = ToolSet()
|
|
req.func_tool.add_tool(
|
|
FunctionTool(
|
|
name="dummy_tool",
|
|
description="dummy",
|
|
parameters={"type": "object", "properties": {}},
|
|
)
|
|
)
|
|
|
|
module._modalities_fix(mock_provider, req)
|
|
|
|
assert req.func_tool is None
|
|
|
|
def test_modalities_fix_all_supported(self, mock_provider):
|
|
"""Test modality fix when all features are supported."""
|
|
module = ama
|
|
mock_provider.provider_config = {"modalities": ["image", "tool_use"]}
|
|
tool_set = ToolSet()
|
|
tool_set.add_tool(
|
|
FunctionTool(
|
|
name="dummy_tool",
|
|
description="dummy",
|
|
parameters={"type": "object", "properties": {}},
|
|
)
|
|
)
|
|
req = ProviderRequest(
|
|
prompt="Hello",
|
|
image_urls=["/path/to/image.jpg"],
|
|
func_tool=tool_set,
|
|
)
|
|
|
|
module._modalities_fix(mock_provider, req)
|
|
|
|
assert req.prompt == "Hello"
|
|
assert len(req.image_urls) == 1
|
|
assert req.func_tool is not None
|
|
|
|
|
|
class TestSanitizeContextByModalities:
|
|
"""Tests for _sanitize_context_by_modalities function."""
|
|
|
|
def test_sanitize_no_op(self, mock_provider):
|
|
"""Test sanitize when disabled or modalities support everything."""
|
|
module = ama
|
|
config = module.MainAgentBuildConfig(
|
|
tool_call_timeout=60, sanitize_context_by_modalities=False
|
|
)
|
|
mock_provider.provider_config = {"modalities": ["image", "tool_use"]}
|
|
req = ProviderRequest(contexts=[{"role": "user", "content": "Hello"}])
|
|
|
|
module._sanitize_context_by_modalities(config, mock_provider, req)
|
|
|
|
assert len(req.contexts) == 1
|
|
|
|
def test_sanitize_removes_tool_messages(self, mock_provider):
|
|
"""Test sanitize removes tool messages when tool_use not supported."""
|
|
module = ama
|
|
config = module.MainAgentBuildConfig(
|
|
tool_call_timeout=60, sanitize_context_by_modalities=True
|
|
)
|
|
mock_provider.provider_config = {"modalities": ["image"]}
|
|
req = ProviderRequest(
|
|
contexts=[
|
|
{"role": "user", "content": "Hello"},
|
|
{"role": "tool", "content": "Tool result"},
|
|
]
|
|
)
|
|
|
|
module._sanitize_context_by_modalities(config, mock_provider, req)
|
|
|
|
assert len(req.contexts) == 1
|
|
assert req.contexts[0]["role"] == "user"
|
|
|
|
def test_sanitize_removes_tool_calls(self, mock_provider):
|
|
"""Test sanitize removes tool_calls from assistant messages."""
|
|
module = ama
|
|
config = module.MainAgentBuildConfig(
|
|
tool_call_timeout=60, sanitize_context_by_modalities=True
|
|
)
|
|
mock_provider.provider_config = {"modalities": ["image"]}
|
|
req = ProviderRequest(
|
|
contexts=[
|
|
{
|
|
"role": "assistant",
|
|
"content": "Response",
|
|
"tool_calls": [{"name": "tool"}],
|
|
}
|
|
]
|
|
)
|
|
|
|
module._sanitize_context_by_modalities(config, mock_provider, req)
|
|
|
|
assert "tool_calls" not in req.contexts[0]
|
|
|
|
def test_sanitize_removes_image_blocks(self, mock_provider):
|
|
"""Test sanitize removes image blocks when image not supported."""
|
|
module = ama
|
|
config = module.MainAgentBuildConfig(
|
|
tool_call_timeout=60, sanitize_context_by_modalities=True
|
|
)
|
|
mock_provider.provider_config = {"modalities": ["tool_use"]}
|
|
req = ProviderRequest(
|
|
contexts=[
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
{"type": "text", "text": "Hello"},
|
|
{"type": "image_url", "url": "image.jpg"},
|
|
],
|
|
}
|
|
]
|
|
)
|
|
|
|
module._sanitize_context_by_modalities(config, mock_provider, req)
|
|
|
|
assert len(req.contexts[0]["content"]) == 1
|
|
assert req.contexts[0]["content"][0]["type"] == "text"
|
|
|
|
|
|
class TestPluginToolFix:
|
|
"""Tests for _plugin_tool_fix function."""
|
|
|
|
def test_plugin_tool_fix_none_plugins(self, mock_event):
|
|
"""Test plugin tool fix when no plugins specified."""
|
|
module = ama
|
|
req = ProviderRequest(func_tool=ToolSet())
|
|
mock_event.plugins_name = None
|
|
|
|
module._plugin_tool_fix(mock_event, req)
|
|
|
|
assert req.func_tool is not None
|
|
|
|
def test_plugin_tool_fix_filters_by_plugin(self, mock_event):
|
|
"""Test plugin tool fix filters tools by enabled plugins."""
|
|
module = ama
|
|
mcp_tool = MagicMock(spec=MCPTool)
|
|
mcp_tool.name = "mcp_tool"
|
|
|
|
plugin_tool = MagicMock()
|
|
plugin_tool.name = "plugin_tool"
|
|
plugin_tool.handler_module_path = "test_plugin"
|
|
plugin_tool.active = True
|
|
|
|
tool_set = ToolSet()
|
|
tool_set.add_tool(mcp_tool)
|
|
tool_set.add_tool(plugin_tool)
|
|
|
|
req = ProviderRequest(func_tool=tool_set)
|
|
mock_event.plugins_name = ["test_plugin"]
|
|
|
|
with patch("astrbot.core.astr_main_agent.star_map") as mock_star_map:
|
|
mock_plugin = MagicMock()
|
|
mock_plugin.name = "test_plugin"
|
|
mock_plugin.reserved = False
|
|
mock_star_map.get.return_value = mock_plugin
|
|
|
|
module._plugin_tool_fix(mock_event, req)
|
|
|
|
assert "mcp_tool" in req.func_tool.names()
|
|
assert "plugin_tool" in req.func_tool.names()
|
|
|
|
def test_plugin_tool_fix_mcp_preserved(self, mock_event):
|
|
"""Test that MCP tools are always preserved."""
|
|
module = ama
|
|
mcp_tool = MagicMock(spec=MCPTool)
|
|
mcp_tool.name = "mcp_tool"
|
|
mcp_tool.active = True
|
|
|
|
tool_set = ToolSet()
|
|
tool_set.add_tool(mcp_tool)
|
|
|
|
req = ProviderRequest(func_tool=tool_set)
|
|
mock_event.plugins_name = ["other_plugin"]
|
|
|
|
with patch("astrbot.core.astr_main_agent.star_map"):
|
|
module._plugin_tool_fix(mock_event, req)
|
|
|
|
assert "mcp_tool" in req.func_tool.names()
|
|
|
|
|
|
class TestBuildMainAgent:
|
|
"""Tests for build_main_agent function."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_build_main_agent_basic(
|
|
self, mock_event, mock_context, mock_provider
|
|
):
|
|
"""Test basic main agent building."""
|
|
module = ama
|
|
mock_context.get_provider_by_id.return_value = None
|
|
mock_context.get_using_provider.return_value = mock_provider
|
|
mock_context.get_config.return_value = {}
|
|
|
|
conv_mgr = mock_context.conversation_manager
|
|
_setup_conversation_for_build(conv_mgr)
|
|
|
|
with (
|
|
patch("astrbot.core.astr_main_agent.AgentRunner") as mock_runner_cls,
|
|
patch("astrbot.core.astr_main_agent.AstrAgentContext"),
|
|
):
|
|
mock_runner = MagicMock()
|
|
mock_runner.reset = AsyncMock()
|
|
mock_runner_cls.return_value = mock_runner
|
|
|
|
result = await module.build_main_agent(
|
|
event=mock_event,
|
|
plugin_context=mock_context,
|
|
config=module.MainAgentBuildConfig(tool_call_timeout=60),
|
|
)
|
|
|
|
assert result is not None
|
|
assert isinstance(result, module.MainAgentBuildResult)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_build_main_agent_no_provider(self, mock_event, mock_context):
|
|
"""Test building main agent when no provider is available."""
|
|
module = ama
|
|
mock_context.get_provider_by_id.return_value = None
|
|
mock_context.get_using_provider.side_effect = ValueError("No provider")
|
|
|
|
result = await module.build_main_agent(
|
|
event=mock_event,
|
|
plugin_context=mock_context,
|
|
config=module.MainAgentBuildConfig(tool_call_timeout=60),
|
|
)
|
|
|
|
assert result is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_build_main_agent_with_wake_prefix(
|
|
self, mock_event, mock_context, mock_provider
|
|
):
|
|
"""Test building main agent with wake prefix."""
|
|
module = ama
|
|
mock_event.message_str = "/command"
|
|
mock_context.get_provider_by_id.return_value = None
|
|
mock_context.get_using_provider.return_value = mock_provider
|
|
mock_context.get_config.return_value = {}
|
|
|
|
conv_mgr = mock_context.conversation_manager
|
|
_setup_conversation_for_build(conv_mgr)
|
|
|
|
with (
|
|
patch("astrbot.core.astr_main_agent.AgentRunner") as mock_runner_cls,
|
|
patch("astrbot.core.astr_main_agent.AstrAgentContext"),
|
|
):
|
|
mock_runner = MagicMock()
|
|
mock_runner.reset = AsyncMock()
|
|
mock_runner_cls.return_value = mock_runner
|
|
|
|
result = await module.build_main_agent(
|
|
event=mock_event,
|
|
plugin_context=mock_context,
|
|
config=module.MainAgentBuildConfig(
|
|
tool_call_timeout=60, provider_wake_prefix="/"
|
|
),
|
|
)
|
|
|
|
assert result is not None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_build_main_agent_no_wake_prefix(
|
|
self, mock_event, mock_context, mock_provider
|
|
):
|
|
"""Test building main agent without matching wake prefix."""
|
|
module = ama
|
|
mock_event.message_str = "hello"
|
|
mock_context.get_provider_by_id.return_value = None
|
|
mock_context.get_using_provider.return_value = mock_provider
|
|
|
|
result = await module.build_main_agent(
|
|
event=mock_event,
|
|
plugin_context=mock_context,
|
|
config=module.MainAgentBuildConfig(
|
|
tool_call_timeout=60, provider_wake_prefix="/"
|
|
),
|
|
)
|
|
|
|
assert result is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_build_main_agent_with_images(
|
|
self, mock_event, mock_context, mock_provider
|
|
):
|
|
"""Test building main agent with image attachments."""
|
|
module = ama
|
|
mock_image = MagicMock(spec=Image)
|
|
mock_image.convert_to_file_path = AsyncMock(return_value="/path/to/image.jpg")
|
|
mock_event.message_obj.message = [mock_image]
|
|
|
|
mock_context.get_provider_by_id.return_value = None
|
|
mock_context.get_using_provider.return_value = mock_provider
|
|
mock_context.get_config.return_value = {}
|
|
|
|
conv_mgr = mock_context.conversation_manager
|
|
_setup_conversation_for_build(conv_mgr)
|
|
|
|
with (
|
|
patch("astrbot.core.astr_main_agent.AgentRunner") as mock_runner_cls,
|
|
patch("astrbot.core.astr_main_agent.AstrAgentContext"),
|
|
):
|
|
mock_runner = MagicMock()
|
|
mock_runner.reset = AsyncMock()
|
|
mock_runner_cls.return_value = mock_runner
|
|
|
|
result = await module.build_main_agent(
|
|
event=mock_event,
|
|
plugin_context=mock_context,
|
|
config=module.MainAgentBuildConfig(tool_call_timeout=60),
|
|
)
|
|
|
|
assert result is not None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_build_main_agent_no_prompt_no_images(
|
|
self, mock_event, mock_context, mock_provider
|
|
):
|
|
"""Test building main agent returns None when no prompt or images."""
|
|
module = ama
|
|
mock_event.message_str = ""
|
|
mock_event.message_obj.message = []
|
|
|
|
mock_context.get_provider_by_id.return_value = None
|
|
mock_context.get_using_provider.return_value = mock_provider
|
|
mock_context.get_config.return_value = {}
|
|
|
|
conv_mgr = mock_context.conversation_manager
|
|
_setup_conversation_for_build(conv_mgr)
|
|
|
|
result = await module.build_main_agent(
|
|
event=mock_event,
|
|
plugin_context=mock_context,
|
|
config=module.MainAgentBuildConfig(tool_call_timeout=60),
|
|
)
|
|
|
|
assert result is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_build_main_agent_apply_reset_false(
|
|
self, mock_event, mock_context, mock_provider
|
|
):
|
|
"""Test building main agent without applying reset."""
|
|
module = ama
|
|
mock_context.get_provider_by_id.return_value = None
|
|
mock_context.get_using_provider.return_value = mock_provider
|
|
mock_context.get_config.return_value = {}
|
|
|
|
conv_mgr = mock_context.conversation_manager
|
|
_setup_conversation_for_build(conv_mgr)
|
|
|
|
with (
|
|
patch("astrbot.core.astr_main_agent.AgentRunner") as mock_runner_cls,
|
|
patch("astrbot.core.astr_main_agent.AstrAgentContext"),
|
|
):
|
|
mock_runner = MagicMock()
|
|
mock_runner.reset = AsyncMock()
|
|
mock_runner_cls.return_value = mock_runner
|
|
|
|
result = await module.build_main_agent(
|
|
event=mock_event,
|
|
plugin_context=mock_context,
|
|
config=module.MainAgentBuildConfig(tool_call_timeout=60),
|
|
apply_reset=False,
|
|
)
|
|
|
|
assert result is not None
|
|
assert result.reset_coro is not None
|
|
mock_runner.reset.assert_called_once()
|
|
result.reset_coro.close()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_build_main_agent_with_existing_request(
|
|
self, mock_event, mock_context, mock_provider
|
|
):
|
|
"""Test building main agent with existing ProviderRequest."""
|
|
module = ama
|
|
existing_req = ProviderRequest(prompt="Existing prompt")
|
|
mock_event.get_extra.side_effect = lambda k: (
|
|
existing_req if k == "provider_request" else None
|
|
)
|
|
|
|
with (
|
|
patch("astrbot.core.astr_main_agent.AgentRunner") as mock_runner_cls,
|
|
patch("astrbot.core.astr_main_agent.AstrAgentContext"),
|
|
):
|
|
mock_runner = MagicMock()
|
|
mock_runner.reset = AsyncMock()
|
|
mock_runner_cls.return_value = mock_runner
|
|
|
|
result = await module.build_main_agent(
|
|
event=mock_event,
|
|
plugin_context=mock_context,
|
|
config=module.MainAgentBuildConfig(tool_call_timeout=60),
|
|
provider=mock_provider,
|
|
req=existing_req,
|
|
)
|
|
|
|
assert result is not None
|
|
assert result.provider_request == existing_req
|
|
|
|
|
|
class TestHandleWebchat:
|
|
"""Tests for _handle_webchat function."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_webchat_generates_title(self, mock_event):
|
|
"""Test generating title for webchat session without display name."""
|
|
module = ama
|
|
mock_event.session_id = "platform!webchat-session-123"
|
|
|
|
req = ProviderRequest(prompt="What is machine learning?")
|
|
prov = MagicMock(spec=Provider)
|
|
llm_response = MagicMock()
|
|
llm_response.completion_text = "Machine Learning Introduction"
|
|
prov.text_chat = AsyncMock(return_value=llm_response)
|
|
|
|
mock_session = MagicMock()
|
|
mock_session.display_name = None
|
|
|
|
with patch("astrbot.core.db_helper") as mock_db:
|
|
mock_db.get_platform_session_by_id = AsyncMock(return_value=mock_session)
|
|
mock_db.update_platform_session = AsyncMock()
|
|
|
|
await module._handle_webchat(mock_event, req, prov)
|
|
|
|
mock_db.get_platform_session_by_id.assert_called_once_with(
|
|
"webchat-session-123"
|
|
)
|
|
mock_db.update_platform_session.assert_called_once_with(
|
|
session_id="webchat-session-123",
|
|
display_name="Machine Learning Introduction",
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_webchat_no_user_prompt(self, mock_event):
|
|
"""Test that title generation is skipped when no user prompt."""
|
|
module = ama
|
|
mock_event.session_id = "platform!webchat-session-123"
|
|
|
|
req = ProviderRequest(prompt=None)
|
|
prov = MagicMock(spec=Provider)
|
|
|
|
mock_session = MagicMock()
|
|
mock_session.display_name = None
|
|
|
|
with patch("astrbot.core.db_helper") as mock_db:
|
|
mock_db.get_platform_session_by_id = AsyncMock(return_value=mock_session)
|
|
await module._handle_webchat(mock_event, req, prov)
|
|
|
|
prov.text_chat.assert_not_called()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_webchat_empty_user_prompt(self, mock_event):
|
|
"""Test that title generation is skipped when user prompt is empty."""
|
|
module = ama
|
|
mock_event.session_id = "platform!webchat-session-123"
|
|
|
|
req = ProviderRequest(prompt="")
|
|
prov = MagicMock(spec=Provider)
|
|
|
|
mock_session = MagicMock()
|
|
mock_session.display_name = None
|
|
|
|
with patch("astrbot.core.db_helper") as mock_db:
|
|
mock_db.get_platform_session_by_id = AsyncMock(return_value=mock_session)
|
|
await module._handle_webchat(mock_event, req, prov)
|
|
|
|
prov.text_chat.assert_not_called()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_webchat_session_already_has_display_name(self, mock_event):
|
|
"""Test that title generation is skipped when session already has display name."""
|
|
module = ama
|
|
mock_event.session_id = "platform!webchat-session-123"
|
|
|
|
req = ProviderRequest(prompt="What is AI?")
|
|
prov = MagicMock(spec=Provider)
|
|
|
|
mock_session = MagicMock()
|
|
mock_session.display_name = "Existing Title"
|
|
|
|
with patch("astrbot.core.db_helper") as mock_db:
|
|
mock_db.get_platform_session_by_id = AsyncMock(return_value=mock_session)
|
|
|
|
await module._handle_webchat(mock_event, req, prov)
|
|
|
|
prov.text_chat.assert_not_called()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_webchat_no_session_found(self, mock_event):
|
|
"""Test that title generation is skipped when session is not found."""
|
|
module = ama
|
|
mock_event.session_id = "platform!webchat-session-123"
|
|
|
|
req = ProviderRequest(prompt="What is AI?")
|
|
prov = MagicMock(spec=Provider)
|
|
|
|
with patch("astrbot.core.db_helper") as mock_db:
|
|
mock_db.get_platform_session_by_id = AsyncMock(return_value=None)
|
|
|
|
await module._handle_webchat(mock_event, req, prov)
|
|
|
|
prov.text_chat.assert_not_called()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_webchat_llm_returns_none_title(self, mock_event):
|
|
"""Test that title is not updated when LLM returns <None>."""
|
|
module = ama
|
|
mock_event.session_id = "platform!webchat-session-123"
|
|
|
|
req = ProviderRequest(prompt="hi")
|
|
prov = MagicMock(spec=Provider)
|
|
llm_response = MagicMock()
|
|
llm_response.completion_text = "<None>"
|
|
prov.text_chat = AsyncMock(return_value=llm_response)
|
|
|
|
mock_session = MagicMock()
|
|
mock_session.display_name = None
|
|
|
|
with patch("astrbot.core.db_helper") as mock_db:
|
|
mock_db.get_platform_session_by_id = AsyncMock(return_value=mock_session)
|
|
mock_db.update_platform_session = AsyncMock()
|
|
|
|
await module._handle_webchat(mock_event, req, prov)
|
|
|
|
mock_db.update_platform_session.assert_not_called()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_webchat_llm_returns_empty_title(self, mock_event):
|
|
"""Test that title is not updated when LLM returns empty string."""
|
|
module = ama
|
|
mock_event.session_id = "platform!webchat-session-123"
|
|
|
|
req = ProviderRequest(prompt="hello")
|
|
prov = MagicMock(spec=Provider)
|
|
llm_response = MagicMock()
|
|
llm_response.completion_text = " "
|
|
prov.text_chat = AsyncMock(return_value=llm_response)
|
|
|
|
mock_session = MagicMock()
|
|
mock_session.display_name = None
|
|
|
|
with patch("astrbot.core.db_helper") as mock_db:
|
|
mock_db.get_platform_session_by_id = AsyncMock(return_value=mock_session)
|
|
mock_db.update_platform_session = AsyncMock()
|
|
|
|
await module._handle_webchat(mock_event, req, prov)
|
|
|
|
mock_db.update_platform_session.assert_not_called()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_webchat_llm_returns_none_response(self, mock_event):
|
|
"""Test handling when LLM returns None response."""
|
|
module = ama
|
|
mock_event.session_id = "platform!webchat-session-123"
|
|
|
|
req = ProviderRequest(prompt="test question")
|
|
prov = MagicMock(spec=Provider)
|
|
prov.text_chat = AsyncMock(return_value=None)
|
|
|
|
mock_session = MagicMock()
|
|
mock_session.display_name = None
|
|
|
|
with patch("astrbot.core.db_helper") as mock_db:
|
|
mock_db.get_platform_session_by_id = AsyncMock(return_value=mock_session)
|
|
mock_db.update_platform_session = AsyncMock()
|
|
|
|
await module._handle_webchat(mock_event, req, prov)
|
|
|
|
mock_db.update_platform_session.assert_not_called()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_webchat_llm_returns_no_completion_text(self, mock_event):
|
|
"""Test handling when LLM response has no completion_text."""
|
|
module = ama
|
|
mock_event.session_id = "platform!webchat-session-123"
|
|
|
|
req = ProviderRequest(prompt="test question")
|
|
prov = MagicMock(spec=Provider)
|
|
llm_response = MagicMock()
|
|
llm_response.completion_text = None
|
|
prov.text_chat = AsyncMock(return_value=llm_response)
|
|
|
|
mock_session = MagicMock()
|
|
mock_session.display_name = None
|
|
|
|
with patch("astrbot.core.db_helper") as mock_db:
|
|
mock_db.get_platform_session_by_id = AsyncMock(return_value=mock_session)
|
|
mock_db.update_platform_session = AsyncMock()
|
|
|
|
await module._handle_webchat(mock_event, req, prov)
|
|
|
|
mock_db.update_platform_session.assert_not_called()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_webchat_strips_title_whitespace(self, mock_event):
|
|
"""Test that generated title has whitespace stripped."""
|
|
module = ama
|
|
mock_event.session_id = "platform!webchat-session-123"
|
|
|
|
req = ProviderRequest(prompt="What is Python?")
|
|
prov = MagicMock(spec=Provider)
|
|
llm_response = MagicMock()
|
|
llm_response.completion_text = " Python Programming Guide "
|
|
prov.text_chat = AsyncMock(return_value=llm_response)
|
|
|
|
mock_session = MagicMock()
|
|
mock_session.display_name = None
|
|
|
|
with patch("astrbot.core.db_helper") as mock_db:
|
|
mock_db.get_platform_session_by_id = AsyncMock(return_value=mock_session)
|
|
mock_db.update_platform_session = AsyncMock()
|
|
|
|
await module._handle_webchat(mock_event, req, prov)
|
|
|
|
mock_db.update_platform_session.assert_called_once_with(
|
|
session_id="webchat-session-123",
|
|
display_name="Python Programming Guide",
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_webchat_provider_exception_is_handled(self, mock_event):
|
|
"""Test that provider exception during title generation is handled."""
|
|
module = ama
|
|
mock_event.session_id = "platform!webchat-session-123"
|
|
|
|
req = ProviderRequest(prompt="What is Python?")
|
|
prov = MagicMock(spec=Provider)
|
|
prov.text_chat = AsyncMock(side_effect=RuntimeError("provider failed"))
|
|
|
|
mock_session = MagicMock()
|
|
mock_session.display_name = None
|
|
|
|
with (
|
|
patch("astrbot.core.db_helper") as mock_db,
|
|
patch("astrbot.core.astr_main_agent.logger") as mock_logger,
|
|
):
|
|
mock_db.get_platform_session_by_id = AsyncMock(return_value=mock_session)
|
|
mock_db.update_platform_session = AsyncMock()
|
|
|
|
await module._handle_webchat(mock_event, req, prov)
|
|
|
|
mock_logger.exception.assert_called_once()
|
|
mock_db.update_platform_session.assert_not_called()
|
|
|
|
|
|
class TestApplyLlmSafetyMode:
|
|
"""Tests for _apply_llm_safety_mode function."""
|
|
|
|
def test_apply_llm_safety_mode_system_prompt_strategy(self):
|
|
"""Test applying safety mode with system_prompt strategy."""
|
|
module = ama
|
|
config = module.MainAgentBuildConfig(
|
|
tool_call_timeout=60,
|
|
llm_safety_mode=True,
|
|
safety_mode_strategy="system_prompt",
|
|
)
|
|
req = ProviderRequest(prompt="Test", system_prompt="Original prompt")
|
|
|
|
module._apply_llm_safety_mode(config, req)
|
|
|
|
assert "You are running in Safe Mode" in req.system_prompt
|
|
assert "Original prompt" in req.system_prompt
|
|
|
|
def test_apply_llm_safety_mode_prepends_safety_prompt(self):
|
|
"""Test that safety prompt is prepended before original system prompt."""
|
|
module = ama
|
|
config = module.MainAgentBuildConfig(
|
|
tool_call_timeout=60,
|
|
safety_mode_strategy="system_prompt",
|
|
)
|
|
req = ProviderRequest(prompt="Test", system_prompt="My custom prompt")
|
|
|
|
module._apply_llm_safety_mode(config, req)
|
|
|
|
assert req.system_prompt.startswith("You are running in Safe Mode")
|
|
assert "My custom prompt" in req.system_prompt
|
|
|
|
def test_apply_llm_safety_mode_with_none_system_prompt(self):
|
|
"""Test applying safety mode when original system_prompt is None."""
|
|
module = ama
|
|
config = module.MainAgentBuildConfig(
|
|
tool_call_timeout=60,
|
|
safety_mode_strategy="system_prompt",
|
|
)
|
|
req = ProviderRequest(prompt="Test", system_prompt=None)
|
|
|
|
module._apply_llm_safety_mode(config, req)
|
|
|
|
assert "You are running in Safe Mode" in req.system_prompt
|
|
|
|
def test_apply_llm_safety_mode_unsupported_strategy(self):
|
|
"""Test that unsupported strategy logs warning and does nothing."""
|
|
module = ama
|
|
config = module.MainAgentBuildConfig(
|
|
tool_call_timeout=60,
|
|
safety_mode_strategy="unsupported_strategy",
|
|
)
|
|
req = ProviderRequest(prompt="Test", system_prompt="Original")
|
|
|
|
with patch("astrbot.core.astr_main_agent.logger") as mock_logger:
|
|
module._apply_llm_safety_mode(config, req)
|
|
|
|
mock_logger.warning.assert_called_once()
|
|
assert (
|
|
"Unsupported llm_safety_mode strategy"
|
|
in mock_logger.warning.call_args[0][0]
|
|
)
|
|
assert req.system_prompt == "Original"
|
|
|
|
def test_apply_llm_safety_mode_empty_system_prompt(self):
|
|
"""Test applying safety mode when original system_prompt is empty."""
|
|
module = ama
|
|
config = module.MainAgentBuildConfig(
|
|
tool_call_timeout=60,
|
|
safety_mode_strategy="system_prompt",
|
|
)
|
|
req = ProviderRequest(prompt="Test", system_prompt="")
|
|
|
|
module._apply_llm_safety_mode(config, req)
|
|
|
|
assert "You are running in Safe Mode" in req.system_prompt
|
|
|
|
|
|
class TestApplySandboxTools:
|
|
"""Tests for _apply_sandbox_tools function."""
|
|
|
|
def test_apply_sandbox_tools_creates_toolset_if_none(self):
|
|
"""Test that ToolSet is created when func_tool is None."""
|
|
module = ama
|
|
config = module.MainAgentBuildConfig(
|
|
tool_call_timeout=60,
|
|
computer_use_runtime="sandbox",
|
|
sandbox_cfg={},
|
|
)
|
|
req = ProviderRequest(prompt="Test", func_tool=None)
|
|
|
|
module._apply_sandbox_tools(config, req, "session-123")
|
|
|
|
assert req.func_tool is not None
|
|
assert isinstance(req.func_tool, ToolSet)
|
|
|
|
def test_apply_sandbox_tools_adds_required_tools(self):
|
|
"""Test that all required sandbox tools are added."""
|
|
module = ama
|
|
config = module.MainAgentBuildConfig(
|
|
tool_call_timeout=60,
|
|
computer_use_runtime="sandbox",
|
|
sandbox_cfg={},
|
|
)
|
|
req = ProviderRequest(prompt="Test", func_tool=None)
|
|
|
|
module._apply_sandbox_tools(config, req, "session-123")
|
|
|
|
tool_names = req.func_tool.names()
|
|
assert "astrbot_execute_shell" in tool_names
|
|
assert "astrbot_execute_ipython" in tool_names
|
|
assert "astrbot_upload_file" in tool_names
|
|
assert "astrbot_download_file" in tool_names
|
|
|
|
def test_apply_sandbox_tools_adds_sandbox_prompt(self):
|
|
"""Test that sandbox mode prompt is added to system_prompt."""
|
|
module = ama
|
|
config = module.MainAgentBuildConfig(
|
|
tool_call_timeout=60,
|
|
computer_use_runtime="sandbox",
|
|
sandbox_cfg={},
|
|
)
|
|
req = ProviderRequest(prompt="Test", system_prompt="Original prompt")
|
|
|
|
module._apply_sandbox_tools(config, req, "session-123")
|
|
|
|
assert "sandboxed environment" in req.system_prompt
|
|
|
|
def test_apply_sandbox_tools_with_shipyard_booter(self, monkeypatch):
|
|
"""Test sandbox tools with shipyard booter configuration."""
|
|
module = ama
|
|
config = module.MainAgentBuildConfig(
|
|
tool_call_timeout=60,
|
|
computer_use_runtime="sandbox",
|
|
sandbox_cfg={
|
|
"booter": "shipyard",
|
|
"shipyard_endpoint": "https://shipyard.example.com",
|
|
"shipyard_access_token": "test-token",
|
|
},
|
|
)
|
|
req = ProviderRequest(prompt="Test", func_tool=None)
|
|
|
|
monkeypatch.delenv("SHIPYARD_ENDPOINT", raising=False)
|
|
monkeypatch.delenv("SHIPYARD_ACCESS_TOKEN", raising=False)
|
|
|
|
module._apply_sandbox_tools(config, req, "session-123")
|
|
|
|
assert os.environ.get("SHIPYARD_ENDPOINT") == "https://shipyard.example.com"
|
|
assert os.environ.get("SHIPYARD_ACCESS_TOKEN") == "test-token"
|
|
|
|
def test_apply_sandbox_tools_shipyard_missing_endpoint(self):
|
|
"""Test that shipyard config is skipped when endpoint is missing."""
|
|
module = ama
|
|
config = module.MainAgentBuildConfig(
|
|
tool_call_timeout=60,
|
|
computer_use_runtime="sandbox",
|
|
sandbox_cfg={
|
|
"booter": "shipyard",
|
|
"shipyard_endpoint": "",
|
|
"shipyard_access_token": "test-token",
|
|
},
|
|
)
|
|
req = ProviderRequest(prompt="Test", func_tool=None)
|
|
|
|
with patch("astrbot.core.astr_main_agent.logger") as mock_logger:
|
|
module._apply_sandbox_tools(config, req, "session-123")
|
|
|
|
mock_logger.error.assert_called_once()
|
|
assert (
|
|
"Shipyard sandbox configuration is incomplete"
|
|
in mock_logger.error.call_args[0][0]
|
|
)
|
|
|
|
def test_apply_sandbox_tools_shipyard_missing_access_token(self):
|
|
"""Test that shipyard config is skipped when access token is missing."""
|
|
module = ama
|
|
config = module.MainAgentBuildConfig(
|
|
tool_call_timeout=60,
|
|
computer_use_runtime="sandbox",
|
|
sandbox_cfg={
|
|
"booter": "shipyard",
|
|
"shipyard_endpoint": "https://shipyard.example.com",
|
|
"shipyard_access_token": "",
|
|
},
|
|
)
|
|
req = ProviderRequest(prompt="Test", func_tool=None)
|
|
|
|
with patch("astrbot.core.astr_main_agent.logger") as mock_logger:
|
|
module._apply_sandbox_tools(config, req, "session-123")
|
|
|
|
mock_logger.error.assert_called_once()
|
|
|
|
def test_apply_sandbox_tools_preserves_existing_toolset(self):
|
|
"""Test that existing tools are preserved when adding sandbox tools."""
|
|
module = ama
|
|
config = module.MainAgentBuildConfig(
|
|
tool_call_timeout=60,
|
|
computer_use_runtime="sandbox",
|
|
sandbox_cfg={},
|
|
)
|
|
existing_toolset = ToolSet()
|
|
existing_tool = MagicMock()
|
|
existing_tool.name = "existing_tool"
|
|
existing_toolset.add_tool(existing_tool)
|
|
req = ProviderRequest(prompt="Test", func_tool=existing_toolset)
|
|
|
|
module._apply_sandbox_tools(config, req, "session-123")
|
|
|
|
assert "existing_tool" in req.func_tool.names()
|
|
assert "astrbot_execute_shell" in req.func_tool.names()
|
|
|
|
def test_apply_sandbox_tools_appends_to_existing_system_prompt(self):
|
|
"""Test that sandbox prompt is appended to existing system prompt."""
|
|
module = ama
|
|
config = module.MainAgentBuildConfig(
|
|
tool_call_timeout=60,
|
|
computer_use_runtime="sandbox",
|
|
sandbox_cfg={},
|
|
)
|
|
req = ProviderRequest(prompt="Test", system_prompt="Base prompt")
|
|
|
|
module._apply_sandbox_tools(config, req, "session-123")
|
|
|
|
assert req.system_prompt.startswith("Base prompt")
|
|
assert "sandboxed environment" in req.system_prompt
|
|
|
|
def test_apply_sandbox_tools_with_none_system_prompt(self):
|
|
"""Test that sandbox prompt is applied when system_prompt is None."""
|
|
module = ama
|
|
config = module.MainAgentBuildConfig(
|
|
tool_call_timeout=60,
|
|
computer_use_runtime="sandbox",
|
|
sandbox_cfg={},
|
|
)
|
|
req = ProviderRequest(prompt="Test", system_prompt=None)
|
|
|
|
module._apply_sandbox_tools(config, req, "session-123")
|
|
|
|
assert isinstance(req.system_prompt, str)
|
|
assert "sandboxed environment" in req.system_prompt
|