refactor: 代码执行器使用指令来制定上传文件以更好适配全平台;telegram 支持发送文件和语音
This commit is contained in:
@@ -116,6 +116,7 @@ class LLMRequestSubStage(Stage):
|
||||
elif llm_response.role == 'tool':
|
||||
# function calling
|
||||
function_calling_result = {}
|
||||
logger.info(f"触发 {len(llm_response.tools_call_name)} 个函数调用: {llm_response.tools_call_name}")
|
||||
for func_tool_name, func_tool_args in zip(llm_response.tools_call_name, llm_response.tools_call_args):
|
||||
func_tool = req.func_tool.get_func(func_tool_name)
|
||||
logger.info(f"调用工具函数:{func_tool_name},参数:{func_tool_args}")
|
||||
|
||||
@@ -102,13 +102,13 @@ class TelegramPlatformAdapter(Platform):
|
||||
message.message = [Record(file=file.file_path, url=file.file_path),]
|
||||
|
||||
elif update.message.photo:
|
||||
for photo in update.message.photo:
|
||||
file = await photo.get_file()
|
||||
message.message.append(Image(file=file.file_path, url=file.file_path))
|
||||
photo = update.message.photo[-1] # get the largest photo
|
||||
file = await photo.get_file()
|
||||
message.message.append(Image(file=file.file_path, url=file.file_path))
|
||||
|
||||
elif update.message.document:
|
||||
file = await update.message.document.get_file()
|
||||
message.message = [AstrBotFile(file=file.file_path, name="file"),]
|
||||
message.message = [AstrBotFile(file=file.file_path, name=update.message.document.file_name),]
|
||||
|
||||
elif update.message.video:
|
||||
file = await update.message.video.get_file()
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
from astrbot.api.platform import AstrBotMessage, PlatformMetadata, MessageType
|
||||
from astrbot.api.message_components import Plain, Image, Reply, At
|
||||
from astrbot.api.message_components import Plain, Image, Reply, At, File, Record
|
||||
from telegram.ext import ExtBot
|
||||
|
||||
class TelegramPlatformEvent(AstrMessageEvent):
|
||||
@@ -48,6 +48,11 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
||||
await client.send_photo(photo=image_bytes, **payload)
|
||||
else:
|
||||
await client.send_photo(photo=image_path, **payload)
|
||||
elif isinstance(i, File):
|
||||
await client.send_document(document=i.file, filename=i.name, **payload)
|
||||
elif isinstance(i, Record):
|
||||
await client.send_voice(voice=i.file, **payload)
|
||||
|
||||
|
||||
async def send(self, message: MessageChain):
|
||||
if self.get_message_type() == MessageType.GROUP_MESSAGE:
|
||||
|
||||
@@ -5,14 +5,16 @@ import aiohttp
|
||||
import uuid
|
||||
import asyncio
|
||||
import re
|
||||
import astrbot.api.star as star
|
||||
import aiodocker
|
||||
import time
|
||||
import astrbot.api.star as star
|
||||
from collections import defaultdict
|
||||
from astrbot.api.event import AstrMessageEvent, MessageEventResult
|
||||
from astrbot.api import llm_tool, logger
|
||||
from astrbot.api.event import filter
|
||||
from astrbot.api.provider import ProviderRequest
|
||||
from astrbot.api.message_components import Image, File
|
||||
from astrbot.core.utils.io import download_image_by_url, download_file
|
||||
|
||||
PROMPT = """
|
||||
## Task
|
||||
@@ -107,7 +109,9 @@ class Main(star.Star):
|
||||
os.makedirs(self.workplace_path, exist_ok=True)
|
||||
|
||||
self.user_file_msg_buffer = defaultdict(list)
|
||||
'''存放用户上传的文件'''
|
||||
'''存放用户上传的文件和图片'''
|
||||
self.user_waiting = {}
|
||||
'''正在等待用户的文件或图片'''
|
||||
|
||||
# 加载配置
|
||||
if not os.path.exists(PATH):
|
||||
@@ -187,11 +191,35 @@ class Main(star.Star):
|
||||
@filter.event_message_type(filter.EventMessageType.ALL)
|
||||
async def on_message(self, event: AstrMessageEvent):
|
||||
'''处理消息'''
|
||||
uid = event.get_sender_id()
|
||||
if uid not in self.user_waiting:
|
||||
return
|
||||
for comp in event.message_obj.message:
|
||||
if isinstance(comp, File):
|
||||
self.user_file_msg_buffer[event.get_session_id()].append(comp.file)
|
||||
logger.debug(f"User uploaded file: {comp.file}")
|
||||
break # 一个消息中,文件只能有一个,这里直接 break 减少计算量。
|
||||
if comp.file.startswith("http"):
|
||||
name = comp.name if comp.name else uuid.uuid4().hex[:8]
|
||||
path = f"data/temp/{name}"
|
||||
await download_file(comp.file, path)
|
||||
else:
|
||||
path = comp.file
|
||||
self.user_file_msg_buffer[event.get_session_id()].append(path)
|
||||
logger.debug(f"User {uid} uploaded file: {path}")
|
||||
yield event.plain_result(f"代码执行器: 文件已经上传: {path}")
|
||||
if uid in self.user_waiting:
|
||||
del self.user_waiting[uid]
|
||||
elif isinstance(comp, Image):
|
||||
image_url = comp.url if comp.url else comp.file
|
||||
if image_url.startswith("http"):
|
||||
image_path = await download_image_by_url(image_url)
|
||||
elif image_url.startswith("file:///"):
|
||||
image_path = image_url.replace("file:///", "")
|
||||
else:
|
||||
image_path = image_url
|
||||
self.user_file_msg_buffer[event.get_session_id()].append(image_path)
|
||||
logger.debug(f"User {uid} uploaded image: {image_path}")
|
||||
yield event.plain_result(f"代码执行器: 图片已经上传: {image_path}")
|
||||
if uid in self.user_waiting:
|
||||
del self.user_waiting[uid]
|
||||
|
||||
@filter.on_llm_request()
|
||||
async def on_llm_req(self, event: AstrMessageEvent, request: ProviderRequest):
|
||||
@@ -239,7 +267,19 @@ class Main(star.Star):
|
||||
pass
|
||||
await docker.images.pull(image_name)
|
||||
yield event.plain_result("重新拉取沙箱镜像成功。")
|
||||
|
||||
|
||||
@pi.command("file")
|
||||
async def pi_file(self, event: AstrMessageEvent):
|
||||
'''在规定秒数(60s)内上传一个文件'''
|
||||
uid = event.get_sender_id()
|
||||
self.user_waiting[uid] = time.time()
|
||||
tip = "文件"
|
||||
yield event.plain_result(f"代码执行器: 请在 60s 内上传一个{tip}。")
|
||||
await asyncio.sleep(60)
|
||||
if uid in self.user_waiting:
|
||||
yield event.plain_result(f"代码执行器: {event.get_sender_name()}/{event.get_sender_id()} 未在规定时间内上传{tip}。")
|
||||
self.user_waiting.pop(uid)
|
||||
|
||||
|
||||
@llm_tool("python_interpreter")
|
||||
async def python_interpreter(self, event: AstrMessageEvent):
|
||||
@@ -258,31 +298,23 @@ class Main(star.Star):
|
||||
os.makedirs(workplace_path, exist_ok=True)
|
||||
os.makedirs(output_path, exist_ok=True)
|
||||
|
||||
# 图片
|
||||
images = []
|
||||
idx = 1
|
||||
for comp in event.message_obj.message:
|
||||
if isinstance(comp, Image):
|
||||
image_url = comp.url if comp.url else comp.file
|
||||
if image_url.startswith("http"):
|
||||
image_path = await self.download_image(image_url, workplace_path, f"img_{idx}")
|
||||
if image_path:
|
||||
images.append(image_path)
|
||||
idx += 1
|
||||
# 文件
|
||||
files = []
|
||||
# 文件
|
||||
for file_path in self.user_file_msg_buffer[event.get_session_id()]:
|
||||
if not file_path:
|
||||
continue
|
||||
elif not os.path.exists(file_path):
|
||||
logger.warning(f"文件 {file_path} 不存在,已忽略。")
|
||||
continue
|
||||
# cp
|
||||
file_name = os.path.basename(file_path)
|
||||
shutil.copy(file_path, os.path.join(workplace_path, file_name))
|
||||
files.append(file_name)
|
||||
|
||||
logger.debug(f"user query: {plain_text}, images: {images}, files: {files}")
|
||||
logger.debug(f"user query: {plain_text}, files: {files}")
|
||||
|
||||
# 整理额外输入
|
||||
extra_inputs = ""
|
||||
if images:
|
||||
extra_inputs += f"User provided images: {images}\n"
|
||||
if files:
|
||||
extra_inputs += f"User provided files: {files}\n"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user