fix: fix SSLCertVerificationError

This commit is contained in:
冰苷晶
2025-03-23 22:55:07 +08:00
parent 4222f8516f
commit d10cb84068
4 changed files with 227 additions and 240 deletions
+10 -2
View File
@@ -8,6 +8,9 @@ import base64
import zipfile
import uuid
import psutil
import certifi
from typing import Union
from PIL import Image
@@ -81,7 +84,9 @@ async def download_image_by_url(
下载图片, 返回 path
"""
try:
async with aiohttp.ClientSession(trust_env=True) as session:
ssl_context = ssl.create_default_context(cafile=certifi.where()) # 使用 certifi 提供的 CA 证书
connector = aiohttp.TCPConnector(ssl=ssl_context) # 使用 certifi 的根证书
async with aiohttp.ClientSession(trust_env=True, connector=connector) as session:
if post:
async with session.post(url, json=post_data) as resp:
if not path:
@@ -118,7 +123,9 @@ async def download_file(url: str, path: str, show_progress: bool = False):
从指定 url 下载文件到指定路径 path
"""
try:
async with aiohttp.ClientSession(trust_env=True) as session:
ssl_context = ssl.create_default_context(cafile=certifi.where()) # 使用 certifi 提供的 CA 证书
connector = aiohttp.TCPConnector(ssl=ssl_context)
async with aiohttp.ClientSession(trust_env=True, connector=connector) as session:
async with session.get(url, timeout=1800) as resp:
if resp.status != 200:
raise Exception(f"下载文件失败: {resp.status}")
@@ -202,6 +209,7 @@ async def download_dashboard():
"""下载管理面板文件"""
dashboard_release_url = "https://astrbot-registry.soulter.top/download/astrbot-dashboard/latest/dist.zip"
try:
ssl_context = ssl.create_default_context(cafile=certifi.where()) # 使用 certifi 提供的 CA 证书
await download_file(
dashboard_release_url, "data/dashboard.zip", show_progress=True
)
+15 -2
View File
@@ -2,6 +2,10 @@ import aiohttp
import os
import zipfile
import shutil
import ssl
import certifi
from astrbot.core.utils.io import on_error, download_file
from astrbot.core import logger
@@ -33,10 +37,18 @@ class RepoZipUpdator:
返回一个列表,每个元素是一个字典,包含版本号、发布时间、更新内容、commit hash等信息。
"""
try:
async with aiohttp.ClientSession(trust_env=True) as session:
ssl_context = ssl.create_default_context(cafile=certifi.where()) # 新增:创建基于 certifi 的 SSL 上下文
connector = aiohttp.TCPConnector(ssl=ssl_context) # 新增:使用 TCPConnector 指定 SSL 上下文
async with aiohttp.ClientSession(trust_env=True, connector=connector) as session:
async with session.get(url) as response:
# 检查 HTTP 状态码
if response.status != 200:
text = await response.text()
logger.error(f"请求 {url} 失败,状态码: {response.status}, 内容: {text}")
raise Exception(f"请求失败,状态码: {response.status}")
result = await response.json()
if not result:
logger.error("返回空的结果喵♡~")
return []
# if latest:
# ret = self.github_api_release_parser([result[0]])
@@ -53,7 +65,8 @@ class RepoZipUpdator:
"zipball_url": release["zipball_url"],
}
)
except BaseException:
except Exception as e:
logger.error(f"解析版本信息时发生异常: {e}")
raise Exception("解析版本信息失败")
return ret
+200 -235
View File
@@ -1,260 +1,225 @@
import traceback
import os
import ssl
import shutil
import socket
import time
import aiohttp
from .route import Route, Response, RouteContext
from astrbot.core import logger
from quart import request
from astrbot.core.star.star_manager import PluginManager
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.star.star_handler import star_handlers_registry
from astrbot.core.star.filter.command import CommandFilter
from astrbot.core.star.filter.command_group import CommandGroupFilter
from astrbot.core.star.filter.permission import PermissionTypeFilter
from astrbot.core.star.filter.regex import RegexFilter
from astrbot.core.star.star_handler import EventType
import base64
import zipfile
import uuid
import psutil
import certifi
from typing import Union
from PIL import Image
class PluginRoute(Route):
def __init__(
self,
context: RouteContext,
core_lifecycle: AstrBotCoreLifecycle,
plugin_manager: PluginManager,
) -> None:
super().__init__(context)
self.routes = {
"/plugin/get": ("GET", self.get_plugins),
"/plugin/install": ("POST", self.install_plugin),
"/plugin/install-upload": ("POST", self.install_plugin_upload),
"/plugin/update": ("POST", self.update_plugin),
"/plugin/uninstall": ("POST", self.uninstall_plugin),
"/plugin/market_list": ("GET", self.get_online_plugins),
"/plugin/off": ("POST", self.off_plugin),
"/plugin/on": ("POST", self.on_plugin),
"/plugin/reload": ("POST", self.reload_plugins),
}
self.core_lifecycle = core_lifecycle
self.plugin_manager = plugin_manager
self.register_routes()
def on_error(func, path, exc_info):
"""
a callback of the rmtree function.
"""
print(f"remove {path} failed.")
import stat
self.translated_event_type = {
EventType.AdapterMessageEvent: "平台消息下发时",
EventType.OnLLMRequestEvent: "LLM 请求时",
EventType.OnLLMResponseEvent: "LLM 响应后",
EventType.OnDecoratingResultEvent: "回复消息前",
EventType.OnCallingFuncToolEvent: "函数工具",
EventType.OnAfterMessageSentEvent: "发送消息后",
}
if not os.access(path, os.W_OK):
os.chmod(path, stat.S_IWUSR)
func(path)
else:
raise
async def reload_plugins(self):
data = await request.json
plugin_name = data.get("name", None)
try:
success, message = await self.plugin_manager.reload(plugin_name)
if not success:
return Response().error(message).__dict__
return Response().ok(None, "重载成功。").__dict__
except Exception as e:
logger.error(f"/api/plugin/reload: {traceback.format_exc()}")
return Response().error(str(e)).__dict__
async def get_online_plugins(self):
custom = request.args.get("custom_registry")
def remove_dir(file_path) -> bool:
if not os.path.exists(file_path):
return True
try:
shutil.rmtree(file_path, onerror=on_error)
return True
except BaseException:
return False
if custom:
urls = [custom]
else:
urls = ["https://api.soulter.top/astrbot/plugins"]
for url in urls:
try:
async with aiohttp.ClientSession(trust_env=True) as session:
async with session.get(url) as response:
if response.status == 200:
result = await response.json()
return Response().ok(result).__dict__
else:
logger.error(f"请求 {url} 失败,状态码:{response.status}")
except Exception as e:
logger.error(f"请求 {url} 失败,错误:{e}")
def port_checker(port: int, host: str = "localhost"):
sk = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sk.settimeout(1)
try:
sk.connect((host, port))
sk.close()
return True
except Exception:
sk.close()
return False
return Response().error("获取插件列表失败").__dict__
async def get_plugins(self):
_plugin_resp = []
for plugin in self.plugin_manager.context.get_all_stars():
_t = {
"name": plugin.name,
"repo": "" if plugin.repo is None else plugin.repo,
"author": plugin.author,
"desc": plugin.desc,
"version": plugin.version,
"reserved": plugin.reserved,
"activated": plugin.activated,
"online_vesion": "",
"handlers": await self.get_plugin_handlers_info(
plugin.star_handler_full_names
),
}
_plugin_resp.append(_t)
return (
Response()
.ok(_plugin_resp, message=self.plugin_manager.failed_plugin_info)
.__dict__
)
def save_temp_img(img: Union[Image.Image, str]) -> str:
os.makedirs("data/temp", exist_ok=True)
# 获得文件创建时间,清除超过 12 小时的
try:
for f in os.listdir("data/temp"):
path = os.path.join("data/temp", f)
if os.path.isfile(path):
ctime = os.path.getctime(path)
if time.time() - ctime > 3600 * 12:
os.remove(path)
except Exception as e:
print(f"清除临时文件失败: {e}")
async def get_plugin_handlers_info(self, handler_full_names: list[str]):
"""解析插件行为"""
handlers = []
# 获得时间戳
timestamp = f"{int(time.time())}_{uuid.uuid4().hex[:8]}"
p = f"data/temp/{timestamp}.jpg"
for handler_full_name in handler_full_names:
info = {}
handler = star_handlers_registry.star_handlers_map.get(
handler_full_name, None
)
if handler is None:
continue
info["event_type"] = handler.event_type.name
info["event_type_h"] = self.translated_event_type.get(
handler.event_type, handler.event_type.name
)
info["handler_full_name"] = handler.handler_full_name
info["desc"] = handler.desc
info["handler_name"] = handler.handler_name
if isinstance(img, Image.Image):
img.save(p)
else:
with open(p, "wb") as f:
f.write(img)
return p
if handler.event_type == EventType.AdapterMessageEvent:
# 处理平台适配器消息事件
has_admin = False
for filter in (
handler.event_filters
): # 正常handler就只有 1~2 个 filter,因此这里时间复杂度不会太高
if isinstance(filter, CommandFilter):
info["type"] = "指令"
info["cmd"] = (
f"{filter.parent_command_names[0]} {filter.command_name}"
)
info["cmd"] = info["cmd"].strip()
if (
self.core_lifecycle.astrbot_config["wake_prefix"]
and len(self.core_lifecycle.astrbot_config["wake_prefix"])
> 0
):
info["cmd"] = (
f"{self.core_lifecycle.astrbot_config['wake_prefix'][0]}{info['cmd']}"
)
elif isinstance(filter, CommandGroupFilter):
info["type"] = "指令组"
info["cmd"] = filter.get_complete_command_names()[0]
info["cmd"] = info["cmd"].strip()
info["sub_command"] = filter.print_cmd_tree(
filter.sub_command_filters
)
if (
self.core_lifecycle.astrbot_config["wake_prefix"]
and len(self.core_lifecycle.astrbot_config["wake_prefix"])
> 0
):
info["cmd"] = (
f"{self.core_lifecycle.astrbot_config['wake_prefix'][0]}{info['cmd']}"
)
elif isinstance(filter, RegexFilter):
info["type"] = "正则匹配"
info["cmd"] = filter.regex_str
elif isinstance(filter, PermissionTypeFilter):
has_admin = True
info["has_admin"] = has_admin
if "cmd" not in info:
info["cmd"] = "未知"
if "type" not in info:
info["type"] = "事件监听器"
async def download_image_by_url(
url: str, post: bool = False, post_data: dict = None, path=None
) -> str:
"""
下载图片, 返回 path
"""
try:
ssl_context = ssl.create_default_context(cafile=certifi.where()) # 使用 certifi 提供的 CA 证书
connector = aiohttp.TCPConnector(ssl=ssl_context) # 使用 certifi 的根证书
async with aiohttp.ClientSession(trust_env=True, connector=connector) as session:
if post:
async with session.post(url, json=post_data) as resp:
if not path:
return save_temp_img(await resp.read())
else:
with open(path, "wb") as f:
f.write(await resp.read())
return path
else:
info["cmd"] = "自动触发"
info["type"] = ""
async with session.get(url) as resp:
if not path:
return save_temp_img(await resp.read())
else:
with open(path, "wb") as f:
f.write(await resp.read())
return path
except aiohttp.client.ClientConnectorSSLError:
# 关闭SSL验证
ssl_context = ssl.create_default_context()
ssl_context.set_ciphers("DEFAULT")
async with aiohttp.ClientSession() as session:
if post:
async with session.get(url, ssl=ssl_context) as resp:
return save_temp_img(await resp.read())
else:
async with session.get(url, ssl=ssl_context) as resp:
return save_temp_img(await resp.read())
except Exception as e:
raise e
if not info["desc"]:
info["desc"] = "无描述"
handlers.append(info)
async def download_file(url: str, path: str, show_progress: bool = False):
"""
从指定 url 下载文件到指定路径 path
"""
try:
ssl_context = ssl.create_default_context(cafile=certifi.where()) # 使用 certifi 提供的 CA 证书
connector = aiohttp.TCPConnector(ssl=ssl_context)
async with aiohttp.ClientSession(trust_env=True, connector=connector) as session:
async with session.get(url, timeout=1800) as resp:
if resp.status != 200:
raise Exception(f"下载文件失败: {resp.status}")
total_size = int(resp.headers.get("content-length", 0))
downloaded_size = 0
start_time = time.time()
if show_progress:
print(f"文件大小: {total_size / 1024:.2f} KB | 文件地址: {url}")
with open(path, "wb") as f:
while True:
chunk = await resp.content.read(8192)
if not chunk:
break
f.write(chunk)
downloaded_size += len(chunk)
if show_progress:
elapsed_time = time.time() - start_time
speed = downloaded_size / 1024 / elapsed_time # KB/s
print(
f"\r下载进度: {downloaded_size / total_size:.2%} 速度: {speed:.2f} KB/s",
end="",
)
except aiohttp.client.ClientConnectorSSLError:
# 关闭SSL验证
ssl_context = ssl.create_default_context()
ssl_context.set_ciphers("DEFAULT")
async with aiohttp.ClientSession() as session:
async with session.get(url, ssl=ssl_context, timeout=120) as resp:
total_size = int(resp.headers.get("content-length", 0))
downloaded_size = 0
start_time = time.time()
if show_progress:
print(f"文件大小: {total_size / 1024:.2f} KB | 文件地址: {url}")
with open(path, "wb") as f:
while True:
chunk = await resp.content.read(8192)
if not chunk:
break
f.write(chunk)
downloaded_size += len(chunk)
if show_progress:
elapsed_time = time.time() - start_time
speed = downloaded_size / 1024 / elapsed_time # KB/s
print(
f"\r下载进度: {downloaded_size / total_size:.2%} 速度: {speed:.2f} KB/s",
end="",
)
if show_progress:
print()
return handlers
async def install_plugin(self):
post_data = await request.json
repo_url = post_data["url"]
def file_to_base64(file_path: str) -> str:
with open(file_path, "rb") as f:
data_bytes = f.read()
base64_str = base64.b64encode(data_bytes).decode()
return "base64://" + base64_str
proxy: str = post_data.get("proxy", None)
if proxy:
proxy = proxy.removesuffix("/")
try:
logger.info(f"正在安装插件 {repo_url}")
await self.plugin_manager.install_plugin(repo_url, proxy)
# self.core_lifecycle.restart()
logger.info(f"安装插件 {repo_url} 成功。")
return Response().ok(None, "安装成功。").__dict__
except Exception as e:
logger.error(traceback.format_exc())
return Response().error(str(e)).__dict__
def get_local_ip_addresses():
net_interfaces = psutil.net_if_addrs()
network_ips = []
async def install_plugin_upload(self):
try:
file = await request.files
file = file["file"]
logger.info(f"正在安装用户上传的插件 {file.filename}")
file_path = f"data/temp/{file.filename}"
await file.save(file_path)
await self.plugin_manager.install_plugin_from_file(file_path)
# self.core_lifecycle.restart()
logger.info(f"安装插件 {file.filename} 成功")
return Response().ok(None, "安装成功。").__dict__
except Exception as e:
logger.error(traceback.format_exc())
return Response().error(str(e)).__dict__
for interface, addrs in net_interfaces.items():
for addr in addrs:
if addr.family == socket.AF_INET: # 使用 socket.AF_INET 代替 psutil.AF_INET
network_ips.append(addr.address)
async def uninstall_plugin(self):
post_data = await request.json
plugin_name = post_data["name"]
try:
logger.info(f"正在卸载插件 {plugin_name}")
await self.plugin_manager.uninstall_plugin(plugin_name)
logger.info(f"卸载插件 {plugin_name} 成功")
return Response().ok(None, "卸载成功").__dict__
except Exception as e:
logger.error(traceback.format_exc())
return Response().error(str(e)).__dict__
return network_ips
async def update_plugin(self):
post_data = await request.json
plugin_name = post_data["name"]
proxy: str = post_data.get("proxy", None)
try:
logger.info(f"正在更新插件 {plugin_name}")
await self.plugin_manager.update_plugin(plugin_name, proxy)
# self.core_lifecycle.restart()
await self.plugin_manager.reload(plugin_name)
logger.info(f"更新插件 {plugin_name} 成功。")
return Response().ok(None, "更新成功。").__dict__
except Exception as e:
logger.error(f"/api/plugin/update: {traceback.format_exc()}")
return Response().error(str(e)).__dict__
async def off_plugin(self):
post_data = await request.json
plugin_name = post_data["name"]
try:
await self.plugin_manager.turn_off_plugin(plugin_name)
logger.info(f"停用插件 {plugin_name}")
return Response().ok(None, "停用成功。").__dict__
except Exception as e:
logger.error(f"/api/plugin/off: {traceback.format_exc()}")
return Response().error(str(e)).__dict__
async def get_dashboard_version():
if os.path.exists("data/dist"):
if os.path.exists("data/dist/assets/version"):
with open("data/dist/assets/version", "r") as f:
v = f.read().strip()
return v
return None
async def on_plugin(self):
post_data = await request.json
plugin_name = post_data["name"]
try:
await self.plugin_manager.turn_on_plugin(plugin_name)
logger.info(f"启用插件 {plugin_name}")
return Response().ok(None, "启用成功。").__dict__
except Exception as e:
logger.error(f"/api/plugin/on: {traceback.format_exc()}")
return Response().error(str(e)).__dict__
async def download_dashboard():
"""下载管理面板文件"""
dashboard_release_url = "https://astrbot-registry.soulter.top/download/astrbot-dashboard/latest/dist.zip"
try:
ssl_context = ssl.create_default_context(cafile=certifi.where()) # 使用 certifi 提供的 CA 证书
await download_file(
dashboard_release_url, "data/dashboard.zip", show_progress=True
)
except BaseException as _:
dashboard_release_url = (
"https://github.com/Soulter/AstrBot/releases/latest/download/dist.zip"
)
await download_file(
dashboard_release_url, "data/dashboard.zip", show_progress=True
)
print("解压管理面板文件中...")
with zipfile.ZipFile("data/dashboard.zip", "r") as z:
z.extractall("data")
+2 -1
View File
@@ -25,4 +25,5 @@ dashscope
python-telegram-bot
wechatpy
dingtalk-stream
mcp
mcp
certifi