Compare commits

...

9 Commits

Author SHA1 Message Date
Soulter 353b6ed761 feat: 支持自定义文转图服务地址 2024-09-22 10:50:47 -04:00
Soulter 90815b1ac5 chore: update version to 3.3.14 2024-09-22 10:25:26 -04:00
Soulter 8a50786e61 feat: 支持设置控制台日志级别;
refactor: 重写了后端与仪表盘的日志通信
2024-09-22 10:23:26 -04:00
Soulter 3b77df0556 fix: 修复下载更新后压缩包不解压的问题 2024-09-21 12:37:05 -04:00
Soulter 1fa11062de fix: /plugin u 指令异常 2024-09-21 12:33:00 -04:00
Soulter 6883de0f1c feat: partially test http server api 2024-09-21 12:19:49 -04:00
Soulter bdde0fe094 refactor: HTTP 请求全部异步化,移除了 baidu_aip, request 依赖 2024-09-21 11:36:02 -04:00
Soulter ab22b8103e Merge pull request #208 from Soulter/fix-issue-207
fix: 修复仪表盘保存配置递归校验失效的问题
2024-09-21 22:42:16 +08:00
Soulter 641d5cd67b fix: 修复仪表盘保存配置递归校验失效的问题 2024-09-21 10:40:32 -04:00
33 changed files with 358 additions and 165 deletions
+19 -7
View File
@@ -11,15 +11,14 @@ from model.platform.manager import PlatformManager
from typing import Union
from type.types import Context
from type.config import VERSION
from SparkleLogging.utils.core import LogManager
from logging import Logger
from util.cmd_config import AstrBotConfig, try_migrate
from util.metrics import MetricUploader
from util.updator.astrbot_updator import AstrBotUpdator
from util.log import LogManager
logger: Logger = LogManager.GetLogger(log_name='astrbot')
class AstrBotBootstrap():
def __init__(self) -> None:
self.context = Context()
@@ -28,7 +27,11 @@ class AstrBotBootstrap():
try_migrate()
self.config_helper = AstrBotConfig()
self.context.config_helper = self.config_helper
# set log queue handler
LogManager.set_queue_handler(logger, self.context._log_queue)
logger.info("AstrBot v" + VERSION)
# set log level
logger.setLevel(self.config_helper.log_level)
# apply proxy settings
http_proxy = self.context.config_helper.http_proxy
https_proxy = self.context.config_helper.https_proxy
@@ -44,6 +47,12 @@ class AstrBotBootstrap():
logger.info("未使用代理。")
self.test_mode = os.environ.get('TEST_MODE', 'off') == 'on'
# set t2i endpoint
if self.context.config_helper.t2i_endpoint:
self.context.image_renderer.set_network_endpoint(
self.context.config_helper.t2i_endpoint
)
async def run(self):
self.command_manager = CommandManager()
@@ -65,7 +74,12 @@ class AstrBotBootstrap():
self.context.plugin_updator = self.plugin_manager.updator
self.context.message_handler = self.message_handler
self.context.command_manager = self.command_manager
# load dashboard
self.dashboard.run_http_server()
dashboard_ws_task = asyncio.create_task(self.dashboard.ws_server(), name="dashboard")
dashboard_log_task = asyncio.create_task(self.dashboard.log_consumer(), name="log")
if self.test_mode:
return
@@ -77,10 +91,8 @@ class AstrBotBootstrap():
platform_tasks = self.load_platform()
# load metrics uploader
metrics_upload_task = asyncio.create_task(self.metrics_uploader.upload_metrics(), name="metrics-uploader")
# load dashboard
self.dashboard.run_http_server()
dashboard_task = asyncio.create_task(self.dashboard.ws_server(), name="dashboard")
tasks = [metrics_upload_task, dashboard_task, *platform_tasks, *self.context.ext_tasks]
tasks = [metrics_upload_task, dashboard_ws_task, dashboard_log_task, *platform_tasks, *self.context.ext_tasks]
tasks = [self.handle_task(task) for task in tasks]
await asyncio.gather(*tasks)
+4 -1
View File
@@ -11,7 +11,7 @@ from model.command.manager import CommandManager
from type.message_event import AstrMessageEvent, MessageResult
from type.types import Context
from type.command import CommandResult
from SparkleLogging.utils.core import LogManager
from util.log import LogManager
from logging import Logger
from nakuru.entities.components import Image
from util.agent.func_call import FuncCall
@@ -60,6 +60,9 @@ class ContentSafetyHelper():
from astrbot.message.baidu_aip_judge import BaiduJudge
self.baidu_judge = BaiduJudge(aip)
logger.info("已启用百度 AI 内容审核。")
except ImportError as e:
logger.error("检测到库依赖不完整,将不会启用百度 AI 内容审核。请先使用 pip 安装 `baidu_aip` 包。")
logger.error(e)
except BaseException as e:
logger.error("百度 AI 内容审核初始化失败。")
logger.error(e)
+37 -18
View File
@@ -2,7 +2,7 @@ from . import DashBoardData
from util.cmd_config import AstrBotConfig
from dataclasses import dataclass, asdict
from util.plugin_dev.api.v1.config import update_config
from SparkleLogging.utils.core import LogManager
from util.log import LogManager
from logging import Logger
from type.types import Context
from type.config import CONFIG_METADATA_2
@@ -15,29 +15,48 @@ class DashBoardHelper():
self.context = context
self.config_key_dont_show = ['dashboard', 'config_version']
def try_cast(self, value: str, type_: str):
if type_ == "int" and value.isdigit():
return int(value)
elif type_ == "float" and isinstance(value, str) \
and value.replace(".", "", 1).isdigit():
return float(value)
elif type_ == "float" and isinstance(value, int):
return float(value)
def validate_config(self, data):
errors = []
# 递归验证数据
def validate(data, path=""):
for key, meta in CONFIG_METADATA_2.items():
def validate(data, metadata=CONFIG_METADATA_2, path=""):
for key, meta in metadata.items():
if key not in data:
if key not in self.config_key_dont_show:
# 这些key不会传给前端,所以不需要验证
errors.append(f"Missing key: {path}{key}")
continue
value = data[key]
if meta["type"] == "int" and not isinstance(value, int):
errors.append(f"Invalid type for {path}{key}: expected int, got {type(value).__name__}")
elif meta["type"] == "bool" and not isinstance(value, bool):
errors.append(f"Invalid type for {path}{key}: expected bool, got {type(value).__name__}")
elif meta["type"] == "string" and not isinstance(value, str):
errors.append(f"Invalid type for {path}{key}: expected string, got {type(value).__name__}")
elif meta["type"] == "list" and not isinstance(value, list):
errors.append(f"Invalid type for {path}{key}: expected list, got {type(value).__name__}")
# 递归验证
if meta["type"] == "list" and isinstance(value, list):
for item in value:
validate(item, meta["items"], path=f"{path}{key}.")
elif meta["type"] == "dict" and not isinstance(value, dict):
errors.append(f"Invalid type for {path}{key}: expected dict, got {type(value).__name__}")
elif meta["type"] == "object" and isinstance(value, dict):
validate(value, meta["items"], path=f"{path}{key}.")
if meta["type"] == "int" and not isinstance(value, int):
casted = self.try_cast(value, "int")
if casted is None:
errors.append(f"错误的类型 {path}{key}: 期望是 int, 得到了 {type(value).__name__}")
data[key] = casted
elif meta["type"] == "float" and not isinstance(value, float):
casted = self.try_cast(value, "float")
if casted is None:
errors.append(f"错误的类型 {path}{key}: 期望是 float, 得到了 {type(value).__name__}")
data[key] = casted
elif meta["type"] == "bool" and not isinstance(value, bool):
errors.append(f"错误的类型 {path}{key}: 期望是 bool, 得到了 {type(value).__name__}")
elif meta["type"] == "string" and not isinstance(value, str):
errors.append(f"错误的类型 {path}{key}: 期望是 string, 得到了 {type(value).__name__}")
elif meta["type"] == "list" and not isinstance(value, list):
errors.append(f"错误的类型 {path}{key}: 期望是 list, 得到了 {type(value).__name__}")
elif meta["type"] == "object" and not isinstance(value, dict):
errors.append(f"错误的类型 {path}{key}: 期望是 dict, 得到了 {type(value).__name__}")
validate(value, meta["items"], path=f"{path}{key}.")
validate(data)
@@ -68,6 +87,6 @@ class DashBoardHelper():
typ = item['val_type']
if typ == 'int':
if not value.isdigit():
raise ValueError(f"Invalid type for {namespace}.{key}: expected int, got {type(value).__name__}")
raise ValueError(f"错误的类型 {namespace}.{key}: 期望是 int, 得到了 {type(value).__name__}")
value = int(value)
update_config(namespace, key, value)
+29 -21
View File
@@ -13,7 +13,7 @@ from werkzeug.serving import make_server
from astrbot.persist.helper import dbConn
from type.types import Context
from typing import List
from SparkleLogging.utils.core import LogManager
from util.log import LogManager
from logging import Logger
from dashboard.helper import DashBoardHelper
from util.io import get_local_ip_addresses
@@ -206,7 +206,8 @@ class AstrBotDashBoard():
repo_url = post_data["url"]
try:
logger.info(f"正在安装插件 {repo_url}")
self.plugin_manager.install_plugin(repo_url)
# self.plugin_manager.install_plugin(repo_url)
asyncio.run_coroutine_threadsafe(self.plugin_manager.install_plugin(repo_url), self.loop).result()
threading.Thread(target=self.astrbot_updator._reboot, args=(2, self.context)).start()
logger.info(f"安装插件 {repo_url} 成功,2秒后重启")
return Response(
@@ -272,7 +273,8 @@ class AstrBotDashBoard():
plugin_name = post_data["name"]
try:
logger.info(f"正在更新插件 {plugin_name}")
self.plugin_manager.update_plugin(plugin_name)
# self.plugin_manager.update_plugin(plugin_name)
asyncio.run_coroutine_threadsafe(self.plugin_manager.update_plugin(plugin_name), self.loop).result()
threading.Thread(target=self.astrbot_updator._reboot, args=(2, self.context)).start()
logger.info(f"更新插件 {plugin_name} 成功,2秒后重启")
return Response(
@@ -288,20 +290,12 @@ class AstrBotDashBoard():
data=None
).__dict__
@self.dashboard_be.post("/api/log")
def log():
for item in self.ws_clients:
try:
asyncio.run_coroutine_threadsafe(
self.ws_clients[item].send(request.data.decode()), self.loop).result()
except Exception as e:
pass
return 'ok'
@self.dashboard_be.get("/api/check_update")
def get_update_info():
try:
ret = self.astrbot_updator.check_update(None, None)
# ret = self.astrbot_updator.check_update(None, None)
ret = asyncio.run_coroutine_threadsafe(
self.astrbot_updator.check_update(None, None), self.loop).result()
return Response(
status="success",
message=str(ret) if ret is not None else "已经是最新版本了。",
@@ -326,7 +320,8 @@ class AstrBotDashBoard():
else:
latest = False
try:
self.astrbot_updator.update(latest=latest, version=version)
# await self.astrbot_updator.update(latest=latest, version=version)
asyncio.run_coroutine_threadsafe(self.astrbot_updator.update(latest=latest, version=version), self.loop).result()
threading.Thread(target=self.astrbot_updator._reboot, args=(2, self.context)).start()
return Response(
status="success",
@@ -417,18 +412,22 @@ class AstrBotDashBoard():
async def get_log_history(self):
try:
with open("logs/astrbot/astrbot.log", "r", encoding="utf-8") as f:
return f.readlines()[-100:]
dq = self.context._log_queue.get_cache()
ret = ""
for log in dq:
ret += log + "\n\r"
return ret
except Exception as e:
logger.warning(f"读取日志历史失败: {e.__str__()}")
return []
return ""
async def __handle_msg(self, websocket, path):
address = websocket.remote_address
self.ws_clients[address] = websocket
data = await self.get_log_history()
data = ''.join(data).replace('\n', '\r\n')
await websocket.send(data)
# 发送日志历史
await websocket.send(await self.get_log_history())
while True:
try:
msg = await websocket.recv()
@@ -445,6 +444,15 @@ class AstrBotDashBoard():
ws_server = websockets.serve(self.__handle_msg, "0.0.0.0", 6186)
logger.info("WebSocket 服务器已启动。")
await ws_server
async def log_consumer(self):
while True:
log = await self.context._log_queue.get()
for ws in self.ws_clients.values():
try:
await ws.send(log)
except Exception as e:
pass
def http_server(self):
http_server = make_server(
+4 -9
View File
@@ -6,8 +6,7 @@ import warnings
import traceback
import mimetypes
from astrbot.bootstrap import AstrBotBootstrap
from SparkleLogging.utils.core import LogManager
from logging import Formatter
from util.log import LogManager
warnings.filterwarnings("ignore")
logo_tmpl = r"""
@@ -27,6 +26,8 @@ def main():
# delete qqbotpy's logger
for handler in logging.root.handlers[:]:
logging.root.removeHandler(handler)
logger.info(logo_tmpl)
bootstrap = AstrBotBootstrap()
asyncio.run(bootstrap.run())
@@ -52,11 +53,5 @@ def check_env():
if __name__ == "__main__":
check_env()
logger = LogManager.GetLogger(
log_name='astrbot',
out_to_console=True,
custom_formatter=Formatter('[%(asctime)s| %(name)s - %(levelname)s|%(filename)s:%(lineno)d]: %(message)s', datefmt="%H:%M:%S")
)
logger.info(logo_tmpl)
logger = LogManager.GetLogger(log_name='astrbot')
main()
+22 -30
View File
@@ -1,4 +1,4 @@
import aiohttp
import aiohttp, os
from model.command.manager import CommandManager
from model.plugin.manager import PluginManager
@@ -6,7 +6,7 @@ from type.message_event import AstrMessageEvent
from type.command import CommandResult
from type.types import Context
from type.config import VERSION
from SparkleLogging.utils.core import LogManager
from util.log import LogManager
from logging import Logger
from util.agent.web_searcher import search_from_bing, fetch_website_content
@@ -27,6 +27,13 @@ class InternalCommandHandler:
self.manager.register("t2i", "文转图", 10, self.t2i_toggle)
self.manager.register("myid", "用户ID", 10, self.myid)
self.manager.register("provider", "LLM 接入源", 10, self.provider)
def _check_auth(self, message: AstrMessageEvent, context: Context):
if os.environ.get("TEST_MODE", "off") == "on":
return
if message.role != "admin":
user_id = message.message_obj.sender.user_id
raise Exception(f"用户(ID: {user_id}) 没有足够的权限使用该指令。")
def provider(self, message: AstrMessageEvent, context: Context):
if len(context.llms) == 0:
@@ -57,9 +64,8 @@ class InternalCommandHandler:
return CommandResult().message("provider: 参数错误。")
def set_nick(self, message: AstrMessageEvent, context: Context):
self._check_auth(message, context)
message_str = message.message_str
if message.role != "admin":
return CommandResult().message("你没有权限使用该指令。")
l = message_str.split(" ")
if len(l) == 1:
return CommandResult().message(f"设置机器人唤醒词。以唤醒词开头的消息会唤醒机器人处理,起到 @ 的效果。\n示例:wake 昵称。当前唤醒词是:{context.config_helper.wake_prefix[0]}")
@@ -74,15 +80,10 @@ class InternalCommandHandler:
message_chain=f"已经成功将唤醒前缀设定为 {nick}",
)
def update(self, message: AstrMessageEvent, context: Context):
async def update(self, message: AstrMessageEvent, context: Context):
self._check_auth(message, context)
tokens = self.manager.command_parser.parse(message.message_str)
if message.role != "admin":
return CommandResult(
hit=True,
success=False,
message_chain="你没有权限使用该指令",
)
update_info = context.updator.check_update(None, None)
update_info = await context.updator.check_update(None, None)
if tokens.len == 1:
ret = ""
if not update_info:
@@ -93,13 +94,13 @@ class InternalCommandHandler:
else:
if tokens.get(1) == "latest":
try:
context.updator.update()
await context.updator.update()
return CommandResult().message(f"已经成功更新到最新版本 v{update_info.version}。要应用更新,请重启 AstrBot。输入 /reboot 即可重启")
except BaseException as e:
return CommandResult().message(f"更新失败。原因:{str(e)}")
elif tokens.get(1).startswith("v"):
try:
context.updator.update(version=tokens.get(1))
await context.updator.update(version=tokens.get(1))
return CommandResult().message(f"已经成功更新到版本 v{tokens.get(1)}。要应用更新,请重启 AstrBot。输入 /reboot 即可重启")
except BaseException as e:
return CommandResult().message(f"更新失败。原因:{str(e)}")
@@ -107,12 +108,7 @@ class InternalCommandHandler:
return CommandResult().message("update: 参数错误。")
def reboot(self, message: AstrMessageEvent, context: Context):
if message.role != "admin":
return CommandResult(
hit=True,
success=False,
message_chain="你没有权限使用该指令",
)
self._check_auth(message, context)
context.updator._reboot(3, context)
return CommandResult(
hit=True,
@@ -120,7 +116,7 @@ class InternalCommandHandler:
message_chain="AstrBot 将在 3s 后重启。",
)
def plugin(self, message: AstrMessageEvent, context: Context):
async def plugin(self, message: AstrMessageEvent, context: Context):
tokens = self.manager.command_parser.parse(message.message_str)
if tokens.len == 1:
ret = "# 插件指令面板 \n- 安装插件: `plugin i 插件Github地址`\n- 卸载插件: `plugin d 插件名`\n- 查看插件列表:`plugin l`\n - 更新插件: `plugin u 插件名`\n"
@@ -133,10 +129,10 @@ class InternalCommandHandler:
if plugin_list_info.strip() == "":
return CommandResult().message("plugin v: 没有找到插件。")
return CommandResult().message(plugin_list_info)
self._check_auth(message, context)
elif tokens.get(1) == "d":
if message.role != "admin":
return CommandResult().message("plugin d: 你没有权限使用该指令。")
if tokens.get(1) == "d":
if tokens.len == 2:
return CommandResult().message("plugin d: 请指定要卸载的插件名。")
plugin_name = tokens.get(2)
@@ -147,25 +143,21 @@ class InternalCommandHandler:
return CommandResult().message(f"plugin d: 已经成功卸载插件 {plugin_name}")
elif tokens.get(1) == "i":
if message.role != "admin":
return CommandResult().message("plugin i: 你没有权限使用该指令。")
if tokens.len == 2:
return CommandResult().message("plugin i: 请指定要安装的插件的 Github 地址,或者前往可视化面板安装。")
plugin_url = tokens.get(2)
try:
self.plugin_manager.install_plugin(plugin_url)
await self.plugin_manager.install_plugin(plugin_url)
except BaseException as e:
return CommandResult().message(f"plugin i: 安装插件失败。原因:{str(e)}")
return CommandResult().message("plugin i: 已经成功安装插件。")
elif tokens.get(1) == "u":
if message.role != "admin":
return CommandResult().message("plugin u: 你没有权限使用该指令。")
if tokens.len == 2:
return CommandResult().message("plugin u: 请指定要更新的插件名。")
plugin_name = tokens.get(2)
try:
self.plugin_manager.update_plugin(plugin_name)
await context.plugin_updator.update(plugin_name)
except BaseException as e:
return CommandResult().message(f"plugin u: 更新插件失败。原因:{str(e)}")
return CommandResult().message(f"plugin u: 已经成功更新插件 {plugin_name}")
+1 -1
View File
@@ -9,7 +9,7 @@ from type.command import CommandResult
from type.register import RegisteredPlugins
from model.command.parser import CommandParser
from model.plugin.command import PluginCommandBridge
from SparkleLogging.utils.core import LogManager
from util.log import LogManager
from logging import Logger
from dataclasses import dataclass
+1 -1
View File
@@ -2,7 +2,7 @@ from model.command.manager import CommandManager
from type.message_event import AstrMessageEvent
from type.command import CommandResult
from type.types import Context
from SparkleLogging.utils.core import LogManager
from util.log import LogManager
from logging import Logger
from nakuru.entities.components import Image
from model.provider.openai_official import ProviderOpenAIOfficial, MODELS
+8 -4
View File
@@ -40,14 +40,18 @@ class Platform():
'''
pass
def parse_message_outline(self, message: AstrBotMessage) -> str:
def parse_message_outline(self, message: Union[AstrBotMessage, list]) -> str:
'''
将消息解析成大纲消息形式,如: xxxxx[图片]xxxxx。用于输出日志等。
'''
if isinstance(message, str):
return message
ret = ''
parsed = message if isinstance(message, list) else message.message
if isinstance(message, list):
parsed = message
elif isinstance(message, AstrBotMessage):
parsed = message.message
elif isinstance(message, str):
return message
try:
for node in parsed:
if isinstance(node, Plain):
+1 -1
View File
@@ -3,7 +3,7 @@ import asyncio
from util.io import port_checker
from type.register import RegisteredPlatform
from type.types import Context
from SparkleLogging.utils.core import LogManager
from util.log import LogManager
from logging import Logger
from astrbot.message.handler import MessageHandler
from util.cmd_config import (
+2 -2
View File
@@ -10,7 +10,7 @@ from type.message_event import *
from type.command import *
from typing import Union, List, Dict
from nakuru.entities.components import *
from SparkleLogging.utils.core import LogManager
from util.log import LogManager
from logging import Logger
from astrbot.message.handler import MessageHandler
from util.cmd_config import PlatformConfig, AiocqhttpPlatformConfig
@@ -209,7 +209,7 @@ class AIOCQHTTP(Platform):
if isinstance(message, AstrBotMessage):
logger.info(
f"{message.sender.user_id} <- {self.parse_message_outline(message)}")
f"{message.sender.nickname}/{message.sender.user_id} <- {self.parse_message_outline(message_chain)}")
else:
logger.info(f"回复消息: {message_chain}")
+2 -2
View File
@@ -15,7 +15,7 @@ from . import Platform
from type.astrbot_message import *
from type.message_event import *
from type.command import *
from SparkleLogging.utils.core import LogManager
from util.log import LogManager
from logging import Logger
from astrbot.message.handler import MessageHandler
from util.cmd_config import PlatformConfig, NakuruPlatformConfig
@@ -171,7 +171,7 @@ class QQNakuru(Platform):
(GroupMessage, FriendMessage, GuildMessage))
logger.info(
f"{source.user_id} <- {self.parse_message_outline(res)}")
f"{message.sender.nickname}/{message.sender.user_id} <- {self.parse_message_outline(res)}")
if isinstance(res, str):
res = [Plain(text=res), ]
+1 -1
View File
@@ -16,7 +16,7 @@ from type.message_event import *
from type.command import *
from typing import Union, List, Dict
from nakuru.entities.components import *
from SparkleLogging.utils.core import LogManager
from util.log import LogManager
from logging import Logger
from astrbot.message.handler import MessageHandler
from util.cmd_config import PlatformConfig, QQOfficialPlatformConfig
+1 -1
View File
@@ -1,7 +1,7 @@
from dataclasses import dataclass
from type.register import RegisteredPlugins
from typing import List, Union, Callable
from SparkleLogging.utils.core import LogManager
from util.log import LogManager
from logging import Logger
logger: Logger = LogManager.GetLogger(log_name='astrbot')
+8 -8
View File
@@ -13,7 +13,7 @@ from types import ModuleType
from type.types import Context
from type.plugin import *
from type.register import *
from SparkleLogging.utils.core import LogManager
from util.log import LogManager
from logging import Logger
logger: Logger = LogManager.GetLogger(log_name='astrbot')
@@ -107,13 +107,13 @@ class PluginManager():
rc = process.poll()
def install_plugin(self, repo_url: str):
async def install_plugin(self, repo_url: str):
ppath = self.plugin_store_path
# we no longer use Git anymore :)
# Repo.clone_from(repo_url, to_path=plugin_path, branch='master')
plugin_path = self.updator.update(repo_url)
plugin_path = await self.updator.update(repo_url)
with open(os.path.join(plugin_path, "REPO"), "w", encoding='utf-8') as f:
f.write(repo_url)
@@ -124,14 +124,14 @@ class PluginManager():
# if not ok:
# raise Exception(err)
def download_from_repo_url(self, target_path: str, repo_url: str):
async def download_from_repo_url(self, target_path: str, repo_url: str):
repo_namespace = repo_url.split("/")[-2:]
author = repo_namespace[0]
repo = repo_namespace[1]
logger.info(f"正在下载插件 {repo} ...")
release_url = f"https://api.github.com/repos/{author}/{repo}/releases"
releases = self.updator.fetch_release_info(url=release_url)
releases = await self.updator.fetch_release_info(url=release_url)
if not releases:
# download from the default branch directly.
logger.warn(f"未在插件 {author}/{repo} 中找到任何发布版本,将从默认分支下载。")
@@ -139,7 +139,7 @@ class PluginManager():
else:
release_url = releases[0]['zipball_url']
download_file(release_url, target_path + ".zip")
await download_file(release_url, target_path + ".zip")
def get_registered_plugin(self, plugin_name: str) -> RegisteredPlugin:
for p in self.context.cached_plugins:
@@ -156,12 +156,12 @@ class PluginManager():
if not remove_dir(os.path.join(ppath, root_dir_name)):
raise Exception("移除插件成功,但是删除插件文件夹失败。您可以手动删除该文件夹,位于 addons/plugins/ 下。")
def update_plugin(self, plugin_name: str):
async def update_plugin(self, plugin_name: str):
plugin = self.get_registered_plugin(plugin_name)
if not plugin:
raise Exception("插件不存在。")
self.updator.update(plugin)
await self.updator.update(plugin)
def plugin_reload(self):
cached_plugins = self.context.cached_plugins
+1 -1
View File
@@ -15,7 +15,7 @@ from util.io import download_image_by_url
from astrbot.persist.helper import dbConn
from model.provider.provider import Provider
from util.cmd_config import LLMConfig
from SparkleLogging.utils.core import LogManager
from util.log import LogManager
from logging import Logger
from typing import List, Dict
+1 -3
View File
@@ -1,6 +1,5 @@
pydantic~=1.10.4
aiohttp
requests
openai
qq-botpy
chardet~=5.1.0
@@ -10,10 +9,9 @@ beautifulsoup4
googlesearch-python
tiktoken
readability-lxml
baidu-aip
websockets
flask
psutil
lxml_html_clean
SparkleLogging
colorlog
aiocqhttp
+51
View File
@@ -0,0 +1,51 @@
import aiohttp
import pytest
BASE_URL = "http://0.0.0.0:6185/api"
async def get_url(url):
async with aiohttp.ClientSession() as session:
async with session.get(url) as response:
return await response.json()
async def post_url(url, data):
async with aiohttp.ClientSession() as session:
async with session.post(url, json=data) as response:
return await response.json()
class TestHTTPServer:
@pytest.mark.asyncio
async def test_config(self):
configs = await get_url(f"{BASE_URL}/configs")
assert 'data' in configs and 'metadata' in configs['data'] \
and 'config' in configs['data']
config = configs['data']['config']
# test post config
await post_url(f"{BASE_URL}/astrbot-configs", config)
# text post config with invalid data
assert 'rate_limit' in config['platform_settings']
config['platform_settings']['rate_limit'] = "invalid"
ret = await post_url(f"{BASE_URL}/astrbot-configs", config)
assert 'status' in ret and ret['status'] == 'error'
@pytest.mark.asyncio
async def test_update(self):
await get_url(f"{BASE_URL}/check_update")
@pytest.mark.asyncio
async def test_plugins(self):
pname = "astrbot_plugin_bilibili"
url = f"https://github.com/Soulter/{pname}"
await get_url(f"{BASE_URL}/extensions")
# test install plugin
await post_url(f"{BASE_URL}/extensions/install", {
"url": url
})
# test uninstall plugin
await post_url(f"{BASE_URL}/extensions/uninstall", {
"name": pname
})
+13 -8
View File
@@ -11,16 +11,11 @@ from model.platform.qq_aiocqhttp import AIOCQHTTP
from model.provider.openai_official import ProviderOpenAIOfficial
from type.astrbot_message import *
from type.message_event import *
from SparkleLogging.utils.core import LogManager
from logging import Formatter
from util.log import LogManager
from util.cmd_config import QQOfficialPlatformConfig, AiocqhttpPlatformConfig
logger = LogManager.GetLogger(
log_name='astrbot',
out_to_console=True,
custom_formatter=Formatter('[%(asctime)s| %(name)s - %(levelname)s|%(filename)s:%(lineno)d]: %(message)s', datefmt="%H:%M:%S")
)
logger = LogManager.GetLogger(log_name='astrbot')
pytest_plugins = ('pytest_asyncio',)
os.environ['TEST_MODE'] = 'on'
@@ -135,7 +130,17 @@ class TestInteralCommandHsandle():
abm = self.create("/t2i")
await aiocqhttp.handle_msg(abm)
await self.fast_test("/help")
@pytest.mark.asyncio
async def test_plugin(self):
pname = "astrbot_plugin_bilibili"
url = f"https://github.com/Soulter/{pname}"
await self.fast_test("/plugin")
await self.fast_test(f"/plugin l")
await self.fast_test(f"/plugin i {url}")
await self.fast_test(f"/plugin u {url}")
await self.fast_test(f"/plugin d {pname}")
class TestLLMChat():
@pytest.mark.asyncio
async def test_llm_chat(self):
+28
View File
@@ -0,0 +1,28 @@
from asyncio import Queue
from collections import deque
from typing import Deque
class CachedQueue(Queue):
def __init__(self, maxsize: int = 0, cachesize: int = 200):
super().__init__(maxsize)
self.cache = deque(maxlen=cachesize)
def put_nowait(self, item):
self.cache.append(item)
super().put_nowait(item)
def get_nowait(self):
item = super().get_nowait()
return item
def get(self):
item = super().get()
return item
def clear(self):
self.cache.clear()
with self.mutex:
self._queue.clear()
def get_cache(self) -> Deque:
return self.cache
+5 -1
View File
@@ -1,4 +1,4 @@
VERSION = '3.3.12'
VERSION = '3.3.14'
DEFAULT_CONFIG = {
"qqbot": {
@@ -169,6 +169,8 @@ DEFAULT_CONFIG_VERSION_2 = {
"username": "",
"password": "",
},
"log_level": "INFO",
"t2i_endpoint": "",
}
# 这个是用于迁移旧版本配置文件的映射表
@@ -350,4 +352,6 @@ CONFIG_METADATA_2 = {
"password": {"description": "密码", "type": "string"},
}
},
"log_level": {"description": "控制台日志级别(DEBUG, INFO, WARNING, ERROR)", "type": "string"},
"t2i_endpoint": {"description": "文本转图像服务接口(为空时使用公共服务器)", "type": "string"},
}
+3 -1
View File
@@ -13,7 +13,7 @@ from type.middleware import Middleware
from type.astrbot_message import MessageType
from model.plugin.command import PluginCommandBridge
from model.provider.provider import Provider
from util.agent.func_call import FuncCall
from type.cached_queue import CachedQueue
class Context:
@@ -49,6 +49,8 @@ class Context:
self.command_manager = None
self.running = True
self._log_queue = CachedQueue()
# useless
# self.reply_prefix = ""
+1 -1
View File
@@ -12,7 +12,7 @@ from util.websearch.bing import Bing
from util.websearch.sogo import Sogo
from util.websearch.google import Google
from model.provider.provider import Provider
from SparkleLogging.utils.core import LogManager
from util.log import LogManager
from logging import Logger
from type.types import Context
from type.message_event import AstrMessageEvent
+4
View File
@@ -133,6 +133,8 @@ class AstrBotConfig():
dashboard: DashboardConfig = field(default_factory=DashboardConfig)
platform: List[PlatformConfig] = field(default_factory=list)
wake_prefix: List[str] = field(default_factory=list)
log_level: str = "INFO"
t2i_endpoint: str = ""
def __init__(self) -> None:
self.init_configs()
@@ -174,6 +176,8 @@ class AstrBotConfig():
self.http_proxy=data.get("http_proxy", "")
self.dashboard=DashboardConfig(**data.get("dashboard", {}))
self.wake_prefix=data.get("wake_prefix", [])
self.log_level=data.get("log_level", "INFO")
self.t2i_endpoint=data.get("t2i_endpoint", "")
def migrate_config_1_2(self, old: dict) -> dict:
'''将配置文件从版本 1 迁移至版本 2'''
+10 -7
View File
@@ -4,10 +4,9 @@ import shutil
import socket
import time
import aiohttp
import requests
from PIL import Image
from SparkleLogging.utils.core import LogManager
from util.log import LogManager
from logging import Logger
logger: Logger = LogManager.GetLogger(log_name='astrbot')
@@ -99,16 +98,20 @@ async def download_image_by_url(url: str, post: bool = False, post_data: dict =
except Exception as e:
raise e
def download_file(url: str, path: str):
async def download_file(url: str, path: str):
'''
从指定 url 下载文件到指定路径 path
'''
try:
logger.info(f"下载文件: {url}")
with requests.get(url, stream=True) as r:
with open(path, 'wb') as f:
for chunk in r.iter_content(chunk_size=8192):
f.write(chunk)
async with aiohttp.ClientSession() as session:
async with session.get(url) as resp:
with open(path, 'wb') as f:
while True:
chunk = await resp.content.read(8192)
if not chunk:
break
f.write(chunk)
except Exception as e:
raise e
+51
View File
@@ -0,0 +1,51 @@
import logging, asyncio, colorlog
from type.cached_queue import CachedQueue
log_color_config = {
'DEBUG': 'bold_blue', 'INFO': 'bold_cyan',
'WARNING': 'bold_yellow', 'ERROR': 'red',
'CRITICAL': 'bold_red', 'RESET': 'reset',
'asctime': 'green'
}
class LogQueueHandler(logging.Handler):
def __init__(self, log_queue: CachedQueue):
super().__init__()
self.log_queue = log_queue
def emit(self, record):
log_entry = self.format(record)
try:
self.log_queue.put_nowait(log_entry)
except Exception:
pass
class LogManager:
@classmethod
def GetLogger(cls, log_name: str = 'default'):
logger = logging.getLogger(log_name)
if logger.hasHandlers():
return logger
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.DEBUG)
console_formatter = colorlog.ColoredFormatter(
fmt='%(log_color)s [%(asctime)s| %(levelname)s] [%(funcName)s|%(filename)s:%(lineno)d]: %(message)s %(reset)s',
datefmt='%H:%M:%S',
log_colors=log_color_config
)
console_handler.setFormatter(console_formatter)
logger.setLevel(logging.DEBUG)
logger.addHandler(console_handler)
return logger
@classmethod
def set_queue_handler(cls, logger: logging.Logger, log_queue: CachedQueue):
handler = LogQueueHandler(log_queue)
handler.setLevel(logging.DEBUG)
if logger.handlers:
handler.setFormatter(logger.handlers[0].formatter)
else:
handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
logger.addHandler(handler)
+4 -3
View File
@@ -1,5 +1,5 @@
import asyncio
import requests
import aiohttp
import json
import sys
@@ -57,8 +57,9 @@ class MetricUploader():
"command_stats": self.command_stats,
"sys": sys.platform, # 系统版本
}
resp = requests.post(
'https://api.soulter.top/upload', data=json.dumps(res), timeout=5)
async with aiohttp.ClientSession() as session:
async with session.post('https://api.soulter.top/upload', data=json.dumps(res), timeout=5) as resp:
pass
if resp.status_code == 200:
ok = resp.json()
if ok['status'] == 'ok':
+9 -3
View File
@@ -1,17 +1,23 @@
from util.t2i.strategies.local_strategy import LocalRenderStrategy
from util.t2i.strategies.network_strategy import NetworkRenderStrategy
from util.t2i.context import RenderContext
from SparkleLogging.utils.core import LogManager
from util.log import LogManager
from logging import Logger
logger: Logger = LogManager.GetLogger(log_name='astrbot')
class TextToImageRenderer:
def __init__(self):
self.network_strategy = NetworkRenderStrategy()
def __init__(self, endpoint_url: str = None):
self.network_strategy = NetworkRenderStrategy(endpoint_url)
self.local_strategy = LocalRenderStrategy()
self.context = RenderContext(self.network_strategy)
def set_network_endpoint(self, endpoint_url: str):
'''设置 t2i 的网络端点。
'''
logger.info("文本转图像服务接口: " + endpoint_url)
self.network_strategy.set_endpoint(endpoint_url)
async def render_custom_template(self, tmpl_str: str, tmpl_data: dict, return_url: bool = False):
'''使用自定义文转图模板。该方法会通过网络调用 t2i 终结点图文渲染API。
@param tmpl_str: HTML Jinja2 模板。
+5 -3
View File
@@ -1,5 +1,6 @@
import re
import requests
import aiohttp
from io import BytesIO
from .base_strategy import RenderStrategy
from PIL import ImageFont, Image, ImageDraw
@@ -82,8 +83,9 @@ class LocalRenderStrategy(RenderStrategy):
try:
image_url = re.findall(IMAGE_REGEX, line)[0]
print(image_url)
image_res = Image.open(requests.get(
image_url, stream=True, timeout=5).raw)
async with aiohttp.ClientSession() as session:
async with session.get(image_url) as resp:
image_res = Image.open(BytesIO(await resp.read()))
images[i] = image_res
# 最大不得超过image_width的50%
img_height = image_res.size[1]
+7
View File
@@ -10,9 +10,16 @@ ASTRBOT_T2I_DEFAULT_ENDPOINT = "https://t2i.soulter.top/text2img"
class NetworkRenderStrategy(RenderStrategy):
def __init__(self, base_url: str = ASTRBOT_T2I_DEFAULT_ENDPOINT) -> None:
super().__init__()
if not base_url:
base_url = ASTRBOT_T2I_DEFAULT_ENDPOINT
self.BASE_RENDER_URL = base_url
self.TEMPLATE_PATH = os.path.join(os.path.dirname(__file__), "template")
def set_endpoint(self, base_url: str):
if not base_url:
base_url = ASTRBOT_T2I_DEFAULT_ENDPOINT
self.BASE_RENDER_URL = base_url
async def render_custom_template(self, tmpl_str: str, tmpl_data: dict, return_url: bool=True) -> str:
'''使用自定义文转图模板'''
post_data = {
+9 -12
View File
@@ -1,6 +1,6 @@
import os, psutil, sys, time
from util.updator.zip_updator import ReleaseInfo, RepoZipUpdator
from SparkleLogging.utils.core import LogManager
from util.log import LogManager
from logging import Logger
from type.config import VERSION
from util.io import on_error, download_file
@@ -31,6 +31,9 @@ class AstrBotUpdator(RepoZipUpdator):
pass
def _reboot(self, delay: int = None, context = None):
if os.environ.get('TEST_MODE', 'off') == 'on':
logger.info("测试模式下不会重启。")
return
# if delay: time.sleep(delay)
py = sys.executable
context.running = False
@@ -43,11 +46,11 @@ class AstrBotUpdator(RepoZipUpdator):
logger.error(f"重启失败({py}, {e}),请尝试手动重启。")
raise e
def check_update(self, url: str, current_version: str) -> ReleaseInfo:
return super().check_update(self.ASTRBOT_RELEASE_API, VERSION)
async def check_update(self, url: str, current_version: str) -> ReleaseInfo:
return await super().check_update(self.ASTRBOT_RELEASE_API, VERSION)
def update(self, reboot = False, latest = True, version = None):
update_data = self.fetch_release_info(self.ASTRBOT_RELEASE_API, latest)
async def update(self, reboot = False, latest = True, version = None):
update_data = await self.fetch_release_info(self.ASTRBOT_RELEASE_API, latest)
file_url = None
if latest:
@@ -65,16 +68,10 @@ class AstrBotUpdator(RepoZipUpdator):
raise Exception(f"未找到版本号为 {version} 的更新文件。")
try:
download_file(file_url, "temp.zip")
await download_file(file_url, "temp.zip")
self.unzip_file("temp.zip", self.MAIN_PATH)
except BaseException as e:
raise e
if reboot:
self._reboot()
def unzip_file(self, zip_path: str, target_dir: str):
'''
解压缩文件, 并将压缩包内**第一个**文件夹内的文件移动到 target_dir
'''
pass
+3 -3
View File
@@ -4,7 +4,7 @@ from util.updator.zip_updator import RepoZipUpdator
from util.io import remove_dir
from type.register import RegisteredPlugin
from typing import Union
from SparkleLogging.utils.core import LogManager
from util.log import LogManager
from logging import Logger
from util.io import on_error
@@ -18,7 +18,7 @@ class PluginUpdator(RepoZipUpdator):
def get_plugin_store_path(self) -> str:
return self.plugin_store_path
def update(self, plugin: Union[RegisteredPlugin, str]) -> str:
async def update(self, plugin: Union[RegisteredPlugin, str]) -> str:
repo_url = None
if not isinstance(plugin, str):
@@ -33,7 +33,7 @@ class PluginUpdator(RepoZipUpdator):
plugin_path = os.path.join(self.plugin_store_path, self.format_repo_name(repo_url))
logger.info(f"正在更新插件,路径: {plugin_path},仓库地址: {repo_url}")
self.download_from_repo_url(plugin_path, repo_url)
await self.download_from_repo_url(plugin_path, repo_url)
try:
remove_dir(plugin_path)
+13 -12
View File
@@ -1,5 +1,5 @@
import requests, os, zipfile, shutil
from SparkleLogging.utils.core import LogManager
import aiohttp, os, zipfile, shutil
from util.log import LogManager
from logging import Logger
from util.io import on_error, download_file
@@ -23,14 +23,15 @@ class RepoZipUpdator():
self.path = path
self.rm_on_error = on_error
def fetch_release_info(self, url: str, latest: bool = True) -> list:
async def fetch_release_info(self, url: str, latest: bool = True) -> list:
'''
请求版本信息。
返回一个列表,每个元素是一个字典,包含版本号、发布时间、更新内容、commit hash等信息。
'''
result = requests.get(url).json()
try:
async with aiohttp.ClientSession() as session:
async with session.get(url) as response:
result = await response.json()
if not result: return []
if latest:
ret = self.github_api_release_parser([result[0]])
@@ -66,7 +67,7 @@ class RepoZipUpdator():
def unzip(self):
raise NotImplementedError()
def update(self):
async def update(self):
raise NotImplementedError()
def compare_version(self, v1: str, v2: str) -> int:
@@ -86,8 +87,8 @@ class RepoZipUpdator():
return -1
return 0
def check_update(self, url: str, current_version: str) -> ReleaseInfo:
update_data = self.fetch_release_info(url)
async def check_update(self, url: str, current_version: str) -> ReleaseInfo:
update_data = await self.fetch_release_info(url)
tag_name = update_data[0]['tag_name']
if self.compare_version(current_version, tag_name) >= 0:
@@ -98,22 +99,22 @@ class RepoZipUpdator():
body=update_data[0]['body']
)
def download_from_repo_url(self, target_path: str, repo_url: str):
async def download_from_repo_url(self, target_path: str, repo_url: str):
repo_namespace = repo_url.split("/")[-2:]
author = repo_namespace[0]
repo = repo_namespace[1]
logger.info(f"正在下载更新 {repo} ...")
release_url = f"https://api.github.com/repos/{author}/{repo}/releases"
releases = self.fetch_release_info(url=release_url)
releases = await self.fetch_release_info(url=release_url)
if not releases:
# download from the default branch directly.
logger.warn(f"未在仓库 {author}/{repo} 中找到任何发布版本,将从默认分支下载。")
logger.warning(f"未在仓库 {author}/{repo} 中找到任何发布版本,将从默认分支下载。")
release_url = f"https://github.com/{author}/{repo}/archive/refs/heads/master.zip"
else:
release_url = releases[0]['zipball_url']
download_file(release_url, target_path + ".zip")
await download_file(release_url, target_path + ".zip")
def unzip_file(self, zip_path: str, target_dir: str):