From 4ff07e3c743b89d804f572afa78d4aacc008903d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E3=82=A8=E3=82=A4=E3=82=AB=E3=82=AF?= <62183434+zouyonghe@users.noreply.github.com> Date: Thu, 12 Feb 2026 23:42:29 +0900 Subject: [PATCH] =?UTF-8?q?fix:=20=E5=AE=8C=E5=96=84=E8=BD=AC=E5=8F=91?= =?UTF-8?q?=E5=BC=95=E7=94=A8=E8=A7=A3=E6=9E=90=E4=B8=8E=E5=9B=BE=E7=89=87?= =?UTF-8?q?=E5=9B=9E=E9=80=80=E5=B9=B6=E6=94=AF=E6=8C=81=E9=85=8D=E7=BD=AE?= =?UTF-8?q?=E5=8C=96=E6=8E=A7=E5=88=B6=20(#5054)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: support fallback image parsing for quoted messages * fix: fallback parse quoted images when reply chain has placeholders * style: format network utils with ruff * test: expand quoted parser coverage and improve fallback diagnostics * fix: fallback to text-only retry when image requests fail * fix: tighten image fallback and resolve nested quoted forwards * refactor: simplify quoted message extraction and dedupe images * fix: harden quoted parsing and openai error candidates * fix: harden quoted image ref normalization * refactor: organize quoted parser settings and logging * fix: cap quoted fallback images and avoid retry loops * refactor: split quoted message parser into focused modules * refactor: share onebot segment parsing logic * refactor: unify quoted message parsing flow * feat: move quoted parser tuning to provider settings * fix: add missing i18n metadata for quoted parser settings * chore: refine forwarded message setting labels --- astrbot/core/astr_main_agent.py | 134 ++++- astrbot/core/config/default.py | 47 ++ .../method/agent_sub_stages/internal.py | 1 + .../core/provider/sources/openai_source.py | 170 +++++- astrbot/core/utils/quoted_message/__init__.py | 8 + .../core/utils/quoted_message/chain_parser.py | 505 ++++++++++++++++++ .../core/utils/quoted_message/extractor.py | 211 ++++++++ .../core/utils/quoted_message/image_refs.py | 94 ++++ .../utils/quoted_message/image_resolver.py | 130 +++++ .../utils/quoted_message/onebot_client.py | 119 +++++ astrbot/core/utils/quoted_message/settings.py | 85 +++ astrbot/core/utils/quoted_message_parser.py | 11 + astrbot/core/utils/string_utils.py | 21 + .../en-US/features/config-metadata.json | 22 + .../zh-CN/features/config-metadata.json | 22 + tests/test_openai_source.py | 382 +++++++++++++ tests/test_quoted_message_parser.py | 494 +++++++++++++++++ 17 files changed, 2425 insertions(+), 31 deletions(-) create mode 100644 astrbot/core/utils/quoted_message/__init__.py create mode 100644 astrbot/core/utils/quoted_message/chain_parser.py create mode 100644 astrbot/core/utils/quoted_message/extractor.py create mode 100644 astrbot/core/utils/quoted_message/image_refs.py create mode 100644 astrbot/core/utils/quoted_message/image_resolver.py create mode 100644 astrbot/core/utils/quoted_message/onebot_client.py create mode 100644 astrbot/core/utils/quoted_message/settings.py create mode 100644 astrbot/core/utils/quoted_message_parser.py create mode 100644 astrbot/core/utils/string_utils.py create mode 100644 tests/test_openai_source.py create mode 100644 tests/test_quoted_message_parser.py diff --git a/astrbot/core/astr_main_agent.py b/astrbot/core/astr_main_agent.py index 6a14f48e8..12c4fde1d 100644 --- a/astrbot/core/astr_main_agent.py +++ b/astrbot/core/astr_main_agent.py @@ -52,6 +52,17 @@ from astrbot.core.tools.cron_tools import ( ) from astrbot.core.utils.file_extract import extract_file_moonshotai from astrbot.core.utils.llm_metadata import LLM_METADATAS +from astrbot.core.utils.quoted_message.settings import ( + SETTINGS as DEFAULT_QUOTED_MESSAGE_SETTINGS, +) +from astrbot.core.utils.quoted_message.settings import ( + QuotedMessageParserSettings, +) +from astrbot.core.utils.quoted_message_parser import ( + extract_quoted_message_images, + extract_quoted_message_text, +) +from astrbot.core.utils.string_utils import normalize_and_dedupe_strings @dataclass(slots=True) @@ -108,6 +119,8 @@ class MainAgentBuildConfig: provider_settings: dict = field(default_factory=dict) subagent_orchestrator: dict = field(default_factory=dict) timezone: str | None = None + max_quoted_fallback_images: int = 20 + """Maximum number of images injected from quoted-message fallback extraction.""" @dataclass(slots=True) @@ -470,11 +483,29 @@ async def _ensure_img_caption( logger.error("处理图片描述失败: %s", exc) +def _append_quoted_image_attachment(req: ProviderRequest, image_path: str) -> None: + req.extra_user_content_parts.append( + TextPart(text=f"[Image Attachment in quoted message: path {image_path}]") + ) + + +def _get_quoted_message_parser_settings( + provider_settings: dict[str, object] | None, +) -> QuotedMessageParserSettings: + if not isinstance(provider_settings, dict): + return DEFAULT_QUOTED_MESSAGE_SETTINGS + overrides = provider_settings.get("quoted_message_parser") + if not isinstance(overrides, dict): + return DEFAULT_QUOTED_MESSAGE_SETTINGS + return DEFAULT_QUOTED_MESSAGE_SETTINGS.with_overrides(overrides) + + async def _process_quote_message( event: AstrMessageEvent, req: ProviderRequest, img_cap_prov_id: str, plugin_context: Context, + quoted_message_settings: QuotedMessageParserSettings = DEFAULT_QUOTED_MESSAGE_SETTINGS, ) -> None: quote = None for comp in event.message_obj.message: @@ -486,7 +517,15 @@ async def _process_quote_message( content_parts = [] sender_info = f"({quote.sender_nickname}): " if quote.sender_nickname else "" - message_str = quote.message_str or "[Empty Text]" + message_str = ( + await extract_quoted_message_text( + event, + quote, + settings=quoted_message_settings, + ) + or quote.message_str + or "[Empty Text]" + ) content_parts.append(f"{sender_info}{message_str}") image_seg = None @@ -592,11 +631,13 @@ async def _decorate_llm_request( ) img_cap_prov_id = cfg.get("default_image_caption_provider_id") or "" + quoted_message_settings = _get_quoted_message_parser_settings(cfg) await _process_quote_message( event, req, img_cap_prov_id, plugin_context, + quoted_message_settings, ) tz = config.timezone @@ -886,32 +927,78 @@ async def build_main_agent( ) # quoted message attachments reply_comps = [ - comp - for comp in event.message_obj.message - if isinstance(comp, Reply) and comp.chain + comp for comp in event.message_obj.message if isinstance(comp, Reply) ] + quoted_message_settings = _get_quoted_message_parser_settings( + config.provider_settings + ) + fallback_quoted_image_count = 0 for comp in reply_comps: - if not comp.chain: - continue - for reply_comp in comp.chain: - if isinstance(reply_comp, Image): - image_path = await reply_comp.convert_to_file_path() - req.image_urls.append(image_path) - req.extra_user_content_parts.append( - TextPart( - text=f"[Image Attachment in quoted message: path {image_path}]" - ) - ) - elif isinstance(reply_comp, File): - file_path = await reply_comp.get_file() - file_name = reply_comp.name or os.path.basename(file_path) - req.extra_user_content_parts.append( - TextPart( - text=( - f"[File Attachment in quoted message: " - f"name {file_name}, path {file_path}]" + has_embedded_image = False + if comp.chain: + for reply_comp in comp.chain: + if isinstance(reply_comp, Image): + has_embedded_image = True + image_path = await reply_comp.convert_to_file_path() + req.image_urls.append(image_path) + _append_quoted_image_attachment(req, image_path) + elif isinstance(reply_comp, File): + file_path = await reply_comp.get_file() + file_name = reply_comp.name or os.path.basename(file_path) + req.extra_user_content_parts.append( + TextPart( + text=( + f"[File Attachment in quoted message: " + f"name {file_name}, path {file_path}]" + ) ) ) + + # Fallback quoted image extraction for reply-id-only payloads, or when + # embedded reply chain only contains placeholders (e.g. [Forward Message], [Image]). + if not has_embedded_image: + try: + fallback_images = normalize_and_dedupe_strings( + await extract_quoted_message_images( + event, + comp, + settings=quoted_message_settings, + ) + ) + remaining_limit = max( + config.max_quoted_fallback_images + - fallback_quoted_image_count, + 0, + ) + if remaining_limit <= 0 and fallback_images: + logger.warning( + "Skip quoted fallback images due to limit=%d for umo=%s", + config.max_quoted_fallback_images, + event.unified_msg_origin, + ) + continue + if len(fallback_images) > remaining_limit: + logger.warning( + "Truncate quoted fallback images for umo=%s, reply_id=%s from %d to %d", + event.unified_msg_origin, + getattr(comp, "id", None), + len(fallback_images), + remaining_limit, + ) + fallback_images = fallback_images[:remaining_limit] + for image_ref in fallback_images: + if image_ref in req.image_urls: + continue + req.image_urls.append(image_ref) + fallback_quoted_image_count += 1 + _append_quoted_image_attachment(req, image_ref) + except Exception as exc: # noqa: BLE001 + logger.warning( + "Failed to resolve fallback quoted images for umo=%s, reply_id=%s: %s", + event.unified_msg_origin, + getattr(comp, "id", None), + exc, + exc_info=True, ) conversation = await _get_session_conv(event, plugin_context) @@ -921,6 +1008,7 @@ async def build_main_agent( if isinstance(req.contexts, str): req.contexts = json.loads(req.contexts) + req.image_urls = normalize_and_dedupe_strings(req.image_urls) if config.file_extract_enabled: try: diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 411384c1d..6c3545d87 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -99,6 +99,13 @@ DEFAULT_CONFIG = { "streaming_response": False, "show_tool_use_status": False, "sanitize_context_by_modalities": False, + "max_quoted_fallback_images": 20, + "quoted_message_parser": { + "max_component_chain_depth": 4, + "max_forward_node_depth": 6, + "max_forward_fetch": 32, + "warn_on_action_failure": False, + }, "agent_runner_type": "local", "dify_agent_runner_provider_id": "", "coze_agent_runner_provider_id": "", @@ -2908,6 +2915,46 @@ CONFIG_METADATA_3 = { "provider_settings.agent_runner_type": "local", }, }, + "provider_settings.max_quoted_fallback_images": { + "description": "引用图片回退解析上限", + "type": "int", + "hint": "引用/转发消息回退解析图片时的最大注入数量,超出会截断。", + "condition": { + "provider_settings.agent_runner_type": "local", + }, + }, + "provider_settings.quoted_message_parser.max_component_chain_depth": { + "description": "引用解析组件链深度", + "type": "int", + "hint": "解析 Reply 组件链时允许的最大递归深度。", + "condition": { + "provider_settings.agent_runner_type": "local", + }, + }, + "provider_settings.quoted_message_parser.max_forward_node_depth": { + "description": "引用解析转发节点深度", + "type": "int", + "hint": "解析合并转发节点时允许的最大递归深度。", + "condition": { + "provider_settings.agent_runner_type": "local", + }, + }, + "provider_settings.quoted_message_parser.max_forward_fetch": { + "description": "引用解析转发拉取上限", + "type": "int", + "hint": "递归拉取 get_forward_msg 的最大次数。", + "condition": { + "provider_settings.agent_runner_type": "local", + }, + }, + "provider_settings.quoted_message_parser.warn_on_action_failure": { + "description": "引用解析 action 失败告警", + "type": "bool", + "hint": "开启后,get_msg/get_forward_msg 全部尝试失败时输出 warning 日志。", + "condition": { + "provider_settings.agent_runner_type": "local", + }, + }, "provider_settings.max_agent_step": { "description": "工具调用轮数上限", "type": "int", diff --git a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py index db61ce0ec..be517dba9 100644 --- a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py +++ b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py @@ -123,6 +123,7 @@ class InternalAgentSubStage(Stage): provider_settings=settings, subagent_orchestrator=conf.get("subagent_orchestrator", {}), timezone=self.ctx.plugin_manager.context.get_config().get("timezone"), + max_quoted_fallback_images=settings.get("max_quoted_fallback_images", 20), ) async def process( diff --git a/astrbot/core/provider/sources/openai_source.py b/astrbot/core/provider/sources/openai_source.py index 0708c09c7..328da2573 100644 --- a/astrbot/core/provider/sources/openai_source.py +++ b/astrbot/core/provider/sources/openai_source.py @@ -5,6 +5,7 @@ import json import random import re from collections.abc import AsyncGenerator +from typing import Any import httpx from openai import AsyncAzureOpenAI, AsyncOpenAI @@ -27,6 +28,7 @@ from astrbot.core.utils.network_utils import ( is_connection_error, log_connection_failure, ) +from astrbot.core.utils.string_utils import normalize_and_dedupe_strings from ..register import register_provider_adapter @@ -36,6 +38,128 @@ from ..register import register_provider_adapter "OpenAI API Chat Completion 提供商适配器", ) class ProviderOpenAIOfficial(Provider): + _ERROR_TEXT_CANDIDATE_MAX_CHARS = 4096 + + @classmethod + def _truncate_error_text_candidate(cls, text: str) -> str: + if len(text) <= cls._ERROR_TEXT_CANDIDATE_MAX_CHARS: + return text + return text[: cls._ERROR_TEXT_CANDIDATE_MAX_CHARS] + + @staticmethod + def _safe_json_dump(value: Any) -> str | None: + try: + return json.dumps(value, ensure_ascii=False, default=str) + except Exception: + return None + + def _get_image_moderation_error_patterns(self) -> list[str]: + """Return configured moderation patterns (case-insensitive substring match, not regex).""" + configured = self.provider_config.get("image_moderation_error_patterns", []) + patterns: list[str] = [] + if isinstance(configured, str): + configured = [configured] + if isinstance(configured, list): + for pattern in configured: + if not isinstance(pattern, str): + continue + pattern = pattern.strip() + if pattern: + patterns.append(pattern) + return patterns + + @staticmethod + def _extract_error_text_candidates(error: Exception) -> list[str]: + candidates: list[str] = [] + + def _append_candidate(candidate: Any): + if candidate is None: + return + text = str(candidate).strip() + if not text: + return + candidates.append( + ProviderOpenAIOfficial._truncate_error_text_candidate(text) + ) + + _append_candidate(str(error)) + + body = getattr(error, "body", None) + if isinstance(body, dict): + err_obj = body.get("error") + body_text = ProviderOpenAIOfficial._safe_json_dump( + {"error": err_obj} if isinstance(err_obj, dict) else body + ) + _append_candidate(body_text) + if isinstance(err_obj, dict): + for field in ("message", "type", "code", "param"): + value = err_obj.get(field) + if value is not None: + _append_candidate(value) + elif isinstance(body, str): + _append_candidate(body) + + response = getattr(error, "response", None) + if response is not None: + response_text = getattr(response, "text", None) + if isinstance(response_text, str): + _append_candidate(response_text) + + return normalize_and_dedupe_strings(candidates) + + def _is_content_moderated_upload_error(self, error: Exception) -> bool: + patterns = [ + pattern.lower() for pattern in self._get_image_moderation_error_patterns() + ] + if not patterns: + return False + candidates = [ + candidate.lower() + for candidate in self._extract_error_text_candidates(error) + ] + for pattern in patterns: + if any(pattern in candidate for candidate in candidates): + return True + return False + + @staticmethod + def _context_contains_image(contexts: list[dict]) -> bool: + for context in contexts: + content = context.get("content") + if not isinstance(content, list): + continue + for item in content: + if isinstance(item, dict) and item.get("type") == "image_url": + return True + return False + + async def _fallback_to_text_only_and_retry( + self, + payloads: dict, + context_query: list, + chosen_key: str, + available_api_keys: list[str], + func_tool: ToolSet | None, + reason: str, + *, + image_fallback_used: bool = False, + ) -> tuple: + logger.warning( + "检测到图片请求失败(%s),已移除图片并重试(保留文本内容)。", + reason, + ) + new_contexts = await self._remove_image_from_context(context_query) + payloads["messages"] = new_contexts + return ( + False, + chosen_key, + available_api_keys, + payloads, + new_contexts, + func_tool, + image_fallback_used, + ) + def _create_http_client(self, provider_config: dict) -> httpx.AsyncClient | None: """创建带代理的 HTTP 客户端""" proxy = provider_config.get("proxy", "") @@ -403,6 +527,7 @@ class ProviderOpenAIOfficial(Provider): available_api_keys: list[str], retry_cnt: int, max_retries: int, + image_fallback_used: bool = False, ) -> tuple: """处理API错误并尝试恢复""" if "429" in str(e): @@ -422,6 +547,7 @@ class ProviderOpenAIOfficial(Provider): payloads, context_query, func_tool, + image_fallback_used, ) raise e if "maximum context length" in str(e): @@ -437,20 +563,34 @@ class ProviderOpenAIOfficial(Provider): payloads, context_query, func_tool, + image_fallback_used, ) if "The model is not a VLM" in str(e): # siliconcloud + if image_fallback_used or not self._context_contains_image(context_query): + raise e # 尝试删除所有 image - new_contexts = await self._remove_image_from_context(context_query) - payloads["messages"] = new_contexts - context_query = new_contexts - return ( - False, - chosen_key, - available_api_keys, + return await self._fallback_to_text_only_and_retry( payloads, context_query, + chosen_key, + available_api_keys, func_tool, + "model_not_vlm", + image_fallback_used=True, ) + if self._is_content_moderated_upload_error(e): + if image_fallback_used or not self._context_contains_image(context_query): + raise e + return await self._fallback_to_text_only_and_retry( + payloads, + context_query, + chosen_key, + available_api_keys, + func_tool, + "image_content_moderated", + image_fallback_used=True, + ) + if ( "Function calling is not enabled" in str(e) or ("tool" in str(e).lower() and "support" in str(e).lower()) @@ -461,7 +601,15 @@ class ProviderOpenAIOfficial(Provider): f"{self.get_model()} 不支持函数工具调用,已自动去除,不影响使用。", ) payloads.pop("tools", None) - return False, chosen_key, available_api_keys, payloads, context_query, None + return ( + False, + chosen_key, + available_api_keys, + payloads, + context_query, + None, + image_fallback_used, + ) # logger.error(f"发生了错误。Provider 配置如下: {self.provider_config}") if "tool" in str(e).lower() and "support" in str(e).lower(): @@ -501,6 +649,7 @@ class ProviderOpenAIOfficial(Provider): max_retries = 10 available_api_keys = self.api_keys.copy() chosen_key = random.choice(available_api_keys) + image_fallback_used = False last_exception = None retry_cnt = 0 @@ -518,6 +667,7 @@ class ProviderOpenAIOfficial(Provider): payloads, context_query, func_tool, + image_fallback_used, ) = await self._handle_api_error( e, payloads, @@ -527,6 +677,7 @@ class ProviderOpenAIOfficial(Provider): available_api_keys, retry_cnt, max_retries, + image_fallback_used=image_fallback_used, ) if success: break @@ -564,6 +715,7 @@ class ProviderOpenAIOfficial(Provider): max_retries = 10 available_api_keys = self.api_keys.copy() chosen_key = random.choice(available_api_keys) + image_fallback_used = False last_exception = None retry_cnt = 0 @@ -582,6 +734,7 @@ class ProviderOpenAIOfficial(Provider): payloads, context_query, func_tool, + image_fallback_used, ) = await self._handle_api_error( e, payloads, @@ -591,6 +744,7 @@ class ProviderOpenAIOfficial(Provider): available_api_keys, retry_cnt, max_retries, + image_fallback_used=image_fallback_used, ) if success: break diff --git a/astrbot/core/utils/quoted_message/__init__.py b/astrbot/core/utils/quoted_message/__init__.py new file mode 100644 index 000000000..8421898fd --- /dev/null +++ b/astrbot/core/utils/quoted_message/__init__.py @@ -0,0 +1,8 @@ +from __future__ import annotations + +from .extractor import extract_quoted_message_images, extract_quoted_message_text + +__all__ = [ + "extract_quoted_message_text", + "extract_quoted_message_images", +] diff --git a/astrbot/core/utils/quoted_message/chain_parser.py b/astrbot/core/utils/quoted_message/chain_parser.py new file mode 100644 index 000000000..5fe89302e --- /dev/null +++ b/astrbot/core/utils/quoted_message/chain_parser.py @@ -0,0 +1,505 @@ +from __future__ import annotations + +import json +import re +from typing import Any, TypedDict + +from astrbot.core.message.components import ( + At, + AtAll, + File, + Forward, + Image, + Node, + Nodes, + Plain, + Reply, + Video, +) +from astrbot.core.platform.astr_message_event import AstrMessageEvent +from astrbot.core.utils.string_utils import normalize_and_dedupe_strings + +from .image_refs import looks_like_image_file_name, normalize_file_like_url +from .settings import SETTINGS, QuotedMessageParserSettings + +_FORWARD_PLACEHOLDER_PATTERN = re.compile( + r"^(?:[\(\[]?[^\]:\)]*[\)\]]?\s*:\s*)?\[(?:forward message|转发消息|合并转发)\]$", + flags=re.IGNORECASE, +) + + +class ParsedOneBotPayload(TypedDict): + text: str | None + forward_ids: list[str] + image_refs: list[str] + + +def _build_parsed_payload( + text: str | None, + forward_ids: list[str] | None = None, + image_refs: list[str] | None = None, +) -> ParsedOneBotPayload: + return { + "text": text, + "forward_ids": forward_ids or [], + "image_refs": image_refs or [], + } + + +def _join_text_parts(parts: list[str]) -> str | None: + text = "".join(parts).strip() + return text or None + + +def _find_first_reply_component(event: AstrMessageEvent) -> Reply | None: + for comp in event.message_obj.message: + if isinstance(comp, Reply): + return comp + return None + + +def _is_forward_placeholder_only_text(text: str | None) -> bool: + if not isinstance(text, str): + return False + lines = [line.strip() for line in text.splitlines() if line.strip()] + if not lines: + return False + return all(_FORWARD_PLACEHOLDER_PATTERN.match(line) for line in lines) + + +def _extract_image_refs_from_component_chain( + chain: list[Any] | None, + *, + depth: int = 0, + settings: QuotedMessageParserSettings = SETTINGS, +) -> list[str]: + if not isinstance(chain, list) or depth > settings.max_component_chain_depth: + return [] + + image_refs: list[str] = [] + for seg in chain: + if isinstance(seg, Image): + for candidate in (seg.url, seg.file, seg.path): + if isinstance(candidate, str) and candidate.strip(): + image_refs.append(candidate.strip()) + break + elif isinstance(seg, Reply): + image_refs.extend( + _extract_image_refs_from_reply_component( + seg, + depth=depth + 1, + settings=settings, + ) + ) + elif isinstance(seg, Node): + image_refs.extend( + _extract_image_refs_from_component_chain( + seg.content, + depth=depth + 1, + settings=settings, + ) + ) + elif isinstance(seg, Nodes): + for node in seg.nodes: + image_refs.extend( + _extract_image_refs_from_component_chain( + node.content, + depth=depth + 1, + settings=settings, + ) + ) + + return normalize_and_dedupe_strings(image_refs) + + +def _extract_text_from_component_chain( + chain: list[Any] | None, + *, + depth: int = 0, + settings: QuotedMessageParserSettings = SETTINGS, +) -> str | None: + if not isinstance(chain, list) or depth > settings.max_component_chain_depth: + return None + + parts: list[str] = [] + for seg in chain: + if isinstance(seg, Plain): + if seg.text: + parts.append(seg.text) + elif isinstance(seg, At): + if seg.name: + parts.append(f"@{seg.name}") + elif seg.qq: + parts.append(f"@{seg.qq}") + elif isinstance(seg, AtAll): + parts.append("@all") + elif isinstance(seg, Image): + parts.append("[Image]") + elif isinstance(seg, Video): + parts.append("[Video]") + elif isinstance(seg, File): + file_name = seg.name or "file" + parts.append(f"[File:{file_name}]") + elif isinstance(seg, Forward): + parts.append("[Forward Message]") + elif isinstance(seg, Reply): + nested = _extract_text_from_reply_component( + seg, + depth=depth + 1, + settings=settings, + ) + if nested: + parts.append(nested) + elif isinstance(seg, Node): + node_sender = seg.name or seg.uin or "Unknown User" + node_text = _extract_text_from_component_chain( + seg.content, + depth=depth + 1, + settings=settings, + ) + if node_text: + parts.append(f"{node_sender}: {node_text}") + elif isinstance(seg, Nodes): + for node in seg.nodes: + node_sender = node.name or node.uin or "Unknown User" + node_text = _extract_text_from_component_chain( + node.content, + depth=depth + 1, + settings=settings, + ) + if node_text: + parts.append(f"{node_sender}: {node_text}") + + return _join_text_parts(parts) + + +def _extract_image_refs_from_reply_component( + reply: Reply, + *, + depth: int = 0, + settings: QuotedMessageParserSettings = SETTINGS, +) -> list[str]: + for attr in ("chain", "message", "origin", "content"): + payload = getattr(reply, attr, None) + image_refs = _extract_image_refs_from_component_chain( + payload, + depth=depth, + settings=settings, + ) + if image_refs: + return image_refs + return [] + + +def _extract_text_from_reply_component( + reply: Reply, + *, + depth: int = 0, + settings: QuotedMessageParserSettings = SETTINGS, +) -> str | None: + for attr in ("chain", "message", "origin", "content"): + payload = getattr(reply, attr, None) + text = _extract_text_from_component_chain( + payload, + depth=depth, + settings=settings, + ) + if text: + return text + + if reply.message_str and reply.message_str.strip(): + return reply.message_str.strip() + return None + + +def _unwrap_onebot_data(payload: Any) -> dict[str, Any]: + if not isinstance(payload, dict): + return {} + data = payload.get("data") + if isinstance(data, dict): + return data + return payload + + +def _extract_text_from_multimsg_json(raw_json: str) -> str | None: + try: + parsed = json.loads(raw_json) + except Exception: + return None + + if not isinstance(parsed, dict): + return None + if parsed.get("app") != "com.tencent.multimsg": + return None + config = parsed.get("config") + if not isinstance(config, dict): + return None + if config.get("forward") != 1: + return None + + meta = parsed.get("meta") + if not isinstance(meta, dict): + return None + detail = meta.get("detail") + if not isinstance(detail, dict): + return None + news_items = detail.get("news") + if not isinstance(news_items, list): + return None + + texts: list[str] = [] + for item in news_items: + if not isinstance(item, dict): + continue + text_content = item.get("text") + if not isinstance(text_content, str): + continue + cleaned = text_content.strip().replace("[图片]", "").strip() + if cleaned: + texts.append(cleaned) + + return "\n".join(texts).strip() or None + + +def _parse_onebot_segments( + segments: list[Any], + *, + settings: QuotedMessageParserSettings = SETTINGS, +) -> ParsedOneBotPayload: + text_parts: list[str] = [] + forward_ids: list[str] = [] + image_refs: list[str] = [] + + for seg in segments: + if not isinstance(seg, dict): + continue + + seg_type = seg.get("type") + seg_data = seg.get("data", {}) if isinstance(seg.get("data"), dict) else {} + + if seg_type in ("text", "plain"): + text = seg_data.get("text") + if isinstance(text, str) and text: + text_parts.append(text) + elif seg_type == "image": + text_parts.append("[Image]") + candidate = seg_data.get("url") or seg_data.get("file") + if isinstance(candidate, str) and candidate.strip(): + image_refs.append(candidate.strip()) + elif seg_type == "video": + text_parts.append("[Video]") + elif seg_type == "file": + file_name = ( + seg_data.get("name") + or seg_data.get("file_name") + or seg_data.get("file") + or "file" + ) + text_parts.append(f"[File:{file_name}]") + candidate_url = seg_data.get("url") + if ( + isinstance(candidate_url, str) + and candidate_url.strip() + and looks_like_image_file_name(normalize_file_like_url(candidate_url)) + ): + image_refs.append(candidate_url.strip()) + candidate_file = seg_data.get("file") + if ( + isinstance(candidate_file, str) + and candidate_file.strip() + and looks_like_image_file_name( + normalize_file_like_url( + seg_data.get("name") + or seg_data.get("file_name") + or candidate_file + ) + ) + ): + image_refs.append(candidate_file.strip()) + elif seg_type in ("forward", "forward_msg", "nodes"): + fid = seg_data.get("id") or seg_data.get("message_id") + if isinstance(fid, (str, int)) and str(fid): + forward_ids.append(str(fid)) + else: + nested_nodes = seg_data.get("content") + nested_text, nested_forward_ids, nested_images = ( + _extract_text_forward_ids_and_images_from_forward_nodes( + nested_nodes if isinstance(nested_nodes, list) else [], + depth=1, + settings=settings, + ) + ) + if nested_text: + text_parts.append(nested_text) + if nested_forward_ids: + forward_ids.extend(nested_forward_ids) + if nested_images: + image_refs.extend(nested_images) + elif seg_type == "json": + raw_json = seg_data.get("data") + if isinstance(raw_json, str) and raw_json.strip(): + raw_json = raw_json.replace(",", ",") + multimsg_text = _extract_text_from_multimsg_json(raw_json) + if multimsg_text: + text_parts.append(multimsg_text) + + return _build_parsed_payload( + _join_text_parts(text_parts), + forward_ids, + normalize_and_dedupe_strings(image_refs), + ) + + +def _extract_text_forward_ids_and_images_from_forward_nodes( + nodes: list[Any], + *, + depth: int = 0, + settings: QuotedMessageParserSettings = SETTINGS, +) -> tuple[str | None, list[str], list[str]]: + if not isinstance(nodes, list) or depth > settings.max_forward_node_depth: + return None, [], [] + + texts: list[str] = [] + forward_ids: list[str] = [] + image_refs: list[str] = [] + indent = " " * depth + + for node in nodes: + if not isinstance(node, dict): + continue + + sender = node.get("sender") if isinstance(node.get("sender"), dict) else {} + sender_name = ( + sender.get("nickname") + or sender.get("card") + or sender.get("user_id") + or "Unknown User" + ) + + raw_content = node.get("message") or node.get("content") or [] + chain: list[Any] = [] + if isinstance(raw_content, list): + chain = raw_content + elif isinstance(raw_content, str): + raw_content = raw_content.strip() + if raw_content: + try: + parsed = json.loads(raw_content) + except Exception: + parsed = None + if isinstance(parsed, list): + chain = parsed + else: + chain = [{"type": "text", "data": {"text": raw_content}}] + + parsed_segments = _parse_onebot_segments(chain, settings=settings) + node_text = parsed_segments["text"] + node_forward_ids = parsed_segments["forward_ids"] + node_images = parsed_segments["image_refs"] + if node_text: + texts.append(f"{indent}{sender_name}: {node_text}") + if node_forward_ids: + forward_ids.extend(node_forward_ids) + if node_images: + image_refs.extend(node_images) + + return ( + "\n".join(texts).strip() or None, + normalize_and_dedupe_strings(forward_ids), + normalize_and_dedupe_strings(image_refs), + ) + + +def _parse_onebot_get_msg_payload( + payload: dict[str, Any], + *, + settings: QuotedMessageParserSettings = SETTINGS, +) -> ParsedOneBotPayload: + data = _unwrap_onebot_data(payload) + segments = data.get("message") or data.get("messages") + if isinstance(segments, list): + return _parse_onebot_segments(segments, settings=settings) + + text: str | None = None + if isinstance(segments, str) and segments.strip(): + text = segments.strip() + else: + raw = data.get("raw_message") + if isinstance(raw, str) and raw.strip(): + text = raw.strip() + return _build_parsed_payload(text) + + +def _parse_onebot_get_forward_payload( + payload: dict[str, Any], + *, + settings: QuotedMessageParserSettings = SETTINGS, +) -> ParsedOneBotPayload: + data = _unwrap_onebot_data(payload) + nodes = ( + data.get("messages") + or data.get("message") + or data.get("nodes") + or data.get("nodeList") + ) + if not isinstance(nodes, list): + return _build_parsed_payload(None) + + text, forward_ids, image_refs = ( + _extract_text_forward_ids_and_images_from_forward_nodes( + nodes, + settings=settings, + ) + ) + return _build_parsed_payload(text, forward_ids, image_refs) + + +class ReplyChainParser: + def __init__(self, settings: QuotedMessageParserSettings = SETTINGS): + self._settings = settings + + @staticmethod + def find_first_reply_component(event: AstrMessageEvent) -> Reply | None: + return _find_first_reply_component(event) + + @staticmethod + def is_forward_placeholder_only_text(text: str | None) -> bool: + return _is_forward_placeholder_only_text(text) + + def extract_text_from_reply_component( + self, + reply: Reply, + *, + depth: int = 0, + ) -> str | None: + return _extract_text_from_reply_component( + reply, + depth=depth, + settings=self._settings, + ) + + def extract_image_refs_from_reply_component( + self, + reply: Reply, + *, + depth: int = 0, + ) -> list[str]: + return _extract_image_refs_from_reply_component( + reply, + depth=depth, + settings=self._settings, + ) + + +class OneBotPayloadParser: + def __init__(self, settings: QuotedMessageParserSettings = SETTINGS): + self._settings = settings + + def parse_get_msg_payload(self, payload: dict[str, Any]) -> ParsedOneBotPayload: + return _parse_onebot_get_msg_payload(payload, settings=self._settings) + + def parse_get_forward_payload( + self, + payload: dict[str, Any], + ) -> ParsedOneBotPayload: + return _parse_onebot_get_forward_payload(payload, settings=self._settings) diff --git a/astrbot/core/utils/quoted_message/extractor.py b/astrbot/core/utils/quoted_message/extractor.py new file mode 100644 index 000000000..83570d66c --- /dev/null +++ b/astrbot/core/utils/quoted_message/extractor.py @@ -0,0 +1,211 @@ +from __future__ import annotations + +from dataclasses import dataclass + +from astrbot import logger +from astrbot.core.message.components import Reply +from astrbot.core.platform.astr_message_event import AstrMessageEvent +from astrbot.core.utils.string_utils import normalize_and_dedupe_strings + +from .chain_parser import OneBotPayloadParser, ReplyChainParser +from .image_resolver import ImageResolver +from .onebot_client import OneBotClient +from .settings import SETTINGS, QuotedMessageParserSettings + + +async def _collect_text_and_images_from_forward_ids( + onebot_client: OneBotClient, + payload_parser: OneBotPayloadParser, + forward_ids: list[str], + *, + max_fetch: int, +) -> tuple[list[str], list[str]]: + texts: list[str] = [] + image_refs: list[str] = [] + pending: list[str] = [] + seen: set[str] = set() + + for fid in forward_ids: + if not isinstance(fid, str): + continue + cleaned = fid.strip() + if cleaned: + pending.append(cleaned) + + fetch_count = 0 + while pending and fetch_count < max_fetch: + current_id = pending.pop(0) + if current_id in seen: + continue + seen.add(current_id) + fetch_count += 1 + + forward_payload = await onebot_client.get_forward_msg(current_id) + if not forward_payload: + continue + + parsed = payload_parser.parse_get_forward_payload(forward_payload) + if parsed["text"]: + texts.append(parsed["text"]) + if parsed["image_refs"]: + image_refs.extend(parsed["image_refs"]) + for nested_id in parsed["forward_ids"]: + if nested_id not in seen: + pending.append(nested_id) + + if pending: + logger.warning( + "quoted_message_parser: stop fetching nested forward messages after %d hops", + max_fetch, + ) + + return texts, normalize_and_dedupe_strings(image_refs) + + +@dataclass(slots=True) +class QuotedMessageContent: + embedded_text: str | None + embedded_image_refs: list[str] + reply_id: str + direct_text: str | None + direct_image_refs: list[str] + forward_texts: list[str] + forward_image_refs: list[str] + + +class QuotedMessageExtractor: + def __init__( + self, + event: AstrMessageEvent, + settings: QuotedMessageParserSettings = SETTINGS, + ): + self._event = event + self._settings = settings + self._reply_parser = ReplyChainParser(settings=settings) + self._payload_parser = OneBotPayloadParser(settings=settings) + self._client = OneBotClient(event, settings=settings) + self._image_resolver = ImageResolver(event, self._client) + + async def _fetch_quoted_content( + self, + reply_component: Reply | None = None, + *, + fetch_remote: bool, + ) -> QuotedMessageContent | None: + reply = reply_component or self._reply_parser.find_first_reply_component( + self._event + ) + if not reply: + return None + + embedded_text = self._reply_parser.extract_text_from_reply_component(reply) + embedded_image_refs = list( + self._reply_parser.extract_image_refs_from_reply_component(reply) + ) + + reply_id = getattr(reply, "id", None) + reply_id_str = str(reply_id).strip() if reply_id is not None else "" + if not fetch_remote or not reply_id_str: + return QuotedMessageContent( + embedded_text=embedded_text, + embedded_image_refs=embedded_image_refs, + reply_id=reply_id_str, + direct_text=None, + direct_image_refs=[], + forward_texts=[], + forward_image_refs=[], + ) + + msg_payload = await self._client.get_msg(reply_id_str) + if not msg_payload: + return QuotedMessageContent( + embedded_text=embedded_text, + embedded_image_refs=embedded_image_refs, + reply_id=reply_id_str, + direct_text=None, + direct_image_refs=[], + forward_texts=[], + forward_image_refs=[], + ) + + parsed = self._payload_parser.parse_get_msg_payload(msg_payload) + forward_texts, forward_images = await _collect_text_and_images_from_forward_ids( + self._client, + self._payload_parser, + parsed["forward_ids"], + max_fetch=self._settings.max_forward_fetch, + ) + return QuotedMessageContent( + embedded_text=embedded_text, + embedded_image_refs=embedded_image_refs, + reply_id=reply_id_str, + direct_text=parsed["text"], + direct_image_refs=list(parsed["image_refs"]), + forward_texts=forward_texts, + forward_image_refs=forward_images, + ) + + async def text(self, reply_component: Reply | None = None) -> str | None: + embedded_content = await self._fetch_quoted_content( + reply_component, + fetch_remote=False, + ) + if not embedded_content: + return None + + if ( + embedded_content.embedded_text + and not self._reply_parser.is_forward_placeholder_only_text( + embedded_content.embedded_text + ) + ): + return embedded_content.embedded_text + + if not embedded_content.reply_id: + return embedded_content.embedded_text + + fetched_content = await self._fetch_quoted_content( + reply_component, + fetch_remote=True, + ) + if not fetched_content: + return embedded_content.embedded_text + + text_parts: list[str] = [] + if fetched_content.direct_text: + text_parts.append(fetched_content.direct_text) + text_parts.extend(fetched_content.forward_texts) + + return "\n".join(text_parts).strip() or embedded_content.embedded_text + + async def images(self, reply_component: Reply | None = None) -> list[str]: + content = await self._fetch_quoted_content(reply_component, fetch_remote=True) + if not content: + return [] + + image_refs: list[str] = [] + image_refs.extend(content.embedded_image_refs) + image_refs.extend(content.direct_image_refs) + image_refs.extend(content.forward_image_refs) + + return await self._image_resolver.resolve_for_llm(image_refs) + + +async def extract_quoted_message_text( + event: AstrMessageEvent, + reply_component: Reply | None = None, + settings: QuotedMessageParserSettings | None = None, +) -> str | None: + return await QuotedMessageExtractor(event, settings=settings or SETTINGS).text( + reply_component + ) + + +async def extract_quoted_message_images( + event: AstrMessageEvent, + reply_component: Reply | None = None, + settings: QuotedMessageParserSettings | None = None, +) -> list[str]: + return await QuotedMessageExtractor(event, settings=settings or SETTINGS).images( + reply_component + ) diff --git a/astrbot/core/utils/quoted_message/image_refs.py b/astrbot/core/utils/quoted_message/image_refs.py new file mode 100644 index 000000000..009d6844a --- /dev/null +++ b/astrbot/core/utils/quoted_message/image_refs.py @@ -0,0 +1,94 @@ +from __future__ import annotations + +import os +from urllib.parse import urlsplit + +IMAGE_EXTENSIONS = { + ".jpg", + ".jpeg", + ".png", + ".webp", + ".bmp", + ".tif", + ".tiff", + ".gif", +} + + +def normalize_file_like_url(path: str | None) -> str | None: + if path is None: + return None + if not isinstance(path, str): + return None + if "?" not in path and "#" not in path: + return path + try: + split = urlsplit(path) + except Exception: + return path + return split.path or path + + +def looks_like_image_file_name(name: str) -> bool: + normalized_name = normalize_file_like_url(name) + if not isinstance(normalized_name, str) or not normalized_name.strip(): + return False + _, ext = os.path.splitext(normalized_name.strip().lower()) + return ext in IMAGE_EXTENSIONS + + +def convert_data_image_to_base64_ref(image_ref: str) -> str | None: + if not isinstance(image_ref, str): + return None + value = image_ref.strip() + if not value: + return None + lower_value = value.lower() + if not lower_value.startswith("data:image/"): + return None + + comma_index = value.find(",") + if comma_index <= 0: + return None + header = value[:comma_index].lower() + payload = value[comma_index + 1 :].strip() + if ";base64" not in header or not payload: + return None + return f"base64://{payload}" + + +def get_existing_local_path(value: str) -> str | None: + lower_value = value.lower() + if lower_value.startswith("file://"): + file_path = value[7:] + if file_path.startswith("/") and len(file_path) > 3 and file_path[2] == ":": + file_path = file_path[1:] + if file_path and os.path.exists(file_path): + return os.path.abspath(file_path) + return None + if os.path.exists(value): + return os.path.abspath(value) + return None + + +def normalize_image_ref(image_ref: str) -> str | None: + if not isinstance(image_ref, str): + return None + value = image_ref.strip() + if not value: + return None + lower_value = value.lower() + + if lower_value.startswith(("http://", "https://")): + return value + if lower_value.startswith("base64://"): + return value + + data_image_ref = convert_data_image_to_base64_ref(value) + if data_image_ref: + return data_image_ref + + local_path = get_existing_local_path(value) + if local_path and looks_like_image_file_name(local_path): + return local_path + return None diff --git a/astrbot/core/utils/quoted_message/image_resolver.py b/astrbot/core/utils/quoted_message/image_resolver.py new file mode 100644 index 000000000..5a4c21fb2 --- /dev/null +++ b/astrbot/core/utils/quoted_message/image_resolver.py @@ -0,0 +1,130 @@ +from __future__ import annotations + +import os +from typing import Any + +from astrbot import logger +from astrbot.core.platform.astr_message_event import AstrMessageEvent +from astrbot.core.utils.string_utils import normalize_and_dedupe_strings + +from .image_refs import IMAGE_EXTENSIONS, get_existing_local_path, normalize_image_ref +from .onebot_client import OneBotClient + + +def _build_image_id_candidates(image_ref: str) -> list[str]: + candidates: list[str] = [image_ref] + base_name, ext = os.path.splitext(image_ref) + if ext and base_name and base_name not in candidates: + if ext.lower() in IMAGE_EXTENSIONS: + candidates.append(base_name) + return candidates + + +def _build_image_resolve_actions( + event: AstrMessageEvent, + image_ref: str, +) -> list[tuple[str, dict[str, Any]]]: + actions: list[tuple[str, dict[str, Any]]] = [] + candidates = _build_image_id_candidates(image_ref) + + for candidate in candidates: + actions.extend( + [ + ("get_image", {"file": candidate}), + ("get_image", {"file_id": candidate}), + ("get_image", {"id": candidate}), + ("get_image", {"image": candidate}), + ("get_file", {"file_id": candidate}), + ("get_file", {"file": candidate}), + ] + ) + + try: + group_id = event.get_group_id() + except Exception: + group_id = None + group_id_value = group_id + if isinstance(group_id, str) and group_id.isdigit(): + group_id_value = int(group_id) + + if group_id_value: + for candidate in candidates: + actions.append( + ( + "get_group_file_url", + {"group_id": group_id_value, "file_id": candidate}, + ) + ) + for candidate in candidates: + actions.append(("get_private_file_url", {"file_id": candidate})) + + return actions + + +class ImageResolver: + def __init__( + self, + event: AstrMessageEvent, + onebot_client: OneBotClient | None = None, + ): + self._event = event + self._client = onebot_client or OneBotClient(event) + + async def resolve_for_llm(self, image_refs: list[str]) -> list[str]: + resolved: list[str] = [] + unresolved: list[str] = [] + + for image_ref in normalize_and_dedupe_strings(image_refs): + normalized = normalize_image_ref(image_ref) + if normalized: + resolved.append(normalized) + elif get_existing_local_path(image_ref): + # Drop non-image local paths instead of treating them as remote IDs. + logger.debug( + "quoted_message_parser: skip non-image local path ref=%s", + image_ref[:128], + ) + else: + unresolved.append(image_ref) + + for image_ref in unresolved: + resolved_ref = await self._resolve_one(image_ref) + if resolved_ref: + resolved.append(resolved_ref) + + return normalize_and_dedupe_strings(resolved) + + async def _resolve_one(self, image_ref: str) -> str | None: + resolved = normalize_image_ref(image_ref) + if resolved: + return resolved + + actions = _build_image_resolve_actions(self._event, image_ref) + for action, params in actions: + data = await self._client.call( + action, + params, + warn_on_all_failed=False, + unwrap_data=True, + ) + if not isinstance(data, dict): + continue + + url = data.get("url") + if isinstance(url, str): + normalized = normalize_image_ref(url) + if normalized: + return normalized + + file_value = data.get("file") + if isinstance(file_value, str): + normalized = normalize_image_ref(file_value) + if normalized: + return normalized + + logger.warning( + "quoted_message_parser: failed to resolve quoted image ref=%s after %d actions", + image_ref[:128], + len(actions), + ) + return None diff --git a/astrbot/core/utils/quoted_message/onebot_client.py b/astrbot/core/utils/quoted_message/onebot_client.py new file mode 100644 index 000000000..c48785d76 --- /dev/null +++ b/astrbot/core/utils/quoted_message/onebot_client.py @@ -0,0 +1,119 @@ +from __future__ import annotations + +from typing import Any + +from astrbot import logger +from astrbot.core.platform.astr_message_event import AstrMessageEvent + +from .settings import SETTINGS, QuotedMessageParserSettings + + +def _unwrap_action_response(ret: dict[str, Any] | None) -> dict[str, Any]: + if not isinstance(ret, dict): + return {} + data = ret.get("data") + if isinstance(data, dict): + return data + return ret + + +class OneBotClient: + def __init__( + self, + event: AstrMessageEvent, + settings: QuotedMessageParserSettings = SETTINGS, + ): + self._call_action = self._resolve_call_action(event) + self._settings = settings + + @staticmethod + def _resolve_call_action(event: AstrMessageEvent): + bot = getattr(event, "bot", None) + api = getattr(bot, "api", None) + call_action = getattr(api, "call_action", None) + if not callable(call_action): + return None + return call_action + + async def _call_action_try_params( + self, + action: str, + params_list: list[dict[str, Any]], + *, + warn_on_all_failed: bool | None = None, + ) -> dict[str, Any] | None: + if self._call_action is None: + return None + if warn_on_all_failed is None: + warn_on_all_failed = self._settings.warn_on_action_failure + + last_error: Exception | None = None + last_params: dict[str, Any] | None = None + for params in params_list: + try: + result = await self._call_action(action, **params) + if isinstance(result, dict): + return result + except Exception as exc: + last_error = exc + last_params = params + logger.debug( + "quoted_message_parser: action %s failed with params %s: %s", + action, + {k: str(v)[:64] for k, v in params.items()}, + exc, + ) + if warn_on_all_failed and last_error is not None: + logger.warning( + "quoted_message_parser: all attempts failed for action %s, " + "last_params=%s, error=%s", + action, + ( + {k: str(v)[:64] for k, v in last_params.items()} + if isinstance(last_params, dict) + else None + ), + last_error, + ) + return None + + async def call( + self, + action: str, + params: dict[str, Any], + *, + warn_on_all_failed: bool = False, + unwrap_data: bool = True, + ) -> dict[str, Any] | None: + ret = await self._call_action_try_params( + action, + [params], + warn_on_all_failed=warn_on_all_failed, + ) + if not unwrap_data: + return ret + return _unwrap_action_response(ret) + + async def _call_action_compat( + self, + action: str, + message_id: str | int, + ) -> dict[str, Any] | None: + message_id_str = str(message_id).strip() + if not message_id_str: + return None + + params_list: list[dict[str, Any]] = [ + {"message_id": message_id_str}, + {"id": message_id_str}, + ] + if message_id_str.isdigit(): + int_id = int(message_id_str) + params_list.extend([{"message_id": int_id}, {"id": int_id}]) + return await self._call_action_try_params(action, params_list) + + async def get_msg(self, message_id: str | int) -> dict[str, Any] | None: + return await self._call_action_compat("get_msg", message_id) + + async def get_forward_msg(self, forward_id: str | int) -> dict[str, Any] | None: + return await self._call_action_compat("get_forward_msg", forward_id) diff --git a/astrbot/core/utils/quoted_message/settings.py b/astrbot/core/utils/quoted_message/settings.py new file mode 100644 index 000000000..2f74f41b6 --- /dev/null +++ b/astrbot/core/utils/quoted_message/settings.py @@ -0,0 +1,85 @@ +from __future__ import annotations + +from collections.abc import Mapping +from dataclasses import dataclass +from typing import Any + +_DEFAULT_MAX_COMPONENT_CHAIN_DEPTH = 4 +_DEFAULT_MAX_FORWARD_NODE_DEPTH = 6 +_DEFAULT_MAX_FORWARD_FETCH = 32 + + +def _read_int_mapping( + mapping: Mapping[str, Any], + key: str, + default: int, +) -> int: + raw = mapping.get(key) + if raw is None: + return default + try: + value = int(raw) + except (TypeError, ValueError): + return default + if value <= 0: + return default + return value + + +def _read_bool_mapping( + mapping: Mapping[str, Any], + key: str, + default: bool, +) -> bool: + raw = mapping.get(key) + if raw is None: + return default + if isinstance(raw, bool): + return raw + if isinstance(raw, str): + lowered = raw.strip().lower() + if lowered in {"1", "true", "yes", "on"}: + return True + if lowered in {"0", "false", "no", "off"}: + return False + return default + + +@dataclass(frozen=True) +class QuotedMessageParserSettings: + max_component_chain_depth: int = _DEFAULT_MAX_COMPONENT_CHAIN_DEPTH + max_forward_node_depth: int = _DEFAULT_MAX_FORWARD_NODE_DEPTH + max_forward_fetch: int = _DEFAULT_MAX_FORWARD_FETCH + warn_on_action_failure: bool = False + + def with_overrides( + self, + overrides: Mapping[str, Any] | None = None, + ) -> QuotedMessageParserSettings: + if not overrides: + return self + return QuotedMessageParserSettings( + max_component_chain_depth=_read_int_mapping( + overrides, + "max_component_chain_depth", + self.max_component_chain_depth, + ), + max_forward_node_depth=_read_int_mapping( + overrides, + "max_forward_node_depth", + self.max_forward_node_depth, + ), + max_forward_fetch=_read_int_mapping( + overrides, + "max_forward_fetch", + self.max_forward_fetch, + ), + warn_on_action_failure=_read_bool_mapping( + overrides, + "warn_on_action_failure", + self.warn_on_action_failure, + ), + ) + + +SETTINGS = QuotedMessageParserSettings() diff --git a/astrbot/core/utils/quoted_message_parser.py b/astrbot/core/utils/quoted_message_parser.py new file mode 100644 index 000000000..fa6ac18dd --- /dev/null +++ b/astrbot/core/utils/quoted_message_parser.py @@ -0,0 +1,11 @@ +from __future__ import annotations + +from astrbot.core.utils.quoted_message.extractor import ( + extract_quoted_message_images, + extract_quoted_message_text, +) + +__all__ = [ + "extract_quoted_message_text", + "extract_quoted_message_images", +] diff --git a/astrbot/core/utils/string_utils.py b/astrbot/core/utils/string_utils.py new file mode 100644 index 000000000..8c2aacb4d --- /dev/null +++ b/astrbot/core/utils/string_utils.py @@ -0,0 +1,21 @@ +from __future__ import annotations + +from collections.abc import Iterable +from typing import Any + + +def normalize_and_dedupe_strings(items: Iterable[Any] | None) -> list[str]: + if items is None: + return [] + + normalized: list[str] = [] + seen: set[str] = set() + for item in items: + if not isinstance(item, str): + continue + cleaned = item.strip() + if not cleaned or cleaned in seen: + continue + seen.add(cleaned) + normalized.append(cleaned) + return normalized diff --git a/dashboard/src/i18n/locales/en-US/features/config-metadata.json b/dashboard/src/i18n/locales/en-US/features/config-metadata.json index e1019cc8a..1fdd2ca91 100644 --- a/dashboard/src/i18n/locales/en-US/features/config-metadata.json +++ b/dashboard/src/i18n/locales/en-US/features/config-metadata.json @@ -247,6 +247,28 @@ "description": "Sanitize History by Modalities", "hint": "When enabled, sanitizes contexts before each LLM request by removing image blocks and tool-call structures that the current provider's modalities do not support (this changes what the model sees)." }, + "max_quoted_fallback_images": { + "description": "Forwarded Image Fetch Limit", + "hint": "Maximum number of images injected from forwarded-message parsing; extra images are truncated." + }, + "quoted_message_parser": { + "max_component_chain_depth": { + "description": "Forwarded Rich-Text Parse Depth", + "hint": "Maximum recursive depth when parsing rich-text component chains inside forwarded messages." + }, + "max_forward_node_depth": { + "description": "Forward Nesting Parse Depth", + "hint": "Maximum recursive depth when parsing nested forwarded nodes." + }, + "max_forward_fetch": { + "description": "Forward Recursive Fetch Limit", + "hint": "Maximum number of recursive get_forward_msg fetch operations." + }, + "warn_on_action_failure": { + "description": "Warn on Forward Parse Failure", + "hint": "When enabled, log warnings when all get_msg/get_forward_msg attempts fail." + } + }, "max_agent_step": { "description": "Maximum Tool Call Rounds" }, diff --git a/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json b/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json index 67681aa1d..67c1d490e 100644 --- a/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json +++ b/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json @@ -250,6 +250,28 @@ "description": "按模型能力清理历史上下文", "hint": "开启后,在每次请求 LLM 前会按当前模型提供商中所选择的模型能力删除对话中不支持的图片/工具调用结构(会改变模型看到的历史)" }, + "max_quoted_fallback_images": { + "description": "转发消息中图片获取上限", + "hint": "转发消息解析到的图片最多注入数量,超出部分会截断。" + }, + "quoted_message_parser": { + "max_component_chain_depth": { + "description": "转发消息富文本解析深度", + "hint": "解析转发消息中的富文本组件链时允许的最大递归深度。" + }, + "max_forward_node_depth": { + "description": "转发消息嵌套解析深度", + "hint": "解析嵌套转发节点时允许的最大递归深度。" + }, + "max_forward_fetch": { + "description": "转发消息递归拉取上限", + "hint": "递归调用 get_forward_msg 拉取转发内容的最大次数。" + }, + "warn_on_action_failure": { + "description": "转发消息解析失败告警", + "hint": "开启后,get_msg/get_forward_msg 全部尝试失败时输出 warning 日志。" + } + }, "max_agent_step": { "description": "工具调用轮数上限" }, diff --git a/tests/test_openai_source.py b/tests/test_openai_source.py new file mode 100644 index 000000000..3172097c7 --- /dev/null +++ b/tests/test_openai_source.py @@ -0,0 +1,382 @@ +from types import SimpleNamespace + +import pytest + +from astrbot.core.provider.sources.openai_source import ProviderOpenAIOfficial + + +class _ErrorWithBody(Exception): + def __init__(self, message: str, body: dict): + super().__init__(message) + self.body = body + + +class _ErrorWithResponse(Exception): + def __init__(self, message: str, response_text: str): + super().__init__(message) + self.response = SimpleNamespace(text=response_text) + + +def _make_provider(overrides: dict | None = None) -> ProviderOpenAIOfficial: + provider_config = { + "id": "test-openai", + "type": "openai_chat_completion", + "model": "gpt-4o-mini", + "key": ["test-key"], + } + if overrides: + provider_config.update(overrides) + return ProviderOpenAIOfficial( + provider_config=provider_config, + provider_settings={}, + ) + + +@pytest.mark.asyncio +async def test_handle_api_error_content_moderated_removes_images(): + provider = _make_provider( + {"image_moderation_error_patterns": ["file:content-moderated"]} + ) + try: + payloads = { + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": "hello"}, + { + "type": "image_url", + "image_url": {"url": "data:image/jpeg;base64,abcd"}, + }, + ], + } + ] + } + context_query = payloads["messages"] + + success, *_rest = await provider._handle_api_error( + Exception("Content is moderated [WKE=file:content-moderated]"), + payloads=payloads, + context_query=context_query, + func_tool=None, + chosen_key="test-key", + available_api_keys=["test-key"], + retry_cnt=0, + max_retries=10, + ) + + assert success is False + updated_context = payloads["messages"] + assert isinstance(updated_context, list) + assert updated_context[0]["content"] == [{"type": "text", "text": "hello"}] + finally: + await provider.terminate() + + +@pytest.mark.asyncio +async def test_handle_api_error_model_not_vlm_removes_images_and_retries_text_only(): + provider = _make_provider() + try: + payloads = { + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": "hello"}, + { + "type": "image_url", + "image_url": {"url": "data:image/jpeg;base64,abcd"}, + }, + ], + } + ] + } + context_query = payloads["messages"] + + success, *_rest = await provider._handle_api_error( + Exception("The model is not a VLM and cannot process images"), + payloads=payloads, + context_query=context_query, + func_tool=None, + chosen_key="test-key", + available_api_keys=["test-key"], + retry_cnt=0, + max_retries=10, + ) + + assert success is False + updated_context = payloads["messages"] + assert isinstance(updated_context, list) + assert updated_context[0]["content"] == [{"type": "text", "text": "hello"}] + finally: + await provider.terminate() + + +@pytest.mark.asyncio +async def test_handle_api_error_model_not_vlm_after_fallback_raises(): + provider = _make_provider() + try: + payloads = { + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": "hello"}, + { + "type": "image_url", + "image_url": {"url": "data:image/jpeg;base64,abcd"}, + }, + ], + } + ] + } + context_query = payloads["messages"] + + with pytest.raises(Exception, match="not a VLM"): + await provider._handle_api_error( + Exception("The model is not a VLM and cannot process images"), + payloads=payloads, + context_query=context_query, + func_tool=None, + chosen_key="test-key", + available_api_keys=["test-key"], + retry_cnt=1, + max_retries=10, + image_fallback_used=True, + ) + finally: + await provider.terminate() + + +@pytest.mark.asyncio +async def test_handle_api_error_content_moderated_with_unserializable_body(): + provider = _make_provider({"image_moderation_error_patterns": ["blocked"]}) + try: + payloads = { + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": "hello"}, + { + "type": "image_url", + "image_url": {"url": "data:image/jpeg;base64,abcd"}, + }, + ], + } + ] + } + context_query = payloads["messages"] + err = _ErrorWithBody( + "upstream error", + {"error": {"message": "blocked"}, "raw": object()}, + ) + + success, *_rest = await provider._handle_api_error( + err, + payloads=payloads, + context_query=context_query, + func_tool=None, + chosen_key="test-key", + available_api_keys=["test-key"], + retry_cnt=0, + max_retries=10, + ) + assert success is False + assert payloads["messages"][0]["content"] == [{"type": "text", "text": "hello"}] + finally: + await provider.terminate() + + +def test_extract_error_text_candidates_truncates_long_response_text(): + long_text = "x" * 20000 + err = _ErrorWithResponse("upstream error", long_text) + candidates = ProviderOpenAIOfficial._extract_error_text_candidates(err) + assert candidates + assert max(len(candidate) for candidate in candidates) <= ( + ProviderOpenAIOfficial._ERROR_TEXT_CANDIDATE_MAX_CHARS + ) + + +@pytest.mark.asyncio +async def test_handle_api_error_content_moderated_without_images_raises(): + provider = _make_provider( + {"image_moderation_error_patterns": ["file:content-moderated"]} + ) + try: + payloads = { + "messages": [ + { + "role": "user", + "content": [{"type": "text", "text": "hello"}], + } + ] + } + context_query = payloads["messages"] + err = Exception("Content is moderated [WKE=file:content-moderated]") + + with pytest.raises(Exception, match="content-moderated"): + await provider._handle_api_error( + err, + payloads=payloads, + context_query=context_query, + func_tool=None, + chosen_key="test-key", + available_api_keys=["test-key"], + retry_cnt=0, + max_retries=10, + ) + finally: + await provider.terminate() + + +@pytest.mark.asyncio +async def test_handle_api_error_content_moderated_detects_structured_body(): + provider = _make_provider( + {"image_moderation_error_patterns": ["content_moderated"]} + ) + try: + payloads = { + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": "hello"}, + { + "type": "image_url", + "image_url": {"url": "data:image/jpeg;base64,abcd"}, + }, + ], + } + ] + } + context_query = payloads["messages"] + err = _ErrorWithBody( + "upstream error", + {"error": {"code": "content_moderated", "message": "blocked"}}, + ) + + success, *_rest = await provider._handle_api_error( + err, + payloads=payloads, + context_query=context_query, + func_tool=None, + chosen_key="test-key", + available_api_keys=["test-key"], + retry_cnt=0, + max_retries=10, + ) + assert success is False + assert payloads["messages"][0]["content"] == [{"type": "text", "text": "hello"}] + finally: + await provider.terminate() + + +@pytest.mark.asyncio +async def test_handle_api_error_content_moderated_supports_custom_patterns(): + provider = _make_provider( + {"image_moderation_error_patterns": ["blocked_by_policy_code_123"]} + ) + try: + payloads = { + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": "hello"}, + { + "type": "image_url", + "image_url": {"url": "data:image/jpeg;base64,abcd"}, + }, + ], + } + ] + } + context_query = payloads["messages"] + err = Exception("upstream: blocked_by_policy_code_123") + + success, *_rest = await provider._handle_api_error( + err, + payloads=payloads, + context_query=context_query, + func_tool=None, + chosen_key="test-key", + available_api_keys=["test-key"], + retry_cnt=0, + max_retries=10, + ) + assert success is False + assert payloads["messages"][0]["content"] == [{"type": "text", "text": "hello"}] + finally: + await provider.terminate() + + +@pytest.mark.asyncio +async def test_handle_api_error_content_moderated_without_patterns_raises(): + provider = _make_provider() + try: + payloads = { + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": "hello"}, + { + "type": "image_url", + "image_url": {"url": "data:image/jpeg;base64,abcd"}, + }, + ], + } + ] + } + context_query = payloads["messages"] + err = Exception("Content is moderated [WKE=file:content-moderated]") + + with pytest.raises(Exception, match="content-moderated"): + await provider._handle_api_error( + err, + payloads=payloads, + context_query=context_query, + func_tool=None, + chosen_key="test-key", + available_api_keys=["test-key"], + retry_cnt=0, + max_retries=10, + ) + finally: + await provider.terminate() + + +@pytest.mark.asyncio +async def test_handle_api_error_unknown_image_error_raises(): + provider = _make_provider() + try: + payloads = { + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": "hello"}, + { + "type": "image_url", + "image_url": {"url": "data:image/jpeg;base64,abcd"}, + }, + ], + } + ] + } + context_query = payloads["messages"] + + with pytest.raises(Exception, match="unknown provider image upload error"): + await provider._handle_api_error( + Exception("some unknown provider image upload error"), + payloads=payloads, + context_query=context_query, + func_tool=None, + chosen_key="test-key", + available_api_keys=["test-key"], + retry_cnt=0, + max_retries=10, + ) + finally: + await provider.terminate() diff --git a/tests/test_quoted_message_parser.py b/tests/test_quoted_message_parser.py new file mode 100644 index 000000000..0a0e126d5 --- /dev/null +++ b/tests/test_quoted_message_parser.py @@ -0,0 +1,494 @@ +from types import SimpleNamespace + +import pytest + +from astrbot.core.message.components import Image, Plain, Reply +from astrbot.core.utils.quoted_message_parser import ( + extract_quoted_message_images, + extract_quoted_message_text, +) + + +class _DummyAPI: + def __init__( + self, + responses: dict[tuple[str, str], dict], + param_responses: dict[tuple[str, tuple[tuple[str, str], ...]], dict] + | None = None, + ): + self._responses = responses + self._param_responses = param_responses or {} + + async def call_action(self, action: str, **params): + param_key = (action, tuple(sorted((k, str(v)) for k, v in params.items()))) + if param_key in self._param_responses: + return self._param_responses[param_key] + + msg_id = params.get("message_id") + if msg_id is None: + msg_id = params.get("id") + key = (action, str(msg_id)) + if key not in self._responses: + raise RuntimeError(f"no mock response for {key}") + return self._responses[key] + + +class _FailIfCalledAPI: + async def call_action(self, action: str, **params): + raise AssertionError( + f"call_action should not be called, got action={action}, params={params}" + ) + + +def _make_event( + reply: Reply, + responses: dict[tuple[str, str], dict] | None = None, + param_responses: dict[tuple[str, tuple[tuple[str, str], ...]], dict] | None = None, +): + if responses is None: + responses = {} + if param_responses is None: + param_responses = {} + return SimpleNamespace( + message_obj=SimpleNamespace(message=[reply]), + bot=SimpleNamespace(api=_DummyAPI(responses, param_responses)), + get_group_id=lambda: "", + ) + + +@pytest.mark.asyncio +async def test_extract_quoted_message_text_from_reply_chain(): + reply = Reply(id="1", chain=[Plain(text="quoted content")], message_str="") + event = _make_event(reply) + text = await extract_quoted_message_text(event) + assert text == "quoted content" + + +@pytest.mark.asyncio +async def test_extract_quoted_message_text_no_reply_component(): + event = SimpleNamespace( + message_obj=SimpleNamespace(message=[Plain(text="unquoted message")]), + bot=SimpleNamespace(api=_DummyAPI({}, {})), + get_group_id=lambda: "", + ) + + text = await extract_quoted_message_text(event) + assert text is None + + +@pytest.mark.asyncio +async def test_extract_quoted_message_images_no_reply_component(): + event = SimpleNamespace( + message_obj=SimpleNamespace(message=[Plain(text="unquoted message")]), + bot=SimpleNamespace(api=_FailIfCalledAPI()), + get_group_id=lambda: "", + ) + + images = await extract_quoted_message_images(event) + assert images == [] + + +@pytest.mark.asyncio +@pytest.mark.parametrize("reply_id", [None, ""]) +async def test_extract_quoted_message_text_reply_without_id_does_not_call_get_msg( + reply_id: str | None, +): + reply = Reply( + id="placeholder", chain=[Plain(text="quoted content")], message_str="" + ) + object.__setattr__(reply, "id", reply_id) + event = SimpleNamespace( + message_obj=SimpleNamespace(message=[reply]), + bot=SimpleNamespace(api=_FailIfCalledAPI()), + get_group_id=lambda: "", + ) + + text = await extract_quoted_message_text(event) + assert text == "quoted content" + + +@pytest.mark.asyncio +async def test_extract_quoted_message_text_fallback_get_msg_and_forward(): + reply = Reply(id="100", chain=None, message_str="") + event = _make_event( + reply, + responses={ + ( + "get_msg", + "100", + ): { + "data": { + "message": [ + {"type": "text", "data": {"text": "parent"}}, + {"type": "forward", "data": {"id": "fwd_1"}}, + ] + } + }, + ( + "get_forward_msg", + "fwd_1", + ): { + "data": { + "messages": [ + { + "sender": {"nickname": "Alice"}, + "message": [{"type": "text", "data": {"text": "hello"}}], + }, + { + "sender": {"nickname": "Bob"}, + "message": [ + {"type": "image", "data": {"url": "http://img"}}, + {"type": "text", "data": {"text": "world"}}, + ], + }, + ] + } + }, + }, + ) + + text = await extract_quoted_message_text(event) + assert text is not None + assert "parent" in text + assert "Alice: hello" in text + assert "Bob: [Image]world" in text + + +@pytest.mark.parametrize( + "placeholder_text", + [ + "[Forward Message]", + "[转发消息]", + "[合并转发]", + "Alice: [Forward Message]", + "(Alice): [转发消息]", + "[Forward Message]\n[转发消息]", + "Alice: [Forward Message]\n(Bob): [合并转发]", + "[转发消息]\n\n[合并转发]", + ], +) +@pytest.mark.asyncio +async def test_extract_quoted_message_text_forward_placeholder_variants_trigger_fallback( + placeholder_text: str, +): + reply = Reply(id="400", chain=[Plain(text=placeholder_text)], message_str="") + event = _make_event( + reply, + responses={ + ("get_msg", "400"): { + "data": { + "message": [ + {"type": "text", "data": {"text": "Bob: "}}, + {"type": "image", "data": {}}, + {"type": "text", "data": {"text": "world"}}, + ] + } + } + }, + ) + + text = await extract_quoted_message_text(event) + assert "Bob: [Image]world" in text + + +@pytest.mark.asyncio +async def test_extract_quoted_message_text_mixed_placeholder_does_not_trigger_fallback(): + reply = Reply( + id="402", + chain=[Plain(text="Alice: [Forward Message]\nreal text")], + message_str="", + ) + event = SimpleNamespace( + message_obj=SimpleNamespace(message=[reply]), + bot=SimpleNamespace(api=_FailIfCalledAPI()), + get_group_id=lambda: "", + ) + + text = await extract_quoted_message_text(event) + assert text is not None + assert "[Forward Message]" in text + assert "real text" in text + + +@pytest.mark.asyncio +async def test_extract_quoted_message_text_forward_placeholder_fallback_failure(): + reply = Reply(id="401", chain=[Plain(text="[Forward Message]")], message_str="") + event = _make_event(reply, responses={}) + + text = await extract_quoted_message_text(event) + assert text == "[Forward Message]" + + +@pytest.mark.asyncio +async def test_extract_quoted_message_text_multimsg_malformed_config_does_not_raise(): + reply = Reply(id="402", chain=None, message_str="") + event = _make_event( + reply, + responses={ + ("get_msg", "402"): { + "data": { + "message": [ + { + "type": "json", + "data": { + "data": ( + '{"app":"com.tencent.multimsg",' + '"config":"oops","meta":{}}' + ) + }, + }, + {"type": "text", "data": {"text": "still works"}}, + ] + } + } + }, + ) + + text = await extract_quoted_message_text(event) + assert text == "still works" + + +@pytest.mark.asyncio +async def test_extract_quoted_message_images_from_reply_chain(): + reply = Reply( + id="1", + chain=[ + Plain(text="quoted"), + Image(file="https://img.example.com/a.jpg"), + ], + message_str="", + ) + event = _make_event(reply) + + images = await extract_quoted_message_images(event) + assert images == ["https://img.example.com/a.jpg"] + + +@pytest.mark.asyncio +async def test_extract_quoted_message_images_fallback_get_msg_direct_url(): + reply = Reply(id="200", chain=None, message_str="") + event = _make_event( + reply, + responses={ + ("get_msg", "200"): { + "data": { + "message": [ + { + "type": "image", + "data": {"url": "https://img.example.com/direct.jpg"}, + } + ] + } + } + }, + ) + + images = await extract_quoted_message_images(event) + assert images == ["https://img.example.com/direct.jpg"] + + +@pytest.mark.asyncio +async def test_extract_quoted_message_images_data_image_ref_normalized_to_base64(): + data_image_ref = "data:image/png;base64,abcd1234==" + reply = Reply(id="201", chain=None, message_str="") + event = _make_event( + reply, + responses={ + ("get_msg", "201"): { + "data": { + "message": [ + {"type": "image", "data": {"url": data_image_ref}}, + ] + } + } + }, + ) + + images = await extract_quoted_message_images(event) + assert images == ["base64://abcd1234=="] + + +@pytest.mark.asyncio +async def test_extract_quoted_message_images_file_url_with_query_string(): + url_with_query = "https://img.example.com/direct.jpg?token=abc123#frag" + reply = Reply(id="205", chain=None, message_str="") + event = _make_event( + reply, + responses={ + ("get_msg", "205"): { + "data": { + "message": [ + { + "type": "file", + "data": { + "url": url_with_query, + "name": "direct.jpg", + }, + } + ] + } + } + }, + ) + + images = await extract_quoted_message_images(event) + assert images == [url_with_query] + + +@pytest.mark.asyncio +async def test_extract_quoted_message_images_non_image_local_path_is_ignored(tmp_path): + non_image_file = tmp_path / "secret.txt" + non_image_file.write_text("not an image", encoding="utf-8") + + reply = Reply( + id="placeholder", chain=[Image(file=str(non_image_file))], message_str="" + ) + object.__setattr__(reply, "id", None) + event = SimpleNamespace( + message_obj=SimpleNamespace(message=[reply]), + bot=SimpleNamespace(api=_FailIfCalledAPI()), + get_group_id=lambda: "", + ) + + images = await extract_quoted_message_images(event) + assert images == [] + + +@pytest.mark.asyncio +async def test_extract_quoted_message_images_chain_placeholder_triggers_fallback(): + reply = Reply(id="210", chain=[Plain(text="[Forward Message]")], message_str="") + event = _make_event( + reply, + responses={ + ("get_msg", "210"): { + "data": { + "message": [ + { + "type": "image", + "data": { + "url": "https://img.example.com/from-fallback.jpg" + }, + } + ] + } + } + }, + ) + + images = await extract_quoted_message_images(event) + assert images == ["https://img.example.com/from-fallback.jpg"] + + +@pytest.mark.asyncio +async def test_extract_quoted_message_images_fallback_resolve_file_id_with_get_image(): + reply = Reply(id="300", chain=None, message_str="") + event = _make_event( + reply, + responses={ + ("get_msg", "300"): { + "data": {"message": [{"type": "image", "data": {"file": "abc123.jpg"}}]} + } + }, + param_responses={ + ("get_image", (("file", "abc123.jpg"),)): { + "data": {"url": "https://img.example.com/resolved.jpg"} + } + }, + ) + + images = await extract_quoted_message_images(event) + assert images == ["https://img.example.com/resolved.jpg"] + + +@pytest.mark.asyncio +async def test_extract_quoted_message_images_deduplicates_across_sources(): + dup_url = "https://img.example.com/dup.jpg" + chain_only_url = "https://img.example.com/only-chain.jpg" + get_msg_only_url = "https://img.example.com/only-get-msg.jpg" + forward_only_url = "https://img.example.com/only-forward.jpg" + + reply = Reply( + id="310", + chain=[Image(file=dup_url), Image(file=chain_only_url)], + message_str="", + ) + + event = _make_event( + reply, + responses={ + ("get_msg", "310"): { + "data": { + "message": [ + {"type": "image", "data": {"url": dup_url}}, + {"type": "image", "data": {"url": get_msg_only_url}}, + {"type": "forward", "data": {"id": "999"}}, + ] + } + }, + ("get_forward_msg", "999"): { + "data": { + "messages": [ + { + "sender": {"nickname": "Tester"}, + "message": [ + {"type": "image", "data": {"url": dup_url}}, + {"type": "image", "data": {"url": forward_only_url}}, + ], + } + ] + } + }, + }, + ) + + images = await extract_quoted_message_images(event) + assert images == [ + dup_url, + chain_only_url, + get_msg_only_url, + forward_only_url, + ] + + +@pytest.mark.asyncio +async def test_extract_quoted_message_nested_forward_id_is_resolved(): + nested_image = "https://img.example.com/nested.jpg" + reply = Reply(id="320", chain=[Plain(text="[Forward Message]")], message_str="") + event = _make_event( + reply, + responses={ + ("get_msg", "320"): { + "data": {"message": [{"type": "forward", "data": {"id": "fwd_1"}}]} + }, + ("get_forward_msg", "fwd_1"): { + "data": { + "messages": [ + { + "sender": {"nickname": "Alice"}, + "message": [{"type": "forward", "data": {"id": "fwd_2"}}], + } + ] + } + }, + ("get_forward_msg", "fwd_2"): { + "data": { + "messages": [ + { + "sender": {"nickname": "Bob"}, + "message": [ + {"type": "text", "data": {"text": "deep"}}, + {"type": "image", "data": {"url": nested_image}}, + ], + } + ] + } + }, + }, + ) + + text = await extract_quoted_message_text(event) + assert text is not None + assert "Bob: deep" in text + + images = await extract_quoted_message_images(event) + assert images == [nested_image]