perf: 升级插件协议簇

This commit is contained in:
Soulter
2023-05-15 20:03:17 +08:00
parent 9f36e5ae05
commit 9c284b84b1
6 changed files with 60 additions and 17 deletions
+10 -6
View File
@@ -309,6 +309,7 @@ def initBot(cfg, prov):
gu.log("--------加载平台--------", gu.LEVEL_INFO, fg=gu.FG_COLORS['yellow'])
# GOCQ
global gocq_bot
if 'gocqbot' in cfg and cfg['gocqbot']['enable']:
gu.log("- 启用QQ机器人 -", gu.LEVEL_INFO)
@@ -326,11 +327,14 @@ def initBot(cfg, prov):
with open("cmd_config.json", 'w', encoding='utf-8') as f:
json.dump(cmd_config, f, indent=4)
f.flush()
global gocq_app, gocq_bot, gocq_loop
gocq_bot = QQ()
global gocq_app, gocq_loop
gocq_loop = asyncio.new_event_loop()
gocq_bot = QQ(True, gocq_loop)
thread_inst = threading.Thread(target=run_gocq_bot, args=(gocq_loop, gocq_bot, gocq_app), daemon=False)
thread_inst.start()
else:
gocq_bot = QQ(False)
# QQ频道
if 'qqbot' in cfg and cfg['qqbot']['enable']:
@@ -437,7 +441,7 @@ def oper_msg(message,
role = "member" # 角色
hit = False # 是否命中指令
command_result = () # 调用指令返回的结果
global admin_qq, cached_plugins
global admin_qq, cached_plugins, gocq_bot
if platform == PLATFORM_QQCHAN:
gu.log(f"接收到消息:{message.content}", gu.LEVEL_INFO, tag="QQ频道")
@@ -556,7 +560,7 @@ def oper_msg(message,
chatgpt_res = ""
if chosen_provider == OPENAI_OFFICIAL:
hit, command_result = command_openai_official.check_command(qq_msg, session_id, user_name, role, platform=platform, message_obj=message, cached_plugins=cached_plugins)
hit, command_result = command_openai_official.check_command(qq_msg, session_id, user_name, role, platform=platform, message_obj=message, cached_plugins=cached_plugins, qq_platform=gocq_bot)
# hit: 是否触发了指令
if not hit:
# 请求ChatGPT获得结果
@@ -569,7 +573,7 @@ def oper_msg(message,
send_message(platform, message, f"OpenAI API错误, 原因: {str(e)}", msg_ref=msg_ref, gocq_loop=gocq_loop, qqchannel_bot=qqchannel_bot, gocq_bot=gocq_bot)
elif chosen_provider == REV_CHATGPT:
hit, command_result = command_rev_chatgpt.check_command(qq_msg, role, platform=platform, message_obj=message, cached_plugins=cached_plugins)
hit, command_result = command_rev_chatgpt.check_command(qq_msg, role, platform=platform, message_obj=message, cached_plugins=cached_plugins, qq_platform=gocq_bot)
if not hit:
try:
chatgpt_res = str(rev_chatgpt.text_chat(qq_msg))
@@ -585,7 +589,7 @@ def oper_msg(message,
bing_cache_loop = gocq_loop
elif platform == PLATFORM_QQCHAN:
bing_cache_loop = qqchan_loop
hit, command_result = command_rev_edgegpt.check_command(qq_msg, bing_cache_loop, role, platform=platform, message_obj=message, cached_plugins=cached_plugins)
hit, command_result = command_rev_edgegpt.check_command(qq_msg, bing_cache_loop, role, platform=platform, message_obj=message, cached_plugins=cached_plugins, qq_platform=gocq_bot)
if not hit:
try:
while rev_edgegpt.is_busy():
+3 -3
View File
@@ -10,7 +10,7 @@ import util.plugin_util as putil
import shutil
import importlib
from util import general_utils as gu
from model.platform.qq import QQ
PLATFORM_QQCHAN = 'qqchan'
PLATFORM_GOCQ = 'gocq'
@@ -34,12 +34,12 @@ class Command:
except BaseException as e:
raise e
def check_command(self, message, role, platform, message_obj, cached_plugins: dict):
def check_command(self, message, role, platform, message_obj, cached_plugins: dict, qq_platform: QQ):
# 插件
for k, v in cached_plugins.items():
try:
hit, res = v["clsobj"].run(message, role, platform, message_obj)
hit, res = v["clsobj"].run(message, role, platform, message_obj, qq_platform)
if hit:
return True, res
except BaseException as e:
+4 -2
View File
@@ -1,6 +1,7 @@
from model.command.command import Command
from model.provider.provider_openai_official import ProviderOpenAIOfficial
from cores.qqbot.personality import personalities
from model.platform.qq import QQ
class CommandOpenAIOfficial(Command):
def __init__(self, provider: ProviderOpenAIOfficial):
@@ -14,8 +15,9 @@ class CommandOpenAIOfficial(Command):
role: str,
platform: str,
message_obj,
cached_plugins: dict):
hit, res = super().check_command(message, role, platform, message_obj=message_obj, cached_plugins=cached_plugins)
cached_plugins: dict,
qq_platform: QQ):
hit, res = super().check_command(message, role, platform, message_obj=message_obj, cached_plugins=cached_plugins, qq_platform=qq_platform)
if hit:
return True, res
if self.command_start_with(message, "reset", "重置"):
+4 -2
View File
@@ -1,5 +1,6 @@
from model.command.command import Command
from model.provider.provider_rev_chatgpt import ProviderRevChatGPT
from model.platform.qq import QQ
class CommandRevChatGPT(Command):
def __init__(self, provider: ProviderRevChatGPT):
@@ -11,8 +12,9 @@ class CommandRevChatGPT(Command):
role: str,
platform: str,
message_obj,
cached_plugins: dict):
hit, res = super().check_command(message, role, platform, message_obj=message_obj, cached_plugins=cached_plugins)
cached_plugins: dict,
qq_platform: QQ):
hit, res = super().check_command(message, role, platform, message_obj=message_obj, cached_plugins=cached_plugins, qq_platform=qq_platform)
if hit:
return True, res
if self.command_start_with(message, "help", "帮助"):
+5 -2
View File
@@ -1,6 +1,8 @@
from model.command.command import Command
from model.provider.provider_rev_edgegpt import ProviderRevEdgeGPT
import asyncio
from model.platform.qq import QQ
class CommandRevEdgeGPT(Command):
def __init__(self, provider: ProviderRevEdgeGPT):
self.provider = provider
@@ -13,8 +15,9 @@ class CommandRevEdgeGPT(Command):
role: str,
platform: str,
message_obj,
cached_plugins: dict):
hit, res = super().check_command(message, role, platform, message_obj=message_obj, cached_plugins=cached_plugins)
cached_plugins: dict,
qq_platform: QQ):
hit, res = super().check_command(message, role, platform, message_obj=message_obj, cached_plugins=cached_plugins, qq_platform=qq_platform)
if hit:
return True, res
if self.command_start_with(message, "reset"):
+34 -2
View File
@@ -1,22 +1,40 @@
from nakuru.entities.components import Plain, At, Image
from util import general_utils as gu
import asyncio
class QQ:
def __init__(self, is_start: bool, gocq_loop = None) -> None:
self.is_start = is_start
self.gocq_loop = gocq_loop
def run_bot(self, gocq):
self.client = gocq
self.client.run()
def get_msg_loop(self):
return self.gocq_loop
async def send_qq_msg(self,
source,
res,
image_mode: bool = False):
if not self.is_start:
raise Exception("管理员未启动QQ平台")
"""
res可以是一个数组也就是gocq的消息链.
res可以是一个数组, 也就是gocq的消息链
插件开发者请使用send方法, 可以不用直接调用这个方法。
"""
gu.log("回复QQ消息: "+str(res), level=gu.LEVEL_INFO, tag="QQ", max_len=30)
if isinstance(source, int):
source = {
"type": "GroupMessage",
"group_id": source
}
if isinstance(res, list) and len(res) > 0:
await self.client.sendGroupMessage(source.group_id, res)
return
# 通过消息链处理
if not image_mode:
if source.type == "GroupMessage":
@@ -39,4 +57,18 @@ class QQ:
await self.client.sendFriendMessage(source.user_id, [
Plain(text="好的,我根据你的需要为你生成了一张图片😊"),
Image.fromURL(url=res)
])
])
def send(self,
to,
res):
'''
提供给插件的发送QQ消息接口, 不用在外部await。
参数说明:第一个参数可以是消息对象,也可以是QQ群号。第二个参数是消息内容(消息内容可以是消息链列表,也可以是纯文字信息)。
'''
if isinstance(to, int):
try:
asyncio.run_coroutine_threadsafe(self.send_qq_msg(message_obj, res), self.gocq_loop).result()
except BaseException as e:
raise e