feat: implement singleton pattern for ShipyardSandboxClient and add FileUploadTool for file uploads
This commit is contained in:
@@ -22,6 +22,11 @@ class ShipyardSandboxClient:
|
||||
)
|
||||
ShipyardSandboxClient._initialized = True
|
||||
|
||||
def __new__(cls) -> "ShipyardSandboxClient":
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
|
||||
class ShipyardBooter(SandboxBooter):
|
||||
def __init__(self):
|
||||
|
||||
@@ -1,131 +0,0 @@
|
||||
import os
|
||||
|
||||
import astrbot.api.message_components as Comp
|
||||
from astrbot.api import logger
|
||||
from astrbot.api.all import Context
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
from astrbot.core.utils.session_waiter import (
|
||||
SessionController,
|
||||
session_waiter,
|
||||
)
|
||||
|
||||
from ..sandbox_client import SandboxClient
|
||||
|
||||
|
||||
class FileCommand:
|
||||
def __init__(self, context: Context) -> None:
|
||||
self.context = context
|
||||
self.user_file_uploads: dict[str, list[str]] = {} # umo -> file_path
|
||||
self.user_file_uploaded_files: dict[str, list[str]] = {} # umo -> file_path
|
||||
"""记录用户上传过的文件,保存了文件在沙箱中的路径。这个在用户下一次 LLM 请求时会被用到,然后清空。"""
|
||||
|
||||
async def _upload_file_to_sandbox(self, event: AstrMessageEvent) -> list[str]:
|
||||
"""将用户上传的文件上传到沙箱"""
|
||||
sender_id = event.get_sender_id()
|
||||
sb = await SandboxClient().get_booter(event.unified_msg_origin)
|
||||
fpath_ls = self.user_file_uploads[sender_id]
|
||||
errors = []
|
||||
for path in fpath_ls:
|
||||
try:
|
||||
fname = os.path.basename(path)
|
||||
data = await sb.upload_file(path, fname)
|
||||
success = data.get("success", False)
|
||||
if not success:
|
||||
raise Exception(f"Upload failed: {data}")
|
||||
file_path = data.get("file_path", "")
|
||||
logger.info(f"File {fname} uploaded to sandbox at {file_path}")
|
||||
self.user_file_uploaded_files.setdefault(sender_id, []).append(
|
||||
file_path
|
||||
)
|
||||
except Exception as e:
|
||||
errors.append((path, str(e)))
|
||||
logger.error(f"Error uploading file {path}: {e}")
|
||||
|
||||
# clean up files
|
||||
for path in fpath_ls:
|
||||
try:
|
||||
os.remove(path)
|
||||
except Exception as e:
|
||||
logger.error(f"Error removing temp file {path}: {e}")
|
||||
|
||||
return errors
|
||||
|
||||
async def file(self, event: AstrMessageEvent):
|
||||
"""等待用户上传文件或图片"""
|
||||
await event.send(
|
||||
MessageChain().message(
|
||||
f"请上传一个或多个文件(或图片),使用 /endupload 结束上传。(请求者 ID: {event.get_sender_id()})"
|
||||
)
|
||||
)
|
||||
try:
|
||||
|
||||
@session_waiter(timeout=600, record_history_chains=False) # type: ignore
|
||||
async def empty_mention_waiter(
|
||||
controller: SessionController, event: AstrMessageEvent
|
||||
):
|
||||
idiom = event.message_str
|
||||
sender_id = event.get_sender_id()
|
||||
|
||||
if idiom == "endupload":
|
||||
files = self.user_file_uploads.get(sender_id, [])
|
||||
if not files:
|
||||
await event.send(
|
||||
event.plain_result("你没有上传任何文件,上传已取消。")
|
||||
)
|
||||
controller.stop()
|
||||
return
|
||||
await event.send(
|
||||
event.plain_result(f"开始上传 {len(files)} 个文件到沙箱...")
|
||||
)
|
||||
errors = await self._upload_file_to_sandbox(event)
|
||||
if errors:
|
||||
error_msgs = "\n".join(
|
||||
[f"{path}: {err}" for path, err in errors]
|
||||
)
|
||||
await event.send(
|
||||
event.plain_result(
|
||||
f"上传中出现错误:\n{error_msgs}\n其他文件已成功上传。"
|
||||
)
|
||||
)
|
||||
else:
|
||||
await event.send(
|
||||
event.plain_result(
|
||||
f"上传完毕,共上传 {len(files)} 个文件。文件信息已被保存,下一次 LLM 请求时会自动将信息附上。"
|
||||
)
|
||||
)
|
||||
self.user_file_uploads.pop(sender_id, None)
|
||||
controller.stop()
|
||||
return
|
||||
|
||||
# 解析文件或图片消息
|
||||
for comp in event.message_obj.message:
|
||||
if isinstance(comp, (Comp.File, Comp.Image)):
|
||||
if isinstance(comp, Comp.File):
|
||||
path = await comp.get_file()
|
||||
self.user_file_uploads.setdefault(
|
||||
event.get_sender_id(), []
|
||||
).append(path)
|
||||
elif isinstance(comp, Comp.Image):
|
||||
path = await comp.convert_to_file_path()
|
||||
self.user_file_uploads.setdefault(
|
||||
event.get_sender_id(), []
|
||||
).append(path)
|
||||
fname = os.path.basename(path)
|
||||
await event.send(
|
||||
event.plain_result(
|
||||
f"已接收文件: {fname},继续上传或发送 /endupload 结束。"
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
await empty_mention_waiter(event)
|
||||
except TimeoutError as _:
|
||||
await event.send(event.plain_result("等待上传超时,上传已取消。"))
|
||||
except Exception as e:
|
||||
await event.send(
|
||||
event.plain_result("发生错误,请联系管理员: " + str(e))
|
||||
)
|
||||
finally:
|
||||
event.stop_event()
|
||||
except Exception as e:
|
||||
logger.error("handle_empty_mention error: " + str(e))
|
||||
@@ -1,12 +1,9 @@
|
||||
import os
|
||||
|
||||
import astrbot.api.star as star
|
||||
from astrbot.api import AstrBotConfig, logger
|
||||
from astrbot.api.event import AstrMessageEvent, filter
|
||||
from astrbot.api.provider import ProviderRequest
|
||||
from astrbot.api import AstrBotConfig
|
||||
|
||||
from .commands.file import FileCommand
|
||||
from .tools.fs import CreateFileTool, ReadFileTool
|
||||
from .tools.fs import CreateFileTool, FileUploadTool, ReadFileTool
|
||||
from .tools.python import PythonTool
|
||||
from .tools.shell import ExecuteShellTool
|
||||
|
||||
@@ -24,25 +21,9 @@ class Main(star.Star):
|
||||
os.environ["SHIPYARD_ACCESS_TOKEN"] = self.access_token
|
||||
|
||||
context.add_llm_tools(
|
||||
CreateFileTool(), ExecuteShellTool(), PythonTool(), ReadFileTool()
|
||||
CreateFileTool(),
|
||||
ExecuteShellTool(),
|
||||
PythonTool(),
|
||||
ReadFileTool(),
|
||||
FileUploadTool(),
|
||||
)
|
||||
|
||||
self.file_c = FileCommand(context)
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
@filter.command("fileupload")
|
||||
async def fileupload(self, event: AstrMessageEvent):
|
||||
"""处理文件上传"""
|
||||
await self.file_c.file(event)
|
||||
|
||||
@filter.on_llm_request()
|
||||
async def on_llm_request(self, event: AstrMessageEvent, req: ProviderRequest):
|
||||
"""处理 LLM 请求"""
|
||||
sender_id = event.get_sender_id()
|
||||
uploads = self.file_c.user_file_uploaded_files.pop(sender_id, None)
|
||||
if uploads:
|
||||
logger.info(f"Attaching uploaded files for user {sender_id}: {uploads}")
|
||||
|
||||
req.system_prompt = f"{req.system_prompt}\n\n\n# User Uploaded Files: {uploads}"
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from astrbot.api import FunctionTool
|
||||
from astrbot.api import FunctionTool, logger
|
||||
from astrbot.api.event import AstrMessageEvent
|
||||
|
||||
from ..sandbox_client import SandboxClient
|
||||
@@ -61,3 +62,58 @@ class ReadFileTool(FunctionTool):
|
||||
return result
|
||||
except Exception as e:
|
||||
return f"Error reading file: {str(e)}"
|
||||
|
||||
|
||||
@dataclass
|
||||
class FileUploadTool(FunctionTool):
|
||||
name: str = "astrbot_upload_file"
|
||||
description: str = "Upload a local file to the sandbox. The file must exist on the local filesystem."
|
||||
parameters: dict = field(
|
||||
default_factory=lambda: {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"local_path": {
|
||||
"type": "string",
|
||||
"description": "The local file path to upload. This must be an absolute path to an existing file on the local filesystem.",
|
||||
},
|
||||
# "remote_path": {
|
||||
# "type": "string",
|
||||
# "description": "The filename to use in the sandbox. If not provided, file will be saved to the working directory with the same name as the local file.",
|
||||
# },
|
||||
},
|
||||
"required": ["local_path"],
|
||||
}
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
event: AstrMessageEvent,
|
||||
local_path: str,
|
||||
):
|
||||
sb = await SandboxClient().get_booter(event.unified_msg_origin)
|
||||
try:
|
||||
# Check if file exists
|
||||
if not os.path.exists(local_path):
|
||||
return f"Error: File does not exist: {local_path}"
|
||||
|
||||
if not os.path.isfile(local_path):
|
||||
return f"Error: Path is not a file: {local_path}"
|
||||
|
||||
# Use basename if sandbox_filename is not provided
|
||||
remote_path = os.path.basename(local_path)
|
||||
|
||||
# Upload file to sandbox
|
||||
result = await sb.upload_file(local_path, remote_path)
|
||||
logger.debug(f"Upload result: {result}")
|
||||
success = result.get("success", False)
|
||||
|
||||
if not success:
|
||||
return f"Error uploading file: {result.get('message', 'Unknown error')}"
|
||||
|
||||
file_path = result.get("file_path", "")
|
||||
logger.info(f"File {local_path} uploaded to sandbox at {file_path}")
|
||||
|
||||
return f"File uploaded successfully to {file_path}"
|
||||
except Exception as e:
|
||||
logger.error(f"Error uploading file {local_path}: {e}")
|
||||
return f"Error uploading file: {str(e)}"
|
||||
|
||||
@@ -5,7 +5,7 @@ import json
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.agent.message import Message
|
||||
from astrbot.core.agent.message import Message, TextPart
|
||||
from astrbot.core.agent.response import AgentStats
|
||||
from astrbot.core.agent.tool import ToolSet
|
||||
from astrbot.core.astr_agent_context import AstrAgentContext
|
||||
@@ -536,6 +536,10 @@ class InternalAgentSubStage(Stage):
|
||||
image_path = await comp.convert_to_file_path()
|
||||
req.image_urls.append(image_path)
|
||||
|
||||
req.extra_user_content_parts.append(
|
||||
TextPart(text=f"[Image Attachment: path {image_path}]")
|
||||
)
|
||||
|
||||
conversation = await self._get_session_conv(event)
|
||||
req.conversation = conversation
|
||||
req.contexts = json.loads(conversation.history)
|
||||
|
||||
@@ -28,7 +28,7 @@
|
||||
<div v-else-if="part.type === 'image' && part.embedded_url" class="image-attachments">
|
||||
<div class="image-attachment">
|
||||
<img :src="part.embedded_url" class="attached-image"
|
||||
@click="$emit('openImagePreview', part.embedded_url)" />
|
||||
@click="openImagePreview(part.embedded_url)" />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -178,7 +178,7 @@
|
||||
<div v-else-if="part.type === 'image' && part.embedded_url" class="embedded-images">
|
||||
<div class="embedded-image">
|
||||
<img :src="part.embedded_url" class="bot-embedded-image"
|
||||
@click="$emit('openImagePreview', part.embedded_url)" />
|
||||
@click="openImagePreview(part.embedded_url)" />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -289,6 +289,13 @@
|
||||
</v-btn>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 图片预览 Overlay -->
|
||||
<v-overlay v-model="imagePreview.show" class="image-preview-overlay" @click="closeImagePreview">
|
||||
<div class="image-preview-container" @click.stop>
|
||||
<img :src="imagePreview.url" class="preview-image" @click="closeImagePreview" />
|
||||
</div>
|
||||
</v-overlay>
|
||||
</template>
|
||||
|
||||
<script>
|
||||
@@ -351,6 +358,11 @@ export default {
|
||||
content: '',
|
||||
messageIndex: null,
|
||||
position: { top: 0, left: 0 }
|
||||
},
|
||||
// 图片预览
|
||||
imagePreview: {
|
||||
show: false,
|
||||
url: ''
|
||||
}
|
||||
};
|
||||
},
|
||||
@@ -676,7 +688,7 @@ export default {
|
||||
if (!img.hasAttribute('data-click-enabled')) {
|
||||
img.style.cursor = 'pointer';
|
||||
img.setAttribute('data-click-enabled', 'true');
|
||||
img.onclick = () => this.$emit('openImagePreview', img.src);
|
||||
img.onclick = () => this.openImagePreview(img.src);
|
||||
}
|
||||
});
|
||||
});
|
||||
@@ -877,6 +889,20 @@ export default {
|
||||
formatTTFT(ttft) {
|
||||
if (!ttft || ttft <= 0) return '';
|
||||
return this.formatDuration(ttft);
|
||||
},
|
||||
|
||||
// 打开图片预览
|
||||
openImagePreview(url) {
|
||||
this.imagePreview.url = url;
|
||||
this.imagePreview.show = true;
|
||||
},
|
||||
|
||||
// 关闭图片预览
|
||||
closeImagePreview() {
|
||||
this.imagePreview.show = false;
|
||||
setTimeout(() => {
|
||||
this.imagePreview.url = '';
|
||||
}, 300);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1268,10 +1294,10 @@ export default {
|
||||
}
|
||||
|
||||
.bot-embedded-image {
|
||||
max-width: 40%;
|
||||
max-width: 55%;
|
||||
width: auto;
|
||||
height: auto;
|
||||
border-radius: 8px;
|
||||
border-radius: 4px;
|
||||
cursor: pointer;
|
||||
transition: transform 0.2s ease;
|
||||
}
|
||||
@@ -1423,12 +1449,14 @@ export default {
|
||||
overflow: hidden;
|
||||
background-color: #eff3f6;
|
||||
margin: 8px 0px;
|
||||
max-width: 300px;
|
||||
transition: max-width 0.1s ease;
|
||||
width: fit-content;
|
||||
min-width: 320px;
|
||||
max-width: 100%;
|
||||
transition: all 0.1s ease;
|
||||
}
|
||||
|
||||
.tool-call-card.expanded {
|
||||
max-width: 100%;
|
||||
width: 100%;
|
||||
}
|
||||
|
||||
.tool-call-header {
|
||||
@@ -1635,4 +1663,36 @@ export default {
|
||||
font-family: 'Fira Code', 'Consolas', monospace;
|
||||
color: var(--v-theme-primaryText);
|
||||
}
|
||||
|
||||
/* 图片预览样式 */
|
||||
.image-preview-overlay {
|
||||
z-index: 9999;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
}
|
||||
|
||||
.image-preview-container {
|
||||
position: relative;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
}
|
||||
|
||||
.preview-image {
|
||||
max-width: 90vw;
|
||||
max-height: 90vh;
|
||||
object-fit: contain;
|
||||
border-radius: 8px;
|
||||
box-shadow: 0 8px 32px rgba(0, 0, 0, 0.3);
|
||||
cursor: pointer;
|
||||
}
|
||||
|
||||
.close-preview-btn {
|
||||
position: fixed;
|
||||
top: 20px;
|
||||
right: 20px;
|
||||
}
|
||||
</style>
|
||||
|
||||
Reference in New Issue
Block a user