Merge pull request #232 from Soulter/feat-python-interpreter

初步实现代码执行器
This commit is contained in:
Soulter
2025-01-09 15:43:40 +08:00
committed by GitHub
18 changed files with 473 additions and 22 deletions
+1 -1
View File
@@ -53,7 +53,7 @@ class AstrBotCoreLifecycle:
)
self.plugin_manager = PluginManager(self.star_context, self.astrbot_config)
self.plugin_manager.reload()
await self.plugin_manager.reload()
'''扫描、注册插件、实例化插件类'''
await self.provider_manager.initialize()
+14 -1
View File
@@ -54,6 +54,7 @@ class ComponentType(Enum):
CardImage = "CardImage"
TTS = "TTS"
Unknown = "Unknown"
File = "File"
class BaseMessageComponent(BaseModel):
@@ -415,6 +416,17 @@ class Unknown(BaseMessageComponent):
def toString(self):
return ""
class File(BaseMessageComponent):
'''
目前此消息段只适配了 Napcat。
'''
type: ComponentType = "File"
name: T.Optional[str] = "" # 名字
file: T.Optional[str] = "" # url(本地路径)
def __init__(self, name: str, file: str):
super().__init__(name=name, file=file)
ComponentTypes = {
"plain": Plain,
@@ -441,5 +453,6 @@ ComponentTypes = {
"json": Json,
"cardimage": CardImage,
"tts": TTS,
"unknown": Unknown
"unknown": Unknown,
'file': File,
}
+4 -1
View File
@@ -43,7 +43,10 @@ class ProcessStage(Stage):
yield
# 调用提供商相关请求
if self.ctx.astrbot_config['provider_settings'].get('enable', True) and not event._has_send_oper:
if not self.ctx.astrbot_config['provider_settings'].get('enable', True):
return
if not event._has_send_oper and event.is_at_or_wake_command:
if (event.get_result() and not event.get_result().is_stopped()) or not event.get_result():
provider = self.ctx.plugin_manager.context.get_using_provider()
match provider.meta().type:
@@ -47,6 +47,7 @@ class WakingCheckStage(Stage):
# 如果是群聊,且第一个消息段是 At 消息,但不是 At 机器人或 At 全体成员,则不唤醒
break
is_wake = True
event.is_at_or_wake_command = True
event.is_wake = True
event.message_str = event.message_str[len(wake_prefix) :].strip()
break
@@ -60,11 +61,13 @@ class WakingCheckStage(Stage):
is_wake = True
event.is_wake = True
wake_prefix = ""
event.is_at_or_wake_command = True
break
# 检查是否是私聊
if event.is_private_chat():
is_wake = True
event.is_wake = True
event.is_at_or_wake_command = True
wake_prefix = ""
# 检查插件的 handler filter
+2 -1
View File
@@ -35,7 +35,8 @@ class AstrMessageEvent(abc.ABC):
self.platform_meta = platform_meta
self.session_id = session_id
self.role = "member"
self.is_wake = False
self.is_wake = False # 是否通过 WakingStage
self.is_at_or_wake_command = False # 是否是 At 机器人或者带有唤醒词或者是私聊(事件监听器会让 is_wake 设为 True
self._extras = {}
self.session = MessageSesion(
platform_name=platform_meta.name,
@@ -1,3 +1,4 @@
import os
import time
import asyncio
import logging
@@ -5,12 +6,13 @@ from typing import Awaitable, Any
from aiocqhttp import CQHttp, Event
from astrbot.api.platform import Platform, AstrBotMessage, MessageMember, MessageType, PlatformMetadata
from astrbot.api.event import MessageChain
from .aiocqhttp_message_event import *
from astrbot.api.message_components import *
from .aiocqhttp_message_event import * # noqa: F403
from astrbot.api.message_components import * # noqa: F403
from astrbot.api import logger
from .aiocqhttp_message_event import AiocqhttpMessageEvent
from astrbot.core.platform.astr_message_event import MessageSesion
from ...register import register_platform_adapter
from aiocqhttp.exceptions import ActionFailed
@register_platform_adapter("aiocqhttp", "适用于 OneBot 标准的消息平台适配器,支持反向 WebSockets。")
class AiocqhttpAdapter(Platform):
@@ -42,7 +44,7 @@ class AiocqhttpAdapter(Platform):
await self.bot.send_private_msg(user_id=session.session_id, message=ret)
await super().send_by_session(session, message_chain)
def convert_message(self, event: Event) -> AstrBotMessage:
async def convert_message(self, event: Event) -> AstrBotMessage:
abm = AstrBotMessage()
abm.self_id = str(event.self_id)
abm.tag = "aiocqhttp"
@@ -78,7 +80,25 @@ class AiocqhttpAdapter(Platform):
a = None
if t == 'text':
message_str += m['data']['text'].strip()
a = ComponentTypes[t](**m['data'])
elif t == 'file':
try:
# Napcat, LLBot
ret = await self.bot.call_action(action="get_file", file_id=event.message[0]['data']['file_id'])
if not ret.get('file', None):
raise ValueError(f"无法解析文件响应: {ret}")
if not os.path.exists(ret['file']):
raise FileNotFoundError(f"文件不存在: {ret['file']}。如果您使用 Docker 部署了 AstrBot 或者消息协议端(Napcat等),暂时无法获取用户上传的文件。")
m['data'] = {
"file": ret['file'],
"name": ret['file_name']
}
except ActionFailed as e:
logger.error(f"获取文件失败: {e},此消息段将被忽略。")
except BaseException as e:
logger.error(f"获取文件失败: {e},此消息段将被忽略。")
a = ComponentTypes[t](**m['data']) # noqa: F405
abm.message.append(a)
abm.timestamp = int(time.time())
abm.message_str = message_str
@@ -91,13 +111,13 @@ class AiocqhttpAdapter(Platform):
self.bot = CQHttp(use_ws_reverse=True, import_name='aiocqhttp', api_timeout_sec=180)
@self.bot.on_message('group')
async def group(event: Event):
abm = self.convert_message(event)
abm = await self.convert_message(event)
if abm:
await self.handle_msg(abm)
@self.bot.on_message('private')
async def private(event: Event):
abm = self.convert_message(event)
abm = await self.convert_message(event)
if abm:
await self.handle_msg(abm)
@@ -31,11 +31,13 @@ class QQOfficialMessageEvent(AstrMessageEvent):
if image_base64:
media = await self.upload_group_and_c2c_image(image_base64, 1, group_openid=source.group_openid)
payload['media'] = media
payload['msg_type'] = 7
await self.bot.api.post_group_message(group_openid=source.group_openid, **payload)
case botpy.message.C2CMessage:
if image_base64:
media = await self.upload_group_and_c2c_image(image_base64, 1, openid=source.author.user_openid)
payload['media'] = media
payload['msg_type'] = 7
await self.bot.api.post_c2c_message(openid=source.author.user_openid, **payload)
case botpy.message.Message:
if image_path:
@@ -73,9 +75,9 @@ class QQOfficialMessageEvent(AstrMessageEvent):
plain_text += i.text
elif isinstance(i, Image) and not image_base64:
if i.file and i.file.startswith("file:///"):
image_base64 = file_to_base64(i.file[8:])
image_base64 = file_to_base64(i.file[8:]).replace("base64://", "")
image_file_path = i.file[8:]
elif i.file and i.file.startswith("http"):
image_file_path = await download_image_by_url(i.file)
image_base64 = file_to_base64(image_file_path)
image_base64 = file_to_base64(image_file_path).replace("base64://", "")
return plain_text, image_base64, image_file_path
@@ -2,6 +2,7 @@ import sys
import time
import uuid
import asyncio
import os
from astrbot.api.platform import Platform, AstrBotMessage, MessageMember, MessageType, PlatformMetadata
from astrbot.api.event import MessageChain
@@ -62,7 +63,7 @@ class VChatPlatformAdapter(Platform):
self.start_time = int(time.time())
return self._run()
async def _run(self):
await self.client.init()
await self.client.auto_login(hot_reload=True, enable_cmd_qr=True)
@@ -1,4 +1,3 @@
import base64
from typing import List
from .. import Provider
from ..entites import LLMResponse
+1 -1
View File
@@ -185,7 +185,7 @@ def register_llm_tool(name: str = None):
"description": arg.description
})
md = get_handler_or_create(awaitable, EventType.OnCallingFuncToolEvent)
llm_tools.add_func(llm_tool_name, args, docstring.short_description, md.handler)
llm_tools.add_func(llm_tool_name, args, docstring.description, md.handler)
logger.debug(f"LLM 函数工具 {llm_tool_name} 已注册")
return awaitable
+7 -3
View File
@@ -137,7 +137,7 @@ class PluginManager:
return metadata
def reload(self):
async def reload(self):
'''扫描并加载所有的 Star'''
for smd in star_registry:
logger.debug(f"尝试终止插件 {smd.name} ...")
@@ -231,6 +231,10 @@ class PluginManager:
if metadata.module_path in inactivated_plugins:
metadata.activated = False
# 执行 initialize 函数
if hasattr(metadata.star_cls, "initialize"):
await metadata.star_cls.initialize()
except BaseException as e:
traceback.print_exc()
fail_rec += f"加载 {path} 插件时出现问题,原因 {str(e)}\n"
@@ -247,7 +251,7 @@ class PluginManager:
async def install_plugin(self, repo_url: str):
plugin_path = await self.updator.install(repo_url)
# reload the plugin
self.reload()
await self.reload()
return plugin_path
async def uninstall_plugin(self, plugin_name: str):
@@ -288,7 +292,7 @@ class PluginManager:
raise Exception("该插件是 AstrBot 保留插件,无法更新。")
await self.updator.update(plugin)
self.reload()
await self.reload()
async def turn_off_plugin(self, plugin_name: str):
plugin = self.context.get_registered_star(plugin_name)
@@ -22,6 +22,9 @@ class ParameterValidationMixin:
result[param_name] = int(params[i])
else:
result[param_name] = params[i]
elif isinstance(param_type_or_default_val, str):
# 如果 param_type_or_default_val 是字符串,直接赋值
result[param_name] = params[i]
else:
result[param_name] = param_type_or_default_val(params[i])
except ValueError:
+382
View File
@@ -0,0 +1,382 @@
import os
import json
import shutil
import aiohttp
import uuid
import asyncio
import re
import astrbot.api.star as star
import aiodocker
from collections import defaultdict
from astrbot.api.event import AstrMessageEvent, MessageEventResult
from astrbot.api import llm_tool, logger
from astrbot.api.event import filter
from astrbot.api.provider import ProviderRequest
from astrbot.api.message_components import Image, File
PROMPT = """
## Task
You need to generate python codes to solve user's problem: {prompt}
{extra_input}
## Limit
1. Available libraries:
- standard libs
- `Pillow`
- `requests`
- `numpy`
- `matplotlib`
- `scipy`
- `scikit-learn`
- `beautifulsoup4`
- `pandas`
- `opencv-python`
- `python-docx`
- `python-pptx`
- `pymupdf` (Do not use fpdf, reportlab, etc.)
- `mplfonts`
You can only use these libraries and the libraries that they depend on.
2. Do not generate malicious code.
3. Use given `shared.api` package to output the result.
It has 3 functions: `send_text(text: str)`, `send_image(image_path: str)`, `send_file(file_path: str)`.
For Image and file, you must save it to `output` folder.
4. You must only output the code, do not output the result of the code and any other information.
5. The output language is same as user's input language.
6. Please first provide relevant knowledge about user's problem appropriately.
## Example
1. User's problem: `please solve the fabonacci sequence problem.`
Output:
```python
from shared.api import send_text, send_image, send_file
def fabonacci(n):
if n <= 1:
return n
else:
return fabonacci(n-1) + fabonacci(n-2)
result = fabonacci(10)
send_text("The fabonacci sequence is a series of numbers in which each number is the sum of the two preceding ones, starting from 0 and 1.")
send_text("Let's calculate the fabonacci sequence of 10: " + result) # send_text is a function to send pure text to user
```
2. User's problem: `please draw a sin(x) function.`
Output:
```python
from shared.api import send_text, send_image, send_file
import numpy as np
import matplotlib.pyplot as plt
x = np.linspace(0, 2*np.pi, 100)
y = np.sin(x)
plt.plot(x, y)
plt.savefig("output/sin_x.png")
send_text("The sin(x) is a periodic function with a period of 2π, and the value range is [-1, 1]. The following is the image of sin(x).")
send_image("output/sin_x.png") # send_image is a function to send image to user
send_text("If you need more information, please let me know :)")
```
{extra_prompt}
"""
DEFAULT_CONFIG = {
"sandbox": {
"image": "soulter/astrbot-code-interpreter-sandbox",
"docker_mirror": "", # cjie.eu.org
}
}
PATH = "data/config/python_interpreter.json"
@star.register(name="astrbot-python-interpreter", desc="Python 代码执行器", author="Soulter", version="0.0.1")
class Main(star.Star):
'''基于 Docker 沙箱的 Python 代码执行器'''
def __init__(self, context: star.Context) -> None:
self.context = context
self.curr_dir = os.path.dirname(os.path.abspath(__file__))
self.workplace_path = os.path.join(self.curr_dir, "workplace")
self.shared_path = os.path.join(self.curr_dir, "shared")
os.makedirs(self.workplace_path, exist_ok=True)
self.user_file_msg_buffer = defaultdict(list)
'''存放用户上传的文件'''
# 加载配置
if not os.path.exists(PATH):
self.config = DEFAULT_CONFIG
self._save_config()
else:
with open(PATH, "r") as f:
self.config = json.load(f)
async def initialize(self):
ok = await self.is_docker_available()
if not ok:
logger.warning("Docker 不可用,代码解释器将无法使用,astrbot-python-interpreter 将自动禁用。")
await self.context._star_manager.turn_off_plugin("astrbot-python-interpreter")
async def file_upload(self, file_path: str):
'''
上传图像文件到 S3
'''
ext = os.path.splitext(file_path)[1]
S3_URL = "https://s3.neko.soulter.top/astrbot-s3"
with open(file_path, "rb") as f:
file = f.read()
s3_file_url = f"{S3_URL}/{uuid.uuid4().hex}{ext}"
async with aiohttp.ClientSession(headers = {"Accept": "application/json"}) as session:
async with session.put(s3_file_url, data=file) as resp:
if resp.status != 200:
raise Exception(f"Failed to upload image: {resp.status}")
return s3_file_url
async def is_docker_available(self) -> bool:
'''Check if docker is available'''
try:
docker = aiodocker.Docker()
await docker.version()
return True
except aiodocker.exceptions.DockerError as e:
logger.error(f"检查 Docker 可用性时出现问题: {e}")
return False
async def get_image_name(self) -> str:
'''Get the image name'''
if self.config["sandbox"]["docker_mirror"]:
return f"{self.config['sandbox']['docker_mirror']}/{self.config['sandbox']['image']}"
return self.config["sandbox"]["image"]
async def _save_config(self):
with open(PATH, "w") as f:
json.dump(self.config, f)
async def gen_magic_code(self) -> str:
return uuid.uuid4().hex[:8]
async def download_image(self, image_url: str, workplace_path: str, filename: str) -> str:
'''Download image from url to workplace_path'''
async with aiohttp.ClientSession() as session:
async with session.get(image_url) as resp:
if resp.status != 200:
return ""
image_path = os.path.join(workplace_path, f"{filename}.jpg")
with open(image_path, 'wb') as f:
f.write(await resp.read())
return f"{filename}.jpg"
async def tidy_code(self, code: str) -> str:
'''Tidy the code'''
pattern = r"```(?:py|python)?\n(.*?)\n```"
match = re.search(pattern, code, re.DOTALL)
if match is None:
raise ValueError("The code is not in the code block.")
return match.group(1)
@filter.event_message_type(filter.EventMessageType.ALL)
async def on_message(self, event: AstrMessageEvent):
'''处理消息'''
for comp in event.message_obj.message:
if isinstance(comp, File):
self.user_file_msg_buffer[event.get_session_id()].append(comp.file)
logger.debug(f"User uploaded file: {comp.file}")
break # 一个消息中,文件只能有一个,这里直接 break 减少计算量。
@filter.on_llm_request()
async def on_llm_req(self, event: AstrMessageEvent, request: ProviderRequest):
if event.get_session_id() in self.user_file_msg_buffer:
files = self.user_file_msg_buffer[event.get_session_id()]
request.prompt += f"\nUser provided files: {files}"
@filter.command_group("pi")
def pi(self):
pass
@pi.command("mirror")
async def pi_mirror(self, event: AstrMessageEvent, url: str = ""):
'''Docker 镜像地址'''
if not url:
yield event.plain_result(f"""当前 Docker 镜像地址: {self.config['sandbox']['docker_mirror']}
使用 `pi mirror <url>` 来设置 Docker 镜像地址。
您所设置的 Docker 镜像地址将会自动加在 Docker 镜像名前。如: `soulter/astrbot-code-interpreter-sandbox` -> `cjie.eu.org/soulter/astrbot-code-interpreter-sandbox`。
""")
else:
self.config["sandbox"]["docker_mirror"] = url
await self._save_config()
yield event.plain_result("设置 Docker 镜像地址成功。")
@pi.command("repull")
async def pi_repull(self, event: AstrMessageEvent):
'''重新拉取沙箱镜像'''
docker = aiodocker.Docker()
image_name = await self.get_image_name()
try:
await docker.images.get(image_name)
await docker.images.delete(image_name, force=True)
except aiodocker.exceptions.DockerError:
pass
await docker.images.pull(image_name)
yield event.plain_result("重新拉取沙箱镜像成功。")
@llm_tool("python_interpreter")
async def python_interpreter(self, event: AstrMessageEvent):
'''Use this tool only if user really want to solve a complex problem and the problem can be solved very well by Python code.
For example, user can use this tool to solve math problems, edit image, docx, pptx, pdf, etc.
'''
if not await self.is_docker_available():
yield event.plain_result("Docker 在当前机器不可用,无法沙箱化执行代码。")
plain_text = event.message_str
# 创建必要的工作目录和幻术码
magic_code = await self.gen_magic_code()
workplace_path = os.path.join(self.workplace_path, magic_code)
output_path = os.path.join(workplace_path, "output")
os.makedirs(workplace_path, exist_ok=True)
os.makedirs(output_path, exist_ok=True)
# 图片
images = []
idx = 1
for comp in event.message_obj.message:
if isinstance(comp, Image):
image_url = comp.url if comp.url else comp.file
if image_url.startswith("http"):
image_path = await self.download_image(image_url, workplace_path, f"img_{idx}")
if image_path:
images.append(image_path)
idx += 1
# 文件
files = []
for file_path in self.user_file_msg_buffer[event.get_session_id()]:
# cp
file_name = os.path.basename(file_path)
shutil.copy(file_path, os.path.join(workplace_path, file_name))
files.append(file_name)
logger.debug(f"user query: {plain_text}, images: {images}, files: {files}")
# 整理额外输入
extra_inputs = ""
if images:
extra_inputs += f"User provided images: {images}\n"
if files:
extra_inputs += f"User provided files: {files}\n"
obs = ""
n = 5
for i in range(n):
if i > 0:
logger.info(f"Try {i+1}/{n}")
PROMPT_ = PROMPT.format(
prompt=plain_text,
extra_input=extra_inputs,
extra_prompt=obs,
)
provider = self.context.get_using_provider()
llm_response = await provider.text_chat(prompt=PROMPT_, session_id=f"{event.session_id}_{magic_code}_{str(i)}")
logger.debug("code interpreter llm gened code:" + llm_response.completion_text)
# 整理代码并保存
code_clean = await self.tidy_code(llm_response.completion_text)
with open(os.path.join(workplace_path, "exec.py"), "w") as f:
f.write(code_clean)
# 启动容器
docker = aiodocker.Docker()
# 检查有没有image
image_name = await self.get_image_name()
try:
await docker.images.get(image_name)
except aiodocker.exceptions.DockerError:
# 拉取镜像
logger.info(f"未找到沙箱镜像,正在尝试拉取 {image_name}...")
await docker.images.pull(image_name)
yield event.plain_result(f"使用沙箱执行代码中,请稍等...(尝试次数: {i+1}/{n})")
container = await docker.containers.run({
"Image": image_name,
"Cmd": ["python", "exec.py"],
"Memory": 512 * 1024 * 1024,
"NanoCPUs": 1000000000,
"HostConfig": {
"Binds": [
f"{self.shared_path}:/astrbot_sandbox/shared:ro",
f"{output_path}:/astrbot_sandbox/output:rw",
f"{workplace_path}:/astrbot_sandbox:rw",
]
},
"Env": [
f"MAGIC_CODE={magic_code}"
],
"AutoRemove": True
})
logger.debug(f"Container {container.id} created.")
logs = await self.run_container(container)
logger.debug(f"Container {container.id} finished.")
logger.debug(f"Container {container.id} logs: {logs}")
# 发送结果
pattern = r"\[ASTRBOT_(TEXT|IMAGE|FILE)_OUTPUT#\w+\]: (.*)"
ok = False
traceback = ""
for idx, log in enumerate(logs):
match = re.match(pattern, log)
if match:
ok = True
if match.group(1) == "TEXT":
yield event.plain_result(match.group(2))
elif match.group(1) == "IMAGE":
image_path = os.path.join(workplace_path, match.group(2))
logger.debug(f"Sending image: {image_path}")
yield event.image_result(image_path)
elif match.group(1) == "FILE":
file_path = os.path.join(workplace_path, match.group(2))
logger.debug(f"Sending file: {file_path}")
file_s3_url = await self.file_upload(file_path)
logger.info(f"文件上传到 AstrBot 云节点: {file_s3_url}")
file_name = os.path.basename(file_path)
chain = [File(name=file_name, file=file_s3_url)]
yield event.set_result(MessageEventResult(chain=chain))
elif "Traceback (most recent call last)" in log \
or "[Error]: " in log:
traceback = "\n".join(logs[idx:])
if not ok:
if traceback:
obs = f"## Observation \n When execute the code: ```python\n{code_clean}\n```\n\n Error occured:\n\n{traceback}\n Need to improve/fix the code."
else:
logger.warning(f"未从沙箱输出中捕获到合法的输出。沙箱输出日志: {logs}")
break
else:
return
yield event.plain_result("经过多次尝试后,未从沙箱输出中捕获到合法的输出,请更换问法或者查看日志。")
async def run_container(self, container: aiodocker.docker.DockerContainer, timeout: int = 20) -> list[str]:
'''Run the container and get the output'''
try:
await container.wait(timeout=timeout)
logs = await container.log(stdout=True, stderr=True)
return logs
except asyncio.TimeoutError:
logger.warning(f"Container {container.id} timeout.")
await container.kill()
return [f"[Error]: Container has been killed due to timeout ({timeout}s)."]
finally:
await container.delete()
+18
View File
@@ -0,0 +1,18 @@
import os
def _get_magic_code():
'''防止注入攻击'''
return os.getenv("MAGIC_CODE")
def send_text(text: str):
print(f"[ASTRBOT_TEXT_OUTPUT#{_get_magic_code()}]: {text}")
def send_image(image_path: str):
if not os.path.exists(image_path):
raise Exception(f"Image file not found: {image_path}")
print(f"[ASTRBOT_IMAGE_OUTPUT#{_get_magic_code()}]: {image_path}")
def send_file(file_path: str):
if not os.path.exists(file_path):
raise Exception(f"File not found: {file_path}")
print(f"[ASTRBOT_FILE_OUTPUT#{_get_magic_code()}]: {file_path}")
+2 -1
View File
@@ -15,4 +15,5 @@ colorlog
aiocqhttp
pyjwt
apscheduler
docstring_parser
docstring_parser
aiodocker
+3 -2
View File
@@ -117,9 +117,10 @@ def star_context(event_queue, config, db, platform_manager, provider_manager):
return star_context
@pytest.fixture(scope="module")
def plugin_manager(star_context, config):
@pytest.mark.asyncio
async def plugin_manager(star_context, config):
plugin_manager = PluginManager(star_context, config)
plugin_manager.reload()
await plugin_manager.reload()
return plugin_manager
@pytest.fixture(scope="module")
+1 -1
View File
@@ -27,7 +27,7 @@ def test_plugin_manager_initialization(plugin_manager_pm: PluginManager):
@pytest.mark.asyncio
async def test_plugin_manager_reload(plugin_manager_pm: PluginManager):
success, err_message = plugin_manager_pm.reload()
success, err_message = await plugin_manager_pm.reload()
assert success is True
assert err_message is None
assert len(star_handlers_registry) > 0 # package