fix: 修复 aiocqhttp 运行导致 ctrl+c 无法退出 bot 的问题

perf: 支持通过context注册task
This commit is contained in:
Soulter
2024-07-26 05:02:29 -04:00
parent bb2164c324
commit af878f2ed3
7 changed files with 41 additions and 22 deletions
+11 -9
View File
@@ -74,21 +74,23 @@ class AstrBotBootstrap():
# load platforms
platform_tasks = self.load_platform()
# load metrics uploader
metrics_upload_task = asyncio.create_task(self.metrics_uploader.upload_metrics())
metrics_upload_task = asyncio.create_task(self.metrics_uploader.upload_metrics(), name="metrics-uploader")
# load dashboard
self.dashboard.run_http_server()
dashboard_task = asyncio.create_task(self.dashboard.ws_server())
tasks = [metrics_upload_task, dashboard_task, *platform_tasks]
dashboard_task = asyncio.create_task(self.dashboard.ws_server(), name="dashboard")
tasks = [metrics_upload_task, dashboard_task, *platform_tasks, *self.context.ext_tasks]
tasks = [self.handle_task(task) for task in tasks]
await asyncio.gather(*tasks)
async def handle_task(self, task: Union[asyncio.Task, asyncio.Future]):
try:
result = await task
return result
except Exception as e:
logger.error(traceback.format_exc())
return None
while True:
try:
result = await task
return result
except Exception as e:
logger.error(traceback.format_exc())
logger.error(f"{task.get_name()} 任务发生错误,将在 5 秒后重试。")
await asyncio.sleep(5)
def load_llm(self):
if 'openai' in self.configs and \
+3 -3
View File
@@ -21,16 +21,16 @@ class PlatformManager():
if 'gocqbot' in self.config and self.config['gocqbot']['enable']:
logger.info("启用 QQ(nakuru 适配器)")
tasks.append(asyncio.create_task(self.gocq_bot()))
tasks.append(asyncio.create_task(self.gocq_bot(), name="nakuru-adapter"))
if 'aiocqhttp' in self.config and self.config['aiocqhttp']['enable']:
logger.info("启用 QQ(aiocqhttp 适配器)")
tasks.append(asyncio.create_task(self.aiocq_bot()))
tasks.append(asyncio.create_task(self.aiocq_bot(), name="aiocqhttp-adapter"))
# QQ频道
if 'qqbot' in self.config and self.config['qqbot']['enable'] and self.config['qqbot']['appid'] != None:
logger.info("启用 QQ(官方 API) 机器人消息平台")
tasks.append(asyncio.create_task(self.qqchan_bot()))
tasks.append(asyncio.create_task(self.qqchan_bot(), name="qqofficial-adapter"))
return tasks
+6 -1
View File
@@ -1,4 +1,5 @@
import time
import asyncio
import traceback
import logging
from aiocqhttp import CQHttp, Event
@@ -82,7 +83,7 @@ class AIOCQHTTP(Platform):
await self.handle_msg(abm)
# return {'reply': event.message}
bot = self.bot.run_task(host=self.host, port=int(self.port))
bot = self.bot.run_task(host=self.host, port=int(self.port), shutdown_trigger=self.shutdown_trigger_placeholder)
for handler in logging.root.handlers[:]:
logging.root.removeHandler(handler)
@@ -90,6 +91,10 @@ class AIOCQHTTP(Platform):
return bot
async def shutdown_trigger_placeholder(self):
while True:
await asyncio.sleep(1)
def pre_check(self, message: AstrBotMessage) -> bool:
# if message chain contains Plain components or At components which points to self_id, return True
if message.type == MessageType.FRIEND_MESSAGE:
+13 -1
View File
@@ -1,5 +1,7 @@
import asyncio
from asyncio import Task
from type.register import *
from typing import List
from typing import List, Awaitable
from logging import Logger
from util.cmd_config import CmdConfig
from util.t2i.renderer import TextToImageRenderer
@@ -38,6 +40,7 @@ class Context:
self.image_renderer = TextToImageRenderer()
self.image_uploader = ImageUploader()
self.message_handler = None # see astrbot/message/handler.py
self.ext_tasks: List[Task] = []
def register_commands(self,
plugin_name: str,
@@ -56,6 +59,15 @@ class Context:
'''
self.plugin_command_bridge.register_command(plugin_name, command_name, description, priority, handler)
def register_task(self, coro: Awaitable, task_name: str):
'''
注册任务。适用于需要长时间运行的插件。
`coro`: 协程对象
`task_name`: 任务名,用于标识任务。自定义即可。
'''
task = asyncio.create_task(coro, name=task_name)
self.ext_tasks.append(task)
def find_platform(self, platform_name: str) -> RegisteredPlatform:
for platform in self.platforms:
-5
View File
@@ -1,6 +1 @@
'''
大语言模型.
插件开发者可以继承这个类来做实现。
'''
from model.provider.provider import Provider as LLMProvider
+2 -1
View File
@@ -1,3 +1,4 @@
from type.message_event import *
from type.astrbot_message import *
from type.command import CommandResult
from type.command import CommandResult
from astrbot.message.handler import MessageHandler
+6 -2
View File
@@ -8,13 +8,17 @@ from model.platform import Platform
from type.types import Context
from type.register import RegisteredPlatform, RegisteredLLM
def register_platform(platform_name: str, platform_instance: Platform, context: Context) -> None:
def register_platform(platform_name: str, context: Context, platform_instance: Platform = None) -> None:
'''
注册一个消息平台。
Args:
platform_name: 平台名称。
platform_instance: 平台实例。
platform_instance: 平台实例,可为空
context: 上下文对象。
Note:
当插件类被加载时,AstrBot 会传给插件 context 对象。插件可以通过 context 对象注册指令、长任务等。
'''
# check 是否已经注册