Compare commits

...

21 Commits

Author SHA1 Message Date
Soulter 602ae4eee2 feat(heihe): enhance configuration metadata and add new parameters for WebSocket connection 2026-02-25 16:21:54 +08:00
Soulter 3d82f42311 Merge remote-tracking branch 'origin/master' into feat/heibox 2026-02-25 12:00:06 +08:00
エイカク 5530a2260a feat(dashboard): add generic desktop app updater bridge (#5424)
* feat(dashboard): add generic desktop app updater bridge

* fix(dashboard): address updater bridge review feedback

* fix(dashboard): unify updater bridge types and error logging

* fix(dashboard): consolidate updater bridge typings
2026-02-25 10:01:13 +09:00
Soulter c24de24ca4 chore: ruff format 2026-02-24 23:12:18 +08:00
Yunhao Cao b54b4c79ed fix: Telegram voice message format (OGG instead of WAV) causing issues with OpenAI STT API (#5389) 2026-02-24 23:11:56 +08:00
Soulter c6cc7aae84 chore: bump version to 4.18.2 2026-02-24 23:08:53 +08:00
Soulter 84cd209074 chore: bump version to 4.18.2 2026-02-24 22:48:27 +08:00
Soulter afda44fbe3 chore: bump version to 4.18.2 2026-02-24 22:44:35 +08:00
Soulter f5d3b93437 fix(context): improve logging for platform not found in session 2026-02-24 22:37:51 +08:00
Soulter 069a3628fa fix(context): log warning when platform not found for session 2026-02-24 22:37:10 +08:00
氕氙 c81ef2672a fix: pass embedding dimensions to provider apis (#5411) 2026-02-24 22:09:44 +08:00
Soulter a5ae27cae0 fix(aiocqhttp): enhance shutdown process for aiocqhttp adapter (#5412) 2026-02-24 22:07:42 +08:00
Helian Nuits 73faaf6577 i18n(SubAgentPage): complete internationalization for subagent orchestration page (#5400)
* i18n: complete internationalization for subagent orchestration page

- Replace hardcoded English strings in [SubAgentPage.vue] with i18n keys.
- Update `en-US` and `zh-CN` locales with missing hints, validation messages, and empty state translations.
- Fix translation typos and improve consistency across the SubAgent orchestration UI.

* fix(bug_risk): 避免在模板中的翻译调用上使用 || 'Close' 作为回退值。
2026-02-24 21:04:01 +08:00
Helian Nuits 29dbd085d4 fix(core): 优化 File 组件处理逻辑并增强 OneBot 驱动层路径兼容性 (#5391)
* fix(core): 优化 File 组件处理逻辑并增强 OneBot 驱动层路径兼容性

原因 (Necessity):
1. 内核一致性:AstrBot 内核的 Record 和 Video 组件均具备识别 `file:///` 协议头的逻辑,但 File 组件此前缺失此功能,导致行为不统一。
2. OneBot 协议合规:OneBot 11 标准要求本地文件路径必须使用 `file:///` 协议头。此前驱动层未对裸路径进行自动转换,导致发送本地文件时常触发 retcode 1200 (识别URL失败) 错误。
3. 容器环境适配:在 Docker 等路径隔离环境下,裸路径更容易因驱动或协议端的解析歧义而失效。

更改 (Changes):
- [astrbot/core/message/components.py]:
  - 在 File.get_file() 中增加对 `file:///` 前缀的识别与剥离逻辑,使其与 Record/Video 组件行为对齐。
- [astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py]:
  - 在发送文件前增加自动修正逻辑:若路径为绝对路径且未包含协议头,驱动层将自动补全 `file:///` 前缀。
  - 对 http、base64 及已有协议头,确保不干扰原有的正常传输逻辑。

影响 (Impact):
- 以完全兼容的方式增强了文件发送的鲁棒性。
- 解决了插件在发送日志等本地生成的压缩包时,因路径格式不规范导致的发送失败问题。

* refactor(core): 根据 cr 建议,规范化文件 URI 生成与解析逻辑,优化跨平台兼容性

原因 (Necessity):
1. 修复原生路径与 URI 转换在 Windows 下的不对称问题。
2. 规范化 file: 协议头处理,确保符合 RFC 标准并能在 Linux/Windows 间稳健切换。
3. 增强协议判定准确度,防止对普通绝对路径的误处理。

更改 (Changes):
- [astrbot/core/platform/sources/aiocqhttp]:
  - 弃用手动拼接,改用 `pathlib.Path.as_uri()` 生成标准 URI。
  - 将协议检测逻辑从前缀匹配优化为包含性检测 ("://")。
- [astrbot/core/message/components]:
  - 重构 `File.get_file` 解析逻辑,支持对称处理 2/3 斜杠格式。
  - 针对 Windows 环境增加了对 `file:///C:/` 格式的自动修正,避免 `os.path` 识别失效。
- [data/plugins/astrbot_plugin_logplus]:
  - 在直接 API 调用中同步应用 URI 规范化处理。

影响 (Impact):
- 解决 Docker 环境中因路径不规范导致的 "识别URL失败" 报错。
- 提升了本体框架在 Windows 系统下的文件操作鲁棒性。
2026-02-24 21:03:06 +08:00
Axi404 00b011809a fix: enforce admin guard for sandbox file transfer tools (#5402)
* fix: enforce admin guard for sandbox file transfer tools

* refactor: deduplicate computer tools admin permission checks

* fix: add missing space in permission error message
2026-02-24 20:59:44 +08:00
Axi404 0b46ca7ff3 feat: enable computer-use tools for subagent handoff (#5399) 2026-02-24 16:32:12 +08:00
whatevertogo 9294b44831 fix: resolve pipeline and star import cycles (#5353)
* fix: resolve pipeline and star import cycles

- Add bootstrap.py and stage_order.py to break circular dependencies
- Export Context, PluginManager, StarTools from star module
- Update pipeline __init__ to defer imports
- Split pipeline initialization into separate bootstrap module

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* fix: add logging for get_config() failure in Star class

* fix: reorder logger initialization in base.py

---------

Co-authored-by: whatevertogo <whatevertogo@users.noreply.github.com>
Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-02-24 13:53:29 +08:00
Soulter 80fd51119b feat: add support for showing tool call results in agent execution (#5388)
closes: #5329
2026-02-24 00:46:45 +08:00
whatevertogo 5af5ad9e36 test: add comprehensive tests for message event handling (#5355)
* test: add comprehensive tests for message event handling

- Add AstrMessageEvent unit tests (688 lines)
- Add AstrBotMessage unit tests
- Enhance smoke tests with message event scenarios

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* fix: improve message type handling and add defensive tests

---------

Co-authored-by: whatevertogo <whatevertogo@users.noreply.github.com>
Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-02-23 23:36:39 +08:00
whatevertogo 7b731ebda8 test: enhance test framework with comprehensive fixtures and mocks (#5354)
* test: enhance test framework with comprehensive fixtures and mocks

- Add shared mock builders for aiocqhttp, discord, telegram
- Add test helpers for platform configs and mock objects
- Expand conftest.py with test profile support
- Update coverage test workflow configuration

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* refactor(tests): 移动并重构模拟 LLM 响应和消息组件函数

* fix(tests): 优化 pytest_runtest_setup 中的标记检查逻辑

---------

Co-authored-by: whatevertogo <whatevertogo@users.noreply.github.com>
Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-02-23 23:35:15 +08:00
Soulter aec2e3bb91 feat: supports 小黑盒语音机器人 2026-02-14 00:44:35 +08:00
57 changed files with 3930 additions and 272 deletions
+1 -1
View File
@@ -37,7 +37,7 @@ jobs:
mkdir -p data/temp
export TESTING=true
export ZHIPU_API_KEY=${{ secrets.OPENAI_API_KEY }}
pytest --cov=. -v -o log_cli=true -o log_level=DEBUG
pytest --cov=astrbot -v -o log_cli=true -o log_level=DEBUG
- name: Upload results to Codecov
uses: codecov/codecov-action@v5
+1 -1
View File
@@ -1 +1 @@
__version__ = "4.18.1"
__version__ = "4.18.2"
+92 -16
View File
@@ -24,15 +24,77 @@ def _should_stop_agent(astr_event) -> bool:
return astr_event.is_stopped() or bool(astr_event.get_extra("agent_stop_requested"))
def _truncate_tool_result(text: str, limit: int = 70) -> str:
if limit <= 0:
return ""
if len(text) <= limit:
return text
if limit <= 3:
return text[:limit]
return f"{text[: limit - 3]}..."
def _extract_chain_json_data(msg_chain: MessageChain) -> dict | None:
if not msg_chain.chain:
return None
first_comp = msg_chain.chain[0]
if isinstance(first_comp, Json) and isinstance(first_comp.data, dict):
return first_comp.data
return None
def _record_tool_call_name(
tool_info: dict | None, tool_name_by_call_id: dict[str, str]
) -> None:
if not isinstance(tool_info, dict):
return
tool_call_id = tool_info.get("id")
tool_name = tool_info.get("name")
if tool_call_id is None or tool_name is None:
return
tool_name_by_call_id[str(tool_call_id)] = str(tool_name)
def _build_tool_call_status_message(tool_info: dict | None) -> str:
if tool_info:
return f"🔨 调用工具: {tool_info.get('name', 'unknown')}"
return "🔨 调用工具..."
def _build_tool_result_status_message(
msg_chain: MessageChain, tool_name_by_call_id: dict[str, str]
) -> str:
tool_name = "unknown"
tool_result = ""
result_data = _extract_chain_json_data(msg_chain)
if result_data:
tool_call_id = result_data.get("id")
if tool_call_id is not None:
tool_name = tool_name_by_call_id.pop(str(tool_call_id), "unknown")
tool_result = str(result_data.get("result", ""))
if not tool_result:
tool_result = msg_chain.get_plain_text(with_other_comps_mark=True)
tool_result = _truncate_tool_result(tool_result, 70)
status_msg = f"🔨 调用工具: {tool_name}"
if tool_result:
status_msg = f"{status_msg}\n📎 返回结果: {tool_result}"
return status_msg
async def run_agent(
agent_runner: AgentRunner,
max_step: int = 30,
show_tool_use: bool = True,
show_tool_call_result: bool = False,
stream_to_general: bool = False,
show_reasoning: bool = False,
) -> AsyncGenerator[MessageChain | None, None]:
step_idx = 0
astr_event = agent_runner.run_context.context.event
tool_name_by_call_id: dict[str, str] = {}
while step_idx < max_step + 1:
step_idx += 1
@@ -90,6 +152,13 @@ async def run_agent(
continue
if astr_event.get_platform_id() == "webchat":
await astr_event.send(msg_chain)
elif show_tool_use and show_tool_call_result:
status_msg = _build_tool_result_status_message(
msg_chain, tool_name_by_call_id
)
await astr_event.send(
MessageChain(type="tool_call").message(status_msg)
)
# 对于其他情况,暂时先不处理
continue
elif resp.type == "tool_call":
@@ -97,25 +166,22 @@ async def run_agent(
# 用来标记流式响应需要分节
yield MessageChain(chain=[], type="break")
tool_info = None
if resp.data["chain"].chain:
json_comp = resp.data["chain"].chain[0]
if isinstance(json_comp, Json):
tool_info = json_comp.data
astr_event.trace.record(
"agent_tool_call",
tool_name=tool_info if tool_info else "unknown",
)
tool_info = _extract_chain_json_data(resp.data["chain"])
astr_event.trace.record(
"agent_tool_call",
tool_name=tool_info if tool_info else "unknown",
)
_record_tool_call_name(tool_info, tool_name_by_call_id)
if astr_event.get_platform_name() == "webchat":
await astr_event.send(resp.data["chain"])
elif show_tool_use:
if tool_info:
m = f"🔨 调用工具: {tool_info.get('name', 'unknown')}"
else:
m = "🔨 调用工具..."
chain = MessageChain(type="tool_call").message(m)
if show_tool_call_result and isinstance(tool_info, dict):
# Delay tool status notification until tool_call_result.
continue
chain = MessageChain(type="tool_call").message(
_build_tool_call_status_message(tool_info)
)
await astr_event.send(chain)
continue
@@ -202,6 +268,7 @@ async def run_live_agent(
tts_provider: TTSProvider | None = None,
max_step: int = 30,
show_tool_use: bool = True,
show_tool_call_result: bool = False,
show_reasoning: bool = False,
) -> AsyncGenerator[MessageChain | None, None]:
"""Live Mode 的 Agent 运行器,支持流式 TTS
@@ -211,6 +278,7 @@ async def run_live_agent(
tts_provider: TTS Provider 实例
max_step: 最大步数
show_tool_use: 是否显示工具使用
show_tool_call_result: 是否显示工具返回结果
show_reasoning: 是否显示推理过程
Yields:
@@ -222,6 +290,7 @@ async def run_live_agent(
agent_runner,
max_step=max_step,
show_tool_use=show_tool_use,
show_tool_call_result=show_tool_call_result,
stream_to_general=False,
show_reasoning=show_reasoning,
):
@@ -250,7 +319,12 @@ async def run_live_agent(
# 1. 启动 Agent Feeder 任务:负责运行 Agent 并将文本分句喂给 text_queue
feeder_task = asyncio.create_task(
_run_agent_feeder(
agent_runner, text_queue, max_step, show_tool_use, show_reasoning
agent_runner,
text_queue,
max_step,
show_tool_use,
show_tool_call_result,
show_reasoning,
)
)
@@ -336,6 +410,7 @@ async def _run_agent_feeder(
text_queue: asyncio.Queue,
max_step: int,
show_tool_use: bool,
show_tool_call_result: bool,
show_reasoning: bool,
) -> None:
"""运行 Agent 并将文本输出分句放入队列"""
@@ -345,6 +420,7 @@ async def _run_agent_feeder(
agent_runner,
max_step=max_step,
show_tool_use=show_tool_use,
show_tool_call_result=show_tool_call_result,
stream_to_general=False,
show_reasoning=show_reasoning,
):
+67 -13
View File
@@ -17,6 +17,12 @@ from astrbot.core.agent.tool_executor import BaseFunctionToolExecutor
from astrbot.core.astr_agent_context import AstrAgentContext
from astrbot.core.astr_main_agent_resources import (
BACKGROUND_TASK_RESULT_WOKE_SYSTEM_PROMPT,
EXECUTE_SHELL_TOOL,
FILE_DOWNLOAD_TOOL,
FILE_UPLOAD_TOOL,
LOCAL_EXECUTE_SHELL_TOOL,
LOCAL_PYTHON_TOOL,
PYTHON_TOOL,
SEND_MESSAGE_TO_USER_TOOL,
)
from astrbot.core.cron.events import CronMessageEvent
@@ -91,6 +97,65 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
yield r
return
@classmethod
def _get_runtime_computer_tools(cls, runtime: str) -> dict[str, FunctionTool]:
if runtime == "sandbox":
return {
EXECUTE_SHELL_TOOL.name: EXECUTE_SHELL_TOOL,
PYTHON_TOOL.name: PYTHON_TOOL,
FILE_UPLOAD_TOOL.name: FILE_UPLOAD_TOOL,
FILE_DOWNLOAD_TOOL.name: FILE_DOWNLOAD_TOOL,
}
if runtime == "local":
return {
LOCAL_EXECUTE_SHELL_TOOL.name: LOCAL_EXECUTE_SHELL_TOOL,
LOCAL_PYTHON_TOOL.name: LOCAL_PYTHON_TOOL,
}
return {}
@classmethod
def _build_handoff_toolset(
cls,
run_context: ContextWrapper[AstrAgentContext],
tools: list[str | FunctionTool] | None,
) -> ToolSet | None:
ctx = run_context.context.context
event = run_context.context.event
cfg = ctx.get_config(umo=event.unified_msg_origin)
provider_settings = cfg.get("provider_settings", {})
runtime = str(provider_settings.get("computer_use_runtime", "local"))
runtime_computer_tools = cls._get_runtime_computer_tools(runtime)
# Keep persona semantics aligned with the main agent: tools=None means
# "all tools", including runtime computer-use tools.
if tools is None:
toolset = ToolSet()
for registered_tool in llm_tools.func_list:
if isinstance(registered_tool, HandoffTool):
continue
if registered_tool.active:
toolset.add_tool(registered_tool)
for runtime_tool in runtime_computer_tools.values():
toolset.add_tool(runtime_tool)
return None if toolset.empty() else toolset
if not tools:
return None
toolset = ToolSet()
for tool_name_or_obj in tools:
if isinstance(tool_name_or_obj, str):
registered_tool = llm_tools.get_func(tool_name_or_obj)
if registered_tool and registered_tool.active:
toolset.add_tool(registered_tool)
continue
runtime_tool = runtime_computer_tools.get(tool_name_or_obj)
if runtime_tool:
toolset.add_tool(runtime_tool)
elif isinstance(tool_name_or_obj, FunctionTool):
toolset.add_tool(tool_name_or_obj)
return None if toolset.empty() else toolset
@classmethod
async def _execute_handoff(
cls,
@@ -101,19 +166,8 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
input_ = tool_args.get("input")
image_urls = tool_args.get("image_urls")
# make toolset for the agent
tools = tool.agent.tools
if tools:
toolset = ToolSet()
for t in tools:
if isinstance(t, str):
_t = llm_tools.get_func(t)
if _t:
toolset.add_tool(_t)
elif isinstance(t, FunctionTool):
toolset.add_tool(t)
else:
toolset = None
# Build handoff toolset from registered tools plus runtime computer tools.
toolset = cls._build_handoff_toolset(run_context, tool.agent.tools)
ctx = run_context.context.context
event = run_context.context.event
+5
View File
@@ -11,6 +11,7 @@ from astrbot.core.message.components import File
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from ..computer_client import get_booter
from .permissions import check_admin_permission
# @dataclass
# class CreateFileTool(FunctionTool):
@@ -102,6 +103,8 @@ class FileUploadTool(FunctionTool):
context: ContextWrapper[AstrAgentContext],
local_path: str,
) -> str | None:
if permission_error := check_admin_permission(context, "File upload/download"):
return permission_error
sb = await get_booter(
context.context.context,
context.context.event.unified_msg_origin,
@@ -161,6 +164,8 @@ class FileDownloadTool(FunctionTool):
remote_path: str,
also_send_to_user: bool = True,
) -> ToolExecResult:
if permission_error := check_admin_permission(context, "File upload/download"):
return permission_error
sb = await get_booter(
context.context.context,
context.context.event.unified_msg_origin,
@@ -0,0 +1,19 @@
from astrbot.core.agent.run_context import ContextWrapper
from astrbot.core.astr_agent_context import AstrAgentContext
def check_admin_permission(
context: ContextWrapper[AstrAgentContext], operation_name: str
) -> str | None:
cfg = context.context.context.get_config(
umo=context.context.event.unified_msg_origin
)
provider_settings = cfg.get("provider_settings", {})
require_admin = provider_settings.get("computer_use_require_admin", True)
if require_admin and context.context.event.role != "admin":
return (
f"error: Permission denied. {operation_name} is only allowed for admin users. "
"Tell user to set admins in `AstrBot WebUI -> Config -> General Config` by adding their user ID to the admins list if they need this feature. "
f"User's ID is: {context.context.event.get_sender_id()}. User's ID can be found by using /sid command."
)
return None
+3 -17
View File
@@ -7,6 +7,7 @@ from astrbot.core.agent.run_context import ContextWrapper
from astrbot.core.agent.tool import ToolExecResult
from astrbot.core.astr_agent_context import AstrAgentContext, AstrMessageEvent
from astrbot.core.computer.computer_client import get_booter, get_local_booter
from astrbot.core.computer.tools.permissions import check_admin_permission
from astrbot.core.message.message_event_result import MessageChain
param_schema = {
@@ -26,21 +27,6 @@ param_schema = {
}
def _check_admin_permission(context: ContextWrapper[AstrAgentContext]) -> str | None:
cfg = context.context.context.get_config(
umo=context.context.event.unified_msg_origin
)
provider_settings = cfg.get("provider_settings", {})
require_admin = provider_settings.get("computer_use_require_admin", True)
if require_admin and context.context.event.role != "admin":
return (
"error: Permission denied. Python execution is only allowed for admin users. "
"Tell user to set admins in `AstrBot WebUI -> Config -> General Config` by adding their user ID to the admins list if they need this feature."
f"User's ID is: {context.context.event.get_sender_id()}. User's ID can be found by using /sid command."
)
return None
async def handle_result(result: dict, event: AstrMessageEvent) -> ToolExecResult:
data = result.get("data", {})
output = data.get("output", {})
@@ -81,7 +67,7 @@ class PythonTool(FunctionTool):
async def call(
self, context: ContextWrapper[AstrAgentContext], code: str, silent: bool = False
) -> ToolExecResult:
if permission_error := _check_admin_permission(context):
if permission_error := check_admin_permission(context, "Python execution"):
return permission_error
sb = await get_booter(
context.context.context,
@@ -104,7 +90,7 @@ class LocalPythonTool(FunctionTool):
async def call(
self, context: ContextWrapper[AstrAgentContext], code: str, silent: bool = False
) -> ToolExecResult:
if permission_error := _check_admin_permission(context):
if permission_error := check_admin_permission(context, "Python execution"):
return permission_error
sb = get_local_booter()
try:
+2 -16
View File
@@ -7,21 +7,7 @@ from astrbot.core.agent.tool import ToolExecResult
from astrbot.core.astr_agent_context import AstrAgentContext
from ..computer_client import get_booter, get_local_booter
def _check_admin_permission(context: ContextWrapper[AstrAgentContext]) -> str | None:
cfg = context.context.context.get_config(
umo=context.context.event.unified_msg_origin
)
provider_settings = cfg.get("provider_settings", {})
require_admin = provider_settings.get("computer_use_require_admin", True)
if require_admin and context.context.event.role != "admin":
return (
"error: Permission denied. Shell execution is only allowed for admin users. "
"Tell user to set admins in `AstrBot WebUI -> Config -> General Config` by adding their user ID to the admins list if they need this feature."
f"User's ID is: {context.context.event.get_sender_id()}. User's ID can be found by using /sid command."
)
return None
from .permissions import check_admin_permission
@dataclass
@@ -61,7 +47,7 @@ class ExecuteShellTool(FunctionTool):
background: bool = False,
env: dict = {},
) -> ToolExecResult:
if permission_error := _check_admin_permission(context):
if permission_error := check_admin_permission(context, "Shell execution"):
return permission_error
if self.is_local:
+14 -1
View File
@@ -5,7 +5,7 @@ from typing import Any, TypedDict
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
VERSION = "4.18.1"
VERSION = "4.18.2"
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db")
WEBHOOK_SUPPORTED_PLATFORMS = [
@@ -100,6 +100,7 @@ DEFAULT_CONFIG = {
"dequeue_context_length": 1,
"streaming_response": False,
"show_tool_use_status": False,
"show_tool_call_result": False,
"sanitize_context_by_modalities": False,
"max_quoted_fallback_images": 20,
"quoted_message_parser": {
@@ -2306,6 +2307,9 @@ CONFIG_METADATA_2 = {
"show_tool_use_status": {
"type": "bool",
},
"show_tool_call_result": {
"type": "bool",
},
"unsupported_streaming_strategy": {
"type": "string",
},
@@ -2994,6 +2998,15 @@ CONFIG_METADATA_3 = {
"provider_settings.agent_runner_type": "local",
},
},
"provider_settings.show_tool_call_result": {
"description": "输出函数调用返回结果",
"type": "bool",
"hint": "仅在输出函数调用状态启用时生效,展示结果前 70 个字符。",
"condition": {
"provider_settings.agent_runner_type": "local",
"provider_settings.show_tool_use_status": True,
},
},
"provider_settings.sanitize_context_by_modalities": {
"description": "按模型能力清理历史上下文",
"type": "bool",
+28 -3
View File
@@ -720,13 +720,38 @@ class File(BaseMessageComponent):
if allow_return_url and self.url:
return self.url
if self.file_ and os.path.exists(self.file_):
return os.path.abspath(self.file_)
if self.file_:
path = self.file_
if path.startswith("file://"):
# 处理 file:// (2 slashes) 或 file:/// (3 slashes)
# pathlib.as_uri() 通常生成 file:///
path = path[7:]
# 兼容 Windows: file:///C:/path -> /C:/path -> C:/path
if (
os.name == "nt"
and len(path) > 2
and path[0] == "/"
and path[2] == ":"
):
path = path[1:]
if os.path.exists(path):
return os.path.abspath(path)
if self.url:
await self._download_file()
if self.file_:
return os.path.abspath(self.file_)
path = self.file_
if path.startswith("file://"):
path = path[7:]
if (
os.name == "nt"
and len(path) > 2
and path[0] == "/"
and path[2] == ":"
):
path = path[1:]
return os.path.abspath(path)
return ""
+66 -21
View File
@@ -1,30 +1,60 @@
"""Pipeline package exports.
This module intentionally avoids eager imports of all pipeline stage modules to
prevent import-time cycles. Stage classes remain available via lazy attribute
resolution for backward compatibility.
"""
from __future__ import annotations
from importlib import import_module
from typing import Any
from astrbot.core.message.message_event_result import (
EventResultType,
MessageEventResult,
)
from .content_safety_check.stage import ContentSafetyCheckStage
from .preprocess_stage.stage import PreProcessStage
from .process_stage.stage import ProcessStage
from .rate_limit_check.stage import RateLimitStage
from .respond.stage import RespondStage
from .result_decorate.stage import ResultDecorateStage
from .session_status_check.stage import SessionStatusCheckStage
from .waking_check.stage import WakingCheckStage
from .whitelist_check.stage import WhitelistCheckStage
from .stage_order import STAGES_ORDER
# 管道阶段顺序
STAGES_ORDER = [
"WakingCheckStage", # 检查是否需要唤醒
"WhitelistCheckStage", # 检查是否在群聊/私聊白名单
"SessionStatusCheckStage", # 检查会话是否整体启用
"RateLimitStage", # 检查会话是否超过频率限制
"ContentSafetyCheckStage", # 检查内容安全
"PreProcessStage", # 预处理
"ProcessStage", # 交由 Stars 处理(a.k.a 插件),或者 LLM 调用
"ResultDecorateStage", # 处理结果,比如添加回复前缀、t2i、转换为语音 等
"RespondStage", # 发送消息
]
_LAZY_EXPORTS = {
"ContentSafetyCheckStage": (
"astrbot.core.pipeline.content_safety_check.stage",
"ContentSafetyCheckStage",
),
"PreProcessStage": (
"astrbot.core.pipeline.preprocess_stage.stage",
"PreProcessStage",
),
"ProcessStage": (
"astrbot.core.pipeline.process_stage.stage",
"ProcessStage",
),
"RateLimitStage": (
"astrbot.core.pipeline.rate_limit_check.stage",
"RateLimitStage",
),
"RespondStage": (
"astrbot.core.pipeline.respond.stage",
"RespondStage",
),
"ResultDecorateStage": (
"astrbot.core.pipeline.result_decorate.stage",
"ResultDecorateStage",
),
"SessionStatusCheckStage": (
"astrbot.core.pipeline.session_status_check.stage",
"SessionStatusCheckStage",
),
"WakingCheckStage": (
"astrbot.core.pipeline.waking_check.stage",
"WakingCheckStage",
),
"WhitelistCheckStage": (
"astrbot.core.pipeline.whitelist_check.stage",
"WhitelistCheckStage",
),
}
__all__ = [
"ContentSafetyCheckStage",
@@ -36,6 +66,21 @@ __all__ = [
"RespondStage",
"ResultDecorateStage",
"SessionStatusCheckStage",
"STAGES_ORDER",
"WakingCheckStage",
"WhitelistCheckStage",
]
def __getattr__(name: str) -> Any:
if name not in _LAZY_EXPORTS:
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
module_path, attr_name = _LAZY_EXPORTS[name]
module = import_module(module_path)
value = getattr(module, attr_name)
globals()[name] = value
return value
def __dir__() -> list[str]:
return sorted(set(globals()) | set(__all__))
+52
View File
@@ -0,0 +1,52 @@
"""Pipeline bootstrap utilities."""
from importlib import import_module
from .stage import registered_stages
_BUILTIN_STAGE_MODULES = (
"astrbot.core.pipeline.waking_check.stage",
"astrbot.core.pipeline.whitelist_check.stage",
"astrbot.core.pipeline.session_status_check.stage",
"astrbot.core.pipeline.rate_limit_check.stage",
"astrbot.core.pipeline.content_safety_check.stage",
"astrbot.core.pipeline.preprocess_stage.stage",
"astrbot.core.pipeline.process_stage.stage",
"astrbot.core.pipeline.result_decorate.stage",
"astrbot.core.pipeline.respond.stage",
)
_EXPECTED_STAGE_NAMES = {
"WakingCheckStage",
"WhitelistCheckStage",
"SessionStatusCheckStage",
"RateLimitStage",
"ContentSafetyCheckStage",
"PreProcessStage",
"ProcessStage",
"ResultDecorateStage",
"RespondStage",
}
_builtin_stages_registered = False
def ensure_builtin_stages_registered() -> None:
"""Ensure built-in pipeline stages are imported and registered."""
global _builtin_stages_registered
if _builtin_stages_registered:
return
stage_names = {stage_cls.__name__ for stage_cls in registered_stages}
if _EXPECTED_STAGE_NAMES.issubset(stage_names):
_builtin_stages_registered = True
return
for module_path in _BUILTIN_STAGE_MODULES:
import_module(module_path)
_builtin_stages_registered = True
__all__ = ["ensure_builtin_stages_registered"]
+4 -2
View File
@@ -1,7 +1,9 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Any
from astrbot.core.config import AstrBotConfig
from astrbot.core.star import PluginManager
from .context_utils import call_event_hook, call_handler
@@ -11,7 +13,7 @@ class PipelineContext:
"""上下文对象,包含管道执行所需的上下文信息"""
astrbot_config: AstrBotConfig # AstrBot 配置对象
plugin_manager: PluginManager # 插件管理器对象
plugin_manager: Any # 插件管理器对象
astrbot_config_id: str
call_handler = call_handler
call_event_hook = call_event_hook
@@ -19,6 +19,7 @@ from astrbot.core.message.message_event_result import (
MessageEventResult,
ResultContentType,
)
from astrbot.core.pipeline.stage import Stage
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.provider.entities import (
LLMResponse,
@@ -30,7 +31,6 @@ from astrbot.core.utils.session_lock import session_lock_manager
from .....astr_agent_run_util import run_agent, run_live_agent
from ....context import PipelineContext, call_event_hook
from ...stage import Stage
class InternalAgentSubStage(Stage):
@@ -54,6 +54,7 @@ class InternalAgentSubStage(Stage):
if isinstance(self.max_step, bool): # workaround: #2622
self.max_step = 30
self.show_tool_use: bool = settings.get("show_tool_use_status", True)
self.show_tool_call_result: bool = settings.get("show_tool_call_result", False)
self.show_reasoning = settings.get("display_reasoning_text", False)
self.sanitize_context_by_modalities: bool = settings.get(
"sanitize_context_by_modalities",
@@ -240,6 +241,7 @@ class InternalAgentSubStage(Stage):
tts_provider,
self.max_step,
self.show_tool_use,
self.show_tool_call_result,
show_reasoning=self.show_reasoning,
),
),
@@ -269,6 +271,7 @@ class InternalAgentSubStage(Stage):
agent_runner,
self.max_step,
self.show_tool_use,
self.show_tool_call_result,
show_reasoning=self.show_reasoning,
),
),
@@ -297,6 +300,7 @@ class InternalAgentSubStage(Stage):
agent_runner,
self.max_step,
self.show_tool_use,
self.show_tool_call_result,
stream_to_general,
show_reasoning=self.show_reasoning,
):
@@ -8,6 +8,7 @@ from astrbot.core.agent.runners.dashscope.dashscope_agent_runner import (
DashscopeAgentRunner,
)
from astrbot.core.agent.runners.dify.dify_agent_runner import DifyAgentRunner
from astrbot.core.astr_agent_hooks import MAIN_AGENT_HOOKS
from astrbot.core.message.components import Image
from astrbot.core.message.message_event_result import (
MessageChain,
@@ -17,6 +18,7 @@ from astrbot.core.message.message_event_result import (
if TYPE_CHECKING:
from astrbot.core.agent.runners.base import BaseAgentRunner
from astrbot.core.pipeline.stage import Stage
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.provider.entities import (
ProviderRequest,
@@ -25,9 +27,7 @@ from astrbot.core.star.star_handler import EventType
from astrbot.core.utils.metrics import Metric
from .....astr_agent_context import AgentContextWrapper, AstrAgentContext
from .....astr_agent_hooks import MAIN_AGENT_HOOKS
from ....context import PipelineContext, call_event_hook
from ...stage import Stage
AGENT_RUNNER_TYPE_KEY = {
"dify": "dify_agent_runner_provider_id",
+3 -1
View File
@@ -8,15 +8,17 @@ from astrbot.core.platform.sources.wecom_ai_bot.wecomai_event import (
)
from astrbot.core.utils.active_event_registry import active_event_registry
from . import STAGES_ORDER
from .bootstrap import ensure_builtin_stages_registered
from .context import PipelineContext
from .stage import registered_stages
from .stage_order import STAGES_ORDER
class PipelineScheduler:
"""管道调度器,负责调度各个阶段的执行"""
def __init__(self, context: PipelineContext) -> None:
ensure_builtin_stages_registered()
registered_stages.sort(
key=lambda x: STAGES_ORDER.index(x.__name__),
) # 按照顺序排序
+15
View File
@@ -0,0 +1,15 @@
"""Pipeline stage execution order."""
STAGES_ORDER = [
"WakingCheckStage", # 检查是否需要唤醒
"WhitelistCheckStage", # 检查是否在群聊/私聊白名单
"SessionStatusCheckStage", # 检查会话是否整体启用
"RateLimitStage", # 检查会话是否超过频率限制
"ContentSafetyCheckStage", # 检查内容安全
"PreProcessStage", # 预处理
"ProcessStage", # 交由 Stars 处理(a.k.a 插件),或者 LLM 调用
"ResultDecorateStage", # 处理结果,比如添加回复前缀、t2i、转换为语音 等
"RespondStage", # 发送消息
]
__all__ = ["STAGES_ORDER"]
+33 -11
View File
@@ -52,9 +52,19 @@ class AstrMessageEvent(abc.ABC):
self.is_at_or_wake_command = False
"""是否是 At 机器人或者带有唤醒词或者是私聊(插件注册的事件监听器会让 is_wake 设为 True, 但是不会让这个属性置为 True)"""
self._extras: dict[str, Any] = {}
message_type = getattr(message_obj, "type", None)
if not isinstance(message_type, MessageType):
try:
message_type = MessageType(str(message_type))
except (ValueError, TypeError, AttributeError):
logger.warning(
f"Failed to convert message type {message_obj.type!r} to MessageType. "
f"Falling back to FRIEND_MESSAGE."
)
message_type = MessageType.FRIEND_MESSAGE
self.session = MessageSession(
platform_name=platform_meta.id,
message_type=message_obj.type,
message_type=message_type,
session_id=session_id,
)
# self.unified_msg_origin = str(self.session)
@@ -159,15 +169,18 @@ class AstrMessageEvent(abc.ABC):
除了文本消息外,其他消息类型会被转换为对应的占位符。如图片消息会被转换为 [图片]。
"""
return self._outline_chain(self.message_obj.message)
return self._outline_chain(getattr(self.message_obj, "message", None))
def get_messages(self) -> list[BaseMessageComponent]:
"""获取消息链。"""
return self.message_obj.message
return getattr(self.message_obj, "message", [])
def get_message_type(self) -> MessageType:
"""获取消息类型。"""
return self.message_obj.type
message_type = getattr(self.message_obj, "type", None)
if isinstance(message_type, MessageType):
return message_type
return self.session.message_type
def get_session_id(self) -> str:
"""获取会话id。"""
@@ -175,21 +188,30 @@ class AstrMessageEvent(abc.ABC):
def get_group_id(self) -> str:
"""获取群组id。如果不是群组消息,返回空字符串。"""
return self.message_obj.group_id
return getattr(self.message_obj, "group_id", "")
def get_self_id(self) -> str:
"""获取机器人自身的id。"""
return self.message_obj.self_id
return getattr(self.message_obj, "self_id", "")
def get_sender_id(self) -> str:
"""获取消息发送者的id。"""
return self.message_obj.sender.user_id
sender = getattr(self.message_obj, "sender", None)
if sender and isinstance(getattr(sender, "user_id", None), str):
return sender.user_id
return ""
def get_sender_name(self) -> str:
"""获取消息发送者的名称。(可能会返回空字符串)"""
if isinstance(self.message_obj.sender.nickname, str):
return self.message_obj.sender.nickname
return ""
sender = getattr(self.message_obj, "sender", None)
if not sender:
return ""
nickname = getattr(sender, "nickname", None)
if nickname is None:
return ""
if isinstance(nickname, str):
return nickname
return str(nickname)
def set_extra(self, key, value) -> None:
"""设置额外的信息。"""
@@ -208,7 +230,7 @@ class AstrMessageEvent(abc.ABC):
def is_private_chat(self) -> bool:
"""是否是私聊。"""
return self.message_obj.type.value == (MessageType.FRIEND_MESSAGE).value
return self.get_message_type() == MessageType.FRIEND_MESSAGE
def is_wake_up(self) -> bool:
"""是否是唤醒机器人的事件。"""
+4
View File
@@ -180,6 +180,10 @@ class PlatformManager:
from .sources.line.line_adapter import (
LinePlatformAdapter, # noqa: F401
)
case "heihe":
from .sources.heihe.heihe_adapter import (
HeihePlatformAdapter, # noqa: F401
)
except (ImportError, ModuleNotFoundError) as e:
logger.error(
f"加载平台适配器 {platform_config['type']} 失败,原因:{e}。请检查依赖库是否安装。提示:可以在 管理面板->平台日志->安装Pip库 中安装依赖库。",
@@ -45,6 +45,19 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
if isinstance(segment, File):
# For File segments, we need to handle the file differently
d = await segment.to_dict()
file_val = d.get("data", {}).get("file", "")
if file_val:
import pathlib
try:
# 使用 pathlib 处理路径,能更好地处理 Windows/Linux 差异
path_obj = pathlib.Path(file_val)
# 如果是绝对路径且不包含协议头 (://),则转换为标准的 file: URI
if path_obj.is_absolute() and "://" not in file_val:
d["data"]["file"] = path_obj.as_uri()
except Exception:
# 如果不是合法路径(例如已经是特定的特殊字符串),则跳过转换
pass
return d
if isinstance(segment, Video):
d = await segment.to_dict()
@@ -1,4 +1,5 @@
import asyncio
import inspect
import itertools
import logging
import time
@@ -436,7 +437,42 @@ class AiocqhttpAdapter(Platform):
return coro
async def terminate(self) -> None:
self.shutdown_event.set()
if hasattr(self, "shutdown_event"):
self.shutdown_event.set()
await self._close_reverse_ws_connections()
async def _close_reverse_ws_connections(self) -> None:
api_clients = getattr(self.bot, "_wsr_api_clients", None)
event_clients = getattr(self.bot, "_wsr_event_clients", None)
ws_clients: set[Any] = set()
if isinstance(api_clients, dict):
ws_clients.update(api_clients.values())
if isinstance(event_clients, set):
ws_clients.update(event_clients)
close_tasks: list[Awaitable[Any]] = []
for ws in ws_clients:
close_func = getattr(ws, "close", None)
if not callable(close_func):
continue
try:
close_result = close_func(code=1000, reason="Adapter shutdown")
except TypeError:
close_result = close_func()
except Exception:
continue
if inspect.isawaitable(close_result):
close_tasks.append(close_result)
if close_tasks:
await asyncio.gather(*close_tasks, return_exceptions=True)
if isinstance(api_clients, dict):
api_clients.clear()
if isinstance(event_clients, set):
event_clients.clear()
async def shutdown_trigger_placeholder(self) -> None:
await self.shutdown_event.wait()
@@ -0,0 +1,523 @@
import asyncio
import json
import time
import uuid
from collections.abc import Mapping
from typing import Any, cast
import websockets
from websockets.asyncio.client import ClientConnection, connect
from astrbot.api import logger
from astrbot.api.event import MessageChain
from astrbot.api.message_components import At, Image, Plain
from astrbot.api.platform import (
AstrBotMessage,
Group,
MessageMember,
MessageType,
Platform,
PlatformMetadata,
)
from astrbot.core.platform.astr_message_event import MessageSesion
from ...register import register_platform_adapter
from .heihe_event import HeiheMessageEvent
HEIHE_CONFIG_METADATA = {
"heihe_ws_url": {
"description": "Heihe WebSocket URL",
"type": "string",
"hint": "一般情况下不需要修改。",
},
"heihe_token": {
"description": "Bot Token",
"type": "string",
"hint": "黑盒 Bot Token。可填写纯 Token(推荐),适配器会自动添加 Authorization 头。",
},
"heihe_origin": {
"description": "WebSocket Origin",
"type": "string",
"hint": "用于 WebSocket 握手的 Origin 头,默认 https://chat.xiaoheihe.cn。",
},
"heihe_bot_id": {
"description": "Bot ID",
"type": "string",
"hint": "可选。为空时会根据收到的消息自动识别机器人 ID。",
},
"heihe_auto_reconnect": {
"description": "Auto Reconnect",
"type": "bool",
"hint": "WebSocket 断开后是否自动重连。",
},
"heihe_heartbeat_interval": {
"description": "Heartbeat Interval (seconds)",
"type": "int",
"hint": "发送心跳包间隔。<=0 表示关闭主动心跳。",
},
"heihe_reconnect_delay": {
"description": "Reconnect Delay (seconds)",
"type": "int",
"hint": "WebSocket 断开后的重连等待时间。",
},
"heihe_ignore_self_message": {
"description": "Ignore Self Message",
"type": "bool",
"hint": "是否忽略机器人自身发送的消息。",
},
}
HEIHE_I18N_RESOURCES = {
"zh-CN": {
"heihe_ws_url": {
"description": "黑盒 WebSocket 地址",
"hint": "一般情况下不需要修改。",
},
"heihe_token": {
"description": "机器人 Token",
"hint": "建议填写纯 Token,适配器会自动补齐 Authorization 头。",
},
"heihe_origin": {
"description": "WebSocket Origin",
"hint": "用于握手的 Origin 头,默认 https://chat.xiaoheihe.cn。",
},
"heihe_bot_id": {
"description": "机器人 ID",
"hint": "可选。为空时会根据收到的消息自动识别机器人 ID。",
},
"heihe_auto_reconnect": {
"description": "自动重连",
"hint": "WebSocket 断开后是否自动重连。",
},
"heihe_heartbeat_interval": {
"description": "心跳间隔(秒)",
"hint": "设置 <=0 将关闭主动心跳。",
},
"heihe_reconnect_delay": {
"description": "重连间隔(秒)",
"hint": "WebSocket 断开后的重连等待时间。",
},
"heihe_ignore_self_message": {
"description": "忽略机器人自身消息",
"hint": "开启后,机器人自己发出的消息将不会触发事件处理。",
},
},
"en-US": {
"heihe_ws_url": {
"description": "Heihe WebSocket URL",
"hint": "Usually no need to change this.",
},
"heihe_token": {
"description": "Bot Token",
"hint": "Plain token is recommended. Authorization header is added automatically.",
},
"heihe_origin": {
"description": "WebSocket Origin",
"hint": "Origin header used in websocket handshake. Default: https://chat.xiaoheihe.cn.",
},
"heihe_bot_id": {
"description": "Bot ID",
"hint": "Optional. If empty, the adapter will infer it from incoming messages.",
},
"heihe_auto_reconnect": {
"description": "Auto Reconnect",
"hint": "Whether to reconnect automatically after websocket disconnects.",
},
"heihe_heartbeat_interval": {
"description": "Heartbeat Interval (seconds)",
"hint": "Set <=0 to disable active heartbeat.",
},
"heihe_reconnect_delay": {
"description": "Reconnect Delay (seconds)",
"hint": "Delay before reconnecting after disconnect.",
},
"heihe_ignore_self_message": {
"description": "Ignore Self Message",
"hint": "When enabled, messages sent by the bot itself will be ignored.",
},
},
}
@register_platform_adapter(
"heihe",
"黑盒机器人(WebSocket)适配器",
support_streaming_message=False,
default_config_tmpl={
"id": "heihe",
"type": "heihe",
"enable": False,
"heihe_ws_url": "wss://chat.xiaoheihe.cn/chatroom/ws/connect",
"heihe_token": "",
"heihe_origin": "https://chat.xiaoheihe.cn",
"heihe_bot_id": "",
"heihe_auto_reconnect": True,
"heihe_heartbeat_interval": 20,
"heihe_reconnect_delay": 5,
"heihe_ignore_self_message": True,
},
config_metadata=HEIHE_CONFIG_METADATA,
i18n_resources=HEIHE_I18N_RESOURCES,
)
class HeihePlatformAdapter(Platform):
def __init__(
self,
platform_config: dict,
platform_settings: dict,
event_queue: asyncio.Queue,
) -> None:
super().__init__(platform_config, event_queue)
self.settings = platform_settings
self.ws_url = str(platform_config.get("heihe_ws_url", "")).strip()
self.token = str(platform_config.get("heihe_token", "")).strip()
self.origin = str(
platform_config.get("heihe_origin", "https://chat.xiaoheihe.cn"),
).strip()
self.bot_id = str(platform_config.get("heihe_bot_id", "")).strip()
self.auto_reconnect = bool(platform_config.get("heihe_auto_reconnect", True))
self.heartbeat_interval = int(
cast(int, platform_config.get("heihe_heartbeat_interval", 20)),
)
self.reconnect_delay = int(
cast(int, platform_config.get("heihe_reconnect_delay", 5)),
)
self.ignore_self_message = bool(
platform_config.get("heihe_ignore_self_message", True),
)
if not self.ws_url:
raise ValueError("heihe_ws_url 不能为空。")
self.metadata = PlatformMetadata(
name="heihe",
description="黑盒机器人(WebSocket)适配器",
id=cast(str, self.config.get("id", "heihe")),
support_streaming_message=False,
)
self.ws: ClientConnection | None = None
self.running = False
self.heartbeat_task: asyncio.Task | None = None
self._last_heartbeat_ts = 0
def meta(self) -> PlatformMetadata:
return self.metadata
async def run(self) -> None:
self.running = True
while self.running:
try:
await self._connect_and_loop()
except websockets.exceptions.ConnectionClosed as e:
logger.warning("[heihe] websocket disconnected: %s", e)
except Exception as e:
logger.error("[heihe] websocket failed: %s", e)
if not self.running:
break
if not self.auto_reconnect:
break
await asyncio.sleep(max(1, self.reconnect_delay))
async def terminate(self) -> None:
self.running = False
if self.heartbeat_task:
self.heartbeat_task.cancel()
try:
await self.heartbeat_task
except asyncio.CancelledError:
pass
if self.ws:
try:
await self.ws.close()
except Exception:
pass
self.ws = None
async def send_by_session(
self,
session: MessageSesion,
message_chain: MessageChain,
) -> None:
await HeiheMessageEvent.send_with_adapter(
self,
message_chain,
session.session_id,
)
await super().send_by_session(session, message_chain)
async def send_payload(self, payload: Mapping[str, Any]) -> None:
if not self.ws:
raise RuntimeError("[heihe] websocket not connected")
if self.ws.close_code is not None:
raise RuntimeError("[heihe] websocket already closed")
body = dict(payload)
body.setdefault("timestamp", int(time.time()))
await self.ws.send(json.dumps(body, ensure_ascii=False))
async def _connect_and_loop(self) -> None:
logger.info("[heihe] connecting websocket: %s", self.ws_url)
headers: dict[str, str] = {}
if self.token:
headers["Authorization"] = f"Bearer {self.token}"
headers["X-Token"] = self.token
websocket = await connect(
self.ws_url,
additional_headers=headers,
max_size=10 * 1024 * 1024,
ping_interval=None,
)
self.ws = websocket
logger.info("[heihe] websocket connected")
if self.heartbeat_interval > 0:
self.heartbeat_task = asyncio.create_task(self._heartbeat_loop())
try:
async for raw in websocket:
await self._handle_incoming(raw)
finally:
if self.heartbeat_task:
self.heartbeat_task.cancel()
try:
await self.heartbeat_task
except asyncio.CancelledError:
pass
self.heartbeat_task = None
if self.ws:
try:
await self.ws.close()
except Exception:
pass
self.ws = None
async def _heartbeat_loop(self) -> None:
try:
while self.running and self.ws and self.ws.close_code is None:
await asyncio.sleep(self.heartbeat_interval)
self._last_heartbeat_ts = int(time.time())
await self.send_payload(
{
"type": "ping",
"ping": self._last_heartbeat_ts,
},
)
except asyncio.CancelledError:
pass
except Exception as e:
logger.warning("[heihe] heartbeat error: %s", e)
async def _handle_incoming(self, raw: Any) -> None:
if isinstance(raw, bytes):
try:
raw = raw.decode("utf-8")
except UnicodeDecodeError:
return
if not isinstance(raw, str):
return
try:
data = json.loads(raw)
except json.JSONDecodeError:
logger.debug("[heihe] skip non-json frame: %s", raw[:200])
return
if isinstance(data, list):
for item in data:
if isinstance(item, dict):
await self._handle_packet(item)
return
if isinstance(data, dict):
await self._handle_packet(data)
async def _handle_packet(self, packet: dict[str, Any]) -> None:
if "ping" in packet:
await self.send_payload({"type": "pong", "pong": packet.get("ping")})
return
if str(packet.get("type", "")).lower() == "ping":
await self.send_payload({"type": "pong", "pong": packet.get("ping")})
return
event_type = str(
packet.get("event")
or packet.get("event_type")
or packet.get("type")
or packet.get("topic")
or "",
).lower()
payload_obj = packet.get("data")
payload = payload_obj if isinstance(payload_obj, dict) else packet
if not self._is_message_event(event_type, payload):
return
abm = self._convert_message(payload, packet)
if not abm:
return
await self.handle_msg(abm)
@staticmethod
def _is_message_event(event_type: str, payload: Mapping[str, Any]) -> bool:
if "message" in event_type:
return True
keys = payload.keys()
return "content" in keys or "text" in keys or "message" in keys
def _convert_message(
self,
payload: Mapping[str, Any],
raw_packet: Mapping[str, Any],
) -> AstrBotMessage | None:
message_obj = payload.get("message")
message = message_obj if isinstance(message_obj, Mapping) else payload
sender_data_obj = (
payload.get("sender") or payload.get("author") or payload.get("user") or {}
)
sender_data = sender_data_obj if isinstance(sender_data_obj, Mapping) else {}
sender_id = str(
sender_data.get("id")
or sender_data.get("user_id")
or payload.get("sender_id")
or payload.get("user_id")
or "",
).strip()
sender_name = str(
sender_data.get("nickname")
or sender_data.get("name")
or sender_data.get("username")
or sender_id
or "unknown",
)
self_id = str(
payload.get("self_id")
or payload.get("bot_id")
or self.bot_id
or self.meta().id,
)
if self.ignore_self_message and sender_id and self_id and sender_id == self_id:
return None
channel_id = str(
payload.get("channel_id")
or payload.get("room_id")
or payload.get("chat_id")
or payload.get("session_id")
or "",
).strip()
guild_id = str(
payload.get("guild_id")
or payload.get("server_id")
or payload.get("group_id")
or "",
).strip()
is_private = bool(payload.get("is_private", False))
if str(payload.get("message_type", "")).lower() in {"private", "friend", "dm"}:
is_private = True
session_id = channel_id or sender_id
if not session_id:
return None
text = str(message.get("content") or message.get("text") or "").strip()
components = self._build_components(text, payload)
if not components:
return None
abm = AstrBotMessage()
abm.self_id = self_id
abm.message_id = str(
message.get("id")
or message.get("message_id")
or payload.get("message_id")
or payload.get("msg_id")
or uuid.uuid4().hex
)
timestamp_raw = (
payload.get("timestamp")
or payload.get("time")
or message.get("timestamp")
or message.get("time")
)
abm.timestamp = int(time.time())
if isinstance(timestamp_raw, int):
abm.timestamp = (
timestamp_raw // 1000
if timestamp_raw > 1_000_000_000_000
else timestamp_raw
)
if not is_private and (channel_id or guild_id):
abm.type = MessageType.GROUP_MESSAGE
abm.group = Group(
group_id=guild_id or channel_id, group_name=guild_id or ""
)
else:
abm.type = MessageType.FRIEND_MESSAGE
abm.session_id = session_id
abm.sender = MessageMember(user_id=sender_id or "unknown", nickname=sender_name)
abm.message = components
abm.message_str = self._build_message_str(components)
abm.raw_message = dict(raw_packet)
return abm
@staticmethod
def _build_components(text: str, payload: Mapping[str, Any]) -> list:
components: list = []
if text:
components.append(Plain(text=text))
mentions_obj = payload.get("mentions")
if isinstance(mentions_obj, list):
for mention in mentions_obj:
if not isinstance(mention, Mapping):
continue
user_id = str(mention.get("user_id") or mention.get("id") or "").strip()
name = str(mention.get("name") or mention.get("nickname") or "").strip()
if user_id or name:
components.append(At(qq=user_id, name=name))
attachments_obj = payload.get("attachments")
if isinstance(attachments_obj, list):
for item in attachments_obj:
if not isinstance(item, Mapping):
continue
url = str(item.get("url") or item.get("file_url") or "").strip()
if not url:
continue
kind = str(item.get("type") or item.get("media_type") or "").lower()
if "image" in kind:
components.append(Image.fromURL(url))
else:
components.append(Plain(text=f"[{kind or 'file'}] {url}"))
return components
@staticmethod
def _build_message_str(components: list) -> str:
parts: list[str] = []
for comp in components:
if isinstance(comp, Plain):
parts.append(comp.text)
elif isinstance(comp, At):
parts.append(f"@{comp.name or comp.qq}")
elif isinstance(comp, Image):
parts.append("[image]")
else:
parts.append(f"[{comp.type}]")
return " ".join(i for i in parts if i).strip()
async def handle_msg(self, abm: AstrBotMessage) -> None:
event = HeiheMessageEvent(
message_str=abm.message_str,
message_obj=abm,
platform_meta=self.meta(),
session_id=abm.session_id,
adapter=self,
)
self.commit_event(event)
@@ -0,0 +1,108 @@
from collections.abc import AsyncGenerator
from typing import TYPE_CHECKING, Any
from astrbot.api import logger
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.message_components import At, Image, Plain, Reply
if TYPE_CHECKING:
from .heihe_adapter import HeihePlatformAdapter
class HeiheMessageEvent(AstrMessageEvent):
def __init__(
self,
message_str: str,
message_obj,
platform_meta,
session_id: str,
adapter: "HeihePlatformAdapter",
) -> None:
super().__init__(message_str, message_obj, platform_meta, session_id)
self.adapter = adapter
@classmethod
async def send_with_adapter(
cls,
adapter: "HeihePlatformAdapter",
message: MessageChain,
session_id: str,
) -> None:
payload = await cls._build_send_payload(message, session_id)
await adapter.send_payload(payload)
async def send(self, message: MessageChain) -> None:
await self.send_with_adapter(self.adapter, message, self.session_id)
await super().send(message)
async def send_streaming(
self,
generator: AsyncGenerator,
use_fallback: bool = False,
):
buffer = None
async for chain in generator:
if not buffer:
buffer = chain
else:
buffer.chain.extend(chain.chain)
if not buffer:
return None
buffer.squash_plain()
await self.send(buffer)
return await super().send_streaming(generator, use_fallback)
@classmethod
async def _build_send_payload(
cls,
message: MessageChain,
session_id: str,
) -> dict[str, Any]:
text_parts: list[str] = []
segments: list[dict[str, Any]] = []
for component in message.chain:
if isinstance(component, Plain):
if component.text:
text_parts.append(component.text)
segments.append({"type": "text", "text": component.text})
continue
if isinstance(component, At):
at_name = str(component.name or component.qq or "").strip()
if at_name:
text_parts.append(f"@{at_name}")
segments.append(
{
"type": "mention",
"user_id": str(component.qq or ""),
"name": at_name,
},
)
continue
if isinstance(component, Reply):
if component.id:
segments.append({"type": "reply", "message_id": component.id})
continue
if isinstance(component, Image):
image_url = ""
try:
image_url = await component.register_to_file_service()
except Exception as e:
logger.debug("[heihe] image upload fallback failed: %s", e)
if image_url:
segments.append({"type": "image", "url": image_url})
text_parts.append("[image]")
continue
content = "".join(text_parts).strip()
payload: dict[str, Any] = {
"action": "send_message",
"channel_id": session_id,
"content": content,
"segments": segments,
}
return payload
@@ -1,4 +1,5 @@
import asyncio
import os
import re
import sys
import uuid
@@ -25,6 +26,9 @@ from astrbot.core.star.filter.command import CommandFilter
from astrbot.core.star.filter.command_group import CommandGroupFilter
from astrbot.core.star.star import star_map
from astrbot.core.star.star_handler import star_handlers_registry
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from astrbot.core.utils.io import download_image_by_url
from astrbot.core.utils.media_utils import convert_audio_to_wav
from .tg_event import TelegramPlatformEvent
@@ -375,8 +379,19 @@ class TelegramPlatformAdapter(Platform):
elif update.message.voice:
file = await update.message.voice.get_file()
file_basename = os.path.basename(file.file_path)
temp_dir = get_astrbot_temp_path()
temp_path = os.path.join(temp_dir, file_basename)
temp_path = await download_image_by_url(file.file_path, path=temp_path)
path_wav = os.path.join(
temp_dir,
f"{file_basename}.wav",
)
path_wav = await convert_audio_to_wav(temp_path, path_wav)
message.message = [
Comp.Record(file=file.file_path, url=file.file_path),
Comp.Record(file=path_wav, url=path_wav),
]
elif update.message.photo:
@@ -48,6 +48,9 @@ class GeminiEmbeddingProvider(EmbeddingProvider):
result = await self.client.models.embed_content(
model=self.model,
contents=text,
config=types.EmbedContentConfig(
output_dimensionality=self.get_dim(),
),
)
assert result.embeddings is not None
assert result.embeddings[0].values is not None
@@ -61,6 +64,9 @@ class GeminiEmbeddingProvider(EmbeddingProvider):
result = await self.client.models.embed_content(
model=self.model,
contents=cast(types.ContentListUnion, text),
config=types.EmbedContentConfig(
output_dimensionality=self.get_dim(),
),
)
assert result.embeddings is not None
@@ -36,12 +36,20 @@ class OpenAIEmbeddingProvider(EmbeddingProvider):
async def get_embedding(self, text: str) -> list[float]:
"""获取文本的嵌入"""
embedding = await self.client.embeddings.create(input=text, model=self.model)
embedding = await self.client.embeddings.create(
input=text,
model=self.model,
dimensions=self.get_dim(),
)
return embedding.data[0].embedding
async def get_embeddings(self, text: list[str]) -> list[list[float]]:
"""批量获取文本的嵌入"""
embeddings = await self.client.embeddings.create(input=text, model=self.model)
embeddings = await self.client.embeddings.create(
input=text,
model=self.model,
dimensions=self.get_dim(),
)
return [item.embedding for item in embeddings.data]
def get_dim(self) -> int:
+13 -62
View File
@@ -1,68 +1,19 @@
from astrbot.core import html_renderer
# 兼容导出: Provider 从 provider 模块重新导出
from astrbot.core.provider import Provider
from astrbot.core.star.star_tools import StarTools
from astrbot.core.utils.command_parser import CommandParserMixin
from astrbot.core.utils.plugin_kv_store import PluginKVStoreMixin
from .base import Star
from .context import Context
from .star import StarMetadata, star_map, star_registry
from .star_manager import PluginManager
from .star_tools import StarTools
class Star(CommandParserMixin, PluginKVStoreMixin):
"""所有插件(Star)的父类,所有插件都应该继承于这个类"""
author: str
name: str
def __init__(self, context: Context, config: dict | None = None) -> None:
StarTools.initialize(context)
self.context = context
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
if not star_map.get(cls.__module__):
metadata = StarMetadata(
star_cls_type=cls,
module_path=cls.__module__,
)
star_map[cls.__module__] = metadata
star_registry.append(metadata)
else:
star_map[cls.__module__].star_cls_type = cls
star_map[cls.__module__].module_path = cls.__module__
async def text_to_image(self, text: str, return_url=True) -> str:
"""将文本转换为图片"""
return await html_renderer.render_t2i(
text,
return_url=return_url,
template_name=self.context._config.get("t2i_active_template"),
)
async def html_render(
self,
tmpl: str,
data: dict,
return_url=True,
options: dict | None = None,
) -> str:
"""渲染 HTML"""
return await html_renderer.render_custom_template(
tmpl,
data,
return_url=return_url,
options=options,
)
async def initialize(self) -> None:
"""当插件被激活时会调用这个方法"""
async def terminate(self) -> None:
"""当插件被禁用、重载插件时会调用这个方法"""
def __del__(self) -> None:
"""[Deprecated] 当插件被禁用、重载插件时会调用这个方法"""
__all__ = ["Context", "PluginManager", "Provider", "Star", "StarMetadata", "StarTools"]
__all__ = [
"Context",
"PluginManager",
"Provider",
"Star",
"StarMetadata",
"StarTools",
"star_map",
"star_registry",
]
+87
View File
@@ -0,0 +1,87 @@
from __future__ import annotations
import logging
from typing import Any, Protocol
from astrbot.core import html_renderer
from astrbot.core.utils.command_parser import CommandParserMixin
from astrbot.core.utils.plugin_kv_store import PluginKVStoreMixin
from .star import StarMetadata, star_map, star_registry
logger = logging.getLogger("astrbot")
class Star(CommandParserMixin, PluginKVStoreMixin):
"""所有插件(Star)的父类,所有插件都应该继承于这个类"""
author: str
name: str
class _ContextLike(Protocol):
def get_config(self, umo: str | None = None) -> Any: ...
def __init__(self, context: _ContextLike, config: dict | None = None) -> None:
self.context = context
def _get_context_config(self) -> Any:
get_config = getattr(self.context, "get_config", None)
if callable(get_config):
try:
return get_config()
except Exception as e:
logger.debug(f"get_config() failed: {e}")
return None
return getattr(self.context, "_config", None)
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
if not star_map.get(cls.__module__):
metadata = StarMetadata(
star_cls_type=cls,
module_path=cls.__module__,
)
star_map[cls.__module__] = metadata
star_registry.append(metadata)
else:
star_map[cls.__module__].star_cls_type = cls
star_map[cls.__module__].module_path = cls.__module__
async def text_to_image(self, text: str, return_url=True) -> str:
"""将文本转换为图片"""
config_obj = self._get_context_config()
template_name = None
if hasattr(config_obj, "get"):
try:
template_name = config_obj.get("t2i_active_template")
except Exception:
template_name = None
return await html_renderer.render_t2i(
text,
return_url=return_url,
template_name=template_name,
)
async def html_render(
self,
tmpl: str,
data: dict,
return_url=True,
options: dict | None = None,
) -> str:
"""渲染 HTML"""
return await html_renderer.render_custom_template(
tmpl,
data,
return_url=return_url,
options=options,
)
async def initialize(self) -> None:
"""当插件被激活时会调用这个方法"""
async def terminate(self) -> None:
"""当插件被禁用、重载插件时会调用这个方法"""
def __del__(self) -> None:
"""[Deprecated] 当插件被禁用、重载插件时会调用这个方法"""
+16 -4
View File
@@ -1,7 +1,9 @@
from __future__ import annotations
import logging
from asyncio import Queue
from collections.abc import Awaitable, Callable
from typing import Any
from typing import TYPE_CHECKING, Any, Protocol
from deprecated import deprecated
@@ -12,14 +14,12 @@ from astrbot.core.agent.tool import ToolSet
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.conversation_mgr import ConversationManager
from astrbot.core.cron.manager import CronJobManager
from astrbot.core.db import BaseDatabase
from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.persona_mgr import PersonaManager
from astrbot.core.platform import Platform
from astrbot.core.platform.astr_message_event import AstrMessageEvent, MessageSesion
from astrbot.core.platform.manager import PlatformManager
from astrbot.core.platform_message_history_mgr import PlatformMessageHistoryManager
from astrbot.core.provider.entities import LLMResponse, ProviderRequest, ProviderType
from astrbot.core.provider.func_tool_manager import FunctionTool, FunctionToolManager
@@ -45,6 +45,15 @@ from .star_handler import EventType, StarHandlerMetadata, star_handlers_registry
logger = logging.getLogger("astrbot")
if TYPE_CHECKING:
from astrbot.core.cron.manager import CronJobManager
else:
CronJobManager = Any
class PlatformManagerProtocol(Protocol):
platform_insts: list[Platform]
class Context:
"""暴露给插件的接口上下文。"""
@@ -61,7 +70,7 @@ class Context:
config: AstrBotConfig,
db: BaseDatabase,
provider_manager: ProviderManager,
platform_manager: PlatformManager,
platform_manager: PlatformManagerProtocol,
conversation_manager: ConversationManager,
message_history_manager: PlatformMessageHistoryManager,
persona_manager: PersonaManager,
@@ -448,6 +457,9 @@ class Context:
if platform.meta().id == session.platform_name:
await platform.send_by_session(session, message_chain)
return True
logger.warning(
f"cannot find platform for session {str(session)}, message not sent"
)
return False
def add_llm_tools(self, *tools: FunctionTool) -> None:
+1 -1
View File
@@ -1,6 +1,6 @@
import warnings
from astrbot.core.star import StarMetadata, star_map
from astrbot.core.star.star import StarMetadata, star_map
_warned_register_star = False
+4 -5
View File
@@ -11,7 +11,6 @@ from astrbot.core.agent.agent import Agent
from astrbot.core.agent.handoff import HandoffTool
from astrbot.core.agent.hooks import BaseAgentRunHooks
from astrbot.core.agent.tool import FunctionTool
from astrbot.core.astr_agent_context import AstrAgentContext
from astrbot.core.message.message_event_result import MessageEventResult
from astrbot.core.provider.func_tool_manager import PY_TO_JSON_TYPE, SUPPORTED_TYPES
from astrbot.core.provider.register import llm_tools
@@ -617,7 +616,7 @@ class RegisteringAgent:
kwargs["registering_agent"] = self
return register_llm_tool(*args, **kwargs)
def __init__(self, agent: Agent[AstrAgentContext]) -> None:
def __init__(self, agent: Agent[Any]) -> None:
self._agent = agent
@@ -625,7 +624,7 @@ def register_agent(
name: str,
instruction: str,
tools: list[str | FunctionTool] | None = None,
run_hooks: BaseAgentRunHooks[AstrAgentContext] | None = None,
run_hooks: BaseAgentRunHooks[Any] | None = None,
):
"""注册一个 Agent
@@ -639,12 +638,12 @@ def register_agent(
tools_ = tools or []
def decorator(awaitable: Callable[..., Awaitable[Any]]):
AstrAgent = Agent[AstrAgentContext]
AstrAgent = Agent[Any]
agent = AstrAgent(
name=name,
instructions=instruction,
tools=tools_,
run_hooks=run_hooks or BaseAgentRunHooks[AstrAgentContext](),
run_hooks=run_hooks or BaseAgentRunHooks[Any](),
)
handoff_tool = HandoffTool(agent=agent)
handoff_tool.handler = awaitable
+3
View File
@@ -49,10 +49,13 @@ class PluginVersionIncompatibleError(Exception):
class PluginManager:
def __init__(self, context: Context, config: AstrBotConfig) -> None:
from .star_tools import StarTools
self.updator = PluginUpdator()
self.context = context
self.context._star_manager = self # type: ignore
StarTools.initialize(context)
self.config = config
self.plugin_store_path = get_astrbot_plugin_path()
+60
View File
@@ -0,0 +1,60 @@
## What's Changed
### 新增
- 新增 Agent 会话停止能力,并优化 stop 请求处理流程,支持 /stop 指令终止 Agent 运行并尽量不丢失已运行输出的结果。 ([#5380](https://github.com/AstrBotDevs/AstrBot/issues/5380))。
- 新增 SubAgent 交接场景下的 computer-use 工具支持 ([#5399](https://github.com/AstrBotDevs/AstrBot/issues/5399))。
- 新增 Agent 执行过程中展示工具调用结果的能力,提升执行过程可观测性 ([#5388](https://github.com/AstrBotDevs/AstrBot/issues/5388))。
- 新增插件加载/卸载 Hook,扩展插件生命周期能力 ([#5331](https://github.com/AstrBotDevs/AstrBot/issues/5331))。
- 新增插件加载失败后的热重载能力,提升插件开发与恢复效率 ([#5334](https://github.com/AstrBotDevs/AstrBot/issues/5334))。
- 新增 SubAgent 图片 URL/本地路径输入支持 ([#5348](https://github.com/AstrBotDevs/AstrBot/issues/5348))。
- 新增 Dashboard 发布跳转基础 URL 可配置项 ([#5330](https://github.com/AstrBotDevs/AstrBot/issues/5330))。
### 修复
- 修复 Tavily 请求的硬编码 6 秒超时。
- 修复 OneBot v11 适配器关闭之后仍然在连接的问题([#5412](https://github.com/AstrBotDevs/AstrBot/issues/5412))。
- 修复上下文会话中平台缺失时的日志处理,补充 warning 并改进排查信息。
- 修复 embedding 维度未透传到 provider API 的问题 ([#5411](https://github.com/AstrBotDevs/AstrBot/issues/5411))。
- 修复 File 组件处理逻辑并增强 OneBot 驱动层路径兼容性 ([#5391](https://github.com/AstrBotDevs/AstrBot/issues/5391))。
- 修复 sandbox 文件传输工具缺少管理员权限校验的问题 ([#5402](https://github.com/AstrBotDevs/AstrBot/issues/5402))。
- 修复 pipeline 与 `from ... import *` 引发的循环依赖问题 ([#5353](https://github.com/AstrBotDevs/AstrBot/issues/5353))。
- 修复配置文件存在 UTF-8 BOM 时的解析问题 ([#5376](https://github.com/AstrBotDevs/AstrBot/issues/5376))。
- 修复 ChatUI 复制回滚路径缺失与错误提示不清晰的问题 ([#5352](https://github.com/AstrBotDevs/AstrBot/issues/5352))。
- 修复保留插件目录处理逻辑,避免插件目录行为异常 ([#5369](https://github.com/AstrBotDevs/AstrBot/issues/5369))。
- 修复 ChatUI 文件消息段无法持久化的问题 ([#5386](https://github.com/AstrBotDevs/AstrBot/issues/5386))。
- 修复 `.dockerignore` 误排除 `changelogs` 目录的问题。
- 修复 aiohttp 版本过新导致 qq-botpy 报错的问题 ([#5316](https://github.com/AstrBotDevs/AstrBot/issues/5316))。
### 优化
- 完成 SubAgent 编排页面国际化,补齐多语言支持 ([#5400](https://github.com/AstrBotDevs/AstrBot/issues/5400))。
- 增补消息事件处理相关测试,并完善测试框架的 fixtures/mocks 覆盖 ([#5355](https://github.com/AstrBotDevs/AstrBot/issues/5355), [#5354](https://github.com/AstrBotDevs/AstrBot/issues/5354))。
## What's Changed (EN)
### New Features
- Added computer-use tools support in sub-agent handoff scenarios ([#5399](https://github.com/AstrBotDevs/AstrBot/issues/5399)).
- Added support for displaying tool call results during agent execution for better observability ([#5388](https://github.com/AstrBotDevs/AstrBot/issues/5388)).
- Added plugin load/unload hooks to extend plugin lifecycle capabilities ([#5331](https://github.com/AstrBotDevs/AstrBot/issues/5331)).
- Added hot reload support when plugin loading fails, improving recovery during plugin development ([#5334](https://github.com/AstrBotDevs/AstrBot/issues/5334)).
- Added image URL/local path input support for sub-agents ([#5348](https://github.com/AstrBotDevs/AstrBot/issues/5348)).
- Added stop control for active agent sessions and improved stop request handling ([#5380](https://github.com/AstrBotDevs/AstrBot/issues/5380)).
- Added configurable base URL for dashboard release redirects ([#5330](https://github.com/AstrBotDevs/AstrBot/issues/5330)).
### Fixes
- Fixed logging behavior when platform information is missing in context sessions, with clearer warning and diagnostics.
- Fixed missing embedding dimensions being passed to provider APIs ([#5411](https://github.com/AstrBotDevs/AstrBot/issues/5411)).
- Fixed shutdown stability issues in the aiocqhttp adapter ([#5412](https://github.com/AstrBotDevs/AstrBot/issues/5412)).
- Fixed File component handling and improved path compatibility in the OneBot driver layer ([#5391](https://github.com/AstrBotDevs/AstrBot/issues/5391)).
- Fixed missing admin guard for sandbox file transfer tools ([#5402](https://github.com/AstrBotDevs/AstrBot/issues/5402)).
- Fixed circular import issues related to pipeline and `from ... import *` usage ([#5353](https://github.com/AstrBotDevs/AstrBot/issues/5353)).
- Fixed config parsing issues when files contain UTF-8 BOM ([#5376](https://github.com/AstrBotDevs/AstrBot/issues/5376)).
- Fixed missing copy rollback path and unclear error messaging in ChatUI ([#5352](https://github.com/AstrBotDevs/AstrBot/issues/5352)).
- Fixed reserved plugin directory handling to avoid abnormal plugin path behavior ([#5369](https://github.com/AstrBotDevs/AstrBot/issues/5369)).
- Fixed ChatUI file segment persistence issues ([#5386](https://github.com/AstrBotDevs/AstrBot/issues/5386)).
- Fixed accidental exclusion of the `changelogs` directory in `.dockerignore`.
- Fixed compatibility issues caused by a hard-coded 6-second timeout in Tavily requests.
- Fixed qq-botpy runtime errors caused by overly new aiohttp versions ([#5316](https://github.com/AstrBotDevs/AstrBot/issues/5316)).
### Improvements
- Completed internationalization for the sub-agent orchestration page ([#5400](https://github.com/AstrBotDevs/AstrBot/issues/5400)).
- Added broader message-event test coverage and improved fixtures/mocks in the test framework ([#5355](https://github.com/AstrBotDevs/AstrBot/issues/5355), [#5354](https://github.com/AstrBotDevs/AstrBot/issues/5354)).
- Updated README content and applied repository-wide formatting cleanup (ruff format) ([#5375](https://github.com/AstrBotDevs/AstrBot/issues/5375)).
@@ -58,6 +58,18 @@
"guideStep2": "Install it and restart AstrBot.",
"guideStep3": "If you use Docker, prefer the image update path."
},
"desktopApp": {
"title": "Update Desktop App",
"message": "Check and upgrade the AstrBot desktop application.",
"currentVersion": "Current version: ",
"latestVersion": "Latest version: ",
"checking": "Checking desktop app updates...",
"hasNewVersion": "A new version is available. Click confirm to upgrade.",
"isLatest": "Already on the latest version",
"installing": "Downloading and installing update. The app will restart automatically...",
"checkFailed": "Failed to check updates. Please try again later.",
"installFailed": "Upgrade failed. Please try again later."
},
"dashboardUpdate": {
"title": "Update Dashboard to Latest Version Only",
"currentVersion": "Current Version",
@@ -251,6 +251,10 @@
"show_tool_use_status": {
"description": "Output Function Call Status"
},
"show_tool_call_result": {
"description": "Output Tool Call Results",
"hint": "Only takes effect when \"Output Function Call Status\" is enabled, and shows at most 70 characters."
},
"sanitize_context_by_modalities": {
"description": "Sanitize History by Modalities",
"hint": "When enabled, sanitizes contexts before each LLM request by removing image blocks and tool-call structures that the current provider's modalities do not support (this changes what the model sees)."
@@ -8,11 +8,14 @@
"refresh": "Refresh",
"save": "Save",
"add": "Add SubAgent",
"delete": "Delete"
"delete": "Delete",
"close": "Close"
},
"switches": {
"enable": "Enable SubAgent orchestration",
"dedupe": "Deduplicate main LLM tools (hide tools duplicated by SubAgents)"
"enableHint": "Enable sub-agent functionality",
"dedupe": "Deduplicate main LLM tools (hide tools duplicated by SubAgents)",
"dedupeHint": "Remove duplicate tools from main agent"
},
"description": {
"disabled": "When off: SubAgent is disabled; the main LLM mounts tools via persona rules (all by default) and calls them directly.",
@@ -29,7 +32,8 @@
"transferPrefix": "transfer_to_{name}",
"switchLabel": "Enable",
"previewTitle": "Preview: handoff tool shown to the main LLM",
"personaChip": "Persona: {id}"
"personaChip": "Persona: {id}",
"personaPreview": "PERSONA PREVIEW"
},
"form": {
"nameLabel": "Agent name (used for transfer_to_{name})",
@@ -49,6 +53,13 @@
"nameDuplicate": "Duplicate SubAgent name: {name}",
"personaMissing": "SubAgent {name} has no persona selected",
"saveSuccess": "Saved successfully",
"saveFailed": "Failed to save"
"saveFailed": "Failed to save",
"nameRequired": "Name is required",
"namePattern": "Lowercase letters, numbers, underscore only"
},
"empty": {
"title": "No Agents Configured",
"subtitle": "Add a new sub-agent to get started",
"action": "Create First Agent"
}
}
@@ -58,6 +58,18 @@
"guideStep2": "完成安装后重启 AstrBot。",
"guideStep3": "如果你使用 Docker,请优先使用镜像更新方式。"
},
"desktopApp": {
"title": "更新桌面应用",
"message": "将检查并升级 AstrBot 桌面端程序。",
"currentVersion": "当前版本:",
"latestVersion": "最新版本:",
"checking": "正在检查桌面应用更新...",
"hasNewVersion": "发现新版本,可点击确认升级。",
"isLatest": "已经是最新版本",
"installing": "正在下载并安装更新,完成后将自动重启应用...",
"checkFailed": "检查更新失败,请稍后重试。",
"installFailed": "升级失败,请稍后重试。"
},
"dashboardUpdate": {
"title": "单独更新管理面板到最新版本",
"currentVersion": "当前版本",
@@ -254,6 +254,10 @@
"show_tool_use_status": {
"description": "输出函数调用状态"
},
"show_tool_call_result": {
"description": "输出函数调用返回结果",
"hint": "仅在启用“输出函数调用状态”时生效,且最多展示 70 个字符。"
},
"sanitize_context_by_modalities": {
"description": "按模型能力清理历史上下文",
"hint": "开启后,在每次请求 LLM 前会按当前模型提供商中所选择的模型能力删除对话中不支持的图片/工具调用结构(会改变模型看到的历史)"
@@ -8,11 +8,14 @@
"refresh": "刷新",
"save": "保存",
"add": "新增 SubAgent",
"delete": "删除"
"delete": "删除",
"close": "关闭"
},
"switches": {
"enable": "启用 SubAgent 编排",
"dedupe": "主 LLM 去重重复工具(与 SubAgent 重叠的工具将被隐藏)"
"enableHint": "启用子代理功能",
"dedupe": "主 LLM 去重重复工具(与 SubAgent 重叠的工具将被隐藏)",
"dedupeHint": "从主代理中移除重复工具"
},
"description": {
"disabled": "不启动:SubAgent 关闭;主 LLM 按 persona 规则挂载工具(默认全部),并直接调用。",
@@ -39,6 +42,7 @@
"providerHint": "留空表示跟随全局默认 provider。",
"personaLabel": "选择人格设定",
"personaHint": "SubAgent 将直接继承所选 Persona 的系统设定与工具。在人格设定页管理和新建人格。",
"personaPreview": "人格预览",
"descriptionLabel": "对主 LLM 的描述(用于决定是否 handoff",
"descriptionHint": "这段会作为 transfer_to_* 工具的描述给主 LLM 看,建议简短明确。"
},
@@ -50,6 +54,13 @@
"nameDuplicate": "SubAgent 名称重复:{name}",
"personaMissing": "SubAgent {name} 未选择 Persona",
"saveSuccess": "保存成功",
"saveFailed": "保存失败"
"saveFailed": "保存失败",
"nameRequired": "名称必填",
"namePattern": "仅支持小写字母、数字和下划线"
},
"empty": {
"title": "未配置 SubAgent",
"subtitle": "添加一个新的子代理以开始",
"action": "创建第一个 Agent"
}
}
@@ -50,24 +50,27 @@ let installLoading = ref(false);
const isDesktopReleaseMode = ref(
typeof window !== 'undefined' && !!window.astrbotDesktop?.isDesktop
);
const redirectConfirmDialog = ref(false);
const pendingRedirectUrl = ref('');
const resolvingReleaseTarget = ref(false);
const DEFAULT_ASTRBOT_RELEASE_BASE_URL = 'https://github.com/AstrBotDevs/AstrBot/releases';
const resolveReleaseBaseUrl = () => {
const raw = import.meta.env.VITE_ASTRBOT_RELEASE_BASE_URL;
// Keep upstream default on AstrBot releases; desktop distributors can override via env injection.
const normalized = raw?.trim()?.replace(/\/+$/, '') || '';
const withoutLatestSuffix = normalized.replace(/\/latest$/i, '');
return withoutLatestSuffix || DEFAULT_ASTRBOT_RELEASE_BASE_URL;
};
const releaseBaseUrl = resolveReleaseBaseUrl();
const getReleaseUrlByTag = (tag: string | null | undefined) => {
const normalizedTag = (tag || '').trim();
if (!normalizedTag || normalizedTag.toLowerCase() === 'latest') {
return `${releaseBaseUrl}/latest`;
const desktopUpdateDialog = ref(false);
const desktopUpdateChecking = ref(false);
const desktopUpdateInstalling = ref(false);
const desktopUpdateHasNewVersion = ref(false);
const desktopUpdateCurrentVersion = ref('-');
const desktopUpdateLatestVersion = ref('-');
const desktopUpdateStatus = ref('');
const getAppUpdaterBridge = (): AstrBotAppUpdaterBridge | null => {
if (typeof window === 'undefined') {
return null;
}
return `${releaseBaseUrl}/tag/${normalizedTag}`;
const bridge = window.astrbotAppUpdater;
if (
bridge &&
typeof bridge.checkForAppUpdate === 'function' &&
typeof bridge.installAppUpdate === 'function'
) {
return bridge;
}
return null;
};
const getSelectedGitHubProxy = () => {
@@ -89,16 +92,6 @@ const releasesHeader = computed(() => [
{ title: t('core.header.updateDialog.table.sourceUrl'), key: 'zipball_url' },
{ title: t('core.header.updateDialog.table.actions'), key: 'switch' }
]);
const latestReleaseTag = computed(() => {
const firstRelease = (releases.value as any[])?.[0];
if (firstRelease?.tag_name) {
return firstRelease.tag_name as string;
}
return hasNewVersion.value
? t('core.header.updateDialog.redirectConfirm.latestLabel')
: (botCurrVersion.value || '-');
});
// Form validation
const formValid = ref(true);
const passwordRules = computed(() => [
@@ -126,47 +119,88 @@ const accountEditStatus = ref({
message: ''
});
const open = (link: string) => {
window.open(link, '_blank');
};
function requestExternalRedirect(link: string) {
pendingRedirectUrl.value = link;
redirectConfirmDialog.value = true;
function cancelDesktopUpdate() {
if (desktopUpdateInstalling.value) {
return;
}
desktopUpdateDialog.value = false;
}
function cancelExternalRedirect() {
redirectConfirmDialog.value = false;
pendingRedirectUrl.value = '';
}
async function openDesktopUpdateDialog() {
desktopUpdateDialog.value = true;
desktopUpdateChecking.value = true;
desktopUpdateInstalling.value = false;
desktopUpdateHasNewVersion.value = false;
desktopUpdateCurrentVersion.value = '-';
desktopUpdateLatestVersion.value = '-';
desktopUpdateStatus.value = t('core.header.updateDialog.desktopApp.checking');
function confirmExternalRedirect() {
const targetUrl = pendingRedirectUrl.value;
cancelExternalRedirect();
if (targetUrl) {
open(targetUrl);
const bridge = getAppUpdaterBridge();
if (!bridge) {
desktopUpdateChecking.value = false;
desktopUpdateStatus.value = t('core.header.updateDialog.desktopApp.checkFailed');
return;
}
try {
const result = await bridge.checkForAppUpdate();
if (!result?.ok) {
desktopUpdateCurrentVersion.value = result?.currentVersion || '-';
desktopUpdateLatestVersion.value =
result?.latestVersion || result?.currentVersion || '-';
desktopUpdateStatus.value =
result?.reason || t('core.header.updateDialog.desktopApp.checkFailed');
return;
}
desktopUpdateCurrentVersion.value = result.currentVersion || '-';
desktopUpdateLatestVersion.value =
result.latestVersion || result.currentVersion || '-';
desktopUpdateHasNewVersion.value = !!result.hasUpdate;
desktopUpdateStatus.value = result.hasUpdate
? t('core.header.updateDialog.desktopApp.hasNewVersion')
: t('core.header.updateDialog.desktopApp.isLatest');
} catch (error) {
console.error(error);
desktopUpdateStatus.value = t('core.header.updateDialog.desktopApp.checkFailed');
} finally {
desktopUpdateChecking.value = false;
}
}
const getReleaseUrlForDesktop = () => {
const firstRelease = (releases.value as any[])?.[0];
if (firstRelease?.tag_name) {
return getReleaseUrlByTag(firstRelease.tag_name as string);
async function confirmDesktopUpdate() {
if (!desktopUpdateHasNewVersion.value || desktopUpdateInstalling.value) {
return;
}
if (hasNewVersion.value) return getReleaseUrlByTag('latest');
const tag = botCurrVersion.value?.startsWith('v') ? botCurrVersion.value : 'latest';
return getReleaseUrlByTag(tag);
};
const bridge = getAppUpdaterBridge();
if (!bridge) {
desktopUpdateStatus.value = t('core.header.updateDialog.desktopApp.installFailed');
return;
}
desktopUpdateInstalling.value = true;
desktopUpdateStatus.value = t('core.header.updateDialog.desktopApp.installing');
try {
const result = await bridge.installAppUpdate();
if (result?.ok) {
desktopUpdateDialog.value = false;
return;
}
desktopUpdateStatus.value =
result?.reason || t('core.header.updateDialog.desktopApp.installFailed');
} catch (error) {
console.error(error);
desktopUpdateStatus.value = t('core.header.updateDialog.desktopApp.installFailed');
} finally {
desktopUpdateInstalling.value = false;
}
}
function handleUpdateClick() {
if (isDesktopReleaseMode.value) {
requestExternalRedirect('');
resolvingReleaseTarget.value = true;
checkUpdate();
void getReleases().finally(() => {
pendingRedirectUrl.value = getReleaseUrlForDesktop() || getReleaseUrlByTag('latest');
resolvingReleaseTarget.value = false;
});
void openDesktopUpdateDialog();
return;
}
checkUpdate();
@@ -680,40 +714,38 @@ onMounted(async () => {
</v-card>
</v-dialog>
<v-dialog v-model="redirectConfirmDialog" max-width="460">
<v-dialog v-model="desktopUpdateDialog" max-width="460">
<v-card>
<v-card-title class="text-h3 pa-4 pl-6 pb-0">
{{ t('core.header.updateDialog.redirectConfirm.title') }}
{{ t('core.header.updateDialog.desktopApp.title') }}
</v-card-title>
<v-card-text>
<div class="mb-3">
{{ t('core.header.updateDialog.redirectConfirm.message') }}
{{ t('core.header.updateDialog.desktopApp.message') }}
</div>
<v-alert type="info" variant="tonal" density="compact">
<div>
{{ t('core.header.updateDialog.redirectConfirm.targetVersion') }}
<strong v-if="!resolvingReleaseTarget">{{ latestReleaseTag }}</strong>
<v-progress-circular v-else indeterminate size="16" width="2" class="ml-1" />
{{ t('core.header.updateDialog.desktopApp.currentVersion') }}
<strong>{{ desktopUpdateCurrentVersion }}</strong>
</div>
<div class="text-caption">
{{ t('core.header.updateDialog.redirectConfirm.currentVersion') }}
{{ botCurrVersion || '-' }}
<div>
{{ t('core.header.updateDialog.desktopApp.latestVersion') }}
<strong v-if="!desktopUpdateChecking">{{ desktopUpdateLatestVersion }}</strong>
<v-progress-circular v-else indeterminate size="16" width="2" class="ml-1" />
</div>
</v-alert>
<div class="text-caption mt-3">
<div>{{ t('core.header.updateDialog.redirectConfirm.guideTitle') }}</div>
<div>1. {{ t('core.header.updateDialog.redirectConfirm.guideStep1') }}</div>
<div>2. {{ t('core.header.updateDialog.redirectConfirm.guideStep2') }}</div>
<div>3. {{ t('core.header.updateDialog.redirectConfirm.guideStep3') }}</div>
{{ desktopUpdateStatus }}
</div>
</v-card-text>
<v-card-actions>
<v-spacer></v-spacer>
<v-btn color="grey" variant="text" @click="cancelExternalRedirect">
<v-btn color="grey" variant="text" @click="cancelDesktopUpdate" :disabled="desktopUpdateInstalling">
{{ t('core.common.dialog.cancelButton') }}
</v-btn>
<v-btn color="primary" variant="flat" @click="confirmExternalRedirect"
:loading="resolvingReleaseTarget" :disabled="resolvingReleaseTarget || !pendingRedirectUrl">
<v-btn color="primary" variant="flat" @click="confirmDesktopUpdate"
:loading="desktopUpdateInstalling"
:disabled="desktopUpdateChecking || desktopUpdateInstalling || !desktopUpdateHasNewVersion">
{{ t('core.common.dialog.confirmButton') }}
</v-btn>
</v-card-actions>
@@ -1,7 +1,26 @@
export {};
declare global {
interface AstrBotDesktopAppUpdateCheckResult {
ok: boolean;
reason?: string | null;
currentVersion?: string;
latestVersion?: string | null;
hasUpdate: boolean;
}
interface AstrBotDesktopAppUpdateResult {
ok: boolean;
reason?: string | null;
}
interface AstrBotAppUpdaterBridge {
checkForAppUpdate: () => Promise<AstrBotDesktopAppUpdateCheckResult>;
installAppUpdate: () => Promise<AstrBotDesktopAppUpdateResult>;
}
interface Window {
astrbotAppUpdater?: AstrBotAppUpdaterBridge;
astrbotDesktop?: {
isDesktop: boolean;
isDesktopRuntime: () => Promise<boolean>;
+8 -8
View File
@@ -62,7 +62,7 @@
<template #label>
<div class="d-flex flex-column">
<span class="text-body-2 font-weight-medium">{{ tm('switches.enable') }}</span>
<span class="text-caption text-medium-emphasis">Enable sub-agent functionality</span>
<span class="text-caption text-medium-emphasis">{{ tm('switches.enableHint') }}</span>
</div>
</template>
</v-switch>
@@ -80,7 +80,7 @@
<template #label>
<div class="d-flex flex-column">
<span class="text-body-2 font-weight-medium">{{ tm('switches.dedupe') }}</span>
<span class="text-caption text-medium-emphasis">Remove duplicate tools from main agent</span>
<span class="text-caption text-medium-emphasis">{{ tm('switches.dedupeHint') }}</span>
</div>
</template>
</v-switch>
@@ -166,7 +166,7 @@
<v-text-field
v-model="agent.name"
:label="tm('form.nameLabel')"
:rules="[v => !!v || 'Name is required', v => /^[a-z][a-z0-9_]*$/.test(v) || 'Lowercase letters, numbers, underscore only']"
:rules="[v => !!v || tm('messages.nameRequired'), v => /^[a-z][a-z0-9_]*$/.test(v) || tm('messages.namePattern')]"
variant="outlined"
density="comfortable"
hide-details="auto"
@@ -215,7 +215,7 @@
<v-col cols="12" md="6">
<div class="h-100">
<div class="text-caption font-weight-bold text-medium-emphasis mb-2 ml-1">
PERSONA PREVIEW
{{ tm('cards.personaPreview') }}
</div>
<PersonaQuickPreview
:model-value="agent.persona_id"
@@ -231,17 +231,17 @@
<!-- Empty State -->
<div v-if="cfg.agents.length === 0" class="d-flex flex-column align-center justify-center py-12 text-medium-emphasis">
<v-icon icon="mdi-robot-off" size="64" class="mb-4 opacity-50" />
<div class="text-h6">No Agents Configured</div>
<div class="text-body-2 mb-4">Add a new sub-agent to get started</div>
<div class="text-h6">{{ tm('empty.title') }}</div>
<div class="text-body-2 mb-4">{{ tm('empty.subtitle') }}</div>
<v-btn color="primary" variant="tonal" @click="addAgent">
Create First Agent
{{ tm('empty.action') }}
</v-btn>
</div>
<v-snackbar v-model="snackbar.show" :color="snackbar.color" timeout="3000" location="top">
{{ snackbar.message }}
<template #actions>
<v-btn variant="text" @click="snackbar.show = false">Close</v-btn>
<v-btn variant="text" @click="snackbar.show = false">{{ tm('actions.close') }}</v-btn>
</template>
</v-snackbar>
</div>
+1 -1
View File
@@ -1,6 +1,6 @@
[project]
name = "AstrBot"
version = "4.18.1"
version = "4.18.2"
description = "Easy-to-use multi-platform LLM chatbot and development framework"
readme = "README.md"
requires-python = ">=3.12"
+381
View File
@@ -0,0 +1,381 @@
"""
AstrBot 测试配置
提供共享的 pytest fixtures 和测试工具
"""
import json
import os
import sys
from asyncio import Queue
from pathlib import Path
from typing import Any
from unittest.mock import AsyncMock, MagicMock
import pytest
import pytest_asyncio
# 使用 tests/fixtures/helpers.py 中的共享工具函数,避免重复定义
from tests.fixtures.helpers import create_mock_llm_response, create_mock_message_component
# 将项目根目录添加到 sys.path
PROJECT_ROOT = Path(__file__).parent.parent
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
# 设置测试环境变量
os.environ.setdefault("TESTING", "true")
os.environ.setdefault("ASTRBOT_TEST_MODE", "true")
# ============================================================
# 测试收集和排序
# ============================================================
def pytest_collection_modifyitems(session, config, items): # noqa: ARG001
"""重新排序测试:单元测试优先,集成测试在后。"""
unit_tests = []
integration_tests = []
deselected = []
profile = config.getoption("--test-profile") or os.environ.get(
"ASTRBOT_TEST_PROFILE", "all"
)
for item in items:
item_path = Path(str(item.path))
is_integration = "integration" in item_path.parts
if is_integration:
if item.get_closest_marker("integration") is None:
item.add_marker(pytest.mark.integration)
item.add_marker(pytest.mark.tier_d)
integration_tests.append(item)
else:
if item.get_closest_marker("unit") is None:
item.add_marker(pytest.mark.unit)
if any(
item.get_closest_marker(marker) is not None
for marker in ("platform", "provider", "slow")
):
item.add_marker(pytest.mark.tier_c)
unit_tests.append(item)
# 单元测试 -> 集成测试
ordered_items = unit_tests + integration_tests
if profile == "blocking":
selected_items = []
for item in ordered_items:
if item.get_closest_marker("tier_c") or item.get_closest_marker("tier_d"):
deselected.append(item)
else:
selected_items.append(item)
if deselected:
config.hook.pytest_deselected(items=deselected)
items[:] = selected_items
return
items[:] = ordered_items
def pytest_addoption(parser):
"""增加测试执行档位选择。"""
parser.addoption(
"--test-profile",
action="store",
default=None,
choices=["all", "blocking"],
help="Select test profile. 'blocking' excludes auto-classified tier_c/tier_d tests.",
)
def pytest_configure(config):
"""注册自定义标记。"""
config.addinivalue_line("markers", "unit: 单元测试")
config.addinivalue_line("markers", "integration: 集成测试")
config.addinivalue_line("markers", "slow: 慢速测试")
config.addinivalue_line("markers", "platform: 平台适配器测试")
config.addinivalue_line("markers", "provider: LLM Provider 测试")
config.addinivalue_line("markers", "db: 数据库相关测试")
config.addinivalue_line("markers", "tier_c: C-tier tests (optional / non-blocking)")
config.addinivalue_line("markers", "tier_d: D-tier tests (extended / integration)")
# ============================================================
# 临时目录和文件 Fixtures
# ============================================================
@pytest.fixture
def temp_dir(tmp_path: Path) -> Path:
"""创建临时目录用于测试。"""
return tmp_path
@pytest.fixture
def event_queue() -> Queue:
"""Create a shared asyncio queue fixture for tests."""
return Queue()
@pytest.fixture
def platform_settings() -> dict:
"""Create a shared empty platform settings fixture for adapter tests."""
return {}
@pytest.fixture
def temp_data_dir(temp_dir: Path) -> Path:
"""创建模拟的 data 目录结构。"""
data_dir = temp_dir / "data"
data_dir.mkdir()
# 创建必要的子目录
(data_dir / "config").mkdir()
(data_dir / "plugins").mkdir()
(data_dir / "temp").mkdir()
(data_dir / "attachments").mkdir()
return data_dir
@pytest.fixture
def temp_config_file(temp_data_dir: Path) -> Path:
"""创建临时配置文件。"""
config_path = temp_data_dir / "config" / "cmd_config.json"
default_config = {
"provider": [],
"platform": [],
"provider_settings": {},
"default_personality": None,
"timezone": "Asia/Shanghai",
}
config_path.write_text(json.dumps(default_config, indent=2), encoding="utf-8")
return config_path
@pytest.fixture
def temp_db_file(temp_data_dir: Path) -> Path:
"""创建临时数据库文件路径。"""
return temp_data_dir / "test.db"
# ============================================================
# Mock Fixtures
# ============================================================
@pytest.fixture
def mock_provider():
"""创建模拟的 Provider。"""
provider = MagicMock()
provider.provider_config = {
"id": "test-provider",
"type": "openai_chat_completion",
"model": "gpt-4o-mini",
}
provider.get_model = MagicMock(return_value="gpt-4o-mini")
provider.text_chat = AsyncMock()
provider.text_chat_stream = AsyncMock()
provider.terminate = AsyncMock()
return provider
@pytest.fixture
def mock_platform():
"""创建模拟的 Platform。"""
platform = MagicMock()
platform.platform_name = "test_platform"
platform.platform_meta = MagicMock()
platform.platform_meta.support_proactive_message = False
platform.send_message = AsyncMock()
platform.terminate = AsyncMock()
return platform
@pytest.fixture
def mock_conversation():
"""创建模拟的 Conversation。"""
from astrbot.core.db.po import ConversationV2
return ConversationV2(
conversation_id="test-conv-id",
platform_id="test_platform",
user_id="test_user",
content=[],
persona_id=None,
)
@pytest.fixture
def mock_event():
"""创建模拟的 AstrMessageEvent。"""
event = MagicMock()
event.unified_msg_origin = "test_umo"
event.session_id = "test_session"
event.message_str = "Hello, world!"
event.message_obj = MagicMock()
event.message_obj.message = []
event.message_obj.sender = MagicMock()
event.message_obj.sender.user_id = "test_user"
event.message_obj.sender.nickname = "Test User"
event.message_obj.group_id = None
event.message_obj.group = None
event.get_platform_name = MagicMock(return_value="test_platform")
event.get_platform_id = MagicMock(return_value="test_platform")
event.get_group_id = MagicMock(return_value=None)
event.get_extra = MagicMock(return_value=None)
event.set_extra = MagicMock()
event.trace = MagicMock()
event.platform_meta = MagicMock()
event.platform_meta.support_proactive_message = False
return event
# ============================================================
# 配置 Fixtures
# ============================================================
@pytest.fixture
def astrbot_config(temp_config_file: Path):
"""创建 AstrBotConfig 实例。"""
from astrbot.core.config.astrbot_config import AstrBotConfig
config = AstrBotConfig()
config._config_path = str(temp_config_file) # noqa: SLF001
return config
@pytest.fixture
def main_agent_build_config():
"""创建 MainAgentBuildConfig 实例。"""
from astrbot.core.astr_main_agent import MainAgentBuildConfig
return MainAgentBuildConfig(
tool_call_timeout=60,
tool_schema_mode="full",
provider_wake_prefix="",
streaming_response=True,
sanitize_context_by_modalities=False,
kb_agentic_mode=False,
file_extract_enabled=False,
context_limit_reached_strategy="truncate_by_turns",
llm_safety_mode=True,
computer_use_runtime="local",
add_cron_tools=True,
)
# ============================================================
# 数据库 Fixtures
# ============================================================
@pytest_asyncio.fixture
async def temp_db(temp_db_file: Path):
"""创建临时数据库实例。"""
from astrbot.core.db.sqlite import SQLiteDatabase
db = SQLiteDatabase(str(temp_db_file))
try:
yield db
finally:
await db.engine.dispose()
if temp_db_file.exists():
temp_db_file.unlink()
# ============================================================
# Context Fixtures
# ============================================================
@pytest_asyncio.fixture
async def mock_context(
astrbot_config,
temp_db,
mock_provider,
mock_platform,
):
"""创建模拟的插件上下文。"""
from asyncio import Queue
from astrbot.core.star.context import Context
event_queue = Queue()
provider_manager = MagicMock()
provider_manager.get_using_provider = MagicMock(return_value=mock_provider)
provider_manager.get_provider_by_id = MagicMock(return_value=mock_provider)
platform_manager = MagicMock()
conversation_manager = MagicMock()
message_history_manager = MagicMock()
persona_manager = MagicMock()
persona_manager.personas_v3 = []
astrbot_config_mgr = MagicMock()
knowledge_base_manager = MagicMock()
cron_manager = MagicMock()
subagent_orchestrator = None
context = Context(
event_queue,
astrbot_config,
temp_db,
provider_manager,
platform_manager,
conversation_manager,
message_history_manager,
persona_manager,
astrbot_config_mgr,
knowledge_base_manager,
cron_manager,
subagent_orchestrator,
)
return context
# ============================================================
# Provider Request Fixtures
# ============================================================
@pytest.fixture
def provider_request():
"""创建 ProviderRequest 实例。"""
from astrbot.core.provider.entities import ProviderRequest
return ProviderRequest(
prompt="Hello",
session_id="test_session",
image_urls=[],
contexts=[],
system_prompt="You are a helpful assistant.",
)
# ============================================================
# 跳过条件
# ============================================================
def pytest_runtest_setup(item):
"""在测试运行前检查跳过条件。"""
# 跳过需要 API Key 但未设置的 Provider 测试
if item.get_closest_marker("provider"):
if not os.environ.get("TEST_PROVIDER_API_KEY"):
pytest.skip("TEST_PROVIDER_API_KEY not set")
# 跳过需要特定平台的测试
if item.get_closest_marker("platform"):
required_platform = None
marker = item.get_closest_marker("platform")
if marker and marker.args:
required_platform = marker.args[0]
if required_platform and not os.environ.get(
f"TEST_{required_platform.upper()}_ENABLED"
):
pytest.skip(f"TEST_{required_platform.upper()}_ENABLED not set")
+64
View File
@@ -0,0 +1,64 @@
"""
AstrBot 测试数据
此目录存放测试用的静态数据和配置文件
目录结构:
- fixtures/
configs/ # 测试配置文件
messages/ # 测试消息数据
plugins/ # 测试插件
knowledge_base/ # 测试知识库数据
mocks/ # Mock 模块
helpers.py # 辅助函数
"""
import json
from pathlib import Path
from .helpers import (
NoopAwaitable,
create_mock_discord_attachment,
create_mock_discord_channel,
create_mock_discord_user,
create_mock_file,
create_mock_llm_response,
create_mock_message_component,
create_mock_update,
make_platform_config,
)
FIXTURES_DIR = Path(__file__).parent
def load_fixture(filename: str) -> dict:
"""加载 JSON 格式的测试数据。"""
filepath = FIXTURES_DIR / filename
if not filepath.exists():
raise FileNotFoundError(f"Fixture not found: {filepath}")
return json.loads(filepath.read_text(encoding="utf-8"))
def get_fixture_path(filename: str) -> Path:
"""获取测试数据文件路径。"""
filepath = FIXTURES_DIR / filename
if not filepath.exists():
raise FileNotFoundError(f"Fixture not found: {filepath}")
return filepath
__all__ = [
"FIXTURES_DIR",
"load_fixture",
"get_fixture_path",
# 辅助函数
"NoopAwaitable",
"make_platform_config",
"create_mock_update",
"create_mock_file",
"create_mock_discord_attachment",
"create_mock_discord_user",
"create_mock_discord_channel",
"create_mock_message_component",
"create_mock_llm_response",
]
+21
View File
@@ -0,0 +1,21 @@
{
"provider": [
{
"id": "test-openai",
"type": "openai_chat_completion",
"model": "gpt-4o-mini",
"key": ["test-key"]
}
],
"platform": [],
"provider_settings": {
"default_personality": null,
"prompt_prefix": "",
"image_caption_provider_id": "",
"datetime_system_prompt": true,
"identifier": true,
"group_name_display": true
},
"default_personality": null,
"timezone": "Asia/Shanghai"
}
+332
View File
@@ -0,0 +1,332 @@
"""测试辅助函数和工具类。
提供统一的测试辅助工具减少测试代码重复
"""
from typing import Any
from unittest.mock import AsyncMock, MagicMock
from astrbot.core.message.components import BaseMessageComponent
class NoopAwaitable:
"""可等待的空操作对象。
用于 mock 需要返回 awaitable 对象的方法
"""
def __await__(self):
if False:
yield
return None
# ============================================================
# 平台配置工厂
# ============================================================
def make_platform_config(platform_type: str, **kwargs) -> dict:
"""平台配置工厂函数。
Args:
platform_type: 平台类型 (telegram, discord, aiocqhttp )
**kwargs: 覆盖默认配置的字段
Returns:
dict: 平台配置字典
"""
configs = {
"telegram": {
"id": "test_telegram",
"telegram_token": "test_token_123",
"telegram_api_base_url": "https://api.telegram.org/bot",
"telegram_file_base_url": "https://api.telegram.org/file/bot",
"telegram_command_register": True,
"telegram_command_auto_refresh": True,
"telegram_command_register_interval": 300,
"telegram_media_group_timeout": 2.5,
"telegram_media_group_max_wait": 10.0,
"start_message": "Welcome to AstrBot!",
},
"discord": {
"id": "test_discord",
"discord_token": "test_token_123",
"discord_proxy": None,
"discord_command_register": True,
"discord_guild_id_for_debug": None,
"discord_activity_name": "Playing AstrBot",
},
"aiocqhttp": {
"id": "test_aiocqhttp",
"ws_reverse_host": "0.0.0.0",
"ws_reverse_port": 6199,
"ws_reverse_token": "test_token",
},
"webchat": {
"id": "test_webchat",
},
"wecom": {
"id": "test_wecom",
"wecom_corpid": "test_corpid",
"wecom_secret": "test_secret",
},
}
config = configs.get(platform_type, {"id": f"test_{platform_type}"}).copy()
config.update(kwargs)
return config
# ============================================================
# Telegram 辅助函数
# ============================================================
def create_mock_update(
message_text: str | None = "Hello World",
chat_type: str = "private",
chat_id: int = 123456789,
user_id: int = 987654321,
username: str = "test_user",
message_id: int = 1,
media_group_id: str | None = None,
photo: list | None = None,
video: MagicMock | None = None,
document: MagicMock | None = None,
voice: MagicMock | None = None,
sticker: MagicMock | None = None,
reply_to_message: MagicMock | None = None,
caption: str | None = None,
entities: list | None = None,
caption_entities: list | None = None,
message_thread_id: int | None = None,
is_topic_message: bool = False,
):
"""创建模拟的 Telegram Update 对象。
Args:
message_text: 消息文本
chat_type: 聊天类型
chat_id: 聊天 ID
user_id: 用户 ID
username: 用户名
message_id: 消息 ID
media_group_id: 媒体组 ID
photo: 图片列表
video: 视频对象
document: 文档对象
voice: 语音对象
sticker: 贴纸对象
reply_to_message: 回复的消息
caption: 说明文字
entities: 实体列表
caption_entities: 说明实体列表
message_thread_id: 消息线程 ID
is_topic_message: 是否为主题消息
Returns:
MagicMock: 模拟的 Update 对象
"""
update = MagicMock()
update.update_id = 1
# Create message mock
message = MagicMock()
message.message_id = message_id
message.chat = MagicMock()
message.chat.id = chat_id
message.chat.type = chat_type
message.message_thread_id = message_thread_id
message.is_topic_message = is_topic_message
# Create user mock
from_user = MagicMock()
from_user.id = user_id
from_user.username = username
message.from_user = from_user
# Set message content
message.text = message_text
message.media_group_id = media_group_id
message.photo = photo
message.video = video
message.document = document
message.voice = voice
message.sticker = sticker
message.reply_to_message = reply_to_message
message.caption = caption
message.entities = entities
message.caption_entities = caption_entities
update.message = message
update.effective_chat = message.chat
return update
def create_mock_file(file_path: str = "https://api.telegram.org/file/test.jpg"):
"""创建模拟的 Telegram File 对象。
Args:
file_path: 文件路径
Returns:
MagicMock: 模拟的 File 对象
"""
file = MagicMock()
file.file_path = file_path
file.get_file = AsyncMock(return_value=file)
return file
# ============================================================
# Discord 辅助函数
# ============================================================
def create_mock_discord_attachment(
filename: str = "test.txt",
url: str = "https://cdn.discordapp.com/test.txt",
content_type: str | None = None,
size: int = 1024,
):
"""创建模拟的 Discord Attachment 对象。
Args:
filename: 文件名
url: 文件 URL
content_type: 内容类型
size: 文件大小
Returns:
MagicMock: 模拟的 Attachment 对象
"""
attachment = MagicMock()
attachment.filename = filename
attachment.url = url
attachment.content_type = content_type
attachment.size = size
return attachment
def create_mock_discord_user(
user_id: int = 123456789,
name: str = "TestUser",
display_name: str = "Test User",
bot: bool = False,
):
"""创建模拟的 Discord User 对象。
Args:
user_id: 用户 ID
name: 用户名
display_name: 显示名
bot: 是否为机器人
Returns:
MagicMock: 模拟的 User 对象
"""
user = MagicMock()
user.id = user_id
user.name = name
user.display_name = display_name
user.bot = bot
user.mention = f"<@{user_id}>"
return user
def create_mock_discord_channel(
channel_id: int = 111222333,
channel_type: str = "text",
name: str = "general",
guild_id: int | None = 444555666,
):
"""创建模拟的 Discord Channel 对象。
Args:
channel_id: 频道 ID
channel_type: 频道类型
name: 频道名
guild_id: 服务器 ID
Returns:
MagicMock: 模拟的 Channel 对象
"""
channel = MagicMock()
channel.id = channel_id
channel.name = name
channel.type = channel_type
if guild_id:
channel.guild = MagicMock()
channel.guild.id = guild_id
else:
channel.guild = None
return channel
# ============================================================
# 消息组件辅助函数
# ============================================================
def create_mock_message_component(
component_type: str,
**kwargs: Any,
) -> BaseMessageComponent:
"""创建模拟的消息组件。
Args:
component_type: 组件类型 (plain, image, at, reply, file)
**kwargs: 组件参数
Returns:
BaseMessageComponent: 消息组件实例
"""
from astrbot.core.message import components as Comp
component_map = {
"plain": Comp.Plain,
"image": Comp.Image,
"at": Comp.At,
"reply": Comp.Reply,
"file": Comp.File,
}
component_class = component_map.get(component_type.lower())
if not component_class:
raise ValueError(f"Unknown component type: {component_type}")
return component_class(**kwargs)
def create_mock_llm_response(
completion_text: str = "Hello! How can I help you?",
role: str = "assistant",
tools_call_name: list[str] | None = None,
tools_call_args: list[dict] | None = None,
tools_call_ids: list[str] | None = None,
):
"""创建模拟的 LLM 响应。
Args:
completion_text: 完成文本
role: 角色
tools_call_name: 工具调用名称列表
tools_call_args: 工具调用参数列表
tools_call_ids: 工具调用 ID 列表
Returns:
LLMResponse: 模拟的 LLM 响应
"""
from astrbot.core.provider.entities import LLMResponse, TokenUsage
return LLMResponse(
role=role,
completion_text=completion_text,
tools_call_name=tools_call_name or [],
tools_call_args=tools_call_args or [],
tools_call_ids=tools_call_ids or [],
usage=TokenUsage(input_other=10, output=5),
)
+33
View File
@@ -0,0 +1,33 @@
{
"plain_message": {
"type": "plain",
"text": "Hello, this is a test message."
},
"image_message": {
"type": "image",
"url": "https://example.com/test.jpg",
"file": null
},
"at_message": {
"type": "at",
"user_id": "12345",
"nickname": "TestUser"
},
"reply_message": {
"type": "reply",
"id": "msg_123",
"sender_nickname": "OriginalSender",
"message_str": "This is the original message"
},
"file_message": {
"type": "file",
"name": "test.pdf",
"url": "https://example.com/test.pdf"
},
"combined_message": {
"components": [
{"type": "at", "user_id": "bot_id"},
{"type": "plain", "text": " Hello bot!"}
]
}
}
+43
View File
@@ -0,0 +1,43 @@
"""测试 Mock 模块。
提供统一的 mock 工具和 fixture减少测试代码重复
使用方式:
# 在测试文件顶部导入需要的 fixture
from tests.fixtures.mocks import mock_telegram_modules
# 或使用 Builder 类创建 mock 对象
from tests.fixtures.mocks import MockTelegramBuilder
bot = MockTelegramBuilder.create_bot()
"""
from .aiocqhttp import (
MockAiocqhttpBuilder,
create_mock_aiocqhttp_modules,
mock_aiocqhttp_modules,
)
from .discord import (
MockDiscordBuilder,
create_mock_discord_modules,
mock_discord_modules,
)
from .telegram import (
MockTelegramBuilder,
create_mock_telegram_modules,
mock_telegram_modules,
)
__all__ = [
# Telegram
"mock_telegram_modules",
"create_mock_telegram_modules",
"MockTelegramBuilder",
# Discord
"mock_discord_modules",
"create_mock_discord_modules",
"MockDiscordBuilder",
# Aiocqhttp
"mock_aiocqhttp_modules",
"create_mock_aiocqhttp_modules",
"MockAiocqhttpBuilder",
]
+58
View File
@@ -0,0 +1,58 @@
"""Aiocqhttp 模块 Mock 工具。
提供统一的 aiocqhttp 相关模块 mock 设置避免在测试文件中重复定义
"""
import sys
from unittest.mock import AsyncMock, MagicMock
import pytest
def create_mock_aiocqhttp_modules():
"""创建 aiocqhttp 相关的 mock 模块。
Returns:
dict: 包含 aiocqhttp 和相关模块的 mock 对象
"""
mock_aiocqhttp = MagicMock()
mock_aiocqhttp.CQHttp = MagicMock
mock_aiocqhttp.Event = MagicMock
mock_aiocqhttp.exceptions = MagicMock()
mock_aiocqhttp.exceptions.ActionFailed = Exception
return mock_aiocqhttp
@pytest.fixture(scope="module", autouse=True)
def mock_aiocqhttp_modules():
"""Mock aiocqhttp 相关模块的 fixture。
自动应用于使用此 fixture 的测试模块
"""
mock_aiocqhttp = create_mock_aiocqhttp_modules()
monkeypatch = pytest.MonkeyPatch()
monkeypatch.setitem(sys.modules, "aiocqhttp", mock_aiocqhttp)
monkeypatch.setitem(sys.modules, "aiocqhttp.exceptions", mock_aiocqhttp.exceptions)
yield
monkeypatch.undo()
class MockAiocqhttpBuilder:
"""构建 aiocqhttp 测试 mock 对象的工具类。"""
@staticmethod
def create_bot():
"""创建 mock CQHttp bot 实例。"""
from tests.fixtures.helpers import NoopAwaitable
bot = MagicMock()
bot.send = AsyncMock()
bot.call_action = AsyncMock()
bot.on_request = MagicMock()
bot.on_notice = MagicMock()
bot.on_message = MagicMock()
bot.on_websocket_connection = MagicMock()
bot.run_task = MagicMock(return_value=NoopAwaitable())
return bot
+140
View File
@@ -0,0 +1,140 @@
"""Discord 模块 Mock 工具。
提供统一的 Discord 相关模块 mock 设置避免在测试文件中重复定义
"""
import sys
from unittest.mock import AsyncMock, MagicMock
import pytest
def create_mock_discord_modules():
"""创建 Discord 相关的 mock 模块。
Returns:
dict: 包含 discord 和相关模块的 mock 对象
"""
mock_discord = MagicMock()
# Mock discord.Intents
mock_intents = MagicMock()
mock_intents.default = MagicMock(return_value=mock_intents)
mock_discord.Intents = mock_intents
# Mock discord.Status
mock_discord.Status = MagicMock()
mock_discord.Status.online = "online"
# Mock discord.Bot
mock_bot = MagicMock()
mock_discord.Bot = MagicMock(return_value=mock_bot)
# Mock discord.Embed
mock_embed = MagicMock()
mock_discord.Embed = MagicMock(return_value=mock_embed)
# Mock discord.ui
mock_ui = MagicMock()
mock_ui.View = MagicMock
mock_ui.Button = MagicMock
mock_discord.ui = mock_ui
# Mock discord.Message
mock_discord.Message = MagicMock
# Mock discord.Interaction
mock_discord.Interaction = MagicMock
mock_discord.InteractionType = MagicMock()
mock_discord.InteractionType.application_command = 2
mock_discord.InteractionType.component = 3
# Mock discord.File
mock_discord.File = MagicMock
# Mock discord.SlashCommand
mock_discord.SlashCommand = MagicMock
# Mock discord.Option
mock_discord.Option = MagicMock
# Mock discord.SlashCommandOptionType
mock_discord.SlashCommandOptionType = MagicMock()
mock_discord.SlashCommandOptionType.string = 3
# Mock discord.errors
mock_discord.errors = MagicMock()
mock_discord.errors.LoginFailure = Exception
mock_discord.errors.ConnectionClosed = Exception
mock_discord.errors.NotFound = Exception
mock_discord.errors.Forbidden = Exception
# Mock discord.abc
mock_discord.abc = MagicMock()
mock_discord.abc.GuildChannel = MagicMock
mock_discord.abc.Messageable = MagicMock
mock_discord.abc.PrivateChannel = MagicMock
# Mock discord.channel
mock_channel = MagicMock()
mock_channel.DMChannel = MagicMock
mock_discord.channel = mock_channel
# Mock discord.types
mock_discord.types = MagicMock()
mock_discord.types.interactions = MagicMock()
# Mock discord.ApplicationContext
mock_discord.ApplicationContext = MagicMock
# Mock discord.CustomActivity
mock_discord.CustomActivity = MagicMock
return mock_discord
@pytest.fixture(scope="module", autouse=True)
def mock_discord_modules():
"""Mock Discord 相关模块的 fixture。
自动应用于使用此 fixture 的测试模块
"""
mock_discord = create_mock_discord_modules()
monkeypatch = pytest.MonkeyPatch()
monkeypatch.setitem(sys.modules, "discord", mock_discord)
monkeypatch.setitem(sys.modules, "discord.abc", mock_discord.abc)
monkeypatch.setitem(sys.modules, "discord.channel", mock_discord.channel)
monkeypatch.setitem(sys.modules, "discord.errors", mock_discord.errors)
monkeypatch.setitem(sys.modules, "discord.types", mock_discord.types)
monkeypatch.setitem(
sys.modules,
"discord.types.interactions",
mock_discord.types.interactions,
)
monkeypatch.setitem(sys.modules, "discord.ui", mock_discord.ui)
yield
monkeypatch.undo()
class MockDiscordBuilder:
"""构建 Discord 测试 mock 对象的工具类。"""
@staticmethod
def create_client():
"""创建 mock Discord client 实例。"""
client = MagicMock()
client.user = MagicMock()
client.user.id = 123456789
client.user.display_name = "TestBot"
client.user.name = "TestBot"
client.get_channel = MagicMock()
client.fetch_channel = AsyncMock()
client.get_message = MagicMock()
client.start = AsyncMock()
client.close = AsyncMock()
client.is_closed = MagicMock(return_value=False)
client.add_application_command = MagicMock()
client.sync_commands = AsyncMock()
client.change_presence = AsyncMock()
return client
+141
View File
@@ -0,0 +1,141 @@
"""Telegram 模块 Mock 工具。
提供统一的 Telegram 相关模块 mock 设置避免在测试文件中重复定义
"""
import sys
from unittest.mock import AsyncMock, MagicMock
import pytest
def create_mock_telegram_modules():
"""创建 Telegram 相关的 mock 模块。
Returns:
dict: 包含 telegram 和相关模块的 mock 对象
"""
mock_telegram = MagicMock()
mock_telegram.BotCommand = MagicMock
mock_telegram.Update = MagicMock
mock_telegram.constants = MagicMock()
mock_telegram.constants.ChatType = MagicMock()
mock_telegram.constants.ChatType.PRIVATE = "private"
mock_telegram.constants.ChatAction = MagicMock()
mock_telegram.constants.ChatAction.TYPING = "typing"
mock_telegram.constants.ChatAction.UPLOAD_VOICE = "upload_voice"
mock_telegram.constants.ChatAction.UPLOAD_DOCUMENT = "upload_document"
mock_telegram.constants.ChatAction.UPLOAD_PHOTO = "upload_photo"
mock_telegram.error = MagicMock()
mock_telegram.error.BadRequest = Exception
mock_telegram.ReactionTypeCustomEmoji = MagicMock
mock_telegram.ReactionTypeEmoji = MagicMock
mock_telegram_ext = MagicMock()
mock_telegram_ext.ApplicationBuilder = MagicMock
mock_telegram_ext.ContextTypes = MagicMock
mock_telegram_ext.ExtBot = MagicMock
mock_telegram_ext.filters = MagicMock()
mock_telegram_ext.filters.ALL = MagicMock()
mock_telegram_ext.MessageHandler = MagicMock
# Mock telegramify_markdown
mock_telegramify = MagicMock()
mock_telegramify.markdownify = lambda text, **kwargs: text
# Mock apscheduler
mock_apscheduler = MagicMock()
mock_apscheduler.schedulers = MagicMock()
mock_apscheduler.schedulers.asyncio = MagicMock()
mock_apscheduler.schedulers.asyncio.AsyncIOScheduler = MagicMock
mock_apscheduler.schedulers.background = MagicMock()
mock_apscheduler.schedulers.background.BackgroundScheduler = MagicMock
return {
"telegram": mock_telegram,
"telegram.ext": mock_telegram_ext,
"telegramify_markdown": mock_telegramify,
"apscheduler": mock_apscheduler,
}
@pytest.fixture(scope="module", autouse=True)
def mock_telegram_modules():
"""Mock Telegram 相关模块的 fixture。
自动应用于使用此 fixture 的测试模块
"""
mocks = create_mock_telegram_modules()
monkeypatch = pytest.MonkeyPatch()
monkeypatch.setitem(sys.modules, "telegram", mocks["telegram"])
monkeypatch.setitem(sys.modules, "telegram.constants", mocks["telegram"].constants)
monkeypatch.setitem(sys.modules, "telegram.error", mocks["telegram"].error)
monkeypatch.setitem(sys.modules, "telegram.ext", mocks["telegram.ext"])
monkeypatch.setitem(sys.modules, "telegramify_markdown", mocks["telegramify_markdown"])
monkeypatch.setitem(sys.modules, "apscheduler", mocks["apscheduler"])
monkeypatch.setitem(
sys.modules, "apscheduler.schedulers", mocks["apscheduler"].schedulers
)
monkeypatch.setitem(
sys.modules,
"apscheduler.schedulers.asyncio",
mocks["apscheduler"].schedulers.asyncio,
)
monkeypatch.setitem(
sys.modules,
"apscheduler.schedulers.background",
mocks["apscheduler"].schedulers.background,
)
yield
monkeypatch.undo()
class MockTelegramBuilder:
"""构建 Telegram 测试 mock 对象的工具类。"""
@staticmethod
def create_bot():
"""创建 mock Telegram bot 实例。"""
bot = MagicMock()
bot.username = "test_bot"
bot.id = 12345678
bot.base_url = "https://api.telegram.org/bottest_token_123/"
bot.send_message = AsyncMock()
bot.send_photo = AsyncMock()
bot.send_document = AsyncMock()
bot.send_voice = AsyncMock()
bot.send_chat_action = AsyncMock()
bot.delete_my_commands = AsyncMock()
bot.set_my_commands = AsyncMock()
bot.set_message_reaction = AsyncMock()
bot.edit_message_text = AsyncMock()
return bot
@staticmethod
def create_application():
"""创建 mock Telegram Application 实例。"""
from tests.fixtures.helpers import NoopAwaitable
app = MagicMock()
app.bot = MagicMock()
app.bot.username = "test_bot"
app.bot.base_url = "https://api.telegram.org/bottest_token_123/"
app.initialize = AsyncMock()
app.start = AsyncMock()
app.stop = AsyncMock()
app.add_handler = MagicMock()
app.updater = MagicMock()
app.updater.start_polling = MagicMock(return_value=NoopAwaitable())
app.updater.stop = AsyncMock()
return app
@staticmethod
def create_scheduler():
"""创建 mock APScheduler 实例。"""
scheduler = MagicMock()
scheduler.add_job = MagicMock()
scheduler.start = MagicMock()
scheduler.running = True
scheduler.shutdown = MagicMock()
return scheduler
+40
View File
@@ -0,0 +1,40 @@
"""
测试插件 - 用于插件系统测试
这是一个最小化的测试插件用于验证插件系统的功能
"""
from astrbot.api import llm_tool, star
from astrbot.api.event import AstrMessageEvent, MessageEventResult, filter
@star.register("test_plugin", "AstrBot Team", "测试插件 - 用于插件系统测试", "1.0.0")
class TestPlugin(star.Star):
"""测试插件类"""
def __init__(self, context: star.Context) -> None:
super().__init__(context)
self.initialized = True
async def terminate(self) -> None:
"""插件终止"""
self.initialized = False
@filter.command("test_cmd")
async def test_command(self, event: AstrMessageEvent) -> None:
"""测试命令处理器。"""
event.set_result(MessageEventResult().message("测试命令执行成功"))
@llm_tool("test_tool")
async def test_llm_tool(self, query: str) -> str:
"""测试 LLM 工具。
Args:
query(string): 查询内容
"""
return f"测试工具执行成功: {query}"
@filter.regex(r"^test_regex_(.+)$")
async def test_regex_handler(self, event: AstrMessageEvent) -> None:
"""测试正则处理器。"""
event.set_result(MessageEventResult().message("正则匹配成功"))
+5
View File
@@ -0,0 +1,5 @@
name: test_plugin
description: 测试插件 - 用于插件系统测试
version: 1.0.0
author: AstrBot Team
repo: https://github.com/test/test_plugin
+115
View File
@@ -0,0 +1,115 @@
"""Smoke tests for critical startup and import paths."""
from __future__ import annotations
import subprocess
import sys
from pathlib import Path
from astrbot.core.pipeline.bootstrap import ensure_builtin_stages_registered
from astrbot.core.pipeline.process_stage.method.agent_sub_stages.internal import (
InternalAgentSubStage,
)
from astrbot.core.pipeline.process_stage.method.agent_sub_stages.third_party import (
ThirdPartyAgentSubStage,
)
from astrbot.core.pipeline.stage import Stage, registered_stages
from astrbot.core.pipeline.stage_order import STAGES_ORDER
REPO_ROOT = Path(__file__).resolve().parents[1]
def _run_code_in_fresh_interpreter(code: str, failure_message: str) -> None:
proc = subprocess.run(
[sys.executable, "-c", code],
cwd=REPO_ROOT,
capture_output=True,
text=True,
check=False,
)
assert proc.returncode == 0, (
f"{failure_message}\nstdout:\n{proc.stdout}\nstderr:\n{proc.stderr}\n"
)
def test_smoke_critical_imports_in_fresh_interpreter() -> None:
code = (
"import importlib;"
"mods=["
"'astrbot.core.core_lifecycle',"
"'astrbot.core.astr_main_agent',"
"'astrbot.core.pipeline.scheduler',"
"'astrbot.core.pipeline.process_stage.method.agent_sub_stages.internal',"
"'astrbot.core.pipeline.process_stage.method.agent_sub_stages.third_party'"
"];"
"[importlib.import_module(m) for m in mods]"
)
_run_code_in_fresh_interpreter(code, "Smoke import check failed.")
def test_smoke_pipeline_stage_registration_matches_order() -> None:
ensure_builtin_stages_registered()
stage_names = {cls.__name__ for cls in registered_stages}
assert set(STAGES_ORDER).issubset(stage_names)
assert len(stage_names) == len(registered_stages)
def test_smoke_agent_sub_stages_are_stage_subclasses() -> None:
assert issubclass(InternalAgentSubStage, Stage)
assert issubclass(ThirdPartyAgentSubStage, Stage)
def test_pipeline_package_exports_remain_compatible() -> None:
import astrbot.core.pipeline as pipeline
assert pipeline.ProcessStage is not None
assert pipeline.RespondStage is not None
assert isinstance(pipeline.STAGES_ORDER, list)
assert "ProcessStage" in pipeline.STAGES_ORDER
def test_builtin_stage_bootstrap_is_idempotent() -> None:
ensure_builtin_stages_registered()
before_count = len(registered_stages)
stage_names = {cls.__name__ for cls in registered_stages}
expected_stage_names = {
"WakingCheckStage",
"WhitelistCheckStage",
"SessionStatusCheckStage",
"RateLimitStage",
"ContentSafetyCheckStage",
"PreProcessStage",
"ProcessStage",
"ResultDecorateStage",
"RespondStage",
}
assert expected_stage_names.issubset(stage_names)
ensure_builtin_stages_registered()
assert len(registered_stages) == before_count
def test_pipeline_import_is_stable_with_mocked_apscheduler() -> None:
"""Regression: importing pipeline should not require cron/apscheduler modules."""
code = (
"import sys;"
"from unittest.mock import MagicMock;"
"mock_apscheduler = MagicMock();"
"mock_apscheduler.schedulers = MagicMock();"
"mock_apscheduler.schedulers.asyncio = MagicMock();"
"mock_apscheduler.schedulers.background = MagicMock();"
"sys.modules['apscheduler'] = mock_apscheduler;"
"sys.modules['apscheduler.schedulers'] = mock_apscheduler.schedulers;"
"sys.modules['apscheduler.schedulers.asyncio'] = mock_apscheduler.schedulers.asyncio;"
"sys.modules['apscheduler.schedulers.background'] = mock_apscheduler.schedulers.background;"
"import astrbot.core.pipeline as pipeline;"
"assert pipeline.ProcessStage is not None;"
"assert pipeline.RespondStage is not None"
)
_run_code_in_fresh_interpreter(
code,
"Pipeline import should not depend on real apscheduler package.",
)
+781
View File
@@ -0,0 +1,781 @@
"""Tests for AstrMessageEvent class."""
import re
from unittest.mock import AsyncMock, patch
import pytest
from astrbot.core.message.components import (
At,
AtAll,
Face,
Forward,
Image,
Plain,
Reply,
)
from astrbot.core.message.message_event_result import MessageEventResult
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.platform.astrbot_message import AstrBotMessage, MessageMember
from astrbot.core.platform.message_type import MessageType
from astrbot.core.platform.platform_metadata import PlatformMetadata
class ConcreteAstrMessageEvent(AstrMessageEvent):
"""Concrete implementation of AstrMessageEvent for testing purposes."""
async def send(self, message):
"""Send message implementation."""
await super().send(message)
@pytest.fixture
def platform_meta():
"""Create platform metadata for testing."""
return PlatformMetadata(
name="test_platform",
description="Test platform",
id="test_platform_id",
)
@pytest.fixture
def message_member():
"""Create a message member for testing."""
return MessageMember(user_id="user123", nickname="TestUser")
@pytest.fixture
def astrbot_message(message_member):
"""Create an AstrBotMessage for testing."""
message = AstrBotMessage()
message.type = MessageType.FRIEND_MESSAGE
message.self_id = "bot123"
message.session_id = "session123"
message.message_id = "msg123"
message.sender = message_member
message.message = [Plain(text="Hello world")]
message.message_str = "Hello world"
message.raw_message = None
return message
@pytest.fixture
def astr_message_event(platform_meta, astrbot_message):
"""Create an AstrMessageEvent instance for testing."""
return ConcreteAstrMessageEvent(
message_str="Hello world",
message_obj=astrbot_message,
platform_meta=platform_meta,
session_id="session123",
)
class TestAstrMessageEventInit:
"""Tests for AstrMessageEvent initialization."""
def test_init_basic(self, astr_message_event):
"""Test basic AstrMessageEvent initialization."""
assert astr_message_event.message_str == "Hello world"
assert astr_message_event.role == "member"
assert astr_message_event.is_wake is False
assert astr_message_event.is_at_or_wake_command is False
assert astr_message_event._extras == {}
assert astr_message_event._result is None
assert astr_message_event.call_llm is False
def test_init_session(self, astr_message_event):
"""Test session initialization."""
assert astr_message_event.session_id == "session123"
assert astr_message_event.session.platform_name == "test_platform_id"
def test_init_platform_reference(self, astr_message_event, platform_meta):
"""Test platform reference initialization."""
assert astr_message_event.platform_meta == platform_meta
assert astr_message_event.platform == platform_meta # back compatibility
def test_init_created_at(self, astr_message_event):
"""Test created_at timestamp is set."""
assert astr_message_event.created_at is not None
assert isinstance(astr_message_event.created_at, float)
def test_init_trace(self, astr_message_event):
"""Test trace/span initialization."""
assert astr_message_event.trace is not None
assert astr_message_event.span is not None
assert astr_message_event.trace == astr_message_event.span
class TestUnifiedMsgOrigin:
"""Tests for unified_msg_origin property."""
def test_unified_msg_origin_getter(self, astr_message_event):
"""Test unified_msg_origin getter."""
expected = "test_platform_id:FriendMessage:session123"
assert astr_message_event.unified_msg_origin == expected
def test_unified_msg_origin_setter(self, astr_message_event):
"""Test unified_msg_origin setter."""
astr_message_event.unified_msg_origin = "new_platform:GroupMessage:new_session"
assert astr_message_event.session.platform_name == "new_platform"
assert astr_message_event.session.session_id == "new_session"
class TestSessionId:
"""Tests for session_id property."""
def test_session_id_getter(self, astr_message_event):
"""Test session_id getter."""
assert astr_message_event.session_id == "session123"
def test_session_id_setter(self, astr_message_event):
"""Test session_id setter."""
astr_message_event.session_id = "new_session_id"
assert astr_message_event.session_id == "new_session_id"
class TestGetPlatformInfo:
"""Tests for platform info methods."""
def test_get_platform_name(self, astr_message_event):
"""Test get_platform_name method."""
assert astr_message_event.get_platform_name() == "test_platform"
def test_get_platform_id(self, astr_message_event):
"""Test get_platform_id method."""
assert astr_message_event.get_platform_id() == "test_platform_id"
class TestGetMessageInfo:
"""Tests for message info methods."""
def test_get_message_str(self, astr_message_event):
"""Test get_message_str method."""
assert astr_message_event.get_message_str() == "Hello world"
def test_get_message_str_none(self, platform_meta, astrbot_message):
"""Test get_message_str keeps None when source message_str is None."""
astrbot_message.message_str = None
event = ConcreteAstrMessageEvent(
message_str=None,
message_obj=astrbot_message,
platform_meta=platform_meta,
session_id="session123",
)
assert event.get_message_str() is None
def test_get_messages(self, astr_message_event):
"""Test get_messages method."""
messages = astr_message_event.get_messages()
assert len(messages) == 1
assert isinstance(messages[0], Plain)
assert messages[0].text == "Hello world"
def test_get_message_type(self, astr_message_event):
"""Test get_message_type method."""
assert astr_message_event.get_message_type() == MessageType.FRIEND_MESSAGE
def test_get_session_id(self, astr_message_event):
"""Test get_session_id method."""
assert astr_message_event.get_session_id() == "session123"
def test_get_group_id_empty_for_private(self, astr_message_event):
"""Test get_group_id returns empty for private messages."""
assert astr_message_event.get_group_id() == ""
def test_get_self_id(self, astr_message_event):
"""Test get_self_id method."""
assert astr_message_event.get_self_id() == "bot123"
def test_get_sender_id(self, astr_message_event):
"""Test get_sender_id method."""
assert astr_message_event.get_sender_id() == "user123"
def test_get_sender_name(self, astr_message_event):
"""Test get_sender_name method."""
assert astr_message_event.get_sender_name() == "TestUser"
def test_get_sender_name_empty_when_none(self, platform_meta, astrbot_message):
"""Test get_sender_name returns empty string when nickname is None."""
astrbot_message.sender = MessageMember(user_id="user123", nickname=None)
event = ConcreteAstrMessageEvent(
message_str="test",
message_obj=astrbot_message,
platform_meta=platform_meta,
session_id="session123",
)
assert event.get_sender_name() == ""
def test_get_sender_name_coerces_non_string(self, platform_meta, astrbot_message):
"""Test get_sender_name stringifies non-string nickname values."""
astrbot_message.sender = MessageMember(user_id="user123", nickname=None)
astrbot_message.sender.nickname = 12345
event = ConcreteAstrMessageEvent(
message_str="test",
message_obj=astrbot_message,
platform_meta=platform_meta,
session_id="session123",
)
assert event.get_sender_name() == "12345"
class TestGetMessageOutline:
"""Tests for get_message_outline method."""
def test_outline_plain_text(self, astr_message_event):
"""Test outline with plain text message."""
outline = astr_message_event.get_message_outline()
assert "Hello world" in outline
def test_outline_with_image(self, platform_meta, astrbot_message):
"""Test outline with image component."""
astrbot_message.message = [
Plain(text="Look at this"),
Image(file="http://example.com/img.jpg"),
]
event = ConcreteAstrMessageEvent(
message_str="Look at this",
message_obj=astrbot_message,
platform_meta=platform_meta,
session_id="session123",
)
outline = event.get_message_outline()
assert "Look at this" in outline
assert "[图片]" in outline
def test_outline_with_at(self, platform_meta, astrbot_message):
"""Test outline with At component."""
astrbot_message.message = [At(qq="12345"), Plain(text=" hello")]
event = ConcreteAstrMessageEvent(
message_str=" hello",
message_obj=astrbot_message,
platform_meta=platform_meta,
session_id="session123",
)
outline = event.get_message_outline()
assert "[At:12345]" in outline
def test_outline_with_at_all(self, platform_meta, astrbot_message):
"""Test outline with AtAll component."""
astrbot_message.message = [AtAll()]
event = ConcreteAstrMessageEvent(
message_str="",
message_obj=astrbot_message,
platform_meta=platform_meta,
session_id="session123",
)
outline = event.get_message_outline()
# AtAll format is "[At:all]" in the actual implementation
assert "[At:" in outline and "all" in outline.lower()
def test_outline_with_face(self, platform_meta, astrbot_message):
"""Test outline with Face component."""
astrbot_message.message = [Face(id="123")]
event = ConcreteAstrMessageEvent(
message_str="",
message_obj=astrbot_message,
platform_meta=platform_meta,
session_id="session123",
)
outline = event.get_message_outline()
assert "[表情:123]" in outline
def test_outline_with_forward(self, platform_meta, astrbot_message):
"""Test outline with Forward component."""
# Forward requires an id parameter
astrbot_message.message = [Forward(id="test_forward_id")]
event = ConcreteAstrMessageEvent(
message_str="",
message_obj=astrbot_message,
platform_meta=platform_meta,
session_id="session123",
)
outline = event.get_message_outline()
assert "[转发消息]" in outline
def test_outline_with_reply(self, platform_meta, astrbot_message):
"""Test outline with Reply component."""
# Reply requires an id parameter
reply = Reply(id="test_reply_id")
reply.message_str = "Original message"
reply.sender_nickname = "Sender"
astrbot_message.message = [reply, Plain(text=" reply")]
event = ConcreteAstrMessageEvent(
message_str=" reply",
message_obj=astrbot_message,
platform_meta=platform_meta,
session_id="session123",
)
outline = event.get_message_outline()
assert "[引用消息(Sender: Original message)]" in outline
def test_outline_with_reply_no_message(self, platform_meta, astrbot_message):
"""Test outline with Reply component without message_str."""
# Reply requires an id parameter
reply = Reply(id="test_reply_id")
reply.message_str = None
astrbot_message.message = [reply]
event = ConcreteAstrMessageEvent(
message_str="",
message_obj=astrbot_message,
platform_meta=platform_meta,
session_id="session123",
)
outline = event.get_message_outline()
assert "[引用消息]" in outline
def test_outline_empty_chain(self, platform_meta, astrbot_message):
"""Test outline with empty message chain."""
astrbot_message.message = []
event = ConcreteAstrMessageEvent(
message_str="",
message_obj=astrbot_message,
platform_meta=platform_meta,
session_id="session123",
)
outline = event.get_message_outline()
assert outline == ""
def test_outline_very_long_plain_text(self, platform_meta, astrbot_message):
"""Test outline generation for very long plain text content."""
long_text = "A" * 20000
astrbot_message.message = [Plain(text=long_text)]
event = ConcreteAstrMessageEvent(
message_str=long_text,
message_obj=astrbot_message,
platform_meta=platform_meta,
session_id="session123",
)
outline = event.get_message_outline()
assert outline.startswith("A")
assert len(outline) >= 20000
class TestExtras:
"""Tests for extra information methods."""
def test_set_extra(self, astr_message_event):
"""Test set_extra method."""
astr_message_event.set_extra("key1", "value1")
assert astr_message_event._extras["key1"] == "value1"
def test_get_extra_with_key(self, astr_message_event):
"""Test get_extra with specific key."""
astr_message_event.set_extra("key1", "value1")
assert astr_message_event.get_extra("key1") == "value1"
def test_get_extra_with_default(self, astr_message_event):
"""Test get_extra with default value."""
result = astr_message_event.get_extra("nonexistent", "default_value")
assert result == "default_value"
def test_get_extra_all(self, astr_message_event):
"""Test get_extra without key returns all extras."""
astr_message_event.set_extra("key1", "value1")
astr_message_event.set_extra("key2", "value2")
all_extras = astr_message_event.get_extra()
assert all_extras == {"key1": "value1", "key2": "value2"}
def test_clear_extra(self, astr_message_event):
"""Test clear_extra method."""
astr_message_event.set_extra("key1", "value1")
astr_message_event.clear_extra()
assert astr_message_event._extras == {}
class TestSetResult:
"""Tests for set_result method."""
def test_set_result_with_message_event_result(self, astr_message_event):
"""Test set_result with MessageEventResult object."""
result = MessageEventResult().message("Test message")
astr_message_event.set_result(result)
assert astr_message_event._result == result
def test_set_result_with_string(self, astr_message_event):
"""Test set_result with string creates MessageEventResult."""
astr_message_event.set_result("Test message")
assert astr_message_event._result is not None
assert len(astr_message_event._result.chain) == 1
assert isinstance(astr_message_event._result.chain[0], Plain)
def test_set_result_with_empty_chain(self, astr_message_event):
"""Test set_result handles empty chain correctly."""
result = MessageEventResult()
# chain is already an empty list by default
astr_message_event.set_result(result)
assert astr_message_event._result.chain == []
class TestStopContinueEvent:
"""Tests for stop_event and continue_event methods."""
def test_stop_event_creates_result_if_none(self, astr_message_event):
"""Test stop_event creates result if none exists."""
astr_message_event.stop_event()
assert astr_message_event._result is not None
assert astr_message_event.is_stopped() is True
def test_stop_event_with_existing_result(self, astr_message_event):
"""Test stop_event with existing result."""
astr_message_event.set_result(MessageEventResult().message("Test"))
astr_message_event.stop_event()
assert astr_message_event.is_stopped() is True
def test_continue_event_creates_result_if_none(self, astr_message_event):
"""Test continue_event creates result if none exists."""
astr_message_event.continue_event()
assert astr_message_event._result is not None
assert astr_message_event.is_stopped() is False
def test_continue_event_with_existing_result(self, astr_message_event):
"""Test continue_event with existing result."""
astr_message_event.set_result(MessageEventResult().message("Test"))
astr_message_event.stop_event()
astr_message_event.continue_event()
assert astr_message_event.is_stopped() is False
def test_is_stopped_default_false(self, astr_message_event):
"""Test is_stopped returns False by default."""
assert astr_message_event.is_stopped() is False
class TestIsPrivateChat:
"""Tests for is_private_chat method."""
def test_is_private_chat_true(self, astr_message_event):
"""Test is_private_chat returns True for friend message."""
assert astr_message_event.is_private_chat() is True
def test_is_private_chat_false(self, platform_meta, astrbot_message):
"""Test is_private_chat returns False for group message."""
astrbot_message.type = MessageType.GROUP_MESSAGE
event = ConcreteAstrMessageEvent(
message_str="test",
message_obj=astrbot_message,
platform_meta=platform_meta,
session_id="session123",
)
assert event.is_private_chat() is False
class TestIsWakeUp:
"""Tests for is_wake_up method."""
def test_is_wake_up_default_false(self, astr_message_event):
"""Test is_wake_up returns False by default."""
assert astr_message_event.is_wake_up() is False
def test_is_wake_up_when_set(self, astr_message_event):
"""Test is_wake_up returns True when is_wake is set."""
astr_message_event.is_wake = True
assert astr_message_event.is_wake_up() is True
class TestIsAdmin:
"""Tests for is_admin method."""
def test_is_admin_default_false(self, astr_message_event):
"""Test is_admin returns False by default."""
assert astr_message_event.is_admin() is False
def test_is_admin_when_admin(self, astr_message_event):
"""Test is_admin returns True when role is admin."""
astr_message_event.role = "admin"
assert astr_message_event.is_admin() is True
class TestProcessBuffer:
"""Tests for process_buffer method."""
@pytest.mark.asyncio
async def test_process_buffer_splits_by_pattern(self, astr_message_event):
"""Test process_buffer splits buffer by pattern."""
buffer = "Line 1\nLine 2\nLine 3\nRemaining"
pattern = re.compile(r".*\n")
with patch.object(
astr_message_event, "send", new_callable=AsyncMock
) as mock_send:
result = await astr_message_event.process_buffer(buffer, pattern)
# Should have sent 3 lines and remaining should be "Remaining"
assert mock_send.call_count == 3
assert result == "Remaining"
@pytest.mark.asyncio
async def test_process_buffer_no_match(self, astr_message_event):
"""Test process_buffer returns original when no match."""
buffer = "No newlines here"
pattern = re.compile(r"\n")
result = await astr_message_event.process_buffer(buffer, pattern)
assert result == "No newlines here"
class TestResultHelpers:
"""Tests for result helper methods."""
def test_make_result(self, astr_message_event):
"""Test make_result creates empty MessageEventResult."""
result = astr_message_event.make_result()
assert isinstance(result, MessageEventResult)
def test_plain_result(self, astr_message_event):
"""Test plain_result creates result with text."""
result = astr_message_event.plain_result("Hello")
assert isinstance(result, MessageEventResult)
assert len(result.chain) == 1
assert isinstance(result.chain[0], Plain)
assert result.chain[0].text == "Hello"
def test_image_result_url(self, astr_message_event):
"""Test image_result with URL."""
result = astr_message_event.image_result("http://example.com/image.jpg")
assert isinstance(result, MessageEventResult)
assert len(result.chain) == 1
assert isinstance(result.chain[0], Image)
def test_image_result_path(self, astr_message_event):
"""Test image_result with file path."""
result = astr_message_event.image_result("/path/to/image.jpg")
assert isinstance(result, MessageEventResult)
assert len(result.chain) == 1
assert isinstance(result.chain[0], Image)
class TestGetResult:
"""Tests for get_result and clear_result methods."""
def test_get_result_returns_none_by_default(self, astr_message_event):
"""Test get_result returns None by default."""
assert astr_message_event.get_result() is None
def test_get_result_returns_set_result(self, astr_message_event):
"""Test get_result returns set result."""
result = MessageEventResult().message("Test")
astr_message_event.set_result(result)
assert astr_message_event.get_result() == result
def test_clear_result(self, astr_message_event):
"""Test clear_result clears the result."""
astr_message_event.set_result(MessageEventResult().message("Test"))
astr_message_event.clear_result()
assert astr_message_event.get_result() is None
class TestShouldCallLlm:
"""Tests for should_call_llm method."""
def test_should_call_llm_default(self, astr_message_event):
"""Test call_llm default is False."""
assert astr_message_event.call_llm is False
def test_should_call_llm_when_set(self, astr_message_event):
"""Test should_call_llm sets call_llm."""
astr_message_event.should_call_llm(True)
assert astr_message_event.call_llm is True
class TestRequestLlm:
"""Tests for request_llm method."""
def test_request_llm_basic(self, astr_message_event):
"""Test request_llm creates ProviderRequest."""
request = astr_message_event.request_llm(prompt="Hello")
assert request.prompt == "Hello"
assert request.session_id == ""
assert request.image_urls == []
assert request.contexts == []
def test_request_llm_with_all_params(self, astr_message_event):
"""Test request_llm with all parameters."""
request = astr_message_event.request_llm(
prompt="Hello",
session_id="session123",
image_urls=["http://example.com/img.jpg"],
contexts=[{"role": "user", "content": "Hi"}],
system_prompt="You are helpful",
)
assert request.prompt == "Hello"
assert request.session_id == "session123"
assert request.image_urls == ["http://example.com/img.jpg"]
assert request.contexts == [{"role": "user", "content": "Hi"}]
assert request.system_prompt == "You are helpful"
class TestSendStreaming:
"""Tests for send_streaming method."""
@pytest.mark.asyncio
async def test_send_streaming_sets_has_send_oper(self, astr_message_event):
"""Test send_streaming sets _has_send_oper flag."""
assert astr_message_event._has_send_oper is False
async def generator():
yield MessageEventResult().message("Test")
with patch(
"astrbot.core.platform.astr_message_event.Metric.upload",
new_callable=AsyncMock,
):
await astr_message_event.send_streaming(generator())
assert astr_message_event._has_send_oper is True
class TestSendTyping:
"""Tests for send_typing method."""
@pytest.mark.asyncio
async def test_send_typing_default_empty(self, astr_message_event):
"""Test send_typing default implementation is empty."""
# Should not raise any exception
await astr_message_event.send_typing()
class TestReact:
"""Tests for react method."""
@pytest.mark.asyncio
async def test_react_sends_emoji(self, astr_message_event):
"""Test react sends emoji as message."""
with patch.object(
astr_message_event, "send", new_callable=AsyncMock
) as mock_send:
await astr_message_event.react("👍")
mock_send.assert_called_once()
call_arg = mock_send.call_args[0][0]
# MessageChain is a dataclass with chain attribute
assert len(call_arg.chain) == 1
assert isinstance(call_arg.chain[0], Plain)
assert call_arg.chain[0].text == "👍"
class TestGetGroup:
"""Tests for get_group method."""
@pytest.mark.asyncio
async def test_get_group_returns_none_for_private(self, astr_message_event):
"""Test get_group returns None for private chat."""
result = await astr_message_event.get_group()
assert result is None
@pytest.mark.asyncio
async def test_get_group_with_group_id_param(self, astr_message_event):
"""Test get_group with group_id parameter."""
# Default implementation returns None
result = await astr_message_event.get_group(group_id="group123")
assert result is None
class TestMessageTypeHandling:
"""Tests for message type handling edge cases."""
def test_message_type_from_valid_string(self, platform_meta):
"""Valid MessageType string should be converted correctly."""
message = AstrBotMessage()
message.type = "FRIEND_MESSAGE"
message.message = []
event = ConcreteAstrMessageEvent(
message_str="test",
message_obj=message,
platform_meta=platform_meta,
session_id="session123",
)
assert event.session.message_type == MessageType.FRIEND_MESSAGE
assert event.get_message_type() == MessageType.FRIEND_MESSAGE
def test_message_type_from_invalid_string_defaults_to_friend(self, platform_meta):
"""Invalid message type should default to FRIEND_MESSAGE."""
message = AstrBotMessage()
message.type = "InvalidMessageType"
message.message = []
event = ConcreteAstrMessageEvent(
message_str="test",
message_obj=message,
platform_meta=platform_meta,
session_id="session123",
)
assert event.session.message_type == MessageType.FRIEND_MESSAGE
assert event.get_message_type() == MessageType.FRIEND_MESSAGE
def test_message_type_from_none_defaults_to_friend(self, platform_meta):
"""None message type should default to FRIEND_MESSAGE."""
message = AstrBotMessage()
message.type = None
message.message = []
event = ConcreteAstrMessageEvent(
message_str="test",
message_obj=message,
platform_meta=platform_meta,
session_id="session123",
)
assert event.session.message_type == MessageType.FRIEND_MESSAGE
assert event.get_message_type() == MessageType.FRIEND_MESSAGE
def test_message_type_from_integer_defaults_to_friend(self, platform_meta):
"""Integer message type should default to FRIEND_MESSAGE."""
message = AstrBotMessage()
message.type = 123
message.message = []
event = ConcreteAstrMessageEvent(
message_str="test",
message_obj=message,
platform_meta=platform_meta,
session_id="session123",
)
assert event.session.message_type == MessageType.FRIEND_MESSAGE
assert event.get_message_type() == MessageType.FRIEND_MESSAGE
class TestDefensiveGetattr:
"""Tests for defensive getattr behavior in AstrMessageEvent."""
def test_get_messages_without_message_attr(self, astr_message_event):
"""get_messages should handle message_obj without 'message' attribute."""
astr_message_event.message_obj = type("DummyMessage", (), {})()
messages = astr_message_event.get_messages()
assert isinstance(messages, list)
def test_get_message_type_without_type_attr(self, astr_message_event):
"""get_message_type should handle message_obj without 'type' attribute."""
astr_message_event.message_obj = type("DummyMessage", (), {})()
message_type = astr_message_event.get_message_type()
assert isinstance(message_type, MessageType)
def test_get_sender_fields_without_sender_attr(self, astr_message_event):
"""get_sender_id and get_sender_name should handle missing 'sender'."""
astr_message_event.message_obj = type("DummyMessage", (), {})()
sender_id = astr_message_event.get_sender_id()
sender_name = astr_message_event.get_sender_name()
assert isinstance(sender_id, str)
assert isinstance(sender_name, str)
def test_get_message_type_with_non_enum_type(self, astr_message_event):
"""get_message_type should handle message_obj.type that is not a MessageType."""
class DummyMessage:
def __init__(self):
self.type = "not_an_enum"
self.message = []
astr_message_event.message_obj = DummyMessage()
message_type = astr_message_event.get_message_type()
assert isinstance(message_type, MessageType)
+268
View File
@@ -0,0 +1,268 @@
"""Tests for AstrBotMessage and MessageMember classes."""
import time
from unittest.mock import patch
from astrbot.core.message.components import Image, Plain
from astrbot.core.platform.astrbot_message import AstrBotMessage, Group, MessageMember
from astrbot.core.platform.message_type import MessageType
class TestMessageMember:
"""Tests for MessageMember dataclass."""
def test_message_member_creation_basic(self):
"""Test creating a MessageMember with required fields."""
member = MessageMember(user_id="user123")
assert member.user_id == "user123"
assert member.nickname is None
def test_message_member_creation_with_nickname(self):
"""Test creating a MessageMember with nickname."""
member = MessageMember(user_id="user123", nickname="TestUser")
assert member.user_id == "user123"
assert member.nickname == "TestUser"
def test_message_member_str_with_nickname(self):
"""Test __str__ method with nickname."""
member = MessageMember(user_id="user123", nickname="TestUser")
result = str(member)
assert "User ID: user123" in result
assert "Nickname: TestUser" in result
def test_message_member_str_without_nickname(self):
"""Test __str__ method without nickname."""
member = MessageMember(user_id="user123")
result = str(member)
assert "User ID: user123" in result
assert "Nickname: N/A" in result
class TestGroup:
"""Tests for Group dataclass."""
def test_group_creation_basic(self):
"""Test creating a Group with required fields."""
group = Group(group_id="group123")
assert group.group_id == "group123"
assert group.group_name is None
assert group.group_avatar is None
assert group.group_owner is None
assert group.group_admins is None
assert group.members is None
def test_group_creation_with_all_fields(self):
"""Test creating a Group with all fields."""
members = [MessageMember(user_id="user1"), MessageMember(user_id="user2")]
group = Group(
group_id="group123",
group_name="Test Group",
group_avatar="http://example.com/avatar.jpg",
group_owner="owner123",
group_admins=["admin1", "admin2"],
members=members,
)
assert group.group_id == "group123"
assert group.group_name == "Test Group"
assert group.group_avatar == "http://example.com/avatar.jpg"
assert group.group_owner == "owner123"
assert group.group_admins == ["admin1", "admin2"]
assert group.members == members
def test_group_str_with_all_fields(self):
"""Test __str__ method with all fields."""
members = [MessageMember(user_id="user1", nickname="User One")]
group = Group(
group_id="group123",
group_name="Test Group",
group_avatar="http://example.com/avatar.jpg",
group_owner="owner123",
group_admins=["admin1"],
members=members,
)
result = str(group)
assert "Group ID: group123" in result
assert "Name: Test Group" in result
assert "Avatar: http://example.com/avatar.jpg" in result
assert "Owner ID: owner123" in result
assert "Admin IDs: ['admin1']" in result
assert "Members Len: 1" in result
def test_group_str_with_minimal_fields(self):
"""Test __str__ method with minimal fields."""
group = Group(group_id="group123")
result = str(group)
assert "Group ID: group123" in result
assert "Name: N/A" in result
assert "Avatar: N/A" in result
assert "Owner ID: N/A" in result
assert "Admin IDs: N/A" in result
assert "Members Len: 0" in result
assert "First Member: N/A" in result
class TestAstrBotMessage:
"""Tests for AstrBotMessage class."""
def test_astrbot_message_creation(self):
"""Test creating an AstrBotMessage."""
message = AstrBotMessage()
assert message.group is None
assert message.timestamp is not None
assert isinstance(message.timestamp, int)
def test_astrbot_message_timestamp(self):
"""Test timestamp is set on creation."""
with patch.object(time, "time", return_value=1234567890):
message = AstrBotMessage()
assert message.timestamp == 1234567890
def test_astrbot_message_all_attributes(self):
"""Test setting all attributes on AstrBotMessage."""
message = AstrBotMessage()
message.type = MessageType.FRIEND_MESSAGE
message.self_id = "bot123"
message.session_id = "session123"
message.message_id = "msg123"
message.sender = MessageMember(user_id="user123", nickname="TestUser")
message.message = [Plain(text="Hello")]
message.message_str = "Hello"
message.raw_message = {"raw": "data"}
assert message.type == MessageType.FRIEND_MESSAGE
assert message.self_id == "bot123"
assert message.session_id == "session123"
assert message.message_id == "msg123"
assert message.sender.user_id == "user123"
assert len(message.message) == 1
assert message.message_str == "Hello"
assert message.raw_message == {"raw": "data"}
def test_astrbot_message_str(self):
"""Test __str__ method."""
message = AstrBotMessage()
message.type = MessageType.FRIEND_MESSAGE
message.self_id = "bot123"
result = str(message)
assert "'type'" in result
assert "'self_id'" in result
class TestAstrBotMessageGroupId:
"""Tests for AstrBotMessage group_id property."""
def test_group_id_returns_empty_when_no_group(self):
"""Test group_id returns empty string when group is None."""
message = AstrBotMessage()
assert message.group_id == ""
def test_group_id_returns_group_id_when_group_exists(self):
"""Test group_id returns the group's id when group exists."""
message = AstrBotMessage()
message.group = Group(group_id="group123")
assert message.group_id == "group123"
def test_group_id_setter_creates_new_group(self):
"""Test group_id setter creates a new group if none exists."""
message = AstrBotMessage()
message.group_id = "new_group123"
assert message.group is not None
assert message.group.group_id == "new_group123"
def test_group_id_setter_updates_existing_group(self):
"""Test group_id setter updates existing group's id."""
message = AstrBotMessage()
message.group = Group(group_id="old_group")
message.group_id = "new_group"
assert message.group.group_id == "new_group"
def test_group_id_setter_with_none_removes_group(self):
"""Test group_id setter with None removes the group."""
message = AstrBotMessage()
message.group = Group(group_id="group123")
message.group_id = None
assert message.group is None
def test_group_id_setter_with_empty_string_removes_group(self):
"""Test group_id setter with empty string removes the group."""
message = AstrBotMessage()
message.group = Group(group_id="group123")
message.group_id = ""
assert message.group is None
class TestAstrBotMessageTypes:
"""Tests for AstrBotMessage with different message types."""
def test_friend_message_type(self):
"""Test AstrBotMessage with FRIEND_MESSAGE type."""
message = AstrBotMessage()
message.type = MessageType.FRIEND_MESSAGE
assert message.type == MessageType.FRIEND_MESSAGE
assert message.type.value == "FriendMessage"
def test_group_message_type(self):
"""Test AstrBotMessage with GROUP_MESSAGE type."""
message = AstrBotMessage()
message.type = MessageType.GROUP_MESSAGE
assert message.type == MessageType.GROUP_MESSAGE
assert message.type.value == "GroupMessage"
def test_other_message_type(self):
"""Test AstrBotMessage with OTHER_MESSAGE type."""
message = AstrBotMessage()
message.type = MessageType.OTHER_MESSAGE
assert message.type == MessageType.OTHER_MESSAGE
assert message.type.value == "OtherMessage"
class TestAstrBotMessageChain:
"""Tests for AstrBotMessage message chain."""
def test_message_chain_with_plain_text(self):
"""Test message chain with plain text."""
message = AstrBotMessage()
message.message = [Plain(text="Hello world")]
assert len(message.message) == 1
assert isinstance(message.message[0], Plain)
assert message.message[0].text == "Hello world"
def test_message_chain_with_multiple_components(self):
"""Test message chain with multiple components."""
message = AstrBotMessage()
message.message = [
Plain(text="Hello "),
Plain(text="world"),
Image(file="http://example.com/img.jpg"),
]
assert len(message.message) == 3
assert isinstance(message.message[0], Plain)
assert isinstance(message.message[1], Plain)
assert isinstance(message.message[2], Image)
def test_message_chain_empty(self):
"""Test empty message chain."""
message = AstrBotMessage()
message.message = []
assert len(message.message) == 0