diff --git a/astrbot/core/__init__.py b/astrbot/core/__init__.py index 59e61d73b..a9b1fafd5 100644 --- a/astrbot/core/__init__.py +++ b/astrbot/core/__init__.py @@ -7,27 +7,27 @@ from astrbot.core.utils.pip_installer import PipInstaller from astrbot.core.db.sqlite import SQLiteDatabase from astrbot.core.config.default import DB_PATH from astrbot.core.config import AstrBotConfig +from astrbot.core.file_token_service import FileTokenService # 初始化数据存储文件夹 os.makedirs("data", exist_ok=True) +WEBUI_SK = "Advanced_System_for_Text_Response_and_Bot_Operations_Tool" +DEMO_MODE = os.getenv("DEMO_MODE", False) + astrbot_config = AstrBotConfig() t2i_base_url = astrbot_config.get("t2i_endpoint", "https://t2i.soulter.top/text2img") html_renderer = HtmlRenderer(t2i_base_url) logger = LogManager.GetLogger(log_name="astrbot") - -if os.environ.get("TESTING", ""): - logger.setLevel("DEBUG") - db_helper = SQLiteDatabase(DB_PATH) -sp = ( - SharedPreferences() -) # 简单的偏好设置存储, 这里后续应该存储到数据库中, 一些部分可以存储到配置中 +# 简单的偏好设置存储, 这里后续应该存储到数据库中, 一些部分可以存储到配置中 +sp = SharedPreferences() +# 文件令牌服务 +file_token_service = FileTokenService() pip_installer = PipInstaller( astrbot_config.get("pip_install_arg", ""), astrbot_config.get("pypi_index_url", None), ) web_chat_queue = asyncio.Queue(maxsize=32) web_chat_back_queue = asyncio.Queue(maxsize=32) -WEBUI_SK = "Advanced_System_for_Text_Response_and_Bot_Operations_Tool" -DEMO_MODE = os.getenv("DEMO_MODE", False) + diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 0aaf2a6d6..341d105a8 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -104,6 +104,7 @@ DEFAULT_CONFIG = { "knowledge_db": {}, "persona": [], "timezone": "", + "callback_api_base": "", } @@ -1283,6 +1284,12 @@ CONFIG_METADATA_2 = { "obvious_hint": True, "hint": "时区设置。请填写 IANA 时区名称, 如 Asia/Shanghai, 为空时使用系统默认时区。所有时区请查看: https://data.iana.org/time-zones/tzdb-2021a/zone1970.tab", }, + "callback_api_base": { + "description": "对外可达的回调接口地址", + "type": "string", + "obvious_hint": True, + "hint": "外部服务可能会通过 AstrBot 生成的回调链接(如文件下载链接)访问 AstrBot 后端。由于 AstrBot 无法自动判断部署环境中对外可达的主机地址(host),因此需要通过此配置项显式指定 “外部服务如何访问 AstrBot” 的地址。如 http://localhost:6185,https://example.com 等。" + }, "log_level": { "description": "控制台日志级别", "type": "string", diff --git a/astrbot/core/file_token_service.py b/astrbot/core/file_token_service.py new file mode 100644 index 000000000..2ed46d433 --- /dev/null +++ b/astrbot/core/file_token_service.py @@ -0,0 +1,68 @@ +import asyncio +import os +import uuid +import time + + +class FileTokenService: + """维护一个简单的基于令牌的文件下载服务,支持超时和懒清除。""" + + def __init__(self, default_timeout: float = 300): + self.lock = asyncio.Lock() + self.staged_files = {} # token: (file_path, expire_time) + self.default_timeout = default_timeout + + async def _cleanup_expired_tokens(self): + """清理过期的令牌""" + now = time.time() + expired_tokens = [token for token, (_, expire) in self.staged_files.items() if expire < now] + for token in expired_tokens: + self.staged_files.pop(token, None) + + async def register_file(self, file_path: str, timeout: float = None) -> str: + """向令牌服务注册一个文件。 + + Args: + file_path(str): 文件路径 + timeout(float): 超时时间,单位秒(可选) + + Returns: + str: 一个单次令牌 + + Raises: + FileNotFoundError: 当路径不存在时抛出 + """ + async with self.lock: + await self._cleanup_expired_tokens() + + if not os.path.exists(file_path): + raise FileNotFoundError(f"文件不存在: {file_path}") + + file_token = str(uuid.uuid4()) + expire_time = time.time() + (timeout if timeout is not None else self.default_timeout) + self.staged_files[file_token] = (file_path, expire_time) + return file_token + + async def handle_file(self, file_token: str) -> str: + """根据令牌获取文件路径,使用后令牌失效。 + + Args: + file_token(str): 注册时返回的令牌 + + Returns: + str: 文件路径 + + Raises: + KeyError: 当令牌不存在或已过期时抛出 + FileNotFoundError: 当文件本身已被删除时抛出 + """ + async with self.lock: + await self._cleanup_expired_tokens() + + if file_token not in self.staged_files: + raise KeyError(f"无效或过期的文件 token: {file_token}") + + file_path, _ = self.staged_files.pop(file_token) + if not os.path.exists(file_path): + raise FileNotFoundError(f"文件不存在: {file_path}") + return file_path diff --git a/astrbot/core/message/components.py b/astrbot/core/message/components.py index 73bfa7279..718fd30fb 100644 --- a/astrbot/core/message/components.py +++ b/astrbot/core/message/components.py @@ -462,10 +462,10 @@ class Node(BaseMessageComponent): type: ComponentType = "Node" id: T.Optional[int] = 0 # 忽略 name: T.Optional[str] = "" # qq昵称 - uin: T.Optional[int] = 0 # qq号 + uin: T.Optional[str] = "0" # qq号 content: T.Optional[T.Union[str, list, dict]] = "" # 子消息段列表 seq: T.Optional[T.Union[str, list]] = "" # 忽略 - time: T.Optional[int] = 0 + time: T.Optional[int] = 0 # 忽略 def __init__(self, content: T.Union[str, list, dict, "Node", T.List["Node"]], **_): if isinstance(content, list): @@ -494,8 +494,14 @@ class Nodes(BaseMessageComponent): super().__init__(nodes=nodes, **_) def toDict(self): - return {"messages": [node.toDict() for node in self.nodes]} - + ret = { + "messages": [], + } + for node in self.nodes: + d = node.toDict() + d["data"]["uin"] = str(node.uin) # 转为字符串 + ret["messages"].append(d) + return ret class Xml(BaseMessageComponent): type: ComponentType = "Xml" @@ -561,10 +567,9 @@ class File(BaseMessageComponent): name: T.Optional[str] = "" # 名字 file_: T.Optional[str] = "" # 本地路径 url: T.Optional[str] = "" # url - _downloaded: bool = False # 是否已经下载 - def __init__(self, name: str, file: str, url: str = ""): - """文件消息段。一般情况下请直接使用 file 参数即可,可以传入文件路径或 URL,AstrBot 会自动识别。""" + def __init__(self, name: str, file: str = "", url: str = ""): + """文件消息段。""" super().__init__(name=name, file_=file, url=url) @property @@ -576,22 +581,24 @@ class File(BaseMessageComponent): str: 文件路径 """ if self.file_ and os.path.exists(self.file_): - return self.file_ + return os.path.abspath(self.file_) - if self.url and not self._downloaded: + if self.url: try: loop = asyncio.get_event_loop() if loop.is_running(): - logger.warning( - "不可以在异步上下文中同步等待下载! 请使用 await get_file() 代替" - ) + logger.warning(( + "不可以在异步上下文中同步等待下载! " + "这个警告通常发生于某些逻辑试图通过 .file 获取文件消息段的文件内容。" + "请使用 await get_file() 代替直接获取 .file 字段" + )) return "" else: # 等待下载完成 loop.run_until_complete(self._download_file()) if self.file_ and os.path.exists(self.file_): - return self.file_ + return os.path.abspath(self.file_) except Exception as e: logger.error(f"文件下载失败: {e}") @@ -610,36 +617,31 @@ class File(BaseMessageComponent): else: self.file_ = value - async def get_file(self) -> str: - """ - 异步获取文件 - To 插件开发者: 请注意在使用后清理下载的文件, 以免占用过多空间 + async def get_file(self, allow_return_url: bool=False) -> str: + """异步获取文件。请注意在使用后清理下载的文件, 以免占用过多空间 + Args: + allow_return_url: 是否允许以文件 http 下载链接的形式返回,这允许您自行控制是否需要下载文件。 + 注意,如果为 True,也可能返回文件路径。 Returns: - str: 文件路径 + str: 文件路径或者 http 下载链接 """ if self.file_ and os.path.exists(self.file_): - return self.file_ + return os.path.abspath(self.file_) if self.url: await self._download_file() - return self.file_ + return os.path.abspath(self.file_) return "" async def _download_file(self): """下载文件""" - if self._downloaded: - return - - os.makedirs("data/download", exist_ok=True) + os.makedirs("data/temp", exist_ok=True) filename = self.name or f"{uuid.uuid4().hex}" - file_path = f"data/download/{filename}" - + file_path = f"data/temp/{filename}" await download_file(self.url, file_path) - - self.file_ = file_path - self._downloaded = True + self.file_ = os.path.abspath(file_path) class WechatEmoji(BaseMessageComponent): diff --git a/astrbot/core/pipeline/respond/stage.py b/astrbot/core/pipeline/respond/stage.py index 776f4a625..bff94a64d 100644 --- a/astrbot/core/pipeline/respond/stage.py +++ b/astrbot/core/pipeline/respond/stage.py @@ -26,33 +26,14 @@ class RespondStage(Stage): Comp.Record: lambda comp: bool(comp.file), # 语音 Comp.Video: lambda comp: bool(comp.file), # 视频 Comp.At: lambda comp: bool(comp.qq) or bool(comp.name), # @ - Comp.AtAll: lambda comp: True, # @所有人 - Comp.RPS: lambda comp: True, # 不知道是啥(未完成) - Comp.Dice: lambda comp: True, # 骰子(未完成) - Comp.Shake: lambda comp: True, # 摇一摇(未完成) - Comp.Anonymous: lambda comp: True, # 匿名(未完成) - Comp.Share: lambda comp: bool(comp.url) and bool(comp.title), # 分享 - Comp.Contact: lambda comp: True, # 联系人(未完成) - Comp.Location: lambda comp: bool(comp.lat and comp.lon), # 位置 - Comp.Music: lambda comp: bool(comp._type) - and bool(comp.url) - and bool(comp.audio), # 音乐 Comp.Image: lambda comp: bool(comp.file), # 图片 Comp.Reply: lambda comp: bool(comp.id) and comp.sender_id is not None, # 回复 - Comp.RedBag: lambda comp: bool(comp.title), # 红包 Comp.Poke: lambda comp: comp.id != 0 and comp.qq != 0, # 戳一戳 - Comp.Forward: lambda comp: bool(comp.id and comp.id.strip()), # 转发 Comp.Node: lambda comp: bool(comp.name) and comp.uin != 0 and bool(comp.content), # 一个转发节点 Comp.Nodes: lambda comp: bool(comp.nodes), # 多个转发节点 - Comp.Xml: lambda comp: bool(comp.data and comp.data.strip()), # XML - Comp.Json: lambda comp: bool(comp.data), # JSON - Comp.CardImage: lambda comp: bool(comp.file), # 卡片图片 - Comp.TTS: lambda comp: bool(comp.text and comp.text.strip()), # 语音合成 - Comp.Unknown: lambda comp: bool(comp.text and comp.text.strip()), # 未知消息 - Comp.File: lambda comp: bool(comp.file), # 文件 - Comp.WechatEmoji: lambda comp: bool(comp.md5), # 微信表情 + Comp.File: lambda comp: bool(comp.file_ or comp.url), } async def initialize(self, ctx: PipelineContext): @@ -129,8 +110,6 @@ class RespondStage(Stage): if comp_type in self._component_validators: if self._component_validators[comp_type](comp): return False - else: - logger.info(f"空内容检查: 无法识别的组件类型: {comp_type.__name__}") # 如果所有组件都为空 return True diff --git a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py index 4acb677dd..068a8bf3c 100644 --- a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py +++ b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py @@ -3,8 +3,9 @@ import re from typing import AsyncGenerator, Dict, List from aiocqhttp import CQHttp from astrbot.api.event import AstrMessageEvent, MessageChain -from astrbot.api.message_components import At, Image, Node, Nodes, Plain, Record +from astrbot.api.message_components import At, Image, Node, Nodes, Plain, Record, File from astrbot.api.platform import Group, MessageMember +from astrbot.core import file_token_service, astrbot_config, logger class AiocqhttpMessageEvent(AstrMessageEvent): @@ -34,24 +35,16 @@ class AiocqhttpMessageEvent(AstrMessageEvent): } elif isinstance(segment, At): d["data"] = { - "qq": str(segment.qq) # 转换为字符串 + "qq": str(segment.qq), # 转换为字符串 } ret.append(d) return ret async def send(self, message: MessageChain): - ret = await AiocqhttpMessageEvent._parse_onebot_json(message) - - if not ret: - return - - send_one_by_one = False - for seg in message.chain: - if isinstance(seg, (Node, Nodes)): - # 转发消息不能和普通消息混在一起发送 - send_one_by_one = True - break - + # 转发消息、文件消息不能和普通消息混在一起发送 + send_one_by_one = any( + isinstance(seg, (Node, Nodes, File)) for seg in message.chain + ) if send_one_by_one: for seg in message.chain: if isinstance(seg, (Node, Nodes)): @@ -70,6 +63,26 @@ class AiocqhttpMessageEvent(AstrMessageEvent): await self.bot.call_action( "send_private_forward_msg", **payload ) + elif isinstance(seg, File): + d = seg.toDict() + url_or_path = await seg.get_file(allow_return_url=True) + if url_or_path.startswith("http"): + payload_file = url_or_path + elif callback_host := astrbot_config.get("callback_api_base"): + callback_host = str(callback_host).removesuffix("/") + token = await file_token_service.register_file(url_or_path) + payload_file = f"{callback_host}/api/file/{token}" + logger.debug(f"Generated file callback link: {payload_file}") + else: + payload_file = url_or_path + d["data"] = { + "name": seg.name, + "file": payload_file, + } + await self.bot.send( + self.message_obj.raw_message, + [d], + ) else: await self.bot.send( self.message_obj.raw_message, @@ -79,6 +92,9 @@ class AiocqhttpMessageEvent(AstrMessageEvent): ) await asyncio.sleep(0.5) else: + ret = await AiocqhttpMessageEvent._parse_onebot_json(message) + if not ret: + return await self.bot.send(self.message_obj.raw_message, ret) await super().send(message) diff --git a/astrbot/dashboard/routes/__init__.py b/astrbot/dashboard/routes/__init__.py index 3e24583ed..f9309c3eb 100644 --- a/astrbot/dashboard/routes/__init__.py +++ b/astrbot/dashboard/routes/__init__.py @@ -8,6 +8,7 @@ from .static_file import StaticFileRoute from .chat import ChatRoute from .tools import ToolsRoute # 导入新的ToolsRoute from .conversation import ConversationRoute +from .file import FileRoute __all__ = [ @@ -19,6 +20,7 @@ __all__ = [ "LogRoute", "StaticFileRoute", "ChatRoute", - "ToolsRoute", # 添加新的ToolsRoute + "ToolsRoute", "ConversationRoute", + "FileRoute", ] diff --git a/astrbot/dashboard/routes/file.py b/astrbot/dashboard/routes/file.py new file mode 100644 index 000000000..8ea73d084 --- /dev/null +++ b/astrbot/dashboard/routes/file.py @@ -0,0 +1,24 @@ +from .route import Route, RouteContext +from astrbot import logger +from quart import abort, send_file +from astrbot.core import file_token_service + + +class FileRoute(Route): + def __init__( + self, + context: RouteContext, + ) -> None: + super().__init__(context) + self.routes = { + "/file/": ("GET", self.serve_file), + } + self.register_routes() + + async def serve_file(self, file_token: str): + try: + file_path = await file_token_service.handle_file(file_token) + return await send_file(file_path) + except (FileNotFoundError, KeyError) as e: + logger.warning(str(e)) + return abort(404) diff --git a/astrbot/dashboard/routes/static_file.py b/astrbot/dashboard/routes/static_file.py index 4503a28e5..729fe8547 100644 --- a/astrbot/dashboard/routes/static_file.py +++ b/astrbot/dashboard/routes/static_file.py @@ -28,7 +28,7 @@ class StaticFileRoute(Route): @self.app.errorhandler(404) async def page_not_found(e): - return "404 Not found。如果你初次使用打开面板发现 404, 请参考文档: https://astrbot.app/faq.html。" + return "404 Not found。如果你初次使用打开面板发现 404, 请参考文档: https://astrbot.app/faq.html。如果你正在测试回调地址可达性,显示这段文字说明测试成功了。" async def index(self): return await self.app.send_static_file("index.html") diff --git a/astrbot/dashboard/server.py b/astrbot/dashboard/server.py index 5d1310807..c85ada4e2 100644 --- a/astrbot/dashboard/server.py +++ b/astrbot/dashboard/server.py @@ -52,15 +52,15 @@ class AstrBotDashboard: self.chat_route = ChatRoute(self.context, db, core_lifecycle) self.tools_root = ToolsRoute(self.context, core_lifecycle) self.conversation_route = ConversationRoute(self.context, db, core_lifecycle) + self.file_route = FileRoute(self.context) self.shutdown_event = shutdown_event async def auth_middleware(self): if not request.path.startswith("/api"): return - if request.path == "/api/auth/login": - return - if request.path == "/api/chat/get_file": + allowed_endpoints = ["/api/auth/login", "/api/chat/get_file", "/api/file"] + if any(request.path.startswith(prefix) for prefix in allowed_endpoints): return # claim jwt token = request.headers.get("Authorization")