From 792e348076ff8c9857ee136cb561fe73fffefe8e Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Tue, 13 Jan 2026 21:28:56 +0800 Subject: [PATCH] feat: add file download functionality and update shipyard SDK version --- .../method/agent_sub_stages/internal.py | 2 + astrbot/core/pipeline/process_stage/utils.py | 2 + .../platform/sources/webchat/webchat_event.py | 2 +- astrbot/core/sandbox/booters/base.py | 4 ++ astrbot/core/sandbox/booters/shipyard.py | 9 +-- astrbot/core/sandbox/tools/__init__.py | 3 +- astrbot/core/sandbox/tools/fs.py | 58 +++++++++++++++++++ pyproject.toml | 2 +- 8 files changed, 75 insertions(+), 7 deletions(-) diff --git a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py index 05d3a01c4..42e0273b0 100644 --- a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py +++ b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py @@ -38,6 +38,7 @@ from ...stage import Stage from ...utils import ( CREATE_FILE_TOOL, EXECUTE_SHELL_TOOL, + FILE_DOWNLOAD_TOOL, FILE_UPLOAD_TOOL, KNOWLEDGE_BASE_QUERY_TOOL, LLM_SAFETY_MODE_SYSTEM_PROMPT, @@ -492,6 +493,7 @@ class InternalAgentSubStage(Stage): req.func_tool.add_tool(EXECUTE_SHELL_TOOL) req.func_tool.add_tool(PYTHON_TOOL) req.func_tool.add_tool(FILE_UPLOAD_TOOL) + req.func_tool.add_tool(FILE_DOWNLOAD_TOOL) async def process( self, event: AstrMessageEvent, provider_wake_prefix: str diff --git a/astrbot/core/pipeline/process_stage/utils.py b/astrbot/core/pipeline/process_stage/utils.py index 4826e9695..dc60c8025 100644 --- a/astrbot/core/pipeline/process_stage/utils.py +++ b/astrbot/core/pipeline/process_stage/utils.py @@ -10,6 +10,7 @@ from astrbot.core.astr_agent_context import AstrAgentContext from astrbot.core.sandbox.tools import ( CreateFileTool, ExecuteShellTool, + FileDownloadTool, FileUploadTool, PythonTool, ReadFileTool, @@ -150,6 +151,7 @@ READ_FILE_TOOL = ReadFileTool() EXECUTE_SHELL_TOOL = ExecuteShellTool() PYTHON_TOOL = PythonTool() FILE_UPLOAD_TOOL = FileUploadTool() +FILE_DOWNLOAD_TOOL = FileDownloadTool() # we prevent astrbot from connecting to known malicious hosts # these hosts are base64 encoded diff --git a/astrbot/core/platform/sources/webchat/webchat_event.py b/astrbot/core/platform/sources/webchat/webchat_event.py index 2e529bb1d..0a2c4b8a5 100644 --- a/astrbot/core/platform/sources/webchat/webchat_event.py +++ b/astrbot/core/platform/sources/webchat/webchat_event.py @@ -94,7 +94,7 @@ class WebChatMessageEvent(AstrMessageEvent): filename = f"{uuid.uuid4()!s}{ext}" dest_path = os.path.join(imgs_dir, filename) shutil.copy2(file_path, dest_path) - data = f"[FILE]{filename}|{original_name}" + data = f"[FILE]{filename}" await web_chat_back_queue.put( { "type": "file", diff --git a/astrbot/core/sandbox/booters/base.py b/astrbot/core/sandbox/booters/base.py index cee5644db..fb2d31bdf 100644 --- a/astrbot/core/sandbox/booters/base.py +++ b/astrbot/core/sandbox/booters/base.py @@ -22,6 +22,10 @@ class SandboxBooter: """ ... + async def download_file(self, remote_path: str, local_path: str): + """Download file from sandbox.""" + ... + async def available(self) -> bool: """Check if the sandbox is available.""" ... diff --git a/astrbot/core/sandbox/booters/shipyard.py b/astrbot/core/sandbox/booters/shipyard.py index 9cd6ce15a..5ca81af23 100644 --- a/astrbot/core/sandbox/booters/shipyard.py +++ b/astrbot/core/sandbox/booters/shipyard.py @@ -1,5 +1,3 @@ -import uuid - from shipyard import ShipyardClient, Spec from astrbot.api import logger @@ -23,12 +21,11 @@ class ShipyardBooter(SandboxBooter): self._session_num = session_num async def boot(self, session_id: str) -> None: - uuid_str = uuid.uuid5(uuid.NAMESPACE_DNS, session_id).hex ship = await self._sandbox_client.create_ship( ttl=self._ttl, spec=Spec(cpus=1.0, memory="512m"), max_session_num=self._session_num, - session_id=uuid_str, + session_id=session_id, ) logger.info(f"Got sandbox ship: {ship.id} for session: {session_id}") self._ship = ship @@ -52,6 +49,10 @@ class ShipyardBooter(SandboxBooter): """Upload file to sandbox""" return await self._ship.upload_file(path, file_name) + async def download_file(self, remote_path: str, local_path: str): + """Download file from sandbox.""" + return await self._ship.download_file(remote_path, local_path) + async def available(self) -> bool: """Check if the sandbox is available.""" try: diff --git a/astrbot/core/sandbox/tools/__init__.py b/astrbot/core/sandbox/tools/__init__.py index 0ff6d7699..2b22479be 100644 --- a/astrbot/core/sandbox/tools/__init__.py +++ b/astrbot/core/sandbox/tools/__init__.py @@ -1,4 +1,4 @@ -from .fs import CreateFileTool, FileUploadTool, ReadFileTool +from .fs import CreateFileTool, FileDownloadTool, FileUploadTool, ReadFileTool from .python import PythonTool from .shell import ExecuteShellTool @@ -8,4 +8,5 @@ __all__ = [ "FileUploadTool", "PythonTool", "ExecuteShellTool", + "FileDownloadTool", ] diff --git a/astrbot/core/sandbox/tools/fs.py b/astrbot/core/sandbox/tools/fs.py index 66fa8d303..3214abfe7 100644 --- a/astrbot/core/sandbox/tools/fs.py +++ b/astrbot/core/sandbox/tools/fs.py @@ -3,9 +3,12 @@ import os from dataclasses import dataclass, field from astrbot.api import FunctionTool, logger +from astrbot.api.event import MessageChain from astrbot.core.agent.run_context import ContextWrapper from astrbot.core.agent.tool import ToolExecResult from astrbot.core.astr_agent_context import AstrAgentContext +from astrbot.core.message.components import File +from astrbot.core.utils.astrbot_path import get_astrbot_temp_path from ..sandbox_client import get_booter @@ -130,3 +133,58 @@ class FileUploadTool(FunctionTool): except Exception as e: logger.error(f"Error uploading file {local_path}: {e}") return f"Error uploading file: {str(e)}" + + +@dataclass +class FileDownloadTool(FunctionTool): + name: str = "astrbot_download_file" + description: str = "Download a file from the sandbox. Only call this when user explicitly need you to download a file." + parameters: dict = field( + default_factory=lambda: { + "type": "object", + "properties": { + "remote_path": { + "type": "string", + "description": "The path of the file in the sandbox to download.", + } + }, + "required": ["remote_path"], + } + ) + + async def call( + self, + context: ContextWrapper[AstrAgentContext], + remote_path: str, + ) -> ToolExecResult: + sb = await get_booter( + context.context.context, + context.context.event.unified_msg_origin, + ) + try: + name = os.path.basename(remote_path) + + local_path = os.path.join(get_astrbot_temp_path(), name) + + # Download file from sandbox + await sb.download_file(remote_path, local_path) + logger.info(f"File {remote_path} downloaded from sandbox to {local_path}") + + try: + name = os.path.basename(local_path) + await context.context.event.send( + MessageChain(chain=[File(name=name, file=local_path)]) + ) + except Exception as e: + logger.error(f"Error sending file message: {e}") + + # remove + try: + os.remove(local_path) + except Exception as e: + logger.error(f"Error removing temp file {local_path}: {e}") + + return f"File downloaded successfully to {local_path}" + except Exception as e: + logger.error(f"Error downloading file {remote_path}: {e}") + return f"Error downloading file: {str(e)}" diff --git a/pyproject.toml b/pyproject.toml index 89218c3f1..923deab15 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,7 +60,7 @@ dependencies = [ "markitdown-no-magika[docx,xls,xlsx]>=0.1.2", "xinference-client", "tenacity>=9.1.2", - "shipyard-python-sdk>=0.2.3", + "shipyard-python-sdk>=0.2.4", ] [dependency-groups]