diff --git a/astrbot/core/pipeline/process_stage/method/llm_request.py b/astrbot/core/pipeline/process_stage/method/llm_request.py index 81ac5c36e..f9d42bc31 100644 --- a/astrbot/core/pipeline/process_stage/method/llm_request.py +++ b/astrbot/core/pipeline/process_stage/method/llm_request.py @@ -51,7 +51,7 @@ class LLMRequestSubStage(Stage): session_provider_context = provider.session_memory.get(event.session_id) req.contexts = session_provider_context if session_provider_context else [] - if not req.prompt: + if not req.prompt and not req.image_urls: return # 执行请求 LLM 前事件。 diff --git a/astrbot/core/platform/sources/gewechat/client.py b/astrbot/core/platform/sources/gewechat/client.py index 886c053f1..cc9eab512 100644 --- a/astrbot/core/platform/sources/gewechat/client.py +++ b/astrbot/core/platform/sources/gewechat/client.py @@ -2,10 +2,14 @@ import threading import asyncio import aiohttp import quart +import base64 from astrbot.api.platform import AstrBotMessage, MessageMember, MessageType -from astrbot.api.message_components import Plain, Image, At +from astrbot.api.message_components import Plain, Image, At, Record from astrbot.api import logger, sp +from .downloader import GeweDownloader +from astrbot.core.utils.io import download_image_by_url + class SimpleGewechatClient(): '''针对 Gewechat 的简单实现。 @@ -17,9 +21,15 @@ class SimpleGewechatClient(): self.base_url = base_url if self.base_url.endswith('/'): self.base_url = self.base_url[:-1] + + self.download_base_url = self.base_url.split(':')[:-1] # 去掉端口 + self.download_base_url = ':'.join(self.download_base_url) + ":2532/download/" self.base_url += "/v2/api" + logger.info(f"Gewechat API: {self.base_url}") + logger.info(f"Gewechat 下载 API: {self.download_base_url}") + if isinstance(port, str): port = int(port) @@ -27,15 +37,19 @@ class SimpleGewechatClient(): self.headers = {} self.nickname = nickname self.appid = sp.get(f"gewechat-appid-{nickname}", "") - self.callback_url = None self.server = quart.Quart(__name__) - self.server.add_url_rule('/astrbot-gewechat/callback', view_func=self.callback, methods=['POST']) + self.server.add_url_rule('/astrbot-gewechat/callback', view_func=self.callback, methods=['POST']) + self.server.add_url_rule('/astrbot-gewechat/file/', view_func=self.handle_file, methods=['GET']) self.host = host self.port = port + self.callback_url = f"http://{self.host}:{self.port}/astrbot-gewechat/callback" + self.file_server_url = f"http://{self.host}:{self.port}/astrbot-gewechat/file" self.event_queue = event_queue + + self.multimedia_downloader = None async def get_token_id(self): async with aiohttp.ClientSession() as session: @@ -54,55 +68,84 @@ class SimpleGewechatClient(): return abm = AstrBotMessage() d = data['Data'] - msg_type = d['MsgType'] + + from_user_name = d['FromUserName']['string'] # 消息来源 + d['to_wxid'] = from_user_name # 用于发信息 - match msg_type: + abm.message_id = str(d.get('MsgId')) + abm.session_id = from_user_name + abm.self_id = data['Wxid'] # 机器人的 wxid + + user_id = "" # 发送人 wxid + content = d['Content']['string'] # 消息内容 + + at_me = False + if "@chatroom" in from_user_name: + abm.type = MessageType.GROUP_MESSAGE + _t = content.split(':\n') + user_id = _t[0] + content = _t[1] + if '\u2005' in content: + # at + content = content.split('\u2005')[1] + abm.group_id = from_user_name + # at + msg_source = d['MsgSource'] + if f'' in msg_source \ + or f'' in msg_source: + at_me = True + else: + abm.type = MessageType.FRIEND_MESSAGE + user_id = from_user_name + + abm.message = [] + if at_me: + abm.message.insert(0, At(qq=abm.self_id)) + + user_real_name = d['PushContent'].split(' : ')[0] \ + .replace('在群聊中@了你', '') \ + .replace('在群聊中发了一段语音', '') # 真实昵称 + abm.sender = MessageMember(user_id, user_real_name) + abm.raw_message = d + abm.message_str = "" + # 不同消息类型 + match d['MsgType']: case 1: - from_user_name = d['FromUserName']['string'] # 消息来源 - d['to_wxid'] = from_user_name # 用于发信息 - - user_id = "" # 发送人 wxid - content = d['Content']['string'] # 消息内容 - user_real_name = d['PushContent'].split(' : ')[0] # 真实昵称 - user_real_name = user_real_name.replace('在群聊中@了你', '') # trick - abm.self_id = data['Wxid'] # 机器人的 wxid - at_me = False - if "@chatroom" in from_user_name: - abm.type = MessageType.GROUP_MESSAGE - _t = content.split(':\n') - user_id = _t[0] - content = _t[1] - if '\u2005' in content: - # at - content = content.split('\u2005')[1] - - abm.group_id = from_user_name - - # at - msg_source = d['MsgSource'] - if f'' in msg_source \ - or f'' in msg_source: - at_me = True - - else: - abm.type = MessageType.FRIEND_MESSAGE - user_id = from_user_name - abm.session_id = from_user_name - abm.sender = MessageMember(user_id, user_real_name) - abm.message = [Plain(content)] - - if at_me: - abm.message.insert(0, At(qq=abm.self_id)) - - abm.message_id = str(d['MsgId']) - abm.raw_message = d + # 文本消息 + abm.message.append(Plain(content)) abm.message_str = content + case 3: + # 图片消息 + file_url = await self.multimedia_downloader.download_image( + self.appid, + content + ) + logger.debug(f"下载图片: {file_url}") + file_path = await download_image_by_url(file_url) + abm.message.append(Image(file=file_path, url=file_path)) + + case 34: + # 语音消息 + # data = await self.multimedia_downloader.download_voice( + # self.appid, + # content, + # abm.message_id + # ) + # print(data) + if 'ImgBuf' in d and 'buffer' in d['ImgBuf']: + voice_data = base64.b64decode(d['ImgBuf']['buffer']) + file_path = f"data/temp/gewe_voice_{abm.message_id}.silk" + with open(file_path, "wb") as f: + f.write(voice_data) + abm.message.append(Record(file=file_path, url=file_path)) - logger.info(f"abm: {abm}") - return abm case _: - logger.error(f"未实现的消息类型: {msg_type}") - + logger.error(f"未实现的消息类型: {d['MsgType']}") + return + + logger.info(f"abm: {abm}") + return abm + async def callback(self): data = await quart.request.json logger.debug(f"收到 gewechat 回调: {data}") @@ -118,25 +161,28 @@ class SimpleGewechatClient(): await coro(abm) return quart.jsonify({"r": "AstrBot ACK"}) + + async def handle_file(self, file_id): + file_path = f"data/temp/{file_id}.jpg" + return await quart.send_file(file_path) async def _set_callback_url(self): logger.info("设置回调,请等待...") await asyncio.sleep(3) - callback_url = f"http://{self.host}:{self.port}/astrbot-gewechat/callback" async with aiohttp.ClientSession() as session: async with session.post( f"{self.base_url}/tools/setCallback", headers=self.headers, json={ "token": self.token, - "callbackUrl": callback_url + "callbackUrl": self.callback_url } ) as resp: json_blob = await resp.json() logger.info(f"设置回调结果: {json_blob}") if json_blob['ret'] != 200: raise Exception(f"设置回调失败: {json_blob}") - logger.info(f"将在 {callback_url} 上接收 gewechat 下发的消息。如果一直没收到消息请先尝试重启 AstrBot。") + logger.info(f"将在 {self.callback_url} 上接收 gewechat 下发的消息。如果一直没收到消息请先尝试重启 AstrBot。") async def start_polling(self): @@ -186,6 +232,8 @@ class SimpleGewechatClient(): async def login(self): if self.token is None: await self.get_token_id() + + self.multimedia_downloader = GeweDownloader(self.base_url, self.download_base_url, self.token) if self.appid: online = await self.check_online(self.appid) @@ -263,4 +311,20 @@ class SimpleGewechatClient(): json=payload ) as resp: json_blob = await resp.json() - logger.info(f"发送消息结果: {json_blob}") \ No newline at end of file + logger.debug(f"发送消息结果: {json_blob}") + + async def post_image(self, to_wxid, image_url: str): + payload = { + "appId": self.appid, + "toWxid": to_wxid, + "imgUrl": image_url, + } + + async with aiohttp.ClientSession() as session: + async with session.post( + f"{self.base_url}/message/postImage", + headers=self.headers, + json=payload + ) as resp: + json_blob = await resp.json() + logger.debug(f"发送图片结果: {json_blob}") \ No newline at end of file diff --git a/astrbot/core/platform/sources/gewechat/downloader.py b/astrbot/core/platform/sources/gewechat/downloader.py new file mode 100644 index 000000000..d0f2efce7 --- /dev/null +++ b/astrbot/core/platform/sources/gewechat/downloader.py @@ -0,0 +1,51 @@ +from astrbot import logger +import aiohttp +import json + +class GeweDownloader(): + def __init__(self, base_url: str, download_base_url: str, token: str): + self.base_url = base_url + self.download_base_url = download_base_url + self.headers = { + "Content-Type": "application/json", + "X-GEWE-TOKEN": token + } + + async def _post_json(self, baseurl: str, route: str, payload: dict): + async with aiohttp.ClientSession() as session: + async with session.post( + f"{baseurl}{route}", + headers=self.headers, + json=payload + ) as resp: + return await resp.read() + + async def download_voice(self, appid: str, xml: str, msg_id: str): + payload = { + "appId": appid, + "xml": xml, + "msgId": msg_id + } + return await self._post_json(self.base_url, "/message/downloadVoice", payload) + + async def download_image(self, appid: str, xml: str) -> str: + '''返回一个可下载的 URL''' + choices = [2, 3] # 2:常规图片 3:缩略图 + + for choice in choices: + try: + payload = { + "appId": appid, + "xml": xml, + "type": choice + } + data = await self._post_json(self.base_url, "/message/downloadImage", payload) + json_blob = json.loads(data) + if 'fileUrl' in json_blob['data']: + return self.download_base_url + json_blob['data']['fileUrl'] + + except BaseException as e: + logger.error(f"gewe download image: {e}") + continue + + raise Exception("无法下载图片") \ No newline at end of file diff --git a/astrbot/core/platform/sources/gewechat/gewechat_event.py b/astrbot/core/platform/sources/gewechat/gewechat_event.py index 3c81f65e6..2de45a218 100644 --- a/astrbot/core/platform/sources/gewechat/gewechat_event.py +++ b/astrbot/core/platform/sources/gewechat/gewechat_event.py @@ -1,6 +1,7 @@ import random import asyncio -from astrbot.core.utils.io import download_image_by_url +import os +from astrbot.core.utils.io import save_temp_img, download_image_by_url from astrbot.api import logger from astrbot.api.event import AstrMessageEvent, MessageChain from astrbot.api.platform import AstrBotMessage, PlatformMetadata @@ -34,5 +35,22 @@ class GewechatPlatformEvent(AstrMessageEvent): for comp in message.chain: if isinstance(comp, Plain): await self.client.post_text(to_wxid, comp.text) + elif isinstance(comp, Image): + img_url = comp.file + img_path = "" + if img_url.startswith("file:///"): + with open(comp.file[8:], "rb") as f: + img_path = save_temp_img(f.read()) + elif comp.file and comp.file.startswith("http"): + img_path = await download_image_by_url(comp.file) + + if not img_path: + logger.error("无法获取到图片路径。") + return + + file_id = os.path.basename(img_path).split(".")[0] + img_url = f"{self.client.file_server_url}/{file_id}" + logger.debug(f"gewe callback img url: {img_url}") + await self.client.post_image(to_wxid, img_url) await super().send(message) \ No newline at end of file diff --git a/astrbot/core/utils/io.py b/astrbot/core/utils/io.py index 2b9dc7fa9..b603dec7c 100644 --- a/astrbot/core/utils/io.py +++ b/astrbot/core/utils/io.py @@ -6,6 +6,8 @@ import time import aiohttp import base64 import zipfile +import uuid +from typing import Union from PIL import Image @@ -41,21 +43,21 @@ def port_checker(port: int, host: str = "localhost"): return False -def save_temp_img(img: Image) -> str: +def save_temp_img(img: Union[Image.Image, str]) -> str: os.makedirs("data/temp", exist_ok=True) - # 获得文件创建时间,清除超过1小时的 + # 获得文件创建时间,清除超过 12 小时的 try: for f in os.listdir("data/temp"): path = os.path.join("data/temp", f) if os.path.isfile(path): ctime = os.path.getctime(path) - if time.time() - ctime > 3600: + if time.time() - ctime > 3600*12: os.remove(path) except Exception as e: print(f"清除临时文件失败: {e}") # 获得时间戳 - timestamp = int(time.time()) + timestamp = f"{int(time.time())}_{uuid.uuid4().hex[:8]}" p = f"data/temp/{timestamp}.jpg" if isinstance(img, Image.Image):