Files
AstrBot/astrbot/core/star/star_manager.py
T

605 lines
25 KiB
Python

"""
插件的重载、启停、安装、卸载等操作。
"""
import inspect
import functools
import os
import sys
import json
import traceback
import yaml
import logging
import asyncio
from types import ModuleType
from typing import List
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core import logger, sp, pip_installer
from .context import Context
from . import StarMetadata
from .updator import PluginUpdator
from astrbot.core.utils.io import remove_dir
from .star import star_registry, star_map
from .star_handler import star_handlers_registry
from astrbot.core.provider.register import llm_tools
from .filter.permission import PermissionTypeFilter, PermissionType
class PluginManager:
def __init__(self, context: Context, config: AstrBotConfig):
self.updator = PluginUpdator(config["plugin_repo_mirror"])
self.context = context
self.context._star_manager = self
self.config = config
self.plugin_store_path = os.path.abspath(
os.path.join(
os.path.dirname(os.path.abspath(__file__)), "../../../data/plugins"
)
)
"""存储插件的路径。即 data/plugins"""
self.plugin_config_path = os.path.abspath(
os.path.join(
os.path.dirname(os.path.abspath(__file__)), "../../../data/config"
)
)
"""存储插件配置的路径。data/config"""
self.reserved_plugin_path = os.path.abspath(
os.path.join(
os.path.dirname(os.path.abspath(__file__)), "../../../packages"
)
)
"""保留插件的路径。在 packages 目录下"""
self.conf_schema_fname = "_conf_schema.json"
"""插件配置 Schema 文件名"""
self.failed_plugin_info = ""
def _get_classes(self, arg: ModuleType):
"""获取指定模块(可以理解为一个 python 文件)下所有的类"""
classes = []
clsmembers = inspect.getmembers(arg, inspect.isclass)
for name, _ in clsmembers:
if name.lower().endswith("plugin") or name.lower() == "main":
classes.append(name)
break
return classes
def _get_modules(self, path):
modules = []
dirs = os.listdir(path)
# 遍历文件夹,找到 main.py 或者和文件夹同名的文件
for d in dirs:
if os.path.isdir(os.path.join(path, d)):
if os.path.exists(os.path.join(path, d, "main.py")):
module_str = "main"
elif os.path.exists(os.path.join(path, d, d + ".py")):
module_str = d
else:
logger.info(f"插件 {d} 未找到 main.py 或者 {d}.py,跳过。")
continue
if os.path.exists(os.path.join(path, d, "main.py")) or os.path.exists(
os.path.join(path, d, d + ".py")
):
modules.append(
{
"pname": d,
"module": module_str,
"module_path": os.path.join(path, d, module_str),
}
)
return modules
def _get_plugin_modules(self) -> List[dict]:
plugins = []
if os.path.exists(self.plugin_store_path):
plugins.extend(self._get_modules(self.plugin_store_path))
if os.path.exists(self.reserved_plugin_path):
_p = self._get_modules(self.reserved_plugin_path)
for p in _p:
p["reserved"] = True
plugins.extend(_p)
return plugins
def _check_plugin_dept_update(self, target_plugin: str = None):
"""检查插件的依赖
如果 target_plugin 为 None,则检查所有插件的依赖
"""
plugin_dir = self.plugin_store_path
if not os.path.exists(plugin_dir):
return False
to_update = []
if target_plugin:
to_update.append(target_plugin)
else:
for p in self.context.get_all_stars():
to_update.append(p.root_dir_name)
for p in to_update:
plugin_path = os.path.join(plugin_dir, p)
if os.path.exists(os.path.join(plugin_path, "requirements.txt")):
pth = os.path.join(plugin_path, "requirements.txt")
logger.info(f"正在安装插件 {p} 所需的依赖库: {pth}")
try:
pip_installer.install(requirements_path=pth)
except Exception as e:
logger.error(f"更新插件 {p} 的依赖失败。Code: {str(e)}")
def _load_plugin_metadata(self, plugin_path: str, plugin_obj=None) -> StarMetadata:
"""v3.4.0 以前的方式载入插件元数据
先寻找 metadata.yaml 文件,如果不存在,则使用插件对象的 info() 函数获取元数据。
"""
metadata = None
if not os.path.exists(plugin_path):
raise Exception("插件不存在。")
if os.path.exists(os.path.join(plugin_path, "metadata.yaml")):
with open(
os.path.join(plugin_path, "metadata.yaml"), "r", encoding="utf-8"
) as f:
metadata = yaml.safe_load(f)
elif plugin_obj:
# 使用 info() 函数
metadata = plugin_obj.info()
if isinstance(metadata, dict):
if (
"name" not in metadata
or "desc" not in metadata
or "version" not in metadata
or "author" not in metadata
):
raise Exception(
"插件元数据信息不完整。name, desc, version, author 是必须的字段。"
)
metadata = StarMetadata(
name=metadata["name"],
author=metadata["author"],
desc=metadata["desc"],
version=metadata["version"],
repo=metadata["repo"] if "repo" in metadata else None,
)
return metadata
async def reload(self, specified_plugin_name=None):
"""扫描并加载所有的插件 当 specified_module_path 指定时,重载指定插件"""
specified_module_path = None
if specified_plugin_name:
for smd in star_registry:
if smd.name == specified_plugin_name:
specified_module_path = smd.module_path
break
# 终止插件
if not specified_module_path:
# 重载所有插件
for smd in star_registry:
try:
await self._terminate_plugin(smd)
except Exception as e:
logger.warning(traceback.format_exc())
logger.warning(
f"插件 {smd.name} 未被正常终止: {str(e)}, 可能会导致该插件运行不正常。"
)
await self._unbind_plugin(smd.name, smd.module_path)
star_handlers_registry.clear()
star_map.clear()
star_registry.clear()
for key in list(sys.modules.keys()):
if key.startswith("data.plugins") or key.startswith("packages"):
del sys.modules[key]
else:
# 只重载指定插件
smd = star_map.get(specified_module_path)
if smd:
try:
await self._terminate_plugin(smd)
except Exception as e:
logger.warning(traceback.format_exc())
logger.warning(
f"插件 {smd.name} 未被正常终止: {str(e)}, 可能会导致该插件运行不正常。"
)
await self._unbind_plugin(smd.name, specified_module_path)
return await self.load(specified_module_path)
async def load(self, specified_module_path=None, specified_dir_name=None):
"""载入插件。
当 specified_module_path 或者 specified_dir_name 不为 None 时,只载入指定的插件。
"""
inactivated_plugins: list = sp.get("inactivated_plugins", [])
inactivated_llm_tools: list = sp.get("inactivated_llm_tools", [])
alter_cmd = sp.get("alter_cmd", {})
plugin_modules = self._get_plugin_modules()
if plugin_modules is None:
return False, "未找到任何插件模块"
fail_rec = ""
# 导入插件模块,并尝试实例化插件类
for plugin_module in plugin_modules:
try:
module_str = plugin_module["module"]
# module_path = plugin_module['module_path']
root_dir_name = plugin_module["pname"] # 插件的目录名
reserved = plugin_module.get(
"reserved", False
) # 是否是保留插件。目前在 packages/ 目录下的都是保留插件。保留插件不可以卸载。
path = "data.plugins." if not reserved else "packages."
path += root_dir_name + "." + module_str
# 检查是否需要载入指定的插件
if specified_module_path and path != specified_module_path:
continue
if specified_dir_name and root_dir_name != specified_dir_name:
continue
logger.info(f"正在载入插件 {root_dir_name} ...")
# 尝试导入模块
try:
module = __import__(path, fromlist=[module_str])
except (ModuleNotFoundError, ImportError):
# 尝试安装依赖
self._check_plugin_dept_update(target_plugin=root_dir_name)
module = __import__(path, fromlist=[module_str])
except Exception as e:
logger.error(traceback.format_exc())
logger.error(f"插件 {root_dir_name} 导入失败。原因:{str(e)}")
continue
# 检查 _conf_schema.json
plugin_config = None
plugin_dir_path = (
os.path.join(self.plugin_store_path, root_dir_name)
if not reserved
else os.path.join(self.reserved_plugin_path, root_dir_name)
)
plugin_schema_path = os.path.join(
plugin_dir_path, self.conf_schema_fname
)
if os.path.exists(plugin_schema_path):
# 加载插件配置
with open(plugin_schema_path, "r", encoding="utf-8") as f:
plugin_config = AstrBotConfig(
config_path=os.path.join(
self.plugin_config_path, f"{root_dir_name}_config.json"
),
schema=json.loads(f.read()),
)
if path in star_map:
# 通过装饰器的方式注册插件
metadata = star_map[path]
try:
# yaml 文件的元数据优先
metadata_yaml = self._load_plugin_metadata(
plugin_path=plugin_dir_path
)
if metadata_yaml:
metadata.name = metadata_yaml.name
metadata.author = metadata_yaml.author
metadata.desc = metadata_yaml.desc
metadata.version = metadata_yaml.version
metadata.repo = metadata_yaml.repo
except Exception:
pass
if path not in inactivated_plugins:
# 只有没有禁用插件时才实例化插件类
if plugin_config:
metadata.config = plugin_config
try:
metadata.star_cls = metadata.star_cls_type(
context=self.context, config=plugin_config
)
except TypeError as _:
metadata.star_cls = metadata.star_cls_type(
context=self.context
)
else:
metadata.star_cls = metadata.star_cls_type(
context=self.context
)
else:
logger.info(f"插件 {metadata.name} 已被禁用。")
metadata.module = module
metadata.root_dir_name = root_dir_name
metadata.reserved = reserved
# 绑定 handler
related_handlers = (
star_handlers_registry.get_handlers_by_module_name(
metadata.module_path
)
)
for handler in related_handlers:
handler.handler = functools.partial(
handler.handler, metadata.star_cls
)
# 绑定 llm_tool handler
for func_tool in llm_tools.func_list:
if func_tool.handler.__module__ == metadata.module_path:
func_tool.handler_module_path = metadata.module_path
func_tool.handler = functools.partial(
func_tool.handler, metadata.star_cls
)
if func_tool.name in inactivated_llm_tools:
func_tool.active = False
else:
# v3.4.0 以前的方式注册插件
logger.debug(
f"插件 {path} 未通过装饰器注册。尝试通过旧版本方式载入。"
)
classes = self._get_classes(module)
if path not in inactivated_plugins:
# 只有没有禁用插件时才实例化插件类
if plugin_config:
try:
obj = getattr(module, classes[0])(
context=self.context, config=plugin_config
) # 实例化插件类
except TypeError as _:
obj = getattr(module, classes[0])(
context=self.context
) # 实例化插件类
else:
obj = getattr(module, classes[0])(
context=self.context
) # 实例化插件类
else:
logger.info(f"插件 {metadata.name} 已被禁用。")
metadata = None
metadata = self._load_plugin_metadata(
plugin_path=plugin_dir_path, plugin_obj=obj
)
metadata.star_cls = obj
metadata.config = plugin_config
metadata.module = module
metadata.root_dir_name = root_dir_name
metadata.reserved = reserved
metadata.star_cls_type = obj.__class__
metadata.module_path = path
star_map[path] = metadata
star_registry.append(metadata)
# 禁用/启用插件
if metadata.module_path in inactivated_plugins:
metadata.activated = False
full_names = []
for handler in star_handlers_registry.get_handlers_by_module_name(
metadata.module_path
):
full_names.append(handler.handler_full_name)
# 检查并且植入自定义的权限过滤器(alter_cmd)
if (
metadata.name in alter_cmd
and handler.handler_name in alter_cmd[metadata.name]
):
cmd_type = alter_cmd[metadata.name][handler.handler_name].get(
"permission", "member"
)
found_permission_filter = False
for filter_ in handler.event_filters:
if isinstance(filter_, PermissionTypeFilter):
if cmd_type == "admin":
filter_.permission_type = PermissionType.ADMIN
else:
filter_.permission_type = PermissionType.MEMBER
found_permission_filter = True
break
if not found_permission_filter:
handler.event_filters.append(
PermissionTypeFilter(
PermissionType.ADMIN
if cmd_type == "admin"
else PermissionType.MEMBER
)
)
logger.debug(
f"插入权限过滤器 {cmd_type}{metadata.name}{handler.handler_name} 方法。"
)
metadata.star_handler_full_names = full_names
# 执行 initialize() 方法
if hasattr(metadata.star_cls, "initialize"):
await metadata.star_cls.initialize()
except BaseException as e:
logger.error(f"----- 插件 {root_dir_name} 载入失败 -----")
errors = traceback.format_exc()
for line in errors.split("\n"):
logger.error(f"| {line}")
logger.error("----------------------------------")
fail_rec += f"加载 {root_dir_name} 插件时出现问题,原因 {str(e)}\n"
# 清除 pip.main 导致的多余的 logging handlers
for handler in logging.root.handlers[:]:
logging.root.removeHandler(handler)
if not fail_rec:
return True, None
else:
self.failed_plugin_info = fail_rec
return False, fail_rec
async def install_plugin(self, repo_url: str, proxy=""):
plugin_path = await self.updator.install(repo_url, proxy)
# reload the plugin
dir_name = os.path.basename(plugin_path)
await self.load(specified_dir_name=dir_name)
return plugin_path
async def uninstall_plugin(self, plugin_name: str):
plugin = self.context.get_registered_star(plugin_name)
if not plugin:
raise Exception("插件不存在。")
if plugin.reserved:
raise Exception("该插件是 AstrBot 保留插件,无法卸载。")
root_dir_name = plugin.root_dir_name
ppath = self.plugin_store_path
# 终止插件
try:
await self._terminate_plugin(plugin)
except Exception as e:
logger.warning(traceback.format_exc())
logger.warning(
f"插件 {plugin_name} 未被正常终止 {str(e)}, 可能会导致资源泄露等问题。"
)
# 从 star_registry 和 star_map 中删除
await self._unbind_plugin(plugin_name, plugin.module_path)
if not remove_dir(os.path.join(ppath, root_dir_name)):
raise Exception(
"移除插件成功,但是删除插件文件夹失败。您可以手动删除该文件夹,位于 addons/plugins/ 下。"
)
async def _unbind_plugin(self, plugin_name: str, plugin_module_path: str):
del star_map[plugin_module_path]
for i, p in enumerate(star_registry):
if p.name == plugin_name:
del star_registry[i]
break
for handler in star_handlers_registry.get_handlers_by_module_name(
plugin_module_path
):
logger.debug(f"unbind handler {handler.handler_name} from {plugin_name}")
star_handlers_registry.remove(handler)
keys_to_delete = [
k
for k, v in star_handlers_registry.star_handlers_map.items()
if k.startswith(plugin_module_path)
]
for k in keys_to_delete:
v = star_handlers_registry.star_handlers_map[k]
logger.debug(f"unbind handler {v.handler_name} from {plugin_name} (map)")
try:
del star_handlers_registry.star_handlers_map[k]
except KeyError:
pass
try:
del sys.modules[plugin_module_path]
except KeyError:
logger.warning(f"模块 {plugin_module_path} 未载入")
async def update_plugin(self, plugin_name: str, proxy=""):
"""升级一个插件"""
plugin = self.context.get_registered_star(plugin_name)
if not plugin:
raise Exception("插件不存在。")
if plugin.reserved:
raise Exception("该插件是 AstrBot 保留插件,无法更新。")
await self.updator.update(plugin, proxy=proxy)
await self.reload(plugin_name)
async def turn_off_plugin(self, plugin_name: str):
"""
禁用一个插件。
调用插件的 terminate() 方法,
将插件的 module_path 加入到 data/shared_preferences.json 的 inactivated_plugins 列表中。
并且同时将插件启用的 llm_tool 禁用。
"""
plugin = self.context.get_registered_star(plugin_name)
if not plugin:
raise Exception("插件不存在。")
# 调用插件的终止方法
await self._terminate_plugin(plugin)
# 加入到 shared_preferences 中
inactivated_plugins: list = sp.get("inactivated_plugins", [])
if plugin.module_path not in inactivated_plugins:
inactivated_plugins.append(plugin.module_path)
inactivated_llm_tools: list = list(
set(sp.get("inactivated_llm_tools", []))
) # 后向兼容
# 禁用插件启用的 llm_tool
for func_tool in llm_tools.func_list:
if func_tool.handler_module_path == plugin.module_path:
func_tool.active = False
if func_tool.name not in inactivated_llm_tools:
inactivated_llm_tools.append(func_tool.name)
sp.put("inactivated_plugins", inactivated_plugins)
sp.put("inactivated_llm_tools", inactivated_llm_tools)
plugin.activated = False
async def _terminate_plugin(self, star_metadata: StarMetadata):
"""终止插件,调用插件的 terminate() 和 __del__() 方法"""
logging.info(f"正在终止插件 {star_metadata.name} ...")
if not star_metadata.activated:
# 说明之前已经被禁用了
logger.debug(f"插件 {star_metadata.name} 未被激活,不需要终止,跳过。")
return
if hasattr(star_metadata.star_cls, "__del__"):
asyncio.get_event_loop().run_in_executor(
None, star_metadata.star_cls.__del__
)
else:
await star_metadata.star_cls.terminate()
async def turn_on_plugin(self, plugin_name: str):
plugin = self.context.get_registered_star(plugin_name)
inactivated_plugins: list = sp.get("inactivated_plugins", [])
inactivated_llm_tools: list = sp.get("inactivated_llm_tools", [])
if plugin.module_path in inactivated_plugins:
inactivated_plugins.remove(plugin.module_path)
sp.put("inactivated_plugins", inactivated_plugins)
# 启用插件启用的 llm_tool
for func_tool in llm_tools.func_list:
if (
func_tool.handler_module_path == plugin.module_path
and func_tool.name in inactivated_llm_tools
):
inactivated_llm_tools.remove(func_tool.name)
func_tool.active = True
sp.put("inactivated_llm_tools", inactivated_llm_tools)
await self.reload(plugin_name)
# plugin.activated = True
async def install_plugin_from_file(self, zip_file_path: str):
dir_name = os.path.basename(zip_file_path).replace(".zip", "")
dir_name = dir_name.removesuffix("-master").removesuffix("-main").lower()
desti_dir = os.path.join(self.plugin_store_path, dir_name)
self.updator.unzip_file(zip_file_path, desti_dir)
# remove the zip
try:
os.remove(zip_file_path)
except BaseException as e:
logger.warning(f"删除插件压缩包失败: {str(e)}")
# await self.reload()
await self.load(desti_dir)