feat: [beta] 支持群聊内基于概率的主动回复
This commit is contained in:
@@ -56,6 +56,13 @@ DEFAULT_CONFIG = {
|
||||
"group_message_max_cnt": 300,
|
||||
"image_caption": False,
|
||||
"image_caption_prompt": "Please describe the image using Chinese.",
|
||||
"active_reply": {
|
||||
"enable": False,
|
||||
"method": "possibility_reply",
|
||||
"possibility_reply": 0.1,
|
||||
"prompt": "",
|
||||
},
|
||||
"put_history_to_prompt": True,
|
||||
},
|
||||
"content_safety": {
|
||||
"internal_keywords": {"enable": True, "extra_keywords": []},
|
||||
@@ -654,25 +661,61 @@ CONFIG_METADATA_2 = {
|
||||
"group_icl_enable": {
|
||||
"description": "群聊内记录各群员对话",
|
||||
"type": "bool",
|
||||
"obvious-hint": True,
|
||||
"obvious_hint": True,
|
||||
"hint": "启用后,会记录群聊内各群员的对话。使用 /reset 命令清除记录。推荐使用 gpt-4o-mini 模型。",
|
||||
},
|
||||
"group_message_max_cnt": {
|
||||
"description": "群聊消息最大数量",
|
||||
"type": "int",
|
||||
"obvious-hint": True,
|
||||
"obvious_hint": True,
|
||||
"hint": "群聊消息最大数量。超过此数量后,会自动清除旧消息。",
|
||||
},
|
||||
"image_caption": {
|
||||
"description": "启用图像转述(需要模型支持)",
|
||||
"type": "bool",
|
||||
"obvious-hint": True,
|
||||
"obvious_hint": True,
|
||||
"hint": "启用后,当接收到图片消息时,会使用模型先将图片转述为文字再进行后续处理。推荐使用 gpt-4o-mini 模型。",
|
||||
},
|
||||
"image_caption_prompt": {
|
||||
"description": "图像转述提示词",
|
||||
"type": "string"
|
||||
},
|
||||
"active_reply": {
|
||||
"description": "主动回复",
|
||||
"type": "object",
|
||||
"items": {
|
||||
"enable": {
|
||||
"description": "启用主动回复",
|
||||
"type": "bool",
|
||||
"obvious_hint": True,
|
||||
"hint": "启用后,会根据触发概率主动回复群聊内的对话。",
|
||||
},
|
||||
"method": {
|
||||
"description": "回复方法",
|
||||
"type": "string",
|
||||
"options": ["possibility_reply"],
|
||||
"hint": "回复方法。possibility_reply 为根据概率回复",
|
||||
},
|
||||
"possibility_reply": {
|
||||
"description": "回复概率",
|
||||
"type": "float",
|
||||
"obvious_hint": True,
|
||||
"hint": "回复概率。当回复方法为 possibility_reply 时有效。当概率 >= 1 时,每条消息都会回复。",
|
||||
},
|
||||
"prompt": {
|
||||
"description": "提示词",
|
||||
"type": "string",
|
||||
"obvious_hint": True,
|
||||
"hint": "提示词。当提示词为空时,如果触发回复,prompt是触发的消息的内容;否则是提示词。此项可以和定时回复(暂未实现)配合使用。",
|
||||
},
|
||||
},
|
||||
},
|
||||
"put_history_to_prompt": {
|
||||
"description": "将群聊历史记录作为 prompt",
|
||||
"type": "bool",
|
||||
"obvious_hint": True,
|
||||
"hint": "需要先启用 group_icl_enable。此功能会将群聊历史记录放到 prompt 再请求。如果关闭,则是放在 system_prompt。如果开启了主动回复,建议启用,模型能够更好地完成回复任务。",
|
||||
}
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
@@ -37,7 +37,11 @@ class ProcessStage(Stage):
|
||||
# Handler 的 LLM 请求
|
||||
logger.debug(f"llm request -> {resp.prompt}")
|
||||
event.set_extra("provider_request", resp)
|
||||
_t = False
|
||||
async for _ in self.llm_request_sub_stage.process(event):
|
||||
_t = True
|
||||
yield
|
||||
if not _t:
|
||||
yield
|
||||
else:
|
||||
yield
|
||||
|
||||
@@ -296,6 +296,7 @@ class AstrMessageEvent(abc.ABC):
|
||||
def request_llm(
|
||||
self,
|
||||
prompt: str,
|
||||
func_tool_manager = None,
|
||||
session_id: str = None,
|
||||
image_urls: List[str] = None,
|
||||
contexts: List = None,
|
||||
@@ -311,11 +312,13 @@ class AstrMessageEvent(abc.ABC):
|
||||
|
||||
image_urls: 可以是 base64:// 或者 http:// 开头的图片链接,也可以是本地图片路径。
|
||||
contexts: 当指定 contexts 时,将会**只**使用 contexts 作为上下文。
|
||||
func_tool_manager: 函数工具管理器,用于调用函数工具。用 self.context.get_llm_tool_manager() 获取。
|
||||
'''
|
||||
return ProviderRequest(
|
||||
prompt = prompt,
|
||||
session_id = session_id,
|
||||
image_urls = image_urls,
|
||||
func_tool = func_tool_manager,
|
||||
contexts = contexts,
|
||||
system_prompt = system_prompt
|
||||
)
|
||||
@@ -54,4 +54,5 @@ class LLMResponse:
|
||||
tools_call_name: List[str] = field(default_factory=list)
|
||||
'''工具调用名称'''
|
||||
|
||||
raw_completion: ChatCompletion = None
|
||||
raw_completion: ChatCompletion = None
|
||||
_new_record: Dict[str, any] = None
|
||||
@@ -122,6 +122,7 @@ class LLMTunerModelLoader(Provider):
|
||||
|
||||
async def forget(self, session_id):
|
||||
self.session_memory[session_id] = []
|
||||
self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]), self.provider_config['id'])
|
||||
return True
|
||||
|
||||
async def get_current_key(self):
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import datetime
|
||||
import uuid
|
||||
import random
|
||||
import astrbot.api.star as star
|
||||
from astrbot.api.event import AstrMessageEvent
|
||||
from astrbot.api.platform import MessageType
|
||||
@@ -8,7 +9,9 @@ from astrbot.api.message_components import Plain, Image
|
||||
from astrbot import logger
|
||||
from collections import defaultdict
|
||||
|
||||
|
||||
'''
|
||||
聊天记忆增强
|
||||
'''
|
||||
class LongTermMemory:
|
||||
def __init__(self, config: dict, context: star.Context):
|
||||
self.config = config
|
||||
@@ -23,6 +26,13 @@ class LongTermMemory:
|
||||
self.image_caption = self.config["image_caption"]
|
||||
self.image_caption_prompt = self.config["image_caption_prompt"]
|
||||
|
||||
self.active_reply = self.config["active_reply"]
|
||||
self.ar_method = self.active_reply["method"]
|
||||
self.ar_possibility = self.active_reply["possibility_reply"]
|
||||
self.ar_prompt = self.active_reply.get("prompt", "")
|
||||
|
||||
self.put_history_to_prompt = self.config["put_history_to_prompt"]
|
||||
|
||||
async def remove_session(self, event: AstrMessageEvent) -> int:
|
||||
cnt = 0
|
||||
if event.unified_msg_origin in self.session_chats:
|
||||
@@ -39,8 +49,26 @@ class LongTermMemory:
|
||||
persist=False,
|
||||
)
|
||||
return response.completion_text
|
||||
|
||||
async def need_active_reply(self, event: AstrMessageEvent) -> bool:
|
||||
if not self.active_reply:
|
||||
return False
|
||||
if event.get_message_type() != MessageType.GROUP_MESSAGE:
|
||||
return False
|
||||
|
||||
if event.is_at_or_wake_command:
|
||||
# if the message is a command, let it pass
|
||||
return False
|
||||
|
||||
match self.ar_method:
|
||||
case "possibility_reply":
|
||||
return random.random() < self.ar_possibility
|
||||
|
||||
return False
|
||||
|
||||
|
||||
async def handle_message(self, event: AstrMessageEvent):
|
||||
'''仅支持群聊'''
|
||||
if event.get_message_type() == MessageType.GROUP_MESSAGE:
|
||||
datetime_str = datetime.datetime.now().strftime("%H:%M:%S")
|
||||
|
||||
@@ -59,23 +87,33 @@ class LongTermMemory:
|
||||
final_message += f" [Image: {caption}]"
|
||||
except Exception as e:
|
||||
logger.error(f"获取图片描述失败: {e}")
|
||||
else:
|
||||
final_message += " [Image]"
|
||||
logger.debug(f"ltm | {event.unified_msg_origin} | {final_message}")
|
||||
self.session_chats[event.unified_msg_origin].append(final_message)
|
||||
if len(self.session_chats[event.unified_msg_origin]) > self.max_cnt:
|
||||
self.session_chats[event.unified_msg_origin].pop(0)
|
||||
|
||||
async def on_req_llm(self, event: AstrMessageEvent, req: ProviderRequest):
|
||||
'''当触发 LLM 请求前,调用此方法修改 req'''
|
||||
if event.unified_msg_origin not in self.session_chats:
|
||||
return
|
||||
|
||||
chats_str = '\n---\n'.join(self.session_chats[event.unified_msg_origin])
|
||||
req.system_prompt += "You are now in a chatroom. The chat history is as follows: \n"
|
||||
req.system_prompt += chats_str
|
||||
if self.image_caption:
|
||||
req.system_prompt += (
|
||||
"The images sent by the members are displayed in text form above."
|
||||
)
|
||||
|
||||
|
||||
if self.put_history_to_prompt:
|
||||
prompt = req.prompt
|
||||
req.prompt = f"You are now in a chatroom. The chat history is as follows:\n{chats_str}"
|
||||
req.prompt += f"\nNow, a new message is coming: `{prompt}`. Please react to it. Only output your response and do not output any other information."
|
||||
req.contexts = [] # 清空上下文,当使用了群聊增强,所有聊天记录都在一个prompt中。
|
||||
else:
|
||||
req.system_prompt += "You are now in a chatroom. The chat history is as follows: \n"
|
||||
req.system_prompt += chats_str
|
||||
if self.image_caption:
|
||||
req.system_prompt += (
|
||||
"The images sent by the members are displayed in text form above."
|
||||
)
|
||||
|
||||
async def after_req_llm(self, event: AstrMessageEvent):
|
||||
if event.unified_msg_origin not in self.session_chats:
|
||||
return
|
||||
|
||||
@@ -5,7 +5,7 @@ import astrbot.api.star as star
|
||||
import astrbot.api.event.filter as filter
|
||||
from astrbot.api.event import AstrMessageEvent, MessageEventResult
|
||||
from astrbot.api import sp
|
||||
from astrbot.api.provider import Personality, ProviderRequest
|
||||
from astrbot.api.provider import Personality, ProviderRequest, LLMResponse
|
||||
from astrbot.api.platform import MessageType
|
||||
from astrbot.core.utils.io import download_dashboard, get_dashboard_version
|
||||
from astrbot.core.config.default import VERSION
|
||||
@@ -25,7 +25,7 @@ class Main(star.Star):
|
||||
self.enable_datetime = cfg['provider_settings']["datetime_system_prompt"]
|
||||
|
||||
self.ltm = None
|
||||
if self.context.get_config()['provider_ltm_settings']['group_icl_enable']:
|
||||
if self.context.get_config()['provider_ltm_settings']['group_icl_enable'] or self.context.get_config()['provider_ltm_settings']['active_reply']['enable']:
|
||||
try:
|
||||
self.ltm = LongTermMemory(self.context.get_config()['provider_ltm_settings'], self.context)
|
||||
except BaseException as e:
|
||||
@@ -452,12 +452,41 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
|
||||
|
||||
@filter.platform_adapter_type(filter.PlatformAdapterType.ALL)
|
||||
async def on_message(self, event: AstrMessageEvent):
|
||||
'''长期记忆'''
|
||||
'''群聊记忆增强'''
|
||||
if self.ltm:
|
||||
try:
|
||||
await self.ltm.handle_message(event)
|
||||
except BaseException as e:
|
||||
logger.error(e)
|
||||
need_active = await self.ltm.need_active_reply(event)
|
||||
|
||||
print(need_active)
|
||||
|
||||
group_icl_enable = self.context.get_config()['provider_ltm_settings']['group_icl_enable']
|
||||
if group_icl_enable:
|
||||
'''记录对话'''
|
||||
try:
|
||||
await self.ltm.handle_message(event)
|
||||
except BaseException as e:
|
||||
logger.error(e)
|
||||
|
||||
if need_active:
|
||||
'''主动回复'''
|
||||
provider = self.context.get_using_provider()
|
||||
if not provider:
|
||||
logger.error("未找到任何 LLM 提供商。请先配置。无法主动回复")
|
||||
return
|
||||
try:
|
||||
session_provider_context = provider.session_memory.get(event.session_id)
|
||||
|
||||
prompt = self.ltm.ar_prompt
|
||||
if not prompt:
|
||||
prompt = event.message_str
|
||||
|
||||
yield event.request_llm(
|
||||
prompt=prompt,
|
||||
func_tool_manager=self.context.get_llm_tool_manager(),
|
||||
session_id=event.session_id,
|
||||
contexts=session_provider_context if session_provider_context else []
|
||||
)
|
||||
except BaseException as e:
|
||||
logger.error(f"主动回复失败: {e}")
|
||||
|
||||
|
||||
@filter.on_llm_request()
|
||||
|
||||
Reference in New Issue
Block a user