7dd95d8a59
* chore: auto fix by ruff * refactor: 统一修正返回类型注解为 None/bool 以匹配实现 * refactor: 将 _get_next_page 改为异步并移除多余的请求错误抛出 * refactor: 将 get_client 的返回类型改为 object * style: 为 LarkMessageEvent 的相关方法添加返回类型注解 None --------- Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>
416 lines
16 KiB
Python
416 lines
16 KiB
Python
"""AstrBot 会话-对话管理器, 维护两个本地存储, 其中一个是 json 格式的shared_preferences, 另外一个是数据库.
|
|
|
|
在 AstrBot 中, 会话和对话是独立的, 会话用于标记对话窗口, 例如群聊"123456789"可以建立一个会话,
|
|
在一个会话中可以建立多个对话, 并且支持对话的切换和删除
|
|
"""
|
|
|
|
import json
|
|
from collections.abc import Awaitable, Callable
|
|
|
|
from astrbot.core import sp
|
|
from astrbot.core.agent.message import AssistantMessageSegment, UserMessageSegment
|
|
from astrbot.core.db import BaseDatabase
|
|
from astrbot.core.db.po import Conversation, ConversationV2
|
|
|
|
|
|
class ConversationManager:
|
|
"""负责管理会话与 LLM 的对话,某个会话当前正在用哪个对话。"""
|
|
|
|
def __init__(self, db_helper: BaseDatabase) -> None:
|
|
self.session_conversations: dict[str, str] = {}
|
|
self.db = db_helper
|
|
self.save_interval = 60 # 每 60 秒保存一次
|
|
|
|
# 会话删除回调函数列表(用于级联清理,如知识库配置)
|
|
self._on_session_deleted_callbacks: list[Callable[[str], Awaitable[None]]] = []
|
|
|
|
def register_on_session_deleted(
|
|
self,
|
|
callback: Callable[[str], Awaitable[None]],
|
|
) -> None:
|
|
"""注册会话删除回调函数.
|
|
|
|
其他模块可以注册回调来响应会话删除事件,实现级联清理。
|
|
例如:知识库模块可以注册回调来清理会话的知识库配置。
|
|
|
|
Args:
|
|
callback: 回调函数,接收会话ID (unified_msg_origin) 作为参数
|
|
|
|
"""
|
|
self._on_session_deleted_callbacks.append(callback)
|
|
|
|
async def _trigger_session_deleted(self, unified_msg_origin: str) -> None:
|
|
"""触发会话删除回调.
|
|
|
|
Args:
|
|
unified_msg_origin: 会话ID
|
|
|
|
"""
|
|
for callback in self._on_session_deleted_callbacks:
|
|
try:
|
|
await callback(unified_msg_origin)
|
|
except Exception as e:
|
|
from astrbot.core import logger
|
|
|
|
logger.error(
|
|
f"会话删除回调执行失败 (session: {unified_msg_origin}): {e}",
|
|
)
|
|
|
|
def _convert_conv_from_v2_to_v1(self, conv_v2: ConversationV2) -> Conversation:
|
|
"""将 ConversationV2 对象转换为 Conversation 对象"""
|
|
created_at = int(conv_v2.created_at.timestamp())
|
|
updated_at = int(conv_v2.updated_at.timestamp())
|
|
return Conversation(
|
|
platform_id=conv_v2.platform_id,
|
|
user_id=conv_v2.user_id,
|
|
cid=conv_v2.conversation_id,
|
|
history=json.dumps(conv_v2.content or []),
|
|
title=conv_v2.title,
|
|
persona_id=conv_v2.persona_id,
|
|
created_at=created_at,
|
|
updated_at=updated_at,
|
|
token_usage=conv_v2.token_usage,
|
|
)
|
|
|
|
async def new_conversation(
|
|
self,
|
|
unified_msg_origin: str,
|
|
platform_id: str | None = None,
|
|
content: list[dict] | None = None,
|
|
title: str | None = None,
|
|
persona_id: str | None = None,
|
|
) -> str:
|
|
"""新建对话,并将当前会话的对话转移到新对话.
|
|
|
|
Args:
|
|
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
|
Returns:
|
|
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
|
|
|
|
"""
|
|
if not platform_id:
|
|
# 如果没有提供 platform_id,则从 unified_msg_origin 中解析
|
|
parts = unified_msg_origin.split(":")
|
|
if len(parts) >= 3:
|
|
platform_id = parts[0]
|
|
if not platform_id:
|
|
platform_id = "unknown"
|
|
conv = await self.db.create_conversation(
|
|
user_id=unified_msg_origin,
|
|
platform_id=platform_id,
|
|
content=content,
|
|
title=title,
|
|
persona_id=persona_id,
|
|
)
|
|
self.session_conversations[unified_msg_origin] = conv.conversation_id
|
|
await sp.session_put(unified_msg_origin, "sel_conv_id", conv.conversation_id)
|
|
return conv.conversation_id
|
|
|
|
async def switch_conversation(
|
|
self, unified_msg_origin: str, conversation_id: str
|
|
) -> None:
|
|
"""切换会话的对话
|
|
|
|
Args:
|
|
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
|
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
|
|
|
|
"""
|
|
self.session_conversations[unified_msg_origin] = conversation_id
|
|
await sp.session_put(unified_msg_origin, "sel_conv_id", conversation_id)
|
|
|
|
async def delete_conversation(
|
|
self,
|
|
unified_msg_origin: str,
|
|
conversation_id: str | None = None,
|
|
) -> None:
|
|
"""删除会话的对话,当 conversation_id 为 None 时删除会话当前的对话
|
|
|
|
Args:
|
|
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
|
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
|
|
|
|
"""
|
|
if not conversation_id:
|
|
conversation_id = self.session_conversations.get(unified_msg_origin)
|
|
if conversation_id:
|
|
await self.db.delete_conversation(cid=conversation_id)
|
|
curr_cid = await self.get_curr_conversation_id(unified_msg_origin)
|
|
if curr_cid == conversation_id:
|
|
self.session_conversations.pop(unified_msg_origin, None)
|
|
await sp.session_remove(unified_msg_origin, "sel_conv_id")
|
|
|
|
async def delete_conversations_by_user_id(self, unified_msg_origin: str) -> None:
|
|
"""删除会话的所有对话
|
|
|
|
Args:
|
|
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
|
|
|
"""
|
|
await self.db.delete_conversations_by_user_id(user_id=unified_msg_origin)
|
|
self.session_conversations.pop(unified_msg_origin, None)
|
|
await sp.session_remove(unified_msg_origin, "sel_conv_id")
|
|
|
|
# 触发会话删除回调(级联清理)
|
|
await self._trigger_session_deleted(unified_msg_origin)
|
|
|
|
async def get_curr_conversation_id(self, unified_msg_origin: str) -> str | None:
|
|
"""获取会话当前的对话 ID
|
|
|
|
Args:
|
|
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
|
Returns:
|
|
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
|
|
|
|
"""
|
|
ret = self.session_conversations.get(unified_msg_origin, None)
|
|
if not ret:
|
|
ret = await sp.session_get(unified_msg_origin, "sel_conv_id", None)
|
|
if ret:
|
|
self.session_conversations[unified_msg_origin] = ret
|
|
return ret
|
|
|
|
async def get_conversation(
|
|
self,
|
|
unified_msg_origin: str,
|
|
conversation_id: str,
|
|
create_if_not_exists: bool = False,
|
|
) -> Conversation | None:
|
|
"""获取会话的对话.
|
|
|
|
Args:
|
|
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
|
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
|
|
create_if_not_exists (bool): 如果对话不存在,是否创建一个新的对话
|
|
Returns:
|
|
conversation (Conversation): 对话对象
|
|
|
|
"""
|
|
conv = await self.db.get_conversation_by_id(cid=conversation_id)
|
|
if not conv and create_if_not_exists:
|
|
# 如果对话不存在且需要创建,则新建一个对话
|
|
conversation_id = await self.new_conversation(unified_msg_origin)
|
|
conv = await self.db.get_conversation_by_id(cid=conversation_id)
|
|
conv_res = None
|
|
if conv:
|
|
conv_res = self._convert_conv_from_v2_to_v1(conv)
|
|
return conv_res
|
|
|
|
async def get_conversations(
|
|
self,
|
|
unified_msg_origin: str | None = None,
|
|
platform_id: str | None = None,
|
|
) -> list[Conversation]:
|
|
"""获取对话列表.
|
|
|
|
Args:
|
|
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id,可选
|
|
platform_id (str): 平台 ID, 可选参数, 用于过滤对话
|
|
Returns:
|
|
conversations (List[Conversation]): 对话对象列表
|
|
|
|
"""
|
|
convs = await self.db.get_conversations(
|
|
user_id=unified_msg_origin,
|
|
platform_id=platform_id,
|
|
)
|
|
convs_res = []
|
|
for conv in convs:
|
|
conv_res = self._convert_conv_from_v2_to_v1(conv)
|
|
convs_res.append(conv_res)
|
|
return convs_res
|
|
|
|
async def get_filtered_conversations(
|
|
self,
|
|
page: int = 1,
|
|
page_size: int = 20,
|
|
platform_ids: list[str] | None = None,
|
|
search_query: str = "",
|
|
**kwargs,
|
|
) -> tuple[list[Conversation], int]:
|
|
"""获取过滤后的对话列表.
|
|
|
|
Args:
|
|
page (int): 页码, 默认为 1
|
|
page_size (int): 每页大小, 默认为 20
|
|
platform_ids (list[str]): 平台 ID 列表, 可选
|
|
search_query (str): 搜索查询字符串, 可选
|
|
Returns:
|
|
conversations (list[Conversation]): 对话对象列表
|
|
|
|
"""
|
|
convs, cnt = await self.db.get_filtered_conversations(
|
|
page=page,
|
|
page_size=page_size,
|
|
platform_ids=platform_ids,
|
|
search_query=search_query,
|
|
**kwargs,
|
|
)
|
|
convs_res = []
|
|
for conv in convs:
|
|
conv_res = self._convert_conv_from_v2_to_v1(conv)
|
|
convs_res.append(conv_res)
|
|
return convs_res, cnt
|
|
|
|
async def update_conversation(
|
|
self,
|
|
unified_msg_origin: str,
|
|
conversation_id: str | None = None,
|
|
history: list[dict] | None = None,
|
|
title: str | None = None,
|
|
persona_id: str | None = None,
|
|
token_usage: int | None = None,
|
|
) -> None:
|
|
"""更新会话的对话.
|
|
|
|
Args:
|
|
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
|
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
|
|
history (List[Dict]): 对话历史记录, 是一个字典列表, 每个字典包含 role 和 content 字段
|
|
token_usage (int | None): token 使用量。None 表示不更新
|
|
|
|
"""
|
|
if not conversation_id:
|
|
# 如果没有提供 conversation_id,则获取当前的
|
|
conversation_id = await self.get_curr_conversation_id(unified_msg_origin)
|
|
if conversation_id:
|
|
await self.db.update_conversation(
|
|
cid=conversation_id,
|
|
title=title,
|
|
persona_id=persona_id,
|
|
content=history,
|
|
token_usage=token_usage,
|
|
)
|
|
|
|
async def update_conversation_title(
|
|
self,
|
|
unified_msg_origin: str,
|
|
title: str,
|
|
conversation_id: str | None = None,
|
|
) -> None:
|
|
"""更新会话的对话标题.
|
|
|
|
Args:
|
|
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
|
title (str): 对话标题
|
|
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
|
|
Deprecated:
|
|
Use `update_conversation` with `title` parameter instead.
|
|
|
|
"""
|
|
await self.update_conversation(
|
|
unified_msg_origin=unified_msg_origin,
|
|
conversation_id=conversation_id,
|
|
title=title,
|
|
)
|
|
|
|
async def update_conversation_persona_id(
|
|
self,
|
|
unified_msg_origin: str,
|
|
persona_id: str,
|
|
conversation_id: str | None = None,
|
|
) -> None:
|
|
"""更新会话的对话 Persona ID.
|
|
|
|
Args:
|
|
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
|
persona_id (str): 对话 Persona ID
|
|
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
|
|
Deprecated:
|
|
Use `update_conversation` with `persona_id` parameter instead.
|
|
|
|
"""
|
|
await self.update_conversation(
|
|
unified_msg_origin=unified_msg_origin,
|
|
conversation_id=conversation_id,
|
|
persona_id=persona_id,
|
|
)
|
|
|
|
async def add_message_pair(
|
|
self,
|
|
cid: str,
|
|
user_message: UserMessageSegment | dict,
|
|
assistant_message: AssistantMessageSegment | dict,
|
|
) -> None:
|
|
"""Add a user-assistant message pair to the conversation history.
|
|
|
|
Args:
|
|
cid (str): Conversation ID
|
|
user_message (UserMessageSegment | dict): OpenAI-format user message object or dict
|
|
assistant_message (AssistantMessageSegment | dict): OpenAI-format assistant message object or dict
|
|
|
|
Raises:
|
|
Exception: If the conversation with the given ID is not found
|
|
"""
|
|
conv = await self.db.get_conversation_by_id(cid=cid)
|
|
if not conv:
|
|
raise Exception(f"Conversation with id {cid} not found")
|
|
history = conv.content or []
|
|
if isinstance(user_message, UserMessageSegment):
|
|
user_msg_dict = user_message.model_dump()
|
|
else:
|
|
user_msg_dict = user_message
|
|
if isinstance(assistant_message, AssistantMessageSegment):
|
|
assistant_msg_dict = assistant_message.model_dump()
|
|
else:
|
|
assistant_msg_dict = assistant_message
|
|
history.append(user_msg_dict)
|
|
history.append(assistant_msg_dict)
|
|
await self.db.update_conversation(
|
|
cid=cid,
|
|
content=history,
|
|
)
|
|
|
|
async def get_human_readable_context(
|
|
self,
|
|
unified_msg_origin: str,
|
|
conversation_id: str,
|
|
page: int = 1,
|
|
page_size: int = 10,
|
|
) -> tuple[list[str], int]:
|
|
"""获取人类可读的上下文.
|
|
|
|
Args:
|
|
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
|
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
|
|
page (int): 页码
|
|
page_size (int): 每页大小
|
|
|
|
"""
|
|
conversation = await self.get_conversation(unified_msg_origin, conversation_id)
|
|
if not conversation:
|
|
return [], 0
|
|
history = json.loads(conversation.history)
|
|
|
|
# contexts_groups 存放按顺序的段落(每个段落是一个 str 列表),
|
|
# 之后会被展平成一个扁平的 str 列表返回。
|
|
contexts_groups: list[list[str]] = []
|
|
temp_contexts: list[str] = []
|
|
for record in history:
|
|
if record["role"] == "user":
|
|
temp_contexts.append(f"User: {record['content']}")
|
|
elif record["role"] == "assistant":
|
|
if record.get("content"):
|
|
temp_contexts.append(f"Assistant: {record['content']}")
|
|
elif "tool_calls" in record:
|
|
tool_calls_str = json.dumps(
|
|
record["tool_calls"],
|
|
ensure_ascii=False,
|
|
)
|
|
temp_contexts.append(f"Assistant: [函数调用] {tool_calls_str}")
|
|
else:
|
|
temp_contexts.append("Assistant: [未知的内容]")
|
|
contexts_groups.insert(0, temp_contexts)
|
|
temp_contexts = []
|
|
|
|
# 展平分组后的 contexts 列表为单层字符串列表
|
|
contexts = [item for sublist in contexts_groups for item in sublist]
|
|
|
|
# 计算分页
|
|
paged_contexts = contexts[(page - 1) * page_size : page * page_size]
|
|
total_pages = len(contexts) // page_size
|
|
if len(contexts) % page_size != 0:
|
|
total_pages += 1
|
|
|
|
return paged_contexts, total_pages
|