perf: 支持图片输入

This commit is contained in:
Soulter
2025-01-08 19:56:03 +08:00
parent 75cc4cac5a
commit b1d1a13d5f
5 changed files with 144 additions and 25 deletions
+14 -1
View File
@@ -54,6 +54,7 @@ class ComponentType(Enum):
CardImage = "CardImage"
TTS = "TTS"
Unknown = "Unknown"
File = "File"
class BaseMessageComponent(BaseModel):
@@ -415,6 +416,17 @@ class Unknown(BaseMessageComponent):
def toString(self):
return ""
class File(BaseMessageComponent):
'''
目前此消息段只适配了 Napcat。
'''
type: ComponentType = "File"
name: T.Optional[str] = "" # 名字
file: T.Optional[str] = "" # url(本地路径)
def __init__(self, name: str, file: str):
super().__init__(name=name, file=file)
ComponentTypes = {
"plain": Plain,
@@ -441,5 +453,6 @@ ComponentTypes = {
"json": Json,
"cardimage": CardImage,
"tts": TTS,
"unknown": Unknown
"unknown": Unknown,
'file': File,
}
@@ -1,3 +1,4 @@
import os
import time
import asyncio
import logging
@@ -5,12 +6,13 @@ from typing import Awaitable, Any
from aiocqhttp import CQHttp, Event
from astrbot.api.platform import Platform, AstrBotMessage, MessageMember, MessageType, PlatformMetadata
from astrbot.api.event import MessageChain
from .aiocqhttp_message_event import *
from astrbot.api.message_components import *
from .aiocqhttp_message_event import * # noqa: F403
from astrbot.api.message_components import * # noqa: F403
from astrbot.api import logger
from .aiocqhttp_message_event import AiocqhttpMessageEvent
from astrbot.core.platform.astr_message_event import MessageSesion
from ...register import register_platform_adapter
from aiocqhttp.exceptions import ActionFailed
@register_platform_adapter("aiocqhttp", "适用于 OneBot 标准的消息平台适配器,支持反向 WebSockets。")
class AiocqhttpAdapter(Platform):
@@ -42,7 +44,7 @@ class AiocqhttpAdapter(Platform):
await self.bot.send_private_msg(user_id=session.session_id, message=ret)
await super().send_by_session(session, message_chain)
def convert_message(self, event: Event) -> AstrBotMessage:
async def convert_message(self, event: Event) -> AstrBotMessage:
abm = AstrBotMessage()
abm.self_id = str(event.self_id)
abm.tag = "aiocqhttp"
@@ -78,7 +80,25 @@ class AiocqhttpAdapter(Platform):
a = None
if t == 'text':
message_str += m['data']['text'].strip()
a = ComponentTypes[t](**m['data'])
elif t == 'file':
try:
# Napcat, LLBot
ret = await self.bot.call_action(action="get_file", file_id=event.message[0]['data']['file_id'])
if not ret.get('file', None):
raise ValueError(f"无法解析文件响应: {ret}")
if not os.path.exists(ret['file']):
raise FileNotFoundError(f"文件不存在: {ret['file']}。如果您使用 Docker 部署了 AstrBot 或者消息协议端(Napcat等),暂时无法获取用户上传的文件。")
m['data'] = {
"file": ret['file'],
"name": ret['file_name']
}
except ActionFailed as e:
logger.error(f"获取文件失败: {e},此消息段将被忽略。")
except BaseException as e:
logger.error(f"获取文件失败: {e},此消息段将被忽略。")
a = ComponentTypes[t](**m['data']) # noqa: F405
abm.message.append(a)
abm.timestamp = int(time.time())
abm.message_str = message_str
@@ -91,13 +111,13 @@ class AiocqhttpAdapter(Platform):
self.bot = CQHttp(use_ws_reverse=True, import_name='aiocqhttp', api_timeout_sec=180)
@self.bot.on_message('group')
async def group(event: Event):
abm = self.convert_message(event)
abm = await self.convert_message(event)
if abm:
await self.handle_msg(abm)
@self.bot.on_message('private')
async def private(event: Event):
abm = self.convert_message(event)
abm = await self.convert_message(event)
if abm:
await self.handle_msg(abm)
@@ -2,6 +2,7 @@ import sys
import time
import uuid
import asyncio
import os
from astrbot.api.platform import Platform, AstrBotMessage, MessageMember, MessageType, PlatformMetadata
from astrbot.api.event import MessageChain
@@ -62,7 +63,7 @@ class VChatPlatformAdapter(Platform):
self.start_time = int(time.time())
return self._run()
async def _run(self):
await self.client.init()
await self.client.auto_login(hot_reload=True, enable_cmd_qr=True)
+96 -16
View File
@@ -1,15 +1,18 @@
import os
import json
import shutil
import aiohttp
import uuid
import asyncio
import re
import astrbot.api.star as star
import aiodocker
from astrbot.api.event import AstrMessageEvent
from collections import defaultdict
from astrbot.api.event import AstrMessageEvent, MessageEventResult
from astrbot.api import llm_tool, logger
from astrbot.api.event import filter
from astrbot.api.message_components import Image
from astrbot.api.provider import ProviderRequest
from astrbot.api.message_components import Image, File
PROMPT = """
## Task
@@ -18,19 +21,35 @@ You need to generate python codes to solve user's problem: {prompt}
{extra_input}
## Limit
1. Available libraries: standard libs, `Pillow`, `requests`, `numpy`, `matplotlib`, `scipy`, `scikit-learn`, `beautifulsoup4`, `pandas`, `opencv-python`
1. Available libraries:
- standard libs
- `Pillow`
- `requests`
- `numpy`
- `matplotlib`
- `scipy`
- `scikit-learn`
- `beautifulsoup4`
- `pandas`
- `opencv-python`
- `python-docx`
- `python-pptx`
- `pymupdf` (Do not use fpdf, reportlab, etc.)
- `mplfonts`
You can only use these libraries and the libraries that they depend on.
2. Do not generate malicious code.
3. Only output text, image. For Image, you need save it to `output` folder.
4. Use given `shared.api` package to output the result.
5. You must only output the code, do not output the result of the code and any other information.
6. The output language is same as user's input language.
7. Please first provide relevant knowledge about user's problem appropriately.
3. Use given `shared.api` package to output the result.
It has 3 functions: `send_text(text: str)`, `send_image(image_path: str)`, `send_file(file_path: str)`.
For Image and file, you must save it to `output` folder.
4. You must only output the code, do not output the result of the code and any other information.
5. The output language is same as user's input language.
6. Please first provide relevant knowledge about user's problem appropriately.
## Example
1. The user's problem is: `please solve the fabonacci sequence problem.`
1. User's problem: `please solve the fabonacci sequence problem.`
Output:
```python
from shared.api import send_text, send_image
from shared.api import send_text, send_image, send_file
def fabonacci(n):
if n <= 1:
@@ -43,10 +62,10 @@ send_text("The fabonacci sequence is a series of numbers in which each number is
send_text("Let's calculate the fabonacci sequence of 10: " + result) # send_text is a function to send pure text to user
```
2. The user's problem is: `please draw a sin(x) function.`
2. User's problem: `please draw a sin(x) function.`
Output:
```python
from shared.api import send_text, send_image
from shared.api import send_text, send_image, send_file
import numpy as np
import matplotlib.pyplot as plt
@@ -80,6 +99,9 @@ class Main(star.Star):
self.shared_path = os.path.join(self.curr_dir, "shared")
os.makedirs(self.workplace_path, exist_ok=True)
self.user_file_msg_buffer = defaultdict(list)
'''存放用户上传的文件'''
# 加载配置
if not os.path.exists(PATH):
self.config = DEFAULT_CONFIG
@@ -88,6 +110,23 @@ class Main(star.Star):
with open(PATH, "r") as f:
self.config = json.load(f)
async def file_upload(self, file_path: str):
'''
上传图像文件到 S3
'''
ext = os.path.splitext(file_path)[1]
S3_URL = "https://s3.neko.soulter.top/astrbot-s3"
with open(file_path, "rb") as f:
file = f.read()
s3_file_url = f"{S3_URL}/{uuid.uuid4().hex}{ext}"
async with aiohttp.ClientSession(headers = {"Accept": "application/json"}) as session:
async with session.put(s3_file_url, data=file) as resp:
if resp.status != 200:
raise Exception(f"Failed to upload image: {resp.status}")
return s3_file_url
async def is_docker_available(self) -> bool:
'''Check if docker is available'''
@@ -131,6 +170,21 @@ class Main(star.Star):
raise ValueError("The code is not in the code block.")
return match.group(1)
@filter.event_message_type(filter.EventMessageType.ALL)
async def on_message(self, event: AstrMessageEvent):
'''处理消息'''
for comp in event.message_obj.message:
if isinstance(comp, File):
self.user_file_msg_buffer[event.get_session_id()].append(comp.file)
logger.debug(f"User uploaded file: {comp.file}")
break # 一个消息中,文件只能有一个,这里直接 break 减少计算量。
@filter.on_llm_request()
async def on_llm_req(self, event: AstrMessageEvent, request: ProviderRequest):
if event.get_session_id() in self.user_file_msg_buffer:
files = self.user_file_msg_buffer[event.get_session_id()]
request.prompt += f"\nUser provided files: {files}"
@filter.command_group("pi")
def pi(self):
@@ -166,7 +220,8 @@ class Main(star.Star):
@llm_tool("python_interpreter")
async def python_interpreter(self, event: AstrMessageEvent):
'''Use this tool only if user really want to solve a complex problem and the problem can be solved very well by Python code. For example, user can use this tool to solve a math problem, edit Image, etc.
'''Use this tool only if user really want to solve a complex problem and the problem can be solved very well by Python code.
For example, user can use this tool to solve math problems, edit image, docx, pptx, pdf, etc.
'''
if not await self.is_docker_available():
yield event.plain_result("Docker 在当前机器不可用,无法沙箱化执行代码。")
@@ -191,7 +246,23 @@ class Main(star.Star):
if image_path:
images.append(image_path)
idx += 1
# 文件
files = []
for file_path in self.user_file_msg_buffer[event.get_session_id()]:
# cp
file_name = os.path.basename(file_path)
shutil.copy(file_path, os.path.join(workplace_path, file_name))
files.append(file_name)
logger.debug(f"user query: {plain_text}, images: {images}, files: {files}")
# 整理额外输入
extra_inputs = ""
if images:
extra_inputs += f"User provided images: {images}\n"
if files:
extra_inputs += f"User provided files: {files}\n"
obs = ""
n = 5
@@ -201,7 +272,7 @@ class Main(star.Star):
PROMPT_ = PROMPT.format(
prompt=plain_text,
extra_input=f"User provided images: {images}",
extra_input=extra_inputs,
extra_prompt=obs,
)
provider = self.context.get_using_provider()
@@ -253,7 +324,7 @@ class Main(star.Star):
logger.debug(f"Container {container.id} logs: {logs}")
# 发送结果
pattern = r"\[ASTRBOT_(TEXT|IMAGE)_OUTPUT#\w+\]: (.*)"
pattern = r"\[ASTRBOT_(TEXT|IMAGE|FILE)_OUTPUT#\w+\]: (.*)"
ok = False
traceback = ""
for idx, log in enumerate(logs):
@@ -266,6 +337,15 @@ class Main(star.Star):
image_path = os.path.join(workplace_path, match.group(2))
logger.debug(f"Sending image: {image_path}")
yield event.image_result(image_path)
elif match.group(1) == "FILE":
file_path = os.path.join(workplace_path, match.group(2))
logger.debug(f"Sending file: {file_path}")
file_s3_url = await self.file_upload(file_path)
logger.info(f"文件上传到 AstrBot 云节点: {file_s3_url}")
file_name = os.path.basename(file_path)
chain = [File(name=file_name, file=file_s3_url)]
yield event.set_result(MessageEventResult(chain=chain))
elif "Traceback (most recent call last)" in log \
or "[Error]: " in log:
traceback = "\n".join(logs[idx:])
+6 -1
View File
@@ -10,4 +10,9 @@ def send_text(text: str):
def send_image(image_path: str):
if not os.path.exists(image_path):
raise Exception(f"Image file not found: {image_path}")
print(f"[ASTRBOT_IMAGE_OUTPUT#{_get_magic_code()}]: {image_path}")
print(f"[ASTRBOT_IMAGE_OUTPUT#{_get_magic_code()}]: {image_path}")
def send_file(file_path: str):
if not os.path.exists(file_path):
raise Exception(f"File not found: {file_path}")
print(f"[ASTRBOT_FILE_OUTPUT#{_get_magic_code()}]: {file_path}")