Files
AstrBot/astrbot/core/db/migration/sqlite_v3.py
T
Dt8333 f624971613 chore: fix bunches of type checking errors (#3213)
* chore(core.utils): 🚨 修正错误Lint

* chore(core.provider): 🚨 修复基类错误Lint

* chore(core.utils): 补全session_get()的重载

* chore(core.provider): 🚨 修正实现错误Lint

* chore(core.platform): 🚨 修正platform基类和webchat的错误Lint

* chore(core.platform): 修正错误实现Lint

* fix(core.provider): 修复循环调用和错误assert

* chore(core.platform): 修复部分实现Lint

* chore(core.provider): 补充Dify.text_chat_stream的参数类型

* chore(core.pipeline): 🚨 修复错误Lint

* fix(core.slack): 补充遗漏导入

* chore(core.utils): 修复错误的session_get声明

* chore(core.platform): 移除Lark adapter import中的wildcard

* chore(core.db): 修复声明和部分逻辑

* chore(core.db): 添加typings,使faiss参数能被正确识别。

* chore(core): 修复声明

* chore(core): 修改声明

* chore: 补充faiss声明

* chore(dashboard): 修改实现,减少报错

* chore(package): 修改部分声明与实现,减少报错

* chore(core): 添加Handler的overload,以去除部分assert同时通过类型检查

* chore(core.pipeline): 修改Pipeline Scheduler的execute,将判断属性改为判断类型,通过静态类型检查

* chore(core.config): 添加类型标注,通过类型检查

* chore(core.message): 为File._download_file添加检查,通过类型检查

* fix: 将断言改为条件判断以实现优雅关闭的容错性

* refactor: 移除 discord 客户端中的 assert,改用 if None 判断并抛出异常

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: DiscordPlatformAdapter 对 self.client.user 为 None 做日志并返回,移除断言

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: 增强 Lark 相关空值/异常检查并完善日志输出

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 将断言替换为条件检查并加入日志与错误处理

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* chore: 移除LLM生成的无用注释

* refactor: 使用 File.get_file 替换下载逻辑并移除 assert,提供默认 filename

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: Slack Socket 未初始化抛出运行时异常,图片 URL 判空改为非空判断

* refactor: 将 WeChatPadProAdapter 的断言改为空值判断并添加日志

* refactor: 使用 isinstance 替代断言实现类型判断,便于静态检查

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: 去除cast,直接使用字段与字典访问,修正端口解析

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 使用 match-case 重构 ProviderManager 加载并通过类型检查抛出 TypeError

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: group_name_display 时若 group 对象为空则记录错误并返回

* fix: 将 _get_current_persona_id 的 assert 替换成 if guard 并返回 None

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: 优化插件目录存在性检查及图片URL非空验证,更新JSON排序配置

* fix: 将 datetime_str 的 assert 替换为显式检查并抛出异常

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 移除 cast,改为运行时检查并在找不到调度器时跳过

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 移除 cast,改用 isinstance 检查 FaissVecDB 并警告

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: 删除 typing.cast 导入,并在获取文件绝对路径前校验 file_

* refactor: 移除 typing.cast,简化内容安全检查调用

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 将 PlatformMetadata.id 设为必填并在注册时传入 id,移除 cast

* refactor: 移除 cast,改用 HasInitialize 与 isinstance 进行初始化

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: 为 ProviderManager.initialize 增加ID类型判断,避免 None 导致 get 失败

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 为 OTTSProvider 与 AzureNativeProvider 引入 _client 与 client 属性改进上下文管理

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: 为 Whisper 自托管源添加模型未初始化校验并直接调用 transcribe

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 移除未使用的 cast 导入并简化 platform_name 赋值

* refactor: 引入 cast 并对 id 使用 cast(str, ...) 提升类型安全

* fix: 将 _id_to_sid 返回改为 str,空值返回空串;对 id 与 message_id 使用 cast

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 重构 Discord 处理逻辑:强制 类型转换、优先斜杠指令并优化提及判断

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: 统一对 id 获取执行 cast,并在微信消息解析失败时抛错

* Revert "fix: 去除cast,直接使用字段与字典访问,修正端口解析"

This reverts commit 1cbfdf9d1b.

* fix: 百炼 Rerank 会话关闭时返回空结果;初始化 request.prompt 避免空值拼接

* fix: 统一处理搜索结果链接为字符串,新增 _get_url 助手并适配 Bing/Sogo

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 调整 call_handler 泛型、Discord 通道注解及 FishAudioTTS API 请求类型

* refactor: 使用 col(...) 替代列引用并对结果进行 CursorResult 强转

* chore: ruff format

---------

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
Co-authored-by: Soulter <905617992@qq.com>
2025-12-09 14:13:47 +08:00

500 lines
15 KiB
Python

import sqlite3
import time
from dataclasses import dataclass
from typing import Any
from astrbot.core.db.po import Platform, Stats
@dataclass
class Conversation:
"""LLM 对话存储
对于网页聊天,history 存储了包括指令、回复、图片等在内的所有消息。
对于其他平台的聊天,不存储非 LLM 的回复(因为考虑到已经存储在各自的平台上)。
"""
user_id: str
cid: str
history: str = ""
"""字符串格式的列表。"""
created_at: int = 0
updated_at: int = 0
title: str = ""
persona_id: str = ""
INIT_SQL = """
CREATE TABLE IF NOT EXISTS platform(
name VARCHAR(32),
count INTEGER,
timestamp INTEGER
);
CREATE TABLE IF NOT EXISTS llm(
name VARCHAR(32),
count INTEGER,
timestamp INTEGER
);
CREATE TABLE IF NOT EXISTS plugin(
name VARCHAR(32),
count INTEGER,
timestamp INTEGER
);
CREATE TABLE IF NOT EXISTS command(
name VARCHAR(32),
count INTEGER,
timestamp INTEGER
);
CREATE TABLE IF NOT EXISTS llm_history(
provider_type VARCHAR(32),
session_id VARCHAR(32),
content TEXT
);
-- ATRI
CREATE TABLE IF NOT EXISTS atri_vision(
id TEXT,
url_or_path TEXT,
caption TEXT,
is_meme BOOLEAN,
keywords TEXT,
platform_name VARCHAR(32),
session_id VARCHAR(32),
sender_nickname VARCHAR(32),
timestamp INTEGER
);
CREATE TABLE IF NOT EXISTS webchat_conversation(
user_id TEXT, -- 会话 id
cid TEXT, -- 对话 id
history TEXT,
created_at INTEGER,
updated_at INTEGER,
title TEXT,
persona_id TEXT
);
PRAGMA encoding = 'UTF-8';
"""
class SQLiteDatabase:
def __init__(self, db_path: str) -> None:
super().__init__()
self.db_path = db_path
sql = INIT_SQL
# 初始化数据库
self.conn = self._get_conn(self.db_path)
c = self.conn.cursor()
c.executescript(sql)
self.conn.commit()
# 检查 webchat_conversation 的 title 字段是否存在
c.execute(
"""
PRAGMA table_info(webchat_conversation)
""",
)
res = c.fetchall()
has_title = False
has_persona_id = False
for row in res:
if row[1] == "title":
has_title = True
if row[1] == "persona_id":
has_persona_id = True
if not has_title:
c.execute(
"""
ALTER TABLE webchat_conversation ADD COLUMN title TEXT;
""",
)
self.conn.commit()
if not has_persona_id:
c.execute(
"""
ALTER TABLE webchat_conversation ADD COLUMN persona_id TEXT;
""",
)
self.conn.commit()
c.close()
def _get_conn(self, db_path: str) -> sqlite3.Connection:
conn = sqlite3.connect(self.db_path)
conn.text_factory = str
return conn
def _exec_sql(self, sql: str, params: tuple | None = None):
conn = self.conn
try:
c = self.conn.cursor()
except sqlite3.ProgrammingError:
conn = self._get_conn(self.db_path)
c = conn.cursor()
if params:
c.execute(sql, params)
c.close()
else:
c.execute(sql)
c.close()
conn.commit()
def insert_platform_metrics(self, metrics: dict):
for k, v in metrics.items():
self._exec_sql(
"""
INSERT INTO platform(name, count, timestamp) VALUES (?, ?, ?)
""",
(k, v, int(time.time())),
)
def insert_llm_metrics(self, metrics: dict):
for k, v in metrics.items():
self._exec_sql(
"""
INSERT INTO llm(name, count, timestamp) VALUES (?, ?, ?)
""",
(k, v, int(time.time())),
)
def get_base_stats(self, offset_sec: int = 86400) -> Stats:
"""获取 offset_sec 秒前到现在的基础统计数据"""
where_clause = f" WHERE timestamp >= {int(time.time()) - offset_sec}"
try:
c = self.conn.cursor()
except sqlite3.ProgrammingError:
c = self._get_conn(self.db_path).cursor()
c.execute(
"""
SELECT * FROM platform
"""
+ where_clause,
)
platform = []
for row in c.fetchall():
platform.append(Platform(*row))
c.close()
return Stats(platform=platform)
def get_total_message_count(self) -> int:
try:
c = self.conn.cursor()
except sqlite3.ProgrammingError:
c = self._get_conn(self.db_path).cursor()
c.execute(
"""
SELECT SUM(count) FROM platform
""",
)
res = c.fetchone()
c.close()
return res[0]
def get_grouped_base_stats(self, offset_sec: int = 86400) -> Stats:
"""获取 offset_sec 秒前到现在的基础统计数据(合并)"""
where_clause = f" WHERE timestamp >= {int(time.time()) - offset_sec}"
try:
c = self.conn.cursor()
except sqlite3.ProgrammingError:
c = self._get_conn(self.db_path).cursor()
c.execute(
"""
SELECT name, SUM(count), timestamp FROM platform
"""
+ where_clause
+ " GROUP BY name",
)
platform = []
for row in c.fetchall():
platform.append(Platform(*row))
c.close()
return Stats(platform)
def get_conversation_by_user_id(
self, user_id: str, cid: str
) -> Conversation | None:
try:
c = self.conn.cursor()
except sqlite3.ProgrammingError:
c = self._get_conn(self.db_path).cursor()
c.execute(
"""
SELECT * FROM webchat_conversation WHERE user_id = ? AND cid = ?
""",
(user_id, cid),
)
res = c.fetchone()
c.close()
if not res:
return None
return Conversation(*res)
def new_conversation(self, user_id: str, cid: str):
history = "[]"
updated_at = int(time.time())
created_at = updated_at
self._exec_sql(
"""
INSERT INTO webchat_conversation(user_id, cid, history, updated_at, created_at) VALUES (?, ?, ?, ?, ?)
""",
(user_id, cid, history, updated_at, created_at),
)
def get_conversations(self, user_id: str) -> list[Conversation]:
try:
c = self.conn.cursor()
except sqlite3.ProgrammingError:
c = self._get_conn(self.db_path).cursor()
c.execute(
"""
SELECT cid, created_at, updated_at, title, persona_id FROM webchat_conversation WHERE user_id = ? ORDER BY updated_at DESC
""",
(user_id,),
)
res = c.fetchall()
c.close()
conversations = []
for row in res:
cid = row[0]
created_at = row[1]
updated_at = row[2]
title = row[3]
persona_id = row[4]
conversations.append(
Conversation("", cid, "[]", created_at, updated_at, title, persona_id),
)
return conversations
def update_conversation(self, user_id: str, cid: str, history: str):
"""更新对话,并且同时更新时间"""
updated_at = int(time.time())
self._exec_sql(
"""
UPDATE webchat_conversation SET history = ?, updated_at = ? WHERE user_id = ? AND cid = ?
""",
(history, updated_at, user_id, cid),
)
def update_conversation_title(self, user_id: str, cid: str, title: str):
self._exec_sql(
"""
UPDATE webchat_conversation SET title = ? WHERE user_id = ? AND cid = ?
""",
(title, user_id, cid),
)
def update_conversation_persona_id(self, user_id: str, cid: str, persona_id: str):
self._exec_sql(
"""
UPDATE webchat_conversation SET persona_id = ? WHERE user_id = ? AND cid = ?
""",
(persona_id, user_id, cid),
)
def delete_conversation(self, user_id: str, cid: str):
self._exec_sql(
"""
DELETE FROM webchat_conversation WHERE user_id = ? AND cid = ?
""",
(user_id, cid),
)
def get_all_conversations(
self,
page: int = 1,
page_size: int = 20,
) -> tuple[list[dict[str, Any]], int]:
"""获取所有对话,支持分页,按更新时间降序排序"""
try:
c = self.conn.cursor()
except sqlite3.ProgrammingError:
c = self._get_conn(self.db_path).cursor()
try:
# 获取总记录数
c.execute("""
SELECT COUNT(*) FROM webchat_conversation
""")
total_count = c.fetchone()[0]
# 计算偏移量
offset = (page - 1) * page_size
# 获取分页数据,按更新时间降序排序
c.execute(
"""
SELECT user_id, cid, created_at, updated_at, title, persona_id
FROM webchat_conversation
ORDER BY updated_at DESC
LIMIT ? OFFSET ?
""",
(page_size, offset),
)
rows = c.fetchall()
conversations = []
for row in rows:
user_id, cid, created_at, updated_at, title, persona_id = row
# 确保 cid 是字符串类型且至少有8个字符,否则使用一个默认值
safe_cid = str(cid) if cid else "unknown"
display_cid = safe_cid[:8] if len(safe_cid) >= 8 else safe_cid
conversations.append(
{
"user_id": user_id or "",
"cid": safe_cid,
"title": title or f"对话 {display_cid}",
"persona_id": persona_id or "",
"created_at": created_at or 0,
"updated_at": updated_at or 0,
},
)
return conversations, total_count
except Exception as _:
# 返回空列表和0,确保即使出错也有有效的返回值
return [], 0
finally:
c.close()
def get_filtered_conversations(
self,
page: int = 1,
page_size: int = 20,
platforms: list[str] | None = None,
message_types: list[str] | None = None,
search_query: str | None = None,
exclude_ids: list[str] | None = None,
exclude_platforms: list[str] | None = None,
) -> tuple[list[dict[str, Any]], int]:
"""获取筛选后的对话列表"""
try:
c = self.conn.cursor()
except sqlite3.ProgrammingError:
c = self._get_conn(self.db_path).cursor()
try:
# 构建查询条件
where_clauses = []
params = []
# 平台筛选
if platforms and len(platforms) > 0:
platform_conditions = []
for platform in platforms:
platform_conditions.append("user_id LIKE ?")
params.append(f"{platform}:%")
if platform_conditions:
where_clauses.append(f"({' OR '.join(platform_conditions)})")
# 消息类型筛选
if message_types and len(message_types) > 0:
message_type_conditions = []
for msg_type in message_types:
message_type_conditions.append("user_id LIKE ?")
params.append(f"%:{msg_type}:%")
if message_type_conditions:
where_clauses.append(f"({' OR '.join(message_type_conditions)})")
# 搜索关键词
if search_query:
search_query = search_query.encode("unicode_escape").decode("utf-8")
where_clauses.append(
"(title LIKE ? OR user_id LIKE ? OR cid LIKE ? OR history LIKE ?)",
)
search_param = f"%{search_query}%"
params.extend([search_param, search_param, search_param, search_param])
# 排除特定用户ID
if exclude_ids and len(exclude_ids) > 0:
for exclude_id in exclude_ids:
where_clauses.append("user_id NOT LIKE ?")
params.append(f"{exclude_id}%")
# 排除特定平台
if exclude_platforms and len(exclude_platforms) > 0:
for exclude_platform in exclude_platforms:
where_clauses.append("user_id NOT LIKE ?")
params.append(f"{exclude_platform}:%")
# 构建完整的 WHERE 子句
where_sql = " WHERE " + " AND ".join(where_clauses) if where_clauses else ""
# 构建计数查询
count_sql = f"SELECT COUNT(*) FROM webchat_conversation{where_sql}"
# 获取总记录数
c.execute(count_sql, params)
total_count = c.fetchone()[0]
# 计算偏移量
offset = (page - 1) * page_size
# 构建分页数据查询
data_sql = f"""
SELECT user_id, cid, created_at, updated_at, title, persona_id
FROM webchat_conversation
{where_sql}
ORDER BY updated_at DESC
LIMIT ? OFFSET ?
"""
query_params = params + [page_size, offset]
# 获取分页数据
c.execute(data_sql, query_params)
rows = c.fetchall()
conversations = []
for row in rows:
user_id, cid, created_at, updated_at, title, persona_id = row
# 确保 cid 是字符串类型,否则使用一个默认值
safe_cid = str(cid) if cid else "unknown"
display_cid = safe_cid[:8] if len(safe_cid) >= 8 else safe_cid
conversations.append(
{
"user_id": user_id or "",
"cid": safe_cid,
"title": title or f"对话 {display_cid}",
"persona_id": persona_id or "",
"created_at": created_at or 0,
"updated_at": updated_at or 0,
},
)
return conversations, total_count
except Exception as _:
# 返回空列表和0,确保即使出错也有有效的返回值
return [], 0
finally:
c.close()