perf: clean codes; 将 keyword 功能转移至 helloworld 插件下

This commit is contained in:
Soulter
2024-02-12 23:17:55 +08:00
parent acb68a4a1e
commit 3c87fc5b31
10 changed files with 170 additions and 269 deletions
+106 -11
View File
@@ -8,6 +8,8 @@ from cores.qqbot.global_object import (
AstrMessageEvent,
CommandResult
)
import os
import shutil
'''
注意改插件名噢!格式:XXXPlugin 或 Main
@@ -18,7 +20,14 @@ class HelloWorldPlugin:
初始化函数, 可以选择直接pass
"""
def __init__(self) -> None:
print("hello, world!")
# 复制旧配置文件到 data 目录下。
if os.path.exists("keyword.json"):
shutil.move("keyword.json", "data/keyword.json")
self.keywords = {}
if os.path.exists("data/keyword.json"):
self.keywords = json.load(open("data/keyword.json", "r"))
else:
self.save_keyword()
"""
机器人程序会调用此函数。
@@ -28,20 +37,106 @@ class HelloWorldPlugin:
"""
def run(self, ame: AstrMessageEvent):
if ame.message_str == "helloworld":
# return True, tuple([True, "Hello World!!", "helloworld"])
return CommandResult(
hit=True,
success=True,
message_chain=[Plain("Hello World!!")],
command_name="helloworld"
)
if ame.message_str.startswith("/keyword") or ame.message_str.startswith("keyword"):
return self.handle_keyword_command(ame)
ret = self.check_keyword(ame.message_str)
if ret: return ret
return CommandResult(
hit=False,
success=False,
message_chain=None,
command_name=None
)
def handle_keyword_command(self, ame: AstrMessageEvent):
l = ame.message_str.split(" ")
# 获取图片
image_url = ""
for comp in ame.message_obj.message:
if isinstance(comp, Image) and image_url == "":
if comp.url is None:
image_url = comp.file
else:
image_url = comp.url
command_result = CommandResult(
hit=True,
success=False,
message_chain=None,
command_name="keyword"
)
if len(l) == 1 or (len(l) == 2 and image_url == ""):
ret = """【设置关键词回复】
示例:
1. keyword <触发词> <回复词>
keyword hi 你好
发送 hi 回复你好
* 回复词支持图片
2. keyword d <触发词>
keyword d hi
删除 hi 触发词产生的回复"""
command_result.success = True
command_result.message_chain = [Plain(ret)]
return command_result
elif len(l) == 3 and l[1] == "d":
if l[2] not in self.keywords:
command_result.message_chain = [Plain(f"关键词 {l[2]} 不存在")]
return command_result
self.keywords.pop(l[2])
self.save_keyword()
command_result.success = True
command_result.message_chain = [Plain("删除成功")]
return command_result
else:
return CommandResult(
hit=False,
success=False,
message_chain=None,
command_name=None
)
self.keywords[l[1]] = {
"plain_text": " ".join(l[2:]),
"image_url": image_url
}
self.save_keyword()
command_result.success = True
command_result.message_chain = [Plain("设置成功")]
return command_result
def save_keyword(self):
json.dump(self.keywords, open("data/keyword.json", "w"), ensure_ascii=False)
def check_keyword(self, message_str: str):
for k in self.keywords:
if message_str == k:
plain_text = ""
if 'plain_text' in self.keywords[k]:
plain_text = self.keywords[k]['plain_text']
else:
plain_text = self.keywords[k]
image_url = ""
if 'image_url' in self.keywords[k]:
image_url = self.keywords[k]['image_url']
if image_url != "":
res = [Plain(plain_text), Image.fromURL(image_url)]
return CommandResult(
hit=True,
success=True,
message_chain=res,
command_name="keyword"
)
return CommandResult(
hit=True,
success=True,
message_chain=[Plain(plain_text)],
command_name="keyword"
)
"""
插件元信息。
当用户输入 plugin v 插件名称 时,会调用此函数,返回帮助信息。
@@ -58,8 +153,8 @@ class HelloWorldPlugin:
def info(self):
return {
"name": "helloworld",
"desc": "测试插件",
"help": "测试插件, 回复 helloworld 即可触发",
"version": "v1.2",
"desc": "这是 AstrBot 的默认插件,支持关键词回复。",
"help": "输入 /keyword 查看关键词回复帮助。",
"version": "v1.3",
"author": "Soulter"
}
-23
View File
@@ -1,23 +0,0 @@
'''
监测机器性能
- Bot 内存使用量
- CPU 占用率
'''
import psutil
from cores.qqbot.global_object import GlobalObject
import time
def run_monitor(global_object: GlobalObject):
'''运行监测'''
start_time = time.time()
while True:
stat = global_object.dashboard_data.stats
# 程序占用的内存大小
mem = psutil.Process().memory_info().rss / 1024 / 1024 # MB
stat['sys_perf'] = {
'memory': mem,
'cpu': psutil.cpu_percent()
}
stat['sys_start_time'] = start_time
time.sleep(30)
+8 -101
View File
@@ -7,19 +7,18 @@ import requests
import util.unfit_words as uw
import os
import sys
from cores.qqbot.personality import personalities
from addons.baidu_aip_judge import BaiduJudge
from nakuru import (
GroupMessage,
FriendMessage,
GuildMessage,
)
from model.platform._nakuru_translation_layer import NakuruGuildMember, NakuruGuildMessage
from model.platform._nakuru_translation_layer import NakuruGuildMessage
from nakuru.entities.components import Plain,At,Image
from model.provider.provider import Provider
from model.command.command import Command
from util import general_utils as gu
from util.general_utils import Logger
from util.general_utils import Logger, upload, run_monitor
from util.cmd_config import CmdConfig as cc
from util.cmd_config import init_astrbot_config_items
import util.function_calling.gplugin as gplugin
@@ -31,7 +30,6 @@ from . global_object import GlobalObject
from typing import Union
from addons.dashboard.helper import DashBoardHelper
from addons.dashboard.server import DashBoardData
from cores.monitor.perf import run_monitor
from cores.database.conn import dbConn
from model.platform._message_result import MessageResult
@@ -40,7 +38,7 @@ user_frequency = {}
# 时间默认值
frequency_time = 60
# 计数默认值
frequency_count = 2
frequency_count = 10
# 版本
version = '3.1.5'
@@ -57,8 +55,6 @@ llm_wake_prefix = ""
# 百度内容审核实例
baidu_judge = None
# 关键词回复
keywords = {}
# CLI
PLATFORM_CLI = 'cli'
@@ -69,36 +65,6 @@ init_astrbot_config_items()
_global_object: GlobalObject = None
logger: Logger = Logger()
# 统计消息数据
def upload():
global version
while True:
addr_ip = ''
try:
o = {
"cnt_total": _global_object.cnt_total,
"admin": _global_object.admin_qq,
}
o_j = json.dumps(o)
res = {
"version": version,
"count": _global_object.cnt_total,
"cntqc": -1,
"cntgc": -1,
"ip": addr_ip,
"others": o_j,
"sys": sys.platform,
}
logger.log(res, gu.LEVEL_DEBUG, tag="Uploader")
resp = requests.post('https://api.soulter.top/upload', data=json.dumps(res), timeout=5)
if resp.status_code == 200:
ok = resp.json()
if ok['status'] == 'ok':
_global_object.cnt_total = 0
except BaseException as e:
pass
time.sleep(10*60)
# 语言模型选择
def privider_chooser(cfg):
l = []
@@ -111,11 +77,11 @@ def privider_chooser(cfg):
'''
初始化机器人
'''
def initBot(cfg):
def init(cfg):
global llm_instance, llm_command_instance
global baidu_judge, chosen_provider
global frequency_count, frequency_time
global keywords, _global_object
global _global_object
global logger
# 迁移旧配置
@@ -128,6 +94,7 @@ def initBot(cfg):
# 初始化 global_object
_global_object = GlobalObject()
_global_object.version = version
_global_object.base_config = cfg
_global_object.stat['session'] = {}
_global_object.stat['message'] = {}
@@ -167,11 +134,6 @@ def initBot(cfg):
llm_command_instance[OPENAI_OFFICIAL] = CommandOpenAIOfficial(llm_instance[OPENAI_OFFICIAL], _global_object)
chosen_provider = OPENAI_OFFICIAL
# 得到关键词
if os.path.exists("keyword.json"):
with open("keyword.json", 'r', encoding='utf-8') as f:
keywords = json.load(f)
# 检查provider设置偏好
p = cc.get("chosen_provider", None)
if p is not None and p in llm_instance:
@@ -185,7 +147,7 @@ def initBot(cfg):
except BaseException as e:
logger.log("百度内容审核初始化失败", gu.LEVEL_ERROR)
threading.Thread(target=upload, daemon=True).start()
threading.Thread(target=upload, args=(_global_object, ), daemon=True).start()
# 得到发言频率配置
if 'limit' in cfg:
@@ -273,34 +235,6 @@ def initBot(cfg):
dashboard_thread.join()
async def cli():
time.sleep(1)
while True:
try:
prompt = input(">>> ")
if prompt == "":
continue
ngm = await cli_pack_message(prompt)
await oper_msg(ngm, True, PLATFORM_CLI)
except EOFError:
return
async def cli_pack_message(prompt: str) -> NakuruGuildMessage:
ngm = NakuruGuildMessage()
ngm.channel_id = 6180
ngm.user_id = 6180
ngm.message = [Plain(prompt)]
ngm.type = "GuildMessage"
ngm.self_id = 6180
ngm.self_tiny_id = 6180
ngm.guild_id = 6180
ngm.sender = NakuruGuildMember()
ngm.sender.tiny_id = 6180
ngm.sender.user_id = 6180
ngm.sender.nickname = "CLI"
ngm.sender.role = 0
return ngm
'''
运行 QQ_OFFICIAL 机器人
'''
@@ -338,7 +272,6 @@ def run_gocq_bot(cfg: dict, _global_object: GlobalObject):
except BaseException as e:
input("启动QQ机器人出现错误"+str(e))
'''
检查发言频率
'''
@@ -381,7 +314,7 @@ async def oper_msg(message: Union[GroupMessage, FriendMessage, GuildMessage, Nak
role: member | admin
platform: 平台(gocq, qqchan)
"""
global chosen_provider, keywords, _global_object
global chosen_provider, _global_object
message_str = ''
session_id = session_id
role = role
@@ -401,22 +334,6 @@ async def oper_msg(message: Union[GroupMessage, FriendMessage, GuildMessage, Nak
user_id = message.user_id
if not check_frequency(user_id):
return MessageResult(f'你的发言超过频率限制(╯▔皿▔)╯。\n管理员设置{frequency_time}秒内只能提问{frequency_count}次。')
# 关键词回复
for k in keywords:
if message_str == k:
plain_text = ""
if 'plain_text' in keywords[k]:
plain_text = keywords[k]['plain_text']
else:
plain_text = keywords[k]
image_url = ""
if 'image_url' in keywords[k]:
image_url = keywords[k]['image_url']
if image_url != "":
res = [Plain(plain_text), Image.fromURL(image_url)]
return MessageResult(res)
return MessageResult(plain_text)
# 检查是否是更换语言模型的请求
temp_switch = ""
@@ -503,16 +420,6 @@ async def oper_msg(message: Union[GroupMessage, FriendMessage, GuildMessage, Nak
return
command = command_result[2]
if command == "keyword":
if os.path.exists("keyword.json"):
with open("keyword.json", "r", encoding="utf-8") as f:
keywords = json.load(f)
else:
try:
return MessageResult(command_result[1])
except BaseException as e:
return MessageResult(f"回复消息出错: {str(e)}")
if command == "update latest r":
def update_restart():
py = sys.executable
+2 -4
View File
@@ -1,14 +1,11 @@
from model.platform.qq_official import QQOfficial, NakuruGuildMember, NakuruGuildMessage
from model.platform.qq_official import QQOfficial, NakuruGuildMessage
from model.platform.qq_gocq import QQGOCQ
from model.provider.provider import Provider
from addons.dashboard.server import DashBoardData
from nakuru import (
CQHTTP,
GroupMessage,
GroupMemberIncrease,
FriendMessage,
GuildMessage,
Notify
)
from typing import Union
@@ -16,6 +13,7 @@ class GlobalObject:
'''
存放一些公用的数据,用于在不同模块(如core与command)之间传递
'''
version: str
nick: str # gocq 的昵称
base_config: dict # config.json
cached_plugins: dict # 缓存的插件
+1 -1
View File
@@ -42,7 +42,7 @@ def main():
os.mkdir(abs_path+"data/config")
# 启动主程序(cores/qqbot/core.py
qqBot.initBot(cfg)
qqBot.init(cfg)
def check_env(ch_mirror=False):
if not (sys.version_info.major == 3 and sys.version_info.minor >= 9):
+1 -72
View File
@@ -83,8 +83,6 @@ class Command:
return True, self.get_new_conf(message, role)
if self.command_start_with(message, "web"): # 网页搜索
return True, self.web_search(message)
if self.command_start_with(message, "keyword"):
return True, self.keyword(message_obj, role)
if self.command_start_with(message, "ip"):
ip = requests.get("https://myip.ipip.net", timeout=5).text
return True, f"机器人 IP 信息:{ip}", "ip"
@@ -237,78 +235,9 @@ class Command:
return True
return False
# keyword: 关键字
def keyword(self, message_obj, role: str):
if role != "admin":
return True, "你没有权限使用该指令", "keyword"
plain_text = ""
image_url = ""
for comp in message_obj.message:
if isinstance(comp, Plain):
plain_text += comp.text
elif isinstance(comp, Image) and image_url == "":
if comp.url is None:
image_url = comp.file
else:
image_url = comp.url
l = plain_text.split(" ")
if len(l) < 3 and image_url == "":
return True, """【设置关键词回复】示例:
1. keyword hi 你好
当发送hi的时候会回复你好
2. keyword /hi 你好
当发送/hi时会回复你好
3. keyword d hi
删除hi关键词的回复
4. keyword hi <图片>
当发送hi时会回复图片""", "keyword"
del_mode = False
if l[1] == "d":
del_mode = True
try:
if os.path.exists("keyword.json"):
with open("keyword.json", "r", encoding="utf-8") as f:
keyword = json.load(f)
if del_mode:
# 删除关键词
if l[2] not in keyword:
return False, "该关键词不存在", "keyword"
else: del keyword[l[2]]
else:
keyword[l[1]] = {
"plain_text": " ".join(l[2:]),
"image_url": image_url
}
else:
if del_mode:
return False, "该关键词不存在", "keyword"
keyword = {
l[1]: {
"plain_text": " ".join(l[2:]),
"image_url": image_url
}
}
with open("keyword.json", "w", encoding="utf-8") as f:
json.dump(keyword, f, ensure_ascii=False, indent=4)
f.flush()
if del_mode:
return True, "删除成功: "+l[2], "keyword"
if image_url == "":
return True, "设置成功: "+l[1]+" "+" ".join(l[2:]), "keyword"
else:
return True, [Plain("设置成功: "+l[1]+" "+" ".join(l[2:])), Image.fromURL(image_url)], "keyword"
except BaseException as e:
return False, "设置失败: "+str(e), "keyword"
def update(self, message: str, role: str):
if role != "admin":
return True, "你没有权限使用该指令", "keyword"
return True, "你没有权限使用该指令", "update"
l = message.split(" ")
if len(l) == 1:
try:
+5 -4
View File
@@ -10,7 +10,7 @@ from nakuru import (
from ._nakuru_translation_layer import (
NakuruGuildMessage,
)
from nakuru.entities.components import Plain, At, Image, Node
from nakuru.entities.components import Plain, At, Image
class Platform():
@@ -49,7 +49,7 @@ class Platform():
'''
pass
def parse_message_outline(self, message: Union[GuildMessage, GroupMessage, FriendMessage, str]) -> NakuruGuildMessage:
def parse_message_outline(self, message: Union[GuildMessage, GroupMessage, FriendMessage, str, list]) -> str:
'''
将消息解析成大纲消息形式。
如: xxxxx[图片]xxxxx
@@ -57,14 +57,15 @@ class Platform():
if isinstance(message, str):
return message
ret = ''
ls_to_parse = message if isinstance(message, list) else message.message
try:
for node in message.message:
for node in ls_to_parse:
if isinstance(node, Plain):
ret += node.text
elif isinstance(node, At):
ret += f'[At: {node.name}/{node.qq}]'
elif isinstance(node, Image):
ret += f'[图片]'
ret += '[图片]'
except Exception as e:
pass
ret.replace('\n', '')
+2 -11
View File
@@ -8,8 +8,7 @@ from nakuru import (
GroupMessage,
FriendMessage,
GroupMemberIncrease,
Notify,
Member
Notify
)
from typing import Union
import time
@@ -31,7 +30,6 @@ class QQGOCQ(Platform):
asyncio.set_event_loop(self.loop)
self.waiting = {}
self.gocq_cnt = 0
self.cc = CmdConfig()
self.cfg = cfg
self.logger: gu.Logger = global_object.logger
@@ -175,8 +173,6 @@ class QQGOCQ(Platform):
"""
source = message
res = result_message
self.gocq_cnt += 1
self.logger.log(f"{source.user_id} <- {self.parse_message_outline(res)}", tag="QQ_GOCQ")
@@ -315,9 +311,4 @@ class QQGOCQ(Platform):
return ret
except BaseException as e:
raise e
def get_cnt(self):
return self.gocq_cnt
def set_cnt(self, cnt):
self.gocq_cnt = cnt
-9
View File
@@ -13,7 +13,6 @@ import time
from ._platfrom import Platform
from ._nakuru_translation_layer import(
NakuruGuildMessage,
NakuruGuildMember,
gocq_compatible_receive,
gocq_compatible_send
)
@@ -44,7 +43,6 @@ class QQOfficial(Platform):
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
self.qqchan_cnt = 0
self.waiting: dict = {}
self.cfg = cfg
@@ -125,7 +123,6 @@ class QQOfficial(Platform):
回复频道消息
'''
self.logger.log(f"{message.sender.nickname}({message.sender.tiny_id}) <- {self.parse_message_outline(res)}", tag="QQ_OFFICIAL")
self.qqchan_cnt += 1
plain_text = ''
image_path = ''
@@ -246,9 +243,3 @@ class QQOfficial(Platform):
if cnt > 300:
raise Exception("等待消息超时。")
time.sleep(1)
def get_cnt(self):
return self.qqchan_cnt
def set_cnt(self, cnt):
self.qqchan_cnt = cnt
+45 -33
View File
@@ -11,6 +11,9 @@ from cores.qqbot.global_object import GlobalObject
import platform
import requests
import logging
import json
import sys
import psutil
PLATFORM_GOCQ = 'gocq'
PLATFORM_QQCHAN = 'qqchan'
@@ -193,7 +196,6 @@ def word2img(title: str, text: str, max_width=30, font_size=20):
return image
def render_markdown(markdown_text, image_width=800, image_height=600, font_size=26, font_color=(0, 0, 0), bg_color=(255, 255, 255)):
HEADER_MARGIN = 20
@@ -322,7 +324,6 @@ def render_markdown(markdown_text, image_width=800, image_height=600, font_size=
height += font_size + TEXT_LINE_MARGIN*2
markdown_text = '\n'.join(pre_lines)
print("Pre process done, height: ", height)
image_height = height
if image_height < 100:
image_height = 100
@@ -332,13 +333,6 @@ def render_markdown(markdown_text, image_width=800, image_height=600, font_size=
image = Image.new('RGB', (image_width, image_height), bg_color)
draw = ImageDraw.Draw(image)
# # get all the emojis unicode in the markdown text
# unicode_text = markdown_text.encode('unicode_escape').decode()
# # print(unicode_text)
# unicode_emojis = re.findall(r'\\U\w{8}', unicode_text)
# emoji_base_url = "https://abs.twimg.com/emoji/v1/72x72/{unicode_emoji}.png"
# 设置初始位置
x, y = 10, 10
@@ -360,25 +354,10 @@ def render_markdown(markdown_text, image_width=800, image_height=600, font_size=
line = line.strip()
if line.startswith("#"):
# unicode_emojis = re.findall(r'\\U0001\w{4}', line)
# for unicode_emoji in unicode_emojis:
# line = line.replace(unicode_emoji, "")
# unicode_emoji = ""
# if len(unicode_emojis) > 0:
# unicode_emoji = unicode_emojis[0]
# 处理标题
header_level = line.count("#")
line = line.strip("#").strip()
font_size_header = HEADER_FONT_STANDARD_SIZE - header_level * 4
# if unicode_emoji != "":
# emoji_url = emoji_base_url.format(unicode_emoji=unicode_emoji[-5:])
# emoji = Image.open(requests.get(emoji_url, stream=True).raw)
# emoji = emoji.resize((font_size, font_size))
# image.paste(emoji, (x, y))
# x += font_size
font = ImageFont.truetype(font_path, font_size_header)
y += HEADER_MARGIN # 上边距
# 字间距
@@ -389,10 +368,6 @@ def render_markdown(markdown_text, image_width=800, image_height=600, font_size=
elif line.startswith(">"):
# 处理引用
quote_text = line.strip(">")
# quote_width = image_width - 20 # 引用框的宽度为图像宽度减去左右边距
# quote_height = font_size + 10 # 引用框的高度为字体大小加上上下边距
# quote_box = (x, y, x + quote_width, y + quote_height)
# draw.rounded_rectangle(quote_box, radius=5, fill=(230, 230, 230), width=2) # 使用灰色填充矩形框作为引用背景
y+=QUOTE_LEFT_LINE_MARGIN
draw.line((x, y, x, y + QUOTE_LEFT_LINE_HEIGHT), fill=QUOTE_LEFT_LINE_COLOR, width=QUOTE_LEFT_LINE_WIDTH)
font = ImageFont.truetype(font_path, QUOTE_FONT_SIZE)
@@ -468,7 +443,6 @@ def render_markdown(markdown_text, image_width=800, image_height=600, font_size=
y += image_res.size[1] + IMAGE_MARGIN*2
return image
def save_temp_img(img: Image) -> str:
if not os.path.exists("temp"):
os.makedirs("temp")
@@ -490,7 +464,6 @@ def save_temp_img(img: Image) -> str:
img.save(p)
return p
def create_text_image(title: str, text: str, max_width=30, font_size=20):
'''
文本转图片。
@@ -520,9 +493,10 @@ def create_markdown_image(text: str):
except Exception as e:
raise e
# 迁移配置文件到 cmd_config.json
def try_migrate_config(old_config: dict):
'''
迁移配置文件到 cmd_config.json
'''
cc = CmdConfig()
if cc.get("qqbot", None) is None:
# 未迁移过
@@ -553,4 +527,42 @@ def get_sys_info(global_object: GlobalObject):
'mem': mem,
'os': os_name + '_' + os_version,
'py': platform.python_version(),
}
}
def upload(_global_object: GlobalObject):
while True:
addr_ip = ''
try:
res = {
"version": _global_object.version,
"count": _global_object.cnt_total,
"ip": addr_ip,
"sys": sys.platform,
"admin": _global_object.admin_qq,
}
resp = requests.post('https://api.soulter.top/upload', data=json.dumps(res), timeout=5)
if resp.status_code == 200:
ok = resp.json()
if ok['status'] == 'ok':
_global_object.cnt_total = 0
except BaseException as e:
pass
time.sleep(10*60)
def run_monitor(global_object: GlobalObject):
'''
监测机器性能
- Bot 内存使用量
- CPU 占用率
'''
start_time = time.time()
while True:
stat = global_object.dashboard_data.stats
# 程序占用的内存大小
mem = psutil.Process().memory_info().rss / 1024 / 1024 # MB
stat['sys_perf'] = {
'memory': mem,
'cpu': psutil.cpu_percent()
}
stat['sys_start_time'] = start_time
time.sleep(30)