style:idea默认格式化了部分代码
feat:添加根据消息事件获取群信息的接口
This commit is contained in:
@@ -1,11 +1,9 @@
|
||||
import abc
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
from .astrbot_message import AstrBotMessage
|
||||
from .platform_metadata import PlatformMetadata
|
||||
from astrbot.core.message.message_event_result import MessageEventResult, MessageChain
|
||||
from astrbot.core.platform.message_type import MessageType
|
||||
from typing import List, Union
|
||||
from typing import List, Union, Optional
|
||||
|
||||
from astrbot.core.db.po import Conversation
|
||||
from astrbot.core.message.components import (
|
||||
Plain,
|
||||
Image,
|
||||
@@ -15,9 +13,12 @@ from astrbot.core.message.components import (
|
||||
AtAll,
|
||||
Forward,
|
||||
)
|
||||
from astrbot.core.utils.metrics import Metric
|
||||
from astrbot.core.message.message_event_result import MessageEventResult, MessageChain
|
||||
from astrbot.core.platform.message_type import MessageType
|
||||
from astrbot.core.provider.entites import ProviderRequest
|
||||
from astrbot.core.db.po import Conversation
|
||||
from astrbot.core.utils.metrics import Metric
|
||||
from .astrbot_message import AstrBotMessage, Group
|
||||
from .platform_metadata import PlatformMetadata
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -37,11 +38,11 @@ class MessageSesion:
|
||||
|
||||
class AstrMessageEvent(abc.ABC):
|
||||
def __init__(
|
||||
self,
|
||||
message_str: str,
|
||||
message_obj: AstrBotMessage,
|
||||
platform_meta: PlatformMetadata,
|
||||
session_id: str,
|
||||
self,
|
||||
message_str: str,
|
||||
message_obj: AstrBotMessage,
|
||||
platform_meta: PlatformMetadata,
|
||||
session_id: str,
|
||||
):
|
||||
self.message_str = message_str
|
||||
"""纯文本的消息"""
|
||||
@@ -320,14 +321,14 @@ class AstrMessageEvent(abc.ABC):
|
||||
"""LLM 请求相关"""
|
||||
|
||||
def request_llm(
|
||||
self,
|
||||
prompt: str,
|
||||
func_tool_manager=None,
|
||||
session_id: str = None,
|
||||
image_urls: List[str] = [],
|
||||
contexts: List = [],
|
||||
system_prompt: str = "",
|
||||
conversation: Conversation = None,
|
||||
self,
|
||||
prompt: str,
|
||||
func_tool_manager=None,
|
||||
session_id: str = None,
|
||||
image_urls: List[str] = [],
|
||||
contexts: List = [],
|
||||
system_prompt: str = "",
|
||||
conversation: Conversation = None,
|
||||
) -> ProviderRequest:
|
||||
"""
|
||||
创建一个 LLM 请求。
|
||||
@@ -363,3 +364,37 @@ class AstrMessageEvent(abc.ABC):
|
||||
system_prompt=system_prompt,
|
||||
conversation=conversation,
|
||||
)
|
||||
|
||||
async def get_group(self, group_id: str = None) -> Optional[Group]:
|
||||
"""
|
||||
获取群聊,如果不填写group_id,且消息是私聊消息,则返回 None
|
||||
目前只实现了 GeweChat 协议
|
||||
"""
|
||||
# 确定有效的 group_id
|
||||
if group_id is None:
|
||||
group_id = self.message_obj.group_id
|
||||
|
||||
if group_id is None:
|
||||
return None
|
||||
|
||||
# 检查平台是否为 gewechat
|
||||
if self.platform_meta.name != "gewechat":
|
||||
return None
|
||||
|
||||
from astrbot.core.platform.sources.gewechat.gewechat_event import (
|
||||
GewechatPlatformEvent,
|
||||
)
|
||||
|
||||
assert isinstance(self, GewechatPlatformEvent)
|
||||
client = self.client
|
||||
|
||||
# 从客户端获取群信息
|
||||
res = await client.get_group(group_id)
|
||||
|
||||
data = res["data"]
|
||||
|
||||
# 检查 chatroomId 是否为空
|
||||
if data["chatroomId"] == "":
|
||||
return None
|
||||
|
||||
return Group.from_dict(data)
|
||||
|
||||
Reference in New Issue
Block a user