feat: 支持持久化历史记录
fix: 修复了一些已知问题
This commit is contained in:
+79
-17
@@ -1,24 +1,86 @@
|
||||
import pymysql
|
||||
import sqlite3
|
||||
import yaml
|
||||
|
||||
# TODO: 数据库缓存prompt
|
||||
|
||||
class dbConn():
|
||||
def __init__(self):
|
||||
with open("./configs/config.yaml", 'r', encoding='utf-8') as ymlfile:
|
||||
cfg = yaml.safe_load(ymlfile)
|
||||
if cfg['database']['host'] != '' or cfg['database']['port'] or cfg['database']['user'] != '' or cfg['database']['password'] != '' or cfg['database']['db'] != '':
|
||||
print("读取数据库配置成功")
|
||||
self.db = pymysql.connect(
|
||||
host=cfg['database']['host'],
|
||||
port=cfg['database']['port'],
|
||||
user=cfg['database']['user'],
|
||||
password=cfg['database']['password'],
|
||||
db=cfg['database']['db'],
|
||||
charset='utf8mb4',
|
||||
)
|
||||
else:
|
||||
raise BaseException("请在config中完善你的数据库配置")
|
||||
# 读取参数,并支持中文
|
||||
conn = sqlite3.connect("data.db")
|
||||
conn.text_factory=str
|
||||
self.conn = conn
|
||||
c = conn.cursor()
|
||||
c.execute(
|
||||
'''
|
||||
CREATE TABLE IF NOT EXISTS tb_session(
|
||||
qq_id VARCHAR(32) PRIMARY KEY,
|
||||
history TEXT
|
||||
)
|
||||
'''
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
|
||||
def getCursor(self):
|
||||
return self.db.cursor()
|
||||
def insert_session(self, qq_id, history):
|
||||
conn = self.conn
|
||||
c = conn.cursor()
|
||||
c.execute(
|
||||
'''
|
||||
INSERT INTO tb_session(qq_id, history) VALUES (?, ?)
|
||||
''', (qq_id, history)
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
def update_session(self, qq_id, history):
|
||||
conn = self.conn
|
||||
c = conn.cursor()
|
||||
c.execute(
|
||||
'''
|
||||
UPDATE tb_session SET history = ? WHERE qq_id = ?
|
||||
''', (history, qq_id)
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
def get_session(self, qq_id):
|
||||
conn = self.conn
|
||||
c = conn.cursor()
|
||||
c.execute(
|
||||
'''
|
||||
SELECT * FROM tb_session WHERE qq_id = ?
|
||||
''', (qq_id, )
|
||||
)
|
||||
return c.fetchone()
|
||||
|
||||
def get_all_session(self):
|
||||
conn = self.conn
|
||||
c = conn.cursor()
|
||||
c.execute(
|
||||
'''
|
||||
SELECT * FROM tb_session
|
||||
'''
|
||||
)
|
||||
return c.fetchall()
|
||||
|
||||
def check_session(self, qq_id):
|
||||
conn = self.conn
|
||||
c = conn.cursor()
|
||||
c.execute(
|
||||
'''
|
||||
SELECT * FROM tb_session WHERE qq_id = ?
|
||||
''', (qq_id, )
|
||||
)
|
||||
return c.fetchone() is not None
|
||||
|
||||
def delete_session(self, qq_id):
|
||||
conn = self.conn
|
||||
c = conn.cursor()
|
||||
c.execute(
|
||||
'''
|
||||
DELETE FROM tb_session WHERE qq_id = ?
|
||||
''', (qq_id, )
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
def close(self):
|
||||
self.conn.close()
|
||||
|
||||
+41
-3
@@ -8,6 +8,8 @@ import json
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import threading
|
||||
import asyncio
|
||||
import time
|
||||
from cores.database.conn import dbConn
|
||||
|
||||
client = ''
|
||||
# executor = ThreadPoolExecutor(max_workers=10)
|
||||
@@ -77,17 +79,46 @@ def toggle_count(at: bool, message):
|
||||
except BaseException:
|
||||
pass
|
||||
|
||||
# 转储历史记录的定时器 Soulter
|
||||
def dump_history():
|
||||
time.sleep(10)
|
||||
global session_dict
|
||||
db = dbConn()
|
||||
while True:
|
||||
try:
|
||||
print("转储历史记录...")
|
||||
for key in session_dict:
|
||||
# print("TEST: "+str(db.get_session(key)))
|
||||
data = session_dict[key]
|
||||
data_json = {
|
||||
'data': data
|
||||
}
|
||||
if db.check_session(key):
|
||||
db.update_session(key, json.dumps(data_json))
|
||||
else:
|
||||
db.insert_session(key, json.dumps(data_json))
|
||||
print("转储历史记录完毕")
|
||||
except BaseException as e:
|
||||
print(e)
|
||||
# 每隔10分钟转储一次
|
||||
time.sleep(10)
|
||||
|
||||
def initBot(chatgpt_inst):
|
||||
global chatgpt
|
||||
chatgpt = chatgpt_inst
|
||||
# global db
|
||||
# db = db_inst
|
||||
|
||||
global max_tokens
|
||||
max_tokens = int(chatgpt_inst.getConfigs()['total_tokens_limit'])
|
||||
global gpt_config
|
||||
gpt_config = chatgpt_inst.getConfigs()
|
||||
gpt_config['key'] = "***"
|
||||
global version
|
||||
|
||||
# 读取历史记录 Soulter
|
||||
db1 = dbConn()
|
||||
for session in db1.get_all_session():
|
||||
session_dict[session[0]] = json.loads(session[1])['data']
|
||||
print("历史记录读取完毕")
|
||||
|
||||
# 读统计信息
|
||||
global stat_file
|
||||
@@ -102,6 +133,9 @@ def initBot(chatgpt_inst):
|
||||
except BaseException:
|
||||
pass
|
||||
|
||||
# 创建转储定时器线程
|
||||
threading.Thread(target=dump_history, daemon=True).start()
|
||||
|
||||
global uniqueSession
|
||||
with open("./configs/config.yaml", 'r', encoding='utf-8') as ymlfile:
|
||||
cfg = yaml.safe_load(ymlfile)
|
||||
@@ -109,6 +143,8 @@ def initBot(chatgpt_inst):
|
||||
uniqueSession = True
|
||||
else:
|
||||
uniqueSession = False
|
||||
if 'version' in cfg:
|
||||
version = cfg['version']
|
||||
if cfg['qqbot']['appid'] != '' or cfg['qqbot']['token'] != '':
|
||||
print("读取QQBot appid token 成功")
|
||||
intents = botpy.Intents(public_guild_messages=True, direct_message=True)
|
||||
@@ -286,7 +322,9 @@ def oper_msg(message, at=False, loop=None):
|
||||
cache_prompt += "Human: "+ qq_msg + "\nAI: "
|
||||
# 请求chatGPT获得结果
|
||||
try:
|
||||
chatgpt_res, current_usage_tokens = get_chatGPT_response(cache_prompt)
|
||||
chatgpt_res="test"
|
||||
current_usage_tokens = 0
|
||||
# chatgpt_res, current_usage_tokens = get_chatGPT_response(cache_prompt)
|
||||
except (PromptExceededError) as e:
|
||||
print("出现token超限, 清空对应缓存")
|
||||
# 超过4097tokens错误,清空缓存
|
||||
|
||||
@@ -2,6 +2,7 @@ import cores.qqbot.core as qqBot
|
||||
from cores.openai.core import ChatGPT
|
||||
import asyncio
|
||||
import yaml
|
||||
import threading
|
||||
# from cores.database.conn import dbConn
|
||||
def main():
|
||||
# 读取参数
|
||||
@@ -15,5 +16,8 @@ def main():
|
||||
# #执行qqBot
|
||||
# qqBot.initBot(chatgpt, db)
|
||||
qqBot.initBot(chatgpt)
|
||||
|
||||
if __name__ == "__main__":
|
||||
# qqbot_thread = threading.Thread(target=main)
|
||||
# qqbot_thread.start()
|
||||
main()
|
||||
Reference in New Issue
Block a user