From 51ee4fc9b3602e741011396144efc6e0af2ae924 Mon Sep 17 00:00:00 2001 From: Soulter <37870767+Soulter@users.noreply.github.com> Date: Fri, 16 Dec 2022 23:13:15 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=94=AF=E6=8C=81=E6=8C=81=E4=B9=85?= =?UTF-8?q?=E5=8C=96=E5=8E=86=E5=8F=B2=E8=AE=B0=E5=BD=95=20fix:=20?= =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E4=BA=86=E4=B8=80=E4=BA=9B=E5=B7=B2=E7=9F=A5?= =?UTF-8?q?=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- cores/database/conn.py | 96 ++++++++++++++++++++++++++++++++++-------- cores/qqbot/core.py | 44 +++++++++++++++++-- main.py | 4 ++ 3 files changed, 124 insertions(+), 20 deletions(-) diff --git a/cores/database/conn.py b/cores/database/conn.py index edf9c0353..95851faaf 100644 --- a/cores/database/conn.py +++ b/cores/database/conn.py @@ -1,24 +1,86 @@ -import pymysql +import sqlite3 import yaml # TODO: 数据库缓存prompt class dbConn(): def __init__(self): - with open("./configs/config.yaml", 'r', encoding='utf-8') as ymlfile: - cfg = yaml.safe_load(ymlfile) - if cfg['database']['host'] != '' or cfg['database']['port'] or cfg['database']['user'] != '' or cfg['database']['password'] != '' or cfg['database']['db'] != '': - print("读取数据库配置成功") - self.db = pymysql.connect( - host=cfg['database']['host'], - port=cfg['database']['port'], - user=cfg['database']['user'], - password=cfg['database']['password'], - db=cfg['database']['db'], - charset='utf8mb4', - ) - else: - raise BaseException("请在config中完善你的数据库配置") + # 读取参数,并支持中文 + conn = sqlite3.connect("data.db") + conn.text_factory=str + self.conn = conn + c = conn.cursor() + c.execute( + ''' + CREATE TABLE IF NOT EXISTS tb_session( + qq_id VARCHAR(32) PRIMARY KEY, + history TEXT + ) + ''' + ) + + conn.commit() - def getCursor(self): - return self.db.cursor() \ No newline at end of file + def insert_session(self, qq_id, history): + conn = self.conn + c = conn.cursor() + c.execute( + ''' + INSERT INTO tb_session(qq_id, history) VALUES (?, ?) + ''', (qq_id, history) + ) + conn.commit() + + def update_session(self, qq_id, history): + conn = self.conn + c = conn.cursor() + c.execute( + ''' + UPDATE tb_session SET history = ? WHERE qq_id = ? + ''', (history, qq_id) + ) + conn.commit() + + def get_session(self, qq_id): + conn = self.conn + c = conn.cursor() + c.execute( + ''' + SELECT * FROM tb_session WHERE qq_id = ? + ''', (qq_id, ) + ) + return c.fetchone() + + def get_all_session(self): + conn = self.conn + c = conn.cursor() + c.execute( + ''' + SELECT * FROM tb_session + ''' + ) + return c.fetchall() + + def check_session(self, qq_id): + conn = self.conn + c = conn.cursor() + c.execute( + ''' + SELECT * FROM tb_session WHERE qq_id = ? + ''', (qq_id, ) + ) + return c.fetchone() is not None + + def delete_session(self, qq_id): + conn = self.conn + c = conn.cursor() + c.execute( + ''' + DELETE FROM tb_session WHERE qq_id = ? + ''', (qq_id, ) + ) + conn.commit() + + def close(self): + self.conn.close() + \ No newline at end of file diff --git a/cores/qqbot/core.py b/cores/qqbot/core.py index 2bb75b518..5d70249d3 100644 --- a/cores/qqbot/core.py +++ b/cores/qqbot/core.py @@ -8,6 +8,8 @@ import json from concurrent.futures import ThreadPoolExecutor import threading import asyncio +import time +from cores.database.conn import dbConn client = '' # executor = ThreadPoolExecutor(max_workers=10) @@ -77,17 +79,46 @@ def toggle_count(at: bool, message): except BaseException: pass +# 转储历史记录的定时器 Soulter +def dump_history(): + time.sleep(10) + global session_dict + db = dbConn() + while True: + try: + print("转储历史记录...") + for key in session_dict: + # print("TEST: "+str(db.get_session(key))) + data = session_dict[key] + data_json = { + 'data': data + } + if db.check_session(key): + db.update_session(key, json.dumps(data_json)) + else: + db.insert_session(key, json.dumps(data_json)) + print("转储历史记录完毕") + except BaseException as e: + print(e) + # 每隔10分钟转储一次 + time.sleep(10) def initBot(chatgpt_inst): global chatgpt chatgpt = chatgpt_inst - # global db - # db = db_inst + global max_tokens max_tokens = int(chatgpt_inst.getConfigs()['total_tokens_limit']) global gpt_config gpt_config = chatgpt_inst.getConfigs() gpt_config['key'] = "***" + global version + + # 读取历史记录 Soulter + db1 = dbConn() + for session in db1.get_all_session(): + session_dict[session[0]] = json.loads(session[1])['data'] + print("历史记录读取完毕") # 读统计信息 global stat_file @@ -102,6 +133,9 @@ def initBot(chatgpt_inst): except BaseException: pass + # 创建转储定时器线程 + threading.Thread(target=dump_history, daemon=True).start() + global uniqueSession with open("./configs/config.yaml", 'r', encoding='utf-8') as ymlfile: cfg = yaml.safe_load(ymlfile) @@ -109,6 +143,8 @@ def initBot(chatgpt_inst): uniqueSession = True else: uniqueSession = False + if 'version' in cfg: + version = cfg['version'] if cfg['qqbot']['appid'] != '' or cfg['qqbot']['token'] != '': print("读取QQBot appid token 成功") intents = botpy.Intents(public_guild_messages=True, direct_message=True) @@ -286,7 +322,9 @@ def oper_msg(message, at=False, loop=None): cache_prompt += "Human: "+ qq_msg + "\nAI: " # 请求chatGPT获得结果 try: - chatgpt_res, current_usage_tokens = get_chatGPT_response(cache_prompt) + chatgpt_res="test" + current_usage_tokens = 0 + # chatgpt_res, current_usage_tokens = get_chatGPT_response(cache_prompt) except (PromptExceededError) as e: print("出现token超限, 清空对应缓存") # 超过4097tokens错误,清空缓存 diff --git a/main.py b/main.py index 41bae9374..c6f983038 100644 --- a/main.py +++ b/main.py @@ -2,6 +2,7 @@ import cores.qqbot.core as qqBot from cores.openai.core import ChatGPT import asyncio import yaml +import threading # from cores.database.conn import dbConn def main(): # 读取参数 @@ -15,5 +16,8 @@ def main(): # #执行qqBot # qqBot.initBot(chatgpt, db) qqBot.initBot(chatgpt) + if __name__ == "__main__": + # qqbot_thread = threading.Thread(target=main) + # qqbot_thread.start() main() \ No newline at end of file