From a18de9de7da7c0fb0f6ed0e3aace4579518767f0 Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Sun, 2 Mar 2025 20:56:18 +0800 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8feat(plugin):=20=E6=B7=BB=E5=8A=A0=20A?= =?UTF-8?q?strBot=20=E5=90=AF=E5=8A=A8=E5=AE=8C=E6=88=90=E6=97=B6=E7=9A=84?= =?UTF-8?q?=E4=BA=8B=E4=BB=B6=E9=92=A9=E5=AD=90=EF=BC=9B=E6=B7=BB=E5=8A=A0?= =?UTF-8?q?=E8=8E=B7=E5=8F=96=E5=88=B6=E5=AE=9A=E5=B9=B3=E5=8F=B0=E9=80=82?= =?UTF-8?q?=E9=85=8D=E5=99=A8=E7=9A=84=E6=8E=A5=E5=8F=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/api/event/filter/__init__.py | 2 + astrbot/core/core_lifecycle.py | 12 ++++ astrbot/core/event_bus.py | 1 - astrbot/core/platform/astr_message_event.py | 10 +-- astrbot/core/platform/platform.py | 8 ++- .../aiocqhttp/aiocqhttp_platform_adapter.py | 68 ++++++++++--------- .../core/platform/sources/gewechat/client.py | 2 +- .../gewechat/gewechat_platform_adapter.py | 41 +++++------ .../platform/sources/lark/lark_adapter.py | 9 ++- .../qqofficial/qqofficial_platform_adapter.py | 5 +- .../qqofficial_webhook/qo_webhook_adapter.py | 6 +- .../platform/sources/telegram/tg_adapter.py | 28 ++++---- .../platform/sources/wecom/wecom_adapter.py | 34 ++++++---- astrbot/core/star/context.py | 22 +++++- astrbot/core/star/register/__init__.py | 2 + astrbot/core/star/register/star_handler.py | 9 +++ astrbot/core/star/star_handler.py | 2 + 17 files changed, 162 insertions(+), 99 deletions(-) diff --git a/astrbot/api/event/filter/__init__.py b/astrbot/api/event/filter/__init__.py index 7174be0fd..646800f55 100644 --- a/astrbot/api/event/filter/__init__.py +++ b/astrbot/api/event/filter/__init__.py @@ -6,6 +6,7 @@ from astrbot.core.star.register import ( register_platform_adapter_type as platform_adapter_type, register_permission_type as permission_type, register_custom_filter as custom_filter, + register_on_astrbot_loaded as on_astrbot_loaded, register_on_llm_request as on_llm_request, register_on_llm_response as on_llm_response, register_llm_tool as llm_tool, @@ -33,6 +34,7 @@ __all__ = [ 'CustomFilter', 'custom_filter', 'PermissionType', + 'on_astrbot_loaded', 'on_llm_request', 'llm_tool', 'on_decorating_result', diff --git a/astrbot/core/core_lifecycle.py b/astrbot/core/core_lifecycle.py index c2a5f4838..924dd4137 100644 --- a/astrbot/core/core_lifecycle.py +++ b/astrbot/core/core_lifecycle.py @@ -19,6 +19,9 @@ from astrbot.core import logger from astrbot.core.config.default import VERSION from astrbot.core.rag.knowledge_db_mgr import KnowledgeDBManager from astrbot.core.conversation_mgr import ConversationManager +from astrbot.core.star.star_handler import star_handlers_registry, EventType +from astrbot.core.star.star_handler import star_map + class AstrBotCoreLifecycle: def __init__(self, log_broker: LogBroker, db: BaseDatabase): self.log_broker = log_broker @@ -104,6 +107,15 @@ class AstrBotCoreLifecycle: self._load() logger.info("AstrBot 启动完成。") + # 执行启动完成事件钩子 + handlers = star_handlers_registry.get_handlers_by_event_type(EventType.OnAstrBotLoadedEvent) + for handler in handlers: + try: + logger.info(f"hook(on_astrbot_loaded) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}") + await handler.handler() + except BaseException: + logger.error(traceback.format_exc()) + await asyncio.gather(*self.curr_tasks, return_exceptions=True) async def stop(self): diff --git a/astrbot/core/event_bus.py b/astrbot/core/event_bus.py index 35b41f1e8..d6039c15f 100644 --- a/astrbot/core/event_bus.py +++ b/astrbot/core/event_bus.py @@ -10,7 +10,6 @@ class EventBus: self.pipeline_scheduler = pipeline_scheduler async def dispatch(self): - logger.info("事件总线已打开。") while True: event: AstrMessageEvent = await self.event_queue.get() self._print_event(event) diff --git a/astrbot/core/platform/astr_message_event.py b/astrbot/core/platform/astr_message_event.py index 05ffac6f8..9569e281e 100644 --- a/astrbot/core/platform/astr_message_event.py +++ b/astrbot/core/platform/astr_message_event.py @@ -33,17 +33,17 @@ class AstrMessageEvent(abc.ABC): self.message_str = message_str '''纯文本的消息''' self.message_obj = message_obj - '''消息对象,AstrBotMessage。带有完整的消息结构。''' + '''消息对象, AstrBotMessage。带有完整的消息结构。''' self.platform_meta = platform_meta '''消息平台的信息, 其中 name 是平台的类型,如 aiocqhttp''' self.session_id = session_id '''用户的会话 ID。可以直接使用下面的 unified_msg_origin''' self.role = "member" '''用户是否是管理员。如果是管理员,这里是 admin''' - self.is_wake = False # 是否通过 WakingStage - '''是否唤醒''' + self.is_wake = False + '''是否唤醒(是否通过 WakingStage)''' self.is_at_or_wake_command = False - '''是否是 At 机器人或者带有唤醒词或者是私聊(事件监听器会让 is_wake 设为 True,但是不会让这个属性置为 True)''' + '''是否是 At 机器人或者带有唤醒词或者是私聊(插件注册的事件监听器会让 is_wake 设为 True, 但是不会让这个属性置为 True)''' self._extras = {} self.session = MessageSesion( platform_name=platform_meta.name, @@ -56,7 +56,7 @@ class AstrMessageEvent(abc.ABC): '''消息事件的结果''' self._has_send_oper = False - '''是否有过至少一次发送操作''' + '''在此次事件中是否有过至少一次发送消息的操作''' self.call_llm = False '''是否在此消息事件中禁止默认的 LLM 请求''' diff --git a/astrbot/core/platform/platform.py b/astrbot/core/platform/platform.py index 3526d2802..1ae8afb4d 100644 --- a/astrbot/core/platform/platform.py +++ b/astrbot/core/platform/platform.py @@ -45,4 +45,10 @@ class Platform(abc.ABC): ''' 提交一个事件到事件队列。 ''' - self._event_queue.put_nowait(event) \ No newline at end of file + self._event_queue.put_nowait(event) + + def get_client(self): + ''' + 获取平台的客户端对象。 + ''' + pass \ No newline at end of file diff --git a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py index cbaa5137b..83689d6b6 100644 --- a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py +++ b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py @@ -34,6 +34,36 @@ class AiocqhttpAdapter(Platform): self.stop = False + self.bot = CQHttp(use_ws_reverse=True, import_name='aiocqhttp', api_timeout_sec=180) + + @self.bot.on_request() + async def request(event: Event): + abm = await self.convert_message(event) + if abm: + await self.handle_msg(abm) + + @self.bot.on_notice() + async def notice(event: Event): + abm = await self.convert_message(event) + if abm: + await self.handle_msg(abm) + + @self.bot.on_message('group') + async def group(event: Event): + abm = await self.convert_message(event) + if abm: + await self.handle_msg(abm) + + @self.bot.on_message('private') + async def private(event: Event): + abm = await self.convert_message(event) + if abm: + await self.handle_msg(abm) + + @self.bot.on_websocket_connection + def on_websocket_connection(_): + logger.info("aiocqhttp(OneBot v11) 适配器已连接。") + async def send_by_session(self, session: MessageSesion, message_chain: MessageChain): ret = await AiocqhttpMessageEvent._parse_onebot_json(message_chain) match session.message_type.value: @@ -113,7 +143,6 @@ class AiocqhttpAdapter(Platform): return abm - async def _convert_handle_message_event(self, event: Event) -> AstrBotMessage: '''OneBot V11 消息类事件''' abm = AstrBotMessage() @@ -202,44 +231,14 @@ class AiocqhttpAdapter(Platform): logger.warning("aiocqhttp: 未配置 ws_reverse_host 或 ws_reverse_port,将使用默认值:http://0.0.0.0:6199") self.host = "0.0.0.0" self.port = 6199 - - self.bot = CQHttp(use_ws_reverse=True, import_name='aiocqhttp', api_timeout_sec=180) - @self.bot.on_request() - async def request(event: Event): - abm = await self.convert_message(event) - if abm: - await self.handle_msg(abm) - - @self.bot.on_notice() - async def notice(event: Event): - abm = await self.convert_message(event) - if abm: - await self.handle_msg(abm) - - @self.bot.on_message('group') - async def group(event: Event): - abm = await self.convert_message(event) - if abm: - await self.handle_msg(abm) - - @self.bot.on_message('private') - async def private(event: Event): - abm = await self.convert_message(event) - if abm: - await self.handle_msg(abm) - - @self.bot.on_websocket_connection - def on_websocket_connection(_): - logger.info("aiocqhttp(OneBot v11) 适配器已连接。") - - bot = self.bot.run_task(host=self.host, port=int(self.port), shutdown_trigger=self.shutdown_trigger_placeholder) + coro = self.bot.run_task(host=self.host, port=int(self.port), shutdown_trigger=self.shutdown_trigger_placeholder) for handler in logging.root.handlers[:]: logging.root.removeHandler(handler) logging.getLogger('aiocqhttp').setLevel(logging.ERROR) - return bot + return coro async def terminate(self): self.stop = True @@ -264,3 +263,6 @@ class AiocqhttpAdapter(Platform): ) self.commit_event(message_event) + + async def get_client(self): + return self.bot \ No newline at end of file diff --git a/astrbot/core/platform/sources/gewechat/client.py b/astrbot/core/platform/sources/gewechat/client.py index 79f61c774..f26831cd9 100644 --- a/astrbot/core/platform/sources/gewechat/client.py +++ b/astrbot/core/platform/sources/gewechat/client.py @@ -320,7 +320,7 @@ class SimpleGewechatClient(): logger.info(f"使用验证码: {code}") try: os.remove("data/temp/gewe_code") - except: + except Exception: logger.warning("删除验证码文件 data/temp/gewe_code 失败。") async with aiohttp.ClientSession() as session: diff --git a/astrbot/core/platform/sources/gewechat/gewechat_platform_adapter.py b/astrbot/core/platform/sources/gewechat/gewechat_platform_adapter.py index 7e325ce8b..ffbfb6c94 100644 --- a/astrbot/core/platform/sources/gewechat/gewechat_platform_adapter.py +++ b/astrbot/core/platform/sources/gewechat/gewechat_platform_adapter.py @@ -26,6 +26,19 @@ class GewechatPlatformAdapter(Platform): self.settingss = platform_settings self.test_mode = os.environ.get('TEST_MODE', 'off') == 'on' self.client = None + + self.client = SimpleGewechatClient( + self.config['base_url'], + self.config['nickname'], + self.config['host'], + self.config['port'], + self._event_queue, + ) + + async def on_event_received(abm: AstrBotMessage): + await self.handle_msg(abm) + + self.client.on_event_received = on_event_received @override async def send_by_session(self, session: MessageSesion, message_chain: MessageChain): @@ -50,32 +63,17 @@ class GewechatPlatformAdapter(Platform): async def terminate(self): self.client.stop = True await asyncio.sleep(1) - - @override - def run(self): - self.client = SimpleGewechatClient( - self.config['base_url'], - self.config['nickname'], - self.config['host'], - self.config['port'], - self._event_queue, - ) - - async def on_event_received(abm: AstrBotMessage): - await self.handle_msg(abm) - - self.client.on_event_received = on_event_received - - return self._run() async def logout(self): await self.client.logout() + + @override + def run(self): + return self._run() async def _run(self): await self.client.login() - await self.client.start_polling() - async def handle_msg(self, message: AstrBotMessage): if message.type == MessageType.GROUP_MESSAGE: @@ -90,4 +88,7 @@ class GewechatPlatformAdapter(Platform): client=self.client ) - self.commit_event(message_event) \ No newline at end of file + self.commit_event(message_event) + + def get_client(self): + return self.client \ No newline at end of file diff --git a/astrbot/core/platform/sources/lark/lark_adapter.py b/astrbot/core/platform/sources/lark/lark_adapter.py index 6a9dc578d..d0ca5a5b8 100644 --- a/astrbot/core/platform/sources/lark/lark_adapter.py +++ b/astrbot/core/platform/sources/lark/lark_adapter.py @@ -1,17 +1,14 @@ import base64 -import time import asyncio import json import re from astrbot.api.platform import Platform, AstrBotMessage, MessageMember, MessageType, PlatformMetadata from astrbot.api.event import MessageChain -from typing import Union, List from astrbot.api.message_components import Image, Plain, At from astrbot.core.platform.astr_message_event import MessageSesion from .lark_event import LarkMessageEvent from ...register import register_platform_adapter -from astrbot.core.message.components import BaseMessageComponent from astrbot import logger import lark_oapi as lark from lark_oapi.api.im.v1 import * @@ -91,7 +88,7 @@ class LarkPlatformAdapter(Platform): if message.message_type == 'text': message_str_raw = content_json_b['text'] # 带有 @ 的消息 at_pattern = r"(@_user_\d+)" # 可以根据需求修改正则 - at_users = re.findall(at_pattern, message_str_raw) + # at_users = re.findall(at_pattern, message_str_raw) # 拆分文本,去掉AT符号部分 parts = re.split(at_pattern, message_str_raw) for i in range(len(parts)): @@ -172,4 +169,6 @@ class LarkPlatformAdapter(Platform): async def run(self): # self.client.start() await self.client._connect() - \ No newline at end of file + + async def get_client(self): + return self.client \ No newline at end of file diff --git a/astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py b/astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py index 4548bbf3c..d0a9a4dcb 100644 --- a/astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py +++ b/astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py @@ -179,4 +179,7 @@ class QQOfficialPlatformAdapter(Platform): return self.client.start( appid=self.appid, secret=self.secret - ) \ No newline at end of file + ) + + def get_client(self): + return self.client \ No newline at end of file diff --git a/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py b/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py index 90e6f7e5b..1f131cb68 100644 --- a/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py +++ b/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py @@ -96,4 +96,8 @@ class QQOfficialWebhookPlatformAdapter(Platform): self.client ) await self.webhook_helper.initialize() - await self.webhook_helper.start_polling() \ No newline at end of file + await self.webhook_helper.start_polling() + + + async def get_client(self): + return self.client \ No newline at end of file diff --git a/astrbot/core/platform/sources/telegram/tg_adapter.py b/astrbot/core/platform/sources/telegram/tg_adapter.py index 82366e5ca..446d9ec38 100644 --- a/astrbot/core/platform/sources/telegram/tg_adapter.py +++ b/astrbot/core/platform/sources/telegram/tg_adapter.py @@ -28,6 +28,17 @@ class TelegramPlatformAdapter(Platform): self.config = platform_config self.settings = platform_settings self.client_self_id = uuid.uuid4().hex[:8] + + base_url = self.config.get("telegram_api_base_url", "https://api.telegram.org/bot") + if not base_url: + base_url = "https://api.telegram.org/bot" + self.application = ApplicationBuilder().token(self.config['telegram_token']).base_url(base_url).build() + message_handler = TelegramMessageHandler( + filters=filters.ALL, # receive all messages + callback=self.convert_message + ) + self.application.add_handler(message_handler) + self.client = self.application.bot @override async def send_by_session(self, session: MessageSesion, message_chain: MessageChain): @@ -44,22 +55,10 @@ class TelegramPlatformAdapter(Platform): @override async def run(self): - base_url = self.config.get("telegram_api_base_url", "https://api.telegram.org/bot") - if not base_url: - base_url = "https://api.telegram.org/bot" - - self.application = ApplicationBuilder().token(self.config['telegram_token']).base_url(base_url).build() - message_handler = TelegramMessageHandler( - filters=filters.ALL, # receive all messages - callback=self.convert_message - ) - self.application.add_handler(message_handler) await self.application.initialize() await self.application.start() queue = self.application.updater.start_polling() - self.client = self.application.bot logger.info("Telegram Platform Adapter is running.") - await queue async def start(self, update: Update, context: ContextTypes.DEFAULT_TYPE): @@ -125,4 +124,7 @@ class TelegramPlatformAdapter(Platform): session_id=message.session_id, client=self.client ) - self.commit_event(message_event) \ No newline at end of file + self.commit_event(message_event) + + async def get_client(self): + return self.client \ No newline at end of file diff --git a/astrbot/core/platform/sources/wecom/wecom_adapter.py b/astrbot/core/platform/sources/wecom/wecom_adapter.py index 0127342fe..6bfea54e7 100644 --- a/astrbot/core/platform/sources/wecom/wecom_adapter.py +++ b/astrbot/core/platform/sources/wecom/wecom_adapter.py @@ -116,20 +116,7 @@ class WecomPlatformAdapter(Platform): if not self.api_base_url.endswith("/"): self.api_base_url += "/" - - @override - async def send_by_session(self, session: MessageSesion, message_chain: MessageChain): - await super().send_by_session(session, message_chain) - - @override - def meta(self) -> PlatformMetadata: - return PlatformMetadata( - "wecom", - "wecom 适配器", - ) - - @override - async def run(self): + self.server = WecomServer( self._event_queue, self.config @@ -145,7 +132,21 @@ class WecomPlatformAdapter(Platform): await self.convert_message(msg) self.server.callback = callback + + @override + async def send_by_session(self, session: MessageSesion, message_chain: MessageChain): + await super().send_by_session(session, message_chain) + + @override + def meta(self) -> PlatformMetadata: + return PlatformMetadata( + "wecom", + "wecom 适配器", + ) + + @override + async def run(self): await self.server.start_polling() async def convert_message(self, msg): @@ -227,4 +228,7 @@ class WecomPlatformAdapter(Platform): session_id=message.session_id, client=self.client ) - self.commit_event(message_event) \ No newline at end of file + self.commit_event(message_event) + + def get_client(self): + return self.client \ No newline at end of file diff --git a/astrbot/core/star/context.py b/astrbot/core/star/context.py index 18b5200fe..2ebb20f8e 100644 --- a/astrbot/core/star/context.py +++ b/astrbot/core/star/context.py @@ -9,6 +9,7 @@ from astrbot.core.provider.func_tool_manager import FuncCall from astrbot.core.platform.astr_message_event import MessageSesion from astrbot.core.message.message_event_result import MessageChain from astrbot.core.provider.manager import ProviderManager +from astrbot.core.platform import Platform from astrbot.core.platform.manager import PlatformManager from .star import star_registry, StarMetadata, star_map from .star_handler import star_handlers_registry, StarHandlerMetadata, EventType @@ -17,6 +18,8 @@ from .filter.regex import RegexFilter from typing import Awaitable from astrbot.core.rag.knowledge_db_mgr import KnowledgeDBManager from astrbot.core.conversation_mgr import ConversationManager +from astrbot.core.star.filter.platform_adapter_type import PlatformAdapterType, ADAPTER_NAME_2_TYPE + class Context: ''' @@ -169,9 +172,21 @@ class Context: ''' return self._event_queue + def get_platform(self, platform_type: Union[PlatformAdapterType, str]) -> Platform: + ''' + 获取指定类型的平台适配器。 + ''' + for platform in self.platform_manager.platform_insts: + if isinstance(platform_type, str): + if platform.meta().name == platform_type: + return platform + else: + if platform.meta().name == ADAPTER_NAME_2_TYPE[platform_type]: + return platform + async def send_message(self, session: Union[str, MessageSesion], message_chain: MessageChain) -> bool: ''' - 根据 session(unified_msg_origin) 发送消息。 + 根据 session(unified_msg_origin) 主动发送消息。 @param session: 消息会话。通过 event.session 或者 event.unified_msg_origin 获取。 @param message_chain: 消息链。 @@ -179,6 +194,8 @@ class Context: @return: 是否找到匹配的平台。 当 session 为字符串时,会尝试解析为 MessageSesion 对象,如果解析失败,会抛出 ValueError 异常。 + + NOTE: qq_official(QQ 官方 API 平台) 不支持此方法 ''' if isinstance(session, str): @@ -192,7 +209,7 @@ class Context: await platform.send_by_session(session, message_chain) return True return False - + ''' 以下的方法已经不推荐使用。请从 AstrBot 文档查看更好的注册方式。 ''' @@ -224,7 +241,6 @@ class Context: '''删除一个函数调用工具。如果再要启用,需要重新注册。''' self.provider_manager.llm_tools.remove_func(name) - def register_commands(self, star_name: str, command_name: str, desc: str, priority: int, awaitable: Awaitable, use_regex=False, ignore_prefix=False): ''' 注册一个命令。 diff --git a/astrbot/core/star/register/__init__.py b/astrbot/core/star/register/__init__.py index ba51f7ab6..705d026be 100644 --- a/astrbot/core/star/register/__init__.py +++ b/astrbot/core/star/register/__init__.py @@ -7,6 +7,7 @@ from .star_handler import ( register_regex, register_permission_type, register_custom_filter, + register_on_astrbot_loaded, register_on_llm_request, register_on_llm_response, register_llm_tool, @@ -23,6 +24,7 @@ __all__ = [ 'register_regex', 'register_permission_type', 'register_custom_filter', + 'register_on_astrbot_loaded', 'register_on_llm_request', 'register_on_llm_response', 'register_llm_tool', diff --git a/astrbot/core/star/register/star_handler.py b/astrbot/core/star/register/star_handler.py index 2cf5aaeeb..999fd562a 100644 --- a/astrbot/core/star/register/star_handler.py +++ b/astrbot/core/star/register/star_handler.py @@ -205,6 +205,15 @@ def register_permission_type(permission_type: PermissionType, raise_error: bool return decorator +def register_on_astrbot_loaded(**kwargs): + '''当 AstrBot 加载完成时 + ''' + def decorator(awaitable): + _ = get_handler_or_create(awaitable, EventType.OnAstrBotLoadedEvent, **kwargs) + return awaitable + + return decorator + def register_on_llm_request(**kwargs): '''当有 LLM 请求时的事件 diff --git a/astrbot/core/star/star_handler.py b/astrbot/core/star/star_handler.py index 7fee47d2b..ee927eac9 100644 --- a/astrbot/core/star/star_handler.py +++ b/astrbot/core/star/star_handler.py @@ -77,6 +77,8 @@ class EventType(enum.Enum): 用于对 Handler 的职能分组。 ''' + OnAstrBotLoadedEvent = enum.auto() # AstrBot 加载完成 + AdapterMessageEvent = enum.auto() # 收到适配器发来的消息 OnLLMRequestEvent = enum.auto() # 收到 LLM 请求(可以是用户也可以是插件) OnLLMResponseEvent = enum.auto() # LLM 响应后