Compare commits
11 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 3e3599835e | |||
| 5255388e2d | |||
| fbdd60b64c | |||
| bd1b0a2836 | |||
| 19541d9d07 | |||
| 2a5d574394 | |||
| f2924fbd1b | |||
| 703e208947 | |||
| 9a5cc977c2 | |||
| aa38fe776a | |||
| 701399c00c |
@@ -15,7 +15,6 @@ Always reference these instructions first and fallback to search or bash command
|
||||
### Running the Application
|
||||
- Run main application: `uv run main.py` -- starts in ~3 seconds
|
||||
- Application creates WebUI on http://localhost:6185 (default credentials: `astrbot`/`astrbot`)
|
||||
- Application loads plugins automatically from `packages/` and `data/plugins/` directories
|
||||
|
||||
### Dashboard Build (Vue.js/Node.js)
|
||||
- **Prerequisites**: Node.js 20+ and npm 10+ required
|
||||
@@ -35,7 +34,7 @@ Always reference these instructions first and fallback to search or bash command
|
||||
- **ALWAYS** run `uv run ruff check .` and `uv run ruff format .` before committing changes
|
||||
|
||||
### Plugin Development
|
||||
- Plugins load from `packages/` (built-in) and `data/plugins/` (user-installed)
|
||||
- Plugins load from `astrbot/builtin_stars/` (built-in) and `data/plugins/` (user-installed)
|
||||
- Plugin system supports function tools and message handlers
|
||||
- Key plugins: python_interpreter, web_searcher, astrbot, reminder, session_controller
|
||||
|
||||
|
||||
+2
-2
@@ -24,9 +24,9 @@ configs/session
|
||||
configs/config.yaml
|
||||
cmd_config.json
|
||||
|
||||
# Plugins and packages
|
||||
# Plugins
|
||||
addons/plugins
|
||||
packages/python_interpreter/workplace
|
||||
astrbot/builtin_stars/python_interpreter/workplace
|
||||
tests/astrbot_plugin_openai
|
||||
|
||||
# Dashboard
|
||||
|
||||
+19
-8
@@ -7,6 +7,7 @@ from astrbot.api import logger, sp, star
|
||||
from astrbot.api.event import AstrMessageEvent
|
||||
from astrbot.api.message_components import Image, Reply
|
||||
from astrbot.api.provider import Provider, ProviderRequest
|
||||
from astrbot.core.agent.message import TextPart
|
||||
from astrbot.core.provider.func_tool_manager import ToolSet
|
||||
|
||||
|
||||
@@ -85,7 +86,9 @@ class ProcessLLMRequest:
|
||||
req.image_urls,
|
||||
)
|
||||
if caption:
|
||||
req.prompt = f"(Image Caption: {caption})\n\n{req.prompt}"
|
||||
req.extra_user_content_parts.append(
|
||||
TextPart(text=f"<image_caption>{caption}</image_caption>")
|
||||
)
|
||||
req.image_urls = []
|
||||
except Exception as e:
|
||||
logger.error(f"处理图片描述失败: {e}")
|
||||
@@ -129,13 +132,14 @@ class ProcessLLMRequest:
|
||||
else:
|
||||
req.prompt = prefix + req.prompt
|
||||
|
||||
# 收集系统提醒信息
|
||||
system_parts = []
|
||||
|
||||
# user identifier
|
||||
if cfg.get("identifier"):
|
||||
user_id = event.message_obj.sender.user_id
|
||||
user_nickname = event.message_obj.sender.nickname
|
||||
req.prompt = (
|
||||
f"\n[User ID: {user_id}, Nickname: {user_nickname}]\n{req.prompt}"
|
||||
)
|
||||
system_parts.append(f"User ID: {user_id}, Nickname: {user_nickname}")
|
||||
|
||||
# group name identifier
|
||||
if cfg.get("group_name_display") and event.message_obj.group_id:
|
||||
@@ -146,7 +150,7 @@ class ProcessLLMRequest:
|
||||
return
|
||||
group_name = event.message_obj.group.group_name
|
||||
if group_name:
|
||||
req.system_prompt += f"\nGroup name: {group_name}\n"
|
||||
system_parts.append(f"Group name: {group_name}")
|
||||
|
||||
# time info
|
||||
if cfg.get("datetime_system_prompt"):
|
||||
@@ -162,7 +166,7 @@ class ProcessLLMRequest:
|
||||
current_time = (
|
||||
datetime.datetime.now().astimezone().strftime("%Y-%m-%d %H:%M (%Z)")
|
||||
)
|
||||
req.system_prompt += f"\nCurrent datetime: {current_time}\n"
|
||||
system_parts.append(f"Current datetime: {current_time}")
|
||||
|
||||
img_cap_prov_id: str = cfg.get("default_image_caption_provider_id") or ""
|
||||
if req.conversation:
|
||||
@@ -225,10 +229,17 @@ class ProcessLLMRequest:
|
||||
except BaseException as e:
|
||||
logger.error(f"处理引用图片失败: {e}")
|
||||
|
||||
# 3. 将所有部分组合成文本并直接注入到当前消息中
|
||||
# 3. 将所有部分组合成文本并添加到 extra_user_content_parts 中
|
||||
# 确保引用内容被正确的标签包裹
|
||||
quoted_content = "\n".join(content_parts)
|
||||
# 确保所有内容都在<Quoted Message>标签内
|
||||
quoted_text = f"<Quoted Message>\n{quoted_content}\n</Quoted Message>"
|
||||
|
||||
req.prompt = f"{quoted_text}\n\n{req.prompt}"
|
||||
req.extra_user_content_parts.append(TextPart(text=quoted_text))
|
||||
|
||||
# 统一包裹所有系统提醒
|
||||
if system_parts:
|
||||
system_content = (
|
||||
"<system_reminder>" + "\n".join(system_parts) + "</system_reminder>"
|
||||
)
|
||||
req.extra_user_content_parts.append(TextPart(text=system_content))
|
||||
+6
-4
@@ -184,7 +184,8 @@ class ProviderCommands:
|
||||
event.set_result(MessageEventResult().message("请输入序号。"))
|
||||
return
|
||||
if idx2 > len(self.context.get_all_tts_providers()) or idx2 < 1:
|
||||
event.set_result(MessageEventResult().message("无效的序号。"))
|
||||
event.set_result(MessageEventResult().message("无效的提供商序号。"))
|
||||
return
|
||||
provider = self.context.get_all_tts_providers()[idx2 - 1]
|
||||
id_ = provider.meta().id
|
||||
await self.context.provider_manager.set_provider(
|
||||
@@ -198,7 +199,8 @@ class ProviderCommands:
|
||||
event.set_result(MessageEventResult().message("请输入序号。"))
|
||||
return
|
||||
if idx2 > len(self.context.get_all_stt_providers()) or idx2 < 1:
|
||||
event.set_result(MessageEventResult().message("无效的序号。"))
|
||||
event.set_result(MessageEventResult().message("无效的提供商序号。"))
|
||||
return
|
||||
provider = self.context.get_all_stt_providers()[idx2 - 1]
|
||||
id_ = provider.meta().id
|
||||
await self.context.provider_manager.set_provider(
|
||||
@@ -209,8 +211,8 @@ class ProviderCommands:
|
||||
event.set_result(MessageEventResult().message(f"成功切换到 {id_}。"))
|
||||
elif isinstance(idx, int):
|
||||
if idx > len(self.context.get_all_providers()) or idx < 1:
|
||||
event.set_result(MessageEventResult().message("无效的序号。"))
|
||||
|
||||
event.set_result(MessageEventResult().message("无效的提供商序号。"))
|
||||
return
|
||||
provider = self.context.get_all_providers()[idx - 1]
|
||||
id_ = provider.meta().id
|
||||
await self.context.provider_manager.set_provider(
|
||||
@@ -1 +1 @@
|
||||
__version__ = "4.10.2"
|
||||
__version__ = "4.10.3"
|
||||
|
||||
@@ -169,6 +169,15 @@ class Message(BaseModel):
|
||||
)
|
||||
return self
|
||||
|
||||
@model_serializer(mode="wrap")
|
||||
def serialize(self, handler):
|
||||
data = handler(self)
|
||||
if self.tool_calls is None:
|
||||
data.pop("tool_calls", None)
|
||||
if self.tool_call_id is None:
|
||||
data.pop("tool_call_id", None)
|
||||
return data
|
||||
|
||||
|
||||
class AssistantMessageSegment(Message):
|
||||
"""A message segment from the assistant."""
|
||||
|
||||
@@ -77,10 +77,11 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
async def _iter_llm_responses(self) -> T.AsyncGenerator[LLMResponse, None]:
|
||||
"""Yields chunks *and* a final LLMResponse."""
|
||||
payload = {
|
||||
"contexts": self.run_context.messages,
|
||||
"contexts": self.run_context.messages, # list[Message]
|
||||
"func_tool": self.req.func_tool,
|
||||
"model": self.req.model, # NOTE: in fact, this arg is None in most cases
|
||||
"session_id": self.req.session_id,
|
||||
"extra_user_content_parts": self.req.extra_user_content_parts, # list[ContentPart]
|
||||
}
|
||||
|
||||
if self.streaming:
|
||||
|
||||
@@ -0,0 +1,26 @@
|
||||
"""AstrBot 备份与恢复模块
|
||||
|
||||
提供数据导出和导入功能,支持用户在服务器迁移时一键备份和恢复所有数据。
|
||||
"""
|
||||
|
||||
# 从 constants 模块导入共享常量
|
||||
from .constants import (
|
||||
BACKUP_MANIFEST_VERSION,
|
||||
KB_METADATA_MODELS,
|
||||
MAIN_DB_MODELS,
|
||||
get_backup_directories,
|
||||
)
|
||||
|
||||
# 导入导出器和导入器
|
||||
from .exporter import AstrBotExporter
|
||||
from .importer import AstrBotImporter, ImportPreCheckResult
|
||||
|
||||
__all__ = [
|
||||
"AstrBotExporter",
|
||||
"AstrBotImporter",
|
||||
"ImportPreCheckResult",
|
||||
"MAIN_DB_MODELS",
|
||||
"KB_METADATA_MODELS",
|
||||
"get_backup_directories",
|
||||
"BACKUP_MANIFEST_VERSION",
|
||||
]
|
||||
@@ -0,0 +1,77 @@
|
||||
"""AstrBot 备份模块共享常量
|
||||
|
||||
此文件定义了导出器和导入器共享的常量,确保两端配置一致。
|
||||
"""
|
||||
|
||||
from sqlmodel import SQLModel
|
||||
|
||||
from astrbot.core.db.po import (
|
||||
Attachment,
|
||||
CommandConfig,
|
||||
CommandConflict,
|
||||
ConversationV2,
|
||||
Persona,
|
||||
PlatformMessageHistory,
|
||||
PlatformSession,
|
||||
PlatformStat,
|
||||
Preference,
|
||||
)
|
||||
from astrbot.core.knowledge_base.models import (
|
||||
KBDocument,
|
||||
KBMedia,
|
||||
KnowledgeBase,
|
||||
)
|
||||
from astrbot.core.utils.astrbot_path import (
|
||||
get_astrbot_config_path,
|
||||
get_astrbot_plugin_data_path,
|
||||
get_astrbot_plugin_path,
|
||||
get_astrbot_t2i_templates_path,
|
||||
get_astrbot_temp_path,
|
||||
get_astrbot_webchat_path,
|
||||
)
|
||||
|
||||
# ============================================================
|
||||
# 共享常量 - 确保导出和导入端配置一致
|
||||
# ============================================================
|
||||
|
||||
# 主数据库模型类映射
|
||||
MAIN_DB_MODELS: dict[str, type[SQLModel]] = {
|
||||
"platform_stats": PlatformStat,
|
||||
"conversations": ConversationV2,
|
||||
"personas": Persona,
|
||||
"preferences": Preference,
|
||||
"platform_message_history": PlatformMessageHistory,
|
||||
"platform_sessions": PlatformSession,
|
||||
"attachments": Attachment,
|
||||
"command_configs": CommandConfig,
|
||||
"command_conflicts": CommandConflict,
|
||||
}
|
||||
|
||||
# 知识库元数据模型类映射
|
||||
KB_METADATA_MODELS: dict[str, type[SQLModel]] = {
|
||||
"knowledge_bases": KnowledgeBase,
|
||||
"kb_documents": KBDocument,
|
||||
"kb_media": KBMedia,
|
||||
}
|
||||
|
||||
|
||||
def get_backup_directories() -> dict[str, str]:
|
||||
"""获取需要备份的目录列表
|
||||
|
||||
使用 astrbot_path 模块动态获取路径,支持通过环境变量 ASTRBOT_ROOT 自定义根目录。
|
||||
|
||||
Returns:
|
||||
dict: 键为备份文件中的目录名称,值为目录的绝对路径
|
||||
"""
|
||||
return {
|
||||
"plugins": get_astrbot_plugin_path(), # 插件本体
|
||||
"plugin_data": get_astrbot_plugin_data_path(), # 插件数据
|
||||
"config": get_astrbot_config_path(), # 配置目录
|
||||
"t2i_templates": get_astrbot_t2i_templates_path(), # T2I 模板
|
||||
"webchat": get_astrbot_webchat_path(), # WebChat 数据
|
||||
"temp": get_astrbot_temp_path(), # 临时文件
|
||||
}
|
||||
|
||||
|
||||
# 备份清单版本号
|
||||
BACKUP_MANIFEST_VERSION = "1.1"
|
||||
@@ -0,0 +1,476 @@
|
||||
"""AstrBot 数据导出器
|
||||
|
||||
负责将所有数据导出为 ZIP 备份文件。
|
||||
导出格式为 JSON,这是数据库无关的方案,支持未来向 MySQL/PostgreSQL 迁移。
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import zipfile
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.config.default import VERSION
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.utils.astrbot_path import (
|
||||
get_astrbot_backups_path,
|
||||
get_astrbot_data_path,
|
||||
)
|
||||
|
||||
# 从共享常量模块导入
|
||||
from .constants import (
|
||||
BACKUP_MANIFEST_VERSION,
|
||||
KB_METADATA_MODELS,
|
||||
MAIN_DB_MODELS,
|
||||
get_backup_directories,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager
|
||||
|
||||
CMD_CONFIG_FILE_PATH = os.path.join(get_astrbot_data_path(), "cmd_config.json")
|
||||
|
||||
|
||||
class AstrBotExporter:
|
||||
"""AstrBot 数据导出器
|
||||
|
||||
导出内容:
|
||||
- 主数据库所有表(data/data_v4.db)
|
||||
- 知识库元数据(data/knowledge_base/kb.db)
|
||||
- 每个知识库的向量文档数据
|
||||
- 配置文件(data/cmd_config.json)
|
||||
- 附件文件
|
||||
- 知识库多媒体文件
|
||||
- 插件目录(data/plugins)
|
||||
- 插件数据目录(data/plugin_data)
|
||||
- 配置目录(data/config)
|
||||
- T2I 模板目录(data/t2i_templates)
|
||||
- WebChat 数据目录(data/webchat)
|
||||
- 临时文件目录(data/temp)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
main_db: BaseDatabase,
|
||||
kb_manager: "KnowledgeBaseManager | None" = None,
|
||||
config_path: str = CMD_CONFIG_FILE_PATH,
|
||||
):
|
||||
self.main_db = main_db
|
||||
self.kb_manager = kb_manager
|
||||
self.config_path = config_path
|
||||
self._checksums: dict[str, str] = {}
|
||||
|
||||
async def export_all(
|
||||
self,
|
||||
output_dir: str | None = None,
|
||||
progress_callback: Any | None = None,
|
||||
) -> str:
|
||||
"""导出所有数据到 ZIP 文件
|
||||
|
||||
Args:
|
||||
output_dir: 输出目录
|
||||
progress_callback: 进度回调函数,接收参数 (stage, current, total, message)
|
||||
|
||||
Returns:
|
||||
str: 生成的 ZIP 文件路径
|
||||
"""
|
||||
if output_dir is None:
|
||||
output_dir = get_astrbot_backups_path()
|
||||
|
||||
# 确保输出目录存在
|
||||
Path(output_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
zip_filename = f"astrbot_backup_{timestamp}.zip"
|
||||
zip_path = os.path.join(output_dir, zip_filename)
|
||||
|
||||
logger.info(f"开始导出备份到 {zip_path}")
|
||||
|
||||
try:
|
||||
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
|
||||
# 1. 导出主数据库
|
||||
if progress_callback:
|
||||
await progress_callback("main_db", 0, 100, "正在导出主数据库...")
|
||||
main_data = await self._export_main_database()
|
||||
main_db_json = json.dumps(
|
||||
main_data, ensure_ascii=False, indent=2, default=str
|
||||
)
|
||||
zf.writestr("databases/main_db.json", main_db_json)
|
||||
self._add_checksum("databases/main_db.json", main_db_json)
|
||||
if progress_callback:
|
||||
await progress_callback("main_db", 100, 100, "主数据库导出完成")
|
||||
|
||||
# 2. 导出知识库数据
|
||||
kb_meta_data: dict[str, Any] = {
|
||||
"knowledge_bases": [],
|
||||
"kb_documents": [],
|
||||
"kb_media": [],
|
||||
}
|
||||
if self.kb_manager:
|
||||
if progress_callback:
|
||||
await progress_callback(
|
||||
"kb_metadata", 0, 100, "正在导出知识库元数据..."
|
||||
)
|
||||
kb_meta_data = await self._export_kb_metadata()
|
||||
kb_meta_json = json.dumps(
|
||||
kb_meta_data, ensure_ascii=False, indent=2, default=str
|
||||
)
|
||||
zf.writestr("databases/kb_metadata.json", kb_meta_json)
|
||||
self._add_checksum("databases/kb_metadata.json", kb_meta_json)
|
||||
if progress_callback:
|
||||
await progress_callback(
|
||||
"kb_metadata", 100, 100, "知识库元数据导出完成"
|
||||
)
|
||||
|
||||
# 导出每个知识库的文档数据
|
||||
kb_insts = self.kb_manager.kb_insts
|
||||
total_kbs = len(kb_insts)
|
||||
for idx, (kb_id, kb_helper) in enumerate(kb_insts.items()):
|
||||
if progress_callback:
|
||||
await progress_callback(
|
||||
"kb_documents",
|
||||
idx,
|
||||
total_kbs,
|
||||
f"正在导出知识库 {kb_helper.kb.kb_name} 的文档数据...",
|
||||
)
|
||||
doc_data = await self._export_kb_documents(kb_helper)
|
||||
doc_json = json.dumps(
|
||||
doc_data, ensure_ascii=False, indent=2, default=str
|
||||
)
|
||||
doc_path = f"databases/kb_{kb_id}/documents.json"
|
||||
zf.writestr(doc_path, doc_json)
|
||||
self._add_checksum(doc_path, doc_json)
|
||||
|
||||
# 导出 FAISS 索引文件
|
||||
await self._export_faiss_index(zf, kb_helper, kb_id)
|
||||
|
||||
# 导出知识库多媒体文件
|
||||
await self._export_kb_media_files(zf, kb_helper, kb_id)
|
||||
|
||||
if progress_callback:
|
||||
await progress_callback(
|
||||
"kb_documents", total_kbs, total_kbs, "知识库文档导出完成"
|
||||
)
|
||||
|
||||
# 3. 导出配置文件
|
||||
if progress_callback:
|
||||
await progress_callback("config", 0, 100, "正在导出配置文件...")
|
||||
if os.path.exists(self.config_path):
|
||||
with open(self.config_path, encoding="utf-8") as f:
|
||||
config_content = f.read()
|
||||
zf.writestr("config/cmd_config.json", config_content)
|
||||
self._add_checksum("config/cmd_config.json", config_content)
|
||||
if progress_callback:
|
||||
await progress_callback("config", 100, 100, "配置文件导出完成")
|
||||
|
||||
# 4. 导出附件文件
|
||||
if progress_callback:
|
||||
await progress_callback("attachments", 0, 100, "正在导出附件...")
|
||||
await self._export_attachments(zf, main_data.get("attachments", []))
|
||||
if progress_callback:
|
||||
await progress_callback("attachments", 100, 100, "附件导出完成")
|
||||
|
||||
# 5. 导出插件和其他目录
|
||||
if progress_callback:
|
||||
await progress_callback(
|
||||
"directories", 0, 100, "正在导出插件和数据目录..."
|
||||
)
|
||||
dir_stats = await self._export_directories(zf)
|
||||
if progress_callback:
|
||||
await progress_callback("directories", 100, 100, "目录导出完成")
|
||||
|
||||
# 6. 生成 manifest
|
||||
if progress_callback:
|
||||
await progress_callback("manifest", 0, 100, "正在生成清单...")
|
||||
manifest = self._generate_manifest(main_data, kb_meta_data, dir_stats)
|
||||
manifest_json = json.dumps(manifest, ensure_ascii=False, indent=2)
|
||||
zf.writestr("manifest.json", manifest_json)
|
||||
if progress_callback:
|
||||
await progress_callback("manifest", 100, 100, "清单生成完成")
|
||||
|
||||
logger.info(f"备份导出完成: {zip_path}")
|
||||
return zip_path
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"备份导出失败: {e}")
|
||||
# 清理失败的文件
|
||||
if os.path.exists(zip_path):
|
||||
os.remove(zip_path)
|
||||
raise
|
||||
|
||||
async def _export_main_database(self) -> dict[str, list[dict]]:
|
||||
"""导出主数据库所有表"""
|
||||
export_data: dict[str, list[dict]] = {}
|
||||
|
||||
async with self.main_db.get_db() as session:
|
||||
for table_name, model_class in MAIN_DB_MODELS.items():
|
||||
try:
|
||||
result = await session.execute(select(model_class))
|
||||
records = result.scalars().all()
|
||||
export_data[table_name] = [
|
||||
self._model_to_dict(record) for record in records
|
||||
]
|
||||
logger.debug(
|
||||
f"导出表 {table_name}: {len(export_data[table_name])} 条记录"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"导出表 {table_name} 失败: {e}")
|
||||
export_data[table_name] = []
|
||||
|
||||
return export_data
|
||||
|
||||
async def _export_kb_metadata(self) -> dict[str, list[dict]]:
|
||||
"""导出知识库元数据库"""
|
||||
if not self.kb_manager:
|
||||
return {"knowledge_bases": [], "kb_documents": [], "kb_media": []}
|
||||
|
||||
export_data: dict[str, list[dict]] = {}
|
||||
|
||||
async with self.kb_manager.kb_db.get_db() as session:
|
||||
for table_name, model_class in KB_METADATA_MODELS.items():
|
||||
try:
|
||||
result = await session.execute(select(model_class))
|
||||
records = result.scalars().all()
|
||||
export_data[table_name] = [
|
||||
self._model_to_dict(record) for record in records
|
||||
]
|
||||
logger.debug(
|
||||
f"导出知识库表 {table_name}: {len(export_data[table_name])} 条记录"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"导出知识库表 {table_name} 失败: {e}")
|
||||
export_data[table_name] = []
|
||||
|
||||
return export_data
|
||||
|
||||
async def _export_kb_documents(self, kb_helper: Any) -> dict[str, Any]:
|
||||
"""导出知识库的文档块数据"""
|
||||
try:
|
||||
from astrbot.core.db.vec_db.faiss_impl.vec_db import FaissVecDB
|
||||
|
||||
vec_db: FaissVecDB = kb_helper.vec_db
|
||||
if not vec_db or not vec_db.document_storage:
|
||||
return {"documents": []}
|
||||
|
||||
# 获取所有文档
|
||||
docs = await vec_db.document_storage.get_documents(
|
||||
metadata_filters={},
|
||||
offset=0,
|
||||
limit=None, # 获取全部
|
||||
)
|
||||
|
||||
return {"documents": docs}
|
||||
except Exception as e:
|
||||
logger.warning(f"导出知识库文档失败: {e}")
|
||||
return {"documents": []}
|
||||
|
||||
async def _export_faiss_index(
|
||||
self,
|
||||
zf: zipfile.ZipFile,
|
||||
kb_helper: Any,
|
||||
kb_id: str,
|
||||
) -> None:
|
||||
"""导出 FAISS 索引文件"""
|
||||
try:
|
||||
index_path = kb_helper.kb_dir / "index.faiss"
|
||||
if index_path.exists():
|
||||
archive_path = f"databases/kb_{kb_id}/index.faiss"
|
||||
zf.write(str(index_path), archive_path)
|
||||
logger.debug(f"导出 FAISS 索引: {archive_path}")
|
||||
except Exception as e:
|
||||
logger.warning(f"导出 FAISS 索引失败: {e}")
|
||||
|
||||
async def _export_kb_media_files(
|
||||
self, zf: zipfile.ZipFile, kb_helper: Any, kb_id: str
|
||||
) -> None:
|
||||
"""导出知识库的多媒体文件"""
|
||||
try:
|
||||
media_dir = kb_helper.kb_medias_dir
|
||||
if not media_dir.exists():
|
||||
return
|
||||
|
||||
for root, _, files in os.walk(media_dir):
|
||||
for file in files:
|
||||
file_path = Path(root) / file
|
||||
# 计算相对路径
|
||||
rel_path = file_path.relative_to(kb_helper.kb_dir)
|
||||
archive_path = f"files/kb_media/{kb_id}/{rel_path}"
|
||||
zf.write(str(file_path), archive_path)
|
||||
except Exception as e:
|
||||
logger.warning(f"导出知识库媒体文件失败: {e}")
|
||||
|
||||
async def _export_directories(
|
||||
self, zf: zipfile.ZipFile
|
||||
) -> dict[str, dict[str, int]]:
|
||||
"""导出插件和其他数据目录
|
||||
|
||||
Returns:
|
||||
dict: 每个目录的统计信息 {dir_name: {"files": count, "size": bytes}}
|
||||
"""
|
||||
stats: dict[str, dict[str, int]] = {}
|
||||
backup_directories = get_backup_directories()
|
||||
|
||||
for dir_name, dir_path in backup_directories.items():
|
||||
full_path = Path(dir_path)
|
||||
if not full_path.exists():
|
||||
logger.debug(f"目录不存在,跳过: {full_path}")
|
||||
continue
|
||||
|
||||
file_count = 0
|
||||
total_size = 0
|
||||
|
||||
try:
|
||||
for root, dirs, files in os.walk(full_path):
|
||||
# 跳过 __pycache__ 目录
|
||||
dirs[:] = [d for d in dirs if d != "__pycache__"]
|
||||
|
||||
for file in files:
|
||||
# 跳过 .pyc 文件
|
||||
if file.endswith(".pyc"):
|
||||
continue
|
||||
|
||||
file_path = Path(root) / file
|
||||
try:
|
||||
# 计算相对路径
|
||||
rel_path = file_path.relative_to(full_path)
|
||||
archive_path = f"directories/{dir_name}/{rel_path}"
|
||||
zf.write(str(file_path), archive_path)
|
||||
file_count += 1
|
||||
total_size += file_path.stat().st_size
|
||||
except Exception as e:
|
||||
logger.warning(f"导出文件 {file_path} 失败: {e}")
|
||||
|
||||
stats[dir_name] = {"files": file_count, "size": total_size}
|
||||
logger.debug(
|
||||
f"导出目录 {dir_name}: {file_count} 个文件, {total_size} 字节"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"导出目录 {dir_path} 失败: {e}")
|
||||
stats[dir_name] = {"files": 0, "size": 0}
|
||||
|
||||
return stats
|
||||
|
||||
async def _export_attachments(
|
||||
self, zf: zipfile.ZipFile, attachments: list[dict]
|
||||
) -> None:
|
||||
"""导出附件文件"""
|
||||
for attachment in attachments:
|
||||
try:
|
||||
file_path = attachment.get("path", "")
|
||||
if file_path and os.path.exists(file_path):
|
||||
# 使用 attachment_id 作为文件名
|
||||
attachment_id = attachment.get("attachment_id", "")
|
||||
ext = os.path.splitext(file_path)[1]
|
||||
archive_path = f"files/attachments/{attachment_id}{ext}"
|
||||
zf.write(file_path, archive_path)
|
||||
except Exception as e:
|
||||
logger.warning(f"导出附件失败: {e}")
|
||||
|
||||
def _model_to_dict(self, record: Any) -> dict:
|
||||
"""将 SQLModel 实例转换为字典
|
||||
|
||||
这是数据库无关的序列化方式,支持未来迁移到其他数据库。
|
||||
"""
|
||||
# 使用 SQLModel 内置的 model_dump 方法(如果可用)
|
||||
if hasattr(record, "model_dump"):
|
||||
data = record.model_dump(mode="python")
|
||||
# 处理 datetime 类型
|
||||
for key, value in data.items():
|
||||
if isinstance(value, datetime):
|
||||
data[key] = value.isoformat()
|
||||
return data
|
||||
|
||||
# 回退到手动提取
|
||||
data = {}
|
||||
# 使用 inspect 获取表信息
|
||||
from sqlalchemy import inspect as sa_inspect
|
||||
|
||||
mapper = sa_inspect(record.__class__)
|
||||
for column in mapper.columns:
|
||||
value = getattr(record, column.name)
|
||||
# 处理 datetime 类型 - 统一转为 ISO 格式字符串
|
||||
if isinstance(value, datetime):
|
||||
value = value.isoformat()
|
||||
data[column.name] = value
|
||||
return data
|
||||
|
||||
def _add_checksum(self, path: str, content: str | bytes) -> None:
|
||||
"""计算并添加文件校验和"""
|
||||
if isinstance(content, str):
|
||||
content = content.encode("utf-8")
|
||||
checksum = hashlib.sha256(content).hexdigest()
|
||||
self._checksums[path] = f"sha256:{checksum}"
|
||||
|
||||
def _generate_manifest(
|
||||
self,
|
||||
main_data: dict[str, list[dict]],
|
||||
kb_meta_data: dict[str, list[dict]],
|
||||
dir_stats: dict[str, dict[str, int]] | None = None,
|
||||
) -> dict:
|
||||
"""生成备份清单"""
|
||||
if dir_stats is None:
|
||||
dir_stats = {}
|
||||
# 收集知识库 ID
|
||||
kb_document_tables = {}
|
||||
if self.kb_manager:
|
||||
for kb_id in self.kb_manager.kb_insts.keys():
|
||||
kb_document_tables[kb_id] = "documents"
|
||||
|
||||
# 收集附件文件列表
|
||||
attachment_files = []
|
||||
for attachment in main_data.get("attachments", []):
|
||||
attachment_id = attachment.get("attachment_id", "")
|
||||
path = attachment.get("path", "")
|
||||
if attachment_id and path:
|
||||
ext = os.path.splitext(path)[1]
|
||||
attachment_files.append(f"{attachment_id}{ext}")
|
||||
|
||||
# 收集知识库媒体文件
|
||||
kb_media_files: dict[str, list[str]] = {}
|
||||
if self.kb_manager:
|
||||
for kb_id, kb_helper in self.kb_manager.kb_insts.items():
|
||||
media_files: list[str] = []
|
||||
media_dir = kb_helper.kb_medias_dir
|
||||
if media_dir.exists():
|
||||
for root, _, files in os.walk(media_dir):
|
||||
for file in files:
|
||||
media_files.append(file)
|
||||
if media_files:
|
||||
kb_media_files[kb_id] = media_files
|
||||
|
||||
manifest = {
|
||||
"version": BACKUP_MANIFEST_VERSION,
|
||||
"astrbot_version": VERSION,
|
||||
"exported_at": datetime.now(timezone.utc).isoformat(),
|
||||
"schema_version": {
|
||||
"main_db": "v4",
|
||||
"kb_db": "v1",
|
||||
},
|
||||
"tables": {
|
||||
"main_db": list(main_data.keys()),
|
||||
"kb_metadata": list(kb_meta_data.keys()),
|
||||
"kb_documents": kb_document_tables,
|
||||
},
|
||||
"files": {
|
||||
"attachments": attachment_files,
|
||||
"kb_media": kb_media_files,
|
||||
},
|
||||
"directories": list(dir_stats.keys()),
|
||||
"checksums": self._checksums,
|
||||
"statistics": {
|
||||
"main_db": {
|
||||
table: len(records) for table, records in main_data.items()
|
||||
},
|
||||
"kb_metadata": {
|
||||
table: len(records) for table, records in kb_meta_data.items()
|
||||
},
|
||||
"directories": dir_stats,
|
||||
},
|
||||
}
|
||||
|
||||
return manifest
|
||||
@@ -0,0 +1,761 @@
|
||||
"""AstrBot 数据导入器
|
||||
|
||||
负责从 ZIP 备份文件恢复所有数据。
|
||||
导入时进行版本校验:
|
||||
- 主版本(前两位)不同时直接拒绝导入
|
||||
- 小版本(第三位)不同时提示警告,用户可选择强制导入
|
||||
- 版本匹配时也需要用户确认
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import zipfile
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from sqlalchemy import delete
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.config.default import VERSION
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.utils.astrbot_path import (
|
||||
get_astrbot_data_path,
|
||||
get_astrbot_knowledge_base_path,
|
||||
)
|
||||
from astrbot.core.utils.version_comparator import VersionComparator
|
||||
|
||||
# 从共享常量模块导入
|
||||
from .constants import (
|
||||
KB_METADATA_MODELS,
|
||||
MAIN_DB_MODELS,
|
||||
get_backup_directories,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager
|
||||
|
||||
|
||||
def _get_major_version(version_str: str) -> str:
|
||||
"""提取版本的主版本部分(前两位)
|
||||
|
||||
Args:
|
||||
version_str: 版本字符串,如 "4.9.1", "4.10.0-beta"
|
||||
|
||||
Returns:
|
||||
主版本字符串,如 "4.9", "4.10"
|
||||
"""
|
||||
if not version_str:
|
||||
return "0.0"
|
||||
# 移除 v 前缀和预发布标签
|
||||
version = version_str.lower().replace("v", "").split("-")[0].split("+")[0]
|
||||
parts = [p for p in version.split(".") if p] # 过滤空字符串
|
||||
if len(parts) >= 2:
|
||||
return f"{parts[0]}.{parts[1]}"
|
||||
elif len(parts) == 1 and parts[0]:
|
||||
return f"{parts[0]}.0"
|
||||
return "0.0"
|
||||
|
||||
|
||||
CMD_CONFIG_FILE_PATH = os.path.join(get_astrbot_data_path(), "cmd_config.json")
|
||||
KB_PATH = get_astrbot_knowledge_base_path()
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImportPreCheckResult:
|
||||
"""导入预检查结果
|
||||
|
||||
用于在实际导入前检查备份文件的版本兼容性,
|
||||
并返回确认信息让用户决定是否继续导入。
|
||||
"""
|
||||
|
||||
# 检查是否通过(文件有效且版本可导入)
|
||||
valid: bool = False
|
||||
# 是否可以导入(版本兼容)
|
||||
can_import: bool = False
|
||||
# 版本状态: match(完全匹配), minor_diff(小版本差异), major_diff(主版本不同,拒绝)
|
||||
version_status: str = ""
|
||||
# 备份文件中的 AstrBot 版本
|
||||
backup_version: str = ""
|
||||
# 当前运行的 AstrBot 版本
|
||||
current_version: str = VERSION
|
||||
# 备份创建时间
|
||||
backup_time: str = ""
|
||||
# 确认消息(显示给用户)
|
||||
confirm_message: str = ""
|
||||
# 警告消息列表
|
||||
warnings: list[str] = field(default_factory=list)
|
||||
# 错误消息(如果检查失败)
|
||||
error: str = ""
|
||||
# 备份包含的内容摘要
|
||||
backup_summary: dict = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"valid": self.valid,
|
||||
"can_import": self.can_import,
|
||||
"version_status": self.version_status,
|
||||
"backup_version": self.backup_version,
|
||||
"current_version": self.current_version,
|
||||
"backup_time": self.backup_time,
|
||||
"confirm_message": self.confirm_message,
|
||||
"warnings": self.warnings,
|
||||
"error": self.error,
|
||||
"backup_summary": self.backup_summary,
|
||||
}
|
||||
|
||||
|
||||
class ImportResult:
|
||||
"""导入结果"""
|
||||
|
||||
def __init__(self):
|
||||
self.success = True
|
||||
self.imported_tables: dict[str, int] = {}
|
||||
self.imported_files: dict[str, int] = {}
|
||||
self.imported_directories: dict[str, int] = {}
|
||||
self.warnings: list[str] = []
|
||||
self.errors: list[str] = []
|
||||
|
||||
def add_warning(self, msg: str) -> None:
|
||||
self.warnings.append(msg)
|
||||
logger.warning(msg)
|
||||
|
||||
def add_error(self, msg: str) -> None:
|
||||
self.errors.append(msg)
|
||||
self.success = False
|
||||
logger.error(msg)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"success": self.success,
|
||||
"imported_tables": self.imported_tables,
|
||||
"imported_files": self.imported_files,
|
||||
"imported_directories": self.imported_directories,
|
||||
"warnings": self.warnings,
|
||||
"errors": self.errors,
|
||||
}
|
||||
|
||||
|
||||
class AstrBotImporter:
|
||||
"""AstrBot 数据导入器
|
||||
|
||||
导入备份文件中的所有数据,包括:
|
||||
- 主数据库所有表
|
||||
- 知识库元数据和文档
|
||||
- 配置文件
|
||||
- 附件文件
|
||||
- 知识库多媒体文件
|
||||
- 插件目录(data/plugins)
|
||||
- 插件数据目录(data/plugin_data)
|
||||
- 配置目录(data/config)
|
||||
- T2I 模板目录(data/t2i_templates)
|
||||
- WebChat 数据目录(data/webchat)
|
||||
- 临时文件目录(data/temp)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
main_db: BaseDatabase,
|
||||
kb_manager: "KnowledgeBaseManager | None" = None,
|
||||
config_path: str = CMD_CONFIG_FILE_PATH,
|
||||
kb_root_dir: str = KB_PATH,
|
||||
):
|
||||
self.main_db = main_db
|
||||
self.kb_manager = kb_manager
|
||||
self.config_path = config_path
|
||||
self.kb_root_dir = kb_root_dir
|
||||
|
||||
def pre_check(self, zip_path: str) -> ImportPreCheckResult:
|
||||
"""预检查备份文件
|
||||
|
||||
在实际导入前检查备份文件的有效性和版本兼容性。
|
||||
返回检查结果供前端显示确认对话框。
|
||||
|
||||
Args:
|
||||
zip_path: ZIP 备份文件路径
|
||||
|
||||
Returns:
|
||||
ImportPreCheckResult: 预检查结果
|
||||
"""
|
||||
result = ImportPreCheckResult()
|
||||
result.current_version = VERSION
|
||||
|
||||
if not os.path.exists(zip_path):
|
||||
result.error = f"备份文件不存在: {zip_path}"
|
||||
return result
|
||||
|
||||
try:
|
||||
with zipfile.ZipFile(zip_path, "r") as zf:
|
||||
# 读取 manifest
|
||||
try:
|
||||
manifest_data = zf.read("manifest.json")
|
||||
manifest = json.loads(manifest_data)
|
||||
except KeyError:
|
||||
result.error = "备份文件缺少 manifest.json,不是有效的 AstrBot 备份"
|
||||
return result
|
||||
except json.JSONDecodeError as e:
|
||||
result.error = f"manifest.json 格式错误: {e}"
|
||||
return result
|
||||
|
||||
# 提取基本信息
|
||||
result.backup_version = manifest.get("astrbot_version", "未知")
|
||||
result.backup_time = manifest.get("exported_at", "未知")
|
||||
result.valid = True
|
||||
|
||||
# 构建备份摘要
|
||||
result.backup_summary = {
|
||||
"tables": list(manifest.get("tables", {}).keys()),
|
||||
"has_knowledge_bases": manifest.get("has_knowledge_bases", False),
|
||||
"has_config": manifest.get("has_config", False),
|
||||
"directories": manifest.get("directories", []),
|
||||
}
|
||||
|
||||
# 检查版本兼容性
|
||||
version_check = self._check_version_compatibility(result.backup_version)
|
||||
result.version_status = version_check["status"]
|
||||
result.can_import = version_check["can_import"]
|
||||
|
||||
# 版本信息由前端根据 version_status 和 i18n 生成显示
|
||||
# 不再将版本消息添加到 warnings 列表中,避免中文硬编码
|
||||
# warnings 列表保留用于其他非版本相关的警告
|
||||
|
||||
return result
|
||||
|
||||
except zipfile.BadZipFile:
|
||||
result.error = "无效的 ZIP 文件"
|
||||
return result
|
||||
except Exception as e:
|
||||
result.error = f"检查备份文件失败: {e}"
|
||||
return result
|
||||
|
||||
def _check_version_compatibility(self, backup_version: str) -> dict:
|
||||
"""检查版本兼容性
|
||||
|
||||
规则:
|
||||
- 主版本(前两位,如 4.9)必须一致,否则拒绝
|
||||
- 小版本(第三位,如 4.9.1 vs 4.9.2)不同时,警告但允许导入
|
||||
|
||||
Returns:
|
||||
dict: {status, can_import, message}
|
||||
"""
|
||||
if not backup_version:
|
||||
return {
|
||||
"status": "major_diff",
|
||||
"can_import": False,
|
||||
"message": "备份文件缺少版本信息",
|
||||
}
|
||||
|
||||
# 提取主版本(前两位)进行比较
|
||||
backup_major = _get_major_version(backup_version)
|
||||
current_major = _get_major_version(VERSION)
|
||||
|
||||
# 比较主版本
|
||||
if VersionComparator.compare_version(backup_major, current_major) != 0:
|
||||
return {
|
||||
"status": "major_diff",
|
||||
"can_import": False,
|
||||
"message": (
|
||||
f"主版本不兼容: 备份版本 {backup_version}, 当前版本 {VERSION}。"
|
||||
f"跨主版本导入可能导致数据损坏,请使用相同主版本的 AstrBot。"
|
||||
),
|
||||
}
|
||||
|
||||
# 比较完整版本
|
||||
version_cmp = VersionComparator.compare_version(backup_version, VERSION)
|
||||
if version_cmp != 0:
|
||||
return {
|
||||
"status": "minor_diff",
|
||||
"can_import": True,
|
||||
"message": (
|
||||
f"小版本差异: 备份版本 {backup_version}, 当前版本 {VERSION}。"
|
||||
),
|
||||
}
|
||||
|
||||
return {
|
||||
"status": "match",
|
||||
"can_import": True,
|
||||
"message": "版本匹配",
|
||||
}
|
||||
|
||||
async def import_all(
|
||||
self,
|
||||
zip_path: str,
|
||||
mode: str = "replace", # "replace" 清空后导入
|
||||
progress_callback: Any | None = None,
|
||||
) -> ImportResult:
|
||||
"""从 ZIP 文件导入所有数据
|
||||
|
||||
Args:
|
||||
zip_path: ZIP 备份文件路径
|
||||
mode: 导入模式,目前仅支持 "replace"(清空后导入)
|
||||
progress_callback: 进度回调函数,接收参数 (stage, current, total, message)
|
||||
|
||||
Returns:
|
||||
ImportResult: 导入结果
|
||||
"""
|
||||
result = ImportResult()
|
||||
|
||||
if not os.path.exists(zip_path):
|
||||
result.add_error(f"备份文件不存在: {zip_path}")
|
||||
return result
|
||||
|
||||
logger.info(f"开始从 {zip_path} 导入备份")
|
||||
|
||||
try:
|
||||
with zipfile.ZipFile(zip_path, "r") as zf:
|
||||
# 1. 读取并验证 manifest
|
||||
if progress_callback:
|
||||
await progress_callback("validate", 0, 100, "正在验证备份文件...")
|
||||
|
||||
try:
|
||||
manifest_data = zf.read("manifest.json")
|
||||
manifest = json.loads(manifest_data)
|
||||
except KeyError:
|
||||
result.add_error("备份文件缺少 manifest.json")
|
||||
return result
|
||||
except json.JSONDecodeError as e:
|
||||
result.add_error(f"manifest.json 格式错误: {e}")
|
||||
return result
|
||||
|
||||
# 版本校验
|
||||
try:
|
||||
self._validate_version(manifest)
|
||||
except ValueError as e:
|
||||
result.add_error(str(e))
|
||||
return result
|
||||
|
||||
if progress_callback:
|
||||
await progress_callback("validate", 100, 100, "验证完成")
|
||||
|
||||
# 2. 导入主数据库
|
||||
if progress_callback:
|
||||
await progress_callback("main_db", 0, 100, "正在导入主数据库...")
|
||||
|
||||
try:
|
||||
main_data_content = zf.read("databases/main_db.json")
|
||||
main_data = json.loads(main_data_content)
|
||||
|
||||
if mode == "replace":
|
||||
await self._clear_main_db()
|
||||
|
||||
imported = await self._import_main_database(main_data)
|
||||
result.imported_tables.update(imported)
|
||||
except Exception as e:
|
||||
result.add_error(f"导入主数据库失败: {e}")
|
||||
return result
|
||||
|
||||
if progress_callback:
|
||||
await progress_callback("main_db", 100, 100, "主数据库导入完成")
|
||||
|
||||
# 3. 导入知识库
|
||||
if self.kb_manager and "databases/kb_metadata.json" in zf.namelist():
|
||||
if progress_callback:
|
||||
await progress_callback("kb", 0, 100, "正在导入知识库...")
|
||||
|
||||
try:
|
||||
kb_meta_content = zf.read("databases/kb_metadata.json")
|
||||
kb_meta_data = json.loads(kb_meta_content)
|
||||
|
||||
if mode == "replace":
|
||||
await self._clear_kb_data()
|
||||
|
||||
await self._import_knowledge_bases(zf, kb_meta_data, result)
|
||||
except Exception as e:
|
||||
result.add_warning(f"导入知识库失败: {e}")
|
||||
|
||||
if progress_callback:
|
||||
await progress_callback("kb", 100, 100, "知识库导入完成")
|
||||
|
||||
# 4. 导入配置文件
|
||||
if progress_callback:
|
||||
await progress_callback("config", 0, 100, "正在导入配置文件...")
|
||||
|
||||
if "config/cmd_config.json" in zf.namelist():
|
||||
try:
|
||||
config_content = zf.read("config/cmd_config.json")
|
||||
# 备份现有配置
|
||||
if os.path.exists(self.config_path):
|
||||
backup_path = f"{self.config_path}.bak"
|
||||
shutil.copy2(self.config_path, backup_path)
|
||||
|
||||
with open(self.config_path, "wb") as f:
|
||||
f.write(config_content)
|
||||
result.imported_files["config"] = 1
|
||||
except Exception as e:
|
||||
result.add_warning(f"导入配置文件失败: {e}")
|
||||
|
||||
if progress_callback:
|
||||
await progress_callback("config", 100, 100, "配置文件导入完成")
|
||||
|
||||
# 5. 导入附件文件
|
||||
if progress_callback:
|
||||
await progress_callback("attachments", 0, 100, "正在导入附件...")
|
||||
|
||||
attachment_count = await self._import_attachments(
|
||||
zf, main_data.get("attachments", [])
|
||||
)
|
||||
result.imported_files["attachments"] = attachment_count
|
||||
|
||||
if progress_callback:
|
||||
await progress_callback("attachments", 100, 100, "附件导入完成")
|
||||
|
||||
# 6. 导入插件和其他目录
|
||||
if progress_callback:
|
||||
await progress_callback(
|
||||
"directories", 0, 100, "正在导入插件和数据目录..."
|
||||
)
|
||||
|
||||
dir_stats = await self._import_directories(zf, manifest, result)
|
||||
result.imported_directories = dir_stats
|
||||
|
||||
if progress_callback:
|
||||
await progress_callback("directories", 100, 100, "目录导入完成")
|
||||
|
||||
logger.info(f"备份导入完成: {result.to_dict()}")
|
||||
return result
|
||||
|
||||
except zipfile.BadZipFile:
|
||||
result.add_error("无效的 ZIP 文件")
|
||||
return result
|
||||
except Exception as e:
|
||||
result.add_error(f"导入失败: {e}")
|
||||
return result
|
||||
|
||||
def _validate_version(self, manifest: dict) -> None:
|
||||
"""验证版本兼容性 - 仅允许相同主版本导入
|
||||
|
||||
注意:此方法仅在 import_all 中调用,用于双重校验。
|
||||
前端应先调用 pre_check 获取详细的版本信息并让用户确认。
|
||||
"""
|
||||
backup_version = manifest.get("astrbot_version")
|
||||
if not backup_version:
|
||||
raise ValueError("备份文件缺少版本信息")
|
||||
|
||||
# 使用新的版本兼容性检查
|
||||
version_check = self._check_version_compatibility(backup_version)
|
||||
|
||||
if version_check["status"] == "major_diff":
|
||||
raise ValueError(version_check["message"])
|
||||
|
||||
# minor_diff 和 match 都允许导入
|
||||
if version_check["status"] == "minor_diff":
|
||||
logger.warning(f"版本差异警告: {version_check['message']}")
|
||||
|
||||
async def _clear_main_db(self) -> None:
|
||||
"""清空主数据库所有表"""
|
||||
async with self.main_db.get_db() as session:
|
||||
async with session.begin():
|
||||
for table_name, model_class in MAIN_DB_MODELS.items():
|
||||
try:
|
||||
await session.execute(delete(model_class))
|
||||
logger.debug(f"已清空表 {table_name}")
|
||||
except Exception as e:
|
||||
logger.warning(f"清空表 {table_name} 失败: {e}")
|
||||
|
||||
async def _clear_kb_data(self) -> None:
|
||||
"""清空知识库数据"""
|
||||
if not self.kb_manager:
|
||||
return
|
||||
|
||||
# 清空知识库元数据表
|
||||
async with self.kb_manager.kb_db.get_db() as session:
|
||||
async with session.begin():
|
||||
for table_name, model_class in KB_METADATA_MODELS.items():
|
||||
try:
|
||||
await session.execute(delete(model_class))
|
||||
logger.debug(f"已清空知识库表 {table_name}")
|
||||
except Exception as e:
|
||||
logger.warning(f"清空知识库表 {table_name} 失败: {e}")
|
||||
|
||||
# 删除知识库文件目录
|
||||
for kb_id in list(self.kb_manager.kb_insts.keys()):
|
||||
try:
|
||||
kb_helper = self.kb_manager.kb_insts[kb_id]
|
||||
await kb_helper.terminate()
|
||||
if kb_helper.kb_dir.exists():
|
||||
shutil.rmtree(kb_helper.kb_dir)
|
||||
except Exception as e:
|
||||
logger.warning(f"清理知识库 {kb_id} 失败: {e}")
|
||||
|
||||
self.kb_manager.kb_insts.clear()
|
||||
|
||||
async def _import_main_database(
|
||||
self, data: dict[str, list[dict]]
|
||||
) -> dict[str, int]:
|
||||
"""导入主数据库数据"""
|
||||
imported: dict[str, int] = {}
|
||||
|
||||
async with self.main_db.get_db() as session:
|
||||
async with session.begin():
|
||||
for table_name, rows in data.items():
|
||||
model_class = MAIN_DB_MODELS.get(table_name)
|
||||
if not model_class:
|
||||
logger.warning(f"未知的表: {table_name}")
|
||||
continue
|
||||
|
||||
count = 0
|
||||
for row in rows:
|
||||
try:
|
||||
# 转换 datetime 字符串为 datetime 对象
|
||||
row = self._convert_datetime_fields(row, model_class)
|
||||
obj = model_class(**row)
|
||||
session.add(obj)
|
||||
count += 1
|
||||
except Exception as e:
|
||||
logger.warning(f"导入记录到 {table_name} 失败: {e}")
|
||||
|
||||
imported[table_name] = count
|
||||
logger.debug(f"导入表 {table_name}: {count} 条记录")
|
||||
|
||||
return imported
|
||||
|
||||
async def _import_knowledge_bases(
|
||||
self,
|
||||
zf: zipfile.ZipFile,
|
||||
kb_meta_data: dict[str, list[dict]],
|
||||
result: ImportResult,
|
||||
) -> None:
|
||||
"""导入知识库数据"""
|
||||
if not self.kb_manager:
|
||||
return
|
||||
|
||||
# 1. 导入知识库元数据
|
||||
async with self.kb_manager.kb_db.get_db() as session:
|
||||
async with session.begin():
|
||||
for table_name, rows in kb_meta_data.items():
|
||||
model_class = KB_METADATA_MODELS.get(table_name)
|
||||
if not model_class:
|
||||
continue
|
||||
|
||||
count = 0
|
||||
for row in rows:
|
||||
try:
|
||||
row = self._convert_datetime_fields(row, model_class)
|
||||
obj = model_class(**row)
|
||||
session.add(obj)
|
||||
count += 1
|
||||
except Exception as e:
|
||||
logger.warning(f"导入知识库记录到 {table_name} 失败: {e}")
|
||||
|
||||
result.imported_tables[f"kb_{table_name}"] = count
|
||||
|
||||
# 2. 导入每个知识库的文档和文件
|
||||
for kb_data in kb_meta_data.get("knowledge_bases", []):
|
||||
kb_id = kb_data.get("kb_id")
|
||||
if not kb_id:
|
||||
continue
|
||||
|
||||
# 创建知识库目录
|
||||
kb_dir = Path(self.kb_root_dir) / kb_id
|
||||
kb_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 导入文档数据
|
||||
doc_path = f"databases/kb_{kb_id}/documents.json"
|
||||
if doc_path in zf.namelist():
|
||||
try:
|
||||
doc_content = zf.read(doc_path)
|
||||
doc_data = json.loads(doc_content)
|
||||
|
||||
# 导入到文档存储数据库
|
||||
await self._import_kb_documents(kb_id, doc_data)
|
||||
except Exception as e:
|
||||
result.add_warning(f"导入知识库 {kb_id} 的文档失败: {e}")
|
||||
|
||||
# 导入 FAISS 索引
|
||||
faiss_path = f"databases/kb_{kb_id}/index.faiss"
|
||||
if faiss_path in zf.namelist():
|
||||
try:
|
||||
target_path = kb_dir / "index.faiss"
|
||||
with zf.open(faiss_path) as src, open(target_path, "wb") as dst:
|
||||
dst.write(src.read())
|
||||
except Exception as e:
|
||||
result.add_warning(f"导入知识库 {kb_id} 的 FAISS 索引失败: {e}")
|
||||
|
||||
# 导入媒体文件
|
||||
media_prefix = f"files/kb_media/{kb_id}/"
|
||||
for name in zf.namelist():
|
||||
if name.startswith(media_prefix):
|
||||
try:
|
||||
rel_path = name[len(media_prefix) :]
|
||||
target_path = kb_dir / rel_path
|
||||
target_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with zf.open(name) as src, open(target_path, "wb") as dst:
|
||||
dst.write(src.read())
|
||||
except Exception as e:
|
||||
result.add_warning(f"导入媒体文件 {name} 失败: {e}")
|
||||
|
||||
# 3. 重新加载知识库实例
|
||||
await self.kb_manager.load_kbs()
|
||||
|
||||
async def _import_kb_documents(self, kb_id: str, doc_data: dict) -> None:
|
||||
"""导入知识库文档到向量数据库"""
|
||||
from astrbot.core.db.vec_db.faiss_impl.document_storage import DocumentStorage
|
||||
|
||||
kb_dir = Path(self.kb_root_dir) / kb_id
|
||||
doc_db_path = kb_dir / "doc.db"
|
||||
|
||||
# 初始化文档存储
|
||||
doc_storage = DocumentStorage(str(doc_db_path))
|
||||
await doc_storage.initialize()
|
||||
|
||||
try:
|
||||
documents = doc_data.get("documents", [])
|
||||
for doc in documents:
|
||||
try:
|
||||
await doc_storage.insert_document(
|
||||
doc_id=doc.get("doc_id", ""),
|
||||
text=doc.get("text", ""),
|
||||
metadata=json.loads(doc.get("metadata", "{}")),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"导入文档块失败: {e}")
|
||||
finally:
|
||||
await doc_storage.close()
|
||||
|
||||
async def _import_attachments(
|
||||
self,
|
||||
zf: zipfile.ZipFile,
|
||||
attachments: list[dict],
|
||||
) -> int:
|
||||
"""导入附件文件"""
|
||||
count = 0
|
||||
|
||||
attachments_dir = Path(self.config_path).parent / "attachments"
|
||||
attachments_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
attachment_prefix = "files/attachments/"
|
||||
for name in zf.namelist():
|
||||
if name.startswith(attachment_prefix) and name != attachment_prefix:
|
||||
try:
|
||||
# 从附件记录中找到原始路径
|
||||
attachment_id = os.path.splitext(os.path.basename(name))[0]
|
||||
original_path = None
|
||||
for att in attachments:
|
||||
if att.get("attachment_id") == attachment_id:
|
||||
original_path = att.get("path")
|
||||
break
|
||||
|
||||
if original_path:
|
||||
target_path = Path(original_path)
|
||||
else:
|
||||
target_path = attachments_dir / os.path.basename(name)
|
||||
|
||||
target_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with zf.open(name) as src, open(target_path, "wb") as dst:
|
||||
dst.write(src.read())
|
||||
count += 1
|
||||
except Exception as e:
|
||||
logger.warning(f"导入附件 {name} 失败: {e}")
|
||||
|
||||
return count
|
||||
|
||||
async def _import_directories(
|
||||
self,
|
||||
zf: zipfile.ZipFile,
|
||||
manifest: dict,
|
||||
result: ImportResult,
|
||||
) -> dict[str, int]:
|
||||
"""导入插件和其他数据目录
|
||||
|
||||
Args:
|
||||
zf: ZIP 文件对象
|
||||
manifest: 备份清单
|
||||
result: 导入结果对象
|
||||
|
||||
Returns:
|
||||
dict: 每个目录导入的文件数量
|
||||
"""
|
||||
dir_stats: dict[str, int] = {}
|
||||
|
||||
# 检查备份版本是否支持目录备份(需要版本 >= 1.1)
|
||||
backup_version = manifest.get("version", "1.0")
|
||||
if VersionComparator.compare_version(backup_version, "1.1") < 0:
|
||||
logger.info("备份版本不支持目录备份,跳过目录导入")
|
||||
return dir_stats
|
||||
|
||||
backed_up_dirs = manifest.get("directories", [])
|
||||
backup_directories = get_backup_directories()
|
||||
|
||||
for dir_name in backed_up_dirs:
|
||||
if dir_name not in backup_directories:
|
||||
result.add_warning(f"未知的目录类型: {dir_name}")
|
||||
continue
|
||||
|
||||
target_dir = Path(backup_directories[dir_name])
|
||||
archive_prefix = f"directories/{dir_name}/"
|
||||
|
||||
file_count = 0
|
||||
|
||||
try:
|
||||
# 获取该目录下的所有文件
|
||||
dir_files = [
|
||||
name
|
||||
for name in zf.namelist()
|
||||
if name.startswith(archive_prefix) and name != archive_prefix
|
||||
]
|
||||
|
||||
if not dir_files:
|
||||
continue
|
||||
|
||||
# 备份现有目录(如果存在)
|
||||
if target_dir.exists():
|
||||
backup_path = Path(f"{target_dir}.bak")
|
||||
if backup_path.exists():
|
||||
shutil.rmtree(backup_path)
|
||||
shutil.move(str(target_dir), str(backup_path))
|
||||
logger.debug(f"已备份现有目录 {target_dir} 到 {backup_path}")
|
||||
|
||||
# 创建目标目录
|
||||
target_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 解压文件
|
||||
for name in dir_files:
|
||||
try:
|
||||
# 计算相对路径
|
||||
rel_path = name[len(archive_prefix) :]
|
||||
if not rel_path: # 跳过目录条目
|
||||
continue
|
||||
|
||||
target_path = target_dir / rel_path
|
||||
target_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with zf.open(name) as src, open(target_path, "wb") as dst:
|
||||
dst.write(src.read())
|
||||
file_count += 1
|
||||
except Exception as e:
|
||||
result.add_warning(f"导入文件 {name} 失败: {e}")
|
||||
|
||||
dir_stats[dir_name] = file_count
|
||||
logger.debug(f"导入目录 {dir_name}: {file_count} 个文件")
|
||||
|
||||
except Exception as e:
|
||||
result.add_warning(f"导入目录 {dir_name} 失败: {e}")
|
||||
dir_stats[dir_name] = 0
|
||||
|
||||
return dir_stats
|
||||
|
||||
def _convert_datetime_fields(self, row: dict, model_class: type) -> dict:
|
||||
"""转换 datetime 字符串字段为 datetime 对象"""
|
||||
result = row.copy()
|
||||
|
||||
# 获取模型的 datetime 字段
|
||||
from sqlalchemy import inspect as sa_inspect
|
||||
|
||||
try:
|
||||
mapper = sa_inspect(model_class)
|
||||
for column in mapper.columns:
|
||||
if column.name in result and result[column.name] is not None:
|
||||
# 检查是否是 datetime 类型的列
|
||||
from sqlalchemy import DateTime
|
||||
|
||||
if isinstance(column.type, DateTime):
|
||||
value = result[column.name]
|
||||
if isinstance(value, str):
|
||||
# 解析 ISO 格式的日期时间字符串
|
||||
result[column.name] = datetime.fromisoformat(value)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return result
|
||||
@@ -5,7 +5,7 @@ from typing import Any, TypedDict
|
||||
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
VERSION = "4.10.2"
|
||||
VERSION = "4.10.3"
|
||||
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db")
|
||||
|
||||
WEBHOOK_SUPPORTED_PLATFORMS = [
|
||||
|
||||
+1
-1
@@ -58,7 +58,7 @@ def is_plugin_path(pathname):
|
||||
return False
|
||||
|
||||
norm_path = os.path.normpath(pathname)
|
||||
return ("data/plugins" in norm_path) or ("packages/" in norm_path)
|
||||
return ("data/plugins" in norm_path) or ("astrbot/builtin_stars/" in norm_path)
|
||||
|
||||
|
||||
def get_short_level_name(level_name):
|
||||
|
||||
@@ -390,7 +390,7 @@ class InternalAgentSubStage(Stage):
|
||||
return
|
||||
|
||||
req.prompt = event.message_str[len(provider_wake_prefix) :]
|
||||
# func_tool selection 现在已经转移到 packages/astrbot 插件中进行选择。
|
||||
# func_tool selection 现在已经转移到 astrbot/builtin_stars/astrbot 插件中进行选择。
|
||||
# req.func_tool = self.ctx.plugin_manager.context.get_llm_tool_manager()
|
||||
for comp in event.message_obj.message:
|
||||
if isinstance(comp, Image):
|
||||
|
||||
@@ -136,7 +136,8 @@ class WakingCheckStage(Stage):
|
||||
):
|
||||
if (
|
||||
self.disable_builtin_commands
|
||||
and handler.handler_module_path == "packages.builtin_commands.main"
|
||||
and handler.handler_module_path
|
||||
== "astrbot.builtin_stars.builtin_commands.main"
|
||||
):
|
||||
logger.debug("skipping builtin command")
|
||||
continue
|
||||
|
||||
@@ -14,6 +14,7 @@ import astrbot.core.message.components as Comp
|
||||
from astrbot import logger
|
||||
from astrbot.core.agent.message import (
|
||||
AssistantMessageSegment,
|
||||
ContentPart,
|
||||
ToolCall,
|
||||
ToolCallMessageSegment,
|
||||
)
|
||||
@@ -92,6 +93,8 @@ class ProviderRequest:
|
||||
"""会话 ID"""
|
||||
image_urls: list[str] = field(default_factory=list)
|
||||
"""图片 URL 列表"""
|
||||
extra_user_content_parts: list[ContentPart] = field(default_factory=list)
|
||||
"""额外的用户消息内容部分列表,用于在用户消息后添加额外的内容块(如系统提醒、指令等)。支持 dict 或 ContentPart 对象"""
|
||||
func_tool: ToolSet | None = None
|
||||
"""可用的函数工具"""
|
||||
contexts: list[dict] = field(default_factory=list)
|
||||
@@ -166,13 +169,23 @@ class ProviderRequest:
|
||||
|
||||
async def assemble_context(self) -> dict:
|
||||
"""将请求(prompt 和 image_urls)包装成 OpenAI 的消息格式。"""
|
||||
# 构建内容块列表
|
||||
content_blocks = []
|
||||
|
||||
# 1. 用户原始发言(OpenAI 建议:用户发言在前)
|
||||
if self.prompt and self.prompt.strip():
|
||||
content_blocks.append({"type": "text", "text": self.prompt})
|
||||
elif self.image_urls:
|
||||
# 如果没有文本但有图片,添加占位文本
|
||||
content_blocks.append({"type": "text", "text": "[图片]"})
|
||||
|
||||
# 2. 额外的内容块(系统提醒、指令等)
|
||||
if self.extra_user_content_parts:
|
||||
for part in self.extra_user_content_parts:
|
||||
content_blocks.append(part.model_dump())
|
||||
|
||||
# 3. 图片内容
|
||||
if self.image_urls:
|
||||
user_content = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": self.prompt if self.prompt else "[图片]"},
|
||||
],
|
||||
}
|
||||
for image_url in self.image_urls:
|
||||
if image_url.startswith("http"):
|
||||
image_path = await download_image_by_url(image_url)
|
||||
@@ -185,11 +198,21 @@ class ProviderRequest:
|
||||
if not image_data:
|
||||
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
|
||||
continue
|
||||
user_content["content"].append(
|
||||
content_blocks.append(
|
||||
{"type": "image_url", "image_url": {"url": image_data}},
|
||||
)
|
||||
return user_content
|
||||
return {"role": "user", "content": self.prompt}
|
||||
|
||||
# 只有当只有一个来自 prompt 的文本块且没有额外内容块时,才降级为简单格式以保持向后兼容
|
||||
if (
|
||||
len(content_blocks) == 1
|
||||
and content_blocks[0]["type"] == "text"
|
||||
and not self.extra_user_content_parts
|
||||
and not self.image_urls
|
||||
):
|
||||
return {"role": "user", "content": content_blocks[0]["text"]}
|
||||
|
||||
# 否则返回多模态格式
|
||||
return {"role": "user", "content": content_blocks}
|
||||
|
||||
async def _encode_image_bs64(self, image_url: str) -> str:
|
||||
"""将图片转换为 base64"""
|
||||
|
||||
@@ -4,7 +4,7 @@ import os
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import TypeAlias, Union
|
||||
|
||||
from astrbot.core.agent.message import Message
|
||||
from astrbot.core.agent.message import ContentPart, Message
|
||||
from astrbot.core.agent.tool import ToolSet
|
||||
from astrbot.core.provider.entities import (
|
||||
LLMResponse,
|
||||
@@ -103,6 +103,7 @@ class Provider(AbstractProvider):
|
||||
system_prompt: str | None = None,
|
||||
tool_calls_result: ToolCallsResult | list[ToolCallsResult] | None = None,
|
||||
model: str | None = None,
|
||||
extra_user_content_parts: list[ContentPart] | None = None,
|
||||
**kwargs,
|
||||
) -> LLMResponse:
|
||||
"""获得 LLM 的文本对话结果。会使用当前的模型进行对话。
|
||||
@@ -114,6 +115,7 @@ class Provider(AbstractProvider):
|
||||
tools: tool set
|
||||
contexts: 上下文,和 prompt 二选一使用
|
||||
tool_calls_result: 回传给 LLM 的工具调用结果。参考: https://platform.openai.com/docs/guides/function-calling
|
||||
extra_user_content_parts: 额外的内容块列表,用于在用户消息后添加额外的文本块(如系统提醒、指令等)
|
||||
kwargs: 其他参数
|
||||
|
||||
Notes:
|
||||
|
||||
@@ -11,6 +11,7 @@ from anthropic.types.usage import Usage
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot.api.provider import Provider
|
||||
from astrbot.core.agent.message import ContentPart, ImageURLPart, TextPart
|
||||
from astrbot.core.provider.entities import LLMResponse, TokenUsage
|
||||
from astrbot.core.provider.func_tool_manager import ToolSet
|
||||
from astrbot.core.utils.io import download_image_by_url
|
||||
@@ -68,7 +69,7 @@ class ProviderAnthropic(Provider):
|
||||
blocks = []
|
||||
if isinstance(message["content"], str):
|
||||
blocks.append({"type": "text", "text": message["content"]})
|
||||
if "tool_calls" in message:
|
||||
if "tool_calls" in message and isinstance(message["tool_calls"], list):
|
||||
for tool_call in message["tool_calls"]:
|
||||
blocks.append( # noqa: PERF401
|
||||
{
|
||||
@@ -132,6 +133,9 @@ class ProviderAnthropic(Provider):
|
||||
|
||||
extra_body = self.provider_config.get("custom_extra_body", {})
|
||||
|
||||
if "max_tokens" not in payloads:
|
||||
payloads["max_tokens"] = 1024
|
||||
|
||||
completion = await self.client.messages.create(
|
||||
**payloads, stream=False, extra_body=extra_body
|
||||
)
|
||||
@@ -181,6 +185,9 @@ class ProviderAnthropic(Provider):
|
||||
usage = TokenUsage()
|
||||
extra_body = self.provider_config.get("custom_extra_body", {})
|
||||
|
||||
if "max_tokens" not in payloads:
|
||||
payloads["max_tokens"] = 1024
|
||||
|
||||
async with self.client.messages.stream(
|
||||
**payloads, extra_body=extra_body
|
||||
) as stream:
|
||||
@@ -296,13 +303,16 @@ class ProviderAnthropic(Provider):
|
||||
system_prompt=None,
|
||||
tool_calls_result=None,
|
||||
model=None,
|
||||
extra_user_content_parts=None,
|
||||
**kwargs,
|
||||
) -> LLMResponse:
|
||||
if contexts is None:
|
||||
contexts = []
|
||||
new_record = None
|
||||
if prompt is not None:
|
||||
new_record = await self.assemble_context(prompt, image_urls)
|
||||
new_record = await self.assemble_context(
|
||||
prompt, image_urls, extra_user_content_parts
|
||||
)
|
||||
context_query = self._ensure_message_to_dicts(contexts)
|
||||
if new_record:
|
||||
context_query.append(new_record)
|
||||
@@ -342,21 +352,24 @@ class ProviderAnthropic(Provider):
|
||||
|
||||
async def text_chat_stream(
|
||||
self,
|
||||
prompt,
|
||||
prompt=None,
|
||||
session_id=None,
|
||||
image_urls=...,
|
||||
image_urls=None,
|
||||
func_tool=None,
|
||||
contexts=...,
|
||||
contexts=None,
|
||||
system_prompt=None,
|
||||
tool_calls_result=None,
|
||||
model=None,
|
||||
extra_user_content_parts=None,
|
||||
**kwargs,
|
||||
):
|
||||
if contexts is None:
|
||||
contexts = []
|
||||
new_record = None
|
||||
if prompt is not None:
|
||||
new_record = await self.assemble_context(prompt, image_urls)
|
||||
new_record = await self.assemble_context(
|
||||
prompt, image_urls, extra_user_content_parts
|
||||
)
|
||||
context_query = self._ensure_message_to_dicts(contexts)
|
||||
if new_record:
|
||||
context_query.append(new_record)
|
||||
@@ -388,15 +401,15 @@ class ProviderAnthropic(Provider):
|
||||
async for llm_response in self._query_stream(payloads, func_tool):
|
||||
yield llm_response
|
||||
|
||||
async def assemble_context(self, text: str, image_urls: list[str] | None = None):
|
||||
async def assemble_context(
|
||||
self,
|
||||
text: str,
|
||||
image_urls: list[str] | None = None,
|
||||
extra_user_content_parts: list[ContentPart] | None = None,
|
||||
):
|
||||
"""组装上下文,支持文本和图片"""
|
||||
if not image_urls:
|
||||
return {"role": "user", "content": text}
|
||||
|
||||
content = []
|
||||
content.append({"type": "text", "text": text})
|
||||
|
||||
for image_url in image_urls:
|
||||
async def resolve_image_url(image_url: str) -> dict | None:
|
||||
if image_url.startswith("http"):
|
||||
image_path = await download_image_by_url(image_url)
|
||||
image_data = await self.encode_image_bs64(image_path)
|
||||
@@ -408,28 +421,68 @@ class ProviderAnthropic(Provider):
|
||||
|
||||
if not image_data:
|
||||
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
|
||||
continue
|
||||
return None
|
||||
|
||||
# Get mime type for the image
|
||||
mime_type, _ = guess_type(image_url)
|
||||
if not mime_type:
|
||||
mime_type = "image/jpeg" # Default to JPEG if can't determine
|
||||
|
||||
content.append(
|
||||
{
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": mime_type,
|
||||
"data": (
|
||||
image_data.split("base64,")[1]
|
||||
if "base64," in image_data
|
||||
else image_data
|
||||
),
|
||||
},
|
||||
return {
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": mime_type,
|
||||
"data": (
|
||||
image_data.split("base64,")[1]
|
||||
if "base64," in image_data
|
||||
else image_data
|
||||
),
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
content = []
|
||||
|
||||
# 1. 用户原始发言(OpenAI 建议:用户发言在前)
|
||||
if text:
|
||||
content.append({"type": "text", "text": text})
|
||||
elif image_urls:
|
||||
# 如果没有文本但有图片,添加占位文本
|
||||
content.append({"type": "text", "text": "[图片]"})
|
||||
elif extra_user_content_parts:
|
||||
# 如果只有额外内容块,也需要添加占位文本
|
||||
content.append({"type": "text", "text": " "})
|
||||
|
||||
# 2. 额外的内容块(系统提醒、指令等)
|
||||
if extra_user_content_parts:
|
||||
for block in extra_user_content_parts:
|
||||
if isinstance(block, TextPart):
|
||||
content.append({"type": "text", "text": block.text})
|
||||
elif isinstance(block, ImageURLPart):
|
||||
image_dict = await resolve_image_url(block.image_url.url)
|
||||
if image_dict:
|
||||
content.append(image_dict)
|
||||
else:
|
||||
raise ValueError(f"不支持的额外内容块类型: {type(block)}")
|
||||
|
||||
# 3. 图片内容
|
||||
if image_urls:
|
||||
for image_url in image_urls:
|
||||
image_dict = await resolve_image_url(image_url)
|
||||
if image_dict:
|
||||
content.append(image_dict)
|
||||
|
||||
# 如果只有主文本且没有额外内容块和图片,返回简单格式以保持向后兼容
|
||||
if (
|
||||
text
|
||||
and not extra_user_content_parts
|
||||
and not image_urls
|
||||
and len(content) == 1
|
||||
and content[0]["type"] == "text"
|
||||
):
|
||||
return {"role": "user", "content": content[0]["text"]}
|
||||
|
||||
# 否则返回多模态格式
|
||||
return {"role": "user", "content": content}
|
||||
|
||||
async def encode_image_bs64(self, image_url: str) -> str:
|
||||
|
||||
@@ -56,10 +56,14 @@ class ProviderFishAudioTTSAPI(TTSProvider):
|
||||
"api_base",
|
||||
"https://api.fish-audio.cn/v1",
|
||||
)
|
||||
try:
|
||||
self.timeout: int = int(provider_config.get("timeout", 20))
|
||||
except ValueError:
|
||||
self.timeout = 20
|
||||
self.headers = {
|
||||
"Authorization": f"Bearer {self.chosen_api_key}",
|
||||
}
|
||||
self.set_model(provider_config["model"])
|
||||
self.set_model(provider_config.get("model", None))
|
||||
|
||||
async def _get_reference_id_by_character(self, character: str) -> str | None:
|
||||
"""获取角色的reference_id
|
||||
@@ -135,17 +139,21 @@ class ProviderFishAudioTTSAPI(TTSProvider):
|
||||
path = os.path.join(temp_dir, f"fishaudio_tts_api_{uuid.uuid4()}.wav")
|
||||
self.headers["content-type"] = "application/msgpack"
|
||||
request = await self._generate_request(text)
|
||||
async with AsyncClient(base_url=self.api_base).stream(
|
||||
async with AsyncClient(base_url=self.api_base, timeout=self.timeout).stream(
|
||||
"POST",
|
||||
"/tts",
|
||||
headers=self.headers,
|
||||
content=ormsgpack.packb(request, option=ormsgpack.OPT_SERIALIZE_PYDANTIC),
|
||||
) as response:
|
||||
if response.headers["content-type"] == "audio/wav":
|
||||
if response.status_code == 200 and response.headers.get(
|
||||
"content-type", ""
|
||||
).startswith("audio/"):
|
||||
with open(path, "wb") as f:
|
||||
async for chunk in response.aiter_bytes():
|
||||
f.write(chunk)
|
||||
return path
|
||||
body = await response.aread()
|
||||
text = body.decode("utf-8", errors="replace")
|
||||
raise Exception(f"Fish Audio API请求失败: {text}")
|
||||
error_bytes = await response.aread()
|
||||
error_text = error_bytes.decode("utf-8", errors="replace")[:1024]
|
||||
raise Exception(
|
||||
f"Fish Audio API请求失败: 状态码 {response.status_code}, 响应内容: {error_text}"
|
||||
)
|
||||
|
||||
@@ -13,6 +13,7 @@ from google.genai.errors import APIError
|
||||
import astrbot.core.message.components as Comp
|
||||
from astrbot import logger
|
||||
from astrbot.api.provider import Provider
|
||||
from astrbot.core.agent.message import ContentPart, ImageURLPart, TextPart
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from astrbot.core.provider.entities import LLMResponse, TokenUsage
|
||||
from astrbot.core.provider.func_tool_manager import ToolSet
|
||||
@@ -680,13 +681,16 @@ class ProviderGoogleGenAI(Provider):
|
||||
system_prompt=None,
|
||||
tool_calls_result=None,
|
||||
model=None,
|
||||
extra_user_content_parts=None,
|
||||
**kwargs,
|
||||
) -> LLMResponse:
|
||||
if contexts is None:
|
||||
contexts = []
|
||||
new_record = None
|
||||
if prompt is not None:
|
||||
new_record = await self.assemble_context(prompt, image_urls)
|
||||
new_record = await self.assemble_context(
|
||||
prompt, image_urls, extra_user_content_parts
|
||||
)
|
||||
context_query = self._ensure_message_to_dicts(contexts)
|
||||
if new_record:
|
||||
context_query.append(new_record)
|
||||
@@ -732,13 +736,16 @@ class ProviderGoogleGenAI(Provider):
|
||||
system_prompt=None,
|
||||
tool_calls_result=None,
|
||||
model=None,
|
||||
extra_user_content_parts=None,
|
||||
**kwargs,
|
||||
) -> AsyncGenerator[LLMResponse, None]:
|
||||
if contexts is None:
|
||||
contexts = []
|
||||
new_record = None
|
||||
if prompt is not None:
|
||||
new_record = await self.assemble_context(prompt, image_urls)
|
||||
new_record = await self.assemble_context(
|
||||
prompt, image_urls, extra_user_content_parts
|
||||
)
|
||||
context_query = self._ensure_message_to_dicts(contexts)
|
||||
if new_record:
|
||||
context_query.append(new_record)
|
||||
@@ -797,33 +804,75 @@ class ProviderGoogleGenAI(Provider):
|
||||
self.chosen_api_key = key
|
||||
self._init_client()
|
||||
|
||||
async def assemble_context(self, text: str, image_urls: list[str] | None = None):
|
||||
async def assemble_context(
|
||||
self,
|
||||
text: str,
|
||||
image_urls: list[str] | None = None,
|
||||
extra_user_content_parts: list[ContentPart] | None = None,
|
||||
):
|
||||
"""组装上下文。"""
|
||||
if image_urls:
|
||||
user_content = {
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": text if text else "[图片]"}],
|
||||
|
||||
async def resolve_image_part(image_url: str) -> dict | None:
|
||||
if image_url.startswith("http"):
|
||||
image_path = await download_image_by_url(image_url)
|
||||
image_data = await self.encode_image_bs64(image_path)
|
||||
elif image_url.startswith("file:///"):
|
||||
image_path = image_url.replace("file:///", "")
|
||||
image_data = await self.encode_image_bs64(image_path)
|
||||
else:
|
||||
image_data = await self.encode_image_bs64(image_url)
|
||||
if not image_data:
|
||||
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
|
||||
return None
|
||||
return {
|
||||
"type": "image_url",
|
||||
"image_url": {"url": image_data},
|
||||
}
|
||||
for image_url in image_urls:
|
||||
if image_url.startswith("http"):
|
||||
image_path = await download_image_by_url(image_url)
|
||||
image_data = await self.encode_image_bs64(image_path)
|
||||
elif image_url.startswith("file:///"):
|
||||
image_path = image_url.replace("file:///", "")
|
||||
image_data = await self.encode_image_bs64(image_path)
|
||||
|
||||
# 构建内容块列表
|
||||
content_blocks = []
|
||||
|
||||
# 1. 用户原始发言(OpenAI 建议:用户发言在前)
|
||||
if text:
|
||||
content_blocks.append({"type": "text", "text": text})
|
||||
elif image_urls:
|
||||
# 如果没有文本但有图片,添加占位文本
|
||||
content_blocks.append({"type": "text", "text": "[图片]"})
|
||||
elif extra_user_content_parts:
|
||||
# 如果只有额外内容块,也需要添加占位文本
|
||||
content_blocks.append({"type": "text", "text": " "})
|
||||
|
||||
# 2. 额外的内容块(系统提醒、指令等)
|
||||
if extra_user_content_parts:
|
||||
for part in extra_user_content_parts:
|
||||
if isinstance(part, TextPart):
|
||||
content_blocks.append({"type": "text", "text": part.text})
|
||||
elif isinstance(part, ImageURLPart):
|
||||
image_part = await resolve_image_part(part.image_url.url)
|
||||
if image_part:
|
||||
content_blocks.append(image_part)
|
||||
else:
|
||||
image_data = await self.encode_image_bs64(image_url)
|
||||
if not image_data:
|
||||
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
|
||||
continue
|
||||
user_content["content"].append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": image_data},
|
||||
},
|
||||
)
|
||||
return user_content
|
||||
return {"role": "user", "content": text}
|
||||
raise ValueError(f"不支持的额外内容块类型: {type(part)}")
|
||||
|
||||
# 3. 图片内容
|
||||
if image_urls:
|
||||
for image_url in image_urls:
|
||||
image_part = await resolve_image_part(image_url)
|
||||
if image_part:
|
||||
content_blocks.append(image_part)
|
||||
|
||||
# 如果只有主文本且没有额外内容块和图片,返回简单格式以保持向后兼容
|
||||
if (
|
||||
text
|
||||
and not extra_user_content_parts
|
||||
and not image_urls
|
||||
and len(content_blocks) == 1
|
||||
and content_blocks[0]["type"] == "text"
|
||||
):
|
||||
return {"role": "user", "content": content_blocks[0]["text"]}
|
||||
|
||||
# 否则返回多模态格式
|
||||
return {"role": "user", "content": content_blocks}
|
||||
|
||||
async def encode_image_bs64(self, image_url: str) -> str:
|
||||
"""将图片转换为 base64"""
|
||||
|
||||
@@ -17,7 +17,7 @@ from openai.types.completion_usage import CompletionUsage
|
||||
import astrbot.core.message.components as Comp
|
||||
from astrbot import logger
|
||||
from astrbot.api.provider import Provider
|
||||
from astrbot.core.agent.message import Message
|
||||
from astrbot.core.agent.message import ContentPart, ImageURLPart, Message, TextPart
|
||||
from astrbot.core.agent.tool import ToolSet
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from astrbot.core.provider.entities import LLMResponse, TokenUsage, ToolCallsResult
|
||||
@@ -348,6 +348,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
system_prompt: str | None = None,
|
||||
tool_calls_result: ToolCallsResult | list[ToolCallsResult] | None = None,
|
||||
model: str | None = None,
|
||||
extra_user_content_parts: list[ContentPart] | None = None,
|
||||
**kwargs,
|
||||
) -> tuple:
|
||||
"""准备聊天所需的有效载荷和上下文"""
|
||||
@@ -355,7 +356,9 @@ class ProviderOpenAIOfficial(Provider):
|
||||
contexts = []
|
||||
new_record = None
|
||||
if prompt is not None:
|
||||
new_record = await self.assemble_context(prompt, image_urls)
|
||||
new_record = await self.assemble_context(
|
||||
prompt, image_urls, extra_user_content_parts
|
||||
)
|
||||
context_query = self._ensure_message_to_dicts(contexts)
|
||||
if new_record:
|
||||
context_query.append(new_record)
|
||||
@@ -476,6 +479,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
system_prompt=None,
|
||||
tool_calls_result=None,
|
||||
model=None,
|
||||
extra_user_content_parts=None,
|
||||
**kwargs,
|
||||
) -> LLMResponse:
|
||||
payloads, context_query = await self._prepare_chat_payload(
|
||||
@@ -485,6 +489,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
system_prompt,
|
||||
tool_calls_result,
|
||||
model=model,
|
||||
extra_user_content_parts=extra_user_content_parts,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -624,33 +629,71 @@ class ProviderOpenAIOfficial(Provider):
|
||||
self,
|
||||
text: str,
|
||||
image_urls: list[str] | None = None,
|
||||
extra_user_content_parts: list[ContentPart] | None = None,
|
||||
) -> dict:
|
||||
"""组装成符合 OpenAI 格式的 role 为 user 的消息段"""
|
||||
if image_urls:
|
||||
user_content = {
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": text if text else "[图片]"}],
|
||||
|
||||
async def resolve_image_part(image_url: str) -> dict | None:
|
||||
if image_url.startswith("http"):
|
||||
image_path = await download_image_by_url(image_url)
|
||||
image_data = await self.encode_image_bs64(image_path)
|
||||
elif image_url.startswith("file:///"):
|
||||
image_path = image_url.replace("file:///", "")
|
||||
image_data = await self.encode_image_bs64(image_path)
|
||||
else:
|
||||
image_data = await self.encode_image_bs64(image_url)
|
||||
if not image_data:
|
||||
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
|
||||
return None
|
||||
return {
|
||||
"type": "image_url",
|
||||
"image_url": {"url": image_data},
|
||||
}
|
||||
for image_url in image_urls:
|
||||
if image_url.startswith("http"):
|
||||
image_path = await download_image_by_url(image_url)
|
||||
image_data = await self.encode_image_bs64(image_path)
|
||||
elif image_url.startswith("file:///"):
|
||||
image_path = image_url.replace("file:///", "")
|
||||
image_data = await self.encode_image_bs64(image_path)
|
||||
|
||||
# 构建内容块列表
|
||||
content_blocks = []
|
||||
|
||||
# 1. 用户原始发言(OpenAI 建议:用户发言在前)
|
||||
if text:
|
||||
content_blocks.append({"type": "text", "text": text})
|
||||
elif image_urls:
|
||||
# 如果没有文本但有图片,添加占位文本
|
||||
content_blocks.append({"type": "text", "text": "[图片]"})
|
||||
elif extra_user_content_parts:
|
||||
# 如果只有额外内容块,也需要添加占位文本
|
||||
content_blocks.append({"type": "text", "text": " "})
|
||||
|
||||
# 2. 额外的内容块(系统提醒、指令等)
|
||||
if extra_user_content_parts:
|
||||
for part in extra_user_content_parts:
|
||||
if isinstance(part, TextPart):
|
||||
content_blocks.append({"type": "text", "text": part.text})
|
||||
elif isinstance(part, ImageURLPart):
|
||||
image_part = await resolve_image_part(part.image_url.url)
|
||||
if image_part:
|
||||
content_blocks.append(image_part)
|
||||
else:
|
||||
image_data = await self.encode_image_bs64(image_url)
|
||||
if not image_data:
|
||||
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
|
||||
continue
|
||||
user_content["content"].append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": image_data},
|
||||
},
|
||||
)
|
||||
return user_content
|
||||
return {"role": "user", "content": text}
|
||||
raise ValueError(f"不支持的额外内容块类型: {type(part)}")
|
||||
|
||||
# 3. 图片内容
|
||||
if image_urls:
|
||||
for image_url in image_urls:
|
||||
image_part = await resolve_image_part(image_url)
|
||||
if image_part:
|
||||
content_blocks.append(image_part)
|
||||
|
||||
# 如果只有主文本且没有额外内容块和图片,返回简单格式以保持向后兼容
|
||||
if (
|
||||
text
|
||||
and not extra_user_content_parts
|
||||
and not image_urls
|
||||
and len(content_blocks) == 1
|
||||
and content_blocks[0]["type"] == "text"
|
||||
):
|
||||
return {"role": "user", "content": content_blocks[0]["text"]}
|
||||
|
||||
# 否则返回多模态格式
|
||||
return {"role": "user", "content": content_blocks}
|
||||
|
||||
async def encode_image_bs64(self, image_url: str) -> str:
|
||||
"""将图片转换为 base64"""
|
||||
|
||||
@@ -377,7 +377,7 @@ class Context:
|
||||
if not module_path:
|
||||
_parts = []
|
||||
module_part = tool.__module__.split(".")
|
||||
flags = ["packages", "plugins"]
|
||||
flags = ["builtin_stars", "plugins"]
|
||||
for i, part in enumerate(module_part):
|
||||
_parts.append(part)
|
||||
if part in flags and i + 1 < len(module_part):
|
||||
|
||||
@@ -18,6 +18,7 @@ from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core.provider.register import llm_tools
|
||||
from astrbot.core.utils.astrbot_path import (
|
||||
get_astrbot_config_path,
|
||||
get_astrbot_path,
|
||||
get_astrbot_plugin_path,
|
||||
)
|
||||
from astrbot.core.utils.io import remove_dir
|
||||
@@ -49,13 +50,10 @@ class PluginManager:
|
||||
"""存储插件的路径。即 data/plugins"""
|
||||
self.plugin_config_path = get_astrbot_config_path()
|
||||
"""存储插件配置的路径。data/config"""
|
||||
self.reserved_plugin_path = os.path.abspath(
|
||||
os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)),
|
||||
"../../../packages",
|
||||
),
|
||||
self.reserved_plugin_path = os.path.join(
|
||||
get_astrbot_path(), "astrbot", "builtin_stars"
|
||||
)
|
||||
"""保留插件的路径。在 packages 目录下"""
|
||||
"""保留插件的路径。在 astrbot/builtin_stars 目录下"""
|
||||
self.conf_schema_fname = "_conf_schema.json"
|
||||
self.logo_fname = "logo.png"
|
||||
"""插件配置 Schema 文件名"""
|
||||
@@ -252,7 +250,7 @@ class PluginManager:
|
||||
list[str]: 与该插件相关的模块名列表
|
||||
|
||||
"""
|
||||
prefix = "packages." if is_reserved else "data.plugins."
|
||||
prefix = "astrbot.builtin_stars." if is_reserved else "data.plugins."
|
||||
return [
|
||||
key
|
||||
for key in list(sys.modules.keys())
|
||||
@@ -270,7 +268,7 @@ class PluginManager:
|
||||
可以基于模块名模式或插件目录名移除模块,用于清理插件相关的模块缓存
|
||||
|
||||
Args:
|
||||
module_patterns: 要移除的模块名模式列表(例如 ["data.plugins", "packages"])
|
||||
module_patterns: 要移除的模块名模式列表(例如 ["data.plugins", "astrbot.builtin_stars"])
|
||||
root_dir_name: 插件根目录名,用于移除与该插件相关的所有模块
|
||||
is_reserved: 插件是否为保留插件(影响模块路径前缀)
|
||||
|
||||
@@ -382,9 +380,9 @@ class PluginManager:
|
||||
reserved = plugin_module.get(
|
||||
"reserved",
|
||||
False,
|
||||
) # 是否是保留插件。目前在 packages/ 目录下的都是保留插件。保留插件不可以卸载。
|
||||
) # 是否是保留插件。目前在 astrbot/builtin_stars 目录下的都是保留插件。保留插件不可以卸载。
|
||||
|
||||
path = "data.plugins." if not reserved else "packages."
|
||||
path = "data.plugins." if not reserved else "astrbot.builtin_stars."
|
||||
path += root_dir_name + "." + module_str
|
||||
|
||||
# 检查是否需要载入指定的插件
|
||||
@@ -829,7 +827,7 @@ class PluginManager:
|
||||
if (
|
||||
mp
|
||||
and mp.startswith(plugin_module_path)
|
||||
and not mp.endswith(("packages", "data.plugins"))
|
||||
and not mp.endswith(("astrbot.builtin_stars", "data.plugins"))
|
||||
):
|
||||
to_remove.append(func_tool)
|
||||
for func_tool in to_remove:
|
||||
@@ -884,7 +882,7 @@ class PluginManager:
|
||||
plugin.module_path
|
||||
and mp
|
||||
and plugin.module_path.startswith(mp)
|
||||
and not mp.endswith(("packages", "data.plugins"))
|
||||
and not mp.endswith(("astrbot.builtin_stars", "data.plugins"))
|
||||
):
|
||||
func_tool.active = False
|
||||
if func_tool.name not in inactivated_llm_tools:
|
||||
@@ -933,7 +931,7 @@ class PluginManager:
|
||||
plugin.module_path
|
||||
and mp
|
||||
and plugin.module_path.startswith(mp)
|
||||
and not mp.endswith(("packages", "data.plugins"))
|
||||
and not mp.endswith(("astrbot.builtin_stars", "data.plugins"))
|
||||
and func_tool.name in inactivated_llm_tools
|
||||
):
|
||||
inactivated_llm_tools.remove(func_tool.name)
|
||||
|
||||
@@ -5,6 +5,10 @@
|
||||
数据目录路径:固定为根目录下的 data 目录
|
||||
配置文件路径:固定为数据目录下的 config 目录
|
||||
插件目录路径:固定为数据目录下的 plugins 目录
|
||||
插件数据目录路径:固定为数据目录下的 plugin_data 目录
|
||||
T2I 模板目录路径:固定为数据目录下的 t2i_templates 目录
|
||||
WebChat 数据目录路径:固定为数据目录下的 webchat 目录
|
||||
临时文件目录路径:固定为数据目录下的 temp 目录
|
||||
"""
|
||||
|
||||
import os
|
||||
@@ -37,3 +41,33 @@ def get_astrbot_config_path() -> str:
|
||||
def get_astrbot_plugin_path() -> str:
|
||||
"""获取Astrbot插件目录路径"""
|
||||
return os.path.realpath(os.path.join(get_astrbot_data_path(), "plugins"))
|
||||
|
||||
|
||||
def get_astrbot_plugin_data_path() -> str:
|
||||
"""获取Astrbot插件数据目录路径"""
|
||||
return os.path.realpath(os.path.join(get_astrbot_data_path(), "plugin_data"))
|
||||
|
||||
|
||||
def get_astrbot_t2i_templates_path() -> str:
|
||||
"""获取Astrbot T2I 模板目录路径"""
|
||||
return os.path.realpath(os.path.join(get_astrbot_data_path(), "t2i_templates"))
|
||||
|
||||
|
||||
def get_astrbot_webchat_path() -> str:
|
||||
"""获取Astrbot WebChat 数据目录路径"""
|
||||
return os.path.realpath(os.path.join(get_astrbot_data_path(), "webchat"))
|
||||
|
||||
|
||||
def get_astrbot_temp_path() -> str:
|
||||
"""获取Astrbot临时文件目录路径"""
|
||||
return os.path.realpath(os.path.join(get_astrbot_data_path(), "temp"))
|
||||
|
||||
|
||||
def get_astrbot_knowledge_base_path() -> str:
|
||||
"""获取Astrbot知识库根目录路径"""
|
||||
return os.path.realpath(os.path.join(get_astrbot_data_path(), "knowledge_base"))
|
||||
|
||||
|
||||
def get_astrbot_backups_path() -> str:
|
||||
"""获取Astrbot备份目录路径"""
|
||||
return os.path.realpath(os.path.join(get_astrbot_data_path(), "backups"))
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from .auth import AuthRoute
|
||||
from .backup import BackupRoute
|
||||
from .chat import ChatRoute
|
||||
from .command import CommandRoute
|
||||
from .config import ConfigRoute
|
||||
@@ -17,6 +18,7 @@ from .update import UpdateRoute
|
||||
|
||||
__all__ = [
|
||||
"AuthRoute",
|
||||
"BackupRoute",
|
||||
"ChatRoute",
|
||||
"CommandRoute",
|
||||
"ConfigRoute",
|
||||
|
||||
@@ -0,0 +1,589 @@
|
||||
"""备份管理 API 路由"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import re
|
||||
import traceback
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
from quart import request, send_file
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.backup.exporter import AstrBotExporter
|
||||
from astrbot.core.backup.importer import AstrBotImporter
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.utils.astrbot_path import (
|
||||
get_astrbot_backups_path,
|
||||
get_astrbot_data_path,
|
||||
)
|
||||
|
||||
from .route import Response, Route, RouteContext
|
||||
|
||||
|
||||
def secure_filename(filename: str) -> str:
|
||||
"""清洗文件名,移除路径遍历字符和危险字符
|
||||
|
||||
Args:
|
||||
filename: 原始文件名
|
||||
|
||||
Returns:
|
||||
安全的文件名
|
||||
"""
|
||||
# 跨平台处理:先将反斜杠替换为正斜杠,再取文件名
|
||||
filename = filename.replace("\\", "/")
|
||||
# 仅保留文件名部分,移除路径
|
||||
filename = os.path.basename(filename)
|
||||
|
||||
# 替换路径遍历字符
|
||||
filename = filename.replace("..", "_")
|
||||
|
||||
# 仅保留字母、数字、下划线、连字符、点
|
||||
filename = re.sub(r"[^\w\-.]", "_", filename)
|
||||
|
||||
# 移除前导点(隐藏文件)和尾部点
|
||||
filename = filename.strip(".")
|
||||
|
||||
# 如果文件名为空或只包含下划线,生成一个默认名称
|
||||
if not filename or filename.replace("_", "") == "":
|
||||
filename = "backup"
|
||||
|
||||
return filename
|
||||
|
||||
|
||||
def generate_unique_filename(original_filename: str) -> str:
|
||||
"""生成唯一的文件名,添加时间戳前缀
|
||||
|
||||
Args:
|
||||
original_filename: 原始文件名(已清洗)
|
||||
|
||||
Returns:
|
||||
唯一的文件名
|
||||
"""
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
name, ext = os.path.splitext(original_filename)
|
||||
return f"uploaded_{timestamp}_{name}{ext}"
|
||||
|
||||
|
||||
class BackupRoute(Route):
|
||||
"""备份管理路由
|
||||
|
||||
提供备份导出、导入、列表等 API 接口
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
context: RouteContext,
|
||||
db: BaseDatabase,
|
||||
core_lifecycle: AstrBotCoreLifecycle,
|
||||
) -> None:
|
||||
super().__init__(context)
|
||||
self.db = db
|
||||
self.core_lifecycle = core_lifecycle
|
||||
self.backup_dir = get_astrbot_backups_path()
|
||||
self.data_dir = get_astrbot_data_path()
|
||||
|
||||
# 任务状态跟踪
|
||||
self.backup_tasks: dict[str, dict] = {}
|
||||
self.backup_progress: dict[str, dict] = {}
|
||||
|
||||
# 注册路由
|
||||
self.routes = {
|
||||
"/backup/list": ("GET", self.list_backups),
|
||||
"/backup/export": ("POST", self.export_backup),
|
||||
"/backup/upload": ("POST", self.upload_backup), # 上传文件
|
||||
"/backup/check": ("POST", self.check_backup), # 预检查
|
||||
"/backup/import": ("POST", self.import_backup), # 确认导入
|
||||
"/backup/progress": ("GET", self.get_progress),
|
||||
"/backup/download": ("GET", self.download_backup),
|
||||
"/backup/delete": ("POST", self.delete_backup),
|
||||
}
|
||||
self.register_routes()
|
||||
|
||||
def _init_task(self, task_id: str, task_type: str, status: str = "pending") -> None:
|
||||
"""初始化任务状态"""
|
||||
self.backup_tasks[task_id] = {
|
||||
"type": task_type,
|
||||
"status": status,
|
||||
"result": None,
|
||||
"error": None,
|
||||
}
|
||||
self.backup_progress[task_id] = {
|
||||
"status": status,
|
||||
"stage": "waiting",
|
||||
"current": 0,
|
||||
"total": 100,
|
||||
"message": "",
|
||||
}
|
||||
|
||||
def _set_task_result(
|
||||
self,
|
||||
task_id: str,
|
||||
status: str,
|
||||
result: dict | None = None,
|
||||
error: str | None = None,
|
||||
) -> None:
|
||||
"""设置任务结果"""
|
||||
if task_id in self.backup_tasks:
|
||||
self.backup_tasks[task_id]["status"] = status
|
||||
self.backup_tasks[task_id]["result"] = result
|
||||
self.backup_tasks[task_id]["error"] = error
|
||||
if task_id in self.backup_progress:
|
||||
self.backup_progress[task_id]["status"] = status
|
||||
|
||||
def _update_progress(
|
||||
self,
|
||||
task_id: str,
|
||||
*,
|
||||
status: str | None = None,
|
||||
stage: str | None = None,
|
||||
current: int | None = None,
|
||||
total: int | None = None,
|
||||
message: str | None = None,
|
||||
) -> None:
|
||||
"""更新任务进度"""
|
||||
if task_id not in self.backup_progress:
|
||||
return
|
||||
p = self.backup_progress[task_id]
|
||||
if status is not None:
|
||||
p["status"] = status
|
||||
if stage is not None:
|
||||
p["stage"] = stage
|
||||
if current is not None:
|
||||
p["current"] = current
|
||||
if total is not None:
|
||||
p["total"] = total
|
||||
if message is not None:
|
||||
p["message"] = message
|
||||
|
||||
def _make_progress_callback(self, task_id: str):
|
||||
"""创建进度回调函数"""
|
||||
|
||||
async def _callback(stage: str, current: int, total: int, message: str = ""):
|
||||
self._update_progress(
|
||||
task_id,
|
||||
status="processing",
|
||||
stage=stage,
|
||||
current=current,
|
||||
total=total,
|
||||
message=message,
|
||||
)
|
||||
|
||||
return _callback
|
||||
|
||||
async def list_backups(self):
|
||||
"""获取备份列表
|
||||
|
||||
Query 参数:
|
||||
- page: 页码 (默认 1)
|
||||
- page_size: 每页数量 (默认 20)
|
||||
"""
|
||||
try:
|
||||
page = request.args.get("page", 1, type=int)
|
||||
page_size = request.args.get("page_size", 20, type=int)
|
||||
|
||||
# 确保备份目录存在
|
||||
Path(self.backup_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 获取所有备份文件
|
||||
backup_files = []
|
||||
for filename in os.listdir(self.backup_dir):
|
||||
if filename.endswith(".zip") and filename.startswith("astrbot_backup_"):
|
||||
file_path = os.path.join(self.backup_dir, filename)
|
||||
stat = os.stat(file_path)
|
||||
backup_files.append(
|
||||
{
|
||||
"filename": filename,
|
||||
"size": stat.st_size,
|
||||
"created_at": stat.st_mtime,
|
||||
}
|
||||
)
|
||||
|
||||
# 按创建时间倒序排序
|
||||
backup_files.sort(key=lambda x: x["created_at"], reverse=True)
|
||||
|
||||
# 分页
|
||||
start = (page - 1) * page_size
|
||||
end = start + page_size
|
||||
items = backup_files[start:end]
|
||||
|
||||
return (
|
||||
Response()
|
||||
.ok(
|
||||
{
|
||||
"items": items,
|
||||
"total": len(backup_files),
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
}
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"获取备份列表失败: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(f"获取备份列表失败: {e!s}").__dict__
|
||||
|
||||
async def export_backup(self):
|
||||
"""创建备份
|
||||
|
||||
返回:
|
||||
- task_id: 任务ID,用于查询导出进度
|
||||
"""
|
||||
try:
|
||||
# 生成任务ID
|
||||
task_id = str(uuid.uuid4())
|
||||
|
||||
# 初始化任务状态
|
||||
self._init_task(task_id, "export", "pending")
|
||||
|
||||
# 启动后台导出任务
|
||||
asyncio.create_task(self._background_export_task(task_id))
|
||||
|
||||
return (
|
||||
Response()
|
||||
.ok(
|
||||
{
|
||||
"task_id": task_id,
|
||||
"message": "export task created, processing in background",
|
||||
}
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"创建备份失败: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(f"创建备份失败: {e!s}").__dict__
|
||||
|
||||
async def _background_export_task(self, task_id: str):
|
||||
"""后台导出任务"""
|
||||
try:
|
||||
self._update_progress(task_id, status="processing", message="正在初始化...")
|
||||
|
||||
# 获取知识库管理器
|
||||
kb_manager = getattr(self.core_lifecycle, "kb_manager", None)
|
||||
|
||||
exporter = AstrBotExporter(
|
||||
main_db=self.db,
|
||||
kb_manager=kb_manager,
|
||||
config_path=os.path.join(self.data_dir, "cmd_config.json"),
|
||||
)
|
||||
|
||||
# 创建进度回调
|
||||
progress_callback = self._make_progress_callback(task_id)
|
||||
|
||||
# 执行导出
|
||||
zip_path = await exporter.export_all(
|
||||
output_dir=self.backup_dir,
|
||||
progress_callback=progress_callback,
|
||||
)
|
||||
|
||||
# 设置成功结果
|
||||
self._set_task_result(
|
||||
task_id,
|
||||
"completed",
|
||||
result={
|
||||
"filename": os.path.basename(zip_path),
|
||||
"path": zip_path,
|
||||
"size": os.path.getsize(zip_path),
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"后台导出任务 {task_id} 失败: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
self._set_task_result(task_id, "failed", error=str(e))
|
||||
|
||||
async def upload_backup(self):
|
||||
"""上传备份文件
|
||||
|
||||
将备份文件上传到服务器,返回保存的文件名。
|
||||
上传后应调用 check_backup 进行预检查。
|
||||
|
||||
Form Data:
|
||||
- file: 备份文件 (.zip)
|
||||
|
||||
返回:
|
||||
- filename: 保存的文件名
|
||||
"""
|
||||
try:
|
||||
files = await request.files
|
||||
if "file" not in files:
|
||||
return Response().error("缺少备份文件").__dict__
|
||||
|
||||
file = files["file"]
|
||||
if not file.filename or not file.filename.endswith(".zip"):
|
||||
return Response().error("请上传 ZIP 格式的备份文件").__dict__
|
||||
|
||||
# 清洗文件名并生成唯一名称,防止路径遍历和覆盖
|
||||
safe_filename = secure_filename(file.filename)
|
||||
unique_filename = generate_unique_filename(safe_filename)
|
||||
|
||||
# 保存上传的文件
|
||||
Path(self.backup_dir).mkdir(parents=True, exist_ok=True)
|
||||
zip_path = os.path.join(self.backup_dir, unique_filename)
|
||||
await file.save(zip_path)
|
||||
|
||||
logger.info(
|
||||
f"上传的备份文件已保存: {unique_filename} (原始名称: {file.filename})"
|
||||
)
|
||||
|
||||
return (
|
||||
Response()
|
||||
.ok(
|
||||
{
|
||||
"filename": unique_filename,
|
||||
"original_filename": file.filename,
|
||||
"size": os.path.getsize(zip_path),
|
||||
}
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"上传备份文件失败: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(f"上传备份文件失败: {e!s}").__dict__
|
||||
|
||||
async def check_backup(self):
|
||||
"""预检查备份文件
|
||||
|
||||
检查备份文件的版本兼容性,返回确认信息。
|
||||
用户确认后调用 import_backup 执行导入。
|
||||
|
||||
JSON Body:
|
||||
- filename: 已上传的备份文件名
|
||||
|
||||
返回:
|
||||
- ImportPreCheckResult: 预检查结果
|
||||
"""
|
||||
try:
|
||||
data = await request.json
|
||||
filename = data.get("filename")
|
||||
if not filename:
|
||||
return Response().error("缺少 filename 参数").__dict__
|
||||
|
||||
# 安全检查 - 防止路径遍历
|
||||
if ".." in filename or "/" in filename or "\\" in filename:
|
||||
return Response().error("无效的文件名").__dict__
|
||||
|
||||
zip_path = os.path.join(self.backup_dir, filename)
|
||||
if not os.path.exists(zip_path):
|
||||
return Response().error(f"备份文件不存在: {filename}").__dict__
|
||||
|
||||
# 获取知识库管理器(用于构造 importer)
|
||||
kb_manager = getattr(self.core_lifecycle, "kb_manager", None)
|
||||
|
||||
importer = AstrBotImporter(
|
||||
main_db=self.db,
|
||||
kb_manager=kb_manager,
|
||||
config_path=os.path.join(self.data_dir, "cmd_config.json"),
|
||||
)
|
||||
|
||||
# 执行预检查
|
||||
check_result = importer.pre_check(zip_path)
|
||||
|
||||
return Response().ok(check_result.to_dict()).__dict__
|
||||
except Exception as e:
|
||||
logger.error(f"预检查备份文件失败: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(f"预检查备份文件失败: {e!s}").__dict__
|
||||
|
||||
async def import_backup(self):
|
||||
"""执行备份导入
|
||||
|
||||
在用户确认后执行实际的导入操作。
|
||||
需要先调用 upload_backup 上传文件,再调用 check_backup 预检查。
|
||||
|
||||
JSON Body:
|
||||
- filename: 已上传的备份文件名(必填)
|
||||
- confirmed: 用户已确认(必填,必须为 true)
|
||||
|
||||
返回:
|
||||
- task_id: 任务ID,用于查询导入进度
|
||||
"""
|
||||
try:
|
||||
data = await request.json
|
||||
filename = data.get("filename")
|
||||
confirmed = data.get("confirmed", False)
|
||||
|
||||
if not filename:
|
||||
return Response().error("缺少 filename 参数").__dict__
|
||||
|
||||
if not confirmed:
|
||||
return (
|
||||
Response()
|
||||
.error("请先确认导入。导入将会清空并覆盖现有数据,此操作不可撤销。")
|
||||
.__dict__
|
||||
)
|
||||
|
||||
# 安全检查 - 防止路径遍历
|
||||
if ".." in filename or "/" in filename or "\\" in filename:
|
||||
return Response().error("无效的文件名").__dict__
|
||||
|
||||
zip_path = os.path.join(self.backup_dir, filename)
|
||||
if not os.path.exists(zip_path):
|
||||
return Response().error(f"备份文件不存在: {filename}").__dict__
|
||||
|
||||
# 生成任务ID
|
||||
task_id = str(uuid.uuid4())
|
||||
|
||||
# 初始化任务状态
|
||||
self._init_task(task_id, "import", "pending")
|
||||
|
||||
# 启动后台导入任务
|
||||
asyncio.create_task(self._background_import_task(task_id, zip_path))
|
||||
|
||||
return (
|
||||
Response()
|
||||
.ok(
|
||||
{
|
||||
"task_id": task_id,
|
||||
"message": "import task created, processing in background",
|
||||
}
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"导入备份失败: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(f"导入备份失败: {e!s}").__dict__
|
||||
|
||||
async def _background_import_task(self, task_id: str, zip_path: str):
|
||||
"""后台导入任务"""
|
||||
try:
|
||||
self._update_progress(task_id, status="processing", message="正在初始化...")
|
||||
|
||||
# 获取知识库管理器
|
||||
kb_manager = getattr(self.core_lifecycle, "kb_manager", None)
|
||||
|
||||
importer = AstrBotImporter(
|
||||
main_db=self.db,
|
||||
kb_manager=kb_manager,
|
||||
config_path=os.path.join(self.data_dir, "cmd_config.json"),
|
||||
)
|
||||
|
||||
# 创建进度回调
|
||||
progress_callback = self._make_progress_callback(task_id)
|
||||
|
||||
# 执行导入
|
||||
result = await importer.import_all(
|
||||
zip_path=zip_path,
|
||||
mode="replace",
|
||||
progress_callback=progress_callback,
|
||||
)
|
||||
|
||||
# 设置结果
|
||||
if result.success:
|
||||
self._set_task_result(
|
||||
task_id,
|
||||
"completed",
|
||||
result=result.to_dict(),
|
||||
)
|
||||
else:
|
||||
self._set_task_result(
|
||||
task_id,
|
||||
"failed",
|
||||
error="; ".join(result.errors),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"后台导入任务 {task_id} 失败: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
self._set_task_result(task_id, "failed", error=str(e))
|
||||
|
||||
async def get_progress(self):
|
||||
"""获取任务进度
|
||||
|
||||
Query 参数:
|
||||
- task_id: 任务 ID (必填)
|
||||
"""
|
||||
try:
|
||||
task_id = request.args.get("task_id")
|
||||
if not task_id:
|
||||
return Response().error("缺少参数 task_id").__dict__
|
||||
|
||||
if task_id not in self.backup_tasks:
|
||||
return Response().error("找不到该任务").__dict__
|
||||
|
||||
task_info = self.backup_tasks[task_id]
|
||||
status = task_info["status"]
|
||||
|
||||
response_data = {
|
||||
"task_id": task_id,
|
||||
"type": task_info["type"],
|
||||
"status": status,
|
||||
}
|
||||
|
||||
# 如果任务正在处理,返回进度信息
|
||||
if status == "processing" and task_id in self.backup_progress:
|
||||
response_data["progress"] = self.backup_progress[task_id]
|
||||
|
||||
# 如果任务完成,返回结果
|
||||
if status == "completed":
|
||||
response_data["result"] = task_info["result"]
|
||||
|
||||
# 如果任务失败,返回错误信息
|
||||
if status == "failed":
|
||||
response_data["error"] = task_info["error"]
|
||||
|
||||
return Response().ok(response_data).__dict__
|
||||
except Exception as e:
|
||||
logger.error(f"获取任务进度失败: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(f"获取任务进度失败: {e!s}").__dict__
|
||||
|
||||
async def download_backup(self):
|
||||
"""下载备份文件
|
||||
|
||||
Query 参数:
|
||||
- filename: 备份文件名 (必填)
|
||||
"""
|
||||
try:
|
||||
filename = request.args.get("filename")
|
||||
if not filename:
|
||||
return Response().error("缺少参数 filename").__dict__
|
||||
|
||||
# 安全检查 - 防止路径遍历
|
||||
if ".." in filename or "/" in filename or "\\" in filename:
|
||||
return Response().error("无效的文件名").__dict__
|
||||
|
||||
file_path = os.path.join(self.backup_dir, filename)
|
||||
if not os.path.exists(file_path):
|
||||
return Response().error("备份文件不存在").__dict__
|
||||
|
||||
return await send_file(
|
||||
file_path,
|
||||
as_attachment=True,
|
||||
attachment_filename=filename,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"下载备份失败: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(f"下载备份失败: {e!s}").__dict__
|
||||
|
||||
async def delete_backup(self):
|
||||
"""删除备份文件
|
||||
|
||||
Body:
|
||||
- filename: 备份文件名 (必填)
|
||||
"""
|
||||
try:
|
||||
data = await request.json
|
||||
filename = data.get("filename")
|
||||
if not filename:
|
||||
return Response().error("缺少参数 filename").__dict__
|
||||
|
||||
# 安全检查 - 防止路径遍历
|
||||
if ".." in filename or "/" in filename or "\\" in filename:
|
||||
return Response().error("无效的文件名").__dict__
|
||||
|
||||
file_path = os.path.join(self.backup_dir, filename)
|
||||
if not os.path.exists(file_path):
|
||||
return Response().error("备份文件不存在").__dict__
|
||||
|
||||
os.remove(file_path)
|
||||
return Response().ok(message="删除备份成功").__dict__
|
||||
except Exception as e:
|
||||
logger.error(f"删除备份失败: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(f"删除备份失败: {e!s}").__dict__
|
||||
@@ -1,15 +1,26 @@
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import cast
|
||||
|
||||
from quart import Response as QuartResponse
|
||||
from quart import make_response
|
||||
from quart import make_response, request
|
||||
|
||||
from astrbot.core import LogBroker, logger
|
||||
|
||||
from .route import Response, Route, RouteContext
|
||||
|
||||
|
||||
def _format_log_sse(log: dict, ts: float) -> str:
|
||||
"""辅助函数:格式化 SSE 消息"""
|
||||
payload = {
|
||||
"type": "log",
|
||||
**log,
|
||||
}
|
||||
return f"id: {ts}\ndata: {json.dumps(payload, ensure_ascii=False)}\n\n"
|
||||
|
||||
|
||||
class LogRoute(Route):
|
||||
def __init__(self, context: RouteContext, log_broker: LogBroker) -> None:
|
||||
super().__init__(context)
|
||||
@@ -21,21 +32,44 @@ class LogRoute(Route):
|
||||
methods=["GET"],
|
||||
)
|
||||
|
||||
async def log(self):
|
||||
async def _replay_cached_logs(
|
||||
self, last_event_id: str
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""辅助生成器:重放缓存的日志"""
|
||||
try:
|
||||
last_ts = float(last_event_id)
|
||||
cached_logs = list(self.log_broker.log_cache)
|
||||
|
||||
for log_item in cached_logs:
|
||||
log_ts = float(log_item.get("time", 0))
|
||||
|
||||
if log_ts > last_ts:
|
||||
yield _format_log_sse(log_item, log_ts)
|
||||
|
||||
except ValueError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"Log SSE 补发历史错误: {e}")
|
||||
|
||||
async def log(self) -> QuartResponse:
|
||||
last_event_id = request.headers.get("Last-Event-ID")
|
||||
|
||||
async def stream():
|
||||
queue = None
|
||||
try:
|
||||
if last_event_id:
|
||||
async for event in self._replay_cached_logs(last_event_id):
|
||||
yield event
|
||||
|
||||
queue = self.log_broker.register()
|
||||
while True:
|
||||
message = await queue.get()
|
||||
payload = {
|
||||
"type": "log",
|
||||
**message, # see astrbot/core/log.py
|
||||
}
|
||||
yield f"data: {json.dumps(payload, ensure_ascii=False)}\n\n"
|
||||
current_ts = message.get("time", time.time())
|
||||
yield _format_log_sse(message, current_ts)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except BaseException as e:
|
||||
except Exception as e:
|
||||
logger.error(f"Log SSE 连接错误: {e}")
|
||||
finally:
|
||||
if queue:
|
||||
@@ -53,7 +87,7 @@ class LogRoute(Route):
|
||||
},
|
||||
),
|
||||
)
|
||||
response.timeout = None
|
||||
response.timeout = None # type: ignore
|
||||
return response
|
||||
|
||||
async def log_history(self):
|
||||
@@ -69,6 +103,6 @@ class LogRoute(Route):
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
except BaseException as e:
|
||||
except Exception as e:
|
||||
logger.error(f"获取日志历史失败: {e}")
|
||||
return Response().error(f"获取日志历史失败: {e}").__dict__
|
||||
|
||||
@@ -19,6 +19,7 @@ from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
from astrbot.core.utils.io import get_local_ip_addresses
|
||||
|
||||
from .routes import *
|
||||
from .routes.backup import BackupRoute
|
||||
from .routes.platform import PlatformRoute
|
||||
from .routes.route import Response, RouteContext
|
||||
from .routes.session_management import SessionManagementRoute
|
||||
@@ -85,6 +86,7 @@ class AstrBotDashboard:
|
||||
self.t2i_route = T2iRoute(self.context, core_lifecycle)
|
||||
self.kb_route = KnowledgeBaseRoute(self.context, core_lifecycle)
|
||||
self.platform_route = PlatformRoute(self.context, core_lifecycle)
|
||||
self.backup_route = BackupRoute(self.context, db, core_lifecycle)
|
||||
|
||||
self.app.add_url_rule(
|
||||
"/api/plug/<path:subpath>",
|
||||
@@ -108,7 +110,12 @@ class AstrBotDashboard:
|
||||
async def auth_middleware(self):
|
||||
if not request.path.startswith("/api"):
|
||||
return None
|
||||
allowed_endpoints = ["/api/auth/login", "/api/file", "/api/platform/webhook"]
|
||||
allowed_endpoints = [
|
||||
"/api/auth/login",
|
||||
"/api/file",
|
||||
"/api/platform/webhook",
|
||||
"/api/stat/start-time",
|
||||
]
|
||||
if any(request.path.startswith(prefix) for prefix in allowed_endpoints):
|
||||
return None
|
||||
# 声明 JWT
|
||||
|
||||
@@ -0,0 +1,18 @@
|
||||
## What's Changed
|
||||
|
||||
### 修复
|
||||
|
||||
1. 修复 FishAudio TTS 不可用的问题;
|
||||
2. 修复 Anthropic API Chat Provider 部分情况下请求报错的问题;
|
||||
3. 修复部分情况下 WebUI 日志重建连接之后丢失日志的问题;
|
||||
4. 修复部分情况下 /provider 指令报错 index out of range 的问题;
|
||||
5. 修复通过 `uv` 或者 cli 方式启动 AstrBot,缺少所有内置插件的问题。
|
||||
|
||||
### 优化
|
||||
|
||||
1. 丢弃值为 None 的 `tool_call_id` 和 `tool_calls` 字段,提高接口兼容性。
|
||||
|
||||
### 新增
|
||||
|
||||
1. 支持备份 AstrBot 数据和导入数据功能(Beta)。入口:WebUi -> 设置 -> 备份。
|
||||
2. text_chat 和 text_chat_stream 接口支持额外用户内容块参数 `extra_user_content_parts`,用于在用户消息后添加额外的内容块(如系统提醒、指令等)。
|
||||
@@ -22,6 +22,7 @@
|
||||
"axios-mock-adapter": "^1.22.0",
|
||||
"chance": "1.1.11",
|
||||
"date-fns": "2.30.0",
|
||||
"event-source-polyfill": "^1.0.31",
|
||||
"highlight.js": "^11.11.1",
|
||||
"js-md5": "^0.8.3",
|
||||
"katex": "^0.16.27",
|
||||
|
||||
@@ -0,0 +1,673 @@
|
||||
<template>
|
||||
<v-dialog v-model="isOpen" persistent max-width="700" scrollable>
|
||||
<v-card>
|
||||
<v-card-title class="d-flex align-center">
|
||||
<v-icon class="mr-2">mdi-backup-restore</v-icon>
|
||||
{{ t('features.settings.backup.dialog.title') }}
|
||||
</v-card-title>
|
||||
|
||||
<v-card-text class="pa-6">
|
||||
<!-- 选项卡 -->
|
||||
<v-tabs v-model="activeTab" color="primary" class="mb-4">
|
||||
<v-tab value="export">
|
||||
<v-icon class="mr-2">mdi-export</v-icon>
|
||||
{{ t('features.settings.backup.tabs.export') }}
|
||||
</v-tab>
|
||||
<v-tab value="import">
|
||||
<v-icon class="mr-2">mdi-import</v-icon>
|
||||
{{ t('features.settings.backup.tabs.import') }}
|
||||
</v-tab>
|
||||
<v-tab value="list">
|
||||
<v-icon class="mr-2">mdi-format-list-bulleted</v-icon>
|
||||
{{ t('features.settings.backup.tabs.list') }}
|
||||
</v-tab>
|
||||
</v-tabs>
|
||||
|
||||
<v-window v-model="activeTab">
|
||||
<!-- 导出标签页 -->
|
||||
<v-window-item value="export">
|
||||
<div v-if="exportStatus === 'idle'" class="text-center py-8">
|
||||
<v-icon size="64" color="primary" class="mb-4">mdi-cloud-upload</v-icon>
|
||||
<h3 class="mb-4">{{ t('features.settings.backup.export.title') }}</h3>
|
||||
<p class="mb-4 text-grey">{{ t('features.settings.backup.export.description') }}</p>
|
||||
<v-alert type="info" variant="tonal" class="mb-4 text-left">
|
||||
<template v-slot:prepend>
|
||||
<v-icon>mdi-information</v-icon>
|
||||
</template>
|
||||
{{ t('features.settings.backup.export.includes') }}
|
||||
</v-alert>
|
||||
<v-btn color="primary" size="large" @click="startExport" :loading="exportStatus === 'processing'">
|
||||
<v-icon class="mr-2">mdi-export</v-icon>
|
||||
{{ t('features.settings.backup.export.button') }}
|
||||
</v-btn>
|
||||
</div>
|
||||
|
||||
<div v-else-if="exportStatus === 'processing'" class="text-center py-8">
|
||||
<v-progress-circular indeterminate color="primary" size="64" class="mb-4"></v-progress-circular>
|
||||
<h3 class="mb-4">{{ t('features.settings.backup.export.processing') }}</h3>
|
||||
<p class="text-grey">{{ exportProgress.message || t('features.settings.backup.export.wait') }}</p>
|
||||
<v-progress-linear :model-value="exportProgress.current" :max="exportProgress.total" class="mt-4" color="primary"></v-progress-linear>
|
||||
</div>
|
||||
|
||||
<div v-else-if="exportStatus === 'completed'" class="text-center py-8">
|
||||
<v-icon size="64" color="success" class="mb-4">mdi-check-circle</v-icon>
|
||||
<h3 class="mb-4">{{ t('features.settings.backup.export.completed') }}</h3>
|
||||
<p class="mb-4">{{ exportResult?.filename }}</p>
|
||||
<v-btn color="primary" @click="downloadBackup(exportResult?.filename)" class="mr-2">
|
||||
<v-icon class="mr-2">mdi-download</v-icon>
|
||||
{{ t('features.settings.backup.export.download') }}
|
||||
</v-btn>
|
||||
<v-btn color="grey" variant="text" @click="resetExport">
|
||||
{{ t('features.settings.backup.export.another') }}
|
||||
</v-btn>
|
||||
</div>
|
||||
|
||||
<div v-else-if="exportStatus === 'failed'" class="text-center py-8">
|
||||
<v-icon size="64" color="error" class="mb-4">mdi-alert-circle</v-icon>
|
||||
<h3 class="mb-4">{{ t('features.settings.backup.export.failed') }}</h3>
|
||||
<v-alert type="error" variant="tonal" class="mb-4">
|
||||
{{ exportError }}
|
||||
</v-alert>
|
||||
<v-btn color="primary" @click="resetExport">
|
||||
{{ t('features.settings.backup.export.retry') }}
|
||||
</v-btn>
|
||||
</div>
|
||||
</v-window-item>
|
||||
|
||||
<!-- 导入标签页 -->
|
||||
<v-window-item value="import">
|
||||
<!-- 步骤1: 选择文件 -->
|
||||
<div v-if="importStatus === 'idle'" class="py-4">
|
||||
<v-alert type="warning" variant="tonal" class="mb-4">
|
||||
<template v-slot:prepend>
|
||||
<v-icon>mdi-alert</v-icon>
|
||||
</template>
|
||||
{{ t('features.settings.backup.import.warning') }}
|
||||
</v-alert>
|
||||
|
||||
<v-file-input
|
||||
v-model="importFile"
|
||||
:label="t('features.settings.backup.import.selectFile')"
|
||||
accept=".zip"
|
||||
prepend-icon="mdi-file-upload"
|
||||
show-size
|
||||
class="mb-4"
|
||||
></v-file-input>
|
||||
|
||||
<div class="d-flex justify-center">
|
||||
<v-btn
|
||||
color="primary"
|
||||
size="large"
|
||||
@click="uploadAndCheck"
|
||||
:disabled="!importFile"
|
||||
:loading="importStatus === 'uploading'"
|
||||
>
|
||||
<v-icon class="mr-2">mdi-upload</v-icon>
|
||||
{{ t('features.settings.backup.import.uploadAndCheck') }}
|
||||
</v-btn>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 步骤1.5: 上传中 -->
|
||||
<div v-else-if="importStatus === 'uploading'" class="text-center py-8">
|
||||
<v-progress-circular indeterminate color="primary" size="64" class="mb-4"></v-progress-circular>
|
||||
<h3 class="mb-4">{{ t('features.settings.backup.import.uploading') }}</h3>
|
||||
<p class="text-grey">{{ t('features.settings.backup.import.uploadWait') }}</p>
|
||||
</div>
|
||||
|
||||
<!-- 步骤2: 确认导入 -->
|
||||
<div v-else-if="importStatus === 'confirm'" class="py-4">
|
||||
<v-alert
|
||||
:type="versionAlertType"
|
||||
variant="tonal"
|
||||
class="mb-4"
|
||||
>
|
||||
<template v-slot:prepend>
|
||||
<v-icon>{{ versionAlertIcon }}</v-icon>
|
||||
</template>
|
||||
<div class="confirm-message">
|
||||
<div class="text-h6 mb-2">{{ versionAlertTitle }}</div>
|
||||
<div class="mb-2">
|
||||
<strong>{{ t('features.settings.backup.import.version.backupVersion') }}:</strong> {{ checkResult?.backup_version }}<br>
|
||||
<strong>{{ t('features.settings.backup.import.version.currentVersion') }}:</strong> {{ checkResult?.current_version }}
|
||||
</div>
|
||||
<div v-if="checkResult?.backup_time && checkResult?.backup_time !== '未知'" class="mb-2">
|
||||
<strong>{{ t('features.settings.backup.import.version.backupTime') }}:</strong> {{ formatISODate(checkResult?.backup_time) }}
|
||||
</div>
|
||||
<div class="mt-3" style="white-space: pre-line;">{{ versionAlertMessage }}</div>
|
||||
</div>
|
||||
</v-alert>
|
||||
|
||||
<!-- 备份摘要 -->
|
||||
<v-card variant="outlined" class="mb-4" v-if="checkResult?.backup_summary">
|
||||
<v-card-title class="text-subtitle-1">
|
||||
<v-icon class="mr-2">mdi-package-variant</v-icon>
|
||||
{{ t('features.settings.backup.import.backupContents') }}
|
||||
</v-card-title>
|
||||
<v-card-text>
|
||||
<div class="d-flex flex-wrap ga-2">
|
||||
<v-chip v-if="checkResult.backup_summary.tables?.length" size="small" color="primary" variant="tonal" :ripple="false" class="non-interactive-chip">
|
||||
{{ checkResult.backup_summary.tables.length }} {{ t('features.settings.backup.import.tables') }}
|
||||
</v-chip>
|
||||
<v-chip v-if="checkResult.backup_summary.has_knowledge_bases" size="small" color="success" variant="tonal" :ripple="false" class="non-interactive-chip">
|
||||
{{ t('features.settings.backup.import.knowledgeBases') }}
|
||||
</v-chip>
|
||||
<v-chip v-if="checkResult.backup_summary.has_config" size="small" color="info" variant="tonal" :ripple="false" class="non-interactive-chip">
|
||||
{{ t('features.settings.backup.import.configFiles') }}
|
||||
</v-chip>
|
||||
<v-chip v-for="dir in (checkResult.backup_summary.directories || [])" :key="dir" size="small" color="warning" variant="tonal" :ripple="false" class="non-interactive-chip">
|
||||
{{ dir }}
|
||||
</v-chip>
|
||||
</div>
|
||||
</v-card-text>
|
||||
</v-card>
|
||||
|
||||
<!-- 警告信息 -->
|
||||
<v-alert v-if="checkResult?.warnings?.length" type="warning" variant="tonal" class="mb-4">
|
||||
<div v-for="(warning, idx) in checkResult.warnings" :key="idx">{{ warning }}</div>
|
||||
</v-alert>
|
||||
|
||||
<div class="d-flex justify-center align-center mt-4" style="gap: 16px;">
|
||||
<v-btn
|
||||
color="grey-darken-1"
|
||||
variant="outlined"
|
||||
size="large"
|
||||
@click="resetImport"
|
||||
>
|
||||
<v-icon class="mr-2">mdi-close</v-icon>
|
||||
{{ t('core.common.cancel') }}
|
||||
</v-btn>
|
||||
<v-btn
|
||||
v-if="checkResult?.can_import"
|
||||
color="error"
|
||||
size="large"
|
||||
variant="flat"
|
||||
@click="confirmImport"
|
||||
>
|
||||
<v-icon class="mr-2">mdi-alert</v-icon>
|
||||
{{ t('features.settings.backup.import.confirmImport') }}
|
||||
</v-btn>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 步骤3: 导入进行中 -->
|
||||
<div v-else-if="importStatus === 'processing'" class="text-center py-8">
|
||||
<v-progress-circular indeterminate color="primary" size="64" class="mb-4"></v-progress-circular>
|
||||
<h3 class="mb-4">{{ t('features.settings.backup.import.processing') }}</h3>
|
||||
<p class="text-grey">{{ importProgress.message || t('features.settings.backup.import.wait') }}</p>
|
||||
<v-progress-linear :model-value="importProgress.current" :max="importProgress.total" class="mt-4" color="primary"></v-progress-linear>
|
||||
</div>
|
||||
|
||||
<div v-else-if="importStatus === 'completed'" class="text-center py-8">
|
||||
<v-icon size="64" color="success" class="mb-4">mdi-check-circle</v-icon>
|
||||
<h3 class="mb-4">{{ t('features.settings.backup.import.completed') }}</h3>
|
||||
<v-alert type="info" variant="tonal" class="mb-4">
|
||||
{{ t('features.settings.backup.import.restartRequired') }}
|
||||
</v-alert>
|
||||
<v-btn color="primary" @click="restartAstrBot" class="mr-2">
|
||||
<v-icon class="mr-2">mdi-restart</v-icon>
|
||||
{{ t('features.settings.backup.import.restartNow') }}
|
||||
</v-btn>
|
||||
<v-btn color="grey" variant="text" @click="resetImport">
|
||||
{{ t('core.common.close') }}
|
||||
</v-btn>
|
||||
</div>
|
||||
|
||||
<div v-else-if="importStatus === 'failed'" class="text-center py-8">
|
||||
<v-icon size="64" color="error" class="mb-4">mdi-alert-circle</v-icon>
|
||||
<h3 class="mb-4">{{ t('features.settings.backup.import.failed') }}</h3>
|
||||
<v-alert type="error" variant="tonal" class="mb-4">
|
||||
{{ importError }}
|
||||
</v-alert>
|
||||
<v-btn color="primary" @click="resetImport">
|
||||
{{ t('features.settings.backup.import.retry') }}
|
||||
</v-btn>
|
||||
</div>
|
||||
</v-window-item>
|
||||
|
||||
<!-- 备份列表标签页 -->
|
||||
<v-window-item value="list">
|
||||
<div v-if="loadingList" class="text-center py-8">
|
||||
<v-progress-circular indeterminate color="primary"></v-progress-circular>
|
||||
</div>
|
||||
|
||||
<div v-else-if="backupList.length === 0" class="text-center py-8">
|
||||
<v-icon size="64" color="grey" class="mb-4">mdi-folder-open-outline</v-icon>
|
||||
<p class="text-grey">{{ t('features.settings.backup.list.empty') }}</p>
|
||||
</div>
|
||||
|
||||
<v-list v-else lines="two">
|
||||
<v-list-item
|
||||
v-for="backup in backupList"
|
||||
:key="backup.filename"
|
||||
>
|
||||
<template v-slot:prepend>
|
||||
<v-icon color="primary">mdi-zip-box</v-icon>
|
||||
</template>
|
||||
|
||||
<v-list-item-title>{{ backup.filename }}</v-list-item-title>
|
||||
<v-list-item-subtitle>
|
||||
{{ formatFileSize(backup.size) }} · {{ formatDate(backup.created_at) }}
|
||||
</v-list-item-subtitle>
|
||||
|
||||
<template v-slot:append>
|
||||
<v-btn icon="mdi-download" variant="text" size="small" @click="downloadBackup(backup.filename)"></v-btn>
|
||||
<v-btn icon="mdi-delete" variant="text" size="small" color="error" @click="deleteBackup(backup.filename)"></v-btn>
|
||||
</template>
|
||||
</v-list-item>
|
||||
</v-list>
|
||||
|
||||
<div class="d-flex justify-center mt-4">
|
||||
<v-btn color="primary" variant="text" @click="loadBackupList">
|
||||
<v-icon class="mr-2">mdi-refresh</v-icon>
|
||||
{{ t('features.settings.backup.list.refresh') }}
|
||||
</v-btn>
|
||||
</div>
|
||||
</v-window-item>
|
||||
</v-window>
|
||||
</v-card-text>
|
||||
|
||||
<v-card-actions class="px-6 py-4">
|
||||
<v-spacer></v-spacer>
|
||||
<v-btn color="grey" variant="text" @click="handleClose" :disabled="isProcessing">
|
||||
{{ t('core.common.close') }}
|
||||
</v-btn>
|
||||
</v-card-actions>
|
||||
</v-card>
|
||||
</v-dialog>
|
||||
|
||||
<WaitingForRestart ref="wfr"></WaitingForRestart>
|
||||
</template>
|
||||
|
||||
<script setup>
|
||||
import { ref, computed, watch } from 'vue'
|
||||
import axios from 'axios'
|
||||
import { useI18n } from '@/i18n/composables'
|
||||
import WaitingForRestart from './WaitingForRestart.vue'
|
||||
|
||||
const { t } = useI18n()
|
||||
|
||||
const isOpen = ref(false)
|
||||
const activeTab = ref('export')
|
||||
const wfr = ref(null)
|
||||
|
||||
// 导出状态
|
||||
const exportStatus = ref('idle') // idle, processing, completed, failed
|
||||
const exportTaskId = ref(null)
|
||||
const exportProgress = ref({ current: 0, total: 100, message: '' })
|
||||
const exportResult = ref(null)
|
||||
const exportError = ref('')
|
||||
|
||||
// 导入状态
|
||||
const importStatus = ref('idle') // idle, uploading, confirm, processing, completed, failed
|
||||
const importFile = ref(null)
|
||||
const importTaskId = ref(null)
|
||||
const importProgress = ref({ current: 0, total: 100, message: '' })
|
||||
const importError = ref('')
|
||||
const uploadedFilename = ref('') // 已上传的文件名
|
||||
const checkResult = ref(null) // 预检查结果
|
||||
|
||||
// 备份列表
|
||||
const loadingList = ref(false)
|
||||
const backupList = ref([])
|
||||
|
||||
// 计算属性
|
||||
const isProcessing = computed(() => {
|
||||
return exportStatus.value === 'processing' || importStatus.value === 'processing'
|
||||
})
|
||||
|
||||
// 版本检查相关的计算属性
|
||||
const versionAlertType = computed(() => {
|
||||
const status = checkResult.value?.version_status
|
||||
if (status === 'major_diff') return 'error'
|
||||
if (status === 'minor_diff') return 'warning'
|
||||
return 'info'
|
||||
})
|
||||
|
||||
const versionAlertIcon = computed(() => {
|
||||
const status = checkResult.value?.version_status
|
||||
if (status === 'major_diff') return 'mdi-close-circle'
|
||||
if (status === 'minor_diff') return 'mdi-alert'
|
||||
return 'mdi-check-circle'
|
||||
})
|
||||
|
||||
const versionAlertTitle = computed(() => {
|
||||
const status = checkResult.value?.version_status
|
||||
if (status === 'major_diff') return t('features.settings.backup.import.version.majorDiffTitle')
|
||||
if (status === 'minor_diff') return t('features.settings.backup.import.version.minorDiffTitle')
|
||||
return t('features.settings.backup.import.version.matchTitle')
|
||||
})
|
||||
|
||||
const versionAlertMessage = computed(() => {
|
||||
const status = checkResult.value?.version_status
|
||||
if (status === 'major_diff') return t('features.settings.backup.import.version.majorDiffMessage')
|
||||
if (status === 'minor_diff') return t('features.settings.backup.import.version.minorDiffMessage')
|
||||
return t('features.settings.backup.import.version.matchMessage')
|
||||
})
|
||||
|
||||
// 监听对话框打开
|
||||
watch(isOpen, (newVal) => {
|
||||
if (newVal) {
|
||||
loadBackupList()
|
||||
} else {
|
||||
resetAll()
|
||||
}
|
||||
})
|
||||
|
||||
// 监听标签页切换
|
||||
watch(activeTab, (newVal) => {
|
||||
if (newVal === 'list') {
|
||||
loadBackupList()
|
||||
}
|
||||
})
|
||||
|
||||
// 加载备份列表
|
||||
const loadBackupList = async () => {
|
||||
loadingList.value = true
|
||||
try {
|
||||
const response = await axios.get('/api/backup/list')
|
||||
if (response.data.status === 'ok') {
|
||||
backupList.value = response.data.data.items || []
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Failed to load backup list:', error)
|
||||
} finally {
|
||||
loadingList.value = false
|
||||
}
|
||||
}
|
||||
|
||||
// 开始导出
|
||||
const startExport = async () => {
|
||||
exportStatus.value = 'processing'
|
||||
exportProgress.value = { current: 0, total: 100, message: '' }
|
||||
|
||||
try {
|
||||
const response = await axios.post('/api/backup/export')
|
||||
if (response.data.status === 'ok') {
|
||||
exportTaskId.value = response.data.data.task_id
|
||||
pollExportProgress()
|
||||
} else {
|
||||
throw new Error(response.data.message)
|
||||
}
|
||||
} catch (error) {
|
||||
exportStatus.value = 'failed'
|
||||
exportError.value = error.message || 'Export failed'
|
||||
}
|
||||
}
|
||||
|
||||
// 轮询导出进度
|
||||
const pollExportProgress = async () => {
|
||||
if (!exportTaskId.value) return
|
||||
|
||||
try {
|
||||
const response = await axios.get('/api/backup/progress', {
|
||||
params: { task_id: exportTaskId.value }
|
||||
})
|
||||
|
||||
if (response.data.status === 'ok') {
|
||||
const data = response.data.data
|
||||
|
||||
if (data.status === 'processing' && data.progress) {
|
||||
exportProgress.value = {
|
||||
current: data.progress.current || 0,
|
||||
total: data.progress.total || 100,
|
||||
message: data.progress.message || ''
|
||||
}
|
||||
setTimeout(pollExportProgress, 1000)
|
||||
} else if (data.status === 'completed') {
|
||||
exportStatus.value = 'completed'
|
||||
exportResult.value = data.result
|
||||
loadBackupList()
|
||||
} else if (data.status === 'failed') {
|
||||
exportStatus.value = 'failed'
|
||||
exportError.value = data.error || 'Export failed'
|
||||
} else {
|
||||
setTimeout(pollExportProgress, 1000)
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
exportStatus.value = 'failed'
|
||||
exportError.value = error.message || 'Failed to get export progress'
|
||||
}
|
||||
}
|
||||
|
||||
// 重置导出状态
|
||||
const resetExport = () => {
|
||||
exportStatus.value = 'idle'
|
||||
exportTaskId.value = null
|
||||
exportProgress.value = { current: 0, total: 100, message: '' }
|
||||
exportResult.value = null
|
||||
exportError.value = ''
|
||||
}
|
||||
|
||||
// 上传并检查
|
||||
const uploadAndCheck = async () => {
|
||||
if (!importFile.value) return
|
||||
|
||||
importStatus.value = 'uploading'
|
||||
|
||||
try {
|
||||
// 步骤1: 上传文件
|
||||
const formData = new FormData()
|
||||
formData.append('file', importFile.value)
|
||||
|
||||
const uploadResponse = await axios.post('/api/backup/upload', formData, {
|
||||
headers: { 'Content-Type': 'multipart/form-data' }
|
||||
})
|
||||
|
||||
if (uploadResponse.data.status !== 'ok') {
|
||||
throw new Error(uploadResponse.data.message)
|
||||
}
|
||||
|
||||
uploadedFilename.value = uploadResponse.data.data.filename
|
||||
|
||||
// 步骤2: 预检查
|
||||
const checkResponse = await axios.post('/api/backup/check', {
|
||||
filename: uploadedFilename.value
|
||||
})
|
||||
|
||||
if (checkResponse.data.status !== 'ok') {
|
||||
throw new Error(checkResponse.data.message)
|
||||
}
|
||||
|
||||
checkResult.value = checkResponse.data.data
|
||||
|
||||
// 检查是否有效
|
||||
if (!checkResult.value.valid) {
|
||||
importStatus.value = 'failed'
|
||||
importError.value = checkResult.value.error || t('features.settings.backup.import.invalidBackup')
|
||||
return
|
||||
}
|
||||
|
||||
// 显示确认对话框
|
||||
importStatus.value = 'confirm'
|
||||
|
||||
} catch (error) {
|
||||
importStatus.value = 'failed'
|
||||
importError.value = error.response?.data?.message || error.message || 'Upload failed'
|
||||
}
|
||||
}
|
||||
|
||||
// 确认导入
|
||||
const confirmImport = async () => {
|
||||
if (!uploadedFilename.value) return
|
||||
|
||||
importStatus.value = 'processing'
|
||||
importProgress.value = { current: 0, total: 100, message: '' }
|
||||
|
||||
try {
|
||||
const response = await axios.post('/api/backup/import', {
|
||||
filename: uploadedFilename.value,
|
||||
confirmed: true
|
||||
})
|
||||
|
||||
if (response.data.status === 'ok') {
|
||||
importTaskId.value = response.data.data.task_id
|
||||
pollImportProgress()
|
||||
} else {
|
||||
throw new Error(response.data.message)
|
||||
}
|
||||
} catch (error) {
|
||||
importStatus.value = 'failed'
|
||||
importError.value = error.response?.data?.message || error.message || 'Import failed'
|
||||
}
|
||||
}
|
||||
|
||||
// 轮询导入进度
|
||||
const pollImportProgress = async () => {
|
||||
if (!importTaskId.value) return
|
||||
|
||||
try {
|
||||
const response = await axios.get('/api/backup/progress', {
|
||||
params: { task_id: importTaskId.value }
|
||||
})
|
||||
|
||||
if (response.data.status === 'ok') {
|
||||
const data = response.data.data
|
||||
|
||||
if (data.status === 'processing' && data.progress) {
|
||||
importProgress.value = {
|
||||
current: data.progress.current || 0,
|
||||
total: data.progress.total || 100,
|
||||
message: data.progress.message || ''
|
||||
}
|
||||
setTimeout(pollImportProgress, 1000)
|
||||
} else if (data.status === 'completed') {
|
||||
importStatus.value = 'completed'
|
||||
} else if (data.status === 'failed') {
|
||||
importStatus.value = 'failed'
|
||||
importError.value = data.error || 'Import failed'
|
||||
} else {
|
||||
setTimeout(pollImportProgress, 1000)
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
importStatus.value = 'failed'
|
||||
importError.value = error.message || 'Failed to get import progress'
|
||||
}
|
||||
}
|
||||
|
||||
// 重置导入状态
|
||||
const resetImport = () => {
|
||||
importStatus.value = 'idle'
|
||||
importFile.value = null
|
||||
importTaskId.value = null
|
||||
importProgress.value = { current: 0, total: 100, message: '' }
|
||||
importError.value = ''
|
||||
uploadedFilename.value = ''
|
||||
checkResult.value = null
|
||||
}
|
||||
|
||||
// 下载备份
|
||||
const downloadBackup = async (filename) => {
|
||||
try {
|
||||
const response = await axios.get('/api/backup/download', {
|
||||
params: { filename },
|
||||
responseType: 'blob'
|
||||
})
|
||||
|
||||
// 创建 Blob URL 并触发下载
|
||||
const blob = new Blob([response.data], { type: 'application/zip' })
|
||||
const url = window.URL.createObjectURL(blob)
|
||||
const link = document.createElement('a')
|
||||
link.href = url
|
||||
link.download = filename
|
||||
document.body.appendChild(link)
|
||||
link.click()
|
||||
document.body.removeChild(link)
|
||||
window.URL.revokeObjectURL(url)
|
||||
} catch (error) {
|
||||
console.error('Download failed:', error)
|
||||
alert(t('features.settings.backup.export.failed') + ': ' + (error.message || 'Unknown error'))
|
||||
}
|
||||
}
|
||||
|
||||
// 删除备份
|
||||
const deleteBackup = async (filename) => {
|
||||
if (!confirm(t('features.settings.backup.list.confirmDelete'))) return
|
||||
|
||||
try {
|
||||
const response = await axios.post('/api/backup/delete', { filename })
|
||||
if (response.data.status === 'ok') {
|
||||
loadBackupList()
|
||||
} else {
|
||||
alert(response.data.message || 'Delete failed')
|
||||
}
|
||||
} catch (error) {
|
||||
alert(error.message || 'Delete failed')
|
||||
}
|
||||
}
|
||||
|
||||
// 格式化文件大小
|
||||
const formatFileSize = (bytes) => {
|
||||
if (bytes === 0) return '0 B'
|
||||
const k = 1024
|
||||
const sizes = ['B', 'KB', 'MB', 'GB']
|
||||
const i = Math.floor(Math.log(bytes) / Math.log(k))
|
||||
return parseFloat((bytes / Math.pow(k, i)).toFixed(2)) + ' ' + sizes[i]
|
||||
}
|
||||
|
||||
// 格式化日期(从时间戳)
|
||||
const formatDate = (timestamp) => {
|
||||
return new Date(timestamp * 1000).toLocaleString()
|
||||
}
|
||||
|
||||
// 格式化 ISO 日期字符串
|
||||
const formatISODate = (isoString) => {
|
||||
if (!isoString) return ''
|
||||
try {
|
||||
return new Date(isoString).toLocaleString()
|
||||
} catch {
|
||||
return isoString
|
||||
}
|
||||
}
|
||||
|
||||
// 重启 AstrBot
|
||||
const restartAstrBot = () => {
|
||||
axios.post('/api/stat/restart-core').then(() => {
|
||||
if (wfr.value) {
|
||||
wfr.value.check()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// 重置所有状态
|
||||
const resetAll = () => {
|
||||
resetExport()
|
||||
resetImport()
|
||||
activeTab.value = 'export'
|
||||
}
|
||||
|
||||
// 关闭对话框
|
||||
const handleClose = () => {
|
||||
if (isProcessing.value) return
|
||||
isOpen.value = false
|
||||
}
|
||||
|
||||
// 打开对话框
|
||||
const open = () => {
|
||||
isOpen.value = true
|
||||
}
|
||||
|
||||
defineExpose({ open })
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
.v-list-item {
|
||||
border-bottom: 1px solid rgba(0, 0, 0, 0.08);
|
||||
}
|
||||
|
||||
.v-list-item:last-child {
|
||||
border-bottom: none;
|
||||
}
|
||||
|
||||
/* 禁用 Chip 的交互效果 */
|
||||
.non-interactive-chip {
|
||||
pointer-events: none;
|
||||
cursor: default;
|
||||
}
|
||||
|
||||
.non-interactive-chip:hover {
|
||||
box-shadow: none !important;
|
||||
}
|
||||
</style>
|
||||
@@ -1,12 +1,11 @@
|
||||
<script setup>
|
||||
import { useCommonStore } from '@/stores/common';
|
||||
import { storeToRefs } from 'pinia';
|
||||
import axios from 'axios';
|
||||
import { EventSourcePolyfill } from 'event-source-polyfill';
|
||||
</script>
|
||||
|
||||
<template>
|
||||
<div>
|
||||
<!-- 添加筛选级别控件 -->
|
||||
<div class="filter-controls mb-2" v-if="showLevelBtns">
|
||||
<v-chip-group v-model="selectedLevels" column multiple>
|
||||
<v-chip v-for="level in logLevels" :key="level" :color="getLevelColor(level)" filter variant="flat" size="small"
|
||||
@@ -26,20 +25,19 @@ export default {
|
||||
name: 'ConsoleDisplayer',
|
||||
data() {
|
||||
return {
|
||||
autoScroll: true, // 默认开启自动滚动
|
||||
autoScroll: true,
|
||||
logColorAnsiMap: {
|
||||
'\u001b[1;34m': 'color: #0000FF; font-weight: bold;', // bold_blue
|
||||
'\u001b[1;36m': 'color: #00FFFF; font-weight: bold;', // bold_cyan
|
||||
'\u001b[1;33m': 'color: #FFFF00; font-weight: bold;', // bold_yellow
|
||||
'\u001b[31m': 'color: #FF0000;', // red
|
||||
'\u001b[1;31m': 'color: #FF0000; font-weight: bold;', // bold_red
|
||||
'\u001b[0m': 'color: inherit; font-weight: normal;', // reset
|
||||
'\u001b[32m': 'color: #00FF00;', // green
|
||||
'\u001b[1;34m': 'color: #0000FF; font-weight: bold;',
|
||||
'\u001b[1;36m': 'color: #00FFFF; font-weight: bold;',
|
||||
'\u001b[1;33m': 'color: #FFFF00; font-weight: bold;',
|
||||
'\u001b[31m': 'color: #FF0000;',
|
||||
'\u001b[1;31m': 'color: #FF0000; font-weight: bold;',
|
||||
'\u001b[0m': 'color: inherit; font-weight: normal;',
|
||||
'\u001b[32m': 'color: #00FF00;',
|
||||
'default': 'color: #FFFFFF;'
|
||||
},
|
||||
historyNum_: -1,
|
||||
logLevels: ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
|
||||
selectedLevels: [0, 1, 2, 3, 4], // 默认选中所有级别
|
||||
selectedLevels: [0, 1, 2, 3, 4],
|
||||
levelColors: {
|
||||
'DEBUG': 'grey',
|
||||
'INFO': 'blue-lighten-3',
|
||||
@@ -47,17 +45,19 @@ export default {
|
||||
'ERROR': 'red',
|
||||
'CRITICAL': 'purple'
|
||||
},
|
||||
lastProcessedTime: 0, // 记录最后处理的日志时间戳
|
||||
localLogCache: [], // 本地日志缓存
|
||||
localLogCache: [],
|
||||
eventSource: null,
|
||||
retryTimer: null,
|
||||
retryAttempts: 0,
|
||||
maxRetryAttempts: 10,
|
||||
baseRetryDelay: 1000,
|
||||
lastEventId: null,
|
||||
}
|
||||
},
|
||||
computed: {
|
||||
commonStore() {
|
||||
return useCommonStore();
|
||||
},
|
||||
logCache() {
|
||||
return this.commonStore.log_cache;
|
||||
}
|
||||
},
|
||||
props: {
|
||||
historyNum: {
|
||||
@@ -70,41 +70,6 @@ export default {
|
||||
}
|
||||
},
|
||||
watch: {
|
||||
logCache: {
|
||||
handler(newVal) {
|
||||
// 基于 timestamp 处理新增的日志
|
||||
if (newVal && newVal.length > 0) {
|
||||
// 确保 DOM 已经准备好
|
||||
this.$nextTick(() => {
|
||||
// 合并到本地缓存并按时间排序
|
||||
const newLogs = newVal.filter(log => log.time > this.lastProcessedTime);
|
||||
|
||||
if (newLogs.length > 0) {
|
||||
this.localLogCache.push(...newLogs);
|
||||
// 按时间戳排序
|
||||
this.localLogCache.sort((a, b) => a.time - b.time);
|
||||
|
||||
// 只保留最新的 log_cache_max_len 条
|
||||
if (this.localLogCache.length > this.commonStore.log_cache_max_len) {
|
||||
this.localLogCache.splice(0, this.localLogCache.length - this.commonStore.log_cache_max_len);
|
||||
}
|
||||
|
||||
// 显示新日志
|
||||
newLogs.forEach(logItem => {
|
||||
if (this.isLevelSelected(logItem.level)) {
|
||||
this.printLog(logItem.data);
|
||||
}
|
||||
});
|
||||
|
||||
// 更新最后处理时间
|
||||
this.lastProcessedTime = Math.max(...newLogs.map(log => log.time));
|
||||
}
|
||||
});
|
||||
}
|
||||
},
|
||||
deep: true,
|
||||
immediate: false
|
||||
},
|
||||
selectedLevels: {
|
||||
handler() {
|
||||
this.refreshDisplay();
|
||||
@@ -113,30 +78,142 @@ export default {
|
||||
}
|
||||
},
|
||||
async mounted() {
|
||||
// 请求历史日志
|
||||
await this.fetchLogHistory();
|
||||
|
||||
// 等待 DOM 准备好后,显示历史日志
|
||||
this.$nextTick(() => {
|
||||
if (this.localLogCache.length > 0) {
|
||||
this.localLogCache.forEach(logItem => {
|
||||
if (this.isLevelSelected(logItem.level)) {
|
||||
this.printLog(logItem.data);
|
||||
}
|
||||
});
|
||||
// 更新最后处理时间
|
||||
this.lastProcessedTime = Math.max(...this.localLogCache.map(log => log.time));
|
||||
}
|
||||
});
|
||||
this.connectSSE();
|
||||
},
|
||||
beforeUnmount() {
|
||||
if (this.eventSource) {
|
||||
this.eventSource.close();
|
||||
this.eventSource = null;
|
||||
}
|
||||
if (this.retryTimer) {
|
||||
clearTimeout(this.retryTimer);
|
||||
this.retryTimer = null;
|
||||
}
|
||||
this.retryAttempts = 0;
|
||||
},
|
||||
methods: {
|
||||
connectSSE() {
|
||||
if (this.eventSource) {
|
||||
this.eventSource.close();
|
||||
this.eventSource = null;
|
||||
}
|
||||
|
||||
console.log(`正在连接日志流... (尝试次数: ${this.retryAttempts})`);
|
||||
|
||||
const token = localStorage.getItem('token');
|
||||
|
||||
this.eventSource = new EventSourcePolyfill('/api/live-log', {
|
||||
headers: {
|
||||
'Authorization': token ? `Bearer ${token}` : ''
|
||||
},
|
||||
heartbeatTimeout: 300000,
|
||||
withCredentials: true
|
||||
});
|
||||
|
||||
this.eventSource.onopen = () => {
|
||||
console.log('日志流连接成功!');
|
||||
this.retryAttempts = 0;
|
||||
|
||||
if (!this.lastEventId) {
|
||||
this.fetchLogHistory();
|
||||
}
|
||||
};
|
||||
|
||||
this.eventSource.onmessage = (event) => {
|
||||
try {
|
||||
if (event.lastEventId) {
|
||||
this.lastEventId = event.lastEventId;
|
||||
}
|
||||
|
||||
const payload = JSON.parse(event.data);
|
||||
this.processNewLogs([payload]);
|
||||
} catch (e) {
|
||||
console.error('解析日志失败:', e);
|
||||
}
|
||||
};
|
||||
|
||||
this.eventSource.onerror = (err) => {
|
||||
|
||||
if (err.status === 401) {
|
||||
console.error('鉴权失败 (401),可能是 Token 过期了。');
|
||||
|
||||
} else {
|
||||
console.warn('日志流连接错误:', err);
|
||||
}
|
||||
|
||||
if (this.eventSource) {
|
||||
this.eventSource.close();
|
||||
this.eventSource = null;
|
||||
}
|
||||
|
||||
if (this.retryAttempts >= this.maxRetryAttempts) {
|
||||
console.error('❌ 已达到最大重试次数,停止重连。请刷新页面重试。');
|
||||
return;
|
||||
}
|
||||
|
||||
const delay = Math.min(
|
||||
this.baseRetryDelay * Math.pow(2, this.retryAttempts),
|
||||
30000
|
||||
);
|
||||
|
||||
console.log(`⏳ ${delay}ms 后尝试第 ${this.retryAttempts + 1} 次重连...`);
|
||||
|
||||
if (this.retryTimer) {
|
||||
clearTimeout(this.retryTimer);
|
||||
this.retryTimer = null;
|
||||
}
|
||||
|
||||
this.retryTimer = setTimeout(async () => {
|
||||
this.retryAttempts++;
|
||||
|
||||
if (!this.lastEventId) {
|
||||
await this.fetchLogHistory();
|
||||
}
|
||||
|
||||
this.connectSSE();
|
||||
}, delay);
|
||||
};
|
||||
},
|
||||
|
||||
processNewLogs(newLogs) {
|
||||
if (!newLogs || newLogs.length === 0) return;
|
||||
|
||||
let hasUpdate = false;
|
||||
|
||||
newLogs.forEach(log => {
|
||||
|
||||
const exists = this.localLogCache.some(existing =>
|
||||
existing.time === log.time &&
|
||||
existing.data === log.data &&
|
||||
existing.level === log.level
|
||||
);
|
||||
|
||||
if (!exists) {
|
||||
this.localLogCache.push(log);
|
||||
hasUpdate = true;
|
||||
|
||||
if (this.isLevelSelected(log.level)) {
|
||||
this.printLog(log.data);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
if (hasUpdate) {
|
||||
this.localLogCache.sort((a, b) => a.time - b.time);
|
||||
|
||||
const maxSize = this.commonStore.log_cache_max_len || 200;
|
||||
if (this.localLogCache.length > maxSize) {
|
||||
this.localLogCache.splice(0, this.localLogCache.length - maxSize);
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
async fetchLogHistory() {
|
||||
try {
|
||||
const res = await axios.get('/api/log-history');
|
||||
if (res.data.data.logs && res.data.data.logs.length > 0) {
|
||||
this.localLogCache = [...res.data.data.logs];
|
||||
// 按时间戳排序
|
||||
this.localLogCache.sort((a, b) => a.time - b.time);
|
||||
this.processNewLogs(res.data.data.logs);
|
||||
}
|
||||
} catch (err) {
|
||||
console.error('Failed to fetch log history:', err);
|
||||
@@ -162,7 +239,6 @@ export default {
|
||||
if (termElement) {
|
||||
termElement.innerHTML = '';
|
||||
|
||||
// 重新显示所有符合筛选条件的日志
|
||||
if (this.localLogCache && this.localLogCache.length > 0) {
|
||||
this.localLogCache.forEach(logItem => {
|
||||
if (this.isLevelSelected(logItem.level)) {
|
||||
@@ -173,16 +249,13 @@ export default {
|
||||
}
|
||||
},
|
||||
|
||||
|
||||
toggleAutoScroll() {
|
||||
this.autoScroll = !this.autoScroll;
|
||||
},
|
||||
|
||||
printLog(log) {
|
||||
// append 一个 span 标签到 term,block 的方式
|
||||
let ele = document.getElementById('term')
|
||||
if (!ele) {
|
||||
console.warn('term element not found, skipping log print');
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -196,11 +269,11 @@ export default {
|
||||
}
|
||||
}
|
||||
|
||||
span.style = style + 'display: block; font-size: 12px; font-family: Consolas, monospace; white-space: pre-wrap;'
|
||||
span.style = style + 'display: block; font-size: 12px; font-family: Consolas, monospace; white-space: pre-wrap; margin-bottom: 2px;'
|
||||
span.classList.add('fade-in')
|
||||
span.innerText = `${log}`;
|
||||
ele.appendChild(span)
|
||||
if (this.autoScroll ) {
|
||||
if (this.autoScroll) {
|
||||
ele.scrollTop = ele.scrollHeight
|
||||
}
|
||||
}
|
||||
@@ -230,4 +303,4 @@ export default {
|
||||
opacity: 1;
|
||||
}
|
||||
}
|
||||
</style>
|
||||
</style>
|
||||
|
||||
@@ -18,6 +18,11 @@
|
||||
"title": "Data Migration to v4.0.0",
|
||||
"subtitle": "If you encounter data compatibility issues, you can manually start the database migration assistant",
|
||||
"button": "Start Migration Assistant"
|
||||
},
|
||||
"backup": {
|
||||
"title": "Backup & Restore",
|
||||
"subtitle": "Export or import all AstrBot data for easy migration to a new server",
|
||||
"button": "Backup Manager"
|
||||
}
|
||||
},
|
||||
"sidebar": {
|
||||
@@ -29,5 +34,66 @@
|
||||
"mainItems": "Main Modules",
|
||||
"moreItems": "More Features"
|
||||
}
|
||||
},
|
||||
"backup": {
|
||||
"dialog": {
|
||||
"title": "Backup Manager"
|
||||
},
|
||||
"tabs": {
|
||||
"export": "Export Backup",
|
||||
"import": "Import Backup",
|
||||
"list": "Backup List"
|
||||
},
|
||||
"export": {
|
||||
"title": "Create Backup",
|
||||
"description": "Export all data as a ZIP backup file, including database, knowledge base, config and attachments.",
|
||||
"includes": "Backup includes: Main database, Knowledge bases (metadata + vector index + documents), Config files, Attachment files",
|
||||
"button": "Start Export",
|
||||
"processing": "Exporting...",
|
||||
"wait": "Please wait, packaging data...",
|
||||
"completed": "Export Completed!",
|
||||
"download": "Download Backup",
|
||||
"another": "Create New Backup",
|
||||
"failed": "Export Failed",
|
||||
"retry": "Retry"
|
||||
},
|
||||
"import": {
|
||||
"title": "Import Backup",
|
||||
"warning": "⚠️ Import will clear and overwrite existing data! Please make sure you have backed up your current data.",
|
||||
"selectFile": "Select backup file (.zip)",
|
||||
"uploadAndCheck": "Upload & Check",
|
||||
"uploading": "Uploading...",
|
||||
"uploadWait": "Please wait, uploading backup file...",
|
||||
"invalidBackup": "Invalid backup file",
|
||||
"backupContents": "Backup Contents",
|
||||
"tables": "tables",
|
||||
"knowledgeBases": "Knowledge Bases",
|
||||
"configFiles": "Config Files",
|
||||
"confirmImport": "Confirm Import",
|
||||
"button": "Start Import",
|
||||
"processing": "Importing...",
|
||||
"wait": "Please wait, restoring data...",
|
||||
"completed": "Import Completed!",
|
||||
"restartRequired": "Data has been successfully imported. It is recommended to restart AstrBot immediately for all changes to take effect.",
|
||||
"restartNow": "Restart Now",
|
||||
"failed": "Import Failed",
|
||||
"retry": "Retry",
|
||||
"version": {
|
||||
"backupVersion": "Backup Version",
|
||||
"currentVersion": "Current Version",
|
||||
"backupTime": "Backup Time",
|
||||
"matchTitle": "✅ Version Match",
|
||||
"matchMessage": "Import will clear and overwrite all existing data, including:\n• Main database (conversations, settings, etc.)\n• Knowledge bases\n• Plugins and plugin data\n• Configuration files\n\nThis action cannot be undone! Do you want to continue?",
|
||||
"minorDiffTitle": "⚠️ Version Difference Warning",
|
||||
"minorDiffMessage": "Minor version differences are usually compatible, but there may be some data structure changes.\nImport will clear and overwrite all existing data!\n\nDo you want to continue?",
|
||||
"majorDiffTitle": "⛔ Cannot Import",
|
||||
"majorDiffMessage": "Major version numbers are different. Cross-major-version import may cause data corruption.\nPlease use the same major version of AstrBot for import."
|
||||
}
|
||||
},
|
||||
"list": {
|
||||
"empty": "No backup files",
|
||||
"refresh": "Refresh List",
|
||||
"confirmDelete": "Are you sure you want to delete this backup file? This action cannot be undone."
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -18,6 +18,11 @@
|
||||
"title": "数据迁移到 v4.0.0 格式",
|
||||
"subtitle": "如果您遇到数据兼容性问题,可以手动启动数据库迁移助手",
|
||||
"button": "启动迁移助手"
|
||||
},
|
||||
"backup": {
|
||||
"title": "数据备份与恢复",
|
||||
"subtitle": "导出或导入 AstrBot 的所有数据,方便迁移到新服务器",
|
||||
"button": "备份管理"
|
||||
}
|
||||
},
|
||||
"sidebar": {
|
||||
@@ -29,5 +34,66 @@
|
||||
"mainItems": "主要模块",
|
||||
"moreItems": "更多功能"
|
||||
}
|
||||
},
|
||||
"backup": {
|
||||
"dialog": {
|
||||
"title": "备份管理"
|
||||
},
|
||||
"tabs": {
|
||||
"export": "导出备份",
|
||||
"import": "导入备份",
|
||||
"list": "备份列表"
|
||||
},
|
||||
"export": {
|
||||
"title": "创建备份",
|
||||
"description": "将所有数据导出为 ZIP 备份文件,包括数据库、知识库、配置和附件。",
|
||||
"includes": "备份包含:主数据库、知识库(元数据+向量索引+文档)、配置文件、附件文件",
|
||||
"button": "开始导出",
|
||||
"processing": "正在导出...",
|
||||
"wait": "请稍候,正在打包数据...",
|
||||
"completed": "导出完成!",
|
||||
"download": "下载备份",
|
||||
"another": "创建新备份",
|
||||
"failed": "导出失败",
|
||||
"retry": "重试"
|
||||
},
|
||||
"import": {
|
||||
"title": "导入备份",
|
||||
"warning": "⚠️ 导入将会清空并覆盖现有数据!请确保已备份当前数据。",
|
||||
"selectFile": "选择备份文件 (.zip)",
|
||||
"uploadAndCheck": "上传并检查",
|
||||
"uploading": "正在上传...",
|
||||
"uploadWait": "请稍候,正在上传备份文件...",
|
||||
"invalidBackup": "无效的备份文件",
|
||||
"backupContents": "备份内容",
|
||||
"tables": "个数据表",
|
||||
"knowledgeBases": "知识库",
|
||||
"configFiles": "配置文件",
|
||||
"confirmImport": "确认导入",
|
||||
"button": "开始导入",
|
||||
"processing": "正在导入...",
|
||||
"wait": "请稍候,正在恢复数据...",
|
||||
"completed": "导入完成!",
|
||||
"restartRequired": "数据已成功导入。建议立即重启 AstrBot 以使所有更改生效。",
|
||||
"restartNow": "立即重启",
|
||||
"failed": "导入失败",
|
||||
"retry": "重试",
|
||||
"version": {
|
||||
"backupVersion": "备份版本",
|
||||
"currentVersion": "当前版本",
|
||||
"backupTime": "备份时间",
|
||||
"matchTitle": "✅ 版本匹配",
|
||||
"matchMessage": "导入将会清空并覆盖现有的所有数据,包括:\n• 主数据库(对话记录、配置等)\n• 知识库数据\n• 插件及插件数据\n• 配置文件\n\n此操作不可撤销!是否确认继续?",
|
||||
"minorDiffTitle": "⚠️ 版本差异警告",
|
||||
"minorDiffMessage": "小版本差异通常是兼容的,但可能存在少量数据结构变化。\n导入将会清空并覆盖现有的所有数据!\n\n是否确认继续导入?",
|
||||
"majorDiffTitle": "⛔ 无法导入",
|
||||
"majorDiffMessage": "主版本号不同,跨主版本导入可能导致数据损坏。\n请使用相同主版本的 AstrBot 进行导入。"
|
||||
}
|
||||
},
|
||||
"list": {
|
||||
"empty": "暂无备份文件",
|
||||
"refresh": "刷新列表",
|
||||
"confirmDelete": "确定要删除这个备份文件吗?此操作不可撤销。"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -21,10 +21,14 @@ export const useCommonStore = defineStore({
|
||||
}
|
||||
const controller = new AbortController();
|
||||
const { signal } = controller;
|
||||
|
||||
// 注意:这里如果之前改过 Polyfill 的话,可能需要保持原样
|
||||
// 如果是用 fetch 的话,这里是支持 Authorization Header 的
|
||||
const headers = {
|
||||
'Content-Type': 'multipart/form-data',
|
||||
'Authorization': 'Bearer ' + localStorage.getItem('token')
|
||||
};
|
||||
|
||||
fetch('/api/live-log', {
|
||||
method: 'GET',
|
||||
headers,
|
||||
@@ -72,10 +76,20 @@ export const useCommonStore = defineStore({
|
||||
|
||||
try {
|
||||
const logObject = JSON.parse(logLine);
|
||||
// give a uuid if not exists
|
||||
|
||||
// 修复:兼容 HTTP 环境的 UUID 生成
|
||||
if (!logObject.uuid) {
|
||||
logObject.uuid = crypto.randomUUID();
|
||||
if (typeof crypto !== 'undefined' && typeof crypto.randomUUID === 'function') {
|
||||
logObject.uuid = crypto.randomUUID();
|
||||
} else {
|
||||
// 手动生成 UUID v4
|
||||
logObject.uuid = 'xxxxxxxx-xxxx-4xxx-yxxx-xxxxxxxxxxxx'.replace(/[xy]/g, function(c) {
|
||||
var r = Math.random() * 16 | 0, v = c == 'x' ? r : (r & 0x3 | 0x8);
|
||||
return v.toString(16);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
this.log_cache.push(logObject);
|
||||
// Limit log cache size
|
||||
if (this.log_cache.length > this.log_cache_max_len) {
|
||||
@@ -93,7 +107,13 @@ export const useCommonStore = defineStore({
|
||||
}).catch(error => {
|
||||
console.error('SSE error:', error);
|
||||
// Attempt to reconnect after a delay
|
||||
this.log_cache.push('SSE Connection failed, retrying in 5 seconds...');
|
||||
this.log_cache.push({
|
||||
type: 'log',
|
||||
level: 'ERROR',
|
||||
time: Date.now() / 1000,
|
||||
data: 'SSE Connection failed, retrying in 5 seconds...',
|
||||
uuid: 'error-' + Date.now()
|
||||
});
|
||||
setTimeout(() => {
|
||||
this.eventSource = null;
|
||||
this.createEventSource();
|
||||
|
||||
@@ -17,6 +17,13 @@
|
||||
|
||||
<v-list-subheader>{{ tm('system.title') }}</v-list-subheader>
|
||||
|
||||
<v-list-item :subtitle="tm('system.backup.subtitle')" :title="tm('system.backup.title')">
|
||||
<v-btn style="margin-top: 16px;" color="primary" @click="openBackupDialog">
|
||||
<v-icon class="mr-2">mdi-backup-restore</v-icon>
|
||||
{{ tm('system.backup.button') }}
|
||||
</v-btn>
|
||||
</v-list-item>
|
||||
|
||||
<v-list-item :subtitle="tm('system.restart.subtitle')" :title="tm('system.restart.title')">
|
||||
<v-btn style="margin-top: 16px;" color="error" @click="restartAstrBot">{{ tm('system.restart.button') }}</v-btn>
|
||||
</v-list-item>
|
||||
@@ -30,6 +37,7 @@
|
||||
|
||||
<WaitingForRestart ref="wfr"></WaitingForRestart>
|
||||
<MigrationDialog ref="migrationDialog"></MigrationDialog>
|
||||
<BackupDialog ref="backupDialog"></BackupDialog>
|
||||
|
||||
</template>
|
||||
|
||||
@@ -40,12 +48,14 @@ import WaitingForRestart from '@/components/shared/WaitingForRestart.vue';
|
||||
import ProxySelector from '@/components/shared/ProxySelector.vue';
|
||||
import MigrationDialog from '@/components/shared/MigrationDialog.vue';
|
||||
import SidebarCustomizer from '@/components/shared/SidebarCustomizer.vue';
|
||||
import BackupDialog from '@/components/shared/BackupDialog.vue';
|
||||
import { useModuleI18n } from '@/i18n/composables';
|
||||
|
||||
const { tm } = useModuleI18n('features/settings');
|
||||
|
||||
const wfr = ref(null);
|
||||
const migrationDialog = ref(null);
|
||||
const backupDialog = ref(null);
|
||||
|
||||
const restartAstrBot = () => {
|
||||
axios.post('/api/stat/restart-core').then(() => {
|
||||
@@ -65,4 +75,10 @@ const startMigration = async () => {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const openBackupDialog = () => {
|
||||
if (backupDialog.value) {
|
||||
backupDialog.value.open();
|
||||
}
|
||||
}
|
||||
</script>
|
||||
@@ -19,6 +19,7 @@ export default defineConfig({
|
||||
],
|
||||
resolve: {
|
||||
alias: {
|
||||
mermaid: 'mermaid/dist/mermaid.js',
|
||||
'@': fileURLToPath(new URL('./src', import.meta.url))
|
||||
}
|
||||
},
|
||||
|
||||
+2
-2
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "AstrBot"
|
||||
version = "4.10.2"
|
||||
version = "4.10.3"
|
||||
description = "Easy-to-use multi-platform LLM chatbot and development framework"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
@@ -103,7 +103,7 @@ typeCheckingMode = "basic"
|
||||
pythonVersion = "3.10"
|
||||
reportMissingTypeStubs = false
|
||||
reportMissingImports = false
|
||||
include = ["astrbot", "packages"]
|
||||
include = ["astrbot"]
|
||||
exclude = ["dashboard", "node_modules", "dist", "data", "tests"]
|
||||
|
||||
[build-system]
|
||||
|
||||
@@ -0,0 +1,760 @@
|
||||
"""备份功能单元测试"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import zipfile
|
||||
from datetime import datetime
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from astrbot.core.backup import (
|
||||
BACKUP_MANIFEST_VERSION,
|
||||
KB_METADATA_MODELS,
|
||||
MAIN_DB_MODELS,
|
||||
ImportPreCheckResult,
|
||||
)
|
||||
from astrbot.core.backup.exporter import AstrBotExporter
|
||||
from astrbot.core.backup.importer import (
|
||||
AstrBotImporter,
|
||||
ImportResult,
|
||||
_get_major_version,
|
||||
)
|
||||
from astrbot.core.config.default import VERSION
|
||||
from astrbot.core.db.po import (
|
||||
ConversationV2,
|
||||
)
|
||||
from astrbot.core.utils.version_comparator import VersionComparator
|
||||
from astrbot.dashboard.routes.backup import (
|
||||
generate_unique_filename,
|
||||
secure_filename,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_backup_dir(tmp_path):
|
||||
"""创建临时备份目录"""
|
||||
backup_dir = tmp_path / "backups"
|
||||
backup_dir.mkdir()
|
||||
return backup_dir
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_data_dir(tmp_path):
|
||||
"""创建临时数据目录"""
|
||||
data_dir = tmp_path / "data"
|
||||
data_dir.mkdir()
|
||||
|
||||
# 创建配置文件
|
||||
config_path = data_dir / "cmd_config.json"
|
||||
config_path.write_text(json.dumps({"test": "config"}))
|
||||
|
||||
# 创建附件目录
|
||||
attachments_dir = data_dir / "attachments"
|
||||
attachments_dir.mkdir()
|
||||
|
||||
return data_dir
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_main_db():
|
||||
"""创建模拟的主数据库"""
|
||||
db = MagicMock()
|
||||
|
||||
# 模拟异步上下文管理器
|
||||
session = AsyncMock()
|
||||
db.get_db = MagicMock(
|
||||
return_value=AsyncMock(__aenter__=AsyncMock(return_value=session))
|
||||
)
|
||||
|
||||
return db
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_kb_manager():
|
||||
"""创建模拟的知识库管理器"""
|
||||
kb_manager = MagicMock()
|
||||
kb_manager.kb_insts = {}
|
||||
|
||||
# 模拟 kb_db
|
||||
kb_db = MagicMock()
|
||||
session = AsyncMock()
|
||||
kb_db.get_db = MagicMock(
|
||||
return_value=AsyncMock(__aenter__=AsyncMock(return_value=session))
|
||||
)
|
||||
kb_manager.kb_db = kb_db
|
||||
|
||||
return kb_manager
|
||||
|
||||
|
||||
class TestImportResult:
|
||||
"""ImportResult 类测试"""
|
||||
|
||||
def test_init(self):
|
||||
"""测试初始化"""
|
||||
result = ImportResult()
|
||||
assert result.success is True
|
||||
assert result.imported_tables == {}
|
||||
assert result.imported_files == {}
|
||||
assert result.warnings == []
|
||||
assert result.errors == []
|
||||
|
||||
def test_add_warning(self):
|
||||
"""测试添加警告"""
|
||||
result = ImportResult()
|
||||
result.add_warning("test warning")
|
||||
assert "test warning" in result.warnings
|
||||
assert result.success is True # 警告不影响成功状态
|
||||
|
||||
def test_add_error(self):
|
||||
"""测试添加错误"""
|
||||
result = ImportResult()
|
||||
result.add_error("test error")
|
||||
assert "test error" in result.errors
|
||||
assert result.success is False # 错误会导致失败
|
||||
|
||||
def test_to_dict(self):
|
||||
"""测试转换为字典"""
|
||||
result = ImportResult()
|
||||
result.imported_tables = {"test_table": 10}
|
||||
result.add_warning("warning")
|
||||
|
||||
d = result.to_dict()
|
||||
assert d["success"] is True
|
||||
assert d["imported_tables"] == {"test_table": 10}
|
||||
assert "warning" in d["warnings"]
|
||||
|
||||
|
||||
class TestAstrBotExporter:
|
||||
"""AstrBotExporter 类测试"""
|
||||
|
||||
def test_init(self, mock_main_db, mock_kb_manager, temp_data_dir):
|
||||
"""测试初始化"""
|
||||
exporter = AstrBotExporter(
|
||||
main_db=mock_main_db,
|
||||
kb_manager=mock_kb_manager,
|
||||
config_path=str(temp_data_dir / "cmd_config.json"),
|
||||
)
|
||||
assert exporter.main_db is mock_main_db
|
||||
assert exporter.kb_manager is mock_kb_manager
|
||||
|
||||
def test_model_to_dict_with_model_dump(self):
|
||||
"""测试 _model_to_dict 使用 model_dump 方法"""
|
||||
exporter = AstrBotExporter(main_db=MagicMock())
|
||||
|
||||
# 创建一个有 model_dump 方法的模拟对象
|
||||
mock_record = MagicMock()
|
||||
mock_record.model_dump.return_value = {"id": 1, "name": "test"}
|
||||
|
||||
result = exporter._model_to_dict(mock_record)
|
||||
assert result == {"id": 1, "name": "test"}
|
||||
|
||||
def test_model_to_dict_with_datetime(self):
|
||||
"""测试 _model_to_dict 处理 datetime 字段"""
|
||||
exporter = AstrBotExporter(main_db=MagicMock())
|
||||
|
||||
now = datetime.now()
|
||||
mock_record = MagicMock()
|
||||
mock_record.model_dump.return_value = {"id": 1, "created_at": now}
|
||||
|
||||
result = exporter._model_to_dict(mock_record)
|
||||
assert result["created_at"] == now.isoformat()
|
||||
|
||||
def test_add_checksum(self):
|
||||
"""测试添加校验和"""
|
||||
exporter = AstrBotExporter(main_db=MagicMock())
|
||||
|
||||
exporter._add_checksum("test.json", '{"test": "data"}')
|
||||
|
||||
assert "test.json" in exporter._checksums
|
||||
assert exporter._checksums["test.json"].startswith("sha256:")
|
||||
|
||||
def test_generate_manifest(self, mock_main_db, mock_kb_manager):
|
||||
"""测试生成清单"""
|
||||
exporter = AstrBotExporter(
|
||||
main_db=mock_main_db,
|
||||
kb_manager=mock_kb_manager,
|
||||
)
|
||||
|
||||
main_data = {
|
||||
"platform_stats": [{"id": 1}],
|
||||
"conversations": [],
|
||||
"attachments": [],
|
||||
}
|
||||
kb_meta_data = {
|
||||
"knowledge_bases": [],
|
||||
"kb_documents": [],
|
||||
}
|
||||
dir_stats = {
|
||||
"plugins": {"files": 10, "size": 1024},
|
||||
"plugin_data": {"files": 5, "size": 512},
|
||||
}
|
||||
|
||||
manifest = exporter._generate_manifest(main_data, kb_meta_data, dir_stats)
|
||||
|
||||
assert manifest["version"] == BACKUP_MANIFEST_VERSION
|
||||
assert manifest["astrbot_version"] == VERSION
|
||||
assert "exported_at" in manifest
|
||||
assert "tables" in manifest
|
||||
assert "statistics" in manifest
|
||||
assert "directories" in manifest
|
||||
assert manifest["statistics"]["main_db"]["platform_stats"] == 1
|
||||
assert manifest["statistics"]["directories"] == dir_stats
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_export_all_creates_zip(
|
||||
self, mock_main_db, temp_backup_dir, temp_data_dir
|
||||
):
|
||||
"""测试导出创建 ZIP 文件"""
|
||||
# 设置模拟数据库返回空数据
|
||||
session = AsyncMock()
|
||||
result = MagicMock()
|
||||
result.scalars.return_value.all.return_value = []
|
||||
session.execute = AsyncMock(return_value=result)
|
||||
|
||||
mock_main_db.get_db.return_value = AsyncMock(
|
||||
__aenter__=AsyncMock(return_value=session),
|
||||
__aexit__=AsyncMock(return_value=None),
|
||||
)
|
||||
|
||||
exporter = AstrBotExporter(
|
||||
main_db=mock_main_db,
|
||||
kb_manager=None,
|
||||
config_path=str(temp_data_dir / "cmd_config.json"),
|
||||
)
|
||||
|
||||
zip_path = await exporter.export_all(output_dir=str(temp_backup_dir))
|
||||
|
||||
assert os.path.exists(zip_path)
|
||||
assert zip_path.endswith(".zip")
|
||||
assert "astrbot_backup_" in zip_path
|
||||
|
||||
# 验证 ZIP 文件内容
|
||||
with zipfile.ZipFile(zip_path, "r") as zf:
|
||||
namelist = zf.namelist()
|
||||
assert "manifest.json" in namelist
|
||||
assert "databases/main_db.json" in namelist
|
||||
assert "config/cmd_config.json" in namelist
|
||||
|
||||
|
||||
class TestAstrBotImporter:
|
||||
"""AstrBotImporter 类测试"""
|
||||
|
||||
def test_init(self, mock_main_db, mock_kb_manager, temp_data_dir):
|
||||
"""测试初始化"""
|
||||
importer = AstrBotImporter(
|
||||
main_db=mock_main_db,
|
||||
kb_manager=mock_kb_manager,
|
||||
config_path=str(temp_data_dir / "cmd_config.json"),
|
||||
)
|
||||
assert importer.main_db is mock_main_db
|
||||
assert importer.kb_manager is mock_kb_manager
|
||||
|
||||
def test_validate_version_match(self):
|
||||
"""测试版本匹配验证"""
|
||||
importer = AstrBotImporter(main_db=MagicMock())
|
||||
|
||||
manifest = {"astrbot_version": VERSION}
|
||||
# 不应该抛出异常
|
||||
importer._validate_version(manifest)
|
||||
|
||||
def test_validate_version_major_diff_rejected(self):
|
||||
"""测试主版本不同被拒绝"""
|
||||
importer = AstrBotImporter(main_db=MagicMock())
|
||||
|
||||
# 使用一个明显不同的主版本
|
||||
manifest = {"astrbot_version": "0.0.1"}
|
||||
with pytest.raises(ValueError, match="主版本不兼容"):
|
||||
importer._validate_version(manifest)
|
||||
|
||||
def test_validate_version_minor_diff_allowed(self):
|
||||
"""测试小版本不同被允许(仅警告)"""
|
||||
importer = AstrBotImporter(main_db=MagicMock())
|
||||
|
||||
# 获取当前主版本
|
||||
major_version = _get_major_version(VERSION)
|
||||
# 构造一个同主版本但小版本不同的版本
|
||||
minor_diff_version = f"{major_version}.999"
|
||||
manifest = {"astrbot_version": minor_diff_version}
|
||||
# 不应该抛出异常
|
||||
importer._validate_version(manifest)
|
||||
|
||||
def test_validate_version_missing(self):
|
||||
"""测试缺少版本信息"""
|
||||
importer = AstrBotImporter(main_db=MagicMock())
|
||||
|
||||
manifest = {}
|
||||
with pytest.raises(ValueError, match="缺少版本信息"):
|
||||
importer._validate_version(manifest)
|
||||
|
||||
def test_convert_datetime_fields(self):
|
||||
"""测试 datetime 字段转换"""
|
||||
importer = AstrBotImporter(main_db=MagicMock())
|
||||
|
||||
# 使用 ConversationV2 作为测试模型(它有 created_at 和 updated_at 字段)
|
||||
row = {
|
||||
"conversation_id": "test-123",
|
||||
"platform_id": "test",
|
||||
"user_id": "user1",
|
||||
"created_at": "2024-01-01T12:00:00",
|
||||
"updated_at": "2024-01-01T12:00:00",
|
||||
}
|
||||
|
||||
result = importer._convert_datetime_fields(row, ConversationV2)
|
||||
|
||||
# created_at 应该被转换为 datetime 对象
|
||||
assert isinstance(result["created_at"], datetime)
|
||||
assert isinstance(result["updated_at"], datetime)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_import_file_not_exists(self, mock_main_db, tmp_path):
|
||||
"""测试导入不存在的文件"""
|
||||
importer = AstrBotImporter(main_db=mock_main_db)
|
||||
|
||||
result = await importer.import_all(str(tmp_path / "nonexistent.zip"))
|
||||
|
||||
assert result.success is False
|
||||
assert any("不存在" in err for err in result.errors)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_import_invalid_zip(self, mock_main_db, tmp_path):
|
||||
"""测试导入无效的 ZIP 文件"""
|
||||
# 创建一个无效的文件
|
||||
invalid_zip = tmp_path / "invalid.zip"
|
||||
invalid_zip.write_text("not a zip file")
|
||||
|
||||
importer = AstrBotImporter(main_db=mock_main_db)
|
||||
result = await importer.import_all(str(invalid_zip))
|
||||
|
||||
assert result.success is False
|
||||
assert any("无效" in err or "ZIP" in err for err in result.errors)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_import_missing_manifest(self, mock_main_db, tmp_path):
|
||||
"""测试导入缺少 manifest 的 ZIP 文件"""
|
||||
# 创建一个没有 manifest 的 ZIP 文件
|
||||
zip_path = tmp_path / "no_manifest.zip"
|
||||
with zipfile.ZipFile(zip_path, "w") as zf:
|
||||
zf.writestr("test.txt", "test content")
|
||||
|
||||
importer = AstrBotImporter(main_db=mock_main_db)
|
||||
result = await importer.import_all(str(zip_path))
|
||||
|
||||
assert result.success is False
|
||||
assert any("manifest" in err.lower() for err in result.errors)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_import_major_version_mismatch(self, mock_main_db, tmp_path):
|
||||
"""测试导入主版本不匹配的备份"""
|
||||
# 创建一个主版本不匹配的备份
|
||||
zip_path = tmp_path / "old_version.zip"
|
||||
manifest = {
|
||||
"version": "1.0",
|
||||
"astrbot_version": "0.0.1", # 主版本不同
|
||||
"tables": {"main_db": []},
|
||||
}
|
||||
|
||||
with zipfile.ZipFile(zip_path, "w") as zf:
|
||||
zf.writestr("manifest.json", json.dumps(manifest))
|
||||
|
||||
importer = AstrBotImporter(main_db=mock_main_db)
|
||||
result = await importer.import_all(str(zip_path))
|
||||
|
||||
assert result.success is False
|
||||
assert any("主版本不兼容" in err for err in result.errors)
|
||||
|
||||
|
||||
class TestSecureFilename:
|
||||
"""安全文件名函数测试"""
|
||||
|
||||
def test_secure_filename_normal(self):
|
||||
"""测试正常文件名"""
|
||||
assert secure_filename("backup.zip") == "backup.zip"
|
||||
assert secure_filename("my_backup_2024.zip") == "my_backup_2024.zip"
|
||||
|
||||
def test_secure_filename_path_traversal(self):
|
||||
"""测试路径遍历攻击"""
|
||||
assert ".." not in secure_filename("../../../etc/passwd")
|
||||
assert "/" not in secure_filename("/etc/passwd")
|
||||
assert "\\" not in secure_filename("..\\..\\windows\\system32")
|
||||
|
||||
def test_secure_filename_with_path(self):
|
||||
"""测试带路径的文件名"""
|
||||
result = secure_filename("/path/to/backup.zip")
|
||||
assert result == "backup.zip"
|
||||
|
||||
result = secure_filename("C:\\Users\\test\\backup.zip")
|
||||
assert result == "backup.zip"
|
||||
|
||||
def test_secure_filename_special_chars(self):
|
||||
"""测试特殊字符"""
|
||||
result = secure_filename('backup<>:"|?*.zip')
|
||||
# 特殊字符应被替换为下划线
|
||||
assert "<" not in result
|
||||
assert ">" not in result
|
||||
assert ":" not in result
|
||||
assert '"' not in result
|
||||
assert "|" not in result
|
||||
assert "?" not in result
|
||||
assert "*" not in result
|
||||
|
||||
def test_secure_filename_hidden_file(self):
|
||||
"""测试隐藏文件(前导点)"""
|
||||
result = secure_filename(".hidden_backup.zip")
|
||||
assert not result.startswith(".")
|
||||
|
||||
def test_secure_filename_empty(self):
|
||||
"""测试空文件名"""
|
||||
assert secure_filename("") == "backup"
|
||||
assert secure_filename("...") == "backup"
|
||||
|
||||
def test_generate_unique_filename(self):
|
||||
"""测试生成唯一文件名"""
|
||||
result = generate_unique_filename("backup.zip")
|
||||
# 应包含 uploaded_ 前缀和时间戳
|
||||
assert result.startswith("uploaded_")
|
||||
assert result.endswith("_backup.zip")
|
||||
# 应包含时间戳格式 YYYYMMDD_HHMMSS
|
||||
assert re.search(r"uploaded_\d{8}_\d{6}_backup\.zip", result)
|
||||
|
||||
|
||||
class TestVersionComparison:
|
||||
"""版本比较函数测试 - 使用 VersionComparator"""
|
||||
|
||||
def test_get_major_version_simple(self):
|
||||
"""测试提取简单主版本号"""
|
||||
assert _get_major_version("1.0") == "1.0"
|
||||
assert _get_major_version("2.1") == "2.1"
|
||||
assert _get_major_version("4.9.1") == "4.9"
|
||||
|
||||
def test_get_major_version_with_prefix(self):
|
||||
"""测试带 v 前缀的版本号"""
|
||||
assert _get_major_version("v1.0") == "1.0"
|
||||
assert _get_major_version("V4.9.1") == "4.9"
|
||||
|
||||
def test_get_major_version_with_prerelease(self):
|
||||
"""测试带预发布标签的版本号"""
|
||||
assert _get_major_version("4.9.1-beta") == "4.9"
|
||||
assert _get_major_version("4.9.1-alpha.1") == "4.9"
|
||||
assert _get_major_version("4.9.1+build123") == "4.9"
|
||||
|
||||
def test_get_major_version_single_part(self):
|
||||
"""测试单部分版本号"""
|
||||
assert _get_major_version("1") == "1.0"
|
||||
|
||||
def test_get_major_version_empty(self):
|
||||
"""测试空版本号"""
|
||||
assert _get_major_version("") == "0.0"
|
||||
|
||||
def test_compare_versions_equal(self):
|
||||
"""测试版本相等"""
|
||||
assert VersionComparator.compare_version("1.0", "1.0") == 0
|
||||
assert VersionComparator.compare_version("1.0.0", "1.0") == 0
|
||||
assert VersionComparator.compare_version("2.10", "2.10") == 0
|
||||
|
||||
def test_compare_versions_less_than(self):
|
||||
"""测试版本小于"""
|
||||
assert VersionComparator.compare_version("1.0", "1.1") == -1
|
||||
assert (
|
||||
VersionComparator.compare_version("1.9", "1.10") == -1
|
||||
) # 关键测试:多位数版本比较
|
||||
assert VersionComparator.compare_version("1.2", "1.10") == -1
|
||||
assert VersionComparator.compare_version("1.0", "2.0") == -1
|
||||
|
||||
def test_compare_versions_greater_than(self):
|
||||
"""测试版本大于"""
|
||||
assert VersionComparator.compare_version("1.1", "1.0") == 1
|
||||
assert (
|
||||
VersionComparator.compare_version("1.10", "1.9") == 1
|
||||
) # 关键测试:多位数版本比较
|
||||
assert VersionComparator.compare_version("1.10", "1.2") == 1
|
||||
assert VersionComparator.compare_version("2.0", "1.0") == 1
|
||||
|
||||
def test_compare_versions_different_lengths(self):
|
||||
"""测试不同长度版本比较"""
|
||||
assert VersionComparator.compare_version("1.0", "1.0.0") == 0
|
||||
assert VersionComparator.compare_version("1.0", "1.0.1") == -1
|
||||
assert VersionComparator.compare_version("1.0.1", "1.0") == 1
|
||||
|
||||
def test_compare_versions_prerelease(self):
|
||||
"""测试预发布版本比较"""
|
||||
# 预发布版本低于正式版本
|
||||
assert VersionComparator.compare_version("1.0.0-alpha", "1.0.0") == -1
|
||||
assert VersionComparator.compare_version("1.0.0", "1.0.0-beta") == 1
|
||||
# alpha < beta
|
||||
assert VersionComparator.compare_version("1.0.0-alpha", "1.0.0-beta") == -1
|
||||
|
||||
|
||||
class TestImportPreCheckResult:
|
||||
"""ImportPreCheckResult 类测试"""
|
||||
|
||||
def test_init_default_values(self):
|
||||
"""测试默认值初始化"""
|
||||
result = ImportPreCheckResult()
|
||||
assert result.valid is False
|
||||
assert result.can_import is False
|
||||
assert result.version_status == ""
|
||||
assert result.backup_version == ""
|
||||
assert result.current_version == VERSION
|
||||
assert result.confirm_message == ""
|
||||
assert result.warnings == []
|
||||
assert result.error == ""
|
||||
assert result.backup_summary == {}
|
||||
|
||||
def test_to_dict(self):
|
||||
"""测试转换为字典"""
|
||||
result = ImportPreCheckResult(
|
||||
valid=True,
|
||||
can_import=True,
|
||||
version_status="match",
|
||||
backup_version="4.9.0",
|
||||
confirm_message="确认导入?",
|
||||
warnings=["警告1"],
|
||||
backup_summary={"tables": ["table1"]},
|
||||
)
|
||||
|
||||
d = result.to_dict()
|
||||
assert d["valid"] is True
|
||||
assert d["can_import"] is True
|
||||
assert d["version_status"] == "match"
|
||||
assert d["backup_version"] == "4.9.0"
|
||||
assert d["confirm_message"] == "确认导入?"
|
||||
assert "警告1" in d["warnings"]
|
||||
assert d["backup_summary"]["tables"] == ["table1"]
|
||||
|
||||
|
||||
class TestPreCheck:
|
||||
"""预检查功能测试"""
|
||||
|
||||
def test_pre_check_file_not_exists(self, mock_main_db):
|
||||
"""测试预检查不存在的文件"""
|
||||
importer = AstrBotImporter(main_db=mock_main_db)
|
||||
result = importer.pre_check("/nonexistent/file.zip")
|
||||
|
||||
assert result.valid is False
|
||||
assert "不存在" in result.error
|
||||
|
||||
def test_pre_check_invalid_zip(self, mock_main_db, tmp_path):
|
||||
"""测试预检查无效的 ZIP 文件"""
|
||||
invalid_zip = tmp_path / "invalid.zip"
|
||||
invalid_zip.write_text("not a zip file")
|
||||
|
||||
importer = AstrBotImporter(main_db=mock_main_db)
|
||||
result = importer.pre_check(str(invalid_zip))
|
||||
|
||||
assert result.valid is False
|
||||
assert "ZIP" in result.error or "无效" in result.error
|
||||
|
||||
def test_pre_check_missing_manifest(self, mock_main_db, tmp_path):
|
||||
"""测试预检查缺少 manifest 的 ZIP 文件"""
|
||||
zip_path = tmp_path / "no_manifest.zip"
|
||||
with zipfile.ZipFile(zip_path, "w") as zf:
|
||||
zf.writestr("test.txt", "test content")
|
||||
|
||||
importer = AstrBotImporter(main_db=mock_main_db)
|
||||
result = importer.pre_check(str(zip_path))
|
||||
|
||||
assert result.valid is False
|
||||
assert "manifest" in result.error.lower()
|
||||
|
||||
def test_pre_check_version_match(self, mock_main_db, tmp_path):
|
||||
"""测试预检查版本匹配"""
|
||||
zip_path = tmp_path / "backup.zip"
|
||||
manifest = {
|
||||
"version": "1.1",
|
||||
"astrbot_version": VERSION,
|
||||
"created_at": "2024-01-01T12:00:00",
|
||||
"tables": {"platform_stats": 1},
|
||||
"has_knowledge_bases": True,
|
||||
"has_config": True,
|
||||
"directories": ["plugins"],
|
||||
}
|
||||
|
||||
with zipfile.ZipFile(zip_path, "w") as zf:
|
||||
zf.writestr("manifest.json", json.dumps(manifest))
|
||||
|
||||
importer = AstrBotImporter(main_db=mock_main_db)
|
||||
result = importer.pre_check(str(zip_path))
|
||||
|
||||
assert result.valid is True
|
||||
assert result.can_import is True
|
||||
assert result.version_status == "match"
|
||||
assert result.backup_version == VERSION
|
||||
# confirm_message 现在由前端生成,后端不再生成
|
||||
assert result.backup_summary["has_knowledge_bases"] is True
|
||||
|
||||
def test_pre_check_minor_version_diff(self, mock_main_db, tmp_path):
|
||||
"""测试预检查小版本差异"""
|
||||
# 构造一个同主版本但小版本不同的版本
|
||||
major_version = _get_major_version(VERSION)
|
||||
minor_diff_version = f"{major_version}.999"
|
||||
|
||||
zip_path = tmp_path / "backup.zip"
|
||||
manifest = {
|
||||
"version": "1.1",
|
||||
"astrbot_version": minor_diff_version,
|
||||
"created_at": "2024-01-01T12:00:00",
|
||||
"tables": {},
|
||||
}
|
||||
|
||||
with zipfile.ZipFile(zip_path, "w") as zf:
|
||||
zf.writestr("manifest.json", json.dumps(manifest))
|
||||
|
||||
importer = AstrBotImporter(main_db=mock_main_db)
|
||||
result = importer.pre_check(str(zip_path))
|
||||
|
||||
assert result.valid is True
|
||||
assert result.can_import is True
|
||||
assert result.version_status == "minor_diff"
|
||||
# 版本消息由前端 i18n 生成,后端 warnings 列表不再包含版本相关消息
|
||||
# warnings 列表保留用于其他非版本相关的警告
|
||||
|
||||
def test_pre_check_major_version_diff(self, mock_main_db, tmp_path):
|
||||
"""测试预检查主版本差异"""
|
||||
zip_path = tmp_path / "backup.zip"
|
||||
manifest = {
|
||||
"version": "1.1",
|
||||
"astrbot_version": "0.0.1", # 主版本不同
|
||||
"created_at": "2024-01-01T12:00:00",
|
||||
"tables": {},
|
||||
}
|
||||
|
||||
with zipfile.ZipFile(zip_path, "w") as zf:
|
||||
zf.writestr("manifest.json", json.dumps(manifest))
|
||||
|
||||
importer = AstrBotImporter(main_db=mock_main_db)
|
||||
result = importer.pre_check(str(zip_path))
|
||||
|
||||
assert result.valid is True # 文件有效
|
||||
assert result.can_import is False # 但不能导入
|
||||
assert result.version_status == "major_diff"
|
||||
# 版本消息由前端 i18n 生成,后端 warnings 列表不再包含版本相关消息
|
||||
|
||||
|
||||
class TestVersionCompatibility:
|
||||
"""版本兼容性检查测试"""
|
||||
|
||||
def test_check_version_compatibility_match(self, mock_main_db):
|
||||
"""测试版本完全匹配"""
|
||||
importer = AstrBotImporter(main_db=mock_main_db)
|
||||
result = importer._check_version_compatibility(VERSION)
|
||||
|
||||
assert result["status"] == "match"
|
||||
assert result["can_import"] is True
|
||||
|
||||
def test_check_version_compatibility_minor_diff(self, mock_main_db):
|
||||
"""测试小版本差异"""
|
||||
major_version = _get_major_version(VERSION)
|
||||
minor_diff_version = f"{major_version}.999"
|
||||
|
||||
importer = AstrBotImporter(main_db=mock_main_db)
|
||||
result = importer._check_version_compatibility(minor_diff_version)
|
||||
|
||||
assert result["status"] == "minor_diff"
|
||||
assert result["can_import"] is True
|
||||
|
||||
def test_check_version_compatibility_major_diff(self, mock_main_db):
|
||||
"""测试主版本差异"""
|
||||
importer = AstrBotImporter(main_db=mock_main_db)
|
||||
result = importer._check_version_compatibility("0.0.1")
|
||||
|
||||
assert result["status"] == "major_diff"
|
||||
assert result["can_import"] is False
|
||||
|
||||
def test_check_version_compatibility_empty_version(self, mock_main_db):
|
||||
"""测试空版本号"""
|
||||
importer = AstrBotImporter(main_db=mock_main_db)
|
||||
result = importer._check_version_compatibility("")
|
||||
|
||||
assert result["status"] == "major_diff"
|
||||
assert result["can_import"] is False
|
||||
|
||||
|
||||
class TestModelMappings:
|
||||
"""测试模型映射配置"""
|
||||
|
||||
def test_main_db_models_not_empty(self):
|
||||
"""测试主数据库模型映射非空"""
|
||||
assert len(MAIN_DB_MODELS) > 0
|
||||
|
||||
def test_main_db_models_contain_expected_tables(self):
|
||||
"""测试主数据库模型映射包含预期的表"""
|
||||
expected_tables = [
|
||||
"platform_stats",
|
||||
"conversations",
|
||||
"personas",
|
||||
"preferences",
|
||||
"attachments",
|
||||
]
|
||||
for table in expected_tables:
|
||||
assert table in MAIN_DB_MODELS, f"Missing table: {table}"
|
||||
|
||||
def test_kb_metadata_models_not_empty(self):
|
||||
"""测试知识库元数据模型映射非空"""
|
||||
assert len(KB_METADATA_MODELS) > 0
|
||||
|
||||
def test_kb_metadata_models_contain_expected_tables(self):
|
||||
"""测试知识库元数据模型映射包含预期的表"""
|
||||
expected_tables = [
|
||||
"knowledge_bases",
|
||||
"kb_documents",
|
||||
"kb_media",
|
||||
]
|
||||
for table in expected_tables:
|
||||
assert table in KB_METADATA_MODELS, f"Missing table: {table}"
|
||||
|
||||
|
||||
class TestBackupIntegration:
|
||||
"""备份集成测试"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_export_import_roundtrip(self, tmp_path):
|
||||
"""测试导出-导入往返"""
|
||||
backup_dir = tmp_path / "backups"
|
||||
backup_dir.mkdir()
|
||||
|
||||
data_dir = tmp_path / "data"
|
||||
data_dir.mkdir()
|
||||
|
||||
config_path = data_dir / "cmd_config.json"
|
||||
config_path.write_text(json.dumps({"setting": "value"}))
|
||||
|
||||
attachments_dir = data_dir / "attachments"
|
||||
attachments_dir.mkdir()
|
||||
|
||||
# 创建模拟数据库
|
||||
mock_db = MagicMock()
|
||||
session = AsyncMock()
|
||||
result = MagicMock()
|
||||
result.scalars.return_value.all.return_value = []
|
||||
session.execute = AsyncMock(return_value=result)
|
||||
|
||||
mock_db.get_db.return_value = AsyncMock(
|
||||
__aenter__=AsyncMock(return_value=session),
|
||||
__aexit__=AsyncMock(return_value=None),
|
||||
)
|
||||
|
||||
# 导出
|
||||
exporter = AstrBotExporter(
|
||||
main_db=mock_db,
|
||||
kb_manager=None,
|
||||
config_path=str(config_path),
|
||||
)
|
||||
|
||||
zip_path = await exporter.export_all(output_dir=str(backup_dir))
|
||||
assert os.path.exists(zip_path)
|
||||
|
||||
# 验证 ZIP 内容
|
||||
with zipfile.ZipFile(zip_path, "r") as zf:
|
||||
# 读取 manifest
|
||||
manifest = json.loads(zf.read("manifest.json"))
|
||||
assert manifest["astrbot_version"] == VERSION
|
||||
|
||||
# 读取配置
|
||||
config = json.loads(zf.read("config/cmd_config.json"))
|
||||
assert config["setting"] == "value"
|
||||
|
||||
# 读取主数据库
|
||||
main_db = json.loads(zf.read("databases/main_db.json"))
|
||||
assert "platform_stats" in main_db
|
||||
Reference in New Issue
Block a user