diff --git a/astrbot/core/core_lifecycle.py b/astrbot/core/core_lifecycle.py index ddb285cc9..1e6948355 100644 --- a/astrbot/core/core_lifecycle.py +++ b/astrbot/core/core_lifecycle.py @@ -53,7 +53,7 @@ class AstrBotCoreLifecycle: ) self.plugin_manager = PluginManager(self.star_context, self.astrbot_config) - self.plugin_manager.reload() + await self.plugin_manager.reload() '''扫描、注册插件、实例化插件类''' await self.provider_manager.initialize() diff --git a/astrbot/core/message/components.py b/astrbot/core/message/components.py index 717a0f81b..50e81c86d 100644 --- a/astrbot/core/message/components.py +++ b/astrbot/core/message/components.py @@ -54,6 +54,7 @@ class ComponentType(Enum): CardImage = "CardImage" TTS = "TTS" Unknown = "Unknown" + File = "File" class BaseMessageComponent(BaseModel): @@ -415,6 +416,17 @@ class Unknown(BaseMessageComponent): def toString(self): return "" +class File(BaseMessageComponent): + ''' + 目前此消息段只适配了 Napcat。 + ''' + type: ComponentType = "File" + name: T.Optional[str] = "" # 名字 + file: T.Optional[str] = "" # url(本地路径) + + def __init__(self, name: str, file: str): + super().__init__(name=name, file=file) + ComponentTypes = { "plain": Plain, @@ -441,5 +453,6 @@ ComponentTypes = { "json": Json, "cardimage": CardImage, "tts": TTS, - "unknown": Unknown + "unknown": Unknown, + 'file': File, } diff --git a/astrbot/core/pipeline/process_stage/stage.py b/astrbot/core/pipeline/process_stage/stage.py index a1ab3ddef..e8b3db18a 100644 --- a/astrbot/core/pipeline/process_stage/stage.py +++ b/astrbot/core/pipeline/process_stage/stage.py @@ -43,7 +43,10 @@ class ProcessStage(Stage): yield # 调用提供商相关请求 - if self.ctx.astrbot_config['provider_settings'].get('enable', True) and not event._has_send_oper: + if not self.ctx.astrbot_config['provider_settings'].get('enable', True): + return + + if not event._has_send_oper and event.is_at_or_wake_command: if (event.get_result() and not event.get_result().is_stopped()) or not event.get_result(): provider = self.ctx.plugin_manager.context.get_using_provider() match provider.meta().type: diff --git a/astrbot/core/pipeline/waking_check/stage.py b/astrbot/core/pipeline/waking_check/stage.py index b1aa6059d..253fd9bf4 100644 --- a/astrbot/core/pipeline/waking_check/stage.py +++ b/astrbot/core/pipeline/waking_check/stage.py @@ -47,6 +47,7 @@ class WakingCheckStage(Stage): # 如果是群聊,且第一个消息段是 At 消息,但不是 At 机器人或 At 全体成员,则不唤醒 break is_wake = True + event.is_at_or_wake_command = True event.is_wake = True event.message_str = event.message_str[len(wake_prefix) :].strip() break @@ -60,11 +61,13 @@ class WakingCheckStage(Stage): is_wake = True event.is_wake = True wake_prefix = "" + event.is_at_or_wake_command = True break # 检查是否是私聊 if event.is_private_chat(): is_wake = True event.is_wake = True + event.is_at_or_wake_command = True wake_prefix = "" # 检查插件的 handler filter diff --git a/astrbot/core/platform/astr_message_event.py b/astrbot/core/platform/astr_message_event.py index 65b06149e..a743610a3 100644 --- a/astrbot/core/platform/astr_message_event.py +++ b/astrbot/core/platform/astr_message_event.py @@ -35,7 +35,8 @@ class AstrMessageEvent(abc.ABC): self.platform_meta = platform_meta self.session_id = session_id self.role = "member" - self.is_wake = False + self.is_wake = False # 是否通过 WakingStage + self.is_at_or_wake_command = False # 是否是 At 机器人或者带有唤醒词或者是私聊(事件监听器会让 is_wake 设为 True) self._extras = {} self.session = MessageSesion( platform_name=platform_meta.name, diff --git a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py index fa7fca81d..3d1439a98 100644 --- a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py +++ b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py @@ -1,3 +1,4 @@ +import os import time import asyncio import logging @@ -5,12 +6,13 @@ from typing import Awaitable, Any from aiocqhttp import CQHttp, Event from astrbot.api.platform import Platform, AstrBotMessage, MessageMember, MessageType, PlatformMetadata from astrbot.api.event import MessageChain -from .aiocqhttp_message_event import * -from astrbot.api.message_components import * +from .aiocqhttp_message_event import * # noqa: F403 +from astrbot.api.message_components import * # noqa: F403 from astrbot.api import logger from .aiocqhttp_message_event import AiocqhttpMessageEvent from astrbot.core.platform.astr_message_event import MessageSesion from ...register import register_platform_adapter +from aiocqhttp.exceptions import ActionFailed @register_platform_adapter("aiocqhttp", "适用于 OneBot 标准的消息平台适配器,支持反向 WebSockets。") class AiocqhttpAdapter(Platform): @@ -42,7 +44,7 @@ class AiocqhttpAdapter(Platform): await self.bot.send_private_msg(user_id=session.session_id, message=ret) await super().send_by_session(session, message_chain) - def convert_message(self, event: Event) -> AstrBotMessage: + async def convert_message(self, event: Event) -> AstrBotMessage: abm = AstrBotMessage() abm.self_id = str(event.self_id) abm.tag = "aiocqhttp" @@ -78,7 +80,25 @@ class AiocqhttpAdapter(Platform): a = None if t == 'text': message_str += m['data']['text'].strip() - a = ComponentTypes[t](**m['data']) + elif t == 'file': + try: + # Napcat, LLBot + ret = await self.bot.call_action(action="get_file", file_id=event.message[0]['data']['file_id']) + if not ret.get('file', None): + raise ValueError(f"无法解析文件响应: {ret}") + if not os.path.exists(ret['file']): + raise FileNotFoundError(f"文件不存在: {ret['file']}。如果您使用 Docker 部署了 AstrBot 或者消息协议端(Napcat等),暂时无法获取用户上传的文件。") + + m['data'] = { + "file": ret['file'], + "name": ret['file_name'] + } + except ActionFailed as e: + logger.error(f"获取文件失败: {e},此消息段将被忽略。") + except BaseException as e: + logger.error(f"获取文件失败: {e},此消息段将被忽略。") + + a = ComponentTypes[t](**m['data']) # noqa: F405 abm.message.append(a) abm.timestamp = int(time.time()) abm.message_str = message_str @@ -91,13 +111,13 @@ class AiocqhttpAdapter(Platform): self.bot = CQHttp(use_ws_reverse=True, import_name='aiocqhttp', api_timeout_sec=180) @self.bot.on_message('group') async def group(event: Event): - abm = self.convert_message(event) + abm = await self.convert_message(event) if abm: await self.handle_msg(abm) @self.bot.on_message('private') async def private(event: Event): - abm = self.convert_message(event) + abm = await self.convert_message(event) if abm: await self.handle_msg(abm) diff --git a/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py b/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py index 5c4a992df..bd077a078 100644 --- a/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py +++ b/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py @@ -31,11 +31,13 @@ class QQOfficialMessageEvent(AstrMessageEvent): if image_base64: media = await self.upload_group_and_c2c_image(image_base64, 1, group_openid=source.group_openid) payload['media'] = media + payload['msg_type'] = 7 await self.bot.api.post_group_message(group_openid=source.group_openid, **payload) case botpy.message.C2CMessage: if image_base64: media = await self.upload_group_and_c2c_image(image_base64, 1, openid=source.author.user_openid) payload['media'] = media + payload['msg_type'] = 7 await self.bot.api.post_c2c_message(openid=source.author.user_openid, **payload) case botpy.message.Message: if image_path: @@ -73,9 +75,9 @@ class QQOfficialMessageEvent(AstrMessageEvent): plain_text += i.text elif isinstance(i, Image) and not image_base64: if i.file and i.file.startswith("file:///"): - image_base64 = file_to_base64(i.file[8:]) + image_base64 = file_to_base64(i.file[8:]).replace("base64://", "") image_file_path = i.file[8:] elif i.file and i.file.startswith("http"): image_file_path = await download_image_by_url(i.file) - image_base64 = file_to_base64(image_file_path) + image_base64 = file_to_base64(image_file_path).replace("base64://", "") return plain_text, image_base64, image_file_path \ No newline at end of file diff --git a/astrbot/core/platform/sources/vchat/vchat_platform_adapter.py b/astrbot/core/platform/sources/vchat/vchat_platform_adapter.py index bfc03c013..f224ecf8f 100644 --- a/astrbot/core/platform/sources/vchat/vchat_platform_adapter.py +++ b/astrbot/core/platform/sources/vchat/vchat_platform_adapter.py @@ -2,6 +2,7 @@ import sys import time import uuid import asyncio +import os from astrbot.api.platform import Platform, AstrBotMessage, MessageMember, MessageType, PlatformMetadata from astrbot.api.event import MessageChain @@ -62,7 +63,7 @@ class VChatPlatformAdapter(Platform): self.start_time = int(time.time()) return self._run() - + async def _run(self): await self.client.init() await self.client.auto_login(hot_reload=True, enable_cmd_qr=True) diff --git a/astrbot/core/provider/sources/dify_source.py b/astrbot/core/provider/sources/dify_source.py index 4ee1897e6..fa797eee6 100644 --- a/astrbot/core/provider/sources/dify_source.py +++ b/astrbot/core/provider/sources/dify_source.py @@ -1,4 +1,3 @@ -import base64 from typing import List from .. import Provider from ..entites import LLMResponse diff --git a/astrbot/core/star/register/star_handler.py b/astrbot/core/star/register/star_handler.py index a46c05ae3..f0a86b0c0 100644 --- a/astrbot/core/star/register/star_handler.py +++ b/astrbot/core/star/register/star_handler.py @@ -185,7 +185,7 @@ def register_llm_tool(name: str = None): "description": arg.description }) md = get_handler_or_create(awaitable, EventType.OnCallingFuncToolEvent) - llm_tools.add_func(llm_tool_name, args, docstring.short_description, md.handler) + llm_tools.add_func(llm_tool_name, args, docstring.description, md.handler) logger.debug(f"LLM 函数工具 {llm_tool_name} 已注册") return awaitable diff --git a/astrbot/core/star/star_manager.py b/astrbot/core/star/star_manager.py index 8d33eb3ab..b62854c7d 100644 --- a/astrbot/core/star/star_manager.py +++ b/astrbot/core/star/star_manager.py @@ -137,7 +137,7 @@ class PluginManager: return metadata - def reload(self): + async def reload(self): '''扫描并加载所有的 Star''' for smd in star_registry: logger.debug(f"尝试终止插件 {smd.name} ...") @@ -231,6 +231,10 @@ class PluginManager: if metadata.module_path in inactivated_plugins: metadata.activated = False + # 执行 initialize 函数 + if hasattr(metadata.star_cls, "initialize"): + await metadata.star_cls.initialize() + except BaseException as e: traceback.print_exc() fail_rec += f"加载 {path} 插件时出现问题,原因 {str(e)}\n" @@ -247,7 +251,7 @@ class PluginManager: async def install_plugin(self, repo_url: str): plugin_path = await self.updator.install(repo_url) # reload the plugin - self.reload() + await self.reload() return plugin_path async def uninstall_plugin(self, plugin_name: str): @@ -288,7 +292,7 @@ class PluginManager: raise Exception("该插件是 AstrBot 保留插件,无法更新。") await self.updator.update(plugin) - self.reload() + await self.reload() async def turn_off_plugin(self, plugin_name: str): plugin = self.context.get_registered_star(plugin_name) diff --git a/astrbot/core/utils/param_validation_mixin.py b/astrbot/core/utils/param_validation_mixin.py index 5c72b8961..896d5bc56 100644 --- a/astrbot/core/utils/param_validation_mixin.py +++ b/astrbot/core/utils/param_validation_mixin.py @@ -22,6 +22,9 @@ class ParameterValidationMixin: result[param_name] = int(params[i]) else: result[param_name] = params[i] + elif isinstance(param_type_or_default_val, str): + # 如果 param_type_or_default_val 是字符串,直接赋值 + result[param_name] = params[i] else: result[param_name] = param_type_or_default_val(params[i]) except ValueError: diff --git a/packages/python_interpreter/main.py b/packages/python_interpreter/main.py new file mode 100644 index 000000000..9fca65fba --- /dev/null +++ b/packages/python_interpreter/main.py @@ -0,0 +1,382 @@ +import os +import json +import shutil +import aiohttp +import uuid +import asyncio +import re +import astrbot.api.star as star +import aiodocker +from collections import defaultdict +from astrbot.api.event import AstrMessageEvent, MessageEventResult +from astrbot.api import llm_tool, logger +from astrbot.api.event import filter +from astrbot.api.provider import ProviderRequest +from astrbot.api.message_components import Image, File + +PROMPT = """ +## Task +You need to generate python codes to solve user's problem: {prompt} + +{extra_input} + +## Limit +1. Available libraries: + - standard libs + - `Pillow` + - `requests` + - `numpy` + - `matplotlib` + - `scipy` + - `scikit-learn` + - `beautifulsoup4` + - `pandas` + - `opencv-python` + - `python-docx` + - `python-pptx` + - `pymupdf` (Do not use fpdf, reportlab, etc.) + - `mplfonts` + You can only use these libraries and the libraries that they depend on. +2. Do not generate malicious code. +3. Use given `shared.api` package to output the result. + It has 3 functions: `send_text(text: str)`, `send_image(image_path: str)`, `send_file(file_path: str)`. + For Image and file, you must save it to `output` folder. +4. You must only output the code, do not output the result of the code and any other information. +5. The output language is same as user's input language. +6. Please first provide relevant knowledge about user's problem appropriately. + +## Example +1. User's problem: `please solve the fabonacci sequence problem.` +Output: +```python +from shared.api import send_text, send_image, send_file + +def fabonacci(n): + if n <= 1: + return n + else: + return fabonacci(n-1) + fabonacci(n-2) + +result = fabonacci(10) +send_text("The fabonacci sequence is a series of numbers in which each number is the sum of the two preceding ones, starting from 0 and 1.") +send_text("Let's calculate the fabonacci sequence of 10: " + result) # send_text is a function to send pure text to user +``` + +2. User's problem: `please draw a sin(x) function.` +Output: +```python +from shared.api import send_text, send_image, send_file +import numpy as np +import matplotlib.pyplot as plt + +x = np.linspace(0, 2*np.pi, 100) +y = np.sin(x) +plt.plot(x, y) +plt.savefig("output/sin_x.png") +send_text("The sin(x) is a periodic function with a period of 2π, and the value range is [-1, 1]. The following is the image of sin(x).") +send_image("output/sin_x.png") # send_image is a function to send image to user +send_text("If you need more information, please let me know :)") +``` + +{extra_prompt} +""" + +DEFAULT_CONFIG = { + "sandbox": { + "image": "soulter/astrbot-code-interpreter-sandbox", + "docker_mirror": "", # cjie.eu.org + } +} +PATH = "data/config/python_interpreter.json" + +@star.register(name="astrbot-python-interpreter", desc="Python 代码执行器", author="Soulter", version="0.0.1") +class Main(star.Star): + '''基于 Docker 沙箱的 Python 代码执行器''' + def __init__(self, context: star.Context) -> None: + self.context = context + self.curr_dir = os.path.dirname(os.path.abspath(__file__)) + self.workplace_path = os.path.join(self.curr_dir, "workplace") + self.shared_path = os.path.join(self.curr_dir, "shared") + os.makedirs(self.workplace_path, exist_ok=True) + + self.user_file_msg_buffer = defaultdict(list) + '''存放用户上传的文件''' + + # 加载配置 + if not os.path.exists(PATH): + self.config = DEFAULT_CONFIG + self._save_config() + else: + with open(PATH, "r") as f: + self.config = json.load(f) + + async def initialize(self): + ok = await self.is_docker_available() + if not ok: + logger.warning("Docker 不可用,代码解释器将无法使用,astrbot-python-interpreter 将自动禁用。") + await self.context._star_manager.turn_off_plugin("astrbot-python-interpreter") + + async def file_upload(self, file_path: str): + ''' + 上传图像文件到 S3 + ''' + ext = os.path.splitext(file_path)[1] + S3_URL = "https://s3.neko.soulter.top/astrbot-s3" + with open(file_path, "rb") as f: + file = f.read() + + s3_file_url = f"{S3_URL}/{uuid.uuid4().hex}{ext}" + + async with aiohttp.ClientSession(headers = {"Accept": "application/json"}) as session: + async with session.put(s3_file_url, data=file) as resp: + if resp.status != 200: + raise Exception(f"Failed to upload image: {resp.status}") + return s3_file_url + + + async def is_docker_available(self) -> bool: + '''Check if docker is available''' + try: + docker = aiodocker.Docker() + await docker.version() + return True + except aiodocker.exceptions.DockerError as e: + logger.error(f"检查 Docker 可用性时出现问题: {e}") + return False + + async def get_image_name(self) -> str: + '''Get the image name''' + if self.config["sandbox"]["docker_mirror"]: + return f"{self.config['sandbox']['docker_mirror']}/{self.config['sandbox']['image']}" + return self.config["sandbox"]["image"] + + async def _save_config(self): + with open(PATH, "w") as f: + json.dump(self.config, f) + + async def gen_magic_code(self) -> str: + return uuid.uuid4().hex[:8] + + async def download_image(self, image_url: str, workplace_path: str, filename: str) -> str: + '''Download image from url to workplace_path''' + async with aiohttp.ClientSession() as session: + async with session.get(image_url) as resp: + if resp.status != 200: + return "" + image_path = os.path.join(workplace_path, f"{filename}.jpg") + with open(image_path, 'wb') as f: + f.write(await resp.read()) + return f"{filename}.jpg" + + async def tidy_code(self, code: str) -> str: + '''Tidy the code''' + pattern = r"```(?:py|python)?\n(.*?)\n```" + match = re.search(pattern, code, re.DOTALL) + if match is None: + raise ValueError("The code is not in the code block.") + return match.group(1) + + @filter.event_message_type(filter.EventMessageType.ALL) + async def on_message(self, event: AstrMessageEvent): + '''处理消息''' + for comp in event.message_obj.message: + if isinstance(comp, File): + self.user_file_msg_buffer[event.get_session_id()].append(comp.file) + logger.debug(f"User uploaded file: {comp.file}") + break # 一个消息中,文件只能有一个,这里直接 break 减少计算量。 + + @filter.on_llm_request() + async def on_llm_req(self, event: AstrMessageEvent, request: ProviderRequest): + if event.get_session_id() in self.user_file_msg_buffer: + files = self.user_file_msg_buffer[event.get_session_id()] + request.prompt += f"\nUser provided files: {files}" + + + @filter.command_group("pi") + def pi(self): + pass + + + @pi.command("mirror") + async def pi_mirror(self, event: AstrMessageEvent, url: str = ""): + '''Docker 镜像地址''' + if not url: + yield event.plain_result(f"""当前 Docker 镜像地址: {self.config['sandbox']['docker_mirror']}。 +使用 `pi mirror ` 来设置 Docker 镜像地址。 +您所设置的 Docker 镜像地址将会自动加在 Docker 镜像名前。如: `soulter/astrbot-code-interpreter-sandbox` -> `cjie.eu.org/soulter/astrbot-code-interpreter-sandbox`。 +""") + else: + self.config["sandbox"]["docker_mirror"] = url + await self._save_config() + yield event.plain_result("设置 Docker 镜像地址成功。") + + @pi.command("repull") + async def pi_repull(self, event: AstrMessageEvent): + '''重新拉取沙箱镜像''' + docker = aiodocker.Docker() + image_name = await self.get_image_name() + try: + await docker.images.get(image_name) + await docker.images.delete(image_name, force=True) + except aiodocker.exceptions.DockerError: + pass + await docker.images.pull(image_name) + yield event.plain_result("重新拉取沙箱镜像成功。") + + + @llm_tool("python_interpreter") + async def python_interpreter(self, event: AstrMessageEvent): + '''Use this tool only if user really want to solve a complex problem and the problem can be solved very well by Python code. + For example, user can use this tool to solve math problems, edit image, docx, pptx, pdf, etc. + ''' + if not await self.is_docker_available(): + yield event.plain_result("Docker 在当前机器不可用,无法沙箱化执行代码。") + + plain_text = event.message_str + + # 创建必要的工作目录和幻术码 + magic_code = await self.gen_magic_code() + workplace_path = os.path.join(self.workplace_path, magic_code) + output_path = os.path.join(workplace_path, "output") + os.makedirs(workplace_path, exist_ok=True) + os.makedirs(output_path, exist_ok=True) + + # 图片 + images = [] + idx = 1 + for comp in event.message_obj.message: + if isinstance(comp, Image): + image_url = comp.url if comp.url else comp.file + if image_url.startswith("http"): + image_path = await self.download_image(image_url, workplace_path, f"img_{idx}") + if image_path: + images.append(image_path) + idx += 1 + # 文件 + files = [] + for file_path in self.user_file_msg_buffer[event.get_session_id()]: + # cp + file_name = os.path.basename(file_path) + shutil.copy(file_path, os.path.join(workplace_path, file_name)) + files.append(file_name) + + logger.debug(f"user query: {plain_text}, images: {images}, files: {files}") + + # 整理额外输入 + extra_inputs = "" + if images: + extra_inputs += f"User provided images: {images}\n" + if files: + extra_inputs += f"User provided files: {files}\n" + + obs = "" + n = 5 + + for i in range(n): + if i > 0: + logger.info(f"Try {i+1}/{n}") + + PROMPT_ = PROMPT.format( + prompt=plain_text, + extra_input=extra_inputs, + extra_prompt=obs, + ) + provider = self.context.get_using_provider() + llm_response = await provider.text_chat(prompt=PROMPT_, session_id=f"{event.session_id}_{magic_code}_{str(i)}") + + logger.debug("code interpreter llm gened code:" + llm_response.completion_text) + + # 整理代码并保存 + code_clean = await self.tidy_code(llm_response.completion_text) + with open(os.path.join(workplace_path, "exec.py"), "w") as f: + f.write(code_clean) + + # 启动容器 + docker = aiodocker.Docker() + + # 检查有没有image + image_name = await self.get_image_name() + try: + await docker.images.get(image_name) + except aiodocker.exceptions.DockerError: + # 拉取镜像 + logger.info(f"未找到沙箱镜像,正在尝试拉取 {image_name}...") + await docker.images.pull(image_name) + + yield event.plain_result(f"使用沙箱执行代码中,请稍等...(尝试次数: {i+1}/{n})") + + container = await docker.containers.run({ + "Image": image_name, + "Cmd": ["python", "exec.py"], + "Memory": 512 * 1024 * 1024, + "NanoCPUs": 1000000000, + "HostConfig": { + "Binds": [ + f"{self.shared_path}:/astrbot_sandbox/shared:ro", + f"{output_path}:/astrbot_sandbox/output:rw", + f"{workplace_path}:/astrbot_sandbox:rw", + ] + }, + "Env": [ + f"MAGIC_CODE={magic_code}" + ], + "AutoRemove": True + }) + + logger.debug(f"Container {container.id} created.") + logs = await self.run_container(container) + + logger.debug(f"Container {container.id} finished.") + logger.debug(f"Container {container.id} logs: {logs}") + + # 发送结果 + pattern = r"\[ASTRBOT_(TEXT|IMAGE|FILE)_OUTPUT#\w+\]: (.*)" + ok = False + traceback = "" + for idx, log in enumerate(logs): + match = re.match(pattern, log) + if match: + ok = True + if match.group(1) == "TEXT": + yield event.plain_result(match.group(2)) + elif match.group(1) == "IMAGE": + image_path = os.path.join(workplace_path, match.group(2)) + logger.debug(f"Sending image: {image_path}") + yield event.image_result(image_path) + elif match.group(1) == "FILE": + file_path = os.path.join(workplace_path, match.group(2)) + logger.debug(f"Sending file: {file_path}") + file_s3_url = await self.file_upload(file_path) + logger.info(f"文件上传到 AstrBot 云节点: {file_s3_url}") + file_name = os.path.basename(file_path) + chain = [File(name=file_name, file=file_s3_url)] + yield event.set_result(MessageEventResult(chain=chain)) + + elif "Traceback (most recent call last)" in log \ + or "[Error]: " in log: + traceback = "\n".join(logs[idx:]) + + if not ok: + if traceback: + obs = f"## Observation \n When execute the code: ```python\n{code_clean}\n```\n\n Error occured:\n\n{traceback}\n Need to improve/fix the code." + else: + logger.warning(f"未从沙箱输出中捕获到合法的输出。沙箱输出日志: {logs}") + break + else: + return + + yield event.plain_result("经过多次尝试后,未从沙箱输出中捕获到合法的输出,请更换问法或者查看日志。") + + + async def run_container(self, container: aiodocker.docker.DockerContainer, timeout: int = 20) -> list[str]: + '''Run the container and get the output''' + try: + await container.wait(timeout=timeout) + logs = await container.log(stdout=True, stderr=True) + return logs + except asyncio.TimeoutError: + logger.warning(f"Container {container.id} timeout.") + await container.kill() + return [f"[Error]: Container has been killed due to timeout ({timeout}s)."] + finally: + await container.delete() \ No newline at end of file diff --git a/packages/python_interpreter/requirements.txt b/packages/python_interpreter/requirements.txt new file mode 100644 index 000000000..e69de29bb diff --git a/packages/python_interpreter/shared/api.py b/packages/python_interpreter/shared/api.py new file mode 100644 index 000000000..9fe27ce67 --- /dev/null +++ b/packages/python_interpreter/shared/api.py @@ -0,0 +1,18 @@ +import os + +def _get_magic_code(): + '''防止注入攻击''' + return os.getenv("MAGIC_CODE") + +def send_text(text: str): + print(f"[ASTRBOT_TEXT_OUTPUT#{_get_magic_code()}]: {text}") + +def send_image(image_path: str): + if not os.path.exists(image_path): + raise Exception(f"Image file not found: {image_path}") + print(f"[ASTRBOT_IMAGE_OUTPUT#{_get_magic_code()}]: {image_path}") + +def send_file(file_path: str): + if not os.path.exists(file_path): + raise Exception(f"File not found: {file_path}") + print(f"[ASTRBOT_FILE_OUTPUT#{_get_magic_code()}]: {file_path}") \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index c02fc9b0d..965651870 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,4 +15,5 @@ colorlog aiocqhttp pyjwt apscheduler -docstring_parser \ No newline at end of file +docstring_parser +aiodocker \ No newline at end of file diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 4d90fae88..b75c314ee 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -117,9 +117,10 @@ def star_context(event_queue, config, db, platform_manager, provider_manager): return star_context @pytest.fixture(scope="module") -def plugin_manager(star_context, config): +@pytest.mark.asyncio +async def plugin_manager(star_context, config): plugin_manager = PluginManager(star_context, config) - plugin_manager.reload() + await plugin_manager.reload() return plugin_manager @pytest.fixture(scope="module") diff --git a/tests/test_plugin_manager.py b/tests/test_plugin_manager.py index 8d7b35568..1d336a57a 100644 --- a/tests/test_plugin_manager.py +++ b/tests/test_plugin_manager.py @@ -27,7 +27,7 @@ def test_plugin_manager_initialization(plugin_manager_pm: PluginManager): @pytest.mark.asyncio async def test_plugin_manager_reload(plugin_manager_pm: PluginManager): - success, err_message = plugin_manager_pm.reload() + success, err_message = await plugin_manager_pm.reload() assert success is True assert err_message is None assert len(star_handlers_registry) > 0 # package