fix: 修复了一些bug。
This commit is contained in:
+15
-6
@@ -1,5 +1,5 @@
|
||||
import asyncio
|
||||
import threading
|
||||
import traceback
|
||||
from astrbot.message.handler import MessageHandler
|
||||
from astrbot.persist.helper import dbConn
|
||||
from dashboard.server import AstrBotDashBoard
|
||||
@@ -72,13 +72,21 @@ class AstrBotBootstrap():
|
||||
# load platforms
|
||||
platform_tasks = self.load_platform()
|
||||
# load metrics uploader
|
||||
metrics_upload_task = upload_metrics(self.context)
|
||||
metrics_upload_task = asyncio.create_task(upload_metrics(self.context))
|
||||
# load dashboard
|
||||
self.dashboard.run_http_server()
|
||||
dashboard_task = self.dashboard.ws_server()
|
||||
|
||||
await asyncio.gather(metrics_upload_task, dashboard_task, *platform_tasks)
|
||||
|
||||
dashboard_task = asyncio.create_task(self.dashboard.ws_server())
|
||||
tasks = [metrics_upload_task, dashboard_task, *platform_tasks]
|
||||
tasks = [self.handle_task(task) for task in tasks]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
async def handle_task(self, task: Union[asyncio.Task, asyncio.Future]):
|
||||
try:
|
||||
result = await task
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
return None
|
||||
|
||||
def load_llm(self):
|
||||
if 'openai' in self.configs and \
|
||||
@@ -88,6 +96,7 @@ class AstrBotBootstrap():
|
||||
from model.command.openai_official_handler import OpenAIOfficialCommandHandler
|
||||
self.openai_command_handler = OpenAIOfficialCommandHandler(self.command_manager)
|
||||
self.llm_instance = ProviderOpenAIOfficial(self.context)
|
||||
self.openai_command_handler.set_provider(self.llm_instance)
|
||||
logger.info("已启用 OpenAI API 支持。")
|
||||
|
||||
def load_plugins(self):
|
||||
|
||||
+3
-3
@@ -419,19 +419,19 @@ class AstrBotDashBoard():
|
||||
"tag": ""
|
||||
},
|
||||
{
|
||||
"title": "QQ(官方机器人 API)",
|
||||
"title": "QQ(官方)",
|
||||
"desc": "QQ官方API。支持频道、群、私聊(需获得群权限)",
|
||||
"namespace": "internal_platform_qq_official",
|
||||
"tag": ""
|
||||
},
|
||||
{
|
||||
"title": "QQ(nakuru 适配器)",
|
||||
"title": "QQ(nakuru)",
|
||||
"desc": "适用于 go-cqhttp",
|
||||
"namespace": "internal_platform_qq_gocq",
|
||||
"tag": ""
|
||||
},
|
||||
{
|
||||
"title": "QQ(aiocqhttp 适配器)",
|
||||
"title": "QQ(aiocqhttp)",
|
||||
"desc": "适用于 Lagrange, LLBot, Shamrock 等支持反向WS的协议实现。",
|
||||
"namespace": "internal_platform_qq_aiocqhttp",
|
||||
"tag": ""
|
||||
|
||||
@@ -23,7 +23,7 @@ class InternalCommandHandler:
|
||||
self.manager.register("update", "更新 AstrBot", 10, self.update)
|
||||
self.manager.register("plugin", "插件管理", 10, self.plugin)
|
||||
self.manager.register("reboot", "重启 AstrBot", 10, self.reboot)
|
||||
self.manager.register("web_search", "网页搜索开关", 10, self.web_search)
|
||||
self.manager.register("websearch", "网页搜索开关", 10, self.web_search)
|
||||
self.manager.register("t2i", "文本转图片开关", 10, self.t2i_toggle)
|
||||
self.manager.register("myid", "获取你在此平台上的ID", 10, self.myid)
|
||||
|
||||
|
||||
@@ -93,7 +93,11 @@ class CommandManager():
|
||||
return command_result
|
||||
except BaseException as e:
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
if not command_metadata.inner_command:
|
||||
logger.error(f"当执行 {command}/({command_metadata.plugin_metadata.plugin_name} By {command_metadata.plugin_metadata.author}) 指令时,发生了异常。")
|
||||
text = f"执行 {command}/({command_metadata.plugin_metadata.plugin_name} By {command_metadata.plugin_metadata.author}) 指令时发生了异常。{e}"
|
||||
logger.error(text)
|
||||
else:
|
||||
logger.error(f"当执行 {command} 指令时,发生了异常。")
|
||||
text = f"执行 {command} 指令时发生了异常。{e}"
|
||||
logger.error(text)
|
||||
return CommandResult().message(text)
|
||||
@@ -145,8 +145,6 @@ class OpenAIOfficialCommandHandler():
|
||||
async def draw(self, message: AstrMessageEvent, context: Context):
|
||||
message = message.message_str.removeprefix("画")
|
||||
img_url = await self.provider.image_generate(message)
|
||||
p = await download_image_by_url(url=img_url)
|
||||
with open(p, 'rb') as f:
|
||||
return CommandResult(
|
||||
message_chain=[Image.fromBytes(f.read())],
|
||||
)
|
||||
return CommandResult(
|
||||
message_chain=[Image.fromURL(img_url)],
|
||||
)
|
||||
@@ -64,7 +64,10 @@ class Platform():
|
||||
if isinstance(i, Plain):
|
||||
plain_str += i.text
|
||||
if plain_str and len(plain_str) > 50:
|
||||
p = await self.context.image_renderer.render(plain_str)
|
||||
rendered_images.append(Image.fromFileSystem(p))
|
||||
p = await self.context.image_renderer.render(plain_str, return_url=True)
|
||||
if p.startswith('http'):
|
||||
rendered_images.append(Image.fromURL(p))
|
||||
else:
|
||||
rendered_images.append(Image.fromFileSystem(p))
|
||||
return rendered_images
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import time, threading
|
||||
import asyncio
|
||||
|
||||
from util.io import port_checker
|
||||
from type.register import RegisteredPlatform
|
||||
@@ -21,20 +21,20 @@ class PlatformManager():
|
||||
|
||||
if 'gocqbot' in self.config and self.config['gocqbot']['enable']:
|
||||
logger.info("启用 QQ(nakuru 适配器)")
|
||||
tasks.append(self.gocq_bot())
|
||||
tasks.append(asyncio.create_task(self.gocq_bot()))
|
||||
|
||||
if 'aiocqhttp' in self.config and self.config['aiocqhttp']['enable']:
|
||||
logger.info("启用 QQ(aiocqhttp 适配器)")
|
||||
tasks.append(self.aiocq_bot())
|
||||
tasks.append(asyncio.create_task(self.aiocq_bot()))
|
||||
|
||||
# QQ频道
|
||||
if 'qqbot' in self.config and self.config['qqbot']['enable'] and self.config['qqbot']['appid'] != None:
|
||||
logger.info("启用 QQ(官方 API) 机器人消息平台")
|
||||
tasks.append(self.qqchan_bot())
|
||||
tasks.append(asyncio.create_task(self.qqchan_bot()))
|
||||
|
||||
return tasks
|
||||
|
||||
def gocq_bot(self):
|
||||
async def gocq_bot(self):
|
||||
'''
|
||||
运行 QQ(nakuru 适配器)
|
||||
'''
|
||||
@@ -51,7 +51,7 @@ class PlatformManager():
|
||||
noticed = True
|
||||
logger.warning(
|
||||
f"连接到{host}:{port}(或{http_port})失败。程序会每隔 5s 自动重试。")
|
||||
time.sleep(5)
|
||||
await asyncio.sleep(5)
|
||||
else:
|
||||
logger.info("nakuru 适配器已连接。")
|
||||
break
|
||||
@@ -59,7 +59,7 @@ class PlatformManager():
|
||||
qq_gocq = QQGOCQ(self.context, self.msg_handler)
|
||||
self.context.platforms.append(RegisteredPlatform(
|
||||
platform_name="gocq", platform_instance=qq_gocq, origin="internal"))
|
||||
return qq_gocq.run()
|
||||
await qq_gocq.run()
|
||||
except BaseException as e:
|
||||
logger.error("启动 nakuru 适配器时出现错误: " + str(e))
|
||||
|
||||
|
||||
@@ -22,11 +22,8 @@ class AIOCQHTTP(Platform):
|
||||
self.announcement = self.context.base_config.get("announcement", "欢迎新人!")
|
||||
self.host = self.context.base_config['aiocqhttp']['ws_reverse_host']
|
||||
self.port = self.context.base_config['aiocqhttp']['ws_reverse_port']
|
||||
|
||||
for handler in logging.root.handlers[:]:
|
||||
logging.root.removeHandler(handler)
|
||||
|
||||
def compat_onebot2astrbotmsg(self, event: Event) -> AstrBotMessage:
|
||||
def convert_message(self, event: Event) -> AstrBotMessage:
|
||||
|
||||
abm = AstrBotMessage()
|
||||
abm.self_id = str(event.self_id)
|
||||
@@ -69,28 +66,48 @@ class AIOCQHTTP(Platform):
|
||||
def run_aiocqhttp(self):
|
||||
if not self.host or not self.port:
|
||||
return
|
||||
self.bot = CQHttp(use_ws_reverse=True)
|
||||
|
||||
self.bot = CQHttp(use_ws_reverse=True, import_name='aiocqhttp')
|
||||
@self.bot.on_message('group')
|
||||
async def group(event: Event):
|
||||
abm = self.compat_onebot2astrbotmsg(event)
|
||||
abm = self.convert_message(event)
|
||||
if abm:
|
||||
await self.handle_msg(event, abm)
|
||||
return {'reply': event.message}
|
||||
await self.handle_msg(abm)
|
||||
# return {'reply': event.message}
|
||||
|
||||
@self.bot.on_message('private')
|
||||
async def private(event: Event):
|
||||
abm = self.compat_onebot2astrbotmsg(event)
|
||||
abm = self.convert_message(event)
|
||||
if abm:
|
||||
await self.handle_msg(event, abm)
|
||||
return {'reply': event.message}
|
||||
await self.handle_msg(abm)
|
||||
# return {'reply': event.message}
|
||||
|
||||
return self.bot.run_task(host=self.host, port=int(self.port))
|
||||
bot = self.bot.run_task(host=self.host, port=int(self.port))
|
||||
|
||||
for handler in logging.root.handlers[:]:
|
||||
logging.root.removeHandler(handler)
|
||||
logging.getLogger('aiocqhttp').setLevel(logging.ERROR)
|
||||
|
||||
return bot
|
||||
|
||||
def pre_check(self, message: AstrBotMessage) -> bool:
|
||||
# if message chain contains Plain components or At components which points to self_id, return True
|
||||
if message.type == MessageType.FRIEND_MESSAGE:
|
||||
return True
|
||||
for comp in message.message:
|
||||
if isinstance(comp, At) and str(comp.qq) == message.self_id:
|
||||
return True
|
||||
# check nicks
|
||||
if self.check_nick(message.message_str):
|
||||
return True
|
||||
return False
|
||||
|
||||
async def handle_msg(self, message: AstrBotMessage):
|
||||
logger.info(
|
||||
f"{message.sender.nickname}/{message.sender.user_id} -> {self.parse_message_outline(message)}")
|
||||
|
||||
if not self.pre_check(message):
|
||||
return
|
||||
|
||||
# 解析 role
|
||||
sender_id = str(message.sender.user_id)
|
||||
if sender_id == self.context.config_helper.get('admin_qq', '') or \
|
||||
@@ -100,7 +117,7 @@ class AIOCQHTTP(Platform):
|
||||
role = 'member'
|
||||
|
||||
# construct astrbot message event
|
||||
ame = AstrMessageEvent().from_astrbot_message(message, self.context, "gocq", message.session_id, role)
|
||||
ame = AstrMessageEvent.from_astrbot_message(message, self.context, "aiocqhttp", message.session_id, role)
|
||||
|
||||
# transfer control to message handler
|
||||
message_result = await self.message_handler.handle(ame)
|
||||
@@ -118,13 +135,14 @@ class AIOCQHTTP(Platform):
|
||||
async def reply_msg(self,
|
||||
message: AstrBotMessage,
|
||||
result_message: list):
|
||||
await super().reply_msg()
|
||||
"""
|
||||
回复用户唤醒机器人的消息。(被动回复)
|
||||
"""
|
||||
logger.info(
|
||||
f"{message.sender.user_id} <- {self.parse_message_outline(message)}")
|
||||
|
||||
res = result_message
|
||||
|
||||
if isinstance(res, str):
|
||||
res = [Plain(text=res), ]
|
||||
|
||||
@@ -138,9 +156,21 @@ class AIOCQHTTP(Platform):
|
||||
except BaseException as e:
|
||||
logger.warn(traceback.format_exc())
|
||||
logger.warn(f"以文本转图片的形式回复消息时发生错误: {e},将尝试默认方式。")
|
||||
|
||||
await self._reply(message, res)
|
||||
|
||||
async def _reply(self, message: AstrBotMessage, message_chain: List[BaseMessageComponent]):
|
||||
if isinstance(message_chain, str):
|
||||
message_chain = [Plain(text=message_chain), ]
|
||||
|
||||
await self.bot.send(message.raw_message, message_chain)
|
||||
|
||||
ret = []
|
||||
for segment in message_chain:
|
||||
d = segment.toDict()
|
||||
if isinstance(segment, Plain):
|
||||
d['type'] = 'text'
|
||||
if isinstance(segment, Image):
|
||||
# d['data']['file'] =
|
||||
pass
|
||||
ret.append(d)
|
||||
|
||||
await self.bot.send(message.raw_message, ret)
|
||||
@@ -75,8 +75,6 @@ class QQGOCQ(Platform):
|
||||
if message.type == MessageType.FRIEND_MESSAGE:
|
||||
return True
|
||||
for comp in message.message:
|
||||
if isinstance(comp, Plain):
|
||||
return True
|
||||
if isinstance(comp, At) and str(comp.qq) == message.self_id:
|
||||
return True
|
||||
# check nicks
|
||||
@@ -85,7 +83,8 @@ class QQGOCQ(Platform):
|
||||
return False
|
||||
|
||||
def run(self):
|
||||
return self.client._run()
|
||||
coro = self.client._run()
|
||||
return coro
|
||||
|
||||
async def handle_msg(self, message: AstrBotMessage):
|
||||
logger.info(
|
||||
@@ -119,7 +118,7 @@ class QQGOCQ(Platform):
|
||||
role = 'member'
|
||||
|
||||
# construct astrbot message event
|
||||
ame = AstrMessageEvent().from_astrbot_message(message, self.context, "gocq", session_id, role)
|
||||
ame = AstrMessageEvent.from_astrbot_message(message, self.context, "gocq", session_id, role)
|
||||
|
||||
# transfer control to message handler
|
||||
message_result = await self.message_handler.handle(ame)
|
||||
|
||||
@@ -309,9 +309,13 @@ class QQOfficial(Platform):
|
||||
if 'file_image' in kwargs:
|
||||
file_image_path = kwargs['file_image'].replace("file:///", "")
|
||||
if file_image_path:
|
||||
logger.debug(f"上传图片: {file_image_path}")
|
||||
image_url = await self.context.image_uploader.upload_image(file_image_path)
|
||||
logger.debug(f"上传成功: {image_url}")
|
||||
|
||||
if file_image_path.startswith("http"):
|
||||
image_url = file_image_path
|
||||
else:
|
||||
logger.debug(f"上传图片: {file_image_path}")
|
||||
image_url = await self.context.image_uploader.upload_image(file_image_path)
|
||||
logger.debug(f"上传成功: {image_url}")
|
||||
media = await self.client.api.post_group_file(kwargs['group_openid'], 1, image_url)
|
||||
del kwargs['file_image']
|
||||
kwargs['media'] = media
|
||||
|
||||
+2
-2
@@ -7,5 +7,5 @@ class RenderContext:
|
||||
def set_strategy(self, strategy: RenderStrategy):
|
||||
self._strategy = strategy
|
||||
|
||||
async def render(self, text: str) -> str:
|
||||
return await self._strategy.render(text)
|
||||
async def render(self, text: str, return_url: bool = False):
|
||||
return await self._strategy.render(text, return_url)
|
||||
|
||||
@@ -15,7 +15,7 @@ class TextToImageRenderer:
|
||||
async def render(self, text: str, use_network: bool = True, return_url: bool = False):
|
||||
if use_network:
|
||||
try:
|
||||
return await self.context.render(text)
|
||||
return await self.context.render(text, return_url=return_url)
|
||||
except BaseException as e:
|
||||
logger.error(f"Failed to render image via AstrBot API: {e}. Falling back to local rendering.")
|
||||
self.context.set_strategy(self.local_strategy)
|
||||
|
||||
@@ -17,7 +17,7 @@ class LocalRenderStrategy(RenderStrategy):
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
async def render(self, text: str):
|
||||
async def render(self, text: str, **kwargs):
|
||||
font_size = 26
|
||||
image_width = 800
|
||||
image_height = 600
|
||||
|
||||
Reference in New Issue
Block a user