✨feat(plugin): 添加 AstrBot 启动完成时的事件钩子;添加获取制定平台适配器的接口
This commit is contained in:
@@ -6,6 +6,7 @@ from astrbot.core.star.register import (
|
||||
register_platform_adapter_type as platform_adapter_type,
|
||||
register_permission_type as permission_type,
|
||||
register_custom_filter as custom_filter,
|
||||
register_on_astrbot_loaded as on_astrbot_loaded,
|
||||
register_on_llm_request as on_llm_request,
|
||||
register_on_llm_response as on_llm_response,
|
||||
register_llm_tool as llm_tool,
|
||||
@@ -33,6 +34,7 @@ __all__ = [
|
||||
'CustomFilter',
|
||||
'custom_filter',
|
||||
'PermissionType',
|
||||
'on_astrbot_loaded',
|
||||
'on_llm_request',
|
||||
'llm_tool',
|
||||
'on_decorating_result',
|
||||
|
||||
@@ -19,6 +19,9 @@ from astrbot.core import logger
|
||||
from astrbot.core.config.default import VERSION
|
||||
from astrbot.core.rag.knowledge_db_mgr import KnowledgeDBManager
|
||||
from astrbot.core.conversation_mgr import ConversationManager
|
||||
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
||||
from astrbot.core.star.star_handler import star_map
|
||||
|
||||
class AstrBotCoreLifecycle:
|
||||
def __init__(self, log_broker: LogBroker, db: BaseDatabase):
|
||||
self.log_broker = log_broker
|
||||
@@ -104,6 +107,15 @@ class AstrBotCoreLifecycle:
|
||||
self._load()
|
||||
logger.info("AstrBot 启动完成。")
|
||||
|
||||
# 执行启动完成事件钩子
|
||||
handlers = star_handlers_registry.get_handlers_by_event_type(EventType.OnAstrBotLoadedEvent)
|
||||
for handler in handlers:
|
||||
try:
|
||||
logger.info(f"hook(on_astrbot_loaded) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}")
|
||||
await handler.handler()
|
||||
except BaseException:
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
await asyncio.gather(*self.curr_tasks, return_exceptions=True)
|
||||
|
||||
async def stop(self):
|
||||
|
||||
@@ -10,7 +10,6 @@ class EventBus:
|
||||
self.pipeline_scheduler = pipeline_scheduler
|
||||
|
||||
async def dispatch(self):
|
||||
logger.info("事件总线已打开。")
|
||||
while True:
|
||||
event: AstrMessageEvent = await self.event_queue.get()
|
||||
self._print_event(event)
|
||||
|
||||
@@ -33,17 +33,17 @@ class AstrMessageEvent(abc.ABC):
|
||||
self.message_str = message_str
|
||||
'''纯文本的消息'''
|
||||
self.message_obj = message_obj
|
||||
'''消息对象,AstrBotMessage。带有完整的消息结构。'''
|
||||
'''消息对象, AstrBotMessage。带有完整的消息结构。'''
|
||||
self.platform_meta = platform_meta
|
||||
'''消息平台的信息, 其中 name 是平台的类型,如 aiocqhttp'''
|
||||
self.session_id = session_id
|
||||
'''用户的会话 ID。可以直接使用下面的 unified_msg_origin'''
|
||||
self.role = "member"
|
||||
'''用户是否是管理员。如果是管理员,这里是 admin'''
|
||||
self.is_wake = False # 是否通过 WakingStage
|
||||
'''是否唤醒'''
|
||||
self.is_wake = False
|
||||
'''是否唤醒(是否通过 WakingStage)'''
|
||||
self.is_at_or_wake_command = False
|
||||
'''是否是 At 机器人或者带有唤醒词或者是私聊(事件监听器会让 is_wake 设为 True,但是不会让这个属性置为 True)'''
|
||||
'''是否是 At 机器人或者带有唤醒词或者是私聊(插件注册的事件监听器会让 is_wake 设为 True, 但是不会让这个属性置为 True)'''
|
||||
self._extras = {}
|
||||
self.session = MessageSesion(
|
||||
platform_name=platform_meta.name,
|
||||
@@ -56,7 +56,7 @@ class AstrMessageEvent(abc.ABC):
|
||||
'''消息事件的结果'''
|
||||
|
||||
self._has_send_oper = False
|
||||
'''是否有过至少一次发送操作'''
|
||||
'''在此次事件中是否有过至少一次发送消息的操作'''
|
||||
self.call_llm = False
|
||||
'''是否在此消息事件中禁止默认的 LLM 请求'''
|
||||
|
||||
|
||||
@@ -45,4 +45,10 @@ class Platform(abc.ABC):
|
||||
'''
|
||||
提交一个事件到事件队列。
|
||||
'''
|
||||
self._event_queue.put_nowait(event)
|
||||
self._event_queue.put_nowait(event)
|
||||
|
||||
def get_client(self):
|
||||
'''
|
||||
获取平台的客户端对象。
|
||||
'''
|
||||
pass
|
||||
@@ -34,6 +34,36 @@ class AiocqhttpAdapter(Platform):
|
||||
|
||||
self.stop = False
|
||||
|
||||
self.bot = CQHttp(use_ws_reverse=True, import_name='aiocqhttp', api_timeout_sec=180)
|
||||
|
||||
@self.bot.on_request()
|
||||
async def request(event: Event):
|
||||
abm = await self.convert_message(event)
|
||||
if abm:
|
||||
await self.handle_msg(abm)
|
||||
|
||||
@self.bot.on_notice()
|
||||
async def notice(event: Event):
|
||||
abm = await self.convert_message(event)
|
||||
if abm:
|
||||
await self.handle_msg(abm)
|
||||
|
||||
@self.bot.on_message('group')
|
||||
async def group(event: Event):
|
||||
abm = await self.convert_message(event)
|
||||
if abm:
|
||||
await self.handle_msg(abm)
|
||||
|
||||
@self.bot.on_message('private')
|
||||
async def private(event: Event):
|
||||
abm = await self.convert_message(event)
|
||||
if abm:
|
||||
await self.handle_msg(abm)
|
||||
|
||||
@self.bot.on_websocket_connection
|
||||
def on_websocket_connection(_):
|
||||
logger.info("aiocqhttp(OneBot v11) 适配器已连接。")
|
||||
|
||||
async def send_by_session(self, session: MessageSesion, message_chain: MessageChain):
|
||||
ret = await AiocqhttpMessageEvent._parse_onebot_json(message_chain)
|
||||
match session.message_type.value:
|
||||
@@ -113,7 +143,6 @@ class AiocqhttpAdapter(Platform):
|
||||
|
||||
return abm
|
||||
|
||||
|
||||
async def _convert_handle_message_event(self, event: Event) -> AstrBotMessage:
|
||||
'''OneBot V11 消息类事件'''
|
||||
abm = AstrBotMessage()
|
||||
@@ -202,44 +231,14 @@ class AiocqhttpAdapter(Platform):
|
||||
logger.warning("aiocqhttp: 未配置 ws_reverse_host 或 ws_reverse_port,将使用默认值:http://0.0.0.0:6199")
|
||||
self.host = "0.0.0.0"
|
||||
self.port = 6199
|
||||
|
||||
self.bot = CQHttp(use_ws_reverse=True, import_name='aiocqhttp', api_timeout_sec=180)
|
||||
|
||||
@self.bot.on_request()
|
||||
async def request(event: Event):
|
||||
abm = await self.convert_message(event)
|
||||
if abm:
|
||||
await self.handle_msg(abm)
|
||||
|
||||
@self.bot.on_notice()
|
||||
async def notice(event: Event):
|
||||
abm = await self.convert_message(event)
|
||||
if abm:
|
||||
await self.handle_msg(abm)
|
||||
|
||||
@self.bot.on_message('group')
|
||||
async def group(event: Event):
|
||||
abm = await self.convert_message(event)
|
||||
if abm:
|
||||
await self.handle_msg(abm)
|
||||
|
||||
@self.bot.on_message('private')
|
||||
async def private(event: Event):
|
||||
abm = await self.convert_message(event)
|
||||
if abm:
|
||||
await self.handle_msg(abm)
|
||||
|
||||
@self.bot.on_websocket_connection
|
||||
def on_websocket_connection(_):
|
||||
logger.info("aiocqhttp(OneBot v11) 适配器已连接。")
|
||||
|
||||
bot = self.bot.run_task(host=self.host, port=int(self.port), shutdown_trigger=self.shutdown_trigger_placeholder)
|
||||
coro = self.bot.run_task(host=self.host, port=int(self.port), shutdown_trigger=self.shutdown_trigger_placeholder)
|
||||
|
||||
for handler in logging.root.handlers[:]:
|
||||
logging.root.removeHandler(handler)
|
||||
logging.getLogger('aiocqhttp').setLevel(logging.ERROR)
|
||||
|
||||
return bot
|
||||
return coro
|
||||
|
||||
async def terminate(self):
|
||||
self.stop = True
|
||||
@@ -264,3 +263,6 @@ class AiocqhttpAdapter(Platform):
|
||||
)
|
||||
|
||||
self.commit_event(message_event)
|
||||
|
||||
async def get_client(self):
|
||||
return self.bot
|
||||
@@ -320,7 +320,7 @@ class SimpleGewechatClient():
|
||||
logger.info(f"使用验证码: {code}")
|
||||
try:
|
||||
os.remove("data/temp/gewe_code")
|
||||
except:
|
||||
except Exception:
|
||||
logger.warning("删除验证码文件 data/temp/gewe_code 失败。")
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
|
||||
@@ -26,6 +26,19 @@ class GewechatPlatformAdapter(Platform):
|
||||
self.settingss = platform_settings
|
||||
self.test_mode = os.environ.get('TEST_MODE', 'off') == 'on'
|
||||
self.client = None
|
||||
|
||||
self.client = SimpleGewechatClient(
|
||||
self.config['base_url'],
|
||||
self.config['nickname'],
|
||||
self.config['host'],
|
||||
self.config['port'],
|
||||
self._event_queue,
|
||||
)
|
||||
|
||||
async def on_event_received(abm: AstrBotMessage):
|
||||
await self.handle_msg(abm)
|
||||
|
||||
self.client.on_event_received = on_event_received
|
||||
|
||||
@override
|
||||
async def send_by_session(self, session: MessageSesion, message_chain: MessageChain):
|
||||
@@ -50,32 +63,17 @@ class GewechatPlatformAdapter(Platform):
|
||||
async def terminate(self):
|
||||
self.client.stop = True
|
||||
await asyncio.sleep(1)
|
||||
|
||||
@override
|
||||
def run(self):
|
||||
self.client = SimpleGewechatClient(
|
||||
self.config['base_url'],
|
||||
self.config['nickname'],
|
||||
self.config['host'],
|
||||
self.config['port'],
|
||||
self._event_queue,
|
||||
)
|
||||
|
||||
async def on_event_received(abm: AstrBotMessage):
|
||||
await self.handle_msg(abm)
|
||||
|
||||
self.client.on_event_received = on_event_received
|
||||
|
||||
return self._run()
|
||||
|
||||
async def logout(self):
|
||||
await self.client.logout()
|
||||
|
||||
@override
|
||||
def run(self):
|
||||
return self._run()
|
||||
|
||||
async def _run(self):
|
||||
await self.client.login()
|
||||
|
||||
await self.client.start_polling()
|
||||
|
||||
|
||||
async def handle_msg(self, message: AstrBotMessage):
|
||||
if message.type == MessageType.GROUP_MESSAGE:
|
||||
@@ -90,4 +88,7 @@ class GewechatPlatformAdapter(Platform):
|
||||
client=self.client
|
||||
)
|
||||
|
||||
self.commit_event(message_event)
|
||||
self.commit_event(message_event)
|
||||
|
||||
def get_client(self):
|
||||
return self.client
|
||||
@@ -1,17 +1,14 @@
|
||||
import base64
|
||||
import time
|
||||
import asyncio
|
||||
import json
|
||||
import re
|
||||
|
||||
from astrbot.api.platform import Platform, AstrBotMessage, MessageMember, MessageType, PlatformMetadata
|
||||
from astrbot.api.event import MessageChain
|
||||
from typing import Union, List
|
||||
from astrbot.api.message_components import Image, Plain, At
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from .lark_event import LarkMessageEvent
|
||||
from ...register import register_platform_adapter
|
||||
from astrbot.core.message.components import BaseMessageComponent
|
||||
from astrbot import logger
|
||||
import lark_oapi as lark
|
||||
from lark_oapi.api.im.v1 import *
|
||||
@@ -91,7 +88,7 @@ class LarkPlatformAdapter(Platform):
|
||||
if message.message_type == 'text':
|
||||
message_str_raw = content_json_b['text'] # 带有 @ 的消息
|
||||
at_pattern = r"(@_user_\d+)" # 可以根据需求修改正则
|
||||
at_users = re.findall(at_pattern, message_str_raw)
|
||||
# at_users = re.findall(at_pattern, message_str_raw)
|
||||
# 拆分文本,去掉AT符号部分
|
||||
parts = re.split(at_pattern, message_str_raw)
|
||||
for i in range(len(parts)):
|
||||
@@ -172,4 +169,6 @@ class LarkPlatformAdapter(Platform):
|
||||
async def run(self):
|
||||
# self.client.start()
|
||||
await self.client._connect()
|
||||
|
||||
|
||||
async def get_client(self):
|
||||
return self.client
|
||||
@@ -179,4 +179,7 @@ class QQOfficialPlatformAdapter(Platform):
|
||||
return self.client.start(
|
||||
appid=self.appid,
|
||||
secret=self.secret
|
||||
)
|
||||
)
|
||||
|
||||
def get_client(self):
|
||||
return self.client
|
||||
@@ -96,4 +96,8 @@ class QQOfficialWebhookPlatformAdapter(Platform):
|
||||
self.client
|
||||
)
|
||||
await self.webhook_helper.initialize()
|
||||
await self.webhook_helper.start_polling()
|
||||
await self.webhook_helper.start_polling()
|
||||
|
||||
|
||||
async def get_client(self):
|
||||
return self.client
|
||||
@@ -28,6 +28,17 @@ class TelegramPlatformAdapter(Platform):
|
||||
self.config = platform_config
|
||||
self.settings = platform_settings
|
||||
self.client_self_id = uuid.uuid4().hex[:8]
|
||||
|
||||
base_url = self.config.get("telegram_api_base_url", "https://api.telegram.org/bot")
|
||||
if not base_url:
|
||||
base_url = "https://api.telegram.org/bot"
|
||||
self.application = ApplicationBuilder().token(self.config['telegram_token']).base_url(base_url).build()
|
||||
message_handler = TelegramMessageHandler(
|
||||
filters=filters.ALL, # receive all messages
|
||||
callback=self.convert_message
|
||||
)
|
||||
self.application.add_handler(message_handler)
|
||||
self.client = self.application.bot
|
||||
|
||||
@override
|
||||
async def send_by_session(self, session: MessageSesion, message_chain: MessageChain):
|
||||
@@ -44,22 +55,10 @@ class TelegramPlatformAdapter(Platform):
|
||||
|
||||
@override
|
||||
async def run(self):
|
||||
base_url = self.config.get("telegram_api_base_url", "https://api.telegram.org/bot")
|
||||
if not base_url:
|
||||
base_url = "https://api.telegram.org/bot"
|
||||
|
||||
self.application = ApplicationBuilder().token(self.config['telegram_token']).base_url(base_url).build()
|
||||
message_handler = TelegramMessageHandler(
|
||||
filters=filters.ALL, # receive all messages
|
||||
callback=self.convert_message
|
||||
)
|
||||
self.application.add_handler(message_handler)
|
||||
await self.application.initialize()
|
||||
await self.application.start()
|
||||
queue = self.application.updater.start_polling()
|
||||
self.client = self.application.bot
|
||||
logger.info("Telegram Platform Adapter is running.")
|
||||
|
||||
await queue
|
||||
|
||||
async def start(self, update: Update, context: ContextTypes.DEFAULT_TYPE):
|
||||
@@ -125,4 +124,7 @@ class TelegramPlatformAdapter(Platform):
|
||||
session_id=message.session_id,
|
||||
client=self.client
|
||||
)
|
||||
self.commit_event(message_event)
|
||||
self.commit_event(message_event)
|
||||
|
||||
async def get_client(self):
|
||||
return self.client
|
||||
@@ -116,20 +116,7 @@ class WecomPlatformAdapter(Platform):
|
||||
|
||||
if not self.api_base_url.endswith("/"):
|
||||
self.api_base_url += "/"
|
||||
|
||||
@override
|
||||
async def send_by_session(self, session: MessageSesion, message_chain: MessageChain):
|
||||
await super().send_by_session(session, message_chain)
|
||||
|
||||
@override
|
||||
def meta(self) -> PlatformMetadata:
|
||||
return PlatformMetadata(
|
||||
"wecom",
|
||||
"wecom 适配器",
|
||||
)
|
||||
|
||||
@override
|
||||
async def run(self):
|
||||
|
||||
self.server = WecomServer(
|
||||
self._event_queue,
|
||||
self.config
|
||||
@@ -145,7 +132,21 @@ class WecomPlatformAdapter(Platform):
|
||||
await self.convert_message(msg)
|
||||
|
||||
self.server.callback = callback
|
||||
|
||||
|
||||
@override
|
||||
async def send_by_session(self, session: MessageSesion, message_chain: MessageChain):
|
||||
await super().send_by_session(session, message_chain)
|
||||
|
||||
@override
|
||||
def meta(self) -> PlatformMetadata:
|
||||
return PlatformMetadata(
|
||||
"wecom",
|
||||
"wecom 适配器",
|
||||
)
|
||||
|
||||
@override
|
||||
async def run(self):
|
||||
await self.server.start_polling()
|
||||
|
||||
async def convert_message(self, msg):
|
||||
@@ -227,4 +228,7 @@ class WecomPlatformAdapter(Platform):
|
||||
session_id=message.session_id,
|
||||
client=self.client
|
||||
)
|
||||
self.commit_event(message_event)
|
||||
self.commit_event(message_event)
|
||||
|
||||
def get_client(self):
|
||||
return self.client
|
||||
@@ -9,6 +9,7 @@ from astrbot.core.provider.func_tool_manager import FuncCall
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from astrbot.core.provider.manager import ProviderManager
|
||||
from astrbot.core.platform import Platform
|
||||
from astrbot.core.platform.manager import PlatformManager
|
||||
from .star import star_registry, StarMetadata, star_map
|
||||
from .star_handler import star_handlers_registry, StarHandlerMetadata, EventType
|
||||
@@ -17,6 +18,8 @@ from .filter.regex import RegexFilter
|
||||
from typing import Awaitable
|
||||
from astrbot.core.rag.knowledge_db_mgr import KnowledgeDBManager
|
||||
from astrbot.core.conversation_mgr import ConversationManager
|
||||
from astrbot.core.star.filter.platform_adapter_type import PlatformAdapterType, ADAPTER_NAME_2_TYPE
|
||||
|
||||
|
||||
class Context:
|
||||
'''
|
||||
@@ -169,9 +172,21 @@ class Context:
|
||||
'''
|
||||
return self._event_queue
|
||||
|
||||
def get_platform(self, platform_type: Union[PlatformAdapterType, str]) -> Platform:
|
||||
'''
|
||||
获取指定类型的平台适配器。
|
||||
'''
|
||||
for platform in self.platform_manager.platform_insts:
|
||||
if isinstance(platform_type, str):
|
||||
if platform.meta().name == platform_type:
|
||||
return platform
|
||||
else:
|
||||
if platform.meta().name == ADAPTER_NAME_2_TYPE[platform_type]:
|
||||
return platform
|
||||
|
||||
async def send_message(self, session: Union[str, MessageSesion], message_chain: MessageChain) -> bool:
|
||||
'''
|
||||
根据 session(unified_msg_origin) 发送消息。
|
||||
根据 session(unified_msg_origin) 主动发送消息。
|
||||
|
||||
@param session: 消息会话。通过 event.session 或者 event.unified_msg_origin 获取。
|
||||
@param message_chain: 消息链。
|
||||
@@ -179,6 +194,8 @@ class Context:
|
||||
@return: 是否找到匹配的平台。
|
||||
|
||||
当 session 为字符串时,会尝试解析为 MessageSesion 对象,如果解析失败,会抛出 ValueError 异常。
|
||||
|
||||
NOTE: qq_official(QQ 官方 API 平台) 不支持此方法
|
||||
'''
|
||||
|
||||
if isinstance(session, str):
|
||||
@@ -192,7 +209,7 @@ class Context:
|
||||
await platform.send_by_session(session, message_chain)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
'''
|
||||
以下的方法已经不推荐使用。请从 AstrBot 文档查看更好的注册方式。
|
||||
'''
|
||||
@@ -224,7 +241,6 @@ class Context:
|
||||
'''删除一个函数调用工具。如果再要启用,需要重新注册。'''
|
||||
self.provider_manager.llm_tools.remove_func(name)
|
||||
|
||||
|
||||
def register_commands(self, star_name: str, command_name: str, desc: str, priority: int, awaitable: Awaitable, use_regex=False, ignore_prefix=False):
|
||||
'''
|
||||
注册一个命令。
|
||||
|
||||
@@ -7,6 +7,7 @@ from .star_handler import (
|
||||
register_regex,
|
||||
register_permission_type,
|
||||
register_custom_filter,
|
||||
register_on_astrbot_loaded,
|
||||
register_on_llm_request,
|
||||
register_on_llm_response,
|
||||
register_llm_tool,
|
||||
@@ -23,6 +24,7 @@ __all__ = [
|
||||
'register_regex',
|
||||
'register_permission_type',
|
||||
'register_custom_filter',
|
||||
'register_on_astrbot_loaded',
|
||||
'register_on_llm_request',
|
||||
'register_on_llm_response',
|
||||
'register_llm_tool',
|
||||
|
||||
@@ -205,6 +205,15 @@ def register_permission_type(permission_type: PermissionType, raise_error: bool
|
||||
|
||||
return decorator
|
||||
|
||||
def register_on_astrbot_loaded(**kwargs):
|
||||
'''当 AstrBot 加载完成时
|
||||
'''
|
||||
def decorator(awaitable):
|
||||
_ = get_handler_or_create(awaitable, EventType.OnAstrBotLoadedEvent, **kwargs)
|
||||
return awaitable
|
||||
|
||||
return decorator
|
||||
|
||||
def register_on_llm_request(**kwargs):
|
||||
'''当有 LLM 请求时的事件
|
||||
|
||||
|
||||
@@ -77,6 +77,8 @@ class EventType(enum.Enum):
|
||||
|
||||
用于对 Handler 的职能分组。
|
||||
'''
|
||||
OnAstrBotLoadedEvent = enum.auto() # AstrBot 加载完成
|
||||
|
||||
AdapterMessageEvent = enum.auto() # 收到适配器发来的消息
|
||||
OnLLMRequestEvent = enum.auto() # 收到 LLM 请求(可以是用户也可以是插件)
|
||||
OnLLMResponseEvent = enum.auto() # LLM 响应后
|
||||
|
||||
Reference in New Issue
Block a user