From 12f4e1146ff4cc40ebac48fbb9800f194b2099ad Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Wed, 5 Feb 2025 13:26:53 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=9B=B4=E5=A5=BD=E7=9A=84=E5=AF=B9?= =?UTF-8?q?=E8=AF=9D=E7=AE=A1=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/core/conversation_mgr.py | 118 +++++++++++++ astrbot/core/core_lifecycle.py | 5 +- astrbot/core/db/__init__.py | 30 ++-- astrbot/core/db/po.py | 17 +- astrbot/core/db/sqlite.py | 73 ++++++-- astrbot/core/db/sqlite_init.sql | 4 +- .../process_stage/method/llm_request.py | 46 ++++- astrbot/core/provider/entites.py | 3 +- astrbot/core/provider/manager.py | 9 +- astrbot/core/provider/provider.py | 59 +++---- .../core/provider/sources/gemini_source.py | 87 +--------- .../core/provider/sources/llmtuner_source.py | 63 +------ .../core/provider/sources/openai_source.py | 76 +------- astrbot/core/provider/sources/zhipu_source.py | 8 +- astrbot/core/star/context.py | 3 + astrbot/core/utils/param_validation_mixin.py | 4 + astrbot/dashboard/routes/chat.py | 16 +- packages/astrbot/main.py | 164 +++++++++++++----- 18 files changed, 437 insertions(+), 348 deletions(-) create mode 100644 astrbot/core/conversation_mgr.py diff --git a/astrbot/core/conversation_mgr.py b/astrbot/core/conversation_mgr.py new file mode 100644 index 000000000..7aca80296 --- /dev/null +++ b/astrbot/core/conversation_mgr.py @@ -0,0 +1,118 @@ +import uuid +import json +import asyncio +from astrbot.core import sp +from typing import Dict, List +from astrbot.core.db import BaseDatabase +from astrbot.core.db.po import Conversation + +class ConversationManager(): + '''负责管理会话与 LLM 的对话,某个会话当前正在用哪个对话。''' + def __init__(self, db_helper: BaseDatabase): + self.session_conversations: Dict[str, str] = sp.get("session_conversation", {}) + self.db = db_helper + self.save_interval = 60 # 每 60 秒保存一次 + self._start_periodic_save() + + def _start_periodic_save(self): + asyncio.create_task(self._periodic_save()) + + async def _periodic_save(self): + while True: + await asyncio.sleep(self.save_interval) + self._save_to_storage() + + def _save_to_storage(self): + sp.put("session_conversation", self.session_conversations) + + async def new_conversation(self, unified_msg_origin: str) -> str: + '''新建对话,并将当前会话的对话转移到新对话''' + conversation_id = str(uuid.uuid4()) + self.db.new_conversation( + user_id=unified_msg_origin, + cid=conversation_id + ) + self.session_conversations[unified_msg_origin] = conversation_id + sp.put("session_conversation", self.session_conversations) + return conversation_id + + async def switch_conversation(self, unified_msg_origin: str, conversation_id: str): + '''切换会话的对话''' + self.session_conversations[unified_msg_origin] = conversation_id + sp.put("session_conversation", self.session_conversations) + + async def delete_conversation(self, unified_msg_origin: str, conversation_id: str=None): + '''删除会话的对话,当 conversation_id 为 None 时删除会话当前的对话''' + conversation_id = self.session_conversations.get(unified_msg_origin) + if conversation_id: + self.db.delete_conversation( + user_id=unified_msg_origin, + cid=conversation_id + ) + del self.session_conversations[unified_msg_origin] + + async def get_curr_conversation_id(self, unified_msg_origin: str) -> str: + '''获取会话当前的对话 ID''' + return self.session_conversations.get(unified_msg_origin, None) + + async def get_conversation(self, unified_msg_origin: str, conversation_id: str) -> Conversation: + '''获取会话的对话''' + return self.db.get_conversation_by_user_id(unified_msg_origin, conversation_id) + + async def get_conversations(self, unified_msg_origin: str) -> List[Conversation]: + '''获取会话的所有对话''' + return self.db.get_conversations(unified_msg_origin) + + async def update_conversation(self, unified_msg_origin: str, conversation_id: str, history: List[Dict]): + '''更新会话的对话''' + if conversation_id: + self.db.update_conversation( + user_id=unified_msg_origin, + cid=conversation_id, + history=json.dumps(history) + ) + + async def update_conversation_title(self, unified_msg_origin: str, title: str): + '''更新会话的对话标题''' + conversation_id = self.session_conversations.get(unified_msg_origin) + if conversation_id: + self.db.update_conversation_title( + user_id=unified_msg_origin, + cid=conversation_id, + title=title + ) + + async def update_conversation_persona_id(self, unified_msg_origin: str, persona_id: str): + '''更新会话的对话 Persona ID''' + conversation_id = self.session_conversations.get(unified_msg_origin) + if conversation_id: + self.db.update_conversation_persona_id( + user_id=unified_msg_origin, + cid=conversation_id, + persona_id=persona_id + ) + + async def get_human_readable_context(self, unified_msg_origin, conversation_id, page=1, page_size=10): + conversation = await self.get_conversation(unified_msg_origin, conversation_id) + history = json.loads(conversation.history) + + contexts = [] + temp_contexts = [] + for record in history: + if record['role'] == "user": + temp_contexts.append(f"User: {record['content']}") + elif record['role'] == "assistant": + temp_contexts.append(f"Assistant: {record['content']}") + contexts.insert(0, temp_contexts) + temp_contexts = [] + + # 展平 contexts 列表 + contexts = [item for sublist in contexts for item in sublist] + + # 计算分页 + paged_contexts = contexts[(page-1)*page_size:page*page_size] + total_pages = len(contexts) // page_size + if len(contexts) % page_size != 0: + total_pages += 1 + + return paged_contexts, total_pages \ No newline at end of file diff --git a/astrbot/core/core_lifecycle.py b/astrbot/core/core_lifecycle.py index 8f9e9b3d7..e5630d017 100644 --- a/astrbot/core/core_lifecycle.py +++ b/astrbot/core/core_lifecycle.py @@ -18,7 +18,7 @@ from astrbot.core.updator import AstrBotUpdator from astrbot.core import logger from astrbot.core.config.default import VERSION from astrbot.core.rag.knowledge_db_mgr import KnowledgeDBManager - +from astrbot.core.conversation_mgr import ConversationManager class AstrBotCoreLifecycle: def __init__(self, log_broker: LogBroker, db: BaseDatabase): self.log_broker = log_broker @@ -43,12 +43,15 @@ class AstrBotCoreLifecycle: self.knowledge_db_manager = KnowledgeDBManager(self.astrbot_config) + self.conversation_manager = ConversationManager(self.db) + self.star_context = Context( self.event_queue, self.astrbot_config, self.db, self.provider_manager, self.platform_manager, + self.conversation_manager, self.knowledge_db_manager ) self.plugin_manager = PluginManager(self.star_context, self.astrbot_config) diff --git a/astrbot/core/db/__init__.py b/astrbot/core/db/__init__.py index 424c1539f..03474ecbf 100644 --- a/astrbot/core/db/__init__.py +++ b/astrbot/core/db/__init__.py @@ -1,7 +1,7 @@ import abc from dataclasses import dataclass from typing import List -from astrbot.core.db.po import Stats, LLMHistory, ATRIVision, WebChatConversation +from astrbot.core.db.po import Stats, LLMHistory, ATRIVision, Conversation @dataclass class BaseDatabase(abc.ABC): @@ -79,25 +79,35 @@ class BaseDatabase(abc.ABC): raise NotImplementedError @abc.abstractmethod - def get_webchat_conversation_by_user_id(self, user_id: str, cid: str) -> WebChatConversation: - '''通过 user_id 和 cid 获取 WebChatConversation''' + def get_conversation_by_user_id(self, user_id: str, cid: str) -> Conversation: + '''通过 user_id 和 cid 获取 Conversation''' raise NotImplementedError @abc.abstractmethod - def webchat_new_conversation(self, user_id: str, cid: str): - '''新建 WebChatConversation''' + def new_conversation(self, user_id: str, cid: str): + '''新建 Conversation''' raise NotImplementedError @abc.abstractmethod - def get_webchat_conversations(self, user_id: str) -> List[WebChatConversation]: + def get_conversations(self, user_id: str) -> List[Conversation]: raise NotImplementedError @abc.abstractmethod - def update_webchat_conversation(self, user_id: str, cid: str, history: str): - '''更新 WebChatConversation''' + def update_conversation(self, user_id: str, cid: str, history: str): + '''更新 Conversation''' raise NotImplementedError @abc.abstractmethod - def delete_webchat_conversation(self, user_id: str, cid: str): - '''删除 WebChatConversation''' + def delete_conversation(self, user_id: str, cid: str): + '''删除 Conversation''' + raise NotImplementedError + + @abc.abstractmethod + def update_conversation_title(self, user_id: str, cid: str, title: str): + '''更新 Conversation 标题''' + raise NotImplementedError + + @abc.abstractmethod + def update_conversation_persona_id(self, user_id: str, cid: str, persona_id: str): + '''更新 Conversation Persona ID''' raise NotImplementedError \ No newline at end of file diff --git a/astrbot/core/db/po.py b/astrbot/core/db/po.py index 72235f4e4..c905a50ba 100644 --- a/astrbot/core/db/po.py +++ b/astrbot/core/db/po.py @@ -33,16 +33,16 @@ class Stats(): command: List[Command] = field(default_factory=list) llm: List[Provider] = field(default_factory=list) -'''LLM 聊天时持久化的信息''' - @dataclass class LLMHistory(): + '''LLM 聊天时持久化的信息''' provider_type: str session_id: str content: str @dataclass class ATRIVision(): + '''Deprecated''' id: str url_or_path: str caption: str @@ -53,13 +53,18 @@ class ATRIVision(): sender_nickname: str timestamp: int = -1 - - @dataclass -class WebChatConversation(): +class Conversation(): + '''LLM 对话存储 + + 对于网页聊天,history 存储了包括指令、回复、图片等在内的所有消息。 + 对于其他平台的聊天,不存储非 LLM 的回复(因为考虑到已经存储在各自的平台上)。 + ''' user_id: str cid: str history: str = "" + '''字符串格式的列表。''' created_at: int = 0 updated_at: int = 0 - \ No newline at end of file + title: str = "" + persona_id: str = "" \ No newline at end of file diff --git a/astrbot/core/db/sqlite.py b/astrbot/core/db/sqlite.py index 94a5a23e4..b81aef773 100644 --- a/astrbot/core/db/sqlite.py +++ b/astrbot/core/db/sqlite.py @@ -6,7 +6,7 @@ from astrbot.core.db.po import ( Stats, LLMHistory, ATRIVision, - WebChatConversation + Conversation ) from . import BaseDatabase from typing import Tuple @@ -25,6 +25,37 @@ class SQLiteDatabase(BaseDatabase): c = self.conn.cursor() c.executescript(sql) self.conn.commit() + + # 检查 webchat_conversation 的 title 字段是否存在 + c.execute( + ''' + PRAGMA table_info(webchat_conversation) + ''' + ) + res = c.fetchall() + has_title = False + has_persona_id = False + for row in res: + if row[1] == "title": + has_title = True + if row[1] == "persona_id": + has_persona_id = True + if not has_title: + c.execute( + ''' + ALTER TABLE webchat_conversation ADD COLUMN title TEXT; + ''' + ) + self.conn.commit() + if not has_persona_id: + c.execute( + ''' + ALTER TABLE webchat_conversation ADD COLUMN persona_id TEXT; + ''' + ) + self.conn.commit() + + c.close() def _get_conn(self, db_path: str) -> sqlite3.Connection: conn = sqlite3.connect(self.db_path) @@ -202,7 +233,7 @@ class SQLiteDatabase(BaseDatabase): return Stats(platform, [], []) - def get_webchat_conversation_by_user_id(self, user_id: str, cid: str) -> WebChatConversation: + def get_conversation_by_user_id(self, user_id: str, cid: str) -> Conversation: try: c = self.conn.cursor() except sqlite3.ProgrammingError: @@ -216,9 +247,9 @@ class SQLiteDatabase(BaseDatabase): res = c.fetchone() c.close() - return WebChatConversation(*res) + return Conversation(*res) - def webchat_new_conversation(self, user_id: str, cid: str): + def new_conversation(self, user_id: str, cid: str): history = "[]" updated_at = int(time.time()) created_at = updated_at @@ -228,7 +259,7 @@ class SQLiteDatabase(BaseDatabase): ''', (user_id, cid, history, updated_at, created_at) ) - def get_webchat_conversations(self, user_id: str) -> Tuple: + def get_conversations(self, user_id: str) -> Tuple: try: c = self.conn.cursor() except sqlite3.ProgrammingError: @@ -236,7 +267,7 @@ class SQLiteDatabase(BaseDatabase): c.execute( ''' - SELECT cid, created_at, updated_at FROM webchat_conversation WHERE user_id = ? ORDER BY updated_at DESC + SELECT cid, created_at, updated_at, title, persona_id FROM webchat_conversation WHERE user_id = ? ORDER BY updated_at DESC ''', (user_id,) ) @@ -247,24 +278,42 @@ class SQLiteDatabase(BaseDatabase): cid = row[0] created_at = row[1] updated_at = row[2] - conversations.append(WebChatConversation("", cid, '[]', created_at, updated_at)) + title = row[3] + persona_id = row[4] + conversations.append(Conversation("", cid, '[]', created_at, updated_at, title, persona_id)) return conversations - def update_webchat_conversation(self, user_id: str, cid: str, history: str): + def update_conversation(self, user_id: str, cid: str, history: str): + '''更新对话,并且同时更新时间''' + updated_at = int(time.time()) self._exec_sql( ''' - UPDATE webchat_conversation SET history = ? WHERE user_id = ? AND cid = ? - ''', (history, user_id, cid) + UPDATE webchat_conversation SET history = ?, updated_at = ? WHERE user_id = ? AND cid = ? + ''', (history, updated_at, user_id, cid) + ) + + + def update_conversation_title(self, user_id: str, cid: str, title: str): + self._exec_sql( + ''' + UPDATE webchat_conversation SET title = ? WHERE user_id = ? AND cid = ? + ''', (title, user_id, cid) + ) + + def update_conversation_persona_id(self, user_id: str, cid: str, persona_id: str): + self._exec_sql( + ''' + UPDATE webchat_conversation SET persona_id = ? WHERE user_id = ? AND cid = ? + ''', (persona_id, user_id, cid) ) - def delete_webchat_conversation(self, user_id: str, cid: str): + def delete_conversation(self, user_id: str, cid: str): self._exec_sql( ''' DELETE FROM webchat_conversation WHERE user_id = ? AND cid = ? ''', (user_id, cid) ) - def insert_atri_vision_data(self, vision: ATRIVision): ts = int(time.time()) keywords = ",".join(vision.keywords) diff --git a/astrbot/core/db/sqlite_init.sql b/astrbot/core/db/sqlite_init.sql index e58f8bad9..900f4f2c0 100644 --- a/astrbot/core/db/sqlite_init.sql +++ b/astrbot/core/db/sqlite_init.sql @@ -42,5 +42,7 @@ CREATE TABLE IF NOT EXISTS webchat_conversation( cid TEXT, history TEXT, created_at INTEGER, - updated_at INTEGER + updated_at INTEGER, + title TEXT, + persona_id TEXT ); \ No newline at end of file diff --git a/astrbot/core/pipeline/process_stage/method/llm_request.py b/astrbot/core/pipeline/process_stage/method/llm_request.py index f983f5da8..dee23211a 100644 --- a/astrbot/core/pipeline/process_stage/method/llm_request.py +++ b/astrbot/core/pipeline/process_stage/method/llm_request.py @@ -2,6 +2,7 @@ 本地 Agent 模式的 LLM 调用 Stage ''' import traceback +import json from typing import Union, AsyncGenerator from ...context import PipelineContext from ..stage import Stage @@ -10,7 +11,7 @@ from astrbot.core.message.message_event_result import MessageEventResult, Result from astrbot.core.message.components import Image from astrbot.core import logger from astrbot.core.utils.metrics import Metric -from astrbot.core.provider.entites import ProviderRequest +from astrbot.core.provider.entites import ProviderRequest, LLMResponse from astrbot.core.star.star_handler import star_handlers_registry, EventType class LLMRequestSubStage(Stage): @@ -24,6 +25,8 @@ class LLMRequestSubStage(Stage): if self.provider_wake_prefix.startswith(bwp): logger.info(f"识别 LLM 聊天额外唤醒前缀 {self.provider_wake_prefix} 以机器人唤醒前缀 {bwp} 开头,已自动去除。") self.provider_wake_prefix = self.provider_wake_prefix[len(bwp):] + + self.conv_manager = ctx.plugin_manager.context.conversation_manager async def process(self, event: AstrMessageEvent, _nested: bool = False) -> Union[None, AsyncGenerator[None, None]]: req: ProviderRequest = None @@ -46,10 +49,17 @@ class LLMRequestSubStage(Stage): if isinstance(comp, Image): image_url = comp.url if comp.url else comp.file req.image_urls.append(image_url) - req.session_id = event.session_id + + # 获取对话上下文 + conversation_id = await self.conv_manager.get_curr_conversation_id(event.unified_msg_origin) + if not conversation_id: + conversation_id = await self.conv_manager.new_conversation(event.unified_msg_origin) + req.session_id = conversation_id + conversation = await self.conv_manager.get_conversation(event.unified_msg_origin, conversation_id) + req.conversation = conversation + req.contexts = json.loads(conversation.history) + event.set_extra("provider_request", req) - session_provider_context = provider.session_memory.get(event.session_id) - req.contexts = session_provider_context if session_provider_context else [] if not req.prompt and not req.image_urls: return @@ -62,6 +72,9 @@ class LLMRequestSubStage(Stage): await handler.handler(event, req) except BaseException: logger.error(traceback.format_exc()) + + if isinstance(req.contexts, str): + req.contexts = json.loads(req.contexts) try: logger.debug(f"提供商请求 Payload: {req.__dict__}") @@ -77,6 +90,9 @@ class LLMRequestSubStage(Stage): except BaseException: logger.error(traceback.format_exc()) + # 保存到历史记录 + await self._save_to_history(event, req, llm_response) + await Metric.upload(llm_tick=1, model_name=provider.get_model(), provider_type=provider.meta().type) if llm_response.role == 'assistant': @@ -117,4 +133,24 @@ class LLMRequestSubStage(Stage): except BaseException as e: logger.error(traceback.format_exc()) event.set_result(MessageEventResult().message(f"AstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {str(e)}")) - return \ No newline at end of file + return + + async def _save_to_history(self, event: AstrMessageEvent, req: ProviderRequest, llm_response: LLMResponse): + if llm_response.role == "assistant": + # 文本回复 + contexts = req.contexts + new_record = { + "role": "user", + "content": req.prompt + } + contexts.append(new_record) + contexts.append({ + "role": "assistant", + "content": llm_response.completion_text + }) + contexts_to_save = list(filter(lambda item: '_no_save' not in item, contexts)) + await self.conv_manager.update_conversation( + event.unified_msg_origin, + req.session_id, + history=contexts_to_save + ) \ No newline at end of file diff --git a/astrbot/core/provider/entites.py b/astrbot/core/provider/entites.py index 4f203a8bb..d0bc5ec0a 100644 --- a/astrbot/core/provider/entites.py +++ b/astrbot/core/provider/entites.py @@ -3,6 +3,7 @@ from dataclasses import dataclass, field from typing import List, Dict, Type from .func_tool_manager import FuncCall from openai.types.chat.chat_completion import ChatCompletion +from astrbot.core.db.po import Conversation class ProviderType(enum.Enum): @@ -38,9 +39,9 @@ class ProviderRequest(): '''上下文。格式与 openai 的上下文格式一致: 参考 https://platform.openai.com/docs/api-reference/chat/create#chat-create-messages ''' - system_prompt: str = "" '''系统提示词''' + conversation: Conversation = None @dataclass diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py index ad8cb8a4f..d702fe156 100644 --- a/astrbot/core/provider/manager.py +++ b/astrbot/core/provider/manager.py @@ -66,7 +66,14 @@ class ProviderManager(): if not self.selected_default_persona and len(self.personas) > 0: # 默认选择第一个 self.selected_default_persona = self.personas[0] - + + if not self.selected_default_persona: + self.selected_default_persona = Personality( + prompt="You are a helpful and friendly assistant.", + name="default", + ) + self.personas.append(self.selected_default_persona) + self.provider_insts: List[Provider] = [] '''加载的 Provider 的实例''' diff --git a/astrbot/core/provider/provider.py b/astrbot/core/provider/provider.py index dfa0bc8a1..144a76d56 100644 --- a/astrbot/core/provider/provider.py +++ b/astrbot/core/provider/provider.py @@ -8,6 +8,8 @@ from typing import TypedDict from astrbot.core.provider.func_tool_manager import FuncCall from astrbot.core.provider.entites import LLMResponse from dataclasses import dataclass + + class Personality(TypedDict): prompt: str = "" name: str = "" @@ -60,25 +62,11 @@ class Provider(AbstractProvider): ) -> None: super().__init__(provider_config) - self.session_memory = defaultdict(list) - '''维护了 session_id 的上下文,**不包含 system 指令**。''' - self.provider_settings = provider_settings self.curr_personality: Personality = default_persona '''维护了当前的使用的 persona,即人格。可能为 None''' - - self.db_helper = db_helper - '''用于持久化的数据库操作对象。''' - if persistant_history: - # 读取历史记录 - try: - for history in db_helper.get_llm_history(provider_type=provider_config['id']): - self.session_memory[history.session_id] = json.loads(history.content) - except BaseException as e: - logger.warning(f"读取 LLM 对话历史记录 失败:{e}。仍可正常使用。") - @abc.abstractmethod def get_current_key(self) -> str: raise NotImplementedError() @@ -96,22 +84,6 @@ class Provider(AbstractProvider): '''获得支持的模型列表''' raise NotImplementedError() - @abc.abstractmethod - async def get_human_readable_context(self, session_id: str, page: int, page_size: int): - '''获取人类可读的上下文 - - page 从 1 开始 - - Example: - - ["User: 你好", "Assistant: 你好!"] - - Return: - contexts: List[str]: 上下文列表 - total_pages: int: 总页数 - ''' - raise NotImplementedError() - @abc.abstractmethod async def text_chat(self, prompt: str, @@ -125,26 +97,35 @@ class Provider(AbstractProvider): Args: prompt: 提示词 - session_id: 会话 ID + session_id: 会话 ID(此属性已经被废弃) image_urls: 图片 URL 列表 tools: Function-calling 工具 contexts: 上下文 kwargs: 其他参数 Notes: - - 如果传入了 contexts,将会提前加上上下文。否则使用 session_memory 中的上下文。 - - 可以选择性地传入 session_id,如果传入了 session_id,将会使用 session_id 对应的上下文进行对话, - 并且也会记录相应的对话上下文,实现多轮对话。如果不传入则不会记录上下文。 - 如果传入了 image_urls,将会在对话时附上图片。如果模型不支持图片输入,将会抛出错误。 - 如果传入了 tools,将会使用 tools 进行 Function-calling。如果模型不支持 Function-calling,将会抛出错误。 ''' raise NotImplementedError() - - @abc.abstractmethod - async def forget(self, session_id: str) -> bool: - '''重置某一个 session_id 的上下文''' - raise NotImplementedError() + async def pop_record(self, context: List): + ''' + 弹出 context 第一条非系统提示词对话记录 + ''' + poped = 0 + indexs_to_pop = [] + for idx, record in enumerate(context): + if record["role"] == "system": + continue + else: + indexs_to_pop.append(idx) + poped += 1 + if poped == 2: + break + + for idx in reversed(indexs_to_pop): + context.pop(idx) class STTProvider(AbstractProvider): diff --git a/astrbot/core/provider/sources/gemini_source.py b/astrbot/core/provider/sources/gemini_source.py index 2b2a9916f..3fe580257 100644 --- a/astrbot/core/provider/sources/gemini_source.py +++ b/astrbot/core/provider/sources/gemini_source.py @@ -73,60 +73,10 @@ class ProviderGoogleGenAI(Provider): api_base=provider_config.get("api_base", None) ) self.set_model(provider_config['model_config']['model']) - - async def get_human_readable_context(self, session_id, page, page_size): - if session_id not in self.session_memory: - raise Exception("会话 ID 不存在") - contexts = [] - temp_contexts = [] - for record in self.session_memory[session_id]: - if record['role'] == "user": - temp_contexts.append(f"User: {record['content']}") - elif record['role'] == "assistant": - temp_contexts.append(f"Assistant: {record['content']}") - contexts.insert(0, temp_contexts) - temp_contexts = [] - - # 展平 contexts 列表 - contexts = [item for sublist in contexts for item in sublist] - - # 计算分页 - paged_contexts = contexts[(page-1)*page_size:page*page_size] - total_pages = len(contexts) // page_size - if len(contexts) % page_size != 0: - total_pages += 1 - - return paged_contexts, total_pages async def get_models(self): return await self.client.models_list() - async def pop_record(self, session_id: str, pop_system_prompt: bool = False): - ''' - 弹出第一条记录 - ''' - if session_id not in self.session_memory: - raise Exception("会话 ID 不存在") - - if len(self.session_memory[session_id]) == 0: - return None - - for i in range(len(self.session_memory[session_id])): - # 检查是否是 system prompt - if not pop_system_prompt and self.session_memory[session_id][i]['user']['role'] == "system": - # 如果只有一个 system prompt,才不删掉 - f = False - for j in range(i+1, len(self.session_memory[session_id])): - if self.session_memory[session_id][j]['user']['role'] == "system": - f = True - break - if not f: - continue - record = self.session_memory[session_id].pop(i) - break - - return record - async def _query(self, payloads: dict, tools: FuncCall) -> LLMResponse: tool = None if tools: @@ -197,11 +147,10 @@ class ProviderGoogleGenAI(Provider): llm_response.completion_text = llm_response.completion_text.strip() return llm_response - async def text_chat( self, prompt: str, - session_id: str, + session_id: str = None, image_urls: List[str]=None, func_tool: FuncCall=None, contexts=None, @@ -210,10 +159,7 @@ class ProviderGoogleGenAI(Provider): ) -> LLMResponse: new_record = await self.assemble_context(prompt, image_urls) context_query = [] - if not contexts: - context_query = [*self.session_memory[session_id], new_record] - else: - context_query = [*contexts, new_record] + context_query = [*contexts, new_record] if system_prompt: context_query.insert(0, {"role": "system", "content": system_prompt}) @@ -234,7 +180,7 @@ class ProviderGoogleGenAI(Provider): while retry_cnt > 0: logger.warning(f"请求失败:{e}。上下文长度超过限制。尝试弹出最早的记录然后重试。") try: - self.pop_record(session_id) + await self.pop_record(context_query) llm_response = await self._query(payloads, func_tool) break except Exception as e: @@ -254,34 +200,7 @@ class ProviderGoogleGenAI(Provider): raise e - if kwargs.get("persist", True) and llm_response: - await self.save_history(contexts, new_record, session_id, llm_response) - return llm_response - - async def save_history(self, contexts: List, new_record: dict, session_id: str, llm_response: LLMResponse): - if llm_response.role == "assistant" and session_id: - # 文本回复 - if not contexts: - # 添加用户 record - self.session_memory[session_id].append(new_record) - # 添加 assistant record - self.session_memory[session_id].append({ - "role": "assistant", - "content": llm_response.completion_text - }) - else: - contexts_to_save = list(filter(lambda item: '_no_save' not in item, contexts)) - self.session_memory[session_id] = [*contexts_to_save, new_record, { - "role": "assistant", - "content": llm_response.completion_text - }] - self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]), self.provider_config['id']) - - async def forget(self, session_id: str) -> bool: - self.session_memory[session_id] = [] - self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]), self.provider_config['id']) - return True def get_current_key(self) -> str: return self.client.api_key diff --git a/astrbot/core/provider/sources/llmtuner_source.py b/astrbot/core/provider/sources/llmtuner_source.py index 9aa709b1a..9c3b2ff79 100644 --- a/astrbot/core/provider/sources/llmtuner_source.py +++ b/astrbot/core/provider/sources/llmtuner_source.py @@ -63,14 +63,7 @@ class LLMTunerModelLoader(Provider): ) -> LLMResponse: system_prompt = "" new_record = {"role": "user", "content": prompt} - if not contexts: - query_context = [ - *self.session_memory[session_id], - new_record, - ] - system_prompt = self.curr_personality["prompt"] - else: - query_context = [*contexts, new_record] + query_context = [*contexts, new_record] # 提取出系统提示 system_idxs = [] @@ -96,34 +89,8 @@ class LLMTunerModelLoader(Provider): responses = await self.model.achat(**conf) llm_response = LLMResponse("assistant", responses[-1].response_text) - - await self.save_history(contexts, new_record, session_id, llm_response) - + return llm_response - - async def save_history(self, contexts: List, new_record: dict, session_id: str, llm_response: LLMResponse): - if llm_response.role == "assistant" and session_id: - # 文本回复 - if not contexts: - # 添加用户 record - self.session_memory[session_id].append(new_record) - # 添加 assistant record - self.session_memory[session_id].append({ - "role": "assistant", - "content": llm_response.completion_text - }) - else: - contexts_to_save = list(filter(lambda item: '_no_save' not in item, contexts)) - self.session_memory[session_id] = [*contexts_to_save, new_record, { - "role": "assistant", - "content": llm_response.completion_text - }] - self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]), self.provider_config['id']) - - async def forget(self, session_id): - self.session_memory[session_id] = [] - self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]), self.provider_config['id']) - return True async def get_current_key(self): return "none" @@ -132,28 +99,4 @@ class LLMTunerModelLoader(Provider): pass async def get_models(self): - return [self.get_model()] - - async def get_human_readable_context(self, session_id, page, page_size): - if session_id not in self.session_memory: - raise Exception("会话 ID 不存在") - contexts = [] - temp_contexts = [] - for record in self.session_memory[session_id]: - if record["role"] == "user": - temp_contexts.append(f"User: {record['content']}") - elif record["role"] == "assistant": - temp_contexts.append(f"Assistant: {record['content']}") - contexts.insert(0, temp_contexts) - temp_contexts = [] - - # 展平 contexts 列表 - contexts = [item for sublist in contexts for item in sublist] - - # 计算分页 - paged_contexts = contexts[(page - 1) * page_size : page * page_size] - total_pages = len(contexts) // page_size - if len(contexts) % page_size != 0: - total_pages += 1 - - return paged_contexts, total_pages + return [self.get_model()] \ No newline at end of file diff --git a/astrbot/core/provider/sources/openai_source.py b/astrbot/core/provider/sources/openai_source.py index d299986e8..ac57bb63e 100644 --- a/astrbot/core/provider/sources/openai_source.py +++ b/astrbot/core/provider/sources/openai_source.py @@ -48,30 +48,6 @@ class ProviderOpenAIOfficial(Provider): ) self.set_model(provider_config['model_config']['model']) - - async def get_human_readable_context(self, session_id, page, page_size): - if session_id not in self.session_memory: - raise Exception("会话 ID 不存在") - contexts = [] - temp_contexts = [] - for record in self.session_memory[session_id]: - if record['role'] == "user": - temp_contexts.append(f"User: {record['content']}") - elif record['role'] == "assistant": - temp_contexts.append(f"Assistant: {record['content']}") - contexts.insert(0, temp_contexts) - temp_contexts = [] - - # 展平 contexts 列表 - contexts = [item for sublist in contexts for item in sublist] - - # 计算分页 - paged_contexts = contexts[(page-1)*page_size:page*page_size] - total_pages = len(contexts) // page_size - if len(contexts) % page_size != 0: - total_pages += 1 - - return paged_contexts, total_pages async def get_models(self): try: @@ -84,22 +60,6 @@ class ProviderOpenAIOfficial(Provider): except NotFoundError as e: raise Exception(f"获取模型列表失败:{e}") - async def pop_record(self, session_id: str): - ''' - 弹出最早的一个对话 - ''' - if session_id not in self.session_memory: - raise Exception("会话 ID 不存在") - - if len(self.session_memory[session_id]) < 2: - return - - try: - self.session_memory[session_id].pop(0) - self.session_memory[session_id].pop(0) - except IndexError: - pass - async def _query(self, payloads: dict, tools: FuncCall) -> LLMResponse: if tools: tool_list = tools.get_func_desc_openai_style() @@ -141,7 +101,7 @@ class ProviderOpenAIOfficial(Provider): async def text_chat( self, prompt: str, - session_id: str, + session_id: str=None, image_urls: List[str]=None, func_tool: FuncCall=None, contexts=None, @@ -149,11 +109,7 @@ class ProviderOpenAIOfficial(Provider): **kwargs ) -> LLMResponse: new_record = await self.assemble_context(prompt, image_urls) - context_query = [] - if not contexts: - context_query = [*self.session_memory[session_id], new_record] - else: - context_query = [*contexts, new_record] + context_query = [*contexts, new_record] if system_prompt: context_query.insert(0, {"role": "system", "content": system_prompt}) @@ -214,9 +170,6 @@ class ProviderOpenAIOfficial(Provider): raise e - if kwargs.get("persist", True) and llm_response: - await self.save_history(contexts, new_record, session_id, llm_response) - return llm_response async def _remove_image_from_context(self, contexts: List): @@ -244,32 +197,7 @@ class ProviderOpenAIOfficial(Provider): context['content'] = new_content new_contexts.append(context) return new_contexts - - async def save_history(self, contexts: List, new_record: dict, session_id: str, llm_response: LLMResponse): - if llm_response.role == "assistant" and session_id: - # 文本回复 - if not contexts: - # 添加用户 record - self.session_memory[session_id].append(new_record) - # 添加 assistant record - self.session_memory[session_id].append({ - "role": "assistant", - "content": llm_response.completion_text - }) - else: - contexts_to_save = list(filter(lambda item: '_no_save' not in item, contexts)) - self.session_memory[session_id] = [*contexts_to_save, new_record, { - "role": "assistant", - "content": llm_response.completion_text - }] - self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]), self.provider_config['id']) - - async def forget(self, session_id: str) -> bool: - self.session_memory[session_id] = [] - self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]), self.provider_config['id']) - return True - def get_current_key(self) -> str: return self.client.api_key diff --git a/astrbot/core/provider/sources/zhipu_source.py b/astrbot/core/provider/sources/zhipu_source.py index f1d576775..191f8968d 100644 --- a/astrbot/core/provider/sources/zhipu_source.py +++ b/astrbot/core/provider/sources/zhipu_source.py @@ -22,7 +22,7 @@ class ProviderZhipu(ProviderOpenAIOfficial): async def text_chat( self, prompt: str, - session_id: str, + session_id: str = None, image_urls: List[str]=None, func_tool: FuncCall=None, contexts=None, @@ -32,10 +32,7 @@ class ProviderZhipu(ProviderOpenAIOfficial): new_record = await self.assemble_context(prompt, image_urls) context_query = [] - if not contexts: - context_query = [*self.session_memory[session_id], new_record] - else: - context_query = [*contexts, new_record] + context_query = [*contexts, new_record] model_cfgs: dict = self.provider_config.get("model_config", {}) # glm-4v-flash 只支持一张图片 @@ -62,7 +59,6 @@ class ProviderZhipu(ProviderOpenAIOfficial): } try: llm_response = await self._query(payloads, func_tool) - await self.save_history(contexts, new_record, session_id, llm_response) return llm_response except Exception as e: if "maximum context length" in str(e): diff --git a/astrbot/core/star/context.py b/astrbot/core/star/context.py index 6b50f951c..472ed0e96 100644 --- a/astrbot/core/star/context.py +++ b/astrbot/core/star/context.py @@ -16,6 +16,7 @@ from .filter.command import CommandFilter from .filter.regex import RegexFilter from typing import Awaitable from astrbot.core.rag.knowledge_db_mgr import KnowledgeDBManager +from astrbot.core.conversation_mgr import ConversationManager class Context: ''' @@ -44,6 +45,7 @@ class Context: db: BaseDatabase, provider_manager: ProviderManager = None, platform_manager: PlatformManager = None, + conversation_manager: ConversationManager = None, knowledge_db_manager: KnowledgeDBManager = None ): self._event_queue = event_queue @@ -52,6 +54,7 @@ class Context: self.provider_manager = provider_manager self.platform_manager = platform_manager self.knowledge_db_manager = knowledge_db_manager + self.conversation_manager = conversation_manager def get_registered_star(self, star_name: str) -> StarMetadata: '''根据插件名获取插件的 Metadata''' diff --git a/astrbot/core/utils/param_validation_mixin.py b/astrbot/core/utils/param_validation_mixin.py index 896d5bc56..5c1e864a5 100644 --- a/astrbot/core/utils/param_validation_mixin.py +++ b/astrbot/core/utils/param_validation_mixin.py @@ -25,6 +25,10 @@ class ParameterValidationMixin: elif isinstance(param_type_or_default_val, str): # 如果 param_type_or_default_val 是字符串,直接赋值 result[param_name] = params[i] + elif isinstance(param_type_or_default_val, int): + result[param_name] = int(params[i]) + elif isinstance(param_type_or_default_val, float): + result[param_name] = float(params[i]) else: result[param_name] = param_type_or_default_val(params[i]) except ValueError: diff --git a/astrbot/dashboard/routes/chat.py b/astrbot/dashboard/routes/chat.py index 9ec348235..329ff6f36 100644 --- a/astrbot/dashboard/routes/chat.py +++ b/astrbot/dashboard/routes/chat.py @@ -121,7 +121,7 @@ class ChatRoute(Route): })) # 持久化 - conversation = self.db.get_webchat_conversation_by_user_id(username, conversation_id) + conversation = self.db.get_conversation_by_user_id(username, conversation_id) try: history = json.loads(conversation.history) except BaseException as e: @@ -136,7 +136,7 @@ class ChatRoute(Route): if audio_url: new_his['audio_url'] = audio_url history.append(new_his) - self.db.update_webchat_conversation(username, conversation_id, history=json.dumps(history)) + self.db.update_conversation(username, conversation_id, history=json.dumps(history)) return Response().ok().__dict__ @@ -168,7 +168,7 @@ class ChatRoute(Route): continue yield result_text + '\n' - conversation = self.db.get_webchat_conversation_by_user_id(username, cid) + conversation = self.db.get_conversation_by_user_id(username, cid) try: history = json.loads(conversation.history) except BaseException as e: @@ -178,7 +178,7 @@ class ChatRoute(Route): 'type': 'bot', 'message': result_text }) - self.db.update_webchat_conversation(username, cid, history=json.dumps(history)) + self.db.update_conversation(username, cid, history=json.dumps(history)) await asyncio.sleep(0.5) except BaseException as e: @@ -204,20 +204,20 @@ class ChatRoute(Route): if not conversation_id: return Response().error("Missing key: conversation_id").__dict__ - self.db.delete_webchat_conversation(username, conversation_id) + self.db.delete_conversation(username, conversation_id) return Response().ok().__dict__ async def new_conversation(self): username = g.get('username', 'guest') conversation_id = str(uuid.uuid4()) - self.db.webchat_new_conversation(username, conversation_id) + self.db.new_conversation(username, conversation_id) return Response().ok(data={ 'conversation_id': conversation_id }).__dict__ async def get_conversations(self): username = g.get('username', 'guest') - conversations = self.db.get_webchat_conversations(username) + conversations = self.db.get_conversations(username) return Response().ok(data=conversations).__dict__ async def get_conversation(self): @@ -226,7 +226,7 @@ class ChatRoute(Route): if not conversation_id: return Response().error("Missing key: conversation_id").__dict__ - conversation = self.db.get_webchat_conversation_by_user_id(username, conversation_id) + conversation = self.db.get_conversation_by_user_id(username, conversation_id) self.curr_user_cid[username] = conversation_id diff --git a/packages/astrbot/main.py b/packages/astrbot/main.py index 95ceae5fa..1a8945609 100644 --- a/packages/astrbot/main.py +++ b/packages/astrbot/main.py @@ -1,6 +1,7 @@ import aiohttp import datetime import builtins +import json import astrbot.api.star as star import astrbot.api.event.filter as filter from astrbot.api.event import AstrMessageEvent, MessageEventResult @@ -59,24 +60,28 @@ AstrBot 指令: /plugin: 查看插件、插件帮助 /t2i: 开关文本转图片 /sid: 获取会话 ID -/op : 授权管理员 -/deop : 取消管理员 -/wl : 添加白名单 -/dwl : 删除白名单 -/dashboard_update: 更新管理面板 -/alter_cmd: 设置指令权限 +/op : 授权管理员(op) +/deop : 取消管理员(op) +/wl : 添加白名单(op) +/dwl : 删除白名单(op) +/dashboard_update: 更新管理面板(op) +/alter_cmd: 设置指令权限(op) [大模型] /provider: 大模型提供商 /model: 模型列表 -/key: API Key -/reset: 重置 LLM 会话 -/history: 对话记录 -/persona: 人格情景 +/ls: 对话列表 +/new: 创建新对话 +/switch: 切换对话 +/del: 删除当前会话对话(op) +/reset: 重置 LLM 会话(op) +/history: 当前对话的对话记录 +/persona: 人格情景(op) /tool ls: 函数工具 +/key: API Key(op) [其他] -/set <变量名> <值>: 为会话定义一个变量。适用于 Dify 工作流输入。 +/set <变量名> <值>: 为会话定义变量。适用于 Dify 工作流输入。 /unset <变量名>: 删除会话的变量。 提示:如要查看插件指令,请输入 /plugin 查看具体信息。 @@ -273,7 +278,10 @@ UID: {user_id} 此 ID 可用于设置管理员。/op 授权管理员, /deo message.set_result(MessageEventResult().message("未找到任何 LLM 提供商。请先配置。")) return - await self.context.get_using_provider().forget(message.session_id) + await self.context.conversation_manager.update_conversation( + message.unified_msg_origin, message.session_id, [] + ) + ret = "清除会话 LLM 聊天历史成功。" if self.ltm: cnt = await self.ltm.remove_session(event=message) @@ -329,20 +337,22 @@ UID: {user_id} 此 ID 可用于设置管理员。/op 授权管理员, /deo @filter.command("history") async def his(self, message: AstrMessageEvent, page: int = 1): - - + '''查看对话记录''' if not self.context.get_using_provider(): message.set_result(MessageEventResult().message("未找到任何 LLM 提供商。请先配置。")) return - size_per_page = 3 - contexts, total_pages = await self.context.get_using_provider().get_human_readable_context(message.session_id, page, size_per_page) + size_per_page = 6 + session_curr_cid = await self.context.conversation_manager.get_curr_conversation_id(message.unified_msg_origin) + contexts, total_pages = await self.context.conversation_manager.get_human_readable_context( + message.unified_msg_origin, session_curr_cid, page, size_per_page + ) history = "" for context in contexts: history += f"{context}\n" - ret = f"""历史记录: + ret = f"""当前对话历史记录: {history} 第 {page} 页 | 共 {total_pages} 页 @@ -351,6 +361,64 @@ UID: {user_id} 此 ID 可用于设置管理员。/op 授权管理员, /deo message.set_result(MessageEventResult().message(ret).use_t2i(False)) + @filter.command("ls") + async def convs(self, message: AstrMessageEvent, page: int = 1): + '''查看对话列表''' + size_per_page = 6 + conversations = await self.context.conversation_manager.get_conversations(message.unified_msg_origin) + total_pages = len(conversations) // size_per_page + if len(conversations) % size_per_page != 0: + total_pages += 1 + conversations = conversations[(page-1)*size_per_page:page*size_per_page] + + ret = "\n对话列表:\n" + global_index = (page - 1) * size_per_page + 1 + + for conv in conversations: + + persona_id = conv.persona_id + if not persona_id and not persona_id == "[%None]": + persona_id = self.context.provider_manager.selected_default_persona['name'] + + ret += f"{global_index}. 新对话{conv.cid[:4]}\n 人格情景: {persona_id}\n上次更新: {datetime.datetime.fromtimestamp(conv.updated_at).strftime('%m-%d %H:%M')}\n" + global_index += 1 + + curr_cid = await self.context.conversation_manager.get_curr_conversation_id(message.unified_msg_origin) + if curr_cid: + ret += f"\n当前对话: {curr_cid[:4]}" + else: + ret += "\n当前对话: 无" + ret += f"\n第 {page} 页 | 共 {total_pages} 页" + ret += "\n*输入 /ls 2 跳转到第 2 页" + + message.set_result(MessageEventResult().message(ret).use_t2i(False)) + + @filter.command("new") + async def new_conv(self, message: AstrMessageEvent): + '''创建新对话''' + cid = await self.context.conversation_manager.new_conversation(message.unified_msg_origin) + message.set_result(MessageEventResult().message(f"切换到新对话: {cid[:4]}。")) + + @filter.command("switch") + async def switch_conv(self, message: AstrMessageEvent, index: int): + '''切换对话''' + conversations = await self.context.conversation_manager.get_conversations(message.unified_msg_origin) + if index > len(conversations) or index < 1: + message.set_result(MessageEventResult().message("对话序号错误。")) + else: + conversation = conversations[index-1] + await self.context.conversation_manager.switch_conversation(message.unified_msg_origin, conversation.cid) + message.set_result(MessageEventResult().message(f"切换到对话: {conversation.cid[:4]}。")) + + @filter.permission_type(filter.PermissionType.ADMIN) + @filter.command("del") + async def del_conv(self, message: AstrMessageEvent): + '''删除当前对话''' + session_curr_cid = await self.context.conversation_manager.get_curr_conversation_id(message.unified_msg_origin) + await self.context.conversation_manager.delete_conversation(message.unified_msg_origin, session_curr_cid) + message.set_result(MessageEventResult().message("删除当前对话成功。")) + + @filter.permission_type(filter.PermissionType.ADMIN) @filter.command("key") async def key(self, message: AstrMessageEvent, index: int=None): @@ -387,28 +455,28 @@ UID: {user_id} 此 ID 可用于设置管理员。/op 授权管理员, /deo @filter.permission_type(filter.PermissionType.ADMIN) @filter.command("persona") async def persona(self, message: AstrMessageEvent): - - if not self.context.get_using_provider(): - message.set_result(MessageEventResult().message("未找到任何 LLM 提供商。请先配置。")) - return - - l = message.message_str.split(" ") curr_persona_name = "无" - if self.context.get_using_provider().curr_personality: - curr_persona_name = self.context.get_using_provider().curr_personality['name'] + cid = await self.context.conversation_manager.get_curr_conversation_id(message.unified_msg_origin) + if cid: + conversation = await self.context.conversation_manager.get_conversation(message.unified_msg_origin, cid) + if not conversation.persona_id and not conversation.persona_id == "[%None]": + curr_persona_name = self.context.provider_manager.selected_default_persona['name'] + else: + curr_persona_name = conversation.persona_id if len(l) == 1: message.set_result( MessageEventResult().message(f"""[Persona] -- 设置人格情景: `/persona 人格名`, 如 /persona 编剧 - 人格情景列表: `/persona list` -- 人格情景详细信息: `/persona view 人格名` +- 设置人格情景: `/persona 人格` +- 人格情景详细信息: `/persona view 人格` - 取消人格: `/persona unset` -当前人格情景: {curr_persona_name} +默认人格情景: {self.context.provider_manager.selected_default_persona['name']} +当前对话 {cid[:4]} 的人格情景: {curr_persona_name} 配置人格情景请前往管理面板-配置页 """).use_t2i(False)) @@ -433,7 +501,10 @@ UID: {user_id} 此 ID 可用于设置管理员。/op 授权管理员, /deo msg = f"人格{ps}不存在" message.set_result(MessageEventResult().message(msg)) elif l[1] == "unset": - self.context.get_using_provider().curr_personality = None + if not cid: + message.set_result(MessageEventResult().message("当前没有对话,无法取消人格。")) + return + await self.context.conversation_manager.update_conversation_persona_id(message.unified_msg_origin, "[%None]") message.set_result(MessageEventResult().message("取消人格成功。")) else: ps = "".join(l[1:]).strip() @@ -441,7 +512,7 @@ UID: {user_id} 此 ID 可用于设置管理员。/op 授权管理员, /deo lambda persona: persona['name'] == ps, self.context.provider_manager.personas ), None): - self.context.get_using_provider().curr_personality = persona + await self.context.conversation_manager.update_conversation_persona_id(message.unified_msg_origin, ps) message.set_result(MessageEventResult().message("设置成功。如果您正在切换到不同的人格,请注意使用 /reset 来清空上下文,防止原人格对话影响现人格。")) else: message.set_result(MessageEventResult().message("不存在该人格情景。使用 /persona list 查看所有。")) @@ -513,7 +584,12 @@ UID: {user_id} 此 ID 可用于设置管理员。/op 授权管理员, /deo logger.error("未找到任何 LLM 提供商。请先配置。无法主动回复") return try: - session_provider_context = provider.session_memory.get(event.session_id) + session_curr_cid = await self.context.conversation_manager.get_curr_conversation_id(event.unified_msg_origin) + conv = await self.context.conversation_manager.get_conversation( + event.unified_msg_origin, + session_curr_cid + ) + history = json.loads(conv.history) prompt = self.ltm.ar_prompt if not prompt: @@ -523,7 +599,7 @@ UID: {user_id} 此 ID 可用于设置管理员。/op 授权管理员, /deo prompt=prompt, func_tool_manager=self.context.get_llm_tool_manager(), session_id=event.session_id, - contexts=session_provider_context if session_provider_context else [] + contexts=history if history else [] ) except BaseException as e: logger.error(f"主动回复失败: {e}") @@ -545,14 +621,22 @@ UID: {user_id} 此 ID 可用于设置管理员。/op 授权管理员, /deo if self.enable_datetime: req.system_prompt += f"\nCurrent datetime: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M')}\n" - if persona := provider.curr_personality: - if prompt := persona['prompt']: - req.system_prompt += prompt - if mood_dialogs := persona['_mood_imitation_dialogs_processed']: - req.system_prompt += "\nHere are few shots of dialogs, you need to imitate the tone of 'B' in the following dialogs to respond:\n" - req.system_prompt += mood_dialogs - if begin_dialogs := persona["_begin_dialogs_processed"]: - req.contexts[:0] = begin_dialogs + if req.conversation: + persona_id = req.conversation.persona_id + if not persona_id and persona_id != "[%None]": # [%None] 为用户取消人格 + persona_id = self.context.provider_manager.selected_default_persona['name'] + persona = next(builtins.filter( + lambda persona: persona['name'] == persona_id, + self.context.provider_manager.personas + ), None) + if persona: + if prompt := persona['prompt']: + req.system_prompt += prompt + if mood_dialogs := persona['_mood_imitation_dialogs_processed']: + req.system_prompt += "\nHere are few shots of dialogs, you need to imitate the tone of 'B' in the following dialogs to respond:\n" + req.system_prompt += mood_dialogs + if begin_dialogs := persona["_begin_dialogs_processed"]: + req.contexts[:0] = begin_dialogs if self.ltm: try: