From 1cb2b62f81db90e62b9f1574b9231d08bc41ee7a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=86=B0=E8=8B=B7=E6=99=B6?= <2749332490@qq.com> Date: Sun, 23 Mar 2025 23:02:34 +0800 Subject: [PATCH] fix: fix error --- astrbot/dashboard/routes/plugin.py | 444 ++++++++++++++++------------- 1 file changed, 243 insertions(+), 201 deletions(-) diff --git a/astrbot/dashboard/routes/plugin.py b/astrbot/dashboard/routes/plugin.py index ba09a9dd2..af4d0db31 100644 --- a/astrbot/dashboard/routes/plugin.py +++ b/astrbot/dashboard/routes/plugin.py @@ -1,225 +1,267 @@ -import os -import ssl -import shutil -import socket -import time +import traceback import aiohttp -import base64 -import zipfile -import uuid -import psutil +import ssl import certifi -from typing import Union - -from PIL import Image +from .route import Route, Response, RouteContext +from astrbot.core import logger +from quart import request +from astrbot.core.star.star_manager import PluginManager +from astrbot.core.core_lifecycle import AstrBotCoreLifecycle +from astrbot.core.star.star_handler import star_handlers_registry +from astrbot.core.star.filter.command import CommandFilter +from astrbot.core.star.filter.command_group import CommandGroupFilter +from astrbot.core.star.filter.permission import PermissionTypeFilter +from astrbot.core.star.filter.regex import RegexFilter +from astrbot.core.star.star_handler import EventType -def on_error(func, path, exc_info): - """ - a callback of the rmtree function. - """ - print(f"remove {path} failed.") - import stat +class PluginRoute(Route): + def __init__( + self, + context: RouteContext, + core_lifecycle: AstrBotCoreLifecycle, + plugin_manager: PluginManager, + ) -> None: + super().__init__(context) + self.routes = { + "/plugin/get": ("GET", self.get_plugins), + "/plugin/install": ("POST", self.install_plugin), + "/plugin/install-upload": ("POST", self.install_plugin_upload), + "/plugin/update": ("POST", self.update_plugin), + "/plugin/uninstall": ("POST", self.uninstall_plugin), + "/plugin/market_list": ("GET", self.get_online_plugins), + "/plugin/off": ("POST", self.off_plugin), + "/plugin/on": ("POST", self.on_plugin), + "/plugin/reload": ("POST", self.reload_plugins), + } + self.core_lifecycle = core_lifecycle + self.plugin_manager = plugin_manager + self.register_routes() - if not os.access(path, os.W_OK): - os.chmod(path, stat.S_IWUSR) - func(path) - else: - raise + self.translated_event_type = { + EventType.AdapterMessageEvent: "平台消息下发时", + EventType.OnLLMRequestEvent: "LLM 请求时", + EventType.OnLLMResponseEvent: "LLM 响应后", + EventType.OnDecoratingResultEvent: "回复消息前", + EventType.OnCallingFuncToolEvent: "函数工具", + EventType.OnAfterMessageSentEvent: "发送消息后", + } + async def reload_plugins(self): + data = await request.json + plugin_name = data.get("name", None) + try: + success, message = await self.plugin_manager.reload(plugin_name) + if not success: + return Response().error(message).__dict__ + return Response().ok(None, "重载成功。").__dict__ + except Exception as e: + logger.error(f"/api/plugin/reload: {traceback.format_exc()}") + return Response().error(str(e)).__dict__ -def remove_dir(file_path) -> bool: - if not os.path.exists(file_path): - return True - try: - shutil.rmtree(file_path, onerror=on_error) - return True - except BaseException: - return False + async def get_online_plugins(self): + custom = request.args.get("custom_registry") + if custom: + urls = [custom] + else: + urls = ["https://api.soulter.top/astrbot/plugins"] -def port_checker(port: int, host: str = "localhost"): - sk = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sk.settimeout(1) - try: - sk.connect((host, port)) - sk.close() - return True - except Exception: - sk.close() - return False - - -def save_temp_img(img: Union[Image.Image, str]) -> str: - os.makedirs("data/temp", exist_ok=True) - # 获得文件创建时间,清除超过 12 小时的 - try: - for f in os.listdir("data/temp"): - path = os.path.join("data/temp", f) - if os.path.isfile(path): - ctime = os.path.getctime(path) - if time.time() - ctime > 3600 * 12: - os.remove(path) - except Exception as e: - print(f"清除临时文件失败: {e}") - - # 获得时间戳 - timestamp = f"{int(time.time())}_{uuid.uuid4().hex[:8]}" - p = f"data/temp/{timestamp}.jpg" - - if isinstance(img, Image.Image): - img.save(p) - else: - with open(p, "wb") as f: - f.write(img) - return p - - -async def download_image_by_url( - url: str, post: bool = False, post_data: dict = None, path=None -) -> str: - """ - 下载图片, 返回 path - """ - try: - ssl_context = ssl.create_default_context(cafile=certifi.where()) # 使用 certifi 提供的 CA 证书 - connector = aiohttp.TCPConnector(ssl=ssl_context) # 使用 certifi 的根证书 - async with aiohttp.ClientSession(trust_env=True, connector=connector) as session: - if post: - async with session.post(url, json=post_data) as resp: - if not path: - return save_temp_img(await resp.read()) - else: - with open(path, "wb") as f: - f.write(await resp.read()) - return path - else: - async with session.get(url) as resp: - if not path: - return save_temp_img(await resp.read()) - else: - with open(path, "wb") as f: - f.write(await resp.read()) - return path - except aiohttp.client.ClientConnectorSSLError: - # 关闭SSL验证 - ssl_context = ssl.create_default_context() - ssl_context.set_ciphers("DEFAULT") - async with aiohttp.ClientSession() as session: - if post: - async with session.get(url, ssl=ssl_context) as resp: - return save_temp_img(await resp.read()) - else: - async with session.get(url, ssl=ssl_context) as resp: - return save_temp_img(await resp.read()) - except Exception as e: - raise e - - -async def download_file(url: str, path: str, show_progress: bool = False): - """ - 从指定 url 下载文件到指定路径 path - """ - try: - ssl_context = ssl.create_default_context(cafile=certifi.where()) # 使用 certifi 提供的 CA 证书 + # 新增:创建 SSL 上下文,使用 certifi 提供的根证书 + ssl_context = ssl.create_default_context(cafile=certifi.where()) connector = aiohttp.TCPConnector(ssl=ssl_context) - async with aiohttp.ClientSession(trust_env=True, connector=connector) as session: - async with session.get(url, timeout=1800) as resp: - if resp.status != 200: - raise Exception(f"下载文件失败: {resp.status}") - total_size = int(resp.headers.get("content-length", 0)) - downloaded_size = 0 - start_time = time.time() - if show_progress: - print(f"文件大小: {total_size / 1024:.2f} KB | 文件地址: {url}") - with open(path, "wb") as f: - while True: - chunk = await resp.content.read(8192) - if not chunk: - break - f.write(chunk) - downloaded_size += len(chunk) - if show_progress: - elapsed_time = time.time() - start_time - speed = downloaded_size / 1024 / elapsed_time # KB/s - print( - f"\r下载进度: {downloaded_size / total_size:.2%} 速度: {speed:.2f} KB/s", - end="", + for url in urls: + try: + async with aiohttp.ClientSession(trust_env=True, connector=connector) as session: + async with session.get(url) as response: + if response.status == 200: + result = await response.json() + return Response().ok(result).__dict__ + else: + logger.error(f"请求 {url} 失败,状态码:{response.status}") + except Exception as e: + logger.error(f"请求 {url} 失败,错误:{e}") + + return Response().error("获取插件列表失败").__dict__ + + async def get_plugins(self): + _plugin_resp = [] + for plugin in self.plugin_manager.context.get_all_stars(): + _t = { + "name": plugin.name, + "repo": "" if plugin.repo is None else plugin.repo, + "author": plugin.author, + "desc": plugin.desc, + "version": plugin.version, + "reserved": plugin.reserved, + "activated": plugin.activated, + "online_vesion": "", + "handlers": await self.get_plugin_handlers_info( + plugin.star_handler_full_names + ), + } + _plugin_resp.append(_t) + return ( + Response() + .ok(_plugin_resp, message=self.plugin_manager.failed_plugin_info) + .__dict__ + ) + + async def get_plugin_handlers_info(self, handler_full_names: list[str]): + """解析插件行为""" + handlers = [] + + for handler_full_name in handler_full_names: + info = {} + handler = star_handlers_registry.star_handlers_map.get( + handler_full_name, None + ) + if handler is None: + continue + info["event_type"] = handler.event_type.name + info["event_type_h"] = self.translated_event_type.get( + handler.event_type, handler.event_type.name + ) + info["handler_full_name"] = handler.handler_full_name + info["desc"] = handler.desc + info["handler_name"] = handler.handler_name + + if handler.event_type == EventType.AdapterMessageEvent: + # 处理平台适配器消息事件 + has_admin = False + for filter in ( + handler.event_filters + ): # 正常handler就只有 1~2 个 filter,因此这里时间复杂度不会太高 + if isinstance(filter, CommandFilter): + info["type"] = "指令" + info["cmd"] = ( + f"{filter.parent_command_names[0]} {filter.command_name}" + ) + info["cmd"] = info["cmd"].strip() + if ( + self.core_lifecycle.astrbot_config["wake_prefix"] + and len(self.core_lifecycle.astrbot_config["wake_prefix"]) + > 0 + ): + info["cmd"] = ( + f"{self.core_lifecycle.astrbot_config['wake_prefix'][0]}{info['cmd']}" ) - except aiohttp.client.ClientConnectorSSLError: - # 关闭SSL验证 - ssl_context = ssl.create_default_context() - ssl_context.set_ciphers("DEFAULT") - async with aiohttp.ClientSession() as session: - async with session.get(url, ssl=ssl_context, timeout=120) as resp: - total_size = int(resp.headers.get("content-length", 0)) - downloaded_size = 0 - start_time = time.time() - if show_progress: - print(f"文件大小: {total_size / 1024:.2f} KB | 文件地址: {url}") - with open(path, "wb") as f: - while True: - chunk = await resp.content.read(8192) - if not chunk: - break - f.write(chunk) - downloaded_size += len(chunk) - if show_progress: - elapsed_time = time.time() - start_time - speed = downloaded_size / 1024 / elapsed_time # KB/s - print( - f"\r下载进度: {downloaded_size / total_size:.2%} 速度: {speed:.2f} KB/s", - end="", + elif isinstance(filter, CommandGroupFilter): + info["type"] = "指令组" + info["cmd"] = filter.get_complete_command_names()[0] + info["cmd"] = info["cmd"].strip() + info["sub_command"] = filter.print_cmd_tree( + filter.sub_command_filters + ) + if ( + self.core_lifecycle.astrbot_config["wake_prefix"] + and len(self.core_lifecycle.astrbot_config["wake_prefix"]) + > 0 + ): + info["cmd"] = ( + f"{self.core_lifecycle.astrbot_config['wake_prefix'][0]}{info['cmd']}" ) - if show_progress: - print() + elif isinstance(filter, RegexFilter): + info["type"] = "正则匹配" + info["cmd"] = filter.regex_str + elif isinstance(filter, PermissionTypeFilter): + has_admin = True + info["has_admin"] = has_admin + if "cmd" not in info: + info["cmd"] = "未知" + if "type" not in info: + info["type"] = "事件监听器" + else: + info["cmd"] = "自动触发" + info["type"] = "无" + if not info["desc"]: + info["desc"] = "无描述" -def file_to_base64(file_path: str) -> str: - with open(file_path, "rb") as f: - data_bytes = f.read() - base64_str = base64.b64encode(data_bytes).decode() - return "base64://" + base64_str + handlers.append(info) + return handlers -def get_local_ip_addresses(): - net_interfaces = psutil.net_if_addrs() - network_ips = [] + async def install_plugin(self): + post_data = await request.json + repo_url = post_data["url"] - for interface, addrs in net_interfaces.items(): - for addr in addrs: - if addr.family == socket.AF_INET: # 使用 socket.AF_INET 代替 psutil.AF_INET - network_ips.append(addr.address) + proxy: str = post_data.get("proxy", None) + if proxy: + proxy = proxy.removesuffix("/") - return network_ips + try: + logger.info(f"正在安装插件 {repo_url}") + await self.plugin_manager.install_plugin(repo_url, proxy) + # self.core_lifecycle.restart() + logger.info(f"安装插件 {repo_url} 成功。") + return Response().ok(None, "安装成功。").__dict__ + except Exception as e: + logger.error(traceback.format_exc()) + return Response().error(str(e)).__dict__ + async def install_plugin_upload(self): + try: + file = await request.files + file = file["file"] + logger.info(f"正在安装用户上传的插件 {file.filename}") + file_path = f"data/temp/{file.filename}" + await file.save(file_path) + await self.plugin_manager.install_plugin_from_file(file_path) + # self.core_lifecycle.restart() + logger.info(f"安装插件 {file.filename} 成功") + return Response().ok(None, "安装成功。").__dict__ + except Exception as e: + logger.error(traceback.format_exc()) + return Response().error(str(e)).__dict__ -async def get_dashboard_version(): - if os.path.exists("data/dist"): - if os.path.exists("data/dist/assets/version"): - with open("data/dist/assets/version", "r") as f: - v = f.read().strip() - return v - return None + async def uninstall_plugin(self): + post_data = await request.json + plugin_name = post_data["name"] + try: + logger.info(f"正在卸载插件 {plugin_name}") + await self.plugin_manager.uninstall_plugin(plugin_name) + logger.info(f"卸载插件 {plugin_name} 成功") + return Response().ok(None, "卸载成功").__dict__ + except Exception as e: + logger.error(traceback.format_exc()) + return Response().error(str(e)).__dict__ + async def update_plugin(self): + post_data = await request.json + plugin_name = post_data["name"] + proxy: str = post_data.get("proxy", None) + try: + logger.info(f"正在更新插件 {plugin_name}") + await self.plugin_manager.update_plugin(plugin_name, proxy) + # self.core_lifecycle.restart() + await self.plugin_manager.reload(plugin_name) + logger.info(f"更新插件 {plugin_name} 成功。") + return Response().ok(None, "更新成功。").__dict__ + except Exception as e: + logger.error(f"/api/plugin/update: {traceback.format_exc()}") + return Response().error(str(e)).__dict__ -async def download_dashboard(): - """下载管理面板文件""" - dashboard_release_url = "https://astrbot-registry.soulter.top/download/astrbot-dashboard/latest/dist.zip" - try: - ssl_context = ssl.create_default_context(cafile=certifi.where()) # 使用 certifi 提供的 CA 证书 - await download_file( - dashboard_release_url, "data/dashboard.zip", show_progress=True - ) - except BaseException as _: - dashboard_release_url = ( - "https://github.com/Soulter/AstrBot/releases/latest/download/dist.zip" - ) - await download_file( - dashboard_release_url, "data/dashboard.zip", show_progress=True - ) - print("解压管理面板文件中...") - with zipfile.ZipFile("data/dashboard.zip", "r") as z: - z.extractall("data") + async def off_plugin(self): + post_data = await request.json + plugin_name = post_data["name"] + try: + await self.plugin_manager.turn_off_plugin(plugin_name) + logger.info(f"停用插件 {plugin_name} 。") + return Response().ok(None, "停用成功。").__dict__ + except Exception as e: + logger.error(f"/api/plugin/off: {traceback.format_exc()}") + return Response().error(str(e)).__dict__ + + async def on_plugin(self): + post_data = await request.json + plugin_name = post_data["name"] + try: + await self.plugin_manager.turn_on_plugin(plugin_name) + logger.info(f"启用插件 {plugin_name} 。") + return Response().ok(None, "启用成功。").__dict__ + except Exception as e: + logger.error(f"/api/plugin/on: {traceback.format_exc()}") + return Response().error(str(e)).__dict__