perf: 支持图片输入

This commit is contained in:
Soulter
2025-01-08 19:56:03 +08:00
parent 75cc4cac5a
commit b1d1a13d5f
5 changed files with 144 additions and 25 deletions
+96 -16
View File
@@ -1,15 +1,18 @@
import os
import json
import shutil
import aiohttp
import uuid
import asyncio
import re
import astrbot.api.star as star
import aiodocker
from astrbot.api.event import AstrMessageEvent
from collections import defaultdict
from astrbot.api.event import AstrMessageEvent, MessageEventResult
from astrbot.api import llm_tool, logger
from astrbot.api.event import filter
from astrbot.api.message_components import Image
from astrbot.api.provider import ProviderRequest
from astrbot.api.message_components import Image, File
PROMPT = """
## Task
@@ -18,19 +21,35 @@ You need to generate python codes to solve user's problem: {prompt}
{extra_input}
## Limit
1. Available libraries: standard libs, `Pillow`, `requests`, `numpy`, `matplotlib`, `scipy`, `scikit-learn`, `beautifulsoup4`, `pandas`, `opencv-python`
1. Available libraries:
- standard libs
- `Pillow`
- `requests`
- `numpy`
- `matplotlib`
- `scipy`
- `scikit-learn`
- `beautifulsoup4`
- `pandas`
- `opencv-python`
- `python-docx`
- `python-pptx`
- `pymupdf` (Do not use fpdf, reportlab, etc.)
- `mplfonts`
You can only use these libraries and the libraries that they depend on.
2. Do not generate malicious code.
3. Only output text, image. For Image, you need save it to `output` folder.
4. Use given `shared.api` package to output the result.
5. You must only output the code, do not output the result of the code and any other information.
6. The output language is same as user's input language.
7. Please first provide relevant knowledge about user's problem appropriately.
3. Use given `shared.api` package to output the result.
It has 3 functions: `send_text(text: str)`, `send_image(image_path: str)`, `send_file(file_path: str)`.
For Image and file, you must save it to `output` folder.
4. You must only output the code, do not output the result of the code and any other information.
5. The output language is same as user's input language.
6. Please first provide relevant knowledge about user's problem appropriately.
## Example
1. The user's problem is: `please solve the fabonacci sequence problem.`
1. User's problem: `please solve the fabonacci sequence problem.`
Output:
```python
from shared.api import send_text, send_image
from shared.api import send_text, send_image, send_file
def fabonacci(n):
if n <= 1:
@@ -43,10 +62,10 @@ send_text("The fabonacci sequence is a series of numbers in which each number is
send_text("Let's calculate the fabonacci sequence of 10: " + result) # send_text is a function to send pure text to user
```
2. The user's problem is: `please draw a sin(x) function.`
2. User's problem: `please draw a sin(x) function.`
Output:
```python
from shared.api import send_text, send_image
from shared.api import send_text, send_image, send_file
import numpy as np
import matplotlib.pyplot as plt
@@ -80,6 +99,9 @@ class Main(star.Star):
self.shared_path = os.path.join(self.curr_dir, "shared")
os.makedirs(self.workplace_path, exist_ok=True)
self.user_file_msg_buffer = defaultdict(list)
'''存放用户上传的文件'''
# 加载配置
if not os.path.exists(PATH):
self.config = DEFAULT_CONFIG
@@ -88,6 +110,23 @@ class Main(star.Star):
with open(PATH, "r") as f:
self.config = json.load(f)
async def file_upload(self, file_path: str):
'''
上传图像文件到 S3
'''
ext = os.path.splitext(file_path)[1]
S3_URL = "https://s3.neko.soulter.top/astrbot-s3"
with open(file_path, "rb") as f:
file = f.read()
s3_file_url = f"{S3_URL}/{uuid.uuid4().hex}{ext}"
async with aiohttp.ClientSession(headers = {"Accept": "application/json"}) as session:
async with session.put(s3_file_url, data=file) as resp:
if resp.status != 200:
raise Exception(f"Failed to upload image: {resp.status}")
return s3_file_url
async def is_docker_available(self) -> bool:
'''Check if docker is available'''
@@ -131,6 +170,21 @@ class Main(star.Star):
raise ValueError("The code is not in the code block.")
return match.group(1)
@filter.event_message_type(filter.EventMessageType.ALL)
async def on_message(self, event: AstrMessageEvent):
'''处理消息'''
for comp in event.message_obj.message:
if isinstance(comp, File):
self.user_file_msg_buffer[event.get_session_id()].append(comp.file)
logger.debug(f"User uploaded file: {comp.file}")
break # 一个消息中,文件只能有一个,这里直接 break 减少计算量。
@filter.on_llm_request()
async def on_llm_req(self, event: AstrMessageEvent, request: ProviderRequest):
if event.get_session_id() in self.user_file_msg_buffer:
files = self.user_file_msg_buffer[event.get_session_id()]
request.prompt += f"\nUser provided files: {files}"
@filter.command_group("pi")
def pi(self):
@@ -166,7 +220,8 @@ class Main(star.Star):
@llm_tool("python_interpreter")
async def python_interpreter(self, event: AstrMessageEvent):
'''Use this tool only if user really want to solve a complex problem and the problem can be solved very well by Python code. For example, user can use this tool to solve a math problem, edit Image, etc.
'''Use this tool only if user really want to solve a complex problem and the problem can be solved very well by Python code.
For example, user can use this tool to solve math problems, edit image, docx, pptx, pdf, etc.
'''
if not await self.is_docker_available():
yield event.plain_result("Docker 在当前机器不可用,无法沙箱化执行代码。")
@@ -191,7 +246,23 @@ class Main(star.Star):
if image_path:
images.append(image_path)
idx += 1
# 文件
files = []
for file_path in self.user_file_msg_buffer[event.get_session_id()]:
# cp
file_name = os.path.basename(file_path)
shutil.copy(file_path, os.path.join(workplace_path, file_name))
files.append(file_name)
logger.debug(f"user query: {plain_text}, images: {images}, files: {files}")
# 整理额外输入
extra_inputs = ""
if images:
extra_inputs += f"User provided images: {images}\n"
if files:
extra_inputs += f"User provided files: {files}\n"
obs = ""
n = 5
@@ -201,7 +272,7 @@ class Main(star.Star):
PROMPT_ = PROMPT.format(
prompt=plain_text,
extra_input=f"User provided images: {images}",
extra_input=extra_inputs,
extra_prompt=obs,
)
provider = self.context.get_using_provider()
@@ -253,7 +324,7 @@ class Main(star.Star):
logger.debug(f"Container {container.id} logs: {logs}")
# 发送结果
pattern = r"\[ASTRBOT_(TEXT|IMAGE)_OUTPUT#\w+\]: (.*)"
pattern = r"\[ASTRBOT_(TEXT|IMAGE|FILE)_OUTPUT#\w+\]: (.*)"
ok = False
traceback = ""
for idx, log in enumerate(logs):
@@ -266,6 +337,15 @@ class Main(star.Star):
image_path = os.path.join(workplace_path, match.group(2))
logger.debug(f"Sending image: {image_path}")
yield event.image_result(image_path)
elif match.group(1) == "FILE":
file_path = os.path.join(workplace_path, match.group(2))
logger.debug(f"Sending file: {file_path}")
file_s3_url = await self.file_upload(file_path)
logger.info(f"文件上传到 AstrBot 云节点: {file_s3_url}")
file_name = os.path.basename(file_path)
chain = [File(name=file_name, file=file_s3_url)]
yield event.set_result(MessageEventResult(chain=chain))
elif "Traceback (most recent call last)" in log \
or "[Error]: " in log:
traceback = "\n".join(logs[idx:])
+6 -1
View File
@@ -10,4 +10,9 @@ def send_text(text: str):
def send_image(image_path: str):
if not os.path.exists(image_path):
raise Exception(f"Image file not found: {image_path}")
print(f"[ASTRBOT_IMAGE_OUTPUT#{_get_magic_code()}]: {image_path}")
print(f"[ASTRBOT_IMAGE_OUTPUT#{_get_magic_code()}]: {image_path}")
def send_file(file_path: str):
if not os.path.exists(file_path):
raise Exception(f"File not found: {file_path}")
print(f"[ASTRBOT_FILE_OUTPUT#{_get_magic_code()}]: {file_path}")