feat: temporary file handling and introduce TempDirCleaner (#5026)

* feat: temporary file handling and introduce TempDirCleaner

- Updated various modules to use `get_astrbot_temp_path()` instead of `get_astrbot_data_path()` for temporary file storage.
- Renamed temporary files for better identification and organization.
- Introduced `TempDirCleaner` to manage the size of the temporary directory, ensuring it does not exceed a specified limit by deleting the oldest files.
- Added configuration option for maximum temporary directory size in the dashboard.
- Implemented tests for `TempDirCleaner` to verify cleanup functionality and size management.

* ruff
This commit is contained in:
Soulter
2026-02-12 01:04:48 +08:00
committed by GitHub
parent a8dda20a30
commit 9d93bda3fe
37 changed files with 388 additions and 112 deletions
@@ -10,7 +10,7 @@ from astrbot.core.provider.entities import (
LLMResponse, LLMResponse,
ProviderRequest, ProviderRequest,
) )
from astrbot.core.utils.astrbot_path import get_astrbot_data_path from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from astrbot.core.utils.io import download_file from astrbot.core.utils.io import download_file
from ...hooks import BaseAgentRunHooks from ...hooks import BaseAgentRunHooks
@@ -291,8 +291,8 @@ class DifyAgentRunner(BaseAgentRunner[TContext]):
return Comp.Image(file=item["url"], url=item["url"]) return Comp.Image(file=item["url"], url=item["url"])
case "audio": case "audio":
# 仅支持 wav # 仅支持 wav
temp_dir = os.path.join(get_astrbot_data_path(), "temp") temp_dir = get_astrbot_temp_path()
path = os.path.join(temp_dir, f"{item['filename']}.wav") path = os.path.join(temp_dir, f"dify_{item['filename']}.wav")
await download_file(item["url"], path) await download_file(item["url"], path)
return Comp.Image(file=item["url"], url=item["url"]) return Comp.Image(file=item["url"], url=item["url"])
case "video": case "video":
+9 -6
View File
@@ -1,6 +1,7 @@
import base64 import base64
import json import json
import os import os
import uuid
from pydantic import Field from pydantic import Field
from pydantic.dataclasses import dataclass from pydantic.dataclasses import dataclass
@@ -240,7 +241,9 @@ class SendMessageToUserTool(FunctionTool[AstrAgentContext]):
if "_&exists_" in json.dumps(result): if "_&exists_" in json.dumps(result):
# Download the file from sandbox # Download the file from sandbox
name = os.path.basename(path) name = os.path.basename(path)
local_path = os.path.join(get_astrbot_temp_path(), name) local_path = os.path.join(
get_astrbot_temp_path(), f"sandbox_{uuid.uuid4().hex[:4]}_{name}"
)
await sb.download_file(path, local_path) await sb.download_file(path, local_path)
logger.info(f"Downloaded file from sandbox: {path} -> {local_path}") logger.info(f"Downloaded file from sandbox: {path} -> {local_path}")
return local_path, True return local_path, True
@@ -352,11 +355,11 @@ class SendMessageToUserTool(FunctionTool[AstrAgentContext]):
MessageChain(chain=components), MessageChain(chain=components),
) )
if file_from_sandbox: # if file_from_sandbox:
try: # try:
os.remove(local_path) # os.remove(local_path)
except Exception as e: # except Exception as e:
logger.error(f"Error removing temp file {local_path}: {e}") # logger.error(f"Error removing temp file {local_path}: {e}")
return f"Message sent to session {target_session}" return f"Message sent to session {target_session}"
+9 -6
View File
@@ -1,4 +1,5 @@
import os import os
import uuid
from dataclasses import dataclass, field from dataclasses import dataclass, field
from astrbot.api import FunctionTool, logger from astrbot.api import FunctionTool, logger
@@ -167,7 +168,9 @@ class FileDownloadTool(FunctionTool):
try: try:
name = os.path.basename(remote_path) name = os.path.basename(remote_path)
local_path = os.path.join(get_astrbot_temp_path(), name) local_path = os.path.join(
get_astrbot_temp_path(), f"sandbox_{uuid.uuid4().hex[:4]}_{name}"
)
# Download file from sandbox # Download file from sandbox
await sb.download_file(remote_path, local_path) await sb.download_file(remote_path, local_path)
@@ -183,12 +186,12 @@ class FileDownloadTool(FunctionTool):
logger.error(f"Error sending file message: {e}") logger.error(f"Error sending file message: {e}")
# remove # remove
try: # try:
os.remove(local_path) # os.remove(local_path)
except Exception as e: # except Exception as e:
logger.error(f"Error removing temp file {local_path}: {e}") # logger.error(f"Error removing temp file {local_path}: {e}")
return f"File downloaded successfully to {local_path} and sent to user. The file has been removed from local storage." return f"File downloaded successfully to {local_path} and sent to user."
return f"File downloaded successfully to {local_path}" return f"File downloaded successfully to {local_path}"
except Exception as e: except Exception as e:
+7
View File
@@ -203,6 +203,7 @@ DEFAULT_CONFIG = {
"log_file_enable": False, "log_file_enable": False,
"log_file_path": "logs/astrbot.log", "log_file_path": "logs/astrbot.log",
"log_file_max_mb": 20, "log_file_max_mb": 20,
"temp_dir_max_size": 1024,
"trace_enable": False, "trace_enable": False,
"trace_log_enable": False, "trace_log_enable": False,
"trace_log_path": "logs/astrbot.trace.log", "trace_log_path": "logs/astrbot.trace.log",
@@ -2394,6 +2395,7 @@ CONFIG_METADATA_2 = {
"log_file_enable": {"type": "bool"}, "log_file_enable": {"type": "bool"},
"log_file_path": {"type": "string", "condition": {"log_file_enable": True}}, "log_file_path": {"type": "string", "condition": {"log_file_enable": True}},
"log_file_max_mb": {"type": "int", "condition": {"log_file_enable": True}}, "log_file_max_mb": {"type": "int", "condition": {"log_file_enable": True}},
"temp_dir_max_size": {"type": "int"},
"trace_log_enable": {"type": "bool"}, "trace_log_enable": {"type": "bool"},
"trace_log_path": { "trace_log_path": {
"type": "string", "type": "string",
@@ -3372,6 +3374,11 @@ CONFIG_METADATA_3_SYSTEM = {
"type": "int", "type": "int",
"hint": "超过大小后自动轮转,默认 20MB。", "hint": "超过大小后自动轮转,默认 20MB。",
}, },
"temp_dir_max_size": {
"description": "临时目录大小上限 (MB)",
"type": "int",
"hint": "用于限制 data/temp 目录总大小,单位为 MB。系统每 10 分钟检查一次,超限时按文件修改时间从旧到新删除,释放约 30% 当前体积。",
},
"trace_log_enable": { "trace_log_enable": {
"description": "启用 Trace 文件日志", "description": "启用 Trace 文件日志",
"type": "bool", "type": "bool",
+19
View File
@@ -37,6 +37,7 @@ from astrbot.core.umop_config_router import UmopConfigRouter
from astrbot.core.updator import AstrBotUpdator from astrbot.core.updator import AstrBotUpdator
from astrbot.core.utils.llm_metadata import update_llm_metadata from astrbot.core.utils.llm_metadata import update_llm_metadata
from astrbot.core.utils.migra_helper import migra from astrbot.core.utils.migra_helper import migra
from astrbot.core.utils.temp_dir_cleaner import TempDirCleaner
from . import astrbot_config, html_renderer from . import astrbot_config, html_renderer
from .event_bus import EventBus from .event_bus import EventBus
@@ -57,6 +58,7 @@ class AstrBotCoreLifecycle:
self.subagent_orchestrator: SubAgentOrchestrator | None = None self.subagent_orchestrator: SubAgentOrchestrator | None = None
self.cron_manager: CronJobManager | None = None self.cron_manager: CronJobManager | None = None
self.temp_dir_cleaner: TempDirCleaner | None = None
# 设置代理 # 设置代理
proxy_config = self.astrbot_config.get("http_proxy", "") proxy_config = self.astrbot_config.get("http_proxy", "")
@@ -125,6 +127,12 @@ class AstrBotCoreLifecycle:
ucr=self.umop_config_router, ucr=self.umop_config_router,
sp=sp, sp=sp,
) )
self.temp_dir_cleaner = TempDirCleaner(
max_size_getter=lambda: self.astrbot_config_mgr.default_conf.get(
TempDirCleaner.CONFIG_KEY,
TempDirCleaner.DEFAULT_MAX_SIZE,
),
)
# apply migration # apply migration
try: try:
@@ -238,6 +246,12 @@ class AstrBotCoreLifecycle:
self.cron_manager.start(self.star_context), self.cron_manager.start(self.star_context),
name="cron_manager", name="cron_manager",
) )
temp_dir_cleaner_task = None
if self.temp_dir_cleaner:
temp_dir_cleaner_task = asyncio.create_task(
self.temp_dir_cleaner.run(),
name="temp_dir_cleaner",
)
# 把插件中注册的所有协程函数注册到事件总线中并执行 # 把插件中注册的所有协程函数注册到事件总线中并执行
extra_tasks = [] extra_tasks = []
@@ -247,6 +261,8 @@ class AstrBotCoreLifecycle:
tasks_ = [event_bus_task, *(extra_tasks if extra_tasks else [])] tasks_ = [event_bus_task, *(extra_tasks if extra_tasks else [])]
if cron_task: if cron_task:
tasks_.append(cron_task) tasks_.append(cron_task)
if temp_dir_cleaner_task:
tasks_.append(temp_dir_cleaner_task)
for task in tasks_: for task in tasks_:
self.curr_tasks.append( self.curr_tasks.append(
asyncio.create_task(self._task_wrapper(task), name=task.get_name()), asyncio.create_task(self._task_wrapper(task), name=task.get_name()),
@@ -298,6 +314,9 @@ class AstrBotCoreLifecycle:
async def stop(self) -> None: async def stop(self) -> None:
"""停止 AstrBot 核心生命周期管理类, 取消所有当前任务并终止各个管理器.""" """停止 AstrBot 核心生命周期管理类, 取消所有当前任务并终止各个管理器."""
if self.temp_dir_cleaner:
await self.temp_dir_cleaner.stop()
# 请求停止所有正在运行的异步任务 # 请求停止所有正在运行的异步任务
for task in self.curr_tasks: for task in self.curr_tasks:
task.cancel() task.cancel()
+13 -11
View File
@@ -31,7 +31,7 @@ from enum import Enum
from pydantic.v1 import BaseModel from pydantic.v1 import BaseModel
from astrbot.core import astrbot_config, file_token_service, logger from astrbot.core import astrbot_config, file_token_service, logger
from astrbot.core.utils.astrbot_path import get_astrbot_data_path from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from astrbot.core.utils.io import download_file, download_image_by_url, file_to_base64 from astrbot.core.utils.io import download_file, download_image_by_url, file_to_base64
@@ -156,8 +156,9 @@ class Record(BaseMessageComponent):
if self.file.startswith("base64://"): if self.file.startswith("base64://"):
bs64_data = self.file.removeprefix("base64://") bs64_data = self.file.removeprefix("base64://")
image_bytes = base64.b64decode(bs64_data) image_bytes = base64.b64decode(bs64_data)
temp_dir = os.path.join(get_astrbot_data_path(), "temp") file_path = os.path.join(
file_path = os.path.join(temp_dir, f"{uuid.uuid4()}.jpg") get_astrbot_temp_path(), f"recordseg_{uuid.uuid4()}.jpg"
)
with open(file_path, "wb") as f: with open(file_path, "wb") as f:
f.write(image_bytes) f.write(image_bytes)
return os.path.abspath(file_path) return os.path.abspath(file_path)
@@ -245,8 +246,9 @@ class Video(BaseMessageComponent):
if url and url.startswith("file:///"): if url and url.startswith("file:///"):
return url[8:] return url[8:]
if url and url.startswith("http"): if url and url.startswith("http"):
download_dir = os.path.join(get_astrbot_data_path(), "temp") video_file_path = os.path.join(
video_file_path = os.path.join(download_dir, f"{uuid.uuid4().hex}") get_astrbot_temp_path(), f"videoseg_{uuid.uuid4().hex}"
)
await download_file(url, video_file_path) await download_file(url, video_file_path)
if os.path.exists(video_file_path): if os.path.exists(video_file_path):
return os.path.abspath(video_file_path) return os.path.abspath(video_file_path)
@@ -445,8 +447,9 @@ class Image(BaseMessageComponent):
if url.startswith("base64://"): if url.startswith("base64://"):
bs64_data = url.removeprefix("base64://") bs64_data = url.removeprefix("base64://")
image_bytes = base64.b64decode(bs64_data) image_bytes = base64.b64decode(bs64_data)
temp_dir = os.path.join(get_astrbot_data_path(), "temp") image_file_path = os.path.join(
image_file_path = os.path.join(temp_dir, f"{uuid.uuid4()}.jpg") get_astrbot_temp_path(), f"imgseg_{uuid.uuid4()}.jpg"
)
with open(image_file_path, "wb") as f: with open(image_file_path, "wb") as f:
f.write(image_bytes) f.write(image_bytes)
return os.path.abspath(image_file_path) return os.path.abspath(image_file_path)
@@ -725,13 +728,12 @@ class File(BaseMessageComponent):
"""下载文件""" """下载文件"""
if not self.url: if not self.url:
raise ValueError("Download failed: No URL provided in File component.") raise ValueError("Download failed: No URL provided in File component.")
download_dir = os.path.join(get_astrbot_data_path(), "temp") download_dir = get_astrbot_temp_path()
os.makedirs(download_dir, exist_ok=True)
if self.name: if self.name:
name, ext = os.path.splitext(self.name) name, ext = os.path.splitext(self.name)
filename = f"{name}_{uuid.uuid4().hex[:8]}{ext}" filename = f"fileseg_{name}_{uuid.uuid4().hex[:8]}{ext}"
else: else:
filename = f"{uuid.uuid4().hex}" filename = f"fileseg_{uuid.uuid4().hex}"
file_path = os.path.join(download_dir, filename) file_path = os.path.join(download_dir, filename)
await download_file(self.url, file_path) await download_file(self.url, file_path)
self.file_ = os.path.abspath(file_path) self.file_ = os.path.abspath(file_path)
@@ -21,7 +21,7 @@ from astrbot.api.platform import (
) )
from astrbot.core import sp from astrbot.core import sp
from astrbot.core.platform.astr_message_event import MessageSesion from astrbot.core.platform.astr_message_event import MessageSesion
from astrbot.core.utils.astrbot_path import get_astrbot_data_path from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from astrbot.core.utils.io import download_file from astrbot.core.utils.io import download_file
from astrbot.core.utils.media_utils import ( from astrbot.core.utils.media_utils import (
convert_audio_format, convert_audio_format,
@@ -253,9 +253,9 @@ class DingtalkPlatformAdapter(Platform):
"downloadCode": download_code, "downloadCode": download_code,
"robotCode": robot_code, "robotCode": robot_code,
} }
temp_dir = Path(get_astrbot_data_path()) / "temp" temp_dir = Path(get_astrbot_temp_path())
temp_dir.mkdir(parents=True, exist_ok=True) temp_dir.mkdir(parents=True, exist_ok=True)
f_path = temp_dir / f"dingtalk_file_{uuid.uuid4()}.{ext}" f_path = temp_dir / f"dingtalk_{uuid.uuid4()}.{ext}"
async with ( async with (
aiohttp.ClientSession() as session, aiohttp.ClientSession() as session,
session.post( session.post(
@@ -21,7 +21,7 @@ from astrbot import logger
from astrbot.api.event import AstrMessageEvent, MessageChain from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.message_components import At, File, Plain, Record, Video from astrbot.api.message_components import At, File, Plain, Record, Video
from astrbot.api.message_components import Image as AstrBotImage from astrbot.api.message_components import Image as AstrBotImage
from astrbot.core.utils.astrbot_path import get_astrbot_data_path from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from astrbot.core.utils.io import download_image_by_url from astrbot.core.utils.io import download_image_by_url
from astrbot.core.utils.media_utils import ( from astrbot.core.utils.media_utils import (
convert_audio_to_opus, convert_audio_to_opus,
@@ -202,8 +202,11 @@ class LarkMessageEvent(AstrMessageEvent):
base64_str = comp.file.removeprefix("base64://") base64_str = comp.file.removeprefix("base64://")
image_data = base64.b64decode(base64_str) image_data = base64.b64decode(base64_str)
# save as temp file # save as temp file
temp_dir = os.path.join(get_astrbot_data_path(), "temp") temp_dir = get_astrbot_temp_path()
file_path = os.path.join(temp_dir, f"{uuid.uuid4()}_test.jpg") file_path = os.path.join(
temp_dir,
f"lark_image_{uuid.uuid4().hex[:8]}.jpg",
)
with open(file_path, "wb") as f: with open(file_path, "wb") as f:
f.write(BytesIO(image_data).getvalue()) f.write(BytesIO(image_data).getvalue())
else: else:
@@ -21,7 +21,7 @@ try:
except Exception: except Exception:
magic = None magic = None
from astrbot.core.utils.astrbot_path import get_astrbot_data_path from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from .misskey_event import MisskeyPlatformEvent from .misskey_event import MisskeyPlatformEvent
from .misskey_utils import ( from .misskey_utils import (
@@ -498,7 +498,7 @@ class MisskeyPlatformAdapter(Platform):
finally: finally:
# 清理临时文件 # 清理临时文件
if local_path and isinstance(local_path, str): if local_path and isinstance(local_path, str):
data_temp = os.path.join(get_astrbot_data_path(), "temp") data_temp = get_astrbot_temp_path()
if local_path.startswith(data_temp) and os.path.exists( if local_path.startswith(data_temp) and os.path.exists(
local_path, local_path,
): ):
@@ -19,7 +19,7 @@ from astrbot.api import logger
from astrbot.api.event import AstrMessageEvent, MessageChain from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.message_components import Image, Plain, Record from astrbot.api.message_components import Image, Plain, Record
from astrbot.api.platform import AstrBotMessage, PlatformMetadata from astrbot.api.platform import AstrBotMessage, PlatformMetadata
from astrbot.core.utils.astrbot_path import get_astrbot_data_path from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from astrbot.core.utils.io import download_image_by_url, file_to_base64 from astrbot.core.utils.io import download_image_by_url, file_to_base64
from astrbot.core.utils.tencent_record_helper import wav_to_tencent_silk from astrbot.core.utils.tencent_record_helper import wav_to_tencent_silk
@@ -350,10 +350,10 @@ class QQOfficialMessageEvent(AstrMessageEvent):
elif isinstance(i, Record): elif isinstance(i, Record):
if i.file: if i.file:
record_wav_path = await i.convert_to_file_path() # wav 路径 record_wav_path = await i.convert_to_file_path() # wav 路径
temp_dir = os.path.join(get_astrbot_data_path(), "temp") temp_dir = get_astrbot_temp_path()
record_tecent_silk_path = os.path.join( record_tecent_silk_path = os.path.join(
temp_dir, temp_dir,
f"{uuid.uuid4()}.silk", f"qqofficial_{uuid.uuid4()}.silk",
) )
try: try:
duration = await wav_to_tencent_silk( duration = await wav_to_tencent_silk(
@@ -25,7 +25,7 @@ from astrbot.api.platform import (
) )
from astrbot.core import logger from astrbot.core import logger
from astrbot.core.platform.astr_message_event import MessageSesion from astrbot.core.platform.astr_message_event import MessageSesion
from astrbot.core.utils.astrbot_path import get_astrbot_data_path from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from astrbot.core.utils.media_utils import convert_audio_to_wav from astrbot.core.utils.media_utils import convert_audio_to_wav
from astrbot.core.utils.webhook_utils import log_webhook_info from astrbot.core.utils.webhook_utils import log_webhook_info
@@ -344,7 +344,7 @@ class WecomPlatformAdapter(Platform):
self.client.media.download, self.client.media.download,
msg.media_id, msg.media_id,
) )
temp_dir = os.path.join(get_astrbot_data_path(), "temp") temp_dir = get_astrbot_temp_path()
path = os.path.join(temp_dir, f"wecom_{msg.media_id}.amr") path = os.path.join(temp_dir, f"wecom_{msg.media_id}.amr")
with open(path, "wb") as f: with open(path, "wb") as f:
f.write(resp.content) f.write(resp.content)
@@ -400,7 +400,8 @@ class WecomPlatformAdapter(Platform):
self.client.media.download, self.client.media.download,
media_id, media_id,
) )
path = f"data/temp/wechat_kf_{media_id}.jpg" temp_dir = get_astrbot_temp_path()
path = os.path.join(temp_dir, f"weixinkefu_{media_id}.jpg")
with open(path, "wb") as f: with open(path, "wb") as f:
f.write(resp.content) f.write(resp.content)
abm.message = [Image(file=path, url=path)] abm.message = [Image(file=path, url=path)]
@@ -412,7 +413,7 @@ class WecomPlatformAdapter(Platform):
media_id, media_id,
) )
temp_dir = os.path.join(get_astrbot_data_path(), "temp") temp_dir = get_astrbot_temp_path()
path = os.path.join(temp_dir, f"weixinkefu_{media_id}.amr") path = os.path.join(temp_dir, f"weixinkefu_{media_id}.amr")
with open(path, "wb") as f: with open(path, "wb") as f:
f.write(resp.content) f.write(resp.content)
@@ -1,4 +1,5 @@
import asyncio import asyncio
import os
import sys import sys
import uuid import uuid
from collections.abc import Awaitable, Callable from collections.abc import Awaitable, Callable
@@ -24,6 +25,7 @@ from astrbot.api.platform import (
) )
from astrbot.core import logger from astrbot.core import logger
from astrbot.core.platform.astr_message_event import MessageSesion from astrbot.core.platform.astr_message_event import MessageSesion
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from astrbot.core.utils.media_utils import convert_audio_to_wav from astrbot.core.utils.media_utils import convert_audio_to_wav
from astrbot.core.utils.webhook_utils import log_webhook_info from astrbot.core.utils.webhook_utils import log_webhook_info
@@ -290,12 +292,16 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
self.client.media.download, self.client.media.download,
msg.media_id, msg.media_id,
) )
path = f"data/temp/wecom_{msg.media_id}.amr" temp_dir = get_astrbot_temp_path()
path = os.path.join(temp_dir, f"weixin_offacc_{msg.media_id}.amr")
with open(path, "wb") as f: with open(path, "wb") as f:
f.write(resp.content) f.write(resp.content)
try: try:
path_wav = f"data/temp/wecom_{msg.media_id}.wav" path_wav = os.path.join(
temp_dir,
f"weixin_offacc_{msg.media_id}.wav",
)
path_wav = await convert_audio_to_wav(path, path_wav) path_wav = await convert_audio_to_wav(path, path_wav)
except Exception as e: except Exception as e:
logger.error( logger.error(
@@ -12,12 +12,13 @@ from httpx import AsyncClient, Timeout
from astrbot import logger from astrbot import logger
from astrbot.core.config.default import VERSION from astrbot.core.config.default import VERSION
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from ..entities import ProviderType from ..entities import ProviderType
from ..provider import TTSProvider from ..provider import TTSProvider
from ..register import register_provider_adapter from ..register import register_provider_adapter
TEMP_DIR = Path("data/temp/azure_tts") TEMP_DIR = Path(get_astrbot_temp_path()) / "azure_tts"
TEMP_DIR.mkdir(parents=True, exist_ok=True) TEMP_DIR.mkdir(parents=True, exist_ok=True)
@@ -15,7 +15,7 @@ except (
): # pragma: no cover - older dashscope versions without Qwen TTS support ): # pragma: no cover - older dashscope versions without Qwen TTS support
MultiModalConversation = None MultiModalConversation = None
from astrbot.core.utils.astrbot_path import get_astrbot_data_path from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from ..entities import ProviderType from ..entities import ProviderType
from ..provider import TTSProvider from ..provider import TTSProvider
@@ -45,7 +45,7 @@ class ProviderDashscopeTTSAPI(TTSProvider):
if not model: if not model:
raise RuntimeError("Dashscope TTS model is not configured.") raise RuntimeError("Dashscope TTS model is not configured.")
temp_dir = os.path.join(get_astrbot_data_path(), "temp") temp_dir = get_astrbot_temp_path()
os.makedirs(temp_dir, exist_ok=True) os.makedirs(temp_dir, exist_ok=True)
if self._is_qwen_tts_model(model): if self._is_qwen_tts_model(model):
@@ -6,7 +6,7 @@ import uuid
import edge_tts import edge_tts
from astrbot.core import logger from astrbot.core import logger
from astrbot.core.utils.astrbot_path import get_astrbot_data_path from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from ..entities import ProviderType from ..entities import ProviderType
from ..provider import TTSProvider from ..provider import TTSProvider
@@ -46,7 +46,7 @@ class ProviderEdgeTTS(TTSProvider):
self.set_model("edge_tts") self.set_model("edge_tts")
async def get_audio(self, text: str) -> str: async def get_audio(self, text: str) -> str:
temp_dir = os.path.join(get_astrbot_data_path(), "temp") temp_dir = get_astrbot_temp_path()
mp3_path = os.path.join(temp_dir, f"edge_tts_temp_{uuid.uuid4()}.mp3") mp3_path = os.path.join(temp_dir, f"edge_tts_temp_{uuid.uuid4()}.mp3")
wav_path = os.path.join(temp_dir, f"edge_tts_{uuid.uuid4()}.wav") wav_path = os.path.join(temp_dir, f"edge_tts_{uuid.uuid4()}.wav")
@@ -8,7 +8,7 @@ from httpx import AsyncClient
from pydantic import BaseModel, conint from pydantic import BaseModel, conint
from astrbot import logger from astrbot import logger
from astrbot.core.utils.astrbot_path import get_astrbot_data_path from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from ..entities import ProviderType from ..entities import ProviderType
from ..provider import TTSProvider from ..provider import TTSProvider
@@ -142,7 +142,7 @@ class ProviderFishAudioTTSAPI(TTSProvider):
) )
async def get_audio(self, text: str) -> str: async def get_audio(self, text: str) -> str:
temp_dir = os.path.join(get_astrbot_data_path(), "temp") temp_dir = get_astrbot_temp_path()
path = os.path.join(temp_dir, f"fishaudio_tts_api_{uuid.uuid4()}.wav") path = os.path.join(temp_dir, f"fishaudio_tts_api_{uuid.uuid4()}.wav")
self.headers["content-type"] = "application/msgpack" self.headers["content-type"] = "application/msgpack"
request = await self._generate_request(text) request = await self._generate_request(text)
@@ -6,7 +6,7 @@ from google import genai
from google.genai import types from google.genai import types
from astrbot import logger from astrbot import logger
from astrbot.core.utils.astrbot_path import get_astrbot_data_path from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from ..entities import ProviderType from ..entities import ProviderType
from ..provider import TTSProvider from ..provider import TTSProvider
@@ -49,7 +49,7 @@ class ProviderGeminiTTSAPI(TTSProvider):
self.voice_name: str = provider_config.get("gemini_tts_voice_name", "Leda") self.voice_name: str = provider_config.get("gemini_tts_voice_name", "Leda")
async def get_audio(self, text: str) -> str: async def get_audio(self, text: str) -> str:
temp_dir = os.path.join(get_astrbot_data_path(), "temp") temp_dir = get_astrbot_temp_path()
path = os.path.join(temp_dir, f"gemini_tts_{uuid.uuid4()}.wav") path = os.path.join(temp_dir, f"gemini_tts_{uuid.uuid4()}.wav")
prompt = f"{self.prefix}: {text}" if self.prefix else text prompt = f"{self.prefix}: {text}" if self.prefix else text
response = await self.client.models.generate_content( response = await self.client.models.generate_content(
+3 -3
View File
@@ -6,7 +6,7 @@ from astrbot.core import logger
from astrbot.core.provider.entities import ProviderType from astrbot.core.provider.entities import ProviderType
from astrbot.core.provider.provider import TTSProvider from astrbot.core.provider.provider import TTSProvider
from astrbot.core.provider.register import register_provider_adapter from astrbot.core.provider.register import register_provider_adapter
from astrbot.core.utils.astrbot_path import get_astrbot_data_path from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
try: try:
import genie_tts as genie # type: ignore import genie_tts as genie # type: ignore
@@ -54,7 +54,7 @@ class GenieTTSProvider(TTSProvider):
return True return True
async def get_audio(self, text: str) -> str: async def get_audio(self, text: str) -> str:
temp_dir = os.path.join(get_astrbot_data_path(), "temp") temp_dir = get_astrbot_temp_path()
os.makedirs(temp_dir, exist_ok=True) os.makedirs(temp_dir, exist_ok=True)
filename = f"genie_tts_{uuid.uuid4()}.wav" filename = f"genie_tts_{uuid.uuid4()}.wav"
path = os.path.join(temp_dir, filename) path = os.path.join(temp_dir, filename)
@@ -94,7 +94,7 @@ class GenieTTSProvider(TTSProvider):
break break
try: try:
temp_dir = os.path.join(get_astrbot_data_path(), "temp") temp_dir = get_astrbot_temp_path()
os.makedirs(temp_dir, exist_ok=True) os.makedirs(temp_dir, exist_ok=True)
filename = f"genie_tts_{uuid.uuid4()}.wav" filename = f"genie_tts_{uuid.uuid4()}.wav"
path = os.path.join(temp_dir, filename) path = os.path.join(temp_dir, filename)
@@ -5,7 +5,7 @@ import uuid
import aiohttp import aiohttp
from astrbot import logger from astrbot import logger
from astrbot.core.utils.astrbot_path import get_astrbot_data_path from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from ..entities import ProviderType from ..entities import ProviderType
from ..provider import TTSProvider from ..provider import TTSProvider
@@ -121,7 +121,7 @@ class ProviderGSVTTS(TTSProvider):
params = self.build_synthesis_params(text) params = self.build_synthesis_params(text)
temp_dir = os.path.join(get_astrbot_data_path(), "temp") temp_dir = get_astrbot_temp_path()
os.makedirs(temp_dir, exist_ok=True) os.makedirs(temp_dir, exist_ok=True)
path = os.path.join(temp_dir, f"gsv_tts_{uuid.uuid4().hex}.wav") path = os.path.join(temp_dir, f"gsv_tts_{uuid.uuid4().hex}.wav")
@@ -4,7 +4,7 @@ import uuid
import aiohttp import aiohttp
from astrbot.core.utils.astrbot_path import get_astrbot_data_path from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from ..entities import ProviderType from ..entities import ProviderType
from ..provider import TTSProvider from ..provider import TTSProvider
@@ -29,7 +29,7 @@ class ProviderGSVITTS(TTSProvider):
self.emotion = provider_config.get("emotion") self.emotion = provider_config.get("emotion")
async def get_audio(self, text: str) -> str: async def get_audio(self, text: str) -> str:
temp_dir = os.path.join(get_astrbot_data_path(), "temp") temp_dir = get_astrbot_temp_path()
path = os.path.join(temp_dir, f"gsvi_tts_{uuid.uuid4()}.wav") path = os.path.join(temp_dir, f"gsvi_tts_{uuid.uuid4()}.wav")
params = {"text": text} params = {"text": text}
@@ -6,7 +6,7 @@ from collections.abc import AsyncIterator
import aiohttp import aiohttp
from astrbot.api import logger from astrbot.api import logger
from astrbot.core.utils.astrbot_path import get_astrbot_data_path from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from ..entities import ProviderType from ..entities import ProviderType
from ..provider import TTSProvider from ..provider import TTSProvider
@@ -145,7 +145,7 @@ class ProviderMiniMaxTTSAPI(TTSProvider):
return b"".join(chunks) return b"".join(chunks)
async def get_audio(self, text: str) -> str: async def get_audio(self, text: str) -> str:
temp_dir = os.path.join(get_astrbot_data_path(), "temp") temp_dir = get_astrbot_temp_path()
os.makedirs(temp_dir, exist_ok=True) os.makedirs(temp_dir, exist_ok=True)
path = os.path.join(temp_dir, f"minimax_tts_api_{uuid.uuid4()}.mp3") path = os.path.join(temp_dir, f"minimax_tts_api_{uuid.uuid4()}.mp3")
@@ -5,7 +5,7 @@ import httpx
from openai import NOT_GIVEN, AsyncOpenAI from openai import NOT_GIVEN, AsyncOpenAI
from astrbot import logger from astrbot import logger
from astrbot.core.utils.astrbot_path import get_astrbot_data_path from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from ..entities import ProviderType from ..entities import ProviderType
from ..provider import TTSProvider from ..provider import TTSProvider
@@ -46,7 +46,7 @@ class ProviderOpenAITTSAPI(TTSProvider):
self.set_model(provider_config.get("model", "")) self.set_model(provider_config.get("model", ""))
async def get_audio(self, text: str) -> str: async def get_audio(self, text: str) -> str:
temp_dir = os.path.join(get_astrbot_data_path(), "temp") temp_dir = get_astrbot_temp_path()
path = os.path.join(temp_dir, f"openai_tts_api_{uuid.uuid4()}.wav") path = os.path.join(temp_dir, f"openai_tts_api_{uuid.uuid4()}.wav")
async with self.client.audio.speech.with_streaming_response.create( async with self.client.audio.speech.with_streaming_response.create(
model=self.model_name, model=self.model_name,
@@ -8,6 +8,7 @@ import uuid
import aiohttp import aiohttp
from astrbot import logger from astrbot import logger
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from ..entities import ProviderType from ..entities import ProviderType
from ..provider import TTSProvider from ..provider import TTSProvider
@@ -92,9 +93,12 @@ class ProviderVolcengineTTS(TTSProvider):
if "data" in resp_data: if "data" in resp_data:
audio_data = base64.b64decode(resp_data["data"]) audio_data = base64.b64decode(resp_data["data"])
os.makedirs("data/temp", exist_ok=True) temp_dir = get_astrbot_temp_path()
os.makedirs(temp_dir, exist_ok=True)
file_path = f"data/temp/volcengine_tts_{uuid.uuid4()}.mp3" file_path = os.path.join(
temp_dir,
f"volcengine_tts_{uuid.uuid4()}.mp3",
)
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
await loop.run_in_executor( await loop.run_in_executor(
@@ -4,7 +4,7 @@ import uuid
from openai import NOT_GIVEN, AsyncOpenAI from openai import NOT_GIVEN, AsyncOpenAI
from astrbot.core import logger from astrbot.core import logger
from astrbot.core.utils.astrbot_path import get_astrbot_data_path from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from astrbot.core.utils.io import download_file from astrbot.core.utils.io import download_file
from astrbot.core.utils.tencent_record_helper import ( from astrbot.core.utils.tencent_record_helper import (
convert_to_pcm_wav, convert_to_pcm_wav,
@@ -65,9 +65,11 @@ class ProviderOpenAIWhisperAPI(STTProvider):
if "multimedia.nt.qq.com.cn" in audio_url: if "multimedia.nt.qq.com.cn" in audio_url:
is_tencent = True is_tencent = True
name = str(uuid.uuid4()) temp_dir = get_astrbot_temp_path()
temp_dir = os.path.join(get_astrbot_data_path(), "temp") path = os.path.join(
path = os.path.join(temp_dir, name) temp_dir,
f"whisper_api_{uuid.uuid4().hex[:8]}.input",
)
await download_file(audio_url, path) await download_file(audio_url, path)
audio_url = path audio_url = path
@@ -79,8 +81,11 @@ class ProviderOpenAIWhisperAPI(STTProvider):
# 判断是否需要转换 # 判断是否需要转换
if file_format in ["silk", "amr"]: if file_format in ["silk", "amr"]:
temp_dir = os.path.join(get_astrbot_data_path(), "temp") temp_dir = get_astrbot_temp_path()
output_path = os.path.join(temp_dir, str(uuid.uuid4()) + ".wav") output_path = os.path.join(
temp_dir,
f"whisper_api_{uuid.uuid4().hex[:8]}.wav",
)
if file_format == "silk": if file_format == "silk":
logger.info( logger.info(
@@ -6,7 +6,7 @@ from typing import cast
import whisper import whisper
from astrbot.core import logger from astrbot.core import logger
from astrbot.core.utils.astrbot_path import get_astrbot_data_path from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from astrbot.core.utils.io import download_file from astrbot.core.utils.io import download_file
from astrbot.core.utils.tencent_record_helper import tencent_silk_to_wav from astrbot.core.utils.tencent_record_helper import tencent_silk_to_wav
@@ -58,9 +58,11 @@ class ProviderOpenAIWhisperSelfHost(STTProvider):
if "multimedia.nt.qq.com.cn" in audio_url: if "multimedia.nt.qq.com.cn" in audio_url:
is_tencent = True is_tencent = True
name = str(uuid.uuid4()) temp_dir = get_astrbot_temp_path()
temp_dir = os.path.join(get_astrbot_data_path(), "temp") path = os.path.join(
path = os.path.join(temp_dir, name) temp_dir,
f"whisper_selfhost_{uuid.uuid4().hex[:8]}.input",
)
await download_file(audio_url, path) await download_file(audio_url, path)
audio_url = path audio_url = path
@@ -71,8 +73,11 @@ class ProviderOpenAIWhisperSelfHost(STTProvider):
is_silk = await self._is_silk_file(audio_url) is_silk = await self._is_silk_file(audio_url)
if is_silk: if is_silk:
logger.info("Converting silk file to wav ...") logger.info("Converting silk file to wav ...")
temp_dir = os.path.join(get_astrbot_data_path(), "temp") temp_dir = get_astrbot_temp_path()
output_path = os.path.join(temp_dir, str(uuid.uuid4()) + ".wav") output_path = os.path.join(
temp_dir,
f"whisper_selfhost_{uuid.uuid4().hex[:8]}.wav",
)
await tencent_silk_to_wav(audio_url, output_path) await tencent_silk_to_wav(audio_url, output_path)
audio_url = output_path audio_url = output_path
@@ -7,7 +7,7 @@ from xinference_client.client.restful.async_restful_client import (
) )
from astrbot.core import logger from astrbot.core import logger
from astrbot.core.utils.astrbot_path import get_astrbot_data_path from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from astrbot.core.utils.tencent_record_helper import ( from astrbot.core.utils.tencent_record_helper import (
convert_to_pcm_wav, convert_to_pcm_wav,
tencent_silk_to_wav, tencent_silk_to_wav,
@@ -130,11 +130,17 @@ class ProviderXinferenceSTT(STTProvider):
logger.info( logger.info(
f"Audio requires conversion ({conversion_type}), using temporary files..." f"Audio requires conversion ({conversion_type}), using temporary files..."
) )
temp_dir = os.path.join(get_astrbot_data_path(), "temp") temp_dir = get_astrbot_temp_path()
os.makedirs(temp_dir, exist_ok=True) os.makedirs(temp_dir, exist_ok=True)
input_path = os.path.join(temp_dir, str(uuid.uuid4())) input_path = os.path.join(
output_path = os.path.join(temp_dir, str(uuid.uuid4()) + ".wav") temp_dir,
f"xinference_stt_{uuid.uuid4().hex[:8]}.input",
)
output_path = os.path.join(
temp_dir,
f"xinference_stt_{uuid.uuid4().hex[:8]}.wav",
)
temp_files.extend([input_path, output_path]) temp_files.extend([input_path, output_path])
with open(input_path, "wb") as f: with open(input_path, "wb") as f:
-1
View File
@@ -93,7 +93,6 @@ class SkillManager:
self.skills_root = skills_root or get_astrbot_skills_path() self.skills_root = skills_root or get_astrbot_skills_path()
self.config_path = os.path.join(get_astrbot_data_path(), SKILLS_CONFIG_FILENAME) self.config_path = os.path.join(get_astrbot_data_path(), SKILLS_CONFIG_FILENAME)
os.makedirs(self.skills_root, exist_ok=True) os.makedirs(self.skills_root, exist_ok=True)
os.makedirs(get_astrbot_temp_path(), exist_ok=True)
def _load_config(self) -> dict: def _load_config(self) -> dict:
if not os.path.exists(self.config_path): if not os.path.exists(self.config_path):
+3 -14
View File
@@ -14,7 +14,7 @@ import certifi
import psutil import psutil
from PIL import Image from PIL import Image
from .astrbot_path import get_astrbot_data_path from .astrbot_path import get_astrbot_data_path, get_astrbot_temp_path
logger = logging.getLogger("astrbot") logger = logging.getLogger("astrbot")
@@ -50,21 +50,10 @@ def port_checker(port: int, host: str = "localhost") -> bool:
def save_temp_img(img: Image.Image | bytes) -> str: def save_temp_img(img: Image.Image | bytes) -> str:
temp_dir = os.path.join(get_astrbot_data_path(), "temp") temp_dir = get_astrbot_temp_path()
# 获得文件创建时间,清除超过 12 小时的
try:
for f in os.listdir(temp_dir):
path = os.path.join(temp_dir, f)
if os.path.isfile(path):
ctime = os.path.getctime(path)
if time.time() - ctime > 3600 * 12:
os.remove(path)
except Exception as e:
print(f"清除临时文件失败: {e}")
# 获得时间戳 # 获得时间戳
timestamp = f"{int(time.time())}_{uuid.uuid4().hex[:8]}" timestamp = f"{int(time.time())}_{uuid.uuid4().hex[:8]}"
p = os.path.join(temp_dir, f"{timestamp}.jpg") p = os.path.join(temp_dir, f"io_temp_img_{timestamp}.jpg")
if isinstance(img, Image.Image): if isinstance(img, Image.Image):
img.save(p) img.save(p)
+12 -9
View File
@@ -10,7 +10,7 @@ import uuid
from pathlib import Path from pathlib import Path
from astrbot import logger from astrbot import logger
from astrbot.core.utils.astrbot_path import get_astrbot_data_path from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
async def get_media_duration(file_path: str) -> int | None: async def get_media_duration(file_path: str) -> int | None:
@@ -77,9 +77,9 @@ async def convert_audio_to_opus(audio_path: str, output_path: str | None = None)
# 生成输出文件路径 # 生成输出文件路径
if output_path is None: if output_path is None:
temp_dir = os.path.join(get_astrbot_data_path(), "temp") temp_dir = get_astrbot_temp_path()
os.makedirs(temp_dir, exist_ok=True) os.makedirs(temp_dir, exist_ok=True)
output_path = os.path.join(temp_dir, f"{uuid.uuid4()}.opus") output_path = os.path.join(temp_dir, f"media_audio_{uuid.uuid4().hex}.opus")
try: try:
# 使用ffmpeg转换为opus格式 # 使用ffmpeg转换为opus格式
@@ -156,9 +156,12 @@ async def convert_video_format(
# 生成输出文件路径 # 生成输出文件路径
if output_path is None: if output_path is None:
temp_dir = os.path.join(get_astrbot_data_path(), "temp") temp_dir = get_astrbot_temp_path()
os.makedirs(temp_dir, exist_ok=True) os.makedirs(temp_dir, exist_ok=True)
output_path = os.path.join(temp_dir, f"{uuid.uuid4()}.{output_format}") output_path = os.path.join(
temp_dir,
f"media_video_{uuid.uuid4().hex}.{output_format}",
)
try: try:
# 使用ffmpeg转换视频格式 # 使用ffmpeg转换视频格式
@@ -227,9 +230,9 @@ async def convert_audio_format(
return audio_path return audio_path
if output_path is None: if output_path is None:
temp_dir = Path(get_astrbot_data_path()) / "temp" temp_dir = Path(get_astrbot_temp_path())
temp_dir.mkdir(parents=True, exist_ok=True) temp_dir.mkdir(parents=True, exist_ok=True)
output_path = str(temp_dir / f"{uuid.uuid4()}.{output_format}") output_path = str(temp_dir / f"media_audio_{uuid.uuid4().hex}.{output_format}")
args = ["ffmpeg", "-y", "-i", audio_path] args = ["ffmpeg", "-y", "-i", audio_path]
if output_format == "amr": if output_format == "amr":
@@ -283,9 +286,9 @@ async def extract_video_cover(
) -> str: ) -> str:
"""从视频中提取封面图(JPG)。""" """从视频中提取封面图(JPG)。"""
if output_path is None: if output_path is None:
temp_dir = Path(get_astrbot_data_path()) / "temp" temp_dir = Path(get_astrbot_temp_path())
temp_dir.mkdir(parents=True, exist_ok=True) temp_dir.mkdir(parents=True, exist_ok=True)
output_path = str(temp_dir / f"{uuid.uuid4()}.jpg") output_path = str(temp_dir / f"media_cover_{uuid.uuid4().hex}.jpg")
try: try:
process = await asyncio.create_subprocess_exec( process = await asyncio.create_subprocess_exec(
+150
View File
@@ -0,0 +1,150 @@
import asyncio
from collections.abc import Callable
from dataclasses import dataclass
from pathlib import Path
from astrbot import logger
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
def parse_size_to_bytes(value: str | int | float | None) -> int:
"""Parse size in MB to bytes."""
if value is None:
return 0
try:
size_mb = float(str(value).strip())
except (TypeError, ValueError):
return 0
if size_mb <= 0:
return 0
return int(size_mb * 1024**2)
@dataclass
class TempFileInfo:
path: Path
size: int
mtime: float
class TempDirCleaner:
CONFIG_KEY = "temp_dir_max_size"
DEFAULT_MAX_SIZE = 1024
CHECK_INTERVAL_SECONDS = 10 * 60
CLEANUP_RATIO = 0.30
def __init__(
self,
max_size_getter: Callable[[], str | int | float | None],
temp_dir: Path | None = None,
) -> None:
self._max_size_getter = max_size_getter
self._temp_dir = temp_dir or Path(get_astrbot_temp_path())
self._stop_event = asyncio.Event()
def _limit_bytes(self) -> int:
configured = self._max_size_getter()
parsed = parse_size_to_bytes(configured)
if parsed <= 0:
fallback = parse_size_to_bytes(self.DEFAULT_MAX_SIZE)
logger.warning(
f"Invalid {self.CONFIG_KEY}={configured!r}, fallback to {self.DEFAULT_MAX_SIZE}MB.",
)
return fallback
return parsed
def _scan_temp_files(self) -> tuple[int, list[TempFileInfo]]:
if not self._temp_dir.exists():
return 0, []
total_size = 0
files: list[TempFileInfo] = []
for path in self._temp_dir.rglob("*"):
if not path.is_file():
continue
try:
stat = path.stat()
except OSError as e:
logger.debug(f"Skip temp file {path} due to stat error: {e}")
continue
total_size += stat.st_size
files.append(
TempFileInfo(path=path, size=stat.st_size, mtime=stat.st_mtime)
)
return total_size, files
def _cleanup_empty_dirs(self) -> None:
if not self._temp_dir.exists():
return
for path in sorted(
self._temp_dir.rglob("*"), key=lambda p: len(p.parts), reverse=True
):
if not path.is_dir():
continue
try:
path.rmdir()
except OSError:
continue
def cleanup_once(self) -> None:
limit = self._limit_bytes()
if limit <= 0:
return
total_size, files = self._scan_temp_files()
if total_size <= limit:
return
target_release = max(int(total_size * self.CLEANUP_RATIO), 1)
released = 0
removed_files = 0
for file_info in sorted(files, key=lambda item: item.mtime):
try:
file_info.path.unlink()
except OSError as e:
logger.warning(f"Failed to delete temp file {file_info.path}: {e}")
continue
released += file_info.size
removed_files += 1
if released >= target_release:
break
self._cleanup_empty_dirs()
logger.warning(
f"Temp dir exceeded limit ({total_size} > {limit}). "
f"Removed {removed_files} files, released {released} bytes "
f"(target {target_release} bytes).",
)
async def run(self) -> None:
logger.info(
f"TempDirCleaner started. interval={self.CHECK_INTERVAL_SECONDS}s "
f"cleanup_ratio={self.CLEANUP_RATIO}",
)
while not self._stop_event.is_set():
try:
# File-system traversal and deletion are blocking operations.
# Run cleanup in a worker thread to avoid blocking the event loop.
await asyncio.to_thread(self.cleanup_once)
except Exception as e:
logger.error(f"TempDirCleaner run failed: {e}", exc_info=True)
try:
await asyncio.wait_for(
self._stop_event.wait(),
timeout=self.CHECK_INTERVAL_SECONDS,
)
except asyncio.TimeoutError:
continue
logger.info("TempDirCleaner stopped.")
async def stop(self) -> None:
self._stop_event.set()
+4 -2
View File
@@ -7,7 +7,7 @@ import wave
from io import BytesIO from io import BytesIO
from astrbot.core import logger from astrbot.core import logger
from astrbot.core.utils.astrbot_path import get_astrbot_data_path from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
async def tencent_silk_to_wav(silk_path: str, output_path: str) -> str: async def tencent_silk_to_wav(silk_path: str, output_path: str) -> str:
@@ -117,12 +117,13 @@ async def audio_to_tencent_silk_base64(audio_path: str) -> tuple[str, float]:
except ImportError as e: except ImportError as e:
raise Exception("未安装 pilk: pip install pilk") from e raise Exception("未安装 pilk: pip install pilk") from e
temp_dir = os.path.join(get_astrbot_data_path(), "temp") temp_dir = get_astrbot_temp_path()
os.makedirs(temp_dir, exist_ok=True) os.makedirs(temp_dir, exist_ok=True)
# 是否需要转换为 WAV # 是否需要转换为 WAV
ext = os.path.splitext(audio_path)[1].lower() ext = os.path.splitext(audio_path)[1].lower()
temp_wav = tempfile.NamedTemporaryFile( temp_wav = tempfile.NamedTemporaryFile(
prefix="tencent_record_",
suffix=".wav", suffix=".wav",
delete=False, delete=False,
dir=temp_dir, dir=temp_dir,
@@ -140,6 +141,7 @@ async def audio_to_tencent_silk_base64(audio_path: str) -> tuple[str, float]:
rate = wav_file.getframerate() rate = wav_file.getframerate()
silk_path = tempfile.NamedTemporaryFile( silk_path = tempfile.NamedTemporaryFile(
prefix="tencent_record_",
suffix=".silk", suffix=".silk",
delete=False, delete=False,
dir=temp_dir, dir=temp_dir,
+5 -1
View File
@@ -12,6 +12,7 @@ from quart import request
from astrbot.core import logger from astrbot.core import logger
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.provider.provider import EmbeddingProvider, RerankProvider from astrbot.core.provider.provider import EmbeddingProvider, RerankProvider
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from ..utils import generate_tsne_visualization from ..utils import generate_tsne_visualization
from .route import Response, Route, RouteContext from .route import Response, Route, RouteContext
@@ -703,7 +704,10 @@ class KnowledgeBaseRoute(Route):
file_name = file.filename file_name = file.filename
# 保存到临时文件 # 保存到临时文件
temp_file_path = f"data/temp/{uuid.uuid4()}_{file_name}" temp_file_path = os.path.join(
get_astrbot_temp_path(),
f"kb_upload_{uuid.uuid4()}_{file_name}",
)
await file.save(temp_file_path) await file.save(temp_file_path)
try: try:
+2 -2
View File
@@ -12,7 +12,7 @@ from quart import websocket
from astrbot import logger from astrbot import logger
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.platform.sources.webchat.webchat_queue_mgr import webchat_queue_mgr from astrbot.core.platform.sources.webchat.webchat_queue_mgr import webchat_queue_mgr
from astrbot.core.utils.astrbot_path import get_astrbot_data_path from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from .route import Route, RouteContext from .route import Route, RouteContext
@@ -60,7 +60,7 @@ class LiveChatSession:
# 组装 WAV 文件 # 组装 WAV 文件
try: try:
temp_dir = os.path.join(get_astrbot_data_path(), "temp") temp_dir = get_astrbot_temp_path()
os.makedirs(temp_dir, exist_ok=True) os.makedirs(temp_dir, exist_ok=True)
audio_path = os.path.join(temp_dir, f"live_audio_{uuid.uuid4()}.wav") audio_path = os.path.join(temp_dir, f"live_audio_{uuid.uuid4()}.wav")
+5 -1
View File
@@ -20,6 +20,7 @@ from astrbot.core.star.filter.permission import PermissionTypeFilter
from astrbot.core.star.filter.regex import RegexFilter from astrbot.core.star.filter.regex import RegexFilter
from astrbot.core.star.star_handler import EventType, star_handlers_registry from astrbot.core.star.star_handler import EventType, star_handlers_registry
from astrbot.core.star.star_manager import PluginManager from astrbot.core.star.star_manager import PluginManager
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from .route import Response, Route, RouteContext from .route import Response, Route, RouteContext
@@ -431,7 +432,10 @@ class PluginRoute(Route):
file = await request.files file = await request.files
file = file["file"] file = file["file"]
logger.info(f"正在安装用户上传的插件 {file.filename}") logger.info(f"正在安装用户上传的插件 {file.filename}")
file_path = f"data/temp/{file.filename}" file_path = os.path.join(
get_astrbot_temp_path(),
f"plugin_upload_{file.filename}",
)
await file.save(file_path) await file.save(file_path)
plugin_info = await self.plugin_manager.install_plugin_from_file(file_path) plugin_info = await self.plugin_manager.install_plugin_from_file(file_path)
# self.core_lifecycle.restart() # self.core_lifecycle.restart()
@@ -819,6 +819,10 @@
"description": "Log File Max Size (MB)", "description": "Log File Max Size (MB)",
"hint": "Rotate when exceeding this size; default 20MB." "hint": "Rotate when exceeding this size; default 20MB."
}, },
"temp_dir_max_size": {
"description": "Temp Directory Size Limit (MB)",
"hint": "Limits total size of data/temp in MB. The system checks every 10 minutes, and when exceeded, deletes oldest files first to release about 30% of current size."
},
"trace_log_enable": { "trace_log_enable": {
"description": "Enable Trace File Logging", "description": "Enable Trace File Logging",
"hint": "Write trace events to a separate file (does not change console output)." "hint": "Write trace events to a separate file (does not change console output)."
@@ -822,6 +822,10 @@
"description": "日志文件大小上限 (MB)", "description": "日志文件大小上限 (MB)",
"hint": "超过大小后自动轮转,默认 20MB。" "hint": "超过大小后自动轮转,默认 20MB。"
}, },
"temp_dir_max_size": {
"description": "临时目录大小上限 (MB)",
"hint": "用于限制 data/temp 目录总大小,单位为 MB。系统每 10 分钟检查一次,超限时按文件修改时间从旧到新删除,释放约 30% 当前体积。"
},
"trace_log_enable": { "trace_log_enable": {
"description": "启用 Trace 文件日志", "description": "启用 Trace 文件日志",
"hint": "将 Trace 事件写入独立文件(不影响控制台输出)。" "hint": "将 Trace 事件写入独立文件(不影响控制台输出)。"
+52
View File
@@ -0,0 +1,52 @@
import os
import time
from pathlib import Path
from astrbot.core.utils.temp_dir_cleaner import TempDirCleaner, parse_size_to_bytes
def test_parse_size_to_bytes():
assert parse_size_to_bytes("1024") == 1024 * 1024**2
assert parse_size_to_bytes(2048) == 2048 * 1024**2
assert parse_size_to_bytes("0.5") == int(0.5 * 1024**2)
assert parse_size_to_bytes(0) == 0
assert parse_size_to_bytes("invalid") == 0
def _write_file(path: Path, size: int, mtime: float) -> None:
path.write_bytes(b"x" * size)
os.utime(path, (mtime, mtime))
def test_cleanup_once_releases_30_percent_and_prefers_old_files(tmp_path):
temp_dir = tmp_path / "temp"
temp_dir.mkdir(parents=True, exist_ok=True)
base_time = time.time() - 1000
file_old = temp_dir / "old.bin"
file_mid = temp_dir / "mid.bin"
file_new = temp_dir / "new.bin"
_write_file(file_old, 400, base_time)
_write_file(file_mid, 300, base_time + 10)
_write_file(file_new, 300, base_time + 20)
cleaner = TempDirCleaner(max_size_getter=lambda: "0.0008", temp_dir=temp_dir)
cleaner.cleanup_once()
remaining_size = sum(f.stat().st_size for f in temp_dir.rglob("*") if f.is_file())
assert remaining_size <= 600
assert not file_old.exists()
assert file_mid.exists()
assert file_new.exists()
def test_cleanup_once_noop_when_below_limit(tmp_path):
temp_dir = tmp_path / "temp"
temp_dir.mkdir(parents=True, exist_ok=True)
file_path = temp_dir / "a.bin"
_write_file(file_path, 100, time.time())
cleaner = TempDirCleaner(max_size_getter=lambda: "1", temp_dir=temp_dir)
cleaner.cleanup_once()
assert file_path.exists()