refactor: 修改框架路径获取方式,规范化路径拼接
This commit is contained in:
@@ -7,9 +7,10 @@ from astrbot.core.utils.pip_installer import PipInstaller
|
||||
from astrbot.core.db.sqlite import SQLiteDatabase
|
||||
from astrbot.core.config.default import DB_PATH
|
||||
from astrbot.core.config import AstrBotConfig
|
||||
from .utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
# 初始化数据存储文件夹
|
||||
os.makedirs("data", exist_ok=True)
|
||||
os.makedirs(get_astrbot_data_path(), exist_ok=True)
|
||||
|
||||
astrbot_config = AstrBotConfig()
|
||||
t2i_base_url = astrbot_config.get("t2i_endpoint", "https://t2i.soulter.top/text2img")
|
||||
|
||||
@@ -4,8 +4,9 @@ import logging
|
||||
import enum
|
||||
from .default import DEFAULT_CONFIG, DEFAULT_VALUE_MAP
|
||||
from typing import Dict
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
ASTRBOT_CONFIG_PATH = "data/cmd_config.json"
|
||||
ASTRBOT_CONFIG_PATH = os.path.join(get_astrbot_data_path(), "cmd_config.json")
|
||||
logger = logging.getLogger("astrbot")
|
||||
|
||||
|
||||
|
||||
@@ -2,8 +2,11 @@
|
||||
如需修改配置,请在 `data/cmd_config.json` 中修改或者在管理面板中可视化修改。
|
||||
"""
|
||||
|
||||
import os
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
VERSION = "3.5.6"
|
||||
DB_PATH = "data/data_v3.db"
|
||||
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v3.db")
|
||||
|
||||
# 默认配置
|
||||
DEFAULT_CONFIG = {
|
||||
|
||||
@@ -3,8 +3,9 @@ import aiosqlite
|
||||
import os
|
||||
from typing import Any
|
||||
from .plugin_storage import PluginStorage
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
DBPATH = "data/plugin_data/sqlite/plugin_data.db"
|
||||
DBPATH = os.path.join(get_astrbot_data_path(), "plugin_data", "sqlite", "plugin_data.db")
|
||||
|
||||
|
||||
class SQLitePluginStorage(PluginStorage):
|
||||
|
||||
@@ -32,6 +32,7 @@ from enum import Enum
|
||||
from pydantic.v1 import BaseModel
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.utils.io import download_image_by_url, file_to_base64, download_file
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
|
||||
class ComponentType(Enum):
|
||||
@@ -167,7 +168,8 @@ class Record(BaseMessageComponent):
|
||||
elif self.file and self.file.startswith("base64://"):
|
||||
bs64_data = self.file.removeprefix("base64://")
|
||||
image_bytes = base64.b64decode(bs64_data)
|
||||
file_path = f"data/temp/{uuid.uuid4()}.jpg"
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
file_path = os.path.join(temp_dir, f"{uuid.uuid4()}.jpg")
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(image_bytes)
|
||||
return os.path.abspath(file_path)
|
||||
@@ -371,7 +373,9 @@ class Image(BaseMessageComponent):
|
||||
elif url and url.startswith("base64://"):
|
||||
bs64_data = url.removeprefix("base64://")
|
||||
image_bytes = base64.b64decode(bs64_data)
|
||||
image_file_path = f"data/temp/{uuid.uuid4()}.jpg"
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
os.makedirs(temp_dir, exist_ok=True)
|
||||
image_file_path = os.path.join(temp_dir, f"{uuid.uuid4()}.jpg")
|
||||
with open(image_file_path, "wb") as f:
|
||||
f.write(image_bytes)
|
||||
return os.path.abspath(image_file_path)
|
||||
@@ -631,9 +635,9 @@ class File(BaseMessageComponent):
|
||||
if self._downloaded:
|
||||
return
|
||||
|
||||
os.makedirs("data/download", exist_ok=True)
|
||||
filename = self.name or f"{uuid.uuid4().hex}"
|
||||
file_path = f"data/download/{filename}"
|
||||
download_dir = os.path.join(get_astrbot_data_path(), "download")
|
||||
os.makedirs(download_dir, exist_ok=True)
|
||||
file_path = os.path.join(download_dir, f"{uuid.uuid4().hex}")
|
||||
|
||||
await download_file(self.url, file_path)
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import os
|
||||
import uuid
|
||||
import aiohttp
|
||||
import dingtalk_stream
|
||||
@@ -19,6 +20,7 @@ from ...register import register_platform_adapter
|
||||
from astrbot import logger
|
||||
from dingtalk_stream import AckMessage
|
||||
from astrbot.core.utils.io import download_file
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
|
||||
class MyEventHandler(dingtalk_stream.EventHandler):
|
||||
@@ -152,7 +154,8 @@ class DingtalkPlatformAdapter(Platform):
|
||||
"downloadCode": download_code,
|
||||
"robotCode": robot_code,
|
||||
}
|
||||
f_path = f"data/dingtalk_file_{uuid.uuid4()}.{ext}"
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
f_path = os.path.join(temp_dir, f"dingtalk_file_{uuid.uuid4()}.{ext}")
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
"https://api.dingtalk.com/v1.0/robot/messageFiles/download",
|
||||
|
||||
@@ -15,6 +15,7 @@ from astrbot.api.message_components import Plain, Image, At, Record, Video
|
||||
from astrbot.api.platform import AstrBotMessage, MessageMember, MessageType
|
||||
from astrbot.core.utils.io import download_image_by_url
|
||||
from .downloader import GeweDownloader
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
try:
|
||||
from .xml_data_parser import GeweDataParser
|
||||
@@ -250,7 +251,10 @@ class SimpleGewechatClient:
|
||||
# 语音消息
|
||||
if "ImgBuf" in d and "buffer" in d["ImgBuf"]:
|
||||
voice_data = base64.b64decode(d["ImgBuf"]["buffer"])
|
||||
file_path = f"data/temp/gewe_voice_{abm.message_id}.silk"
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
file_path = os.path.join(
|
||||
temp_dir, f"gewe_voice_{abm.message_id}.silk"
|
||||
)
|
||||
|
||||
async with await anyio.open_file(file_path, "wb") as f:
|
||||
await f.write(voice_data)
|
||||
@@ -458,8 +462,10 @@ class SimpleGewechatClient:
|
||||
retry_cnt -= 1
|
||||
|
||||
# 需要验证码
|
||||
if os.path.exists("data/temp/gewe_code"):
|
||||
with open("data/temp/gewe_code", "r") as f:
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
code_file_path = os.path.join(temp_dir, "gewe_code")
|
||||
if os.path.exists(code_file_path):
|
||||
with open(code_file_path, "r") as f:
|
||||
code = f.read().strip()
|
||||
if not code:
|
||||
logger.warning(
|
||||
@@ -470,9 +476,9 @@ class SimpleGewechatClient:
|
||||
payload["captchCode"] = code
|
||||
logger.info(f"使用验证码: {code}")
|
||||
try:
|
||||
os.remove("data/temp/gewe_code")
|
||||
os.remove(code_file_path)
|
||||
except Exception:
|
||||
logger.warning("删除验证码文件 data/temp/gewe_code 失败。")
|
||||
logger.warning(f"删除验证码文件 {code_file_path} 失败。")
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
|
||||
@@ -6,7 +6,7 @@ import traceback
|
||||
import os
|
||||
|
||||
from typing import AsyncGenerator
|
||||
from astrbot.core.utils.io import save_temp_img, download_file
|
||||
from astrbot.core.utils.io import download_file
|
||||
from astrbot.core.utils.tencent_record_helper import wav_to_tencent_silk
|
||||
from astrbot.api import logger
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
@@ -21,6 +21,7 @@ from astrbot.api.message_components import (
|
||||
WechatEmoji as Emoji,
|
||||
)
|
||||
from .client import SimpleGewechatClient
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
|
||||
def get_wav_duration(file_path):
|
||||
@@ -106,7 +107,8 @@ class GewechatPlatformEvent(AstrMessageEvent):
|
||||
# 根据 url 下载视频
|
||||
if video_url.startswith("http"):
|
||||
video_filename = f"{uuid.uuid4()}.mp4"
|
||||
video_path = f"data/temp/{video_filename}"
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
video_path = os.path.join(temp_dir, video_filename)
|
||||
await download_file(video_url, video_path)
|
||||
else:
|
||||
video_path = video_url
|
||||
@@ -115,7 +117,10 @@ class GewechatPlatformEvent(AstrMessageEvent):
|
||||
video_callback_url = f"{client.file_server_url}/{video_token}"
|
||||
|
||||
# 获取视频第一帧
|
||||
thumb_path = f"data/temp/gewechat_video_thumb_{uuid.uuid4()}.jpg"
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
thumb_path = os.path.join(
|
||||
temp_dir, f"gewechat_video_thumb_{uuid.uuid4()}.jpg"
|
||||
)
|
||||
|
||||
video_path = video_path.replace(" ", "\\ ")
|
||||
try:
|
||||
@@ -154,7 +159,8 @@ class GewechatPlatformEvent(AstrMessageEvent):
|
||||
record_url = comp.file
|
||||
record_path = await comp.convert_to_file_path()
|
||||
|
||||
silk_path = f"data/temp/{uuid.uuid4()}.silk"
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
silk_path = os.path.join(temp_dir, f"{uuid.uuid4()}.silk")
|
||||
try:
|
||||
duration = await wav_to_tencent_silk(record_path, silk_path)
|
||||
except Exception as e:
|
||||
@@ -173,7 +179,10 @@ class GewechatPlatformEvent(AstrMessageEvent):
|
||||
if file_path.startswith("file:///"):
|
||||
file_path = file_path[8:]
|
||||
elif file_path.startswith("http"):
|
||||
await download_file(file_path, f"data/temp/{file_name}")
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
temp_file_path = os.path.join(temp_dir, file_name)
|
||||
await download_file(file_path, temp_file_path)
|
||||
file_path = temp_file_path
|
||||
else:
|
||||
file_path = file_path
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
import base64
|
||||
import lark_oapi as lark
|
||||
@@ -9,6 +10,7 @@ from astrbot.api.message_components import Plain, Image as AstrBotImage, At
|
||||
from astrbot.core.utils.io import download_image_by_url
|
||||
from lark_oapi.api.im.v1 import *
|
||||
from astrbot import logger
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
|
||||
class LarkMessageEvent(AstrMessageEvent):
|
||||
@@ -40,7 +42,8 @@ class LarkMessageEvent(AstrMessageEvent):
|
||||
base64_str = comp.file.removeprefix("base64://")
|
||||
image_data = base64.b64decode(base64_str)
|
||||
# save as temp file
|
||||
file_path = f"data/temp/{uuid.uuid4()}_test.jpg"
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
file_path = os.path.join(temp_dir, f"{uuid.uuid4()}_test.jpg")
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(BytesIO(image_data).getvalue())
|
||||
else:
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import os
|
||||
import asyncio
|
||||
import telegramify_markdown
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
@@ -13,6 +14,7 @@ from astrbot.api.message_components import (
|
||||
from telegram.ext import ExtBot
|
||||
from astrbot.core.utils.io import download_file
|
||||
from astrbot import logger
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
|
||||
class TelegramPlatformEvent(AstrMessageEvent):
|
||||
@@ -75,7 +77,8 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
||||
await client.send_photo(photo=image_path, **payload)
|
||||
elif isinstance(i, File):
|
||||
if i.file.startswith("https://"):
|
||||
path = "data/temp/" + i.name
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
path = os.path.join(temp_dir, i.name)
|
||||
await download_file(i.file, path)
|
||||
i.file = path
|
||||
|
||||
@@ -126,7 +129,8 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
||||
continue
|
||||
elif isinstance(i, File):
|
||||
if i.file.startswith("https://"):
|
||||
path = "data/temp/" + i.name
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
path = os.path.join(temp_dir, i.name)
|
||||
await download_file(i.file, path)
|
||||
i.file = path
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@ from astrbot.core import web_chat_queue
|
||||
from .webchat_event import WebChatMessageEvent
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from ...register import register_platform_adapter
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
|
||||
class QueueListener:
|
||||
@@ -40,7 +41,8 @@ class WebChatAdapter(Platform):
|
||||
self.config = platform_config
|
||||
self.settings = platform_settings
|
||||
self.unique_session = platform_settings["unique_session"]
|
||||
self.imgs_dir = "data/webchat/imgs"
|
||||
self.imgs_dir = os.path.join(get_astrbot_data_path(), "webchat", "imgs")
|
||||
os.makedirs(self.imgs_dir, exist_ok=True)
|
||||
|
||||
self.metadata = PlatformMetadata(
|
||||
name="webchat", description="webchat", id=self.config.get("id")
|
||||
|
||||
@@ -6,8 +6,9 @@ from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
from astrbot.api.message_components import Plain, Image, Record
|
||||
from astrbot.core.utils.io import download_image_by_url
|
||||
from astrbot.core import web_chat_back_queue
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
imgs_dir = "data/webchat/imgs"
|
||||
imgs_dir = os.path.join(get_astrbot_data_path(), "webchat", "imgs")
|
||||
|
||||
|
||||
class WebChatMessageEvent(AstrMessageEvent):
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import os
|
||||
import sys
|
||||
import uuid
|
||||
import asyncio
|
||||
@@ -23,6 +24,7 @@ from wechatpy.enterprise.messages import TextMessage, ImageMessage, VoiceMessage
|
||||
from wechatpy.exceptions import InvalidSignatureException
|
||||
from wechatpy.enterprise import parse_message
|
||||
from .wecom_event import WecomPlatformEvent
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import override
|
||||
@@ -191,14 +193,15 @@ class WecomPlatformAdapter(Platform):
|
||||
resp: Response = await asyncio.get_event_loop().run_in_executor(
|
||||
None, self.client.media.download, msg.media_id
|
||||
)
|
||||
path = f"data/temp/wecom_{msg.media_id}.amr"
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
path = os.path.join(temp_dir, f"wecom_{msg.media_id}.amr")
|
||||
with open(path, "wb") as f:
|
||||
f.write(resp.content)
|
||||
|
||||
try:
|
||||
from pydub import AudioSegment
|
||||
|
||||
path_wav = f"data/temp/wecom_{msg.media_id}.wav"
|
||||
path_wav = os.path.join(temp_dir, f"wecom_{msg.media_id}.wav")
|
||||
audio = AudioSegment.from_file(path)
|
||||
audio.export(path_wav, format="wav")
|
||||
except Exception as e:
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import os
|
||||
import uuid
|
||||
import asyncio
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
@@ -6,6 +7,7 @@ from astrbot.api.message_components import Plain, Image, Record
|
||||
from wechatpy.enterprise import WeChatClient
|
||||
|
||||
from astrbot.api import logger
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
try:
|
||||
import pydub
|
||||
@@ -52,19 +54,29 @@ class WecomPlatformEvent(AstrMessageEvent):
|
||||
if start + 2048 >= len(plain):
|
||||
result.append(plain[start:])
|
||||
break
|
||||
|
||||
|
||||
# 向前搜索分割标点符号
|
||||
end = min(start + 2048, len(plain))
|
||||
cut_position = end
|
||||
for i in range(end, start, -1):
|
||||
if i < len(plain) and plain[i-1] in ["。", "!", "?", ".", "!", "?", "\n", ";", ";"]:
|
||||
if i < len(plain) and plain[i - 1] in [
|
||||
"。",
|
||||
"!",
|
||||
"?",
|
||||
".",
|
||||
"!",
|
||||
"?",
|
||||
"\n",
|
||||
";",
|
||||
";",
|
||||
]:
|
||||
cut_position = i
|
||||
break
|
||||
|
||||
|
||||
# 没找到合适的位置分割, 直接切分
|
||||
if cut_position == end and end < len(plain):
|
||||
cut_position = end
|
||||
|
||||
|
||||
result.append(plain[start:cut_position])
|
||||
start = cut_position
|
||||
|
||||
@@ -103,7 +115,8 @@ class WecomPlatformEvent(AstrMessageEvent):
|
||||
elif isinstance(comp, Record):
|
||||
record_path = await comp.convert_to_file_path()
|
||||
# 转成amr
|
||||
record_path_amr = f"data/temp/{uuid.uuid4()}.amr"
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
record_path_amr = os.path.join(temp_dir, f"{uuid.uuid4()}.amr")
|
||||
pydub.AudioSegment.from_wav(record_path).export(
|
||||
record_path_amr, format="amr"
|
||||
)
|
||||
|
||||
@@ -12,6 +12,8 @@ from contextlib import AsyncExitStack
|
||||
from astrbot import logger
|
||||
from astrbot.core.utils.log_pipe import LogPipe
|
||||
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
try:
|
||||
import mcp
|
||||
from mcp.client.sse import sse_client
|
||||
@@ -238,8 +240,7 @@ class FuncCall:
|
||||
}
|
||||
```
|
||||
"""
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
data_dir = os.path.abspath(os.path.join(current_dir, "../../../data"))
|
||||
data_dir = get_astrbot_data_path()
|
||||
|
||||
mcp_json_file = os.path.join(data_dir, "mcp_server.json")
|
||||
if not os.path.exists(mcp_json_file):
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import os
|
||||
import dashscope
|
||||
import uuid
|
||||
import asyncio
|
||||
@@ -5,6 +6,7 @@ from dashscope.audio.tts_v2 import *
|
||||
from ..provider import TTSProvider
|
||||
from ..entities import ProviderType
|
||||
from ..register import register_provider_adapter
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
|
||||
@register_provider_adapter(
|
||||
@@ -24,7 +26,8 @@ class ProviderDashscopeTTSAPI(TTSProvider):
|
||||
dashscope.api_key = self.chosen_api_key
|
||||
|
||||
async def get_audio(self, text: str) -> str:
|
||||
path = f"data/temp/dashscope_tts_{uuid.uuid4()}.wav"
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
path = os.path.join(temp_dir, f"dashscope_tts_{uuid.uuid4()}.wav")
|
||||
self.synthesizer = SpeechSynthesizer(
|
||||
model=self.get_model(),
|
||||
voice=self.voice,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import astrbot.core.message.components as Comp
|
||||
|
||||
import os
|
||||
from typing import List
|
||||
from .. import Provider, Personality
|
||||
from ..entities import LLMResponse
|
||||
@@ -10,6 +10,7 @@ from astrbot.core.utils.dify_api_client import DifyAPIClient
|
||||
from astrbot.core.utils.io import download_image_by_url, download_file
|
||||
from astrbot.core import logger, sp
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
|
||||
@register_provider_adapter("dify", "Dify APP 适配器。")
|
||||
@@ -227,7 +228,8 @@ class ProviderDify(Provider):
|
||||
return Comp.Image(file=item["url"], url=item["url"])
|
||||
case "audio":
|
||||
# 仅支持 wav
|
||||
path = f"data/temp/{item['filename']}.wav"
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
path = os.path.join(temp_dir, f"{item['filename']}.wav")
|
||||
await download_file(item["url"], path)
|
||||
return Comp.Image(file=item["url"], url=item["url"])
|
||||
case "video":
|
||||
|
||||
@@ -7,6 +7,7 @@ from ..provider import TTSProvider
|
||||
from ..entities import ProviderType
|
||||
from ..register import register_provider_adapter
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
"""
|
||||
edge_tts 方式,能够免费、快速生成语音,使用需要先安装edge-tts库
|
||||
@@ -40,9 +41,9 @@ class ProviderEdgeTTS(TTSProvider):
|
||||
self.set_model("edge_tts")
|
||||
|
||||
async def get_audio(self, text: str) -> str:
|
||||
os.makedirs("data/temp", exist_ok=True)
|
||||
mp3_path = f"data/temp/edge_tts_temp_{uuid.uuid4()}.mp3"
|
||||
wav_path = f"data/temp/edge_tts_{uuid.uuid4()}.wav"
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
mp3_path = os.path.join(temp_dir, f"edge_tts_temp_{uuid.uuid4()}.mp3")
|
||||
wav_path = os.path.join(temp_dir, f"edge_tts_{uuid.uuid4()}.wav")
|
||||
|
||||
# 构建 Edge TTS 参数
|
||||
kwargs = {"text": text, "voice": self.voice}
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import os
|
||||
import uuid
|
||||
import ormsgpack
|
||||
from pydantic import BaseModel, conint
|
||||
@@ -6,6 +7,7 @@ from typing import Annotated, Literal
|
||||
from ..provider import TTSProvider
|
||||
from ..entities import ProviderType
|
||||
from ..register import register_provider_adapter
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
|
||||
class ServeReferenceAudio(BaseModel):
|
||||
@@ -87,7 +89,8 @@ class ProviderFishAudioTTSAPI(TTSProvider):
|
||||
)
|
||||
|
||||
async def get_audio(self, text: str) -> str:
|
||||
path = f"data/temp/fishaudio_tts_api_{uuid.uuid4()}.wav"
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
path = os.path.join(temp_dir, f"fishaudio_tts_api_{uuid.uuid4()}.wav")
|
||||
self.headers["content-type"] = "application/msgpack"
|
||||
request = await self._generate_request(text)
|
||||
async with AsyncClient(base_url=self.api_base).stream(
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
import os
|
||||
import uuid
|
||||
import aiohttp
|
||||
import urllib.parse
|
||||
from ..provider import TTSProvider
|
||||
from ..entities import ProviderType
|
||||
from ..register import register_provider_adapter
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
|
||||
@register_provider_adapter(
|
||||
@@ -23,7 +25,8 @@ class ProviderGSVITTS(TTSProvider):
|
||||
self.emotion = provider_config.get("emotion")
|
||||
|
||||
async def get_audio(self, text: str) -> str:
|
||||
path = f"data/temp/gsvi_tts_{uuid.uuid4()}.wav"
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
path = os.path.join(temp_dir, f"gsvi_tts_{uuid.uuid4()}.wav")
|
||||
params = {"text": text}
|
||||
|
||||
if self.character:
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
import os
|
||||
import uuid
|
||||
from openai import AsyncOpenAI, NOT_GIVEN
|
||||
from ..provider import TTSProvider
|
||||
from ..entities import ProviderType
|
||||
from ..register import register_provider_adapter
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
|
||||
@register_provider_adapter(
|
||||
@@ -31,7 +33,8 @@ class ProviderOpenAITTSAPI(TTSProvider):
|
||||
self.set_model(provider_config.get("model", None))
|
||||
|
||||
async def get_audio(self, text: str) -> str:
|
||||
path = f"data/temp/openai_tts_api_{uuid.uuid4()}.wav"
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
path = os.path.join(temp_dir, f"openai_tts_api_{uuid.uuid4()}.wav")
|
||||
async with self.client.audio.speech.with_streaming_response.create(
|
||||
model=self.model_name, voice=self.voice, response_format="wav", input=text
|
||||
) as response:
|
||||
|
||||
@@ -7,6 +7,7 @@ from astrbot.core.utils.io import download_file
|
||||
from ..register import register_provider_adapter
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.utils.tencent_record_helper import tencent_silk_to_wav
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
|
||||
@register_provider_adapter(
|
||||
@@ -50,7 +51,8 @@ class ProviderOpenAIWhisperAPI(STTProvider):
|
||||
is_tencent = True
|
||||
|
||||
name = str(uuid.uuid4())
|
||||
path = os.path.join("data/temp", name)
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
path = os.path.join(temp_dir, name)
|
||||
await download_file(audio_url, path)
|
||||
audio_url = path
|
||||
|
||||
@@ -61,7 +63,8 @@ class ProviderOpenAIWhisperAPI(STTProvider):
|
||||
is_silk = await self._is_silk_file(audio_url)
|
||||
if is_silk:
|
||||
logger.info("Converting silk file to wav ...")
|
||||
output_path = os.path.join("data/temp", str(uuid.uuid4()) + ".wav")
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
output_path = os.path.join(temp_dir, str(uuid.uuid4()) + ".wav")
|
||||
await tencent_silk_to_wav(audio_url, output_path)
|
||||
audio_url = output_path
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ from astrbot.core.utils.io import download_file
|
||||
from ..register import register_provider_adapter
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.utils.tencent_record_helper import tencent_silk_to_wav
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
|
||||
@register_provider_adapter(
|
||||
@@ -53,7 +54,8 @@ class ProviderOpenAIWhisperSelfHost(STTProvider):
|
||||
is_tencent = True
|
||||
|
||||
name = str(uuid.uuid4())
|
||||
path = os.path.join("data/temp", name)
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
path = os.path.join(temp_dir, name)
|
||||
await download_file(audio_url, path)
|
||||
audio_url = path
|
||||
|
||||
@@ -64,7 +66,8 @@ class ProviderOpenAIWhisperSelfHost(STTProvider):
|
||||
is_silk = await self._is_silk_file(audio_url)
|
||||
if is_silk:
|
||||
logger.info("Converting silk file to wav ...")
|
||||
output_path = os.path.join("data/temp", str(uuid.uuid4()) + ".wav")
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
output_path = os.path.join(temp_dir, str(uuid.uuid4()) + ".wav")
|
||||
await tencent_silk_to_wav(audio_url, output_path)
|
||||
audio_url = output_path
|
||||
|
||||
|
||||
@@ -3,11 +3,12 @@ from typing import List, Dict
|
||||
from astrbot.core import logger
|
||||
from .store import Store
|
||||
from astrbot.core.config import AstrBotConfig
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
|
||||
class KnowledgeDBManager:
|
||||
def __init__(self, astrbot_config: AstrBotConfig) -> None:
|
||||
self.db_path = "data/knowledge_db/"
|
||||
self.db_path = os.path.join(get_astrbot_data_path(), "knowledge_db")
|
||||
self.config = astrbot_config.get("knowledge_db", {})
|
||||
self.astrbot_config = astrbot_config
|
||||
if not os.path.exists(self.db_path):
|
||||
|
||||
@@ -4,12 +4,14 @@ from typing import List, Dict
|
||||
from astrbot.api import logger
|
||||
from ..embedding.openai_source import SimpleOpenAIEmbedding
|
||||
from . import Store
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
|
||||
class ChromaVectorStore(Store):
|
||||
def __init__(self, name: str, embedding_cfg: Dict) -> None:
|
||||
import os
|
||||
self.chroma_client = chromadb.PersistentClient(
|
||||
path="data/long_term_memory_chroma.db"
|
||||
path=os.path.join(get_astrbot_data_path(), "long_term_memory_chroma.db")
|
||||
)
|
||||
self.collection = self.chroma_client.get_or_create_collection(name=name)
|
||||
self.embedding = None
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
from typing import Union
|
||||
import os
|
||||
import json
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
|
||||
def load_config(namespace: str) -> Union[dict, bool]:
|
||||
@@ -13,7 +14,7 @@ def load_config(namespace: str) -> Union[dict, bool]:
|
||||
namespace: str, 配置的唯一识别符,也就是配置文件的名字。
|
||||
返回值: 当配置文件存在时,返回 namespace 对应配置文件的内容dict,否则返回 False。
|
||||
"""
|
||||
path = f"data/config/{namespace}.json"
|
||||
path = os.path.join(get_astrbot_data_path(), "config", f"{namespace}.json")
|
||||
if not os.path.exists(path):
|
||||
return False
|
||||
with open(path, "r", encoding="utf-8-sig") as f:
|
||||
@@ -43,7 +44,10 @@ def put_config(namespace: str, name: str, key: str, value, description: str):
|
||||
raise ValueError("key 只支持 str 类型。")
|
||||
if not isinstance(value, (str, int, float, bool, list)):
|
||||
raise ValueError("value 只支持 str, int, float, bool, list 类型。")
|
||||
path = f"data/config/{namespace}.json"
|
||||
|
||||
config_dir = os.path.join(get_astrbot_data_path(), "config")
|
||||
path = os.path.join(config_dir, f"{namespace}.json")
|
||||
|
||||
if not os.path.exists(path):
|
||||
with open(path, "w", encoding="utf-8-sig") as f:
|
||||
f.write("{}")
|
||||
@@ -71,7 +75,7 @@ def update_config(namespace: str, key: str, value):
|
||||
key: str, 配置项的键。
|
||||
value: str, int, float, bool, list, 配置项的值。
|
||||
"""
|
||||
path = f"data/config/{namespace}.json"
|
||||
path = os.path.join(get_astrbot_data_path(), "config", f"{namespace}.json")
|
||||
if not os.path.exists(path):
|
||||
raise FileNotFoundError(f"配置文件 {namespace}.json 不存在。")
|
||||
with open(path, "r", encoding="utf-8-sig") as f:
|
||||
|
||||
@@ -22,6 +22,10 @@ from astrbot.core.utils.io import remove_dir
|
||||
from .star import star_registry, star_map
|
||||
from .star_handler import star_handlers_registry
|
||||
from astrbot.core.provider.register import llm_tools
|
||||
from astrbot.core.utils.astrbot_path import (
|
||||
get_astrbot_plugin_path,
|
||||
get_astrbot_config_path,
|
||||
)
|
||||
|
||||
from .filter.permission import PermissionTypeFilter, PermissionType
|
||||
|
||||
@@ -34,17 +38,9 @@ class PluginManager:
|
||||
self.context._star_manager = self
|
||||
|
||||
self.config = config
|
||||
self.plugin_store_path = os.path.abspath(
|
||||
os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)), "../../../data/plugins"
|
||||
)
|
||||
)
|
||||
self.plugin_store_path = get_astrbot_plugin_path()
|
||||
"""存储插件的路径。即 data/plugins"""
|
||||
self.plugin_config_path = os.path.abspath(
|
||||
os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)), "../../../data/config"
|
||||
)
|
||||
)
|
||||
self.plugin_config_path = get_astrbot_config_path()
|
||||
"""存储插件配置的路径。data/config"""
|
||||
self.reserved_plugin_path = os.path.abspath(
|
||||
os.path.join(
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
import inspect
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Union, Awaitable, List, Optional, ClassVar
|
||||
from astrbot.core.message.components import BaseMessageComponent
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
@@ -6,7 +8,7 @@ from astrbot.api.platform import MessageMember, AstrBotMessage
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from astrbot.core.star.context import Context
|
||||
from astrbot.core.star.star import star_map
|
||||
from pathlib import Path
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
|
||||
class StarTools:
|
||||
@@ -180,7 +182,7 @@ class StarTools:
|
||||
|
||||
plugin_name = metadata.name
|
||||
|
||||
data_dir = Path("data/plugin_data") / plugin_name
|
||||
data_dir = Path(os.path.join(get_astrbot_data_path(), "plugin_data", plugin_name))
|
||||
|
||||
try:
|
||||
data_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@@ -6,16 +6,13 @@ from ..updator import RepoZipUpdator
|
||||
from astrbot.core.utils.io import remove_dir, on_error
|
||||
from ..star.star import StarMetadata
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_plugin_path
|
||||
|
||||
|
||||
class PluginUpdator(RepoZipUpdator):
|
||||
def __init__(self, repo_mirror: str = "") -> None:
|
||||
super().__init__(repo_mirror)
|
||||
self.plugin_store_path = os.path.abspath(
|
||||
os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)), "../../../data/plugins"
|
||||
)
|
||||
)
|
||||
self.plugin_store_path = get_astrbot_plugin_path()
|
||||
|
||||
def get_plugin_store_path(self) -> str:
|
||||
return self.plugin_store_path
|
||||
|
||||
+12
-4
@@ -6,6 +6,7 @@ from .zip_updator import ReleaseInfo, RepoZipUpdator
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.config.default import VERSION
|
||||
from astrbot.core.utils.io import download_file
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_path
|
||||
|
||||
|
||||
class AstrBotUpdator(RepoZipUpdator):
|
||||
@@ -16,9 +17,7 @@ class AstrBotUpdator(RepoZipUpdator):
|
||||
|
||||
def __init__(self, repo_mirror: str = "") -> None:
|
||||
super().__init__(repo_mirror)
|
||||
self.MAIN_PATH = os.path.abspath(
|
||||
os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../")
|
||||
)
|
||||
self.MAIN_PATH = get_astrbot_path()
|
||||
self.ASTRBOT_RELEASE_API = "https://api.soulter.top/releases"
|
||||
|
||||
def terminate_child_processes(self):
|
||||
@@ -51,7 +50,13 @@ class AstrBotUpdator(RepoZipUpdator):
|
||||
self.terminate_child_processes()
|
||||
py = py.replace(" ", "\\ ")
|
||||
try:
|
||||
os.execl(py, py, *sys.argv)
|
||||
if "astrbot" in os.path.basename(sys.argv[0]): # 兼容cli
|
||||
args = [
|
||||
f'"{arg}"' if " " in arg else arg for arg in sys.argv[1:]
|
||||
]
|
||||
os.execl(py, py, "-m", "astrbot.cli.__main__", *args)
|
||||
else:
|
||||
os.execl(py, py, *sys.argv)
|
||||
except Exception as e:
|
||||
logger.error(f"重启失败({py}, {e}),请尝试手动重启。")
|
||||
raise e
|
||||
@@ -67,6 +72,9 @@ class AstrBotUpdator(RepoZipUpdator):
|
||||
update_data = await self.fetch_release_info(self.ASTRBOT_RELEASE_API, latest)
|
||||
file_url = None
|
||||
|
||||
if os.environ.get("ASTRBOT_CLI"):
|
||||
raise Exception("不支持更新CLI启动的AstrBot") # 避免版本管理混乱
|
||||
|
||||
if latest:
|
||||
latest_version = update_data[0]["tag_name"]
|
||||
if self.compare_version(VERSION, latest_version) >= 0:
|
||||
|
||||
@@ -0,0 +1,41 @@
|
||||
"""
|
||||
Astrbot统一路径获取
|
||||
|
||||
项目路径:固定为源码所在路径
|
||||
根目录路径:默认为当前工作目录,可通过环境变量 ASTRBOT_ROOT 指定
|
||||
数据目录路径:固定为根目录下的 data 目录
|
||||
配置文件路径:固定为数据目录下的 config 目录
|
||||
插件目录路径:固定为数据目录下的 plugins 目录
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
|
||||
def get_astrbot_path() -> str:
|
||||
"""获取Astrbot项目路径"""
|
||||
return os.path.realpath(
|
||||
os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../../")
|
||||
)
|
||||
|
||||
|
||||
def get_astrbot_root() -> str:
|
||||
"""获取Astrbot根目录路径"""
|
||||
if path := os.environ.get("ASTRBOT_ROOT"):
|
||||
return os.path.realpath(path)
|
||||
else:
|
||||
return os.path.realpath(os.getcwd())
|
||||
|
||||
|
||||
def get_astrbot_data_path() -> str:
|
||||
"""获取Astrbot数据目录路径"""
|
||||
return os.path.realpath(os.path.join(get_astrbot_root(), "data"))
|
||||
|
||||
|
||||
def get_astrbot_config_path() -> str:
|
||||
"""获取Astrbot配置文件路径"""
|
||||
return os.path.realpath(os.path.join(get_astrbot_data_path(), "config"))
|
||||
|
||||
|
||||
def get_astrbot_plugin_path() -> str:
|
||||
"""获取Astrbot插件目录路径"""
|
||||
return os.path.realpath(os.path.join(get_astrbot_data_path(), "plugins"))
|
||||
+16
-14
@@ -14,6 +14,7 @@ import certifi
|
||||
from typing import Union
|
||||
|
||||
from PIL import Image
|
||||
from .astrbot_path import get_astrbot_data_path
|
||||
|
||||
|
||||
def on_error(func, path, exc_info):
|
||||
@@ -49,11 +50,11 @@ def port_checker(port: int, host: str = "localhost"):
|
||||
|
||||
|
||||
def save_temp_img(img: Union[Image.Image, str]) -> str:
|
||||
os.makedirs("data/temp", exist_ok=True)
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
# 获得文件创建时间,清除超过 12 小时的
|
||||
try:
|
||||
for f in os.listdir("data/temp"):
|
||||
path = os.path.join("data/temp", f)
|
||||
for f in os.listdir(temp_dir):
|
||||
path = os.path.join(temp_dir, f)
|
||||
if os.path.isfile(path):
|
||||
ctime = os.path.getctime(path)
|
||||
if time.time() - ctime > 3600 * 12:
|
||||
@@ -63,7 +64,7 @@ def save_temp_img(img: Union[Image.Image, str]) -> str:
|
||||
|
||||
# 获得时间戳
|
||||
timestamp = f"{int(time.time())}_{uuid.uuid4().hex[:8]}"
|
||||
p = f"data/temp/{timestamp}.jpg"
|
||||
p = os.path.join(temp_dir, f"{timestamp}.jpg")
|
||||
|
||||
if isinstance(img, Image.Image):
|
||||
img.save(p)
|
||||
@@ -201,28 +202,29 @@ def get_local_ip_addresses():
|
||||
|
||||
|
||||
async def get_dashboard_version():
|
||||
if os.path.exists("data/dist"):
|
||||
if os.path.exists("data/dist/assets/version"):
|
||||
with open("data/dist/assets/version", "r") as f:
|
||||
dist_dir = os.path.join(get_astrbot_data_path(), "dist")
|
||||
if os.path.exists(dist_dir):
|
||||
version_file = os.path.join(dist_dir, "assets", "version")
|
||||
if os.path.exists(version_file):
|
||||
with open(version_file, "r") as f:
|
||||
v = f.read().strip()
|
||||
return v
|
||||
return None
|
||||
|
||||
|
||||
async def download_dashboard(path: str = "data/dashboard.zip", extract_path: str = "data"):
|
||||
async def download_dashboard(path: str = None, extract_path: str = "data"):
|
||||
"""下载管理面板文件"""
|
||||
if path is None:
|
||||
path = os.path.join(get_astrbot_data_path(), "dashboard.zip")
|
||||
|
||||
dashboard_release_url = "https://astrbot-registry.soulter.top/download/astrbot-dashboard/latest/dist.zip"
|
||||
try:
|
||||
await download_file(
|
||||
dashboard_release_url, path, show_progress=True
|
||||
)
|
||||
await download_file(dashboard_release_url, path, show_progress=True)
|
||||
except BaseException as _:
|
||||
dashboard_release_url = (
|
||||
"https://github.com/Soulter/AstrBot/releases/latest/download/dist.zip"
|
||||
)
|
||||
await download_file(
|
||||
dashboard_release_url, path, show_progress=True
|
||||
)
|
||||
await download_file(dashboard_release_url, path, show_progress=True)
|
||||
print("解压管理面板文件中...")
|
||||
with zipfile.ZipFile(path, "r") as z:
|
||||
z.extractall(extract_path)
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
import json
|
||||
import os
|
||||
from .astrbot_path import get_astrbot_data_path
|
||||
|
||||
|
||||
class SharedPreferences:
|
||||
def __init__(self, path="data/shared_preferences.json"):
|
||||
def __init__(self, path=None):
|
||||
if path is None:
|
||||
path = os.path.join(get_astrbot_data_path(), "shared_preferences.json")
|
||||
self.path = path
|
||||
self._data = self._load_preferences()
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -8,6 +8,7 @@ from astrbot.core.db import BaseDatabase
|
||||
import asyncio
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
|
||||
class ChatRoute(Route):
|
||||
@@ -33,7 +34,8 @@ class ChatRoute(Route):
|
||||
self.db = db
|
||||
self.core_lifecycle = core_lifecycle
|
||||
self.register_routes()
|
||||
self.imgs_dir = "data/webchat/imgs"
|
||||
self.imgs_dir = os.path.join(get_astrbot_data_path(), "webchat", "imgs")
|
||||
os.makedirs(self.imgs_dir, exist_ok=True)
|
||||
|
||||
self.supported_imgs = ["jpg", "jpeg", "png", "gif", "webp"]
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import os
|
||||
import aiohttp
|
||||
import datetime
|
||||
import builtins
|
||||
@@ -13,6 +14,7 @@ from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from astrbot.core.platform.message_type import MessageType
|
||||
from astrbot.core.provider.sources.dify_source import ProviderDify
|
||||
from astrbot.core.utils.io import download_dashboard, get_dashboard_version
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
from astrbot.core.star.star_handler import star_handlers_registry, StarHandlerMetadata
|
||||
from astrbot.core.star.star import star_map
|
||||
from astrbot.core.star.star_manager import PluginManager
|
||||
@@ -1159,7 +1161,8 @@ UID: {user_id} 此 ID 可用于设置管理员。
|
||||
@filter.command("gewe_code")
|
||||
async def gewe_code(self, event: AstrMessageEvent, code: str):
|
||||
"""保存 gewechat 验证码"""
|
||||
with open("data/temp/gewe_code", "w", encoding="utf-8") as f:
|
||||
code_path = os.path.join(get_astrbot_data_path(), "temp","gewe_code")
|
||||
with open(code_path, "w", encoding="utf-8") as f:
|
||||
f.write(code)
|
||||
yield event.plain_result("验证码已保存。")
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ from astrbot.api.event import filter
|
||||
from astrbot.api.provider import ProviderRequest
|
||||
from astrbot.api.message_components import Image, File
|
||||
from astrbot.core.utils.io import download_image_by_url, download_file
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
PROMPT = """
|
||||
## Task
|
||||
@@ -90,7 +91,7 @@ DEFAULT_CONFIG = {
|
||||
},
|
||||
"docker_host_astrbot_abs_path": "",
|
||||
}
|
||||
PATH = "data/config/python_interpreter.json"
|
||||
PATH = os.path.join(get_astrbot_data_path(), "config", "python_interpreter.json")
|
||||
|
||||
|
||||
@star.register(
|
||||
@@ -212,7 +213,8 @@ class Main(star.Star):
|
||||
if isinstance(comp, File):
|
||||
if comp.file.startswith("http"):
|
||||
name = comp.name if comp.name else uuid.uuid4().hex[:8]
|
||||
path = f"data/temp/{name}"
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
path = os.path.join(temp_dir, name)
|
||||
await download_file(comp.file, path)
|
||||
else:
|
||||
path = comp.file
|
||||
|
||||
@@ -8,6 +8,7 @@ from astrbot.api.event import filter
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
from astrbot.api.event import AstrMessageEvent, MessageEventResult
|
||||
from astrbot.api import llm_tool, logger
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
|
||||
@star.register(
|
||||
@@ -29,10 +30,11 @@ class Main(star.Star):
|
||||
self.scheduler = AsyncIOScheduler(timezone=self.timezone)
|
||||
|
||||
# set and load config
|
||||
if not os.path.exists("data/astrbot-reminder.json"):
|
||||
with open("data/astrbot-reminder.json", "w", encoding="utf-8") as f:
|
||||
reminder_file = os.path.join(get_astrbot_data_path(), "astrbot-reminder.json")
|
||||
if not os.path.exists(reminder_file):
|
||||
with open(reminder_file, "w", encoding="utf-8") as f:
|
||||
f.write("{}")
|
||||
with open("data/astrbot-reminder.json", "r", encoding="utf-8") as f:
|
||||
with open(reminder_file, "r", encoding="utf-8") as f:
|
||||
self.reminder_data = json.load(f)
|
||||
|
||||
self._init_scheduler()
|
||||
@@ -82,7 +84,8 @@ class Main(star.Star):
|
||||
|
||||
async def _save_data(self):
|
||||
"""Save the reminder data."""
|
||||
with open("data/astrbot-reminder.json", "w", encoding="utf-8") as f:
|
||||
reminder_file = os.path.join(get_astrbot_data_path(), "astrbot-reminder.json")
|
||||
with open(reminder_file, "w", encoding="utf-8") as f:
|
||||
json.dump(self.reminder_data, f, ensure_ascii=False)
|
||||
|
||||
def _parse_cron_expr(self, cron_expr: str):
|
||||
|
||||
Reference in New Issue
Block a user