diff --git a/astrbot/core/message/components.py b/astrbot/core/message/components.py index 5020d6b26..74538d097 100644 --- a/astrbot/core/message/components.py +++ b/astrbot/core/message/components.py @@ -26,10 +26,12 @@ import base64 import json import os import uuid +import asyncio import typing as T from enum import Enum from pydantic.v1 import BaseModel -from astrbot.core.utils.io import download_image_by_url, file_to_base64 +from astrbot.core import logger +from astrbot.core.utils.io import download_image_by_url, file_to_base64, download_file class ComponentType(Enum): @@ -552,15 +554,91 @@ class Unknown(BaseMessageComponent): class File(BaseMessageComponent): """ - 目前此消息段只适配了 Napcat。 + 文件消息段 """ type: ComponentType = "File" name: T.Optional[str] = "" # 名字 - file: T.Optional[str] = "" # url(本地路径) + _file: T.Optional[str] = "" # 本地路径 + url: T.Optional[str] = "" # url + _downloaded: bool = False # 是否已经下载 - def __init__(self, name: str, file: str): - super().__init__(name=name, file=file) + def __init__(self, name: str = "", file: str = "", url: str = ""): + super().__init__(name=name, _file=file, url=url) + + @property + def file(self) -> str: + """ + 获取文件路径,如果文件不存在但有URL,则同步下载文件 + + Returns: + str: 文件路径 + """ + if self._file and os.path.exists(self._file): + return self._file + + if self.url and not self._downloaded: + try: + loop = asyncio.get_event_loop() + if loop.is_running(): + logger.warning( + "不可以在异步上下文中同步等待下载! 请使用 await get_file() 代替" + ) + return "" + else: + # 等待下载完成 + loop.run_until_complete(self._download_file()) + + if self._file and os.path.exists(self._file): + return self._file + except Exception as e: + logger.error(f"文件下载失败: {e}") + + return "" + + @file.setter + def file(self, value: str): + """ + 向前兼容, 设置file属性, 传入的参数可能是文件路径或URL + + Args: + value (str): 文件路径或URL + """ + if value.startswith("http://") or value.startswith("https://"): + self.url = value + else: + self._file = value + + async def get_file(self) -> str: + """ + 异步获取文件 + To 插件开发者: 请注意在使用后清理下载的文件, 以免占用过多空间 + + Returns: + str: 文件路径 + """ + if self._file and os.path.exists(self._file): + return self._file + + if self.url: + await self._download_file() + return self._file + + return "" + + async def _download_file(self): + """下载文件""" + if self._downloaded: + return + + os.makedirs("data/download", exist_ok=True) + filename = self.name or f"{uuid.uuid4().hex}" + file_path = f"data/download/{filename}" + + await download_file(self.url, file_path) + + self._file = file_path + self._downloaded = True class WechatEmoji(BaseMessageComponent): diff --git a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py index a1822c5e9..97754f2c9 100644 --- a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py +++ b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py @@ -1,4 +1,3 @@ -import os import time import asyncio import logging @@ -21,7 +20,6 @@ from .aiocqhttp_message_event import AiocqhttpMessageEvent from astrbot.core.platform.astr_message_event import MessageSesion from ...register import register_platform_adapter from aiocqhttp.exceptions import ActionFailed -from astrbot.core.utils.io import download_file @register_platform_adapter( @@ -167,7 +165,9 @@ class AiocqhttpAdapter(Platform): if "sub_type" in event: if event["sub_type"] == "poke" and "target_id" in event: - abm.message.append(Poke(qq=str(event["target_id"]), type="poke")) # noqa: F405 + abm.message.append( + Poke(qq=str(event["target_id"]), type="poke") + ) # noqa: F405 return abm @@ -227,32 +227,30 @@ class AiocqhttpAdapter(Platform): if m["data"].get("url") and m["data"].get("url").startswith("http"): # Lagrange logger.info("guessing lagrange") - file_name = m["data"].get("file_name", "file") - path = os.path.join("data/temp", file_name) - await download_file(m["data"]["url"], path) - - m["data"] = {"file": path, "name": file_name} - a = ComponentTypes[t](**m["data"]) # noqa: F405 - abm.message.append(a) - + abm.message.append(File(name=file_name, url=m["data"]["url"])) else: try: - # Napcat, LLBot - ret = await self.bot.call_action( - action="get_file", - file_id=event.message[0]["data"]["file_id"], - ) - if not ret.get("file", None): - raise ValueError(f"无法解析文件响应: {ret}") - if not os.path.exists(ret["file"]): - raise FileNotFoundError( - f"文件不存在或者权限问题: {ret['file']}。如果您使用 Docker 部署了 AstrBot 或者消息协议端(Napcat等),请先映射路径。如果路径在 /root 目录下,请用 sudo 打开 AstrBot" + # Napcat + ret = None + if abm.type == MessageType.GROUP_MESSAGE: + ret = await self.bot.call_action( + action="get_group_file_url", + file_id=event.message[0]["data"]["file_id"], + group_id=event.group_id, ) + elif abm.type == MessageType.FRIEND_MESSAGE: + ret = await self.bot.call_action( + action="get_private_file_url", + file_id=event.message[0]["data"]["file_id"], + ) + if ret and "url" in ret: + file_url = ret["url"] # https + a = File(name="", url=file_url) + abm.message.append(a) + else: + logger.error(f"获取文件失败: {ret}") - m["data"] = {"file": ret["file"], "name": ret["file_name"]} - a = ComponentTypes[t](**m["data"]) # noqa: F405 - abm.message.append(a) except ActionFailed as e: logger.error(f"获取文件失败: {e},此消息段将被忽略。") except BaseException as e: diff --git a/astrbot/core/platform/sources/telegram/tg_adapter.py b/astrbot/core/platform/sources/telegram/tg_adapter.py index d227f1f68..a2ce88736 100644 --- a/astrbot/core/platform/sources/telegram/tg_adapter.py +++ b/astrbot/core/platform/sources/telegram/tg_adapter.py @@ -58,8 +58,12 @@ class TelegramPlatformAdapter(Platform): self.base_url = base_url - self.enable_command_register = self.config.get("telegram_command_register", True) - self.enable_command_refresh = self.config.get("telegram_command_auto_refresh", True) + self.enable_command_register = self.config.get( + "telegram_command_register", True + ) + self.enable_command_refresh = self.config.get( + "telegram_command_auto_refresh", True + ) self.last_command_hash = None self.application = ( @@ -123,7 +127,9 @@ class TelegramPlatformAdapter(Platform): commands = self.collect_commands() if commands: - current_hash = hash(tuple((cmd.command, cmd.description) for cmd in commands)) + current_hash = hash( + tuple((cmd.command, cmd.description) for cmd in commands) + ) if current_hash == self.last_command_hash: return self.last_command_hash = current_hash