fix: 修复了一些bug。

This commit is contained in:
Soulter
2024-07-24 09:19:43 -04:00
parent a2cf058951
commit 4edd11f2f7
13 changed files with 101 additions and 54 deletions
+15 -6
View File
@@ -1,5 +1,5 @@
import asyncio
import threading
import traceback
from astrbot.message.handler import MessageHandler
from astrbot.persist.helper import dbConn
from dashboard.server import AstrBotDashBoard
@@ -72,13 +72,21 @@ class AstrBotBootstrap():
# load platforms
platform_tasks = self.load_platform()
# load metrics uploader
metrics_upload_task = upload_metrics(self.context)
metrics_upload_task = asyncio.create_task(upload_metrics(self.context))
# load dashboard
self.dashboard.run_http_server()
dashboard_task = self.dashboard.ws_server()
await asyncio.gather(metrics_upload_task, dashboard_task, *platform_tasks)
dashboard_task = asyncio.create_task(self.dashboard.ws_server())
tasks = [metrics_upload_task, dashboard_task, *platform_tasks]
tasks = [self.handle_task(task) for task in tasks]
await asyncio.gather(*tasks)
async def handle_task(self, task: Union[asyncio.Task, asyncio.Future]):
try:
result = await task
return result
except Exception as e:
logger.error(traceback.format_exc())
return None
def load_llm(self):
if 'openai' in self.configs and \
@@ -88,6 +96,7 @@ class AstrBotBootstrap():
from model.command.openai_official_handler import OpenAIOfficialCommandHandler
self.openai_command_handler = OpenAIOfficialCommandHandler(self.command_manager)
self.llm_instance = ProviderOpenAIOfficial(self.context)
self.openai_command_handler.set_provider(self.llm_instance)
logger.info("已启用 OpenAI API 支持。")
def load_plugins(self):
+3 -3
View File
@@ -419,19 +419,19 @@ class AstrBotDashBoard():
"tag": ""
},
{
"title": "QQ(官方机器人 API)",
"title": "QQ(官方)",
"desc": "QQ官方API。支持频道、群、私聊(需获得群权限)",
"namespace": "internal_platform_qq_official",
"tag": ""
},
{
"title": "QQ(nakuru 适配器)",
"title": "QQ(nakuru)",
"desc": "适用于 go-cqhttp",
"namespace": "internal_platform_qq_gocq",
"tag": ""
},
{
"title": "QQ(aiocqhttp 适配器)",
"title": "QQ(aiocqhttp)",
"desc": "适用于 Lagrange, LLBot, Shamrock 等支持反向WS的协议实现。",
"namespace": "internal_platform_qq_aiocqhttp",
"tag": ""
+1 -1
View File
@@ -23,7 +23,7 @@ class InternalCommandHandler:
self.manager.register("update", "更新 AstrBot", 10, self.update)
self.manager.register("plugin", "插件管理", 10, self.plugin)
self.manager.register("reboot", "重启 AstrBot", 10, self.reboot)
self.manager.register("web_search", "网页搜索开关", 10, self.web_search)
self.manager.register("websearch", "网页搜索开关", 10, self.web_search)
self.manager.register("t2i", "文本转图片开关", 10, self.t2i_toggle)
self.manager.register("myid", "获取你在此平台上的ID", 10, self.myid)
+6 -2
View File
@@ -93,7 +93,11 @@ class CommandManager():
return command_result
except BaseException as e:
logger.error(traceback.format_exc())
if not command_metadata.inner_command:
logger.error(f"执行 {command}/({command_metadata.plugin_metadata.plugin_name} By {command_metadata.plugin_metadata.author}) 指令时发生了异常。")
text = f"执行 {command}/({command_metadata.plugin_metadata.plugin_name} By {command_metadata.plugin_metadata.author}) 指令时发生了异常。{e}"
logger.error(text)
else:
logger.error(f"执行 {command} 指令时发生了异常。")
text = f"执行 {command} 指令时发生了异常。{e}"
logger.error(text)
return CommandResult().message(text)
+3 -5
View File
@@ -145,8 +145,6 @@ class OpenAIOfficialCommandHandler():
async def draw(self, message: AstrMessageEvent, context: Context):
message = message.message_str.removeprefix("")
img_url = await self.provider.image_generate(message)
p = await download_image_by_url(url=img_url)
with open(p, 'rb') as f:
return CommandResult(
message_chain=[Image.fromBytes(f.read())],
)
return CommandResult(
message_chain=[Image.fromURL(img_url)],
)
+5 -2
View File
@@ -64,7 +64,10 @@ class Platform():
if isinstance(i, Plain):
plain_str += i.text
if plain_str and len(plain_str) > 50:
p = await self.context.image_renderer.render(plain_str)
rendered_images.append(Image.fromFileSystem(p))
p = await self.context.image_renderer.render(plain_str, return_url=True)
if p.startswith('http'):
rendered_images.append(Image.fromURL(p))
else:
rendered_images.append(Image.fromFileSystem(p))
return rendered_images
+7 -7
View File
@@ -1,4 +1,4 @@
import time, threading
import asyncio
from util.io import port_checker
from type.register import RegisteredPlatform
@@ -21,20 +21,20 @@ class PlatformManager():
if 'gocqbot' in self.config and self.config['gocqbot']['enable']:
logger.info("启用 QQ(nakuru 适配器)")
tasks.append(self.gocq_bot())
tasks.append(asyncio.create_task(self.gocq_bot()))
if 'aiocqhttp' in self.config and self.config['aiocqhttp']['enable']:
logger.info("启用 QQ(aiocqhttp 适配器)")
tasks.append(self.aiocq_bot())
tasks.append(asyncio.create_task(self.aiocq_bot()))
# QQ频道
if 'qqbot' in self.config and self.config['qqbot']['enable'] and self.config['qqbot']['appid'] != None:
logger.info("启用 QQ(官方 API) 机器人消息平台")
tasks.append(self.qqchan_bot())
tasks.append(asyncio.create_task(self.qqchan_bot()))
return tasks
def gocq_bot(self):
async def gocq_bot(self):
'''
运行 QQ(nakuru 适配器)
'''
@@ -51,7 +51,7 @@ class PlatformManager():
noticed = True
logger.warning(
f"连接到{host}:{port}(或{http_port})失败。程序会每隔 5s 自动重试。")
time.sleep(5)
await asyncio.sleep(5)
else:
logger.info("nakuru 适配器已连接。")
break
@@ -59,7 +59,7 @@ class PlatformManager():
qq_gocq = QQGOCQ(self.context, self.msg_handler)
self.context.platforms.append(RegisteredPlatform(
platform_name="gocq", platform_instance=qq_gocq, origin="internal"))
return qq_gocq.run()
await qq_gocq.run()
except BaseException as e:
logger.error("启动 nakuru 适配器时出现错误: " + str(e))
+47 -17
View File
@@ -22,11 +22,8 @@ class AIOCQHTTP(Platform):
self.announcement = self.context.base_config.get("announcement", "欢迎新人!")
self.host = self.context.base_config['aiocqhttp']['ws_reverse_host']
self.port = self.context.base_config['aiocqhttp']['ws_reverse_port']
for handler in logging.root.handlers[:]:
logging.root.removeHandler(handler)
def compat_onebot2astrbotmsg(self, event: Event) -> AstrBotMessage:
def convert_message(self, event: Event) -> AstrBotMessage:
abm = AstrBotMessage()
abm.self_id = str(event.self_id)
@@ -69,28 +66,48 @@ class AIOCQHTTP(Platform):
def run_aiocqhttp(self):
if not self.host or not self.port:
return
self.bot = CQHttp(use_ws_reverse=True)
self.bot = CQHttp(use_ws_reverse=True, import_name='aiocqhttp')
@self.bot.on_message('group')
async def group(event: Event):
abm = self.compat_onebot2astrbotmsg(event)
abm = self.convert_message(event)
if abm:
await self.handle_msg(event, abm)
return {'reply': event.message}
await self.handle_msg(abm)
# return {'reply': event.message}
@self.bot.on_message('private')
async def private(event: Event):
abm = self.compat_onebot2astrbotmsg(event)
abm = self.convert_message(event)
if abm:
await self.handle_msg(event, abm)
return {'reply': event.message}
await self.handle_msg(abm)
# return {'reply': event.message}
return self.bot.run_task(host=self.host, port=int(self.port))
bot = self.bot.run_task(host=self.host, port=int(self.port))
for handler in logging.root.handlers[:]:
logging.root.removeHandler(handler)
logging.getLogger('aiocqhttp').setLevel(logging.ERROR)
return bot
def pre_check(self, message: AstrBotMessage) -> bool:
# if message chain contains Plain components or At components which points to self_id, return True
if message.type == MessageType.FRIEND_MESSAGE:
return True
for comp in message.message:
if isinstance(comp, At) and str(comp.qq) == message.self_id:
return True
# check nicks
if self.check_nick(message.message_str):
return True
return False
async def handle_msg(self, message: AstrBotMessage):
logger.info(
f"{message.sender.nickname}/{message.sender.user_id} -> {self.parse_message_outline(message)}")
if not self.pre_check(message):
return
# 解析 role
sender_id = str(message.sender.user_id)
if sender_id == self.context.config_helper.get('admin_qq', '') or \
@@ -100,7 +117,7 @@ class AIOCQHTTP(Platform):
role = 'member'
# construct astrbot message event
ame = AstrMessageEvent().from_astrbot_message(message, self.context, "gocq", message.session_id, role)
ame = AstrMessageEvent.from_astrbot_message(message, self.context, "aiocqhttp", message.session_id, role)
# transfer control to message handler
message_result = await self.message_handler.handle(ame)
@@ -118,13 +135,14 @@ class AIOCQHTTP(Platform):
async def reply_msg(self,
message: AstrBotMessage,
result_message: list):
await super().reply_msg()
"""
回复用户唤醒机器人的消息。(被动回复)
"""
logger.info(
f"{message.sender.user_id} <- {self.parse_message_outline(message)}")
res = result_message
if isinstance(res, str):
res = [Plain(text=res), ]
@@ -138,9 +156,21 @@ class AIOCQHTTP(Platform):
except BaseException as e:
logger.warn(traceback.format_exc())
logger.warn(f"以文本转图片的形式回复消息时发生错误: {e},将尝试默认方式。")
await self._reply(message, res)
async def _reply(self, message: AstrBotMessage, message_chain: List[BaseMessageComponent]):
if isinstance(message_chain, str):
message_chain = [Plain(text=message_chain), ]
await self.bot.send(message.raw_message, message_chain)
ret = []
for segment in message_chain:
d = segment.toDict()
if isinstance(segment, Plain):
d['type'] = 'text'
if isinstance(segment, Image):
# d['data']['file'] =
pass
ret.append(d)
await self.bot.send(message.raw_message, ret)
+3 -4
View File
@@ -75,8 +75,6 @@ class QQGOCQ(Platform):
if message.type == MessageType.FRIEND_MESSAGE:
return True
for comp in message.message:
if isinstance(comp, Plain):
return True
if isinstance(comp, At) and str(comp.qq) == message.self_id:
return True
# check nicks
@@ -85,7 +83,8 @@ class QQGOCQ(Platform):
return False
def run(self):
return self.client._run()
coro = self.client._run()
return coro
async def handle_msg(self, message: AstrBotMessage):
logger.info(
@@ -119,7 +118,7 @@ class QQGOCQ(Platform):
role = 'member'
# construct astrbot message event
ame = AstrMessageEvent().from_astrbot_message(message, self.context, "gocq", session_id, role)
ame = AstrMessageEvent.from_astrbot_message(message, self.context, "gocq", session_id, role)
# transfer control to message handler
message_result = await self.message_handler.handle(ame)
+7 -3
View File
@@ -309,9 +309,13 @@ class QQOfficial(Platform):
if 'file_image' in kwargs:
file_image_path = kwargs['file_image'].replace("file:///", "")
if file_image_path:
logger.debug(f"上传图片: {file_image_path}")
image_url = await self.context.image_uploader.upload_image(file_image_path)
logger.debug(f"上传成功: {image_url}")
if file_image_path.startswith("http"):
image_url = file_image_path
else:
logger.debug(f"上传图片: {file_image_path}")
image_url = await self.context.image_uploader.upload_image(file_image_path)
logger.debug(f"上传成功: {image_url}")
media = await self.client.api.post_group_file(kwargs['group_openid'], 1, image_url)
del kwargs['file_image']
kwargs['media'] = media
+2 -2
View File
@@ -7,5 +7,5 @@ class RenderContext:
def set_strategy(self, strategy: RenderStrategy):
self._strategy = strategy
async def render(self, text: str) -> str:
return await self._strategy.render(text)
async def render(self, text: str, return_url: bool = False):
return await self._strategy.render(text, return_url)
+1 -1
View File
@@ -15,7 +15,7 @@ class TextToImageRenderer:
async def render(self, text: str, use_network: bool = True, return_url: bool = False):
if use_network:
try:
return await self.context.render(text)
return await self.context.render(text, return_url=return_url)
except BaseException as e:
logger.error(f"Failed to render image via AstrBot API: {e}. Falling back to local rendering.")
self.context.set_strategy(self.local_strategy)
+1 -1
View File
@@ -17,7 +17,7 @@ class LocalRenderStrategy(RenderStrategy):
except Exception as e:
pass
async def render(self, text: str):
async def render(self, text: str, **kwargs):
font_size = 26
image_width = 800
image_height = 600