Compare commits
9 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 353b6ed761 | |||
| 90815b1ac5 | |||
| 8a50786e61 | |||
| 3b77df0556 | |||
| 1fa11062de | |||
| 6883de0f1c | |||
| bdde0fe094 | |||
| ab22b8103e | |||
| 641d5cd67b |
+19
-7
@@ -11,15 +11,14 @@ from model.platform.manager import PlatformManager
|
||||
from typing import Union
|
||||
from type.types import Context
|
||||
from type.config import VERSION
|
||||
from SparkleLogging.utils.core import LogManager
|
||||
from logging import Logger
|
||||
from util.cmd_config import AstrBotConfig, try_migrate
|
||||
from util.metrics import MetricUploader
|
||||
from util.updator.astrbot_updator import AstrBotUpdator
|
||||
from util.log import LogManager
|
||||
|
||||
logger: Logger = LogManager.GetLogger(log_name='astrbot')
|
||||
|
||||
|
||||
class AstrBotBootstrap():
|
||||
def __init__(self) -> None:
|
||||
self.context = Context()
|
||||
@@ -28,7 +27,11 @@ class AstrBotBootstrap():
|
||||
try_migrate()
|
||||
self.config_helper = AstrBotConfig()
|
||||
self.context.config_helper = self.config_helper
|
||||
# set log queue handler
|
||||
LogManager.set_queue_handler(logger, self.context._log_queue)
|
||||
logger.info("AstrBot v" + VERSION)
|
||||
# set log level
|
||||
logger.setLevel(self.config_helper.log_level)
|
||||
# apply proxy settings
|
||||
http_proxy = self.context.config_helper.http_proxy
|
||||
https_proxy = self.context.config_helper.https_proxy
|
||||
@@ -44,6 +47,12 @@ class AstrBotBootstrap():
|
||||
logger.info("未使用代理。")
|
||||
|
||||
self.test_mode = os.environ.get('TEST_MODE', 'off') == 'on'
|
||||
|
||||
# set t2i endpoint
|
||||
if self.context.config_helper.t2i_endpoint:
|
||||
self.context.image_renderer.set_network_endpoint(
|
||||
self.context.config_helper.t2i_endpoint
|
||||
)
|
||||
|
||||
async def run(self):
|
||||
self.command_manager = CommandManager()
|
||||
@@ -65,7 +74,12 @@ class AstrBotBootstrap():
|
||||
self.context.plugin_updator = self.plugin_manager.updator
|
||||
self.context.message_handler = self.message_handler
|
||||
self.context.command_manager = self.command_manager
|
||||
|
||||
|
||||
# load dashboard
|
||||
self.dashboard.run_http_server()
|
||||
dashboard_ws_task = asyncio.create_task(self.dashboard.ws_server(), name="dashboard")
|
||||
dashboard_log_task = asyncio.create_task(self.dashboard.log_consumer(), name="log")
|
||||
|
||||
if self.test_mode:
|
||||
return
|
||||
|
||||
@@ -77,10 +91,8 @@ class AstrBotBootstrap():
|
||||
platform_tasks = self.load_platform()
|
||||
# load metrics uploader
|
||||
metrics_upload_task = asyncio.create_task(self.metrics_uploader.upload_metrics(), name="metrics-uploader")
|
||||
# load dashboard
|
||||
self.dashboard.run_http_server()
|
||||
dashboard_task = asyncio.create_task(self.dashboard.ws_server(), name="dashboard")
|
||||
tasks = [metrics_upload_task, dashboard_task, *platform_tasks, *self.context.ext_tasks]
|
||||
|
||||
tasks = [metrics_upload_task, dashboard_ws_task, dashboard_log_task, *platform_tasks, *self.context.ext_tasks]
|
||||
tasks = [self.handle_task(task) for task in tasks]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ from model.command.manager import CommandManager
|
||||
from type.message_event import AstrMessageEvent, MessageResult
|
||||
from type.types import Context
|
||||
from type.command import CommandResult
|
||||
from SparkleLogging.utils.core import LogManager
|
||||
from util.log import LogManager
|
||||
from logging import Logger
|
||||
from nakuru.entities.components import Image
|
||||
from util.agent.func_call import FuncCall
|
||||
@@ -60,6 +60,9 @@ class ContentSafetyHelper():
|
||||
from astrbot.message.baidu_aip_judge import BaiduJudge
|
||||
self.baidu_judge = BaiduJudge(aip)
|
||||
logger.info("已启用百度 AI 内容审核。")
|
||||
except ImportError as e:
|
||||
logger.error("检测到库依赖不完整,将不会启用百度 AI 内容审核。请先使用 pip 安装 `baidu_aip` 包。")
|
||||
logger.error(e)
|
||||
except BaseException as e:
|
||||
logger.error("百度 AI 内容审核初始化失败。")
|
||||
logger.error(e)
|
||||
|
||||
+37
-18
@@ -2,7 +2,7 @@ from . import DashBoardData
|
||||
from util.cmd_config import AstrBotConfig
|
||||
from dataclasses import dataclass, asdict
|
||||
from util.plugin_dev.api.v1.config import update_config
|
||||
from SparkleLogging.utils.core import LogManager
|
||||
from util.log import LogManager
|
||||
from logging import Logger
|
||||
from type.types import Context
|
||||
from type.config import CONFIG_METADATA_2
|
||||
@@ -15,29 +15,48 @@ class DashBoardHelper():
|
||||
self.context = context
|
||||
self.config_key_dont_show = ['dashboard', 'config_version']
|
||||
|
||||
def try_cast(self, value: str, type_: str):
|
||||
if type_ == "int" and value.isdigit():
|
||||
return int(value)
|
||||
elif type_ == "float" and isinstance(value, str) \
|
||||
and value.replace(".", "", 1).isdigit():
|
||||
return float(value)
|
||||
elif type_ == "float" and isinstance(value, int):
|
||||
return float(value)
|
||||
|
||||
|
||||
def validate_config(self, data):
|
||||
errors = []
|
||||
# 递归验证数据
|
||||
def validate(data, path=""):
|
||||
for key, meta in CONFIG_METADATA_2.items():
|
||||
def validate(data, metadata=CONFIG_METADATA_2, path=""):
|
||||
for key, meta in metadata.items():
|
||||
if key not in data:
|
||||
if key not in self.config_key_dont_show:
|
||||
# 这些key不会传给前端,所以不需要验证
|
||||
errors.append(f"Missing key: {path}{key}")
|
||||
continue
|
||||
value = data[key]
|
||||
if meta["type"] == "int" and not isinstance(value, int):
|
||||
errors.append(f"Invalid type for {path}{key}: expected int, got {type(value).__name__}")
|
||||
elif meta["type"] == "bool" and not isinstance(value, bool):
|
||||
errors.append(f"Invalid type for {path}{key}: expected bool, got {type(value).__name__}")
|
||||
elif meta["type"] == "string" and not isinstance(value, str):
|
||||
errors.append(f"Invalid type for {path}{key}: expected string, got {type(value).__name__}")
|
||||
elif meta["type"] == "list" and not isinstance(value, list):
|
||||
errors.append(f"Invalid type for {path}{key}: expected list, got {type(value).__name__}")
|
||||
# 递归验证
|
||||
if meta["type"] == "list" and isinstance(value, list):
|
||||
for item in value:
|
||||
validate(item, meta["items"], path=f"{path}{key}.")
|
||||
elif meta["type"] == "dict" and not isinstance(value, dict):
|
||||
errors.append(f"Invalid type for {path}{key}: expected dict, got {type(value).__name__}")
|
||||
elif meta["type"] == "object" and isinstance(value, dict):
|
||||
validate(value, meta["items"], path=f"{path}{key}.")
|
||||
|
||||
if meta["type"] == "int" and not isinstance(value, int):
|
||||
casted = self.try_cast(value, "int")
|
||||
if casted is None:
|
||||
errors.append(f"错误的类型 {path}{key}: 期望是 int, 得到了 {type(value).__name__}")
|
||||
data[key] = casted
|
||||
elif meta["type"] == "float" and not isinstance(value, float):
|
||||
casted = self.try_cast(value, "float")
|
||||
if casted is None:
|
||||
errors.append(f"错误的类型 {path}{key}: 期望是 float, 得到了 {type(value).__name__}")
|
||||
data[key] = casted
|
||||
elif meta["type"] == "bool" and not isinstance(value, bool):
|
||||
errors.append(f"错误的类型 {path}{key}: 期望是 bool, 得到了 {type(value).__name__}")
|
||||
elif meta["type"] == "string" and not isinstance(value, str):
|
||||
errors.append(f"错误的类型 {path}{key}: 期望是 string, 得到了 {type(value).__name__}")
|
||||
elif meta["type"] == "list" and not isinstance(value, list):
|
||||
errors.append(f"错误的类型 {path}{key}: 期望是 list, 得到了 {type(value).__name__}")
|
||||
elif meta["type"] == "object" and not isinstance(value, dict):
|
||||
errors.append(f"错误的类型 {path}{key}: 期望是 dict, 得到了 {type(value).__name__}")
|
||||
validate(value, meta["items"], path=f"{path}{key}.")
|
||||
validate(data)
|
||||
|
||||
@@ -68,6 +87,6 @@ class DashBoardHelper():
|
||||
typ = item['val_type']
|
||||
if typ == 'int':
|
||||
if not value.isdigit():
|
||||
raise ValueError(f"Invalid type for {namespace}.{key}: expected int, got {type(value).__name__}")
|
||||
raise ValueError(f"错误的类型 {namespace}.{key}: 期望是 int, 得到了 {type(value).__name__}")
|
||||
value = int(value)
|
||||
update_config(namespace, key, value)
|
||||
|
||||
+29
-21
@@ -13,7 +13,7 @@ from werkzeug.serving import make_server
|
||||
from astrbot.persist.helper import dbConn
|
||||
from type.types import Context
|
||||
from typing import List
|
||||
from SparkleLogging.utils.core import LogManager
|
||||
from util.log import LogManager
|
||||
from logging import Logger
|
||||
from dashboard.helper import DashBoardHelper
|
||||
from util.io import get_local_ip_addresses
|
||||
@@ -206,7 +206,8 @@ class AstrBotDashBoard():
|
||||
repo_url = post_data["url"]
|
||||
try:
|
||||
logger.info(f"正在安装插件 {repo_url}")
|
||||
self.plugin_manager.install_plugin(repo_url)
|
||||
# self.plugin_manager.install_plugin(repo_url)
|
||||
asyncio.run_coroutine_threadsafe(self.plugin_manager.install_plugin(repo_url), self.loop).result()
|
||||
threading.Thread(target=self.astrbot_updator._reboot, args=(2, self.context)).start()
|
||||
logger.info(f"安装插件 {repo_url} 成功,2秒后重启")
|
||||
return Response(
|
||||
@@ -272,7 +273,8 @@ class AstrBotDashBoard():
|
||||
plugin_name = post_data["name"]
|
||||
try:
|
||||
logger.info(f"正在更新插件 {plugin_name}")
|
||||
self.plugin_manager.update_plugin(plugin_name)
|
||||
# self.plugin_manager.update_plugin(plugin_name)
|
||||
asyncio.run_coroutine_threadsafe(self.plugin_manager.update_plugin(plugin_name), self.loop).result()
|
||||
threading.Thread(target=self.astrbot_updator._reboot, args=(2, self.context)).start()
|
||||
logger.info(f"更新插件 {plugin_name} 成功,2秒后重启")
|
||||
return Response(
|
||||
@@ -288,20 +290,12 @@ class AstrBotDashBoard():
|
||||
data=None
|
||||
).__dict__
|
||||
|
||||
@self.dashboard_be.post("/api/log")
|
||||
def log():
|
||||
for item in self.ws_clients:
|
||||
try:
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self.ws_clients[item].send(request.data.decode()), self.loop).result()
|
||||
except Exception as e:
|
||||
pass
|
||||
return 'ok'
|
||||
|
||||
@self.dashboard_be.get("/api/check_update")
|
||||
def get_update_info():
|
||||
try:
|
||||
ret = self.astrbot_updator.check_update(None, None)
|
||||
# ret = self.astrbot_updator.check_update(None, None)
|
||||
ret = asyncio.run_coroutine_threadsafe(
|
||||
self.astrbot_updator.check_update(None, None), self.loop).result()
|
||||
return Response(
|
||||
status="success",
|
||||
message=str(ret) if ret is not None else "已经是最新版本了。",
|
||||
@@ -326,7 +320,8 @@ class AstrBotDashBoard():
|
||||
else:
|
||||
latest = False
|
||||
try:
|
||||
self.astrbot_updator.update(latest=latest, version=version)
|
||||
# await self.astrbot_updator.update(latest=latest, version=version)
|
||||
asyncio.run_coroutine_threadsafe(self.astrbot_updator.update(latest=latest, version=version), self.loop).result()
|
||||
threading.Thread(target=self.astrbot_updator._reboot, args=(2, self.context)).start()
|
||||
return Response(
|
||||
status="success",
|
||||
@@ -417,18 +412,22 @@ class AstrBotDashBoard():
|
||||
|
||||
async def get_log_history(self):
|
||||
try:
|
||||
with open("logs/astrbot/astrbot.log", "r", encoding="utf-8") as f:
|
||||
return f.readlines()[-100:]
|
||||
dq = self.context._log_queue.get_cache()
|
||||
ret = ""
|
||||
for log in dq:
|
||||
ret += log + "\n\r"
|
||||
return ret
|
||||
except Exception as e:
|
||||
logger.warning(f"读取日志历史失败: {e.__str__()}")
|
||||
return []
|
||||
return ""
|
||||
|
||||
async def __handle_msg(self, websocket, path):
|
||||
address = websocket.remote_address
|
||||
self.ws_clients[address] = websocket
|
||||
data = await self.get_log_history()
|
||||
data = ''.join(data).replace('\n', '\r\n')
|
||||
await websocket.send(data)
|
||||
|
||||
# 发送日志历史
|
||||
await websocket.send(await self.get_log_history())
|
||||
|
||||
while True:
|
||||
try:
|
||||
msg = await websocket.recv()
|
||||
@@ -445,6 +444,15 @@ class AstrBotDashBoard():
|
||||
ws_server = websockets.serve(self.__handle_msg, "0.0.0.0", 6186)
|
||||
logger.info("WebSocket 服务器已启动。")
|
||||
await ws_server
|
||||
|
||||
async def log_consumer(self):
|
||||
while True:
|
||||
log = await self.context._log_queue.get()
|
||||
for ws in self.ws_clients.values():
|
||||
try:
|
||||
await ws.send(log)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
def http_server(self):
|
||||
http_server = make_server(
|
||||
|
||||
@@ -6,8 +6,7 @@ import warnings
|
||||
import traceback
|
||||
import mimetypes
|
||||
from astrbot.bootstrap import AstrBotBootstrap
|
||||
from SparkleLogging.utils.core import LogManager
|
||||
from logging import Formatter
|
||||
from util.log import LogManager
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
logo_tmpl = r"""
|
||||
@@ -27,6 +26,8 @@ def main():
|
||||
# delete qqbotpy's logger
|
||||
for handler in logging.root.handlers[:]:
|
||||
logging.root.removeHandler(handler)
|
||||
|
||||
logger.info(logo_tmpl)
|
||||
|
||||
bootstrap = AstrBotBootstrap()
|
||||
asyncio.run(bootstrap.run())
|
||||
@@ -52,11 +53,5 @@ def check_env():
|
||||
|
||||
if __name__ == "__main__":
|
||||
check_env()
|
||||
|
||||
logger = LogManager.GetLogger(
|
||||
log_name='astrbot',
|
||||
out_to_console=True,
|
||||
custom_formatter=Formatter('[%(asctime)s| %(name)s - %(levelname)s|%(filename)s:%(lineno)d]: %(message)s', datefmt="%H:%M:%S")
|
||||
)
|
||||
logger.info(logo_tmpl)
|
||||
logger = LogManager.GetLogger(log_name='astrbot')
|
||||
main()
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import aiohttp
|
||||
import aiohttp, os
|
||||
|
||||
from model.command.manager import CommandManager
|
||||
from model.plugin.manager import PluginManager
|
||||
@@ -6,7 +6,7 @@ from type.message_event import AstrMessageEvent
|
||||
from type.command import CommandResult
|
||||
from type.types import Context
|
||||
from type.config import VERSION
|
||||
from SparkleLogging.utils.core import LogManager
|
||||
from util.log import LogManager
|
||||
from logging import Logger
|
||||
from util.agent.web_searcher import search_from_bing, fetch_website_content
|
||||
|
||||
@@ -27,6 +27,13 @@ class InternalCommandHandler:
|
||||
self.manager.register("t2i", "文转图", 10, self.t2i_toggle)
|
||||
self.manager.register("myid", "用户ID", 10, self.myid)
|
||||
self.manager.register("provider", "LLM 接入源", 10, self.provider)
|
||||
|
||||
def _check_auth(self, message: AstrMessageEvent, context: Context):
|
||||
if os.environ.get("TEST_MODE", "off") == "on":
|
||||
return
|
||||
if message.role != "admin":
|
||||
user_id = message.message_obj.sender.user_id
|
||||
raise Exception(f"用户(ID: {user_id}) 没有足够的权限使用该指令。")
|
||||
|
||||
def provider(self, message: AstrMessageEvent, context: Context):
|
||||
if len(context.llms) == 0:
|
||||
@@ -57,9 +64,8 @@ class InternalCommandHandler:
|
||||
return CommandResult().message("provider: 参数错误。")
|
||||
|
||||
def set_nick(self, message: AstrMessageEvent, context: Context):
|
||||
self._check_auth(message, context)
|
||||
message_str = message.message_str
|
||||
if message.role != "admin":
|
||||
return CommandResult().message("你没有权限使用该指令。")
|
||||
l = message_str.split(" ")
|
||||
if len(l) == 1:
|
||||
return CommandResult().message(f"设置机器人唤醒词。以唤醒词开头的消息会唤醒机器人处理,起到 @ 的效果。\n示例:wake 昵称。当前唤醒词是:{context.config_helper.wake_prefix[0]}")
|
||||
@@ -74,15 +80,10 @@ class InternalCommandHandler:
|
||||
message_chain=f"已经成功将唤醒前缀设定为 {nick}。",
|
||||
)
|
||||
|
||||
def update(self, message: AstrMessageEvent, context: Context):
|
||||
async def update(self, message: AstrMessageEvent, context: Context):
|
||||
self._check_auth(message, context)
|
||||
tokens = self.manager.command_parser.parse(message.message_str)
|
||||
if message.role != "admin":
|
||||
return CommandResult(
|
||||
hit=True,
|
||||
success=False,
|
||||
message_chain="你没有权限使用该指令",
|
||||
)
|
||||
update_info = context.updator.check_update(None, None)
|
||||
update_info = await context.updator.check_update(None, None)
|
||||
if tokens.len == 1:
|
||||
ret = ""
|
||||
if not update_info:
|
||||
@@ -93,13 +94,13 @@ class InternalCommandHandler:
|
||||
else:
|
||||
if tokens.get(1) == "latest":
|
||||
try:
|
||||
context.updator.update()
|
||||
await context.updator.update()
|
||||
return CommandResult().message(f"已经成功更新到最新版本 v{update_info.version}。要应用更新,请重启 AstrBot。输入 /reboot 即可重启")
|
||||
except BaseException as e:
|
||||
return CommandResult().message(f"更新失败。原因:{str(e)}")
|
||||
elif tokens.get(1).startswith("v"):
|
||||
try:
|
||||
context.updator.update(version=tokens.get(1))
|
||||
await context.updator.update(version=tokens.get(1))
|
||||
return CommandResult().message(f"已经成功更新到版本 v{tokens.get(1)}。要应用更新,请重启 AstrBot。输入 /reboot 即可重启")
|
||||
except BaseException as e:
|
||||
return CommandResult().message(f"更新失败。原因:{str(e)}")
|
||||
@@ -107,12 +108,7 @@ class InternalCommandHandler:
|
||||
return CommandResult().message("update: 参数错误。")
|
||||
|
||||
def reboot(self, message: AstrMessageEvent, context: Context):
|
||||
if message.role != "admin":
|
||||
return CommandResult(
|
||||
hit=True,
|
||||
success=False,
|
||||
message_chain="你没有权限使用该指令",
|
||||
)
|
||||
self._check_auth(message, context)
|
||||
context.updator._reboot(3, context)
|
||||
return CommandResult(
|
||||
hit=True,
|
||||
@@ -120,7 +116,7 @@ class InternalCommandHandler:
|
||||
message_chain="AstrBot 将在 3s 后重启。",
|
||||
)
|
||||
|
||||
def plugin(self, message: AstrMessageEvent, context: Context):
|
||||
async def plugin(self, message: AstrMessageEvent, context: Context):
|
||||
tokens = self.manager.command_parser.parse(message.message_str)
|
||||
if tokens.len == 1:
|
||||
ret = "# 插件指令面板 \n- 安装插件: `plugin i 插件Github地址`\n- 卸载插件: `plugin d 插件名`\n- 查看插件列表:`plugin l`\n - 更新插件: `plugin u 插件名`\n"
|
||||
@@ -133,10 +129,10 @@ class InternalCommandHandler:
|
||||
if plugin_list_info.strip() == "":
|
||||
return CommandResult().message("plugin v: 没有找到插件。")
|
||||
return CommandResult().message(plugin_list_info)
|
||||
|
||||
self._check_auth(message, context)
|
||||
|
||||
elif tokens.get(1) == "d":
|
||||
if message.role != "admin":
|
||||
return CommandResult().message("plugin d: 你没有权限使用该指令。")
|
||||
if tokens.get(1) == "d":
|
||||
if tokens.len == 2:
|
||||
return CommandResult().message("plugin d: 请指定要卸载的插件名。")
|
||||
plugin_name = tokens.get(2)
|
||||
@@ -147,25 +143,21 @@ class InternalCommandHandler:
|
||||
return CommandResult().message(f"plugin d: 已经成功卸载插件 {plugin_name}。")
|
||||
|
||||
elif tokens.get(1) == "i":
|
||||
if message.role != "admin":
|
||||
return CommandResult().message("plugin i: 你没有权限使用该指令。")
|
||||
if tokens.len == 2:
|
||||
return CommandResult().message("plugin i: 请指定要安装的插件的 Github 地址,或者前往可视化面板安装。")
|
||||
plugin_url = tokens.get(2)
|
||||
try:
|
||||
self.plugin_manager.install_plugin(plugin_url)
|
||||
await self.plugin_manager.install_plugin(plugin_url)
|
||||
except BaseException as e:
|
||||
return CommandResult().message(f"plugin i: 安装插件失败。原因:{str(e)}")
|
||||
return CommandResult().message("plugin i: 已经成功安装插件。")
|
||||
|
||||
elif tokens.get(1) == "u":
|
||||
if message.role != "admin":
|
||||
return CommandResult().message("plugin u: 你没有权限使用该指令。")
|
||||
if tokens.len == 2:
|
||||
return CommandResult().message("plugin u: 请指定要更新的插件名。")
|
||||
plugin_name = tokens.get(2)
|
||||
try:
|
||||
self.plugin_manager.update_plugin(plugin_name)
|
||||
await context.plugin_updator.update(plugin_name)
|
||||
except BaseException as e:
|
||||
return CommandResult().message(f"plugin u: 更新插件失败。原因:{str(e)}")
|
||||
return CommandResult().message(f"plugin u: 已经成功更新插件 {plugin_name}。")
|
||||
|
||||
@@ -9,7 +9,7 @@ from type.command import CommandResult
|
||||
from type.register import RegisteredPlugins
|
||||
from model.command.parser import CommandParser
|
||||
from model.plugin.command import PluginCommandBridge
|
||||
from SparkleLogging.utils.core import LogManager
|
||||
from util.log import LogManager
|
||||
from logging import Logger
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ from model.command.manager import CommandManager
|
||||
from type.message_event import AstrMessageEvent
|
||||
from type.command import CommandResult
|
||||
from type.types import Context
|
||||
from SparkleLogging.utils.core import LogManager
|
||||
from util.log import LogManager
|
||||
from logging import Logger
|
||||
from nakuru.entities.components import Image
|
||||
from model.provider.openai_official import ProviderOpenAIOfficial, MODELS
|
||||
|
||||
@@ -40,14 +40,18 @@ class Platform():
|
||||
'''
|
||||
pass
|
||||
|
||||
def parse_message_outline(self, message: AstrBotMessage) -> str:
|
||||
def parse_message_outline(self, message: Union[AstrBotMessage, list]) -> str:
|
||||
'''
|
||||
将消息解析成大纲消息形式,如: xxxxx[图片]xxxxx。用于输出日志等。
|
||||
'''
|
||||
if isinstance(message, str):
|
||||
return message
|
||||
ret = ''
|
||||
parsed = message if isinstance(message, list) else message.message
|
||||
if isinstance(message, list):
|
||||
parsed = message
|
||||
elif isinstance(message, AstrBotMessage):
|
||||
parsed = message.message
|
||||
elif isinstance(message, str):
|
||||
return message
|
||||
|
||||
try:
|
||||
for node in parsed:
|
||||
if isinstance(node, Plain):
|
||||
|
||||
@@ -3,7 +3,7 @@ import asyncio
|
||||
from util.io import port_checker
|
||||
from type.register import RegisteredPlatform
|
||||
from type.types import Context
|
||||
from SparkleLogging.utils.core import LogManager
|
||||
from util.log import LogManager
|
||||
from logging import Logger
|
||||
from astrbot.message.handler import MessageHandler
|
||||
from util.cmd_config import (
|
||||
|
||||
@@ -10,7 +10,7 @@ from type.message_event import *
|
||||
from type.command import *
|
||||
from typing import Union, List, Dict
|
||||
from nakuru.entities.components import *
|
||||
from SparkleLogging.utils.core import LogManager
|
||||
from util.log import LogManager
|
||||
from logging import Logger
|
||||
from astrbot.message.handler import MessageHandler
|
||||
from util.cmd_config import PlatformConfig, AiocqhttpPlatformConfig
|
||||
@@ -209,7 +209,7 @@ class AIOCQHTTP(Platform):
|
||||
|
||||
if isinstance(message, AstrBotMessage):
|
||||
logger.info(
|
||||
f"{message.sender.user_id} <- {self.parse_message_outline(message)}")
|
||||
f"{message.sender.nickname}/{message.sender.user_id} <- {self.parse_message_outline(message_chain)}")
|
||||
else:
|
||||
logger.info(f"回复消息: {message_chain}")
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ from . import Platform
|
||||
from type.astrbot_message import *
|
||||
from type.message_event import *
|
||||
from type.command import *
|
||||
from SparkleLogging.utils.core import LogManager
|
||||
from util.log import LogManager
|
||||
from logging import Logger
|
||||
from astrbot.message.handler import MessageHandler
|
||||
from util.cmd_config import PlatformConfig, NakuruPlatformConfig
|
||||
@@ -171,7 +171,7 @@ class QQNakuru(Platform):
|
||||
(GroupMessage, FriendMessage, GuildMessage))
|
||||
|
||||
logger.info(
|
||||
f"{source.user_id} <- {self.parse_message_outline(res)}")
|
||||
f"{message.sender.nickname}/{message.sender.user_id} <- {self.parse_message_outline(res)}")
|
||||
|
||||
if isinstance(res, str):
|
||||
res = [Plain(text=res), ]
|
||||
|
||||
@@ -16,7 +16,7 @@ from type.message_event import *
|
||||
from type.command import *
|
||||
from typing import Union, List, Dict
|
||||
from nakuru.entities.components import *
|
||||
from SparkleLogging.utils.core import LogManager
|
||||
from util.log import LogManager
|
||||
from logging import Logger
|
||||
from astrbot.message.handler import MessageHandler
|
||||
from util.cmd_config import PlatformConfig, QQOfficialPlatformConfig
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from dataclasses import dataclass
|
||||
from type.register import RegisteredPlugins
|
||||
from typing import List, Union, Callable
|
||||
from SparkleLogging.utils.core import LogManager
|
||||
from util.log import LogManager
|
||||
from logging import Logger
|
||||
|
||||
logger: Logger = LogManager.GetLogger(log_name='astrbot')
|
||||
|
||||
@@ -13,7 +13,7 @@ from types import ModuleType
|
||||
from type.types import Context
|
||||
from type.plugin import *
|
||||
from type.register import *
|
||||
from SparkleLogging.utils.core import LogManager
|
||||
from util.log import LogManager
|
||||
from logging import Logger
|
||||
|
||||
logger: Logger = LogManager.GetLogger(log_name='astrbot')
|
||||
@@ -107,13 +107,13 @@ class PluginManager():
|
||||
rc = process.poll()
|
||||
|
||||
|
||||
def install_plugin(self, repo_url: str):
|
||||
async def install_plugin(self, repo_url: str):
|
||||
ppath = self.plugin_store_path
|
||||
|
||||
# we no longer use Git anymore :)
|
||||
# Repo.clone_from(repo_url, to_path=plugin_path, branch='master')
|
||||
|
||||
plugin_path = self.updator.update(repo_url)
|
||||
plugin_path = await self.updator.update(repo_url)
|
||||
with open(os.path.join(plugin_path, "REPO"), "w", encoding='utf-8') as f:
|
||||
f.write(repo_url)
|
||||
|
||||
@@ -124,14 +124,14 @@ class PluginManager():
|
||||
# if not ok:
|
||||
# raise Exception(err)
|
||||
|
||||
def download_from_repo_url(self, target_path: str, repo_url: str):
|
||||
async def download_from_repo_url(self, target_path: str, repo_url: str):
|
||||
repo_namespace = repo_url.split("/")[-2:]
|
||||
author = repo_namespace[0]
|
||||
repo = repo_namespace[1]
|
||||
|
||||
logger.info(f"正在下载插件 {repo} ...")
|
||||
release_url = f"https://api.github.com/repos/{author}/{repo}/releases"
|
||||
releases = self.updator.fetch_release_info(url=release_url)
|
||||
releases = await self.updator.fetch_release_info(url=release_url)
|
||||
if not releases:
|
||||
# download from the default branch directly.
|
||||
logger.warn(f"未在插件 {author}/{repo} 中找到任何发布版本,将从默认分支下载。")
|
||||
@@ -139,7 +139,7 @@ class PluginManager():
|
||||
else:
|
||||
release_url = releases[0]['zipball_url']
|
||||
|
||||
download_file(release_url, target_path + ".zip")
|
||||
await download_file(release_url, target_path + ".zip")
|
||||
|
||||
def get_registered_plugin(self, plugin_name: str) -> RegisteredPlugin:
|
||||
for p in self.context.cached_plugins:
|
||||
@@ -156,12 +156,12 @@ class PluginManager():
|
||||
if not remove_dir(os.path.join(ppath, root_dir_name)):
|
||||
raise Exception("移除插件成功,但是删除插件文件夹失败。您可以手动删除该文件夹,位于 addons/plugins/ 下。")
|
||||
|
||||
def update_plugin(self, plugin_name: str):
|
||||
async def update_plugin(self, plugin_name: str):
|
||||
plugin = self.get_registered_plugin(plugin_name)
|
||||
if not plugin:
|
||||
raise Exception("插件不存在。")
|
||||
|
||||
self.updator.update(plugin)
|
||||
await self.updator.update(plugin)
|
||||
|
||||
def plugin_reload(self):
|
||||
cached_plugins = self.context.cached_plugins
|
||||
|
||||
@@ -15,7 +15,7 @@ from util.io import download_image_by_url
|
||||
from astrbot.persist.helper import dbConn
|
||||
from model.provider.provider import Provider
|
||||
from util.cmd_config import LLMConfig
|
||||
from SparkleLogging.utils.core import LogManager
|
||||
from util.log import LogManager
|
||||
from logging import Logger
|
||||
from typing import List, Dict
|
||||
|
||||
|
||||
+1
-3
@@ -1,6 +1,5 @@
|
||||
pydantic~=1.10.4
|
||||
aiohttp
|
||||
requests
|
||||
openai
|
||||
qq-botpy
|
||||
chardet~=5.1.0
|
||||
@@ -10,10 +9,9 @@ beautifulsoup4
|
||||
googlesearch-python
|
||||
tiktoken
|
||||
readability-lxml
|
||||
baidu-aip
|
||||
websockets
|
||||
flask
|
||||
psutil
|
||||
lxml_html_clean
|
||||
SparkleLogging
|
||||
colorlog
|
||||
aiocqhttp
|
||||
|
||||
@@ -0,0 +1,51 @@
|
||||
|
||||
import aiohttp
|
||||
import pytest
|
||||
|
||||
BASE_URL = "http://0.0.0.0:6185/api"
|
||||
|
||||
async def get_url(url):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url) as response:
|
||||
return await response.json()
|
||||
|
||||
async def post_url(url, data):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(url, json=data) as response:
|
||||
return await response.json()
|
||||
|
||||
class TestHTTPServer:
|
||||
@pytest.mark.asyncio
|
||||
async def test_config(self):
|
||||
configs = await get_url(f"{BASE_URL}/configs")
|
||||
assert 'data' in configs and 'metadata' in configs['data'] \
|
||||
and 'config' in configs['data']
|
||||
config = configs['data']['config']
|
||||
# test post config
|
||||
await post_url(f"{BASE_URL}/astrbot-configs", config)
|
||||
# text post config with invalid data
|
||||
assert 'rate_limit' in config['platform_settings']
|
||||
config['platform_settings']['rate_limit'] = "invalid"
|
||||
ret = await post_url(f"{BASE_URL}/astrbot-configs", config)
|
||||
assert 'status' in ret and ret['status'] == 'error'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update(self):
|
||||
await get_url(f"{BASE_URL}/check_update")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_plugins(self):
|
||||
pname = "astrbot_plugin_bilibili"
|
||||
url = f"https://github.com/Soulter/{pname}"
|
||||
|
||||
await get_url(f"{BASE_URL}/extensions")
|
||||
|
||||
# test install plugin
|
||||
await post_url(f"{BASE_URL}/extensions/install", {
|
||||
"url": url
|
||||
})
|
||||
|
||||
# test uninstall plugin
|
||||
await post_url(f"{BASE_URL}/extensions/uninstall", {
|
||||
"name": pname
|
||||
})
|
||||
+13
-8
@@ -11,16 +11,11 @@ from model.platform.qq_aiocqhttp import AIOCQHTTP
|
||||
from model.provider.openai_official import ProviderOpenAIOfficial
|
||||
from type.astrbot_message import *
|
||||
from type.message_event import *
|
||||
from SparkleLogging.utils.core import LogManager
|
||||
from logging import Formatter
|
||||
from util.log import LogManager
|
||||
|
||||
from util.cmd_config import QQOfficialPlatformConfig, AiocqhttpPlatformConfig
|
||||
|
||||
logger = LogManager.GetLogger(
|
||||
log_name='astrbot',
|
||||
out_to_console=True,
|
||||
custom_formatter=Formatter('[%(asctime)s| %(name)s - %(levelname)s|%(filename)s:%(lineno)d]: %(message)s', datefmt="%H:%M:%S")
|
||||
)
|
||||
logger = LogManager.GetLogger(log_name='astrbot')
|
||||
pytest_plugins = ('pytest_asyncio',)
|
||||
|
||||
os.environ['TEST_MODE'] = 'on'
|
||||
@@ -135,7 +130,17 @@ class TestInteralCommandHsandle():
|
||||
abm = self.create("/t2i")
|
||||
await aiocqhttp.handle_msg(abm)
|
||||
await self.fast_test("/help")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_plugin(self):
|
||||
pname = "astrbot_plugin_bilibili"
|
||||
url = f"https://github.com/Soulter/{pname}"
|
||||
await self.fast_test("/plugin")
|
||||
await self.fast_test(f"/plugin l")
|
||||
await self.fast_test(f"/plugin i {url}")
|
||||
await self.fast_test(f"/plugin u {url}")
|
||||
await self.fast_test(f"/plugin d {pname}")
|
||||
|
||||
class TestLLMChat():
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_chat(self):
|
||||
|
||||
@@ -0,0 +1,28 @@
|
||||
from asyncio import Queue
|
||||
from collections import deque
|
||||
from typing import Deque
|
||||
|
||||
class CachedQueue(Queue):
|
||||
def __init__(self, maxsize: int = 0, cachesize: int = 200):
|
||||
super().__init__(maxsize)
|
||||
self.cache = deque(maxlen=cachesize)
|
||||
|
||||
def put_nowait(self, item):
|
||||
self.cache.append(item)
|
||||
super().put_nowait(item)
|
||||
|
||||
def get_nowait(self):
|
||||
item = super().get_nowait()
|
||||
return item
|
||||
|
||||
def get(self):
|
||||
item = super().get()
|
||||
return item
|
||||
|
||||
def clear(self):
|
||||
self.cache.clear()
|
||||
with self.mutex:
|
||||
self._queue.clear()
|
||||
|
||||
def get_cache(self) -> Deque:
|
||||
return self.cache
|
||||
+5
-1
@@ -1,4 +1,4 @@
|
||||
VERSION = '3.3.12'
|
||||
VERSION = '3.3.14'
|
||||
|
||||
DEFAULT_CONFIG = {
|
||||
"qqbot": {
|
||||
@@ -169,6 +169,8 @@ DEFAULT_CONFIG_VERSION_2 = {
|
||||
"username": "",
|
||||
"password": "",
|
||||
},
|
||||
"log_level": "INFO",
|
||||
"t2i_endpoint": "",
|
||||
}
|
||||
|
||||
# 这个是用于迁移旧版本配置文件的映射表
|
||||
@@ -350,4 +352,6 @@ CONFIG_METADATA_2 = {
|
||||
"password": {"description": "密码", "type": "string"},
|
||||
}
|
||||
},
|
||||
"log_level": {"description": "控制台日志级别(DEBUG, INFO, WARNING, ERROR)", "type": "string"},
|
||||
"t2i_endpoint": {"description": "文本转图像服务接口(为空时使用公共服务器)", "type": "string"},
|
||||
}
|
||||
|
||||
+3
-1
@@ -13,7 +13,7 @@ from type.middleware import Middleware
|
||||
from type.astrbot_message import MessageType
|
||||
from model.plugin.command import PluginCommandBridge
|
||||
from model.provider.provider import Provider
|
||||
from util.agent.func_call import FuncCall
|
||||
from type.cached_queue import CachedQueue
|
||||
|
||||
|
||||
class Context:
|
||||
@@ -49,6 +49,8 @@ class Context:
|
||||
self.command_manager = None
|
||||
self.running = True
|
||||
|
||||
self._log_queue = CachedQueue()
|
||||
|
||||
# useless
|
||||
# self.reply_prefix = ""
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ from util.websearch.bing import Bing
|
||||
from util.websearch.sogo import Sogo
|
||||
from util.websearch.google import Google
|
||||
from model.provider.provider import Provider
|
||||
from SparkleLogging.utils.core import LogManager
|
||||
from util.log import LogManager
|
||||
from logging import Logger
|
||||
from type.types import Context
|
||||
from type.message_event import AstrMessageEvent
|
||||
|
||||
@@ -133,6 +133,8 @@ class AstrBotConfig():
|
||||
dashboard: DashboardConfig = field(default_factory=DashboardConfig)
|
||||
platform: List[PlatformConfig] = field(default_factory=list)
|
||||
wake_prefix: List[str] = field(default_factory=list)
|
||||
log_level: str = "INFO"
|
||||
t2i_endpoint: str = ""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.init_configs()
|
||||
@@ -174,6 +176,8 @@ class AstrBotConfig():
|
||||
self.http_proxy=data.get("http_proxy", "")
|
||||
self.dashboard=DashboardConfig(**data.get("dashboard", {}))
|
||||
self.wake_prefix=data.get("wake_prefix", [])
|
||||
self.log_level=data.get("log_level", "INFO")
|
||||
self.t2i_endpoint=data.get("t2i_endpoint", "")
|
||||
|
||||
def migrate_config_1_2(self, old: dict) -> dict:
|
||||
'''将配置文件从版本 1 迁移至版本 2'''
|
||||
|
||||
+10
-7
@@ -4,10 +4,9 @@ import shutil
|
||||
import socket
|
||||
import time
|
||||
import aiohttp
|
||||
import requests
|
||||
|
||||
from PIL import Image
|
||||
from SparkleLogging.utils.core import LogManager
|
||||
from util.log import LogManager
|
||||
from logging import Logger
|
||||
|
||||
logger: Logger = LogManager.GetLogger(log_name='astrbot')
|
||||
@@ -99,16 +98,20 @@ async def download_image_by_url(url: str, post: bool = False, post_data: dict =
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def download_file(url: str, path: str):
|
||||
async def download_file(url: str, path: str):
|
||||
'''
|
||||
从指定 url 下载文件到指定路径 path
|
||||
'''
|
||||
try:
|
||||
logger.info(f"下载文件: {url}")
|
||||
with requests.get(url, stream=True) as r:
|
||||
with open(path, 'wb') as f:
|
||||
for chunk in r.iter_content(chunk_size=8192):
|
||||
f.write(chunk)
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url) as resp:
|
||||
with open(path, 'wb') as f:
|
||||
while True:
|
||||
chunk = await resp.content.read(8192)
|
||||
if not chunk:
|
||||
break
|
||||
f.write(chunk)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
+51
@@ -0,0 +1,51 @@
|
||||
import logging, asyncio, colorlog
|
||||
from type.cached_queue import CachedQueue
|
||||
|
||||
log_color_config = {
|
||||
'DEBUG': 'bold_blue', 'INFO': 'bold_cyan',
|
||||
'WARNING': 'bold_yellow', 'ERROR': 'red',
|
||||
'CRITICAL': 'bold_red', 'RESET': 'reset',
|
||||
'asctime': 'green'
|
||||
}
|
||||
|
||||
class LogQueueHandler(logging.Handler):
|
||||
def __init__(self, log_queue: CachedQueue):
|
||||
super().__init__()
|
||||
self.log_queue = log_queue
|
||||
|
||||
def emit(self, record):
|
||||
log_entry = self.format(record)
|
||||
try:
|
||||
self.log_queue.put_nowait(log_entry)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
class LogManager:
|
||||
|
||||
@classmethod
|
||||
def GetLogger(cls, log_name: str = 'default'):
|
||||
logger = logging.getLogger(log_name)
|
||||
if logger.hasHandlers():
|
||||
return logger
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler.setLevel(logging.DEBUG)
|
||||
console_formatter = colorlog.ColoredFormatter(
|
||||
fmt='%(log_color)s [%(asctime)s| %(levelname)s] [%(funcName)s|%(filename)s:%(lineno)d]: %(message)s %(reset)s',
|
||||
datefmt='%H:%M:%S',
|
||||
log_colors=log_color_config
|
||||
)
|
||||
console_handler.setFormatter(console_formatter)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
logger.addHandler(console_handler)
|
||||
|
||||
return logger
|
||||
|
||||
@classmethod
|
||||
def set_queue_handler(cls, logger: logging.Logger, log_queue: CachedQueue):
|
||||
handler = LogQueueHandler(log_queue)
|
||||
handler.setLevel(logging.DEBUG)
|
||||
if logger.handlers:
|
||||
handler.setFormatter(logger.handlers[0].formatter)
|
||||
else:
|
||||
handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
|
||||
logger.addHandler(handler)
|
||||
+4
-3
@@ -1,5 +1,5 @@
|
||||
import asyncio
|
||||
import requests
|
||||
import aiohttp
|
||||
import json
|
||||
import sys
|
||||
|
||||
@@ -57,8 +57,9 @@ class MetricUploader():
|
||||
"command_stats": self.command_stats,
|
||||
"sys": sys.platform, # 系统版本
|
||||
}
|
||||
resp = requests.post(
|
||||
'https://api.soulter.top/upload', data=json.dumps(res), timeout=5)
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post('https://api.soulter.top/upload', data=json.dumps(res), timeout=5) as resp:
|
||||
pass
|
||||
if resp.status_code == 200:
|
||||
ok = resp.json()
|
||||
if ok['status'] == 'ok':
|
||||
|
||||
@@ -1,17 +1,23 @@
|
||||
from util.t2i.strategies.local_strategy import LocalRenderStrategy
|
||||
from util.t2i.strategies.network_strategy import NetworkRenderStrategy
|
||||
from util.t2i.context import RenderContext
|
||||
from SparkleLogging.utils.core import LogManager
|
||||
from util.log import LogManager
|
||||
from logging import Logger
|
||||
|
||||
logger: Logger = LogManager.GetLogger(log_name='astrbot')
|
||||
|
||||
class TextToImageRenderer:
|
||||
def __init__(self):
|
||||
self.network_strategy = NetworkRenderStrategy()
|
||||
def __init__(self, endpoint_url: str = None):
|
||||
self.network_strategy = NetworkRenderStrategy(endpoint_url)
|
||||
self.local_strategy = LocalRenderStrategy()
|
||||
self.context = RenderContext(self.network_strategy)
|
||||
|
||||
def set_network_endpoint(self, endpoint_url: str):
|
||||
'''设置 t2i 的网络端点。
|
||||
'''
|
||||
logger.info("文本转图像服务接口: " + endpoint_url)
|
||||
self.network_strategy.set_endpoint(endpoint_url)
|
||||
|
||||
async def render_custom_template(self, tmpl_str: str, tmpl_data: dict, return_url: bool = False):
|
||||
'''使用自定义文转图模板。该方法会通过网络调用 t2i 终结点图文渲染API。
|
||||
@param tmpl_str: HTML Jinja2 模板。
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import re
|
||||
import requests
|
||||
import aiohttp
|
||||
from io import BytesIO
|
||||
|
||||
from .base_strategy import RenderStrategy
|
||||
from PIL import ImageFont, Image, ImageDraw
|
||||
@@ -82,8 +83,9 @@ class LocalRenderStrategy(RenderStrategy):
|
||||
try:
|
||||
image_url = re.findall(IMAGE_REGEX, line)[0]
|
||||
print(image_url)
|
||||
image_res = Image.open(requests.get(
|
||||
image_url, stream=True, timeout=5).raw)
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(image_url) as resp:
|
||||
image_res = Image.open(BytesIO(await resp.read()))
|
||||
images[i] = image_res
|
||||
# 最大不得超过image_width的50%
|
||||
img_height = image_res.size[1]
|
||||
|
||||
@@ -10,9 +10,16 @@ ASTRBOT_T2I_DEFAULT_ENDPOINT = "https://t2i.soulter.top/text2img"
|
||||
class NetworkRenderStrategy(RenderStrategy):
|
||||
def __init__(self, base_url: str = ASTRBOT_T2I_DEFAULT_ENDPOINT) -> None:
|
||||
super().__init__()
|
||||
if not base_url:
|
||||
base_url = ASTRBOT_T2I_DEFAULT_ENDPOINT
|
||||
self.BASE_RENDER_URL = base_url
|
||||
self.TEMPLATE_PATH = os.path.join(os.path.dirname(__file__), "template")
|
||||
|
||||
def set_endpoint(self, base_url: str):
|
||||
if not base_url:
|
||||
base_url = ASTRBOT_T2I_DEFAULT_ENDPOINT
|
||||
self.BASE_RENDER_URL = base_url
|
||||
|
||||
async def render_custom_template(self, tmpl_str: str, tmpl_data: dict, return_url: bool=True) -> str:
|
||||
'''使用自定义文转图模板'''
|
||||
post_data = {
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import os, psutil, sys, time
|
||||
from util.updator.zip_updator import ReleaseInfo, RepoZipUpdator
|
||||
from SparkleLogging.utils.core import LogManager
|
||||
from util.log import LogManager
|
||||
from logging import Logger
|
||||
from type.config import VERSION
|
||||
from util.io import on_error, download_file
|
||||
@@ -31,6 +31,9 @@ class AstrBotUpdator(RepoZipUpdator):
|
||||
pass
|
||||
|
||||
def _reboot(self, delay: int = None, context = None):
|
||||
if os.environ.get('TEST_MODE', 'off') == 'on':
|
||||
logger.info("测试模式下不会重启。")
|
||||
return
|
||||
# if delay: time.sleep(delay)
|
||||
py = sys.executable
|
||||
context.running = False
|
||||
@@ -43,11 +46,11 @@ class AstrBotUpdator(RepoZipUpdator):
|
||||
logger.error(f"重启失败({py}, {e}),请尝试手动重启。")
|
||||
raise e
|
||||
|
||||
def check_update(self, url: str, current_version: str) -> ReleaseInfo:
|
||||
return super().check_update(self.ASTRBOT_RELEASE_API, VERSION)
|
||||
async def check_update(self, url: str, current_version: str) -> ReleaseInfo:
|
||||
return await super().check_update(self.ASTRBOT_RELEASE_API, VERSION)
|
||||
|
||||
def update(self, reboot = False, latest = True, version = None):
|
||||
update_data = self.fetch_release_info(self.ASTRBOT_RELEASE_API, latest)
|
||||
async def update(self, reboot = False, latest = True, version = None):
|
||||
update_data = await self.fetch_release_info(self.ASTRBOT_RELEASE_API, latest)
|
||||
file_url = None
|
||||
|
||||
if latest:
|
||||
@@ -65,16 +68,10 @@ class AstrBotUpdator(RepoZipUpdator):
|
||||
raise Exception(f"未找到版本号为 {version} 的更新文件。")
|
||||
|
||||
try:
|
||||
download_file(file_url, "temp.zip")
|
||||
await download_file(file_url, "temp.zip")
|
||||
self.unzip_file("temp.zip", self.MAIN_PATH)
|
||||
except BaseException as e:
|
||||
raise e
|
||||
|
||||
if reboot:
|
||||
self._reboot()
|
||||
|
||||
def unzip_file(self, zip_path: str, target_dir: str):
|
||||
'''
|
||||
解压缩文件, 并将压缩包内**第一个**文件夹内的文件移动到 target_dir
|
||||
'''
|
||||
pass
|
||||
@@ -4,7 +4,7 @@ from util.updator.zip_updator import RepoZipUpdator
|
||||
from util.io import remove_dir
|
||||
from type.register import RegisteredPlugin
|
||||
from typing import Union
|
||||
from SparkleLogging.utils.core import LogManager
|
||||
from util.log import LogManager
|
||||
from logging import Logger
|
||||
from util.io import on_error
|
||||
|
||||
@@ -18,7 +18,7 @@ class PluginUpdator(RepoZipUpdator):
|
||||
def get_plugin_store_path(self) -> str:
|
||||
return self.plugin_store_path
|
||||
|
||||
def update(self, plugin: Union[RegisteredPlugin, str]) -> str:
|
||||
async def update(self, plugin: Union[RegisteredPlugin, str]) -> str:
|
||||
repo_url = None
|
||||
|
||||
if not isinstance(plugin, str):
|
||||
@@ -33,7 +33,7 @@ class PluginUpdator(RepoZipUpdator):
|
||||
plugin_path = os.path.join(self.plugin_store_path, self.format_repo_name(repo_url))
|
||||
|
||||
logger.info(f"正在更新插件,路径: {plugin_path},仓库地址: {repo_url}")
|
||||
self.download_from_repo_url(plugin_path, repo_url)
|
||||
await self.download_from_repo_url(plugin_path, repo_url)
|
||||
|
||||
try:
|
||||
remove_dir(plugin_path)
|
||||
|
||||
+13
-12
@@ -1,5 +1,5 @@
|
||||
import requests, os, zipfile, shutil
|
||||
from SparkleLogging.utils.core import LogManager
|
||||
import aiohttp, os, zipfile, shutil
|
||||
from util.log import LogManager
|
||||
from logging import Logger
|
||||
from util.io import on_error, download_file
|
||||
|
||||
@@ -23,14 +23,15 @@ class RepoZipUpdator():
|
||||
self.path = path
|
||||
self.rm_on_error = on_error
|
||||
|
||||
def fetch_release_info(self, url: str, latest: bool = True) -> list:
|
||||
async def fetch_release_info(self, url: str, latest: bool = True) -> list:
|
||||
'''
|
||||
请求版本信息。
|
||||
返回一个列表,每个元素是一个字典,包含版本号、发布时间、更新内容、commit hash等信息。
|
||||
'''
|
||||
result = requests.get(url).json()
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url) as response:
|
||||
result = await response.json()
|
||||
if not result: return []
|
||||
if latest:
|
||||
ret = self.github_api_release_parser([result[0]])
|
||||
@@ -66,7 +67,7 @@ class RepoZipUpdator():
|
||||
def unzip(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def update(self):
|
||||
async def update(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def compare_version(self, v1: str, v2: str) -> int:
|
||||
@@ -86,8 +87,8 @@ class RepoZipUpdator():
|
||||
return -1
|
||||
return 0
|
||||
|
||||
def check_update(self, url: str, current_version: str) -> ReleaseInfo:
|
||||
update_data = self.fetch_release_info(url)
|
||||
async def check_update(self, url: str, current_version: str) -> ReleaseInfo:
|
||||
update_data = await self.fetch_release_info(url)
|
||||
tag_name = update_data[0]['tag_name']
|
||||
|
||||
if self.compare_version(current_version, tag_name) >= 0:
|
||||
@@ -98,22 +99,22 @@ class RepoZipUpdator():
|
||||
body=update_data[0]['body']
|
||||
)
|
||||
|
||||
def download_from_repo_url(self, target_path: str, repo_url: str):
|
||||
async def download_from_repo_url(self, target_path: str, repo_url: str):
|
||||
repo_namespace = repo_url.split("/")[-2:]
|
||||
author = repo_namespace[0]
|
||||
repo = repo_namespace[1]
|
||||
|
||||
logger.info(f"正在下载更新 {repo} ...")
|
||||
release_url = f"https://api.github.com/repos/{author}/{repo}/releases"
|
||||
releases = self.fetch_release_info(url=release_url)
|
||||
releases = await self.fetch_release_info(url=release_url)
|
||||
if not releases:
|
||||
# download from the default branch directly.
|
||||
logger.warn(f"未在仓库 {author}/{repo} 中找到任何发布版本,将从默认分支下载。")
|
||||
logger.warning(f"未在仓库 {author}/{repo} 中找到任何发布版本,将从默认分支下载。")
|
||||
release_url = f"https://github.com/{author}/{repo}/archive/refs/heads/master.zip"
|
||||
else:
|
||||
release_url = releases[0]['zipball_url']
|
||||
|
||||
download_file(release_url, target_path + ".zip")
|
||||
await download_file(release_url, target_path + ".zip")
|
||||
|
||||
|
||||
def unzip_file(self, zip_path: str, target_dir: str):
|
||||
|
||||
Reference in New Issue
Block a user