Merge branch 'master' into feat/neo-skill-self-iteration
This commit is contained in:
@@ -0,0 +1,382 @@
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from astrbot.core.provider.sources.openai_source import ProviderOpenAIOfficial
|
||||
|
||||
|
||||
class _ErrorWithBody(Exception):
|
||||
def __init__(self, message: str, body: dict):
|
||||
super().__init__(message)
|
||||
self.body = body
|
||||
|
||||
|
||||
class _ErrorWithResponse(Exception):
|
||||
def __init__(self, message: str, response_text: str):
|
||||
super().__init__(message)
|
||||
self.response = SimpleNamespace(text=response_text)
|
||||
|
||||
|
||||
def _make_provider(overrides: dict | None = None) -> ProviderOpenAIOfficial:
|
||||
provider_config = {
|
||||
"id": "test-openai",
|
||||
"type": "openai_chat_completion",
|
||||
"model": "gpt-4o-mini",
|
||||
"key": ["test-key"],
|
||||
}
|
||||
if overrides:
|
||||
provider_config.update(overrides)
|
||||
return ProviderOpenAIOfficial(
|
||||
provider_config=provider_config,
|
||||
provider_settings={},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_api_error_content_moderated_removes_images():
|
||||
provider = _make_provider(
|
||||
{"image_moderation_error_patterns": ["file:content-moderated"]}
|
||||
)
|
||||
try:
|
||||
payloads = {
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "hello"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": "data:image/jpeg;base64,abcd"},
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
}
|
||||
context_query = payloads["messages"]
|
||||
|
||||
success, *_rest = await provider._handle_api_error(
|
||||
Exception("Content is moderated [WKE=file:content-moderated]"),
|
||||
payloads=payloads,
|
||||
context_query=context_query,
|
||||
func_tool=None,
|
||||
chosen_key="test-key",
|
||||
available_api_keys=["test-key"],
|
||||
retry_cnt=0,
|
||||
max_retries=10,
|
||||
)
|
||||
|
||||
assert success is False
|
||||
updated_context = payloads["messages"]
|
||||
assert isinstance(updated_context, list)
|
||||
assert updated_context[0]["content"] == [{"type": "text", "text": "hello"}]
|
||||
finally:
|
||||
await provider.terminate()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_api_error_model_not_vlm_removes_images_and_retries_text_only():
|
||||
provider = _make_provider()
|
||||
try:
|
||||
payloads = {
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "hello"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": "data:image/jpeg;base64,abcd"},
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
}
|
||||
context_query = payloads["messages"]
|
||||
|
||||
success, *_rest = await provider._handle_api_error(
|
||||
Exception("The model is not a VLM and cannot process images"),
|
||||
payloads=payloads,
|
||||
context_query=context_query,
|
||||
func_tool=None,
|
||||
chosen_key="test-key",
|
||||
available_api_keys=["test-key"],
|
||||
retry_cnt=0,
|
||||
max_retries=10,
|
||||
)
|
||||
|
||||
assert success is False
|
||||
updated_context = payloads["messages"]
|
||||
assert isinstance(updated_context, list)
|
||||
assert updated_context[0]["content"] == [{"type": "text", "text": "hello"}]
|
||||
finally:
|
||||
await provider.terminate()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_api_error_model_not_vlm_after_fallback_raises():
|
||||
provider = _make_provider()
|
||||
try:
|
||||
payloads = {
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "hello"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": "data:image/jpeg;base64,abcd"},
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
}
|
||||
context_query = payloads["messages"]
|
||||
|
||||
with pytest.raises(Exception, match="not a VLM"):
|
||||
await provider._handle_api_error(
|
||||
Exception("The model is not a VLM and cannot process images"),
|
||||
payloads=payloads,
|
||||
context_query=context_query,
|
||||
func_tool=None,
|
||||
chosen_key="test-key",
|
||||
available_api_keys=["test-key"],
|
||||
retry_cnt=1,
|
||||
max_retries=10,
|
||||
image_fallback_used=True,
|
||||
)
|
||||
finally:
|
||||
await provider.terminate()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_api_error_content_moderated_with_unserializable_body():
|
||||
provider = _make_provider({"image_moderation_error_patterns": ["blocked"]})
|
||||
try:
|
||||
payloads = {
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "hello"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": "data:image/jpeg;base64,abcd"},
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
}
|
||||
context_query = payloads["messages"]
|
||||
err = _ErrorWithBody(
|
||||
"upstream error",
|
||||
{"error": {"message": "blocked"}, "raw": object()},
|
||||
)
|
||||
|
||||
success, *_rest = await provider._handle_api_error(
|
||||
err,
|
||||
payloads=payloads,
|
||||
context_query=context_query,
|
||||
func_tool=None,
|
||||
chosen_key="test-key",
|
||||
available_api_keys=["test-key"],
|
||||
retry_cnt=0,
|
||||
max_retries=10,
|
||||
)
|
||||
assert success is False
|
||||
assert payloads["messages"][0]["content"] == [{"type": "text", "text": "hello"}]
|
||||
finally:
|
||||
await provider.terminate()
|
||||
|
||||
|
||||
def test_extract_error_text_candidates_truncates_long_response_text():
|
||||
long_text = "x" * 20000
|
||||
err = _ErrorWithResponse("upstream error", long_text)
|
||||
candidates = ProviderOpenAIOfficial._extract_error_text_candidates(err)
|
||||
assert candidates
|
||||
assert max(len(candidate) for candidate in candidates) <= (
|
||||
ProviderOpenAIOfficial._ERROR_TEXT_CANDIDATE_MAX_CHARS
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_api_error_content_moderated_without_images_raises():
|
||||
provider = _make_provider(
|
||||
{"image_moderation_error_patterns": ["file:content-moderated"]}
|
||||
)
|
||||
try:
|
||||
payloads = {
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": "hello"}],
|
||||
}
|
||||
]
|
||||
}
|
||||
context_query = payloads["messages"]
|
||||
err = Exception("Content is moderated [WKE=file:content-moderated]")
|
||||
|
||||
with pytest.raises(Exception, match="content-moderated"):
|
||||
await provider._handle_api_error(
|
||||
err,
|
||||
payloads=payloads,
|
||||
context_query=context_query,
|
||||
func_tool=None,
|
||||
chosen_key="test-key",
|
||||
available_api_keys=["test-key"],
|
||||
retry_cnt=0,
|
||||
max_retries=10,
|
||||
)
|
||||
finally:
|
||||
await provider.terminate()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_api_error_content_moderated_detects_structured_body():
|
||||
provider = _make_provider(
|
||||
{"image_moderation_error_patterns": ["content_moderated"]}
|
||||
)
|
||||
try:
|
||||
payloads = {
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "hello"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": "data:image/jpeg;base64,abcd"},
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
}
|
||||
context_query = payloads["messages"]
|
||||
err = _ErrorWithBody(
|
||||
"upstream error",
|
||||
{"error": {"code": "content_moderated", "message": "blocked"}},
|
||||
)
|
||||
|
||||
success, *_rest = await provider._handle_api_error(
|
||||
err,
|
||||
payloads=payloads,
|
||||
context_query=context_query,
|
||||
func_tool=None,
|
||||
chosen_key="test-key",
|
||||
available_api_keys=["test-key"],
|
||||
retry_cnt=0,
|
||||
max_retries=10,
|
||||
)
|
||||
assert success is False
|
||||
assert payloads["messages"][0]["content"] == [{"type": "text", "text": "hello"}]
|
||||
finally:
|
||||
await provider.terminate()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_api_error_content_moderated_supports_custom_patterns():
|
||||
provider = _make_provider(
|
||||
{"image_moderation_error_patterns": ["blocked_by_policy_code_123"]}
|
||||
)
|
||||
try:
|
||||
payloads = {
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "hello"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": "data:image/jpeg;base64,abcd"},
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
}
|
||||
context_query = payloads["messages"]
|
||||
err = Exception("upstream: blocked_by_policy_code_123")
|
||||
|
||||
success, *_rest = await provider._handle_api_error(
|
||||
err,
|
||||
payloads=payloads,
|
||||
context_query=context_query,
|
||||
func_tool=None,
|
||||
chosen_key="test-key",
|
||||
available_api_keys=["test-key"],
|
||||
retry_cnt=0,
|
||||
max_retries=10,
|
||||
)
|
||||
assert success is False
|
||||
assert payloads["messages"][0]["content"] == [{"type": "text", "text": "hello"}]
|
||||
finally:
|
||||
await provider.terminate()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_api_error_content_moderated_without_patterns_raises():
|
||||
provider = _make_provider()
|
||||
try:
|
||||
payloads = {
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "hello"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": "data:image/jpeg;base64,abcd"},
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
}
|
||||
context_query = payloads["messages"]
|
||||
err = Exception("Content is moderated [WKE=file:content-moderated]")
|
||||
|
||||
with pytest.raises(Exception, match="content-moderated"):
|
||||
await provider._handle_api_error(
|
||||
err,
|
||||
payloads=payloads,
|
||||
context_query=context_query,
|
||||
func_tool=None,
|
||||
chosen_key="test-key",
|
||||
available_api_keys=["test-key"],
|
||||
retry_cnt=0,
|
||||
max_retries=10,
|
||||
)
|
||||
finally:
|
||||
await provider.terminate()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_api_error_unknown_image_error_raises():
|
||||
provider = _make_provider()
|
||||
try:
|
||||
payloads = {
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "hello"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": "data:image/jpeg;base64,abcd"},
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
}
|
||||
context_query = payloads["messages"]
|
||||
|
||||
with pytest.raises(Exception, match="unknown provider image upload error"):
|
||||
await provider._handle_api_error(
|
||||
Exception("some unknown provider image upload error"),
|
||||
payloads=payloads,
|
||||
context_query=context_query,
|
||||
func_tool=None,
|
||||
chosen_key="test-key",
|
||||
available_api_keys=["test-key"],
|
||||
retry_cnt=0,
|
||||
max_retries=10,
|
||||
)
|
||||
finally:
|
||||
await provider.terminate()
|
||||
@@ -0,0 +1,494 @@
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from astrbot.core.message.components import Image, Plain, Reply
|
||||
from astrbot.core.utils.quoted_message_parser import (
|
||||
extract_quoted_message_images,
|
||||
extract_quoted_message_text,
|
||||
)
|
||||
|
||||
|
||||
class _DummyAPI:
|
||||
def __init__(
|
||||
self,
|
||||
responses: dict[tuple[str, str], dict],
|
||||
param_responses: dict[tuple[str, tuple[tuple[str, str], ...]], dict]
|
||||
| None = None,
|
||||
):
|
||||
self._responses = responses
|
||||
self._param_responses = param_responses or {}
|
||||
|
||||
async def call_action(self, action: str, **params):
|
||||
param_key = (action, tuple(sorted((k, str(v)) for k, v in params.items())))
|
||||
if param_key in self._param_responses:
|
||||
return self._param_responses[param_key]
|
||||
|
||||
msg_id = params.get("message_id")
|
||||
if msg_id is None:
|
||||
msg_id = params.get("id")
|
||||
key = (action, str(msg_id))
|
||||
if key not in self._responses:
|
||||
raise RuntimeError(f"no mock response for {key}")
|
||||
return self._responses[key]
|
||||
|
||||
|
||||
class _FailIfCalledAPI:
|
||||
async def call_action(self, action: str, **params):
|
||||
raise AssertionError(
|
||||
f"call_action should not be called, got action={action}, params={params}"
|
||||
)
|
||||
|
||||
|
||||
def _make_event(
|
||||
reply: Reply,
|
||||
responses: dict[tuple[str, str], dict] | None = None,
|
||||
param_responses: dict[tuple[str, tuple[tuple[str, str], ...]], dict] | None = None,
|
||||
):
|
||||
if responses is None:
|
||||
responses = {}
|
||||
if param_responses is None:
|
||||
param_responses = {}
|
||||
return SimpleNamespace(
|
||||
message_obj=SimpleNamespace(message=[reply]),
|
||||
bot=SimpleNamespace(api=_DummyAPI(responses, param_responses)),
|
||||
get_group_id=lambda: "",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_quoted_message_text_from_reply_chain():
|
||||
reply = Reply(id="1", chain=[Plain(text="quoted content")], message_str="")
|
||||
event = _make_event(reply)
|
||||
text = await extract_quoted_message_text(event)
|
||||
assert text == "quoted content"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_quoted_message_text_no_reply_component():
|
||||
event = SimpleNamespace(
|
||||
message_obj=SimpleNamespace(message=[Plain(text="unquoted message")]),
|
||||
bot=SimpleNamespace(api=_DummyAPI({}, {})),
|
||||
get_group_id=lambda: "",
|
||||
)
|
||||
|
||||
text = await extract_quoted_message_text(event)
|
||||
assert text is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_quoted_message_images_no_reply_component():
|
||||
event = SimpleNamespace(
|
||||
message_obj=SimpleNamespace(message=[Plain(text="unquoted message")]),
|
||||
bot=SimpleNamespace(api=_FailIfCalledAPI()),
|
||||
get_group_id=lambda: "",
|
||||
)
|
||||
|
||||
images = await extract_quoted_message_images(event)
|
||||
assert images == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("reply_id", [None, ""])
|
||||
async def test_extract_quoted_message_text_reply_without_id_does_not_call_get_msg(
|
||||
reply_id: str | None,
|
||||
):
|
||||
reply = Reply(
|
||||
id="placeholder", chain=[Plain(text="quoted content")], message_str=""
|
||||
)
|
||||
object.__setattr__(reply, "id", reply_id)
|
||||
event = SimpleNamespace(
|
||||
message_obj=SimpleNamespace(message=[reply]),
|
||||
bot=SimpleNamespace(api=_FailIfCalledAPI()),
|
||||
get_group_id=lambda: "",
|
||||
)
|
||||
|
||||
text = await extract_quoted_message_text(event)
|
||||
assert text == "quoted content"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_quoted_message_text_fallback_get_msg_and_forward():
|
||||
reply = Reply(id="100", chain=None, message_str="")
|
||||
event = _make_event(
|
||||
reply,
|
||||
responses={
|
||||
(
|
||||
"get_msg",
|
||||
"100",
|
||||
): {
|
||||
"data": {
|
||||
"message": [
|
||||
{"type": "text", "data": {"text": "parent"}},
|
||||
{"type": "forward", "data": {"id": "fwd_1"}},
|
||||
]
|
||||
}
|
||||
},
|
||||
(
|
||||
"get_forward_msg",
|
||||
"fwd_1",
|
||||
): {
|
||||
"data": {
|
||||
"messages": [
|
||||
{
|
||||
"sender": {"nickname": "Alice"},
|
||||
"message": [{"type": "text", "data": {"text": "hello"}}],
|
||||
},
|
||||
{
|
||||
"sender": {"nickname": "Bob"},
|
||||
"message": [
|
||||
{"type": "image", "data": {"url": "http://img"}},
|
||||
{"type": "text", "data": {"text": "world"}},
|
||||
],
|
||||
},
|
||||
]
|
||||
}
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
text = await extract_quoted_message_text(event)
|
||||
assert text is not None
|
||||
assert "parent" in text
|
||||
assert "Alice: hello" in text
|
||||
assert "Bob: [Image]world" in text
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"placeholder_text",
|
||||
[
|
||||
"[Forward Message]",
|
||||
"[转发消息]",
|
||||
"[合并转发]",
|
||||
"Alice: [Forward Message]",
|
||||
"(Alice): [转发消息]",
|
||||
"[Forward Message]\n[转发消息]",
|
||||
"Alice: [Forward Message]\n(Bob): [合并转发]",
|
||||
"[转发消息]\n\n[合并转发]",
|
||||
],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_quoted_message_text_forward_placeholder_variants_trigger_fallback(
|
||||
placeholder_text: str,
|
||||
):
|
||||
reply = Reply(id="400", chain=[Plain(text=placeholder_text)], message_str="")
|
||||
event = _make_event(
|
||||
reply,
|
||||
responses={
|
||||
("get_msg", "400"): {
|
||||
"data": {
|
||||
"message": [
|
||||
{"type": "text", "data": {"text": "Bob: "}},
|
||||
{"type": "image", "data": {}},
|
||||
{"type": "text", "data": {"text": "world"}},
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
text = await extract_quoted_message_text(event)
|
||||
assert "Bob: [Image]world" in text
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_quoted_message_text_mixed_placeholder_does_not_trigger_fallback():
|
||||
reply = Reply(
|
||||
id="402",
|
||||
chain=[Plain(text="Alice: [Forward Message]\nreal text")],
|
||||
message_str="",
|
||||
)
|
||||
event = SimpleNamespace(
|
||||
message_obj=SimpleNamespace(message=[reply]),
|
||||
bot=SimpleNamespace(api=_FailIfCalledAPI()),
|
||||
get_group_id=lambda: "",
|
||||
)
|
||||
|
||||
text = await extract_quoted_message_text(event)
|
||||
assert text is not None
|
||||
assert "[Forward Message]" in text
|
||||
assert "real text" in text
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_quoted_message_text_forward_placeholder_fallback_failure():
|
||||
reply = Reply(id="401", chain=[Plain(text="[Forward Message]")], message_str="")
|
||||
event = _make_event(reply, responses={})
|
||||
|
||||
text = await extract_quoted_message_text(event)
|
||||
assert text == "[Forward Message]"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_quoted_message_text_multimsg_malformed_config_does_not_raise():
|
||||
reply = Reply(id="402", chain=None, message_str="")
|
||||
event = _make_event(
|
||||
reply,
|
||||
responses={
|
||||
("get_msg", "402"): {
|
||||
"data": {
|
||||
"message": [
|
||||
{
|
||||
"type": "json",
|
||||
"data": {
|
||||
"data": (
|
||||
'{"app":"com.tencent.multimsg",'
|
||||
'"config":"oops","meta":{}}'
|
||||
)
|
||||
},
|
||||
},
|
||||
{"type": "text", "data": {"text": "still works"}},
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
text = await extract_quoted_message_text(event)
|
||||
assert text == "still works"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_quoted_message_images_from_reply_chain():
|
||||
reply = Reply(
|
||||
id="1",
|
||||
chain=[
|
||||
Plain(text="quoted"),
|
||||
Image(file="https://img.example.com/a.jpg"),
|
||||
],
|
||||
message_str="",
|
||||
)
|
||||
event = _make_event(reply)
|
||||
|
||||
images = await extract_quoted_message_images(event)
|
||||
assert images == ["https://img.example.com/a.jpg"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_quoted_message_images_fallback_get_msg_direct_url():
|
||||
reply = Reply(id="200", chain=None, message_str="")
|
||||
event = _make_event(
|
||||
reply,
|
||||
responses={
|
||||
("get_msg", "200"): {
|
||||
"data": {
|
||||
"message": [
|
||||
{
|
||||
"type": "image",
|
||||
"data": {"url": "https://img.example.com/direct.jpg"},
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
images = await extract_quoted_message_images(event)
|
||||
assert images == ["https://img.example.com/direct.jpg"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_quoted_message_images_data_image_ref_normalized_to_base64():
|
||||
data_image_ref = "data:image/png;base64,abcd1234=="
|
||||
reply = Reply(id="201", chain=None, message_str="")
|
||||
event = _make_event(
|
||||
reply,
|
||||
responses={
|
||||
("get_msg", "201"): {
|
||||
"data": {
|
||||
"message": [
|
||||
{"type": "image", "data": {"url": data_image_ref}},
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
images = await extract_quoted_message_images(event)
|
||||
assert images == ["base64://abcd1234=="]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_quoted_message_images_file_url_with_query_string():
|
||||
url_with_query = "https://img.example.com/direct.jpg?token=abc123#frag"
|
||||
reply = Reply(id="205", chain=None, message_str="")
|
||||
event = _make_event(
|
||||
reply,
|
||||
responses={
|
||||
("get_msg", "205"): {
|
||||
"data": {
|
||||
"message": [
|
||||
{
|
||||
"type": "file",
|
||||
"data": {
|
||||
"url": url_with_query,
|
||||
"name": "direct.jpg",
|
||||
},
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
images = await extract_quoted_message_images(event)
|
||||
assert images == [url_with_query]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_quoted_message_images_non_image_local_path_is_ignored(tmp_path):
|
||||
non_image_file = tmp_path / "secret.txt"
|
||||
non_image_file.write_text("not an image", encoding="utf-8")
|
||||
|
||||
reply = Reply(
|
||||
id="placeholder", chain=[Image(file=str(non_image_file))], message_str=""
|
||||
)
|
||||
object.__setattr__(reply, "id", None)
|
||||
event = SimpleNamespace(
|
||||
message_obj=SimpleNamespace(message=[reply]),
|
||||
bot=SimpleNamespace(api=_FailIfCalledAPI()),
|
||||
get_group_id=lambda: "",
|
||||
)
|
||||
|
||||
images = await extract_quoted_message_images(event)
|
||||
assert images == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_quoted_message_images_chain_placeholder_triggers_fallback():
|
||||
reply = Reply(id="210", chain=[Plain(text="[Forward Message]")], message_str="")
|
||||
event = _make_event(
|
||||
reply,
|
||||
responses={
|
||||
("get_msg", "210"): {
|
||||
"data": {
|
||||
"message": [
|
||||
{
|
||||
"type": "image",
|
||||
"data": {
|
||||
"url": "https://img.example.com/from-fallback.jpg"
|
||||
},
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
images = await extract_quoted_message_images(event)
|
||||
assert images == ["https://img.example.com/from-fallback.jpg"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_quoted_message_images_fallback_resolve_file_id_with_get_image():
|
||||
reply = Reply(id="300", chain=None, message_str="")
|
||||
event = _make_event(
|
||||
reply,
|
||||
responses={
|
||||
("get_msg", "300"): {
|
||||
"data": {"message": [{"type": "image", "data": {"file": "abc123.jpg"}}]}
|
||||
}
|
||||
},
|
||||
param_responses={
|
||||
("get_image", (("file", "abc123.jpg"),)): {
|
||||
"data": {"url": "https://img.example.com/resolved.jpg"}
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
images = await extract_quoted_message_images(event)
|
||||
assert images == ["https://img.example.com/resolved.jpg"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_quoted_message_images_deduplicates_across_sources():
|
||||
dup_url = "https://img.example.com/dup.jpg"
|
||||
chain_only_url = "https://img.example.com/only-chain.jpg"
|
||||
get_msg_only_url = "https://img.example.com/only-get-msg.jpg"
|
||||
forward_only_url = "https://img.example.com/only-forward.jpg"
|
||||
|
||||
reply = Reply(
|
||||
id="310",
|
||||
chain=[Image(file=dup_url), Image(file=chain_only_url)],
|
||||
message_str="",
|
||||
)
|
||||
|
||||
event = _make_event(
|
||||
reply,
|
||||
responses={
|
||||
("get_msg", "310"): {
|
||||
"data": {
|
||||
"message": [
|
||||
{"type": "image", "data": {"url": dup_url}},
|
||||
{"type": "image", "data": {"url": get_msg_only_url}},
|
||||
{"type": "forward", "data": {"id": "999"}},
|
||||
]
|
||||
}
|
||||
},
|
||||
("get_forward_msg", "999"): {
|
||||
"data": {
|
||||
"messages": [
|
||||
{
|
||||
"sender": {"nickname": "Tester"},
|
||||
"message": [
|
||||
{"type": "image", "data": {"url": dup_url}},
|
||||
{"type": "image", "data": {"url": forward_only_url}},
|
||||
],
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
images = await extract_quoted_message_images(event)
|
||||
assert images == [
|
||||
dup_url,
|
||||
chain_only_url,
|
||||
get_msg_only_url,
|
||||
forward_only_url,
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_quoted_message_nested_forward_id_is_resolved():
|
||||
nested_image = "https://img.example.com/nested.jpg"
|
||||
reply = Reply(id="320", chain=[Plain(text="[Forward Message]")], message_str="")
|
||||
event = _make_event(
|
||||
reply,
|
||||
responses={
|
||||
("get_msg", "320"): {
|
||||
"data": {"message": [{"type": "forward", "data": {"id": "fwd_1"}}]}
|
||||
},
|
||||
("get_forward_msg", "fwd_1"): {
|
||||
"data": {
|
||||
"messages": [
|
||||
{
|
||||
"sender": {"nickname": "Alice"},
|
||||
"message": [{"type": "forward", "data": {"id": "fwd_2"}}],
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
("get_forward_msg", "fwd_2"): {
|
||||
"data": {
|
||||
"messages": [
|
||||
{
|
||||
"sender": {"nickname": "Bob"},
|
||||
"message": [
|
||||
{"type": "text", "data": {"text": "deep"}},
|
||||
{"type": "image", "data": {"url": nested_image}},
|
||||
],
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
text = await extract_quoted_message_text(event)
|
||||
assert text is not None
|
||||
assert "Bob: deep" in text
|
||||
|
||||
images = await extract_quoted_message_images(event)
|
||||
assert images == [nested_image]
|
||||
@@ -0,0 +1,52 @@
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
from astrbot.core.utils.temp_dir_cleaner import TempDirCleaner, parse_size_to_bytes
|
||||
|
||||
|
||||
def test_parse_size_to_bytes():
|
||||
assert parse_size_to_bytes("1024") == 1024 * 1024**2
|
||||
assert parse_size_to_bytes(2048) == 2048 * 1024**2
|
||||
assert parse_size_to_bytes("0.5") == int(0.5 * 1024**2)
|
||||
assert parse_size_to_bytes(0) == 0
|
||||
assert parse_size_to_bytes("invalid") == 0
|
||||
|
||||
|
||||
def _write_file(path: Path, size: int, mtime: float) -> None:
|
||||
path.write_bytes(b"x" * size)
|
||||
os.utime(path, (mtime, mtime))
|
||||
|
||||
|
||||
def test_cleanup_once_releases_30_percent_and_prefers_old_files(tmp_path):
|
||||
temp_dir = tmp_path / "temp"
|
||||
temp_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
base_time = time.time() - 1000
|
||||
file_old = temp_dir / "old.bin"
|
||||
file_mid = temp_dir / "mid.bin"
|
||||
file_new = temp_dir / "new.bin"
|
||||
_write_file(file_old, 400, base_time)
|
||||
_write_file(file_mid, 300, base_time + 10)
|
||||
_write_file(file_new, 300, base_time + 20)
|
||||
|
||||
cleaner = TempDirCleaner(max_size_getter=lambda: "0.0008", temp_dir=temp_dir)
|
||||
cleaner.cleanup_once()
|
||||
|
||||
remaining_size = sum(f.stat().st_size for f in temp_dir.rglob("*") if f.is_file())
|
||||
assert remaining_size <= 600
|
||||
assert not file_old.exists()
|
||||
assert file_mid.exists()
|
||||
assert file_new.exists()
|
||||
|
||||
|
||||
def test_cleanup_once_noop_when_below_limit(tmp_path):
|
||||
temp_dir = tmp_path / "temp"
|
||||
temp_dir.mkdir(parents=True, exist_ok=True)
|
||||
file_path = temp_dir / "a.bin"
|
||||
_write_file(file_path, 100, time.time())
|
||||
|
||||
cleaner = TempDirCleaner(max_size_getter=lambda: "1", temp_dir=temp_dir)
|
||||
cleaner.cleanup_once()
|
||||
|
||||
assert file_path.exists()
|
||||
@@ -90,6 +90,21 @@ class MockToolExecutor:
|
||||
return generator()
|
||||
|
||||
|
||||
class MockFailingProvider(MockProvider):
|
||||
async def text_chat(self, **kwargs) -> LLMResponse:
|
||||
self.call_count += 1
|
||||
raise RuntimeError("primary provider failed")
|
||||
|
||||
|
||||
class MockErrProvider(MockProvider):
|
||||
async def text_chat(self, **kwargs) -> LLMResponse:
|
||||
self.call_count += 1
|
||||
return LLMResponse(
|
||||
role="err",
|
||||
completion_text="primary provider returned error",
|
||||
)
|
||||
|
||||
|
||||
class MockHooks(BaseAgentRunHooks):
|
||||
"""模拟钩子函数"""
|
||||
|
||||
@@ -321,6 +336,64 @@ async def test_hooks_called_with_max_step(
|
||||
assert mock_hooks.tool_end_called, "on_tool_end应该被调用"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_provider_used_when_primary_raises(
|
||||
runner, provider_request, mock_tool_executor, mock_hooks
|
||||
):
|
||||
primary_provider = MockFailingProvider()
|
||||
fallback_provider = MockProvider()
|
||||
fallback_provider.should_call_tools = False
|
||||
|
||||
await runner.reset(
|
||||
provider=primary_provider,
|
||||
request=provider_request,
|
||||
run_context=ContextWrapper(context=None),
|
||||
tool_executor=mock_tool_executor,
|
||||
agent_hooks=mock_hooks,
|
||||
streaming=False,
|
||||
fallback_providers=[fallback_provider],
|
||||
)
|
||||
|
||||
async for _ in runner.step_until_done(5):
|
||||
pass
|
||||
|
||||
final_resp = runner.get_final_llm_resp()
|
||||
assert final_resp is not None
|
||||
assert final_resp.role == "assistant"
|
||||
assert final_resp.completion_text == "这是我的最终回答"
|
||||
assert primary_provider.call_count == 1
|
||||
assert fallback_provider.call_count == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_provider_used_when_primary_returns_err(
|
||||
runner, provider_request, mock_tool_executor, mock_hooks
|
||||
):
|
||||
primary_provider = MockErrProvider()
|
||||
fallback_provider = MockProvider()
|
||||
fallback_provider.should_call_tools = False
|
||||
|
||||
await runner.reset(
|
||||
provider=primary_provider,
|
||||
request=provider_request,
|
||||
run_context=ContextWrapper(context=None),
|
||||
tool_executor=mock_tool_executor,
|
||||
agent_hooks=mock_hooks,
|
||||
streaming=False,
|
||||
fallback_providers=[fallback_provider],
|
||||
)
|
||||
|
||||
async for _ in runner.step_until_done(5):
|
||||
pass
|
||||
|
||||
final_resp = runner.get_final_llm_resp()
|
||||
assert final_resp is not None
|
||||
assert final_resp.role == "assistant"
|
||||
assert final_resp.completion_text == "这是我的最终回答"
|
||||
assert primary_provider.call_count == 1
|
||||
assert fallback_provider.call_count == 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 运行测试
|
||||
pytest.main([__file__, "-v"])
|
||||
|
||||
Reference in New Issue
Block a user