feat: [beta] 支持群聊内基于概率的主动回复

This commit is contained in:
Soulter
2025-02-02 19:23:46 +08:00
parent aa8c56a688
commit 4611ce15eb
7 changed files with 138 additions and 19 deletions
+46 -3
View File
@@ -56,6 +56,13 @@ DEFAULT_CONFIG = {
"group_message_max_cnt": 300,
"image_caption": False,
"image_caption_prompt": "Please describe the image using Chinese.",
"active_reply": {
"enable": False,
"method": "possibility_reply",
"possibility_reply": 0.1,
"prompt": "",
},
"put_history_to_prompt": True,
},
"content_safety": {
"internal_keywords": {"enable": True, "extra_keywords": []},
@@ -654,25 +661,61 @@ CONFIG_METADATA_2 = {
"group_icl_enable": {
"description": "群聊内记录各群员对话",
"type": "bool",
"obvious-hint": True,
"obvious_hint": True,
"hint": "启用后,会记录群聊内各群员的对话。使用 /reset 命令清除记录。推荐使用 gpt-4o-mini 模型。",
},
"group_message_max_cnt": {
"description": "群聊消息最大数量",
"type": "int",
"obvious-hint": True,
"obvious_hint": True,
"hint": "群聊消息最大数量。超过此数量后,会自动清除旧消息。",
},
"image_caption": {
"description": "启用图像转述(需要模型支持)",
"type": "bool",
"obvious-hint": True,
"obvious_hint": True,
"hint": "启用后,当接收到图片消息时,会使用模型先将图片转述为文字再进行后续处理。推荐使用 gpt-4o-mini 模型。",
},
"image_caption_prompt": {
"description": "图像转述提示词",
"type": "string"
},
"active_reply": {
"description": "主动回复",
"type": "object",
"items": {
"enable": {
"description": "启用主动回复",
"type": "bool",
"obvious_hint": True,
"hint": "启用后,会根据触发概率主动回复群聊内的对话。",
},
"method": {
"description": "回复方法",
"type": "string",
"options": ["possibility_reply"],
"hint": "回复方法。possibility_reply 为根据概率回复",
},
"possibility_reply": {
"description": "回复概率",
"type": "float",
"obvious_hint": True,
"hint": "回复概率。当回复方法为 possibility_reply 时有效。当概率 >= 1 时,每条消息都会回复。",
},
"prompt": {
"description": "提示词",
"type": "string",
"obvious_hint": True,
"hint": "提示词。当提示词为空时,如果触发回复,prompt是触发的消息的内容;否则是提示词。此项可以和定时回复(暂未实现)配合使用。",
},
},
},
"put_history_to_prompt": {
"description": "将群聊历史记录作为 prompt",
"type": "bool",
"obvious_hint": True,
"hint": "需要先启用 group_icl_enable。此功能会将群聊历史记录放到 prompt 再请求。如果关闭,则是放在 system_prompt。如果开启了主动回复,建议启用,模型能够更好地完成回复任务。",
}
},
},
},
@@ -37,7 +37,11 @@ class ProcessStage(Stage):
# Handler 的 LLM 请求
logger.debug(f"llm request -> {resp.prompt}")
event.set_extra("provider_request", resp)
_t = False
async for _ in self.llm_request_sub_stage.process(event):
_t = True
yield
if not _t:
yield
else:
yield
@@ -296,6 +296,7 @@ class AstrMessageEvent(abc.ABC):
def request_llm(
self,
prompt: str,
func_tool_manager = None,
session_id: str = None,
image_urls: List[str] = None,
contexts: List = None,
@@ -311,11 +312,13 @@ class AstrMessageEvent(abc.ABC):
image_urls: 可以是 base64:// 或者 http:// 开头的图片链接,也可以是本地图片路径。
contexts: 当指定 contexts 时,将会**只**使用 contexts 作为上下文。
func_tool_manager: 函数工具管理器,用于调用函数工具。用 self.context.get_llm_tool_manager() 获取。
'''
return ProviderRequest(
prompt = prompt,
session_id = session_id,
image_urls = image_urls,
func_tool = func_tool_manager,
contexts = contexts,
system_prompt = system_prompt
)
+2 -1
View File
@@ -54,4 +54,5 @@ class LLMResponse:
tools_call_name: List[str] = field(default_factory=list)
'''工具调用名称'''
raw_completion: ChatCompletion = None
raw_completion: ChatCompletion = None
_new_record: Dict[str, any] = None
@@ -122,6 +122,7 @@ class LLMTunerModelLoader(Provider):
async def forget(self, session_id):
self.session_memory[session_id] = []
self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]), self.provider_config['id'])
return True
async def get_current_key(self):
+46 -8
View File
@@ -1,5 +1,6 @@
import datetime
import uuid
import random
import astrbot.api.star as star
from astrbot.api.event import AstrMessageEvent
from astrbot.api.platform import MessageType
@@ -8,7 +9,9 @@ from astrbot.api.message_components import Plain, Image
from astrbot import logger
from collections import defaultdict
'''
聊天记忆增强
'''
class LongTermMemory:
def __init__(self, config: dict, context: star.Context):
self.config = config
@@ -23,6 +26,13 @@ class LongTermMemory:
self.image_caption = self.config["image_caption"]
self.image_caption_prompt = self.config["image_caption_prompt"]
self.active_reply = self.config["active_reply"]
self.ar_method = self.active_reply["method"]
self.ar_possibility = self.active_reply["possibility_reply"]
self.ar_prompt = self.active_reply.get("prompt", "")
self.put_history_to_prompt = self.config["put_history_to_prompt"]
async def remove_session(self, event: AstrMessageEvent) -> int:
cnt = 0
if event.unified_msg_origin in self.session_chats:
@@ -39,8 +49,26 @@ class LongTermMemory:
persist=False,
)
return response.completion_text
async def need_active_reply(self, event: AstrMessageEvent) -> bool:
if not self.active_reply:
return False
if event.get_message_type() != MessageType.GROUP_MESSAGE:
return False
if event.is_at_or_wake_command:
# if the message is a command, let it pass
return False
match self.ar_method:
case "possibility_reply":
return random.random() < self.ar_possibility
return False
async def handle_message(self, event: AstrMessageEvent):
'''仅支持群聊'''
if event.get_message_type() == MessageType.GROUP_MESSAGE:
datetime_str = datetime.datetime.now().strftime("%H:%M:%S")
@@ -59,23 +87,33 @@ class LongTermMemory:
final_message += f" [Image: {caption}]"
except Exception as e:
logger.error(f"获取图片描述失败: {e}")
else:
final_message += " [Image]"
logger.debug(f"ltm | {event.unified_msg_origin} | {final_message}")
self.session_chats[event.unified_msg_origin].append(final_message)
if len(self.session_chats[event.unified_msg_origin]) > self.max_cnt:
self.session_chats[event.unified_msg_origin].pop(0)
async def on_req_llm(self, event: AstrMessageEvent, req: ProviderRequest):
'''当触发 LLM 请求前,调用此方法修改 req'''
if event.unified_msg_origin not in self.session_chats:
return
chats_str = '\n---\n'.join(self.session_chats[event.unified_msg_origin])
req.system_prompt += "You are now in a chatroom. The chat history is as follows: \n"
req.system_prompt += chats_str
if self.image_caption:
req.system_prompt += (
"The images sent by the members are displayed in text form above."
)
if self.put_history_to_prompt:
prompt = req.prompt
req.prompt = f"You are now in a chatroom. The chat history is as follows:\n{chats_str}"
req.prompt += f"\nNow, a new message is coming: `{prompt}`. Please react to it. Only output your response and do not output any other information."
req.contexts = [] # 清空上下文,当使用了群聊增强,所有聊天记录都在一个prompt中。
else:
req.system_prompt += "You are now in a chatroom. The chat history is as follows: \n"
req.system_prompt += chats_str
if self.image_caption:
req.system_prompt += (
"The images sent by the members are displayed in text form above."
)
async def after_req_llm(self, event: AstrMessageEvent):
if event.unified_msg_origin not in self.session_chats:
return
+36 -7
View File
@@ -5,7 +5,7 @@ import astrbot.api.star as star
import astrbot.api.event.filter as filter
from astrbot.api.event import AstrMessageEvent, MessageEventResult
from astrbot.api import sp
from astrbot.api.provider import Personality, ProviderRequest
from astrbot.api.provider import Personality, ProviderRequest, LLMResponse
from astrbot.api.platform import MessageType
from astrbot.core.utils.io import download_dashboard, get_dashboard_version
from astrbot.core.config.default import VERSION
@@ -25,7 +25,7 @@ class Main(star.Star):
self.enable_datetime = cfg['provider_settings']["datetime_system_prompt"]
self.ltm = None
if self.context.get_config()['provider_ltm_settings']['group_icl_enable']:
if self.context.get_config()['provider_ltm_settings']['group_icl_enable'] or self.context.get_config()['provider_ltm_settings']['active_reply']['enable']:
try:
self.ltm = LongTermMemory(self.context.get_config()['provider_ltm_settings'], self.context)
except BaseException as e:
@@ -452,12 +452,41 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
@filter.platform_adapter_type(filter.PlatformAdapterType.ALL)
async def on_message(self, event: AstrMessageEvent):
'''长期记忆'''
'''群聊记忆增强'''
if self.ltm:
try:
await self.ltm.handle_message(event)
except BaseException as e:
logger.error(e)
need_active = await self.ltm.need_active_reply(event)
print(need_active)
group_icl_enable = self.context.get_config()['provider_ltm_settings']['group_icl_enable']
if group_icl_enable:
'''记录对话'''
try:
await self.ltm.handle_message(event)
except BaseException as e:
logger.error(e)
if need_active:
'''主动回复'''
provider = self.context.get_using_provider()
if not provider:
logger.error("未找到任何 LLM 提供商。请先配置。无法主动回复")
return
try:
session_provider_context = provider.session_memory.get(event.session_id)
prompt = self.ltm.ar_prompt
if not prompt:
prompt = event.message_str
yield event.request_llm(
prompt=prompt,
func_tool_manager=self.context.get_llm_tool_manager(),
session_id=event.session_id,
contexts=session_provider_context if session_provider_context else []
)
except BaseException as e:
logger.error(f"主动回复失败: {e}")
@filter.on_llm_request()