perf: 文转图渲染失败时发送纯文本

This commit is contained in:
Soulter
2024-10-06 00:20:42 +08:00
parent ab39dfd254
commit 50f62e66b0
5 changed files with 84 additions and 132 deletions
+3
View File
@@ -5,6 +5,9 @@ from type.astrbot_message import AstrBotMessage
from type.command import CommandResult
from type.astrbot_message import MessageType
class T2IException(Exception):
def __init__(self, message: str = "文本转图片时发生错误") -> None:
super().__init__(message)
class Platform():
def __init__(self, platform_name: str, context) -> None:
+28 -38
View File
@@ -4,7 +4,7 @@ import traceback
import logging
from aiocqhttp import CQHttp, Event
from aiocqhttp.exceptions import ActionFailed
from . import Platform
from . import Platform, T2IException
from type.astrbot_message import *
from type.message_event import *
from type.command import *
@@ -25,13 +25,11 @@ class AIOCQHTTP(Platform):
assert isinstance(platform_config, AiocqhttpPlatformConfig), "aiocqhttp: 无法识别的配置类型。"
self.message_handler = message_handler
self.waiting = {}
self.context = context
self.config = platform_config
self.unique_session = context.config_helper.platform_settings.unique_session
self.host = platform_config.ws_reverse_host
self.port = platform_config.ws_reverse_port
self.admins = context.config_helper.admins_id
def convert_message(self, event: Event) -> AstrBotMessage:
@@ -134,13 +132,6 @@ class AIOCQHTTP(Platform):
ok, reason = await self.pre_check(message)
if not ok:
return
# 解析 role
sender_id = str(message.sender.user_id)
if sender_id in self.admins:
role = 'admin'
else:
role = 'member'
# parse unified message origin
unified_msg_origin = None
@@ -157,7 +148,6 @@ class AIOCQHTTP(Platform):
self.context,
"aiocqhttp",
message.session_id,
role,
unified_msg_origin,
reason == "command") # only_command
@@ -169,10 +159,6 @@ class AIOCQHTTP(Platform):
if message_result.callback:
message_result.callback()
# 如果是等待回复的消息
if message.session_id in self.waiting and self.waiting[message.session_id] == '':
self.waiting[message.session_id] = message
return message_result
@@ -183,36 +169,35 @@ class AIOCQHTTP(Platform):
"""
回复用户唤醒机器人的消息。(被动回复)
"""
res = result_message
if isinstance(res, str):
res = [Plain(text=res), ]
try:
await self._reply(message, result_message, use_t2i)
except T2IException as e:
logger.error(traceback.format_exc())
logger.warning(f"文本转图片时发生错误,将使用纯文本发送。")
await self._reply(message, result_message, False)
return result_message
# if image mode, put all Plain texts into a new picture.
if (use_t2i or (use_t2i == None and self.context.config_helper.t2i)) and isinstance(result_message, list):
rendered_images = await self.convert_to_t2i_chain(res)
if rendered_images:
try:
await self._reply(message, rendered_images)
return rendered_images
except BaseException as e:
logger.warn(traceback.format_exc())
logger.warn(f"以文本转图片的形式回复消息时发生错误: {e},将尝试默认方式。")
await self._reply(message, res)
return res
async def _reply(self, message: Union[AstrBotMessage, Dict], message_chain: List[BaseMessageComponent]):
async def _reply(self, message: Union[AstrBotMessage, Dict], message_chain: List[BaseMessageComponent], use_t2i: bool = None):
await self.record_metrics()
if isinstance(message_chain, str):
message_chain = [Plain(text=message_chain), ]
# 文转图处理
if (use_t2i or (use_t2i == None and self.context.config_helper.t2i)) and isinstance(message_chain, list):
try:
message_chain = await self.convert_to_t2i_chain(message_chain)
if not message_chain: raise T2IException()
except BaseException as e:
raise T2IException()
# log
if isinstance(message, AstrBotMessage):
logger.info(
f"{message.sender.nickname}/{message.sender.user_id} <- {self.parse_message_outline(message_chain)}")
else:
logger.info(f"回复消息: {message_chain}")
# 解析成 OneBot json 格式并发送
ret = []
image_idx = []
for idx, segment in enumerate(message_chain):
@@ -232,10 +217,11 @@ class AIOCQHTTP(Platform):
# ENOENT
if not image_idx:
raise e
logger.warn("回复失败。检测到失败原因为文件未找到,猜测用户的协议端与 AstrBot 位于不同的文件系统上。尝试采用上传图片的方式发图。")
logger.warning("回复失败。检测到失败原因为文件未找到,猜测用户的协议端与 AstrBot 位于不同的文件系统上。尝试采用上传图片的方式发图。")
for idx in image_idx:
if ret[idx]['data']['file'].startswith('file://'):
logger.info(f"正在上传图片: {ret[idx]['data']['path']}")
# 除了上传到图床,想不到更好的办法。
image_url = await self.context.image_uploader.upload_image(ret[idx]['data']['path'])
logger.info(f"上传成功。")
ret[idx]['data']['file'] = image_url
@@ -267,8 +253,12 @@ class AIOCQHTTP(Platform):
- 要发给某个群聊,请添加 key `group_id`,值为 int 类型的 qq 群号;
'''
await self._reply(target, result_message.message_chain)
try:
await self._reply(target, result_message, result_message.is_use_t2i)
except T2IException as e:
logger.error(traceback.format_exc())
logger.warning(f"文本转图片时发生错误,将使用纯文本发送。")
await self._reply(target, result_message, False)
async def send_msg_new(self, message_type: MessageType, target: str, result_message: CommandResult):
if message_type == MessageType.GROUP_MESSAGE:
+37 -60
View File
@@ -11,7 +11,7 @@ from nakuru import (
)
from typing import Union, List, Dict
from type.types import Context
from . import Platform
from . import Platform, T2IException
from type.astrbot_message import *
from type.message_event import *
from type.command import *
@@ -40,11 +40,9 @@ class QQNakuru(Platform):
asyncio.set_event_loop(self.loop)
self.message_handler = message_handler
self.waiting = {}
self.context = context
self.unique_session = context.config_helper.platform_settings.unique_session
self.config = platform_config
self.admins = context.config_helper.admins_id
self.client = CQHTTP(
host=self.config.host,
@@ -113,13 +111,6 @@ class QQNakuru(Platform):
session_id = message.raw_message.user_id
message.session_id = session_id
# 解析 role
sender_id = str(message.raw_message.user_id)
if sender_id in self.admins:
role = 'admin'
else:
role = 'member'
# parse unified message origin
unified_msg_origin = None
@@ -141,7 +132,6 @@ class QQNakuru(Platform):
self.context,
"nakuru",
session_id,
role,
unified_msg_origin,
reason == 'command') # only_command
@@ -153,49 +143,47 @@ class QQNakuru(Platform):
if message_result.callback:
message_result.callback()
# 如果是等待回复的消息
if session_id in self.waiting and self.waiting[session_id] == '':
self.waiting[session_id] = message
async def reply_msg(self,
message: AstrBotMessage,
result_message: List[BaseMessageComponent],
use_t2i: bool = None):
"""
回复用户唤醒机器人的消息。(被动回复)
"""
source = message.raw_message
res = result_message
"""
assert isinstance(message.raw_message, (GroupMessage, FriendMessage, GuildMessage))
try:
await self._reply(message, result_message, use_t2i)
except T2IException as e:
logger.error(traceback.format_exc())
logger.warning(f"文本转图片时发生错误,将使用纯文本发送。")
await self._reply(message, result_message, False)
return result_message
assert isinstance(source,
(GroupMessage, FriendMessage, GuildMessage))
logger.info(
f"{message.sender.nickname}/{message.sender.user_id} <- {self.parse_message_outline(res)}")
if isinstance(res, str):
res = [Plain(text=res), ]
# if image mode, put all Plain texts into a new picture.
if use_t2i or (use_t2i == None and self.context.config_helper.t2i) and isinstance(result_message, list):
rendered_images = await self.convert_to_t2i_chain(res)
if rendered_images:
try:
await self._reply(source, rendered_images)
return
except BaseException as e:
logger.warn(traceback.format_exc())
logger.warn(f"以文本转图片的形式回复消息时发生错误: {e},将尝试默认方式。")
await self._reply(source, res)
async def _reply(self, source, message_chain: List[BaseMessageComponent]):
async def _reply(self, message: Union[AstrBotMessage, Dict], message_chain: List[BaseMessageComponent], use_t2i: bool = None):
await self.record_metrics()
if isinstance(message_chain, str):
message_chain = [Plain(text=message_chain), ]
# 文转图处理
if (use_t2i or (use_t2i == None and self.context.config_helper.t2i)) and isinstance(message_chain, list):
try:
message_chain = await self.convert_to_t2i_chain(message_chain)
if not message_chain: raise T2IException()
except BaseException as e:
raise T2IException()
# log
if isinstance(message, AstrBotMessage):
logger.info(
f"{message.sender.nickname}/{message.sender.user_id} <- {self.parse_message_outline(message_chain)}")
else:
logger.info(f"回复消息: {message_chain}")
source = message.raw_message
is_dict = isinstance(source, dict)
# 发消息
typ = None
if is_dict:
if "group_id" in source:
@@ -250,7 +238,13 @@ class QQNakuru(Platform):
guild_id 不是频道号。
'''
await self._reply(target, result_message.message_chain)
try:
await self._reply(target, result_message, result_message.is_use_t2i)
except T2IException as e:
logger.error(traceback.format_exc())
logger.warning(f"文本转图片时发生错误,将使用纯文本发送。")
await self._reply(target, result_message, False)
return result_message
async def send_msg_new(self, message_type: MessageType, target: str, result_message: CommandResult):
'''
@@ -290,21 +284,4 @@ class QQNakuru(Platform):
)
abm.tag = "nakuru"
abm.message = message.message
return abm
def wait_for_message(self, group_id) -> Union[GroupMessage, FriendMessage, GuildMessage]:
'''
等待下一条消息,超时 300s 后抛出异常
'''
self.waiting[group_id] = ''
cnt = 0
while True:
if group_id in self.waiting and self.waiting[group_id] != '':
# 去掉
ret = self.waiting[group_id]
del self.waiting[group_id]
return ret
cnt += 1
if cnt > 300:
raise Exception("等待消息超时。")
time.sleep(1)
return abm
+8 -33
View File
@@ -63,10 +63,8 @@ class QQOfficial(Platform):
asyncio.set_event_loop(self.loop)
self.message_handler = message_handler
self.waiting: dict = {}
self.context = context
self.config = platform_config
self.admins = context.config_helper.admins_id
self.appid = platform_config.appid
self.secret = platform_config.secret
@@ -201,15 +199,8 @@ class QQOfficial(Platform):
session_id = str(message.raw_message.author.id)
message.session_id = session_id
# 解析出 role
sender_id = message.sender.user_id
if sender_id in self.admins:
role = 'admin'
else:
role = 'member'
# construct astrbot message event
ame = AstrMessageEvent.from_astrbot_message(message, self.context, "qqofficial", session_id, role)
ame = AstrMessageEvent.from_astrbot_message(message, self.context, "qqofficial", session_id)
message_result = await self.message_handler.handle(ame)
if not message_result:
@@ -219,10 +210,6 @@ class QQOfficial(Platform):
if message_result.callback:
message_result.callback()
# 如果是等待回复的消息
if session_id in self.waiting and self.waiting[session_id] == '':
self.waiting[session_id] = message
return ret
async def reply_msg(self,
@@ -241,10 +228,15 @@ class QQOfficial(Platform):
plain_text = ''
image_path = ''
msg_ref = None
rendered_images = []
rendered_images = None
if use_t2i or (use_t2i == None and self.context.config_helper.t2i) and isinstance(result_message, list):
rendered_images = await self.convert_to_t2i_chain(result_message)
try:
rendered_images = await self.convert_to_t2i_chain(result_message)
except BaseException as e:
logger.warning(traceback.format_exc())
logger.warning(f"文本转图片时发生错误: {e},将尝试默认方式。")
rendered_images = None
if isinstance(result_message, list):
plain_text, image_path = await self._parse_to_qqofficial(result_message, message.type == MessageType.GROUP_MESSAGE)
@@ -386,20 +378,3 @@ class QQOfficial(Platform):
async def send_msg_new(self, message_type: MessageType, target: str, result_message: CommandResult):
raise NotImplementedError("qqofficial 不支持此方法。")
def wait_for_message(self, channel_id: int) -> AstrBotMessage:
'''
等待指定 channel_id 的下一条信息,超时 300s 后抛出异常
'''
self.waiting[channel_id] = ''
cnt = 0
while True:
if channel_id in self.waiting and self.waiting[channel_id] != '':
# 去掉
ret = self.waiting[channel_id]
del self.waiting[channel_id]
return ret
cnt += 1
if cnt > 300:
raise Exception("等待消息超时。")
time.sleep(1)
+8 -1
View File
@@ -47,10 +47,17 @@ class AstrMessageEvent():
context: Context,
platform_name: str,
session_id: str,
role: str = "member",
unified_msg_origin: str = None,
only_command: bool = False):
# 解析 role
sender_id = str(message.sender.user_id)
if sender_id in context.config_helper.admins_id:
role = 'admin'
else:
role = 'member'
ame = AstrMessageEvent(message.message_str,
message,
context.find_platform(platform_name),