feat: 支持记录非唤醒状态下群聊历史记录
This commit is contained in:
@@ -1,2 +1,2 @@
|
||||
from astrbot.core.provider import Provider, STTProvider, Personality
|
||||
from astrbot.core.provider.entites import ProviderRequest, ProviderType, ProviderMetaData
|
||||
from astrbot.core.provider.entites import ProviderRequest, ProviderType, ProviderMetaData, LLMResponse
|
||||
@@ -50,6 +50,12 @@ DEFAULT_CONFIG = {
|
||||
"enable": False,
|
||||
"provider_id": "",
|
||||
},
|
||||
"provider_ltm_settings": {
|
||||
"group_icl_enable": False,
|
||||
"group_message_max_cnt": 300,
|
||||
"image_caption": False,
|
||||
"image_caption_prompt": "Please describe the image using Chinese.",
|
||||
},
|
||||
"content_safety": {
|
||||
"internal_keywords": {"enable": True, "extra_keywords": []},
|
||||
"baidu_aip": {"enable": False, "app_id": "", "api_key": "", "secret_key": ""},
|
||||
@@ -630,6 +636,34 @@ CONFIG_METADATA_2 = {
|
||||
},
|
||||
},
|
||||
},
|
||||
"provider_ltm_settings": {
|
||||
"description": "聊天记忆增强(Beta)",
|
||||
"type": "object",
|
||||
"items": {
|
||||
"group_icl_enable": {
|
||||
"description": "群聊内记录各群员对话",
|
||||
"type": "bool",
|
||||
"obvious-hint": True,
|
||||
"hint": "启用后,会记录群聊内各群员的对话。使用 /reset 命令清除记录。推荐使用 gpt-4o-mini 模型。",
|
||||
},
|
||||
"group_message_max_cnt": {
|
||||
"description": "群聊消息最大数量",
|
||||
"type": "int",
|
||||
"obvious-hint": True,
|
||||
"hint": "群聊消息最大数量。超过此数量后,会自动清除旧消息。",
|
||||
},
|
||||
"image_caption": {
|
||||
"description": "启用图像转述(需要模型支持)",
|
||||
"type": "bool",
|
||||
"obvious-hint": True,
|
||||
"hint": "启用后,当接收到图片消息时,会使用模型先将图片转述为文字再进行后续处理。推荐使用 gpt-4o-mini 模型。",
|
||||
},
|
||||
"image_caption_prompt": {
|
||||
"description": "图像转述提示词",
|
||||
"type": "string"
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"misc_config_group": {
|
||||
|
||||
@@ -140,5 +140,10 @@ class MessageEventResult(MessageChain):
|
||||
'''
|
||||
return self.result_content_type == ResultContentType.LLM_RESULT
|
||||
|
||||
def get_plain_text(self) -> str:
|
||||
'''获取纯文本消息。这个方法将获取所有 Plain 组件的文本并拼接成一条消息。空格分隔。
|
||||
'''
|
||||
return " ".join([comp.text for comp in self.chain if isinstance(comp, Plain)])
|
||||
|
||||
|
||||
CommandResult = MessageEventResult
|
||||
@@ -39,8 +39,11 @@ class StarRequestSubStage(Stage):
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(f"Star {handler.handler_full_name} handle error: {e}")
|
||||
ret = f":(\n\n在调用插件 {star_map.get(handler.handler_module_path).name} 的处理函数 {handler.handler_name} 时出现异常:{e}"
|
||||
event.set_result(MessageEventResult().message(ret))
|
||||
yield
|
||||
event.clear_result()
|
||||
|
||||
if event.is_at_or_wake_command:
|
||||
ret = f":(\n\n在调用插件 {star_map.get(handler.handler_module_path).name} 的处理函数 {handler.handler_name} 时出现异常:{e}"
|
||||
event.set_result(MessageEventResult().message(ret))
|
||||
yield
|
||||
event.clear_result()
|
||||
|
||||
event.stop_event()
|
||||
@@ -163,7 +163,8 @@ class ProviderOpenAIOfficial(Provider):
|
||||
|
||||
try:
|
||||
llm_response = await self._query(payloads, func_tool)
|
||||
await self.save_history(contexts, new_record, session_id, llm_response)
|
||||
if kwargs.get("persist", True):
|
||||
await self.save_history(contexts, new_record, session_id, llm_response)
|
||||
return llm_response
|
||||
except Exception as e:
|
||||
if "maximum context length" in str(e):
|
||||
|
||||
@@ -0,0 +1,84 @@
|
||||
import datetime
|
||||
import uuid
|
||||
import astrbot.api.star as star
|
||||
from astrbot.api.event import AstrMessageEvent
|
||||
from astrbot.api.platform import MessageType
|
||||
from astrbot.api.provider import ProviderRequest
|
||||
from astrbot.api.message_components import Plain, Image
|
||||
from astrbot import logger
|
||||
from collections import defaultdict
|
||||
|
||||
|
||||
class LongTermMemory:
|
||||
def __init__(self, config: dict, context: star.Context):
|
||||
self.config = config
|
||||
self.context = context
|
||||
self.session_chats = defaultdict(list)
|
||||
"""记录群成员的群聊记录"""
|
||||
self.max_cnt = self.config["group_message_max_cnt"]
|
||||
self.image_caption = self.config["image_caption"]
|
||||
self.image_caption_prompt = self.config["image_caption_prompt"]
|
||||
|
||||
async def remove_session(self, event: AstrMessageEvent) -> int:
|
||||
cnt = 0
|
||||
if event.unified_msg_origin in self.session_chats:
|
||||
cnt = len(self.session_chats[event.unified_msg_origin])
|
||||
del self.session_chats[event.unified_msg_origin]
|
||||
return cnt
|
||||
|
||||
async def get_image_caption(self, image_url: str) -> str:
|
||||
provider = self.context.get_using_provider()
|
||||
response = await provider.text_chat(
|
||||
prompt=self.image_caption_prompt,
|
||||
session_id=uuid.uuid4().hex,
|
||||
image_urls=[image_url],
|
||||
persist=False,
|
||||
)
|
||||
return response.completion_text
|
||||
|
||||
async def handle_message(self, event: AstrMessageEvent):
|
||||
if event.get_message_type() == MessageType.GROUP_MESSAGE:
|
||||
datetime_str = datetime.datetime.now().strftime("%H:%M:%S")
|
||||
|
||||
final_message = f"[{event.message_obj.sender.nickname}/{datetime_str}]: "
|
||||
|
||||
for comp in event.get_messages():
|
||||
if isinstance(comp, Plain):
|
||||
final_message += f" {comp.text}"
|
||||
elif isinstance(comp, Image):
|
||||
# image_urls.append(comp.url if comp.url else comp.file)
|
||||
if self.image_caption:
|
||||
try:
|
||||
caption = await self.get_image_caption(
|
||||
comp.url if comp.url else comp.file
|
||||
)
|
||||
final_message += f" [Image: {caption}]"
|
||||
except Exception as e:
|
||||
logger.error(f"获取图片描述失败: {e}")
|
||||
logger.debug(f"ltm | {event.unified_msg_origin} | {final_message}")
|
||||
self.session_chats[event.unified_msg_origin].append(final_message)
|
||||
if len(self.session_chats[event.unified_msg_origin]) > self.max_cnt:
|
||||
self.session_chats[event.unified_msg_origin].pop(0)
|
||||
|
||||
async def on_req_llm(self, event: AstrMessageEvent, req: ProviderRequest):
|
||||
if event.unified_msg_origin not in self.session_chats:
|
||||
return
|
||||
|
||||
req.system_prompt += f"""You are now in a chatroom. The chat history is as follows.:
|
||||
{'\n---\n'.join(self.session_chats[event.unified_msg_origin])}
|
||||
"""
|
||||
if self.image_caption:
|
||||
req.system_prompt += (
|
||||
"The images sent by the members are displayed in text form above."
|
||||
)
|
||||
|
||||
async def after_req_llm(self, event: AstrMessageEvent):
|
||||
if event.unified_msg_origin not in self.session_chats:
|
||||
return
|
||||
|
||||
if event.get_result() and event.get_result().is_llm_result():
|
||||
final_message = f"[AstrBot/{datetime.datetime.now().strftime('%H:%M:%S')}]: {event.get_result().get_plain_text()}"
|
||||
logger.debug(f"ltm | {event.unified_msg_origin} | {final_message}")
|
||||
self.session_chats[event.unified_msg_origin].append(final_message)
|
||||
if len(self.session_chats[event.unified_msg_origin]) > self.max_cnt:
|
||||
self.session_chats[event.unified_msg_origin].pop(0)
|
||||
+86
-58
@@ -6,12 +6,15 @@ import astrbot.api.event.filter as filter
|
||||
from astrbot.api.event import AstrMessageEvent, MessageEventResult
|
||||
from astrbot.api import sp
|
||||
from astrbot.api.provider import Personality, ProviderRequest
|
||||
from astrbot.api.platform import MessageType
|
||||
from astrbot.core.utils.io import download_dashboard, get_dashboard_version
|
||||
from astrbot.core.config.default import VERSION
|
||||
from collections import defaultdict
|
||||
from .long_term_memory import LongTermMemory
|
||||
|
||||
from typing import Union
|
||||
|
||||
@star.register(name="astrbot", desc="AstrBot 基础指令集合", author="Soulter", version="4.0.0")
|
||||
@star.register(name="astrbot", desc="AstrBot 基础指令结合 + 拓展功能", author="Soulter", version="4.0.0")
|
||||
class Main(star.Star):
|
||||
def __init__(self, context: star.Context) -> None:
|
||||
self.context = context
|
||||
@@ -20,7 +23,9 @@ class Main(star.Star):
|
||||
self.identifier = cfg['provider_settings']['identifier']
|
||||
self.enable_datetime = cfg['provider_settings']["datetime_system_prompt"]
|
||||
|
||||
self.kdb_enabled = False
|
||||
self.ltm = None
|
||||
if self.context.get_config()['provider_ltm_settings']['group_icl_enable']:
|
||||
self.ltm = LongTermMemory(self.context.get_config()['provider_ltm_settings'], self.context)
|
||||
|
||||
async def _query_astrbot_notice(self):
|
||||
try:
|
||||
@@ -219,7 +224,12 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
|
||||
@filter.command("reset")
|
||||
async def reset(self, message: AstrMessageEvent):
|
||||
await self.context.get_using_provider().forget(message.session_id)
|
||||
message.set_result(MessageEventResult().message("重置成功"))
|
||||
ret = "清除会话 LLM 聊天历史成功。"
|
||||
if self.ltm:
|
||||
cnt = await self.ltm.remove_session(event=message)
|
||||
ret += f"\n聊天增强: 已清除 {cnt} 条聊天记录。"
|
||||
|
||||
message.set_result(MessageEventResult().message(ret))
|
||||
|
||||
@filter.command("model")
|
||||
async def model_ls(self, message: AstrMessageEvent, idx_or_name: Union[int, str] = None):
|
||||
@@ -355,9 +365,9 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
|
||||
self.context.provider_manager.personas
|
||||
), None):
|
||||
self.context.get_using_provider().curr_personality = persona
|
||||
message.set_result(MessageEventResult().message(f"设置成功。如果您正在切换到不同的人格,请注意使用 /reset 来清空上下文,防止原人格对话影响现人格。"))
|
||||
message.set_result(MessageEventResult().message("设置成功。如果您正在切换到不同的人格,请注意使用 /reset 来清空上下文,防止原人格对话影响现人格。"))
|
||||
else:
|
||||
message.set_result(MessageEventResult().message(f"不存在该人格情景。使用 /persona list 查看所有。"))
|
||||
message.set_result(MessageEventResult().message("不存在该人格情景。使用 /persona list 查看所有。"))
|
||||
|
||||
@filter.permission_type(filter.PermissionType.ADMIN)
|
||||
@filter.command("dashboard_update")
|
||||
@@ -366,31 +376,6 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
|
||||
await download_dashboard()
|
||||
yield event.plain_result("管理面板更新完成。")
|
||||
|
||||
@filter.on_llm_request()
|
||||
async def decorate_llm_req(self, event: AstrMessageEvent, req: ProviderRequest):
|
||||
provider = self.context.get_using_provider()
|
||||
if self.prompt_prefix:
|
||||
req.prompt = self.prompt_prefix + req.prompt
|
||||
if self.identifier:
|
||||
user_id = event.message_obj.sender.user_id
|
||||
user_nickname = event.message_obj.sender.nickname
|
||||
user_info = f"\n[User ID: {user_id}, Nickname: {user_nickname}]\n"
|
||||
req.prompt = user_info + req.prompt
|
||||
if self.enable_datetime:
|
||||
req.system_prompt += f"\nCurrent datetime: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M')}\n"
|
||||
|
||||
if persona := provider.curr_personality:
|
||||
if prompt := persona['prompt']:
|
||||
req.system_prompt += prompt
|
||||
if mood_dialogs := persona['_mood_imitation_dialogs_processed']:
|
||||
req.system_prompt += "\nHere are few shots of dialogs, you need to imitate the tone of 'B' in the following dialogs to respond:\n"
|
||||
req.system_prompt += mood_dialogs
|
||||
if begin_dialogs := persona["_begin_dialogs_processed"]:
|
||||
req.contexts[:0] = begin_dialogs
|
||||
|
||||
# if provider.curr_personality['prompt']:
|
||||
# req.system_prompt += f"\n{provider.curr_personality['prompt']}"
|
||||
|
||||
@filter.command("set")
|
||||
async def set_variable(self, event: AstrMessageEvent, key: str, value: str):
|
||||
session_id = event.get_session_id()
|
||||
@@ -428,32 +413,75 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
|
||||
await platform.logout()
|
||||
yield event.plain_result("已登出 gewechat")
|
||||
return
|
||||
|
||||
@filter.command_group("kdb")
|
||||
def kdb(self):
|
||||
pass
|
||||
|
||||
@kdb.command("on")
|
||||
async def on_kdb(self, event: AstrMessageEvent):
|
||||
self.kdb_enabled = True
|
||||
curr_kdb_name = self.context.provider_manager.curr_kdb_name
|
||||
if not curr_kdb_name:
|
||||
yield event.plain_result("未载入任何知识库")
|
||||
else:
|
||||
yield event.plain_result(f"知识库已打开。当前载入的知识库: {curr_kdb_name}")
|
||||
|
||||
@kdb.command("off")
|
||||
async def off_kdb(self, event: AstrMessageEvent):
|
||||
self.kdb_enabled = False
|
||||
yield event.plain_result("知识库已关闭")
|
||||
|
||||
|
||||
|
||||
@filter.platform_adapter_type(filter.PlatformAdapterType.ALL)
|
||||
async def on_message(self, event: AstrMessageEvent):
|
||||
'''长期记忆'''
|
||||
if self.ltm:
|
||||
await self.ltm.handle_message(event)
|
||||
|
||||
|
||||
@filter.on_llm_request()
|
||||
async def on_llm_response(self, event: AstrMessageEvent, req: ProviderRequest):
|
||||
curr_kdb_name = self.context.provider_manager.curr_kdb_name
|
||||
if self.kdb_enabled and curr_kdb_name:
|
||||
mgr = self.context.knowledge_db_manager
|
||||
results = await mgr.retrive_records(curr_kdb_name, req.prompt)
|
||||
if results:
|
||||
req.system_prompt += "\nHere are documents that related to user's query: \n"
|
||||
for result in results:
|
||||
req.system_prompt += f"- {result}\n"
|
||||
async def decorate_llm_req(self, event: AstrMessageEvent, req: ProviderRequest):
|
||||
'''在请求 LLM 前注入人格信息、Identifier、时间等 System Prompt'''
|
||||
provider = self.context.get_using_provider()
|
||||
if self.prompt_prefix:
|
||||
req.prompt = self.prompt_prefix + req.prompt
|
||||
|
||||
if self.identifier:
|
||||
user_id = event.message_obj.sender.user_id
|
||||
user_nickname = event.message_obj.sender.nickname
|
||||
user_info = f"\n[User ID: {user_id}, Nickname: {user_nickname}]\n"
|
||||
req.prompt = user_info + req.prompt
|
||||
|
||||
if self.enable_datetime:
|
||||
req.system_prompt += f"\nCurrent datetime: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M')}\n"
|
||||
|
||||
if persona := provider.curr_personality:
|
||||
if prompt := persona['prompt']:
|
||||
req.system_prompt += prompt
|
||||
if mood_dialogs := persona['_mood_imitation_dialogs_processed']:
|
||||
req.system_prompt += "\nHere are few shots of dialogs, you need to imitate the tone of 'B' in the following dialogs to respond:\n"
|
||||
req.system_prompt += mood_dialogs
|
||||
if begin_dialogs := persona["_begin_dialogs_processed"]:
|
||||
req.contexts[:0] = begin_dialogs
|
||||
|
||||
if self.ltm:
|
||||
await self.ltm.on_req_llm(event, req)
|
||||
|
||||
|
||||
@filter.after_message_sent()
|
||||
async def after_llm_req(self, event: AstrMessageEvent):
|
||||
'''在 LLM 请求后记录对话'''
|
||||
if self.ltm:
|
||||
await self.ltm.after_req_llm(event)
|
||||
|
||||
# @filter.command_group("kdb")
|
||||
# def kdb(self):
|
||||
# pass
|
||||
|
||||
# @kdb.command("on")
|
||||
# async def on_kdb(self, event: AstrMessageEvent):
|
||||
# self.kdb_enabled = True
|
||||
# curr_kdb_name = self.context.provider_manager.curr_kdb_name
|
||||
# if not curr_kdb_name:
|
||||
# yield event.plain_result("未载入任何知识库")
|
||||
# else:
|
||||
# yield event.plain_result(f"知识库已打开。当前载入的知识库: {curr_kdb_name}")
|
||||
|
||||
# @kdb.command("off")
|
||||
# async def off_kdb(self, event: AstrMessageEvent):
|
||||
# self.kdb_enabled = False
|
||||
# yield event.plain_result("知识库已关闭")
|
||||
|
||||
# @filter.on_llm_request()
|
||||
# async def on_llm_response(self, event: AstrMessageEvent, req: ProviderRequest):
|
||||
# curr_kdb_name = self.context.provider_manager.curr_kdb_name
|
||||
# if self.kdb_enabled and curr_kdb_name:
|
||||
# mgr = self.context.knowledge_db_manager
|
||||
# results = await mgr.retrive_records(curr_kdb_name, req.prompt)
|
||||
# if results:
|
||||
# req.system_prompt += "\nHere are documents that related to user's query: \n"
|
||||
# for result in results:
|
||||
# req.system_prompt += f"- {result}\n"
|
||||
Reference in New Issue
Block a user