Merge remote-tracking branch 'refs/remotes/origin/master'
This commit is contained in:
+7
-2
@@ -641,9 +641,14 @@ async def oper_msg(message: Union[GroupMessage, FriendMessage, GuildMessage, Nak
|
||||
|
||||
chatgpt_res = ""
|
||||
|
||||
if session_id in gocq_bot.waiting and gocq_bot.waiting[session_id] == '':
|
||||
gocq_bot.waiting[session_id] = qq_msg
|
||||
# 如果是等待回复的消息
|
||||
if platform == PLATFORM_GOCQ and session_id in gocq_bot.waiting and gocq_bot.waiting[session_id] == '':
|
||||
gocq_bot.waiting[session_id] = message
|
||||
return
|
||||
if platform == PLATFORM_QQCHAN and session_id in qqchannel_bot.waiting and qqchannel_bot.waiting[session_id] == '':
|
||||
qqchannel_bot.waiting[session_id] = message
|
||||
return
|
||||
|
||||
hit, command_result = llm_command_instance[chosen_provider].check_command(
|
||||
qq_msg,
|
||||
session_id,
|
||||
|
||||
@@ -48,6 +48,7 @@ class AstrMessageEvent():
|
||||
platform: str # `gocq` 或 `qqchan`
|
||||
role: str # `admin` 或 `member`
|
||||
global_object: GlobalObject # 一些公用数据
|
||||
session_id: int # 会话id (可能是群id,也可能是某个user的id。取决于是否开启了 uniqueSession)
|
||||
|
||||
def __init__(self, message_str: str,
|
||||
message_obj: Union[GroupMessage, FriendMessage, GuildMessage, NakuruGuildMessage],
|
||||
@@ -56,7 +57,8 @@ class AstrMessageEvent():
|
||||
platform: str,
|
||||
role: str,
|
||||
global_object: GlobalObject,
|
||||
llm_provider: Provider = None):
|
||||
llm_provider: Provider = None,
|
||||
session_id: int = None):
|
||||
self.message_str = message_str
|
||||
self.message_obj = message_obj
|
||||
self.gocq_platform = gocq_platform
|
||||
@@ -64,4 +66,5 @@ class AstrMessageEvent():
|
||||
self.platform = platform
|
||||
self.role = role
|
||||
self.global_object = global_object
|
||||
self.llm_provider = llm_provider
|
||||
self.llm_provider = llm_provider
|
||||
self.session_id = session_id
|
||||
@@ -50,7 +50,8 @@ class Command:
|
||||
qq_sdk_platform=self.global_object.platform_qqchan,
|
||||
platform=platform,
|
||||
role=role,
|
||||
global_object=self.global_object
|
||||
global_object=self.global_object,
|
||||
session_id = session_id
|
||||
)
|
||||
for k, v in cached_plugins.items():
|
||||
try:
|
||||
|
||||
+11
-4
@@ -4,8 +4,11 @@ from util.cmd_config import CmdConfig
|
||||
import asyncio
|
||||
from nakuru import (
|
||||
CQHTTP,
|
||||
GuildMessage
|
||||
GuildMessage,
|
||||
GroupMessage,
|
||||
FriendMessage
|
||||
)
|
||||
from typing import Union
|
||||
import time
|
||||
|
||||
|
||||
@@ -155,18 +158,22 @@ class QQ:
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def wait_for_message(self, group_id):
|
||||
def wait_for_message(self, group_id) -> Union[GroupMessage, FriendMessage, GuildMessage]:
|
||||
'''
|
||||
等待下一条消息
|
||||
等待下一条消息,超时 300s 后抛出异常
|
||||
'''
|
||||
self.waiting[group_id] = ''
|
||||
cnt = 0
|
||||
while True:
|
||||
if group_id in self.waiting and self.waiting[group_id] != '':
|
||||
# 去掉
|
||||
ret = self.waiting[group_id]
|
||||
del self.waiting[group_id]
|
||||
return ret
|
||||
time.sleep(0.5)
|
||||
cnt += 1
|
||||
if cnt > 300:
|
||||
raise Exception("等待消息超时。")
|
||||
time.sleep(1)
|
||||
|
||||
def get_client(self):
|
||||
return self.client
|
||||
|
||||
@@ -10,6 +10,7 @@ from util import general_utils as gu
|
||||
from nakuru.entities.components import Plain, At, Image
|
||||
from botpy.types.message import Reference
|
||||
from botpy import Client
|
||||
import time
|
||||
|
||||
class NakuruGuildMember():
|
||||
tiny_id: int # 发送者识别号
|
||||
@@ -38,6 +39,7 @@ class NakuruGuildMessage():
|
||||
class QQChan():
|
||||
def __init__(self, cnt: dict = None) -> None:
|
||||
self.qqchan_cnt = 0
|
||||
self.waiting: dict = {}
|
||||
|
||||
def get_cnt(self):
|
||||
return self.qqchan_cnt
|
||||
@@ -133,7 +135,7 @@ class QQChan():
|
||||
|
||||
try:
|
||||
# reply_res = asyncio.run_coroutine_threadsafe(message.raw_message.reply(content=str(plain_text), message_reference = msg_ref, file_image=image_path), self.client.loop)
|
||||
reply_res = asyncio.run_coroutine_threadsafe(self.client.api.post_message(channel_id=message.channel_id,
|
||||
reply_res = asyncio.run_coroutine_threadsafe(self.client.api.post_message(channel_id=str(message.channel_id),
|
||||
content=str(plain_text),
|
||||
msg_id=message.message_id,
|
||||
file_image=image_path,
|
||||
@@ -146,7 +148,7 @@ class QQChan():
|
||||
split_res.append(plain_text[:len(plain_text)//2])
|
||||
split_res.append(plain_text[len(plain_text)//2:])
|
||||
for i in split_res:
|
||||
reply_res = asyncio.run_coroutine_threadsafe(self.client.api.post_message(channel_id=message.channel_id,
|
||||
reply_res = asyncio.run_coroutine_threadsafe(self.client.api.post_message(channel_id=str(message.channel_id),
|
||||
content=str(i),
|
||||
msg_id=message.message_id,
|
||||
file_image=image_path,
|
||||
@@ -157,7 +159,7 @@ class QQChan():
|
||||
try:
|
||||
# 防止被qq频道过滤消息
|
||||
plain_text = plain_text.replace(".", " . ")
|
||||
reply_res = asyncio.run_coroutine_threadsafe(self.client.api.post_message(channel_id=message.channel_id,
|
||||
reply_res = asyncio.run_coroutine_threadsafe(self.client.api.post_message(channel_id=str(message.channel_id),
|
||||
content=str(plain_text),
|
||||
msg_id=message.message_id,
|
||||
file_image=image_path,
|
||||
@@ -166,7 +168,7 @@ class QQChan():
|
||||
print("QQ频道API错误: \n"+str(e))
|
||||
try:
|
||||
# reply_res = asyncio.run_coroutine_threadsafe(message.raw_message.reply(content=str(str.join(" ", plain_text)), message_reference = msg_ref, file_image=image_path), self.client.loop)
|
||||
reply_res = asyncio.run_coroutine_threadsafe(self.client.api.post_message(channel_id=message.channel_id,
|
||||
reply_res = asyncio.run_coroutine_threadsafe(self.client.api.post_message(channel_id=str(message.channel_id),
|
||||
content=str(str.join(" ", plain_text)),
|
||||
msg_id=message.message_id,
|
||||
file_image=image_path,
|
||||
@@ -174,18 +176,42 @@ class QQChan():
|
||||
except BaseException as e:
|
||||
plain_text = re.sub(r'(https|http)?:\/\/(\w|\.|\/|\?|\=|\&|\%)*\b', '[被隐藏的链接]', str(e), flags=re.MULTILINE)
|
||||
plain_text = plain_text.replace(".", "·")
|
||||
reply_res = asyncio.run_coroutine_threadsafe(self.client.api.post_message(channel_id=message.channel_id,
|
||||
reply_res = asyncio.run_coroutine_threadsafe(self.client.api.post_message(channel_id=str(message.channel_id),
|
||||
content=plain_text,
|
||||
msg_id=message.message_id,
|
||||
file_image=image_path,
|
||||
message_reference=msg_ref), self.client.loop).result()
|
||||
# send(message, f"QQ频道API错误:{str(e)}\n下面是格式化后的回答:\n{f_res}")
|
||||
|
||||
def push_message(self, channel_id: int, message_chain: list):
|
||||
def push_message(self, channel_id: int, message_chain: list, message_id: int = None):
|
||||
'''
|
||||
推送消息
|
||||
推送消息, 如果有 message_id,那么就是回复消息。
|
||||
'''
|
||||
_n = NakuruGuildMessage()
|
||||
_n.channel_id = channel_id
|
||||
_n.message_id = message_id
|
||||
self.send_qq_msg(_n, message_chain)
|
||||
|
||||
def send(self, message_obj, message_chain: list):
|
||||
'''
|
||||
发送信息
|
||||
'''
|
||||
self.send_qq_msg(message_obj, message_chain)
|
||||
|
||||
def wait_for_message(self, channel_id: int) -> NakuruGuildMessage:
|
||||
'''
|
||||
等待指定 channel_id 的下一条信息,超时 300s 后抛出异常
|
||||
'''
|
||||
self.waiting[channel_id] = ''
|
||||
cnt = 0
|
||||
while True:
|
||||
if channel_id in self.waiting and self.waiting[channel_id] != '':
|
||||
# 去掉
|
||||
ret = self.waiting[channel_id]
|
||||
del self.waiting[channel_id]
|
||||
return ret
|
||||
cnt += 1
|
||||
if cnt > 300:
|
||||
raise Exception("等待消息超时。")
|
||||
time.sleep(1)
|
||||
|
||||
Reference in New Issue
Block a user