style:idea默认格式化了部分代码

feat:添加根据消息事件获取群信息的接口
This commit is contained in:
Moyuyanli
2025-03-11 17:10:55 +08:00
parent 76cfc31a1d
commit 4a8309ed1f
+55 -20
View File
@@ -1,11 +1,9 @@
import abc
import asyncio
from dataclasses import dataclass
from .astrbot_message import AstrBotMessage
from .platform_metadata import PlatformMetadata
from astrbot.core.message.message_event_result import MessageEventResult, MessageChain
from astrbot.core.platform.message_type import MessageType
from typing import List, Union
from typing import List, Union, Optional
from astrbot.core.db.po import Conversation
from astrbot.core.message.components import (
Plain,
Image,
@@ -15,9 +13,12 @@ from astrbot.core.message.components import (
AtAll,
Forward,
)
from astrbot.core.utils.metrics import Metric
from astrbot.core.message.message_event_result import MessageEventResult, MessageChain
from astrbot.core.platform.message_type import MessageType
from astrbot.core.provider.entites import ProviderRequest
from astrbot.core.db.po import Conversation
from astrbot.core.utils.metrics import Metric
from .astrbot_message import AstrBotMessage, Group
from .platform_metadata import PlatformMetadata
@dataclass
@@ -37,11 +38,11 @@ class MessageSesion:
class AstrMessageEvent(abc.ABC):
def __init__(
self,
message_str: str,
message_obj: AstrBotMessage,
platform_meta: PlatformMetadata,
session_id: str,
self,
message_str: str,
message_obj: AstrBotMessage,
platform_meta: PlatformMetadata,
session_id: str,
):
self.message_str = message_str
"""纯文本的消息"""
@@ -320,14 +321,14 @@ class AstrMessageEvent(abc.ABC):
"""LLM 请求相关"""
def request_llm(
self,
prompt: str,
func_tool_manager=None,
session_id: str = None,
image_urls: List[str] = [],
contexts: List = [],
system_prompt: str = "",
conversation: Conversation = None,
self,
prompt: str,
func_tool_manager=None,
session_id: str = None,
image_urls: List[str] = [],
contexts: List = [],
system_prompt: str = "",
conversation: Conversation = None,
) -> ProviderRequest:
"""
创建一个 LLM 请求。
@@ -363,3 +364,37 @@ class AstrMessageEvent(abc.ABC):
system_prompt=system_prompt,
conversation=conversation,
)
async def get_group(self, group_id: str = None) -> Optional[Group]:
"""
获取群聊,如果不填写group_id,且消息是私聊消息,则返回 None
目前只实现了 GeweChat 协议
"""
# 确定有效的 group_id
if group_id is None:
group_id = self.message_obj.group_id
if group_id is None:
return None
# 检查平台是否为 gewechat
if self.platform_meta.name != "gewechat":
return None
from astrbot.core.platform.sources.gewechat.gewechat_event import (
GewechatPlatformEvent,
)
assert isinstance(self, GewechatPlatformEvent)
client = self.client
# 从客户端获取群信息
res = await client.get_group(group_id)
data = res["data"]
# 检查 chatroomId 是否为空
if data["chatroomId"] == "":
return None
return Group.from_dict(data)