From 2bc433a30bb4d248ce2dfec017a6555ae61a5f98 Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Mon, 27 Jan 2025 20:00:32 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=94=AF=E6=8C=81=E8=AE=B0=E5=BD=95?= =?UTF-8?q?=E9=9D=9E=E5=94=A4=E9=86=92=E7=8A=B6=E6=80=81=E4=B8=8B=E7=BE=A4?= =?UTF-8?q?=E8=81=8A=E5=8E=86=E5=8F=B2=E8=AE=B0=E5=BD=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/api/provider/__init__.py | 2 +- astrbot/core/config/default.py | 34 +++++ astrbot/core/message/message_event_result.py | 5 + .../process_stage/method/star_request.py | 11 +- .../core/provider/sources/openai_source.py | 3 +- packages/astrbot/long_term_memory.py | 84 ++++++++++ packages/astrbot/main.py | 144 +++++++++++------- 7 files changed, 219 insertions(+), 64 deletions(-) create mode 100644 packages/astrbot/long_term_memory.py diff --git a/astrbot/api/provider/__init__.py b/astrbot/api/provider/__init__.py index 17e379478..0158d2a89 100644 --- a/astrbot/api/provider/__init__.py +++ b/astrbot/api/provider/__init__.py @@ -1,2 +1,2 @@ from astrbot.core.provider import Provider, STTProvider, Personality -from astrbot.core.provider.entites import ProviderRequest, ProviderType, ProviderMetaData \ No newline at end of file +from astrbot.core.provider.entites import ProviderRequest, ProviderType, ProviderMetaData, LLMResponse \ No newline at end of file diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 6a5820640..5eb29faf7 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -50,6 +50,12 @@ DEFAULT_CONFIG = { "enable": False, "provider_id": "", }, + "provider_ltm_settings": { + "group_icl_enable": False, + "group_message_max_cnt": 300, + "image_caption": False, + "image_caption_prompt": "Please describe the image using Chinese.", + }, "content_safety": { "internal_keywords": {"enable": True, "extra_keywords": []}, "baidu_aip": {"enable": False, "app_id": "", "api_key": "", "secret_key": ""}, @@ -630,6 +636,34 @@ CONFIG_METADATA_2 = { }, }, }, + "provider_ltm_settings": { + "description": "聊天记忆增强(Beta)", + "type": "object", + "items": { + "group_icl_enable": { + "description": "群聊内记录各群员对话", + "type": "bool", + "obvious-hint": True, + "hint": "启用后,会记录群聊内各群员的对话。使用 /reset 命令清除记录。推荐使用 gpt-4o-mini 模型。", + }, + "group_message_max_cnt": { + "description": "群聊消息最大数量", + "type": "int", + "obvious-hint": True, + "hint": "群聊消息最大数量。超过此数量后,会自动清除旧消息。", + }, + "image_caption": { + "description": "启用图像转述(需要模型支持)", + "type": "bool", + "obvious-hint": True, + "hint": "启用后,当接收到图片消息时,会使用模型先将图片转述为文字再进行后续处理。推荐使用 gpt-4o-mini 模型。", + }, + "image_caption_prompt": { + "description": "图像转述提示词", + "type": "string" + }, + }, + }, }, }, "misc_config_group": { diff --git a/astrbot/core/message/message_event_result.py b/astrbot/core/message/message_event_result.py index c9c13ec9c..33eae64bc 100644 --- a/astrbot/core/message/message_event_result.py +++ b/astrbot/core/message/message_event_result.py @@ -140,5 +140,10 @@ class MessageEventResult(MessageChain): ''' return self.result_content_type == ResultContentType.LLM_RESULT + def get_plain_text(self) -> str: + '''获取纯文本消息。这个方法将获取所有 Plain 组件的文本并拼接成一条消息。空格分隔。 + ''' + return " ".join([comp.text for comp in self.chain if isinstance(comp, Plain)]) + CommandResult = MessageEventResult \ No newline at end of file diff --git a/astrbot/core/pipeline/process_stage/method/star_request.py b/astrbot/core/pipeline/process_stage/method/star_request.py index 6863df473..109e1ea81 100644 --- a/astrbot/core/pipeline/process_stage/method/star_request.py +++ b/astrbot/core/pipeline/process_stage/method/star_request.py @@ -39,8 +39,11 @@ class StarRequestSubStage(Stage): except Exception as e: logger.error(traceback.format_exc()) logger.error(f"Star {handler.handler_full_name} handle error: {e}") - ret = f":(\n\n在调用插件 {star_map.get(handler.handler_module_path).name} 的处理函数 {handler.handler_name} 时出现异常:{e}" - event.set_result(MessageEventResult().message(ret)) - yield - event.clear_result() + + if event.is_at_or_wake_command: + ret = f":(\n\n在调用插件 {star_map.get(handler.handler_module_path).name} 的处理函数 {handler.handler_name} 时出现异常:{e}" + event.set_result(MessageEventResult().message(ret)) + yield + event.clear_result() + event.stop_event() \ No newline at end of file diff --git a/astrbot/core/provider/sources/openai_source.py b/astrbot/core/provider/sources/openai_source.py index 9a0e72fb2..dc4b4ed87 100644 --- a/astrbot/core/provider/sources/openai_source.py +++ b/astrbot/core/provider/sources/openai_source.py @@ -163,7 +163,8 @@ class ProviderOpenAIOfficial(Provider): try: llm_response = await self._query(payloads, func_tool) - await self.save_history(contexts, new_record, session_id, llm_response) + if kwargs.get("persist", True): + await self.save_history(contexts, new_record, session_id, llm_response) return llm_response except Exception as e: if "maximum context length" in str(e): diff --git a/packages/astrbot/long_term_memory.py b/packages/astrbot/long_term_memory.py new file mode 100644 index 000000000..ef91a1309 --- /dev/null +++ b/packages/astrbot/long_term_memory.py @@ -0,0 +1,84 @@ +import datetime +import uuid +import astrbot.api.star as star +from astrbot.api.event import AstrMessageEvent +from astrbot.api.platform import MessageType +from astrbot.api.provider import ProviderRequest +from astrbot.api.message_components import Plain, Image +from astrbot import logger +from collections import defaultdict + + +class LongTermMemory: + def __init__(self, config: dict, context: star.Context): + self.config = config + self.context = context + self.session_chats = defaultdict(list) + """记录群成员的群聊记录""" + self.max_cnt = self.config["group_message_max_cnt"] + self.image_caption = self.config["image_caption"] + self.image_caption_prompt = self.config["image_caption_prompt"] + + async def remove_session(self, event: AstrMessageEvent) -> int: + cnt = 0 + if event.unified_msg_origin in self.session_chats: + cnt = len(self.session_chats[event.unified_msg_origin]) + del self.session_chats[event.unified_msg_origin] + return cnt + + async def get_image_caption(self, image_url: str) -> str: + provider = self.context.get_using_provider() + response = await provider.text_chat( + prompt=self.image_caption_prompt, + session_id=uuid.uuid4().hex, + image_urls=[image_url], + persist=False, + ) + return response.completion_text + + async def handle_message(self, event: AstrMessageEvent): + if event.get_message_type() == MessageType.GROUP_MESSAGE: + datetime_str = datetime.datetime.now().strftime("%H:%M:%S") + + final_message = f"[{event.message_obj.sender.nickname}/{datetime_str}]: " + + for comp in event.get_messages(): + if isinstance(comp, Plain): + final_message += f" {comp.text}" + elif isinstance(comp, Image): + # image_urls.append(comp.url if comp.url else comp.file) + if self.image_caption: + try: + caption = await self.get_image_caption( + comp.url if comp.url else comp.file + ) + final_message += f" [Image: {caption}]" + except Exception as e: + logger.error(f"获取图片描述失败: {e}") + logger.debug(f"ltm | {event.unified_msg_origin} | {final_message}") + self.session_chats[event.unified_msg_origin].append(final_message) + if len(self.session_chats[event.unified_msg_origin]) > self.max_cnt: + self.session_chats[event.unified_msg_origin].pop(0) + + async def on_req_llm(self, event: AstrMessageEvent, req: ProviderRequest): + if event.unified_msg_origin not in self.session_chats: + return + + req.system_prompt += f"""You are now in a chatroom. The chat history is as follows.: + {'\n---\n'.join(self.session_chats[event.unified_msg_origin])} +""" + if self.image_caption: + req.system_prompt += ( + "The images sent by the members are displayed in text form above." + ) + + async def after_req_llm(self, event: AstrMessageEvent): + if event.unified_msg_origin not in self.session_chats: + return + + if event.get_result() and event.get_result().is_llm_result(): + final_message = f"[AstrBot/{datetime.datetime.now().strftime('%H:%M:%S')}]: {event.get_result().get_plain_text()}" + logger.debug(f"ltm | {event.unified_msg_origin} | {final_message}") + self.session_chats[event.unified_msg_origin].append(final_message) + if len(self.session_chats[event.unified_msg_origin]) > self.max_cnt: + self.session_chats[event.unified_msg_origin].pop(0) diff --git a/packages/astrbot/main.py b/packages/astrbot/main.py index bb533ee22..cb0955d33 100644 --- a/packages/astrbot/main.py +++ b/packages/astrbot/main.py @@ -6,12 +6,15 @@ import astrbot.api.event.filter as filter from astrbot.api.event import AstrMessageEvent, MessageEventResult from astrbot.api import sp from astrbot.api.provider import Personality, ProviderRequest +from astrbot.api.platform import MessageType from astrbot.core.utils.io import download_dashboard, get_dashboard_version from astrbot.core.config.default import VERSION +from collections import defaultdict +from .long_term_memory import LongTermMemory from typing import Union -@star.register(name="astrbot", desc="AstrBot 基础指令集合", author="Soulter", version="4.0.0") +@star.register(name="astrbot", desc="AstrBot 基础指令结合 + 拓展功能", author="Soulter", version="4.0.0") class Main(star.Star): def __init__(self, context: star.Context) -> None: self.context = context @@ -20,7 +23,9 @@ class Main(star.Star): self.identifier = cfg['provider_settings']['identifier'] self.enable_datetime = cfg['provider_settings']["datetime_system_prompt"] - self.kdb_enabled = False + self.ltm = None + if self.context.get_config()['provider_ltm_settings']['group_icl_enable']: + self.ltm = LongTermMemory(self.context.get_config()['provider_ltm_settings'], self.context) async def _query_astrbot_notice(self): try: @@ -219,7 +224,12 @@ UID: {user_id} 此 ID 可用于设置管理员。/op 授权管理员, /deo @filter.command("reset") async def reset(self, message: AstrMessageEvent): await self.context.get_using_provider().forget(message.session_id) - message.set_result(MessageEventResult().message("重置成功")) + ret = "清除会话 LLM 聊天历史成功。" + if self.ltm: + cnt = await self.ltm.remove_session(event=message) + ret += f"\n聊天增强: 已清除 {cnt} 条聊天记录。" + + message.set_result(MessageEventResult().message(ret)) @filter.command("model") async def model_ls(self, message: AstrMessageEvent, idx_or_name: Union[int, str] = None): @@ -355,9 +365,9 @@ UID: {user_id} 此 ID 可用于设置管理员。/op 授权管理员, /deo self.context.provider_manager.personas ), None): self.context.get_using_provider().curr_personality = persona - message.set_result(MessageEventResult().message(f"设置成功。如果您正在切换到不同的人格,请注意使用 /reset 来清空上下文,防止原人格对话影响现人格。")) + message.set_result(MessageEventResult().message("设置成功。如果您正在切换到不同的人格,请注意使用 /reset 来清空上下文,防止原人格对话影响现人格。")) else: - message.set_result(MessageEventResult().message(f"不存在该人格情景。使用 /persona list 查看所有。")) + message.set_result(MessageEventResult().message("不存在该人格情景。使用 /persona list 查看所有。")) @filter.permission_type(filter.PermissionType.ADMIN) @filter.command("dashboard_update") @@ -366,31 +376,6 @@ UID: {user_id} 此 ID 可用于设置管理员。/op 授权管理员, /deo await download_dashboard() yield event.plain_result("管理面板更新完成。") - @filter.on_llm_request() - async def decorate_llm_req(self, event: AstrMessageEvent, req: ProviderRequest): - provider = self.context.get_using_provider() - if self.prompt_prefix: - req.prompt = self.prompt_prefix + req.prompt - if self.identifier: - user_id = event.message_obj.sender.user_id - user_nickname = event.message_obj.sender.nickname - user_info = f"\n[User ID: {user_id}, Nickname: {user_nickname}]\n" - req.prompt = user_info + req.prompt - if self.enable_datetime: - req.system_prompt += f"\nCurrent datetime: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M')}\n" - - if persona := provider.curr_personality: - if prompt := persona['prompt']: - req.system_prompt += prompt - if mood_dialogs := persona['_mood_imitation_dialogs_processed']: - req.system_prompt += "\nHere are few shots of dialogs, you need to imitate the tone of 'B' in the following dialogs to respond:\n" - req.system_prompt += mood_dialogs - if begin_dialogs := persona["_begin_dialogs_processed"]: - req.contexts[:0] = begin_dialogs - - # if provider.curr_personality['prompt']: - # req.system_prompt += f"\n{provider.curr_personality['prompt']}" - @filter.command("set") async def set_variable(self, event: AstrMessageEvent, key: str, value: str): session_id = event.get_session_id() @@ -428,32 +413,75 @@ UID: {user_id} 此 ID 可用于设置管理员。/op 授权管理员, /deo await platform.logout() yield event.plain_result("已登出 gewechat") return - - @filter.command_group("kdb") - def kdb(self): - pass - - @kdb.command("on") - async def on_kdb(self, event: AstrMessageEvent): - self.kdb_enabled = True - curr_kdb_name = self.context.provider_manager.curr_kdb_name - if not curr_kdb_name: - yield event.plain_result("未载入任何知识库") - else: - yield event.plain_result(f"知识库已打开。当前载入的知识库: {curr_kdb_name}") - - @kdb.command("off") - async def off_kdb(self, event: AstrMessageEvent): - self.kdb_enabled = False - yield event.plain_result("知识库已关闭") - + + + @filter.platform_adapter_type(filter.PlatformAdapterType.ALL) + async def on_message(self, event: AstrMessageEvent): + '''长期记忆''' + if self.ltm: + await self.ltm.handle_message(event) + + @filter.on_llm_request() - async def on_llm_response(self, event: AstrMessageEvent, req: ProviderRequest): - curr_kdb_name = self.context.provider_manager.curr_kdb_name - if self.kdb_enabled and curr_kdb_name: - mgr = self.context.knowledge_db_manager - results = await mgr.retrive_records(curr_kdb_name, req.prompt) - if results: - req.system_prompt += "\nHere are documents that related to user's query: \n" - for result in results: - req.system_prompt += f"- {result}\n" \ No newline at end of file + async def decorate_llm_req(self, event: AstrMessageEvent, req: ProviderRequest): + '''在请求 LLM 前注入人格信息、Identifier、时间等 System Prompt''' + provider = self.context.get_using_provider() + if self.prompt_prefix: + req.prompt = self.prompt_prefix + req.prompt + + if self.identifier: + user_id = event.message_obj.sender.user_id + user_nickname = event.message_obj.sender.nickname + user_info = f"\n[User ID: {user_id}, Nickname: {user_nickname}]\n" + req.prompt = user_info + req.prompt + + if self.enable_datetime: + req.system_prompt += f"\nCurrent datetime: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M')}\n" + + if persona := provider.curr_personality: + if prompt := persona['prompt']: + req.system_prompt += prompt + if mood_dialogs := persona['_mood_imitation_dialogs_processed']: + req.system_prompt += "\nHere are few shots of dialogs, you need to imitate the tone of 'B' in the following dialogs to respond:\n" + req.system_prompt += mood_dialogs + if begin_dialogs := persona["_begin_dialogs_processed"]: + req.contexts[:0] = begin_dialogs + + if self.ltm: + await self.ltm.on_req_llm(event, req) + + + @filter.after_message_sent() + async def after_llm_req(self, event: AstrMessageEvent): + '''在 LLM 请求后记录对话''' + if self.ltm: + await self.ltm.after_req_llm(event) + + # @filter.command_group("kdb") + # def kdb(self): + # pass + + # @kdb.command("on") + # async def on_kdb(self, event: AstrMessageEvent): + # self.kdb_enabled = True + # curr_kdb_name = self.context.provider_manager.curr_kdb_name + # if not curr_kdb_name: + # yield event.plain_result("未载入任何知识库") + # else: + # yield event.plain_result(f"知识库已打开。当前载入的知识库: {curr_kdb_name}") + + # @kdb.command("off") + # async def off_kdb(self, event: AstrMessageEvent): + # self.kdb_enabled = False + # yield event.plain_result("知识库已关闭") + + # @filter.on_llm_request() + # async def on_llm_response(self, event: AstrMessageEvent, req: ProviderRequest): + # curr_kdb_name = self.context.provider_manager.curr_kdb_name + # if self.kdb_enabled and curr_kdb_name: + # mgr = self.context.knowledge_db_manager + # results = await mgr.retrive_records(curr_kdb_name, req.prompt) + # if results: + # req.system_prompt += "\nHere are documents that related to user's query: \n" + # for result in results: + # req.system_prompt += f"- {result}\n" \ No newline at end of file