fix: fix error

This commit is contained in:
冰苷晶
2025-03-23 23:02:34 +08:00
parent d10cb84068
commit 1cb2b62f81
+243 -201
View File
@@ -1,225 +1,267 @@
import os
import ssl
import shutil
import socket
import time
import traceback
import aiohttp
import base64
import zipfile
import uuid
import psutil
import ssl
import certifi
from typing import Union
from PIL import Image
from .route import Route, Response, RouteContext
from astrbot.core import logger
from quart import request
from astrbot.core.star.star_manager import PluginManager
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.star.star_handler import star_handlers_registry
from astrbot.core.star.filter.command import CommandFilter
from astrbot.core.star.filter.command_group import CommandGroupFilter
from astrbot.core.star.filter.permission import PermissionTypeFilter
from astrbot.core.star.filter.regex import RegexFilter
from astrbot.core.star.star_handler import EventType
def on_error(func, path, exc_info):
"""
a callback of the rmtree function.
"""
print(f"remove {path} failed.")
import stat
class PluginRoute(Route):
def __init__(
self,
context: RouteContext,
core_lifecycle: AstrBotCoreLifecycle,
plugin_manager: PluginManager,
) -> None:
super().__init__(context)
self.routes = {
"/plugin/get": ("GET", self.get_plugins),
"/plugin/install": ("POST", self.install_plugin),
"/plugin/install-upload": ("POST", self.install_plugin_upload),
"/plugin/update": ("POST", self.update_plugin),
"/plugin/uninstall": ("POST", self.uninstall_plugin),
"/plugin/market_list": ("GET", self.get_online_plugins),
"/plugin/off": ("POST", self.off_plugin),
"/plugin/on": ("POST", self.on_plugin),
"/plugin/reload": ("POST", self.reload_plugins),
}
self.core_lifecycle = core_lifecycle
self.plugin_manager = plugin_manager
self.register_routes()
if not os.access(path, os.W_OK):
os.chmod(path, stat.S_IWUSR)
func(path)
else:
raise
self.translated_event_type = {
EventType.AdapterMessageEvent: "平台消息下发时",
EventType.OnLLMRequestEvent: "LLM 请求时",
EventType.OnLLMResponseEvent: "LLM 响应后",
EventType.OnDecoratingResultEvent: "回复消息前",
EventType.OnCallingFuncToolEvent: "函数工具",
EventType.OnAfterMessageSentEvent: "发送消息后",
}
async def reload_plugins(self):
data = await request.json
plugin_name = data.get("name", None)
try:
success, message = await self.plugin_manager.reload(plugin_name)
if not success:
return Response().error(message).__dict__
return Response().ok(None, "重载成功。").__dict__
except Exception as e:
logger.error(f"/api/plugin/reload: {traceback.format_exc()}")
return Response().error(str(e)).__dict__
def remove_dir(file_path) -> bool:
if not os.path.exists(file_path):
return True
try:
shutil.rmtree(file_path, onerror=on_error)
return True
except BaseException:
return False
async def get_online_plugins(self):
custom = request.args.get("custom_registry")
if custom:
urls = [custom]
else:
urls = ["https://api.soulter.top/astrbot/plugins"]
def port_checker(port: int, host: str = "localhost"):
sk = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sk.settimeout(1)
try:
sk.connect((host, port))
sk.close()
return True
except Exception:
sk.close()
return False
def save_temp_img(img: Union[Image.Image, str]) -> str:
os.makedirs("data/temp", exist_ok=True)
# 获得文件创建时间,清除超过 12 小时的
try:
for f in os.listdir("data/temp"):
path = os.path.join("data/temp", f)
if os.path.isfile(path):
ctime = os.path.getctime(path)
if time.time() - ctime > 3600 * 12:
os.remove(path)
except Exception as e:
print(f"清除临时文件失败: {e}")
# 获得时间戳
timestamp = f"{int(time.time())}_{uuid.uuid4().hex[:8]}"
p = f"data/temp/{timestamp}.jpg"
if isinstance(img, Image.Image):
img.save(p)
else:
with open(p, "wb") as f:
f.write(img)
return p
async def download_image_by_url(
url: str, post: bool = False, post_data: dict = None, path=None
) -> str:
"""
下载图片, 返回 path
"""
try:
ssl_context = ssl.create_default_context(cafile=certifi.where()) # 使用 certifi 提供的 CA 证书
connector = aiohttp.TCPConnector(ssl=ssl_context) # 使用 certifi 的根证书
async with aiohttp.ClientSession(trust_env=True, connector=connector) as session:
if post:
async with session.post(url, json=post_data) as resp:
if not path:
return save_temp_img(await resp.read())
else:
with open(path, "wb") as f:
f.write(await resp.read())
return path
else:
async with session.get(url) as resp:
if not path:
return save_temp_img(await resp.read())
else:
with open(path, "wb") as f:
f.write(await resp.read())
return path
except aiohttp.client.ClientConnectorSSLError:
# 关闭SSL验证
ssl_context = ssl.create_default_context()
ssl_context.set_ciphers("DEFAULT")
async with aiohttp.ClientSession() as session:
if post:
async with session.get(url, ssl=ssl_context) as resp:
return save_temp_img(await resp.read())
else:
async with session.get(url, ssl=ssl_context) as resp:
return save_temp_img(await resp.read())
except Exception as e:
raise e
async def download_file(url: str, path: str, show_progress: bool = False):
"""
从指定 url 下载文件到指定路径 path
"""
try:
ssl_context = ssl.create_default_context(cafile=certifi.where()) # 使用 certifi 提供的 CA 证书
# 新增:创建 SSL 上下文,使用 certifi 提供的根证书
ssl_context = ssl.create_default_context(cafile=certifi.where())
connector = aiohttp.TCPConnector(ssl=ssl_context)
async with aiohttp.ClientSession(trust_env=True, connector=connector) as session:
async with session.get(url, timeout=1800) as resp:
if resp.status != 200:
raise Exception(f"下载文件失败: {resp.status}")
total_size = int(resp.headers.get("content-length", 0))
downloaded_size = 0
start_time = time.time()
if show_progress:
print(f"文件大小: {total_size / 1024:.2f} KB | 文件地址: {url}")
with open(path, "wb") as f:
while True:
chunk = await resp.content.read(8192)
if not chunk:
break
f.write(chunk)
downloaded_size += len(chunk)
if show_progress:
elapsed_time = time.time() - start_time
speed = downloaded_size / 1024 / elapsed_time # KB/s
print(
f"\r下载进度: {downloaded_size / total_size:.2%} 速度: {speed:.2f} KB/s",
end="",
for url in urls:
try:
async with aiohttp.ClientSession(trust_env=True, connector=connector) as session:
async with session.get(url) as response:
if response.status == 200:
result = await response.json()
return Response().ok(result).__dict__
else:
logger.error(f"请求 {url} 失败,状态码:{response.status}")
except Exception as e:
logger.error(f"请求 {url} 失败,错误:{e}")
return Response().error("获取插件列表失败").__dict__
async def get_plugins(self):
_plugin_resp = []
for plugin in self.plugin_manager.context.get_all_stars():
_t = {
"name": plugin.name,
"repo": "" if plugin.repo is None else plugin.repo,
"author": plugin.author,
"desc": plugin.desc,
"version": plugin.version,
"reserved": plugin.reserved,
"activated": plugin.activated,
"online_vesion": "",
"handlers": await self.get_plugin_handlers_info(
plugin.star_handler_full_names
),
}
_plugin_resp.append(_t)
return (
Response()
.ok(_plugin_resp, message=self.plugin_manager.failed_plugin_info)
.__dict__
)
async def get_plugin_handlers_info(self, handler_full_names: list[str]):
"""解析插件行为"""
handlers = []
for handler_full_name in handler_full_names:
info = {}
handler = star_handlers_registry.star_handlers_map.get(
handler_full_name, None
)
if handler is None:
continue
info["event_type"] = handler.event_type.name
info["event_type_h"] = self.translated_event_type.get(
handler.event_type, handler.event_type.name
)
info["handler_full_name"] = handler.handler_full_name
info["desc"] = handler.desc
info["handler_name"] = handler.handler_name
if handler.event_type == EventType.AdapterMessageEvent:
# 处理平台适配器消息事件
has_admin = False
for filter in (
handler.event_filters
): # 正常handler就只有 1~2 个 filter,因此这里时间复杂度不会太高
if isinstance(filter, CommandFilter):
info["type"] = "指令"
info["cmd"] = (
f"{filter.parent_command_names[0]} {filter.command_name}"
)
info["cmd"] = info["cmd"].strip()
if (
self.core_lifecycle.astrbot_config["wake_prefix"]
and len(self.core_lifecycle.astrbot_config["wake_prefix"])
> 0
):
info["cmd"] = (
f"{self.core_lifecycle.astrbot_config['wake_prefix'][0]}{info['cmd']}"
)
except aiohttp.client.ClientConnectorSSLError:
# 关闭SSL验证
ssl_context = ssl.create_default_context()
ssl_context.set_ciphers("DEFAULT")
async with aiohttp.ClientSession() as session:
async with session.get(url, ssl=ssl_context, timeout=120) as resp:
total_size = int(resp.headers.get("content-length", 0))
downloaded_size = 0
start_time = time.time()
if show_progress:
print(f"文件大小: {total_size / 1024:.2f} KB | 文件地址: {url}")
with open(path, "wb") as f:
while True:
chunk = await resp.content.read(8192)
if not chunk:
break
f.write(chunk)
downloaded_size += len(chunk)
if show_progress:
elapsed_time = time.time() - start_time
speed = downloaded_size / 1024 / elapsed_time # KB/s
print(
f"\r下载进度: {downloaded_size / total_size:.2%} 速度: {speed:.2f} KB/s",
end="",
elif isinstance(filter, CommandGroupFilter):
info["type"] = "指令组"
info["cmd"] = filter.get_complete_command_names()[0]
info["cmd"] = info["cmd"].strip()
info["sub_command"] = filter.print_cmd_tree(
filter.sub_command_filters
)
if (
self.core_lifecycle.astrbot_config["wake_prefix"]
and len(self.core_lifecycle.astrbot_config["wake_prefix"])
> 0
):
info["cmd"] = (
f"{self.core_lifecycle.astrbot_config['wake_prefix'][0]}{info['cmd']}"
)
if show_progress:
print()
elif isinstance(filter, RegexFilter):
info["type"] = "正则匹配"
info["cmd"] = filter.regex_str
elif isinstance(filter, PermissionTypeFilter):
has_admin = True
info["has_admin"] = has_admin
if "cmd" not in info:
info["cmd"] = "未知"
if "type" not in info:
info["type"] = "事件监听器"
else:
info["cmd"] = "自动触发"
info["type"] = ""
if not info["desc"]:
info["desc"] = "无描述"
def file_to_base64(file_path: str) -> str:
with open(file_path, "rb") as f:
data_bytes = f.read()
base64_str = base64.b64encode(data_bytes).decode()
return "base64://" + base64_str
handlers.append(info)
return handlers
def get_local_ip_addresses():
net_interfaces = psutil.net_if_addrs()
network_ips = []
async def install_plugin(self):
post_data = await request.json
repo_url = post_data["url"]
for interface, addrs in net_interfaces.items():
for addr in addrs:
if addr.family == socket.AF_INET: # 使用 socket.AF_INET 代替 psutil.AF_INET
network_ips.append(addr.address)
proxy: str = post_data.get("proxy", None)
if proxy:
proxy = proxy.removesuffix("/")
return network_ips
try:
logger.info(f"正在安装插件 {repo_url}")
await self.plugin_manager.install_plugin(repo_url, proxy)
# self.core_lifecycle.restart()
logger.info(f"安装插件 {repo_url} 成功。")
return Response().ok(None, "安装成功。").__dict__
except Exception as e:
logger.error(traceback.format_exc())
return Response().error(str(e)).__dict__
async def install_plugin_upload(self):
try:
file = await request.files
file = file["file"]
logger.info(f"正在安装用户上传的插件 {file.filename}")
file_path = f"data/temp/{file.filename}"
await file.save(file_path)
await self.plugin_manager.install_plugin_from_file(file_path)
# self.core_lifecycle.restart()
logger.info(f"安装插件 {file.filename} 成功")
return Response().ok(None, "安装成功。").__dict__
except Exception as e:
logger.error(traceback.format_exc())
return Response().error(str(e)).__dict__
async def get_dashboard_version():
if os.path.exists("data/dist"):
if os.path.exists("data/dist/assets/version"):
with open("data/dist/assets/version", "r") as f:
v = f.read().strip()
return v
return None
async def uninstall_plugin(self):
post_data = await request.json
plugin_name = post_data["name"]
try:
logger.info(f"正在卸载插件 {plugin_name}")
await self.plugin_manager.uninstall_plugin(plugin_name)
logger.info(f"卸载插件 {plugin_name} 成功")
return Response().ok(None, "卸载成功").__dict__
except Exception as e:
logger.error(traceback.format_exc())
return Response().error(str(e)).__dict__
async def update_plugin(self):
post_data = await request.json
plugin_name = post_data["name"]
proxy: str = post_data.get("proxy", None)
try:
logger.info(f"正在更新插件 {plugin_name}")
await self.plugin_manager.update_plugin(plugin_name, proxy)
# self.core_lifecycle.restart()
await self.plugin_manager.reload(plugin_name)
logger.info(f"更新插件 {plugin_name} 成功。")
return Response().ok(None, "更新成功。").__dict__
except Exception as e:
logger.error(f"/api/plugin/update: {traceback.format_exc()}")
return Response().error(str(e)).__dict__
async def download_dashboard():
"""下载管理面板文件"""
dashboard_release_url = "https://astrbot-registry.soulter.top/download/astrbot-dashboard/latest/dist.zip"
try:
ssl_context = ssl.create_default_context(cafile=certifi.where()) # 使用 certifi 提供的 CA 证书
await download_file(
dashboard_release_url, "data/dashboard.zip", show_progress=True
)
except BaseException as _:
dashboard_release_url = (
"https://github.com/Soulter/AstrBot/releases/latest/download/dist.zip"
)
await download_file(
dashboard_release_url, "data/dashboard.zip", show_progress=True
)
print("解压管理面板文件中...")
with zipfile.ZipFile("data/dashboard.zip", "r") as z:
z.extractall("data")
async def off_plugin(self):
post_data = await request.json
plugin_name = post_data["name"]
try:
await self.plugin_manager.turn_off_plugin(plugin_name)
logger.info(f"停用插件 {plugin_name}")
return Response().ok(None, "停用成功。").__dict__
except Exception as e:
logger.error(f"/api/plugin/off: {traceback.format_exc()}")
return Response().error(str(e)).__dict__
async def on_plugin(self):
post_data = await request.json
plugin_name = post_data["name"]
try:
await self.plugin_manager.turn_on_plugin(plugin_name)
logger.info(f"启用插件 {plugin_name}")
return Response().ok(None, "启用成功。").__dict__
except Exception as e:
logger.error(f"/api/plugin/on: {traceback.format_exc()}")
return Response().error(str(e)).__dict__