Compare commits

...

10 Commits

Author SHA1 Message Date
Soulter 3e3599835e chore: bump version to 4.10.3 2025-12-26 22:39:59 +08:00
Soulter 5255388e2d refactor: move builtin stars to astrbot package (#4209)
* refactor: move builtin stars to astrbot package

fixes: #4202

* chore: ruff format

* chore: remove print
2025-12-26 22:31:22 +08:00
Yokami fbdd60b64c feat: add extra user content block support (#4189)
* feat: 多文本块功能

* FIX

* 传递链

* 重命名

* refactor: unify extra_user_content_parts type to ContentPart across providers and update related handling

* claude额外块支持图片模态

* 已经处理过了不用再处理

* feat: enhance image handling in extra content blocks for multiple providers

---------

Co-authored-by: Soulter <905617992@qq.com>
2025-12-26 22:08:20 +08:00
Soulter bd1b0a2836 perf: drop unnecessary none-value fields in tool call loop (#4213) 2025-12-26 21:12:34 +08:00
Soulter 19541d9d07 fix: ensure max_tokens is set and validate tool_calls type in ProviderAnthropic (#4212) 2025-12-26 21:01:05 +08:00
大饼鸡蛋 2a5d574394 fix: failed to initialize FishAudio TTS instance (#4200)
fixes: #4172

* fix: 修复 FishAudio 源的配置加载问题并增强请求鲁棒性

- Fix `KeyError: 'model'``: 适配新版配置结构。
- Add `timeout` support: 防止长文本生成时超时。
- Improve response handling: 使用更标准的 Header 检查方式。

* feat: 使用更安全的类型转换并优化错误信息打印
2025-12-26 20:50:45 +08:00
Soulter f2924fbd1b chore: update readme 2025-12-26 18:04:56 +08:00
Gao Jinzhe 703e208947 fix: handle index out of range error when selecting provider (#4206) 2025-12-26 18:02:43 +08:00
NoctuUFO 9a5cc977c2 fix: fix log loss on SSE reconnect using Last-Event-ID (#4205)
* feat: implement last-event-id handing in log route

* perf: better log handling

* chore: ruff format

* perf: log

* Update ConsoleDisplayer.vue

* Update package.json

* Update ConsoleDisplayer.vue

* Update common.js

* chore: ruff format

* fix: ensure last_event_id is required for log replay

---------

Co-authored-by: Soulter <905617992@qq.com>
2025-12-26 18:01:58 +08:00
RC-CHN aa38fe776a feat: supports data backup (#4105)
* feat: 添加数据迁移功能

* test: 添加迁移相关测试

* feat: 备份插件及相关持久化目录

* fix: 修复版本号比较逻辑,添加相关测试

* fix: 清洗文件名,添加相关测试

* fix: 修复安全文件名测试用例断言

* refactor: 优化代码,为备份模块提取公用常量

* feat: 修改备份版本校验逻辑,允许强制小版本间导入

* fix: 修复备份创建时间读取,修复备份相关i18n

* refactor(backup): 使用 astrbot_path 统一管理备份目录路径

* fix(backup): 清理备份模块中未使用的导入

* refactor(backup): 统一备份路径与参数并移除未用附件目录

- 通过 astrbot_path 动态获取备份/知识库/数据相关路径
- 移除 exporter/importer 未使用的 attachments_dir/data_root 传参
- 更新备份路由与测试用例的构造参数

* fix(dashboard): alias mermaid to dist entry for Vite prebundle

* fix(backup): 放行start-time接口到白名单以处理备份导入后jwt token变化导致无法自动刷新webui的问题

* chore(backup): 统一配置路径以使用动态数据目录

* refactor(backup): 使用 VersionComparator 替代重复的版本比较函数

* style(backup test): format code

---------

Co-authored-by: Soulter <905617992@qq.com>
2025-12-26 15:47:50 +08:00
72 changed files with 4124 additions and 225 deletions
+1 -2
View File
@@ -15,7 +15,6 @@ Always reference these instructions first and fallback to search or bash command
### Running the Application
- Run main application: `uv run main.py` -- starts in ~3 seconds
- Application creates WebUI on http://localhost:6185 (default credentials: `astrbot`/`astrbot`)
- Application loads plugins automatically from `packages/` and `data/plugins/` directories
### Dashboard Build (Vue.js/Node.js)
- **Prerequisites**: Node.js 20+ and npm 10+ required
@@ -35,7 +34,7 @@ Always reference these instructions first and fallback to search or bash command
- **ALWAYS** run `uv run ruff check .` and `uv run ruff format .` before committing changes
### Plugin Development
- Plugins load from `packages/` (built-in) and `data/plugins/` (user-installed)
- Plugins load from `astrbot/builtin_stars/` (built-in) and `data/plugins/` (user-installed)
- Plugin system supports function tools and message handlers
- Key plugins: python_interpreter, web_searcher, astrbot, reminder, session_controller
+2 -2
View File
@@ -24,9 +24,9 @@ configs/session
configs/config.yaml
cmd_config.json
# Plugins and packages
# Plugins
addons/plugins
packages/python_interpreter/workplace
astrbot/builtin_stars/python_interpreter/workplace
tests/astrbot_plugin_openai
# Dashboard
+1 -1
View File
@@ -1,4 +1,4 @@
![astrbot-banner-xmas](https://github.com/user-attachments/assets/bf2341de-ec7a-45a7-a04a-02ad36450e99)
![AstrBot-Logo-Simplified](https://github.com/user-attachments/assets/ffd99b6b-3272-4682-beaa-6fe74250f7d9)
<div align="center">
@@ -7,6 +7,7 @@ from astrbot.api import logger, sp, star
from astrbot.api.event import AstrMessageEvent
from astrbot.api.message_components import Image, Reply
from astrbot.api.provider import Provider, ProviderRequest
from astrbot.core.agent.message import TextPart
from astrbot.core.provider.func_tool_manager import ToolSet
@@ -85,7 +86,9 @@ class ProcessLLMRequest:
req.image_urls,
)
if caption:
req.prompt = f"(Image Caption: {caption})\n\n{req.prompt}"
req.extra_user_content_parts.append(
TextPart(text=f"<image_caption>{caption}</image_caption>")
)
req.image_urls = []
except Exception as e:
logger.error(f"处理图片描述失败: {e}")
@@ -129,13 +132,14 @@ class ProcessLLMRequest:
else:
req.prompt = prefix + req.prompt
# 收集系统提醒信息
system_parts = []
# user identifier
if cfg.get("identifier"):
user_id = event.message_obj.sender.user_id
user_nickname = event.message_obj.sender.nickname
req.prompt = (
f"\n[User ID: {user_id}, Nickname: {user_nickname}]\n{req.prompt}"
)
system_parts.append(f"User ID: {user_id}, Nickname: {user_nickname}")
# group name identifier
if cfg.get("group_name_display") and event.message_obj.group_id:
@@ -146,7 +150,7 @@ class ProcessLLMRequest:
return
group_name = event.message_obj.group.group_name
if group_name:
req.system_prompt += f"\nGroup name: {group_name}\n"
system_parts.append(f"Group name: {group_name}")
# time info
if cfg.get("datetime_system_prompt"):
@@ -162,7 +166,7 @@ class ProcessLLMRequest:
current_time = (
datetime.datetime.now().astimezone().strftime("%Y-%m-%d %H:%M (%Z)")
)
req.system_prompt += f"\nCurrent datetime: {current_time}\n"
system_parts.append(f"Current datetime: {current_time}")
img_cap_prov_id: str = cfg.get("default_image_caption_provider_id") or ""
if req.conversation:
@@ -225,10 +229,17 @@ class ProcessLLMRequest:
except BaseException as e:
logger.error(f"处理引用图片失败: {e}")
# 3. 将所有部分组合成文本并直接注入到当前消息
# 3. 将所有部分组合成文本并添加到 extra_user_content_parts
# 确保引用内容被正确的标签包裹
quoted_content = "\n".join(content_parts)
# 确保所有内容都在<Quoted Message>标签内
quoted_text = f"<Quoted Message>\n{quoted_content}\n</Quoted Message>"
req.prompt = f"{quoted_text}\n\n{req.prompt}"
req.extra_user_content_parts.append(TextPart(text=quoted_text))
# 统一包裹所有系统提醒
if system_parts:
system_content = (
"<system_reminder>" + "\n".join(system_parts) + "</system_reminder>"
)
req.extra_user_content_parts.append(TextPart(text=system_content))
@@ -184,7 +184,8 @@ class ProviderCommands:
event.set_result(MessageEventResult().message("请输入序号。"))
return
if idx2 > len(self.context.get_all_tts_providers()) or idx2 < 1:
event.set_result(MessageEventResult().message("无效的序号。"))
event.set_result(MessageEventResult().message("无效的提供商序号。"))
return
provider = self.context.get_all_tts_providers()[idx2 - 1]
id_ = provider.meta().id
await self.context.provider_manager.set_provider(
@@ -198,7 +199,8 @@ class ProviderCommands:
event.set_result(MessageEventResult().message("请输入序号。"))
return
if idx2 > len(self.context.get_all_stt_providers()) or idx2 < 1:
event.set_result(MessageEventResult().message("无效的序号。"))
event.set_result(MessageEventResult().message("无效的提供商序号。"))
return
provider = self.context.get_all_stt_providers()[idx2 - 1]
id_ = provider.meta().id
await self.context.provider_manager.set_provider(
@@ -209,8 +211,8 @@ class ProviderCommands:
event.set_result(MessageEventResult().message(f"成功切换到 {id_}"))
elif isinstance(idx, int):
if idx > len(self.context.get_all_providers()) or idx < 1:
event.set_result(MessageEventResult().message("无效的序号。"))
event.set_result(MessageEventResult().message("无效的提供商序号。"))
return
provider = self.context.get_all_providers()[idx - 1]
id_ = provider.meta().id
await self.context.provider_manager.set_provider(
+1 -1
View File
@@ -1 +1 @@
__version__ = "4.10.2"
__version__ = "4.10.3"
+9
View File
@@ -169,6 +169,15 @@ class Message(BaseModel):
)
return self
@model_serializer(mode="wrap")
def serialize(self, handler):
data = handler(self)
if self.tool_calls is None:
data.pop("tool_calls", None)
if self.tool_call_id is None:
data.pop("tool_call_id", None)
return data
class AssistantMessageSegment(Message):
"""A message segment from the assistant."""
@@ -77,10 +77,11 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
async def _iter_llm_responses(self) -> T.AsyncGenerator[LLMResponse, None]:
"""Yields chunks *and* a final LLMResponse."""
payload = {
"contexts": self.run_context.messages,
"contexts": self.run_context.messages, # list[Message]
"func_tool": self.req.func_tool,
"model": self.req.model, # NOTE: in fact, this arg is None in most cases
"session_id": self.req.session_id,
"extra_user_content_parts": self.req.extra_user_content_parts, # list[ContentPart]
}
if self.streaming:
+26
View File
@@ -0,0 +1,26 @@
"""AstrBot 备份与恢复模块
提供数据导出和导入功能,支持用户在服务器迁移时一键备份和恢复所有数据。
"""
# 从 constants 模块导入共享常量
from .constants import (
BACKUP_MANIFEST_VERSION,
KB_METADATA_MODELS,
MAIN_DB_MODELS,
get_backup_directories,
)
# 导入导出器和导入器
from .exporter import AstrBotExporter
from .importer import AstrBotImporter, ImportPreCheckResult
__all__ = [
"AstrBotExporter",
"AstrBotImporter",
"ImportPreCheckResult",
"MAIN_DB_MODELS",
"KB_METADATA_MODELS",
"get_backup_directories",
"BACKUP_MANIFEST_VERSION",
]
+77
View File
@@ -0,0 +1,77 @@
"""AstrBot 备份模块共享常量
此文件定义了导出器和导入器共享的常量,确保两端配置一致。
"""
from sqlmodel import SQLModel
from astrbot.core.db.po import (
Attachment,
CommandConfig,
CommandConflict,
ConversationV2,
Persona,
PlatformMessageHistory,
PlatformSession,
PlatformStat,
Preference,
)
from astrbot.core.knowledge_base.models import (
KBDocument,
KBMedia,
KnowledgeBase,
)
from astrbot.core.utils.astrbot_path import (
get_astrbot_config_path,
get_astrbot_plugin_data_path,
get_astrbot_plugin_path,
get_astrbot_t2i_templates_path,
get_astrbot_temp_path,
get_astrbot_webchat_path,
)
# ============================================================
# 共享常量 - 确保导出和导入端配置一致
# ============================================================
# 主数据库模型类映射
MAIN_DB_MODELS: dict[str, type[SQLModel]] = {
"platform_stats": PlatformStat,
"conversations": ConversationV2,
"personas": Persona,
"preferences": Preference,
"platform_message_history": PlatformMessageHistory,
"platform_sessions": PlatformSession,
"attachments": Attachment,
"command_configs": CommandConfig,
"command_conflicts": CommandConflict,
}
# 知识库元数据模型类映射
KB_METADATA_MODELS: dict[str, type[SQLModel]] = {
"knowledge_bases": KnowledgeBase,
"kb_documents": KBDocument,
"kb_media": KBMedia,
}
def get_backup_directories() -> dict[str, str]:
"""获取需要备份的目录列表
使用 astrbot_path 模块动态获取路径,支持通过环境变量 ASTRBOT_ROOT 自定义根目录。
Returns:
dict: 键为备份文件中的目录名称,值为目录的绝对路径
"""
return {
"plugins": get_astrbot_plugin_path(), # 插件本体
"plugin_data": get_astrbot_plugin_data_path(), # 插件数据
"config": get_astrbot_config_path(), # 配置目录
"t2i_templates": get_astrbot_t2i_templates_path(), # T2I 模板
"webchat": get_astrbot_webchat_path(), # WebChat 数据
"temp": get_astrbot_temp_path(), # 临时文件
}
# 备份清单版本号
BACKUP_MANIFEST_VERSION = "1.1"
+476
View File
@@ -0,0 +1,476 @@
"""AstrBot 数据导出器
负责将所有数据导出为 ZIP 备份文件。
导出格式为 JSON,这是数据库无关的方案,支持未来向 MySQL/PostgreSQL 迁移。
"""
import hashlib
import json
import os
import zipfile
from datetime import datetime, timezone
from pathlib import Path
from typing import TYPE_CHECKING, Any
from sqlalchemy import select
from astrbot.core import logger
from astrbot.core.config.default import VERSION
from astrbot.core.db import BaseDatabase
from astrbot.core.utils.astrbot_path import (
get_astrbot_backups_path,
get_astrbot_data_path,
)
# 从共享常量模块导入
from .constants import (
BACKUP_MANIFEST_VERSION,
KB_METADATA_MODELS,
MAIN_DB_MODELS,
get_backup_directories,
)
if TYPE_CHECKING:
from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager
CMD_CONFIG_FILE_PATH = os.path.join(get_astrbot_data_path(), "cmd_config.json")
class AstrBotExporter:
"""AstrBot 数据导出器
导出内容:
- 主数据库所有表(data/data_v4.db
- 知识库元数据(data/knowledge_base/kb.db
- 每个知识库的向量文档数据
- 配置文件(data/cmd_config.json
- 附件文件
- 知识库多媒体文件
- 插件目录(data/plugins
- 插件数据目录(data/plugin_data
- 配置目录(data/config
- T2I 模板目录(data/t2i_templates
- WebChat 数据目录(data/webchat
- 临时文件目录(data/temp
"""
def __init__(
self,
main_db: BaseDatabase,
kb_manager: "KnowledgeBaseManager | None" = None,
config_path: str = CMD_CONFIG_FILE_PATH,
):
self.main_db = main_db
self.kb_manager = kb_manager
self.config_path = config_path
self._checksums: dict[str, str] = {}
async def export_all(
self,
output_dir: str | None = None,
progress_callback: Any | None = None,
) -> str:
"""导出所有数据到 ZIP 文件
Args:
output_dir: 输出目录
progress_callback: 进度回调函数,接收参数 (stage, current, total, message)
Returns:
str: 生成的 ZIP 文件路径
"""
if output_dir is None:
output_dir = get_astrbot_backups_path()
# 确保输出目录存在
Path(output_dir).mkdir(parents=True, exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
zip_filename = f"astrbot_backup_{timestamp}.zip"
zip_path = os.path.join(output_dir, zip_filename)
logger.info(f"开始导出备份到 {zip_path}")
try:
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
# 1. 导出主数据库
if progress_callback:
await progress_callback("main_db", 0, 100, "正在导出主数据库...")
main_data = await self._export_main_database()
main_db_json = json.dumps(
main_data, ensure_ascii=False, indent=2, default=str
)
zf.writestr("databases/main_db.json", main_db_json)
self._add_checksum("databases/main_db.json", main_db_json)
if progress_callback:
await progress_callback("main_db", 100, 100, "主数据库导出完成")
# 2. 导出知识库数据
kb_meta_data: dict[str, Any] = {
"knowledge_bases": [],
"kb_documents": [],
"kb_media": [],
}
if self.kb_manager:
if progress_callback:
await progress_callback(
"kb_metadata", 0, 100, "正在导出知识库元数据..."
)
kb_meta_data = await self._export_kb_metadata()
kb_meta_json = json.dumps(
kb_meta_data, ensure_ascii=False, indent=2, default=str
)
zf.writestr("databases/kb_metadata.json", kb_meta_json)
self._add_checksum("databases/kb_metadata.json", kb_meta_json)
if progress_callback:
await progress_callback(
"kb_metadata", 100, 100, "知识库元数据导出完成"
)
# 导出每个知识库的文档数据
kb_insts = self.kb_manager.kb_insts
total_kbs = len(kb_insts)
for idx, (kb_id, kb_helper) in enumerate(kb_insts.items()):
if progress_callback:
await progress_callback(
"kb_documents",
idx,
total_kbs,
f"正在导出知识库 {kb_helper.kb.kb_name} 的文档数据...",
)
doc_data = await self._export_kb_documents(kb_helper)
doc_json = json.dumps(
doc_data, ensure_ascii=False, indent=2, default=str
)
doc_path = f"databases/kb_{kb_id}/documents.json"
zf.writestr(doc_path, doc_json)
self._add_checksum(doc_path, doc_json)
# 导出 FAISS 索引文件
await self._export_faiss_index(zf, kb_helper, kb_id)
# 导出知识库多媒体文件
await self._export_kb_media_files(zf, kb_helper, kb_id)
if progress_callback:
await progress_callback(
"kb_documents", total_kbs, total_kbs, "知识库文档导出完成"
)
# 3. 导出配置文件
if progress_callback:
await progress_callback("config", 0, 100, "正在导出配置文件...")
if os.path.exists(self.config_path):
with open(self.config_path, encoding="utf-8") as f:
config_content = f.read()
zf.writestr("config/cmd_config.json", config_content)
self._add_checksum("config/cmd_config.json", config_content)
if progress_callback:
await progress_callback("config", 100, 100, "配置文件导出完成")
# 4. 导出附件文件
if progress_callback:
await progress_callback("attachments", 0, 100, "正在导出附件...")
await self._export_attachments(zf, main_data.get("attachments", []))
if progress_callback:
await progress_callback("attachments", 100, 100, "附件导出完成")
# 5. 导出插件和其他目录
if progress_callback:
await progress_callback(
"directories", 0, 100, "正在导出插件和数据目录..."
)
dir_stats = await self._export_directories(zf)
if progress_callback:
await progress_callback("directories", 100, 100, "目录导出完成")
# 6. 生成 manifest
if progress_callback:
await progress_callback("manifest", 0, 100, "正在生成清单...")
manifest = self._generate_manifest(main_data, kb_meta_data, dir_stats)
manifest_json = json.dumps(manifest, ensure_ascii=False, indent=2)
zf.writestr("manifest.json", manifest_json)
if progress_callback:
await progress_callback("manifest", 100, 100, "清单生成完成")
logger.info(f"备份导出完成: {zip_path}")
return zip_path
except Exception as e:
logger.error(f"备份导出失败: {e}")
# 清理失败的文件
if os.path.exists(zip_path):
os.remove(zip_path)
raise
async def _export_main_database(self) -> dict[str, list[dict]]:
"""导出主数据库所有表"""
export_data: dict[str, list[dict]] = {}
async with self.main_db.get_db() as session:
for table_name, model_class in MAIN_DB_MODELS.items():
try:
result = await session.execute(select(model_class))
records = result.scalars().all()
export_data[table_name] = [
self._model_to_dict(record) for record in records
]
logger.debug(
f"导出表 {table_name}: {len(export_data[table_name])} 条记录"
)
except Exception as e:
logger.warning(f"导出表 {table_name} 失败: {e}")
export_data[table_name] = []
return export_data
async def _export_kb_metadata(self) -> dict[str, list[dict]]:
"""导出知识库元数据库"""
if not self.kb_manager:
return {"knowledge_bases": [], "kb_documents": [], "kb_media": []}
export_data: dict[str, list[dict]] = {}
async with self.kb_manager.kb_db.get_db() as session:
for table_name, model_class in KB_METADATA_MODELS.items():
try:
result = await session.execute(select(model_class))
records = result.scalars().all()
export_data[table_name] = [
self._model_to_dict(record) for record in records
]
logger.debug(
f"导出知识库表 {table_name}: {len(export_data[table_name])} 条记录"
)
except Exception as e:
logger.warning(f"导出知识库表 {table_name} 失败: {e}")
export_data[table_name] = []
return export_data
async def _export_kb_documents(self, kb_helper: Any) -> dict[str, Any]:
"""导出知识库的文档块数据"""
try:
from astrbot.core.db.vec_db.faiss_impl.vec_db import FaissVecDB
vec_db: FaissVecDB = kb_helper.vec_db
if not vec_db or not vec_db.document_storage:
return {"documents": []}
# 获取所有文档
docs = await vec_db.document_storage.get_documents(
metadata_filters={},
offset=0,
limit=None, # 获取全部
)
return {"documents": docs}
except Exception as e:
logger.warning(f"导出知识库文档失败: {e}")
return {"documents": []}
async def _export_faiss_index(
self,
zf: zipfile.ZipFile,
kb_helper: Any,
kb_id: str,
) -> None:
"""导出 FAISS 索引文件"""
try:
index_path = kb_helper.kb_dir / "index.faiss"
if index_path.exists():
archive_path = f"databases/kb_{kb_id}/index.faiss"
zf.write(str(index_path), archive_path)
logger.debug(f"导出 FAISS 索引: {archive_path}")
except Exception as e:
logger.warning(f"导出 FAISS 索引失败: {e}")
async def _export_kb_media_files(
self, zf: zipfile.ZipFile, kb_helper: Any, kb_id: str
) -> None:
"""导出知识库的多媒体文件"""
try:
media_dir = kb_helper.kb_medias_dir
if not media_dir.exists():
return
for root, _, files in os.walk(media_dir):
for file in files:
file_path = Path(root) / file
# 计算相对路径
rel_path = file_path.relative_to(kb_helper.kb_dir)
archive_path = f"files/kb_media/{kb_id}/{rel_path}"
zf.write(str(file_path), archive_path)
except Exception as e:
logger.warning(f"导出知识库媒体文件失败: {e}")
async def _export_directories(
self, zf: zipfile.ZipFile
) -> dict[str, dict[str, int]]:
"""导出插件和其他数据目录
Returns:
dict: 每个目录的统计信息 {dir_name: {"files": count, "size": bytes}}
"""
stats: dict[str, dict[str, int]] = {}
backup_directories = get_backup_directories()
for dir_name, dir_path in backup_directories.items():
full_path = Path(dir_path)
if not full_path.exists():
logger.debug(f"目录不存在,跳过: {full_path}")
continue
file_count = 0
total_size = 0
try:
for root, dirs, files in os.walk(full_path):
# 跳过 __pycache__ 目录
dirs[:] = [d for d in dirs if d != "__pycache__"]
for file in files:
# 跳过 .pyc 文件
if file.endswith(".pyc"):
continue
file_path = Path(root) / file
try:
# 计算相对路径
rel_path = file_path.relative_to(full_path)
archive_path = f"directories/{dir_name}/{rel_path}"
zf.write(str(file_path), archive_path)
file_count += 1
total_size += file_path.stat().st_size
except Exception as e:
logger.warning(f"导出文件 {file_path} 失败: {e}")
stats[dir_name] = {"files": file_count, "size": total_size}
logger.debug(
f"导出目录 {dir_name}: {file_count} 个文件, {total_size} 字节"
)
except Exception as e:
logger.warning(f"导出目录 {dir_path} 失败: {e}")
stats[dir_name] = {"files": 0, "size": 0}
return stats
async def _export_attachments(
self, zf: zipfile.ZipFile, attachments: list[dict]
) -> None:
"""导出附件文件"""
for attachment in attachments:
try:
file_path = attachment.get("path", "")
if file_path and os.path.exists(file_path):
# 使用 attachment_id 作为文件名
attachment_id = attachment.get("attachment_id", "")
ext = os.path.splitext(file_path)[1]
archive_path = f"files/attachments/{attachment_id}{ext}"
zf.write(file_path, archive_path)
except Exception as e:
logger.warning(f"导出附件失败: {e}")
def _model_to_dict(self, record: Any) -> dict:
"""将 SQLModel 实例转换为字典
这是数据库无关的序列化方式,支持未来迁移到其他数据库。
"""
# 使用 SQLModel 内置的 model_dump 方法(如果可用)
if hasattr(record, "model_dump"):
data = record.model_dump(mode="python")
# 处理 datetime 类型
for key, value in data.items():
if isinstance(value, datetime):
data[key] = value.isoformat()
return data
# 回退到手动提取
data = {}
# 使用 inspect 获取表信息
from sqlalchemy import inspect as sa_inspect
mapper = sa_inspect(record.__class__)
for column in mapper.columns:
value = getattr(record, column.name)
# 处理 datetime 类型 - 统一转为 ISO 格式字符串
if isinstance(value, datetime):
value = value.isoformat()
data[column.name] = value
return data
def _add_checksum(self, path: str, content: str | bytes) -> None:
"""计算并添加文件校验和"""
if isinstance(content, str):
content = content.encode("utf-8")
checksum = hashlib.sha256(content).hexdigest()
self._checksums[path] = f"sha256:{checksum}"
def _generate_manifest(
self,
main_data: dict[str, list[dict]],
kb_meta_data: dict[str, list[dict]],
dir_stats: dict[str, dict[str, int]] | None = None,
) -> dict:
"""生成备份清单"""
if dir_stats is None:
dir_stats = {}
# 收集知识库 ID
kb_document_tables = {}
if self.kb_manager:
for kb_id in self.kb_manager.kb_insts.keys():
kb_document_tables[kb_id] = "documents"
# 收集附件文件列表
attachment_files = []
for attachment in main_data.get("attachments", []):
attachment_id = attachment.get("attachment_id", "")
path = attachment.get("path", "")
if attachment_id and path:
ext = os.path.splitext(path)[1]
attachment_files.append(f"{attachment_id}{ext}")
# 收集知识库媒体文件
kb_media_files: dict[str, list[str]] = {}
if self.kb_manager:
for kb_id, kb_helper in self.kb_manager.kb_insts.items():
media_files: list[str] = []
media_dir = kb_helper.kb_medias_dir
if media_dir.exists():
for root, _, files in os.walk(media_dir):
for file in files:
media_files.append(file)
if media_files:
kb_media_files[kb_id] = media_files
manifest = {
"version": BACKUP_MANIFEST_VERSION,
"astrbot_version": VERSION,
"exported_at": datetime.now(timezone.utc).isoformat(),
"schema_version": {
"main_db": "v4",
"kb_db": "v1",
},
"tables": {
"main_db": list(main_data.keys()),
"kb_metadata": list(kb_meta_data.keys()),
"kb_documents": kb_document_tables,
},
"files": {
"attachments": attachment_files,
"kb_media": kb_media_files,
},
"directories": list(dir_stats.keys()),
"checksums": self._checksums,
"statistics": {
"main_db": {
table: len(records) for table, records in main_data.items()
},
"kb_metadata": {
table: len(records) for table, records in kb_meta_data.items()
},
"directories": dir_stats,
},
}
return manifest
+761
View File
@@ -0,0 +1,761 @@
"""AstrBot 数据导入器
负责从 ZIP 备份文件恢复所有数据。
导入时进行版本校验:
- 主版本(前两位)不同时直接拒绝导入
- 小版本(第三位)不同时提示警告,用户可选择强制导入
- 版本匹配时也需要用户确认
"""
import json
import os
import shutil
import zipfile
from dataclasses import dataclass, field
from datetime import datetime
from pathlib import Path
from typing import TYPE_CHECKING, Any
from sqlalchemy import delete
from astrbot.core import logger
from astrbot.core.config.default import VERSION
from astrbot.core.db import BaseDatabase
from astrbot.core.utils.astrbot_path import (
get_astrbot_data_path,
get_astrbot_knowledge_base_path,
)
from astrbot.core.utils.version_comparator import VersionComparator
# 从共享常量模块导入
from .constants import (
KB_METADATA_MODELS,
MAIN_DB_MODELS,
get_backup_directories,
)
if TYPE_CHECKING:
from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager
def _get_major_version(version_str: str) -> str:
"""提取版本的主版本部分(前两位)
Args:
version_str: 版本字符串,如 "4.9.1", "4.10.0-beta"
Returns:
主版本字符串,如 "4.9", "4.10"
"""
if not version_str:
return "0.0"
# 移除 v 前缀和预发布标签
version = version_str.lower().replace("v", "").split("-")[0].split("+")[0]
parts = [p for p in version.split(".") if p] # 过滤空字符串
if len(parts) >= 2:
return f"{parts[0]}.{parts[1]}"
elif len(parts) == 1 and parts[0]:
return f"{parts[0]}.0"
return "0.0"
CMD_CONFIG_FILE_PATH = os.path.join(get_astrbot_data_path(), "cmd_config.json")
KB_PATH = get_astrbot_knowledge_base_path()
@dataclass
class ImportPreCheckResult:
"""导入预检查结果
用于在实际导入前检查备份文件的版本兼容性,
并返回确认信息让用户决定是否继续导入。
"""
# 检查是否通过(文件有效且版本可导入)
valid: bool = False
# 是否可以导入(版本兼容)
can_import: bool = False
# 版本状态: match(完全匹配), minor_diff(小版本差异), major_diff(主版本不同,拒绝)
version_status: str = ""
# 备份文件中的 AstrBot 版本
backup_version: str = ""
# 当前运行的 AstrBot 版本
current_version: str = VERSION
# 备份创建时间
backup_time: str = ""
# 确认消息(显示给用户)
confirm_message: str = ""
# 警告消息列表
warnings: list[str] = field(default_factory=list)
# 错误消息(如果检查失败)
error: str = ""
# 备份包含的内容摘要
backup_summary: dict = field(default_factory=dict)
def to_dict(self) -> dict:
return {
"valid": self.valid,
"can_import": self.can_import,
"version_status": self.version_status,
"backup_version": self.backup_version,
"current_version": self.current_version,
"backup_time": self.backup_time,
"confirm_message": self.confirm_message,
"warnings": self.warnings,
"error": self.error,
"backup_summary": self.backup_summary,
}
class ImportResult:
"""导入结果"""
def __init__(self):
self.success = True
self.imported_tables: dict[str, int] = {}
self.imported_files: dict[str, int] = {}
self.imported_directories: dict[str, int] = {}
self.warnings: list[str] = []
self.errors: list[str] = []
def add_warning(self, msg: str) -> None:
self.warnings.append(msg)
logger.warning(msg)
def add_error(self, msg: str) -> None:
self.errors.append(msg)
self.success = False
logger.error(msg)
def to_dict(self) -> dict:
return {
"success": self.success,
"imported_tables": self.imported_tables,
"imported_files": self.imported_files,
"imported_directories": self.imported_directories,
"warnings": self.warnings,
"errors": self.errors,
}
class AstrBotImporter:
"""AstrBot 数据导入器
导入备份文件中的所有数据,包括:
- 主数据库所有表
- 知识库元数据和文档
- 配置文件
- 附件文件
- 知识库多媒体文件
- 插件目录(data/plugins
- 插件数据目录(data/plugin_data
- 配置目录(data/config
- T2I 模板目录(data/t2i_templates
- WebChat 数据目录(data/webchat
- 临时文件目录(data/temp
"""
def __init__(
self,
main_db: BaseDatabase,
kb_manager: "KnowledgeBaseManager | None" = None,
config_path: str = CMD_CONFIG_FILE_PATH,
kb_root_dir: str = KB_PATH,
):
self.main_db = main_db
self.kb_manager = kb_manager
self.config_path = config_path
self.kb_root_dir = kb_root_dir
def pre_check(self, zip_path: str) -> ImportPreCheckResult:
"""预检查备份文件
在实际导入前检查备份文件的有效性和版本兼容性。
返回检查结果供前端显示确认对话框。
Args:
zip_path: ZIP 备份文件路径
Returns:
ImportPreCheckResult: 预检查结果
"""
result = ImportPreCheckResult()
result.current_version = VERSION
if not os.path.exists(zip_path):
result.error = f"备份文件不存在: {zip_path}"
return result
try:
with zipfile.ZipFile(zip_path, "r") as zf:
# 读取 manifest
try:
manifest_data = zf.read("manifest.json")
manifest = json.loads(manifest_data)
except KeyError:
result.error = "备份文件缺少 manifest.json,不是有效的 AstrBot 备份"
return result
except json.JSONDecodeError as e:
result.error = f"manifest.json 格式错误: {e}"
return result
# 提取基本信息
result.backup_version = manifest.get("astrbot_version", "未知")
result.backup_time = manifest.get("exported_at", "未知")
result.valid = True
# 构建备份摘要
result.backup_summary = {
"tables": list(manifest.get("tables", {}).keys()),
"has_knowledge_bases": manifest.get("has_knowledge_bases", False),
"has_config": manifest.get("has_config", False),
"directories": manifest.get("directories", []),
}
# 检查版本兼容性
version_check = self._check_version_compatibility(result.backup_version)
result.version_status = version_check["status"]
result.can_import = version_check["can_import"]
# 版本信息由前端根据 version_status 和 i18n 生成显示
# 不再将版本消息添加到 warnings 列表中,避免中文硬编码
# warnings 列表保留用于其他非版本相关的警告
return result
except zipfile.BadZipFile:
result.error = "无效的 ZIP 文件"
return result
except Exception as e:
result.error = f"检查备份文件失败: {e}"
return result
def _check_version_compatibility(self, backup_version: str) -> dict:
"""检查版本兼容性
规则:
- 主版本(前两位,如 4.9)必须一致,否则拒绝
- 小版本(第三位,如 4.9.1 vs 4.9.2)不同时,警告但允许导入
Returns:
dict: {status, can_import, message}
"""
if not backup_version:
return {
"status": "major_diff",
"can_import": False,
"message": "备份文件缺少版本信息",
}
# 提取主版本(前两位)进行比较
backup_major = _get_major_version(backup_version)
current_major = _get_major_version(VERSION)
# 比较主版本
if VersionComparator.compare_version(backup_major, current_major) != 0:
return {
"status": "major_diff",
"can_import": False,
"message": (
f"主版本不兼容: 备份版本 {backup_version}, 当前版本 {VERSION}"
f"跨主版本导入可能导致数据损坏,请使用相同主版本的 AstrBot。"
),
}
# 比较完整版本
version_cmp = VersionComparator.compare_version(backup_version, VERSION)
if version_cmp != 0:
return {
"status": "minor_diff",
"can_import": True,
"message": (
f"小版本差异: 备份版本 {backup_version}, 当前版本 {VERSION}"
),
}
return {
"status": "match",
"can_import": True,
"message": "版本匹配",
}
async def import_all(
self,
zip_path: str,
mode: str = "replace", # "replace" 清空后导入
progress_callback: Any | None = None,
) -> ImportResult:
"""从 ZIP 文件导入所有数据
Args:
zip_path: ZIP 备份文件路径
mode: 导入模式,目前仅支持 "replace"(清空后导入)
progress_callback: 进度回调函数,接收参数 (stage, current, total, message)
Returns:
ImportResult: 导入结果
"""
result = ImportResult()
if not os.path.exists(zip_path):
result.add_error(f"备份文件不存在: {zip_path}")
return result
logger.info(f"开始从 {zip_path} 导入备份")
try:
with zipfile.ZipFile(zip_path, "r") as zf:
# 1. 读取并验证 manifest
if progress_callback:
await progress_callback("validate", 0, 100, "正在验证备份文件...")
try:
manifest_data = zf.read("manifest.json")
manifest = json.loads(manifest_data)
except KeyError:
result.add_error("备份文件缺少 manifest.json")
return result
except json.JSONDecodeError as e:
result.add_error(f"manifest.json 格式错误: {e}")
return result
# 版本校验
try:
self._validate_version(manifest)
except ValueError as e:
result.add_error(str(e))
return result
if progress_callback:
await progress_callback("validate", 100, 100, "验证完成")
# 2. 导入主数据库
if progress_callback:
await progress_callback("main_db", 0, 100, "正在导入主数据库...")
try:
main_data_content = zf.read("databases/main_db.json")
main_data = json.loads(main_data_content)
if mode == "replace":
await self._clear_main_db()
imported = await self._import_main_database(main_data)
result.imported_tables.update(imported)
except Exception as e:
result.add_error(f"导入主数据库失败: {e}")
return result
if progress_callback:
await progress_callback("main_db", 100, 100, "主数据库导入完成")
# 3. 导入知识库
if self.kb_manager and "databases/kb_metadata.json" in zf.namelist():
if progress_callback:
await progress_callback("kb", 0, 100, "正在导入知识库...")
try:
kb_meta_content = zf.read("databases/kb_metadata.json")
kb_meta_data = json.loads(kb_meta_content)
if mode == "replace":
await self._clear_kb_data()
await self._import_knowledge_bases(zf, kb_meta_data, result)
except Exception as e:
result.add_warning(f"导入知识库失败: {e}")
if progress_callback:
await progress_callback("kb", 100, 100, "知识库导入完成")
# 4. 导入配置文件
if progress_callback:
await progress_callback("config", 0, 100, "正在导入配置文件...")
if "config/cmd_config.json" in zf.namelist():
try:
config_content = zf.read("config/cmd_config.json")
# 备份现有配置
if os.path.exists(self.config_path):
backup_path = f"{self.config_path}.bak"
shutil.copy2(self.config_path, backup_path)
with open(self.config_path, "wb") as f:
f.write(config_content)
result.imported_files["config"] = 1
except Exception as e:
result.add_warning(f"导入配置文件失败: {e}")
if progress_callback:
await progress_callback("config", 100, 100, "配置文件导入完成")
# 5. 导入附件文件
if progress_callback:
await progress_callback("attachments", 0, 100, "正在导入附件...")
attachment_count = await self._import_attachments(
zf, main_data.get("attachments", [])
)
result.imported_files["attachments"] = attachment_count
if progress_callback:
await progress_callback("attachments", 100, 100, "附件导入完成")
# 6. 导入插件和其他目录
if progress_callback:
await progress_callback(
"directories", 0, 100, "正在导入插件和数据目录..."
)
dir_stats = await self._import_directories(zf, manifest, result)
result.imported_directories = dir_stats
if progress_callback:
await progress_callback("directories", 100, 100, "目录导入完成")
logger.info(f"备份导入完成: {result.to_dict()}")
return result
except zipfile.BadZipFile:
result.add_error("无效的 ZIP 文件")
return result
except Exception as e:
result.add_error(f"导入失败: {e}")
return result
def _validate_version(self, manifest: dict) -> None:
"""验证版本兼容性 - 仅允许相同主版本导入
注意:此方法仅在 import_all 中调用,用于双重校验。
前端应先调用 pre_check 获取详细的版本信息并让用户确认。
"""
backup_version = manifest.get("astrbot_version")
if not backup_version:
raise ValueError("备份文件缺少版本信息")
# 使用新的版本兼容性检查
version_check = self._check_version_compatibility(backup_version)
if version_check["status"] == "major_diff":
raise ValueError(version_check["message"])
# minor_diff 和 match 都允许导入
if version_check["status"] == "minor_diff":
logger.warning(f"版本差异警告: {version_check['message']}")
async def _clear_main_db(self) -> None:
"""清空主数据库所有表"""
async with self.main_db.get_db() as session:
async with session.begin():
for table_name, model_class in MAIN_DB_MODELS.items():
try:
await session.execute(delete(model_class))
logger.debug(f"已清空表 {table_name}")
except Exception as e:
logger.warning(f"清空表 {table_name} 失败: {e}")
async def _clear_kb_data(self) -> None:
"""清空知识库数据"""
if not self.kb_manager:
return
# 清空知识库元数据表
async with self.kb_manager.kb_db.get_db() as session:
async with session.begin():
for table_name, model_class in KB_METADATA_MODELS.items():
try:
await session.execute(delete(model_class))
logger.debug(f"已清空知识库表 {table_name}")
except Exception as e:
logger.warning(f"清空知识库表 {table_name} 失败: {e}")
# 删除知识库文件目录
for kb_id in list(self.kb_manager.kb_insts.keys()):
try:
kb_helper = self.kb_manager.kb_insts[kb_id]
await kb_helper.terminate()
if kb_helper.kb_dir.exists():
shutil.rmtree(kb_helper.kb_dir)
except Exception as e:
logger.warning(f"清理知识库 {kb_id} 失败: {e}")
self.kb_manager.kb_insts.clear()
async def _import_main_database(
self, data: dict[str, list[dict]]
) -> dict[str, int]:
"""导入主数据库数据"""
imported: dict[str, int] = {}
async with self.main_db.get_db() as session:
async with session.begin():
for table_name, rows in data.items():
model_class = MAIN_DB_MODELS.get(table_name)
if not model_class:
logger.warning(f"未知的表: {table_name}")
continue
count = 0
for row in rows:
try:
# 转换 datetime 字符串为 datetime 对象
row = self._convert_datetime_fields(row, model_class)
obj = model_class(**row)
session.add(obj)
count += 1
except Exception as e:
logger.warning(f"导入记录到 {table_name} 失败: {e}")
imported[table_name] = count
logger.debug(f"导入表 {table_name}: {count} 条记录")
return imported
async def _import_knowledge_bases(
self,
zf: zipfile.ZipFile,
kb_meta_data: dict[str, list[dict]],
result: ImportResult,
) -> None:
"""导入知识库数据"""
if not self.kb_manager:
return
# 1. 导入知识库元数据
async with self.kb_manager.kb_db.get_db() as session:
async with session.begin():
for table_name, rows in kb_meta_data.items():
model_class = KB_METADATA_MODELS.get(table_name)
if not model_class:
continue
count = 0
for row in rows:
try:
row = self._convert_datetime_fields(row, model_class)
obj = model_class(**row)
session.add(obj)
count += 1
except Exception as e:
logger.warning(f"导入知识库记录到 {table_name} 失败: {e}")
result.imported_tables[f"kb_{table_name}"] = count
# 2. 导入每个知识库的文档和文件
for kb_data in kb_meta_data.get("knowledge_bases", []):
kb_id = kb_data.get("kb_id")
if not kb_id:
continue
# 创建知识库目录
kb_dir = Path(self.kb_root_dir) / kb_id
kb_dir.mkdir(parents=True, exist_ok=True)
# 导入文档数据
doc_path = f"databases/kb_{kb_id}/documents.json"
if doc_path in zf.namelist():
try:
doc_content = zf.read(doc_path)
doc_data = json.loads(doc_content)
# 导入到文档存储数据库
await self._import_kb_documents(kb_id, doc_data)
except Exception as e:
result.add_warning(f"导入知识库 {kb_id} 的文档失败: {e}")
# 导入 FAISS 索引
faiss_path = f"databases/kb_{kb_id}/index.faiss"
if faiss_path in zf.namelist():
try:
target_path = kb_dir / "index.faiss"
with zf.open(faiss_path) as src, open(target_path, "wb") as dst:
dst.write(src.read())
except Exception as e:
result.add_warning(f"导入知识库 {kb_id} 的 FAISS 索引失败: {e}")
# 导入媒体文件
media_prefix = f"files/kb_media/{kb_id}/"
for name in zf.namelist():
if name.startswith(media_prefix):
try:
rel_path = name[len(media_prefix) :]
target_path = kb_dir / rel_path
target_path.parent.mkdir(parents=True, exist_ok=True)
with zf.open(name) as src, open(target_path, "wb") as dst:
dst.write(src.read())
except Exception as e:
result.add_warning(f"导入媒体文件 {name} 失败: {e}")
# 3. 重新加载知识库实例
await self.kb_manager.load_kbs()
async def _import_kb_documents(self, kb_id: str, doc_data: dict) -> None:
"""导入知识库文档到向量数据库"""
from astrbot.core.db.vec_db.faiss_impl.document_storage import DocumentStorage
kb_dir = Path(self.kb_root_dir) / kb_id
doc_db_path = kb_dir / "doc.db"
# 初始化文档存储
doc_storage = DocumentStorage(str(doc_db_path))
await doc_storage.initialize()
try:
documents = doc_data.get("documents", [])
for doc in documents:
try:
await doc_storage.insert_document(
doc_id=doc.get("doc_id", ""),
text=doc.get("text", ""),
metadata=json.loads(doc.get("metadata", "{}")),
)
except Exception as e:
logger.warning(f"导入文档块失败: {e}")
finally:
await doc_storage.close()
async def _import_attachments(
self,
zf: zipfile.ZipFile,
attachments: list[dict],
) -> int:
"""导入附件文件"""
count = 0
attachments_dir = Path(self.config_path).parent / "attachments"
attachments_dir.mkdir(parents=True, exist_ok=True)
attachment_prefix = "files/attachments/"
for name in zf.namelist():
if name.startswith(attachment_prefix) and name != attachment_prefix:
try:
# 从附件记录中找到原始路径
attachment_id = os.path.splitext(os.path.basename(name))[0]
original_path = None
for att in attachments:
if att.get("attachment_id") == attachment_id:
original_path = att.get("path")
break
if original_path:
target_path = Path(original_path)
else:
target_path = attachments_dir / os.path.basename(name)
target_path.parent.mkdir(parents=True, exist_ok=True)
with zf.open(name) as src, open(target_path, "wb") as dst:
dst.write(src.read())
count += 1
except Exception as e:
logger.warning(f"导入附件 {name} 失败: {e}")
return count
async def _import_directories(
self,
zf: zipfile.ZipFile,
manifest: dict,
result: ImportResult,
) -> dict[str, int]:
"""导入插件和其他数据目录
Args:
zf: ZIP 文件对象
manifest: 备份清单
result: 导入结果对象
Returns:
dict: 每个目录导入的文件数量
"""
dir_stats: dict[str, int] = {}
# 检查备份版本是否支持目录备份(需要版本 >= 1.1)
backup_version = manifest.get("version", "1.0")
if VersionComparator.compare_version(backup_version, "1.1") < 0:
logger.info("备份版本不支持目录备份,跳过目录导入")
return dir_stats
backed_up_dirs = manifest.get("directories", [])
backup_directories = get_backup_directories()
for dir_name in backed_up_dirs:
if dir_name not in backup_directories:
result.add_warning(f"未知的目录类型: {dir_name}")
continue
target_dir = Path(backup_directories[dir_name])
archive_prefix = f"directories/{dir_name}/"
file_count = 0
try:
# 获取该目录下的所有文件
dir_files = [
name
for name in zf.namelist()
if name.startswith(archive_prefix) and name != archive_prefix
]
if not dir_files:
continue
# 备份现有目录(如果存在)
if target_dir.exists():
backup_path = Path(f"{target_dir}.bak")
if backup_path.exists():
shutil.rmtree(backup_path)
shutil.move(str(target_dir), str(backup_path))
logger.debug(f"已备份现有目录 {target_dir}{backup_path}")
# 创建目标目录
target_dir.mkdir(parents=True, exist_ok=True)
# 解压文件
for name in dir_files:
try:
# 计算相对路径
rel_path = name[len(archive_prefix) :]
if not rel_path: # 跳过目录条目
continue
target_path = target_dir / rel_path
target_path.parent.mkdir(parents=True, exist_ok=True)
with zf.open(name) as src, open(target_path, "wb") as dst:
dst.write(src.read())
file_count += 1
except Exception as e:
result.add_warning(f"导入文件 {name} 失败: {e}")
dir_stats[dir_name] = file_count
logger.debug(f"导入目录 {dir_name}: {file_count} 个文件")
except Exception as e:
result.add_warning(f"导入目录 {dir_name} 失败: {e}")
dir_stats[dir_name] = 0
return dir_stats
def _convert_datetime_fields(self, row: dict, model_class: type) -> dict:
"""转换 datetime 字符串字段为 datetime 对象"""
result = row.copy()
# 获取模型的 datetime 字段
from sqlalchemy import inspect as sa_inspect
try:
mapper = sa_inspect(model_class)
for column in mapper.columns:
if column.name in result and result[column.name] is not None:
# 检查是否是 datetime 类型的列
from sqlalchemy import DateTime
if isinstance(column.type, DateTime):
value = result[column.name]
if isinstance(value, str):
# 解析 ISO 格式的日期时间字符串
result[column.name] = datetime.fromisoformat(value)
except Exception:
pass
return result
+1 -1
View File
@@ -5,7 +5,7 @@ from typing import Any, TypedDict
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
VERSION = "4.10.2"
VERSION = "4.10.3"
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db")
WEBHOOK_SUPPORTED_PLATFORMS = [
+1 -1
View File
@@ -58,7 +58,7 @@ def is_plugin_path(pathname):
return False
norm_path = os.path.normpath(pathname)
return ("data/plugins" in norm_path) or ("packages/" in norm_path)
return ("data/plugins" in norm_path) or ("astrbot/builtin_stars/" in norm_path)
def get_short_level_name(level_name):
@@ -390,7 +390,7 @@ class InternalAgentSubStage(Stage):
return
req.prompt = event.message_str[len(provider_wake_prefix) :]
# func_tool selection 现在已经转移到 packages/astrbot 插件中进行选择。
# func_tool selection 现在已经转移到 astrbot/builtin_stars/astrbot 插件中进行选择。
# req.func_tool = self.ctx.plugin_manager.context.get_llm_tool_manager()
for comp in event.message_obj.message:
if isinstance(comp, Image):
+2 -1
View File
@@ -136,7 +136,8 @@ class WakingCheckStage(Stage):
):
if (
self.disable_builtin_commands
and handler.handler_module_path == "packages.builtin_commands.main"
and handler.handler_module_path
== "astrbot.builtin_stars.builtin_commands.main"
):
logger.debug("skipping builtin command")
continue
+32 -9
View File
@@ -14,6 +14,7 @@ import astrbot.core.message.components as Comp
from astrbot import logger
from astrbot.core.agent.message import (
AssistantMessageSegment,
ContentPart,
ToolCall,
ToolCallMessageSegment,
)
@@ -92,6 +93,8 @@ class ProviderRequest:
"""会话 ID"""
image_urls: list[str] = field(default_factory=list)
"""图片 URL 列表"""
extra_user_content_parts: list[ContentPart] = field(default_factory=list)
"""额外的用户消息内容部分列表,用于在用户消息后添加额外的内容块(如系统提醒、指令等)。支持 dict 或 ContentPart 对象"""
func_tool: ToolSet | None = None
"""可用的函数工具"""
contexts: list[dict] = field(default_factory=list)
@@ -166,13 +169,23 @@ class ProviderRequest:
async def assemble_context(self) -> dict:
"""将请求(prompt 和 image_urls)包装成 OpenAI 的消息格式。"""
# 构建内容块列表
content_blocks = []
# 1. 用户原始发言(OpenAI 建议:用户发言在前)
if self.prompt and self.prompt.strip():
content_blocks.append({"type": "text", "text": self.prompt})
elif self.image_urls:
# 如果没有文本但有图片,添加占位文本
content_blocks.append({"type": "text", "text": "[图片]"})
# 2. 额外的内容块(系统提醒、指令等)
if self.extra_user_content_parts:
for part in self.extra_user_content_parts:
content_blocks.append(part.model_dump())
# 3. 图片内容
if self.image_urls:
user_content = {
"role": "user",
"content": [
{"type": "text", "text": self.prompt if self.prompt else "[图片]"},
],
}
for image_url in self.image_urls:
if image_url.startswith("http"):
image_path = await download_image_by_url(image_url)
@@ -185,11 +198,21 @@ class ProviderRequest:
if not image_data:
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
continue
user_content["content"].append(
content_blocks.append(
{"type": "image_url", "image_url": {"url": image_data}},
)
return user_content
return {"role": "user", "content": self.prompt}
# 只有当只有一个来自 prompt 的文本块且没有额外内容块时,才降级为简单格式以保持向后兼容
if (
len(content_blocks) == 1
and content_blocks[0]["type"] == "text"
and not self.extra_user_content_parts
and not self.image_urls
):
return {"role": "user", "content": content_blocks[0]["text"]}
# 否则返回多模态格式
return {"role": "user", "content": content_blocks}
async def _encode_image_bs64(self, image_url: str) -> str:
"""将图片转换为 base64"""
+3 -1
View File
@@ -4,7 +4,7 @@ import os
from collections.abc import AsyncGenerator
from typing import TypeAlias, Union
from astrbot.core.agent.message import Message
from astrbot.core.agent.message import ContentPart, Message
from astrbot.core.agent.tool import ToolSet
from astrbot.core.provider.entities import (
LLMResponse,
@@ -103,6 +103,7 @@ class Provider(AbstractProvider):
system_prompt: str | None = None,
tool_calls_result: ToolCallsResult | list[ToolCallsResult] | None = None,
model: str | None = None,
extra_user_content_parts: list[ContentPart] | None = None,
**kwargs,
) -> LLMResponse:
"""获得 LLM 的文本对话结果。会使用当前的模型进行对话。
@@ -114,6 +115,7 @@ class Provider(AbstractProvider):
tools: tool set
contexts: 上下文,和 prompt 二选一使用
tool_calls_result: 回传给 LLM 的工具调用结果。参考: https://platform.openai.com/docs/guides/function-calling
extra_user_content_parts: 额外的内容块列表,用于在用户消息后添加额外的文本块(如系统提醒、指令等)
kwargs: 其他参数
Notes:
@@ -11,6 +11,7 @@ from anthropic.types.usage import Usage
from astrbot import logger
from astrbot.api.provider import Provider
from astrbot.core.agent.message import ContentPart, ImageURLPart, TextPart
from astrbot.core.provider.entities import LLMResponse, TokenUsage
from astrbot.core.provider.func_tool_manager import ToolSet
from astrbot.core.utils.io import download_image_by_url
@@ -68,7 +69,7 @@ class ProviderAnthropic(Provider):
blocks = []
if isinstance(message["content"], str):
blocks.append({"type": "text", "text": message["content"]})
if "tool_calls" in message:
if "tool_calls" in message and isinstance(message["tool_calls"], list):
for tool_call in message["tool_calls"]:
blocks.append( # noqa: PERF401
{
@@ -132,6 +133,9 @@ class ProviderAnthropic(Provider):
extra_body = self.provider_config.get("custom_extra_body", {})
if "max_tokens" not in payloads:
payloads["max_tokens"] = 1024
completion = await self.client.messages.create(
**payloads, stream=False, extra_body=extra_body
)
@@ -181,6 +185,9 @@ class ProviderAnthropic(Provider):
usage = TokenUsage()
extra_body = self.provider_config.get("custom_extra_body", {})
if "max_tokens" not in payloads:
payloads["max_tokens"] = 1024
async with self.client.messages.stream(
**payloads, extra_body=extra_body
) as stream:
@@ -296,13 +303,16 @@ class ProviderAnthropic(Provider):
system_prompt=None,
tool_calls_result=None,
model=None,
extra_user_content_parts=None,
**kwargs,
) -> LLMResponse:
if contexts is None:
contexts = []
new_record = None
if prompt is not None:
new_record = await self.assemble_context(prompt, image_urls)
new_record = await self.assemble_context(
prompt, image_urls, extra_user_content_parts
)
context_query = self._ensure_message_to_dicts(contexts)
if new_record:
context_query.append(new_record)
@@ -342,21 +352,24 @@ class ProviderAnthropic(Provider):
async def text_chat_stream(
self,
prompt,
prompt=None,
session_id=None,
image_urls=...,
image_urls=None,
func_tool=None,
contexts=...,
contexts=None,
system_prompt=None,
tool_calls_result=None,
model=None,
extra_user_content_parts=None,
**kwargs,
):
if contexts is None:
contexts = []
new_record = None
if prompt is not None:
new_record = await self.assemble_context(prompt, image_urls)
new_record = await self.assemble_context(
prompt, image_urls, extra_user_content_parts
)
context_query = self._ensure_message_to_dicts(contexts)
if new_record:
context_query.append(new_record)
@@ -388,15 +401,15 @@ class ProviderAnthropic(Provider):
async for llm_response in self._query_stream(payloads, func_tool):
yield llm_response
async def assemble_context(self, text: str, image_urls: list[str] | None = None):
async def assemble_context(
self,
text: str,
image_urls: list[str] | None = None,
extra_user_content_parts: list[ContentPart] | None = None,
):
"""组装上下文,支持文本和图片"""
if not image_urls:
return {"role": "user", "content": text}
content = []
content.append({"type": "text", "text": text})
for image_url in image_urls:
async def resolve_image_url(image_url: str) -> dict | None:
if image_url.startswith("http"):
image_path = await download_image_by_url(image_url)
image_data = await self.encode_image_bs64(image_path)
@@ -408,28 +421,68 @@ class ProviderAnthropic(Provider):
if not image_data:
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
continue
return None
# Get mime type for the image
mime_type, _ = guess_type(image_url)
if not mime_type:
mime_type = "image/jpeg" # Default to JPEG if can't determine
content.append(
{
"type": "image",
"source": {
"type": "base64",
"media_type": mime_type,
"data": (
image_data.split("base64,")[1]
if "base64," in image_data
else image_data
),
},
return {
"type": "image",
"source": {
"type": "base64",
"media_type": mime_type,
"data": (
image_data.split("base64,")[1]
if "base64," in image_data
else image_data
),
},
)
}
content = []
# 1. 用户原始发言(OpenAI 建议:用户发言在前)
if text:
content.append({"type": "text", "text": text})
elif image_urls:
# 如果没有文本但有图片,添加占位文本
content.append({"type": "text", "text": "[图片]"})
elif extra_user_content_parts:
# 如果只有额外内容块,也需要添加占位文本
content.append({"type": "text", "text": " "})
# 2. 额外的内容块(系统提醒、指令等)
if extra_user_content_parts:
for block in extra_user_content_parts:
if isinstance(block, TextPart):
content.append({"type": "text", "text": block.text})
elif isinstance(block, ImageURLPart):
image_dict = await resolve_image_url(block.image_url.url)
if image_dict:
content.append(image_dict)
else:
raise ValueError(f"不支持的额外内容块类型: {type(block)}")
# 3. 图片内容
if image_urls:
for image_url in image_urls:
image_dict = await resolve_image_url(image_url)
if image_dict:
content.append(image_dict)
# 如果只有主文本且没有额外内容块和图片,返回简单格式以保持向后兼容
if (
text
and not extra_user_content_parts
and not image_urls
and len(content) == 1
and content[0]["type"] == "text"
):
return {"role": "user", "content": content[0]["text"]}
# 否则返回多模态格式
return {"role": "user", "content": content}
async def encode_image_bs64(self, image_url: str) -> str:
@@ -56,10 +56,14 @@ class ProviderFishAudioTTSAPI(TTSProvider):
"api_base",
"https://api.fish-audio.cn/v1",
)
try:
self.timeout: int = int(provider_config.get("timeout", 20))
except ValueError:
self.timeout = 20
self.headers = {
"Authorization": f"Bearer {self.chosen_api_key}",
}
self.set_model(provider_config["model"])
self.set_model(provider_config.get("model", None))
async def _get_reference_id_by_character(self, character: str) -> str | None:
"""获取角色的reference_id
@@ -135,17 +139,21 @@ class ProviderFishAudioTTSAPI(TTSProvider):
path = os.path.join(temp_dir, f"fishaudio_tts_api_{uuid.uuid4()}.wav")
self.headers["content-type"] = "application/msgpack"
request = await self._generate_request(text)
async with AsyncClient(base_url=self.api_base).stream(
async with AsyncClient(base_url=self.api_base, timeout=self.timeout).stream(
"POST",
"/tts",
headers=self.headers,
content=ormsgpack.packb(request, option=ormsgpack.OPT_SERIALIZE_PYDANTIC),
) as response:
if response.headers["content-type"] == "audio/wav":
if response.status_code == 200 and response.headers.get(
"content-type", ""
).startswith("audio/"):
with open(path, "wb") as f:
async for chunk in response.aiter_bytes():
f.write(chunk)
return path
body = await response.aread()
text = body.decode("utf-8", errors="replace")
raise Exception(f"Fish Audio API请求失败: {text}")
error_bytes = await response.aread()
error_text = error_bytes.decode("utf-8", errors="replace")[:1024]
raise Exception(
f"Fish Audio API请求失败: 状态码 {response.status_code}, 响应内容: {error_text}"
)
+75 -26
View File
@@ -13,6 +13,7 @@ from google.genai.errors import APIError
import astrbot.core.message.components as Comp
from astrbot import logger
from astrbot.api.provider import Provider
from astrbot.core.agent.message import ContentPart, ImageURLPart, TextPart
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.provider.entities import LLMResponse, TokenUsage
from astrbot.core.provider.func_tool_manager import ToolSet
@@ -680,13 +681,16 @@ class ProviderGoogleGenAI(Provider):
system_prompt=None,
tool_calls_result=None,
model=None,
extra_user_content_parts=None,
**kwargs,
) -> LLMResponse:
if contexts is None:
contexts = []
new_record = None
if prompt is not None:
new_record = await self.assemble_context(prompt, image_urls)
new_record = await self.assemble_context(
prompt, image_urls, extra_user_content_parts
)
context_query = self._ensure_message_to_dicts(contexts)
if new_record:
context_query.append(new_record)
@@ -732,13 +736,16 @@ class ProviderGoogleGenAI(Provider):
system_prompt=None,
tool_calls_result=None,
model=None,
extra_user_content_parts=None,
**kwargs,
) -> AsyncGenerator[LLMResponse, None]:
if contexts is None:
contexts = []
new_record = None
if prompt is not None:
new_record = await self.assemble_context(prompt, image_urls)
new_record = await self.assemble_context(
prompt, image_urls, extra_user_content_parts
)
context_query = self._ensure_message_to_dicts(contexts)
if new_record:
context_query.append(new_record)
@@ -797,33 +804,75 @@ class ProviderGoogleGenAI(Provider):
self.chosen_api_key = key
self._init_client()
async def assemble_context(self, text: str, image_urls: list[str] | None = None):
async def assemble_context(
self,
text: str,
image_urls: list[str] | None = None,
extra_user_content_parts: list[ContentPart] | None = None,
):
"""组装上下文。"""
if image_urls:
user_content = {
"role": "user",
"content": [{"type": "text", "text": text if text else "[图片]"}],
async def resolve_image_part(image_url: str) -> dict | None:
if image_url.startswith("http"):
image_path = await download_image_by_url(image_url)
image_data = await self.encode_image_bs64(image_path)
elif image_url.startswith("file:///"):
image_path = image_url.replace("file:///", "")
image_data = await self.encode_image_bs64(image_path)
else:
image_data = await self.encode_image_bs64(image_url)
if not image_data:
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
return None
return {
"type": "image_url",
"image_url": {"url": image_data},
}
for image_url in image_urls:
if image_url.startswith("http"):
image_path = await download_image_by_url(image_url)
image_data = await self.encode_image_bs64(image_path)
elif image_url.startswith("file:///"):
image_path = image_url.replace("file:///", "")
image_data = await self.encode_image_bs64(image_path)
# 构建内容块列表
content_blocks = []
# 1. 用户原始发言(OpenAI 建议:用户发言在前)
if text:
content_blocks.append({"type": "text", "text": text})
elif image_urls:
# 如果没有文本但有图片,添加占位文本
content_blocks.append({"type": "text", "text": "[图片]"})
elif extra_user_content_parts:
# 如果只有额外内容块,也需要添加占位文本
content_blocks.append({"type": "text", "text": " "})
# 2. 额外的内容块(系统提醒、指令等)
if extra_user_content_parts:
for part in extra_user_content_parts:
if isinstance(part, TextPart):
content_blocks.append({"type": "text", "text": part.text})
elif isinstance(part, ImageURLPart):
image_part = await resolve_image_part(part.image_url.url)
if image_part:
content_blocks.append(image_part)
else:
image_data = await self.encode_image_bs64(image_url)
if not image_data:
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
continue
user_content["content"].append(
{
"type": "image_url",
"image_url": {"url": image_data},
},
)
return user_content
return {"role": "user", "content": text}
raise ValueError(f"不支持的额外内容块类型: {type(part)}")
# 3. 图片内容
if image_urls:
for image_url in image_urls:
image_part = await resolve_image_part(image_url)
if image_part:
content_blocks.append(image_part)
# 如果只有主文本且没有额外内容块和图片,返回简单格式以保持向后兼容
if (
text
and not extra_user_content_parts
and not image_urls
and len(content_blocks) == 1
and content_blocks[0]["type"] == "text"
):
return {"role": "user", "content": content_blocks[0]["text"]}
# 否则返回多模态格式
return {"role": "user", "content": content_blocks}
async def encode_image_bs64(self, image_url: str) -> str:
"""将图片转换为 base64"""
+68 -25
View File
@@ -17,7 +17,7 @@ from openai.types.completion_usage import CompletionUsage
import astrbot.core.message.components as Comp
from astrbot import logger
from astrbot.api.provider import Provider
from astrbot.core.agent.message import Message
from astrbot.core.agent.message import ContentPart, ImageURLPart, Message, TextPart
from astrbot.core.agent.tool import ToolSet
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.provider.entities import LLMResponse, TokenUsage, ToolCallsResult
@@ -348,6 +348,7 @@ class ProviderOpenAIOfficial(Provider):
system_prompt: str | None = None,
tool_calls_result: ToolCallsResult | list[ToolCallsResult] | None = None,
model: str | None = None,
extra_user_content_parts: list[ContentPart] | None = None,
**kwargs,
) -> tuple:
"""准备聊天所需的有效载荷和上下文"""
@@ -355,7 +356,9 @@ class ProviderOpenAIOfficial(Provider):
contexts = []
new_record = None
if prompt is not None:
new_record = await self.assemble_context(prompt, image_urls)
new_record = await self.assemble_context(
prompt, image_urls, extra_user_content_parts
)
context_query = self._ensure_message_to_dicts(contexts)
if new_record:
context_query.append(new_record)
@@ -476,6 +479,7 @@ class ProviderOpenAIOfficial(Provider):
system_prompt=None,
tool_calls_result=None,
model=None,
extra_user_content_parts=None,
**kwargs,
) -> LLMResponse:
payloads, context_query = await self._prepare_chat_payload(
@@ -485,6 +489,7 @@ class ProviderOpenAIOfficial(Provider):
system_prompt,
tool_calls_result,
model=model,
extra_user_content_parts=extra_user_content_parts,
**kwargs,
)
@@ -624,33 +629,71 @@ class ProviderOpenAIOfficial(Provider):
self,
text: str,
image_urls: list[str] | None = None,
extra_user_content_parts: list[ContentPart] | None = None,
) -> dict:
"""组装成符合 OpenAI 格式的 role 为 user 的消息段"""
if image_urls:
user_content = {
"role": "user",
"content": [{"type": "text", "text": text if text else "[图片]"}],
async def resolve_image_part(image_url: str) -> dict | None:
if image_url.startswith("http"):
image_path = await download_image_by_url(image_url)
image_data = await self.encode_image_bs64(image_path)
elif image_url.startswith("file:///"):
image_path = image_url.replace("file:///", "")
image_data = await self.encode_image_bs64(image_path)
else:
image_data = await self.encode_image_bs64(image_url)
if not image_data:
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
return None
return {
"type": "image_url",
"image_url": {"url": image_data},
}
for image_url in image_urls:
if image_url.startswith("http"):
image_path = await download_image_by_url(image_url)
image_data = await self.encode_image_bs64(image_path)
elif image_url.startswith("file:///"):
image_path = image_url.replace("file:///", "")
image_data = await self.encode_image_bs64(image_path)
# 构建内容块列表
content_blocks = []
# 1. 用户原始发言(OpenAI 建议:用户发言在前)
if text:
content_blocks.append({"type": "text", "text": text})
elif image_urls:
# 如果没有文本但有图片,添加占位文本
content_blocks.append({"type": "text", "text": "[图片]"})
elif extra_user_content_parts:
# 如果只有额外内容块,也需要添加占位文本
content_blocks.append({"type": "text", "text": " "})
# 2. 额外的内容块(系统提醒、指令等)
if extra_user_content_parts:
for part in extra_user_content_parts:
if isinstance(part, TextPart):
content_blocks.append({"type": "text", "text": part.text})
elif isinstance(part, ImageURLPart):
image_part = await resolve_image_part(part.image_url.url)
if image_part:
content_blocks.append(image_part)
else:
image_data = await self.encode_image_bs64(image_url)
if not image_data:
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
continue
user_content["content"].append(
{
"type": "image_url",
"image_url": {"url": image_data},
},
)
return user_content
return {"role": "user", "content": text}
raise ValueError(f"不支持的额外内容块类型: {type(part)}")
# 3. 图片内容
if image_urls:
for image_url in image_urls:
image_part = await resolve_image_part(image_url)
if image_part:
content_blocks.append(image_part)
# 如果只有主文本且没有额外内容块和图片,返回简单格式以保持向后兼容
if (
text
and not extra_user_content_parts
and not image_urls
and len(content_blocks) == 1
and content_blocks[0]["type"] == "text"
):
return {"role": "user", "content": content_blocks[0]["text"]}
# 否则返回多模态格式
return {"role": "user", "content": content_blocks}
async def encode_image_bs64(self, image_url: str) -> str:
"""将图片转换为 base64"""
+1 -1
View File
@@ -377,7 +377,7 @@ class Context:
if not module_path:
_parts = []
module_part = tool.__module__.split(".")
flags = ["packages", "plugins"]
flags = ["builtin_stars", "plugins"]
for i, part in enumerate(module_part):
_parts.append(part)
if part in flags and i + 1 < len(module_part):
+11 -13
View File
@@ -18,6 +18,7 @@ from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.provider.register import llm_tools
from astrbot.core.utils.astrbot_path import (
get_astrbot_config_path,
get_astrbot_path,
get_astrbot_plugin_path,
)
from astrbot.core.utils.io import remove_dir
@@ -49,13 +50,10 @@ class PluginManager:
"""存储插件的路径。即 data/plugins"""
self.plugin_config_path = get_astrbot_config_path()
"""存储插件配置的路径。data/config"""
self.reserved_plugin_path = os.path.abspath(
os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"../../../packages",
),
self.reserved_plugin_path = os.path.join(
get_astrbot_path(), "astrbot", "builtin_stars"
)
"""保留插件的路径。在 packages 目录下"""
"""保留插件的路径。在 astrbot/builtin_stars 目录下"""
self.conf_schema_fname = "_conf_schema.json"
self.logo_fname = "logo.png"
"""插件配置 Schema 文件名"""
@@ -252,7 +250,7 @@ class PluginManager:
list[str]: 与该插件相关的模块名列表
"""
prefix = "packages." if is_reserved else "data.plugins."
prefix = "astrbot.builtin_stars." if is_reserved else "data.plugins."
return [
key
for key in list(sys.modules.keys())
@@ -270,7 +268,7 @@ class PluginManager:
可以基于模块名模式或插件目录名移除模块,用于清理插件相关的模块缓存
Args:
module_patterns: 要移除的模块名模式列表(例如 ["data.plugins", "packages"]
module_patterns: 要移除的模块名模式列表(例如 ["data.plugins", "astrbot.builtin_stars"]
root_dir_name: 插件根目录名,用于移除与该插件相关的所有模块
is_reserved: 插件是否为保留插件(影响模块路径前缀)
@@ -382,9 +380,9 @@ class PluginManager:
reserved = plugin_module.get(
"reserved",
False,
) # 是否是保留插件。目前在 packages/ 目录下的都是保留插件。保留插件不可以卸载。
) # 是否是保留插件。目前在 astrbot/builtin_stars 目录下的都是保留插件。保留插件不可以卸载。
path = "data.plugins." if not reserved else "packages."
path = "data.plugins." if not reserved else "astrbot.builtin_stars."
path += root_dir_name + "." + module_str
# 检查是否需要载入指定的插件
@@ -829,7 +827,7 @@ class PluginManager:
if (
mp
and mp.startswith(plugin_module_path)
and not mp.endswith(("packages", "data.plugins"))
and not mp.endswith(("astrbot.builtin_stars", "data.plugins"))
):
to_remove.append(func_tool)
for func_tool in to_remove:
@@ -884,7 +882,7 @@ class PluginManager:
plugin.module_path
and mp
and plugin.module_path.startswith(mp)
and not mp.endswith(("packages", "data.plugins"))
and not mp.endswith(("astrbot.builtin_stars", "data.plugins"))
):
func_tool.active = False
if func_tool.name not in inactivated_llm_tools:
@@ -933,7 +931,7 @@ class PluginManager:
plugin.module_path
and mp
and plugin.module_path.startswith(mp)
and not mp.endswith(("packages", "data.plugins"))
and not mp.endswith(("astrbot.builtin_stars", "data.plugins"))
and func_tool.name in inactivated_llm_tools
):
inactivated_llm_tools.remove(func_tool.name)
+34
View File
@@ -5,6 +5,10 @@
数据目录路径:固定为根目录下的 data 目录
配置文件路径:固定为数据目录下的 config 目录
插件目录路径:固定为数据目录下的 plugins 目录
插件数据目录路径:固定为数据目录下的 plugin_data 目录
T2I 模板目录路径:固定为数据目录下的 t2i_templates 目录
WebChat 数据目录路径:固定为数据目录下的 webchat 目录
临时文件目录路径:固定为数据目录下的 temp 目录
"""
import os
@@ -37,3 +41,33 @@ def get_astrbot_config_path() -> str:
def get_astrbot_plugin_path() -> str:
"""获取Astrbot插件目录路径"""
return os.path.realpath(os.path.join(get_astrbot_data_path(), "plugins"))
def get_astrbot_plugin_data_path() -> str:
"""获取Astrbot插件数据目录路径"""
return os.path.realpath(os.path.join(get_astrbot_data_path(), "plugin_data"))
def get_astrbot_t2i_templates_path() -> str:
"""获取Astrbot T2I 模板目录路径"""
return os.path.realpath(os.path.join(get_astrbot_data_path(), "t2i_templates"))
def get_astrbot_webchat_path() -> str:
"""获取Astrbot WebChat 数据目录路径"""
return os.path.realpath(os.path.join(get_astrbot_data_path(), "webchat"))
def get_astrbot_temp_path() -> str:
"""获取Astrbot临时文件目录路径"""
return os.path.realpath(os.path.join(get_astrbot_data_path(), "temp"))
def get_astrbot_knowledge_base_path() -> str:
"""获取Astrbot知识库根目录路径"""
return os.path.realpath(os.path.join(get_astrbot_data_path(), "knowledge_base"))
def get_astrbot_backups_path() -> str:
"""获取Astrbot备份目录路径"""
return os.path.realpath(os.path.join(get_astrbot_data_path(), "backups"))
+2
View File
@@ -1,4 +1,5 @@
from .auth import AuthRoute
from .backup import BackupRoute
from .chat import ChatRoute
from .command import CommandRoute
from .config import ConfigRoute
@@ -17,6 +18,7 @@ from .update import UpdateRoute
__all__ = [
"AuthRoute",
"BackupRoute",
"ChatRoute",
"CommandRoute",
"ConfigRoute",
+589
View File
@@ -0,0 +1,589 @@
"""备份管理 API 路由"""
import asyncio
import os
import re
import traceback
import uuid
from datetime import datetime
from pathlib import Path
from quart import request, send_file
from astrbot.core import logger
from astrbot.core.backup.exporter import AstrBotExporter
from astrbot.core.backup.importer import AstrBotImporter
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.db import BaseDatabase
from astrbot.core.utils.astrbot_path import (
get_astrbot_backups_path,
get_astrbot_data_path,
)
from .route import Response, Route, RouteContext
def secure_filename(filename: str) -> str:
"""清洗文件名,移除路径遍历字符和危险字符
Args:
filename: 原始文件名
Returns:
安全的文件名
"""
# 跨平台处理:先将反斜杠替换为正斜杠,再取文件名
filename = filename.replace("\\", "/")
# 仅保留文件名部分,移除路径
filename = os.path.basename(filename)
# 替换路径遍历字符
filename = filename.replace("..", "_")
# 仅保留字母、数字、下划线、连字符、点
filename = re.sub(r"[^\w\-.]", "_", filename)
# 移除前导点(隐藏文件)和尾部点
filename = filename.strip(".")
# 如果文件名为空或只包含下划线,生成一个默认名称
if not filename or filename.replace("_", "") == "":
filename = "backup"
return filename
def generate_unique_filename(original_filename: str) -> str:
"""生成唯一的文件名,添加时间戳前缀
Args:
original_filename: 原始文件名(已清洗)
Returns:
唯一的文件名
"""
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
name, ext = os.path.splitext(original_filename)
return f"uploaded_{timestamp}_{name}{ext}"
class BackupRoute(Route):
"""备份管理路由
提供备份导出、导入、列表等 API 接口
"""
def __init__(
self,
context: RouteContext,
db: BaseDatabase,
core_lifecycle: AstrBotCoreLifecycle,
) -> None:
super().__init__(context)
self.db = db
self.core_lifecycle = core_lifecycle
self.backup_dir = get_astrbot_backups_path()
self.data_dir = get_astrbot_data_path()
# 任务状态跟踪
self.backup_tasks: dict[str, dict] = {}
self.backup_progress: dict[str, dict] = {}
# 注册路由
self.routes = {
"/backup/list": ("GET", self.list_backups),
"/backup/export": ("POST", self.export_backup),
"/backup/upload": ("POST", self.upload_backup), # 上传文件
"/backup/check": ("POST", self.check_backup), # 预检查
"/backup/import": ("POST", self.import_backup), # 确认导入
"/backup/progress": ("GET", self.get_progress),
"/backup/download": ("GET", self.download_backup),
"/backup/delete": ("POST", self.delete_backup),
}
self.register_routes()
def _init_task(self, task_id: str, task_type: str, status: str = "pending") -> None:
"""初始化任务状态"""
self.backup_tasks[task_id] = {
"type": task_type,
"status": status,
"result": None,
"error": None,
}
self.backup_progress[task_id] = {
"status": status,
"stage": "waiting",
"current": 0,
"total": 100,
"message": "",
}
def _set_task_result(
self,
task_id: str,
status: str,
result: dict | None = None,
error: str | None = None,
) -> None:
"""设置任务结果"""
if task_id in self.backup_tasks:
self.backup_tasks[task_id]["status"] = status
self.backup_tasks[task_id]["result"] = result
self.backup_tasks[task_id]["error"] = error
if task_id in self.backup_progress:
self.backup_progress[task_id]["status"] = status
def _update_progress(
self,
task_id: str,
*,
status: str | None = None,
stage: str | None = None,
current: int | None = None,
total: int | None = None,
message: str | None = None,
) -> None:
"""更新任务进度"""
if task_id not in self.backup_progress:
return
p = self.backup_progress[task_id]
if status is not None:
p["status"] = status
if stage is not None:
p["stage"] = stage
if current is not None:
p["current"] = current
if total is not None:
p["total"] = total
if message is not None:
p["message"] = message
def _make_progress_callback(self, task_id: str):
"""创建进度回调函数"""
async def _callback(stage: str, current: int, total: int, message: str = ""):
self._update_progress(
task_id,
status="processing",
stage=stage,
current=current,
total=total,
message=message,
)
return _callback
async def list_backups(self):
"""获取备份列表
Query 参数:
- page: 页码 (默认 1)
- page_size: 每页数量 (默认 20)
"""
try:
page = request.args.get("page", 1, type=int)
page_size = request.args.get("page_size", 20, type=int)
# 确保备份目录存在
Path(self.backup_dir).mkdir(parents=True, exist_ok=True)
# 获取所有备份文件
backup_files = []
for filename in os.listdir(self.backup_dir):
if filename.endswith(".zip") and filename.startswith("astrbot_backup_"):
file_path = os.path.join(self.backup_dir, filename)
stat = os.stat(file_path)
backup_files.append(
{
"filename": filename,
"size": stat.st_size,
"created_at": stat.st_mtime,
}
)
# 按创建时间倒序排序
backup_files.sort(key=lambda x: x["created_at"], reverse=True)
# 分页
start = (page - 1) * page_size
end = start + page_size
items = backup_files[start:end]
return (
Response()
.ok(
{
"items": items,
"total": len(backup_files),
"page": page,
"page_size": page_size,
}
)
.__dict__
)
except Exception as e:
logger.error(f"获取备份列表失败: {e}")
logger.error(traceback.format_exc())
return Response().error(f"获取备份列表失败: {e!s}").__dict__
async def export_backup(self):
"""创建备份
返回:
- task_id: 任务ID,用于查询导出进度
"""
try:
# 生成任务ID
task_id = str(uuid.uuid4())
# 初始化任务状态
self._init_task(task_id, "export", "pending")
# 启动后台导出任务
asyncio.create_task(self._background_export_task(task_id))
return (
Response()
.ok(
{
"task_id": task_id,
"message": "export task created, processing in background",
}
)
.__dict__
)
except Exception as e:
logger.error(f"创建备份失败: {e}")
logger.error(traceback.format_exc())
return Response().error(f"创建备份失败: {e!s}").__dict__
async def _background_export_task(self, task_id: str):
"""后台导出任务"""
try:
self._update_progress(task_id, status="processing", message="正在初始化...")
# 获取知识库管理器
kb_manager = getattr(self.core_lifecycle, "kb_manager", None)
exporter = AstrBotExporter(
main_db=self.db,
kb_manager=kb_manager,
config_path=os.path.join(self.data_dir, "cmd_config.json"),
)
# 创建进度回调
progress_callback = self._make_progress_callback(task_id)
# 执行导出
zip_path = await exporter.export_all(
output_dir=self.backup_dir,
progress_callback=progress_callback,
)
# 设置成功结果
self._set_task_result(
task_id,
"completed",
result={
"filename": os.path.basename(zip_path),
"path": zip_path,
"size": os.path.getsize(zip_path),
},
)
except Exception as e:
logger.error(f"后台导出任务 {task_id} 失败: {e}")
logger.error(traceback.format_exc())
self._set_task_result(task_id, "failed", error=str(e))
async def upload_backup(self):
"""上传备份文件
将备份文件上传到服务器,返回保存的文件名。
上传后应调用 check_backup 进行预检查。
Form Data:
- file: 备份文件 (.zip)
返回:
- filename: 保存的文件名
"""
try:
files = await request.files
if "file" not in files:
return Response().error("缺少备份文件").__dict__
file = files["file"]
if not file.filename or not file.filename.endswith(".zip"):
return Response().error("请上传 ZIP 格式的备份文件").__dict__
# 清洗文件名并生成唯一名称,防止路径遍历和覆盖
safe_filename = secure_filename(file.filename)
unique_filename = generate_unique_filename(safe_filename)
# 保存上传的文件
Path(self.backup_dir).mkdir(parents=True, exist_ok=True)
zip_path = os.path.join(self.backup_dir, unique_filename)
await file.save(zip_path)
logger.info(
f"上传的备份文件已保存: {unique_filename} (原始名称: {file.filename})"
)
return (
Response()
.ok(
{
"filename": unique_filename,
"original_filename": file.filename,
"size": os.path.getsize(zip_path),
}
)
.__dict__
)
except Exception as e:
logger.error(f"上传备份文件失败: {e}")
logger.error(traceback.format_exc())
return Response().error(f"上传备份文件失败: {e!s}").__dict__
async def check_backup(self):
"""预检查备份文件
检查备份文件的版本兼容性,返回确认信息。
用户确认后调用 import_backup 执行导入。
JSON Body:
- filename: 已上传的备份文件名
返回:
- ImportPreCheckResult: 预检查结果
"""
try:
data = await request.json
filename = data.get("filename")
if not filename:
return Response().error("缺少 filename 参数").__dict__
# 安全检查 - 防止路径遍历
if ".." in filename or "/" in filename or "\\" in filename:
return Response().error("无效的文件名").__dict__
zip_path = os.path.join(self.backup_dir, filename)
if not os.path.exists(zip_path):
return Response().error(f"备份文件不存在: {filename}").__dict__
# 获取知识库管理器(用于构造 importer)
kb_manager = getattr(self.core_lifecycle, "kb_manager", None)
importer = AstrBotImporter(
main_db=self.db,
kb_manager=kb_manager,
config_path=os.path.join(self.data_dir, "cmd_config.json"),
)
# 执行预检查
check_result = importer.pre_check(zip_path)
return Response().ok(check_result.to_dict()).__dict__
except Exception as e:
logger.error(f"预检查备份文件失败: {e}")
logger.error(traceback.format_exc())
return Response().error(f"预检查备份文件失败: {e!s}").__dict__
async def import_backup(self):
"""执行备份导入
在用户确认后执行实际的导入操作。
需要先调用 upload_backup 上传文件,再调用 check_backup 预检查。
JSON Body:
- filename: 已上传的备份文件名(必填)
- confirmed: 用户已确认(必填,必须为 true)
返回:
- task_id: 任务ID,用于查询导入进度
"""
try:
data = await request.json
filename = data.get("filename")
confirmed = data.get("confirmed", False)
if not filename:
return Response().error("缺少 filename 参数").__dict__
if not confirmed:
return (
Response()
.error("请先确认导入。导入将会清空并覆盖现有数据,此操作不可撤销。")
.__dict__
)
# 安全检查 - 防止路径遍历
if ".." in filename or "/" in filename or "\\" in filename:
return Response().error("无效的文件名").__dict__
zip_path = os.path.join(self.backup_dir, filename)
if not os.path.exists(zip_path):
return Response().error(f"备份文件不存在: {filename}").__dict__
# 生成任务ID
task_id = str(uuid.uuid4())
# 初始化任务状态
self._init_task(task_id, "import", "pending")
# 启动后台导入任务
asyncio.create_task(self._background_import_task(task_id, zip_path))
return (
Response()
.ok(
{
"task_id": task_id,
"message": "import task created, processing in background",
}
)
.__dict__
)
except Exception as e:
logger.error(f"导入备份失败: {e}")
logger.error(traceback.format_exc())
return Response().error(f"导入备份失败: {e!s}").__dict__
async def _background_import_task(self, task_id: str, zip_path: str):
"""后台导入任务"""
try:
self._update_progress(task_id, status="processing", message="正在初始化...")
# 获取知识库管理器
kb_manager = getattr(self.core_lifecycle, "kb_manager", None)
importer = AstrBotImporter(
main_db=self.db,
kb_manager=kb_manager,
config_path=os.path.join(self.data_dir, "cmd_config.json"),
)
# 创建进度回调
progress_callback = self._make_progress_callback(task_id)
# 执行导入
result = await importer.import_all(
zip_path=zip_path,
mode="replace",
progress_callback=progress_callback,
)
# 设置结果
if result.success:
self._set_task_result(
task_id,
"completed",
result=result.to_dict(),
)
else:
self._set_task_result(
task_id,
"failed",
error="; ".join(result.errors),
)
except Exception as e:
logger.error(f"后台导入任务 {task_id} 失败: {e}")
logger.error(traceback.format_exc())
self._set_task_result(task_id, "failed", error=str(e))
async def get_progress(self):
"""获取任务进度
Query 参数:
- task_id: 任务 ID (必填)
"""
try:
task_id = request.args.get("task_id")
if not task_id:
return Response().error("缺少参数 task_id").__dict__
if task_id not in self.backup_tasks:
return Response().error("找不到该任务").__dict__
task_info = self.backup_tasks[task_id]
status = task_info["status"]
response_data = {
"task_id": task_id,
"type": task_info["type"],
"status": status,
}
# 如果任务正在处理,返回进度信息
if status == "processing" and task_id in self.backup_progress:
response_data["progress"] = self.backup_progress[task_id]
# 如果任务完成,返回结果
if status == "completed":
response_data["result"] = task_info["result"]
# 如果任务失败,返回错误信息
if status == "failed":
response_data["error"] = task_info["error"]
return Response().ok(response_data).__dict__
except Exception as e:
logger.error(f"获取任务进度失败: {e}")
logger.error(traceback.format_exc())
return Response().error(f"获取任务进度失败: {e!s}").__dict__
async def download_backup(self):
"""下载备份文件
Query 参数:
- filename: 备份文件名 (必填)
"""
try:
filename = request.args.get("filename")
if not filename:
return Response().error("缺少参数 filename").__dict__
# 安全检查 - 防止路径遍历
if ".." in filename or "/" in filename or "\\" in filename:
return Response().error("无效的文件名").__dict__
file_path = os.path.join(self.backup_dir, filename)
if not os.path.exists(file_path):
return Response().error("备份文件不存在").__dict__
return await send_file(
file_path,
as_attachment=True,
attachment_filename=filename,
)
except Exception as e:
logger.error(f"下载备份失败: {e}")
logger.error(traceback.format_exc())
return Response().error(f"下载备份失败: {e!s}").__dict__
async def delete_backup(self):
"""删除备份文件
Body:
- filename: 备份文件名 (必填)
"""
try:
data = await request.json
filename = data.get("filename")
if not filename:
return Response().error("缺少参数 filename").__dict__
# 安全检查 - 防止路径遍历
if ".." in filename or "/" in filename or "\\" in filename:
return Response().error("无效的文件名").__dict__
file_path = os.path.join(self.backup_dir, filename)
if not os.path.exists(file_path):
return Response().error("备份文件不存在").__dict__
os.remove(file_path)
return Response().ok(message="删除备份成功").__dict__
except Exception as e:
logger.error(f"删除备份失败: {e}")
logger.error(traceback.format_exc())
return Response().error(f"删除备份失败: {e!s}").__dict__
+44 -10
View File
@@ -1,15 +1,26 @@
import asyncio
import json
import time
from collections.abc import AsyncGenerator
from typing import cast
from quart import Response as QuartResponse
from quart import make_response
from quart import make_response, request
from astrbot.core import LogBroker, logger
from .route import Response, Route, RouteContext
def _format_log_sse(log: dict, ts: float) -> str:
"""辅助函数:格式化 SSE 消息"""
payload = {
"type": "log",
**log,
}
return f"id: {ts}\ndata: {json.dumps(payload, ensure_ascii=False)}\n\n"
class LogRoute(Route):
def __init__(self, context: RouteContext, log_broker: LogBroker) -> None:
super().__init__(context)
@@ -21,21 +32,44 @@ class LogRoute(Route):
methods=["GET"],
)
async def log(self):
async def _replay_cached_logs(
self, last_event_id: str
) -> AsyncGenerator[str, None]:
"""辅助生成器:重放缓存的日志"""
try:
last_ts = float(last_event_id)
cached_logs = list(self.log_broker.log_cache)
for log_item in cached_logs:
log_ts = float(log_item.get("time", 0))
if log_ts > last_ts:
yield _format_log_sse(log_item, log_ts)
except ValueError:
pass
except Exception as e:
logger.error(f"Log SSE 补发历史错误: {e}")
async def log(self) -> QuartResponse:
last_event_id = request.headers.get("Last-Event-ID")
async def stream():
queue = None
try:
if last_event_id:
async for event in self._replay_cached_logs(last_event_id):
yield event
queue = self.log_broker.register()
while True:
message = await queue.get()
payload = {
"type": "log",
**message, # see astrbot/core/log.py
}
yield f"data: {json.dumps(payload, ensure_ascii=False)}\n\n"
current_ts = message.get("time", time.time())
yield _format_log_sse(message, current_ts)
except asyncio.CancelledError:
pass
except BaseException as e:
except Exception as e:
logger.error(f"Log SSE 连接错误: {e}")
finally:
if queue:
@@ -53,7 +87,7 @@ class LogRoute(Route):
},
),
)
response.timeout = None
response.timeout = None # type: ignore
return response
async def log_history(self):
@@ -69,6 +103,6 @@ class LogRoute(Route):
)
.__dict__
)
except BaseException as e:
except Exception as e:
logger.error(f"获取日志历史失败: {e}")
return Response().error(f"获取日志历史失败: {e}").__dict__
+8 -1
View File
@@ -19,6 +19,7 @@ from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from astrbot.core.utils.io import get_local_ip_addresses
from .routes import *
from .routes.backup import BackupRoute
from .routes.platform import PlatformRoute
from .routes.route import Response, RouteContext
from .routes.session_management import SessionManagementRoute
@@ -85,6 +86,7 @@ class AstrBotDashboard:
self.t2i_route = T2iRoute(self.context, core_lifecycle)
self.kb_route = KnowledgeBaseRoute(self.context, core_lifecycle)
self.platform_route = PlatformRoute(self.context, core_lifecycle)
self.backup_route = BackupRoute(self.context, db, core_lifecycle)
self.app.add_url_rule(
"/api/plug/<path:subpath>",
@@ -108,7 +110,12 @@ class AstrBotDashboard:
async def auth_middleware(self):
if not request.path.startswith("/api"):
return None
allowed_endpoints = ["/api/auth/login", "/api/file", "/api/platform/webhook"]
allowed_endpoints = [
"/api/auth/login",
"/api/file",
"/api/platform/webhook",
"/api/stat/start-time",
]
if any(request.path.startswith(prefix) for prefix in allowed_endpoints):
return None
# 声明 JWT
+18
View File
@@ -0,0 +1,18 @@
## What's Changed
### 修复
1. 修复 FishAudio TTS 不可用的问题;
2. 修复 Anthropic API Chat Provider 部分情况下请求报错的问题;
3. 修复部分情况下 WebUI 日志重建连接之后丢失日志的问题;
4. 修复部分情况下 /provider 指令报错 index out of range 的问题;
5. 修复通过 `uv` 或者 cli 方式启动 AstrBot,缺少所有内置插件的问题。
### 优化
1. 丢弃值为 None 的 `tool_call_id``tool_calls` 字段,提高接口兼容性。
### 新增
1. 支持备份 AstrBot 数据和导入数据功能(Beta)。入口:WebUi -> 设置 -> 备份。
2. text_chat 和 text_chat_stream 接口支持额外用户内容块参数 `extra_user_content_parts`,用于在用户消息后添加额外的内容块(如系统提醒、指令等)。
+1
View File
@@ -22,6 +22,7 @@
"axios-mock-adapter": "^1.22.0",
"chance": "1.1.11",
"date-fns": "2.30.0",
"event-source-polyfill": "^1.0.31",
"highlight.js": "^11.11.1",
"js-md5": "^0.8.3",
"katex": "^0.16.27",
@@ -0,0 +1,673 @@
<template>
<v-dialog v-model="isOpen" persistent max-width="700" scrollable>
<v-card>
<v-card-title class="d-flex align-center">
<v-icon class="mr-2">mdi-backup-restore</v-icon>
{{ t('features.settings.backup.dialog.title') }}
</v-card-title>
<v-card-text class="pa-6">
<!-- 选项卡 -->
<v-tabs v-model="activeTab" color="primary" class="mb-4">
<v-tab value="export">
<v-icon class="mr-2">mdi-export</v-icon>
{{ t('features.settings.backup.tabs.export') }}
</v-tab>
<v-tab value="import">
<v-icon class="mr-2">mdi-import</v-icon>
{{ t('features.settings.backup.tabs.import') }}
</v-tab>
<v-tab value="list">
<v-icon class="mr-2">mdi-format-list-bulleted</v-icon>
{{ t('features.settings.backup.tabs.list') }}
</v-tab>
</v-tabs>
<v-window v-model="activeTab">
<!-- 导出标签页 -->
<v-window-item value="export">
<div v-if="exportStatus === 'idle'" class="text-center py-8">
<v-icon size="64" color="primary" class="mb-4">mdi-cloud-upload</v-icon>
<h3 class="mb-4">{{ t('features.settings.backup.export.title') }}</h3>
<p class="mb-4 text-grey">{{ t('features.settings.backup.export.description') }}</p>
<v-alert type="info" variant="tonal" class="mb-4 text-left">
<template v-slot:prepend>
<v-icon>mdi-information</v-icon>
</template>
{{ t('features.settings.backup.export.includes') }}
</v-alert>
<v-btn color="primary" size="large" @click="startExport" :loading="exportStatus === 'processing'">
<v-icon class="mr-2">mdi-export</v-icon>
{{ t('features.settings.backup.export.button') }}
</v-btn>
</div>
<div v-else-if="exportStatus === 'processing'" class="text-center py-8">
<v-progress-circular indeterminate color="primary" size="64" class="mb-4"></v-progress-circular>
<h3 class="mb-4">{{ t('features.settings.backup.export.processing') }}</h3>
<p class="text-grey">{{ exportProgress.message || t('features.settings.backup.export.wait') }}</p>
<v-progress-linear :model-value="exportProgress.current" :max="exportProgress.total" class="mt-4" color="primary"></v-progress-linear>
</div>
<div v-else-if="exportStatus === 'completed'" class="text-center py-8">
<v-icon size="64" color="success" class="mb-4">mdi-check-circle</v-icon>
<h3 class="mb-4">{{ t('features.settings.backup.export.completed') }}</h3>
<p class="mb-4">{{ exportResult?.filename }}</p>
<v-btn color="primary" @click="downloadBackup(exportResult?.filename)" class="mr-2">
<v-icon class="mr-2">mdi-download</v-icon>
{{ t('features.settings.backup.export.download') }}
</v-btn>
<v-btn color="grey" variant="text" @click="resetExport">
{{ t('features.settings.backup.export.another') }}
</v-btn>
</div>
<div v-else-if="exportStatus === 'failed'" class="text-center py-8">
<v-icon size="64" color="error" class="mb-4">mdi-alert-circle</v-icon>
<h3 class="mb-4">{{ t('features.settings.backup.export.failed') }}</h3>
<v-alert type="error" variant="tonal" class="mb-4">
{{ exportError }}
</v-alert>
<v-btn color="primary" @click="resetExport">
{{ t('features.settings.backup.export.retry') }}
</v-btn>
</div>
</v-window-item>
<!-- 导入标签页 -->
<v-window-item value="import">
<!-- 步骤1: 选择文件 -->
<div v-if="importStatus === 'idle'" class="py-4">
<v-alert type="warning" variant="tonal" class="mb-4">
<template v-slot:prepend>
<v-icon>mdi-alert</v-icon>
</template>
{{ t('features.settings.backup.import.warning') }}
</v-alert>
<v-file-input
v-model="importFile"
:label="t('features.settings.backup.import.selectFile')"
accept=".zip"
prepend-icon="mdi-file-upload"
show-size
class="mb-4"
></v-file-input>
<div class="d-flex justify-center">
<v-btn
color="primary"
size="large"
@click="uploadAndCheck"
:disabled="!importFile"
:loading="importStatus === 'uploading'"
>
<v-icon class="mr-2">mdi-upload</v-icon>
{{ t('features.settings.backup.import.uploadAndCheck') }}
</v-btn>
</div>
</div>
<!-- 步骤1.5: 上传中 -->
<div v-else-if="importStatus === 'uploading'" class="text-center py-8">
<v-progress-circular indeterminate color="primary" size="64" class="mb-4"></v-progress-circular>
<h3 class="mb-4">{{ t('features.settings.backup.import.uploading') }}</h3>
<p class="text-grey">{{ t('features.settings.backup.import.uploadWait') }}</p>
</div>
<!-- 步骤2: 确认导入 -->
<div v-else-if="importStatus === 'confirm'" class="py-4">
<v-alert
:type="versionAlertType"
variant="tonal"
class="mb-4"
>
<template v-slot:prepend>
<v-icon>{{ versionAlertIcon }}</v-icon>
</template>
<div class="confirm-message">
<div class="text-h6 mb-2">{{ versionAlertTitle }}</div>
<div class="mb-2">
<strong>{{ t('features.settings.backup.import.version.backupVersion') }}:</strong> {{ checkResult?.backup_version }}<br>
<strong>{{ t('features.settings.backup.import.version.currentVersion') }}:</strong> {{ checkResult?.current_version }}
</div>
<div v-if="checkResult?.backup_time && checkResult?.backup_time !== '未知'" class="mb-2">
<strong>{{ t('features.settings.backup.import.version.backupTime') }}:</strong> {{ formatISODate(checkResult?.backup_time) }}
</div>
<div class="mt-3" style="white-space: pre-line;">{{ versionAlertMessage }}</div>
</div>
</v-alert>
<!-- 备份摘要 -->
<v-card variant="outlined" class="mb-4" v-if="checkResult?.backup_summary">
<v-card-title class="text-subtitle-1">
<v-icon class="mr-2">mdi-package-variant</v-icon>
{{ t('features.settings.backup.import.backupContents') }}
</v-card-title>
<v-card-text>
<div class="d-flex flex-wrap ga-2">
<v-chip v-if="checkResult.backup_summary.tables?.length" size="small" color="primary" variant="tonal" :ripple="false" class="non-interactive-chip">
{{ checkResult.backup_summary.tables.length }} {{ t('features.settings.backup.import.tables') }}
</v-chip>
<v-chip v-if="checkResult.backup_summary.has_knowledge_bases" size="small" color="success" variant="tonal" :ripple="false" class="non-interactive-chip">
{{ t('features.settings.backup.import.knowledgeBases') }}
</v-chip>
<v-chip v-if="checkResult.backup_summary.has_config" size="small" color="info" variant="tonal" :ripple="false" class="non-interactive-chip">
{{ t('features.settings.backup.import.configFiles') }}
</v-chip>
<v-chip v-for="dir in (checkResult.backup_summary.directories || [])" :key="dir" size="small" color="warning" variant="tonal" :ripple="false" class="non-interactive-chip">
{{ dir }}
</v-chip>
</div>
</v-card-text>
</v-card>
<!-- 警告信息 -->
<v-alert v-if="checkResult?.warnings?.length" type="warning" variant="tonal" class="mb-4">
<div v-for="(warning, idx) in checkResult.warnings" :key="idx">{{ warning }}</div>
</v-alert>
<div class="d-flex justify-center align-center mt-4" style="gap: 16px;">
<v-btn
color="grey-darken-1"
variant="outlined"
size="large"
@click="resetImport"
>
<v-icon class="mr-2">mdi-close</v-icon>
{{ t('core.common.cancel') }}
</v-btn>
<v-btn
v-if="checkResult?.can_import"
color="error"
size="large"
variant="flat"
@click="confirmImport"
>
<v-icon class="mr-2">mdi-alert</v-icon>
{{ t('features.settings.backup.import.confirmImport') }}
</v-btn>
</div>
</div>
<!-- 步骤3: 导入进行中 -->
<div v-else-if="importStatus === 'processing'" class="text-center py-8">
<v-progress-circular indeterminate color="primary" size="64" class="mb-4"></v-progress-circular>
<h3 class="mb-4">{{ t('features.settings.backup.import.processing') }}</h3>
<p class="text-grey">{{ importProgress.message || t('features.settings.backup.import.wait') }}</p>
<v-progress-linear :model-value="importProgress.current" :max="importProgress.total" class="mt-4" color="primary"></v-progress-linear>
</div>
<div v-else-if="importStatus === 'completed'" class="text-center py-8">
<v-icon size="64" color="success" class="mb-4">mdi-check-circle</v-icon>
<h3 class="mb-4">{{ t('features.settings.backup.import.completed') }}</h3>
<v-alert type="info" variant="tonal" class="mb-4">
{{ t('features.settings.backup.import.restartRequired') }}
</v-alert>
<v-btn color="primary" @click="restartAstrBot" class="mr-2">
<v-icon class="mr-2">mdi-restart</v-icon>
{{ t('features.settings.backup.import.restartNow') }}
</v-btn>
<v-btn color="grey" variant="text" @click="resetImport">
{{ t('core.common.close') }}
</v-btn>
</div>
<div v-else-if="importStatus === 'failed'" class="text-center py-8">
<v-icon size="64" color="error" class="mb-4">mdi-alert-circle</v-icon>
<h3 class="mb-4">{{ t('features.settings.backup.import.failed') }}</h3>
<v-alert type="error" variant="tonal" class="mb-4">
{{ importError }}
</v-alert>
<v-btn color="primary" @click="resetImport">
{{ t('features.settings.backup.import.retry') }}
</v-btn>
</div>
</v-window-item>
<!-- 备份列表标签页 -->
<v-window-item value="list">
<div v-if="loadingList" class="text-center py-8">
<v-progress-circular indeterminate color="primary"></v-progress-circular>
</div>
<div v-else-if="backupList.length === 0" class="text-center py-8">
<v-icon size="64" color="grey" class="mb-4">mdi-folder-open-outline</v-icon>
<p class="text-grey">{{ t('features.settings.backup.list.empty') }}</p>
</div>
<v-list v-else lines="two">
<v-list-item
v-for="backup in backupList"
:key="backup.filename"
>
<template v-slot:prepend>
<v-icon color="primary">mdi-zip-box</v-icon>
</template>
<v-list-item-title>{{ backup.filename }}</v-list-item-title>
<v-list-item-subtitle>
{{ formatFileSize(backup.size) }} · {{ formatDate(backup.created_at) }}
</v-list-item-subtitle>
<template v-slot:append>
<v-btn icon="mdi-download" variant="text" size="small" @click="downloadBackup(backup.filename)"></v-btn>
<v-btn icon="mdi-delete" variant="text" size="small" color="error" @click="deleteBackup(backup.filename)"></v-btn>
</template>
</v-list-item>
</v-list>
<div class="d-flex justify-center mt-4">
<v-btn color="primary" variant="text" @click="loadBackupList">
<v-icon class="mr-2">mdi-refresh</v-icon>
{{ t('features.settings.backup.list.refresh') }}
</v-btn>
</div>
</v-window-item>
</v-window>
</v-card-text>
<v-card-actions class="px-6 py-4">
<v-spacer></v-spacer>
<v-btn color="grey" variant="text" @click="handleClose" :disabled="isProcessing">
{{ t('core.common.close') }}
</v-btn>
</v-card-actions>
</v-card>
</v-dialog>
<WaitingForRestart ref="wfr"></WaitingForRestart>
</template>
<script setup>
import { ref, computed, watch } from 'vue'
import axios from 'axios'
import { useI18n } from '@/i18n/composables'
import WaitingForRestart from './WaitingForRestart.vue'
const { t } = useI18n()
const isOpen = ref(false)
const activeTab = ref('export')
const wfr = ref(null)
// 导出状态
const exportStatus = ref('idle') // idle, processing, completed, failed
const exportTaskId = ref(null)
const exportProgress = ref({ current: 0, total: 100, message: '' })
const exportResult = ref(null)
const exportError = ref('')
// 导入状态
const importStatus = ref('idle') // idle, uploading, confirm, processing, completed, failed
const importFile = ref(null)
const importTaskId = ref(null)
const importProgress = ref({ current: 0, total: 100, message: '' })
const importError = ref('')
const uploadedFilename = ref('') // 已上传的文件名
const checkResult = ref(null) // 预检查结果
// 备份列表
const loadingList = ref(false)
const backupList = ref([])
// 计算属性
const isProcessing = computed(() => {
return exportStatus.value === 'processing' || importStatus.value === 'processing'
})
// 版本检查相关的计算属性
const versionAlertType = computed(() => {
const status = checkResult.value?.version_status
if (status === 'major_diff') return 'error'
if (status === 'minor_diff') return 'warning'
return 'info'
})
const versionAlertIcon = computed(() => {
const status = checkResult.value?.version_status
if (status === 'major_diff') return 'mdi-close-circle'
if (status === 'minor_diff') return 'mdi-alert'
return 'mdi-check-circle'
})
const versionAlertTitle = computed(() => {
const status = checkResult.value?.version_status
if (status === 'major_diff') return t('features.settings.backup.import.version.majorDiffTitle')
if (status === 'minor_diff') return t('features.settings.backup.import.version.minorDiffTitle')
return t('features.settings.backup.import.version.matchTitle')
})
const versionAlertMessage = computed(() => {
const status = checkResult.value?.version_status
if (status === 'major_diff') return t('features.settings.backup.import.version.majorDiffMessage')
if (status === 'minor_diff') return t('features.settings.backup.import.version.minorDiffMessage')
return t('features.settings.backup.import.version.matchMessage')
})
// 监听对话框打开
watch(isOpen, (newVal) => {
if (newVal) {
loadBackupList()
} else {
resetAll()
}
})
// 监听标签页切换
watch(activeTab, (newVal) => {
if (newVal === 'list') {
loadBackupList()
}
})
// 加载备份列表
const loadBackupList = async () => {
loadingList.value = true
try {
const response = await axios.get('/api/backup/list')
if (response.data.status === 'ok') {
backupList.value = response.data.data.items || []
}
} catch (error) {
console.error('Failed to load backup list:', error)
} finally {
loadingList.value = false
}
}
// 开始导出
const startExport = async () => {
exportStatus.value = 'processing'
exportProgress.value = { current: 0, total: 100, message: '' }
try {
const response = await axios.post('/api/backup/export')
if (response.data.status === 'ok') {
exportTaskId.value = response.data.data.task_id
pollExportProgress()
} else {
throw new Error(response.data.message)
}
} catch (error) {
exportStatus.value = 'failed'
exportError.value = error.message || 'Export failed'
}
}
// 轮询导出进度
const pollExportProgress = async () => {
if (!exportTaskId.value) return
try {
const response = await axios.get('/api/backup/progress', {
params: { task_id: exportTaskId.value }
})
if (response.data.status === 'ok') {
const data = response.data.data
if (data.status === 'processing' && data.progress) {
exportProgress.value = {
current: data.progress.current || 0,
total: data.progress.total || 100,
message: data.progress.message || ''
}
setTimeout(pollExportProgress, 1000)
} else if (data.status === 'completed') {
exportStatus.value = 'completed'
exportResult.value = data.result
loadBackupList()
} else if (data.status === 'failed') {
exportStatus.value = 'failed'
exportError.value = data.error || 'Export failed'
} else {
setTimeout(pollExportProgress, 1000)
}
}
} catch (error) {
exportStatus.value = 'failed'
exportError.value = error.message || 'Failed to get export progress'
}
}
// 重置导出状态
const resetExport = () => {
exportStatus.value = 'idle'
exportTaskId.value = null
exportProgress.value = { current: 0, total: 100, message: '' }
exportResult.value = null
exportError.value = ''
}
// 上传并检查
const uploadAndCheck = async () => {
if (!importFile.value) return
importStatus.value = 'uploading'
try {
// 步骤1: 上传文件
const formData = new FormData()
formData.append('file', importFile.value)
const uploadResponse = await axios.post('/api/backup/upload', formData, {
headers: { 'Content-Type': 'multipart/form-data' }
})
if (uploadResponse.data.status !== 'ok') {
throw new Error(uploadResponse.data.message)
}
uploadedFilename.value = uploadResponse.data.data.filename
// 步骤2: 预检查
const checkResponse = await axios.post('/api/backup/check', {
filename: uploadedFilename.value
})
if (checkResponse.data.status !== 'ok') {
throw new Error(checkResponse.data.message)
}
checkResult.value = checkResponse.data.data
// 检查是否有效
if (!checkResult.value.valid) {
importStatus.value = 'failed'
importError.value = checkResult.value.error || t('features.settings.backup.import.invalidBackup')
return
}
// 显示确认对话框
importStatus.value = 'confirm'
} catch (error) {
importStatus.value = 'failed'
importError.value = error.response?.data?.message || error.message || 'Upload failed'
}
}
// 确认导入
const confirmImport = async () => {
if (!uploadedFilename.value) return
importStatus.value = 'processing'
importProgress.value = { current: 0, total: 100, message: '' }
try {
const response = await axios.post('/api/backup/import', {
filename: uploadedFilename.value,
confirmed: true
})
if (response.data.status === 'ok') {
importTaskId.value = response.data.data.task_id
pollImportProgress()
} else {
throw new Error(response.data.message)
}
} catch (error) {
importStatus.value = 'failed'
importError.value = error.response?.data?.message || error.message || 'Import failed'
}
}
// 轮询导入进度
const pollImportProgress = async () => {
if (!importTaskId.value) return
try {
const response = await axios.get('/api/backup/progress', {
params: { task_id: importTaskId.value }
})
if (response.data.status === 'ok') {
const data = response.data.data
if (data.status === 'processing' && data.progress) {
importProgress.value = {
current: data.progress.current || 0,
total: data.progress.total || 100,
message: data.progress.message || ''
}
setTimeout(pollImportProgress, 1000)
} else if (data.status === 'completed') {
importStatus.value = 'completed'
} else if (data.status === 'failed') {
importStatus.value = 'failed'
importError.value = data.error || 'Import failed'
} else {
setTimeout(pollImportProgress, 1000)
}
}
} catch (error) {
importStatus.value = 'failed'
importError.value = error.message || 'Failed to get import progress'
}
}
// 重置导入状态
const resetImport = () => {
importStatus.value = 'idle'
importFile.value = null
importTaskId.value = null
importProgress.value = { current: 0, total: 100, message: '' }
importError.value = ''
uploadedFilename.value = ''
checkResult.value = null
}
// 下载备份
const downloadBackup = async (filename) => {
try {
const response = await axios.get('/api/backup/download', {
params: { filename },
responseType: 'blob'
})
// 创建 Blob URL 并触发下载
const blob = new Blob([response.data], { type: 'application/zip' })
const url = window.URL.createObjectURL(blob)
const link = document.createElement('a')
link.href = url
link.download = filename
document.body.appendChild(link)
link.click()
document.body.removeChild(link)
window.URL.revokeObjectURL(url)
} catch (error) {
console.error('Download failed:', error)
alert(t('features.settings.backup.export.failed') + ': ' + (error.message || 'Unknown error'))
}
}
// 删除备份
const deleteBackup = async (filename) => {
if (!confirm(t('features.settings.backup.list.confirmDelete'))) return
try {
const response = await axios.post('/api/backup/delete', { filename })
if (response.data.status === 'ok') {
loadBackupList()
} else {
alert(response.data.message || 'Delete failed')
}
} catch (error) {
alert(error.message || 'Delete failed')
}
}
// 格式化文件大小
const formatFileSize = (bytes) => {
if (bytes === 0) return '0 B'
const k = 1024
const sizes = ['B', 'KB', 'MB', 'GB']
const i = Math.floor(Math.log(bytes) / Math.log(k))
return parseFloat((bytes / Math.pow(k, i)).toFixed(2)) + ' ' + sizes[i]
}
// 格式化日期(从时间戳)
const formatDate = (timestamp) => {
return new Date(timestamp * 1000).toLocaleString()
}
// 格式化 ISO 日期字符串
const formatISODate = (isoString) => {
if (!isoString) return ''
try {
return new Date(isoString).toLocaleString()
} catch {
return isoString
}
}
// 重启 AstrBot
const restartAstrBot = () => {
axios.post('/api/stat/restart-core').then(() => {
if (wfr.value) {
wfr.value.check()
}
})
}
// 重置所有状态
const resetAll = () => {
resetExport()
resetImport()
activeTab.value = 'export'
}
// 关闭对话框
const handleClose = () => {
if (isProcessing.value) return
isOpen.value = false
}
// 打开对话框
const open = () => {
isOpen.value = true
}
defineExpose({ open })
</script>
<style scoped>
.v-list-item {
border-bottom: 1px solid rgba(0, 0, 0, 0.08);
}
.v-list-item:last-child {
border-bottom: none;
}
/* 禁用 Chip 的交互效果 */
.non-interactive-chip {
pointer-events: none;
cursor: default;
}
.non-interactive-chip:hover {
box-shadow: none !important;
}
</style>
@@ -1,12 +1,11 @@
<script setup>
import { useCommonStore } from '@/stores/common';
import { storeToRefs } from 'pinia';
import axios from 'axios';
import { EventSourcePolyfill } from 'event-source-polyfill';
</script>
<template>
<div>
<!-- 添加筛选级别控件 -->
<div class="filter-controls mb-2" v-if="showLevelBtns">
<v-chip-group v-model="selectedLevels" column multiple>
<v-chip v-for="level in logLevels" :key="level" :color="getLevelColor(level)" filter variant="flat" size="small"
@@ -26,20 +25,19 @@ export default {
name: 'ConsoleDisplayer',
data() {
return {
autoScroll: true, // 默认开启自动滚动
autoScroll: true,
logColorAnsiMap: {
'\u001b[1;34m': 'color: #0000FF; font-weight: bold;', // bold_blue
'\u001b[1;36m': 'color: #00FFFF; font-weight: bold;', // bold_cyan
'\u001b[1;33m': 'color: #FFFF00; font-weight: bold;', // bold_yellow
'\u001b[31m': 'color: #FF0000;', // red
'\u001b[1;31m': 'color: #FF0000; font-weight: bold;', // bold_red
'\u001b[0m': 'color: inherit; font-weight: normal;', // reset
'\u001b[32m': 'color: #00FF00;', // green
'\u001b[1;34m': 'color: #0000FF; font-weight: bold;',
'\u001b[1;36m': 'color: #00FFFF; font-weight: bold;',
'\u001b[1;33m': 'color: #FFFF00; font-weight: bold;',
'\u001b[31m': 'color: #FF0000;',
'\u001b[1;31m': 'color: #FF0000; font-weight: bold;',
'\u001b[0m': 'color: inherit; font-weight: normal;',
'\u001b[32m': 'color: #00FF00;',
'default': 'color: #FFFFFF;'
},
historyNum_: -1,
logLevels: ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
selectedLevels: [0, 1, 2, 3, 4], // 默认选中所有级别
selectedLevels: [0, 1, 2, 3, 4],
levelColors: {
'DEBUG': 'grey',
'INFO': 'blue-lighten-3',
@@ -47,17 +45,19 @@ export default {
'ERROR': 'red',
'CRITICAL': 'purple'
},
lastProcessedTime: 0, // 记录最后处理的日志时间戳
localLogCache: [], // 本地日志缓存
localLogCache: [],
eventSource: null,
retryTimer: null,
retryAttempts: 0,
maxRetryAttempts: 10,
baseRetryDelay: 1000,
lastEventId: null,
}
},
computed: {
commonStore() {
return useCommonStore();
},
logCache() {
return this.commonStore.log_cache;
}
},
props: {
historyNum: {
@@ -70,41 +70,6 @@ export default {
}
},
watch: {
logCache: {
handler(newVal) {
// 基于 timestamp 处理新增的日志
if (newVal && newVal.length > 0) {
// 确保 DOM 已经准备好
this.$nextTick(() => {
// 合并到本地缓存并按时间排序
const newLogs = newVal.filter(log => log.time > this.lastProcessedTime);
if (newLogs.length > 0) {
this.localLogCache.push(...newLogs);
// 按时间戳排序
this.localLogCache.sort((a, b) => a.time - b.time);
// 只保留最新的 log_cache_max_len 条
if (this.localLogCache.length > this.commonStore.log_cache_max_len) {
this.localLogCache.splice(0, this.localLogCache.length - this.commonStore.log_cache_max_len);
}
// 显示新日志
newLogs.forEach(logItem => {
if (this.isLevelSelected(logItem.level)) {
this.printLog(logItem.data);
}
});
// 更新最后处理时间
this.lastProcessedTime = Math.max(...newLogs.map(log => log.time));
}
});
}
},
deep: true,
immediate: false
},
selectedLevels: {
handler() {
this.refreshDisplay();
@@ -113,30 +78,142 @@ export default {
}
},
async mounted() {
// 请求历史日志
await this.fetchLogHistory();
// 等待 DOM 准备好后,显示历史日志
this.$nextTick(() => {
if (this.localLogCache.length > 0) {
this.localLogCache.forEach(logItem => {
if (this.isLevelSelected(logItem.level)) {
this.printLog(logItem.data);
}
});
// 更新最后处理时间
this.lastProcessedTime = Math.max(...this.localLogCache.map(log => log.time));
}
});
this.connectSSE();
},
beforeUnmount() {
if (this.eventSource) {
this.eventSource.close();
this.eventSource = null;
}
if (this.retryTimer) {
clearTimeout(this.retryTimer);
this.retryTimer = null;
}
this.retryAttempts = 0;
},
methods: {
connectSSE() {
if (this.eventSource) {
this.eventSource.close();
this.eventSource = null;
}
console.log(`正在连接日志流... (尝试次数: ${this.retryAttempts})`);
const token = localStorage.getItem('token');
this.eventSource = new EventSourcePolyfill('/api/live-log', {
headers: {
'Authorization': token ? `Bearer ${token}` : ''
},
heartbeatTimeout: 300000,
withCredentials: true
});
this.eventSource.onopen = () => {
console.log('日志流连接成功!');
this.retryAttempts = 0;
if (!this.lastEventId) {
this.fetchLogHistory();
}
};
this.eventSource.onmessage = (event) => {
try {
if (event.lastEventId) {
this.lastEventId = event.lastEventId;
}
const payload = JSON.parse(event.data);
this.processNewLogs([payload]);
} catch (e) {
console.error('解析日志失败:', e);
}
};
this.eventSource.onerror = (err) => {
if (err.status === 401) {
console.error('鉴权失败 (401),可能是 Token 过期了。');
} else {
console.warn('日志流连接错误:', err);
}
if (this.eventSource) {
this.eventSource.close();
this.eventSource = null;
}
if (this.retryAttempts >= this.maxRetryAttempts) {
console.error('❌ 已达到最大重试次数,停止重连。请刷新页面重试。');
return;
}
const delay = Math.min(
this.baseRetryDelay * Math.pow(2, this.retryAttempts),
30000
);
console.log(`${delay}ms 后尝试第 ${this.retryAttempts + 1} 次重连...`);
if (this.retryTimer) {
clearTimeout(this.retryTimer);
this.retryTimer = null;
}
this.retryTimer = setTimeout(async () => {
this.retryAttempts++;
if (!this.lastEventId) {
await this.fetchLogHistory();
}
this.connectSSE();
}, delay);
};
},
processNewLogs(newLogs) {
if (!newLogs || newLogs.length === 0) return;
let hasUpdate = false;
newLogs.forEach(log => {
const exists = this.localLogCache.some(existing =>
existing.time === log.time &&
existing.data === log.data &&
existing.level === log.level
);
if (!exists) {
this.localLogCache.push(log);
hasUpdate = true;
if (this.isLevelSelected(log.level)) {
this.printLog(log.data);
}
}
});
if (hasUpdate) {
this.localLogCache.sort((a, b) => a.time - b.time);
const maxSize = this.commonStore.log_cache_max_len || 200;
if (this.localLogCache.length > maxSize) {
this.localLogCache.splice(0, this.localLogCache.length - maxSize);
}
}
},
async fetchLogHistory() {
try {
const res = await axios.get('/api/log-history');
if (res.data.data.logs && res.data.data.logs.length > 0) {
this.localLogCache = [...res.data.data.logs];
// 按时间戳排序
this.localLogCache.sort((a, b) => a.time - b.time);
this.processNewLogs(res.data.data.logs);
}
} catch (err) {
console.error('Failed to fetch log history:', err);
@@ -162,7 +239,6 @@ export default {
if (termElement) {
termElement.innerHTML = '';
// 重新显示所有符合筛选条件的日志
if (this.localLogCache && this.localLogCache.length > 0) {
this.localLogCache.forEach(logItem => {
if (this.isLevelSelected(logItem.level)) {
@@ -173,16 +249,13 @@ export default {
}
},
toggleAutoScroll() {
this.autoScroll = !this.autoScroll;
},
printLog(log) {
// append 一个 span 标签到 termblock 的方式
let ele = document.getElementById('term')
if (!ele) {
console.warn('term element not found, skipping log print');
return;
}
@@ -196,11 +269,11 @@ export default {
}
}
span.style = style + 'display: block; font-size: 12px; font-family: Consolas, monospace; white-space: pre-wrap;'
span.style = style + 'display: block; font-size: 12px; font-family: Consolas, monospace; white-space: pre-wrap; margin-bottom: 2px;'
span.classList.add('fade-in')
span.innerText = `${log}`;
ele.appendChild(span)
if (this.autoScroll ) {
if (this.autoScroll) {
ele.scrollTop = ele.scrollHeight
}
}
@@ -230,4 +303,4 @@ export default {
opacity: 1;
}
}
</style>
</style>
@@ -18,6 +18,11 @@
"title": "Data Migration to v4.0.0",
"subtitle": "If you encounter data compatibility issues, you can manually start the database migration assistant",
"button": "Start Migration Assistant"
},
"backup": {
"title": "Backup & Restore",
"subtitle": "Export or import all AstrBot data for easy migration to a new server",
"button": "Backup Manager"
}
},
"sidebar": {
@@ -29,5 +34,66 @@
"mainItems": "Main Modules",
"moreItems": "More Features"
}
},
"backup": {
"dialog": {
"title": "Backup Manager"
},
"tabs": {
"export": "Export Backup",
"import": "Import Backup",
"list": "Backup List"
},
"export": {
"title": "Create Backup",
"description": "Export all data as a ZIP backup file, including database, knowledge base, config and attachments.",
"includes": "Backup includes: Main database, Knowledge bases (metadata + vector index + documents), Config files, Attachment files",
"button": "Start Export",
"processing": "Exporting...",
"wait": "Please wait, packaging data...",
"completed": "Export Completed!",
"download": "Download Backup",
"another": "Create New Backup",
"failed": "Export Failed",
"retry": "Retry"
},
"import": {
"title": "Import Backup",
"warning": "⚠️ Import will clear and overwrite existing data! Please make sure you have backed up your current data.",
"selectFile": "Select backup file (.zip)",
"uploadAndCheck": "Upload & Check",
"uploading": "Uploading...",
"uploadWait": "Please wait, uploading backup file...",
"invalidBackup": "Invalid backup file",
"backupContents": "Backup Contents",
"tables": "tables",
"knowledgeBases": "Knowledge Bases",
"configFiles": "Config Files",
"confirmImport": "Confirm Import",
"button": "Start Import",
"processing": "Importing...",
"wait": "Please wait, restoring data...",
"completed": "Import Completed!",
"restartRequired": "Data has been successfully imported. It is recommended to restart AstrBot immediately for all changes to take effect.",
"restartNow": "Restart Now",
"failed": "Import Failed",
"retry": "Retry",
"version": {
"backupVersion": "Backup Version",
"currentVersion": "Current Version",
"backupTime": "Backup Time",
"matchTitle": "✅ Version Match",
"matchMessage": "Import will clear and overwrite all existing data, including:\n• Main database (conversations, settings, etc.)\n• Knowledge bases\n• Plugins and plugin data\n• Configuration files\n\nThis action cannot be undone! Do you want to continue?",
"minorDiffTitle": "⚠️ Version Difference Warning",
"minorDiffMessage": "Minor version differences are usually compatible, but there may be some data structure changes.\nImport will clear and overwrite all existing data!\n\nDo you want to continue?",
"majorDiffTitle": "⛔ Cannot Import",
"majorDiffMessage": "Major version numbers are different. Cross-major-version import may cause data corruption.\nPlease use the same major version of AstrBot for import."
}
},
"list": {
"empty": "No backup files",
"refresh": "Refresh List",
"confirmDelete": "Are you sure you want to delete this backup file? This action cannot be undone."
}
}
}
}
@@ -18,6 +18,11 @@
"title": "数据迁移到 v4.0.0 格式",
"subtitle": "如果您遇到数据兼容性问题,可以手动启动数据库迁移助手",
"button": "启动迁移助手"
},
"backup": {
"title": "数据备份与恢复",
"subtitle": "导出或导入 AstrBot 的所有数据,方便迁移到新服务器",
"button": "备份管理"
}
},
"sidebar": {
@@ -29,5 +34,66 @@
"mainItems": "主要模块",
"moreItems": "更多功能"
}
},
"backup": {
"dialog": {
"title": "备份管理"
},
"tabs": {
"export": "导出备份",
"import": "导入备份",
"list": "备份列表"
},
"export": {
"title": "创建备份",
"description": "将所有数据导出为 ZIP 备份文件,包括数据库、知识库、配置和附件。",
"includes": "备份包含:主数据库、知识库(元数据+向量索引+文档)、配置文件、附件文件",
"button": "开始导出",
"processing": "正在导出...",
"wait": "请稍候,正在打包数据...",
"completed": "导出完成!",
"download": "下载备份",
"another": "创建新备份",
"failed": "导出失败",
"retry": "重试"
},
"import": {
"title": "导入备份",
"warning": "⚠️ 导入将会清空并覆盖现有数据!请确保已备份当前数据。",
"selectFile": "选择备份文件 (.zip)",
"uploadAndCheck": "上传并检查",
"uploading": "正在上传...",
"uploadWait": "请稍候,正在上传备份文件...",
"invalidBackup": "无效的备份文件",
"backupContents": "备份内容",
"tables": "个数据表",
"knowledgeBases": "知识库",
"configFiles": "配置文件",
"confirmImport": "确认导入",
"button": "开始导入",
"processing": "正在导入...",
"wait": "请稍候,正在恢复数据...",
"completed": "导入完成!",
"restartRequired": "数据已成功导入。建议立即重启 AstrBot 以使所有更改生效。",
"restartNow": "立即重启",
"failed": "导入失败",
"retry": "重试",
"version": {
"backupVersion": "备份版本",
"currentVersion": "当前版本",
"backupTime": "备份时间",
"matchTitle": "✅ 版本匹配",
"matchMessage": "导入将会清空并覆盖现有的所有数据,包括:\n• 主数据库(对话记录、配置等)\n• 知识库数据\n• 插件及插件数据\n• 配置文件\n\n此操作不可撤销!是否确认继续?",
"minorDiffTitle": "⚠️ 版本差异警告",
"minorDiffMessage": "小版本差异通常是兼容的,但可能存在少量数据结构变化。\n导入将会清空并覆盖现有的所有数据!\n\n是否确认继续导入?",
"majorDiffTitle": "⛔ 无法导入",
"majorDiffMessage": "主版本号不同,跨主版本导入可能导致数据损坏。\n请使用相同主版本的 AstrBot 进行导入。"
}
},
"list": {
"empty": "暂无备份文件",
"refresh": "刷新列表",
"confirmDelete": "确定要删除这个备份文件吗?此操作不可撤销。"
}
}
}
}
+23 -3
View File
@@ -21,10 +21,14 @@ export const useCommonStore = defineStore({
}
const controller = new AbortController();
const { signal } = controller;
// 注意:这里如果之前改过 Polyfill 的话,可能需要保持原样
// 如果是用 fetch 的话,这里是支持 Authorization Header 的
const headers = {
'Content-Type': 'multipart/form-data',
'Authorization': 'Bearer ' + localStorage.getItem('token')
};
fetch('/api/live-log', {
method: 'GET',
headers,
@@ -72,10 +76,20 @@ export const useCommonStore = defineStore({
try {
const logObject = JSON.parse(logLine);
// give a uuid if not exists
// 修复:兼容 HTTP 环境的 UUID 生成
if (!logObject.uuid) {
logObject.uuid = crypto.randomUUID();
if (typeof crypto !== 'undefined' && typeof crypto.randomUUID === 'function') {
logObject.uuid = crypto.randomUUID();
} else {
// 手动生成 UUID v4
logObject.uuid = 'xxxxxxxx-xxxx-4xxx-yxxx-xxxxxxxxxxxx'.replace(/[xy]/g, function(c) {
var r = Math.random() * 16 | 0, v = c == 'x' ? r : (r & 0x3 | 0x8);
return v.toString(16);
});
}
}
this.log_cache.push(logObject);
// Limit log cache size
if (this.log_cache.length > this.log_cache_max_len) {
@@ -93,7 +107,13 @@ export const useCommonStore = defineStore({
}).catch(error => {
console.error('SSE error:', error);
// Attempt to reconnect after a delay
this.log_cache.push('SSE Connection failed, retrying in 5 seconds...');
this.log_cache.push({
type: 'log',
level: 'ERROR',
time: Date.now() / 1000,
data: 'SSE Connection failed, retrying in 5 seconds...',
uuid: 'error-' + Date.now()
});
setTimeout(() => {
this.eventSource = null;
this.createEventSource();
+16
View File
@@ -17,6 +17,13 @@
<v-list-subheader>{{ tm('system.title') }}</v-list-subheader>
<v-list-item :subtitle="tm('system.backup.subtitle')" :title="tm('system.backup.title')">
<v-btn style="margin-top: 16px;" color="primary" @click="openBackupDialog">
<v-icon class="mr-2">mdi-backup-restore</v-icon>
{{ tm('system.backup.button') }}
</v-btn>
</v-list-item>
<v-list-item :subtitle="tm('system.restart.subtitle')" :title="tm('system.restart.title')">
<v-btn style="margin-top: 16px;" color="error" @click="restartAstrBot">{{ tm('system.restart.button') }}</v-btn>
</v-list-item>
@@ -30,6 +37,7 @@
<WaitingForRestart ref="wfr"></WaitingForRestart>
<MigrationDialog ref="migrationDialog"></MigrationDialog>
<BackupDialog ref="backupDialog"></BackupDialog>
</template>
@@ -40,12 +48,14 @@ import WaitingForRestart from '@/components/shared/WaitingForRestart.vue';
import ProxySelector from '@/components/shared/ProxySelector.vue';
import MigrationDialog from '@/components/shared/MigrationDialog.vue';
import SidebarCustomizer from '@/components/shared/SidebarCustomizer.vue';
import BackupDialog from '@/components/shared/BackupDialog.vue';
import { useModuleI18n } from '@/i18n/composables';
const { tm } = useModuleI18n('features/settings');
const wfr = ref(null);
const migrationDialog = ref(null);
const backupDialog = ref(null);
const restartAstrBot = () => {
axios.post('/api/stat/restart-core').then(() => {
@@ -65,4 +75,10 @@ const startMigration = async () => {
}
}
}
const openBackupDialog = () => {
if (backupDialog.value) {
backupDialog.value.open();
}
}
</script>
+1
View File
@@ -19,6 +19,7 @@ export default defineConfig({
],
resolve: {
alias: {
mermaid: 'mermaid/dist/mermaid.js',
'@': fileURLToPath(new URL('./src', import.meta.url))
}
},
+2 -2
View File
@@ -1,6 +1,6 @@
[project]
name = "AstrBot"
version = "4.10.2"
version = "4.10.3"
description = "Easy-to-use multi-platform LLM chatbot and development framework"
readme = "README.md"
requires-python = ">=3.10"
@@ -103,7 +103,7 @@ typeCheckingMode = "basic"
pythonVersion = "3.10"
reportMissingTypeStubs = false
reportMissingImports = false
include = ["astrbot", "packages"]
include = ["astrbot"]
exclude = ["dashboard", "node_modules", "dist", "data", "tests"]
[build-system]
+760
View File
@@ -0,0 +1,760 @@
"""备份功能单元测试"""
import json
import os
import re
import zipfile
from datetime import datetime
from unittest.mock import AsyncMock, MagicMock
import pytest
from astrbot.core.backup import (
BACKUP_MANIFEST_VERSION,
KB_METADATA_MODELS,
MAIN_DB_MODELS,
ImportPreCheckResult,
)
from astrbot.core.backup.exporter import AstrBotExporter
from astrbot.core.backup.importer import (
AstrBotImporter,
ImportResult,
_get_major_version,
)
from astrbot.core.config.default import VERSION
from astrbot.core.db.po import (
ConversationV2,
)
from astrbot.core.utils.version_comparator import VersionComparator
from astrbot.dashboard.routes.backup import (
generate_unique_filename,
secure_filename,
)
@pytest.fixture
def temp_backup_dir(tmp_path):
"""创建临时备份目录"""
backup_dir = tmp_path / "backups"
backup_dir.mkdir()
return backup_dir
@pytest.fixture
def temp_data_dir(tmp_path):
"""创建临时数据目录"""
data_dir = tmp_path / "data"
data_dir.mkdir()
# 创建配置文件
config_path = data_dir / "cmd_config.json"
config_path.write_text(json.dumps({"test": "config"}))
# 创建附件目录
attachments_dir = data_dir / "attachments"
attachments_dir.mkdir()
return data_dir
@pytest.fixture
def mock_main_db():
"""创建模拟的主数据库"""
db = MagicMock()
# 模拟异步上下文管理器
session = AsyncMock()
db.get_db = MagicMock(
return_value=AsyncMock(__aenter__=AsyncMock(return_value=session))
)
return db
@pytest.fixture
def mock_kb_manager():
"""创建模拟的知识库管理器"""
kb_manager = MagicMock()
kb_manager.kb_insts = {}
# 模拟 kb_db
kb_db = MagicMock()
session = AsyncMock()
kb_db.get_db = MagicMock(
return_value=AsyncMock(__aenter__=AsyncMock(return_value=session))
)
kb_manager.kb_db = kb_db
return kb_manager
class TestImportResult:
"""ImportResult 类测试"""
def test_init(self):
"""测试初始化"""
result = ImportResult()
assert result.success is True
assert result.imported_tables == {}
assert result.imported_files == {}
assert result.warnings == []
assert result.errors == []
def test_add_warning(self):
"""测试添加警告"""
result = ImportResult()
result.add_warning("test warning")
assert "test warning" in result.warnings
assert result.success is True # 警告不影响成功状态
def test_add_error(self):
"""测试添加错误"""
result = ImportResult()
result.add_error("test error")
assert "test error" in result.errors
assert result.success is False # 错误会导致失败
def test_to_dict(self):
"""测试转换为字典"""
result = ImportResult()
result.imported_tables = {"test_table": 10}
result.add_warning("warning")
d = result.to_dict()
assert d["success"] is True
assert d["imported_tables"] == {"test_table": 10}
assert "warning" in d["warnings"]
class TestAstrBotExporter:
"""AstrBotExporter 类测试"""
def test_init(self, mock_main_db, mock_kb_manager, temp_data_dir):
"""测试初始化"""
exporter = AstrBotExporter(
main_db=mock_main_db,
kb_manager=mock_kb_manager,
config_path=str(temp_data_dir / "cmd_config.json"),
)
assert exporter.main_db is mock_main_db
assert exporter.kb_manager is mock_kb_manager
def test_model_to_dict_with_model_dump(self):
"""测试 _model_to_dict 使用 model_dump 方法"""
exporter = AstrBotExporter(main_db=MagicMock())
# 创建一个有 model_dump 方法的模拟对象
mock_record = MagicMock()
mock_record.model_dump.return_value = {"id": 1, "name": "test"}
result = exporter._model_to_dict(mock_record)
assert result == {"id": 1, "name": "test"}
def test_model_to_dict_with_datetime(self):
"""测试 _model_to_dict 处理 datetime 字段"""
exporter = AstrBotExporter(main_db=MagicMock())
now = datetime.now()
mock_record = MagicMock()
mock_record.model_dump.return_value = {"id": 1, "created_at": now}
result = exporter._model_to_dict(mock_record)
assert result["created_at"] == now.isoformat()
def test_add_checksum(self):
"""测试添加校验和"""
exporter = AstrBotExporter(main_db=MagicMock())
exporter._add_checksum("test.json", '{"test": "data"}')
assert "test.json" in exporter._checksums
assert exporter._checksums["test.json"].startswith("sha256:")
def test_generate_manifest(self, mock_main_db, mock_kb_manager):
"""测试生成清单"""
exporter = AstrBotExporter(
main_db=mock_main_db,
kb_manager=mock_kb_manager,
)
main_data = {
"platform_stats": [{"id": 1}],
"conversations": [],
"attachments": [],
}
kb_meta_data = {
"knowledge_bases": [],
"kb_documents": [],
}
dir_stats = {
"plugins": {"files": 10, "size": 1024},
"plugin_data": {"files": 5, "size": 512},
}
manifest = exporter._generate_manifest(main_data, kb_meta_data, dir_stats)
assert manifest["version"] == BACKUP_MANIFEST_VERSION
assert manifest["astrbot_version"] == VERSION
assert "exported_at" in manifest
assert "tables" in manifest
assert "statistics" in manifest
assert "directories" in manifest
assert manifest["statistics"]["main_db"]["platform_stats"] == 1
assert manifest["statistics"]["directories"] == dir_stats
@pytest.mark.asyncio
async def test_export_all_creates_zip(
self, mock_main_db, temp_backup_dir, temp_data_dir
):
"""测试导出创建 ZIP 文件"""
# 设置模拟数据库返回空数据
session = AsyncMock()
result = MagicMock()
result.scalars.return_value.all.return_value = []
session.execute = AsyncMock(return_value=result)
mock_main_db.get_db.return_value = AsyncMock(
__aenter__=AsyncMock(return_value=session),
__aexit__=AsyncMock(return_value=None),
)
exporter = AstrBotExporter(
main_db=mock_main_db,
kb_manager=None,
config_path=str(temp_data_dir / "cmd_config.json"),
)
zip_path = await exporter.export_all(output_dir=str(temp_backup_dir))
assert os.path.exists(zip_path)
assert zip_path.endswith(".zip")
assert "astrbot_backup_" in zip_path
# 验证 ZIP 文件内容
with zipfile.ZipFile(zip_path, "r") as zf:
namelist = zf.namelist()
assert "manifest.json" in namelist
assert "databases/main_db.json" in namelist
assert "config/cmd_config.json" in namelist
class TestAstrBotImporter:
"""AstrBotImporter 类测试"""
def test_init(self, mock_main_db, mock_kb_manager, temp_data_dir):
"""测试初始化"""
importer = AstrBotImporter(
main_db=mock_main_db,
kb_manager=mock_kb_manager,
config_path=str(temp_data_dir / "cmd_config.json"),
)
assert importer.main_db is mock_main_db
assert importer.kb_manager is mock_kb_manager
def test_validate_version_match(self):
"""测试版本匹配验证"""
importer = AstrBotImporter(main_db=MagicMock())
manifest = {"astrbot_version": VERSION}
# 不应该抛出异常
importer._validate_version(manifest)
def test_validate_version_major_diff_rejected(self):
"""测试主版本不同被拒绝"""
importer = AstrBotImporter(main_db=MagicMock())
# 使用一个明显不同的主版本
manifest = {"astrbot_version": "0.0.1"}
with pytest.raises(ValueError, match="主版本不兼容"):
importer._validate_version(manifest)
def test_validate_version_minor_diff_allowed(self):
"""测试小版本不同被允许(仅警告)"""
importer = AstrBotImporter(main_db=MagicMock())
# 获取当前主版本
major_version = _get_major_version(VERSION)
# 构造一个同主版本但小版本不同的版本
minor_diff_version = f"{major_version}.999"
manifest = {"astrbot_version": minor_diff_version}
# 不应该抛出异常
importer._validate_version(manifest)
def test_validate_version_missing(self):
"""测试缺少版本信息"""
importer = AstrBotImporter(main_db=MagicMock())
manifest = {}
with pytest.raises(ValueError, match="缺少版本信息"):
importer._validate_version(manifest)
def test_convert_datetime_fields(self):
"""测试 datetime 字段转换"""
importer = AstrBotImporter(main_db=MagicMock())
# 使用 ConversationV2 作为测试模型(它有 created_at 和 updated_at 字段)
row = {
"conversation_id": "test-123",
"platform_id": "test",
"user_id": "user1",
"created_at": "2024-01-01T12:00:00",
"updated_at": "2024-01-01T12:00:00",
}
result = importer._convert_datetime_fields(row, ConversationV2)
# created_at 应该被转换为 datetime 对象
assert isinstance(result["created_at"], datetime)
assert isinstance(result["updated_at"], datetime)
@pytest.mark.asyncio
async def test_import_file_not_exists(self, mock_main_db, tmp_path):
"""测试导入不存在的文件"""
importer = AstrBotImporter(main_db=mock_main_db)
result = await importer.import_all(str(tmp_path / "nonexistent.zip"))
assert result.success is False
assert any("不存在" in err for err in result.errors)
@pytest.mark.asyncio
async def test_import_invalid_zip(self, mock_main_db, tmp_path):
"""测试导入无效的 ZIP 文件"""
# 创建一个无效的文件
invalid_zip = tmp_path / "invalid.zip"
invalid_zip.write_text("not a zip file")
importer = AstrBotImporter(main_db=mock_main_db)
result = await importer.import_all(str(invalid_zip))
assert result.success is False
assert any("无效" in err or "ZIP" in err for err in result.errors)
@pytest.mark.asyncio
async def test_import_missing_manifest(self, mock_main_db, tmp_path):
"""测试导入缺少 manifest 的 ZIP 文件"""
# 创建一个没有 manifest 的 ZIP 文件
zip_path = tmp_path / "no_manifest.zip"
with zipfile.ZipFile(zip_path, "w") as zf:
zf.writestr("test.txt", "test content")
importer = AstrBotImporter(main_db=mock_main_db)
result = await importer.import_all(str(zip_path))
assert result.success is False
assert any("manifest" in err.lower() for err in result.errors)
@pytest.mark.asyncio
async def test_import_major_version_mismatch(self, mock_main_db, tmp_path):
"""测试导入主版本不匹配的备份"""
# 创建一个主版本不匹配的备份
zip_path = tmp_path / "old_version.zip"
manifest = {
"version": "1.0",
"astrbot_version": "0.0.1", # 主版本不同
"tables": {"main_db": []},
}
with zipfile.ZipFile(zip_path, "w") as zf:
zf.writestr("manifest.json", json.dumps(manifest))
importer = AstrBotImporter(main_db=mock_main_db)
result = await importer.import_all(str(zip_path))
assert result.success is False
assert any("主版本不兼容" in err for err in result.errors)
class TestSecureFilename:
"""安全文件名函数测试"""
def test_secure_filename_normal(self):
"""测试正常文件名"""
assert secure_filename("backup.zip") == "backup.zip"
assert secure_filename("my_backup_2024.zip") == "my_backup_2024.zip"
def test_secure_filename_path_traversal(self):
"""测试路径遍历攻击"""
assert ".." not in secure_filename("../../../etc/passwd")
assert "/" not in secure_filename("/etc/passwd")
assert "\\" not in secure_filename("..\\..\\windows\\system32")
def test_secure_filename_with_path(self):
"""测试带路径的文件名"""
result = secure_filename("/path/to/backup.zip")
assert result == "backup.zip"
result = secure_filename("C:\\Users\\test\\backup.zip")
assert result == "backup.zip"
def test_secure_filename_special_chars(self):
"""测试特殊字符"""
result = secure_filename('backup<>:"|?*.zip')
# 特殊字符应被替换为下划线
assert "<" not in result
assert ">" not in result
assert ":" not in result
assert '"' not in result
assert "|" not in result
assert "?" not in result
assert "*" not in result
def test_secure_filename_hidden_file(self):
"""测试隐藏文件(前导点)"""
result = secure_filename(".hidden_backup.zip")
assert not result.startswith(".")
def test_secure_filename_empty(self):
"""测试空文件名"""
assert secure_filename("") == "backup"
assert secure_filename("...") == "backup"
def test_generate_unique_filename(self):
"""测试生成唯一文件名"""
result = generate_unique_filename("backup.zip")
# 应包含 uploaded_ 前缀和时间戳
assert result.startswith("uploaded_")
assert result.endswith("_backup.zip")
# 应包含时间戳格式 YYYYMMDD_HHMMSS
assert re.search(r"uploaded_\d{8}_\d{6}_backup\.zip", result)
class TestVersionComparison:
"""版本比较函数测试 - 使用 VersionComparator"""
def test_get_major_version_simple(self):
"""测试提取简单主版本号"""
assert _get_major_version("1.0") == "1.0"
assert _get_major_version("2.1") == "2.1"
assert _get_major_version("4.9.1") == "4.9"
def test_get_major_version_with_prefix(self):
"""测试带 v 前缀的版本号"""
assert _get_major_version("v1.0") == "1.0"
assert _get_major_version("V4.9.1") == "4.9"
def test_get_major_version_with_prerelease(self):
"""测试带预发布标签的版本号"""
assert _get_major_version("4.9.1-beta") == "4.9"
assert _get_major_version("4.9.1-alpha.1") == "4.9"
assert _get_major_version("4.9.1+build123") == "4.9"
def test_get_major_version_single_part(self):
"""测试单部分版本号"""
assert _get_major_version("1") == "1.0"
def test_get_major_version_empty(self):
"""测试空版本号"""
assert _get_major_version("") == "0.0"
def test_compare_versions_equal(self):
"""测试版本相等"""
assert VersionComparator.compare_version("1.0", "1.0") == 0
assert VersionComparator.compare_version("1.0.0", "1.0") == 0
assert VersionComparator.compare_version("2.10", "2.10") == 0
def test_compare_versions_less_than(self):
"""测试版本小于"""
assert VersionComparator.compare_version("1.0", "1.1") == -1
assert (
VersionComparator.compare_version("1.9", "1.10") == -1
) # 关键测试:多位数版本比较
assert VersionComparator.compare_version("1.2", "1.10") == -1
assert VersionComparator.compare_version("1.0", "2.0") == -1
def test_compare_versions_greater_than(self):
"""测试版本大于"""
assert VersionComparator.compare_version("1.1", "1.0") == 1
assert (
VersionComparator.compare_version("1.10", "1.9") == 1
) # 关键测试:多位数版本比较
assert VersionComparator.compare_version("1.10", "1.2") == 1
assert VersionComparator.compare_version("2.0", "1.0") == 1
def test_compare_versions_different_lengths(self):
"""测试不同长度版本比较"""
assert VersionComparator.compare_version("1.0", "1.0.0") == 0
assert VersionComparator.compare_version("1.0", "1.0.1") == -1
assert VersionComparator.compare_version("1.0.1", "1.0") == 1
def test_compare_versions_prerelease(self):
"""测试预发布版本比较"""
# 预发布版本低于正式版本
assert VersionComparator.compare_version("1.0.0-alpha", "1.0.0") == -1
assert VersionComparator.compare_version("1.0.0", "1.0.0-beta") == 1
# alpha < beta
assert VersionComparator.compare_version("1.0.0-alpha", "1.0.0-beta") == -1
class TestImportPreCheckResult:
"""ImportPreCheckResult 类测试"""
def test_init_default_values(self):
"""测试默认值初始化"""
result = ImportPreCheckResult()
assert result.valid is False
assert result.can_import is False
assert result.version_status == ""
assert result.backup_version == ""
assert result.current_version == VERSION
assert result.confirm_message == ""
assert result.warnings == []
assert result.error == ""
assert result.backup_summary == {}
def test_to_dict(self):
"""测试转换为字典"""
result = ImportPreCheckResult(
valid=True,
can_import=True,
version_status="match",
backup_version="4.9.0",
confirm_message="确认导入?",
warnings=["警告1"],
backup_summary={"tables": ["table1"]},
)
d = result.to_dict()
assert d["valid"] is True
assert d["can_import"] is True
assert d["version_status"] == "match"
assert d["backup_version"] == "4.9.0"
assert d["confirm_message"] == "确认导入?"
assert "警告1" in d["warnings"]
assert d["backup_summary"]["tables"] == ["table1"]
class TestPreCheck:
"""预检查功能测试"""
def test_pre_check_file_not_exists(self, mock_main_db):
"""测试预检查不存在的文件"""
importer = AstrBotImporter(main_db=mock_main_db)
result = importer.pre_check("/nonexistent/file.zip")
assert result.valid is False
assert "不存在" in result.error
def test_pre_check_invalid_zip(self, mock_main_db, tmp_path):
"""测试预检查无效的 ZIP 文件"""
invalid_zip = tmp_path / "invalid.zip"
invalid_zip.write_text("not a zip file")
importer = AstrBotImporter(main_db=mock_main_db)
result = importer.pre_check(str(invalid_zip))
assert result.valid is False
assert "ZIP" in result.error or "无效" in result.error
def test_pre_check_missing_manifest(self, mock_main_db, tmp_path):
"""测试预检查缺少 manifest 的 ZIP 文件"""
zip_path = tmp_path / "no_manifest.zip"
with zipfile.ZipFile(zip_path, "w") as zf:
zf.writestr("test.txt", "test content")
importer = AstrBotImporter(main_db=mock_main_db)
result = importer.pre_check(str(zip_path))
assert result.valid is False
assert "manifest" in result.error.lower()
def test_pre_check_version_match(self, mock_main_db, tmp_path):
"""测试预检查版本匹配"""
zip_path = tmp_path / "backup.zip"
manifest = {
"version": "1.1",
"astrbot_version": VERSION,
"created_at": "2024-01-01T12:00:00",
"tables": {"platform_stats": 1},
"has_knowledge_bases": True,
"has_config": True,
"directories": ["plugins"],
}
with zipfile.ZipFile(zip_path, "w") as zf:
zf.writestr("manifest.json", json.dumps(manifest))
importer = AstrBotImporter(main_db=mock_main_db)
result = importer.pre_check(str(zip_path))
assert result.valid is True
assert result.can_import is True
assert result.version_status == "match"
assert result.backup_version == VERSION
# confirm_message 现在由前端生成,后端不再生成
assert result.backup_summary["has_knowledge_bases"] is True
def test_pre_check_minor_version_diff(self, mock_main_db, tmp_path):
"""测试预检查小版本差异"""
# 构造一个同主版本但小版本不同的版本
major_version = _get_major_version(VERSION)
minor_diff_version = f"{major_version}.999"
zip_path = tmp_path / "backup.zip"
manifest = {
"version": "1.1",
"astrbot_version": minor_diff_version,
"created_at": "2024-01-01T12:00:00",
"tables": {},
}
with zipfile.ZipFile(zip_path, "w") as zf:
zf.writestr("manifest.json", json.dumps(manifest))
importer = AstrBotImporter(main_db=mock_main_db)
result = importer.pre_check(str(zip_path))
assert result.valid is True
assert result.can_import is True
assert result.version_status == "minor_diff"
# 版本消息由前端 i18n 生成,后端 warnings 列表不再包含版本相关消息
# warnings 列表保留用于其他非版本相关的警告
def test_pre_check_major_version_diff(self, mock_main_db, tmp_path):
"""测试预检查主版本差异"""
zip_path = tmp_path / "backup.zip"
manifest = {
"version": "1.1",
"astrbot_version": "0.0.1", # 主版本不同
"created_at": "2024-01-01T12:00:00",
"tables": {},
}
with zipfile.ZipFile(zip_path, "w") as zf:
zf.writestr("manifest.json", json.dumps(manifest))
importer = AstrBotImporter(main_db=mock_main_db)
result = importer.pre_check(str(zip_path))
assert result.valid is True # 文件有效
assert result.can_import is False # 但不能导入
assert result.version_status == "major_diff"
# 版本消息由前端 i18n 生成,后端 warnings 列表不再包含版本相关消息
class TestVersionCompatibility:
"""版本兼容性检查测试"""
def test_check_version_compatibility_match(self, mock_main_db):
"""测试版本完全匹配"""
importer = AstrBotImporter(main_db=mock_main_db)
result = importer._check_version_compatibility(VERSION)
assert result["status"] == "match"
assert result["can_import"] is True
def test_check_version_compatibility_minor_diff(self, mock_main_db):
"""测试小版本差异"""
major_version = _get_major_version(VERSION)
minor_diff_version = f"{major_version}.999"
importer = AstrBotImporter(main_db=mock_main_db)
result = importer._check_version_compatibility(minor_diff_version)
assert result["status"] == "minor_diff"
assert result["can_import"] is True
def test_check_version_compatibility_major_diff(self, mock_main_db):
"""测试主版本差异"""
importer = AstrBotImporter(main_db=mock_main_db)
result = importer._check_version_compatibility("0.0.1")
assert result["status"] == "major_diff"
assert result["can_import"] is False
def test_check_version_compatibility_empty_version(self, mock_main_db):
"""测试空版本号"""
importer = AstrBotImporter(main_db=mock_main_db)
result = importer._check_version_compatibility("")
assert result["status"] == "major_diff"
assert result["can_import"] is False
class TestModelMappings:
"""测试模型映射配置"""
def test_main_db_models_not_empty(self):
"""测试主数据库模型映射非空"""
assert len(MAIN_DB_MODELS) > 0
def test_main_db_models_contain_expected_tables(self):
"""测试主数据库模型映射包含预期的表"""
expected_tables = [
"platform_stats",
"conversations",
"personas",
"preferences",
"attachments",
]
for table in expected_tables:
assert table in MAIN_DB_MODELS, f"Missing table: {table}"
def test_kb_metadata_models_not_empty(self):
"""测试知识库元数据模型映射非空"""
assert len(KB_METADATA_MODELS) > 0
def test_kb_metadata_models_contain_expected_tables(self):
"""测试知识库元数据模型映射包含预期的表"""
expected_tables = [
"knowledge_bases",
"kb_documents",
"kb_media",
]
for table in expected_tables:
assert table in KB_METADATA_MODELS, f"Missing table: {table}"
class TestBackupIntegration:
"""备份集成测试"""
@pytest.mark.asyncio
async def test_export_import_roundtrip(self, tmp_path):
"""测试导出-导入往返"""
backup_dir = tmp_path / "backups"
backup_dir.mkdir()
data_dir = tmp_path / "data"
data_dir.mkdir()
config_path = data_dir / "cmd_config.json"
config_path.write_text(json.dumps({"setting": "value"}))
attachments_dir = data_dir / "attachments"
attachments_dir.mkdir()
# 创建模拟数据库
mock_db = MagicMock()
session = AsyncMock()
result = MagicMock()
result.scalars.return_value.all.return_value = []
session.execute = AsyncMock(return_value=result)
mock_db.get_db.return_value = AsyncMock(
__aenter__=AsyncMock(return_value=session),
__aexit__=AsyncMock(return_value=None),
)
# 导出
exporter = AstrBotExporter(
main_db=mock_db,
kb_manager=None,
config_path=str(config_path),
)
zip_path = await exporter.export_all(output_dir=str(backup_dir))
assert os.path.exists(zip_path)
# 验证 ZIP 内容
with zipfile.ZipFile(zip_path, "r") as zf:
# 读取 manifest
manifest = json.loads(zf.read("manifest.json"))
assert manifest["astrbot_version"] == VERSION
# 读取配置
config = json.loads(zf.read("config/cmd_config.json"))
assert config["setting"] == "value"
# 读取主数据库
main_db = json.loads(zf.read("databases/main_db.json"))
assert "platform_stats" in main_db