feat: 支持持久化历史记录

fix: 修复了一些已知问题
This commit is contained in:
Soulter
2022-12-16 23:13:15 +08:00
parent a5fbc82e11
commit 51ee4fc9b3
3 changed files with 124 additions and 20 deletions
+79 -17
View File
@@ -1,24 +1,86 @@
import pymysql
import sqlite3
import yaml
# TODO: 数据库缓存prompt
class dbConn():
def __init__(self):
with open("./configs/config.yaml", 'r', encoding='utf-8') as ymlfile:
cfg = yaml.safe_load(ymlfile)
if cfg['database']['host'] != '' or cfg['database']['port'] or cfg['database']['user'] != '' or cfg['database']['password'] != '' or cfg['database']['db'] != '':
print("读取数据库配置成功")
self.db = pymysql.connect(
host=cfg['database']['host'],
port=cfg['database']['port'],
user=cfg['database']['user'],
password=cfg['database']['password'],
db=cfg['database']['db'],
charset='utf8mb4',
)
else:
raise BaseException("请在config中完善你的数据库配置")
# 读取参数,并支持中文
conn = sqlite3.connect("data.db")
conn.text_factory=str
self.conn = conn
c = conn.cursor()
c.execute(
'''
CREATE TABLE IF NOT EXISTS tb_session(
qq_id VARCHAR(32) PRIMARY KEY,
history TEXT
)
'''
)
conn.commit()
def getCursor(self):
return self.db.cursor()
def insert_session(self, qq_id, history):
conn = self.conn
c = conn.cursor()
c.execute(
'''
INSERT INTO tb_session(qq_id, history) VALUES (?, ?)
''', (qq_id, history)
)
conn.commit()
def update_session(self, qq_id, history):
conn = self.conn
c = conn.cursor()
c.execute(
'''
UPDATE tb_session SET history = ? WHERE qq_id = ?
''', (history, qq_id)
)
conn.commit()
def get_session(self, qq_id):
conn = self.conn
c = conn.cursor()
c.execute(
'''
SELECT * FROM tb_session WHERE qq_id = ?
''', (qq_id, )
)
return c.fetchone()
def get_all_session(self):
conn = self.conn
c = conn.cursor()
c.execute(
'''
SELECT * FROM tb_session
'''
)
return c.fetchall()
def check_session(self, qq_id):
conn = self.conn
c = conn.cursor()
c.execute(
'''
SELECT * FROM tb_session WHERE qq_id = ?
''', (qq_id, )
)
return c.fetchone() is not None
def delete_session(self, qq_id):
conn = self.conn
c = conn.cursor()
c.execute(
'''
DELETE FROM tb_session WHERE qq_id = ?
''', (qq_id, )
)
conn.commit()
def close(self):
self.conn.close()
+41 -3
View File
@@ -8,6 +8,8 @@ import json
from concurrent.futures import ThreadPoolExecutor
import threading
import asyncio
import time
from cores.database.conn import dbConn
client = ''
# executor = ThreadPoolExecutor(max_workers=10)
@@ -77,17 +79,46 @@ def toggle_count(at: bool, message):
except BaseException:
pass
# 转储历史记录的定时器 Soulter
def dump_history():
time.sleep(10)
global session_dict
db = dbConn()
while True:
try:
print("转储历史记录...")
for key in session_dict:
# print("TEST: "+str(db.get_session(key)))
data = session_dict[key]
data_json = {
'data': data
}
if db.check_session(key):
db.update_session(key, json.dumps(data_json))
else:
db.insert_session(key, json.dumps(data_json))
print("转储历史记录完毕")
except BaseException as e:
print(e)
# 每隔10分钟转储一次
time.sleep(10)
def initBot(chatgpt_inst):
global chatgpt
chatgpt = chatgpt_inst
# global db
# db = db_inst
global max_tokens
max_tokens = int(chatgpt_inst.getConfigs()['total_tokens_limit'])
global gpt_config
gpt_config = chatgpt_inst.getConfigs()
gpt_config['key'] = "***"
global version
# 读取历史记录 Soulter
db1 = dbConn()
for session in db1.get_all_session():
session_dict[session[0]] = json.loads(session[1])['data']
print("历史记录读取完毕")
# 读统计信息
global stat_file
@@ -102,6 +133,9 @@ def initBot(chatgpt_inst):
except BaseException:
pass
# 创建转储定时器线程
threading.Thread(target=dump_history, daemon=True).start()
global uniqueSession
with open("./configs/config.yaml", 'r', encoding='utf-8') as ymlfile:
cfg = yaml.safe_load(ymlfile)
@@ -109,6 +143,8 @@ def initBot(chatgpt_inst):
uniqueSession = True
else:
uniqueSession = False
if 'version' in cfg:
version = cfg['version']
if cfg['qqbot']['appid'] != '' or cfg['qqbot']['token'] != '':
print("读取QQBot appid token 成功")
intents = botpy.Intents(public_guild_messages=True, direct_message=True)
@@ -286,7 +322,9 @@ def oper_msg(message, at=False, loop=None):
cache_prompt += "Human: "+ qq_msg + "\nAI: "
# 请求chatGPT获得结果
try:
chatgpt_res, current_usage_tokens = get_chatGPT_response(cache_prompt)
chatgpt_res="test"
current_usage_tokens = 0
# chatgpt_res, current_usage_tokens = get_chatGPT_response(cache_prompt)
except (PromptExceededError) as e:
print("出现token超限, 清空对应缓存")
# 超过4097tokens错误,清空缓存
+4
View File
@@ -2,6 +2,7 @@ import cores.qqbot.core as qqBot
from cores.openai.core import ChatGPT
import asyncio
import yaml
import threading
# from cores.database.conn import dbConn
def main():
# 读取参数
@@ -15,5 +16,8 @@ def main():
# #执行qqBot
# qqBot.initBot(chatgpt, db)
qqBot.initBot(chatgpt)
if __name__ == "__main__":
# qqbot_thread = threading.Thread(target=main)
# qqbot_thread.start()
main()