fix: fix SSLCertVerificationError
This commit is contained in:
@@ -8,6 +8,9 @@ import base64
|
||||
import zipfile
|
||||
import uuid
|
||||
import psutil
|
||||
|
||||
import certifi
|
||||
|
||||
from typing import Union
|
||||
|
||||
from PIL import Image
|
||||
@@ -81,7 +84,9 @@ async def download_image_by_url(
|
||||
下载图片, 返回 path
|
||||
"""
|
||||
try:
|
||||
async with aiohttp.ClientSession(trust_env=True) as session:
|
||||
ssl_context = ssl.create_default_context(cafile=certifi.where()) # 使用 certifi 提供的 CA 证书
|
||||
connector = aiohttp.TCPConnector(ssl=ssl_context) # 使用 certifi 的根证书
|
||||
async with aiohttp.ClientSession(trust_env=True, connector=connector) as session:
|
||||
if post:
|
||||
async with session.post(url, json=post_data) as resp:
|
||||
if not path:
|
||||
@@ -118,7 +123,9 @@ async def download_file(url: str, path: str, show_progress: bool = False):
|
||||
从指定 url 下载文件到指定路径 path
|
||||
"""
|
||||
try:
|
||||
async with aiohttp.ClientSession(trust_env=True) as session:
|
||||
ssl_context = ssl.create_default_context(cafile=certifi.where()) # 使用 certifi 提供的 CA 证书
|
||||
connector = aiohttp.TCPConnector(ssl=ssl_context)
|
||||
async with aiohttp.ClientSession(trust_env=True, connector=connector) as session:
|
||||
async with session.get(url, timeout=1800) as resp:
|
||||
if resp.status != 200:
|
||||
raise Exception(f"下载文件失败: {resp.status}")
|
||||
@@ -202,6 +209,7 @@ async def download_dashboard():
|
||||
"""下载管理面板文件"""
|
||||
dashboard_release_url = "https://astrbot-registry.soulter.top/download/astrbot-dashboard/latest/dist.zip"
|
||||
try:
|
||||
ssl_context = ssl.create_default_context(cafile=certifi.where()) # 使用 certifi 提供的 CA 证书
|
||||
await download_file(
|
||||
dashboard_release_url, "data/dashboard.zip", show_progress=True
|
||||
)
|
||||
|
||||
@@ -2,6 +2,10 @@ import aiohttp
|
||||
import os
|
||||
import zipfile
|
||||
import shutil
|
||||
|
||||
import ssl
|
||||
import certifi
|
||||
|
||||
from astrbot.core.utils.io import on_error, download_file
|
||||
from astrbot.core import logger
|
||||
|
||||
@@ -33,10 +37,18 @@ class RepoZipUpdator:
|
||||
返回一个列表,每个元素是一个字典,包含版本号、发布时间、更新内容、commit hash等信息。
|
||||
"""
|
||||
try:
|
||||
async with aiohttp.ClientSession(trust_env=True) as session:
|
||||
ssl_context = ssl.create_default_context(cafile=certifi.where()) # 新增:创建基于 certifi 的 SSL 上下文
|
||||
connector = aiohttp.TCPConnector(ssl=ssl_context) # 新增:使用 TCPConnector 指定 SSL 上下文
|
||||
async with aiohttp.ClientSession(trust_env=True, connector=connector) as session:
|
||||
async with session.get(url) as response:
|
||||
# 检查 HTTP 状态码
|
||||
if response.status != 200:
|
||||
text = await response.text()
|
||||
logger.error(f"请求 {url} 失败,状态码: {response.status}, 内容: {text}")
|
||||
raise Exception(f"请求失败,状态码: {response.status}")
|
||||
result = await response.json()
|
||||
if not result:
|
||||
logger.error("返回空的结果喵♡~")
|
||||
return []
|
||||
# if latest:
|
||||
# ret = self.github_api_release_parser([result[0]])
|
||||
@@ -53,7 +65,8 @@ class RepoZipUpdator:
|
||||
"zipball_url": release["zipball_url"],
|
||||
}
|
||||
)
|
||||
except BaseException:
|
||||
except Exception as e:
|
||||
logger.error(f"解析版本信息时发生异常: {e}")
|
||||
raise Exception("解析版本信息失败")
|
||||
return ret
|
||||
|
||||
|
||||
+200
-235
@@ -1,260 +1,225 @@
|
||||
import traceback
|
||||
import os
|
||||
import ssl
|
||||
import shutil
|
||||
import socket
|
||||
import time
|
||||
import aiohttp
|
||||
from .route import Route, Response, RouteContext
|
||||
from astrbot.core import logger
|
||||
from quart import request
|
||||
from astrbot.core.star.star_manager import PluginManager
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
from astrbot.core.star.star_handler import star_handlers_registry
|
||||
from astrbot.core.star.filter.command import CommandFilter
|
||||
from astrbot.core.star.filter.command_group import CommandGroupFilter
|
||||
from astrbot.core.star.filter.permission import PermissionTypeFilter
|
||||
from astrbot.core.star.filter.regex import RegexFilter
|
||||
from astrbot.core.star.star_handler import EventType
|
||||
import base64
|
||||
import zipfile
|
||||
import uuid
|
||||
import psutil
|
||||
|
||||
import certifi
|
||||
|
||||
from typing import Union
|
||||
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class PluginRoute(Route):
|
||||
def __init__(
|
||||
self,
|
||||
context: RouteContext,
|
||||
core_lifecycle: AstrBotCoreLifecycle,
|
||||
plugin_manager: PluginManager,
|
||||
) -> None:
|
||||
super().__init__(context)
|
||||
self.routes = {
|
||||
"/plugin/get": ("GET", self.get_plugins),
|
||||
"/plugin/install": ("POST", self.install_plugin),
|
||||
"/plugin/install-upload": ("POST", self.install_plugin_upload),
|
||||
"/plugin/update": ("POST", self.update_plugin),
|
||||
"/plugin/uninstall": ("POST", self.uninstall_plugin),
|
||||
"/plugin/market_list": ("GET", self.get_online_plugins),
|
||||
"/plugin/off": ("POST", self.off_plugin),
|
||||
"/plugin/on": ("POST", self.on_plugin),
|
||||
"/plugin/reload": ("POST", self.reload_plugins),
|
||||
}
|
||||
self.core_lifecycle = core_lifecycle
|
||||
self.plugin_manager = plugin_manager
|
||||
self.register_routes()
|
||||
def on_error(func, path, exc_info):
|
||||
"""
|
||||
a callback of the rmtree function.
|
||||
"""
|
||||
print(f"remove {path} failed.")
|
||||
import stat
|
||||
|
||||
self.translated_event_type = {
|
||||
EventType.AdapterMessageEvent: "平台消息下发时",
|
||||
EventType.OnLLMRequestEvent: "LLM 请求时",
|
||||
EventType.OnLLMResponseEvent: "LLM 响应后",
|
||||
EventType.OnDecoratingResultEvent: "回复消息前",
|
||||
EventType.OnCallingFuncToolEvent: "函数工具",
|
||||
EventType.OnAfterMessageSentEvent: "发送消息后",
|
||||
}
|
||||
if not os.access(path, os.W_OK):
|
||||
os.chmod(path, stat.S_IWUSR)
|
||||
func(path)
|
||||
else:
|
||||
raise
|
||||
|
||||
async def reload_plugins(self):
|
||||
data = await request.json
|
||||
plugin_name = data.get("name", None)
|
||||
try:
|
||||
success, message = await self.plugin_manager.reload(plugin_name)
|
||||
if not success:
|
||||
return Response().error(message).__dict__
|
||||
return Response().ok(None, "重载成功。").__dict__
|
||||
except Exception as e:
|
||||
logger.error(f"/api/plugin/reload: {traceback.format_exc()}")
|
||||
return Response().error(str(e)).__dict__
|
||||
|
||||
async def get_online_plugins(self):
|
||||
custom = request.args.get("custom_registry")
|
||||
def remove_dir(file_path) -> bool:
|
||||
if not os.path.exists(file_path):
|
||||
return True
|
||||
try:
|
||||
shutil.rmtree(file_path, onerror=on_error)
|
||||
return True
|
||||
except BaseException:
|
||||
return False
|
||||
|
||||
if custom:
|
||||
urls = [custom]
|
||||
else:
|
||||
urls = ["https://api.soulter.top/astrbot/plugins"]
|
||||
|
||||
for url in urls:
|
||||
try:
|
||||
async with aiohttp.ClientSession(trust_env=True) as session:
|
||||
async with session.get(url) as response:
|
||||
if response.status == 200:
|
||||
result = await response.json()
|
||||
return Response().ok(result).__dict__
|
||||
else:
|
||||
logger.error(f"请求 {url} 失败,状态码:{response.status}")
|
||||
except Exception as e:
|
||||
logger.error(f"请求 {url} 失败,错误:{e}")
|
||||
def port_checker(port: int, host: str = "localhost"):
|
||||
sk = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
sk.settimeout(1)
|
||||
try:
|
||||
sk.connect((host, port))
|
||||
sk.close()
|
||||
return True
|
||||
except Exception:
|
||||
sk.close()
|
||||
return False
|
||||
|
||||
return Response().error("获取插件列表失败").__dict__
|
||||
|
||||
async def get_plugins(self):
|
||||
_plugin_resp = []
|
||||
for plugin in self.plugin_manager.context.get_all_stars():
|
||||
_t = {
|
||||
"name": plugin.name,
|
||||
"repo": "" if plugin.repo is None else plugin.repo,
|
||||
"author": plugin.author,
|
||||
"desc": plugin.desc,
|
||||
"version": plugin.version,
|
||||
"reserved": plugin.reserved,
|
||||
"activated": plugin.activated,
|
||||
"online_vesion": "",
|
||||
"handlers": await self.get_plugin_handlers_info(
|
||||
plugin.star_handler_full_names
|
||||
),
|
||||
}
|
||||
_plugin_resp.append(_t)
|
||||
return (
|
||||
Response()
|
||||
.ok(_plugin_resp, message=self.plugin_manager.failed_plugin_info)
|
||||
.__dict__
|
||||
)
|
||||
def save_temp_img(img: Union[Image.Image, str]) -> str:
|
||||
os.makedirs("data/temp", exist_ok=True)
|
||||
# 获得文件创建时间,清除超过 12 小时的
|
||||
try:
|
||||
for f in os.listdir("data/temp"):
|
||||
path = os.path.join("data/temp", f)
|
||||
if os.path.isfile(path):
|
||||
ctime = os.path.getctime(path)
|
||||
if time.time() - ctime > 3600 * 12:
|
||||
os.remove(path)
|
||||
except Exception as e:
|
||||
print(f"清除临时文件失败: {e}")
|
||||
|
||||
async def get_plugin_handlers_info(self, handler_full_names: list[str]):
|
||||
"""解析插件行为"""
|
||||
handlers = []
|
||||
# 获得时间戳
|
||||
timestamp = f"{int(time.time())}_{uuid.uuid4().hex[:8]}"
|
||||
p = f"data/temp/{timestamp}.jpg"
|
||||
|
||||
for handler_full_name in handler_full_names:
|
||||
info = {}
|
||||
handler = star_handlers_registry.star_handlers_map.get(
|
||||
handler_full_name, None
|
||||
)
|
||||
if handler is None:
|
||||
continue
|
||||
info["event_type"] = handler.event_type.name
|
||||
info["event_type_h"] = self.translated_event_type.get(
|
||||
handler.event_type, handler.event_type.name
|
||||
)
|
||||
info["handler_full_name"] = handler.handler_full_name
|
||||
info["desc"] = handler.desc
|
||||
info["handler_name"] = handler.handler_name
|
||||
if isinstance(img, Image.Image):
|
||||
img.save(p)
|
||||
else:
|
||||
with open(p, "wb") as f:
|
||||
f.write(img)
|
||||
return p
|
||||
|
||||
if handler.event_type == EventType.AdapterMessageEvent:
|
||||
# 处理平台适配器消息事件
|
||||
has_admin = False
|
||||
for filter in (
|
||||
handler.event_filters
|
||||
): # 正常handler就只有 1~2 个 filter,因此这里时间复杂度不会太高
|
||||
if isinstance(filter, CommandFilter):
|
||||
info["type"] = "指令"
|
||||
info["cmd"] = (
|
||||
f"{filter.parent_command_names[0]} {filter.command_name}"
|
||||
)
|
||||
info["cmd"] = info["cmd"].strip()
|
||||
if (
|
||||
self.core_lifecycle.astrbot_config["wake_prefix"]
|
||||
and len(self.core_lifecycle.astrbot_config["wake_prefix"])
|
||||
> 0
|
||||
):
|
||||
info["cmd"] = (
|
||||
f"{self.core_lifecycle.astrbot_config['wake_prefix'][0]}{info['cmd']}"
|
||||
)
|
||||
elif isinstance(filter, CommandGroupFilter):
|
||||
info["type"] = "指令组"
|
||||
info["cmd"] = filter.get_complete_command_names()[0]
|
||||
info["cmd"] = info["cmd"].strip()
|
||||
info["sub_command"] = filter.print_cmd_tree(
|
||||
filter.sub_command_filters
|
||||
)
|
||||
if (
|
||||
self.core_lifecycle.astrbot_config["wake_prefix"]
|
||||
and len(self.core_lifecycle.astrbot_config["wake_prefix"])
|
||||
> 0
|
||||
):
|
||||
info["cmd"] = (
|
||||
f"{self.core_lifecycle.astrbot_config['wake_prefix'][0]}{info['cmd']}"
|
||||
)
|
||||
elif isinstance(filter, RegexFilter):
|
||||
info["type"] = "正则匹配"
|
||||
info["cmd"] = filter.regex_str
|
||||
elif isinstance(filter, PermissionTypeFilter):
|
||||
has_admin = True
|
||||
info["has_admin"] = has_admin
|
||||
if "cmd" not in info:
|
||||
info["cmd"] = "未知"
|
||||
if "type" not in info:
|
||||
info["type"] = "事件监听器"
|
||||
|
||||
async def download_image_by_url(
|
||||
url: str, post: bool = False, post_data: dict = None, path=None
|
||||
) -> str:
|
||||
"""
|
||||
下载图片, 返回 path
|
||||
"""
|
||||
try:
|
||||
ssl_context = ssl.create_default_context(cafile=certifi.where()) # 使用 certifi 提供的 CA 证书
|
||||
connector = aiohttp.TCPConnector(ssl=ssl_context) # 使用 certifi 的根证书
|
||||
async with aiohttp.ClientSession(trust_env=True, connector=connector) as session:
|
||||
if post:
|
||||
async with session.post(url, json=post_data) as resp:
|
||||
if not path:
|
||||
return save_temp_img(await resp.read())
|
||||
else:
|
||||
with open(path, "wb") as f:
|
||||
f.write(await resp.read())
|
||||
return path
|
||||
else:
|
||||
info["cmd"] = "自动触发"
|
||||
info["type"] = "无"
|
||||
async with session.get(url) as resp:
|
||||
if not path:
|
||||
return save_temp_img(await resp.read())
|
||||
else:
|
||||
with open(path, "wb") as f:
|
||||
f.write(await resp.read())
|
||||
return path
|
||||
except aiohttp.client.ClientConnectorSSLError:
|
||||
# 关闭SSL验证
|
||||
ssl_context = ssl.create_default_context()
|
||||
ssl_context.set_ciphers("DEFAULT")
|
||||
async with aiohttp.ClientSession() as session:
|
||||
if post:
|
||||
async with session.get(url, ssl=ssl_context) as resp:
|
||||
return save_temp_img(await resp.read())
|
||||
else:
|
||||
async with session.get(url, ssl=ssl_context) as resp:
|
||||
return save_temp_img(await resp.read())
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
if not info["desc"]:
|
||||
info["desc"] = "无描述"
|
||||
|
||||
handlers.append(info)
|
||||
async def download_file(url: str, path: str, show_progress: bool = False):
|
||||
"""
|
||||
从指定 url 下载文件到指定路径 path
|
||||
"""
|
||||
try:
|
||||
ssl_context = ssl.create_default_context(cafile=certifi.where()) # 使用 certifi 提供的 CA 证书
|
||||
connector = aiohttp.TCPConnector(ssl=ssl_context)
|
||||
async with aiohttp.ClientSession(trust_env=True, connector=connector) as session:
|
||||
async with session.get(url, timeout=1800) as resp:
|
||||
if resp.status != 200:
|
||||
raise Exception(f"下载文件失败: {resp.status}")
|
||||
total_size = int(resp.headers.get("content-length", 0))
|
||||
downloaded_size = 0
|
||||
start_time = time.time()
|
||||
if show_progress:
|
||||
print(f"文件大小: {total_size / 1024:.2f} KB | 文件地址: {url}")
|
||||
with open(path, "wb") as f:
|
||||
while True:
|
||||
chunk = await resp.content.read(8192)
|
||||
if not chunk:
|
||||
break
|
||||
f.write(chunk)
|
||||
downloaded_size += len(chunk)
|
||||
if show_progress:
|
||||
elapsed_time = time.time() - start_time
|
||||
speed = downloaded_size / 1024 / elapsed_time # KB/s
|
||||
print(
|
||||
f"\r下载进度: {downloaded_size / total_size:.2%} 速度: {speed:.2f} KB/s",
|
||||
end="",
|
||||
)
|
||||
except aiohttp.client.ClientConnectorSSLError:
|
||||
# 关闭SSL验证
|
||||
ssl_context = ssl.create_default_context()
|
||||
ssl_context.set_ciphers("DEFAULT")
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url, ssl=ssl_context, timeout=120) as resp:
|
||||
total_size = int(resp.headers.get("content-length", 0))
|
||||
downloaded_size = 0
|
||||
start_time = time.time()
|
||||
if show_progress:
|
||||
print(f"文件大小: {total_size / 1024:.2f} KB | 文件地址: {url}")
|
||||
with open(path, "wb") as f:
|
||||
while True:
|
||||
chunk = await resp.content.read(8192)
|
||||
if not chunk:
|
||||
break
|
||||
f.write(chunk)
|
||||
downloaded_size += len(chunk)
|
||||
if show_progress:
|
||||
elapsed_time = time.time() - start_time
|
||||
speed = downloaded_size / 1024 / elapsed_time # KB/s
|
||||
print(
|
||||
f"\r下载进度: {downloaded_size / total_size:.2%} 速度: {speed:.2f} KB/s",
|
||||
end="",
|
||||
)
|
||||
if show_progress:
|
||||
print()
|
||||
|
||||
return handlers
|
||||
|
||||
async def install_plugin(self):
|
||||
post_data = await request.json
|
||||
repo_url = post_data["url"]
|
||||
def file_to_base64(file_path: str) -> str:
|
||||
with open(file_path, "rb") as f:
|
||||
data_bytes = f.read()
|
||||
base64_str = base64.b64encode(data_bytes).decode()
|
||||
return "base64://" + base64_str
|
||||
|
||||
proxy: str = post_data.get("proxy", None)
|
||||
if proxy:
|
||||
proxy = proxy.removesuffix("/")
|
||||
|
||||
try:
|
||||
logger.info(f"正在安装插件 {repo_url}")
|
||||
await self.plugin_manager.install_plugin(repo_url, proxy)
|
||||
# self.core_lifecycle.restart()
|
||||
logger.info(f"安装插件 {repo_url} 成功。")
|
||||
return Response().ok(None, "安装成功。").__dict__
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(str(e)).__dict__
|
||||
def get_local_ip_addresses():
|
||||
net_interfaces = psutil.net_if_addrs()
|
||||
network_ips = []
|
||||
|
||||
async def install_plugin_upload(self):
|
||||
try:
|
||||
file = await request.files
|
||||
file = file["file"]
|
||||
logger.info(f"正在安装用户上传的插件 {file.filename}")
|
||||
file_path = f"data/temp/{file.filename}"
|
||||
await file.save(file_path)
|
||||
await self.plugin_manager.install_plugin_from_file(file_path)
|
||||
# self.core_lifecycle.restart()
|
||||
logger.info(f"安装插件 {file.filename} 成功")
|
||||
return Response().ok(None, "安装成功。").__dict__
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(str(e)).__dict__
|
||||
for interface, addrs in net_interfaces.items():
|
||||
for addr in addrs:
|
||||
if addr.family == socket.AF_INET: # 使用 socket.AF_INET 代替 psutil.AF_INET
|
||||
network_ips.append(addr.address)
|
||||
|
||||
async def uninstall_plugin(self):
|
||||
post_data = await request.json
|
||||
plugin_name = post_data["name"]
|
||||
try:
|
||||
logger.info(f"正在卸载插件 {plugin_name}")
|
||||
await self.plugin_manager.uninstall_plugin(plugin_name)
|
||||
logger.info(f"卸载插件 {plugin_name} 成功")
|
||||
return Response().ok(None, "卸载成功").__dict__
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(str(e)).__dict__
|
||||
return network_ips
|
||||
|
||||
async def update_plugin(self):
|
||||
post_data = await request.json
|
||||
plugin_name = post_data["name"]
|
||||
proxy: str = post_data.get("proxy", None)
|
||||
try:
|
||||
logger.info(f"正在更新插件 {plugin_name}")
|
||||
await self.plugin_manager.update_plugin(plugin_name, proxy)
|
||||
# self.core_lifecycle.restart()
|
||||
await self.plugin_manager.reload(plugin_name)
|
||||
logger.info(f"更新插件 {plugin_name} 成功。")
|
||||
return Response().ok(None, "更新成功。").__dict__
|
||||
except Exception as e:
|
||||
logger.error(f"/api/plugin/update: {traceback.format_exc()}")
|
||||
return Response().error(str(e)).__dict__
|
||||
|
||||
async def off_plugin(self):
|
||||
post_data = await request.json
|
||||
plugin_name = post_data["name"]
|
||||
try:
|
||||
await self.plugin_manager.turn_off_plugin(plugin_name)
|
||||
logger.info(f"停用插件 {plugin_name} 。")
|
||||
return Response().ok(None, "停用成功。").__dict__
|
||||
except Exception as e:
|
||||
logger.error(f"/api/plugin/off: {traceback.format_exc()}")
|
||||
return Response().error(str(e)).__dict__
|
||||
async def get_dashboard_version():
|
||||
if os.path.exists("data/dist"):
|
||||
if os.path.exists("data/dist/assets/version"):
|
||||
with open("data/dist/assets/version", "r") as f:
|
||||
v = f.read().strip()
|
||||
return v
|
||||
return None
|
||||
|
||||
async def on_plugin(self):
|
||||
post_data = await request.json
|
||||
plugin_name = post_data["name"]
|
||||
try:
|
||||
await self.plugin_manager.turn_on_plugin(plugin_name)
|
||||
logger.info(f"启用插件 {plugin_name} 。")
|
||||
return Response().ok(None, "启用成功。").__dict__
|
||||
except Exception as e:
|
||||
logger.error(f"/api/plugin/on: {traceback.format_exc()}")
|
||||
return Response().error(str(e)).__dict__
|
||||
|
||||
async def download_dashboard():
|
||||
"""下载管理面板文件"""
|
||||
dashboard_release_url = "https://astrbot-registry.soulter.top/download/astrbot-dashboard/latest/dist.zip"
|
||||
try:
|
||||
ssl_context = ssl.create_default_context(cafile=certifi.where()) # 使用 certifi 提供的 CA 证书
|
||||
await download_file(
|
||||
dashboard_release_url, "data/dashboard.zip", show_progress=True
|
||||
)
|
||||
except BaseException as _:
|
||||
dashboard_release_url = (
|
||||
"https://github.com/Soulter/AstrBot/releases/latest/download/dist.zip"
|
||||
)
|
||||
await download_file(
|
||||
dashboard_release_url, "data/dashboard.zip", show_progress=True
|
||||
)
|
||||
print("解压管理面板文件中...")
|
||||
with zipfile.ZipFile("data/dashboard.zip", "r") as z:
|
||||
z.extractall("data")
|
||||
|
||||
+2
-1
@@ -25,4 +25,5 @@ dashscope
|
||||
python-telegram-bot
|
||||
wechatpy
|
||||
dingtalk-stream
|
||||
mcp
|
||||
mcp
|
||||
certifi
|
||||
Reference in New Issue
Block a user