Compare commits
10 Commits
multimessage
...
v4.10.3
| Author | SHA1 | Date | |
|---|---|---|---|
| 3e3599835e | |||
| 5255388e2d | |||
| fbdd60b64c | |||
| bd1b0a2836 | |||
| 19541d9d07 | |||
| 2a5d574394 | |||
| f2924fbd1b | |||
| 703e208947 | |||
| 9a5cc977c2 | |||
| aa38fe776a |
@@ -15,7 +15,6 @@ Always reference these instructions first and fallback to search or bash command
|
|||||||
### Running the Application
|
### Running the Application
|
||||||
- Run main application: `uv run main.py` -- starts in ~3 seconds
|
- Run main application: `uv run main.py` -- starts in ~3 seconds
|
||||||
- Application creates WebUI on http://localhost:6185 (default credentials: `astrbot`/`astrbot`)
|
- Application creates WebUI on http://localhost:6185 (default credentials: `astrbot`/`astrbot`)
|
||||||
- Application loads plugins automatically from `packages/` and `data/plugins/` directories
|
|
||||||
|
|
||||||
### Dashboard Build (Vue.js/Node.js)
|
### Dashboard Build (Vue.js/Node.js)
|
||||||
- **Prerequisites**: Node.js 20+ and npm 10+ required
|
- **Prerequisites**: Node.js 20+ and npm 10+ required
|
||||||
@@ -35,7 +34,7 @@ Always reference these instructions first and fallback to search or bash command
|
|||||||
- **ALWAYS** run `uv run ruff check .` and `uv run ruff format .` before committing changes
|
- **ALWAYS** run `uv run ruff check .` and `uv run ruff format .` before committing changes
|
||||||
|
|
||||||
### Plugin Development
|
### Plugin Development
|
||||||
- Plugins load from `packages/` (built-in) and `data/plugins/` (user-installed)
|
- Plugins load from `astrbot/builtin_stars/` (built-in) and `data/plugins/` (user-installed)
|
||||||
- Plugin system supports function tools and message handlers
|
- Plugin system supports function tools and message handlers
|
||||||
- Key plugins: python_interpreter, web_searcher, astrbot, reminder, session_controller
|
- Key plugins: python_interpreter, web_searcher, astrbot, reminder, session_controller
|
||||||
|
|
||||||
|
|||||||
+2
-2
@@ -24,9 +24,9 @@ configs/session
|
|||||||
configs/config.yaml
|
configs/config.yaml
|
||||||
cmd_config.json
|
cmd_config.json
|
||||||
|
|
||||||
# Plugins and packages
|
# Plugins
|
||||||
addons/plugins
|
addons/plugins
|
||||||
packages/python_interpreter/workplace
|
astrbot/builtin_stars/python_interpreter/workplace
|
||||||
tests/astrbot_plugin_openai
|
tests/astrbot_plugin_openai
|
||||||
|
|
||||||
# Dashboard
|
# Dashboard
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||

|

|
||||||
|
|
||||||
<div align="center">
|
<div align="center">
|
||||||
|
|
||||||
|
|||||||
+19
-8
@@ -7,6 +7,7 @@ from astrbot.api import logger, sp, star
|
|||||||
from astrbot.api.event import AstrMessageEvent
|
from astrbot.api.event import AstrMessageEvent
|
||||||
from astrbot.api.message_components import Image, Reply
|
from astrbot.api.message_components import Image, Reply
|
||||||
from astrbot.api.provider import Provider, ProviderRequest
|
from astrbot.api.provider import Provider, ProviderRequest
|
||||||
|
from astrbot.core.agent.message import TextPart
|
||||||
from astrbot.core.provider.func_tool_manager import ToolSet
|
from astrbot.core.provider.func_tool_manager import ToolSet
|
||||||
|
|
||||||
|
|
||||||
@@ -85,7 +86,9 @@ class ProcessLLMRequest:
|
|||||||
req.image_urls,
|
req.image_urls,
|
||||||
)
|
)
|
||||||
if caption:
|
if caption:
|
||||||
req.prompt = f"(Image Caption: {caption})\n\n{req.prompt}"
|
req.extra_user_content_parts.append(
|
||||||
|
TextPart(text=f"<image_caption>{caption}</image_caption>")
|
||||||
|
)
|
||||||
req.image_urls = []
|
req.image_urls = []
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"处理图片描述失败: {e}")
|
logger.error(f"处理图片描述失败: {e}")
|
||||||
@@ -129,13 +132,14 @@ class ProcessLLMRequest:
|
|||||||
else:
|
else:
|
||||||
req.prompt = prefix + req.prompt
|
req.prompt = prefix + req.prompt
|
||||||
|
|
||||||
|
# 收集系统提醒信息
|
||||||
|
system_parts = []
|
||||||
|
|
||||||
# user identifier
|
# user identifier
|
||||||
if cfg.get("identifier"):
|
if cfg.get("identifier"):
|
||||||
user_id = event.message_obj.sender.user_id
|
user_id = event.message_obj.sender.user_id
|
||||||
user_nickname = event.message_obj.sender.nickname
|
user_nickname = event.message_obj.sender.nickname
|
||||||
req.prompt = (
|
system_parts.append(f"User ID: {user_id}, Nickname: {user_nickname}")
|
||||||
f"\n[User ID: {user_id}, Nickname: {user_nickname}]\n{req.prompt}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# group name identifier
|
# group name identifier
|
||||||
if cfg.get("group_name_display") and event.message_obj.group_id:
|
if cfg.get("group_name_display") and event.message_obj.group_id:
|
||||||
@@ -146,7 +150,7 @@ class ProcessLLMRequest:
|
|||||||
return
|
return
|
||||||
group_name = event.message_obj.group.group_name
|
group_name = event.message_obj.group.group_name
|
||||||
if group_name:
|
if group_name:
|
||||||
req.system_prompt += f"\nGroup name: {group_name}\n"
|
system_parts.append(f"Group name: {group_name}")
|
||||||
|
|
||||||
# time info
|
# time info
|
||||||
if cfg.get("datetime_system_prompt"):
|
if cfg.get("datetime_system_prompt"):
|
||||||
@@ -162,7 +166,7 @@ class ProcessLLMRequest:
|
|||||||
current_time = (
|
current_time = (
|
||||||
datetime.datetime.now().astimezone().strftime("%Y-%m-%d %H:%M (%Z)")
|
datetime.datetime.now().astimezone().strftime("%Y-%m-%d %H:%M (%Z)")
|
||||||
)
|
)
|
||||||
req.system_prompt += f"\nCurrent datetime: {current_time}\n"
|
system_parts.append(f"Current datetime: {current_time}")
|
||||||
|
|
||||||
img_cap_prov_id: str = cfg.get("default_image_caption_provider_id") or ""
|
img_cap_prov_id: str = cfg.get("default_image_caption_provider_id") or ""
|
||||||
if req.conversation:
|
if req.conversation:
|
||||||
@@ -225,10 +229,17 @@ class ProcessLLMRequest:
|
|||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
logger.error(f"处理引用图片失败: {e}")
|
logger.error(f"处理引用图片失败: {e}")
|
||||||
|
|
||||||
# 3. 将所有部分组合成文本并直接注入到当前消息中
|
# 3. 将所有部分组合成文本并添加到 extra_user_content_parts 中
|
||||||
# 确保引用内容被正确的标签包裹
|
# 确保引用内容被正确的标签包裹
|
||||||
quoted_content = "\n".join(content_parts)
|
quoted_content = "\n".join(content_parts)
|
||||||
# 确保所有内容都在<Quoted Message>标签内
|
# 确保所有内容都在<Quoted Message>标签内
|
||||||
quoted_text = f"<Quoted Message>\n{quoted_content}\n</Quoted Message>"
|
quoted_text = f"<Quoted Message>\n{quoted_content}\n</Quoted Message>"
|
||||||
|
|
||||||
req.prompt = f"{quoted_text}\n\n{req.prompt}"
|
req.extra_user_content_parts.append(TextPart(text=quoted_text))
|
||||||
|
|
||||||
|
# 统一包裹所有系统提醒
|
||||||
|
if system_parts:
|
||||||
|
system_content = (
|
||||||
|
"<system_reminder>" + "\n".join(system_parts) + "</system_reminder>"
|
||||||
|
)
|
||||||
|
req.extra_user_content_parts.append(TextPart(text=system_content))
|
||||||
+6
-4
@@ -184,7 +184,8 @@ class ProviderCommands:
|
|||||||
event.set_result(MessageEventResult().message("请输入序号。"))
|
event.set_result(MessageEventResult().message("请输入序号。"))
|
||||||
return
|
return
|
||||||
if idx2 > len(self.context.get_all_tts_providers()) or idx2 < 1:
|
if idx2 > len(self.context.get_all_tts_providers()) or idx2 < 1:
|
||||||
event.set_result(MessageEventResult().message("无效的序号。"))
|
event.set_result(MessageEventResult().message("无效的提供商序号。"))
|
||||||
|
return
|
||||||
provider = self.context.get_all_tts_providers()[idx2 - 1]
|
provider = self.context.get_all_tts_providers()[idx2 - 1]
|
||||||
id_ = provider.meta().id
|
id_ = provider.meta().id
|
||||||
await self.context.provider_manager.set_provider(
|
await self.context.provider_manager.set_provider(
|
||||||
@@ -198,7 +199,8 @@ class ProviderCommands:
|
|||||||
event.set_result(MessageEventResult().message("请输入序号。"))
|
event.set_result(MessageEventResult().message("请输入序号。"))
|
||||||
return
|
return
|
||||||
if idx2 > len(self.context.get_all_stt_providers()) or idx2 < 1:
|
if idx2 > len(self.context.get_all_stt_providers()) or idx2 < 1:
|
||||||
event.set_result(MessageEventResult().message("无效的序号。"))
|
event.set_result(MessageEventResult().message("无效的提供商序号。"))
|
||||||
|
return
|
||||||
provider = self.context.get_all_stt_providers()[idx2 - 1]
|
provider = self.context.get_all_stt_providers()[idx2 - 1]
|
||||||
id_ = provider.meta().id
|
id_ = provider.meta().id
|
||||||
await self.context.provider_manager.set_provider(
|
await self.context.provider_manager.set_provider(
|
||||||
@@ -209,8 +211,8 @@ class ProviderCommands:
|
|||||||
event.set_result(MessageEventResult().message(f"成功切换到 {id_}。"))
|
event.set_result(MessageEventResult().message(f"成功切换到 {id_}。"))
|
||||||
elif isinstance(idx, int):
|
elif isinstance(idx, int):
|
||||||
if idx > len(self.context.get_all_providers()) or idx < 1:
|
if idx > len(self.context.get_all_providers()) or idx < 1:
|
||||||
event.set_result(MessageEventResult().message("无效的序号。"))
|
event.set_result(MessageEventResult().message("无效的提供商序号。"))
|
||||||
|
return
|
||||||
provider = self.context.get_all_providers()[idx - 1]
|
provider = self.context.get_all_providers()[idx - 1]
|
||||||
id_ = provider.meta().id
|
id_ = provider.meta().id
|
||||||
await self.context.provider_manager.set_provider(
|
await self.context.provider_manager.set_provider(
|
||||||
@@ -1 +1 @@
|
|||||||
__version__ = "4.10.2"
|
__version__ = "4.10.3"
|
||||||
|
|||||||
@@ -169,6 +169,15 @@ class Message(BaseModel):
|
|||||||
)
|
)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
@model_serializer(mode="wrap")
|
||||||
|
def serialize(self, handler):
|
||||||
|
data = handler(self)
|
||||||
|
if self.tool_calls is None:
|
||||||
|
data.pop("tool_calls", None)
|
||||||
|
if self.tool_call_id is None:
|
||||||
|
data.pop("tool_call_id", None)
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
class AssistantMessageSegment(Message):
|
class AssistantMessageSegment(Message):
|
||||||
"""A message segment from the assistant."""
|
"""A message segment from the assistant."""
|
||||||
|
|||||||
@@ -77,10 +77,11 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
|||||||
async def _iter_llm_responses(self) -> T.AsyncGenerator[LLMResponse, None]:
|
async def _iter_llm_responses(self) -> T.AsyncGenerator[LLMResponse, None]:
|
||||||
"""Yields chunks *and* a final LLMResponse."""
|
"""Yields chunks *and* a final LLMResponse."""
|
||||||
payload = {
|
payload = {
|
||||||
"contexts": self.run_context.messages,
|
"contexts": self.run_context.messages, # list[Message]
|
||||||
"func_tool": self.req.func_tool,
|
"func_tool": self.req.func_tool,
|
||||||
"model": self.req.model, # NOTE: in fact, this arg is None in most cases
|
"model": self.req.model, # NOTE: in fact, this arg is None in most cases
|
||||||
"session_id": self.req.session_id,
|
"session_id": self.req.session_id,
|
||||||
|
"extra_user_content_parts": self.req.extra_user_content_parts, # list[ContentPart]
|
||||||
}
|
}
|
||||||
|
|
||||||
if self.streaming:
|
if self.streaming:
|
||||||
|
|||||||
@@ -0,0 +1,26 @@
|
|||||||
|
"""AstrBot 备份与恢复模块
|
||||||
|
|
||||||
|
提供数据导出和导入功能,支持用户在服务器迁移时一键备份和恢复所有数据。
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 从 constants 模块导入共享常量
|
||||||
|
from .constants import (
|
||||||
|
BACKUP_MANIFEST_VERSION,
|
||||||
|
KB_METADATA_MODELS,
|
||||||
|
MAIN_DB_MODELS,
|
||||||
|
get_backup_directories,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 导入导出器和导入器
|
||||||
|
from .exporter import AstrBotExporter
|
||||||
|
from .importer import AstrBotImporter, ImportPreCheckResult
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AstrBotExporter",
|
||||||
|
"AstrBotImporter",
|
||||||
|
"ImportPreCheckResult",
|
||||||
|
"MAIN_DB_MODELS",
|
||||||
|
"KB_METADATA_MODELS",
|
||||||
|
"get_backup_directories",
|
||||||
|
"BACKUP_MANIFEST_VERSION",
|
||||||
|
]
|
||||||
@@ -0,0 +1,77 @@
|
|||||||
|
"""AstrBot 备份模块共享常量
|
||||||
|
|
||||||
|
此文件定义了导出器和导入器共享的常量,确保两端配置一致。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from sqlmodel import SQLModel
|
||||||
|
|
||||||
|
from astrbot.core.db.po import (
|
||||||
|
Attachment,
|
||||||
|
CommandConfig,
|
||||||
|
CommandConflict,
|
||||||
|
ConversationV2,
|
||||||
|
Persona,
|
||||||
|
PlatformMessageHistory,
|
||||||
|
PlatformSession,
|
||||||
|
PlatformStat,
|
||||||
|
Preference,
|
||||||
|
)
|
||||||
|
from astrbot.core.knowledge_base.models import (
|
||||||
|
KBDocument,
|
||||||
|
KBMedia,
|
||||||
|
KnowledgeBase,
|
||||||
|
)
|
||||||
|
from astrbot.core.utils.astrbot_path import (
|
||||||
|
get_astrbot_config_path,
|
||||||
|
get_astrbot_plugin_data_path,
|
||||||
|
get_astrbot_plugin_path,
|
||||||
|
get_astrbot_t2i_templates_path,
|
||||||
|
get_astrbot_temp_path,
|
||||||
|
get_astrbot_webchat_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# 共享常量 - 确保导出和导入端配置一致
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
# 主数据库模型类映射
|
||||||
|
MAIN_DB_MODELS: dict[str, type[SQLModel]] = {
|
||||||
|
"platform_stats": PlatformStat,
|
||||||
|
"conversations": ConversationV2,
|
||||||
|
"personas": Persona,
|
||||||
|
"preferences": Preference,
|
||||||
|
"platform_message_history": PlatformMessageHistory,
|
||||||
|
"platform_sessions": PlatformSession,
|
||||||
|
"attachments": Attachment,
|
||||||
|
"command_configs": CommandConfig,
|
||||||
|
"command_conflicts": CommandConflict,
|
||||||
|
}
|
||||||
|
|
||||||
|
# 知识库元数据模型类映射
|
||||||
|
KB_METADATA_MODELS: dict[str, type[SQLModel]] = {
|
||||||
|
"knowledge_bases": KnowledgeBase,
|
||||||
|
"kb_documents": KBDocument,
|
||||||
|
"kb_media": KBMedia,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_backup_directories() -> dict[str, str]:
|
||||||
|
"""获取需要备份的目录列表
|
||||||
|
|
||||||
|
使用 astrbot_path 模块动态获取路径,支持通过环境变量 ASTRBOT_ROOT 自定义根目录。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: 键为备份文件中的目录名称,值为目录的绝对路径
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"plugins": get_astrbot_plugin_path(), # 插件本体
|
||||||
|
"plugin_data": get_astrbot_plugin_data_path(), # 插件数据
|
||||||
|
"config": get_astrbot_config_path(), # 配置目录
|
||||||
|
"t2i_templates": get_astrbot_t2i_templates_path(), # T2I 模板
|
||||||
|
"webchat": get_astrbot_webchat_path(), # WebChat 数据
|
||||||
|
"temp": get_astrbot_temp_path(), # 临时文件
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# 备份清单版本号
|
||||||
|
BACKUP_MANIFEST_VERSION = "1.1"
|
||||||
@@ -0,0 +1,476 @@
|
|||||||
|
"""AstrBot 数据导出器
|
||||||
|
|
||||||
|
负责将所有数据导出为 ZIP 备份文件。
|
||||||
|
导出格式为 JSON,这是数据库无关的方案,支持未来向 MySQL/PostgreSQL 迁移。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import zipfile
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
from sqlalchemy import select
|
||||||
|
|
||||||
|
from astrbot.core import logger
|
||||||
|
from astrbot.core.config.default import VERSION
|
||||||
|
from astrbot.core.db import BaseDatabase
|
||||||
|
from astrbot.core.utils.astrbot_path import (
|
||||||
|
get_astrbot_backups_path,
|
||||||
|
get_astrbot_data_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 从共享常量模块导入
|
||||||
|
from .constants import (
|
||||||
|
BACKUP_MANIFEST_VERSION,
|
||||||
|
KB_METADATA_MODELS,
|
||||||
|
MAIN_DB_MODELS,
|
||||||
|
get_backup_directories,
|
||||||
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager
|
||||||
|
|
||||||
|
CMD_CONFIG_FILE_PATH = os.path.join(get_astrbot_data_path(), "cmd_config.json")
|
||||||
|
|
||||||
|
|
||||||
|
class AstrBotExporter:
|
||||||
|
"""AstrBot 数据导出器
|
||||||
|
|
||||||
|
导出内容:
|
||||||
|
- 主数据库所有表(data/data_v4.db)
|
||||||
|
- 知识库元数据(data/knowledge_base/kb.db)
|
||||||
|
- 每个知识库的向量文档数据
|
||||||
|
- 配置文件(data/cmd_config.json)
|
||||||
|
- 附件文件
|
||||||
|
- 知识库多媒体文件
|
||||||
|
- 插件目录(data/plugins)
|
||||||
|
- 插件数据目录(data/plugin_data)
|
||||||
|
- 配置目录(data/config)
|
||||||
|
- T2I 模板目录(data/t2i_templates)
|
||||||
|
- WebChat 数据目录(data/webchat)
|
||||||
|
- 临时文件目录(data/temp)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
main_db: BaseDatabase,
|
||||||
|
kb_manager: "KnowledgeBaseManager | None" = None,
|
||||||
|
config_path: str = CMD_CONFIG_FILE_PATH,
|
||||||
|
):
|
||||||
|
self.main_db = main_db
|
||||||
|
self.kb_manager = kb_manager
|
||||||
|
self.config_path = config_path
|
||||||
|
self._checksums: dict[str, str] = {}
|
||||||
|
|
||||||
|
async def export_all(
|
||||||
|
self,
|
||||||
|
output_dir: str | None = None,
|
||||||
|
progress_callback: Any | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""导出所有数据到 ZIP 文件
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output_dir: 输出目录
|
||||||
|
progress_callback: 进度回调函数,接收参数 (stage, current, total, message)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 生成的 ZIP 文件路径
|
||||||
|
"""
|
||||||
|
if output_dir is None:
|
||||||
|
output_dir = get_astrbot_backups_path()
|
||||||
|
|
||||||
|
# 确保输出目录存在
|
||||||
|
Path(output_dir).mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
|
zip_filename = f"astrbot_backup_{timestamp}.zip"
|
||||||
|
zip_path = os.path.join(output_dir, zip_filename)
|
||||||
|
|
||||||
|
logger.info(f"开始导出备份到 {zip_path}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
|
||||||
|
# 1. 导出主数据库
|
||||||
|
if progress_callback:
|
||||||
|
await progress_callback("main_db", 0, 100, "正在导出主数据库...")
|
||||||
|
main_data = await self._export_main_database()
|
||||||
|
main_db_json = json.dumps(
|
||||||
|
main_data, ensure_ascii=False, indent=2, default=str
|
||||||
|
)
|
||||||
|
zf.writestr("databases/main_db.json", main_db_json)
|
||||||
|
self._add_checksum("databases/main_db.json", main_db_json)
|
||||||
|
if progress_callback:
|
||||||
|
await progress_callback("main_db", 100, 100, "主数据库导出完成")
|
||||||
|
|
||||||
|
# 2. 导出知识库数据
|
||||||
|
kb_meta_data: dict[str, Any] = {
|
||||||
|
"knowledge_bases": [],
|
||||||
|
"kb_documents": [],
|
||||||
|
"kb_media": [],
|
||||||
|
}
|
||||||
|
if self.kb_manager:
|
||||||
|
if progress_callback:
|
||||||
|
await progress_callback(
|
||||||
|
"kb_metadata", 0, 100, "正在导出知识库元数据..."
|
||||||
|
)
|
||||||
|
kb_meta_data = await self._export_kb_metadata()
|
||||||
|
kb_meta_json = json.dumps(
|
||||||
|
kb_meta_data, ensure_ascii=False, indent=2, default=str
|
||||||
|
)
|
||||||
|
zf.writestr("databases/kb_metadata.json", kb_meta_json)
|
||||||
|
self._add_checksum("databases/kb_metadata.json", kb_meta_json)
|
||||||
|
if progress_callback:
|
||||||
|
await progress_callback(
|
||||||
|
"kb_metadata", 100, 100, "知识库元数据导出完成"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 导出每个知识库的文档数据
|
||||||
|
kb_insts = self.kb_manager.kb_insts
|
||||||
|
total_kbs = len(kb_insts)
|
||||||
|
for idx, (kb_id, kb_helper) in enumerate(kb_insts.items()):
|
||||||
|
if progress_callback:
|
||||||
|
await progress_callback(
|
||||||
|
"kb_documents",
|
||||||
|
idx,
|
||||||
|
total_kbs,
|
||||||
|
f"正在导出知识库 {kb_helper.kb.kb_name} 的文档数据...",
|
||||||
|
)
|
||||||
|
doc_data = await self._export_kb_documents(kb_helper)
|
||||||
|
doc_json = json.dumps(
|
||||||
|
doc_data, ensure_ascii=False, indent=2, default=str
|
||||||
|
)
|
||||||
|
doc_path = f"databases/kb_{kb_id}/documents.json"
|
||||||
|
zf.writestr(doc_path, doc_json)
|
||||||
|
self._add_checksum(doc_path, doc_json)
|
||||||
|
|
||||||
|
# 导出 FAISS 索引文件
|
||||||
|
await self._export_faiss_index(zf, kb_helper, kb_id)
|
||||||
|
|
||||||
|
# 导出知识库多媒体文件
|
||||||
|
await self._export_kb_media_files(zf, kb_helper, kb_id)
|
||||||
|
|
||||||
|
if progress_callback:
|
||||||
|
await progress_callback(
|
||||||
|
"kb_documents", total_kbs, total_kbs, "知识库文档导出完成"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. 导出配置文件
|
||||||
|
if progress_callback:
|
||||||
|
await progress_callback("config", 0, 100, "正在导出配置文件...")
|
||||||
|
if os.path.exists(self.config_path):
|
||||||
|
with open(self.config_path, encoding="utf-8") as f:
|
||||||
|
config_content = f.read()
|
||||||
|
zf.writestr("config/cmd_config.json", config_content)
|
||||||
|
self._add_checksum("config/cmd_config.json", config_content)
|
||||||
|
if progress_callback:
|
||||||
|
await progress_callback("config", 100, 100, "配置文件导出完成")
|
||||||
|
|
||||||
|
# 4. 导出附件文件
|
||||||
|
if progress_callback:
|
||||||
|
await progress_callback("attachments", 0, 100, "正在导出附件...")
|
||||||
|
await self._export_attachments(zf, main_data.get("attachments", []))
|
||||||
|
if progress_callback:
|
||||||
|
await progress_callback("attachments", 100, 100, "附件导出完成")
|
||||||
|
|
||||||
|
# 5. 导出插件和其他目录
|
||||||
|
if progress_callback:
|
||||||
|
await progress_callback(
|
||||||
|
"directories", 0, 100, "正在导出插件和数据目录..."
|
||||||
|
)
|
||||||
|
dir_stats = await self._export_directories(zf)
|
||||||
|
if progress_callback:
|
||||||
|
await progress_callback("directories", 100, 100, "目录导出完成")
|
||||||
|
|
||||||
|
# 6. 生成 manifest
|
||||||
|
if progress_callback:
|
||||||
|
await progress_callback("manifest", 0, 100, "正在生成清单...")
|
||||||
|
manifest = self._generate_manifest(main_data, kb_meta_data, dir_stats)
|
||||||
|
manifest_json = json.dumps(manifest, ensure_ascii=False, indent=2)
|
||||||
|
zf.writestr("manifest.json", manifest_json)
|
||||||
|
if progress_callback:
|
||||||
|
await progress_callback("manifest", 100, 100, "清单生成完成")
|
||||||
|
|
||||||
|
logger.info(f"备份导出完成: {zip_path}")
|
||||||
|
return zip_path
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"备份导出失败: {e}")
|
||||||
|
# 清理失败的文件
|
||||||
|
if os.path.exists(zip_path):
|
||||||
|
os.remove(zip_path)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def _export_main_database(self) -> dict[str, list[dict]]:
|
||||||
|
"""导出主数据库所有表"""
|
||||||
|
export_data: dict[str, list[dict]] = {}
|
||||||
|
|
||||||
|
async with self.main_db.get_db() as session:
|
||||||
|
for table_name, model_class in MAIN_DB_MODELS.items():
|
||||||
|
try:
|
||||||
|
result = await session.execute(select(model_class))
|
||||||
|
records = result.scalars().all()
|
||||||
|
export_data[table_name] = [
|
||||||
|
self._model_to_dict(record) for record in records
|
||||||
|
]
|
||||||
|
logger.debug(
|
||||||
|
f"导出表 {table_name}: {len(export_data[table_name])} 条记录"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"导出表 {table_name} 失败: {e}")
|
||||||
|
export_data[table_name] = []
|
||||||
|
|
||||||
|
return export_data
|
||||||
|
|
||||||
|
async def _export_kb_metadata(self) -> dict[str, list[dict]]:
|
||||||
|
"""导出知识库元数据库"""
|
||||||
|
if not self.kb_manager:
|
||||||
|
return {"knowledge_bases": [], "kb_documents": [], "kb_media": []}
|
||||||
|
|
||||||
|
export_data: dict[str, list[dict]] = {}
|
||||||
|
|
||||||
|
async with self.kb_manager.kb_db.get_db() as session:
|
||||||
|
for table_name, model_class in KB_METADATA_MODELS.items():
|
||||||
|
try:
|
||||||
|
result = await session.execute(select(model_class))
|
||||||
|
records = result.scalars().all()
|
||||||
|
export_data[table_name] = [
|
||||||
|
self._model_to_dict(record) for record in records
|
||||||
|
]
|
||||||
|
logger.debug(
|
||||||
|
f"导出知识库表 {table_name}: {len(export_data[table_name])} 条记录"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"导出知识库表 {table_name} 失败: {e}")
|
||||||
|
export_data[table_name] = []
|
||||||
|
|
||||||
|
return export_data
|
||||||
|
|
||||||
|
async def _export_kb_documents(self, kb_helper: Any) -> dict[str, Any]:
|
||||||
|
"""导出知识库的文档块数据"""
|
||||||
|
try:
|
||||||
|
from astrbot.core.db.vec_db.faiss_impl.vec_db import FaissVecDB
|
||||||
|
|
||||||
|
vec_db: FaissVecDB = kb_helper.vec_db
|
||||||
|
if not vec_db or not vec_db.document_storage:
|
||||||
|
return {"documents": []}
|
||||||
|
|
||||||
|
# 获取所有文档
|
||||||
|
docs = await vec_db.document_storage.get_documents(
|
||||||
|
metadata_filters={},
|
||||||
|
offset=0,
|
||||||
|
limit=None, # 获取全部
|
||||||
|
)
|
||||||
|
|
||||||
|
return {"documents": docs}
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"导出知识库文档失败: {e}")
|
||||||
|
return {"documents": []}
|
||||||
|
|
||||||
|
async def _export_faiss_index(
|
||||||
|
self,
|
||||||
|
zf: zipfile.ZipFile,
|
||||||
|
kb_helper: Any,
|
||||||
|
kb_id: str,
|
||||||
|
) -> None:
|
||||||
|
"""导出 FAISS 索引文件"""
|
||||||
|
try:
|
||||||
|
index_path = kb_helper.kb_dir / "index.faiss"
|
||||||
|
if index_path.exists():
|
||||||
|
archive_path = f"databases/kb_{kb_id}/index.faiss"
|
||||||
|
zf.write(str(index_path), archive_path)
|
||||||
|
logger.debug(f"导出 FAISS 索引: {archive_path}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"导出 FAISS 索引失败: {e}")
|
||||||
|
|
||||||
|
async def _export_kb_media_files(
|
||||||
|
self, zf: zipfile.ZipFile, kb_helper: Any, kb_id: str
|
||||||
|
) -> None:
|
||||||
|
"""导出知识库的多媒体文件"""
|
||||||
|
try:
|
||||||
|
media_dir = kb_helper.kb_medias_dir
|
||||||
|
if not media_dir.exists():
|
||||||
|
return
|
||||||
|
|
||||||
|
for root, _, files in os.walk(media_dir):
|
||||||
|
for file in files:
|
||||||
|
file_path = Path(root) / file
|
||||||
|
# 计算相对路径
|
||||||
|
rel_path = file_path.relative_to(kb_helper.kb_dir)
|
||||||
|
archive_path = f"files/kb_media/{kb_id}/{rel_path}"
|
||||||
|
zf.write(str(file_path), archive_path)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"导出知识库媒体文件失败: {e}")
|
||||||
|
|
||||||
|
async def _export_directories(
|
||||||
|
self, zf: zipfile.ZipFile
|
||||||
|
) -> dict[str, dict[str, int]]:
|
||||||
|
"""导出插件和其他数据目录
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: 每个目录的统计信息 {dir_name: {"files": count, "size": bytes}}
|
||||||
|
"""
|
||||||
|
stats: dict[str, dict[str, int]] = {}
|
||||||
|
backup_directories = get_backup_directories()
|
||||||
|
|
||||||
|
for dir_name, dir_path in backup_directories.items():
|
||||||
|
full_path = Path(dir_path)
|
||||||
|
if not full_path.exists():
|
||||||
|
logger.debug(f"目录不存在,跳过: {full_path}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
file_count = 0
|
||||||
|
total_size = 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
for root, dirs, files in os.walk(full_path):
|
||||||
|
# 跳过 __pycache__ 目录
|
||||||
|
dirs[:] = [d for d in dirs if d != "__pycache__"]
|
||||||
|
|
||||||
|
for file in files:
|
||||||
|
# 跳过 .pyc 文件
|
||||||
|
if file.endswith(".pyc"):
|
||||||
|
continue
|
||||||
|
|
||||||
|
file_path = Path(root) / file
|
||||||
|
try:
|
||||||
|
# 计算相对路径
|
||||||
|
rel_path = file_path.relative_to(full_path)
|
||||||
|
archive_path = f"directories/{dir_name}/{rel_path}"
|
||||||
|
zf.write(str(file_path), archive_path)
|
||||||
|
file_count += 1
|
||||||
|
total_size += file_path.stat().st_size
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"导出文件 {file_path} 失败: {e}")
|
||||||
|
|
||||||
|
stats[dir_name] = {"files": file_count, "size": total_size}
|
||||||
|
logger.debug(
|
||||||
|
f"导出目录 {dir_name}: {file_count} 个文件, {total_size} 字节"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"导出目录 {dir_path} 失败: {e}")
|
||||||
|
stats[dir_name] = {"files": 0, "size": 0}
|
||||||
|
|
||||||
|
return stats
|
||||||
|
|
||||||
|
async def _export_attachments(
|
||||||
|
self, zf: zipfile.ZipFile, attachments: list[dict]
|
||||||
|
) -> None:
|
||||||
|
"""导出附件文件"""
|
||||||
|
for attachment in attachments:
|
||||||
|
try:
|
||||||
|
file_path = attachment.get("path", "")
|
||||||
|
if file_path and os.path.exists(file_path):
|
||||||
|
# 使用 attachment_id 作为文件名
|
||||||
|
attachment_id = attachment.get("attachment_id", "")
|
||||||
|
ext = os.path.splitext(file_path)[1]
|
||||||
|
archive_path = f"files/attachments/{attachment_id}{ext}"
|
||||||
|
zf.write(file_path, archive_path)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"导出附件失败: {e}")
|
||||||
|
|
||||||
|
def _model_to_dict(self, record: Any) -> dict:
|
||||||
|
"""将 SQLModel 实例转换为字典
|
||||||
|
|
||||||
|
这是数据库无关的序列化方式,支持未来迁移到其他数据库。
|
||||||
|
"""
|
||||||
|
# 使用 SQLModel 内置的 model_dump 方法(如果可用)
|
||||||
|
if hasattr(record, "model_dump"):
|
||||||
|
data = record.model_dump(mode="python")
|
||||||
|
# 处理 datetime 类型
|
||||||
|
for key, value in data.items():
|
||||||
|
if isinstance(value, datetime):
|
||||||
|
data[key] = value.isoformat()
|
||||||
|
return data
|
||||||
|
|
||||||
|
# 回退到手动提取
|
||||||
|
data = {}
|
||||||
|
# 使用 inspect 获取表信息
|
||||||
|
from sqlalchemy import inspect as sa_inspect
|
||||||
|
|
||||||
|
mapper = sa_inspect(record.__class__)
|
||||||
|
for column in mapper.columns:
|
||||||
|
value = getattr(record, column.name)
|
||||||
|
# 处理 datetime 类型 - 统一转为 ISO 格式字符串
|
||||||
|
if isinstance(value, datetime):
|
||||||
|
value = value.isoformat()
|
||||||
|
data[column.name] = value
|
||||||
|
return data
|
||||||
|
|
||||||
|
def _add_checksum(self, path: str, content: str | bytes) -> None:
|
||||||
|
"""计算并添加文件校验和"""
|
||||||
|
if isinstance(content, str):
|
||||||
|
content = content.encode("utf-8")
|
||||||
|
checksum = hashlib.sha256(content).hexdigest()
|
||||||
|
self._checksums[path] = f"sha256:{checksum}"
|
||||||
|
|
||||||
|
def _generate_manifest(
|
||||||
|
self,
|
||||||
|
main_data: dict[str, list[dict]],
|
||||||
|
kb_meta_data: dict[str, list[dict]],
|
||||||
|
dir_stats: dict[str, dict[str, int]] | None = None,
|
||||||
|
) -> dict:
|
||||||
|
"""生成备份清单"""
|
||||||
|
if dir_stats is None:
|
||||||
|
dir_stats = {}
|
||||||
|
# 收集知识库 ID
|
||||||
|
kb_document_tables = {}
|
||||||
|
if self.kb_manager:
|
||||||
|
for kb_id in self.kb_manager.kb_insts.keys():
|
||||||
|
kb_document_tables[kb_id] = "documents"
|
||||||
|
|
||||||
|
# 收集附件文件列表
|
||||||
|
attachment_files = []
|
||||||
|
for attachment in main_data.get("attachments", []):
|
||||||
|
attachment_id = attachment.get("attachment_id", "")
|
||||||
|
path = attachment.get("path", "")
|
||||||
|
if attachment_id and path:
|
||||||
|
ext = os.path.splitext(path)[1]
|
||||||
|
attachment_files.append(f"{attachment_id}{ext}")
|
||||||
|
|
||||||
|
# 收集知识库媒体文件
|
||||||
|
kb_media_files: dict[str, list[str]] = {}
|
||||||
|
if self.kb_manager:
|
||||||
|
for kb_id, kb_helper in self.kb_manager.kb_insts.items():
|
||||||
|
media_files: list[str] = []
|
||||||
|
media_dir = kb_helper.kb_medias_dir
|
||||||
|
if media_dir.exists():
|
||||||
|
for root, _, files in os.walk(media_dir):
|
||||||
|
for file in files:
|
||||||
|
media_files.append(file)
|
||||||
|
if media_files:
|
||||||
|
kb_media_files[kb_id] = media_files
|
||||||
|
|
||||||
|
manifest = {
|
||||||
|
"version": BACKUP_MANIFEST_VERSION,
|
||||||
|
"astrbot_version": VERSION,
|
||||||
|
"exported_at": datetime.now(timezone.utc).isoformat(),
|
||||||
|
"schema_version": {
|
||||||
|
"main_db": "v4",
|
||||||
|
"kb_db": "v1",
|
||||||
|
},
|
||||||
|
"tables": {
|
||||||
|
"main_db": list(main_data.keys()),
|
||||||
|
"kb_metadata": list(kb_meta_data.keys()),
|
||||||
|
"kb_documents": kb_document_tables,
|
||||||
|
},
|
||||||
|
"files": {
|
||||||
|
"attachments": attachment_files,
|
||||||
|
"kb_media": kb_media_files,
|
||||||
|
},
|
||||||
|
"directories": list(dir_stats.keys()),
|
||||||
|
"checksums": self._checksums,
|
||||||
|
"statistics": {
|
||||||
|
"main_db": {
|
||||||
|
table: len(records) for table, records in main_data.items()
|
||||||
|
},
|
||||||
|
"kb_metadata": {
|
||||||
|
table: len(records) for table, records in kb_meta_data.items()
|
||||||
|
},
|
||||||
|
"directories": dir_stats,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
return manifest
|
||||||
@@ -0,0 +1,761 @@
|
|||||||
|
"""AstrBot 数据导入器
|
||||||
|
|
||||||
|
负责从 ZIP 备份文件恢复所有数据。
|
||||||
|
导入时进行版本校验:
|
||||||
|
- 主版本(前两位)不同时直接拒绝导入
|
||||||
|
- 小版本(第三位)不同时提示警告,用户可选择强制导入
|
||||||
|
- 版本匹配时也需要用户确认
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import zipfile
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
from sqlalchemy import delete
|
||||||
|
|
||||||
|
from astrbot.core import logger
|
||||||
|
from astrbot.core.config.default import VERSION
|
||||||
|
from astrbot.core.db import BaseDatabase
|
||||||
|
from astrbot.core.utils.astrbot_path import (
|
||||||
|
get_astrbot_data_path,
|
||||||
|
get_astrbot_knowledge_base_path,
|
||||||
|
)
|
||||||
|
from astrbot.core.utils.version_comparator import VersionComparator
|
||||||
|
|
||||||
|
# 从共享常量模块导入
|
||||||
|
from .constants import (
|
||||||
|
KB_METADATA_MODELS,
|
||||||
|
MAIN_DB_MODELS,
|
||||||
|
get_backup_directories,
|
||||||
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager
|
||||||
|
|
||||||
|
|
||||||
|
def _get_major_version(version_str: str) -> str:
|
||||||
|
"""提取版本的主版本部分(前两位)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
version_str: 版本字符串,如 "4.9.1", "4.10.0-beta"
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
主版本字符串,如 "4.9", "4.10"
|
||||||
|
"""
|
||||||
|
if not version_str:
|
||||||
|
return "0.0"
|
||||||
|
# 移除 v 前缀和预发布标签
|
||||||
|
version = version_str.lower().replace("v", "").split("-")[0].split("+")[0]
|
||||||
|
parts = [p for p in version.split(".") if p] # 过滤空字符串
|
||||||
|
if len(parts) >= 2:
|
||||||
|
return f"{parts[0]}.{parts[1]}"
|
||||||
|
elif len(parts) == 1 and parts[0]:
|
||||||
|
return f"{parts[0]}.0"
|
||||||
|
return "0.0"
|
||||||
|
|
||||||
|
|
||||||
|
CMD_CONFIG_FILE_PATH = os.path.join(get_astrbot_data_path(), "cmd_config.json")
|
||||||
|
KB_PATH = get_astrbot_knowledge_base_path()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ImportPreCheckResult:
|
||||||
|
"""导入预检查结果
|
||||||
|
|
||||||
|
用于在实际导入前检查备份文件的版本兼容性,
|
||||||
|
并返回确认信息让用户决定是否继续导入。
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 检查是否通过(文件有效且版本可导入)
|
||||||
|
valid: bool = False
|
||||||
|
# 是否可以导入(版本兼容)
|
||||||
|
can_import: bool = False
|
||||||
|
# 版本状态: match(完全匹配), minor_diff(小版本差异), major_diff(主版本不同,拒绝)
|
||||||
|
version_status: str = ""
|
||||||
|
# 备份文件中的 AstrBot 版本
|
||||||
|
backup_version: str = ""
|
||||||
|
# 当前运行的 AstrBot 版本
|
||||||
|
current_version: str = VERSION
|
||||||
|
# 备份创建时间
|
||||||
|
backup_time: str = ""
|
||||||
|
# 确认消息(显示给用户)
|
||||||
|
confirm_message: str = ""
|
||||||
|
# 警告消息列表
|
||||||
|
warnings: list[str] = field(default_factory=list)
|
||||||
|
# 错误消息(如果检查失败)
|
||||||
|
error: str = ""
|
||||||
|
# 备份包含的内容摘要
|
||||||
|
backup_summary: dict = field(default_factory=dict)
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
return {
|
||||||
|
"valid": self.valid,
|
||||||
|
"can_import": self.can_import,
|
||||||
|
"version_status": self.version_status,
|
||||||
|
"backup_version": self.backup_version,
|
||||||
|
"current_version": self.current_version,
|
||||||
|
"backup_time": self.backup_time,
|
||||||
|
"confirm_message": self.confirm_message,
|
||||||
|
"warnings": self.warnings,
|
||||||
|
"error": self.error,
|
||||||
|
"backup_summary": self.backup_summary,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class ImportResult:
|
||||||
|
"""导入结果"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.success = True
|
||||||
|
self.imported_tables: dict[str, int] = {}
|
||||||
|
self.imported_files: dict[str, int] = {}
|
||||||
|
self.imported_directories: dict[str, int] = {}
|
||||||
|
self.warnings: list[str] = []
|
||||||
|
self.errors: list[str] = []
|
||||||
|
|
||||||
|
def add_warning(self, msg: str) -> None:
|
||||||
|
self.warnings.append(msg)
|
||||||
|
logger.warning(msg)
|
||||||
|
|
||||||
|
def add_error(self, msg: str) -> None:
|
||||||
|
self.errors.append(msg)
|
||||||
|
self.success = False
|
||||||
|
logger.error(msg)
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
return {
|
||||||
|
"success": self.success,
|
||||||
|
"imported_tables": self.imported_tables,
|
||||||
|
"imported_files": self.imported_files,
|
||||||
|
"imported_directories": self.imported_directories,
|
||||||
|
"warnings": self.warnings,
|
||||||
|
"errors": self.errors,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class AstrBotImporter:
|
||||||
|
"""AstrBot 数据导入器
|
||||||
|
|
||||||
|
导入备份文件中的所有数据,包括:
|
||||||
|
- 主数据库所有表
|
||||||
|
- 知识库元数据和文档
|
||||||
|
- 配置文件
|
||||||
|
- 附件文件
|
||||||
|
- 知识库多媒体文件
|
||||||
|
- 插件目录(data/plugins)
|
||||||
|
- 插件数据目录(data/plugin_data)
|
||||||
|
- 配置目录(data/config)
|
||||||
|
- T2I 模板目录(data/t2i_templates)
|
||||||
|
- WebChat 数据目录(data/webchat)
|
||||||
|
- 临时文件目录(data/temp)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
main_db: BaseDatabase,
|
||||||
|
kb_manager: "KnowledgeBaseManager | None" = None,
|
||||||
|
config_path: str = CMD_CONFIG_FILE_PATH,
|
||||||
|
kb_root_dir: str = KB_PATH,
|
||||||
|
):
|
||||||
|
self.main_db = main_db
|
||||||
|
self.kb_manager = kb_manager
|
||||||
|
self.config_path = config_path
|
||||||
|
self.kb_root_dir = kb_root_dir
|
||||||
|
|
||||||
|
def pre_check(self, zip_path: str) -> ImportPreCheckResult:
|
||||||
|
"""预检查备份文件
|
||||||
|
|
||||||
|
在实际导入前检查备份文件的有效性和版本兼容性。
|
||||||
|
返回检查结果供前端显示确认对话框。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
zip_path: ZIP 备份文件路径
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ImportPreCheckResult: 预检查结果
|
||||||
|
"""
|
||||||
|
result = ImportPreCheckResult()
|
||||||
|
result.current_version = VERSION
|
||||||
|
|
||||||
|
if not os.path.exists(zip_path):
|
||||||
|
result.error = f"备份文件不存在: {zip_path}"
|
||||||
|
return result
|
||||||
|
|
||||||
|
try:
|
||||||
|
with zipfile.ZipFile(zip_path, "r") as zf:
|
||||||
|
# 读取 manifest
|
||||||
|
try:
|
||||||
|
manifest_data = zf.read("manifest.json")
|
||||||
|
manifest = json.loads(manifest_data)
|
||||||
|
except KeyError:
|
||||||
|
result.error = "备份文件缺少 manifest.json,不是有效的 AstrBot 备份"
|
||||||
|
return result
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
result.error = f"manifest.json 格式错误: {e}"
|
||||||
|
return result
|
||||||
|
|
||||||
|
# 提取基本信息
|
||||||
|
result.backup_version = manifest.get("astrbot_version", "未知")
|
||||||
|
result.backup_time = manifest.get("exported_at", "未知")
|
||||||
|
result.valid = True
|
||||||
|
|
||||||
|
# 构建备份摘要
|
||||||
|
result.backup_summary = {
|
||||||
|
"tables": list(manifest.get("tables", {}).keys()),
|
||||||
|
"has_knowledge_bases": manifest.get("has_knowledge_bases", False),
|
||||||
|
"has_config": manifest.get("has_config", False),
|
||||||
|
"directories": manifest.get("directories", []),
|
||||||
|
}
|
||||||
|
|
||||||
|
# 检查版本兼容性
|
||||||
|
version_check = self._check_version_compatibility(result.backup_version)
|
||||||
|
result.version_status = version_check["status"]
|
||||||
|
result.can_import = version_check["can_import"]
|
||||||
|
|
||||||
|
# 版本信息由前端根据 version_status 和 i18n 生成显示
|
||||||
|
# 不再将版本消息添加到 warnings 列表中,避免中文硬编码
|
||||||
|
# warnings 列表保留用于其他非版本相关的警告
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
except zipfile.BadZipFile:
|
||||||
|
result.error = "无效的 ZIP 文件"
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
result.error = f"检查备份文件失败: {e}"
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _check_version_compatibility(self, backup_version: str) -> dict:
|
||||||
|
"""检查版本兼容性
|
||||||
|
|
||||||
|
规则:
|
||||||
|
- 主版本(前两位,如 4.9)必须一致,否则拒绝
|
||||||
|
- 小版本(第三位,如 4.9.1 vs 4.9.2)不同时,警告但允许导入
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: {status, can_import, message}
|
||||||
|
"""
|
||||||
|
if not backup_version:
|
||||||
|
return {
|
||||||
|
"status": "major_diff",
|
||||||
|
"can_import": False,
|
||||||
|
"message": "备份文件缺少版本信息",
|
||||||
|
}
|
||||||
|
|
||||||
|
# 提取主版本(前两位)进行比较
|
||||||
|
backup_major = _get_major_version(backup_version)
|
||||||
|
current_major = _get_major_version(VERSION)
|
||||||
|
|
||||||
|
# 比较主版本
|
||||||
|
if VersionComparator.compare_version(backup_major, current_major) != 0:
|
||||||
|
return {
|
||||||
|
"status": "major_diff",
|
||||||
|
"can_import": False,
|
||||||
|
"message": (
|
||||||
|
f"主版本不兼容: 备份版本 {backup_version}, 当前版本 {VERSION}。"
|
||||||
|
f"跨主版本导入可能导致数据损坏,请使用相同主版本的 AstrBot。"
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
# 比较完整版本
|
||||||
|
version_cmp = VersionComparator.compare_version(backup_version, VERSION)
|
||||||
|
if version_cmp != 0:
|
||||||
|
return {
|
||||||
|
"status": "minor_diff",
|
||||||
|
"can_import": True,
|
||||||
|
"message": (
|
||||||
|
f"小版本差异: 备份版本 {backup_version}, 当前版本 {VERSION}。"
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "match",
|
||||||
|
"can_import": True,
|
||||||
|
"message": "版本匹配",
|
||||||
|
}
|
||||||
|
|
||||||
|
async def import_all(
|
||||||
|
self,
|
||||||
|
zip_path: str,
|
||||||
|
mode: str = "replace", # "replace" 清空后导入
|
||||||
|
progress_callback: Any | None = None,
|
||||||
|
) -> ImportResult:
|
||||||
|
"""从 ZIP 文件导入所有数据
|
||||||
|
|
||||||
|
Args:
|
||||||
|
zip_path: ZIP 备份文件路径
|
||||||
|
mode: 导入模式,目前仅支持 "replace"(清空后导入)
|
||||||
|
progress_callback: 进度回调函数,接收参数 (stage, current, total, message)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ImportResult: 导入结果
|
||||||
|
"""
|
||||||
|
result = ImportResult()
|
||||||
|
|
||||||
|
if not os.path.exists(zip_path):
|
||||||
|
result.add_error(f"备份文件不存在: {zip_path}")
|
||||||
|
return result
|
||||||
|
|
||||||
|
logger.info(f"开始从 {zip_path} 导入备份")
|
||||||
|
|
||||||
|
try:
|
||||||
|
with zipfile.ZipFile(zip_path, "r") as zf:
|
||||||
|
# 1. 读取并验证 manifest
|
||||||
|
if progress_callback:
|
||||||
|
await progress_callback("validate", 0, 100, "正在验证备份文件...")
|
||||||
|
|
||||||
|
try:
|
||||||
|
manifest_data = zf.read("manifest.json")
|
||||||
|
manifest = json.loads(manifest_data)
|
||||||
|
except KeyError:
|
||||||
|
result.add_error("备份文件缺少 manifest.json")
|
||||||
|
return result
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
result.add_error(f"manifest.json 格式错误: {e}")
|
||||||
|
return result
|
||||||
|
|
||||||
|
# 版本校验
|
||||||
|
try:
|
||||||
|
self._validate_version(manifest)
|
||||||
|
except ValueError as e:
|
||||||
|
result.add_error(str(e))
|
||||||
|
return result
|
||||||
|
|
||||||
|
if progress_callback:
|
||||||
|
await progress_callback("validate", 100, 100, "验证完成")
|
||||||
|
|
||||||
|
# 2. 导入主数据库
|
||||||
|
if progress_callback:
|
||||||
|
await progress_callback("main_db", 0, 100, "正在导入主数据库...")
|
||||||
|
|
||||||
|
try:
|
||||||
|
main_data_content = zf.read("databases/main_db.json")
|
||||||
|
main_data = json.loads(main_data_content)
|
||||||
|
|
||||||
|
if mode == "replace":
|
||||||
|
await self._clear_main_db()
|
||||||
|
|
||||||
|
imported = await self._import_main_database(main_data)
|
||||||
|
result.imported_tables.update(imported)
|
||||||
|
except Exception as e:
|
||||||
|
result.add_error(f"导入主数据库失败: {e}")
|
||||||
|
return result
|
||||||
|
|
||||||
|
if progress_callback:
|
||||||
|
await progress_callback("main_db", 100, 100, "主数据库导入完成")
|
||||||
|
|
||||||
|
# 3. 导入知识库
|
||||||
|
if self.kb_manager and "databases/kb_metadata.json" in zf.namelist():
|
||||||
|
if progress_callback:
|
||||||
|
await progress_callback("kb", 0, 100, "正在导入知识库...")
|
||||||
|
|
||||||
|
try:
|
||||||
|
kb_meta_content = zf.read("databases/kb_metadata.json")
|
||||||
|
kb_meta_data = json.loads(kb_meta_content)
|
||||||
|
|
||||||
|
if mode == "replace":
|
||||||
|
await self._clear_kb_data()
|
||||||
|
|
||||||
|
await self._import_knowledge_bases(zf, kb_meta_data, result)
|
||||||
|
except Exception as e:
|
||||||
|
result.add_warning(f"导入知识库失败: {e}")
|
||||||
|
|
||||||
|
if progress_callback:
|
||||||
|
await progress_callback("kb", 100, 100, "知识库导入完成")
|
||||||
|
|
||||||
|
# 4. 导入配置文件
|
||||||
|
if progress_callback:
|
||||||
|
await progress_callback("config", 0, 100, "正在导入配置文件...")
|
||||||
|
|
||||||
|
if "config/cmd_config.json" in zf.namelist():
|
||||||
|
try:
|
||||||
|
config_content = zf.read("config/cmd_config.json")
|
||||||
|
# 备份现有配置
|
||||||
|
if os.path.exists(self.config_path):
|
||||||
|
backup_path = f"{self.config_path}.bak"
|
||||||
|
shutil.copy2(self.config_path, backup_path)
|
||||||
|
|
||||||
|
with open(self.config_path, "wb") as f:
|
||||||
|
f.write(config_content)
|
||||||
|
result.imported_files["config"] = 1
|
||||||
|
except Exception as e:
|
||||||
|
result.add_warning(f"导入配置文件失败: {e}")
|
||||||
|
|
||||||
|
if progress_callback:
|
||||||
|
await progress_callback("config", 100, 100, "配置文件导入完成")
|
||||||
|
|
||||||
|
# 5. 导入附件文件
|
||||||
|
if progress_callback:
|
||||||
|
await progress_callback("attachments", 0, 100, "正在导入附件...")
|
||||||
|
|
||||||
|
attachment_count = await self._import_attachments(
|
||||||
|
zf, main_data.get("attachments", [])
|
||||||
|
)
|
||||||
|
result.imported_files["attachments"] = attachment_count
|
||||||
|
|
||||||
|
if progress_callback:
|
||||||
|
await progress_callback("attachments", 100, 100, "附件导入完成")
|
||||||
|
|
||||||
|
# 6. 导入插件和其他目录
|
||||||
|
if progress_callback:
|
||||||
|
await progress_callback(
|
||||||
|
"directories", 0, 100, "正在导入插件和数据目录..."
|
||||||
|
)
|
||||||
|
|
||||||
|
dir_stats = await self._import_directories(zf, manifest, result)
|
||||||
|
result.imported_directories = dir_stats
|
||||||
|
|
||||||
|
if progress_callback:
|
||||||
|
await progress_callback("directories", 100, 100, "目录导入完成")
|
||||||
|
|
||||||
|
logger.info(f"备份导入完成: {result.to_dict()}")
|
||||||
|
return result
|
||||||
|
|
||||||
|
except zipfile.BadZipFile:
|
||||||
|
result.add_error("无效的 ZIP 文件")
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
result.add_error(f"导入失败: {e}")
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _validate_version(self, manifest: dict) -> None:
|
||||||
|
"""验证版本兼容性 - 仅允许相同主版本导入
|
||||||
|
|
||||||
|
注意:此方法仅在 import_all 中调用,用于双重校验。
|
||||||
|
前端应先调用 pre_check 获取详细的版本信息并让用户确认。
|
||||||
|
"""
|
||||||
|
backup_version = manifest.get("astrbot_version")
|
||||||
|
if not backup_version:
|
||||||
|
raise ValueError("备份文件缺少版本信息")
|
||||||
|
|
||||||
|
# 使用新的版本兼容性检查
|
||||||
|
version_check = self._check_version_compatibility(backup_version)
|
||||||
|
|
||||||
|
if version_check["status"] == "major_diff":
|
||||||
|
raise ValueError(version_check["message"])
|
||||||
|
|
||||||
|
# minor_diff 和 match 都允许导入
|
||||||
|
if version_check["status"] == "minor_diff":
|
||||||
|
logger.warning(f"版本差异警告: {version_check['message']}")
|
||||||
|
|
||||||
|
async def _clear_main_db(self) -> None:
|
||||||
|
"""清空主数据库所有表"""
|
||||||
|
async with self.main_db.get_db() as session:
|
||||||
|
async with session.begin():
|
||||||
|
for table_name, model_class in MAIN_DB_MODELS.items():
|
||||||
|
try:
|
||||||
|
await session.execute(delete(model_class))
|
||||||
|
logger.debug(f"已清空表 {table_name}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"清空表 {table_name} 失败: {e}")
|
||||||
|
|
||||||
|
async def _clear_kb_data(self) -> None:
|
||||||
|
"""清空知识库数据"""
|
||||||
|
if not self.kb_manager:
|
||||||
|
return
|
||||||
|
|
||||||
|
# 清空知识库元数据表
|
||||||
|
async with self.kb_manager.kb_db.get_db() as session:
|
||||||
|
async with session.begin():
|
||||||
|
for table_name, model_class in KB_METADATA_MODELS.items():
|
||||||
|
try:
|
||||||
|
await session.execute(delete(model_class))
|
||||||
|
logger.debug(f"已清空知识库表 {table_name}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"清空知识库表 {table_name} 失败: {e}")
|
||||||
|
|
||||||
|
# 删除知识库文件目录
|
||||||
|
for kb_id in list(self.kb_manager.kb_insts.keys()):
|
||||||
|
try:
|
||||||
|
kb_helper = self.kb_manager.kb_insts[kb_id]
|
||||||
|
await kb_helper.terminate()
|
||||||
|
if kb_helper.kb_dir.exists():
|
||||||
|
shutil.rmtree(kb_helper.kb_dir)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"清理知识库 {kb_id} 失败: {e}")
|
||||||
|
|
||||||
|
self.kb_manager.kb_insts.clear()
|
||||||
|
|
||||||
|
async def _import_main_database(
|
||||||
|
self, data: dict[str, list[dict]]
|
||||||
|
) -> dict[str, int]:
|
||||||
|
"""导入主数据库数据"""
|
||||||
|
imported: dict[str, int] = {}
|
||||||
|
|
||||||
|
async with self.main_db.get_db() as session:
|
||||||
|
async with session.begin():
|
||||||
|
for table_name, rows in data.items():
|
||||||
|
model_class = MAIN_DB_MODELS.get(table_name)
|
||||||
|
if not model_class:
|
||||||
|
logger.warning(f"未知的表: {table_name}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
count = 0
|
||||||
|
for row in rows:
|
||||||
|
try:
|
||||||
|
# 转换 datetime 字符串为 datetime 对象
|
||||||
|
row = self._convert_datetime_fields(row, model_class)
|
||||||
|
obj = model_class(**row)
|
||||||
|
session.add(obj)
|
||||||
|
count += 1
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"导入记录到 {table_name} 失败: {e}")
|
||||||
|
|
||||||
|
imported[table_name] = count
|
||||||
|
logger.debug(f"导入表 {table_name}: {count} 条记录")
|
||||||
|
|
||||||
|
return imported
|
||||||
|
|
||||||
|
async def _import_knowledge_bases(
|
||||||
|
self,
|
||||||
|
zf: zipfile.ZipFile,
|
||||||
|
kb_meta_data: dict[str, list[dict]],
|
||||||
|
result: ImportResult,
|
||||||
|
) -> None:
|
||||||
|
"""导入知识库数据"""
|
||||||
|
if not self.kb_manager:
|
||||||
|
return
|
||||||
|
|
||||||
|
# 1. 导入知识库元数据
|
||||||
|
async with self.kb_manager.kb_db.get_db() as session:
|
||||||
|
async with session.begin():
|
||||||
|
for table_name, rows in kb_meta_data.items():
|
||||||
|
model_class = KB_METADATA_MODELS.get(table_name)
|
||||||
|
if not model_class:
|
||||||
|
continue
|
||||||
|
|
||||||
|
count = 0
|
||||||
|
for row in rows:
|
||||||
|
try:
|
||||||
|
row = self._convert_datetime_fields(row, model_class)
|
||||||
|
obj = model_class(**row)
|
||||||
|
session.add(obj)
|
||||||
|
count += 1
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"导入知识库记录到 {table_name} 失败: {e}")
|
||||||
|
|
||||||
|
result.imported_tables[f"kb_{table_name}"] = count
|
||||||
|
|
||||||
|
# 2. 导入每个知识库的文档和文件
|
||||||
|
for kb_data in kb_meta_data.get("knowledge_bases", []):
|
||||||
|
kb_id = kb_data.get("kb_id")
|
||||||
|
if not kb_id:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 创建知识库目录
|
||||||
|
kb_dir = Path(self.kb_root_dir) / kb_id
|
||||||
|
kb_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# 导入文档数据
|
||||||
|
doc_path = f"databases/kb_{kb_id}/documents.json"
|
||||||
|
if doc_path in zf.namelist():
|
||||||
|
try:
|
||||||
|
doc_content = zf.read(doc_path)
|
||||||
|
doc_data = json.loads(doc_content)
|
||||||
|
|
||||||
|
# 导入到文档存储数据库
|
||||||
|
await self._import_kb_documents(kb_id, doc_data)
|
||||||
|
except Exception as e:
|
||||||
|
result.add_warning(f"导入知识库 {kb_id} 的文档失败: {e}")
|
||||||
|
|
||||||
|
# 导入 FAISS 索引
|
||||||
|
faiss_path = f"databases/kb_{kb_id}/index.faiss"
|
||||||
|
if faiss_path in zf.namelist():
|
||||||
|
try:
|
||||||
|
target_path = kb_dir / "index.faiss"
|
||||||
|
with zf.open(faiss_path) as src, open(target_path, "wb") as dst:
|
||||||
|
dst.write(src.read())
|
||||||
|
except Exception as e:
|
||||||
|
result.add_warning(f"导入知识库 {kb_id} 的 FAISS 索引失败: {e}")
|
||||||
|
|
||||||
|
# 导入媒体文件
|
||||||
|
media_prefix = f"files/kb_media/{kb_id}/"
|
||||||
|
for name in zf.namelist():
|
||||||
|
if name.startswith(media_prefix):
|
||||||
|
try:
|
||||||
|
rel_path = name[len(media_prefix) :]
|
||||||
|
target_path = kb_dir / rel_path
|
||||||
|
target_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
with zf.open(name) as src, open(target_path, "wb") as dst:
|
||||||
|
dst.write(src.read())
|
||||||
|
except Exception as e:
|
||||||
|
result.add_warning(f"导入媒体文件 {name} 失败: {e}")
|
||||||
|
|
||||||
|
# 3. 重新加载知识库实例
|
||||||
|
await self.kb_manager.load_kbs()
|
||||||
|
|
||||||
|
async def _import_kb_documents(self, kb_id: str, doc_data: dict) -> None:
|
||||||
|
"""导入知识库文档到向量数据库"""
|
||||||
|
from astrbot.core.db.vec_db.faiss_impl.document_storage import DocumentStorage
|
||||||
|
|
||||||
|
kb_dir = Path(self.kb_root_dir) / kb_id
|
||||||
|
doc_db_path = kb_dir / "doc.db"
|
||||||
|
|
||||||
|
# 初始化文档存储
|
||||||
|
doc_storage = DocumentStorage(str(doc_db_path))
|
||||||
|
await doc_storage.initialize()
|
||||||
|
|
||||||
|
try:
|
||||||
|
documents = doc_data.get("documents", [])
|
||||||
|
for doc in documents:
|
||||||
|
try:
|
||||||
|
await doc_storage.insert_document(
|
||||||
|
doc_id=doc.get("doc_id", ""),
|
||||||
|
text=doc.get("text", ""),
|
||||||
|
metadata=json.loads(doc.get("metadata", "{}")),
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"导入文档块失败: {e}")
|
||||||
|
finally:
|
||||||
|
await doc_storage.close()
|
||||||
|
|
||||||
|
async def _import_attachments(
|
||||||
|
self,
|
||||||
|
zf: zipfile.ZipFile,
|
||||||
|
attachments: list[dict],
|
||||||
|
) -> int:
|
||||||
|
"""导入附件文件"""
|
||||||
|
count = 0
|
||||||
|
|
||||||
|
attachments_dir = Path(self.config_path).parent / "attachments"
|
||||||
|
attachments_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
attachment_prefix = "files/attachments/"
|
||||||
|
for name in zf.namelist():
|
||||||
|
if name.startswith(attachment_prefix) and name != attachment_prefix:
|
||||||
|
try:
|
||||||
|
# 从附件记录中找到原始路径
|
||||||
|
attachment_id = os.path.splitext(os.path.basename(name))[0]
|
||||||
|
original_path = None
|
||||||
|
for att in attachments:
|
||||||
|
if att.get("attachment_id") == attachment_id:
|
||||||
|
original_path = att.get("path")
|
||||||
|
break
|
||||||
|
|
||||||
|
if original_path:
|
||||||
|
target_path = Path(original_path)
|
||||||
|
else:
|
||||||
|
target_path = attachments_dir / os.path.basename(name)
|
||||||
|
|
||||||
|
target_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
with zf.open(name) as src, open(target_path, "wb") as dst:
|
||||||
|
dst.write(src.read())
|
||||||
|
count += 1
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"导入附件 {name} 失败: {e}")
|
||||||
|
|
||||||
|
return count
|
||||||
|
|
||||||
|
async def _import_directories(
|
||||||
|
self,
|
||||||
|
zf: zipfile.ZipFile,
|
||||||
|
manifest: dict,
|
||||||
|
result: ImportResult,
|
||||||
|
) -> dict[str, int]:
|
||||||
|
"""导入插件和其他数据目录
|
||||||
|
|
||||||
|
Args:
|
||||||
|
zf: ZIP 文件对象
|
||||||
|
manifest: 备份清单
|
||||||
|
result: 导入结果对象
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: 每个目录导入的文件数量
|
||||||
|
"""
|
||||||
|
dir_stats: dict[str, int] = {}
|
||||||
|
|
||||||
|
# 检查备份版本是否支持目录备份(需要版本 >= 1.1)
|
||||||
|
backup_version = manifest.get("version", "1.0")
|
||||||
|
if VersionComparator.compare_version(backup_version, "1.1") < 0:
|
||||||
|
logger.info("备份版本不支持目录备份,跳过目录导入")
|
||||||
|
return dir_stats
|
||||||
|
|
||||||
|
backed_up_dirs = manifest.get("directories", [])
|
||||||
|
backup_directories = get_backup_directories()
|
||||||
|
|
||||||
|
for dir_name in backed_up_dirs:
|
||||||
|
if dir_name not in backup_directories:
|
||||||
|
result.add_warning(f"未知的目录类型: {dir_name}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
target_dir = Path(backup_directories[dir_name])
|
||||||
|
archive_prefix = f"directories/{dir_name}/"
|
||||||
|
|
||||||
|
file_count = 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 获取该目录下的所有文件
|
||||||
|
dir_files = [
|
||||||
|
name
|
||||||
|
for name in zf.namelist()
|
||||||
|
if name.startswith(archive_prefix) and name != archive_prefix
|
||||||
|
]
|
||||||
|
|
||||||
|
if not dir_files:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 备份现有目录(如果存在)
|
||||||
|
if target_dir.exists():
|
||||||
|
backup_path = Path(f"{target_dir}.bak")
|
||||||
|
if backup_path.exists():
|
||||||
|
shutil.rmtree(backup_path)
|
||||||
|
shutil.move(str(target_dir), str(backup_path))
|
||||||
|
logger.debug(f"已备份现有目录 {target_dir} 到 {backup_path}")
|
||||||
|
|
||||||
|
# 创建目标目录
|
||||||
|
target_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# 解压文件
|
||||||
|
for name in dir_files:
|
||||||
|
try:
|
||||||
|
# 计算相对路径
|
||||||
|
rel_path = name[len(archive_prefix) :]
|
||||||
|
if not rel_path: # 跳过目录条目
|
||||||
|
continue
|
||||||
|
|
||||||
|
target_path = target_dir / rel_path
|
||||||
|
target_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
with zf.open(name) as src, open(target_path, "wb") as dst:
|
||||||
|
dst.write(src.read())
|
||||||
|
file_count += 1
|
||||||
|
except Exception as e:
|
||||||
|
result.add_warning(f"导入文件 {name} 失败: {e}")
|
||||||
|
|
||||||
|
dir_stats[dir_name] = file_count
|
||||||
|
logger.debug(f"导入目录 {dir_name}: {file_count} 个文件")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
result.add_warning(f"导入目录 {dir_name} 失败: {e}")
|
||||||
|
dir_stats[dir_name] = 0
|
||||||
|
|
||||||
|
return dir_stats
|
||||||
|
|
||||||
|
def _convert_datetime_fields(self, row: dict, model_class: type) -> dict:
|
||||||
|
"""转换 datetime 字符串字段为 datetime 对象"""
|
||||||
|
result = row.copy()
|
||||||
|
|
||||||
|
# 获取模型的 datetime 字段
|
||||||
|
from sqlalchemy import inspect as sa_inspect
|
||||||
|
|
||||||
|
try:
|
||||||
|
mapper = sa_inspect(model_class)
|
||||||
|
for column in mapper.columns:
|
||||||
|
if column.name in result and result[column.name] is not None:
|
||||||
|
# 检查是否是 datetime 类型的列
|
||||||
|
from sqlalchemy import DateTime
|
||||||
|
|
||||||
|
if isinstance(column.type, DateTime):
|
||||||
|
value = result[column.name]
|
||||||
|
if isinstance(value, str):
|
||||||
|
# 解析 ISO 格式的日期时间字符串
|
||||||
|
result[column.name] = datetime.fromisoformat(value)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return result
|
||||||
@@ -5,7 +5,7 @@ from typing import Any, TypedDict
|
|||||||
|
|
||||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||||
|
|
||||||
VERSION = "4.10.2"
|
VERSION = "4.10.3"
|
||||||
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db")
|
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db")
|
||||||
|
|
||||||
WEBHOOK_SUPPORTED_PLATFORMS = [
|
WEBHOOK_SUPPORTED_PLATFORMS = [
|
||||||
|
|||||||
+1
-1
@@ -58,7 +58,7 @@ def is_plugin_path(pathname):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
norm_path = os.path.normpath(pathname)
|
norm_path = os.path.normpath(pathname)
|
||||||
return ("data/plugins" in norm_path) or ("packages/" in norm_path)
|
return ("data/plugins" in norm_path) or ("astrbot/builtin_stars/" in norm_path)
|
||||||
|
|
||||||
|
|
||||||
def get_short_level_name(level_name):
|
def get_short_level_name(level_name):
|
||||||
|
|||||||
@@ -390,7 +390,7 @@ class InternalAgentSubStage(Stage):
|
|||||||
return
|
return
|
||||||
|
|
||||||
req.prompt = event.message_str[len(provider_wake_prefix) :]
|
req.prompt = event.message_str[len(provider_wake_prefix) :]
|
||||||
# func_tool selection 现在已经转移到 packages/astrbot 插件中进行选择。
|
# func_tool selection 现在已经转移到 astrbot/builtin_stars/astrbot 插件中进行选择。
|
||||||
# req.func_tool = self.ctx.plugin_manager.context.get_llm_tool_manager()
|
# req.func_tool = self.ctx.plugin_manager.context.get_llm_tool_manager()
|
||||||
for comp in event.message_obj.message:
|
for comp in event.message_obj.message:
|
||||||
if isinstance(comp, Image):
|
if isinstance(comp, Image):
|
||||||
|
|||||||
@@ -136,7 +136,8 @@ class WakingCheckStage(Stage):
|
|||||||
):
|
):
|
||||||
if (
|
if (
|
||||||
self.disable_builtin_commands
|
self.disable_builtin_commands
|
||||||
and handler.handler_module_path == "packages.builtin_commands.main"
|
and handler.handler_module_path
|
||||||
|
== "astrbot.builtin_stars.builtin_commands.main"
|
||||||
):
|
):
|
||||||
logger.debug("skipping builtin command")
|
logger.debug("skipping builtin command")
|
||||||
continue
|
continue
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ import astrbot.core.message.components as Comp
|
|||||||
from astrbot import logger
|
from astrbot import logger
|
||||||
from astrbot.core.agent.message import (
|
from astrbot.core.agent.message import (
|
||||||
AssistantMessageSegment,
|
AssistantMessageSegment,
|
||||||
|
ContentPart,
|
||||||
ToolCall,
|
ToolCall,
|
||||||
ToolCallMessageSegment,
|
ToolCallMessageSegment,
|
||||||
)
|
)
|
||||||
@@ -92,6 +93,8 @@ class ProviderRequest:
|
|||||||
"""会话 ID"""
|
"""会话 ID"""
|
||||||
image_urls: list[str] = field(default_factory=list)
|
image_urls: list[str] = field(default_factory=list)
|
||||||
"""图片 URL 列表"""
|
"""图片 URL 列表"""
|
||||||
|
extra_user_content_parts: list[ContentPart] = field(default_factory=list)
|
||||||
|
"""额外的用户消息内容部分列表,用于在用户消息后添加额外的内容块(如系统提醒、指令等)。支持 dict 或 ContentPart 对象"""
|
||||||
func_tool: ToolSet | None = None
|
func_tool: ToolSet | None = None
|
||||||
"""可用的函数工具"""
|
"""可用的函数工具"""
|
||||||
contexts: list[dict] = field(default_factory=list)
|
contexts: list[dict] = field(default_factory=list)
|
||||||
@@ -166,13 +169,23 @@ class ProviderRequest:
|
|||||||
|
|
||||||
async def assemble_context(self) -> dict:
|
async def assemble_context(self) -> dict:
|
||||||
"""将请求(prompt 和 image_urls)包装成 OpenAI 的消息格式。"""
|
"""将请求(prompt 和 image_urls)包装成 OpenAI 的消息格式。"""
|
||||||
|
# 构建内容块列表
|
||||||
|
content_blocks = []
|
||||||
|
|
||||||
|
# 1. 用户原始发言(OpenAI 建议:用户发言在前)
|
||||||
|
if self.prompt and self.prompt.strip():
|
||||||
|
content_blocks.append({"type": "text", "text": self.prompt})
|
||||||
|
elif self.image_urls:
|
||||||
|
# 如果没有文本但有图片,添加占位文本
|
||||||
|
content_blocks.append({"type": "text", "text": "[图片]"})
|
||||||
|
|
||||||
|
# 2. 额外的内容块(系统提醒、指令等)
|
||||||
|
if self.extra_user_content_parts:
|
||||||
|
for part in self.extra_user_content_parts:
|
||||||
|
content_blocks.append(part.model_dump())
|
||||||
|
|
||||||
|
# 3. 图片内容
|
||||||
if self.image_urls:
|
if self.image_urls:
|
||||||
user_content = {
|
|
||||||
"role": "user",
|
|
||||||
"content": [
|
|
||||||
{"type": "text", "text": self.prompt if self.prompt else "[图片]"},
|
|
||||||
],
|
|
||||||
}
|
|
||||||
for image_url in self.image_urls:
|
for image_url in self.image_urls:
|
||||||
if image_url.startswith("http"):
|
if image_url.startswith("http"):
|
||||||
image_path = await download_image_by_url(image_url)
|
image_path = await download_image_by_url(image_url)
|
||||||
@@ -185,11 +198,21 @@ class ProviderRequest:
|
|||||||
if not image_data:
|
if not image_data:
|
||||||
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
|
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
|
||||||
continue
|
continue
|
||||||
user_content["content"].append(
|
content_blocks.append(
|
||||||
{"type": "image_url", "image_url": {"url": image_data}},
|
{"type": "image_url", "image_url": {"url": image_data}},
|
||||||
)
|
)
|
||||||
return user_content
|
|
||||||
return {"role": "user", "content": self.prompt}
|
# 只有当只有一个来自 prompt 的文本块且没有额外内容块时,才降级为简单格式以保持向后兼容
|
||||||
|
if (
|
||||||
|
len(content_blocks) == 1
|
||||||
|
and content_blocks[0]["type"] == "text"
|
||||||
|
and not self.extra_user_content_parts
|
||||||
|
and not self.image_urls
|
||||||
|
):
|
||||||
|
return {"role": "user", "content": content_blocks[0]["text"]}
|
||||||
|
|
||||||
|
# 否则返回多模态格式
|
||||||
|
return {"role": "user", "content": content_blocks}
|
||||||
|
|
||||||
async def _encode_image_bs64(self, image_url: str) -> str:
|
async def _encode_image_bs64(self, image_url: str) -> str:
|
||||||
"""将图片转换为 base64"""
|
"""将图片转换为 base64"""
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import os
|
|||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
from typing import TypeAlias, Union
|
from typing import TypeAlias, Union
|
||||||
|
|
||||||
from astrbot.core.agent.message import Message
|
from astrbot.core.agent.message import ContentPart, Message
|
||||||
from astrbot.core.agent.tool import ToolSet
|
from astrbot.core.agent.tool import ToolSet
|
||||||
from astrbot.core.provider.entities import (
|
from astrbot.core.provider.entities import (
|
||||||
LLMResponse,
|
LLMResponse,
|
||||||
@@ -103,6 +103,7 @@ class Provider(AbstractProvider):
|
|||||||
system_prompt: str | None = None,
|
system_prompt: str | None = None,
|
||||||
tool_calls_result: ToolCallsResult | list[ToolCallsResult] | None = None,
|
tool_calls_result: ToolCallsResult | list[ToolCallsResult] | None = None,
|
||||||
model: str | None = None,
|
model: str | None = None,
|
||||||
|
extra_user_content_parts: list[ContentPart] | None = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> LLMResponse:
|
) -> LLMResponse:
|
||||||
"""获得 LLM 的文本对话结果。会使用当前的模型进行对话。
|
"""获得 LLM 的文本对话结果。会使用当前的模型进行对话。
|
||||||
@@ -114,6 +115,7 @@ class Provider(AbstractProvider):
|
|||||||
tools: tool set
|
tools: tool set
|
||||||
contexts: 上下文,和 prompt 二选一使用
|
contexts: 上下文,和 prompt 二选一使用
|
||||||
tool_calls_result: 回传给 LLM 的工具调用结果。参考: https://platform.openai.com/docs/guides/function-calling
|
tool_calls_result: 回传给 LLM 的工具调用结果。参考: https://platform.openai.com/docs/guides/function-calling
|
||||||
|
extra_user_content_parts: 额外的内容块列表,用于在用户消息后添加额外的文本块(如系统提醒、指令等)
|
||||||
kwargs: 其他参数
|
kwargs: 其他参数
|
||||||
|
|
||||||
Notes:
|
Notes:
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ from anthropic.types.usage import Usage
|
|||||||
|
|
||||||
from astrbot import logger
|
from astrbot import logger
|
||||||
from astrbot.api.provider import Provider
|
from astrbot.api.provider import Provider
|
||||||
|
from astrbot.core.agent.message import ContentPart, ImageURLPart, TextPart
|
||||||
from astrbot.core.provider.entities import LLMResponse, TokenUsage
|
from astrbot.core.provider.entities import LLMResponse, TokenUsage
|
||||||
from astrbot.core.provider.func_tool_manager import ToolSet
|
from astrbot.core.provider.func_tool_manager import ToolSet
|
||||||
from astrbot.core.utils.io import download_image_by_url
|
from astrbot.core.utils.io import download_image_by_url
|
||||||
@@ -68,7 +69,7 @@ class ProviderAnthropic(Provider):
|
|||||||
blocks = []
|
blocks = []
|
||||||
if isinstance(message["content"], str):
|
if isinstance(message["content"], str):
|
||||||
blocks.append({"type": "text", "text": message["content"]})
|
blocks.append({"type": "text", "text": message["content"]})
|
||||||
if "tool_calls" in message:
|
if "tool_calls" in message and isinstance(message["tool_calls"], list):
|
||||||
for tool_call in message["tool_calls"]:
|
for tool_call in message["tool_calls"]:
|
||||||
blocks.append( # noqa: PERF401
|
blocks.append( # noqa: PERF401
|
||||||
{
|
{
|
||||||
@@ -132,6 +133,9 @@ class ProviderAnthropic(Provider):
|
|||||||
|
|
||||||
extra_body = self.provider_config.get("custom_extra_body", {})
|
extra_body = self.provider_config.get("custom_extra_body", {})
|
||||||
|
|
||||||
|
if "max_tokens" not in payloads:
|
||||||
|
payloads["max_tokens"] = 1024
|
||||||
|
|
||||||
completion = await self.client.messages.create(
|
completion = await self.client.messages.create(
|
||||||
**payloads, stream=False, extra_body=extra_body
|
**payloads, stream=False, extra_body=extra_body
|
||||||
)
|
)
|
||||||
@@ -181,6 +185,9 @@ class ProviderAnthropic(Provider):
|
|||||||
usage = TokenUsage()
|
usage = TokenUsage()
|
||||||
extra_body = self.provider_config.get("custom_extra_body", {})
|
extra_body = self.provider_config.get("custom_extra_body", {})
|
||||||
|
|
||||||
|
if "max_tokens" not in payloads:
|
||||||
|
payloads["max_tokens"] = 1024
|
||||||
|
|
||||||
async with self.client.messages.stream(
|
async with self.client.messages.stream(
|
||||||
**payloads, extra_body=extra_body
|
**payloads, extra_body=extra_body
|
||||||
) as stream:
|
) as stream:
|
||||||
@@ -296,13 +303,16 @@ class ProviderAnthropic(Provider):
|
|||||||
system_prompt=None,
|
system_prompt=None,
|
||||||
tool_calls_result=None,
|
tool_calls_result=None,
|
||||||
model=None,
|
model=None,
|
||||||
|
extra_user_content_parts=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> LLMResponse:
|
) -> LLMResponse:
|
||||||
if contexts is None:
|
if contexts is None:
|
||||||
contexts = []
|
contexts = []
|
||||||
new_record = None
|
new_record = None
|
||||||
if prompt is not None:
|
if prompt is not None:
|
||||||
new_record = await self.assemble_context(prompt, image_urls)
|
new_record = await self.assemble_context(
|
||||||
|
prompt, image_urls, extra_user_content_parts
|
||||||
|
)
|
||||||
context_query = self._ensure_message_to_dicts(contexts)
|
context_query = self._ensure_message_to_dicts(contexts)
|
||||||
if new_record:
|
if new_record:
|
||||||
context_query.append(new_record)
|
context_query.append(new_record)
|
||||||
@@ -342,21 +352,24 @@ class ProviderAnthropic(Provider):
|
|||||||
|
|
||||||
async def text_chat_stream(
|
async def text_chat_stream(
|
||||||
self,
|
self,
|
||||||
prompt,
|
prompt=None,
|
||||||
session_id=None,
|
session_id=None,
|
||||||
image_urls=...,
|
image_urls=None,
|
||||||
func_tool=None,
|
func_tool=None,
|
||||||
contexts=...,
|
contexts=None,
|
||||||
system_prompt=None,
|
system_prompt=None,
|
||||||
tool_calls_result=None,
|
tool_calls_result=None,
|
||||||
model=None,
|
model=None,
|
||||||
|
extra_user_content_parts=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
if contexts is None:
|
if contexts is None:
|
||||||
contexts = []
|
contexts = []
|
||||||
new_record = None
|
new_record = None
|
||||||
if prompt is not None:
|
if prompt is not None:
|
||||||
new_record = await self.assemble_context(prompt, image_urls)
|
new_record = await self.assemble_context(
|
||||||
|
prompt, image_urls, extra_user_content_parts
|
||||||
|
)
|
||||||
context_query = self._ensure_message_to_dicts(contexts)
|
context_query = self._ensure_message_to_dicts(contexts)
|
||||||
if new_record:
|
if new_record:
|
||||||
context_query.append(new_record)
|
context_query.append(new_record)
|
||||||
@@ -388,15 +401,15 @@ class ProviderAnthropic(Provider):
|
|||||||
async for llm_response in self._query_stream(payloads, func_tool):
|
async for llm_response in self._query_stream(payloads, func_tool):
|
||||||
yield llm_response
|
yield llm_response
|
||||||
|
|
||||||
async def assemble_context(self, text: str, image_urls: list[str] | None = None):
|
async def assemble_context(
|
||||||
|
self,
|
||||||
|
text: str,
|
||||||
|
image_urls: list[str] | None = None,
|
||||||
|
extra_user_content_parts: list[ContentPart] | None = None,
|
||||||
|
):
|
||||||
"""组装上下文,支持文本和图片"""
|
"""组装上下文,支持文本和图片"""
|
||||||
if not image_urls:
|
|
||||||
return {"role": "user", "content": text}
|
|
||||||
|
|
||||||
content = []
|
async def resolve_image_url(image_url: str) -> dict | None:
|
||||||
content.append({"type": "text", "text": text})
|
|
||||||
|
|
||||||
for image_url in image_urls:
|
|
||||||
if image_url.startswith("http"):
|
if image_url.startswith("http"):
|
||||||
image_path = await download_image_by_url(image_url)
|
image_path = await download_image_by_url(image_url)
|
||||||
image_data = await self.encode_image_bs64(image_path)
|
image_data = await self.encode_image_bs64(image_path)
|
||||||
@@ -408,28 +421,68 @@ class ProviderAnthropic(Provider):
|
|||||||
|
|
||||||
if not image_data:
|
if not image_data:
|
||||||
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
|
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
|
||||||
continue
|
return None
|
||||||
|
|
||||||
# Get mime type for the image
|
# Get mime type for the image
|
||||||
mime_type, _ = guess_type(image_url)
|
mime_type, _ = guess_type(image_url)
|
||||||
if not mime_type:
|
if not mime_type:
|
||||||
mime_type = "image/jpeg" # Default to JPEG if can't determine
|
mime_type = "image/jpeg" # Default to JPEG if can't determine
|
||||||
|
|
||||||
content.append(
|
return {
|
||||||
{
|
"type": "image",
|
||||||
"type": "image",
|
"source": {
|
||||||
"source": {
|
"type": "base64",
|
||||||
"type": "base64",
|
"media_type": mime_type,
|
||||||
"media_type": mime_type,
|
"data": (
|
||||||
"data": (
|
image_data.split("base64,")[1]
|
||||||
image_data.split("base64,")[1]
|
if "base64," in image_data
|
||||||
if "base64," in image_data
|
else image_data
|
||||||
else image_data
|
),
|
||||||
),
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
)
|
}
|
||||||
|
|
||||||
|
content = []
|
||||||
|
|
||||||
|
# 1. 用户原始发言(OpenAI 建议:用户发言在前)
|
||||||
|
if text:
|
||||||
|
content.append({"type": "text", "text": text})
|
||||||
|
elif image_urls:
|
||||||
|
# 如果没有文本但有图片,添加占位文本
|
||||||
|
content.append({"type": "text", "text": "[图片]"})
|
||||||
|
elif extra_user_content_parts:
|
||||||
|
# 如果只有额外内容块,也需要添加占位文本
|
||||||
|
content.append({"type": "text", "text": " "})
|
||||||
|
|
||||||
|
# 2. 额外的内容块(系统提醒、指令等)
|
||||||
|
if extra_user_content_parts:
|
||||||
|
for block in extra_user_content_parts:
|
||||||
|
if isinstance(block, TextPart):
|
||||||
|
content.append({"type": "text", "text": block.text})
|
||||||
|
elif isinstance(block, ImageURLPart):
|
||||||
|
image_dict = await resolve_image_url(block.image_url.url)
|
||||||
|
if image_dict:
|
||||||
|
content.append(image_dict)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"不支持的额外内容块类型: {type(block)}")
|
||||||
|
|
||||||
|
# 3. 图片内容
|
||||||
|
if image_urls:
|
||||||
|
for image_url in image_urls:
|
||||||
|
image_dict = await resolve_image_url(image_url)
|
||||||
|
if image_dict:
|
||||||
|
content.append(image_dict)
|
||||||
|
|
||||||
|
# 如果只有主文本且没有额外内容块和图片,返回简单格式以保持向后兼容
|
||||||
|
if (
|
||||||
|
text
|
||||||
|
and not extra_user_content_parts
|
||||||
|
and not image_urls
|
||||||
|
and len(content) == 1
|
||||||
|
and content[0]["type"] == "text"
|
||||||
|
):
|
||||||
|
return {"role": "user", "content": content[0]["text"]}
|
||||||
|
|
||||||
|
# 否则返回多模态格式
|
||||||
return {"role": "user", "content": content}
|
return {"role": "user", "content": content}
|
||||||
|
|
||||||
async def encode_image_bs64(self, image_url: str) -> str:
|
async def encode_image_bs64(self, image_url: str) -> str:
|
||||||
|
|||||||
@@ -56,10 +56,14 @@ class ProviderFishAudioTTSAPI(TTSProvider):
|
|||||||
"api_base",
|
"api_base",
|
||||||
"https://api.fish-audio.cn/v1",
|
"https://api.fish-audio.cn/v1",
|
||||||
)
|
)
|
||||||
|
try:
|
||||||
|
self.timeout: int = int(provider_config.get("timeout", 20))
|
||||||
|
except ValueError:
|
||||||
|
self.timeout = 20
|
||||||
self.headers = {
|
self.headers = {
|
||||||
"Authorization": f"Bearer {self.chosen_api_key}",
|
"Authorization": f"Bearer {self.chosen_api_key}",
|
||||||
}
|
}
|
||||||
self.set_model(provider_config["model"])
|
self.set_model(provider_config.get("model", None))
|
||||||
|
|
||||||
async def _get_reference_id_by_character(self, character: str) -> str | None:
|
async def _get_reference_id_by_character(self, character: str) -> str | None:
|
||||||
"""获取角色的reference_id
|
"""获取角色的reference_id
|
||||||
@@ -135,17 +139,21 @@ class ProviderFishAudioTTSAPI(TTSProvider):
|
|||||||
path = os.path.join(temp_dir, f"fishaudio_tts_api_{uuid.uuid4()}.wav")
|
path = os.path.join(temp_dir, f"fishaudio_tts_api_{uuid.uuid4()}.wav")
|
||||||
self.headers["content-type"] = "application/msgpack"
|
self.headers["content-type"] = "application/msgpack"
|
||||||
request = await self._generate_request(text)
|
request = await self._generate_request(text)
|
||||||
async with AsyncClient(base_url=self.api_base).stream(
|
async with AsyncClient(base_url=self.api_base, timeout=self.timeout).stream(
|
||||||
"POST",
|
"POST",
|
||||||
"/tts",
|
"/tts",
|
||||||
headers=self.headers,
|
headers=self.headers,
|
||||||
content=ormsgpack.packb(request, option=ormsgpack.OPT_SERIALIZE_PYDANTIC),
|
content=ormsgpack.packb(request, option=ormsgpack.OPT_SERIALIZE_PYDANTIC),
|
||||||
) as response:
|
) as response:
|
||||||
if response.headers["content-type"] == "audio/wav":
|
if response.status_code == 200 and response.headers.get(
|
||||||
|
"content-type", ""
|
||||||
|
).startswith("audio/"):
|
||||||
with open(path, "wb") as f:
|
with open(path, "wb") as f:
|
||||||
async for chunk in response.aiter_bytes():
|
async for chunk in response.aiter_bytes():
|
||||||
f.write(chunk)
|
f.write(chunk)
|
||||||
return path
|
return path
|
||||||
body = await response.aread()
|
error_bytes = await response.aread()
|
||||||
text = body.decode("utf-8", errors="replace")
|
error_text = error_bytes.decode("utf-8", errors="replace")[:1024]
|
||||||
raise Exception(f"Fish Audio API请求失败: {text}")
|
raise Exception(
|
||||||
|
f"Fish Audio API请求失败: 状态码 {response.status_code}, 响应内容: {error_text}"
|
||||||
|
)
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ from google.genai.errors import APIError
|
|||||||
import astrbot.core.message.components as Comp
|
import astrbot.core.message.components as Comp
|
||||||
from astrbot import logger
|
from astrbot import logger
|
||||||
from astrbot.api.provider import Provider
|
from astrbot.api.provider import Provider
|
||||||
|
from astrbot.core.agent.message import ContentPart, ImageURLPart, TextPart
|
||||||
from astrbot.core.message.message_event_result import MessageChain
|
from astrbot.core.message.message_event_result import MessageChain
|
||||||
from astrbot.core.provider.entities import LLMResponse, TokenUsage
|
from astrbot.core.provider.entities import LLMResponse, TokenUsage
|
||||||
from astrbot.core.provider.func_tool_manager import ToolSet
|
from astrbot.core.provider.func_tool_manager import ToolSet
|
||||||
@@ -680,13 +681,16 @@ class ProviderGoogleGenAI(Provider):
|
|||||||
system_prompt=None,
|
system_prompt=None,
|
||||||
tool_calls_result=None,
|
tool_calls_result=None,
|
||||||
model=None,
|
model=None,
|
||||||
|
extra_user_content_parts=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> LLMResponse:
|
) -> LLMResponse:
|
||||||
if contexts is None:
|
if contexts is None:
|
||||||
contexts = []
|
contexts = []
|
||||||
new_record = None
|
new_record = None
|
||||||
if prompt is not None:
|
if prompt is not None:
|
||||||
new_record = await self.assemble_context(prompt, image_urls)
|
new_record = await self.assemble_context(
|
||||||
|
prompt, image_urls, extra_user_content_parts
|
||||||
|
)
|
||||||
context_query = self._ensure_message_to_dicts(contexts)
|
context_query = self._ensure_message_to_dicts(contexts)
|
||||||
if new_record:
|
if new_record:
|
||||||
context_query.append(new_record)
|
context_query.append(new_record)
|
||||||
@@ -732,13 +736,16 @@ class ProviderGoogleGenAI(Provider):
|
|||||||
system_prompt=None,
|
system_prompt=None,
|
||||||
tool_calls_result=None,
|
tool_calls_result=None,
|
||||||
model=None,
|
model=None,
|
||||||
|
extra_user_content_parts=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> AsyncGenerator[LLMResponse, None]:
|
) -> AsyncGenerator[LLMResponse, None]:
|
||||||
if contexts is None:
|
if contexts is None:
|
||||||
contexts = []
|
contexts = []
|
||||||
new_record = None
|
new_record = None
|
||||||
if prompt is not None:
|
if prompt is not None:
|
||||||
new_record = await self.assemble_context(prompt, image_urls)
|
new_record = await self.assemble_context(
|
||||||
|
prompt, image_urls, extra_user_content_parts
|
||||||
|
)
|
||||||
context_query = self._ensure_message_to_dicts(contexts)
|
context_query = self._ensure_message_to_dicts(contexts)
|
||||||
if new_record:
|
if new_record:
|
||||||
context_query.append(new_record)
|
context_query.append(new_record)
|
||||||
@@ -797,33 +804,75 @@ class ProviderGoogleGenAI(Provider):
|
|||||||
self.chosen_api_key = key
|
self.chosen_api_key = key
|
||||||
self._init_client()
|
self._init_client()
|
||||||
|
|
||||||
async def assemble_context(self, text: str, image_urls: list[str] | None = None):
|
async def assemble_context(
|
||||||
|
self,
|
||||||
|
text: str,
|
||||||
|
image_urls: list[str] | None = None,
|
||||||
|
extra_user_content_parts: list[ContentPart] | None = None,
|
||||||
|
):
|
||||||
"""组装上下文。"""
|
"""组装上下文。"""
|
||||||
if image_urls:
|
|
||||||
user_content = {
|
async def resolve_image_part(image_url: str) -> dict | None:
|
||||||
"role": "user",
|
if image_url.startswith("http"):
|
||||||
"content": [{"type": "text", "text": text if text else "[图片]"}],
|
image_path = await download_image_by_url(image_url)
|
||||||
|
image_data = await self.encode_image_bs64(image_path)
|
||||||
|
elif image_url.startswith("file:///"):
|
||||||
|
image_path = image_url.replace("file:///", "")
|
||||||
|
image_data = await self.encode_image_bs64(image_path)
|
||||||
|
else:
|
||||||
|
image_data = await self.encode_image_bs64(image_url)
|
||||||
|
if not image_data:
|
||||||
|
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
|
||||||
|
return None
|
||||||
|
return {
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {"url": image_data},
|
||||||
}
|
}
|
||||||
for image_url in image_urls:
|
|
||||||
if image_url.startswith("http"):
|
# 构建内容块列表
|
||||||
image_path = await download_image_by_url(image_url)
|
content_blocks = []
|
||||||
image_data = await self.encode_image_bs64(image_path)
|
|
||||||
elif image_url.startswith("file:///"):
|
# 1. 用户原始发言(OpenAI 建议:用户发言在前)
|
||||||
image_path = image_url.replace("file:///", "")
|
if text:
|
||||||
image_data = await self.encode_image_bs64(image_path)
|
content_blocks.append({"type": "text", "text": text})
|
||||||
|
elif image_urls:
|
||||||
|
# 如果没有文本但有图片,添加占位文本
|
||||||
|
content_blocks.append({"type": "text", "text": "[图片]"})
|
||||||
|
elif extra_user_content_parts:
|
||||||
|
# 如果只有额外内容块,也需要添加占位文本
|
||||||
|
content_blocks.append({"type": "text", "text": " "})
|
||||||
|
|
||||||
|
# 2. 额外的内容块(系统提醒、指令等)
|
||||||
|
if extra_user_content_parts:
|
||||||
|
for part in extra_user_content_parts:
|
||||||
|
if isinstance(part, TextPart):
|
||||||
|
content_blocks.append({"type": "text", "text": part.text})
|
||||||
|
elif isinstance(part, ImageURLPart):
|
||||||
|
image_part = await resolve_image_part(part.image_url.url)
|
||||||
|
if image_part:
|
||||||
|
content_blocks.append(image_part)
|
||||||
else:
|
else:
|
||||||
image_data = await self.encode_image_bs64(image_url)
|
raise ValueError(f"不支持的额外内容块类型: {type(part)}")
|
||||||
if not image_data:
|
|
||||||
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
|
# 3. 图片内容
|
||||||
continue
|
if image_urls:
|
||||||
user_content["content"].append(
|
for image_url in image_urls:
|
||||||
{
|
image_part = await resolve_image_part(image_url)
|
||||||
"type": "image_url",
|
if image_part:
|
||||||
"image_url": {"url": image_data},
|
content_blocks.append(image_part)
|
||||||
},
|
|
||||||
)
|
# 如果只有主文本且没有额外内容块和图片,返回简单格式以保持向后兼容
|
||||||
return user_content
|
if (
|
||||||
return {"role": "user", "content": text}
|
text
|
||||||
|
and not extra_user_content_parts
|
||||||
|
and not image_urls
|
||||||
|
and len(content_blocks) == 1
|
||||||
|
and content_blocks[0]["type"] == "text"
|
||||||
|
):
|
||||||
|
return {"role": "user", "content": content_blocks[0]["text"]}
|
||||||
|
|
||||||
|
# 否则返回多模态格式
|
||||||
|
return {"role": "user", "content": content_blocks}
|
||||||
|
|
||||||
async def encode_image_bs64(self, image_url: str) -> str:
|
async def encode_image_bs64(self, image_url: str) -> str:
|
||||||
"""将图片转换为 base64"""
|
"""将图片转换为 base64"""
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ from openai.types.completion_usage import CompletionUsage
|
|||||||
import astrbot.core.message.components as Comp
|
import astrbot.core.message.components as Comp
|
||||||
from astrbot import logger
|
from astrbot import logger
|
||||||
from astrbot.api.provider import Provider
|
from astrbot.api.provider import Provider
|
||||||
from astrbot.core.agent.message import Message
|
from astrbot.core.agent.message import ContentPart, ImageURLPart, Message, TextPart
|
||||||
from astrbot.core.agent.tool import ToolSet
|
from astrbot.core.agent.tool import ToolSet
|
||||||
from astrbot.core.message.message_event_result import MessageChain
|
from astrbot.core.message.message_event_result import MessageChain
|
||||||
from astrbot.core.provider.entities import LLMResponse, TokenUsage, ToolCallsResult
|
from astrbot.core.provider.entities import LLMResponse, TokenUsage, ToolCallsResult
|
||||||
@@ -348,6 +348,7 @@ class ProviderOpenAIOfficial(Provider):
|
|||||||
system_prompt: str | None = None,
|
system_prompt: str | None = None,
|
||||||
tool_calls_result: ToolCallsResult | list[ToolCallsResult] | None = None,
|
tool_calls_result: ToolCallsResult | list[ToolCallsResult] | None = None,
|
||||||
model: str | None = None,
|
model: str | None = None,
|
||||||
|
extra_user_content_parts: list[ContentPart] | None = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> tuple:
|
) -> tuple:
|
||||||
"""准备聊天所需的有效载荷和上下文"""
|
"""准备聊天所需的有效载荷和上下文"""
|
||||||
@@ -355,7 +356,9 @@ class ProviderOpenAIOfficial(Provider):
|
|||||||
contexts = []
|
contexts = []
|
||||||
new_record = None
|
new_record = None
|
||||||
if prompt is not None:
|
if prompt is not None:
|
||||||
new_record = await self.assemble_context(prompt, image_urls)
|
new_record = await self.assemble_context(
|
||||||
|
prompt, image_urls, extra_user_content_parts
|
||||||
|
)
|
||||||
context_query = self._ensure_message_to_dicts(contexts)
|
context_query = self._ensure_message_to_dicts(contexts)
|
||||||
if new_record:
|
if new_record:
|
||||||
context_query.append(new_record)
|
context_query.append(new_record)
|
||||||
@@ -476,6 +479,7 @@ class ProviderOpenAIOfficial(Provider):
|
|||||||
system_prompt=None,
|
system_prompt=None,
|
||||||
tool_calls_result=None,
|
tool_calls_result=None,
|
||||||
model=None,
|
model=None,
|
||||||
|
extra_user_content_parts=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> LLMResponse:
|
) -> LLMResponse:
|
||||||
payloads, context_query = await self._prepare_chat_payload(
|
payloads, context_query = await self._prepare_chat_payload(
|
||||||
@@ -485,6 +489,7 @@ class ProviderOpenAIOfficial(Provider):
|
|||||||
system_prompt,
|
system_prompt,
|
||||||
tool_calls_result,
|
tool_calls_result,
|
||||||
model=model,
|
model=model,
|
||||||
|
extra_user_content_parts=extra_user_content_parts,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -624,33 +629,71 @@ class ProviderOpenAIOfficial(Provider):
|
|||||||
self,
|
self,
|
||||||
text: str,
|
text: str,
|
||||||
image_urls: list[str] | None = None,
|
image_urls: list[str] | None = None,
|
||||||
|
extra_user_content_parts: list[ContentPart] | None = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""组装成符合 OpenAI 格式的 role 为 user 的消息段"""
|
"""组装成符合 OpenAI 格式的 role 为 user 的消息段"""
|
||||||
if image_urls:
|
|
||||||
user_content = {
|
async def resolve_image_part(image_url: str) -> dict | None:
|
||||||
"role": "user",
|
if image_url.startswith("http"):
|
||||||
"content": [{"type": "text", "text": text if text else "[图片]"}],
|
image_path = await download_image_by_url(image_url)
|
||||||
|
image_data = await self.encode_image_bs64(image_path)
|
||||||
|
elif image_url.startswith("file:///"):
|
||||||
|
image_path = image_url.replace("file:///", "")
|
||||||
|
image_data = await self.encode_image_bs64(image_path)
|
||||||
|
else:
|
||||||
|
image_data = await self.encode_image_bs64(image_url)
|
||||||
|
if not image_data:
|
||||||
|
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
|
||||||
|
return None
|
||||||
|
return {
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {"url": image_data},
|
||||||
}
|
}
|
||||||
for image_url in image_urls:
|
|
||||||
if image_url.startswith("http"):
|
# 构建内容块列表
|
||||||
image_path = await download_image_by_url(image_url)
|
content_blocks = []
|
||||||
image_data = await self.encode_image_bs64(image_path)
|
|
||||||
elif image_url.startswith("file:///"):
|
# 1. 用户原始发言(OpenAI 建议:用户发言在前)
|
||||||
image_path = image_url.replace("file:///", "")
|
if text:
|
||||||
image_data = await self.encode_image_bs64(image_path)
|
content_blocks.append({"type": "text", "text": text})
|
||||||
|
elif image_urls:
|
||||||
|
# 如果没有文本但有图片,添加占位文本
|
||||||
|
content_blocks.append({"type": "text", "text": "[图片]"})
|
||||||
|
elif extra_user_content_parts:
|
||||||
|
# 如果只有额外内容块,也需要添加占位文本
|
||||||
|
content_blocks.append({"type": "text", "text": " "})
|
||||||
|
|
||||||
|
# 2. 额外的内容块(系统提醒、指令等)
|
||||||
|
if extra_user_content_parts:
|
||||||
|
for part in extra_user_content_parts:
|
||||||
|
if isinstance(part, TextPart):
|
||||||
|
content_blocks.append({"type": "text", "text": part.text})
|
||||||
|
elif isinstance(part, ImageURLPart):
|
||||||
|
image_part = await resolve_image_part(part.image_url.url)
|
||||||
|
if image_part:
|
||||||
|
content_blocks.append(image_part)
|
||||||
else:
|
else:
|
||||||
image_data = await self.encode_image_bs64(image_url)
|
raise ValueError(f"不支持的额外内容块类型: {type(part)}")
|
||||||
if not image_data:
|
|
||||||
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
|
# 3. 图片内容
|
||||||
continue
|
if image_urls:
|
||||||
user_content["content"].append(
|
for image_url in image_urls:
|
||||||
{
|
image_part = await resolve_image_part(image_url)
|
||||||
"type": "image_url",
|
if image_part:
|
||||||
"image_url": {"url": image_data},
|
content_blocks.append(image_part)
|
||||||
},
|
|
||||||
)
|
# 如果只有主文本且没有额外内容块和图片,返回简单格式以保持向后兼容
|
||||||
return user_content
|
if (
|
||||||
return {"role": "user", "content": text}
|
text
|
||||||
|
and not extra_user_content_parts
|
||||||
|
and not image_urls
|
||||||
|
and len(content_blocks) == 1
|
||||||
|
and content_blocks[0]["type"] == "text"
|
||||||
|
):
|
||||||
|
return {"role": "user", "content": content_blocks[0]["text"]}
|
||||||
|
|
||||||
|
# 否则返回多模态格式
|
||||||
|
return {"role": "user", "content": content_blocks}
|
||||||
|
|
||||||
async def encode_image_bs64(self, image_url: str) -> str:
|
async def encode_image_bs64(self, image_url: str) -> str:
|
||||||
"""将图片转换为 base64"""
|
"""将图片转换为 base64"""
|
||||||
|
|||||||
@@ -377,7 +377,7 @@ class Context:
|
|||||||
if not module_path:
|
if not module_path:
|
||||||
_parts = []
|
_parts = []
|
||||||
module_part = tool.__module__.split(".")
|
module_part = tool.__module__.split(".")
|
||||||
flags = ["packages", "plugins"]
|
flags = ["builtin_stars", "plugins"]
|
||||||
for i, part in enumerate(module_part):
|
for i, part in enumerate(module_part):
|
||||||
_parts.append(part)
|
_parts.append(part)
|
||||||
if part in flags and i + 1 < len(module_part):
|
if part in flags and i + 1 < len(module_part):
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ from astrbot.core.config.astrbot_config import AstrBotConfig
|
|||||||
from astrbot.core.provider.register import llm_tools
|
from astrbot.core.provider.register import llm_tools
|
||||||
from astrbot.core.utils.astrbot_path import (
|
from astrbot.core.utils.astrbot_path import (
|
||||||
get_astrbot_config_path,
|
get_astrbot_config_path,
|
||||||
|
get_astrbot_path,
|
||||||
get_astrbot_plugin_path,
|
get_astrbot_plugin_path,
|
||||||
)
|
)
|
||||||
from astrbot.core.utils.io import remove_dir
|
from astrbot.core.utils.io import remove_dir
|
||||||
@@ -49,13 +50,10 @@ class PluginManager:
|
|||||||
"""存储插件的路径。即 data/plugins"""
|
"""存储插件的路径。即 data/plugins"""
|
||||||
self.plugin_config_path = get_astrbot_config_path()
|
self.plugin_config_path = get_astrbot_config_path()
|
||||||
"""存储插件配置的路径。data/config"""
|
"""存储插件配置的路径。data/config"""
|
||||||
self.reserved_plugin_path = os.path.abspath(
|
self.reserved_plugin_path = os.path.join(
|
||||||
os.path.join(
|
get_astrbot_path(), "astrbot", "builtin_stars"
|
||||||
os.path.dirname(os.path.abspath(__file__)),
|
|
||||||
"../../../packages",
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
"""保留插件的路径。在 packages 目录下"""
|
"""保留插件的路径。在 astrbot/builtin_stars 目录下"""
|
||||||
self.conf_schema_fname = "_conf_schema.json"
|
self.conf_schema_fname = "_conf_schema.json"
|
||||||
self.logo_fname = "logo.png"
|
self.logo_fname = "logo.png"
|
||||||
"""插件配置 Schema 文件名"""
|
"""插件配置 Schema 文件名"""
|
||||||
@@ -252,7 +250,7 @@ class PluginManager:
|
|||||||
list[str]: 与该插件相关的模块名列表
|
list[str]: 与该插件相关的模块名列表
|
||||||
|
|
||||||
"""
|
"""
|
||||||
prefix = "packages." if is_reserved else "data.plugins."
|
prefix = "astrbot.builtin_stars." if is_reserved else "data.plugins."
|
||||||
return [
|
return [
|
||||||
key
|
key
|
||||||
for key in list(sys.modules.keys())
|
for key in list(sys.modules.keys())
|
||||||
@@ -270,7 +268,7 @@ class PluginManager:
|
|||||||
可以基于模块名模式或插件目录名移除模块,用于清理插件相关的模块缓存
|
可以基于模块名模式或插件目录名移除模块,用于清理插件相关的模块缓存
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
module_patterns: 要移除的模块名模式列表(例如 ["data.plugins", "packages"])
|
module_patterns: 要移除的模块名模式列表(例如 ["data.plugins", "astrbot.builtin_stars"])
|
||||||
root_dir_name: 插件根目录名,用于移除与该插件相关的所有模块
|
root_dir_name: 插件根目录名,用于移除与该插件相关的所有模块
|
||||||
is_reserved: 插件是否为保留插件(影响模块路径前缀)
|
is_reserved: 插件是否为保留插件(影响模块路径前缀)
|
||||||
|
|
||||||
@@ -382,9 +380,9 @@ class PluginManager:
|
|||||||
reserved = plugin_module.get(
|
reserved = plugin_module.get(
|
||||||
"reserved",
|
"reserved",
|
||||||
False,
|
False,
|
||||||
) # 是否是保留插件。目前在 packages/ 目录下的都是保留插件。保留插件不可以卸载。
|
) # 是否是保留插件。目前在 astrbot/builtin_stars 目录下的都是保留插件。保留插件不可以卸载。
|
||||||
|
|
||||||
path = "data.plugins." if not reserved else "packages."
|
path = "data.plugins." if not reserved else "astrbot.builtin_stars."
|
||||||
path += root_dir_name + "." + module_str
|
path += root_dir_name + "." + module_str
|
||||||
|
|
||||||
# 检查是否需要载入指定的插件
|
# 检查是否需要载入指定的插件
|
||||||
@@ -829,7 +827,7 @@ class PluginManager:
|
|||||||
if (
|
if (
|
||||||
mp
|
mp
|
||||||
and mp.startswith(plugin_module_path)
|
and mp.startswith(plugin_module_path)
|
||||||
and not mp.endswith(("packages", "data.plugins"))
|
and not mp.endswith(("astrbot.builtin_stars", "data.plugins"))
|
||||||
):
|
):
|
||||||
to_remove.append(func_tool)
|
to_remove.append(func_tool)
|
||||||
for func_tool in to_remove:
|
for func_tool in to_remove:
|
||||||
@@ -884,7 +882,7 @@ class PluginManager:
|
|||||||
plugin.module_path
|
plugin.module_path
|
||||||
and mp
|
and mp
|
||||||
and plugin.module_path.startswith(mp)
|
and plugin.module_path.startswith(mp)
|
||||||
and not mp.endswith(("packages", "data.plugins"))
|
and not mp.endswith(("astrbot.builtin_stars", "data.plugins"))
|
||||||
):
|
):
|
||||||
func_tool.active = False
|
func_tool.active = False
|
||||||
if func_tool.name not in inactivated_llm_tools:
|
if func_tool.name not in inactivated_llm_tools:
|
||||||
@@ -933,7 +931,7 @@ class PluginManager:
|
|||||||
plugin.module_path
|
plugin.module_path
|
||||||
and mp
|
and mp
|
||||||
and plugin.module_path.startswith(mp)
|
and plugin.module_path.startswith(mp)
|
||||||
and not mp.endswith(("packages", "data.plugins"))
|
and not mp.endswith(("astrbot.builtin_stars", "data.plugins"))
|
||||||
and func_tool.name in inactivated_llm_tools
|
and func_tool.name in inactivated_llm_tools
|
||||||
):
|
):
|
||||||
inactivated_llm_tools.remove(func_tool.name)
|
inactivated_llm_tools.remove(func_tool.name)
|
||||||
|
|||||||
@@ -5,6 +5,10 @@
|
|||||||
数据目录路径:固定为根目录下的 data 目录
|
数据目录路径:固定为根目录下的 data 目录
|
||||||
配置文件路径:固定为数据目录下的 config 目录
|
配置文件路径:固定为数据目录下的 config 目录
|
||||||
插件目录路径:固定为数据目录下的 plugins 目录
|
插件目录路径:固定为数据目录下的 plugins 目录
|
||||||
|
插件数据目录路径:固定为数据目录下的 plugin_data 目录
|
||||||
|
T2I 模板目录路径:固定为数据目录下的 t2i_templates 目录
|
||||||
|
WebChat 数据目录路径:固定为数据目录下的 webchat 目录
|
||||||
|
临时文件目录路径:固定为数据目录下的 temp 目录
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
@@ -37,3 +41,33 @@ def get_astrbot_config_path() -> str:
|
|||||||
def get_astrbot_plugin_path() -> str:
|
def get_astrbot_plugin_path() -> str:
|
||||||
"""获取Astrbot插件目录路径"""
|
"""获取Astrbot插件目录路径"""
|
||||||
return os.path.realpath(os.path.join(get_astrbot_data_path(), "plugins"))
|
return os.path.realpath(os.path.join(get_astrbot_data_path(), "plugins"))
|
||||||
|
|
||||||
|
|
||||||
|
def get_astrbot_plugin_data_path() -> str:
|
||||||
|
"""获取Astrbot插件数据目录路径"""
|
||||||
|
return os.path.realpath(os.path.join(get_astrbot_data_path(), "plugin_data"))
|
||||||
|
|
||||||
|
|
||||||
|
def get_astrbot_t2i_templates_path() -> str:
|
||||||
|
"""获取Astrbot T2I 模板目录路径"""
|
||||||
|
return os.path.realpath(os.path.join(get_astrbot_data_path(), "t2i_templates"))
|
||||||
|
|
||||||
|
|
||||||
|
def get_astrbot_webchat_path() -> str:
|
||||||
|
"""获取Astrbot WebChat 数据目录路径"""
|
||||||
|
return os.path.realpath(os.path.join(get_astrbot_data_path(), "webchat"))
|
||||||
|
|
||||||
|
|
||||||
|
def get_astrbot_temp_path() -> str:
|
||||||
|
"""获取Astrbot临时文件目录路径"""
|
||||||
|
return os.path.realpath(os.path.join(get_astrbot_data_path(), "temp"))
|
||||||
|
|
||||||
|
|
||||||
|
def get_astrbot_knowledge_base_path() -> str:
|
||||||
|
"""获取Astrbot知识库根目录路径"""
|
||||||
|
return os.path.realpath(os.path.join(get_astrbot_data_path(), "knowledge_base"))
|
||||||
|
|
||||||
|
|
||||||
|
def get_astrbot_backups_path() -> str:
|
||||||
|
"""获取Astrbot备份目录路径"""
|
||||||
|
return os.path.realpath(os.path.join(get_astrbot_data_path(), "backups"))
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
from .auth import AuthRoute
|
from .auth import AuthRoute
|
||||||
|
from .backup import BackupRoute
|
||||||
from .chat import ChatRoute
|
from .chat import ChatRoute
|
||||||
from .command import CommandRoute
|
from .command import CommandRoute
|
||||||
from .config import ConfigRoute
|
from .config import ConfigRoute
|
||||||
@@ -17,6 +18,7 @@ from .update import UpdateRoute
|
|||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"AuthRoute",
|
"AuthRoute",
|
||||||
|
"BackupRoute",
|
||||||
"ChatRoute",
|
"ChatRoute",
|
||||||
"CommandRoute",
|
"CommandRoute",
|
||||||
"ConfigRoute",
|
"ConfigRoute",
|
||||||
|
|||||||
@@ -0,0 +1,589 @@
|
|||||||
|
"""备份管理 API 路由"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import traceback
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from quart import request, send_file
|
||||||
|
|
||||||
|
from astrbot.core import logger
|
||||||
|
from astrbot.core.backup.exporter import AstrBotExporter
|
||||||
|
from astrbot.core.backup.importer import AstrBotImporter
|
||||||
|
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||||
|
from astrbot.core.db import BaseDatabase
|
||||||
|
from astrbot.core.utils.astrbot_path import (
|
||||||
|
get_astrbot_backups_path,
|
||||||
|
get_astrbot_data_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .route import Response, Route, RouteContext
|
||||||
|
|
||||||
|
|
||||||
|
def secure_filename(filename: str) -> str:
|
||||||
|
"""清洗文件名,移除路径遍历字符和危险字符
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename: 原始文件名
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
安全的文件名
|
||||||
|
"""
|
||||||
|
# 跨平台处理:先将反斜杠替换为正斜杠,再取文件名
|
||||||
|
filename = filename.replace("\\", "/")
|
||||||
|
# 仅保留文件名部分,移除路径
|
||||||
|
filename = os.path.basename(filename)
|
||||||
|
|
||||||
|
# 替换路径遍历字符
|
||||||
|
filename = filename.replace("..", "_")
|
||||||
|
|
||||||
|
# 仅保留字母、数字、下划线、连字符、点
|
||||||
|
filename = re.sub(r"[^\w\-.]", "_", filename)
|
||||||
|
|
||||||
|
# 移除前导点(隐藏文件)和尾部点
|
||||||
|
filename = filename.strip(".")
|
||||||
|
|
||||||
|
# 如果文件名为空或只包含下划线,生成一个默认名称
|
||||||
|
if not filename or filename.replace("_", "") == "":
|
||||||
|
filename = "backup"
|
||||||
|
|
||||||
|
return filename
|
||||||
|
|
||||||
|
|
||||||
|
def generate_unique_filename(original_filename: str) -> str:
|
||||||
|
"""生成唯一的文件名,添加时间戳前缀
|
||||||
|
|
||||||
|
Args:
|
||||||
|
original_filename: 原始文件名(已清洗)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
唯一的文件名
|
||||||
|
"""
|
||||||
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
|
name, ext = os.path.splitext(original_filename)
|
||||||
|
return f"uploaded_{timestamp}_{name}{ext}"
|
||||||
|
|
||||||
|
|
||||||
|
class BackupRoute(Route):
|
||||||
|
"""备份管理路由
|
||||||
|
|
||||||
|
提供备份导出、导入、列表等 API 接口
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
context: RouteContext,
|
||||||
|
db: BaseDatabase,
|
||||||
|
core_lifecycle: AstrBotCoreLifecycle,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(context)
|
||||||
|
self.db = db
|
||||||
|
self.core_lifecycle = core_lifecycle
|
||||||
|
self.backup_dir = get_astrbot_backups_path()
|
||||||
|
self.data_dir = get_astrbot_data_path()
|
||||||
|
|
||||||
|
# 任务状态跟踪
|
||||||
|
self.backup_tasks: dict[str, dict] = {}
|
||||||
|
self.backup_progress: dict[str, dict] = {}
|
||||||
|
|
||||||
|
# 注册路由
|
||||||
|
self.routes = {
|
||||||
|
"/backup/list": ("GET", self.list_backups),
|
||||||
|
"/backup/export": ("POST", self.export_backup),
|
||||||
|
"/backup/upload": ("POST", self.upload_backup), # 上传文件
|
||||||
|
"/backup/check": ("POST", self.check_backup), # 预检查
|
||||||
|
"/backup/import": ("POST", self.import_backup), # 确认导入
|
||||||
|
"/backup/progress": ("GET", self.get_progress),
|
||||||
|
"/backup/download": ("GET", self.download_backup),
|
||||||
|
"/backup/delete": ("POST", self.delete_backup),
|
||||||
|
}
|
||||||
|
self.register_routes()
|
||||||
|
|
||||||
|
def _init_task(self, task_id: str, task_type: str, status: str = "pending") -> None:
|
||||||
|
"""初始化任务状态"""
|
||||||
|
self.backup_tasks[task_id] = {
|
||||||
|
"type": task_type,
|
||||||
|
"status": status,
|
||||||
|
"result": None,
|
||||||
|
"error": None,
|
||||||
|
}
|
||||||
|
self.backup_progress[task_id] = {
|
||||||
|
"status": status,
|
||||||
|
"stage": "waiting",
|
||||||
|
"current": 0,
|
||||||
|
"total": 100,
|
||||||
|
"message": "",
|
||||||
|
}
|
||||||
|
|
||||||
|
def _set_task_result(
|
||||||
|
self,
|
||||||
|
task_id: str,
|
||||||
|
status: str,
|
||||||
|
result: dict | None = None,
|
||||||
|
error: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""设置任务结果"""
|
||||||
|
if task_id in self.backup_tasks:
|
||||||
|
self.backup_tasks[task_id]["status"] = status
|
||||||
|
self.backup_tasks[task_id]["result"] = result
|
||||||
|
self.backup_tasks[task_id]["error"] = error
|
||||||
|
if task_id in self.backup_progress:
|
||||||
|
self.backup_progress[task_id]["status"] = status
|
||||||
|
|
||||||
|
def _update_progress(
|
||||||
|
self,
|
||||||
|
task_id: str,
|
||||||
|
*,
|
||||||
|
status: str | None = None,
|
||||||
|
stage: str | None = None,
|
||||||
|
current: int | None = None,
|
||||||
|
total: int | None = None,
|
||||||
|
message: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""更新任务进度"""
|
||||||
|
if task_id not in self.backup_progress:
|
||||||
|
return
|
||||||
|
p = self.backup_progress[task_id]
|
||||||
|
if status is not None:
|
||||||
|
p["status"] = status
|
||||||
|
if stage is not None:
|
||||||
|
p["stage"] = stage
|
||||||
|
if current is not None:
|
||||||
|
p["current"] = current
|
||||||
|
if total is not None:
|
||||||
|
p["total"] = total
|
||||||
|
if message is not None:
|
||||||
|
p["message"] = message
|
||||||
|
|
||||||
|
def _make_progress_callback(self, task_id: str):
|
||||||
|
"""创建进度回调函数"""
|
||||||
|
|
||||||
|
async def _callback(stage: str, current: int, total: int, message: str = ""):
|
||||||
|
self._update_progress(
|
||||||
|
task_id,
|
||||||
|
status="processing",
|
||||||
|
stage=stage,
|
||||||
|
current=current,
|
||||||
|
total=total,
|
||||||
|
message=message,
|
||||||
|
)
|
||||||
|
|
||||||
|
return _callback
|
||||||
|
|
||||||
|
async def list_backups(self):
|
||||||
|
"""获取备份列表
|
||||||
|
|
||||||
|
Query 参数:
|
||||||
|
- page: 页码 (默认 1)
|
||||||
|
- page_size: 每页数量 (默认 20)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
page = request.args.get("page", 1, type=int)
|
||||||
|
page_size = request.args.get("page_size", 20, type=int)
|
||||||
|
|
||||||
|
# 确保备份目录存在
|
||||||
|
Path(self.backup_dir).mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# 获取所有备份文件
|
||||||
|
backup_files = []
|
||||||
|
for filename in os.listdir(self.backup_dir):
|
||||||
|
if filename.endswith(".zip") and filename.startswith("astrbot_backup_"):
|
||||||
|
file_path = os.path.join(self.backup_dir, filename)
|
||||||
|
stat = os.stat(file_path)
|
||||||
|
backup_files.append(
|
||||||
|
{
|
||||||
|
"filename": filename,
|
||||||
|
"size": stat.st_size,
|
||||||
|
"created_at": stat.st_mtime,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 按创建时间倒序排序
|
||||||
|
backup_files.sort(key=lambda x: x["created_at"], reverse=True)
|
||||||
|
|
||||||
|
# 分页
|
||||||
|
start = (page - 1) * page_size
|
||||||
|
end = start + page_size
|
||||||
|
items = backup_files[start:end]
|
||||||
|
|
||||||
|
return (
|
||||||
|
Response()
|
||||||
|
.ok(
|
||||||
|
{
|
||||||
|
"items": items,
|
||||||
|
"total": len(backup_files),
|
||||||
|
"page": page,
|
||||||
|
"page_size": page_size,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
.__dict__
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取备份列表失败: {e}")
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
return Response().error(f"获取备份列表失败: {e!s}").__dict__
|
||||||
|
|
||||||
|
async def export_backup(self):
|
||||||
|
"""创建备份
|
||||||
|
|
||||||
|
返回:
|
||||||
|
- task_id: 任务ID,用于查询导出进度
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 生成任务ID
|
||||||
|
task_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
# 初始化任务状态
|
||||||
|
self._init_task(task_id, "export", "pending")
|
||||||
|
|
||||||
|
# 启动后台导出任务
|
||||||
|
asyncio.create_task(self._background_export_task(task_id))
|
||||||
|
|
||||||
|
return (
|
||||||
|
Response()
|
||||||
|
.ok(
|
||||||
|
{
|
||||||
|
"task_id": task_id,
|
||||||
|
"message": "export task created, processing in background",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
.__dict__
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"创建备份失败: {e}")
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
return Response().error(f"创建备份失败: {e!s}").__dict__
|
||||||
|
|
||||||
|
async def _background_export_task(self, task_id: str):
|
||||||
|
"""后台导出任务"""
|
||||||
|
try:
|
||||||
|
self._update_progress(task_id, status="processing", message="正在初始化...")
|
||||||
|
|
||||||
|
# 获取知识库管理器
|
||||||
|
kb_manager = getattr(self.core_lifecycle, "kb_manager", None)
|
||||||
|
|
||||||
|
exporter = AstrBotExporter(
|
||||||
|
main_db=self.db,
|
||||||
|
kb_manager=kb_manager,
|
||||||
|
config_path=os.path.join(self.data_dir, "cmd_config.json"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# 创建进度回调
|
||||||
|
progress_callback = self._make_progress_callback(task_id)
|
||||||
|
|
||||||
|
# 执行导出
|
||||||
|
zip_path = await exporter.export_all(
|
||||||
|
output_dir=self.backup_dir,
|
||||||
|
progress_callback=progress_callback,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 设置成功结果
|
||||||
|
self._set_task_result(
|
||||||
|
task_id,
|
||||||
|
"completed",
|
||||||
|
result={
|
||||||
|
"filename": os.path.basename(zip_path),
|
||||||
|
"path": zip_path,
|
||||||
|
"size": os.path.getsize(zip_path),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"后台导出任务 {task_id} 失败: {e}")
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
self._set_task_result(task_id, "failed", error=str(e))
|
||||||
|
|
||||||
|
async def upload_backup(self):
|
||||||
|
"""上传备份文件
|
||||||
|
|
||||||
|
将备份文件上传到服务器,返回保存的文件名。
|
||||||
|
上传后应调用 check_backup 进行预检查。
|
||||||
|
|
||||||
|
Form Data:
|
||||||
|
- file: 备份文件 (.zip)
|
||||||
|
|
||||||
|
返回:
|
||||||
|
- filename: 保存的文件名
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
files = await request.files
|
||||||
|
if "file" not in files:
|
||||||
|
return Response().error("缺少备份文件").__dict__
|
||||||
|
|
||||||
|
file = files["file"]
|
||||||
|
if not file.filename or not file.filename.endswith(".zip"):
|
||||||
|
return Response().error("请上传 ZIP 格式的备份文件").__dict__
|
||||||
|
|
||||||
|
# 清洗文件名并生成唯一名称,防止路径遍历和覆盖
|
||||||
|
safe_filename = secure_filename(file.filename)
|
||||||
|
unique_filename = generate_unique_filename(safe_filename)
|
||||||
|
|
||||||
|
# 保存上传的文件
|
||||||
|
Path(self.backup_dir).mkdir(parents=True, exist_ok=True)
|
||||||
|
zip_path = os.path.join(self.backup_dir, unique_filename)
|
||||||
|
await file.save(zip_path)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"上传的备份文件已保存: {unique_filename} (原始名称: {file.filename})"
|
||||||
|
)
|
||||||
|
|
||||||
|
return (
|
||||||
|
Response()
|
||||||
|
.ok(
|
||||||
|
{
|
||||||
|
"filename": unique_filename,
|
||||||
|
"original_filename": file.filename,
|
||||||
|
"size": os.path.getsize(zip_path),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
.__dict__
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"上传备份文件失败: {e}")
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
return Response().error(f"上传备份文件失败: {e!s}").__dict__
|
||||||
|
|
||||||
|
async def check_backup(self):
|
||||||
|
"""预检查备份文件
|
||||||
|
|
||||||
|
检查备份文件的版本兼容性,返回确认信息。
|
||||||
|
用户确认后调用 import_backup 执行导入。
|
||||||
|
|
||||||
|
JSON Body:
|
||||||
|
- filename: 已上传的备份文件名
|
||||||
|
|
||||||
|
返回:
|
||||||
|
- ImportPreCheckResult: 预检查结果
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
data = await request.json
|
||||||
|
filename = data.get("filename")
|
||||||
|
if not filename:
|
||||||
|
return Response().error("缺少 filename 参数").__dict__
|
||||||
|
|
||||||
|
# 安全检查 - 防止路径遍历
|
||||||
|
if ".." in filename or "/" in filename or "\\" in filename:
|
||||||
|
return Response().error("无效的文件名").__dict__
|
||||||
|
|
||||||
|
zip_path = os.path.join(self.backup_dir, filename)
|
||||||
|
if not os.path.exists(zip_path):
|
||||||
|
return Response().error(f"备份文件不存在: {filename}").__dict__
|
||||||
|
|
||||||
|
# 获取知识库管理器(用于构造 importer)
|
||||||
|
kb_manager = getattr(self.core_lifecycle, "kb_manager", None)
|
||||||
|
|
||||||
|
importer = AstrBotImporter(
|
||||||
|
main_db=self.db,
|
||||||
|
kb_manager=kb_manager,
|
||||||
|
config_path=os.path.join(self.data_dir, "cmd_config.json"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# 执行预检查
|
||||||
|
check_result = importer.pre_check(zip_path)
|
||||||
|
|
||||||
|
return Response().ok(check_result.to_dict()).__dict__
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"预检查备份文件失败: {e}")
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
return Response().error(f"预检查备份文件失败: {e!s}").__dict__
|
||||||
|
|
||||||
|
async def import_backup(self):
|
||||||
|
"""执行备份导入
|
||||||
|
|
||||||
|
在用户确认后执行实际的导入操作。
|
||||||
|
需要先调用 upload_backup 上传文件,再调用 check_backup 预检查。
|
||||||
|
|
||||||
|
JSON Body:
|
||||||
|
- filename: 已上传的备份文件名(必填)
|
||||||
|
- confirmed: 用户已确认(必填,必须为 true)
|
||||||
|
|
||||||
|
返回:
|
||||||
|
- task_id: 任务ID,用于查询导入进度
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
data = await request.json
|
||||||
|
filename = data.get("filename")
|
||||||
|
confirmed = data.get("confirmed", False)
|
||||||
|
|
||||||
|
if not filename:
|
||||||
|
return Response().error("缺少 filename 参数").__dict__
|
||||||
|
|
||||||
|
if not confirmed:
|
||||||
|
return (
|
||||||
|
Response()
|
||||||
|
.error("请先确认导入。导入将会清空并覆盖现有数据,此操作不可撤销。")
|
||||||
|
.__dict__
|
||||||
|
)
|
||||||
|
|
||||||
|
# 安全检查 - 防止路径遍历
|
||||||
|
if ".." in filename or "/" in filename or "\\" in filename:
|
||||||
|
return Response().error("无效的文件名").__dict__
|
||||||
|
|
||||||
|
zip_path = os.path.join(self.backup_dir, filename)
|
||||||
|
if not os.path.exists(zip_path):
|
||||||
|
return Response().error(f"备份文件不存在: {filename}").__dict__
|
||||||
|
|
||||||
|
# 生成任务ID
|
||||||
|
task_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
# 初始化任务状态
|
||||||
|
self._init_task(task_id, "import", "pending")
|
||||||
|
|
||||||
|
# 启动后台导入任务
|
||||||
|
asyncio.create_task(self._background_import_task(task_id, zip_path))
|
||||||
|
|
||||||
|
return (
|
||||||
|
Response()
|
||||||
|
.ok(
|
||||||
|
{
|
||||||
|
"task_id": task_id,
|
||||||
|
"message": "import task created, processing in background",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
.__dict__
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"导入备份失败: {e}")
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
return Response().error(f"导入备份失败: {e!s}").__dict__
|
||||||
|
|
||||||
|
async def _background_import_task(self, task_id: str, zip_path: str):
|
||||||
|
"""后台导入任务"""
|
||||||
|
try:
|
||||||
|
self._update_progress(task_id, status="processing", message="正在初始化...")
|
||||||
|
|
||||||
|
# 获取知识库管理器
|
||||||
|
kb_manager = getattr(self.core_lifecycle, "kb_manager", None)
|
||||||
|
|
||||||
|
importer = AstrBotImporter(
|
||||||
|
main_db=self.db,
|
||||||
|
kb_manager=kb_manager,
|
||||||
|
config_path=os.path.join(self.data_dir, "cmd_config.json"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# 创建进度回调
|
||||||
|
progress_callback = self._make_progress_callback(task_id)
|
||||||
|
|
||||||
|
# 执行导入
|
||||||
|
result = await importer.import_all(
|
||||||
|
zip_path=zip_path,
|
||||||
|
mode="replace",
|
||||||
|
progress_callback=progress_callback,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 设置结果
|
||||||
|
if result.success:
|
||||||
|
self._set_task_result(
|
||||||
|
task_id,
|
||||||
|
"completed",
|
||||||
|
result=result.to_dict(),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self._set_task_result(
|
||||||
|
task_id,
|
||||||
|
"failed",
|
||||||
|
error="; ".join(result.errors),
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"后台导入任务 {task_id} 失败: {e}")
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
self._set_task_result(task_id, "failed", error=str(e))
|
||||||
|
|
||||||
|
async def get_progress(self):
|
||||||
|
"""获取任务进度
|
||||||
|
|
||||||
|
Query 参数:
|
||||||
|
- task_id: 任务 ID (必填)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
task_id = request.args.get("task_id")
|
||||||
|
if not task_id:
|
||||||
|
return Response().error("缺少参数 task_id").__dict__
|
||||||
|
|
||||||
|
if task_id not in self.backup_tasks:
|
||||||
|
return Response().error("找不到该任务").__dict__
|
||||||
|
|
||||||
|
task_info = self.backup_tasks[task_id]
|
||||||
|
status = task_info["status"]
|
||||||
|
|
||||||
|
response_data = {
|
||||||
|
"task_id": task_id,
|
||||||
|
"type": task_info["type"],
|
||||||
|
"status": status,
|
||||||
|
}
|
||||||
|
|
||||||
|
# 如果任务正在处理,返回进度信息
|
||||||
|
if status == "processing" and task_id in self.backup_progress:
|
||||||
|
response_data["progress"] = self.backup_progress[task_id]
|
||||||
|
|
||||||
|
# 如果任务完成,返回结果
|
||||||
|
if status == "completed":
|
||||||
|
response_data["result"] = task_info["result"]
|
||||||
|
|
||||||
|
# 如果任务失败,返回错误信息
|
||||||
|
if status == "failed":
|
||||||
|
response_data["error"] = task_info["error"]
|
||||||
|
|
||||||
|
return Response().ok(response_data).__dict__
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取任务进度失败: {e}")
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
return Response().error(f"获取任务进度失败: {e!s}").__dict__
|
||||||
|
|
||||||
|
async def download_backup(self):
|
||||||
|
"""下载备份文件
|
||||||
|
|
||||||
|
Query 参数:
|
||||||
|
- filename: 备份文件名 (必填)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
filename = request.args.get("filename")
|
||||||
|
if not filename:
|
||||||
|
return Response().error("缺少参数 filename").__dict__
|
||||||
|
|
||||||
|
# 安全检查 - 防止路径遍历
|
||||||
|
if ".." in filename or "/" in filename or "\\" in filename:
|
||||||
|
return Response().error("无效的文件名").__dict__
|
||||||
|
|
||||||
|
file_path = os.path.join(self.backup_dir, filename)
|
||||||
|
if not os.path.exists(file_path):
|
||||||
|
return Response().error("备份文件不存在").__dict__
|
||||||
|
|
||||||
|
return await send_file(
|
||||||
|
file_path,
|
||||||
|
as_attachment=True,
|
||||||
|
attachment_filename=filename,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"下载备份失败: {e}")
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
return Response().error(f"下载备份失败: {e!s}").__dict__
|
||||||
|
|
||||||
|
async def delete_backup(self):
|
||||||
|
"""删除备份文件
|
||||||
|
|
||||||
|
Body:
|
||||||
|
- filename: 备份文件名 (必填)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
data = await request.json
|
||||||
|
filename = data.get("filename")
|
||||||
|
if not filename:
|
||||||
|
return Response().error("缺少参数 filename").__dict__
|
||||||
|
|
||||||
|
# 安全检查 - 防止路径遍历
|
||||||
|
if ".." in filename or "/" in filename or "\\" in filename:
|
||||||
|
return Response().error("无效的文件名").__dict__
|
||||||
|
|
||||||
|
file_path = os.path.join(self.backup_dir, filename)
|
||||||
|
if not os.path.exists(file_path):
|
||||||
|
return Response().error("备份文件不存在").__dict__
|
||||||
|
|
||||||
|
os.remove(file_path)
|
||||||
|
return Response().ok(message="删除备份成功").__dict__
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"删除备份失败: {e}")
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
return Response().error(f"删除备份失败: {e!s}").__dict__
|
||||||
@@ -1,15 +1,26 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
|
import time
|
||||||
|
from collections.abc import AsyncGenerator
|
||||||
from typing import cast
|
from typing import cast
|
||||||
|
|
||||||
from quart import Response as QuartResponse
|
from quart import Response as QuartResponse
|
||||||
from quart import make_response
|
from quart import make_response, request
|
||||||
|
|
||||||
from astrbot.core import LogBroker, logger
|
from astrbot.core import LogBroker, logger
|
||||||
|
|
||||||
from .route import Response, Route, RouteContext
|
from .route import Response, Route, RouteContext
|
||||||
|
|
||||||
|
|
||||||
|
def _format_log_sse(log: dict, ts: float) -> str:
|
||||||
|
"""辅助函数:格式化 SSE 消息"""
|
||||||
|
payload = {
|
||||||
|
"type": "log",
|
||||||
|
**log,
|
||||||
|
}
|
||||||
|
return f"id: {ts}\ndata: {json.dumps(payload, ensure_ascii=False)}\n\n"
|
||||||
|
|
||||||
|
|
||||||
class LogRoute(Route):
|
class LogRoute(Route):
|
||||||
def __init__(self, context: RouteContext, log_broker: LogBroker) -> None:
|
def __init__(self, context: RouteContext, log_broker: LogBroker) -> None:
|
||||||
super().__init__(context)
|
super().__init__(context)
|
||||||
@@ -21,21 +32,44 @@ class LogRoute(Route):
|
|||||||
methods=["GET"],
|
methods=["GET"],
|
||||||
)
|
)
|
||||||
|
|
||||||
async def log(self):
|
async def _replay_cached_logs(
|
||||||
|
self, last_event_id: str
|
||||||
|
) -> AsyncGenerator[str, None]:
|
||||||
|
"""辅助生成器:重放缓存的日志"""
|
||||||
|
try:
|
||||||
|
last_ts = float(last_event_id)
|
||||||
|
cached_logs = list(self.log_broker.log_cache)
|
||||||
|
|
||||||
|
for log_item in cached_logs:
|
||||||
|
log_ts = float(log_item.get("time", 0))
|
||||||
|
|
||||||
|
if log_ts > last_ts:
|
||||||
|
yield _format_log_sse(log_item, log_ts)
|
||||||
|
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Log SSE 补发历史错误: {e}")
|
||||||
|
|
||||||
|
async def log(self) -> QuartResponse:
|
||||||
|
last_event_id = request.headers.get("Last-Event-ID")
|
||||||
|
|
||||||
async def stream():
|
async def stream():
|
||||||
queue = None
|
queue = None
|
||||||
try:
|
try:
|
||||||
|
if last_event_id:
|
||||||
|
async for event in self._replay_cached_logs(last_event_id):
|
||||||
|
yield event
|
||||||
|
|
||||||
queue = self.log_broker.register()
|
queue = self.log_broker.register()
|
||||||
while True:
|
while True:
|
||||||
message = await queue.get()
|
message = await queue.get()
|
||||||
payload = {
|
current_ts = message.get("time", time.time())
|
||||||
"type": "log",
|
yield _format_log_sse(message, current_ts)
|
||||||
**message, # see astrbot/core/log.py
|
|
||||||
}
|
|
||||||
yield f"data: {json.dumps(payload, ensure_ascii=False)}\n\n"
|
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
pass
|
pass
|
||||||
except BaseException as e:
|
except Exception as e:
|
||||||
logger.error(f"Log SSE 连接错误: {e}")
|
logger.error(f"Log SSE 连接错误: {e}")
|
||||||
finally:
|
finally:
|
||||||
if queue:
|
if queue:
|
||||||
@@ -53,7 +87,7 @@ class LogRoute(Route):
|
|||||||
},
|
},
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
response.timeout = None
|
response.timeout = None # type: ignore
|
||||||
return response
|
return response
|
||||||
|
|
||||||
async def log_history(self):
|
async def log_history(self):
|
||||||
@@ -69,6 +103,6 @@ class LogRoute(Route):
|
|||||||
)
|
)
|
||||||
.__dict__
|
.__dict__
|
||||||
)
|
)
|
||||||
except BaseException as e:
|
except Exception as e:
|
||||||
logger.error(f"获取日志历史失败: {e}")
|
logger.error(f"获取日志历史失败: {e}")
|
||||||
return Response().error(f"获取日志历史失败: {e}").__dict__
|
return Response().error(f"获取日志历史失败: {e}").__dict__
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
|||||||
from astrbot.core.utils.io import get_local_ip_addresses
|
from astrbot.core.utils.io import get_local_ip_addresses
|
||||||
|
|
||||||
from .routes import *
|
from .routes import *
|
||||||
|
from .routes.backup import BackupRoute
|
||||||
from .routes.platform import PlatformRoute
|
from .routes.platform import PlatformRoute
|
||||||
from .routes.route import Response, RouteContext
|
from .routes.route import Response, RouteContext
|
||||||
from .routes.session_management import SessionManagementRoute
|
from .routes.session_management import SessionManagementRoute
|
||||||
@@ -85,6 +86,7 @@ class AstrBotDashboard:
|
|||||||
self.t2i_route = T2iRoute(self.context, core_lifecycle)
|
self.t2i_route = T2iRoute(self.context, core_lifecycle)
|
||||||
self.kb_route = KnowledgeBaseRoute(self.context, core_lifecycle)
|
self.kb_route = KnowledgeBaseRoute(self.context, core_lifecycle)
|
||||||
self.platform_route = PlatformRoute(self.context, core_lifecycle)
|
self.platform_route = PlatformRoute(self.context, core_lifecycle)
|
||||||
|
self.backup_route = BackupRoute(self.context, db, core_lifecycle)
|
||||||
|
|
||||||
self.app.add_url_rule(
|
self.app.add_url_rule(
|
||||||
"/api/plug/<path:subpath>",
|
"/api/plug/<path:subpath>",
|
||||||
@@ -108,7 +110,12 @@ class AstrBotDashboard:
|
|||||||
async def auth_middleware(self):
|
async def auth_middleware(self):
|
||||||
if not request.path.startswith("/api"):
|
if not request.path.startswith("/api"):
|
||||||
return None
|
return None
|
||||||
allowed_endpoints = ["/api/auth/login", "/api/file", "/api/platform/webhook"]
|
allowed_endpoints = [
|
||||||
|
"/api/auth/login",
|
||||||
|
"/api/file",
|
||||||
|
"/api/platform/webhook",
|
||||||
|
"/api/stat/start-time",
|
||||||
|
]
|
||||||
if any(request.path.startswith(prefix) for prefix in allowed_endpoints):
|
if any(request.path.startswith(prefix) for prefix in allowed_endpoints):
|
||||||
return None
|
return None
|
||||||
# 声明 JWT
|
# 声明 JWT
|
||||||
|
|||||||
@@ -0,0 +1,18 @@
|
|||||||
|
## What's Changed
|
||||||
|
|
||||||
|
### 修复
|
||||||
|
|
||||||
|
1. 修复 FishAudio TTS 不可用的问题;
|
||||||
|
2. 修复 Anthropic API Chat Provider 部分情况下请求报错的问题;
|
||||||
|
3. 修复部分情况下 WebUI 日志重建连接之后丢失日志的问题;
|
||||||
|
4. 修复部分情况下 /provider 指令报错 index out of range 的问题;
|
||||||
|
5. 修复通过 `uv` 或者 cli 方式启动 AstrBot,缺少所有内置插件的问题。
|
||||||
|
|
||||||
|
### 优化
|
||||||
|
|
||||||
|
1. 丢弃值为 None 的 `tool_call_id` 和 `tool_calls` 字段,提高接口兼容性。
|
||||||
|
|
||||||
|
### 新增
|
||||||
|
|
||||||
|
1. 支持备份 AstrBot 数据和导入数据功能(Beta)。入口:WebUi -> 设置 -> 备份。
|
||||||
|
2. text_chat 和 text_chat_stream 接口支持额外用户内容块参数 `extra_user_content_parts`,用于在用户消息后添加额外的内容块(如系统提醒、指令等)。
|
||||||
@@ -22,6 +22,7 @@
|
|||||||
"axios-mock-adapter": "^1.22.0",
|
"axios-mock-adapter": "^1.22.0",
|
||||||
"chance": "1.1.11",
|
"chance": "1.1.11",
|
||||||
"date-fns": "2.30.0",
|
"date-fns": "2.30.0",
|
||||||
|
"event-source-polyfill": "^1.0.31",
|
||||||
"highlight.js": "^11.11.1",
|
"highlight.js": "^11.11.1",
|
||||||
"js-md5": "^0.8.3",
|
"js-md5": "^0.8.3",
|
||||||
"katex": "^0.16.27",
|
"katex": "^0.16.27",
|
||||||
|
|||||||
@@ -0,0 +1,673 @@
|
|||||||
|
<template>
|
||||||
|
<v-dialog v-model="isOpen" persistent max-width="700" scrollable>
|
||||||
|
<v-card>
|
||||||
|
<v-card-title class="d-flex align-center">
|
||||||
|
<v-icon class="mr-2">mdi-backup-restore</v-icon>
|
||||||
|
{{ t('features.settings.backup.dialog.title') }}
|
||||||
|
</v-card-title>
|
||||||
|
|
||||||
|
<v-card-text class="pa-6">
|
||||||
|
<!-- 选项卡 -->
|
||||||
|
<v-tabs v-model="activeTab" color="primary" class="mb-4">
|
||||||
|
<v-tab value="export">
|
||||||
|
<v-icon class="mr-2">mdi-export</v-icon>
|
||||||
|
{{ t('features.settings.backup.tabs.export') }}
|
||||||
|
</v-tab>
|
||||||
|
<v-tab value="import">
|
||||||
|
<v-icon class="mr-2">mdi-import</v-icon>
|
||||||
|
{{ t('features.settings.backup.tabs.import') }}
|
||||||
|
</v-tab>
|
||||||
|
<v-tab value="list">
|
||||||
|
<v-icon class="mr-2">mdi-format-list-bulleted</v-icon>
|
||||||
|
{{ t('features.settings.backup.tabs.list') }}
|
||||||
|
</v-tab>
|
||||||
|
</v-tabs>
|
||||||
|
|
||||||
|
<v-window v-model="activeTab">
|
||||||
|
<!-- 导出标签页 -->
|
||||||
|
<v-window-item value="export">
|
||||||
|
<div v-if="exportStatus === 'idle'" class="text-center py-8">
|
||||||
|
<v-icon size="64" color="primary" class="mb-4">mdi-cloud-upload</v-icon>
|
||||||
|
<h3 class="mb-4">{{ t('features.settings.backup.export.title') }}</h3>
|
||||||
|
<p class="mb-4 text-grey">{{ t('features.settings.backup.export.description') }}</p>
|
||||||
|
<v-alert type="info" variant="tonal" class="mb-4 text-left">
|
||||||
|
<template v-slot:prepend>
|
||||||
|
<v-icon>mdi-information</v-icon>
|
||||||
|
</template>
|
||||||
|
{{ t('features.settings.backup.export.includes') }}
|
||||||
|
</v-alert>
|
||||||
|
<v-btn color="primary" size="large" @click="startExport" :loading="exportStatus === 'processing'">
|
||||||
|
<v-icon class="mr-2">mdi-export</v-icon>
|
||||||
|
{{ t('features.settings.backup.export.button') }}
|
||||||
|
</v-btn>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div v-else-if="exportStatus === 'processing'" class="text-center py-8">
|
||||||
|
<v-progress-circular indeterminate color="primary" size="64" class="mb-4"></v-progress-circular>
|
||||||
|
<h3 class="mb-4">{{ t('features.settings.backup.export.processing') }}</h3>
|
||||||
|
<p class="text-grey">{{ exportProgress.message || t('features.settings.backup.export.wait') }}</p>
|
||||||
|
<v-progress-linear :model-value="exportProgress.current" :max="exportProgress.total" class="mt-4" color="primary"></v-progress-linear>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div v-else-if="exportStatus === 'completed'" class="text-center py-8">
|
||||||
|
<v-icon size="64" color="success" class="mb-4">mdi-check-circle</v-icon>
|
||||||
|
<h3 class="mb-4">{{ t('features.settings.backup.export.completed') }}</h3>
|
||||||
|
<p class="mb-4">{{ exportResult?.filename }}</p>
|
||||||
|
<v-btn color="primary" @click="downloadBackup(exportResult?.filename)" class="mr-2">
|
||||||
|
<v-icon class="mr-2">mdi-download</v-icon>
|
||||||
|
{{ t('features.settings.backup.export.download') }}
|
||||||
|
</v-btn>
|
||||||
|
<v-btn color="grey" variant="text" @click="resetExport">
|
||||||
|
{{ t('features.settings.backup.export.another') }}
|
||||||
|
</v-btn>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div v-else-if="exportStatus === 'failed'" class="text-center py-8">
|
||||||
|
<v-icon size="64" color="error" class="mb-4">mdi-alert-circle</v-icon>
|
||||||
|
<h3 class="mb-4">{{ t('features.settings.backup.export.failed') }}</h3>
|
||||||
|
<v-alert type="error" variant="tonal" class="mb-4">
|
||||||
|
{{ exportError }}
|
||||||
|
</v-alert>
|
||||||
|
<v-btn color="primary" @click="resetExport">
|
||||||
|
{{ t('features.settings.backup.export.retry') }}
|
||||||
|
</v-btn>
|
||||||
|
</div>
|
||||||
|
</v-window-item>
|
||||||
|
|
||||||
|
<!-- 导入标签页 -->
|
||||||
|
<v-window-item value="import">
|
||||||
|
<!-- 步骤1: 选择文件 -->
|
||||||
|
<div v-if="importStatus === 'idle'" class="py-4">
|
||||||
|
<v-alert type="warning" variant="tonal" class="mb-4">
|
||||||
|
<template v-slot:prepend>
|
||||||
|
<v-icon>mdi-alert</v-icon>
|
||||||
|
</template>
|
||||||
|
{{ t('features.settings.backup.import.warning') }}
|
||||||
|
</v-alert>
|
||||||
|
|
||||||
|
<v-file-input
|
||||||
|
v-model="importFile"
|
||||||
|
:label="t('features.settings.backup.import.selectFile')"
|
||||||
|
accept=".zip"
|
||||||
|
prepend-icon="mdi-file-upload"
|
||||||
|
show-size
|
||||||
|
class="mb-4"
|
||||||
|
></v-file-input>
|
||||||
|
|
||||||
|
<div class="d-flex justify-center">
|
||||||
|
<v-btn
|
||||||
|
color="primary"
|
||||||
|
size="large"
|
||||||
|
@click="uploadAndCheck"
|
||||||
|
:disabled="!importFile"
|
||||||
|
:loading="importStatus === 'uploading'"
|
||||||
|
>
|
||||||
|
<v-icon class="mr-2">mdi-upload</v-icon>
|
||||||
|
{{ t('features.settings.backup.import.uploadAndCheck') }}
|
||||||
|
</v-btn>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- 步骤1.5: 上传中 -->
|
||||||
|
<div v-else-if="importStatus === 'uploading'" class="text-center py-8">
|
||||||
|
<v-progress-circular indeterminate color="primary" size="64" class="mb-4"></v-progress-circular>
|
||||||
|
<h3 class="mb-4">{{ t('features.settings.backup.import.uploading') }}</h3>
|
||||||
|
<p class="text-grey">{{ t('features.settings.backup.import.uploadWait') }}</p>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- 步骤2: 确认导入 -->
|
||||||
|
<div v-else-if="importStatus === 'confirm'" class="py-4">
|
||||||
|
<v-alert
|
||||||
|
:type="versionAlertType"
|
||||||
|
variant="tonal"
|
||||||
|
class="mb-4"
|
||||||
|
>
|
||||||
|
<template v-slot:prepend>
|
||||||
|
<v-icon>{{ versionAlertIcon }}</v-icon>
|
||||||
|
</template>
|
||||||
|
<div class="confirm-message">
|
||||||
|
<div class="text-h6 mb-2">{{ versionAlertTitle }}</div>
|
||||||
|
<div class="mb-2">
|
||||||
|
<strong>{{ t('features.settings.backup.import.version.backupVersion') }}:</strong> {{ checkResult?.backup_version }}<br>
|
||||||
|
<strong>{{ t('features.settings.backup.import.version.currentVersion') }}:</strong> {{ checkResult?.current_version }}
|
||||||
|
</div>
|
||||||
|
<div v-if="checkResult?.backup_time && checkResult?.backup_time !== '未知'" class="mb-2">
|
||||||
|
<strong>{{ t('features.settings.backup.import.version.backupTime') }}:</strong> {{ formatISODate(checkResult?.backup_time) }}
|
||||||
|
</div>
|
||||||
|
<div class="mt-3" style="white-space: pre-line;">{{ versionAlertMessage }}</div>
|
||||||
|
</div>
|
||||||
|
</v-alert>
|
||||||
|
|
||||||
|
<!-- 备份摘要 -->
|
||||||
|
<v-card variant="outlined" class="mb-4" v-if="checkResult?.backup_summary">
|
||||||
|
<v-card-title class="text-subtitle-1">
|
||||||
|
<v-icon class="mr-2">mdi-package-variant</v-icon>
|
||||||
|
{{ t('features.settings.backup.import.backupContents') }}
|
||||||
|
</v-card-title>
|
||||||
|
<v-card-text>
|
||||||
|
<div class="d-flex flex-wrap ga-2">
|
||||||
|
<v-chip v-if="checkResult.backup_summary.tables?.length" size="small" color="primary" variant="tonal" :ripple="false" class="non-interactive-chip">
|
||||||
|
{{ checkResult.backup_summary.tables.length }} {{ t('features.settings.backup.import.tables') }}
|
||||||
|
</v-chip>
|
||||||
|
<v-chip v-if="checkResult.backup_summary.has_knowledge_bases" size="small" color="success" variant="tonal" :ripple="false" class="non-interactive-chip">
|
||||||
|
{{ t('features.settings.backup.import.knowledgeBases') }}
|
||||||
|
</v-chip>
|
||||||
|
<v-chip v-if="checkResult.backup_summary.has_config" size="small" color="info" variant="tonal" :ripple="false" class="non-interactive-chip">
|
||||||
|
{{ t('features.settings.backup.import.configFiles') }}
|
||||||
|
</v-chip>
|
||||||
|
<v-chip v-for="dir in (checkResult.backup_summary.directories || [])" :key="dir" size="small" color="warning" variant="tonal" :ripple="false" class="non-interactive-chip">
|
||||||
|
{{ dir }}
|
||||||
|
</v-chip>
|
||||||
|
</div>
|
||||||
|
</v-card-text>
|
||||||
|
</v-card>
|
||||||
|
|
||||||
|
<!-- 警告信息 -->
|
||||||
|
<v-alert v-if="checkResult?.warnings?.length" type="warning" variant="tonal" class="mb-4">
|
||||||
|
<div v-for="(warning, idx) in checkResult.warnings" :key="idx">{{ warning }}</div>
|
||||||
|
</v-alert>
|
||||||
|
|
||||||
|
<div class="d-flex justify-center align-center mt-4" style="gap: 16px;">
|
||||||
|
<v-btn
|
||||||
|
color="grey-darken-1"
|
||||||
|
variant="outlined"
|
||||||
|
size="large"
|
||||||
|
@click="resetImport"
|
||||||
|
>
|
||||||
|
<v-icon class="mr-2">mdi-close</v-icon>
|
||||||
|
{{ t('core.common.cancel') }}
|
||||||
|
</v-btn>
|
||||||
|
<v-btn
|
||||||
|
v-if="checkResult?.can_import"
|
||||||
|
color="error"
|
||||||
|
size="large"
|
||||||
|
variant="flat"
|
||||||
|
@click="confirmImport"
|
||||||
|
>
|
||||||
|
<v-icon class="mr-2">mdi-alert</v-icon>
|
||||||
|
{{ t('features.settings.backup.import.confirmImport') }}
|
||||||
|
</v-btn>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- 步骤3: 导入进行中 -->
|
||||||
|
<div v-else-if="importStatus === 'processing'" class="text-center py-8">
|
||||||
|
<v-progress-circular indeterminate color="primary" size="64" class="mb-4"></v-progress-circular>
|
||||||
|
<h3 class="mb-4">{{ t('features.settings.backup.import.processing') }}</h3>
|
||||||
|
<p class="text-grey">{{ importProgress.message || t('features.settings.backup.import.wait') }}</p>
|
||||||
|
<v-progress-linear :model-value="importProgress.current" :max="importProgress.total" class="mt-4" color="primary"></v-progress-linear>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div v-else-if="importStatus === 'completed'" class="text-center py-8">
|
||||||
|
<v-icon size="64" color="success" class="mb-4">mdi-check-circle</v-icon>
|
||||||
|
<h3 class="mb-4">{{ t('features.settings.backup.import.completed') }}</h3>
|
||||||
|
<v-alert type="info" variant="tonal" class="mb-4">
|
||||||
|
{{ t('features.settings.backup.import.restartRequired') }}
|
||||||
|
</v-alert>
|
||||||
|
<v-btn color="primary" @click="restartAstrBot" class="mr-2">
|
||||||
|
<v-icon class="mr-2">mdi-restart</v-icon>
|
||||||
|
{{ t('features.settings.backup.import.restartNow') }}
|
||||||
|
</v-btn>
|
||||||
|
<v-btn color="grey" variant="text" @click="resetImport">
|
||||||
|
{{ t('core.common.close') }}
|
||||||
|
</v-btn>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div v-else-if="importStatus === 'failed'" class="text-center py-8">
|
||||||
|
<v-icon size="64" color="error" class="mb-4">mdi-alert-circle</v-icon>
|
||||||
|
<h3 class="mb-4">{{ t('features.settings.backup.import.failed') }}</h3>
|
||||||
|
<v-alert type="error" variant="tonal" class="mb-4">
|
||||||
|
{{ importError }}
|
||||||
|
</v-alert>
|
||||||
|
<v-btn color="primary" @click="resetImport">
|
||||||
|
{{ t('features.settings.backup.import.retry') }}
|
||||||
|
</v-btn>
|
||||||
|
</div>
|
||||||
|
</v-window-item>
|
||||||
|
|
||||||
|
<!-- 备份列表标签页 -->
|
||||||
|
<v-window-item value="list">
|
||||||
|
<div v-if="loadingList" class="text-center py-8">
|
||||||
|
<v-progress-circular indeterminate color="primary"></v-progress-circular>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div v-else-if="backupList.length === 0" class="text-center py-8">
|
||||||
|
<v-icon size="64" color="grey" class="mb-4">mdi-folder-open-outline</v-icon>
|
||||||
|
<p class="text-grey">{{ t('features.settings.backup.list.empty') }}</p>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<v-list v-else lines="two">
|
||||||
|
<v-list-item
|
||||||
|
v-for="backup in backupList"
|
||||||
|
:key="backup.filename"
|
||||||
|
>
|
||||||
|
<template v-slot:prepend>
|
||||||
|
<v-icon color="primary">mdi-zip-box</v-icon>
|
||||||
|
</template>
|
||||||
|
|
||||||
|
<v-list-item-title>{{ backup.filename }}</v-list-item-title>
|
||||||
|
<v-list-item-subtitle>
|
||||||
|
{{ formatFileSize(backup.size) }} · {{ formatDate(backup.created_at) }}
|
||||||
|
</v-list-item-subtitle>
|
||||||
|
|
||||||
|
<template v-slot:append>
|
||||||
|
<v-btn icon="mdi-download" variant="text" size="small" @click="downloadBackup(backup.filename)"></v-btn>
|
||||||
|
<v-btn icon="mdi-delete" variant="text" size="small" color="error" @click="deleteBackup(backup.filename)"></v-btn>
|
||||||
|
</template>
|
||||||
|
</v-list-item>
|
||||||
|
</v-list>
|
||||||
|
|
||||||
|
<div class="d-flex justify-center mt-4">
|
||||||
|
<v-btn color="primary" variant="text" @click="loadBackupList">
|
||||||
|
<v-icon class="mr-2">mdi-refresh</v-icon>
|
||||||
|
{{ t('features.settings.backup.list.refresh') }}
|
||||||
|
</v-btn>
|
||||||
|
</div>
|
||||||
|
</v-window-item>
|
||||||
|
</v-window>
|
||||||
|
</v-card-text>
|
||||||
|
|
||||||
|
<v-card-actions class="px-6 py-4">
|
||||||
|
<v-spacer></v-spacer>
|
||||||
|
<v-btn color="grey" variant="text" @click="handleClose" :disabled="isProcessing">
|
||||||
|
{{ t('core.common.close') }}
|
||||||
|
</v-btn>
|
||||||
|
</v-card-actions>
|
||||||
|
</v-card>
|
||||||
|
</v-dialog>
|
||||||
|
|
||||||
|
<WaitingForRestart ref="wfr"></WaitingForRestart>
|
||||||
|
</template>
|
||||||
|
|
||||||
|
<script setup>
|
||||||
|
import { ref, computed, watch } from 'vue'
|
||||||
|
import axios from 'axios'
|
||||||
|
import { useI18n } from '@/i18n/composables'
|
||||||
|
import WaitingForRestart from './WaitingForRestart.vue'
|
||||||
|
|
||||||
|
const { t } = useI18n()
|
||||||
|
|
||||||
|
const isOpen = ref(false)
|
||||||
|
const activeTab = ref('export')
|
||||||
|
const wfr = ref(null)
|
||||||
|
|
||||||
|
// 导出状态
|
||||||
|
const exportStatus = ref('idle') // idle, processing, completed, failed
|
||||||
|
const exportTaskId = ref(null)
|
||||||
|
const exportProgress = ref({ current: 0, total: 100, message: '' })
|
||||||
|
const exportResult = ref(null)
|
||||||
|
const exportError = ref('')
|
||||||
|
|
||||||
|
// 导入状态
|
||||||
|
const importStatus = ref('idle') // idle, uploading, confirm, processing, completed, failed
|
||||||
|
const importFile = ref(null)
|
||||||
|
const importTaskId = ref(null)
|
||||||
|
const importProgress = ref({ current: 0, total: 100, message: '' })
|
||||||
|
const importError = ref('')
|
||||||
|
const uploadedFilename = ref('') // 已上传的文件名
|
||||||
|
const checkResult = ref(null) // 预检查结果
|
||||||
|
|
||||||
|
// 备份列表
|
||||||
|
const loadingList = ref(false)
|
||||||
|
const backupList = ref([])
|
||||||
|
|
||||||
|
// 计算属性
|
||||||
|
const isProcessing = computed(() => {
|
||||||
|
return exportStatus.value === 'processing' || importStatus.value === 'processing'
|
||||||
|
})
|
||||||
|
|
||||||
|
// 版本检查相关的计算属性
|
||||||
|
const versionAlertType = computed(() => {
|
||||||
|
const status = checkResult.value?.version_status
|
||||||
|
if (status === 'major_diff') return 'error'
|
||||||
|
if (status === 'minor_diff') return 'warning'
|
||||||
|
return 'info'
|
||||||
|
})
|
||||||
|
|
||||||
|
const versionAlertIcon = computed(() => {
|
||||||
|
const status = checkResult.value?.version_status
|
||||||
|
if (status === 'major_diff') return 'mdi-close-circle'
|
||||||
|
if (status === 'minor_diff') return 'mdi-alert'
|
||||||
|
return 'mdi-check-circle'
|
||||||
|
})
|
||||||
|
|
||||||
|
const versionAlertTitle = computed(() => {
|
||||||
|
const status = checkResult.value?.version_status
|
||||||
|
if (status === 'major_diff') return t('features.settings.backup.import.version.majorDiffTitle')
|
||||||
|
if (status === 'minor_diff') return t('features.settings.backup.import.version.minorDiffTitle')
|
||||||
|
return t('features.settings.backup.import.version.matchTitle')
|
||||||
|
})
|
||||||
|
|
||||||
|
const versionAlertMessage = computed(() => {
|
||||||
|
const status = checkResult.value?.version_status
|
||||||
|
if (status === 'major_diff') return t('features.settings.backup.import.version.majorDiffMessage')
|
||||||
|
if (status === 'minor_diff') return t('features.settings.backup.import.version.minorDiffMessage')
|
||||||
|
return t('features.settings.backup.import.version.matchMessage')
|
||||||
|
})
|
||||||
|
|
||||||
|
// 监听对话框打开
|
||||||
|
watch(isOpen, (newVal) => {
|
||||||
|
if (newVal) {
|
||||||
|
loadBackupList()
|
||||||
|
} else {
|
||||||
|
resetAll()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// 监听标签页切换
|
||||||
|
watch(activeTab, (newVal) => {
|
||||||
|
if (newVal === 'list') {
|
||||||
|
loadBackupList()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// 加载备份列表
|
||||||
|
const loadBackupList = async () => {
|
||||||
|
loadingList.value = true
|
||||||
|
try {
|
||||||
|
const response = await axios.get('/api/backup/list')
|
||||||
|
if (response.data.status === 'ok') {
|
||||||
|
backupList.value = response.data.data.items || []
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Failed to load backup list:', error)
|
||||||
|
} finally {
|
||||||
|
loadingList.value = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 开始导出
|
||||||
|
const startExport = async () => {
|
||||||
|
exportStatus.value = 'processing'
|
||||||
|
exportProgress.value = { current: 0, total: 100, message: '' }
|
||||||
|
|
||||||
|
try {
|
||||||
|
const response = await axios.post('/api/backup/export')
|
||||||
|
if (response.data.status === 'ok') {
|
||||||
|
exportTaskId.value = response.data.data.task_id
|
||||||
|
pollExportProgress()
|
||||||
|
} else {
|
||||||
|
throw new Error(response.data.message)
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
exportStatus.value = 'failed'
|
||||||
|
exportError.value = error.message || 'Export failed'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 轮询导出进度
|
||||||
|
const pollExportProgress = async () => {
|
||||||
|
if (!exportTaskId.value) return
|
||||||
|
|
||||||
|
try {
|
||||||
|
const response = await axios.get('/api/backup/progress', {
|
||||||
|
params: { task_id: exportTaskId.value }
|
||||||
|
})
|
||||||
|
|
||||||
|
if (response.data.status === 'ok') {
|
||||||
|
const data = response.data.data
|
||||||
|
|
||||||
|
if (data.status === 'processing' && data.progress) {
|
||||||
|
exportProgress.value = {
|
||||||
|
current: data.progress.current || 0,
|
||||||
|
total: data.progress.total || 100,
|
||||||
|
message: data.progress.message || ''
|
||||||
|
}
|
||||||
|
setTimeout(pollExportProgress, 1000)
|
||||||
|
} else if (data.status === 'completed') {
|
||||||
|
exportStatus.value = 'completed'
|
||||||
|
exportResult.value = data.result
|
||||||
|
loadBackupList()
|
||||||
|
} else if (data.status === 'failed') {
|
||||||
|
exportStatus.value = 'failed'
|
||||||
|
exportError.value = data.error || 'Export failed'
|
||||||
|
} else {
|
||||||
|
setTimeout(pollExportProgress, 1000)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
exportStatus.value = 'failed'
|
||||||
|
exportError.value = error.message || 'Failed to get export progress'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 重置导出状态
|
||||||
|
const resetExport = () => {
|
||||||
|
exportStatus.value = 'idle'
|
||||||
|
exportTaskId.value = null
|
||||||
|
exportProgress.value = { current: 0, total: 100, message: '' }
|
||||||
|
exportResult.value = null
|
||||||
|
exportError.value = ''
|
||||||
|
}
|
||||||
|
|
||||||
|
// 上传并检查
|
||||||
|
const uploadAndCheck = async () => {
|
||||||
|
if (!importFile.value) return
|
||||||
|
|
||||||
|
importStatus.value = 'uploading'
|
||||||
|
|
||||||
|
try {
|
||||||
|
// 步骤1: 上传文件
|
||||||
|
const formData = new FormData()
|
||||||
|
formData.append('file', importFile.value)
|
||||||
|
|
||||||
|
const uploadResponse = await axios.post('/api/backup/upload', formData, {
|
||||||
|
headers: { 'Content-Type': 'multipart/form-data' }
|
||||||
|
})
|
||||||
|
|
||||||
|
if (uploadResponse.data.status !== 'ok') {
|
||||||
|
throw new Error(uploadResponse.data.message)
|
||||||
|
}
|
||||||
|
|
||||||
|
uploadedFilename.value = uploadResponse.data.data.filename
|
||||||
|
|
||||||
|
// 步骤2: 预检查
|
||||||
|
const checkResponse = await axios.post('/api/backup/check', {
|
||||||
|
filename: uploadedFilename.value
|
||||||
|
})
|
||||||
|
|
||||||
|
if (checkResponse.data.status !== 'ok') {
|
||||||
|
throw new Error(checkResponse.data.message)
|
||||||
|
}
|
||||||
|
|
||||||
|
checkResult.value = checkResponse.data.data
|
||||||
|
|
||||||
|
// 检查是否有效
|
||||||
|
if (!checkResult.value.valid) {
|
||||||
|
importStatus.value = 'failed'
|
||||||
|
importError.value = checkResult.value.error || t('features.settings.backup.import.invalidBackup')
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 显示确认对话框
|
||||||
|
importStatus.value = 'confirm'
|
||||||
|
|
||||||
|
} catch (error) {
|
||||||
|
importStatus.value = 'failed'
|
||||||
|
importError.value = error.response?.data?.message || error.message || 'Upload failed'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 确认导入
|
||||||
|
const confirmImport = async () => {
|
||||||
|
if (!uploadedFilename.value) return
|
||||||
|
|
||||||
|
importStatus.value = 'processing'
|
||||||
|
importProgress.value = { current: 0, total: 100, message: '' }
|
||||||
|
|
||||||
|
try {
|
||||||
|
const response = await axios.post('/api/backup/import', {
|
||||||
|
filename: uploadedFilename.value,
|
||||||
|
confirmed: true
|
||||||
|
})
|
||||||
|
|
||||||
|
if (response.data.status === 'ok') {
|
||||||
|
importTaskId.value = response.data.data.task_id
|
||||||
|
pollImportProgress()
|
||||||
|
} else {
|
||||||
|
throw new Error(response.data.message)
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
importStatus.value = 'failed'
|
||||||
|
importError.value = error.response?.data?.message || error.message || 'Import failed'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 轮询导入进度
|
||||||
|
const pollImportProgress = async () => {
|
||||||
|
if (!importTaskId.value) return
|
||||||
|
|
||||||
|
try {
|
||||||
|
const response = await axios.get('/api/backup/progress', {
|
||||||
|
params: { task_id: importTaskId.value }
|
||||||
|
})
|
||||||
|
|
||||||
|
if (response.data.status === 'ok') {
|
||||||
|
const data = response.data.data
|
||||||
|
|
||||||
|
if (data.status === 'processing' && data.progress) {
|
||||||
|
importProgress.value = {
|
||||||
|
current: data.progress.current || 0,
|
||||||
|
total: data.progress.total || 100,
|
||||||
|
message: data.progress.message || ''
|
||||||
|
}
|
||||||
|
setTimeout(pollImportProgress, 1000)
|
||||||
|
} else if (data.status === 'completed') {
|
||||||
|
importStatus.value = 'completed'
|
||||||
|
} else if (data.status === 'failed') {
|
||||||
|
importStatus.value = 'failed'
|
||||||
|
importError.value = data.error || 'Import failed'
|
||||||
|
} else {
|
||||||
|
setTimeout(pollImportProgress, 1000)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
importStatus.value = 'failed'
|
||||||
|
importError.value = error.message || 'Failed to get import progress'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 重置导入状态
|
||||||
|
const resetImport = () => {
|
||||||
|
importStatus.value = 'idle'
|
||||||
|
importFile.value = null
|
||||||
|
importTaskId.value = null
|
||||||
|
importProgress.value = { current: 0, total: 100, message: '' }
|
||||||
|
importError.value = ''
|
||||||
|
uploadedFilename.value = ''
|
||||||
|
checkResult.value = null
|
||||||
|
}
|
||||||
|
|
||||||
|
// 下载备份
|
||||||
|
const downloadBackup = async (filename) => {
|
||||||
|
try {
|
||||||
|
const response = await axios.get('/api/backup/download', {
|
||||||
|
params: { filename },
|
||||||
|
responseType: 'blob'
|
||||||
|
})
|
||||||
|
|
||||||
|
// 创建 Blob URL 并触发下载
|
||||||
|
const blob = new Blob([response.data], { type: 'application/zip' })
|
||||||
|
const url = window.URL.createObjectURL(blob)
|
||||||
|
const link = document.createElement('a')
|
||||||
|
link.href = url
|
||||||
|
link.download = filename
|
||||||
|
document.body.appendChild(link)
|
||||||
|
link.click()
|
||||||
|
document.body.removeChild(link)
|
||||||
|
window.URL.revokeObjectURL(url)
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Download failed:', error)
|
||||||
|
alert(t('features.settings.backup.export.failed') + ': ' + (error.message || 'Unknown error'))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 删除备份
|
||||||
|
const deleteBackup = async (filename) => {
|
||||||
|
if (!confirm(t('features.settings.backup.list.confirmDelete'))) return
|
||||||
|
|
||||||
|
try {
|
||||||
|
const response = await axios.post('/api/backup/delete', { filename })
|
||||||
|
if (response.data.status === 'ok') {
|
||||||
|
loadBackupList()
|
||||||
|
} else {
|
||||||
|
alert(response.data.message || 'Delete failed')
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
alert(error.message || 'Delete failed')
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 格式化文件大小
|
||||||
|
const formatFileSize = (bytes) => {
|
||||||
|
if (bytes === 0) return '0 B'
|
||||||
|
const k = 1024
|
||||||
|
const sizes = ['B', 'KB', 'MB', 'GB']
|
||||||
|
const i = Math.floor(Math.log(bytes) / Math.log(k))
|
||||||
|
return parseFloat((bytes / Math.pow(k, i)).toFixed(2)) + ' ' + sizes[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
// 格式化日期(从时间戳)
|
||||||
|
const formatDate = (timestamp) => {
|
||||||
|
return new Date(timestamp * 1000).toLocaleString()
|
||||||
|
}
|
||||||
|
|
||||||
|
// 格式化 ISO 日期字符串
|
||||||
|
const formatISODate = (isoString) => {
|
||||||
|
if (!isoString) return ''
|
||||||
|
try {
|
||||||
|
return new Date(isoString).toLocaleString()
|
||||||
|
} catch {
|
||||||
|
return isoString
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 重启 AstrBot
|
||||||
|
const restartAstrBot = () => {
|
||||||
|
axios.post('/api/stat/restart-core').then(() => {
|
||||||
|
if (wfr.value) {
|
||||||
|
wfr.value.check()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// 重置所有状态
|
||||||
|
const resetAll = () => {
|
||||||
|
resetExport()
|
||||||
|
resetImport()
|
||||||
|
activeTab.value = 'export'
|
||||||
|
}
|
||||||
|
|
||||||
|
// 关闭对话框
|
||||||
|
const handleClose = () => {
|
||||||
|
if (isProcessing.value) return
|
||||||
|
isOpen.value = false
|
||||||
|
}
|
||||||
|
|
||||||
|
// 打开对话框
|
||||||
|
const open = () => {
|
||||||
|
isOpen.value = true
|
||||||
|
}
|
||||||
|
|
||||||
|
defineExpose({ open })
|
||||||
|
</script>
|
||||||
|
|
||||||
|
<style scoped>
|
||||||
|
.v-list-item {
|
||||||
|
border-bottom: 1px solid rgba(0, 0, 0, 0.08);
|
||||||
|
}
|
||||||
|
|
||||||
|
.v-list-item:last-child {
|
||||||
|
border-bottom: none;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* 禁用 Chip 的交互效果 */
|
||||||
|
.non-interactive-chip {
|
||||||
|
pointer-events: none;
|
||||||
|
cursor: default;
|
||||||
|
}
|
||||||
|
|
||||||
|
.non-interactive-chip:hover {
|
||||||
|
box-shadow: none !important;
|
||||||
|
}
|
||||||
|
</style>
|
||||||
@@ -1,12 +1,11 @@
|
|||||||
<script setup>
|
<script setup>
|
||||||
import { useCommonStore } from '@/stores/common';
|
import { useCommonStore } from '@/stores/common';
|
||||||
import { storeToRefs } from 'pinia';
|
|
||||||
import axios from 'axios';
|
import axios from 'axios';
|
||||||
|
import { EventSourcePolyfill } from 'event-source-polyfill';
|
||||||
</script>
|
</script>
|
||||||
|
|
||||||
<template>
|
<template>
|
||||||
<div>
|
<div>
|
||||||
<!-- 添加筛选级别控件 -->
|
|
||||||
<div class="filter-controls mb-2" v-if="showLevelBtns">
|
<div class="filter-controls mb-2" v-if="showLevelBtns">
|
||||||
<v-chip-group v-model="selectedLevels" column multiple>
|
<v-chip-group v-model="selectedLevels" column multiple>
|
||||||
<v-chip v-for="level in logLevels" :key="level" :color="getLevelColor(level)" filter variant="flat" size="small"
|
<v-chip v-for="level in logLevels" :key="level" :color="getLevelColor(level)" filter variant="flat" size="small"
|
||||||
@@ -26,20 +25,19 @@ export default {
|
|||||||
name: 'ConsoleDisplayer',
|
name: 'ConsoleDisplayer',
|
||||||
data() {
|
data() {
|
||||||
return {
|
return {
|
||||||
autoScroll: true, // 默认开启自动滚动
|
autoScroll: true,
|
||||||
logColorAnsiMap: {
|
logColorAnsiMap: {
|
||||||
'\u001b[1;34m': 'color: #0000FF; font-weight: bold;', // bold_blue
|
'\u001b[1;34m': 'color: #0000FF; font-weight: bold;',
|
||||||
'\u001b[1;36m': 'color: #00FFFF; font-weight: bold;', // bold_cyan
|
'\u001b[1;36m': 'color: #00FFFF; font-weight: bold;',
|
||||||
'\u001b[1;33m': 'color: #FFFF00; font-weight: bold;', // bold_yellow
|
'\u001b[1;33m': 'color: #FFFF00; font-weight: bold;',
|
||||||
'\u001b[31m': 'color: #FF0000;', // red
|
'\u001b[31m': 'color: #FF0000;',
|
||||||
'\u001b[1;31m': 'color: #FF0000; font-weight: bold;', // bold_red
|
'\u001b[1;31m': 'color: #FF0000; font-weight: bold;',
|
||||||
'\u001b[0m': 'color: inherit; font-weight: normal;', // reset
|
'\u001b[0m': 'color: inherit; font-weight: normal;',
|
||||||
'\u001b[32m': 'color: #00FF00;', // green
|
'\u001b[32m': 'color: #00FF00;',
|
||||||
'default': 'color: #FFFFFF;'
|
'default': 'color: #FFFFFF;'
|
||||||
},
|
},
|
||||||
historyNum_: -1,
|
|
||||||
logLevels: ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
|
logLevels: ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
|
||||||
selectedLevels: [0, 1, 2, 3, 4], // 默认选中所有级别
|
selectedLevels: [0, 1, 2, 3, 4],
|
||||||
levelColors: {
|
levelColors: {
|
||||||
'DEBUG': 'grey',
|
'DEBUG': 'grey',
|
||||||
'INFO': 'blue-lighten-3',
|
'INFO': 'blue-lighten-3',
|
||||||
@@ -47,17 +45,19 @@ export default {
|
|||||||
'ERROR': 'red',
|
'ERROR': 'red',
|
||||||
'CRITICAL': 'purple'
|
'CRITICAL': 'purple'
|
||||||
},
|
},
|
||||||
lastProcessedTime: 0, // 记录最后处理的日志时间戳
|
localLogCache: [],
|
||||||
localLogCache: [], // 本地日志缓存
|
eventSource: null,
|
||||||
|
retryTimer: null,
|
||||||
|
retryAttempts: 0,
|
||||||
|
maxRetryAttempts: 10,
|
||||||
|
baseRetryDelay: 1000,
|
||||||
|
lastEventId: null,
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
computed: {
|
computed: {
|
||||||
commonStore() {
|
commonStore() {
|
||||||
return useCommonStore();
|
return useCommonStore();
|
||||||
},
|
},
|
||||||
logCache() {
|
|
||||||
return this.commonStore.log_cache;
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
props: {
|
props: {
|
||||||
historyNum: {
|
historyNum: {
|
||||||
@@ -70,41 +70,6 @@ export default {
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
watch: {
|
watch: {
|
||||||
logCache: {
|
|
||||||
handler(newVal) {
|
|
||||||
// 基于 timestamp 处理新增的日志
|
|
||||||
if (newVal && newVal.length > 0) {
|
|
||||||
// 确保 DOM 已经准备好
|
|
||||||
this.$nextTick(() => {
|
|
||||||
// 合并到本地缓存并按时间排序
|
|
||||||
const newLogs = newVal.filter(log => log.time > this.lastProcessedTime);
|
|
||||||
|
|
||||||
if (newLogs.length > 0) {
|
|
||||||
this.localLogCache.push(...newLogs);
|
|
||||||
// 按时间戳排序
|
|
||||||
this.localLogCache.sort((a, b) => a.time - b.time);
|
|
||||||
|
|
||||||
// 只保留最新的 log_cache_max_len 条
|
|
||||||
if (this.localLogCache.length > this.commonStore.log_cache_max_len) {
|
|
||||||
this.localLogCache.splice(0, this.localLogCache.length - this.commonStore.log_cache_max_len);
|
|
||||||
}
|
|
||||||
|
|
||||||
// 显示新日志
|
|
||||||
newLogs.forEach(logItem => {
|
|
||||||
if (this.isLevelSelected(logItem.level)) {
|
|
||||||
this.printLog(logItem.data);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
// 更新最后处理时间
|
|
||||||
this.lastProcessedTime = Math.max(...newLogs.map(log => log.time));
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
},
|
|
||||||
deep: true,
|
|
||||||
immediate: false
|
|
||||||
},
|
|
||||||
selectedLevels: {
|
selectedLevels: {
|
||||||
handler() {
|
handler() {
|
||||||
this.refreshDisplay();
|
this.refreshDisplay();
|
||||||
@@ -113,30 +78,142 @@ export default {
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
async mounted() {
|
async mounted() {
|
||||||
// 请求历史日志
|
|
||||||
await this.fetchLogHistory();
|
await this.fetchLogHistory();
|
||||||
|
this.connectSSE();
|
||||||
// 等待 DOM 准备好后,显示历史日志
|
},
|
||||||
this.$nextTick(() => {
|
beforeUnmount() {
|
||||||
if (this.localLogCache.length > 0) {
|
if (this.eventSource) {
|
||||||
this.localLogCache.forEach(logItem => {
|
this.eventSource.close();
|
||||||
if (this.isLevelSelected(logItem.level)) {
|
this.eventSource = null;
|
||||||
this.printLog(logItem.data);
|
}
|
||||||
}
|
if (this.retryTimer) {
|
||||||
});
|
clearTimeout(this.retryTimer);
|
||||||
// 更新最后处理时间
|
this.retryTimer = null;
|
||||||
this.lastProcessedTime = Math.max(...this.localLogCache.map(log => log.time));
|
}
|
||||||
}
|
this.retryAttempts = 0;
|
||||||
});
|
|
||||||
},
|
},
|
||||||
methods: {
|
methods: {
|
||||||
|
connectSSE() {
|
||||||
|
if (this.eventSource) {
|
||||||
|
this.eventSource.close();
|
||||||
|
this.eventSource = null;
|
||||||
|
}
|
||||||
|
|
||||||
|
console.log(`正在连接日志流... (尝试次数: ${this.retryAttempts})`);
|
||||||
|
|
||||||
|
const token = localStorage.getItem('token');
|
||||||
|
|
||||||
|
this.eventSource = new EventSourcePolyfill('/api/live-log', {
|
||||||
|
headers: {
|
||||||
|
'Authorization': token ? `Bearer ${token}` : ''
|
||||||
|
},
|
||||||
|
heartbeatTimeout: 300000,
|
||||||
|
withCredentials: true
|
||||||
|
});
|
||||||
|
|
||||||
|
this.eventSource.onopen = () => {
|
||||||
|
console.log('日志流连接成功!');
|
||||||
|
this.retryAttempts = 0;
|
||||||
|
|
||||||
|
if (!this.lastEventId) {
|
||||||
|
this.fetchLogHistory();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
this.eventSource.onmessage = (event) => {
|
||||||
|
try {
|
||||||
|
if (event.lastEventId) {
|
||||||
|
this.lastEventId = event.lastEventId;
|
||||||
|
}
|
||||||
|
|
||||||
|
const payload = JSON.parse(event.data);
|
||||||
|
this.processNewLogs([payload]);
|
||||||
|
} catch (e) {
|
||||||
|
console.error('解析日志失败:', e);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
this.eventSource.onerror = (err) => {
|
||||||
|
|
||||||
|
if (err.status === 401) {
|
||||||
|
console.error('鉴权失败 (401),可能是 Token 过期了。');
|
||||||
|
|
||||||
|
} else {
|
||||||
|
console.warn('日志流连接错误:', err);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (this.eventSource) {
|
||||||
|
this.eventSource.close();
|
||||||
|
this.eventSource = null;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (this.retryAttempts >= this.maxRetryAttempts) {
|
||||||
|
console.error('❌ 已达到最大重试次数,停止重连。请刷新页面重试。');
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const delay = Math.min(
|
||||||
|
this.baseRetryDelay * Math.pow(2, this.retryAttempts),
|
||||||
|
30000
|
||||||
|
);
|
||||||
|
|
||||||
|
console.log(`⏳ ${delay}ms 后尝试第 ${this.retryAttempts + 1} 次重连...`);
|
||||||
|
|
||||||
|
if (this.retryTimer) {
|
||||||
|
clearTimeout(this.retryTimer);
|
||||||
|
this.retryTimer = null;
|
||||||
|
}
|
||||||
|
|
||||||
|
this.retryTimer = setTimeout(async () => {
|
||||||
|
this.retryAttempts++;
|
||||||
|
|
||||||
|
if (!this.lastEventId) {
|
||||||
|
await this.fetchLogHistory();
|
||||||
|
}
|
||||||
|
|
||||||
|
this.connectSSE();
|
||||||
|
}, delay);
|
||||||
|
};
|
||||||
|
},
|
||||||
|
|
||||||
|
processNewLogs(newLogs) {
|
||||||
|
if (!newLogs || newLogs.length === 0) return;
|
||||||
|
|
||||||
|
let hasUpdate = false;
|
||||||
|
|
||||||
|
newLogs.forEach(log => {
|
||||||
|
|
||||||
|
const exists = this.localLogCache.some(existing =>
|
||||||
|
existing.time === log.time &&
|
||||||
|
existing.data === log.data &&
|
||||||
|
existing.level === log.level
|
||||||
|
);
|
||||||
|
|
||||||
|
if (!exists) {
|
||||||
|
this.localLogCache.push(log);
|
||||||
|
hasUpdate = true;
|
||||||
|
|
||||||
|
if (this.isLevelSelected(log.level)) {
|
||||||
|
this.printLog(log.data);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
if (hasUpdate) {
|
||||||
|
this.localLogCache.sort((a, b) => a.time - b.time);
|
||||||
|
|
||||||
|
const maxSize = this.commonStore.log_cache_max_len || 200;
|
||||||
|
if (this.localLogCache.length > maxSize) {
|
||||||
|
this.localLogCache.splice(0, this.localLogCache.length - maxSize);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
async fetchLogHistory() {
|
async fetchLogHistory() {
|
||||||
try {
|
try {
|
||||||
const res = await axios.get('/api/log-history');
|
const res = await axios.get('/api/log-history');
|
||||||
if (res.data.data.logs && res.data.data.logs.length > 0) {
|
if (res.data.data.logs && res.data.data.logs.length > 0) {
|
||||||
this.localLogCache = [...res.data.data.logs];
|
this.processNewLogs(res.data.data.logs);
|
||||||
// 按时间戳排序
|
|
||||||
this.localLogCache.sort((a, b) => a.time - b.time);
|
|
||||||
}
|
}
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
console.error('Failed to fetch log history:', err);
|
console.error('Failed to fetch log history:', err);
|
||||||
@@ -162,7 +239,6 @@ export default {
|
|||||||
if (termElement) {
|
if (termElement) {
|
||||||
termElement.innerHTML = '';
|
termElement.innerHTML = '';
|
||||||
|
|
||||||
// 重新显示所有符合筛选条件的日志
|
|
||||||
if (this.localLogCache && this.localLogCache.length > 0) {
|
if (this.localLogCache && this.localLogCache.length > 0) {
|
||||||
this.localLogCache.forEach(logItem => {
|
this.localLogCache.forEach(logItem => {
|
||||||
if (this.isLevelSelected(logItem.level)) {
|
if (this.isLevelSelected(logItem.level)) {
|
||||||
@@ -173,16 +249,13 @@ export default {
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|
||||||
|
|
||||||
toggleAutoScroll() {
|
toggleAutoScroll() {
|
||||||
this.autoScroll = !this.autoScroll;
|
this.autoScroll = !this.autoScroll;
|
||||||
},
|
},
|
||||||
|
|
||||||
printLog(log) {
|
printLog(log) {
|
||||||
// append 一个 span 标签到 term,block 的方式
|
|
||||||
let ele = document.getElementById('term')
|
let ele = document.getElementById('term')
|
||||||
if (!ele) {
|
if (!ele) {
|
||||||
console.warn('term element not found, skipping log print');
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -196,11 +269,11 @@ export default {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
span.style = style + 'display: block; font-size: 12px; font-family: Consolas, monospace; white-space: pre-wrap;'
|
span.style = style + 'display: block; font-size: 12px; font-family: Consolas, monospace; white-space: pre-wrap; margin-bottom: 2px;'
|
||||||
span.classList.add('fade-in')
|
span.classList.add('fade-in')
|
||||||
span.innerText = `${log}`;
|
span.innerText = `${log}`;
|
||||||
ele.appendChild(span)
|
ele.appendChild(span)
|
||||||
if (this.autoScroll ) {
|
if (this.autoScroll) {
|
||||||
ele.scrollTop = ele.scrollHeight
|
ele.scrollTop = ele.scrollHeight
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -18,6 +18,11 @@
|
|||||||
"title": "Data Migration to v4.0.0",
|
"title": "Data Migration to v4.0.0",
|
||||||
"subtitle": "If you encounter data compatibility issues, you can manually start the database migration assistant",
|
"subtitle": "If you encounter data compatibility issues, you can manually start the database migration assistant",
|
||||||
"button": "Start Migration Assistant"
|
"button": "Start Migration Assistant"
|
||||||
|
},
|
||||||
|
"backup": {
|
||||||
|
"title": "Backup & Restore",
|
||||||
|
"subtitle": "Export or import all AstrBot data for easy migration to a new server",
|
||||||
|
"button": "Backup Manager"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"sidebar": {
|
"sidebar": {
|
||||||
@@ -29,5 +34,66 @@
|
|||||||
"mainItems": "Main Modules",
|
"mainItems": "Main Modules",
|
||||||
"moreItems": "More Features"
|
"moreItems": "More Features"
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
"backup": {
|
||||||
|
"dialog": {
|
||||||
|
"title": "Backup Manager"
|
||||||
|
},
|
||||||
|
"tabs": {
|
||||||
|
"export": "Export Backup",
|
||||||
|
"import": "Import Backup",
|
||||||
|
"list": "Backup List"
|
||||||
|
},
|
||||||
|
"export": {
|
||||||
|
"title": "Create Backup",
|
||||||
|
"description": "Export all data as a ZIP backup file, including database, knowledge base, config and attachments.",
|
||||||
|
"includes": "Backup includes: Main database, Knowledge bases (metadata + vector index + documents), Config files, Attachment files",
|
||||||
|
"button": "Start Export",
|
||||||
|
"processing": "Exporting...",
|
||||||
|
"wait": "Please wait, packaging data...",
|
||||||
|
"completed": "Export Completed!",
|
||||||
|
"download": "Download Backup",
|
||||||
|
"another": "Create New Backup",
|
||||||
|
"failed": "Export Failed",
|
||||||
|
"retry": "Retry"
|
||||||
|
},
|
||||||
|
"import": {
|
||||||
|
"title": "Import Backup",
|
||||||
|
"warning": "⚠️ Import will clear and overwrite existing data! Please make sure you have backed up your current data.",
|
||||||
|
"selectFile": "Select backup file (.zip)",
|
||||||
|
"uploadAndCheck": "Upload & Check",
|
||||||
|
"uploading": "Uploading...",
|
||||||
|
"uploadWait": "Please wait, uploading backup file...",
|
||||||
|
"invalidBackup": "Invalid backup file",
|
||||||
|
"backupContents": "Backup Contents",
|
||||||
|
"tables": "tables",
|
||||||
|
"knowledgeBases": "Knowledge Bases",
|
||||||
|
"configFiles": "Config Files",
|
||||||
|
"confirmImport": "Confirm Import",
|
||||||
|
"button": "Start Import",
|
||||||
|
"processing": "Importing...",
|
||||||
|
"wait": "Please wait, restoring data...",
|
||||||
|
"completed": "Import Completed!",
|
||||||
|
"restartRequired": "Data has been successfully imported. It is recommended to restart AstrBot immediately for all changes to take effect.",
|
||||||
|
"restartNow": "Restart Now",
|
||||||
|
"failed": "Import Failed",
|
||||||
|
"retry": "Retry",
|
||||||
|
"version": {
|
||||||
|
"backupVersion": "Backup Version",
|
||||||
|
"currentVersion": "Current Version",
|
||||||
|
"backupTime": "Backup Time",
|
||||||
|
"matchTitle": "✅ Version Match",
|
||||||
|
"matchMessage": "Import will clear and overwrite all existing data, including:\n• Main database (conversations, settings, etc.)\n• Knowledge bases\n• Plugins and plugin data\n• Configuration files\n\nThis action cannot be undone! Do you want to continue?",
|
||||||
|
"minorDiffTitle": "⚠️ Version Difference Warning",
|
||||||
|
"minorDiffMessage": "Minor version differences are usually compatible, but there may be some data structure changes.\nImport will clear and overwrite all existing data!\n\nDo you want to continue?",
|
||||||
|
"majorDiffTitle": "⛔ Cannot Import",
|
||||||
|
"majorDiffMessage": "Major version numbers are different. Cross-major-version import may cause data corruption.\nPlease use the same major version of AstrBot for import."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"list": {
|
||||||
|
"empty": "No backup files",
|
||||||
|
"refresh": "Refresh List",
|
||||||
|
"confirmDelete": "Are you sure you want to delete this backup file? This action cannot be undone."
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -18,6 +18,11 @@
|
|||||||
"title": "数据迁移到 v4.0.0 格式",
|
"title": "数据迁移到 v4.0.0 格式",
|
||||||
"subtitle": "如果您遇到数据兼容性问题,可以手动启动数据库迁移助手",
|
"subtitle": "如果您遇到数据兼容性问题,可以手动启动数据库迁移助手",
|
||||||
"button": "启动迁移助手"
|
"button": "启动迁移助手"
|
||||||
|
},
|
||||||
|
"backup": {
|
||||||
|
"title": "数据备份与恢复",
|
||||||
|
"subtitle": "导出或导入 AstrBot 的所有数据,方便迁移到新服务器",
|
||||||
|
"button": "备份管理"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"sidebar": {
|
"sidebar": {
|
||||||
@@ -29,5 +34,66 @@
|
|||||||
"mainItems": "主要模块",
|
"mainItems": "主要模块",
|
||||||
"moreItems": "更多功能"
|
"moreItems": "更多功能"
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
"backup": {
|
||||||
|
"dialog": {
|
||||||
|
"title": "备份管理"
|
||||||
|
},
|
||||||
|
"tabs": {
|
||||||
|
"export": "导出备份",
|
||||||
|
"import": "导入备份",
|
||||||
|
"list": "备份列表"
|
||||||
|
},
|
||||||
|
"export": {
|
||||||
|
"title": "创建备份",
|
||||||
|
"description": "将所有数据导出为 ZIP 备份文件,包括数据库、知识库、配置和附件。",
|
||||||
|
"includes": "备份包含:主数据库、知识库(元数据+向量索引+文档)、配置文件、附件文件",
|
||||||
|
"button": "开始导出",
|
||||||
|
"processing": "正在导出...",
|
||||||
|
"wait": "请稍候,正在打包数据...",
|
||||||
|
"completed": "导出完成!",
|
||||||
|
"download": "下载备份",
|
||||||
|
"another": "创建新备份",
|
||||||
|
"failed": "导出失败",
|
||||||
|
"retry": "重试"
|
||||||
|
},
|
||||||
|
"import": {
|
||||||
|
"title": "导入备份",
|
||||||
|
"warning": "⚠️ 导入将会清空并覆盖现有数据!请确保已备份当前数据。",
|
||||||
|
"selectFile": "选择备份文件 (.zip)",
|
||||||
|
"uploadAndCheck": "上传并检查",
|
||||||
|
"uploading": "正在上传...",
|
||||||
|
"uploadWait": "请稍候,正在上传备份文件...",
|
||||||
|
"invalidBackup": "无效的备份文件",
|
||||||
|
"backupContents": "备份内容",
|
||||||
|
"tables": "个数据表",
|
||||||
|
"knowledgeBases": "知识库",
|
||||||
|
"configFiles": "配置文件",
|
||||||
|
"confirmImport": "确认导入",
|
||||||
|
"button": "开始导入",
|
||||||
|
"processing": "正在导入...",
|
||||||
|
"wait": "请稍候,正在恢复数据...",
|
||||||
|
"completed": "导入完成!",
|
||||||
|
"restartRequired": "数据已成功导入。建议立即重启 AstrBot 以使所有更改生效。",
|
||||||
|
"restartNow": "立即重启",
|
||||||
|
"failed": "导入失败",
|
||||||
|
"retry": "重试",
|
||||||
|
"version": {
|
||||||
|
"backupVersion": "备份版本",
|
||||||
|
"currentVersion": "当前版本",
|
||||||
|
"backupTime": "备份时间",
|
||||||
|
"matchTitle": "✅ 版本匹配",
|
||||||
|
"matchMessage": "导入将会清空并覆盖现有的所有数据,包括:\n• 主数据库(对话记录、配置等)\n• 知识库数据\n• 插件及插件数据\n• 配置文件\n\n此操作不可撤销!是否确认继续?",
|
||||||
|
"minorDiffTitle": "⚠️ 版本差异警告",
|
||||||
|
"minorDiffMessage": "小版本差异通常是兼容的,但可能存在少量数据结构变化。\n导入将会清空并覆盖现有的所有数据!\n\n是否确认继续导入?",
|
||||||
|
"majorDiffTitle": "⛔ 无法导入",
|
||||||
|
"majorDiffMessage": "主版本号不同,跨主版本导入可能导致数据损坏。\n请使用相同主版本的 AstrBot 进行导入。"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"list": {
|
||||||
|
"empty": "暂无备份文件",
|
||||||
|
"refresh": "刷新列表",
|
||||||
|
"confirmDelete": "确定要删除这个备份文件吗?此操作不可撤销。"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -21,10 +21,14 @@ export const useCommonStore = defineStore({
|
|||||||
}
|
}
|
||||||
const controller = new AbortController();
|
const controller = new AbortController();
|
||||||
const { signal } = controller;
|
const { signal } = controller;
|
||||||
|
|
||||||
|
// 注意:这里如果之前改过 Polyfill 的话,可能需要保持原样
|
||||||
|
// 如果是用 fetch 的话,这里是支持 Authorization Header 的
|
||||||
const headers = {
|
const headers = {
|
||||||
'Content-Type': 'multipart/form-data',
|
'Content-Type': 'multipart/form-data',
|
||||||
'Authorization': 'Bearer ' + localStorage.getItem('token')
|
'Authorization': 'Bearer ' + localStorage.getItem('token')
|
||||||
};
|
};
|
||||||
|
|
||||||
fetch('/api/live-log', {
|
fetch('/api/live-log', {
|
||||||
method: 'GET',
|
method: 'GET',
|
||||||
headers,
|
headers,
|
||||||
@@ -72,10 +76,20 @@ export const useCommonStore = defineStore({
|
|||||||
|
|
||||||
try {
|
try {
|
||||||
const logObject = JSON.parse(logLine);
|
const logObject = JSON.parse(logLine);
|
||||||
// give a uuid if not exists
|
|
||||||
|
// 修复:兼容 HTTP 环境的 UUID 生成
|
||||||
if (!logObject.uuid) {
|
if (!logObject.uuid) {
|
||||||
logObject.uuid = crypto.randomUUID();
|
if (typeof crypto !== 'undefined' && typeof crypto.randomUUID === 'function') {
|
||||||
|
logObject.uuid = crypto.randomUUID();
|
||||||
|
} else {
|
||||||
|
// 手动生成 UUID v4
|
||||||
|
logObject.uuid = 'xxxxxxxx-xxxx-4xxx-yxxx-xxxxxxxxxxxx'.replace(/[xy]/g, function(c) {
|
||||||
|
var r = Math.random() * 16 | 0, v = c == 'x' ? r : (r & 0x3 | 0x8);
|
||||||
|
return v.toString(16);
|
||||||
|
});
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
this.log_cache.push(logObject);
|
this.log_cache.push(logObject);
|
||||||
// Limit log cache size
|
// Limit log cache size
|
||||||
if (this.log_cache.length > this.log_cache_max_len) {
|
if (this.log_cache.length > this.log_cache_max_len) {
|
||||||
@@ -93,7 +107,13 @@ export const useCommonStore = defineStore({
|
|||||||
}).catch(error => {
|
}).catch(error => {
|
||||||
console.error('SSE error:', error);
|
console.error('SSE error:', error);
|
||||||
// Attempt to reconnect after a delay
|
// Attempt to reconnect after a delay
|
||||||
this.log_cache.push('SSE Connection failed, retrying in 5 seconds...');
|
this.log_cache.push({
|
||||||
|
type: 'log',
|
||||||
|
level: 'ERROR',
|
||||||
|
time: Date.now() / 1000,
|
||||||
|
data: 'SSE Connection failed, retrying in 5 seconds...',
|
||||||
|
uuid: 'error-' + Date.now()
|
||||||
|
});
|
||||||
setTimeout(() => {
|
setTimeout(() => {
|
||||||
this.eventSource = null;
|
this.eventSource = null;
|
||||||
this.createEventSource();
|
this.createEventSource();
|
||||||
|
|||||||
@@ -17,6 +17,13 @@
|
|||||||
|
|
||||||
<v-list-subheader>{{ tm('system.title') }}</v-list-subheader>
|
<v-list-subheader>{{ tm('system.title') }}</v-list-subheader>
|
||||||
|
|
||||||
|
<v-list-item :subtitle="tm('system.backup.subtitle')" :title="tm('system.backup.title')">
|
||||||
|
<v-btn style="margin-top: 16px;" color="primary" @click="openBackupDialog">
|
||||||
|
<v-icon class="mr-2">mdi-backup-restore</v-icon>
|
||||||
|
{{ tm('system.backup.button') }}
|
||||||
|
</v-btn>
|
||||||
|
</v-list-item>
|
||||||
|
|
||||||
<v-list-item :subtitle="tm('system.restart.subtitle')" :title="tm('system.restart.title')">
|
<v-list-item :subtitle="tm('system.restart.subtitle')" :title="tm('system.restart.title')">
|
||||||
<v-btn style="margin-top: 16px;" color="error" @click="restartAstrBot">{{ tm('system.restart.button') }}</v-btn>
|
<v-btn style="margin-top: 16px;" color="error" @click="restartAstrBot">{{ tm('system.restart.button') }}</v-btn>
|
||||||
</v-list-item>
|
</v-list-item>
|
||||||
@@ -30,6 +37,7 @@
|
|||||||
|
|
||||||
<WaitingForRestart ref="wfr"></WaitingForRestart>
|
<WaitingForRestart ref="wfr"></WaitingForRestart>
|
||||||
<MigrationDialog ref="migrationDialog"></MigrationDialog>
|
<MigrationDialog ref="migrationDialog"></MigrationDialog>
|
||||||
|
<BackupDialog ref="backupDialog"></BackupDialog>
|
||||||
|
|
||||||
</template>
|
</template>
|
||||||
|
|
||||||
@@ -40,12 +48,14 @@ import WaitingForRestart from '@/components/shared/WaitingForRestart.vue';
|
|||||||
import ProxySelector from '@/components/shared/ProxySelector.vue';
|
import ProxySelector from '@/components/shared/ProxySelector.vue';
|
||||||
import MigrationDialog from '@/components/shared/MigrationDialog.vue';
|
import MigrationDialog from '@/components/shared/MigrationDialog.vue';
|
||||||
import SidebarCustomizer from '@/components/shared/SidebarCustomizer.vue';
|
import SidebarCustomizer from '@/components/shared/SidebarCustomizer.vue';
|
||||||
|
import BackupDialog from '@/components/shared/BackupDialog.vue';
|
||||||
import { useModuleI18n } from '@/i18n/composables';
|
import { useModuleI18n } from '@/i18n/composables';
|
||||||
|
|
||||||
const { tm } = useModuleI18n('features/settings');
|
const { tm } = useModuleI18n('features/settings');
|
||||||
|
|
||||||
const wfr = ref(null);
|
const wfr = ref(null);
|
||||||
const migrationDialog = ref(null);
|
const migrationDialog = ref(null);
|
||||||
|
const backupDialog = ref(null);
|
||||||
|
|
||||||
const restartAstrBot = () => {
|
const restartAstrBot = () => {
|
||||||
axios.post('/api/stat/restart-core').then(() => {
|
axios.post('/api/stat/restart-core').then(() => {
|
||||||
@@ -65,4 +75,10 @@ const startMigration = async () => {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const openBackupDialog = () => {
|
||||||
|
if (backupDialog.value) {
|
||||||
|
backupDialog.value.open();
|
||||||
|
}
|
||||||
|
}
|
||||||
</script>
|
</script>
|
||||||
@@ -19,6 +19,7 @@ export default defineConfig({
|
|||||||
],
|
],
|
||||||
resolve: {
|
resolve: {
|
||||||
alias: {
|
alias: {
|
||||||
|
mermaid: 'mermaid/dist/mermaid.js',
|
||||||
'@': fileURLToPath(new URL('./src', import.meta.url))
|
'@': fileURLToPath(new URL('./src', import.meta.url))
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|||||||
+2
-2
@@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "AstrBot"
|
name = "AstrBot"
|
||||||
version = "4.10.2"
|
version = "4.10.3"
|
||||||
description = "Easy-to-use multi-platform LLM chatbot and development framework"
|
description = "Easy-to-use multi-platform LLM chatbot and development framework"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.10"
|
requires-python = ">=3.10"
|
||||||
@@ -103,7 +103,7 @@ typeCheckingMode = "basic"
|
|||||||
pythonVersion = "3.10"
|
pythonVersion = "3.10"
|
||||||
reportMissingTypeStubs = false
|
reportMissingTypeStubs = false
|
||||||
reportMissingImports = false
|
reportMissingImports = false
|
||||||
include = ["astrbot", "packages"]
|
include = ["astrbot"]
|
||||||
exclude = ["dashboard", "node_modules", "dist", "data", "tests"]
|
exclude = ["dashboard", "node_modules", "dist", "data", "tests"]
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
|
|||||||
@@ -0,0 +1,760 @@
|
|||||||
|
"""备份功能单元测试"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import zipfile
|
||||||
|
from datetime import datetime
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from astrbot.core.backup import (
|
||||||
|
BACKUP_MANIFEST_VERSION,
|
||||||
|
KB_METADATA_MODELS,
|
||||||
|
MAIN_DB_MODELS,
|
||||||
|
ImportPreCheckResult,
|
||||||
|
)
|
||||||
|
from astrbot.core.backup.exporter import AstrBotExporter
|
||||||
|
from astrbot.core.backup.importer import (
|
||||||
|
AstrBotImporter,
|
||||||
|
ImportResult,
|
||||||
|
_get_major_version,
|
||||||
|
)
|
||||||
|
from astrbot.core.config.default import VERSION
|
||||||
|
from astrbot.core.db.po import (
|
||||||
|
ConversationV2,
|
||||||
|
)
|
||||||
|
from astrbot.core.utils.version_comparator import VersionComparator
|
||||||
|
from astrbot.dashboard.routes.backup import (
|
||||||
|
generate_unique_filename,
|
||||||
|
secure_filename,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def temp_backup_dir(tmp_path):
|
||||||
|
"""创建临时备份目录"""
|
||||||
|
backup_dir = tmp_path / "backups"
|
||||||
|
backup_dir.mkdir()
|
||||||
|
return backup_dir
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def temp_data_dir(tmp_path):
|
||||||
|
"""创建临时数据目录"""
|
||||||
|
data_dir = tmp_path / "data"
|
||||||
|
data_dir.mkdir()
|
||||||
|
|
||||||
|
# 创建配置文件
|
||||||
|
config_path = data_dir / "cmd_config.json"
|
||||||
|
config_path.write_text(json.dumps({"test": "config"}))
|
||||||
|
|
||||||
|
# 创建附件目录
|
||||||
|
attachments_dir = data_dir / "attachments"
|
||||||
|
attachments_dir.mkdir()
|
||||||
|
|
||||||
|
return data_dir
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_main_db():
|
||||||
|
"""创建模拟的主数据库"""
|
||||||
|
db = MagicMock()
|
||||||
|
|
||||||
|
# 模拟异步上下文管理器
|
||||||
|
session = AsyncMock()
|
||||||
|
db.get_db = MagicMock(
|
||||||
|
return_value=AsyncMock(__aenter__=AsyncMock(return_value=session))
|
||||||
|
)
|
||||||
|
|
||||||
|
return db
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_kb_manager():
|
||||||
|
"""创建模拟的知识库管理器"""
|
||||||
|
kb_manager = MagicMock()
|
||||||
|
kb_manager.kb_insts = {}
|
||||||
|
|
||||||
|
# 模拟 kb_db
|
||||||
|
kb_db = MagicMock()
|
||||||
|
session = AsyncMock()
|
||||||
|
kb_db.get_db = MagicMock(
|
||||||
|
return_value=AsyncMock(__aenter__=AsyncMock(return_value=session))
|
||||||
|
)
|
||||||
|
kb_manager.kb_db = kb_db
|
||||||
|
|
||||||
|
return kb_manager
|
||||||
|
|
||||||
|
|
||||||
|
class TestImportResult:
|
||||||
|
"""ImportResult 类测试"""
|
||||||
|
|
||||||
|
def test_init(self):
|
||||||
|
"""测试初始化"""
|
||||||
|
result = ImportResult()
|
||||||
|
assert result.success is True
|
||||||
|
assert result.imported_tables == {}
|
||||||
|
assert result.imported_files == {}
|
||||||
|
assert result.warnings == []
|
||||||
|
assert result.errors == []
|
||||||
|
|
||||||
|
def test_add_warning(self):
|
||||||
|
"""测试添加警告"""
|
||||||
|
result = ImportResult()
|
||||||
|
result.add_warning("test warning")
|
||||||
|
assert "test warning" in result.warnings
|
||||||
|
assert result.success is True # 警告不影响成功状态
|
||||||
|
|
||||||
|
def test_add_error(self):
|
||||||
|
"""测试添加错误"""
|
||||||
|
result = ImportResult()
|
||||||
|
result.add_error("test error")
|
||||||
|
assert "test error" in result.errors
|
||||||
|
assert result.success is False # 错误会导致失败
|
||||||
|
|
||||||
|
def test_to_dict(self):
|
||||||
|
"""测试转换为字典"""
|
||||||
|
result = ImportResult()
|
||||||
|
result.imported_tables = {"test_table": 10}
|
||||||
|
result.add_warning("warning")
|
||||||
|
|
||||||
|
d = result.to_dict()
|
||||||
|
assert d["success"] is True
|
||||||
|
assert d["imported_tables"] == {"test_table": 10}
|
||||||
|
assert "warning" in d["warnings"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestAstrBotExporter:
|
||||||
|
"""AstrBotExporter 类测试"""
|
||||||
|
|
||||||
|
def test_init(self, mock_main_db, mock_kb_manager, temp_data_dir):
|
||||||
|
"""测试初始化"""
|
||||||
|
exporter = AstrBotExporter(
|
||||||
|
main_db=mock_main_db,
|
||||||
|
kb_manager=mock_kb_manager,
|
||||||
|
config_path=str(temp_data_dir / "cmd_config.json"),
|
||||||
|
)
|
||||||
|
assert exporter.main_db is mock_main_db
|
||||||
|
assert exporter.kb_manager is mock_kb_manager
|
||||||
|
|
||||||
|
def test_model_to_dict_with_model_dump(self):
|
||||||
|
"""测试 _model_to_dict 使用 model_dump 方法"""
|
||||||
|
exporter = AstrBotExporter(main_db=MagicMock())
|
||||||
|
|
||||||
|
# 创建一个有 model_dump 方法的模拟对象
|
||||||
|
mock_record = MagicMock()
|
||||||
|
mock_record.model_dump.return_value = {"id": 1, "name": "test"}
|
||||||
|
|
||||||
|
result = exporter._model_to_dict(mock_record)
|
||||||
|
assert result == {"id": 1, "name": "test"}
|
||||||
|
|
||||||
|
def test_model_to_dict_with_datetime(self):
|
||||||
|
"""测试 _model_to_dict 处理 datetime 字段"""
|
||||||
|
exporter = AstrBotExporter(main_db=MagicMock())
|
||||||
|
|
||||||
|
now = datetime.now()
|
||||||
|
mock_record = MagicMock()
|
||||||
|
mock_record.model_dump.return_value = {"id": 1, "created_at": now}
|
||||||
|
|
||||||
|
result = exporter._model_to_dict(mock_record)
|
||||||
|
assert result["created_at"] == now.isoformat()
|
||||||
|
|
||||||
|
def test_add_checksum(self):
|
||||||
|
"""测试添加校验和"""
|
||||||
|
exporter = AstrBotExporter(main_db=MagicMock())
|
||||||
|
|
||||||
|
exporter._add_checksum("test.json", '{"test": "data"}')
|
||||||
|
|
||||||
|
assert "test.json" in exporter._checksums
|
||||||
|
assert exporter._checksums["test.json"].startswith("sha256:")
|
||||||
|
|
||||||
|
def test_generate_manifest(self, mock_main_db, mock_kb_manager):
|
||||||
|
"""测试生成清单"""
|
||||||
|
exporter = AstrBotExporter(
|
||||||
|
main_db=mock_main_db,
|
||||||
|
kb_manager=mock_kb_manager,
|
||||||
|
)
|
||||||
|
|
||||||
|
main_data = {
|
||||||
|
"platform_stats": [{"id": 1}],
|
||||||
|
"conversations": [],
|
||||||
|
"attachments": [],
|
||||||
|
}
|
||||||
|
kb_meta_data = {
|
||||||
|
"knowledge_bases": [],
|
||||||
|
"kb_documents": [],
|
||||||
|
}
|
||||||
|
dir_stats = {
|
||||||
|
"plugins": {"files": 10, "size": 1024},
|
||||||
|
"plugin_data": {"files": 5, "size": 512},
|
||||||
|
}
|
||||||
|
|
||||||
|
manifest = exporter._generate_manifest(main_data, kb_meta_data, dir_stats)
|
||||||
|
|
||||||
|
assert manifest["version"] == BACKUP_MANIFEST_VERSION
|
||||||
|
assert manifest["astrbot_version"] == VERSION
|
||||||
|
assert "exported_at" in manifest
|
||||||
|
assert "tables" in manifest
|
||||||
|
assert "statistics" in manifest
|
||||||
|
assert "directories" in manifest
|
||||||
|
assert manifest["statistics"]["main_db"]["platform_stats"] == 1
|
||||||
|
assert manifest["statistics"]["directories"] == dir_stats
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_export_all_creates_zip(
|
||||||
|
self, mock_main_db, temp_backup_dir, temp_data_dir
|
||||||
|
):
|
||||||
|
"""测试导出创建 ZIP 文件"""
|
||||||
|
# 设置模拟数据库返回空数据
|
||||||
|
session = AsyncMock()
|
||||||
|
result = MagicMock()
|
||||||
|
result.scalars.return_value.all.return_value = []
|
||||||
|
session.execute = AsyncMock(return_value=result)
|
||||||
|
|
||||||
|
mock_main_db.get_db.return_value = AsyncMock(
|
||||||
|
__aenter__=AsyncMock(return_value=session),
|
||||||
|
__aexit__=AsyncMock(return_value=None),
|
||||||
|
)
|
||||||
|
|
||||||
|
exporter = AstrBotExporter(
|
||||||
|
main_db=mock_main_db,
|
||||||
|
kb_manager=None,
|
||||||
|
config_path=str(temp_data_dir / "cmd_config.json"),
|
||||||
|
)
|
||||||
|
|
||||||
|
zip_path = await exporter.export_all(output_dir=str(temp_backup_dir))
|
||||||
|
|
||||||
|
assert os.path.exists(zip_path)
|
||||||
|
assert zip_path.endswith(".zip")
|
||||||
|
assert "astrbot_backup_" in zip_path
|
||||||
|
|
||||||
|
# 验证 ZIP 文件内容
|
||||||
|
with zipfile.ZipFile(zip_path, "r") as zf:
|
||||||
|
namelist = zf.namelist()
|
||||||
|
assert "manifest.json" in namelist
|
||||||
|
assert "databases/main_db.json" in namelist
|
||||||
|
assert "config/cmd_config.json" in namelist
|
||||||
|
|
||||||
|
|
||||||
|
class TestAstrBotImporter:
|
||||||
|
"""AstrBotImporter 类测试"""
|
||||||
|
|
||||||
|
def test_init(self, mock_main_db, mock_kb_manager, temp_data_dir):
|
||||||
|
"""测试初始化"""
|
||||||
|
importer = AstrBotImporter(
|
||||||
|
main_db=mock_main_db,
|
||||||
|
kb_manager=mock_kb_manager,
|
||||||
|
config_path=str(temp_data_dir / "cmd_config.json"),
|
||||||
|
)
|
||||||
|
assert importer.main_db is mock_main_db
|
||||||
|
assert importer.kb_manager is mock_kb_manager
|
||||||
|
|
||||||
|
def test_validate_version_match(self):
|
||||||
|
"""测试版本匹配验证"""
|
||||||
|
importer = AstrBotImporter(main_db=MagicMock())
|
||||||
|
|
||||||
|
manifest = {"astrbot_version": VERSION}
|
||||||
|
# 不应该抛出异常
|
||||||
|
importer._validate_version(manifest)
|
||||||
|
|
||||||
|
def test_validate_version_major_diff_rejected(self):
|
||||||
|
"""测试主版本不同被拒绝"""
|
||||||
|
importer = AstrBotImporter(main_db=MagicMock())
|
||||||
|
|
||||||
|
# 使用一个明显不同的主版本
|
||||||
|
manifest = {"astrbot_version": "0.0.1"}
|
||||||
|
with pytest.raises(ValueError, match="主版本不兼容"):
|
||||||
|
importer._validate_version(manifest)
|
||||||
|
|
||||||
|
def test_validate_version_minor_diff_allowed(self):
|
||||||
|
"""测试小版本不同被允许(仅警告)"""
|
||||||
|
importer = AstrBotImporter(main_db=MagicMock())
|
||||||
|
|
||||||
|
# 获取当前主版本
|
||||||
|
major_version = _get_major_version(VERSION)
|
||||||
|
# 构造一个同主版本但小版本不同的版本
|
||||||
|
minor_diff_version = f"{major_version}.999"
|
||||||
|
manifest = {"astrbot_version": minor_diff_version}
|
||||||
|
# 不应该抛出异常
|
||||||
|
importer._validate_version(manifest)
|
||||||
|
|
||||||
|
def test_validate_version_missing(self):
|
||||||
|
"""测试缺少版本信息"""
|
||||||
|
importer = AstrBotImporter(main_db=MagicMock())
|
||||||
|
|
||||||
|
manifest = {}
|
||||||
|
with pytest.raises(ValueError, match="缺少版本信息"):
|
||||||
|
importer._validate_version(manifest)
|
||||||
|
|
||||||
|
def test_convert_datetime_fields(self):
|
||||||
|
"""测试 datetime 字段转换"""
|
||||||
|
importer = AstrBotImporter(main_db=MagicMock())
|
||||||
|
|
||||||
|
# 使用 ConversationV2 作为测试模型(它有 created_at 和 updated_at 字段)
|
||||||
|
row = {
|
||||||
|
"conversation_id": "test-123",
|
||||||
|
"platform_id": "test",
|
||||||
|
"user_id": "user1",
|
||||||
|
"created_at": "2024-01-01T12:00:00",
|
||||||
|
"updated_at": "2024-01-01T12:00:00",
|
||||||
|
}
|
||||||
|
|
||||||
|
result = importer._convert_datetime_fields(row, ConversationV2)
|
||||||
|
|
||||||
|
# created_at 应该被转换为 datetime 对象
|
||||||
|
assert isinstance(result["created_at"], datetime)
|
||||||
|
assert isinstance(result["updated_at"], datetime)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_import_file_not_exists(self, mock_main_db, tmp_path):
|
||||||
|
"""测试导入不存在的文件"""
|
||||||
|
importer = AstrBotImporter(main_db=mock_main_db)
|
||||||
|
|
||||||
|
result = await importer.import_all(str(tmp_path / "nonexistent.zip"))
|
||||||
|
|
||||||
|
assert result.success is False
|
||||||
|
assert any("不存在" in err for err in result.errors)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_import_invalid_zip(self, mock_main_db, tmp_path):
|
||||||
|
"""测试导入无效的 ZIP 文件"""
|
||||||
|
# 创建一个无效的文件
|
||||||
|
invalid_zip = tmp_path / "invalid.zip"
|
||||||
|
invalid_zip.write_text("not a zip file")
|
||||||
|
|
||||||
|
importer = AstrBotImporter(main_db=mock_main_db)
|
||||||
|
result = await importer.import_all(str(invalid_zip))
|
||||||
|
|
||||||
|
assert result.success is False
|
||||||
|
assert any("无效" in err or "ZIP" in err for err in result.errors)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_import_missing_manifest(self, mock_main_db, tmp_path):
|
||||||
|
"""测试导入缺少 manifest 的 ZIP 文件"""
|
||||||
|
# 创建一个没有 manifest 的 ZIP 文件
|
||||||
|
zip_path = tmp_path / "no_manifest.zip"
|
||||||
|
with zipfile.ZipFile(zip_path, "w") as zf:
|
||||||
|
zf.writestr("test.txt", "test content")
|
||||||
|
|
||||||
|
importer = AstrBotImporter(main_db=mock_main_db)
|
||||||
|
result = await importer.import_all(str(zip_path))
|
||||||
|
|
||||||
|
assert result.success is False
|
||||||
|
assert any("manifest" in err.lower() for err in result.errors)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_import_major_version_mismatch(self, mock_main_db, tmp_path):
|
||||||
|
"""测试导入主版本不匹配的备份"""
|
||||||
|
# 创建一个主版本不匹配的备份
|
||||||
|
zip_path = tmp_path / "old_version.zip"
|
||||||
|
manifest = {
|
||||||
|
"version": "1.0",
|
||||||
|
"astrbot_version": "0.0.1", # 主版本不同
|
||||||
|
"tables": {"main_db": []},
|
||||||
|
}
|
||||||
|
|
||||||
|
with zipfile.ZipFile(zip_path, "w") as zf:
|
||||||
|
zf.writestr("manifest.json", json.dumps(manifest))
|
||||||
|
|
||||||
|
importer = AstrBotImporter(main_db=mock_main_db)
|
||||||
|
result = await importer.import_all(str(zip_path))
|
||||||
|
|
||||||
|
assert result.success is False
|
||||||
|
assert any("主版本不兼容" in err for err in result.errors)
|
||||||
|
|
||||||
|
|
||||||
|
class TestSecureFilename:
|
||||||
|
"""安全文件名函数测试"""
|
||||||
|
|
||||||
|
def test_secure_filename_normal(self):
|
||||||
|
"""测试正常文件名"""
|
||||||
|
assert secure_filename("backup.zip") == "backup.zip"
|
||||||
|
assert secure_filename("my_backup_2024.zip") == "my_backup_2024.zip"
|
||||||
|
|
||||||
|
def test_secure_filename_path_traversal(self):
|
||||||
|
"""测试路径遍历攻击"""
|
||||||
|
assert ".." not in secure_filename("../../../etc/passwd")
|
||||||
|
assert "/" not in secure_filename("/etc/passwd")
|
||||||
|
assert "\\" not in secure_filename("..\\..\\windows\\system32")
|
||||||
|
|
||||||
|
def test_secure_filename_with_path(self):
|
||||||
|
"""测试带路径的文件名"""
|
||||||
|
result = secure_filename("/path/to/backup.zip")
|
||||||
|
assert result == "backup.zip"
|
||||||
|
|
||||||
|
result = secure_filename("C:\\Users\\test\\backup.zip")
|
||||||
|
assert result == "backup.zip"
|
||||||
|
|
||||||
|
def test_secure_filename_special_chars(self):
|
||||||
|
"""测试特殊字符"""
|
||||||
|
result = secure_filename('backup<>:"|?*.zip')
|
||||||
|
# 特殊字符应被替换为下划线
|
||||||
|
assert "<" not in result
|
||||||
|
assert ">" not in result
|
||||||
|
assert ":" not in result
|
||||||
|
assert '"' not in result
|
||||||
|
assert "|" not in result
|
||||||
|
assert "?" not in result
|
||||||
|
assert "*" not in result
|
||||||
|
|
||||||
|
def test_secure_filename_hidden_file(self):
|
||||||
|
"""测试隐藏文件(前导点)"""
|
||||||
|
result = secure_filename(".hidden_backup.zip")
|
||||||
|
assert not result.startswith(".")
|
||||||
|
|
||||||
|
def test_secure_filename_empty(self):
|
||||||
|
"""测试空文件名"""
|
||||||
|
assert secure_filename("") == "backup"
|
||||||
|
assert secure_filename("...") == "backup"
|
||||||
|
|
||||||
|
def test_generate_unique_filename(self):
|
||||||
|
"""测试生成唯一文件名"""
|
||||||
|
result = generate_unique_filename("backup.zip")
|
||||||
|
# 应包含 uploaded_ 前缀和时间戳
|
||||||
|
assert result.startswith("uploaded_")
|
||||||
|
assert result.endswith("_backup.zip")
|
||||||
|
# 应包含时间戳格式 YYYYMMDD_HHMMSS
|
||||||
|
assert re.search(r"uploaded_\d{8}_\d{6}_backup\.zip", result)
|
||||||
|
|
||||||
|
|
||||||
|
class TestVersionComparison:
|
||||||
|
"""版本比较函数测试 - 使用 VersionComparator"""
|
||||||
|
|
||||||
|
def test_get_major_version_simple(self):
|
||||||
|
"""测试提取简单主版本号"""
|
||||||
|
assert _get_major_version("1.0") == "1.0"
|
||||||
|
assert _get_major_version("2.1") == "2.1"
|
||||||
|
assert _get_major_version("4.9.1") == "4.9"
|
||||||
|
|
||||||
|
def test_get_major_version_with_prefix(self):
|
||||||
|
"""测试带 v 前缀的版本号"""
|
||||||
|
assert _get_major_version("v1.0") == "1.0"
|
||||||
|
assert _get_major_version("V4.9.1") == "4.9"
|
||||||
|
|
||||||
|
def test_get_major_version_with_prerelease(self):
|
||||||
|
"""测试带预发布标签的版本号"""
|
||||||
|
assert _get_major_version("4.9.1-beta") == "4.9"
|
||||||
|
assert _get_major_version("4.9.1-alpha.1") == "4.9"
|
||||||
|
assert _get_major_version("4.9.1+build123") == "4.9"
|
||||||
|
|
||||||
|
def test_get_major_version_single_part(self):
|
||||||
|
"""测试单部分版本号"""
|
||||||
|
assert _get_major_version("1") == "1.0"
|
||||||
|
|
||||||
|
def test_get_major_version_empty(self):
|
||||||
|
"""测试空版本号"""
|
||||||
|
assert _get_major_version("") == "0.0"
|
||||||
|
|
||||||
|
def test_compare_versions_equal(self):
|
||||||
|
"""测试版本相等"""
|
||||||
|
assert VersionComparator.compare_version("1.0", "1.0") == 0
|
||||||
|
assert VersionComparator.compare_version("1.0.0", "1.0") == 0
|
||||||
|
assert VersionComparator.compare_version("2.10", "2.10") == 0
|
||||||
|
|
||||||
|
def test_compare_versions_less_than(self):
|
||||||
|
"""测试版本小于"""
|
||||||
|
assert VersionComparator.compare_version("1.0", "1.1") == -1
|
||||||
|
assert (
|
||||||
|
VersionComparator.compare_version("1.9", "1.10") == -1
|
||||||
|
) # 关键测试:多位数版本比较
|
||||||
|
assert VersionComparator.compare_version("1.2", "1.10") == -1
|
||||||
|
assert VersionComparator.compare_version("1.0", "2.0") == -1
|
||||||
|
|
||||||
|
def test_compare_versions_greater_than(self):
|
||||||
|
"""测试版本大于"""
|
||||||
|
assert VersionComparator.compare_version("1.1", "1.0") == 1
|
||||||
|
assert (
|
||||||
|
VersionComparator.compare_version("1.10", "1.9") == 1
|
||||||
|
) # 关键测试:多位数版本比较
|
||||||
|
assert VersionComparator.compare_version("1.10", "1.2") == 1
|
||||||
|
assert VersionComparator.compare_version("2.0", "1.0") == 1
|
||||||
|
|
||||||
|
def test_compare_versions_different_lengths(self):
|
||||||
|
"""测试不同长度版本比较"""
|
||||||
|
assert VersionComparator.compare_version("1.0", "1.0.0") == 0
|
||||||
|
assert VersionComparator.compare_version("1.0", "1.0.1") == -1
|
||||||
|
assert VersionComparator.compare_version("1.0.1", "1.0") == 1
|
||||||
|
|
||||||
|
def test_compare_versions_prerelease(self):
|
||||||
|
"""测试预发布版本比较"""
|
||||||
|
# 预发布版本低于正式版本
|
||||||
|
assert VersionComparator.compare_version("1.0.0-alpha", "1.0.0") == -1
|
||||||
|
assert VersionComparator.compare_version("1.0.0", "1.0.0-beta") == 1
|
||||||
|
# alpha < beta
|
||||||
|
assert VersionComparator.compare_version("1.0.0-alpha", "1.0.0-beta") == -1
|
||||||
|
|
||||||
|
|
||||||
|
class TestImportPreCheckResult:
|
||||||
|
"""ImportPreCheckResult 类测试"""
|
||||||
|
|
||||||
|
def test_init_default_values(self):
|
||||||
|
"""测试默认值初始化"""
|
||||||
|
result = ImportPreCheckResult()
|
||||||
|
assert result.valid is False
|
||||||
|
assert result.can_import is False
|
||||||
|
assert result.version_status == ""
|
||||||
|
assert result.backup_version == ""
|
||||||
|
assert result.current_version == VERSION
|
||||||
|
assert result.confirm_message == ""
|
||||||
|
assert result.warnings == []
|
||||||
|
assert result.error == ""
|
||||||
|
assert result.backup_summary == {}
|
||||||
|
|
||||||
|
def test_to_dict(self):
|
||||||
|
"""测试转换为字典"""
|
||||||
|
result = ImportPreCheckResult(
|
||||||
|
valid=True,
|
||||||
|
can_import=True,
|
||||||
|
version_status="match",
|
||||||
|
backup_version="4.9.0",
|
||||||
|
confirm_message="确认导入?",
|
||||||
|
warnings=["警告1"],
|
||||||
|
backup_summary={"tables": ["table1"]},
|
||||||
|
)
|
||||||
|
|
||||||
|
d = result.to_dict()
|
||||||
|
assert d["valid"] is True
|
||||||
|
assert d["can_import"] is True
|
||||||
|
assert d["version_status"] == "match"
|
||||||
|
assert d["backup_version"] == "4.9.0"
|
||||||
|
assert d["confirm_message"] == "确认导入?"
|
||||||
|
assert "警告1" in d["warnings"]
|
||||||
|
assert d["backup_summary"]["tables"] == ["table1"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestPreCheck:
|
||||||
|
"""预检查功能测试"""
|
||||||
|
|
||||||
|
def test_pre_check_file_not_exists(self, mock_main_db):
|
||||||
|
"""测试预检查不存在的文件"""
|
||||||
|
importer = AstrBotImporter(main_db=mock_main_db)
|
||||||
|
result = importer.pre_check("/nonexistent/file.zip")
|
||||||
|
|
||||||
|
assert result.valid is False
|
||||||
|
assert "不存在" in result.error
|
||||||
|
|
||||||
|
def test_pre_check_invalid_zip(self, mock_main_db, tmp_path):
|
||||||
|
"""测试预检查无效的 ZIP 文件"""
|
||||||
|
invalid_zip = tmp_path / "invalid.zip"
|
||||||
|
invalid_zip.write_text("not a zip file")
|
||||||
|
|
||||||
|
importer = AstrBotImporter(main_db=mock_main_db)
|
||||||
|
result = importer.pre_check(str(invalid_zip))
|
||||||
|
|
||||||
|
assert result.valid is False
|
||||||
|
assert "ZIP" in result.error or "无效" in result.error
|
||||||
|
|
||||||
|
def test_pre_check_missing_manifest(self, mock_main_db, tmp_path):
|
||||||
|
"""测试预检查缺少 manifest 的 ZIP 文件"""
|
||||||
|
zip_path = tmp_path / "no_manifest.zip"
|
||||||
|
with zipfile.ZipFile(zip_path, "w") as zf:
|
||||||
|
zf.writestr("test.txt", "test content")
|
||||||
|
|
||||||
|
importer = AstrBotImporter(main_db=mock_main_db)
|
||||||
|
result = importer.pre_check(str(zip_path))
|
||||||
|
|
||||||
|
assert result.valid is False
|
||||||
|
assert "manifest" in result.error.lower()
|
||||||
|
|
||||||
|
def test_pre_check_version_match(self, mock_main_db, tmp_path):
|
||||||
|
"""测试预检查版本匹配"""
|
||||||
|
zip_path = tmp_path / "backup.zip"
|
||||||
|
manifest = {
|
||||||
|
"version": "1.1",
|
||||||
|
"astrbot_version": VERSION,
|
||||||
|
"created_at": "2024-01-01T12:00:00",
|
||||||
|
"tables": {"platform_stats": 1},
|
||||||
|
"has_knowledge_bases": True,
|
||||||
|
"has_config": True,
|
||||||
|
"directories": ["plugins"],
|
||||||
|
}
|
||||||
|
|
||||||
|
with zipfile.ZipFile(zip_path, "w") as zf:
|
||||||
|
zf.writestr("manifest.json", json.dumps(manifest))
|
||||||
|
|
||||||
|
importer = AstrBotImporter(main_db=mock_main_db)
|
||||||
|
result = importer.pre_check(str(zip_path))
|
||||||
|
|
||||||
|
assert result.valid is True
|
||||||
|
assert result.can_import is True
|
||||||
|
assert result.version_status == "match"
|
||||||
|
assert result.backup_version == VERSION
|
||||||
|
# confirm_message 现在由前端生成,后端不再生成
|
||||||
|
assert result.backup_summary["has_knowledge_bases"] is True
|
||||||
|
|
||||||
|
def test_pre_check_minor_version_diff(self, mock_main_db, tmp_path):
|
||||||
|
"""测试预检查小版本差异"""
|
||||||
|
# 构造一个同主版本但小版本不同的版本
|
||||||
|
major_version = _get_major_version(VERSION)
|
||||||
|
minor_diff_version = f"{major_version}.999"
|
||||||
|
|
||||||
|
zip_path = tmp_path / "backup.zip"
|
||||||
|
manifest = {
|
||||||
|
"version": "1.1",
|
||||||
|
"astrbot_version": minor_diff_version,
|
||||||
|
"created_at": "2024-01-01T12:00:00",
|
||||||
|
"tables": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
with zipfile.ZipFile(zip_path, "w") as zf:
|
||||||
|
zf.writestr("manifest.json", json.dumps(manifest))
|
||||||
|
|
||||||
|
importer = AstrBotImporter(main_db=mock_main_db)
|
||||||
|
result = importer.pre_check(str(zip_path))
|
||||||
|
|
||||||
|
assert result.valid is True
|
||||||
|
assert result.can_import is True
|
||||||
|
assert result.version_status == "minor_diff"
|
||||||
|
# 版本消息由前端 i18n 生成,后端 warnings 列表不再包含版本相关消息
|
||||||
|
# warnings 列表保留用于其他非版本相关的警告
|
||||||
|
|
||||||
|
def test_pre_check_major_version_diff(self, mock_main_db, tmp_path):
|
||||||
|
"""测试预检查主版本差异"""
|
||||||
|
zip_path = tmp_path / "backup.zip"
|
||||||
|
manifest = {
|
||||||
|
"version": "1.1",
|
||||||
|
"astrbot_version": "0.0.1", # 主版本不同
|
||||||
|
"created_at": "2024-01-01T12:00:00",
|
||||||
|
"tables": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
with zipfile.ZipFile(zip_path, "w") as zf:
|
||||||
|
zf.writestr("manifest.json", json.dumps(manifest))
|
||||||
|
|
||||||
|
importer = AstrBotImporter(main_db=mock_main_db)
|
||||||
|
result = importer.pre_check(str(zip_path))
|
||||||
|
|
||||||
|
assert result.valid is True # 文件有效
|
||||||
|
assert result.can_import is False # 但不能导入
|
||||||
|
assert result.version_status == "major_diff"
|
||||||
|
# 版本消息由前端 i18n 生成,后端 warnings 列表不再包含版本相关消息
|
||||||
|
|
||||||
|
|
||||||
|
class TestVersionCompatibility:
|
||||||
|
"""版本兼容性检查测试"""
|
||||||
|
|
||||||
|
def test_check_version_compatibility_match(self, mock_main_db):
|
||||||
|
"""测试版本完全匹配"""
|
||||||
|
importer = AstrBotImporter(main_db=mock_main_db)
|
||||||
|
result = importer._check_version_compatibility(VERSION)
|
||||||
|
|
||||||
|
assert result["status"] == "match"
|
||||||
|
assert result["can_import"] is True
|
||||||
|
|
||||||
|
def test_check_version_compatibility_minor_diff(self, mock_main_db):
|
||||||
|
"""测试小版本差异"""
|
||||||
|
major_version = _get_major_version(VERSION)
|
||||||
|
minor_diff_version = f"{major_version}.999"
|
||||||
|
|
||||||
|
importer = AstrBotImporter(main_db=mock_main_db)
|
||||||
|
result = importer._check_version_compatibility(minor_diff_version)
|
||||||
|
|
||||||
|
assert result["status"] == "minor_diff"
|
||||||
|
assert result["can_import"] is True
|
||||||
|
|
||||||
|
def test_check_version_compatibility_major_diff(self, mock_main_db):
|
||||||
|
"""测试主版本差异"""
|
||||||
|
importer = AstrBotImporter(main_db=mock_main_db)
|
||||||
|
result = importer._check_version_compatibility("0.0.1")
|
||||||
|
|
||||||
|
assert result["status"] == "major_diff"
|
||||||
|
assert result["can_import"] is False
|
||||||
|
|
||||||
|
def test_check_version_compatibility_empty_version(self, mock_main_db):
|
||||||
|
"""测试空版本号"""
|
||||||
|
importer = AstrBotImporter(main_db=mock_main_db)
|
||||||
|
result = importer._check_version_compatibility("")
|
||||||
|
|
||||||
|
assert result["status"] == "major_diff"
|
||||||
|
assert result["can_import"] is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestModelMappings:
|
||||||
|
"""测试模型映射配置"""
|
||||||
|
|
||||||
|
def test_main_db_models_not_empty(self):
|
||||||
|
"""测试主数据库模型映射非空"""
|
||||||
|
assert len(MAIN_DB_MODELS) > 0
|
||||||
|
|
||||||
|
def test_main_db_models_contain_expected_tables(self):
|
||||||
|
"""测试主数据库模型映射包含预期的表"""
|
||||||
|
expected_tables = [
|
||||||
|
"platform_stats",
|
||||||
|
"conversations",
|
||||||
|
"personas",
|
||||||
|
"preferences",
|
||||||
|
"attachments",
|
||||||
|
]
|
||||||
|
for table in expected_tables:
|
||||||
|
assert table in MAIN_DB_MODELS, f"Missing table: {table}"
|
||||||
|
|
||||||
|
def test_kb_metadata_models_not_empty(self):
|
||||||
|
"""测试知识库元数据模型映射非空"""
|
||||||
|
assert len(KB_METADATA_MODELS) > 0
|
||||||
|
|
||||||
|
def test_kb_metadata_models_contain_expected_tables(self):
|
||||||
|
"""测试知识库元数据模型映射包含预期的表"""
|
||||||
|
expected_tables = [
|
||||||
|
"knowledge_bases",
|
||||||
|
"kb_documents",
|
||||||
|
"kb_media",
|
||||||
|
]
|
||||||
|
for table in expected_tables:
|
||||||
|
assert table in KB_METADATA_MODELS, f"Missing table: {table}"
|
||||||
|
|
||||||
|
|
||||||
|
class TestBackupIntegration:
|
||||||
|
"""备份集成测试"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_export_import_roundtrip(self, tmp_path):
|
||||||
|
"""测试导出-导入往返"""
|
||||||
|
backup_dir = tmp_path / "backups"
|
||||||
|
backup_dir.mkdir()
|
||||||
|
|
||||||
|
data_dir = tmp_path / "data"
|
||||||
|
data_dir.mkdir()
|
||||||
|
|
||||||
|
config_path = data_dir / "cmd_config.json"
|
||||||
|
config_path.write_text(json.dumps({"setting": "value"}))
|
||||||
|
|
||||||
|
attachments_dir = data_dir / "attachments"
|
||||||
|
attachments_dir.mkdir()
|
||||||
|
|
||||||
|
# 创建模拟数据库
|
||||||
|
mock_db = MagicMock()
|
||||||
|
session = AsyncMock()
|
||||||
|
result = MagicMock()
|
||||||
|
result.scalars.return_value.all.return_value = []
|
||||||
|
session.execute = AsyncMock(return_value=result)
|
||||||
|
|
||||||
|
mock_db.get_db.return_value = AsyncMock(
|
||||||
|
__aenter__=AsyncMock(return_value=session),
|
||||||
|
__aexit__=AsyncMock(return_value=None),
|
||||||
|
)
|
||||||
|
|
||||||
|
# 导出
|
||||||
|
exporter = AstrBotExporter(
|
||||||
|
main_db=mock_db,
|
||||||
|
kb_manager=None,
|
||||||
|
config_path=str(config_path),
|
||||||
|
)
|
||||||
|
|
||||||
|
zip_path = await exporter.export_all(output_dir=str(backup_dir))
|
||||||
|
assert os.path.exists(zip_path)
|
||||||
|
|
||||||
|
# 验证 ZIP 内容
|
||||||
|
with zipfile.ZipFile(zip_path, "r") as zf:
|
||||||
|
# 读取 manifest
|
||||||
|
manifest = json.loads(zf.read("manifest.json"))
|
||||||
|
assert manifest["astrbot_version"] == VERSION
|
||||||
|
|
||||||
|
# 读取配置
|
||||||
|
config = json.loads(zf.read("config/cmd_config.json"))
|
||||||
|
assert config["setting"] == "value"
|
||||||
|
|
||||||
|
# 读取主数据库
|
||||||
|
main_db = json.loads(zf.read("databases/main_db.json"))
|
||||||
|
assert "platform_stats" in main_db
|
||||||
Reference in New Issue
Block a user