feat: 完善插件在禁用/重载时的逻辑,添加 terminate() Star 父类方法

This commit is contained in:
Soulter
2025-03-02 16:02:47 +08:00
parent 0956f240b3
commit 7e89fbc907
4 changed files with 61 additions and 23 deletions
+13 -1
View File
@@ -16,4 +16,16 @@ class Star(CommandParserMixin):
async def html_render(self, tmpl: str, data: dict, return_url = True) -> str:
'''渲染 HTML'''
return await html_renderer.render_custom_template(tmpl, data, return_url=return_url)
return await html_renderer.render_custom_template(tmpl, data, return_url=return_url)
async def terminate(self):
'''当插件被禁用、重载插件时会调用这个方法'''
pass
__all__ = [
'Star',
'StarMetadata',
'PluginManager',
'Context',
'Provider'
]
+42 -5
View File
@@ -6,6 +6,7 @@ import json
import traceback
import yaml
import logging
import asyncio
from types import ModuleType
from typing import List
from astrbot.core.config.astrbot_config import AstrBotConfig
@@ -151,11 +152,14 @@ class PluginManager:
# 终止插件
if not specified_module_path:
# 重载所有插件
for smd in star_registry:
logger.debug(f"尝试终止插件 {smd.name} ...")
if hasattr(smd.star_cls, "__del__"):
smd.star_cls.__del__()
try:
await self._terminate_plugin(smd)
except Exception as e:
logger.warning(traceback.format_exc())
logger.warning(f"插件 {smd.name} 未被正常终止: {str(e)}, 可能会导致该插件运行不正常。")
star_handlers_registry.clear()
star_map.clear()
star_registry.clear()
@@ -166,6 +170,12 @@ class PluginManager:
# 只重载指定插件
smd = star_map.get(specified_module_path)
if smd:
try:
await self._terminate_plugin(smd)
except Exception as e:
logger.warning(traceback.format_exc())
logger.warning(f"插件 {smd.name} 未被正常终止: {str(e)}, 可能会导致该插件运行不正常。")
await self._unbind_plugin(smd.name, specified_module_path)
try:
del sys.modules[specified_module_path]
@@ -355,6 +365,13 @@ class PluginManager:
root_dir_name = plugin.root_dir_name
ppath = self.plugin_store_path
# 终止插件
try:
await self._terminate_plugin(plugin)
except Exception as e:
logger.warning(traceback.format_exc())
logger.warning(f"插件 {plugin_name} 未被正常终止 {str(e)}, 可能会导致资源泄露等问题。")
# 从 star_registry 和 star_map 中删除
await self._unbind_plugin(plugin_name, plugin.module_path)
@@ -377,6 +394,7 @@ class PluginManager:
del star_handlers_registry.star_handlers_map[k]
async def update_plugin(self, plugin_name: str, proxy = ""):
'''升级一个插件'''
plugin = self.context.get_registered_star(plugin_name)
if not plugin:
raise Exception("插件不存在。")
@@ -387,9 +405,20 @@ class PluginManager:
await self.reload()
async def turn_off_plugin(self, plugin_name: str):
'''
禁用一个插件。
调用插件的 terminate() 方法,
将插件的 module_path 加入到 data/shared_preferences.json 的 inactivated_plugins 列表中。
并且同时将插件启用的 llm_tool 禁用。
'''
plugin = self.context.get_registered_star(plugin_name)
if not plugin:
raise Exception("插件不存在。")
# 调用插件的终止方法
await self._terminate_plugin(plugin)
# 加入到 shared_preferences 中
inactivated_plugins: list = sp.get("inactivated_plugins", [])
if plugin.module_path not in inactivated_plugins:
inactivated_plugins.append(plugin.module_path)
@@ -408,6 +437,15 @@ class PluginManager:
plugin.activated = False
async def _terminate_plugin(self, star_metadata: StarMetadata):
'''终止插件,调用插件的 terminate() 和 __del__() 方法'''
logging.info(f"正在终止插件 {star_metadata.name} ...")
if hasattr(star_metadata.star_cls, "__del__"):
asyncio.get_event_loop().run_in_executor(star_metadata.star_cls.__del__)
else:
await star_metadata.star_cls.terminate()
async def turn_on_plugin(self, plugin_name: str):
plugin = self.context.get_registered_star(plugin_name)
inactivated_plugins: list = sp.get("inactivated_plugins", [])
@@ -425,7 +463,6 @@ class PluginManager:
plugin.activated = True
async def install_plugin_from_file(self, zip_file_path: str):
dir_name = os.path.basename(zip_file_path).replace(".zip", "")
dir_name = dir_name.removesuffix("-master").removesuffix("-main").lower()
-16
View File
@@ -980,19 +980,3 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
found_command.event_filters.insert(0, PermissionTypeFilter(filter.PermissionType.ADMIN if cmd_type == "admin" else filter.PermissionType.MEMBER))
yield event.plain_result(f"已将 {cmd_name} 设置为 {cmd_type} 指令")
# @filter.command_group("kdb")
# def kdb(self):
# pass
# @filter.on_llm_request()
# async def on_llm_response(self, event: AstrMessageEvent, req: ProviderRequest):
# curr_kdb_name = self.context.provider_manager.curr_kdb_name
# if self.kdb_enabled and curr_kdb_name:
# mgr = self.context.knowledge_db_manager
# results = await mgr.retrive_records(curr_kdb_name, req.prompt)
# if results:
# req.system_prompt += "\nHere are documents that related to user's query: \n"
# for result in results:
# req.system_prompt += f"- {result}\n"7
+6 -1
View File
@@ -193,4 +193,9 @@ class Main(star.Star):
async def _reminder_callback(self, unified_msg_origin: str, d: dict):
'''The callback function of the reminder.'''
logger.info(f"Reminder Activated: {d['text']}, created by {unified_msg_origin}")
await self.context.send_message(unified_msg_origin, MessageEventResult().message("待办提醒: \n\n" + d['text'] + "\n时间: " + d.get("datetime", "") + d.get("cron_h", "")))
await self.context.send_message(unified_msg_origin, MessageEventResult().message("待办提醒: \n\n" + d['text'] + "\n时间: " + d.get("datetime", "") + d.get("cron_h", "")))
async def terminate(self):
self.scheduler.shutdown()
await self._save_data()
logger.info("Reminder plugin terminated.")