style: format codes
This commit is contained in:
+133
-107
@@ -1,13 +1,7 @@
|
||||
import sqlite3
|
||||
import os
|
||||
import time
|
||||
from astrbot.core.db.po import (
|
||||
Platform,
|
||||
Stats,
|
||||
LLMHistory,
|
||||
ATRIVision,
|
||||
Conversation
|
||||
)
|
||||
from astrbot.core.db.po import Platform, Stats, LLMHistory, ATRIVision, Conversation
|
||||
from . import BaseDatabase
|
||||
from typing import Tuple
|
||||
|
||||
@@ -16,21 +10,21 @@ class SQLiteDatabase(BaseDatabase):
|
||||
def __init__(self, db_path: str) -> None:
|
||||
super().__init__()
|
||||
self.db_path = db_path
|
||||
|
||||
|
||||
with open(os.path.dirname(__file__) + "/sqlite_init.sql", "r") as f:
|
||||
sql = f.read()
|
||||
|
||||
|
||||
# 初始化数据库
|
||||
self.conn = self._get_conn(self.db_path)
|
||||
c = self.conn.cursor()
|
||||
c.executescript(sql)
|
||||
self.conn.commit()
|
||||
|
||||
|
||||
# 检查 webchat_conversation 的 title 字段是否存在
|
||||
c.execute(
|
||||
'''
|
||||
"""
|
||||
PRAGMA table_info(webchat_conversation)
|
||||
'''
|
||||
"""
|
||||
)
|
||||
res = c.fetchall()
|
||||
has_title = False
|
||||
@@ -42,26 +36,26 @@ class SQLiteDatabase(BaseDatabase):
|
||||
has_persona_id = True
|
||||
if not has_title:
|
||||
c.execute(
|
||||
'''
|
||||
"""
|
||||
ALTER TABLE webchat_conversation ADD COLUMN title TEXT;
|
||||
'''
|
||||
"""
|
||||
)
|
||||
self.conn.commit()
|
||||
if not has_persona_id:
|
||||
c.execute(
|
||||
'''
|
||||
"""
|
||||
ALTER TABLE webchat_conversation ADD COLUMN persona_id TEXT;
|
||||
'''
|
||||
"""
|
||||
)
|
||||
self.conn.commit()
|
||||
|
||||
|
||||
c.close()
|
||||
|
||||
|
||||
def _get_conn(self, db_path: str) -> sqlite3.Connection:
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
conn.text_factory = str
|
||||
return conn
|
||||
|
||||
|
||||
def _exec_sql(self, sql: str, params: Tuple = None):
|
||||
conn = self.conn
|
||||
try:
|
||||
@@ -69,22 +63,23 @@ class SQLiteDatabase(BaseDatabase):
|
||||
except sqlite3.ProgrammingError:
|
||||
conn = self._get_conn(self.db_path)
|
||||
c = conn.cursor()
|
||||
|
||||
|
||||
if params:
|
||||
c.execute(sql, params)
|
||||
c.close()
|
||||
else:
|
||||
c.execute(sql)
|
||||
c.close()
|
||||
|
||||
|
||||
conn.commit()
|
||||
|
||||
|
||||
def insert_platform_metrics(self, metrics: dict):
|
||||
for k, v in metrics.items():
|
||||
self._exec_sql(
|
||||
'''
|
||||
"""
|
||||
INSERT INTO platform(name, count, timestamp) VALUES (?, ?, ?)
|
||||
''', (k, v, int(time.time()))
|
||||
""",
|
||||
(k, v, int(time.time())),
|
||||
)
|
||||
|
||||
def insert_plugin_metrics(self, metrics: dict):
|
||||
@@ -93,40 +88,46 @@ class SQLiteDatabase(BaseDatabase):
|
||||
def insert_command_metrics(self, metrics: dict):
|
||||
for k, v in metrics.items():
|
||||
self._exec_sql(
|
||||
'''
|
||||
"""
|
||||
INSERT INTO command(name, count, timestamp) VALUES (?, ?, ?)
|
||||
''', (k, v, int(time.time()))
|
||||
""",
|
||||
(k, v, int(time.time())),
|
||||
)
|
||||
|
||||
def insert_llm_metrics(self, metrics: dict):
|
||||
for k, v in metrics.items():
|
||||
self._exec_sql(
|
||||
'''
|
||||
"""
|
||||
INSERT INTO llm(name, count, timestamp) VALUES (?, ?, ?)
|
||||
''', (k, v, int(time.time()))
|
||||
""",
|
||||
(k, v, int(time.time())),
|
||||
)
|
||||
|
||||
def update_llm_history(self, session_id: str, content: str, provider_type: str):
|
||||
res = self.get_llm_history(session_id, provider_type)
|
||||
if res:
|
||||
self._exec_sql(
|
||||
'''
|
||||
"""
|
||||
UPDATE llm_history SET content = ? WHERE session_id = ? AND provider_type = ?
|
||||
''', (content, session_id, provider_type)
|
||||
""",
|
||||
(content, session_id, provider_type),
|
||||
)
|
||||
else:
|
||||
self._exec_sql(
|
||||
'''
|
||||
"""
|
||||
INSERT INTO llm_history(provider_type, session_id, content) VALUES (?, ?, ?)
|
||||
''', (provider_type, session_id, content)
|
||||
""",
|
||||
(provider_type, session_id, content),
|
||||
)
|
||||
|
||||
def get_llm_history(self, session_id: str = None, provider_type: str = None) -> Tuple:
|
||||
def get_llm_history(
|
||||
self, session_id: str = None, provider_type: str = None
|
||||
) -> Tuple:
|
||||
try:
|
||||
c = self.conn.cursor()
|
||||
except sqlite3.ProgrammingError:
|
||||
c = self._get_conn(self.db_path).cursor()
|
||||
|
||||
|
||||
where_clause = ""
|
||||
if session_id or provider_type:
|
||||
where_clause += " WHERE "
|
||||
@@ -138,11 +139,12 @@ class SQLiteDatabase(BaseDatabase):
|
||||
if has:
|
||||
where_clause += " AND "
|
||||
where_clause += f"provider_type = '{provider_type}'"
|
||||
|
||||
|
||||
c.execute(
|
||||
'''
|
||||
"""
|
||||
SELECT * FROM llm_history
|
||||
''' + where_clause
|
||||
"""
|
||||
+ where_clause
|
||||
)
|
||||
res = c.fetchall()
|
||||
histories = []
|
||||
@@ -152,129 +154,134 @@ class SQLiteDatabase(BaseDatabase):
|
||||
return histories
|
||||
|
||||
def get_base_stats(self, offset_sec: int = 86400) -> Stats:
|
||||
'''获取 offset_sec 秒前到现在的基础统计数据'''
|
||||
"""获取 offset_sec 秒前到现在的基础统计数据"""
|
||||
where_clause = f" WHERE timestamp >= {int(time.time()) - offset_sec}"
|
||||
|
||||
|
||||
try:
|
||||
c = self.conn.cursor()
|
||||
except sqlite3.ProgrammingError:
|
||||
c = self._get_conn(self.db_path).cursor()
|
||||
|
||||
|
||||
c.execute(
|
||||
'''
|
||||
"""
|
||||
SELECT * FROM platform
|
||||
''' + where_clause
|
||||
"""
|
||||
+ where_clause
|
||||
)
|
||||
|
||||
|
||||
platform = []
|
||||
for row in c.fetchall():
|
||||
platform.append(Platform(*row))
|
||||
|
||||
|
||||
# c.execute(
|
||||
# '''
|
||||
# SELECT * FROM command
|
||||
# ''' + where_clause
|
||||
# )
|
||||
|
||||
|
||||
# command = []
|
||||
# for row in c.fetchall():
|
||||
# command.append(Command(*row))
|
||||
|
||||
|
||||
# c.execute(
|
||||
# '''
|
||||
# SELECT * FROM llm
|
||||
# ''' + where_clause
|
||||
# )
|
||||
|
||||
|
||||
# llm = []
|
||||
# for row in c.fetchall():
|
||||
# llm.append(Provider(*row))
|
||||
|
||||
|
||||
c.close()
|
||||
|
||||
|
||||
return Stats(platform, [], [])
|
||||
|
||||
|
||||
def get_total_message_count(self) -> int:
|
||||
try:
|
||||
c = self.conn.cursor()
|
||||
except sqlite3.ProgrammingError:
|
||||
c = self._get_conn(self.db_path).cursor()
|
||||
|
||||
|
||||
c.execute(
|
||||
'''
|
||||
"""
|
||||
SELECT SUM(count) FROM platform
|
||||
'''
|
||||
"""
|
||||
)
|
||||
res = c.fetchone()
|
||||
c.close()
|
||||
return res[0]
|
||||
|
||||
|
||||
def get_grouped_base_stats(self, offset_sec: int = 86400) -> Stats:
|
||||
'''获取 offset_sec 秒前到现在的基础统计数据(合并)'''
|
||||
"""获取 offset_sec 秒前到现在的基础统计数据(合并)"""
|
||||
where_clause = f" WHERE timestamp >= {int(time.time()) - offset_sec}"
|
||||
|
||||
|
||||
try:
|
||||
c = self.conn.cursor()
|
||||
except sqlite3.ProgrammingError:
|
||||
c = self._get_conn(self.db_path).cursor()
|
||||
|
||||
|
||||
c.execute(
|
||||
'''
|
||||
"""
|
||||
SELECT name, SUM(count), timestamp FROM platform
|
||||
''' + where_clause + " GROUP BY name"
|
||||
"""
|
||||
+ where_clause
|
||||
+ " GROUP BY name"
|
||||
)
|
||||
|
||||
|
||||
platform = []
|
||||
for row in c.fetchall():
|
||||
platform.append(Platform(*row))
|
||||
|
||||
|
||||
c.close()
|
||||
|
||||
|
||||
return Stats(platform, [], [])
|
||||
|
||||
|
||||
|
||||
def get_conversation_by_user_id(self, user_id: str, cid: str) -> Conversation:
|
||||
try:
|
||||
c = self.conn.cursor()
|
||||
except sqlite3.ProgrammingError:
|
||||
c = self._get_conn(self.db_path).cursor()
|
||||
|
||||
|
||||
c.execute(
|
||||
'''
|
||||
"""
|
||||
SELECT * FROM webchat_conversation WHERE user_id = ? AND cid = ?
|
||||
''', (user_id, cid)
|
||||
""",
|
||||
(user_id, cid),
|
||||
)
|
||||
|
||||
|
||||
res = c.fetchone()
|
||||
c.close()
|
||||
|
||||
|
||||
if not res:
|
||||
return
|
||||
|
||||
|
||||
return Conversation(*res)
|
||||
|
||||
|
||||
def new_conversation(self, user_id: str, cid: str):
|
||||
history = "[]"
|
||||
updated_at = int(time.time())
|
||||
created_at = updated_at
|
||||
self._exec_sql(
|
||||
'''
|
||||
"""
|
||||
INSERT INTO webchat_conversation(user_id, cid, history, updated_at, created_at) VALUES (?, ?, ?, ?, ?)
|
||||
''', (user_id, cid, history, updated_at, created_at)
|
||||
""",
|
||||
(user_id, cid, history, updated_at, created_at),
|
||||
)
|
||||
|
||||
|
||||
def get_conversations(self, user_id: str) -> Tuple:
|
||||
try:
|
||||
c = self.conn.cursor()
|
||||
except sqlite3.ProgrammingError:
|
||||
c = self._get_conn(self.db_path).cursor()
|
||||
|
||||
|
||||
c.execute(
|
||||
'''
|
||||
"""
|
||||
SELECT cid, created_at, updated_at, title, persona_id FROM webchat_conversation WHERE user_id = ? ORDER BY updated_at DESC
|
||||
''', (user_id,)
|
||||
""",
|
||||
(user_id,),
|
||||
)
|
||||
|
||||
|
||||
res = c.fetchall()
|
||||
c.close()
|
||||
conversations = []
|
||||
@@ -284,82 +291,101 @@ class SQLiteDatabase(BaseDatabase):
|
||||
updated_at = row[2]
|
||||
title = row[3]
|
||||
persona_id = row[4]
|
||||
conversations.append(Conversation("", cid, '[]', created_at, updated_at, title, persona_id))
|
||||
conversations.append(
|
||||
Conversation("", cid, "[]", created_at, updated_at, title, persona_id)
|
||||
)
|
||||
return conversations
|
||||
|
||||
|
||||
def update_conversation(self, user_id: str, cid: str, history: str):
|
||||
'''更新对话,并且同时更新时间'''
|
||||
"""更新对话,并且同时更新时间"""
|
||||
updated_at = int(time.time())
|
||||
self._exec_sql(
|
||||
'''
|
||||
"""
|
||||
UPDATE webchat_conversation SET history = ?, updated_at = ? WHERE user_id = ? AND cid = ?
|
||||
''', (history, updated_at, user_id, cid)
|
||||
""",
|
||||
(history, updated_at, user_id, cid),
|
||||
)
|
||||
|
||||
|
||||
|
||||
def update_conversation_title(self, user_id: str, cid: str, title: str):
|
||||
self._exec_sql(
|
||||
'''
|
||||
"""
|
||||
UPDATE webchat_conversation SET title = ? WHERE user_id = ? AND cid = ?
|
||||
''', (title, user_id, cid)
|
||||
""",
|
||||
(title, user_id, cid),
|
||||
)
|
||||
|
||||
|
||||
def update_conversation_persona_id(self, user_id: str, cid: str, persona_id: str):
|
||||
self._exec_sql(
|
||||
'''
|
||||
"""
|
||||
UPDATE webchat_conversation SET persona_id = ? WHERE user_id = ? AND cid = ?
|
||||
''', (persona_id, user_id, cid)
|
||||
""",
|
||||
(persona_id, user_id, cid),
|
||||
)
|
||||
|
||||
|
||||
def delete_conversation(self, user_id: str, cid: str):
|
||||
self._exec_sql(
|
||||
'''
|
||||
"""
|
||||
DELETE FROM webchat_conversation WHERE user_id = ? AND cid = ?
|
||||
''', (user_id, cid)
|
||||
""",
|
||||
(user_id, cid),
|
||||
)
|
||||
|
||||
def insert_atri_vision_data(self, vision: ATRIVision):
|
||||
ts = int(time.time())
|
||||
keywords = ",".join(vision.keywords)
|
||||
self._exec_sql(
|
||||
'''
|
||||
"""
|
||||
INSERT INTO atri_vision(id, url_or_path, caption, is_meme, keywords, platform_name, session_id, sender_nickname, timestamp) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
''', (vision.id, vision.url_or_path, vision.caption, vision.is_meme, keywords, vision.platform_name, vision.session_id, vision.sender_nickname, ts)
|
||||
""",
|
||||
(
|
||||
vision.id,
|
||||
vision.url_or_path,
|
||||
vision.caption,
|
||||
vision.is_meme,
|
||||
keywords,
|
||||
vision.platform_name,
|
||||
vision.session_id,
|
||||
vision.sender_nickname,
|
||||
ts,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def get_atri_vision_data(self) -> Tuple:
|
||||
try:
|
||||
c = self.conn.cursor()
|
||||
except sqlite3.ProgrammingError:
|
||||
c = self._get_conn(self.db_path).cursor()
|
||||
|
||||
|
||||
c.execute(
|
||||
'''
|
||||
"""
|
||||
SELECT * FROM atri_vision
|
||||
'''
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
res = c.fetchall()
|
||||
visions = []
|
||||
for row in res:
|
||||
visions.append(ATRIVision(*row))
|
||||
c.close()
|
||||
return visions
|
||||
|
||||
def get_atri_vision_data_by_path_or_id(self, url_or_path: str, id: str) -> ATRIVision:
|
||||
|
||||
def get_atri_vision_data_by_path_or_id(
|
||||
self, url_or_path: str, id: str
|
||||
) -> ATRIVision:
|
||||
try:
|
||||
c = self.conn.cursor()
|
||||
except sqlite3.ProgrammingError:
|
||||
c = self._get_conn(self.db_path).cursor()
|
||||
|
||||
|
||||
c.execute(
|
||||
'''
|
||||
"""
|
||||
SELECT * FROM atri_vision WHERE url_or_path = ? OR id = ?
|
||||
''', (url_or_path, id)
|
||||
""",
|
||||
(url_or_path, id),
|
||||
)
|
||||
|
||||
|
||||
res = c.fetchone()
|
||||
c.close()
|
||||
if res:
|
||||
return ATRIVision(*res)
|
||||
return None
|
||||
return None
|
||||
|
||||
Reference in New Issue
Block a user