From 87f05fce669dbbe1845dd3aaaa6902dfd8d98fe6 Mon Sep 17 00:00:00 2001 From: Soulter <37870767+Soulter@users.noreply.github.com> Date: Sat, 2 Aug 2025 21:38:55 +0800 Subject: [PATCH] =?UTF-8?q?Fix:=20=E5=BD=93=E5=A4=9A=E4=B8=AA=E7=9B=B8?= =?UTF-8?q?=E5=90=8C=E6=B6=88=E6=81=AF=E5=B9=B3=E5=8F=B0=E5=AE=9E=E4=BE=8B?= =?UTF-8?q?=E9=83=A8=E7=BD=B2=E6=97=B6=E4=B8=8A=E4=B8=8B=E6=96=87=E5=8F=AF?= =?UTF-8?q?=E8=83=BD=E6=B7=B7=E4=B9=B1=EF=BC=88=E5=85=B1=E4=BA=AB=EF=BC=89?= =?UTF-8?q?=20(#2298)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * perf: update astrbot event session format, using platfrom id to ensure uniqueness fixes: #1000 * fix: 更新 MessageSession 类以使用 platform_id 作为唯一标识符,并调整相关方法以确保一致性 * fix: 更新 MessageSession 文档以明确 platform_id 的赋值规则,并调整 get_platform 和 get_platform_inst 方法的返回类型 --- astrbot/core/core_lifecycle.py | 2 +- astrbot/core/event_bus.py | 4 +-- astrbot/core/pipeline/respond/stage.py | 2 +- astrbot/core/platform/astr_message_event.py | 27 +++++++++++++++---- astrbot/core/platform/platform_metadata.py | 2 +- astrbot/core/star/context.py | 29 ++++++++++++++++++--- packages/astrbot/main.py | 2 +- 7 files changed, 54 insertions(+), 14 deletions(-) diff --git a/astrbot/core/core_lifecycle.py b/astrbot/core/core_lifecycle.py index 8412d5bea..025bf03e5 100644 --- a/astrbot/core/core_lifecycle.py +++ b/astrbot/core/core_lifecycle.py @@ -232,6 +232,6 @@ class AstrBotCoreLifecycle: platform_insts = self.platform_manager.get_insts() for platform_inst in platform_insts: tasks.append( - asyncio.create_task(platform_inst.run(), name=platform_inst.meta().name) + asyncio.create_task(platform_inst.run(), name=f"{platform_inst.meta().id}({platform_inst.meta().name})") ) return tasks diff --git a/astrbot/core/event_bus.py b/astrbot/core/event_bus.py index d4caa2910..5010a0645 100644 --- a/astrbot/core/event_bus.py +++ b/astrbot/core/event_bus.py @@ -48,10 +48,10 @@ class EventBus: # 如果有发送者名称: [平台名] 发送者名称/发送者ID: 消息概要 if event.get_sender_name(): logger.info( - f"[{event.get_platform_name()}] {event.get_sender_name()}/{event.get_sender_id()}: {event.get_message_outline()}" + f"[{event.get_platform_id()}({event.get_platform_name()})] {event.get_sender_name()}/{event.get_sender_id()}: {event.get_message_outline()}" ) # 没有发送者名称: [平台名] 发送者ID: 消息概要 else: logger.info( - f"[{event.get_platform_name()}] {event.get_sender_id()}: {event.get_message_outline()}" + f"[{event.get_platform_id()}({event.get_platform_name()})] {event.get_sender_id()}: {event.get_message_outline()}" ) diff --git a/astrbot/core/pipeline/respond/stage.py b/astrbot/core/pipeline/respond/stage.py index 54ad1e63b..ffdebf276 100644 --- a/astrbot/core/pipeline/respond/stage.py +++ b/astrbot/core/pipeline/respond/stage.py @@ -128,7 +128,7 @@ class RespondStage(Stage): use_fallback = self.config.get("provider_settings", {}).get( "streaming_segmented", False ) - logger.info(f"应用流式输出({event.get_platform_name()})") + logger.info(f"应用流式输出({event.get_platform_id()})") await event.send_streaming(result.async_stream, use_fallback) return elif len(result.chain) > 0: diff --git a/astrbot/core/platform/astr_message_event.py b/astrbot/core/platform/astr_message_event.py index 8be20be73..5efbfb2f6 100644 --- a/astrbot/core/platform/astr_message_event.py +++ b/astrbot/core/platform/astr_message_event.py @@ -27,19 +27,29 @@ from .platform_metadata import PlatformMetadata @dataclass class MessageSession: + """描述一条消息在 AstrBot 中对应的会话的唯一标识。 + 如果您需要实例化 MessageSession,请不要给 platform_id 赋值(或者同时给 platform_name 和 platform_id 赋值相同值)。它会在 __post_init__ 中自动设置为 platform_name 的值。""" + platform_name: str + """平台适配器实例的唯一标识符。自 AstrBot v4.0.0 起,该字段实际为 platform_id。""" message_type: MessageType session_id: str + platform_id: str = None def __str__(self): - return f"{self.platform_name}:{self.message_type.value}:{self.session_id}" + return f"{self.platform_id}:{self.message_type.value}:{self.session_id}" + + def __post_init__(self): + self.platform_id = self.platform_name @staticmethod def from_str(session_str: str): - platform_name, message_type, session_id = session_str.split(":") - return MessageSession(platform_name, MessageType(message_type), session_id) + platform_id, message_type, session_id = session_str.split(":") + return MessageSession(platform_id, MessageType(message_type), session_id) + + +MessageSesion = MessageSession # back compatibility -MessageSesion = MessageSession # back compatibility class AstrMessageEvent(abc.ABC): def __init__( @@ -65,7 +75,7 @@ class AstrMessageEvent(abc.ABC): """是否是 At 机器人或者带有唤醒词或者是私聊(插件注册的事件监听器会让 is_wake 设为 True, 但是不会让这个属性置为 True)""" self._extras = {} self.session = MessageSesion( - platform_name=platform_meta.name, + platform_name=platform_meta.id, message_type=message_obj.type, session_id=session_id, ) @@ -83,9 +93,16 @@ class AstrMessageEvent(abc.ABC): self.platform = platform_meta def get_platform_name(self): + """获取这个事件所属的平台的类型(如 aiocqhttp, slack, discord 等)。 + + NOTE: 用户可能会同时运行多个相同类型的平台适配器。""" return self.platform_meta.name def get_platform_id(self): + """获取这个事件所属的平台的 ID。 + + NOTE: 用户可能会同时运行多个相同类型的平台适配器,但能确定的是 ID 是唯一的。 + """ return self.platform_meta.id def get_message_str(self) -> str: diff --git a/astrbot/core/platform/platform_metadata.py b/astrbot/core/platform/platform_metadata.py index dd0e93fec..7fb7f9d3e 100644 --- a/astrbot/core/platform/platform_metadata.py +++ b/astrbot/core/platform/platform_metadata.py @@ -4,7 +4,7 @@ from dataclasses import dataclass @dataclass class PlatformMetadata: name: str - """平台的名称""" + """平台的名称,即平台的类型,如 aiocqhttp, discord, slack""" description: str """平台的描述""" id: str = None diff --git a/astrbot/core/star/context.py b/astrbot/core/star/context.py index 0b14525d3..6d89da57b 100644 --- a/astrbot/core/star/context.py +++ b/astrbot/core/star/context.py @@ -2,7 +2,12 @@ from asyncio import Queue from typing import List, Union from astrbot.core import sp -from astrbot.core.provider.provider import Provider, TTSProvider, STTProvider, EmbeddingProvider +from astrbot.core.provider.provider import ( + Provider, + TTSProvider, + STTProvider, + EmbeddingProvider, +) from astrbot.core.provider.entities import ProviderType from astrbot.core.db import BaseDatabase from astrbot.core.config.astrbot_config import AstrBotConfig @@ -22,6 +27,7 @@ from astrbot.core.star.filter.platform_adapter_type import ( PlatformAdapterType, ADAPTER_NAME_2_TYPE, ) +from deprecated import deprecated class Context: @@ -201,9 +207,12 @@ class Context: """ return self._event_queue - def get_platform(self, platform_type: Union[PlatformAdapterType, str]) -> Platform: + @deprecated(version="4.0.0", reason="Use get_platform_inst instead") + def get_platform(self, platform_type: Union[PlatformAdapterType, str]) -> Platform | None: """ 获取指定类型的平台适配器。 + + 该方法已经过时,请使用 get_platform_inst 方法。(>= AstrBot v4.0.0) """ for platform in self.platform_manager.platform_insts: name = platform.meta().name @@ -217,6 +226,20 @@ class Context: ): return platform + def get_platform_inst(self, platform_id: str) -> Platform | None: + """ + 获取指定 ID 的平台适配器实例。 + + Args: + platform_id (str): 平台适配器的唯一标识符。你可以通过 event.get_platform_id() 获取。 + + Returns: + Platform: 平台适配器实例,如果未找到则返回 None。 + """ + for platform in self.platform_manager.platform_insts: + if platform.meta().id == platform_id: + return platform + async def send_message( self, session: Union[str, MessageSesion], message_chain: MessageChain ) -> bool: @@ -240,7 +263,7 @@ class Context: raise ValueError("不合法的 session 字符串: " + str(e)) for platform in self.platform_manager.platform_insts: - if platform.meta().name == session.platform_name: + if platform.meta().id == session.platform_name: await platform.send_by_session(session, message_chain) return True return False diff --git a/packages/astrbot/main.py b/packages/astrbot/main.py index 1dd2cbe2f..8a73945f5 100644 --- a/packages/astrbot/main.py +++ b/packages/astrbot/main.py @@ -820,7 +820,7 @@ UID: {user_id} 此 ID 可用于设置管理员。 if sid: session = str( MessageSesion( - platform_name=message.platform_meta.name, + platform_name=message.platform_meta.id, message_type=MessageType("GroupMessage"), session_id=sid, )