378 lines
14 KiB
Python
378 lines
14 KiB
Python
import json
|
|
import traceback
|
|
from datetime import datetime
|
|
from io import BytesIO
|
|
|
|
from quart import request, send_file
|
|
|
|
from astrbot.core import logger
|
|
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
|
from astrbot.core.db import BaseDatabase
|
|
|
|
from .route import Response, Route, RouteContext
|
|
|
|
|
|
class ConversationRoute(Route):
|
|
def __init__(
|
|
self,
|
|
context: RouteContext,
|
|
db_helper: BaseDatabase,
|
|
core_lifecycle: AstrBotCoreLifecycle,
|
|
) -> None:
|
|
super().__init__(context)
|
|
self.routes = {
|
|
"/conversation/list": ("GET", self.list_conversations),
|
|
"/conversation/detail": (
|
|
"POST",
|
|
self.get_conv_detail,
|
|
),
|
|
"/conversation/update": ("POST", self.upd_conv),
|
|
"/conversation/delete": ("POST", self.del_conv),
|
|
"/conversation/update_history": (
|
|
"POST",
|
|
self.update_history,
|
|
),
|
|
"/conversation/export": ("POST", self.export_conversations),
|
|
}
|
|
self.db_helper = db_helper
|
|
self.conv_mgr = core_lifecycle.conversation_manager
|
|
self.core_lifecycle = core_lifecycle
|
|
self.register_routes()
|
|
|
|
async def list_conversations(self):
|
|
"""获取对话列表,支持分页、排序和筛选"""
|
|
try:
|
|
# 获取分页参数
|
|
page = request.args.get("page", 1, type=int)
|
|
page_size = request.args.get("page_size", 20, type=int)
|
|
|
|
# 获取筛选参数
|
|
platforms = request.args.get("platforms", "")
|
|
message_types = request.args.get("message_types", "")
|
|
search_query = request.args.get("search", "")
|
|
exclude_ids = request.args.get("exclude_ids", "")
|
|
exclude_platforms = request.args.get("exclude_platforms", "")
|
|
|
|
# 转换为列表
|
|
platform_list = platforms.split(",") if platforms else []
|
|
message_type_list = message_types.split(",") if message_types else []
|
|
exclude_id_list = exclude_ids.split(",") if exclude_ids else []
|
|
exclude_platform_list = (
|
|
exclude_platforms.split(",") if exclude_platforms else []
|
|
)
|
|
|
|
page = max(page, 1)
|
|
if page_size < 1:
|
|
page_size = 20
|
|
page_size = min(page_size, 100)
|
|
|
|
try:
|
|
(
|
|
conversations,
|
|
total_count,
|
|
) = await self.conv_mgr.get_filtered_conversations(
|
|
page=page,
|
|
page_size=page_size,
|
|
platforms=platform_list,
|
|
message_types=message_type_list,
|
|
search_query=search_query,
|
|
exclude_ids=exclude_id_list,
|
|
exclude_platforms=exclude_platform_list,
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"数据库查询出错: {e!s}\n{traceback.format_exc()}")
|
|
return Response().error(f"数据库查询出错: {e!s}").__dict__
|
|
|
|
# 计算总页数
|
|
total_pages = (
|
|
(total_count + page_size - 1) // page_size if total_count > 0 else 1
|
|
)
|
|
|
|
result = {
|
|
"conversations": conversations,
|
|
"pagination": {
|
|
"page": page,
|
|
"page_size": page_size,
|
|
"total": total_count,
|
|
"total_pages": total_pages,
|
|
},
|
|
}
|
|
return Response().ok(result).__dict__
|
|
|
|
except Exception as e:
|
|
error_msg = f"获取对话列表失败: {e!s}\n{traceback.format_exc()}"
|
|
logger.error(error_msg)
|
|
return Response().error(f"获取对话列表失败: {e!s}").__dict__
|
|
|
|
async def get_conv_detail(self):
|
|
"""获取指定对话详情(通过POST请求)"""
|
|
try:
|
|
data = await request.get_json()
|
|
user_id = data.get("user_id")
|
|
cid = data.get("cid")
|
|
|
|
if not user_id or not cid:
|
|
return Response().error("缺少必要参数: user_id 和 cid").__dict__
|
|
|
|
conversation = await self.conv_mgr.get_conversation(
|
|
unified_msg_origin=user_id,
|
|
conversation_id=cid,
|
|
)
|
|
if not conversation:
|
|
return Response().error("对话不存在").__dict__
|
|
|
|
return (
|
|
Response()
|
|
.ok(
|
|
{
|
|
"user_id": user_id,
|
|
"cid": cid,
|
|
"title": conversation.title,
|
|
"persona_id": conversation.persona_id,
|
|
"history": conversation.history,
|
|
"created_at": conversation.created_at,
|
|
"updated_at": conversation.updated_at,
|
|
},
|
|
)
|
|
.__dict__
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"获取对话详情失败: {e!s}\n{traceback.format_exc()}")
|
|
return Response().error(f"获取对话详情失败: {e!s}").__dict__
|
|
|
|
async def upd_conv(self):
|
|
"""更新对话信息(标题和角色ID)"""
|
|
try:
|
|
data = await request.get_json()
|
|
user_id = data.get("user_id")
|
|
cid = data.get("cid")
|
|
title = data.get("title")
|
|
|
|
if not user_id or not cid:
|
|
return Response().error("缺少必要参数: user_id 和 cid").__dict__
|
|
conversation = await self.conv_mgr.get_conversation(
|
|
unified_msg_origin=user_id,
|
|
conversation_id=cid,
|
|
)
|
|
if not conversation:
|
|
return Response().error("对话不存在").__dict__
|
|
|
|
persona_id = data.get("persona_id", conversation.persona_id)
|
|
|
|
if title is not None or persona_id is not None:
|
|
await self.conv_mgr.update_conversation(
|
|
unified_msg_origin=user_id,
|
|
conversation_id=cid,
|
|
title=title,
|
|
persona_id=persona_id,
|
|
)
|
|
return Response().ok({"message": "对话信息更新成功"}).__dict__
|
|
|
|
except Exception as e:
|
|
logger.error(f"更新对话信息失败: {e!s}\n{traceback.format_exc()}")
|
|
return Response().error(f"更新对话信息失败: {e!s}").__dict__
|
|
|
|
async def del_conv(self):
|
|
"""删除对话"""
|
|
try:
|
|
data = await request.get_json()
|
|
|
|
# 检查是否是批量删除
|
|
if "conversations" in data:
|
|
# 批量删除
|
|
conversations = data.get("conversations", [])
|
|
if not conversations:
|
|
return (
|
|
Response().error("批量删除时conversations参数不能为空").__dict__
|
|
)
|
|
|
|
deleted_count = 0
|
|
failed_items = []
|
|
|
|
for conv in conversations:
|
|
user_id = conv.get("user_id")
|
|
cid = conv.get("cid")
|
|
|
|
if not user_id or not cid:
|
|
failed_items.append(
|
|
f"user_id:{user_id}, cid:{cid} - 缺少必要参数",
|
|
)
|
|
continue
|
|
|
|
try:
|
|
await self.core_lifecycle.conversation_manager.delete_conversation(
|
|
unified_msg_origin=user_id,
|
|
conversation_id=cid,
|
|
)
|
|
deleted_count += 1
|
|
except Exception as e:
|
|
failed_items.append(f"user_id:{user_id}, cid:{cid} - {e!s}")
|
|
|
|
message = f"成功删除 {deleted_count} 个对话"
|
|
if failed_items:
|
|
message += f",失败 {len(failed_items)} 个"
|
|
|
|
return (
|
|
Response()
|
|
.ok(
|
|
{
|
|
"message": message,
|
|
"deleted_count": deleted_count,
|
|
"failed_count": len(failed_items),
|
|
"failed_items": failed_items,
|
|
},
|
|
)
|
|
.__dict__
|
|
)
|
|
# 单个删除
|
|
user_id = data.get("user_id")
|
|
cid = data.get("cid")
|
|
|
|
if not user_id or not cid:
|
|
return Response().error("缺少必要参数: user_id 和 cid").__dict__
|
|
|
|
await self.core_lifecycle.conversation_manager.delete_conversation(
|
|
unified_msg_origin=user_id,
|
|
conversation_id=cid,
|
|
)
|
|
return Response().ok({"message": "对话删除成功"}).__dict__
|
|
|
|
except Exception as e:
|
|
logger.error(f"删除对话失败: {e!s}\n{traceback.format_exc()}")
|
|
return Response().error(f"删除对话失败: {e!s}").__dict__
|
|
|
|
async def update_history(self):
|
|
"""更新对话历史内容"""
|
|
try:
|
|
data = await request.get_json()
|
|
user_id = data.get("user_id")
|
|
cid = data.get("cid")
|
|
history = data.get("history")
|
|
|
|
if not user_id or not cid:
|
|
return Response().error("缺少必要参数: user_id 和 cid").__dict__
|
|
|
|
if history is None:
|
|
return Response().error("缺少必要参数: history").__dict__
|
|
|
|
# 历史记录必须是合法的 JSON 字符串
|
|
try:
|
|
if isinstance(history, list):
|
|
history = json.dumps(history)
|
|
else:
|
|
# 验证是否为有效的 JSON 字符串
|
|
json.loads(history)
|
|
except json.JSONDecodeError:
|
|
return (
|
|
Response().error("history 必须是有效的 JSON 字符串或数组").__dict__
|
|
)
|
|
|
|
conversation = await self.conv_mgr.get_conversation(
|
|
unified_msg_origin=user_id,
|
|
conversation_id=cid,
|
|
)
|
|
if not conversation:
|
|
return Response().error("对话不存在").__dict__
|
|
|
|
history = json.loads(history) if isinstance(history, str) else history
|
|
|
|
await self.conv_mgr.update_conversation(
|
|
unified_msg_origin=user_id,
|
|
conversation_id=cid,
|
|
history=history,
|
|
)
|
|
|
|
return Response().ok({"message": "对话历史更新成功"}).__dict__
|
|
|
|
except Exception as e:
|
|
logger.error(f"更新对话历史失败: {e!s}\n{traceback.format_exc()}")
|
|
return Response().error(f"更新对话历史失败: {e!s}").__dict__
|
|
|
|
async def export_conversations(self):
|
|
"""批量导出对话为 JSONL 格式"""
|
|
try:
|
|
data = await request.get_json()
|
|
conversations_to_export = data.get("conversations", [])
|
|
|
|
if not conversations_to_export:
|
|
return Response().error("导出列表不能为空").__dict__
|
|
|
|
# 收集所有对话的内容
|
|
jsonl_lines = []
|
|
exported_count = 0
|
|
failed_items = []
|
|
|
|
for conv_info in conversations_to_export:
|
|
user_id = conv_info.get("user_id")
|
|
cid = conv_info.get("cid")
|
|
|
|
if not user_id or not cid:
|
|
failed_items.append(
|
|
f"user_id:{user_id}, cid:{cid} - 缺少必要参数",
|
|
)
|
|
continue
|
|
|
|
try:
|
|
conversation = await self.conv_mgr.get_conversation(
|
|
unified_msg_origin=user_id,
|
|
conversation_id=cid,
|
|
)
|
|
|
|
if not conversation:
|
|
failed_items.append(
|
|
f"user_id:{user_id}, cid:{cid} - 对话不存在"
|
|
)
|
|
continue
|
|
|
|
# 解析对话内容 (history is always a JSON string from _convert_conv_from_v2_to_v1)
|
|
content = json.loads(conversation.history)
|
|
|
|
# 创建导出记录
|
|
export_record = {
|
|
"cid": cid,
|
|
"user_id": user_id,
|
|
"platform_id": conversation.platform_id,
|
|
"title": conversation.title,
|
|
"persona_id": conversation.persona_id,
|
|
"created_at": conversation.created_at,
|
|
"updated_at": conversation.updated_at,
|
|
"content": content,
|
|
}
|
|
|
|
# 将记录转换为 JSON 字符串并添加到 JSONL
|
|
jsonl_lines.append(json.dumps(export_record, ensure_ascii=False))
|
|
exported_count += 1
|
|
|
|
except Exception as e:
|
|
failed_items.append(f"user_id:{user_id}, cid:{cid} - {e!s}")
|
|
logger.error(
|
|
f"导出对话失败: user_id={user_id}, cid={cid}, error={e!s}"
|
|
)
|
|
|
|
if exported_count == 0:
|
|
return Response().error("没有成功导出任何对话").__dict__
|
|
|
|
# 创建 JSONL 内容
|
|
jsonl_content = "\n".join(jsonl_lines)
|
|
|
|
# 创建一个内存文件对象
|
|
file_obj = BytesIO(jsonl_content.encode("utf-8"))
|
|
file_obj.seek(0)
|
|
|
|
# 生成文件名
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
filename = f"astrbot_conversations_export_{timestamp}.jsonl"
|
|
|
|
# 返回文件流
|
|
return await send_file(
|
|
file_obj,
|
|
mimetype="application/jsonl",
|
|
as_attachment=True,
|
|
attachment_filename=filename,
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"批量导出对话失败: {e!s}\n{traceback.format_exc()}")
|
|
return Response().error(f"批量导出对话失败: {e!s}").__dict__
|