feat(plugin): 添加 AstrBot 启动完成时的事件钩子;添加获取制定平台适配器的接口

This commit is contained in:
Soulter
2025-03-02 20:56:18 +08:00
parent 59fbbd5987
commit a18de9de7d
17 changed files with 162 additions and 99 deletions
+2
View File
@@ -6,6 +6,7 @@ from astrbot.core.star.register import (
register_platform_adapter_type as platform_adapter_type,
register_permission_type as permission_type,
register_custom_filter as custom_filter,
register_on_astrbot_loaded as on_astrbot_loaded,
register_on_llm_request as on_llm_request,
register_on_llm_response as on_llm_response,
register_llm_tool as llm_tool,
@@ -33,6 +34,7 @@ __all__ = [
'CustomFilter',
'custom_filter',
'PermissionType',
'on_astrbot_loaded',
'on_llm_request',
'llm_tool',
'on_decorating_result',
+12
View File
@@ -19,6 +19,9 @@ from astrbot.core import logger
from astrbot.core.config.default import VERSION
from astrbot.core.rag.knowledge_db_mgr import KnowledgeDBManager
from astrbot.core.conversation_mgr import ConversationManager
from astrbot.core.star.star_handler import star_handlers_registry, EventType
from astrbot.core.star.star_handler import star_map
class AstrBotCoreLifecycle:
def __init__(self, log_broker: LogBroker, db: BaseDatabase):
self.log_broker = log_broker
@@ -104,6 +107,15 @@ class AstrBotCoreLifecycle:
self._load()
logger.info("AstrBot 启动完成。")
# 执行启动完成事件钩子
handlers = star_handlers_registry.get_handlers_by_event_type(EventType.OnAstrBotLoadedEvent)
for handler in handlers:
try:
logger.info(f"hook(on_astrbot_loaded) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}")
await handler.handler()
except BaseException:
logger.error(traceback.format_exc())
await asyncio.gather(*self.curr_tasks, return_exceptions=True)
async def stop(self):
-1
View File
@@ -10,7 +10,6 @@ class EventBus:
self.pipeline_scheduler = pipeline_scheduler
async def dispatch(self):
logger.info("事件总线已打开。")
while True:
event: AstrMessageEvent = await self.event_queue.get()
self._print_event(event)
+5 -5
View File
@@ -33,17 +33,17 @@ class AstrMessageEvent(abc.ABC):
self.message_str = message_str
'''纯文本的消息'''
self.message_obj = message_obj
'''消息对象AstrBotMessage。带有完整的消息结构。'''
'''消息对象, AstrBotMessage。带有完整的消息结构。'''
self.platform_meta = platform_meta
'''消息平台的信息, 其中 name 是平台的类型,如 aiocqhttp'''
self.session_id = session_id
'''用户的会话 ID。可以直接使用下面的 unified_msg_origin'''
self.role = "member"
'''用户是否是管理员。如果是管理员,这里是 admin'''
self.is_wake = False # 是否通过 WakingStage
'''是否唤醒'''
self.is_wake = False
'''是否唤醒(是否通过 WakingStage)'''
self.is_at_or_wake_command = False
'''是否是 At 机器人或者带有唤醒词或者是私聊事件监听器会让 is_wake 设为 True但是不会让这个属性置为 True'''
'''是否是 At 机器人或者带有唤醒词或者是私聊(插件注册的事件监听器会让 is_wake 设为 True, 但是不会让这个属性置为 True)'''
self._extras = {}
self.session = MessageSesion(
platform_name=platform_meta.name,
@@ -56,7 +56,7 @@ class AstrMessageEvent(abc.ABC):
'''消息事件的结果'''
self._has_send_oper = False
'''是否有过至少一次发送操作'''
'''在此次事件中是否有过至少一次发送消息的操作'''
self.call_llm = False
'''是否在此消息事件中禁止默认的 LLM 请求'''
+7 -1
View File
@@ -45,4 +45,10 @@ class Platform(abc.ABC):
'''
提交一个事件到事件队列。
'''
self._event_queue.put_nowait(event)
self._event_queue.put_nowait(event)
def get_client(self):
'''
获取平台的客户端对象。
'''
pass
@@ -34,6 +34,36 @@ class AiocqhttpAdapter(Platform):
self.stop = False
self.bot = CQHttp(use_ws_reverse=True, import_name='aiocqhttp', api_timeout_sec=180)
@self.bot.on_request()
async def request(event: Event):
abm = await self.convert_message(event)
if abm:
await self.handle_msg(abm)
@self.bot.on_notice()
async def notice(event: Event):
abm = await self.convert_message(event)
if abm:
await self.handle_msg(abm)
@self.bot.on_message('group')
async def group(event: Event):
abm = await self.convert_message(event)
if abm:
await self.handle_msg(abm)
@self.bot.on_message('private')
async def private(event: Event):
abm = await self.convert_message(event)
if abm:
await self.handle_msg(abm)
@self.bot.on_websocket_connection
def on_websocket_connection(_):
logger.info("aiocqhttp(OneBot v11) 适配器已连接。")
async def send_by_session(self, session: MessageSesion, message_chain: MessageChain):
ret = await AiocqhttpMessageEvent._parse_onebot_json(message_chain)
match session.message_type.value:
@@ -113,7 +143,6 @@ class AiocqhttpAdapter(Platform):
return abm
async def _convert_handle_message_event(self, event: Event) -> AstrBotMessage:
'''OneBot V11 消息类事件'''
abm = AstrBotMessage()
@@ -202,44 +231,14 @@ class AiocqhttpAdapter(Platform):
logger.warning("aiocqhttp: 未配置 ws_reverse_host 或 ws_reverse_port,将使用默认值:http://0.0.0.0:6199")
self.host = "0.0.0.0"
self.port = 6199
self.bot = CQHttp(use_ws_reverse=True, import_name='aiocqhttp', api_timeout_sec=180)
@self.bot.on_request()
async def request(event: Event):
abm = await self.convert_message(event)
if abm:
await self.handle_msg(abm)
@self.bot.on_notice()
async def notice(event: Event):
abm = await self.convert_message(event)
if abm:
await self.handle_msg(abm)
@self.bot.on_message('group')
async def group(event: Event):
abm = await self.convert_message(event)
if abm:
await self.handle_msg(abm)
@self.bot.on_message('private')
async def private(event: Event):
abm = await self.convert_message(event)
if abm:
await self.handle_msg(abm)
@self.bot.on_websocket_connection
def on_websocket_connection(_):
logger.info("aiocqhttp(OneBot v11) 适配器已连接。")
bot = self.bot.run_task(host=self.host, port=int(self.port), shutdown_trigger=self.shutdown_trigger_placeholder)
coro = self.bot.run_task(host=self.host, port=int(self.port), shutdown_trigger=self.shutdown_trigger_placeholder)
for handler in logging.root.handlers[:]:
logging.root.removeHandler(handler)
logging.getLogger('aiocqhttp').setLevel(logging.ERROR)
return bot
return coro
async def terminate(self):
self.stop = True
@@ -264,3 +263,6 @@ class AiocqhttpAdapter(Platform):
)
self.commit_event(message_event)
async def get_client(self):
return self.bot
@@ -320,7 +320,7 @@ class SimpleGewechatClient():
logger.info(f"使用验证码: {code}")
try:
os.remove("data/temp/gewe_code")
except:
except Exception:
logger.warning("删除验证码文件 data/temp/gewe_code 失败。")
async with aiohttp.ClientSession() as session:
@@ -26,6 +26,19 @@ class GewechatPlatformAdapter(Platform):
self.settingss = platform_settings
self.test_mode = os.environ.get('TEST_MODE', 'off') == 'on'
self.client = None
self.client = SimpleGewechatClient(
self.config['base_url'],
self.config['nickname'],
self.config['host'],
self.config['port'],
self._event_queue,
)
async def on_event_received(abm: AstrBotMessage):
await self.handle_msg(abm)
self.client.on_event_received = on_event_received
@override
async def send_by_session(self, session: MessageSesion, message_chain: MessageChain):
@@ -50,32 +63,17 @@ class GewechatPlatformAdapter(Platform):
async def terminate(self):
self.client.stop = True
await asyncio.sleep(1)
@override
def run(self):
self.client = SimpleGewechatClient(
self.config['base_url'],
self.config['nickname'],
self.config['host'],
self.config['port'],
self._event_queue,
)
async def on_event_received(abm: AstrBotMessage):
await self.handle_msg(abm)
self.client.on_event_received = on_event_received
return self._run()
async def logout(self):
await self.client.logout()
@override
def run(self):
return self._run()
async def _run(self):
await self.client.login()
await self.client.start_polling()
async def handle_msg(self, message: AstrBotMessage):
if message.type == MessageType.GROUP_MESSAGE:
@@ -90,4 +88,7 @@ class GewechatPlatformAdapter(Platform):
client=self.client
)
self.commit_event(message_event)
self.commit_event(message_event)
def get_client(self):
return self.client
@@ -1,17 +1,14 @@
import base64
import time
import asyncio
import json
import re
from astrbot.api.platform import Platform, AstrBotMessage, MessageMember, MessageType, PlatformMetadata
from astrbot.api.event import MessageChain
from typing import Union, List
from astrbot.api.message_components import Image, Plain, At
from astrbot.core.platform.astr_message_event import MessageSesion
from .lark_event import LarkMessageEvent
from ...register import register_platform_adapter
from astrbot.core.message.components import BaseMessageComponent
from astrbot import logger
import lark_oapi as lark
from lark_oapi.api.im.v1 import *
@@ -91,7 +88,7 @@ class LarkPlatformAdapter(Platform):
if message.message_type == 'text':
message_str_raw = content_json_b['text'] # 带有 @ 的消息
at_pattern = r"(@_user_\d+)" # 可以根据需求修改正则
at_users = re.findall(at_pattern, message_str_raw)
# at_users = re.findall(at_pattern, message_str_raw)
# 拆分文本,去掉AT符号部分
parts = re.split(at_pattern, message_str_raw)
for i in range(len(parts)):
@@ -172,4 +169,6 @@ class LarkPlatformAdapter(Platform):
async def run(self):
# self.client.start()
await self.client._connect()
async def get_client(self):
return self.client
@@ -179,4 +179,7 @@ class QQOfficialPlatformAdapter(Platform):
return self.client.start(
appid=self.appid,
secret=self.secret
)
)
def get_client(self):
return self.client
@@ -96,4 +96,8 @@ class QQOfficialWebhookPlatformAdapter(Platform):
self.client
)
await self.webhook_helper.initialize()
await self.webhook_helper.start_polling()
await self.webhook_helper.start_polling()
async def get_client(self):
return self.client
@@ -28,6 +28,17 @@ class TelegramPlatformAdapter(Platform):
self.config = platform_config
self.settings = platform_settings
self.client_self_id = uuid.uuid4().hex[:8]
base_url = self.config.get("telegram_api_base_url", "https://api.telegram.org/bot")
if not base_url:
base_url = "https://api.telegram.org/bot"
self.application = ApplicationBuilder().token(self.config['telegram_token']).base_url(base_url).build()
message_handler = TelegramMessageHandler(
filters=filters.ALL, # receive all messages
callback=self.convert_message
)
self.application.add_handler(message_handler)
self.client = self.application.bot
@override
async def send_by_session(self, session: MessageSesion, message_chain: MessageChain):
@@ -44,22 +55,10 @@ class TelegramPlatformAdapter(Platform):
@override
async def run(self):
base_url = self.config.get("telegram_api_base_url", "https://api.telegram.org/bot")
if not base_url:
base_url = "https://api.telegram.org/bot"
self.application = ApplicationBuilder().token(self.config['telegram_token']).base_url(base_url).build()
message_handler = TelegramMessageHandler(
filters=filters.ALL, # receive all messages
callback=self.convert_message
)
self.application.add_handler(message_handler)
await self.application.initialize()
await self.application.start()
queue = self.application.updater.start_polling()
self.client = self.application.bot
logger.info("Telegram Platform Adapter is running.")
await queue
async def start(self, update: Update, context: ContextTypes.DEFAULT_TYPE):
@@ -125,4 +124,7 @@ class TelegramPlatformAdapter(Platform):
session_id=message.session_id,
client=self.client
)
self.commit_event(message_event)
self.commit_event(message_event)
async def get_client(self):
return self.client
@@ -116,20 +116,7 @@ class WecomPlatformAdapter(Platform):
if not self.api_base_url.endswith("/"):
self.api_base_url += "/"
@override
async def send_by_session(self, session: MessageSesion, message_chain: MessageChain):
await super().send_by_session(session, message_chain)
@override
def meta(self) -> PlatformMetadata:
return PlatformMetadata(
"wecom",
"wecom 适配器",
)
@override
async def run(self):
self.server = WecomServer(
self._event_queue,
self.config
@@ -145,7 +132,21 @@ class WecomPlatformAdapter(Platform):
await self.convert_message(msg)
self.server.callback = callback
@override
async def send_by_session(self, session: MessageSesion, message_chain: MessageChain):
await super().send_by_session(session, message_chain)
@override
def meta(self) -> PlatformMetadata:
return PlatformMetadata(
"wecom",
"wecom 适配器",
)
@override
async def run(self):
await self.server.start_polling()
async def convert_message(self, msg):
@@ -227,4 +228,7 @@ class WecomPlatformAdapter(Platform):
session_id=message.session_id,
client=self.client
)
self.commit_event(message_event)
self.commit_event(message_event)
def get_client(self):
return self.client
+19 -3
View File
@@ -9,6 +9,7 @@ from astrbot.core.provider.func_tool_manager import FuncCall
from astrbot.core.platform.astr_message_event import MessageSesion
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.provider.manager import ProviderManager
from astrbot.core.platform import Platform
from astrbot.core.platform.manager import PlatformManager
from .star import star_registry, StarMetadata, star_map
from .star_handler import star_handlers_registry, StarHandlerMetadata, EventType
@@ -17,6 +18,8 @@ from .filter.regex import RegexFilter
from typing import Awaitable
from astrbot.core.rag.knowledge_db_mgr import KnowledgeDBManager
from astrbot.core.conversation_mgr import ConversationManager
from astrbot.core.star.filter.platform_adapter_type import PlatformAdapterType, ADAPTER_NAME_2_TYPE
class Context:
'''
@@ -169,9 +172,21 @@ class Context:
'''
return self._event_queue
def get_platform(self, platform_type: Union[PlatformAdapterType, str]) -> Platform:
'''
获取指定类型的平台适配器。
'''
for platform in self.platform_manager.platform_insts:
if isinstance(platform_type, str):
if platform.meta().name == platform_type:
return platform
else:
if platform.meta().name == ADAPTER_NAME_2_TYPE[platform_type]:
return platform
async def send_message(self, session: Union[str, MessageSesion], message_chain: MessageChain) -> bool:
'''
根据 session(unified_msg_origin) 发送消息。
根据 session(unified_msg_origin) 主动发送消息。
@param session: 消息会话。通过 event.session 或者 event.unified_msg_origin 获取。
@param message_chain: 消息链。
@@ -179,6 +194,8 @@ class Context:
@return: 是否找到匹配的平台。
当 session 为字符串时,会尝试解析为 MessageSesion 对象,如果解析失败,会抛出 ValueError 异常。
NOTE: qq_official(QQ 官方 API 平台) 不支持此方法
'''
if isinstance(session, str):
@@ -192,7 +209,7 @@ class Context:
await platform.send_by_session(session, message_chain)
return True
return False
'''
以下的方法已经不推荐使用。请从 AstrBot 文档查看更好的注册方式。
'''
@@ -224,7 +241,6 @@ class Context:
'''删除一个函数调用工具。如果再要启用,需要重新注册。'''
self.provider_manager.llm_tools.remove_func(name)
def register_commands(self, star_name: str, command_name: str, desc: str, priority: int, awaitable: Awaitable, use_regex=False, ignore_prefix=False):
'''
注册一个命令。
+2
View File
@@ -7,6 +7,7 @@ from .star_handler import (
register_regex,
register_permission_type,
register_custom_filter,
register_on_astrbot_loaded,
register_on_llm_request,
register_on_llm_response,
register_llm_tool,
@@ -23,6 +24,7 @@ __all__ = [
'register_regex',
'register_permission_type',
'register_custom_filter',
'register_on_astrbot_loaded',
'register_on_llm_request',
'register_on_llm_response',
'register_llm_tool',
@@ -205,6 +205,15 @@ def register_permission_type(permission_type: PermissionType, raise_error: bool
return decorator
def register_on_astrbot_loaded(**kwargs):
'''当 AstrBot 加载完成时
'''
def decorator(awaitable):
_ = get_handler_or_create(awaitable, EventType.OnAstrBotLoadedEvent, **kwargs)
return awaitable
return decorator
def register_on_llm_request(**kwargs):
'''当有 LLM 请求时的事件
+2
View File
@@ -77,6 +77,8 @@ class EventType(enum.Enum):
用于对 Handler 的职能分组。
'''
OnAstrBotLoadedEvent = enum.auto() # AstrBot 加载完成
AdapterMessageEvent = enum.auto() # 收到适配器发来的消息
OnLLMRequestEvent = enum.auto() # 收到 LLM 请求(可以是用户也可以是插件)
OnLLMResponseEvent = enum.auto() # LLM 响应后