perf: 结构化插件的表示格式; 优化插件开发接口
This commit is contained in:
@@ -7,7 +7,7 @@ import re
|
||||
import requests
|
||||
from util.cmd_config import CmdConfig
|
||||
import socket
|
||||
from cores.qqbot.global_object import GlobalObject
|
||||
from cores.qqbot.types import GlobalObject
|
||||
import platform
|
||||
import logging
|
||||
import json
|
||||
@@ -537,7 +537,7 @@ def upload(_global_object: GlobalObject):
|
||||
"count": _global_object.cnt_total,
|
||||
"ip": addr_ip,
|
||||
"sys": sys.platform,
|
||||
"admin": _global_object.admin_qq,
|
||||
"admin": "null",
|
||||
}
|
||||
resp = requests.post('https://api.soulter.top/upload', data=json.dumps(res), timeout=5)
|
||||
if resp.status_code == 200:
|
||||
|
||||
@@ -1 +1,11 @@
|
||||
from cores.qqbot.global_object import GlobalObject
|
||||
from cores.qqbot.types import (
|
||||
PluginMetadata,
|
||||
RegisteredLLM,
|
||||
RegisteredPlugin,
|
||||
RegisteredPlatform,
|
||||
RegisteredPlugins,
|
||||
PluginType,
|
||||
GlobalObject,
|
||||
AstrMessageEvent,
|
||||
CommandResult
|
||||
)
|
||||
@@ -1,4 +1,3 @@
|
||||
from cores.qqbot.global_object import GlobalObject
|
||||
from typing import Union
|
||||
import os
|
||||
import json
|
||||
@@ -19,7 +18,6 @@ def load_config(namespace: str) -> Union[dict, bool]:
|
||||
ret[k] = data[k]["value"]
|
||||
return ret
|
||||
|
||||
|
||||
def put_config(namespace: str, name: str, key: str, value, description: str):
|
||||
'''
|
||||
将配置项写入以namespace为名字的配置文件,如果key不存在于目标配置文件中。当前 value 仅支持 str, int, float, bool, list 类型(暂不支持 dict)。
|
||||
|
||||
@@ -0,0 +1,6 @@
|
||||
'''
|
||||
大语言模型.
|
||||
|
||||
插件开发者可以继承这个类来做实现。
|
||||
'''
|
||||
from model.provider.provider import Provider as LLMProvider
|
||||
@@ -1,5 +1,5 @@
|
||||
from cores.qqbot.core import oper_msg
|
||||
from cores.qqbot.global_object import AstrMessageEvent, CommandResult
|
||||
from cores.qqbot.types import AstrMessageEvent, CommandResult
|
||||
from model.platform._message_result import MessageResult
|
||||
|
||||
'''
|
||||
|
||||
@@ -0,0 +1,11 @@
|
||||
'''
|
||||
消息平台。
|
||||
|
||||
Platform类是消息平台的抽象类,定义了消息平台的基本接口。
|
||||
消息平台的具体实现类需要继承Platform类,并实现其中的抽象方法。
|
||||
'''
|
||||
|
||||
from model.platform._platfrom import Platform
|
||||
|
||||
from model.platform.qq_gocq import QQGOCQ
|
||||
from model.platform.qq_official import QQOfficial
|
||||
@@ -0,0 +1,77 @@
|
||||
'''
|
||||
允许开发者注册某一个类的实例到 LLM 或者 PLATFORM 中,方便其他插件调用。
|
||||
|
||||
必须分别实现 Platform 和 LLMProvider 中涉及的接口
|
||||
'''
|
||||
from model.provider.provider import Provider as LLMProvider
|
||||
from model.platform._platfrom import Platform
|
||||
from cores.qqbot.types import GlobalObject, RegisteredPlatform, RegisteredLLM
|
||||
|
||||
def register_platform(platform_name: str, platform_instance: Platform, context: GlobalObject) -> None:
|
||||
'''
|
||||
注册一个消息平台。
|
||||
|
||||
Args:
|
||||
platform_name: 平台名称。
|
||||
platform_instance: 平台实例。
|
||||
'''
|
||||
|
||||
# check 是否已经注册
|
||||
for platform in context.platforms:
|
||||
if platform.platform_name == platform_name:
|
||||
raise ValueError(f"Platform {platform_name} has been registered.")
|
||||
|
||||
# check
|
||||
should_attrs = Platform.__dir__()
|
||||
has_attrs = platform_instance.__dir__()
|
||||
|
||||
if not all([attr in has_attrs for attr in should_attrs]):
|
||||
raise ValueError(f"Platform {platform_name} should implement all methods in LLMProvider.")
|
||||
|
||||
context.platforms.append(RegisteredPlatform(platform_name, platform_instance))
|
||||
|
||||
def register_llm(llm_name: str, llm_instance: LLMProvider, context: GlobalObject) -> None:
|
||||
'''
|
||||
注册一个大语言模型。
|
||||
|
||||
Args:
|
||||
llm_name: 大语言模型名称。
|
||||
llm_instance: 大语言模型实例。
|
||||
'''
|
||||
# check 是否已经注册
|
||||
for llm in context.llms:
|
||||
if llm.llm_name == llm_name:
|
||||
raise ValueError(f"LLMProvider {llm_name} has been registered.")
|
||||
|
||||
# check
|
||||
should_attrs = LLMProvider.__dir__()
|
||||
has_attrs = llm_instance.__dir__()
|
||||
|
||||
if not all([attr in has_attrs for attr in should_attrs]):
|
||||
raise ValueError(f"LLMProvider {llm_name} should implement all methods in LLMProvider.")
|
||||
|
||||
context.llms.append(RegisteredLLM(llm_name, llm_instance))
|
||||
|
||||
def unregister_platform(platform_name: str, context: GlobalObject) -> None:
|
||||
'''
|
||||
注销一个消息平台。
|
||||
|
||||
Args:
|
||||
platform_name: 平台名称。
|
||||
'''
|
||||
for i, platform in enumerate(context.platforms):
|
||||
if platform.platform_name == platform_name:
|
||||
context.platforms.pop(i)
|
||||
return
|
||||
|
||||
def unregister_llm(llm_name: str, context: GlobalObject) -> None:
|
||||
'''
|
||||
注销一个大语言模型。
|
||||
|
||||
Args:
|
||||
llm_name: 大语言模型名称。
|
||||
'''
|
||||
for i, llm in enumerate(context.llms):
|
||||
if llm.llm_name == llm_name:
|
||||
context.llms.pop(i)
|
||||
return
|
||||
@@ -0,0 +1,5 @@
|
||||
'''
|
||||
插件类型
|
||||
'''
|
||||
|
||||
from cores.qqbot.types import PluginType
|
||||
+75
-33
@@ -9,11 +9,19 @@ try:
|
||||
except ImportError:
|
||||
pass
|
||||
import shutil
|
||||
from pip._internal import main as pipmain
|
||||
import importlib
|
||||
import stat
|
||||
import traceback
|
||||
from types import ModuleType
|
||||
from typing import List
|
||||
from pip._internal import main as pipmain
|
||||
from cores.qqbot.types import (
|
||||
PluginMetadata,
|
||||
PluginType,
|
||||
RegisteredPlugin,
|
||||
RegisteredPlugins
|
||||
)
|
||||
|
||||
|
||||
# 找出模块里所有的类名
|
||||
def get_classes(p_name, arg: ModuleType):
|
||||
@@ -45,7 +53,8 @@ def get_modules(path):
|
||||
if os.path.exists(os.path.join(path, d, "main.py")) or os.path.exists(os.path.join(path, d, d + ".py")):
|
||||
modules.append({
|
||||
"pname": d,
|
||||
"module": module_str
|
||||
"module": module_str,
|
||||
"module_path": os.path.join(path, d, module_str)
|
||||
})
|
||||
return modules
|
||||
|
||||
@@ -73,39 +82,62 @@ def get_plugin_modules():
|
||||
except BaseException as e:
|
||||
raise e
|
||||
|
||||
def plugin_reload(cached_plugins: dict, target: str = None, all: bool = False):
|
||||
def plugin_reload(cached_plugins: RegisteredPlugins):
|
||||
plugins = get_plugin_modules()
|
||||
if plugins is None:
|
||||
return False, "未找到任何插件模块"
|
||||
fail_rec = ""
|
||||
|
||||
registered_map = {}
|
||||
for p in cached_plugins:
|
||||
registered_map[p.module_path] = None
|
||||
|
||||
for plugin in plugins:
|
||||
try:
|
||||
p = plugin['module']
|
||||
module_path = plugin['module_path']
|
||||
root_dir_name = plugin['pname']
|
||||
if p not in cached_plugins or p == target or all:
|
||||
|
||||
if module_path in registered_map:
|
||||
# 之前注册过
|
||||
module = importlib.reload(module)
|
||||
else:
|
||||
module = __import__("addons.plugins." + root_dir_name + "." + p, fromlist=[p])
|
||||
if p in cached_plugins:
|
||||
module = importlib.reload(module)
|
||||
cls = get_classes(p, module)
|
||||
obj = getattr(module, cls[0])()
|
||||
try:
|
||||
info = obj.info()
|
||||
|
||||
cls = get_classes(p, module)
|
||||
obj = getattr(module, cls[0])()
|
||||
|
||||
metadata = None
|
||||
try:
|
||||
info = obj.info()
|
||||
if isinstance(info, dict):
|
||||
if 'name' not in info or 'desc' not in info or 'version' not in info or 'author' not in info:
|
||||
fail_rec += f"载入插件{p}失败,原因: 插件信息不完整\n"
|
||||
fail_rec += f"注册插件 {module_path} 失败,原因: 插件信息不完整\n"
|
||||
continue
|
||||
if isinstance(info, dict) == False:
|
||||
fail_rec += f"载入插件{p}失败,原因: 插件信息格式不正确\n"
|
||||
continue
|
||||
except BaseException as e:
|
||||
fail_rec += f"调用插件{p} info失败, 原因: {str(e)}\n"
|
||||
else:
|
||||
metadata = PluginMetadata(
|
||||
plugin_name=info['name'],
|
||||
plugin_type=PluginType.COMMON if 'plugin_type' not in info else PluginType(info['plugin_type']),
|
||||
author=info['author'],
|
||||
desc=info['desc'],
|
||||
version=info['version'],
|
||||
repo=info['repo'] if 'repo' in info else None
|
||||
)
|
||||
elif isinstance(info, PluginMetadata):
|
||||
metadata = info
|
||||
else:
|
||||
fail_rec += f"注册插件 {module_path} 失败,原因: info 函数返回值类型错误\n"
|
||||
continue
|
||||
cached_plugins[info['name']] = {
|
||||
"module": module,
|
||||
"clsobj": obj,
|
||||
"info": info,
|
||||
"name": info['name'],
|
||||
"root_dir_name": root_dir_name,
|
||||
}
|
||||
except BaseException as e:
|
||||
fail_rec += f"注册插件 {module_path} 失败, 原因: {str(e)}\n"
|
||||
continue
|
||||
cached_plugins.append(RegisteredPlugin(
|
||||
metadata=metadata,
|
||||
plugin_instance=obj,
|
||||
module=module,
|
||||
module_path=module_path,
|
||||
root_dir_name=root_dir_name
|
||||
))
|
||||
except BaseException as e:
|
||||
traceback.print_exc()
|
||||
fail_rec += f"加载{p}插件出现问题,原因 {str(e)}\n"
|
||||
@@ -114,7 +146,7 @@ def plugin_reload(cached_plugins: dict, target: str = None, all: bool = False):
|
||||
else:
|
||||
return False, fail_rec
|
||||
|
||||
def install_plugin(repo_url: str, cached_plugins: dict):
|
||||
def install_plugin(repo_url: str, cached_plugins: RegisteredPlugins):
|
||||
ppath = get_plugin_store_path()
|
||||
# 删除末尾的 /
|
||||
if repo_url.endswith("/"):
|
||||
@@ -132,23 +164,33 @@ def install_plugin(repo_url: str, cached_plugins: dict):
|
||||
if os.path.exists(os.path.join(plugin_path, "requirements.txt")):
|
||||
if pipmain(['install', '-r', os.path.join(plugin_path, "requirements.txt"), '--quiet']) != 0:
|
||||
raise Exception("插件的依赖安装失败, 需要您手动 pip 安装对应插件的依赖。")
|
||||
ok, err = plugin_reload(cached_plugins, target=d)
|
||||
ok, err = plugin_reload(cached_plugins)
|
||||
if not ok: raise Exception(err)
|
||||
|
||||
def get_registered_plugin(plugin_name: str, cached_plugins: RegisteredPlugins) -> RegisteredPlugin:
|
||||
ret = None
|
||||
for p in cached_plugins:
|
||||
if p.metadata.plugin_name == plugin_name:
|
||||
ret = p
|
||||
break
|
||||
return ret
|
||||
|
||||
def uninstall_plugin(plugin_name: str, cached_plugins: dict):
|
||||
if plugin_name not in cached_plugins:
|
||||
def uninstall_plugin(plugin_name: str, cached_plugins: RegisteredPlugins):
|
||||
plugin = get_registered_plugin(plugin_name, cached_plugins)
|
||||
if not plugin:
|
||||
raise Exception("插件不存在。")
|
||||
root_dir_name = cached_plugins[plugin_name]["root_dir_name"]
|
||||
root_dir_name = plugin.root_dir_name
|
||||
ppath = get_plugin_store_path()
|
||||
del cached_plugins[plugin_name]
|
||||
cached_plugins.remove(plugin)
|
||||
if not remove_dir(os.path.join(ppath, root_dir_name)):
|
||||
raise Exception("移除插件成功,但是删除插件文件夹失败。您可以手动删除该文件夹,位于 addons/plugins/ 下。")
|
||||
|
||||
def update_plugin(plugin_name: str, cached_plugins: dict):
|
||||
if plugin_name not in cached_plugins:
|
||||
def update_plugin(plugin_name: str, cached_plugins: RegisteredPlugins):
|
||||
plugin = get_registered_plugin(plugin_name, cached_plugins)
|
||||
if not plugin:
|
||||
raise Exception("插件不存在。")
|
||||
ppath = get_plugin_store_path()
|
||||
root_dir_name = cached_plugins[plugin_name]["root_dir_name"]
|
||||
root_dir_name = plugin.root_dir_name
|
||||
plugin_path = os.path.join(ppath, root_dir_name)
|
||||
repo = Repo(path = plugin_path)
|
||||
repo.remotes.origin.pull()
|
||||
@@ -156,7 +198,7 @@ def update_plugin(plugin_name: str, cached_plugins: dict):
|
||||
if os.path.exists(os.path.join(plugin_path, "requirements.txt")):
|
||||
if pipmain(['install', '-r', os.path.join(plugin_path, "requirements.txt"), '--quiet']) != 0:
|
||||
raise Exception("插件依赖安装失败, 需要您手动pip安装对应插件的依赖。")
|
||||
ok, err = plugin_reload(cached_plugins, target=plugin_name)
|
||||
ok, err = plugin_reload(cached_plugins)
|
||||
if not ok: raise Exception(err)
|
||||
|
||||
def remove_dir(file_path) -> bool:
|
||||
|
||||
Reference in New Issue
Block a user