feat: file upload

This commit is contained in:
Soulter
2025-10-04 23:42:59 +08:00
parent 972b5ffb86
commit 9fec29c1a3
6 changed files with 185 additions and 5 deletions
+2 -1
View File
@@ -813,7 +813,8 @@ class File(BaseMessageComponent):
"""下载文件"""
download_dir = os.path.join(get_astrbot_data_path(), "temp")
os.makedirs(download_dir, exist_ok=True)
file_path = os.path.join(download_dir, f"{uuid.uuid4().hex}")
fname = self.name if self.name else uuid.uuid4().hex
file_path = os.path.join(download_dir, fname)
await download_file(self.url, file_path)
self.file_ = os.path.abspath(file_path)
@@ -232,7 +232,9 @@ class AiocqhttpAdapter(Platform):
if m["data"].get("url") and m["data"].get("url").startswith("http"):
# Lagrange
logger.info("guessing lagrange")
file_name = m["data"].get("file_name", "file")
file_name = m["data"].get(
"file_name", m["data"].get("file", "file")
)
abm.message.append(File(name=file_name, url=m["data"]["url"]))
else:
try:
+129
View File
@@ -0,0 +1,129 @@
import os
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api import logger
from astrbot.api.all import Context
import astrbot.api.message_components as Comp
from astrbot.core.utils.session_waiter import (
session_waiter,
SessionController,
)
from ..sandbox_client import SandboxClient
class FileCommand:
def __init__(self, context: Context) -> None:
self.context = context
self.user_file_uploads: dict[str, list[str]] = {} # umo -> file_path
self.user_file_uploaded_files: dict[str, list[str]] = {} # umo -> file_path
"""记录用户上传过的文件,保存了文件在沙箱中的路径。这个在用户下一次 LLM 请求时会被用到,然后清空。"""
async def _upload_file_to_sandbox(self, event: AstrMessageEvent) -> list[str]:
"""将用户上传的文件上传到沙箱"""
sender_id = event.get_sender_id()
sb = await SandboxClient().get_ship(event.unified_msg_origin)
fpath_ls = self.user_file_uploads[sender_id]
errors = []
for path in fpath_ls:
try:
fname = os.path.basename(path)
data = await sb.upload_file(path, fname)
success = data.get("success", False)
if not success:
raise Exception(f"Upload failed: {data}")
file_path = data.get("file_path", "")
logger.info(f"File {fname} uploaded to sandbox at {file_path}")
self.user_file_uploaded_files.setdefault(sender_id, []).append(
file_path
)
except Exception as e:
errors.append((path, str(e)))
logger.error(f"Error uploading file {path}: {e}")
# clean up files
for path in fpath_ls:
try:
os.remove(path)
except Exception as e:
logger.error(f"Error removing temp file {path}: {e}")
return errors
async def file(self, event: AstrMessageEvent):
"""等待用户上传文件或图片"""
await event.send(
MessageChain().message(
f"请上传一个或多个文件(或图片),使用 /endupload 结束上传。(请求者 ID: {event.get_sender_id()})"
)
)
try:
@session_waiter(timeout=600, record_history_chains=False) # type: ignore
async def empty_mention_waiter(
controller: SessionController, event: AstrMessageEvent
):
idiom = event.message_str
sender_id = event.get_sender_id()
if idiom == "endupload":
files = self.user_file_uploads.get(sender_id, [])
if not files:
await event.send(
event.plain_result("你没有上传任何文件,上传已取消。")
)
controller.stop()
return
await event.send(
event.plain_result(f"开始上传 {len(files)} 个文件到沙箱...")
)
errors = await self._upload_file_to_sandbox(event)
if errors:
error_msgs = "\n".join(
[f"{path}: {err}" for path, err in errors]
)
await event.send(
event.plain_result(
f"上传中出现错误:\n{error_msgs}\n其他文件已成功上传。"
)
)
else:
await event.send(
event.plain_result(
f"上传完毕,共上传 {len(files)} 个文件。文件信息已被保存,下一次 LLM 请求时会自动将信息附上。"
)
)
self.user_file_uploads.pop(sender_id, None)
controller.stop()
return
# 解析文件或图片消息
for comp in event.message_obj.message:
if isinstance(comp, (Comp.File, Comp.Image)):
if isinstance(comp, Comp.File):
path = await comp.get_file()
self.user_file_uploads.setdefault(
event.get_sender_id(), []
).append(path)
elif isinstance(comp, Comp.Image):
path = await comp.convert_to_file_path()
self.user_file_uploads.setdefault(
event.get_sender_id(), []
).append(path)
fname = os.path.basename(path)
await event.send(
event.plain_result(
f"已接收文件: {fname},继续上传或发送 /endupload 结束。"
)
)
try:
await empty_mention_waiter(event)
except TimeoutError as _:
await event.send(event.plain_result("等待上传超时,上传已取消。"))
except Exception as e:
await event.send(
event.plain_result("发生错误,请联系管理员: " + str(e))
)
finally:
event.stop_event()
except Exception as e:
logger.error("handle_empty_mention error: " + str(e))
+25 -2
View File
@@ -1,9 +1,13 @@
import os
import astrbot.api.star as star
from astrbot.api import logger
from astrbot.api.event import filter, AstrMessageEvent
from astrbot.api.provider import ProviderRequest
from astrbot.api import AstrBotConfig
from .tools.fs import CreateFileTool
from .tools.fs import CreateFileTool, ReadFileTool
from .tools.shell import ExecuteShellTool
from .tools.python import PythonTool
from .commands.file import FileCommand
class Main(star.Star):
@@ -17,7 +21,26 @@ class Main(star.Star):
os.environ["SHIPYARD_ENDPOINT"] = self.endpoint
os.environ["SHIPYARD_ACCESS_TOKEN"] = self.access_token
context.add_llm_tool(CreateFileTool(), ExecuteShellTool(), PythonTool())
context.add_llm_tool(
CreateFileTool(), ExecuteShellTool(), PythonTool(), ReadFileTool()
)
self.file_c = FileCommand(context)
async def initialize(self):
pass
@filter.command("fileupload")
async def fileupload(self, event: AstrMessageEvent):
"""处理文件上传"""
await self.file_c.file(event)
@filter.on_llm_request()
async def on_llm_request(self, event: AstrMessageEvent, req: ProviderRequest):
"""处理 LLM 请求"""
sender_id = event.get_sender_id()
uploads = self.file_c.user_file_uploaded_files.pop(sender_id, None)
if uploads:
logger.info(f"Attaching uploaded files for user {sender_id}: {uploads}")
req.system_prompt = f"{req.system_prompt}\n\n\n# User Uploaded Files: {uploads}"
+25
View File
@@ -33,3 +33,28 @@ class CreateFileTool(FunctionTool):
return json.dumps(result)
except Exception as e:
return f"Error creating file: {str(e)}"
@dataclass
class ReadFileTool(FunctionTool):
name: str = "astrbot_read_file"
description: str = "Read the content of a file in the sandbox."
parameters: dict = field(
default_factory=lambda: {
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "The path of the file to read, relative to the sandbox root. Must not use absolute paths or traverse outside the sandbox.",
},
},
"required": ["path"],
}
)
async def run(self, event: AstrMessageEvent, path: str):
sb = await SandboxClient().get_ship(event.unified_msg_origin)
try:
result = await sb.fs.read_file(path)
return result
except Exception as e:
return f"Error reading file: {str(e)}"
+1 -1
View File
@@ -51,7 +51,7 @@ dependencies = [
"wechatpy>=1.8.18",
"audioop-lts ; python_full_version >= '3.13'",
"click>=8.2.1",
"shipyard-python-sdk>=0.1.0",
"shipyard-python-sdk>=0.2.3",
]
[project.scripts]