Compare commits

...

13 Commits

Author SHA1 Message Date
Soulter edfde51434 fix: 修复频道平台下未找到平台 qqchan 的实例的错误 2024-03-13 19:53:36 +08:00
Soulter 3fc1347bba fix: plugin register management 2024-03-12 20:00:02 +08:00
Soulter e643eea365 perf: 结构化插件的表示格式; 优化插件开发接口 2024-03-12 18:50:50 +08:00
Soulter 1af481f5f9 fix: function call with newer version 2024-03-07 17:35:21 +08:00
Soulter 317d1c4c41 fix: onebot protocol connection error 2024-03-05 14:03:46 +08:00
Soulter a703860512 fix: plugin call 2024-03-05 13:52:44 +08:00
Soulter 1cd1c8ea0d feat: 异步重写
perf: 优化网页搜索回答规范
2024-03-03 18:54:50 +08:00
Soulter 53ef3bbf4f fix: 修复修改cqhttp端口后仍检测失败的问题 2024-02-19 19:04:40 +08:00
Soulter ab7b8aad7c chore: delete llms 2024-02-12 23:28:12 +08:00
Soulter c49213282b Merge remote-tracking branch 'refs/remotes/origin/master' 2024-02-12 23:18:11 +08:00
Soulter 3c87fc5b31 perf: clean codes; 将 keyword 功能转移至 helloworld 插件下 2024-02-12 23:17:55 +08:00
Soulter 9684508e1d Update README.md 2024-02-11 13:47:09 +08:00
Soulter bb0edae200 Update README.md 2024-02-08 00:40:48 +08:00
31 changed files with 821 additions and 700 deletions
+1
View File
@@ -9,3 +9,4 @@ temp
cmd_config.json
addons/plugins/
data/
cookies.json
+6
View File
@@ -32,6 +32,11 @@
1. 可视化面板
2. Docker 一键部署项目:[链接](https://astrbot.soulter.top/center/docs/%E9%83%A8%E7%BD%B2/%E9%80%9A%E8%BF%87Docker%E9%83%A8%E7%BD%B2)
🌍支持的消息平台/接口
- go-cqhttpQQ、QQ频道)
- QQ 官方机器人接口
- Telegram(由 [astrbot_plugin_telegram](https://github.com/Soulter/astrbot_plugin_telegram) 插件支持)
🌍支持的AI语言模型一览:
**文字模型/图片理解**
@@ -41,6 +46,7 @@
- OpenAI GPT-4(原生支持)
- Claude(免费,由[LLMs插件](https://github.com/Soulter/llms)支持)
- HuggingChat(免费,由[LLMs插件](https://github.com/Soulter/llms)支持)
- Gemini(免费,由[LLMs插件](https://github.com/Soulter/llms)支持)
**图片生成**
- OpenAI Dalle 接口
+4 -4
View File
@@ -104,14 +104,14 @@ class DashBoardHelper():
)
qq_gocq_platform_group = DashBoardConfig(
config_type="group",
name="GO-CQHTTP 平台配置",
name="OneBot协议平台配置",
description="",
body=[
DashBoardConfig(
config_type="item",
val_type="bool",
name="启用 GO-CQHTTP 平台",
description="gocq 是一个基于 HTTP 协议的 CQHTTP 协议的实现。详见 github.com/Mrs4s/go-cqhttp",
name="启用",
description="支持cq-http、shamrock等(目前仅支持QQ平台)",
value=config['gocqbot']['enable'],
path="gocqbot.enable",
),
@@ -135,7 +135,7 @@ class DashBoardHelper():
config_type="item",
val_type="int",
name="WebSocket 服务器端口",
description="",
description="目前仅支持正向 WebSocket",
value=config['gocq_websocket_port'],
path="gocq_websocket_port",
),
+19 -21
View File
@@ -7,6 +7,7 @@ import logging
from cores.database.conn import dbConn
from util.cmd_config import CmdConfig
from util.updator import check_update, update_project, request_release_info
from cores.qqbot.types import *
import util.plugin_util as putil
import websockets
import json
@@ -20,7 +21,7 @@ class DashBoardData():
stats: dict
configs: dict
logs: dict
plugins: list[dict]
plugins: List[RegisteredPlugin]
@dataclass
class Response():
@@ -33,7 +34,7 @@ class AstrBotDashBoard():
self.global_object = global_object
self.loop = asyncio.get_event_loop()
asyncio.set_event_loop(self.loop)
self.dashboard_data = global_object.dashboard_data
self.dashboard_data: DashBoardData = global_object.dashboard_data
self.dashboard_be = Flask(__name__, static_folder="dist", static_url_path="/")
log = logging.getLogger('werkzeug')
log.setLevel(logging.ERROR)
@@ -151,13 +152,13 @@ class AstrBotDashBoard():
def get_plugins():
_plugin_resp = []
for plugin in self.dashboard_data.plugins:
_p = self.dashboard_data.plugins[plugin]
_p = plugin.metadata
_t = {
"name": _p["info"]["name"],
"repo": '' if "repo" not in _p["info"] else _p["info"]["repo"],
"author": _p["info"]["author"],
"desc": _p["info"]["desc"],
"version": _p["info"]["version"]
"name": _p.plugin_name,
"repo": '' if _p.repo is None else _p.repo,
"author": _p.author,
"desc": _p.desc,
"version": _p.version
}
_plugin_resp.append(_t)
return Response(
@@ -332,8 +333,8 @@ class AstrBotDashBoard():
"tag": ""
},
{
"title": "QQ_GOCQ",
"desc": "go-cqhttp",
"title": "OneBot协议",
"desc": "支持cq-http、shamrock等(目前仅支持QQ平台)",
"namespace": "internal_platform_qq_gocq",
"tag": ""
}
@@ -359,17 +360,14 @@ class AstrBotDashBoard():
}
]
for plugin in self.global_object.cached_plugins:
# 从插件信息中获取 plugin_type 字段,如果有则归类到对应的大纲中
if "plugin_type" in self.global_object.cached_plugins[plugin]["info"]:
_t = self.global_object.cached_plugins[plugin]["info"]["plugin_type"]
for item in outline:
if item["type"] == _t:
item["body"].append({
"title": self.global_object.cached_plugins[plugin]["info"]["name"],
"desc": self.global_object.cached_plugins[plugin]["info"]["desc"],
"namespace": plugin,
"tag": plugin,
})
for item in outline:
if item['type'] == plugin.metadata.plugin_type:
item['body'].append({
"title": plugin.metadata.plugin_name,
"desc": plugin.metadata.desc,
"namespace": plugin.metadata.plugin_name,
"tag": plugin.metadata.plugin_name
})
return outline
def register(self, name: str):
-5
View File
@@ -1,5 +0,0 @@
# helloworld
QQChannelChatGPT项目的测试插件
A test plugin for QQChannelChatGPT plugin feature
+120 -15
View File
@@ -1,13 +1,25 @@
import os
import shutil
from nakuru.entities.components import *
from nakuru import (
GroupMessage,
FriendMessage
)
from botpy.message import Message, DirectMessage
from cores.qqbot.global_object import (
AstrMessageEvent,
CommandResult
)
flag_not_support = False
try:
from util.plugin_dev.api.v1.config import *
from util.plugin_dev.api.v1.bot import (
PluginMetadata,
PluginType,
AstrMessageEvent,
CommandResult,
)
from util.plugin_dev.api.v1.register import register_llm, unregister_llm
except ImportError:
flag_not_support = True
print("llms: 导入接口失败。请升级到 AstrBot 最新版本。")
'''
注意改插件名噢!格式:XXXPlugin 或 Main
@@ -18,7 +30,14 @@ class HelloWorldPlugin:
初始化函数, 可以选择直接pass
"""
def __init__(self) -> None:
print("hello, world!")
# 复制旧配置文件到 data 目录下。
if os.path.exists("keyword.json"):
shutil.move("keyword.json", "data/keyword.json")
self.keywords = {}
if os.path.exists("data/keyword.json"):
self.keywords = json.load(open("data/keyword.json", "r"))
else:
self.save_keyword()
"""
机器人程序会调用此函数。
@@ -28,20 +47,106 @@ class HelloWorldPlugin:
"""
def run(self, ame: AstrMessageEvent):
if ame.message_str == "helloworld":
# return True, tuple([True, "Hello World!!", "helloworld"])
return CommandResult(
hit=True,
success=True,
message_chain=[Plain("Hello World!!")],
command_name="helloworld"
)
if ame.message_str.startswith("/keyword") or ame.message_str.startswith("keyword"):
return self.handle_keyword_command(ame)
ret = self.check_keyword(ame.message_str)
if ret: return ret
return CommandResult(
hit=False,
success=False,
message_chain=None,
command_name=None
)
def handle_keyword_command(self, ame: AstrMessageEvent):
l = ame.message_str.split(" ")
# 获取图片
image_url = ""
for comp in ame.message_obj.message:
if isinstance(comp, Image) and image_url == "":
if comp.url is None:
image_url = comp.file
else:
image_url = comp.url
command_result = CommandResult(
hit=True,
success=False,
message_chain=None,
command_name="keyword"
)
if len(l) == 1 or (len(l) == 2 and image_url == ""):
ret = """【设置关键词回复】
示例:
1. keyword <触发词> <回复词>
keyword hi 你好
发送 hi 回复你好
* 回复词支持图片
2. keyword d <触发词>
keyword d hi
删除 hi 触发词产生的回复"""
command_result.success = True
command_result.message_chain = [Plain(ret)]
return command_result
elif len(l) == 3 and l[1] == "d":
if l[2] not in self.keywords:
command_result.message_chain = [Plain(f"关键词 {l[2]} 不存在")]
return command_result
self.keywords.pop(l[2])
self.save_keyword()
command_result.success = True
command_result.message_chain = [Plain("删除成功")]
return command_result
else:
return CommandResult(
hit=False,
success=False,
message_chain=None,
command_name=None
)
self.keywords[l[1]] = {
"plain_text": " ".join(l[2:]),
"image_url": image_url
}
self.save_keyword()
command_result.success = True
command_result.message_chain = [Plain("设置成功")]
return command_result
def save_keyword(self):
json.dump(self.keywords, open("data/keyword.json", "w"), ensure_ascii=False)
def check_keyword(self, message_str: str):
for k in self.keywords:
if message_str == k:
plain_text = ""
if 'plain_text' in self.keywords[k]:
plain_text = self.keywords[k]['plain_text']
else:
plain_text = self.keywords[k]
image_url = ""
if 'image_url' in self.keywords[k]:
image_url = self.keywords[k]['image_url']
if image_url != "":
res = [Plain(plain_text), Image.fromURL(image_url)]
return CommandResult(
hit=True,
success=True,
message_chain=res,
command_name="keyword"
)
return CommandResult(
hit=True,
success=True,
message_chain=[Plain(plain_text)],
command_name="keyword"
)
"""
插件元信息。
当用户输入 plugin v 插件名称 时,会调用此函数,返回帮助信息。
@@ -58,8 +163,8 @@ class HelloWorldPlugin:
def info(self):
return {
"name": "helloworld",
"desc": "测试插件",
"help": "测试插件, 回复 helloworld 即可触发",
"version": "v1.2",
"desc": "这是 AstrBot 的默认插件,支持关键词回复。",
"help": "输入 /keyword 查看关键词回复帮助。",
"version": "v1.3",
"author": "Soulter"
}
Submodule addons/plugins/llms deleted from ec088771a3
-23
View File
@@ -1,23 +0,0 @@
'''
监测机器性能
- Bot 内存使用量
- CPU 占用率
'''
import psutil
from cores.qqbot.global_object import GlobalObject
import time
def run_monitor(global_object: GlobalObject):
'''运行监测'''
start_time = time.time()
while True:
stat = global_object.dashboard_data.stats
# 程序占用的内存大小
mem = psutil.Process().memory_info().rss / 1024 / 1024 # MB
stat['sys_perf'] = {
'memory': mem,
'cpu': psutil.cpu_percent()
}
stat['sys_start_time'] = start_time
time.sleep(30)
+63 -140
View File
@@ -1,49 +1,50 @@
import re
import json
import threading
import asyncio
import time
import requests
import aiohttp
import util.unfit_words as uw
import os
import sys
from cores.qqbot.personality import personalities
from addons.baidu_aip_judge import BaiduJudge
import io
import traceback
import util.function_calling.gplugin as gplugin
import util.plugin_util as putil
from PIL import Image as PILImage
from typing import Union
from nakuru import (
GroupMessage,
FriendMessage,
GuildMessage,
)
from model.platform._nakuru_translation_layer import NakuruGuildMember, NakuruGuildMessage
from nakuru.entities.components import Plain,At,Image
from nakuru.entities.components import Plain, At, Image
from addons.baidu_aip_judge import BaiduJudge
from model.platform._nakuru_translation_layer import NakuruGuildMessage
from model.provider.provider import Provider
from model.command.command import Command
from util import general_utils as gu
from util.general_utils import Logger
from util.general_utils import Logger, upload, run_monitor
from util.cmd_config import CmdConfig as cc
from util.cmd_config import init_astrbot_config_items
import util.function_calling.gplugin as gplugin
import util.plugin_util as putil
from PIL import Image as PILImage
import io
import traceback
from . global_object import GlobalObject
from typing import Union
from .types import *
from addons.dashboard.helper import DashBoardHelper
from addons.dashboard.server import DashBoardData
from cores.monitor.perf import run_monitor
from cores.database.conn import dbConn
from model.platform._message_result import MessageResult
# 用户发言频率
user_frequency = {}
# 时间默认值
frequency_time = 60
# 计数默认值
frequency_count = 2
frequency_count = 10
# 版本
version = '3.1.5'
version = '3.1.8'
# 语言模型
REV_CHATGPT = 'rev_chatgpt'
@@ -57,8 +58,6 @@ llm_wake_prefix = ""
# 百度内容审核实例
baidu_judge = None
# 关键词回复
keywords = {}
# CLI
PLATFORM_CLI = 'cli'
@@ -69,36 +68,6 @@ init_astrbot_config_items()
_global_object: GlobalObject = None
logger: Logger = Logger()
# 统计消息数据
def upload():
global version
while True:
addr_ip = ''
try:
o = {
"cnt_total": _global_object.cnt_total,
"admin": _global_object.admin_qq,
}
o_j = json.dumps(o)
res = {
"version": version,
"count": _global_object.cnt_total,
"cntqc": -1,
"cntgc": -1,
"ip": addr_ip,
"others": o_j,
"sys": sys.platform,
}
logger.log(res, gu.LEVEL_DEBUG, tag="Uploader")
resp = requests.post('https://api.soulter.top/upload', data=json.dumps(res), timeout=5)
if resp.status_code == 200:
ok = resp.json()
if ok['status'] == 'ok':
_global_object.cnt_total = 0
except BaseException as e:
pass
time.sleep(10*60)
# 语言模型选择
def privider_chooser(cfg):
l = []
@@ -111,11 +80,11 @@ def privider_chooser(cfg):
'''
初始化机器人
'''
def initBot(cfg):
def init(cfg):
global llm_instance, llm_command_instance
global baidu_judge, chosen_provider
global frequency_count, frequency_time
global keywords, _global_object
global _global_object
global logger
# 迁移旧配置
@@ -128,10 +97,8 @@ def initBot(cfg):
# 初始化 global_object
_global_object = GlobalObject()
_global_object.version = version
_global_object.base_config = cfg
_global_object.stat['session'] = {}
_global_object.stat['message'] = {}
_global_object.stat['platform'] = {}
_global_object.logger = logger
logger.log("AstrBot v"+version, gu.LEVEL_INFO)
@@ -156,6 +123,7 @@ def initBot(cfg):
llm_instance[REV_CHATGPT] = ProviderRevChatGPT(cfg['rev_ChatGPT'], base_url=cc.get("CHATGPT_BASE_URL", None))
llm_command_instance[REV_CHATGPT] = CommandRevChatGPT(llm_instance[REV_CHATGPT], _global_object)
chosen_provider = REV_CHATGPT
_global_object.llms.append(RegisteredLLM(llm_name=REV_CHATGPT, llm_instance=llm_instance[REV_CHATGPT], origin="internal"))
else:
input("请退出本程序, 然后在配置文件中填写rev_ChatGPT相关配置")
if OPENAI_OFFICIAL in prov:
@@ -165,13 +133,9 @@ def initBot(cfg):
from model.command.openai_official import CommandOpenAIOfficial
llm_instance[OPENAI_OFFICIAL] = ProviderOpenAIOfficial(cfg['openai'])
llm_command_instance[OPENAI_OFFICIAL] = CommandOpenAIOfficial(llm_instance[OPENAI_OFFICIAL], _global_object)
_global_object.llms.append(RegisteredLLM(llm_name=OPENAI_OFFICIAL, llm_instance=llm_instance[OPENAI_OFFICIAL], origin="internal"))
chosen_provider = OPENAI_OFFICIAL
# 得到关键词
if os.path.exists("keyword.json"):
with open("keyword.json", 'r', encoding='utf-8') as f:
keywords = json.load(f)
# 检查provider设置偏好
p = cc.get("chosen_provider", None)
if p is not None and p in llm_instance:
@@ -185,7 +149,7 @@ def initBot(cfg):
except BaseException as e:
logger.log("百度内容审核初始化失败", gu.LEVEL_ERROR)
threading.Thread(target=upload, daemon=True).start()
threading.Thread(target=upload, args=(_global_object, ), daemon=True).start()
# 得到发言频率配置
if 'limit' in cfg:
@@ -220,7 +184,7 @@ def initBot(cfg):
_command = Command(None, _global_object)
ok, err = putil.plugin_reload(_global_object.cached_plugins)
if ok:
logger.log(f"成功载入{len(_global_object.cached_plugins)}个插件", gu.LEVEL_INFO)
logger.log(f"成功载入 {len(_global_object.cached_plugins)} 个插件", gu.LEVEL_INFO)
else:
logger.log(err, gu.LEVEL_ERROR)
@@ -273,34 +237,6 @@ def initBot(cfg):
dashboard_thread.join()
async def cli():
time.sleep(1)
while True:
try:
prompt = input(">>> ")
if prompt == "":
continue
ngm = await cli_pack_message(prompt)
await oper_msg(ngm, True, PLATFORM_CLI)
except EOFError:
return
async def cli_pack_message(prompt: str) -> NakuruGuildMessage:
ngm = NakuruGuildMessage()
ngm.channel_id = 6180
ngm.user_id = 6180
ngm.message = [Plain(prompt)]
ngm.type = "GuildMessage"
ngm.self_id = 6180
ngm.self_tiny_id = 6180
ngm.guild_id = 6180
ngm.sender = NakuruGuildMember()
ngm.sender.tiny_id = 6180
ngm.sender.user_id = 6180
ngm.sender.nickname = "CLI"
ngm.sender.role = 0
return ngm
'''
运行 QQ_OFFICIAL 机器人
'''
@@ -308,7 +244,7 @@ def run_qqchan_bot(cfg: dict, global_object: GlobalObject):
try:
from model.platform.qq_official import QQOfficial
qqchannel_bot = QQOfficial(cfg=cfg, message_handler=oper_msg, global_object=global_object)
global_object.platform_qqchan = qqchannel_bot
global_object.platforms.append(RegisteredPlatform(platform_name="qqchan", platform_instance=qqchannel_bot, origin="internal"))
qqchannel_bot.run()
except BaseException as e:
logger.log("启动QQ频道机器人时出现错误, 原因如下: " + str(e), gu.LEVEL_CRITICAL, tag="QQ频道")
@@ -320,25 +256,27 @@ def run_qqchan_bot(cfg: dict, global_object: GlobalObject):
def run_gocq_bot(cfg: dict, _global_object: GlobalObject):
from model.platform.qq_gocq import QQGOCQ
logger.log("正在检查本地GO-CQHTTP连接...端口5700, 6700", tag="QQ")
noticed = False
host = cc.get("gocq_host", "127.0.0.1")
port = cc.get("gocq_websocket_port", 6700)
http_port = cc.get("gocq_http_port", 5700)
logger.log(f"正在检查连接...host: {host}, ws port: {port}, http port: {http_port}", tag="QQ")
while True:
if not gu.port_checker(5700, cc.get("gocq_host", "127.0.0.1")) or not gu.port_checker(6700, cc.get("gocq_host", "127.0.0.1")):
if not gu.port_checker(port=port, host=host) or not gu.port_checker(port=http_port, host=host):
if not noticed:
noticed = True
logger.log("与GO-CQHTTP通信失败, 请检查GO-CQHTTP是否启动并正确配置。程序会每隔 5s 自动重试。", gu.LEVEL_CRITICAL, tag="QQ")
logger.log(f"连接到{host}:{port}(或{http_port})失败。程序会每隔 5s 自动重试。", gu.LEVEL_CRITICAL, tag="QQ")
time.sleep(5)
else:
logger.log("检查完毕,未发现问题。", tag="QQ")
break
try:
qq_gocq = QQGOCQ(cfg=cfg, message_handler=oper_msg, global_object=_global_object)
_global_object.platform_qq = qq_gocq
_global_object.platforms.append(RegisteredPlatform(platform_name="gocq", platform_instance=qq_gocq, origin="internal"))
qq_gocq.run()
except BaseException as e:
input("启动QQ机器人出现错误"+str(e))
'''
检查发言频率
'''
@@ -379,17 +317,27 @@ async def oper_msg(message: Union[GroupMessage, FriendMessage, GuildMessage, Nak
message: 消息对象
session_id: 该消息源的唯一识别号
role: member | admin
platform: 平台(gocq, qqchan)
platform: str 所注册的平台的名称。如果没有注册,将抛出一个异常。
"""
global chosen_provider, keywords, _global_object
global chosen_provider, _global_object
message_str = ''
session_id = session_id
role = role
hit = False # 是否命中指令
command_result = () # 调用指令返回的结果
# 获取平台实例
reg_platform: RegisteredPlatform = None
for p in _global_object.platforms:
if p.platform_name == platform:
reg_platform = p
break
if not reg_platform:
_global_object.logger.log(f"未找到平台 {platform} 的实例。", gu.LEVEL_ERROR)
raise Exception(f"未找到平台 {platform} 的实例。")
# 统计数据,如频道消息量
record_message(platform, session_id)
await record_message(platform, session_id)
for i in message.message:
if isinstance(i, Plain):
@@ -398,25 +346,8 @@ async def oper_msg(message: Union[GroupMessage, FriendMessage, GuildMessage, Nak
return MessageResult("Hi~")
# 检查发言频率
user_id = message.user_id
if not check_frequency(user_id):
if not check_frequency(message.user_id):
return MessageResult(f'你的发言超过频率限制(╯▔皿▔)╯。\n管理员设置{frequency_time}秒内只能提问{frequency_count}次。')
# 关键词回复
for k in keywords:
if message_str == k:
plain_text = ""
if 'plain_text' in keywords[k]:
plain_text = keywords[k]['plain_text']
else:
plain_text = keywords[k]
image_url = ""
if 'image_url' in keywords[k]:
image_url = keywords[k]['image_url']
if image_url != "":
res = [Plain(plain_text), Image.fromURL(image_url)]
return MessageResult(res)
return MessageResult(plain_text)
# 检查是否是更换语言模型的请求
temp_switch = ""
@@ -439,11 +370,12 @@ async def oper_msg(message: Union[GroupMessage, FriendMessage, GuildMessage, Nak
llm_result_str = ""
hit, command_result = llm_command_instance[chosen_provider].check_command(
# check commands and plugins
hit, command_result = await llm_command_instance[chosen_provider].check_command(
message_str,
session_id,
role,
platform,
reg_platform,
message,
)
@@ -455,11 +387,12 @@ async def oper_msg(message: Union[GroupMessage, FriendMessage, GuildMessage, Nak
if matches:
return MessageResult(f"你的提问得到的回复未通过【默认关键词拦截】服务, 不予回复。")
if baidu_judge != None:
check, msg = baidu_judge.judge(message_str)
check, msg = await asyncio.to_thread(baidu_judge.judge, message_str)
if not check:
return MessageResult(f"你的提问得到的回复未通过【百度AI内容审核】服务, 不予回复。\n\n{msg}")
if chosen_provider == NONE_LLM:
return MessageResult("没有启动任何 LLM 并且未触发任何指令。")
logger.log("一条消息由于 Bot 未启动任何语言模型并且未触发指令而将被忽略。", gu.LEVEL_WARNING)
return
try:
if llm_wake_prefix != "" and not message_str.startswith(llm_wake_prefix):
return
@@ -483,9 +416,9 @@ async def oper_msg(message: Union[GroupMessage, FriendMessage, GuildMessage, Nak
if chosen_provider == REV_CHATGPT or chosen_provider == OPENAI_OFFICIAL:
if _global_object.web_search or web_sch_flag:
official_fc = chosen_provider == OPENAI_OFFICIAL
llm_result_str = gplugin.web_search(message_str, llm_instance[chosen_provider], session_id, official_fc)
llm_result_str = await gplugin.web_search(message_str, llm_instance[chosen_provider], session_id, official_fc)
else:
llm_result_str = str(llm_instance[chosen_provider].text_chat(message_str, session_id, image_url, default_personality = _global_object.default_personality))
llm_result_str = await llm_instance[chosen_provider].text_chat(message_str, session_id, image_url, default_personality = _global_object.default_personality)
llm_result_str = _global_object.reply_prefix + llm_result_str
except BaseException as e:
@@ -496,23 +429,13 @@ async def oper_msg(message: Union[GroupMessage, FriendMessage, GuildMessage, Nak
if temp_switch != "":
chosen_provider = temp_switch
# 指令回复
if hit:
# 检查指令。command_result 是一个元组:(指令调用是否成功, 指令返回的文本结果, 指令类型)
# 有指令或者插件触发
# command_result 是一个元组:(指令调用是否成功, 指令返回的文本结果, 指令类型)
if command_result == None:
return
command = command_result[2]
if command == "keyword":
if os.path.exists("keyword.json"):
with open("keyword.json", "r", encoding="utf-8") as f:
keywords = json.load(f)
else:
try:
return MessageResult(command_result[1])
except BaseException as e:
return MessageResult(f"回复消息出错: {str(e)}")
if command == "update latest r":
def update_restart():
py = sys.executable
@@ -526,11 +449,11 @@ async def oper_msg(message: Union[GroupMessage, FriendMessage, GuildMessage, Nak
if isinstance(command_result[1], list) and len(command_result) == 3 and command == 'draw':
for i in command_result[1]:
# 保存到本地
pic_res = requests.get(i, stream = True)
if pic_res.status_code == 200:
image = PILImage.open(io.BytesIO(pic_res.content))
return MessageResult([Image.fromFileSystem(gu.save_temp_img(image))])
async with aiohttp.ClientSession() as session:
async with session.get(i) as resp:
if resp.status == 200:
image = PILImage.open(io.BytesIO(await resp.read()))
return MessageResult([Image.fromFileSystem(gu.save_temp_img(image))])
# 其他指令
else:
try:
@@ -545,7 +468,7 @@ async def oper_msg(message: Union[GroupMessage, FriendMessage, GuildMessage, Nak
llm_result_str = re.sub(i, "***", llm_result_str)
# 百度内容审核服务二次审核
if baidu_judge != None:
check, msg = baidu_judge.judge(llm_result_str)
check, msg = await asyncio.to_thread(baidu_judge.judge, llm_result_str)
if not check:
return MessageResult(f"你的提问得到的回复【百度内容审核】未通过,不予回复。\n\n{msg}")
# 发送信息
-92
View File
@@ -1,92 +0,0 @@
from model.platform.qq_official import QQOfficial, NakuruGuildMember, NakuruGuildMessage
from model.platform.qq_gocq import QQGOCQ
from model.provider.provider import Provider
from addons.dashboard.server import DashBoardData
from nakuru import (
CQHTTP,
GroupMessage,
GroupMemberIncrease,
FriendMessage,
GuildMessage,
Notify
)
from typing import Union
class GlobalObject:
'''
存放一些公用的数据,用于在不同模块(如core与command)之间传递
'''
nick: str # gocq 的昵称
base_config: dict # config.json
cached_plugins: dict # 缓存的插件
web_search: bool # 是否开启了网页搜索
reply_prefix: str
admin_qq: str
admin_qqchan: str
unique_session: bool
cnt_total: int
platform_qq: QQGOCQ
platform_qqchan: QQOfficial
default_personality: dict
dashboard_data: DashBoardData
stat: dict
logger: None
def __init__(self):
self.nick = None # gocq 的昵称
self.base_config = None # config.yaml
self.cached_plugins = {} # 缓存的插件
self.web_search = False # 是否开启了网页搜索
self.reply_prefix = None
self.admin_qq = "123456"
self.admin_qqchan = "123456"
self.unique_session = False
self.cnt_total = 0
self.platform_qq = None
self.platform_qqchan = None
self.default_personality = None
self.dashboard_data = None
self.stat = {}
class AstrMessageEvent():
message_str: str # 纯消息字符串
message_obj: Union[GroupMessage, FriendMessage, GuildMessage, NakuruGuildMessage] # 消息对象
gocq_platform: QQGOCQ
qq_sdk_platform: QQOfficial
platform: str # `gocq` 或 `qqchan`
role: str # `admin` 或 `member`
global_object: GlobalObject # 一些公用数据
session_id: int # 会话id (可能是群id,也可能是某个user的id。取决于是否开启了 unique_session)
def __init__(self, message_str: str,
message_obj: Union[GroupMessage, FriendMessage, GuildMessage, NakuruGuildMessage],
gocq_platform: QQGOCQ,
qq_sdk_platform: QQOfficial,
platform: str,
role: str,
global_object: GlobalObject,
llm_provider: Provider = None,
session_id: int = None):
self.message_str = message_str
self.message_obj = message_obj
self.gocq_platform = gocq_platform
self.qq_sdk_platform = qq_sdk_platform
self.platform = platform
self.role = role
self.global_object = global_object
self.llm_provider = llm_provider
self.session_id = session_id
class CommandResult():
'''
用于在Command中返回多个值
'''
def __init__(self, hit: bool, success: bool, message_chain: list, command_name: str = "unknown_command") -> None:
self.hit = hit
self.success = success
self.message_chain = message_chain
self.command_name = command_name
def _result_tuple(self):
return (self.success, self.message_chain, self.command_name)
+140
View File
@@ -0,0 +1,140 @@
from model.platform.qq_official import NakuruGuildMessage
from model.provider.provider import Provider as LLMProvider
from model.platform._platfrom import Platform
from nakuru import (
GroupMessage,
FriendMessage,
GuildMessage,
)
from typing import Union, List, ClassVar
from types import ModuleType
from enum import Enum
from dataclasses import dataclass
class PluginType(Enum):
PLATFORM = 'platfrom' # 平台类插件。
LLM = 'llm' # 大语言模型类插件
COMMON = 'common' # 其他插件
@dataclass
class PluginMetadata:
'''
插件的元数据。
'''
# required
plugin_name: str
plugin_type: PluginType
author: str # 插件作者
desc: str # 插件简介
version: str # 插件版本
# optional
repo: str = None # 插件仓库地址
def __str__(self) -> str:
return f"PluginMetadata({self.plugin_name}, {self.plugin_type}, {self.desc}, {self.version}, {self.repo})"
@dataclass
class RegisteredPlugin:
'''
注册在 AstrBot 中的插件。
'''
metadata: PluginMetadata
plugin_instance: object
module_path: str
module: ModuleType
root_dir_name: str
def __str__(self) -> str:
return f"RegisteredPlugin({self.metadata}, {self.module_path}, {self.root_dir_name})"
RegisteredPlugins = List[RegisteredPlugin]
@dataclass
class RegisteredPlatform:
'''
注册在 AstrBot 中的平台。平台应当实现 Platform 接口。
'''
platform_name: str
platform_instance: Platform
origin: str = None # 注册来源
@dataclass
class RegisteredLLM:
'''
注册在 AstrBot 中的大语言模型调用。大语言模型应当实现 LLMProvider 接口。
'''
llm_name: str
llm_instance: LLMProvider
origin: str = None # 注册来源
class GlobalObject:
'''
存放一些公用的数据,用于在不同模块(如core与command)之间传递
'''
version: str # 机器人版本
nick: str # 用户定义的机器人的别名
base_config: dict # config.json 中导出的配置
cached_plugins: List[RegisteredPlugin] # 加载的插件
platforms: List[RegisteredPlatform]
llms: List[RegisteredLLM]
web_search: bool # 是否开启了网页搜索
reply_prefix: str # 回复前缀
unique_session: bool # 是否开启了独立会话
cnt_total: int # 总消息数
default_personality: dict
dashboard_data = None
logger: None
def __init__(self):
self.nick = None # gocq 的昵称
self.base_config = None # config.yaml
self.cached_plugins = [] # 缓存的插件
self.web_search = False # 是否开启了网页搜索
self.reply_prefix = None
self.unique_session = False
self.cnt_total = 0
self.platforms = []
self.llms = []
self.default_personality = None
self.dashboard_data = None
self.stat = {}
class AstrMessageEvent():
'''
消息事件。
'''
context: GlobalObject # 一些公用数据
message_str: str # 纯消息字符串
message_obj: Union[GroupMessage, FriendMessage, GuildMessage, NakuruGuildMessage] # 消息对象
platform: RegisteredPlatform # 来源平台
role: str # 基本身份。`admin` 或 `member`
session_id: int # 会话 id
def __init__(self,
message_str: str,
message_obj: Union[GroupMessage, FriendMessage, GuildMessage, NakuruGuildMessage],
platform: RegisteredPlatform,
role: str,
context: GlobalObject,
session_id: str = None):
self.context = context
self.message_str = message_str
self.message_obj = message_obj
self.platform = platform
self.role = role
self.session_id = session_id
class CommandResult():
'''
用于在Command中返回多个值
'''
def __init__(self, hit: bool, success: bool, message_chain: list, command_name: str = "unknown_command") -> None:
self.hit = hit
self.success = success
self.message_chain = message_chain
self.command_name = command_name
def _result_tuple(self):
return (self.success, self.message_chain, self.command_name)
+2 -2
View File
@@ -23,7 +23,7 @@ def main():
print(file_not_found)
input("配置文件不存在,请检查是否已经下载配置文件。")
except BaseException as e:
print(e)
raise e
# 设置代理
if 'http_proxy' in cfg and cfg['http_proxy'] != '':
@@ -42,7 +42,7 @@ def main():
os.mkdir(abs_path+"data/config")
# 启动主程序(cores/qqbot/core.py
qqBot.initBot(cfg)
qqBot.init(cfg)
def check_env(ch_mirror=False):
if not (sys.version_info.major == 3 and sys.version_info.minor >= 9):
+58 -115
View File
@@ -1,19 +1,29 @@
import json
from util import general_utils as gu
import os
import requests
from model.provider.provider import Provider
import inspect
import aiohttp
import asyncio
import json
import util.plugin_util as putil
from util.cmd_config import CmdConfig as cc
from util.general_utils import Logger
import util.updator
from nakuru.entities.components import (
Plain,
Image
)
from cores.qqbot.global_object import GlobalObject, AstrMessageEvent
from cores.qqbot.global_object import CommandResult
from util import general_utils as gu
from model.provider.provider import Provider
from util.cmd_config import CmdConfig as cc
from util.general_utils import Logger
from cores.qqbot.types import (
GlobalObject,
AstrMessageEvent,
PluginType,
CommandResult,
RegisteredPlugin,
RegisteredPlatform
)
from typing import List, Tuple
PLATFORM_QQCHAN = 'qqchan'
PLATFORM_GOCQ = 'gocq'
@@ -25,11 +35,11 @@ class Command:
self.global_object = global_object
self.logger: Logger = global_object.logger
def check_command(self,
async def check_command(self,
message,
session_id: str,
role,
platform,
role: str,
platform: RegisteredPlatform,
message_obj):
self.platform = platform
# 插件
@@ -38,20 +48,21 @@ class Command:
ame = AstrMessageEvent(
message_str=message,
message_obj=message_obj,
gocq_platform=self.global_object.platform_qq,
qq_sdk_platform=self.global_object.platform_qqchan,
platform=platform,
role=role,
global_object=self.global_object,
context=self.global_object,
session_id = session_id
)
# 从已启动的插件中查找是否有匹配的指令
for k, v in cached_plugins.items():
for plugin in cached_plugins:
# 过滤掉平台类插件
if "type" in v["info"] and v["info"]["plugin_type"] == "platform":
if plugin.metadata.plugin_type == PluginType.PLATFORM:
continue
try:
result = v["clsobj"].run(ame)
if inspect.iscoroutinefunction(plugin.plugin_instance.run):
result = await plugin.plugin_instance.run(ame)
else:
result = await asyncio.to_thread(plugin.plugin_instance.run, ame)
if isinstance(result, CommandResult):
hit = result.hit
res = result._result_tuple()
@@ -65,13 +76,16 @@ class Command:
except TypeError as e:
# 参数不匹配,尝试使用旧的参数方案
try:
hit, res = v["clsobj"].run(message, role, platform, message_obj, self.global_object.platform_qq)
if inspect.iscoroutinefunction(plugin.plugin_instance.run):
hit, res = await plugin.plugin_instance.run(message, role, platform, message_obj, self.global_object.platform_qq)
else:
hit, res = await asyncio.to_thread(plugin.plugin_instance.run, message, role, platform, message_obj, self.global_object.platform_qq)
if hit:
return True, res
except BaseException as e:
self.logger.log(f"{k}插件异常,原因: {str(e)}\n已安装插件: {cached_plugins.keys}\n如果你没有相关装插件的想法, 请直接忽略此报错, 不影响其他功能的运行。", level=gu.LEVEL_WARNING)
self.logger.log(f"{plugin.metadata.plugin_name} 插件异常,原因: {str(e)}\n如果你没有相关装插件的想法, 请直接忽略此报错, 不影响其他功能的运行。", level=gu.LEVEL_WARNING)
except BaseException as e:
self.logger.log(f"{k} 插件异常,原因: {str(e)}\n已安装插件: {cached_plugins.keys}\n如果你没有相关装插件的想法, 请直接忽略此报错, 不影响其他功能的运行。", level=gu.LEVEL_WARNING)
self.logger.log(f"{plugin.metadata.plugin_name} 插件异常,原因: {str(e)}\n如果你没有相关装插件的想法, 请直接忽略此报错, 不影响其他功能的运行。", level=gu.LEVEL_WARNING)
if self.command_start_with(message, "nick"):
return True, self.set_nick(message, platform, role)
@@ -79,20 +93,13 @@ class Command:
return True, self.plugin_oper(message, role, cached_plugins, platform)
if self.command_start_with(message, "myid") or self.command_start_with(message, "!myid"):
return True, self.get_my_id(message_obj, platform)
if self.command_start_with(message, "nconf") or self.command_start_with(message, "newconf"):
return True, self.get_new_conf(message, role)
if self.command_start_with(message, "web"): # 网页搜索
return True, self.web_search(message)
if self.command_start_with(message, "keyword"):
return True, self.keyword(message_obj, role)
if self.command_start_with(message, "ip"):
ip = requests.get("https://myip.ipip.net", timeout=5).text
return True, f"机器人 IP 信息:{ip}", "ip"
if not self.provider and self.command_start_with(message, "help"):
return True, self.help()
return True, await self.help()
return False, None
def web_search(self, message):
l = message.split(' ')
if len(l) == 1:
@@ -123,7 +130,7 @@ class Command:
'''
插件指令
'''
def plugin_oper(self, message: str, role: str, cached_plugins: dict, platform: str):
def plugin_oper(self, message: str, role: str, cached_plugins: List[RegisteredPlugin], platform: str):
l = message.split(" ")
if len(l) < 2:
p = gu.create_text_image("【插件指令面板】", "安装插件: \nplugin i 插件Github地址\n卸载插件: \nplugin d 插件名 \n重载插件: \nplugin reload\n查看插件列表:\nplugin l\n更新插件: plugin u 插件名\n")
@@ -153,25 +160,27 @@ class Command:
return False, f"更新插件失败,原因: {str(e)}\n建议: 使用 plugin i 指令进行覆盖安装(插件数据可能会丢失)", "plugin"
elif l[1] == "l":
try:
plugin_list_info = "\n".join([f"{k}: \n名称: {v['info']['name']}\n简介: {v['info']['desc']}\n版本: {v['info']['version']}\n作者: {v['info']['author']}\n" for k, v in cached_plugins.items()])
plugin_list_info = ""
for plugin in cached_plugins:
plugin_list_info += f"{plugin.metadata.plugin_name}: \n名称: {plugin.metadata.plugin_name}\n简介: {plugin.metadata.plugin_desc}\n版本: {plugin.metadata.version}\n作者: {plugin.metadata.author}\n"
p = gu.create_text_image("【已激活插件列表】", plugin_list_info + "\n使用plugin v 插件名 查看插件帮助\n")
return True, [Image.fromFileSystem(p)], "plugin"
except BaseException as e:
return False, f"获取插件列表失败,原因: {str(e)}", "plugin"
elif l[1] == "v":
try:
if l[2] in cached_plugins:
info = cached_plugins[l[2]]["info"]
info = None
for i in cached_plugins:
if i.metadata.plugin_name == l[2]:
info = i.metadata
break
if info:
p = gu.create_text_image(f"【插件信息】", f"名称: {info['name']}\n{info['desc']}\n版本: {info['version']}\n作者: {info['author']}\n\n帮助:\n{info['help']}")
return True, [Image.fromFileSystem(p)], "plugin"
else:
return False, "未找到该插件", "plugin"
except BaseException as e:
return False, f"获取插件信息失败,原因: {str(e)}", "plugin"
elif l[1] == "dev":
if role != "admin":
return False, f"你的身份组{role}没有权限开发者模式", "plugin"
return True, "cached_plugins: \n" + str(cached_plugins), "plugin"
'''
nick: 存储机器人的昵称
@@ -204,10 +213,11 @@ class Command:
"/revgpt": "切换到网页版ChatGPT",
}
def help_messager(self, commands: dict, platform: str, cached_plugins: dict = None):
async def help_messager(self, commands: dict, platform: str, cached_plugins: List[RegisteredPlugin] = None):
try:
resp = requests.get("https://soulter.top/channelbot/notice.json").text
notice = json.loads(resp)["notice"]
async with aiohttp.ClientSession() as session:
async with session.get("https://soulter.top/channelbot/notice.json") as resp:
notice = (await resp.json())["notice"]
except BaseException as e:
notice = ""
msg = "# Help Center\n## 指令列表\n"
@@ -215,7 +225,9 @@ class Command:
msg += f"`{key}` - {value}\n"
# plugins
if cached_plugins != None:
plugin_list_info = "\n".join([f"`{k}` {v['info']['name']}\n{v['info']['desc']}\n" for k, v in cached_plugins.items()])
plugin_list_info = ""
for plugin in cached_plugins:
plugin_list_info += f"`{plugin.metadata.plugin_name}` {plugin.metadata.desc}\n"
if plugin_list_info.strip() != "":
msg += "\n## 插件列表\n> 使用plugin v 插件名 查看插件帮助\n"
msg += plugin_list_info
@@ -237,78 +249,9 @@ class Command:
return True
return False
# keyword: 关键字
def keyword(self, message_obj, role: str):
if role != "admin":
return True, "你没有权限使用该指令", "keyword"
plain_text = ""
image_url = ""
for comp in message_obj.message:
if isinstance(comp, Plain):
plain_text += comp.text
elif isinstance(comp, Image) and image_url == "":
if comp.url is None:
image_url = comp.file
else:
image_url = comp.url
l = plain_text.split(" ")
if len(l) < 3 and image_url == "":
return True, """【设置关键词回复】示例:
1. keyword hi 你好
当发送hi的时候会回复你好
2. keyword /hi 你好
当发送/hi时会回复你好
3. keyword d hi
删除hi关键词的回复
4. keyword hi <图片>
当发送hi时会回复图片""", "keyword"
del_mode = False
if l[1] == "d":
del_mode = True
try:
if os.path.exists("keyword.json"):
with open("keyword.json", "r", encoding="utf-8") as f:
keyword = json.load(f)
if del_mode:
# 删除关键词
if l[2] not in keyword:
return False, "该关键词不存在", "keyword"
else: del keyword[l[2]]
else:
keyword[l[1]] = {
"plain_text": " ".join(l[2:]),
"image_url": image_url
}
else:
if del_mode:
return False, "该关键词不存在", "keyword"
keyword = {
l[1]: {
"plain_text": " ".join(l[2:]),
"image_url": image_url
}
}
with open("keyword.json", "w", encoding="utf-8") as f:
json.dump(keyword, f, ensure_ascii=False, indent=4)
f.flush()
if del_mode:
return True, "删除成功: "+l[2], "keyword"
if image_url == "":
return True, "设置成功: "+l[1]+" "+" ".join(l[2:]), "keyword"
else:
return True, [Plain("设置成功: "+l[1]+" "+" ".join(l[2:])), Image.fromURL(image_url)], "keyword"
except BaseException as e:
return False, "设置失败: "+str(e), "keyword"
def update(self, message: str, role: str):
if role != "admin":
return True, "你没有权限使用该指令", "keyword"
return True, "你没有权限使用该指令", "update"
l = message.split(" ")
if len(l) == 1:
try:
@@ -350,9 +293,9 @@ class Command:
def key(self):
return False
def help(self):
return True, self.help_messager(self.general_commands(), self.platform, self.global_object.cached_plugins), "help"
async def help(self):
ret = await self.help_messager(self.general_commands(), self.platform, self.global_object.cached_plugins)
return True, ret, "help"
def status(self):
return False
+12 -13
View File
@@ -1,17 +1,16 @@
from model.command.command import Command
from model.provider.openai_official import ProviderOpenAIOfficial
from cores.qqbot.personality import personalities
from cores.qqbot.global_object import GlobalObject
from cores.qqbot.types import GlobalObject
class CommandOpenAIOfficial(Command):
def __init__(self, provider: ProviderOpenAIOfficial, global_object: GlobalObject):
self.provider = provider
self.cached_plugins = {}
self.global_object = global_object
self.personality_str = ""
super().__init__(provider, global_object)
def check_command(self,
async def check_command(self,
message: str,
session_id: str,
role: str,
@@ -20,7 +19,7 @@ class CommandOpenAIOfficial(Command):
self.platform = platform
# 检查基础指令
hit, res = super().check_command(
hit, res = await super().check_command(
message,
session_id,
role,
@@ -32,7 +31,7 @@ class CommandOpenAIOfficial(Command):
if hit:
return True, res
if self.command_start_with(message, "reset", "重置"):
return True, self.reset(session_id, message)
return True, await self.reset(session_id, message)
elif self.command_start_with(message, "his", "历史"):
return True, self.his(message, session_id)
elif self.command_start_with(message, "token"):
@@ -42,7 +41,7 @@ class CommandOpenAIOfficial(Command):
elif self.command_start_with(message, "status"):
return True, self.status()
elif self.command_start_with(message, "help", "帮助"):
return True, self.help()
return True, await self.help()
elif self.command_start_with(message, "unset"):
return True, self.unset(session_id)
elif self.command_start_with(message, "set"):
@@ -54,11 +53,11 @@ class CommandOpenAIOfficial(Command):
elif self.command_start_with(message, "key"):
return True, self.key(message)
elif self.command_start_with(message, "switch"):
return True, self.switch(message)
return True, await self.switch(message)
return False, None
def help(self):
async def help(self):
commands = super().general_commands()
commands[''] = '画画'
commands['key'] = '添加OpenAI key'
@@ -66,15 +65,15 @@ class CommandOpenAIOfficial(Command):
commands['gpt'] = '查看gpt配置信息'
commands['status'] = '查看key使用状态'
commands['token'] = '查看本轮会话token'
return True, super().help_messager(commands, self.platform, self.global_object.cached_plugins), "help"
return True, await super().help_messager(commands, self.platform, self.global_object.cached_plugins), "help"
def reset(self, session_id: str, message: str = "reset"):
async def reset(self, session_id: str, message: str = "reset"):
if self.provider is None:
return False, "未启用 OpenAI 官方 API", "reset"
l = message.split(" ")
if len(l) == 1:
self.provider.forget(session_id)
await self.provider.forget(session_id)
return True, "重置成功", "reset"
if len(l) == 2 and l[1] == "p":
self.provider.forget(session_id)
@@ -146,7 +145,7 @@ class CommandOpenAIOfficial(Command):
else:
return True, "该Key被验证为无效。也许是输入错误了,或者重试。", "key"
def switch(self, message: str):
async def switch(self, message: str):
'''
切换账号
'''
@@ -168,7 +167,7 @@ class CommandOpenAIOfficial(Command):
else:
try:
new_key = list(key_stat.keys())[index-1]
ret = self.provider.check_key(new_key)
ret = await self.provider.check_key(new_key)
self.provider.set_key(new_key)
except BaseException as e:
return True, "账号切换失败,原因: " + str(e), "switch"
+6 -7
View File
@@ -1,24 +1,23 @@
from model.command.command import Command
from model.provider.rev_chatgpt import ProviderRevChatGPT
from cores.qqbot.personality import personalities
from cores.qqbot.global_object import GlobalObject
from cores.qqbot.types import GlobalObject
class CommandRevChatGPT(Command):
def __init__(self, provider: ProviderRevChatGPT, global_object: GlobalObject):
self.provider = provider
self.cached_plugins = {}
self.global_object = global_object
self.personality_str = ""
super().__init__(provider, global_object)
def check_command(self,
async def check_command(self,
message: str,
session_id: str,
role: str,
platform: str,
message_obj):
self.platform = platform
hit, res = super().check_command(
hit, res = await super().check_command(
message,
session_id,
role,
@@ -29,7 +28,7 @@ class CommandRevChatGPT(Command):
if hit:
return True, res
if self.command_start_with(message, "help", "帮助"):
return True, self.help()
return True, await self.help()
elif self.command_start_with(message, "reset"):
return True, self.reset(session_id, message)
elif self.command_start_with(message, "update"):
@@ -127,7 +126,7 @@ class CommandRevChatGPT(Command):
else:
return True, "参数过多。", "switch"
def help(self):
async def help(self):
commands = super().general_commands()
commands['set'] = '设置人格'
return True, super().help_messager(commands, self.platform, self.global_object.cached_plugins), "help"
return True, await super().help_messager(commands, self.platform, self.global_object.cached_plugins), "help"
+10 -22
View File
@@ -1,7 +1,5 @@
import abc
import threading
import asyncio
from typing import Callable, Union
from typing import Union
from nakuru import (
GuildMessage,
GroupMessage,
@@ -10,7 +8,7 @@ from nakuru import (
from ._nakuru_translation_layer import (
NakuruGuildMessage,
)
from nakuru.entities.components import Plain, At, Image, Node
from nakuru.entities.components import Plain, At, Image
class Platform():
@@ -22,34 +20,34 @@ class Platform():
pass
@abc.abstractmethod
def handle_msg():
async def handle_msg():
'''
处理到来的消息
'''
pass
@abc.abstractmethod
def reply_msg():
async def reply_msg():
'''
回复消息(被动发送)
'''
pass
@abc.abstractmethod
def send_msg():
async def send_msg(target: Union[GuildMessage, GroupMessage, FriendMessage, str], message: Union[str, list]):
'''
发送消息(主动发送)
'''
pass
@abc.abstractmethod
def send():
async def send(target: Union[GuildMessage, GroupMessage, FriendMessage, str], message: Union[str, list]):
'''
发送消息(主动发送)同 send_msg()
'''
pass
def parse_message_outline(self, message: Union[GuildMessage, GroupMessage, FriendMessage, str]) -> NakuruGuildMessage:
def parse_message_outline(self, message: Union[GuildMessage, GroupMessage, FriendMessage, str, list]) -> str:
'''
将消息解析成大纲消息形式。
如: xxxxx[图片]xxxxx
@@ -57,26 +55,16 @@ class Platform():
if isinstance(message, str):
return message
ret = ''
ls_to_parse = message if isinstance(message, list) else message.message
try:
for node in message.message:
for node in ls_to_parse:
if isinstance(node, Plain):
ret += node.text
elif isinstance(node, At):
ret += f'[At: {node.name}/{node.qq}]'
elif isinstance(node, Image):
ret += f'[图片]'
ret += '[图片]'
except Exception as e:
pass
ret.replace('\n', '')
return ret
def new_sub_thread(self, func, args=()):
thread = threading.Thread(target=self._runner, args=(func, args), daemon=True)
thread.start()
def _runner(self, func: Callable, args: tuple):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(func(*args))
loop.close()
+21 -44
View File
@@ -8,8 +8,7 @@ from nakuru import (
GroupMessage,
FriendMessage,
GroupMemberIncrease,
Notify,
Member
Notify
)
from typing import Union
import time
@@ -31,7 +30,6 @@ class QQGOCQ(Platform):
asyncio.set_event_loop(self.loop)
self.waiting = {}
self.gocq_cnt = 0
self.cc = CmdConfig()
self.cfg = cfg
self.logger: gu.Logger = global_object.logger
@@ -60,10 +58,10 @@ class QQGOCQ(Platform):
async def _(app: CQHTTP, source: GroupMessage):
if self.cc.get("gocq_react_group", True):
if isinstance(source.message[0], Plain):
self.new_sub_thread(self.handle_msg, (source, True))
await self.handle_msg(source, True)
elif isinstance(source.message[0], At):
if source.message[0].qq == source.self_id:
self.new_sub_thread(self.handle_msg, (source, True))
await self.handle_msg(source, True)
else:
return
@@ -71,7 +69,7 @@ class QQGOCQ(Platform):
async def _(app: CQHTTP, source: FriendMessage):
if self.cc.get("gocq_react_friend", True):
if isinstance(source.message[0], Plain):
self.new_sub_thread(self.handle_msg, (source, False))
await self.handle_msg(source, False)
else:
return
@@ -87,19 +85,16 @@ class QQGOCQ(Platform):
async def _(app: CQHTTP, source: Notify):
print(source)
if source.sub_type == "poke" and source.target_id == source.self_id:
# await self.handle_msg(source, False)
self.new_sub_thread(self.handle_msg, (source, False))
await self.handle_msg(source, False)
@gocq_app.receiver("GuildMessage")
async def _(app: CQHTTP, source: GuildMessage):
if self.cc.get("gocq_react_guild", True):
if isinstance(source.message[0], Plain):
# await self.handle_msg(source, True)
self.new_sub_thread(self.handle_msg, (source, True))
await self.handle_msg(source, True)
elif isinstance(source.message[0], At):
if source.message[0].qq == source.self_tiny_id:
# await self.handle_msg(source, True)
self.new_sub_thread(self.handle_msg, (source, True))
await self.handle_msg(source, True)
else:
return
@@ -159,7 +154,7 @@ class QQGOCQ(Platform):
if message_result is None:
return
self.reply_msg(message, message_result.result_message)
await self.reply_msg(message, message_result.result_message)
if message_result.callback is not None:
message_result.callback()
@@ -167,16 +162,14 @@ class QQGOCQ(Platform):
if session_id in self.waiting and self.waiting[session_id] == '':
self.waiting[session_id] = message
def reply_msg(self,
async def reply_msg(self,
message: Union[GroupMessage, FriendMessage, GuildMessage, Notify],
result_message: list):
"""
插件开发者请使用send方法, 可以不用直接调用这个方法。
插件开发者请使用send方法, 可以不用直接调用这个方法。
"""
source = message
res = result_message
self.gocq_cnt += 1
self.logger.log(f"{source.user_id} <- {self.parse_message_outline(res)}", tag="QQ_GOCQ")
@@ -209,12 +202,10 @@ class QQGOCQ(Platform):
# 回复消息链
if isinstance(res, list) and len(res) > 0:
if source.type == "GuildMessage":
# await self.client.sendGuildChannelMessage(source.guild_id, source.channel_id, res)
asyncio.run_coroutine_threadsafe(self.client.sendGuildChannelMessage(source.guild_id, source.channel_id, res), self.loop).result()
await self.client.sendGuildChannelMessage(source.guild_id, source.channel_id, res)
return
elif source.type == "FriendMessage":
# await self.client.sendFriendMessage(source.user_id, res)
asyncio.run_coroutine_threadsafe(self.client.sendFriendMessage(source.user_id, res), self.loop).result()
await self.client.sendFriendMessage(source.user_id, res)
return
elif source.type == "GroupMessage":
# 过长时forward发送
@@ -237,37 +228,28 @@ class QQGOCQ(Platform):
node.time = int(time.time())
# print(node)
nodes=[node]
# await self.client.sendGroupForwardMessage(source.group_id, nodes)
asyncio.run_coroutine_threadsafe(self.client.sendGroupForwardMessage(source.group_id, nodes), self.loop).result()
await self.client.sendGroupForwardMessage(source.group_id, nodes)
return
# await self.client.sendGroupMessage(source.group_id, res)
asyncio.run_coroutine_threadsafe(self.client.sendGroupMessage(source.group_id, res), self.loop).result()
await self.client.sendGroupMessage(source.group_id, res)
return
def send_msg(self, message: Union[GroupMessage, FriendMessage, GuildMessage, Notify], result_message: list):
async def send_msg(self, message: Union[GroupMessage, FriendMessage, GuildMessage, Notify], result_message: list):
'''
提供给插件的发送QQ消息接口。
参数说明:第一个参数可以是消息对象,也可以是QQ群号。第二个参数是消息内容(消息内容可以是消息链列表,也可以是纯文字信息)。
非异步
'''
try:
# await self.reply_msg(message, result_message)
self.reply_msg(message, result_message)
await self.reply_msg(message, result_message)
except BaseException as e:
raise e
def send(self,
async def send(self,
to,
res):
'''
同 send_msg()
非异步
'''
try:
# await self.send_msg(to, res)
self.reply_msg(to, res)
except BaseException as e:
raise e
await self.reply_msg(to, res)
def create_text_image(title: str, text: str, max_width=30, font_size=20):
'''
@@ -306,18 +288,13 @@ class QQGOCQ(Platform):
def get_client(self):
return self.client
def nakuru_method_invoker(self, func, *args, **kwargs):
async def nakuru_method_invoker(self, func, *args, **kwargs):
"""
返回一个方法调用器,可以用来立即调用nakuru的方法。
"""
try:
ret = asyncio.run_coroutine_threadsafe(func(*args, **kwargs), self.loop).result()
ret = func(*args, **kwargs)
return ret
except BaseException as e:
raise e
def get_cnt(self):
return self.gocq_cnt
def set_cnt(self, cnt):
self.gocq_cnt = cnt
+25 -41
View File
@@ -4,7 +4,7 @@ from PIL import Image as PILImage
from botpy.message import Message, DirectMessage
import re
import asyncio
import requests
import aiohttp
from util import general_utils as gu
from botpy.types.message import Reference
@@ -13,10 +13,9 @@ import time
from ._platfrom import Platform
from ._nakuru_translation_layer import(
NakuruGuildMessage,
NakuruGuildMember,
gocq_compatible_receive,
gocq_compatible_send
)
)
from typing import Union
# QQ 机器人官方框架
@@ -28,13 +27,13 @@ class botClient(Client):
async def on_at_message_create(self, message: Message):
# 转换层
nakuru_guild_message = gocq_compatible_receive(message)
self.platform.new_sub_thread(self.platform.handle_msg, (nakuru_guild_message, True))
await self.platform.handle_msg(nakuru_guild_message, True)
# 收到私聊消息
async def on_direct_message_create(self, message: DirectMessage):
# 转换层
nakuru_guild_message = gocq_compatible_receive(message)
self.platform.new_sub_thread(self.platform.handle_msg, (nakuru_guild_message, False))
await self.platform.handle_msg(nakuru_guild_message, False)
class QQOfficial(Platform):
@@ -44,7 +43,6 @@ class QQOfficial(Platform):
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
self.qqchan_cnt = 0
self.waiting: dict = {}
self.cfg = cfg
@@ -109,7 +107,7 @@ class QQOfficial(Platform):
if message_result is None:
return
self.reply_msg(is_group, message, message_result.result_message)
await self.reply_msg(is_group, message, message_result.result_message)
if message_result.callback is not None:
message_result.callback()
@@ -117,7 +115,7 @@ class QQOfficial(Platform):
if session_id in self.waiting and self.waiting[session_id] == '':
self.waiting[session_id] = message
def reply_msg(self,
async def reply_msg(self,
is_group: bool,
message: NakuruGuildMessage,
res: Union[str, list]):
@@ -125,7 +123,6 @@ class QQOfficial(Platform):
回复频道消息
'''
self.logger.log(f"{message.sender.nickname}({message.sender.tiny_id}) <- {self.parse_message_outline(res)}", tag="QQ_OFFICIAL")
self.qqchan_cnt += 1
plain_text = ''
image_path = ''
@@ -151,10 +148,11 @@ class QQOfficial(Platform):
if image_path is not None and image_path != '':
msg_ref = None
if image_path.startswith("http"):
pic_res = requests.get(image_path, stream = True)
if pic_res.status_code == 200:
image = PILImage.open(io.BytesIO(pic_res.content))
image_path = gu.save_temp_img(image)
async with aiohttp.ClientSession() as session:
async with session.get(image_path) as response:
if response.status == 200:
image = PILImage.open(io.BytesIO(await response.read()))
image_path = gu.save_temp_img(image)
if message.raw_message is not None and image_path == '': # file_image与message_reference不能同时传入
msg_ref = Reference(message_id=message.raw_message.id, ignore_get_message_error=False)
@@ -173,8 +171,7 @@ class QQOfficial(Platform):
data['file_image'] = image_path
try:
# await self._send_wrapper(**data)
self._send_wrapper(**data)
await self._send_wrapper(**data)
except BaseException as e:
print(e)
# 分割过长的消息
@@ -184,51 +181,44 @@ class QQOfficial(Platform):
split_res.append(plain_text[len(plain_text)//2:])
for i in split_res:
data['content'] = i
# await self._send_wrapper(**data)
self._send_wrapper(**data)
await self._send_wrapper(**data)
else:
# 发送qq信息
try:
# 防止被qq频道过滤消息
plain_text = plain_text.replace(".", " . ")
# await self._send_wrapper(**data)
self._send_wrapper(**data)
await self._send_wrapper(**data)
except BaseException as e:
try:
data['content'] = str.join(" ", plain_text)
# await self._send_wrapper(**data)
self._send_wrapper(**data)
await self._send_wrapper(**data)
except BaseException as e:
plain_text = re.sub(r'(https|http)?:\/\/(\w|\.|\/|\?|\=|\&|\%)*\b', '[被隐藏的链接]', str(e), flags=re.MULTILINE)
plain_text = plain_text.replace(".", "·")
data['content'] = plain_text
# await self._send_wrapper(**data)
self._send_wrapper(**data)
await self._send_wrapper(**data)
def _send_wrapper(self, **kwargs):
async def _send_wrapper(self, **kwargs):
if 'channel_id' in kwargs:
asyncio.run_coroutine_threadsafe(self.client.api.post_message(**kwargs), self.loop).result()
await self.client.api.post_message(**kwargs)
else:
asyncio.run_coroutine_threadsafe(self.client.api.post_dms(**kwargs), self.loop).result()
await self.client.api.post_dms(**kwargs)
def send_msg(self, channel_id: int, message_chain: list, message_id: int = None):
async def send_msg(self, channel_id: int, message_chain: list, message_id: int = None):
'''
推送消息, 如果有 message_id,那么就是回复消息。非异步。
推送消息, 如果有 message_id,那么就是回复消息。
'''
_n = NakuruGuildMessage()
_n.channel_id = channel_id
_n.message_id = message_id
# await self.reply_msg(_n, message_chain)
self.reply_msg(_n, message_chain)
await self.reply_msg(_n, message_chain)
def send(self, message_obj, message_chain: list):
async def send(self, message_obj, message_chain: list):
'''
发送信息。内容同 reply_msg。非异步。
发送信息。内容同 reply_msg。
'''
# await self.reply_msg(message_obj, message_chain)
self.reply_msg(message_obj, message_chain)
await self.reply_msg(message_obj, message_chain)
def wait_for_message(self, channel_id: int) -> NakuruGuildMessage:
'''
@@ -246,9 +236,3 @@ class QQOfficial(Platform):
if cnt > 300:
raise Exception("等待消息超时。")
time.sleep(1)
def get_cnt(self):
return self.qqchan_cnt
def set_cnt(self, cnt):
self.qqchan_cnt = cnt
+28 -38
View File
@@ -1,18 +1,21 @@
from openai import OpenAI
from openai.types.chat.chat_completion import ChatCompletion
from openai.types.images_response import ImagesResponse
import json
import time
import os
import sys
import json
import time
import tiktoken
import threading
import traceback
from openai import AsyncOpenAI
from openai.types.images_response import ImagesResponse
from openai.types.chat.chat_completion import ChatCompletion
from cores.database.conn import dbConn
from model.provider.provider import Provider
import threading
from util import general_utils as gu
from util.cmd_config import CmdConfig
from util.general_utils import Logger
import traceback
import tiktoken
abs_path = os.path.dirname(os.path.realpath(sys.argv[0])) + '/'
@@ -42,7 +45,7 @@ class ProviderOpenAIOfficial(Provider):
self.logger.log(f"设置 api_base 为: {self.api_base}", tag="OpenAI")
# 创建 OpenAI Client
self.client = OpenAI(
self.client = AsyncOpenAI(
api_key=self.key_list[0],
base_url=self.api_base
)
@@ -113,7 +116,7 @@ class ProviderOpenAIOfficial(Provider):
}
self.session_dict[session_id].append(new_record)
def text_chat(self, prompt,
async def text_chat(self, prompt,
session_id = None,
image_url = None,
function_call=None,
@@ -132,7 +135,6 @@ class ProviderOpenAIOfficial(Provider):
if default_personality is not None:
self.personality_set(default_personality, session_id)
# 使用 tictoken 截断消息
_encoded_prompt = self.enc.encode(prompt)
if self.openai_model_configs['max_tokens'] < len(_encoded_prompt):
@@ -140,8 +142,8 @@ class ProviderOpenAIOfficial(Provider):
self.logger.log(f"注意,有一部分 prompt 文本由于超出 token 限制而被截断。", level=gu.LEVEL_WARNING, tag="OpenAI")
cache_data_list, new_record, req = self.wrap(prompt, session_id, image_url)
self.logger.log(f"CACHE_DATA_: {str(cache_data_list)}", level=gu.LEVEL_DEBUG, tag="OpenAI")
self.logger.log(f"OPENAI REQUEST: {str(req)}", level=gu.LEVEL_DEBUG, tag="OpenAI")
self.logger.log(f"cache: {str(cache_data_list)}", level=gu.LEVEL_DEBUG, tag="OpenAI")
self.logger.log(f"request: {str(req)}", level=gu.LEVEL_DEBUG, tag="OpenAI")
retry = 0
response = None
err = ''
@@ -168,19 +170,19 @@ class ProviderOpenAIOfficial(Provider):
while retry < 10:
try:
if function_call is None:
response = self.client.chat.completions.create(
response = await self.client.chat.completions.create(
messages=req,
**conf
)
else:
response = self.client.chat.completions.create(
response = await self.client.chat.completions.create(
messages=req,
tools = function_call,
**conf
)
break
except Exception as e:
print(traceback.format_exc())
traceback.print_exc()
if 'Invalid content type. image_url is only supported by certain models.' in str(e):
raise e
if 'You exceeded' in str(e) or 'Billing hard limit has been reached' in str(e) or 'No API key provided' in str(e) or 'Incorrect API key provided' in str(e):
@@ -188,7 +190,6 @@ class ProviderOpenAIOfficial(Provider):
self.key_stat[self.client.api_key]['exceed'] = True
is_switched = self.handle_switch_key()
if not is_switched:
# 所有Key都超额或不正常
raise e
retry -= 1
elif 'maximum context length' in str(e):
@@ -239,7 +240,6 @@ class ProviderOpenAIOfficial(Provider):
index += 1
# 删除完后更新相关字段
self.session_dict[session_id] = cache_data_list
# cache_prompt = get_prompts_by_cache_list(cache_data_list)
# 添加新条目进入缓存的prompt
new_record['AI'] = {
@@ -258,7 +258,7 @@ class ProviderOpenAIOfficial(Provider):
return chatgpt_res
def image_chat(self, prompt, img_num = 1, img_size = "1024x1024"):
async def image_chat(self, prompt, img_num = 1, img_size = "1024x1024"):
retry = 0
image_url = ''
@@ -266,7 +266,7 @@ class ProviderOpenAIOfficial(Provider):
while retry < 5:
try:
response: ImagesResponse = self.client.images.generate(
response: ImagesResponse = await self.client.images.generate(
prompt=prompt,
**image_generate_configs
)
@@ -282,7 +282,6 @@ class ProviderOpenAIOfficial(Provider):
self.key_stat[self.client.api_key]['exceed'] = True
is_switched = self.handle_switch_key()
if not is_switched:
# 所有Key都超额或不正常
raise e
elif 'Your request was rejected as a result of our safety system.' in str(e):
self.logger.log("您的请求被 OpenAI 安全系统拒绝, 请稍后再试", level=gu.LEVEL_WARNING, tag="OpenAI")
@@ -294,16 +293,16 @@ class ProviderOpenAIOfficial(Provider):
return image_url
def forget(self, session_id = None) -> bool:
async def forget(self, session_id = None) -> bool:
if session_id is None:
return False
self.session_dict[session_id] = []
return True
'''
获取缓存的会话
'''
def get_prompts_by_cache_list(self, cache_data_list, divide=False, paging=False, size=5, page=1):
'''
获取缓存的会话
'''
prompts = ""
if paging:
page_begin = (page-1)*size
@@ -320,15 +319,7 @@ class ProviderOpenAIOfficial(Provider):
if divide:
prompts += "----------\n"
return prompts
def get_user_usage_tokens(self,cache_list):
usage_tokens = 0
for item in cache_list:
usage_tokens += int(item['single_tokens'])
return usage_tokens
# 包装信息
def wrap(self, prompt, session_id, image_url = None):
if image_url is not None:
prompt = [
@@ -364,7 +355,6 @@ class ProviderOpenAIOfficial(Provider):
return context, new_record, req_list
def handle_switch_key(self):
# messages = [{"role": "user", "content": prompt}]
is_all_exceed = True
for key in self.key_stat:
if key == None or self.key_stat[key]['exceed']:
@@ -399,13 +389,13 @@ class ProviderOpenAIOfficial(Provider):
self.key_stat[key] = {'exceed': False, 'used': 0, 'sponsor': sponsor}
# 检查key是否可用
def check_key(self, key):
client_ = OpenAI(
async def check_key(self, key):
client_ = AsyncOpenAI(
api_key=key,
base_url=self.api_base
)
messages = [{"role": "user", "content": "please just echo `test`"}]
client_.chat.completions.create(
await client_.chat.completions.create(
messages=messages,
**self.openai_model_configs
)
+32 -10
View File
@@ -1,13 +1,35 @@
import abc
class Provider:
def __init__(self, cfg):
pass
async def text_chat(self,
prompt: str,
session_id: str,
image_url: None,
function_call: None,
extra_conf: dict = None,
default_personality: dict = None,
**kwargs) -> str:
'''
[require]
prompt: 提示词
session_id: 会话id
[optional]
image_url: 图片url(识图)
function_call: 函数调用
extra_conf: 额外配置
default_personality: 默认人格
'''
raise NotImplementedError
@abc.abstractmethod
def text_chat(self, prompt, session_id, image_url: None, function_call: None, extra_conf: dict = None, default_personality: dict = None) -> str:
pass
async def image_generate(self, prompt, session_id, **kwargs) -> str:
'''
[require]
prompt: 提示词
session_id: 会话id
'''
raise NotImplementedError
@abc.abstractmethod
def forget(self, session_id = None) -> bool:
pass
async def forget(self, session_id = None) -> bool:
'''
重置会话
'''
raise NotImplementedError
+2 -1
View File
@@ -1,5 +1,6 @@
pydantic~=1.10.4
requests~=2.28.1
aiohttp
requests
openai~=1.2.3
qq-botpy
chardet~=5.1.0
+54 -34
View File
@@ -1,18 +1,19 @@
import requests
import util.general_utils as gu
from bs4 import BeautifulSoup
import traceback
import time
import json
import asyncio
from googlesearch import search, SearchResult
from readability import Document
from bs4 import BeautifulSoup
from openai.types.chat.chat_completion_message_tool_call import Function
from util.function_calling.func_call import (
FuncCall,
FuncCallJsonFormatError,
FuncNotFoundError
)
from openai.types.chat.chat_completion_message_tool_call import Function
import traceback
from googlesearch import search, SearchResult
from model.provider.provider import Provider
import json
from readability import Document
def tidy_text(text: str) -> str:
@@ -53,11 +54,11 @@ def google_web_search(keyword) -> str:
for i in ls:
desc = i.description
try:
gu.log(f"搜索网页: {i.url}", tag="网页搜索", level=gu.LEVEL_INFO)
# gu.log(f"搜索网页: {i.url}", tag="网页搜索", level=gu.LEVEL_INFO)
desc = fetch_website_content(i.url)
except BaseException as e:
print(f"(google) fetch_website_content err: {str(e)}")
gu.log(f"# No.{str(index)}\ntitle: {i.title}\nurl: {i.url}\ncontent: {desc}\n\n", level=gu.LEVEL_DEBUG, max_len=9999)
# gu.log(f"# No.{str(index)}\ntitle: {i.title}\nurl: {i.url}\ncontent: {desc}\n\n", level=gu.LEVEL_DEBUG, max_len=9999)
ret += f"# No.{str(index)}\ntitle: {i.title}\nurl: {i.url}\ncontent: {desc}\n\n"
index += 1
except Exception as e:
@@ -80,7 +81,7 @@ def web_keyword_search_via_bing(keyword) -> str:
try:
response = requests.get(url, headers=headers)
response.encoding = "utf-8"
gu.log(f"bing response: {response.text}", tag="bing", level=gu.LEVEL_DEBUG, max_len=9999)
# gu.log(f"bing response: {response.text}", tag="bing", level=gu.LEVEL_DEBUG, max_len=9999)
soup = BeautifulSoup(response.text, "html.parser")
res = ""
result_cnt = 0
@@ -96,7 +97,7 @@ def web_keyword_search_via_bing(keyword) -> str:
# "link": link,
# })
try:
gu.log(f"搜索网页: {link}", tag="网页搜索", level=gu.LEVEL_INFO)
# gu.log(f"搜索网页: {link}", tag="网页搜索", level=gu.LEVEL_INFO)
desc = fetch_website_content(link)
except BaseException as e:
print(f"(bing) fetch_website_content err: {str(e)}")
@@ -124,11 +125,11 @@ def web_keyword_search_via_bing(keyword) -> str:
if result_cnt == 0: break
return res
except Exception as e:
gu.log(f"bing fetch err: {str(e)}")
# gu.log(f"bing fetch err: {str(e)}")
_cnt += 1
time.sleep(1)
gu.log("fail to fetch bing info, using sougou.")
# gu.log("fail to fetch bing info, using sougou.")
return web_keyword_search_via_sougou(keyword)
def web_keyword_search_via_sougou(keyword) -> str:
@@ -157,7 +158,7 @@ def web_keyword_search_via_sougou(keyword) -> str:
break
except Exception as e:
pass
gu.log(f"sougou parse err: {str(e)}", tag="web_keyword_search_via_sougou", level=gu.LEVEL_ERROR)
# gu.log(f"sougou parse err: {str(e)}", tag="web_keyword_search_via_sougou", level=gu.LEVEL_ERROR)
# 爬取网页内容
_detail_store = []
for i in res:
@@ -173,7 +174,7 @@ def web_keyword_search_via_sougou(keyword) -> str:
return ret
def fetch_website_content(url):
gu.log(f"fetch_website_content: {url}", tag="fetch_website_content", level=gu.LEVEL_DEBUG)
# gu.log(f"fetch_website_content: {url}", tag="fetch_website_content", level=gu.LEVEL_DEBUG)
headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) \
AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
@@ -187,7 +188,7 @@ def fetch_website_content(url):
ret = tidy_text(soup.get_text())
return ret
def web_search(question, provider: Provider, session_id, official_fc=False):
async def web_search(question, provider: Provider, session_id, official_fc=False):
'''
official_fc: 使用官方 function-calling
'''
@@ -197,7 +198,7 @@ def web_search(question, provider: Provider, session_id, official_fc=False):
"name": "keyword",
"description": "google search query (分词,尽量保留所有信息)"
}],
"通过搜索引擎搜索。如果问题需要在网页上搜索(如天气、新闻或任何需要通过网页获取信息的问题),则调用此函数;如果没有,不要调用此函数。",
"通过搜索引擎搜索。如果问题需要获取近期、实时的消息,在网页上搜索(如天气、新闻或任何需要通过网页获取信息的问题),则调用此函数;如果没有,不要调用此函数。",
web_keyword_search_via_bing
)
new_func_call.add_func("fetch_website_content", [{
@@ -205,16 +206,16 @@ def web_search(question, provider: Provider, session_id, official_fc=False):
"name": "url",
"description": "网址"
}],
"获取网页的内容。如果问题带有合法的网页链接(例如: `帮我总结一下https://github.com的内容`), 就调用此函数。如果没有,不要调用此函数。",
"获取网页的内容。如果问题带有合法的网页链接(例如: `帮我总结一下 https://github.com 的内容`), 就调用此函数。如果没有,不要调用此函数。",
fetch_website_content
)
question1 = f"{question} \n> hint: 最多只能调用1个function, 并且存在不会调用任何function的可能性。"
has_func = False
function_invoked_ret = ""
if official_fc:
func = provider.text_chat(question1, session_id, function_call=new_func_call.get_func())
# we use official function-calling
func = await provider.text_chat(question1, session_id, function_call=new_func_call.get_func())
if isinstance(func, Function):
# arguments='{\n "keyword": "北京今天的天气"\n}', name='google_web_search'
# 执行对应的结果:
func_obj = None
for i in new_func_call.func_list:
@@ -222,49 +223,68 @@ def web_search(question, provider: Provider, session_id, official_fc=False):
func_obj = i["func_obj"]
break
if not func_obj:
gu.log("找不到返回的 func name " + func.name, level=gu.LEVEL_ERROR)
return provider.text_chat(question1, session_id) + "\n(网页搜索失败, 此为默认回复)"
# gu.log("找不到返回的 func name " + func.name, level=gu.LEVEL_ERROR)
return await provider.text_chat(question1, session_id) + "\n(网页搜索失败, 此为默认回复)"
try:
args = json.loads(func.arguments)
function_invoked_ret = func_obj(**args)
# we use to_thread to avoid blocking the event loop
function_invoked_ret = await asyncio.to_thread(func_obj, **args)
has_func = True
except BaseException as e:
traceback.print_exc()
return provider.text_chat(question1, session_id) + "\n(网页搜索失败, 此为默认回复)"
return await provider.text_chat(question1, session_id) + "\n(网页搜索失败, 此为默认回复)"
else:
# now func is a string
return func
else:
# we use our own function-calling
try:
function_invoked_ret, has_func = new_func_call.func_call(question1, new_func_call.func_dump(), is_task=False, is_summary=False)
args = {
'question': question1,
'func_definition': new_func_call.func_dump(),
'is_task': False,
'is_summary': False,
}
function_invoked_ret, has_func = await asyncio.to_thread(new_func_call.func_call, **args)
except BaseException as e:
res = provider.text_chat(question) + "\n(网页搜索失败, 此为默认回复)"
res = await provider.text_chat(question) + "\n(网页搜索失败, 此为默认回复)"
return res
has_func = True
if has_func:
provider.forget(session_id)
await provider.forget(session_id)
question3 = f"""
以下是相关材料,你的任务是:
1. 根据材料对问题`{question}`做切题的总结回答;
2. 发表你对这个问题的看法.
你的任务是:
1. 根据末尾的材料对问题`{question}`做切题的总结(详细);
2. 简单地发表你对这个问题的看法(简略)。
你的总结末尾应当有对材料的引用, 如果有链接, 请在末尾附上引用网页链接。引用格式严格按照 `\n[1] title url \n`。
不要提到任何函数调用的信息。以下是相关材料:
不要提到任何函数调用的信息。
一些回复的消息模板:
模板1:
```
从网上的信息来看,可以知道...我个人认为...你觉得呢?
```
模板2:
```
根据网上的最新信息,可以得知...我觉得...你怎么看?
```
你可以根据这些模板来组织回答,但可以不照搬,要根据问题的内容来回答。
以下是相关材料:
"""
gu.log(f"web_search: {question3}", tag="web_search", level=gu.LEVEL_DEBUG, max_len=99999)
_c = 0
while _c < 3:
try:
print('text chat')
final_ret = provider.text_chat(question3 + "```" + function_invoked_ret + "```", session_id)
final_ret = await provider.text_chat(question3 + "```" + function_invoked_ret + "```", session_id)
return final_ret
except Exception as e:
print(e)
_c += 1
if _c == 3: raise e
if "The message you submitted was too long" in str(e):
provider.forget(session_id)
await provider.forget(session_id)
function_invoked_ret = function_invoked_ret[:int(len(function_invoked_ret) / 2)]
time.sleep(3)
return function_invoked_ret
+46 -35
View File
@@ -7,10 +7,12 @@ import re
import requests
from util.cmd_config import CmdConfig
import socket
from cores.qqbot.global_object import GlobalObject
from cores.qqbot.types import GlobalObject
import platform
import requests
import logging
import json
import sys
import psutil
PLATFORM_GOCQ = 'gocq'
PLATFORM_QQCHAN = 'qqchan'
@@ -193,7 +195,6 @@ def word2img(title: str, text: str, max_width=30, font_size=20):
return image
def render_markdown(markdown_text, image_width=800, image_height=600, font_size=26, font_color=(0, 0, 0), bg_color=(255, 255, 255)):
HEADER_MARGIN = 20
@@ -322,7 +323,6 @@ def render_markdown(markdown_text, image_width=800, image_height=600, font_size=
height += font_size + TEXT_LINE_MARGIN*2
markdown_text = '\n'.join(pre_lines)
print("Pre process done, height: ", height)
image_height = height
if image_height < 100:
image_height = 100
@@ -332,13 +332,6 @@ def render_markdown(markdown_text, image_width=800, image_height=600, font_size=
image = Image.new('RGB', (image_width, image_height), bg_color)
draw = ImageDraw.Draw(image)
# # get all the emojis unicode in the markdown text
# unicode_text = markdown_text.encode('unicode_escape').decode()
# # print(unicode_text)
# unicode_emojis = re.findall(r'\\U\w{8}', unicode_text)
# emoji_base_url = "https://abs.twimg.com/emoji/v1/72x72/{unicode_emoji}.png"
# 设置初始位置
x, y = 10, 10
@@ -360,25 +353,10 @@ def render_markdown(markdown_text, image_width=800, image_height=600, font_size=
line = line.strip()
if line.startswith("#"):
# unicode_emojis = re.findall(r'\\U0001\w{4}', line)
# for unicode_emoji in unicode_emojis:
# line = line.replace(unicode_emoji, "")
# unicode_emoji = ""
# if len(unicode_emojis) > 0:
# unicode_emoji = unicode_emojis[0]
# 处理标题
header_level = line.count("#")
line = line.strip("#").strip()
font_size_header = HEADER_FONT_STANDARD_SIZE - header_level * 4
# if unicode_emoji != "":
# emoji_url = emoji_base_url.format(unicode_emoji=unicode_emoji[-5:])
# emoji = Image.open(requests.get(emoji_url, stream=True).raw)
# emoji = emoji.resize((font_size, font_size))
# image.paste(emoji, (x, y))
# x += font_size
font = ImageFont.truetype(font_path, font_size_header)
y += HEADER_MARGIN # 上边距
# 字间距
@@ -389,10 +367,6 @@ def render_markdown(markdown_text, image_width=800, image_height=600, font_size=
elif line.startswith(">"):
# 处理引用
quote_text = line.strip(">")
# quote_width = image_width - 20 # 引用框的宽度为图像宽度减去左右边距
# quote_height = font_size + 10 # 引用框的高度为字体大小加上上下边距
# quote_box = (x, y, x + quote_width, y + quote_height)
# draw.rounded_rectangle(quote_box, radius=5, fill=(230, 230, 230), width=2) # 使用灰色填充矩形框作为引用背景
y+=QUOTE_LEFT_LINE_MARGIN
draw.line((x, y, x, y + QUOTE_LEFT_LINE_HEIGHT), fill=QUOTE_LEFT_LINE_COLOR, width=QUOTE_LEFT_LINE_WIDTH)
font = ImageFont.truetype(font_path, QUOTE_FONT_SIZE)
@@ -468,7 +442,6 @@ def render_markdown(markdown_text, image_width=800, image_height=600, font_size=
y += image_res.size[1] + IMAGE_MARGIN*2
return image
def save_temp_img(img: Image) -> str:
if not os.path.exists("temp"):
os.makedirs("temp")
@@ -490,7 +463,6 @@ def save_temp_img(img: Image) -> str:
img.save(p)
return p
def create_text_image(title: str, text: str, max_width=30, font_size=20):
'''
文本转图片。
@@ -520,9 +492,10 @@ def create_markdown_image(text: str):
except Exception as e:
raise e
# 迁移配置文件到 cmd_config.json
def try_migrate_config(old_config: dict):
'''
迁移配置文件到 cmd_config.json
'''
cc = CmdConfig()
if cc.get("qqbot", None) is None:
# 未迁移过
@@ -553,4 +526,42 @@ def get_sys_info(global_object: GlobalObject):
'mem': mem,
'os': os_name + '_' + os_version,
'py': platform.python_version(),
}
}
def upload(_global_object: GlobalObject):
while True:
addr_ip = ''
try:
res = {
"version": _global_object.version,
"count": _global_object.cnt_total,
"ip": addr_ip,
"sys": sys.platform,
"admin": "null",
}
resp = requests.post('https://api.soulter.top/upload', data=json.dumps(res), timeout=5)
if resp.status_code == 200:
ok = resp.json()
if ok['status'] == 'ok':
_global_object.cnt_total = 0
except BaseException as e:
pass
time.sleep(10*60)
def run_monitor(global_object: GlobalObject):
'''
监测机器性能
- Bot 内存使用量
- CPU 占用率
'''
start_time = time.time()
while True:
stat = global_object.dashboard_data.stats
# 程序占用的内存大小
mem = psutil.Process().memory_info().rss / 1024 / 1024 # MB
stat['sys_perf'] = {
'memory': mem,
'cpu': psutil.cpu_percent()
}
stat['sys_start_time'] = start_time
time.sleep(30)
+11 -1
View File
@@ -1 +1,11 @@
from cores.qqbot.global_object import GlobalObject
from cores.qqbot.types import (
PluginMetadata,
RegisteredLLM,
RegisteredPlugin,
RegisteredPlatform,
RegisteredPlugins,
PluginType,
GlobalObject,
AstrMessageEvent,
CommandResult
)
-2
View File
@@ -1,4 +1,3 @@
from cores.qqbot.global_object import GlobalObject
from typing import Union
import os
import json
@@ -19,7 +18,6 @@ def load_config(namespace: str) -> Union[dict, bool]:
ret[k] = data[k]["value"]
return ret
def put_config(namespace: str, name: str, key: str, value, description: str):
'''
将配置项写入以namespace为名字的配置文件,如果key不存在于目标配置文件中。当前 value 仅支持 str, int, float, bool, list 类型(暂不支持 dict)。
+6
View File
@@ -0,0 +1,6 @@
'''
大语言模型.
插件开发者可以继承这个类来做实现。
'''
from model.provider.provider import Provider as LLMProvider
+1 -1
View File
@@ -1,5 +1,5 @@
from cores.qqbot.core import oper_msg
from cores.qqbot.global_object import AstrMessageEvent, CommandResult
from cores.qqbot.types import AstrMessageEvent, CommandResult
from model.platform._message_result import MessageResult
'''
+11
View File
@@ -0,0 +1,11 @@
'''
消息平台。
Platform类是消息平台的抽象类,定义了消息平台的基本接口。
消息平台的具体实现类需要继承Platform类,并实现其中的抽象方法。
'''
from model.platform._platfrom import Platform
from model.platform.qq_gocq import QQGOCQ
from model.platform.qq_official import QQOfficial
+63
View File
@@ -0,0 +1,63 @@
'''
允许开发者注册某一个类的实例到 LLM 或者 PLATFORM 中,方便其他插件调用。
必须分别实现 Platform 和 LLMProvider 中涉及的接口
'''
from model.provider.provider import Provider as LLMProvider
from model.platform._platfrom import Platform
from cores.qqbot.types import GlobalObject, RegisteredPlatform, RegisteredLLM
def register_platform(platform_name: str, platform_instance: Platform, context: GlobalObject) -> None:
'''
注册一个消息平台。
Args:
platform_name: 平台名称。
platform_instance: 平台实例。
'''
# check 是否已经注册
for platform in context.platforms:
if platform.platform_name == platform_name:
raise ValueError(f"Platform {platform_name} has been registered.")
context.platforms.append(RegisteredPlatform(platform_name, platform_instance))
def register_llm(llm_name: str, llm_instance: LLMProvider, context: GlobalObject) -> None:
'''
注册一个大语言模型。
Args:
llm_name: 大语言模型名称。
llm_instance: 大语言模型实例。
'''
# check 是否已经注册
for llm in context.llms:
if llm.llm_name == llm_name:
raise ValueError(f"LLMProvider {llm_name} has been registered.")
context.llms.append(RegisteredLLM(llm_name, llm_instance))
def unregister_platform(platform_name: str, context: GlobalObject) -> None:
'''
注销一个消息平台。
Args:
platform_name: 平台名称。
'''
for i, platform in enumerate(context.platforms):
if platform.platform_name == platform_name:
context.platforms.pop(i)
return
def unregister_llm(llm_name: str, context: GlobalObject) -> None:
'''
注销一个大语言模型。
Args:
llm_name: 大语言模型名称。
'''
for i, llm in enumerate(context.llms):
if llm.llm_name == llm_name:
context.llms.pop(i)
return
+5
View File
@@ -0,0 +1,5 @@
'''
插件类型
'''
from cores.qqbot.types import PluginType
+75 -33
View File
@@ -9,11 +9,19 @@ try:
except ImportError:
pass
import shutil
from pip._internal import main as pipmain
import importlib
import stat
import traceback
from types import ModuleType
from typing import List
from pip._internal import main as pipmain
from cores.qqbot.types import (
PluginMetadata,
PluginType,
RegisteredPlugin,
RegisteredPlugins
)
# 找出模块里所有的类名
def get_classes(p_name, arg: ModuleType):
@@ -45,7 +53,8 @@ def get_modules(path):
if os.path.exists(os.path.join(path, d, "main.py")) or os.path.exists(os.path.join(path, d, d + ".py")):
modules.append({
"pname": d,
"module": module_str
"module": module_str,
"module_path": os.path.join(path, d, module_str)
})
return modules
@@ -73,39 +82,62 @@ def get_plugin_modules():
except BaseException as e:
raise e
def plugin_reload(cached_plugins: dict, target: str = None, all: bool = False):
def plugin_reload(cached_plugins: RegisteredPlugins):
plugins = get_plugin_modules()
if plugins is None:
return False, "未找到任何插件模块"
fail_rec = ""
registered_map = {}
for p in cached_plugins:
registered_map[p.module_path] = None
for plugin in plugins:
try:
p = plugin['module']
module_path = plugin['module_path']
root_dir_name = plugin['pname']
if p not in cached_plugins or p == target or all:
if module_path in registered_map:
# 之前注册过
module = importlib.reload(module)
else:
module = __import__("addons.plugins." + root_dir_name + "." + p, fromlist=[p])
if p in cached_plugins:
module = importlib.reload(module)
cls = get_classes(p, module)
obj = getattr(module, cls[0])()
try:
info = obj.info()
cls = get_classes(p, module)
obj = getattr(module, cls[0])()
metadata = None
try:
info = obj.info()
if isinstance(info, dict):
if 'name' not in info or 'desc' not in info or 'version' not in info or 'author' not in info:
fail_rec += f"载入插件{p}失败,原因: 插件信息不完整\n"
fail_rec += f"注册插件 {module_path} 失败,原因: 插件信息不完整\n"
continue
if isinstance(info, dict) == False:
fail_rec += f"载入插件{p}失败,原因: 插件信息格式不正确\n"
continue
except BaseException as e:
fail_rec += f"调用插件{p} info失败, 原因: {str(e)}\n"
else:
metadata = PluginMetadata(
plugin_name=info['name'],
plugin_type=PluginType.COMMON if 'plugin_type' not in info else PluginType(info['plugin_type']),
author=info['author'],
desc=info['desc'],
version=info['version'],
repo=info['repo'] if 'repo' in info else None
)
elif isinstance(info, PluginMetadata):
metadata = info
else:
fail_rec += f"注册插件 {module_path} 失败,原因: info 函数返回值类型错误\n"
continue
cached_plugins[info['name']] = {
"module": module,
"clsobj": obj,
"info": info,
"name": info['name'],
"root_dir_name": root_dir_name,
}
except BaseException as e:
fail_rec += f"注册插件 {module_path} 失败, 原因: {str(e)}\n"
continue
cached_plugins.append(RegisteredPlugin(
metadata=metadata,
plugin_instance=obj,
module=module,
module_path=module_path,
root_dir_name=root_dir_name
))
except BaseException as e:
traceback.print_exc()
fail_rec += f"加载{p}插件出现问题,原因 {str(e)}\n"
@@ -114,7 +146,7 @@ def plugin_reload(cached_plugins: dict, target: str = None, all: bool = False):
else:
return False, fail_rec
def install_plugin(repo_url: str, cached_plugins: dict):
def install_plugin(repo_url: str, cached_plugins: RegisteredPlugins):
ppath = get_plugin_store_path()
# 删除末尾的 /
if repo_url.endswith("/"):
@@ -132,23 +164,33 @@ def install_plugin(repo_url: str, cached_plugins: dict):
if os.path.exists(os.path.join(plugin_path, "requirements.txt")):
if pipmain(['install', '-r', os.path.join(plugin_path, "requirements.txt"), '--quiet']) != 0:
raise Exception("插件的依赖安装失败, 需要您手动 pip 安装对应插件的依赖。")
ok, err = plugin_reload(cached_plugins, target=d)
ok, err = plugin_reload(cached_plugins)
if not ok: raise Exception(err)
def get_registered_plugin(plugin_name: str, cached_plugins: RegisteredPlugins) -> RegisteredPlugin:
ret = None
for p in cached_plugins:
if p.metadata.plugin_name == plugin_name:
ret = p
break
return ret
def uninstall_plugin(plugin_name: str, cached_plugins: dict):
if plugin_name not in cached_plugins:
def uninstall_plugin(plugin_name: str, cached_plugins: RegisteredPlugins):
plugin = get_registered_plugin(plugin_name, cached_plugins)
if not plugin:
raise Exception("插件不存在。")
root_dir_name = cached_plugins[plugin_name]["root_dir_name"]
root_dir_name = plugin.root_dir_name
ppath = get_plugin_store_path()
del cached_plugins[plugin_name]
cached_plugins.remove(plugin)
if not remove_dir(os.path.join(ppath, root_dir_name)):
raise Exception("移除插件成功,但是删除插件文件夹失败。您可以手动删除该文件夹,位于 addons/plugins/ 下。")
def update_plugin(plugin_name: str, cached_plugins: dict):
if plugin_name not in cached_plugins:
def update_plugin(plugin_name: str, cached_plugins: RegisteredPlugins):
plugin = get_registered_plugin(plugin_name, cached_plugins)
if not plugin:
raise Exception("插件不存在。")
ppath = get_plugin_store_path()
root_dir_name = cached_plugins[plugin_name]["root_dir_name"]
root_dir_name = plugin.root_dir_name
plugin_path = os.path.join(ppath, root_dir_name)
repo = Repo(path = plugin_path)
repo.remotes.origin.pull()
@@ -156,7 +198,7 @@ def update_plugin(plugin_name: str, cached_plugins: dict):
if os.path.exists(os.path.join(plugin_path, "requirements.txt")):
if pipmain(['install', '-r', os.path.join(plugin_path, "requirements.txt"), '--quiet']) != 0:
raise Exception("插件依赖安装失败, 需要您手动pip安装对应插件的依赖。")
ok, err = plugin_reload(cached_plugins, target=plugin_name)
ok, err = plugin_reload(cached_plugins)
if not ok: raise Exception(err)
def remove_dir(file_path) -> bool: