perf: 更好的插件处理逻辑和更开放的插件功能

This commit is contained in:
Soulter
2023-05-13 10:54:57 +08:00
parent 33793a2053
commit 2bf9c82617
9 changed files with 130 additions and 69 deletions
+29 -9
View File
@@ -1,3 +1,10 @@
from nakuru.entities.components import *
from nakuru import (
GroupMessage,
FriendMessage
)
from botpy.message import Message, DirectMessage
class HelloWorldPlugin:
"""
初始化函数, 可以选择直接pass
@@ -7,20 +14,33 @@ class HelloWorldPlugin:
"""
入口函数,机器人会调用此函数。
参数规范: message: 消息文本; role: 身份; platform: 消息平台
参数规范: message: 消息文本; role: 身份; platform: 消息平台; message_obj: 消息对象
参数详情: role为admin或者member; platform为qqchan或者gocq; message_obj为nakuru的GroupMessage对象或者FriendMessage对象或者频道的Message, DirectMessage对象。
返回规范: bool: 是否hit到此插件(所有的消息均会调用每一个载入的插件, 如果没有hit到, 则应返回False)
Tuple: None或者长度为3的元组。当没有hit到时, 返回None. hit到时, 第1个参数为指令是否调用成功, 第2个参数为返回的消息文本, 第3个参数为指令名称
Tuple: None或者长度为3的元组。当没有hit到时, 返回None. hit到时, 第1个参数为指令是否调用成功, 第2个参数为返回的消息文本或者gocq的消息链列表, 第3个参数为指令名称
例子:做一个名为"yuanshen"的插件;当接收到消息为“原神 可莉”, 如果不想要处理此消息,则返回False, None;如果想要处理,但是执行失败了,返回True, tuple([False, "请求失败啦~", "yuanshen"])
;执行成功了,返回True, tuple([True, "结果文本", "yuanshen"])
"""
def run(self, message: str, role: str, platform: str):
def run(self, message: str, role: str, platform: str, message_obj):
if platform == "gocq":
"""
QQ平台指令处理逻辑
"""
if message == "helloworld":
return True, tuple([True, [Plain("Hello World!!")], "helloworld"])
else:
return False, None
elif platform == "qqchan":
"""
频道处理逻辑(频道暂时只支持回复字符串类型的信息,返回的信息都会被转成字符串,如果不想处理某一个平台的信息,直接返回False, None就行)
"""
if message == "helloworld":
return True, tuple([True, "Hello World!!", "helloworld"])
else:
return False, None
# 这里是插件核心处理逻辑
if message == "helloworld":
return True, tuple([True, "Hello World~", "helloworld"])
else:
return False, None
# 热知识:检测消息开头指令,使用以下方法
# if message.startswith("原神"):
# pass
+21 -18
View File
@@ -1,8 +1,7 @@
import botpy
from botpy.message import Message
from botpy.message import Message, DirectMessage
from botpy.types.message import Reference
import re
from botpy.message import DirectMessage
import json
import threading
import asyncio
@@ -396,20 +395,25 @@ def save_provider_preference(chosen_provider):
def send_message(platform, message, res, msg_ref = None, image = None, gocq_loop = None, qqchannel_bot = None, gocq_bot = None):
if platform == PLATFORM_QQCHAN:
if image != None:
qqchannel_bot.send_qq_msg(message, res, image_mode=True, msg_ref=msg_ref)
qqchannel_bot.send_qq_msg(message, str(res), image_mode=True, msg_ref=msg_ref)
else:
qqchannel_bot.send_qq_msg(message, res, msg_ref=msg_ref)
qqchannel_bot.send_qq_msg(message, str(res), msg_ref=msg_ref)
if platform == PLATFORM_GOCQ:
if image != None:
asyncio.run_coroutine_threadsafe(gocq_bot.send_qq_msg(message, image, image_mode=True), gocq_loop).result()
else:
asyncio.run_coroutine_threadsafe(gocq_bot.send_qq_msg(message, res), gocq_loop).result()
asyncio.run_coroutine_threadsafe(gocq_bot.send_qq_msg(message, res, False, ), gocq_loop).result()
'''
处理消息
group: 群聊模式
'''
def oper_msg(message, group=False, msg_ref = None, platform = None):
def oper_msg(message,
group: bool=False,
msg_ref: Reference = None,
platform: str = None):
"""
处理消息。
group: 群聊模式,
message: 频道是频道的消息对象, QQ是nakuru-gocq的消息对象
"""
global session_dict, provider
qq_msg = ''
session_id = ''
@@ -538,7 +542,7 @@ def oper_msg(message, group=False, msg_ref = None, platform = None):
chatgpt_res = ""
if chosen_provider == OPENAI_OFFICIAL:
hit, command_result = command_openai_official.check_command(qq_msg, session_id, user_name, role, platform=platform)
hit, command_result = command_openai_official.check_command(qq_msg, session_id, user_name, role, platform=platform, message_obj=message)
# hit: 是否触发了指令
if not hit:
# 请求ChatGPT获得结果
@@ -551,7 +555,7 @@ def oper_msg(message, group=False, msg_ref = None, platform = None):
send_message(platform, message, f"OpenAI API错误, 原因: {str(e)}", msg_ref=msg_ref, gocq_loop=gocq_loop, qqchannel_bot=qqchannel_bot, gocq_bot=gocq_bot)
elif chosen_provider == REV_CHATGPT:
hit, command_result = command_rev_chatgpt.check_command(qq_msg, role, platform=platform)
hit, command_result = command_rev_chatgpt.check_command(qq_msg, role, platform=platform, message=message)
if not hit:
try:
chatgpt_res = str(rev_chatgpt.text_chat(qq_msg))
@@ -567,7 +571,7 @@ def oper_msg(message, group=False, msg_ref = None, platform = None):
bing_cache_loop = gocq_loop
elif platform == PLATFORM_QQCHAN:
bing_cache_loop = qqchan_loop
hit, command_result = command_rev_edgegpt.check_command(qq_msg, bing_cache_loop, role, platform=platform)
hit, command_result = command_rev_edgegpt.check_command(qq_msg, bing_cache_loop, role, platform=platform, message_obj=message)
if not hit:
try:
while rev_edgegpt.is_busy():
@@ -611,18 +615,17 @@ def oper_msg(message, group=False, msg_ref = None, platform = None):
if command_result[0]:
# 是否是画图指令
if len(command_result) == 3 and command_result[2] == 'draw':
if isinstance(command_result, list) and len(command_result) == 3 and command_result[2] == 'draw':
for i in command_result[1]:
send_message(platform, message, i, msg_ref=msg_ref, image=i, gocq_loop=gocq_loop, qqchannel_bot=qqchannel_bot, gocq_bot=gocq_bot)
else:
else:
try:
send_message(platform, message, command_result[1], msg_ref=msg_ref, gocq_loop=gocq_loop, qqchannel_bot=qqchannel_bot, gocq_bot=gocq_bot)
except BaseException as e:
t = command_result[1].replace(".", " . ")
send_message(platform, message, t, msg_ref=msg_ref, gocq_loop=gocq_loop, qqchannel_bot=qqchannel_bot, gocq_bot=gocq_bot)
send_message(platform, message, f"回复消息出错: {str(e)}", msg_ref=msg_ref, gocq_loop=gocq_loop, qqchannel_bot=qqchannel_bot, gocq_bot=gocq_bot)
else:
send_message(platform, message, f"指令调用错误: \n{command_result[1]}", msg_ref=msg_ref, gocq_loop=gocq_loop, qqchannel_bot=qqchannel_bot, gocq_bot=gocq_bot)
send_message(platform, message, f"指令调用错误: \n{str(command_result[1])}", msg_ref=msg_ref, gocq_loop=gocq_loop, qqchannel_bot=qqchannel_bot, gocq_bot=gocq_bot)
return
+33 -28
View File
@@ -1,4 +1,3 @@
import abc
import json
import git.exc
from git.repo import Repo
@@ -8,6 +7,7 @@ import requests
from model.provider.provider import Provider
import json
import util.plugin_util as putil
import importlib
PLATFORM_QQCHAN = 'qqchan'
PLATFORM_GOCQ = 'gocq'
@@ -17,30 +17,40 @@ class Command:
def __init__(self, provider: Provider):
self.provider = Provider
def check_command(self, message, role, platform):
# 插件
def get_plugin_modules(self):
plugins = []
try:
go = True
if os.path.exists("addons/plugins"):
plugins = putil.get_modules("addons/plugins")
return plugins
elif os.path.exists("QQChannelChatGPT/addons/plugins"):
plugins = putil.get_modules("QQChannelChatGPT/addons/plugins")
return plugins
else:
go = False
return None
except BaseException as e:
raise e
if go:
print(f"[DEBUG] 当前加载的插件:{plugins}")
def check_command(self, message, role, platform, message_obj):
# 插件
try:
plugins = self.get_plugin_modules()
if plugins != None:
# print(f"[DEBUG] 当前加载的插件:{plugins}")
for p in plugins:
module = __import__("addons.plugins." + p + "." + p, fromlist=[p])
cls = putil.get_classes(module)
if p in self.cached_plugins:
module = self.cached_plugins[p]
else:
module = __import__("addons.plugins." + p + "." + p, fromlist=[p])
self.cached_plugins[p] = module
cls = putil.get_classes(p, module)
obj = getattr(module, cls[0])()
hit, res = obj.run(message, role, platform)
hit, res = obj.run(message, role, platform, message_obj)
if hit:
return True, res
except BaseException as e:
print(f"[Debug] 插件加载出现问题,原因: {str(e)}\n已安装插件: {plugins}\n如果你没有相关装插件的想法, 请直接忽略此报错, 不影响其他功能的运行。")
if self.command_start_with(message, "nick"):
return True, self.set_nick(message, platform)
@@ -56,8 +66,8 @@ class Command:
if role != "admin":
return False, f"你的身份组{role}没有权限操作插件", "plugin"
l = message.split(" ")
if len(l) < 3:
return True, "【安装插件】示例:\n安装插件: \nplugin i 插件Github地址\n卸载插件: \nplugin i 插件名", "plugin"
if len(l) < 2:
return True, "【安装插件】示例:\n安装插件: \nplugin i 插件Github地址\n卸载插件: \nplugin i 插件名 \n重载插件: \nplugin reload", "plugin"
else:
ppath = ""
if os.path.exists("addons/plugins"):
@@ -82,8 +92,14 @@ class Command:
return True, "插件卸载成功~", "plugin"
except BaseException as e:
return False, f"卸载插件失败,原因: {str(e)}", "plugin"
elif l[1] == "reload":
try:
for pm in self.cached_plugins:
importlib.reload(self.cached_plugins[pm])
return True, "插件重载成功!", "plugin"
except BaseException as e:
return False, f"插件重载失败,原因: {str(e)}", "plugin"
'''
nick: 存储机器人的昵称
@@ -121,7 +137,7 @@ class Command:
"keyword": "设置关键词/关键指令回复",
"update": "更新面板",
"update latest": "更新到最新版本",
"update r": "重启程序",
"update r": "重启机器人",
"reset": "重置会话",
"nick": "设置机器人昵称",
"/bing": "切换到bing模型",
@@ -130,6 +146,7 @@ class Command:
"/bing 问题": "临时使用一次bing模型进行会话",
"/gpt 问题": "临时使用一次OpenAI ChatGPT API进行会话",
"/revgpt 问题": "临时使用一次网页版ChatGPT进行会话",
"plugin": "插件安装、卸载和重载"
}
def help_messager(self, commands: dict):
@@ -229,18 +246,6 @@ class Command:
pash_tag = "QQChannelChatGPT"+os.sep
repo.remotes.origin.pull()
# 检查是否是windows环境
# if platform.system().lower() == "windows":
# if os.path.exists("launcher.exe"):
# os.system("start launcher.exe")
# elif os.path.exists("QQChannelChatGPT\\main.py"):
# os.system("start python QQChannelChatGPT\\main.py")
# else:
# return True, "更新成功,未发现启动项,因此需要手动重启程序。"
# exit()
# else:
# py = sys.executable
# os.execl(py, py, *sys.argv)
return True, "更新成功~是否重启?输入update r重启(重启指令不返回任何确认信息)。", "update"
except BaseException as e:
+9 -2
View File
@@ -5,9 +5,16 @@ from cores.qqbot.personality import personalities
class CommandOpenAIOfficial(Command):
def __init__(self, provider: ProviderOpenAIOfficial):
self.provider = provider
self.cached_plugins = {}
def check_command(self, message: str, session_id: str, user_name: str, role, platform: str):
hit, res = super().check_command(message, role, platform)
def check_command(self,
message: str,
session_id: str,
user_name: str,
role: str,
platform: str,
message_obj):
hit, res = super().check_command(message, role, platform, message_obj=message_obj)
if hit:
return True, res
if self.command_start_with(message, "reset", "重置"):
+8 -3
View File
@@ -4,9 +4,14 @@ from model.provider.provider_rev_chatgpt import ProviderRevChatGPT
class CommandRevChatGPT(Command):
def __init__(self, provider: ProviderRevChatGPT):
self.provider = provider
def check_command(self, message: str, role, platform: str):
hit, res = super().check_command(message, role, platform)
self.cached_plugins = {}
def check_command(self,
message: str,
role: str,
platform: str,
message_obj):
hit, res = super().check_command(message, role, platform, message_obj=message_obj)
if hit:
return True, res
if self.command_start_with(message, "help", "帮助"):
+9 -2
View File
@@ -4,9 +4,16 @@ import asyncio
class CommandRevEdgeGPT(Command):
def __init__(self, provider: ProviderRevEdgeGPT):
self.provider = provider
self.cached_plugins = {}
def check_command(self, message: str, loop, role, platform: str):
hit, res = super().check_command(message, role, platform)
def check_command(self,
message: str,
loop,
role: str,
platform: str,
message_obj):
hit, res = super().check_command(message, role, platform, message_obj=message_obj)
if hit:
return True, res
if self.command_start_with(message, "reset"):
+14 -2
View File
@@ -5,8 +5,20 @@ class QQ:
self.client = gocq
self.client.run()
async def send_qq_msg(self, source, res, image_mode = False):
print("[System-Info] 回复QQ消息中..."+res)
async def send_qq_msg(self,
source,
res,
image_mode: bool = False):
"""
res可以是一个数组,也就是gocq的消息链.
"""
# print(res)
print("[System-Info] 回复QQ消息中..."+str(res))
if isinstance(res, list) and len(res) > 0:
await self.client.sendGroupMessage(source.group_id, res)
return
# 通过消息链处理
if not image_mode:
if source.type == "GroupMessage":
+4 -3
View File
@@ -1,6 +1,7 @@
import io
import botpy
from PIL import Image
from botpy.message import Message, DirectMessage
import re
import asyncio
import requests
@@ -15,13 +16,13 @@ class QQChan():
self.client.run(appid=appid, token=token)
def send_qq_msg(self, message, res, image_mode=False, msg_ref = None):
print("[System-Info] 回复QQ频道消息中..."+res)
print("[System-Info] 回复QQ频道消息中..."+str(res))
if not image_mode:
try:
if msg_ref is not None:
reply_res = asyncio.run_coroutine_threadsafe(message.reply(content=res, message_reference = msg_ref), self.client.loop)
reply_res = asyncio.run_coroutine_threadsafe(message.reply(content=str(res), message_reference = msg_ref), self.client.loop)
else:
reply_res = asyncio.run_coroutine_threadsafe(message.reply(content=res), self.client.loop)
reply_res = asyncio.run_coroutine_threadsafe(message.reply(content=str(res)), self.client.loop)
reply_res.result()
except BaseException as e:
# 分割过长的消息
+3 -2
View File
@@ -2,11 +2,12 @@ import os
import inspect
# 找出模块里所有的类名
def get_classes(arg):
def get_classes(p_name, arg):
classes = []
clsmembers = inspect.getmembers(arg, inspect.isclass)
for (name, _) in clsmembers:
classes.append(name)
if p_name == name.lower().replace("plugin", ""):
classes.append(name)
return classes
# 获取一个文件夹下所有的模块