From c1756e5767071c828421925a95c2d6c103201bbf Mon Sep 17 00:00:00 2001 From: Zhalslar Date: Sat, 6 Sep 2025 19:22:49 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=E7=BB=84=E4=BB=B6=20t?= =?UTF-8?q?ype=20=E5=B1=9E=E6=80=A7=E4=B8=BA=E6=9E=9A=E4=B8=BE=E5=80=BC=20?= =?UTF-8?q?(#2628)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 当前components.py中每个组件的type属性都是直接字符串赋值,IDE会爆红。 修正为使用本就定义好的ComponentType枚举类 用时修正多个组件中当url为空时convert_to_base64检查路径导致的报错 --- astrbot/core/message/components.py | 105 +++++++++++++++-------------- 1 file changed, 55 insertions(+), 50 deletions(-) diff --git a/astrbot/core/message/components.py b/astrbot/core/message/components.py index beb6e40bb..f02d492d0 100644 --- a/astrbot/core/message/components.py +++ b/astrbot/core/message/components.py @@ -37,7 +37,7 @@ from astrbot.core.utils.astrbot_path import get_astrbot_data_path from astrbot.core.utils.io import download_file, download_image_by_url, file_to_base64 -class ComponentType(Enum): +class ComponentType(str, Enum): Plain = "Plain" # 纯文本消息 Face = "Face" # QQ表情 Record = "Record" # 语音 @@ -108,7 +108,7 @@ class BaseMessageComponent(BaseModel): class Plain(BaseMessageComponent): - type: ComponentType = "Plain" + type = ComponentType.Plain text: str convert: T.Optional[bool] = True # 若为 False 则直接发送未转换 CQ 码的消息 @@ -128,8 +128,9 @@ class Plain(BaseMessageComponent): async def to_dict(self): return {"type": "text", "data": {"text": self.text}} + class Face(BaseMessageComponent): - type: ComponentType = "Face" + type = ComponentType.Face id: int def __init__(self, **_): @@ -137,7 +138,7 @@ class Face(BaseMessageComponent): class Record(BaseMessageComponent): - type: ComponentType = "Record" + type = ComponentType.Record file: T.Optional[str] = "" magic: T.Optional[bool] = False url: T.Optional[str] = "" @@ -170,13 +171,14 @@ class Record(BaseMessageComponent): Returns: str: 语音的本地路径,以绝对路径表示。 """ - if self.file and self.file.startswith("file:///"): - file_path = self.file[8:] - return file_path - elif self.file and self.file.startswith("http"): + if not self.file: + raise Exception(f"not a valid file: {self.file}") + if self.file.startswith("file:///"): + return self.file[8:] + elif self.file.startswith("http"): file_path = await download_image_by_url(self.file) return os.path.abspath(file_path) - elif self.file and self.file.startswith("base64://"): + elif self.file.startswith("base64://"): bs64_data = self.file.removeprefix("base64://") image_bytes = base64.b64decode(bs64_data) temp_dir = os.path.join(get_astrbot_data_path(), "temp") @@ -185,8 +187,7 @@ class Record(BaseMessageComponent): f.write(image_bytes) return os.path.abspath(file_path) elif os.path.exists(self.file): - file_path = self.file - return os.path.abspath(file_path) + return os.path.abspath(self.file) else: raise Exception(f"not a valid file: {self.file}") @@ -197,12 +198,14 @@ class Record(BaseMessageComponent): str: 语音的 base64 编码,不以 base64:// 或者 data:image/jpeg;base64, 开头。 """ # convert to base64 - if self.file and self.file.startswith("file:///"): + if not self.file: + raise Exception(f"not a valid file: {self.file}") + if self.file.startswith("file:///"): bs64_data = file_to_base64(self.file[8:]) - elif self.file and self.file.startswith("http"): + elif self.file.startswith("http"): file_path = await download_image_by_url(self.file) bs64_data = file_to_base64(file_path) - elif self.file and self.file.startswith("base64://"): + elif self.file.startswith("base64://"): bs64_data = self.file elif os.path.exists(self.file): bs64_data = file_to_base64(self.file) @@ -236,7 +239,7 @@ class Record(BaseMessageComponent): class Video(BaseMessageComponent): - type: ComponentType = "Video" + type = ComponentType.Video file: str cover: T.Optional[str] = "" c: T.Optional[int] = 2 @@ -322,7 +325,7 @@ class Video(BaseMessageComponent): class At(BaseMessageComponent): - type: ComponentType = "At" + type = ComponentType.At qq: T.Union[int, str] # 此处str为all时代表所有人 name: T.Optional[str] = "" @@ -344,28 +347,28 @@ class AtAll(At): class RPS(BaseMessageComponent): # TODO - type: ComponentType = "RPS" + type = ComponentType.RPS def __init__(self, **_): super().__init__(**_) class Dice(BaseMessageComponent): # TODO - type: ComponentType = "Dice" + type = ComponentType.Dice def __init__(self, **_): super().__init__(**_) class Shake(BaseMessageComponent): # TODO - type: ComponentType = "Shake" + type = ComponentType.Shake def __init__(self, **_): super().__init__(**_) class Anonymous(BaseMessageComponent): # TODO - type: ComponentType = "Anonymous" + type = ComponentType.Anonymous ignore: T.Optional[bool] = False def __init__(self, **_): @@ -373,7 +376,7 @@ class Anonymous(BaseMessageComponent): # TODO class Share(BaseMessageComponent): - type: ComponentType = "Share" + type = ComponentType.Share url: str title: str content: T.Optional[str] = "" @@ -384,7 +387,7 @@ class Share(BaseMessageComponent): class Contact(BaseMessageComponent): # TODO - type: ComponentType = "Contact" + type = ComponentType.Contact _type: str # type 字段冲突 id: T.Optional[int] = 0 @@ -393,7 +396,7 @@ class Contact(BaseMessageComponent): # TODO class Location(BaseMessageComponent): # TODO - type: ComponentType = "Location" + type = ComponentType.Location lat: float lon: float title: T.Optional[str] = "" @@ -404,7 +407,7 @@ class Location(BaseMessageComponent): # TODO class Music(BaseMessageComponent): - type: ComponentType = "Music" + type = ComponentType.Music _type: str id: T.Optional[int] = 0 url: T.Optional[str] = "" @@ -421,7 +424,7 @@ class Music(BaseMessageComponent): class Image(BaseMessageComponent): - type: ComponentType = "Image" + type = ComponentType.Image file: T.Optional[str] = "" _type: T.Optional[str] = "" subType: T.Optional[int] = 0 @@ -464,14 +467,15 @@ class Image(BaseMessageComponent): Returns: str: 图片的本地路径,以绝对路径表示。 """ - url = self.url if self.url else self.file - if url and url.startswith("file:///"): - image_file_path = url[8:] - return image_file_path - elif url and url.startswith("http"): + url = self.url or self.file + if not url: + raise ValueError("No valid file or URL provided") + if url.startswith("file:///"): + return url[8:] + elif url.startswith("http"): image_file_path = await download_image_by_url(url) return os.path.abspath(image_file_path) - elif url and url.startswith("base64://"): + elif url.startswith("base64://"): bs64_data = url.removeprefix("base64://") image_bytes = base64.b64decode(bs64_data) temp_dir = os.path.join(get_astrbot_data_path(), "temp") @@ -480,8 +484,7 @@ class Image(BaseMessageComponent): f.write(image_bytes) return os.path.abspath(image_file_path) elif os.path.exists(url): - image_file_path = url - return os.path.abspath(image_file_path) + return os.path.abspath(url) else: raise Exception(f"not a valid file: {url}") @@ -492,13 +495,15 @@ class Image(BaseMessageComponent): str: 图片的 base64 编码,不以 base64:// 或者 data:image/jpeg;base64, 开头。 """ # convert to base64 - url = self.url if self.url else self.file - if url and url.startswith("file:///"): + url = self.url or self.file + if not url: + raise ValueError("No valid file or URL provided") + if url.startswith("file:///"): bs64_data = file_to_base64(url[8:]) - elif url and url.startswith("http"): + elif url.startswith("http"): image_file_path = await download_image_by_url(url) bs64_data = file_to_base64(image_file_path) - elif url and url.startswith("base64://"): + elif url.startswith("base64://"): bs64_data = url elif os.path.exists(url): bs64_data = file_to_base64(url) @@ -532,7 +537,7 @@ class Image(BaseMessageComponent): class Reply(BaseMessageComponent): - type: ComponentType = "Reply" + type = ComponentType.Reply id: T.Union[str, int] """所引用的消息 ID""" chain: T.Optional[T.List["BaseMessageComponent"]] = [] @@ -558,7 +563,7 @@ class Reply(BaseMessageComponent): class RedBag(BaseMessageComponent): - type: ComponentType = "RedBag" + type = ComponentType.RedBag title: str def __init__(self, **_): @@ -566,7 +571,7 @@ class RedBag(BaseMessageComponent): class Poke(BaseMessageComponent): - type: str = "" + type: str = ComponentType.Poke id: T.Optional[int] = 0 qq: T.Optional[int] = 0 @@ -576,7 +581,7 @@ class Poke(BaseMessageComponent): class Forward(BaseMessageComponent): - type: ComponentType = "Forward" + type = ComponentType.Forward id: str def __init__(self, **_): @@ -586,7 +591,7 @@ class Forward(BaseMessageComponent): class Node(BaseMessageComponent): """群合并转发消息""" - type: ComponentType = "Node" + type = ComponentType.Node id: T.Optional[int] = 0 # 忽略 name: T.Optional[str] = "" # qq昵称 uin: T.Optional[str] = "0" # qq号 @@ -638,7 +643,7 @@ class Node(BaseMessageComponent): class Nodes(BaseMessageComponent): - type: ComponentType = "Nodes" + type = ComponentType.Nodes nodes: T.List[Node] def __init__(self, nodes: T.List[Node], **_): @@ -664,7 +669,7 @@ class Nodes(BaseMessageComponent): class Xml(BaseMessageComponent): - type: ComponentType = "Xml" + type = ComponentType.Xml data: str resid: T.Optional[int] = 0 @@ -673,7 +678,7 @@ class Xml(BaseMessageComponent): class Json(BaseMessageComponent): - type: ComponentType = "Json" + type = ComponentType.Json data: T.Union[str, dict] resid: T.Optional[int] = 0 @@ -684,7 +689,7 @@ class Json(BaseMessageComponent): class CardImage(BaseMessageComponent): - type: ComponentType = "CardImage" + type = ComponentType.CardImage file: str cache: T.Optional[bool] = True minwidth: T.Optional[int] = 400 @@ -703,7 +708,7 @@ class CardImage(BaseMessageComponent): class TTS(BaseMessageComponent): - type: ComponentType = "TTS" + type = ComponentType.TTS text: str def __init__(self, **_): @@ -711,7 +716,7 @@ class TTS(BaseMessageComponent): class Unknown(BaseMessageComponent): - type: ComponentType = "Unknown" + type = ComponentType.Unknown text: str def toString(self): @@ -723,7 +728,7 @@ class File(BaseMessageComponent): 文件消息段 """ - type: ComponentType = "File" + type = ComponentType.File name: T.Optional[str] = "" # 名字 file_: T.Optional[str] = "" # 本地路径 url: T.Optional[str] = "" # url @@ -853,7 +858,7 @@ class File(BaseMessageComponent): class WechatEmoji(BaseMessageComponent): - type: ComponentType = "WechatEmoji" + type = ComponentType.WechatEmoji md5: T.Optional[str] = "" md5_len: T.Optional[int] = 0 cdnurl: T.Optional[str] = ""