diff --git a/cores/openai/core.py b/cores/openai/core.py index 5bcf9115c..588f652e2 100644 --- a/cores/openai/core.py +++ b/cores/openai/core.py @@ -22,8 +22,8 @@ class ChatGPT: global inst inst = self - async def chat(self, prompt): - print("[ChatGPT] 接收到prompt") + def chat(self, prompt): + print("[OpenAI API]收到") try: response = openai.Completion.create( prompt=prompt, @@ -33,7 +33,8 @@ class ChatGPT: raise PromptExceededError("OpenAI遇到错误:输入了一个不合法的请求。\n"+str(e)) # print(response['usage']) - return response["choices"][0]["text"], response['usage']['total_tokens'] + print("[ChatGPT] "+response["choices"][0]["text"]) + return response["choices"][0]["text"].strip(), response['usage']['total_tokens'] def getConfigs(self): return self.openai_configs diff --git a/cores/qqbot/core.py b/cores/qqbot/core.py index 13a2113c1..2bb75b518 100644 --- a/cores/qqbot/core.py +++ b/cores/qqbot/core.py @@ -1,18 +1,25 @@ import botpy from botpy.message import Message import yaml -import asyncio -import cores.openai.core import re from util.errors.errors import PromptExceededError from botpy.message import DirectMessage import json +from concurrent.futures import ThreadPoolExecutor +import threading +import asyncio +client = '' +# executor = ThreadPoolExecutor(max_workers=10) +# ChatGPT的实例 chatgpt = "" -# db = "" +# 缓存的会话 session_dict = {} +# 最大缓存token(在配置里改 configs/config.yaml) max_tokens = 2000 -version = "1.3" +# 版本 +version = "1.4" +# gpt配置(在配置改) gpt_config = { 'engine': '', 'temperature': '', @@ -21,50 +28,55 @@ gpt_config = { 'presence_penalty': '', 'max_tokens': '', } +# 统计信息 count = { } +# 统计信息 stat_file = '' +# 是否是独立会话(在配置改) uniqueSession = False +def new_sub_thread(func, args=()): + thread = threading.Thread(target=func, args=args, daemon=True) + thread.start() + class botClient(botpy.Client): + # 收到At消息 async def on_at_message_create(self, message: Message): - global stat_file - try: - if str(message.guild_id) not in count: - count[str(message.guild_id)] = { - 'count': 1, - 'direct_count': 0, - } - else: - count[str(message.guild_id)]['count'] += 1 - stat_file = open("./configs/stat", 'w', encoding='utf-8') - stat_file.write(json.dumps(count)) - stat_file.flush() - stat_file.close() - except BaseException: - pass - - await oper_msg(message=message, at=True) + toggle_count(at=True, message=message) + # executor.submit(oper_msg, message, True) + new_sub_thread(oper_msg, (message, True)) + # await oper_msg(message=message, at=True) + # 收到私聊消息 async def on_direct_message_create(self, message: DirectMessage): - global stat_file - try: - if str(message.guild_id) not in count: - count[str(message.guild_id)] = { - 'count': 1, - 'direct_count': 1, - } - else: - count[str(message.guild_id)]['count'] += 1 - count[str(message.guild_id)]['direct_count'] += 1 - stat_file = open("./configs/stat", 'w', encoding='utf-8') - stat_file.write(json.dumps(count)) - stat_file.flush() - stat_file.close() - except BaseException: - pass + toggle_count(at=False, message=message) + print("收到私聊消息") + # executor.submit(oper_msg, message, True) + # await oper_msg(message=message, at=False) + new_sub_thread(oper_msg, (message, False)) + print("私聊消息处理完毕") + +# 写入统计信息 +def toggle_count(at: bool, message): + global stat_file + try: + if str(message.guild_id) not in count: + count[str(message.guild_id)] = { + 'count': 1, + 'direct_count': 1, + } + else: + count[str(message.guild_id)]['count'] += 1 + if not at: + count[str(message.guild_id)]['direct_count'] += 1 + stat_file = open("./configs/stat", 'w', encoding='utf-8') + stat_file.write(json.dumps(count)) + stat_file.flush() + stat_file.close() + except BaseException: + pass - await oper_msg(message=message, at=False) def initBot(chatgpt_inst): global chatgpt @@ -91,7 +103,6 @@ def initBot(chatgpt_inst): pass global uniqueSession - with open("./configs/config.yaml", 'r', encoding='utf-8') as ymlfile: cfg = yaml.safe_load(ymlfile) if 'uniqueSessionMode' in cfg['qqbot'] and cfg['qqbot']['uniqueSessionMode'] == 'true': @@ -101,21 +112,32 @@ def initBot(chatgpt_inst): if cfg['qqbot']['appid'] != '' or cfg['qqbot']['token'] != '': print("读取QQBot appid token 成功") intents = botpy.Intents(public_guild_messages=True, direct_message=True) + global client client = botClient(intents=intents) client.run(appid=cfg['qqbot']['appid'], token=cfg['qqbot']['token']) else: raise BaseException("请在config中完善你的appid和token") - - -async def get_chatGPT_response(prompts_str): +''' +得到OpenAI的回复 +''' +def get_chatGPT_response(prompts_str): res = '' usage = '' - res, usage = await chatgpt.chat(prompts_str) + res, usage = chatgpt.chat(prompts_str) # 处理结果文本 chatgpt_res = res.strip() return res, usage +''' +回复QQ消息 +''' +def send_qq_msg(message, res): + asyncio.run_coroutine_threadsafe(message.reply(content=res), client.loop) + +''' +获取缓存的会话 +''' def get_prompts_by_cache_list(cache_data_list, divide=False, paging=False, size=5, page=1): prompts = "" if paging: @@ -138,9 +160,11 @@ def get_user_usage_tokens(cache_list): usage_tokens += int(item['single_tokens']) return usage_tokens -async def oper_msg(message, at=False): +def oper_msg(message, at=False, loop=None): print("[QQBOT] 接收到消息:"+ message.content) - qq_msg = "" + qq_msg = '' + session_id = '' + name = '' if at: # 过滤用户id @@ -150,193 +174,177 @@ async def oper_msg(message, at=False): result = re.search(pattern, message.content) if result: qq_msg = result.group(1).strip() - else: - qq_msg = message.content - - user_id = message.author.id - if not at: - session_id = message.author.id - else: if uniqueSession: session_id = message.author.id else: session_id = message.guild_id - if session_id: - name = '' - if uniqueSession: - name = message.member.nick + else: + qq_msg = message.content + session_id = message.author.id + + if uniqueSession: + name = message.member.nick + else: + name = "频道" + + # 指令控制 + if qq_msg == "/reset": + msg = '' + session_dict[session_id] = [] + if at: + msg = f"{name}(id: {session_id}) 的历史记录重置成功" else: - name = "频道" - if qq_msg == "/reset": - - session_dict[session_id] = [] - if at: - await message.reply(content=f"{name}(id: {session_id}) 的历史记录重置成功") - else: - await message.reply(content=f"你的历史记录重置成功") - return - if qq_msg[:4] == "/his": - #分页,每页5条 - size_per_page = 3 - page = 1 - if qq_msg[5:]: - page = int(qq_msg[5:]) - # 检查是否有过历史记录 - if session_id not in session_dict: - await message.reply(content=f"{name} 的历史记录为空") - return - l = session_dict[session_id] - max_page = len(l)//size_per_page + 1 if len(l)%size_per_page != 0 else len(l)//size_per_page - p = get_prompts_by_cache_list(session_dict[session_id], divide=True, paging=True, size=size_per_page, page=page) - if at: - await message.reply(content=f"{name} 的历史记录如下:\n{p}\n第{page}页 | 共{max_page}页\n*输入/his 2跳转到第2页") - else: - await message.reply(content=f"历史记录如下:\n{p}\n第{page}页 | 共{max_page}页\n*输入/his 2跳转到第2页") - - return - if qq_msg == "/token": - if at: - await message.reply(content=f"{name} 会话的token数: {get_user_usage_tokens(session_dict[session_id])}\n系统最大缓存token数: {max_tokens}") - else: - await message.reply(content=f"会话的token数: {get_user_usage_tokens(session_dict[session_id])}\n系统最大缓存token数: {max_tokens}") - - return - if qq_msg == "/status": - chatgpt_cfg_str = "" - for k, v in gpt_config.items(): - if k == "key": - continue - chatgpt_cfg_str += f"{k}: {v}" - - await message.reply(content=f"ChatGPT配置:\n - {chatgpt_cfg_str}\n QQChannelChatGPT 版本: {version}") - return - - if qq_msg == "/count": - try: - f = open("./configs/stat", "r", encoding="utf-8") - fjson = json.loads(f.read()) - f.close() - guild_count = 0 - guild_msg_count = 0 - guild_direct_msg_count = 0 - - for k,v in fjson.items(): - guild_count += 1 - guild_msg_count += v['count'] - guild_direct_msg_count += v['direct_count'] - - session_count = 0 - - f = open("./configs/session", "r", encoding="utf-8") - fjson = json.loads(f.read()) - f.close() - for k,v in fjson.items(): - session_count += 1 - except: - pass - - await message.reply(content=f"当前会话数: {len(session_dict)}\n共有频道数: {guild_count} \n共有消息数: {guild_msg_count}\n私信数: {guild_direct_msg_count}\n历史会话数: {session_count}") - return - - if qq_msg == "/help": - await message.reply(content=f"请联系频道管理员或者前往github(仓库名: QQChannelChatGPT)提issue~") - return - + msg = f"你的历史记录重置成功" + send_qq_msg(message, msg) + return + if qq_msg[:4] == "/his": + #分页,每页5条 + msg = '' + size_per_page = 3 + page = 1 + if qq_msg[5:]: + page = int(qq_msg[5:]) + # 检查是否有过历史记录 if session_id not in session_dict: - session_dict[session_id] = [] + msg = f"{name} 的历史记录为空" + l = session_dict[session_id] + max_page = len(l)//size_per_page + 1 if len(l)%size_per_page != 0 else len(l)//size_per_page + p = get_prompts_by_cache_list(session_dict[session_id], divide=True, paging=True, size=size_per_page, page=page) + if at: + msg=f"{name} 的历史记录如下:\n{p}\n第{page}页 | 共{max_page}页\n*输入/his 2跳转到第2页" + else: + msg=f"历史记录如下:\n{p}\n第{page}页 | 共{max_page}页\n*输入/his 2跳转到第2页" + send_qq_msg(message, msg) + return + if qq_msg == "/token": + msg = '' + if at: + msg=f"{name} 会话的token数: {get_user_usage_tokens(session_dict[session_id])}\n系统最大缓存token数: {max_tokens}" + else: + msg=f"会话的token数: {get_user_usage_tokens(session_dict[session_id])}\n系统最大缓存token数: {max_tokens}" + send_qq_msg(message, msg) + return + if qq_msg == "/status": + chatgpt_cfg_str = "" + for k, v in gpt_config.items(): + if k == "key": + continue + chatgpt_cfg_str += f"{k}: {v}" + send_qq_msg(message, f"ChatGPT配置:\n - {chatgpt_cfg_str}\n QQChannelChatGPT 版本: {version}") + return + if qq_msg == "/count": + try: + f = open("./configs/stat", "r", encoding="utf-8") + fjson = json.loads(f.read()) + f.close() + guild_count = 0 + guild_msg_count = 0 + guild_direct_msg_count = 0 - fjson = {} - try: - f = open("./configs/session", "r", encoding="utf-8") - fjson = json.loads(f.read()) - f.close() - except: - pass - finally: - fjson[session_id] = 'true' - f = open("./configs/session", "w", encoding="utf-8") - f.write(json.dumps(fjson)) - f.flush() - f.close() + for k,v in fjson.items(): + guild_count += 1 + guild_msg_count += v['count'] + guild_direct_msg_count += v['direct_count'] + + session_count = 0 + f = open("./configs/session", "r", encoding="utf-8") + fjson = json.loads(f.read()) + f.close() + for k,v in fjson.items(): + session_count += 1 + except: + pass + send_qq_msg(message, f"当前会话数: {len(session_dict)}\n共有频道数: {guild_count} \n共有消息数: {guild_msg_count}\n私信数: {guild_direct_msg_count}\n历史会话数: {session_count}") + return + if qq_msg == "/help": + send_qq_msg(message, "请联系频道管理员或者前往github(仓库名: QQChannelChatGPT)提issue~") + return + + # 统计历史会话 + if session_id not in session_dict: + session_dict[session_id] = [] - # 获取缓存 - cache_prompt = '' + fjson = {} + try: + f = open("./configs/session", "r", encoding="utf-8") + fjson = json.loads(f.read()) + f.close() + except: + pass + finally: + fjson[session_id] = 'true' + f = open("./configs/session", "w", encoding="utf-8") + f.write(json.dumps(fjson)) + f.flush() + f.close() + + # 获取缓存 + cache_prompt = '' + cache_data_list = session_dict[session_id] + cache_prompt = get_prompts_by_cache_list(cache_data_list) + cache_prompt += "Human: "+ qq_msg + "\nAI: " + # 请求chatGPT获得结果 + try: + chatgpt_res, current_usage_tokens = get_chatGPT_response(cache_prompt) + except (PromptExceededError) as e: + print("出现token超限, 清空对应缓存") + # 超过4097tokens错误,清空缓存 + session_dict[session_id] = [] + cache_data_list = [] + cache_prompt = "Human: "+ qq_msg + "\nAI: " + chatgpt_res, current_usage_tokens = get_chatGPT_response(cache_prompt) + except (BaseException) as e: + print("OpenAI API错误:(") + send_qq_msg(message, f"OpenAI API错误:( 原因如下:\n{str(e)} \n*前往github(仓库名: QQChannelChatGPT)反馈~") + return + + # 发送qq信息 + try: + # 防止被qq频道过滤消息 + gap_chatgpt_res = chatgpt_res.replace(".", " . ") + # 发送信息 + send_qq_msg(message, '[ChatGPT]'+gap_chatgpt_res) + except BaseException as e: + print("QQ频道API错误: \n"+str(e)) + f_res = "" + for t in chatgpt_res: + f_res += t + ' ' + try: + pass + # send(message, f"QQ频道API错误:{str(e)}\n下面是格式化后的回答:\n{f_res}") + except BaseException as e: + # 如果还是不行则过滤url + f_res = re.sub(r'(https|http)?:\/\/(\w|\.|\/|\?|\=|\&|\%)*\b', '', f_res, flags=re.MULTILINE) + f_res = f_res.replace(".", " . ") + # send(message, f"QQ频道API错误:{str(e)}\n下面是格式化后的回答:\n{f_res}") + + # 超过指定tokens, 尽可能的保留最多的条目,直到小于max_tokens + if current_usage_tokens > max_tokens: + t = current_usage_tokens + cache_list = session_dict[session_id] + index = 0 + while t > max_tokens: + if index >= len(cache_list): + break + t -= int(cache_list[index]['single_tokens']) + index += 1 + session_dict[session_id] = cache_list[index:] cache_data_list = session_dict[session_id] cache_prompt = get_prompts_by_cache_list(cache_data_list) - cache_prompt += "Human: "+ qq_msg + "\nAI: " - # 请求chatGPT获得结果 - try: - chatgpt_res, current_usage_tokens = await get_chatGPT_response(cache_prompt) - except (PromptExceededError) as e: - print("出现token超限, 清空对应缓存") - # 超过4097tokens错误,清空缓存 - session_dict[session_id] = [] - cache_data_list = [] - cache_prompt = "Human: "+ qq_msg + "\nAI: " - chatgpt_res, current_usage_tokens = await get_chatGPT_response(cache_prompt) - except (BaseException) as e: - print("OpenAI API错误:(") - await message.reply(content=f"OpenAI API错误:( 原因如下:\n{str(e)} \n*前往github(仓库名: QQChannelChatGPT)反馈~") - # 超过指定tokens, 尽可能的保留最多的条目,直到小于max_tokens - # print("current_usage_tokens: ", current_usage_tokens) - # print("max_tokens: ", max_tokens) - if current_usage_tokens > max_tokens: - t = current_usage_tokens - cache_list = session_dict[session_id] - index = 0 - while t > max_tokens: - if index >= len(cache_list): - break - t -= int(cache_list[index]['single_tokens']) - index += 1 - session_dict[session_id] = cache_list[index:] - cache_data_list = session_dict[session_id] - cache_prompt = get_prompts_by_cache_list(cache_data_list) - - # cache_prompt += chatgpt_res + "\n"; - # 添加新条目进入缓存的prompt - if len(cache_data_list) > 0: - single_record = { - "prompt": f'Human: {qq_msg}\nAI: {chatgpt_res}\n', - "usage_tokens": current_usage_tokens, - "single_tokens": current_usage_tokens - int(cache_data_list[-1]['usage_tokens']) - } - else: - single_record = { - "prompt": f'Human: {qq_msg}\nAI: {chatgpt_res}\n', - "usage_tokens": current_usage_tokens, - "single_tokens": current_usage_tokens - } - # print(single_record) - cache_data_list.append(single_record) - - # 写入数据库 - # try: - # data = { - # "data": cache_data_list - # } - # data_str = json.dumps(data) - # if len(cache_data_list) > 1: - # db.update_session(session_id, data_str) - # else: - # db.insert_session(session_id, data_str) - # except Exception as e: - # print(e) - # print("数据库写入失败") - - session_dict[session_id] = cache_data_list - - # #检测是否存在url,如果存在,则去除url 防止被qq频道过滤 - chatgpt_res = chatgpt_res.replace(".", " . ") - - # 发送qq信息 - try: - await message.reply(content=f"[ChatGPT]{chatgpt_res}") - except BaseException as e: - print("QQ频道API错误: \n"+str(e)) - f_res = "" - for t in chatgpt_res: - f_res += t + ' ' - await message.reply(content=f"QQ频道API错误:{str(e)}\n下面是格式化后的回答:\n{f_res}") \ No newline at end of file + # 添加新条目进入缓存的prompt + if len(cache_data_list) > 0: + single_record = { + "prompt": f'Human: {qq_msg}\nAI: {chatgpt_res}\n', + "usage_tokens": current_usage_tokens, + "single_tokens": current_usage_tokens - int(cache_data_list[-1]['usage_tokens']) + } + else: + single_record = { + "prompt": f'Human: {qq_msg}\nAI: {chatgpt_res}\n', + "usage_tokens": current_usage_tokens, + "single_tokens": current_usage_tokens + } + cache_data_list.append(single_record) + session_dict[session_id] = cache_data_list \ No newline at end of file