Files
AstrBot/util/updator.py
T
2024-07-07 20:59:12 +08:00

209 lines
7.2 KiB
Python

import sys, os, zipfile, shutil
import requests
import psutil
from type.config import VERSION
from SparkleLogging.utils.core import LogManager
from logging import Logger
from util.general_utils import download_file
logger: Logger = LogManager.GetLogger(log_name='astrbot-core')
ASTRBOT_RELEASE_API = "https://api.github.com/repos/Soulter/AstrBot/releases"
MIRROR_ASTRBOT_RELEASE_API = "https://api.soulter.top/releases" # 0-10 分钟的缓存时间
def get_main_path():
ret = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), ".."))
return ret
def terminate_child_processes():
try:
parent = psutil.Process(os.getpid())
children = parent.children(recursive=True)
logger.info(f"正在终止 {len(children)} 个子进程。")
for child in children:
logger.info(f"正在终止子进程 {child.pid}")
child.terminate()
try:
child.wait(timeout=3)
except psutil.NoSuchProcess:
continue
except psutil.TimeoutExpired:
logger.info(f"子进程 {child.pid} 没有被正常终止, 正在强行杀死。")
child.kill()
except psutil.NoSuchProcess:
pass
def _reboot():
py = sys.executable
terminate_child_processes()
os.execl(py, py, *sys.argv)
def request_release_info(latest: bool = True, url: str = ASTRBOT_RELEASE_API, mirror_url: str = MIRROR_ASTRBOT_RELEASE_API) -> list:
'''
请求版本信息。
返回一个列表,每个元素是一个字典,包含版本号、发布时间、更新内容、commit hash等信息。
'''
try:
result = requests.get(mirror_url).json()
except BaseException as e:
result = requests.get(url).json()
try:
if not result: return []
if latest:
ret = github_api_release_parser([result[0]])
else:
ret = github_api_release_parser(result)
except BaseException as e:
logger.error(f"解析版本信息失败: {result}")
raise Exception(f"解析版本信息失败: {result}")
return ret
def github_api_release_parser(releases: list) -> list:
'''
解析 GitHub API 返回的 releases 信息。
返回一个列表,每个元素是一个字典,包含版本号、发布时间、更新内容、commit hash等信息。
'''
ret = []
for release in releases:
version = release['name']
commit_hash = ''
# 规范是: v3.0.7.xxxxxx,其中xxxxxx为 commit hash
_t = version.split(".")
if len(_t) == 4:
commit_hash = _t[3]
ret.append({
"version": release['name'],
"published_at": release['published_at'],
"body": release['body'],
"commit_hash": commit_hash,
"tag_name": release['tag_name'],
"zipball_url": release['zipball_url']
})
return ret
def compare_version(v1: str, v2: str) -> int:
'''
比较两个版本号的大小。
返回 1 表示 v1 > v2,返回 -1 表示 v1 < v2,返回 0 表示 v1 = v2。
'''
v1 = v1.replace('v', '')
v2 = v2.replace('v', '')
v1 = v1.split('.')
v2 = v2.split('.')
for i in range(3):
if int(v1[i]) > int(v2[i]):
return 1
elif int(v1[i]) < int(v2[i]):
return -1
return 0
def check_update() -> str:
update_data = request_release_info()
tag_name = update_data[0]['tag_name']
logger.debug(f"当前版本: v{VERSION}")
logger.debug(f"最新版本: {tag_name}")
if compare_version(VERSION, tag_name) >= 0:
return "当前已经是最新版本。"
update_info = f"""# 当前版本
v{VERSION}
# 最新版本
{update_data[0]['version']}
# 发布时间
{update_data[0]['published_at']}
# 更新内容
---
{update_data[0]['body']}
---"""
return update_info
def update_project(reboot: bool = False,
latest: bool = True,
version: str = ''):
update_data = request_release_info(latest)
if latest:
latest_version = update_data[0]['tag_name']
if compare_version(VERSION, latest_version) >= 0:
raise Exception("当前已经是最新版本。")
else:
try:
download_file(update_data[0]['zipball_url'], "temp.zip")
unzip_file("temp.zip", get_main_path())
if reboot: _reboot()
except BaseException as e:
raise e
else:
# 更新到指定版本
flag = False
print(f"请求更新到指定版本: {version}")
for data in update_data:
if data['tag_name'] == version:
try:
download_file(data['zipball_url'], "temp.zip")
unzip_file("temp.zip", get_main_path())
flag = True
if reboot: _reboot()
except BaseException as e:
raise e
if not flag:
raise Exception("未找到指定版本。")
def unzip_file(zip_path: str, target_dir: str):
'''
解压缩文件, 并将压缩包内**第一个**文件夹内的文件移动到 target_dir
'''
os.makedirs(target_dir, exist_ok=True)
update_dir = ""
logger.info(f"解压文件: {zip_path}")
with zipfile.ZipFile(zip_path, 'r') as z:
update_dir = z.namelist()[0]
z.extractall(target_dir)
avoid_dirs = ["logs", "data", "configs", "temp_plugins", update_dir]
# copy addons/plugins to the target_dir temporarily
if os.path.exists(os.path.join(target_dir, "addons/plugins")):
logger.info("备份插件目录:从 addons/plugins 到 temp_plugins")
shutil.copytree(os.path.join(target_dir, "addons/plugins"), "temp_plugins")
files = os.listdir(os.path.join(target_dir, update_dir))
for f in files:
logger.info(f"移动更新文件/目录: {f}")
if os.path.isdir(os.path.join(target_dir, update_dir, f)):
if f in avoid_dirs: continue
if os.path.exists(os.path.join(target_dir, f)):
shutil.rmtree(os.path.join(target_dir, f), onerror=on_error)
else:
if os.path.exists(os.path.join(target_dir, f)):
os.remove(os.path.join(target_dir, f))
shutil.move(os.path.join(target_dir, update_dir, f), target_dir)
# move back
if os.path.exists("temp_plugins"):
logger.info("恢复插件目录:从 temp_plugins 到 addons/plugins")
shutil.rmtree(os.path.join(target_dir, "addons/plugins"), onerror=on_error)
shutil.move("temp_plugins", os.path.join(target_dir, "addons/plugins"))
try:
logger.info(f"删除临时更新文件: {zip_path}{os.path.join(target_dir, update_dir)}")
shutil.rmtree(os.path.join(target_dir, update_dir), onerror=on_error)
os.remove(zip_path)
except:
logger.warn(f"删除更新文件失败,可以手动删除 {zip_path}{os.path.join(target_dir, update_dir)}")
def on_error(func, path, exc_info):
'''
a callback of the rmtree function.
'''
print(f"remove {path} failed.")
import stat
if not os.access(path, os.W_OK):
os.chmod(path, stat.S_IWUSR)
func(path)
else:
raise