diff --git a/astrbot/message/handler.py b/astrbot/message/handler.py index 5e862ce1d..473ab4349 100644 --- a/astrbot/message/handler.py +++ b/astrbot/message/handler.py @@ -60,6 +60,9 @@ class ContentSafetyHelper(): from astrbot.message.baidu_aip_judge import BaiduJudge self.baidu_judge = BaiduJudge(aip) logger.info("已启用百度 AI 内容审核。") + except ImportError as e: + logger.error("检测到库依赖不完整,将不会启用百度 AI 内容审核。请先使用 pip 安装 `baidu_aip` 包。") + logger.error(e) except BaseException as e: logger.error("百度 AI 内容审核初始化失败。") logger.error(e) diff --git a/dashboard/server.py b/dashboard/server.py index 33b852c89..4736c5edb 100644 --- a/dashboard/server.py +++ b/dashboard/server.py @@ -206,7 +206,8 @@ class AstrBotDashBoard(): repo_url = post_data["url"] try: logger.info(f"正在安装插件 {repo_url}") - self.plugin_manager.install_plugin(repo_url) + # self.plugin_manager.install_plugin(repo_url) + asyncio.run_coroutine_threadsafe(self.plugin_manager.install_plugin(repo_url), self.loop).result() threading.Thread(target=self.astrbot_updator._reboot, args=(2, self.context)).start() logger.info(f"安装插件 {repo_url} 成功,2秒后重启") return Response( @@ -272,7 +273,8 @@ class AstrBotDashBoard(): plugin_name = post_data["name"] try: logger.info(f"正在更新插件 {plugin_name}") - self.plugin_manager.update_plugin(plugin_name) + # self.plugin_manager.update_plugin(plugin_name) + asyncio.run_coroutine_threadsafe(self.plugin_manager.update_plugin(plugin_name), self.loop).result() threading.Thread(target=self.astrbot_updator._reboot, args=(2, self.context)).start() logger.info(f"更新插件 {plugin_name} 成功,2秒后重启") return Response( @@ -301,7 +303,9 @@ class AstrBotDashBoard(): @self.dashboard_be.get("/api/check_update") def get_update_info(): try: - ret = self.astrbot_updator.check_update(None, None) + # ret = self.astrbot_updator.check_update(None, None) + ret = asyncio.run_coroutine_threadsafe( + self.astrbot_updator.check_update(None, None), self.loop).result() return Response( status="success", message=str(ret) if ret is not None else "已经是最新版本了。", @@ -326,7 +330,8 @@ class AstrBotDashBoard(): else: latest = False try: - self.astrbot_updator.update(latest=latest, version=version) + # await self.astrbot_updator.update(latest=latest, version=version) + asyncio.run_coroutine_threadsafe(self.astrbot_updator.update(latest=latest, version=version), self.loop).result() threading.Thread(target=self.astrbot_updator._reboot, args=(2, self.context)).start() return Response( status="success", diff --git a/main.py b/main.py index 7f408dec3..c334a9704 100644 --- a/main.py +++ b/main.py @@ -27,6 +27,8 @@ def main(): # delete qqbotpy's logger for handler in logging.root.handlers[:]: logging.root.removeHandler(handler) + + logger.info(logo_tmpl) bootstrap = AstrBotBootstrap() asyncio.run(bootstrap.run()) @@ -58,5 +60,4 @@ if __name__ == "__main__": out_to_console=True, custom_formatter=Formatter('[%(asctime)s| %(name)s - %(levelname)s|%(filename)s:%(lineno)d]: %(message)s', datefmt="%H:%M:%S") ) - logger.info(logo_tmpl) main() diff --git a/model/command/internal_handler.py b/model/command/internal_handler.py index e05feccbe..99f4834d8 100644 --- a/model/command/internal_handler.py +++ b/model/command/internal_handler.py @@ -1,4 +1,4 @@ -import aiohttp +import aiohttp, os from model.command.manager import CommandManager from model.plugin.manager import PluginManager @@ -27,6 +27,13 @@ class InternalCommandHandler: self.manager.register("t2i", "文转图", 10, self.t2i_toggle) self.manager.register("myid", "用户ID", 10, self.myid) self.manager.register("provider", "LLM 接入源", 10, self.provider) + + def _check_auth(self, message: AstrMessageEvent, context: Context): + if os.environ.get("TEST_MODE", "off") == "on": + return + if message.role != "admin": + user_id = message.message_obj.sender.user_id + raise Exception(f"用户(ID: {user_id}) 没有足够的权限使用该指令。") def provider(self, message: AstrMessageEvent, context: Context): if len(context.llms) == 0: @@ -57,9 +64,8 @@ class InternalCommandHandler: return CommandResult().message("provider: 参数错误。") def set_nick(self, message: AstrMessageEvent, context: Context): + self._check_auth(message, context) message_str = message.message_str - if message.role != "admin": - return CommandResult().message("你没有权限使用该指令。") l = message_str.split(" ") if len(l) == 1: return CommandResult().message(f"设置机器人唤醒词。以唤醒词开头的消息会唤醒机器人处理,起到 @ 的效果。\n示例:wake 昵称。当前唤醒词是:{context.config_helper.wake_prefix[0]}") @@ -74,15 +80,10 @@ class InternalCommandHandler: message_chain=f"已经成功将唤醒前缀设定为 {nick}。", ) - def update(self, message: AstrMessageEvent, context: Context): + async def update(self, message: AstrMessageEvent, context: Context): + self._check_auth(message, context) tokens = self.manager.command_parser.parse(message.message_str) - if message.role != "admin": - return CommandResult( - hit=True, - success=False, - message_chain="你没有权限使用该指令", - ) - update_info = context.updator.check_update(None, None) + update_info = await context.updator.check_update(None, None) if tokens.len == 1: ret = "" if not update_info: @@ -93,13 +94,13 @@ class InternalCommandHandler: else: if tokens.get(1) == "latest": try: - context.updator.update() + await context.updator.update() return CommandResult().message(f"已经成功更新到最新版本 v{update_info.version}。要应用更新,请重启 AstrBot。输入 /reboot 即可重启") except BaseException as e: return CommandResult().message(f"更新失败。原因:{str(e)}") elif tokens.get(1).startswith("v"): try: - context.updator.update(version=tokens.get(1)) + await context.updator.update(version=tokens.get(1)) return CommandResult().message(f"已经成功更新到版本 v{tokens.get(1)}。要应用更新,请重启 AstrBot。输入 /reboot 即可重启") except BaseException as e: return CommandResult().message(f"更新失败。原因:{str(e)}") @@ -107,12 +108,7 @@ class InternalCommandHandler: return CommandResult().message("update: 参数错误。") def reboot(self, message: AstrMessageEvent, context: Context): - if message.role != "admin": - return CommandResult( - hit=True, - success=False, - message_chain="你没有权限使用该指令", - ) + self._check_auth(message, context) context.updator._reboot(3, context) return CommandResult( hit=True, @@ -120,7 +116,7 @@ class InternalCommandHandler: message_chain="AstrBot 将在 3s 后重启。", ) - def plugin(self, message: AstrMessageEvent, context: Context): + async def plugin(self, message: AstrMessageEvent, context: Context): tokens = self.manager.command_parser.parse(message.message_str) if tokens.len == 1: ret = "# 插件指令面板 \n- 安装插件: `plugin i 插件Github地址`\n- 卸载插件: `plugin d 插件名`\n- 查看插件列表:`plugin l`\n - 更新插件: `plugin u 插件名`\n" @@ -133,10 +129,10 @@ class InternalCommandHandler: if plugin_list_info.strip() == "": return CommandResult().message("plugin v: 没有找到插件。") return CommandResult().message(plugin_list_info) + + self._check_auth(message, context) - elif tokens.get(1) == "d": - if message.role != "admin": - return CommandResult().message("plugin d: 你没有权限使用该指令。") + if tokens.get(1) == "d": if tokens.len == 2: return CommandResult().message("plugin d: 请指定要卸载的插件名。") plugin_name = tokens.get(2) @@ -147,25 +143,21 @@ class InternalCommandHandler: return CommandResult().message(f"plugin d: 已经成功卸载插件 {plugin_name}。") elif tokens.get(1) == "i": - if message.role != "admin": - return CommandResult().message("plugin i: 你没有权限使用该指令。") if tokens.len == 2: return CommandResult().message("plugin i: 请指定要安装的插件的 Github 地址,或者前往可视化面板安装。") plugin_url = tokens.get(2) try: - self.plugin_manager.install_plugin(plugin_url) + await self.plugin_manager.install_plugin(plugin_url) except BaseException as e: return CommandResult().message(f"plugin i: 安装插件失败。原因:{str(e)}") return CommandResult().message("plugin i: 已经成功安装插件。") elif tokens.get(1) == "u": - if message.role != "admin": - return CommandResult().message("plugin u: 你没有权限使用该指令。") if tokens.len == 2: return CommandResult().message("plugin u: 请指定要更新的插件名。") plugin_name = tokens.get(2) try: - self.plugin_manager.update_plugin(plugin_name) + await context.updator.update(plugin_name) except BaseException as e: return CommandResult().message(f"plugin u: 更新插件失败。原因:{str(e)}") return CommandResult().message(f"plugin u: 已经成功更新插件 {plugin_name}。") diff --git a/model/plugin/manager.py b/model/plugin/manager.py index fb5275b5d..952ebbdc2 100644 --- a/model/plugin/manager.py +++ b/model/plugin/manager.py @@ -107,13 +107,13 @@ class PluginManager(): rc = process.poll() - def install_plugin(self, repo_url: str): + async def install_plugin(self, repo_url: str): ppath = self.plugin_store_path # we no longer use Git anymore :) # Repo.clone_from(repo_url, to_path=plugin_path, branch='master') - plugin_path = self.updator.update(repo_url) + plugin_path = await self.updator.update(repo_url) with open(os.path.join(plugin_path, "REPO"), "w", encoding='utf-8') as f: f.write(repo_url) @@ -124,14 +124,14 @@ class PluginManager(): # if not ok: # raise Exception(err) - def download_from_repo_url(self, target_path: str, repo_url: str): + async def download_from_repo_url(self, target_path: str, repo_url: str): repo_namespace = repo_url.split("/")[-2:] author = repo_namespace[0] repo = repo_namespace[1] logger.info(f"正在下载插件 {repo} ...") release_url = f"https://api.github.com/repos/{author}/{repo}/releases" - releases = self.updator.fetch_release_info(url=release_url) + releases = await self.updator.fetch_release_info(url=release_url) if not releases: # download from the default branch directly. logger.warn(f"未在插件 {author}/{repo} 中找到任何发布版本,将从默认分支下载。") @@ -139,7 +139,7 @@ class PluginManager(): else: release_url = releases[0]['zipball_url'] - download_file(release_url, target_path + ".zip") + await download_file(release_url, target_path + ".zip") def get_registered_plugin(self, plugin_name: str) -> RegisteredPlugin: for p in self.context.cached_plugins: @@ -156,12 +156,12 @@ class PluginManager(): if not remove_dir(os.path.join(ppath, root_dir_name)): raise Exception("移除插件成功,但是删除插件文件夹失败。您可以手动删除该文件夹,位于 addons/plugins/ 下。") - def update_plugin(self, plugin_name: str): + async def update_plugin(self, plugin_name: str): plugin = self.get_registered_plugin(plugin_name) if not plugin: raise Exception("插件不存在。") - self.updator.update(plugin) + await self.updator.update(plugin) def plugin_reload(self): cached_plugins = self.context.cached_plugins diff --git a/requirements.txt b/requirements.txt index 270244f6a..79529299d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,5 @@ pydantic~=1.10.4 aiohttp -requests openai qq-botpy chardet~=5.1.0 @@ -10,7 +9,6 @@ beautifulsoup4 googlesearch-python tiktoken readability-lxml -baidu-aip websockets flask psutil diff --git a/tests/test_message.py b/tests/test_message.py index 7015785e1..cc7764d4f 100644 --- a/tests/test_message.py +++ b/tests/test_message.py @@ -135,7 +135,17 @@ class TestInteralCommandHsandle(): abm = self.create("/t2i") await aiocqhttp.handle_msg(abm) await self.fast_test("/help") - + + @pytest.mark.asyncio + async def test_plugin(self): + pname = "astrbot_plugin_bilibili" + url = f"https://github.com/Soulter/{pname}" + await self.fast_test("/plugin") + await self.fast_test(f"/plugin l") + await self.fast_test(f"/plugin i {url}") + await self.fast_test(f"/plugin u {url}") + await self.fast_test(f"/plugin d {pname}") + class TestLLMChat(): @pytest.mark.asyncio async def test_llm_chat(self): diff --git a/util/io.py b/util/io.py index 6125d620c..d2ad3312c 100644 --- a/util/io.py +++ b/util/io.py @@ -4,7 +4,6 @@ import shutil import socket import time import aiohttp -import requests from PIL import Image from SparkleLogging.utils.core import LogManager @@ -99,16 +98,20 @@ async def download_image_by_url(url: str, post: bool = False, post_data: dict = except Exception as e: raise e -def download_file(url: str, path: str): +async def download_file(url: str, path: str): ''' 从指定 url 下载文件到指定路径 path ''' try: logger.info(f"下载文件: {url}") - with requests.get(url, stream=True) as r: - with open(path, 'wb') as f: - for chunk in r.iter_content(chunk_size=8192): - f.write(chunk) + async with aiohttp.ClientSession() as session: + async with session.get(url) as resp: + with open(path, 'wb') as f: + while True: + chunk = await resp.content.read(8192) + if not chunk: + break + f.write(chunk) except Exception as e: raise e diff --git a/util/metrics.py b/util/metrics.py index 84472ba45..d5f98be2a 100644 --- a/util/metrics.py +++ b/util/metrics.py @@ -1,5 +1,5 @@ import asyncio -import requests +import aiohttp import json import sys @@ -57,8 +57,9 @@ class MetricUploader(): "command_stats": self.command_stats, "sys": sys.platform, # 系统版本 } - resp = requests.post( - 'https://api.soulter.top/upload', data=json.dumps(res), timeout=5) + async with aiohttp.ClientSession() as session: + async with session.post('https://api.soulter.top/upload', data=json.dumps(res), timeout=5) as resp: + pass if resp.status_code == 200: ok = resp.json() if ok['status'] == 'ok': diff --git a/util/t2i/strategies/local_strategy.py b/util/t2i/strategies/local_strategy.py index 061ddc726..764ab9873 100644 --- a/util/t2i/strategies/local_strategy.py +++ b/util/t2i/strategies/local_strategy.py @@ -1,5 +1,6 @@ import re -import requests +import aiohttp +from io import BytesIO from .base_strategy import RenderStrategy from PIL import ImageFont, Image, ImageDraw @@ -82,8 +83,9 @@ class LocalRenderStrategy(RenderStrategy): try: image_url = re.findall(IMAGE_REGEX, line)[0] print(image_url) - image_res = Image.open(requests.get( - image_url, stream=True, timeout=5).raw) + async with aiohttp.ClientSession() as session: + async with session.get(image_url) as resp: + image_res = Image.open(BytesIO(await resp.read())) images[i] = image_res # 最大不得超过image_width的50% img_height = image_res.size[1] diff --git a/util/updator/astrbot_updator.py b/util/updator/astrbot_updator.py index 7db2705a9..a0245f760 100644 --- a/util/updator/astrbot_updator.py +++ b/util/updator/astrbot_updator.py @@ -43,11 +43,11 @@ class AstrBotUpdator(RepoZipUpdator): logger.error(f"重启失败({py}, {e}),请尝试手动重启。") raise e - def check_update(self, url: str, current_version: str) -> ReleaseInfo: - return super().check_update(self.ASTRBOT_RELEASE_API, VERSION) + async def check_update(self, url: str, current_version: str) -> ReleaseInfo: + return await super().check_update(self.ASTRBOT_RELEASE_API, VERSION) - def update(self, reboot = False, latest = True, version = None): - update_data = self.fetch_release_info(self.ASTRBOT_RELEASE_API, latest) + async def update(self, reboot = False, latest = True, version = None): + update_data = await self.fetch_release_info(self.ASTRBOT_RELEASE_API, latest) file_url = None if latest: @@ -65,7 +65,7 @@ class AstrBotUpdator(RepoZipUpdator): raise Exception(f"未找到版本号为 {version} 的更新文件。") try: - download_file(file_url, "temp.zip") + await download_file(file_url, "temp.zip") self.unzip_file("temp.zip", self.MAIN_PATH) except BaseException as e: raise e diff --git a/util/updator/plugin_updator.py b/util/updator/plugin_updator.py index 97b14550b..ac9a5d17a 100644 --- a/util/updator/plugin_updator.py +++ b/util/updator/plugin_updator.py @@ -18,7 +18,7 @@ class PluginUpdator(RepoZipUpdator): def get_plugin_store_path(self) -> str: return self.plugin_store_path - def update(self, plugin: Union[RegisteredPlugin, str]) -> str: + async def update(self, plugin: Union[RegisteredPlugin, str]) -> str: repo_url = None if not isinstance(plugin, str): @@ -33,7 +33,7 @@ class PluginUpdator(RepoZipUpdator): plugin_path = os.path.join(self.plugin_store_path, self.format_repo_name(repo_url)) logger.info(f"正在更新插件,路径: {plugin_path},仓库地址: {repo_url}") - self.download_from_repo_url(plugin_path, repo_url) + await self.download_from_repo_url(plugin_path, repo_url) try: remove_dir(plugin_path) diff --git a/util/updator/zip_updator.py b/util/updator/zip_updator.py index 102c9ee5a..b0e2889d7 100644 --- a/util/updator/zip_updator.py +++ b/util/updator/zip_updator.py @@ -1,4 +1,4 @@ -import requests, os, zipfile, shutil +import aiohttp, os, zipfile, shutil from SparkleLogging.utils.core import LogManager from logging import Logger from util.io import on_error, download_file @@ -23,14 +23,15 @@ class RepoZipUpdator(): self.path = path self.rm_on_error = on_error - def fetch_release_info(self, url: str, latest: bool = True) -> list: + async def fetch_release_info(self, url: str, latest: bool = True) -> list: ''' 请求版本信息。 返回一个列表,每个元素是一个字典,包含版本号、发布时间、更新内容、commit hash等信息。 ''' - result = requests.get(url).json() - try: + async with aiohttp.ClientSession() as session: + async with session.get(url) as response: + result = await response.json() if not result: return [] if latest: ret = self.github_api_release_parser([result[0]]) @@ -66,7 +67,7 @@ class RepoZipUpdator(): def unzip(self): raise NotImplementedError() - def update(self): + async def update(self): raise NotImplementedError() def compare_version(self, v1: str, v2: str) -> int: @@ -86,8 +87,8 @@ class RepoZipUpdator(): return -1 return 0 - def check_update(self, url: str, current_version: str) -> ReleaseInfo: - update_data = self.fetch_release_info(url) + async def check_update(self, url: str, current_version: str) -> ReleaseInfo: + update_data = await self.fetch_release_info(url) tag_name = update_data[0]['tag_name'] if self.compare_version(current_version, tag_name) >= 0: @@ -98,14 +99,14 @@ class RepoZipUpdator(): body=update_data[0]['body'] ) - def download_from_repo_url(self, target_path: str, repo_url: str): + async def download_from_repo_url(self, target_path: str, repo_url: str): repo_namespace = repo_url.split("/")[-2:] author = repo_namespace[0] repo = repo_namespace[1] logger.info(f"正在下载更新 {repo} ...") release_url = f"https://api.github.com/repos/{author}/{repo}/releases" - releases = self.fetch_release_info(url=release_url) + releases = await self.fetch_release_info(url=release_url) if not releases: # download from the default branch directly. logger.warn(f"未在仓库 {author}/{repo} 中找到任何发布版本,将从默认分支下载。") @@ -113,7 +114,7 @@ class RepoZipUpdator(): else: release_url = releases[0]['zipball_url'] - download_file(release_url, target_path + ".zip") + await download_file(release_url, target_path + ".zip") def unzip_file(self, zip_path: str, target_dir: str):