Compare commits

..

3 Commits

Author SHA1 Message Date
Soulter cfb0538b32 feat(chat): add websocket API key extraction and scope validation 2026-02-25 17:40:55 +08:00
Soulter f8f7e6d57a feat(webchat): refactor message parsing logic and integrate new parsing function 2026-02-25 17:14:02 +08:00
Soulter 53ae8cd7cf feat: implement websockets transport mode selection for chat
- Added transport mode selection (SSE/WebSocket) in the chat component.
- Updated conversation sidebar to include transport mode options.
- Integrated transport mode handling in message sending logic.
- Refactored message sending functions to support both SSE and WebSocket.
- Enhanced WebSocket connection management and message handling.
- Updated localization files for transport mode labels.
- Configured Vite to support WebSocket proxying.
2026-02-24 21:14:16 +08:00
102 changed files with 5178 additions and 7833 deletions
+1 -1
View File
@@ -37,7 +37,7 @@ jobs:
mkdir -p data/temp
export TESTING=true
export ZHIPU_API_KEY=${{ secrets.OPENAI_API_KEY }}
pytest --cov=astrbot -v -o log_cli=true -o log_level=DEBUG
pytest --cov=. -v -o log_cli=true -o log_level=DEBUG
- name: Upload results to Codecov
uses: codecov/codecov-action@v5
@@ -206,33 +206,16 @@ class ConversationCommands:
_titles[conv.cid] = title
"""遍历分页后的对话生成列表显示"""
provider_settings = cfg.get("provider_settings", {})
platform_name = message.get_platform_name()
for conv in conversations_paged:
(
persona_id,
_,
force_applied_persona_id,
_,
) = await self.context.persona_manager.resolve_selected_persona(
umo=message.unified_msg_origin,
conversation_persona_id=conv.persona_id,
platform_name=platform_name,
provider_settings=provider_settings,
)
if persona_id == "[%None]":
persona_name = ""
elif persona_id:
persona_name = persona_id
else:
persona_name = ""
if force_applied_persona_id:
persona_name = f"{persona_name} (自定义规则)"
persona_id = conv.persona_id
if not persona_id or persona_id == "[%None]":
persona = await self.context.persona_manager.get_default_persona_v3(
umo=message.unified_msg_origin,
)
persona_id = persona["name"]
title = _titles.get(conv.cid, "新对话")
parts.append(
f"{global_index}. {title}({conv.cid[:4]})\n 人格情景: {persona_name}\n 上次更新: {datetime.datetime.fromtimestamp(conv.updated_at).strftime('%m-%d %H:%M')}\n"
f"{global_index}. {title}({conv.cid[:4]})\n 人格情景: {persona_id}\n 上次更新: {datetime.datetime.fromtimestamp(conv.updated_at).strftime('%m-%d %H:%M')}\n"
)
global_index += 1
@@ -1,7 +1,7 @@
import builtins
from typing import TYPE_CHECKING
from astrbot.api import star
from astrbot.api import sp, star
from astrbot.api.event import AstrMessageEvent, MessageEventResult
if TYPE_CHECKING:
@@ -59,7 +59,12 @@ class PersonaCommands:
default_persona = await self.context.persona_manager.get_default_persona_v3(
umo=umo,
)
force_applied_persona_id = None
force_applied_persona_id = (
await sp.get_async(
scope="umo", scope_id=umo, key="session_service_config", default={}
)
).get("persona_id")
curr_cid_title = ""
if cid:
@@ -75,27 +80,10 @@ class PersonaCommands:
),
)
return
provider_settings = self.context.get_config(umo=umo).get(
"provider_settings",
{},
)
(
persona_id,
_,
force_applied_persona_id,
_,
) = await self.context.persona_manager.resolve_selected_persona(
umo=umo,
conversation_persona_id=conv.persona_id,
platform_name=message.get_platform_name(),
provider_settings=provider_settings,
)
if persona_id == "[%None]":
curr_persona_name = ""
elif persona_id:
curr_persona_name = persona_id
if not conv.persona_id and conv.persona_id != "[%None]":
curr_persona_name = default_persona["name"]
else:
curr_persona_name = conv.persona_id
if force_applied_persona_id:
curr_persona_name = f"{curr_persona_name} (自定义规则)"
+1 -1
View File
@@ -1 +1 @@
__version__ = "4.18.3"
__version__ = "4.18.1"
+12 -53
View File
@@ -4,60 +4,19 @@ from ..message import Message
class ContextTruncator:
"""Context truncator."""
def _has_tool_calls(self, message: Message) -> bool:
"""Check if a message contains tool calls."""
return (
message.role == "assistant"
and message.tool_calls is not None
and len(message.tool_calls) > 0
)
def fix_messages(self, messages: list[Message]) -> list[Message]:
"""修复消息列表,确保 tool call 和 tool response 的配对关系有效。
此方法确保:
1. 每个 `tool` 消息前面都有一个包含 tool_calls 的 `assistant` 消息
2. 每个包含 tool_calls 的 `assistant` 消息后面都有对应的 `tool` 响应
这是 OpenAI Chat Completions API 规范的要求(Gemini 对此执行严格检查)。
"""
if not messages:
return messages
fixed_messages: list[Message] = []
pending_assistant: Message | None = None
pending_tools: list[Message] = []
def flush_pending_if_valid() -> None:
nonlocal pending_assistant, pending_tools
if pending_assistant is not None and pending_tools:
fixed_messages.append(pending_assistant)
fixed_messages.extend(pending_tools)
pending_assistant = None
pending_tools = []
for msg in messages:
if msg.role == "tool":
# 只有在有挂起的 assistant(tool_calls) 时才记录 tool 响应
if pending_assistant is not None:
pending_tools.append(msg)
# else: 孤立的 tool 消息,直接忽略
continue
if self._has_tool_calls(msg):
# 遇到新的 assistant(tool_calls) 前,先处理旧的 pending 链
flush_pending_if_valid()
pending_assistant = msg
continue
# 非 tool,且不含 tool_calls 的消息
# 先结束任何 pending 链,再正常追加
flush_pending_if_valid()
fixed_messages.append(msg)
# 结束时处理最后一个 pending 链
flush_pending_if_valid()
fixed_messages = []
for message in messages:
if message.role == "tool":
# tool block 前面必须要有 user 和 assistant block
if len(fixed_messages) < 2:
# 这种情况可能是上下文被截断导致的
# 我们直接将之前的上下文都清空
fixed_messages = []
else:
fixed_messages.append(message)
else:
fixed_messages.append(message)
return fixed_messages
def truncate_by_turns(
+16 -92
View File
@@ -24,77 +24,15 @@ def _should_stop_agent(astr_event) -> bool:
return astr_event.is_stopped() or bool(astr_event.get_extra("agent_stop_requested"))
def _truncate_tool_result(text: str, limit: int = 70) -> str:
if limit <= 0:
return ""
if len(text) <= limit:
return text
if limit <= 3:
return text[:limit]
return f"{text[: limit - 3]}..."
def _extract_chain_json_data(msg_chain: MessageChain) -> dict | None:
if not msg_chain.chain:
return None
first_comp = msg_chain.chain[0]
if isinstance(first_comp, Json) and isinstance(first_comp.data, dict):
return first_comp.data
return None
def _record_tool_call_name(
tool_info: dict | None, tool_name_by_call_id: dict[str, str]
) -> None:
if not isinstance(tool_info, dict):
return
tool_call_id = tool_info.get("id")
tool_name = tool_info.get("name")
if tool_call_id is None or tool_name is None:
return
tool_name_by_call_id[str(tool_call_id)] = str(tool_name)
def _build_tool_call_status_message(tool_info: dict | None) -> str:
if tool_info:
return f"🔨 调用工具: {tool_info.get('name', 'unknown')}"
return "🔨 调用工具..."
def _build_tool_result_status_message(
msg_chain: MessageChain, tool_name_by_call_id: dict[str, str]
) -> str:
tool_name = "unknown"
tool_result = ""
result_data = _extract_chain_json_data(msg_chain)
if result_data:
tool_call_id = result_data.get("id")
if tool_call_id is not None:
tool_name = tool_name_by_call_id.pop(str(tool_call_id), "unknown")
tool_result = str(result_data.get("result", ""))
if not tool_result:
tool_result = msg_chain.get_plain_text(with_other_comps_mark=True)
tool_result = _truncate_tool_result(tool_result, 70)
status_msg = f"🔨 调用工具: {tool_name}"
if tool_result:
status_msg = f"{status_msg}\n📎 返回结果: {tool_result}"
return status_msg
async def run_agent(
agent_runner: AgentRunner,
max_step: int = 30,
show_tool_use: bool = True,
show_tool_call_result: bool = False,
stream_to_general: bool = False,
show_reasoning: bool = False,
) -> AsyncGenerator[MessageChain | None, None]:
step_idx = 0
astr_event = agent_runner.run_context.context.event
tool_name_by_call_id: dict[str, str] = {}
while step_idx < max_step + 1:
step_idx += 1
@@ -152,13 +90,6 @@ async def run_agent(
continue
if astr_event.get_platform_id() == "webchat":
await astr_event.send(msg_chain)
elif show_tool_use and show_tool_call_result:
status_msg = _build_tool_result_status_message(
msg_chain, tool_name_by_call_id
)
await astr_event.send(
MessageChain(type="tool_call").message(status_msg)
)
# 对于其他情况,暂时先不处理
continue
elif resp.type == "tool_call":
@@ -166,22 +97,25 @@ async def run_agent(
# 用来标记流式响应需要分节
yield MessageChain(chain=[], type="break")
tool_info = _extract_chain_json_data(resp.data["chain"])
astr_event.trace.record(
"agent_tool_call",
tool_name=tool_info if tool_info else "unknown",
)
_record_tool_call_name(tool_info, tool_name_by_call_id)
tool_info = None
if resp.data["chain"].chain:
json_comp = resp.data["chain"].chain[0]
if isinstance(json_comp, Json):
tool_info = json_comp.data
astr_event.trace.record(
"agent_tool_call",
tool_name=tool_info if tool_info else "unknown",
)
if astr_event.get_platform_name() == "webchat":
await astr_event.send(resp.data["chain"])
elif show_tool_use:
if show_tool_call_result and isinstance(tool_info, dict):
# Delay tool status notification until tool_call_result.
continue
chain = MessageChain(type="tool_call").message(
_build_tool_call_status_message(tool_info)
)
if tool_info:
m = f"🔨 调用工具: {tool_info.get('name', 'unknown')}"
else:
m = "🔨 调用工具..."
chain = MessageChain(type="tool_call").message(m)
await astr_event.send(chain)
continue
@@ -268,7 +202,6 @@ async def run_live_agent(
tts_provider: TTSProvider | None = None,
max_step: int = 30,
show_tool_use: bool = True,
show_tool_call_result: bool = False,
show_reasoning: bool = False,
) -> AsyncGenerator[MessageChain | None, None]:
"""Live Mode 的 Agent 运行器,支持流式 TTS
@@ -278,7 +211,6 @@ async def run_live_agent(
tts_provider: TTS Provider 实例
max_step: 最大步数
show_tool_use: 是否显示工具使用
show_tool_call_result: 是否显示工具返回结果
show_reasoning: 是否显示推理过程
Yields:
@@ -290,7 +222,6 @@ async def run_live_agent(
agent_runner,
max_step=max_step,
show_tool_use=show_tool_use,
show_tool_call_result=show_tool_call_result,
stream_to_general=False,
show_reasoning=show_reasoning,
):
@@ -319,12 +250,7 @@ async def run_live_agent(
# 1. 启动 Agent Feeder 任务:负责运行 Agent 并将文本分句喂给 text_queue
feeder_task = asyncio.create_task(
_run_agent_feeder(
agent_runner,
text_queue,
max_step,
show_tool_use,
show_tool_call_result,
show_reasoning,
agent_runner, text_queue, max_step, show_tool_use, show_reasoning
)
)
@@ -410,7 +336,6 @@ async def _run_agent_feeder(
text_queue: asyncio.Queue,
max_step: int,
show_tool_use: bool,
show_tool_call_result: bool,
show_reasoning: bool,
) -> None:
"""运行 Agent 并将文本输出分句放入队列"""
@@ -420,7 +345,6 @@ async def _run_agent_feeder(
agent_runner,
max_step=max_step,
show_tool_use=show_tool_use,
show_tool_call_result=show_tool_call_result,
stream_to_general=False,
show_reasoning=show_reasoning,
):
+13 -67
View File
@@ -17,12 +17,6 @@ from astrbot.core.agent.tool_executor import BaseFunctionToolExecutor
from astrbot.core.astr_agent_context import AstrAgentContext
from astrbot.core.astr_main_agent_resources import (
BACKGROUND_TASK_RESULT_WOKE_SYSTEM_PROMPT,
EXECUTE_SHELL_TOOL,
FILE_DOWNLOAD_TOOL,
FILE_UPLOAD_TOOL,
LOCAL_EXECUTE_SHELL_TOOL,
LOCAL_PYTHON_TOOL,
PYTHON_TOOL,
SEND_MESSAGE_TO_USER_TOOL,
)
from astrbot.core.cron.events import CronMessageEvent
@@ -97,65 +91,6 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
yield r
return
@classmethod
def _get_runtime_computer_tools(cls, runtime: str) -> dict[str, FunctionTool]:
if runtime == "sandbox":
return {
EXECUTE_SHELL_TOOL.name: EXECUTE_SHELL_TOOL,
PYTHON_TOOL.name: PYTHON_TOOL,
FILE_UPLOAD_TOOL.name: FILE_UPLOAD_TOOL,
FILE_DOWNLOAD_TOOL.name: FILE_DOWNLOAD_TOOL,
}
if runtime in {"local", "local_sandboxed"}:
return {
LOCAL_EXECUTE_SHELL_TOOL.name: LOCAL_EXECUTE_SHELL_TOOL,
LOCAL_PYTHON_TOOL.name: LOCAL_PYTHON_TOOL,
}
return {}
@classmethod
def _build_handoff_toolset(
cls,
run_context: ContextWrapper[AstrAgentContext],
tools: list[str | FunctionTool] | None,
) -> ToolSet | None:
ctx = run_context.context.context
event = run_context.context.event
cfg = ctx.get_config(umo=event.unified_msg_origin)
provider_settings = cfg.get("provider_settings", {})
runtime = str(provider_settings.get("computer_use_runtime", "local"))
runtime_computer_tools = cls._get_runtime_computer_tools(runtime)
# Keep persona semantics aligned with the main agent: tools=None means
# "all tools", including runtime computer-use tools.
if tools is None:
toolset = ToolSet()
for registered_tool in llm_tools.func_list:
if isinstance(registered_tool, HandoffTool):
continue
if registered_tool.active:
toolset.add_tool(registered_tool)
for runtime_tool in runtime_computer_tools.values():
toolset.add_tool(runtime_tool)
return None if toolset.empty() else toolset
if not tools:
return None
toolset = ToolSet()
for tool_name_or_obj in tools:
if isinstance(tool_name_or_obj, str):
registered_tool = llm_tools.get_func(tool_name_or_obj)
if registered_tool and registered_tool.active:
toolset.add_tool(registered_tool)
continue
runtime_tool = runtime_computer_tools.get(tool_name_or_obj)
if runtime_tool:
toolset.add_tool(runtime_tool)
elif isinstance(tool_name_or_obj, FunctionTool):
toolset.add_tool(tool_name_or_obj)
return None if toolset.empty() else toolset
@classmethod
async def _execute_handoff(
cls,
@@ -166,8 +101,19 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
input_ = tool_args.get("input")
image_urls = tool_args.get("image_urls")
# Build handoff toolset from registered tools plus runtime computer tools.
toolset = cls._build_handoff_toolset(run_context, tool.agent.tools)
# make toolset for the agent
tools = tool.agent.tools
if tools:
toolset = ToolSet()
for t in tools:
if isinstance(t, str):
_t = llm_tools.get_func(t)
if _t:
toolset.add_tool(_t)
elif isinstance(t, FunctionTool):
toolset.add_tool(t)
else:
toolset = None
ctx = run_context.context.context
event = run_context.context.event
+38 -15
View File
@@ -1,6 +1,7 @@
from __future__ import annotations
import asyncio
import builtins
import copy
import datetime
import json
@@ -9,6 +10,7 @@ import zoneinfo
from collections.abc import Coroutine
from dataclasses import dataclass, field
from astrbot.api import sp
from astrbot.core import logger
from astrbot.core.agent.handoff import HandoffTool
from astrbot.core.agent.mcp_client import MCPTool
@@ -110,7 +112,7 @@ class MainAgentBuildConfig:
to prevent LLM output harmful information"""
safety_mode_strategy: str = "system_prompt"
computer_use_runtime: str = "local"
"""The runtime for agent computer use: none, local, local_sandboxed, or sandbox."""
"""The runtime for agent computer use: none, local, or sandbox."""
sandbox_cfg: dict = field(default_factory=dict)
add_cron_tools: bool = True
"""This will add cron job management tools to the main agent for proactive cron job execution."""
@@ -273,26 +275,47 @@ async def _ensure_persona_and_skills(
if not req.conversation:
return
(
persona_id,
persona,
_,
use_webchat_special_default,
) = await plugin_context.persona_manager.resolve_selected_persona(
umo=event.unified_msg_origin,
conversation_persona_id=req.conversation.persona_id,
platform_name=event.get_platform_name(),
provider_settings=cfg,
)
# get persona ID
# 1. from session service config - highest priority
persona_id = (
await sp.get_async(
scope="umo",
scope_id=event.unified_msg_origin,
key="session_service_config",
default={},
)
).get("persona_id")
if not persona_id:
# 2. from conversation setting - second priority
persona_id = req.conversation.persona_id
if persona_id == "[%None]":
# explicitly set to no persona
pass
elif persona_id is None:
# 3. from config default persona setting - last priority
persona_id = cfg.get("default_personality")
persona = next(
builtins.filter(
lambda persona: persona["name"] == persona_id,
plugin_context.persona_manager.personas_v3,
),
None,
)
if persona:
# Inject persona system prompt
if prompt := persona["prompt"]:
req.system_prompt += f"\n# Persona Instructions\n\n{prompt}\n"
if begin_dialogs := copy.deepcopy(persona.get("_begin_dialogs_processed")):
req.contexts[:0] = begin_dialogs
elif use_webchat_special_default:
req.system_prompt += CHATUI_SPECIAL_DEFAULT_PERSONA_PROMPT
else:
# special handling for webchat persona
if event.get_platform_name() == "webchat" and persona_id != "[%None]":
persona_id = "_chatui_default_"
req.system_prompt += CHATUI_SPECIAL_DEFAULT_PERSONA_PROMPT
# Inject skills prompt
runtime = cfg.get("computer_use_runtime", "local")
@@ -1050,7 +1073,7 @@ async def build_main_agent(
if config.computer_use_runtime == "sandbox":
_apply_sandbox_tools(config, req, req.session_id)
elif config.computer_use_runtime in {"local", "local_sandboxed"}:
elif config.computer_use_runtime == "local":
_apply_local_env_tools(req)
agent_runner = AgentRunner()
+56 -208
View File
@@ -2,24 +2,22 @@ from __future__ import annotations
import asyncio
import os
import re
import shlex
import shutil
import subprocess
import sys
import uuid
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Literal
from typing import Any
from astrbot.api import logger
from astrbot.core.utils.astrbot_path import get_astrbot_root
from astrbot.core.utils.astrbot_path import (
get_astrbot_data_path,
get_astrbot_root,
get_astrbot_temp_path,
)
from ..olayer import FileSystemComponent, PythonComponent, ShellComponent
from .base import ComputerBooter
SandboxBackend = Literal["none", "bwrap", "seatbelt"]
_BLOCKED_COMMAND_PATTERNS = [
" rm -rf ",
" rm -fr ",
@@ -42,132 +40,20 @@ def _is_safe_command(command: str) -> bool:
return not any(pat in cmd for pat in _BLOCKED_COMMAND_PATTERNS)
def _escape_seatbelt_string(raw: str) -> str:
return raw.replace("\\", "\\\\").replace('"', '\\"')
def _session_workspace_name(session_id: str) -> str:
safe_prefix = re.sub(r"[^A-Za-z0-9._-]+", "_", session_id).strip("._-")
if not safe_prefix:
safe_prefix = "session"
safe_prefix = safe_prefix[:40]
suffix = uuid.uuid5(uuid.NAMESPACE_DNS, session_id).hex[:12]
return f"{safe_prefix}_{suffix}"
def _detect_sandbox_backend() -> SandboxBackend:
if sys.platform.startswith("linux"):
if shutil.which("bwrap"):
return "bwrap"
raise RuntimeError("Local runtime requires 'bwrap' on Linux.")
if sys.platform == "darwin":
if shutil.which("sandbox-exec"):
return "seatbelt"
raise RuntimeError("Local runtime requires 'sandbox-exec' on macOS.")
return "none"
@dataclass(frozen=True)
class LocalSandboxPolicy:
workspace: Path
backend: SandboxBackend
sandboxed: bool
default_cwd: Path
@classmethod
def build_default(cls, session_id: str, sandboxed: bool) -> LocalSandboxPolicy:
workspace_root_raw = os.environ.get(
"ASTRBOT_LOCAL_WORKSPACE_ROOT"
) or os.environ.get("ASTRBOT_LOCAL_WORKSPACE", "~/.astrbot/workspace")
workspace_root = Path(workspace_root_raw).expanduser().resolve()
workspace = workspace_root / _session_workspace_name(session_id)
default_cwd = workspace if sandboxed else Path(get_astrbot_root()).resolve()
return cls(
workspace=workspace,
backend=_detect_sandbox_backend() if sandboxed else "none",
sandboxed=sandboxed,
default_cwd=default_cwd,
)
def ensure_workspace(self) -> None:
try:
self.workspace.mkdir(parents=True, exist_ok=True)
except PermissionError as exc:
raise RuntimeError(
"Cannot create local workspace. "
"Set ASTRBOT_LOCAL_WORKSPACE_ROOT to a writable path."
) from exc
def resolve_path(self, path: str, base: Path | None = None) -> Path:
raw = Path(path).expanduser()
resolved = raw if raw.is_absolute() else (base or self.default_cwd) / raw
return resolved.resolve()
def ensure_writable_path(self, path: str) -> Path:
abs_path = self.resolve_path(path)
if self.sandboxed and not abs_path.is_relative_to(self.workspace):
raise PermissionError(
f"Write path is outside workspace: {self.workspace.as_posix()}"
)
return abs_path
def normalize_working_dir(self, cwd: str | None) -> Path:
target = self.resolve_path(cwd) if cwd else self.default_cwd
if not target.exists():
raise FileNotFoundError(f"Working directory does not exist: {target}")
if not target.is_dir():
raise NotADirectoryError(f"Working directory is not a directory: {target}")
return target
def wrap_command(self, command: list[str], working_dir: Path) -> list[str]:
if not self.sandboxed:
return command
if self.backend == "bwrap":
return [
"bwrap",
"--die-with-parent",
"--new-session",
"--ro-bind",
"/",
"/",
"--bind",
str(self.workspace),
str(self.workspace),
"--proc",
"/proc",
"--dev",
"/dev",
"--chdir",
str(working_dir),
"--",
*command,
]
if self.backend == "seatbelt":
workspace_escaped = _escape_seatbelt_string(str(self.workspace))
profile = "\n".join(
[
"(version 1)",
"(deny default)",
'(import "system.sb")',
"(allow process*)",
"(allow file-read*)",
f'(allow file-write* (subpath "{workspace_escaped}"))',
"(allow network*)",
]
)
return ["sandbox-exec", "-p", profile, *command]
raise RuntimeError("Sandbox backend is not available for local_sandboxed mode.")
def _ensure_safe_path(path: str) -> str:
abs_path = os.path.abspath(path)
allowed_roots = [
os.path.abspath(get_astrbot_root()),
os.path.abspath(get_astrbot_data_path()),
os.path.abspath(get_astrbot_temp_path()),
]
if not any(abs_path.startswith(root) for root in allowed_roots):
raise PermissionError("Path is outside the allowed computer roots.")
return abs_path
@dataclass
class LocalShellComponent(ShellComponent):
policy: LocalSandboxPolicy
async def exec(
self,
command: str,
@@ -181,58 +67,41 @@ class LocalShellComponent(ShellComponent):
raise PermissionError("Blocked unsafe shell command.")
def _run() -> dict[str, Any]:
shell_command = (
["/bin/sh", "-lc", command] if shell else shlex.split(command)
)
run_env = os.environ.copy()
if env:
run_env.update({str(k): str(v) for k, v in env.items()})
working_dir = self.policy.normalize_working_dir(cwd)
wrapped_command = self.policy.wrap_command(shell_command, working_dir)
working_dir = _ensure_safe_path(cwd) if cwd else get_astrbot_root()
if background:
proc = subprocess.Popen(
wrapped_command,
shell=False,
command,
shell=shell,
cwd=working_dir,
env=run_env,
stdin=subprocess.DEVNULL,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
)
return {"pid": proc.pid, "stdout": "", "stderr": "", "exit_code": None}
try:
result = subprocess.run(
wrapped_command,
shell=False,
cwd=working_dir,
env=run_env,
timeout=timeout,
stdin=subprocess.DEVNULL,
capture_output=True,
text=True,
)
return {
"stdout": result.stdout,
"stderr": result.stderr,
"exit_code": result.returncode,
}
except subprocess.TimeoutExpired:
timeout_seconds = timeout if timeout is not None else "configured"
return {
"stdout": "",
"stderr": f"Execution timed out after {timeout_seconds} seconds.",
"exit_code": 124,
}
result = subprocess.run(
command,
shell=shell,
cwd=working_dir,
env=run_env,
timeout=timeout,
capture_output=True,
text=True,
)
return {
"stdout": result.stdout,
"stderr": result.stderr,
"exit_code": result.returncode,
}
return await asyncio.to_thread(_run)
@dataclass
class LocalPythonComponent(PythonComponent):
policy: LocalSandboxPolicy
async def exec(
self,
code: str,
@@ -241,13 +110,9 @@ class LocalPythonComponent(PythonComponent):
silent: bool = False,
) -> dict[str, Any]:
def _run() -> dict[str, Any]:
python_command = [os.environ.get("PYTHON", sys.executable), "-c", code]
working_dir = self.policy.normalize_working_dir(None)
wrapped_command = self.policy.wrap_command(python_command, working_dir)
try:
result = subprocess.run(
wrapped_command,
cwd=working_dir,
[os.environ.get("PYTHON", sys.executable), "-c", code],
timeout=timeout,
capture_output=True,
text=True,
@@ -273,25 +138,23 @@ class LocalPythonComponent(PythonComponent):
@dataclass
class LocalFileSystemComponent(FileSystemComponent):
policy: LocalSandboxPolicy
async def create_file(
self, path: str, content: str = "", mode: int = 0o644
) -> dict[str, Any]:
def _run() -> dict[str, Any]:
abs_path = self.policy.ensure_writable_path(path)
abs_path.parent.mkdir(parents=True, exist_ok=True)
with abs_path.open("w", encoding="utf-8") as f:
abs_path = _ensure_safe_path(path)
os.makedirs(os.path.dirname(abs_path), exist_ok=True)
with open(abs_path, "w", encoding="utf-8") as f:
f.write(content)
abs_path.chmod(mode)
return {"success": True, "path": str(abs_path)}
os.chmod(abs_path, mode)
return {"success": True, "path": abs_path}
return await asyncio.to_thread(_run)
async def read_file(self, path: str, encoding: str = "utf-8") -> dict[str, Any]:
def _run() -> dict[str, Any]:
abs_path = self.policy.resolve_path(path)
with abs_path.open(encoding=encoding) as f:
abs_path = _ensure_safe_path(path)
with open(abs_path, encoding=encoding) as f:
content = f.read()
return {"success": True, "content": content}
@@ -301,22 +164,22 @@ class LocalFileSystemComponent(FileSystemComponent):
self, path: str, content: str, mode: str = "w", encoding: str = "utf-8"
) -> dict[str, Any]:
def _run() -> dict[str, Any]:
abs_path = self.policy.ensure_writable_path(path)
abs_path.parent.mkdir(parents=True, exist_ok=True)
with abs_path.open(mode, encoding=encoding) as f:
abs_path = _ensure_safe_path(path)
os.makedirs(os.path.dirname(abs_path), exist_ok=True)
with open(abs_path, mode, encoding=encoding) as f:
f.write(content)
return {"success": True, "path": str(abs_path)}
return {"success": True, "path": abs_path}
return await asyncio.to_thread(_run)
async def delete_file(self, path: str) -> dict[str, Any]:
def _run() -> dict[str, Any]:
abs_path = self.policy.ensure_writable_path(path)
if abs_path.is_dir():
abs_path = _ensure_safe_path(path)
if os.path.isdir(abs_path):
shutil.rmtree(abs_path)
else:
abs_path.unlink()
return {"success": True, "path": str(abs_path)}
os.remove(abs_path)
return {"success": True, "path": abs_path}
return await asyncio.to_thread(_run)
@@ -324,8 +187,8 @@ class LocalFileSystemComponent(FileSystemComponent):
self, path: str = ".", show_hidden: bool = False
) -> dict[str, Any]:
def _run() -> dict[str, Any]:
abs_path = self.policy.resolve_path(path)
entries = [entry.name for entry in abs_path.iterdir()]
abs_path = _ensure_safe_path(path)
entries = os.listdir(abs_path)
if not show_hidden:
entries = [e for e in entries if not e.startswith(".")]
return {"success": True, "entries": entries}
@@ -334,28 +197,13 @@ class LocalFileSystemComponent(FileSystemComponent):
class LocalBooter(ComputerBooter):
def __init__(self, session_id: str, sandboxed: bool = False) -> None:
self._session_id = session_id
self._policy = LocalSandboxPolicy.build_default(
session_id=session_id, sandboxed=sandboxed
)
if sandboxed:
self._policy.ensure_workspace()
if sandboxed and self._policy.backend == "none":
logger.warning(
f"Local runtime sandbox backend is unavailable on {sys.platform}. "
"Only filesystem tools are restricted to workspace."
)
self._fs = LocalFileSystemComponent(policy=self._policy)
self._python = LocalPythonComponent(policy=self._policy)
self._shell = LocalShellComponent(policy=self._policy)
def __init__(self) -> None:
self._fs = LocalFileSystemComponent()
self._python = LocalPythonComponent()
self._shell = LocalShellComponent()
async def boot(self, session_id: str) -> None:
logger.info(
f"Local computer booter initialized for session: {session_id} "
f"(sandboxed={self._policy.sandboxed}, "
f"backend={self._policy.backend}, workspace={self._policy.workspace})"
)
logger.info(f"Local computer booter initialized for session: {session_id}")
async def shutdown(self) -> None:
logger.info("Local computer booter shutdown complete.")
+6 -6
View File
@@ -15,7 +15,7 @@ from .booters.base import ComputerBooter
from .booters.local import LocalBooter
session_booter: dict[str, ComputerBooter] = {}
local_booters: dict[tuple[str, bool], ComputerBooter] = {}
local_booter: ComputerBooter | None = None
async def _sync_skills_to_sandbox(booter: ComputerBooter) -> None:
@@ -104,8 +104,8 @@ async def get_booter(
return session_booter[session_id]
def get_local_booter(session_id: str, sandboxed: bool = False) -> ComputerBooter:
key = (session_id, sandboxed)
if key not in local_booters:
local_booters[key] = LocalBooter(session_id=session_id, sandboxed=sandboxed)
return local_booters[key]
def get_local_booter() -> ComputerBooter:
global local_booter
if local_booter is None:
local_booter = LocalBooter()
return local_booter
-5
View File
@@ -11,7 +11,6 @@ from astrbot.core.message.components import File
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from ..computer_client import get_booter
from .permissions import check_admin_permission
# @dataclass
# class CreateFileTool(FunctionTool):
@@ -103,8 +102,6 @@ class FileUploadTool(FunctionTool):
context: ContextWrapper[AstrAgentContext],
local_path: str,
) -> str | None:
if permission_error := check_admin_permission(context, "File upload/download"):
return permission_error
sb = await get_booter(
context.context.context,
context.context.event.unified_msg_origin,
@@ -164,8 +161,6 @@ class FileDownloadTool(FunctionTool):
remote_path: str,
also_send_to_user: bool = True,
) -> ToolExecResult:
if permission_error := check_admin_permission(context, "File upload/download"):
return permission_error
sb = await get_booter(
context.context.context,
context.context.event.unified_msg_origin,
@@ -1,19 +0,0 @@
from astrbot.core.agent.run_context import ContextWrapper
from astrbot.core.astr_agent_context import AstrAgentContext
def check_admin_permission(
context: ContextWrapper[AstrAgentContext], operation_name: str
) -> str | None:
cfg = context.context.context.get_config(
umo=context.context.event.unified_msg_origin
)
provider_settings = cfg.get("provider_settings", {})
require_admin = provider_settings.get("computer_use_require_admin", True)
if require_admin and context.context.event.role != "admin":
return (
f"error: Permission denied. {operation_name} is only allowed for admin users. "
"Tell user to set admins in `AstrBot WebUI -> Config -> General Config` by adding their user ID to the admins list if they need this feature. "
f"User's ID is: {context.context.event.get_sender_id()}. User's ID can be found by using /sid command."
)
return None
+19 -16
View File
@@ -7,7 +7,6 @@ from astrbot.core.agent.run_context import ContextWrapper
from astrbot.core.agent.tool import ToolExecResult
from astrbot.core.astr_agent_context import AstrAgentContext, AstrMessageEvent
from astrbot.core.computer.computer_client import get_booter, get_local_booter
from astrbot.core.computer.tools.permissions import check_admin_permission
from astrbot.core.message.message_event_result import MessageChain
param_schema = {
@@ -27,6 +26,21 @@ param_schema = {
}
def _check_admin_permission(context: ContextWrapper[AstrAgentContext]) -> str | None:
cfg = context.context.context.get_config(
umo=context.context.event.unified_msg_origin
)
provider_settings = cfg.get("provider_settings", {})
require_admin = provider_settings.get("computer_use_require_admin", True)
if require_admin and context.context.event.role != "admin":
return (
"error: Permission denied. Python execution is only allowed for admin users. "
"Tell user to set admins in `AstrBot WebUI -> Config -> General Config` by adding their user ID to the admins list if they need this feature."
f"User's ID is: {context.context.event.get_sender_id()}. User's ID can be found by using /sid command."
)
return None
async def handle_result(result: dict, event: AstrMessageEvent) -> ToolExecResult:
data = result.get("data", {})
output = data.get("output", {})
@@ -67,7 +81,7 @@ class PythonTool(FunctionTool):
async def call(
self, context: ContextWrapper[AstrAgentContext], code: str, silent: bool = False
) -> ToolExecResult:
if permission_error := check_admin_permission(context, "Python execution"):
if permission_error := _check_admin_permission(context):
return permission_error
sb = await get_booter(
context.context.context,
@@ -83,27 +97,16 @@ class PythonTool(FunctionTool):
@dataclass
class LocalPythonTool(FunctionTool):
name: str = "astrbot_execute_python"
description: str = (
"Execute code in a local Python environment. "
"In local_sandboxed runtime, writes are restricted to ~/.astrbot/workspace/<session>."
)
description: str = "Execute codes in a Python environment."
parameters: dict = field(default_factory=lambda: param_schema)
async def call(
self, context: ContextWrapper[AstrAgentContext], code: str, silent: bool = False
) -> ToolExecResult:
if permission_error := check_admin_permission(context, "Python execution"):
if permission_error := _check_admin_permission(context):
return permission_error
event = context.context.event
cfg = context.context.context.get_config(umo=event.unified_msg_origin)
runtime = str(
cfg.get("provider_settings", {}).get("computer_use_runtime", "local")
)
sb = get_local_booter(
event.unified_msg_origin,
sandboxed=runtime == "local_sandboxed",
)
sb = get_local_booter()
try:
result = await sb.python.exec(code, silent=silent)
return await handle_result(result, context.context.event)
+21 -29
View File
@@ -7,27 +7,34 @@ from astrbot.core.agent.tool import ToolExecResult
from astrbot.core.astr_agent_context import AstrAgentContext
from ..computer_client import get_booter, get_local_booter
from .permissions import check_admin_permission
def _check_admin_permission(context: ContextWrapper[AstrAgentContext]) -> str | None:
cfg = context.context.context.get_config(
umo=context.context.event.unified_msg_origin
)
provider_settings = cfg.get("provider_settings", {})
require_admin = provider_settings.get("computer_use_require_admin", True)
if require_admin and context.context.event.role != "admin":
return (
"error: Permission denied. Shell execution is only allowed for admin users. "
"Tell user to set admins in `AstrBot WebUI -> Config -> General Config` by adding their user ID to the admins list if they need this feature."
f"User's ID is: {context.context.event.get_sender_id()}. User's ID can be found by using /sid command."
)
return None
@dataclass
class ExecuteShellTool(FunctionTool):
name: str = "astrbot_execute_shell"
description: str = (
"Execute a command in the shell. "
"In local_sandboxed runtime, writes are restricted to ~/.astrbot/workspace/<session>."
)
description: str = "Execute a command in the shell."
parameters: dict = field(
default_factory=lambda: {
"type": "object",
"properties": {
"command": {
"type": "string",
"description": "The shell command to execute.",
},
"cwd": {
"type": "string",
"description": "Optional working directory for command execution.",
"description": "The bash command to execute. Equal to 'cd {working_dir} && {your_command}'.",
},
"background": {
"type": "boolean",
@@ -51,36 +58,21 @@ class ExecuteShellTool(FunctionTool):
self,
context: ContextWrapper[AstrAgentContext],
command: str,
cwd: str | None = None,
background: bool = False,
env: dict = {},
) -> ToolExecResult:
if permission_error := check_admin_permission(context, "Shell execution"):
if permission_error := _check_admin_permission(context):
return permission_error
event = context.context.event
cfg = context.context.context.get_config(umo=event.unified_msg_origin)
runtime = str(
cfg.get("provider_settings", {}).get("computer_use_runtime", "local")
)
if self.is_local:
sb = get_local_booter(
event.unified_msg_origin,
sandboxed=runtime == "local_sandboxed",
)
sb = get_local_booter()
else:
sb = await get_booter(
context.context.context,
event.unified_msg_origin,
context.context.event.unified_msg_origin,
)
try:
result = await sb.shell.exec(
command,
cwd=cwd,
background=background,
env=env,
)
result = await sb.shell.exec(command, background=background, env=env)
return json.dumps(result)
except Exception as e:
return f"Error executing command: {str(e)}"
+6 -29
View File
@@ -5,7 +5,7 @@ from typing import Any, TypedDict
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
VERSION = "4.18.3"
VERSION = "4.18.1"
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db")
WEBHOOK_SUPPORTED_PLATFORMS = [
@@ -100,7 +100,6 @@ DEFAULT_CONFIG = {
"dequeue_context_length": 1,
"streaming_response": False,
"show_tool_use_status": False,
"show_tool_call_result": False,
"sanitize_context_by_modalities": False,
"max_quoted_fallback_images": 20,
"quoted_message_parser": {
@@ -425,15 +424,7 @@ CONFIG_METADATA_2 = {
"slack_webhook_port": 6197,
"slack_webhook_path": "/astrbot-slack-webhook/callback",
},
"Line": {
"id": "line",
"type": "line",
"enable": False,
"channel_access_token": "",
"channel_secret": "",
"unified_webhook_mode": True,
"webhook_uuid": "",
},
# LINE's config is located in line_adapter.py
"Satori": {
"id": "satori",
"type": "satori",
@@ -1471,7 +1462,6 @@ CONFIG_METADATA_2 = {
"type": "openai_embedding",
"provider": "openai",
"provider_type": "embedding",
"hint": "provider_group.provider.openai_embedding.hint",
"enable": True,
"embedding_api_key": "",
"embedding_api_base": "",
@@ -1485,7 +1475,6 @@ CONFIG_METADATA_2 = {
"type": "gemini_embedding",
"provider": "google",
"provider_type": "embedding",
"hint": "provider_group.provider.gemini_embedding.hint",
"enable": True,
"embedding_api_key": "",
"embedding_api_base": "",
@@ -2202,9 +2191,9 @@ CONFIG_METADATA_2 = {
"type": "string",
},
"proxy": {
"description": "provider_group.provider.proxy.description",
"description": "代理地址",
"type": "string",
"hint": "provider_group.provider.proxy.hint",
"hint": "HTTP/HTTPS 代理地址,格式如 http://127.0.0.1:7890。仅对该提供商的 API 请求生效,不影响 Docker 内网通信。",
},
"model": {
"description": "模型 ID",
@@ -2317,9 +2306,6 @@ CONFIG_METADATA_2 = {
"show_tool_use_status": {
"type": "bool",
},
"show_tool_call_result": {
"type": "bool",
},
"unsupported_streaming_strategy": {
"type": "string",
},
@@ -2772,8 +2758,8 @@ CONFIG_METADATA_3 = {
"provider_settings.computer_use_runtime": {
"description": "Computer Use Runtime",
"type": "string",
"options": ["none", "local", "local_sandboxed", "sandbox"],
"labels": ["", "本地", "本地(沙箱增强)", "沙箱"],
"options": ["none", "local", "sandbox"],
"labels": ["", "本地", "沙箱"],
"hint": "选择 Computer Use 运行环境。",
},
"provider_settings.computer_use_require_admin": {
@@ -3008,15 +2994,6 @@ CONFIG_METADATA_3 = {
"provider_settings.agent_runner_type": "local",
},
},
"provider_settings.show_tool_call_result": {
"description": "输出函数调用返回结果",
"type": "bool",
"hint": "仅在输出函数调用状态启用时生效,展示结果前 70 个字符。",
"condition": {
"provider_settings.agent_runner_type": "local",
"provider_settings.show_tool_use_status": True,
},
},
"provider_settings.sanitize_context_by_modalities": {
"description": "按模型能力清理历史上下文",
"type": "bool",
+6 -6
View File
@@ -4,7 +4,7 @@ import typing as T
from collections.abc import Awaitable, Callable
from datetime import datetime, timedelta, timezone
from sqlalchemy import CursorResult, Row
from sqlalchemy import CursorResult
from sqlalchemy.ext.asyncio import AsyncSession
from sqlmodel import col, delete, desc, func, or_, select, text, update
@@ -626,7 +626,7 @@ class SQLiteDatabase(BaseDatabase):
query = select(ApiKey).where(
ApiKey.key_hash == key_hash,
col(ApiKey.revoked_at).is_(None),
or_(col(ApiKey.expires_at).is_(None), col(ApiKey.expires_at) > now),
or_(col(ApiKey.expires_at).is_(None), ApiKey.expires_at > now),
)
result = await session.execute(query)
return result.scalar_one_or_none()
@@ -638,7 +638,7 @@ class SQLiteDatabase(BaseDatabase):
async with session.begin():
await session.execute(
update(ApiKey)
.where(col(ApiKey.key_id) == key_id)
.where(ApiKey.key_id == key_id)
.values(last_used_at=datetime.now(timezone.utc)),
)
@@ -649,7 +649,7 @@ class SQLiteDatabase(BaseDatabase):
async with session.begin():
query = (
update(ApiKey)
.where(col(ApiKey.key_id) == key_id)
.where(ApiKey.key_id == key_id)
.values(revoked_at=datetime.now(timezone.utc))
)
result = T.cast(CursorResult, await session.execute(query))
@@ -663,7 +663,7 @@ class SQLiteDatabase(BaseDatabase):
result = T.cast(
CursorResult,
await session.execute(
delete(ApiKey).where(col(ApiKey.key_id) == key_id)
delete(ApiKey).where(ApiKey.key_id == key_id)
),
)
return result.rowcount > 0
@@ -1457,7 +1457,7 @@ class SQLiteDatabase(BaseDatabase):
return query
@staticmethod
def _rows_to_session_dicts(rows: T.Sequence[Row[tuple]]) -> list[dict]:
def _rows_to_session_dicts(rows: list[tuple]) -> list[dict]:
sessions_with_projects = []
for row in rows:
platform_session = row[0]
@@ -256,46 +256,6 @@ class KBSQLiteDatabase:
"knowledge_base": row[1],
}
async def get_documents_with_metadata_batch(
self, doc_ids: set[str]
) -> dict[str, dict]:
"""批量获取文档及其所属知识库元数据
Args:
doc_ids: 文档 ID 集合
Returns:
dict: doc_id -> {"document": KBDocument, "knowledge_base": KnowledgeBase}
"""
if not doc_ids:
return {}
metadata_map: dict[str, dict] = {}
# SQLite 参数上限为 999,分片查询避免超限
chunk_size = 900
doc_id_list = list(doc_ids)
async with self.get_db() as session:
for i in range(0, len(doc_id_list), chunk_size):
chunk = doc_id_list[i : i + chunk_size]
stmt = (
select(KBDocument, KnowledgeBase)
.join(
KnowledgeBase,
col(KBDocument.kb_id) == col(KnowledgeBase.kb_id),
)
.where(col(KBDocument.doc_id).in_(chunk))
)
result = await session.execute(stmt)
for row in result.all():
metadata_map[row[0].doc_id] = {
"document": row[0],
"knowledge_base": row[1],
}
return metadata_map
async def delete_document_by_id(self, doc_id: str, vec_db: FaissVecDB) -> None:
"""删除单个文档及其相关数据"""
# 在知识库表中删除
@@ -142,13 +142,10 @@ class RetrievalManager:
f"Rank fusion took {time_end - time_start:.2f}s and returned {len(fused_results)} results.",
)
# 4. 转换为 RetrievalResult (批量获取元数据)
doc_ids = {fr.doc_id for fr in fused_results}
metadata_map = await self.kb_db.get_documents_with_metadata_batch(doc_ids)
# 4. 转换为 RetrievalResult (获取元数据)
retrieval_results = []
for fr in fused_results:
metadata_dict = metadata_map.get(fr.doc_id)
metadata_dict = await self.kb_db.get_document_with_metadata(fr.doc_id)
if metadata_dict:
retrieval_results.append(
RetrievalResult(
+3 -28
View File
@@ -720,38 +720,13 @@ class File(BaseMessageComponent):
if allow_return_url and self.url:
return self.url
if self.file_:
path = self.file_
if path.startswith("file://"):
# 处理 file:// (2 slashes) 或 file:/// (3 slashes)
# pathlib.as_uri() 通常生成 file:///
path = path[7:]
# 兼容 Windows: file:///C:/path -> /C:/path -> C:/path
if (
os.name == "nt"
and len(path) > 2
and path[0] == "/"
and path[2] == ":"
):
path = path[1:]
if os.path.exists(path):
return os.path.abspath(path)
if self.file_ and os.path.exists(self.file_):
return os.path.abspath(self.file_)
if self.url:
await self._download_file()
if self.file_:
path = self.file_
if path.startswith("file://"):
path = path[7:]
if (
os.name == "nt"
and len(path) > 2
and path[0] == "/"
and path[2] == ":"
):
path = path[1:]
return os.path.abspath(path)
return os.path.abspath(self.file_)
return ""
-55
View File
@@ -1,5 +1,4 @@
from astrbot import logger
from astrbot.api import sp
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
from astrbot.core.db import BaseDatabase
from astrbot.core.db.po import Persona, PersonaFolder, Personality
@@ -59,60 +58,6 @@ class PersonaManager:
except Exception:
return DEFAULT_PERSONALITY
async def resolve_selected_persona(
self,
*,
umo: str | MessageSession,
conversation_persona_id: str | None,
platform_name: str,
provider_settings: dict | None = None,
) -> tuple[str | None, Personality | None, str | None, bool]:
"""解析当前会话最终生效的人格。
Returns:
tuple:
- selected persona_id
- selected persona object
- force applied persona_id from session rule
- whether use webchat special default persona
"""
session_service_config = (
await sp.get_async(
scope="umo",
scope_id=str(umo),
key="session_service_config",
default={},
)
or {}
)
force_applied_persona_id = session_service_config.get("persona_id")
persona_id = force_applied_persona_id
if not persona_id:
persona_id = conversation_persona_id
if persona_id == "[%None]":
pass
elif persona_id is None:
persona_id = (provider_settings or {}).get("default_personality")
persona = next(
(item for item in self.personas_v3 if item["name"] == persona_id),
None,
)
use_webchat_special_default = False
if not persona and platform_name == "webchat" and persona_id != "[%None]":
persona_id = "_chatui_default_"
use_webchat_special_default = True
return (
persona_id,
persona,
force_applied_persona_id,
use_webchat_special_default,
)
async def delete_persona(self, persona_id: str) -> None:
"""删除指定 persona"""
if not await self.db.get_persona_by_id(persona_id):
+21 -77
View File
@@ -1,71 +1,30 @@
"""Pipeline package exports.
This module intentionally avoids eager imports of all pipeline stage modules to
prevent import-time cycles. Stage classes remain available via lazy attribute
resolution for backward compatibility.
"""
from __future__ import annotations
from importlib import import_module
from typing import TYPE_CHECKING, Any
from astrbot.core.message.message_event_result import (
EventResultType,
MessageEventResult,
)
from .stage_order import STAGES_ORDER
from .content_safety_check.stage import ContentSafetyCheckStage
from .preprocess_stage.stage import PreProcessStage
from .process_stage.stage import ProcessStage
from .rate_limit_check.stage import RateLimitStage
from .respond.stage import RespondStage
from .result_decorate.stage import ResultDecorateStage
from .session_status_check.stage import SessionStatusCheckStage
from .waking_check.stage import WakingCheckStage
from .whitelist_check.stage import WhitelistCheckStage
if TYPE_CHECKING:
from .content_safety_check.stage import ContentSafetyCheckStage
from .preprocess_stage.stage import PreProcessStage
from .process_stage.stage import ProcessStage
from .rate_limit_check.stage import RateLimitStage
from .respond.stage import RespondStage
from .result_decorate.stage import ResultDecorateStage
from .session_status_check.stage import SessionStatusCheckStage
from .waking_check.stage import WakingCheckStage
from .whitelist_check.stage import WhitelistCheckStage
_LAZY_EXPORTS = {
"ContentSafetyCheckStage": (
"astrbot.core.pipeline.content_safety_check.stage",
"ContentSafetyCheckStage",
),
"PreProcessStage": (
"astrbot.core.pipeline.preprocess_stage.stage",
"PreProcessStage",
),
"ProcessStage": (
"astrbot.core.pipeline.process_stage.stage",
"ProcessStage",
),
"RateLimitStage": (
"astrbot.core.pipeline.rate_limit_check.stage",
"RateLimitStage",
),
"RespondStage": (
"astrbot.core.pipeline.respond.stage",
"RespondStage",
),
"ResultDecorateStage": (
"astrbot.core.pipeline.result_decorate.stage",
"ResultDecorateStage",
),
"SessionStatusCheckStage": (
"astrbot.core.pipeline.session_status_check.stage",
"SessionStatusCheckStage",
),
"WakingCheckStage": (
"astrbot.core.pipeline.waking_check.stage",
"WakingCheckStage",
),
"WhitelistCheckStage": (
"astrbot.core.pipeline.whitelist_check.stage",
"WhitelistCheckStage",
),
}
# 管道阶段顺序
STAGES_ORDER = [
"WakingCheckStage", # 检查是否需要唤醒
"WhitelistCheckStage", # 检查是否在群聊/私聊白名单
"SessionStatusCheckStage", # 检查会话是否整体启用
"RateLimitStage", # 检查会话是否超过频率限制
"ContentSafetyCheckStage", # 检查内容安全
"PreProcessStage", # 预处理
"ProcessStage", # 交由 Stars 处理(a.k.a 插件),或者 LLM 调用
"ResultDecorateStage", # 处理结果,比如添加回复前缀、t2i、转换为语音 等
"RespondStage", # 发送消息
]
__all__ = [
"ContentSafetyCheckStage",
@@ -77,21 +36,6 @@ __all__ = [
"RespondStage",
"ResultDecorateStage",
"SessionStatusCheckStage",
"STAGES_ORDER",
"WakingCheckStage",
"WhitelistCheckStage",
]
def __getattr__(name: str) -> Any:
if name not in _LAZY_EXPORTS:
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
module_path, attr_name = _LAZY_EXPORTS[name]
module = import_module(module_path)
value = getattr(module, attr_name)
globals()[name] = value
return value
def __dir__() -> list[str]:
return sorted(set(globals()) | set(__all__))
-52
View File
@@ -1,52 +0,0 @@
"""Pipeline bootstrap utilities."""
from importlib import import_module
from .stage import registered_stages
_BUILTIN_STAGE_MODULES = (
"astrbot.core.pipeline.waking_check.stage",
"astrbot.core.pipeline.whitelist_check.stage",
"astrbot.core.pipeline.session_status_check.stage",
"astrbot.core.pipeline.rate_limit_check.stage",
"astrbot.core.pipeline.content_safety_check.stage",
"astrbot.core.pipeline.preprocess_stage.stage",
"astrbot.core.pipeline.process_stage.stage",
"astrbot.core.pipeline.result_decorate.stage",
"astrbot.core.pipeline.respond.stage",
)
_EXPECTED_STAGE_NAMES = {
"WakingCheckStage",
"WhitelistCheckStage",
"SessionStatusCheckStage",
"RateLimitStage",
"ContentSafetyCheckStage",
"PreProcessStage",
"ProcessStage",
"ResultDecorateStage",
"RespondStage",
}
_builtin_stages_registered = False
def ensure_builtin_stages_registered() -> None:
"""Ensure built-in pipeline stages are imported and registered."""
global _builtin_stages_registered
if _builtin_stages_registered:
return
stage_names = {stage_cls.__name__ for stage_cls in registered_stages}
if _EXPECTED_STAGE_NAMES.issubset(stage_names):
_builtin_stages_registered = True
return
for module_path in _BUILTIN_STAGE_MODULES:
import_module(module_path)
_builtin_stages_registered = True
__all__ = ["ensure_builtin_stages_registered"]
+2 -4
View File
@@ -1,9 +1,7 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Any
from astrbot.core.config import AstrBotConfig
from astrbot.core.star import PluginManager
from .context_utils import call_event_hook, call_handler
@@ -13,7 +11,7 @@ class PipelineContext:
"""上下文对象,包含管道执行所需的上下文信息"""
astrbot_config: AstrBotConfig # AstrBot 配置对象
plugin_manager: Any # 插件管理器对象
plugin_manager: PluginManager # 插件管理器对象
astrbot_config_id: str
call_handler = call_handler
call_event_hook = call_event_hook
@@ -19,7 +19,6 @@ from astrbot.core.message.message_event_result import (
MessageEventResult,
ResultContentType,
)
from astrbot.core.pipeline.stage import Stage
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.provider.entities import (
LLMResponse,
@@ -31,6 +30,7 @@ from astrbot.core.utils.session_lock import session_lock_manager
from .....astr_agent_run_util import run_agent, run_live_agent
from ....context import PipelineContext, call_event_hook
from ...stage import Stage
class InternalAgentSubStage(Stage):
@@ -54,7 +54,6 @@ class InternalAgentSubStage(Stage):
if isinstance(self.max_step, bool): # workaround: #2622
self.max_step = 30
self.show_tool_use: bool = settings.get("show_tool_use_status", True)
self.show_tool_call_result: bool = settings.get("show_tool_call_result", False)
self.show_reasoning = settings.get("display_reasoning_text", False)
self.sanitize_context_by_modalities: bool = settings.get(
"sanitize_context_by_modalities",
@@ -241,7 +240,6 @@ class InternalAgentSubStage(Stage):
tts_provider,
self.max_step,
self.show_tool_use,
self.show_tool_call_result,
show_reasoning=self.show_reasoning,
),
),
@@ -271,7 +269,6 @@ class InternalAgentSubStage(Stage):
agent_runner,
self.max_step,
self.show_tool_use,
self.show_tool_call_result,
show_reasoning=self.show_reasoning,
),
),
@@ -300,7 +297,6 @@ class InternalAgentSubStage(Stage):
agent_runner,
self.max_step,
self.show_tool_use,
self.show_tool_call_result,
stream_to_general,
show_reasoning=self.show_reasoning,
):
@@ -8,7 +8,6 @@ from astrbot.core.agent.runners.dashscope.dashscope_agent_runner import (
DashscopeAgentRunner,
)
from astrbot.core.agent.runners.dify.dify_agent_runner import DifyAgentRunner
from astrbot.core.astr_agent_hooks import MAIN_AGENT_HOOKS
from astrbot.core.message.components import Image
from astrbot.core.message.message_event_result import (
MessageChain,
@@ -18,7 +17,6 @@ from astrbot.core.message.message_event_result import (
if TYPE_CHECKING:
from astrbot.core.agent.runners.base import BaseAgentRunner
from astrbot.core.pipeline.stage import Stage
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.provider.entities import (
ProviderRequest,
@@ -27,7 +25,9 @@ from astrbot.core.star.star_handler import EventType
from astrbot.core.utils.metrics import Metric
from .....astr_agent_context import AgentContextWrapper, AstrAgentContext
from .....astr_agent_hooks import MAIN_AGENT_HOOKS
from ....context import PipelineContext, call_event_hook
from ...stage import Stage
AGENT_RUNNER_TYPE_KEY = {
"dify": "dify_agent_runner_provider_id",
+1 -3
View File
@@ -8,17 +8,15 @@ from astrbot.core.platform.sources.wecom_ai_bot.wecomai_event import (
)
from astrbot.core.utils.active_event_registry import active_event_registry
from .bootstrap import ensure_builtin_stages_registered
from . import STAGES_ORDER
from .context import PipelineContext
from .stage import registered_stages
from .stage_order import STAGES_ORDER
class PipelineScheduler:
"""管道调度器,负责调度各个阶段的执行"""
def __init__(self, context: PipelineContext) -> None:
ensure_builtin_stages_registered()
registered_stages.sort(
key=lambda x: STAGES_ORDER.index(x.__name__),
) # 按照顺序排序
-15
View File
@@ -1,15 +0,0 @@
"""Pipeline stage execution order."""
STAGES_ORDER = [
"WakingCheckStage", # 检查是否需要唤醒
"WhitelistCheckStage", # 检查是否在群聊/私聊白名单
"SessionStatusCheckStage", # 检查会话是否整体启用
"RateLimitStage", # 检查会话是否超过频率限制
"ContentSafetyCheckStage", # 检查内容安全
"PreProcessStage", # 预处理
"ProcessStage", # 交由 Stars 处理(a.k.a 插件),或者 LLM 调用
"ResultDecorateStage", # 处理结果,比如添加回复前缀、t2i、转换为语音 等
"RespondStage", # 发送消息
]
__all__ = ["STAGES_ORDER"]
+11 -33
View File
@@ -52,19 +52,9 @@ class AstrMessageEvent(abc.ABC):
self.is_at_or_wake_command = False
"""是否是 At 机器人或者带有唤醒词或者是私聊(插件注册的事件监听器会让 is_wake 设为 True, 但是不会让这个属性置为 True)"""
self._extras: dict[str, Any] = {}
message_type = getattr(message_obj, "type", None)
if not isinstance(message_type, MessageType):
try:
message_type = MessageType(str(message_type))
except (ValueError, TypeError, AttributeError):
logger.warning(
f"Failed to convert message type {message_obj.type!r} to MessageType. "
f"Falling back to FRIEND_MESSAGE."
)
message_type = MessageType.FRIEND_MESSAGE
self.session = MessageSession(
platform_name=platform_meta.id,
message_type=message_type,
message_type=message_obj.type,
session_id=session_id,
)
# self.unified_msg_origin = str(self.session)
@@ -169,18 +159,15 @@ class AstrMessageEvent(abc.ABC):
除了文本消息外其他消息类型会被转换为对应的占位符如图片消息会被转换为 [图片]
"""
return self._outline_chain(getattr(self.message_obj, "message", None))
return self._outline_chain(self.message_obj.message)
def get_messages(self) -> list[BaseMessageComponent]:
"""获取消息链。"""
return getattr(self.message_obj, "message", [])
return self.message_obj.message
def get_message_type(self) -> MessageType:
"""获取消息类型。"""
message_type = getattr(self.message_obj, "type", None)
if isinstance(message_type, MessageType):
return message_type
return self.session.message_type
return self.message_obj.type
def get_session_id(self) -> str:
"""获取会话id。"""
@@ -188,30 +175,21 @@ class AstrMessageEvent(abc.ABC):
def get_group_id(self) -> str:
"""获取群组id。如果不是群组消息,返回空字符串。"""
return getattr(self.message_obj, "group_id", "")
return self.message_obj.group_id
def get_self_id(self) -> str:
"""获取机器人自身的id。"""
return getattr(self.message_obj, "self_id", "")
return self.message_obj.self_id
def get_sender_id(self) -> str:
"""获取消息发送者的id。"""
sender = getattr(self.message_obj, "sender", None)
if sender and isinstance(getattr(sender, "user_id", None), str):
return sender.user_id
return ""
return self.message_obj.sender.user_id
def get_sender_name(self) -> str:
"""获取消息发送者的名称。(可能会返回空字符串)"""
sender = getattr(self.message_obj, "sender", None)
if not sender:
return ""
nickname = getattr(sender, "nickname", None)
if nickname is None:
return ""
if isinstance(nickname, str):
return nickname
return str(nickname)
if isinstance(self.message_obj.sender.nickname, str):
return self.message_obj.sender.nickname
return ""
def set_extra(self, key, value) -> None:
"""设置额外的信息。"""
@@ -230,7 +208,7 @@ class AstrMessageEvent(abc.ABC):
def is_private_chat(self) -> bool:
"""是否是私聊。"""
return self.get_message_type() == MessageType.FRIEND_MESSAGE
return self.message_obj.type.value == (MessageType.FRIEND_MESSAGE).value
def is_wake_up(self) -> bool:
"""是否是唤醒机器人的事件。"""
-4
View File
@@ -180,10 +180,6 @@ class PlatformManager:
from .sources.line.line_adapter import (
LinePlatformAdapter, # noqa: F401
)
case "email":
from .sources.email.email_adapter import (
EmailPlatformAdapter, # noqa: F401
)
except (ImportError, ModuleNotFoundError) as e:
logger.error(
f"加载平台适配器 {platform_config['type']} 失败,原因:{e}。请检查依赖库是否安装。提示:可以在 管理面板->平台日志->安装Pip库 中安装依赖库。",
@@ -45,19 +45,6 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
if isinstance(segment, File):
# For File segments, we need to handle the file differently
d = await segment.to_dict()
file_val = d.get("data", {}).get("file", "")
if file_val:
import pathlib
try:
# 使用 pathlib 处理路径,能更好地处理 Windows/Linux 差异
path_obj = pathlib.Path(file_val)
# 如果是绝对路径且不包含协议头 (://),则转换为标准的 file: URI
if path_obj.is_absolute() and "://" not in file_val:
d["data"]["file"] = path_obj.as_uri()
except Exception:
# 如果不是合法路径(例如已经是特定的特殊字符串),则跳过转换
pass
return d
if isinstance(segment, Video):
d = await segment.to_dict()
@@ -1,5 +1,4 @@
import asyncio
import inspect
import itertools
import logging
import time
@@ -437,42 +436,7 @@ class AiocqhttpAdapter(Platform):
return coro
async def terminate(self) -> None:
if hasattr(self, "shutdown_event"):
self.shutdown_event.set()
await self._close_reverse_ws_connections()
async def _close_reverse_ws_connections(self) -> None:
api_clients = getattr(self.bot, "_wsr_api_clients", None)
event_clients = getattr(self.bot, "_wsr_event_clients", None)
ws_clients: set[Any] = set()
if isinstance(api_clients, dict):
ws_clients.update(api_clients.values())
if isinstance(event_clients, set):
ws_clients.update(event_clients)
close_tasks: list[Awaitable[Any]] = []
for ws in ws_clients:
close_func = getattr(ws, "close", None)
if not callable(close_func):
continue
try:
close_result = close_func(code=1000, reason="Adapter shutdown")
except TypeError:
close_result = close_func()
except Exception:
continue
if inspect.isawaitable(close_result):
close_tasks.append(close_result)
if close_tasks:
await asyncio.gather(*close_tasks, return_exceptions=True)
if isinstance(api_clients, dict):
api_clients.clear()
if isinstance(event_clients, set):
event_clients.clear()
self.shutdown_event.set()
async def shutdown_trigger_placeholder(self) -> None:
await self.shutdown_event.wait()
@@ -65,6 +65,15 @@ LINE_I18N_RESOURCES = {
"line",
"LINE Messaging API 适配器",
support_streaming_message=False,
default_config_tmpl={
"id": "line",
"type": "line",
"enable": False,
"channel_access_token": "",
"channel_secret": "",
"unified_webhook_mode": True,
"webhook_uuid": "",
},
config_metadata=LINE_CONFIG_METADATA,
i18n_resources=LINE_I18N_RESOURCES,
)
@@ -162,8 +162,6 @@ class QQOfficialMessageEvent(AstrMessageEvent):
)
payload["media"] = media
payload["msg_type"] = 7
payload.pop("markdown", None)
payload["content"] = plain_text or None
if record_file_path: # group record msg
media = await self.upload_group_and_c2c_record(
record_file_path,
@@ -172,8 +170,6 @@ class QQOfficialMessageEvent(AstrMessageEvent):
)
payload["media"] = media
payload["msg_type"] = 7
payload.pop("markdown", None)
payload["content"] = plain_text or None
ret = await self._send_with_markdown_fallback(
send_func=lambda retry_payload: self.bot.api.post_group_message(
group_openid=source.group_openid, # type: ignore
@@ -192,8 +188,6 @@ class QQOfficialMessageEvent(AstrMessageEvent):
)
payload["media"] = media
payload["msg_type"] = 7
payload.pop("markdown", None)
payload["content"] = plain_text or None
if record_file_path: # c2c record
media = await self.upload_group_and_c2c_record(
record_file_path,
@@ -202,8 +196,6 @@ class QQOfficialMessageEvent(AstrMessageEvent):
)
payload["media"] = media
payload["msg_type"] = 7
payload.pop("markdown", None)
payload["content"] = plain_text or None
if stream:
ret = await self._send_with_markdown_fallback(
send_func=lambda retry_payload: self.post_c2c_message(
@@ -1,9 +1,7 @@
import asyncio
import os
import re
import sys
import uuid
from typing import cast
from apscheduler.schedulers.asyncio import AsyncIOScheduler
from telegram import BotCommand, Update
@@ -27,9 +25,6 @@ from astrbot.core.star.filter.command import CommandFilter
from astrbot.core.star.filter.command_group import CommandGroupFilter
from astrbot.core.star.star import star_map
from astrbot.core.star.star_handler import star_handlers_registry
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from astrbot.core.utils.io import download_file
from astrbot.core.utils.media_utils import convert_audio_to_wav
from .tg_event import TelegramPlatformEvent
@@ -380,19 +375,8 @@ class TelegramPlatformAdapter(Platform):
elif update.message.voice:
file = await update.message.voice.get_file()
file_basename = os.path.basename(cast(str, file.file_path))
temp_dir = get_astrbot_temp_path()
temp_path = os.path.join(temp_dir, file_basename)
await download_file(cast(str, file.file_path), path=temp_path)
path_wav = os.path.join(
temp_dir,
f"{file_basename}.wav",
)
path_wav = await convert_audio_to_wav(temp_path, path_wav)
message.message = [
Comp.Record(file=path_wav, url=path_wav),
Comp.Record(file=file.file_path, url=file.file_path),
]
elif update.message.photo:
@@ -18,7 +18,6 @@ from astrbot.api.message_components import (
Plain,
Record,
Reply,
Video,
)
from astrbot.api.platform import AstrBotMessage, MessageType, PlatformMetadata
@@ -37,7 +36,6 @@ class TelegramPlatformEvent(AstrMessageEvent):
# 消息类型到 chat action 的映射,用于优先级判断
ACTION_BY_TYPE: dict[type, str] = {
Record: ChatAction.UPLOAD_VOICE,
Video: ChatAction.UPLOAD_VIDEO,
File: ChatAction.UPLOAD_DOCUMENT,
Image: ChatAction.UPLOAD_PHOTO,
Plain: ChatAction.TYPING,
@@ -116,18 +114,10 @@ class TelegramPlatformEvent(AstrMessageEvent):
**payload: Any,
) -> None:
"""发送媒体时显示 upload action,发送完成后恢复 typing"""
effective_thread_id = message_thread_id or cast(
str | None, payload.get("message_thread_id")
)
await cls._send_chat_action(client, user_name, upload_action, message_thread_id)
await send_coro(**payload)
await cls._send_chat_action(
client, user_name, upload_action, effective_thread_id
)
send_payload = dict(payload)
if effective_thread_id and "message_thread_id" not in send_payload:
send_payload["message_thread_id"] = effective_thread_id
await send_coro(**send_payload)
await cls._send_chat_action(
client, user_name, ChatAction.TYPING, effective_thread_id
client, user_name, ChatAction.TYPING, message_thread_id
)
@classmethod
@@ -151,16 +141,14 @@ class TelegramPlatformEvent(AstrMessageEvent):
"""
try:
if use_media_action:
media_payload = dict(payload)
if message_thread_id and "message_thread_id" not in media_payload:
media_payload["message_thread_id"] = message_thread_id
await cls._send_media_with_action(
client,
ChatAction.UPLOAD_VOICE,
client.send_voice,
user_name=user_name,
message_thread_id=message_thread_id,
voice=path,
**cast(Any, media_payload),
**cast(Any, payload),
)
else:
await client.send_voice(voice=path, **cast(Any, payload))
@@ -174,17 +162,15 @@ class TelegramPlatformEvent(AstrMessageEvent):
"To enable voice messages, go to Telegram Settings → Privacy and Security → Voice Messages → set to 'Everyone'."
)
if use_media_action:
media_payload = dict(payload)
if message_thread_id and "message_thread_id" not in media_payload:
media_payload["message_thread_id"] = message_thread_id
await cls._send_media_with_action(
client,
ChatAction.UPLOAD_DOCUMENT,
client.send_document,
user_name=user_name,
message_thread_id=message_thread_id,
document=path,
caption=caption,
**cast(Any, media_payload),
**cast(Any, payload),
)
else:
await client.send_document(
@@ -292,13 +278,6 @@ class TelegramPlatformEvent(AstrMessageEvent):
caption=i.text or None,
use_media_action=False,
)
elif isinstance(i, Video):
path = await i.convert_to_file_path()
await client.send_video(
video=path,
caption=getattr(i, "text", None) or None,
**cast(Any, payload),
)
async def send(self, message: MessageChain) -> None:
if self.get_message_type() == MessageType.GROUP_MESSAGE:
@@ -354,7 +333,7 @@ class TelegramPlatformEvent(AstrMessageEvent):
"chat_id": user_name,
}
if message_thread_id:
payload["message_thread_id"] = message_thread_id
payload["reply_to_message_id"] = message_thread_id
delta = ""
current_content = ""
@@ -396,6 +375,7 @@ class TelegramPlatformEvent(AstrMessageEvent):
ChatAction.UPLOAD_PHOTO,
self.client.send_photo,
user_name=user_name,
message_thread_id=message_thread_id,
photo=image_path,
**cast(Any, payload),
)
@@ -408,6 +388,7 @@ class TelegramPlatformEvent(AstrMessageEvent):
ChatAction.UPLOAD_DOCUMENT,
self.client.send_document,
user_name=user_name,
message_thread_id=message_thread_id,
document=path,
filename=name,
**cast(Any, payload),
@@ -425,17 +406,6 @@ class TelegramPlatformEvent(AstrMessageEvent):
use_media_action=True,
)
continue
elif isinstance(i, Video):
path = await i.convert_to_file_path()
await self._send_media_with_action(
self.client,
ChatAction.UPLOAD_VIDEO,
self.client.send_video,
user_name=user_name,
video=path,
**cast(Any, payload),
)
continue
else:
logger.warning(f"不支持的消息类型: {type(i)}")
continue
@@ -0,0 +1,465 @@
import json
import mimetypes
import shutil
import uuid
from collections.abc import Awaitable, Callable, Sequence
from pathlib import Path
from typing import Any
from astrbot.core.db.po import Attachment
from astrbot.core.message.components import (
File,
Image,
Json,
Plain,
Record,
Reply,
Video,
)
from astrbot.core.message.message_event_result import MessageChain
AttachmentGetter = Callable[[str], Awaitable[Attachment | None]]
AttachmentInserter = Callable[[str, str, str], Awaitable[Attachment | None]]
ReplyHistoryGetter = Callable[
[Any],
Awaitable[tuple[list[dict], str | None, str | None] | None],
]
MEDIA_PART_TYPES = {"image", "record", "file", "video"}
def strip_message_parts_path_fields(message_parts: list[dict]) -> list[dict]:
return [{k: v for k, v in part.items() if k != "path"} for part in message_parts]
def webchat_message_parts_have_content(message_parts: list[dict]) -> bool:
return any(
part.get("type") in ("plain", "image", "record", "file", "video")
and (part.get("text") or part.get("attachment_id") or part.get("filename"))
for part in message_parts
)
async def parse_webchat_message_parts(
message_parts: list,
*,
strict: bool = False,
include_empty_plain: bool = False,
verify_media_path_exists: bool = True,
reply_history_getter: ReplyHistoryGetter | None = None,
current_depth: int = 0,
max_reply_depth: int = 0,
cast_reply_id_to_str: bool = True,
) -> tuple[list, list[str], bool]:
"""Parse webchat message parts into components/text parts.
Returns:
tuple[list, list[str], bool]:
(components, plain_text_parts, has_non_reply_content)
"""
components = []
text_parts: list[str] = []
has_content = False
for part in message_parts:
if not isinstance(part, dict):
if strict:
raise ValueError("message part must be an object")
continue
part_type = str(part.get("type", "")).strip()
if part_type == "plain":
text = str(part.get("text", ""))
if text or include_empty_plain:
components.append(Plain(text=text))
text_parts.append(text)
if text:
has_content = True
continue
if part_type == "reply":
message_id = part.get("message_id")
if message_id is None:
if strict:
raise ValueError("reply part missing message_id")
continue
reply_chain = []
reply_message_str = str(part.get("selected_text", ""))
sender_id = None
sender_name = None
if reply_message_str:
reply_chain = [Plain(text=reply_message_str)]
elif (
reply_history_getter
and current_depth < max_reply_depth
and message_id is not None
):
reply_info = await reply_history_getter(message_id)
if reply_info:
reply_parts, sender_id, sender_name = reply_info
(
reply_chain,
reply_text_parts,
_,
) = await parse_webchat_message_parts(
reply_parts,
strict=strict,
include_empty_plain=include_empty_plain,
verify_media_path_exists=verify_media_path_exists,
reply_history_getter=reply_history_getter,
current_depth=current_depth + 1,
max_reply_depth=max_reply_depth,
cast_reply_id_to_str=cast_reply_id_to_str,
)
reply_message_str = "".join(reply_text_parts)
reply_id = str(message_id) if cast_reply_id_to_str else message_id
components.append(
Reply(
id=reply_id,
message_str=reply_message_str,
chain=reply_chain,
sender_id=sender_id,
sender_nickname=sender_name,
)
)
continue
if part_type not in MEDIA_PART_TYPES:
if strict:
raise ValueError(f"unsupported message part type: {part_type}")
continue
path = part.get("path")
if not path:
if strict:
raise ValueError(f"{part_type} part missing path")
continue
file_path = Path(str(path))
if verify_media_path_exists and not file_path.exists():
if strict:
raise ValueError(f"file not found: {file_path!s}")
continue
file_path_str = (
str(file_path.resolve()) if verify_media_path_exists else str(file_path)
)
has_content = True
if part_type == "image":
components.append(Image.fromFileSystem(file_path_str))
elif part_type == "record":
components.append(Record.fromFileSystem(file_path_str))
elif part_type == "video":
components.append(Video.fromFileSystem(file_path_str))
else:
filename = str(part.get("filename", "")).strip() or file_path.name
components.append(File(name=filename, file=file_path_str))
return components, text_parts, has_content
async def build_webchat_message_parts(
message_payload: str | list,
*,
get_attachment_by_id: AttachmentGetter,
strict: bool = False,
) -> list[dict]:
if isinstance(message_payload, str):
text = message_payload.strip()
return [{"type": "plain", "text": text}] if text else []
if not isinstance(message_payload, list):
if strict:
raise ValueError("message must be a string or list")
return []
message_parts: list[dict] = []
for part in message_payload:
if not isinstance(part, dict):
if strict:
raise ValueError("message part must be an object")
continue
part_type = str(part.get("type", "")).strip()
if part_type == "plain":
text = str(part.get("text", ""))
if text:
message_parts.append({"type": "plain", "text": text})
continue
if part_type == "reply":
message_id = part.get("message_id")
if message_id is None:
if strict:
raise ValueError("reply part missing message_id")
continue
message_parts.append(
{
"type": "reply",
"message_id": message_id,
"selected_text": str(part.get("selected_text", "")),
}
)
continue
if part_type not in MEDIA_PART_TYPES:
if strict:
raise ValueError(f"unsupported message part type: {part_type}")
continue
attachment_id = part.get("attachment_id")
if not attachment_id:
if strict:
raise ValueError(f"{part_type} part missing attachment_id")
continue
attachment = await get_attachment_by_id(str(attachment_id))
if not attachment:
if strict:
raise ValueError(f"attachment not found: {attachment_id}")
continue
attachment_path = Path(attachment.path)
message_parts.append(
{
"type": attachment.type,
"attachment_id": attachment.attachment_id,
"filename": attachment_path.name,
"path": str(attachment_path),
}
)
return message_parts
def webchat_message_parts_to_message_chain(
message_parts: list[dict],
*,
strict: bool = False,
) -> MessageChain:
components = []
has_content = False
for part in message_parts:
if not isinstance(part, dict):
if strict:
raise ValueError("message part must be an object")
continue
part_type = str(part.get("type", "")).strip()
if part_type == "plain":
text = str(part.get("text", ""))
if text:
components.append(Plain(text=text))
has_content = True
continue
if part_type == "reply":
message_id = part.get("message_id")
if message_id is None:
if strict:
raise ValueError("reply part missing message_id")
continue
components.append(
Reply(
id=str(message_id),
message_str=str(part.get("selected_text", "")),
chain=[],
)
)
continue
if part_type not in MEDIA_PART_TYPES:
if strict:
raise ValueError(f"unsupported message part type: {part_type}")
continue
path = part.get("path")
if not path:
if strict:
raise ValueError(f"{part_type} part missing path")
continue
file_path = Path(str(path))
if not file_path.exists():
if strict:
raise ValueError(f"file not found: {file_path!s}")
continue
file_path_str = str(file_path.resolve())
has_content = True
if part_type == "image":
components.append(Image.fromFileSystem(file_path_str))
elif part_type == "record":
components.append(Record.fromFileSystem(file_path_str))
elif part_type == "video":
components.append(Video.fromFileSystem(file_path_str))
else:
filename = str(part.get("filename", "")).strip() or file_path.name
components.append(File(name=filename, file=file_path_str))
if strict and (not components or not has_content):
raise ValueError("Message content is empty (reply only is not allowed)")
return MessageChain(chain=components)
async def build_message_chain_from_payload(
message_payload: str | list,
*,
get_attachment_by_id: AttachmentGetter,
strict: bool = True,
) -> MessageChain:
message_parts = await build_webchat_message_parts(
message_payload,
get_attachment_by_id=get_attachment_by_id,
strict=strict,
)
components, _, has_content = await parse_webchat_message_parts(
message_parts,
strict=strict,
)
if strict and (not components or not has_content):
raise ValueError("Message content is empty (reply only is not allowed)")
return MessageChain(chain=components)
async def create_attachment_part_from_existing_file(
filename: str,
*,
attach_type: str,
insert_attachment: AttachmentInserter,
attachments_dir: str | Path,
fallback_dirs: Sequence[str | Path] = (),
) -> dict | None:
basename = Path(filename).name
candidate_paths = [Path(attachments_dir) / basename]
candidate_paths.extend(Path(p) / basename for p in fallback_dirs)
file_path = next((path for path in candidate_paths if path.exists()), None)
if not file_path:
return None
mime_type, _ = mimetypes.guess_type(str(file_path))
attachment = await insert_attachment(
str(file_path),
attach_type,
mime_type or "application/octet-stream",
)
if not attachment:
return None
return {
"type": attach_type,
"attachment_id": attachment.attachment_id,
"filename": file_path.name,
}
async def message_chain_to_storage_message_parts(
message_chain: MessageChain,
*,
insert_attachment: AttachmentInserter,
attachments_dir: str | Path,
) -> list[dict]:
target_dir = Path(attachments_dir)
target_dir.mkdir(parents=True, exist_ok=True)
parts: list[dict] = []
for comp in message_chain.chain:
if isinstance(comp, Plain):
if comp.text:
parts.append({"type": "plain", "text": comp.text})
continue
if isinstance(comp, Json):
parts.append(
{"type": "plain", "text": json.dumps(comp.data, ensure_ascii=False)}
)
continue
if isinstance(comp, Image):
file_path = await comp.convert_to_file_path()
attachment_part = await _copy_file_to_attachment_part(
file_path=file_path,
attach_type="image",
insert_attachment=insert_attachment,
attachments_dir=target_dir,
)
if attachment_part:
parts.append(attachment_part)
continue
if isinstance(comp, Record):
file_path = await comp.convert_to_file_path()
attachment_part = await _copy_file_to_attachment_part(
file_path=file_path,
attach_type="record",
insert_attachment=insert_attachment,
attachments_dir=target_dir,
)
if attachment_part:
parts.append(attachment_part)
continue
if isinstance(comp, Video):
file_path = await comp.convert_to_file_path()
attachment_part = await _copy_file_to_attachment_part(
file_path=file_path,
attach_type="video",
insert_attachment=insert_attachment,
attachments_dir=target_dir,
)
if attachment_part:
parts.append(attachment_part)
continue
if isinstance(comp, File):
file_path = await comp.get_file()
attachment_part = await _copy_file_to_attachment_part(
file_path=file_path,
attach_type="file",
insert_attachment=insert_attachment,
attachments_dir=target_dir,
display_name=comp.name,
)
if attachment_part:
parts.append(attachment_part)
continue
return parts
async def _copy_file_to_attachment_part(
*,
file_path: str,
attach_type: str,
insert_attachment: AttachmentInserter,
attachments_dir: Path,
display_name: str | None = None,
) -> dict | None:
src_path = Path(file_path)
if not src_path.exists() or not src_path.is_file():
return None
suffix = src_path.suffix
target_path = attachments_dir / f"{uuid.uuid4().hex}{suffix}"
shutil.copy2(src_path, target_path)
mime_type, _ = mimetypes.guess_type(target_path.name)
attachment = await insert_attachment(
str(target_path),
attach_type,
mime_type or "application/octet-stream",
)
if not attachment:
return None
return {
"type": attach_type,
"attachment_id": attachment.attachment_id,
"filename": display_name or src_path.name,
}
@@ -3,12 +3,12 @@ import os
import time
import uuid
from collections.abc import Callable, Coroutine
from pathlib import Path
from typing import Any
from astrbot import logger
from astrbot.core import db_helper
from astrbot.core.db.po import PlatformMessageHistory
from astrbot.core.message.components import File, Image, Plain, Record, Reply, Video
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.platform import (
AstrBotMessage,
@@ -21,10 +21,23 @@ from astrbot.core.platform.astr_message_event import MessageSesion
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from ...register import register_platform_adapter
from .message_parts_helper import (
message_chain_to_storage_message_parts,
parse_webchat_message_parts,
)
from .webchat_event import WebChatMessageEvent
from .webchat_queue_mgr import WebChatQueueMgr, webchat_queue_mgr
def _extract_conversation_id(session_id: str) -> str:
"""Extract raw webchat conversation id from event/session id."""
if session_id.startswith("webchat!"):
parts = session_id.split("!", 2)
if len(parts) == 3:
return parts[2]
return session_id
class QueueListener:
def __init__(
self,
@@ -57,13 +70,15 @@ class WebChatAdapter(Platform):
self.settings = platform_settings
self.imgs_dir = os.path.join(get_astrbot_data_path(), "webchat", "imgs")
self.attachments_dir = Path(get_astrbot_data_path()) / "attachments"
os.makedirs(self.imgs_dir, exist_ok=True)
self.attachments_dir.mkdir(parents=True, exist_ok=True)
self.metadata = PlatformMetadata(
name="webchat",
description="webchat",
id="webchat",
support_proactive_message=False,
support_proactive_message=True,
)
self._shutdown_event = asyncio.Event()
self._webchat_queue_mgr = webchat_queue_mgr
@@ -73,10 +88,67 @@ class WebChatAdapter(Platform):
session: MessageSesion,
message_chain: MessageChain,
) -> None:
message_id = f"active_{str(uuid.uuid4())}"
await WebChatMessageEvent._send(message_id, message_chain, session.session_id)
conversation_id = _extract_conversation_id(session.session_id)
active_request_ids = self._webchat_queue_mgr.list_back_request_ids(
conversation_id
)
subscription_request_ids = [
req_id for req_id in active_request_ids if req_id.startswith("ws_sub_")
]
target_request_ids = subscription_request_ids or active_request_ids
if target_request_ids:
for request_id in target_request_ids:
await WebChatMessageEvent._send(
request_id,
message_chain,
session.session_id,
)
else:
message_id = f"active_{uuid.uuid4()!s}"
await WebChatMessageEvent._send(
message_id,
message_chain,
session.session_id,
)
should_persist = (
bool(subscription_request_ids)
or not active_request_ids
or all(req_id.startswith("active_") for req_id in active_request_ids)
)
if should_persist:
try:
await self._save_proactive_message(conversation_id, message_chain)
except Exception as e:
logger.error(
f"[WebChatAdapter] Failed to save proactive message: {e}",
exc_info=True,
)
await super().send_by_session(session, message_chain)
async def _save_proactive_message(
self,
conversation_id: str,
message_chain: MessageChain,
) -> None:
message_parts = await message_chain_to_storage_message_parts(
message_chain,
insert_attachment=db_helper.insert_attachment,
attachments_dir=self.attachments_dir,
)
if not message_parts:
return
await db_helper.insert_platform_message_history(
platform_id="webchat",
user_id=conversation_id,
content={"type": "bot", "message": message_parts},
sender_id="bot",
sender_name="bot",
)
async def _get_message_history(
self, message_id: int
) -> PlatformMessageHistory | None:
@@ -98,72 +170,30 @@ class WebChatAdapter(Platform):
Returns:
tuple[list, list[str]]: (消息组件列表, 纯文本列表)
"""
components = []
text_parts = []
for part in message_parts:
part_type = part.get("type")
if part_type == "plain":
text = part.get("text", "")
components.append(Plain(text=text))
text_parts.append(text)
elif part_type == "reply":
message_id = part.get("message_id")
reply_chain = []
reply_message_str = part.get("selected_text", "")
sender_id = None
sender_name = None
async def get_reply_parts(
message_id: Any,
) -> tuple[list[dict], str | None, str | None] | None:
history = await self._get_message_history(message_id)
if not history or not history.content:
return None
if reply_message_str:
reply_chain = [Plain(text=reply_message_str)]
reply_parts = history.content.get("message", [])
if not isinstance(reply_parts, list):
return None
# recursively get the content of the referenced message, if selected_text is empty
if not reply_message_str and depth < max_depth and message_id:
history = await self._get_message_history(message_id)
if history and history.content:
reply_parts = history.content.get("message", [])
if isinstance(reply_parts, list):
(
reply_chain,
reply_text_parts,
) = await self._parse_message_parts(
reply_parts,
depth=depth + 1,
max_depth=max_depth,
)
reply_message_str = "".join(reply_text_parts)
sender_id = history.sender_id
sender_name = history.sender_name
components.append(
Reply(
id=message_id,
chain=reply_chain,
message_str=reply_message_str,
sender_id=sender_id,
sender_nickname=sender_name,
)
)
elif part_type == "image":
path = part.get("path")
if path:
components.append(Image.fromFileSystem(path))
elif part_type == "record":
path = part.get("path")
if path:
components.append(Record.fromFileSystem(path))
elif part_type == "file":
path = part.get("path")
if path:
filename = part.get("filename") or (
os.path.basename(path) if path else "file"
)
components.append(File(name=filename, file=path))
elif part_type == "video":
path = part.get("path")
if path:
components.append(Video.fromFileSystem(path))
return reply_parts, history.sender_id, history.sender_name
components, text_parts, _ = await parse_webchat_message_parts(
message_parts,
strict=False,
include_empty_plain=True,
verify_media_path_exists=False,
reply_history_getter=get_reply_parts,
current_depth=depth,
max_reply_depth=max_depth,
cast_reply_id_to_str=False,
)
return components, text_parts
async def convert_message(self, data: tuple) -> AstrBotMessage:
@@ -14,6 +14,15 @@ from .webchat_queue_mgr import webchat_queue_mgr
attachments_dir = os.path.join(get_astrbot_data_path(), "attachments")
def _extract_conversation_id(session_id: str) -> str:
"""Extract raw webchat conversation id from event/session id."""
if session_id.startswith("webchat!"):
parts = session_id.split("!", 2)
if len(parts) == 3:
return parts[2]
return session_id
class WebChatMessageEvent(AstrMessageEvent):
def __init__(self, message_str, message_obj, platform_meta, session_id) -> None:
super().__init__(message_str, message_obj, platform_meta, session_id)
@@ -27,7 +36,7 @@ class WebChatMessageEvent(AstrMessageEvent):
streaming: bool = False,
) -> str | None:
request_id = str(message_id)
conversation_id = session_id.split("!")[-1]
conversation_id = _extract_conversation_id(session_id)
web_chat_back_queue = webchat_queue_mgr.get_or_create_back_queue(
request_id,
conversation_id,
@@ -130,7 +139,7 @@ class WebChatMessageEvent(AstrMessageEvent):
reasoning_content = ""
message_id = self.message_obj.message_id
request_id = str(message_id)
conversation_id = self.session_id.split("!")[-1]
conversation_id = _extract_conversation_id(self.session_id)
web_chat_back_queue = webchat_queue_mgr.get_or_create_back_queue(
request_id,
conversation_id,
@@ -75,6 +75,10 @@ class WebChatQueueMgr:
if task is not None:
task.cancel()
def list_back_request_ids(self, conversation_id: str) -> list[str]:
"""List active back-queue request IDs for a conversation."""
return list(self._conversation_back_requests.get(conversation_id, set()))
def has_queue(self, conversation_id: str) -> bool:
"""Check if a queue exists for the given conversation ID"""
return conversation_id in self.queues
@@ -3,7 +3,7 @@ import os
import sys
import time
import uuid
from collections.abc import Callable, Coroutine
from collections.abc import Awaitable, Callable
from typing import Any, cast
import quart
@@ -65,9 +65,7 @@ class WeixinOfficialAccountServer:
self.event_queue = event_queue
self.callback: (
Callable[[BaseMessage], Coroutine[Any, Any, str | None]] | None
) = None
self.callback: Callable[[BaseMessage], Awaitable[None]] | None = None
self.shutdown_event = asyncio.Event()
self._wx_msg_time_out = 4.0 # 微信服务器要求 5 秒内回复
@@ -48,9 +48,6 @@ class GeminiEmbeddingProvider(EmbeddingProvider):
result = await self.client.models.embed_content(
model=self.model,
contents=text,
config=types.EmbedContentConfig(
output_dimensionality=self.get_dim(),
),
)
assert result.embeddings is not None
assert result.embeddings[0].values is not None
@@ -64,9 +61,6 @@ class GeminiEmbeddingProvider(EmbeddingProvider):
result = await self.client.models.embed_content(
model=self.model,
contents=cast(types.ContentListUnion, text),
config=types.EmbedContentConfig(
output_dimensionality=self.get_dim(),
),
)
assert result.embeddings is not None
@@ -23,16 +23,12 @@ class OpenAIEmbeddingProvider(EmbeddingProvider):
if proxy:
logger.info(f"[OpenAI Embedding] 使用代理: {proxy}")
http_client = httpx.AsyncClient(proxy=proxy)
api_base = provider_config.get("embedding_api_base", "").strip()
if not api_base:
api_base = "https://api.openai.com/v1"
else:
api_base = api_base.removesuffix("/")
if not api_base.endswith("/v1"):
api_base = f"{api_base}/v1"
self.client = AsyncOpenAI(
api_key=provider_config.get("embedding_api_key"),
base_url=api_base,
base_url=provider_config.get(
"embedding_api_base",
"https://api.openai.com/v1",
),
timeout=int(provider_config.get("timeout", 20)),
http_client=http_client,
)
@@ -40,20 +36,12 @@ class OpenAIEmbeddingProvider(EmbeddingProvider):
async def get_embedding(self, text: str) -> list[float]:
"""获取文本的嵌入"""
embedding = await self.client.embeddings.create(
input=text,
model=self.model,
dimensions=self.get_dim(),
)
embedding = await self.client.embeddings.create(input=text, model=self.model)
return embedding.data[0].embedding
async def get_embeddings(self, text: list[str]) -> list[list[float]]:
"""批量获取文本的嵌入"""
embeddings = await self.client.embeddings.create(
input=text,
model=self.model,
dimensions=self.get_dim(),
)
embeddings = await self.client.embeddings.create(input=text, model=self.model)
return [item.embedding for item in embeddings.data]
def get_dim(self) -> int:
+62 -13
View File
@@ -1,19 +1,68 @@
# 兼容导出: Provider 从 provider 模块重新导出
from astrbot.core import html_renderer
from astrbot.core.provider import Provider
from astrbot.core.star.star_tools import StarTools
from astrbot.core.utils.command_parser import CommandParserMixin
from astrbot.core.utils.plugin_kv_store import PluginKVStoreMixin
from .base import Star
from .context import Context
from .star import StarMetadata, star_map, star_registry
from .star_manager import PluginManager
from .star_tools import StarTools
__all__ = [
"Context",
"PluginManager",
"Provider",
"Star",
"StarMetadata",
"StarTools",
"star_map",
"star_registry",
]
class Star(CommandParserMixin, PluginKVStoreMixin):
"""所有插件(Star)的父类,所有插件都应该继承于这个类"""
author: str
name: str
def __init__(self, context: Context, config: dict | None = None) -> None:
StarTools.initialize(context)
self.context = context
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
if not star_map.get(cls.__module__):
metadata = StarMetadata(
star_cls_type=cls,
module_path=cls.__module__,
)
star_map[cls.__module__] = metadata
star_registry.append(metadata)
else:
star_map[cls.__module__].star_cls_type = cls
star_map[cls.__module__].module_path = cls.__module__
async def text_to_image(self, text: str, return_url=True) -> str:
"""将文本转换为图片"""
return await html_renderer.render_t2i(
text,
return_url=return_url,
template_name=self.context._config.get("t2i_active_template"),
)
async def html_render(
self,
tmpl: str,
data: dict,
return_url=True,
options: dict | None = None,
) -> str:
"""渲染 HTML"""
return await html_renderer.render_custom_template(
tmpl,
data,
return_url=return_url,
options=options,
)
async def initialize(self) -> None:
"""当插件被激活时会调用这个方法"""
async def terminate(self) -> None:
"""当插件被禁用、重载插件时会调用这个方法"""
def __del__(self) -> None:
"""[Deprecated] 当插件被禁用、重载插件时会调用这个方法"""
__all__ = ["Context", "PluginManager", "Provider", "Star", "StarMetadata", "StarTools"]
-87
View File
@@ -1,87 +0,0 @@
from __future__ import annotations
import logging
from typing import Any, Protocol
from astrbot.core import html_renderer
from astrbot.core.utils.command_parser import CommandParserMixin
from astrbot.core.utils.plugin_kv_store import PluginKVStoreMixin
from .star import StarMetadata, star_map, star_registry
logger = logging.getLogger("astrbot")
class Star(CommandParserMixin, PluginKVStoreMixin):
"""所有插件(Star)的父类,所有插件都应该继承于这个类"""
author: str
name: str
class _ContextLike(Protocol):
def get_config(self, umo: str | None = None) -> Any: ...
def __init__(self, context: _ContextLike, config: dict | None = None) -> None:
self.context = context
def _get_context_config(self) -> Any:
get_config = getattr(self.context, "get_config", None)
if callable(get_config):
try:
return get_config()
except Exception as e:
logger.debug(f"get_config() failed: {e}")
return None
return getattr(self.context, "_config", None)
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
if not star_map.get(cls.__module__):
metadata = StarMetadata(
star_cls_type=cls,
module_path=cls.__module__,
)
star_map[cls.__module__] = metadata
star_registry.append(metadata)
else:
star_map[cls.__module__].star_cls_type = cls
star_map[cls.__module__].module_path = cls.__module__
async def text_to_image(self, text: str, return_url=True) -> str:
"""将文本转换为图片"""
config_obj = self._get_context_config()
template_name = None
if hasattr(config_obj, "get"):
try:
template_name = config_obj.get("t2i_active_template")
except Exception:
template_name = None
return await html_renderer.render_t2i(
text,
return_url=return_url,
template_name=template_name,
)
async def html_render(
self,
tmpl: str,
data: dict,
return_url=True,
options: dict | None = None,
) -> str:
"""渲染 HTML"""
return await html_renderer.render_custom_template(
tmpl,
data,
return_url=return_url,
options=options,
)
async def initialize(self) -> None:
"""当插件被激活时会调用这个方法"""
async def terminate(self) -> None:
"""当插件被禁用、重载插件时会调用这个方法"""
def __del__(self) -> None:
"""[Deprecated] 当插件被禁用、重载插件时会调用这个方法"""
+4 -16
View File
@@ -1,9 +1,7 @@
from __future__ import annotations
import logging
from asyncio import Queue
from collections.abc import Awaitable, Callable
from typing import TYPE_CHECKING, Any, Protocol
from typing import Any
from deprecated import deprecated
@@ -14,12 +12,14 @@ from astrbot.core.agent.tool import ToolSet
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.conversation_mgr import ConversationManager
from astrbot.core.cron.manager import CronJobManager
from astrbot.core.db import BaseDatabase
from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.persona_mgr import PersonaManager
from astrbot.core.platform import Platform
from astrbot.core.platform.astr_message_event import AstrMessageEvent, MessageSesion
from astrbot.core.platform.manager import PlatformManager
from astrbot.core.platform_message_history_mgr import PlatformMessageHistoryManager
from astrbot.core.provider.entities import LLMResponse, ProviderRequest, ProviderType
from astrbot.core.provider.func_tool_manager import FunctionTool, FunctionToolManager
@@ -45,15 +45,6 @@ from .star_handler import EventType, StarHandlerMetadata, star_handlers_registry
logger = logging.getLogger("astrbot")
if TYPE_CHECKING:
from astrbot.core.cron.manager import CronJobManager
else:
CronJobManager = Any
class PlatformManagerProtocol(Protocol):
platform_insts: list[Platform]
class Context:
"""暴露给插件的接口上下文。"""
@@ -70,7 +61,7 @@ class Context:
config: AstrBotConfig,
db: BaseDatabase,
provider_manager: ProviderManager,
platform_manager: PlatformManagerProtocol,
platform_manager: PlatformManager,
conversation_manager: ConversationManager,
message_history_manager: PlatformMessageHistoryManager,
persona_manager: PersonaManager,
@@ -457,9 +448,6 @@ class Context:
if platform.meta().id == session.platform_name:
await platform.send_by_session(session, message_chain)
return True
logger.warning(
f"cannot find platform for session {str(session)}, message not sent"
)
return False
def add_llm_tools(self, *tools: FunctionTool) -> None:
+1 -1
View File
@@ -1,6 +1,6 @@
import warnings
from astrbot.core.star.star import StarMetadata, star_map
from astrbot.core.star import StarMetadata, star_map
_warned_register_star = False
+5 -4
View File
@@ -11,6 +11,7 @@ from astrbot.core.agent.agent import Agent
from astrbot.core.agent.handoff import HandoffTool
from astrbot.core.agent.hooks import BaseAgentRunHooks
from astrbot.core.agent.tool import FunctionTool
from astrbot.core.astr_agent_context import AstrAgentContext
from astrbot.core.message.message_event_result import MessageEventResult
from astrbot.core.provider.func_tool_manager import PY_TO_JSON_TYPE, SUPPORTED_TYPES
from astrbot.core.provider.register import llm_tools
@@ -616,7 +617,7 @@ class RegisteringAgent:
kwargs["registering_agent"] = self
return register_llm_tool(*args, **kwargs)
def __init__(self, agent: Agent[Any]) -> None:
def __init__(self, agent: Agent[AstrAgentContext]) -> None:
self._agent = agent
@@ -624,7 +625,7 @@ def register_agent(
name: str,
instruction: str,
tools: list[str | FunctionTool] | None = None,
run_hooks: BaseAgentRunHooks[Any] | None = None,
run_hooks: BaseAgentRunHooks[AstrAgentContext] | None = None,
):
"""注册一个 Agent
@@ -638,12 +639,12 @@ def register_agent(
tools_ = tools or []
def decorator(awaitable: Callable[..., Awaitable[Any]]):
AstrAgent = Agent[Any]
AstrAgent = Agent[AstrAgentContext]
agent = AstrAgent(
name=name,
instructions=instruction,
tools=tools_,
run_hooks=run_hooks or BaseAgentRunHooks[Any](),
run_hooks=run_hooks or BaseAgentRunHooks[AstrAgentContext](),
)
handoff_tool = HandoffTool(agent=agent)
handoff_tool.handler = awaitable
-16
View File
@@ -105,22 +105,6 @@ class StarHandlerRegistry(Generic[T]):
plugins_name: list[str] | None = None,
) -> list[StarHandlerMetadata[Callable[..., Awaitable[Any]]]]: ...
@overload
def get_handlers_by_event_type(
self,
event_type: Literal[EventType.OnPluginLoadedEvent],
only_activated=True,
plugins_name: list[str] | None = None,
) -> list[StarHandlerMetadata[Callable[..., Awaitable[Any]]]]: ...
@overload
def get_handlers_by_event_type(
self,
event_type: Literal[EventType.OnPluginUnloadedEvent],
only_activated=True,
plugins_name: list[str] | None = None,
) -> list[StarHandlerMetadata[Callable[..., Awaitable[Any]]]]: ...
@overload
def get_handlers_by_event_type(
self,
+10 -44
View File
@@ -49,13 +49,10 @@ class PluginVersionIncompatibleError(Exception):
class PluginManager:
def __init__(self, context: Context, config: AstrBotConfig) -> None:
from .star_tools import StarTools
self.updator = PluginUpdator()
self.context = context
self.context._star_manager = self # type: ignore
StarTools.initialize(context)
self.config = config
self.plugin_store_path = get_astrbot_plugin_path()
@@ -388,33 +385,6 @@ class PluginManager:
except KeyError:
logger.warning(f"模块 {module_name} 未载入")
def _cleanup_plugin_state(self, dir_name: str) -> None:
plugin_root_name = "data.plugins."
# 清理 sys.modules
for key in list(sys.modules.keys()):
if key.startswith(f"{plugin_root_name}{dir_name}"):
logger.info(f"清除了插件{dir_name}中的{key}模块")
del sys.modules[key]
possible_paths = [
f"{plugin_root_name}{dir_name}.main",
f"{plugin_root_name}{dir_name}.{dir_name}",
]
# 清理 handlers
for path in possible_paths:
handlers = star_handlers_registry.get_handlers_by_module_name(path)
for handler in handlers:
star_handlers_registry.remove(handler)
logger.info(f"清理处理器: {handler.handler_name}")
# 清理工具
for tool in list(llm_tools.func_list):
if tool.handler_module_path in possible_paths:
llm_tools.func_list.remove(tool)
logger.info(f"清理工具: {tool.name}")
async def reload_failed_plugin(self, dir_name):
"""
重新加载未注册加载失败的插件
@@ -425,21 +395,17 @@ class PluginManager:
- success (bool): 重载是否成功
- error_message (str|None): 错误信息成功时为 None
"""
async with self._pm_lock:
if dir_name not in self.failed_plugin_dict:
return False, "插件不存在于失败列表中"
self._cleanup_plugin_state(dir_name)
success, error = await self.load(specified_dir_name=dir_name)
if success:
self.failed_plugin_dict.pop(dir_name, None)
if not self.failed_plugin_dict:
self.failed_plugin_info = ""
return success, None
else:
return False, error
if dir_name in self.failed_plugin_dict:
success, error = await self.load(specified_dir_name=dir_name)
if success:
self.failed_plugin_dict.pop(dir_name, None)
if not self.failed_plugin_dict:
self.failed_plugin_info = ""
return success, None
else:
return False, error
return False, "插件不存在于失败列表中"
async def reload(self, specified_plugin_name=None):
"""重新加载插件
+1 -21
View File
@@ -1,5 +1,4 @@
from datetime import datetime
from typing import Any
from pydantic import Field
from pydantic.dataclasses import dataclass
@@ -9,14 +8,6 @@ from astrbot.core.agent.tool import FunctionTool, ToolExecResult
from astrbot.core.astr_agent_context import AstrAgentContext
def _extract_job_session(job: Any) -> str | None:
payload = getattr(job, "payload", None)
if not isinstance(payload, dict):
return None
session = payload.get("session")
return str(session) if session is not None else None
@dataclass
class CreateActiveCronTool(FunctionTool[AstrAgentContext]):
name: str = "create_future_task"
@@ -128,15 +119,9 @@ class DeleteCronJobTool(FunctionTool[AstrAgentContext]):
cron_mgr = context.context.context.cron_manager
if cron_mgr is None:
return "error: cron manager is not available."
current_umo = context.context.event.unified_msg_origin
job_id = kwargs.get("job_id")
if not job_id:
return "error: job_id is required."
job = await cron_mgr.db.get_cron_job(str(job_id))
if not job:
return f"error: cron job {job_id} not found."
if _extract_job_session(job) != current_umo:
return "error: you can only delete future tasks in the current umo."
await cron_mgr.delete_job(str(job_id))
return f"Deleted cron job {job_id}."
@@ -163,13 +148,8 @@ class ListCronJobsTool(FunctionTool[AstrAgentContext]):
cron_mgr = context.context.context.cron_manager
if cron_mgr is None:
return "error: cron manager is not available."
current_umo = context.context.event.unified_msg_origin
job_type = kwargs.get("job_type")
jobs = [
job
for job in await cron_mgr.list_jobs(job_type)
if _extract_job_session(job) == current_umo
]
jobs = await cron_mgr.list_jobs(job_type)
if not jobs:
return "No cron jobs found."
lines = []
@@ -19,7 +19,7 @@ from astrbot.core.message.components import (
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.utils.string_utils import normalize_and_dedupe_strings
from .image_refs import looks_like_image_file_name
from .image_refs import looks_like_image_file_name, normalize_file_like_url
from .settings import SETTINGS, QuotedMessageParserSettings
_FORWARD_PLACEHOLDER_PATTERN = re.compile(
@@ -296,11 +296,11 @@ def _parse_onebot_segments(
or "file"
)
text_parts.append(f"[File:{file_name}]")
candidate_url = seg_data.get("url", "")
candidate_url = seg_data.get("url")
if (
isinstance(candidate_url, str)
and candidate_url.strip()
and looks_like_image_file_name(candidate_url)
and looks_like_image_file_name(normalize_file_like_url(candidate_url))
):
image_refs.append(candidate_url.strip())
candidate_file = seg_data.get("file")
@@ -308,7 +308,11 @@ def _parse_onebot_segments(
isinstance(candidate_file, str)
and candidate_file.strip()
and looks_like_image_file_name(
seg_data.get("name") or seg_data.get("file_name") or candidate_file
normalize_file_like_url(
seg_data.get("name")
or seg_data.get("file_name")
or candidate_file
)
)
):
image_refs.append(candidate_file.strip())
@@ -364,9 +368,7 @@ def _extract_text_forward_ids_and_images_from_forward_nodes(
if not isinstance(node, dict):
continue
sender = node.get("sender")
if not isinstance(sender, dict):
sender = {}
sender = node.get("sender") if isinstance(node.get("sender"), dict) else {}
sender_name = (
sender.get("nickname")
or sender.get("card")
@@ -1,7 +1,6 @@
from __future__ import annotations
from collections.abc import Awaitable
from typing import Any, Protocol
from typing import Any
from astrbot import logger
from astrbot.core.platform.astr_message_event import AstrMessageEvent
@@ -18,10 +17,6 @@ def _unwrap_action_response(ret: dict[str, Any] | None) -> dict[str, Any]:
return ret
class CallAction(Protocol):
def __call__(self, action: str, **params: Any) -> Awaitable[Any] | Any: ...
class OneBotClient:
def __init__(
self,
@@ -32,7 +27,7 @@ class OneBotClient:
self._settings = settings
@staticmethod
def _resolve_call_action(event: AstrMessageEvent) -> CallAction | None:
def _resolve_call_action(event: AstrMessageEvent):
bot = getattr(event, "bot", None)
api = getattr(bot, "api", None)
call_action = getattr(api, "call_action", None)
+26 -92
View File
@@ -1,6 +1,5 @@
import asyncio
import json
import mimetypes
import os
import re
import uuid
@@ -14,6 +13,12 @@ from astrbot.core import logger, sp
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.db import BaseDatabase
from astrbot.core.platform.message_type import MessageType
from astrbot.core.platform.sources.webchat.message_parts_helper import (
build_webchat_message_parts,
create_attachment_part_from_existing_file,
strip_message_parts_path_fields,
webchat_message_parts_have_content,
)
from astrbot.core.platform.sources.webchat.webchat_queue_mgr import webchat_queue_mgr
from astrbot.core.utils.active_event_registry import active_event_registry
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
@@ -166,83 +171,24 @@ class ChatRoute(Route):
)
async def _build_user_message_parts(self, message: str | list) -> list[dict]:
"""构建用户消息的部分列表
Args:
message: 文本消息 (str) 或消息段列表 (list)
"""
parts = []
if isinstance(message, list):
for part in message:
part_type = part.get("type")
if part_type == "plain":
parts.append({"type": "plain", "text": part.get("text", "")})
elif part_type == "reply":
parts.append(
{
"type": "reply",
"message_id": part.get("message_id"),
"selected_text": part.get("selected_text", ""),
}
)
elif attachment_id := part.get("attachment_id"):
attachment = await self.db.get_attachment_by_id(attachment_id)
if attachment:
parts.append(
{
"type": attachment.type,
"attachment_id": attachment.attachment_id,
"filename": os.path.basename(attachment.path),
"path": attachment.path, # will be deleted
}
)
return parts
if message:
parts.append({"type": "plain", "text": message})
return parts
"""构建用户消息的部分列表"""
return await build_webchat_message_parts(
message,
get_attachment_by_id=self.db.get_attachment_by_id,
strict=False,
)
async def _create_attachment_from_file(
self, filename: str, attach_type: str
) -> dict | None:
"""从本地文件创建 attachment 并返回消息部分
用于处理 bot 回复中的媒体文件
Args:
filename: 存储的文件名
attach_type: 附件类型 (image, record, file, video)
"""
basename = os.path.basename(filename)
candidate_paths = [
os.path.join(self.attachments_dir, basename),
os.path.join(self.legacy_img_dir, basename),
]
file_path = next((p for p in candidate_paths if os.path.exists(p)), None)
if not file_path:
return None
# guess mime type
mime_type, _ = mimetypes.guess_type(filename)
if not mime_type:
mime_type = "application/octet-stream"
# insert attachment
attachment = await self.db.insert_attachment(
path=file_path,
type=attach_type,
mime_type=mime_type,
"""从本地文件创建 attachment 并返回消息部分"""
return await create_attachment_part_from_existing_file(
filename,
attach_type=attach_type,
insert_attachment=self.db.insert_attachment,
attachments_dir=self.attachments_dir,
fallback_dirs=[self.legacy_img_dir],
)
if not attachment:
return None
return {
"type": attach_type,
"attachment_id": attachment.attachment_id,
"filename": os.path.basename(file_path),
}
def _extract_web_search_refs(
self, accumulated_text: str, accumulated_parts: list
@@ -356,21 +302,6 @@ class ChatRoute(Route):
selected_model = post_data.get("selected_model")
enable_streaming = post_data.get("enable_streaming", True)
# 检查消息是否为空
if isinstance(message, list):
has_content = any(
part.get("type") in ("plain", "image", "record", "file", "video")
for part in message
)
if not has_content:
return (
Response()
.error("Message content is empty (reply only is not allowed)")
.__dict__
)
elif not message:
return Response().error("Message are both empty").__dict__
if not session_id:
return Response().error("session_id is empty").__dict__
@@ -378,6 +309,12 @@ class ChatRoute(Route):
# 构建用户消息段(包含 path 用于传递给 adapter
message_parts = await self._build_user_message_parts(message)
if not webchat_message_parts_have_content(message_parts):
return (
Response()
.error("Message content is empty (reply only is not allowed)")
.__dict__
)
message_id = str(uuid.uuid4())
back_queue = webchat_queue_mgr.get_or_create_back_queue(
@@ -583,10 +520,7 @@ class ChatRoute(Route):
),
)
message_parts_for_storage = []
for part in message_parts:
part_copy = {k: v for k, v in part.items() if k != "path"}
message_parts_for_storage.append(part_copy)
message_parts_for_storage = strip_message_parts_path_fields(message_parts)
await self.platform_history_mgr.insert(
platform_id="webchat",
+1 -17
View File
@@ -754,22 +754,6 @@ class ConfigRoute(Route):
if not provider_type:
return Response().error("provider_config 缺少 type 字段").__dict__
# 首次添加某类提供商时,provider_cls_map 可能尚未注册该适配器
if provider_type not in provider_cls_map:
try:
self.core_lifecycle.provider_manager.dynamic_import_provider(
provider_type,
)
except ImportError:
logger.error(traceback.format_exc())
return (
Response()
.error(
"提供商适配器加载失败,请检查提供商类型配置或查看服务端日志"
)
.__dict__
)
# 获取对应的 provider 类
if provider_type not in provider_cls_map:
return (
@@ -795,7 +779,7 @@ class ConfigRoute(Route):
if inspect.iscoroutinefunction(init_fn):
await init_fn()
# 通过实际请求验证当前 embedding_dimensions 是否可用
# 获取嵌入向量维度
vec = await inst.get_embedding("echo")
dim = len(vec)
+1 -3
View File
@@ -148,6 +148,7 @@ class ConversationRoute(Route):
user_id = data.get("user_id")
cid = data.get("cid")
title = data.get("title")
persona_id = data.get("persona_id", "")
if not user_id or not cid:
return Response().error("缺少必要参数: user_id 和 cid").__dict__
@@ -157,9 +158,6 @@ class ConversationRoute(Route):
)
if not conversation:
return Response().error("对话不存在").__dict__
persona_id = data.get("persona_id", conversation.persona_id)
if title is not None or persona_id is not None:
await self.conv_mgr.update_conversation(
unified_msg_origin=user_id,
+509 -3
View File
@@ -1,6 +1,7 @@
import asyncio
import json
import os
import re
import time
import uuid
import wave
@@ -10,9 +11,16 @@ import jwt
from quart import websocket
from astrbot import logger
from astrbot.core import sp
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.platform.sources.webchat.message_parts_helper import (
build_webchat_message_parts,
create_attachment_part_from_existing_file,
strip_message_parts_path_fields,
webchat_message_parts_have_content,
)
from astrbot.core.platform.sources.webchat.webchat_queue_mgr import webchat_queue_mgr
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from astrbot.core.utils.astrbot_path import get_astrbot_data_path, get_astrbot_temp_path
from .route import Route, RouteContext
@@ -30,6 +38,9 @@ class LiveChatSession:
self.audio_frames: list[bytes] = []
self.current_stamp: str | None = None
self.temp_audio_path: str | None = None
self.chat_subscriptions: dict[str, str] = {}
self.chat_subscription_tasks: dict[str, asyncio.Task] = {}
self.ws_send_lock = asyncio.Lock()
def start_speaking(self, stamp: str) -> None:
"""开始说话"""
@@ -106,13 +117,26 @@ class LiveChatRoute(Route):
self.core_lifecycle = core_lifecycle
self.db = db
self.plugin_manager = core_lifecycle.plugin_manager
self.platform_history_mgr = core_lifecycle.platform_message_history_manager
self.sessions: dict[str, LiveChatSession] = {}
self.attachments_dir = os.path.join(get_astrbot_data_path(), "attachments")
self.legacy_img_dir = os.path.join(get_astrbot_data_path(), "webchat", "imgs")
os.makedirs(self.attachments_dir, exist_ok=True)
# 注册 WebSocket 路由
self.app.websocket("/api/live_chat/ws")(self.live_chat_ws)
self.app.websocket("/api/unified_chat/ws")(self.unified_chat_ws)
async def live_chat_ws(self) -> None:
"""Live Chat WebSocket 处理器"""
"""Legacy Live Chat WebSocket 处理器(默认 ct=live"""
await self._unified_ws_loop(force_ct="live")
async def unified_chat_ws(self) -> None:
"""Unified Chat WebSocket 处理器(支持 ct=live/chat"""
await self._unified_ws_loop(force_ct=None)
async def _unified_ws_loop(self, force_ct: str | None = None) -> None:
"""统一 WebSocket 循环"""
# WebSocket 不能通过 header 传递 token,需要从 query 参数获取
# 注意:WebSocket 上下文使用 websocket.args 而不是 request.args
token = websocket.args.get("token")
@@ -140,7 +164,11 @@ class LiveChatRoute(Route):
try:
while True:
message = await websocket.receive_json()
await self._handle_message(live_session, message)
ct = force_ct or message.get("ct", "live")
if ct == "chat":
await self._handle_chat_message(live_session, message)
else:
await self._handle_message(live_session, message)
except Exception as e:
logger.error(f"[Live Chat] WebSocket 错误: {e}", exc_info=True)
@@ -148,10 +176,488 @@ class LiveChatRoute(Route):
finally:
# 清理会话
if session_id in self.sessions:
await self._cleanup_chat_subscriptions(live_session)
live_session.cleanup()
del self.sessions[session_id]
logger.info(f"[Live Chat] WebSocket 连接关闭: {username}")
async def _create_attachment_from_file(
self, filename: str, attach_type: str
) -> dict | None:
"""从本地文件创建 attachment 并返回消息部分。"""
return await create_attachment_part_from_existing_file(
filename,
attach_type=attach_type,
insert_attachment=self.db.insert_attachment,
attachments_dir=self.attachments_dir,
fallback_dirs=[self.legacy_img_dir],
)
def _extract_web_search_refs(
self, accumulated_text: str, accumulated_parts: list
) -> dict:
"""从消息中提取 web_search 引用。"""
supported = ["web_search_tavily", "web_search_bocha"]
web_search_results = {}
tool_call_parts = [
p
for p in accumulated_parts
if p.get("type") == "tool_call" and p.get("tool_calls")
]
for part in tool_call_parts:
for tool_call in part["tool_calls"]:
if tool_call.get("name") not in supported or not tool_call.get(
"result"
):
continue
try:
result_data = json.loads(tool_call["result"])
for item in result_data.get("results", []):
if idx := item.get("index"):
web_search_results[idx] = {
"url": item.get("url"),
"title": item.get("title"),
"snippet": item.get("snippet"),
}
except (json.JSONDecodeError, KeyError):
pass
if not web_search_results:
return {}
ref_indices = {
m.strip() for m in re.findall(r"<ref>(.*?)</ref>", accumulated_text)
}
used_refs = []
for ref_index in ref_indices:
if ref_index not in web_search_results:
continue
payload = {"index": ref_index, **web_search_results[ref_index]}
if favicon := sp.temporary_cache.get("_ws_favicon", {}).get(payload["url"]):
payload["favicon"] = favicon
used_refs.append(payload)
return {"used": used_refs} if used_refs else {}
async def _save_bot_message(
self,
webchat_conv_id: str,
text: str,
media_parts: list,
reasoning: str,
agent_stats: dict,
refs: dict,
):
"""保存 bot 消息到历史记录。"""
bot_message_parts = []
bot_message_parts.extend(media_parts)
if text:
bot_message_parts.append({"type": "plain", "text": text})
new_his = {"type": "bot", "message": bot_message_parts}
if reasoning:
new_his["reasoning"] = reasoning
if agent_stats:
new_his["agent_stats"] = agent_stats
if refs:
new_his["refs"] = refs
return await self.platform_history_mgr.insert(
platform_id="webchat",
user_id=webchat_conv_id,
content=new_his,
sender_id="bot",
sender_name="bot",
)
async def _send_chat_payload(self, session: LiveChatSession, payload: dict) -> None:
async with session.ws_send_lock:
await websocket.send_json(payload)
async def _forward_chat_subscription(
self,
session: LiveChatSession,
chat_session_id: str,
request_id: str,
) -> None:
back_queue = webchat_queue_mgr.get_or_create_back_queue(
request_id, chat_session_id
)
try:
while True:
result = await back_queue.get()
if not result:
continue
await self._send_chat_payload(session, {"ct": "chat", **result})
except asyncio.CancelledError:
pass
except Exception as e:
logger.error(
f"[Live Chat] chat subscription forward failed ({chat_session_id}): {e}",
exc_info=True,
)
finally:
webchat_queue_mgr.remove_back_queue(request_id)
if session.chat_subscriptions.get(chat_session_id) == request_id:
session.chat_subscriptions.pop(chat_session_id, None)
session.chat_subscription_tasks.pop(chat_session_id, None)
async def _ensure_chat_subscription(
self,
session: LiveChatSession,
chat_session_id: str,
) -> str:
existing_request_id = session.chat_subscriptions.get(chat_session_id)
existing_task = session.chat_subscription_tasks.get(chat_session_id)
if existing_request_id and existing_task and not existing_task.done():
return existing_request_id
request_id = f"ws_sub_{uuid.uuid4().hex}"
session.chat_subscriptions[chat_session_id] = request_id
task = asyncio.create_task(
self._forward_chat_subscription(session, chat_session_id, request_id),
name=f"chat_ws_sub_{chat_session_id}",
)
session.chat_subscription_tasks[chat_session_id] = task
return request_id
async def _cleanup_chat_subscriptions(self, session: LiveChatSession) -> None:
tasks = list(session.chat_subscription_tasks.values())
for task in tasks:
task.cancel()
if tasks:
await asyncio.gather(*tasks, return_exceptions=True)
for request_id in list(session.chat_subscriptions.values()):
webchat_queue_mgr.remove_back_queue(request_id)
session.chat_subscriptions.clear()
session.chat_subscription_tasks.clear()
async def _handle_chat_message(
self, session: LiveChatSession, message: dict
) -> None:
"""处理 Chat Mode 消息(ct=chat"""
msg_type = message.get("t")
if msg_type == "bind":
chat_session_id = message.get("session_id")
if not isinstance(chat_session_id, str) or not chat_session_id:
await self._send_chat_payload(
session,
{
"ct": "chat",
"t": "error",
"data": "session_id is required",
"code": "INVALID_MESSAGE_FORMAT",
},
)
return
request_id = await self._ensure_chat_subscription(session, chat_session_id)
await self._send_chat_payload(
session,
{
"ct": "chat",
"type": "session_bound",
"session_id": chat_session_id,
"message_id": request_id,
},
)
return
if msg_type == "interrupt":
session.should_interrupt = True
await self._send_chat_payload(
session,
{
"ct": "chat",
"t": "error",
"data": "INTERRUPTED",
"code": "INTERRUPTED",
},
)
return
if msg_type != "send":
await self._send_chat_payload(
session,
{
"ct": "chat",
"t": "error",
"data": f"Unsupported message type: {msg_type}",
"code": "INVALID_MESSAGE_FORMAT",
},
)
return
if session.is_processing:
await self._send_chat_payload(
session,
{
"ct": "chat",
"t": "error",
"data": "Session is busy",
"code": "PROCESSING_ERROR",
},
)
return
payload = message.get("message")
session_id = message.get("session_id") or session.session_id
message_id = message.get("message_id") or str(uuid.uuid4())
selected_provider = message.get("selected_provider")
selected_model = message.get("selected_model")
selected_stt_provider = message.get("selected_stt_provider")
selected_tts_provider = message.get("selected_tts_provider")
persona_prompt = message.get("persona_prompt")
show_reasoning = message.get("show_reasoning")
enable_streaming = message.get("enable_streaming", True)
if not isinstance(payload, list):
await self._send_chat_payload(
session,
{
"ct": "chat",
"t": "error",
"data": "message must be list",
"code": "INVALID_MESSAGE_FORMAT",
},
)
return
message_parts = await self._build_chat_message_parts(payload)
has_content = webchat_message_parts_have_content(message_parts)
if not has_content:
await self._send_chat_payload(
session,
{
"ct": "chat",
"t": "error",
"data": "Message content is empty",
"code": "INVALID_MESSAGE_FORMAT",
},
)
return
await self._ensure_chat_subscription(session, session_id)
session.is_processing = True
session.should_interrupt = False
back_queue = webchat_queue_mgr.get_or_create_back_queue(message_id, session_id)
try:
chat_queue = webchat_queue_mgr.get_or_create_queue(session_id)
await chat_queue.put(
(
session.username,
session_id,
{
"message": message_parts,
"selected_provider": selected_provider,
"selected_model": selected_model,
"selected_stt_provider": selected_stt_provider,
"selected_tts_provider": selected_tts_provider,
"persona_prompt": persona_prompt,
"show_reasoning": show_reasoning,
"enable_streaming": enable_streaming,
"message_id": message_id,
},
),
)
message_parts_for_storage = strip_message_parts_path_fields(message_parts)
await self.platform_history_mgr.insert(
platform_id="webchat",
user_id=session_id,
content={"type": "user", "message": message_parts_for_storage},
sender_id=session.username,
sender_name=session.username,
)
accumulated_parts = []
accumulated_text = ""
accumulated_reasoning = ""
tool_calls = {}
agent_stats = {}
refs = {}
while True:
if session.should_interrupt:
session.should_interrupt = False
break
try:
result = await asyncio.wait_for(back_queue.get(), timeout=1)
except asyncio.TimeoutError:
continue
if not result:
continue
if result.get("message_id") and result.get("message_id") != message_id:
continue
result_text = result.get("data", "")
msg_type = result.get("type")
streaming = result.get("streaming", False)
chain_type = result.get("chain_type")
if chain_type == "agent_stats":
try:
parsed_agent_stats = json.loads(result_text)
agent_stats = parsed_agent_stats
await self._send_chat_payload(
session,
{
"ct": "chat",
"type": "agent_stats",
"data": parsed_agent_stats,
},
)
except Exception:
pass
continue
outgoing = {"ct": "chat", **result}
await self._send_chat_payload(session, outgoing)
if msg_type == "plain":
if chain_type == "tool_call":
try:
tool_call = json.loads(result_text)
tool_calls[tool_call.get("id")] = tool_call
if accumulated_text:
accumulated_parts.append(
{"type": "plain", "text": accumulated_text}
)
accumulated_text = ""
except Exception:
pass
elif chain_type == "tool_call_result":
try:
tcr = json.loads(result_text)
tc_id = tcr.get("id")
if tc_id in tool_calls:
tool_calls[tc_id]["result"] = tcr.get("result")
tool_calls[tc_id]["finished_ts"] = tcr.get("ts")
accumulated_parts.append(
{
"type": "tool_call",
"tool_calls": [tool_calls[tc_id]],
}
)
tool_calls.pop(tc_id, None)
except Exception:
pass
elif chain_type == "reasoning":
accumulated_reasoning += result_text
elif streaming:
accumulated_text += result_text
else:
accumulated_text = result_text
elif msg_type == "image":
filename = str(result_text).replace("[IMAGE]", "")
part = await self._create_attachment_from_file(filename, "image")
if part:
accumulated_parts.append(part)
elif msg_type == "record":
filename = str(result_text).replace("[RECORD]", "")
part = await self._create_attachment_from_file(filename, "record")
if part:
accumulated_parts.append(part)
elif msg_type == "file":
filename = str(result_text).replace("[FILE]", "").split("|", 1)[0]
part = await self._create_attachment_from_file(filename, "file")
if part:
accumulated_parts.append(part)
elif msg_type == "video":
filename = str(result_text).replace("[VIDEO]", "").split("|", 1)[0]
part = await self._create_attachment_from_file(filename, "video")
if part:
accumulated_parts.append(part)
should_save = False
if msg_type == "end":
should_save = bool(
accumulated_parts
or accumulated_text
or accumulated_reasoning
or refs
or agent_stats
)
elif (streaming and msg_type == "complete") or not streaming:
if chain_type not in (
"tool_call",
"tool_call_result",
"agent_stats",
):
should_save = True
if should_save:
try:
refs = self._extract_web_search_refs(
accumulated_text,
accumulated_parts,
)
except Exception as e:
logger.exception(
f"[Live Chat] Failed to extract web search refs: {e}",
exc_info=True,
)
saved_record = await self._save_bot_message(
session_id,
accumulated_text,
accumulated_parts,
accumulated_reasoning,
agent_stats,
refs,
)
if saved_record:
await self._send_chat_payload(
session,
{
"ct": "chat",
"type": "message_saved",
"data": {
"id": saved_record.id,
"created_at": saved_record.created_at.astimezone().isoformat(),
},
},
)
accumulated_parts = []
accumulated_text = ""
accumulated_reasoning = ""
agent_stats = {}
refs = {}
if msg_type == "end":
break
except Exception as e:
logger.error(f"[Live Chat] 处理 chat 消息失败: {e}", exc_info=True)
await self._send_chat_payload(
session,
{
"ct": "chat",
"t": "error",
"data": f"处理失败: {str(e)}",
"code": "PROCESSING_ERROR",
},
)
finally:
session.is_processing = False
webchat_queue_mgr.remove_back_queue(message_id)
async def _build_chat_message_parts(self, message: list[dict]) -> list[dict]:
"""构建 chat websocket 用户消息段(复用 webchat 逻辑)"""
return await build_webchat_message_parts(
message,
get_attachment_by_id=self.db.get_attachment_by_id,
strict=False,
)
async def _handle_message(self, session: LiveChatSession, message: dict) -> None:
"""处理 WebSocket 消息"""
msg_type = message.get("t") # 使用 t 代替 type
+360 -81
View File
@@ -1,15 +1,22 @@
from pathlib import Path
import asyncio
import hashlib
import json
from uuid import uuid4
from quart import g, request
from quart import g, request, websocket
from astrbot.core import logger
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.db import BaseDatabase
from astrbot.core.message.components import File, Image, Plain, Record, Reply, Video
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.platform.message_session import MessageSesion
from astrbot.core.platform.sources.webchat.message_parts_helper import (
build_message_chain_from_payload,
strip_message_parts_path_fields,
webchat_message_parts_have_content,
)
from astrbot.core.platform.sources.webchat.webchat_queue_mgr import webchat_queue_mgr
from .api_key import ALL_OPEN_API_SCOPES
from .chat import ChatRoute
from .route import Response, Route, RouteContext
@@ -37,6 +44,7 @@ class OpenApiRoute(Route):
"/v1/im/bots": ("GET", self.get_bots),
}
self.register_routes()
self.app.websocket("/api/v1/chat/ws")(self.chat_ws)
@staticmethod
def _resolve_open_username(
@@ -181,6 +189,348 @@ class OpenApiRoute(Route):
finally:
g.username = original_username
@staticmethod
def _extract_ws_api_key() -> str | None:
if key := websocket.args.get("api_key"):
return key.strip()
if key := websocket.args.get("key"):
return key.strip()
if key := websocket.headers.get("X-API-Key"):
return key.strip()
auth_header = websocket.headers.get("Authorization", "").strip()
if auth_header.startswith("Bearer "):
return auth_header.removeprefix("Bearer ").strip()
if auth_header.startswith("ApiKey "):
return auth_header.removeprefix("ApiKey ").strip()
return None
async def _authenticate_chat_ws_api_key(self) -> tuple[bool, str | None]:
raw_key = self._extract_ws_api_key()
if not raw_key:
return False, "Missing API key"
key_hash = hashlib.pbkdf2_hmac(
"sha256",
raw_key.encode("utf-8"),
b"astrbot_api_key",
100_000,
).hex()
api_key = await self.db.get_active_api_key_by_hash(key_hash)
if not api_key:
return False, "Invalid API key"
if isinstance(api_key.scopes, list):
scopes = api_key.scopes
else:
scopes = list(ALL_OPEN_API_SCOPES)
if "*" not in scopes and "chat" not in scopes:
return False, "Insufficient API key scope"
await self.db.touch_api_key(api_key.key_id)
return True, None
async def _send_chat_ws_error(self, message: str, code: str) -> None:
await websocket.send_json(
{
"type": "error",
"code": code,
"data": message,
}
)
async def _update_session_config_route(
self,
*,
username: str,
session_id: str,
config_id: str | None,
) -> str | None:
if not config_id:
return None
umo = f"webchat:FriendMessage:webchat!{username}!{session_id}"
try:
if config_id == "default":
await self.core_lifecycle.umop_config_router.delete_route(umo)
else:
await self.core_lifecycle.umop_config_router.update_route(
umo, config_id
)
except Exception as e:
logger.error(
"Failed to update chat config route for %s with %s: %s",
umo,
config_id,
e,
exc_info=True,
)
return f"Failed to update chat config route: {e}"
return None
async def _handle_chat_ws_send(self, post_data: dict) -> None:
effective_username, username_err = self._resolve_open_username(
post_data.get("username")
)
if username_err or not effective_username:
await self._send_chat_ws_error(
username_err or "Invalid username", "BAD_USER"
)
return
message = post_data.get("message")
if message is None:
await self._send_chat_ws_error("Missing key: message", "INVALID_MESSAGE")
return
raw_session_id = post_data.get("session_id", post_data.get("conversation_id"))
session_id = str(raw_session_id).strip() if raw_session_id is not None else ""
if not session_id:
session_id = str(uuid4())
ensure_session_err = await self._ensure_chat_session(
effective_username,
session_id,
)
if ensure_session_err:
await self._send_chat_ws_error(ensure_session_err, "SESSION_ERROR")
return
config_id, resolve_err = self._resolve_chat_config_id(post_data)
if resolve_err:
await self._send_chat_ws_error(resolve_err, "CONFIG_ERROR")
return
config_err = await self._update_session_config_route(
username=effective_username,
session_id=session_id,
config_id=config_id,
)
if config_err:
await self._send_chat_ws_error(config_err, "CONFIG_ERROR")
return
message_parts = await self.chat_route._build_user_message_parts(message)
if not webchat_message_parts_have_content(message_parts):
await self._send_chat_ws_error(
"Message content is empty (reply only is not allowed)",
"INVALID_MESSAGE",
)
return
message_id = str(post_data.get("message_id") or uuid4())
selected_provider = post_data.get("selected_provider")
selected_model = post_data.get("selected_model")
enable_streaming = post_data.get("enable_streaming", True)
back_queue = webchat_queue_mgr.get_or_create_back_queue(message_id, session_id)
try:
chat_queue = webchat_queue_mgr.get_or_create_queue(session_id)
await chat_queue.put(
(
effective_username,
session_id,
{
"message": message_parts,
"selected_provider": selected_provider,
"selected_model": selected_model,
"enable_streaming": enable_streaming,
"message_id": message_id,
},
)
)
message_parts_for_storage = strip_message_parts_path_fields(message_parts)
await self.chat_route.platform_history_mgr.insert(
platform_id="webchat",
user_id=session_id,
content={"type": "user", "message": message_parts_for_storage},
sender_id=effective_username,
sender_name=effective_username,
)
await websocket.send_json(
{
"type": "session_id",
"data": None,
"session_id": session_id,
"message_id": message_id,
}
)
accumulated_parts = []
accumulated_text = ""
accumulated_reasoning = ""
tool_calls = {}
agent_stats = {}
refs = {}
while True:
try:
result = await asyncio.wait_for(back_queue.get(), timeout=1)
except asyncio.TimeoutError:
continue
if not result:
continue
if "message_id" in result and result["message_id"] != message_id:
logger.warning("openapi ws stream message_id mismatch")
continue
result_text = result.get("data", "")
msg_type = result.get("type")
streaming = result.get("streaming", False)
chain_type = result.get("chain_type")
if chain_type == "agent_stats":
try:
stats_info = {
"type": "agent_stats",
"data": json.loads(result_text),
}
await websocket.send_json(stats_info)
agent_stats = stats_info["data"]
except Exception:
pass
continue
await websocket.send_json(result)
if msg_type == "plain":
if chain_type == "tool_call":
tool_call = json.loads(result_text)
tool_calls[tool_call.get("id")] = tool_call
if accumulated_text:
accumulated_parts.append(
{"type": "plain", "text": accumulated_text}
)
accumulated_text = ""
elif chain_type == "tool_call_result":
tcr = json.loads(result_text)
tc_id = tcr.get("id")
if tc_id in tool_calls:
tool_calls[tc_id]["result"] = tcr.get("result")
tool_calls[tc_id]["finished_ts"] = tcr.get("ts")
accumulated_parts.append(
{"type": "tool_call", "tool_calls": [tool_calls[tc_id]]}
)
tool_calls.pop(tc_id, None)
elif chain_type == "reasoning":
accumulated_reasoning += result_text
elif streaming:
accumulated_text += result_text
else:
accumulated_text = result_text
elif msg_type == "image":
filename = str(result_text).replace("[IMAGE]", "")
part = await self.chat_route._create_attachment_from_file(
filename, "image"
)
if part:
accumulated_parts.append(part)
elif msg_type == "record":
filename = str(result_text).replace("[RECORD]", "")
part = await self.chat_route._create_attachment_from_file(
filename, "record"
)
if part:
accumulated_parts.append(part)
elif msg_type == "file":
filename = str(result_text).replace("[FILE]", "")
part = await self.chat_route._create_attachment_from_file(
filename, "file"
)
if part:
accumulated_parts.append(part)
elif msg_type == "video":
filename = str(result_text).replace("[VIDEO]", "")
part = await self.chat_route._create_attachment_from_file(
filename, "video"
)
if part:
accumulated_parts.append(part)
if msg_type == "end":
break
if (streaming and msg_type == "complete") or not streaming:
if chain_type in ("tool_call", "tool_call_result"):
continue
try:
refs = self.chat_route._extract_web_search_refs(
accumulated_text,
accumulated_parts,
)
except Exception as e:
logger.exception(
f"Open API WS failed to extract web search refs: {e}",
exc_info=True,
)
saved_record = await self.chat_route._save_bot_message(
session_id,
accumulated_text,
accumulated_parts,
accumulated_reasoning,
agent_stats,
refs,
)
if saved_record:
await websocket.send_json(
{
"type": "message_saved",
"data": {
"id": saved_record.id,
"created_at": saved_record.created_at.astimezone().isoformat(),
},
"session_id": session_id,
}
)
accumulated_parts = []
accumulated_text = ""
accumulated_reasoning = ""
agent_stats = {}
refs = {}
except Exception as e:
logger.exception(f"Open API WS chat failed: {e}", exc_info=True)
await self._send_chat_ws_error(
f"Failed to process message: {e}", "PROCESSING_ERROR"
)
finally:
webchat_queue_mgr.remove_back_queue(message_id)
async def chat_ws(self) -> None:
authed, auth_err = await self._authenticate_chat_ws_api_key()
if not authed:
await self._send_chat_ws_error(auth_err or "Unauthorized", "UNAUTHORIZED")
await websocket.close(1008, auth_err or "Unauthorized")
return
try:
while True:
message = await websocket.receive_json()
if not isinstance(message, dict):
await self._send_chat_ws_error(
"message must be an object",
"INVALID_MESSAGE",
)
continue
msg_type = message.get("t", "send")
if msg_type == "ping":
await websocket.send_json({"type": "pong"})
continue
if msg_type != "send":
await self._send_chat_ws_error(
f"Unsupported message type: {msg_type}",
"INVALID_MESSAGE",
)
continue
await self._handle_chat_ws_send(message)
except Exception as e:
logger.debug("Open API WS connection closed: %s", e)
async def upload_file(self):
return await self.chat_route.post_file()
@@ -254,83 +604,12 @@ class OpenApiRoute(Route):
async def _build_message_chain_from_payload(
self,
message_payload: str | list,
) -> MessageChain:
if isinstance(message_payload, str):
text = message_payload.strip()
if not text:
raise ValueError("Message is empty")
return MessageChain(chain=[Plain(text=text)])
if not isinstance(message_payload, list):
raise ValueError("message must be a string or list")
components = []
has_content = False
for part in message_payload:
if not isinstance(part, dict):
raise ValueError("message part must be an object")
part_type = str(part.get("type", "")).strip()
if part_type == "plain":
text = str(part.get("text", ""))
if text:
has_content = True
components.append(Plain(text=text))
continue
if part_type == "reply":
message_id = part.get("message_id")
if message_id is None:
raise ValueError("reply part missing message_id")
components.append(
Reply(
id=str(message_id),
message_str=str(part.get("selected_text", "")),
chain=[],
)
)
continue
if part_type not in {"image", "record", "file", "video"}:
raise ValueError(f"unsupported message part type: {part_type}")
has_content = True
file_path: Path | None = None
resolved_type = part_type
filename = str(part.get("filename", "")).strip()
attachment_id = part.get("attachment_id")
if attachment_id:
attachment = await self.db.get_attachment_by_id(str(attachment_id))
if not attachment:
raise ValueError(f"attachment not found: {attachment_id}")
file_path = Path(attachment.path)
resolved_type = attachment.type
if not filename:
filename = file_path.name
else:
raise ValueError(f"{part_type} part missing attachment_id")
if not file_path.exists():
raise ValueError(f"file not found: {file_path!s}")
file_path_str = str(file_path.resolve())
if resolved_type == "image":
components.append(Image.fromFileSystem(file_path_str))
elif resolved_type == "record":
components.append(Record.fromFileSystem(file_path_str))
elif resolved_type == "video":
components.append(Video.fromFileSystem(file_path_str))
else:
components.append(
File(name=filename or file_path.name, file=file_path_str)
)
if not components or not has_content:
raise ValueError("Message content is empty (reply only is not allowed)")
return MessageChain(chain=components)
):
return await build_message_chain_from_payload(
message_payload,
get_attachment_by_id=self.db.get_attachment_by_id,
strict=True,
)
async def send_message(self):
post_data = await request.json or {}
+5
View File
@@ -204,6 +204,10 @@ class AstrBotDashboard:
@staticmethod
def _extract_raw_api_key() -> str | None:
if key := request.args.get("api_key"):
return key.strip()
if key := request.args.get("key"):
return key.strip()
if key := request.headers.get("X-API-Key"):
return key.strip()
auth_header = request.headers.get("Authorization", "").strip()
@@ -217,6 +221,7 @@ class AstrBotDashboard:
def _get_required_open_api_scope(path: str) -> str | None:
scope_map = {
"/api/v1/chat": "chat",
"/api/v1/chat/ws": "chat",
"/api/v1/chat/sessions": "chat",
"/api/v1/configs": "config",
"/api/v1/file": "file",
-60
View File
@@ -1,60 +0,0 @@
## What's Changed
### 新增
- 新增 Agent 会话停止能力,并优化 stop 请求处理流程,支持 /stop 指令终止 Agent 运行并尽量不丢失已运行输出的结果。 ([#5380](https://github.com/AstrBotDevs/AstrBot/issues/5380))。
- 新增 SubAgent 交接场景下的 computer-use 工具支持 ([#5399](https://github.com/AstrBotDevs/AstrBot/issues/5399))。
- 新增 Agent 执行过程中展示工具调用结果的能力,提升执行过程可观测性 ([#5388](https://github.com/AstrBotDevs/AstrBot/issues/5388))。
- 新增插件加载/卸载 Hook,扩展插件生命周期能力 ([#5331](https://github.com/AstrBotDevs/AstrBot/issues/5331))。
- 新增插件加载失败后的热重载能力,提升插件开发与恢复效率 ([#5334](https://github.com/AstrBotDevs/AstrBot/issues/5334))。
- 新增 SubAgent 图片 URL/本地路径输入支持 ([#5348](https://github.com/AstrBotDevs/AstrBot/issues/5348))。
- 新增 Dashboard 发布跳转基础 URL 可配置项 ([#5330](https://github.com/AstrBotDevs/AstrBot/issues/5330))。
### 修复
- 修复 Tavily 请求的硬编码 6 秒超时。
- 修复 OneBot v11 适配器关闭之后仍然在连接的问题([#5412](https://github.com/AstrBotDevs/AstrBot/issues/5412))。
- 修复上下文会话中平台缺失时的日志处理,补充 warning 并改进排查信息。
- 修复 embedding 维度未透传到 provider API 的问题 ([#5411](https://github.com/AstrBotDevs/AstrBot/issues/5411))。
- 修复 File 组件处理逻辑并增强 OneBot 驱动层路径兼容性 ([#5391](https://github.com/AstrBotDevs/AstrBot/issues/5391))。
- 修复 sandbox 文件传输工具缺少管理员权限校验的问题 ([#5402](https://github.com/AstrBotDevs/AstrBot/issues/5402))。
- 修复 pipeline 与 `from ... import *` 引发的循环依赖问题 ([#5353](https://github.com/AstrBotDevs/AstrBot/issues/5353))。
- 修复配置文件存在 UTF-8 BOM 时的解析问题 ([#5376](https://github.com/AstrBotDevs/AstrBot/issues/5376))。
- 修复 ChatUI 复制回滚路径缺失与错误提示不清晰的问题 ([#5352](https://github.com/AstrBotDevs/AstrBot/issues/5352))。
- 修复保留插件目录处理逻辑,避免插件目录行为异常 ([#5369](https://github.com/AstrBotDevs/AstrBot/issues/5369))。
- 修复 ChatUI 文件消息段无法持久化的问题 ([#5386](https://github.com/AstrBotDevs/AstrBot/issues/5386))。
- 修复 `.dockerignore` 误排除 `changelogs` 目录的问题。
- 修复 aiohttp 版本过新导致 qq-botpy 报错的问题 ([#5316](https://github.com/AstrBotDevs/AstrBot/issues/5316))。
### 优化
- 完成 SubAgent 编排页面国际化,补齐多语言支持 ([#5400](https://github.com/AstrBotDevs/AstrBot/issues/5400))。
- 增补消息事件处理相关测试,并完善测试框架的 fixtures/mocks 覆盖 ([#5355](https://github.com/AstrBotDevs/AstrBot/issues/5355), [#5354](https://github.com/AstrBotDevs/AstrBot/issues/5354))。
## What's Changed (EN)
### New Features
- Added computer-use tools support in sub-agent handoff scenarios ([#5399](https://github.com/AstrBotDevs/AstrBot/issues/5399)).
- Added support for displaying tool call results during agent execution for better observability ([#5388](https://github.com/AstrBotDevs/AstrBot/issues/5388)).
- Added plugin load/unload hooks to extend plugin lifecycle capabilities ([#5331](https://github.com/AstrBotDevs/AstrBot/issues/5331)).
- Added hot reload support when plugin loading fails, improving recovery during plugin development ([#5334](https://github.com/AstrBotDevs/AstrBot/issues/5334)).
- Added image URL/local path input support for sub-agents ([#5348](https://github.com/AstrBotDevs/AstrBot/issues/5348)).
- Added stop control for active agent sessions and improved stop request handling ([#5380](https://github.com/AstrBotDevs/AstrBot/issues/5380)).
- Added configurable base URL for dashboard release redirects ([#5330](https://github.com/AstrBotDevs/AstrBot/issues/5330)).
### Fixes
- Fixed logging behavior when platform information is missing in context sessions, with clearer warning and diagnostics.
- Fixed missing embedding dimensions being passed to provider APIs ([#5411](https://github.com/AstrBotDevs/AstrBot/issues/5411)).
- Fixed shutdown stability issues in the aiocqhttp adapter ([#5412](https://github.com/AstrBotDevs/AstrBot/issues/5412)).
- Fixed File component handling and improved path compatibility in the OneBot driver layer ([#5391](https://github.com/AstrBotDevs/AstrBot/issues/5391)).
- Fixed missing admin guard for sandbox file transfer tools ([#5402](https://github.com/AstrBotDevs/AstrBot/issues/5402)).
- Fixed circular import issues related to pipeline and `from ... import *` usage ([#5353](https://github.com/AstrBotDevs/AstrBot/issues/5353)).
- Fixed config parsing issues when files contain UTF-8 BOM ([#5376](https://github.com/AstrBotDevs/AstrBot/issues/5376)).
- Fixed missing copy rollback path and unclear error messaging in ChatUI ([#5352](https://github.com/AstrBotDevs/AstrBot/issues/5352)).
- Fixed reserved plugin directory handling to avoid abnormal plugin path behavior ([#5369](https://github.com/AstrBotDevs/AstrBot/issues/5369)).
- Fixed ChatUI file segment persistence issues ([#5386](https://github.com/AstrBotDevs/AstrBot/issues/5386)).
- Fixed accidental exclusion of the `changelogs` directory in `.dockerignore`.
- Fixed compatibility issues caused by a hard-coded 6-second timeout in Tavily requests.
- Fixed qq-botpy runtime errors caused by overly new aiohttp versions ([#5316](https://github.com/AstrBotDevs/AstrBot/issues/5316)).
### Improvements
- Completed internationalization for the sub-agent orchestration page ([#5400](https://github.com/AstrBotDevs/AstrBot/issues/5400)).
- Added broader message-event test coverage and improved fixtures/mocks in the test framework ([#5355](https://github.com/AstrBotDevs/AstrBot/issues/5355), [#5354](https://github.com/AstrBotDevs/AstrBot/issues/5354)).
- Updated README content and applied repository-wide formatting cleanup (ruff format) ([#5375](https://github.com/AstrBotDevs/AstrBot/issues/5375)).
-49
View File
@@ -1,49 +0,0 @@
## What's Changed
### 新增
- 新增桌面端通用更新桥接能力,便于接入客户端内更新流程 ([#5424](https://github.com/AstrBotDevs/AstrBot/issues/5424))。
### 修复
- 修复新增平台对话框中 Line 适配器未显示的问题。
- 修复 Telegram 无法发送 Video 的问题 ([#5430](https://github.com/AstrBotDevs/AstrBot/issues/5430))。
- 修复创建 embedding provider 时无法自动识别向量维度的问题 ([#5442](https://github.com/AstrBotDevs/AstrBot/issues/5442))。
- 修复 QQ 官方平台发送媒体消息时 markdown 字段未清理的问题 ([#5445](https://github.com/AstrBotDevs/AstrBot/issues/5445))。
- 修复上下文管理策略 -> 上下文截断时 tool call / response 配对丢失的问题 ([#5417](https://github.com/AstrBotDevs/AstrBot/issues/5417))。
- 修复会话更新时 `persona_id` 被覆盖的问题,并增强 persona 解析逻辑。
- 修复 WebUI 中 GitHub 代理地址显示异常的问题 ([#5438](https://github.com/AstrBotDevs/AstrBot/issues/5438))。
- 修复设置页新建开发者 API Key 后复制失败的问题 ([#5439](https://github.com/AstrBotDevs/AstrBot/issues/5439))。
- 修复 Telegram 语音消息格式与 OpenAI STT 兼容性问题(使用 OGG ([#5389](https://github.com/AstrBotDevs/AstrBot/issues/5389))。
### 优化
- 优化知识库检索流程,改为批量查询元数据,修复 N+1 查询性能问题 ([#5463](https://github.com/AstrBotDevs/AstrBot/issues/5463))。
- 优化 Cron 未来任务执行的会话隔离能力,提升并发稳定性。
- 优化 WebUI 插件页的交互。
## What's Changed (EN)
### New Features
- Added `useExtensionPage` composable for unified plugin extension page state management.
- Added a generic desktop app updater bridge to support in-app update workflows ([#5424](https://github.com/AstrBotDevs/AstrBot/issues/5424)).
### Bug Fixes
- Fixed the Line adapter not appearing in the "Add Platform" dialog.
- Fixed Telegram video sending issues ([#5430](https://github.com/AstrBotDevs/AstrBot/issues/5430)).
- Fixed Pyright static type checking errors ([#5437](https://github.com/AstrBotDevs/AstrBot/issues/5437)).
- Fixed embedding dimension auto-detection when creating embedding providers ([#5442](https://github.com/AstrBotDevs/AstrBot/issues/5442)).
- Fixed stale markdown fields when sending media messages via QQ Official Platform ([#5445](https://github.com/AstrBotDevs/AstrBot/issues/5445)).
- Fixed tool call/response pairing loss during context truncation ([#5417](https://github.com/AstrBotDevs/AstrBot/issues/5417)).
- Fixed `persona_id` being overwritten during conversation updates and improved persona resolution logic.
- Fixed incorrect GitHub proxy display in WebUI ([#5438](https://github.com/AstrBotDevs/AstrBot/issues/5438)).
- Fixed API key copy failure after creating a new key in settings ([#5439](https://github.com/AstrBotDevs/AstrBot/issues/5439)).
- Fixed Telegram voice format compatibility with OpenAI STT by using OGG ([#5389](https://github.com/AstrBotDevs/AstrBot/issues/5389)).
### Improvements
- Improved knowledge base retrieval by batching metadata queries to eliminate the N+1 query pattern ([#5463](https://github.com/AstrBotDevs/AstrBot/issues/5463)).
- Improved session isolation for future cron tasks to increase stability under concurrency.
- Improved WebUI plugin page interactions.
+7 -1
View File
@@ -10,6 +10,7 @@
:selectedSessions="selectedSessions"
:currSessionId="currSessionId"
:selectedProjectId="selectedProjectId"
:transportMode="transportMode"
:isDark="isDark"
:chatboxMode="chatboxMode"
:isMobile="isMobile"
@@ -26,6 +27,7 @@
@createProject="showCreateProjectDialog"
@editProject="showEditProjectDialog"
@deleteProject="handleDeleteProject"
@updateTransportMode="setTransportMode"
/>
<!-- 右侧聊天内容区域 -->
@@ -301,11 +303,14 @@ const {
isStreaming,
isConvRunning,
enableStreaming,
transportMode,
currentSessionProject,
getSessionMessages: getSessionMsg,
sendMessage: sendMsg,
stopMessage: stopMsg,
toggleStreaming
toggleStreaming,
setTransportMode,
cleanupTransport
} = useMessages(currSessionId, getMediaFile, updateSessionTitle, getSessions);
//
@@ -695,6 +700,7 @@ onMounted(() => {
onBeforeUnmount(() => {
window.removeEventListener('resize', checkMobile);
cleanupMediaCache();
cleanupTransport();
});
</script>
@@ -117,6 +117,27 @@
<v-list-item-title>{{ isDark ? tm('modes.lightMode') : tm('modes.darkMode') }}</v-list-item-title>
</v-list-item>
<!-- 通信传输模式 -->
<v-list-item class="styled-menu-item">
<template v-slot:prepend>
<v-icon>mdi-lan-connect</v-icon>
</template>
<v-list-item-title>{{ tm('transport.title') }}</v-list-item-title>
<template v-slot:append>
<v-select
:model-value="transportMode"
:items="transportOptions"
item-title="label"
item-value="value"
density="compact"
variant="underlined"
hide-details
class="transport-mode-select"
@update:model-value="handleTransportModeChange"
/>
</template>
</v-list-item>
<!-- 全屏/退出全屏 -->
<v-list-item class="styled-menu-item" @click="$emit('toggleFullscreen')">
<template v-slot:prepend>
@@ -156,6 +177,7 @@ interface Props {
selectedSessions: string[];
currSessionId: string;
selectedProjectId?: string | null;
transportMode: 'sse' | 'websocket';
isDark: boolean;
chatboxMode: boolean;
isMobile: boolean;
@@ -179,6 +201,7 @@ const emit = defineEmits<{
createProject: [];
editProject: [project: Project];
deleteProject: [projectId: string];
updateTransportMode: [mode: 'sse' | 'websocket'];
}>();
const { t } = useI18n();
@@ -188,6 +211,10 @@ const confirmDialog = useConfirmDialog();
const sidebarCollapsed = ref(true);
const showProviderConfigDialog = ref(false);
const transportOptions = [
{ label: tm('transport.sse'), value: 'sse' as const },
{ label: tm('transport.websocket'), value: 'websocket' as const }
];
// localStorage
const savedCollapsedState = localStorage.getItem('sidebarCollapsed');
@@ -209,6 +236,12 @@ async function handleDeleteConversation(session: Session) {
emit('deleteConversation', session.session_id);
}
}
function handleTransportModeChange(mode: string | null) {
if (mode === 'sse' || mode === 'websocket') {
emit('updateTransportMode', mode);
}
}
</script>
<style scoped>
@@ -361,4 +394,8 @@ async function handleDeleteConversation(session: Session) {
display: flex;
justify-content: center;
}
.transport-mode-select {
min-width: 120px;
}
</style>
@@ -34,7 +34,6 @@ const platformDisplayList = computed(() =>
const handleInstall = (plugin) => {
emit("install", plugin);
};
</script>
<template>
@@ -124,7 +123,6 @@ const handleInstall = (plugin) => {
v-if="plugin?.social_link"
:href="plugin.social_link"
target="_blank"
@click.stop
class="text-subtitle-2 font-weight-medium"
style="
text-decoration: none;
@@ -215,10 +213,7 @@ const handleInstall = (plugin) => {
</div>
</v-card-text>
<v-card-actions
style="gap: 6px; padding: 8px 12px; padding-top: 0"
@click.stop
>
<v-card-actions style="gap: 6px; padding: 8px 12px; padding-top: 0">
<v-chip
v-for="tag in plugin.tags?.slice(0, 2)"
:key="tag"
@@ -253,24 +248,22 @@ const handleInstall = (plugin) => {
<v-btn
v-if="plugin?.repo"
color="secondary"
size="small"
size="x-small"
variant="tonal"
class="market-action-btn"
:href="plugin.repo"
target="_blank"
style="height: 32px"
style="height: 24px"
>
<v-icon icon="mdi-github" start size="small"></v-icon>
<v-icon icon="mdi-github" start size="x-small"></v-icon>
{{ tm("buttons.viewRepo") }}
</v-btn>
<v-btn
v-if="!plugin?.installed"
color="primary"
size="small"
size="x-small"
@click="handleInstall(plugin)"
variant="flat"
class="market-action-btn"
style="height: 32px"
style="height: 24px"
>
{{ tm("buttons.install") }}
</v-btn>
@@ -313,9 +306,4 @@ const handleInstall = (plugin) => {
.plugin-description::-webkit-scrollbar-thumb:hover {
background-color: rgba(var(--v-theme-primary-rgb), 0.6);
}
.market-action-btn {
font-size: 0.9rem;
font-weight: 600;
}
</style>
@@ -48,40 +48,6 @@ const filteredIterable = computed(() => {
return rest
})
const providerHint = computed(() => {
const hint = props.iterable?.hint
if (typeof hint !== 'string' || !hint) return ''
if (
hint === 'provider_group.provider.openai_embedding.hint'
|| hint === 'provider_group.provider.gemini_embedding.hint'
) {
return ''
}
return hint
})
const getItemHint = (itemKey, itemMeta) => {
if (itemMeta?.hint) return itemMeta.hint
if (itemKey !== 'embedding_api_base') return ''
const providerType = props.iterable?.type
if (providerType === 'openai_embedding') {
return getRaw('provider_group.provider.openai_embedding.hint')
? 'provider_group.provider.openai_embedding.hint'
: ''
}
if (providerType === 'gemini_embedding') {
return getRaw('provider_group.provider.gemini_embedding.hint')
? 'provider_group.provider.gemini_embedding.hint'
: ''
}
return ''
}
const dialog = ref(false)
const currentEditingKey = ref('')
const currentEditingLanguage = ref('json')
@@ -187,14 +153,14 @@ function hasVisibleItemsAfter(items, currentIndex) {
<div v-if="metadata[metadataKey]?.type === 'object' || metadata[metadataKey]?.config_template" class="object-config">
<!-- Provider-level hint -->
<v-alert
v-if="providerHint"
v-if="iterable.hint && !isEditing"
type="info"
variant="tonal"
class="mb-4"
border="start"
density="compact"
>
{{ translateIfKey(providerHint) }}
{{ iterable.hint }}
</v-alert>
<div v-for="(val, key, index) in filteredIterable" :key="key" class="config-item">
@@ -252,9 +218,9 @@ function hasVisibleItemsAfter(items, currentIndex) {
</v-list-item-title>
<v-list-item-subtitle class="property-hint">
<span v-if="metadata[metadataKey].items[key]?.obvious_hint && getItemHint(key, metadata[metadataKey].items[key])"
<span v-if="metadata[metadataKey].items[key]?.obvious_hint && metadata[metadataKey].items[key]?.hint"
class="important-hint"></span>
{{ translateIfKey(getItemHint(key, metadata[metadataKey].items[key])) }}
{{ translateIfKey(metadata[metadataKey].items[key]?.hint) }}
</v-list-item-subtitle>
</v-list-item>
</v-col>
+232 -357
View File
@@ -1,12 +1,10 @@
<script setup lang="ts">
import { ref, computed, inject, watch } from "vue";
import { ref, computed, inject } from "vue";
import { useCustomizerStore } from "@/stores/customizer";
import { useModuleI18n } from "@/i18n/composables";
import { getPlatformDisplayName, getPlatformIcon } from "@/utils/platformUtils";
import UninstallConfirmDialog from "./UninstallConfirmDialog.vue";
import PluginPlatformChip from "./PluginPlatformChip.vue";
import StyledMenu from "./StyledMenu.vue";
import defaultPluginIcon from "@/assets/images/plugin_icon.png";
const props = defineProps({
extension: {
@@ -61,25 +59,6 @@ const astrbotVersionRequirement = computed(() => {
: "";
});
const logoLoadFailed = ref(false);
const logoSrc = computed(() => {
const logo = props.extension?.logo;
if (logoLoadFailed.value) {
return defaultPluginIcon;
}
return typeof logo === "string" && logo.trim().length
? logo
: defaultPluginIcon;
});
watch(
() => props.extension?.logo,
() => {
logoLoadFailed.value = false;
},
);
//
const configure = () => {
emit("configure", props.extension);
@@ -125,7 +104,6 @@ const viewReadme = () => {
const viewChangelog = () => {
emit("view-changelog", props.extension);
};
</script>
<template>
@@ -151,292 +129,249 @@ const viewChangelog = () => {
style="
padding: 16px;
padding-bottom: 0px;
display: flex;
gap: 16px;
width: 100%;
"
>
<div style="overflow-x: auto; width: 100%">
<div style="width: 100%; margin-bottom: 24px">
<div class="extension-title-row">
<p
class="text-h3 font-weight-black extension-title"
:class="{ 'text-h4': $vuetify.display.xs }"
>
<v-tooltip
location="top"
:text="
extension.display_name?.length &&
extension.display_name !== extension.name
? `${extension.display_name} (${extension.name})`
: extension.name
"
<div v-if="extension?.logo">
<img :src="extension.logo" :alt="extension.name" cover width="100" />
</div>
<div style="overflow-x: auto">
<!-- Top-right three-dot menu -->
<div style="position: absolute; right: 8px; top: 8px; z-index: 5">
<v-menu offset-y>
<template v-slot:activator="{ props: menuProps }">
<v-btn
icon
variant="text"
aria-label="more"
v-if="extension?.repo"
:href="extension?.repo"
target="_blank"
>
<template v-slot:activator="{ props: titleTooltipProps }">
<span v-bind="titleTooltipProps" class="extension-title__text">{{
extension.display_name?.length
? extension.display_name
: extension.name
}}</span>
</template>
</v-tooltip>
<v-tooltip
location="top"
v-if="extension?.has_update && !marketMode"
>
<template v-slot:activator="{ props: tooltipProps }">
<v-icon
v-bind="tooltipProps"
color="warning"
class="ml-2"
icon="mdi-update"
size="small"
></v-icon>
</template>
<span
>{{ tm("card.status.hasUpdate") }}:
{{ extension.online_version }}</span
<v-icon icon="mdi-github"></v-icon>
</v-btn>
<v-btn v-bind="menuProps" icon variant="text" aria-label="more">
<v-icon icon="mdi-dots-vertical"></v-icon>
</v-btn>
</template>
<v-list>
<v-list-item @click="viewReadme">
<v-list-item-title
>📄 {{ tm("buttons.viewDocs") }}</v-list-item-title
>
</v-tooltip>
<v-tooltip
location="top"
v-if="!extension.activated && !marketMode"
</v-list-item>
<v-list-item v-if="!marketMode" @click="viewChangelog">
<v-list-item-title
>📝 {{ tm("pluginChangelog.menuTitle") }}</v-list-item-title
>
</v-list-item>
<v-list-item
v-if="marketMode && !extension?.installed"
@click="installExtension"
>
<template v-slot:activator="{ props: tooltipProps }">
<v-icon
v-bind="tooltipProps"
color="error"
class="ml-2"
icon="mdi-cancel"
size="small"
></v-icon>
</template>
<span>{{ tm("card.status.disabled") }}</span>
</v-tooltip>
</p>
<v-list-item-title>
{{ tm("buttons.install") }}</v-list-item-title
>
</v-list-item>
<template v-if="!marketMode">
<v-tooltip location="left">
<template v-slot:activator="{ props: tooltipProps }">
<div v-bind="tooltipProps" class="extension-switch-wrap" @click.stop>
<v-switch
:model-value="extension.activated"
color="success"
density="compact"
hide-details
inset
@update:model-value="toggleActivation"
></v-switch>
</div>
</template>
<span>{{
extension.activated ? tm("buttons.disable") : tm("buttons.enable")
}}</span>
</v-tooltip>
</template>
<template v-else>
<div class="extension-market-menu-wrap">
<v-menu offset-y>
<template v-slot:activator="{ props: menuProps }">
<v-btn
icon
variant="text"
aria-label="more"
v-if="extension?.repo"
:href="extension?.repo"
target="_blank"
>
<v-icon icon="mdi-github"></v-icon>
</v-btn>
<v-btn v-bind="menuProps" icon variant="text" aria-label="more">
<v-icon icon="mdi-dots-vertical"></v-icon>
</v-btn>
</template>
<v-list-item v-if="marketMode && extension?.installed">
<v-list-item-title class="text--disabled">{{
tm("status.installed")
}}</v-list-item-title>
</v-list-item>
<v-list>
<v-list-item @click="viewReadme">
<v-list-item-title
>📄 {{ tm("buttons.viewDocs") }}</v-list-item-title
>
</v-list-item>
<!-- Divider between market actions and plugin actions -->
<v-divider v-if="!marketMode" />
<v-list-item
v-if="marketMode && !extension?.installed"
@click="installExtension"
>
<v-list-item-title>
{{ tm("buttons.install") }}</v-list-item-title
>
</v-list-item>
<template v-if="!marketMode">
<v-list-item @click="configure">
<v-list-item-title>
{{ tm("card.actions.pluginConfig") }}</v-list-item-title
>
</v-list-item>
<v-list-item v-if="marketMode && extension?.installed">
<v-list-item-title class="text--disabled">{{
tm("status.installed")
}}</v-list-item-title>
</v-list-item>
</v-list>
</v-menu>
</div>
</template>
<v-list-item @click="uninstallExtension">
<v-list-item-title class="text-error">{{
tm("card.actions.uninstallPlugin")
}}</v-list-item-title>
</v-list-item>
<v-list-item @click="reloadExtension">
<v-list-item-title>{{
tm("card.actions.reloadPlugin")
}}</v-list-item-title>
</v-list-item>
<v-list-item @click="toggleActivation">
<v-list-item-title>
{{
extension.activated
? tm("buttons.disable")
: tm("buttons.enable")
}}{{ tm("card.actions.togglePlugin") }}
</v-list-item-title>
</v-list-item>
<v-list-item @click="viewHandlers">
<v-list-item-title
>{{ tm("card.actions.viewHandlers") }} ({{
extension.handlers.length
}})</v-list-item-title
>
</v-list-item>
<v-list-item @click="updateExtension">
<v-list-item-title>
{{
extension.has_update
? tm("card.actions.updateTo") +
" " +
extension.online_version
: tm("card.actions.reinstall")
}}
</v-list-item-title>
</v-list-item>
</template>
</v-list>
</v-menu>
</div>
<div style="width: 100%; margin-bottom: 24px">
<!-- 最多一行 -->
<div
class="text-caption"
style="
color: gray;
white-space: nowrap;
overflow: hidden;
text-overflow: ellipsis;
margin-right: 84px;
"
>
{{ extension.author }} / {{ extension.name }}
</div>
<p
class="text-h3 font-weight-black extension-title"
:class="{ 'text-h4': $vuetify.display.xs }"
>
<span class="extension-title__text">{{
extension.display_name?.length
? extension.display_name
: extension.name
}}</span>
<v-tooltip
location="top"
v-if="extension?.has_update && !marketMode"
>
<template v-slot:activator="{ props: tooltipProps }">
<v-icon
v-bind="tooltipProps"
color="warning"
class="ml-2"
icon="mdi-update"
size="small"
></v-icon>
</template>
<span
>{{ tm("card.status.hasUpdate") }}:
{{ extension.online_version }}</span
>
</v-tooltip>
<v-tooltip
location="top"
v-if="!extension.activated && !marketMode"
>
<template v-slot:activator="{ props: tooltipProps }">
<v-icon
v-bind="tooltipProps"
color="error"
class="ml-2"
icon="mdi-cancel"
size="small"
></v-icon>
</template>
<span>{{ tm("card.status.disabled") }}</span>
</v-tooltip>
</p>
<div class="mt-1 d-flex flex-wrap">
<v-chip color="primary" label size="small">
<v-icon icon="mdi-source-branch" start></v-icon>
{{ extension.version }}
</v-chip>
<v-chip
v-if="extension?.has_update"
color="warning"
label
size="small"
class="ml-2"
>
<v-icon icon="mdi-arrow-up-bold" start></v-icon>
{{ extension.online_version }}
</v-chip>
<v-chip
color="primary"
label
size="small"
class="ml-2"
v-if="extension.handlers?.length"
@click="viewHandlers"
style="cursor: pointer"
>
<v-icon icon="mdi-cogs" start></v-icon>
{{ extension.handlers?.length
}}{{ tm("card.status.handlersCount") }}
</v-chip>
<v-chip
v-for="tag in extension.tags"
:key="tag"
:color="tag === 'danger' ? 'error' : 'primary'"
label
size="small"
class="ml-2"
>
{{ tag === "danger" ? tm("tags.danger") : tag }}
</v-chip>
<PluginPlatformChip
:platforms="supportPlatforms"
class="ml-2"
/>
<v-chip
v-if="astrbotVersionRequirement"
color="secondary"
variant="outlined"
label
size="small"
class="ml-2"
>
AstrBot: {{ astrbotVersionRequirement }}
</v-chip>
</div>
<div class="extension-content-row mt-2">
<div class="extension-image-container">
<img
:src="logoSrc"
:alt="extension.name"
class="extension-logo"
@error="logoLoadFailed = true"
/>
</div>
<div class="extension-meta-group">
<div class="extension-chip-group d-flex flex-wrap">
<v-chip color="primary" label size="small">
<v-icon icon="mdi-source-branch" start></v-icon>
{{ extension.version }}
</v-chip>
<v-chip
v-if="extension?.has_update"
color="warning"
label
size="small"
>
<v-icon icon="mdi-arrow-up-bold" start></v-icon>
{{ extension.online_version }}
</v-chip>
<v-chip
v-if="extension.handlers?.length"
color="primary"
label
size="small"
@click="viewHandlers"
style="cursor: pointer"
>
<v-icon icon="mdi-cogs" start></v-icon>
{{ extension.handlers?.length
}}{{ tm("card.status.handlersCount") }}
</v-chip>
<v-chip
v-for="tag in extension.tags"
:key="tag"
:color="tag === 'danger' ? 'error' : 'primary'"
label
size="small"
>
{{ tag === "danger" ? tm("tags.danger") : tag }}
</v-chip>
<PluginPlatformChip :platforms="supportPlatforms" />
<v-chip
v-if="astrbotVersionRequirement"
color="secondary"
variant="outlined"
label
size="small"
>
AstrBot: {{ astrbotVersionRequirement }}
</v-chip>
</div>
<div
class="extension-desc"
:class="{ 'text-caption': $vuetify.display.xs }"
>
{{ extension.desc }}
</div>
</div>
<div
class="mt-2"
:class="{ 'text-caption': $vuetify.display.xs }"
style="overflow-y: auto; height: 70px; font-size: 90%"
>
{{ extension.desc }}
</div>
</div>
</div>
</v-card-text>
<v-card-actions class="extension-actions" @click.stop>
<template v-if="!marketMode">
<v-spacer></v-spacer>
<v-tooltip location="top" :text="tm('buttons.viewDocs')">
<template v-slot:activator="{ props: actionProps }">
<v-btn
v-bind="actionProps"
icon="mdi-book-open-page-variant"
size="small"
variant="tonal"
color="info"
@click="viewReadme"
></v-btn>
</template>
</v-tooltip>
<v-tooltip location="top" :text="tm('card.actions.pluginConfig')">
<template v-slot:activator="{ props: actionProps }">
<v-btn
v-bind="actionProps"
icon="mdi-cog"
size="small"
variant="tonal"
color="primary"
@click="configure"
></v-btn>
</template>
</v-tooltip>
<v-tooltip v-if="extension?.repo" location="top" :text="tm('buttons.viewRepo')">
<template v-slot:activator="{ props: actionProps }">
<v-btn
v-bind="actionProps"
icon="mdi-github"
size="small"
variant="tonal"
color="secondary"
:href="extension.repo"
target="_blank"
></v-btn>
</template>
</v-tooltip>
<v-tooltip location="top" :text="tm('card.actions.reloadPlugin')">
<template v-slot:activator="{ props: actionProps }">
<v-btn
v-bind="actionProps"
icon="mdi-refresh"
size="small"
variant="tonal"
color="primary"
@click="reloadExtension"
></v-btn>
</template>
</v-tooltip>
<StyledMenu location="top end" offset="8">
<template #activator="{ props: menuProps }">
<v-btn
v-bind="menuProps"
icon="mdi-dots-horizontal"
size="small"
variant="tonal"
color="secondary"
></v-btn>
</template>
<v-list-item class="styled-menu-item" prepend-icon="mdi-information" @click="viewHandlers">
<v-list-item-title>{{ tm("buttons.viewInfo") }}</v-list-item-title>
</v-list-item>
<v-list-item class="styled-menu-item" prepend-icon="mdi-update" @click="updateExtension">
<v-list-item-title>{{
extension.has_update
? tm("card.actions.updateTo") + " " + extension.online_version
: tm("card.actions.reinstall")
}}</v-list-item-title>
</v-list-item>
<v-list-item class="styled-menu-item" prepend-icon="mdi-delete" @click="uninstallExtension">
<v-list-item-title class="text-error">{{ tm("card.actions.uninstallPlugin") }}</v-list-item-title>
</v-list-item>
</StyledMenu>
</template>
<template v-else>
<v-btn color="primary" size="small" @click="viewReadme">
{{ tm("buttons.viewDocs") }}
</v-btn>
</template>
<v-card-actions class="extension-actions">
<v-btn color="primary" size="small" @click="viewReadme">
{{ tm("buttons.viewDocs") }}
</v-btn>
<v-btn v-if="!marketMode" color="primary" size="small" @click="configure">
{{ tm("card.actions.pluginConfig") }}
</v-btn>
</v-card-actions>
</v-card>
@@ -450,52 +385,13 @@ const viewChangelog = () => {
<style scoped>
.extension-image-container {
display: flex;
align-items: flex-start;
flex-shrink: 0;
}
.extension-logo {
width: 72px;
height: 72px;
border-radius: 12px;
object-fit: cover;
}
.extension-content-row {
display: flex;
gap: 12px;
align-items: flex-start;
}
.extension-meta-group {
flex: 1;
min-width: 0;
}
.extension-chip-group {
gap: 8px;
}
.extension-desc {
margin-top: 8px;
font-size: 90%;
overflow-y: auto;
height: 70px;
align-items: center;
margin-left: 12px;
}
.extension-title {
display: flex;
align-items: center;
min-width: 0;
flex: 1;
margin: 0;
}
.extension-title-row {
display: flex;
align-items: center;
justify-content: space-between;
gap: 12px;
}
.extension-title__text {
@@ -503,38 +399,17 @@ const viewChangelog = () => {
overflow: hidden;
text-overflow: ellipsis;
white-space: nowrap;
}
.extension-switch-wrap {
display: flex;
align-items: center;
flex-shrink: 0;
}
.extension-switch-wrap :deep(.v-switch) {
margin: 0;
}
.extension-market-menu-wrap {
display: flex;
align-items: center;
flex-shrink: 0;
padding-top: 6px;
}
@media (max-width: 600px) {
.extension-content-row {
flex-direction: column;
}
.extension-logo {
width: 64px;
height: 64px;
.extension-image-container {
margin-left: 8px;
}
}
.extension-actions {
margin-top: auto;
gap: 8px;
justify-content: flex-end;
}
</style>
@@ -15,7 +15,7 @@
<v-expand-transition>
<div v-if="radioValue === '1'" style="margin-left: 16px;">
<v-radio-group v-model="githubProxyRadioControl" class="mt-2" hide-details="true">
<v-radio color="success" v-for="(proxy, idx) in githubProxies" :key="proxy" :value="String(idx)">
<v-radio color="success" v-for="(proxy, idx) in githubProxies" :key="proxy" :value="idx">
<template v-slot:label>
<div class="d-flex align-center">
<span class="mr-2">{{ proxy }}</span>
@@ -37,7 +37,7 @@
</template>
</v-radio>
<v-radio color="primary" value="-1" :label="tm('network.proxySelector.custom')">
<template v-slot:label v-if="String(githubProxyRadioControl) === '-1'">
<template v-slot:label v-if="githubProxyRadioControl === '-1'">
<v-text-field density="compact" v-model="selectedGitHubProxy" variant="outlined"
style="width: 100vw;" :placeholder="tm('network.proxySelector.custom')" hide-details="true">
</v-text-field>
@@ -72,21 +72,9 @@ export default {
loadingTestingConnection: false,
testingProxies: {},
proxyStatus: {},
initializing: true,
}
},
methods: {
getProxyByControl(control) {
const normalizedControl = String(control);
if (normalizedControl === "-1") {
return "";
}
const index = Number.parseInt(normalizedControl, 10);
if (Number.isNaN(index)) {
return "";
}
return this.githubProxies[index] || "";
},
async testSingleProxy(idx) {
this.testingProxies[idx] = true;
@@ -130,60 +118,42 @@ export default {
},
},
mounted() {
this.initializing = true;
const savedProxy = localStorage.getItem('selectedGitHubProxy') || "";
const savedRadio = localStorage.getItem('githubProxyRadioValue') || "0";
const savedControl = String(localStorage.getItem('githubProxyRadioControl') || "0");
this.radioValue = savedRadio;
this.githubProxyRadioControl = savedControl;
if (savedRadio === "1") {
if (savedControl !== "-1") {
this.selectedGitHubProxy = this.getProxyByControl(savedControl);
} else {
this.selectedGitHubProxy = savedProxy;
this.selectedGitHubProxy = localStorage.getItem('selectedGitHubProxy') || "";
this.radioValue = localStorage.getItem('githubProxyRadioValue') || "0";
this.githubProxyRadioControl = localStorage.getItem('githubProxyRadioControl') || "0";
if (this.radioValue === "1") {
if (this.githubProxyRadioControl !== "-1") {
this.selectedGitHubProxy = this.githubProxies[this.githubProxyRadioControl] || "";
}
} else {
this.selectedGitHubProxy = "";
}
this.initializing = false;
},
watch: {
selectedGitHubProxy: function (newVal, oldVal) {
if (this.initializing) {
return;
}
if (!newVal) {
newVal = ""
}
localStorage.setItem('selectedGitHubProxy', newVal);
},
radioValue: function (newVal) {
if (this.initializing) {
return;
}
localStorage.setItem('githubProxyRadioValue', newVal);
if (String(newVal) === "0") {
if (newVal === "0") {
this.selectedGitHubProxy = "";
} else if (String(this.githubProxyRadioControl) !== "-1") {
this.selectedGitHubProxy = this.getProxyByControl(this.githubProxyRadioControl);
} else if (this.githubProxyRadioControl !== "-1") {
this.selectedGitHubProxy = this.githubProxies[this.githubProxyRadioControl] || "";
}
},
githubProxyRadioControl: function (newVal) {
if (this.initializing) {
return;
}
const normalizedVal = String(newVal);
localStorage.setItem('githubProxyRadioControl', normalizedVal);
if (String(this.radioValue) !== "1") {
localStorage.setItem('githubProxyRadioControl', newVal);
if (this.radioValue !== "1") {
this.selectedGitHubProxy = "";
return;
}
if (normalizedVal !== "-1") {
this.selectedGitHubProxy = this.getProxyByControl(normalizedVal);
if (newVal !== "-1") {
this.selectedGitHubProxy = this.githubProxies[newVal] || "";
} else {
this.selectedGitHubProxy = "";
}
}
}
File diff suppressed because it is too large Load Diff
@@ -58,18 +58,6 @@
"guideStep2": "Install it and restart AstrBot.",
"guideStep3": "If you use Docker, prefer the image update path."
},
"desktopApp": {
"title": "Update Desktop App",
"message": "Check and upgrade the AstrBot desktop application.",
"currentVersion": "Current version: ",
"latestVersion": "Latest version: ",
"checking": "Checking desktop app updates...",
"hasNewVersion": "A new version is available. Click confirm to upgrade.",
"isLatest": "Already on the latest version",
"installing": "Downloading and installing update. The app will restart automatically...",
"checkFailed": "Failed to check updates. Please try again later.",
"installFailed": "Upgrade failed. Please try again later."
},
"dashboardUpdate": {
"title": "Update Dashboard to Latest Version Only",
"currentVersion": "Current Version",
@@ -81,9 +81,16 @@
"disabled": "Streaming disabled",
"on": "Stream",
"off": "Normal"
}, "config": {
},
"transport": {
"title": "Transport Mode",
"sse": "SSE",
"websocket": "WebSocket"
},
"config": {
"title": "Config"
}, "reasoning": {
},
"reasoning": {
"thinking": "Thinking Process"
},
"reply": {
@@ -147,7 +147,7 @@
"provider_settings": {
"computer_use_runtime": {
"description": "Computer Use Runtime",
"hint": "sandbox means running in a remote sandbox environment, local means running directly on the local machine, local_sandboxed means local execution with OS-level sandboxing (bwrap/seatbelt), and none means disabling Computer Use. If skills are uploaded, choosing none will cause them to not be usable by the Agent."
"hint": "sandbox means running in a sandbox environment, local means running in a local environment, none means disabling Computer Use. If skills are uploaded, choosing none will cause them to not be usable by the Agent."
},
"computer_use_require_admin": {
"description": "Require AstrBot Admin Permission",
@@ -251,10 +251,6 @@
"show_tool_use_status": {
"description": "Output Function Call Status"
},
"show_tool_call_result": {
"description": "Output Tool Call Results",
"hint": "Only takes effect when \"Output Function Call Status\" is enabled, and shows at most 70 characters."
},
"sanitize_context_by_modalities": {
"description": "Sanitize History by Modalities",
"hint": "When enabled, sanitizes contexts before each LLM request by removing image blocks and tool-call structures that the current provider's modalities do not support (this changes what the model sees)."
@@ -1086,12 +1082,6 @@
"embedding_api_base": {
"description": "API Base URL"
},
"openai_embedding": {
"hint": "OpenAI Embedding automatically appends /v1 at request time."
},
"gemini_embedding": {
"hint": "Gemini Embedding does not require manually adding /v1beta."
},
"volcengine_cluster": {
"description": "Volcengine cluster",
"hint": "For voice cloning models, choose volcano_icl or volcano_icl_concurr; default is volcano_tts."
@@ -1319,10 +1309,6 @@
"api_base": {
"description": "API Base URL"
},
"proxy": {
"description": "Proxy address",
"hint": "HTTP/HTTPS proxy URL, e.g. http://127.0.0.1:7890. Applies only to this provider's API requests and does not affect Docker internal networking."
},
"model": {
"description": "Model ID",
"hint": "Model name, e.g., gpt-4o-mini, deepseek-chat."
@@ -8,9 +8,6 @@
"handlersOperation": "Manage Handlers",
"market": "AstrBot Plugin Market"
},
"titles": {
"installedAstrBotPlugins": "Installed AstrBot Plugins"
},
"search": {
"placeholder": "Search extensions...",
"marketPlaceholder": "Search market extensions..."
@@ -225,7 +222,7 @@
"deleteSuccess": "Deleted successfully",
"deleteFailed": "Delete failed",
"runtimeNoneWarning": "Computer Use runtime is set to None; Skills may not run correctly because no runtime is enabled.",
"runtimeHint": "Set the Computer Use runtime to Local, Local Sandboxed, or Sandbox in settings so AstrBot can use your Skills."
"runtimeHint": "Set the Computer Use runtime to Local or Sandbox in settings so AstrBot can use your Skills."
},
"card": {
"actions": {
@@ -8,14 +8,11 @@
"refresh": "Refresh",
"save": "Save",
"add": "Add SubAgent",
"delete": "Delete",
"close": "Close"
"delete": "Delete"
},
"switches": {
"enable": "Enable SubAgent orchestration",
"enableHint": "Enable sub-agent functionality",
"dedupe": "Deduplicate main LLM tools (hide tools duplicated by SubAgents)",
"dedupeHint": "Remove duplicate tools from main agent"
"dedupe": "Deduplicate main LLM tools (hide tools duplicated by SubAgents)"
},
"description": {
"disabled": "When off: SubAgent is disabled; the main LLM mounts tools via persona rules (all by default) and calls them directly.",
@@ -32,8 +29,7 @@
"transferPrefix": "transfer_to_{name}",
"switchLabel": "Enable",
"previewTitle": "Preview: handoff tool shown to the main LLM",
"personaChip": "Persona: {id}",
"personaPreview": "PERSONA PREVIEW"
"personaChip": "Persona: {id}"
},
"form": {
"nameLabel": "Agent name (used for transfer_to_{name})",
@@ -53,13 +49,6 @@
"nameDuplicate": "Duplicate SubAgent name: {name}",
"personaMissing": "SubAgent {name} has no persona selected",
"saveSuccess": "Saved successfully",
"saveFailed": "Failed to save",
"nameRequired": "Name is required",
"namePattern": "Lowercase letters, numbers, underscore only"
},
"empty": {
"title": "No Agents Configured",
"subtitle": "Add a new sub-agent to get started",
"action": "Create First Agent"
"saveFailed": "Failed to save"
}
}
@@ -58,18 +58,6 @@
"guideStep2": "完成安装后重启 AstrBot。",
"guideStep3": "如果你使用 Docker,请优先使用镜像更新方式。"
},
"desktopApp": {
"title": "更新桌面应用",
"message": "将检查并升级 AstrBot 桌面端程序。",
"currentVersion": "当前版本:",
"latestVersion": "最新版本:",
"checking": "正在检查桌面应用更新...",
"hasNewVersion": "发现新版本,可点击确认升级。",
"isLatest": "已经是最新版本",
"installing": "正在下载并安装更新,完成后将自动重启应用...",
"checkFailed": "检查更新失败,请稍后重试。",
"installFailed": "升级失败,请稍后重试。"
},
"dashboardUpdate": {
"title": "单独更新管理面板到最新版本",
"currentVersion": "当前版本",
@@ -82,6 +82,11 @@
"on": "流式",
"off": "普通"
},
"transport": {
"title": "通信传输模式",
"sse": "SSE",
"websocket": "WebSocket"
},
"config": {
"title": "配置文件"
},
@@ -150,7 +150,7 @@
"provider_settings": {
"computer_use_runtime": {
"description": "运行环境",
"hint": "sandbox 代表在远程沙箱环境中运行, local 代表在本地直接运行, local_sandboxed 代表本地运行但使用系统沙箱(bwrap/seatbelt)增强隔离, none 代表不启用。如果上传了 skills,选择 none 会导致其无法被 Agent 正常使用。"
"hint": "sandbox 代表在沙箱环境中运行, local 代表在本地环境中运行, none 代表不启用。如果上传了 skills,选择 none 会导致其无法被 Agent 正常使用。"
},
"computer_use_require_admin": {
"description": "需要 AstrBot 管理员权限",
@@ -254,10 +254,6 @@
"show_tool_use_status": {
"description": "输出函数调用状态"
},
"show_tool_call_result": {
"description": "输出函数调用返回结果",
"hint": "仅在启用“输出函数调用状态”时生效,且最多展示 70 个字符。"
},
"sanitize_context_by_modalities": {
"description": "按模型能力清理历史上下文",
"hint": "开启后,在每次请求 LLM 前会按当前模型提供商中所选择的模型能力删除对话中不支持的图片/工具调用结构(会改变模型看到的历史)"
@@ -1089,12 +1085,6 @@
"embedding_api_base": {
"description": "API Base URL"
},
"openai_embedding": {
"hint": "OpenAI Embedding 会在请求时自动补上 /v1。"
},
"gemini_embedding": {
"hint": "Gemini Embedding 无需手动添加 /v1beta。"
},
"volcengine_cluster": {
"description": "火山引擎集群",
"hint": "若使用语音复刻大模型,可选volcano_icl或volcano_icl_concurr,默认使用volcano_tts"
@@ -1322,10 +1312,6 @@
"api_base": {
"description": "API Base URL"
},
"proxy": {
"description": "代理地址",
"hint": "HTTP/HTTPS 代理地址,格式如 http://127.0.0.1:7890。仅对该提供商的 API 请求生效,不影响 Docker 内网通信。"
},
"model": {
"description": "模型 ID",
"hint": "模型名称,如 gpt-4o-mini, deepseek-chat。"
@@ -8,9 +8,6 @@
"skills": "Skills",
"handlersOperation": "管理行为"
},
"titles": {
"installedAstrBotPlugins": "已安装的 AstrBot 插件"
},
"search": {
"placeholder": "搜索插件...",
"marketPlaceholder": "搜索市场插件..."
@@ -225,7 +222,7 @@
"deleteSuccess": "删除成功",
"deleteFailed": "删除失败",
"runtimeNoneWarning": "Computer Use 运行环境为无,Skills 可能无法正确被 Agent 运行,因为没有启用运行环境。",
"runtimeHint": "需要在配置的 “使用电脑能力” 中将运行环境设置为 “local”、“local_sandboxed” 或 “sandbox” 才能让 AstrBot 正常使用你提供的 Skills。"
"runtimeHint": "需要在配置的 “使用电脑能力” 中将运行环境设置为 “local” 或 “sandbox” 才能让 AstrBot 正常使用你提供的 Skills。"
},
"card": {
"actions": {
@@ -8,14 +8,11 @@
"refresh": "刷新",
"save": "保存",
"add": "新增 SubAgent",
"delete": "删除",
"close": "关闭"
"delete": "删除"
},
"switches": {
"enable": "启用 SubAgent 编排",
"enableHint": "启用子代理功能",
"dedupe": "主 LLM 去重重复工具(与 SubAgent 重叠的工具将被隐藏)",
"dedupeHint": "从主代理中移除重复工具"
"dedupe": "主 LLM 去重重复工具(与 SubAgent 重叠的工具将被隐藏)"
},
"description": {
"disabled": "不启动:SubAgent 关闭;主 LLM 按 persona 规则挂载工具(默认全部),并直接调用。",
@@ -42,7 +39,6 @@
"providerHint": "留空表示跟随全局默认 provider。",
"personaLabel": "选择人格设定",
"personaHint": "SubAgent 将直接继承所选 Persona 的系统设定与工具。在人格设定页管理和新建人格。",
"personaPreview": "人格预览",
"descriptionLabel": "对主 LLM 的描述(用于决定是否 handoff",
"descriptionHint": "这段会作为 transfer_to_* 工具的描述给主 LLM 看,建议简短明确。"
},
@@ -54,13 +50,6 @@
"nameDuplicate": "SubAgent 名称重复:{name}",
"personaMissing": "SubAgent {name} 未选择 Persona",
"saveSuccess": "保存成功",
"saveFailed": "保存失败",
"nameRequired": "名称必填",
"namePattern": "仅支持小写字母、数字和下划线"
},
"empty": {
"title": "未配置 SubAgent",
"subtitle": "添加一个新的子代理以开始",
"action": "创建第一个 Agent"
"saveFailed": "保存失败"
}
}
@@ -50,27 +50,24 @@ let installLoading = ref(false);
const isDesktopReleaseMode = ref(
typeof window !== 'undefined' && !!window.astrbotDesktop?.isDesktop
);
const desktopUpdateDialog = ref(false);
const desktopUpdateChecking = ref(false);
const desktopUpdateInstalling = ref(false);
const desktopUpdateHasNewVersion = ref(false);
const desktopUpdateCurrentVersion = ref('-');
const desktopUpdateLatestVersion = ref('-');
const desktopUpdateStatus = ref('');
const getAppUpdaterBridge = (): AstrBotAppUpdaterBridge | null => {
if (typeof window === 'undefined') {
return null;
const redirectConfirmDialog = ref(false);
const pendingRedirectUrl = ref('');
const resolvingReleaseTarget = ref(false);
const DEFAULT_ASTRBOT_RELEASE_BASE_URL = 'https://github.com/AstrBotDevs/AstrBot/releases';
const resolveReleaseBaseUrl = () => {
const raw = import.meta.env.VITE_ASTRBOT_RELEASE_BASE_URL;
// Keep upstream default on AstrBot releases; desktop distributors can override via env injection.
const normalized = raw?.trim()?.replace(/\/+$/, '') || '';
const withoutLatestSuffix = normalized.replace(/\/latest$/i, '');
return withoutLatestSuffix || DEFAULT_ASTRBOT_RELEASE_BASE_URL;
};
const releaseBaseUrl = resolveReleaseBaseUrl();
const getReleaseUrlByTag = (tag: string | null | undefined) => {
const normalizedTag = (tag || '').trim();
if (!normalizedTag || normalizedTag.toLowerCase() === 'latest') {
return `${releaseBaseUrl}/latest`;
}
const bridge = window.astrbotAppUpdater;
if (
bridge &&
typeof bridge.checkForAppUpdate === 'function' &&
typeof bridge.installAppUpdate === 'function'
) {
return bridge;
}
return null;
return `${releaseBaseUrl}/tag/${normalizedTag}`;
};
const getSelectedGitHubProxy = () => {
@@ -92,6 +89,16 @@ const releasesHeader = computed(() => [
{ title: t('core.header.updateDialog.table.sourceUrl'), key: 'zipball_url' },
{ title: t('core.header.updateDialog.table.actions'), key: 'switch' }
]);
const latestReleaseTag = computed(() => {
const firstRelease = (releases.value as any[])?.[0];
if (firstRelease?.tag_name) {
return firstRelease.tag_name as string;
}
return hasNewVersion.value
? t('core.header.updateDialog.redirectConfirm.latestLabel')
: (botCurrVersion.value || '-');
});
// Form validation
const formValid = ref(true);
const passwordRules = computed(() => [
@@ -119,88 +126,47 @@ const accountEditStatus = ref({
message: ''
});
function cancelDesktopUpdate() {
if (desktopUpdateInstalling.value) {
return;
}
desktopUpdateDialog.value = false;
const open = (link: string) => {
window.open(link, '_blank');
};
function requestExternalRedirect(link: string) {
pendingRedirectUrl.value = link;
redirectConfirmDialog.value = true;
}
async function openDesktopUpdateDialog() {
desktopUpdateDialog.value = true;
desktopUpdateChecking.value = true;
desktopUpdateInstalling.value = false;
desktopUpdateHasNewVersion.value = false;
desktopUpdateCurrentVersion.value = '-';
desktopUpdateLatestVersion.value = '-';
desktopUpdateStatus.value = t('core.header.updateDialog.desktopApp.checking');
function cancelExternalRedirect() {
redirectConfirmDialog.value = false;
pendingRedirectUrl.value = '';
}
const bridge = getAppUpdaterBridge();
if (!bridge) {
desktopUpdateChecking.value = false;
desktopUpdateStatus.value = t('core.header.updateDialog.desktopApp.checkFailed');
return;
}
try {
const result = await bridge.checkForAppUpdate();
if (!result?.ok) {
desktopUpdateCurrentVersion.value = result?.currentVersion || '-';
desktopUpdateLatestVersion.value =
result?.latestVersion || result?.currentVersion || '-';
desktopUpdateStatus.value =
result?.reason || t('core.header.updateDialog.desktopApp.checkFailed');
return;
}
desktopUpdateCurrentVersion.value = result.currentVersion || '-';
desktopUpdateLatestVersion.value =
result.latestVersion || result.currentVersion || '-';
desktopUpdateHasNewVersion.value = !!result.hasUpdate;
desktopUpdateStatus.value = result.hasUpdate
? t('core.header.updateDialog.desktopApp.hasNewVersion')
: t('core.header.updateDialog.desktopApp.isLatest');
} catch (error) {
console.error(error);
desktopUpdateStatus.value = t('core.header.updateDialog.desktopApp.checkFailed');
} finally {
desktopUpdateChecking.value = false;
function confirmExternalRedirect() {
const targetUrl = pendingRedirectUrl.value;
cancelExternalRedirect();
if (targetUrl) {
open(targetUrl);
}
}
async function confirmDesktopUpdate() {
if (!desktopUpdateHasNewVersion.value || desktopUpdateInstalling.value) {
return;
const getReleaseUrlForDesktop = () => {
const firstRelease = (releases.value as any[])?.[0];
if (firstRelease?.tag_name) {
return getReleaseUrlByTag(firstRelease.tag_name as string);
}
const bridge = getAppUpdaterBridge();
if (!bridge) {
desktopUpdateStatus.value = t('core.header.updateDialog.desktopApp.installFailed');
return;
}
desktopUpdateInstalling.value = true;
desktopUpdateStatus.value = t('core.header.updateDialog.desktopApp.installing');
try {
const result = await bridge.installAppUpdate();
if (result?.ok) {
desktopUpdateDialog.value = false;
return;
}
desktopUpdateStatus.value =
result?.reason || t('core.header.updateDialog.desktopApp.installFailed');
} catch (error) {
console.error(error);
desktopUpdateStatus.value = t('core.header.updateDialog.desktopApp.installFailed');
} finally {
desktopUpdateInstalling.value = false;
}
}
if (hasNewVersion.value) return getReleaseUrlByTag('latest');
const tag = botCurrVersion.value?.startsWith('v') ? botCurrVersion.value : 'latest';
return getReleaseUrlByTag(tag);
};
function handleUpdateClick() {
if (isDesktopReleaseMode.value) {
void openDesktopUpdateDialog();
requestExternalRedirect('');
resolvingReleaseTarget.value = true;
checkUpdate();
void getReleases().finally(() => {
pendingRedirectUrl.value = getReleaseUrlForDesktop() || getReleaseUrlByTag('latest');
resolvingReleaseTarget.value = false;
});
return;
}
checkUpdate();
@@ -714,38 +680,40 @@ onMounted(async () => {
</v-card>
</v-dialog>
<v-dialog v-model="desktopUpdateDialog" max-width="460">
<v-dialog v-model="redirectConfirmDialog" max-width="460">
<v-card>
<v-card-title class="text-h3 pa-4 pl-6 pb-0">
{{ t('core.header.updateDialog.desktopApp.title') }}
{{ t('core.header.updateDialog.redirectConfirm.title') }}
</v-card-title>
<v-card-text>
<div class="mb-3">
{{ t('core.header.updateDialog.desktopApp.message') }}
{{ t('core.header.updateDialog.redirectConfirm.message') }}
</div>
<v-alert type="info" variant="tonal" density="compact">
<div>
{{ t('core.header.updateDialog.desktopApp.currentVersion') }}
<strong>{{ desktopUpdateCurrentVersion }}</strong>
</div>
<div>
{{ t('core.header.updateDialog.desktopApp.latestVersion') }}
<strong v-if="!desktopUpdateChecking">{{ desktopUpdateLatestVersion }}</strong>
{{ t('core.header.updateDialog.redirectConfirm.targetVersion') }}
<strong v-if="!resolvingReleaseTarget">{{ latestReleaseTag }}</strong>
<v-progress-circular v-else indeterminate size="16" width="2" class="ml-1" />
</div>
<div class="text-caption">
{{ t('core.header.updateDialog.redirectConfirm.currentVersion') }}
{{ botCurrVersion || '-' }}
</div>
</v-alert>
<div class="text-caption mt-3">
{{ desktopUpdateStatus }}
<div>{{ t('core.header.updateDialog.redirectConfirm.guideTitle') }}</div>
<div>1. {{ t('core.header.updateDialog.redirectConfirm.guideStep1') }}</div>
<div>2. {{ t('core.header.updateDialog.redirectConfirm.guideStep2') }}</div>
<div>3. {{ t('core.header.updateDialog.redirectConfirm.guideStep3') }}</div>
</div>
</v-card-text>
<v-card-actions>
<v-spacer></v-spacer>
<v-btn color="grey" variant="text" @click="cancelDesktopUpdate" :disabled="desktopUpdateInstalling">
<v-btn color="grey" variant="text" @click="cancelExternalRedirect">
{{ t('core.common.dialog.cancelButton') }}
</v-btn>
<v-btn color="primary" variant="flat" @click="confirmDesktopUpdate"
:loading="desktopUpdateInstalling"
:disabled="desktopUpdateChecking || desktopUpdateInstalling || !desktopUpdateHasNewVersion">
<v-btn color="primary" variant="flat" @click="confirmExternalRedirect"
:loading="resolvingReleaseTarget" :disabled="resolvingReleaseTarget || !pendingRedirectUrl">
{{ t('core.common.dialog.confirmButton') }}
</v-btn>
</v-card-actions>
@@ -1,26 +1,7 @@
export {};
declare global {
interface AstrBotDesktopAppUpdateCheckResult {
ok: boolean;
reason?: string | null;
currentVersion?: string;
latestVersion?: string | null;
hasUpdate: boolean;
}
interface AstrBotDesktopAppUpdateResult {
ok: boolean;
reason?: string | null;
}
interface AstrBotAppUpdaterBridge {
checkForAppUpdate: () => Promise<AstrBotDesktopAppUpdateCheckResult>;
installAppUpdate: () => Promise<AstrBotDesktopAppUpdateResult>;
}
interface Window {
astrbotAppUpdater?: AstrBotAppUpdaterBridge;
astrbotDesktop?: {
isDesktop: boolean;
isDesktopRuntime: () => Promise<boolean>;
-1
View File
@@ -61,7 +61,6 @@ export function getTutorialLink(platformType) {
"vocechat": "https://docs.astrbot.app/deploy/platform/vocechat.html",
"satori": "https://docs.astrbot.app/deploy/platform/satori/llonebot.html",
"misskey": "https://docs.astrbot.app/deploy/platform/misskey.html",
"line": "https://docs.astrbot.app/deploy/platform/line.html",
}
return tutorialMap[platformType] || "https://docs.astrbot.app";
}
File diff suppressed because it is too large Load Diff
+3 -44
View File
@@ -333,53 +333,12 @@ const loadApiKeys = async () => {
}
};
const tryExecCommandCopy = (text) => {
let textArea = null;
try {
if (typeof document === 'undefined' || !document.body) return false;
textArea = document.createElement('textarea');
textArea.value = text;
textArea.setAttribute('readonly', '');
textArea.style.position = 'fixed';
textArea.style.opacity = '0';
textArea.style.pointerEvents = 'none';
textArea.style.left = '-9999px';
document.body.appendChild(textArea);
textArea.focus();
textArea.select();
textArea.setSelectionRange(0, text.length);
return document.execCommand('copy');
} catch (_) {
return false;
} finally {
try {
if (textArea?.parentNode) {
textArea.parentNode.removeChild(textArea);
}
} catch (_) {
// ignore cleanup errors
}
}
};
const copyTextToClipboard = async (text) => {
if (!text) return false;
if (tryExecCommandCopy(text)) return true;
if (typeof navigator === 'undefined' || !navigator.clipboard?.writeText) return false;
try {
await navigator.clipboard.writeText(text);
return true;
} catch (_) {
return false;
}
};
const copyCreatedApiKey = async () => {
if (!createdApiKeyPlaintext.value) return;
const ok = await copyTextToClipboard(createdApiKeyPlaintext.value);
if (ok) {
try {
await navigator.clipboard.writeText(createdApiKeyPlaintext.value);
showToast(tm('apiKey.messages.copySuccess'), 'success');
} else {
} catch (_) {
showToast(tm('apiKey.messages.copyFailed'), 'error');
}
};
+8 -8
View File
@@ -62,7 +62,7 @@
<template #label>
<div class="d-flex flex-column">
<span class="text-body-2 font-weight-medium">{{ tm('switches.enable') }}</span>
<span class="text-caption text-medium-emphasis">{{ tm('switches.enableHint') }}</span>
<span class="text-caption text-medium-emphasis">Enable sub-agent functionality</span>
</div>
</template>
</v-switch>
@@ -80,7 +80,7 @@
<template #label>
<div class="d-flex flex-column">
<span class="text-body-2 font-weight-medium">{{ tm('switches.dedupe') }}</span>
<span class="text-caption text-medium-emphasis">{{ tm('switches.dedupeHint') }}</span>
<span class="text-caption text-medium-emphasis">Remove duplicate tools from main agent</span>
</div>
</template>
</v-switch>
@@ -166,7 +166,7 @@
<v-text-field
v-model="agent.name"
:label="tm('form.nameLabel')"
:rules="[v => !!v || tm('messages.nameRequired'), v => /^[a-z][a-z0-9_]*$/.test(v) || tm('messages.namePattern')]"
:rules="[v => !!v || 'Name is required', v => /^[a-z][a-z0-9_]*$/.test(v) || 'Lowercase letters, numbers, underscore only']"
variant="outlined"
density="comfortable"
hide-details="auto"
@@ -215,7 +215,7 @@
<v-col cols="12" md="6">
<div class="h-100">
<div class="text-caption font-weight-bold text-medium-emphasis mb-2 ml-1">
{{ tm('cards.personaPreview') }}
PERSONA PREVIEW
</div>
<PersonaQuickPreview
:model-value="agent.persona_id"
@@ -231,17 +231,17 @@
<!-- Empty State -->
<div v-if="cfg.agents.length === 0" class="d-flex flex-column align-center justify-center py-12 text-medium-emphasis">
<v-icon icon="mdi-robot-off" size="64" class="mb-4 opacity-50" />
<div class="text-h6">{{ tm('empty.title') }}</div>
<div class="text-body-2 mb-4">{{ tm('empty.subtitle') }}</div>
<div class="text-h6">No Agents Configured</div>
<div class="text-body-2 mb-4">Add a new sub-agent to get started</div>
<v-btn color="primary" variant="tonal" @click="addAgent">
{{ tm('empty.action') }}
Create First Agent
</v-btn>
</div>
<v-snackbar v-model="snackbar.show" :color="snackbar.color" timeout="3000" location="top">
{{ snackbar.message }}
<template #actions>
<v-btn variant="text" @click="snackbar.show = false">{{ tm('actions.close') }}</v-btn>
<v-btn variant="text" @click="snackbar.show = false">Close</v-btn>
</template>
</v-snackbar>
</div>
@@ -1,639 +0,0 @@
<script setup>
import ExtensionCard from "@/components/shared/ExtensionCard.vue";
import StyledMenu from "@/components/shared/StyledMenu.vue";
import defaultPluginIcon from "@/assets/images/plugin_icon.png";
const props = defineProps({
state: {
type: Object,
required: true,
},
});
const {
commonStore,
t,
tm,
router,
route,
getSelectedGitHubProxy,
conflictDialog,
checkAndPromptConflicts,
handleConflictConfirm,
fileInput,
activeTab,
validTabs,
isValidTab,
getLocationHash,
extractTabFromHash,
syncTabFromHash,
extension_data,
getInitialShowReserved,
showReserved,
snack_message,
snack_show,
snack_success,
configDialog,
extension_config,
pluginMarketData,
loadingDialog,
showPluginInfoDialog,
selectedPlugin,
curr_namespace,
updatingAll,
readmeDialog,
forceUpdateDialog,
updateAllConfirmDialog,
changelogDialog,
getInitialListViewMode,
isListView,
pluginSearch,
loading_,
currentPage,
dangerConfirmDialog,
selectedDangerPlugin,
selectedMarketInstallPlugin,
installCompat,
versionCompatibilityDialog,
showUninstallDialog,
pluginToUninstall,
showSourceDialog,
showSourceManagerDialog,
sourceName,
sourceUrl,
customSources,
selectedSource,
showRemoveSourceDialog,
sourceToRemove,
editingSource,
originalSourceUrl,
extension_url,
dialog,
upload_file,
uploadTab,
showPluginFullName,
marketSearch,
debouncedMarketSearch,
refreshingMarket,
sortBy,
sortOrder,
randomPluginNames,
normalizeStr,
toPinyinText,
toInitials,
marketCustomFilter,
plugin_handler_info_headers,
pluginHeaders,
filteredExtensions,
filteredPlugins,
filteredMarketPlugins,
sortedPlugins,
RANDOM_PLUGINS_COUNT,
randomPlugins,
shufflePlugins,
refreshRandomPlugins,
displayItemsPerPage,
totalPages,
paginatedPlugins,
updatableExtensions,
toggleShowReserved,
toast,
resetLoadingDialog,
onLoadingDialogResult,
failedPluginsDict,
getExtensions,
handleReloadAllFailed,
checkUpdate,
uninstallExtension,
handleUninstallConfirm,
updateExtension,
showUpdateAllConfirm,
confirmUpdateAll,
cancelUpdateAll,
confirmForceUpdate,
updateAllExtensions,
pluginOn,
pluginOff,
openExtensionConfig,
updateConfig,
showPluginInfo,
reloadPlugin,
viewReadme,
viewChangelog,
handleInstallPlugin,
confirmDangerInstall,
cancelDangerInstall,
loadCustomSources,
saveCustomSources,
addCustomSource,
openSourceManagerDialog,
selectPluginSource,
sourceSelectItems,
editCustomSource,
removeCustomSource,
confirmRemoveSource,
saveCustomSource,
trimExtensionName,
checkAlreadyInstalled,
showVersionCompatibilityWarning,
continueInstallIgnoringVersionWarning,
cancelInstallOnVersionWarning,
newExtension,
normalizePlatformList,
getPlatformDisplayList,
resolveSelectedInstallPlugin,
selectedInstallPlugin,
checkInstallCompatibility,
refreshPluginMarket,
handleLocaleChange,
searchDebounceTimer,
} = props.state;
</script>
<template>
<v-tab-item v-show="activeTab === 'installed'">
<div class="mb-4 pt-4 pb-4">
<div class="d-flex align-center flex-wrap" style="gap: 12px">
<h2 class="text-h2 mb-0">{{ tm("titles.installedAstrBotPlugins") }}</h2>
<div class="d-flex align-center flex-wrap ml-auto" style="gap: 8px">
<v-text-field
v-model="pluginSearch"
density="compact"
:label="tm('search.placeholder')"
prepend-inner-icon="mdi-magnify"
variant="solo-filled"
flat
hide-details
single-line
style="min-width: 220px; max-width: 340px"
>
</v-text-field>
<v-btn-toggle
v-model="isListView"
mandatory
density="compact"
color="primary"
class="view-mode-toggle"
>
<v-btn :value="false" icon="mdi-view-grid"></v-btn>
<v-btn :value="true" icon="mdi-view-list"></v-btn>
</v-btn-toggle>
</div>
</div>
</div>
<v-row class="mb-4">
<v-col cols="12" class="d-flex align-center flex-wrap ga-2">
<v-btn variant="tonal" @click="toggleShowReserved">
<v-icon>{{
showReserved ? "mdi-eye-off" : "mdi-eye"
}}</v-icon>
{{
showReserved
? tm("buttons.hideSystemPlugins")
: tm("buttons.showSystemPlugins")
}}
</v-btn>
<v-btn
class="ml-2"
color="warning"
variant="tonal"
:disabled="updatableExtensions.length === 0"
:loading="updatingAll"
@click="showUpdateAllConfirm"
>
<v-icon>mdi-update</v-icon>
{{ tm("buttons.updateAll") }}
</v-btn>
<v-dialog max-width="500px" v-if="extension_data.message">
<template v-slot:activator="{ props }">
<v-btn
v-bind="props"
icon
size="small"
color="error"
class="ml-auto"
variant="tonal"
>
<v-icon>mdi-alert-circle</v-icon>
</v-btn>
</template>
<template v-slot:default="{ isActive }">
<v-card class="rounded-lg">
<v-card-title class="headline d-flex align-center">
<v-icon color="error" class="mr-2"
>mdi-alert-circle</v-icon
>
{{ tm("dialogs.error.title") }}
</v-card-title>
<v-card-text>
<p class="text-body-1">
{{ extension_data.message }}
</p>
<p class="text-caption mt-2">
{{ tm("dialogs.error.checkConsole") }}
</p>
</v-card-text>
<v-card-actions>
<v-btn
color="error"
variant="tonal"
prepend-icon="mdi-refresh"
@click="handleReloadAllFailed"
>
尝试一键重载修复
</v-btn>
<v-spacer></v-spacer>
<v-btn
color="primary"
@click="isActive.value = false"
>{{ tm("buttons.close") }}</v-btn
>
</v-card-actions>
</v-card>
</template>
</v-dialog>
</v-col>
</v-row>
<v-fade-transition hide-on-leave>
<!-- 表格视图 -->
<div v-if="isListView">
<v-card class="rounded-lg overflow-hidden elevation-0">
<v-data-table
:headers="pluginHeaders"
:items="filteredPlugins"
:loading="loading_"
item-key="name"
hover
>
<template v-slot:loader>
<v-row class="py-8 d-flex align-center justify-center">
<v-progress-circular
indeterminate
color="primary"
></v-progress-circular>
<span class="ml-2">{{ tm("status.loading") }}</span>
</v-row>
</template>
<template v-slot:item.name="{ item }">
<div class="d-flex align-center py-2">
<div
v-if="item.logo"
class="mr-3"
style="flex-shrink: 0"
>
<img
:src="item.logo"
:alt="item.name"
style="
height: 40px;
width: 40px;
border-radius: 8px;
object-fit: cover;
"
/>
</div>
<div v-else class="mr-3" style="flex-shrink: 0">
<img
:src="defaultPluginIcon"
:alt="item.name"
style="
height: 40px;
width: 40px;
border-radius: 8px;
object-fit: cover;
"
/>
</div>
<div>
<div class="text-h5" style="font-family: inherit;">
{{
item.display_name && item.display_name.length
? item.display_name
: item.name
}}
</div>
<div
v-if="item.display_name && item.display_name.length"
class="text-caption text-medium-emphasis mt-1"
>
{{ item.name }}
</div>
<div
v-if="item.reserved"
class="d-flex align-center mt-1"
>
<v-chip
color="primary"
size="x-small"
class="font-weight-medium"
>{{ tm("status.system") }}</v-chip
>
</div>
</div>
</div>
</template>
<template v-slot:item.desc="{ item }">
<div class="py-2">
<div
class="text-body-2 text-medium-emphasis"
style="
display: -webkit-box;
-webkit-line-clamp: 3;
line-clamp: 3;
-webkit-box-orient: vertical;
overflow: hidden;
text-overflow: ellipsis;
"
>
{{ item.desc }}
</div>
<div
v-if="item.support_platforms?.length"
class="d-flex align-center flex-wrap mt-2"
>
<span class="text-caption text-medium-emphasis mr-2">
{{ tm("card.status.supportPlatform") }}:
</span>
<v-chip
v-for="platformId in item.support_platforms"
:key="platformId"
size="x-small"
color="info"
variant="outlined"
class="mr-1 mb-1"
>
{{ platformId }}
</v-chip>
</div>
<div
v-if="item.astrbot_version"
class="d-flex align-center flex-wrap mt-1"
>
<span class="text-caption text-medium-emphasis mr-2">
{{ tm("card.status.astrbotVersion") }}:
</span>
<v-chip
size="x-small"
color="secondary"
variant="outlined"
class="mr-1 mb-1"
>
{{ item.astrbot_version }}
</v-chip>
</div>
</div>
</template>
<template v-slot:item.version="{ item }">
<div class="d-flex align-center">
<span class="text-body-2">{{ item.version }}</span>
<v-icon
v-if="item.has_update"
color="warning"
size="small"
class="ml-1"
>mdi-alert</v-icon
>
<v-tooltip v-if="item.has_update" activator="parent">
<span
>{{ tm("messages.hasUpdate") }}
{{ item.online_version }}</span
>
</v-tooltip>
</div>
</template>
<template v-slot:item.author="{ item }">
<div class="text-body-2">{{ item.author }}</div>
</template>
<template v-slot:item.actions="{ item }">
<div class="table-action-row d-flex align-center flex-nowrap ga-2 py-1">
<v-btn
v-if="!item.activated"
size="small"
variant="tonal"
color="success"
class="table-action-btn"
prepend-icon="mdi-play"
@click="pluginOn(item)"
>
{{ tm("buttons.enable") }}
</v-btn>
<v-btn
v-else
size="small"
variant="tonal"
color="error"
class="table-action-btn"
prepend-icon="mdi-pause"
@click="pluginOff(item)"
>
{{ tm("buttons.disable") }}
</v-btn>
<v-btn
size="small"
variant="tonal"
color="primary"
class="table-action-btn"
prepend-icon="mdi-refresh"
@click="reloadPlugin(item.name)"
>
{{ tm("buttons.reload") }}
</v-btn>
<v-btn
size="small"
variant="tonal"
color="primary"
class="table-action-btn"
prepend-icon="mdi-cog"
@click="openExtensionConfig(item.name)"
>
{{ tm("buttons.configure") }}
</v-btn>
<v-btn
size="small"
variant="tonal"
color="info"
class="table-action-btn"
prepend-icon="mdi-book-open-page-variant"
:disabled="!item.repo"
@click="item.repo && viewReadme(item)"
>
{{ tm("buttons.viewDocs") }}
</v-btn>
<StyledMenu location="bottom end" offset="8">
<template #activator="{ props: menuProps }">
<v-btn
v-bind="menuProps"
icon="mdi-dots-horizontal"
size="small"
variant="tonal"
color="secondary"
class="table-action-btn"
></v-btn>
</template>
<v-list-item
class="styled-menu-item"
prepend-icon="mdi-information"
@click="showPluginInfo(item)"
>
<v-list-item-title>{{ tm("buttons.viewInfo") }}</v-list-item-title>
</v-list-item>
<v-list-item
class="styled-menu-item"
prepend-icon="mdi-update"
@click="updateExtension(item.name)"
>
<v-list-item-title>{{ tm("buttons.update") }}</v-list-item-title>
</v-list-item>
<v-list-item
class="styled-menu-item"
prepend-icon="mdi-delete"
:disabled="item.reserved"
@click="uninstallExtension(item.name)"
>
<v-list-item-title>{{ tm("buttons.uninstall") }}</v-list-item-title>
</v-list-item>
</StyledMenu>
</div>
</template>
<template v-slot:no-data>
<div class="text-center pa-8">
<v-icon size="64" color="info" class="mb-4"
>mdi-puzzle-outline</v-icon
>
<div class="text-h5 mb-2">
{{ tm("empty.noPlugins") }}
</div>
<div class="text-body-1 mb-4">
{{ tm("empty.noPluginsDesc") }}
</div>
</div>
</template>
</v-data-table>
</v-card>
</div>
<!-- 卡片视图 -->
<div v-else>
<v-row v-if="filteredPlugins.length === 0" class="text-center">
<v-col cols="12" class="pa-2">
<v-icon size="64" color="info" class="mb-4"
>mdi-puzzle-outline</v-icon
>
<div class="text-h5 mb-2">{{ tm("empty.noPlugins") }}</div>
<div class="text-body-1 mb-4">
{{ tm("empty.noPluginsDesc") }}
</div>
</v-col>
</v-row>
<v-row>
<v-col
cols="12"
md="6"
lg="4"
v-for="extension in filteredPlugins"
:key="extension.name"
class="pb-2"
>
<ExtensionCard
:extension="extension"
class="rounded-lg"
style="background-color: rgb(var(--v-theme-mcpCardBg))"
@configure="openExtensionConfig(extension.name)"
@uninstall="
(ext, options) => uninstallExtension(ext.name, options)
"
@update="updateExtension(extension.name)"
@reload="reloadPlugin(extension.name)"
@toggle-activation="
extension.activated
? pluginOff(extension)
: pluginOn(extension)
"
@view-handlers="showPluginInfo(extension)"
@view-readme="viewReadme(extension)"
@view-changelog="viewChangelog(extension)"
>
</ExtensionCard>
</v-col>
</v-row>
</div>
</v-fade-transition>
<v-tooltip :text="tm('market.installPlugin')" location="left">
<template v-slot:activator="{ props }">
<button
v-bind="props"
type="button"
class="v-btn v-btn--elevated v-btn--icon v-theme--PurpleThemeDark bg-darkprimary v-btn--density-default v-btn--size-x-large v-btn--variant-elevated fab-button"
style="
position: fixed;
right: 52px;
bottom: 52px;
z-index: 10000;
border-radius: 16px;
"
@click="dialog = true"
>
<span class="v-btn__overlay"></span>
<span class="v-btn__underlay"></span>
<span class="v-btn__content" data-no-activator="">
<i
class="mdi-plus mdi v-icon notranslate v-theme--PurpleThemeDark v-icon--size-default"
aria-hidden="true"
style="font-size: 32px"
></i>
</span>
</button>
</template>
</v-tooltip>
</v-tab-item>
</template>
<style scoped>
.view-mode-toggle :deep(.v-btn) {
min-width: 30px;
height: 28px;
padding: 0 8px;
}
.table-action-btn {
min-height: 34px;
font-size: 0.9rem;
font-weight: 600;
}
.table-action-row {
overflow-x: auto;
white-space: nowrap;
}
.fab-button {
transition: all 0.3s cubic-bezier(0.25, 0.8, 0.25, 1);
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
}
.fab-button:hover {
transform: translateY(-4px) scale(1.05);
box-shadow: 0 12px 20px rgba(var(--v-theme-primary), 0.4);
}
</style>
@@ -1,373 +0,0 @@
<script setup>
import MarketPluginCard from "@/components/extension/MarketPluginCard.vue";
import defaultPluginIcon from "@/assets/images/plugin_icon.png";
import { computed } from "vue";
const props = defineProps({
state: {
type: Object,
required: true,
},
});
const {
commonStore,
t,
tm,
router,
route,
getSelectedGitHubProxy,
conflictDialog,
checkAndPromptConflicts,
handleConflictConfirm,
fileInput,
activeTab,
validTabs,
isValidTab,
getLocationHash,
extractTabFromHash,
syncTabFromHash,
extension_data,
getInitialShowReserved,
showReserved,
snack_message,
snack_show,
snack_success,
configDialog,
extension_config,
pluginMarketData,
loadingDialog,
showPluginInfoDialog,
selectedPlugin,
curr_namespace,
updatingAll,
readmeDialog,
forceUpdateDialog,
updateAllConfirmDialog,
changelogDialog,
getInitialListViewMode,
isListView,
pluginSearch,
loading_,
currentPage,
dangerConfirmDialog,
selectedDangerPlugin,
selectedMarketInstallPlugin,
installCompat,
versionCompatibilityDialog,
showUninstallDialog,
pluginToUninstall,
showSourceDialog,
showSourceManagerDialog,
sourceName,
sourceUrl,
customSources,
selectedSource,
showRemoveSourceDialog,
sourceToRemove,
editingSource,
originalSourceUrl,
extension_url,
dialog,
upload_file,
uploadTab,
showPluginFullName,
marketSearch,
debouncedMarketSearch,
refreshingMarket,
sortBy,
sortOrder,
randomPluginNames,
normalizeStr,
toPinyinText,
toInitials,
marketCustomFilter,
plugin_handler_info_headers,
pluginHeaders,
filteredExtensions,
filteredPlugins,
filteredMarketPlugins,
sortedPlugins,
RANDOM_PLUGINS_COUNT,
randomPlugins,
shufflePlugins,
refreshRandomPlugins,
displayItemsPerPage,
totalPages,
paginatedPlugins,
updatableExtensions,
toggleShowReserved,
toast,
resetLoadingDialog,
onLoadingDialogResult,
failedPluginsDict,
getExtensions,
handleReloadAllFailed,
checkUpdate,
uninstallExtension,
handleUninstallConfirm,
updateExtension,
showUpdateAllConfirm,
confirmUpdateAll,
cancelUpdateAll,
confirmForceUpdate,
updateAllExtensions,
pluginOn,
pluginOff,
openExtensionConfig,
updateConfig,
showPluginInfo,
reloadPlugin,
viewReadme,
viewChangelog,
handleInstallPlugin,
confirmDangerInstall,
cancelDangerInstall,
loadCustomSources,
saveCustomSources,
addCustomSource,
openSourceManagerDialog,
selectPluginSource,
sourceSelectItems,
editCustomSource,
removeCustomSource,
confirmRemoveSource,
saveCustomSource,
trimExtensionName,
checkAlreadyInstalled,
showVersionCompatibilityWarning,
continueInstallIgnoringVersionWarning,
cancelInstallOnVersionWarning,
newExtension,
normalizePlatformList,
getPlatformDisplayList,
resolveSelectedInstallPlugin,
selectedInstallPlugin,
checkInstallCompatibility,
refreshPluginMarket,
handleLocaleChange,
searchDebounceTimer,
} = props.state;
const currentSourceName = computed(() => {
if (!selectedSource.value) {
return tm("market.defaultSource");
}
const matched = customSources.value.find((s) => s.url === selectedSource.value);
return matched?.name || tm("market.defaultSource");
});
</script>
<template>
<v-tab-item v-show="activeTab === 'market'">
<div class="mb-6 pt-4 pb-4">
<div class="d-flex align-center flex-wrap" style="gap: 12px">
<h2 class="text-h2 mb-0">{{ tm("tabs.market") }}</h2>
<v-tooltip location="top" :text="tm('market.sourceManagement')">
<template v-slot:activator="{ props }">
<v-btn
v-bind="props"
variant="tonal"
rounded="md"
color="primary"
class="text-none px-2"
@click="openSourceManagerDialog"
>
<v-icon size="18" class="mr-1">mdi-source-branch</v-icon>
<span class="text-truncate" style="max-width: 180px">
{{ currentSourceName }}
</span>
</v-btn>
</template>
</v-tooltip>
<v-text-field
v-model="marketSearch"
density="compact"
:label="tm('search.marketPlaceholder')"
prepend-inner-icon="mdi-magnify"
variant="solo-filled"
flat
hide-details
single-line
style="min-width: 220px; max-width: 340px"
>
</v-text-field>
</div>
<div
class="d-flex align-center text-caption text-medium-emphasis mt-2"
style="color: grey; line-height: 1.4"
>
<v-icon size="16" class="mr-1">mdi-alert-outline</v-icon>
<span>{{ tm("market.sourceSafetyWarning") }}</span>
</div>
</div>
<!-- <small style="color: var(--v-theme-secondaryText);">每个插件都是作者无偿提供的的劳动成果如果您喜欢某个插件 Star</small> -->
<!-- FAB Button -->
<v-tooltip :text="tm('market.installPlugin')" location="left">
<template v-slot:activator="{ props }">
<button
v-bind="props"
type="button"
class="v-btn v-btn--elevated v-btn--icon v-theme--PurpleThemeDark bg-darkprimary v-btn--density-default v-btn--size-x-large v-btn--variant-elevated fab-button"
style="
position: fixed;
right: 52px;
bottom: 52px;
z-index: 10000;
border-radius: 16px;
"
@click="dialog = true"
>
<span class="v-btn__overlay"></span>
<span class="v-btn__underlay"></span>
<span class="v-btn__content" data-no-activator="">
<i
class="mdi-plus mdi v-icon notranslate v-theme--PurpleThemeDark v-icon--size-default"
aria-hidden="true"
style="font-size: 32px"
></i>
</span>
</button>
</template>
</v-tooltip>
<div class="mt-4">
<div
class="d-flex align-center mb-2"
style="justify-content: space-between; flex-wrap: wrap; gap: 8px"
>
<h2>
{{ tm("market.randomPlugins") }}
</h2>
<v-btn
color="primary"
variant="tonal"
prepend-icon="mdi-shuffle-variant"
:disabled="pluginMarketData.length === 0"
@click="refreshRandomPlugins"
>
{{ tm("buttons.reshuffle") }}
</v-btn>
</div>
<v-row class="mb-6" dense>
<v-col
v-for="plugin in randomPlugins"
:key="`random-${plugin.name}`"
cols="12"
md="6"
lg="4"
class="pb-2"
>
<MarketPluginCard
:plugin="plugin"
:default-plugin-icon="defaultPluginIcon"
:show-plugin-full-name="showPluginFullName"
@install="handleInstallPlugin"
/>
</v-col>
</v-row>
<div
class="d-flex align-center mb-2"
style="
justify-content: space-between;
flex-wrap: wrap;
gap: 8px;
"
>
<div class="d-flex align-center" style="gap: 6px">
<h2>
{{ tm("market.allPlugins") }}({{
filteredMarketPlugins.length
}})
</h2>
<v-btn
icon
variant="text"
@click="refreshPluginMarket"
:loading="refreshingMarket"
>
<v-icon>mdi-refresh</v-icon>
</v-btn>
</div>
<div
class="d-flex align-center"
style="gap: 8px; flex-wrap: wrap"
>
<v-select
v-model="sortBy"
:items="[
{ title: tm('sort.default'), value: 'default' },
{ title: tm('sort.stars'), value: 'stars' },
{ title: tm('sort.author'), value: 'author' },
{ title: tm('sort.updated'), value: 'updated' },
]"
density="compact"
variant="outlined"
hide-details
style="max-width: 150px"
>
<template v-slot:prepend-inner>
<v-icon size="small">mdi-sort</v-icon>
</template>
</v-select>
<v-btn
icon
v-if="sortBy !== 'default'"
@click="sortOrder = sortOrder === 'desc' ? 'asc' : 'desc'"
variant="text"
density="compact"
>
<v-icon>{{
sortOrder === "desc"
? "mdi-sort-descending"
: "mdi-sort-ascending"
}}</v-icon>
<v-tooltip activator="parent" location="top">
{{
sortOrder === "desc"
? tm("sort.descending")
: tm("sort.ascending")
}}
</v-tooltip>
</v-btn>
</div>
</div>
<v-row style="min-height: 26rem" dense>
<v-col
v-for="plugin in paginatedPlugins"
:key="plugin.name"
cols="12"
md="6"
lg="4"
class="pb-2"
>
<MarketPluginCard
:plugin="plugin"
:default-plugin-icon="defaultPluginIcon"
:show-plugin-full-name="showPluginFullName"
@install="handleInstallPlugin"
/>
</v-col>
</v-row>
<div class="d-flex justify-center mt-4" v-if="totalPages > 1">
<v-pagination
v-model="currentPage"
:length="totalPages"
:total-visible="7"
size="small"
></v-pagination>
</div>
</div>
</v-tab-item>
</template>
File diff suppressed because it is too large Load Diff
+1
View File
@@ -43,6 +43,7 @@ export default defineConfig({
'/api': {
target: 'http://127.0.0.1:6185/',
changeOrigin: true,
ws: true
}
}
}
+1 -1
View File
@@ -1,6 +1,6 @@
[project]
name = "AstrBot"
version = "4.18.3"
version = "4.18.1"
description = "Easy-to-use multi-platform LLM chatbot and development framework"
readme = "README.md"
requires-python = ">=3.12"
-381
View File
@@ -1,381 +0,0 @@
"""
AstrBot 测试配置
提供共享的 pytest fixtures 和测试工具
"""
import json
import os
import sys
from asyncio import Queue
from pathlib import Path
from typing import Any
from unittest.mock import AsyncMock, MagicMock
import pytest
import pytest_asyncio
# 使用 tests/fixtures/helpers.py 中的共享工具函数,避免重复定义
from tests.fixtures.helpers import create_mock_llm_response, create_mock_message_component
# 将项目根目录添加到 sys.path
PROJECT_ROOT = Path(__file__).parent.parent
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
# 设置测试环境变量
os.environ.setdefault("TESTING", "true")
os.environ.setdefault("ASTRBOT_TEST_MODE", "true")
# ============================================================
# 测试收集和排序
# ============================================================
def pytest_collection_modifyitems(session, config, items): # noqa: ARG001
"""重新排序测试:单元测试优先,集成测试在后。"""
unit_tests = []
integration_tests = []
deselected = []
profile = config.getoption("--test-profile") or os.environ.get(
"ASTRBOT_TEST_PROFILE", "all"
)
for item in items:
item_path = Path(str(item.path))
is_integration = "integration" in item_path.parts
if is_integration:
if item.get_closest_marker("integration") is None:
item.add_marker(pytest.mark.integration)
item.add_marker(pytest.mark.tier_d)
integration_tests.append(item)
else:
if item.get_closest_marker("unit") is None:
item.add_marker(pytest.mark.unit)
if any(
item.get_closest_marker(marker) is not None
for marker in ("platform", "provider", "slow")
):
item.add_marker(pytest.mark.tier_c)
unit_tests.append(item)
# 单元测试 -> 集成测试
ordered_items = unit_tests + integration_tests
if profile == "blocking":
selected_items = []
for item in ordered_items:
if item.get_closest_marker("tier_c") or item.get_closest_marker("tier_d"):
deselected.append(item)
else:
selected_items.append(item)
if deselected:
config.hook.pytest_deselected(items=deselected)
items[:] = selected_items
return
items[:] = ordered_items
def pytest_addoption(parser):
"""增加测试执行档位选择。"""
parser.addoption(
"--test-profile",
action="store",
default=None,
choices=["all", "blocking"],
help="Select test profile. 'blocking' excludes auto-classified tier_c/tier_d tests.",
)
def pytest_configure(config):
"""注册自定义标记。"""
config.addinivalue_line("markers", "unit: 单元测试")
config.addinivalue_line("markers", "integration: 集成测试")
config.addinivalue_line("markers", "slow: 慢速测试")
config.addinivalue_line("markers", "platform: 平台适配器测试")
config.addinivalue_line("markers", "provider: LLM Provider 测试")
config.addinivalue_line("markers", "db: 数据库相关测试")
config.addinivalue_line("markers", "tier_c: C-tier tests (optional / non-blocking)")
config.addinivalue_line("markers", "tier_d: D-tier tests (extended / integration)")
# ============================================================
# 临时目录和文件 Fixtures
# ============================================================
@pytest.fixture
def temp_dir(tmp_path: Path) -> Path:
"""创建临时目录用于测试。"""
return tmp_path
@pytest.fixture
def event_queue() -> Queue:
"""Create a shared asyncio queue fixture for tests."""
return Queue()
@pytest.fixture
def platform_settings() -> dict:
"""Create a shared empty platform settings fixture for adapter tests."""
return {}
@pytest.fixture
def temp_data_dir(temp_dir: Path) -> Path:
"""创建模拟的 data 目录结构。"""
data_dir = temp_dir / "data"
data_dir.mkdir()
# 创建必要的子目录
(data_dir / "config").mkdir()
(data_dir / "plugins").mkdir()
(data_dir / "temp").mkdir()
(data_dir / "attachments").mkdir()
return data_dir
@pytest.fixture
def temp_config_file(temp_data_dir: Path) -> Path:
"""创建临时配置文件。"""
config_path = temp_data_dir / "config" / "cmd_config.json"
default_config = {
"provider": [],
"platform": [],
"provider_settings": {},
"default_personality": None,
"timezone": "Asia/Shanghai",
}
config_path.write_text(json.dumps(default_config, indent=2), encoding="utf-8")
return config_path
@pytest.fixture
def temp_db_file(temp_data_dir: Path) -> Path:
"""创建临时数据库文件路径。"""
return temp_data_dir / "test.db"
# ============================================================
# Mock Fixtures
# ============================================================
@pytest.fixture
def mock_provider():
"""创建模拟的 Provider。"""
provider = MagicMock()
provider.provider_config = {
"id": "test-provider",
"type": "openai_chat_completion",
"model": "gpt-4o-mini",
}
provider.get_model = MagicMock(return_value="gpt-4o-mini")
provider.text_chat = AsyncMock()
provider.text_chat_stream = AsyncMock()
provider.terminate = AsyncMock()
return provider
@pytest.fixture
def mock_platform():
"""创建模拟的 Platform。"""
platform = MagicMock()
platform.platform_name = "test_platform"
platform.platform_meta = MagicMock()
platform.platform_meta.support_proactive_message = False
platform.send_message = AsyncMock()
platform.terminate = AsyncMock()
return platform
@pytest.fixture
def mock_conversation():
"""创建模拟的 Conversation。"""
from astrbot.core.db.po import ConversationV2
return ConversationV2(
conversation_id="test-conv-id",
platform_id="test_platform",
user_id="test_user",
content=[],
persona_id=None,
)
@pytest.fixture
def mock_event():
"""创建模拟的 AstrMessageEvent。"""
event = MagicMock()
event.unified_msg_origin = "test_umo"
event.session_id = "test_session"
event.message_str = "Hello, world!"
event.message_obj = MagicMock()
event.message_obj.message = []
event.message_obj.sender = MagicMock()
event.message_obj.sender.user_id = "test_user"
event.message_obj.sender.nickname = "Test User"
event.message_obj.group_id = None
event.message_obj.group = None
event.get_platform_name = MagicMock(return_value="test_platform")
event.get_platform_id = MagicMock(return_value="test_platform")
event.get_group_id = MagicMock(return_value=None)
event.get_extra = MagicMock(return_value=None)
event.set_extra = MagicMock()
event.trace = MagicMock()
event.platform_meta = MagicMock()
event.platform_meta.support_proactive_message = False
return event
# ============================================================
# 配置 Fixtures
# ============================================================
@pytest.fixture
def astrbot_config(temp_config_file: Path):
"""创建 AstrBotConfig 实例。"""
from astrbot.core.config.astrbot_config import AstrBotConfig
config = AstrBotConfig()
config._config_path = str(temp_config_file) # noqa: SLF001
return config
@pytest.fixture
def main_agent_build_config():
"""创建 MainAgentBuildConfig 实例。"""
from astrbot.core.astr_main_agent import MainAgentBuildConfig
return MainAgentBuildConfig(
tool_call_timeout=60,
tool_schema_mode="full",
provider_wake_prefix="",
streaming_response=True,
sanitize_context_by_modalities=False,
kb_agentic_mode=False,
file_extract_enabled=False,
context_limit_reached_strategy="truncate_by_turns",
llm_safety_mode=True,
computer_use_runtime="local",
add_cron_tools=True,
)
# ============================================================
# 数据库 Fixtures
# ============================================================
@pytest_asyncio.fixture
async def temp_db(temp_db_file: Path):
"""创建临时数据库实例。"""
from astrbot.core.db.sqlite import SQLiteDatabase
db = SQLiteDatabase(str(temp_db_file))
try:
yield db
finally:
await db.engine.dispose()
if temp_db_file.exists():
temp_db_file.unlink()
# ============================================================
# Context Fixtures
# ============================================================
@pytest_asyncio.fixture
async def mock_context(
astrbot_config,
temp_db,
mock_provider,
mock_platform,
):
"""创建模拟的插件上下文。"""
from asyncio import Queue
from astrbot.core.star.context import Context
event_queue = Queue()
provider_manager = MagicMock()
provider_manager.get_using_provider = MagicMock(return_value=mock_provider)
provider_manager.get_provider_by_id = MagicMock(return_value=mock_provider)
platform_manager = MagicMock()
conversation_manager = MagicMock()
message_history_manager = MagicMock()
persona_manager = MagicMock()
persona_manager.personas_v3 = []
astrbot_config_mgr = MagicMock()
knowledge_base_manager = MagicMock()
cron_manager = MagicMock()
subagent_orchestrator = None
context = Context(
event_queue,
astrbot_config,
temp_db,
provider_manager,
platform_manager,
conversation_manager,
message_history_manager,
persona_manager,
astrbot_config_mgr,
knowledge_base_manager,
cron_manager,
subagent_orchestrator,
)
return context
# ============================================================
# Provider Request Fixtures
# ============================================================
@pytest.fixture
def provider_request():
"""创建 ProviderRequest 实例。"""
from astrbot.core.provider.entities import ProviderRequest
return ProviderRequest(
prompt="Hello",
session_id="test_session",
image_urls=[],
contexts=[],
system_prompt="You are a helpful assistant.",
)
# ============================================================
# 跳过条件
# ============================================================
def pytest_runtest_setup(item):
"""在测试运行前检查跳过条件。"""
# 跳过需要 API Key 但未设置的 Provider 测试
if item.get_closest_marker("provider"):
if not os.environ.get("TEST_PROVIDER_API_KEY"):
pytest.skip("TEST_PROVIDER_API_KEY not set")
# 跳过需要特定平台的测试
if item.get_closest_marker("platform"):
required_platform = None
marker = item.get_closest_marker("platform")
if marker and marker.args:
required_platform = marker.args[0]
if required_platform and not os.environ.get(
f"TEST_{required_platform.upper()}_ENABLED"
):
pytest.skip(f"TEST_{required_platform.upper()}_ENABLED not set")
-64
View File
@@ -1,64 +0,0 @@
"""
AstrBot 测试数据
此目录存放测试用的静态数据和配置文件
目录结构:
- fixtures/
configs/ # 测试配置文件
messages/ # 测试消息数据
plugins/ # 测试插件
knowledge_base/ # 测试知识库数据
mocks/ # Mock 模块
helpers.py # 辅助函数
"""
import json
from pathlib import Path
from .helpers import (
NoopAwaitable,
create_mock_discord_attachment,
create_mock_discord_channel,
create_mock_discord_user,
create_mock_file,
create_mock_llm_response,
create_mock_message_component,
create_mock_update,
make_platform_config,
)
FIXTURES_DIR = Path(__file__).parent
def load_fixture(filename: str) -> dict:
"""加载 JSON 格式的测试数据。"""
filepath = FIXTURES_DIR / filename
if not filepath.exists():
raise FileNotFoundError(f"Fixture not found: {filepath}")
return json.loads(filepath.read_text(encoding="utf-8"))
def get_fixture_path(filename: str) -> Path:
"""获取测试数据文件路径。"""
filepath = FIXTURES_DIR / filename
if not filepath.exists():
raise FileNotFoundError(f"Fixture not found: {filepath}")
return filepath
__all__ = [
"FIXTURES_DIR",
"load_fixture",
"get_fixture_path",
# 辅助函数
"NoopAwaitable",
"make_platform_config",
"create_mock_update",
"create_mock_file",
"create_mock_discord_attachment",
"create_mock_discord_user",
"create_mock_discord_channel",
"create_mock_message_component",
"create_mock_llm_response",
]
-21
View File
@@ -1,21 +0,0 @@
{
"provider": [
{
"id": "test-openai",
"type": "openai_chat_completion",
"model": "gpt-4o-mini",
"key": ["test-key"]
}
],
"platform": [],
"provider_settings": {
"default_personality": null,
"prompt_prefix": "",
"image_caption_provider_id": "",
"datetime_system_prompt": true,
"identifier": true,
"group_name_display": true
},
"default_personality": null,
"timezone": "Asia/Shanghai"
}
-332
View File
@@ -1,332 +0,0 @@
"""测试辅助函数和工具类。
提供统一的测试辅助工具减少测试代码重复
"""
from typing import Any
from unittest.mock import AsyncMock, MagicMock
from astrbot.core.message.components import BaseMessageComponent
class NoopAwaitable:
"""可等待的空操作对象。
用于 mock 需要返回 awaitable 对象的方法
"""
def __await__(self):
if False:
yield
return None
# ============================================================
# 平台配置工厂
# ============================================================
def make_platform_config(platform_type: str, **kwargs) -> dict:
"""平台配置工厂函数。
Args:
platform_type: 平台类型 (telegram, discord, aiocqhttp )
**kwargs: 覆盖默认配置的字段
Returns:
dict: 平台配置字典
"""
configs = {
"telegram": {
"id": "test_telegram",
"telegram_token": "test_token_123",
"telegram_api_base_url": "https://api.telegram.org/bot",
"telegram_file_base_url": "https://api.telegram.org/file/bot",
"telegram_command_register": True,
"telegram_command_auto_refresh": True,
"telegram_command_register_interval": 300,
"telegram_media_group_timeout": 2.5,
"telegram_media_group_max_wait": 10.0,
"start_message": "Welcome to AstrBot!",
},
"discord": {
"id": "test_discord",
"discord_token": "test_token_123",
"discord_proxy": None,
"discord_command_register": True,
"discord_guild_id_for_debug": None,
"discord_activity_name": "Playing AstrBot",
},
"aiocqhttp": {
"id": "test_aiocqhttp",
"ws_reverse_host": "0.0.0.0",
"ws_reverse_port": 6199,
"ws_reverse_token": "test_token",
},
"webchat": {
"id": "test_webchat",
},
"wecom": {
"id": "test_wecom",
"wecom_corpid": "test_corpid",
"wecom_secret": "test_secret",
},
}
config = configs.get(platform_type, {"id": f"test_{platform_type}"}).copy()
config.update(kwargs)
return config
# ============================================================
# Telegram 辅助函数
# ============================================================
def create_mock_update(
message_text: str | None = "Hello World",
chat_type: str = "private",
chat_id: int = 123456789,
user_id: int = 987654321,
username: str = "test_user",
message_id: int = 1,
media_group_id: str | None = None,
photo: list | None = None,
video: MagicMock | None = None,
document: MagicMock | None = None,
voice: MagicMock | None = None,
sticker: MagicMock | None = None,
reply_to_message: MagicMock | None = None,
caption: str | None = None,
entities: list | None = None,
caption_entities: list | None = None,
message_thread_id: int | None = None,
is_topic_message: bool = False,
):
"""创建模拟的 Telegram Update 对象。
Args:
message_text: 消息文本
chat_type: 聊天类型
chat_id: 聊天 ID
user_id: 用户 ID
username: 用户名
message_id: 消息 ID
media_group_id: 媒体组 ID
photo: 图片列表
video: 视频对象
document: 文档对象
voice: 语音对象
sticker: 贴纸对象
reply_to_message: 回复的消息
caption: 说明文字
entities: 实体列表
caption_entities: 说明实体列表
message_thread_id: 消息线程 ID
is_topic_message: 是否为主题消息
Returns:
MagicMock: 模拟的 Update 对象
"""
update = MagicMock()
update.update_id = 1
# Create message mock
message = MagicMock()
message.message_id = message_id
message.chat = MagicMock()
message.chat.id = chat_id
message.chat.type = chat_type
message.message_thread_id = message_thread_id
message.is_topic_message = is_topic_message
# Create user mock
from_user = MagicMock()
from_user.id = user_id
from_user.username = username
message.from_user = from_user
# Set message content
message.text = message_text
message.media_group_id = media_group_id
message.photo = photo
message.video = video
message.document = document
message.voice = voice
message.sticker = sticker
message.reply_to_message = reply_to_message
message.caption = caption
message.entities = entities
message.caption_entities = caption_entities
update.message = message
update.effective_chat = message.chat
return update
def create_mock_file(file_path: str = "https://api.telegram.org/file/test.jpg"):
"""创建模拟的 Telegram File 对象。
Args:
file_path: 文件路径
Returns:
MagicMock: 模拟的 File 对象
"""
file = MagicMock()
file.file_path = file_path
file.get_file = AsyncMock(return_value=file)
return file
# ============================================================
# Discord 辅助函数
# ============================================================
def create_mock_discord_attachment(
filename: str = "test.txt",
url: str = "https://cdn.discordapp.com/test.txt",
content_type: str | None = None,
size: int = 1024,
):
"""创建模拟的 Discord Attachment 对象。
Args:
filename: 文件名
url: 文件 URL
content_type: 内容类型
size: 文件大小
Returns:
MagicMock: 模拟的 Attachment 对象
"""
attachment = MagicMock()
attachment.filename = filename
attachment.url = url
attachment.content_type = content_type
attachment.size = size
return attachment
def create_mock_discord_user(
user_id: int = 123456789,
name: str = "TestUser",
display_name: str = "Test User",
bot: bool = False,
):
"""创建模拟的 Discord User 对象。
Args:
user_id: 用户 ID
name: 用户名
display_name: 显示名
bot: 是否为机器人
Returns:
MagicMock: 模拟的 User 对象
"""
user = MagicMock()
user.id = user_id
user.name = name
user.display_name = display_name
user.bot = bot
user.mention = f"<@{user_id}>"
return user
def create_mock_discord_channel(
channel_id: int = 111222333,
channel_type: str = "text",
name: str = "general",
guild_id: int | None = 444555666,
):
"""创建模拟的 Discord Channel 对象。
Args:
channel_id: 频道 ID
channel_type: 频道类型
name: 频道名
guild_id: 服务器 ID
Returns:
MagicMock: 模拟的 Channel 对象
"""
channel = MagicMock()
channel.id = channel_id
channel.name = name
channel.type = channel_type
if guild_id:
channel.guild = MagicMock()
channel.guild.id = guild_id
else:
channel.guild = None
return channel
# ============================================================
# 消息组件辅助函数
# ============================================================
def create_mock_message_component(
component_type: str,
**kwargs: Any,
) -> BaseMessageComponent:
"""创建模拟的消息组件。
Args:
component_type: 组件类型 (plain, image, at, reply, file)
**kwargs: 组件参数
Returns:
BaseMessageComponent: 消息组件实例
"""
from astrbot.core.message import components as Comp
component_map = {
"plain": Comp.Plain,
"image": Comp.Image,
"at": Comp.At,
"reply": Comp.Reply,
"file": Comp.File,
}
component_class = component_map.get(component_type.lower())
if not component_class:
raise ValueError(f"Unknown component type: {component_type}")
return component_class(**kwargs)
def create_mock_llm_response(
completion_text: str = "Hello! How can I help you?",
role: str = "assistant",
tools_call_name: list[str] | None = None,
tools_call_args: list[dict] | None = None,
tools_call_ids: list[str] | None = None,
):
"""创建模拟的 LLM 响应。
Args:
completion_text: 完成文本
role: 角色
tools_call_name: 工具调用名称列表
tools_call_args: 工具调用参数列表
tools_call_ids: 工具调用 ID 列表
Returns:
LLMResponse: 模拟的 LLM 响应
"""
from astrbot.core.provider.entities import LLMResponse, TokenUsage
return LLMResponse(
role=role,
completion_text=completion_text,
tools_call_name=tools_call_name or [],
tools_call_args=tools_call_args or [],
tools_call_ids=tools_call_ids or [],
usage=TokenUsage(input_other=10, output=5),
)
-33
View File
@@ -1,33 +0,0 @@
{
"plain_message": {
"type": "plain",
"text": "Hello, this is a test message."
},
"image_message": {
"type": "image",
"url": "https://example.com/test.jpg",
"file": null
},
"at_message": {
"type": "at",
"user_id": "12345",
"nickname": "TestUser"
},
"reply_message": {
"type": "reply",
"id": "msg_123",
"sender_nickname": "OriginalSender",
"message_str": "This is the original message"
},
"file_message": {
"type": "file",
"name": "test.pdf",
"url": "https://example.com/test.pdf"
},
"combined_message": {
"components": [
{"type": "at", "user_id": "bot_id"},
{"type": "plain", "text": " Hello bot!"}
]
}
}
-43
View File
@@ -1,43 +0,0 @@
"""测试 Mock 模块。
提供统一的 mock 工具和 fixture减少测试代码重复
使用方式:
# 在测试文件顶部导入需要的 fixture
from tests.fixtures.mocks import mock_telegram_modules
# 或使用 Builder 类创建 mock 对象
from tests.fixtures.mocks import MockTelegramBuilder
bot = MockTelegramBuilder.create_bot()
"""
from .aiocqhttp import (
MockAiocqhttpBuilder,
create_mock_aiocqhttp_modules,
mock_aiocqhttp_modules,
)
from .discord import (
MockDiscordBuilder,
create_mock_discord_modules,
mock_discord_modules,
)
from .telegram import (
MockTelegramBuilder,
create_mock_telegram_modules,
mock_telegram_modules,
)
__all__ = [
# Telegram
"mock_telegram_modules",
"create_mock_telegram_modules",
"MockTelegramBuilder",
# Discord
"mock_discord_modules",
"create_mock_discord_modules",
"MockDiscordBuilder",
# Aiocqhttp
"mock_aiocqhttp_modules",
"create_mock_aiocqhttp_modules",
"MockAiocqhttpBuilder",
]
-58
View File
@@ -1,58 +0,0 @@
"""Aiocqhttp 模块 Mock 工具。
提供统一的 aiocqhttp 相关模块 mock 设置避免在测试文件中重复定义
"""
import sys
from unittest.mock import AsyncMock, MagicMock
import pytest
def create_mock_aiocqhttp_modules():
"""创建 aiocqhttp 相关的 mock 模块。
Returns:
dict: 包含 aiocqhttp 和相关模块的 mock 对象
"""
mock_aiocqhttp = MagicMock()
mock_aiocqhttp.CQHttp = MagicMock
mock_aiocqhttp.Event = MagicMock
mock_aiocqhttp.exceptions = MagicMock()
mock_aiocqhttp.exceptions.ActionFailed = Exception
return mock_aiocqhttp
@pytest.fixture(scope="module", autouse=True)
def mock_aiocqhttp_modules():
"""Mock aiocqhttp 相关模块的 fixture。
自动应用于使用此 fixture 的测试模块
"""
mock_aiocqhttp = create_mock_aiocqhttp_modules()
monkeypatch = pytest.MonkeyPatch()
monkeypatch.setitem(sys.modules, "aiocqhttp", mock_aiocqhttp)
monkeypatch.setitem(sys.modules, "aiocqhttp.exceptions", mock_aiocqhttp.exceptions)
yield
monkeypatch.undo()
class MockAiocqhttpBuilder:
"""构建 aiocqhttp 测试 mock 对象的工具类。"""
@staticmethod
def create_bot():
"""创建 mock CQHttp bot 实例。"""
from tests.fixtures.helpers import NoopAwaitable
bot = MagicMock()
bot.send = AsyncMock()
bot.call_action = AsyncMock()
bot.on_request = MagicMock()
bot.on_notice = MagicMock()
bot.on_message = MagicMock()
bot.on_websocket_connection = MagicMock()
bot.run_task = MagicMock(return_value=NoopAwaitable())
return bot
-140
View File
@@ -1,140 +0,0 @@
"""Discord 模块 Mock 工具。
提供统一的 Discord 相关模块 mock 设置避免在测试文件中重复定义
"""
import sys
from unittest.mock import AsyncMock, MagicMock
import pytest
def create_mock_discord_modules():
"""创建 Discord 相关的 mock 模块。
Returns:
dict: 包含 discord 和相关模块的 mock 对象
"""
mock_discord = MagicMock()
# Mock discord.Intents
mock_intents = MagicMock()
mock_intents.default = MagicMock(return_value=mock_intents)
mock_discord.Intents = mock_intents
# Mock discord.Status
mock_discord.Status = MagicMock()
mock_discord.Status.online = "online"
# Mock discord.Bot
mock_bot = MagicMock()
mock_discord.Bot = MagicMock(return_value=mock_bot)
# Mock discord.Embed
mock_embed = MagicMock()
mock_discord.Embed = MagicMock(return_value=mock_embed)
# Mock discord.ui
mock_ui = MagicMock()
mock_ui.View = MagicMock
mock_ui.Button = MagicMock
mock_discord.ui = mock_ui
# Mock discord.Message
mock_discord.Message = MagicMock
# Mock discord.Interaction
mock_discord.Interaction = MagicMock
mock_discord.InteractionType = MagicMock()
mock_discord.InteractionType.application_command = 2
mock_discord.InteractionType.component = 3
# Mock discord.File
mock_discord.File = MagicMock
# Mock discord.SlashCommand
mock_discord.SlashCommand = MagicMock
# Mock discord.Option
mock_discord.Option = MagicMock
# Mock discord.SlashCommandOptionType
mock_discord.SlashCommandOptionType = MagicMock()
mock_discord.SlashCommandOptionType.string = 3
# Mock discord.errors
mock_discord.errors = MagicMock()
mock_discord.errors.LoginFailure = Exception
mock_discord.errors.ConnectionClosed = Exception
mock_discord.errors.NotFound = Exception
mock_discord.errors.Forbidden = Exception
# Mock discord.abc
mock_discord.abc = MagicMock()
mock_discord.abc.GuildChannel = MagicMock
mock_discord.abc.Messageable = MagicMock
mock_discord.abc.PrivateChannel = MagicMock
# Mock discord.channel
mock_channel = MagicMock()
mock_channel.DMChannel = MagicMock
mock_discord.channel = mock_channel
# Mock discord.types
mock_discord.types = MagicMock()
mock_discord.types.interactions = MagicMock()
# Mock discord.ApplicationContext
mock_discord.ApplicationContext = MagicMock
# Mock discord.CustomActivity
mock_discord.CustomActivity = MagicMock
return mock_discord
@pytest.fixture(scope="module", autouse=True)
def mock_discord_modules():
"""Mock Discord 相关模块的 fixture。
自动应用于使用此 fixture 的测试模块
"""
mock_discord = create_mock_discord_modules()
monkeypatch = pytest.MonkeyPatch()
monkeypatch.setitem(sys.modules, "discord", mock_discord)
monkeypatch.setitem(sys.modules, "discord.abc", mock_discord.abc)
monkeypatch.setitem(sys.modules, "discord.channel", mock_discord.channel)
monkeypatch.setitem(sys.modules, "discord.errors", mock_discord.errors)
monkeypatch.setitem(sys.modules, "discord.types", mock_discord.types)
monkeypatch.setitem(
sys.modules,
"discord.types.interactions",
mock_discord.types.interactions,
)
monkeypatch.setitem(sys.modules, "discord.ui", mock_discord.ui)
yield
monkeypatch.undo()
class MockDiscordBuilder:
"""构建 Discord 测试 mock 对象的工具类。"""
@staticmethod
def create_client():
"""创建 mock Discord client 实例。"""
client = MagicMock()
client.user = MagicMock()
client.user.id = 123456789
client.user.display_name = "TestBot"
client.user.name = "TestBot"
client.get_channel = MagicMock()
client.fetch_channel = AsyncMock()
client.get_message = MagicMock()
client.start = AsyncMock()
client.close = AsyncMock()
client.is_closed = MagicMock(return_value=False)
client.add_application_command = MagicMock()
client.sync_commands = AsyncMock()
client.change_presence = AsyncMock()
return client
-141
View File
@@ -1,141 +0,0 @@
"""Telegram 模块 Mock 工具。
提供统一的 Telegram 相关模块 mock 设置避免在测试文件中重复定义
"""
import sys
from unittest.mock import AsyncMock, MagicMock
import pytest
def create_mock_telegram_modules():
"""创建 Telegram 相关的 mock 模块。
Returns:
dict: 包含 telegram 和相关模块的 mock 对象
"""
mock_telegram = MagicMock()
mock_telegram.BotCommand = MagicMock
mock_telegram.Update = MagicMock
mock_telegram.constants = MagicMock()
mock_telegram.constants.ChatType = MagicMock()
mock_telegram.constants.ChatType.PRIVATE = "private"
mock_telegram.constants.ChatAction = MagicMock()
mock_telegram.constants.ChatAction.TYPING = "typing"
mock_telegram.constants.ChatAction.UPLOAD_VOICE = "upload_voice"
mock_telegram.constants.ChatAction.UPLOAD_DOCUMENT = "upload_document"
mock_telegram.constants.ChatAction.UPLOAD_PHOTO = "upload_photo"
mock_telegram.error = MagicMock()
mock_telegram.error.BadRequest = Exception
mock_telegram.ReactionTypeCustomEmoji = MagicMock
mock_telegram.ReactionTypeEmoji = MagicMock
mock_telegram_ext = MagicMock()
mock_telegram_ext.ApplicationBuilder = MagicMock
mock_telegram_ext.ContextTypes = MagicMock
mock_telegram_ext.ExtBot = MagicMock
mock_telegram_ext.filters = MagicMock()
mock_telegram_ext.filters.ALL = MagicMock()
mock_telegram_ext.MessageHandler = MagicMock
# Mock telegramify_markdown
mock_telegramify = MagicMock()
mock_telegramify.markdownify = lambda text, **kwargs: text
# Mock apscheduler
mock_apscheduler = MagicMock()
mock_apscheduler.schedulers = MagicMock()
mock_apscheduler.schedulers.asyncio = MagicMock()
mock_apscheduler.schedulers.asyncio.AsyncIOScheduler = MagicMock
mock_apscheduler.schedulers.background = MagicMock()
mock_apscheduler.schedulers.background.BackgroundScheduler = MagicMock
return {
"telegram": mock_telegram,
"telegram.ext": mock_telegram_ext,
"telegramify_markdown": mock_telegramify,
"apscheduler": mock_apscheduler,
}
@pytest.fixture(scope="module", autouse=True)
def mock_telegram_modules():
"""Mock Telegram 相关模块的 fixture。
自动应用于使用此 fixture 的测试模块
"""
mocks = create_mock_telegram_modules()
monkeypatch = pytest.MonkeyPatch()
monkeypatch.setitem(sys.modules, "telegram", mocks["telegram"])
monkeypatch.setitem(sys.modules, "telegram.constants", mocks["telegram"].constants)
monkeypatch.setitem(sys.modules, "telegram.error", mocks["telegram"].error)
monkeypatch.setitem(sys.modules, "telegram.ext", mocks["telegram.ext"])
monkeypatch.setitem(sys.modules, "telegramify_markdown", mocks["telegramify_markdown"])
monkeypatch.setitem(sys.modules, "apscheduler", mocks["apscheduler"])
monkeypatch.setitem(
sys.modules, "apscheduler.schedulers", mocks["apscheduler"].schedulers
)
monkeypatch.setitem(
sys.modules,
"apscheduler.schedulers.asyncio",
mocks["apscheduler"].schedulers.asyncio,
)
monkeypatch.setitem(
sys.modules,
"apscheduler.schedulers.background",
mocks["apscheduler"].schedulers.background,
)
yield
monkeypatch.undo()
class MockTelegramBuilder:
"""构建 Telegram 测试 mock 对象的工具类。"""
@staticmethod
def create_bot():
"""创建 mock Telegram bot 实例。"""
bot = MagicMock()
bot.username = "test_bot"
bot.id = 12345678
bot.base_url = "https://api.telegram.org/bottest_token_123/"
bot.send_message = AsyncMock()
bot.send_photo = AsyncMock()
bot.send_document = AsyncMock()
bot.send_voice = AsyncMock()
bot.send_chat_action = AsyncMock()
bot.delete_my_commands = AsyncMock()
bot.set_my_commands = AsyncMock()
bot.set_message_reaction = AsyncMock()
bot.edit_message_text = AsyncMock()
return bot
@staticmethod
def create_application():
"""创建 mock Telegram Application 实例。"""
from tests.fixtures.helpers import NoopAwaitable
app = MagicMock()
app.bot = MagicMock()
app.bot.username = "test_bot"
app.bot.base_url = "https://api.telegram.org/bottest_token_123/"
app.initialize = AsyncMock()
app.start = AsyncMock()
app.stop = AsyncMock()
app.add_handler = MagicMock()
app.updater = MagicMock()
app.updater.start_polling = MagicMock(return_value=NoopAwaitable())
app.updater.stop = AsyncMock()
return app
@staticmethod
def create_scheduler():
"""创建 mock APScheduler 实例。"""
scheduler = MagicMock()
scheduler.add_job = MagicMock()
scheduler.start = MagicMock()
scheduler.running = True
scheduler.shutdown = MagicMock()
return scheduler
-40
View File
@@ -1,40 +0,0 @@
"""
测试插件 - 用于插件系统测试
这是一个最小化的测试插件用于验证插件系统的功能
"""
from astrbot.api import llm_tool, star
from astrbot.api.event import AstrMessageEvent, MessageEventResult, filter
@star.register("test_plugin", "AstrBot Team", "测试插件 - 用于插件系统测试", "1.0.0")
class TestPlugin(star.Star):
"""测试插件类"""
def __init__(self, context: star.Context) -> None:
super().__init__(context)
self.initialized = True
async def terminate(self) -> None:
"""插件终止"""
self.initialized = False
@filter.command("test_cmd")
async def test_command(self, event: AstrMessageEvent) -> None:
"""测试命令处理器。"""
event.set_result(MessageEventResult().message("测试命令执行成功"))
@llm_tool("test_tool")
async def test_llm_tool(self, query: str) -> str:
"""测试 LLM 工具。
Args:
query(string): 查询内容
"""
return f"测试工具执行成功: {query}"
@filter.regex(r"^test_regex_(.+)$")
async def test_regex_handler(self, event: AstrMessageEvent) -> None:
"""测试正则处理器。"""
event.set_result(MessageEventResult().message("正则匹配成功"))
-5
View File
@@ -1,5 +0,0 @@
name: test_plugin
description: 测试插件 - 用于插件系统测试
version: 1.0.0
author: AstrBot Team
repo: https://github.com/test/test_plugin
-115
View File
@@ -1,115 +0,0 @@
"""Smoke tests for critical startup and import paths."""
from __future__ import annotations
import subprocess
import sys
from pathlib import Path
from astrbot.core.pipeline.bootstrap import ensure_builtin_stages_registered
from astrbot.core.pipeline.process_stage.method.agent_sub_stages.internal import (
InternalAgentSubStage,
)
from astrbot.core.pipeline.process_stage.method.agent_sub_stages.third_party import (
ThirdPartyAgentSubStage,
)
from astrbot.core.pipeline.stage import Stage, registered_stages
from astrbot.core.pipeline.stage_order import STAGES_ORDER
REPO_ROOT = Path(__file__).resolve().parents[1]
def _run_code_in_fresh_interpreter(code: str, failure_message: str) -> None:
proc = subprocess.run(
[sys.executable, "-c", code],
cwd=REPO_ROOT,
capture_output=True,
text=True,
check=False,
)
assert proc.returncode == 0, (
f"{failure_message}\nstdout:\n{proc.stdout}\nstderr:\n{proc.stderr}\n"
)
def test_smoke_critical_imports_in_fresh_interpreter() -> None:
code = (
"import importlib;"
"mods=["
"'astrbot.core.core_lifecycle',"
"'astrbot.core.astr_main_agent',"
"'astrbot.core.pipeline.scheduler',"
"'astrbot.core.pipeline.process_stage.method.agent_sub_stages.internal',"
"'astrbot.core.pipeline.process_stage.method.agent_sub_stages.third_party'"
"];"
"[importlib.import_module(m) for m in mods]"
)
_run_code_in_fresh_interpreter(code, "Smoke import check failed.")
def test_smoke_pipeline_stage_registration_matches_order() -> None:
ensure_builtin_stages_registered()
stage_names = {cls.__name__ for cls in registered_stages}
assert set(STAGES_ORDER).issubset(stage_names)
assert len(stage_names) == len(registered_stages)
def test_smoke_agent_sub_stages_are_stage_subclasses() -> None:
assert issubclass(InternalAgentSubStage, Stage)
assert issubclass(ThirdPartyAgentSubStage, Stage)
def test_pipeline_package_exports_remain_compatible() -> None:
import astrbot.core.pipeline as pipeline
assert pipeline.ProcessStage is not None
assert pipeline.RespondStage is not None
assert isinstance(pipeline.STAGES_ORDER, list)
assert "ProcessStage" in pipeline.STAGES_ORDER
def test_builtin_stage_bootstrap_is_idempotent() -> None:
ensure_builtin_stages_registered()
before_count = len(registered_stages)
stage_names = {cls.__name__ for cls in registered_stages}
expected_stage_names = {
"WakingCheckStage",
"WhitelistCheckStage",
"SessionStatusCheckStage",
"RateLimitStage",
"ContentSafetyCheckStage",
"PreProcessStage",
"ProcessStage",
"ResultDecorateStage",
"RespondStage",
}
assert expected_stage_names.issubset(stage_names)
ensure_builtin_stages_registered()
assert len(registered_stages) == before_count
def test_pipeline_import_is_stable_with_mocked_apscheduler() -> None:
"""Regression: importing pipeline should not require cron/apscheduler modules."""
code = (
"import sys;"
"from unittest.mock import MagicMock;"
"mock_apscheduler = MagicMock();"
"mock_apscheduler.schedulers = MagicMock();"
"mock_apscheduler.schedulers.asyncio = MagicMock();"
"mock_apscheduler.schedulers.background = MagicMock();"
"sys.modules['apscheduler'] = mock_apscheduler;"
"sys.modules['apscheduler.schedulers'] = mock_apscheduler.schedulers;"
"sys.modules['apscheduler.schedulers.asyncio'] = mock_apscheduler.schedulers.asyncio;"
"sys.modules['apscheduler.schedulers.background'] = mock_apscheduler.schedulers.background;"
"import astrbot.core.pipeline as pipeline;"
"assert pipeline.ProcessStage is not None;"
"assert pipeline.RespondStage is not None"
)
_run_code_in_fresh_interpreter(
code,
"Pipeline import should not depend on real apscheduler package.",
)

Some files were not shown because too many files have changed in this diff Show More