feat: file upload
This commit is contained in:
@@ -813,7 +813,8 @@ class File(BaseMessageComponent):
|
||||
"""下载文件"""
|
||||
download_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
os.makedirs(download_dir, exist_ok=True)
|
||||
file_path = os.path.join(download_dir, f"{uuid.uuid4().hex}")
|
||||
fname = self.name if self.name else uuid.uuid4().hex
|
||||
file_path = os.path.join(download_dir, fname)
|
||||
await download_file(self.url, file_path)
|
||||
self.file_ = os.path.abspath(file_path)
|
||||
|
||||
|
||||
@@ -232,7 +232,9 @@ class AiocqhttpAdapter(Platform):
|
||||
if m["data"].get("url") and m["data"].get("url").startswith("http"):
|
||||
# Lagrange
|
||||
logger.info("guessing lagrange")
|
||||
file_name = m["data"].get("file_name", "file")
|
||||
file_name = m["data"].get(
|
||||
"file_name", m["data"].get("file", "file")
|
||||
)
|
||||
abm.message.append(File(name=file_name, url=m["data"]["url"]))
|
||||
else:
|
||||
try:
|
||||
|
||||
@@ -0,0 +1,129 @@
|
||||
import os
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
from astrbot.api import logger
|
||||
from astrbot.api.all import Context
|
||||
import astrbot.api.message_components as Comp
|
||||
from astrbot.core.utils.session_waiter import (
|
||||
session_waiter,
|
||||
SessionController,
|
||||
)
|
||||
from ..sandbox_client import SandboxClient
|
||||
|
||||
|
||||
class FileCommand:
|
||||
def __init__(self, context: Context) -> None:
|
||||
self.context = context
|
||||
self.user_file_uploads: dict[str, list[str]] = {} # umo -> file_path
|
||||
self.user_file_uploaded_files: dict[str, list[str]] = {} # umo -> file_path
|
||||
"""记录用户上传过的文件,保存了文件在沙箱中的路径。这个在用户下一次 LLM 请求时会被用到,然后清空。"""
|
||||
|
||||
async def _upload_file_to_sandbox(self, event: AstrMessageEvent) -> list[str]:
|
||||
"""将用户上传的文件上传到沙箱"""
|
||||
sender_id = event.get_sender_id()
|
||||
sb = await SandboxClient().get_ship(event.unified_msg_origin)
|
||||
fpath_ls = self.user_file_uploads[sender_id]
|
||||
errors = []
|
||||
for path in fpath_ls:
|
||||
try:
|
||||
fname = os.path.basename(path)
|
||||
data = await sb.upload_file(path, fname)
|
||||
success = data.get("success", False)
|
||||
if not success:
|
||||
raise Exception(f"Upload failed: {data}")
|
||||
file_path = data.get("file_path", "")
|
||||
logger.info(f"File {fname} uploaded to sandbox at {file_path}")
|
||||
self.user_file_uploaded_files.setdefault(sender_id, []).append(
|
||||
file_path
|
||||
)
|
||||
except Exception as e:
|
||||
errors.append((path, str(e)))
|
||||
logger.error(f"Error uploading file {path}: {e}")
|
||||
|
||||
# clean up files
|
||||
for path in fpath_ls:
|
||||
try:
|
||||
os.remove(path)
|
||||
except Exception as e:
|
||||
logger.error(f"Error removing temp file {path}: {e}")
|
||||
|
||||
return errors
|
||||
|
||||
async def file(self, event: AstrMessageEvent):
|
||||
"""等待用户上传文件或图片"""
|
||||
await event.send(
|
||||
MessageChain().message(
|
||||
f"请上传一个或多个文件(或图片),使用 /endupload 结束上传。(请求者 ID: {event.get_sender_id()})"
|
||||
)
|
||||
)
|
||||
try:
|
||||
|
||||
@session_waiter(timeout=600, record_history_chains=False) # type: ignore
|
||||
async def empty_mention_waiter(
|
||||
controller: SessionController, event: AstrMessageEvent
|
||||
):
|
||||
idiom = event.message_str
|
||||
sender_id = event.get_sender_id()
|
||||
|
||||
if idiom == "endupload":
|
||||
files = self.user_file_uploads.get(sender_id, [])
|
||||
if not files:
|
||||
await event.send(
|
||||
event.plain_result("你没有上传任何文件,上传已取消。")
|
||||
)
|
||||
controller.stop()
|
||||
return
|
||||
await event.send(
|
||||
event.plain_result(f"开始上传 {len(files)} 个文件到沙箱...")
|
||||
)
|
||||
errors = await self._upload_file_to_sandbox(event)
|
||||
if errors:
|
||||
error_msgs = "\n".join(
|
||||
[f"{path}: {err}" for path, err in errors]
|
||||
)
|
||||
await event.send(
|
||||
event.plain_result(
|
||||
f"上传中出现错误:\n{error_msgs}\n其他文件已成功上传。"
|
||||
)
|
||||
)
|
||||
else:
|
||||
await event.send(
|
||||
event.plain_result(
|
||||
f"上传完毕,共上传 {len(files)} 个文件。文件信息已被保存,下一次 LLM 请求时会自动将信息附上。"
|
||||
)
|
||||
)
|
||||
self.user_file_uploads.pop(sender_id, None)
|
||||
controller.stop()
|
||||
return
|
||||
|
||||
# 解析文件或图片消息
|
||||
for comp in event.message_obj.message:
|
||||
if isinstance(comp, (Comp.File, Comp.Image)):
|
||||
if isinstance(comp, Comp.File):
|
||||
path = await comp.get_file()
|
||||
self.user_file_uploads.setdefault(
|
||||
event.get_sender_id(), []
|
||||
).append(path)
|
||||
elif isinstance(comp, Comp.Image):
|
||||
path = await comp.convert_to_file_path()
|
||||
self.user_file_uploads.setdefault(
|
||||
event.get_sender_id(), []
|
||||
).append(path)
|
||||
fname = os.path.basename(path)
|
||||
await event.send(
|
||||
event.plain_result(
|
||||
f"已接收文件: {fname},继续上传或发送 /endupload 结束。"
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
await empty_mention_waiter(event)
|
||||
except TimeoutError as _:
|
||||
await event.send(event.plain_result("等待上传超时,上传已取消。"))
|
||||
except Exception as e:
|
||||
await event.send(
|
||||
event.plain_result("发生错误,请联系管理员: " + str(e))
|
||||
)
|
||||
finally:
|
||||
event.stop_event()
|
||||
except Exception as e:
|
||||
logger.error("handle_empty_mention error: " + str(e))
|
||||
@@ -1,9 +1,13 @@
|
||||
import os
|
||||
import astrbot.api.star as star
|
||||
from astrbot.api import logger
|
||||
from astrbot.api.event import filter, AstrMessageEvent
|
||||
from astrbot.api.provider import ProviderRequest
|
||||
from astrbot.api import AstrBotConfig
|
||||
from .tools.fs import CreateFileTool
|
||||
from .tools.fs import CreateFileTool, ReadFileTool
|
||||
from .tools.shell import ExecuteShellTool
|
||||
from .tools.python import PythonTool
|
||||
from .commands.file import FileCommand
|
||||
|
||||
|
||||
class Main(star.Star):
|
||||
@@ -17,7 +21,26 @@ class Main(star.Star):
|
||||
os.environ["SHIPYARD_ENDPOINT"] = self.endpoint
|
||||
os.environ["SHIPYARD_ACCESS_TOKEN"] = self.access_token
|
||||
|
||||
context.add_llm_tool(CreateFileTool(), ExecuteShellTool(), PythonTool())
|
||||
context.add_llm_tool(
|
||||
CreateFileTool(), ExecuteShellTool(), PythonTool(), ReadFileTool()
|
||||
)
|
||||
|
||||
self.file_c = FileCommand(context)
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
@filter.command("fileupload")
|
||||
async def fileupload(self, event: AstrMessageEvent):
|
||||
"""处理文件上传"""
|
||||
await self.file_c.file(event)
|
||||
|
||||
@filter.on_llm_request()
|
||||
async def on_llm_request(self, event: AstrMessageEvent, req: ProviderRequest):
|
||||
"""处理 LLM 请求"""
|
||||
sender_id = event.get_sender_id()
|
||||
uploads = self.file_c.user_file_uploaded_files.pop(sender_id, None)
|
||||
if uploads:
|
||||
logger.info(f"Attaching uploaded files for user {sender_id}: {uploads}")
|
||||
|
||||
req.system_prompt = f"{req.system_prompt}\n\n\n# User Uploaded Files: {uploads}"
|
||||
|
||||
@@ -33,3 +33,28 @@ class CreateFileTool(FunctionTool):
|
||||
return json.dumps(result)
|
||||
except Exception as e:
|
||||
return f"Error creating file: {str(e)}"
|
||||
|
||||
@dataclass
|
||||
class ReadFileTool(FunctionTool):
|
||||
name: str = "astrbot_read_file"
|
||||
description: str = "Read the content of a file in the sandbox."
|
||||
parameters: dict = field(
|
||||
default_factory=lambda: {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "The path of the file to read, relative to the sandbox root. Must not use absolute paths or traverse outside the sandbox.",
|
||||
},
|
||||
},
|
||||
"required": ["path"],
|
||||
}
|
||||
)
|
||||
|
||||
async def run(self, event: AstrMessageEvent, path: str):
|
||||
sb = await SandboxClient().get_ship(event.unified_msg_origin)
|
||||
try:
|
||||
result = await sb.fs.read_file(path)
|
||||
return result
|
||||
except Exception as e:
|
||||
return f"Error reading file: {str(e)}"
|
||||
|
||||
+1
-1
@@ -51,7 +51,7 @@ dependencies = [
|
||||
"wechatpy>=1.8.18",
|
||||
"audioop-lts ; python_full_version >= '3.13'",
|
||||
"click>=8.2.1",
|
||||
"shipyard-python-sdk>=0.1.0",
|
||||
"shipyard-python-sdk>=0.2.3",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
Reference in New Issue
Block a user