Files
AstrBot/astrbot/core/utils/io.py
T
LIghtJUNction 04aee2890a feat: implement dashboard download caching and fallback mechanism
- Add local file caching for dashboard downloads with version validation
- Implement fallback to 'latest' version if specific version download fails
- Add robust error handling in CLI check_dashboard to prevent crashes
- Remove dashboard caching from smoke tests (backend-only mode)
2026-03-15 19:14:13 +08:00

392 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import asyncio
import base64
import logging
import os
import shutil
import socket
import ssl
import time
import uuid
import zipfile
from ipaddress import IPv4Address, IPv6Address, ip_address
from pathlib import Path
import aiohttp
import certifi
import psutil
from PIL import Image
from .astrbot_path import get_astrbot_data_path, get_astrbot_path, get_astrbot_temp_path
logger = logging.getLogger("astrbot")
def on_error(func, path, exc_info) -> None:
"""A callback of the rmtree function."""
import stat
if not os.access(path, os.W_OK):
os.chmod(path, stat.S_IWUSR)
func(path)
else:
raise exc_info[1]
def remove_dir(file_path: str) -> bool:
if not os.path.exists(file_path):
return True
shutil.rmtree(file_path, onerror=on_error)
return True
def port_checker(port: int, host: str = "localhost") -> bool:
sk = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sk.settimeout(1)
try:
sk.connect((host, port))
sk.close()
return True
except Exception:
sk.close()
return False
def save_temp_img(img: Image.Image | bytes) -> str:
temp_dir = get_astrbot_temp_path()
# 获得时间戳
timestamp = f"{int(time.time())}_{uuid.uuid4().hex[:8]}"
p = os.path.join(temp_dir, f"io_temp_img_{timestamp}.jpg")
if isinstance(img, Image.Image):
img.save(p)
else:
with open(p, "wb") as f:
f.write(img)
return p
async def download_image_by_url(
url: str,
post: bool = False,
post_data: dict | None = None,
path: str | None = None,
) -> str:
"""下载图片, 返回 path"""
try:
ssl_context = ssl.create_default_context(
cafile=certifi.where(),
) # 使用 certifi 提供的 CA 证书
connector = aiohttp.TCPConnector(ssl=ssl_context) # 使用 certifi 的根证书
async with aiohttp.ClientSession(
trust_env=True,
connector=connector,
) as session:
if post:
async with session.post(url, json=post_data) as resp:
if not path:
return save_temp_img(await resp.read())
with open(path, "wb") as f:
f.write(await resp.read())
return path
else:
async with session.get(url) as resp:
if not path:
return save_temp_img(await resp.read())
with open(path, "wb") as f:
f.write(await resp.read())
return path
except (aiohttp.ClientConnectorSSLError, aiohttp.ClientConnectorCertificateError):
# 关闭SSL验证(仅在证书验证失败时作为fallback)
logger.warning(
f"SSL certificate verification failed for {url}. "
"Disabling SSL verification (CERT_NONE) as a fallback. "
"This is insecure and exposes the application to man-in-the-middle attacks. "
"Please investigate and resolve certificate issues."
)
ssl_context = ssl.create_default_context()
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE
async with aiohttp.ClientSession() as session:
if post:
async with session.post(url, json=post_data, ssl=ssl_context) as resp:
if not path:
return save_temp_img(await resp.read())
with open(path, "wb") as f:
f.write(await resp.read())
return path
else:
async with session.get(url, ssl=ssl_context) as resp:
if not path:
return save_temp_img(await resp.read())
with open(path, "wb") as f:
f.write(await resp.read())
return path
except Exception as e:
raise e
async def download_file(url: str, path: str, show_progress: bool = False) -> None:
"""从指定 url 下载文件到指定路径 path"""
try:
ssl_context = ssl.create_default_context(
cafile=certifi.where(),
) # 使用 certifi 提供的 CA 证书
connector = aiohttp.TCPConnector(ssl=ssl_context)
async with aiohttp.ClientSession(
trust_env=True,
connector=connector,
) as session:
async with session.get(url, timeout=1800) as resp:
if resp.status != 200:
raise Exception(f"下载文件失败: {resp.status}")
total_size = int(resp.headers.get("content-length", 0))
downloaded_size = 0
start_time = time.time()
if show_progress:
print(f"文件大小: {total_size / 1024:.2f} KB | 文件地址: {url}")
with open(path, "wb") as f:
while True:
chunk = await resp.content.read(8192)
if not chunk:
break
f.write(chunk)
downloaded_size += len(chunk)
if show_progress:
elapsed_time = (
time.time() - start_time
if time.time() - start_time > 0
else 1
)
speed = downloaded_size / 1024 / elapsed_time # KB/s
print(
f"\r下载进度: {downloaded_size / total_size:.2%} 速度: {speed:.2f} KB/s",
end="",
)
except (aiohttp.ClientConnectorSSLError, aiohttp.ClientConnectorCertificateError):
# 关闭SSL验证(仅在证书验证失败时作为fallback)
logger.warning(
"SSL 证书验证失败,已关闭 SSL 验证(不安全,仅用于临时下载)。请检查目标服务器的证书配置。"
)
logger.warning(
f"SSL certificate verification failed for {url}. "
"Falling back to unverified connection (CERT_NONE). "
"This is insecure and exposes the application to man-in-the-middle attacks. "
"Please investigate certificate issues with the remote server."
)
ssl_context = ssl.create_default_context()
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE
async with aiohttp.ClientSession() as session:
async with session.get(url, ssl=ssl_context, timeout=120) as resp:
total_size = int(resp.headers.get("content-length", 0))
downloaded_size = 0
start_time = time.time()
if show_progress:
print(f"文件大小: {total_size / 1024:.2f} KB | 文件地址: {url}")
with open(path, "wb") as f:
while True:
chunk = await resp.content.read(8192)
if not chunk:
break
f.write(chunk)
downloaded_size += len(chunk)
if show_progress:
elapsed_time = time.time() - start_time
speed = downloaded_size / 1024 / elapsed_time # KB/s
print(
f"\r下载进度: {downloaded_size / total_size:.2%} 速度: {speed:.2f} KB/s",
end="",
)
if show_progress:
print()
def file_to_base64(file_path: str) -> str:
with open(file_path, "rb") as f:
data_bytes = f.read()
base64_str = base64.b64encode(data_bytes).decode()
return "base64://" + base64_str
def get_local_ip_addresses() -> list[IPv4Address | IPv6Address]:
net_interfaces = psutil.net_if_addrs()
network_ips: list[IPv4Address | IPv6Address] = []
for _, addrs in net_interfaces.items():
for addr in addrs:
if addr.family == socket.AF_INET:
network_ips.append(ip_address(addr.address))
elif addr.family == socket.AF_INET6:
# 过滤掉 IPv6 的 link-local 地址(fe80:...
ip = ip_address(addr.address.split("%")[0]) # 处理带 zone index 的情况
if not ip.is_link_local:
network_ips.append(ip)
return network_ips
async def get_public_ip_address() -> list[IPv4Address | IPv6Address]:
urls = [
"https://api64.ipify.org",
"https://ident.me",
"https://ifconfig.me",
"https://icanhazip.com",
]
found_ips: dict[int, IPv4Address | IPv6Address] = {}
async def fetch(session: aiohttp.ClientSession, url: str):
try:
async with session.get(url, timeout=3) as resp:
if resp.status == 200:
raw_ip = (await resp.text()).strip()
ip = ip_address(raw_ip)
if ip.version not in found_ips:
found_ips[ip.version] = ip
except Exception as e:
# Ignore errors from individual services so that a single failing
# endpoint does not prevent discovering the public IP from others.
logger.debug("Failed to fetch public IP from %s: %s", url, e)
async with aiohttp.ClientSession() as session:
tasks = [fetch(session, url) for url in urls]
await asyncio.gather(*tasks)
# 返回找到的所有 IP 对象列表
return list(found_ips.values())
async def get_dashboard_version():
# First check user data directory (manually updated / downloaded dashboard).
dist_dir = os.path.join(get_astrbot_data_path(), "dist")
if not os.path.exists(dist_dir):
# Fall back to the dist bundled inside the installed wheel.
_bundled = Path(get_astrbot_path()) / "astrbot" / "dashboard" / "dist"
if _bundled.exists():
dist_dir = str(_bundled)
if os.path.exists(dist_dir):
version_file = os.path.join(dist_dir, "assets", "version")
if os.path.exists(version_file):
with open(version_file, encoding="utf-8") as f:
v = f.read().strip()
return v
return None
async def download_dashboard(
path: str | None = None,
extract_path: str = "data",
latest: bool = True,
version: str | None = None,
proxy: str | None = None,
) -> None:
"""下载管理面板文件"""
if path is None:
zip_path = Path(get_astrbot_data_path()).absolute() / "dashboard.zip"
else:
zip_path = Path(path).absolute()
# 缓存机制
cache_dir = Path(get_astrbot_data_path()).absolute() / "cache"
if not cache_dir.exists():
cache_dir.mkdir(parents=True, exist_ok=True)
use_cache = False
# Only use cache if not requesting "latest" (we don't know the version yet)
if not latest and version:
cache_name = f"dashboard_{version}.zip"
cache_path = cache_dir / cache_name
if cache_path.exists():
logger.info(f"发现本地缓存的管理面板文件: {cache_path}")
try:
with zipfile.ZipFile(cache_path, "r") as z:
if z.testzip() is None:
logger.info("缓存文件校验通过,将直接使用缓存。")
if str(cache_path) != str(zip_path):
shutil.copy(cache_path, zip_path)
use_cache = True
else:
logger.warning("缓存文件损坏,将重新下载。")
os.remove(cache_path)
except zipfile.BadZipFile:
logger.warning("缓存文件损坏 (BadZipFile),将重新下载。")
os.remove(cache_path)
if not use_cache:
if latest or len(str(version)) != 40:
ver_name = "latest" if latest else version
dashboard_release_url = f"https://astrbot-registry.soulter.top/download/astrbot-dashboard/{ver_name}/dist.zip"
logger.info(
f"准备下载指定发行版本的 AstrBot WebUI 文件: {dashboard_release_url}",
)
try:
await download_file(
dashboard_release_url,
str(zip_path),
show_progress=True,
)
except BaseException as _:
try:
if latest:
dashboard_release_url = "https://github.com/AstrBotDevs/AstrBot/releases/latest/download/dist.zip"
else:
dashboard_release_url = f"https://github.com/AstrBotDevs/AstrBot/releases/download/{version}/dist.zip"
if proxy:
dashboard_release_url = f"{proxy}/{dashboard_release_url}"
await download_file(
dashboard_release_url,
str(zip_path),
show_progress=True,
)
except Exception as e:
if not latest:
logger.warning(
f"下载指定版本({version})失败: {e},尝试下载最新版本。"
)
await download_dashboard(
path=path,
extract_path=extract_path,
latest=True,
proxy=proxy,
)
return
raise e
else:
url = f"https://github.com/AstrBotDevs/astrbot-release-harbour/releases/download/release-{version}/dist.zip"
logger.info(f"准备下载指定版本的 AstrBot WebUI: {url}")
if proxy:
url = f"{proxy}/{url}"
await download_file(url, str(zip_path), show_progress=True)
# 下载完成后存入缓存
try:
save_cache_name = None
if not latest and version:
save_cache_name = f"dashboard_{version}.zip"
else:
# 尝试从下载的文件中读取版本号
try:
with zipfile.ZipFile(zip_path, "r") as z:
for v_path in ["dist/assets/version", "assets/version"]:
try:
with z.open(v_path) as f:
v = f.read().decode("utf-8").strip()
save_cache_name = f"dashboard_{v}.zip"
break
except KeyError:
continue
except Exception:
pass
if save_cache_name:
cache_save_path = cache_dir / save_cache_name
if str(zip_path) != str(cache_save_path):
shutil.copy(zip_path, cache_save_path)
logger.info(f"已缓存管理面板文件至: {cache_save_path}")
except Exception as e:
logger.warning(f"缓存管理面板文件失败: {e}")
with zipfile.ZipFile(zip_path, "r") as z:
z.extractall(extract_path)