diff --git a/cores/qqbot/core.py b/cores/qqbot/core.py index 262483622..1ef1d6eb1 100644 --- a/cores/qqbot/core.py +++ b/cores/qqbot/core.py @@ -603,9 +603,14 @@ async def oper_msg(message: Union[GroupMessage, FriendMessage, GuildMessage, Nak chatgpt_res = "" - if session_id in gocq_bot.waiting and gocq_bot.waiting[session_id] == '': - gocq_bot.waiting[session_id] = qq_msg + # 如果是等待回复的消息 + if platform == PLATFORM_GOCQ and session_id in gocq_bot.waiting and gocq_bot.waiting[session_id] == '': + gocq_bot.waiting[session_id] = message return + if platform == PLATFORM_QQCHAN and session_id in qqchannel_bot.waiting and qqchannel_bot.waiting[session_id] == '': + qqchannel_bot.waiting[session_id] = message + return + hit, command_result = llm_command_instance[chosen_provider].check_command( qq_msg, session_id, diff --git a/cores/qqbot/global_object.py b/cores/qqbot/global_object.py index ace53fe2c..cc6873450 100644 --- a/cores/qqbot/global_object.py +++ b/cores/qqbot/global_object.py @@ -48,6 +48,7 @@ class AstrMessageEvent(): platform: str # `gocq` 或 `qqchan` role: str # `admin` 或 `member` global_object: GlobalObject # 一些公用数据 + session_id: int # 会话id (可能是群id,也可能是某个user的id。取决于是否开启了 uniqueSession) def __init__(self, message_str: str, message_obj: Union[GroupMessage, FriendMessage, GuildMessage, NakuruGuildMessage], @@ -56,7 +57,8 @@ class AstrMessageEvent(): platform: str, role: str, global_object: GlobalObject, - llm_provider: Provider = None): + llm_provider: Provider = None, + session_id: int = None): self.message_str = message_str self.message_obj = message_obj self.gocq_platform = gocq_platform @@ -64,4 +66,5 @@ class AstrMessageEvent(): self.platform = platform self.role = role self.global_object = global_object - self.llm_provider = llm_provider \ No newline at end of file + self.llm_provider = llm_provider + self.session_id = session_id \ No newline at end of file diff --git a/model/command/command.py b/model/command/command.py index b23c4ec05..7fc068c61 100644 --- a/model/command/command.py +++ b/model/command/command.py @@ -50,7 +50,8 @@ class Command: qq_sdk_platform=self.global_object.platform_qqchan, platform=platform, role=role, - global_object=self.global_object + global_object=self.global_object, + session_id = session_id ) for k, v in cached_plugins.items(): try: diff --git a/model/platform/qq.py b/model/platform/qq.py index a52342628..1b7a910d8 100644 --- a/model/platform/qq.py +++ b/model/platform/qq.py @@ -4,8 +4,11 @@ from util.cmd_config import CmdConfig import asyncio from nakuru import ( CQHTTP, - GuildMessage + GuildMessage, + GroupMessage, + FriendMessage ) +from typing import Union import time @@ -155,18 +158,22 @@ class QQ: except Exception as e: raise e - def wait_for_message(self, group_id): + def wait_for_message(self, group_id) -> Union[GroupMessage, FriendMessage, GuildMessage]: ''' - 等待下一条消息 + 等待下一条消息,超时 300s 后抛出异常 ''' self.waiting[group_id] = '' + cnt = 0 while True: if group_id in self.waiting and self.waiting[group_id] != '': # 去掉 ret = self.waiting[group_id] del self.waiting[group_id] return ret - time.sleep(0.5) + cnt += 1 + if cnt > 300: + raise Exception("等待消息超时。") + time.sleep(1) def get_client(self): return self.client diff --git a/model/platform/qqchan.py b/model/platform/qqchan.py index 486b45b42..1f891b6f4 100644 --- a/model/platform/qqchan.py +++ b/model/platform/qqchan.py @@ -10,6 +10,7 @@ from util import general_utils as gu from nakuru.entities.components import Plain, At, Image from botpy.types.message import Reference from botpy import Client +import time class NakuruGuildMember(): tiny_id: int # 发送者识别号 @@ -38,6 +39,7 @@ class NakuruGuildMessage(): class QQChan(): def __init__(self, cnt: dict = None) -> None: self.qqchan_cnt = 0 + self.waiting: dict = {} def get_cnt(self): return self.qqchan_cnt @@ -181,11 +183,35 @@ class QQChan(): message_reference=msg_ref), self.client.loop).result() # send(message, f"QQ频道API错误:{str(e)}\n下面是格式化后的回答:\n{f_res}") - def push_message(self, channel_id: int, message_chain: list): + def push_message(self, channel_id: int, message_chain: list, message_id: int = None): ''' - 推送消息 + 推送消息, 如果有 message_id,那么就是回复消息。 ''' _n = NakuruGuildMessage() _n.channel_id = channel_id + _n.message_id = message_id self.send_qq_msg(_n, message_chain) + + def send(self, message_obj, message_chain: list): + ''' + 发送信息 + ''' + self.send_qq_msg(message_obj, message_chain) + + def wait_for_message(self, channel_id: int) -> NakuruGuildMessage: + ''' + 等待指定 channel_id 的下一条信息,超时 300s 后抛出异常 + ''' + self.waiting[channel_id] = '' + cnt = 0 + while True: + if channel_id in self.waiting and self.waiting[channel_id] != '': + # 去掉 + ret = self.waiting[channel_id] + del self.waiting[channel_id] + return ret + cnt += 1 + if cnt > 300: + raise Exception("等待消息超时。") + time.sleep(1) \ No newline at end of file