4ff07e3c74
* feat: support fallback image parsing for quoted messages * fix: fallback parse quoted images when reply chain has placeholders * style: format network utils with ruff * test: expand quoted parser coverage and improve fallback diagnostics * fix: fallback to text-only retry when image requests fail * fix: tighten image fallback and resolve nested quoted forwards * refactor: simplify quoted message extraction and dedupe images * fix: harden quoted parsing and openai error candidates * fix: harden quoted image ref normalization * refactor: organize quoted parser settings and logging * fix: cap quoted fallback images and avoid retry loops * refactor: split quoted message parser into focused modules * refactor: share onebot segment parsing logic * refactor: unify quoted message parsing flow * feat: move quoted parser tuning to provider settings * fix: add missing i18n metadata for quoted parser settings * chore: refine forwarded message setting labels
383 lines
12 KiB
Python
383 lines
12 KiB
Python
from types import SimpleNamespace
|
|
|
|
import pytest
|
|
|
|
from astrbot.core.provider.sources.openai_source import ProviderOpenAIOfficial
|
|
|
|
|
|
class _ErrorWithBody(Exception):
|
|
def __init__(self, message: str, body: dict):
|
|
super().__init__(message)
|
|
self.body = body
|
|
|
|
|
|
class _ErrorWithResponse(Exception):
|
|
def __init__(self, message: str, response_text: str):
|
|
super().__init__(message)
|
|
self.response = SimpleNamespace(text=response_text)
|
|
|
|
|
|
def _make_provider(overrides: dict | None = None) -> ProviderOpenAIOfficial:
|
|
provider_config = {
|
|
"id": "test-openai",
|
|
"type": "openai_chat_completion",
|
|
"model": "gpt-4o-mini",
|
|
"key": ["test-key"],
|
|
}
|
|
if overrides:
|
|
provider_config.update(overrides)
|
|
return ProviderOpenAIOfficial(
|
|
provider_config=provider_config,
|
|
provider_settings={},
|
|
)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_api_error_content_moderated_removes_images():
|
|
provider = _make_provider(
|
|
{"image_moderation_error_patterns": ["file:content-moderated"]}
|
|
)
|
|
try:
|
|
payloads = {
|
|
"messages": [
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
{"type": "text", "text": "hello"},
|
|
{
|
|
"type": "image_url",
|
|
"image_url": {"url": "data:image/jpeg;base64,abcd"},
|
|
},
|
|
],
|
|
}
|
|
]
|
|
}
|
|
context_query = payloads["messages"]
|
|
|
|
success, *_rest = await provider._handle_api_error(
|
|
Exception("Content is moderated [WKE=file:content-moderated]"),
|
|
payloads=payloads,
|
|
context_query=context_query,
|
|
func_tool=None,
|
|
chosen_key="test-key",
|
|
available_api_keys=["test-key"],
|
|
retry_cnt=0,
|
|
max_retries=10,
|
|
)
|
|
|
|
assert success is False
|
|
updated_context = payloads["messages"]
|
|
assert isinstance(updated_context, list)
|
|
assert updated_context[0]["content"] == [{"type": "text", "text": "hello"}]
|
|
finally:
|
|
await provider.terminate()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_api_error_model_not_vlm_removes_images_and_retries_text_only():
|
|
provider = _make_provider()
|
|
try:
|
|
payloads = {
|
|
"messages": [
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
{"type": "text", "text": "hello"},
|
|
{
|
|
"type": "image_url",
|
|
"image_url": {"url": "data:image/jpeg;base64,abcd"},
|
|
},
|
|
],
|
|
}
|
|
]
|
|
}
|
|
context_query = payloads["messages"]
|
|
|
|
success, *_rest = await provider._handle_api_error(
|
|
Exception("The model is not a VLM and cannot process images"),
|
|
payloads=payloads,
|
|
context_query=context_query,
|
|
func_tool=None,
|
|
chosen_key="test-key",
|
|
available_api_keys=["test-key"],
|
|
retry_cnt=0,
|
|
max_retries=10,
|
|
)
|
|
|
|
assert success is False
|
|
updated_context = payloads["messages"]
|
|
assert isinstance(updated_context, list)
|
|
assert updated_context[0]["content"] == [{"type": "text", "text": "hello"}]
|
|
finally:
|
|
await provider.terminate()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_api_error_model_not_vlm_after_fallback_raises():
|
|
provider = _make_provider()
|
|
try:
|
|
payloads = {
|
|
"messages": [
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
{"type": "text", "text": "hello"},
|
|
{
|
|
"type": "image_url",
|
|
"image_url": {"url": "data:image/jpeg;base64,abcd"},
|
|
},
|
|
],
|
|
}
|
|
]
|
|
}
|
|
context_query = payloads["messages"]
|
|
|
|
with pytest.raises(Exception, match="not a VLM"):
|
|
await provider._handle_api_error(
|
|
Exception("The model is not a VLM and cannot process images"),
|
|
payloads=payloads,
|
|
context_query=context_query,
|
|
func_tool=None,
|
|
chosen_key="test-key",
|
|
available_api_keys=["test-key"],
|
|
retry_cnt=1,
|
|
max_retries=10,
|
|
image_fallback_used=True,
|
|
)
|
|
finally:
|
|
await provider.terminate()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_api_error_content_moderated_with_unserializable_body():
|
|
provider = _make_provider({"image_moderation_error_patterns": ["blocked"]})
|
|
try:
|
|
payloads = {
|
|
"messages": [
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
{"type": "text", "text": "hello"},
|
|
{
|
|
"type": "image_url",
|
|
"image_url": {"url": "data:image/jpeg;base64,abcd"},
|
|
},
|
|
],
|
|
}
|
|
]
|
|
}
|
|
context_query = payloads["messages"]
|
|
err = _ErrorWithBody(
|
|
"upstream error",
|
|
{"error": {"message": "blocked"}, "raw": object()},
|
|
)
|
|
|
|
success, *_rest = await provider._handle_api_error(
|
|
err,
|
|
payloads=payloads,
|
|
context_query=context_query,
|
|
func_tool=None,
|
|
chosen_key="test-key",
|
|
available_api_keys=["test-key"],
|
|
retry_cnt=0,
|
|
max_retries=10,
|
|
)
|
|
assert success is False
|
|
assert payloads["messages"][0]["content"] == [{"type": "text", "text": "hello"}]
|
|
finally:
|
|
await provider.terminate()
|
|
|
|
|
|
def test_extract_error_text_candidates_truncates_long_response_text():
|
|
long_text = "x" * 20000
|
|
err = _ErrorWithResponse("upstream error", long_text)
|
|
candidates = ProviderOpenAIOfficial._extract_error_text_candidates(err)
|
|
assert candidates
|
|
assert max(len(candidate) for candidate in candidates) <= (
|
|
ProviderOpenAIOfficial._ERROR_TEXT_CANDIDATE_MAX_CHARS
|
|
)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_api_error_content_moderated_without_images_raises():
|
|
provider = _make_provider(
|
|
{"image_moderation_error_patterns": ["file:content-moderated"]}
|
|
)
|
|
try:
|
|
payloads = {
|
|
"messages": [
|
|
{
|
|
"role": "user",
|
|
"content": [{"type": "text", "text": "hello"}],
|
|
}
|
|
]
|
|
}
|
|
context_query = payloads["messages"]
|
|
err = Exception("Content is moderated [WKE=file:content-moderated]")
|
|
|
|
with pytest.raises(Exception, match="content-moderated"):
|
|
await provider._handle_api_error(
|
|
err,
|
|
payloads=payloads,
|
|
context_query=context_query,
|
|
func_tool=None,
|
|
chosen_key="test-key",
|
|
available_api_keys=["test-key"],
|
|
retry_cnt=0,
|
|
max_retries=10,
|
|
)
|
|
finally:
|
|
await provider.terminate()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_api_error_content_moderated_detects_structured_body():
|
|
provider = _make_provider(
|
|
{"image_moderation_error_patterns": ["content_moderated"]}
|
|
)
|
|
try:
|
|
payloads = {
|
|
"messages": [
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
{"type": "text", "text": "hello"},
|
|
{
|
|
"type": "image_url",
|
|
"image_url": {"url": "data:image/jpeg;base64,abcd"},
|
|
},
|
|
],
|
|
}
|
|
]
|
|
}
|
|
context_query = payloads["messages"]
|
|
err = _ErrorWithBody(
|
|
"upstream error",
|
|
{"error": {"code": "content_moderated", "message": "blocked"}},
|
|
)
|
|
|
|
success, *_rest = await provider._handle_api_error(
|
|
err,
|
|
payloads=payloads,
|
|
context_query=context_query,
|
|
func_tool=None,
|
|
chosen_key="test-key",
|
|
available_api_keys=["test-key"],
|
|
retry_cnt=0,
|
|
max_retries=10,
|
|
)
|
|
assert success is False
|
|
assert payloads["messages"][0]["content"] == [{"type": "text", "text": "hello"}]
|
|
finally:
|
|
await provider.terminate()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_api_error_content_moderated_supports_custom_patterns():
|
|
provider = _make_provider(
|
|
{"image_moderation_error_patterns": ["blocked_by_policy_code_123"]}
|
|
)
|
|
try:
|
|
payloads = {
|
|
"messages": [
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
{"type": "text", "text": "hello"},
|
|
{
|
|
"type": "image_url",
|
|
"image_url": {"url": "data:image/jpeg;base64,abcd"},
|
|
},
|
|
],
|
|
}
|
|
]
|
|
}
|
|
context_query = payloads["messages"]
|
|
err = Exception("upstream: blocked_by_policy_code_123")
|
|
|
|
success, *_rest = await provider._handle_api_error(
|
|
err,
|
|
payloads=payloads,
|
|
context_query=context_query,
|
|
func_tool=None,
|
|
chosen_key="test-key",
|
|
available_api_keys=["test-key"],
|
|
retry_cnt=0,
|
|
max_retries=10,
|
|
)
|
|
assert success is False
|
|
assert payloads["messages"][0]["content"] == [{"type": "text", "text": "hello"}]
|
|
finally:
|
|
await provider.terminate()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_api_error_content_moderated_without_patterns_raises():
|
|
provider = _make_provider()
|
|
try:
|
|
payloads = {
|
|
"messages": [
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
{"type": "text", "text": "hello"},
|
|
{
|
|
"type": "image_url",
|
|
"image_url": {"url": "data:image/jpeg;base64,abcd"},
|
|
},
|
|
],
|
|
}
|
|
]
|
|
}
|
|
context_query = payloads["messages"]
|
|
err = Exception("Content is moderated [WKE=file:content-moderated]")
|
|
|
|
with pytest.raises(Exception, match="content-moderated"):
|
|
await provider._handle_api_error(
|
|
err,
|
|
payloads=payloads,
|
|
context_query=context_query,
|
|
func_tool=None,
|
|
chosen_key="test-key",
|
|
available_api_keys=["test-key"],
|
|
retry_cnt=0,
|
|
max_retries=10,
|
|
)
|
|
finally:
|
|
await provider.terminate()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_api_error_unknown_image_error_raises():
|
|
provider = _make_provider()
|
|
try:
|
|
payloads = {
|
|
"messages": [
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
{"type": "text", "text": "hello"},
|
|
{
|
|
"type": "image_url",
|
|
"image_url": {"url": "data:image/jpeg;base64,abcd"},
|
|
},
|
|
],
|
|
}
|
|
]
|
|
}
|
|
context_query = payloads["messages"]
|
|
|
|
with pytest.raises(Exception, match="unknown provider image upload error"):
|
|
await provider._handle_api_error(
|
|
Exception("some unknown provider image upload error"),
|
|
payloads=payloads,
|
|
context_query=context_query,
|
|
func_tool=None,
|
|
chosen_key="test-key",
|
|
available_api_keys=["test-key"],
|
|
retry_cnt=0,
|
|
max_retries=10,
|
|
)
|
|
finally:
|
|
await provider.terminate()
|