feat: 使用压缩包文件的更新方式

This commit is contained in:
Soulter
2024-07-07 18:26:58 +08:00
parent 465d283cad
commit 77067c545c
6 changed files with 236 additions and 110 deletions
+1 -2
View File
@@ -288,8 +288,7 @@ class AstrBotDashBoard():
else:
latest = False
try:
update_project(request_release_info(latest),
latest=latest, version=version)
update_project(latest=latest, version=version)
threading.Thread(target=self.shutdown_bot, args=(3,)).start()
return Response(
status="success",
+2 -6
View File
@@ -274,8 +274,7 @@ class Command:
else:
if l[1] == "latest":
try:
release_data = util.updator.request_release_info()
util.updator.update_project(release_data)
util.updator.update_project()
return True, "更新成功,重启生效。可输入「update r」重启", "update"
except BaseException as e:
return False, "更新失败: "+str(e), "update"
@@ -284,10 +283,7 @@ class Command:
else:
if l[1].lower().startswith('v'):
try:
release_data = util.updator.request_release_info(
latest=False)
util.updator.update_project(
release_data, latest=False, version=l[1])
util.updator.update_project(latest=False, version=l[1])
return True, "更新成功,重启生效。可输入「update r」重启", "update"
except BaseException as e:
return False, "更新失败: "+str(e), "update"
-1
View File
@@ -5,7 +5,6 @@ openai~=1.2.3
qq-botpy
chardet~=5.1.0
Pillow
GitPython
nakuru-project
beautifulsoup4
googlesearch-python
+38 -3
View File
@@ -5,12 +5,13 @@ import re
import requests
import aiohttp
import socket
import platform
import json
import sys
import psutil
import ssl
import base64
import zipfile
import shutil
import stat
from PIL import Image, ImageDraw, ImageFont
from type.types import GlobalObject
@@ -362,7 +363,20 @@ async def download_image_by_url(url: str, post: bool = False, post_data: dict =
return save_temp_img(await resp.read())
except Exception as e:
raise e
def download_file(url: str, path: str):
'''
从指定 url 下载文件到指定路径 path
'''
try:
logger.info(f"下载文件: {url}")
with requests.get(url, stream=True) as r:
with open(path, 'wb') as f:
for chunk in r.iter_content(chunk_size=8192):
f.write(chunk)
except Exception as e:
raise e
def create_markdown_image(text: str):
'''
@@ -469,3 +483,24 @@ def run_monitor(global_object: GlobalObject):
}
stat['sys_start_time'] = start_time
time.sleep(30)
def remove_dir(file_path) -> bool:
if not os.path.exists(file_path): return True
try:
shutil.rmtree(file_path, onerror=on_error)
return True
except BaseException as e:
logger.error(f"删除文件/文件夹 {file_path} 失败: {str(e)}")
return False
def on_error(func, path, exc_info):
'''
a callback of the rmtree function.
'''
print(f"remove {path} failed.")
import stat
if not os.access(path, os.W_OK):
os.chmod(path, stat.S_IWUSR)
func(path)
else:
raise
+92 -34
View File
@@ -1,22 +1,18 @@
'''
插件工具函数
'''
import os, sys
import os, sys, zipfile, shutil
import inspect
import shutil
import stat
import traceback
try:
from git.repo import Repo
except ImportError:
pass
from types import ModuleType
from type.plugin import *
from type.register import *
from SparkleLogging.utils.core import LogManager
from logging import Logger
from type.types import GlobalObject
from util.general_utils import download_file, remove_dir
from util.updator import request_release_info
logger: Logger = LogManager.GetLogger(log_name='astrbot-core')
@@ -61,7 +57,7 @@ def get_modules(path):
def get_plugin_store_path():
plugin_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), "../../addons/plugins"))
plugin_dir = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../addons/plugins"))
return plugin_dir
def get_plugin_modules():
@@ -182,22 +178,44 @@ def update_plugin_dept(path):
def install_plugin(repo_url: str, ctx: GlobalObject):
ppath = get_plugin_store_path()
# 删除末尾的 /
if repo_url.endswith("/"):
repo_url = repo_url[:-1]
# 得到 url 的最后一段
d = repo_url.split("/")[-1]
# 转换非法字符:-
d = d.replace("-", "_")
d = d.lower() # 转换为小写
# 创建文件夹
plugin_path = os.path.join(ppath, d)
if os.path.exists(plugin_path):
remove_dir(plugin_path)
Repo.clone_from(repo_url, to_path=plugin_path, branch='master')
repo_namespace = repo_url.split("/")[-2:]
repo = repo_namespace[1]
plugin_path = os.path.join(ppath, repo.replace("-", "_").lower())
if os.path.exists(plugin_path): remove_dir(plugin_path)
# we no longer use Git anymore :)
# Repo.clone_from(repo_url, to_path=plugin_path, branch='master')
download_from_repo_url(plugin_path, repo_url)
unzip_file(plugin_path + ".zip", plugin_path)
with open(os.path.join(plugin_path, "REPO"), "w") as f:
f.write(repo_url)
ok, err = plugin_reload(ctx)
if not ok:
raise Exception(err)
def download_from_repo_url(target_path: str, repo_url: str):
repo_namespace = repo_url.split("/")[-2:]
author = repo_namespace[0]
repo = repo_namespace[1]
logger.info(f"正在下载插件 {repo} ...")
release_url = f"https://api.github.com/repos/{author}/{repo}/releases"
releases = request_release_info(latest=True, url=release_url, mirror_url=release_url)
if not releases:
# download from the default branch directly.
logger.warn(f"未在插件 {author}/{repo} 中找到任何发布版本,将从默认分支下载。")
release_url = f"https://github.com/{author}/{repo}/archive/refs/heads/master.zip"
else:
release_url = releases[0]['zipball_url']
download_file(release_url, target_path + ".zip")
def get_registered_plugin(plugin_name: str, cached_plugins: RegisteredPlugins) -> RegisteredPlugin:
@@ -227,23 +245,63 @@ def update_plugin(plugin_name: str, ctx: GlobalObject):
ppath = get_plugin_store_path()
root_dir_name = plugin.root_dir_name
plugin_path = os.path.join(ppath, root_dir_name)
repo = Repo(path=plugin_path)
repo.remotes.origin.pull()
if not os.path.exists(os.path.join(plugin_path, "REPO")):
raise Exception("插件更新信息文件 `REPO` 不存在,请手动升级,或者先卸载然后重新安装该插件。")
repo_url = None
with open(os.path.join(plugin_path, "REPO"), "r") as f:
repo_url = f.read()
download_from_repo_url(plugin_path, repo_url)
try:
remove_dir(plugin_path)
except BaseException as e:
logger.error(f"删除旧版本插件 {plugin_name} 文件夹失败: {str(e)},使用覆盖安装。")
unzip_file(plugin_path + ".zip", plugin_path)
ok, err = plugin_reload(ctx)
if not ok:
raise Exception(err)
def unzip_file(zip_path: str, target_dir: str):
'''
解压缩文件, 并将压缩包内**第一个**文件夹内的文件移动到 target_dir
'''
os.makedirs(target_dir, exist_ok=True)
update_dir = ""
logger.info(f"解压文件: {zip_path}")
with zipfile.ZipFile(zip_path, 'r') as z:
update_dir = z.namelist()[0]
z.extractall(target_dir)
def remove_dir(file_path) -> bool:
try_cnt = 50
while try_cnt > 0:
if not os.path.exists(file_path):
return False
try:
shutil.rmtree(file_path)
return True
except PermissionError as e:
err_file_path = str(e).split("\'", 2)[1]
if os.path.exists(err_file_path):
os.chmod(err_file_path, stat.S_IWUSR)
try_cnt -= 1
files = os.listdir(os.path.join(target_dir, update_dir))
for f in files:
logger.info(f"移动更新文件/目录: {f}")
if os.path.isdir(os.path.join(target_dir, update_dir, f)):
if os.path.exists(os.path.join(target_dir, f)):
shutil.rmtree(os.path.join(target_dir, f), onerror=on_error)
else:
if os.path.exists(os.path.join(target_dir, f)):
os.remove(os.path.join(target_dir, f))
shutil.move(os.path.join(target_dir, update_dir, f), target_dir)
try:
logger.info(f"删除临时更新文件: {zip_path}{os.path.join(target_dir, update_dir)}")
shutil.rmtree(os.path.join(target_dir, update_dir), onerror=on_error)
os.remove(zip_path)
except:
logger.warn(f"删除更新文件失败,可以手动删除 {zip_path}{os.path.join(target_dir, update_dir)}")
def on_error(func, path, exc_info):
'''
a callback of the rmtree function.
'''
print(f"remove {path} failed.")
import stat
if not os.access(path, os.W_OK):
os.chmod(path, stat.S_IWUSR)
func(path)
else:
raise
+103 -64
View File
@@ -1,18 +1,20 @@
has_git = True
try:
import git.exc
from git.repo import Repo
except BaseException as e:
has_git = False
import sys, os
import sys, os, zipfile, shutil
import requests
import psutil
from type.config import VERSION
from SparkleLogging.utils.core import LogManager
from logging import Logger
from util.general_utils import download_file
logger: Logger = LogManager.GetLogger(log_name='astrbot-core')
ASTRBOT_RELEASE_API = "https://api.github.com/repos/Soulter/AstrBot/releases"
MIRROR_ASTRBOT_RELEASE_API = "https://api.soulter.top/releases" # 0-10 分钟的缓存时间
def get_main_path():
ret = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), ".."))
return ret
def terminate_child_processes():
try:
@@ -37,40 +39,23 @@ def _reboot():
terminate_child_processes()
os.execl(py, py, *sys.argv)
def find_repo() -> Repo:
if not has_git:
raise Exception("未安装 GitPython 库,无法进行更新。")
repo = None
# 由于项目更名过,因此这里需要多次尝试。
try:
repo = Repo()
except git.exc.InvalidGitRepositoryError:
try:
repo = Repo(path="QQChannelChatGPT")
except git.exc.InvalidGitRepositoryError:
repo = Repo(path="AstrBot")
if not repo:
raise Exception("在已知的目录下未找到项目位置。请联系项目维护者。")
return repo
def request_release_info(latest: bool = True) -> list:
def request_release_info(latest: bool = True, url: str = ASTRBOT_RELEASE_API, mirror_url: str = MIRROR_ASTRBOT_RELEASE_API) -> list:
'''
请求版本信息。
返回一个列表,每个元素是一个字典,包含版本号、发布时间、更新内容、commit hash等信息。
'''
api_url1 = "https://api.github.com/repos/Soulter/AstrBot/releases"
api_url2 = "https://api.soulter.top/releases" # 0-10 分钟的缓存时间
try:
result = requests.get(api_url2).json()
result = requests.get(mirror_url).json()
except BaseException as e:
result = requests.get(api_url1).json()
result = requests.get(url).json()
try:
if not result: return []
if latest:
ret = github_api_release_parser([result[0]])
else:
ret = github_api_release_parser(result)
except BaseException as e:
logger.error(f"解析版本信息失败: {result}")
raise Exception(f"解析版本信息失败: {result}")
return ret
@@ -92,22 +77,38 @@ def github_api_release_parser(releases: list) -> list:
"published_at": release['published_at'],
"body": release['body'],
"commit_hash": commit_hash,
"tag_name": release['tag_name']
"tag_name": release['tag_name'],
"zipball_url": release['zipball_url']
})
return ret
def compare_version(v1: str, v2: str) -> int:
'''
比较两个版本号的大小。
返回 1 表示 v1 > v2,返回 -1 表示 v1 < v2,返回 0 表示 v1 = v2。
'''
v1 = v1.replace('v', '')
v2 = v2.replace('v', '')
v1 = v1.split('.')
v2 = v2.split('.')
for i in range(3):
if int(v1[i]) > int(v2[i]):
return 1
elif int(v1[i]) < int(v2[i]):
return -1
return 0
def check_update() -> str:
repo = find_repo()
curr_commit = repo.commit().hexsha
update_data = request_release_info()
new_commit = update_data[0]['commit_hash']
print(f"当前版本: {curr_commit}")
print(f"最新版本: {new_commit}")
if curr_commit.startswith(new_commit):
return f"当前已经是最新版本: v{VERSION}"
else:
update_info = f"""> 有新版本可用,请及时更新。
# 当前版本
tag_name = update_data[0]['tag_name']
logger.debug(f"当前版本: v{VERSION}")
logger.debug(f"最新版本: {tag_name}")
if compare_version(VERSION, tag_name) >= 0:
return "当前已经是最新版本。"
update_info = f"""# 当前版本
v{VERSION}
# 最新版本
@@ -120,27 +121,20 @@ v{VERSION}
---
{update_data[0]['body']}
---"""
return update_info
return update_info
def update_project(update_data: list,
reboot: bool = False,
def update_project(reboot: bool = False,
latest: bool = True,
version: str = ''):
repo = find_repo()
# update_data = request_release_info(latest)
update_data = request_release_info(latest)
if latest:
# 检查本地commit和最新commit是否一致
curr_commit = repo.head.commit.hexsha
new_commit = update_data[0]['commit_hash']
if curr_commit == '':
raise Exception("无法获取当前版本号对应的版本位置。请联系项目维护者。")
if curr_commit.startswith(new_commit):
latest_version = update_data[0]['tag_name']
if compare_version(VERSION, latest_version) >= 0:
raise Exception("当前已经是最新版本。")
else:
# 更新到最新版本对应的commit
try:
repo.git.fetch()
repo.git.checkout(update_data[0]['tag_name'], "-f")
download_file(update_data[0]['zipball_url'], "temp.zip")
unzip_file("temp.zip", get_main_path())
if reboot: _reboot()
except BaseException as e:
raise e
@@ -151,21 +145,66 @@ def update_project(update_data: list,
for data in update_data:
if data['tag_name'] == version:
try:
repo.git.fetch()
repo.git.checkout(data['tag_name'], "-f")
download_file(data['zipball_url'], "temp.zip")
unzip_file("temp.zip", get_main_path())
flag = True
if reboot: _reboot()
except BaseException as e:
raise e
if not flag:
raise Exception("未找到指定版本。")
def unzip_file(zip_path: str, target_dir: str):
'''
解压缩文件, 并将压缩包内**第一个**文件夹内的文件移动到 target_dir
'''
os.makedirs(target_dir, exist_ok=True)
update_dir = ""
logger.info(f"解压文件: {zip_path}")
with zipfile.ZipFile(zip_path, 'r') as z:
update_dir = z.namelist()[0]
z.extractall(target_dir)
avoid_dirs = ["logs", "data", "configs", "temp_plugins", update_dir]
# copy addons/plugins to the target_dir temporarily
if os.path.exists(os.path.join(target_dir, "addons/plugins")):
logger.info("备份插件目录:从 addons/plugins 到 temp_plugins")
shutil.copytree(os.path.join(target_dir, "addons/plugins"), "temp_plugins")
files = os.listdir(os.path.join(target_dir, update_dir))
for f in files:
logger.info(f"移动更新文件/目录: {f}")
if os.path.isdir(os.path.join(target_dir, update_dir, f)):
if f in avoid_dirs: continue
if os.path.exists(os.path.join(target_dir, f)):
shutil.rmtree(os.path.join(target_dir, f), onerror=on_error)
else:
if os.path.exists(os.path.join(target_dir, f)):
os.remove(os.path.join(target_dir, f))
shutil.move(os.path.join(target_dir, update_dir, f), target_dir)
# move back
if os.path.exists("temp_plugins"):
logger.info("恢复插件目录:从 temp_plugins 到 addons/plugins")
shutil.rmtree(os.path.join(target_dir, "addons/plugins"), onerror=on_error)
shutil.move("temp_plugins", os.path.join(target_dir, "addons/plugins"))
def checkout_branch(branch_name: str):
repo = find_repo()
try:
repo.git.fetch()
repo.git.checkout(branch_name, "-f")
repo.git.pull("origin", branch_name, "-f")
return True
except BaseException as e:
raise e
logger.info(f"删除临时更新文件: {zip_path}{os.path.join(target_dir, update_dir)}")
shutil.rmtree(os.path.join(target_dir, update_dir), onerror=on_error)
os.remove(zip_path)
except:
logger.warn(f"删除更新文件失败,可以手动删除 {zip_path}{os.path.join(target_dir, update_dir)}")
def on_error(func, path, exc_info):
'''
a callback of the rmtree function.
'''
print(f"remove {path} failed.")
import stat
if not os.access(path, os.W_OK):
os.chmod(path, stat.S_IWUSR)
func(path)
else:
raise