fix: 完善转发引用解析与图片回退并支持配置化控制 (#5054)

* feat: support fallback image parsing for quoted messages

* fix: fallback parse quoted images when reply chain has placeholders

* style: format network utils with ruff

* test: expand quoted parser coverage and improve fallback diagnostics

* fix: fallback to text-only retry when image requests fail

* fix: tighten image fallback and resolve nested quoted forwards

* refactor: simplify quoted message extraction and dedupe images

* fix: harden quoted parsing and openai error candidates

* fix: harden quoted image ref normalization

* refactor: organize quoted parser settings and logging

* fix: cap quoted fallback images and avoid retry loops

* refactor: split quoted message parser into focused modules

* refactor: share onebot segment parsing logic

* refactor: unify quoted message parsing flow

* feat: move quoted parser tuning to provider settings

* fix: add missing i18n metadata for quoted parser settings

* chore: refine forwarded message setting labels
This commit is contained in:
エイカク
2026-02-12 23:42:29 +09:00
committed by GitHub
parent 473e01aadd
commit 4ff07e3c74
17 changed files with 2425 additions and 31 deletions
+111 -23
View File
@@ -52,6 +52,17 @@ from astrbot.core.tools.cron_tools import (
)
from astrbot.core.utils.file_extract import extract_file_moonshotai
from astrbot.core.utils.llm_metadata import LLM_METADATAS
from astrbot.core.utils.quoted_message.settings import (
SETTINGS as DEFAULT_QUOTED_MESSAGE_SETTINGS,
)
from astrbot.core.utils.quoted_message.settings import (
QuotedMessageParserSettings,
)
from astrbot.core.utils.quoted_message_parser import (
extract_quoted_message_images,
extract_quoted_message_text,
)
from astrbot.core.utils.string_utils import normalize_and_dedupe_strings
@dataclass(slots=True)
@@ -108,6 +119,8 @@ class MainAgentBuildConfig:
provider_settings: dict = field(default_factory=dict)
subagent_orchestrator: dict = field(default_factory=dict)
timezone: str | None = None
max_quoted_fallback_images: int = 20
"""Maximum number of images injected from quoted-message fallback extraction."""
@dataclass(slots=True)
@@ -470,11 +483,29 @@ async def _ensure_img_caption(
logger.error("处理图片描述失败: %s", exc)
def _append_quoted_image_attachment(req: ProviderRequest, image_path: str) -> None:
req.extra_user_content_parts.append(
TextPart(text=f"[Image Attachment in quoted message: path {image_path}]")
)
def _get_quoted_message_parser_settings(
provider_settings: dict[str, object] | None,
) -> QuotedMessageParserSettings:
if not isinstance(provider_settings, dict):
return DEFAULT_QUOTED_MESSAGE_SETTINGS
overrides = provider_settings.get("quoted_message_parser")
if not isinstance(overrides, dict):
return DEFAULT_QUOTED_MESSAGE_SETTINGS
return DEFAULT_QUOTED_MESSAGE_SETTINGS.with_overrides(overrides)
async def _process_quote_message(
event: AstrMessageEvent,
req: ProviderRequest,
img_cap_prov_id: str,
plugin_context: Context,
quoted_message_settings: QuotedMessageParserSettings = DEFAULT_QUOTED_MESSAGE_SETTINGS,
) -> None:
quote = None
for comp in event.message_obj.message:
@@ -486,7 +517,15 @@ async def _process_quote_message(
content_parts = []
sender_info = f"({quote.sender_nickname}): " if quote.sender_nickname else ""
message_str = quote.message_str or "[Empty Text]"
message_str = (
await extract_quoted_message_text(
event,
quote,
settings=quoted_message_settings,
)
or quote.message_str
or "[Empty Text]"
)
content_parts.append(f"{sender_info}{message_str}")
image_seg = None
@@ -592,11 +631,13 @@ async def _decorate_llm_request(
)
img_cap_prov_id = cfg.get("default_image_caption_provider_id") or ""
quoted_message_settings = _get_quoted_message_parser_settings(cfg)
await _process_quote_message(
event,
req,
img_cap_prov_id,
plugin_context,
quoted_message_settings,
)
tz = config.timezone
@@ -886,32 +927,78 @@ async def build_main_agent(
)
# quoted message attachments
reply_comps = [
comp
for comp in event.message_obj.message
if isinstance(comp, Reply) and comp.chain
comp for comp in event.message_obj.message if isinstance(comp, Reply)
]
quoted_message_settings = _get_quoted_message_parser_settings(
config.provider_settings
)
fallback_quoted_image_count = 0
for comp in reply_comps:
if not comp.chain:
continue
for reply_comp in comp.chain:
if isinstance(reply_comp, Image):
image_path = await reply_comp.convert_to_file_path()
req.image_urls.append(image_path)
req.extra_user_content_parts.append(
TextPart(
text=f"[Image Attachment in quoted message: path {image_path}]"
)
)
elif isinstance(reply_comp, File):
file_path = await reply_comp.get_file()
file_name = reply_comp.name or os.path.basename(file_path)
req.extra_user_content_parts.append(
TextPart(
text=(
f"[File Attachment in quoted message: "
f"name {file_name}, path {file_path}]"
has_embedded_image = False
if comp.chain:
for reply_comp in comp.chain:
if isinstance(reply_comp, Image):
has_embedded_image = True
image_path = await reply_comp.convert_to_file_path()
req.image_urls.append(image_path)
_append_quoted_image_attachment(req, image_path)
elif isinstance(reply_comp, File):
file_path = await reply_comp.get_file()
file_name = reply_comp.name or os.path.basename(file_path)
req.extra_user_content_parts.append(
TextPart(
text=(
f"[File Attachment in quoted message: "
f"name {file_name}, path {file_path}]"
)
)
)
# Fallback quoted image extraction for reply-id-only payloads, or when
# embedded reply chain only contains placeholders (e.g. [Forward Message], [Image]).
if not has_embedded_image:
try:
fallback_images = normalize_and_dedupe_strings(
await extract_quoted_message_images(
event,
comp,
settings=quoted_message_settings,
)
)
remaining_limit = max(
config.max_quoted_fallback_images
- fallback_quoted_image_count,
0,
)
if remaining_limit <= 0 and fallback_images:
logger.warning(
"Skip quoted fallback images due to limit=%d for umo=%s",
config.max_quoted_fallback_images,
event.unified_msg_origin,
)
continue
if len(fallback_images) > remaining_limit:
logger.warning(
"Truncate quoted fallback images for umo=%s, reply_id=%s from %d to %d",
event.unified_msg_origin,
getattr(comp, "id", None),
len(fallback_images),
remaining_limit,
)
fallback_images = fallback_images[:remaining_limit]
for image_ref in fallback_images:
if image_ref in req.image_urls:
continue
req.image_urls.append(image_ref)
fallback_quoted_image_count += 1
_append_quoted_image_attachment(req, image_ref)
except Exception as exc: # noqa: BLE001
logger.warning(
"Failed to resolve fallback quoted images for umo=%s, reply_id=%s: %s",
event.unified_msg_origin,
getattr(comp, "id", None),
exc,
exc_info=True,
)
conversation = await _get_session_conv(event, plugin_context)
@@ -921,6 +1008,7 @@ async def build_main_agent(
if isinstance(req.contexts, str):
req.contexts = json.loads(req.contexts)
req.image_urls = normalize_and_dedupe_strings(req.image_urls)
if config.file_extract_enabled:
try:
+47
View File
@@ -99,6 +99,13 @@ DEFAULT_CONFIG = {
"streaming_response": False,
"show_tool_use_status": False,
"sanitize_context_by_modalities": False,
"max_quoted_fallback_images": 20,
"quoted_message_parser": {
"max_component_chain_depth": 4,
"max_forward_node_depth": 6,
"max_forward_fetch": 32,
"warn_on_action_failure": False,
},
"agent_runner_type": "local",
"dify_agent_runner_provider_id": "",
"coze_agent_runner_provider_id": "",
@@ -2908,6 +2915,46 @@ CONFIG_METADATA_3 = {
"provider_settings.agent_runner_type": "local",
},
},
"provider_settings.max_quoted_fallback_images": {
"description": "引用图片回退解析上限",
"type": "int",
"hint": "引用/转发消息回退解析图片时的最大注入数量,超出会截断。",
"condition": {
"provider_settings.agent_runner_type": "local",
},
},
"provider_settings.quoted_message_parser.max_component_chain_depth": {
"description": "引用解析组件链深度",
"type": "int",
"hint": "解析 Reply 组件链时允许的最大递归深度。",
"condition": {
"provider_settings.agent_runner_type": "local",
},
},
"provider_settings.quoted_message_parser.max_forward_node_depth": {
"description": "引用解析转发节点深度",
"type": "int",
"hint": "解析合并转发节点时允许的最大递归深度。",
"condition": {
"provider_settings.agent_runner_type": "local",
},
},
"provider_settings.quoted_message_parser.max_forward_fetch": {
"description": "引用解析转发拉取上限",
"type": "int",
"hint": "递归拉取 get_forward_msg 的最大次数。",
"condition": {
"provider_settings.agent_runner_type": "local",
},
},
"provider_settings.quoted_message_parser.warn_on_action_failure": {
"description": "引用解析 action 失败告警",
"type": "bool",
"hint": "开启后,get_msg/get_forward_msg 全部尝试失败时输出 warning 日志。",
"condition": {
"provider_settings.agent_runner_type": "local",
},
},
"provider_settings.max_agent_step": {
"description": "工具调用轮数上限",
"type": "int",
@@ -123,6 +123,7 @@ class InternalAgentSubStage(Stage):
provider_settings=settings,
subagent_orchestrator=conf.get("subagent_orchestrator", {}),
timezone=self.ctx.plugin_manager.context.get_config().get("timezone"),
max_quoted_fallback_images=settings.get("max_quoted_fallback_images", 20),
)
async def process(
+162 -8
View File
@@ -5,6 +5,7 @@ import json
import random
import re
from collections.abc import AsyncGenerator
from typing import Any
import httpx
from openai import AsyncAzureOpenAI, AsyncOpenAI
@@ -27,6 +28,7 @@ from astrbot.core.utils.network_utils import (
is_connection_error,
log_connection_failure,
)
from astrbot.core.utils.string_utils import normalize_and_dedupe_strings
from ..register import register_provider_adapter
@@ -36,6 +38,128 @@ from ..register import register_provider_adapter
"OpenAI API Chat Completion 提供商适配器",
)
class ProviderOpenAIOfficial(Provider):
_ERROR_TEXT_CANDIDATE_MAX_CHARS = 4096
@classmethod
def _truncate_error_text_candidate(cls, text: str) -> str:
if len(text) <= cls._ERROR_TEXT_CANDIDATE_MAX_CHARS:
return text
return text[: cls._ERROR_TEXT_CANDIDATE_MAX_CHARS]
@staticmethod
def _safe_json_dump(value: Any) -> str | None:
try:
return json.dumps(value, ensure_ascii=False, default=str)
except Exception:
return None
def _get_image_moderation_error_patterns(self) -> list[str]:
"""Return configured moderation patterns (case-insensitive substring match, not regex)."""
configured = self.provider_config.get("image_moderation_error_patterns", [])
patterns: list[str] = []
if isinstance(configured, str):
configured = [configured]
if isinstance(configured, list):
for pattern in configured:
if not isinstance(pattern, str):
continue
pattern = pattern.strip()
if pattern:
patterns.append(pattern)
return patterns
@staticmethod
def _extract_error_text_candidates(error: Exception) -> list[str]:
candidates: list[str] = []
def _append_candidate(candidate: Any):
if candidate is None:
return
text = str(candidate).strip()
if not text:
return
candidates.append(
ProviderOpenAIOfficial._truncate_error_text_candidate(text)
)
_append_candidate(str(error))
body = getattr(error, "body", None)
if isinstance(body, dict):
err_obj = body.get("error")
body_text = ProviderOpenAIOfficial._safe_json_dump(
{"error": err_obj} if isinstance(err_obj, dict) else body
)
_append_candidate(body_text)
if isinstance(err_obj, dict):
for field in ("message", "type", "code", "param"):
value = err_obj.get(field)
if value is not None:
_append_candidate(value)
elif isinstance(body, str):
_append_candidate(body)
response = getattr(error, "response", None)
if response is not None:
response_text = getattr(response, "text", None)
if isinstance(response_text, str):
_append_candidate(response_text)
return normalize_and_dedupe_strings(candidates)
def _is_content_moderated_upload_error(self, error: Exception) -> bool:
patterns = [
pattern.lower() for pattern in self._get_image_moderation_error_patterns()
]
if not patterns:
return False
candidates = [
candidate.lower()
for candidate in self._extract_error_text_candidates(error)
]
for pattern in patterns:
if any(pattern in candidate for candidate in candidates):
return True
return False
@staticmethod
def _context_contains_image(contexts: list[dict]) -> bool:
for context in contexts:
content = context.get("content")
if not isinstance(content, list):
continue
for item in content:
if isinstance(item, dict) and item.get("type") == "image_url":
return True
return False
async def _fallback_to_text_only_and_retry(
self,
payloads: dict,
context_query: list,
chosen_key: str,
available_api_keys: list[str],
func_tool: ToolSet | None,
reason: str,
*,
image_fallback_used: bool = False,
) -> tuple:
logger.warning(
"检测到图片请求失败(%s),已移除图片并重试(保留文本内容)。",
reason,
)
new_contexts = await self._remove_image_from_context(context_query)
payloads["messages"] = new_contexts
return (
False,
chosen_key,
available_api_keys,
payloads,
new_contexts,
func_tool,
image_fallback_used,
)
def _create_http_client(self, provider_config: dict) -> httpx.AsyncClient | None:
"""创建带代理的 HTTP 客户端"""
proxy = provider_config.get("proxy", "")
@@ -403,6 +527,7 @@ class ProviderOpenAIOfficial(Provider):
available_api_keys: list[str],
retry_cnt: int,
max_retries: int,
image_fallback_used: bool = False,
) -> tuple:
"""处理API错误并尝试恢复"""
if "429" in str(e):
@@ -422,6 +547,7 @@ class ProviderOpenAIOfficial(Provider):
payloads,
context_query,
func_tool,
image_fallback_used,
)
raise e
if "maximum context length" in str(e):
@@ -437,20 +563,34 @@ class ProviderOpenAIOfficial(Provider):
payloads,
context_query,
func_tool,
image_fallback_used,
)
if "The model is not a VLM" in str(e): # siliconcloud
if image_fallback_used or not self._context_contains_image(context_query):
raise e
# 尝试删除所有 image
new_contexts = await self._remove_image_from_context(context_query)
payloads["messages"] = new_contexts
context_query = new_contexts
return (
False,
chosen_key,
available_api_keys,
return await self._fallback_to_text_only_and_retry(
payloads,
context_query,
chosen_key,
available_api_keys,
func_tool,
"model_not_vlm",
image_fallback_used=True,
)
if self._is_content_moderated_upload_error(e):
if image_fallback_used or not self._context_contains_image(context_query):
raise e
return await self._fallback_to_text_only_and_retry(
payloads,
context_query,
chosen_key,
available_api_keys,
func_tool,
"image_content_moderated",
image_fallback_used=True,
)
if (
"Function calling is not enabled" in str(e)
or ("tool" in str(e).lower() and "support" in str(e).lower())
@@ -461,7 +601,15 @@ class ProviderOpenAIOfficial(Provider):
f"{self.get_model()} 不支持函数工具调用,已自动去除,不影响使用。",
)
payloads.pop("tools", None)
return False, chosen_key, available_api_keys, payloads, context_query, None
return (
False,
chosen_key,
available_api_keys,
payloads,
context_query,
None,
image_fallback_used,
)
# logger.error(f"发生了错误。Provider 配置如下: {self.provider_config}")
if "tool" in str(e).lower() and "support" in str(e).lower():
@@ -501,6 +649,7 @@ class ProviderOpenAIOfficial(Provider):
max_retries = 10
available_api_keys = self.api_keys.copy()
chosen_key = random.choice(available_api_keys)
image_fallback_used = False
last_exception = None
retry_cnt = 0
@@ -518,6 +667,7 @@ class ProviderOpenAIOfficial(Provider):
payloads,
context_query,
func_tool,
image_fallback_used,
) = await self._handle_api_error(
e,
payloads,
@@ -527,6 +677,7 @@ class ProviderOpenAIOfficial(Provider):
available_api_keys,
retry_cnt,
max_retries,
image_fallback_used=image_fallback_used,
)
if success:
break
@@ -564,6 +715,7 @@ class ProviderOpenAIOfficial(Provider):
max_retries = 10
available_api_keys = self.api_keys.copy()
chosen_key = random.choice(available_api_keys)
image_fallback_used = False
last_exception = None
retry_cnt = 0
@@ -582,6 +734,7 @@ class ProviderOpenAIOfficial(Provider):
payloads,
context_query,
func_tool,
image_fallback_used,
) = await self._handle_api_error(
e,
payloads,
@@ -591,6 +744,7 @@ class ProviderOpenAIOfficial(Provider):
available_api_keys,
retry_cnt,
max_retries,
image_fallback_used=image_fallback_used,
)
if success:
break
@@ -0,0 +1,8 @@
from __future__ import annotations
from .extractor import extract_quoted_message_images, extract_quoted_message_text
__all__ = [
"extract_quoted_message_text",
"extract_quoted_message_images",
]
@@ -0,0 +1,505 @@
from __future__ import annotations
import json
import re
from typing import Any, TypedDict
from astrbot.core.message.components import (
At,
AtAll,
File,
Forward,
Image,
Node,
Nodes,
Plain,
Reply,
Video,
)
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.utils.string_utils import normalize_and_dedupe_strings
from .image_refs import looks_like_image_file_name, normalize_file_like_url
from .settings import SETTINGS, QuotedMessageParserSettings
_FORWARD_PLACEHOLDER_PATTERN = re.compile(
r"^(?:[\(\[]?[^\]:\)]*[\)\]]?\s*:\s*)?\[(?:forward message|转发消息|合并转发)\]$",
flags=re.IGNORECASE,
)
class ParsedOneBotPayload(TypedDict):
text: str | None
forward_ids: list[str]
image_refs: list[str]
def _build_parsed_payload(
text: str | None,
forward_ids: list[str] | None = None,
image_refs: list[str] | None = None,
) -> ParsedOneBotPayload:
return {
"text": text,
"forward_ids": forward_ids or [],
"image_refs": image_refs or [],
}
def _join_text_parts(parts: list[str]) -> str | None:
text = "".join(parts).strip()
return text or None
def _find_first_reply_component(event: AstrMessageEvent) -> Reply | None:
for comp in event.message_obj.message:
if isinstance(comp, Reply):
return comp
return None
def _is_forward_placeholder_only_text(text: str | None) -> bool:
if not isinstance(text, str):
return False
lines = [line.strip() for line in text.splitlines() if line.strip()]
if not lines:
return False
return all(_FORWARD_PLACEHOLDER_PATTERN.match(line) for line in lines)
def _extract_image_refs_from_component_chain(
chain: list[Any] | None,
*,
depth: int = 0,
settings: QuotedMessageParserSettings = SETTINGS,
) -> list[str]:
if not isinstance(chain, list) or depth > settings.max_component_chain_depth:
return []
image_refs: list[str] = []
for seg in chain:
if isinstance(seg, Image):
for candidate in (seg.url, seg.file, seg.path):
if isinstance(candidate, str) and candidate.strip():
image_refs.append(candidate.strip())
break
elif isinstance(seg, Reply):
image_refs.extend(
_extract_image_refs_from_reply_component(
seg,
depth=depth + 1,
settings=settings,
)
)
elif isinstance(seg, Node):
image_refs.extend(
_extract_image_refs_from_component_chain(
seg.content,
depth=depth + 1,
settings=settings,
)
)
elif isinstance(seg, Nodes):
for node in seg.nodes:
image_refs.extend(
_extract_image_refs_from_component_chain(
node.content,
depth=depth + 1,
settings=settings,
)
)
return normalize_and_dedupe_strings(image_refs)
def _extract_text_from_component_chain(
chain: list[Any] | None,
*,
depth: int = 0,
settings: QuotedMessageParserSettings = SETTINGS,
) -> str | None:
if not isinstance(chain, list) or depth > settings.max_component_chain_depth:
return None
parts: list[str] = []
for seg in chain:
if isinstance(seg, Plain):
if seg.text:
parts.append(seg.text)
elif isinstance(seg, At):
if seg.name:
parts.append(f"@{seg.name}")
elif seg.qq:
parts.append(f"@{seg.qq}")
elif isinstance(seg, AtAll):
parts.append("@all")
elif isinstance(seg, Image):
parts.append("[Image]")
elif isinstance(seg, Video):
parts.append("[Video]")
elif isinstance(seg, File):
file_name = seg.name or "file"
parts.append(f"[File:{file_name}]")
elif isinstance(seg, Forward):
parts.append("[Forward Message]")
elif isinstance(seg, Reply):
nested = _extract_text_from_reply_component(
seg,
depth=depth + 1,
settings=settings,
)
if nested:
parts.append(nested)
elif isinstance(seg, Node):
node_sender = seg.name or seg.uin or "Unknown User"
node_text = _extract_text_from_component_chain(
seg.content,
depth=depth + 1,
settings=settings,
)
if node_text:
parts.append(f"{node_sender}: {node_text}")
elif isinstance(seg, Nodes):
for node in seg.nodes:
node_sender = node.name or node.uin or "Unknown User"
node_text = _extract_text_from_component_chain(
node.content,
depth=depth + 1,
settings=settings,
)
if node_text:
parts.append(f"{node_sender}: {node_text}")
return _join_text_parts(parts)
def _extract_image_refs_from_reply_component(
reply: Reply,
*,
depth: int = 0,
settings: QuotedMessageParserSettings = SETTINGS,
) -> list[str]:
for attr in ("chain", "message", "origin", "content"):
payload = getattr(reply, attr, None)
image_refs = _extract_image_refs_from_component_chain(
payload,
depth=depth,
settings=settings,
)
if image_refs:
return image_refs
return []
def _extract_text_from_reply_component(
reply: Reply,
*,
depth: int = 0,
settings: QuotedMessageParserSettings = SETTINGS,
) -> str | None:
for attr in ("chain", "message", "origin", "content"):
payload = getattr(reply, attr, None)
text = _extract_text_from_component_chain(
payload,
depth=depth,
settings=settings,
)
if text:
return text
if reply.message_str and reply.message_str.strip():
return reply.message_str.strip()
return None
def _unwrap_onebot_data(payload: Any) -> dict[str, Any]:
if not isinstance(payload, dict):
return {}
data = payload.get("data")
if isinstance(data, dict):
return data
return payload
def _extract_text_from_multimsg_json(raw_json: str) -> str | None:
try:
parsed = json.loads(raw_json)
except Exception:
return None
if not isinstance(parsed, dict):
return None
if parsed.get("app") != "com.tencent.multimsg":
return None
config = parsed.get("config")
if not isinstance(config, dict):
return None
if config.get("forward") != 1:
return None
meta = parsed.get("meta")
if not isinstance(meta, dict):
return None
detail = meta.get("detail")
if not isinstance(detail, dict):
return None
news_items = detail.get("news")
if not isinstance(news_items, list):
return None
texts: list[str] = []
for item in news_items:
if not isinstance(item, dict):
continue
text_content = item.get("text")
if not isinstance(text_content, str):
continue
cleaned = text_content.strip().replace("[图片]", "").strip()
if cleaned:
texts.append(cleaned)
return "\n".join(texts).strip() or None
def _parse_onebot_segments(
segments: list[Any],
*,
settings: QuotedMessageParserSettings = SETTINGS,
) -> ParsedOneBotPayload:
text_parts: list[str] = []
forward_ids: list[str] = []
image_refs: list[str] = []
for seg in segments:
if not isinstance(seg, dict):
continue
seg_type = seg.get("type")
seg_data = seg.get("data", {}) if isinstance(seg.get("data"), dict) else {}
if seg_type in ("text", "plain"):
text = seg_data.get("text")
if isinstance(text, str) and text:
text_parts.append(text)
elif seg_type == "image":
text_parts.append("[Image]")
candidate = seg_data.get("url") or seg_data.get("file")
if isinstance(candidate, str) and candidate.strip():
image_refs.append(candidate.strip())
elif seg_type == "video":
text_parts.append("[Video]")
elif seg_type == "file":
file_name = (
seg_data.get("name")
or seg_data.get("file_name")
or seg_data.get("file")
or "file"
)
text_parts.append(f"[File:{file_name}]")
candidate_url = seg_data.get("url")
if (
isinstance(candidate_url, str)
and candidate_url.strip()
and looks_like_image_file_name(normalize_file_like_url(candidate_url))
):
image_refs.append(candidate_url.strip())
candidate_file = seg_data.get("file")
if (
isinstance(candidate_file, str)
and candidate_file.strip()
and looks_like_image_file_name(
normalize_file_like_url(
seg_data.get("name")
or seg_data.get("file_name")
or candidate_file
)
)
):
image_refs.append(candidate_file.strip())
elif seg_type in ("forward", "forward_msg", "nodes"):
fid = seg_data.get("id") or seg_data.get("message_id")
if isinstance(fid, (str, int)) and str(fid):
forward_ids.append(str(fid))
else:
nested_nodes = seg_data.get("content")
nested_text, nested_forward_ids, nested_images = (
_extract_text_forward_ids_and_images_from_forward_nodes(
nested_nodes if isinstance(nested_nodes, list) else [],
depth=1,
settings=settings,
)
)
if nested_text:
text_parts.append(nested_text)
if nested_forward_ids:
forward_ids.extend(nested_forward_ids)
if nested_images:
image_refs.extend(nested_images)
elif seg_type == "json":
raw_json = seg_data.get("data")
if isinstance(raw_json, str) and raw_json.strip():
raw_json = raw_json.replace("&#44;", ",")
multimsg_text = _extract_text_from_multimsg_json(raw_json)
if multimsg_text:
text_parts.append(multimsg_text)
return _build_parsed_payload(
_join_text_parts(text_parts),
forward_ids,
normalize_and_dedupe_strings(image_refs),
)
def _extract_text_forward_ids_and_images_from_forward_nodes(
nodes: list[Any],
*,
depth: int = 0,
settings: QuotedMessageParserSettings = SETTINGS,
) -> tuple[str | None, list[str], list[str]]:
if not isinstance(nodes, list) or depth > settings.max_forward_node_depth:
return None, [], []
texts: list[str] = []
forward_ids: list[str] = []
image_refs: list[str] = []
indent = " " * depth
for node in nodes:
if not isinstance(node, dict):
continue
sender = node.get("sender") if isinstance(node.get("sender"), dict) else {}
sender_name = (
sender.get("nickname")
or sender.get("card")
or sender.get("user_id")
or "Unknown User"
)
raw_content = node.get("message") or node.get("content") or []
chain: list[Any] = []
if isinstance(raw_content, list):
chain = raw_content
elif isinstance(raw_content, str):
raw_content = raw_content.strip()
if raw_content:
try:
parsed = json.loads(raw_content)
except Exception:
parsed = None
if isinstance(parsed, list):
chain = parsed
else:
chain = [{"type": "text", "data": {"text": raw_content}}]
parsed_segments = _parse_onebot_segments(chain, settings=settings)
node_text = parsed_segments["text"]
node_forward_ids = parsed_segments["forward_ids"]
node_images = parsed_segments["image_refs"]
if node_text:
texts.append(f"{indent}{sender_name}: {node_text}")
if node_forward_ids:
forward_ids.extend(node_forward_ids)
if node_images:
image_refs.extend(node_images)
return (
"\n".join(texts).strip() or None,
normalize_and_dedupe_strings(forward_ids),
normalize_and_dedupe_strings(image_refs),
)
def _parse_onebot_get_msg_payload(
payload: dict[str, Any],
*,
settings: QuotedMessageParserSettings = SETTINGS,
) -> ParsedOneBotPayload:
data = _unwrap_onebot_data(payload)
segments = data.get("message") or data.get("messages")
if isinstance(segments, list):
return _parse_onebot_segments(segments, settings=settings)
text: str | None = None
if isinstance(segments, str) and segments.strip():
text = segments.strip()
else:
raw = data.get("raw_message")
if isinstance(raw, str) and raw.strip():
text = raw.strip()
return _build_parsed_payload(text)
def _parse_onebot_get_forward_payload(
payload: dict[str, Any],
*,
settings: QuotedMessageParserSettings = SETTINGS,
) -> ParsedOneBotPayload:
data = _unwrap_onebot_data(payload)
nodes = (
data.get("messages")
or data.get("message")
or data.get("nodes")
or data.get("nodeList")
)
if not isinstance(nodes, list):
return _build_parsed_payload(None)
text, forward_ids, image_refs = (
_extract_text_forward_ids_and_images_from_forward_nodes(
nodes,
settings=settings,
)
)
return _build_parsed_payload(text, forward_ids, image_refs)
class ReplyChainParser:
def __init__(self, settings: QuotedMessageParserSettings = SETTINGS):
self._settings = settings
@staticmethod
def find_first_reply_component(event: AstrMessageEvent) -> Reply | None:
return _find_first_reply_component(event)
@staticmethod
def is_forward_placeholder_only_text(text: str | None) -> bool:
return _is_forward_placeholder_only_text(text)
def extract_text_from_reply_component(
self,
reply: Reply,
*,
depth: int = 0,
) -> str | None:
return _extract_text_from_reply_component(
reply,
depth=depth,
settings=self._settings,
)
def extract_image_refs_from_reply_component(
self,
reply: Reply,
*,
depth: int = 0,
) -> list[str]:
return _extract_image_refs_from_reply_component(
reply,
depth=depth,
settings=self._settings,
)
class OneBotPayloadParser:
def __init__(self, settings: QuotedMessageParserSettings = SETTINGS):
self._settings = settings
def parse_get_msg_payload(self, payload: dict[str, Any]) -> ParsedOneBotPayload:
return _parse_onebot_get_msg_payload(payload, settings=self._settings)
def parse_get_forward_payload(
self,
payload: dict[str, Any],
) -> ParsedOneBotPayload:
return _parse_onebot_get_forward_payload(payload, settings=self._settings)
@@ -0,0 +1,211 @@
from __future__ import annotations
from dataclasses import dataclass
from astrbot import logger
from astrbot.core.message.components import Reply
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.utils.string_utils import normalize_and_dedupe_strings
from .chain_parser import OneBotPayloadParser, ReplyChainParser
from .image_resolver import ImageResolver
from .onebot_client import OneBotClient
from .settings import SETTINGS, QuotedMessageParserSettings
async def _collect_text_and_images_from_forward_ids(
onebot_client: OneBotClient,
payload_parser: OneBotPayloadParser,
forward_ids: list[str],
*,
max_fetch: int,
) -> tuple[list[str], list[str]]:
texts: list[str] = []
image_refs: list[str] = []
pending: list[str] = []
seen: set[str] = set()
for fid in forward_ids:
if not isinstance(fid, str):
continue
cleaned = fid.strip()
if cleaned:
pending.append(cleaned)
fetch_count = 0
while pending and fetch_count < max_fetch:
current_id = pending.pop(0)
if current_id in seen:
continue
seen.add(current_id)
fetch_count += 1
forward_payload = await onebot_client.get_forward_msg(current_id)
if not forward_payload:
continue
parsed = payload_parser.parse_get_forward_payload(forward_payload)
if parsed["text"]:
texts.append(parsed["text"])
if parsed["image_refs"]:
image_refs.extend(parsed["image_refs"])
for nested_id in parsed["forward_ids"]:
if nested_id not in seen:
pending.append(nested_id)
if pending:
logger.warning(
"quoted_message_parser: stop fetching nested forward messages after %d hops",
max_fetch,
)
return texts, normalize_and_dedupe_strings(image_refs)
@dataclass(slots=True)
class QuotedMessageContent:
embedded_text: str | None
embedded_image_refs: list[str]
reply_id: str
direct_text: str | None
direct_image_refs: list[str]
forward_texts: list[str]
forward_image_refs: list[str]
class QuotedMessageExtractor:
def __init__(
self,
event: AstrMessageEvent,
settings: QuotedMessageParserSettings = SETTINGS,
):
self._event = event
self._settings = settings
self._reply_parser = ReplyChainParser(settings=settings)
self._payload_parser = OneBotPayloadParser(settings=settings)
self._client = OneBotClient(event, settings=settings)
self._image_resolver = ImageResolver(event, self._client)
async def _fetch_quoted_content(
self,
reply_component: Reply | None = None,
*,
fetch_remote: bool,
) -> QuotedMessageContent | None:
reply = reply_component or self._reply_parser.find_first_reply_component(
self._event
)
if not reply:
return None
embedded_text = self._reply_parser.extract_text_from_reply_component(reply)
embedded_image_refs = list(
self._reply_parser.extract_image_refs_from_reply_component(reply)
)
reply_id = getattr(reply, "id", None)
reply_id_str = str(reply_id).strip() if reply_id is not None else ""
if not fetch_remote or not reply_id_str:
return QuotedMessageContent(
embedded_text=embedded_text,
embedded_image_refs=embedded_image_refs,
reply_id=reply_id_str,
direct_text=None,
direct_image_refs=[],
forward_texts=[],
forward_image_refs=[],
)
msg_payload = await self._client.get_msg(reply_id_str)
if not msg_payload:
return QuotedMessageContent(
embedded_text=embedded_text,
embedded_image_refs=embedded_image_refs,
reply_id=reply_id_str,
direct_text=None,
direct_image_refs=[],
forward_texts=[],
forward_image_refs=[],
)
parsed = self._payload_parser.parse_get_msg_payload(msg_payload)
forward_texts, forward_images = await _collect_text_and_images_from_forward_ids(
self._client,
self._payload_parser,
parsed["forward_ids"],
max_fetch=self._settings.max_forward_fetch,
)
return QuotedMessageContent(
embedded_text=embedded_text,
embedded_image_refs=embedded_image_refs,
reply_id=reply_id_str,
direct_text=parsed["text"],
direct_image_refs=list(parsed["image_refs"]),
forward_texts=forward_texts,
forward_image_refs=forward_images,
)
async def text(self, reply_component: Reply | None = None) -> str | None:
embedded_content = await self._fetch_quoted_content(
reply_component,
fetch_remote=False,
)
if not embedded_content:
return None
if (
embedded_content.embedded_text
and not self._reply_parser.is_forward_placeholder_only_text(
embedded_content.embedded_text
)
):
return embedded_content.embedded_text
if not embedded_content.reply_id:
return embedded_content.embedded_text
fetched_content = await self._fetch_quoted_content(
reply_component,
fetch_remote=True,
)
if not fetched_content:
return embedded_content.embedded_text
text_parts: list[str] = []
if fetched_content.direct_text:
text_parts.append(fetched_content.direct_text)
text_parts.extend(fetched_content.forward_texts)
return "\n".join(text_parts).strip() or embedded_content.embedded_text
async def images(self, reply_component: Reply | None = None) -> list[str]:
content = await self._fetch_quoted_content(reply_component, fetch_remote=True)
if not content:
return []
image_refs: list[str] = []
image_refs.extend(content.embedded_image_refs)
image_refs.extend(content.direct_image_refs)
image_refs.extend(content.forward_image_refs)
return await self._image_resolver.resolve_for_llm(image_refs)
async def extract_quoted_message_text(
event: AstrMessageEvent,
reply_component: Reply | None = None,
settings: QuotedMessageParserSettings | None = None,
) -> str | None:
return await QuotedMessageExtractor(event, settings=settings or SETTINGS).text(
reply_component
)
async def extract_quoted_message_images(
event: AstrMessageEvent,
reply_component: Reply | None = None,
settings: QuotedMessageParserSettings | None = None,
) -> list[str]:
return await QuotedMessageExtractor(event, settings=settings or SETTINGS).images(
reply_component
)
@@ -0,0 +1,94 @@
from __future__ import annotations
import os
from urllib.parse import urlsplit
IMAGE_EXTENSIONS = {
".jpg",
".jpeg",
".png",
".webp",
".bmp",
".tif",
".tiff",
".gif",
}
def normalize_file_like_url(path: str | None) -> str | None:
if path is None:
return None
if not isinstance(path, str):
return None
if "?" not in path and "#" not in path:
return path
try:
split = urlsplit(path)
except Exception:
return path
return split.path or path
def looks_like_image_file_name(name: str) -> bool:
normalized_name = normalize_file_like_url(name)
if not isinstance(normalized_name, str) or not normalized_name.strip():
return False
_, ext = os.path.splitext(normalized_name.strip().lower())
return ext in IMAGE_EXTENSIONS
def convert_data_image_to_base64_ref(image_ref: str) -> str | None:
if not isinstance(image_ref, str):
return None
value = image_ref.strip()
if not value:
return None
lower_value = value.lower()
if not lower_value.startswith("data:image/"):
return None
comma_index = value.find(",")
if comma_index <= 0:
return None
header = value[:comma_index].lower()
payload = value[comma_index + 1 :].strip()
if ";base64" not in header or not payload:
return None
return f"base64://{payload}"
def get_existing_local_path(value: str) -> str | None:
lower_value = value.lower()
if lower_value.startswith("file://"):
file_path = value[7:]
if file_path.startswith("/") and len(file_path) > 3 and file_path[2] == ":":
file_path = file_path[1:]
if file_path and os.path.exists(file_path):
return os.path.abspath(file_path)
return None
if os.path.exists(value):
return os.path.abspath(value)
return None
def normalize_image_ref(image_ref: str) -> str | None:
if not isinstance(image_ref, str):
return None
value = image_ref.strip()
if not value:
return None
lower_value = value.lower()
if lower_value.startswith(("http://", "https://")):
return value
if lower_value.startswith("base64://"):
return value
data_image_ref = convert_data_image_to_base64_ref(value)
if data_image_ref:
return data_image_ref
local_path = get_existing_local_path(value)
if local_path and looks_like_image_file_name(local_path):
return local_path
return None
@@ -0,0 +1,130 @@
from __future__ import annotations
import os
from typing import Any
from astrbot import logger
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.utils.string_utils import normalize_and_dedupe_strings
from .image_refs import IMAGE_EXTENSIONS, get_existing_local_path, normalize_image_ref
from .onebot_client import OneBotClient
def _build_image_id_candidates(image_ref: str) -> list[str]:
candidates: list[str] = [image_ref]
base_name, ext = os.path.splitext(image_ref)
if ext and base_name and base_name not in candidates:
if ext.lower() in IMAGE_EXTENSIONS:
candidates.append(base_name)
return candidates
def _build_image_resolve_actions(
event: AstrMessageEvent,
image_ref: str,
) -> list[tuple[str, dict[str, Any]]]:
actions: list[tuple[str, dict[str, Any]]] = []
candidates = _build_image_id_candidates(image_ref)
for candidate in candidates:
actions.extend(
[
("get_image", {"file": candidate}),
("get_image", {"file_id": candidate}),
("get_image", {"id": candidate}),
("get_image", {"image": candidate}),
("get_file", {"file_id": candidate}),
("get_file", {"file": candidate}),
]
)
try:
group_id = event.get_group_id()
except Exception:
group_id = None
group_id_value = group_id
if isinstance(group_id, str) and group_id.isdigit():
group_id_value = int(group_id)
if group_id_value:
for candidate in candidates:
actions.append(
(
"get_group_file_url",
{"group_id": group_id_value, "file_id": candidate},
)
)
for candidate in candidates:
actions.append(("get_private_file_url", {"file_id": candidate}))
return actions
class ImageResolver:
def __init__(
self,
event: AstrMessageEvent,
onebot_client: OneBotClient | None = None,
):
self._event = event
self._client = onebot_client or OneBotClient(event)
async def resolve_for_llm(self, image_refs: list[str]) -> list[str]:
resolved: list[str] = []
unresolved: list[str] = []
for image_ref in normalize_and_dedupe_strings(image_refs):
normalized = normalize_image_ref(image_ref)
if normalized:
resolved.append(normalized)
elif get_existing_local_path(image_ref):
# Drop non-image local paths instead of treating them as remote IDs.
logger.debug(
"quoted_message_parser: skip non-image local path ref=%s",
image_ref[:128],
)
else:
unresolved.append(image_ref)
for image_ref in unresolved:
resolved_ref = await self._resolve_one(image_ref)
if resolved_ref:
resolved.append(resolved_ref)
return normalize_and_dedupe_strings(resolved)
async def _resolve_one(self, image_ref: str) -> str | None:
resolved = normalize_image_ref(image_ref)
if resolved:
return resolved
actions = _build_image_resolve_actions(self._event, image_ref)
for action, params in actions:
data = await self._client.call(
action,
params,
warn_on_all_failed=False,
unwrap_data=True,
)
if not isinstance(data, dict):
continue
url = data.get("url")
if isinstance(url, str):
normalized = normalize_image_ref(url)
if normalized:
return normalized
file_value = data.get("file")
if isinstance(file_value, str):
normalized = normalize_image_ref(file_value)
if normalized:
return normalized
logger.warning(
"quoted_message_parser: failed to resolve quoted image ref=%s after %d actions",
image_ref[:128],
len(actions),
)
return None
@@ -0,0 +1,119 @@
from __future__ import annotations
from typing import Any
from astrbot import logger
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from .settings import SETTINGS, QuotedMessageParserSettings
def _unwrap_action_response(ret: dict[str, Any] | None) -> dict[str, Any]:
if not isinstance(ret, dict):
return {}
data = ret.get("data")
if isinstance(data, dict):
return data
return ret
class OneBotClient:
def __init__(
self,
event: AstrMessageEvent,
settings: QuotedMessageParserSettings = SETTINGS,
):
self._call_action = self._resolve_call_action(event)
self._settings = settings
@staticmethod
def _resolve_call_action(event: AstrMessageEvent):
bot = getattr(event, "bot", None)
api = getattr(bot, "api", None)
call_action = getattr(api, "call_action", None)
if not callable(call_action):
return None
return call_action
async def _call_action_try_params(
self,
action: str,
params_list: list[dict[str, Any]],
*,
warn_on_all_failed: bool | None = None,
) -> dict[str, Any] | None:
if self._call_action is None:
return None
if warn_on_all_failed is None:
warn_on_all_failed = self._settings.warn_on_action_failure
last_error: Exception | None = None
last_params: dict[str, Any] | None = None
for params in params_list:
try:
result = await self._call_action(action, **params)
if isinstance(result, dict):
return result
except Exception as exc:
last_error = exc
last_params = params
logger.debug(
"quoted_message_parser: action %s failed with params %s: %s",
action,
{k: str(v)[:64] for k, v in params.items()},
exc,
)
if warn_on_all_failed and last_error is not None:
logger.warning(
"quoted_message_parser: all attempts failed for action %s, "
"last_params=%s, error=%s",
action,
(
{k: str(v)[:64] for k, v in last_params.items()}
if isinstance(last_params, dict)
else None
),
last_error,
)
return None
async def call(
self,
action: str,
params: dict[str, Any],
*,
warn_on_all_failed: bool = False,
unwrap_data: bool = True,
) -> dict[str, Any] | None:
ret = await self._call_action_try_params(
action,
[params],
warn_on_all_failed=warn_on_all_failed,
)
if not unwrap_data:
return ret
return _unwrap_action_response(ret)
async def _call_action_compat(
self,
action: str,
message_id: str | int,
) -> dict[str, Any] | None:
message_id_str = str(message_id).strip()
if not message_id_str:
return None
params_list: list[dict[str, Any]] = [
{"message_id": message_id_str},
{"id": message_id_str},
]
if message_id_str.isdigit():
int_id = int(message_id_str)
params_list.extend([{"message_id": int_id}, {"id": int_id}])
return await self._call_action_try_params(action, params_list)
async def get_msg(self, message_id: str | int) -> dict[str, Any] | None:
return await self._call_action_compat("get_msg", message_id)
async def get_forward_msg(self, forward_id: str | int) -> dict[str, Any] | None:
return await self._call_action_compat("get_forward_msg", forward_id)
@@ -0,0 +1,85 @@
from __future__ import annotations
from collections.abc import Mapping
from dataclasses import dataclass
from typing import Any
_DEFAULT_MAX_COMPONENT_CHAIN_DEPTH = 4
_DEFAULT_MAX_FORWARD_NODE_DEPTH = 6
_DEFAULT_MAX_FORWARD_FETCH = 32
def _read_int_mapping(
mapping: Mapping[str, Any],
key: str,
default: int,
) -> int:
raw = mapping.get(key)
if raw is None:
return default
try:
value = int(raw)
except (TypeError, ValueError):
return default
if value <= 0:
return default
return value
def _read_bool_mapping(
mapping: Mapping[str, Any],
key: str,
default: bool,
) -> bool:
raw = mapping.get(key)
if raw is None:
return default
if isinstance(raw, bool):
return raw
if isinstance(raw, str):
lowered = raw.strip().lower()
if lowered in {"1", "true", "yes", "on"}:
return True
if lowered in {"0", "false", "no", "off"}:
return False
return default
@dataclass(frozen=True)
class QuotedMessageParserSettings:
max_component_chain_depth: int = _DEFAULT_MAX_COMPONENT_CHAIN_DEPTH
max_forward_node_depth: int = _DEFAULT_MAX_FORWARD_NODE_DEPTH
max_forward_fetch: int = _DEFAULT_MAX_FORWARD_FETCH
warn_on_action_failure: bool = False
def with_overrides(
self,
overrides: Mapping[str, Any] | None = None,
) -> QuotedMessageParserSettings:
if not overrides:
return self
return QuotedMessageParserSettings(
max_component_chain_depth=_read_int_mapping(
overrides,
"max_component_chain_depth",
self.max_component_chain_depth,
),
max_forward_node_depth=_read_int_mapping(
overrides,
"max_forward_node_depth",
self.max_forward_node_depth,
),
max_forward_fetch=_read_int_mapping(
overrides,
"max_forward_fetch",
self.max_forward_fetch,
),
warn_on_action_failure=_read_bool_mapping(
overrides,
"warn_on_action_failure",
self.warn_on_action_failure,
),
)
SETTINGS = QuotedMessageParserSettings()
@@ -0,0 +1,11 @@
from __future__ import annotations
from astrbot.core.utils.quoted_message.extractor import (
extract_quoted_message_images,
extract_quoted_message_text,
)
__all__ = [
"extract_quoted_message_text",
"extract_quoted_message_images",
]
+21
View File
@@ -0,0 +1,21 @@
from __future__ import annotations
from collections.abc import Iterable
from typing import Any
def normalize_and_dedupe_strings(items: Iterable[Any] | None) -> list[str]:
if items is None:
return []
normalized: list[str] = []
seen: set[str] = set()
for item in items:
if not isinstance(item, str):
continue
cleaned = item.strip()
if not cleaned or cleaned in seen:
continue
seen.add(cleaned)
normalized.append(cleaned)
return normalized
@@ -247,6 +247,28 @@
"description": "Sanitize History by Modalities",
"hint": "When enabled, sanitizes contexts before each LLM request by removing image blocks and tool-call structures that the current provider's modalities do not support (this changes what the model sees)."
},
"max_quoted_fallback_images": {
"description": "Forwarded Image Fetch Limit",
"hint": "Maximum number of images injected from forwarded-message parsing; extra images are truncated."
},
"quoted_message_parser": {
"max_component_chain_depth": {
"description": "Forwarded Rich-Text Parse Depth",
"hint": "Maximum recursive depth when parsing rich-text component chains inside forwarded messages."
},
"max_forward_node_depth": {
"description": "Forward Nesting Parse Depth",
"hint": "Maximum recursive depth when parsing nested forwarded nodes."
},
"max_forward_fetch": {
"description": "Forward Recursive Fetch Limit",
"hint": "Maximum number of recursive get_forward_msg fetch operations."
},
"warn_on_action_failure": {
"description": "Warn on Forward Parse Failure",
"hint": "When enabled, log warnings when all get_msg/get_forward_msg attempts fail."
}
},
"max_agent_step": {
"description": "Maximum Tool Call Rounds"
},
@@ -250,6 +250,28 @@
"description": "按模型能力清理历史上下文",
"hint": "开启后,在每次请求 LLM 前会按当前模型提供商中所选择的模型能力删除对话中不支持的图片/工具调用结构(会改变模型看到的历史)"
},
"max_quoted_fallback_images": {
"description": "转发消息中图片获取上限",
"hint": "转发消息解析到的图片最多注入数量,超出部分会截断。"
},
"quoted_message_parser": {
"max_component_chain_depth": {
"description": "转发消息富文本解析深度",
"hint": "解析转发消息中的富文本组件链时允许的最大递归深度。"
},
"max_forward_node_depth": {
"description": "转发消息嵌套解析深度",
"hint": "解析嵌套转发节点时允许的最大递归深度。"
},
"max_forward_fetch": {
"description": "转发消息递归拉取上限",
"hint": "递归调用 get_forward_msg 拉取转发内容的最大次数。"
},
"warn_on_action_failure": {
"description": "转发消息解析失败告警",
"hint": "开启后,get_msg/get_forward_msg 全部尝试失败时输出 warning 日志。"
}
},
"max_agent_step": {
"description": "工具调用轮数上限"
},
+382
View File
@@ -0,0 +1,382 @@
from types import SimpleNamespace
import pytest
from astrbot.core.provider.sources.openai_source import ProviderOpenAIOfficial
class _ErrorWithBody(Exception):
def __init__(self, message: str, body: dict):
super().__init__(message)
self.body = body
class _ErrorWithResponse(Exception):
def __init__(self, message: str, response_text: str):
super().__init__(message)
self.response = SimpleNamespace(text=response_text)
def _make_provider(overrides: dict | None = None) -> ProviderOpenAIOfficial:
provider_config = {
"id": "test-openai",
"type": "openai_chat_completion",
"model": "gpt-4o-mini",
"key": ["test-key"],
}
if overrides:
provider_config.update(overrides)
return ProviderOpenAIOfficial(
provider_config=provider_config,
provider_settings={},
)
@pytest.mark.asyncio
async def test_handle_api_error_content_moderated_removes_images():
provider = _make_provider(
{"image_moderation_error_patterns": ["file:content-moderated"]}
)
try:
payloads = {
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": "hello"},
{
"type": "image_url",
"image_url": {"url": "data:image/jpeg;base64,abcd"},
},
],
}
]
}
context_query = payloads["messages"]
success, *_rest = await provider._handle_api_error(
Exception("Content is moderated [WKE=file:content-moderated]"),
payloads=payloads,
context_query=context_query,
func_tool=None,
chosen_key="test-key",
available_api_keys=["test-key"],
retry_cnt=0,
max_retries=10,
)
assert success is False
updated_context = payloads["messages"]
assert isinstance(updated_context, list)
assert updated_context[0]["content"] == [{"type": "text", "text": "hello"}]
finally:
await provider.terminate()
@pytest.mark.asyncio
async def test_handle_api_error_model_not_vlm_removes_images_and_retries_text_only():
provider = _make_provider()
try:
payloads = {
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": "hello"},
{
"type": "image_url",
"image_url": {"url": "data:image/jpeg;base64,abcd"},
},
],
}
]
}
context_query = payloads["messages"]
success, *_rest = await provider._handle_api_error(
Exception("The model is not a VLM and cannot process images"),
payloads=payloads,
context_query=context_query,
func_tool=None,
chosen_key="test-key",
available_api_keys=["test-key"],
retry_cnt=0,
max_retries=10,
)
assert success is False
updated_context = payloads["messages"]
assert isinstance(updated_context, list)
assert updated_context[0]["content"] == [{"type": "text", "text": "hello"}]
finally:
await provider.terminate()
@pytest.mark.asyncio
async def test_handle_api_error_model_not_vlm_after_fallback_raises():
provider = _make_provider()
try:
payloads = {
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": "hello"},
{
"type": "image_url",
"image_url": {"url": "data:image/jpeg;base64,abcd"},
},
],
}
]
}
context_query = payloads["messages"]
with pytest.raises(Exception, match="not a VLM"):
await provider._handle_api_error(
Exception("The model is not a VLM and cannot process images"),
payloads=payloads,
context_query=context_query,
func_tool=None,
chosen_key="test-key",
available_api_keys=["test-key"],
retry_cnt=1,
max_retries=10,
image_fallback_used=True,
)
finally:
await provider.terminate()
@pytest.mark.asyncio
async def test_handle_api_error_content_moderated_with_unserializable_body():
provider = _make_provider({"image_moderation_error_patterns": ["blocked"]})
try:
payloads = {
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": "hello"},
{
"type": "image_url",
"image_url": {"url": "data:image/jpeg;base64,abcd"},
},
],
}
]
}
context_query = payloads["messages"]
err = _ErrorWithBody(
"upstream error",
{"error": {"message": "blocked"}, "raw": object()},
)
success, *_rest = await provider._handle_api_error(
err,
payloads=payloads,
context_query=context_query,
func_tool=None,
chosen_key="test-key",
available_api_keys=["test-key"],
retry_cnt=0,
max_retries=10,
)
assert success is False
assert payloads["messages"][0]["content"] == [{"type": "text", "text": "hello"}]
finally:
await provider.terminate()
def test_extract_error_text_candidates_truncates_long_response_text():
long_text = "x" * 20000
err = _ErrorWithResponse("upstream error", long_text)
candidates = ProviderOpenAIOfficial._extract_error_text_candidates(err)
assert candidates
assert max(len(candidate) for candidate in candidates) <= (
ProviderOpenAIOfficial._ERROR_TEXT_CANDIDATE_MAX_CHARS
)
@pytest.mark.asyncio
async def test_handle_api_error_content_moderated_without_images_raises():
provider = _make_provider(
{"image_moderation_error_patterns": ["file:content-moderated"]}
)
try:
payloads = {
"messages": [
{
"role": "user",
"content": [{"type": "text", "text": "hello"}],
}
]
}
context_query = payloads["messages"]
err = Exception("Content is moderated [WKE=file:content-moderated]")
with pytest.raises(Exception, match="content-moderated"):
await provider._handle_api_error(
err,
payloads=payloads,
context_query=context_query,
func_tool=None,
chosen_key="test-key",
available_api_keys=["test-key"],
retry_cnt=0,
max_retries=10,
)
finally:
await provider.terminate()
@pytest.mark.asyncio
async def test_handle_api_error_content_moderated_detects_structured_body():
provider = _make_provider(
{"image_moderation_error_patterns": ["content_moderated"]}
)
try:
payloads = {
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": "hello"},
{
"type": "image_url",
"image_url": {"url": "data:image/jpeg;base64,abcd"},
},
],
}
]
}
context_query = payloads["messages"]
err = _ErrorWithBody(
"upstream error",
{"error": {"code": "content_moderated", "message": "blocked"}},
)
success, *_rest = await provider._handle_api_error(
err,
payloads=payloads,
context_query=context_query,
func_tool=None,
chosen_key="test-key",
available_api_keys=["test-key"],
retry_cnt=0,
max_retries=10,
)
assert success is False
assert payloads["messages"][0]["content"] == [{"type": "text", "text": "hello"}]
finally:
await provider.terminate()
@pytest.mark.asyncio
async def test_handle_api_error_content_moderated_supports_custom_patterns():
provider = _make_provider(
{"image_moderation_error_patterns": ["blocked_by_policy_code_123"]}
)
try:
payloads = {
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": "hello"},
{
"type": "image_url",
"image_url": {"url": "data:image/jpeg;base64,abcd"},
},
],
}
]
}
context_query = payloads["messages"]
err = Exception("upstream: blocked_by_policy_code_123")
success, *_rest = await provider._handle_api_error(
err,
payloads=payloads,
context_query=context_query,
func_tool=None,
chosen_key="test-key",
available_api_keys=["test-key"],
retry_cnt=0,
max_retries=10,
)
assert success is False
assert payloads["messages"][0]["content"] == [{"type": "text", "text": "hello"}]
finally:
await provider.terminate()
@pytest.mark.asyncio
async def test_handle_api_error_content_moderated_without_patterns_raises():
provider = _make_provider()
try:
payloads = {
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": "hello"},
{
"type": "image_url",
"image_url": {"url": "data:image/jpeg;base64,abcd"},
},
],
}
]
}
context_query = payloads["messages"]
err = Exception("Content is moderated [WKE=file:content-moderated]")
with pytest.raises(Exception, match="content-moderated"):
await provider._handle_api_error(
err,
payloads=payloads,
context_query=context_query,
func_tool=None,
chosen_key="test-key",
available_api_keys=["test-key"],
retry_cnt=0,
max_retries=10,
)
finally:
await provider.terminate()
@pytest.mark.asyncio
async def test_handle_api_error_unknown_image_error_raises():
provider = _make_provider()
try:
payloads = {
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": "hello"},
{
"type": "image_url",
"image_url": {"url": "data:image/jpeg;base64,abcd"},
},
],
}
]
}
context_query = payloads["messages"]
with pytest.raises(Exception, match="unknown provider image upload error"):
await provider._handle_api_error(
Exception("some unknown provider image upload error"),
payloads=payloads,
context_query=context_query,
func_tool=None,
chosen_key="test-key",
available_api_keys=["test-key"],
retry_cnt=0,
max_retries=10,
)
finally:
await provider.terminate()
+494
View File
@@ -0,0 +1,494 @@
from types import SimpleNamespace
import pytest
from astrbot.core.message.components import Image, Plain, Reply
from astrbot.core.utils.quoted_message_parser import (
extract_quoted_message_images,
extract_quoted_message_text,
)
class _DummyAPI:
def __init__(
self,
responses: dict[tuple[str, str], dict],
param_responses: dict[tuple[str, tuple[tuple[str, str], ...]], dict]
| None = None,
):
self._responses = responses
self._param_responses = param_responses or {}
async def call_action(self, action: str, **params):
param_key = (action, tuple(sorted((k, str(v)) for k, v in params.items())))
if param_key in self._param_responses:
return self._param_responses[param_key]
msg_id = params.get("message_id")
if msg_id is None:
msg_id = params.get("id")
key = (action, str(msg_id))
if key not in self._responses:
raise RuntimeError(f"no mock response for {key}")
return self._responses[key]
class _FailIfCalledAPI:
async def call_action(self, action: str, **params):
raise AssertionError(
f"call_action should not be called, got action={action}, params={params}"
)
def _make_event(
reply: Reply,
responses: dict[tuple[str, str], dict] | None = None,
param_responses: dict[tuple[str, tuple[tuple[str, str], ...]], dict] | None = None,
):
if responses is None:
responses = {}
if param_responses is None:
param_responses = {}
return SimpleNamespace(
message_obj=SimpleNamespace(message=[reply]),
bot=SimpleNamespace(api=_DummyAPI(responses, param_responses)),
get_group_id=lambda: "",
)
@pytest.mark.asyncio
async def test_extract_quoted_message_text_from_reply_chain():
reply = Reply(id="1", chain=[Plain(text="quoted content")], message_str="")
event = _make_event(reply)
text = await extract_quoted_message_text(event)
assert text == "quoted content"
@pytest.mark.asyncio
async def test_extract_quoted_message_text_no_reply_component():
event = SimpleNamespace(
message_obj=SimpleNamespace(message=[Plain(text="unquoted message")]),
bot=SimpleNamespace(api=_DummyAPI({}, {})),
get_group_id=lambda: "",
)
text = await extract_quoted_message_text(event)
assert text is None
@pytest.mark.asyncio
async def test_extract_quoted_message_images_no_reply_component():
event = SimpleNamespace(
message_obj=SimpleNamespace(message=[Plain(text="unquoted message")]),
bot=SimpleNamespace(api=_FailIfCalledAPI()),
get_group_id=lambda: "",
)
images = await extract_quoted_message_images(event)
assert images == []
@pytest.mark.asyncio
@pytest.mark.parametrize("reply_id", [None, ""])
async def test_extract_quoted_message_text_reply_without_id_does_not_call_get_msg(
reply_id: str | None,
):
reply = Reply(
id="placeholder", chain=[Plain(text="quoted content")], message_str=""
)
object.__setattr__(reply, "id", reply_id)
event = SimpleNamespace(
message_obj=SimpleNamespace(message=[reply]),
bot=SimpleNamespace(api=_FailIfCalledAPI()),
get_group_id=lambda: "",
)
text = await extract_quoted_message_text(event)
assert text == "quoted content"
@pytest.mark.asyncio
async def test_extract_quoted_message_text_fallback_get_msg_and_forward():
reply = Reply(id="100", chain=None, message_str="")
event = _make_event(
reply,
responses={
(
"get_msg",
"100",
): {
"data": {
"message": [
{"type": "text", "data": {"text": "parent"}},
{"type": "forward", "data": {"id": "fwd_1"}},
]
}
},
(
"get_forward_msg",
"fwd_1",
): {
"data": {
"messages": [
{
"sender": {"nickname": "Alice"},
"message": [{"type": "text", "data": {"text": "hello"}}],
},
{
"sender": {"nickname": "Bob"},
"message": [
{"type": "image", "data": {"url": "http://img"}},
{"type": "text", "data": {"text": "world"}},
],
},
]
}
},
},
)
text = await extract_quoted_message_text(event)
assert text is not None
assert "parent" in text
assert "Alice: hello" in text
assert "Bob: [Image]world" in text
@pytest.mark.parametrize(
"placeholder_text",
[
"[Forward Message]",
"[转发消息]",
"[合并转发]",
"Alice: [Forward Message]",
"(Alice): [转发消息]",
"[Forward Message]\n[转发消息]",
"Alice: [Forward Message]\n(Bob): [合并转发]",
"[转发消息]\n\n[合并转发]",
],
)
@pytest.mark.asyncio
async def test_extract_quoted_message_text_forward_placeholder_variants_trigger_fallback(
placeholder_text: str,
):
reply = Reply(id="400", chain=[Plain(text=placeholder_text)], message_str="")
event = _make_event(
reply,
responses={
("get_msg", "400"): {
"data": {
"message": [
{"type": "text", "data": {"text": "Bob: "}},
{"type": "image", "data": {}},
{"type": "text", "data": {"text": "world"}},
]
}
}
},
)
text = await extract_quoted_message_text(event)
assert "Bob: [Image]world" in text
@pytest.mark.asyncio
async def test_extract_quoted_message_text_mixed_placeholder_does_not_trigger_fallback():
reply = Reply(
id="402",
chain=[Plain(text="Alice: [Forward Message]\nreal text")],
message_str="",
)
event = SimpleNamespace(
message_obj=SimpleNamespace(message=[reply]),
bot=SimpleNamespace(api=_FailIfCalledAPI()),
get_group_id=lambda: "",
)
text = await extract_quoted_message_text(event)
assert text is not None
assert "[Forward Message]" in text
assert "real text" in text
@pytest.mark.asyncio
async def test_extract_quoted_message_text_forward_placeholder_fallback_failure():
reply = Reply(id="401", chain=[Plain(text="[Forward Message]")], message_str="")
event = _make_event(reply, responses={})
text = await extract_quoted_message_text(event)
assert text == "[Forward Message]"
@pytest.mark.asyncio
async def test_extract_quoted_message_text_multimsg_malformed_config_does_not_raise():
reply = Reply(id="402", chain=None, message_str="")
event = _make_event(
reply,
responses={
("get_msg", "402"): {
"data": {
"message": [
{
"type": "json",
"data": {
"data": (
'{"app":"com.tencent.multimsg",'
'"config":"oops","meta":{}}'
)
},
},
{"type": "text", "data": {"text": "still works"}},
]
}
}
},
)
text = await extract_quoted_message_text(event)
assert text == "still works"
@pytest.mark.asyncio
async def test_extract_quoted_message_images_from_reply_chain():
reply = Reply(
id="1",
chain=[
Plain(text="quoted"),
Image(file="https://img.example.com/a.jpg"),
],
message_str="",
)
event = _make_event(reply)
images = await extract_quoted_message_images(event)
assert images == ["https://img.example.com/a.jpg"]
@pytest.mark.asyncio
async def test_extract_quoted_message_images_fallback_get_msg_direct_url():
reply = Reply(id="200", chain=None, message_str="")
event = _make_event(
reply,
responses={
("get_msg", "200"): {
"data": {
"message": [
{
"type": "image",
"data": {"url": "https://img.example.com/direct.jpg"},
}
]
}
}
},
)
images = await extract_quoted_message_images(event)
assert images == ["https://img.example.com/direct.jpg"]
@pytest.mark.asyncio
async def test_extract_quoted_message_images_data_image_ref_normalized_to_base64():
data_image_ref = "data:image/png;base64,abcd1234=="
reply = Reply(id="201", chain=None, message_str="")
event = _make_event(
reply,
responses={
("get_msg", "201"): {
"data": {
"message": [
{"type": "image", "data": {"url": data_image_ref}},
]
}
}
},
)
images = await extract_quoted_message_images(event)
assert images == ["base64://abcd1234=="]
@pytest.mark.asyncio
async def test_extract_quoted_message_images_file_url_with_query_string():
url_with_query = "https://img.example.com/direct.jpg?token=abc123#frag"
reply = Reply(id="205", chain=None, message_str="")
event = _make_event(
reply,
responses={
("get_msg", "205"): {
"data": {
"message": [
{
"type": "file",
"data": {
"url": url_with_query,
"name": "direct.jpg",
},
}
]
}
}
},
)
images = await extract_quoted_message_images(event)
assert images == [url_with_query]
@pytest.mark.asyncio
async def test_extract_quoted_message_images_non_image_local_path_is_ignored(tmp_path):
non_image_file = tmp_path / "secret.txt"
non_image_file.write_text("not an image", encoding="utf-8")
reply = Reply(
id="placeholder", chain=[Image(file=str(non_image_file))], message_str=""
)
object.__setattr__(reply, "id", None)
event = SimpleNamespace(
message_obj=SimpleNamespace(message=[reply]),
bot=SimpleNamespace(api=_FailIfCalledAPI()),
get_group_id=lambda: "",
)
images = await extract_quoted_message_images(event)
assert images == []
@pytest.mark.asyncio
async def test_extract_quoted_message_images_chain_placeholder_triggers_fallback():
reply = Reply(id="210", chain=[Plain(text="[Forward Message]")], message_str="")
event = _make_event(
reply,
responses={
("get_msg", "210"): {
"data": {
"message": [
{
"type": "image",
"data": {
"url": "https://img.example.com/from-fallback.jpg"
},
}
]
}
}
},
)
images = await extract_quoted_message_images(event)
assert images == ["https://img.example.com/from-fallback.jpg"]
@pytest.mark.asyncio
async def test_extract_quoted_message_images_fallback_resolve_file_id_with_get_image():
reply = Reply(id="300", chain=None, message_str="")
event = _make_event(
reply,
responses={
("get_msg", "300"): {
"data": {"message": [{"type": "image", "data": {"file": "abc123.jpg"}}]}
}
},
param_responses={
("get_image", (("file", "abc123.jpg"),)): {
"data": {"url": "https://img.example.com/resolved.jpg"}
}
},
)
images = await extract_quoted_message_images(event)
assert images == ["https://img.example.com/resolved.jpg"]
@pytest.mark.asyncio
async def test_extract_quoted_message_images_deduplicates_across_sources():
dup_url = "https://img.example.com/dup.jpg"
chain_only_url = "https://img.example.com/only-chain.jpg"
get_msg_only_url = "https://img.example.com/only-get-msg.jpg"
forward_only_url = "https://img.example.com/only-forward.jpg"
reply = Reply(
id="310",
chain=[Image(file=dup_url), Image(file=chain_only_url)],
message_str="",
)
event = _make_event(
reply,
responses={
("get_msg", "310"): {
"data": {
"message": [
{"type": "image", "data": {"url": dup_url}},
{"type": "image", "data": {"url": get_msg_only_url}},
{"type": "forward", "data": {"id": "999"}},
]
}
},
("get_forward_msg", "999"): {
"data": {
"messages": [
{
"sender": {"nickname": "Tester"},
"message": [
{"type": "image", "data": {"url": dup_url}},
{"type": "image", "data": {"url": forward_only_url}},
],
}
]
}
},
},
)
images = await extract_quoted_message_images(event)
assert images == [
dup_url,
chain_only_url,
get_msg_only_url,
forward_only_url,
]
@pytest.mark.asyncio
async def test_extract_quoted_message_nested_forward_id_is_resolved():
nested_image = "https://img.example.com/nested.jpg"
reply = Reply(id="320", chain=[Plain(text="[Forward Message]")], message_str="")
event = _make_event(
reply,
responses={
("get_msg", "320"): {
"data": {"message": [{"type": "forward", "data": {"id": "fwd_1"}}]}
},
("get_forward_msg", "fwd_1"): {
"data": {
"messages": [
{
"sender": {"nickname": "Alice"},
"message": [{"type": "forward", "data": {"id": "fwd_2"}}],
}
]
}
},
("get_forward_msg", "fwd_2"): {
"data": {
"messages": [
{
"sender": {"nickname": "Bob"},
"message": [
{"type": "text", "data": {"text": "deep"}},
{"type": "image", "data": {"url": nested_image}},
],
}
]
}
},
},
)
text = await extract_quoted_message_text(event)
assert text is not None
assert "Bob: deep" in text
images = await extract_quoted_message_images(event)
assert images == [nested_image]