Compare commits

..

3 Commits

Author SHA1 Message Date
Soulter cfb0538b32 feat(chat): add websocket API key extraction and scope validation 2026-02-25 17:40:55 +08:00
Soulter f8f7e6d57a feat(webchat): refactor message parsing logic and integrate new parsing function 2026-02-25 17:14:02 +08:00
Soulter 53ae8cd7cf feat: implement websockets transport mode selection for chat
- Added transport mode selection (SSE/WebSocket) in the chat component.
- Updated conversation sidebar to include transport mode options.
- Integrated transport mode handling in message sending logic.
- Refactored message sending functions to support both SSE and WebSocket.
- Enhanced WebSocket connection management and message handling.
- Updated localization files for transport mode labels.
- Configured Vite to support WebSocket proxying.
2026-02-24 21:14:16 +08:00
62 changed files with 2456 additions and 3679 deletions
+1 -1
View File
@@ -37,7 +37,7 @@ jobs:
mkdir -p data/temp
export TESTING=true
export ZHIPU_API_KEY=${{ secrets.OPENAI_API_KEY }}
pytest --cov=astrbot -v -o log_cli=true -o log_level=DEBUG
pytest --cov=. -v -o log_cli=true -o log_level=DEBUG
- name: Upload results to Codecov
uses: codecov/codecov-action@v5
+1 -1
View File
@@ -1 +1 @@
__version__ = "4.18.2"
__version__ = "4.18.1"
+16 -92
View File
@@ -24,77 +24,15 @@ def _should_stop_agent(astr_event) -> bool:
return astr_event.is_stopped() or bool(astr_event.get_extra("agent_stop_requested"))
def _truncate_tool_result(text: str, limit: int = 70) -> str:
if limit <= 0:
return ""
if len(text) <= limit:
return text
if limit <= 3:
return text[:limit]
return f"{text[: limit - 3]}..."
def _extract_chain_json_data(msg_chain: MessageChain) -> dict | None:
if not msg_chain.chain:
return None
first_comp = msg_chain.chain[0]
if isinstance(first_comp, Json) and isinstance(first_comp.data, dict):
return first_comp.data
return None
def _record_tool_call_name(
tool_info: dict | None, tool_name_by_call_id: dict[str, str]
) -> None:
if not isinstance(tool_info, dict):
return
tool_call_id = tool_info.get("id")
tool_name = tool_info.get("name")
if tool_call_id is None or tool_name is None:
return
tool_name_by_call_id[str(tool_call_id)] = str(tool_name)
def _build_tool_call_status_message(tool_info: dict | None) -> str:
if tool_info:
return f"🔨 调用工具: {tool_info.get('name', 'unknown')}"
return "🔨 调用工具..."
def _build_tool_result_status_message(
msg_chain: MessageChain, tool_name_by_call_id: dict[str, str]
) -> str:
tool_name = "unknown"
tool_result = ""
result_data = _extract_chain_json_data(msg_chain)
if result_data:
tool_call_id = result_data.get("id")
if tool_call_id is not None:
tool_name = tool_name_by_call_id.pop(str(tool_call_id), "unknown")
tool_result = str(result_data.get("result", ""))
if not tool_result:
tool_result = msg_chain.get_plain_text(with_other_comps_mark=True)
tool_result = _truncate_tool_result(tool_result, 70)
status_msg = f"🔨 调用工具: {tool_name}"
if tool_result:
status_msg = f"{status_msg}\n📎 返回结果: {tool_result}"
return status_msg
async def run_agent(
agent_runner: AgentRunner,
max_step: int = 30,
show_tool_use: bool = True,
show_tool_call_result: bool = False,
stream_to_general: bool = False,
show_reasoning: bool = False,
) -> AsyncGenerator[MessageChain | None, None]:
step_idx = 0
astr_event = agent_runner.run_context.context.event
tool_name_by_call_id: dict[str, str] = {}
while step_idx < max_step + 1:
step_idx += 1
@@ -152,13 +90,6 @@ async def run_agent(
continue
if astr_event.get_platform_id() == "webchat":
await astr_event.send(msg_chain)
elif show_tool_use and show_tool_call_result:
status_msg = _build_tool_result_status_message(
msg_chain, tool_name_by_call_id
)
await astr_event.send(
MessageChain(type="tool_call").message(status_msg)
)
# 对于其他情况,暂时先不处理
continue
elif resp.type == "tool_call":
@@ -166,22 +97,25 @@ async def run_agent(
# 用来标记流式响应需要分节
yield MessageChain(chain=[], type="break")
tool_info = _extract_chain_json_data(resp.data["chain"])
astr_event.trace.record(
"agent_tool_call",
tool_name=tool_info if tool_info else "unknown",
)
_record_tool_call_name(tool_info, tool_name_by_call_id)
tool_info = None
if resp.data["chain"].chain:
json_comp = resp.data["chain"].chain[0]
if isinstance(json_comp, Json):
tool_info = json_comp.data
astr_event.trace.record(
"agent_tool_call",
tool_name=tool_info if tool_info else "unknown",
)
if astr_event.get_platform_name() == "webchat":
await astr_event.send(resp.data["chain"])
elif show_tool_use:
if show_tool_call_result and isinstance(tool_info, dict):
# Delay tool status notification until tool_call_result.
continue
chain = MessageChain(type="tool_call").message(
_build_tool_call_status_message(tool_info)
)
if tool_info:
m = f"🔨 调用工具: {tool_info.get('name', 'unknown')}"
else:
m = "🔨 调用工具..."
chain = MessageChain(type="tool_call").message(m)
await astr_event.send(chain)
continue
@@ -268,7 +202,6 @@ async def run_live_agent(
tts_provider: TTSProvider | None = None,
max_step: int = 30,
show_tool_use: bool = True,
show_tool_call_result: bool = False,
show_reasoning: bool = False,
) -> AsyncGenerator[MessageChain | None, None]:
"""Live Mode 的 Agent 运行器,支持流式 TTS
@@ -278,7 +211,6 @@ async def run_live_agent(
tts_provider: TTS Provider 实例
max_step: 最大步数
show_tool_use: 是否显示工具使用
show_tool_call_result: 是否显示工具返回结果
show_reasoning: 是否显示推理过程
Yields:
@@ -290,7 +222,6 @@ async def run_live_agent(
agent_runner,
max_step=max_step,
show_tool_use=show_tool_use,
show_tool_call_result=show_tool_call_result,
stream_to_general=False,
show_reasoning=show_reasoning,
):
@@ -319,12 +250,7 @@ async def run_live_agent(
# 1. 启动 Agent Feeder 任务:负责运行 Agent 并将文本分句喂给 text_queue
feeder_task = asyncio.create_task(
_run_agent_feeder(
agent_runner,
text_queue,
max_step,
show_tool_use,
show_tool_call_result,
show_reasoning,
agent_runner, text_queue, max_step, show_tool_use, show_reasoning
)
)
@@ -410,7 +336,6 @@ async def _run_agent_feeder(
text_queue: asyncio.Queue,
max_step: int,
show_tool_use: bool,
show_tool_call_result: bool,
show_reasoning: bool,
) -> None:
"""运行 Agent 并将文本输出分句放入队列"""
@@ -420,7 +345,6 @@ async def _run_agent_feeder(
agent_runner,
max_step=max_step,
show_tool_use=show_tool_use,
show_tool_call_result=show_tool_call_result,
stream_to_general=False,
show_reasoning=show_reasoning,
):
+13 -67
View File
@@ -17,12 +17,6 @@ from astrbot.core.agent.tool_executor import BaseFunctionToolExecutor
from astrbot.core.astr_agent_context import AstrAgentContext
from astrbot.core.astr_main_agent_resources import (
BACKGROUND_TASK_RESULT_WOKE_SYSTEM_PROMPT,
EXECUTE_SHELL_TOOL,
FILE_DOWNLOAD_TOOL,
FILE_UPLOAD_TOOL,
LOCAL_EXECUTE_SHELL_TOOL,
LOCAL_PYTHON_TOOL,
PYTHON_TOOL,
SEND_MESSAGE_TO_USER_TOOL,
)
from astrbot.core.cron.events import CronMessageEvent
@@ -97,65 +91,6 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
yield r
return
@classmethod
def _get_runtime_computer_tools(cls, runtime: str) -> dict[str, FunctionTool]:
if runtime == "sandbox":
return {
EXECUTE_SHELL_TOOL.name: EXECUTE_SHELL_TOOL,
PYTHON_TOOL.name: PYTHON_TOOL,
FILE_UPLOAD_TOOL.name: FILE_UPLOAD_TOOL,
FILE_DOWNLOAD_TOOL.name: FILE_DOWNLOAD_TOOL,
}
if runtime == "local":
return {
LOCAL_EXECUTE_SHELL_TOOL.name: LOCAL_EXECUTE_SHELL_TOOL,
LOCAL_PYTHON_TOOL.name: LOCAL_PYTHON_TOOL,
}
return {}
@classmethod
def _build_handoff_toolset(
cls,
run_context: ContextWrapper[AstrAgentContext],
tools: list[str | FunctionTool] | None,
) -> ToolSet | None:
ctx = run_context.context.context
event = run_context.context.event
cfg = ctx.get_config(umo=event.unified_msg_origin)
provider_settings = cfg.get("provider_settings", {})
runtime = str(provider_settings.get("computer_use_runtime", "local"))
runtime_computer_tools = cls._get_runtime_computer_tools(runtime)
# Keep persona semantics aligned with the main agent: tools=None means
# "all tools", including runtime computer-use tools.
if tools is None:
toolset = ToolSet()
for registered_tool in llm_tools.func_list:
if isinstance(registered_tool, HandoffTool):
continue
if registered_tool.active:
toolset.add_tool(registered_tool)
for runtime_tool in runtime_computer_tools.values():
toolset.add_tool(runtime_tool)
return None if toolset.empty() else toolset
if not tools:
return None
toolset = ToolSet()
for tool_name_or_obj in tools:
if isinstance(tool_name_or_obj, str):
registered_tool = llm_tools.get_func(tool_name_or_obj)
if registered_tool and registered_tool.active:
toolset.add_tool(registered_tool)
continue
runtime_tool = runtime_computer_tools.get(tool_name_or_obj)
if runtime_tool:
toolset.add_tool(runtime_tool)
elif isinstance(tool_name_or_obj, FunctionTool):
toolset.add_tool(tool_name_or_obj)
return None if toolset.empty() else toolset
@classmethod
async def _execute_handoff(
cls,
@@ -166,8 +101,19 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
input_ = tool_args.get("input")
image_urls = tool_args.get("image_urls")
# Build handoff toolset from registered tools plus runtime computer tools.
toolset = cls._build_handoff_toolset(run_context, tool.agent.tools)
# make toolset for the agent
tools = tool.agent.tools
if tools:
toolset = ToolSet()
for t in tools:
if isinstance(t, str):
_t = llm_tools.get_func(t)
if _t:
toolset.add_tool(_t)
elif isinstance(t, FunctionTool):
toolset.add_tool(t)
else:
toolset = None
ctx = run_context.context.context
event = run_context.context.event
-5
View File
@@ -11,7 +11,6 @@ from astrbot.core.message.components import File
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from ..computer_client import get_booter
from .permissions import check_admin_permission
# @dataclass
# class CreateFileTool(FunctionTool):
@@ -103,8 +102,6 @@ class FileUploadTool(FunctionTool):
context: ContextWrapper[AstrAgentContext],
local_path: str,
) -> str | None:
if permission_error := check_admin_permission(context, "File upload/download"):
return permission_error
sb = await get_booter(
context.context.context,
context.context.event.unified_msg_origin,
@@ -164,8 +161,6 @@ class FileDownloadTool(FunctionTool):
remote_path: str,
also_send_to_user: bool = True,
) -> ToolExecResult:
if permission_error := check_admin_permission(context, "File upload/download"):
return permission_error
sb = await get_booter(
context.context.context,
context.context.event.unified_msg_origin,
@@ -1,19 +0,0 @@
from astrbot.core.agent.run_context import ContextWrapper
from astrbot.core.astr_agent_context import AstrAgentContext
def check_admin_permission(
context: ContextWrapper[AstrAgentContext], operation_name: str
) -> str | None:
cfg = context.context.context.get_config(
umo=context.context.event.unified_msg_origin
)
provider_settings = cfg.get("provider_settings", {})
require_admin = provider_settings.get("computer_use_require_admin", True)
if require_admin and context.context.event.role != "admin":
return (
f"error: Permission denied. {operation_name} is only allowed for admin users. "
"Tell user to set admins in `AstrBot WebUI -> Config -> General Config` by adding their user ID to the admins list if they need this feature. "
f"User's ID is: {context.context.event.get_sender_id()}. User's ID can be found by using /sid command."
)
return None
+17 -3
View File
@@ -7,7 +7,6 @@ from astrbot.core.agent.run_context import ContextWrapper
from astrbot.core.agent.tool import ToolExecResult
from astrbot.core.astr_agent_context import AstrAgentContext, AstrMessageEvent
from astrbot.core.computer.computer_client import get_booter, get_local_booter
from astrbot.core.computer.tools.permissions import check_admin_permission
from astrbot.core.message.message_event_result import MessageChain
param_schema = {
@@ -27,6 +26,21 @@ param_schema = {
}
def _check_admin_permission(context: ContextWrapper[AstrAgentContext]) -> str | None:
cfg = context.context.context.get_config(
umo=context.context.event.unified_msg_origin
)
provider_settings = cfg.get("provider_settings", {})
require_admin = provider_settings.get("computer_use_require_admin", True)
if require_admin and context.context.event.role != "admin":
return (
"error: Permission denied. Python execution is only allowed for admin users. "
"Tell user to set admins in `AstrBot WebUI -> Config -> General Config` by adding their user ID to the admins list if they need this feature."
f"User's ID is: {context.context.event.get_sender_id()}. User's ID can be found by using /sid command."
)
return None
async def handle_result(result: dict, event: AstrMessageEvent) -> ToolExecResult:
data = result.get("data", {})
output = data.get("output", {})
@@ -67,7 +81,7 @@ class PythonTool(FunctionTool):
async def call(
self, context: ContextWrapper[AstrAgentContext], code: str, silent: bool = False
) -> ToolExecResult:
if permission_error := check_admin_permission(context, "Python execution"):
if permission_error := _check_admin_permission(context):
return permission_error
sb = await get_booter(
context.context.context,
@@ -90,7 +104,7 @@ class LocalPythonTool(FunctionTool):
async def call(
self, context: ContextWrapper[AstrAgentContext], code: str, silent: bool = False
) -> ToolExecResult:
if permission_error := check_admin_permission(context, "Python execution"):
if permission_error := _check_admin_permission(context):
return permission_error
sb = get_local_booter()
try:
+16 -2
View File
@@ -7,7 +7,21 @@ from astrbot.core.agent.tool import ToolExecResult
from astrbot.core.astr_agent_context import AstrAgentContext
from ..computer_client import get_booter, get_local_booter
from .permissions import check_admin_permission
def _check_admin_permission(context: ContextWrapper[AstrAgentContext]) -> str | None:
cfg = context.context.context.get_config(
umo=context.context.event.unified_msg_origin
)
provider_settings = cfg.get("provider_settings", {})
require_admin = provider_settings.get("computer_use_require_admin", True)
if require_admin and context.context.event.role != "admin":
return (
"error: Permission denied. Shell execution is only allowed for admin users. "
"Tell user to set admins in `AstrBot WebUI -> Config -> General Config` by adding their user ID to the admins list if they need this feature."
f"User's ID is: {context.context.event.get_sender_id()}. User's ID can be found by using /sid command."
)
return None
@dataclass
@@ -47,7 +61,7 @@ class ExecuteShellTool(FunctionTool):
background: bool = False,
env: dict = {},
) -> ToolExecResult:
if permission_error := check_admin_permission(context, "Shell execution"):
if permission_error := _check_admin_permission(context):
return permission_error
if self.is_local:
+1 -14
View File
@@ -5,7 +5,7 @@ from typing import Any, TypedDict
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
VERSION = "4.18.2"
VERSION = "4.18.1"
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db")
WEBHOOK_SUPPORTED_PLATFORMS = [
@@ -100,7 +100,6 @@ DEFAULT_CONFIG = {
"dequeue_context_length": 1,
"streaming_response": False,
"show_tool_use_status": False,
"show_tool_call_result": False,
"sanitize_context_by_modalities": False,
"max_quoted_fallback_images": 20,
"quoted_message_parser": {
@@ -2307,9 +2306,6 @@ CONFIG_METADATA_2 = {
"show_tool_use_status": {
"type": "bool",
},
"show_tool_call_result": {
"type": "bool",
},
"unsupported_streaming_strategy": {
"type": "string",
},
@@ -2998,15 +2994,6 @@ CONFIG_METADATA_3 = {
"provider_settings.agent_runner_type": "local",
},
},
"provider_settings.show_tool_call_result": {
"description": "输出函数调用返回结果",
"type": "bool",
"hint": "仅在输出函数调用状态启用时生效,展示结果前 70 个字符。",
"condition": {
"provider_settings.agent_runner_type": "local",
"provider_settings.show_tool_use_status": True,
},
},
"provider_settings.sanitize_context_by_modalities": {
"description": "按模型能力清理历史上下文",
"type": "bool",
+3 -28
View File
@@ -720,38 +720,13 @@ class File(BaseMessageComponent):
if allow_return_url and self.url:
return self.url
if self.file_:
path = self.file_
if path.startswith("file://"):
# 处理 file:// (2 slashes) 或 file:/// (3 slashes)
# pathlib.as_uri() 通常生成 file:///
path = path[7:]
# 兼容 Windows: file:///C:/path -> /C:/path -> C:/path
if (
os.name == "nt"
and len(path) > 2
and path[0] == "/"
and path[2] == ":"
):
path = path[1:]
if os.path.exists(path):
return os.path.abspath(path)
if self.file_ and os.path.exists(self.file_):
return os.path.abspath(self.file_)
if self.url:
await self._download_file()
if self.file_:
path = self.file_
if path.startswith("file://"):
path = path[7:]
if (
os.name == "nt"
and len(path) > 2
and path[0] == "/"
and path[2] == ":"
):
path = path[1:]
return os.path.abspath(path)
return os.path.abspath(self.file_)
return ""
+21 -66
View File
@@ -1,60 +1,30 @@
"""Pipeline package exports.
This module intentionally avoids eager imports of all pipeline stage modules to
prevent import-time cycles. Stage classes remain available via lazy attribute
resolution for backward compatibility.
"""
from __future__ import annotations
from importlib import import_module
from typing import Any
from astrbot.core.message.message_event_result import (
EventResultType,
MessageEventResult,
)
from .stage_order import STAGES_ORDER
from .content_safety_check.stage import ContentSafetyCheckStage
from .preprocess_stage.stage import PreProcessStage
from .process_stage.stage import ProcessStage
from .rate_limit_check.stage import RateLimitStage
from .respond.stage import RespondStage
from .result_decorate.stage import ResultDecorateStage
from .session_status_check.stage import SessionStatusCheckStage
from .waking_check.stage import WakingCheckStage
from .whitelist_check.stage import WhitelistCheckStage
_LAZY_EXPORTS = {
"ContentSafetyCheckStage": (
"astrbot.core.pipeline.content_safety_check.stage",
"ContentSafetyCheckStage",
),
"PreProcessStage": (
"astrbot.core.pipeline.preprocess_stage.stage",
"PreProcessStage",
),
"ProcessStage": (
"astrbot.core.pipeline.process_stage.stage",
"ProcessStage",
),
"RateLimitStage": (
"astrbot.core.pipeline.rate_limit_check.stage",
"RateLimitStage",
),
"RespondStage": (
"astrbot.core.pipeline.respond.stage",
"RespondStage",
),
"ResultDecorateStage": (
"astrbot.core.pipeline.result_decorate.stage",
"ResultDecorateStage",
),
"SessionStatusCheckStage": (
"astrbot.core.pipeline.session_status_check.stage",
"SessionStatusCheckStage",
),
"WakingCheckStage": (
"astrbot.core.pipeline.waking_check.stage",
"WakingCheckStage",
),
"WhitelistCheckStage": (
"astrbot.core.pipeline.whitelist_check.stage",
"WhitelistCheckStage",
),
}
# 管道阶段顺序
STAGES_ORDER = [
"WakingCheckStage", # 检查是否需要唤醒
"WhitelistCheckStage", # 检查是否在群聊/私聊白名单
"SessionStatusCheckStage", # 检查会话是否整体启用
"RateLimitStage", # 检查会话是否超过频率限制
"ContentSafetyCheckStage", # 检查内容安全
"PreProcessStage", # 预处理
"ProcessStage", # 交由 Stars 处理(a.k.a 插件),或者 LLM 调用
"ResultDecorateStage", # 处理结果,比如添加回复前缀、t2i、转换为语音 等
"RespondStage", # 发送消息
]
__all__ = [
"ContentSafetyCheckStage",
@@ -66,21 +36,6 @@ __all__ = [
"RespondStage",
"ResultDecorateStage",
"SessionStatusCheckStage",
"STAGES_ORDER",
"WakingCheckStage",
"WhitelistCheckStage",
]
def __getattr__(name: str) -> Any:
if name not in _LAZY_EXPORTS:
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
module_path, attr_name = _LAZY_EXPORTS[name]
module = import_module(module_path)
value = getattr(module, attr_name)
globals()[name] = value
return value
def __dir__() -> list[str]:
return sorted(set(globals()) | set(__all__))
-52
View File
@@ -1,52 +0,0 @@
"""Pipeline bootstrap utilities."""
from importlib import import_module
from .stage import registered_stages
_BUILTIN_STAGE_MODULES = (
"astrbot.core.pipeline.waking_check.stage",
"astrbot.core.pipeline.whitelist_check.stage",
"astrbot.core.pipeline.session_status_check.stage",
"astrbot.core.pipeline.rate_limit_check.stage",
"astrbot.core.pipeline.content_safety_check.stage",
"astrbot.core.pipeline.preprocess_stage.stage",
"astrbot.core.pipeline.process_stage.stage",
"astrbot.core.pipeline.result_decorate.stage",
"astrbot.core.pipeline.respond.stage",
)
_EXPECTED_STAGE_NAMES = {
"WakingCheckStage",
"WhitelistCheckStage",
"SessionStatusCheckStage",
"RateLimitStage",
"ContentSafetyCheckStage",
"PreProcessStage",
"ProcessStage",
"ResultDecorateStage",
"RespondStage",
}
_builtin_stages_registered = False
def ensure_builtin_stages_registered() -> None:
"""Ensure built-in pipeline stages are imported and registered."""
global _builtin_stages_registered
if _builtin_stages_registered:
return
stage_names = {stage_cls.__name__ for stage_cls in registered_stages}
if _EXPECTED_STAGE_NAMES.issubset(stage_names):
_builtin_stages_registered = True
return
for module_path in _BUILTIN_STAGE_MODULES:
import_module(module_path)
_builtin_stages_registered = True
__all__ = ["ensure_builtin_stages_registered"]
+2 -4
View File
@@ -1,9 +1,7 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Any
from astrbot.core.config import AstrBotConfig
from astrbot.core.star import PluginManager
from .context_utils import call_event_hook, call_handler
@@ -13,7 +11,7 @@ class PipelineContext:
"""上下文对象,包含管道执行所需的上下文信息"""
astrbot_config: AstrBotConfig # AstrBot 配置对象
plugin_manager: Any # 插件管理器对象
plugin_manager: PluginManager # 插件管理器对象
astrbot_config_id: str
call_handler = call_handler
call_event_hook = call_event_hook
@@ -19,7 +19,6 @@ from astrbot.core.message.message_event_result import (
MessageEventResult,
ResultContentType,
)
from astrbot.core.pipeline.stage import Stage
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.provider.entities import (
LLMResponse,
@@ -31,6 +30,7 @@ from astrbot.core.utils.session_lock import session_lock_manager
from .....astr_agent_run_util import run_agent, run_live_agent
from ....context import PipelineContext, call_event_hook
from ...stage import Stage
class InternalAgentSubStage(Stage):
@@ -54,7 +54,6 @@ class InternalAgentSubStage(Stage):
if isinstance(self.max_step, bool): # workaround: #2622
self.max_step = 30
self.show_tool_use: bool = settings.get("show_tool_use_status", True)
self.show_tool_call_result: bool = settings.get("show_tool_call_result", False)
self.show_reasoning = settings.get("display_reasoning_text", False)
self.sanitize_context_by_modalities: bool = settings.get(
"sanitize_context_by_modalities",
@@ -241,7 +240,6 @@ class InternalAgentSubStage(Stage):
tts_provider,
self.max_step,
self.show_tool_use,
self.show_tool_call_result,
show_reasoning=self.show_reasoning,
),
),
@@ -271,7 +269,6 @@ class InternalAgentSubStage(Stage):
agent_runner,
self.max_step,
self.show_tool_use,
self.show_tool_call_result,
show_reasoning=self.show_reasoning,
),
),
@@ -300,7 +297,6 @@ class InternalAgentSubStage(Stage):
agent_runner,
self.max_step,
self.show_tool_use,
self.show_tool_call_result,
stream_to_general,
show_reasoning=self.show_reasoning,
):
@@ -8,7 +8,6 @@ from astrbot.core.agent.runners.dashscope.dashscope_agent_runner import (
DashscopeAgentRunner,
)
from astrbot.core.agent.runners.dify.dify_agent_runner import DifyAgentRunner
from astrbot.core.astr_agent_hooks import MAIN_AGENT_HOOKS
from astrbot.core.message.components import Image
from astrbot.core.message.message_event_result import (
MessageChain,
@@ -18,7 +17,6 @@ from astrbot.core.message.message_event_result import (
if TYPE_CHECKING:
from astrbot.core.agent.runners.base import BaseAgentRunner
from astrbot.core.pipeline.stage import Stage
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.provider.entities import (
ProviderRequest,
@@ -27,7 +25,9 @@ from astrbot.core.star.star_handler import EventType
from astrbot.core.utils.metrics import Metric
from .....astr_agent_context import AgentContextWrapper, AstrAgentContext
from .....astr_agent_hooks import MAIN_AGENT_HOOKS
from ....context import PipelineContext, call_event_hook
from ...stage import Stage
AGENT_RUNNER_TYPE_KEY = {
"dify": "dify_agent_runner_provider_id",
+1 -3
View File
@@ -8,17 +8,15 @@ from astrbot.core.platform.sources.wecom_ai_bot.wecomai_event import (
)
from astrbot.core.utils.active_event_registry import active_event_registry
from .bootstrap import ensure_builtin_stages_registered
from . import STAGES_ORDER
from .context import PipelineContext
from .stage import registered_stages
from .stage_order import STAGES_ORDER
class PipelineScheduler:
"""管道调度器,负责调度各个阶段的执行"""
def __init__(self, context: PipelineContext) -> None:
ensure_builtin_stages_registered()
registered_stages.sort(
key=lambda x: STAGES_ORDER.index(x.__name__),
) # 按照顺序排序
-15
View File
@@ -1,15 +0,0 @@
"""Pipeline stage execution order."""
STAGES_ORDER = [
"WakingCheckStage", # 检查是否需要唤醒
"WhitelistCheckStage", # 检查是否在群聊/私聊白名单
"SessionStatusCheckStage", # 检查会话是否整体启用
"RateLimitStage", # 检查会话是否超过频率限制
"ContentSafetyCheckStage", # 检查内容安全
"PreProcessStage", # 预处理
"ProcessStage", # 交由 Stars 处理(a.k.a 插件),或者 LLM 调用
"ResultDecorateStage", # 处理结果,比如添加回复前缀、t2i、转换为语音 等
"RespondStage", # 发送消息
]
__all__ = ["STAGES_ORDER"]
+11 -33
View File
@@ -52,19 +52,9 @@ class AstrMessageEvent(abc.ABC):
self.is_at_or_wake_command = False
"""是否是 At 机器人或者带有唤醒词或者是私聊(插件注册的事件监听器会让 is_wake 设为 True, 但是不会让这个属性置为 True)"""
self._extras: dict[str, Any] = {}
message_type = getattr(message_obj, "type", None)
if not isinstance(message_type, MessageType):
try:
message_type = MessageType(str(message_type))
except (ValueError, TypeError, AttributeError):
logger.warning(
f"Failed to convert message type {message_obj.type!r} to MessageType. "
f"Falling back to FRIEND_MESSAGE."
)
message_type = MessageType.FRIEND_MESSAGE
self.session = MessageSession(
platform_name=platform_meta.id,
message_type=message_type,
message_type=message_obj.type,
session_id=session_id,
)
# self.unified_msg_origin = str(self.session)
@@ -169,18 +159,15 @@ class AstrMessageEvent(abc.ABC):
除了文本消息外,其他消息类型会被转换为对应的占位符。如图片消息会被转换为 [图片]。
"""
return self._outline_chain(getattr(self.message_obj, "message", None))
return self._outline_chain(self.message_obj.message)
def get_messages(self) -> list[BaseMessageComponent]:
"""获取消息链。"""
return getattr(self.message_obj, "message", [])
return self.message_obj.message
def get_message_type(self) -> MessageType:
"""获取消息类型。"""
message_type = getattr(self.message_obj, "type", None)
if isinstance(message_type, MessageType):
return message_type
return self.session.message_type
return self.message_obj.type
def get_session_id(self) -> str:
"""获取会话id。"""
@@ -188,30 +175,21 @@ class AstrMessageEvent(abc.ABC):
def get_group_id(self) -> str:
"""获取群组id。如果不是群组消息,返回空字符串。"""
return getattr(self.message_obj, "group_id", "")
return self.message_obj.group_id
def get_self_id(self) -> str:
"""获取机器人自身的id。"""
return getattr(self.message_obj, "self_id", "")
return self.message_obj.self_id
def get_sender_id(self) -> str:
"""获取消息发送者的id。"""
sender = getattr(self.message_obj, "sender", None)
if sender and isinstance(getattr(sender, "user_id", None), str):
return sender.user_id
return ""
return self.message_obj.sender.user_id
def get_sender_name(self) -> str:
"""获取消息发送者的名称。(可能会返回空字符串)"""
sender = getattr(self.message_obj, "sender", None)
if not sender:
return ""
nickname = getattr(sender, "nickname", None)
if nickname is None:
return ""
if isinstance(nickname, str):
return nickname
return str(nickname)
if isinstance(self.message_obj.sender.nickname, str):
return self.message_obj.sender.nickname
return ""
def set_extra(self, key, value) -> None:
"""设置额外的信息。"""
@@ -230,7 +208,7 @@ class AstrMessageEvent(abc.ABC):
def is_private_chat(self) -> bool:
"""是否是私聊。"""
return self.get_message_type() == MessageType.FRIEND_MESSAGE
return self.message_obj.type.value == (MessageType.FRIEND_MESSAGE).value
def is_wake_up(self) -> bool:
"""是否是唤醒机器人的事件。"""
@@ -45,19 +45,6 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
if isinstance(segment, File):
# For File segments, we need to handle the file differently
d = await segment.to_dict()
file_val = d.get("data", {}).get("file", "")
if file_val:
import pathlib
try:
# 使用 pathlib 处理路径,能更好地处理 Windows/Linux 差异
path_obj = pathlib.Path(file_val)
# 如果是绝对路径且不包含协议头 (://),则转换为标准的 file: URI
if path_obj.is_absolute() and "://" not in file_val:
d["data"]["file"] = path_obj.as_uri()
except Exception:
# 如果不是合法路径(例如已经是特定的特殊字符串),则跳过转换
pass
return d
if isinstance(segment, Video):
d = await segment.to_dict()
@@ -1,5 +1,4 @@
import asyncio
import inspect
import itertools
import logging
import time
@@ -437,42 +436,7 @@ class AiocqhttpAdapter(Platform):
return coro
async def terminate(self) -> None:
if hasattr(self, "shutdown_event"):
self.shutdown_event.set()
await self._close_reverse_ws_connections()
async def _close_reverse_ws_connections(self) -> None:
api_clients = getattr(self.bot, "_wsr_api_clients", None)
event_clients = getattr(self.bot, "_wsr_event_clients", None)
ws_clients: set[Any] = set()
if isinstance(api_clients, dict):
ws_clients.update(api_clients.values())
if isinstance(event_clients, set):
ws_clients.update(event_clients)
close_tasks: list[Awaitable[Any]] = []
for ws in ws_clients:
close_func = getattr(ws, "close", None)
if not callable(close_func):
continue
try:
close_result = close_func(code=1000, reason="Adapter shutdown")
except TypeError:
close_result = close_func()
except Exception:
continue
if inspect.isawaitable(close_result):
close_tasks.append(close_result)
if close_tasks:
await asyncio.gather(*close_tasks, return_exceptions=True)
if isinstance(api_clients, dict):
api_clients.clear()
if isinstance(event_clients, set):
event_clients.clear()
self.shutdown_event.set()
async def shutdown_trigger_placeholder(self) -> None:
await self.shutdown_event.wait()
@@ -0,0 +1,465 @@
import json
import mimetypes
import shutil
import uuid
from collections.abc import Awaitable, Callable, Sequence
from pathlib import Path
from typing import Any
from astrbot.core.db.po import Attachment
from astrbot.core.message.components import (
File,
Image,
Json,
Plain,
Record,
Reply,
Video,
)
from astrbot.core.message.message_event_result import MessageChain
AttachmentGetter = Callable[[str], Awaitable[Attachment | None]]
AttachmentInserter = Callable[[str, str, str], Awaitable[Attachment | None]]
ReplyHistoryGetter = Callable[
[Any],
Awaitable[tuple[list[dict], str | None, str | None] | None],
]
MEDIA_PART_TYPES = {"image", "record", "file", "video"}
def strip_message_parts_path_fields(message_parts: list[dict]) -> list[dict]:
return [{k: v for k, v in part.items() if k != "path"} for part in message_parts]
def webchat_message_parts_have_content(message_parts: list[dict]) -> bool:
return any(
part.get("type") in ("plain", "image", "record", "file", "video")
and (part.get("text") or part.get("attachment_id") or part.get("filename"))
for part in message_parts
)
async def parse_webchat_message_parts(
message_parts: list,
*,
strict: bool = False,
include_empty_plain: bool = False,
verify_media_path_exists: bool = True,
reply_history_getter: ReplyHistoryGetter | None = None,
current_depth: int = 0,
max_reply_depth: int = 0,
cast_reply_id_to_str: bool = True,
) -> tuple[list, list[str], bool]:
"""Parse webchat message parts into components/text parts.
Returns:
tuple[list, list[str], bool]:
(components, plain_text_parts, has_non_reply_content)
"""
components = []
text_parts: list[str] = []
has_content = False
for part in message_parts:
if not isinstance(part, dict):
if strict:
raise ValueError("message part must be an object")
continue
part_type = str(part.get("type", "")).strip()
if part_type == "plain":
text = str(part.get("text", ""))
if text or include_empty_plain:
components.append(Plain(text=text))
text_parts.append(text)
if text:
has_content = True
continue
if part_type == "reply":
message_id = part.get("message_id")
if message_id is None:
if strict:
raise ValueError("reply part missing message_id")
continue
reply_chain = []
reply_message_str = str(part.get("selected_text", ""))
sender_id = None
sender_name = None
if reply_message_str:
reply_chain = [Plain(text=reply_message_str)]
elif (
reply_history_getter
and current_depth < max_reply_depth
and message_id is not None
):
reply_info = await reply_history_getter(message_id)
if reply_info:
reply_parts, sender_id, sender_name = reply_info
(
reply_chain,
reply_text_parts,
_,
) = await parse_webchat_message_parts(
reply_parts,
strict=strict,
include_empty_plain=include_empty_plain,
verify_media_path_exists=verify_media_path_exists,
reply_history_getter=reply_history_getter,
current_depth=current_depth + 1,
max_reply_depth=max_reply_depth,
cast_reply_id_to_str=cast_reply_id_to_str,
)
reply_message_str = "".join(reply_text_parts)
reply_id = str(message_id) if cast_reply_id_to_str else message_id
components.append(
Reply(
id=reply_id,
message_str=reply_message_str,
chain=reply_chain,
sender_id=sender_id,
sender_nickname=sender_name,
)
)
continue
if part_type not in MEDIA_PART_TYPES:
if strict:
raise ValueError(f"unsupported message part type: {part_type}")
continue
path = part.get("path")
if not path:
if strict:
raise ValueError(f"{part_type} part missing path")
continue
file_path = Path(str(path))
if verify_media_path_exists and not file_path.exists():
if strict:
raise ValueError(f"file not found: {file_path!s}")
continue
file_path_str = (
str(file_path.resolve()) if verify_media_path_exists else str(file_path)
)
has_content = True
if part_type == "image":
components.append(Image.fromFileSystem(file_path_str))
elif part_type == "record":
components.append(Record.fromFileSystem(file_path_str))
elif part_type == "video":
components.append(Video.fromFileSystem(file_path_str))
else:
filename = str(part.get("filename", "")).strip() or file_path.name
components.append(File(name=filename, file=file_path_str))
return components, text_parts, has_content
async def build_webchat_message_parts(
message_payload: str | list,
*,
get_attachment_by_id: AttachmentGetter,
strict: bool = False,
) -> list[dict]:
if isinstance(message_payload, str):
text = message_payload.strip()
return [{"type": "plain", "text": text}] if text else []
if not isinstance(message_payload, list):
if strict:
raise ValueError("message must be a string or list")
return []
message_parts: list[dict] = []
for part in message_payload:
if not isinstance(part, dict):
if strict:
raise ValueError("message part must be an object")
continue
part_type = str(part.get("type", "")).strip()
if part_type == "plain":
text = str(part.get("text", ""))
if text:
message_parts.append({"type": "plain", "text": text})
continue
if part_type == "reply":
message_id = part.get("message_id")
if message_id is None:
if strict:
raise ValueError("reply part missing message_id")
continue
message_parts.append(
{
"type": "reply",
"message_id": message_id,
"selected_text": str(part.get("selected_text", "")),
}
)
continue
if part_type not in MEDIA_PART_TYPES:
if strict:
raise ValueError(f"unsupported message part type: {part_type}")
continue
attachment_id = part.get("attachment_id")
if not attachment_id:
if strict:
raise ValueError(f"{part_type} part missing attachment_id")
continue
attachment = await get_attachment_by_id(str(attachment_id))
if not attachment:
if strict:
raise ValueError(f"attachment not found: {attachment_id}")
continue
attachment_path = Path(attachment.path)
message_parts.append(
{
"type": attachment.type,
"attachment_id": attachment.attachment_id,
"filename": attachment_path.name,
"path": str(attachment_path),
}
)
return message_parts
def webchat_message_parts_to_message_chain(
message_parts: list[dict],
*,
strict: bool = False,
) -> MessageChain:
components = []
has_content = False
for part in message_parts:
if not isinstance(part, dict):
if strict:
raise ValueError("message part must be an object")
continue
part_type = str(part.get("type", "")).strip()
if part_type == "plain":
text = str(part.get("text", ""))
if text:
components.append(Plain(text=text))
has_content = True
continue
if part_type == "reply":
message_id = part.get("message_id")
if message_id is None:
if strict:
raise ValueError("reply part missing message_id")
continue
components.append(
Reply(
id=str(message_id),
message_str=str(part.get("selected_text", "")),
chain=[],
)
)
continue
if part_type not in MEDIA_PART_TYPES:
if strict:
raise ValueError(f"unsupported message part type: {part_type}")
continue
path = part.get("path")
if not path:
if strict:
raise ValueError(f"{part_type} part missing path")
continue
file_path = Path(str(path))
if not file_path.exists():
if strict:
raise ValueError(f"file not found: {file_path!s}")
continue
file_path_str = str(file_path.resolve())
has_content = True
if part_type == "image":
components.append(Image.fromFileSystem(file_path_str))
elif part_type == "record":
components.append(Record.fromFileSystem(file_path_str))
elif part_type == "video":
components.append(Video.fromFileSystem(file_path_str))
else:
filename = str(part.get("filename", "")).strip() or file_path.name
components.append(File(name=filename, file=file_path_str))
if strict and (not components or not has_content):
raise ValueError("Message content is empty (reply only is not allowed)")
return MessageChain(chain=components)
async def build_message_chain_from_payload(
message_payload: str | list,
*,
get_attachment_by_id: AttachmentGetter,
strict: bool = True,
) -> MessageChain:
message_parts = await build_webchat_message_parts(
message_payload,
get_attachment_by_id=get_attachment_by_id,
strict=strict,
)
components, _, has_content = await parse_webchat_message_parts(
message_parts,
strict=strict,
)
if strict and (not components or not has_content):
raise ValueError("Message content is empty (reply only is not allowed)")
return MessageChain(chain=components)
async def create_attachment_part_from_existing_file(
filename: str,
*,
attach_type: str,
insert_attachment: AttachmentInserter,
attachments_dir: str | Path,
fallback_dirs: Sequence[str | Path] = (),
) -> dict | None:
basename = Path(filename).name
candidate_paths = [Path(attachments_dir) / basename]
candidate_paths.extend(Path(p) / basename for p in fallback_dirs)
file_path = next((path for path in candidate_paths if path.exists()), None)
if not file_path:
return None
mime_type, _ = mimetypes.guess_type(str(file_path))
attachment = await insert_attachment(
str(file_path),
attach_type,
mime_type or "application/octet-stream",
)
if not attachment:
return None
return {
"type": attach_type,
"attachment_id": attachment.attachment_id,
"filename": file_path.name,
}
async def message_chain_to_storage_message_parts(
message_chain: MessageChain,
*,
insert_attachment: AttachmentInserter,
attachments_dir: str | Path,
) -> list[dict]:
target_dir = Path(attachments_dir)
target_dir.mkdir(parents=True, exist_ok=True)
parts: list[dict] = []
for comp in message_chain.chain:
if isinstance(comp, Plain):
if comp.text:
parts.append({"type": "plain", "text": comp.text})
continue
if isinstance(comp, Json):
parts.append(
{"type": "plain", "text": json.dumps(comp.data, ensure_ascii=False)}
)
continue
if isinstance(comp, Image):
file_path = await comp.convert_to_file_path()
attachment_part = await _copy_file_to_attachment_part(
file_path=file_path,
attach_type="image",
insert_attachment=insert_attachment,
attachments_dir=target_dir,
)
if attachment_part:
parts.append(attachment_part)
continue
if isinstance(comp, Record):
file_path = await comp.convert_to_file_path()
attachment_part = await _copy_file_to_attachment_part(
file_path=file_path,
attach_type="record",
insert_attachment=insert_attachment,
attachments_dir=target_dir,
)
if attachment_part:
parts.append(attachment_part)
continue
if isinstance(comp, Video):
file_path = await comp.convert_to_file_path()
attachment_part = await _copy_file_to_attachment_part(
file_path=file_path,
attach_type="video",
insert_attachment=insert_attachment,
attachments_dir=target_dir,
)
if attachment_part:
parts.append(attachment_part)
continue
if isinstance(comp, File):
file_path = await comp.get_file()
attachment_part = await _copy_file_to_attachment_part(
file_path=file_path,
attach_type="file",
insert_attachment=insert_attachment,
attachments_dir=target_dir,
display_name=comp.name,
)
if attachment_part:
parts.append(attachment_part)
continue
return parts
async def _copy_file_to_attachment_part(
*,
file_path: str,
attach_type: str,
insert_attachment: AttachmentInserter,
attachments_dir: Path,
display_name: str | None = None,
) -> dict | None:
src_path = Path(file_path)
if not src_path.exists() or not src_path.is_file():
return None
suffix = src_path.suffix
target_path = attachments_dir / f"{uuid.uuid4().hex}{suffix}"
shutil.copy2(src_path, target_path)
mime_type, _ = mimetypes.guess_type(target_path.name)
attachment = await insert_attachment(
str(target_path),
attach_type,
mime_type or "application/octet-stream",
)
if not attachment:
return None
return {
"type": attach_type,
"attachment_id": attachment.attachment_id,
"filename": display_name or src_path.name,
}
@@ -3,12 +3,12 @@ import os
import time
import uuid
from collections.abc import Callable, Coroutine
from pathlib import Path
from typing import Any
from astrbot import logger
from astrbot.core import db_helper
from astrbot.core.db.po import PlatformMessageHistory
from astrbot.core.message.components import File, Image, Plain, Record, Reply, Video
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.platform import (
AstrBotMessage,
@@ -21,10 +21,23 @@ from astrbot.core.platform.astr_message_event import MessageSesion
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from ...register import register_platform_adapter
from .message_parts_helper import (
message_chain_to_storage_message_parts,
parse_webchat_message_parts,
)
from .webchat_event import WebChatMessageEvent
from .webchat_queue_mgr import WebChatQueueMgr, webchat_queue_mgr
def _extract_conversation_id(session_id: str) -> str:
"""Extract raw webchat conversation id from event/session id."""
if session_id.startswith("webchat!"):
parts = session_id.split("!", 2)
if len(parts) == 3:
return parts[2]
return session_id
class QueueListener:
def __init__(
self,
@@ -57,13 +70,15 @@ class WebChatAdapter(Platform):
self.settings = platform_settings
self.imgs_dir = os.path.join(get_astrbot_data_path(), "webchat", "imgs")
self.attachments_dir = Path(get_astrbot_data_path()) / "attachments"
os.makedirs(self.imgs_dir, exist_ok=True)
self.attachments_dir.mkdir(parents=True, exist_ok=True)
self.metadata = PlatformMetadata(
name="webchat",
description="webchat",
id="webchat",
support_proactive_message=False,
support_proactive_message=True,
)
self._shutdown_event = asyncio.Event()
self._webchat_queue_mgr = webchat_queue_mgr
@@ -73,10 +88,67 @@ class WebChatAdapter(Platform):
session: MessageSesion,
message_chain: MessageChain,
) -> None:
message_id = f"active_{str(uuid.uuid4())}"
await WebChatMessageEvent._send(message_id, message_chain, session.session_id)
conversation_id = _extract_conversation_id(session.session_id)
active_request_ids = self._webchat_queue_mgr.list_back_request_ids(
conversation_id
)
subscription_request_ids = [
req_id for req_id in active_request_ids if req_id.startswith("ws_sub_")
]
target_request_ids = subscription_request_ids or active_request_ids
if target_request_ids:
for request_id in target_request_ids:
await WebChatMessageEvent._send(
request_id,
message_chain,
session.session_id,
)
else:
message_id = f"active_{uuid.uuid4()!s}"
await WebChatMessageEvent._send(
message_id,
message_chain,
session.session_id,
)
should_persist = (
bool(subscription_request_ids)
or not active_request_ids
or all(req_id.startswith("active_") for req_id in active_request_ids)
)
if should_persist:
try:
await self._save_proactive_message(conversation_id, message_chain)
except Exception as e:
logger.error(
f"[WebChatAdapter] Failed to save proactive message: {e}",
exc_info=True,
)
await super().send_by_session(session, message_chain)
async def _save_proactive_message(
self,
conversation_id: str,
message_chain: MessageChain,
) -> None:
message_parts = await message_chain_to_storage_message_parts(
message_chain,
insert_attachment=db_helper.insert_attachment,
attachments_dir=self.attachments_dir,
)
if not message_parts:
return
await db_helper.insert_platform_message_history(
platform_id="webchat",
user_id=conversation_id,
content={"type": "bot", "message": message_parts},
sender_id="bot",
sender_name="bot",
)
async def _get_message_history(
self, message_id: int
) -> PlatformMessageHistory | None:
@@ -98,72 +170,30 @@ class WebChatAdapter(Platform):
Returns:
tuple[list, list[str]]: (消息组件列表, 纯文本列表)
"""
components = []
text_parts = []
for part in message_parts:
part_type = part.get("type")
if part_type == "plain":
text = part.get("text", "")
components.append(Plain(text=text))
text_parts.append(text)
elif part_type == "reply":
message_id = part.get("message_id")
reply_chain = []
reply_message_str = part.get("selected_text", "")
sender_id = None
sender_name = None
async def get_reply_parts(
message_id: Any,
) -> tuple[list[dict], str | None, str | None] | None:
history = await self._get_message_history(message_id)
if not history or not history.content:
return None
if reply_message_str:
reply_chain = [Plain(text=reply_message_str)]
reply_parts = history.content.get("message", [])
if not isinstance(reply_parts, list):
return None
# recursively get the content of the referenced message, if selected_text is empty
if not reply_message_str and depth < max_depth and message_id:
history = await self._get_message_history(message_id)
if history and history.content:
reply_parts = history.content.get("message", [])
if isinstance(reply_parts, list):
(
reply_chain,
reply_text_parts,
) = await self._parse_message_parts(
reply_parts,
depth=depth + 1,
max_depth=max_depth,
)
reply_message_str = "".join(reply_text_parts)
sender_id = history.sender_id
sender_name = history.sender_name
components.append(
Reply(
id=message_id,
chain=reply_chain,
message_str=reply_message_str,
sender_id=sender_id,
sender_nickname=sender_name,
)
)
elif part_type == "image":
path = part.get("path")
if path:
components.append(Image.fromFileSystem(path))
elif part_type == "record":
path = part.get("path")
if path:
components.append(Record.fromFileSystem(path))
elif part_type == "file":
path = part.get("path")
if path:
filename = part.get("filename") or (
os.path.basename(path) if path else "file"
)
components.append(File(name=filename, file=path))
elif part_type == "video":
path = part.get("path")
if path:
components.append(Video.fromFileSystem(path))
return reply_parts, history.sender_id, history.sender_name
components, text_parts, _ = await parse_webchat_message_parts(
message_parts,
strict=False,
include_empty_plain=True,
verify_media_path_exists=False,
reply_history_getter=get_reply_parts,
current_depth=depth,
max_reply_depth=max_depth,
cast_reply_id_to_str=False,
)
return components, text_parts
async def convert_message(self, data: tuple) -> AstrBotMessage:
@@ -14,6 +14,15 @@ from .webchat_queue_mgr import webchat_queue_mgr
attachments_dir = os.path.join(get_astrbot_data_path(), "attachments")
def _extract_conversation_id(session_id: str) -> str:
"""Extract raw webchat conversation id from event/session id."""
if session_id.startswith("webchat!"):
parts = session_id.split("!", 2)
if len(parts) == 3:
return parts[2]
return session_id
class WebChatMessageEvent(AstrMessageEvent):
def __init__(self, message_str, message_obj, platform_meta, session_id) -> None:
super().__init__(message_str, message_obj, platform_meta, session_id)
@@ -27,7 +36,7 @@ class WebChatMessageEvent(AstrMessageEvent):
streaming: bool = False,
) -> str | None:
request_id = str(message_id)
conversation_id = session_id.split("!")[-1]
conversation_id = _extract_conversation_id(session_id)
web_chat_back_queue = webchat_queue_mgr.get_or_create_back_queue(
request_id,
conversation_id,
@@ -130,7 +139,7 @@ class WebChatMessageEvent(AstrMessageEvent):
reasoning_content = ""
message_id = self.message_obj.message_id
request_id = str(message_id)
conversation_id = self.session_id.split("!")[-1]
conversation_id = _extract_conversation_id(self.session_id)
web_chat_back_queue = webchat_queue_mgr.get_or_create_back_queue(
request_id,
conversation_id,
@@ -75,6 +75,10 @@ class WebChatQueueMgr:
if task is not None:
task.cancel()
def list_back_request_ids(self, conversation_id: str) -> list[str]:
"""List active back-queue request IDs for a conversation."""
return list(self._conversation_back_requests.get(conversation_id, set()))
def has_queue(self, conversation_id: str) -> bool:
"""Check if a queue exists for the given conversation ID"""
return conversation_id in self.queues
@@ -48,9 +48,6 @@ class GeminiEmbeddingProvider(EmbeddingProvider):
result = await self.client.models.embed_content(
model=self.model,
contents=text,
config=types.EmbedContentConfig(
output_dimensionality=self.get_dim(),
),
)
assert result.embeddings is not None
assert result.embeddings[0].values is not None
@@ -64,9 +61,6 @@ class GeminiEmbeddingProvider(EmbeddingProvider):
result = await self.client.models.embed_content(
model=self.model,
contents=cast(types.ContentListUnion, text),
config=types.EmbedContentConfig(
output_dimensionality=self.get_dim(),
),
)
assert result.embeddings is not None
@@ -36,20 +36,12 @@ class OpenAIEmbeddingProvider(EmbeddingProvider):
async def get_embedding(self, text: str) -> list[float]:
"""获取文本的嵌入"""
embedding = await self.client.embeddings.create(
input=text,
model=self.model,
dimensions=self.get_dim(),
)
embedding = await self.client.embeddings.create(input=text, model=self.model)
return embedding.data[0].embedding
async def get_embeddings(self, text: list[str]) -> list[list[float]]:
"""批量获取文本的嵌入"""
embeddings = await self.client.embeddings.create(
input=text,
model=self.model,
dimensions=self.get_dim(),
)
embeddings = await self.client.embeddings.create(input=text, model=self.model)
return [item.embedding for item in embeddings.data]
def get_dim(self) -> int:
+62 -13
View File
@@ -1,19 +1,68 @@
# 兼容导出: Provider 从 provider 模块重新导出
from astrbot.core import html_renderer
from astrbot.core.provider import Provider
from astrbot.core.star.star_tools import StarTools
from astrbot.core.utils.command_parser import CommandParserMixin
from astrbot.core.utils.plugin_kv_store import PluginKVStoreMixin
from .base import Star
from .context import Context
from .star import StarMetadata, star_map, star_registry
from .star_manager import PluginManager
from .star_tools import StarTools
__all__ = [
"Context",
"PluginManager",
"Provider",
"Star",
"StarMetadata",
"StarTools",
"star_map",
"star_registry",
]
class Star(CommandParserMixin, PluginKVStoreMixin):
"""所有插件(Star)的父类,所有插件都应该继承于这个类"""
author: str
name: str
def __init__(self, context: Context, config: dict | None = None) -> None:
StarTools.initialize(context)
self.context = context
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
if not star_map.get(cls.__module__):
metadata = StarMetadata(
star_cls_type=cls,
module_path=cls.__module__,
)
star_map[cls.__module__] = metadata
star_registry.append(metadata)
else:
star_map[cls.__module__].star_cls_type = cls
star_map[cls.__module__].module_path = cls.__module__
async def text_to_image(self, text: str, return_url=True) -> str:
"""将文本转换为图片"""
return await html_renderer.render_t2i(
text,
return_url=return_url,
template_name=self.context._config.get("t2i_active_template"),
)
async def html_render(
self,
tmpl: str,
data: dict,
return_url=True,
options: dict | None = None,
) -> str:
"""渲染 HTML"""
return await html_renderer.render_custom_template(
tmpl,
data,
return_url=return_url,
options=options,
)
async def initialize(self) -> None:
"""当插件被激活时会调用这个方法"""
async def terminate(self) -> None:
"""当插件被禁用、重载插件时会调用这个方法"""
def __del__(self) -> None:
"""[Deprecated] 当插件被禁用、重载插件时会调用这个方法"""
__all__ = ["Context", "PluginManager", "Provider", "Star", "StarMetadata", "StarTools"]
-87
View File
@@ -1,87 +0,0 @@
from __future__ import annotations
import logging
from typing import Any, Protocol
from astrbot.core import html_renderer
from astrbot.core.utils.command_parser import CommandParserMixin
from astrbot.core.utils.plugin_kv_store import PluginKVStoreMixin
from .star import StarMetadata, star_map, star_registry
logger = logging.getLogger("astrbot")
class Star(CommandParserMixin, PluginKVStoreMixin):
"""所有插件(Star)的父类,所有插件都应该继承于这个类"""
author: str
name: str
class _ContextLike(Protocol):
def get_config(self, umo: str | None = None) -> Any: ...
def __init__(self, context: _ContextLike, config: dict | None = None) -> None:
self.context = context
def _get_context_config(self) -> Any:
get_config = getattr(self.context, "get_config", None)
if callable(get_config):
try:
return get_config()
except Exception as e:
logger.debug(f"get_config() failed: {e}")
return None
return getattr(self.context, "_config", None)
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
if not star_map.get(cls.__module__):
metadata = StarMetadata(
star_cls_type=cls,
module_path=cls.__module__,
)
star_map[cls.__module__] = metadata
star_registry.append(metadata)
else:
star_map[cls.__module__].star_cls_type = cls
star_map[cls.__module__].module_path = cls.__module__
async def text_to_image(self, text: str, return_url=True) -> str:
"""将文本转换为图片"""
config_obj = self._get_context_config()
template_name = None
if hasattr(config_obj, "get"):
try:
template_name = config_obj.get("t2i_active_template")
except Exception:
template_name = None
return await html_renderer.render_t2i(
text,
return_url=return_url,
template_name=template_name,
)
async def html_render(
self,
tmpl: str,
data: dict,
return_url=True,
options: dict | None = None,
) -> str:
"""渲染 HTML"""
return await html_renderer.render_custom_template(
tmpl,
data,
return_url=return_url,
options=options,
)
async def initialize(self) -> None:
"""当插件被激活时会调用这个方法"""
async def terminate(self) -> None:
"""当插件被禁用、重载插件时会调用这个方法"""
def __del__(self) -> None:
"""[Deprecated] 当插件被禁用、重载插件时会调用这个方法"""
+4 -16
View File
@@ -1,9 +1,7 @@
from __future__ import annotations
import logging
from asyncio import Queue
from collections.abc import Awaitable, Callable
from typing import TYPE_CHECKING, Any, Protocol
from typing import Any
from deprecated import deprecated
@@ -14,12 +12,14 @@ from astrbot.core.agent.tool import ToolSet
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.conversation_mgr import ConversationManager
from astrbot.core.cron.manager import CronJobManager
from astrbot.core.db import BaseDatabase
from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.persona_mgr import PersonaManager
from astrbot.core.platform import Platform
from astrbot.core.platform.astr_message_event import AstrMessageEvent, MessageSesion
from astrbot.core.platform.manager import PlatformManager
from astrbot.core.platform_message_history_mgr import PlatformMessageHistoryManager
from astrbot.core.provider.entities import LLMResponse, ProviderRequest, ProviderType
from astrbot.core.provider.func_tool_manager import FunctionTool, FunctionToolManager
@@ -45,15 +45,6 @@ from .star_handler import EventType, StarHandlerMetadata, star_handlers_registry
logger = logging.getLogger("astrbot")
if TYPE_CHECKING:
from astrbot.core.cron.manager import CronJobManager
else:
CronJobManager = Any
class PlatformManagerProtocol(Protocol):
platform_insts: list[Platform]
class Context:
"""暴露给插件的接口上下文。"""
@@ -70,7 +61,7 @@ class Context:
config: AstrBotConfig,
db: BaseDatabase,
provider_manager: ProviderManager,
platform_manager: PlatformManagerProtocol,
platform_manager: PlatformManager,
conversation_manager: ConversationManager,
message_history_manager: PlatformMessageHistoryManager,
persona_manager: PersonaManager,
@@ -457,9 +448,6 @@ class Context:
if platform.meta().id == session.platform_name:
await platform.send_by_session(session, message_chain)
return True
logger.warning(
f"cannot find platform for session {str(session)}, message not sent"
)
return False
def add_llm_tools(self, *tools: FunctionTool) -> None:
+1 -1
View File
@@ -1,6 +1,6 @@
import warnings
from astrbot.core.star.star import StarMetadata, star_map
from astrbot.core.star import StarMetadata, star_map
_warned_register_star = False
+5 -4
View File
@@ -11,6 +11,7 @@ from astrbot.core.agent.agent import Agent
from astrbot.core.agent.handoff import HandoffTool
from astrbot.core.agent.hooks import BaseAgentRunHooks
from astrbot.core.agent.tool import FunctionTool
from astrbot.core.astr_agent_context import AstrAgentContext
from astrbot.core.message.message_event_result import MessageEventResult
from astrbot.core.provider.func_tool_manager import PY_TO_JSON_TYPE, SUPPORTED_TYPES
from astrbot.core.provider.register import llm_tools
@@ -616,7 +617,7 @@ class RegisteringAgent:
kwargs["registering_agent"] = self
return register_llm_tool(*args, **kwargs)
def __init__(self, agent: Agent[Any]) -> None:
def __init__(self, agent: Agent[AstrAgentContext]) -> None:
self._agent = agent
@@ -624,7 +625,7 @@ def register_agent(
name: str,
instruction: str,
tools: list[str | FunctionTool] | None = None,
run_hooks: BaseAgentRunHooks[Any] | None = None,
run_hooks: BaseAgentRunHooks[AstrAgentContext] | None = None,
):
"""注册一个 Agent
@@ -638,12 +639,12 @@ def register_agent(
tools_ = tools or []
def decorator(awaitable: Callable[..., Awaitable[Any]]):
AstrAgent = Agent[Any]
AstrAgent = Agent[AstrAgentContext]
agent = AstrAgent(
name=name,
instructions=instruction,
tools=tools_,
run_hooks=run_hooks or BaseAgentRunHooks[Any](),
run_hooks=run_hooks or BaseAgentRunHooks[AstrAgentContext](),
)
handoff_tool = HandoffTool(agent=agent)
handoff_tool.handler = awaitable
-3
View File
@@ -49,13 +49,10 @@ class PluginVersionIncompatibleError(Exception):
class PluginManager:
def __init__(self, context: Context, config: AstrBotConfig) -> None:
from .star_tools import StarTools
self.updator = PluginUpdator()
self.context = context
self.context._star_manager = self # type: ignore
StarTools.initialize(context)
self.config = config
self.plugin_store_path = get_astrbot_plugin_path()
+26 -92
View File
@@ -1,6 +1,5 @@
import asyncio
import json
import mimetypes
import os
import re
import uuid
@@ -14,6 +13,12 @@ from astrbot.core import logger, sp
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.db import BaseDatabase
from astrbot.core.platform.message_type import MessageType
from astrbot.core.platform.sources.webchat.message_parts_helper import (
build_webchat_message_parts,
create_attachment_part_from_existing_file,
strip_message_parts_path_fields,
webchat_message_parts_have_content,
)
from astrbot.core.platform.sources.webchat.webchat_queue_mgr import webchat_queue_mgr
from astrbot.core.utils.active_event_registry import active_event_registry
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
@@ -166,83 +171,24 @@ class ChatRoute(Route):
)
async def _build_user_message_parts(self, message: str | list) -> list[dict]:
"""构建用户消息的部分列表
Args:
message: 文本消息 (str) 或消息段列表 (list)
"""
parts = []
if isinstance(message, list):
for part in message:
part_type = part.get("type")
if part_type == "plain":
parts.append({"type": "plain", "text": part.get("text", "")})
elif part_type == "reply":
parts.append(
{
"type": "reply",
"message_id": part.get("message_id"),
"selected_text": part.get("selected_text", ""),
}
)
elif attachment_id := part.get("attachment_id"):
attachment = await self.db.get_attachment_by_id(attachment_id)
if attachment:
parts.append(
{
"type": attachment.type,
"attachment_id": attachment.attachment_id,
"filename": os.path.basename(attachment.path),
"path": attachment.path, # will be deleted
}
)
return parts
if message:
parts.append({"type": "plain", "text": message})
return parts
"""构建用户消息的部分列表"""
return await build_webchat_message_parts(
message,
get_attachment_by_id=self.db.get_attachment_by_id,
strict=False,
)
async def _create_attachment_from_file(
self, filename: str, attach_type: str
) -> dict | None:
"""从本地文件创建 attachment 并返回消息部分
用于处理 bot 回复中的媒体文件
Args:
filename: 存储的文件名
attach_type: 附件类型 (image, record, file, video)
"""
basename = os.path.basename(filename)
candidate_paths = [
os.path.join(self.attachments_dir, basename),
os.path.join(self.legacy_img_dir, basename),
]
file_path = next((p for p in candidate_paths if os.path.exists(p)), None)
if not file_path:
return None
# guess mime type
mime_type, _ = mimetypes.guess_type(filename)
if not mime_type:
mime_type = "application/octet-stream"
# insert attachment
attachment = await self.db.insert_attachment(
path=file_path,
type=attach_type,
mime_type=mime_type,
"""从本地文件创建 attachment 并返回消息部分"""
return await create_attachment_part_from_existing_file(
filename,
attach_type=attach_type,
insert_attachment=self.db.insert_attachment,
attachments_dir=self.attachments_dir,
fallback_dirs=[self.legacy_img_dir],
)
if not attachment:
return None
return {
"type": attach_type,
"attachment_id": attachment.attachment_id,
"filename": os.path.basename(file_path),
}
def _extract_web_search_refs(
self, accumulated_text: str, accumulated_parts: list
@@ -356,21 +302,6 @@ class ChatRoute(Route):
selected_model = post_data.get("selected_model")
enable_streaming = post_data.get("enable_streaming", True)
# 检查消息是否为空
if isinstance(message, list):
has_content = any(
part.get("type") in ("plain", "image", "record", "file", "video")
for part in message
)
if not has_content:
return (
Response()
.error("Message content is empty (reply only is not allowed)")
.__dict__
)
elif not message:
return Response().error("Message are both empty").__dict__
if not session_id:
return Response().error("session_id is empty").__dict__
@@ -378,6 +309,12 @@ class ChatRoute(Route):
# 构建用户消息段(包含 path 用于传递给 adapter
message_parts = await self._build_user_message_parts(message)
if not webchat_message_parts_have_content(message_parts):
return (
Response()
.error("Message content is empty (reply only is not allowed)")
.__dict__
)
message_id = str(uuid.uuid4())
back_queue = webchat_queue_mgr.get_or_create_back_queue(
@@ -583,10 +520,7 @@ class ChatRoute(Route):
),
)
message_parts_for_storage = []
for part in message_parts:
part_copy = {k: v for k, v in part.items() if k != "path"}
message_parts_for_storage.append(part_copy)
message_parts_for_storage = strip_message_parts_path_fields(message_parts)
await self.platform_history_mgr.insert(
platform_id="webchat",
+509 -3
View File
@@ -1,6 +1,7 @@
import asyncio
import json
import os
import re
import time
import uuid
import wave
@@ -10,9 +11,16 @@ import jwt
from quart import websocket
from astrbot import logger
from astrbot.core import sp
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.platform.sources.webchat.message_parts_helper import (
build_webchat_message_parts,
create_attachment_part_from_existing_file,
strip_message_parts_path_fields,
webchat_message_parts_have_content,
)
from astrbot.core.platform.sources.webchat.webchat_queue_mgr import webchat_queue_mgr
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from astrbot.core.utils.astrbot_path import get_astrbot_data_path, get_astrbot_temp_path
from .route import Route, RouteContext
@@ -30,6 +38,9 @@ class LiveChatSession:
self.audio_frames: list[bytes] = []
self.current_stamp: str | None = None
self.temp_audio_path: str | None = None
self.chat_subscriptions: dict[str, str] = {}
self.chat_subscription_tasks: dict[str, asyncio.Task] = {}
self.ws_send_lock = asyncio.Lock()
def start_speaking(self, stamp: str) -> None:
"""开始说话"""
@@ -106,13 +117,26 @@ class LiveChatRoute(Route):
self.core_lifecycle = core_lifecycle
self.db = db
self.plugin_manager = core_lifecycle.plugin_manager
self.platform_history_mgr = core_lifecycle.platform_message_history_manager
self.sessions: dict[str, LiveChatSession] = {}
self.attachments_dir = os.path.join(get_astrbot_data_path(), "attachments")
self.legacy_img_dir = os.path.join(get_astrbot_data_path(), "webchat", "imgs")
os.makedirs(self.attachments_dir, exist_ok=True)
# 注册 WebSocket 路由
self.app.websocket("/api/live_chat/ws")(self.live_chat_ws)
self.app.websocket("/api/unified_chat/ws")(self.unified_chat_ws)
async def live_chat_ws(self) -> None:
"""Live Chat WebSocket 处理器"""
"""Legacy Live Chat WebSocket 处理器(默认 ct=live"""
await self._unified_ws_loop(force_ct="live")
async def unified_chat_ws(self) -> None:
"""Unified Chat WebSocket 处理器(支持 ct=live/chat"""
await self._unified_ws_loop(force_ct=None)
async def _unified_ws_loop(self, force_ct: str | None = None) -> None:
"""统一 WebSocket 循环"""
# WebSocket 不能通过 header 传递 token,需要从 query 参数获取
# 注意:WebSocket 上下文使用 websocket.args 而不是 request.args
token = websocket.args.get("token")
@@ -140,7 +164,11 @@ class LiveChatRoute(Route):
try:
while True:
message = await websocket.receive_json()
await self._handle_message(live_session, message)
ct = force_ct or message.get("ct", "live")
if ct == "chat":
await self._handle_chat_message(live_session, message)
else:
await self._handle_message(live_session, message)
except Exception as e:
logger.error(f"[Live Chat] WebSocket 错误: {e}", exc_info=True)
@@ -148,10 +176,488 @@ class LiveChatRoute(Route):
finally:
# 清理会话
if session_id in self.sessions:
await self._cleanup_chat_subscriptions(live_session)
live_session.cleanup()
del self.sessions[session_id]
logger.info(f"[Live Chat] WebSocket 连接关闭: {username}")
async def _create_attachment_from_file(
self, filename: str, attach_type: str
) -> dict | None:
"""从本地文件创建 attachment 并返回消息部分。"""
return await create_attachment_part_from_existing_file(
filename,
attach_type=attach_type,
insert_attachment=self.db.insert_attachment,
attachments_dir=self.attachments_dir,
fallback_dirs=[self.legacy_img_dir],
)
def _extract_web_search_refs(
self, accumulated_text: str, accumulated_parts: list
) -> dict:
"""从消息中提取 web_search 引用。"""
supported = ["web_search_tavily", "web_search_bocha"]
web_search_results = {}
tool_call_parts = [
p
for p in accumulated_parts
if p.get("type") == "tool_call" and p.get("tool_calls")
]
for part in tool_call_parts:
for tool_call in part["tool_calls"]:
if tool_call.get("name") not in supported or not tool_call.get(
"result"
):
continue
try:
result_data = json.loads(tool_call["result"])
for item in result_data.get("results", []):
if idx := item.get("index"):
web_search_results[idx] = {
"url": item.get("url"),
"title": item.get("title"),
"snippet": item.get("snippet"),
}
except (json.JSONDecodeError, KeyError):
pass
if not web_search_results:
return {}
ref_indices = {
m.strip() for m in re.findall(r"<ref>(.*?)</ref>", accumulated_text)
}
used_refs = []
for ref_index in ref_indices:
if ref_index not in web_search_results:
continue
payload = {"index": ref_index, **web_search_results[ref_index]}
if favicon := sp.temporary_cache.get("_ws_favicon", {}).get(payload["url"]):
payload["favicon"] = favicon
used_refs.append(payload)
return {"used": used_refs} if used_refs else {}
async def _save_bot_message(
self,
webchat_conv_id: str,
text: str,
media_parts: list,
reasoning: str,
agent_stats: dict,
refs: dict,
):
"""保存 bot 消息到历史记录。"""
bot_message_parts = []
bot_message_parts.extend(media_parts)
if text:
bot_message_parts.append({"type": "plain", "text": text})
new_his = {"type": "bot", "message": bot_message_parts}
if reasoning:
new_his["reasoning"] = reasoning
if agent_stats:
new_his["agent_stats"] = agent_stats
if refs:
new_his["refs"] = refs
return await self.platform_history_mgr.insert(
platform_id="webchat",
user_id=webchat_conv_id,
content=new_his,
sender_id="bot",
sender_name="bot",
)
async def _send_chat_payload(self, session: LiveChatSession, payload: dict) -> None:
async with session.ws_send_lock:
await websocket.send_json(payload)
async def _forward_chat_subscription(
self,
session: LiveChatSession,
chat_session_id: str,
request_id: str,
) -> None:
back_queue = webchat_queue_mgr.get_or_create_back_queue(
request_id, chat_session_id
)
try:
while True:
result = await back_queue.get()
if not result:
continue
await self._send_chat_payload(session, {"ct": "chat", **result})
except asyncio.CancelledError:
pass
except Exception as e:
logger.error(
f"[Live Chat] chat subscription forward failed ({chat_session_id}): {e}",
exc_info=True,
)
finally:
webchat_queue_mgr.remove_back_queue(request_id)
if session.chat_subscriptions.get(chat_session_id) == request_id:
session.chat_subscriptions.pop(chat_session_id, None)
session.chat_subscription_tasks.pop(chat_session_id, None)
async def _ensure_chat_subscription(
self,
session: LiveChatSession,
chat_session_id: str,
) -> str:
existing_request_id = session.chat_subscriptions.get(chat_session_id)
existing_task = session.chat_subscription_tasks.get(chat_session_id)
if existing_request_id and existing_task and not existing_task.done():
return existing_request_id
request_id = f"ws_sub_{uuid.uuid4().hex}"
session.chat_subscriptions[chat_session_id] = request_id
task = asyncio.create_task(
self._forward_chat_subscription(session, chat_session_id, request_id),
name=f"chat_ws_sub_{chat_session_id}",
)
session.chat_subscription_tasks[chat_session_id] = task
return request_id
async def _cleanup_chat_subscriptions(self, session: LiveChatSession) -> None:
tasks = list(session.chat_subscription_tasks.values())
for task in tasks:
task.cancel()
if tasks:
await asyncio.gather(*tasks, return_exceptions=True)
for request_id in list(session.chat_subscriptions.values()):
webchat_queue_mgr.remove_back_queue(request_id)
session.chat_subscriptions.clear()
session.chat_subscription_tasks.clear()
async def _handle_chat_message(
self, session: LiveChatSession, message: dict
) -> None:
"""处理 Chat Mode 消息(ct=chat"""
msg_type = message.get("t")
if msg_type == "bind":
chat_session_id = message.get("session_id")
if not isinstance(chat_session_id, str) or not chat_session_id:
await self._send_chat_payload(
session,
{
"ct": "chat",
"t": "error",
"data": "session_id is required",
"code": "INVALID_MESSAGE_FORMAT",
},
)
return
request_id = await self._ensure_chat_subscription(session, chat_session_id)
await self._send_chat_payload(
session,
{
"ct": "chat",
"type": "session_bound",
"session_id": chat_session_id,
"message_id": request_id,
},
)
return
if msg_type == "interrupt":
session.should_interrupt = True
await self._send_chat_payload(
session,
{
"ct": "chat",
"t": "error",
"data": "INTERRUPTED",
"code": "INTERRUPTED",
},
)
return
if msg_type != "send":
await self._send_chat_payload(
session,
{
"ct": "chat",
"t": "error",
"data": f"Unsupported message type: {msg_type}",
"code": "INVALID_MESSAGE_FORMAT",
},
)
return
if session.is_processing:
await self._send_chat_payload(
session,
{
"ct": "chat",
"t": "error",
"data": "Session is busy",
"code": "PROCESSING_ERROR",
},
)
return
payload = message.get("message")
session_id = message.get("session_id") or session.session_id
message_id = message.get("message_id") or str(uuid.uuid4())
selected_provider = message.get("selected_provider")
selected_model = message.get("selected_model")
selected_stt_provider = message.get("selected_stt_provider")
selected_tts_provider = message.get("selected_tts_provider")
persona_prompt = message.get("persona_prompt")
show_reasoning = message.get("show_reasoning")
enable_streaming = message.get("enable_streaming", True)
if not isinstance(payload, list):
await self._send_chat_payload(
session,
{
"ct": "chat",
"t": "error",
"data": "message must be list",
"code": "INVALID_MESSAGE_FORMAT",
},
)
return
message_parts = await self._build_chat_message_parts(payload)
has_content = webchat_message_parts_have_content(message_parts)
if not has_content:
await self._send_chat_payload(
session,
{
"ct": "chat",
"t": "error",
"data": "Message content is empty",
"code": "INVALID_MESSAGE_FORMAT",
},
)
return
await self._ensure_chat_subscription(session, session_id)
session.is_processing = True
session.should_interrupt = False
back_queue = webchat_queue_mgr.get_or_create_back_queue(message_id, session_id)
try:
chat_queue = webchat_queue_mgr.get_or_create_queue(session_id)
await chat_queue.put(
(
session.username,
session_id,
{
"message": message_parts,
"selected_provider": selected_provider,
"selected_model": selected_model,
"selected_stt_provider": selected_stt_provider,
"selected_tts_provider": selected_tts_provider,
"persona_prompt": persona_prompt,
"show_reasoning": show_reasoning,
"enable_streaming": enable_streaming,
"message_id": message_id,
},
),
)
message_parts_for_storage = strip_message_parts_path_fields(message_parts)
await self.platform_history_mgr.insert(
platform_id="webchat",
user_id=session_id,
content={"type": "user", "message": message_parts_for_storage},
sender_id=session.username,
sender_name=session.username,
)
accumulated_parts = []
accumulated_text = ""
accumulated_reasoning = ""
tool_calls = {}
agent_stats = {}
refs = {}
while True:
if session.should_interrupt:
session.should_interrupt = False
break
try:
result = await asyncio.wait_for(back_queue.get(), timeout=1)
except asyncio.TimeoutError:
continue
if not result:
continue
if result.get("message_id") and result.get("message_id") != message_id:
continue
result_text = result.get("data", "")
msg_type = result.get("type")
streaming = result.get("streaming", False)
chain_type = result.get("chain_type")
if chain_type == "agent_stats":
try:
parsed_agent_stats = json.loads(result_text)
agent_stats = parsed_agent_stats
await self._send_chat_payload(
session,
{
"ct": "chat",
"type": "agent_stats",
"data": parsed_agent_stats,
},
)
except Exception:
pass
continue
outgoing = {"ct": "chat", **result}
await self._send_chat_payload(session, outgoing)
if msg_type == "plain":
if chain_type == "tool_call":
try:
tool_call = json.loads(result_text)
tool_calls[tool_call.get("id")] = tool_call
if accumulated_text:
accumulated_parts.append(
{"type": "plain", "text": accumulated_text}
)
accumulated_text = ""
except Exception:
pass
elif chain_type == "tool_call_result":
try:
tcr = json.loads(result_text)
tc_id = tcr.get("id")
if tc_id in tool_calls:
tool_calls[tc_id]["result"] = tcr.get("result")
tool_calls[tc_id]["finished_ts"] = tcr.get("ts")
accumulated_parts.append(
{
"type": "tool_call",
"tool_calls": [tool_calls[tc_id]],
}
)
tool_calls.pop(tc_id, None)
except Exception:
pass
elif chain_type == "reasoning":
accumulated_reasoning += result_text
elif streaming:
accumulated_text += result_text
else:
accumulated_text = result_text
elif msg_type == "image":
filename = str(result_text).replace("[IMAGE]", "")
part = await self._create_attachment_from_file(filename, "image")
if part:
accumulated_parts.append(part)
elif msg_type == "record":
filename = str(result_text).replace("[RECORD]", "")
part = await self._create_attachment_from_file(filename, "record")
if part:
accumulated_parts.append(part)
elif msg_type == "file":
filename = str(result_text).replace("[FILE]", "").split("|", 1)[0]
part = await self._create_attachment_from_file(filename, "file")
if part:
accumulated_parts.append(part)
elif msg_type == "video":
filename = str(result_text).replace("[VIDEO]", "").split("|", 1)[0]
part = await self._create_attachment_from_file(filename, "video")
if part:
accumulated_parts.append(part)
should_save = False
if msg_type == "end":
should_save = bool(
accumulated_parts
or accumulated_text
or accumulated_reasoning
or refs
or agent_stats
)
elif (streaming and msg_type == "complete") or not streaming:
if chain_type not in (
"tool_call",
"tool_call_result",
"agent_stats",
):
should_save = True
if should_save:
try:
refs = self._extract_web_search_refs(
accumulated_text,
accumulated_parts,
)
except Exception as e:
logger.exception(
f"[Live Chat] Failed to extract web search refs: {e}",
exc_info=True,
)
saved_record = await self._save_bot_message(
session_id,
accumulated_text,
accumulated_parts,
accumulated_reasoning,
agent_stats,
refs,
)
if saved_record:
await self._send_chat_payload(
session,
{
"ct": "chat",
"type": "message_saved",
"data": {
"id": saved_record.id,
"created_at": saved_record.created_at.astimezone().isoformat(),
},
},
)
accumulated_parts = []
accumulated_text = ""
accumulated_reasoning = ""
agent_stats = {}
refs = {}
if msg_type == "end":
break
except Exception as e:
logger.error(f"[Live Chat] 处理 chat 消息失败: {e}", exc_info=True)
await self._send_chat_payload(
session,
{
"ct": "chat",
"t": "error",
"data": f"处理失败: {str(e)}",
"code": "PROCESSING_ERROR",
},
)
finally:
session.is_processing = False
webchat_queue_mgr.remove_back_queue(message_id)
async def _build_chat_message_parts(self, message: list[dict]) -> list[dict]:
"""构建 chat websocket 用户消息段(复用 webchat 逻辑)"""
return await build_webchat_message_parts(
message,
get_attachment_by_id=self.db.get_attachment_by_id,
strict=False,
)
async def _handle_message(self, session: LiveChatSession, message: dict) -> None:
"""处理 WebSocket 消息"""
msg_type = message.get("t") # 使用 t 代替 type
+360 -81
View File
@@ -1,15 +1,22 @@
from pathlib import Path
import asyncio
import hashlib
import json
from uuid import uuid4
from quart import g, request
from quart import g, request, websocket
from astrbot.core import logger
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.db import BaseDatabase
from astrbot.core.message.components import File, Image, Plain, Record, Reply, Video
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.platform.message_session import MessageSesion
from astrbot.core.platform.sources.webchat.message_parts_helper import (
build_message_chain_from_payload,
strip_message_parts_path_fields,
webchat_message_parts_have_content,
)
from astrbot.core.platform.sources.webchat.webchat_queue_mgr import webchat_queue_mgr
from .api_key import ALL_OPEN_API_SCOPES
from .chat import ChatRoute
from .route import Response, Route, RouteContext
@@ -37,6 +44,7 @@ class OpenApiRoute(Route):
"/v1/im/bots": ("GET", self.get_bots),
}
self.register_routes()
self.app.websocket("/api/v1/chat/ws")(self.chat_ws)
@staticmethod
def _resolve_open_username(
@@ -181,6 +189,348 @@ class OpenApiRoute(Route):
finally:
g.username = original_username
@staticmethod
def _extract_ws_api_key() -> str | None:
if key := websocket.args.get("api_key"):
return key.strip()
if key := websocket.args.get("key"):
return key.strip()
if key := websocket.headers.get("X-API-Key"):
return key.strip()
auth_header = websocket.headers.get("Authorization", "").strip()
if auth_header.startswith("Bearer "):
return auth_header.removeprefix("Bearer ").strip()
if auth_header.startswith("ApiKey "):
return auth_header.removeprefix("ApiKey ").strip()
return None
async def _authenticate_chat_ws_api_key(self) -> tuple[bool, str | None]:
raw_key = self._extract_ws_api_key()
if not raw_key:
return False, "Missing API key"
key_hash = hashlib.pbkdf2_hmac(
"sha256",
raw_key.encode("utf-8"),
b"astrbot_api_key",
100_000,
).hex()
api_key = await self.db.get_active_api_key_by_hash(key_hash)
if not api_key:
return False, "Invalid API key"
if isinstance(api_key.scopes, list):
scopes = api_key.scopes
else:
scopes = list(ALL_OPEN_API_SCOPES)
if "*" not in scopes and "chat" not in scopes:
return False, "Insufficient API key scope"
await self.db.touch_api_key(api_key.key_id)
return True, None
async def _send_chat_ws_error(self, message: str, code: str) -> None:
await websocket.send_json(
{
"type": "error",
"code": code,
"data": message,
}
)
async def _update_session_config_route(
self,
*,
username: str,
session_id: str,
config_id: str | None,
) -> str | None:
if not config_id:
return None
umo = f"webchat:FriendMessage:webchat!{username}!{session_id}"
try:
if config_id == "default":
await self.core_lifecycle.umop_config_router.delete_route(umo)
else:
await self.core_lifecycle.umop_config_router.update_route(
umo, config_id
)
except Exception as e:
logger.error(
"Failed to update chat config route for %s with %s: %s",
umo,
config_id,
e,
exc_info=True,
)
return f"Failed to update chat config route: {e}"
return None
async def _handle_chat_ws_send(self, post_data: dict) -> None:
effective_username, username_err = self._resolve_open_username(
post_data.get("username")
)
if username_err or not effective_username:
await self._send_chat_ws_error(
username_err or "Invalid username", "BAD_USER"
)
return
message = post_data.get("message")
if message is None:
await self._send_chat_ws_error("Missing key: message", "INVALID_MESSAGE")
return
raw_session_id = post_data.get("session_id", post_data.get("conversation_id"))
session_id = str(raw_session_id).strip() if raw_session_id is not None else ""
if not session_id:
session_id = str(uuid4())
ensure_session_err = await self._ensure_chat_session(
effective_username,
session_id,
)
if ensure_session_err:
await self._send_chat_ws_error(ensure_session_err, "SESSION_ERROR")
return
config_id, resolve_err = self._resolve_chat_config_id(post_data)
if resolve_err:
await self._send_chat_ws_error(resolve_err, "CONFIG_ERROR")
return
config_err = await self._update_session_config_route(
username=effective_username,
session_id=session_id,
config_id=config_id,
)
if config_err:
await self._send_chat_ws_error(config_err, "CONFIG_ERROR")
return
message_parts = await self.chat_route._build_user_message_parts(message)
if not webchat_message_parts_have_content(message_parts):
await self._send_chat_ws_error(
"Message content is empty (reply only is not allowed)",
"INVALID_MESSAGE",
)
return
message_id = str(post_data.get("message_id") or uuid4())
selected_provider = post_data.get("selected_provider")
selected_model = post_data.get("selected_model")
enable_streaming = post_data.get("enable_streaming", True)
back_queue = webchat_queue_mgr.get_or_create_back_queue(message_id, session_id)
try:
chat_queue = webchat_queue_mgr.get_or_create_queue(session_id)
await chat_queue.put(
(
effective_username,
session_id,
{
"message": message_parts,
"selected_provider": selected_provider,
"selected_model": selected_model,
"enable_streaming": enable_streaming,
"message_id": message_id,
},
)
)
message_parts_for_storage = strip_message_parts_path_fields(message_parts)
await self.chat_route.platform_history_mgr.insert(
platform_id="webchat",
user_id=session_id,
content={"type": "user", "message": message_parts_for_storage},
sender_id=effective_username,
sender_name=effective_username,
)
await websocket.send_json(
{
"type": "session_id",
"data": None,
"session_id": session_id,
"message_id": message_id,
}
)
accumulated_parts = []
accumulated_text = ""
accumulated_reasoning = ""
tool_calls = {}
agent_stats = {}
refs = {}
while True:
try:
result = await asyncio.wait_for(back_queue.get(), timeout=1)
except asyncio.TimeoutError:
continue
if not result:
continue
if "message_id" in result and result["message_id"] != message_id:
logger.warning("openapi ws stream message_id mismatch")
continue
result_text = result.get("data", "")
msg_type = result.get("type")
streaming = result.get("streaming", False)
chain_type = result.get("chain_type")
if chain_type == "agent_stats":
try:
stats_info = {
"type": "agent_stats",
"data": json.loads(result_text),
}
await websocket.send_json(stats_info)
agent_stats = stats_info["data"]
except Exception:
pass
continue
await websocket.send_json(result)
if msg_type == "plain":
if chain_type == "tool_call":
tool_call = json.loads(result_text)
tool_calls[tool_call.get("id")] = tool_call
if accumulated_text:
accumulated_parts.append(
{"type": "plain", "text": accumulated_text}
)
accumulated_text = ""
elif chain_type == "tool_call_result":
tcr = json.loads(result_text)
tc_id = tcr.get("id")
if tc_id in tool_calls:
tool_calls[tc_id]["result"] = tcr.get("result")
tool_calls[tc_id]["finished_ts"] = tcr.get("ts")
accumulated_parts.append(
{"type": "tool_call", "tool_calls": [tool_calls[tc_id]]}
)
tool_calls.pop(tc_id, None)
elif chain_type == "reasoning":
accumulated_reasoning += result_text
elif streaming:
accumulated_text += result_text
else:
accumulated_text = result_text
elif msg_type == "image":
filename = str(result_text).replace("[IMAGE]", "")
part = await self.chat_route._create_attachment_from_file(
filename, "image"
)
if part:
accumulated_parts.append(part)
elif msg_type == "record":
filename = str(result_text).replace("[RECORD]", "")
part = await self.chat_route._create_attachment_from_file(
filename, "record"
)
if part:
accumulated_parts.append(part)
elif msg_type == "file":
filename = str(result_text).replace("[FILE]", "")
part = await self.chat_route._create_attachment_from_file(
filename, "file"
)
if part:
accumulated_parts.append(part)
elif msg_type == "video":
filename = str(result_text).replace("[VIDEO]", "")
part = await self.chat_route._create_attachment_from_file(
filename, "video"
)
if part:
accumulated_parts.append(part)
if msg_type == "end":
break
if (streaming and msg_type == "complete") or not streaming:
if chain_type in ("tool_call", "tool_call_result"):
continue
try:
refs = self.chat_route._extract_web_search_refs(
accumulated_text,
accumulated_parts,
)
except Exception as e:
logger.exception(
f"Open API WS failed to extract web search refs: {e}",
exc_info=True,
)
saved_record = await self.chat_route._save_bot_message(
session_id,
accumulated_text,
accumulated_parts,
accumulated_reasoning,
agent_stats,
refs,
)
if saved_record:
await websocket.send_json(
{
"type": "message_saved",
"data": {
"id": saved_record.id,
"created_at": saved_record.created_at.astimezone().isoformat(),
},
"session_id": session_id,
}
)
accumulated_parts = []
accumulated_text = ""
accumulated_reasoning = ""
agent_stats = {}
refs = {}
except Exception as e:
logger.exception(f"Open API WS chat failed: {e}", exc_info=True)
await self._send_chat_ws_error(
f"Failed to process message: {e}", "PROCESSING_ERROR"
)
finally:
webchat_queue_mgr.remove_back_queue(message_id)
async def chat_ws(self) -> None:
authed, auth_err = await self._authenticate_chat_ws_api_key()
if not authed:
await self._send_chat_ws_error(auth_err or "Unauthorized", "UNAUTHORIZED")
await websocket.close(1008, auth_err or "Unauthorized")
return
try:
while True:
message = await websocket.receive_json()
if not isinstance(message, dict):
await self._send_chat_ws_error(
"message must be an object",
"INVALID_MESSAGE",
)
continue
msg_type = message.get("t", "send")
if msg_type == "ping":
await websocket.send_json({"type": "pong"})
continue
if msg_type != "send":
await self._send_chat_ws_error(
f"Unsupported message type: {msg_type}",
"INVALID_MESSAGE",
)
continue
await self._handle_chat_ws_send(message)
except Exception as e:
logger.debug("Open API WS connection closed: %s", e)
async def upload_file(self):
return await self.chat_route.post_file()
@@ -254,83 +604,12 @@ class OpenApiRoute(Route):
async def _build_message_chain_from_payload(
self,
message_payload: str | list,
) -> MessageChain:
if isinstance(message_payload, str):
text = message_payload.strip()
if not text:
raise ValueError("Message is empty")
return MessageChain(chain=[Plain(text=text)])
if not isinstance(message_payload, list):
raise ValueError("message must be a string or list")
components = []
has_content = False
for part in message_payload:
if not isinstance(part, dict):
raise ValueError("message part must be an object")
part_type = str(part.get("type", "")).strip()
if part_type == "plain":
text = str(part.get("text", ""))
if text:
has_content = True
components.append(Plain(text=text))
continue
if part_type == "reply":
message_id = part.get("message_id")
if message_id is None:
raise ValueError("reply part missing message_id")
components.append(
Reply(
id=str(message_id),
message_str=str(part.get("selected_text", "")),
chain=[],
)
)
continue
if part_type not in {"image", "record", "file", "video"}:
raise ValueError(f"unsupported message part type: {part_type}")
has_content = True
file_path: Path | None = None
resolved_type = part_type
filename = str(part.get("filename", "")).strip()
attachment_id = part.get("attachment_id")
if attachment_id:
attachment = await self.db.get_attachment_by_id(str(attachment_id))
if not attachment:
raise ValueError(f"attachment not found: {attachment_id}")
file_path = Path(attachment.path)
resolved_type = attachment.type
if not filename:
filename = file_path.name
else:
raise ValueError(f"{part_type} part missing attachment_id")
if not file_path.exists():
raise ValueError(f"file not found: {file_path!s}")
file_path_str = str(file_path.resolve())
if resolved_type == "image":
components.append(Image.fromFileSystem(file_path_str))
elif resolved_type == "record":
components.append(Record.fromFileSystem(file_path_str))
elif resolved_type == "video":
components.append(Video.fromFileSystem(file_path_str))
else:
components.append(
File(name=filename or file_path.name, file=file_path_str)
)
if not components or not has_content:
raise ValueError("Message content is empty (reply only is not allowed)")
return MessageChain(chain=components)
):
return await build_message_chain_from_payload(
message_payload,
get_attachment_by_id=self.db.get_attachment_by_id,
strict=True,
)
async def send_message(self):
post_data = await request.json or {}
+5
View File
@@ -204,6 +204,10 @@ class AstrBotDashboard:
@staticmethod
def _extract_raw_api_key() -> str | None:
if key := request.args.get("api_key"):
return key.strip()
if key := request.args.get("key"):
return key.strip()
if key := request.headers.get("X-API-Key"):
return key.strip()
auth_header = request.headers.get("Authorization", "").strip()
@@ -217,6 +221,7 @@ class AstrBotDashboard:
def _get_required_open_api_scope(path: str) -> str | None:
scope_map = {
"/api/v1/chat": "chat",
"/api/v1/chat/ws": "chat",
"/api/v1/chat/sessions": "chat",
"/api/v1/configs": "config",
"/api/v1/file": "file",
-60
View File
@@ -1,60 +0,0 @@
## What's Changed
### 新增
- 新增 Agent 会话停止能力,并优化 stop 请求处理流程,支持 /stop 指令终止 Agent 运行并尽量不丢失已运行输出的结果。 ([#5380](https://github.com/AstrBotDevs/AstrBot/issues/5380))。
- 新增 SubAgent 交接场景下的 computer-use 工具支持 ([#5399](https://github.com/AstrBotDevs/AstrBot/issues/5399))。
- 新增 Agent 执行过程中展示工具调用结果的能力,提升执行过程可观测性 ([#5388](https://github.com/AstrBotDevs/AstrBot/issues/5388))。
- 新增插件加载/卸载 Hook,扩展插件生命周期能力 ([#5331](https://github.com/AstrBotDevs/AstrBot/issues/5331))。
- 新增插件加载失败后的热重载能力,提升插件开发与恢复效率 ([#5334](https://github.com/AstrBotDevs/AstrBot/issues/5334))。
- 新增 SubAgent 图片 URL/本地路径输入支持 ([#5348](https://github.com/AstrBotDevs/AstrBot/issues/5348))。
- 新增 Dashboard 发布跳转基础 URL 可配置项 ([#5330](https://github.com/AstrBotDevs/AstrBot/issues/5330))。
### 修复
- 修复 Tavily 请求的硬编码 6 秒超时。
- 修复 OneBot v11 适配器关闭之后仍然在连接的问题([#5412](https://github.com/AstrBotDevs/AstrBot/issues/5412))。
- 修复上下文会话中平台缺失时的日志处理,补充 warning 并改进排查信息。
- 修复 embedding 维度未透传到 provider API 的问题 ([#5411](https://github.com/AstrBotDevs/AstrBot/issues/5411))。
- 修复 File 组件处理逻辑并增强 OneBot 驱动层路径兼容性 ([#5391](https://github.com/AstrBotDevs/AstrBot/issues/5391))。
- 修复 sandbox 文件传输工具缺少管理员权限校验的问题 ([#5402](https://github.com/AstrBotDevs/AstrBot/issues/5402))。
- 修复 pipeline 与 `from ... import *` 引发的循环依赖问题 ([#5353](https://github.com/AstrBotDevs/AstrBot/issues/5353))。
- 修复配置文件存在 UTF-8 BOM 时的解析问题 ([#5376](https://github.com/AstrBotDevs/AstrBot/issues/5376))。
- 修复 ChatUI 复制回滚路径缺失与错误提示不清晰的问题 ([#5352](https://github.com/AstrBotDevs/AstrBot/issues/5352))。
- 修复保留插件目录处理逻辑,避免插件目录行为异常 ([#5369](https://github.com/AstrBotDevs/AstrBot/issues/5369))。
- 修复 ChatUI 文件消息段无法持久化的问题 ([#5386](https://github.com/AstrBotDevs/AstrBot/issues/5386))。
- 修复 `.dockerignore` 误排除 `changelogs` 目录的问题。
- 修复 aiohttp 版本过新导致 qq-botpy 报错的问题 ([#5316](https://github.com/AstrBotDevs/AstrBot/issues/5316))。
### 优化
- 完成 SubAgent 编排页面国际化,补齐多语言支持 ([#5400](https://github.com/AstrBotDevs/AstrBot/issues/5400))。
- 增补消息事件处理相关测试,并完善测试框架的 fixtures/mocks 覆盖 ([#5355](https://github.com/AstrBotDevs/AstrBot/issues/5355), [#5354](https://github.com/AstrBotDevs/AstrBot/issues/5354))。
## What's Changed (EN)
### New Features
- Added computer-use tools support in sub-agent handoff scenarios ([#5399](https://github.com/AstrBotDevs/AstrBot/issues/5399)).
- Added support for displaying tool call results during agent execution for better observability ([#5388](https://github.com/AstrBotDevs/AstrBot/issues/5388)).
- Added plugin load/unload hooks to extend plugin lifecycle capabilities ([#5331](https://github.com/AstrBotDevs/AstrBot/issues/5331)).
- Added hot reload support when plugin loading fails, improving recovery during plugin development ([#5334](https://github.com/AstrBotDevs/AstrBot/issues/5334)).
- Added image URL/local path input support for sub-agents ([#5348](https://github.com/AstrBotDevs/AstrBot/issues/5348)).
- Added stop control for active agent sessions and improved stop request handling ([#5380](https://github.com/AstrBotDevs/AstrBot/issues/5380)).
- Added configurable base URL for dashboard release redirects ([#5330](https://github.com/AstrBotDevs/AstrBot/issues/5330)).
### Fixes
- Fixed logging behavior when platform information is missing in context sessions, with clearer warning and diagnostics.
- Fixed missing embedding dimensions being passed to provider APIs ([#5411](https://github.com/AstrBotDevs/AstrBot/issues/5411)).
- Fixed shutdown stability issues in the aiocqhttp adapter ([#5412](https://github.com/AstrBotDevs/AstrBot/issues/5412)).
- Fixed File component handling and improved path compatibility in the OneBot driver layer ([#5391](https://github.com/AstrBotDevs/AstrBot/issues/5391)).
- Fixed missing admin guard for sandbox file transfer tools ([#5402](https://github.com/AstrBotDevs/AstrBot/issues/5402)).
- Fixed circular import issues related to pipeline and `from ... import *` usage ([#5353](https://github.com/AstrBotDevs/AstrBot/issues/5353)).
- Fixed config parsing issues when files contain UTF-8 BOM ([#5376](https://github.com/AstrBotDevs/AstrBot/issues/5376)).
- Fixed missing copy rollback path and unclear error messaging in ChatUI ([#5352](https://github.com/AstrBotDevs/AstrBot/issues/5352)).
- Fixed reserved plugin directory handling to avoid abnormal plugin path behavior ([#5369](https://github.com/AstrBotDevs/AstrBot/issues/5369)).
- Fixed ChatUI file segment persistence issues ([#5386](https://github.com/AstrBotDevs/AstrBot/issues/5386)).
- Fixed accidental exclusion of the `changelogs` directory in `.dockerignore`.
- Fixed compatibility issues caused by a hard-coded 6-second timeout in Tavily requests.
- Fixed qq-botpy runtime errors caused by overly new aiohttp versions ([#5316](https://github.com/AstrBotDevs/AstrBot/issues/5316)).
### Improvements
- Completed internationalization for the sub-agent orchestration page ([#5400](https://github.com/AstrBotDevs/AstrBot/issues/5400)).
- Added broader message-event test coverage and improved fixtures/mocks in the test framework ([#5355](https://github.com/AstrBotDevs/AstrBot/issues/5355), [#5354](https://github.com/AstrBotDevs/AstrBot/issues/5354)).
- Updated README content and applied repository-wide formatting cleanup (ruff format) ([#5375](https://github.com/AstrBotDevs/AstrBot/issues/5375)).
+7 -1
View File
@@ -10,6 +10,7 @@
:selectedSessions="selectedSessions"
:currSessionId="currSessionId"
:selectedProjectId="selectedProjectId"
:transportMode="transportMode"
:isDark="isDark"
:chatboxMode="chatboxMode"
:isMobile="isMobile"
@@ -26,6 +27,7 @@
@createProject="showCreateProjectDialog"
@editProject="showEditProjectDialog"
@deleteProject="handleDeleteProject"
@updateTransportMode="setTransportMode"
/>
<!-- 右侧聊天内容区域 -->
@@ -301,11 +303,14 @@ const {
isStreaming,
isConvRunning,
enableStreaming,
transportMode,
currentSessionProject,
getSessionMessages: getSessionMsg,
sendMessage: sendMsg,
stopMessage: stopMsg,
toggleStreaming
toggleStreaming,
setTransportMode,
cleanupTransport
} = useMessages(currSessionId, getMediaFile, updateSessionTitle, getSessions);
//
@@ -695,6 +700,7 @@ onMounted(() => {
onBeforeUnmount(() => {
window.removeEventListener('resize', checkMobile);
cleanupMediaCache();
cleanupTransport();
});
</script>
@@ -117,6 +117,27 @@
<v-list-item-title>{{ isDark ? tm('modes.lightMode') : tm('modes.darkMode') }}</v-list-item-title>
</v-list-item>
<!-- 通信传输模式 -->
<v-list-item class="styled-menu-item">
<template v-slot:prepend>
<v-icon>mdi-lan-connect</v-icon>
</template>
<v-list-item-title>{{ tm('transport.title') }}</v-list-item-title>
<template v-slot:append>
<v-select
:model-value="transportMode"
:items="transportOptions"
item-title="label"
item-value="value"
density="compact"
variant="underlined"
hide-details
class="transport-mode-select"
@update:model-value="handleTransportModeChange"
/>
</template>
</v-list-item>
<!-- 全屏/退出全屏 -->
<v-list-item class="styled-menu-item" @click="$emit('toggleFullscreen')">
<template v-slot:prepend>
@@ -156,6 +177,7 @@ interface Props {
selectedSessions: string[];
currSessionId: string;
selectedProjectId?: string | null;
transportMode: 'sse' | 'websocket';
isDark: boolean;
chatboxMode: boolean;
isMobile: boolean;
@@ -179,6 +201,7 @@ const emit = defineEmits<{
createProject: [];
editProject: [project: Project];
deleteProject: [projectId: string];
updateTransportMode: [mode: 'sse' | 'websocket'];
}>();
const { t } = useI18n();
@@ -188,6 +211,10 @@ const confirmDialog = useConfirmDialog();
const sidebarCollapsed = ref(true);
const showProviderConfigDialog = ref(false);
const transportOptions = [
{ label: tm('transport.sse'), value: 'sse' as const },
{ label: tm('transport.websocket'), value: 'websocket' as const }
];
// localStorage
const savedCollapsedState = localStorage.getItem('sidebarCollapsed');
@@ -209,6 +236,12 @@ async function handleDeleteConversation(session: Session) {
emit('deleteConversation', session.session_id);
}
}
function handleTransportModeChange(mode: string | null) {
if (mode === 'sse' || mode === 'websocket') {
emit('updateTransportMode', mode);
}
}
</script>
<style scoped>
@@ -361,4 +394,8 @@ async function handleDeleteConversation(session: Session) {
display: flex;
justify-content: center;
}
.transport-mode-select {
min-width: 120px;
}
</style>
File diff suppressed because it is too large Load Diff
@@ -81,9 +81,16 @@
"disabled": "Streaming disabled",
"on": "Stream",
"off": "Normal"
}, "config": {
},
"transport": {
"title": "Transport Mode",
"sse": "SSE",
"websocket": "WebSocket"
},
"config": {
"title": "Config"
}, "reasoning": {
},
"reasoning": {
"thinking": "Thinking Process"
},
"reply": {
@@ -251,10 +251,6 @@
"show_tool_use_status": {
"description": "Output Function Call Status"
},
"show_tool_call_result": {
"description": "Output Tool Call Results",
"hint": "Only takes effect when \"Output Function Call Status\" is enabled, and shows at most 70 characters."
},
"sanitize_context_by_modalities": {
"description": "Sanitize History by Modalities",
"hint": "When enabled, sanitizes contexts before each LLM request by removing image blocks and tool-call structures that the current provider's modalities do not support (this changes what the model sees)."
@@ -8,14 +8,11 @@
"refresh": "Refresh",
"save": "Save",
"add": "Add SubAgent",
"delete": "Delete",
"close": "Close"
"delete": "Delete"
},
"switches": {
"enable": "Enable SubAgent orchestration",
"enableHint": "Enable sub-agent functionality",
"dedupe": "Deduplicate main LLM tools (hide tools duplicated by SubAgents)",
"dedupeHint": "Remove duplicate tools from main agent"
"dedupe": "Deduplicate main LLM tools (hide tools duplicated by SubAgents)"
},
"description": {
"disabled": "When off: SubAgent is disabled; the main LLM mounts tools via persona rules (all by default) and calls them directly.",
@@ -32,8 +29,7 @@
"transferPrefix": "transfer_to_{name}",
"switchLabel": "Enable",
"previewTitle": "Preview: handoff tool shown to the main LLM",
"personaChip": "Persona: {id}",
"personaPreview": "PERSONA PREVIEW"
"personaChip": "Persona: {id}"
},
"form": {
"nameLabel": "Agent name (used for transfer_to_{name})",
@@ -53,13 +49,6 @@
"nameDuplicate": "Duplicate SubAgent name: {name}",
"personaMissing": "SubAgent {name} has no persona selected",
"saveSuccess": "Saved successfully",
"saveFailed": "Failed to save",
"nameRequired": "Name is required",
"namePattern": "Lowercase letters, numbers, underscore only"
},
"empty": {
"title": "No Agents Configured",
"subtitle": "Add a new sub-agent to get started",
"action": "Create First Agent"
"saveFailed": "Failed to save"
}
}
@@ -82,6 +82,11 @@
"on": "流式",
"off": "普通"
},
"transport": {
"title": "通信传输模式",
"sse": "SSE",
"websocket": "WebSocket"
},
"config": {
"title": "配置文件"
},
@@ -254,10 +254,6 @@
"show_tool_use_status": {
"description": "输出函数调用状态"
},
"show_tool_call_result": {
"description": "输出函数调用返回结果",
"hint": "仅在启用“输出函数调用状态”时生效,且最多展示 70 个字符。"
},
"sanitize_context_by_modalities": {
"description": "按模型能力清理历史上下文",
"hint": "开启后,在每次请求 LLM 前会按当前模型提供商中所选择的模型能力删除对话中不支持的图片/工具调用结构(会改变模型看到的历史)"
@@ -8,14 +8,11 @@
"refresh": "刷新",
"save": "保存",
"add": "新增 SubAgent",
"delete": "删除",
"close": "关闭"
"delete": "删除"
},
"switches": {
"enable": "启用 SubAgent 编排",
"enableHint": "启用子代理功能",
"dedupe": "主 LLM 去重重复工具(与 SubAgent 重叠的工具将被隐藏)",
"dedupeHint": "从主代理中移除重复工具"
"dedupe": "主 LLM 去重重复工具(与 SubAgent 重叠的工具将被隐藏)"
},
"description": {
"disabled": "不启动:SubAgent 关闭;主 LLM 按 persona 规则挂载工具(默认全部),并直接调用。",
@@ -42,7 +39,6 @@
"providerHint": "留空表示跟随全局默认 provider。",
"personaLabel": "选择人格设定",
"personaHint": "SubAgent 将直接继承所选 Persona 的系统设定与工具。在人格设定页管理和新建人格。",
"personaPreview": "人格预览",
"descriptionLabel": "对主 LLM 的描述(用于决定是否 handoff",
"descriptionHint": "这段会作为 transfer_to_* 工具的描述给主 LLM 看,建议简短明确。"
},
@@ -54,13 +50,6 @@
"nameDuplicate": "SubAgent 名称重复:{name}",
"personaMissing": "SubAgent {name} 未选择 Persona",
"saveSuccess": "保存成功",
"saveFailed": "保存失败",
"nameRequired": "名称必填",
"namePattern": "仅支持小写字母、数字和下划线"
},
"empty": {
"title": "未配置 SubAgent",
"subtitle": "添加一个新的子代理以开始",
"action": "创建第一个 Agent"
"saveFailed": "保存失败"
}
}
+8 -8
View File
@@ -62,7 +62,7 @@
<template #label>
<div class="d-flex flex-column">
<span class="text-body-2 font-weight-medium">{{ tm('switches.enable') }}</span>
<span class="text-caption text-medium-emphasis">{{ tm('switches.enableHint') }}</span>
<span class="text-caption text-medium-emphasis">Enable sub-agent functionality</span>
</div>
</template>
</v-switch>
@@ -80,7 +80,7 @@
<template #label>
<div class="d-flex flex-column">
<span class="text-body-2 font-weight-medium">{{ tm('switches.dedupe') }}</span>
<span class="text-caption text-medium-emphasis">{{ tm('switches.dedupeHint') }}</span>
<span class="text-caption text-medium-emphasis">Remove duplicate tools from main agent</span>
</div>
</template>
</v-switch>
@@ -166,7 +166,7 @@
<v-text-field
v-model="agent.name"
:label="tm('form.nameLabel')"
:rules="[v => !!v || tm('messages.nameRequired'), v => /^[a-z][a-z0-9_]*$/.test(v) || tm('messages.namePattern')]"
:rules="[v => !!v || 'Name is required', v => /^[a-z][a-z0-9_]*$/.test(v) || 'Lowercase letters, numbers, underscore only']"
variant="outlined"
density="comfortable"
hide-details="auto"
@@ -215,7 +215,7 @@
<v-col cols="12" md="6">
<div class="h-100">
<div class="text-caption font-weight-bold text-medium-emphasis mb-2 ml-1">
{{ tm('cards.personaPreview') }}
PERSONA PREVIEW
</div>
<PersonaQuickPreview
:model-value="agent.persona_id"
@@ -231,17 +231,17 @@
<!-- Empty State -->
<div v-if="cfg.agents.length === 0" class="d-flex flex-column align-center justify-center py-12 text-medium-emphasis">
<v-icon icon="mdi-robot-off" size="64" class="mb-4 opacity-50" />
<div class="text-h6">{{ tm('empty.title') }}</div>
<div class="text-body-2 mb-4">{{ tm('empty.subtitle') }}</div>
<div class="text-h6">No Agents Configured</div>
<div class="text-body-2 mb-4">Add a new sub-agent to get started</div>
<v-btn color="primary" variant="tonal" @click="addAgent">
{{ tm('empty.action') }}
Create First Agent
</v-btn>
</div>
<v-snackbar v-model="snackbar.show" :color="snackbar.color" timeout="3000" location="top">
{{ snackbar.message }}
<template #actions>
<v-btn variant="text" @click="snackbar.show = false">{{ tm('actions.close') }}</v-btn>
<v-btn variant="text" @click="snackbar.show = false">Close</v-btn>
</template>
</v-snackbar>
</div>
+1
View File
@@ -43,6 +43,7 @@ export default defineConfig({
'/api': {
target: 'http://127.0.0.1:6185/',
changeOrigin: true,
ws: true
}
}
}
-381
View File
@@ -1,381 +0,0 @@
"""
AstrBot 测试配置
提供共享的 pytest fixtures 和测试工具
"""
import json
import os
import sys
from asyncio import Queue
from pathlib import Path
from typing import Any
from unittest.mock import AsyncMock, MagicMock
import pytest
import pytest_asyncio
# 使用 tests/fixtures/helpers.py 中的共享工具函数,避免重复定义
from tests.fixtures.helpers import create_mock_llm_response, create_mock_message_component
# 将项目根目录添加到 sys.path
PROJECT_ROOT = Path(__file__).parent.parent
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
# 设置测试环境变量
os.environ.setdefault("TESTING", "true")
os.environ.setdefault("ASTRBOT_TEST_MODE", "true")
# ============================================================
# 测试收集和排序
# ============================================================
def pytest_collection_modifyitems(session, config, items): # noqa: ARG001
"""重新排序测试:单元测试优先,集成测试在后。"""
unit_tests = []
integration_tests = []
deselected = []
profile = config.getoption("--test-profile") or os.environ.get(
"ASTRBOT_TEST_PROFILE", "all"
)
for item in items:
item_path = Path(str(item.path))
is_integration = "integration" in item_path.parts
if is_integration:
if item.get_closest_marker("integration") is None:
item.add_marker(pytest.mark.integration)
item.add_marker(pytest.mark.tier_d)
integration_tests.append(item)
else:
if item.get_closest_marker("unit") is None:
item.add_marker(pytest.mark.unit)
if any(
item.get_closest_marker(marker) is not None
for marker in ("platform", "provider", "slow")
):
item.add_marker(pytest.mark.tier_c)
unit_tests.append(item)
# 单元测试 -> 集成测试
ordered_items = unit_tests + integration_tests
if profile == "blocking":
selected_items = []
for item in ordered_items:
if item.get_closest_marker("tier_c") or item.get_closest_marker("tier_d"):
deselected.append(item)
else:
selected_items.append(item)
if deselected:
config.hook.pytest_deselected(items=deselected)
items[:] = selected_items
return
items[:] = ordered_items
def pytest_addoption(parser):
"""增加测试执行档位选择。"""
parser.addoption(
"--test-profile",
action="store",
default=None,
choices=["all", "blocking"],
help="Select test profile. 'blocking' excludes auto-classified tier_c/tier_d tests.",
)
def pytest_configure(config):
"""注册自定义标记。"""
config.addinivalue_line("markers", "unit: 单元测试")
config.addinivalue_line("markers", "integration: 集成测试")
config.addinivalue_line("markers", "slow: 慢速测试")
config.addinivalue_line("markers", "platform: 平台适配器测试")
config.addinivalue_line("markers", "provider: LLM Provider 测试")
config.addinivalue_line("markers", "db: 数据库相关测试")
config.addinivalue_line("markers", "tier_c: C-tier tests (optional / non-blocking)")
config.addinivalue_line("markers", "tier_d: D-tier tests (extended / integration)")
# ============================================================
# 临时目录和文件 Fixtures
# ============================================================
@pytest.fixture
def temp_dir(tmp_path: Path) -> Path:
"""创建临时目录用于测试。"""
return tmp_path
@pytest.fixture
def event_queue() -> Queue:
"""Create a shared asyncio queue fixture for tests."""
return Queue()
@pytest.fixture
def platform_settings() -> dict:
"""Create a shared empty platform settings fixture for adapter tests."""
return {}
@pytest.fixture
def temp_data_dir(temp_dir: Path) -> Path:
"""创建模拟的 data 目录结构。"""
data_dir = temp_dir / "data"
data_dir.mkdir()
# 创建必要的子目录
(data_dir / "config").mkdir()
(data_dir / "plugins").mkdir()
(data_dir / "temp").mkdir()
(data_dir / "attachments").mkdir()
return data_dir
@pytest.fixture
def temp_config_file(temp_data_dir: Path) -> Path:
"""创建临时配置文件。"""
config_path = temp_data_dir / "config" / "cmd_config.json"
default_config = {
"provider": [],
"platform": [],
"provider_settings": {},
"default_personality": None,
"timezone": "Asia/Shanghai",
}
config_path.write_text(json.dumps(default_config, indent=2), encoding="utf-8")
return config_path
@pytest.fixture
def temp_db_file(temp_data_dir: Path) -> Path:
"""创建临时数据库文件路径。"""
return temp_data_dir / "test.db"
# ============================================================
# Mock Fixtures
# ============================================================
@pytest.fixture
def mock_provider():
"""创建模拟的 Provider。"""
provider = MagicMock()
provider.provider_config = {
"id": "test-provider",
"type": "openai_chat_completion",
"model": "gpt-4o-mini",
}
provider.get_model = MagicMock(return_value="gpt-4o-mini")
provider.text_chat = AsyncMock()
provider.text_chat_stream = AsyncMock()
provider.terminate = AsyncMock()
return provider
@pytest.fixture
def mock_platform():
"""创建模拟的 Platform。"""
platform = MagicMock()
platform.platform_name = "test_platform"
platform.platform_meta = MagicMock()
platform.platform_meta.support_proactive_message = False
platform.send_message = AsyncMock()
platform.terminate = AsyncMock()
return platform
@pytest.fixture
def mock_conversation():
"""创建模拟的 Conversation。"""
from astrbot.core.db.po import ConversationV2
return ConversationV2(
conversation_id="test-conv-id",
platform_id="test_platform",
user_id="test_user",
content=[],
persona_id=None,
)
@pytest.fixture
def mock_event():
"""创建模拟的 AstrMessageEvent。"""
event = MagicMock()
event.unified_msg_origin = "test_umo"
event.session_id = "test_session"
event.message_str = "Hello, world!"
event.message_obj = MagicMock()
event.message_obj.message = []
event.message_obj.sender = MagicMock()
event.message_obj.sender.user_id = "test_user"
event.message_obj.sender.nickname = "Test User"
event.message_obj.group_id = None
event.message_obj.group = None
event.get_platform_name = MagicMock(return_value="test_platform")
event.get_platform_id = MagicMock(return_value="test_platform")
event.get_group_id = MagicMock(return_value=None)
event.get_extra = MagicMock(return_value=None)
event.set_extra = MagicMock()
event.trace = MagicMock()
event.platform_meta = MagicMock()
event.platform_meta.support_proactive_message = False
return event
# ============================================================
# 配置 Fixtures
# ============================================================
@pytest.fixture
def astrbot_config(temp_config_file: Path):
"""创建 AstrBotConfig 实例。"""
from astrbot.core.config.astrbot_config import AstrBotConfig
config = AstrBotConfig()
config._config_path = str(temp_config_file) # noqa: SLF001
return config
@pytest.fixture
def main_agent_build_config():
"""创建 MainAgentBuildConfig 实例。"""
from astrbot.core.astr_main_agent import MainAgentBuildConfig
return MainAgentBuildConfig(
tool_call_timeout=60,
tool_schema_mode="full",
provider_wake_prefix="",
streaming_response=True,
sanitize_context_by_modalities=False,
kb_agentic_mode=False,
file_extract_enabled=False,
context_limit_reached_strategy="truncate_by_turns",
llm_safety_mode=True,
computer_use_runtime="local",
add_cron_tools=True,
)
# ============================================================
# 数据库 Fixtures
# ============================================================
@pytest_asyncio.fixture
async def temp_db(temp_db_file: Path):
"""创建临时数据库实例。"""
from astrbot.core.db.sqlite import SQLiteDatabase
db = SQLiteDatabase(str(temp_db_file))
try:
yield db
finally:
await db.engine.dispose()
if temp_db_file.exists():
temp_db_file.unlink()
# ============================================================
# Context Fixtures
# ============================================================
@pytest_asyncio.fixture
async def mock_context(
astrbot_config,
temp_db,
mock_provider,
mock_platform,
):
"""创建模拟的插件上下文。"""
from asyncio import Queue
from astrbot.core.star.context import Context
event_queue = Queue()
provider_manager = MagicMock()
provider_manager.get_using_provider = MagicMock(return_value=mock_provider)
provider_manager.get_provider_by_id = MagicMock(return_value=mock_provider)
platform_manager = MagicMock()
conversation_manager = MagicMock()
message_history_manager = MagicMock()
persona_manager = MagicMock()
persona_manager.personas_v3 = []
astrbot_config_mgr = MagicMock()
knowledge_base_manager = MagicMock()
cron_manager = MagicMock()
subagent_orchestrator = None
context = Context(
event_queue,
astrbot_config,
temp_db,
provider_manager,
platform_manager,
conversation_manager,
message_history_manager,
persona_manager,
astrbot_config_mgr,
knowledge_base_manager,
cron_manager,
subagent_orchestrator,
)
return context
# ============================================================
# Provider Request Fixtures
# ============================================================
@pytest.fixture
def provider_request():
"""创建 ProviderRequest 实例。"""
from astrbot.core.provider.entities import ProviderRequest
return ProviderRequest(
prompt="Hello",
session_id="test_session",
image_urls=[],
contexts=[],
system_prompt="You are a helpful assistant.",
)
# ============================================================
# 跳过条件
# ============================================================
def pytest_runtest_setup(item):
"""在测试运行前检查跳过条件。"""
# 跳过需要 API Key 但未设置的 Provider 测试
if item.get_closest_marker("provider"):
if not os.environ.get("TEST_PROVIDER_API_KEY"):
pytest.skip("TEST_PROVIDER_API_KEY not set")
# 跳过需要特定平台的测试
if item.get_closest_marker("platform"):
required_platform = None
marker = item.get_closest_marker("platform")
if marker and marker.args:
required_platform = marker.args[0]
if required_platform and not os.environ.get(
f"TEST_{required_platform.upper()}_ENABLED"
):
pytest.skip(f"TEST_{required_platform.upper()}_ENABLED not set")
-64
View File
@@ -1,64 +0,0 @@
"""
AstrBot 测试数据
此目录存放测试用的静态数据和配置文件
目录结构:
- fixtures/
configs/ # 测试配置文件
messages/ # 测试消息数据
plugins/ # 测试插件
knowledge_base/ # 测试知识库数据
mocks/ # Mock 模块
helpers.py # 辅助函数
"""
import json
from pathlib import Path
from .helpers import (
NoopAwaitable,
create_mock_discord_attachment,
create_mock_discord_channel,
create_mock_discord_user,
create_mock_file,
create_mock_llm_response,
create_mock_message_component,
create_mock_update,
make_platform_config,
)
FIXTURES_DIR = Path(__file__).parent
def load_fixture(filename: str) -> dict:
"""加载 JSON 格式的测试数据。"""
filepath = FIXTURES_DIR / filename
if not filepath.exists():
raise FileNotFoundError(f"Fixture not found: {filepath}")
return json.loads(filepath.read_text(encoding="utf-8"))
def get_fixture_path(filename: str) -> Path:
"""获取测试数据文件路径。"""
filepath = FIXTURES_DIR / filename
if not filepath.exists():
raise FileNotFoundError(f"Fixture not found: {filepath}")
return filepath
__all__ = [
"FIXTURES_DIR",
"load_fixture",
"get_fixture_path",
# 辅助函数
"NoopAwaitable",
"make_platform_config",
"create_mock_update",
"create_mock_file",
"create_mock_discord_attachment",
"create_mock_discord_user",
"create_mock_discord_channel",
"create_mock_message_component",
"create_mock_llm_response",
]
-21
View File
@@ -1,21 +0,0 @@
{
"provider": [
{
"id": "test-openai",
"type": "openai_chat_completion",
"model": "gpt-4o-mini",
"key": ["test-key"]
}
],
"platform": [],
"provider_settings": {
"default_personality": null,
"prompt_prefix": "",
"image_caption_provider_id": "",
"datetime_system_prompt": true,
"identifier": true,
"group_name_display": true
},
"default_personality": null,
"timezone": "Asia/Shanghai"
}
-332
View File
@@ -1,332 +0,0 @@
"""测试辅助函数和工具类。
提供统一的测试辅助工具减少测试代码重复
"""
from typing import Any
from unittest.mock import AsyncMock, MagicMock
from astrbot.core.message.components import BaseMessageComponent
class NoopAwaitable:
"""可等待的空操作对象。
用于 mock 需要返回 awaitable 对象的方法
"""
def __await__(self):
if False:
yield
return None
# ============================================================
# 平台配置工厂
# ============================================================
def make_platform_config(platform_type: str, **kwargs) -> dict:
"""平台配置工厂函数。
Args:
platform_type: 平台类型 (telegram, discord, aiocqhttp )
**kwargs: 覆盖默认配置的字段
Returns:
dict: 平台配置字典
"""
configs = {
"telegram": {
"id": "test_telegram",
"telegram_token": "test_token_123",
"telegram_api_base_url": "https://api.telegram.org/bot",
"telegram_file_base_url": "https://api.telegram.org/file/bot",
"telegram_command_register": True,
"telegram_command_auto_refresh": True,
"telegram_command_register_interval": 300,
"telegram_media_group_timeout": 2.5,
"telegram_media_group_max_wait": 10.0,
"start_message": "Welcome to AstrBot!",
},
"discord": {
"id": "test_discord",
"discord_token": "test_token_123",
"discord_proxy": None,
"discord_command_register": True,
"discord_guild_id_for_debug": None,
"discord_activity_name": "Playing AstrBot",
},
"aiocqhttp": {
"id": "test_aiocqhttp",
"ws_reverse_host": "0.0.0.0",
"ws_reverse_port": 6199,
"ws_reverse_token": "test_token",
},
"webchat": {
"id": "test_webchat",
},
"wecom": {
"id": "test_wecom",
"wecom_corpid": "test_corpid",
"wecom_secret": "test_secret",
},
}
config = configs.get(platform_type, {"id": f"test_{platform_type}"}).copy()
config.update(kwargs)
return config
# ============================================================
# Telegram 辅助函数
# ============================================================
def create_mock_update(
message_text: str | None = "Hello World",
chat_type: str = "private",
chat_id: int = 123456789,
user_id: int = 987654321,
username: str = "test_user",
message_id: int = 1,
media_group_id: str | None = None,
photo: list | None = None,
video: MagicMock | None = None,
document: MagicMock | None = None,
voice: MagicMock | None = None,
sticker: MagicMock | None = None,
reply_to_message: MagicMock | None = None,
caption: str | None = None,
entities: list | None = None,
caption_entities: list | None = None,
message_thread_id: int | None = None,
is_topic_message: bool = False,
):
"""创建模拟的 Telegram Update 对象。
Args:
message_text: 消息文本
chat_type: 聊天类型
chat_id: 聊天 ID
user_id: 用户 ID
username: 用户名
message_id: 消息 ID
media_group_id: 媒体组 ID
photo: 图片列表
video: 视频对象
document: 文档对象
voice: 语音对象
sticker: 贴纸对象
reply_to_message: 回复的消息
caption: 说明文字
entities: 实体列表
caption_entities: 说明实体列表
message_thread_id: 消息线程 ID
is_topic_message: 是否为主题消息
Returns:
MagicMock: 模拟的 Update 对象
"""
update = MagicMock()
update.update_id = 1
# Create message mock
message = MagicMock()
message.message_id = message_id
message.chat = MagicMock()
message.chat.id = chat_id
message.chat.type = chat_type
message.message_thread_id = message_thread_id
message.is_topic_message = is_topic_message
# Create user mock
from_user = MagicMock()
from_user.id = user_id
from_user.username = username
message.from_user = from_user
# Set message content
message.text = message_text
message.media_group_id = media_group_id
message.photo = photo
message.video = video
message.document = document
message.voice = voice
message.sticker = sticker
message.reply_to_message = reply_to_message
message.caption = caption
message.entities = entities
message.caption_entities = caption_entities
update.message = message
update.effective_chat = message.chat
return update
def create_mock_file(file_path: str = "https://api.telegram.org/file/test.jpg"):
"""创建模拟的 Telegram File 对象。
Args:
file_path: 文件路径
Returns:
MagicMock: 模拟的 File 对象
"""
file = MagicMock()
file.file_path = file_path
file.get_file = AsyncMock(return_value=file)
return file
# ============================================================
# Discord 辅助函数
# ============================================================
def create_mock_discord_attachment(
filename: str = "test.txt",
url: str = "https://cdn.discordapp.com/test.txt",
content_type: str | None = None,
size: int = 1024,
):
"""创建模拟的 Discord Attachment 对象。
Args:
filename: 文件名
url: 文件 URL
content_type: 内容类型
size: 文件大小
Returns:
MagicMock: 模拟的 Attachment 对象
"""
attachment = MagicMock()
attachment.filename = filename
attachment.url = url
attachment.content_type = content_type
attachment.size = size
return attachment
def create_mock_discord_user(
user_id: int = 123456789,
name: str = "TestUser",
display_name: str = "Test User",
bot: bool = False,
):
"""创建模拟的 Discord User 对象。
Args:
user_id: 用户 ID
name: 用户名
display_name: 显示名
bot: 是否为机器人
Returns:
MagicMock: 模拟的 User 对象
"""
user = MagicMock()
user.id = user_id
user.name = name
user.display_name = display_name
user.bot = bot
user.mention = f"<@{user_id}>"
return user
def create_mock_discord_channel(
channel_id: int = 111222333,
channel_type: str = "text",
name: str = "general",
guild_id: int | None = 444555666,
):
"""创建模拟的 Discord Channel 对象。
Args:
channel_id: 频道 ID
channel_type: 频道类型
name: 频道名
guild_id: 服务器 ID
Returns:
MagicMock: 模拟的 Channel 对象
"""
channel = MagicMock()
channel.id = channel_id
channel.name = name
channel.type = channel_type
if guild_id:
channel.guild = MagicMock()
channel.guild.id = guild_id
else:
channel.guild = None
return channel
# ============================================================
# 消息组件辅助函数
# ============================================================
def create_mock_message_component(
component_type: str,
**kwargs: Any,
) -> BaseMessageComponent:
"""创建模拟的消息组件。
Args:
component_type: 组件类型 (plain, image, at, reply, file)
**kwargs: 组件参数
Returns:
BaseMessageComponent: 消息组件实例
"""
from astrbot.core.message import components as Comp
component_map = {
"plain": Comp.Plain,
"image": Comp.Image,
"at": Comp.At,
"reply": Comp.Reply,
"file": Comp.File,
}
component_class = component_map.get(component_type.lower())
if not component_class:
raise ValueError(f"Unknown component type: {component_type}")
return component_class(**kwargs)
def create_mock_llm_response(
completion_text: str = "Hello! How can I help you?",
role: str = "assistant",
tools_call_name: list[str] | None = None,
tools_call_args: list[dict] | None = None,
tools_call_ids: list[str] | None = None,
):
"""创建模拟的 LLM 响应。
Args:
completion_text: 完成文本
role: 角色
tools_call_name: 工具调用名称列表
tools_call_args: 工具调用参数列表
tools_call_ids: 工具调用 ID 列表
Returns:
LLMResponse: 模拟的 LLM 响应
"""
from astrbot.core.provider.entities import LLMResponse, TokenUsage
return LLMResponse(
role=role,
completion_text=completion_text,
tools_call_name=tools_call_name or [],
tools_call_args=tools_call_args or [],
tools_call_ids=tools_call_ids or [],
usage=TokenUsage(input_other=10, output=5),
)
-33
View File
@@ -1,33 +0,0 @@
{
"plain_message": {
"type": "plain",
"text": "Hello, this is a test message."
},
"image_message": {
"type": "image",
"url": "https://example.com/test.jpg",
"file": null
},
"at_message": {
"type": "at",
"user_id": "12345",
"nickname": "TestUser"
},
"reply_message": {
"type": "reply",
"id": "msg_123",
"sender_nickname": "OriginalSender",
"message_str": "This is the original message"
},
"file_message": {
"type": "file",
"name": "test.pdf",
"url": "https://example.com/test.pdf"
},
"combined_message": {
"components": [
{"type": "at", "user_id": "bot_id"},
{"type": "plain", "text": " Hello bot!"}
]
}
}
-43
View File
@@ -1,43 +0,0 @@
"""测试 Mock 模块。
提供统一的 mock 工具和 fixture减少测试代码重复
使用方式:
# 在测试文件顶部导入需要的 fixture
from tests.fixtures.mocks import mock_telegram_modules
# 或使用 Builder 类创建 mock 对象
from tests.fixtures.mocks import MockTelegramBuilder
bot = MockTelegramBuilder.create_bot()
"""
from .aiocqhttp import (
MockAiocqhttpBuilder,
create_mock_aiocqhttp_modules,
mock_aiocqhttp_modules,
)
from .discord import (
MockDiscordBuilder,
create_mock_discord_modules,
mock_discord_modules,
)
from .telegram import (
MockTelegramBuilder,
create_mock_telegram_modules,
mock_telegram_modules,
)
__all__ = [
# Telegram
"mock_telegram_modules",
"create_mock_telegram_modules",
"MockTelegramBuilder",
# Discord
"mock_discord_modules",
"create_mock_discord_modules",
"MockDiscordBuilder",
# Aiocqhttp
"mock_aiocqhttp_modules",
"create_mock_aiocqhttp_modules",
"MockAiocqhttpBuilder",
]
-58
View File
@@ -1,58 +0,0 @@
"""Aiocqhttp 模块 Mock 工具。
提供统一的 aiocqhttp 相关模块 mock 设置避免在测试文件中重复定义
"""
import sys
from unittest.mock import AsyncMock, MagicMock
import pytest
def create_mock_aiocqhttp_modules():
"""创建 aiocqhttp 相关的 mock 模块。
Returns:
dict: 包含 aiocqhttp 和相关模块的 mock 对象
"""
mock_aiocqhttp = MagicMock()
mock_aiocqhttp.CQHttp = MagicMock
mock_aiocqhttp.Event = MagicMock
mock_aiocqhttp.exceptions = MagicMock()
mock_aiocqhttp.exceptions.ActionFailed = Exception
return mock_aiocqhttp
@pytest.fixture(scope="module", autouse=True)
def mock_aiocqhttp_modules():
"""Mock aiocqhttp 相关模块的 fixture。
自动应用于使用此 fixture 的测试模块
"""
mock_aiocqhttp = create_mock_aiocqhttp_modules()
monkeypatch = pytest.MonkeyPatch()
monkeypatch.setitem(sys.modules, "aiocqhttp", mock_aiocqhttp)
monkeypatch.setitem(sys.modules, "aiocqhttp.exceptions", mock_aiocqhttp.exceptions)
yield
monkeypatch.undo()
class MockAiocqhttpBuilder:
"""构建 aiocqhttp 测试 mock 对象的工具类。"""
@staticmethod
def create_bot():
"""创建 mock CQHttp bot 实例。"""
from tests.fixtures.helpers import NoopAwaitable
bot = MagicMock()
bot.send = AsyncMock()
bot.call_action = AsyncMock()
bot.on_request = MagicMock()
bot.on_notice = MagicMock()
bot.on_message = MagicMock()
bot.on_websocket_connection = MagicMock()
bot.run_task = MagicMock(return_value=NoopAwaitable())
return bot
-140
View File
@@ -1,140 +0,0 @@
"""Discord 模块 Mock 工具。
提供统一的 Discord 相关模块 mock 设置避免在测试文件中重复定义
"""
import sys
from unittest.mock import AsyncMock, MagicMock
import pytest
def create_mock_discord_modules():
"""创建 Discord 相关的 mock 模块。
Returns:
dict: 包含 discord 和相关模块的 mock 对象
"""
mock_discord = MagicMock()
# Mock discord.Intents
mock_intents = MagicMock()
mock_intents.default = MagicMock(return_value=mock_intents)
mock_discord.Intents = mock_intents
# Mock discord.Status
mock_discord.Status = MagicMock()
mock_discord.Status.online = "online"
# Mock discord.Bot
mock_bot = MagicMock()
mock_discord.Bot = MagicMock(return_value=mock_bot)
# Mock discord.Embed
mock_embed = MagicMock()
mock_discord.Embed = MagicMock(return_value=mock_embed)
# Mock discord.ui
mock_ui = MagicMock()
mock_ui.View = MagicMock
mock_ui.Button = MagicMock
mock_discord.ui = mock_ui
# Mock discord.Message
mock_discord.Message = MagicMock
# Mock discord.Interaction
mock_discord.Interaction = MagicMock
mock_discord.InteractionType = MagicMock()
mock_discord.InteractionType.application_command = 2
mock_discord.InteractionType.component = 3
# Mock discord.File
mock_discord.File = MagicMock
# Mock discord.SlashCommand
mock_discord.SlashCommand = MagicMock
# Mock discord.Option
mock_discord.Option = MagicMock
# Mock discord.SlashCommandOptionType
mock_discord.SlashCommandOptionType = MagicMock()
mock_discord.SlashCommandOptionType.string = 3
# Mock discord.errors
mock_discord.errors = MagicMock()
mock_discord.errors.LoginFailure = Exception
mock_discord.errors.ConnectionClosed = Exception
mock_discord.errors.NotFound = Exception
mock_discord.errors.Forbidden = Exception
# Mock discord.abc
mock_discord.abc = MagicMock()
mock_discord.abc.GuildChannel = MagicMock
mock_discord.abc.Messageable = MagicMock
mock_discord.abc.PrivateChannel = MagicMock
# Mock discord.channel
mock_channel = MagicMock()
mock_channel.DMChannel = MagicMock
mock_discord.channel = mock_channel
# Mock discord.types
mock_discord.types = MagicMock()
mock_discord.types.interactions = MagicMock()
# Mock discord.ApplicationContext
mock_discord.ApplicationContext = MagicMock
# Mock discord.CustomActivity
mock_discord.CustomActivity = MagicMock
return mock_discord
@pytest.fixture(scope="module", autouse=True)
def mock_discord_modules():
"""Mock Discord 相关模块的 fixture。
自动应用于使用此 fixture 的测试模块
"""
mock_discord = create_mock_discord_modules()
monkeypatch = pytest.MonkeyPatch()
monkeypatch.setitem(sys.modules, "discord", mock_discord)
monkeypatch.setitem(sys.modules, "discord.abc", mock_discord.abc)
monkeypatch.setitem(sys.modules, "discord.channel", mock_discord.channel)
monkeypatch.setitem(sys.modules, "discord.errors", mock_discord.errors)
monkeypatch.setitem(sys.modules, "discord.types", mock_discord.types)
monkeypatch.setitem(
sys.modules,
"discord.types.interactions",
mock_discord.types.interactions,
)
monkeypatch.setitem(sys.modules, "discord.ui", mock_discord.ui)
yield
monkeypatch.undo()
class MockDiscordBuilder:
"""构建 Discord 测试 mock 对象的工具类。"""
@staticmethod
def create_client():
"""创建 mock Discord client 实例。"""
client = MagicMock()
client.user = MagicMock()
client.user.id = 123456789
client.user.display_name = "TestBot"
client.user.name = "TestBot"
client.get_channel = MagicMock()
client.fetch_channel = AsyncMock()
client.get_message = MagicMock()
client.start = AsyncMock()
client.close = AsyncMock()
client.is_closed = MagicMock(return_value=False)
client.add_application_command = MagicMock()
client.sync_commands = AsyncMock()
client.change_presence = AsyncMock()
return client
-141
View File
@@ -1,141 +0,0 @@
"""Telegram 模块 Mock 工具。
提供统一的 Telegram 相关模块 mock 设置避免在测试文件中重复定义
"""
import sys
from unittest.mock import AsyncMock, MagicMock
import pytest
def create_mock_telegram_modules():
"""创建 Telegram 相关的 mock 模块。
Returns:
dict: 包含 telegram 和相关模块的 mock 对象
"""
mock_telegram = MagicMock()
mock_telegram.BotCommand = MagicMock
mock_telegram.Update = MagicMock
mock_telegram.constants = MagicMock()
mock_telegram.constants.ChatType = MagicMock()
mock_telegram.constants.ChatType.PRIVATE = "private"
mock_telegram.constants.ChatAction = MagicMock()
mock_telegram.constants.ChatAction.TYPING = "typing"
mock_telegram.constants.ChatAction.UPLOAD_VOICE = "upload_voice"
mock_telegram.constants.ChatAction.UPLOAD_DOCUMENT = "upload_document"
mock_telegram.constants.ChatAction.UPLOAD_PHOTO = "upload_photo"
mock_telegram.error = MagicMock()
mock_telegram.error.BadRequest = Exception
mock_telegram.ReactionTypeCustomEmoji = MagicMock
mock_telegram.ReactionTypeEmoji = MagicMock
mock_telegram_ext = MagicMock()
mock_telegram_ext.ApplicationBuilder = MagicMock
mock_telegram_ext.ContextTypes = MagicMock
mock_telegram_ext.ExtBot = MagicMock
mock_telegram_ext.filters = MagicMock()
mock_telegram_ext.filters.ALL = MagicMock()
mock_telegram_ext.MessageHandler = MagicMock
# Mock telegramify_markdown
mock_telegramify = MagicMock()
mock_telegramify.markdownify = lambda text, **kwargs: text
# Mock apscheduler
mock_apscheduler = MagicMock()
mock_apscheduler.schedulers = MagicMock()
mock_apscheduler.schedulers.asyncio = MagicMock()
mock_apscheduler.schedulers.asyncio.AsyncIOScheduler = MagicMock
mock_apscheduler.schedulers.background = MagicMock()
mock_apscheduler.schedulers.background.BackgroundScheduler = MagicMock
return {
"telegram": mock_telegram,
"telegram.ext": mock_telegram_ext,
"telegramify_markdown": mock_telegramify,
"apscheduler": mock_apscheduler,
}
@pytest.fixture(scope="module", autouse=True)
def mock_telegram_modules():
"""Mock Telegram 相关模块的 fixture。
自动应用于使用此 fixture 的测试模块
"""
mocks = create_mock_telegram_modules()
monkeypatch = pytest.MonkeyPatch()
monkeypatch.setitem(sys.modules, "telegram", mocks["telegram"])
monkeypatch.setitem(sys.modules, "telegram.constants", mocks["telegram"].constants)
monkeypatch.setitem(sys.modules, "telegram.error", mocks["telegram"].error)
monkeypatch.setitem(sys.modules, "telegram.ext", mocks["telegram.ext"])
monkeypatch.setitem(sys.modules, "telegramify_markdown", mocks["telegramify_markdown"])
monkeypatch.setitem(sys.modules, "apscheduler", mocks["apscheduler"])
monkeypatch.setitem(
sys.modules, "apscheduler.schedulers", mocks["apscheduler"].schedulers
)
monkeypatch.setitem(
sys.modules,
"apscheduler.schedulers.asyncio",
mocks["apscheduler"].schedulers.asyncio,
)
monkeypatch.setitem(
sys.modules,
"apscheduler.schedulers.background",
mocks["apscheduler"].schedulers.background,
)
yield
monkeypatch.undo()
class MockTelegramBuilder:
"""构建 Telegram 测试 mock 对象的工具类。"""
@staticmethod
def create_bot():
"""创建 mock Telegram bot 实例。"""
bot = MagicMock()
bot.username = "test_bot"
bot.id = 12345678
bot.base_url = "https://api.telegram.org/bottest_token_123/"
bot.send_message = AsyncMock()
bot.send_photo = AsyncMock()
bot.send_document = AsyncMock()
bot.send_voice = AsyncMock()
bot.send_chat_action = AsyncMock()
bot.delete_my_commands = AsyncMock()
bot.set_my_commands = AsyncMock()
bot.set_message_reaction = AsyncMock()
bot.edit_message_text = AsyncMock()
return bot
@staticmethod
def create_application():
"""创建 mock Telegram Application 实例。"""
from tests.fixtures.helpers import NoopAwaitable
app = MagicMock()
app.bot = MagicMock()
app.bot.username = "test_bot"
app.bot.base_url = "https://api.telegram.org/bottest_token_123/"
app.initialize = AsyncMock()
app.start = AsyncMock()
app.stop = AsyncMock()
app.add_handler = MagicMock()
app.updater = MagicMock()
app.updater.start_polling = MagicMock(return_value=NoopAwaitable())
app.updater.stop = AsyncMock()
return app
@staticmethod
def create_scheduler():
"""创建 mock APScheduler 实例。"""
scheduler = MagicMock()
scheduler.add_job = MagicMock()
scheduler.start = MagicMock()
scheduler.running = True
scheduler.shutdown = MagicMock()
return scheduler
-40
View File
@@ -1,40 +0,0 @@
"""
测试插件 - 用于插件系统测试
这是一个最小化的测试插件用于验证插件系统的功能
"""
from astrbot.api import llm_tool, star
from astrbot.api.event import AstrMessageEvent, MessageEventResult, filter
@star.register("test_plugin", "AstrBot Team", "测试插件 - 用于插件系统测试", "1.0.0")
class TestPlugin(star.Star):
"""测试插件类"""
def __init__(self, context: star.Context) -> None:
super().__init__(context)
self.initialized = True
async def terminate(self) -> None:
"""插件终止"""
self.initialized = False
@filter.command("test_cmd")
async def test_command(self, event: AstrMessageEvent) -> None:
"""测试命令处理器。"""
event.set_result(MessageEventResult().message("测试命令执行成功"))
@llm_tool("test_tool")
async def test_llm_tool(self, query: str) -> str:
"""测试 LLM 工具。
Args:
query(string): 查询内容
"""
return f"测试工具执行成功: {query}"
@filter.regex(r"^test_regex_(.+)$")
async def test_regex_handler(self, event: AstrMessageEvent) -> None:
"""测试正则处理器。"""
event.set_result(MessageEventResult().message("正则匹配成功"))
-5
View File
@@ -1,5 +0,0 @@
name: test_plugin
description: 测试插件 - 用于插件系统测试
version: 1.0.0
author: AstrBot Team
repo: https://github.com/test/test_plugin
-115
View File
@@ -1,115 +0,0 @@
"""Smoke tests for critical startup and import paths."""
from __future__ import annotations
import subprocess
import sys
from pathlib import Path
from astrbot.core.pipeline.bootstrap import ensure_builtin_stages_registered
from astrbot.core.pipeline.process_stage.method.agent_sub_stages.internal import (
InternalAgentSubStage,
)
from astrbot.core.pipeline.process_stage.method.agent_sub_stages.third_party import (
ThirdPartyAgentSubStage,
)
from astrbot.core.pipeline.stage import Stage, registered_stages
from astrbot.core.pipeline.stage_order import STAGES_ORDER
REPO_ROOT = Path(__file__).resolve().parents[1]
def _run_code_in_fresh_interpreter(code: str, failure_message: str) -> None:
proc = subprocess.run(
[sys.executable, "-c", code],
cwd=REPO_ROOT,
capture_output=True,
text=True,
check=False,
)
assert proc.returncode == 0, (
f"{failure_message}\nstdout:\n{proc.stdout}\nstderr:\n{proc.stderr}\n"
)
def test_smoke_critical_imports_in_fresh_interpreter() -> None:
code = (
"import importlib;"
"mods=["
"'astrbot.core.core_lifecycle',"
"'astrbot.core.astr_main_agent',"
"'astrbot.core.pipeline.scheduler',"
"'astrbot.core.pipeline.process_stage.method.agent_sub_stages.internal',"
"'astrbot.core.pipeline.process_stage.method.agent_sub_stages.third_party'"
"];"
"[importlib.import_module(m) for m in mods]"
)
_run_code_in_fresh_interpreter(code, "Smoke import check failed.")
def test_smoke_pipeline_stage_registration_matches_order() -> None:
ensure_builtin_stages_registered()
stage_names = {cls.__name__ for cls in registered_stages}
assert set(STAGES_ORDER).issubset(stage_names)
assert len(stage_names) == len(registered_stages)
def test_smoke_agent_sub_stages_are_stage_subclasses() -> None:
assert issubclass(InternalAgentSubStage, Stage)
assert issubclass(ThirdPartyAgentSubStage, Stage)
def test_pipeline_package_exports_remain_compatible() -> None:
import astrbot.core.pipeline as pipeline
assert pipeline.ProcessStage is not None
assert pipeline.RespondStage is not None
assert isinstance(pipeline.STAGES_ORDER, list)
assert "ProcessStage" in pipeline.STAGES_ORDER
def test_builtin_stage_bootstrap_is_idempotent() -> None:
ensure_builtin_stages_registered()
before_count = len(registered_stages)
stage_names = {cls.__name__ for cls in registered_stages}
expected_stage_names = {
"WakingCheckStage",
"WhitelistCheckStage",
"SessionStatusCheckStage",
"RateLimitStage",
"ContentSafetyCheckStage",
"PreProcessStage",
"ProcessStage",
"ResultDecorateStage",
"RespondStage",
}
assert expected_stage_names.issubset(stage_names)
ensure_builtin_stages_registered()
assert len(registered_stages) == before_count
def test_pipeline_import_is_stable_with_mocked_apscheduler() -> None:
"""Regression: importing pipeline should not require cron/apscheduler modules."""
code = (
"import sys;"
"from unittest.mock import MagicMock;"
"mock_apscheduler = MagicMock();"
"mock_apscheduler.schedulers = MagicMock();"
"mock_apscheduler.schedulers.asyncio = MagicMock();"
"mock_apscheduler.schedulers.background = MagicMock();"
"sys.modules['apscheduler'] = mock_apscheduler;"
"sys.modules['apscheduler.schedulers'] = mock_apscheduler.schedulers;"
"sys.modules['apscheduler.schedulers.asyncio'] = mock_apscheduler.schedulers.asyncio;"
"sys.modules['apscheduler.schedulers.background'] = mock_apscheduler.schedulers.background;"
"import astrbot.core.pipeline as pipeline;"
"assert pipeline.ProcessStage is not None;"
"assert pipeline.RespondStage is not None"
)
_run_code_in_fresh_interpreter(
code,
"Pipeline import should not depend on real apscheduler package.",
)
-781
View File
@@ -1,781 +0,0 @@
"""Tests for AstrMessageEvent class."""
import re
from unittest.mock import AsyncMock, patch
import pytest
from astrbot.core.message.components import (
At,
AtAll,
Face,
Forward,
Image,
Plain,
Reply,
)
from astrbot.core.message.message_event_result import MessageEventResult
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.platform.astrbot_message import AstrBotMessage, MessageMember
from astrbot.core.platform.message_type import MessageType
from astrbot.core.platform.platform_metadata import PlatformMetadata
class ConcreteAstrMessageEvent(AstrMessageEvent):
"""Concrete implementation of AstrMessageEvent for testing purposes."""
async def send(self, message):
"""Send message implementation."""
await super().send(message)
@pytest.fixture
def platform_meta():
"""Create platform metadata for testing."""
return PlatformMetadata(
name="test_platform",
description="Test platform",
id="test_platform_id",
)
@pytest.fixture
def message_member():
"""Create a message member for testing."""
return MessageMember(user_id="user123", nickname="TestUser")
@pytest.fixture
def astrbot_message(message_member):
"""Create an AstrBotMessage for testing."""
message = AstrBotMessage()
message.type = MessageType.FRIEND_MESSAGE
message.self_id = "bot123"
message.session_id = "session123"
message.message_id = "msg123"
message.sender = message_member
message.message = [Plain(text="Hello world")]
message.message_str = "Hello world"
message.raw_message = None
return message
@pytest.fixture
def astr_message_event(platform_meta, astrbot_message):
"""Create an AstrMessageEvent instance for testing."""
return ConcreteAstrMessageEvent(
message_str="Hello world",
message_obj=astrbot_message,
platform_meta=platform_meta,
session_id="session123",
)
class TestAstrMessageEventInit:
"""Tests for AstrMessageEvent initialization."""
def test_init_basic(self, astr_message_event):
"""Test basic AstrMessageEvent initialization."""
assert astr_message_event.message_str == "Hello world"
assert astr_message_event.role == "member"
assert astr_message_event.is_wake is False
assert astr_message_event.is_at_or_wake_command is False
assert astr_message_event._extras == {}
assert astr_message_event._result is None
assert astr_message_event.call_llm is False
def test_init_session(self, astr_message_event):
"""Test session initialization."""
assert astr_message_event.session_id == "session123"
assert astr_message_event.session.platform_name == "test_platform_id"
def test_init_platform_reference(self, astr_message_event, platform_meta):
"""Test platform reference initialization."""
assert astr_message_event.platform_meta == platform_meta
assert astr_message_event.platform == platform_meta # back compatibility
def test_init_created_at(self, astr_message_event):
"""Test created_at timestamp is set."""
assert astr_message_event.created_at is not None
assert isinstance(astr_message_event.created_at, float)
def test_init_trace(self, astr_message_event):
"""Test trace/span initialization."""
assert astr_message_event.trace is not None
assert astr_message_event.span is not None
assert astr_message_event.trace == astr_message_event.span
class TestUnifiedMsgOrigin:
"""Tests for unified_msg_origin property."""
def test_unified_msg_origin_getter(self, astr_message_event):
"""Test unified_msg_origin getter."""
expected = "test_platform_id:FriendMessage:session123"
assert astr_message_event.unified_msg_origin == expected
def test_unified_msg_origin_setter(self, astr_message_event):
"""Test unified_msg_origin setter."""
astr_message_event.unified_msg_origin = "new_platform:GroupMessage:new_session"
assert astr_message_event.session.platform_name == "new_platform"
assert astr_message_event.session.session_id == "new_session"
class TestSessionId:
"""Tests for session_id property."""
def test_session_id_getter(self, astr_message_event):
"""Test session_id getter."""
assert astr_message_event.session_id == "session123"
def test_session_id_setter(self, astr_message_event):
"""Test session_id setter."""
astr_message_event.session_id = "new_session_id"
assert astr_message_event.session_id == "new_session_id"
class TestGetPlatformInfo:
"""Tests for platform info methods."""
def test_get_platform_name(self, astr_message_event):
"""Test get_platform_name method."""
assert astr_message_event.get_platform_name() == "test_platform"
def test_get_platform_id(self, astr_message_event):
"""Test get_platform_id method."""
assert astr_message_event.get_platform_id() == "test_platform_id"
class TestGetMessageInfo:
"""Tests for message info methods."""
def test_get_message_str(self, astr_message_event):
"""Test get_message_str method."""
assert astr_message_event.get_message_str() == "Hello world"
def test_get_message_str_none(self, platform_meta, astrbot_message):
"""Test get_message_str keeps None when source message_str is None."""
astrbot_message.message_str = None
event = ConcreteAstrMessageEvent(
message_str=None,
message_obj=astrbot_message,
platform_meta=platform_meta,
session_id="session123",
)
assert event.get_message_str() is None
def test_get_messages(self, astr_message_event):
"""Test get_messages method."""
messages = astr_message_event.get_messages()
assert len(messages) == 1
assert isinstance(messages[0], Plain)
assert messages[0].text == "Hello world"
def test_get_message_type(self, astr_message_event):
"""Test get_message_type method."""
assert astr_message_event.get_message_type() == MessageType.FRIEND_MESSAGE
def test_get_session_id(self, astr_message_event):
"""Test get_session_id method."""
assert astr_message_event.get_session_id() == "session123"
def test_get_group_id_empty_for_private(self, astr_message_event):
"""Test get_group_id returns empty for private messages."""
assert astr_message_event.get_group_id() == ""
def test_get_self_id(self, astr_message_event):
"""Test get_self_id method."""
assert astr_message_event.get_self_id() == "bot123"
def test_get_sender_id(self, astr_message_event):
"""Test get_sender_id method."""
assert astr_message_event.get_sender_id() == "user123"
def test_get_sender_name(self, astr_message_event):
"""Test get_sender_name method."""
assert astr_message_event.get_sender_name() == "TestUser"
def test_get_sender_name_empty_when_none(self, platform_meta, astrbot_message):
"""Test get_sender_name returns empty string when nickname is None."""
astrbot_message.sender = MessageMember(user_id="user123", nickname=None)
event = ConcreteAstrMessageEvent(
message_str="test",
message_obj=astrbot_message,
platform_meta=platform_meta,
session_id="session123",
)
assert event.get_sender_name() == ""
def test_get_sender_name_coerces_non_string(self, platform_meta, astrbot_message):
"""Test get_sender_name stringifies non-string nickname values."""
astrbot_message.sender = MessageMember(user_id="user123", nickname=None)
astrbot_message.sender.nickname = 12345
event = ConcreteAstrMessageEvent(
message_str="test",
message_obj=astrbot_message,
platform_meta=platform_meta,
session_id="session123",
)
assert event.get_sender_name() == "12345"
class TestGetMessageOutline:
"""Tests for get_message_outline method."""
def test_outline_plain_text(self, astr_message_event):
"""Test outline with plain text message."""
outline = astr_message_event.get_message_outline()
assert "Hello world" in outline
def test_outline_with_image(self, platform_meta, astrbot_message):
"""Test outline with image component."""
astrbot_message.message = [
Plain(text="Look at this"),
Image(file="http://example.com/img.jpg"),
]
event = ConcreteAstrMessageEvent(
message_str="Look at this",
message_obj=astrbot_message,
platform_meta=platform_meta,
session_id="session123",
)
outline = event.get_message_outline()
assert "Look at this" in outline
assert "[图片]" in outline
def test_outline_with_at(self, platform_meta, astrbot_message):
"""Test outline with At component."""
astrbot_message.message = [At(qq="12345"), Plain(text=" hello")]
event = ConcreteAstrMessageEvent(
message_str=" hello",
message_obj=astrbot_message,
platform_meta=platform_meta,
session_id="session123",
)
outline = event.get_message_outline()
assert "[At:12345]" in outline
def test_outline_with_at_all(self, platform_meta, astrbot_message):
"""Test outline with AtAll component."""
astrbot_message.message = [AtAll()]
event = ConcreteAstrMessageEvent(
message_str="",
message_obj=astrbot_message,
platform_meta=platform_meta,
session_id="session123",
)
outline = event.get_message_outline()
# AtAll format is "[At:all]" in the actual implementation
assert "[At:" in outline and "all" in outline.lower()
def test_outline_with_face(self, platform_meta, astrbot_message):
"""Test outline with Face component."""
astrbot_message.message = [Face(id="123")]
event = ConcreteAstrMessageEvent(
message_str="",
message_obj=astrbot_message,
platform_meta=platform_meta,
session_id="session123",
)
outline = event.get_message_outline()
assert "[表情:123]" in outline
def test_outline_with_forward(self, platform_meta, astrbot_message):
"""Test outline with Forward component."""
# Forward requires an id parameter
astrbot_message.message = [Forward(id="test_forward_id")]
event = ConcreteAstrMessageEvent(
message_str="",
message_obj=astrbot_message,
platform_meta=platform_meta,
session_id="session123",
)
outline = event.get_message_outline()
assert "[转发消息]" in outline
def test_outline_with_reply(self, platform_meta, astrbot_message):
"""Test outline with Reply component."""
# Reply requires an id parameter
reply = Reply(id="test_reply_id")
reply.message_str = "Original message"
reply.sender_nickname = "Sender"
astrbot_message.message = [reply, Plain(text=" reply")]
event = ConcreteAstrMessageEvent(
message_str=" reply",
message_obj=astrbot_message,
platform_meta=platform_meta,
session_id="session123",
)
outline = event.get_message_outline()
assert "[引用消息(Sender: Original message)]" in outline
def test_outline_with_reply_no_message(self, platform_meta, astrbot_message):
"""Test outline with Reply component without message_str."""
# Reply requires an id parameter
reply = Reply(id="test_reply_id")
reply.message_str = None
astrbot_message.message = [reply]
event = ConcreteAstrMessageEvent(
message_str="",
message_obj=astrbot_message,
platform_meta=platform_meta,
session_id="session123",
)
outline = event.get_message_outline()
assert "[引用消息]" in outline
def test_outline_empty_chain(self, platform_meta, astrbot_message):
"""Test outline with empty message chain."""
astrbot_message.message = []
event = ConcreteAstrMessageEvent(
message_str="",
message_obj=astrbot_message,
platform_meta=platform_meta,
session_id="session123",
)
outline = event.get_message_outline()
assert outline == ""
def test_outline_very_long_plain_text(self, platform_meta, astrbot_message):
"""Test outline generation for very long plain text content."""
long_text = "A" * 20000
astrbot_message.message = [Plain(text=long_text)]
event = ConcreteAstrMessageEvent(
message_str=long_text,
message_obj=astrbot_message,
platform_meta=platform_meta,
session_id="session123",
)
outline = event.get_message_outline()
assert outline.startswith("A")
assert len(outline) >= 20000
class TestExtras:
"""Tests for extra information methods."""
def test_set_extra(self, astr_message_event):
"""Test set_extra method."""
astr_message_event.set_extra("key1", "value1")
assert astr_message_event._extras["key1"] == "value1"
def test_get_extra_with_key(self, astr_message_event):
"""Test get_extra with specific key."""
astr_message_event.set_extra("key1", "value1")
assert astr_message_event.get_extra("key1") == "value1"
def test_get_extra_with_default(self, astr_message_event):
"""Test get_extra with default value."""
result = astr_message_event.get_extra("nonexistent", "default_value")
assert result == "default_value"
def test_get_extra_all(self, astr_message_event):
"""Test get_extra without key returns all extras."""
astr_message_event.set_extra("key1", "value1")
astr_message_event.set_extra("key2", "value2")
all_extras = astr_message_event.get_extra()
assert all_extras == {"key1": "value1", "key2": "value2"}
def test_clear_extra(self, astr_message_event):
"""Test clear_extra method."""
astr_message_event.set_extra("key1", "value1")
astr_message_event.clear_extra()
assert astr_message_event._extras == {}
class TestSetResult:
"""Tests for set_result method."""
def test_set_result_with_message_event_result(self, astr_message_event):
"""Test set_result with MessageEventResult object."""
result = MessageEventResult().message("Test message")
astr_message_event.set_result(result)
assert astr_message_event._result == result
def test_set_result_with_string(self, astr_message_event):
"""Test set_result with string creates MessageEventResult."""
astr_message_event.set_result("Test message")
assert astr_message_event._result is not None
assert len(astr_message_event._result.chain) == 1
assert isinstance(astr_message_event._result.chain[0], Plain)
def test_set_result_with_empty_chain(self, astr_message_event):
"""Test set_result handles empty chain correctly."""
result = MessageEventResult()
# chain is already an empty list by default
astr_message_event.set_result(result)
assert astr_message_event._result.chain == []
class TestStopContinueEvent:
"""Tests for stop_event and continue_event methods."""
def test_stop_event_creates_result_if_none(self, astr_message_event):
"""Test stop_event creates result if none exists."""
astr_message_event.stop_event()
assert astr_message_event._result is not None
assert astr_message_event.is_stopped() is True
def test_stop_event_with_existing_result(self, astr_message_event):
"""Test stop_event with existing result."""
astr_message_event.set_result(MessageEventResult().message("Test"))
astr_message_event.stop_event()
assert astr_message_event.is_stopped() is True
def test_continue_event_creates_result_if_none(self, astr_message_event):
"""Test continue_event creates result if none exists."""
astr_message_event.continue_event()
assert astr_message_event._result is not None
assert astr_message_event.is_stopped() is False
def test_continue_event_with_existing_result(self, astr_message_event):
"""Test continue_event with existing result."""
astr_message_event.set_result(MessageEventResult().message("Test"))
astr_message_event.stop_event()
astr_message_event.continue_event()
assert astr_message_event.is_stopped() is False
def test_is_stopped_default_false(self, astr_message_event):
"""Test is_stopped returns False by default."""
assert astr_message_event.is_stopped() is False
class TestIsPrivateChat:
"""Tests for is_private_chat method."""
def test_is_private_chat_true(self, astr_message_event):
"""Test is_private_chat returns True for friend message."""
assert astr_message_event.is_private_chat() is True
def test_is_private_chat_false(self, platform_meta, astrbot_message):
"""Test is_private_chat returns False for group message."""
astrbot_message.type = MessageType.GROUP_MESSAGE
event = ConcreteAstrMessageEvent(
message_str="test",
message_obj=astrbot_message,
platform_meta=platform_meta,
session_id="session123",
)
assert event.is_private_chat() is False
class TestIsWakeUp:
"""Tests for is_wake_up method."""
def test_is_wake_up_default_false(self, astr_message_event):
"""Test is_wake_up returns False by default."""
assert astr_message_event.is_wake_up() is False
def test_is_wake_up_when_set(self, astr_message_event):
"""Test is_wake_up returns True when is_wake is set."""
astr_message_event.is_wake = True
assert astr_message_event.is_wake_up() is True
class TestIsAdmin:
"""Tests for is_admin method."""
def test_is_admin_default_false(self, astr_message_event):
"""Test is_admin returns False by default."""
assert astr_message_event.is_admin() is False
def test_is_admin_when_admin(self, astr_message_event):
"""Test is_admin returns True when role is admin."""
astr_message_event.role = "admin"
assert astr_message_event.is_admin() is True
class TestProcessBuffer:
"""Tests for process_buffer method."""
@pytest.mark.asyncio
async def test_process_buffer_splits_by_pattern(self, astr_message_event):
"""Test process_buffer splits buffer by pattern."""
buffer = "Line 1\nLine 2\nLine 3\nRemaining"
pattern = re.compile(r".*\n")
with patch.object(
astr_message_event, "send", new_callable=AsyncMock
) as mock_send:
result = await astr_message_event.process_buffer(buffer, pattern)
# Should have sent 3 lines and remaining should be "Remaining"
assert mock_send.call_count == 3
assert result == "Remaining"
@pytest.mark.asyncio
async def test_process_buffer_no_match(self, astr_message_event):
"""Test process_buffer returns original when no match."""
buffer = "No newlines here"
pattern = re.compile(r"\n")
result = await astr_message_event.process_buffer(buffer, pattern)
assert result == "No newlines here"
class TestResultHelpers:
"""Tests for result helper methods."""
def test_make_result(self, astr_message_event):
"""Test make_result creates empty MessageEventResult."""
result = astr_message_event.make_result()
assert isinstance(result, MessageEventResult)
def test_plain_result(self, astr_message_event):
"""Test plain_result creates result with text."""
result = astr_message_event.plain_result("Hello")
assert isinstance(result, MessageEventResult)
assert len(result.chain) == 1
assert isinstance(result.chain[0], Plain)
assert result.chain[0].text == "Hello"
def test_image_result_url(self, astr_message_event):
"""Test image_result with URL."""
result = astr_message_event.image_result("http://example.com/image.jpg")
assert isinstance(result, MessageEventResult)
assert len(result.chain) == 1
assert isinstance(result.chain[0], Image)
def test_image_result_path(self, astr_message_event):
"""Test image_result with file path."""
result = astr_message_event.image_result("/path/to/image.jpg")
assert isinstance(result, MessageEventResult)
assert len(result.chain) == 1
assert isinstance(result.chain[0], Image)
class TestGetResult:
"""Tests for get_result and clear_result methods."""
def test_get_result_returns_none_by_default(self, astr_message_event):
"""Test get_result returns None by default."""
assert astr_message_event.get_result() is None
def test_get_result_returns_set_result(self, astr_message_event):
"""Test get_result returns set result."""
result = MessageEventResult().message("Test")
astr_message_event.set_result(result)
assert astr_message_event.get_result() == result
def test_clear_result(self, astr_message_event):
"""Test clear_result clears the result."""
astr_message_event.set_result(MessageEventResult().message("Test"))
astr_message_event.clear_result()
assert astr_message_event.get_result() is None
class TestShouldCallLlm:
"""Tests for should_call_llm method."""
def test_should_call_llm_default(self, astr_message_event):
"""Test call_llm default is False."""
assert astr_message_event.call_llm is False
def test_should_call_llm_when_set(self, astr_message_event):
"""Test should_call_llm sets call_llm."""
astr_message_event.should_call_llm(True)
assert astr_message_event.call_llm is True
class TestRequestLlm:
"""Tests for request_llm method."""
def test_request_llm_basic(self, astr_message_event):
"""Test request_llm creates ProviderRequest."""
request = astr_message_event.request_llm(prompt="Hello")
assert request.prompt == "Hello"
assert request.session_id == ""
assert request.image_urls == []
assert request.contexts == []
def test_request_llm_with_all_params(self, astr_message_event):
"""Test request_llm with all parameters."""
request = astr_message_event.request_llm(
prompt="Hello",
session_id="session123",
image_urls=["http://example.com/img.jpg"],
contexts=[{"role": "user", "content": "Hi"}],
system_prompt="You are helpful",
)
assert request.prompt == "Hello"
assert request.session_id == "session123"
assert request.image_urls == ["http://example.com/img.jpg"]
assert request.contexts == [{"role": "user", "content": "Hi"}]
assert request.system_prompt == "You are helpful"
class TestSendStreaming:
"""Tests for send_streaming method."""
@pytest.mark.asyncio
async def test_send_streaming_sets_has_send_oper(self, astr_message_event):
"""Test send_streaming sets _has_send_oper flag."""
assert astr_message_event._has_send_oper is False
async def generator():
yield MessageEventResult().message("Test")
with patch(
"astrbot.core.platform.astr_message_event.Metric.upload",
new_callable=AsyncMock,
):
await astr_message_event.send_streaming(generator())
assert astr_message_event._has_send_oper is True
class TestSendTyping:
"""Tests for send_typing method."""
@pytest.mark.asyncio
async def test_send_typing_default_empty(self, astr_message_event):
"""Test send_typing default implementation is empty."""
# Should not raise any exception
await astr_message_event.send_typing()
class TestReact:
"""Tests for react method."""
@pytest.mark.asyncio
async def test_react_sends_emoji(self, astr_message_event):
"""Test react sends emoji as message."""
with patch.object(
astr_message_event, "send", new_callable=AsyncMock
) as mock_send:
await astr_message_event.react("👍")
mock_send.assert_called_once()
call_arg = mock_send.call_args[0][0]
# MessageChain is a dataclass with chain attribute
assert len(call_arg.chain) == 1
assert isinstance(call_arg.chain[0], Plain)
assert call_arg.chain[0].text == "👍"
class TestGetGroup:
"""Tests for get_group method."""
@pytest.mark.asyncio
async def test_get_group_returns_none_for_private(self, astr_message_event):
"""Test get_group returns None for private chat."""
result = await astr_message_event.get_group()
assert result is None
@pytest.mark.asyncio
async def test_get_group_with_group_id_param(self, astr_message_event):
"""Test get_group with group_id parameter."""
# Default implementation returns None
result = await astr_message_event.get_group(group_id="group123")
assert result is None
class TestMessageTypeHandling:
"""Tests for message type handling edge cases."""
def test_message_type_from_valid_string(self, platform_meta):
"""Valid MessageType string should be converted correctly."""
message = AstrBotMessage()
message.type = "FRIEND_MESSAGE"
message.message = []
event = ConcreteAstrMessageEvent(
message_str="test",
message_obj=message,
platform_meta=platform_meta,
session_id="session123",
)
assert event.session.message_type == MessageType.FRIEND_MESSAGE
assert event.get_message_type() == MessageType.FRIEND_MESSAGE
def test_message_type_from_invalid_string_defaults_to_friend(self, platform_meta):
"""Invalid message type should default to FRIEND_MESSAGE."""
message = AstrBotMessage()
message.type = "InvalidMessageType"
message.message = []
event = ConcreteAstrMessageEvent(
message_str="test",
message_obj=message,
platform_meta=platform_meta,
session_id="session123",
)
assert event.session.message_type == MessageType.FRIEND_MESSAGE
assert event.get_message_type() == MessageType.FRIEND_MESSAGE
def test_message_type_from_none_defaults_to_friend(self, platform_meta):
"""None message type should default to FRIEND_MESSAGE."""
message = AstrBotMessage()
message.type = None
message.message = []
event = ConcreteAstrMessageEvent(
message_str="test",
message_obj=message,
platform_meta=platform_meta,
session_id="session123",
)
assert event.session.message_type == MessageType.FRIEND_MESSAGE
assert event.get_message_type() == MessageType.FRIEND_MESSAGE
def test_message_type_from_integer_defaults_to_friend(self, platform_meta):
"""Integer message type should default to FRIEND_MESSAGE."""
message = AstrBotMessage()
message.type = 123
message.message = []
event = ConcreteAstrMessageEvent(
message_str="test",
message_obj=message,
platform_meta=platform_meta,
session_id="session123",
)
assert event.session.message_type == MessageType.FRIEND_MESSAGE
assert event.get_message_type() == MessageType.FRIEND_MESSAGE
class TestDefensiveGetattr:
"""Tests for defensive getattr behavior in AstrMessageEvent."""
def test_get_messages_without_message_attr(self, astr_message_event):
"""get_messages should handle message_obj without 'message' attribute."""
astr_message_event.message_obj = type("DummyMessage", (), {})()
messages = astr_message_event.get_messages()
assert isinstance(messages, list)
def test_get_message_type_without_type_attr(self, astr_message_event):
"""get_message_type should handle message_obj without 'type' attribute."""
astr_message_event.message_obj = type("DummyMessage", (), {})()
message_type = astr_message_event.get_message_type()
assert isinstance(message_type, MessageType)
def test_get_sender_fields_without_sender_attr(self, astr_message_event):
"""get_sender_id and get_sender_name should handle missing 'sender'."""
astr_message_event.message_obj = type("DummyMessage", (), {})()
sender_id = astr_message_event.get_sender_id()
sender_name = astr_message_event.get_sender_name()
assert isinstance(sender_id, str)
assert isinstance(sender_name, str)
def test_get_message_type_with_non_enum_type(self, astr_message_event):
"""get_message_type should handle message_obj.type that is not a MessageType."""
class DummyMessage:
def __init__(self):
self.type = "not_an_enum"
self.message = []
astr_message_event.message_obj = DummyMessage()
message_type = astr_message_event.get_message_type()
assert isinstance(message_type, MessageType)
-268
View File
@@ -1,268 +0,0 @@
"""Tests for AstrBotMessage and MessageMember classes."""
import time
from unittest.mock import patch
from astrbot.core.message.components import Image, Plain
from astrbot.core.platform.astrbot_message import AstrBotMessage, Group, MessageMember
from astrbot.core.platform.message_type import MessageType
class TestMessageMember:
"""Tests for MessageMember dataclass."""
def test_message_member_creation_basic(self):
"""Test creating a MessageMember with required fields."""
member = MessageMember(user_id="user123")
assert member.user_id == "user123"
assert member.nickname is None
def test_message_member_creation_with_nickname(self):
"""Test creating a MessageMember with nickname."""
member = MessageMember(user_id="user123", nickname="TestUser")
assert member.user_id == "user123"
assert member.nickname == "TestUser"
def test_message_member_str_with_nickname(self):
"""Test __str__ method with nickname."""
member = MessageMember(user_id="user123", nickname="TestUser")
result = str(member)
assert "User ID: user123" in result
assert "Nickname: TestUser" in result
def test_message_member_str_without_nickname(self):
"""Test __str__ method without nickname."""
member = MessageMember(user_id="user123")
result = str(member)
assert "User ID: user123" in result
assert "Nickname: N/A" in result
class TestGroup:
"""Tests for Group dataclass."""
def test_group_creation_basic(self):
"""Test creating a Group with required fields."""
group = Group(group_id="group123")
assert group.group_id == "group123"
assert group.group_name is None
assert group.group_avatar is None
assert group.group_owner is None
assert group.group_admins is None
assert group.members is None
def test_group_creation_with_all_fields(self):
"""Test creating a Group with all fields."""
members = [MessageMember(user_id="user1"), MessageMember(user_id="user2")]
group = Group(
group_id="group123",
group_name="Test Group",
group_avatar="http://example.com/avatar.jpg",
group_owner="owner123",
group_admins=["admin1", "admin2"],
members=members,
)
assert group.group_id == "group123"
assert group.group_name == "Test Group"
assert group.group_avatar == "http://example.com/avatar.jpg"
assert group.group_owner == "owner123"
assert group.group_admins == ["admin1", "admin2"]
assert group.members == members
def test_group_str_with_all_fields(self):
"""Test __str__ method with all fields."""
members = [MessageMember(user_id="user1", nickname="User One")]
group = Group(
group_id="group123",
group_name="Test Group",
group_avatar="http://example.com/avatar.jpg",
group_owner="owner123",
group_admins=["admin1"],
members=members,
)
result = str(group)
assert "Group ID: group123" in result
assert "Name: Test Group" in result
assert "Avatar: http://example.com/avatar.jpg" in result
assert "Owner ID: owner123" in result
assert "Admin IDs: ['admin1']" in result
assert "Members Len: 1" in result
def test_group_str_with_minimal_fields(self):
"""Test __str__ method with minimal fields."""
group = Group(group_id="group123")
result = str(group)
assert "Group ID: group123" in result
assert "Name: N/A" in result
assert "Avatar: N/A" in result
assert "Owner ID: N/A" in result
assert "Admin IDs: N/A" in result
assert "Members Len: 0" in result
assert "First Member: N/A" in result
class TestAstrBotMessage:
"""Tests for AstrBotMessage class."""
def test_astrbot_message_creation(self):
"""Test creating an AstrBotMessage."""
message = AstrBotMessage()
assert message.group is None
assert message.timestamp is not None
assert isinstance(message.timestamp, int)
def test_astrbot_message_timestamp(self):
"""Test timestamp is set on creation."""
with patch.object(time, "time", return_value=1234567890):
message = AstrBotMessage()
assert message.timestamp == 1234567890
def test_astrbot_message_all_attributes(self):
"""Test setting all attributes on AstrBotMessage."""
message = AstrBotMessage()
message.type = MessageType.FRIEND_MESSAGE
message.self_id = "bot123"
message.session_id = "session123"
message.message_id = "msg123"
message.sender = MessageMember(user_id="user123", nickname="TestUser")
message.message = [Plain(text="Hello")]
message.message_str = "Hello"
message.raw_message = {"raw": "data"}
assert message.type == MessageType.FRIEND_MESSAGE
assert message.self_id == "bot123"
assert message.session_id == "session123"
assert message.message_id == "msg123"
assert message.sender.user_id == "user123"
assert len(message.message) == 1
assert message.message_str == "Hello"
assert message.raw_message == {"raw": "data"}
def test_astrbot_message_str(self):
"""Test __str__ method."""
message = AstrBotMessage()
message.type = MessageType.FRIEND_MESSAGE
message.self_id = "bot123"
result = str(message)
assert "'type'" in result
assert "'self_id'" in result
class TestAstrBotMessageGroupId:
"""Tests for AstrBotMessage group_id property."""
def test_group_id_returns_empty_when_no_group(self):
"""Test group_id returns empty string when group is None."""
message = AstrBotMessage()
assert message.group_id == ""
def test_group_id_returns_group_id_when_group_exists(self):
"""Test group_id returns the group's id when group exists."""
message = AstrBotMessage()
message.group = Group(group_id="group123")
assert message.group_id == "group123"
def test_group_id_setter_creates_new_group(self):
"""Test group_id setter creates a new group if none exists."""
message = AstrBotMessage()
message.group_id = "new_group123"
assert message.group is not None
assert message.group.group_id == "new_group123"
def test_group_id_setter_updates_existing_group(self):
"""Test group_id setter updates existing group's id."""
message = AstrBotMessage()
message.group = Group(group_id="old_group")
message.group_id = "new_group"
assert message.group.group_id == "new_group"
def test_group_id_setter_with_none_removes_group(self):
"""Test group_id setter with None removes the group."""
message = AstrBotMessage()
message.group = Group(group_id="group123")
message.group_id = None
assert message.group is None
def test_group_id_setter_with_empty_string_removes_group(self):
"""Test group_id setter with empty string removes the group."""
message = AstrBotMessage()
message.group = Group(group_id="group123")
message.group_id = ""
assert message.group is None
class TestAstrBotMessageTypes:
"""Tests for AstrBotMessage with different message types."""
def test_friend_message_type(self):
"""Test AstrBotMessage with FRIEND_MESSAGE type."""
message = AstrBotMessage()
message.type = MessageType.FRIEND_MESSAGE
assert message.type == MessageType.FRIEND_MESSAGE
assert message.type.value == "FriendMessage"
def test_group_message_type(self):
"""Test AstrBotMessage with GROUP_MESSAGE type."""
message = AstrBotMessage()
message.type = MessageType.GROUP_MESSAGE
assert message.type == MessageType.GROUP_MESSAGE
assert message.type.value == "GroupMessage"
def test_other_message_type(self):
"""Test AstrBotMessage with OTHER_MESSAGE type."""
message = AstrBotMessage()
message.type = MessageType.OTHER_MESSAGE
assert message.type == MessageType.OTHER_MESSAGE
assert message.type.value == "OtherMessage"
class TestAstrBotMessageChain:
"""Tests for AstrBotMessage message chain."""
def test_message_chain_with_plain_text(self):
"""Test message chain with plain text."""
message = AstrBotMessage()
message.message = [Plain(text="Hello world")]
assert len(message.message) == 1
assert isinstance(message.message[0], Plain)
assert message.message[0].text == "Hello world"
def test_message_chain_with_multiple_components(self):
"""Test message chain with multiple components."""
message = AstrBotMessage()
message.message = [
Plain(text="Hello "),
Plain(text="world"),
Image(file="http://example.com/img.jpg"),
]
assert len(message.message) == 3
assert isinstance(message.message[0], Plain)
assert isinstance(message.message[1], Plain)
assert isinstance(message.message[2], Image)
def test_message_chain_empty(self):
"""Test empty message chain."""
message = AstrBotMessage()
message.message = []
assert len(message.message) == 0