From 8dc8c5b5dc56dbb38f266358e3ffbef4ae7e6756 Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Wed, 8 Jan 2025 22:28:20 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=94=AF=E6=8C=81=E5=AF=B9=E6=8F=92?= =?UTF-8?q?=E4=BB=B6=E7=A6=81=E7=94=A8/=E5=90=AF=E7=94=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/core/provider/func_tool_manager.py | 1 + astrbot/core/provider/manager.py | 3 +- astrbot/core/star/context.py | 14 +++++ astrbot/core/star/star.py | 3 + astrbot/core/star/star_handler.py | 14 ++++- astrbot/core/star/star_manager.py | 68 +++++++++++++++++++--- astrbot/core/utils/io.py | 2 +- packages/astrbot/main.py | 35 ++++++++--- 8 files changed, 120 insertions(+), 20 deletions(-) diff --git a/astrbot/core/provider/func_tool_manager.py b/astrbot/core/provider/func_tool_manager.py index 07c672daf..f1d2f7b28 100644 --- a/astrbot/core/provider/func_tool_manager.py +++ b/astrbot/core/provider/func_tool_manager.py @@ -14,6 +14,7 @@ class FuncTool: parameters: Dict description: str handler: Awaitable + handler_module_path: str = None # 必须要保留这个,handler 在初始化会被 functools.partial 包装,导致 handler 的 __module__ 为 functools active: bool = True '''是否激活''' diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py index d72b4b5aa..e020148d6 100644 --- a/astrbot/core/provider/manager.py +++ b/astrbot/core/provider/manager.py @@ -65,7 +65,8 @@ class ProviderManager(): if len(self.provider_insts) > 0 and not self.curr_provider_inst: self.curr_provider_inst = self.provider_insts[0] - else: + + if not self.curr_provider_inst: logger.warning("未启用任何提供商适配器。") def get_insts(self): diff --git a/astrbot/core/star/context.py b/astrbot/core/star/context.py index 39ed5baf6..58810de96 100644 --- a/astrbot/core/star/context.py +++ b/astrbot/core/star/context.py @@ -1,6 +1,7 @@ from asyncio import Queue from typing import List, TypedDict, Union +from astrbot.core import sp from astrbot.core.provider.provider import Provider from astrbot.core.db import BaseDatabase from astrbot.core.config.astrbot_config import AstrBotConfig @@ -39,6 +40,7 @@ class Context: # back compatibility _register_tasks: List[Awaitable] = [] + _star_manager = None def __init__(self, event_queue: Queue, @@ -105,6 +107,12 @@ class Context: func_tool = self.provider_manager.llm_tools.get_func(name) if func_tool is not None: func_tool.active = True + + inactivated_llm_tools: list = sp.get("inactivated_llm_tools", []) + if name in inactivated_llm_tools: + inactivated_llm_tools.remove(name) + sp.put("inactivated_llm_tools", inactivated_llm_tools) + return True return False @@ -116,6 +124,12 @@ class Context: func_tool = self.provider_manager.llm_tools.get_func(name) if func_tool is not None: func_tool.active = False + + inactivated_llm_tools: list = sp.get("inactivated_llm_tools", []) + if name not in inactivated_llm_tools: + inactivated_llm_tools.append(name) + sp.put("inactivated_llm_tools", inactivated_llm_tools) + return True return False diff --git a/astrbot/core/star/star.py b/astrbot/core/star/star.py index cee93c5e6..960504f29 100644 --- a/astrbot/core/star/star.py +++ b/astrbot/core/star/star.py @@ -32,6 +32,9 @@ class StarMetadata: '''Star 的根目录名''' reserved: bool = False '''是否是 AstrBot 的保留 Star''' + + activated: bool = True + '''是否被激活''' def __str__(self) -> str: return f"StarMetadata({self.name}, {self.desc}, {self.version}, {self.repo})" \ No newline at end of file diff --git a/astrbot/core/star/star_handler.py b/astrbot/core/star/star_handler.py index 9acfa56e0..e0e13db40 100644 --- a/astrbot/core/star/star_handler.py +++ b/astrbot/core/star/star_handler.py @@ -3,6 +3,7 @@ import enum from dataclasses import dataclass from typing import Awaitable, List, Dict, TypeVar, Generic from .filter import HandlerFilter +from .star import star_map T = TypeVar('T', bound='StarHandlerMetadata') class StarHandlerRegistry(Generic[T], List[T]): @@ -16,9 +17,18 @@ class StarHandlerRegistry(Generic[T], List[T]): super().append(handler) self.star_handlers_map[handler.handler_full_name] = handler - def get_handlers_by_event_type(self, event_type: EventType) -> List[StarHandlerMetadata]: + def get_handlers_by_event_type(self, event_type: EventType, only_activated = True) -> List[StarHandlerMetadata]: '''通过事件类型获取 Handler''' - return [handler for handler in self if handler.event_type == event_type] + if only_activated: + return [ + handler + for handler in self + if handler.event_type == event_type and + star_map[handler.handler_module_path] and + star_map[handler.handler_module_path].activated + ] + else: + return [handler for handler in self if handler.event_type == event_type] def get_handler_by_full_name(self, full_name: str) -> StarHandlerMetadata: '''通过 Handler 的全名获取 Handler''' diff --git a/astrbot/core/star/star_manager.py b/astrbot/core/star/star_manager.py index 1a49f9826..264a6f5c8 100644 --- a/astrbot/core/star/star_manager.py +++ b/astrbot/core/star/star_manager.py @@ -9,7 +9,7 @@ from types import ModuleType from typing import List from pip import main as pip_main from astrbot.core.config.astrbot_config import AstrBotConfig -from astrbot.core import logger +from astrbot.core import logger, sp from .context import Context from . import StarMetadata from .updator import PluginUpdator @@ -27,6 +27,7 @@ class PluginManager: self.updator = PluginUpdator(config['plugin_repo_mirror']) self.context = context + self.context._star_manager = self # 就这样吧,不想改了 self.config = config self.plugin_store_path = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../../data/plugins")) @@ -156,6 +157,9 @@ class PluginManager: return False, "未找到任何插件模块" fail_rec = "" + inactivated_plugins: list = sp.get("inactivated_plugins", []) + inactivated_llm_tools: list = sp.get("inactivated_llm_tools", []) + # 导入 Star 模块,并尝试实例化 Star 类 for plugin_module in plugin_modules: try: @@ -196,7 +200,10 @@ class PluginManager: # llm_tool for func_tool in llm_tools.func_list: if func_tool.handler.__module__ == star_metadata.module_path: + func_tool.handler_module_path = star_metadata.module_path func_tool.handler = functools.partial(func_tool.handler, star_metadata.star_cls) + if func_tool.name in inactivated_llm_tools: + func_tool.active = False else: # v3.4.0 以前的方式注册插件 @@ -221,6 +228,9 @@ class PluginManager: star_registry.append(metadata) logger.debug(f"插件 {root_dir_name} 载入成功。") + if metadata.module_path in inactivated_plugins: + metadata.activated = False + except BaseException as e: traceback.print_exc() fail_rec += f"加载 {path} 插件时出现问题,原因 {str(e)}\n" @@ -250,22 +260,25 @@ class PluginManager: ppath = self.plugin_store_path # 从 star_registry 和 star_map 中删除 - del star_map[plugin.module_path] + await self._unbind_plugin(plugin_name, plugin.module_path) + + if not remove_dir(os.path.join(ppath, root_dir_name)): + raise Exception("移除插件成功,但是删除插件文件夹失败。您可以手动删除该文件夹,位于 addons/plugins/ 下。") + + async def _unbind_plugin(self, plugin_name: str, plugin_module_path: str): + del star_map[plugin_module_path] for i, p in enumerate(star_registry): if p.name == plugin_name: del star_registry[i] break - for handler in star_handlers_registry.get_handlers_by_module_name(plugin.module_path): + for handler in star_handlers_registry.get_handlers_by_module_name(plugin_module_path): logger.debug(f"unbind handler {handler.handler_name} from {plugin_name}") star_handlers_registry.remove(handler) - keys_to_delete = [k for k, v in star_handlers_registry.star_handlers_map.items() if k.startswith(plugin.module_path)] + keys_to_delete = [k for k, v in star_handlers_registry.star_handlers_map.items() if k.startswith(plugin_module_path)] for k in keys_to_delete: v = star_handlers_registry.star_handlers_map[k] logger.debug(f"unbind handler {v.handler_name} from {plugin_name} (map)") del star_handlers_registry.star_handlers_map[k] - - if not remove_dir(os.path.join(ppath, root_dir_name)): - raise Exception("移除插件成功,但是删除插件文件夹失败。您可以手动删除该文件夹,位于 addons/plugins/ 下。") async def update_plugin(self, plugin_name: str): plugin = self.context.get_registered_star(plugin_name) @@ -276,6 +289,47 @@ class PluginManager: await self.updator.update(plugin) self.reload() + + async def turn_off_plugin(self, plugin_name: str): + plugin = self.context.get_registered_star(plugin_name) + if not plugin: + raise Exception("插件不存在。") + inactivated_plugins: list = sp.get("inactivated_plugins", []) + if plugin.module_path not in inactivated_plugins: + inactivated_plugins.append(plugin.module_path) + + inactivated_llm_tools: list = sp.get("inactivated_llm_tools", []) + + # 禁用插件启用的 llm_tool + for func_tool in llm_tools.func_list: + if func_tool.handler_module_path == plugin.module_path: + func_tool.active = False + inactivated_llm_tools.append(func_tool.name) + + sp.put("inactivated_plugins", inactivated_plugins) + sp.put("inactivated_llm_tools", inactivated_llm_tools) + + plugin.activated = False + + async def turn_on_plugin(self, plugin_name: str): + plugin = self.context.get_registered_star(plugin_name) + if not plugin: + raise Exception("插件已经启用,无需重新启用。") + inactivated_plugins: list = sp.get("inactivated_plugins", []) + inactivated_llm_tools: list = sp.get("inactivated_llm_tools", []) + if plugin.module_path in inactivated_plugins: + inactivated_plugins.remove(plugin.module_path) + sp.put("inactivated_plugins", inactivated_plugins) + + # 启用插件启用的 llm_tool + for func_tool in llm_tools.func_list: + if func_tool.handler_module_path == plugin.module_path: + inactivated_llm_tools.remove(func_tool.name) + func_tool.active = True + sp.put("inactivated_llm_tools", inactivated_llm_tools) + + plugin.activated = True + def install_plugin_from_file(self, zip_file_path: str): desti_dir = os.path.join(self.plugin_store_path, os.path.basename(zip_file_path)) diff --git a/astrbot/core/utils/io.py b/astrbot/core/utils/io.py index 3c0a6a926..bbc3341bb 100644 --- a/astrbot/core/utils/io.py +++ b/astrbot/core/utils/io.py @@ -96,7 +96,7 @@ async def download_file(url: str, path: str): ''' try: async with aiohttp.ClientSession() as session: - async with session.get(url) as resp: + async with session.get(url, timeout=20) as resp: with open(path, 'wb') as f: while True: chunk = await resp.content.read(8192) diff --git a/packages/astrbot/main.py b/packages/astrbot/main.py index c81fa8f88..fa5161b7c 100644 --- a/packages/astrbot/main.py +++ b/packages/astrbot/main.py @@ -89,24 +89,41 @@ class Main(star.Star): event.set_result(MessageEventResult().message(f"停用工具 {tool_name} 失败,未找到此工具。")) @filter.command("plugin") - async def plugin(self, event: AstrMessageEvent, oper: str = None): - if oper is None: + async def plugin(self, event: AstrMessageEvent, oper1: str = None, oper2: str = None): + if oper1 is None: plugin_list_info = "已加载的插件:\n" for plugin in self.context.get_all_stars(): plugin_list_info += f"- `{plugin.name}` By {plugin.author}: {plugin.desc}\n" if plugin_list_info.strip() == "": plugin_list_info = "没有加载任何插件。" - plugin_list_info += "\n使用 /plugin <插件名> 查看插件帮助。" + plugin_list_info += "\n使用 /plugin <插件名> 查看插件帮助。\n使用 /plugin on/off <插件名> 启用或者禁用插件。" event.set_result(MessageEventResult().message(f"{plugin_list_info}").use_t2i(False)) else: - plugin = self.context.get_registered_star(oper) - if plugin is None: - event.set_result(MessageEventResult().message("未找到此插件。")) + if oper1 == "off": + # 禁用插件 + if oper2 is None: + event.set_result(MessageEventResult().message("/plugin off <插件名> 禁用插件。")) + return + await self.context._star_manager.turn_off_plugin(oper2) + event.set_result(MessageEventResult().message(f"插件 {oper2} 已禁用。")) + elif oper1 == "on": + # 启用插件 + if oper2 is None: + event.set_result(MessageEventResult().message("/plugin on <插件名> 启用插件。")) + return + await self.context._star_manager.turn_on_plugin(oper2) + event.set_result(MessageEventResult().message(f"插件 {oper2} 已启用。")) + else: - help_msg = plugin.star_cls.__doc__ if plugin.star_cls.__doc__ else "该插件未提供帮助信息" - ret = f"插件 {oper} 帮助信息:\n" + help_msg - event.set_result(MessageEventResult().message(ret).use_t2i(False)) + # 获取插件帮助 + plugin = self.context.get_registered_star(oper1) + if plugin is None: + event.set_result(MessageEventResult().message("未找到此插件。")) + else: + help_msg = plugin.star_cls.__doc__ if plugin.star_cls.__doc__ else "该插件未提供帮助信息" + ret = f"插件 {oper1} 帮助信息:\n" + help_msg + event.set_result(MessageEventResult().message(ret).use_t2i(False)) @filter.command("t2i") async def t2i(self, event: AstrMessageEvent):