diff --git a/astrbot/core/platform/astr_message_event.py b/astrbot/core/platform/astr_message_event.py index 72e4414b6..a85ec0b9b 100644 --- a/astrbot/core/platform/astr_message_event.py +++ b/astrbot/core/platform/astr_message_event.py @@ -1,11 +1,9 @@ import abc import asyncio from dataclasses import dataclass -from .astrbot_message import AstrBotMessage -from .platform_metadata import PlatformMetadata -from astrbot.core.message.message_event_result import MessageEventResult, MessageChain -from astrbot.core.platform.message_type import MessageType -from typing import List, Union +from typing import List, Union, Optional + +from astrbot.core.db.po import Conversation from astrbot.core.message.components import ( Plain, Image, @@ -15,9 +13,12 @@ from astrbot.core.message.components import ( AtAll, Forward, ) -from astrbot.core.utils.metrics import Metric +from astrbot.core.message.message_event_result import MessageEventResult, MessageChain +from astrbot.core.platform.message_type import MessageType from astrbot.core.provider.entites import ProviderRequest -from astrbot.core.db.po import Conversation +from astrbot.core.utils.metrics import Metric +from .astrbot_message import AstrBotMessage, Group +from .platform_metadata import PlatformMetadata @dataclass @@ -37,11 +38,11 @@ class MessageSesion: class AstrMessageEvent(abc.ABC): def __init__( - self, - message_str: str, - message_obj: AstrBotMessage, - platform_meta: PlatformMetadata, - session_id: str, + self, + message_str: str, + message_obj: AstrBotMessage, + platform_meta: PlatformMetadata, + session_id: str, ): self.message_str = message_str """纯文本的消息""" @@ -320,14 +321,14 @@ class AstrMessageEvent(abc.ABC): """LLM 请求相关""" def request_llm( - self, - prompt: str, - func_tool_manager=None, - session_id: str = None, - image_urls: List[str] = [], - contexts: List = [], - system_prompt: str = "", - conversation: Conversation = None, + self, + prompt: str, + func_tool_manager=None, + session_id: str = None, + image_urls: List[str] = [], + contexts: List = [], + system_prompt: str = "", + conversation: Conversation = None, ) -> ProviderRequest: """ 创建一个 LLM 请求。 @@ -363,3 +364,37 @@ class AstrMessageEvent(abc.ABC): system_prompt=system_prompt, conversation=conversation, ) + + async def get_group(self, group_id: str = None) -> Optional[Group]: + """ + 获取群聊,如果不填写group_id,且消息是私聊消息,则返回 None + 目前只实现了 GeweChat 协议 + """ + # 确定有效的 group_id + if group_id is None: + group_id = self.message_obj.group_id + + if group_id is None: + return None + + # 检查平台是否为 gewechat + if self.platform_meta.name != "gewechat": + return None + + from astrbot.core.platform.sources.gewechat.gewechat_event import ( + GewechatPlatformEvent, + ) + + assert isinstance(self, GewechatPlatformEvent) + client = self.client + + # 从客户端获取群信息 + res = await client.get_group(group_id) + + data = res["data"] + + # 检查 chatroomId 是否为空 + if data["chatroomId"] == "": + return None + + return Group.from_dict(data)