fix: 修复 aiocqhttp 运行导致 ctrl+c 无法退出 bot 的问题
perf: 支持通过context注册task
This commit is contained in:
+11
-9
@@ -74,21 +74,23 @@ class AstrBotBootstrap():
|
||||
# load platforms
|
||||
platform_tasks = self.load_platform()
|
||||
# load metrics uploader
|
||||
metrics_upload_task = asyncio.create_task(self.metrics_uploader.upload_metrics())
|
||||
metrics_upload_task = asyncio.create_task(self.metrics_uploader.upload_metrics(), name="metrics-uploader")
|
||||
# load dashboard
|
||||
self.dashboard.run_http_server()
|
||||
dashboard_task = asyncio.create_task(self.dashboard.ws_server())
|
||||
tasks = [metrics_upload_task, dashboard_task, *platform_tasks]
|
||||
dashboard_task = asyncio.create_task(self.dashboard.ws_server(), name="dashboard")
|
||||
tasks = [metrics_upload_task, dashboard_task, *platform_tasks, *self.context.ext_tasks]
|
||||
tasks = [self.handle_task(task) for task in tasks]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
async def handle_task(self, task: Union[asyncio.Task, asyncio.Future]):
|
||||
try:
|
||||
result = await task
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
return None
|
||||
while True:
|
||||
try:
|
||||
result = await task
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(f"{task.get_name()} 任务发生错误,将在 5 秒后重试。")
|
||||
await asyncio.sleep(5)
|
||||
|
||||
def load_llm(self):
|
||||
if 'openai' in self.configs and \
|
||||
|
||||
@@ -21,16 +21,16 @@ class PlatformManager():
|
||||
|
||||
if 'gocqbot' in self.config and self.config['gocqbot']['enable']:
|
||||
logger.info("启用 QQ(nakuru 适配器)")
|
||||
tasks.append(asyncio.create_task(self.gocq_bot()))
|
||||
tasks.append(asyncio.create_task(self.gocq_bot(), name="nakuru-adapter"))
|
||||
|
||||
if 'aiocqhttp' in self.config and self.config['aiocqhttp']['enable']:
|
||||
logger.info("启用 QQ(aiocqhttp 适配器)")
|
||||
tasks.append(asyncio.create_task(self.aiocq_bot()))
|
||||
tasks.append(asyncio.create_task(self.aiocq_bot(), name="aiocqhttp-adapter"))
|
||||
|
||||
# QQ频道
|
||||
if 'qqbot' in self.config and self.config['qqbot']['enable'] and self.config['qqbot']['appid'] != None:
|
||||
logger.info("启用 QQ(官方 API) 机器人消息平台")
|
||||
tasks.append(asyncio.create_task(self.qqchan_bot()))
|
||||
tasks.append(asyncio.create_task(self.qqchan_bot(), name="qqofficial-adapter"))
|
||||
|
||||
return tasks
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import time
|
||||
import asyncio
|
||||
import traceback
|
||||
import logging
|
||||
from aiocqhttp import CQHttp, Event
|
||||
@@ -82,7 +83,7 @@ class AIOCQHTTP(Platform):
|
||||
await self.handle_msg(abm)
|
||||
# return {'reply': event.message}
|
||||
|
||||
bot = self.bot.run_task(host=self.host, port=int(self.port))
|
||||
bot = self.bot.run_task(host=self.host, port=int(self.port), shutdown_trigger=self.shutdown_trigger_placeholder)
|
||||
|
||||
for handler in logging.root.handlers[:]:
|
||||
logging.root.removeHandler(handler)
|
||||
@@ -90,6 +91,10 @@ class AIOCQHTTP(Platform):
|
||||
|
||||
return bot
|
||||
|
||||
async def shutdown_trigger_placeholder(self):
|
||||
while True:
|
||||
await asyncio.sleep(1)
|
||||
|
||||
def pre_check(self, message: AstrBotMessage) -> bool:
|
||||
# if message chain contains Plain components or At components which points to self_id, return True
|
||||
if message.type == MessageType.FRIEND_MESSAGE:
|
||||
|
||||
+13
-1
@@ -1,5 +1,7 @@
|
||||
import asyncio
|
||||
from asyncio import Task
|
||||
from type.register import *
|
||||
from typing import List
|
||||
from typing import List, Awaitable
|
||||
from logging import Logger
|
||||
from util.cmd_config import CmdConfig
|
||||
from util.t2i.renderer import TextToImageRenderer
|
||||
@@ -38,6 +40,7 @@ class Context:
|
||||
self.image_renderer = TextToImageRenderer()
|
||||
self.image_uploader = ImageUploader()
|
||||
self.message_handler = None # see astrbot/message/handler.py
|
||||
self.ext_tasks: List[Task] = []
|
||||
|
||||
def register_commands(self,
|
||||
plugin_name: str,
|
||||
@@ -56,6 +59,15 @@ class Context:
|
||||
'''
|
||||
self.plugin_command_bridge.register_command(plugin_name, command_name, description, priority, handler)
|
||||
|
||||
def register_task(self, coro: Awaitable, task_name: str):
|
||||
'''
|
||||
注册任务。适用于需要长时间运行的插件。
|
||||
|
||||
`coro`: 协程对象
|
||||
`task_name`: 任务名,用于标识任务。自定义即可。
|
||||
'''
|
||||
task = asyncio.create_task(coro, name=task_name)
|
||||
self.ext_tasks.append(task)
|
||||
|
||||
def find_platform(self, platform_name: str) -> RegisteredPlatform:
|
||||
for platform in self.platforms:
|
||||
|
||||
@@ -1,6 +1 @@
|
||||
'''
|
||||
大语言模型.
|
||||
|
||||
插件开发者可以继承这个类来做实现。
|
||||
'''
|
||||
from model.provider.provider import Provider as LLMProvider
|
||||
@@ -1,3 +1,4 @@
|
||||
from type.message_event import *
|
||||
from type.astrbot_message import *
|
||||
from type.command import CommandResult
|
||||
from type.command import CommandResult
|
||||
from astrbot.message.handler import MessageHandler
|
||||
@@ -8,13 +8,17 @@ from model.platform import Platform
|
||||
from type.types import Context
|
||||
from type.register import RegisteredPlatform, RegisteredLLM
|
||||
|
||||
def register_platform(platform_name: str, platform_instance: Platform, context: Context) -> None:
|
||||
def register_platform(platform_name: str, context: Context, platform_instance: Platform = None) -> None:
|
||||
'''
|
||||
注册一个消息平台。
|
||||
|
||||
Args:
|
||||
platform_name: 平台名称。
|
||||
platform_instance: 平台实例。
|
||||
platform_instance: 平台实例,可为空。
|
||||
context: 上下文对象。
|
||||
|
||||
Note:
|
||||
当插件类被加载时,AstrBot 会传给插件 context 对象。插件可以通过 context 对象注册指令、长任务等。
|
||||
'''
|
||||
|
||||
# check 是否已经注册
|
||||
|
||||
Reference in New Issue
Block a user