perf: 优化webui和主程序更新的协调

fix: 修复某些请求不能正确应用代理的问题
This commit is contained in:
Soulter
2025-01-21 01:08:15 +08:00
parent 529cd64d82
commit 5dd1488b5d
15 changed files with 67 additions and 50 deletions
@@ -18,7 +18,7 @@ class SimpleGoogleGenAIClient():
self.api_base = api_base[:-1]
else:
self.api_base = api_base
self.client = aiohttp.ClientSession()
self.client = aiohttp.ClientSession(trust_env=True)
async def models_list(self) -> List[str]:
request_url = f"{self.api_base}/v1beta/models?key={self.api_key}"
+1 -1
View File
@@ -11,7 +11,7 @@ class AstrBotUpdator(RepoZipUpdator):
def __init__(self, repo_mirror: str = "") -> None:
super().__init__(repo_mirror)
self.MAIN_PATH = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../"))
self.ASTRBOT_RELEASE_API = "https://api.github.com/repos/Soulter/AstrBot/releases"
self.ASTRBOT_RELEASE_API = "https://api.soulter.top/releases"
def terminate_child_processes(self):
try:
+6 -6
View File
@@ -70,7 +70,7 @@ async def download_image_by_url(url: str, post: bool = False, post_data: dict =
下载图片, 返回 path
'''
try:
async with aiohttp.ClientSession() as session:
async with aiohttp.ClientSession(trust_env=True) as session:
if post:
async with session.post(url, json=post_data) as resp:
if not path:
@@ -91,7 +91,7 @@ async def download_image_by_url(url: str, post: bool = False, post_data: dict =
# 关闭SSL验证
ssl_context = ssl.create_default_context()
ssl_context.set_ciphers('DEFAULT')
async with aiohttp.ClientSession(trust_env=False) as session:
async with aiohttp.ClientSession() as session:
if post:
async with session.get(url, ssl=ssl_context) as resp:
return save_temp_img(await resp.read())
@@ -106,8 +106,8 @@ async def download_file(url: str, path: str, show_progress: bool = False):
从指定 url 下载文件到指定路径 path
'''
try:
async with aiohttp.ClientSession() as session:
async with session.get(url, timeout=300) as resp:
async with aiohttp.ClientSession(trust_env=True) as session:
async with session.get(url, timeout=120) as resp:
if resp.status != 200:
raise Exception(f"下载文件失败: {resp.status}")
total_size = int(resp.headers.get('content-length', 0))
@@ -130,8 +130,8 @@ async def download_file(url: str, path: str, show_progress: bool = False):
# 关闭SSL验证
ssl_context = ssl.create_default_context()
ssl_context.set_ciphers('DEFAULT')
async with aiohttp.ClientSession(trust_env=False) as session:
async with session.get(url, ssl=ssl_context, timeout=300) as resp:
async with aiohttp.ClientSession() as session:
async with session.get(url, ssl=ssl_context, timeout=120) as resp:
total_size = int(resp.headers.get('content-length', 0))
downloaded_size = 0
start_time = time.time()
+1 -1
View File
@@ -30,7 +30,7 @@ class Metric():
pass
try:
async with aiohttp.ClientSession() as session:
async with aiohttp.ClientSession(trust_env=True) as session:
async with session.post(base_url, json=payload, timeout=3) as response:
if response.status != 200:
pass
+1 -1
View File
@@ -83,7 +83,7 @@ class LocalRenderStrategy(RenderStrategy):
try:
image_url = re.findall(IMAGE_REGEX, line)[0]
print(image_url)
async with aiohttp.ClientSession() as session:
async with aiohttp.ClientSession(trust_env=True) as session:
async with session.get(image_url) as resp:
image_res = Image.open(BytesIO(await resp.read()))
images[i] = image_res
+1 -1
View File
@@ -33,7 +33,7 @@ class NetworkRenderStrategy(RenderStrategy):
}
}
if return_url:
async with aiohttp.ClientSession() as session:
async with aiohttp.ClientSession(trust_env=True) as session:
async with session.post(f"{self.BASE_RENDER_URL}/generate", json=post_data) as resp:
ret = await resp.json()
return f"{self.BASE_RENDER_URL}/{ret['data']['id']}"
+2 -2
View File
@@ -29,7 +29,7 @@ class RepoZipUpdator():
返回一个列表,每个元素是一个字典,包含版本号、发布时间、更新内容、commit hash等信息。
'''
try:
async with aiohttp.ClientSession() as session:
async with aiohttp.ClientSession(trust_env=True) as session:
async with session.get(url) as response:
result = await response.json()
if not result:
@@ -111,7 +111,7 @@ class RepoZipUpdator():
releases = await self.fetch_release_info(url=release_url)
if not releases:
# download from the default branch directly.
logger.warning(f"未在仓库 {author}/{repo} 中找到任何发布版本,正在从默认分支下载。")
logger.info(f"未在仓库 {author}/{repo} 中找到任何发布版本,正在从默认分支下载。")
release_url = f"https://github.com/{author}/{repo}/archive/refs/heads/master.zip"
else:
release_url = releases[0]['zipball_url']
+1 -1
View File
@@ -27,7 +27,7 @@ class PluginRoute(Route):
async def get_online_plugins(self):
url = "https://soulter.github.io/AstrBot_Plugins_Collection/plugins.json"
try:
async with aiohttp.ClientSession() as session:
async with aiohttp.ClientSession(trust_env=True) as session:
async with session.get(url) as response:
result = await response.json()
return Response().ok(result).__dict__
-11
View File
@@ -15,7 +15,6 @@ class StatRoute(Route):
self.routes = {
'/stat/get': ('GET', self.get_stat),
'/stat/version': ('GET', self.get_version),
'/stat/dashboard-version': ('GET', self.get_dashboard_version),
'/stat/start-time': ('GET', self.get_start_time),
'/stat/restart-core': ('GET', self.restart_core)
}
@@ -37,16 +36,6 @@ class StatRoute(Route):
"version": VERSION
}).__dict__
async def get_dashboard_version(self):
async with aiohttp.ClientSession() as session:
async with session.get('https://api.github.com/repos/Soulter/Astrbot-dashboard/actions/artifacts') as resp:
data = await resp.json()
return Response().ok({
"data": data,
"mark": "unimplemented feature"
}).__dict__
async def get_start_time(self):
return Response().ok({
"start_time": self.core_lifecycle.start_time
+26 -8
View File
@@ -4,6 +4,8 @@ from .route import Route, Response, RouteContext
from quart import request
from astrbot.core.updator import AstrBotUpdator
from astrbot.core import logger, pip_installer
from astrbot.core.utils.io import download_dashboard, get_dashboard_version
from astrbot.core.config.default import VERSION
class UpdateRoute(Route):
def __init__(self, context: RouteContext, astrbot_updator: AstrBotUpdator) -> None:
@@ -17,15 +19,24 @@ class UpdateRoute(Route):
self.register_routes()
async def check_update(self):
type_ = request.args.get('type', None)
try:
ret = await self.astrbot_updator.check_update(None, None)
return Response(
status="success",
message=str(ret) if ret is not None else "已经是最新版本了。",
data={
"has_new_version": ret is not None
}
).__dict__
if type_ == 'dashboard':
dv = await get_dashboard_version()
return Response().ok({
"has_new_version": dv != f"v{VERSION}",
"current_version": dv
}).__dict__
else:
ret = await self.astrbot_updator.check_update(None, None)
return Response(
status="success",
message=str(ret) if ret is not None else "已经是最新版本了。",
data={
"has_new_version": ret is not None
}
).__dict__
except Exception as e:
logger.error(traceback.format_exc())
return Response().error(e.__str__()).__dict__
@@ -41,6 +52,13 @@ class UpdateRoute(Route):
latest = False
try:
await self.astrbot_updator.update(latest=latest, version=version)
if latest:
try:
await download_dashboard()
except Exception as e:
logger.error(f"下载管理面板文件失败: {e}")
if reboot:
threading.Thread(target=self.astrbot_updator._reboot, args=(2, )).start()
return Response().ok(None, "更新成功,AstrBot 将在 2 秒内全量重启以应用新的代码。").__dict__
@@ -136,11 +136,15 @@ commonStore.getStartTime();
</template>
<v-card>
<v-card-title>
<span class="text-h5">更新项目</span>
<span class="text-h5">更新 AstrBot</span>
</v-card-title>
<v-card-text>
<v-container>
<h3 class="mb-4">升级到最新版本</h3>
<div class="mb-4">
<small>会同时尝试更新机器人主程序和管理面板如果您正在使用 Docker 部署也可以重新拉取镜像或者使用 <a href="https://containrrr.dev/watchtower/usage-overview/">watchtower</a> 来自动监控拉取</small>
</div>
<p>{{ updateStatus }}</p>
<v-btn class="mt-4 mb-4" @click="switchVersion('latest')" color="primary" style="border-radius: 10px;"
:disabled="!hasNewVersion">
@@ -149,6 +153,9 @@ commonStore.getStartTime();
<v-divider></v-divider>
<div style="margin-top: 16px;">
<h3 class="mb-4">切换到指定版本或指定提交</h3>
<div class="mb-4">
<small>跳到旧版本不会重新下载管理面板文件这可能会造成部分数据显示错误您可在 <a href="https://github.com/Soulter/AstrBot/releases">此处</a> 找到对应的面板文件 dist.zip解压后替换 data/dist 文件夹即可</small>
</div>
<v-text-field label="输入版本号或 master 分支下的 commit hash。" v-model="version" required
variant="outlined"></v-text-field>
<div class="mb-4">
@@ -27,10 +27,10 @@ const sidebarMenu = shallowRef(sidebarItems);
</v-btn>
</v-list-item>
<small style="display: block;" v-if="buildVer">构建: {{ buildVer }}</small>
<small style="display: block;" v-else="buildVer">构建: embedded</small>
<small style="display: block;" v-else>构建: embedded</small>
<v-tooltip text="使用 /dashbord_update 指令更新管理面板">
<template v-slot:activator="{ props }">
<small v-bind="props" v-if="buildVer != version" style="display: block; margin-top: 4px;">面板有更新</small>
<small v-bind="props" v-if="hasWebUIUpdate" style="display: block; margin-top: 4px;">面板有更新</small>
</template>
</v-tooltip>
@@ -50,19 +50,12 @@ export default {
},
data: () => ({
version: "",
buildVer: ""
buildVer: "",
hasWebUIUpdate: false,
}),
mounted() {
this.get_version()
fetch('/assets/version').then((res) => {
return res.text()
}).then((res) => {
if (res.length > 10) {
// 不是版本,不显示 😎
return
}
this.buildVer = res.replace(/\s+/g, '')
})
this.check_webui_update()
},
methods: {
get_version() {
@@ -73,6 +66,16 @@ export default {
.catch((err) => {
console.log(err);
});
},
check_webui_update() {
axios.get('/api/update/check?type=dashboard')
.then((res) => {
this.hasWebUIUpdate = res.data.data.has_new_version;
this.buildVer = res.data.data.current_version;
})
.catch((err) => {
console.log(err);
});
}
},
};
+1 -1
View File
@@ -24,7 +24,7 @@ class Main(star.Star):
async def _query_astrbot_notice(self):
try:
async with aiohttp.ClientSession() as session:
async with aiohttp.ClientSession(trust_env=True) as session:
async with session.get("https://astrbot.soulter.top/notice.json", timeout=2) as resp:
return (await resp.json())["notice"]
except BaseException:
+2 -2
View File
@@ -127,7 +127,7 @@ class Main(star.Star):
s3_file_url = f"{S3_URL}/{uuid.uuid4().hex}{ext}"
async with aiohttp.ClientSession(headers = {"Accept": "application/json"}) as session:
async with aiohttp.ClientSession(headers = {"Accept": "application/json"}, trust_env=True) as session:
async with session.put(s3_file_url, data=file) as resp:
if resp.status != 200:
raise Exception(f"Failed to upload image: {resp.status}")
@@ -159,7 +159,7 @@ class Main(star.Star):
async def download_image(self, image_url: str, workplace_path: str, filename: str) -> str:
'''Download image from url to workplace_path'''
async with aiohttp.ClientSession() as session:
async with aiohttp.ClientSession(trust_env=True) as session:
async with session.get(image_url) as resp:
if resp.status != 200:
return ""
+1 -1
View File
@@ -39,7 +39,7 @@ class Main(star.Star):
'''获取网页内容'''
header = HEADERS
header.update({'User-Agent': random.choice(USER_AGENTS)})
async with aiohttp.ClientSession() as session:
async with aiohttp.ClientSession(trust_env=True) as session:
async with session.get(url, headers=header, timeout=6) as response:
html = await response.text(encoding="utf-8")
doc = Document(html)