feat: add file download functionality and update shipyard SDK version

This commit is contained in:
Soulter
2026-01-13 21:28:56 +08:00
parent 068094708e
commit 792e348076
8 changed files with 75 additions and 7 deletions
@@ -38,6 +38,7 @@ from ...stage import Stage
from ...utils import (
CREATE_FILE_TOOL,
EXECUTE_SHELL_TOOL,
FILE_DOWNLOAD_TOOL,
FILE_UPLOAD_TOOL,
KNOWLEDGE_BASE_QUERY_TOOL,
LLM_SAFETY_MODE_SYSTEM_PROMPT,
@@ -492,6 +493,7 @@ class InternalAgentSubStage(Stage):
req.func_tool.add_tool(EXECUTE_SHELL_TOOL)
req.func_tool.add_tool(PYTHON_TOOL)
req.func_tool.add_tool(FILE_UPLOAD_TOOL)
req.func_tool.add_tool(FILE_DOWNLOAD_TOOL)
async def process(
self, event: AstrMessageEvent, provider_wake_prefix: str
@@ -10,6 +10,7 @@ from astrbot.core.astr_agent_context import AstrAgentContext
from astrbot.core.sandbox.tools import (
CreateFileTool,
ExecuteShellTool,
FileDownloadTool,
FileUploadTool,
PythonTool,
ReadFileTool,
@@ -150,6 +151,7 @@ READ_FILE_TOOL = ReadFileTool()
EXECUTE_SHELL_TOOL = ExecuteShellTool()
PYTHON_TOOL = PythonTool()
FILE_UPLOAD_TOOL = FileUploadTool()
FILE_DOWNLOAD_TOOL = FileDownloadTool()
# we prevent astrbot from connecting to known malicious hosts
# these hosts are base64 encoded
@@ -94,7 +94,7 @@ class WebChatMessageEvent(AstrMessageEvent):
filename = f"{uuid.uuid4()!s}{ext}"
dest_path = os.path.join(imgs_dir, filename)
shutil.copy2(file_path, dest_path)
data = f"[FILE]{filename}|{original_name}"
data = f"[FILE]{filename}"
await web_chat_back_queue.put(
{
"type": "file",
+4
View File
@@ -22,6 +22,10 @@ class SandboxBooter:
"""
...
async def download_file(self, remote_path: str, local_path: str):
"""Download file from sandbox."""
...
async def available(self) -> bool:
"""Check if the sandbox is available."""
...
+5 -4
View File
@@ -1,5 +1,3 @@
import uuid
from shipyard import ShipyardClient, Spec
from astrbot.api import logger
@@ -23,12 +21,11 @@ class ShipyardBooter(SandboxBooter):
self._session_num = session_num
async def boot(self, session_id: str) -> None:
uuid_str = uuid.uuid5(uuid.NAMESPACE_DNS, session_id).hex
ship = await self._sandbox_client.create_ship(
ttl=self._ttl,
spec=Spec(cpus=1.0, memory="512m"),
max_session_num=self._session_num,
session_id=uuid_str,
session_id=session_id,
)
logger.info(f"Got sandbox ship: {ship.id} for session: {session_id}")
self._ship = ship
@@ -52,6 +49,10 @@ class ShipyardBooter(SandboxBooter):
"""Upload file to sandbox"""
return await self._ship.upload_file(path, file_name)
async def download_file(self, remote_path: str, local_path: str):
"""Download file from sandbox."""
return await self._ship.download_file(remote_path, local_path)
async def available(self) -> bool:
"""Check if the sandbox is available."""
try:
+2 -1
View File
@@ -1,4 +1,4 @@
from .fs import CreateFileTool, FileUploadTool, ReadFileTool
from .fs import CreateFileTool, FileDownloadTool, FileUploadTool, ReadFileTool
from .python import PythonTool
from .shell import ExecuteShellTool
@@ -8,4 +8,5 @@ __all__ = [
"FileUploadTool",
"PythonTool",
"ExecuteShellTool",
"FileDownloadTool",
]
+58
View File
@@ -3,9 +3,12 @@ import os
from dataclasses import dataclass, field
from astrbot.api import FunctionTool, logger
from astrbot.api.event import MessageChain
from astrbot.core.agent.run_context import ContextWrapper
from astrbot.core.agent.tool import ToolExecResult
from astrbot.core.astr_agent_context import AstrAgentContext
from astrbot.core.message.components import File
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from ..sandbox_client import get_booter
@@ -130,3 +133,58 @@ class FileUploadTool(FunctionTool):
except Exception as e:
logger.error(f"Error uploading file {local_path}: {e}")
return f"Error uploading file: {str(e)}"
@dataclass
class FileDownloadTool(FunctionTool):
name: str = "astrbot_download_file"
description: str = "Download a file from the sandbox. Only call this when user explicitly need you to download a file."
parameters: dict = field(
default_factory=lambda: {
"type": "object",
"properties": {
"remote_path": {
"type": "string",
"description": "The path of the file in the sandbox to download.",
}
},
"required": ["remote_path"],
}
)
async def call(
self,
context: ContextWrapper[AstrAgentContext],
remote_path: str,
) -> ToolExecResult:
sb = await get_booter(
context.context.context,
context.context.event.unified_msg_origin,
)
try:
name = os.path.basename(remote_path)
local_path = os.path.join(get_astrbot_temp_path(), name)
# Download file from sandbox
await sb.download_file(remote_path, local_path)
logger.info(f"File {remote_path} downloaded from sandbox to {local_path}")
try:
name = os.path.basename(local_path)
await context.context.event.send(
MessageChain(chain=[File(name=name, file=local_path)])
)
except Exception as e:
logger.error(f"Error sending file message: {e}")
# remove
try:
os.remove(local_path)
except Exception as e:
logger.error(f"Error removing temp file {local_path}: {e}")
return f"File downloaded successfully to {local_path}"
except Exception as e:
logger.error(f"Error downloading file {remote_path}: {e}")
return f"Error downloading file: {str(e)}"
+1 -1
View File
@@ -60,7 +60,7 @@ dependencies = [
"markitdown-no-magika[docx,xls,xlsx]>=0.1.2",
"xinference-client",
"tenacity>=9.1.2",
"shipyard-python-sdk>=0.2.3",
"shipyard-python-sdk>=0.2.4",
]
[dependency-groups]