diff --git a/astrbot/core/utils/io.py b/astrbot/core/utils/io.py index 318a61835..ba09a9dd2 100644 --- a/astrbot/core/utils/io.py +++ b/astrbot/core/utils/io.py @@ -8,6 +8,9 @@ import base64 import zipfile import uuid import psutil + +import certifi + from typing import Union from PIL import Image @@ -81,7 +84,9 @@ async def download_image_by_url( 下载图片, 返回 path """ try: - async with aiohttp.ClientSession(trust_env=True) as session: + ssl_context = ssl.create_default_context(cafile=certifi.where()) # 使用 certifi 提供的 CA 证书 + connector = aiohttp.TCPConnector(ssl=ssl_context) # 使用 certifi 的根证书 + async with aiohttp.ClientSession(trust_env=True, connector=connector) as session: if post: async with session.post(url, json=post_data) as resp: if not path: @@ -118,7 +123,9 @@ async def download_file(url: str, path: str, show_progress: bool = False): 从指定 url 下载文件到指定路径 path """ try: - async with aiohttp.ClientSession(trust_env=True) as session: + ssl_context = ssl.create_default_context(cafile=certifi.where()) # 使用 certifi 提供的 CA 证书 + connector = aiohttp.TCPConnector(ssl=ssl_context) + async with aiohttp.ClientSession(trust_env=True, connector=connector) as session: async with session.get(url, timeout=1800) as resp: if resp.status != 200: raise Exception(f"下载文件失败: {resp.status}") @@ -202,6 +209,7 @@ async def download_dashboard(): """下载管理面板文件""" dashboard_release_url = "https://astrbot-registry.soulter.top/download/astrbot-dashboard/latest/dist.zip" try: + ssl_context = ssl.create_default_context(cafile=certifi.where()) # 使用 certifi 提供的 CA 证书 await download_file( dashboard_release_url, "data/dashboard.zip", show_progress=True ) diff --git a/astrbot/core/zip_updator.py b/astrbot/core/zip_updator.py index 29533ea88..cc951d257 100644 --- a/astrbot/core/zip_updator.py +++ b/astrbot/core/zip_updator.py @@ -2,6 +2,10 @@ import aiohttp import os import zipfile import shutil + +import ssl +import certifi + from astrbot.core.utils.io import on_error, download_file from astrbot.core import logger @@ -33,10 +37,18 @@ class RepoZipUpdator: 返回一个列表,每个元素是一个字典,包含版本号、发布时间、更新内容、commit hash等信息。 """ try: - async with aiohttp.ClientSession(trust_env=True) as session: + ssl_context = ssl.create_default_context(cafile=certifi.where()) # 新增:创建基于 certifi 的 SSL 上下文 + connector = aiohttp.TCPConnector(ssl=ssl_context) # 新增:使用 TCPConnector 指定 SSL 上下文 + async with aiohttp.ClientSession(trust_env=True, connector=connector) as session: async with session.get(url) as response: + # 检查 HTTP 状态码 + if response.status != 200: + text = await response.text() + logger.error(f"请求 {url} 失败,状态码: {response.status}, 内容: {text}") + raise Exception(f"请求失败,状态码: {response.status}") result = await response.json() if not result: + logger.error("返回空的结果喵♡~") return [] # if latest: # ret = self.github_api_release_parser([result[0]]) @@ -53,7 +65,8 @@ class RepoZipUpdator: "zipball_url": release["zipball_url"], } ) - except BaseException: + except Exception as e: + logger.error(f"解析版本信息时发生异常: {e}") raise Exception("解析版本信息失败") return ret diff --git a/astrbot/dashboard/routes/plugin.py b/astrbot/dashboard/routes/plugin.py index 6e90d73e6..ba09a9dd2 100644 --- a/astrbot/dashboard/routes/plugin.py +++ b/astrbot/dashboard/routes/plugin.py @@ -1,260 +1,225 @@ -import traceback +import os +import ssl +import shutil +import socket +import time import aiohttp -from .route import Route, Response, RouteContext -from astrbot.core import logger -from quart import request -from astrbot.core.star.star_manager import PluginManager -from astrbot.core.core_lifecycle import AstrBotCoreLifecycle -from astrbot.core.star.star_handler import star_handlers_registry -from astrbot.core.star.filter.command import CommandFilter -from astrbot.core.star.filter.command_group import CommandGroupFilter -from astrbot.core.star.filter.permission import PermissionTypeFilter -from astrbot.core.star.filter.regex import RegexFilter -from astrbot.core.star.star_handler import EventType +import base64 +import zipfile +import uuid +import psutil + +import certifi + +from typing import Union + +from PIL import Image -class PluginRoute(Route): - def __init__( - self, - context: RouteContext, - core_lifecycle: AstrBotCoreLifecycle, - plugin_manager: PluginManager, - ) -> None: - super().__init__(context) - self.routes = { - "/plugin/get": ("GET", self.get_plugins), - "/plugin/install": ("POST", self.install_plugin), - "/plugin/install-upload": ("POST", self.install_plugin_upload), - "/plugin/update": ("POST", self.update_plugin), - "/plugin/uninstall": ("POST", self.uninstall_plugin), - "/plugin/market_list": ("GET", self.get_online_plugins), - "/plugin/off": ("POST", self.off_plugin), - "/plugin/on": ("POST", self.on_plugin), - "/plugin/reload": ("POST", self.reload_plugins), - } - self.core_lifecycle = core_lifecycle - self.plugin_manager = plugin_manager - self.register_routes() +def on_error(func, path, exc_info): + """ + a callback of the rmtree function. + """ + print(f"remove {path} failed.") + import stat - self.translated_event_type = { - EventType.AdapterMessageEvent: "平台消息下发时", - EventType.OnLLMRequestEvent: "LLM 请求时", - EventType.OnLLMResponseEvent: "LLM 响应后", - EventType.OnDecoratingResultEvent: "回复消息前", - EventType.OnCallingFuncToolEvent: "函数工具", - EventType.OnAfterMessageSentEvent: "发送消息后", - } + if not os.access(path, os.W_OK): + os.chmod(path, stat.S_IWUSR) + func(path) + else: + raise - async def reload_plugins(self): - data = await request.json - plugin_name = data.get("name", None) - try: - success, message = await self.plugin_manager.reload(plugin_name) - if not success: - return Response().error(message).__dict__ - return Response().ok(None, "重载成功。").__dict__ - except Exception as e: - logger.error(f"/api/plugin/reload: {traceback.format_exc()}") - return Response().error(str(e)).__dict__ - async def get_online_plugins(self): - custom = request.args.get("custom_registry") +def remove_dir(file_path) -> bool: + if not os.path.exists(file_path): + return True + try: + shutil.rmtree(file_path, onerror=on_error) + return True + except BaseException: + return False - if custom: - urls = [custom] - else: - urls = ["https://api.soulter.top/astrbot/plugins"] - for url in urls: - try: - async with aiohttp.ClientSession(trust_env=True) as session: - async with session.get(url) as response: - if response.status == 200: - result = await response.json() - return Response().ok(result).__dict__ - else: - logger.error(f"请求 {url} 失败,状态码:{response.status}") - except Exception as e: - logger.error(f"请求 {url} 失败,错误:{e}") +def port_checker(port: int, host: str = "localhost"): + sk = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sk.settimeout(1) + try: + sk.connect((host, port)) + sk.close() + return True + except Exception: + sk.close() + return False - return Response().error("获取插件列表失败").__dict__ - async def get_plugins(self): - _plugin_resp = [] - for plugin in self.plugin_manager.context.get_all_stars(): - _t = { - "name": plugin.name, - "repo": "" if plugin.repo is None else plugin.repo, - "author": plugin.author, - "desc": plugin.desc, - "version": plugin.version, - "reserved": plugin.reserved, - "activated": plugin.activated, - "online_vesion": "", - "handlers": await self.get_plugin_handlers_info( - plugin.star_handler_full_names - ), - } - _plugin_resp.append(_t) - return ( - Response() - .ok(_plugin_resp, message=self.plugin_manager.failed_plugin_info) - .__dict__ - ) +def save_temp_img(img: Union[Image.Image, str]) -> str: + os.makedirs("data/temp", exist_ok=True) + # 获得文件创建时间,清除超过 12 小时的 + try: + for f in os.listdir("data/temp"): + path = os.path.join("data/temp", f) + if os.path.isfile(path): + ctime = os.path.getctime(path) + if time.time() - ctime > 3600 * 12: + os.remove(path) + except Exception as e: + print(f"清除临时文件失败: {e}") - async def get_plugin_handlers_info(self, handler_full_names: list[str]): - """解析插件行为""" - handlers = [] + # 获得时间戳 + timestamp = f"{int(time.time())}_{uuid.uuid4().hex[:8]}" + p = f"data/temp/{timestamp}.jpg" - for handler_full_name in handler_full_names: - info = {} - handler = star_handlers_registry.star_handlers_map.get( - handler_full_name, None - ) - if handler is None: - continue - info["event_type"] = handler.event_type.name - info["event_type_h"] = self.translated_event_type.get( - handler.event_type, handler.event_type.name - ) - info["handler_full_name"] = handler.handler_full_name - info["desc"] = handler.desc - info["handler_name"] = handler.handler_name + if isinstance(img, Image.Image): + img.save(p) + else: + with open(p, "wb") as f: + f.write(img) + return p - if handler.event_type == EventType.AdapterMessageEvent: - # 处理平台适配器消息事件 - has_admin = False - for filter in ( - handler.event_filters - ): # 正常handler就只有 1~2 个 filter,因此这里时间复杂度不会太高 - if isinstance(filter, CommandFilter): - info["type"] = "指令" - info["cmd"] = ( - f"{filter.parent_command_names[0]} {filter.command_name}" - ) - info["cmd"] = info["cmd"].strip() - if ( - self.core_lifecycle.astrbot_config["wake_prefix"] - and len(self.core_lifecycle.astrbot_config["wake_prefix"]) - > 0 - ): - info["cmd"] = ( - f"{self.core_lifecycle.astrbot_config['wake_prefix'][0]}{info['cmd']}" - ) - elif isinstance(filter, CommandGroupFilter): - info["type"] = "指令组" - info["cmd"] = filter.get_complete_command_names()[0] - info["cmd"] = info["cmd"].strip() - info["sub_command"] = filter.print_cmd_tree( - filter.sub_command_filters - ) - if ( - self.core_lifecycle.astrbot_config["wake_prefix"] - and len(self.core_lifecycle.astrbot_config["wake_prefix"]) - > 0 - ): - info["cmd"] = ( - f"{self.core_lifecycle.astrbot_config['wake_prefix'][0]}{info['cmd']}" - ) - elif isinstance(filter, RegexFilter): - info["type"] = "正则匹配" - info["cmd"] = filter.regex_str - elif isinstance(filter, PermissionTypeFilter): - has_admin = True - info["has_admin"] = has_admin - if "cmd" not in info: - info["cmd"] = "未知" - if "type" not in info: - info["type"] = "事件监听器" + +async def download_image_by_url( + url: str, post: bool = False, post_data: dict = None, path=None +) -> str: + """ + 下载图片, 返回 path + """ + try: + ssl_context = ssl.create_default_context(cafile=certifi.where()) # 使用 certifi 提供的 CA 证书 + connector = aiohttp.TCPConnector(ssl=ssl_context) # 使用 certifi 的根证书 + async with aiohttp.ClientSession(trust_env=True, connector=connector) as session: + if post: + async with session.post(url, json=post_data) as resp: + if not path: + return save_temp_img(await resp.read()) + else: + with open(path, "wb") as f: + f.write(await resp.read()) + return path else: - info["cmd"] = "自动触发" - info["type"] = "无" + async with session.get(url) as resp: + if not path: + return save_temp_img(await resp.read()) + else: + with open(path, "wb") as f: + f.write(await resp.read()) + return path + except aiohttp.client.ClientConnectorSSLError: + # 关闭SSL验证 + ssl_context = ssl.create_default_context() + ssl_context.set_ciphers("DEFAULT") + async with aiohttp.ClientSession() as session: + if post: + async with session.get(url, ssl=ssl_context) as resp: + return save_temp_img(await resp.read()) + else: + async with session.get(url, ssl=ssl_context) as resp: + return save_temp_img(await resp.read()) + except Exception as e: + raise e - if not info["desc"]: - info["desc"] = "无描述" - handlers.append(info) +async def download_file(url: str, path: str, show_progress: bool = False): + """ + 从指定 url 下载文件到指定路径 path + """ + try: + ssl_context = ssl.create_default_context(cafile=certifi.where()) # 使用 certifi 提供的 CA 证书 + connector = aiohttp.TCPConnector(ssl=ssl_context) + async with aiohttp.ClientSession(trust_env=True, connector=connector) as session: + async with session.get(url, timeout=1800) as resp: + if resp.status != 200: + raise Exception(f"下载文件失败: {resp.status}") + total_size = int(resp.headers.get("content-length", 0)) + downloaded_size = 0 + start_time = time.time() + if show_progress: + print(f"文件大小: {total_size / 1024:.2f} KB | 文件地址: {url}") + with open(path, "wb") as f: + while True: + chunk = await resp.content.read(8192) + if not chunk: + break + f.write(chunk) + downloaded_size += len(chunk) + if show_progress: + elapsed_time = time.time() - start_time + speed = downloaded_size / 1024 / elapsed_time # KB/s + print( + f"\r下载进度: {downloaded_size / total_size:.2%} 速度: {speed:.2f} KB/s", + end="", + ) + except aiohttp.client.ClientConnectorSSLError: + # 关闭SSL验证 + ssl_context = ssl.create_default_context() + ssl_context.set_ciphers("DEFAULT") + async with aiohttp.ClientSession() as session: + async with session.get(url, ssl=ssl_context, timeout=120) as resp: + total_size = int(resp.headers.get("content-length", 0)) + downloaded_size = 0 + start_time = time.time() + if show_progress: + print(f"文件大小: {total_size / 1024:.2f} KB | 文件地址: {url}") + with open(path, "wb") as f: + while True: + chunk = await resp.content.read(8192) + if not chunk: + break + f.write(chunk) + downloaded_size += len(chunk) + if show_progress: + elapsed_time = time.time() - start_time + speed = downloaded_size / 1024 / elapsed_time # KB/s + print( + f"\r下载进度: {downloaded_size / total_size:.2%} 速度: {speed:.2f} KB/s", + end="", + ) + if show_progress: + print() - return handlers - async def install_plugin(self): - post_data = await request.json - repo_url = post_data["url"] +def file_to_base64(file_path: str) -> str: + with open(file_path, "rb") as f: + data_bytes = f.read() + base64_str = base64.b64encode(data_bytes).decode() + return "base64://" + base64_str - proxy: str = post_data.get("proxy", None) - if proxy: - proxy = proxy.removesuffix("/") - try: - logger.info(f"正在安装插件 {repo_url}") - await self.plugin_manager.install_plugin(repo_url, proxy) - # self.core_lifecycle.restart() - logger.info(f"安装插件 {repo_url} 成功。") - return Response().ok(None, "安装成功。").__dict__ - except Exception as e: - logger.error(traceback.format_exc()) - return Response().error(str(e)).__dict__ +def get_local_ip_addresses(): + net_interfaces = psutil.net_if_addrs() + network_ips = [] - async def install_plugin_upload(self): - try: - file = await request.files - file = file["file"] - logger.info(f"正在安装用户上传的插件 {file.filename}") - file_path = f"data/temp/{file.filename}" - await file.save(file_path) - await self.plugin_manager.install_plugin_from_file(file_path) - # self.core_lifecycle.restart() - logger.info(f"安装插件 {file.filename} 成功") - return Response().ok(None, "安装成功。").__dict__ - except Exception as e: - logger.error(traceback.format_exc()) - return Response().error(str(e)).__dict__ + for interface, addrs in net_interfaces.items(): + for addr in addrs: + if addr.family == socket.AF_INET: # 使用 socket.AF_INET 代替 psutil.AF_INET + network_ips.append(addr.address) - async def uninstall_plugin(self): - post_data = await request.json - plugin_name = post_data["name"] - try: - logger.info(f"正在卸载插件 {plugin_name}") - await self.plugin_manager.uninstall_plugin(plugin_name) - logger.info(f"卸载插件 {plugin_name} 成功") - return Response().ok(None, "卸载成功").__dict__ - except Exception as e: - logger.error(traceback.format_exc()) - return Response().error(str(e)).__dict__ + return network_ips - async def update_plugin(self): - post_data = await request.json - plugin_name = post_data["name"] - proxy: str = post_data.get("proxy", None) - try: - logger.info(f"正在更新插件 {plugin_name}") - await self.plugin_manager.update_plugin(plugin_name, proxy) - # self.core_lifecycle.restart() - await self.plugin_manager.reload(plugin_name) - logger.info(f"更新插件 {plugin_name} 成功。") - return Response().ok(None, "更新成功。").__dict__ - except Exception as e: - logger.error(f"/api/plugin/update: {traceback.format_exc()}") - return Response().error(str(e)).__dict__ - async def off_plugin(self): - post_data = await request.json - plugin_name = post_data["name"] - try: - await self.plugin_manager.turn_off_plugin(plugin_name) - logger.info(f"停用插件 {plugin_name} 。") - return Response().ok(None, "停用成功。").__dict__ - except Exception as e: - logger.error(f"/api/plugin/off: {traceback.format_exc()}") - return Response().error(str(e)).__dict__ +async def get_dashboard_version(): + if os.path.exists("data/dist"): + if os.path.exists("data/dist/assets/version"): + with open("data/dist/assets/version", "r") as f: + v = f.read().strip() + return v + return None - async def on_plugin(self): - post_data = await request.json - plugin_name = post_data["name"] - try: - await self.plugin_manager.turn_on_plugin(plugin_name) - logger.info(f"启用插件 {plugin_name} 。") - return Response().ok(None, "启用成功。").__dict__ - except Exception as e: - logger.error(f"/api/plugin/on: {traceback.format_exc()}") - return Response().error(str(e)).__dict__ + +async def download_dashboard(): + """下载管理面板文件""" + dashboard_release_url = "https://astrbot-registry.soulter.top/download/astrbot-dashboard/latest/dist.zip" + try: + ssl_context = ssl.create_default_context(cafile=certifi.where()) # 使用 certifi 提供的 CA 证书 + await download_file( + dashboard_release_url, "data/dashboard.zip", show_progress=True + ) + except BaseException as _: + dashboard_release_url = ( + "https://github.com/Soulter/AstrBot/releases/latest/download/dist.zip" + ) + await download_file( + dashboard_release_url, "data/dashboard.zip", show_progress=True + ) + print("解压管理面板文件中...") + with zipfile.ZipFile("data/dashboard.zip", "r") as z: + z.extractall("data") diff --git a/requirements.txt b/requirements.txt index 313dba0c8..95983a2e8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -25,4 +25,5 @@ dashscope python-telegram-bot wechatpy dingtalk-stream -mcp \ No newline at end of file +mcp +certifi \ No newline at end of file