diff --git a/astrbot/bootstrap.py b/astrbot/bootstrap.py index 200d25961..ef548dd39 100644 --- a/astrbot/bootstrap.py +++ b/astrbot/bootstrap.py @@ -74,21 +74,23 @@ class AstrBotBootstrap(): # load platforms platform_tasks = self.load_platform() # load metrics uploader - metrics_upload_task = asyncio.create_task(self.metrics_uploader.upload_metrics()) + metrics_upload_task = asyncio.create_task(self.metrics_uploader.upload_metrics(), name="metrics-uploader") # load dashboard self.dashboard.run_http_server() - dashboard_task = asyncio.create_task(self.dashboard.ws_server()) - tasks = [metrics_upload_task, dashboard_task, *platform_tasks] + dashboard_task = asyncio.create_task(self.dashboard.ws_server(), name="dashboard") + tasks = [metrics_upload_task, dashboard_task, *platform_tasks, *self.context.ext_tasks] tasks = [self.handle_task(task) for task in tasks] await asyncio.gather(*tasks) async def handle_task(self, task: Union[asyncio.Task, asyncio.Future]): - try: - result = await task - return result - except Exception as e: - logger.error(traceback.format_exc()) - return None + while True: + try: + result = await task + return result + except Exception as e: + logger.error(traceback.format_exc()) + logger.error(f"{task.get_name()} 任务发生错误,将在 5 秒后重试。") + await asyncio.sleep(5) def load_llm(self): if 'openai' in self.configs and \ diff --git a/model/platform/manager.py b/model/platform/manager.py index 1f287f466..5ca217346 100644 --- a/model/platform/manager.py +++ b/model/platform/manager.py @@ -21,16 +21,16 @@ class PlatformManager(): if 'gocqbot' in self.config and self.config['gocqbot']['enable']: logger.info("启用 QQ(nakuru 适配器)") - tasks.append(asyncio.create_task(self.gocq_bot())) + tasks.append(asyncio.create_task(self.gocq_bot(), name="nakuru-adapter")) if 'aiocqhttp' in self.config and self.config['aiocqhttp']['enable']: logger.info("启用 QQ(aiocqhttp 适配器)") - tasks.append(asyncio.create_task(self.aiocq_bot())) + tasks.append(asyncio.create_task(self.aiocq_bot(), name="aiocqhttp-adapter")) # QQ频道 if 'qqbot' in self.config and self.config['qqbot']['enable'] and self.config['qqbot']['appid'] != None: logger.info("启用 QQ(官方 API) 机器人消息平台") - tasks.append(asyncio.create_task(self.qqchan_bot())) + tasks.append(asyncio.create_task(self.qqchan_bot(), name="qqofficial-adapter")) return tasks diff --git a/model/platform/qq_aiocqhttp.py b/model/platform/qq_aiocqhttp.py index 7bb1c4937..e0d81053b 100644 --- a/model/platform/qq_aiocqhttp.py +++ b/model/platform/qq_aiocqhttp.py @@ -1,4 +1,5 @@ import time +import asyncio import traceback import logging from aiocqhttp import CQHttp, Event @@ -82,7 +83,7 @@ class AIOCQHTTP(Platform): await self.handle_msg(abm) # return {'reply': event.message} - bot = self.bot.run_task(host=self.host, port=int(self.port)) + bot = self.bot.run_task(host=self.host, port=int(self.port), shutdown_trigger=self.shutdown_trigger_placeholder) for handler in logging.root.handlers[:]: logging.root.removeHandler(handler) @@ -90,6 +91,10 @@ class AIOCQHTTP(Platform): return bot + async def shutdown_trigger_placeholder(self): + while True: + await asyncio.sleep(1) + def pre_check(self, message: AstrBotMessage) -> bool: # if message chain contains Plain components or At components which points to self_id, return True if message.type == MessageType.FRIEND_MESSAGE: diff --git a/type/types.py b/type/types.py index 78b058970..4acf0fc3a 100644 --- a/type/types.py +++ b/type/types.py @@ -1,5 +1,7 @@ +import asyncio +from asyncio import Task from type.register import * -from typing import List +from typing import List, Awaitable from logging import Logger from util.cmd_config import CmdConfig from util.t2i.renderer import TextToImageRenderer @@ -38,6 +40,7 @@ class Context: self.image_renderer = TextToImageRenderer() self.image_uploader = ImageUploader() self.message_handler = None # see astrbot/message/handler.py + self.ext_tasks: List[Task] = [] def register_commands(self, plugin_name: str, @@ -56,6 +59,15 @@ class Context: ''' self.plugin_command_bridge.register_command(plugin_name, command_name, description, priority, handler) + def register_task(self, coro: Awaitable, task_name: str): + ''' + 注册任务。适用于需要长时间运行的插件。 + + `coro`: 协程对象 + `task_name`: 任务名,用于标识任务。自定义即可。 + ''' + task = asyncio.create_task(coro, name=task_name) + self.ext_tasks.append(task) def find_platform(self, platform_name: str) -> RegisteredPlatform: for platform in self.platforms: diff --git a/util/plugin_dev/api/v1/llm.py b/util/plugin_dev/api/v1/llm.py index e585689ca..11d2be559 100644 --- a/util/plugin_dev/api/v1/llm.py +++ b/util/plugin_dev/api/v1/llm.py @@ -1,6 +1 @@ -''' -大语言模型. - -插件开发者可以继承这个类来做实现。 -''' from model.provider.provider import Provider as LLMProvider \ No newline at end of file diff --git a/util/plugin_dev/api/v1/message.py b/util/plugin_dev/api/v1/message.py index 182e9114d..218498c4b 100644 --- a/util/plugin_dev/api/v1/message.py +++ b/util/plugin_dev/api/v1/message.py @@ -1,3 +1,4 @@ from type.message_event import * from type.astrbot_message import * -from type.command import CommandResult \ No newline at end of file +from type.command import CommandResult +from astrbot.message.handler import MessageHandler \ No newline at end of file diff --git a/util/plugin_dev/api/v1/register.py b/util/plugin_dev/api/v1/register.py index 6ebef89c9..da5881922 100644 --- a/util/plugin_dev/api/v1/register.py +++ b/util/plugin_dev/api/v1/register.py @@ -8,13 +8,17 @@ from model.platform import Platform from type.types import Context from type.register import RegisteredPlatform, RegisteredLLM -def register_platform(platform_name: str, platform_instance: Platform, context: Context) -> None: +def register_platform(platform_name: str, context: Context, platform_instance: Platform = None) -> None: ''' 注册一个消息平台。 Args: platform_name: 平台名称。 - platform_instance: 平台实例。 + platform_instance: 平台实例,可为空。 + context: 上下文对象。 + + Note: + 当插件类被加载时,AstrBot 会传给插件 context 对象。插件可以通过 context 对象注册指令、长任务等。 ''' # check 是否已经注册