perf: plugin manager 重复 reload 释放资源
This commit is contained in:
@@ -37,9 +37,13 @@ class AstrBotCoreLifecycle:
|
||||
|
||||
self.platform_manager = PlatformManager(self.astrbot_config, self.event_queue)
|
||||
|
||||
self.star_context = Context(self.event_queue, self.astrbot_config, self.db)
|
||||
self.star_context.platform_manager = self.platform_manager
|
||||
self.star_context.provider_manager = self.provider_manager
|
||||
self.star_context = Context(
|
||||
self.event_queue,
|
||||
self.astrbot_config,
|
||||
self.db,
|
||||
self.provider_manager,
|
||||
self.platform_manager
|
||||
)
|
||||
self.plugin_manager = PluginManager(self.star_context, self.astrbot_config)
|
||||
|
||||
self.plugin_manager.reload()
|
||||
|
||||
@@ -24,7 +24,7 @@ class StarRequestSubStage(Stage):
|
||||
for handler in activated_handlers:
|
||||
params = handlers_parsed_params.get(handler.handler_full_name, {})
|
||||
try:
|
||||
if handler.handler_module_str not in star_map:
|
||||
if handler.handler_module_path not in star_map:
|
||||
# 孤立无援的 star handler
|
||||
continue
|
||||
|
||||
@@ -36,7 +36,7 @@ class StarRequestSubStage(Stage):
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(f"Star {handler.handler_full_name} handle error: {e}")
|
||||
ret = f":(\n\n在调用插件 {star_map.get(handler.handler_module_str).name} 的处理函数 {handler.handler_name} 时出现异常:{e}"
|
||||
ret = f":(\n\n在调用插件 {star_map.get(handler.handler_module_path).name} 的处理函数 {handler.handler_name} 时出现异常:{e}"
|
||||
event.set_result(MessageEventResult().message(ret))
|
||||
yield
|
||||
event.clear_result()
|
||||
|
||||
@@ -54,6 +54,8 @@ class ProviderManager():
|
||||
|
||||
if len(self.provider_insts) > 0:
|
||||
self.curr_provider_inst = self.provider_insts[0]
|
||||
else:
|
||||
logger.warning("未启用任何大模型提供商适配器。")
|
||||
|
||||
def get_insts(self):
|
||||
return self.provider_insts
|
||||
@@ -39,10 +39,18 @@ class Context:
|
||||
# back compatibility
|
||||
_register_tasks: List[Awaitable] = []
|
||||
|
||||
def __init__(self, event_queue: Queue, config: AstrBotConfig, db: BaseDatabase):
|
||||
def __init__(self,
|
||||
event_queue: Queue,
|
||||
config: AstrBotConfig,
|
||||
db: BaseDatabase,
|
||||
provider_manager: ProviderManager = None,
|
||||
platform_manager: PlatformManager = None
|
||||
):
|
||||
self._event_queue = event_queue
|
||||
self._config = config
|
||||
self._db = db
|
||||
self.provider_manager = provider_manager
|
||||
self.platform_manager = platform_manager
|
||||
|
||||
def get_registered_star(self, star_name: str) -> StarMetadata:
|
||||
for star in star_registry:
|
||||
@@ -73,7 +81,7 @@ class Context:
|
||||
event_type=EventType.OnLLMRequestEvent,
|
||||
handler_full_name=func_obj.__module__ + "_" + func_obj.__name__,
|
||||
handler_name=func_obj.__name__,
|
||||
handler_module_str=func_obj.__module__,
|
||||
handler_module_path=func_obj.__module__,
|
||||
handler=func_obj,
|
||||
event_filters=[],
|
||||
desc=desc
|
||||
@@ -125,7 +133,7 @@ class Context:
|
||||
event_type=EventType.AdapterMessageEvent,
|
||||
handler_full_name=awaitable.__module__ + "_" + awaitable.__name__,
|
||||
handler_name=awaitable.__name__,
|
||||
handler_module_str=awaitable.__module__,
|
||||
handler_module_path=awaitable.__module__,
|
||||
handler=awaitable,
|
||||
event_filters=[],
|
||||
desc=desc
|
||||
|
||||
@@ -28,7 +28,7 @@ def get_handler_or_create(handler: Awaitable, event_type: EventType, dont_add =
|
||||
event_type=event_type,
|
||||
handler_full_name=handler_full_name,
|
||||
handler_name=handler.__name__,
|
||||
handler_module_str=handler.__module__,
|
||||
handler_module_path=handler.__module__,
|
||||
handler=handler,
|
||||
event_filters=[]
|
||||
)
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
from __future__ import annotations
|
||||
import enum
|
||||
from dataclasses import dataclass
|
||||
from typing import Awaitable, List, Dict
|
||||
from typing import Awaitable, List, Dict, TypeVar, Generic
|
||||
from .filter import HandlerFilter
|
||||
|
||||
|
||||
class StarHandlerRegistry(List):
|
||||
T = TypeVar('T', bound='StarHandlerMetadata')
|
||||
class StarHandlerRegistry(Generic[T], List[T]):
|
||||
'''用于存储所有的 Star Handler'''
|
||||
|
||||
star_handlers_map: Dict[str, StarHandlerMetadata] = {}
|
||||
@@ -26,8 +26,7 @@ class StarHandlerRegistry(List):
|
||||
|
||||
def get_handlers_by_module_name(self, module_name: str) -> List[StarHandlerMetadata]:
|
||||
'''通过模块名获取 Handler'''
|
||||
return [handler for handler in self if handler.handler_module_str == module_name]
|
||||
|
||||
return [handler for handler in self if handler.handler_module_path == module_name]
|
||||
|
||||
star_handlers_registry = StarHandlerRegistry()
|
||||
|
||||
@@ -55,7 +54,7 @@ class StarHandlerMetadata():
|
||||
handler_name: str
|
||||
'''Handler 的名字,也就是方法名'''
|
||||
|
||||
handler_module_str: str
|
||||
handler_module_path: str
|
||||
'''Handler 所在的模块路径。'''
|
||||
|
||||
handler: Awaitable
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import inspect
|
||||
import functools
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
import yaml
|
||||
import logging
|
||||
@@ -14,6 +15,7 @@ from . import StarMetadata
|
||||
from .updator import PluginUpdator
|
||||
from astrbot.core.utils.io import remove_dir
|
||||
from .star import star_registry, star_map
|
||||
from .star_handler import star_handlers_registry
|
||||
from astrbot.core.provider.register import llm_tools
|
||||
|
||||
from .star_handler import star_handlers_registry
|
||||
@@ -139,6 +141,12 @@ class PluginManager:
|
||||
def reload(self):
|
||||
'''扫描并加载所有的 Star'''
|
||||
star_handlers_registry.clear()
|
||||
star_handlers_registry.star_handlers_map.clear()
|
||||
star_map.clear()
|
||||
star_registry.clear()
|
||||
for key in list(sys.modules.keys()):
|
||||
if key.startswith("data.plugins") or key.startswith("packages"):
|
||||
del sys.modules[key]
|
||||
|
||||
plugin_modules = self._get_plugin_modules()
|
||||
if plugin_modules is None:
|
||||
@@ -225,10 +233,11 @@ class PluginManager:
|
||||
|
||||
async def install_plugin(self, repo_url: str):
|
||||
plugin_path = await self.updator.install(repo_url)
|
||||
self._check_plugin_dept_update()
|
||||
# reload the plugin
|
||||
self.reload()
|
||||
return plugin_path
|
||||
|
||||
def uninstall_plugin(self, plugin_name: str):
|
||||
async def uninstall_plugin(self, plugin_name: str):
|
||||
plugin = self.context.get_registered_star(plugin_name)
|
||||
if not plugin:
|
||||
raise Exception("插件不存在。")
|
||||
@@ -237,7 +246,20 @@ class PluginManager:
|
||||
root_dir_name = plugin.root_dir_name
|
||||
ppath = self.plugin_store_path
|
||||
|
||||
# 从 star_registry 和 star_map 中删除
|
||||
del star_map[plugin.module_path]
|
||||
for i, p in enumerate(star_registry):
|
||||
if p.name == plugin_name:
|
||||
del star_registry[i]
|
||||
break
|
||||
for handler in star_handlers_registry.get_handlers_by_module_name(plugin.module_path):
|
||||
logger.debug(f"unbind handler {handler.handler_name} from {plugin_name}")
|
||||
star_handlers_registry.remove(handler)
|
||||
keys_to_delete = [k for k, v in star_handlers_registry.star_handlers_map.items() if k.startswith(plugin.module_path)]
|
||||
for k in keys_to_delete:
|
||||
v = star_handlers_registry.star_handlers_map[k]
|
||||
logger.debug(f"unbind handler {v.handler_name} from {plugin_name} (map)")
|
||||
del star_handlers_registry.star_handlers_map[k]
|
||||
|
||||
if not remove_dir(os.path.join(ppath, root_dir_name)):
|
||||
raise Exception("移除插件成功,但是删除插件文件夹失败。您可以手动删除该文件夹,位于 addons/plugins/ 下。")
|
||||
@@ -262,3 +284,4 @@ class PluginManager:
|
||||
logger.warning(f"删除插件压缩包失败: {str(e)}")
|
||||
|
||||
self._check_plugin_dept_update()
|
||||
|
||||
|
||||
@@ -80,7 +80,7 @@ class PluginRoute(Route):
|
||||
plugin_name = post_data["name"]
|
||||
try:
|
||||
logger.info(f"正在卸载插件 {plugin_name}")
|
||||
self.plugin_manager.uninstall_plugin(plugin_name)
|
||||
await self.plugin_manager.uninstall_plugin(plugin_name)
|
||||
logger.info(f"卸载插件 {plugin_name} 成功")
|
||||
return Response().ok(None, "卸载成功").__dict__
|
||||
except Exception as e:
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
|
||||
import os
|
||||
import asyncio
|
||||
import sys
|
||||
@@ -42,14 +41,16 @@ async def check_dashboard_files():
|
||||
return
|
||||
dashboard_release_url = "https://astrbot-registry.soulter.top/download/astrbot-dashboard/latest/dist.zip"
|
||||
logger.info("开始下载管理面板文件...")
|
||||
ok = False
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(dashboard_release_url) as resp:
|
||||
if resp.status != 200:
|
||||
logger.error(f"下载管理面板文件失败: {resp.status}")
|
||||
with open("data/dashboard.zip", "wb") as f:
|
||||
f.write(await resp.read())
|
||||
logger.info("管理面板文件下载完成。")
|
||||
ok = True
|
||||
else:
|
||||
with open("data/dashboard.zip", "wb") as f:
|
||||
f.write(await resp.read())
|
||||
logger.info("管理面板文件下载完成。")
|
||||
ok = True
|
||||
|
||||
if not ok:
|
||||
logger.critical("下载管理面板文件失败")
|
||||
|
||||
Reference in New Issue
Block a user