Files
AstrBot/cores/qqbot/core.py
T
2023-12-07 19:18:54 +08:00

907 lines
34 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import botpy
from botpy.message import Message, DirectMessage
import re
import json
import threading
import asyncio
import time
import requests
import util.unfit_words as uw
import os
import sys
from cores.qqbot.personality import personalities
from addons.baidu_aip_judge import BaiduJudge
from model.platform.qqchan import QQChan, NakuruGuildMember, NakuruGuildMessage
from model.platform.qq import QQ
from model.platform.qqgroup import (
UnofficialQQBotSDK,
Event as QQEvent,
Message as QQMessage,
MessageChain,
PlainText
)
from nakuru import (
CQHTTP,
GroupMessage,
GroupMemberIncrease,
FriendMessage,
GuildMessage,
Notify
)
from nakuru.entities.components import Plain,At,Image
from model.provider.provider import Provider
from model.command.command import Command
from util import general_utils as gu
from util.cmd_config import CmdConfig as cc
import util.gplugin as gplugin
from PIL import Image as PILImage
import io
import traceback
from . global_object import GlobalObject
from typing import Union, Callable
# 缓存的会话
session_dict = {}
# 统计信息
count = {}
# 统计信息
stat_file = ''
# 日志记录
# logf = open('log.log', 'a+', encoding='utf-8')
# 用户发言频率
user_frequency = {}
# 时间默认值
frequency_time = 60
# 计数默认值
frequency_count = 2
# 公告(可自定义):
announcement = ""
# 机器人私聊模式
direct_message_mode = True
# 适配pyinstaller
abs_path = os.path.dirname(os.path.realpath(sys.argv[0])) + '/'
# 版本
version = '3.0.4'
# 语言模型
REV_CHATGPT = 'rev_chatgpt'
OPENAI_OFFICIAL = 'openai_official'
REV_ERNIE = 'rev_ernie'
REV_EDGEGPT = 'rev_edgegpt'
NONE_LLM = 'none_llm'
chosen_provider = None
# 语言模型对象
llm_instance: dict[str, Provider] = {}
llm_command_instance: dict[str, Command] = {}
# 百度内容审核实例
baidu_judge = None
# 关键词回复
keywords = {}
# QQ频道机器人
qqchannel_bot: QQChan = None
PLATFORM_QQCHAN = 'qqchan'
qqchan_loop = None
client = None
# QQ群机器人
PLATFROM_QQBOT = 'qqbot'
# CLI
PLATFORM_CLI = 'cli'
# 加载默认配置
cc.init_attributes("qq_forward_threshold", 200)
cc.init_attributes("qq_welcome", "欢迎加入本群!\n欢迎给https://github.com/Soulter/QQChannelChatGPT项目一个Star😊~\n输入help查看帮助~\n")
cc.init_attributes("bing_proxy", "")
cc.init_attributes("qq_pic_mode", False)
cc.init_attributes("rev_chatgpt_model", "")
cc.init_attributes("rev_chatgpt_plugin_ids", [])
cc.init_attributes("rev_chatgpt_PUID", "")
cc.init_attributes("rev_chatgpt_unverified_plugin_domains", [])
cc.init_attributes("gocq_host", "127.0.0.1")
cc.init_attributes("gocq_http_port", 5700)
cc.init_attributes("gocq_websocket_port", 6700)
cc.init_attributes("gocq_react_group", True)
cc.init_attributes("gocq_react_guild", True)
cc.init_attributes("gocq_react_friend", True)
cc.init_attributes("gocq_react_group_increase", True)
cc.init_attributes("gocq_qqchan_admin", "")
cc.init_attributes("other_admins", [])
cc.init_attributes("CHATGPT_BASE_URL", "")
cc.init_attributes("qqbot_appid", "")
cc.init_attributes("qqbot_secret", "")
cc.init_attributes("llm_env_prompt", "> hint: 末尾根据内容和心情添加 1-2 个emoji")
cc.init_attributes("default_personality_str", "")
cc.init_attributes("openai_image_generate", {
"model": "dall-e-3",
"size": "1024x1024",
"style": "vivid",
"quality": "standard",
})
# cc.init_attributes(["qq_forward_mode"], False)
# QQ机器人
gocq_bot = None
PLATFORM_GOCQ = 'gocq'
gocq_app = CQHTTP(
host=cc.get("gocq_host", "127.0.0.1"),
port=cc.get("gocq_websocket_port", 6700),
http_port=cc.get("gocq_http_port", 5700),
)
qq_bot: UnofficialQQBotSDK = UnofficialQQBotSDK(
cc.get("qqbot_appid", None),
cc.get("qqbot_secret", None)
)
gocq_loop: asyncio.AbstractEventLoop = None
qqbot_loop: asyncio.AbstractEventLoop = None
# 全局对象
_global_object: GlobalObject = None
def new_sub_thread(func, args=()):
thread = threading.Thread(target=_runner, args=(func, args), daemon=True)
thread.start()
def _runner(func: Callable, args: tuple):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(func(*args))
loop.close()
# [Deprecated] 写入统计信息
def toggle_count(at: bool, message):
global stat_file
try:
if str(message.guild_id) not in count:
count[str(message.guild_id)] = {
'count': 1,
'direct_count': 1,
}
else:
count[str(message.guild_id)]['count'] += 1
if not at:
count[str(message.guild_id)]['direct_count'] += 1
stat_file = open(abs_path+"configs/stat", 'w', encoding='utf-8')
stat_file.write(json.dumps(count))
stat_file.flush()
stat_file.close()
except BaseException:
pass
# 上传统计信息并检查更新
def upload():
global version, gocq_bot, qqchannel_bot
while True:
addr = ''
addr_ip = ''
session_dict_dump = '{}'
try:
addr = requests.get('http://myip.ipip.net', timeout=5).text
addr_ip = re.findall(r'\d+.\d+.\d+.\d+', addr)[0]
except BaseException as e:
pass
try:
gocq_cnt = 0
qqchan_cnt = 0
if gocq_bot is not None:
gocq_cnt = gocq_bot.get_cnt()
if qqchannel_bot is not None:
qqchan_cnt = qqchannel_bot.get_cnt()
o = {"cnt_total": _global_object.cnt_total,"admin": _global_object.admin_qq,"addr": addr, 's': session_dict_dump}
o_j = json.dumps(o)
res = {"version": version, "count": gocq_cnt+qqchan_cnt, "ip": addr_ip, "others": o_j, "cntqc": qqchan_cnt, "cntgc": gocq_cnt}
gu.log(res, gu.LEVEL_DEBUG, tag="Upload", fg = gu.FG_COLORS['yellow'], bg=gu.BG_COLORS['black'])
resp = requests.post('https://api.soulter.top/upload', data=json.dumps(res), timeout=5)
# print(resp.text)
if resp.status_code == 200:
ok = resp.json()
if ok['status'] == 'ok':
_global_object.cnt_total = 0
if gocq_bot is not None:
gocq_cnt = gocq_bot.set_cnt(0)
if qqchannel_bot is not None:
qqchan_cnt = qqchannel_bot.set_cnt(0)
except BaseException as e:
gu.log("上传统计信息时出现错误: " + str(e), gu.LEVEL_ERROR, tag="Upload")
pass
time.sleep(10*60)
'''
初始化机器人
'''
def initBot(cfg, prov):
global llm_instance, llm_command_instance
global baidu_judge, chosen_provider
global frequency_count, frequency_time, announcement, direct_message_mode, version
global keywords, _global_object
_event_loop = asyncio.new_event_loop()
asyncio.set_event_loop(_event_loop)
# 初始化 global_object
_global_object = GlobalObject()
_global_object.base_config = cfg
if 'reply_prefix' in cfg:
_global_object.reply_prefix = cfg['reply_prefix']
# 语言模型提供商
gu.log("--------加载语言模型--------", gu.LEVEL_INFO, fg=gu.FG_COLORS['yellow'])
if REV_CHATGPT in prov:
gu.log("- 逆向ChatGPT库 -", gu.LEVEL_INFO)
if cfg['rev_ChatGPT']['enable']:
if 'account' in cfg['rev_ChatGPT']:
from model.provider.rev_chatgpt import ProviderRevChatGPT
from model.command.rev_chatgpt import CommandRevChatGPT
llm_instance[REV_CHATGPT] = ProviderRevChatGPT(cfg['rev_ChatGPT'], base_url=cc.get("CHATGPT_BASE_URL", None))
llm_command_instance[REV_CHATGPT] = CommandRevChatGPT(llm_instance[REV_CHATGPT], _global_object)
chosen_provider = REV_CHATGPT
else:
input("[System-err] 请退出本程序, 然后在配置文件中填写rev_ChatGPT相关配置")
if REV_EDGEGPT in prov:
gu.log("- New Bing -", gu.LEVEL_INFO)
if not os.path.exists('./cookies.json'):
input("[System-err] 导入Bing模型时发生错误, 没有找到cookies文件或者cookies文件放置位置错误。windows启动器启动的用户请把cookies.json文件放到和启动器相同的目录下。\n如何获取请看https://github.com/Soulter/QQChannelChatGPT仓库介绍。")
else:
if cfg['rev_edgegpt']['enable']:
try:
from model.provider.rev_edgegpt import ProviderRevEdgeGPT
from model.command.rev_edgegpt import CommandRevEdgeGPT
llm_instance[REV_EDGEGPT] = ProviderRevEdgeGPT()
llm_command_instance[REV_EDGEGPT] = CommandRevEdgeGPT(llm_instance[REV_EDGEGPT], _global_object)
chosen_provider = REV_EDGEGPT
except BaseException as e:
print(traceback.format_exc())
gu.log("加载Bing模型时发生错误, 请检查1. cookies文件是否正确放置 2. 是否设置了代理(梯子)。", gu.LEVEL_ERROR, max_len=60)
if OPENAI_OFFICIAL in prov:
gu.log("- OpenAI官方 -", gu.LEVEL_INFO)
if cfg['openai']['key'] is not None and cfg['openai']['key'] != [None]:
from model.provider.openai_official import ProviderOpenAIOfficial
from model.command.openai_official import CommandOpenAIOfficial
llm_instance[OPENAI_OFFICIAL] = ProviderOpenAIOfficial(cfg['openai'])
llm_command_instance[OPENAI_OFFICIAL] = CommandOpenAIOfficial(llm_instance[OPENAI_OFFICIAL], _global_object)
chosen_provider = OPENAI_OFFICIAL
gu.log("--------加载配置--------", gu.LEVEL_INFO, fg=gu.FG_COLORS['yellow'])
# 得到关键词
if os.path.exists("keyword.json"):
with open("keyword.json", 'r', encoding='utf-8') as f:
keywords = json.load(f)
# 检查provider设置偏好
p = cc.get("chosen_provider", None)
if p is not None and p in llm_instance:
chosen_provider = p
gu.log(f"将使用 {chosen_provider} 语言模型。", gu.LEVEL_INFO)
# 百度内容审核
if 'baidu_aip' in cfg and 'enable' in cfg['baidu_aip'] and cfg['baidu_aip']['enable']:
try:
baidu_judge = BaiduJudge(cfg['baidu_aip'])
gu.log("百度内容审核初始化成功", gu.LEVEL_INFO)
except BaseException as e:
gu.log("百度内容审核初始化失败", gu.LEVEL_ERROR)
threading.Thread(target=upload, daemon=True).start()
# 得到私聊模式配置
if 'direct_message_mode' in cfg:
direct_message_mode = cfg['direct_message_mode']
gu.log("私聊功能: "+str(direct_message_mode), gu.LEVEL_INFO)
# 得到发言频率配置
if 'limit' in cfg:
gu.log("发言频率配置: "+str(cfg['limit']), gu.LEVEL_INFO)
if 'count' in cfg['limit']:
frequency_count = cfg['limit']['count']
if 'time' in cfg['limit']:
frequency_time = cfg['limit']['time']
# 得到公告配置
if 'notice' in cfg:
if cc.get("qq_welcome", None) != None and cfg['notice'] == '此机器人由Github项目QQChannelChatGPT驱动。':
announcement = cc.get("qq_welcome", None)
else:
announcement = cfg['notice']
gu.log("公告配置: " + announcement, gu.LEVEL_INFO)
try:
if 'uniqueSessionMode' in cfg and cfg['uniqueSessionMode']:
_global_object.uniqueSession = True
else:
_global_object.uniqueSession = False
gu.log("独立会话: "+str(_global_object.uniqueSession), gu.LEVEL_INFO)
except BaseException as e:
gu.log("独立会话配置错误: "+str(e), gu.LEVEL_ERROR)
gu.log(f"QQ开放平台AppID: {cfg['qqbot']['appid']} 令牌: {cfg['qqbot']['token']}")
if chosen_provider is None:
gu.log("检测到没有启动任何语言模型。", gu.LEVEL_CRITICAL)
nick_qq = cc.get("nick_qq", None)
if nick_qq == None:
nick_qq = ("ai","!","")
if isinstance(nick_qq, str):
nick_qq = (nick_qq,)
if isinstance(nick_qq, list):
nick_qq = tuple(nick_qq)
_global_object.nick = nick_qq
thread_inst = None
gu.log("--------加载插件--------", gu.LEVEL_INFO, fg=gu.FG_COLORS['yellow'])
# 加载插件
_command = Command(None, _global_object)
ok, err = _command.plugin_reload(_global_object.cached_plugins)
if ok:
gu.log("加载插件完成", gu.LEVEL_INFO)
else:
gu.log(err, gu.LEVEL_ERROR)
if chosen_provider is None:
llm_command_instance[NONE_LLM] = _command
chosen_provider = NONE_LLM
gu.log("--------加载机器人平台--------", gu.LEVEL_INFO, fg=gu.FG_COLORS['yellow'])
admin_qq = cc.get('admin_qq', None)
admin_qqchan = cc.get('admin_qqchan', None)
if admin_qq == None:
gu.log("未设置管理者QQ号(管理者才能使用update/plugin等指令),如需设置,请编辑 cmd_config.json 文件", gu.LEVEL_WARNING)
if admin_qqchan == None:
gu.log("未设置管理者QQ频道用户号(管理者才能使用update/plugin等指令),如需设置,请编辑 cmd_config.json 文件。可在频道发送指令 !myid 获取", gu.LEVEL_WARNING)
_global_object.admin_qq = admin_qq
_global_object.admin_qqchan = admin_qqchan
global qq_bot, qqbot_loop
qqbot_loop = asyncio.new_event_loop()
if cc.get("qqbot_appid", '') != '' and cc.get("qqbot_secret", '') != '':
gu.log("- 启用QQ群机器人 -", gu.LEVEL_INFO)
thread_inst = threading.Thread(target=run_qqbot, args=(qqbot_loop, qq_bot,), daemon=True)
thread_inst.start()
# GOCQ
global gocq_bot
if 'gocqbot' in cfg and cfg['gocqbot']['enable']:
gu.log("- 启用QQ机器人 -", gu.LEVEL_INFO)
global gocq_app, gocq_loop
gocq_loop = asyncio.new_event_loop()
gocq_bot = QQ(True, cc, gocq_loop)
thread_inst = threading.Thread(target=run_gocq_bot, args=(gocq_loop, gocq_bot, gocq_app), daemon=True)
thread_inst.start()
else:
gocq_bot = QQ(False)
_global_object.platform_qq = gocq_bot
gu.log("机器人部署教程: https://github.com/Soulter/QQChannelChatGPT/wiki/", gu.LEVEL_INFO, fg=gu.FG_COLORS['yellow'])
gu.log("如果有任何问题, 请在 https://github.com/Soulter/QQChannelChatGPT 上提交issue或加群322154837", gu.LEVEL_INFO, fg=gu.FG_COLORS['yellow'])
gu.log("请给 https://github.com/Soulter/QQChannelChatGPT 点个star!", gu.LEVEL_INFO, fg=gu.FG_COLORS['yellow'])
# QQ频道
if 'qqbot' in cfg and cfg['qqbot']['enable']:
gu.log("- 启用QQ频道机器人 -", gu.LEVEL_INFO)
global qqchannel_bot, qqchan_loop
qqchannel_bot = QQChan()
qqchan_loop = asyncio.new_event_loop()
_global_object.platform_qqchan = qqchannel_bot
thread_inst = threading.Thread(target=run_qqchan_bot, args=(cfg, qqchan_loop, qqchannel_bot), daemon=True)
thread_inst.start()
# thread.join()
if thread_inst == None:
raise Exception("[System-Error] 没有启用/成功启用任何机器人,程序退出")
default_personality_str = cc.get("default_personality_str", "")
if default_personality_str == "":
_global_object.default_personality = None
else:
_global_object.default_personality = {
"name": "default",
"prompt": default_personality_str,
}
gu.log("🎉 项目启动完成。")
asyncio.get_event_loop().run_until_complete(cli())
thread_inst.join()
async def cli():
time.sleep(1)
while True:
try:
prompt = input(">>> ")
if prompt == "":
continue
ngm = await cli_pack_message(prompt)
await oper_msg(ngm, True, PLATFORM_CLI)
except EOFError:
return
async def cli_pack_message(prompt: str) -> NakuruGuildMessage:
ngm = NakuruGuildMessage()
ngm.channel_id = 6180
ngm.user_id = 6180
ngm.message = [Plain(prompt)]
ngm.type = "GuildMessage"
ngm.self_id = 6180
ngm.self_tiny_id = 6180
ngm.guild_id = 6180
ngm.sender = NakuruGuildMember()
ngm.sender.tiny_id = 6180
ngm.sender.user_id = 6180
ngm.sender.nickname = "CLI"
ngm.sender.role = 0
return ngm
'''
运行QQ频道机器人
'''
def run_qqchan_bot(cfg, loop, qqchannel_bot: QQChan):
asyncio.set_event_loop(loop)
intents = botpy.Intents(public_guild_messages=True, direct_message=True)
global client
client = botClient(
intents=intents,
bot_log=False
)
try:
qqchannel_bot.run_bot(client, cfg['qqbot']['appid'], cfg['qqbot']['token'])
except BaseException as e:
gu.log("启动QQ频道机器人时出现错误, 原因如下: " + str(e), gu.LEVEL_CRITICAL, tag="QQ频道")
gu.log(r"如果您是初次启动,请修改配置文件(QQChannelChatGPT/config.yaml)详情请看:https://github.com/Soulter/QQChannelChatGPT/wiki。" + str(e), gu.LEVEL_CRITICAL, tag="System")
i = input("按回车退出程序。\n")
'''
运行GOCQ机器人
'''
def run_gocq_bot(loop, gocq_bot, gocq_app):
asyncio.set_event_loop(loop)
gu.log("正在检查本地GO-CQHTTP连接...端口5700, 6700", tag="QQ")
noticed = False
while True:
if not gu.port_checker(5700, cc.get("gocq_host", "127.0.0.1")) or not gu.port_checker(6700, cc.get("gocq_host", "127.0.0.1")):
if not noticed:
noticed = True
gu.log("与GO-CQHTTP通信失败, 请检查GO-CQHTTP是否启动并正确配置。程序会每隔 5s 自动重试。", gu.LEVEL_CRITICAL, tag="QQ")
time.sleep(5)
else:
gu.log("检查完毕,未发现问题。", tag="QQ")
break
global gocq_client
gocq_client = gocqClient()
try:
gocq_bot.run_bot(gocq_app)
except BaseException as e:
input("启动QQ机器人出现错误"+str(e))
'''
启动QQ群机器人(官方接口)
'''
def run_qqbot(loop: asyncio.AbstractEventLoop, qq_bot: UnofficialQQBotSDK):
asyncio.set_event_loop(loop)
QQBotClient()
qq_bot.run_bot()
'''
检查发言频率
'''
def check_frequency(id) -> bool:
ts = int(time.time())
if id in user_frequency:
if ts-user_frequency[id]['time'] > frequency_time:
user_frequency[id]['time'] = ts
user_frequency[id]['count'] = 1
return True
else:
if user_frequency[id]['count'] >= frequency_count:
return False
else:
user_frequency[id]['count']+=1
return True
else:
t = {'time':ts,'count':1}
user_frequency[id] = t
return True
'''
通用消息回复
'''
async def send_message(platform, message, res, session_id = None):
global qqchannel_bot, qqchannel_bot, gocq_loop, session_dict
if session_id is not None:
if session_id not in session_dict:
session_dict[session_id] = {'cnt': 1}
else:
session_dict[session_id]['cnt'] += 1
else:
session_dict[session_id]['cnt'] += 1
if platform == PLATFORM_QQCHAN:
qqchannel_bot.send_qq_msg(message, res)
if platform == PLATFORM_GOCQ:
await gocq_bot.send_qq_msg(message, res)
if platform == PLATFROM_QQBOT:
message_chain = MessageChain()
message_chain.parse_from_nakuru(res)
await qq_bot.send(message, message_chain)
if platform == PLATFORM_CLI:
print(res)
async def oper_msg(message: Union[GroupMessage, FriendMessage, GuildMessage, NakuruGuildMessage],
group: bool=False,
platform: str = None):
"""
处理消息。
group: 群聊模式,
message: 频道是频道的消息对象, QQ是nakuru-gocq的消息对象
msg_ref: 引用消息(频道)
platform: 平台(gocq, qqchan)
"""
global chosen_provider, keywords, qqchannel_bot, gocq_bot
global _global_object
qq_msg = ''
session_id = ''
user_id = ''
role = "member" # 角色, member或admin
hit = False # 是否命中指令
command_result = () # 调用指令返回的结果
_global_object.cnt_total += 1
with_tag = False # 是否带有昵称
if platform == PLATFORM_QQCHAN or platform == PLATFROM_QQBOT or platform == PLATFORM_CLI:
with_tag = True
_len = 0
for i in message.message:
if isinstance(i, Plain) or isinstance(i, PlainText):
qq_msg += str(i.text).strip()
if isinstance(i, At):
if message.type == "GuildMessage":
if i.qq == message.user_id or i.qq == message.self_tiny_id:
with_tag = True
if message.type == "FriendMessage":
if i.qq == message.self_id:
with_tag = True
if message.type == "GroupMessage":
if i.qq == message.self_id:
with_tag = True
for i in _global_object.nick:
if i != '' and qq_msg.startswith(i):
_len = len(i)
with_tag = True
break
qq_msg = qq_msg[_len:].strip()
gu.log(f"收到消息:{qq_msg}", gu.LEVEL_INFO, tag="QQ")
user_id = message.user_id
if group:
# 适配GO-CQHTTP的频道功能
if message.type == "GuildMessage":
session_id = message.channel_id
else:
session_id = message.group_id
else:
with_tag = True
session_id = message.user_id
if message.type == "GuildMessage":
sender_id = str(message.sender.tiny_id)
else:
sender_id = str(message.sender.user_id)
if sender_id == _global_object.admin_qq or \
sender_id == _global_object.admin_qqchan or \
sender_id in cc.get("other_admins", []) or \
sender_id == cc.get("gocq_qqchan_admin", "") or \
platform == PLATFORM_CLI:
role = "admin"
if _global_object.uniqueSession:
# 独立会话时,一个用户一个 session
session_id = sender_id
if qq_msg == "":
await send_message(platform, message, f"Hi~", session_id=session_id)
return
if with_tag:
# 检查发言频率
if not check_frequency(user_id):
await send_message(platform, message, f'你的发言超过频率限制(╯▔皿▔)╯。\n管理员设置{frequency_time}秒内只能提问{frequency_count}次。', session_id=session_id)
return
# logf.write("[GOCQBOT] "+ qq_msg+'\n')
# logf.flush()
# 关键词回复
for k in keywords:
if qq_msg == k:
plain_text = ""
if 'plain_text' in keywords[k]:
plain_text = keywords[k]['plain_text']
else:
plain_text = keywords[k]
image_url = ""
if 'image_url' in keywords[k]:
image_url = keywords[k]['image_url']
if image_url != "":
res = [Plain(plain_text), Image.fromURL(image_url)]
await send_message(platform, message, res, session_id=session_id)
else:
await send_message(platform, message, plain_text, session_id=session_id)
return
# 检查是否是更换语言模型的请求
temp_switch = ""
if qq_msg.startswith('/bing') or qq_msg.startswith('/gpt') or qq_msg.startswith('/revgpt'):
target = chosen_provider
if qq_msg.startswith('/bing'):
target = REV_EDGEGPT
elif qq_msg.startswith('/gpt'):
target = OPENAI_OFFICIAL
elif qq_msg.startswith('/revgpt'):
target = REV_CHATGPT
l = qq_msg.split(' ')
if len(l) > 1 and l[1] != "":
# 临时对话模式,先记录下之前的语言模型,回答完毕后再切回
temp_switch = chosen_provider
chosen_provider = target
qq_msg = l[1]
else:
chosen_provider = target
cc.put("chosen_provider", chosen_provider)
await send_message(platform, message, f"已切换至【{chosen_provider}", session_id=session_id)
return
chatgpt_res = ""
# 如果是等待回复的消息
if platform == PLATFORM_GOCQ and session_id in gocq_bot.waiting and gocq_bot.waiting[session_id] == '':
gocq_bot.waiting[session_id] = message
return
if platform == PLATFORM_QQCHAN and session_id in qqchannel_bot.waiting and qqchannel_bot.waiting[session_id] == '':
qqchannel_bot.waiting[session_id] = message
return
hit, command_result = llm_command_instance[chosen_provider].check_command(
qq_msg,
session_id,
role,
platform,
message,
)
# 没触发指令
if not hit:
if not with_tag:
return
# 关键词拦截
for i in uw.unfit_words_q:
matches = re.match(i, qq_msg.strip(), re.I | re.M)
if matches:
await send_message(platform, message, f"你的提问得到的回复未通过【自有关键词拦截】服务, 不予回复。", session_id=session_id)
return
if baidu_judge != None:
check, msg = baidu_judge.judge(qq_msg)
if not check:
await send_message(platform, message, f"你的提问得到的回复未通过【百度AI内容审核】服务, 不予回复。\n\n{msg}", session_id=session_id)
return
if chosen_provider == None:
await send_message(platform, message, f"管理员未启动任何语言模型或者语言模型初始化时失败。", session_id=session_id)
return
try:
# check image url
image_url = None
for comp in message.message:
if isinstance(comp, Image):
if comp.url is None:
image_url = comp.file
break
else:
image_url = comp.url
break
# web search keyword
web_sch_flag = False
if qq_msg.startswith("ws ") and qq_msg != "ws ":
qq_msg = qq_msg[3:]
web_sch_flag = True
else:
qq_msg += " " + cc.get("llm_env_prompt", "")
if chosen_provider == REV_CHATGPT or chosen_provider == OPENAI_OFFICIAL:
if _global_object.web_search or web_sch_flag:
official_fc = chosen_provider == OPENAI_OFFICIAL
chatgpt_res = gplugin.web_search(qq_msg, llm_instance[chosen_provider], session_id, official_fc)
else:
chatgpt_res = str(llm_instance[chosen_provider].text_chat(qq_msg, session_id, image_url, default_personality = _global_object.default_personality))
elif chosen_provider == REV_EDGEGPT:
res, res_code = await llm_instance[chosen_provider].text_chat(qq_msg, platform)
if res_code == 0: # bing不想继续话题,重置会话后重试。
await send_message(platform, message, "Bing不想继续话题了, 正在自动重置会话并重试。", session_id=session_id)
await llm_instance[chosen_provider].forget()
res, res_code = await llm_instance[chosen_provider].text_chat(qq_msg, platform)
if res_code == 0: # bing还是不想继续话题,大概率说明提问有问题。
await llm_instance[chosen_provider].forget()
await send_message(platform, message, "Bing仍然不想继续话题, 会话已重置, 请检查您的提问后重试。", session_id=session_id)
res = ""
chatgpt_res = str(res)
if chosen_provider in _global_object.reply_prefix:
chatgpt_res = _global_object.reply_prefix[chosen_provider] + chatgpt_res
except BaseException as e:
gu.log(f"调用异常:{traceback.format_exc()}", gu.LEVEL_ERROR, max_len=100000)
gu.log("调用语言模型例程时出现异常。原因: "+str(e), gu.LEVEL_ERROR)
await send_message(platform, message, "调用语言模型例程时出现异常。原因: "+str(e), session_id=session_id)
return
# 切换回原来的语言模型
if temp_switch != "":
chosen_provider = temp_switch
# 指令回复
if hit:
# 检查指令. command_result是一个元组:(指令调用是否成功, 指令返回的文本结果, 指令类型)
if command_result == None:
# await send_message(platform, message, "指令调用未返回任何信息。", session_id=session_id)
return
command = command_result[2]
if command == "keyword":
if os.path.exists("keyword.json"):
with open("keyword.json", "r", encoding="utf-8") as f:
keywords = json.load(f)
else:
try:
await send_message(platform, message, command_result[1], session_id=session_id)
except BaseException as e:
await send_message(platform, message, f"回复消息出错: {str(e)}", session_id=session_id)
if command == "update latest r":
await send_message(platform, message, command_result[1] + "\n\n即将自动重启。", session_id=session_id)
py = sys.executable
os.execl(py, py, *sys.argv)
if not command_result[0]:
await send_message(platform, message, f"指令调用错误: \n{str(command_result[1])}", session_id=session_id)
return
# 画图指令
if isinstance(command_result[1], list) and len(command_result) == 3 and command == 'draw':
for i in command_result[1]:
# i is a link
# 保存到本地
pic_res = requests.get(i, stream = True)
if pic_res.status_code == 200:
image = PILImage.open(io.BytesIO(pic_res.content))
await send_message(platform, message, [Image.fromFileSystem(gu.save_temp_img(image))], session_id=session_id)
# 其他指令
else:
try:
await send_message(platform, message, command_result[1], session_id=session_id)
except BaseException as e:
await send_message(platform, message, f"回复消息出错: {str(e)}", session_id=session_id)
return
# 记录日志
# logf.write(f"{reply_prefix} {str(chatgpt_res)}\n")
# logf.flush()
# 敏感过滤
# 过滤不合适的词
for i in uw.unfit_words:
chatgpt_res = re.sub(i, "***", chatgpt_res)
# 百度内容审核服务二次审核
if baidu_judge != None:
check, msg = baidu_judge.judge(chatgpt_res)
if not check:
await send_message(platform, message, f"你的提问得到的回复【百度内容审核】未通过,不予回复。\n\n{msg}", session_id=session_id)
return
# 发送信息
try:
await send_message(platform, message, chatgpt_res, session_id=session_id)
except BaseException as e:
gu.log("回复消息错误: \n"+str(e), gu.LEVEL_ERROR)
# QQ频道机器人
class botClient(botpy.Client):
# 收到频道消息
async def on_at_message_create(self, message: Message):
gu.log(str(message), gu.LEVEL_DEBUG, max_len=9999)
# 转换层
nakuru_guild_message = qqchannel_bot.gocq_compatible_receive(message)
gu.log(f"转换后: {str(nakuru_guild_message)}", gu.LEVEL_DEBUG, max_len=9999)
new_sub_thread(oper_msg, (nakuru_guild_message, True, PLATFORM_QQCHAN))
# 收到私聊消息
async def on_direct_message_create(self, message: DirectMessage):
if direct_message_mode:
# 转换层
nakuru_guild_message = qqchannel_bot.gocq_compatible_receive(message)
gu.log(f"转换后: {str(nakuru_guild_message)}", gu.LEVEL_DEBUG, max_len=9999)
new_sub_thread(oper_msg, (nakuru_guild_message, False, PLATFORM_QQCHAN))
# QQ机器人
class gocqClient():
# 收到群聊消息
@gocq_app.receiver("GroupMessage")
async def _(app: CQHTTP, source: GroupMessage):
if cc.get("gocq_react_group", True):
if isinstance(source.message[0], Plain):
new_sub_thread(oper_msg, (source, True, PLATFORM_GOCQ))
if isinstance(source.message[0], At):
if source.message[0].qq == source.self_id:
new_sub_thread(oper_msg, (source, True, PLATFORM_GOCQ))
else:
return
@gocq_app.receiver("FriendMessage")
async def _(app: CQHTTP, source: FriendMessage):
if cc.get("gocq_react_friend", True):
if isinstance(source.message[0], Plain):
new_sub_thread(oper_msg, (source, False, PLATFORM_GOCQ))
else:
return
@gocq_app.receiver("GroupMemberIncrease")
async def _(app: CQHTTP, source: GroupMemberIncrease):
if cc.get("gocq_react_group_increase", True):
global announcement
await app.sendGroupMessage(source.group_id, [
Plain(text = announcement),
])
@gocq_app.receiver("Notify")
async def _(app: CQHTTP, source: Notify):
print(source)
if source.sub_type == "poke" and source.target_id == source.self_id:
new_sub_thread(oper_msg, (source, False, PLATFORM_GOCQ))
@gocq_app.receiver("GuildMessage")
async def _(app: CQHTTP, source: GuildMessage):
if cc.get("gocq_react_guild", True):
if isinstance(source.message[0], Plain):
new_sub_thread(oper_msg, (source, True, PLATFORM_GOCQ))
if isinstance(source.message[0], At):
if source.message[0].qq == source.self_tiny_id:
new_sub_thread(oper_msg, (source, True, PLATFORM_GOCQ))
else:
return
class QQBotClient():
@qq_bot.on('GroupMessage')
async def _(bot: UnofficialQQBotSDK, message: QQMessage):
print(message)
new_sub_thread(oper_msg, (message, True, PLATFROM_QQBOT))