perf: 支持图片输入
This commit is contained in:
@@ -54,6 +54,7 @@ class ComponentType(Enum):
|
||||
CardImage = "CardImage"
|
||||
TTS = "TTS"
|
||||
Unknown = "Unknown"
|
||||
File = "File"
|
||||
|
||||
|
||||
class BaseMessageComponent(BaseModel):
|
||||
@@ -415,6 +416,17 @@ class Unknown(BaseMessageComponent):
|
||||
def toString(self):
|
||||
return ""
|
||||
|
||||
class File(BaseMessageComponent):
|
||||
'''
|
||||
目前此消息段只适配了 Napcat。
|
||||
'''
|
||||
type: ComponentType = "File"
|
||||
name: T.Optional[str] = "" # 名字
|
||||
file: T.Optional[str] = "" # url(本地路径)
|
||||
|
||||
def __init__(self, name: str, file: str):
|
||||
super().__init__(name=name, file=file)
|
||||
|
||||
|
||||
ComponentTypes = {
|
||||
"plain": Plain,
|
||||
@@ -441,5 +453,6 @@ ComponentTypes = {
|
||||
"json": Json,
|
||||
"cardimage": CardImage,
|
||||
"tts": TTS,
|
||||
"unknown": Unknown
|
||||
"unknown": Unknown,
|
||||
'file': File,
|
||||
}
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import os
|
||||
import time
|
||||
import asyncio
|
||||
import logging
|
||||
@@ -5,12 +6,13 @@ from typing import Awaitable, Any
|
||||
from aiocqhttp import CQHttp, Event
|
||||
from astrbot.api.platform import Platform, AstrBotMessage, MessageMember, MessageType, PlatformMetadata
|
||||
from astrbot.api.event import MessageChain
|
||||
from .aiocqhttp_message_event import *
|
||||
from astrbot.api.message_components import *
|
||||
from .aiocqhttp_message_event import * # noqa: F403
|
||||
from astrbot.api.message_components import * # noqa: F403
|
||||
from astrbot.api import logger
|
||||
from .aiocqhttp_message_event import AiocqhttpMessageEvent
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from ...register import register_platform_adapter
|
||||
from aiocqhttp.exceptions import ActionFailed
|
||||
|
||||
@register_platform_adapter("aiocqhttp", "适用于 OneBot 标准的消息平台适配器,支持反向 WebSockets。")
|
||||
class AiocqhttpAdapter(Platform):
|
||||
@@ -42,7 +44,7 @@ class AiocqhttpAdapter(Platform):
|
||||
await self.bot.send_private_msg(user_id=session.session_id, message=ret)
|
||||
await super().send_by_session(session, message_chain)
|
||||
|
||||
def convert_message(self, event: Event) -> AstrBotMessage:
|
||||
async def convert_message(self, event: Event) -> AstrBotMessage:
|
||||
abm = AstrBotMessage()
|
||||
abm.self_id = str(event.self_id)
|
||||
abm.tag = "aiocqhttp"
|
||||
@@ -78,7 +80,25 @@ class AiocqhttpAdapter(Platform):
|
||||
a = None
|
||||
if t == 'text':
|
||||
message_str += m['data']['text'].strip()
|
||||
a = ComponentTypes[t](**m['data'])
|
||||
elif t == 'file':
|
||||
try:
|
||||
# Napcat, LLBot
|
||||
ret = await self.bot.call_action(action="get_file", file_id=event.message[0]['data']['file_id'])
|
||||
if not ret.get('file', None):
|
||||
raise ValueError(f"无法解析文件响应: {ret}")
|
||||
if not os.path.exists(ret['file']):
|
||||
raise FileNotFoundError(f"文件不存在: {ret['file']}。如果您使用 Docker 部署了 AstrBot 或者消息协议端(Napcat等),暂时无法获取用户上传的文件。")
|
||||
|
||||
m['data'] = {
|
||||
"file": ret['file'],
|
||||
"name": ret['file_name']
|
||||
}
|
||||
except ActionFailed as e:
|
||||
logger.error(f"获取文件失败: {e},此消息段将被忽略。")
|
||||
except BaseException as e:
|
||||
logger.error(f"获取文件失败: {e},此消息段将被忽略。")
|
||||
|
||||
a = ComponentTypes[t](**m['data']) # noqa: F405
|
||||
abm.message.append(a)
|
||||
abm.timestamp = int(time.time())
|
||||
abm.message_str = message_str
|
||||
@@ -91,13 +111,13 @@ class AiocqhttpAdapter(Platform):
|
||||
self.bot = CQHttp(use_ws_reverse=True, import_name='aiocqhttp', api_timeout_sec=180)
|
||||
@self.bot.on_message('group')
|
||||
async def group(event: Event):
|
||||
abm = self.convert_message(event)
|
||||
abm = await self.convert_message(event)
|
||||
if abm:
|
||||
await self.handle_msg(abm)
|
||||
|
||||
@self.bot.on_message('private')
|
||||
async def private(event: Event):
|
||||
abm = self.convert_message(event)
|
||||
abm = await self.convert_message(event)
|
||||
if abm:
|
||||
await self.handle_msg(abm)
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ import sys
|
||||
import time
|
||||
import uuid
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
from astrbot.api.platform import Platform, AstrBotMessage, MessageMember, MessageType, PlatformMetadata
|
||||
from astrbot.api.event import MessageChain
|
||||
@@ -62,7 +63,7 @@ class VChatPlatformAdapter(Platform):
|
||||
self.start_time = int(time.time())
|
||||
return self._run()
|
||||
|
||||
|
||||
|
||||
async def _run(self):
|
||||
await self.client.init()
|
||||
await self.client.auto_login(hot_reload=True, enable_cmd_qr=True)
|
||||
|
||||
@@ -1,15 +1,18 @@
|
||||
import os
|
||||
import json
|
||||
import shutil
|
||||
import aiohttp
|
||||
import uuid
|
||||
import asyncio
|
||||
import re
|
||||
import astrbot.api.star as star
|
||||
import aiodocker
|
||||
from astrbot.api.event import AstrMessageEvent
|
||||
from collections import defaultdict
|
||||
from astrbot.api.event import AstrMessageEvent, MessageEventResult
|
||||
from astrbot.api import llm_tool, logger
|
||||
from astrbot.api.event import filter
|
||||
from astrbot.api.message_components import Image
|
||||
from astrbot.api.provider import ProviderRequest
|
||||
from astrbot.api.message_components import Image, File
|
||||
|
||||
PROMPT = """
|
||||
## Task
|
||||
@@ -18,19 +21,35 @@ You need to generate python codes to solve user's problem: {prompt}
|
||||
{extra_input}
|
||||
|
||||
## Limit
|
||||
1. Available libraries: standard libs, `Pillow`, `requests`, `numpy`, `matplotlib`, `scipy`, `scikit-learn`, `beautifulsoup4`, `pandas`, `opencv-python`
|
||||
1. Available libraries:
|
||||
- standard libs
|
||||
- `Pillow`
|
||||
- `requests`
|
||||
- `numpy`
|
||||
- `matplotlib`
|
||||
- `scipy`
|
||||
- `scikit-learn`
|
||||
- `beautifulsoup4`
|
||||
- `pandas`
|
||||
- `opencv-python`
|
||||
- `python-docx`
|
||||
- `python-pptx`
|
||||
- `pymupdf` (Do not use fpdf, reportlab, etc.)
|
||||
- `mplfonts`
|
||||
You can only use these libraries and the libraries that they depend on.
|
||||
2. Do not generate malicious code.
|
||||
3. Only output text, image. For Image, you need save it to `output` folder.
|
||||
4. Use given `shared.api` package to output the result.
|
||||
5. You must only output the code, do not output the result of the code and any other information.
|
||||
6. The output language is same as user's input language.
|
||||
7. Please first provide relevant knowledge about user's problem appropriately.
|
||||
3. Use given `shared.api` package to output the result.
|
||||
It has 3 functions: `send_text(text: str)`, `send_image(image_path: str)`, `send_file(file_path: str)`.
|
||||
For Image and file, you must save it to `output` folder.
|
||||
4. You must only output the code, do not output the result of the code and any other information.
|
||||
5. The output language is same as user's input language.
|
||||
6. Please first provide relevant knowledge about user's problem appropriately.
|
||||
|
||||
## Example
|
||||
1. The user's problem is: `please solve the fabonacci sequence problem.`
|
||||
1. User's problem: `please solve the fabonacci sequence problem.`
|
||||
Output:
|
||||
```python
|
||||
from shared.api import send_text, send_image
|
||||
from shared.api import send_text, send_image, send_file
|
||||
|
||||
def fabonacci(n):
|
||||
if n <= 1:
|
||||
@@ -43,10 +62,10 @@ send_text("The fabonacci sequence is a series of numbers in which each number is
|
||||
send_text("Let's calculate the fabonacci sequence of 10: " + result) # send_text is a function to send pure text to user
|
||||
```
|
||||
|
||||
2. The user's problem is: `please draw a sin(x) function.`
|
||||
2. User's problem: `please draw a sin(x) function.`
|
||||
Output:
|
||||
```python
|
||||
from shared.api import send_text, send_image
|
||||
from shared.api import send_text, send_image, send_file
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
@@ -80,6 +99,9 @@ class Main(star.Star):
|
||||
self.shared_path = os.path.join(self.curr_dir, "shared")
|
||||
os.makedirs(self.workplace_path, exist_ok=True)
|
||||
|
||||
self.user_file_msg_buffer = defaultdict(list)
|
||||
'''存放用户上传的文件'''
|
||||
|
||||
# 加载配置
|
||||
if not os.path.exists(PATH):
|
||||
self.config = DEFAULT_CONFIG
|
||||
@@ -88,6 +110,23 @@ class Main(star.Star):
|
||||
with open(PATH, "r") as f:
|
||||
self.config = json.load(f)
|
||||
|
||||
async def file_upload(self, file_path: str):
|
||||
'''
|
||||
上传图像文件到 S3
|
||||
'''
|
||||
ext = os.path.splitext(file_path)[1]
|
||||
S3_URL = "https://s3.neko.soulter.top/astrbot-s3"
|
||||
with open(file_path, "rb") as f:
|
||||
file = f.read()
|
||||
|
||||
s3_file_url = f"{S3_URL}/{uuid.uuid4().hex}{ext}"
|
||||
|
||||
async with aiohttp.ClientSession(headers = {"Accept": "application/json"}) as session:
|
||||
async with session.put(s3_file_url, data=file) as resp:
|
||||
if resp.status != 200:
|
||||
raise Exception(f"Failed to upload image: {resp.status}")
|
||||
return s3_file_url
|
||||
|
||||
|
||||
async def is_docker_available(self) -> bool:
|
||||
'''Check if docker is available'''
|
||||
@@ -131,6 +170,21 @@ class Main(star.Star):
|
||||
raise ValueError("The code is not in the code block.")
|
||||
return match.group(1)
|
||||
|
||||
@filter.event_message_type(filter.EventMessageType.ALL)
|
||||
async def on_message(self, event: AstrMessageEvent):
|
||||
'''处理消息'''
|
||||
for comp in event.message_obj.message:
|
||||
if isinstance(comp, File):
|
||||
self.user_file_msg_buffer[event.get_session_id()].append(comp.file)
|
||||
logger.debug(f"User uploaded file: {comp.file}")
|
||||
break # 一个消息中,文件只能有一个,这里直接 break 减少计算量。
|
||||
|
||||
@filter.on_llm_request()
|
||||
async def on_llm_req(self, event: AstrMessageEvent, request: ProviderRequest):
|
||||
if event.get_session_id() in self.user_file_msg_buffer:
|
||||
files = self.user_file_msg_buffer[event.get_session_id()]
|
||||
request.prompt += f"\nUser provided files: {files}"
|
||||
|
||||
|
||||
@filter.command_group("pi")
|
||||
def pi(self):
|
||||
@@ -166,7 +220,8 @@ class Main(star.Star):
|
||||
|
||||
@llm_tool("python_interpreter")
|
||||
async def python_interpreter(self, event: AstrMessageEvent):
|
||||
'''Use this tool only if user really want to solve a complex problem and the problem can be solved very well by Python code. For example, user can use this tool to solve a math problem, edit Image, etc.
|
||||
'''Use this tool only if user really want to solve a complex problem and the problem can be solved very well by Python code.
|
||||
For example, user can use this tool to solve math problems, edit image, docx, pptx, pdf, etc.
|
||||
'''
|
||||
if not await self.is_docker_available():
|
||||
yield event.plain_result("Docker 在当前机器不可用,无法沙箱化执行代码。")
|
||||
@@ -191,7 +246,23 @@ class Main(star.Star):
|
||||
if image_path:
|
||||
images.append(image_path)
|
||||
idx += 1
|
||||
|
||||
# 文件
|
||||
files = []
|
||||
for file_path in self.user_file_msg_buffer[event.get_session_id()]:
|
||||
# cp
|
||||
file_name = os.path.basename(file_path)
|
||||
shutil.copy(file_path, os.path.join(workplace_path, file_name))
|
||||
files.append(file_name)
|
||||
|
||||
logger.debug(f"user query: {plain_text}, images: {images}, files: {files}")
|
||||
|
||||
# 整理额外输入
|
||||
extra_inputs = ""
|
||||
if images:
|
||||
extra_inputs += f"User provided images: {images}\n"
|
||||
if files:
|
||||
extra_inputs += f"User provided files: {files}\n"
|
||||
|
||||
obs = ""
|
||||
n = 5
|
||||
|
||||
@@ -201,7 +272,7 @@ class Main(star.Star):
|
||||
|
||||
PROMPT_ = PROMPT.format(
|
||||
prompt=plain_text,
|
||||
extra_input=f"User provided images: {images}",
|
||||
extra_input=extra_inputs,
|
||||
extra_prompt=obs,
|
||||
)
|
||||
provider = self.context.get_using_provider()
|
||||
@@ -253,7 +324,7 @@ class Main(star.Star):
|
||||
logger.debug(f"Container {container.id} logs: {logs}")
|
||||
|
||||
# 发送结果
|
||||
pattern = r"\[ASTRBOT_(TEXT|IMAGE)_OUTPUT#\w+\]: (.*)"
|
||||
pattern = r"\[ASTRBOT_(TEXT|IMAGE|FILE)_OUTPUT#\w+\]: (.*)"
|
||||
ok = False
|
||||
traceback = ""
|
||||
for idx, log in enumerate(logs):
|
||||
@@ -266,6 +337,15 @@ class Main(star.Star):
|
||||
image_path = os.path.join(workplace_path, match.group(2))
|
||||
logger.debug(f"Sending image: {image_path}")
|
||||
yield event.image_result(image_path)
|
||||
elif match.group(1) == "FILE":
|
||||
file_path = os.path.join(workplace_path, match.group(2))
|
||||
logger.debug(f"Sending file: {file_path}")
|
||||
file_s3_url = await self.file_upload(file_path)
|
||||
logger.info(f"文件上传到 AstrBot 云节点: {file_s3_url}")
|
||||
file_name = os.path.basename(file_path)
|
||||
chain = [File(name=file_name, file=file_s3_url)]
|
||||
yield event.set_result(MessageEventResult(chain=chain))
|
||||
|
||||
elif "Traceback (most recent call last)" in log \
|
||||
or "[Error]: " in log:
|
||||
traceback = "\n".join(logs[idx:])
|
||||
|
||||
@@ -10,4 +10,9 @@ def send_text(text: str):
|
||||
def send_image(image_path: str):
|
||||
if not os.path.exists(image_path):
|
||||
raise Exception(f"Image file not found: {image_path}")
|
||||
print(f"[ASTRBOT_IMAGE_OUTPUT#{_get_magic_code()}]: {image_path}")
|
||||
print(f"[ASTRBOT_IMAGE_OUTPUT#{_get_magic_code()}]: {image_path}")
|
||||
|
||||
def send_file(file_path: str):
|
||||
if not os.path.exists(file_path):
|
||||
raise Exception(f"File not found: {file_path}")
|
||||
print(f"[ASTRBOT_FILE_OUTPUT#{_get_magic_code()}]: {file_path}")
|
||||
Reference in New Issue
Block a user