209 lines
7.2 KiB
Python
209 lines
7.2 KiB
Python
import sys, os, zipfile, shutil
|
|
import requests
|
|
import psutil
|
|
from type.config import VERSION
|
|
from SparkleLogging.utils.core import LogManager
|
|
from logging import Logger
|
|
|
|
from util.general_utils import download_file
|
|
|
|
logger: Logger = LogManager.GetLogger(log_name='astrbot-core')
|
|
|
|
ASTRBOT_RELEASE_API = "https://api.github.com/repos/Soulter/AstrBot/releases"
|
|
MIRROR_ASTRBOT_RELEASE_API = "https://api.soulter.top/releases" # 0-10 分钟的缓存时间
|
|
|
|
def get_main_path():
|
|
ret = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), ".."))
|
|
return ret
|
|
|
|
def terminate_child_processes():
|
|
try:
|
|
parent = psutil.Process(os.getpid())
|
|
children = parent.children(recursive=True)
|
|
logger.info(f"正在终止 {len(children)} 个子进程。")
|
|
for child in children:
|
|
logger.info(f"正在终止子进程 {child.pid}")
|
|
child.terminate()
|
|
try:
|
|
child.wait(timeout=3)
|
|
except psutil.NoSuchProcess:
|
|
continue
|
|
except psutil.TimeoutExpired:
|
|
logger.info(f"子进程 {child.pid} 没有被正常终止, 正在强行杀死。")
|
|
child.kill()
|
|
except psutil.NoSuchProcess:
|
|
pass
|
|
|
|
def _reboot():
|
|
py = sys.executable
|
|
terminate_child_processes()
|
|
os.execl(py, py, *sys.argv)
|
|
|
|
def request_release_info(latest: bool = True, url: str = ASTRBOT_RELEASE_API, mirror_url: str = MIRROR_ASTRBOT_RELEASE_API) -> list:
|
|
'''
|
|
请求版本信息。
|
|
返回一个列表,每个元素是一个字典,包含版本号、发布时间、更新内容、commit hash等信息。
|
|
'''
|
|
try:
|
|
result = requests.get(mirror_url).json()
|
|
except BaseException as e:
|
|
result = requests.get(url).json()
|
|
try:
|
|
if not result: return []
|
|
if latest:
|
|
ret = github_api_release_parser([result[0]])
|
|
else:
|
|
ret = github_api_release_parser(result)
|
|
except BaseException as e:
|
|
logger.error(f"解析版本信息失败: {result}")
|
|
raise Exception(f"解析版本信息失败: {result}")
|
|
return ret
|
|
|
|
def github_api_release_parser(releases: list) -> list:
|
|
'''
|
|
解析 GitHub API 返回的 releases 信息。
|
|
返回一个列表,每个元素是一个字典,包含版本号、发布时间、更新内容、commit hash等信息。
|
|
'''
|
|
ret = []
|
|
for release in releases:
|
|
version = release['name']
|
|
commit_hash = ''
|
|
# 规范是: v3.0.7.xxxxxx,其中xxxxxx为 commit hash
|
|
_t = version.split(".")
|
|
if len(_t) == 4:
|
|
commit_hash = _t[3]
|
|
ret.append({
|
|
"version": release['name'],
|
|
"published_at": release['published_at'],
|
|
"body": release['body'],
|
|
"commit_hash": commit_hash,
|
|
"tag_name": release['tag_name'],
|
|
"zipball_url": release['zipball_url']
|
|
})
|
|
return ret
|
|
|
|
def compare_version(v1: str, v2: str) -> int:
|
|
'''
|
|
比较两个版本号的大小。
|
|
返回 1 表示 v1 > v2,返回 -1 表示 v1 < v2,返回 0 表示 v1 = v2。
|
|
'''
|
|
v1 = v1.replace('v', '')
|
|
v2 = v2.replace('v', '')
|
|
v1 = v1.split('.')
|
|
v2 = v2.split('.')
|
|
|
|
for i in range(3):
|
|
if int(v1[i]) > int(v2[i]):
|
|
return 1
|
|
elif int(v1[i]) < int(v2[i]):
|
|
return -1
|
|
return 0
|
|
|
|
def check_update() -> str:
|
|
update_data = request_release_info()
|
|
tag_name = update_data[0]['tag_name']
|
|
logger.debug(f"当前版本: v{VERSION}")
|
|
logger.debug(f"最新版本: {tag_name}")
|
|
|
|
if compare_version(VERSION, tag_name) >= 0:
|
|
return "当前已经是最新版本。"
|
|
|
|
update_info = f"""# 当前版本
|
|
v{VERSION}
|
|
|
|
# 最新版本
|
|
{update_data[0]['version']}
|
|
|
|
# 发布时间
|
|
{update_data[0]['published_at']}
|
|
|
|
# 更新内容
|
|
---
|
|
{update_data[0]['body']}
|
|
---"""
|
|
return update_info
|
|
|
|
def update_project(reboot: bool = False,
|
|
latest: bool = True,
|
|
version: str = ''):
|
|
update_data = request_release_info(latest)
|
|
if latest:
|
|
latest_version = update_data[0]['tag_name']
|
|
if compare_version(VERSION, latest_version) >= 0:
|
|
raise Exception("当前已经是最新版本。")
|
|
else:
|
|
try:
|
|
download_file(update_data[0]['zipball_url'], "temp.zip")
|
|
unzip_file("temp.zip", get_main_path())
|
|
if reboot: _reboot()
|
|
except BaseException as e:
|
|
raise e
|
|
else:
|
|
# 更新到指定版本
|
|
flag = False
|
|
print(f"请求更新到指定版本: {version}")
|
|
for data in update_data:
|
|
if data['tag_name'] == version:
|
|
try:
|
|
download_file(data['zipball_url'], "temp.zip")
|
|
unzip_file("temp.zip", get_main_path())
|
|
flag = True
|
|
if reboot: _reboot()
|
|
except BaseException as e:
|
|
raise e
|
|
if not flag:
|
|
raise Exception("未找到指定版本。")
|
|
|
|
def unzip_file(zip_path: str, target_dir: str):
|
|
'''
|
|
解压缩文件, 并将压缩包内**第一个**文件夹内的文件移动到 target_dir
|
|
'''
|
|
os.makedirs(target_dir, exist_ok=True)
|
|
update_dir = ""
|
|
logger.info(f"解压文件: {zip_path}")
|
|
with zipfile.ZipFile(zip_path, 'r') as z:
|
|
update_dir = z.namelist()[0]
|
|
z.extractall(target_dir)
|
|
|
|
avoid_dirs = ["logs", "data", "configs", "temp_plugins", update_dir]
|
|
# copy addons/plugins to the target_dir temporarily
|
|
if os.path.exists(os.path.join(target_dir, "addons/plugins")):
|
|
logger.info("备份插件目录:从 addons/plugins 到 temp_plugins")
|
|
shutil.copytree(os.path.join(target_dir, "addons/plugins"), "temp_plugins")
|
|
|
|
files = os.listdir(os.path.join(target_dir, update_dir))
|
|
for f in files:
|
|
logger.info(f"移动更新文件/目录: {f}")
|
|
if os.path.isdir(os.path.join(target_dir, update_dir, f)):
|
|
if f in avoid_dirs: continue
|
|
if os.path.exists(os.path.join(target_dir, f)):
|
|
shutil.rmtree(os.path.join(target_dir, f), onerror=on_error)
|
|
else:
|
|
if os.path.exists(os.path.join(target_dir, f)):
|
|
os.remove(os.path.join(target_dir, f))
|
|
shutil.move(os.path.join(target_dir, update_dir, f), target_dir)
|
|
|
|
# move back
|
|
if os.path.exists("temp_plugins"):
|
|
logger.info("恢复插件目录:从 temp_plugins 到 addons/plugins")
|
|
shutil.rmtree(os.path.join(target_dir, "addons/plugins"), onerror=on_error)
|
|
shutil.move("temp_plugins", os.path.join(target_dir, "addons/plugins"))
|
|
|
|
try:
|
|
logger.info(f"删除临时更新文件: {zip_path} 和 {os.path.join(target_dir, update_dir)}")
|
|
shutil.rmtree(os.path.join(target_dir, update_dir), onerror=on_error)
|
|
os.remove(zip_path)
|
|
except:
|
|
logger.warn(f"删除更新文件失败,可以手动删除 {zip_path} 和 {os.path.join(target_dir, update_dir)}")
|
|
|
|
def on_error(func, path, exc_info):
|
|
'''
|
|
a callback of the rmtree function.
|
|
'''
|
|
print(f"remove {path} failed.")
|
|
import stat
|
|
if not os.access(path, os.W_OK):
|
|
os.chmod(path, stat.S_IWUSR)
|
|
func(path)
|
|
else:
|
|
raise |