diff --git a/astrbot/core/message/components.py b/astrbot/core/message/components.py index 717a0f81b..50e81c86d 100644 --- a/astrbot/core/message/components.py +++ b/astrbot/core/message/components.py @@ -54,6 +54,7 @@ class ComponentType(Enum): CardImage = "CardImage" TTS = "TTS" Unknown = "Unknown" + File = "File" class BaseMessageComponent(BaseModel): @@ -415,6 +416,17 @@ class Unknown(BaseMessageComponent): def toString(self): return "" +class File(BaseMessageComponent): + ''' + 目前此消息段只适配了 Napcat。 + ''' + type: ComponentType = "File" + name: T.Optional[str] = "" # 名字 + file: T.Optional[str] = "" # url(本地路径) + + def __init__(self, name: str, file: str): + super().__init__(name=name, file=file) + ComponentTypes = { "plain": Plain, @@ -441,5 +453,6 @@ ComponentTypes = { "json": Json, "cardimage": CardImage, "tts": TTS, - "unknown": Unknown + "unknown": Unknown, + 'file': File, } diff --git a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py index fa7fca81d..3d1439a98 100644 --- a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py +++ b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py @@ -1,3 +1,4 @@ +import os import time import asyncio import logging @@ -5,12 +6,13 @@ from typing import Awaitable, Any from aiocqhttp import CQHttp, Event from astrbot.api.platform import Platform, AstrBotMessage, MessageMember, MessageType, PlatformMetadata from astrbot.api.event import MessageChain -from .aiocqhttp_message_event import * -from astrbot.api.message_components import * +from .aiocqhttp_message_event import * # noqa: F403 +from astrbot.api.message_components import * # noqa: F403 from astrbot.api import logger from .aiocqhttp_message_event import AiocqhttpMessageEvent from astrbot.core.platform.astr_message_event import MessageSesion from ...register import register_platform_adapter +from aiocqhttp.exceptions import ActionFailed @register_platform_adapter("aiocqhttp", "适用于 OneBot 标准的消息平台适配器,支持反向 WebSockets。") class AiocqhttpAdapter(Platform): @@ -42,7 +44,7 @@ class AiocqhttpAdapter(Platform): await self.bot.send_private_msg(user_id=session.session_id, message=ret) await super().send_by_session(session, message_chain) - def convert_message(self, event: Event) -> AstrBotMessage: + async def convert_message(self, event: Event) -> AstrBotMessage: abm = AstrBotMessage() abm.self_id = str(event.self_id) abm.tag = "aiocqhttp" @@ -78,7 +80,25 @@ class AiocqhttpAdapter(Platform): a = None if t == 'text': message_str += m['data']['text'].strip() - a = ComponentTypes[t](**m['data']) + elif t == 'file': + try: + # Napcat, LLBot + ret = await self.bot.call_action(action="get_file", file_id=event.message[0]['data']['file_id']) + if not ret.get('file', None): + raise ValueError(f"无法解析文件响应: {ret}") + if not os.path.exists(ret['file']): + raise FileNotFoundError(f"文件不存在: {ret['file']}。如果您使用 Docker 部署了 AstrBot 或者消息协议端(Napcat等),暂时无法获取用户上传的文件。") + + m['data'] = { + "file": ret['file'], + "name": ret['file_name'] + } + except ActionFailed as e: + logger.error(f"获取文件失败: {e},此消息段将被忽略。") + except BaseException as e: + logger.error(f"获取文件失败: {e},此消息段将被忽略。") + + a = ComponentTypes[t](**m['data']) # noqa: F405 abm.message.append(a) abm.timestamp = int(time.time()) abm.message_str = message_str @@ -91,13 +111,13 @@ class AiocqhttpAdapter(Platform): self.bot = CQHttp(use_ws_reverse=True, import_name='aiocqhttp', api_timeout_sec=180) @self.bot.on_message('group') async def group(event: Event): - abm = self.convert_message(event) + abm = await self.convert_message(event) if abm: await self.handle_msg(abm) @self.bot.on_message('private') async def private(event: Event): - abm = self.convert_message(event) + abm = await self.convert_message(event) if abm: await self.handle_msg(abm) diff --git a/astrbot/core/platform/sources/vchat/vchat_platform_adapter.py b/astrbot/core/platform/sources/vchat/vchat_platform_adapter.py index bfc03c013..f224ecf8f 100644 --- a/astrbot/core/platform/sources/vchat/vchat_platform_adapter.py +++ b/astrbot/core/platform/sources/vchat/vchat_platform_adapter.py @@ -2,6 +2,7 @@ import sys import time import uuid import asyncio +import os from astrbot.api.platform import Platform, AstrBotMessage, MessageMember, MessageType, PlatformMetadata from astrbot.api.event import MessageChain @@ -62,7 +63,7 @@ class VChatPlatformAdapter(Platform): self.start_time = int(time.time()) return self._run() - + async def _run(self): await self.client.init() await self.client.auto_login(hot_reload=True, enable_cmd_qr=True) diff --git a/packages/python_interpreter/main.py b/packages/python_interpreter/main.py index d88c9676a..9768398c7 100644 --- a/packages/python_interpreter/main.py +++ b/packages/python_interpreter/main.py @@ -1,15 +1,18 @@ import os import json +import shutil import aiohttp import uuid import asyncio import re import astrbot.api.star as star import aiodocker -from astrbot.api.event import AstrMessageEvent +from collections import defaultdict +from astrbot.api.event import AstrMessageEvent, MessageEventResult from astrbot.api import llm_tool, logger from astrbot.api.event import filter -from astrbot.api.message_components import Image +from astrbot.api.provider import ProviderRequest +from astrbot.api.message_components import Image, File PROMPT = """ ## Task @@ -18,19 +21,35 @@ You need to generate python codes to solve user's problem: {prompt} {extra_input} ## Limit -1. Available libraries: standard libs, `Pillow`, `requests`, `numpy`, `matplotlib`, `scipy`, `scikit-learn`, `beautifulsoup4`, `pandas`, `opencv-python` +1. Available libraries: + - standard libs + - `Pillow` + - `requests` + - `numpy` + - `matplotlib` + - `scipy` + - `scikit-learn` + - `beautifulsoup4` + - `pandas` + - `opencv-python` + - `python-docx` + - `python-pptx` + - `pymupdf` (Do not use fpdf, reportlab, etc.) + - `mplfonts` + You can only use these libraries and the libraries that they depend on. 2. Do not generate malicious code. -3. Only output text, image. For Image, you need save it to `output` folder. -4. Use given `shared.api` package to output the result. -5. You must only output the code, do not output the result of the code and any other information. -6. The output language is same as user's input language. -7. Please first provide relevant knowledge about user's problem appropriately. +3. Use given `shared.api` package to output the result. + It has 3 functions: `send_text(text: str)`, `send_image(image_path: str)`, `send_file(file_path: str)`. + For Image and file, you must save it to `output` folder. +4. You must only output the code, do not output the result of the code and any other information. +5. The output language is same as user's input language. +6. Please first provide relevant knowledge about user's problem appropriately. ## Example -1. The user's problem is: `please solve the fabonacci sequence problem.` +1. User's problem: `please solve the fabonacci sequence problem.` Output: ```python -from shared.api import send_text, send_image +from shared.api import send_text, send_image, send_file def fabonacci(n): if n <= 1: @@ -43,10 +62,10 @@ send_text("The fabonacci sequence is a series of numbers in which each number is send_text("Let's calculate the fabonacci sequence of 10: " + result) # send_text is a function to send pure text to user ``` -2. The user's problem is: `please draw a sin(x) function.` +2. User's problem: `please draw a sin(x) function.` Output: ```python -from shared.api import send_text, send_image +from shared.api import send_text, send_image, send_file import numpy as np import matplotlib.pyplot as plt @@ -80,6 +99,9 @@ class Main(star.Star): self.shared_path = os.path.join(self.curr_dir, "shared") os.makedirs(self.workplace_path, exist_ok=True) + self.user_file_msg_buffer = defaultdict(list) + '''存放用户上传的文件''' + # 加载配置 if not os.path.exists(PATH): self.config = DEFAULT_CONFIG @@ -88,6 +110,23 @@ class Main(star.Star): with open(PATH, "r") as f: self.config = json.load(f) + async def file_upload(self, file_path: str): + ''' + 上传图像文件到 S3 + ''' + ext = os.path.splitext(file_path)[1] + S3_URL = "https://s3.neko.soulter.top/astrbot-s3" + with open(file_path, "rb") as f: + file = f.read() + + s3_file_url = f"{S3_URL}/{uuid.uuid4().hex}{ext}" + + async with aiohttp.ClientSession(headers = {"Accept": "application/json"}) as session: + async with session.put(s3_file_url, data=file) as resp: + if resp.status != 200: + raise Exception(f"Failed to upload image: {resp.status}") + return s3_file_url + async def is_docker_available(self) -> bool: '''Check if docker is available''' @@ -131,6 +170,21 @@ class Main(star.Star): raise ValueError("The code is not in the code block.") return match.group(1) + @filter.event_message_type(filter.EventMessageType.ALL) + async def on_message(self, event: AstrMessageEvent): + '''处理消息''' + for comp in event.message_obj.message: + if isinstance(comp, File): + self.user_file_msg_buffer[event.get_session_id()].append(comp.file) + logger.debug(f"User uploaded file: {comp.file}") + break # 一个消息中,文件只能有一个,这里直接 break 减少计算量。 + + @filter.on_llm_request() + async def on_llm_req(self, event: AstrMessageEvent, request: ProviderRequest): + if event.get_session_id() in self.user_file_msg_buffer: + files = self.user_file_msg_buffer[event.get_session_id()] + request.prompt += f"\nUser provided files: {files}" + @filter.command_group("pi") def pi(self): @@ -166,7 +220,8 @@ class Main(star.Star): @llm_tool("python_interpreter") async def python_interpreter(self, event: AstrMessageEvent): - '''Use this tool only if user really want to solve a complex problem and the problem can be solved very well by Python code. For example, user can use this tool to solve a math problem, edit Image, etc. + '''Use this tool only if user really want to solve a complex problem and the problem can be solved very well by Python code. + For example, user can use this tool to solve math problems, edit image, docx, pptx, pdf, etc. ''' if not await self.is_docker_available(): yield event.plain_result("Docker 在当前机器不可用,无法沙箱化执行代码。") @@ -191,7 +246,23 @@ class Main(star.Star): if image_path: images.append(image_path) idx += 1 - + # 文件 + files = [] + for file_path in self.user_file_msg_buffer[event.get_session_id()]: + # cp + file_name = os.path.basename(file_path) + shutil.copy(file_path, os.path.join(workplace_path, file_name)) + files.append(file_name) + + logger.debug(f"user query: {plain_text}, images: {images}, files: {files}") + + # 整理额外输入 + extra_inputs = "" + if images: + extra_inputs += f"User provided images: {images}\n" + if files: + extra_inputs += f"User provided files: {files}\n" + obs = "" n = 5 @@ -201,7 +272,7 @@ class Main(star.Star): PROMPT_ = PROMPT.format( prompt=plain_text, - extra_input=f"User provided images: {images}", + extra_input=extra_inputs, extra_prompt=obs, ) provider = self.context.get_using_provider() @@ -253,7 +324,7 @@ class Main(star.Star): logger.debug(f"Container {container.id} logs: {logs}") # 发送结果 - pattern = r"\[ASTRBOT_(TEXT|IMAGE)_OUTPUT#\w+\]: (.*)" + pattern = r"\[ASTRBOT_(TEXT|IMAGE|FILE)_OUTPUT#\w+\]: (.*)" ok = False traceback = "" for idx, log in enumerate(logs): @@ -266,6 +337,15 @@ class Main(star.Star): image_path = os.path.join(workplace_path, match.group(2)) logger.debug(f"Sending image: {image_path}") yield event.image_result(image_path) + elif match.group(1) == "FILE": + file_path = os.path.join(workplace_path, match.group(2)) + logger.debug(f"Sending file: {file_path}") + file_s3_url = await self.file_upload(file_path) + logger.info(f"文件上传到 AstrBot 云节点: {file_s3_url}") + file_name = os.path.basename(file_path) + chain = [File(name=file_name, file=file_s3_url)] + yield event.set_result(MessageEventResult(chain=chain)) + elif "Traceback (most recent call last)" in log \ or "[Error]: " in log: traceback = "\n".join(logs[idx:]) diff --git a/packages/python_interpreter/shared/api.py b/packages/python_interpreter/shared/api.py index 320936e01..9fe27ce67 100644 --- a/packages/python_interpreter/shared/api.py +++ b/packages/python_interpreter/shared/api.py @@ -10,4 +10,9 @@ def send_text(text: str): def send_image(image_path: str): if not os.path.exists(image_path): raise Exception(f"Image file not found: {image_path}") - print(f"[ASTRBOT_IMAGE_OUTPUT#{_get_magic_code()}]: {image_path}") \ No newline at end of file + print(f"[ASTRBOT_IMAGE_OUTPUT#{_get_magic_code()}]: {image_path}") + +def send_file(file_path: str): + if not os.path.exists(file_path): + raise Exception(f"File not found: {file_path}") + print(f"[ASTRBOT_FILE_OUTPUT#{_get_magic_code()}]: {file_path}") \ No newline at end of file