diff --git a/astrbot/core/message/components.py b/astrbot/core/message/components.py index 7f73e5d91..7fdcf4a18 100644 --- a/astrbot/core/message/components.py +++ b/astrbot/core/message/components.py @@ -25,9 +25,11 @@ SOFTWARE. import base64 import json import os +import uuid import typing as T from enum import Enum from pydantic.v1 import BaseModel +from astrbot.core.utils.io import download_image_by_url, file_to_base64 class ComponentType(Enum): @@ -146,6 +148,51 @@ class Record(BaseMessageComponent): return Record(file=url, **_) raise Exception("not a valid url") + async def convert_to_file_path(self) -> str: + """将这个语音统一转换为本地文件路径。这个方法避免了手动判断语音数据类型,直接返回语音数据的本地路径(如果是网络 URL, 则会自动进行下载)。 + + Returns: + str: 语音的本地路径,以绝对路径表示。 + """ + if self.file and self.file.startswith("file:///"): + file_path = self.file[8:] + return file_path + elif self.file and self.file.startswith("http"): + file_path = await download_image_by_url(self.file) + return os.path.abspath(file_path) + elif self.file and self.file.startswith("base64://"): + bs64_data = self.file.removeprefix("base64://") + image_bytes = base64.b64decode(bs64_data) + file_path = f"data/temp/{uuid.uuid4()}.jpg" + with open(file_path, "wb") as f: + f.write(image_bytes) + return os.path.abspath(file_path) + elif os.path.exists(self.file): + file_path = self.file + return os.path.abspath(file_path) + else: + raise Exception(f"not a valid file: {self.file}") + + async def convert_to_base64(self) -> str: + """将语音统一转换为 base64 编码。这个方法避免了手动判断语音数据类型,直接返回语音数据的 base64 编码。 + + Returns: + str: 语音的 base64 编码,不以 base64:// 或者 data:image/jpeg;base64, 开头。 + """ + # convert to base64 + if self.file and self.file.startswith("file:///"): + bs64_data = file_to_base64(self.file[8:]) + elif self.file and self.file.startswith("http"): + file_path = await download_image_by_url(self.file) + bs64_data = file_to_base64(file_path) + elif self.file and self.file.startswith("base64://"): + bs64_data = self.file + elif os.path.exists(self.file): + bs64_data = file_to_base64(self.file) + else: + raise Exception(f"not a valid file: {self.file}") + return bs64_data + class Video(BaseMessageComponent): type: ComponentType = "Video" @@ -279,10 +326,6 @@ class Image(BaseMessageComponent): file_unique: T.Optional[str] = "" # 某些平台可能有图片缓存的唯一标识 def __init__(self, file: T.Optional[str], **_): - # for k in _.keys(): - # if (k == "_type" and _[k] not in ["flash", "show", None]) or \ - # (k == "c" and _[k] not in [2, 3]): - # logger.warn(f"Protocol: {k}={_[k]} doesn't match values") super().__init__(file=file, **_) @staticmethod @@ -307,6 +350,51 @@ class Image(BaseMessageComponent): def fromIO(IO): return Image.fromBytes(IO.read()) + async def convert_to_file_path(self) -> str: + """将这个图片统一转换为本地文件路径。这个方法避免了手动判断图片数据类型,直接返回图片数据的本地路径(如果是网络 URL, 则会自动进行下载)。 + + Returns: + str: 图片的本地路径,以绝对路径表示。 + """ + if self.file and self.file.startswith("file:///"): + image_file_path = self.file[8:] + return image_file_path + elif self.file and self.file.startswith("http"): + image_file_path = await download_image_by_url(self.file) + return os.path.abspath(image_file_path) + elif self.file and self.file.startswith("base64://"): + bs64_data = self.file.removeprefix("base64://") + image_bytes = base64.b64decode(bs64_data) + image_file_path = f"data/temp/{uuid.uuid4()}.jpg" + with open(image_file_path, "wb") as f: + f.write(image_bytes) + return os.path.abspath(image_file_path) + elif os.path.exists(self.file): + image_file_path = self.file + return os.path.abspath(image_file_path) + else: + raise Exception(f"not a valid file: {self.file}") + + async def convert_to_base64(self) -> str: + """将这个图片统一转换为 base64 编码。这个方法避免了手动判断图片数据类型,直接返回图片数据的 base64 编码。 + + Returns: + str: 图片的 base64 编码,不以 base64:// 或者 data:image/jpeg;base64, 开头。 + """ + # convert to base64 + if self.file and self.file.startswith("file:///"): + bs64_data = file_to_base64(self.file[8:]) + elif self.file and self.file.startswith("http"): + image_file_path = await download_image_by_url(self.file) + bs64_data = file_to_base64(image_file_path) + elif self.file and self.file.startswith("base64://"): + bs64_data = self.file + elif os.path.exists(self.file): + bs64_data = file_to_base64(self.file) + else: + raise Exception(f"not a valid file: {self.file}") + return bs64_data + class Reply(BaseMessageComponent): type: ComponentType = "Reply" diff --git a/astrbot/core/pipeline/process_stage/method/llm_request.py b/astrbot/core/pipeline/process_stage/method/llm_request.py index e8246805e..210e62a7c 100644 --- a/astrbot/core/pipeline/process_stage/method/llm_request.py +++ b/astrbot/core/pipeline/process_stage/method/llm_request.py @@ -64,8 +64,8 @@ class LLMRequestSubStage(Stage): req.func_tool = self.ctx.plugin_manager.context.get_llm_tool_manager() for comp in event.message_obj.message: if isinstance(comp, Image): - image_url = comp.url if comp.url else comp.file - req.image_urls.append(image_url) + image_path = await comp.convert_to_file_path() + req.image_urls.append(image_path) # 获取对话上下文 conversation_id = await self.conv_manager.get_curr_conversation_id( diff --git a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py index 57d0c4f5b..08990015e 100644 --- a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py +++ b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py @@ -3,8 +3,6 @@ import asyncio from astrbot.api.event import AstrMessageEvent, MessageChain from astrbot.api.message_components import Plain, Image, Record, At, Node, Nodes from aiocqhttp import CQHttp -from astrbot.core.utils.io import file_to_base64, download_image_by_url - class AiocqhttpMessageEvent(AstrMessageEvent): def __init__( @@ -24,18 +22,9 @@ class AiocqhttpMessageEvent(AstrMessageEvent): d["data"]["text"] = segment.text.strip() elif isinstance(segment, (Image, Record)): # convert to base64 - if segment.file and segment.file.startswith("file:///"): - bs64_data = file_to_base64(segment.file[8:]) - image_file_path = segment.file[8:] - elif segment.file and segment.file.startswith("http"): - image_file_path = await download_image_by_url(segment.file) - bs64_data = file_to_base64(image_file_path) - elif segment.file and segment.file.startswith("base64://"): - bs64_data = segment.file - else: - bs64_data = file_to_base64(segment.file) + bs64 = await segment.convert_to_base64() d["data"] = { - "file": bs64_data, + "file": bs64, } elif isinstance(segment, At): d["data"] = { diff --git a/astrbot/core/platform/sources/gewechat/gewechat_event.py b/astrbot/core/platform/sources/gewechat/gewechat_event.py index 7668663fb..247a2a6a4 100644 --- a/astrbot/core/platform/sources/gewechat/gewechat_event.py +++ b/astrbot/core/platform/sources/gewechat/gewechat_event.py @@ -70,18 +70,10 @@ class GewechatPlatformEvent(AstrMessageEvent): await client.post_text(**payload) elif isinstance(comp, Image): - img_url = comp.file - img_path = "" - if img_url.startswith("file:///"): - img_path = img_url[8:] - elif comp.file and comp.file.startswith("http"): - img_path = await download_image_by_url(comp.file) - else: - img_path = img_url + img_path = await comp.convert_to_file_path() - # 检查 record_path 是否在 data/temp 目录中, record_path 可能是绝对路径 + # 检查 record_path 是否在 data/temp 目录中 temp_directory = os.path.abspath("data/temp") - img_path = os.path.abspath(img_path) if os.path.commonpath([temp_directory, img_path]) != temp_directory: with open(img_path, "rb") as f: img_path = save_temp_img(f.read()) @@ -93,14 +85,7 @@ class GewechatPlatformEvent(AstrMessageEvent): elif isinstance(comp, Record): # 默认已经存在 data/temp 中 record_url = comp.file - record_path = "" - - if record_url.startswith("file:///"): - record_path = record_url[8:] - elif record_url.startswith("http"): - await download_file(record_url, f"data/temp/{uuid.uuid4()}.wav") - else: - record_path = record_url + record_path = await comp.convert_to_file_path() silk_path = f"data/temp/{uuid.uuid4()}.silk" try: diff --git a/astrbot/core/platform/sources/lark/lark_adapter.py b/astrbot/core/platform/sources/lark/lark_adapter.py index fd29b3602..1ee30c482 100644 --- a/astrbot/core/platform/sources/lark/lark_adapter.py +++ b/astrbot/core/platform/sources/lark/lark_adapter.py @@ -2,6 +2,7 @@ import base64 import asyncio import json import re +import astrbot.api.message_components as Comp from astrbot.api.platform import ( Platform, @@ -11,7 +12,6 @@ from astrbot.api.platform import ( PlatformMetadata, ) from astrbot.api.event import MessageChain -from astrbot.api.message_components import Image, Plain, At from astrbot.core.platform.astr_message_event import MessageSesion from .lark_event import LarkMessageEvent from ...register import register_platform_adapter @@ -92,7 +92,7 @@ class LarkPlatformAdapter(Platform): at_list = {} if message.mentions: for m in message.mentions: - at_list[m.key] = At(qq=m.id.open_id, name=m.name) + at_list[m.key] = Comp.At(qq=m.id.open_id, name=m.name) if m.name == self.bot_name: abm.self_id = m.id.open_id @@ -111,7 +111,7 @@ class LarkPlatformAdapter(Platform): if s in at_list: abm.message.append(at_list[s]) else: - abm.message.append(Plain(parts[i].strip())) + abm.message.append(Comp.Plain(parts[i].strip())) elif message.message_type == "post": _ls = [] @@ -132,7 +132,7 @@ class LarkPlatformAdapter(Platform): if comp["tag"] == "at": abm.message.append(at_list[comp["user_id"]]) elif comp["tag"] == "text" and comp["text"].strip(): - abm.message.append(Plain(comp["text"].strip())) + abm.message.append(Comp.Plain(comp["text"].strip())) elif comp["tag"] == "img": image_key = comp["image_key"] request = ( @@ -147,10 +147,10 @@ class LarkPlatformAdapter(Platform): logger.error(f"无法下载飞书图片: {image_key}") image_bytes = response.file.read() image_base64 = base64.b64encode(image_bytes).decode() - abm.message.append(Image.fromBase64(image_base64)) + abm.message.append(Comp.Image.fromBase64(image_base64)) for comp in abm.message: - if isinstance(comp, Plain): + if isinstance(comp, Comp.Plain): abm.message_str += comp.text abm.message_id = message.message_id abm.raw_message = message diff --git a/astrbot/core/platform/sources/telegram/tg_event.py b/astrbot/core/platform/sources/telegram/tg_event.py index a8a04e2e1..d19017a4f 100644 --- a/astrbot/core/platform/sources/telegram/tg_event.py +++ b/astrbot/core/platform/sources/telegram/tg_event.py @@ -51,19 +51,8 @@ class TelegramPlatformEvent(AstrMessageEvent): at_flag = True await client.send_message(text=i.text, **payload) elif isinstance(i, Image): - if i.path: - image_path = i.path - else: - image_path = i.file - - if image_path.startswith("base64://"): - import base64 - - base64_data = image_path[9:] - image_bytes = base64.b64decode(base64_data) - await client.send_photo(photo=image_bytes, **payload) - else: - await client.send_photo(photo=image_path, **payload) + image_path = await i.convert_to_file_path() + await client.send_photo(photo=image_path, **payload) elif isinstance(i, File): if i.file.startswith("https://"): path = "data/temp/" + i.name @@ -72,7 +61,8 @@ class TelegramPlatformEvent(AstrMessageEvent): await client.send_document(document=i.file, filename=i.name, **payload) elif isinstance(i, Record): - await client.send_voice(voice=i.file, **payload) + path = await i.convert_to_file_path() + await client.send_voice(voice=path, **payload) async def send(self, message: MessageChain): if self.get_message_type() == MessageType.GROUP_MESSAGE: diff --git a/astrbot/core/platform/sources/wecom/wecom_event.py b/astrbot/core/platform/sources/wecom/wecom_event.py index 83e99b5c4..c6f8d6ef6 100644 --- a/astrbot/core/platform/sources/wecom/wecom_event.py +++ b/astrbot/core/platform/sources/wecom/wecom_event.py @@ -43,14 +43,7 @@ class WecomPlatformEvent(AstrMessageEvent): message_obj.self_id, message_obj.session_id, comp.text ) elif isinstance(comp, Image): - img_url = comp.file - img_path = "" - if img_url.startswith("file:///"): - img_path = img_url[8:] - elif comp.file and comp.file.startswith("http"): - img_path = await download_image_by_url(comp.file) - else: - img_path = img_url + img_path = await comp.convert_to_file_path() with open(img_path, "rb") as f: try: @@ -68,16 +61,7 @@ class WecomPlatformEvent(AstrMessageEvent): response["media_id"], ) elif isinstance(comp, Record): - record_url = comp.file - record_path = "" - - if record_url.startswith("file:///"): - record_path = record_url[8:] - elif record_url.startswith("http"): - await download_file(record_url, f"data/temp/{uuid.uuid4()}.wav") - else: - record_path = record_url - + record_path = await comp.convert_to_file_path() # 转成amr record_path_amr = f"data/temp/{uuid.uuid4()}.amr" pydub.AudioSegment.from_wav(record_path).export( diff --git a/astrbot/core/provider/sources/dify_source.py b/astrbot/core/provider/sources/dify_source.py index b59cc16b7..8b5890c28 100644 --- a/astrbot/core/provider/sources/dify_source.py +++ b/astrbot/core/provider/sources/dify_source.py @@ -33,7 +33,6 @@ class ProviderDify(Provider): if not self.api_key: raise Exception("Dify API Key 不能为空。") api_base = provider_config.get("dify_api_base", "https://api.dify.ai/v1") - self.api_client = DifyAPIClient(self.api_key, api_base) self.api_type = provider_config.get("dify_api_type", "") if not self.api_type: raise Exception("Dify API 类型不能为空。") @@ -55,6 +54,8 @@ class ProviderDify(Provider): self.conversation_ids = {} """记录当前 session id 的对话 ID""" + self.api_client = DifyAPIClient(self.api_key, api_base) + async def text_chat( self, prompt: str, @@ -70,26 +71,27 @@ class ProviderDify(Provider): files_payload = [] for image_url in image_urls: - if image_url.startswith("http"): - image_path = await download_image_by_url(image_url) - file_response = await self.api_client.file_upload( - image_path, user=session_id + image_path = ( + await download_image_by_url(image_url) + if image_url.startswith("http") + else image_url + ) + file_response = await self.api_client.file_upload( + image_path, user=session_id + ) + logger.debug(f"Dify 上传图片响应:{file_response}") + if "id" not in file_response: + logger.warning( + f"上传图片后得到未知的 Dify 响应:{file_response},图片将忽略。" ) - if "id" not in file_response: - logger.warning( - f"上传图片后得到未知的 Dify 响应:{file_response},图片将忽略。" - ) - continue - files_payload.append( - { - "type": "image", - "transfer_method": "local_file", - "upload_file_id": file_response["id"], - } - ) - else: - # TODO: 处理更多情况 - logger.warning(f"未知的图片链接:{image_url},图片将忽略。") + continue + files_payload.append( + { + "type": "image", + "transfer_method": "local_file", + "upload_file_id": file_response["id"], + } + ) # 获得会话变量 payload_vars = self.variables.copy() diff --git a/astrbot/dashboard/routes/config.py b/astrbot/dashboard/routes/config.py index 443a544bb..088c999f9 100644 --- a/astrbot/dashboard/routes/config.py +++ b/astrbot/dashboard/routes/config.py @@ -31,7 +31,6 @@ def validate_config( def validate(data: dict, metadata: dict = schema, path=""): for key, value in data.items(): - print(key, value) if key not in metadata: # 无 schema 的配置项,执行类型猜测 if isinstance(value, str):