@@ -14,22 +18,17 @@
-

-

-

-

-

-

+

+

+

+

+

+

+
-
中文 |
-
English |
-
日本語 |
-
繁體中文 |
-
Français
-
Документация |
Блог |
Дорожная карта |
@@ -42,13 +41,32 @@ AstrBot — это универсальная платформа Agent-чатб
## Основные возможности
-1. 💯 Бесплатно и с открытым исходным кодом.
-2. ✨ ИИ-диалоги с LLM, мультимодальность, Agent, MCP, база знаний, настройки личности.
-3. 🤖 Поддержка интеграции с Dify, Alibaba Cloud Bailian, Coze и другими платформами агентов.
-4. 🌐 Мультиплатформенность: QQ, WeChat Work, Feishu, DingTalk, официальные аккаунты WeChat, Telegram, Slack и [другие](#поддерживаемые-платформы-обмена-сообщениями).
-5. 📦 Расширения плагинов с почти 800 плагинами, доступными для установки в один клик.
-6. 💻 Поддержка WebUI.
-7. 🌐 Поддержка интернационализации (i18n).
+1. 💯 Бесплатно & Открытый исходный код.
+2. ✨ Диалоги с ИИ-моделями, мультимодальность, Agent, MCP, Skills, База знаний, Настройка личности, автоматическое сжатие диалогов.
+3. 🤖 Поддержка интеграции с платформами Agents, такими как Dify, Alibaba Cloud Bailian, Coze и др.
+4. 🌐 Мультиплатформенность: поддержка QQ, WeChat для предприятий, Feishu, DingTalk, публичных аккаунтов WeChat, Telegram, Slack и [других](#Поддерживаемые-платформы-обмена-сообщениями).
+5. 📦 Расширение плагинами: доступно почти 800 плагинов для установки в один клик.
+6. 🛡️ Изолированная среда[Agent Sandbox](https://docs.astrbot.app/use/astrbot-agent-sandbox.html): безопасное выполнение любого кода, вызов Shell, повторное использование ресурсов на уровне сессии.
+7. 💻 Поддержка WebUI.
+8. 🌈 Поддержка Web ChatUI: встроенная песочница агента, веб-поиск и др.
+9. 🌐 Поддержка интернационализации (i18n).
+
+
+
+
+
+ | 💙 Ролевые игры & Эмоциональная поддержка |
+ ✨ Проактивный Агент(Agent) |
+ 🚀 Универсальные Агентные возможности |
+ 🧩 Универсальные Агентные (Agentic) возможности |
+
+
+ 
|
+ 
|
+ 
|
+ 
|
+
+
## Быстрый старт
@@ -61,7 +79,8 @@ AstrBot — это универсальная платформа Agent-чатб
#### Развёртывание uv
```bash
-uvx astrbot
+uv tool install astrbot
+astrbot
```
#### Развёртывание BT-Panel
@@ -115,6 +134,16 @@ uv run main.py
Или см. официальную документацию: [Развёртывание AstrBot из исходного кода](https://astrbot.app/deploy/astrbot/cli.html).
+#### Установка через системный пакетный менеджер
+
+##### Arch Linux
+
+```bash
+yay -S astrbot-git
+# или используйте paru
+paru -S astrbot-git
+```
+
## Поддерживаемые платформы обмена сообщениями
**Официально поддерживаемые**
@@ -153,7 +182,7 @@ uv run main.py
- [CompShare](https://www.compshare.cn/?ytag=GPU_YY-gh_astrbot&referral_code=FV7DcGowN4hB5UuXKgpE74)
- [302.AI](https://share.302.ai/rr1M3l)
- [TokenPony](https://www.tokenpony.cn/3YPyf)
-- [SiliconFlow](https://docs.siliconflow.cn/cn/usecases/use-siliconcloud-in-astrbot)
+- [SiliconFlow](https://docs.siliconflow.cn/cn/usercases/use-siliconcloud-in-astrbot)
- [PPIO Cloud](https://ppio.com/user/register?invited_by=AIOONE)
- ModelScope
- OneAPI
@@ -235,13 +264,19 @@ pre-commit install
> [!TIP]
> Если этот проект помог вам в жизни или работе, или если вас интересует его будущее развитие, пожалуйста, поставьте проекту звезду. Это движущая сила поддержки этого проекта с открытым исходным кодом <3
+
[](https://star-history.com/#astrbotdevs/astrbot&Date)
-
+
+
+_Сопровождение и способности никогда не должны быть противоположностями. Мы стремимся создать робота, который сможет как понимать эмоции, оказывать душевную поддержку, так и надёжно выполнять работу._
_私は、高性能ですから!_
+

+
+
diff --git a/README_zh-TW.md b/README_zh-TW.md
index c6df22ea2..7232d8cc7 100644
--- a/README_zh-TW.md
+++ b/README_zh-TW.md
@@ -1,9 +1,13 @@

-
-
+
简体中文 |
+
English |
+
日本語 |
+
Français |
+
Русский
+
@@ -14,22 +18,17 @@
-

-

-

-

-

-

+

+

+

+

+

+

+
-
简体中文 |
-
English |
-
日本語 |
-
Français |
-
Русский
-
文件 |
Blog |
路線圖 |
@@ -43,12 +42,31 @@ AstrBot 是一個開源的一站式 Agent 聊天機器人平台,可接入主
## 主要功能
1. 💯 免費 & 開源。
-2. ✨ AI 大型模型對話,多模態,Agent,MCP,知識庫,人格設定。
-3. 🤖 支援接入 Dify、阿里雲百煉、Coze 等智慧體平台。
-4. 🌐 多平台:QQ、企業微信、飛書、釘釘、微信公眾號、Telegram、Slack 以及[更多](#支援的訊息平台)。
-5. 📦 外掛擴充,已有近 800 個外掛可一鍵安裝。
-6. 💻 WebUI 支援。
-7. 🌐 國際化(i18n)支援。
+2. ✨ AI 大模型對話,多模態,Agent,MCP,Skills,知識庫,人格設定,自動壓縮對話。
+3. 🤖 支援接入 Dify、阿里雲百煉、Coze 等智慧體 (Agent) 平台。
+4. 🌐 多平台,支援 QQ、企業微信、飛書、釘釘、微信公眾號、Telegram、Slack 以及[更多](#支援的訊息平台)。
+5. 📦 插件擴展,已有近 800 個插件可一鍵安裝。
+6. 🛡️ [Agent Sandbox](https://docs.astrbot.app/use/astrbot-agent-sandbox.html) 隔離化環境,安全地執行任何代碼、調用 Shell、會話級資源複用。
+7. 💻 WebUI 支援。
+8. 🌈 Web ChatUI 支援,ChatUI 內置代理沙盒 (Agent Sandbox)、網頁搜尋等。
+9. 🌐 國際化(i18n)支援。
+
+
+
+
+
+ | 💙 角色扮演 & 情感陪伴 |
+ ✨ 主動式 Agent |
+ 🚀 通用 Agentic 能力 |
+ 🧩 900+ 社區外掛程式 |
+
+
+ 
|
+ 
|
+ 
|
+ 
|
+
+
## 快速開始
@@ -61,7 +79,8 @@ AstrBot 是一個開源的一站式 Agent 聊天機器人平台,可接入主
#### uv 部署
```bash
-uvx astrbot
+uv tool install astrbot
+astrbot
```
#### 寶塔面板部署
@@ -115,6 +134,16 @@ uv run main.py
或者請參閱官方文件 [透過原始碼部署 AstrBot](https://astrbot.app/deploy/astrbot/cli.html)。
+#### 系統套件管理員安裝
+
+##### Arch Linux
+
+```bash
+yay -S astrbot-git
+# 或者使用 paru
+paru -S astrbot-git
+```
+
## 支援的訊息平台
**官方維護**
@@ -241,7 +270,12 @@ pre-commit install
-
+
+
+_陪伴與能力從來不應該是對立面。我們希望創造的是一個既能理解情緒、給予陪伴,也能可靠完成工作的機器人。_
_私は、高性能ですから!_
+

+
+
diff --git a/astrbot/cli/__init__.py b/astrbot/cli/__init__.py
index 54ddc8fdd..10b26dff5 100644
--- a/astrbot/cli/__init__.py
+++ b/astrbot/cli/__init__.py
@@ -1 +1 @@
-__version__ = "4.15.0"
+__version__ = "4.17.3"
diff --git a/astrbot/core/agent/runners/dify/dify_agent_runner.py b/astrbot/core/agent/runners/dify/dify_agent_runner.py
index d9a8b7cd6..93f8d3570 100644
--- a/astrbot/core/agent/runners/dify/dify_agent_runner.py
+++ b/astrbot/core/agent/runners/dify/dify_agent_runner.py
@@ -10,7 +10,7 @@ from astrbot.core.provider.entities import (
LLMResponse,
ProviderRequest,
)
-from astrbot.core.utils.astrbot_path import get_astrbot_data_path
+from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from astrbot.core.utils.io import download_file
from ...hooks import BaseAgentRunHooks
@@ -291,8 +291,8 @@ class DifyAgentRunner(BaseAgentRunner[TContext]):
return Comp.Image(file=item["url"], url=item["url"])
case "audio":
# 仅支持 wav
- temp_dir = os.path.join(get_astrbot_data_path(), "temp")
- path = os.path.join(temp_dir, f"{item['filename']}.wav")
+ temp_dir = get_astrbot_temp_path()
+ path = os.path.join(temp_dir, f"dify_{item['filename']}.wav")
await download_file(item["url"], path)
return Comp.Image(file=item["url"], url=item["url"])
case "video":
diff --git a/astrbot/core/agent/runners/tool_loop_agent_runner.py b/astrbot/core/agent/runners/tool_loop_agent_runner.py
index 8fb01bfcb..8309e6674 100644
--- a/astrbot/core/agent/runners/tool_loop_agent_runner.py
+++ b/astrbot/core/agent/runners/tool_loop_agent_runner.py
@@ -91,6 +91,7 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
custom_token_counter: TokenCounter | None = None,
custom_compressor: ContextCompressor | None = None,
tool_schema_mode: str | None = "full",
+ fallback_providers: list[Provider] | None = None,
**kwargs: T.Any,
) -> None:
self.req = request
@@ -120,6 +121,17 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
self.context_manager = ContextManager(self.context_config)
self.provider = provider
+ self.fallback_providers: list[Provider] = []
+ seen_provider_ids: set[str] = {str(provider.provider_config.get("id", ""))}
+ for fallback_provider in fallback_providers or []:
+ fallback_id = str(fallback_provider.provider_config.get("id", ""))
+ if fallback_provider is provider:
+ continue
+ if fallback_id and fallback_id in seen_provider_ids:
+ continue
+ self.fallback_providers.append(fallback_provider)
+ if fallback_id:
+ seen_provider_ids.add(fallback_id)
self.final_llm_resp = None
self._state = AgentState.IDLE
self.tool_executor = tool_executor
@@ -166,16 +178,19 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
self.stats = AgentStats()
self.stats.start_time = time.time()
- async def _iter_llm_responses(self) -> T.AsyncGenerator[LLMResponse, None]:
+ async def _iter_llm_responses(
+ self, *, include_model: bool = True
+ ) -> T.AsyncGenerator[LLMResponse, None]:
"""Yields chunks *and* a final LLMResponse."""
payload = {
"contexts": self.run_context.messages, # list[Message]
"func_tool": self.req.func_tool,
- "model": self.req.model, # NOTE: in fact, this arg is None in most cases
"session_id": self.req.session_id,
"extra_user_content_parts": self.req.extra_user_content_parts, # list[ContentPart]
}
-
+ if include_model:
+ # For primary provider we keep explicit model selection if provided.
+ payload["model"] = self.req.model
if self.streaming:
stream = self.provider.text_chat_stream(**payload)
async for resp in stream: # type: ignore
@@ -183,6 +198,83 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
else:
yield await self.provider.text_chat(**payload)
+ async def _iter_llm_responses_with_fallback(
+ self,
+ ) -> T.AsyncGenerator[LLMResponse, None]:
+ """Wrap _iter_llm_responses with provider fallback handling."""
+ candidates = [self.provider, *self.fallback_providers]
+ total_candidates = len(candidates)
+ last_exception: Exception | None = None
+ last_err_response: LLMResponse | None = None
+
+ for idx, candidate in enumerate(candidates):
+ candidate_id = candidate.provider_config.get("id", "
")
+ is_last_candidate = idx == total_candidates - 1
+ if idx > 0:
+ logger.warning(
+ "Switched from %s to fallback chat provider: %s",
+ self.provider.provider_config.get("id", ""),
+ candidate_id,
+ )
+ self.provider = candidate
+ has_stream_output = False
+ try:
+ async for resp in self._iter_llm_responses(include_model=idx == 0):
+ if resp.is_chunk:
+ has_stream_output = True
+ yield resp
+ continue
+
+ if (
+ resp.role == "err"
+ and not has_stream_output
+ and (not is_last_candidate)
+ ):
+ last_err_response = resp
+ logger.warning(
+ "Chat Model %s returns error response, trying fallback to next provider.",
+ candidate_id,
+ )
+ break
+
+ yield resp
+ return
+
+ if has_stream_output:
+ return
+ except Exception as exc: # noqa: BLE001
+ last_exception = exc
+ logger.warning(
+ "Chat Model %s request error: %s",
+ candidate_id,
+ exc,
+ exc_info=True,
+ )
+ continue
+
+ if last_err_response:
+ yield last_err_response
+ return
+ if last_exception:
+ yield LLMResponse(
+ role="err",
+ completion_text=(
+ "All chat models failed: "
+ f"{type(last_exception).__name__}: {last_exception}"
+ ),
+ )
+ return
+ yield LLMResponse(
+ role="err",
+ completion_text="All available chat models are unavailable.",
+ )
+
+ def _simple_print_message_role(self, tag: str = ""):
+ roles = []
+ for message in self.run_context.messages:
+ roles.append(message.role)
+ logger.debug(f"{tag} RunCtx.messages -> [{len(roles)}] {','.join(roles)}")
+
@override
async def step(self):
"""Process a single step of the agent.
@@ -203,11 +295,13 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
# do truncate and compress
token_usage = self.req.conversation.token_usage if self.req.conversation else 0
+ self._simple_print_message_role("[BefCompact]")
self.run_context.messages = await self.context_manager.process(
self.run_context.messages, trusted_token_usage=token_usage
)
+ self._simple_print_message_role("[AftCompact]")
- async for llm_response in self._iter_llm_responses():
+ async for llm_response in self._iter_llm_responses_with_fallback():
if llm_response.is_chunk:
# update ttft
if self.stats.time_to_first_token == 0:
diff --git a/astrbot/core/astr_main_agent.py b/astrbot/core/astr_main_agent.py
index 6a35c042f..d0054b3f6 100644
--- a/astrbot/core/astr_main_agent.py
+++ b/astrbot/core/astr_main_agent.py
@@ -56,6 +56,7 @@ from astrbot.core.message.components import File, Image, Reply
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.provider import Provider
from astrbot.core.provider.entities import ProviderRequest
+from astrbot.core.provider.manager import llm_tools
from astrbot.core.skills.skill_manager import SkillManager, build_skills_prompt
from astrbot.core.star.context import Context
from astrbot.core.star.star_handler import star_map
@@ -66,6 +67,17 @@ from astrbot.core.tools.cron_tools import (
)
from astrbot.core.utils.file_extract import extract_file_moonshotai
from astrbot.core.utils.llm_metadata import LLM_METADATAS
+from astrbot.core.utils.quoted_message.settings import (
+ SETTINGS as DEFAULT_QUOTED_MESSAGE_SETTINGS,
+)
+from astrbot.core.utils.quoted_message.settings import (
+ QuotedMessageParserSettings,
+)
+from astrbot.core.utils.quoted_message_parser import (
+ extract_quoted_message_images,
+ extract_quoted_message_text,
+)
+from astrbot.core.utils.string_utils import normalize_and_dedupe_strings
@dataclass(slots=True)
@@ -122,6 +134,8 @@ class MainAgentBuildConfig:
provider_settings: dict = field(default_factory=dict)
subagent_orchestrator: dict = field(default_factory=dict)
timezone: str | None = None
+ max_quoted_fallback_images: int = 20
+ """Maximum number of images injected from quoted-message fallback extraction."""
@dataclass(slots=True)
@@ -484,11 +498,29 @@ async def _ensure_img_caption(
logger.error("处理图片描述失败: %s", exc)
+def _append_quoted_image_attachment(req: ProviderRequest, image_path: str) -> None:
+ req.extra_user_content_parts.append(
+ TextPart(text=f"[Image Attachment in quoted message: path {image_path}]")
+ )
+
+
+def _get_quoted_message_parser_settings(
+ provider_settings: dict[str, object] | None,
+) -> QuotedMessageParserSettings:
+ if not isinstance(provider_settings, dict):
+ return DEFAULT_QUOTED_MESSAGE_SETTINGS
+ overrides = provider_settings.get("quoted_message_parser")
+ if not isinstance(overrides, dict):
+ return DEFAULT_QUOTED_MESSAGE_SETTINGS
+ return DEFAULT_QUOTED_MESSAGE_SETTINGS.with_overrides(overrides)
+
+
async def _process_quote_message(
event: AstrMessageEvent,
req: ProviderRequest,
img_cap_prov_id: str,
plugin_context: Context,
+ quoted_message_settings: QuotedMessageParserSettings = DEFAULT_QUOTED_MESSAGE_SETTINGS,
) -> None:
quote = None
for comp in event.message_obj.message:
@@ -500,7 +532,15 @@ async def _process_quote_message(
content_parts = []
sender_info = f"({quote.sender_nickname}): " if quote.sender_nickname else ""
- message_str = quote.message_str or "[Empty Text]"
+ message_str = (
+ await extract_quoted_message_text(
+ event,
+ quote,
+ settings=quoted_message_settings,
+ )
+ or quote.message_str
+ or "[Empty Text]"
+ )
content_parts.append(f"{sender_info}{message_str}")
image_seg = None
@@ -606,11 +646,13 @@ async def _decorate_llm_request(
)
img_cap_prov_id = cfg.get("default_image_caption_provider_id") or ""
+ quoted_message_settings = _get_quoted_message_parser_settings(cfg)
await _process_quote_message(
event,
req,
img_cap_prov_id,
plugin_context,
+ quoted_message_settings,
)
tz = config.timezone
@@ -742,6 +784,14 @@ def _plugin_tool_fix(event: AstrMessageEvent, req: ProviderRequest) -> None:
if plugin.name in event.plugins_name or plugin.reserved:
new_tool_set.add_tool(tool)
req.func_tool = new_tool_set
+ else:
+ # mcp tools
+ tool_set = req.func_tool
+ if not tool_set:
+ tool_set = ToolSet()
+ for tool in llm_tools.func_list:
+ if isinstance(tool, MCPTool):
+ tool_set.add_tool(tool)
async def _handle_webchat(
@@ -863,6 +913,41 @@ def _get_compress_provider(
return provider
+def _get_fallback_chat_providers(
+ provider: Provider, plugin_context: Context, provider_settings: dict
+) -> list[Provider]:
+ fallback_ids = provider_settings.get("fallback_chat_models", [])
+ if not isinstance(fallback_ids, list):
+ logger.warning(
+ "fallback_chat_models setting is not a list, skip fallback providers."
+ )
+ return []
+
+ provider_id = str(provider.provider_config.get("id", ""))
+ seen_provider_ids: set[str] = {provider_id} if provider_id else set()
+ fallbacks: list[Provider] = []
+
+ for fallback_id in fallback_ids:
+ if not isinstance(fallback_id, str) or not fallback_id:
+ continue
+ if fallback_id in seen_provider_ids:
+ continue
+ fallback_provider = plugin_context.get_provider_by_id(fallback_id)
+ if fallback_provider is None:
+ logger.warning("Fallback chat provider `%s` not found, skip.", fallback_id)
+ continue
+ if not isinstance(fallback_provider, Provider):
+ logger.warning(
+ "Fallback chat provider `%s` is invalid type: %s, skip.",
+ fallback_id,
+ type(fallback_provider),
+ )
+ continue
+ fallbacks.append(fallback_provider)
+ seen_provider_ids.add(fallback_id)
+ return fallbacks
+
+
async def build_main_agent(
*,
event: AstrMessageEvent,
@@ -901,6 +986,8 @@ async def build_main_agent(
return None
req.prompt = event.message_str[len(config.provider_wake_prefix) :]
+
+ # media files attachments
for comp in event.message_obj.message:
if isinstance(comp, Image):
image_path = await comp.convert_to_file_path()
@@ -916,6 +1003,81 @@ async def build_main_agent(
text=f"[File Attachment: name {file_name}, path {file_path}]"
)
)
+ # quoted message attachments
+ reply_comps = [
+ comp for comp in event.message_obj.message if isinstance(comp, Reply)
+ ]
+ quoted_message_settings = _get_quoted_message_parser_settings(
+ config.provider_settings
+ )
+ fallback_quoted_image_count = 0
+ for comp in reply_comps:
+ has_embedded_image = False
+ if comp.chain:
+ for reply_comp in comp.chain:
+ if isinstance(reply_comp, Image):
+ has_embedded_image = True
+ image_path = await reply_comp.convert_to_file_path()
+ req.image_urls.append(image_path)
+ _append_quoted_image_attachment(req, image_path)
+ elif isinstance(reply_comp, File):
+ file_path = await reply_comp.get_file()
+ file_name = reply_comp.name or os.path.basename(file_path)
+ req.extra_user_content_parts.append(
+ TextPart(
+ text=(
+ f"[File Attachment in quoted message: "
+ f"name {file_name}, path {file_path}]"
+ )
+ )
+ )
+
+ # Fallback quoted image extraction for reply-id-only payloads, or when
+ # embedded reply chain only contains placeholders (e.g. [Forward Message], [Image]).
+ if not has_embedded_image:
+ try:
+ fallback_images = normalize_and_dedupe_strings(
+ await extract_quoted_message_images(
+ event,
+ comp,
+ settings=quoted_message_settings,
+ )
+ )
+ remaining_limit = max(
+ config.max_quoted_fallback_images
+ - fallback_quoted_image_count,
+ 0,
+ )
+ if remaining_limit <= 0 and fallback_images:
+ logger.warning(
+ "Skip quoted fallback images due to limit=%d for umo=%s",
+ config.max_quoted_fallback_images,
+ event.unified_msg_origin,
+ )
+ continue
+ if len(fallback_images) > remaining_limit:
+ logger.warning(
+ "Truncate quoted fallback images for umo=%s, reply_id=%s from %d to %d",
+ event.unified_msg_origin,
+ getattr(comp, "id", None),
+ len(fallback_images),
+ remaining_limit,
+ )
+ fallback_images = fallback_images[:remaining_limit]
+ for image_ref in fallback_images:
+ if image_ref in req.image_urls:
+ continue
+ req.image_urls.append(image_ref)
+ fallback_quoted_image_count += 1
+ _append_quoted_image_attachment(req, image_ref)
+ except Exception as exc: # noqa: BLE001
+ logger.warning(
+ "Failed to resolve fallback quoted images for umo=%s, reply_id=%s: %s",
+ event.unified_msg_origin,
+ getattr(comp, "id", None),
+ exc,
+ exc_info=True,
+ )
conversation = await _get_session_conv(event, plugin_context)
req.conversation = conversation
@@ -924,6 +1086,7 @@ async def build_main_agent(
if isinstance(req.contexts, str):
req.contexts = json.loads(req.contexts)
+ req.image_urls = normalize_and_dedupe_strings(req.image_urls)
if config.file_extract_enabled:
try:
@@ -1008,6 +1171,9 @@ async def build_main_agent(
truncate_turns=config.dequeue_context_length,
enforce_max_turns=config.max_context_length,
tool_schema_mode=config.tool_schema_mode,
+ fallback_providers=_get_fallback_chat_providers(
+ provider, plugin_context, config.provider_settings
+ ),
)
if apply_reset:
diff --git a/astrbot/core/astr_main_agent_resources.py b/astrbot/core/astr_main_agent_resources.py
index e04291faf..2e0d8b0aa 100644
--- a/astrbot/core/astr_main_agent_resources.py
+++ b/astrbot/core/astr_main_agent_resources.py
@@ -1,6 +1,7 @@
import base64
import json
import os
+import uuid
from pydantic import Field
from pydantic.dataclasses import dataclass
@@ -254,7 +255,9 @@ class SendMessageToUserTool(FunctionTool[AstrAgentContext]):
if "_&exists_" in json.dumps(result):
# Download the file from sandbox
name = os.path.basename(path)
- local_path = os.path.join(get_astrbot_temp_path(), name)
+ local_path = os.path.join(
+ get_astrbot_temp_path(), f"sandbox_{uuid.uuid4().hex[:4]}_{name}"
+ )
await sb.download_file(path, local_path)
logger.info(f"Downloaded file from sandbox: {path} -> {local_path}")
return local_path, True
@@ -366,11 +369,11 @@ class SendMessageToUserTool(FunctionTool[AstrAgentContext]):
MessageChain(chain=components),
)
- if file_from_sandbox:
- try:
- os.remove(local_path)
- except Exception as e:
- logger.error(f"Error removing temp file {local_path}: {e}")
+ # if file_from_sandbox:
+ # try:
+ # os.remove(local_path)
+ # except Exception as e:
+ # logger.error(f"Error removing temp file {local_path}: {e}")
return f"Message sent to session {target_session}"
diff --git a/astrbot/core/backup/constants.py b/astrbot/core/backup/constants.py
index b45b702e7..be206b307 100644
--- a/astrbot/core/backup/constants.py
+++ b/astrbot/core/backup/constants.py
@@ -11,6 +11,7 @@ from astrbot.core.db.po import (
CommandConflict,
ConversationV2,
Persona,
+ PersonaFolder,
PlatformMessageHistory,
PlatformSession,
PlatformStat,
@@ -39,6 +40,7 @@ MAIN_DB_MODELS: dict[str, type[SQLModel]] = {
"platform_stats": PlatformStat,
"conversations": ConversationV2,
"personas": Persona,
+ "persona_folders": PersonaFolder,
"preferences": Preference,
"platform_message_history": PlatformMessageHistory,
"platform_sessions": PlatformSession,
diff --git a/astrbot/core/computer/tools/fs.py b/astrbot/core/computer/tools/fs.py
index 9cf590a61..126da4258 100644
--- a/astrbot/core/computer/tools/fs.py
+++ b/astrbot/core/computer/tools/fs.py
@@ -1,4 +1,5 @@
import os
+import uuid
from dataclasses import dataclass, field
from astrbot.api import FunctionTool, logger
@@ -167,7 +168,9 @@ class FileDownloadTool(FunctionTool):
try:
name = os.path.basename(remote_path)
- local_path = os.path.join(get_astrbot_temp_path(), name)
+ local_path = os.path.join(
+ get_astrbot_temp_path(), f"sandbox_{uuid.uuid4().hex[:4]}_{name}"
+ )
# Download file from sandbox
await sb.download_file(remote_path, local_path)
@@ -183,12 +186,12 @@ class FileDownloadTool(FunctionTool):
logger.error(f"Error sending file message: {e}")
# remove
- try:
- os.remove(local_path)
- except Exception as e:
- logger.error(f"Error removing temp file {local_path}: {e}")
+ # try:
+ # os.remove(local_path)
+ # except Exception as e:
+ # logger.error(f"Error removing temp file {local_path}: {e}")
- return f"File downloaded successfully to {local_path} and sent to user. The file has been removed from local storage."
+ return f"File downloaded successfully to {local_path} and sent to user."
return f"File downloaded successfully to {local_path}"
except Exception as e:
diff --git a/astrbot/core/computer/tools/python.py b/astrbot/core/computer/tools/python.py
index 333f442f9..2c4ae6c8a 100644
--- a/astrbot/core/computer/tools/python.py
+++ b/astrbot/core/computer/tools/python.py
@@ -5,8 +5,9 @@ import mcp
from astrbot.api import FunctionTool
from astrbot.core.agent.run_context import ContextWrapper
from astrbot.core.agent.tool import ToolExecResult
-from astrbot.core.astr_agent_context import AstrAgentContext
+from astrbot.core.astr_agent_context import AstrAgentContext, AstrMessageEvent
from astrbot.core.computer.computer_client import get_booter, get_local_booter
+from astrbot.core.message.message_event_result import MessageChain
param_schema = {
"type": "object",
@@ -25,7 +26,7 @@ param_schema = {
}
-def handle_result(result: dict) -> ToolExecResult:
+async def handle_result(result: dict, event: AstrMessageEvent) -> ToolExecResult:
data = result.get("data", {})
output = data.get("output", {})
error = data.get("error", "")
@@ -44,6 +45,9 @@ def handle_result(result: dict) -> ToolExecResult:
type="image", data=img["image/png"], mimeType="image/png"
)
)
+
+ if event.get_platform_name() == "webchat":
+ await event.send(message=MessageChain().base64_image(img["image/png"]))
if text:
resp.content.append(mcp.types.TextContent(type="text", text=text))
@@ -68,7 +72,7 @@ class PythonTool(FunctionTool):
)
try:
result = await sb.python.exec(code, silent=silent)
- return handle_result(result)
+ return await handle_result(result, context.context.event)
except Exception as e:
return f"Error executing code: {str(e)}"
@@ -89,6 +93,6 @@ class LocalPythonTool(FunctionTool):
sb = get_local_booter()
try:
result = await sb.python.exec(code, silent=silent)
- return handle_result(result)
+ return await handle_result(result, context.context.event)
except Exception as e:
return f"Error executing code: {str(e)}"
diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py
index 047eefcf2..158be3f72 100644
--- a/astrbot/core/config/default.py
+++ b/astrbot/core/config/default.py
@@ -5,7 +5,7 @@ from typing import Any, TypedDict
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
-VERSION = "4.15.0"
+VERSION = "4.17.3"
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db")
WEBHOOK_SUPPORTED_PLATFORMS = [
@@ -15,6 +15,7 @@ WEBHOOK_SUPPORTED_PLATFORMS = [
"wecom_ai_bot",
"slack",
"lark",
+ "line",
]
# 默认配置
@@ -67,6 +68,7 @@ DEFAULT_CONFIG = {
"provider_settings": {
"enable": True,
"default_provider_id": "",
+ "fallback_chat_models": [],
"default_image_caption_provider_id": "",
"image_caption_prompt": "Please describe the image using Chinese.",
"provider_pool": ["*"], # "*" 表示使用所有可用的提供者
@@ -99,6 +101,13 @@ DEFAULT_CONFIG = {
"streaming_response": False,
"show_tool_use_status": False,
"sanitize_context_by_modalities": False,
+ "max_quoted_fallback_images": 20,
+ "quoted_message_parser": {
+ "max_component_chain_depth": 4,
+ "max_forward_node_depth": 6,
+ "max_forward_fetch": 32,
+ "warn_on_action_failure": False,
+ },
"agent_runner_type": "local",
"dify_agent_runner_provider_id": "",
"coze_agent_runner_provider_id": "",
@@ -191,6 +200,12 @@ DEFAULT_CONFIG = {
"host": "0.0.0.0",
"port": 6185,
"disable_access_log": True,
+ "ssl": {
+ "enable": False,
+ "cert_file": "",
+ "key_file": "",
+ "ca_certs": "",
+ },
},
"platform": [],
"platform_specific": {
@@ -207,6 +222,7 @@ DEFAULT_CONFIG = {
"log_file_enable": False,
"log_file_path": "logs/astrbot.log",
"log_file_max_mb": 20,
+ "temp_dir_max_size": 1024,
"trace_enable": False,
"trace_log_enable": False,
"trace_log_path": "logs/astrbot.trace.log",
@@ -411,6 +427,7 @@ CONFIG_METADATA_2 = {
"slack_webhook_port": 6197,
"slack_webhook_path": "/astrbot-slack-webhook/callback",
},
+ # LINE's config is located in line_adapter.py
"Satori": {
"id": "satori",
"type": "satori",
@@ -1016,6 +1033,18 @@ CONFIG_METADATA_2 = {
"proxy": "",
"custom_headers": {},
},
+ "NVIDIA": {
+ "id": "nvidia",
+ "provider": "nvidia",
+ "type": "openai_chat_completion",
+ "provider_type": "chat_completion",
+ "enable": True,
+ "key": [],
+ "api_base": "https://integrate.api.nvidia.com/v1",
+ "timeout": 120,
+ "proxy": "",
+ "custom_headers": {},
+ },
"Azure OpenAI": {
"id": "azure_openai",
"provider": "azure",
@@ -2201,6 +2230,10 @@ CONFIG_METADATA_2 = {
"default_provider_id": {
"type": "string",
},
+ "fallback_chat_models": {
+ "type": "list",
+ "items": {"type": "string"},
+ },
"wake_prefix": {
"type": "string",
},
@@ -2395,9 +2428,23 @@ CONFIG_METADATA_2 = {
"type": "string",
"options": ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
},
+ "dashboard.ssl.enable": {"type": "bool"},
+ "dashboard.ssl.cert_file": {
+ "type": "string",
+ "condition": {"dashboard.ssl.enable": True},
+ },
+ "dashboard.ssl.key_file": {
+ "type": "string",
+ "condition": {"dashboard.ssl.enable": True},
+ },
+ "dashboard.ssl.ca_certs": {
+ "type": "string",
+ "condition": {"dashboard.ssl.enable": True},
+ },
"log_file_enable": {"type": "bool"},
"log_file_path": {"type": "string", "condition": {"log_file_enable": True}},
"log_file_max_mb": {"type": "int", "condition": {"log_file_enable": True}},
+ "temp_dir_max_size": {"type": "int"},
"trace_log_enable": {"type": "bool"},
"trace_log_path": {
"type": "string",
@@ -2497,15 +2544,22 @@ CONFIG_METADATA_3 = {
},
"ai": {
"description": "模型",
- "hint": "当使用非内置 Agent 执行器时,默认聊天模型和默认图片转述模型可能会无效,但某些插件会依赖此配置项来调用 AI 能力。",
+ "hint": "当使用非内置 Agent 执行器时,默认对话模型和默认图片转述模型可能会无效,但某些插件会依赖此配置项来调用 AI 能力。",
"type": "object",
"items": {
"provider_settings.default_provider_id": {
- "description": "默认聊天模型",
+ "description": "默认对话模型",
"type": "string",
"_special": "select_provider",
"hint": "留空时使用第一个模型",
},
+ "provider_settings.fallback_chat_models": {
+ "description": "回退对话模型列表",
+ "type": "list",
+ "items": {"type": "string"},
+ "_special": "select_providers",
+ "hint": "主聊天模型请求失败时,按顺序切换到这些模型。",
+ },
"provider_settings.default_image_caption_provider_id": {
"description": "默认图片转述模型",
"type": "string",
@@ -2946,6 +3000,46 @@ CONFIG_METADATA_3 = {
"provider_settings.agent_runner_type": "local",
},
},
+ "provider_settings.max_quoted_fallback_images": {
+ "description": "引用图片回退解析上限",
+ "type": "int",
+ "hint": "引用/转发消息回退解析图片时的最大注入数量,超出会截断。",
+ "condition": {
+ "provider_settings.agent_runner_type": "local",
+ },
+ },
+ "provider_settings.quoted_message_parser.max_component_chain_depth": {
+ "description": "引用解析组件链深度",
+ "type": "int",
+ "hint": "解析 Reply 组件链时允许的最大递归深度。",
+ "condition": {
+ "provider_settings.agent_runner_type": "local",
+ },
+ },
+ "provider_settings.quoted_message_parser.max_forward_node_depth": {
+ "description": "引用解析转发节点深度",
+ "type": "int",
+ "hint": "解析合并转发节点时允许的最大递归深度。",
+ "condition": {
+ "provider_settings.agent_runner_type": "local",
+ },
+ },
+ "provider_settings.quoted_message_parser.max_forward_fetch": {
+ "description": "引用解析转发拉取上限",
+ "type": "int",
+ "hint": "递归拉取 get_forward_msg 的最大次数。",
+ "condition": {
+ "provider_settings.agent_runner_type": "local",
+ },
+ },
+ "provider_settings.quoted_message_parser.warn_on_action_failure": {
+ "description": "引用解析 action 失败告警",
+ "type": "bool",
+ "hint": "开启后,get_msg/get_forward_msg 全部尝试失败时输出 warning 日志。",
+ "condition": {
+ "provider_settings.agent_runner_type": "local",
+ },
+ },
"provider_settings.max_agent_step": {
"description": "工具调用轮数上限",
"type": "int",
@@ -3397,6 +3491,29 @@ CONFIG_METADATA_3_SYSTEM = {
"hint": "控制台输出日志的级别。",
"options": ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
},
+ "dashboard.ssl.enable": {
+ "description": "启用 WebUI HTTPS",
+ "type": "bool",
+ "hint": "启用后,WebUI 将直接使用 HTTPS 提供服务。",
+ },
+ "dashboard.ssl.cert_file": {
+ "description": "SSL 证书文件路径",
+ "type": "string",
+ "hint": "证书文件路径(PEM)。支持绝对路径和相对路径(相对于当前工作目录)。",
+ "condition": {"dashboard.ssl.enable": True},
+ },
+ "dashboard.ssl.key_file": {
+ "description": "SSL 私钥文件路径",
+ "type": "string",
+ "hint": "私钥文件路径(PEM)。支持绝对路径和相对路径(相对于当前工作目录)。",
+ "condition": {"dashboard.ssl.enable": True},
+ },
+ "dashboard.ssl.ca_certs": {
+ "description": "SSL CA 证书文件路径",
+ "type": "string",
+ "hint": "可选。用于指定 CA 证书文件路径。",
+ "condition": {"dashboard.ssl.enable": True},
+ },
"log_file_enable": {
"description": "启用文件日志",
"type": "bool",
@@ -3412,6 +3529,11 @@ CONFIG_METADATA_3_SYSTEM = {
"type": "int",
"hint": "超过大小后自动轮转,默认 20MB。",
},
+ "temp_dir_max_size": {
+ "description": "临时目录大小上限 (MB)",
+ "type": "int",
+ "hint": "用于限制 data/temp 目录总大小,单位为 MB。系统每 10 分钟检查一次,超限时按文件修改时间从旧到新删除,释放约 30% 当前体积。",
+ },
"trace_log_enable": {
"description": "启用 Trace 文件日志",
"type": "bool",
diff --git a/astrbot/core/core_lifecycle.py b/astrbot/core/core_lifecycle.py
index 6b36cca0d..758cf1ccd 100644
--- a/astrbot/core/core_lifecycle.py
+++ b/astrbot/core/core_lifecycle.py
@@ -37,6 +37,7 @@ from astrbot.core.umop_config_router import UmopConfigRouter
from astrbot.core.updator import AstrBotUpdator
from astrbot.core.utils.llm_metadata import update_llm_metadata
from astrbot.core.utils.migra_helper import migra
+from astrbot.core.utils.temp_dir_cleaner import TempDirCleaner
from . import astrbot_config, html_renderer
from .event_bus import EventBus
@@ -57,6 +58,7 @@ class AstrBotCoreLifecycle:
self.subagent_orchestrator: SubAgentOrchestrator | None = None
self.cron_manager: CronJobManager | None = None
+ self.temp_dir_cleaner: TempDirCleaner | None = None
# 设置代理
proxy_config = self.astrbot_config.get("http_proxy", "")
@@ -125,6 +127,12 @@ class AstrBotCoreLifecycle:
ucr=self.umop_config_router,
sp=sp,
)
+ self.temp_dir_cleaner = TempDirCleaner(
+ max_size_getter=lambda: self.astrbot_config_mgr.default_conf.get(
+ TempDirCleaner.CONFIG_KEY,
+ TempDirCleaner.DEFAULT_MAX_SIZE,
+ ),
+ )
# apply migration
try:
@@ -238,6 +246,12 @@ class AstrBotCoreLifecycle:
self.cron_manager.start(self.star_context),
name="cron_manager",
)
+ temp_dir_cleaner_task = None
+ if self.temp_dir_cleaner:
+ temp_dir_cleaner_task = asyncio.create_task(
+ self.temp_dir_cleaner.run(),
+ name="temp_dir_cleaner",
+ )
# 把插件中注册的所有协程函数注册到事件总线中并执行
extra_tasks = []
@@ -247,6 +261,8 @@ class AstrBotCoreLifecycle:
tasks_ = [event_bus_task, *(extra_tasks if extra_tasks else [])]
if cron_task:
tasks_.append(cron_task)
+ if temp_dir_cleaner_task:
+ tasks_.append(temp_dir_cleaner_task)
for task in tasks_:
self.curr_tasks.append(
asyncio.create_task(self._task_wrapper(task), name=task.get_name()),
@@ -298,6 +314,9 @@ class AstrBotCoreLifecycle:
async def stop(self) -> None:
"""停止 AstrBot 核心生命周期管理类, 取消所有当前任务并终止各个管理器."""
+ if self.temp_dir_cleaner:
+ await self.temp_dir_cleaner.stop()
+
# 请求停止所有正在运行的异步任务
for task in self.curr_tasks:
task.cancel()
diff --git a/astrbot/core/log.py b/astrbot/core/log.py
index 264c43197..66a2f3154 100644
--- a/astrbot/core/log.py
+++ b/astrbot/core/log.py
@@ -1,24 +1,4 @@
-"""日志系统, 用于支持核心组件和插件的日志记录, 提供了日志订阅功能
-
-const:
- CACHED_SIZE: 日志缓存大小, 用于限制缓存的日志数量
- log_color_config: 日志颜色配置, 定义了不同日志级别的颜色
-
-class:
- LogBroker: 日志代理类, 用于缓存和分发日志消息
- LogQueueHandler: 日志处理器, 用于将日志消息发送到 LogBroker
- LogManager: 日志管理器, 用于创建和配置日志记录器
-
-function:
- is_plugin_path: 检查文件路径是否来自插件目录
- get_short_level_name: 将日志级别名称转换为四个字母的缩写
-
-工作流程:
-1. 通过 LogManager.GetLogger() 获取日志器, 配置了控制台输出和多个格式化过滤器
-2. 通过 set_queue_handler() 设置日志处理器, 将日志消息发送到 LogBroker
-3. logBroker 维护一个订阅者列表, 负责将日志分发给所有订阅者
-4. 订阅者可以使用 register() 方法注册到 LogBroker, 订阅日志流
-"""
+"""日志系统,统一将标准 logging 输出转发到 loguru。"""
import asyncio
import logging
@@ -27,54 +7,59 @@ import sys
import time
from asyncio import Queue
from collections import deque
-from logging.handlers import RotatingFileHandler
+from typing import TYPE_CHECKING
-import colorlog
+from loguru import logger as _raw_loguru_logger
from astrbot.core.config.default import VERSION
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
-# 日志缓存大小
CACHED_SIZE = 500
-# 日志颜色配置
-log_color_config = {
- "DEBUG": "green",
- "INFO": "bold_cyan",
- "WARNING": "bold_yellow",
- "ERROR": "red",
- "CRITICAL": "bold_red",
- "RESET": "reset",
- "asctime": "green",
-}
+
+if TYPE_CHECKING:
+ from loguru import Record
-def is_plugin_path(pathname):
- """检查文件路径是否来自插件目录
+class _RecordEnricherFilter(logging.Filter):
+ """为 logging.LogRecord 注入 AstrBot 日志字段。"""
- Args:
- pathname (str): 文件路径
+ def filter(self, record: logging.LogRecord) -> bool:
+ record.plugin_tag = "[Plug]" if _is_plugin_path(record.pathname) else "[Core]"
+ record.short_levelname = _get_short_level_name(record.levelname)
+ record.astrbot_version_tag = (
+ f" [v{VERSION}]" if record.levelno >= logging.WARNING else ""
+ )
+ record.source_file = _build_source_file(record.pathname)
+ record.source_line = record.lineno
+ record.is_trace = record.name == "astrbot.trace"
+ return True
- Returns:
- bool: 如果路径来自插件目录,则返回 True,否则返回 False
- """
+class _QueueAnsiColorFilter(logging.Filter):
+ """Attach ANSI color prefix for WebUI console rendering."""
+
+ _LEVEL_COLOR = {
+ "DEBUG": "\u001b[1;34m",
+ "INFO": "\u001b[1;36m",
+ "WARNING": "\u001b[1;33m",
+ "ERROR": "\u001b[31m",
+ "CRITICAL": "\u001b[1;31m",
+ }
+
+ def filter(self, record: logging.LogRecord) -> bool:
+ record.ansi_prefix = self._LEVEL_COLOR.get(record.levelname, "\u001b[0m")
+ record.ansi_reset = "\u001b[0m"
+ return True
+
+
+def _is_plugin_path(pathname: str | None) -> bool:
if not pathname:
return False
-
norm_path = os.path.normpath(pathname)
return ("data/plugins" in norm_path) or ("astrbot/builtin_stars/" in norm_path)
-def get_short_level_name(level_name):
- """将日志级别名称转换为四个字母的缩写
-
- Args:
- level_name (str): 日志级别名称, 如 "DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"
-
- Returns:
- str: 四个字母的日志级别缩写
-
- """
+def _get_short_level_name(level_name: str) -> str:
level_map = {
"DEBUG": "DBUG",
"INFO": "INFO",
@@ -85,44 +70,75 @@ def get_short_level_name(level_name):
return level_map.get(level_name, level_name[:4].upper())
-class LogBroker:
- """日志代理类, 用于缓存和分发日志消息
+def _build_source_file(pathname: str | None) -> str:
+ if not pathname:
+ return "unknown"
+ dirname = os.path.dirname(pathname)
+ return (
+ os.path.basename(dirname) + "." + os.path.basename(pathname).replace(".py", "")
+ )
- 发布-订阅模式
- """
+
+def _patch_record(record: "Record") -> None:
+ extra = record["extra"]
+ extra.setdefault("plugin_tag", "[Core]")
+ extra.setdefault("short_levelname", _get_short_level_name(record["level"].name))
+ level_no = record["level"].no
+ extra.setdefault("astrbot_version_tag", f" [v{VERSION}]" if level_no >= 30 else "")
+ extra.setdefault("source_file", _build_source_file(record["file"].path))
+ extra.setdefault("source_line", record["line"])
+ extra.setdefault("is_trace", False)
+
+
+_loguru = _raw_loguru_logger.patch(_patch_record)
+
+
+class _LoguruInterceptHandler(logging.Handler):
+ """将 logging 记录转发到 loguru。"""
+
+ def emit(self, record: logging.LogRecord) -> None:
+ try:
+ level: str | int = _loguru.level(record.levelname).name
+ except ValueError:
+ level = record.levelno
+
+ payload = {
+ "plugin_tag": getattr(record, "plugin_tag", "[Core]"),
+ "short_levelname": getattr(
+ record,
+ "short_levelname",
+ _get_short_level_name(record.levelname),
+ ),
+ "astrbot_version_tag": getattr(record, "astrbot_version_tag", ""),
+ "source_file": getattr(
+ record, "source_file", _build_source_file(record.pathname)
+ ),
+ "source_line": getattr(record, "source_line", record.lineno),
+ "is_trace": getattr(record, "is_trace", record.name == "astrbot.trace"),
+ }
+
+ _loguru.bind(**payload).opt(exception=record.exc_info).log(
+ level,
+ record.getMessage(),
+ )
+
+
+class LogBroker:
+ """日志代理类,用于缓存和分发日志消息。"""
def __init__(self) -> None:
- self.log_cache = deque(maxlen=CACHED_SIZE) # 环形缓冲区, 保存最近的日志
- self.subscribers: list[Queue] = [] # 订阅者列表
+ self.log_cache = deque(maxlen=CACHED_SIZE)
+ self.subscribers: list[Queue] = []
def register(self) -> Queue:
- """注册新的订阅者, 并给每个订阅者返回一个带有日志缓存的队列
-
- Returns:
- Queue: 订阅者的队列, 可用于接收日志消息
-
- """
q = Queue(maxsize=CACHED_SIZE + 10)
self.subscribers.append(q)
return q
def unregister(self, q: Queue) -> None:
- """取消订阅
-
- Args:
- q (Queue): 需要取消订阅的队列
-
- """
self.subscribers.remove(q)
def publish(self, log_entry: dict) -> None:
- """发布新日志到所有订阅者, 使用非阻塞方式投递, 避免一个订阅者阻塞整个系统
-
- Args:
- log_entry (dict): 日志消息, 包含日志级别和日志内容.
- example: {"level": "INFO", "data": "This is a log message.", "time": "2023-10-01 12:00:00"}
-
- """
self.log_cache.append(log_entry)
for q in self.subscribers:
try:
@@ -132,23 +148,13 @@ class LogBroker:
class LogQueueHandler(logging.Handler):
- """日志处理器, 用于将日志消息发送到 LogBroker
-
- 继承自 logging.Handler
- """
+ """日志处理器,用于将日志消息发送到 LogBroker。"""
def __init__(self, log_broker: LogBroker) -> None:
super().__init__()
self.log_broker = log_broker
- def emit(self, record) -> None:
- """日志处理的入口方法, 接受一个日志记录, 转换为字符串后由 LogBroker 发布
- 这个方法会在每次日志记录时被调用
-
- Args:
- record (logging.LogRecord): 日志记录对象, 包含日志信息
-
- """
+ def emit(self, record: logging.LogRecord) -> None:
log_entry = self.format(record)
self.log_broker.publish(
{
@@ -160,117 +166,16 @@ class LogQueueHandler(logging.Handler):
class LogManager:
- """日志管理器, 用于创建和配置日志记录器
+ _LOGGER_HANDLER_FLAG = "_astrbot_loguru_handler"
+ _ENRICH_FILTER_FLAG = "_astrbot_enrich_filter"
- 提供了获取默认日志记录器logger和设置队列处理器的方法
- """
-
- _FILE_HANDLER_FLAG = "_astrbot_file_handler"
- _TRACE_FILE_HANDLER_FLAG = "_astrbot_trace_file_handler"
-
- @classmethod
- def GetLogger(cls, log_name: str = "default"):
- """获取指定名称的日志记录器logger
-
- Args:
- log_name (str): 日志记录器的名称, 默认为 "default"
-
- Returns:
- logging.Logger: 返回配置好的日志记录器
-
- """
- logger = logging.getLogger(log_name)
- # 检查该logger或父级logger是否已经有处理器, 如果已经有处理器, 直接返回该logger, 避免重复配置
- if logger.hasHandlers():
- return logger
- # 如果logger没有处理器
- console_handler = logging.StreamHandler(
- sys.stdout,
- ) # 创建一个StreamHandler用于控制台输出
- console_handler.setLevel(
- logging.DEBUG,
- ) # 将日志级别设置为DEBUG(最低级别, 显示所有日志), *如果插件没有设置级别, 默认为DEBUG
-
- # 创建彩色日志格式化器, 输出日志格式为: [时间] [插件标签] [日志级别] [文件名:行号]: 日志消息
- console_formatter = colorlog.ColoredFormatter(
- fmt="%(log_color)s [%(asctime)s] %(plugin_tag)s [%(short_levelname)-4s]%(astrbot_version_tag)s [%(filename)s:%(lineno)d]: %(message)s %(reset)s",
- datefmt="%H:%M:%S",
- log_colors=log_color_config,
- )
-
- class PluginFilter(logging.Filter):
- """插件过滤器类, 用于标记日志来源是插件还是核心组件"""
-
- def filter(self, record) -> bool:
- record.plugin_tag = (
- "[Plug]" if is_plugin_path(record.pathname) else "[Core]"
- )
- return True
-
- class FileNameFilter(logging.Filter):
- """文件名过滤器类, 用于修改日志记录的文件名格式
- 例如: 将文件路径 /path/to/file.py 转换为 file. 格式
- """
-
- # 获取这个文件和父文件夹的名字:. 并且去除 .py
- def filter(self, record) -> bool:
- dirname = os.path.dirname(record.pathname)
- record.filename = (
- os.path.basename(dirname)
- + "."
- + os.path.basename(record.pathname).replace(".py", "")
- )
- return True
-
- class LevelNameFilter(logging.Filter):
- """短日志级别名称过滤器类, 用于将日志级别名称转换为四个字母的缩写"""
-
- # 添加短日志级别名称
- def filter(self, record) -> bool:
- record.short_levelname = get_short_level_name(record.levelname)
- return True
-
- class AstrBotVersionTagFilter(logging.Filter):
- """在 WARNING 及以上级别日志后追加当前 AstrBot 版本号。"""
-
- def filter(self, record) -> bool:
- if record.levelno >= logging.WARNING:
- record.astrbot_version_tag = f" [v{VERSION}]"
- else:
- record.astrbot_version_tag = ""
- return True
-
- console_handler.setFormatter(console_formatter) # 设置处理器的格式化器
- logger.addFilter(PluginFilter()) # 添加插件过滤器
- logger.addFilter(FileNameFilter()) # 添加文件名过滤器
- logger.addFilter(LevelNameFilter()) # 添加级别名称过滤器
- logger.addFilter(AstrBotVersionTagFilter()) # 追加版本号(WARNING 及以上)
- logger.setLevel(logging.DEBUG) # 设置日志级别为DEBUG
- logger.addHandler(console_handler) # 添加处理器到logger
-
- return logger
-
- @classmethod
- def set_queue_handler(cls, logger: logging.Logger, log_broker: LogBroker) -> None:
- """设置队列处理器, 用于将日志消息发送到 LogBroker
-
- Args:
- logger (logging.Logger): 日志记录器
- log_broker (LogBroker): 日志代理类, 用于缓存和分发日志消息
-
- """
- handler = LogQueueHandler(log_broker)
- handler.setLevel(logging.DEBUG)
- if logger.handlers:
- handler.setFormatter(logger.handlers[0].formatter)
- else:
- # 为队列处理器设置相同格式的formatter
- handler.setFormatter(
- logging.Formatter(
- "[%(asctime)s] [%(short_levelname)s] %(plugin_tag)s[%(filename)s:%(lineno)d]: %(message)s",
- ),
- )
- logger.addHandler(handler)
+ _configured = False
+ _console_sink_id: int | None = None
+ _file_sink_id: int | None = None
+ _trace_sink_id: int | None = None
+ _NOISY_LOGGER_LEVELS: dict[str, int] = {
+ "aiosqlite": logging.WARNING,
+ }
@classmethod
def _default_log_path(cls) -> str:
@@ -285,79 +190,147 @@ class LogManager:
return os.path.join(get_astrbot_data_path(), configured_path)
@classmethod
- def _get_file_handlers(cls, logger: logging.Logger) -> list[logging.Handler]:
- return [
- handler
- for handler in logger.handlers
- if getattr(handler, cls._FILE_HANDLER_FLAG, False)
- ]
+ def _setup_loguru(cls) -> None:
+ if cls._configured:
+ return
- @classmethod
- def _get_trace_file_handlers(cls, logger: logging.Logger) -> list[logging.Handler]:
- return [
- handler
- for handler in logger.handlers
- if getattr(handler, cls._TRACE_FILE_HANDLER_FLAG, False)
- ]
-
- @classmethod
- def _remove_file_handlers(cls, logger: logging.Logger) -> None:
- for handler in cls._get_file_handlers(logger):
- logger.removeHandler(handler)
- try:
- handler.close()
- except Exception:
- pass
-
- @classmethod
- def _remove_trace_file_handlers(cls, logger: logging.Logger) -> None:
- for handler in cls._get_trace_file_handlers(logger):
- logger.removeHandler(handler)
- try:
- handler.close()
- except Exception:
- pass
-
- @classmethod
- def _add_file_handler(
- cls,
- logger: logging.Logger,
- file_path: str,
- max_mb: int | None = None,
- backup_count: int = 3,
- trace: bool = False,
- ) -> None:
- os.makedirs(os.path.dirname(file_path) or ".", exist_ok=True)
- max_bytes = 0
- if max_mb and max_mb > 0:
- max_bytes = max_mb * 1024 * 1024
- if max_bytes > 0:
- file_handler = RotatingFileHandler(
- file_path,
- maxBytes=max_bytes,
- backupCount=backup_count,
- encoding="utf-8",
- )
- else:
- file_handler = logging.FileHandler(file_path, encoding="utf-8")
- file_handler.setLevel(logger.level)
- if trace:
- formatter = logging.Formatter(
- "[%(asctime)s] %(message)s",
- datefmt="%Y-%m-%d %H:%M:%S",
- )
- else:
- formatter = logging.Formatter(
- "[%(asctime)s] %(plugin_tag)s [%(short_levelname)s]%(astrbot_version_tag)s [%(filename)s:%(lineno)d]: %(message)s",
- datefmt="%Y-%m-%d %H:%M:%S",
- )
- file_handler.setFormatter(formatter)
- setattr(
- file_handler,
- cls._TRACE_FILE_HANDLER_FLAG if trace else cls._FILE_HANDLER_FLAG,
- True,
+ _loguru.remove()
+ cls._console_sink_id = _loguru.add(
+ sys.stdout,
+ level="DEBUG",
+ colorize=True,
+ filter=lambda record: not record["extra"].get("is_trace", False),
+ format=(
+ "[{time:HH:mm:ss.SSS}] {extra[plugin_tag]} "
+ "[{extra[short_levelname]}]{extra[astrbot_version_tag]} "
+ "[{extra[source_file]}:{extra[source_line]}]: {message}"
+ ),
+ )
+ cls._configured = True
+
+ @classmethod
+ def _setup_root_bridge(cls) -> None:
+ root_logger = logging.getLogger()
+
+ has_handler = any(
+ getattr(handler, cls._LOGGER_HANDLER_FLAG, False)
+ for handler in root_logger.handlers
+ )
+ if not has_handler:
+ handler = _LoguruInterceptHandler()
+ setattr(handler, cls._LOGGER_HANDLER_FLAG, True)
+ root_logger.addHandler(handler)
+ root_logger.setLevel(logging.DEBUG)
+ for name, level in cls._NOISY_LOGGER_LEVELS.items():
+ logging.getLogger(name).setLevel(level)
+
+ @classmethod
+ def _ensure_logger_enricher_filter(cls, logger: logging.Logger) -> None:
+ has_filter = any(
+ getattr(existing_filter, cls._ENRICH_FILTER_FLAG, False)
+ for existing_filter in logger.filters
+ )
+ if not has_filter:
+ enrich_filter = _RecordEnricherFilter()
+ setattr(enrich_filter, cls._ENRICH_FILTER_FLAG, True)
+ logger.addFilter(enrich_filter)
+
+ @classmethod
+ def _ensure_logger_intercept_handler(cls, logger: logging.Logger) -> None:
+ has_handler = any(
+ getattr(handler, cls._LOGGER_HANDLER_FLAG, False)
+ for handler in logger.handlers
+ )
+ if not has_handler:
+ handler = _LoguruInterceptHandler()
+ setattr(handler, cls._LOGGER_HANDLER_FLAG, True)
+ logger.addHandler(handler)
+
+ @classmethod
+ def GetLogger(cls, log_name: str = "default") -> logging.Logger:
+ cls._setup_loguru()
+ cls._setup_root_bridge()
+
+ logger = logging.getLogger(log_name)
+ cls._ensure_logger_enricher_filter(logger)
+ cls._ensure_logger_intercept_handler(logger)
+ logger.setLevel(logging.DEBUG)
+ logger.propagate = False
+ return logger
+
+ @classmethod
+ def set_queue_handler(cls, logger: logging.Logger, log_broker: LogBroker) -> None:
+ cls._ensure_logger_enricher_filter(logger)
+
+ for handler in logger.handlers:
+ if isinstance(handler, LogQueueHandler):
+ return
+
+ handler = LogQueueHandler(log_broker)
+ handler.setLevel(logging.DEBUG)
+ handler.addFilter(_QueueAnsiColorFilter())
+ handler.setFormatter(
+ logging.Formatter(
+ "%(ansi_prefix)s[%(asctime)s.%(msecs)03d] %(plugin_tag)s [%(short_levelname)s]%(astrbot_version_tag)s "
+ "[%(source_file)s:%(source_line)d]: %(message)s%(ansi_reset)s",
+ datefmt="%Y-%m-%d %H:%M:%S",
+ ),
+ )
+ logger.addHandler(handler)
+
+ @classmethod
+ def _remove_sink(cls, sink_id: int | None) -> None:
+ if sink_id is None:
+ return
+ try:
+ _loguru.remove(sink_id)
+ except ValueError:
+ pass
+
+ @classmethod
+ def _add_file_sink(
+ cls,
+ *,
+ file_path: str,
+ level: int,
+ max_mb: int | None,
+ backup_count: int,
+ trace: bool,
+ ) -> int:
+ os.makedirs(os.path.dirname(file_path) or ".", exist_ok=True)
+ rotation = f"{max_mb} MB" if max_mb and max_mb > 0 else None
+ retention = (
+ backup_count if rotation and backup_count and backup_count > 0 else None
+ )
+ if trace:
+ return _loguru.add(
+ file_path,
+ level="INFO",
+ format="[{time:YYYY-MM-DD HH:mm:ss.SSS}] {message}",
+ encoding="utf-8",
+ rotation=rotation,
+ retention=retention,
+ enqueue=True,
+ filter=lambda record: record["extra"].get("is_trace", False),
+ )
+
+ logging_level_name = logging.getLevelName(level)
+ if isinstance(logging_level_name, int):
+ logging_level_name = "INFO"
+ return _loguru.add(
+ file_path,
+ level=logging_level_name,
+ format=(
+ "[{time:YYYY-MM-DD HH:mm:ss.SSS}] {extra[plugin_tag]} "
+ "[{extra[short_levelname]}]{extra[astrbot_version_tag]} "
+ "[{extra[source_file]}:{extra[source_line]}]: {message}"
+ ),
+ encoding="utf-8",
+ rotation=rotation,
+ retention=retention,
+ enqueue=True,
+ filter=lambda record: not record["extra"].get("is_trace", False),
)
- logger.addHandler(file_handler)
@classmethod
def configure_logger(
@@ -366,13 +339,6 @@ class LogManager:
config: dict | None,
override_level: str | None = None,
) -> None:
- """根据配置设置日志级别和文件日志。
-
- Args:
- logger: 需要配置的 logger
- config: 配置字典
- override_level: 若提供,将覆盖配置中的日志级别
- """
if not config:
return
@@ -383,7 +349,6 @@ class LogManager:
except Exception:
logger.setLevel(logging.INFO)
- # 兼容旧版嵌套配置
if "log_file" in config:
file_conf = config.get("log_file") or {}
enable_file = bool(file_conf.get("enable", False))
@@ -394,27 +359,25 @@ class LogManager:
file_path = config.get("log_file_path")
max_mb = config.get("log_file_max_mb")
- file_path = cls._resolve_log_path(file_path)
+ cls._remove_sink(cls._file_sink_id)
+ cls._file_sink_id = None
- existing = cls._get_file_handlers(logger)
if not enable_file:
- cls._remove_file_handlers(logger)
return
- # 如果已有文件处理器且路径一致,则仅同步级别
- if existing:
- handler = existing[0]
- base = getattr(handler, "baseFilename", "")
- if base and os.path.abspath(base) == os.path.abspath(file_path):
- handler.setLevel(logger.level)
- return
- cls._remove_file_handlers(logger)
-
- cls._add_file_handler(logger, file_path, max_mb=max_mb)
+ try:
+ cls._file_sink_id = cls._add_file_sink(
+ file_path=cls._resolve_log_path(file_path),
+ level=logger.level,
+ max_mb=max_mb,
+ backup_count=3,
+ trace=False,
+ )
+ except Exception as e:
+ logger.error(f"Failed to add file sink: {e}")
@classmethod
def configure_trace_logger(cls, config: dict | None) -> None:
- """为 trace 事件配置独立的文件日志,不向控制台输出。"""
if not config:
return
@@ -429,28 +392,22 @@ class LogManager:
path = path or legacy.get("trace_path")
max_mb = max_mb or legacy.get("trace_max_mb")
- if not enable:
- trace_logger = logging.getLogger("astrbot.trace")
- cls._remove_trace_file_handlers(trace_logger)
- return
-
- file_path = cls._resolve_log_path(path or "logs/astrbot.trace.log")
trace_logger = logging.getLogger("astrbot.trace")
+ cls._ensure_logger_enricher_filter(trace_logger)
+ cls._ensure_logger_intercept_handler(trace_logger)
trace_logger.setLevel(logging.INFO)
trace_logger.propagate = False
- existing = cls._get_trace_file_handlers(trace_logger)
- if existing:
- handler = existing[0]
- base = getattr(handler, "baseFilename", "")
- if base and os.path.abspath(base) == os.path.abspath(file_path):
- handler.setLevel(trace_logger.level)
- return
- cls._remove_trace_file_handlers(trace_logger)
+ cls._remove_sink(cls._trace_sink_id)
+ cls._trace_sink_id = None
- cls._add_file_handler(
- trace_logger,
- file_path,
+ if not enable:
+ return
+
+ cls._trace_sink_id = cls._add_file_sink(
+ file_path=cls._resolve_log_path(path or "logs/astrbot.trace.log"),
+ level=logging.INFO,
max_mb=max_mb,
+ backup_count=3,
trace=True,
)
diff --git a/astrbot/core/message/components.py b/astrbot/core/message/components.py
index a192025dc..758381ba2 100644
--- a/astrbot/core/message/components.py
+++ b/astrbot/core/message/components.py
@@ -25,13 +25,17 @@ import asyncio
import base64
import json
import os
+import sys
import uuid
from enum import Enum
-from pydantic.v1 import BaseModel
+if sys.version_info >= (3, 14):
+ from pydantic import BaseModel
+else:
+ from pydantic.v1 import BaseModel
from astrbot.core import astrbot_config, file_token_service, logger
-from astrbot.core.utils.astrbot_path import get_astrbot_data_path
+from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from astrbot.core.utils.io import download_file, download_image_by_url, file_to_base64
@@ -85,7 +89,7 @@ class BaseMessageComponent(BaseModel):
class Plain(BaseMessageComponent):
- type = ComponentType.Plain
+ type: ComponentType = ComponentType.Plain
text: str
convert: bool | None = True
@@ -100,7 +104,7 @@ class Plain(BaseMessageComponent):
class Face(BaseMessageComponent):
- type = ComponentType.Face
+ type: ComponentType = ComponentType.Face
id: int
def __init__(self, **_) -> None:
@@ -108,7 +112,7 @@ class Face(BaseMessageComponent):
class Record(BaseMessageComponent):
- type = ComponentType.Record
+ type: ComponentType = ComponentType.Record
file: str | None = ""
magic: bool | None = False
url: str | None = ""
@@ -156,8 +160,9 @@ class Record(BaseMessageComponent):
if self.file.startswith("base64://"):
bs64_data = self.file.removeprefix("base64://")
image_bytes = base64.b64decode(bs64_data)
- temp_dir = os.path.join(get_astrbot_data_path(), "temp")
- file_path = os.path.join(temp_dir, f"{uuid.uuid4()}.jpg")
+ file_path = os.path.join(
+ get_astrbot_temp_path(), f"recordseg_{uuid.uuid4()}.jpg"
+ )
with open(file_path, "wb") as f:
f.write(image_bytes)
return os.path.abspath(file_path)
@@ -214,7 +219,7 @@ class Record(BaseMessageComponent):
class Video(BaseMessageComponent):
- type = ComponentType.Video
+ type: ComponentType = ComponentType.Video
file: str
cover: str | None = ""
c: int | None = 2
@@ -245,8 +250,9 @@ class Video(BaseMessageComponent):
if url and url.startswith("file:///"):
return url[8:]
if url and url.startswith("http"):
- download_dir = os.path.join(get_astrbot_data_path(), "temp")
- video_file_path = os.path.join(download_dir, f"{uuid.uuid4().hex}")
+ video_file_path = os.path.join(
+ get_astrbot_temp_path(), f"videoseg_{uuid.uuid4().hex}"
+ )
await download_file(url, video_file_path)
if os.path.exists(video_file_path):
return os.path.abspath(video_file_path)
@@ -299,7 +305,7 @@ class Video(BaseMessageComponent):
class At(BaseMessageComponent):
- type = ComponentType.At
+ type: ComponentType = ComponentType.At
qq: int | str # 此处str为all时代表所有人
name: str | None = ""
@@ -321,28 +327,28 @@ class AtAll(At):
class RPS(BaseMessageComponent): # TODO
- type = ComponentType.RPS
+ type: ComponentType = ComponentType.RPS
def __init__(self, **_) -> None:
super().__init__(**_)
class Dice(BaseMessageComponent): # TODO
- type = ComponentType.Dice
+ type: ComponentType = ComponentType.Dice
def __init__(self, **_) -> None:
super().__init__(**_)
class Shake(BaseMessageComponent): # TODO
- type = ComponentType.Shake
+ type: ComponentType = ComponentType.Shake
def __init__(self, **_) -> None:
super().__init__(**_)
class Share(BaseMessageComponent):
- type = ComponentType.Share
+ type: ComponentType = ComponentType.Share
url: str
title: str
content: str | None = ""
@@ -353,7 +359,7 @@ class Share(BaseMessageComponent):
class Contact(BaseMessageComponent): # TODO
- type = ComponentType.Contact
+ type: ComponentType = ComponentType.Contact
_type: str # type 字段冲突
id: int | None = 0
@@ -362,7 +368,7 @@ class Contact(BaseMessageComponent): # TODO
class Location(BaseMessageComponent): # TODO
- type = ComponentType.Location
+ type: ComponentType = ComponentType.Location
lat: float
lon: float
title: str | None = ""
@@ -373,7 +379,7 @@ class Location(BaseMessageComponent): # TODO
class Music(BaseMessageComponent):
- type = ComponentType.Music
+ type: ComponentType = ComponentType.Music
_type: str
id: int | None = 0
url: str | None = ""
@@ -390,7 +396,7 @@ class Music(BaseMessageComponent):
class Image(BaseMessageComponent):
- type = ComponentType.Image
+ type: ComponentType = ComponentType.Image
file: str | None = ""
_type: str | None = ""
subType: int | None = 0
@@ -445,8 +451,9 @@ class Image(BaseMessageComponent):
if url.startswith("base64://"):
bs64_data = url.removeprefix("base64://")
image_bytes = base64.b64decode(bs64_data)
- temp_dir = os.path.join(get_astrbot_data_path(), "temp")
- image_file_path = os.path.join(temp_dir, f"{uuid.uuid4()}.jpg")
+ image_file_path = os.path.join(
+ get_astrbot_temp_path(), f"imgseg_{uuid.uuid4()}.jpg"
+ )
with open(image_file_path, "wb") as f:
f.write(image_bytes)
return os.path.abspath(image_file_path)
@@ -504,7 +511,7 @@ class Image(BaseMessageComponent):
class Reply(BaseMessageComponent):
- type = ComponentType.Reply
+ type: ComponentType = ComponentType.Reply
id: str | int
"""所引用的消息 ID"""
chain: list["BaseMessageComponent"] | None = []
@@ -540,7 +547,7 @@ class Poke(BaseMessageComponent):
class Forward(BaseMessageComponent):
- type = ComponentType.Forward
+ type: ComponentType = ComponentType.Forward
id: str
def __init__(self, **_) -> None:
@@ -550,7 +557,7 @@ class Forward(BaseMessageComponent):
class Node(BaseMessageComponent):
"""群合并转发消息"""
- type = ComponentType.Node
+ type: ComponentType = ComponentType.Node
id: int | None = 0 # 忽略
name: str | None = "" # qq昵称
uin: str | None = "0" # qq号
@@ -602,7 +609,7 @@ class Node(BaseMessageComponent):
class Nodes(BaseMessageComponent):
- type = ComponentType.Nodes
+ type: ComponentType = ComponentType.Nodes
nodes: list[Node]
def __init__(self, nodes: list[Node], **_) -> None:
@@ -628,7 +635,7 @@ class Nodes(BaseMessageComponent):
class Json(BaseMessageComponent):
- type = ComponentType.Json
+ type: ComponentType = ComponentType.Json
data: dict
def __init__(self, data: str | dict, **_) -> None:
@@ -638,14 +645,14 @@ class Json(BaseMessageComponent):
class Unknown(BaseMessageComponent):
- type = ComponentType.Unknown
+ type: ComponentType = ComponentType.Unknown
text: str
class File(BaseMessageComponent):
"""文件消息段"""
- type = ComponentType.File
+ type: ComponentType = ComponentType.File
name: str | None = "" # 名字
file_: str | None = "" # 本地路径
url: str | None = "" # url
@@ -725,13 +732,12 @@ class File(BaseMessageComponent):
"""下载文件"""
if not self.url:
raise ValueError("Download failed: No URL provided in File component.")
- download_dir = os.path.join(get_astrbot_data_path(), "temp")
- os.makedirs(download_dir, exist_ok=True)
+ download_dir = get_astrbot_temp_path()
if self.name:
name, ext = os.path.splitext(self.name)
- filename = f"{name}_{uuid.uuid4().hex[:8]}{ext}"
+ filename = f"fileseg_{name}_{uuid.uuid4().hex[:8]}{ext}"
else:
- filename = f"{uuid.uuid4().hex}"
+ filename = f"fileseg_{uuid.uuid4().hex}"
file_path = os.path.join(download_dir, filename)
await download_file(self.url, file_path)
self.file_ = os.path.abspath(file_path)
@@ -781,7 +787,7 @@ class File(BaseMessageComponent):
class WechatEmoji(BaseMessageComponent):
- type = ComponentType.WechatEmoji
+ type: ComponentType = ComponentType.WechatEmoji
md5: str | None = ""
md5_len: int | None = 0
cdnurl: str | None = ""
diff --git a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py
index d26f67add..be517dba9 100644
--- a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py
+++ b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py
@@ -123,6 +123,7 @@ class InternalAgentSubStage(Stage):
provider_settings=settings,
subagent_orchestrator=conf.get("subagent_orchestrator", {}),
timezone=self.ctx.plugin_manager.context.get_config().get("timezone"),
+ max_quoted_fallback_images=settings.get("max_quoted_fallback_images", 20),
)
async def process(
@@ -149,6 +150,7 @@ class InternalAgentSubStage(Stage):
logger.debug("ready to request llm provider")
+ await event.send_typing()
await call_event_hook(event, EventType.OnWaitingLLMRequestEvent)
async with session_lock_manager.acquire_lock(event.unified_msg_origin):
@@ -190,6 +192,8 @@ class InternalAgentSubStage(Stage):
)
if await call_event_hook(event, EventType.OnLLMRequestEvent, req):
+ if reset_coro:
+ reset_coro.close()
return
# apply reset
diff --git a/astrbot/core/pipeline/respond/stage.py b/astrbot/core/pipeline/respond/stage.py
index b57fed29e..b4a7ee7fa 100644
--- a/astrbot/core/pipeline/respond/stage.py
+++ b/astrbot/core/pipeline/respond/stage.py
@@ -61,16 +61,17 @@ class RespondStage(Stage):
self.log_base = float(
ctx.astrbot_config["platform_settings"]["segmented_reply"]["log_base"],
)
- interval_str: str = ctx.astrbot_config["platform_settings"]["segmented_reply"][
- "interval"
- ]
- interval_str_ls = interval_str.replace(" ", "").split(",")
- try:
- self.interval = [float(t) for t in interval_str_ls]
- except BaseException as e:
- logger.error(f"解析分段回复的间隔时间失败。{e}")
- self.interval = [1.5, 3.5]
- logger.info(f"分段回复间隔时间:{self.interval}")
+ self.interval = [1.5, 3.5]
+ if self.enable_seg:
+ interval_str: str = ctx.astrbot_config["platform_settings"][
+ "segmented_reply"
+ ]["interval"]
+ interval_str_ls = interval_str.replace(" ", "").split(",")
+ try:
+ self.interval = [float(t) for t in interval_str_ls]
+ except BaseException as e:
+ logger.error(f"解析分段回复的间隔时间失败。{e}")
+ logger.info(f"分段回复间隔时间:{self.interval}")
async def _word_cnt(self, text: str) -> int:
"""分段回复 统计字数"""
diff --git a/astrbot/core/platform/astr_message_event.py b/astrbot/core/platform/astr_message_event.py
index 83b9813e0..4cd531c53 100644
--- a/astrbot/core/platform/astr_message_event.py
+++ b/astrbot/core/platform/astr_message_event.py
@@ -244,6 +244,12 @@ class AstrMessageEvent(abc.ABC):
)
self._has_send_oper = True
+ async def send_typing(self) -> None:
+ """发送输入中状态。
+
+ 默认实现为空,由具体平台按需重写。
+ """
+
async def _pre_send(self) -> None:
"""调度器会在执行 send() 前调用该方法 deprecated in v3.5.18"""
diff --git a/astrbot/core/platform/manager.py b/astrbot/core/platform/manager.py
index 9cec8a942..0238779da 100644
--- a/astrbot/core/platform/manager.py
+++ b/astrbot/core/platform/manager.py
@@ -1,6 +1,7 @@
import asyncio
import traceback
from asyncio import Queue
+from dataclasses import dataclass
from astrbot.core import logger
from astrbot.core.config.astrbot_config import AstrBotConfig
@@ -12,12 +13,19 @@ from .register import platform_cls_map
from .sources.webchat.webchat_adapter import WebChatAdapter
+@dataclass
+class PlatformTasks:
+ run: asyncio.Task
+ wrapper: asyncio.Task
+
+
class PlatformManager:
def __init__(self, config: AstrBotConfig, event_queue: Queue) -> None:
self.platform_insts: list[Platform] = []
"""加载的 Platform 的实例"""
self._inst_map: dict[str, dict] = {}
+ self._platform_tasks: dict[str, PlatformTasks] = {}
self.astrbot_config = config
self.platforms_config = config["platform"]
@@ -38,6 +46,44 @@ class PlatformManager:
sanitized = platform_id.replace(":", "_").replace("!", "_")
return sanitized, sanitized != platform_id
+ def _start_platform_task(self, task_name: str, inst: Platform) -> None:
+ run_task = asyncio.create_task(inst.run(), name=task_name)
+ wrapper_task = asyncio.create_task(
+ self._task_wrapper(run_task, platform=inst),
+ name=f"{task_name}_wrapper",
+ )
+ self._platform_tasks[inst.client_self_id] = PlatformTasks(
+ run=run_task,
+ wrapper=wrapper_task,
+ )
+
+ async def _stop_platform_task(self, client_id: str) -> None:
+ tasks = self._platform_tasks.pop(client_id, None)
+ if not tasks:
+ return
+ for task in (tasks.run, tasks.wrapper):
+ if not task.done():
+ task.cancel()
+ await asyncio.gather(tasks.run, tasks.wrapper, return_exceptions=True)
+
+ async def _terminate_inst_and_tasks(self, inst: Platform) -> None:
+ client_id = inst.client_self_id
+ try:
+ if getattr(inst, "terminate", None):
+ try:
+ await inst.terminate()
+ except asyncio.CancelledError:
+ raise
+ except Exception as e:
+ logger.error(
+ "终止平台适配器失败: client_id=%s, error=%s",
+ client_id,
+ e,
+ )
+ logger.error(traceback.format_exc())
+ finally:
+ await self._stop_platform_task(client_id)
+
async def initialize(self) -> None:
"""初始化所有平台适配器"""
for platform in self.platforms_config:
@@ -51,12 +97,7 @@ class PlatformManager:
# 网页聊天
webchat_inst = WebChatAdapter({}, self.settings, self.event_queue)
self.platform_insts.append(webchat_inst)
- asyncio.create_task(
- self._task_wrapper(
- asyncio.create_task(webchat_inst.run(), name="webchat"),
- platform=webchat_inst,
- ),
- )
+ self._start_platform_task("webchat", webchat_inst)
async def load_platform(self, platform_config: dict) -> None:
"""实例化一个平台"""
@@ -135,6 +176,10 @@ class PlatformManager:
from .sources.satori.satori_adapter import (
SatoriPlatformAdapter, # noqa: F401
)
+ case "line":
+ from .sources.line.line_adapter import (
+ LinePlatformAdapter, # noqa: F401
+ )
except (ImportError, ModuleNotFoundError) as e:
logger.error(
f"加载平台适配器 {platform_config['type']} 失败,原因:{e}。请检查依赖库是否安装。提示:可以在 管理面板->平台日志->安装Pip库 中安装依赖库。",
@@ -154,15 +199,9 @@ class PlatformManager:
"client_id": inst.client_self_id,
}
self.platform_insts.append(inst)
-
- asyncio.create_task(
- self._task_wrapper(
- asyncio.create_task(
- inst.run(),
- name=f"platform_{platform_config['type']}_{platform_config['id']}",
- ),
- platform=inst,
- ),
+ self._start_platform_task(
+ f"platform_{platform_config['type']}_{platform_config['id']}",
+ inst,
)
handlers = star_handlers_registry.get_handlers_by_event_type(
EventType.OnPlatformLoadedEvent,
@@ -230,13 +269,25 @@ class PlatformManager:
except Exception:
logger.warning(f"可能未完全移除 {platform_id} 平台适配器")
- if getattr(inst, "terminate", None):
- await inst.terminate()
+ await self._terminate_inst_and_tasks(inst)
async def terminate(self) -> None:
- for inst in self.platform_insts:
- if getattr(inst, "terminate", None):
- await inst.terminate()
+ terminated_client_ids: set[str] = set()
+ for platform_id in list(self._inst_map.keys()):
+ info = self._inst_map.get(platform_id)
+ if info:
+ terminated_client_ids.add(info["client_id"])
+ await self.terminate_platform(platform_id)
+
+ for inst in list(self.platform_insts):
+ client_id = inst.client_self_id
+ if client_id in terminated_client_ids:
+ continue
+ await self._terminate_inst_and_tasks(inst)
+
+ self.platform_insts.clear()
+ self._inst_map.clear()
+ self._platform_tasks.clear()
def get_insts(self):
return self.platform_insts
diff --git a/astrbot/core/platform/platform_metadata.py b/astrbot/core/platform/platform_metadata.py
index 00e7dc966..2d01b921d 100644
--- a/astrbot/core/platform/platform_metadata.py
+++ b/astrbot/core/platform/platform_metadata.py
@@ -24,3 +24,14 @@ class PlatformMetadata:
module_path: str | None = None
"""注册该适配器的模块路径,用于插件热重载时清理"""
+ i18n_resources: dict[str, dict] | None = None
+ """国际化资源数据,如 {"zh-CN": {...}, "en-US": {...}}
+
+ 参考 https://github.com/AstrBotDevs/AstrBot/pull/5045
+ """
+
+ config_metadata: dict | None = None
+ """配置项元数据,用于 WebUI 生成表单。对应 config_metadata.json 的内容
+
+ 参考 https://github.com/AstrBotDevs/AstrBot/pull/5045
+ """
diff --git a/astrbot/core/platform/register.py b/astrbot/core/platform/register.py
index 3bbec4809..62ec5070a 100644
--- a/astrbot/core/platform/register.py
+++ b/astrbot/core/platform/register.py
@@ -15,11 +15,14 @@ def register_platform_adapter(
adapter_display_name: str | None = None,
logo_path: str | None = None,
support_streaming_message: bool = True,
+ i18n_resources: dict[str, dict] | None = None,
+ config_metadata: dict | None = None,
):
"""用于注册平台适配器的带参装饰器。
default_config_tmpl 指定了平台适配器的默认配置模板。用户填写好后将会作为 platform_config 传入你的 Platform 类的实现类。
logo_path 指定了平台适配器的 logo 文件路径,是相对于插件目录的路径。
+ config_metadata 指定了配置项的元数据,用于 WebUI 生成表单。如果不指定,WebUI 将会把配置项渲染为原始的键值对编辑框。
"""
def decorator(cls):
@@ -49,6 +52,8 @@ def register_platform_adapter(
logo_path=logo_path,
support_streaming_message=support_streaming_message,
module_path=module_path,
+ i18n_resources=i18n_resources,
+ config_metadata=config_metadata,
)
platform_registry.append(pm)
platform_cls_map[adapter_name] = cls
diff --git a/astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py b/astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py
index fd0be3f1c..2d9b45cc1 100644
--- a/astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py
+++ b/astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py
@@ -21,7 +21,7 @@ from astrbot.api.platform import (
)
from astrbot.core import sp
from astrbot.core.platform.astr_message_event import MessageSesion
-from astrbot.core.utils.astrbot_path import get_astrbot_data_path
+from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from astrbot.core.utils.io import download_file
from astrbot.core.utils.media_utils import (
convert_audio_format,
@@ -253,9 +253,9 @@ class DingtalkPlatformAdapter(Platform):
"downloadCode": download_code,
"robotCode": robot_code,
}
- temp_dir = Path(get_astrbot_data_path()) / "temp"
+ temp_dir = Path(get_astrbot_temp_path())
temp_dir.mkdir(parents=True, exist_ok=True)
- f_path = temp_dir / f"dingtalk_file_{uuid.uuid4()}.{ext}"
+ f_path = temp_dir / f"dingtalk_{uuid.uuid4()}.{ext}"
async with (
aiohttp.ClientSession() as session,
session.post(
diff --git a/astrbot/core/platform/sources/lark/lark_adapter.py b/astrbot/core/platform/sources/lark/lark_adapter.py
index e76572768..be1c81c26 100644
--- a/astrbot/core/platform/sources/lark/lark_adapter.py
+++ b/astrbot/core/platform/sources/lark/lark_adapter.py
@@ -3,10 +3,13 @@ import base64
import json
import re
import time
+from pathlib import Path
from typing import Any, cast
+from uuid import uuid4
import lark_oapi as lark
from lark_oapi.api.im.v1 import (
+ GetMessageRequest,
GetMessageResourceRequest,
)
from lark_oapi.api.im.v1.processor import P2ImMessageReceiveV1Processor
@@ -22,6 +25,7 @@ from astrbot.api.platform import (
PlatformMetadata,
)
from astrbot.core.platform.astr_message_event import MessageSesion
+from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from astrbot.core.utils.webhook_utils import log_webhook_info
from ...register import register_platform_adapter
@@ -91,6 +95,347 @@ class LarkPlatformAdapter(Platform):
self.event_id_timestamps: dict[str, float] = {}
+ async def _download_message_resource(
+ self,
+ *,
+ message_id: str,
+ file_key: str,
+ resource_type: str,
+ ) -> bytes | None:
+ if self.lark_api.im is None:
+ logger.error("[Lark] API Client im 模块未初始化")
+ return None
+
+ request = (
+ GetMessageResourceRequest.builder()
+ .message_id(message_id)
+ .file_key(file_key)
+ .type(resource_type)
+ .build()
+ )
+ response = await self.lark_api.im.v1.message_resource.aget(request)
+ if not response.success():
+ logger.error(
+ f"[Lark] 下载消息资源失败 type={resource_type}, key={file_key}, "
+ f"code={response.code}, msg={response.msg}",
+ )
+ return None
+
+ if response.file is None:
+ logger.error(f"[Lark] 消息资源响应中不包含文件流: {file_key}")
+ return None
+
+ return response.file.read()
+
+ @staticmethod
+ def _build_message_str_from_components(
+ components: list[Comp.BaseMessageComponent],
+ ) -> str:
+ parts: list[str] = []
+ for comp in components:
+ if isinstance(comp, Comp.Plain):
+ text = comp.text.strip()
+ if text:
+ parts.append(text)
+ elif isinstance(comp, Comp.At):
+ name = str(comp.name or comp.qq or "").strip()
+ if name:
+ parts.append(f"@{name}")
+ elif isinstance(comp, Comp.Image):
+ parts.append("[image]")
+ elif isinstance(comp, Comp.File):
+ parts.append(str(comp.name or "[file]"))
+ elif isinstance(comp, Comp.Record):
+ parts.append("[audio]")
+ elif isinstance(comp, Comp.Video):
+ parts.append("[video]")
+
+ return " ".join(parts).strip()
+
+ @staticmethod
+ def _parse_post_content(content: dict[str, Any]) -> list[dict[str, Any]]:
+ result: list[dict[str, Any]] = []
+ for item in content.get("content", []):
+ if isinstance(item, list):
+ for comp in item:
+ if isinstance(comp, dict):
+ result.append(comp)
+ elif isinstance(item, dict):
+ result.append(item)
+ return result
+
+ @staticmethod
+ def _build_at_map(mentions: list[Any] | None) -> dict[str, Comp.At]:
+ at_map: dict[str, Comp.At] = {}
+ if not mentions:
+ return at_map
+
+ for mention in mentions:
+ key = getattr(mention, "key", None)
+ if not key:
+ continue
+
+ mention_id = getattr(mention, "id", None)
+ open_id = ""
+ if mention_id is not None:
+ if hasattr(mention_id, "open_id"):
+ open_id = getattr(mention_id, "open_id", "") or ""
+ else:
+ open_id = str(mention_id)
+
+ mention_name = str(getattr(mention, "name", "") or "")
+ at_map[key] = Comp.At(qq=open_id, name=mention_name)
+
+ return at_map
+
+ async def _parse_message_components(
+ self,
+ *,
+ message_id: str | None,
+ message_type: str,
+ content: dict[str, Any],
+ at_map: dict[str, Comp.At],
+ ) -> list[Comp.BaseMessageComponent]:
+ components: list[Comp.BaseMessageComponent] = []
+
+ if message_type == "text":
+ message_str_raw = str(content.get("text", ""))
+ at_pattern = r"(@_user_\d+)"
+ parts = re.split(at_pattern, message_str_raw)
+ for part in parts:
+ segment = part.strip()
+ if not segment:
+ continue
+ if segment in at_map:
+ components.append(at_map[segment])
+ else:
+ components.append(Comp.Plain(segment))
+ return components
+
+ if message_type in ("post", "image"):
+ if message_type == "image":
+ comp_list = [
+ {
+ "tag": "img",
+ "image_key": content.get("image_key"),
+ },
+ ]
+ else:
+ comp_list = self._parse_post_content(content)
+
+ for comp in comp_list:
+ tag = comp.get("tag")
+ if tag == "at":
+ user_key = str(comp.get("user_id", ""))
+ if user_key in at_map:
+ components.append(at_map[user_key])
+ elif tag == "text":
+ text = str(comp.get("text", "")).strip()
+ if text:
+ components.append(Comp.Plain(text))
+ elif tag == "a":
+ text = str(comp.get("text", "")).strip()
+ href = str(comp.get("href", "")).strip()
+ if text and href:
+ components.append(Comp.Plain(f"{text}({href})"))
+ elif text:
+ components.append(Comp.Plain(text))
+ elif tag == "img":
+ image_key = str(comp.get("image_key", "")).strip()
+ if not image_key:
+ continue
+ if not message_id:
+ logger.error("[Lark] 图片消息缺少 message_id")
+ continue
+ image_bytes = await self._download_message_resource(
+ message_id=message_id,
+ file_key=image_key,
+ resource_type="image",
+ )
+ if image_bytes is None:
+ continue
+ image_base64 = base64.b64encode(image_bytes).decode()
+ components.append(Comp.Image.fromBase64(image_base64))
+ elif tag == "media":
+ file_key = str(comp.get("file_key", "")).strip()
+ file_name = (
+ str(comp.get("file_name", "")).strip() or "lark_media.mp4"
+ )
+ if not file_key:
+ continue
+ if not message_id:
+ logger.error("[Lark] 富文本视频消息缺少 message_id")
+ continue
+ file_path = await self._download_file_resource_to_temp(
+ message_id=message_id,
+ file_key=file_key,
+ message_type="post_media",
+ file_name=file_name,
+ default_suffix=".mp4",
+ )
+ if file_path:
+ components.append(Comp.Video(file=file_path, path=file_path))
+
+ return components
+
+ if message_type == "file":
+ file_key = str(content.get("file_key", "")).strip()
+ file_name = str(content.get("file_name", "")).strip() or "lark_file"
+ if not message_id:
+ logger.error("[Lark] 文件消息缺少 message_id")
+ return components
+ if not file_key:
+ logger.error("[Lark] 文件消息缺少 file_key")
+ return components
+ file_path = await self._download_file_resource_to_temp(
+ message_id=message_id,
+ file_key=file_key,
+ message_type="file",
+ file_name=file_name,
+ )
+ if file_path:
+ components.append(Comp.File(name=file_name, file=file_path))
+ return components
+
+ if message_type == "audio":
+ file_key = str(content.get("file_key", "")).strip()
+ if not message_id:
+ logger.error("[Lark] 音频消息缺少 message_id")
+ return components
+ if not file_key:
+ logger.error("[Lark] 音频消息缺少 file_key")
+ return components
+ file_path = await self._download_file_resource_to_temp(
+ message_id=message_id,
+ file_key=file_key,
+ message_type="audio",
+ default_suffix=".opus",
+ )
+ if file_path:
+ components.append(Comp.Record(file=file_path, url=file_path))
+ return components
+
+ if message_type == "media":
+ file_key = str(content.get("file_key", "")).strip()
+ file_name = str(content.get("file_name", "")).strip() or "lark_media.mp4"
+ if not message_id:
+ logger.error("[Lark] 视频消息缺少 message_id")
+ return components
+ if not file_key:
+ logger.error("[Lark] 视频消息缺少 file_key")
+ return components
+ file_path = await self._download_file_resource_to_temp(
+ message_id=message_id,
+ file_key=file_key,
+ message_type="media",
+ file_name=file_name,
+ default_suffix=".mp4",
+ )
+ if file_path:
+ components.append(Comp.Video(file=file_path, path=file_path))
+ return components
+
+ return components
+
+ async def _build_reply_from_parent_id(
+ self,
+ parent_message_id: str,
+ ) -> Comp.Reply | None:
+ if self.lark_api.im is None:
+ logger.error("[Lark] API Client im 模块未初始化")
+ return None
+
+ request = GetMessageRequest.builder().message_id(parent_message_id).build()
+ response = await self.lark_api.im.v1.message.aget(request)
+ if not response.success():
+ logger.error(
+ f"[Lark] 获取引用消息失败 id={parent_message_id}, "
+ f"code={response.code}, msg={response.msg}",
+ )
+ return None
+
+ if response.data is None or not response.data.items:
+ logger.error(
+ f"[Lark] 引用消息响应为空 id={parent_message_id}",
+ )
+ return None
+
+ parent_message = response.data.items[0]
+ quoted_message_id = parent_message.message_id or parent_message_id
+ quoted_sender_id = (
+ parent_message.sender.id
+ if parent_message.sender and parent_message.sender.id
+ else "unknown"
+ )
+ quoted_time_raw = parent_message.create_time or 0
+ quoted_time = (
+ quoted_time_raw // 1000
+ if isinstance(quoted_time_raw, int) and quoted_time_raw > 10**11
+ else quoted_time_raw
+ )
+ quoted_content = (
+ parent_message.body.content if parent_message.body else ""
+ ) or ""
+ quoted_type = parent_message.msg_type or ""
+ quoted_content_json: dict[str, Any] = {}
+ if quoted_content:
+ try:
+ parsed = json.loads(quoted_content)
+ if isinstance(parsed, dict):
+ quoted_content_json = parsed
+ except json.JSONDecodeError:
+ logger.warning(
+ f"[Lark] 解析引用消息内容失败 id={quoted_message_id}",
+ )
+
+ quoted_at_map = self._build_at_map(parent_message.mentions)
+ quoted_chain = await self._parse_message_components(
+ message_id=quoted_message_id,
+ message_type=quoted_type,
+ content=quoted_content_json,
+ at_map=quoted_at_map,
+ )
+ quoted_text = self._build_message_str_from_components(quoted_chain)
+ sender_nickname = (
+ quoted_sender_id[:8] if quoted_sender_id != "unknown" else "unknown"
+ )
+
+ return Comp.Reply(
+ id=quoted_message_id,
+ chain=quoted_chain,
+ sender_id=quoted_sender_id,
+ sender_nickname=sender_nickname,
+ time=quoted_time,
+ message_str=quoted_text,
+ text=quoted_text,
+ )
+
+ async def _download_file_resource_to_temp(
+ self,
+ *,
+ message_id: str,
+ file_key: str,
+ message_type: str,
+ file_name: str = "",
+ default_suffix: str = ".bin",
+ ) -> str | None:
+ file_bytes = await self._download_message_resource(
+ message_id=message_id,
+ file_key=file_key,
+ resource_type="file",
+ )
+ if file_bytes is None:
+ return None
+
+ suffix = Path(file_name).suffix if file_name else default_suffix
+ temp_dir = Path(get_astrbot_temp_path())
+ temp_dir.mkdir(parents=True, exist_ok=True)
+ temp_path = (
+ temp_dir / f"lark_{message_type}_{file_name}_{uuid4().hex[:4]}{suffix}"
+ )
+ temp_path.write_bytes(file_bytes)
+ return str(temp_path.resolve())
+
def _clean_expired_events(self) -> None:
"""清理超过 30 分钟的事件记录"""
current_time = time.time()
@@ -176,6 +521,11 @@ class LarkPlatformAdapter(Platform):
abm.message_str = ""
at_list = {}
+ if message.parent_id:
+ reply_seg = await self._build_reply_from_parent_id(message.parent_id)
+ if reply_seg:
+ abm.message.append(reply_seg)
+
if message.mentions:
for m in message.mentions:
if m.id is None:
@@ -198,80 +548,19 @@ class LarkPlatformAdapter(Platform):
logger.error(f"[Lark] 解析消息内容失败: {message.content}")
return
- if message.message_type == "text":
- message_str_raw = content_json_b.get("text", "") # 带有 @ 的消息
- at_pattern = r"(@_user_\d+)" # 可以根据需求修改正则
- # at_users = re.findall(at_pattern, message_str_raw)
- # 拆分文本,去掉AT符号部分
- parts = re.split(at_pattern, message_str_raw)
- for i in range(len(parts)):
- s = parts[i].strip()
- if not s:
- continue
- if s in at_list:
- abm.message.append(at_list[s])
- else:
- abm.message.append(Comp.Plain(parts[i].strip()))
- elif message.message_type == "post":
- _ls = []
+ if not isinstance(content_json_b, dict):
+ logger.error(f"[Lark] 消息内容不是 JSON Object: {message.content}")
+ return
- content_ls = content_json_b.get("content", [])
- for comp in content_ls:
- if isinstance(comp, list):
- _ls.extend(comp)
- elif isinstance(comp, dict):
- _ls.append(comp)
- content_json_b = _ls
- elif message.message_type == "image":
- content_json_b = [
- {
- "tag": "img",
- "image_key": content_json_b.get("image_key"),
- "style": [],
- },
- ]
-
- if message.message_type in ("post", "image"):
- for comp in content_json_b:
- if comp.get("tag") == "at":
- user_id = comp.get("user_id")
- if user_id in at_list:
- abm.message.append(at_list[user_id])
- elif comp.get("tag") == "text" and comp.get("text", "").strip():
- abm.message.append(Comp.Plain(comp["text"].strip()))
- elif comp.get("tag") == "img":
- image_key = comp.get("image_key")
- if not image_key:
- continue
-
- request = (
- GetMessageResourceRequest.builder()
- .message_id(cast(str, message.message_id))
- .file_key(image_key)
- .type("image")
- .build()
- )
-
- if self.lark_api.im is None:
- logger.error("[Lark] API Client im 模块未初始化")
- continue
-
- response = await self.lark_api.im.v1.message_resource.aget(request)
- if not response.success():
- logger.error(f"无法下载飞书图片: {image_key}")
- continue
-
- if response.file is None:
- logger.error(f"飞书图片响应中不包含文件流: {image_key}")
- continue
-
- image_bytes = response.file.read()
- image_base64 = base64.b64encode(image_bytes).decode()
- abm.message.append(Comp.Image.fromBase64(image_base64))
-
- for comp in abm.message:
- if isinstance(comp, Comp.Plain):
- abm.message_str += comp.text
+ logger.debug(f"[Lark] 解析消息内容: {content_json_b}")
+ parsed_components = await self._parse_message_components(
+ message_id=message.message_id,
+ message_type=message.message_type or "unknown",
+ content=content_json_b,
+ at_map=at_list,
+ )
+ abm.message.extend(parsed_components)
+ abm.message_str = self._build_message_str_from_components(parsed_components)
if message.message_id is None:
logger.error("[Lark] 消息缺少 message_id")
@@ -296,7 +585,6 @@ class LarkPlatformAdapter(Platform):
else:
abm.session_id = abm.sender.user_id
- logger.debug(abm)
await self.handle_msg(abm)
async def handle_msg(self, abm: AstrBotMessage) -> None:
diff --git a/astrbot/core/platform/sources/lark/lark_event.py b/astrbot/core/platform/sources/lark/lark_event.py
index 83a455888..92e3a32b9 100644
--- a/astrbot/core/platform/sources/lark/lark_event.py
+++ b/astrbot/core/platform/sources/lark/lark_event.py
@@ -21,7 +21,7 @@ from astrbot import logger
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.message_components import At, File, Plain, Record, Video
from astrbot.api.message_components import Image as AstrBotImage
-from astrbot.core.utils.astrbot_path import get_astrbot_data_path
+from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from astrbot.core.utils.io import download_image_by_url
from astrbot.core.utils.media_utils import (
convert_audio_to_opus,
@@ -202,8 +202,11 @@ class LarkMessageEvent(AstrMessageEvent):
base64_str = comp.file.removeprefix("base64://")
image_data = base64.b64decode(base64_str)
# save as temp file
- temp_dir = os.path.join(get_astrbot_data_path(), "temp")
- file_path = os.path.join(temp_dir, f"{uuid.uuid4()}_test.jpg")
+ temp_dir = get_astrbot_temp_path()
+ file_path = os.path.join(
+ temp_dir,
+ f"lark_image_{uuid.uuid4().hex[:8]}.jpg",
+ )
with open(file_path, "wb") as f:
f.write(BytesIO(image_data).getvalue())
else:
diff --git a/astrbot/core/platform/sources/line/line_adapter.py b/astrbot/core/platform/sources/line/line_adapter.py
new file mode 100644
index 000000000..9348ff100
--- /dev/null
+++ b/astrbot/core/platform/sources/line/line_adapter.py
@@ -0,0 +1,474 @@
+import asyncio
+import mimetypes
+import time
+import uuid
+from pathlib import Path
+from typing import Any, cast
+
+from astrbot.api import logger
+from astrbot.api.event import MessageChain
+from astrbot.api.message_components import At, File, Image, Plain, Record, Video
+from astrbot.api.platform import (
+ AstrBotMessage,
+ Group,
+ MessageMember,
+ MessageType,
+ Platform,
+ PlatformMetadata,
+)
+from astrbot.core.platform.astr_message_event import MessageSesion
+from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
+from astrbot.core.utils.webhook_utils import log_webhook_info
+
+from ...register import register_platform_adapter
+from .line_api import LineAPIClient
+from .line_event import LineMessageEvent
+
+LINE_CONFIG_METADATA = {
+ "channel_access_token": {
+ "description": "LINE Channel Access Token",
+ "type": "string",
+ "hint": "LINE Messaging API 的 channel access token。",
+ },
+ "channel_secret": {
+ "description": "LINE Channel Secret",
+ "type": "string",
+ "hint": "用于校验 LINE Webhook 签名。",
+ },
+}
+
+LINE_I18N_RESOURCES = {
+ "zh-CN": {
+ "channel_access_token": {
+ "description": "LINE Channel Access Token",
+ "hint": "LINE Messaging API 的 channel access token。",
+ },
+ "channel_secret": {
+ "description": "LINE Channel Secret",
+ "hint": "用于校验 LINE Webhook 签名。",
+ },
+ },
+ "en-US": {
+ "channel_access_token": {
+ "description": "LINE Channel Access Token",
+ "hint": "Channel access token for LINE Messaging API.",
+ },
+ "channel_secret": {
+ "description": "LINE Channel Secret",
+ "hint": "Used to verify LINE webhook signatures.",
+ },
+ },
+}
+
+
+@register_platform_adapter(
+ "line",
+ "LINE Messaging API 适配器",
+ support_streaming_message=False,
+ default_config_tmpl={
+ "id": "line",
+ "type": "line",
+ "enable": False,
+ "channel_access_token": "",
+ "channel_secret": "",
+ "unified_webhook_mode": True,
+ "webhook_uuid": "",
+ },
+ config_metadata=LINE_CONFIG_METADATA,
+ i18n_resources=LINE_I18N_RESOURCES,
+)
+class LinePlatformAdapter(Platform):
+ def __init__(
+ self,
+ platform_config: dict,
+ platform_settings: dict,
+ event_queue: asyncio.Queue,
+ ) -> None:
+ super().__init__(platform_config, event_queue)
+ self.config["unified_webhook_mode"] = True
+ self.destination = "unknown"
+ self.settings = platform_settings
+ self._event_id_timestamps: dict[str, float] = {}
+ self.shutdown_event = asyncio.Event()
+
+ channel_access_token = str(platform_config.get("channel_access_token", ""))
+ channel_secret = str(platform_config.get("channel_secret", ""))
+ if not channel_access_token or not channel_secret:
+ raise ValueError(
+ "LINE 适配器需要 channel_access_token 和 channel_secret。",
+ )
+
+ self.line_api = LineAPIClient(
+ channel_access_token=channel_access_token,
+ channel_secret=channel_secret,
+ )
+
+ async def send_by_session(
+ self,
+ session: MessageSesion,
+ message_chain: MessageChain,
+ ) -> None:
+ messages = await LineMessageEvent.build_line_messages(message_chain)
+ if messages:
+ await self.line_api.push_message(session.session_id, messages)
+ await super().send_by_session(session, message_chain)
+
+ def meta(self) -> PlatformMetadata:
+ return PlatformMetadata(
+ name="line",
+ description="LINE Messaging API 适配器",
+ id=cast(str, self.config.get("id", "line")),
+ support_streaming_message=False,
+ )
+
+ async def run(self) -> None:
+ webhook_uuid = self.config.get("webhook_uuid")
+ if webhook_uuid:
+ log_webhook_info(f"{self.meta().id}(LINE)", webhook_uuid)
+ else:
+ logger.warning("[LINE] webhook_uuid 为空,统一 Webhook 可能无法接收消息。")
+ await self.shutdown_event.wait()
+
+ async def terminate(self) -> None:
+ self.shutdown_event.set()
+ await self.line_api.close()
+
+ async def webhook_callback(self, request: Any) -> Any:
+ raw_body = await request.get_data()
+ signature = request.headers.get("x-line-signature")
+ if not self.line_api.verify_signature(raw_body, signature):
+ logger.warning("[LINE] invalid webhook signature")
+ return "invalid signature", 400
+
+ try:
+ payload = await request.get_json(force=True, silent=False)
+ except Exception as e:
+ logger.warning("[LINE] invalid webhook body: %s", e)
+ return "bad request", 400
+
+ if not isinstance(payload, dict):
+ return "bad request", 400
+
+ await self.handle_webhook_event(payload)
+ return "ok", 200
+
+ async def handle_webhook_event(self, payload: dict[str, Any]) -> None:
+ destination = str(payload.get("destination", "")).strip()
+ if destination:
+ self.destination = destination
+
+ events = payload.get("events")
+ if not isinstance(events, list):
+ return
+
+ for event in events:
+ if not isinstance(event, dict):
+ continue
+
+ event_id = str(event.get("webhookEventId", ""))
+ if event_id and self._is_duplicate_event(event_id):
+ logger.debug("[LINE] duplicate event skipped: %s", event_id)
+ continue
+
+ abm = await self.convert_message(event)
+ if abm is None:
+ continue
+ await self.handle_msg(abm)
+
+ async def convert_message(self, event: dict[str, Any]) -> AstrBotMessage | None:
+ if str(event.get("type", "")) != "message":
+ return None
+ if str(event.get("mode", "active")) == "standby":
+ return None
+
+ source = event.get("source", {})
+ if not isinstance(source, dict):
+ return None
+
+ message = event.get("message", {})
+ if not isinstance(message, dict):
+ return None
+
+ source_type = str(source.get("type", ""))
+ user_id = str(source.get("userId", "")).strip()
+ group_id = str(source.get("groupId", "")).strip()
+ room_id = str(source.get("roomId", "")).strip()
+
+ abm = AstrBotMessage()
+ abm.self_id = self.destination or self.meta().id
+ abm.message = []
+ abm.raw_message = event
+ abm.message_id = str(
+ message.get("id")
+ or event.get("webhookEventId")
+ or event.get("deliveryContext", {}).get("deliveryId", "")
+ or uuid.uuid4().hex
+ )
+
+ event_timestamp = event.get("timestamp")
+ if isinstance(event_timestamp, int):
+ abm.timestamp = (
+ event_timestamp // 1000
+ if event_timestamp > 1_000_000_000_000
+ else event_timestamp
+ )
+ else:
+ abm.timestamp = int(time.time())
+
+ if source_type in {"group", "room"}:
+ abm.type = MessageType.GROUP_MESSAGE
+ container_id = group_id or room_id
+ abm.group = Group(group_id=container_id, group_name=container_id)
+ abm.session_id = container_id
+ sender_id = user_id or container_id
+ elif source_type == "user":
+ abm.type = MessageType.FRIEND_MESSAGE
+ abm.session_id = user_id
+ sender_id = user_id
+ else:
+ abm.type = MessageType.OTHER_MESSAGE
+ abm.session_id = user_id or group_id or room_id or "unknown"
+ sender_id = abm.session_id
+
+ abm.sender = MessageMember(user_id=sender_id, nickname=sender_id[:8])
+
+ components = await self._parse_line_message_components(message)
+ if not components:
+ return None
+ abm.message = components
+ abm.message_str = self._build_message_str(components)
+ return abm
+
+ async def _parse_line_message_components(
+ self,
+ message: dict[str, Any],
+ ) -> list:
+ msg_type = str(message.get("type", ""))
+ message_id = str(message.get("id", "")).strip()
+
+ if msg_type == "text":
+ text = str(message.get("text", ""))
+ mention = message.get("mention")
+ if isinstance(mention, dict):
+ return self._parse_text_with_mentions(text, mention)
+ return [Plain(text=text)] if text else []
+
+ if msg_type == "image":
+ image_component = await self._build_image_component(message_id, message)
+ return [image_component] if image_component else [Plain(text="[image]")]
+
+ if msg_type == "video":
+ video_component = await self._build_video_component(message_id, message)
+ return [video_component] if video_component else [Plain(text="[video]")]
+
+ if msg_type == "audio":
+ audio_component = await self._build_audio_component(message_id, message)
+ return [audio_component] if audio_component else [Plain(text="[audio]")]
+
+ if msg_type == "file":
+ file_component = await self._build_file_component(message_id, message)
+ return [file_component] if file_component else [Plain(text="[file]")]
+
+ if msg_type == "sticker":
+ return [Plain(text="[sticker]")]
+
+ return [Plain(text=f"[{msg_type}]")]
+
+ def _parse_text_with_mentions(self, text: str, mention_obj: dict[str, Any]) -> list:
+ mentions = mention_obj.get("mentionees", [])
+ if not isinstance(mentions, list) or not mentions:
+ return [Plain(text=text)] if text else []
+
+ normalized = []
+ for item in mentions:
+ if not isinstance(item, dict):
+ continue
+ start = item.get("index")
+ length = item.get("length")
+ if not isinstance(start, int) or not isinstance(length, int):
+ continue
+ normalized.append((start, length, item))
+ normalized.sort(key=lambda x: x[0])
+
+ ret = []
+ cursor = 0
+ for start, length, item in normalized:
+ if start > cursor:
+ part = text[cursor:start]
+ if part:
+ ret.append(Plain(text=part))
+
+ label = text[start : start + length] or "@user"
+ mention_type = str(item.get("type", ""))
+ if mention_type == "user":
+ target_id = str(item.get("userId", "")).strip()
+ ret.append(At(qq=target_id, name=label.lstrip("@")))
+ else:
+ ret.append(Plain(text=label))
+ cursor = max(cursor, start + length)
+
+ if cursor < len(text):
+ tail = text[cursor:]
+ if tail:
+ ret.append(Plain(text=tail))
+ return ret
+
+ async def _build_image_component(
+ self,
+ message_id: str,
+ message: dict[str, Any],
+ ) -> Image | None:
+ external_url = self._get_external_content_url(message)
+ if external_url:
+ return Image.fromURL(external_url)
+
+ content = await self.line_api.get_message_content(message_id)
+ if not content:
+ return None
+ content_bytes, _, _ = content
+ return Image.fromBytes(content_bytes)
+
+ async def _build_video_component(
+ self,
+ message_id: str,
+ message: dict[str, Any],
+ ) -> Video | None:
+ external_url = self._get_external_content_url(message)
+ if external_url:
+ return Video.fromURL(external_url)
+
+ content = await self.line_api.get_message_content(message_id)
+ if not content:
+ return None
+ content_bytes, content_type, _ = content
+ suffix = self._guess_suffix(content_type, ".mp4")
+ file_path = self._store_temp_content("video", message_id, content_bytes, suffix)
+ return Video(file=file_path, path=file_path)
+
+ async def _build_audio_component(
+ self,
+ message_id: str,
+ message: dict[str, Any],
+ ) -> Record | None:
+ external_url = self._get_external_content_url(message)
+ if external_url:
+ return Record.fromURL(external_url)
+
+ content = await self.line_api.get_message_content(message_id)
+ if not content:
+ return None
+ content_bytes, content_type, _ = content
+ suffix = self._guess_suffix(content_type, ".m4a")
+ file_path = self._store_temp_content("audio", message_id, content_bytes, suffix)
+ return Record(file=file_path, url=file_path)
+
+ async def _build_file_component(
+ self,
+ message_id: str,
+ message: dict[str, Any],
+ ) -> File | None:
+ content = await self.line_api.get_message_content(message_id)
+ if not content:
+ return None
+ content_bytes, content_type, filename = content
+ default_name = str(message.get("fileName", "")).strip() or f"{message_id}.bin"
+ suffix = Path(default_name).suffix or self._guess_suffix(content_type, ".bin")
+ final_name = filename or default_name
+ file_path = self._store_temp_content(
+ "file",
+ message_id,
+ content_bytes,
+ suffix,
+ original_name=final_name,
+ )
+ return File(name=final_name, file=file_path, url=file_path)
+
+ @staticmethod
+ def _get_external_content_url(message: dict[str, Any]) -> str:
+ provider = message.get("contentProvider")
+ if not isinstance(provider, dict):
+ return ""
+ if str(provider.get("type", "")) != "external":
+ return ""
+ return str(provider.get("originalContentUrl", "")).strip()
+
+ @staticmethod
+ def _guess_suffix(content_type: str | None, fallback: str) -> str:
+ if not content_type:
+ return fallback
+ base_type = content_type.split(";", 1)[0].strip().lower()
+ guessed = mimetypes.guess_extension(base_type)
+ if guessed:
+ return guessed
+ return fallback
+
+ @staticmethod
+ def _store_temp_content(
+ content_type: str,
+ message_id: str,
+ content: bytes,
+ suffix: str,
+ original_name: str = "",
+ ) -> str:
+ temp_dir = Path(get_astrbot_temp_path())
+ temp_dir.mkdir(parents=True, exist_ok=True)
+ name_prefix = f"line_{content_type}"
+ if original_name:
+ safe_stem = Path(original_name).stem.strip()
+ safe_stem = "".join(
+ ch if ch.isalnum() or ch in ("-", "_", ".") else "_" for ch in safe_stem
+ )
+ safe_stem = safe_stem.strip("._")
+ if safe_stem:
+ name_prefix = safe_stem[:64]
+ file_path = temp_dir / f"{name_prefix}_{message_id}_{uuid.uuid4().hex[:6]}"
+ file_path = file_path.with_suffix(suffix)
+ file_path.write_bytes(content)
+ return str(file_path.resolve())
+
+ @staticmethod
+ def _build_message_str(components: list) -> str:
+ parts: list[str] = []
+ for comp in components:
+ if isinstance(comp, Plain):
+ parts.append(comp.text)
+ elif isinstance(comp, At):
+ parts.append(f"@{comp.name or comp.qq}")
+ elif isinstance(comp, Image):
+ parts.append("[image]")
+ elif isinstance(comp, Video):
+ parts.append("[video]")
+ elif isinstance(comp, Record):
+ parts.append("[audio]")
+ elif isinstance(comp, File):
+ parts.append(str(comp.name or "[file]"))
+ else:
+ parts.append(f"[{comp.type}]")
+ return " ".join(i for i in parts if i).strip()
+
+ def _clean_expired_events(self) -> None:
+ current = time.time()
+ expired = [
+ event_id
+ for event_id, ts in self._event_id_timestamps.items()
+ if current - ts > 1800
+ ]
+ for event_id in expired:
+ del self._event_id_timestamps[event_id]
+
+ def _is_duplicate_event(self, event_id: str) -> bool:
+ self._clean_expired_events()
+ if event_id in self._event_id_timestamps:
+ return True
+ self._event_id_timestamps[event_id] = time.time()
+ return False
+
+ async def handle_msg(self, abm: AstrBotMessage) -> None:
+ event = LineMessageEvent(
+ message_str=abm.message_str,
+ message_obj=abm,
+ platform_meta=self.meta(),
+ session_id=abm.session_id,
+ line_api=self.line_api,
+ )
+ self._event_queue.put_nowait(event)
diff --git a/astrbot/core/platform/sources/line/line_api.py b/astrbot/core/platform/sources/line/line_api.py
new file mode 100644
index 000000000..32204bd6e
--- /dev/null
+++ b/astrbot/core/platform/sources/line/line_api.py
@@ -0,0 +1,203 @@
+import asyncio
+import base64
+import hmac
+import json
+from hashlib import sha256
+from typing import Any
+from urllib.parse import unquote
+
+import aiohttp
+
+from astrbot.api import logger
+
+
+class LineAPIClient:
+ def __init__(
+ self,
+ *,
+ channel_access_token: str,
+ channel_secret: str,
+ timeout_seconds: int = 30,
+ ) -> None:
+ self.channel_access_token = channel_access_token.strip()
+ self.channel_secret = channel_secret.strip()
+ self.timeout = aiohttp.ClientTimeout(total=timeout_seconds)
+ self._session: aiohttp.ClientSession | None = None
+
+ async def _get_session(self) -> aiohttp.ClientSession:
+ if self._session is None or self._session.closed:
+ self._session = aiohttp.ClientSession(timeout=self.timeout)
+ return self._session
+
+ async def close(self) -> None:
+ if self._session and not self._session.closed:
+ await self._session.close()
+
+ def verify_signature(self, raw_body: bytes, signature: str | None) -> bool:
+ if not signature:
+ return False
+ digest = hmac.new(
+ self.channel_secret.encode("utf-8"),
+ raw_body,
+ sha256,
+ ).digest()
+ expected = base64.b64encode(digest).decode("utf-8")
+ return hmac.compare_digest(expected, signature.strip())
+
+ @property
+ def _auth_headers(self) -> dict[str, str]:
+ return {"Authorization": f"Bearer {self.channel_access_token}"}
+
+ async def reply_message(
+ self,
+ reply_token: str,
+ messages: list[dict[str, Any]],
+ *,
+ notification_disabled: bool = False,
+ ) -> bool:
+ payload = {
+ "replyToken": reply_token,
+ "messages": messages[:5],
+ "notificationDisabled": notification_disabled,
+ }
+ return await self._post_json(
+ "https://api.line.me/v2/bot/message/reply",
+ payload=payload,
+ op_name="reply",
+ )
+
+ async def push_message(
+ self,
+ to: str,
+ messages: list[dict[str, Any]],
+ *,
+ notification_disabled: bool = False,
+ ) -> bool:
+ payload = {
+ "to": to,
+ "messages": messages[:5],
+ "notificationDisabled": notification_disabled,
+ }
+ return await self._post_json(
+ "https://api.line.me/v2/bot/message/push",
+ payload=payload,
+ op_name="push",
+ )
+
+ async def _post_json(
+ self,
+ url: str,
+ *,
+ payload: dict[str, Any],
+ op_name: str,
+ ) -> bool:
+ session = await self._get_session()
+ headers = {
+ **self._auth_headers,
+ "Content-Type": "application/json",
+ }
+ try:
+ async with session.post(url, json=payload, headers=headers) as resp:
+ if resp.status < 400:
+ return True
+ body = await resp.text()
+ logger.error(
+ "[LINE] %s message failed: status=%s body=%s",
+ op_name,
+ resp.status,
+ body,
+ )
+ return False
+ except Exception as e:
+ logger.error("[LINE] %s message request failed: %s", op_name, e)
+ return False
+
+ async def get_message_content(
+ self,
+ message_id: str,
+ ) -> tuple[bytes, str | None, str | None] | None:
+ session = await self._get_session()
+ url = f"https://api-data.line.me/v2/bot/message/{message_id}/content"
+ headers = self._auth_headers
+
+ async with session.get(url, headers=headers) as resp:
+ if resp.status == 202:
+ if not await self._wait_for_transcoding(message_id):
+ return None
+ async with session.get(url, headers=headers) as retry_resp:
+ if retry_resp.status != 200:
+ body = await retry_resp.text()
+ logger.warning(
+ "[LINE] get content retry failed: message_id=%s status=%s body=%s",
+ message_id,
+ retry_resp.status,
+ body,
+ )
+ return None
+ return await self._read_content_response(retry_resp)
+
+ if resp.status != 200:
+ body = await resp.text()
+ logger.warning(
+ "[LINE] get content failed: message_id=%s status=%s body=%s",
+ message_id,
+ resp.status,
+ body,
+ )
+ return None
+ return await self._read_content_response(resp)
+
+ async def _read_content_response(
+ self,
+ resp: aiohttp.ClientResponse,
+ ) -> tuple[bytes, str | None, str | None]:
+ content = await resp.read()
+ content_type = resp.headers.get("Content-Type")
+ disposition = resp.headers.get("Content-Disposition")
+ filename = self._extract_filename_from_disposition(disposition)
+ return content, content_type, filename
+
+ def _extract_filename_from_disposition(self, disposition: str | None) -> str | None:
+ if not disposition:
+ return None
+ for part in disposition.split(";"):
+ token = part.strip()
+ if token.startswith("filename*="):
+ val = token.split("=", 1)[1].strip().strip('"')
+ if val.lower().startswith("utf-8''"):
+ val = val[7:]
+ return unquote(val)
+ if token.startswith("filename="):
+ return token.split("=", 1)[1].strip().strip('"')
+ return None
+
+ async def _wait_for_transcoding(
+ self,
+ message_id: str,
+ *,
+ max_attempts: int = 10,
+ interval_seconds: float = 1.0,
+ ) -> bool:
+ session = await self._get_session()
+ url = (
+ f"https://api-data.line.me/v2/bot/message/{message_id}/content/transcoding"
+ )
+ headers = self._auth_headers
+
+ for _ in range(max_attempts):
+ try:
+ async with session.get(url, headers=headers) as resp:
+ if resp.status != 200:
+ await asyncio.sleep(interval_seconds)
+ continue
+ body = await resp.text()
+ data = json.loads(body)
+ status = str(data.get("status", "")).lower()
+ if status == "succeeded":
+ return True
+ if status == "failed":
+ return False
+ except Exception:
+ pass
+ await asyncio.sleep(interval_seconds)
+ return False
diff --git a/astrbot/core/platform/sources/line/line_event.py b/astrbot/core/platform/sources/line/line_event.py
new file mode 100644
index 000000000..04be53922
--- /dev/null
+++ b/astrbot/core/platform/sources/line/line_event.py
@@ -0,0 +1,285 @@
+import asyncio
+import os
+import re
+import uuid
+from collections.abc import AsyncGenerator
+from pathlib import Path
+
+from astrbot.api import logger
+from astrbot.api.event import AstrMessageEvent, MessageChain
+from astrbot.api.message_components import (
+ At,
+ BaseMessageComponent,
+ File,
+ Image,
+ Plain,
+ Record,
+ Video,
+)
+from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
+from astrbot.core.utils.media_utils import get_media_duration
+
+from .line_api import LineAPIClient
+
+
+class LineMessageEvent(AstrMessageEvent):
+ def __init__(
+ self,
+ message_str,
+ message_obj,
+ platform_meta,
+ session_id,
+ line_api: LineAPIClient,
+ ) -> None:
+ super().__init__(message_str, message_obj, platform_meta, session_id)
+ self.line_api = line_api
+
+ @staticmethod
+ async def _component_to_message_object(
+ segment: BaseMessageComponent,
+ ) -> dict | None:
+ if isinstance(segment, Plain):
+ text = segment.text.strip()
+ if not text:
+ return None
+ return {"type": "text", "text": text[:5000]}
+
+ if isinstance(segment, At):
+ name = str(segment.name or segment.qq or "").strip()
+ if not name:
+ return None
+ return {"type": "text", "text": f"@{name}"[:5000]}
+
+ if isinstance(segment, Image):
+ image_url = await LineMessageEvent._resolve_image_url(segment)
+ if not image_url:
+ return None
+ return {
+ "type": "image",
+ "originalContentUrl": image_url,
+ "previewImageUrl": image_url,
+ }
+
+ if isinstance(segment, Record):
+ audio_url = await LineMessageEvent._resolve_record_url(segment)
+ if not audio_url:
+ return None
+ duration = await LineMessageEvent._resolve_record_duration(segment)
+ return {
+ "type": "audio",
+ "originalContentUrl": audio_url,
+ "duration": duration,
+ }
+
+ if isinstance(segment, Video):
+ video_url = await LineMessageEvent._resolve_video_url(segment)
+ if not video_url:
+ return None
+ preview_url = await LineMessageEvent._resolve_video_preview_url(segment)
+ if not preview_url:
+ return None
+ return {
+ "type": "video",
+ "originalContentUrl": video_url,
+ "previewImageUrl": preview_url,
+ }
+
+ if isinstance(segment, File):
+ file_url = await LineMessageEvent._resolve_file_url(segment)
+ if not file_url:
+ return None
+ file_name = str(segment.name or "").strip() or "file.bin"
+ file_size = await LineMessageEvent._resolve_file_size(segment)
+ if file_size <= 0:
+ return None
+ return {
+ "type": "file",
+ "fileName": file_name,
+ "fileSize": file_size,
+ "originalContentUrl": file_url,
+ }
+
+ return None
+
+ @staticmethod
+ async def _resolve_image_url(segment: Image) -> str:
+ candidate = (segment.url or segment.file or "").strip()
+ if candidate.startswith("http://") or candidate.startswith("https://"):
+ return candidate
+ try:
+ return await segment.register_to_file_service()
+ except Exception as e:
+ logger.debug("[LINE] resolve image url failed: %s", e)
+ return ""
+
+ @staticmethod
+ async def _resolve_record_url(segment: Record) -> str:
+ candidate = (segment.url or segment.file or "").strip()
+ if candidate.startswith("http://") or candidate.startswith("https://"):
+ return candidate
+ try:
+ return await segment.register_to_file_service()
+ except Exception as e:
+ logger.debug("[LINE] resolve record url failed: %s", e)
+ return ""
+
+ @staticmethod
+ async def _resolve_record_duration(segment: Record) -> int:
+ try:
+ file_path = await segment.convert_to_file_path()
+ duration_ms = await get_media_duration(file_path)
+ if isinstance(duration_ms, int) and duration_ms > 0:
+ return duration_ms
+ except Exception as e:
+ logger.debug("[LINE] resolve record duration failed: %s", e)
+ return 1000
+
+ @staticmethod
+ async def _resolve_video_url(segment: Video) -> str:
+ candidate = (segment.file or "").strip()
+ if candidate.startswith("http://") or candidate.startswith("https://"):
+ return candidate
+ try:
+ return await segment.register_to_file_service()
+ except Exception as e:
+ logger.debug("[LINE] resolve video url failed: %s", e)
+ return ""
+
+ @staticmethod
+ async def _resolve_video_preview_url(segment: Video) -> str:
+ cover_candidate = (segment.cover or "").strip()
+ if cover_candidate.startswith("http://") or cover_candidate.startswith(
+ "https://"
+ ):
+ return cover_candidate
+
+ if cover_candidate:
+ try:
+ cover_seg = Image(file=cover_candidate)
+ return await cover_seg.register_to_file_service()
+ except Exception as e:
+ logger.debug("[LINE] resolve video cover failed: %s", e)
+
+ try:
+ video_path = await segment.convert_to_file_path()
+ temp_dir = Path(get_astrbot_temp_path())
+ temp_dir.mkdir(parents=True, exist_ok=True)
+ thumb_path = temp_dir / f"line_video_preview_{uuid.uuid4().hex}.jpg"
+
+ process = await asyncio.create_subprocess_exec(
+ "ffmpeg",
+ "-y",
+ "-ss",
+ "00:00:01",
+ "-i",
+ video_path,
+ "-frames:v",
+ "1",
+ str(thumb_path),
+ stdout=asyncio.subprocess.PIPE,
+ stderr=asyncio.subprocess.PIPE,
+ )
+ await process.communicate()
+ if process.returncode != 0 or not thumb_path.exists():
+ return ""
+
+ cover_seg = Image.fromFileSystem(str(thumb_path))
+ return await cover_seg.register_to_file_service()
+ except Exception as e:
+ logger.debug("[LINE] generate video preview failed: %s", e)
+ return ""
+
+ @staticmethod
+ async def _resolve_file_url(segment: File) -> str:
+ if segment.url and segment.url.startswith(("http://", "https://")):
+ return segment.url
+ try:
+ return await segment.register_to_file_service()
+ except Exception as e:
+ logger.debug("[LINE] resolve file url failed: %s", e)
+ return ""
+
+ @staticmethod
+ async def _resolve_file_size(segment: File) -> int:
+ try:
+ file_path = await segment.get_file(allow_return_url=False)
+ if file_path and os.path.exists(file_path):
+ return int(os.path.getsize(file_path))
+ except Exception as e:
+ logger.debug("[LINE] resolve file size failed: %s", e)
+ return 0
+
+ @classmethod
+ async def build_line_messages(cls, message_chain: MessageChain) -> list[dict]:
+ messages: list[dict] = []
+ for segment in message_chain.chain:
+ obj = await cls._component_to_message_object(segment)
+ if obj:
+ messages.append(obj)
+
+ if not messages:
+ return []
+
+ if len(messages) > 5:
+ logger.warning(
+ "[LINE] message count exceeds 5, extra segments will be dropped."
+ )
+ messages = messages[:5]
+ return messages
+
+ async def send(self, message: MessageChain) -> None:
+ messages = await self.build_line_messages(message)
+ if not messages:
+ return
+
+ raw = self.message_obj.raw_message
+ reply_token = ""
+ if isinstance(raw, dict):
+ reply_token = str(raw.get("replyToken") or "")
+
+ sent = False
+ if reply_token:
+ sent = await self.line_api.reply_message(reply_token, messages)
+
+ if not sent:
+ target_id = self.get_group_id() or self.get_sender_id()
+ if target_id:
+ await self.line_api.push_message(target_id, messages)
+
+ await super().send(message)
+
+ async def send_streaming(
+ self,
+ generator: AsyncGenerator,
+ use_fallback: bool = False,
+ ):
+ if not use_fallback:
+ buffer = None
+ async for chain in generator:
+ if not buffer:
+ buffer = chain
+ else:
+ buffer.chain.extend(chain.chain)
+ if not buffer:
+ return None
+ buffer.squash_plain()
+ await self.send(buffer)
+ return await super().send_streaming(generator, use_fallback)
+
+ buffer = ""
+ pattern = re.compile(r"[^。?!~…]+[。?!~…]+")
+
+ async for chain in generator:
+ if isinstance(chain, MessageChain):
+ for comp in chain.chain:
+ if isinstance(comp, Plain):
+ buffer += comp.text
+ if any(p in buffer for p in "。?!~…"):
+ buffer = await self.process_buffer(buffer, pattern)
+ else:
+ await self.send(MessageChain(chain=[comp]))
+ await asyncio.sleep(1.5)
+
+ if buffer.strip():
+ await self.send(MessageChain([Plain(buffer)]))
+ return await super().send_streaming(generator, use_fallback)
diff --git a/astrbot/core/platform/sources/misskey/misskey_adapter.py b/astrbot/core/platform/sources/misskey/misskey_adapter.py
index 34415b855..fd61c3e50 100644
--- a/astrbot/core/platform/sources/misskey/misskey_adapter.py
+++ b/astrbot/core/platform/sources/misskey/misskey_adapter.py
@@ -21,7 +21,7 @@ try:
except Exception:
magic = None
-from astrbot.core.utils.astrbot_path import get_astrbot_data_path
+from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from .misskey_event import MisskeyPlatformEvent
from .misskey_utils import (
@@ -498,7 +498,7 @@ class MisskeyPlatformAdapter(Platform):
finally:
# 清理临时文件
if local_path and isinstance(local_path, str):
- data_temp = os.path.join(get_astrbot_data_path(), "temp")
+ data_temp = get_astrbot_temp_path()
if local_path.startswith(data_temp) and os.path.exists(
local_path,
):
diff --git a/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py b/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py
index 88c8fc225..1af4de49b 100644
--- a/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py
+++ b/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py
@@ -19,7 +19,7 @@ from astrbot.api import logger
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.message_components import Image, Plain, Record
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
-from astrbot.core.utils.astrbot_path import get_astrbot_data_path
+from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from astrbot.core.utils.io import download_image_by_url, file_to_base64
from astrbot.core.utils.tencent_record_helper import wav_to_tencent_silk
@@ -350,10 +350,10 @@ class QQOfficialMessageEvent(AstrMessageEvent):
elif isinstance(i, Record):
if i.file:
record_wav_path = await i.convert_to_file_path() # wav 路径
- temp_dir = os.path.join(get_astrbot_data_path(), "temp")
+ temp_dir = get_astrbot_temp_path()
record_tecent_silk_path = os.path.join(
temp_dir,
- f"{uuid.uuid4()}.silk",
+ f"qqofficial_{uuid.uuid4()}.silk",
)
try:
duration = await wav_to_tencent_silk(
diff --git a/astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py b/astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py
index 86ca76db8..603bc8f58 100644
--- a/astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py
+++ b/astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py
@@ -8,13 +8,11 @@ from typing import cast
import botpy
import botpy.message
-import botpy.types
-import botpy.types.message
from botpy import Client
from astrbot import logger
from astrbot.api.event import MessageChain
-from astrbot.api.message_components import At, Image, Plain
+from astrbot.api.message_components import At, File, Image, Plain
from astrbot.api.platform import (
AstrBotMessage,
MessageMember,
@@ -143,6 +141,41 @@ class QQOfficialPlatformAdapter(Platform):
support_proactive_message=False,
)
+ @staticmethod
+ def _normalize_attachment_url(url: str | None) -> str:
+ if not url:
+ return ""
+ if url.startswith("http://") or url.startswith("https://"):
+ return url
+ return f"https://{url}"
+
+ @staticmethod
+ def _append_attachments(
+ msg: list[BaseMessageComponent],
+ attachments: list | None,
+ ) -> None:
+ if not attachments:
+ return
+
+ for attachment in attachments:
+ content_type = cast(str, getattr(attachment, "content_type", "") or "")
+ url = QQOfficialPlatformAdapter._normalize_attachment_url(
+ cast(str | None, getattr(attachment, "url", None))
+ )
+ if not url:
+ continue
+
+ if content_type.startswith("image"):
+ msg.append(Image.fromURL(url))
+ else:
+ filename = cast(
+ str,
+ getattr(attachment, "filename", None)
+ or getattr(attachment, "name", None)
+ or "attachment",
+ )
+ msg.append(File(name=filename, file=url, url=url))
+
@staticmethod
def _parse_from_qqofficial(
message: botpy.message.Message
@@ -172,14 +205,7 @@ class QQOfficialPlatformAdapter(Platform):
abm.self_id = "unknown_selfid"
msg.append(At(qq="qq_official"))
msg.append(Plain(abm.message_str))
- if message.attachments:
- for i in message.attachments:
- if i.content_type.startswith("image"):
- url = i.url
- if not url.startswith("http"):
- url = "https://" + url
- img = Image.fromURL(url)
- msg.append(img)
+ QQOfficialPlatformAdapter._append_attachments(msg, message.attachments)
abm.message = msg
elif isinstance(message, botpy.message.Message) or isinstance(
@@ -196,14 +222,7 @@ class QQOfficialPlatformAdapter(Platform):
"",
).strip()
- if message.attachments:
- for i in message.attachments:
- if i.content_type.startswith("image"):
- url = i.url
- if not url.startswith("http"):
- url = "https://" + url
- img = Image.fromURL(url)
- msg.append(img)
+ QQOfficialPlatformAdapter._append_attachments(msg, message.attachments)
abm.message = msg
abm.message_str = plain_content
abm.sender = MessageMember(
diff --git a/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py b/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py
index c709f2cec..6aae6b9ce 100644
--- a/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py
+++ b/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py
@@ -1,11 +1,11 @@
import asyncio
import logging
+import random
+from types import SimpleNamespace
from typing import Any, cast
import botpy
import botpy.message
-import botpy.types
-import botpy.types.message
from botpy import Client
from astrbot import logger
@@ -15,6 +15,7 @@ from astrbot.core.platform.astr_message_event import MessageSesion
from astrbot.core.utils.webhook_utils import log_webhook_info
from ...register import register_platform_adapter
+from ..qqofficial.qqofficial_message_event import QQOfficialMessageEvent
from ..qqofficial.qqofficial_platform_adapter import QQOfficialPlatformAdapter
from .qo_webhook_event import QQOfficialWebhookMessageEvent
from .qo_webhook_server import QQOfficialWebhook
@@ -39,6 +40,7 @@ class botClient(Client):
)
abm.group_id = cast(str, message.group_openid)
abm.session_id = abm.group_id
+ self.platform.remember_session_scene(abm.session_id, "group")
self._commit(abm)
# 收到频道消息
@@ -49,6 +51,7 @@ class botClient(Client):
)
abm.group_id = message.channel_id
abm.session_id = abm.group_id
+ self.platform.remember_session_scene(abm.session_id, "channel")
self._commit(abm)
# 收到私聊消息
@@ -60,6 +63,7 @@ class botClient(Client):
MessageType.FRIEND_MESSAGE,
)
abm.session_id = abm.sender.user_id
+ self.platform.remember_session_scene(abm.session_id, "friend")
self._commit(abm)
# 收到 C2C 消息
@@ -69,9 +73,11 @@ class botClient(Client):
MessageType.FRIEND_MESSAGE,
)
abm.session_id = abm.sender.user_id
+ self.platform.remember_session_scene(abm.session_id, "friend")
self._commit(abm)
def _commit(self, abm: AstrBotMessage) -> None:
+ self.platform.remember_session_message_id(abm.session_id, abm.message_id)
self.platform.commit_event(
QQOfficialWebhookMessageEvent(
abm.message_str,
@@ -109,20 +115,129 @@ class QQOfficialWebhookPlatformAdapter(Platform):
)
self.client.set_platform(self)
self.webhook_helper = None
+ self._session_last_message_id: dict[str, str] = {}
+ self._session_scene: dict[str, str] = {}
async def send_by_session(
self,
session: MessageSesion,
message_chain: MessageChain,
) -> None:
- raise NotImplementedError("QQ 机器人官方 API 适配器不支持 send_by_session")
+ (
+ plain_text,
+ image_base64,
+ image_path,
+ record_file_path,
+ ) = await QQOfficialMessageEvent._parse_to_qqofficial(message_chain)
+ if not plain_text and not image_path:
+ return
+
+ msg_id = self._session_last_message_id.get(session.session_id)
+ if not msg_id:
+ logger.warning(
+ "[QQOfficialWebhook] No cached msg_id for session: %s, skip send_by_session",
+ session.session_id,
+ )
+ return
+
+ payload: dict[str, Any] = {"content": plain_text, "msg_id": msg_id}
+ ret: Any = None
+ send_helper = SimpleNamespace(bot=self.client)
+ if session.message_type == MessageType.GROUP_MESSAGE:
+ scene = self._session_scene.get(session.session_id)
+ if scene == "group":
+ payload["msg_seq"] = random.randint(1, 10000)
+ if image_base64:
+ media = await QQOfficialMessageEvent.upload_group_and_c2c_image(
+ send_helper, # type: ignore
+ image_base64,
+ 1,
+ group_openid=session.session_id,
+ )
+ payload["media"] = media
+ payload["msg_type"] = 7
+ if record_file_path:
+ media = await QQOfficialMessageEvent.upload_group_and_c2c_record(
+ send_helper, # type: ignore
+ record_file_path,
+ 3,
+ group_openid=session.session_id,
+ )
+ payload["media"] = media
+ payload["msg_type"] = 7
+ ret = await self.client.api.post_group_message(
+ group_openid=session.session_id,
+ **payload,
+ )
+ else:
+ if image_path:
+ payload["file_image"] = image_path
+ ret = await self.client.api.post_message(
+ channel_id=session.session_id,
+ **payload,
+ )
+ elif session.message_type == MessageType.FRIEND_MESSAGE:
+ payload["msg_seq"] = random.randint(1, 10000)
+ if image_base64:
+ media = await QQOfficialMessageEvent.upload_group_and_c2c_image(
+ send_helper, # type: ignore
+ image_base64,
+ 1,
+ openid=session.session_id,
+ )
+ payload["media"] = media
+ payload["msg_type"] = 7
+ if record_file_path:
+ media = await QQOfficialMessageEvent.upload_group_and_c2c_record(
+ send_helper, # type: ignore
+ record_file_path,
+ 3,
+ openid=session.session_id,
+ )
+ payload["media"] = media
+ payload["msg_type"] = 7
+ ret = await QQOfficialMessageEvent.post_c2c_message(
+ send_helper, # type: ignore
+ openid=session.session_id,
+ **payload,
+ )
+ else:
+ logger.warning(
+ "[QQOfficialWebhook] Unsupported message type for send_by_session: %s",
+ session.message_type,
+ )
+ return
+
+ sent_message_id = self._extract_message_id(ret)
+ if sent_message_id:
+ self.remember_session_message_id(session.session_id, sent_message_id)
+ await super().send_by_session(session, message_chain)
+
+ def remember_session_message_id(self, session_id: str, message_id: str) -> None:
+ if not session_id or not message_id:
+ return
+ self._session_last_message_id[session_id] = message_id
+
+ def remember_session_scene(self, session_id: str, scene: str) -> None:
+ if not session_id or not scene:
+ return
+ self._session_scene[session_id] = scene
+
+ def _extract_message_id(self, ret: Any) -> str | None:
+ if isinstance(ret, dict):
+ message_id = ret.get("id")
+ return str(message_id) if message_id else None
+ message_id = getattr(ret, "id", None)
+ if message_id:
+ return str(message_id)
+ return None
def meta(self) -> PlatformMetadata:
return PlatformMetadata(
name="qq_official_webhook",
description="QQ 机器人官方 API 适配器",
id=cast(str, self.config.get("id")),
- support_proactive_message=False,
+ support_proactive_message=True,
)
async def run(self) -> None:
diff --git a/astrbot/core/platform/sources/telegram/tg_event.py b/astrbot/core/platform/sources/telegram/tg_event.py
index 1df289d83..ffa58e1a8 100644
--- a/astrbot/core/platform/sources/telegram/tg_event.py
+++ b/astrbot/core/platform/sources/telegram/tg_event.py
@@ -5,6 +5,7 @@ from typing import Any, cast
import telegramify_markdown
from telegram import ReactionTypeCustomEmoji, ReactionTypeEmoji
+from telegram.constants import ChatAction
from telegram.ext import ExtBot
from astrbot import logger
@@ -31,6 +32,14 @@ class TelegramPlatformEvent(AstrMessageEvent):
"word": re.compile(r"\s"),
}
+ # 消息类型到 chat action 的映射,用于优先级判断
+ ACTION_BY_TYPE: dict[type, str] = {
+ Record: ChatAction.UPLOAD_VOICE,
+ File: ChatAction.UPLOAD_DOCUMENT,
+ Image: ChatAction.UPLOAD_PHOTO,
+ Plain: ChatAction.TYPING,
+ }
+
def __init__(
self,
message_str: str,
@@ -67,6 +76,71 @@ class TelegramPlatformEvent(AstrMessageEvent):
return chunks
+ @classmethod
+ async def _send_chat_action(
+ cls,
+ client: ExtBot,
+ chat_id: str,
+ action: ChatAction | str,
+ message_thread_id: str | None = None,
+ ) -> None:
+ """发送聊天状态动作"""
+ try:
+ payload: dict[str, Any] = {"chat_id": chat_id, "action": action}
+ if message_thread_id:
+ payload["message_thread_id"] = message_thread_id
+ await client.send_chat_action(**payload)
+ except Exception as e:
+ logger.warning(f"[Telegram] 发送 chat action 失败: {e}")
+
+ @classmethod
+ def _get_chat_action_for_chain(cls, chain: list[Any]) -> ChatAction | str:
+ """根据消息链中的组件类型确定合适的 chat action(按优先级)"""
+ for seg_type, action in cls.ACTION_BY_TYPE.items():
+ if any(isinstance(seg, seg_type) for seg in chain):
+ return action
+ return ChatAction.TYPING
+
+ @classmethod
+ async def _send_media_with_action(
+ cls,
+ client: ExtBot,
+ upload_action: ChatAction | str,
+ send_coro,
+ *,
+ user_name: str,
+ message_thread_id: str | None = None,
+ **payload: Any,
+ ) -> None:
+ """发送媒体时显示 upload action,发送完成后恢复 typing"""
+ await cls._send_chat_action(client, user_name, upload_action, message_thread_id)
+ await send_coro(**payload)
+ await cls._send_chat_action(
+ client, user_name, ChatAction.TYPING, message_thread_id
+ )
+
+ async def _ensure_typing(
+ self,
+ user_name: str,
+ message_thread_id: str | None = None,
+ ) -> None:
+ """确保显示 typing 状态"""
+ await self._send_chat_action(
+ self.client, user_name, ChatAction.TYPING, message_thread_id
+ )
+
+ async def send_typing(self) -> None:
+ message_thread_id = None
+ if self.get_message_type() == MessageType.GROUP_MESSAGE:
+ user_name = self.message_obj.group_id
+ else:
+ user_name = self.get_sender_id()
+
+ if "#" in user_name:
+ user_name, message_thread_id = user_name.split("#")
+
+ await self._ensure_typing(user_name, message_thread_id)
+
@classmethod
async def send_with_client(
cls,
@@ -91,6 +165,11 @@ class TelegramPlatformEvent(AstrMessageEvent):
if "#" in user_name:
# it's a supergroup chat with message_thread_id
user_name, message_thread_id = user_name.split("#")
+
+ # 根据消息链确定合适的 chat action 并发送
+ action = cls._get_chat_action_for_chain(message.chain)
+ await cls._send_chat_action(client, user_name, action, message_thread_id)
+
for i in message.chain:
payload = {
"chat_id": user_name,
@@ -195,6 +274,12 @@ class TelegramPlatformEvent(AstrMessageEvent):
message_id = None
last_edit_time = 0 # 上次编辑消息的时间
throttle_interval = 0.6 # 编辑消息的间隔时间 (秒)
+ last_chat_action_time = 0 # 上次发送 chat action 的时间
+ chat_action_interval = 0.5 # chat action 的节流间隔 (秒)
+
+ # 发送初始 typing 状态
+ await self._ensure_typing(user_name, message_thread_id)
+ last_chat_action_time = asyncio.get_event_loop().time()
async for chain in generator:
if isinstance(chain, MessageChain):
@@ -219,15 +304,25 @@ class TelegramPlatformEvent(AstrMessageEvent):
delta += i.text
elif isinstance(i, Image):
image_path = await i.convert_to_file_path()
- await self.client.send_photo(
- photo=image_path, **cast(Any, payload)
+ await self._send_media_with_action(
+ self.client,
+ ChatAction.UPLOAD_PHOTO,
+ self.client.send_photo,
+ user_name=user_name,
+ message_thread_id=message_thread_id,
+ photo=image_path,
+ **cast(Any, payload),
)
continue
elif isinstance(i, File):
path = await i.get_file()
name = i.name or os.path.basename(path)
-
- await self.client.send_document(
+ await self._send_media_with_action(
+ self.client,
+ ChatAction.UPLOAD_DOCUMENT,
+ self.client.send_document,
+ user_name=user_name,
+ message_thread_id=message_thread_id,
document=path,
filename=name,
**cast(Any, payload),
@@ -235,7 +330,15 @@ class TelegramPlatformEvent(AstrMessageEvent):
continue
elif isinstance(i, Record):
path = await i.convert_to_file_path()
- await self.client.send_voice(voice=path, **cast(Any, payload))
+ await self._send_media_with_action(
+ self.client,
+ ChatAction.UPLOAD_VOICE,
+ self.client.send_voice,
+ user_name=user_name,
+ message_thread_id=message_thread_id,
+ voice=path,
+ **cast(Any, payload),
+ )
continue
else:
logger.warning(f"不支持的消息类型: {type(i)}")
@@ -248,6 +351,11 @@ class TelegramPlatformEvent(AstrMessageEvent):
# 如果距离上次编辑的时间 >= 设定的间隔,等待一段时间
if time_since_last_edit >= throttle_interval:
+ # 发送 typing 状态(带节流)
+ current_time = asyncio.get_event_loop().time()
+ if current_time - last_chat_action_time >= chat_action_interval:
+ await self._ensure_typing(user_name, message_thread_id)
+ last_chat_action_time = current_time
# 编辑消息
try:
await self.client.edit_message_text(
@@ -263,6 +371,11 @@ class TelegramPlatformEvent(AstrMessageEvent):
) # 更新上次编辑的时间
else:
# delta 长度一般不会大于 4096,因此这里直接发送
+ # 发送 typing 状态(带节流)
+ current_time = asyncio.get_event_loop().time()
+ if current_time - last_chat_action_time >= chat_action_interval:
+ await self._ensure_typing(user_name, message_thread_id)
+ last_chat_action_time = current_time
try:
msg = await self.client.send_message(
text=delta, **cast(Any, payload)
diff --git a/astrbot/core/platform/sources/webchat/webchat_adapter.py b/astrbot/core/platform/sources/webchat/webchat_adapter.py
index 5eb62e6b3..047417aaa 100644
--- a/astrbot/core/platform/sources/webchat/webchat_adapter.py
+++ b/astrbot/core/platform/sources/webchat/webchat_adapter.py
@@ -26,14 +26,23 @@ from .webchat_queue_mgr import WebChatQueueMgr, webchat_queue_mgr
class QueueListener:
- def __init__(self, webchat_queue_mgr: WebChatQueueMgr, callback: Callable) -> None:
+ def __init__(
+ self,
+ webchat_queue_mgr: WebChatQueueMgr,
+ callback: Callable,
+ stop_event: asyncio.Event,
+ ) -> None:
self.webchat_queue_mgr = webchat_queue_mgr
self.callback = callback
+ self.stop_event = stop_event
async def run(self) -> None:
"""Register callback and keep adapter task alive."""
self.webchat_queue_mgr.set_listener(self.callback)
- await asyncio.Event().wait()
+ try:
+ await self.stop_event.wait()
+ finally:
+ await self.webchat_queue_mgr.clear_listener()
@register_platform_adapter("webchat", "webchat")
@@ -56,6 +65,8 @@ class WebChatAdapter(Platform):
id="webchat",
support_proactive_message=False,
)
+ self._shutdown_event = asyncio.Event()
+ self._webchat_queue_mgr = webchat_queue_mgr
async def send_by_session(
self,
@@ -184,7 +195,7 @@ class WebChatAdapter(Platform):
abm = await self.convert_message(data)
await self.handle_msg(abm)
- bot = QueueListener(webchat_queue_mgr, callback)
+ bot = QueueListener(self._webchat_queue_mgr, callback, self._shutdown_event)
return bot.run()
def meta(self) -> PlatformMetadata:
@@ -209,5 +220,4 @@ class WebChatAdapter(Platform):
self.commit_event(message_event)
async def terminate(self) -> None:
- # Do nothing
- pass
+ self._shutdown_event.set()
diff --git a/astrbot/core/platform/sources/webchat/webchat_queue_mgr.py b/astrbot/core/platform/sources/webchat/webchat_queue_mgr.py
index 688d83e2c..fd35e837c 100644
--- a/astrbot/core/platform/sources/webchat/webchat_queue_mgr.py
+++ b/astrbot/core/platform/sources/webchat/webchat_queue_mgr.py
@@ -87,6 +87,19 @@ class WebChatQueueMgr:
for conversation_id in list(self.queues.keys()):
self._start_listener_if_needed(conversation_id)
+ async def clear_listener(self) -> None:
+ self._listener_callback = None
+ for close_event in list(self._queue_close_events.values()):
+ close_event.set()
+ self._queue_close_events.clear()
+
+ listener_tasks = list(self._listener_tasks.values())
+ for task in listener_tasks:
+ task.cancel()
+ if listener_tasks:
+ await asyncio.gather(*listener_tasks, return_exceptions=True)
+ self._listener_tasks.clear()
+
def _start_listener_if_needed(self, conversation_id: str):
if self._listener_callback is None:
return
diff --git a/astrbot/core/platform/sources/wecom/wecom_adapter.py b/astrbot/core/platform/sources/wecom/wecom_adapter.py
index 0a2e71b61..6647db89f 100644
--- a/astrbot/core/platform/sources/wecom/wecom_adapter.py
+++ b/astrbot/core/platform/sources/wecom/wecom_adapter.py
@@ -25,7 +25,7 @@ from astrbot.api.platform import (
)
from astrbot.core import logger
from astrbot.core.platform.astr_message_event import MessageSesion
-from astrbot.core.utils.astrbot_path import get_astrbot_data_path
+from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from astrbot.core.utils.media_utils import convert_audio_to_wav
from astrbot.core.utils.webhook_utils import log_webhook_info
@@ -344,7 +344,7 @@ class WecomPlatformAdapter(Platform):
self.client.media.download,
msg.media_id,
)
- temp_dir = os.path.join(get_astrbot_data_path(), "temp")
+ temp_dir = get_astrbot_temp_path()
path = os.path.join(temp_dir, f"wecom_{msg.media_id}.amr")
with open(path, "wb") as f:
f.write(resp.content)
@@ -400,7 +400,8 @@ class WecomPlatformAdapter(Platform):
self.client.media.download,
media_id,
)
- path = f"data/temp/wechat_kf_{media_id}.jpg"
+ temp_dir = get_astrbot_temp_path()
+ path = os.path.join(temp_dir, f"weixinkefu_{media_id}.jpg")
with open(path, "wb") as f:
f.write(resp.content)
abm.message = [Image(file=path, url=path)]
@@ -412,7 +413,7 @@ class WecomPlatformAdapter(Platform):
media_id,
)
- temp_dir = os.path.join(get_astrbot_data_path(), "temp")
+ temp_dir = get_astrbot_temp_path()
path = os.path.join(temp_dir, f"weixinkefu_{media_id}.amr")
with open(path, "wb") as f:
f.write(resp.content)
diff --git a/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py b/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py
index 8f12ec82b..28985f757 100644
--- a/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py
+++ b/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py
@@ -1,4 +1,5 @@
import asyncio
+import os
import sys
import uuid
from collections.abc import Awaitable, Callable
@@ -24,6 +25,7 @@ from astrbot.api.platform import (
)
from astrbot.core import logger
from astrbot.core.platform.astr_message_event import MessageSesion
+from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from astrbot.core.utils.media_utils import convert_audio_to_wav
from astrbot.core.utils.webhook_utils import log_webhook_info
@@ -290,12 +292,16 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
self.client.media.download,
msg.media_id,
)
- path = f"data/temp/wecom_{msg.media_id}.amr"
+ temp_dir = get_astrbot_temp_path()
+ path = os.path.join(temp_dir, f"weixin_offacc_{msg.media_id}.amr")
with open(path, "wb") as f:
f.write(resp.content)
try:
- path_wav = f"data/temp/wecom_{msg.media_id}.wav"
+ path_wav = os.path.join(
+ temp_dir,
+ f"weixin_offacc_{msg.media_id}.wav",
+ )
path_wav = await convert_audio_to_wav(path, path_wav)
except Exception as e:
logger.error(
diff --git a/astrbot/core/provider/sources/azure_tts_source.py b/astrbot/core/provider/sources/azure_tts_source.py
index 29dc3f7a4..0e8f00ce5 100644
--- a/astrbot/core/provider/sources/azure_tts_source.py
+++ b/astrbot/core/provider/sources/azure_tts_source.py
@@ -12,12 +12,13 @@ from httpx import AsyncClient, Timeout
from astrbot import logger
from astrbot.core.config.default import VERSION
+from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from ..entities import ProviderType
from ..provider import TTSProvider
from ..register import register_provider_adapter
-TEMP_DIR = Path("data/temp/azure_tts")
+TEMP_DIR = Path(get_astrbot_temp_path()) / "azure_tts"
TEMP_DIR.mkdir(parents=True, exist_ok=True)
diff --git a/astrbot/core/provider/sources/dashscope_tts.py b/astrbot/core/provider/sources/dashscope_tts.py
index 50bc421fd..9b6816859 100644
--- a/astrbot/core/provider/sources/dashscope_tts.py
+++ b/astrbot/core/provider/sources/dashscope_tts.py
@@ -15,7 +15,7 @@ except (
): # pragma: no cover - older dashscope versions without Qwen TTS support
MultiModalConversation = None
-from astrbot.core.utils.astrbot_path import get_astrbot_data_path
+from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from ..entities import ProviderType
from ..provider import TTSProvider
@@ -45,7 +45,7 @@ class ProviderDashscopeTTSAPI(TTSProvider):
if not model:
raise RuntimeError("Dashscope TTS model is not configured.")
- temp_dir = os.path.join(get_astrbot_data_path(), "temp")
+ temp_dir = get_astrbot_temp_path()
os.makedirs(temp_dir, exist_ok=True)
if self._is_qwen_tts_model(model):
diff --git a/astrbot/core/provider/sources/edge_tts_source.py b/astrbot/core/provider/sources/edge_tts_source.py
index 71a5a82d6..503bd275b 100644
--- a/astrbot/core/provider/sources/edge_tts_source.py
+++ b/astrbot/core/provider/sources/edge_tts_source.py
@@ -6,7 +6,7 @@ import uuid
import edge_tts
from astrbot.core import logger
-from astrbot.core.utils.astrbot_path import get_astrbot_data_path
+from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from ..entities import ProviderType
from ..provider import TTSProvider
@@ -46,7 +46,7 @@ class ProviderEdgeTTS(TTSProvider):
self.set_model("edge_tts")
async def get_audio(self, text: str) -> str:
- temp_dir = os.path.join(get_astrbot_data_path(), "temp")
+ temp_dir = get_astrbot_temp_path()
mp3_path = os.path.join(temp_dir, f"edge_tts_temp_{uuid.uuid4()}.mp3")
wav_path = os.path.join(temp_dir, f"edge_tts_{uuid.uuid4()}.wav")
diff --git a/astrbot/core/provider/sources/fishaudio_tts_api_source.py b/astrbot/core/provider/sources/fishaudio_tts_api_source.py
index dde2736a8..35945b7b6 100644
--- a/astrbot/core/provider/sources/fishaudio_tts_api_source.py
+++ b/astrbot/core/provider/sources/fishaudio_tts_api_source.py
@@ -8,7 +8,7 @@ from httpx import AsyncClient
from pydantic import BaseModel, conint
from astrbot import logger
-from astrbot.core.utils.astrbot_path import get_astrbot_data_path
+from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from ..entities import ProviderType
from ..provider import TTSProvider
@@ -142,7 +142,7 @@ class ProviderFishAudioTTSAPI(TTSProvider):
)
async def get_audio(self, text: str) -> str:
- temp_dir = os.path.join(get_astrbot_data_path(), "temp")
+ temp_dir = get_astrbot_temp_path()
path = os.path.join(temp_dir, f"fishaudio_tts_api_{uuid.uuid4()}.wav")
self.headers["content-type"] = "application/msgpack"
request = await self._generate_request(text)
diff --git a/astrbot/core/provider/sources/gemini_tts_source.py b/astrbot/core/provider/sources/gemini_tts_source.py
index 37022f761..d6954ef82 100644
--- a/astrbot/core/provider/sources/gemini_tts_source.py
+++ b/astrbot/core/provider/sources/gemini_tts_source.py
@@ -6,7 +6,7 @@ from google import genai
from google.genai import types
from astrbot import logger
-from astrbot.core.utils.astrbot_path import get_astrbot_data_path
+from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from ..entities import ProviderType
from ..provider import TTSProvider
@@ -49,7 +49,7 @@ class ProviderGeminiTTSAPI(TTSProvider):
self.voice_name: str = provider_config.get("gemini_tts_voice_name", "Leda")
async def get_audio(self, text: str) -> str:
- temp_dir = os.path.join(get_astrbot_data_path(), "temp")
+ temp_dir = get_astrbot_temp_path()
path = os.path.join(temp_dir, f"gemini_tts_{uuid.uuid4()}.wav")
prompt = f"{self.prefix}: {text}" if self.prefix else text
response = await self.client.models.generate_content(
diff --git a/astrbot/core/provider/sources/genie_tts.py b/astrbot/core/provider/sources/genie_tts.py
index bca92deb7..8f9b6d91d 100644
--- a/astrbot/core/provider/sources/genie_tts.py
+++ b/astrbot/core/provider/sources/genie_tts.py
@@ -6,7 +6,7 @@ from astrbot.core import logger
from astrbot.core.provider.entities import ProviderType
from astrbot.core.provider.provider import TTSProvider
from astrbot.core.provider.register import register_provider_adapter
-from astrbot.core.utils.astrbot_path import get_astrbot_data_path
+from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
try:
import genie_tts as genie # type: ignore
@@ -54,7 +54,7 @@ class GenieTTSProvider(TTSProvider):
return True
async def get_audio(self, text: str) -> str:
- temp_dir = os.path.join(get_astrbot_data_path(), "temp")
+ temp_dir = get_astrbot_temp_path()
os.makedirs(temp_dir, exist_ok=True)
filename = f"genie_tts_{uuid.uuid4()}.wav"
path = os.path.join(temp_dir, filename)
@@ -94,7 +94,7 @@ class GenieTTSProvider(TTSProvider):
break
try:
- temp_dir = os.path.join(get_astrbot_data_path(), "temp")
+ temp_dir = get_astrbot_temp_path()
os.makedirs(temp_dir, exist_ok=True)
filename = f"genie_tts_{uuid.uuid4()}.wav"
path = os.path.join(temp_dir, filename)
diff --git a/astrbot/core/provider/sources/gsv_selfhosted_source.py b/astrbot/core/provider/sources/gsv_selfhosted_source.py
index 029f6af10..fc8bccea8 100644
--- a/astrbot/core/provider/sources/gsv_selfhosted_source.py
+++ b/astrbot/core/provider/sources/gsv_selfhosted_source.py
@@ -5,7 +5,7 @@ import uuid
import aiohttp
from astrbot import logger
-from astrbot.core.utils.astrbot_path import get_astrbot_data_path
+from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from ..entities import ProviderType
from ..provider import TTSProvider
@@ -121,7 +121,7 @@ class ProviderGSVTTS(TTSProvider):
params = self.build_synthesis_params(text)
- temp_dir = os.path.join(get_astrbot_data_path(), "temp")
+ temp_dir = get_astrbot_temp_path()
os.makedirs(temp_dir, exist_ok=True)
path = os.path.join(temp_dir, f"gsv_tts_{uuid.uuid4().hex}.wav")
diff --git a/astrbot/core/provider/sources/gsvi_tts_source.py b/astrbot/core/provider/sources/gsvi_tts_source.py
index d8b171718..425e801f4 100644
--- a/astrbot/core/provider/sources/gsvi_tts_source.py
+++ b/astrbot/core/provider/sources/gsvi_tts_source.py
@@ -4,7 +4,7 @@ import uuid
import aiohttp
-from astrbot.core.utils.astrbot_path import get_astrbot_data_path
+from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from ..entities import ProviderType
from ..provider import TTSProvider
@@ -29,7 +29,7 @@ class ProviderGSVITTS(TTSProvider):
self.emotion = provider_config.get("emotion")
async def get_audio(self, text: str) -> str:
- temp_dir = os.path.join(get_astrbot_data_path(), "temp")
+ temp_dir = get_astrbot_temp_path()
path = os.path.join(temp_dir, f"gsvi_tts_{uuid.uuid4()}.wav")
params = {"text": text}
diff --git a/astrbot/core/provider/sources/minimax_tts_api_source.py b/astrbot/core/provider/sources/minimax_tts_api_source.py
index dcd29060e..69860111c 100644
--- a/astrbot/core/provider/sources/minimax_tts_api_source.py
+++ b/astrbot/core/provider/sources/minimax_tts_api_source.py
@@ -6,7 +6,7 @@ from collections.abc import AsyncIterator
import aiohttp
from astrbot.api import logger
-from astrbot.core.utils.astrbot_path import get_astrbot_data_path
+from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from ..entities import ProviderType
from ..provider import TTSProvider
@@ -145,7 +145,7 @@ class ProviderMiniMaxTTSAPI(TTSProvider):
return b"".join(chunks)
async def get_audio(self, text: str) -> str:
- temp_dir = os.path.join(get_astrbot_data_path(), "temp")
+ temp_dir = get_astrbot_temp_path()
os.makedirs(temp_dir, exist_ok=True)
path = os.path.join(temp_dir, f"minimax_tts_api_{uuid.uuid4()}.mp3")
diff --git a/astrbot/core/provider/sources/openai_source.py b/astrbot/core/provider/sources/openai_source.py
index 0708c09c7..5378385e5 100644
--- a/astrbot/core/provider/sources/openai_source.py
+++ b/astrbot/core/provider/sources/openai_source.py
@@ -5,6 +5,7 @@ import json
import random
import re
from collections.abc import AsyncGenerator
+from typing import Any
import httpx
from openai import AsyncAzureOpenAI, AsyncOpenAI
@@ -27,6 +28,7 @@ from astrbot.core.utils.network_utils import (
is_connection_error,
log_connection_failure,
)
+from astrbot.core.utils.string_utils import normalize_and_dedupe_strings
from ..register import register_provider_adapter
@@ -36,6 +38,128 @@ from ..register import register_provider_adapter
"OpenAI API Chat Completion 提供商适配器",
)
class ProviderOpenAIOfficial(Provider):
+ _ERROR_TEXT_CANDIDATE_MAX_CHARS = 4096
+
+ @classmethod
+ def _truncate_error_text_candidate(cls, text: str) -> str:
+ if len(text) <= cls._ERROR_TEXT_CANDIDATE_MAX_CHARS:
+ return text
+ return text[: cls._ERROR_TEXT_CANDIDATE_MAX_CHARS]
+
+ @staticmethod
+ def _safe_json_dump(value: Any) -> str | None:
+ try:
+ return json.dumps(value, ensure_ascii=False, default=str)
+ except Exception:
+ return None
+
+ def _get_image_moderation_error_patterns(self) -> list[str]:
+ """Return configured moderation patterns (case-insensitive substring match, not regex)."""
+ configured = self.provider_config.get("image_moderation_error_patterns", [])
+ patterns: list[str] = []
+ if isinstance(configured, str):
+ configured = [configured]
+ if isinstance(configured, list):
+ for pattern in configured:
+ if not isinstance(pattern, str):
+ continue
+ pattern = pattern.strip()
+ if pattern:
+ patterns.append(pattern)
+ return patterns
+
+ @staticmethod
+ def _extract_error_text_candidates(error: Exception) -> list[str]:
+ candidates: list[str] = []
+
+ def _append_candidate(candidate: Any):
+ if candidate is None:
+ return
+ text = str(candidate).strip()
+ if not text:
+ return
+ candidates.append(
+ ProviderOpenAIOfficial._truncate_error_text_candidate(text)
+ )
+
+ _append_candidate(str(error))
+
+ body = getattr(error, "body", None)
+ if isinstance(body, dict):
+ err_obj = body.get("error")
+ body_text = ProviderOpenAIOfficial._safe_json_dump(
+ {"error": err_obj} if isinstance(err_obj, dict) else body
+ )
+ _append_candidate(body_text)
+ if isinstance(err_obj, dict):
+ for field in ("message", "type", "code", "param"):
+ value = err_obj.get(field)
+ if value is not None:
+ _append_candidate(value)
+ elif isinstance(body, str):
+ _append_candidate(body)
+
+ response = getattr(error, "response", None)
+ if response is not None:
+ response_text = getattr(response, "text", None)
+ if isinstance(response_text, str):
+ _append_candidate(response_text)
+
+ return normalize_and_dedupe_strings(candidates)
+
+ def _is_content_moderated_upload_error(self, error: Exception) -> bool:
+ patterns = [
+ pattern.lower() for pattern in self._get_image_moderation_error_patterns()
+ ]
+ if not patterns:
+ return False
+ candidates = [
+ candidate.lower()
+ for candidate in self._extract_error_text_candidates(error)
+ ]
+ for pattern in patterns:
+ if any(pattern in candidate for candidate in candidates):
+ return True
+ return False
+
+ @staticmethod
+ def _context_contains_image(contexts: list[dict]) -> bool:
+ for context in contexts:
+ content = context.get("content")
+ if not isinstance(content, list):
+ continue
+ for item in content:
+ if isinstance(item, dict) and item.get("type") == "image_url":
+ return True
+ return False
+
+ async def _fallback_to_text_only_and_retry(
+ self,
+ payloads: dict,
+ context_query: list,
+ chosen_key: str,
+ available_api_keys: list[str],
+ func_tool: ToolSet | None,
+ reason: str,
+ *,
+ image_fallback_used: bool = False,
+ ) -> tuple:
+ logger.warning(
+ "检测到图片请求失败(%s),已移除图片并重试(保留文本内容)。",
+ reason,
+ )
+ new_contexts = await self._remove_image_from_context(context_query)
+ payloads["messages"] = new_contexts
+ return (
+ False,
+ chosen_key,
+ available_api_keys,
+ payloads,
+ new_contexts,
+ func_tool,
+ image_fallback_used,
+ )
+
def _create_http_client(self, provider_config: dict) -> httpx.AsyncClient | None:
"""创建带代理的 HTTP 客户端"""
proxy = provider_config.get("proxy", "")
@@ -199,7 +323,8 @@ class ProviderOpenAIOfficial(Provider):
llm_response.reasoning_content = reasoning
_y = True
if delta.content:
- completion_text = delta.content
+ # Don't strip streaming chunks to preserve spaces between words
+ completion_text = self._normalize_content(delta.content, strip=False)
llm_response.result_chain = MessageChain(
chain=[Comp.Plain(completion_text)],
)
@@ -247,6 +372,86 @@ class ProviderOpenAIOfficial(Provider):
output=completion_tokens,
)
+ @staticmethod
+ def _normalize_content(raw_content: Any, strip: bool = True) -> str:
+ """Normalize content from various formats to plain string.
+
+ Some LLM providers return content as list[dict] format
+ like [{'type': 'text', 'text': '...'}] instead of
+ plain string. This method handles both formats.
+
+ Args:
+ raw_content: The raw content from LLM response, can be str, list, or other.
+ strip: Whether to strip whitespace from the result. Set to False for
+ streaming chunks to preserve spaces between words.
+
+ Returns:
+ Normalized plain text string.
+ """
+ if isinstance(raw_content, list):
+ # Check if this looks like OpenAI content-part format
+ # Only process if at least one item has {'type': 'text', 'text': ...} structure
+ has_content_part = any(
+ isinstance(part, dict) and part.get("type") == "text"
+ for part in raw_content
+ )
+ if has_content_part:
+ text_parts = []
+ for part in raw_content:
+ if isinstance(part, dict) and part.get("type") == "text":
+ text_val = part.get("text", "")
+ # Coerce to str in case text is null or non-string
+ text_parts.append(str(text_val) if text_val is not None else "")
+ return "".join(text_parts)
+ # Not content-part format, return string representation
+ return str(raw_content)
+
+ if isinstance(raw_content, str):
+ content = raw_content.strip() if strip else raw_content
+ # Check if the string is a JSON-encoded list (e.g., "[{'type': 'text', ...}]")
+ # This can happen when streaming concatenates content that was originally list format
+ # Only check if it looks like a complete JSON array (requires strip for check)
+ check_content = raw_content.strip()
+ if (
+ check_content.startswith("[")
+ and check_content.endswith("]")
+ and len(check_content) < 8192
+ ):
+ try:
+ # First try standard JSON parsing
+ parsed = json.loads(check_content)
+ except json.JSONDecodeError:
+ # If that fails, try parsing as Python literal (handles single quotes)
+ # This is safer than blind replace("'", '"') which corrupts apostrophes
+ try:
+ import ast
+
+ parsed = ast.literal_eval(check_content)
+ except (ValueError, SyntaxError):
+ parsed = None
+
+ if isinstance(parsed, list):
+ # Only convert if it matches OpenAI content-part schema
+ # i.e., at least one item has {'type': 'text', 'text': ...}
+ has_content_part = any(
+ isinstance(part, dict) and part.get("type") == "text"
+ for part in parsed
+ )
+ if has_content_part:
+ text_parts = []
+ for part in parsed:
+ if isinstance(part, dict) and part.get("type") == "text":
+ text_val = part.get("text", "")
+ # Coerce to str in case text is null or non-string
+ text_parts.append(
+ str(text_val) if text_val is not None else ""
+ )
+ if text_parts:
+ return "".join(text_parts)
+ return content
+
+ return str(raw_content)
+
async def _parse_openai_completion(
self, completion: ChatCompletion, tools: ToolSet | None
) -> LLMResponse:
@@ -259,8 +464,7 @@ class ProviderOpenAIOfficial(Provider):
# parse the text completion
if choice.message.content is not None:
- # text completion
- completion_text = str(choice.message.content).strip()
+ completion_text = self._normalize_content(choice.message.content)
# specially, some providers may set tags around reasoning content in the completion text,
# we use regex to remove them, and store then in reasoning_content field
reasoning_pattern = re.compile(r"(.*?)", re.DOTALL)
@@ -270,6 +474,8 @@ class ProviderOpenAIOfficial(Provider):
[match.strip() for match in matches],
)
completion_text = reasoning_pattern.sub("", completion_text).strip()
+ # Also clean up orphan tags that may leak from some models
+ completion_text = re.sub(r"\s*$", "", completion_text).strip()
llm_response.result_chain = MessageChain().message(completion_text)
# parse the reasoning content if any
@@ -403,6 +609,7 @@ class ProviderOpenAIOfficial(Provider):
available_api_keys: list[str],
retry_cnt: int,
max_retries: int,
+ image_fallback_used: bool = False,
) -> tuple:
"""处理API错误并尝试恢复"""
if "429" in str(e):
@@ -422,6 +629,7 @@ class ProviderOpenAIOfficial(Provider):
payloads,
context_query,
func_tool,
+ image_fallback_used,
)
raise e
if "maximum context length" in str(e):
@@ -437,20 +645,34 @@ class ProviderOpenAIOfficial(Provider):
payloads,
context_query,
func_tool,
+ image_fallback_used,
)
if "The model is not a VLM" in str(e): # siliconcloud
+ if image_fallback_used or not self._context_contains_image(context_query):
+ raise e
# 尝试删除所有 image
- new_contexts = await self._remove_image_from_context(context_query)
- payloads["messages"] = new_contexts
- context_query = new_contexts
- return (
- False,
- chosen_key,
- available_api_keys,
+ return await self._fallback_to_text_only_and_retry(
payloads,
context_query,
+ chosen_key,
+ available_api_keys,
func_tool,
+ "model_not_vlm",
+ image_fallback_used=True,
)
+ if self._is_content_moderated_upload_error(e):
+ if image_fallback_used or not self._context_contains_image(context_query):
+ raise e
+ return await self._fallback_to_text_only_and_retry(
+ payloads,
+ context_query,
+ chosen_key,
+ available_api_keys,
+ func_tool,
+ "image_content_moderated",
+ image_fallback_used=True,
+ )
+
if (
"Function calling is not enabled" in str(e)
or ("tool" in str(e).lower() and "support" in str(e).lower())
@@ -461,7 +683,15 @@ class ProviderOpenAIOfficial(Provider):
f"{self.get_model()} 不支持函数工具调用,已自动去除,不影响使用。",
)
payloads.pop("tools", None)
- return False, chosen_key, available_api_keys, payloads, context_query, None
+ return (
+ False,
+ chosen_key,
+ available_api_keys,
+ payloads,
+ context_query,
+ None,
+ image_fallback_used,
+ )
# logger.error(f"发生了错误。Provider 配置如下: {self.provider_config}")
if "tool" in str(e).lower() and "support" in str(e).lower():
@@ -501,6 +731,7 @@ class ProviderOpenAIOfficial(Provider):
max_retries = 10
available_api_keys = self.api_keys.copy()
chosen_key = random.choice(available_api_keys)
+ image_fallback_used = False
last_exception = None
retry_cnt = 0
@@ -518,6 +749,7 @@ class ProviderOpenAIOfficial(Provider):
payloads,
context_query,
func_tool,
+ image_fallback_used,
) = await self._handle_api_error(
e,
payloads,
@@ -527,6 +759,7 @@ class ProviderOpenAIOfficial(Provider):
available_api_keys,
retry_cnt,
max_retries,
+ image_fallback_used=image_fallback_used,
)
if success:
break
@@ -564,6 +797,7 @@ class ProviderOpenAIOfficial(Provider):
max_retries = 10
available_api_keys = self.api_keys.copy()
chosen_key = random.choice(available_api_keys)
+ image_fallback_used = False
last_exception = None
retry_cnt = 0
@@ -582,6 +816,7 @@ class ProviderOpenAIOfficial(Provider):
payloads,
context_query,
func_tool,
+ image_fallback_used,
) = await self._handle_api_error(
e,
payloads,
@@ -591,6 +826,7 @@ class ProviderOpenAIOfficial(Provider):
available_api_keys,
retry_cnt,
max_retries,
+ image_fallback_used=image_fallback_used,
)
if success:
break
diff --git a/astrbot/core/provider/sources/openai_tts_api_source.py b/astrbot/core/provider/sources/openai_tts_api_source.py
index 489a37b2d..217b18925 100644
--- a/astrbot/core/provider/sources/openai_tts_api_source.py
+++ b/astrbot/core/provider/sources/openai_tts_api_source.py
@@ -5,7 +5,7 @@ import httpx
from openai import NOT_GIVEN, AsyncOpenAI
from astrbot import logger
-from astrbot.core.utils.astrbot_path import get_astrbot_data_path
+from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from ..entities import ProviderType
from ..provider import TTSProvider
@@ -46,7 +46,7 @@ class ProviderOpenAITTSAPI(TTSProvider):
self.set_model(provider_config.get("model", ""))
async def get_audio(self, text: str) -> str:
- temp_dir = os.path.join(get_astrbot_data_path(), "temp")
+ temp_dir = get_astrbot_temp_path()
path = os.path.join(temp_dir, f"openai_tts_api_{uuid.uuid4()}.wav")
async with self.client.audio.speech.with_streaming_response.create(
model=self.model_name,
diff --git a/astrbot/core/provider/sources/volcengine_tts.py b/astrbot/core/provider/sources/volcengine_tts.py
index f5d758f5c..349815907 100644
--- a/astrbot/core/provider/sources/volcengine_tts.py
+++ b/astrbot/core/provider/sources/volcengine_tts.py
@@ -8,6 +8,7 @@ import uuid
import aiohttp
from astrbot import logger
+from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from ..entities import ProviderType
from ..provider import TTSProvider
@@ -92,9 +93,12 @@ class ProviderVolcengineTTS(TTSProvider):
if "data" in resp_data:
audio_data = base64.b64decode(resp_data["data"])
- os.makedirs("data/temp", exist_ok=True)
-
- file_path = f"data/temp/volcengine_tts_{uuid.uuid4()}.mp3"
+ temp_dir = get_astrbot_temp_path()
+ os.makedirs(temp_dir, exist_ok=True)
+ file_path = os.path.join(
+ temp_dir,
+ f"volcengine_tts_{uuid.uuid4()}.mp3",
+ )
loop = asyncio.get_running_loop()
await loop.run_in_executor(
diff --git a/astrbot/core/provider/sources/whisper_api_source.py b/astrbot/core/provider/sources/whisper_api_source.py
index 875881b7b..386da063d 100644
--- a/astrbot/core/provider/sources/whisper_api_source.py
+++ b/astrbot/core/provider/sources/whisper_api_source.py
@@ -4,7 +4,7 @@ import uuid
from openai import NOT_GIVEN, AsyncOpenAI
from astrbot.core import logger
-from astrbot.core.utils.astrbot_path import get_astrbot_data_path
+from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from astrbot.core.utils.io import download_file
from astrbot.core.utils.tencent_record_helper import (
convert_to_pcm_wav,
@@ -65,9 +65,11 @@ class ProviderOpenAIWhisperAPI(STTProvider):
if "multimedia.nt.qq.com.cn" in audio_url:
is_tencent = True
- name = str(uuid.uuid4())
- temp_dir = os.path.join(get_astrbot_data_path(), "temp")
- path = os.path.join(temp_dir, name)
+ temp_dir = get_astrbot_temp_path()
+ path = os.path.join(
+ temp_dir,
+ f"whisper_api_{uuid.uuid4().hex[:8]}.input",
+ )
await download_file(audio_url, path)
audio_url = path
@@ -79,8 +81,11 @@ class ProviderOpenAIWhisperAPI(STTProvider):
# 判断是否需要转换
if file_format in ["silk", "amr"]:
- temp_dir = os.path.join(get_astrbot_data_path(), "temp")
- output_path = os.path.join(temp_dir, str(uuid.uuid4()) + ".wav")
+ temp_dir = get_astrbot_temp_path()
+ output_path = os.path.join(
+ temp_dir,
+ f"whisper_api_{uuid.uuid4().hex[:8]}.wav",
+ )
if file_format == "silk":
logger.info(
diff --git a/astrbot/core/provider/sources/whisper_selfhosted_source.py b/astrbot/core/provider/sources/whisper_selfhosted_source.py
index d5d2dc340..678deb948 100644
--- a/astrbot/core/provider/sources/whisper_selfhosted_source.py
+++ b/astrbot/core/provider/sources/whisper_selfhosted_source.py
@@ -6,7 +6,7 @@ from typing import cast
import whisper
from astrbot.core import logger
-from astrbot.core.utils.astrbot_path import get_astrbot_data_path
+from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from astrbot.core.utils.io import download_file
from astrbot.core.utils.tencent_record_helper import tencent_silk_to_wav
@@ -58,9 +58,11 @@ class ProviderOpenAIWhisperSelfHost(STTProvider):
if "multimedia.nt.qq.com.cn" in audio_url:
is_tencent = True
- name = str(uuid.uuid4())
- temp_dir = os.path.join(get_astrbot_data_path(), "temp")
- path = os.path.join(temp_dir, name)
+ temp_dir = get_astrbot_temp_path()
+ path = os.path.join(
+ temp_dir,
+ f"whisper_selfhost_{uuid.uuid4().hex[:8]}.input",
+ )
await download_file(audio_url, path)
audio_url = path
@@ -71,8 +73,11 @@ class ProviderOpenAIWhisperSelfHost(STTProvider):
is_silk = await self._is_silk_file(audio_url)
if is_silk:
logger.info("Converting silk file to wav ...")
- temp_dir = os.path.join(get_astrbot_data_path(), "temp")
- output_path = os.path.join(temp_dir, str(uuid.uuid4()) + ".wav")
+ temp_dir = get_astrbot_temp_path()
+ output_path = os.path.join(
+ temp_dir,
+ f"whisper_selfhost_{uuid.uuid4().hex[:8]}.wav",
+ )
await tencent_silk_to_wav(audio_url, output_path)
audio_url = output_path
diff --git a/astrbot/core/provider/sources/xinference_stt_provider.py b/astrbot/core/provider/sources/xinference_stt_provider.py
index a3e5be352..0a22e456e 100644
--- a/astrbot/core/provider/sources/xinference_stt_provider.py
+++ b/astrbot/core/provider/sources/xinference_stt_provider.py
@@ -7,7 +7,7 @@ from xinference_client.client.restful.async_restful_client import (
)
from astrbot.core import logger
-from astrbot.core.utils.astrbot_path import get_astrbot_data_path
+from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from astrbot.core.utils.tencent_record_helper import (
convert_to_pcm_wav,
tencent_silk_to_wav,
@@ -130,11 +130,17 @@ class ProviderXinferenceSTT(STTProvider):
logger.info(
f"Audio requires conversion ({conversion_type}), using temporary files..."
)
- temp_dir = os.path.join(get_astrbot_data_path(), "temp")
+ temp_dir = get_astrbot_temp_path()
os.makedirs(temp_dir, exist_ok=True)
- input_path = os.path.join(temp_dir, str(uuid.uuid4()))
- output_path = os.path.join(temp_dir, str(uuid.uuid4()) + ".wav")
+ input_path = os.path.join(
+ temp_dir,
+ f"xinference_stt_{uuid.uuid4().hex[:8]}.input",
+ )
+ output_path = os.path.join(
+ temp_dir,
+ f"xinference_stt_{uuid.uuid4().hex[:8]}.wav",
+ )
temp_files.extend([input_path, output_path])
with open(input_path, "wb") as f:
diff --git a/astrbot/core/skills/skill_manager.py b/astrbot/core/skills/skill_manager.py
index fab38d1d1..bd2e09bb4 100644
--- a/astrbot/core/skills/skill_manager.py
+++ b/astrbot/core/skills/skill_manager.py
@@ -98,7 +98,6 @@ class SkillManager:
self.config_path = str(data_path / SKILLS_CONFIG_FILENAME)
self.sandbox_skills_cache_path = str(data_path / SANDBOX_SKILLS_CACHE_FILENAME)
os.makedirs(self.skills_root, exist_ok=True)
- os.makedirs(get_astrbot_temp_path(), exist_ok=True)
def _load_config(self) -> dict:
if not os.path.exists(self.config_path):
diff --git a/astrbot/core/star/filter/platform_adapter_type.py b/astrbot/core/star/filter/platform_adapter_type.py
index ff1affa24..1630650a9 100644
--- a/astrbot/core/star/filter/platform_adapter_type.py
+++ b/astrbot/core/star/filter/platform_adapter_type.py
@@ -20,6 +20,7 @@ class PlatformAdapterType(enum.Flag):
WEIXIN_OFFICIAL_ACCOUNT = enum.auto()
SATORI = enum.auto()
MISSKEY = enum.auto()
+ LINE = enum.auto()
ALL = (
AIOCQHTTP
| QQOFFICIAL
@@ -34,6 +35,7 @@ class PlatformAdapterType(enum.Flag):
| WEIXIN_OFFICIAL_ACCOUNT
| SATORI
| MISSKEY
+ | LINE
)
@@ -51,6 +53,7 @@ ADAPTER_NAME_2_TYPE = {
"weixin_official_account": PlatformAdapterType.WEIXIN_OFFICIAL_ACCOUNT,
"satori": PlatformAdapterType.SATORI,
"misskey": PlatformAdapterType.MISSKEY,
+ "line": PlatformAdapterType.LINE,
}
diff --git a/astrbot/core/star/star_manager.py b/astrbot/core/star/star_manager.py
index 587808956..51f50aedf 100644
--- a/astrbot/core/star/star_manager.py
+++ b/astrbot/core/star/star_manager.py
@@ -62,6 +62,9 @@ class PluginManager:
self._pm_lock = asyncio.Lock()
"""StarManager操作互斥锁"""
+ self.failed_plugin_dict = {}
+ """加载失败插件的信息,用于后续可能的热重载"""
+
self.failed_plugin_info = ""
if os.getenv("ASTRBOT_RELOAD", "0") == "1":
asyncio.create_task(self._watch_plugins_changes())
@@ -193,6 +196,37 @@ class PluginManager:
logger.error(f"更新插件 {p} 的依赖失败。Code: {e!s}")
return True
+ async def _import_plugin_with_dependency_recovery(
+ self,
+ path: str,
+ module_str: str,
+ root_dir_name: str,
+ requirements_path: str,
+ ) -> ModuleType:
+ try:
+ return __import__(path, fromlist=[module_str])
+ except (ModuleNotFoundError, ImportError) as import_exc:
+ if os.path.exists(requirements_path):
+ try:
+ logger.info(
+ f"插件 {root_dir_name} 导入失败,尝试从已安装依赖恢复: {import_exc!s}"
+ )
+ pip_installer.prefer_installed_dependencies(
+ requirements_path=requirements_path
+ )
+ module = __import__(path, fromlist=[module_str])
+ logger.info(
+ f"插件 {root_dir_name} 已从 site-packages 恢复依赖,跳过重新安装。"
+ )
+ return module
+ except Exception as recover_exc:
+ logger.info(
+ f"插件 {root_dir_name} 已安装依赖恢复失败,将重新安装依赖: {recover_exc!s}"
+ )
+
+ await self._check_plugin_dept_update(target_plugin=root_dir_name)
+ return __import__(path, fromlist=[module_str])
+
@staticmethod
def _load_plugin_metadata(plugin_path: str, plugin_obj=None) -> StarMetadata | None:
"""先寻找 metadata.yaml 文件,如果不存在,则使用插件对象的 info() 函数获取元数据。
@@ -296,6 +330,28 @@ class PluginManager:
except KeyError:
logger.warning(f"模块 {module_name} 未载入")
+ async def reload_failed_plugin(self, dir_name):
+ """
+ 重新加载未注册(加载失败)的插件
+ Args:
+ dir_name (str): 要重载的特定插件名称。
+ Returns:
+ tuple: 返回 load() 方法的结果,包含 (success, error_message)
+ - success (bool): 重载是否成功
+ - error_message (str|None): 错误信息,成功时为 None
+ """
+ async with self._pm_lock:
+ if dir_name in self.failed_plugin_dict:
+ success, error = await self.load(specified_dir_name=dir_name)
+ if success:
+ self.failed_plugin_dict.pop(dir_name, None)
+ if not self.failed_plugin_dict:
+ self.failed_plugin_info = ""
+ return success, None
+ else:
+ return False, error
+ return False, "插件不存在于失败列表中"
+
async def reload(self, specified_plugin_name=None):
"""重新加载插件
@@ -386,6 +442,12 @@ class PluginManager:
"reserved",
False,
) # 是否是保留插件。目前在 astrbot/builtin_stars 目录下的都是保留插件。保留插件不可以卸载。
+ plugin_dir_path = (
+ os.path.join(self.plugin_store_path, root_dir_name)
+ if not reserved
+ else os.path.join(self.reserved_plugin_path, root_dir_name)
+ )
+ requirements_path = os.path.join(plugin_dir_path, "requirements.txt")
path = "data.plugins." if not reserved else "astrbot.builtin_stars."
path += root_dir_name + "." + module_str
@@ -400,11 +462,12 @@ class PluginManager:
# 尝试导入模块
try:
- module = __import__(path, fromlist=[module_str])
- except (ModuleNotFoundError, ImportError):
- # 尝试安装依赖
- await self._check_plugin_dept_update(target_plugin=root_dir_name)
- module = __import__(path, fromlist=[module_str])
+ module = await self._import_plugin_with_dependency_recovery(
+ path=path,
+ module_str=module_str,
+ root_dir_name=root_dir_name,
+ requirements_path=requirements_path,
+ )
except Exception as e:
logger.error(traceback.format_exc())
logger.error(f"插件 {root_dir_name} 导入失败。原因:{e!s}")
@@ -412,11 +475,6 @@ class PluginManager:
# 检查 _conf_schema.json
plugin_config = None
- plugin_dir_path = (
- os.path.join(self.plugin_store_path, root_dir_name)
- if not reserved
- else os.path.join(self.reserved_plugin_path, root_dir_name)
- )
plugin_schema_path = os.path.join(
plugin_dir_path,
self.conf_schema_fname,
@@ -455,6 +513,16 @@ class PluginManager:
)
logger.info(metadata)
metadata.config = plugin_config
+ p_name = (metadata.name or "unknown").lower().replace("/", "_")
+ p_author = (metadata.author or "unknown").lower().replace("/", "_")
+ plugin_id = f"{p_author}/{p_name}"
+
+ # 在实例化前注入类属性,保证插件 __init__ 可读取这些值
+ if metadata.star_cls_type:
+ setattr(metadata.star_cls_type, "name", p_name)
+ setattr(metadata.star_cls_type, "author", p_author)
+ setattr(metadata.star_cls_type, "plugin_id", plugin_id)
+
if path not in inactivated_plugins:
# 只有没有禁用插件时才实例化插件类
if plugin_config and metadata.star_cls_type:
@@ -472,17 +540,10 @@ class PluginManager:
context=self.context,
)
- p_name = (metadata.name or "unknown").lower().replace("/", "_")
- p_author = (
- (metadata.author or "unknown").lower().replace("/", "_")
- )
- setattr(metadata.star_cls, "name", p_name)
- setattr(metadata.star_cls, "author", p_author)
- setattr(
- metadata.star_cls,
- "plugin_id",
- f"{p_author}/{p_name}",
- )
+ if metadata.star_cls:
+ setattr(metadata.star_cls, "name", p_name)
+ setattr(metadata.star_cls, "author", p_author)
+ setattr(metadata.star_cls, "plugin_id", plugin_id)
else:
logger.info(f"插件 {metadata.name} 已被禁用。")
@@ -630,6 +691,11 @@ class PluginManager:
logger.error(f"| {line}")
logger.error("----------------------------------")
fail_rec += f"加载 {root_dir_name} 插件时出现问题,原因 {e!s}。\n"
+ self.failed_plugin_dict[root_dir_name] = {
+ "error": str(e),
+ "traceback": errors,
+ }
+ # 记录注册失败的插件名称,以便后续重载插件
# 清除 pip.main 导致的多余的 logging handlers
for handler in logging.root.handlers[:]:
diff --git a/astrbot/core/updator.py b/astrbot/core/updator.py
index e7c2aa54b..049a19789 100644
--- a/astrbot/core/updator.py
+++ b/astrbot/core/updator.py
@@ -148,8 +148,8 @@ class AstrBotUpdator(RepoZipUpdator):
update_data = await self.fetch_release_info(self.ASTRBOT_RELEASE_API, latest)
file_url = None
- if os.environ.get("ASTRBOT_CLI"):
- raise Exception("不支持更新CLI启动的AstrBot") # 避免版本管理混乱
+ if os.environ.get("ASTRBOT_CLI") or os.environ.get("ASTRBOT_LAUNCHER"):
+ raise Exception("不支持更新此方式启动的AstrBot") # 避免版本管理混乱
if latest:
latest_version = update_data[0]["tag_name"]
diff --git a/astrbot/core/utils/io.py b/astrbot/core/utils/io.py
index 24b919bdd..0ce3624e8 100644
--- a/astrbot/core/utils/io.py
+++ b/astrbot/core/utils/io.py
@@ -14,7 +14,7 @@ import certifi
import psutil
from PIL import Image
-from .astrbot_path import get_astrbot_data_path
+from .astrbot_path import get_astrbot_data_path, get_astrbot_temp_path
logger = logging.getLogger("astrbot")
@@ -50,21 +50,10 @@ def port_checker(port: int, host: str = "localhost") -> bool:
def save_temp_img(img: Image.Image | bytes) -> str:
- temp_dir = os.path.join(get_astrbot_data_path(), "temp")
- # 获得文件创建时间,清除超过 12 小时的
- try:
- for f in os.listdir(temp_dir):
- path = os.path.join(temp_dir, f)
- if os.path.isfile(path):
- ctime = os.path.getctime(path)
- if time.time() - ctime > 3600 * 12:
- os.remove(path)
- except Exception as e:
- print(f"清除临时文件失败: {e}")
-
+ temp_dir = get_astrbot_temp_path()
# 获得时间戳
timestamp = f"{int(time.time())}_{uuid.uuid4().hex[:8]}"
- p = os.path.join(temp_dir, f"{timestamp}.jpg")
+ p = os.path.join(temp_dir, f"io_temp_img_{timestamp}.jpg")
if isinstance(img, Image.Image):
img.save(p)
diff --git a/astrbot/core/utils/media_utils.py b/astrbot/core/utils/media_utils.py
index e2abb0744..8d833514f 100644
--- a/astrbot/core/utils/media_utils.py
+++ b/astrbot/core/utils/media_utils.py
@@ -10,7 +10,7 @@ import uuid
from pathlib import Path
from astrbot import logger
-from astrbot.core.utils.astrbot_path import get_astrbot_data_path
+from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
async def get_media_duration(file_path: str) -> int | None:
@@ -77,9 +77,9 @@ async def convert_audio_to_opus(audio_path: str, output_path: str | None = None)
# 生成输出文件路径
if output_path is None:
- temp_dir = os.path.join(get_astrbot_data_path(), "temp")
+ temp_dir = get_astrbot_temp_path()
os.makedirs(temp_dir, exist_ok=True)
- output_path = os.path.join(temp_dir, f"{uuid.uuid4()}.opus")
+ output_path = os.path.join(temp_dir, f"media_audio_{uuid.uuid4().hex}.opus")
try:
# 使用ffmpeg转换为opus格式
@@ -156,9 +156,12 @@ async def convert_video_format(
# 生成输出文件路径
if output_path is None:
- temp_dir = os.path.join(get_astrbot_data_path(), "temp")
+ temp_dir = get_astrbot_temp_path()
os.makedirs(temp_dir, exist_ok=True)
- output_path = os.path.join(temp_dir, f"{uuid.uuid4()}.{output_format}")
+ output_path = os.path.join(
+ temp_dir,
+ f"media_video_{uuid.uuid4().hex}.{output_format}",
+ )
try:
# 使用ffmpeg转换视频格式
@@ -227,9 +230,9 @@ async def convert_audio_format(
return audio_path
if output_path is None:
- temp_dir = Path(get_astrbot_data_path()) / "temp"
+ temp_dir = Path(get_astrbot_temp_path())
temp_dir.mkdir(parents=True, exist_ok=True)
- output_path = str(temp_dir / f"{uuid.uuid4()}.{output_format}")
+ output_path = str(temp_dir / f"media_audio_{uuid.uuid4().hex}.{output_format}")
args = ["ffmpeg", "-y", "-i", audio_path]
if output_format == "amr":
@@ -283,9 +286,9 @@ async def extract_video_cover(
) -> str:
"""从视频中提取封面图(JPG)。"""
if output_path is None:
- temp_dir = Path(get_astrbot_data_path()) / "temp"
+ temp_dir = Path(get_astrbot_temp_path())
temp_dir.mkdir(parents=True, exist_ok=True)
- output_path = str(temp_dir / f"{uuid.uuid4()}.jpg")
+ output_path = str(temp_dir / f"media_cover_{uuid.uuid4().hex}.jpg")
try:
process = await asyncio.create_subprocess_exec(
diff --git a/astrbot/core/utils/network_utils.py b/astrbot/core/utils/network_utils.py
index feb234a30..727f3762a 100644
--- a/astrbot/core/utils/network_utils.py
+++ b/astrbot/core/utils/network_utils.py
@@ -77,9 +77,7 @@ def log_connection_failure(
f"代理地址: {effective_proxy},错误: {error}"
)
else:
- logger.error(
- f"[{provider_label}] 网络连接失败 ({error_type}),未配置代理。错误: {error}"
- )
+ logger.error(f"[{provider_label}] 网络连接失败 ({error_type})。错误: {error}")
def create_proxy_client(
diff --git a/astrbot/core/utils/pip_installer.py b/astrbot/core/utils/pip_installer.py
index 8d43c11b2..1c8da23c1 100644
--- a/astrbot/core/utils/pip_installer.py
+++ b/astrbot/core/utils/pip_installer.py
@@ -580,6 +580,26 @@ class PipInstaller:
)
importlib.invalidate_caches()
+ def prefer_installed_dependencies(self, requirements_path: str) -> None:
+ """优先使用已安装在插件 site-packages 中的依赖,不执行安装。"""
+ if not is_packaged_electron_runtime():
+ return
+
+ target_site_packages = get_astrbot_site_packages_path()
+ if not os.path.isdir(target_site_packages):
+ return
+
+ requested_requirements = _extract_requirement_names(requirements_path)
+ if not requested_requirements:
+ return
+
+ _prepend_sys_path(target_site_packages)
+ _ensure_plugin_dependencies_preferred(
+ target_site_packages,
+ requested_requirements,
+ )
+ importlib.invalidate_caches()
+
async def _run_pip_in_process(self, args: list[str]) -> int:
pip_main = _get_pip_main()
_patch_distlib_finder_for_frozen_runtime()
diff --git a/astrbot/core/utils/quoted_message/__init__.py b/astrbot/core/utils/quoted_message/__init__.py
new file mode 100644
index 000000000..8421898fd
--- /dev/null
+++ b/astrbot/core/utils/quoted_message/__init__.py
@@ -0,0 +1,8 @@
+from __future__ import annotations
+
+from .extractor import extract_quoted_message_images, extract_quoted_message_text
+
+__all__ = [
+ "extract_quoted_message_text",
+ "extract_quoted_message_images",
+]
diff --git a/astrbot/core/utils/quoted_message/chain_parser.py b/astrbot/core/utils/quoted_message/chain_parser.py
new file mode 100644
index 000000000..5fe89302e
--- /dev/null
+++ b/astrbot/core/utils/quoted_message/chain_parser.py
@@ -0,0 +1,505 @@
+from __future__ import annotations
+
+import json
+import re
+from typing import Any, TypedDict
+
+from astrbot.core.message.components import (
+ At,
+ AtAll,
+ File,
+ Forward,
+ Image,
+ Node,
+ Nodes,
+ Plain,
+ Reply,
+ Video,
+)
+from astrbot.core.platform.astr_message_event import AstrMessageEvent
+from astrbot.core.utils.string_utils import normalize_and_dedupe_strings
+
+from .image_refs import looks_like_image_file_name, normalize_file_like_url
+from .settings import SETTINGS, QuotedMessageParserSettings
+
+_FORWARD_PLACEHOLDER_PATTERN = re.compile(
+ r"^(?:[\(\[]?[^\]:\)]*[\)\]]?\s*:\s*)?\[(?:forward message|转发消息|合并转发)\]$",
+ flags=re.IGNORECASE,
+)
+
+
+class ParsedOneBotPayload(TypedDict):
+ text: str | None
+ forward_ids: list[str]
+ image_refs: list[str]
+
+
+def _build_parsed_payload(
+ text: str | None,
+ forward_ids: list[str] | None = None,
+ image_refs: list[str] | None = None,
+) -> ParsedOneBotPayload:
+ return {
+ "text": text,
+ "forward_ids": forward_ids or [],
+ "image_refs": image_refs or [],
+ }
+
+
+def _join_text_parts(parts: list[str]) -> str | None:
+ text = "".join(parts).strip()
+ return text or None
+
+
+def _find_first_reply_component(event: AstrMessageEvent) -> Reply | None:
+ for comp in event.message_obj.message:
+ if isinstance(comp, Reply):
+ return comp
+ return None
+
+
+def _is_forward_placeholder_only_text(text: str | None) -> bool:
+ if not isinstance(text, str):
+ return False
+ lines = [line.strip() for line in text.splitlines() if line.strip()]
+ if not lines:
+ return False
+ return all(_FORWARD_PLACEHOLDER_PATTERN.match(line) for line in lines)
+
+
+def _extract_image_refs_from_component_chain(
+ chain: list[Any] | None,
+ *,
+ depth: int = 0,
+ settings: QuotedMessageParserSettings = SETTINGS,
+) -> list[str]:
+ if not isinstance(chain, list) or depth > settings.max_component_chain_depth:
+ return []
+
+ image_refs: list[str] = []
+ for seg in chain:
+ if isinstance(seg, Image):
+ for candidate in (seg.url, seg.file, seg.path):
+ if isinstance(candidate, str) and candidate.strip():
+ image_refs.append(candidate.strip())
+ break
+ elif isinstance(seg, Reply):
+ image_refs.extend(
+ _extract_image_refs_from_reply_component(
+ seg,
+ depth=depth + 1,
+ settings=settings,
+ )
+ )
+ elif isinstance(seg, Node):
+ image_refs.extend(
+ _extract_image_refs_from_component_chain(
+ seg.content,
+ depth=depth + 1,
+ settings=settings,
+ )
+ )
+ elif isinstance(seg, Nodes):
+ for node in seg.nodes:
+ image_refs.extend(
+ _extract_image_refs_from_component_chain(
+ node.content,
+ depth=depth + 1,
+ settings=settings,
+ )
+ )
+
+ return normalize_and_dedupe_strings(image_refs)
+
+
+def _extract_text_from_component_chain(
+ chain: list[Any] | None,
+ *,
+ depth: int = 0,
+ settings: QuotedMessageParserSettings = SETTINGS,
+) -> str | None:
+ if not isinstance(chain, list) or depth > settings.max_component_chain_depth:
+ return None
+
+ parts: list[str] = []
+ for seg in chain:
+ if isinstance(seg, Plain):
+ if seg.text:
+ parts.append(seg.text)
+ elif isinstance(seg, At):
+ if seg.name:
+ parts.append(f"@{seg.name}")
+ elif seg.qq:
+ parts.append(f"@{seg.qq}")
+ elif isinstance(seg, AtAll):
+ parts.append("@all")
+ elif isinstance(seg, Image):
+ parts.append("[Image]")
+ elif isinstance(seg, Video):
+ parts.append("[Video]")
+ elif isinstance(seg, File):
+ file_name = seg.name or "file"
+ parts.append(f"[File:{file_name}]")
+ elif isinstance(seg, Forward):
+ parts.append("[Forward Message]")
+ elif isinstance(seg, Reply):
+ nested = _extract_text_from_reply_component(
+ seg,
+ depth=depth + 1,
+ settings=settings,
+ )
+ if nested:
+ parts.append(nested)
+ elif isinstance(seg, Node):
+ node_sender = seg.name or seg.uin or "Unknown User"
+ node_text = _extract_text_from_component_chain(
+ seg.content,
+ depth=depth + 1,
+ settings=settings,
+ )
+ if node_text:
+ parts.append(f"{node_sender}: {node_text}")
+ elif isinstance(seg, Nodes):
+ for node in seg.nodes:
+ node_sender = node.name or node.uin or "Unknown User"
+ node_text = _extract_text_from_component_chain(
+ node.content,
+ depth=depth + 1,
+ settings=settings,
+ )
+ if node_text:
+ parts.append(f"{node_sender}: {node_text}")
+
+ return _join_text_parts(parts)
+
+
+def _extract_image_refs_from_reply_component(
+ reply: Reply,
+ *,
+ depth: int = 0,
+ settings: QuotedMessageParserSettings = SETTINGS,
+) -> list[str]:
+ for attr in ("chain", "message", "origin", "content"):
+ payload = getattr(reply, attr, None)
+ image_refs = _extract_image_refs_from_component_chain(
+ payload,
+ depth=depth,
+ settings=settings,
+ )
+ if image_refs:
+ return image_refs
+ return []
+
+
+def _extract_text_from_reply_component(
+ reply: Reply,
+ *,
+ depth: int = 0,
+ settings: QuotedMessageParserSettings = SETTINGS,
+) -> str | None:
+ for attr in ("chain", "message", "origin", "content"):
+ payload = getattr(reply, attr, None)
+ text = _extract_text_from_component_chain(
+ payload,
+ depth=depth,
+ settings=settings,
+ )
+ if text:
+ return text
+
+ if reply.message_str and reply.message_str.strip():
+ return reply.message_str.strip()
+ return None
+
+
+def _unwrap_onebot_data(payload: Any) -> dict[str, Any]:
+ if not isinstance(payload, dict):
+ return {}
+ data = payload.get("data")
+ if isinstance(data, dict):
+ return data
+ return payload
+
+
+def _extract_text_from_multimsg_json(raw_json: str) -> str | None:
+ try:
+ parsed = json.loads(raw_json)
+ except Exception:
+ return None
+
+ if not isinstance(parsed, dict):
+ return None
+ if parsed.get("app") != "com.tencent.multimsg":
+ return None
+ config = parsed.get("config")
+ if not isinstance(config, dict):
+ return None
+ if config.get("forward") != 1:
+ return None
+
+ meta = parsed.get("meta")
+ if not isinstance(meta, dict):
+ return None
+ detail = meta.get("detail")
+ if not isinstance(detail, dict):
+ return None
+ news_items = detail.get("news")
+ if not isinstance(news_items, list):
+ return None
+
+ texts: list[str] = []
+ for item in news_items:
+ if not isinstance(item, dict):
+ continue
+ text_content = item.get("text")
+ if not isinstance(text_content, str):
+ continue
+ cleaned = text_content.strip().replace("[图片]", "").strip()
+ if cleaned:
+ texts.append(cleaned)
+
+ return "\n".join(texts).strip() or None
+
+
+def _parse_onebot_segments(
+ segments: list[Any],
+ *,
+ settings: QuotedMessageParserSettings = SETTINGS,
+) -> ParsedOneBotPayload:
+ text_parts: list[str] = []
+ forward_ids: list[str] = []
+ image_refs: list[str] = []
+
+ for seg in segments:
+ if not isinstance(seg, dict):
+ continue
+
+ seg_type = seg.get("type")
+ seg_data = seg.get("data", {}) if isinstance(seg.get("data"), dict) else {}
+
+ if seg_type in ("text", "plain"):
+ text = seg_data.get("text")
+ if isinstance(text, str) and text:
+ text_parts.append(text)
+ elif seg_type == "image":
+ text_parts.append("[Image]")
+ candidate = seg_data.get("url") or seg_data.get("file")
+ if isinstance(candidate, str) and candidate.strip():
+ image_refs.append(candidate.strip())
+ elif seg_type == "video":
+ text_parts.append("[Video]")
+ elif seg_type == "file":
+ file_name = (
+ seg_data.get("name")
+ or seg_data.get("file_name")
+ or seg_data.get("file")
+ or "file"
+ )
+ text_parts.append(f"[File:{file_name}]")
+ candidate_url = seg_data.get("url")
+ if (
+ isinstance(candidate_url, str)
+ and candidate_url.strip()
+ and looks_like_image_file_name(normalize_file_like_url(candidate_url))
+ ):
+ image_refs.append(candidate_url.strip())
+ candidate_file = seg_data.get("file")
+ if (
+ isinstance(candidate_file, str)
+ and candidate_file.strip()
+ and looks_like_image_file_name(
+ normalize_file_like_url(
+ seg_data.get("name")
+ or seg_data.get("file_name")
+ or candidate_file
+ )
+ )
+ ):
+ image_refs.append(candidate_file.strip())
+ elif seg_type in ("forward", "forward_msg", "nodes"):
+ fid = seg_data.get("id") or seg_data.get("message_id")
+ if isinstance(fid, (str, int)) and str(fid):
+ forward_ids.append(str(fid))
+ else:
+ nested_nodes = seg_data.get("content")
+ nested_text, nested_forward_ids, nested_images = (
+ _extract_text_forward_ids_and_images_from_forward_nodes(
+ nested_nodes if isinstance(nested_nodes, list) else [],
+ depth=1,
+ settings=settings,
+ )
+ )
+ if nested_text:
+ text_parts.append(nested_text)
+ if nested_forward_ids:
+ forward_ids.extend(nested_forward_ids)
+ if nested_images:
+ image_refs.extend(nested_images)
+ elif seg_type == "json":
+ raw_json = seg_data.get("data")
+ if isinstance(raw_json, str) and raw_json.strip():
+ raw_json = raw_json.replace(",", ",")
+ multimsg_text = _extract_text_from_multimsg_json(raw_json)
+ if multimsg_text:
+ text_parts.append(multimsg_text)
+
+ return _build_parsed_payload(
+ _join_text_parts(text_parts),
+ forward_ids,
+ normalize_and_dedupe_strings(image_refs),
+ )
+
+
+def _extract_text_forward_ids_and_images_from_forward_nodes(
+ nodes: list[Any],
+ *,
+ depth: int = 0,
+ settings: QuotedMessageParserSettings = SETTINGS,
+) -> tuple[str | None, list[str], list[str]]:
+ if not isinstance(nodes, list) or depth > settings.max_forward_node_depth:
+ return None, [], []
+
+ texts: list[str] = []
+ forward_ids: list[str] = []
+ image_refs: list[str] = []
+ indent = " " * depth
+
+ for node in nodes:
+ if not isinstance(node, dict):
+ continue
+
+ sender = node.get("sender") if isinstance(node.get("sender"), dict) else {}
+ sender_name = (
+ sender.get("nickname")
+ or sender.get("card")
+ or sender.get("user_id")
+ or "Unknown User"
+ )
+
+ raw_content = node.get("message") or node.get("content") or []
+ chain: list[Any] = []
+ if isinstance(raw_content, list):
+ chain = raw_content
+ elif isinstance(raw_content, str):
+ raw_content = raw_content.strip()
+ if raw_content:
+ try:
+ parsed = json.loads(raw_content)
+ except Exception:
+ parsed = None
+ if isinstance(parsed, list):
+ chain = parsed
+ else:
+ chain = [{"type": "text", "data": {"text": raw_content}}]
+
+ parsed_segments = _parse_onebot_segments(chain, settings=settings)
+ node_text = parsed_segments["text"]
+ node_forward_ids = parsed_segments["forward_ids"]
+ node_images = parsed_segments["image_refs"]
+ if node_text:
+ texts.append(f"{indent}{sender_name}: {node_text}")
+ if node_forward_ids:
+ forward_ids.extend(node_forward_ids)
+ if node_images:
+ image_refs.extend(node_images)
+
+ return (
+ "\n".join(texts).strip() or None,
+ normalize_and_dedupe_strings(forward_ids),
+ normalize_and_dedupe_strings(image_refs),
+ )
+
+
+def _parse_onebot_get_msg_payload(
+ payload: dict[str, Any],
+ *,
+ settings: QuotedMessageParserSettings = SETTINGS,
+) -> ParsedOneBotPayload:
+ data = _unwrap_onebot_data(payload)
+ segments = data.get("message") or data.get("messages")
+ if isinstance(segments, list):
+ return _parse_onebot_segments(segments, settings=settings)
+
+ text: str | None = None
+ if isinstance(segments, str) and segments.strip():
+ text = segments.strip()
+ else:
+ raw = data.get("raw_message")
+ if isinstance(raw, str) and raw.strip():
+ text = raw.strip()
+ return _build_parsed_payload(text)
+
+
+def _parse_onebot_get_forward_payload(
+ payload: dict[str, Any],
+ *,
+ settings: QuotedMessageParserSettings = SETTINGS,
+) -> ParsedOneBotPayload:
+ data = _unwrap_onebot_data(payload)
+ nodes = (
+ data.get("messages")
+ or data.get("message")
+ or data.get("nodes")
+ or data.get("nodeList")
+ )
+ if not isinstance(nodes, list):
+ return _build_parsed_payload(None)
+
+ text, forward_ids, image_refs = (
+ _extract_text_forward_ids_and_images_from_forward_nodes(
+ nodes,
+ settings=settings,
+ )
+ )
+ return _build_parsed_payload(text, forward_ids, image_refs)
+
+
+class ReplyChainParser:
+ def __init__(self, settings: QuotedMessageParserSettings = SETTINGS):
+ self._settings = settings
+
+ @staticmethod
+ def find_first_reply_component(event: AstrMessageEvent) -> Reply | None:
+ return _find_first_reply_component(event)
+
+ @staticmethod
+ def is_forward_placeholder_only_text(text: str | None) -> bool:
+ return _is_forward_placeholder_only_text(text)
+
+ def extract_text_from_reply_component(
+ self,
+ reply: Reply,
+ *,
+ depth: int = 0,
+ ) -> str | None:
+ return _extract_text_from_reply_component(
+ reply,
+ depth=depth,
+ settings=self._settings,
+ )
+
+ def extract_image_refs_from_reply_component(
+ self,
+ reply: Reply,
+ *,
+ depth: int = 0,
+ ) -> list[str]:
+ return _extract_image_refs_from_reply_component(
+ reply,
+ depth=depth,
+ settings=self._settings,
+ )
+
+
+class OneBotPayloadParser:
+ def __init__(self, settings: QuotedMessageParserSettings = SETTINGS):
+ self._settings = settings
+
+ def parse_get_msg_payload(self, payload: dict[str, Any]) -> ParsedOneBotPayload:
+ return _parse_onebot_get_msg_payload(payload, settings=self._settings)
+
+ def parse_get_forward_payload(
+ self,
+ payload: dict[str, Any],
+ ) -> ParsedOneBotPayload:
+ return _parse_onebot_get_forward_payload(payload, settings=self._settings)
diff --git a/astrbot/core/utils/quoted_message/extractor.py b/astrbot/core/utils/quoted_message/extractor.py
new file mode 100644
index 000000000..83570d66c
--- /dev/null
+++ b/astrbot/core/utils/quoted_message/extractor.py
@@ -0,0 +1,211 @@
+from __future__ import annotations
+
+from dataclasses import dataclass
+
+from astrbot import logger
+from astrbot.core.message.components import Reply
+from astrbot.core.platform.astr_message_event import AstrMessageEvent
+from astrbot.core.utils.string_utils import normalize_and_dedupe_strings
+
+from .chain_parser import OneBotPayloadParser, ReplyChainParser
+from .image_resolver import ImageResolver
+from .onebot_client import OneBotClient
+from .settings import SETTINGS, QuotedMessageParserSettings
+
+
+async def _collect_text_and_images_from_forward_ids(
+ onebot_client: OneBotClient,
+ payload_parser: OneBotPayloadParser,
+ forward_ids: list[str],
+ *,
+ max_fetch: int,
+) -> tuple[list[str], list[str]]:
+ texts: list[str] = []
+ image_refs: list[str] = []
+ pending: list[str] = []
+ seen: set[str] = set()
+
+ for fid in forward_ids:
+ if not isinstance(fid, str):
+ continue
+ cleaned = fid.strip()
+ if cleaned:
+ pending.append(cleaned)
+
+ fetch_count = 0
+ while pending and fetch_count < max_fetch:
+ current_id = pending.pop(0)
+ if current_id in seen:
+ continue
+ seen.add(current_id)
+ fetch_count += 1
+
+ forward_payload = await onebot_client.get_forward_msg(current_id)
+ if not forward_payload:
+ continue
+
+ parsed = payload_parser.parse_get_forward_payload(forward_payload)
+ if parsed["text"]:
+ texts.append(parsed["text"])
+ if parsed["image_refs"]:
+ image_refs.extend(parsed["image_refs"])
+ for nested_id in parsed["forward_ids"]:
+ if nested_id not in seen:
+ pending.append(nested_id)
+
+ if pending:
+ logger.warning(
+ "quoted_message_parser: stop fetching nested forward messages after %d hops",
+ max_fetch,
+ )
+
+ return texts, normalize_and_dedupe_strings(image_refs)
+
+
+@dataclass(slots=True)
+class QuotedMessageContent:
+ embedded_text: str | None
+ embedded_image_refs: list[str]
+ reply_id: str
+ direct_text: str | None
+ direct_image_refs: list[str]
+ forward_texts: list[str]
+ forward_image_refs: list[str]
+
+
+class QuotedMessageExtractor:
+ def __init__(
+ self,
+ event: AstrMessageEvent,
+ settings: QuotedMessageParserSettings = SETTINGS,
+ ):
+ self._event = event
+ self._settings = settings
+ self._reply_parser = ReplyChainParser(settings=settings)
+ self._payload_parser = OneBotPayloadParser(settings=settings)
+ self._client = OneBotClient(event, settings=settings)
+ self._image_resolver = ImageResolver(event, self._client)
+
+ async def _fetch_quoted_content(
+ self,
+ reply_component: Reply | None = None,
+ *,
+ fetch_remote: bool,
+ ) -> QuotedMessageContent | None:
+ reply = reply_component or self._reply_parser.find_first_reply_component(
+ self._event
+ )
+ if not reply:
+ return None
+
+ embedded_text = self._reply_parser.extract_text_from_reply_component(reply)
+ embedded_image_refs = list(
+ self._reply_parser.extract_image_refs_from_reply_component(reply)
+ )
+
+ reply_id = getattr(reply, "id", None)
+ reply_id_str = str(reply_id).strip() if reply_id is not None else ""
+ if not fetch_remote or not reply_id_str:
+ return QuotedMessageContent(
+ embedded_text=embedded_text,
+ embedded_image_refs=embedded_image_refs,
+ reply_id=reply_id_str,
+ direct_text=None,
+ direct_image_refs=[],
+ forward_texts=[],
+ forward_image_refs=[],
+ )
+
+ msg_payload = await self._client.get_msg(reply_id_str)
+ if not msg_payload:
+ return QuotedMessageContent(
+ embedded_text=embedded_text,
+ embedded_image_refs=embedded_image_refs,
+ reply_id=reply_id_str,
+ direct_text=None,
+ direct_image_refs=[],
+ forward_texts=[],
+ forward_image_refs=[],
+ )
+
+ parsed = self._payload_parser.parse_get_msg_payload(msg_payload)
+ forward_texts, forward_images = await _collect_text_and_images_from_forward_ids(
+ self._client,
+ self._payload_parser,
+ parsed["forward_ids"],
+ max_fetch=self._settings.max_forward_fetch,
+ )
+ return QuotedMessageContent(
+ embedded_text=embedded_text,
+ embedded_image_refs=embedded_image_refs,
+ reply_id=reply_id_str,
+ direct_text=parsed["text"],
+ direct_image_refs=list(parsed["image_refs"]),
+ forward_texts=forward_texts,
+ forward_image_refs=forward_images,
+ )
+
+ async def text(self, reply_component: Reply | None = None) -> str | None:
+ embedded_content = await self._fetch_quoted_content(
+ reply_component,
+ fetch_remote=False,
+ )
+ if not embedded_content:
+ return None
+
+ if (
+ embedded_content.embedded_text
+ and not self._reply_parser.is_forward_placeholder_only_text(
+ embedded_content.embedded_text
+ )
+ ):
+ return embedded_content.embedded_text
+
+ if not embedded_content.reply_id:
+ return embedded_content.embedded_text
+
+ fetched_content = await self._fetch_quoted_content(
+ reply_component,
+ fetch_remote=True,
+ )
+ if not fetched_content:
+ return embedded_content.embedded_text
+
+ text_parts: list[str] = []
+ if fetched_content.direct_text:
+ text_parts.append(fetched_content.direct_text)
+ text_parts.extend(fetched_content.forward_texts)
+
+ return "\n".join(text_parts).strip() or embedded_content.embedded_text
+
+ async def images(self, reply_component: Reply | None = None) -> list[str]:
+ content = await self._fetch_quoted_content(reply_component, fetch_remote=True)
+ if not content:
+ return []
+
+ image_refs: list[str] = []
+ image_refs.extend(content.embedded_image_refs)
+ image_refs.extend(content.direct_image_refs)
+ image_refs.extend(content.forward_image_refs)
+
+ return await self._image_resolver.resolve_for_llm(image_refs)
+
+
+async def extract_quoted_message_text(
+ event: AstrMessageEvent,
+ reply_component: Reply | None = None,
+ settings: QuotedMessageParserSettings | None = None,
+) -> str | None:
+ return await QuotedMessageExtractor(event, settings=settings or SETTINGS).text(
+ reply_component
+ )
+
+
+async def extract_quoted_message_images(
+ event: AstrMessageEvent,
+ reply_component: Reply | None = None,
+ settings: QuotedMessageParserSettings | None = None,
+) -> list[str]:
+ return await QuotedMessageExtractor(event, settings=settings or SETTINGS).images(
+ reply_component
+ )
diff --git a/astrbot/core/utils/quoted_message/image_refs.py b/astrbot/core/utils/quoted_message/image_refs.py
new file mode 100644
index 000000000..009d6844a
--- /dev/null
+++ b/astrbot/core/utils/quoted_message/image_refs.py
@@ -0,0 +1,94 @@
+from __future__ import annotations
+
+import os
+from urllib.parse import urlsplit
+
+IMAGE_EXTENSIONS = {
+ ".jpg",
+ ".jpeg",
+ ".png",
+ ".webp",
+ ".bmp",
+ ".tif",
+ ".tiff",
+ ".gif",
+}
+
+
+def normalize_file_like_url(path: str | None) -> str | None:
+ if path is None:
+ return None
+ if not isinstance(path, str):
+ return None
+ if "?" not in path and "#" not in path:
+ return path
+ try:
+ split = urlsplit(path)
+ except Exception:
+ return path
+ return split.path or path
+
+
+def looks_like_image_file_name(name: str) -> bool:
+ normalized_name = normalize_file_like_url(name)
+ if not isinstance(normalized_name, str) or not normalized_name.strip():
+ return False
+ _, ext = os.path.splitext(normalized_name.strip().lower())
+ return ext in IMAGE_EXTENSIONS
+
+
+def convert_data_image_to_base64_ref(image_ref: str) -> str | None:
+ if not isinstance(image_ref, str):
+ return None
+ value = image_ref.strip()
+ if not value:
+ return None
+ lower_value = value.lower()
+ if not lower_value.startswith("data:image/"):
+ return None
+
+ comma_index = value.find(",")
+ if comma_index <= 0:
+ return None
+ header = value[:comma_index].lower()
+ payload = value[comma_index + 1 :].strip()
+ if ";base64" not in header or not payload:
+ return None
+ return f"base64://{payload}"
+
+
+def get_existing_local_path(value: str) -> str | None:
+ lower_value = value.lower()
+ if lower_value.startswith("file://"):
+ file_path = value[7:]
+ if file_path.startswith("/") and len(file_path) > 3 and file_path[2] == ":":
+ file_path = file_path[1:]
+ if file_path and os.path.exists(file_path):
+ return os.path.abspath(file_path)
+ return None
+ if os.path.exists(value):
+ return os.path.abspath(value)
+ return None
+
+
+def normalize_image_ref(image_ref: str) -> str | None:
+ if not isinstance(image_ref, str):
+ return None
+ value = image_ref.strip()
+ if not value:
+ return None
+ lower_value = value.lower()
+
+ if lower_value.startswith(("http://", "https://")):
+ return value
+ if lower_value.startswith("base64://"):
+ return value
+
+ data_image_ref = convert_data_image_to_base64_ref(value)
+ if data_image_ref:
+ return data_image_ref
+
+ local_path = get_existing_local_path(value)
+ if local_path and looks_like_image_file_name(local_path):
+ return local_path
+ return None
diff --git a/astrbot/core/utils/quoted_message/image_resolver.py b/astrbot/core/utils/quoted_message/image_resolver.py
new file mode 100644
index 000000000..5a4c21fb2
--- /dev/null
+++ b/astrbot/core/utils/quoted_message/image_resolver.py
@@ -0,0 +1,130 @@
+from __future__ import annotations
+
+import os
+from typing import Any
+
+from astrbot import logger
+from astrbot.core.platform.astr_message_event import AstrMessageEvent
+from astrbot.core.utils.string_utils import normalize_and_dedupe_strings
+
+from .image_refs import IMAGE_EXTENSIONS, get_existing_local_path, normalize_image_ref
+from .onebot_client import OneBotClient
+
+
+def _build_image_id_candidates(image_ref: str) -> list[str]:
+ candidates: list[str] = [image_ref]
+ base_name, ext = os.path.splitext(image_ref)
+ if ext and base_name and base_name not in candidates:
+ if ext.lower() in IMAGE_EXTENSIONS:
+ candidates.append(base_name)
+ return candidates
+
+
+def _build_image_resolve_actions(
+ event: AstrMessageEvent,
+ image_ref: str,
+) -> list[tuple[str, dict[str, Any]]]:
+ actions: list[tuple[str, dict[str, Any]]] = []
+ candidates = _build_image_id_candidates(image_ref)
+
+ for candidate in candidates:
+ actions.extend(
+ [
+ ("get_image", {"file": candidate}),
+ ("get_image", {"file_id": candidate}),
+ ("get_image", {"id": candidate}),
+ ("get_image", {"image": candidate}),
+ ("get_file", {"file_id": candidate}),
+ ("get_file", {"file": candidate}),
+ ]
+ )
+
+ try:
+ group_id = event.get_group_id()
+ except Exception:
+ group_id = None
+ group_id_value = group_id
+ if isinstance(group_id, str) and group_id.isdigit():
+ group_id_value = int(group_id)
+
+ if group_id_value:
+ for candidate in candidates:
+ actions.append(
+ (
+ "get_group_file_url",
+ {"group_id": group_id_value, "file_id": candidate},
+ )
+ )
+ for candidate in candidates:
+ actions.append(("get_private_file_url", {"file_id": candidate}))
+
+ return actions
+
+
+class ImageResolver:
+ def __init__(
+ self,
+ event: AstrMessageEvent,
+ onebot_client: OneBotClient | None = None,
+ ):
+ self._event = event
+ self._client = onebot_client or OneBotClient(event)
+
+ async def resolve_for_llm(self, image_refs: list[str]) -> list[str]:
+ resolved: list[str] = []
+ unresolved: list[str] = []
+
+ for image_ref in normalize_and_dedupe_strings(image_refs):
+ normalized = normalize_image_ref(image_ref)
+ if normalized:
+ resolved.append(normalized)
+ elif get_existing_local_path(image_ref):
+ # Drop non-image local paths instead of treating them as remote IDs.
+ logger.debug(
+ "quoted_message_parser: skip non-image local path ref=%s",
+ image_ref[:128],
+ )
+ else:
+ unresolved.append(image_ref)
+
+ for image_ref in unresolved:
+ resolved_ref = await self._resolve_one(image_ref)
+ if resolved_ref:
+ resolved.append(resolved_ref)
+
+ return normalize_and_dedupe_strings(resolved)
+
+ async def _resolve_one(self, image_ref: str) -> str | None:
+ resolved = normalize_image_ref(image_ref)
+ if resolved:
+ return resolved
+
+ actions = _build_image_resolve_actions(self._event, image_ref)
+ for action, params in actions:
+ data = await self._client.call(
+ action,
+ params,
+ warn_on_all_failed=False,
+ unwrap_data=True,
+ )
+ if not isinstance(data, dict):
+ continue
+
+ url = data.get("url")
+ if isinstance(url, str):
+ normalized = normalize_image_ref(url)
+ if normalized:
+ return normalized
+
+ file_value = data.get("file")
+ if isinstance(file_value, str):
+ normalized = normalize_image_ref(file_value)
+ if normalized:
+ return normalized
+
+ logger.warning(
+ "quoted_message_parser: failed to resolve quoted image ref=%s after %d actions",
+ image_ref[:128],
+ len(actions),
+ )
+ return None
diff --git a/astrbot/core/utils/quoted_message/onebot_client.py b/astrbot/core/utils/quoted_message/onebot_client.py
new file mode 100644
index 000000000..c48785d76
--- /dev/null
+++ b/astrbot/core/utils/quoted_message/onebot_client.py
@@ -0,0 +1,119 @@
+from __future__ import annotations
+
+from typing import Any
+
+from astrbot import logger
+from astrbot.core.platform.astr_message_event import AstrMessageEvent
+
+from .settings import SETTINGS, QuotedMessageParserSettings
+
+
+def _unwrap_action_response(ret: dict[str, Any] | None) -> dict[str, Any]:
+ if not isinstance(ret, dict):
+ return {}
+ data = ret.get("data")
+ if isinstance(data, dict):
+ return data
+ return ret
+
+
+class OneBotClient:
+ def __init__(
+ self,
+ event: AstrMessageEvent,
+ settings: QuotedMessageParserSettings = SETTINGS,
+ ):
+ self._call_action = self._resolve_call_action(event)
+ self._settings = settings
+
+ @staticmethod
+ def _resolve_call_action(event: AstrMessageEvent):
+ bot = getattr(event, "bot", None)
+ api = getattr(bot, "api", None)
+ call_action = getattr(api, "call_action", None)
+ if not callable(call_action):
+ return None
+ return call_action
+
+ async def _call_action_try_params(
+ self,
+ action: str,
+ params_list: list[dict[str, Any]],
+ *,
+ warn_on_all_failed: bool | None = None,
+ ) -> dict[str, Any] | None:
+ if self._call_action is None:
+ return None
+ if warn_on_all_failed is None:
+ warn_on_all_failed = self._settings.warn_on_action_failure
+
+ last_error: Exception | None = None
+ last_params: dict[str, Any] | None = None
+ for params in params_list:
+ try:
+ result = await self._call_action(action, **params)
+ if isinstance(result, dict):
+ return result
+ except Exception as exc:
+ last_error = exc
+ last_params = params
+ logger.debug(
+ "quoted_message_parser: action %s failed with params %s: %s",
+ action,
+ {k: str(v)[:64] for k, v in params.items()},
+ exc,
+ )
+ if warn_on_all_failed and last_error is not None:
+ logger.warning(
+ "quoted_message_parser: all attempts failed for action %s, "
+ "last_params=%s, error=%s",
+ action,
+ (
+ {k: str(v)[:64] for k, v in last_params.items()}
+ if isinstance(last_params, dict)
+ else None
+ ),
+ last_error,
+ )
+ return None
+
+ async def call(
+ self,
+ action: str,
+ params: dict[str, Any],
+ *,
+ warn_on_all_failed: bool = False,
+ unwrap_data: bool = True,
+ ) -> dict[str, Any] | None:
+ ret = await self._call_action_try_params(
+ action,
+ [params],
+ warn_on_all_failed=warn_on_all_failed,
+ )
+ if not unwrap_data:
+ return ret
+ return _unwrap_action_response(ret)
+
+ async def _call_action_compat(
+ self,
+ action: str,
+ message_id: str | int,
+ ) -> dict[str, Any] | None:
+ message_id_str = str(message_id).strip()
+ if not message_id_str:
+ return None
+
+ params_list: list[dict[str, Any]] = [
+ {"message_id": message_id_str},
+ {"id": message_id_str},
+ ]
+ if message_id_str.isdigit():
+ int_id = int(message_id_str)
+ params_list.extend([{"message_id": int_id}, {"id": int_id}])
+ return await self._call_action_try_params(action, params_list)
+
+ async def get_msg(self, message_id: str | int) -> dict[str, Any] | None:
+ return await self._call_action_compat("get_msg", message_id)
+
+ async def get_forward_msg(self, forward_id: str | int) -> dict[str, Any] | None:
+ return await self._call_action_compat("get_forward_msg", forward_id)
diff --git a/astrbot/core/utils/quoted_message/settings.py b/astrbot/core/utils/quoted_message/settings.py
new file mode 100644
index 000000000..2f74f41b6
--- /dev/null
+++ b/astrbot/core/utils/quoted_message/settings.py
@@ -0,0 +1,85 @@
+from __future__ import annotations
+
+from collections.abc import Mapping
+from dataclasses import dataclass
+from typing import Any
+
+_DEFAULT_MAX_COMPONENT_CHAIN_DEPTH = 4
+_DEFAULT_MAX_FORWARD_NODE_DEPTH = 6
+_DEFAULT_MAX_FORWARD_FETCH = 32
+
+
+def _read_int_mapping(
+ mapping: Mapping[str, Any],
+ key: str,
+ default: int,
+) -> int:
+ raw = mapping.get(key)
+ if raw is None:
+ return default
+ try:
+ value = int(raw)
+ except (TypeError, ValueError):
+ return default
+ if value <= 0:
+ return default
+ return value
+
+
+def _read_bool_mapping(
+ mapping: Mapping[str, Any],
+ key: str,
+ default: bool,
+) -> bool:
+ raw = mapping.get(key)
+ if raw is None:
+ return default
+ if isinstance(raw, bool):
+ return raw
+ if isinstance(raw, str):
+ lowered = raw.strip().lower()
+ if lowered in {"1", "true", "yes", "on"}:
+ return True
+ if lowered in {"0", "false", "no", "off"}:
+ return False
+ return default
+
+
+@dataclass(frozen=True)
+class QuotedMessageParserSettings:
+ max_component_chain_depth: int = _DEFAULT_MAX_COMPONENT_CHAIN_DEPTH
+ max_forward_node_depth: int = _DEFAULT_MAX_FORWARD_NODE_DEPTH
+ max_forward_fetch: int = _DEFAULT_MAX_FORWARD_FETCH
+ warn_on_action_failure: bool = False
+
+ def with_overrides(
+ self,
+ overrides: Mapping[str, Any] | None = None,
+ ) -> QuotedMessageParserSettings:
+ if not overrides:
+ return self
+ return QuotedMessageParserSettings(
+ max_component_chain_depth=_read_int_mapping(
+ overrides,
+ "max_component_chain_depth",
+ self.max_component_chain_depth,
+ ),
+ max_forward_node_depth=_read_int_mapping(
+ overrides,
+ "max_forward_node_depth",
+ self.max_forward_node_depth,
+ ),
+ max_forward_fetch=_read_int_mapping(
+ overrides,
+ "max_forward_fetch",
+ self.max_forward_fetch,
+ ),
+ warn_on_action_failure=_read_bool_mapping(
+ overrides,
+ "warn_on_action_failure",
+ self.warn_on_action_failure,
+ ),
+ )
+
+
+SETTINGS = QuotedMessageParserSettings()
diff --git a/astrbot/core/utils/quoted_message_parser.py b/astrbot/core/utils/quoted_message_parser.py
new file mode 100644
index 000000000..fa6ac18dd
--- /dev/null
+++ b/astrbot/core/utils/quoted_message_parser.py
@@ -0,0 +1,11 @@
+from __future__ import annotations
+
+from astrbot.core.utils.quoted_message.extractor import (
+ extract_quoted_message_images,
+ extract_quoted_message_text,
+)
+
+__all__ = [
+ "extract_quoted_message_text",
+ "extract_quoted_message_images",
+]
diff --git a/astrbot/core/utils/string_utils.py b/astrbot/core/utils/string_utils.py
new file mode 100644
index 000000000..8c2aacb4d
--- /dev/null
+++ b/astrbot/core/utils/string_utils.py
@@ -0,0 +1,21 @@
+from __future__ import annotations
+
+from collections.abc import Iterable
+from typing import Any
+
+
+def normalize_and_dedupe_strings(items: Iterable[Any] | None) -> list[str]:
+ if items is None:
+ return []
+
+ normalized: list[str] = []
+ seen: set[str] = set()
+ for item in items:
+ if not isinstance(item, str):
+ continue
+ cleaned = item.strip()
+ if not cleaned or cleaned in seen:
+ continue
+ seen.add(cleaned)
+ normalized.append(cleaned)
+ return normalized
diff --git a/astrbot/core/utils/temp_dir_cleaner.py b/astrbot/core/utils/temp_dir_cleaner.py
new file mode 100644
index 000000000..c0c060098
--- /dev/null
+++ b/astrbot/core/utils/temp_dir_cleaner.py
@@ -0,0 +1,150 @@
+import asyncio
+from collections.abc import Callable
+from dataclasses import dataclass
+from pathlib import Path
+
+from astrbot import logger
+from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
+
+
+def parse_size_to_bytes(value: str | int | float | None) -> int:
+ """Parse size in MB to bytes."""
+ if value is None:
+ return 0
+
+ try:
+ size_mb = float(str(value).strip())
+ except (TypeError, ValueError):
+ return 0
+
+ if size_mb <= 0:
+ return 0
+
+ return int(size_mb * 1024**2)
+
+
+@dataclass
+class TempFileInfo:
+ path: Path
+ size: int
+ mtime: float
+
+
+class TempDirCleaner:
+ CONFIG_KEY = "temp_dir_max_size"
+ DEFAULT_MAX_SIZE = 1024
+ CHECK_INTERVAL_SECONDS = 10 * 60
+ CLEANUP_RATIO = 0.30
+
+ def __init__(
+ self,
+ max_size_getter: Callable[[], str | int | float | None],
+ temp_dir: Path | None = None,
+ ) -> None:
+ self._max_size_getter = max_size_getter
+ self._temp_dir = temp_dir or Path(get_astrbot_temp_path())
+ self._stop_event = asyncio.Event()
+
+ def _limit_bytes(self) -> int:
+ configured = self._max_size_getter()
+ parsed = parse_size_to_bytes(configured)
+ if parsed <= 0:
+ fallback = parse_size_to_bytes(self.DEFAULT_MAX_SIZE)
+ logger.warning(
+ f"Invalid {self.CONFIG_KEY}={configured!r}, fallback to {self.DEFAULT_MAX_SIZE}MB.",
+ )
+ return fallback
+ return parsed
+
+ def _scan_temp_files(self) -> tuple[int, list[TempFileInfo]]:
+ if not self._temp_dir.exists():
+ return 0, []
+
+ total_size = 0
+ files: list[TempFileInfo] = []
+ for path in self._temp_dir.rglob("*"):
+ if not path.is_file():
+ continue
+ try:
+ stat = path.stat()
+ except OSError as e:
+ logger.debug(f"Skip temp file {path} due to stat error: {e}")
+ continue
+ total_size += stat.st_size
+ files.append(
+ TempFileInfo(path=path, size=stat.st_size, mtime=stat.st_mtime)
+ )
+
+ return total_size, files
+
+ def _cleanup_empty_dirs(self) -> None:
+ if not self._temp_dir.exists():
+ return
+ for path in sorted(
+ self._temp_dir.rglob("*"), key=lambda p: len(p.parts), reverse=True
+ ):
+ if not path.is_dir():
+ continue
+ try:
+ path.rmdir()
+ except OSError:
+ continue
+
+ def cleanup_once(self) -> None:
+ limit = self._limit_bytes()
+ if limit <= 0:
+ return
+
+ total_size, files = self._scan_temp_files()
+ if total_size <= limit:
+ return
+
+ target_release = max(int(total_size * self.CLEANUP_RATIO), 1)
+ released = 0
+ removed_files = 0
+
+ for file_info in sorted(files, key=lambda item: item.mtime):
+ try:
+ file_info.path.unlink()
+ except OSError as e:
+ logger.warning(f"Failed to delete temp file {file_info.path}: {e}")
+ continue
+
+ released += file_info.size
+ removed_files += 1
+ if released >= target_release:
+ break
+
+ self._cleanup_empty_dirs()
+
+ logger.warning(
+ f"Temp dir exceeded limit ({total_size} > {limit}). "
+ f"Removed {removed_files} files, released {released} bytes "
+ f"(target {target_release} bytes).",
+ )
+
+ async def run(self) -> None:
+ logger.info(
+ f"TempDirCleaner started. interval={self.CHECK_INTERVAL_SECONDS}s "
+ f"cleanup_ratio={self.CLEANUP_RATIO}",
+ )
+ while not self._stop_event.is_set():
+ try:
+ # File-system traversal and deletion are blocking operations.
+ # Run cleanup in a worker thread to avoid blocking the event loop.
+ await asyncio.to_thread(self.cleanup_once)
+ except Exception as e:
+ logger.error(f"TempDirCleaner run failed: {e}", exc_info=True)
+
+ try:
+ await asyncio.wait_for(
+ self._stop_event.wait(),
+ timeout=self.CHECK_INTERVAL_SECONDS,
+ )
+ except asyncio.TimeoutError:
+ continue
+
+ logger.info("TempDirCleaner stopped.")
+
+ async def stop(self) -> None:
+ self._stop_event.set()
diff --git a/astrbot/core/utils/tencent_record_helper.py b/astrbot/core/utils/tencent_record_helper.py
index b58643bd3..f342484bd 100644
--- a/astrbot/core/utils/tencent_record_helper.py
+++ b/astrbot/core/utils/tencent_record_helper.py
@@ -7,7 +7,7 @@ import wave
from io import BytesIO
from astrbot.core import logger
-from astrbot.core.utils.astrbot_path import get_astrbot_data_path
+from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
async def tencent_silk_to_wav(silk_path: str, output_path: str) -> str:
@@ -117,12 +117,13 @@ async def audio_to_tencent_silk_base64(audio_path: str) -> tuple[str, float]:
except ImportError as e:
raise Exception("未安装 pilk: pip install pilk") from e
- temp_dir = os.path.join(get_astrbot_data_path(), "temp")
+ temp_dir = get_astrbot_temp_path()
os.makedirs(temp_dir, exist_ok=True)
# 是否需要转换为 WAV
ext = os.path.splitext(audio_path)[1].lower()
temp_wav = tempfile.NamedTemporaryFile(
+ prefix="tencent_record_",
suffix=".wav",
delete=False,
dir=temp_dir,
@@ -140,6 +141,7 @@ async def audio_to_tencent_silk_base64(audio_path: str) -> tuple[str, float]:
rate = wav_file.getframerate()
silk_path = tempfile.NamedTemporaryFile(
+ prefix="tencent_record_",
suffix=".silk",
delete=False,
dir=temp_dir,
diff --git a/astrbot/core/utils/webhook_utils.py b/astrbot/core/utils/webhook_utils.py
index 07abc115a..40dada3cb 100644
--- a/astrbot/core/utils/webhook_utils.py
+++ b/astrbot/core/utils/webhook_utils.py
@@ -1,3 +1,4 @@
+import os
import uuid
from astrbot.core import astrbot_config, logger
@@ -20,6 +21,20 @@ def _get_dashboard_port() -> int:
return 6185
+def _is_dashboard_ssl_enabled() -> bool:
+ env_ssl = os.environ.get("DASHBOARD_SSL_ENABLE") or os.environ.get(
+ "ASTRBOT_DASHBOARD_SSL_ENABLE"
+ )
+ if env_ssl is not None:
+ return env_ssl.strip().lower() in {"1", "true", "yes", "on"}
+
+ try:
+ return bool(astrbot_config.get("dashboard", {}).get("ssl", {}).get("enable"))
+ except Exception as e:
+ logger.error(f"获取 dashboard SSL 配置失败: {e!s}")
+ return False
+
+
def log_webhook_info(platform_name: str, webhook_uuid: str) -> None:
"""打印美观的 webhook 信息日志
@@ -38,12 +53,13 @@ def log_webhook_info(platform_name: str, webhook_uuid: str) -> None:
callback_base = callback_base.rstrip("/")
webhook_url = f"{callback_base}/api/platform/webhook/{webhook_uuid}"
+ scheme = "https" if _is_dashboard_ssl_enabled() else "http"
display_log = (
"\n====================\n"
f"🔗 机器人平台 {platform_name} 已启用统一 Webhook 模式\n"
f"📍 Webhook 回调地址: \n"
- f" ➜ http://:{_get_dashboard_port()}/api/platform/webhook/{webhook_uuid}\n"
+ f" ➜ {scheme}://:{_get_dashboard_port()}/api/platform/webhook/{webhook_uuid}\n"
f" ➜ {webhook_url}\n"
"====================\n"
)
diff --git a/astrbot/dashboard/routes/config.py b/astrbot/dashboard/routes/config.py
index c9f7dbd34..1a22ee052 100644
--- a/astrbot/dashboard/routes/config.py
+++ b/astrbot/dashboard/routes/config.py
@@ -1333,6 +1333,30 @@ class ConfigRoute(Route):
f"Unexpected error registering logo for platform {platform.name}: {e}",
)
+ def _inject_platform_metadata_with_i18n(
+ self, platform, metadata, platform_i18n_translations: dict
+ ):
+ """将配置元数据注入到 metadata 中并处理国际化键转换。"""
+ metadata["platform_group"]["metadata"]["platform"].setdefault("items", {})
+ platform_items_to_inject = copy.deepcopy(platform.config_metadata)
+
+ if platform.i18n_resources:
+ i18n_prefix = f"platform_group.platform.{platform.name}"
+
+ for lang, lang_data in platform.i18n_resources.items():
+ platform_i18n_translations.setdefault(lang, {}).setdefault(
+ "platform_group", {}
+ ).setdefault("platform", {})[platform.name] = lang_data
+
+ for field_key, field_value in platform_items_to_inject.items():
+ for key in ("description", "hint", "labels"):
+ if key in field_value:
+ field_value[key] = f"{i18n_prefix}.{field_key}.{key}"
+
+ metadata["platform_group"]["metadata"]["platform"]["items"].update(
+ platform_items_to_inject
+ )
+
async def _get_astrbot_config(self):
config = self.config
metadata = copy.deepcopy(CONFIG_METADATA_2)
@@ -1354,11 +1378,23 @@ class ConfigRoute(Route):
"config_template"
]
+ # 收集平台的 i18n 翻译数据
+ platform_i18n_translations = {}
+
# 收集需要注册logo的平台
logo_registration_tasks = []
for platform in platform_registry:
if platform.default_config_tmpl:
- platform_default_tmpl[platform.name] = platform.default_config_tmpl
+ platform_default_tmpl[platform.name] = copy.deepcopy(
+ platform.default_config_tmpl
+ )
+
+ # 注入配置元数据(在 convert_to_i18n_keys 之后,使用国际化键)
+ if platform.config_metadata:
+ self._inject_platform_metadata_with_i18n(
+ platform, metadata, platform_i18n_translations
+ )
+
# 收集logo注册任务
if platform.logo_path:
logo_registration_tasks.append(
@@ -1377,7 +1413,11 @@ class ConfigRoute(Route):
if provider.default_config_tmpl:
provider_default_tmpl[provider.type] = provider.default_config_tmpl
- return {"metadata": metadata, "config": config}
+ return {
+ "metadata": metadata,
+ "config": config,
+ "platform_i18n_translations": platform_i18n_translations,
+ }
async def _get_plugin_config(self, plugin_name: str):
ret: dict = {"metadata": None, "config": None}
diff --git a/astrbot/dashboard/routes/knowledge_base.py b/astrbot/dashboard/routes/knowledge_base.py
index efdc3bc6a..f0ac5d43d 100644
--- a/astrbot/dashboard/routes/knowledge_base.py
+++ b/astrbot/dashboard/routes/knowledge_base.py
@@ -12,6 +12,7 @@ from quart import request
from astrbot.core import logger
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.provider.provider import EmbeddingProvider, RerankProvider
+from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from ..utils import generate_tsne_visualization
from .route import Response, Route, RouteContext
@@ -703,7 +704,10 @@ class KnowledgeBaseRoute(Route):
file_name = file.filename
# 保存到临时文件
- temp_file_path = f"data/temp/{uuid.uuid4()}_{file_name}"
+ temp_file_path = os.path.join(
+ get_astrbot_temp_path(),
+ f"kb_upload_{uuid.uuid4()}_{file_name}",
+ )
await file.save(temp_file_path)
try:
diff --git a/astrbot/dashboard/routes/live_chat.py b/astrbot/dashboard/routes/live_chat.py
index 85af3cef8..8c922ab69 100644
--- a/astrbot/dashboard/routes/live_chat.py
+++ b/astrbot/dashboard/routes/live_chat.py
@@ -12,7 +12,7 @@ from quart import websocket
from astrbot import logger
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.platform.sources.webchat.webchat_queue_mgr import webchat_queue_mgr
-from astrbot.core.utils.astrbot_path import get_astrbot_data_path
+from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from .route import Route, RouteContext
@@ -60,7 +60,7 @@ class LiveChatSession:
# 组装 WAV 文件
try:
- temp_dir = os.path.join(get_astrbot_data_path(), "temp")
+ temp_dir = get_astrbot_temp_path()
os.makedirs(temp_dir, exist_ok=True)
audio_path = os.path.join(temp_dir, f"live_audio_{uuid.uuid4()}.wav")
diff --git a/astrbot/dashboard/routes/plugin.py b/astrbot/dashboard/routes/plugin.py
index f9f8961b4..ca271cdf6 100644
--- a/astrbot/dashboard/routes/plugin.py
+++ b/astrbot/dashboard/routes/plugin.py
@@ -20,6 +20,7 @@ from astrbot.core.star.filter.permission import PermissionTypeFilter
from astrbot.core.star.filter.regex import RegexFilter
from astrbot.core.star.star_handler import EventType, star_handlers_registry
from astrbot.core.star.star_manager import PluginManager
+from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from .route import Response, Route, RouteContext
@@ -53,11 +54,13 @@ class PluginRoute(Route):
"/plugin/market_list": ("GET", self.get_online_plugins),
"/plugin/off": ("POST", self.off_plugin),
"/plugin/on": ("POST", self.on_plugin),
+ "/plugin/reload-failed": ("POST", self.reload_failed_plugins),
"/plugin/reload": ("POST", self.reload_plugins),
"/plugin/readme": ("GET", self.get_plugin_readme),
"/plugin/changelog": ("GET", self.get_plugin_changelog),
"/plugin/source/get": ("GET", self.get_custom_source),
"/plugin/source/save": ("POST", self.save_custom_source),
+ "/plugin/source/get-failed-plugins": ("GET", self.get_failed_plugins),
}
self.core_lifecycle = core_lifecycle
self.plugin_manager = plugin_manager
@@ -74,6 +77,33 @@ class PluginRoute(Route):
self._logo_cache = {}
+ async def reload_failed_plugins(self):
+ if DEMO_MODE:
+ return (
+ Response()
+ .error("You are not permitted to do this operation in demo mode")
+ .__dict__
+ )
+ try:
+ data = await request.get_json()
+ dir_name = data.get("dir_name") # 这里拿的是目录名,不是插件名
+
+ if not dir_name:
+ return Response().error("缺少插件目录名").__dict__
+
+ # 调用 star_manager.py 中的函数
+ # 注意:传入的是目录名
+ success, err = await self.plugin_manager.reload_failed_plugin(dir_name)
+
+ if success:
+ return Response().ok(None, f"插件 {dir_name} 重载成功。").__dict__
+ else:
+ return Response().error(f"重载失败: {err}").__dict__
+
+ except Exception as e:
+ logger.error(f"/api/plugin/reload-failed: {traceback.format_exc()}")
+ return Response().error(str(e)).__dict__
+
async def reload_plugins(self):
if DEMO_MODE:
return (
@@ -333,6 +363,10 @@ class PluginRoute(Route):
.__dict__
)
+ async def get_failed_plugins(self):
+ """专门获取加载失败的插件列表(字典格式)"""
+ return Response().ok(self.plugin_manager.failed_plugin_dict).__dict__
+
async def get_plugin_handlers_info(self, handler_full_names: list[str]):
"""解析插件行为"""
handlers = []
@@ -431,7 +465,10 @@ class PluginRoute(Route):
file = await request.files
file = file["file"]
logger.info(f"正在安装用户上传的插件 {file.filename}")
- file_path = f"data/temp/{file.filename}"
+ file_path = os.path.join(
+ get_astrbot_temp_path(),
+ f"plugin_upload_{file.filename}",
+ )
await file.save(file_path)
plugin_info = await self.plugin_manager.install_plugin_from_file(file_path)
# self.core_lifecycle.restart()
diff --git a/astrbot/dashboard/routes/stat.py b/astrbot/dashboard/routes/stat.py
index 054eec995..532238ac7 100644
--- a/astrbot/dashboard/routes/stat.py
+++ b/astrbot/dashboard/routes/stat.py
@@ -4,6 +4,7 @@ import threading
import time
import traceback
from functools import cmp_to_key
+from pathlib import Path
import aiohttp
import psutil
@@ -37,6 +38,7 @@ class StatRoute(Route):
"/stat/test-ghproxy-connection": ("POST", self.test_ghproxy_connection),
"/stat/changelog": ("GET", self.get_changelog),
"/stat/changelog/list": ("GET", self.list_changelog_versions),
+ "/stat/first-notice": ("GET", self.get_first_notice),
}
self.db_helper = db_helper
self.register_routes()
@@ -279,3 +281,40 @@ class StatRoute(Route):
except Exception as e:
logger.error(traceback.format_exc())
return Response().error(f"Error: {e!s}").__dict__
+
+ async def get_first_notice(self):
+ """读取项目根目录 FIRST_NOTICE.md 内容。"""
+ try:
+ locale = (request.args.get("locale") or "").strip()
+ if not re.match(r"^[A-Za-z0-9_-]*$", locale):
+ locale = ""
+
+ base_path = Path(get_astrbot_path())
+ candidates: list[Path] = []
+
+ if locale:
+ candidates.append(base_path / f"FIRST_NOTICE.{locale}.md")
+ if locale.lower().startswith("zh"):
+ candidates.append(base_path / "FIRST_NOTICE.md")
+ candidates.append(base_path / "FIRST_NOTICE.zh-CN.md")
+ elif locale.lower().startswith("en"):
+ candidates.append(base_path / "FIRST_NOTICE.en-US.md")
+
+ candidates.extend(
+ [
+ base_path / "FIRST_NOTICE.md",
+ base_path / "FIRST_NOTICE.en-US.md",
+ ],
+ )
+
+ for notice_path in candidates:
+ if not notice_path.is_file():
+ continue
+ content = notice_path.read_text(encoding="utf-8")
+ if content.strip():
+ return Response().ok({"content": content}).__dict__
+
+ return Response().ok({"content": None}).__dict__
+ except Exception as e:
+ logger.error(traceback.format_exc())
+ return Response().error(f"Error: {e!s}").__dict__
diff --git a/astrbot/dashboard/server.py b/astrbot/dashboard/server.py
index 604866a87..6a588a5d3 100644
--- a/astrbot/dashboard/server.py
+++ b/astrbot/dashboard/server.py
@@ -2,6 +2,7 @@ import asyncio
import logging
import os
import socket
+from pathlib import Path
from typing import Protocol, cast
import jwt
@@ -36,6 +37,12 @@ class _AddrWithPort(Protocol):
APP: Quart
+def _parse_env_bool(value: str | None, default: bool) -> bool:
+ if value is None:
+ return default
+ return value.strip().lower() in {"1", "true", "yes", "on"}
+
+
class AstrBotDashboard:
def __init__(
self,
@@ -201,19 +208,33 @@ class AstrBotDashboard:
def run(self):
ip_addr = []
- if p := os.environ.get("DASHBOARD_PORT"):
- port = p
- else:
- port = self.core_lifecycle.astrbot_config["dashboard"].get("port", 6185)
- host = self.core_lifecycle.astrbot_config["dashboard"].get("host", "0.0.0.0")
- enable = self.core_lifecycle.astrbot_config["dashboard"].get("enable", True)
+ dashboard_config = self.core_lifecycle.astrbot_config.get("dashboard", {})
+ port = (
+ os.environ.get("DASHBOARD_PORT")
+ or os.environ.get("ASTRBOT_DASHBOARD_PORT")
+ or dashboard_config.get("port", 6185)
+ )
+ host = (
+ os.environ.get("DASHBOARD_HOST")
+ or os.environ.get("ASTRBOT_DASHBOARD_HOST")
+ or dashboard_config.get("host", "0.0.0.0")
+ )
+ enable = dashboard_config.get("enable", True)
+ ssl_config = dashboard_config.get("ssl", {})
+ if not isinstance(ssl_config, dict):
+ ssl_config = {}
+ ssl_enable = _parse_env_bool(
+ os.environ.get("DASHBOARD_SSL_ENABLE")
+ or os.environ.get("ASTRBOT_DASHBOARD_SSL_ENABLE"),
+ bool(ssl_config.get("enable", False)),
+ )
+ scheme = "https" if ssl_enable else "http"
if not enable:
logger.info("WebUI 已被禁用")
return None
- logger.info(f"正在启动 WebUI, 监听地址: http://{host}:{port}")
-
+ logger.info(f"正在启动 WebUI, 监听地址: {scheme}://{host}:{port}")
if host == "0.0.0.0":
logger.info(
"提示: WebUI 将监听所有网络接口,请注意安全。(可在 data/cmd_config.json 中配置 dashboard.host 以修改 host)",
@@ -241,9 +262,9 @@ class AstrBotDashboard:
raise Exception(f"端口 {port} 已被占用")
parts = [f"\n ✨✨✨\n AstrBot v{VERSION} WebUI 已启动,可访问\n\n"]
- parts.append(f" ➜ 本地: http://localhost:{port}\n")
+ parts.append(f" ➜ 本地: {scheme}://localhost:{port}\n")
for ip in ip_addr:
- parts.append(f" ➜ 网络: http://{ip}:{port}\n")
+ parts.append(f" ➜ 网络: {scheme}://{ip}:{port}\n")
parts.append(" ➜ 默认用户名和密码: astrbot\n ✨✨✨\n")
display = "".join(parts)
@@ -257,11 +278,45 @@ class AstrBotDashboard:
# 配置 Hypercorn
config = HyperConfig()
config.bind = [f"{host}:{port}"]
+ if ssl_enable:
+ cert_file = (
+ os.environ.get("DASHBOARD_SSL_CERT")
+ or os.environ.get("ASTRBOT_DASHBOARD_SSL_CERT")
+ or ssl_config.get("cert_file", "")
+ )
+ key_file = (
+ os.environ.get("DASHBOARD_SSL_KEY")
+ or os.environ.get("ASTRBOT_DASHBOARD_SSL_KEY")
+ or ssl_config.get("key_file", "")
+ )
+ ca_certs = (
+ os.environ.get("DASHBOARD_SSL_CA_CERTS")
+ or os.environ.get("ASTRBOT_DASHBOARD_SSL_CA_CERTS")
+ or ssl_config.get("ca_certs", "")
+ )
+
+ cert_path = Path(cert_file).expanduser()
+ key_path = Path(key_file).expanduser()
+ if not cert_file or not key_file:
+ raise ValueError(
+ "dashboard.ssl.enable 为 true 时,必须配置 cert_file 和 key_file。",
+ )
+ if not cert_path.is_file():
+ raise ValueError(f"SSL 证书文件不存在: {cert_path}")
+ if not key_path.is_file():
+ raise ValueError(f"SSL 私钥文件不存在: {key_path}")
+
+ config.certfile = str(cert_path.resolve())
+ config.keyfile = str(key_path.resolve())
+
+ if ca_certs:
+ ca_path = Path(ca_certs).expanduser()
+ if not ca_path.is_file():
+ raise ValueError(f"SSL CA 证书文件不存在: {ca_path}")
+ config.ca_certs = str(ca_path.resolve())
# 根据配置决定是否禁用访问日志
- disable_access_log = self.core_lifecycle.astrbot_config.get(
- "dashboard", {}
- ).get("disable_access_log", True)
+ disable_access_log = dashboard_config.get("disable_access_log", True)
if disable_access_log:
config.accesslog = None
else:
diff --git a/changelogs/v4.16.0.md b/changelogs/v4.16.0.md
new file mode 100644
index 000000000..fd7657865
--- /dev/null
+++ b/changelogs/v4.16.0.md
@@ -0,0 +1,62 @@
+## What's Changed
+
+### 新增
+- QQ 官方机器人平台支持主动推送消息,私聊场景支持接收文件 ([#5066](https://github.com/AstrBotDevs/AstrBot/issues/5066))
+- 为 Telegram 平台适配器新增等待 AI 回复时自动展示 “正在输入”、“正在上传图片” 等状态的功能 ([#5037](https://github.com/AstrBotDevs/AstrBot/issues/5037))
+- 为飞书适配器增加接收文件、读取引用消息的内容(包括引用的图片、视频、文件、文字等) ([#5018](https://github.com/AstrBotDevs/AstrBot/issues/5018))
+- 新增自定义平台适配器 i18n 支持 ([#5045](https://github.com/AstrBotDevs/AstrBot/issues/5045))
+- 新增临时文件处理能力,可在系统配置中限制 data/temp 目录的最大大小。 ([#5026](https://github.com/AstrBotDevs/AstrBot/issues/5026))
+- 增加首次启动公告功能,支持多语言与 WebUI 集成
+
+### 修复
+
+- 修复 OpenRouter DeepSeek 场景下的 chunk 错误 ([#5069](https://github.com/AstrBotDevs/AstrBot/issues/5069))
+- 修复备份时人格文件夹映射缺失问题 ([#5042](https://github.com/AstrBotDevs/AstrBot/issues/5042))
+- 修复更新日志与官方文档弹窗双滚动条问题 ([#5060](https://github.com/AstrBotDevs/AstrBot/issues/5060))
+- 修复 provider 额外参数弹窗 key 显示异常
+- 修复连接失败时错误日志提示不准确的问题
+- 修复提前返回时未等待 reset 协程导致的资源清理问题 ([#5033](https://github.com/AstrBotDevs/AstrBot/issues/5033))
+- 提升打包版桌面端启动稳定性并优化插件依赖处理 ([#5031](https://github.com/AstrBotDevs/AstrBot/issues/5031))
+- 为 Electron 与后端日志增加按大小轮转 ([#5029](https://github.com/AstrBotDevs/AstrBot/issues/5029))
+- 加固冻结运行时(frozen app runtime)插件依赖加载流程 ([#5015](https://github.com/AstrBotDevs/AstrBot/issues/5015))
+
+### 优化
+- 完善合并消息、引用解析与图片回退,并支持配置化控制 ([#5054](https://github.com/AstrBotDevs/AstrBot/issues/5054))
+- 配置页面支持通过侧边栏子项切换普通配置/系统配置,并补充相关路由修复
+- 优化分段回复间隔时间初始化逻辑 ([#5068](https://github.com/AstrBotDevs/AstrBot/issues/5068))
+
+### 文档与维护
+- 同步并修正 README 文档内容与拼写 ([#5055](https://github.com/AstrBotDevs/AstrBot/issues/5055), [#5014](https://github.com/AstrBotDevs/AstrBot/issues/5014))
+- 新增 AUR 安装方式说明 ([#4879](https://github.com/AstrBotDevs/AstrBot/issues/4879))
+- 执行代码格式化(ruff)
+
+## What's Changed (EN)
+
+### New Features
+- Added proactive message push and private-chat file receiving support for the QQ official bot adapter ([#5066](https://github.com/AstrBotDevs/AstrBot/issues/5066))
+- Added automatic "typing..." and "uploading image..." status display while waiting for AI response in the Telegram adapter ([#5037](https://github.com/AstrBotDevs/AstrBot/issues/5037))
+- Added file receiving and quoted message content reading (including quoted images, videos, files, text, etc.) for the Feishu adapter ([#5018](https://github.com/AstrBotDevs/AstrBot/issues/5018))
+- Added i18n support for custom platform adapters ([#5045](https://github.com/AstrBotDevs/AstrBot/issues/5045))
+- Introduced temporary file handling and `TempDirCleaner` ([#5026](https://github.com/AstrBotDevs/AstrBot/issues/5026))
+- Added a first-launch notice feature with multilingual content and WebUI integration
+
+### Fixes
+- Added sidebar child-tab switching for normal/system config and fixed related routing behavior on the config page
+- Fixed chunk errors when using OpenRouter DeepSeek ([#5069](https://github.com/AstrBotDevs/AstrBot/issues/5069))
+- Improved forwarded-quote parsing and image fallback with configurable controls ([#5054](https://github.com/AstrBotDevs/AstrBot/issues/5054))
+- Fixed missing persona-folder mapping in backup exports ([#5042](https://github.com/AstrBotDevs/AstrBot/issues/5042))
+- Fixed double scrollbar issue in changelog and official docs dialogs ([#5060](https://github.com/AstrBotDevs/AstrBot/issues/5060))
+- Fixed key rendering issues in the provider extra-params dialog
+- Improved error log wording for connection failures
+- Fixed unawaited reset coroutine cleanup on early returns ([#5033](https://github.com/AstrBotDevs/AstrBot/issues/5033))
+- Improved packaged desktop startup stability and plugin dependency handling ([#5031](https://github.com/AstrBotDevs/AstrBot/issues/5031))
+- Added size-based log rotation for Electron and backend logs ([#5029](https://github.com/AstrBotDevs/AstrBot/issues/5029))
+- Hardened plugin dependency loading in frozen app runtime ([#5015](https://github.com/AstrBotDevs/AstrBot/issues/5015))
+
+### Improvements
+- Optimized initialization logic for segmented-reply interval timing ([#5068](https://github.com/AstrBotDevs/AstrBot/issues/5068))
+
+### Docs & Maintenance
+- Synced and fixed README docs and typos ([#5055](https://github.com/AstrBotDevs/AstrBot/issues/5055), [#5014](https://github.com/AstrBotDevs/AstrBot/issues/5014))
+- Added AUR installation instructions ([#4879](https://github.com/AstrBotDevs/AstrBot/issues/4879))
+- Applied code formatting with ruff
diff --git a/changelogs/v4.17.0.md b/changelogs/v4.17.0.md
new file mode 100644
index 000000000..c8f2e93a1
--- /dev/null
+++ b/changelogs/v4.17.0.md
@@ -0,0 +1,29 @@
+## What's Changed
+
+### 新增
+- 新增 LINE 平台适配器与相关配置支持 ([#5085](https://github.com/AstrBotDevs/AstrBot/issues/5085))
+- 新增备用回退聊天模型列表,当主模型报错时自动切换到备用模型 ([#5109](https://github.com/AstrBotDevs/AstrBot/issues/5109))
+- 新增插件加载失败后的热重载支持,便于插件修复后快速恢复 ([#5043](https://github.com/AstrBotDevs/AstrBot/issues/5043))
+- WebUI 新增 SSL 配置选项并同步更新相关日志行为 ([#5117](https://github.com/AstrBotDevs/AstrBot/issues/5117))
+
+### 修复
+- 修复 Dockerfile 中依赖导出流程,增加 `uv lock` 步骤并移除不必要的 `--frozen` 参数,提升构建稳定性 ([#5091](https://github.com/AstrBotDevs/AstrBot/issues/5091), [#5089](https://github.com/AstrBotDevs/AstrBot/issues/5089))
+- 修复首次启动公告 `FIRST_NOTICE.md` 的本地化路径解析问题,补充兼容路径处理 ([#5083](https://github.com/AstrBotDevs/AstrBot/issues/5083), [#5082](https://github.com/AstrBotDevs/AstrBot/issues/5082))
+
+### 优化
+- 日志系统由 `colorlog` 切换为 `loguru`,增强日志输出与展示能力 ([#5115](https://github.com/AstrBotDevs/AstrBot/issues/5115))
+
+## What's Changed (EN)
+
+### New Features
+- Added LINE platform adapter support with related configuration options ([#5085](https://github.com/AstrBotDevs/AstrBot/issues/5085))
+- Added fallback chat model chain support in tool loop runner, with corresponding config and improved provider selection display ([#5109](https://github.com/AstrBotDevs/AstrBot/issues/5109))
+- Added hot reload support after plugin load failure for faster recovery during plugin development and maintenance ([#5043](https://github.com/AstrBotDevs/AstrBot/issues/5043))
+- Added SSL configuration options for WebUI and updated related logging behavior ([#5117](https://github.com/AstrBotDevs/AstrBot/issues/5117))
+
+### Fixes
+- Fixed Dockerfile dependency export flow by adding a `uv lock` step and removing unnecessary `--frozen` flag to improve build stability ([#5091](https://github.com/AstrBotDevs/AstrBot/issues/5091), [#5089](https://github.com/AstrBotDevs/AstrBot/issues/5089))
+- Fixed locale path resolution for `FIRST_NOTICE.md` and added compatible fallback handling ([#5083](https://github.com/AstrBotDevs/AstrBot/issues/5083), [#5082](https://github.com/AstrBotDevs/AstrBot/issues/5082))
+
+### Improvements
+- Replaced `colorlog` with `loguru` to improve logging capabilities and console display ([#5115](https://github.com/AstrBotDevs/AstrBot/issues/5115))
diff --git a/changelogs/v4.17.1.md b/changelogs/v4.17.1.md
new file mode 100644
index 000000000..13e7daad6
--- /dev/null
+++ b/changelogs/v4.17.1.md
@@ -0,0 +1,34 @@
+## What's Changed
+
+hotfix of 4.17.0
+
+- 修复:当开启了 “启用文件日志” 后,无法启动 AstrBot,报错 `ValueError: Invalid unit value while parsing duration: 'files'`。这是由于日志轮转设置中保留配置错误导致的,已通过根据备份数量正确设置保留参数进行修复。
+- fix: When "Enable file logging" is turned on, AstrBot fails to start with error `ValueError: Invalid unit value while parsing duration: 'files'`. This is due to an incorrect retention configuration in the log rotation setup, which has been fixed by properly setting the retention parameter based on backup count.
+
+### 新增
+- 新增 LINE 平台适配器与相关配置支持 ([#5085](https://github.com/AstrBotDevs/AstrBot/issues/5085))
+- 新增备用回退聊天模型列表,当主模型报错时自动切换到备用模型 ([#5109](https://github.com/AstrBotDevs/AstrBot/issues/5109))
+- 新增插件加载失败后的热重载支持,便于插件修复后快速恢复 ([#5043](https://github.com/AstrBotDevs/AstrBot/issues/5043))
+- WebUI 新增 SSL 配置选项并同步更新相关日志行为 ([#5117](https://github.com/AstrBotDevs/AstrBot/issues/5117))
+
+### 修复
+- 修复 Dockerfile 中依赖导出流程,增加 `uv lock` 步骤并移除不必要的 `--frozen` 参数,提升构建稳定性 ([#5091](https://github.com/AstrBotDevs/AstrBot/issues/5091), [#5089](https://github.com/AstrBotDevs/AstrBot/issues/5089))
+- 修复首次启动公告 `FIRST_NOTICE.md` 的本地化路径解析问题,补充兼容路径处理 ([#5083](https://github.com/AstrBotDevs/AstrBot/issues/5083), [#5082](https://github.com/AstrBotDevs/AstrBot/issues/5082))
+
+### 优化
+- 日志系统由 `colorlog` 切换为 `loguru`,增强日志输出与展示能力 ([#5115](https://github.com/AstrBotDevs/AstrBot/issues/5115))
+
+## What's Changed (EN)
+
+### New Features
+- Added LINE platform adapter support with related configuration options ([#5085](https://github.com/AstrBotDevs/AstrBot/issues/5085))
+- Added fallback chat model chain support in tool loop runner, with corresponding config and improved provider selection display ([#5109](https://github.com/AstrBotDevs/AstrBot/issues/5109))
+- Added hot reload support after plugin load failure for faster recovery during plugin development and maintenance ([#5043](https://github.com/AstrBotDevs/AstrBot/issues/5043))
+- Added SSL configuration options for WebUI and updated related logging behavior ([#5117](https://github.com/AstrBotDevs/AstrBot/issues/5117))
+
+### Fixes
+- Fixed Dockerfile dependency export flow by adding a `uv lock` step and removing unnecessary `--frozen` flag to improve build stability ([#5091](https://github.com/AstrBotDevs/AstrBot/issues/5091), [#5089](https://github.com/AstrBotDevs/AstrBot/issues/5089))
+- Fixed locale path resolution for `FIRST_NOTICE.md` and added compatible fallback handling ([#5083](https://github.com/AstrBotDevs/AstrBot/issues/5083), [#5082](https://github.com/AstrBotDevs/AstrBot/issues/5082))
+
+### Improvements
+- Replaced `colorlog` with `loguru` to improve logging capabilities and console display ([#5115](https://github.com/AstrBotDevs/AstrBot/issues/5115))
diff --git a/changelogs/v4.17.2.md b/changelogs/v4.17.2.md
new file mode 100644
index 000000000..e65972de3
--- /dev/null
+++ b/changelogs/v4.17.2.md
@@ -0,0 +1,8 @@
+## What's Changed
+
+hotfix of 4.17.0
+
+- 修复:MCP 服务器的 Tools 没有被正确添加到上下文中。
+- 修复:Electron 桌面应用部署时,系统自带插件未被正确加载的问题。
+- fix: Tools from MCP server were not properly added to context.
+- fix: built-in plugins were not properly loaded in Electron desktop application deployment.
diff --git a/changelogs/v4.17.3.md b/changelogs/v4.17.3.md
new file mode 100644
index 000000000..4b87b6243
--- /dev/null
+++ b/changelogs/v4.17.3.md
@@ -0,0 +1,27 @@
+## What's Changed
+
+### 修复
+- ‼️ 修复 Python 3.14 环境下 `'Plain' object has no attribute 'text'` 报错问题 ([#5154](https://github.com/AstrBotDevs/AstrBot/issues/5154))。
+- ‼️ 修复插件元数据处理流程:在实例化前注入必要属性,避免初始化阶段元数据缺失 ([#5155](https://github.com/AstrBotDevs/AstrBot/issues/5155))。
+- 修复桌面端后端构建中 AstrBot 内置插件运行时依赖未打包的问题 ([#5146](https://github.com/AstrBotDevs/AstrBot/issues/5146))。
+- 修复通过 AstrBot Launcher 启动时仍被检测并触发更新的问题。
+
+### 优化
+
+- Webchat 下,使用 `astrbot_execute_ipython` 工具如果返回了图片,会自动将图片发送到聊天中。
+
+### 其他
+- 执行 `ruff format` 代码格式整理。
+
+## What's Changed (EN)
+
+### Fixes
+- ‼️ Fixed plugin metadata handling by injecting required attributes before instantiation to avoid missing metadata during initialization ([#5155](https://github.com/AstrBotDevs/AstrBot/issues/5155)).
+- ‼️ Fixed `'Plain' object has no attribute 'text'` error when using Python 3.14 ([#5154](https://github.com/AstrBotDevs/AstrBot/issues/5154)).
+- Fixed missing runtime dependencies for built-in plugins in desktop backend builds ([#5146](https://github.com/AstrBotDevs/AstrBot/issues/5146)).
+- Fixed update checks being triggered when AstrBot is launched via AstrBot Launcher.
+
+### Improvements
+- In Webchat, when using the `astrbot_execute_ipython` tool, if an image is returned, it will automatically be sent to the chat.
+### Others
+- Applied `ruff format` code formatting.
diff --git a/dashboard/src/assets/images/platform_logos/line.png b/dashboard/src/assets/images/platform_logos/line.png
new file mode 100644
index 000000000..df80d27f1
Binary files /dev/null and b/dashboard/src/assets/images/platform_logos/line.png differ
diff --git a/dashboard/src/components/platform/AddNewPlatform.vue b/dashboard/src/components/platform/AddNewPlatform.vue
index 6595bf129..e771a56e5 100644
--- a/dashboard/src/components/platform/AddNewPlatform.vue
+++ b/dashboard/src/components/platform/AddNewPlatform.vue
@@ -522,7 +522,14 @@ export default {
}
},
methods: {
- getPlatformIcon,
+ getPlatformIcon(platformType) {
+ // Check for plugin-provided logo_token first
+ const template = this.platformTemplates?.[platformType];
+ if (template && template.logo_token) {
+ return `/api/file/${template.logo_token}`;
+ }
+ return getPlatformIcon(platformType);
+ },
getPlatformDescription,
resetForm() {
this.selectedPlatformType = null;
diff --git a/dashboard/src/components/provider/ProviderSourcesPanel.vue b/dashboard/src/components/provider/ProviderSourcesPanel.vue
index adca10832..8d230d4bc 100644
--- a/dashboard/src/components/provider/ProviderSourcesPanel.vue
+++ b/dashboard/src/components/provider/ProviderSourcesPanel.vue
@@ -4,7 +4,7 @@
{{ tm('providerSources.title') }}
-
+
-
-
- {{ sourceType.label }}
-
-
-
+
+
-
+
+
+
+
+
+
+
+
+ mdi-shape-outline
+
+
+
+
+
+
+
+
+
+
mdi-creation
- {{ getSourceDisplayName(source) }}
+ {{ getSourceDisplayName(source) }}
{{ source.api_base || 'N/A' }}
@@ -72,6 +114,8 @@
\ No newline at end of file
+
diff --git a/dashboard/src/components/shared/ProviderSelector.vue b/dashboard/src/components/shared/ProviderSelector.vue
index 3e32bc8e9..7ffbeb6ea 100644
--- a/dashboard/src/components/shared/ProviderSelector.vue
+++ b/dashboard/src/components/shared/ProviderSelector.vue
@@ -1,16 +1,35 @@
-
+
{{ tm('providerSelector.notSelected') }}
- {{ modelValue }}
+
+ {{ tm('providerSelector.selectedCount', { count: selectedProviders.length }) }}
+
+
+ {{ modelValue }}
+
{{ buttonText || tm('providerSelector.buttonText') }}
+
+
+ {{ providerId }}
+
+
+
@@ -32,10 +51,52 @@
+
+
+
+ {{ tm('providerSelector.selectedCount', { count: selectedProviders.length }) }}
+
+
+
+ {{ providerId }}
+
+
+
+
+
+
+
+
+
+
+
{{ provider.id }}
@@ -67,7 +128,7 @@
- mdi-check-circle
+ mdi-check-circle
@@ -121,7 +182,7 @@ import ProviderPage from '@/views/ProviderPage.vue'
const props = defineProps({
modelValue: {
- type: String,
+ type: [String, Array],
default: ''
},
providerType: {
@@ -135,6 +196,10 @@ const props = defineProps({
buttonText: {
type: String,
default: ''
+ },
+ multiple: {
+ type: Boolean,
+ default: false
}
})
@@ -145,8 +210,16 @@ const dialog = ref(false)
const providerList = ref([])
const loading = ref(false)
const selectedProvider = ref('')
+const selectedProviders = ref([])
const providerDrawer = ref(false)
+const hasSelection = computed(() => {
+ if (props.multiple) {
+ return selectedProviders.value.length > 0
+ }
+ return Boolean(props.modelValue)
+})
+
const defaultTab = computed(() => {
if (props.providerType === 'agent_runner' && props.providerSubtype) {
return `select_agent_runner_provider:${props.providerSubtype}`
@@ -156,7 +229,13 @@ const defaultTab = computed(() => {
// 监听 modelValue 变化,同步到 selectedProvider
watch(() => props.modelValue, (newValue) => {
- selectedProvider.value = newValue || ''
+ if (props.multiple) {
+ selectedProviders.value = Array.isArray(newValue)
+ ? [...newValue.filter((v) => typeof v === 'string' && v)]
+ : []
+ return
+ }
+ selectedProvider.value = typeof newValue === 'string' ? newValue : ''
}, { immediate: true })
watch(providerDrawer, (isOpen, wasOpen) => {
@@ -166,7 +245,13 @@ watch(providerDrawer, (isOpen, wasOpen) => {
})
async function openDialog() {
- selectedProvider.value = props.modelValue || ''
+ if (props.multiple) {
+ selectedProviders.value = Array.isArray(props.modelValue)
+ ? [...props.modelValue.filter((v) => typeof v === 'string' && v)]
+ : []
+ } else {
+ selectedProvider.value = typeof props.modelValue === 'string' ? props.modelValue : ''
+ }
dialog.value = true
await loadProviders()
}
@@ -205,19 +290,72 @@ function matchesProviderSubtype(provider, subtype) {
}
function selectProvider(provider) {
+ if (props.multiple) {
+ if (!provider.id) {
+ selectedProviders.value = []
+ return
+ }
+ const idx = selectedProviders.value.indexOf(provider.id)
+ if (idx >= 0) {
+ selectedProviders.value.splice(idx, 1)
+ } else {
+ selectedProviders.value.push(provider.id)
+ }
+ return
+ }
selectedProvider.value = provider.id
}
function confirmSelection() {
- emit('update:modelValue', selectedProvider.value)
+ if (props.multiple) {
+ emit('update:modelValue', [...selectedProviders.value])
+ } else {
+ emit('update:modelValue', selectedProvider.value)
+ }
dialog.value = false
}
function cancelSelection() {
- selectedProvider.value = props.modelValue || ''
+ if (props.multiple) {
+ selectedProviders.value = Array.isArray(props.modelValue)
+ ? [...props.modelValue.filter((v) => typeof v === 'string' && v)]
+ : []
+ } else {
+ selectedProvider.value = typeof props.modelValue === 'string' ? props.modelValue : ''
+ }
dialog.value = false
}
+function isProviderSelected(providerId) {
+ if (props.multiple) {
+ return selectedProviders.value.includes(providerId)
+ }
+ return selectedProvider.value === providerId
+}
+
+function removeSelected(providerId) {
+ const idx = selectedProviders.value.indexOf(providerId)
+ if (idx >= 0) {
+ selectedProviders.value.splice(idx, 1)
+ }
+}
+
+function moveSelected(index, delta) {
+ const targetIndex = index + delta
+ if (
+ targetIndex < 0
+ || targetIndex >= selectedProviders.value.length
+ || index < 0
+ || index >= selectedProviders.value.length
+ ) {
+ return
+ }
+ const copied = [...selectedProviders.value]
+ const [item] = copied.splice(index, 1)
+ copied.splice(targetIndex, 0, item)
+ selectedProviders.value = copied
+}
+
function openProviderDrawer() {
providerDrawer.value = true
}
@@ -236,6 +374,16 @@ function closeProviderDrawer() {
display: inline-block;
}
+.selected-preview {
+ width: 100%;
+ max-width: 100%;
+}
+
+.selected-order-list {
+ background: rgba(var(--v-theme-surface-variant), 0.15);
+ border-radius: 10px;
+}
+
.v-list-item {
transition: all 0.2s ease;
}
diff --git a/dashboard/src/components/shared/ReadmeDialog.vue b/dashboard/src/components/shared/ReadmeDialog.vue
index 04ab7afd1..ddc27cd90 100644
--- a/dashboard/src/components/shared/ReadmeDialog.vue
+++ b/dashboard/src/components/shared/ReadmeDialog.vue
@@ -35,7 +35,7 @@ const props = defineProps({
mode: {
type: String,
default: "readme",
- validator: (value) => ["readme", "changelog"].includes(value),
+ validator: (value) => ["readme", "changelog", "first-notice"].includes(value),
},
});
@@ -166,19 +166,50 @@ const renderedHtml = computed(() => {
});
const modeConfig = computed(() => {
- const isChangelog = props.mode === "changelog";
- const keyBase = `core.common.${isChangelog ? "changelog" : "readme"}`;
+ if (props.mode === "changelog") {
+ return {
+ title: t("core.common.changelog.title"),
+ loading: t("core.common.changelog.loading"),
+ emptyTitle: t("core.common.changelog.empty.title"),
+ emptySubtitle: t("core.common.changelog.empty.subtitle"),
+ apiPath: "/api/plugin/changelog",
+ showGithubButton: false,
+ showRefreshButton: true,
+ refreshLabel: t("core.common.readme.buttons.refresh"),
+ };
+ }
+
+ if (props.mode === "first-notice") {
+ return {
+ title: t("core.common.firstNotice.title"),
+ loading: t("core.common.firstNotice.loading"),
+ emptyTitle: t("core.common.firstNotice.empty.title"),
+ emptySubtitle: t("core.common.firstNotice.empty.subtitle"),
+ apiPath: "/api/stat/first-notice",
+ showGithubButton: false,
+ showRefreshButton: false,
+ refreshLabel: "",
+ };
+ }
+
return {
- title: t(`${keyBase}.title`),
- loading: t(`${keyBase}.loading`),
- emptyTitle: t(`${keyBase}.empty.title`),
- emptySubtitle: t(`${keyBase}.empty.subtitle`),
- apiPath: `/api/plugin/${isChangelog ? "changelog" : "readme"}`,
+ title: t("core.common.readme.title"),
+ loading: t("core.common.readme.loading"),
+ emptyTitle: t("core.common.readme.empty.title"),
+ emptySubtitle: t("core.common.readme.empty.subtitle"),
+ apiPath: "/api/plugin/readme",
+ showGithubButton: true,
+ showRefreshButton: true,
+ refreshLabel: t("core.common.readme.buttons.refresh"),
};
});
+const requiresPluginName = computed(
+ () => props.mode === "readme" || props.mode === "changelog",
+);
+
async function fetchContent() {
- if (!props.pluginName) return;
+ if (requiresPluginName.value && !props.pluginName) return;
const requestId = ++lastRequestId.value;
loading.value = true;
content.value = null;
@@ -186,9 +217,13 @@ async function fetchContent() {
isEmpty.value = false;
try {
- const res = await axios.get(
- `${modeConfig.value.apiPath}?name=${props.pluginName}`,
- );
+ let params;
+ if (requiresPluginName.value) {
+ params = { name: props.pluginName };
+ } else if (props.mode === "first-notice") {
+ params = { locale: locale.value };
+ }
+ const res = await axios.get(modeConfig.value.apiPath, { params });
if (requestId !== lastRequestId.value) return;
if (res.data.status === "ok") {
@@ -207,7 +242,9 @@ async function fetchContent() {
watch(
[() => props.show, () => props.pluginName, () => props.mode],
([show, name]) => {
- if (show && name) fetchContent();
+ if (!show) return;
+ if (requiresPluginName.value && !name) return;
+ fetchContent();
},
{ immediate: true },
);
@@ -273,22 +310,26 @@ function openExternalLink(url) {
if (!url) return;
window.open(url, "_blank", "noopener,noreferrer");
}
+
+const showActionArea = computed(() => {
+ const hasGithub = modeConfig.value.showGithubButton && !!props.repoUrl;
+ return hasGithub || modeConfig.value.showRefreshButton;
+});
- {{ modeConfig.title }}
+ {{ modeConfig.title }}
mdi-close
-
-
-
+
+
- {{ t("core.common.readme.buttons.refresh") }}
+ {{ modeConfig.refreshLabel }}
@@ -357,7 +399,6 @@ function openExternalLink(url) {
-
diff --git a/dashboard/src/components/shared/TemplateListEditor.vue b/dashboard/src/components/shared/TemplateListEditor.vue
index 796dae2dd..9cc49d9a9 100644
--- a/dashboard/src/components/shared/TemplateListEditor.vue
+++ b/dashboard/src/components/shared/TemplateListEditor.vue
@@ -19,8 +19,8 @@
:key="option.value"
@click="addEntry(option.value)"
>
- {{ option.label }}
- {{ option.hint }}
+ {{ translateIfKey(option.label) }}
+ {{ translateIfKey(option.hint) }}
@@ -58,7 +58,7 @@
{{ templateLabel(entry.__template_key) }}
- {{ getTemplate(entry)?.hint || getTemplate(entry)?.description }}
+ {{ translateIfKey(getTemplate(entry)?.hint || getTemplate(entry)?.description) }}
@@ -82,10 +82,10 @@
>
- {{ itemMeta?.description || itemKey }}
+ {{ translateIfKey(itemMeta?.description) || itemKey }}
- {{ itemMeta.hint }}
+ {{ translateIfKey(itemMeta.hint) }}
@@ -94,10 +94,10 @@
- {{ childMeta?.description || childKey }}
+ {{ translateIfKey(childMeta?.description) || childKey }}
- {{ childMeta?.hint }}
+ {{ translateIfKey(childMeta?.hint) }}
@@ -122,11 +122,11 @@
- {{ itemMeta?.description }} ({{ itemKey }})
+ {{ translateIfKey(itemMeta?.description) }} ({{ itemKey }})
{{ itemKey }}
- {{ itemMeta?.hint }}
+ {{ translateIfKey(itemMeta?.hint) }}
@@ -153,7 +153,7 @@
@@ -62,7 +99,6 @@ onMounted(() => {
-
{
/>
-
- {
-
-
+
+
diff --git a/dashboard/src/layouts/full/vertical-sidebar/VerticalSidebar.vue b/dashboard/src/layouts/full/vertical-sidebar/VerticalSidebar.vue
index 2520e8578..bb879d9c5 100644
--- a/dashboard/src/layouts/full/vertical-sidebar/VerticalSidebar.vue
+++ b/dashboard/src/layouts/full/vertical-sidebar/VerticalSidebar.vue
@@ -319,7 +319,7 @@ function openChangelogDialog() {
diff --git a/dashboard/src/layouts/full/vertical-sidebar/sidebarItem.ts b/dashboard/src/layouts/full/vertical-sidebar/sidebarItem.ts
index 69a3791fa..1f358aa0f 100644
--- a/dashboard/src/layouts/full/vertical-sidebar/sidebarItem.ts
+++ b/dashboard/src/layouts/full/vertical-sidebar/sidebarItem.ts
@@ -36,7 +36,19 @@ const sidebarItem: menu[] = [
{
title: 'core.navigation.config',
icon: 'mdi-cog',
- to: '/config',
+ to: '/config#normal',
+ children: [
+ {
+ title: 'core.navigation.configTabs.normal',
+ icon: 'mdi-cog',
+ to: '/config#normal'
+ },
+ {
+ title: 'core.navigation.configTabs.system',
+ icon: 'mdi-cog-outline',
+ to: '/config#system'
+ }
+ ]
},
{
title: 'core.navigation.extension',
diff --git a/dashboard/src/main.ts b/dashboard/src/main.ts
index 451f1616b..687166654 100644
--- a/dashboard/src/main.ts
+++ b/dashboard/src/main.ts
@@ -84,6 +84,10 @@ axios.interceptors.request.use((config) => {
if (token) {
config.headers['Authorization'] = `Bearer ${token}`;
}
+ const locale = localStorage.getItem('astrbot-locale');
+ if (locale) {
+ config.headers['Accept-Language'] = locale;
+ }
return config;
});
@@ -98,6 +102,10 @@ window.fetch = (input: RequestInfo | URL, init?: RequestInit) => {
if (!headers.has('Authorization')) {
headers.set('Authorization', `Bearer ${token}`);
}
+ const locale = localStorage.getItem('astrbot-locale');
+ if (locale && !headers.has('Accept-Language')) {
+ headers.set('Accept-Language', locale);
+ }
return _origFetch(input, { ...init, headers });
};
diff --git a/dashboard/src/router/MainRoutes.ts b/dashboard/src/router/MainRoutes.ts
index ce0706498..024f591cc 100644
--- a/dashboard/src/router/MainRoutes.ts
+++ b/dashboard/src/router/MainRoutes.ts
@@ -41,6 +41,14 @@ const MainRoutes = {
path: '/config',
component: () => import('@/views/ConfigPage.vue')
},
+ {
+ path: '/normal',
+ redirect: '/config#normal'
+ },
+ {
+ path: '/system',
+ redirect: '/config#system'
+ },
{
name: 'Default',
path: '/dashboard/default',
diff --git a/dashboard/src/utils/platformUtils.js b/dashboard/src/utils/platformUtils.js
index 47f494193..fc56b022a 100644
--- a/dashboard/src/utils/platformUtils.js
+++ b/dashboard/src/utils/platformUtils.js
@@ -34,6 +34,8 @@ export function getPlatformIcon(name) {
return new URL('@/assets/images/platform_logos/satori.png', import.meta.url).href
} else if (name === 'misskey') {
return new URL('@/assets/images/platform_logos/misskey.png', import.meta.url).href
+ } else if (name === 'line') {
+ return new URL('@/assets/images/platform_logos/line.png', import.meta.url).href
}
}
diff --git a/dashboard/src/utils/providerUtils.js b/dashboard/src/utils/providerUtils.js
index 917807fc3..93a3ad547 100644
--- a/dashboard/src/utils/providerUtils.js
+++ b/dashboard/src/utils/providerUtils.js
@@ -18,6 +18,7 @@ export function getProviderIcon(type) {
'deepseek': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/deepseek.svg',
'modelscope': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/modelscope.svg',
'zhipu': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/zhipu.svg',
+ 'nvidia': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/nvidia-color.svg',
'siliconflow': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/siliconcloud.svg',
'moonshot': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/kimi.svg',
'ppio': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/ppio.svg',
diff --git a/dashboard/src/views/ConfigPage.vue b/dashboard/src/views/ConfigPage.vue
index 4c27081f3..617c6716c 100644
--- a/dashboard/src/views/ConfigPage.vue
+++ b/dashboard/src/views/ConfigPage.vue
@@ -4,29 +4,17 @@
-
+ style="margin-bottom: 16px; align-items: center; gap: 12px; width: 100%; justify-content: space-between;">
-
-
-
- {{ tm('configSelection.normalConfig') }}
-
-
- {{ tm('configSelection.systemConfig') }}
-
-
-
@@ -252,6 +240,9 @@ export default {
config_data_str(val) {
this.config_data_has_changed = true;
},
+ '$route.fullPath'(newVal) {
+ this.syncConfigTypeFromHash(newVal);
+ },
initialConfigId(newVal) {
if (!newVal) {
return;
@@ -299,12 +290,57 @@ export default {
}
},
mounted() {
+ const hashConfigType = this.extractConfigTypeFromHash(
+ this.$route?.fullPath || ''
+ );
+ this.configType = hashConfigType || 'normal';
+ this.isSystemConfig = this.configType === 'system';
+
const targetConfigId = this.initialConfigId || 'default';
this.getConfigInfoList(targetConfigId);
// 初始化配置类型状态
this.configType = this.isSystemConfig ? 'system' : 'normal';
+
+ // 监听语言切换事件,重新加载配置以获取插件的 i18n 数据
+ window.addEventListener('astrbot-locale-changed', this.handleLocaleChange);
+ },
+
+ beforeUnmount() {
+ // 移除语言切换事件监听器
+ window.removeEventListener('astrbot-locale-changed', this.handleLocaleChange);
},
methods: {
+ // 处理语言切换事件,重新加载配置以获取插件的 i18n 数据
+ handleLocaleChange() {
+ // 重新加载当前配置
+ if (this.selectedConfigID) {
+ this.getConfig(this.selectedConfigID);
+ } else if (this.isSystemConfig) {
+ this.getConfig();
+ }
+ },
+
+ },
+ methods: {
+ extractConfigTypeFromHash(hash) {
+ const rawHash = String(hash || '');
+ const lastHashIndex = rawHash.lastIndexOf('#');
+ if (lastHashIndex === -1) {
+ return null;
+ }
+ const cleanHash = rawHash.slice(lastHashIndex + 1);
+ return cleanHash === 'system' || cleanHash === 'normal' ? cleanHash : null;
+ },
+ syncConfigTypeFromHash(hash) {
+ const configType = this.extractConfigTypeFromHash(hash);
+ if (!configType || configType === this.configType) {
+ return false;
+ }
+
+ this.configType = configType;
+ this.onConfigTypeToggle();
+ return true;
+ },
getConfigInfoList(abconf_id) {
// 获取配置列表
axios.get('/api/config/abconfs').then((res) => {
@@ -550,19 +586,7 @@ export default {
// 保持向后兼容性,更新 configType
this.configType = this.isSystemConfig ? 'system' : 'normal';
- this.fetched = false; // 重置加载状态
-
- if (this.isSystemConfig) {
- // 切换到系统配置
- this.getConfig();
- } else {
- // 切换回普通配置,如果有选中的配置文件则加载,否则加载default
- if (this.selectedConfigID) {
- this.getConfig(this.selectedConfigID);
- } else {
- this.getConfigInfoList("default");
- }
- }
+ this.onConfigTypeToggle();
},
openTestChat() {
if (!this.selectedConfigID) {
diff --git a/dashboard/src/views/ExtensionPage.vue b/dashboard/src/views/ExtensionPage.vue
index c62637e3c..fb6055d6a 100644
--- a/dashboard/src/views/ExtensionPage.vue
+++ b/dashboard/src/views/ExtensionPage.vue
@@ -14,7 +14,7 @@ import { useCommonStore } from "@/stores/common";
import { useI18n, useModuleI18n } from "@/i18n/composables";
import defaultPluginIcon from "@/assets/images/plugin_icon.png";
-import { ref, computed, onMounted, reactive, watch } from "vue";
+import { ref, computed, onMounted, onUnmounted, reactive, watch } from "vue";
import { useRoute, useRouter } from "vue-router";
const commonStore = useCommonStore();
@@ -357,11 +357,17 @@ const onLoadingDialogResult = (statusCode, result, timeToClose = 2000) => {
setTimeout(resetLoadingDialog, timeToClose);
};
+const failedPluginsDict = ref({});
+
const getExtensions = async () => {
loading_.value = true;
try {
- const res = await axios.get("/api/plugin/get");
+ const res = await axios.get("/api/plugin/get");
Object.assign(extension_data, res.data);
+
+ const failRes = await axios.get("/api/plugin/source/get-failed-plugins");
+ failedPluginsDict.value = failRes.data.data || {};
+
checkUpdate();
} catch (err) {
toast(err, "error");
@@ -370,6 +376,36 @@ const getExtensions = async () => {
}
};
+const handleReloadAllFailed = async () => {
+ const dirNames = Object.keys(failedPluginsDict.value);
+ if (dirNames.length === 0) {
+ toast("没有需要重载的失败插件", "info");
+ return;
+ }
+
+ loading_.value = true;
+ try {
+ const promises = dirNames.map(dir =>
+ axios.post("/api/plugin/reload-failed", { dir_name: dir })
+ );
+ await Promise.all(promises);
+
+ toast("已尝试重载所有失败插件", "success");
+
+ // 清空 message 关闭对话框
+ extension_data.message = "";
+
+ // 刷新列表
+ await getExtensions();
+
+ } catch (e) {
+ console.error("重载失败:", e);
+ toast("批量重载过程中出现错误", "error");
+ } finally {
+ loading_.value = false;
+ }
+};
+
const checkUpdate = () => {
const onlinePluginsMap = new Map();
const onlinePluginsNameMap = new Map();
@@ -1054,6 +1090,22 @@ onMounted(async () => {
}
});
+// 处理语言切换事件,重新加载插件配置以获取插件的 i18n 数据
+const handleLocaleChange = () => {
+ // 如果配置对话框是打开的,重新加载当前插件的配置
+ if (configDialog.value && currentConfigPlugin.value) {
+ openExtensionConfig(currentConfigPlugin.value);
+ }
+};
+
+// 监听语言切换事件
+window.addEventListener("astrbot-locale-changed", handleLocaleChange);
+
+// 清理事件监听器
+onUnmounted(() => {
+ window.removeEventListener("astrbot-locale-changed", handleLocaleChange);
+});
+
// 搜索防抖处理
let searchDebounceTimer = null;
watch(marketSearch, (newVal) => {
@@ -1257,6 +1309,15 @@ watch(activeTab, (newTab) => {
+
+ 尝试一键重载修复
+
{
this.getPlatformStats();
}, 10000);
+
+ // 监听语言切换事件,重新加载配置以获取插件的 i18n 数据
+ window.addEventListener('astrbot-locale-changed', this.handleLocaleChange);
},
beforeUnmount() {
if (this.statsRefreshInterval) {
clearInterval(this.statsRefreshInterval);
}
+ // 移除语言切换事件监听器
+ window.removeEventListener('astrbot-locale-changed', this.handleLocaleChange);
},
methods: {
+ // 处理语言切换事件,重新加载配置以获取插件的 i18n 数据
+ handleLocaleChange() {
+ this.getConfig();
+ },
+
// 从工具函数导入
getPlatformIcon(platform_id) {
// 首先检查是否有来自插件的 logo_token
@@ -305,6 +315,12 @@ export default {
this.config_data = res.data.data.config;
this.fetched = true
this.metadata = res.data.data.metadata;
+
+ // 将插件平台适配器的 i18n 翻译注入到前端 i18n 系统中
+ const platformI18n = res.data.data.platform_i18n_translations;
+ if (platformI18n && typeof platformI18n === 'object') {
+ mergeDynamicTranslations('features.config-metadata', platformI18n);
+ }
}).catch((err) => {
this.showError(err);
});
diff --git a/dashboard/src/views/WelcomePage.vue b/dashboard/src/views/WelcomePage.vue
index eb7c80308..8d52131b1 100644
--- a/dashboard/src/views/WelcomePage.vue
+++ b/dashboard/src/views/WelcomePage.vue
@@ -70,7 +70,7 @@
{{ tm('resources.title') }}
-
+
@@ -84,7 +84,7 @@
-
+
@@ -98,6 +98,20 @@
+
+
+
+
+ mdi-hand-heart
+ {{ tm('resources.afdianTitle') }}
+
+
+ {{ tm('resources.afdianDesc') }}
+
+
+
+
diff --git a/desktop/README.md b/desktop/README.md
index b5698e3ee..48dcb341a 100644
--- a/desktop/README.md
+++ b/desktop/README.md
@@ -91,6 +91,15 @@ Runtime logs:
- Electron shell log: `~/.astrbot/logs/electron.log`
- Backend stdout/stderr log: `~/.astrbot/logs/backend.log`
+- Both files rotate by size by default: `20MB` per file, keep `3` backups.
+- Electron log rotation envs:
+ - `ASTRBOT_ELECTRON_LOG_MAX_MB`
+ - `ASTRBOT_ELECTRON_LOG_BACKUP_COUNT`
+- Backend log rotation envs:
+ - `ASTRBOT_BACKEND_LOG_MAX_MB`
+ - `ASTRBOT_BACKEND_LOG_BACKUP_COUNT`
+- Rotation debug logging:
+ - `ASTRBOT_LOG_ROTATION_DEBUG=1` (or `NODE_ENV=development`) to print filesystem errors from rotation operations.
- On backend startup failure, the app dialog also shows the backend reason and backend log path.
Timeout and loading controls:
diff --git a/desktop/lib/backend-manager.js b/desktop/lib/backend-manager.js
index 477995027..eb8958a4c 100644
--- a/desktop/lib/backend-manager.js
+++ b/desktop/lib/backend-manager.js
@@ -4,10 +4,21 @@ const fs = require('fs');
const os = require('os');
const path = require('path');
const { spawn, spawnSync } = require('child_process');
-const { delay, ensureDir, normalizeUrl, waitForProcessExit } = require('./common');
+const { BufferedRotatingLogger } = require('./buffered-rotating-logger');
+const {
+ delay,
+ ensureDir,
+ formatLogTimestamp,
+ normalizeUrl,
+ parseLogBackupCount,
+ parseLogMaxBytes,
+ waitForProcessExit,
+} = require('./common');
const PACKAGED_BACKEND_TIMEOUT_FALLBACK_MS = 5 * 60 * 1000;
const GRACEFUL_RESTART_WAIT_FALLBACK_MS = 20 * 1000;
+const BACKEND_LOG_FLUSH_INTERVAL_MS = 120;
+const BACKEND_LOG_MAX_BUFFER_BYTES = 128 * 1024;
function parseBackendTimeoutMs(app) {
const defaultTimeoutMs = app.isPackaged ? 0 : 20000;
@@ -34,10 +45,22 @@ class BackendManager {
);
this.backendAutoStart = process.env.ASTRBOT_BACKEND_AUTO_START !== '0';
this.backendTimeoutMs = parseBackendTimeoutMs(app);
+ this.backendLogMaxBytes = parseLogMaxBytes(
+ process.env.ASTRBOT_BACKEND_LOG_MAX_MB,
+ );
+ this.backendLogBackupCount = parseLogBackupCount(
+ process.env.ASTRBOT_BACKEND_LOG_BACKUP_COUNT,
+ );
this.backendProcess = null;
this.backendConfig = null;
- this.backendLogFd = null;
+ this.backendLogger = new BufferedRotatingLogger({
+ logPath: null,
+ maxBytes: this.backendLogMaxBytes,
+ backupCount: this.backendLogBackupCount,
+ flushIntervalMs: BACKEND_LOG_FLUSH_INTERVAL_MS,
+ maxBufferBytes: BACKEND_LOG_MAX_BUFFER_BYTES,
+ });
this.backendLastExitReason = null;
this.backendStartupFailureReason = null;
this.backendSpawning = false;
@@ -195,14 +218,8 @@ class BackendManager {
return Boolean(this.getBackendConfig().cmd);
}
- closeBackendLogFd() {
- if (this.backendLogFd === null) {
- return;
- }
- try {
- fs.closeSync(this.backendLogFd);
- } catch {}
- this.backendLogFd = null;
+ async flushLogs() {
+ await this.backendLogger.flush();
}
async pingBackend(timeoutMs = 800) {
@@ -355,7 +372,7 @@ class BackendManager {
}
}
- startBackend() {
+ async startBackend() {
if (this.shouldSkipStart()) {
this.log('Skip backend start because app is quitting.');
return;
@@ -375,65 +392,77 @@ class BackendManager {
};
if (this.app.isPackaged) {
env.ASTRBOT_ELECTRON_CLIENT = '1';
+ const hasExplicitDashboardHost = Boolean(
+ process.env.DASHBOARD_HOST || process.env.ASTRBOT_DASHBOARD_HOST,
+ );
+ const hasExplicitDashboardPort = Boolean(
+ process.env.DASHBOARD_PORT || process.env.ASTRBOT_DASHBOARD_PORT,
+ );
+ if (!hasExplicitDashboardHost) {
+ env.DASHBOARD_HOST = '127.0.0.1';
+ }
+ if (!hasExplicitDashboardPort) {
+ env.DASHBOARD_PORT = '6185';
+ }
}
if (backendConfig.webuiDir) {
env.ASTRBOT_WEBUI_DIR = backendConfig.webuiDir;
}
+ let backendLogPath = null;
if (backendConfig.rootDir) {
env.ASTRBOT_ROOT = backendConfig.rootDir;
const logsDir = path.join(backendConfig.rootDir, 'logs');
ensureDir(logsDir);
- const logPath = path.join(logsDir, 'backend.log');
- try {
- this.backendLogFd = fs.openSync(logPath, 'a');
- } catch {
- this.backendLogFd = null;
- }
+ backendLogPath = path.join(logsDir, 'backend.log');
}
+ await this.backendLogger.setLogPath(backendLogPath);
+ const usePipedLogging = Boolean(backendLogPath);
this.backendProcess = spawn(backendConfig.cmd, backendConfig.args || [], {
cwd: backendConfig.cwd,
env,
shell: backendConfig.shell,
- stdio:
- this.backendLogFd === null
- ? 'ignore'
- : ['ignore', this.backendLogFd, this.backendLogFd],
+ stdio: usePipedLogging ? ['ignore', 'pipe', 'pipe'] : 'ignore',
windowsHide: true,
});
- if (this.backendLogFd !== null) {
+ if (usePipedLogging) {
+ if (this.backendProcess.stdout) {
+ this.backendProcess.stdout.on('data', (chunk) => {
+ this.backendLogger.log(chunk);
+ });
+ }
+ if (this.backendProcess.stderr) {
+ this.backendProcess.stderr.on('data', (chunk) => {
+ this.backendLogger.log(chunk);
+ });
+ }
+ }
+
+ if (usePipedLogging) {
const launchLine = [backendConfig.cmd, ...(backendConfig.args || [])]
.map((item) => JSON.stringify(item))
.join(' ');
- try {
- fs.writeSync(
- this.backendLogFd,
- `[${new Date().toISOString()}] [Electron] Start backend ${launchLine}\n`,
- );
- } catch {}
+ this.backendLogger.log(
+ `[${formatLogTimestamp()}] [Electron] Start backend ${launchLine}\n`,
+ );
}
this.backendProcess.on('error', (error) => {
this.backendLastExitReason =
error instanceof Error ? error.message : String(error);
- if (this.backendLogFd !== null) {
- try {
- fs.writeSync(
- this.backendLogFd,
- `[${new Date().toISOString()}] [Electron] Backend spawn error: ${
- error instanceof Error ? error.message : String(error)
- }\n`,
- );
- } catch {}
- }
- this.closeBackendLogFd();
+ this.backendLogger.log(
+ `[${formatLogTimestamp()}] [Electron] Backend spawn error: ${
+ error instanceof Error ? error.message : String(error)
+ }\n`,
+ );
+ void this.backendLogger.flush();
this.backendProcess = null;
});
this.backendProcess.on('exit', (code, signal) => {
this.backendLastExitReason = `Backend process exited (code=${code ?? 'null'}, signal=${signal ?? 'null'}).`;
- this.closeBackendLogFd();
+ void this.backendLogger.flush();
this.backendProcess = null;
});
}
@@ -447,7 +476,7 @@ class BackendManager {
}
this.backendSpawning = true;
try {
- this.startBackend();
+ await this.startBackend();
return await this.waitForBackend(maxWaitMs, true);
} finally {
this.backendSpawning = false;
@@ -506,7 +535,7 @@ class BackendManager {
await waitForProcessExit(processToStop, 1500);
}
}
- this.closeBackendLogFd();
+ await this.backendLogger.flush();
}
findListeningPidsOnWindows(port) {
diff --git a/desktop/lib/buffered-rotating-logger.js b/desktop/lib/buffered-rotating-logger.js
new file mode 100644
index 000000000..7a443a97d
--- /dev/null
+++ b/desktop/lib/buffered-rotating-logger.js
@@ -0,0 +1,162 @@
+'use strict';
+
+const { RotatingLogWriter } = require('./rotating-log-writer');
+const { parseEnvInt } = require('./common');
+
+const DEFAULT_FLUSH_INTERVAL_MS = 120;
+const DEFAULT_MAX_BUFFER_BYTES = 128 * 1024;
+const MIN_FLUSH_INTERVAL_MS = 10;
+const MIN_MAX_BUFFER_BYTES = 4 * 1024;
+const MAX_MAX_BUFFER_BYTES = 16 * 1024 * 1024;
+
+function clampIntOption(raw, { defaultValue, min, max }) {
+ const value = parseEnvInt(raw, defaultValue);
+ return Math.min(Math.max(value, min), max);
+}
+
+class BufferedRotatingLogger {
+ constructor({
+ logPath = null,
+ maxBytes,
+ backupCount,
+ flushIntervalMs,
+ maxBufferBytes,
+ label = 'buffered-log',
+ }) {
+ this.logPath = logPath || null;
+ this.flushIntervalMs = clampIntOption(flushIntervalMs, {
+ defaultValue: DEFAULT_FLUSH_INTERVAL_MS,
+ min: MIN_FLUSH_INTERVAL_MS,
+ max: 60 * 1000,
+ });
+ this.maxBufferBytes = clampIntOption(maxBufferBytes, {
+ defaultValue: DEFAULT_MAX_BUFFER_BYTES,
+ min: MIN_MAX_BUFFER_BYTES,
+ max: MAX_MAX_BUFFER_BYTES,
+ });
+ this.buffer = [];
+ this.bufferBytes = 0;
+ this.flushTimer = null;
+ this.pathSwitch = Promise.resolve();
+ this.writer = new RotatingLogWriter({
+ logPath: this.logPath,
+ maxBytes,
+ backupCount,
+ label,
+ });
+ }
+
+ setLogPath(logPath) {
+ const nextLogPath = logPath || null;
+ this.pathSwitch = this.pathSwitch.then(async () => {
+ if (nextLogPath === this.logPath) {
+ await this.flush();
+ return;
+ }
+
+ const previousLogPath = this.logPath;
+ if (previousLogPath) {
+ await this.flush();
+ }
+
+ this.logPath = null;
+ await this.writer.setLogPath(nextLogPath);
+ this.logPath = nextLogPath;
+ await this.flush();
+ });
+ return this.pathSwitch;
+ }
+
+ log(payload) {
+ if (payload === undefined || payload === null) {
+ return;
+ }
+ const chunk = Buffer.isBuffer(payload)
+ ? payload
+ : Buffer.from(String(payload), 'utf8');
+ if (!chunk.length) {
+ return;
+ }
+
+ if (!this.logPath) {
+ const boundedChunk = this.clipChunkToBufferLimit(chunk);
+ this.dropOldestUntilWithinLimit(boundedChunk.length);
+ this.buffer.push(boundedChunk);
+ this.bufferBytes += boundedChunk.length;
+ return;
+ }
+
+ this.buffer.push(chunk);
+ this.bufferBytes += chunk.length;
+
+ if (this.bufferBytes >= this.maxBufferBytes) {
+ void this.flush();
+ return;
+ }
+ this.scheduleFlush();
+ }
+
+ flush() {
+ this.clearFlushTimer();
+ if (!this.buffer.length) {
+ return this.writer.flush();
+ }
+ if (!this.logPath) {
+ // Path is switching or temporarily unavailable; keep buffered data.
+ this.dropOldestUntilWithinLimit(0);
+ return this.writer.flush();
+ }
+
+ const chunks = this.buffer;
+ this.buffer = [];
+ this.bufferBytes = 0;
+ const payload = chunks.length === 1 ? chunks[0] : Buffer.concat(chunks);
+ this.writer.append(payload);
+ return this.writer.flush();
+ }
+
+ dropOldestUntilWithinLimit(incomingBytes = 0) {
+ while (
+ this.buffer.length &&
+ this.bufferBytes + Math.max(0, incomingBytes) > this.maxBufferBytes
+ ) {
+ const removed = this.buffer.shift();
+ if (removed) {
+ this.bufferBytes -= removed.length;
+ }
+ }
+ if (this.bufferBytes < 0) {
+ this.bufferBytes = 0;
+ }
+ }
+
+ clipChunkToBufferLimit(chunk) {
+ if (chunk.length <= this.maxBufferBytes) {
+ return chunk;
+ }
+ return chunk.subarray(chunk.length - this.maxBufferBytes);
+ }
+
+ scheduleFlush() {
+ if (this.flushTimer !== null) {
+ return;
+ }
+ this.flushTimer = setTimeout(() => {
+ this.flushTimer = null;
+ void this.flush();
+ }, this.flushIntervalMs);
+ this.flushTimer.unref?.();
+ }
+
+ clearFlushTimer() {
+ if (this.flushTimer === null) {
+ return;
+ }
+ clearTimeout(this.flushTimer);
+ this.flushTimer = null;
+ }
+}
+
+module.exports = {
+ BufferedRotatingLogger,
+};
diff --git a/desktop/lib/common.js b/desktop/lib/common.js
index 561eca39e..9f39358dc 100644
--- a/desktop/lib/common.js
+++ b/desktop/lib/common.js
@@ -2,6 +2,9 @@
const fs = require('fs');
+const LOG_ROTATION_DEFAULT_MAX_MB = 20;
+const LOG_ROTATION_DEFAULT_BACKUP_COUNT = 3;
+
function normalizeUrl(value) {
try {
const url = new URL(value);
@@ -24,6 +27,33 @@ function ensureDir(value) {
fs.mkdirSync(value, { recursive: true });
}
+function parseEnvInt(raw, defaultValue) {
+ const parsed = Number.parseInt(`${raw ?? ''}`, 10);
+ return Number.isFinite(parsed) ? parsed : defaultValue;
+}
+
+function isLogRotationDebugEnabled() {
+ return (
+ process.env.ASTRBOT_LOG_ROTATION_DEBUG === '1' ||
+ process.env.NODE_ENV === 'development'
+ );
+}
+
+function parseLogMaxBytes(envValue) {
+ const mb = parseEnvInt(envValue, LOG_ROTATION_DEFAULT_MAX_MB);
+ const maxMb = mb > 0 ? mb : LOG_ROTATION_DEFAULT_MAX_MB;
+ return maxMb * 1024 * 1024;
+}
+
+function parseLogBackupCount(envValue) {
+ const count = parseEnvInt(envValue, LOG_ROTATION_DEFAULT_BACKUP_COUNT);
+ return count >= 0 ? count : LOG_ROTATION_DEFAULT_BACKUP_COUNT;
+}
+
+function isIgnorableFsError(error) {
+ return Boolean(error && typeof error === 'object' && error.code === 'ENOENT');
+}
+
function delay(ms) {
return new Promise((resolve) => setTimeout(resolve, ms));
}
@@ -51,9 +81,35 @@ function waitForProcessExit(child, timeoutMs = 5000) {
});
}
+function formatLogTimestamp(date = new Date()) {
+ const year = date.getFullYear();
+ const month = `${date.getMonth() + 1}`.padStart(2, '0');
+ const day = `${date.getDate()}`.padStart(2, '0');
+ const hour = `${date.getHours()}`.padStart(2, '0');
+ const minute = `${date.getMinutes()}`.padStart(2, '0');
+ const second = `${date.getSeconds()}`.padStart(2, '0');
+ const millisecond = `${date.getMilliseconds()}`.padStart(3, '0');
+
+ const offsetMinutes = -date.getTimezoneOffset();
+ const offsetSign = offsetMinutes >= 0 ? '+' : '-';
+ const absOffsetMinutes = Math.abs(offsetMinutes);
+ const offsetHour = `${Math.floor(absOffsetMinutes / 60)}`.padStart(2, '0');
+ const offsetMinute = `${absOffsetMinutes % 60}`.padStart(2, '0');
+
+ return `${year}-${month}-${day} ${hour}:${minute}:${second}.${millisecond} ${offsetSign}${offsetHour}${offsetMinute}`;
+}
+
module.exports = {
+ LOG_ROTATION_DEFAULT_BACKUP_COUNT,
+ LOG_ROTATION_DEFAULT_MAX_MB,
delay,
ensureDir,
+ formatLogTimestamp,
+ isIgnorableFsError,
+ isLogRotationDebugEnabled,
normalizeUrl,
+ parseEnvInt,
+ parseLogBackupCount,
+ parseLogMaxBytes,
waitForProcessExit,
};
diff --git a/desktop/lib/electron-logger.js b/desktop/lib/electron-logger.js
index b8dc73bc6..6a52d1c76 100644
--- a/desktop/lib/electron-logger.js
+++ b/desktop/lib/electron-logger.js
@@ -1,10 +1,27 @@
'use strict';
-const fs = require('fs');
const path = require('path');
-const { ensureDir } = require('./common');
+const { RotatingLogWriter } = require('./rotating-log-writer');
+const {
+ formatLogTimestamp,
+ parseLogBackupCount,
+ parseLogMaxBytes,
+} = require('./common');
function createElectronLogger({ app, getRootDir }) {
+ const electronLogMaxBytes = parseLogMaxBytes(
+ process.env.ASTRBOT_ELECTRON_LOG_MAX_MB,
+ );
+ const electronLogBackupCount = parseLogBackupCount(
+ process.env.ASTRBOT_ELECTRON_LOG_BACKUP_COUNT,
+ );
+ const writer = new RotatingLogWriter({
+ logPath: null,
+ maxBytes: electronLogMaxBytes,
+ backupCount: electronLogBackupCount,
+ label: 'electron-log',
+ });
+
function getElectronLogPath() {
const rootDir =
process.env.ASTRBOT_ROOT ||
@@ -15,16 +32,19 @@ function createElectronLogger({ app, getRootDir }) {
function logElectron(message) {
const logPath = getElectronLogPath();
- ensureDir(path.dirname(logPath));
- const line = `[${new Date().toISOString()}] ${message}\n`;
- try {
- fs.appendFileSync(logPath, line, 'utf8');
- } catch {}
+ const line = `[${formatLogTimestamp()}] ${message}\n`;
+ void writer.setLogPath(logPath);
+ void writer.append(line);
+ }
+
+ async function flushElectron() {
+ await writer.flush();
}
return {
getElectronLogPath,
logElectron,
+ flushElectron,
};
}
diff --git a/desktop/lib/rotating-log-writer.js b/desktop/lib/rotating-log-writer.js
new file mode 100644
index 000000000..c6c8f8fb1
--- /dev/null
+++ b/desktop/lib/rotating-log-writer.js
@@ -0,0 +1,178 @@
+'use strict';
+
+const fs = require('fs/promises');
+const path = require('path');
+const { isIgnorableFsError, isLogRotationDebugEnabled } = require('./common');
+
+class RotatingLogWriter {
+ constructor({ logPath = null, maxBytes = 0, backupCount = 0, label = 'log' }) {
+ this.logPath = logPath || null;
+ this.maxBytes = Number.isFinite(maxBytes) && maxBytes > 0 ? maxBytes : 0;
+ this.backupCount = Number.isFinite(backupCount) && backupCount >= 0 ? backupCount : 0;
+ this.label = label;
+ this.cachedSize = null;
+ this.dirReadyForPath = null;
+ this.queue = Promise.resolve();
+ }
+
+ setLogPath(logPath) {
+ const nextPath = logPath || null;
+ if (nextPath === this.logPath) {
+ return this.queue;
+ }
+ return this.enqueue(async () => {
+ this.logPath = nextPath;
+ this.cachedSize = null;
+ this.dirReadyForPath = null;
+ });
+ }
+
+ append(payload) {
+ if (payload === undefined || payload === null) {
+ return this.queue;
+ }
+ const content = Buffer.isBuffer(payload)
+ ? payload
+ : Buffer.from(String(payload), 'utf8');
+ if (!content.length) {
+ return this.queue;
+ }
+ return this.enqueue(async () => {
+ if (!this.logPath) {
+ return;
+ }
+ await this.ensureDirReady();
+ await this.ensureSizeLoaded();
+ await this.rotateIfNeeded(content.length);
+ await fs.appendFile(this.logPath, content);
+ if (!Number.isFinite(this.cachedSize)) {
+ this.cachedSize = await this.readSize();
+ } else {
+ this.cachedSize += content.length;
+ }
+ });
+ }
+
+ flush() {
+ return this.queue;
+ }
+
+ enqueue(task) {
+ const run = async () => {
+ try {
+ await task();
+ } catch (error) {
+ this.reportError('write', this.logPath || 'unknown', error);
+ }
+ };
+ this.queue = this.queue.then(run, run);
+ return this.queue;
+ }
+
+ async ensureSizeLoaded() {
+ if (Number.isFinite(this.cachedSize)) {
+ return;
+ }
+ this.cachedSize = await this.readSize();
+ }
+
+ async ensureDirReady() {
+ if (!this.logPath) {
+ return;
+ }
+ if (this.dirReadyForPath === this.logPath) {
+ return;
+ }
+ const dirPath = path.dirname(this.logPath);
+ try {
+ await fs.mkdir(dirPath, { recursive: true });
+ this.dirReadyForPath = this.logPath;
+ } catch (error) {
+ this.reportError('mkdir', dirPath, error);
+ }
+ }
+
+ async readSize() {
+ if (!this.logPath) {
+ return 0;
+ }
+ try {
+ const stat = await fs.stat(this.logPath);
+ return stat.size;
+ } catch (error) {
+ if (isIgnorableFsError(error)) {
+ return 0;
+ }
+ this.reportError('stat', this.logPath, error);
+ return 0;
+ }
+ }
+
+ async rotateIfNeeded(incomingBytes) {
+ if (!this.logPath || this.maxBytes <= 0) {
+ return;
+ }
+
+ const currentSize = Number.isFinite(this.cachedSize) ? this.cachedSize : 0;
+ if (currentSize + Math.max(0, incomingBytes) <= this.maxBytes) {
+ return;
+ }
+
+ if (this.backupCount <= 0) {
+ try {
+ await fs.truncate(this.logPath, 0);
+ } catch (error) {
+ if (!isIgnorableFsError(error)) {
+ this.reportError('truncate', this.logPath, error);
+ }
+ }
+ this.cachedSize = await this.readSize();
+ return;
+ }
+
+ const oldestPath = `${this.logPath}.${this.backupCount}`;
+ try {
+ await fs.unlink(oldestPath);
+ } catch (error) {
+ if (!isIgnorableFsError(error)) {
+ this.reportError('unlink', oldestPath, error);
+ }
+ }
+
+ for (let index = this.backupCount - 1; index >= 1; index -= 1) {
+ const sourcePath = `${this.logPath}.${index}`;
+ const targetPath = `${this.logPath}.${index + 1}`;
+ try {
+ await fs.rename(sourcePath, targetPath);
+ } catch (error) {
+ if (!isIgnorableFsError(error)) {
+ this.reportError('rename', `${sourcePath} -> ${targetPath}`, error);
+ }
+ }
+ }
+
+ try {
+ await fs.rename(this.logPath, `${this.logPath}.1`);
+ } catch (error) {
+ if (!isIgnorableFsError(error)) {
+ this.reportError('rename', `${this.logPath} -> ${this.logPath}.1`, error);
+ }
+ }
+
+ this.cachedSize = await this.readSize();
+ }
+
+ reportError(action, targetPath, error) {
+ if (!isLogRotationDebugEnabled()) {
+ return;
+ }
+ const details = error instanceof Error ? error.message : String(error);
+ console.error(
+ `[astrbot][${this.label}] ${action} failed for ${targetPath}: ${details}`,
+ );
+ }
+}
+
+module.exports = {
+ RotatingLogWriter,
+};
diff --git a/desktop/main.js b/desktop/main.js
index 6118c4360..5adff38b3 100644
--- a/desktop/main.js
+++ b/desktop/main.js
@@ -36,7 +36,7 @@ let backendManager = null;
app.commandLine.appendSwitch('disable-http-cache');
-const { logElectron } = createElectronLogger({
+const { logElectron, flushElectron } = createElectronLogger({
app,
getRootDir: () => (backendManager ? backendManager.getRootDir() : null),
});
@@ -387,8 +387,12 @@ app.on('before-quit', (event) => {
}
}),
)
- .finally(() => {
+ .finally(async () => {
logElectron('Backend stop finished, exiting app.');
+ await Promise.allSettled([
+ flushElectron(),
+ backendManager ? backendManager.flushLogs() : Promise.resolve(),
+ ]);
app.exit(0);
});
});
diff --git a/desktop/package.json b/desktop/package.json
index c70689f1e..3e04f21b1 100644
--- a/desktop/package.json
+++ b/desktop/package.json
@@ -1,6 +1,6 @@
{
"name": "astrbot-desktop",
- "version": "4.15.0",
+ "version": "4.17.3",
"description": "AstrBot desktop wrapper",
"private": true,
"main": "main.js",
@@ -22,7 +22,7 @@
"dist": "pnpm run sync:version && electron-builder"
},
"devDependencies": {
- "electron": "^30.0.0",
+ "electron": "^40.3.0",
"electron-builder": "^24.13.0"
},
"build": {
diff --git a/desktop/pnpm-lock.yaml b/desktop/pnpm-lock.yaml
index f91a21a86..98411a90e 100644
--- a/desktop/pnpm-lock.yaml
+++ b/desktop/pnpm-lock.yaml
@@ -9,8 +9,8 @@ importers:
.:
devDependencies:
electron:
- specifier: ^30.0.0
- version: 30.5.1
+ specifier: ^40.3.0
+ version: 40.3.0
electron-builder:
specifier: ^24.13.0
version: 24.13.3(electron-builder-squirrel-windows@24.13.3)
@@ -92,8 +92,8 @@ packages:
'@types/ms@2.1.0':
resolution: {integrity: sha512-GsCCIZDE/p3i96vtEqx+7dBUGXrc7zeSK3wwPHIaRThS+9OhWIXRqzs4d6k1SVU8g91DrNRWxWUGhp5KXQb2VA==}
- '@types/node@20.19.33':
- resolution: {integrity: sha512-Rs1bVAIdBs5gbTIKza/tgpMuG1k3U/UMJLWecIMxNdJFDMzcM5LOiLVRYh3PilWEYDIeUDv7bpiHPLPsbydGcw==}
+ '@types/node@24.10.13':
+ resolution: {integrity: sha512-oH72nZRfDv9lADUBSo104Aq7gPHpQZc4BTx38r9xf9pg5LfP6EzSyH2n7qFmmxRQXh7YlUXODcYsg6PuTDSxGg==}
'@types/node@25.2.2':
resolution: {integrity: sha512-BkmoP5/FhRYek5izySdkOneRyXYN35I860MFAGupTdebyE66uZaR+bXLHq8k4DirE5DwQi3NuhvRU1jqTVwUrQ==}
@@ -397,8 +397,8 @@ packages:
electron-publish@24.13.1:
resolution: {integrity: sha512-2ZgdEqJ8e9D17Hwp5LEq5mLQPjqU3lv/IALvgp+4W8VeNhryfGhYEQC/PgDPMrnWUp+l60Ou5SJLsu+k4mhQ8A==}
- electron@30.5.1:
- resolution: {integrity: sha512-AhL7+mZ8Lg14iaNfoYTkXQ2qee8mmsQyllKdqxlpv/zrKgfxz6jNVtcRRbQtLxtF8yzcImWdfTQROpYiPumdbw==}
+ electron@40.3.0:
+ resolution: {integrity: sha512-ZaDkTZpNHr863tyZHieoqbaiLI0e3RVCXoEC5y1Ld70/Q5H1mPV9d5TK0h1dWtaSFVOW0w8iDvtdLwAXtasXpg==}
engines: {node: '>= 12.20.55'}
hasBin: true
@@ -992,9 +992,6 @@ packages:
engines: {node: '>=14.17'}
hasBin: true
- undici-types@6.21.0:
- resolution: {integrity: sha512-iwDZqg0QAGrg9Rav5H4n0M64c3mkR59cJ6wQp+7C4nI0gsmExaedaYLNO44eT4AtBBwjbTiGPMlt2Md0T9H9JQ==}
-
undici-types@7.16.0:
resolution: {integrity: sha512-Zz+aZWSj8LE6zoxD+xrjh4VfkIG8Ya6LvYkZqtUQGJPZjYl53ypCaUwWqo7eI0x66KBGeRo+mlBEkMSeSZ38Nw==}
@@ -1158,7 +1155,7 @@ snapshots:
dependencies:
'@types/http-cache-semantics': 4.2.0
'@types/keyv': 3.1.4
- '@types/node': 20.19.33
+ '@types/node': 25.2.2
'@types/responselike': 1.0.3
'@types/debug@4.1.12':
@@ -1173,13 +1170,13 @@ snapshots:
'@types/keyv@3.1.4':
dependencies:
- '@types/node': 20.19.33
+ '@types/node': 25.2.2
'@types/ms@2.1.0': {}
- '@types/node@20.19.33':
+ '@types/node@24.10.13':
dependencies:
- undici-types: 6.21.0
+ undici-types: 7.16.0
'@types/node@25.2.2':
dependencies:
@@ -1193,14 +1190,14 @@ snapshots:
'@types/responselike@1.0.3':
dependencies:
- '@types/node': 20.19.33
+ '@types/node': 25.2.2
'@types/verror@1.10.11':
optional: true
'@types/yauzl@2.10.3':
dependencies:
- '@types/node': 20.19.33
+ '@types/node': 25.2.2
optional: true
'@xmldom/xmldom@0.8.11': {}
@@ -1597,10 +1594,10 @@ snapshots:
transitivePeerDependencies:
- supports-color
- electron@30.5.1:
+ electron@40.3.0:
dependencies:
'@electron/get': 2.0.3
- '@types/node': 20.19.33
+ '@types/node': 24.10.13
extract-zip: 2.0.1
transitivePeerDependencies:
- supports-color
@@ -2211,8 +2208,6 @@ snapshots:
typescript@5.9.3: {}
- undici-types@6.21.0: {}
-
undici-types@7.16.0: {}
universalify@0.1.2: {}
diff --git a/desktop/scripts/build-backend.mjs b/desktop/scripts/build-backend.mjs
index e88297c0e..921cf19cb 100644
--- a/desktop/scripts/build-backend.mjs
+++ b/desktop/scripts/build-backend.mjs
@@ -16,6 +16,8 @@ const kbStopwordsSrc = path.join(
'hit_stopwords.txt',
);
const kbStopwordsDest = 'astrbot/core/knowledge_base/retrieval';
+const builtinStarsSrc = path.join(rootDir, 'astrbot', 'builtin_stars');
+const builtinStarsDest = 'astrbot/builtin_stars';
const args = [
'run',
@@ -33,11 +35,25 @@ const args = [
'aiosqlite',
'--collect-all',
'pip',
+ '--collect-all',
+ 'bs4',
+ '--collect-all',
+ 'readability',
+ '--collect-all',
+ 'lxml',
+ '--collect-all',
+ 'lxml_html_clean',
+ '--collect-all',
+ 'rfc3987_syntax',
'--collect-submodules',
'astrbot.api',
+ '--collect-submodules',
+ 'astrbot.builtin_stars',
'--collect-data',
'certifi',
'--add-data',
+ `${builtinStarsSrc}${dataSeparator}${builtinStarsDest}`,
+ '--add-data',
`${kbStopwordsSrc}${dataSeparator}${kbStopwordsDest}`,
'--distpath',
outputDir,
diff --git a/pyproject.toml b/pyproject.toml
index 77d1c110c..f26a2b349 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[project]
name = "AstrBot"
-version = "4.15.0"
+version = "4.17.3"
description = "Easy-to-use multi-platform LLM chatbot and development framework"
readme = "README.md"
requires-python = ">=3.12"
@@ -17,7 +17,7 @@ dependencies = [
"beautifulsoup4>=4.13.4",
"certifi>=2025.4.26",
"chardet~=5.1.0",
- "colorlog>=6.9.0",
+ "loguru>=0.7.2",
"cryptography>=44.0.3",
"dashscope>=1.23.2",
"defusedxml>=0.7.1",
diff --git a/requirements.txt b/requirements.txt
index 0965a91d8..779136e88 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -10,7 +10,7 @@ apscheduler>=3.11.0
beautifulsoup4>=4.13.4
certifi>=2025.4.26
chardet~=5.1.0
-colorlog>=6.9.0
+loguru>=0.7.2
cryptography>=44.0.3
dashscope>=1.23.2
defusedxml>=0.7.1
@@ -29,7 +29,7 @@ pillow>=11.2.1
pip>=25.1.1
psutil>=5.8.0,<7.2.0
py-cord>=2.6.1
-pydantic~=2.10.3
+pydantic>=2.12.5
pydub>=0.25.1
pyjwt>=2.10.1
python-telegram-bot>=22.0
@@ -54,4 +54,3 @@ markitdown-no-magika[docx,xls,xlsx]>=0.1.2
xinference-client
tenacity>=9.1.2
shipyard-python-sdk>=0.2.4
-shipyard-neo-sdk @ git+https://github.com/AstrBotDevs/shipyard-neo.git#subdirectory=shipyard-neo-sdk
diff --git a/tests/test_openai_source.py b/tests/test_openai_source.py
new file mode 100644
index 000000000..3172097c7
--- /dev/null
+++ b/tests/test_openai_source.py
@@ -0,0 +1,382 @@
+from types import SimpleNamespace
+
+import pytest
+
+from astrbot.core.provider.sources.openai_source import ProviderOpenAIOfficial
+
+
+class _ErrorWithBody(Exception):
+ def __init__(self, message: str, body: dict):
+ super().__init__(message)
+ self.body = body
+
+
+class _ErrorWithResponse(Exception):
+ def __init__(self, message: str, response_text: str):
+ super().__init__(message)
+ self.response = SimpleNamespace(text=response_text)
+
+
+def _make_provider(overrides: dict | None = None) -> ProviderOpenAIOfficial:
+ provider_config = {
+ "id": "test-openai",
+ "type": "openai_chat_completion",
+ "model": "gpt-4o-mini",
+ "key": ["test-key"],
+ }
+ if overrides:
+ provider_config.update(overrides)
+ return ProviderOpenAIOfficial(
+ provider_config=provider_config,
+ provider_settings={},
+ )
+
+
+@pytest.mark.asyncio
+async def test_handle_api_error_content_moderated_removes_images():
+ provider = _make_provider(
+ {"image_moderation_error_patterns": ["file:content-moderated"]}
+ )
+ try:
+ payloads = {
+ "messages": [
+ {
+ "role": "user",
+ "content": [
+ {"type": "text", "text": "hello"},
+ {
+ "type": "image_url",
+ "image_url": {"url": "data:image/jpeg;base64,abcd"},
+ },
+ ],
+ }
+ ]
+ }
+ context_query = payloads["messages"]
+
+ success, *_rest = await provider._handle_api_error(
+ Exception("Content is moderated [WKE=file:content-moderated]"),
+ payloads=payloads,
+ context_query=context_query,
+ func_tool=None,
+ chosen_key="test-key",
+ available_api_keys=["test-key"],
+ retry_cnt=0,
+ max_retries=10,
+ )
+
+ assert success is False
+ updated_context = payloads["messages"]
+ assert isinstance(updated_context, list)
+ assert updated_context[0]["content"] == [{"type": "text", "text": "hello"}]
+ finally:
+ await provider.terminate()
+
+
+@pytest.mark.asyncio
+async def test_handle_api_error_model_not_vlm_removes_images_and_retries_text_only():
+ provider = _make_provider()
+ try:
+ payloads = {
+ "messages": [
+ {
+ "role": "user",
+ "content": [
+ {"type": "text", "text": "hello"},
+ {
+ "type": "image_url",
+ "image_url": {"url": "data:image/jpeg;base64,abcd"},
+ },
+ ],
+ }
+ ]
+ }
+ context_query = payloads["messages"]
+
+ success, *_rest = await provider._handle_api_error(
+ Exception("The model is not a VLM and cannot process images"),
+ payloads=payloads,
+ context_query=context_query,
+ func_tool=None,
+ chosen_key="test-key",
+ available_api_keys=["test-key"],
+ retry_cnt=0,
+ max_retries=10,
+ )
+
+ assert success is False
+ updated_context = payloads["messages"]
+ assert isinstance(updated_context, list)
+ assert updated_context[0]["content"] == [{"type": "text", "text": "hello"}]
+ finally:
+ await provider.terminate()
+
+
+@pytest.mark.asyncio
+async def test_handle_api_error_model_not_vlm_after_fallback_raises():
+ provider = _make_provider()
+ try:
+ payloads = {
+ "messages": [
+ {
+ "role": "user",
+ "content": [
+ {"type": "text", "text": "hello"},
+ {
+ "type": "image_url",
+ "image_url": {"url": "data:image/jpeg;base64,abcd"},
+ },
+ ],
+ }
+ ]
+ }
+ context_query = payloads["messages"]
+
+ with pytest.raises(Exception, match="not a VLM"):
+ await provider._handle_api_error(
+ Exception("The model is not a VLM and cannot process images"),
+ payloads=payloads,
+ context_query=context_query,
+ func_tool=None,
+ chosen_key="test-key",
+ available_api_keys=["test-key"],
+ retry_cnt=1,
+ max_retries=10,
+ image_fallback_used=True,
+ )
+ finally:
+ await provider.terminate()
+
+
+@pytest.mark.asyncio
+async def test_handle_api_error_content_moderated_with_unserializable_body():
+ provider = _make_provider({"image_moderation_error_patterns": ["blocked"]})
+ try:
+ payloads = {
+ "messages": [
+ {
+ "role": "user",
+ "content": [
+ {"type": "text", "text": "hello"},
+ {
+ "type": "image_url",
+ "image_url": {"url": "data:image/jpeg;base64,abcd"},
+ },
+ ],
+ }
+ ]
+ }
+ context_query = payloads["messages"]
+ err = _ErrorWithBody(
+ "upstream error",
+ {"error": {"message": "blocked"}, "raw": object()},
+ )
+
+ success, *_rest = await provider._handle_api_error(
+ err,
+ payloads=payloads,
+ context_query=context_query,
+ func_tool=None,
+ chosen_key="test-key",
+ available_api_keys=["test-key"],
+ retry_cnt=0,
+ max_retries=10,
+ )
+ assert success is False
+ assert payloads["messages"][0]["content"] == [{"type": "text", "text": "hello"}]
+ finally:
+ await provider.terminate()
+
+
+def test_extract_error_text_candidates_truncates_long_response_text():
+ long_text = "x" * 20000
+ err = _ErrorWithResponse("upstream error", long_text)
+ candidates = ProviderOpenAIOfficial._extract_error_text_candidates(err)
+ assert candidates
+ assert max(len(candidate) for candidate in candidates) <= (
+ ProviderOpenAIOfficial._ERROR_TEXT_CANDIDATE_MAX_CHARS
+ )
+
+
+@pytest.mark.asyncio
+async def test_handle_api_error_content_moderated_without_images_raises():
+ provider = _make_provider(
+ {"image_moderation_error_patterns": ["file:content-moderated"]}
+ )
+ try:
+ payloads = {
+ "messages": [
+ {
+ "role": "user",
+ "content": [{"type": "text", "text": "hello"}],
+ }
+ ]
+ }
+ context_query = payloads["messages"]
+ err = Exception("Content is moderated [WKE=file:content-moderated]")
+
+ with pytest.raises(Exception, match="content-moderated"):
+ await provider._handle_api_error(
+ err,
+ payloads=payloads,
+ context_query=context_query,
+ func_tool=None,
+ chosen_key="test-key",
+ available_api_keys=["test-key"],
+ retry_cnt=0,
+ max_retries=10,
+ )
+ finally:
+ await provider.terminate()
+
+
+@pytest.mark.asyncio
+async def test_handle_api_error_content_moderated_detects_structured_body():
+ provider = _make_provider(
+ {"image_moderation_error_patterns": ["content_moderated"]}
+ )
+ try:
+ payloads = {
+ "messages": [
+ {
+ "role": "user",
+ "content": [
+ {"type": "text", "text": "hello"},
+ {
+ "type": "image_url",
+ "image_url": {"url": "data:image/jpeg;base64,abcd"},
+ },
+ ],
+ }
+ ]
+ }
+ context_query = payloads["messages"]
+ err = _ErrorWithBody(
+ "upstream error",
+ {"error": {"code": "content_moderated", "message": "blocked"}},
+ )
+
+ success, *_rest = await provider._handle_api_error(
+ err,
+ payloads=payloads,
+ context_query=context_query,
+ func_tool=None,
+ chosen_key="test-key",
+ available_api_keys=["test-key"],
+ retry_cnt=0,
+ max_retries=10,
+ )
+ assert success is False
+ assert payloads["messages"][0]["content"] == [{"type": "text", "text": "hello"}]
+ finally:
+ await provider.terminate()
+
+
+@pytest.mark.asyncio
+async def test_handle_api_error_content_moderated_supports_custom_patterns():
+ provider = _make_provider(
+ {"image_moderation_error_patterns": ["blocked_by_policy_code_123"]}
+ )
+ try:
+ payloads = {
+ "messages": [
+ {
+ "role": "user",
+ "content": [
+ {"type": "text", "text": "hello"},
+ {
+ "type": "image_url",
+ "image_url": {"url": "data:image/jpeg;base64,abcd"},
+ },
+ ],
+ }
+ ]
+ }
+ context_query = payloads["messages"]
+ err = Exception("upstream: blocked_by_policy_code_123")
+
+ success, *_rest = await provider._handle_api_error(
+ err,
+ payloads=payloads,
+ context_query=context_query,
+ func_tool=None,
+ chosen_key="test-key",
+ available_api_keys=["test-key"],
+ retry_cnt=0,
+ max_retries=10,
+ )
+ assert success is False
+ assert payloads["messages"][0]["content"] == [{"type": "text", "text": "hello"}]
+ finally:
+ await provider.terminate()
+
+
+@pytest.mark.asyncio
+async def test_handle_api_error_content_moderated_without_patterns_raises():
+ provider = _make_provider()
+ try:
+ payloads = {
+ "messages": [
+ {
+ "role": "user",
+ "content": [
+ {"type": "text", "text": "hello"},
+ {
+ "type": "image_url",
+ "image_url": {"url": "data:image/jpeg;base64,abcd"},
+ },
+ ],
+ }
+ ]
+ }
+ context_query = payloads["messages"]
+ err = Exception("Content is moderated [WKE=file:content-moderated]")
+
+ with pytest.raises(Exception, match="content-moderated"):
+ await provider._handle_api_error(
+ err,
+ payloads=payloads,
+ context_query=context_query,
+ func_tool=None,
+ chosen_key="test-key",
+ available_api_keys=["test-key"],
+ retry_cnt=0,
+ max_retries=10,
+ )
+ finally:
+ await provider.terminate()
+
+
+@pytest.mark.asyncio
+async def test_handle_api_error_unknown_image_error_raises():
+ provider = _make_provider()
+ try:
+ payloads = {
+ "messages": [
+ {
+ "role": "user",
+ "content": [
+ {"type": "text", "text": "hello"},
+ {
+ "type": "image_url",
+ "image_url": {"url": "data:image/jpeg;base64,abcd"},
+ },
+ ],
+ }
+ ]
+ }
+ context_query = payloads["messages"]
+
+ with pytest.raises(Exception, match="unknown provider image upload error"):
+ await provider._handle_api_error(
+ Exception("some unknown provider image upload error"),
+ payloads=payloads,
+ context_query=context_query,
+ func_tool=None,
+ chosen_key="test-key",
+ available_api_keys=["test-key"],
+ retry_cnt=0,
+ max_retries=10,
+ )
+ finally:
+ await provider.terminate()
diff --git a/tests/test_quoted_message_parser.py b/tests/test_quoted_message_parser.py
new file mode 100644
index 000000000..0a0e126d5
--- /dev/null
+++ b/tests/test_quoted_message_parser.py
@@ -0,0 +1,494 @@
+from types import SimpleNamespace
+
+import pytest
+
+from astrbot.core.message.components import Image, Plain, Reply
+from astrbot.core.utils.quoted_message_parser import (
+ extract_quoted_message_images,
+ extract_quoted_message_text,
+)
+
+
+class _DummyAPI:
+ def __init__(
+ self,
+ responses: dict[tuple[str, str], dict],
+ param_responses: dict[tuple[str, tuple[tuple[str, str], ...]], dict]
+ | None = None,
+ ):
+ self._responses = responses
+ self._param_responses = param_responses or {}
+
+ async def call_action(self, action: str, **params):
+ param_key = (action, tuple(sorted((k, str(v)) for k, v in params.items())))
+ if param_key in self._param_responses:
+ return self._param_responses[param_key]
+
+ msg_id = params.get("message_id")
+ if msg_id is None:
+ msg_id = params.get("id")
+ key = (action, str(msg_id))
+ if key not in self._responses:
+ raise RuntimeError(f"no mock response for {key}")
+ return self._responses[key]
+
+
+class _FailIfCalledAPI:
+ async def call_action(self, action: str, **params):
+ raise AssertionError(
+ f"call_action should not be called, got action={action}, params={params}"
+ )
+
+
+def _make_event(
+ reply: Reply,
+ responses: dict[tuple[str, str], dict] | None = None,
+ param_responses: dict[tuple[str, tuple[tuple[str, str], ...]], dict] | None = None,
+):
+ if responses is None:
+ responses = {}
+ if param_responses is None:
+ param_responses = {}
+ return SimpleNamespace(
+ message_obj=SimpleNamespace(message=[reply]),
+ bot=SimpleNamespace(api=_DummyAPI(responses, param_responses)),
+ get_group_id=lambda: "",
+ )
+
+
+@pytest.mark.asyncio
+async def test_extract_quoted_message_text_from_reply_chain():
+ reply = Reply(id="1", chain=[Plain(text="quoted content")], message_str="")
+ event = _make_event(reply)
+ text = await extract_quoted_message_text(event)
+ assert text == "quoted content"
+
+
+@pytest.mark.asyncio
+async def test_extract_quoted_message_text_no_reply_component():
+ event = SimpleNamespace(
+ message_obj=SimpleNamespace(message=[Plain(text="unquoted message")]),
+ bot=SimpleNamespace(api=_DummyAPI({}, {})),
+ get_group_id=lambda: "",
+ )
+
+ text = await extract_quoted_message_text(event)
+ assert text is None
+
+
+@pytest.mark.asyncio
+async def test_extract_quoted_message_images_no_reply_component():
+ event = SimpleNamespace(
+ message_obj=SimpleNamespace(message=[Plain(text="unquoted message")]),
+ bot=SimpleNamespace(api=_FailIfCalledAPI()),
+ get_group_id=lambda: "",
+ )
+
+ images = await extract_quoted_message_images(event)
+ assert images == []
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("reply_id", [None, ""])
+async def test_extract_quoted_message_text_reply_without_id_does_not_call_get_msg(
+ reply_id: str | None,
+):
+ reply = Reply(
+ id="placeholder", chain=[Plain(text="quoted content")], message_str=""
+ )
+ object.__setattr__(reply, "id", reply_id)
+ event = SimpleNamespace(
+ message_obj=SimpleNamespace(message=[reply]),
+ bot=SimpleNamespace(api=_FailIfCalledAPI()),
+ get_group_id=lambda: "",
+ )
+
+ text = await extract_quoted_message_text(event)
+ assert text == "quoted content"
+
+
+@pytest.mark.asyncio
+async def test_extract_quoted_message_text_fallback_get_msg_and_forward():
+ reply = Reply(id="100", chain=None, message_str="")
+ event = _make_event(
+ reply,
+ responses={
+ (
+ "get_msg",
+ "100",
+ ): {
+ "data": {
+ "message": [
+ {"type": "text", "data": {"text": "parent"}},
+ {"type": "forward", "data": {"id": "fwd_1"}},
+ ]
+ }
+ },
+ (
+ "get_forward_msg",
+ "fwd_1",
+ ): {
+ "data": {
+ "messages": [
+ {
+ "sender": {"nickname": "Alice"},
+ "message": [{"type": "text", "data": {"text": "hello"}}],
+ },
+ {
+ "sender": {"nickname": "Bob"},
+ "message": [
+ {"type": "image", "data": {"url": "http://img"}},
+ {"type": "text", "data": {"text": "world"}},
+ ],
+ },
+ ]
+ }
+ },
+ },
+ )
+
+ text = await extract_quoted_message_text(event)
+ assert text is not None
+ assert "parent" in text
+ assert "Alice: hello" in text
+ assert "Bob: [Image]world" in text
+
+
+@pytest.mark.parametrize(
+ "placeholder_text",
+ [
+ "[Forward Message]",
+ "[转发消息]",
+ "[合并转发]",
+ "Alice: [Forward Message]",
+ "(Alice): [转发消息]",
+ "[Forward Message]\n[转发消息]",
+ "Alice: [Forward Message]\n(Bob): [合并转发]",
+ "[转发消息]\n\n[合并转发]",
+ ],
+)
+@pytest.mark.asyncio
+async def test_extract_quoted_message_text_forward_placeholder_variants_trigger_fallback(
+ placeholder_text: str,
+):
+ reply = Reply(id="400", chain=[Plain(text=placeholder_text)], message_str="")
+ event = _make_event(
+ reply,
+ responses={
+ ("get_msg", "400"): {
+ "data": {
+ "message": [
+ {"type": "text", "data": {"text": "Bob: "}},
+ {"type": "image", "data": {}},
+ {"type": "text", "data": {"text": "world"}},
+ ]
+ }
+ }
+ },
+ )
+
+ text = await extract_quoted_message_text(event)
+ assert "Bob: [Image]world" in text
+
+
+@pytest.mark.asyncio
+async def test_extract_quoted_message_text_mixed_placeholder_does_not_trigger_fallback():
+ reply = Reply(
+ id="402",
+ chain=[Plain(text="Alice: [Forward Message]\nreal text")],
+ message_str="",
+ )
+ event = SimpleNamespace(
+ message_obj=SimpleNamespace(message=[reply]),
+ bot=SimpleNamespace(api=_FailIfCalledAPI()),
+ get_group_id=lambda: "",
+ )
+
+ text = await extract_quoted_message_text(event)
+ assert text is not None
+ assert "[Forward Message]" in text
+ assert "real text" in text
+
+
+@pytest.mark.asyncio
+async def test_extract_quoted_message_text_forward_placeholder_fallback_failure():
+ reply = Reply(id="401", chain=[Plain(text="[Forward Message]")], message_str="")
+ event = _make_event(reply, responses={})
+
+ text = await extract_quoted_message_text(event)
+ assert text == "[Forward Message]"
+
+
+@pytest.mark.asyncio
+async def test_extract_quoted_message_text_multimsg_malformed_config_does_not_raise():
+ reply = Reply(id="402", chain=None, message_str="")
+ event = _make_event(
+ reply,
+ responses={
+ ("get_msg", "402"): {
+ "data": {
+ "message": [
+ {
+ "type": "json",
+ "data": {
+ "data": (
+ '{"app":"com.tencent.multimsg",'
+ '"config":"oops","meta":{}}'
+ )
+ },
+ },
+ {"type": "text", "data": {"text": "still works"}},
+ ]
+ }
+ }
+ },
+ )
+
+ text = await extract_quoted_message_text(event)
+ assert text == "still works"
+
+
+@pytest.mark.asyncio
+async def test_extract_quoted_message_images_from_reply_chain():
+ reply = Reply(
+ id="1",
+ chain=[
+ Plain(text="quoted"),
+ Image(file="https://img.example.com/a.jpg"),
+ ],
+ message_str="",
+ )
+ event = _make_event(reply)
+
+ images = await extract_quoted_message_images(event)
+ assert images == ["https://img.example.com/a.jpg"]
+
+
+@pytest.mark.asyncio
+async def test_extract_quoted_message_images_fallback_get_msg_direct_url():
+ reply = Reply(id="200", chain=None, message_str="")
+ event = _make_event(
+ reply,
+ responses={
+ ("get_msg", "200"): {
+ "data": {
+ "message": [
+ {
+ "type": "image",
+ "data": {"url": "https://img.example.com/direct.jpg"},
+ }
+ ]
+ }
+ }
+ },
+ )
+
+ images = await extract_quoted_message_images(event)
+ assert images == ["https://img.example.com/direct.jpg"]
+
+
+@pytest.mark.asyncio
+async def test_extract_quoted_message_images_data_image_ref_normalized_to_base64():
+ data_image_ref = "data:image/png;base64,abcd1234=="
+ reply = Reply(id="201", chain=None, message_str="")
+ event = _make_event(
+ reply,
+ responses={
+ ("get_msg", "201"): {
+ "data": {
+ "message": [
+ {"type": "image", "data": {"url": data_image_ref}},
+ ]
+ }
+ }
+ },
+ )
+
+ images = await extract_quoted_message_images(event)
+ assert images == ["base64://abcd1234=="]
+
+
+@pytest.mark.asyncio
+async def test_extract_quoted_message_images_file_url_with_query_string():
+ url_with_query = "https://img.example.com/direct.jpg?token=abc123#frag"
+ reply = Reply(id="205", chain=None, message_str="")
+ event = _make_event(
+ reply,
+ responses={
+ ("get_msg", "205"): {
+ "data": {
+ "message": [
+ {
+ "type": "file",
+ "data": {
+ "url": url_with_query,
+ "name": "direct.jpg",
+ },
+ }
+ ]
+ }
+ }
+ },
+ )
+
+ images = await extract_quoted_message_images(event)
+ assert images == [url_with_query]
+
+
+@pytest.mark.asyncio
+async def test_extract_quoted_message_images_non_image_local_path_is_ignored(tmp_path):
+ non_image_file = tmp_path / "secret.txt"
+ non_image_file.write_text("not an image", encoding="utf-8")
+
+ reply = Reply(
+ id="placeholder", chain=[Image(file=str(non_image_file))], message_str=""
+ )
+ object.__setattr__(reply, "id", None)
+ event = SimpleNamespace(
+ message_obj=SimpleNamespace(message=[reply]),
+ bot=SimpleNamespace(api=_FailIfCalledAPI()),
+ get_group_id=lambda: "",
+ )
+
+ images = await extract_quoted_message_images(event)
+ assert images == []
+
+
+@pytest.mark.asyncio
+async def test_extract_quoted_message_images_chain_placeholder_triggers_fallback():
+ reply = Reply(id="210", chain=[Plain(text="[Forward Message]")], message_str="")
+ event = _make_event(
+ reply,
+ responses={
+ ("get_msg", "210"): {
+ "data": {
+ "message": [
+ {
+ "type": "image",
+ "data": {
+ "url": "https://img.example.com/from-fallback.jpg"
+ },
+ }
+ ]
+ }
+ }
+ },
+ )
+
+ images = await extract_quoted_message_images(event)
+ assert images == ["https://img.example.com/from-fallback.jpg"]
+
+
+@pytest.mark.asyncio
+async def test_extract_quoted_message_images_fallback_resolve_file_id_with_get_image():
+ reply = Reply(id="300", chain=None, message_str="")
+ event = _make_event(
+ reply,
+ responses={
+ ("get_msg", "300"): {
+ "data": {"message": [{"type": "image", "data": {"file": "abc123.jpg"}}]}
+ }
+ },
+ param_responses={
+ ("get_image", (("file", "abc123.jpg"),)): {
+ "data": {"url": "https://img.example.com/resolved.jpg"}
+ }
+ },
+ )
+
+ images = await extract_quoted_message_images(event)
+ assert images == ["https://img.example.com/resolved.jpg"]
+
+
+@pytest.mark.asyncio
+async def test_extract_quoted_message_images_deduplicates_across_sources():
+ dup_url = "https://img.example.com/dup.jpg"
+ chain_only_url = "https://img.example.com/only-chain.jpg"
+ get_msg_only_url = "https://img.example.com/only-get-msg.jpg"
+ forward_only_url = "https://img.example.com/only-forward.jpg"
+
+ reply = Reply(
+ id="310",
+ chain=[Image(file=dup_url), Image(file=chain_only_url)],
+ message_str="",
+ )
+
+ event = _make_event(
+ reply,
+ responses={
+ ("get_msg", "310"): {
+ "data": {
+ "message": [
+ {"type": "image", "data": {"url": dup_url}},
+ {"type": "image", "data": {"url": get_msg_only_url}},
+ {"type": "forward", "data": {"id": "999"}},
+ ]
+ }
+ },
+ ("get_forward_msg", "999"): {
+ "data": {
+ "messages": [
+ {
+ "sender": {"nickname": "Tester"},
+ "message": [
+ {"type": "image", "data": {"url": dup_url}},
+ {"type": "image", "data": {"url": forward_only_url}},
+ ],
+ }
+ ]
+ }
+ },
+ },
+ )
+
+ images = await extract_quoted_message_images(event)
+ assert images == [
+ dup_url,
+ chain_only_url,
+ get_msg_only_url,
+ forward_only_url,
+ ]
+
+
+@pytest.mark.asyncio
+async def test_extract_quoted_message_nested_forward_id_is_resolved():
+ nested_image = "https://img.example.com/nested.jpg"
+ reply = Reply(id="320", chain=[Plain(text="[Forward Message]")], message_str="")
+ event = _make_event(
+ reply,
+ responses={
+ ("get_msg", "320"): {
+ "data": {"message": [{"type": "forward", "data": {"id": "fwd_1"}}]}
+ },
+ ("get_forward_msg", "fwd_1"): {
+ "data": {
+ "messages": [
+ {
+ "sender": {"nickname": "Alice"},
+ "message": [{"type": "forward", "data": {"id": "fwd_2"}}],
+ }
+ ]
+ }
+ },
+ ("get_forward_msg", "fwd_2"): {
+ "data": {
+ "messages": [
+ {
+ "sender": {"nickname": "Bob"},
+ "message": [
+ {"type": "text", "data": {"text": "deep"}},
+ {"type": "image", "data": {"url": nested_image}},
+ ],
+ }
+ ]
+ }
+ },
+ },
+ )
+
+ text = await extract_quoted_message_text(event)
+ assert text is not None
+ assert "Bob: deep" in text
+
+ images = await extract_quoted_message_images(event)
+ assert images == [nested_image]
diff --git a/tests/test_temp_dir_cleaner.py b/tests/test_temp_dir_cleaner.py
new file mode 100644
index 000000000..01f3e65d0
--- /dev/null
+++ b/tests/test_temp_dir_cleaner.py
@@ -0,0 +1,52 @@
+import os
+import time
+from pathlib import Path
+
+from astrbot.core.utils.temp_dir_cleaner import TempDirCleaner, parse_size_to_bytes
+
+
+def test_parse_size_to_bytes():
+ assert parse_size_to_bytes("1024") == 1024 * 1024**2
+ assert parse_size_to_bytes(2048) == 2048 * 1024**2
+ assert parse_size_to_bytes("0.5") == int(0.5 * 1024**2)
+ assert parse_size_to_bytes(0) == 0
+ assert parse_size_to_bytes("invalid") == 0
+
+
+def _write_file(path: Path, size: int, mtime: float) -> None:
+ path.write_bytes(b"x" * size)
+ os.utime(path, (mtime, mtime))
+
+
+def test_cleanup_once_releases_30_percent_and_prefers_old_files(tmp_path):
+ temp_dir = tmp_path / "temp"
+ temp_dir.mkdir(parents=True, exist_ok=True)
+
+ base_time = time.time() - 1000
+ file_old = temp_dir / "old.bin"
+ file_mid = temp_dir / "mid.bin"
+ file_new = temp_dir / "new.bin"
+ _write_file(file_old, 400, base_time)
+ _write_file(file_mid, 300, base_time + 10)
+ _write_file(file_new, 300, base_time + 20)
+
+ cleaner = TempDirCleaner(max_size_getter=lambda: "0.0008", temp_dir=temp_dir)
+ cleaner.cleanup_once()
+
+ remaining_size = sum(f.stat().st_size for f in temp_dir.rglob("*") if f.is_file())
+ assert remaining_size <= 600
+ assert not file_old.exists()
+ assert file_mid.exists()
+ assert file_new.exists()
+
+
+def test_cleanup_once_noop_when_below_limit(tmp_path):
+ temp_dir = tmp_path / "temp"
+ temp_dir.mkdir(parents=True, exist_ok=True)
+ file_path = temp_dir / "a.bin"
+ _write_file(file_path, 100, time.time())
+
+ cleaner = TempDirCleaner(max_size_getter=lambda: "1", temp_dir=temp_dir)
+ cleaner.cleanup_once()
+
+ assert file_path.exists()
diff --git a/tests/test_tool_loop_agent_runner.py b/tests/test_tool_loop_agent_runner.py
index f0e90002d..4a91877fd 100644
--- a/tests/test_tool_loop_agent_runner.py
+++ b/tests/test_tool_loop_agent_runner.py
@@ -90,6 +90,21 @@ class MockToolExecutor:
return generator()
+class MockFailingProvider(MockProvider):
+ async def text_chat(self, **kwargs) -> LLMResponse:
+ self.call_count += 1
+ raise RuntimeError("primary provider failed")
+
+
+class MockErrProvider(MockProvider):
+ async def text_chat(self, **kwargs) -> LLMResponse:
+ self.call_count += 1
+ return LLMResponse(
+ role="err",
+ completion_text="primary provider returned error",
+ )
+
+
class MockHooks(BaseAgentRunHooks):
"""模拟钩子函数"""
@@ -321,6 +336,64 @@ async def test_hooks_called_with_max_step(
assert mock_hooks.tool_end_called, "on_tool_end应该被调用"
+@pytest.mark.asyncio
+async def test_fallback_provider_used_when_primary_raises(
+ runner, provider_request, mock_tool_executor, mock_hooks
+):
+ primary_provider = MockFailingProvider()
+ fallback_provider = MockProvider()
+ fallback_provider.should_call_tools = False
+
+ await runner.reset(
+ provider=primary_provider,
+ request=provider_request,
+ run_context=ContextWrapper(context=None),
+ tool_executor=mock_tool_executor,
+ agent_hooks=mock_hooks,
+ streaming=False,
+ fallback_providers=[fallback_provider],
+ )
+
+ async for _ in runner.step_until_done(5):
+ pass
+
+ final_resp = runner.get_final_llm_resp()
+ assert final_resp is not None
+ assert final_resp.role == "assistant"
+ assert final_resp.completion_text == "这是我的最终回答"
+ assert primary_provider.call_count == 1
+ assert fallback_provider.call_count == 1
+
+
+@pytest.mark.asyncio
+async def test_fallback_provider_used_when_primary_returns_err(
+ runner, provider_request, mock_tool_executor, mock_hooks
+):
+ primary_provider = MockErrProvider()
+ fallback_provider = MockProvider()
+ fallback_provider.should_call_tools = False
+
+ await runner.reset(
+ provider=primary_provider,
+ request=provider_request,
+ run_context=ContextWrapper(context=None),
+ tool_executor=mock_tool_executor,
+ agent_hooks=mock_hooks,
+ streaming=False,
+ fallback_providers=[fallback_provider],
+ )
+
+ async for _ in runner.step_until_done(5):
+ pass
+
+ final_resp = runner.get_final_llm_resp()
+ assert final_resp is not None
+ assert final_resp.role == "assistant"
+ assert final_resp.completion_text == "这是我的最终回答"
+ assert primary_provider.call_count == 1
+ assert fallback_provider.call_count == 1
+
+
if __name__ == "__main__":
# 运行测试
pytest.main([__file__, "-v"])