diff --git a/astrbot/core/message/components.py b/astrbot/core/message/components.py index d9ec4b41b..480c06909 100644 --- a/astrbot/core/message/components.py +++ b/astrbot/core/message/components.py @@ -813,7 +813,8 @@ class File(BaseMessageComponent): """下载文件""" download_dir = os.path.join(get_astrbot_data_path(), "temp") os.makedirs(download_dir, exist_ok=True) - file_path = os.path.join(download_dir, f"{uuid.uuid4().hex}") + fname = self.name if self.name else uuid.uuid4().hex + file_path = os.path.join(download_dir, fname) await download_file(self.url, file_path) self.file_ = os.path.abspath(file_path) diff --git a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py index d1992b6c3..0e78c45aa 100644 --- a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py +++ b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py @@ -232,7 +232,9 @@ class AiocqhttpAdapter(Platform): if m["data"].get("url") and m["data"].get("url").startswith("http"): # Lagrange logger.info("guessing lagrange") - file_name = m["data"].get("file_name", "file") + file_name = m["data"].get( + "file_name", m["data"].get("file", "file") + ) abm.message.append(File(name=file_name, url=m["data"]["url"])) else: try: diff --git a/packages/astrbot_agent/commands/file.py b/packages/astrbot_agent/commands/file.py new file mode 100644 index 000000000..110d8d1d0 --- /dev/null +++ b/packages/astrbot_agent/commands/file.py @@ -0,0 +1,129 @@ +import os +from astrbot.api.event import AstrMessageEvent, MessageChain +from astrbot.api import logger +from astrbot.api.all import Context +import astrbot.api.message_components as Comp +from astrbot.core.utils.session_waiter import ( + session_waiter, + SessionController, +) +from ..sandbox_client import SandboxClient + + +class FileCommand: + def __init__(self, context: Context) -> None: + self.context = context + self.user_file_uploads: dict[str, list[str]] = {} # umo -> file_path + self.user_file_uploaded_files: dict[str, list[str]] = {} # umo -> file_path + """记录用户上传过的文件,保存了文件在沙箱中的路径。这个在用户下一次 LLM 请求时会被用到,然后清空。""" + + async def _upload_file_to_sandbox(self, event: AstrMessageEvent) -> list[str]: + """将用户上传的文件上传到沙箱""" + sender_id = event.get_sender_id() + sb = await SandboxClient().get_ship(event.unified_msg_origin) + fpath_ls = self.user_file_uploads[sender_id] + errors = [] + for path in fpath_ls: + try: + fname = os.path.basename(path) + data = await sb.upload_file(path, fname) + success = data.get("success", False) + if not success: + raise Exception(f"Upload failed: {data}") + file_path = data.get("file_path", "") + logger.info(f"File {fname} uploaded to sandbox at {file_path}") + self.user_file_uploaded_files.setdefault(sender_id, []).append( + file_path + ) + except Exception as e: + errors.append((path, str(e))) + logger.error(f"Error uploading file {path}: {e}") + + # clean up files + for path in fpath_ls: + try: + os.remove(path) + except Exception as e: + logger.error(f"Error removing temp file {path}: {e}") + + return errors + + async def file(self, event: AstrMessageEvent): + """等待用户上传文件或图片""" + await event.send( + MessageChain().message( + f"请上传一个或多个文件(或图片),使用 /endupload 结束上传。(请求者 ID: {event.get_sender_id()})" + ) + ) + try: + + @session_waiter(timeout=600, record_history_chains=False) # type: ignore + async def empty_mention_waiter( + controller: SessionController, event: AstrMessageEvent + ): + idiom = event.message_str + sender_id = event.get_sender_id() + + if idiom == "endupload": + files = self.user_file_uploads.get(sender_id, []) + if not files: + await event.send( + event.plain_result("你没有上传任何文件,上传已取消。") + ) + controller.stop() + return + await event.send( + event.plain_result(f"开始上传 {len(files)} 个文件到沙箱...") + ) + errors = await self._upload_file_to_sandbox(event) + if errors: + error_msgs = "\n".join( + [f"{path}: {err}" for path, err in errors] + ) + await event.send( + event.plain_result( + f"上传中出现错误:\n{error_msgs}\n其他文件已成功上传。" + ) + ) + else: + await event.send( + event.plain_result( + f"上传完毕,共上传 {len(files)} 个文件。文件信息已被保存,下一次 LLM 请求时会自动将信息附上。" + ) + ) + self.user_file_uploads.pop(sender_id, None) + controller.stop() + return + + # 解析文件或图片消息 + for comp in event.message_obj.message: + if isinstance(comp, (Comp.File, Comp.Image)): + if isinstance(comp, Comp.File): + path = await comp.get_file() + self.user_file_uploads.setdefault( + event.get_sender_id(), [] + ).append(path) + elif isinstance(comp, Comp.Image): + path = await comp.convert_to_file_path() + self.user_file_uploads.setdefault( + event.get_sender_id(), [] + ).append(path) + fname = os.path.basename(path) + await event.send( + event.plain_result( + f"已接收文件: {fname},继续上传或发送 /endupload 结束。" + ) + ) + + try: + await empty_mention_waiter(event) + except TimeoutError as _: + await event.send(event.plain_result("等待上传超时,上传已取消。")) + except Exception as e: + await event.send( + event.plain_result("发生错误,请联系管理员: " + str(e)) + ) + finally: + event.stop_event() + except Exception as e: + logger.error("handle_empty_mention error: " + str(e)) diff --git a/packages/astrbot_agent/main.py b/packages/astrbot_agent/main.py index 52d589416..89bae42a6 100644 --- a/packages/astrbot_agent/main.py +++ b/packages/astrbot_agent/main.py @@ -1,9 +1,13 @@ import os import astrbot.api.star as star +from astrbot.api import logger +from astrbot.api.event import filter, AstrMessageEvent +from astrbot.api.provider import ProviderRequest from astrbot.api import AstrBotConfig -from .tools.fs import CreateFileTool +from .tools.fs import CreateFileTool, ReadFileTool from .tools.shell import ExecuteShellTool from .tools.python import PythonTool +from .commands.file import FileCommand class Main(star.Star): @@ -17,7 +21,26 @@ class Main(star.Star): os.environ["SHIPYARD_ENDPOINT"] = self.endpoint os.environ["SHIPYARD_ACCESS_TOKEN"] = self.access_token - context.add_llm_tool(CreateFileTool(), ExecuteShellTool(), PythonTool()) + context.add_llm_tool( + CreateFileTool(), ExecuteShellTool(), PythonTool(), ReadFileTool() + ) + + self.file_c = FileCommand(context) async def initialize(self): pass + + @filter.command("fileupload") + async def fileupload(self, event: AstrMessageEvent): + """处理文件上传""" + await self.file_c.file(event) + + @filter.on_llm_request() + async def on_llm_request(self, event: AstrMessageEvent, req: ProviderRequest): + """处理 LLM 请求""" + sender_id = event.get_sender_id() + uploads = self.file_c.user_file_uploaded_files.pop(sender_id, None) + if uploads: + logger.info(f"Attaching uploaded files for user {sender_id}: {uploads}") + + req.system_prompt = f"{req.system_prompt}\n\n\n# User Uploaded Files: {uploads}" diff --git a/packages/astrbot_agent/tools/fs.py b/packages/astrbot_agent/tools/fs.py index 1bb6d987d..b404481d9 100644 --- a/packages/astrbot_agent/tools/fs.py +++ b/packages/astrbot_agent/tools/fs.py @@ -33,3 +33,28 @@ class CreateFileTool(FunctionTool): return json.dumps(result) except Exception as e: return f"Error creating file: {str(e)}" + +@dataclass +class ReadFileTool(FunctionTool): + name: str = "astrbot_read_file" + description: str = "Read the content of a file in the sandbox." + parameters: dict = field( + default_factory=lambda: { + "type": "object", + "properties": { + "path": { + "type": "string", + "description": "The path of the file to read, relative to the sandbox root. Must not use absolute paths or traverse outside the sandbox.", + }, + }, + "required": ["path"], + } + ) + + async def run(self, event: AstrMessageEvent, path: str): + sb = await SandboxClient().get_ship(event.unified_msg_origin) + try: + result = await sb.fs.read_file(path) + return result + except Exception as e: + return f"Error reading file: {str(e)}" diff --git a/pyproject.toml b/pyproject.toml index 78d5c775e..e7a2e49a0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,7 +51,7 @@ dependencies = [ "wechatpy>=1.8.18", "audioop-lts ; python_full_version >= '3.13'", "click>=8.2.1", - "shipyard-python-sdk>=0.1.0", + "shipyard-python-sdk>=0.2.3", ] [project.scripts]