Merge remote-tracking branch 'refs/remotes/origin/master'

This commit is contained in:
Soulter
2023-11-25 11:51:47 +08:00
5 changed files with 58 additions and 16 deletions
+7 -2
View File
@@ -641,9 +641,14 @@ async def oper_msg(message: Union[GroupMessage, FriendMessage, GuildMessage, Nak
chatgpt_res = ""
if session_id in gocq_bot.waiting and gocq_bot.waiting[session_id] == '':
gocq_bot.waiting[session_id] = qq_msg
# 如果是等待回复的消息
if platform == PLATFORM_GOCQ and session_id in gocq_bot.waiting and gocq_bot.waiting[session_id] == '':
gocq_bot.waiting[session_id] = message
return
if platform == PLATFORM_QQCHAN and session_id in qqchannel_bot.waiting and qqchannel_bot.waiting[session_id] == '':
qqchannel_bot.waiting[session_id] = message
return
hit, command_result = llm_command_instance[chosen_provider].check_command(
qq_msg,
session_id,
+5 -2
View File
@@ -48,6 +48,7 @@ class AstrMessageEvent():
platform: str # `gocq` 或 `qqchan`
role: str # `admin` 或 `member`
global_object: GlobalObject # 一些公用数据
session_id: int # 会话id (可能是群id,也可能是某个user的id。取决于是否开启了 uniqueSession)
def __init__(self, message_str: str,
message_obj: Union[GroupMessage, FriendMessage, GuildMessage, NakuruGuildMessage],
@@ -56,7 +57,8 @@ class AstrMessageEvent():
platform: str,
role: str,
global_object: GlobalObject,
llm_provider: Provider = None):
llm_provider: Provider = None,
session_id: int = None):
self.message_str = message_str
self.message_obj = message_obj
self.gocq_platform = gocq_platform
@@ -64,4 +66,5 @@ class AstrMessageEvent():
self.platform = platform
self.role = role
self.global_object = global_object
self.llm_provider = llm_provider
self.llm_provider = llm_provider
self.session_id = session_id
+2 -1
View File
@@ -50,7 +50,8 @@ class Command:
qq_sdk_platform=self.global_object.platform_qqchan,
platform=platform,
role=role,
global_object=self.global_object
global_object=self.global_object,
session_id = session_id
)
for k, v in cached_plugins.items():
try:
+11 -4
View File
@@ -4,8 +4,11 @@ from util.cmd_config import CmdConfig
import asyncio
from nakuru import (
CQHTTP,
GuildMessage
GuildMessage,
GroupMessage,
FriendMessage
)
from typing import Union
import time
@@ -155,18 +158,22 @@ class QQ:
except Exception as e:
raise e
def wait_for_message(self, group_id):
def wait_for_message(self, group_id) -> Union[GroupMessage, FriendMessage, GuildMessage]:
'''
等待下一条消息
等待下一条消息,超时 300s 后抛出异常
'''
self.waiting[group_id] = ''
cnt = 0
while True:
if group_id in self.waiting and self.waiting[group_id] != '':
# 去掉
ret = self.waiting[group_id]
del self.waiting[group_id]
return ret
time.sleep(0.5)
cnt += 1
if cnt > 300:
raise Exception("等待消息超时。")
time.sleep(1)
def get_client(self):
return self.client
+33 -7
View File
@@ -10,6 +10,7 @@ from util import general_utils as gu
from nakuru.entities.components import Plain, At, Image
from botpy.types.message import Reference
from botpy import Client
import time
class NakuruGuildMember():
tiny_id: int # 发送者识别号
@@ -38,6 +39,7 @@ class NakuruGuildMessage():
class QQChan():
def __init__(self, cnt: dict = None) -> None:
self.qqchan_cnt = 0
self.waiting: dict = {}
def get_cnt(self):
return self.qqchan_cnt
@@ -133,7 +135,7 @@ class QQChan():
try:
# reply_res = asyncio.run_coroutine_threadsafe(message.raw_message.reply(content=str(plain_text), message_reference = msg_ref, file_image=image_path), self.client.loop)
reply_res = asyncio.run_coroutine_threadsafe(self.client.api.post_message(channel_id=message.channel_id,
reply_res = asyncio.run_coroutine_threadsafe(self.client.api.post_message(channel_id=str(message.channel_id),
content=str(plain_text),
msg_id=message.message_id,
file_image=image_path,
@@ -146,7 +148,7 @@ class QQChan():
split_res.append(plain_text[:len(plain_text)//2])
split_res.append(plain_text[len(plain_text)//2:])
for i in split_res:
reply_res = asyncio.run_coroutine_threadsafe(self.client.api.post_message(channel_id=message.channel_id,
reply_res = asyncio.run_coroutine_threadsafe(self.client.api.post_message(channel_id=str(message.channel_id),
content=str(i),
msg_id=message.message_id,
file_image=image_path,
@@ -157,7 +159,7 @@ class QQChan():
try:
# 防止被qq频道过滤消息
plain_text = plain_text.replace(".", " . ")
reply_res = asyncio.run_coroutine_threadsafe(self.client.api.post_message(channel_id=message.channel_id,
reply_res = asyncio.run_coroutine_threadsafe(self.client.api.post_message(channel_id=str(message.channel_id),
content=str(plain_text),
msg_id=message.message_id,
file_image=image_path,
@@ -166,7 +168,7 @@ class QQChan():
print("QQ频道API错误: \n"+str(e))
try:
# reply_res = asyncio.run_coroutine_threadsafe(message.raw_message.reply(content=str(str.join(" ", plain_text)), message_reference = msg_ref, file_image=image_path), self.client.loop)
reply_res = asyncio.run_coroutine_threadsafe(self.client.api.post_message(channel_id=message.channel_id,
reply_res = asyncio.run_coroutine_threadsafe(self.client.api.post_message(channel_id=str(message.channel_id),
content=str(str.join(" ", plain_text)),
msg_id=message.message_id,
file_image=image_path,
@@ -174,18 +176,42 @@ class QQChan():
except BaseException as e:
plain_text = re.sub(r'(https|http)?:\/\/(\w|\.|\/|\?|\=|\&|\%)*\b', '[被隐藏的链接]', str(e), flags=re.MULTILINE)
plain_text = plain_text.replace(".", "·")
reply_res = asyncio.run_coroutine_threadsafe(self.client.api.post_message(channel_id=message.channel_id,
reply_res = asyncio.run_coroutine_threadsafe(self.client.api.post_message(channel_id=str(message.channel_id),
content=plain_text,
msg_id=message.message_id,
file_image=image_path,
message_reference=msg_ref), self.client.loop).result()
# send(message, f"QQ频道API错误:{str(e)}\n下面是格式化后的回答:\n{f_res}")
def push_message(self, channel_id: int, message_chain: list):
def push_message(self, channel_id: int, message_chain: list, message_id: int = None):
'''
推送消息
推送消息, 如果有 message_id,那么就是回复消息。
'''
_n = NakuruGuildMessage()
_n.channel_id = channel_id
_n.message_id = message_id
self.send_qq_msg(_n, message_chain)
def send(self, message_obj, message_chain: list):
'''
发送信息
'''
self.send_qq_msg(message_obj, message_chain)
def wait_for_message(self, channel_id: int) -> NakuruGuildMessage:
'''
等待指定 channel_id 的下一条信息,超时 300s 后抛出异常
'''
self.waiting[channel_id] = ''
cnt = 0
while True:
if channel_id in self.waiting and self.waiting[channel_id] != '':
# 去掉
ret = self.waiting[channel_id]
del self.waiting[channel_id]
return ret
cnt += 1
if cnt > 300:
raise Exception("等待消息超时。")
time.sleep(1)