fix: 修复组件 type 属性为枚举值 (#2628)
当前components.py中每个组件的type属性都是直接字符串赋值,IDE会爆红。 修正为使用本就定义好的ComponentType枚举类 用时修正多个组件中当url为空时convert_to_base64检查路径导致的报错
This commit is contained in:
@@ -37,7 +37,7 @@ from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
from astrbot.core.utils.io import download_file, download_image_by_url, file_to_base64
|
||||
|
||||
|
||||
class ComponentType(Enum):
|
||||
class ComponentType(str, Enum):
|
||||
Plain = "Plain" # 纯文本消息
|
||||
Face = "Face" # QQ表情
|
||||
Record = "Record" # 语音
|
||||
@@ -108,7 +108,7 @@ class BaseMessageComponent(BaseModel):
|
||||
|
||||
|
||||
class Plain(BaseMessageComponent):
|
||||
type: ComponentType = "Plain"
|
||||
type = ComponentType.Plain
|
||||
text: str
|
||||
convert: T.Optional[bool] = True # 若为 False 则直接发送未转换 CQ 码的消息
|
||||
|
||||
@@ -128,8 +128,9 @@ class Plain(BaseMessageComponent):
|
||||
async def to_dict(self):
|
||||
return {"type": "text", "data": {"text": self.text}}
|
||||
|
||||
|
||||
class Face(BaseMessageComponent):
|
||||
type: ComponentType = "Face"
|
||||
type = ComponentType.Face
|
||||
id: int
|
||||
|
||||
def __init__(self, **_):
|
||||
@@ -137,7 +138,7 @@ class Face(BaseMessageComponent):
|
||||
|
||||
|
||||
class Record(BaseMessageComponent):
|
||||
type: ComponentType = "Record"
|
||||
type = ComponentType.Record
|
||||
file: T.Optional[str] = ""
|
||||
magic: T.Optional[bool] = False
|
||||
url: T.Optional[str] = ""
|
||||
@@ -170,13 +171,14 @@ class Record(BaseMessageComponent):
|
||||
Returns:
|
||||
str: 语音的本地路径,以绝对路径表示。
|
||||
"""
|
||||
if self.file and self.file.startswith("file:///"):
|
||||
file_path = self.file[8:]
|
||||
return file_path
|
||||
elif self.file and self.file.startswith("http"):
|
||||
if not self.file:
|
||||
raise Exception(f"not a valid file: {self.file}")
|
||||
if self.file.startswith("file:///"):
|
||||
return self.file[8:]
|
||||
elif self.file.startswith("http"):
|
||||
file_path = await download_image_by_url(self.file)
|
||||
return os.path.abspath(file_path)
|
||||
elif self.file and self.file.startswith("base64://"):
|
||||
elif self.file.startswith("base64://"):
|
||||
bs64_data = self.file.removeprefix("base64://")
|
||||
image_bytes = base64.b64decode(bs64_data)
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
@@ -185,8 +187,7 @@ class Record(BaseMessageComponent):
|
||||
f.write(image_bytes)
|
||||
return os.path.abspath(file_path)
|
||||
elif os.path.exists(self.file):
|
||||
file_path = self.file
|
||||
return os.path.abspath(file_path)
|
||||
return os.path.abspath(self.file)
|
||||
else:
|
||||
raise Exception(f"not a valid file: {self.file}")
|
||||
|
||||
@@ -197,12 +198,14 @@ class Record(BaseMessageComponent):
|
||||
str: 语音的 base64 编码,不以 base64:// 或者 data:image/jpeg;base64, 开头。
|
||||
"""
|
||||
# convert to base64
|
||||
if self.file and self.file.startswith("file:///"):
|
||||
if not self.file:
|
||||
raise Exception(f"not a valid file: {self.file}")
|
||||
if self.file.startswith("file:///"):
|
||||
bs64_data = file_to_base64(self.file[8:])
|
||||
elif self.file and self.file.startswith("http"):
|
||||
elif self.file.startswith("http"):
|
||||
file_path = await download_image_by_url(self.file)
|
||||
bs64_data = file_to_base64(file_path)
|
||||
elif self.file and self.file.startswith("base64://"):
|
||||
elif self.file.startswith("base64://"):
|
||||
bs64_data = self.file
|
||||
elif os.path.exists(self.file):
|
||||
bs64_data = file_to_base64(self.file)
|
||||
@@ -236,7 +239,7 @@ class Record(BaseMessageComponent):
|
||||
|
||||
|
||||
class Video(BaseMessageComponent):
|
||||
type: ComponentType = "Video"
|
||||
type = ComponentType.Video
|
||||
file: str
|
||||
cover: T.Optional[str] = ""
|
||||
c: T.Optional[int] = 2
|
||||
@@ -322,7 +325,7 @@ class Video(BaseMessageComponent):
|
||||
|
||||
|
||||
class At(BaseMessageComponent):
|
||||
type: ComponentType = "At"
|
||||
type = ComponentType.At
|
||||
qq: T.Union[int, str] # 此处str为all时代表所有人
|
||||
name: T.Optional[str] = ""
|
||||
|
||||
@@ -344,28 +347,28 @@ class AtAll(At):
|
||||
|
||||
|
||||
class RPS(BaseMessageComponent): # TODO
|
||||
type: ComponentType = "RPS"
|
||||
type = ComponentType.RPS
|
||||
|
||||
def __init__(self, **_):
|
||||
super().__init__(**_)
|
||||
|
||||
|
||||
class Dice(BaseMessageComponent): # TODO
|
||||
type: ComponentType = "Dice"
|
||||
type = ComponentType.Dice
|
||||
|
||||
def __init__(self, **_):
|
||||
super().__init__(**_)
|
||||
|
||||
|
||||
class Shake(BaseMessageComponent): # TODO
|
||||
type: ComponentType = "Shake"
|
||||
type = ComponentType.Shake
|
||||
|
||||
def __init__(self, **_):
|
||||
super().__init__(**_)
|
||||
|
||||
|
||||
class Anonymous(BaseMessageComponent): # TODO
|
||||
type: ComponentType = "Anonymous"
|
||||
type = ComponentType.Anonymous
|
||||
ignore: T.Optional[bool] = False
|
||||
|
||||
def __init__(self, **_):
|
||||
@@ -373,7 +376,7 @@ class Anonymous(BaseMessageComponent): # TODO
|
||||
|
||||
|
||||
class Share(BaseMessageComponent):
|
||||
type: ComponentType = "Share"
|
||||
type = ComponentType.Share
|
||||
url: str
|
||||
title: str
|
||||
content: T.Optional[str] = ""
|
||||
@@ -384,7 +387,7 @@ class Share(BaseMessageComponent):
|
||||
|
||||
|
||||
class Contact(BaseMessageComponent): # TODO
|
||||
type: ComponentType = "Contact"
|
||||
type = ComponentType.Contact
|
||||
_type: str # type 字段冲突
|
||||
id: T.Optional[int] = 0
|
||||
|
||||
@@ -393,7 +396,7 @@ class Contact(BaseMessageComponent): # TODO
|
||||
|
||||
|
||||
class Location(BaseMessageComponent): # TODO
|
||||
type: ComponentType = "Location"
|
||||
type = ComponentType.Location
|
||||
lat: float
|
||||
lon: float
|
||||
title: T.Optional[str] = ""
|
||||
@@ -404,7 +407,7 @@ class Location(BaseMessageComponent): # TODO
|
||||
|
||||
|
||||
class Music(BaseMessageComponent):
|
||||
type: ComponentType = "Music"
|
||||
type = ComponentType.Music
|
||||
_type: str
|
||||
id: T.Optional[int] = 0
|
||||
url: T.Optional[str] = ""
|
||||
@@ -421,7 +424,7 @@ class Music(BaseMessageComponent):
|
||||
|
||||
|
||||
class Image(BaseMessageComponent):
|
||||
type: ComponentType = "Image"
|
||||
type = ComponentType.Image
|
||||
file: T.Optional[str] = ""
|
||||
_type: T.Optional[str] = ""
|
||||
subType: T.Optional[int] = 0
|
||||
@@ -464,14 +467,15 @@ class Image(BaseMessageComponent):
|
||||
Returns:
|
||||
str: 图片的本地路径,以绝对路径表示。
|
||||
"""
|
||||
url = self.url if self.url else self.file
|
||||
if url and url.startswith("file:///"):
|
||||
image_file_path = url[8:]
|
||||
return image_file_path
|
||||
elif url and url.startswith("http"):
|
||||
url = self.url or self.file
|
||||
if not url:
|
||||
raise ValueError("No valid file or URL provided")
|
||||
if url.startswith("file:///"):
|
||||
return url[8:]
|
||||
elif url.startswith("http"):
|
||||
image_file_path = await download_image_by_url(url)
|
||||
return os.path.abspath(image_file_path)
|
||||
elif url and url.startswith("base64://"):
|
||||
elif url.startswith("base64://"):
|
||||
bs64_data = url.removeprefix("base64://")
|
||||
image_bytes = base64.b64decode(bs64_data)
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
@@ -480,8 +484,7 @@ class Image(BaseMessageComponent):
|
||||
f.write(image_bytes)
|
||||
return os.path.abspath(image_file_path)
|
||||
elif os.path.exists(url):
|
||||
image_file_path = url
|
||||
return os.path.abspath(image_file_path)
|
||||
return os.path.abspath(url)
|
||||
else:
|
||||
raise Exception(f"not a valid file: {url}")
|
||||
|
||||
@@ -492,13 +495,15 @@ class Image(BaseMessageComponent):
|
||||
str: 图片的 base64 编码,不以 base64:// 或者 data:image/jpeg;base64, 开头。
|
||||
"""
|
||||
# convert to base64
|
||||
url = self.url if self.url else self.file
|
||||
if url and url.startswith("file:///"):
|
||||
url = self.url or self.file
|
||||
if not url:
|
||||
raise ValueError("No valid file or URL provided")
|
||||
if url.startswith("file:///"):
|
||||
bs64_data = file_to_base64(url[8:])
|
||||
elif url and url.startswith("http"):
|
||||
elif url.startswith("http"):
|
||||
image_file_path = await download_image_by_url(url)
|
||||
bs64_data = file_to_base64(image_file_path)
|
||||
elif url and url.startswith("base64://"):
|
||||
elif url.startswith("base64://"):
|
||||
bs64_data = url
|
||||
elif os.path.exists(url):
|
||||
bs64_data = file_to_base64(url)
|
||||
@@ -532,7 +537,7 @@ class Image(BaseMessageComponent):
|
||||
|
||||
|
||||
class Reply(BaseMessageComponent):
|
||||
type: ComponentType = "Reply"
|
||||
type = ComponentType.Reply
|
||||
id: T.Union[str, int]
|
||||
"""所引用的消息 ID"""
|
||||
chain: T.Optional[T.List["BaseMessageComponent"]] = []
|
||||
@@ -558,7 +563,7 @@ class Reply(BaseMessageComponent):
|
||||
|
||||
|
||||
class RedBag(BaseMessageComponent):
|
||||
type: ComponentType = "RedBag"
|
||||
type = ComponentType.RedBag
|
||||
title: str
|
||||
|
||||
def __init__(self, **_):
|
||||
@@ -566,7 +571,7 @@ class RedBag(BaseMessageComponent):
|
||||
|
||||
|
||||
class Poke(BaseMessageComponent):
|
||||
type: str = ""
|
||||
type: str = ComponentType.Poke
|
||||
id: T.Optional[int] = 0
|
||||
qq: T.Optional[int] = 0
|
||||
|
||||
@@ -576,7 +581,7 @@ class Poke(BaseMessageComponent):
|
||||
|
||||
|
||||
class Forward(BaseMessageComponent):
|
||||
type: ComponentType = "Forward"
|
||||
type = ComponentType.Forward
|
||||
id: str
|
||||
|
||||
def __init__(self, **_):
|
||||
@@ -586,7 +591,7 @@ class Forward(BaseMessageComponent):
|
||||
class Node(BaseMessageComponent):
|
||||
"""群合并转发消息"""
|
||||
|
||||
type: ComponentType = "Node"
|
||||
type = ComponentType.Node
|
||||
id: T.Optional[int] = 0 # 忽略
|
||||
name: T.Optional[str] = "" # qq昵称
|
||||
uin: T.Optional[str] = "0" # qq号
|
||||
@@ -638,7 +643,7 @@ class Node(BaseMessageComponent):
|
||||
|
||||
|
||||
class Nodes(BaseMessageComponent):
|
||||
type: ComponentType = "Nodes"
|
||||
type = ComponentType.Nodes
|
||||
nodes: T.List[Node]
|
||||
|
||||
def __init__(self, nodes: T.List[Node], **_):
|
||||
@@ -664,7 +669,7 @@ class Nodes(BaseMessageComponent):
|
||||
|
||||
|
||||
class Xml(BaseMessageComponent):
|
||||
type: ComponentType = "Xml"
|
||||
type = ComponentType.Xml
|
||||
data: str
|
||||
resid: T.Optional[int] = 0
|
||||
|
||||
@@ -673,7 +678,7 @@ class Xml(BaseMessageComponent):
|
||||
|
||||
|
||||
class Json(BaseMessageComponent):
|
||||
type: ComponentType = "Json"
|
||||
type = ComponentType.Json
|
||||
data: T.Union[str, dict]
|
||||
resid: T.Optional[int] = 0
|
||||
|
||||
@@ -684,7 +689,7 @@ class Json(BaseMessageComponent):
|
||||
|
||||
|
||||
class CardImage(BaseMessageComponent):
|
||||
type: ComponentType = "CardImage"
|
||||
type = ComponentType.CardImage
|
||||
file: str
|
||||
cache: T.Optional[bool] = True
|
||||
minwidth: T.Optional[int] = 400
|
||||
@@ -703,7 +708,7 @@ class CardImage(BaseMessageComponent):
|
||||
|
||||
|
||||
class TTS(BaseMessageComponent):
|
||||
type: ComponentType = "TTS"
|
||||
type = ComponentType.TTS
|
||||
text: str
|
||||
|
||||
def __init__(self, **_):
|
||||
@@ -711,7 +716,7 @@ class TTS(BaseMessageComponent):
|
||||
|
||||
|
||||
class Unknown(BaseMessageComponent):
|
||||
type: ComponentType = "Unknown"
|
||||
type = ComponentType.Unknown
|
||||
text: str
|
||||
|
||||
def toString(self):
|
||||
@@ -723,7 +728,7 @@ class File(BaseMessageComponent):
|
||||
文件消息段
|
||||
"""
|
||||
|
||||
type: ComponentType = "File"
|
||||
type = ComponentType.File
|
||||
name: T.Optional[str] = "" # 名字
|
||||
file_: T.Optional[str] = "" # 本地路径
|
||||
url: T.Optional[str] = "" # url
|
||||
@@ -853,7 +858,7 @@ class File(BaseMessageComponent):
|
||||
|
||||
|
||||
class WechatEmoji(BaseMessageComponent):
|
||||
type: ComponentType = "WechatEmoji"
|
||||
type = ComponentType.WechatEmoji
|
||||
md5: T.Optional[str] = ""
|
||||
md5_len: T.Optional[int] = 0
|
||||
cdnurl: T.Optional[str] = ""
|
||||
|
||||
Reference in New Issue
Block a user