Compare commits
21 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 602ae4eee2 | |||
| 3d82f42311 | |||
| 5530a2260a | |||
| c24de24ca4 | |||
| b54b4c79ed | |||
| c6cc7aae84 | |||
| 84cd209074 | |||
| afda44fbe3 | |||
| f5d3b93437 | |||
| 069a3628fa | |||
| c81ef2672a | |||
| a5ae27cae0 | |||
| 73faaf6577 | |||
| 29dbd085d4 | |||
| 00b011809a | |||
| 0b46ca7ff3 | |||
| 9294b44831 | |||
| 80fd51119b | |||
| 5af5ad9e36 | |||
| 7b731ebda8 | |||
| aec2e3bb91 |
@@ -37,7 +37,7 @@ jobs:
|
||||
mkdir -p data/temp
|
||||
export TESTING=true
|
||||
export ZHIPU_API_KEY=${{ secrets.OPENAI_API_KEY }}
|
||||
pytest --cov=. -v -o log_cli=true -o log_level=DEBUG
|
||||
pytest --cov=astrbot -v -o log_cli=true -o log_level=DEBUG
|
||||
|
||||
- name: Upload results to Codecov
|
||||
uses: codecov/codecov-action@v5
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = "4.18.1"
|
||||
__version__ = "4.18.2"
|
||||
|
||||
@@ -24,15 +24,77 @@ def _should_stop_agent(astr_event) -> bool:
|
||||
return astr_event.is_stopped() or bool(astr_event.get_extra("agent_stop_requested"))
|
||||
|
||||
|
||||
def _truncate_tool_result(text: str, limit: int = 70) -> str:
|
||||
if limit <= 0:
|
||||
return ""
|
||||
if len(text) <= limit:
|
||||
return text
|
||||
if limit <= 3:
|
||||
return text[:limit]
|
||||
return f"{text[: limit - 3]}..."
|
||||
|
||||
|
||||
def _extract_chain_json_data(msg_chain: MessageChain) -> dict | None:
|
||||
if not msg_chain.chain:
|
||||
return None
|
||||
first_comp = msg_chain.chain[0]
|
||||
if isinstance(first_comp, Json) and isinstance(first_comp.data, dict):
|
||||
return first_comp.data
|
||||
return None
|
||||
|
||||
|
||||
def _record_tool_call_name(
|
||||
tool_info: dict | None, tool_name_by_call_id: dict[str, str]
|
||||
) -> None:
|
||||
if not isinstance(tool_info, dict):
|
||||
return
|
||||
tool_call_id = tool_info.get("id")
|
||||
tool_name = tool_info.get("name")
|
||||
if tool_call_id is None or tool_name is None:
|
||||
return
|
||||
tool_name_by_call_id[str(tool_call_id)] = str(tool_name)
|
||||
|
||||
|
||||
def _build_tool_call_status_message(tool_info: dict | None) -> str:
|
||||
if tool_info:
|
||||
return f"🔨 调用工具: {tool_info.get('name', 'unknown')}"
|
||||
return "🔨 调用工具..."
|
||||
|
||||
|
||||
def _build_tool_result_status_message(
|
||||
msg_chain: MessageChain, tool_name_by_call_id: dict[str, str]
|
||||
) -> str:
|
||||
tool_name = "unknown"
|
||||
tool_result = ""
|
||||
|
||||
result_data = _extract_chain_json_data(msg_chain)
|
||||
if result_data:
|
||||
tool_call_id = result_data.get("id")
|
||||
if tool_call_id is not None:
|
||||
tool_name = tool_name_by_call_id.pop(str(tool_call_id), "unknown")
|
||||
tool_result = str(result_data.get("result", ""))
|
||||
|
||||
if not tool_result:
|
||||
tool_result = msg_chain.get_plain_text(with_other_comps_mark=True)
|
||||
tool_result = _truncate_tool_result(tool_result, 70)
|
||||
|
||||
status_msg = f"🔨 调用工具: {tool_name}"
|
||||
if tool_result:
|
||||
status_msg = f"{status_msg}\n📎 返回结果: {tool_result}"
|
||||
return status_msg
|
||||
|
||||
|
||||
async def run_agent(
|
||||
agent_runner: AgentRunner,
|
||||
max_step: int = 30,
|
||||
show_tool_use: bool = True,
|
||||
show_tool_call_result: bool = False,
|
||||
stream_to_general: bool = False,
|
||||
show_reasoning: bool = False,
|
||||
) -> AsyncGenerator[MessageChain | None, None]:
|
||||
step_idx = 0
|
||||
astr_event = agent_runner.run_context.context.event
|
||||
tool_name_by_call_id: dict[str, str] = {}
|
||||
while step_idx < max_step + 1:
|
||||
step_idx += 1
|
||||
|
||||
@@ -90,6 +152,13 @@ async def run_agent(
|
||||
continue
|
||||
if astr_event.get_platform_id() == "webchat":
|
||||
await astr_event.send(msg_chain)
|
||||
elif show_tool_use and show_tool_call_result:
|
||||
status_msg = _build_tool_result_status_message(
|
||||
msg_chain, tool_name_by_call_id
|
||||
)
|
||||
await astr_event.send(
|
||||
MessageChain(type="tool_call").message(status_msg)
|
||||
)
|
||||
# 对于其他情况,暂时先不处理
|
||||
continue
|
||||
elif resp.type == "tool_call":
|
||||
@@ -97,25 +166,22 @@ async def run_agent(
|
||||
# 用来标记流式响应需要分节
|
||||
yield MessageChain(chain=[], type="break")
|
||||
|
||||
tool_info = None
|
||||
|
||||
if resp.data["chain"].chain:
|
||||
json_comp = resp.data["chain"].chain[0]
|
||||
if isinstance(json_comp, Json):
|
||||
tool_info = json_comp.data
|
||||
astr_event.trace.record(
|
||||
"agent_tool_call",
|
||||
tool_name=tool_info if tool_info else "unknown",
|
||||
)
|
||||
tool_info = _extract_chain_json_data(resp.data["chain"])
|
||||
astr_event.trace.record(
|
||||
"agent_tool_call",
|
||||
tool_name=tool_info if tool_info else "unknown",
|
||||
)
|
||||
_record_tool_call_name(tool_info, tool_name_by_call_id)
|
||||
|
||||
if astr_event.get_platform_name() == "webchat":
|
||||
await astr_event.send(resp.data["chain"])
|
||||
elif show_tool_use:
|
||||
if tool_info:
|
||||
m = f"🔨 调用工具: {tool_info.get('name', 'unknown')}"
|
||||
else:
|
||||
m = "🔨 调用工具..."
|
||||
chain = MessageChain(type="tool_call").message(m)
|
||||
if show_tool_call_result and isinstance(tool_info, dict):
|
||||
# Delay tool status notification until tool_call_result.
|
||||
continue
|
||||
chain = MessageChain(type="tool_call").message(
|
||||
_build_tool_call_status_message(tool_info)
|
||||
)
|
||||
await astr_event.send(chain)
|
||||
continue
|
||||
|
||||
@@ -202,6 +268,7 @@ async def run_live_agent(
|
||||
tts_provider: TTSProvider | None = None,
|
||||
max_step: int = 30,
|
||||
show_tool_use: bool = True,
|
||||
show_tool_call_result: bool = False,
|
||||
show_reasoning: bool = False,
|
||||
) -> AsyncGenerator[MessageChain | None, None]:
|
||||
"""Live Mode 的 Agent 运行器,支持流式 TTS
|
||||
@@ -211,6 +278,7 @@ async def run_live_agent(
|
||||
tts_provider: TTS Provider 实例
|
||||
max_step: 最大步数
|
||||
show_tool_use: 是否显示工具使用
|
||||
show_tool_call_result: 是否显示工具返回结果
|
||||
show_reasoning: 是否显示推理过程
|
||||
|
||||
Yields:
|
||||
@@ -222,6 +290,7 @@ async def run_live_agent(
|
||||
agent_runner,
|
||||
max_step=max_step,
|
||||
show_tool_use=show_tool_use,
|
||||
show_tool_call_result=show_tool_call_result,
|
||||
stream_to_general=False,
|
||||
show_reasoning=show_reasoning,
|
||||
):
|
||||
@@ -250,7 +319,12 @@ async def run_live_agent(
|
||||
# 1. 启动 Agent Feeder 任务:负责运行 Agent 并将文本分句喂给 text_queue
|
||||
feeder_task = asyncio.create_task(
|
||||
_run_agent_feeder(
|
||||
agent_runner, text_queue, max_step, show_tool_use, show_reasoning
|
||||
agent_runner,
|
||||
text_queue,
|
||||
max_step,
|
||||
show_tool_use,
|
||||
show_tool_call_result,
|
||||
show_reasoning,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -336,6 +410,7 @@ async def _run_agent_feeder(
|
||||
text_queue: asyncio.Queue,
|
||||
max_step: int,
|
||||
show_tool_use: bool,
|
||||
show_tool_call_result: bool,
|
||||
show_reasoning: bool,
|
||||
) -> None:
|
||||
"""运行 Agent 并将文本输出分句放入队列"""
|
||||
@@ -345,6 +420,7 @@ async def _run_agent_feeder(
|
||||
agent_runner,
|
||||
max_step=max_step,
|
||||
show_tool_use=show_tool_use,
|
||||
show_tool_call_result=show_tool_call_result,
|
||||
stream_to_general=False,
|
||||
show_reasoning=show_reasoning,
|
||||
):
|
||||
|
||||
@@ -17,6 +17,12 @@ from astrbot.core.agent.tool_executor import BaseFunctionToolExecutor
|
||||
from astrbot.core.astr_agent_context import AstrAgentContext
|
||||
from astrbot.core.astr_main_agent_resources import (
|
||||
BACKGROUND_TASK_RESULT_WOKE_SYSTEM_PROMPT,
|
||||
EXECUTE_SHELL_TOOL,
|
||||
FILE_DOWNLOAD_TOOL,
|
||||
FILE_UPLOAD_TOOL,
|
||||
LOCAL_EXECUTE_SHELL_TOOL,
|
||||
LOCAL_PYTHON_TOOL,
|
||||
PYTHON_TOOL,
|
||||
SEND_MESSAGE_TO_USER_TOOL,
|
||||
)
|
||||
from astrbot.core.cron.events import CronMessageEvent
|
||||
@@ -91,6 +97,65 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
|
||||
yield r
|
||||
return
|
||||
|
||||
@classmethod
|
||||
def _get_runtime_computer_tools(cls, runtime: str) -> dict[str, FunctionTool]:
|
||||
if runtime == "sandbox":
|
||||
return {
|
||||
EXECUTE_SHELL_TOOL.name: EXECUTE_SHELL_TOOL,
|
||||
PYTHON_TOOL.name: PYTHON_TOOL,
|
||||
FILE_UPLOAD_TOOL.name: FILE_UPLOAD_TOOL,
|
||||
FILE_DOWNLOAD_TOOL.name: FILE_DOWNLOAD_TOOL,
|
||||
}
|
||||
if runtime == "local":
|
||||
return {
|
||||
LOCAL_EXECUTE_SHELL_TOOL.name: LOCAL_EXECUTE_SHELL_TOOL,
|
||||
LOCAL_PYTHON_TOOL.name: LOCAL_PYTHON_TOOL,
|
||||
}
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
def _build_handoff_toolset(
|
||||
cls,
|
||||
run_context: ContextWrapper[AstrAgentContext],
|
||||
tools: list[str | FunctionTool] | None,
|
||||
) -> ToolSet | None:
|
||||
ctx = run_context.context.context
|
||||
event = run_context.context.event
|
||||
cfg = ctx.get_config(umo=event.unified_msg_origin)
|
||||
provider_settings = cfg.get("provider_settings", {})
|
||||
runtime = str(provider_settings.get("computer_use_runtime", "local"))
|
||||
runtime_computer_tools = cls._get_runtime_computer_tools(runtime)
|
||||
|
||||
# Keep persona semantics aligned with the main agent: tools=None means
|
||||
# "all tools", including runtime computer-use tools.
|
||||
if tools is None:
|
||||
toolset = ToolSet()
|
||||
for registered_tool in llm_tools.func_list:
|
||||
if isinstance(registered_tool, HandoffTool):
|
||||
continue
|
||||
if registered_tool.active:
|
||||
toolset.add_tool(registered_tool)
|
||||
for runtime_tool in runtime_computer_tools.values():
|
||||
toolset.add_tool(runtime_tool)
|
||||
return None if toolset.empty() else toolset
|
||||
|
||||
if not tools:
|
||||
return None
|
||||
|
||||
toolset = ToolSet()
|
||||
for tool_name_or_obj in tools:
|
||||
if isinstance(tool_name_or_obj, str):
|
||||
registered_tool = llm_tools.get_func(tool_name_or_obj)
|
||||
if registered_tool and registered_tool.active:
|
||||
toolset.add_tool(registered_tool)
|
||||
continue
|
||||
runtime_tool = runtime_computer_tools.get(tool_name_or_obj)
|
||||
if runtime_tool:
|
||||
toolset.add_tool(runtime_tool)
|
||||
elif isinstance(tool_name_or_obj, FunctionTool):
|
||||
toolset.add_tool(tool_name_or_obj)
|
||||
return None if toolset.empty() else toolset
|
||||
|
||||
@classmethod
|
||||
async def _execute_handoff(
|
||||
cls,
|
||||
@@ -101,19 +166,8 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
|
||||
input_ = tool_args.get("input")
|
||||
image_urls = tool_args.get("image_urls")
|
||||
|
||||
# make toolset for the agent
|
||||
tools = tool.agent.tools
|
||||
if tools:
|
||||
toolset = ToolSet()
|
||||
for t in tools:
|
||||
if isinstance(t, str):
|
||||
_t = llm_tools.get_func(t)
|
||||
if _t:
|
||||
toolset.add_tool(_t)
|
||||
elif isinstance(t, FunctionTool):
|
||||
toolset.add_tool(t)
|
||||
else:
|
||||
toolset = None
|
||||
# Build handoff toolset from registered tools plus runtime computer tools.
|
||||
toolset = cls._build_handoff_toolset(run_context, tool.agent.tools)
|
||||
|
||||
ctx = run_context.context.context
|
||||
event = run_context.context.event
|
||||
|
||||
@@ -11,6 +11,7 @@ from astrbot.core.message.components import File
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
|
||||
|
||||
from ..computer_client import get_booter
|
||||
from .permissions import check_admin_permission
|
||||
|
||||
# @dataclass
|
||||
# class CreateFileTool(FunctionTool):
|
||||
@@ -102,6 +103,8 @@ class FileUploadTool(FunctionTool):
|
||||
context: ContextWrapper[AstrAgentContext],
|
||||
local_path: str,
|
||||
) -> str | None:
|
||||
if permission_error := check_admin_permission(context, "File upload/download"):
|
||||
return permission_error
|
||||
sb = await get_booter(
|
||||
context.context.context,
|
||||
context.context.event.unified_msg_origin,
|
||||
@@ -161,6 +164,8 @@ class FileDownloadTool(FunctionTool):
|
||||
remote_path: str,
|
||||
also_send_to_user: bool = True,
|
||||
) -> ToolExecResult:
|
||||
if permission_error := check_admin_permission(context, "File upload/download"):
|
||||
return permission_error
|
||||
sb = await get_booter(
|
||||
context.context.context,
|
||||
context.context.event.unified_msg_origin,
|
||||
|
||||
@@ -0,0 +1,19 @@
|
||||
from astrbot.core.agent.run_context import ContextWrapper
|
||||
from astrbot.core.astr_agent_context import AstrAgentContext
|
||||
|
||||
|
||||
def check_admin_permission(
|
||||
context: ContextWrapper[AstrAgentContext], operation_name: str
|
||||
) -> str | None:
|
||||
cfg = context.context.context.get_config(
|
||||
umo=context.context.event.unified_msg_origin
|
||||
)
|
||||
provider_settings = cfg.get("provider_settings", {})
|
||||
require_admin = provider_settings.get("computer_use_require_admin", True)
|
||||
if require_admin and context.context.event.role != "admin":
|
||||
return (
|
||||
f"error: Permission denied. {operation_name} is only allowed for admin users. "
|
||||
"Tell user to set admins in `AstrBot WebUI -> Config -> General Config` by adding their user ID to the admins list if they need this feature. "
|
||||
f"User's ID is: {context.context.event.get_sender_id()}. User's ID can be found by using /sid command."
|
||||
)
|
||||
return None
|
||||
@@ -7,6 +7,7 @@ from astrbot.core.agent.run_context import ContextWrapper
|
||||
from astrbot.core.agent.tool import ToolExecResult
|
||||
from astrbot.core.astr_agent_context import AstrAgentContext, AstrMessageEvent
|
||||
from astrbot.core.computer.computer_client import get_booter, get_local_booter
|
||||
from astrbot.core.computer.tools.permissions import check_admin_permission
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
|
||||
param_schema = {
|
||||
@@ -26,21 +27,6 @@ param_schema = {
|
||||
}
|
||||
|
||||
|
||||
def _check_admin_permission(context: ContextWrapper[AstrAgentContext]) -> str | None:
|
||||
cfg = context.context.context.get_config(
|
||||
umo=context.context.event.unified_msg_origin
|
||||
)
|
||||
provider_settings = cfg.get("provider_settings", {})
|
||||
require_admin = provider_settings.get("computer_use_require_admin", True)
|
||||
if require_admin and context.context.event.role != "admin":
|
||||
return (
|
||||
"error: Permission denied. Python execution is only allowed for admin users. "
|
||||
"Tell user to set admins in `AstrBot WebUI -> Config -> General Config` by adding their user ID to the admins list if they need this feature."
|
||||
f"User's ID is: {context.context.event.get_sender_id()}. User's ID can be found by using /sid command."
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
async def handle_result(result: dict, event: AstrMessageEvent) -> ToolExecResult:
|
||||
data = result.get("data", {})
|
||||
output = data.get("output", {})
|
||||
@@ -81,7 +67,7 @@ class PythonTool(FunctionTool):
|
||||
async def call(
|
||||
self, context: ContextWrapper[AstrAgentContext], code: str, silent: bool = False
|
||||
) -> ToolExecResult:
|
||||
if permission_error := _check_admin_permission(context):
|
||||
if permission_error := check_admin_permission(context, "Python execution"):
|
||||
return permission_error
|
||||
sb = await get_booter(
|
||||
context.context.context,
|
||||
@@ -104,7 +90,7 @@ class LocalPythonTool(FunctionTool):
|
||||
async def call(
|
||||
self, context: ContextWrapper[AstrAgentContext], code: str, silent: bool = False
|
||||
) -> ToolExecResult:
|
||||
if permission_error := _check_admin_permission(context):
|
||||
if permission_error := check_admin_permission(context, "Python execution"):
|
||||
return permission_error
|
||||
sb = get_local_booter()
|
||||
try:
|
||||
|
||||
@@ -7,21 +7,7 @@ from astrbot.core.agent.tool import ToolExecResult
|
||||
from astrbot.core.astr_agent_context import AstrAgentContext
|
||||
|
||||
from ..computer_client import get_booter, get_local_booter
|
||||
|
||||
|
||||
def _check_admin_permission(context: ContextWrapper[AstrAgentContext]) -> str | None:
|
||||
cfg = context.context.context.get_config(
|
||||
umo=context.context.event.unified_msg_origin
|
||||
)
|
||||
provider_settings = cfg.get("provider_settings", {})
|
||||
require_admin = provider_settings.get("computer_use_require_admin", True)
|
||||
if require_admin and context.context.event.role != "admin":
|
||||
return (
|
||||
"error: Permission denied. Shell execution is only allowed for admin users. "
|
||||
"Tell user to set admins in `AstrBot WebUI -> Config -> General Config` by adding their user ID to the admins list if they need this feature."
|
||||
f"User's ID is: {context.context.event.get_sender_id()}. User's ID can be found by using /sid command."
|
||||
)
|
||||
return None
|
||||
from .permissions import check_admin_permission
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -61,7 +47,7 @@ class ExecuteShellTool(FunctionTool):
|
||||
background: bool = False,
|
||||
env: dict = {},
|
||||
) -> ToolExecResult:
|
||||
if permission_error := _check_admin_permission(context):
|
||||
if permission_error := check_admin_permission(context, "Shell execution"):
|
||||
return permission_error
|
||||
|
||||
if self.is_local:
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Any, TypedDict
|
||||
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
VERSION = "4.18.1"
|
||||
VERSION = "4.18.2"
|
||||
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db")
|
||||
|
||||
WEBHOOK_SUPPORTED_PLATFORMS = [
|
||||
@@ -100,6 +100,7 @@ DEFAULT_CONFIG = {
|
||||
"dequeue_context_length": 1,
|
||||
"streaming_response": False,
|
||||
"show_tool_use_status": False,
|
||||
"show_tool_call_result": False,
|
||||
"sanitize_context_by_modalities": False,
|
||||
"max_quoted_fallback_images": 20,
|
||||
"quoted_message_parser": {
|
||||
@@ -2306,6 +2307,9 @@ CONFIG_METADATA_2 = {
|
||||
"show_tool_use_status": {
|
||||
"type": "bool",
|
||||
},
|
||||
"show_tool_call_result": {
|
||||
"type": "bool",
|
||||
},
|
||||
"unsupported_streaming_strategy": {
|
||||
"type": "string",
|
||||
},
|
||||
@@ -2994,6 +2998,15 @@ CONFIG_METADATA_3 = {
|
||||
"provider_settings.agent_runner_type": "local",
|
||||
},
|
||||
},
|
||||
"provider_settings.show_tool_call_result": {
|
||||
"description": "输出函数调用返回结果",
|
||||
"type": "bool",
|
||||
"hint": "仅在输出函数调用状态启用时生效,展示结果前 70 个字符。",
|
||||
"condition": {
|
||||
"provider_settings.agent_runner_type": "local",
|
||||
"provider_settings.show_tool_use_status": True,
|
||||
},
|
||||
},
|
||||
"provider_settings.sanitize_context_by_modalities": {
|
||||
"description": "按模型能力清理历史上下文",
|
||||
"type": "bool",
|
||||
|
||||
@@ -720,13 +720,38 @@ class File(BaseMessageComponent):
|
||||
if allow_return_url and self.url:
|
||||
return self.url
|
||||
|
||||
if self.file_ and os.path.exists(self.file_):
|
||||
return os.path.abspath(self.file_)
|
||||
if self.file_:
|
||||
path = self.file_
|
||||
if path.startswith("file://"):
|
||||
# 处理 file:// (2 slashes) 或 file:/// (3 slashes)
|
||||
# pathlib.as_uri() 通常生成 file:///
|
||||
path = path[7:]
|
||||
# 兼容 Windows: file:///C:/path -> /C:/path -> C:/path
|
||||
if (
|
||||
os.name == "nt"
|
||||
and len(path) > 2
|
||||
and path[0] == "/"
|
||||
and path[2] == ":"
|
||||
):
|
||||
path = path[1:]
|
||||
|
||||
if os.path.exists(path):
|
||||
return os.path.abspath(path)
|
||||
|
||||
if self.url:
|
||||
await self._download_file()
|
||||
if self.file_:
|
||||
return os.path.abspath(self.file_)
|
||||
path = self.file_
|
||||
if path.startswith("file://"):
|
||||
path = path[7:]
|
||||
if (
|
||||
os.name == "nt"
|
||||
and len(path) > 2
|
||||
and path[0] == "/"
|
||||
and path[2] == ":"
|
||||
):
|
||||
path = path[1:]
|
||||
return os.path.abspath(path)
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
@@ -1,30 +1,60 @@
|
||||
"""Pipeline package exports.
|
||||
|
||||
This module intentionally avoids eager imports of all pipeline stage modules to
|
||||
prevent import-time cycles. Stage classes remain available via lazy attribute
|
||||
resolution for backward compatibility.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from importlib import import_module
|
||||
from typing import Any
|
||||
|
||||
from astrbot.core.message.message_event_result import (
|
||||
EventResultType,
|
||||
MessageEventResult,
|
||||
)
|
||||
|
||||
from .content_safety_check.stage import ContentSafetyCheckStage
|
||||
from .preprocess_stage.stage import PreProcessStage
|
||||
from .process_stage.stage import ProcessStage
|
||||
from .rate_limit_check.stage import RateLimitStage
|
||||
from .respond.stage import RespondStage
|
||||
from .result_decorate.stage import ResultDecorateStage
|
||||
from .session_status_check.stage import SessionStatusCheckStage
|
||||
from .waking_check.stage import WakingCheckStage
|
||||
from .whitelist_check.stage import WhitelistCheckStage
|
||||
from .stage_order import STAGES_ORDER
|
||||
|
||||
# 管道阶段顺序
|
||||
STAGES_ORDER = [
|
||||
"WakingCheckStage", # 检查是否需要唤醒
|
||||
"WhitelistCheckStage", # 检查是否在群聊/私聊白名单
|
||||
"SessionStatusCheckStage", # 检查会话是否整体启用
|
||||
"RateLimitStage", # 检查会话是否超过频率限制
|
||||
"ContentSafetyCheckStage", # 检查内容安全
|
||||
"PreProcessStage", # 预处理
|
||||
"ProcessStage", # 交由 Stars 处理(a.k.a 插件),或者 LLM 调用
|
||||
"ResultDecorateStage", # 处理结果,比如添加回复前缀、t2i、转换为语音 等
|
||||
"RespondStage", # 发送消息
|
||||
]
|
||||
_LAZY_EXPORTS = {
|
||||
"ContentSafetyCheckStage": (
|
||||
"astrbot.core.pipeline.content_safety_check.stage",
|
||||
"ContentSafetyCheckStage",
|
||||
),
|
||||
"PreProcessStage": (
|
||||
"astrbot.core.pipeline.preprocess_stage.stage",
|
||||
"PreProcessStage",
|
||||
),
|
||||
"ProcessStage": (
|
||||
"astrbot.core.pipeline.process_stage.stage",
|
||||
"ProcessStage",
|
||||
),
|
||||
"RateLimitStage": (
|
||||
"astrbot.core.pipeline.rate_limit_check.stage",
|
||||
"RateLimitStage",
|
||||
),
|
||||
"RespondStage": (
|
||||
"astrbot.core.pipeline.respond.stage",
|
||||
"RespondStage",
|
||||
),
|
||||
"ResultDecorateStage": (
|
||||
"astrbot.core.pipeline.result_decorate.stage",
|
||||
"ResultDecorateStage",
|
||||
),
|
||||
"SessionStatusCheckStage": (
|
||||
"astrbot.core.pipeline.session_status_check.stage",
|
||||
"SessionStatusCheckStage",
|
||||
),
|
||||
"WakingCheckStage": (
|
||||
"astrbot.core.pipeline.waking_check.stage",
|
||||
"WakingCheckStage",
|
||||
),
|
||||
"WhitelistCheckStage": (
|
||||
"astrbot.core.pipeline.whitelist_check.stage",
|
||||
"WhitelistCheckStage",
|
||||
),
|
||||
}
|
||||
|
||||
__all__ = [
|
||||
"ContentSafetyCheckStage",
|
||||
@@ -36,6 +66,21 @@ __all__ = [
|
||||
"RespondStage",
|
||||
"ResultDecorateStage",
|
||||
"SessionStatusCheckStage",
|
||||
"STAGES_ORDER",
|
||||
"WakingCheckStage",
|
||||
"WhitelistCheckStage",
|
||||
]
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
if name not in _LAZY_EXPORTS:
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
module_path, attr_name = _LAZY_EXPORTS[name]
|
||||
module = import_module(module_path)
|
||||
value = getattr(module, attr_name)
|
||||
globals()[name] = value
|
||||
return value
|
||||
|
||||
|
||||
def __dir__() -> list[str]:
|
||||
return sorted(set(globals()) | set(__all__))
|
||||
|
||||
@@ -0,0 +1,52 @@
|
||||
"""Pipeline bootstrap utilities."""
|
||||
|
||||
from importlib import import_module
|
||||
|
||||
from .stage import registered_stages
|
||||
|
||||
_BUILTIN_STAGE_MODULES = (
|
||||
"astrbot.core.pipeline.waking_check.stage",
|
||||
"astrbot.core.pipeline.whitelist_check.stage",
|
||||
"astrbot.core.pipeline.session_status_check.stage",
|
||||
"astrbot.core.pipeline.rate_limit_check.stage",
|
||||
"astrbot.core.pipeline.content_safety_check.stage",
|
||||
"astrbot.core.pipeline.preprocess_stage.stage",
|
||||
"astrbot.core.pipeline.process_stage.stage",
|
||||
"astrbot.core.pipeline.result_decorate.stage",
|
||||
"astrbot.core.pipeline.respond.stage",
|
||||
)
|
||||
|
||||
_EXPECTED_STAGE_NAMES = {
|
||||
"WakingCheckStage",
|
||||
"WhitelistCheckStage",
|
||||
"SessionStatusCheckStage",
|
||||
"RateLimitStage",
|
||||
"ContentSafetyCheckStage",
|
||||
"PreProcessStage",
|
||||
"ProcessStage",
|
||||
"ResultDecorateStage",
|
||||
"RespondStage",
|
||||
}
|
||||
|
||||
_builtin_stages_registered = False
|
||||
|
||||
|
||||
def ensure_builtin_stages_registered() -> None:
|
||||
"""Ensure built-in pipeline stages are imported and registered."""
|
||||
global _builtin_stages_registered
|
||||
|
||||
if _builtin_stages_registered:
|
||||
return
|
||||
|
||||
stage_names = {stage_cls.__name__ for stage_cls in registered_stages}
|
||||
if _EXPECTED_STAGE_NAMES.issubset(stage_names):
|
||||
_builtin_stages_registered = True
|
||||
return
|
||||
|
||||
for module_path in _BUILTIN_STAGE_MODULES:
|
||||
import_module(module_path)
|
||||
|
||||
_builtin_stages_registered = True
|
||||
|
||||
|
||||
__all__ = ["ensure_builtin_stages_registered"]
|
||||
@@ -1,7 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from astrbot.core.config import AstrBotConfig
|
||||
from astrbot.core.star import PluginManager
|
||||
|
||||
from .context_utils import call_event_hook, call_handler
|
||||
|
||||
@@ -11,7 +13,7 @@ class PipelineContext:
|
||||
"""上下文对象,包含管道执行所需的上下文信息"""
|
||||
|
||||
astrbot_config: AstrBotConfig # AstrBot 配置对象
|
||||
plugin_manager: PluginManager # 插件管理器对象
|
||||
plugin_manager: Any # 插件管理器对象
|
||||
astrbot_config_id: str
|
||||
call_handler = call_handler
|
||||
call_event_hook = call_event_hook
|
||||
|
||||
@@ -19,6 +19,7 @@ from astrbot.core.message.message_event_result import (
|
||||
MessageEventResult,
|
||||
ResultContentType,
|
||||
)
|
||||
from astrbot.core.pipeline.stage import Stage
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.provider.entities import (
|
||||
LLMResponse,
|
||||
@@ -30,7 +31,6 @@ from astrbot.core.utils.session_lock import session_lock_manager
|
||||
|
||||
from .....astr_agent_run_util import run_agent, run_live_agent
|
||||
from ....context import PipelineContext, call_event_hook
|
||||
from ...stage import Stage
|
||||
|
||||
|
||||
class InternalAgentSubStage(Stage):
|
||||
@@ -54,6 +54,7 @@ class InternalAgentSubStage(Stage):
|
||||
if isinstance(self.max_step, bool): # workaround: #2622
|
||||
self.max_step = 30
|
||||
self.show_tool_use: bool = settings.get("show_tool_use_status", True)
|
||||
self.show_tool_call_result: bool = settings.get("show_tool_call_result", False)
|
||||
self.show_reasoning = settings.get("display_reasoning_text", False)
|
||||
self.sanitize_context_by_modalities: bool = settings.get(
|
||||
"sanitize_context_by_modalities",
|
||||
@@ -240,6 +241,7 @@ class InternalAgentSubStage(Stage):
|
||||
tts_provider,
|
||||
self.max_step,
|
||||
self.show_tool_use,
|
||||
self.show_tool_call_result,
|
||||
show_reasoning=self.show_reasoning,
|
||||
),
|
||||
),
|
||||
@@ -269,6 +271,7 @@ class InternalAgentSubStage(Stage):
|
||||
agent_runner,
|
||||
self.max_step,
|
||||
self.show_tool_use,
|
||||
self.show_tool_call_result,
|
||||
show_reasoning=self.show_reasoning,
|
||||
),
|
||||
),
|
||||
@@ -297,6 +300,7 @@ class InternalAgentSubStage(Stage):
|
||||
agent_runner,
|
||||
self.max_step,
|
||||
self.show_tool_use,
|
||||
self.show_tool_call_result,
|
||||
stream_to_general,
|
||||
show_reasoning=self.show_reasoning,
|
||||
):
|
||||
|
||||
@@ -8,6 +8,7 @@ from astrbot.core.agent.runners.dashscope.dashscope_agent_runner import (
|
||||
DashscopeAgentRunner,
|
||||
)
|
||||
from astrbot.core.agent.runners.dify.dify_agent_runner import DifyAgentRunner
|
||||
from astrbot.core.astr_agent_hooks import MAIN_AGENT_HOOKS
|
||||
from astrbot.core.message.components import Image
|
||||
from astrbot.core.message.message_event_result import (
|
||||
MessageChain,
|
||||
@@ -17,6 +18,7 @@ from astrbot.core.message.message_event_result import (
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from astrbot.core.agent.runners.base import BaseAgentRunner
|
||||
from astrbot.core.pipeline.stage import Stage
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.provider.entities import (
|
||||
ProviderRequest,
|
||||
@@ -25,9 +27,7 @@ from astrbot.core.star.star_handler import EventType
|
||||
from astrbot.core.utils.metrics import Metric
|
||||
|
||||
from .....astr_agent_context import AgentContextWrapper, AstrAgentContext
|
||||
from .....astr_agent_hooks import MAIN_AGENT_HOOKS
|
||||
from ....context import PipelineContext, call_event_hook
|
||||
from ...stage import Stage
|
||||
|
||||
AGENT_RUNNER_TYPE_KEY = {
|
||||
"dify": "dify_agent_runner_provider_id",
|
||||
|
||||
@@ -8,15 +8,17 @@ from astrbot.core.platform.sources.wecom_ai_bot.wecomai_event import (
|
||||
)
|
||||
from astrbot.core.utils.active_event_registry import active_event_registry
|
||||
|
||||
from . import STAGES_ORDER
|
||||
from .bootstrap import ensure_builtin_stages_registered
|
||||
from .context import PipelineContext
|
||||
from .stage import registered_stages
|
||||
from .stage_order import STAGES_ORDER
|
||||
|
||||
|
||||
class PipelineScheduler:
|
||||
"""管道调度器,负责调度各个阶段的执行"""
|
||||
|
||||
def __init__(self, context: PipelineContext) -> None:
|
||||
ensure_builtin_stages_registered()
|
||||
registered_stages.sort(
|
||||
key=lambda x: STAGES_ORDER.index(x.__name__),
|
||||
) # 按照顺序排序
|
||||
|
||||
@@ -0,0 +1,15 @@
|
||||
"""Pipeline stage execution order."""
|
||||
|
||||
STAGES_ORDER = [
|
||||
"WakingCheckStage", # 检查是否需要唤醒
|
||||
"WhitelistCheckStage", # 检查是否在群聊/私聊白名单
|
||||
"SessionStatusCheckStage", # 检查会话是否整体启用
|
||||
"RateLimitStage", # 检查会话是否超过频率限制
|
||||
"ContentSafetyCheckStage", # 检查内容安全
|
||||
"PreProcessStage", # 预处理
|
||||
"ProcessStage", # 交由 Stars 处理(a.k.a 插件),或者 LLM 调用
|
||||
"ResultDecorateStage", # 处理结果,比如添加回复前缀、t2i、转换为语音 等
|
||||
"RespondStage", # 发送消息
|
||||
]
|
||||
|
||||
__all__ = ["STAGES_ORDER"]
|
||||
@@ -52,9 +52,19 @@ class AstrMessageEvent(abc.ABC):
|
||||
self.is_at_or_wake_command = False
|
||||
"""是否是 At 机器人或者带有唤醒词或者是私聊(插件注册的事件监听器会让 is_wake 设为 True, 但是不会让这个属性置为 True)"""
|
||||
self._extras: dict[str, Any] = {}
|
||||
message_type = getattr(message_obj, "type", None)
|
||||
if not isinstance(message_type, MessageType):
|
||||
try:
|
||||
message_type = MessageType(str(message_type))
|
||||
except (ValueError, TypeError, AttributeError):
|
||||
logger.warning(
|
||||
f"Failed to convert message type {message_obj.type!r} to MessageType. "
|
||||
f"Falling back to FRIEND_MESSAGE."
|
||||
)
|
||||
message_type = MessageType.FRIEND_MESSAGE
|
||||
self.session = MessageSession(
|
||||
platform_name=platform_meta.id,
|
||||
message_type=message_obj.type,
|
||||
message_type=message_type,
|
||||
session_id=session_id,
|
||||
)
|
||||
# self.unified_msg_origin = str(self.session)
|
||||
@@ -159,15 +169,18 @@ class AstrMessageEvent(abc.ABC):
|
||||
|
||||
除了文本消息外,其他消息类型会被转换为对应的占位符。如图片消息会被转换为 [图片]。
|
||||
"""
|
||||
return self._outline_chain(self.message_obj.message)
|
||||
return self._outline_chain(getattr(self.message_obj, "message", None))
|
||||
|
||||
def get_messages(self) -> list[BaseMessageComponent]:
|
||||
"""获取消息链。"""
|
||||
return self.message_obj.message
|
||||
return getattr(self.message_obj, "message", [])
|
||||
|
||||
def get_message_type(self) -> MessageType:
|
||||
"""获取消息类型。"""
|
||||
return self.message_obj.type
|
||||
message_type = getattr(self.message_obj, "type", None)
|
||||
if isinstance(message_type, MessageType):
|
||||
return message_type
|
||||
return self.session.message_type
|
||||
|
||||
def get_session_id(self) -> str:
|
||||
"""获取会话id。"""
|
||||
@@ -175,21 +188,30 @@ class AstrMessageEvent(abc.ABC):
|
||||
|
||||
def get_group_id(self) -> str:
|
||||
"""获取群组id。如果不是群组消息,返回空字符串。"""
|
||||
return self.message_obj.group_id
|
||||
return getattr(self.message_obj, "group_id", "")
|
||||
|
||||
def get_self_id(self) -> str:
|
||||
"""获取机器人自身的id。"""
|
||||
return self.message_obj.self_id
|
||||
return getattr(self.message_obj, "self_id", "")
|
||||
|
||||
def get_sender_id(self) -> str:
|
||||
"""获取消息发送者的id。"""
|
||||
return self.message_obj.sender.user_id
|
||||
sender = getattr(self.message_obj, "sender", None)
|
||||
if sender and isinstance(getattr(sender, "user_id", None), str):
|
||||
return sender.user_id
|
||||
return ""
|
||||
|
||||
def get_sender_name(self) -> str:
|
||||
"""获取消息发送者的名称。(可能会返回空字符串)"""
|
||||
if isinstance(self.message_obj.sender.nickname, str):
|
||||
return self.message_obj.sender.nickname
|
||||
return ""
|
||||
sender = getattr(self.message_obj, "sender", None)
|
||||
if not sender:
|
||||
return ""
|
||||
nickname = getattr(sender, "nickname", None)
|
||||
if nickname is None:
|
||||
return ""
|
||||
if isinstance(nickname, str):
|
||||
return nickname
|
||||
return str(nickname)
|
||||
|
||||
def set_extra(self, key, value) -> None:
|
||||
"""设置额外的信息。"""
|
||||
@@ -208,7 +230,7 @@ class AstrMessageEvent(abc.ABC):
|
||||
|
||||
def is_private_chat(self) -> bool:
|
||||
"""是否是私聊。"""
|
||||
return self.message_obj.type.value == (MessageType.FRIEND_MESSAGE).value
|
||||
return self.get_message_type() == MessageType.FRIEND_MESSAGE
|
||||
|
||||
def is_wake_up(self) -> bool:
|
||||
"""是否是唤醒机器人的事件。"""
|
||||
|
||||
@@ -180,6 +180,10 @@ class PlatformManager:
|
||||
from .sources.line.line_adapter import (
|
||||
LinePlatformAdapter, # noqa: F401
|
||||
)
|
||||
case "heihe":
|
||||
from .sources.heihe.heihe_adapter import (
|
||||
HeihePlatformAdapter, # noqa: F401
|
||||
)
|
||||
except (ImportError, ModuleNotFoundError) as e:
|
||||
logger.error(
|
||||
f"加载平台适配器 {platform_config['type']} 失败,原因:{e}。请检查依赖库是否安装。提示:可以在 管理面板->平台日志->安装Pip库 中安装依赖库。",
|
||||
|
||||
@@ -45,6 +45,19 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
|
||||
if isinstance(segment, File):
|
||||
# For File segments, we need to handle the file differently
|
||||
d = await segment.to_dict()
|
||||
file_val = d.get("data", {}).get("file", "")
|
||||
if file_val:
|
||||
import pathlib
|
||||
|
||||
try:
|
||||
# 使用 pathlib 处理路径,能更好地处理 Windows/Linux 差异
|
||||
path_obj = pathlib.Path(file_val)
|
||||
# 如果是绝对路径且不包含协议头 (://),则转换为标准的 file: URI
|
||||
if path_obj.is_absolute() and "://" not in file_val:
|
||||
d["data"]["file"] = path_obj.as_uri()
|
||||
except Exception:
|
||||
# 如果不是合法路径(例如已经是特定的特殊字符串),则跳过转换
|
||||
pass
|
||||
return d
|
||||
if isinstance(segment, Video):
|
||||
d = await segment.to_dict()
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import inspect
|
||||
import itertools
|
||||
import logging
|
||||
import time
|
||||
@@ -436,7 +437,42 @@ class AiocqhttpAdapter(Platform):
|
||||
return coro
|
||||
|
||||
async def terminate(self) -> None:
|
||||
self.shutdown_event.set()
|
||||
if hasattr(self, "shutdown_event"):
|
||||
self.shutdown_event.set()
|
||||
await self._close_reverse_ws_connections()
|
||||
|
||||
async def _close_reverse_ws_connections(self) -> None:
|
||||
api_clients = getattr(self.bot, "_wsr_api_clients", None)
|
||||
event_clients = getattr(self.bot, "_wsr_event_clients", None)
|
||||
|
||||
ws_clients: set[Any] = set()
|
||||
if isinstance(api_clients, dict):
|
||||
ws_clients.update(api_clients.values())
|
||||
if isinstance(event_clients, set):
|
||||
ws_clients.update(event_clients)
|
||||
|
||||
close_tasks: list[Awaitable[Any]] = []
|
||||
for ws in ws_clients:
|
||||
close_func = getattr(ws, "close", None)
|
||||
if not callable(close_func):
|
||||
continue
|
||||
try:
|
||||
close_result = close_func(code=1000, reason="Adapter shutdown")
|
||||
except TypeError:
|
||||
close_result = close_func()
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if inspect.isawaitable(close_result):
|
||||
close_tasks.append(close_result)
|
||||
|
||||
if close_tasks:
|
||||
await asyncio.gather(*close_tasks, return_exceptions=True)
|
||||
|
||||
if isinstance(api_clients, dict):
|
||||
api_clients.clear()
|
||||
if isinstance(event_clients, set):
|
||||
event_clients.clear()
|
||||
|
||||
async def shutdown_trigger_placeholder(self) -> None:
|
||||
await self.shutdown_event.wait()
|
||||
|
||||
@@ -0,0 +1,523 @@
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, cast
|
||||
|
||||
import websockets
|
||||
from websockets.asyncio.client import ClientConnection, connect
|
||||
|
||||
from astrbot.api import logger
|
||||
from astrbot.api.event import MessageChain
|
||||
from astrbot.api.message_components import At, Image, Plain
|
||||
from astrbot.api.platform import (
|
||||
AstrBotMessage,
|
||||
Group,
|
||||
MessageMember,
|
||||
MessageType,
|
||||
Platform,
|
||||
PlatformMetadata,
|
||||
)
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
|
||||
from ...register import register_platform_adapter
|
||||
from .heihe_event import HeiheMessageEvent
|
||||
|
||||
HEIHE_CONFIG_METADATA = {
|
||||
"heihe_ws_url": {
|
||||
"description": "Heihe WebSocket URL",
|
||||
"type": "string",
|
||||
"hint": "一般情况下不需要修改。",
|
||||
},
|
||||
"heihe_token": {
|
||||
"description": "Bot Token",
|
||||
"type": "string",
|
||||
"hint": "黑盒 Bot Token。可填写纯 Token(推荐),适配器会自动添加 Authorization 头。",
|
||||
},
|
||||
"heihe_origin": {
|
||||
"description": "WebSocket Origin",
|
||||
"type": "string",
|
||||
"hint": "用于 WebSocket 握手的 Origin 头,默认 https://chat.xiaoheihe.cn。",
|
||||
},
|
||||
"heihe_bot_id": {
|
||||
"description": "Bot ID",
|
||||
"type": "string",
|
||||
"hint": "可选。为空时会根据收到的消息自动识别机器人 ID。",
|
||||
},
|
||||
"heihe_auto_reconnect": {
|
||||
"description": "Auto Reconnect",
|
||||
"type": "bool",
|
||||
"hint": "WebSocket 断开后是否自动重连。",
|
||||
},
|
||||
"heihe_heartbeat_interval": {
|
||||
"description": "Heartbeat Interval (seconds)",
|
||||
"type": "int",
|
||||
"hint": "发送心跳包间隔。<=0 表示关闭主动心跳。",
|
||||
},
|
||||
"heihe_reconnect_delay": {
|
||||
"description": "Reconnect Delay (seconds)",
|
||||
"type": "int",
|
||||
"hint": "WebSocket 断开后的重连等待时间。",
|
||||
},
|
||||
"heihe_ignore_self_message": {
|
||||
"description": "Ignore Self Message",
|
||||
"type": "bool",
|
||||
"hint": "是否忽略机器人自身发送的消息。",
|
||||
},
|
||||
}
|
||||
|
||||
HEIHE_I18N_RESOURCES = {
|
||||
"zh-CN": {
|
||||
"heihe_ws_url": {
|
||||
"description": "黑盒 WebSocket 地址",
|
||||
"hint": "一般情况下不需要修改。",
|
||||
},
|
||||
"heihe_token": {
|
||||
"description": "机器人 Token",
|
||||
"hint": "建议填写纯 Token,适配器会自动补齐 Authorization 头。",
|
||||
},
|
||||
"heihe_origin": {
|
||||
"description": "WebSocket Origin",
|
||||
"hint": "用于握手的 Origin 头,默认 https://chat.xiaoheihe.cn。",
|
||||
},
|
||||
"heihe_bot_id": {
|
||||
"description": "机器人 ID",
|
||||
"hint": "可选。为空时会根据收到的消息自动识别机器人 ID。",
|
||||
},
|
||||
"heihe_auto_reconnect": {
|
||||
"description": "自动重连",
|
||||
"hint": "WebSocket 断开后是否自动重连。",
|
||||
},
|
||||
"heihe_heartbeat_interval": {
|
||||
"description": "心跳间隔(秒)",
|
||||
"hint": "设置 <=0 将关闭主动心跳。",
|
||||
},
|
||||
"heihe_reconnect_delay": {
|
||||
"description": "重连间隔(秒)",
|
||||
"hint": "WebSocket 断开后的重连等待时间。",
|
||||
},
|
||||
"heihe_ignore_self_message": {
|
||||
"description": "忽略机器人自身消息",
|
||||
"hint": "开启后,机器人自己发出的消息将不会触发事件处理。",
|
||||
},
|
||||
},
|
||||
"en-US": {
|
||||
"heihe_ws_url": {
|
||||
"description": "Heihe WebSocket URL",
|
||||
"hint": "Usually no need to change this.",
|
||||
},
|
||||
"heihe_token": {
|
||||
"description": "Bot Token",
|
||||
"hint": "Plain token is recommended. Authorization header is added automatically.",
|
||||
},
|
||||
"heihe_origin": {
|
||||
"description": "WebSocket Origin",
|
||||
"hint": "Origin header used in websocket handshake. Default: https://chat.xiaoheihe.cn.",
|
||||
},
|
||||
"heihe_bot_id": {
|
||||
"description": "Bot ID",
|
||||
"hint": "Optional. If empty, the adapter will infer it from incoming messages.",
|
||||
},
|
||||
"heihe_auto_reconnect": {
|
||||
"description": "Auto Reconnect",
|
||||
"hint": "Whether to reconnect automatically after websocket disconnects.",
|
||||
},
|
||||
"heihe_heartbeat_interval": {
|
||||
"description": "Heartbeat Interval (seconds)",
|
||||
"hint": "Set <=0 to disable active heartbeat.",
|
||||
},
|
||||
"heihe_reconnect_delay": {
|
||||
"description": "Reconnect Delay (seconds)",
|
||||
"hint": "Delay before reconnecting after disconnect.",
|
||||
},
|
||||
"heihe_ignore_self_message": {
|
||||
"description": "Ignore Self Message",
|
||||
"hint": "When enabled, messages sent by the bot itself will be ignored.",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@register_platform_adapter(
|
||||
"heihe",
|
||||
"黑盒机器人(WebSocket)适配器",
|
||||
support_streaming_message=False,
|
||||
default_config_tmpl={
|
||||
"id": "heihe",
|
||||
"type": "heihe",
|
||||
"enable": False,
|
||||
"heihe_ws_url": "wss://chat.xiaoheihe.cn/chatroom/ws/connect",
|
||||
"heihe_token": "",
|
||||
"heihe_origin": "https://chat.xiaoheihe.cn",
|
||||
"heihe_bot_id": "",
|
||||
"heihe_auto_reconnect": True,
|
||||
"heihe_heartbeat_interval": 20,
|
||||
"heihe_reconnect_delay": 5,
|
||||
"heihe_ignore_self_message": True,
|
||||
},
|
||||
config_metadata=HEIHE_CONFIG_METADATA,
|
||||
i18n_resources=HEIHE_I18N_RESOURCES,
|
||||
)
|
||||
class HeihePlatformAdapter(Platform):
|
||||
def __init__(
|
||||
self,
|
||||
platform_config: dict,
|
||||
platform_settings: dict,
|
||||
event_queue: asyncio.Queue,
|
||||
) -> None:
|
||||
super().__init__(platform_config, event_queue)
|
||||
self.settings = platform_settings
|
||||
|
||||
self.ws_url = str(platform_config.get("heihe_ws_url", "")).strip()
|
||||
self.token = str(platform_config.get("heihe_token", "")).strip()
|
||||
self.origin = str(
|
||||
platform_config.get("heihe_origin", "https://chat.xiaoheihe.cn"),
|
||||
).strip()
|
||||
self.bot_id = str(platform_config.get("heihe_bot_id", "")).strip()
|
||||
self.auto_reconnect = bool(platform_config.get("heihe_auto_reconnect", True))
|
||||
self.heartbeat_interval = int(
|
||||
cast(int, platform_config.get("heihe_heartbeat_interval", 20)),
|
||||
)
|
||||
self.reconnect_delay = int(
|
||||
cast(int, platform_config.get("heihe_reconnect_delay", 5)),
|
||||
)
|
||||
self.ignore_self_message = bool(
|
||||
platform_config.get("heihe_ignore_self_message", True),
|
||||
)
|
||||
|
||||
if not self.ws_url:
|
||||
raise ValueError("heihe_ws_url 不能为空。")
|
||||
|
||||
self.metadata = PlatformMetadata(
|
||||
name="heihe",
|
||||
description="黑盒机器人(WebSocket)适配器",
|
||||
id=cast(str, self.config.get("id", "heihe")),
|
||||
support_streaming_message=False,
|
||||
)
|
||||
|
||||
self.ws: ClientConnection | None = None
|
||||
self.running = False
|
||||
self.heartbeat_task: asyncio.Task | None = None
|
||||
self._last_heartbeat_ts = 0
|
||||
|
||||
def meta(self) -> PlatformMetadata:
|
||||
return self.metadata
|
||||
|
||||
async def run(self) -> None:
|
||||
self.running = True
|
||||
while self.running:
|
||||
try:
|
||||
await self._connect_and_loop()
|
||||
except websockets.exceptions.ConnectionClosed as e:
|
||||
logger.warning("[heihe] websocket disconnected: %s", e)
|
||||
except Exception as e:
|
||||
logger.error("[heihe] websocket failed: %s", e)
|
||||
|
||||
if not self.running:
|
||||
break
|
||||
if not self.auto_reconnect:
|
||||
break
|
||||
await asyncio.sleep(max(1, self.reconnect_delay))
|
||||
|
||||
async def terminate(self) -> None:
|
||||
self.running = False
|
||||
if self.heartbeat_task:
|
||||
self.heartbeat_task.cancel()
|
||||
try:
|
||||
await self.heartbeat_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
if self.ws:
|
||||
try:
|
||||
await self.ws.close()
|
||||
except Exception:
|
||||
pass
|
||||
self.ws = None
|
||||
|
||||
async def send_by_session(
|
||||
self,
|
||||
session: MessageSesion,
|
||||
message_chain: MessageChain,
|
||||
) -> None:
|
||||
await HeiheMessageEvent.send_with_adapter(
|
||||
self,
|
||||
message_chain,
|
||||
session.session_id,
|
||||
)
|
||||
await super().send_by_session(session, message_chain)
|
||||
|
||||
async def send_payload(self, payload: Mapping[str, Any]) -> None:
|
||||
if not self.ws:
|
||||
raise RuntimeError("[heihe] websocket not connected")
|
||||
if self.ws.close_code is not None:
|
||||
raise RuntimeError("[heihe] websocket already closed")
|
||||
|
||||
body = dict(payload)
|
||||
body.setdefault("timestamp", int(time.time()))
|
||||
await self.ws.send(json.dumps(body, ensure_ascii=False))
|
||||
|
||||
async def _connect_and_loop(self) -> None:
|
||||
logger.info("[heihe] connecting websocket: %s", self.ws_url)
|
||||
|
||||
headers: dict[str, str] = {}
|
||||
if self.token:
|
||||
headers["Authorization"] = f"Bearer {self.token}"
|
||||
headers["X-Token"] = self.token
|
||||
|
||||
websocket = await connect(
|
||||
self.ws_url,
|
||||
additional_headers=headers,
|
||||
max_size=10 * 1024 * 1024,
|
||||
ping_interval=None,
|
||||
)
|
||||
self.ws = websocket
|
||||
logger.info("[heihe] websocket connected")
|
||||
|
||||
if self.heartbeat_interval > 0:
|
||||
self.heartbeat_task = asyncio.create_task(self._heartbeat_loop())
|
||||
|
||||
try:
|
||||
async for raw in websocket:
|
||||
await self._handle_incoming(raw)
|
||||
finally:
|
||||
if self.heartbeat_task:
|
||||
self.heartbeat_task.cancel()
|
||||
try:
|
||||
await self.heartbeat_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self.heartbeat_task = None
|
||||
if self.ws:
|
||||
try:
|
||||
await self.ws.close()
|
||||
except Exception:
|
||||
pass
|
||||
self.ws = None
|
||||
|
||||
async def _heartbeat_loop(self) -> None:
|
||||
try:
|
||||
while self.running and self.ws and self.ws.close_code is None:
|
||||
await asyncio.sleep(self.heartbeat_interval)
|
||||
self._last_heartbeat_ts = int(time.time())
|
||||
await self.send_payload(
|
||||
{
|
||||
"type": "ping",
|
||||
"ping": self._last_heartbeat_ts,
|
||||
},
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.warning("[heihe] heartbeat error: %s", e)
|
||||
|
||||
async def _handle_incoming(self, raw: Any) -> None:
|
||||
if isinstance(raw, bytes):
|
||||
try:
|
||||
raw = raw.decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
return
|
||||
if not isinstance(raw, str):
|
||||
return
|
||||
|
||||
try:
|
||||
data = json.loads(raw)
|
||||
except json.JSONDecodeError:
|
||||
logger.debug("[heihe] skip non-json frame: %s", raw[:200])
|
||||
return
|
||||
|
||||
if isinstance(data, list):
|
||||
for item in data:
|
||||
if isinstance(item, dict):
|
||||
await self._handle_packet(item)
|
||||
return
|
||||
if isinstance(data, dict):
|
||||
await self._handle_packet(data)
|
||||
|
||||
async def _handle_packet(self, packet: dict[str, Any]) -> None:
|
||||
if "ping" in packet:
|
||||
await self.send_payload({"type": "pong", "pong": packet.get("ping")})
|
||||
return
|
||||
if str(packet.get("type", "")).lower() == "ping":
|
||||
await self.send_payload({"type": "pong", "pong": packet.get("ping")})
|
||||
return
|
||||
|
||||
event_type = str(
|
||||
packet.get("event")
|
||||
or packet.get("event_type")
|
||||
or packet.get("type")
|
||||
or packet.get("topic")
|
||||
or "",
|
||||
).lower()
|
||||
payload_obj = packet.get("data")
|
||||
payload = payload_obj if isinstance(payload_obj, dict) else packet
|
||||
|
||||
if not self._is_message_event(event_type, payload):
|
||||
return
|
||||
|
||||
abm = self._convert_message(payload, packet)
|
||||
if not abm:
|
||||
return
|
||||
await self.handle_msg(abm)
|
||||
|
||||
@staticmethod
|
||||
def _is_message_event(event_type: str, payload: Mapping[str, Any]) -> bool:
|
||||
if "message" in event_type:
|
||||
return True
|
||||
keys = payload.keys()
|
||||
return "content" in keys or "text" in keys or "message" in keys
|
||||
|
||||
def _convert_message(
|
||||
self,
|
||||
payload: Mapping[str, Any],
|
||||
raw_packet: Mapping[str, Any],
|
||||
) -> AstrBotMessage | None:
|
||||
message_obj = payload.get("message")
|
||||
message = message_obj if isinstance(message_obj, Mapping) else payload
|
||||
|
||||
sender_data_obj = (
|
||||
payload.get("sender") or payload.get("author") or payload.get("user") or {}
|
||||
)
|
||||
sender_data = sender_data_obj if isinstance(sender_data_obj, Mapping) else {}
|
||||
sender_id = str(
|
||||
sender_data.get("id")
|
||||
or sender_data.get("user_id")
|
||||
or payload.get("sender_id")
|
||||
or payload.get("user_id")
|
||||
or "",
|
||||
).strip()
|
||||
sender_name = str(
|
||||
sender_data.get("nickname")
|
||||
or sender_data.get("name")
|
||||
or sender_data.get("username")
|
||||
or sender_id
|
||||
or "unknown",
|
||||
)
|
||||
|
||||
self_id = str(
|
||||
payload.get("self_id")
|
||||
or payload.get("bot_id")
|
||||
or self.bot_id
|
||||
or self.meta().id,
|
||||
)
|
||||
if self.ignore_self_message and sender_id and self_id and sender_id == self_id:
|
||||
return None
|
||||
|
||||
channel_id = str(
|
||||
payload.get("channel_id")
|
||||
or payload.get("room_id")
|
||||
or payload.get("chat_id")
|
||||
or payload.get("session_id")
|
||||
or "",
|
||||
).strip()
|
||||
guild_id = str(
|
||||
payload.get("guild_id")
|
||||
or payload.get("server_id")
|
||||
or payload.get("group_id")
|
||||
or "",
|
||||
).strip()
|
||||
is_private = bool(payload.get("is_private", False))
|
||||
if str(payload.get("message_type", "")).lower() in {"private", "friend", "dm"}:
|
||||
is_private = True
|
||||
|
||||
session_id = channel_id or sender_id
|
||||
if not session_id:
|
||||
return None
|
||||
|
||||
text = str(message.get("content") or message.get("text") or "").strip()
|
||||
components = self._build_components(text, payload)
|
||||
if not components:
|
||||
return None
|
||||
|
||||
abm = AstrBotMessage()
|
||||
abm.self_id = self_id
|
||||
abm.message_id = str(
|
||||
message.get("id")
|
||||
or message.get("message_id")
|
||||
or payload.get("message_id")
|
||||
or payload.get("msg_id")
|
||||
or uuid.uuid4().hex
|
||||
)
|
||||
timestamp_raw = (
|
||||
payload.get("timestamp")
|
||||
or payload.get("time")
|
||||
or message.get("timestamp")
|
||||
or message.get("time")
|
||||
)
|
||||
abm.timestamp = int(time.time())
|
||||
if isinstance(timestamp_raw, int):
|
||||
abm.timestamp = (
|
||||
timestamp_raw // 1000
|
||||
if timestamp_raw > 1_000_000_000_000
|
||||
else timestamp_raw
|
||||
)
|
||||
|
||||
if not is_private and (channel_id or guild_id):
|
||||
abm.type = MessageType.GROUP_MESSAGE
|
||||
abm.group = Group(
|
||||
group_id=guild_id or channel_id, group_name=guild_id or ""
|
||||
)
|
||||
else:
|
||||
abm.type = MessageType.FRIEND_MESSAGE
|
||||
|
||||
abm.session_id = session_id
|
||||
abm.sender = MessageMember(user_id=sender_id or "unknown", nickname=sender_name)
|
||||
abm.message = components
|
||||
abm.message_str = self._build_message_str(components)
|
||||
abm.raw_message = dict(raw_packet)
|
||||
return abm
|
||||
|
||||
@staticmethod
|
||||
def _build_components(text: str, payload: Mapping[str, Any]) -> list:
|
||||
components: list = []
|
||||
if text:
|
||||
components.append(Plain(text=text))
|
||||
|
||||
mentions_obj = payload.get("mentions")
|
||||
if isinstance(mentions_obj, list):
|
||||
for mention in mentions_obj:
|
||||
if not isinstance(mention, Mapping):
|
||||
continue
|
||||
user_id = str(mention.get("user_id") or mention.get("id") or "").strip()
|
||||
name = str(mention.get("name") or mention.get("nickname") or "").strip()
|
||||
if user_id or name:
|
||||
components.append(At(qq=user_id, name=name))
|
||||
|
||||
attachments_obj = payload.get("attachments")
|
||||
if isinstance(attachments_obj, list):
|
||||
for item in attachments_obj:
|
||||
if not isinstance(item, Mapping):
|
||||
continue
|
||||
url = str(item.get("url") or item.get("file_url") or "").strip()
|
||||
if not url:
|
||||
continue
|
||||
kind = str(item.get("type") or item.get("media_type") or "").lower()
|
||||
if "image" in kind:
|
||||
components.append(Image.fromURL(url))
|
||||
else:
|
||||
components.append(Plain(text=f"[{kind or 'file'}] {url}"))
|
||||
return components
|
||||
|
||||
@staticmethod
|
||||
def _build_message_str(components: list) -> str:
|
||||
parts: list[str] = []
|
||||
for comp in components:
|
||||
if isinstance(comp, Plain):
|
||||
parts.append(comp.text)
|
||||
elif isinstance(comp, At):
|
||||
parts.append(f"@{comp.name or comp.qq}")
|
||||
elif isinstance(comp, Image):
|
||||
parts.append("[image]")
|
||||
else:
|
||||
parts.append(f"[{comp.type}]")
|
||||
return " ".join(i for i in parts if i).strip()
|
||||
|
||||
async def handle_msg(self, abm: AstrBotMessage) -> None:
|
||||
event = HeiheMessageEvent(
|
||||
message_str=abm.message_str,
|
||||
message_obj=abm,
|
||||
platform_meta=self.meta(),
|
||||
session_id=abm.session_id,
|
||||
adapter=self,
|
||||
)
|
||||
self.commit_event(event)
|
||||
@@ -0,0 +1,108 @@
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from astrbot.api import logger
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
from astrbot.api.message_components import At, Image, Plain, Reply
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .heihe_adapter import HeihePlatformAdapter
|
||||
|
||||
|
||||
class HeiheMessageEvent(AstrMessageEvent):
|
||||
def __init__(
|
||||
self,
|
||||
message_str: str,
|
||||
message_obj,
|
||||
platform_meta,
|
||||
session_id: str,
|
||||
adapter: "HeihePlatformAdapter",
|
||||
) -> None:
|
||||
super().__init__(message_str, message_obj, platform_meta, session_id)
|
||||
self.adapter = adapter
|
||||
|
||||
@classmethod
|
||||
async def send_with_adapter(
|
||||
cls,
|
||||
adapter: "HeihePlatformAdapter",
|
||||
message: MessageChain,
|
||||
session_id: str,
|
||||
) -> None:
|
||||
payload = await cls._build_send_payload(message, session_id)
|
||||
await adapter.send_payload(payload)
|
||||
|
||||
async def send(self, message: MessageChain) -> None:
|
||||
await self.send_with_adapter(self.adapter, message, self.session_id)
|
||||
await super().send(message)
|
||||
|
||||
async def send_streaming(
|
||||
self,
|
||||
generator: AsyncGenerator,
|
||||
use_fallback: bool = False,
|
||||
):
|
||||
buffer = None
|
||||
async for chain in generator:
|
||||
if not buffer:
|
||||
buffer = chain
|
||||
else:
|
||||
buffer.chain.extend(chain.chain)
|
||||
if not buffer:
|
||||
return None
|
||||
buffer.squash_plain()
|
||||
await self.send(buffer)
|
||||
return await super().send_streaming(generator, use_fallback)
|
||||
|
||||
@classmethod
|
||||
async def _build_send_payload(
|
||||
cls,
|
||||
message: MessageChain,
|
||||
session_id: str,
|
||||
) -> dict[str, Any]:
|
||||
text_parts: list[str] = []
|
||||
segments: list[dict[str, Any]] = []
|
||||
|
||||
for component in message.chain:
|
||||
if isinstance(component, Plain):
|
||||
if component.text:
|
||||
text_parts.append(component.text)
|
||||
segments.append({"type": "text", "text": component.text})
|
||||
continue
|
||||
|
||||
if isinstance(component, At):
|
||||
at_name = str(component.name or component.qq or "").strip()
|
||||
if at_name:
|
||||
text_parts.append(f"@{at_name}")
|
||||
segments.append(
|
||||
{
|
||||
"type": "mention",
|
||||
"user_id": str(component.qq or ""),
|
||||
"name": at_name,
|
||||
},
|
||||
)
|
||||
continue
|
||||
|
||||
if isinstance(component, Reply):
|
||||
if component.id:
|
||||
segments.append({"type": "reply", "message_id": component.id})
|
||||
continue
|
||||
|
||||
if isinstance(component, Image):
|
||||
image_url = ""
|
||||
try:
|
||||
image_url = await component.register_to_file_service()
|
||||
except Exception as e:
|
||||
logger.debug("[heihe] image upload fallback failed: %s", e)
|
||||
|
||||
if image_url:
|
||||
segments.append({"type": "image", "url": image_url})
|
||||
text_parts.append("[image]")
|
||||
continue
|
||||
|
||||
content = "".join(text_parts).strip()
|
||||
payload: dict[str, Any] = {
|
||||
"action": "send_message",
|
||||
"channel_id": session_id,
|
||||
"content": content,
|
||||
"segments": segments,
|
||||
}
|
||||
return payload
|
||||
@@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import uuid
|
||||
@@ -25,6 +26,9 @@ from astrbot.core.star.filter.command import CommandFilter
|
||||
from astrbot.core.star.filter.command_group import CommandGroupFilter
|
||||
from astrbot.core.star.star import star_map
|
||||
from astrbot.core.star.star_handler import star_handlers_registry
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
|
||||
from astrbot.core.utils.io import download_image_by_url
|
||||
from astrbot.core.utils.media_utils import convert_audio_to_wav
|
||||
|
||||
from .tg_event import TelegramPlatformEvent
|
||||
|
||||
@@ -375,8 +379,19 @@ class TelegramPlatformAdapter(Platform):
|
||||
|
||||
elif update.message.voice:
|
||||
file = await update.message.voice.get_file()
|
||||
|
||||
file_basename = os.path.basename(file.file_path)
|
||||
temp_dir = get_astrbot_temp_path()
|
||||
temp_path = os.path.join(temp_dir, file_basename)
|
||||
temp_path = await download_image_by_url(file.file_path, path=temp_path)
|
||||
path_wav = os.path.join(
|
||||
temp_dir,
|
||||
f"{file_basename}.wav",
|
||||
)
|
||||
path_wav = await convert_audio_to_wav(temp_path, path_wav)
|
||||
|
||||
message.message = [
|
||||
Comp.Record(file=file.file_path, url=file.file_path),
|
||||
Comp.Record(file=path_wav, url=path_wav),
|
||||
]
|
||||
|
||||
elif update.message.photo:
|
||||
|
||||
@@ -48,6 +48,9 @@ class GeminiEmbeddingProvider(EmbeddingProvider):
|
||||
result = await self.client.models.embed_content(
|
||||
model=self.model,
|
||||
contents=text,
|
||||
config=types.EmbedContentConfig(
|
||||
output_dimensionality=self.get_dim(),
|
||||
),
|
||||
)
|
||||
assert result.embeddings is not None
|
||||
assert result.embeddings[0].values is not None
|
||||
@@ -61,6 +64,9 @@ class GeminiEmbeddingProvider(EmbeddingProvider):
|
||||
result = await self.client.models.embed_content(
|
||||
model=self.model,
|
||||
contents=cast(types.ContentListUnion, text),
|
||||
config=types.EmbedContentConfig(
|
||||
output_dimensionality=self.get_dim(),
|
||||
),
|
||||
)
|
||||
assert result.embeddings is not None
|
||||
|
||||
|
||||
@@ -36,12 +36,20 @@ class OpenAIEmbeddingProvider(EmbeddingProvider):
|
||||
|
||||
async def get_embedding(self, text: str) -> list[float]:
|
||||
"""获取文本的嵌入"""
|
||||
embedding = await self.client.embeddings.create(input=text, model=self.model)
|
||||
embedding = await self.client.embeddings.create(
|
||||
input=text,
|
||||
model=self.model,
|
||||
dimensions=self.get_dim(),
|
||||
)
|
||||
return embedding.data[0].embedding
|
||||
|
||||
async def get_embeddings(self, text: list[str]) -> list[list[float]]:
|
||||
"""批量获取文本的嵌入"""
|
||||
embeddings = await self.client.embeddings.create(input=text, model=self.model)
|
||||
embeddings = await self.client.embeddings.create(
|
||||
input=text,
|
||||
model=self.model,
|
||||
dimensions=self.get_dim(),
|
||||
)
|
||||
return [item.embedding for item in embeddings.data]
|
||||
|
||||
def get_dim(self) -> int:
|
||||
|
||||
@@ -1,68 +1,19 @@
|
||||
from astrbot.core import html_renderer
|
||||
# 兼容导出: Provider 从 provider 模块重新导出
|
||||
from astrbot.core.provider import Provider
|
||||
from astrbot.core.star.star_tools import StarTools
|
||||
from astrbot.core.utils.command_parser import CommandParserMixin
|
||||
from astrbot.core.utils.plugin_kv_store import PluginKVStoreMixin
|
||||
|
||||
from .base import Star
|
||||
from .context import Context
|
||||
from .star import StarMetadata, star_map, star_registry
|
||||
from .star_manager import PluginManager
|
||||
from .star_tools import StarTools
|
||||
|
||||
|
||||
class Star(CommandParserMixin, PluginKVStoreMixin):
|
||||
"""所有插件(Star)的父类,所有插件都应该继承于这个类"""
|
||||
|
||||
author: str
|
||||
name: str
|
||||
|
||||
def __init__(self, context: Context, config: dict | None = None) -> None:
|
||||
StarTools.initialize(context)
|
||||
self.context = context
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
super().__init_subclass__(**kwargs)
|
||||
if not star_map.get(cls.__module__):
|
||||
metadata = StarMetadata(
|
||||
star_cls_type=cls,
|
||||
module_path=cls.__module__,
|
||||
)
|
||||
star_map[cls.__module__] = metadata
|
||||
star_registry.append(metadata)
|
||||
else:
|
||||
star_map[cls.__module__].star_cls_type = cls
|
||||
star_map[cls.__module__].module_path = cls.__module__
|
||||
|
||||
async def text_to_image(self, text: str, return_url=True) -> str:
|
||||
"""将文本转换为图片"""
|
||||
return await html_renderer.render_t2i(
|
||||
text,
|
||||
return_url=return_url,
|
||||
template_name=self.context._config.get("t2i_active_template"),
|
||||
)
|
||||
|
||||
async def html_render(
|
||||
self,
|
||||
tmpl: str,
|
||||
data: dict,
|
||||
return_url=True,
|
||||
options: dict | None = None,
|
||||
) -> str:
|
||||
"""渲染 HTML"""
|
||||
return await html_renderer.render_custom_template(
|
||||
tmpl,
|
||||
data,
|
||||
return_url=return_url,
|
||||
options=options,
|
||||
)
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""当插件被激活时会调用这个方法"""
|
||||
|
||||
async def terminate(self) -> None:
|
||||
"""当插件被禁用、重载插件时会调用这个方法"""
|
||||
|
||||
def __del__(self) -> None:
|
||||
"""[Deprecated] 当插件被禁用、重载插件时会调用这个方法"""
|
||||
|
||||
|
||||
__all__ = ["Context", "PluginManager", "Provider", "Star", "StarMetadata", "StarTools"]
|
||||
__all__ = [
|
||||
"Context",
|
||||
"PluginManager",
|
||||
"Provider",
|
||||
"Star",
|
||||
"StarMetadata",
|
||||
"StarTools",
|
||||
"star_map",
|
||||
"star_registry",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,87 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Protocol
|
||||
|
||||
from astrbot.core import html_renderer
|
||||
from astrbot.core.utils.command_parser import CommandParserMixin
|
||||
from astrbot.core.utils.plugin_kv_store import PluginKVStoreMixin
|
||||
|
||||
from .star import StarMetadata, star_map, star_registry
|
||||
|
||||
logger = logging.getLogger("astrbot")
|
||||
|
||||
|
||||
class Star(CommandParserMixin, PluginKVStoreMixin):
|
||||
"""所有插件(Star)的父类,所有插件都应该继承于这个类"""
|
||||
|
||||
author: str
|
||||
name: str
|
||||
|
||||
class _ContextLike(Protocol):
|
||||
def get_config(self, umo: str | None = None) -> Any: ...
|
||||
|
||||
def __init__(self, context: _ContextLike, config: dict | None = None) -> None:
|
||||
self.context = context
|
||||
|
||||
def _get_context_config(self) -> Any:
|
||||
get_config = getattr(self.context, "get_config", None)
|
||||
if callable(get_config):
|
||||
try:
|
||||
return get_config()
|
||||
except Exception as e:
|
||||
logger.debug(f"get_config() failed: {e}")
|
||||
return None
|
||||
return getattr(self.context, "_config", None)
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
super().__init_subclass__(**kwargs)
|
||||
if not star_map.get(cls.__module__):
|
||||
metadata = StarMetadata(
|
||||
star_cls_type=cls,
|
||||
module_path=cls.__module__,
|
||||
)
|
||||
star_map[cls.__module__] = metadata
|
||||
star_registry.append(metadata)
|
||||
else:
|
||||
star_map[cls.__module__].star_cls_type = cls
|
||||
star_map[cls.__module__].module_path = cls.__module__
|
||||
|
||||
async def text_to_image(self, text: str, return_url=True) -> str:
|
||||
"""将文本转换为图片"""
|
||||
config_obj = self._get_context_config()
|
||||
template_name = None
|
||||
if hasattr(config_obj, "get"):
|
||||
try:
|
||||
template_name = config_obj.get("t2i_active_template")
|
||||
except Exception:
|
||||
template_name = None
|
||||
return await html_renderer.render_t2i(
|
||||
text,
|
||||
return_url=return_url,
|
||||
template_name=template_name,
|
||||
)
|
||||
|
||||
async def html_render(
|
||||
self,
|
||||
tmpl: str,
|
||||
data: dict,
|
||||
return_url=True,
|
||||
options: dict | None = None,
|
||||
) -> str:
|
||||
"""渲染 HTML"""
|
||||
return await html_renderer.render_custom_template(
|
||||
tmpl,
|
||||
data,
|
||||
return_url=return_url,
|
||||
options=options,
|
||||
)
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""当插件被激活时会调用这个方法"""
|
||||
|
||||
async def terminate(self) -> None:
|
||||
"""当插件被禁用、重载插件时会调用这个方法"""
|
||||
|
||||
def __del__(self) -> None:
|
||||
"""[Deprecated] 当插件被禁用、重载插件时会调用这个方法"""
|
||||
@@ -1,7 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from asyncio import Queue
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any, Protocol
|
||||
|
||||
from deprecated import deprecated
|
||||
|
||||
@@ -12,14 +14,12 @@ from astrbot.core.agent.tool import ToolSet
|
||||
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core.conversation_mgr import ConversationManager
|
||||
from astrbot.core.cron.manager import CronJobManager
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from astrbot.core.persona_mgr import PersonaManager
|
||||
from astrbot.core.platform import Platform
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent, MessageSesion
|
||||
from astrbot.core.platform.manager import PlatformManager
|
||||
from astrbot.core.platform_message_history_mgr import PlatformMessageHistoryManager
|
||||
from astrbot.core.provider.entities import LLMResponse, ProviderRequest, ProviderType
|
||||
from astrbot.core.provider.func_tool_manager import FunctionTool, FunctionToolManager
|
||||
@@ -45,6 +45,15 @@ from .star_handler import EventType, StarHandlerMetadata, star_handlers_registry
|
||||
|
||||
logger = logging.getLogger("astrbot")
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from astrbot.core.cron.manager import CronJobManager
|
||||
else:
|
||||
CronJobManager = Any
|
||||
|
||||
|
||||
class PlatformManagerProtocol(Protocol):
|
||||
platform_insts: list[Platform]
|
||||
|
||||
|
||||
class Context:
|
||||
"""暴露给插件的接口上下文。"""
|
||||
@@ -61,7 +70,7 @@ class Context:
|
||||
config: AstrBotConfig,
|
||||
db: BaseDatabase,
|
||||
provider_manager: ProviderManager,
|
||||
platform_manager: PlatformManager,
|
||||
platform_manager: PlatformManagerProtocol,
|
||||
conversation_manager: ConversationManager,
|
||||
message_history_manager: PlatformMessageHistoryManager,
|
||||
persona_manager: PersonaManager,
|
||||
@@ -448,6 +457,9 @@ class Context:
|
||||
if platform.meta().id == session.platform_name:
|
||||
await platform.send_by_session(session, message_chain)
|
||||
return True
|
||||
logger.warning(
|
||||
f"cannot find platform for session {str(session)}, message not sent"
|
||||
)
|
||||
return False
|
||||
|
||||
def add_llm_tools(self, *tools: FunctionTool) -> None:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import warnings
|
||||
|
||||
from astrbot.core.star import StarMetadata, star_map
|
||||
from astrbot.core.star.star import StarMetadata, star_map
|
||||
|
||||
_warned_register_star = False
|
||||
|
||||
|
||||
@@ -11,7 +11,6 @@ from astrbot.core.agent.agent import Agent
|
||||
from astrbot.core.agent.handoff import HandoffTool
|
||||
from astrbot.core.agent.hooks import BaseAgentRunHooks
|
||||
from astrbot.core.agent.tool import FunctionTool
|
||||
from astrbot.core.astr_agent_context import AstrAgentContext
|
||||
from astrbot.core.message.message_event_result import MessageEventResult
|
||||
from astrbot.core.provider.func_tool_manager import PY_TO_JSON_TYPE, SUPPORTED_TYPES
|
||||
from astrbot.core.provider.register import llm_tools
|
||||
@@ -617,7 +616,7 @@ class RegisteringAgent:
|
||||
kwargs["registering_agent"] = self
|
||||
return register_llm_tool(*args, **kwargs)
|
||||
|
||||
def __init__(self, agent: Agent[AstrAgentContext]) -> None:
|
||||
def __init__(self, agent: Agent[Any]) -> None:
|
||||
self._agent = agent
|
||||
|
||||
|
||||
@@ -625,7 +624,7 @@ def register_agent(
|
||||
name: str,
|
||||
instruction: str,
|
||||
tools: list[str | FunctionTool] | None = None,
|
||||
run_hooks: BaseAgentRunHooks[AstrAgentContext] | None = None,
|
||||
run_hooks: BaseAgentRunHooks[Any] | None = None,
|
||||
):
|
||||
"""注册一个 Agent
|
||||
|
||||
@@ -639,12 +638,12 @@ def register_agent(
|
||||
tools_ = tools or []
|
||||
|
||||
def decorator(awaitable: Callable[..., Awaitable[Any]]):
|
||||
AstrAgent = Agent[AstrAgentContext]
|
||||
AstrAgent = Agent[Any]
|
||||
agent = AstrAgent(
|
||||
name=name,
|
||||
instructions=instruction,
|
||||
tools=tools_,
|
||||
run_hooks=run_hooks or BaseAgentRunHooks[AstrAgentContext](),
|
||||
run_hooks=run_hooks or BaseAgentRunHooks[Any](),
|
||||
)
|
||||
handoff_tool = HandoffTool(agent=agent)
|
||||
handoff_tool.handler = awaitable
|
||||
|
||||
@@ -49,10 +49,13 @@ class PluginVersionIncompatibleError(Exception):
|
||||
|
||||
class PluginManager:
|
||||
def __init__(self, context: Context, config: AstrBotConfig) -> None:
|
||||
from .star_tools import StarTools
|
||||
|
||||
self.updator = PluginUpdator()
|
||||
|
||||
self.context = context
|
||||
self.context._star_manager = self # type: ignore
|
||||
StarTools.initialize(context)
|
||||
|
||||
self.config = config
|
||||
self.plugin_store_path = get_astrbot_plugin_path()
|
||||
|
||||
@@ -0,0 +1,60 @@
|
||||
## What's Changed
|
||||
|
||||
### 新增
|
||||
- 新增 Agent 会话停止能力,并优化 stop 请求处理流程,支持 /stop 指令终止 Agent 运行并尽量不丢失已运行输出的结果。 ([#5380](https://github.com/AstrBotDevs/AstrBot/issues/5380))。
|
||||
- 新增 SubAgent 交接场景下的 computer-use 工具支持 ([#5399](https://github.com/AstrBotDevs/AstrBot/issues/5399))。
|
||||
- 新增 Agent 执行过程中展示工具调用结果的能力,提升执行过程可观测性 ([#5388](https://github.com/AstrBotDevs/AstrBot/issues/5388))。
|
||||
- 新增插件加载/卸载 Hook,扩展插件生命周期能力 ([#5331](https://github.com/AstrBotDevs/AstrBot/issues/5331))。
|
||||
- 新增插件加载失败后的热重载能力,提升插件开发与恢复效率 ([#5334](https://github.com/AstrBotDevs/AstrBot/issues/5334))。
|
||||
- 新增 SubAgent 图片 URL/本地路径输入支持 ([#5348](https://github.com/AstrBotDevs/AstrBot/issues/5348))。
|
||||
- 新增 Dashboard 发布跳转基础 URL 可配置项 ([#5330](https://github.com/AstrBotDevs/AstrBot/issues/5330))。
|
||||
|
||||
### 修复
|
||||
- 修复 Tavily 请求的硬编码 6 秒超时。
|
||||
- 修复 OneBot v11 适配器关闭之后仍然在连接的问题([#5412](https://github.com/AstrBotDevs/AstrBot/issues/5412))。
|
||||
- 修复上下文会话中平台缺失时的日志处理,补充 warning 并改进排查信息。
|
||||
- 修复 embedding 维度未透传到 provider API 的问题 ([#5411](https://github.com/AstrBotDevs/AstrBot/issues/5411))。
|
||||
- 修复 File 组件处理逻辑并增强 OneBot 驱动层路径兼容性 ([#5391](https://github.com/AstrBotDevs/AstrBot/issues/5391))。
|
||||
- 修复 sandbox 文件传输工具缺少管理员权限校验的问题 ([#5402](https://github.com/AstrBotDevs/AstrBot/issues/5402))。
|
||||
- 修复 pipeline 与 `from ... import *` 引发的循环依赖问题 ([#5353](https://github.com/AstrBotDevs/AstrBot/issues/5353))。
|
||||
- 修复配置文件存在 UTF-8 BOM 时的解析问题 ([#5376](https://github.com/AstrBotDevs/AstrBot/issues/5376))。
|
||||
- 修复 ChatUI 复制回滚路径缺失与错误提示不清晰的问题 ([#5352](https://github.com/AstrBotDevs/AstrBot/issues/5352))。
|
||||
- 修复保留插件目录处理逻辑,避免插件目录行为异常 ([#5369](https://github.com/AstrBotDevs/AstrBot/issues/5369))。
|
||||
- 修复 ChatUI 文件消息段无法持久化的问题 ([#5386](https://github.com/AstrBotDevs/AstrBot/issues/5386))。
|
||||
- 修复 `.dockerignore` 误排除 `changelogs` 目录的问题。
|
||||
- 修复 aiohttp 版本过新导致 qq-botpy 报错的问题 ([#5316](https://github.com/AstrBotDevs/AstrBot/issues/5316))。
|
||||
|
||||
### 优化
|
||||
- 完成 SubAgent 编排页面国际化,补齐多语言支持 ([#5400](https://github.com/AstrBotDevs/AstrBot/issues/5400))。
|
||||
- 增补消息事件处理相关测试,并完善测试框架的 fixtures/mocks 覆盖 ([#5355](https://github.com/AstrBotDevs/AstrBot/issues/5355), [#5354](https://github.com/AstrBotDevs/AstrBot/issues/5354))。
|
||||
|
||||
## What's Changed (EN)
|
||||
|
||||
### New Features
|
||||
- Added computer-use tools support in sub-agent handoff scenarios ([#5399](https://github.com/AstrBotDevs/AstrBot/issues/5399)).
|
||||
- Added support for displaying tool call results during agent execution for better observability ([#5388](https://github.com/AstrBotDevs/AstrBot/issues/5388)).
|
||||
- Added plugin load/unload hooks to extend plugin lifecycle capabilities ([#5331](https://github.com/AstrBotDevs/AstrBot/issues/5331)).
|
||||
- Added hot reload support when plugin loading fails, improving recovery during plugin development ([#5334](https://github.com/AstrBotDevs/AstrBot/issues/5334)).
|
||||
- Added image URL/local path input support for sub-agents ([#5348](https://github.com/AstrBotDevs/AstrBot/issues/5348)).
|
||||
- Added stop control for active agent sessions and improved stop request handling ([#5380](https://github.com/AstrBotDevs/AstrBot/issues/5380)).
|
||||
- Added configurable base URL for dashboard release redirects ([#5330](https://github.com/AstrBotDevs/AstrBot/issues/5330)).
|
||||
|
||||
### Fixes
|
||||
- Fixed logging behavior when platform information is missing in context sessions, with clearer warning and diagnostics.
|
||||
- Fixed missing embedding dimensions being passed to provider APIs ([#5411](https://github.com/AstrBotDevs/AstrBot/issues/5411)).
|
||||
- Fixed shutdown stability issues in the aiocqhttp adapter ([#5412](https://github.com/AstrBotDevs/AstrBot/issues/5412)).
|
||||
- Fixed File component handling and improved path compatibility in the OneBot driver layer ([#5391](https://github.com/AstrBotDevs/AstrBot/issues/5391)).
|
||||
- Fixed missing admin guard for sandbox file transfer tools ([#5402](https://github.com/AstrBotDevs/AstrBot/issues/5402)).
|
||||
- Fixed circular import issues related to pipeline and `from ... import *` usage ([#5353](https://github.com/AstrBotDevs/AstrBot/issues/5353)).
|
||||
- Fixed config parsing issues when files contain UTF-8 BOM ([#5376](https://github.com/AstrBotDevs/AstrBot/issues/5376)).
|
||||
- Fixed missing copy rollback path and unclear error messaging in ChatUI ([#5352](https://github.com/AstrBotDevs/AstrBot/issues/5352)).
|
||||
- Fixed reserved plugin directory handling to avoid abnormal plugin path behavior ([#5369](https://github.com/AstrBotDevs/AstrBot/issues/5369)).
|
||||
- Fixed ChatUI file segment persistence issues ([#5386](https://github.com/AstrBotDevs/AstrBot/issues/5386)).
|
||||
- Fixed accidental exclusion of the `changelogs` directory in `.dockerignore`.
|
||||
- Fixed compatibility issues caused by a hard-coded 6-second timeout in Tavily requests.
|
||||
- Fixed qq-botpy runtime errors caused by overly new aiohttp versions ([#5316](https://github.com/AstrBotDevs/AstrBot/issues/5316)).
|
||||
|
||||
### Improvements
|
||||
- Completed internationalization for the sub-agent orchestration page ([#5400](https://github.com/AstrBotDevs/AstrBot/issues/5400)).
|
||||
- Added broader message-event test coverage and improved fixtures/mocks in the test framework ([#5355](https://github.com/AstrBotDevs/AstrBot/issues/5355), [#5354](https://github.com/AstrBotDevs/AstrBot/issues/5354)).
|
||||
- Updated README content and applied repository-wide formatting cleanup (ruff format) ([#5375](https://github.com/AstrBotDevs/AstrBot/issues/5375)).
|
||||
@@ -58,6 +58,18 @@
|
||||
"guideStep2": "Install it and restart AstrBot.",
|
||||
"guideStep3": "If you use Docker, prefer the image update path."
|
||||
},
|
||||
"desktopApp": {
|
||||
"title": "Update Desktop App",
|
||||
"message": "Check and upgrade the AstrBot desktop application.",
|
||||
"currentVersion": "Current version: ",
|
||||
"latestVersion": "Latest version: ",
|
||||
"checking": "Checking desktop app updates...",
|
||||
"hasNewVersion": "A new version is available. Click confirm to upgrade.",
|
||||
"isLatest": "Already on the latest version",
|
||||
"installing": "Downloading and installing update. The app will restart automatically...",
|
||||
"checkFailed": "Failed to check updates. Please try again later.",
|
||||
"installFailed": "Upgrade failed. Please try again later."
|
||||
},
|
||||
"dashboardUpdate": {
|
||||
"title": "Update Dashboard to Latest Version Only",
|
||||
"currentVersion": "Current Version",
|
||||
|
||||
@@ -251,6 +251,10 @@
|
||||
"show_tool_use_status": {
|
||||
"description": "Output Function Call Status"
|
||||
},
|
||||
"show_tool_call_result": {
|
||||
"description": "Output Tool Call Results",
|
||||
"hint": "Only takes effect when \"Output Function Call Status\" is enabled, and shows at most 70 characters."
|
||||
},
|
||||
"sanitize_context_by_modalities": {
|
||||
"description": "Sanitize History by Modalities",
|
||||
"hint": "When enabled, sanitizes contexts before each LLM request by removing image blocks and tool-call structures that the current provider's modalities do not support (this changes what the model sees)."
|
||||
|
||||
@@ -8,11 +8,14 @@
|
||||
"refresh": "Refresh",
|
||||
"save": "Save",
|
||||
"add": "Add SubAgent",
|
||||
"delete": "Delete"
|
||||
"delete": "Delete",
|
||||
"close": "Close"
|
||||
},
|
||||
"switches": {
|
||||
"enable": "Enable SubAgent orchestration",
|
||||
"dedupe": "Deduplicate main LLM tools (hide tools duplicated by SubAgents)"
|
||||
"enableHint": "Enable sub-agent functionality",
|
||||
"dedupe": "Deduplicate main LLM tools (hide tools duplicated by SubAgents)",
|
||||
"dedupeHint": "Remove duplicate tools from main agent"
|
||||
},
|
||||
"description": {
|
||||
"disabled": "When off: SubAgent is disabled; the main LLM mounts tools via persona rules (all by default) and calls them directly.",
|
||||
@@ -29,7 +32,8 @@
|
||||
"transferPrefix": "transfer_to_{name}",
|
||||
"switchLabel": "Enable",
|
||||
"previewTitle": "Preview: handoff tool shown to the main LLM",
|
||||
"personaChip": "Persona: {id}"
|
||||
"personaChip": "Persona: {id}",
|
||||
"personaPreview": "PERSONA PREVIEW"
|
||||
},
|
||||
"form": {
|
||||
"nameLabel": "Agent name (used for transfer_to_{name})",
|
||||
@@ -49,6 +53,13 @@
|
||||
"nameDuplicate": "Duplicate SubAgent name: {name}",
|
||||
"personaMissing": "SubAgent {name} has no persona selected",
|
||||
"saveSuccess": "Saved successfully",
|
||||
"saveFailed": "Failed to save"
|
||||
"saveFailed": "Failed to save",
|
||||
"nameRequired": "Name is required",
|
||||
"namePattern": "Lowercase letters, numbers, underscore only"
|
||||
},
|
||||
"empty": {
|
||||
"title": "No Agents Configured",
|
||||
"subtitle": "Add a new sub-agent to get started",
|
||||
"action": "Create First Agent"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -58,6 +58,18 @@
|
||||
"guideStep2": "完成安装后重启 AstrBot。",
|
||||
"guideStep3": "如果你使用 Docker,请优先使用镜像更新方式。"
|
||||
},
|
||||
"desktopApp": {
|
||||
"title": "更新桌面应用",
|
||||
"message": "将检查并升级 AstrBot 桌面端程序。",
|
||||
"currentVersion": "当前版本:",
|
||||
"latestVersion": "最新版本:",
|
||||
"checking": "正在检查桌面应用更新...",
|
||||
"hasNewVersion": "发现新版本,可点击确认升级。",
|
||||
"isLatest": "已经是最新版本",
|
||||
"installing": "正在下载并安装更新,完成后将自动重启应用...",
|
||||
"checkFailed": "检查更新失败,请稍后重试。",
|
||||
"installFailed": "升级失败,请稍后重试。"
|
||||
},
|
||||
"dashboardUpdate": {
|
||||
"title": "单独更新管理面板到最新版本",
|
||||
"currentVersion": "当前版本",
|
||||
|
||||
@@ -254,6 +254,10 @@
|
||||
"show_tool_use_status": {
|
||||
"description": "输出函数调用状态"
|
||||
},
|
||||
"show_tool_call_result": {
|
||||
"description": "输出函数调用返回结果",
|
||||
"hint": "仅在启用“输出函数调用状态”时生效,且最多展示 70 个字符。"
|
||||
},
|
||||
"sanitize_context_by_modalities": {
|
||||
"description": "按模型能力清理历史上下文",
|
||||
"hint": "开启后,在每次请求 LLM 前会按当前模型提供商中所选择的模型能力删除对话中不支持的图片/工具调用结构(会改变模型看到的历史)"
|
||||
|
||||
@@ -8,11 +8,14 @@
|
||||
"refresh": "刷新",
|
||||
"save": "保存",
|
||||
"add": "新增 SubAgent",
|
||||
"delete": "删除"
|
||||
"delete": "删除",
|
||||
"close": "关闭"
|
||||
},
|
||||
"switches": {
|
||||
"enable": "启用 SubAgent 编排",
|
||||
"dedupe": "主 LLM 去重重复工具(与 SubAgent 重叠的工具将被隐藏)"
|
||||
"enableHint": "启用子代理功能",
|
||||
"dedupe": "主 LLM 去重重复工具(与 SubAgent 重叠的工具将被隐藏)",
|
||||
"dedupeHint": "从主代理中移除重复工具"
|
||||
},
|
||||
"description": {
|
||||
"disabled": "不启动:SubAgent 关闭;主 LLM 按 persona 规则挂载工具(默认全部),并直接调用。",
|
||||
@@ -39,6 +42,7 @@
|
||||
"providerHint": "留空表示跟随全局默认 provider。",
|
||||
"personaLabel": "选择人格设定",
|
||||
"personaHint": "SubAgent 将直接继承所选 Persona 的系统设定与工具。在人格设定页管理和新建人格。",
|
||||
"personaPreview": "人格预览",
|
||||
"descriptionLabel": "对主 LLM 的描述(用于决定是否 handoff)",
|
||||
"descriptionHint": "这段会作为 transfer_to_* 工具的描述给主 LLM 看,建议简短明确。"
|
||||
},
|
||||
@@ -50,6 +54,13 @@
|
||||
"nameDuplicate": "SubAgent 名称重复:{name}",
|
||||
"personaMissing": "SubAgent {name} 未选择 Persona",
|
||||
"saveSuccess": "保存成功",
|
||||
"saveFailed": "保存失败"
|
||||
"saveFailed": "保存失败",
|
||||
"nameRequired": "名称必填",
|
||||
"namePattern": "仅支持小写字母、数字和下划线"
|
||||
},
|
||||
"empty": {
|
||||
"title": "未配置 SubAgent",
|
||||
"subtitle": "添加一个新的子代理以开始",
|
||||
"action": "创建第一个 Agent"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -50,24 +50,27 @@ let installLoading = ref(false);
|
||||
const isDesktopReleaseMode = ref(
|
||||
typeof window !== 'undefined' && !!window.astrbotDesktop?.isDesktop
|
||||
);
|
||||
const redirectConfirmDialog = ref(false);
|
||||
const pendingRedirectUrl = ref('');
|
||||
const resolvingReleaseTarget = ref(false);
|
||||
const DEFAULT_ASTRBOT_RELEASE_BASE_URL = 'https://github.com/AstrBotDevs/AstrBot/releases';
|
||||
const resolveReleaseBaseUrl = () => {
|
||||
const raw = import.meta.env.VITE_ASTRBOT_RELEASE_BASE_URL;
|
||||
// Keep upstream default on AstrBot releases; desktop distributors can override via env injection.
|
||||
const normalized = raw?.trim()?.replace(/\/+$/, '') || '';
|
||||
const withoutLatestSuffix = normalized.replace(/\/latest$/i, '');
|
||||
return withoutLatestSuffix || DEFAULT_ASTRBOT_RELEASE_BASE_URL;
|
||||
};
|
||||
const releaseBaseUrl = resolveReleaseBaseUrl();
|
||||
const getReleaseUrlByTag = (tag: string | null | undefined) => {
|
||||
const normalizedTag = (tag || '').trim();
|
||||
if (!normalizedTag || normalizedTag.toLowerCase() === 'latest') {
|
||||
return `${releaseBaseUrl}/latest`;
|
||||
const desktopUpdateDialog = ref(false);
|
||||
const desktopUpdateChecking = ref(false);
|
||||
const desktopUpdateInstalling = ref(false);
|
||||
const desktopUpdateHasNewVersion = ref(false);
|
||||
const desktopUpdateCurrentVersion = ref('-');
|
||||
const desktopUpdateLatestVersion = ref('-');
|
||||
const desktopUpdateStatus = ref('');
|
||||
|
||||
const getAppUpdaterBridge = (): AstrBotAppUpdaterBridge | null => {
|
||||
if (typeof window === 'undefined') {
|
||||
return null;
|
||||
}
|
||||
return `${releaseBaseUrl}/tag/${normalizedTag}`;
|
||||
const bridge = window.astrbotAppUpdater;
|
||||
if (
|
||||
bridge &&
|
||||
typeof bridge.checkForAppUpdate === 'function' &&
|
||||
typeof bridge.installAppUpdate === 'function'
|
||||
) {
|
||||
return bridge;
|
||||
}
|
||||
return null;
|
||||
};
|
||||
|
||||
const getSelectedGitHubProxy = () => {
|
||||
@@ -89,16 +92,6 @@ const releasesHeader = computed(() => [
|
||||
{ title: t('core.header.updateDialog.table.sourceUrl'), key: 'zipball_url' },
|
||||
{ title: t('core.header.updateDialog.table.actions'), key: 'switch' }
|
||||
]);
|
||||
const latestReleaseTag = computed(() => {
|
||||
const firstRelease = (releases.value as any[])?.[0];
|
||||
if (firstRelease?.tag_name) {
|
||||
return firstRelease.tag_name as string;
|
||||
}
|
||||
return hasNewVersion.value
|
||||
? t('core.header.updateDialog.redirectConfirm.latestLabel')
|
||||
: (botCurrVersion.value || '-');
|
||||
});
|
||||
|
||||
// Form validation
|
||||
const formValid = ref(true);
|
||||
const passwordRules = computed(() => [
|
||||
@@ -126,47 +119,88 @@ const accountEditStatus = ref({
|
||||
message: ''
|
||||
});
|
||||
|
||||
const open = (link: string) => {
|
||||
window.open(link, '_blank');
|
||||
};
|
||||
|
||||
function requestExternalRedirect(link: string) {
|
||||
pendingRedirectUrl.value = link;
|
||||
redirectConfirmDialog.value = true;
|
||||
function cancelDesktopUpdate() {
|
||||
if (desktopUpdateInstalling.value) {
|
||||
return;
|
||||
}
|
||||
desktopUpdateDialog.value = false;
|
||||
}
|
||||
|
||||
function cancelExternalRedirect() {
|
||||
redirectConfirmDialog.value = false;
|
||||
pendingRedirectUrl.value = '';
|
||||
}
|
||||
async function openDesktopUpdateDialog() {
|
||||
desktopUpdateDialog.value = true;
|
||||
desktopUpdateChecking.value = true;
|
||||
desktopUpdateInstalling.value = false;
|
||||
desktopUpdateHasNewVersion.value = false;
|
||||
desktopUpdateCurrentVersion.value = '-';
|
||||
desktopUpdateLatestVersion.value = '-';
|
||||
desktopUpdateStatus.value = t('core.header.updateDialog.desktopApp.checking');
|
||||
|
||||
function confirmExternalRedirect() {
|
||||
const targetUrl = pendingRedirectUrl.value;
|
||||
cancelExternalRedirect();
|
||||
if (targetUrl) {
|
||||
open(targetUrl);
|
||||
const bridge = getAppUpdaterBridge();
|
||||
if (!bridge) {
|
||||
desktopUpdateChecking.value = false;
|
||||
desktopUpdateStatus.value = t('core.header.updateDialog.desktopApp.checkFailed');
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
const result = await bridge.checkForAppUpdate();
|
||||
if (!result?.ok) {
|
||||
desktopUpdateCurrentVersion.value = result?.currentVersion || '-';
|
||||
desktopUpdateLatestVersion.value =
|
||||
result?.latestVersion || result?.currentVersion || '-';
|
||||
desktopUpdateStatus.value =
|
||||
result?.reason || t('core.header.updateDialog.desktopApp.checkFailed');
|
||||
return;
|
||||
}
|
||||
|
||||
desktopUpdateCurrentVersion.value = result.currentVersion || '-';
|
||||
desktopUpdateLatestVersion.value =
|
||||
result.latestVersion || result.currentVersion || '-';
|
||||
desktopUpdateHasNewVersion.value = !!result.hasUpdate;
|
||||
desktopUpdateStatus.value = result.hasUpdate
|
||||
? t('core.header.updateDialog.desktopApp.hasNewVersion')
|
||||
: t('core.header.updateDialog.desktopApp.isLatest');
|
||||
} catch (error) {
|
||||
console.error(error);
|
||||
desktopUpdateStatus.value = t('core.header.updateDialog.desktopApp.checkFailed');
|
||||
} finally {
|
||||
desktopUpdateChecking.value = false;
|
||||
}
|
||||
}
|
||||
|
||||
const getReleaseUrlForDesktop = () => {
|
||||
const firstRelease = (releases.value as any[])?.[0];
|
||||
if (firstRelease?.tag_name) {
|
||||
return getReleaseUrlByTag(firstRelease.tag_name as string);
|
||||
async function confirmDesktopUpdate() {
|
||||
if (!desktopUpdateHasNewVersion.value || desktopUpdateInstalling.value) {
|
||||
return;
|
||||
}
|
||||
if (hasNewVersion.value) return getReleaseUrlByTag('latest');
|
||||
const tag = botCurrVersion.value?.startsWith('v') ? botCurrVersion.value : 'latest';
|
||||
return getReleaseUrlByTag(tag);
|
||||
};
|
||||
|
||||
const bridge = getAppUpdaterBridge();
|
||||
if (!bridge) {
|
||||
desktopUpdateStatus.value = t('core.header.updateDialog.desktopApp.installFailed');
|
||||
return;
|
||||
}
|
||||
|
||||
desktopUpdateInstalling.value = true;
|
||||
desktopUpdateStatus.value = t('core.header.updateDialog.desktopApp.installing');
|
||||
|
||||
try {
|
||||
const result = await bridge.installAppUpdate();
|
||||
if (result?.ok) {
|
||||
desktopUpdateDialog.value = false;
|
||||
return;
|
||||
}
|
||||
desktopUpdateStatus.value =
|
||||
result?.reason || t('core.header.updateDialog.desktopApp.installFailed');
|
||||
} catch (error) {
|
||||
console.error(error);
|
||||
desktopUpdateStatus.value = t('core.header.updateDialog.desktopApp.installFailed');
|
||||
} finally {
|
||||
desktopUpdateInstalling.value = false;
|
||||
}
|
||||
}
|
||||
|
||||
function handleUpdateClick() {
|
||||
if (isDesktopReleaseMode.value) {
|
||||
requestExternalRedirect('');
|
||||
resolvingReleaseTarget.value = true;
|
||||
checkUpdate();
|
||||
void getReleases().finally(() => {
|
||||
pendingRedirectUrl.value = getReleaseUrlForDesktop() || getReleaseUrlByTag('latest');
|
||||
resolvingReleaseTarget.value = false;
|
||||
});
|
||||
void openDesktopUpdateDialog();
|
||||
return;
|
||||
}
|
||||
checkUpdate();
|
||||
@@ -680,40 +714,38 @@ onMounted(async () => {
|
||||
</v-card>
|
||||
</v-dialog>
|
||||
|
||||
<v-dialog v-model="redirectConfirmDialog" max-width="460">
|
||||
<v-dialog v-model="desktopUpdateDialog" max-width="460">
|
||||
<v-card>
|
||||
<v-card-title class="text-h3 pa-4 pl-6 pb-0">
|
||||
{{ t('core.header.updateDialog.redirectConfirm.title') }}
|
||||
{{ t('core.header.updateDialog.desktopApp.title') }}
|
||||
</v-card-title>
|
||||
<v-card-text>
|
||||
<div class="mb-3">
|
||||
{{ t('core.header.updateDialog.redirectConfirm.message') }}
|
||||
{{ t('core.header.updateDialog.desktopApp.message') }}
|
||||
</div>
|
||||
<v-alert type="info" variant="tonal" density="compact">
|
||||
<div>
|
||||
{{ t('core.header.updateDialog.redirectConfirm.targetVersion') }}
|
||||
<strong v-if="!resolvingReleaseTarget">{{ latestReleaseTag }}</strong>
|
||||
<v-progress-circular v-else indeterminate size="16" width="2" class="ml-1" />
|
||||
{{ t('core.header.updateDialog.desktopApp.currentVersion') }}
|
||||
<strong>{{ desktopUpdateCurrentVersion }}</strong>
|
||||
</div>
|
||||
<div class="text-caption">
|
||||
{{ t('core.header.updateDialog.redirectConfirm.currentVersion') }}
|
||||
{{ botCurrVersion || '-' }}
|
||||
<div>
|
||||
{{ t('core.header.updateDialog.desktopApp.latestVersion') }}
|
||||
<strong v-if="!desktopUpdateChecking">{{ desktopUpdateLatestVersion }}</strong>
|
||||
<v-progress-circular v-else indeterminate size="16" width="2" class="ml-1" />
|
||||
</div>
|
||||
</v-alert>
|
||||
<div class="text-caption mt-3">
|
||||
<div>{{ t('core.header.updateDialog.redirectConfirm.guideTitle') }}</div>
|
||||
<div>1. {{ t('core.header.updateDialog.redirectConfirm.guideStep1') }}</div>
|
||||
<div>2. {{ t('core.header.updateDialog.redirectConfirm.guideStep2') }}</div>
|
||||
<div>3. {{ t('core.header.updateDialog.redirectConfirm.guideStep3') }}</div>
|
||||
{{ desktopUpdateStatus }}
|
||||
</div>
|
||||
</v-card-text>
|
||||
<v-card-actions>
|
||||
<v-spacer></v-spacer>
|
||||
<v-btn color="grey" variant="text" @click="cancelExternalRedirect">
|
||||
<v-btn color="grey" variant="text" @click="cancelDesktopUpdate" :disabled="desktopUpdateInstalling">
|
||||
{{ t('core.common.dialog.cancelButton') }}
|
||||
</v-btn>
|
||||
<v-btn color="primary" variant="flat" @click="confirmExternalRedirect"
|
||||
:loading="resolvingReleaseTarget" :disabled="resolvingReleaseTarget || !pendingRedirectUrl">
|
||||
<v-btn color="primary" variant="flat" @click="confirmDesktopUpdate"
|
||||
:loading="desktopUpdateInstalling"
|
||||
:disabled="desktopUpdateChecking || desktopUpdateInstalling || !desktopUpdateHasNewVersion">
|
||||
{{ t('core.common.dialog.confirmButton') }}
|
||||
</v-btn>
|
||||
</v-card-actions>
|
||||
|
||||
+19
@@ -1,7 +1,26 @@
|
||||
export {};
|
||||
|
||||
declare global {
|
||||
interface AstrBotDesktopAppUpdateCheckResult {
|
||||
ok: boolean;
|
||||
reason?: string | null;
|
||||
currentVersion?: string;
|
||||
latestVersion?: string | null;
|
||||
hasUpdate: boolean;
|
||||
}
|
||||
|
||||
interface AstrBotDesktopAppUpdateResult {
|
||||
ok: boolean;
|
||||
reason?: string | null;
|
||||
}
|
||||
|
||||
interface AstrBotAppUpdaterBridge {
|
||||
checkForAppUpdate: () => Promise<AstrBotDesktopAppUpdateCheckResult>;
|
||||
installAppUpdate: () => Promise<AstrBotDesktopAppUpdateResult>;
|
||||
}
|
||||
|
||||
interface Window {
|
||||
astrbotAppUpdater?: AstrBotAppUpdaterBridge;
|
||||
astrbotDesktop?: {
|
||||
isDesktop: boolean;
|
||||
isDesktopRuntime: () => Promise<boolean>;
|
||||
@@ -62,7 +62,7 @@
|
||||
<template #label>
|
||||
<div class="d-flex flex-column">
|
||||
<span class="text-body-2 font-weight-medium">{{ tm('switches.enable') }}</span>
|
||||
<span class="text-caption text-medium-emphasis">Enable sub-agent functionality</span>
|
||||
<span class="text-caption text-medium-emphasis">{{ tm('switches.enableHint') }}</span>
|
||||
</div>
|
||||
</template>
|
||||
</v-switch>
|
||||
@@ -80,7 +80,7 @@
|
||||
<template #label>
|
||||
<div class="d-flex flex-column">
|
||||
<span class="text-body-2 font-weight-medium">{{ tm('switches.dedupe') }}</span>
|
||||
<span class="text-caption text-medium-emphasis">Remove duplicate tools from main agent</span>
|
||||
<span class="text-caption text-medium-emphasis">{{ tm('switches.dedupeHint') }}</span>
|
||||
</div>
|
||||
</template>
|
||||
</v-switch>
|
||||
@@ -166,7 +166,7 @@
|
||||
<v-text-field
|
||||
v-model="agent.name"
|
||||
:label="tm('form.nameLabel')"
|
||||
:rules="[v => !!v || 'Name is required', v => /^[a-z][a-z0-9_]*$/.test(v) || 'Lowercase letters, numbers, underscore only']"
|
||||
:rules="[v => !!v || tm('messages.nameRequired'), v => /^[a-z][a-z0-9_]*$/.test(v) || tm('messages.namePattern')]"
|
||||
variant="outlined"
|
||||
density="comfortable"
|
||||
hide-details="auto"
|
||||
@@ -215,7 +215,7 @@
|
||||
<v-col cols="12" md="6">
|
||||
<div class="h-100">
|
||||
<div class="text-caption font-weight-bold text-medium-emphasis mb-2 ml-1">
|
||||
PERSONA PREVIEW
|
||||
{{ tm('cards.personaPreview') }}
|
||||
</div>
|
||||
<PersonaQuickPreview
|
||||
:model-value="agent.persona_id"
|
||||
@@ -231,17 +231,17 @@
|
||||
<!-- Empty State -->
|
||||
<div v-if="cfg.agents.length === 0" class="d-flex flex-column align-center justify-center py-12 text-medium-emphasis">
|
||||
<v-icon icon="mdi-robot-off" size="64" class="mb-4 opacity-50" />
|
||||
<div class="text-h6">No Agents Configured</div>
|
||||
<div class="text-body-2 mb-4">Add a new sub-agent to get started</div>
|
||||
<div class="text-h6">{{ tm('empty.title') }}</div>
|
||||
<div class="text-body-2 mb-4">{{ tm('empty.subtitle') }}</div>
|
||||
<v-btn color="primary" variant="tonal" @click="addAgent">
|
||||
Create First Agent
|
||||
{{ tm('empty.action') }}
|
||||
</v-btn>
|
||||
</div>
|
||||
|
||||
<v-snackbar v-model="snackbar.show" :color="snackbar.color" timeout="3000" location="top">
|
||||
{{ snackbar.message }}
|
||||
<template #actions>
|
||||
<v-btn variant="text" @click="snackbar.show = false">Close</v-btn>
|
||||
<v-btn variant="text" @click="snackbar.show = false">{{ tm('actions.close') }}</v-btn>
|
||||
</template>
|
||||
</v-snackbar>
|
||||
</div>
|
||||
|
||||
+1
-1
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "AstrBot"
|
||||
version = "4.18.1"
|
||||
version = "4.18.2"
|
||||
description = "Easy-to-use multi-platform LLM chatbot and development framework"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.12"
|
||||
|
||||
@@ -0,0 +1,381 @@
|
||||
"""
|
||||
AstrBot 测试配置
|
||||
|
||||
提供共享的 pytest fixtures 和测试工具。
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from asyncio import Queue
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
# 使用 tests/fixtures/helpers.py 中的共享工具函数,避免重复定义
|
||||
from tests.fixtures.helpers import create_mock_llm_response, create_mock_message_component
|
||||
|
||||
# 将项目根目录添加到 sys.path
|
||||
PROJECT_ROOT = Path(__file__).parent.parent
|
||||
if str(PROJECT_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
|
||||
# 设置测试环境变量
|
||||
os.environ.setdefault("TESTING", "true")
|
||||
os.environ.setdefault("ASTRBOT_TEST_MODE", "true")
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 测试收集和排序
|
||||
# ============================================================
|
||||
|
||||
|
||||
def pytest_collection_modifyitems(session, config, items): # noqa: ARG001
|
||||
"""重新排序测试:单元测试优先,集成测试在后。"""
|
||||
unit_tests = []
|
||||
integration_tests = []
|
||||
deselected = []
|
||||
profile = config.getoption("--test-profile") or os.environ.get(
|
||||
"ASTRBOT_TEST_PROFILE", "all"
|
||||
)
|
||||
|
||||
for item in items:
|
||||
item_path = Path(str(item.path))
|
||||
is_integration = "integration" in item_path.parts
|
||||
|
||||
if is_integration:
|
||||
if item.get_closest_marker("integration") is None:
|
||||
item.add_marker(pytest.mark.integration)
|
||||
item.add_marker(pytest.mark.tier_d)
|
||||
integration_tests.append(item)
|
||||
else:
|
||||
if item.get_closest_marker("unit") is None:
|
||||
item.add_marker(pytest.mark.unit)
|
||||
if any(
|
||||
item.get_closest_marker(marker) is not None
|
||||
for marker in ("platform", "provider", "slow")
|
||||
):
|
||||
item.add_marker(pytest.mark.tier_c)
|
||||
unit_tests.append(item)
|
||||
|
||||
# 单元测试 -> 集成测试
|
||||
ordered_items = unit_tests + integration_tests
|
||||
if profile == "blocking":
|
||||
selected_items = []
|
||||
for item in ordered_items:
|
||||
if item.get_closest_marker("tier_c") or item.get_closest_marker("tier_d"):
|
||||
deselected.append(item)
|
||||
else:
|
||||
selected_items.append(item)
|
||||
if deselected:
|
||||
config.hook.pytest_deselected(items=deselected)
|
||||
items[:] = selected_items
|
||||
return
|
||||
|
||||
items[:] = ordered_items
|
||||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
"""增加测试执行档位选择。"""
|
||||
parser.addoption(
|
||||
"--test-profile",
|
||||
action="store",
|
||||
default=None,
|
||||
choices=["all", "blocking"],
|
||||
help="Select test profile. 'blocking' excludes auto-classified tier_c/tier_d tests.",
|
||||
)
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
"""注册自定义标记。"""
|
||||
config.addinivalue_line("markers", "unit: 单元测试")
|
||||
config.addinivalue_line("markers", "integration: 集成测试")
|
||||
config.addinivalue_line("markers", "slow: 慢速测试")
|
||||
config.addinivalue_line("markers", "platform: 平台适配器测试")
|
||||
config.addinivalue_line("markers", "provider: LLM Provider 测试")
|
||||
config.addinivalue_line("markers", "db: 数据库相关测试")
|
||||
config.addinivalue_line("markers", "tier_c: C-tier tests (optional / non-blocking)")
|
||||
config.addinivalue_line("markers", "tier_d: D-tier tests (extended / integration)")
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 临时目录和文件 Fixtures
|
||||
# ============================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_dir(tmp_path: Path) -> Path:
|
||||
"""创建临时目录用于测试。"""
|
||||
return tmp_path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def event_queue() -> Queue:
|
||||
"""Create a shared asyncio queue fixture for tests."""
|
||||
return Queue()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def platform_settings() -> dict:
|
||||
"""Create a shared empty platform settings fixture for adapter tests."""
|
||||
return {}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_data_dir(temp_dir: Path) -> Path:
|
||||
"""创建模拟的 data 目录结构。"""
|
||||
data_dir = temp_dir / "data"
|
||||
data_dir.mkdir()
|
||||
|
||||
# 创建必要的子目录
|
||||
(data_dir / "config").mkdir()
|
||||
(data_dir / "plugins").mkdir()
|
||||
(data_dir / "temp").mkdir()
|
||||
(data_dir / "attachments").mkdir()
|
||||
|
||||
return data_dir
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_config_file(temp_data_dir: Path) -> Path:
|
||||
"""创建临时配置文件。"""
|
||||
config_path = temp_data_dir / "config" / "cmd_config.json"
|
||||
default_config = {
|
||||
"provider": [],
|
||||
"platform": [],
|
||||
"provider_settings": {},
|
||||
"default_personality": None,
|
||||
"timezone": "Asia/Shanghai",
|
||||
}
|
||||
config_path.write_text(json.dumps(default_config, indent=2), encoding="utf-8")
|
||||
return config_path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_db_file(temp_data_dir: Path) -> Path:
|
||||
"""创建临时数据库文件路径。"""
|
||||
return temp_data_dir / "test.db"
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Mock Fixtures
|
||||
# ============================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_provider():
|
||||
"""创建模拟的 Provider。"""
|
||||
provider = MagicMock()
|
||||
provider.provider_config = {
|
||||
"id": "test-provider",
|
||||
"type": "openai_chat_completion",
|
||||
"model": "gpt-4o-mini",
|
||||
}
|
||||
provider.get_model = MagicMock(return_value="gpt-4o-mini")
|
||||
provider.text_chat = AsyncMock()
|
||||
provider.text_chat_stream = AsyncMock()
|
||||
provider.terminate = AsyncMock()
|
||||
return provider
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_platform():
|
||||
"""创建模拟的 Platform。"""
|
||||
platform = MagicMock()
|
||||
platform.platform_name = "test_platform"
|
||||
platform.platform_meta = MagicMock()
|
||||
platform.platform_meta.support_proactive_message = False
|
||||
platform.send_message = AsyncMock()
|
||||
platform.terminate = AsyncMock()
|
||||
return platform
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_conversation():
|
||||
"""创建模拟的 Conversation。"""
|
||||
from astrbot.core.db.po import ConversationV2
|
||||
|
||||
return ConversationV2(
|
||||
conversation_id="test-conv-id",
|
||||
platform_id="test_platform",
|
||||
user_id="test_user",
|
||||
content=[],
|
||||
persona_id=None,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_event():
|
||||
"""创建模拟的 AstrMessageEvent。"""
|
||||
event = MagicMock()
|
||||
event.unified_msg_origin = "test_umo"
|
||||
event.session_id = "test_session"
|
||||
event.message_str = "Hello, world!"
|
||||
event.message_obj = MagicMock()
|
||||
event.message_obj.message = []
|
||||
event.message_obj.sender = MagicMock()
|
||||
event.message_obj.sender.user_id = "test_user"
|
||||
event.message_obj.sender.nickname = "Test User"
|
||||
event.message_obj.group_id = None
|
||||
event.message_obj.group = None
|
||||
event.get_platform_name = MagicMock(return_value="test_platform")
|
||||
event.get_platform_id = MagicMock(return_value="test_platform")
|
||||
event.get_group_id = MagicMock(return_value=None)
|
||||
event.get_extra = MagicMock(return_value=None)
|
||||
event.set_extra = MagicMock()
|
||||
event.trace = MagicMock()
|
||||
event.platform_meta = MagicMock()
|
||||
event.platform_meta.support_proactive_message = False
|
||||
return event
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 配置 Fixtures
|
||||
# ============================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def astrbot_config(temp_config_file: Path):
|
||||
"""创建 AstrBotConfig 实例。"""
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
|
||||
config = AstrBotConfig()
|
||||
config._config_path = str(temp_config_file) # noqa: SLF001
|
||||
return config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def main_agent_build_config():
|
||||
"""创建 MainAgentBuildConfig 实例。"""
|
||||
from astrbot.core.astr_main_agent import MainAgentBuildConfig
|
||||
|
||||
return MainAgentBuildConfig(
|
||||
tool_call_timeout=60,
|
||||
tool_schema_mode="full",
|
||||
provider_wake_prefix="",
|
||||
streaming_response=True,
|
||||
sanitize_context_by_modalities=False,
|
||||
kb_agentic_mode=False,
|
||||
file_extract_enabled=False,
|
||||
context_limit_reached_strategy="truncate_by_turns",
|
||||
llm_safety_mode=True,
|
||||
computer_use_runtime="local",
|
||||
add_cron_tools=True,
|
||||
)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 数据库 Fixtures
|
||||
# ============================================================
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def temp_db(temp_db_file: Path):
|
||||
"""创建临时数据库实例。"""
|
||||
from astrbot.core.db.sqlite import SQLiteDatabase
|
||||
|
||||
db = SQLiteDatabase(str(temp_db_file))
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
await db.engine.dispose()
|
||||
if temp_db_file.exists():
|
||||
temp_db_file.unlink()
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Context Fixtures
|
||||
# ============================================================
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def mock_context(
|
||||
astrbot_config,
|
||||
temp_db,
|
||||
mock_provider,
|
||||
mock_platform,
|
||||
):
|
||||
"""创建模拟的插件上下文。"""
|
||||
from asyncio import Queue
|
||||
|
||||
from astrbot.core.star.context import Context
|
||||
|
||||
event_queue = Queue()
|
||||
|
||||
provider_manager = MagicMock()
|
||||
provider_manager.get_using_provider = MagicMock(return_value=mock_provider)
|
||||
provider_manager.get_provider_by_id = MagicMock(return_value=mock_provider)
|
||||
|
||||
platform_manager = MagicMock()
|
||||
conversation_manager = MagicMock()
|
||||
message_history_manager = MagicMock()
|
||||
persona_manager = MagicMock()
|
||||
persona_manager.personas_v3 = []
|
||||
astrbot_config_mgr = MagicMock()
|
||||
knowledge_base_manager = MagicMock()
|
||||
cron_manager = MagicMock()
|
||||
subagent_orchestrator = None
|
||||
|
||||
context = Context(
|
||||
event_queue,
|
||||
astrbot_config,
|
||||
temp_db,
|
||||
provider_manager,
|
||||
platform_manager,
|
||||
conversation_manager,
|
||||
message_history_manager,
|
||||
persona_manager,
|
||||
astrbot_config_mgr,
|
||||
knowledge_base_manager,
|
||||
cron_manager,
|
||||
subagent_orchestrator,
|
||||
)
|
||||
|
||||
return context
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Provider Request Fixtures
|
||||
# ============================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def provider_request():
|
||||
"""创建 ProviderRequest 实例。"""
|
||||
from astrbot.core.provider.entities import ProviderRequest
|
||||
|
||||
return ProviderRequest(
|
||||
prompt="Hello",
|
||||
session_id="test_session",
|
||||
image_urls=[],
|
||||
contexts=[],
|
||||
system_prompt="You are a helpful assistant.",
|
||||
)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 跳过条件
|
||||
# ============================================================
|
||||
|
||||
|
||||
def pytest_runtest_setup(item):
|
||||
"""在测试运行前检查跳过条件。"""
|
||||
# 跳过需要 API Key 但未设置的 Provider 测试
|
||||
if item.get_closest_marker("provider"):
|
||||
if not os.environ.get("TEST_PROVIDER_API_KEY"):
|
||||
pytest.skip("TEST_PROVIDER_API_KEY not set")
|
||||
|
||||
# 跳过需要特定平台的测试
|
||||
if item.get_closest_marker("platform"):
|
||||
required_platform = None
|
||||
marker = item.get_closest_marker("platform")
|
||||
if marker and marker.args:
|
||||
required_platform = marker.args[0]
|
||||
|
||||
if required_platform and not os.environ.get(
|
||||
f"TEST_{required_platform.upper()}_ENABLED"
|
||||
):
|
||||
pytest.skip(f"TEST_{required_platform.upper()}_ENABLED not set")
|
||||
Vendored
+64
@@ -0,0 +1,64 @@
|
||||
"""
|
||||
AstrBot 测试数据
|
||||
|
||||
此目录存放测试用的静态数据和配置文件。
|
||||
|
||||
目录结构:
|
||||
- fixtures/
|
||||
├── configs/ # 测试配置文件
|
||||
├── messages/ # 测试消息数据
|
||||
├── plugins/ # 测试插件
|
||||
├── knowledge_base/ # 测试知识库数据
|
||||
├── mocks/ # Mock 模块
|
||||
└── helpers.py # 辅助函数
|
||||
"""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from .helpers import (
|
||||
NoopAwaitable,
|
||||
create_mock_discord_attachment,
|
||||
create_mock_discord_channel,
|
||||
create_mock_discord_user,
|
||||
create_mock_file,
|
||||
create_mock_llm_response,
|
||||
create_mock_message_component,
|
||||
create_mock_update,
|
||||
make_platform_config,
|
||||
)
|
||||
|
||||
FIXTURES_DIR = Path(__file__).parent
|
||||
|
||||
|
||||
def load_fixture(filename: str) -> dict:
|
||||
"""加载 JSON 格式的测试数据。"""
|
||||
filepath = FIXTURES_DIR / filename
|
||||
if not filepath.exists():
|
||||
raise FileNotFoundError(f"Fixture not found: {filepath}")
|
||||
return json.loads(filepath.read_text(encoding="utf-8"))
|
||||
|
||||
|
||||
def get_fixture_path(filename: str) -> Path:
|
||||
"""获取测试数据文件路径。"""
|
||||
filepath = FIXTURES_DIR / filename
|
||||
if not filepath.exists():
|
||||
raise FileNotFoundError(f"Fixture not found: {filepath}")
|
||||
return filepath
|
||||
|
||||
|
||||
__all__ = [
|
||||
"FIXTURES_DIR",
|
||||
"load_fixture",
|
||||
"get_fixture_path",
|
||||
# 辅助函数
|
||||
"NoopAwaitable",
|
||||
"make_platform_config",
|
||||
"create_mock_update",
|
||||
"create_mock_file",
|
||||
"create_mock_discord_attachment",
|
||||
"create_mock_discord_user",
|
||||
"create_mock_discord_channel",
|
||||
"create_mock_message_component",
|
||||
"create_mock_llm_response",
|
||||
]
|
||||
+21
@@ -0,0 +1,21 @@
|
||||
{
|
||||
"provider": [
|
||||
{
|
||||
"id": "test-openai",
|
||||
"type": "openai_chat_completion",
|
||||
"model": "gpt-4o-mini",
|
||||
"key": ["test-key"]
|
||||
}
|
||||
],
|
||||
"platform": [],
|
||||
"provider_settings": {
|
||||
"default_personality": null,
|
||||
"prompt_prefix": "",
|
||||
"image_caption_provider_id": "",
|
||||
"datetime_system_prompt": true,
|
||||
"identifier": true,
|
||||
"group_name_display": true
|
||||
},
|
||||
"default_personality": null,
|
||||
"timezone": "Asia/Shanghai"
|
||||
}
|
||||
Vendored
+332
@@ -0,0 +1,332 @@
|
||||
"""测试辅助函数和工具类。
|
||||
|
||||
提供统一的测试辅助工具,减少测试代码重复。
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from astrbot.core.message.components import BaseMessageComponent
|
||||
|
||||
|
||||
class NoopAwaitable:
|
||||
"""可等待的空操作对象。
|
||||
|
||||
用于 mock 需要返回 awaitable 对象的方法。
|
||||
"""
|
||||
|
||||
def __await__(self):
|
||||
if False:
|
||||
yield
|
||||
return None
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 平台配置工厂
|
||||
# ============================================================
|
||||
|
||||
|
||||
def make_platform_config(platform_type: str, **kwargs) -> dict:
|
||||
"""平台配置工厂函数。
|
||||
|
||||
Args:
|
||||
platform_type: 平台类型 (telegram, discord, aiocqhttp 等)
|
||||
**kwargs: 覆盖默认配置的字段
|
||||
|
||||
Returns:
|
||||
dict: 平台配置字典
|
||||
"""
|
||||
configs = {
|
||||
"telegram": {
|
||||
"id": "test_telegram",
|
||||
"telegram_token": "test_token_123",
|
||||
"telegram_api_base_url": "https://api.telegram.org/bot",
|
||||
"telegram_file_base_url": "https://api.telegram.org/file/bot",
|
||||
"telegram_command_register": True,
|
||||
"telegram_command_auto_refresh": True,
|
||||
"telegram_command_register_interval": 300,
|
||||
"telegram_media_group_timeout": 2.5,
|
||||
"telegram_media_group_max_wait": 10.0,
|
||||
"start_message": "Welcome to AstrBot!",
|
||||
},
|
||||
"discord": {
|
||||
"id": "test_discord",
|
||||
"discord_token": "test_token_123",
|
||||
"discord_proxy": None,
|
||||
"discord_command_register": True,
|
||||
"discord_guild_id_for_debug": None,
|
||||
"discord_activity_name": "Playing AstrBot",
|
||||
},
|
||||
"aiocqhttp": {
|
||||
"id": "test_aiocqhttp",
|
||||
"ws_reverse_host": "0.0.0.0",
|
||||
"ws_reverse_port": 6199,
|
||||
"ws_reverse_token": "test_token",
|
||||
},
|
||||
"webchat": {
|
||||
"id": "test_webchat",
|
||||
},
|
||||
"wecom": {
|
||||
"id": "test_wecom",
|
||||
"wecom_corpid": "test_corpid",
|
||||
"wecom_secret": "test_secret",
|
||||
},
|
||||
}
|
||||
config = configs.get(platform_type, {"id": f"test_{platform_type}"}).copy()
|
||||
config.update(kwargs)
|
||||
return config
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Telegram 辅助函数
|
||||
# ============================================================
|
||||
|
||||
|
||||
def create_mock_update(
|
||||
message_text: str | None = "Hello World",
|
||||
chat_type: str = "private",
|
||||
chat_id: int = 123456789,
|
||||
user_id: int = 987654321,
|
||||
username: str = "test_user",
|
||||
message_id: int = 1,
|
||||
media_group_id: str | None = None,
|
||||
photo: list | None = None,
|
||||
video: MagicMock | None = None,
|
||||
document: MagicMock | None = None,
|
||||
voice: MagicMock | None = None,
|
||||
sticker: MagicMock | None = None,
|
||||
reply_to_message: MagicMock | None = None,
|
||||
caption: str | None = None,
|
||||
entities: list | None = None,
|
||||
caption_entities: list | None = None,
|
||||
message_thread_id: int | None = None,
|
||||
is_topic_message: bool = False,
|
||||
):
|
||||
"""创建模拟的 Telegram Update 对象。
|
||||
|
||||
Args:
|
||||
message_text: 消息文本
|
||||
chat_type: 聊天类型
|
||||
chat_id: 聊天 ID
|
||||
user_id: 用户 ID
|
||||
username: 用户名
|
||||
message_id: 消息 ID
|
||||
media_group_id: 媒体组 ID
|
||||
photo: 图片列表
|
||||
video: 视频对象
|
||||
document: 文档对象
|
||||
voice: 语音对象
|
||||
sticker: 贴纸对象
|
||||
reply_to_message: 回复的消息
|
||||
caption: 说明文字
|
||||
entities: 实体列表
|
||||
caption_entities: 说明实体列表
|
||||
message_thread_id: 消息线程 ID
|
||||
is_topic_message: 是否为主题消息
|
||||
|
||||
Returns:
|
||||
MagicMock: 模拟的 Update 对象
|
||||
"""
|
||||
update = MagicMock()
|
||||
update.update_id = 1
|
||||
|
||||
# Create message mock
|
||||
message = MagicMock()
|
||||
message.message_id = message_id
|
||||
message.chat = MagicMock()
|
||||
message.chat.id = chat_id
|
||||
message.chat.type = chat_type
|
||||
message.message_thread_id = message_thread_id
|
||||
message.is_topic_message = is_topic_message
|
||||
|
||||
# Create user mock
|
||||
from_user = MagicMock()
|
||||
from_user.id = user_id
|
||||
from_user.username = username
|
||||
message.from_user = from_user
|
||||
|
||||
# Set message content
|
||||
message.text = message_text
|
||||
message.media_group_id = media_group_id
|
||||
message.photo = photo
|
||||
message.video = video
|
||||
message.document = document
|
||||
message.voice = voice
|
||||
message.sticker = sticker
|
||||
message.reply_to_message = reply_to_message
|
||||
message.caption = caption
|
||||
message.entities = entities
|
||||
message.caption_entities = caption_entities
|
||||
|
||||
update.message = message
|
||||
update.effective_chat = message.chat
|
||||
|
||||
return update
|
||||
|
||||
|
||||
def create_mock_file(file_path: str = "https://api.telegram.org/file/test.jpg"):
|
||||
"""创建模拟的 Telegram File 对象。
|
||||
|
||||
Args:
|
||||
file_path: 文件路径
|
||||
|
||||
Returns:
|
||||
MagicMock: 模拟的 File 对象
|
||||
"""
|
||||
file = MagicMock()
|
||||
file.file_path = file_path
|
||||
file.get_file = AsyncMock(return_value=file)
|
||||
return file
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Discord 辅助函数
|
||||
# ============================================================
|
||||
|
||||
|
||||
def create_mock_discord_attachment(
|
||||
filename: str = "test.txt",
|
||||
url: str = "https://cdn.discordapp.com/test.txt",
|
||||
content_type: str | None = None,
|
||||
size: int = 1024,
|
||||
):
|
||||
"""创建模拟的 Discord Attachment 对象。
|
||||
|
||||
Args:
|
||||
filename: 文件名
|
||||
url: 文件 URL
|
||||
content_type: 内容类型
|
||||
size: 文件大小
|
||||
|
||||
Returns:
|
||||
MagicMock: 模拟的 Attachment 对象
|
||||
"""
|
||||
attachment = MagicMock()
|
||||
attachment.filename = filename
|
||||
attachment.url = url
|
||||
attachment.content_type = content_type
|
||||
attachment.size = size
|
||||
return attachment
|
||||
|
||||
|
||||
def create_mock_discord_user(
|
||||
user_id: int = 123456789,
|
||||
name: str = "TestUser",
|
||||
display_name: str = "Test User",
|
||||
bot: bool = False,
|
||||
):
|
||||
"""创建模拟的 Discord User 对象。
|
||||
|
||||
Args:
|
||||
user_id: 用户 ID
|
||||
name: 用户名
|
||||
display_name: 显示名
|
||||
bot: 是否为机器人
|
||||
|
||||
Returns:
|
||||
MagicMock: 模拟的 User 对象
|
||||
"""
|
||||
user = MagicMock()
|
||||
user.id = user_id
|
||||
user.name = name
|
||||
user.display_name = display_name
|
||||
user.bot = bot
|
||||
user.mention = f"<@{user_id}>"
|
||||
return user
|
||||
|
||||
|
||||
def create_mock_discord_channel(
|
||||
channel_id: int = 111222333,
|
||||
channel_type: str = "text",
|
||||
name: str = "general",
|
||||
guild_id: int | None = 444555666,
|
||||
):
|
||||
"""创建模拟的 Discord Channel 对象。
|
||||
|
||||
Args:
|
||||
channel_id: 频道 ID
|
||||
channel_type: 频道类型
|
||||
name: 频道名
|
||||
guild_id: 服务器 ID
|
||||
|
||||
Returns:
|
||||
MagicMock: 模拟的 Channel 对象
|
||||
"""
|
||||
channel = MagicMock()
|
||||
channel.id = channel_id
|
||||
channel.name = name
|
||||
channel.type = channel_type
|
||||
|
||||
if guild_id:
|
||||
channel.guild = MagicMock()
|
||||
channel.guild.id = guild_id
|
||||
else:
|
||||
channel.guild = None
|
||||
|
||||
return channel
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 消息组件辅助函数
|
||||
# ============================================================
|
||||
|
||||
|
||||
def create_mock_message_component(
|
||||
component_type: str,
|
||||
**kwargs: Any,
|
||||
) -> BaseMessageComponent:
|
||||
"""创建模拟的消息组件。
|
||||
|
||||
Args:
|
||||
component_type: 组件类型 (plain, image, at, reply, file)
|
||||
**kwargs: 组件参数
|
||||
|
||||
Returns:
|
||||
BaseMessageComponent: 消息组件实例
|
||||
"""
|
||||
from astrbot.core.message import components as Comp
|
||||
|
||||
component_map = {
|
||||
"plain": Comp.Plain,
|
||||
"image": Comp.Image,
|
||||
"at": Comp.At,
|
||||
"reply": Comp.Reply,
|
||||
"file": Comp.File,
|
||||
}
|
||||
|
||||
component_class = component_map.get(component_type.lower())
|
||||
if not component_class:
|
||||
raise ValueError(f"Unknown component type: {component_type}")
|
||||
|
||||
return component_class(**kwargs)
|
||||
|
||||
|
||||
def create_mock_llm_response(
|
||||
completion_text: str = "Hello! How can I help you?",
|
||||
role: str = "assistant",
|
||||
tools_call_name: list[str] | None = None,
|
||||
tools_call_args: list[dict] | None = None,
|
||||
tools_call_ids: list[str] | None = None,
|
||||
):
|
||||
"""创建模拟的 LLM 响应。
|
||||
|
||||
Args:
|
||||
completion_text: 完成文本
|
||||
role: 角色
|
||||
tools_call_name: 工具调用名称列表
|
||||
tools_call_args: 工具调用参数列表
|
||||
tools_call_ids: 工具调用 ID 列表
|
||||
|
||||
Returns:
|
||||
LLMResponse: 模拟的 LLM 响应
|
||||
"""
|
||||
from astrbot.core.provider.entities import LLMResponse, TokenUsage
|
||||
|
||||
return LLMResponse(
|
||||
role=role,
|
||||
completion_text=completion_text,
|
||||
tools_call_name=tools_call_name or [],
|
||||
tools_call_args=tools_call_args or [],
|
||||
tools_call_ids=tools_call_ids or [],
|
||||
usage=TokenUsage(input_other=10, output=5),
|
||||
)
|
||||
+33
@@ -0,0 +1,33 @@
|
||||
{
|
||||
"plain_message": {
|
||||
"type": "plain",
|
||||
"text": "Hello, this is a test message."
|
||||
},
|
||||
"image_message": {
|
||||
"type": "image",
|
||||
"url": "https://example.com/test.jpg",
|
||||
"file": null
|
||||
},
|
||||
"at_message": {
|
||||
"type": "at",
|
||||
"user_id": "12345",
|
||||
"nickname": "TestUser"
|
||||
},
|
||||
"reply_message": {
|
||||
"type": "reply",
|
||||
"id": "msg_123",
|
||||
"sender_nickname": "OriginalSender",
|
||||
"message_str": "This is the original message"
|
||||
},
|
||||
"file_message": {
|
||||
"type": "file",
|
||||
"name": "test.pdf",
|
||||
"url": "https://example.com/test.pdf"
|
||||
},
|
||||
"combined_message": {
|
||||
"components": [
|
||||
{"type": "at", "user_id": "bot_id"},
|
||||
{"type": "plain", "text": " Hello bot!"}
|
||||
]
|
||||
}
|
||||
}
|
||||
Vendored
+43
@@ -0,0 +1,43 @@
|
||||
"""测试 Mock 模块。
|
||||
|
||||
提供统一的 mock 工具和 fixture,减少测试代码重复。
|
||||
|
||||
使用方式:
|
||||
# 在测试文件顶部导入需要的 fixture
|
||||
from tests.fixtures.mocks import mock_telegram_modules
|
||||
|
||||
# 或使用 Builder 类创建 mock 对象
|
||||
from tests.fixtures.mocks import MockTelegramBuilder
|
||||
bot = MockTelegramBuilder.create_bot()
|
||||
"""
|
||||
|
||||
from .aiocqhttp import (
|
||||
MockAiocqhttpBuilder,
|
||||
create_mock_aiocqhttp_modules,
|
||||
mock_aiocqhttp_modules,
|
||||
)
|
||||
from .discord import (
|
||||
MockDiscordBuilder,
|
||||
create_mock_discord_modules,
|
||||
mock_discord_modules,
|
||||
)
|
||||
from .telegram import (
|
||||
MockTelegramBuilder,
|
||||
create_mock_telegram_modules,
|
||||
mock_telegram_modules,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Telegram
|
||||
"mock_telegram_modules",
|
||||
"create_mock_telegram_modules",
|
||||
"MockTelegramBuilder",
|
||||
# Discord
|
||||
"mock_discord_modules",
|
||||
"create_mock_discord_modules",
|
||||
"MockDiscordBuilder",
|
||||
# Aiocqhttp
|
||||
"mock_aiocqhttp_modules",
|
||||
"create_mock_aiocqhttp_modules",
|
||||
"MockAiocqhttpBuilder",
|
||||
]
|
||||
Vendored
+58
@@ -0,0 +1,58 @@
|
||||
"""Aiocqhttp 模块 Mock 工具。
|
||||
|
||||
提供统一的 aiocqhttp 相关模块 mock 设置,避免在测试文件中重复定义。
|
||||
"""
|
||||
|
||||
import sys
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def create_mock_aiocqhttp_modules():
|
||||
"""创建 aiocqhttp 相关的 mock 模块。
|
||||
|
||||
Returns:
|
||||
dict: 包含 aiocqhttp 和相关模块的 mock 对象
|
||||
"""
|
||||
mock_aiocqhttp = MagicMock()
|
||||
mock_aiocqhttp.CQHttp = MagicMock
|
||||
mock_aiocqhttp.Event = MagicMock
|
||||
mock_aiocqhttp.exceptions = MagicMock()
|
||||
mock_aiocqhttp.exceptions.ActionFailed = Exception
|
||||
|
||||
return mock_aiocqhttp
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", autouse=True)
|
||||
def mock_aiocqhttp_modules():
|
||||
"""Mock aiocqhttp 相关模块的 fixture。
|
||||
|
||||
自动应用于使用此 fixture 的测试模块。
|
||||
"""
|
||||
mock_aiocqhttp = create_mock_aiocqhttp_modules()
|
||||
monkeypatch = pytest.MonkeyPatch()
|
||||
|
||||
monkeypatch.setitem(sys.modules, "aiocqhttp", mock_aiocqhttp)
|
||||
monkeypatch.setitem(sys.modules, "aiocqhttp.exceptions", mock_aiocqhttp.exceptions)
|
||||
yield
|
||||
monkeypatch.undo()
|
||||
|
||||
|
||||
class MockAiocqhttpBuilder:
|
||||
"""构建 aiocqhttp 测试 mock 对象的工具类。"""
|
||||
|
||||
@staticmethod
|
||||
def create_bot():
|
||||
"""创建 mock CQHttp bot 实例。"""
|
||||
from tests.fixtures.helpers import NoopAwaitable
|
||||
|
||||
bot = MagicMock()
|
||||
bot.send = AsyncMock()
|
||||
bot.call_action = AsyncMock()
|
||||
bot.on_request = MagicMock()
|
||||
bot.on_notice = MagicMock()
|
||||
bot.on_message = MagicMock()
|
||||
bot.on_websocket_connection = MagicMock()
|
||||
bot.run_task = MagicMock(return_value=NoopAwaitable())
|
||||
return bot
|
||||
Vendored
+140
@@ -0,0 +1,140 @@
|
||||
"""Discord 模块 Mock 工具。
|
||||
|
||||
提供统一的 Discord 相关模块 mock 设置,避免在测试文件中重复定义。
|
||||
"""
|
||||
|
||||
import sys
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def create_mock_discord_modules():
|
||||
"""创建 Discord 相关的 mock 模块。
|
||||
|
||||
Returns:
|
||||
dict: 包含 discord 和相关模块的 mock 对象
|
||||
"""
|
||||
mock_discord = MagicMock()
|
||||
|
||||
# Mock discord.Intents
|
||||
mock_intents = MagicMock()
|
||||
mock_intents.default = MagicMock(return_value=mock_intents)
|
||||
mock_discord.Intents = mock_intents
|
||||
|
||||
# Mock discord.Status
|
||||
mock_discord.Status = MagicMock()
|
||||
mock_discord.Status.online = "online"
|
||||
|
||||
# Mock discord.Bot
|
||||
mock_bot = MagicMock()
|
||||
mock_discord.Bot = MagicMock(return_value=mock_bot)
|
||||
|
||||
# Mock discord.Embed
|
||||
mock_embed = MagicMock()
|
||||
mock_discord.Embed = MagicMock(return_value=mock_embed)
|
||||
|
||||
# Mock discord.ui
|
||||
mock_ui = MagicMock()
|
||||
mock_ui.View = MagicMock
|
||||
mock_ui.Button = MagicMock
|
||||
mock_discord.ui = mock_ui
|
||||
|
||||
# Mock discord.Message
|
||||
mock_discord.Message = MagicMock
|
||||
|
||||
# Mock discord.Interaction
|
||||
mock_discord.Interaction = MagicMock
|
||||
mock_discord.InteractionType = MagicMock()
|
||||
mock_discord.InteractionType.application_command = 2
|
||||
mock_discord.InteractionType.component = 3
|
||||
|
||||
# Mock discord.File
|
||||
mock_discord.File = MagicMock
|
||||
|
||||
# Mock discord.SlashCommand
|
||||
mock_discord.SlashCommand = MagicMock
|
||||
|
||||
# Mock discord.Option
|
||||
mock_discord.Option = MagicMock
|
||||
|
||||
# Mock discord.SlashCommandOptionType
|
||||
mock_discord.SlashCommandOptionType = MagicMock()
|
||||
mock_discord.SlashCommandOptionType.string = 3
|
||||
|
||||
# Mock discord.errors
|
||||
mock_discord.errors = MagicMock()
|
||||
mock_discord.errors.LoginFailure = Exception
|
||||
mock_discord.errors.ConnectionClosed = Exception
|
||||
mock_discord.errors.NotFound = Exception
|
||||
mock_discord.errors.Forbidden = Exception
|
||||
|
||||
# Mock discord.abc
|
||||
mock_discord.abc = MagicMock()
|
||||
mock_discord.abc.GuildChannel = MagicMock
|
||||
mock_discord.abc.Messageable = MagicMock
|
||||
mock_discord.abc.PrivateChannel = MagicMock
|
||||
|
||||
# Mock discord.channel
|
||||
mock_channel = MagicMock()
|
||||
mock_channel.DMChannel = MagicMock
|
||||
mock_discord.channel = mock_channel
|
||||
|
||||
# Mock discord.types
|
||||
mock_discord.types = MagicMock()
|
||||
mock_discord.types.interactions = MagicMock()
|
||||
|
||||
# Mock discord.ApplicationContext
|
||||
mock_discord.ApplicationContext = MagicMock
|
||||
|
||||
# Mock discord.CustomActivity
|
||||
mock_discord.CustomActivity = MagicMock
|
||||
|
||||
return mock_discord
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", autouse=True)
|
||||
def mock_discord_modules():
|
||||
"""Mock Discord 相关模块的 fixture。
|
||||
|
||||
自动应用于使用此 fixture 的测试模块。
|
||||
"""
|
||||
mock_discord = create_mock_discord_modules()
|
||||
monkeypatch = pytest.MonkeyPatch()
|
||||
|
||||
monkeypatch.setitem(sys.modules, "discord", mock_discord)
|
||||
monkeypatch.setitem(sys.modules, "discord.abc", mock_discord.abc)
|
||||
monkeypatch.setitem(sys.modules, "discord.channel", mock_discord.channel)
|
||||
monkeypatch.setitem(sys.modules, "discord.errors", mock_discord.errors)
|
||||
monkeypatch.setitem(sys.modules, "discord.types", mock_discord.types)
|
||||
monkeypatch.setitem(
|
||||
sys.modules,
|
||||
"discord.types.interactions",
|
||||
mock_discord.types.interactions,
|
||||
)
|
||||
monkeypatch.setitem(sys.modules, "discord.ui", mock_discord.ui)
|
||||
yield
|
||||
monkeypatch.undo()
|
||||
|
||||
|
||||
class MockDiscordBuilder:
|
||||
"""构建 Discord 测试 mock 对象的工具类。"""
|
||||
|
||||
@staticmethod
|
||||
def create_client():
|
||||
"""创建 mock Discord client 实例。"""
|
||||
client = MagicMock()
|
||||
client.user = MagicMock()
|
||||
client.user.id = 123456789
|
||||
client.user.display_name = "TestBot"
|
||||
client.user.name = "TestBot"
|
||||
client.get_channel = MagicMock()
|
||||
client.fetch_channel = AsyncMock()
|
||||
client.get_message = MagicMock()
|
||||
client.start = AsyncMock()
|
||||
client.close = AsyncMock()
|
||||
client.is_closed = MagicMock(return_value=False)
|
||||
client.add_application_command = MagicMock()
|
||||
client.sync_commands = AsyncMock()
|
||||
client.change_presence = AsyncMock()
|
||||
return client
|
||||
Vendored
+141
@@ -0,0 +1,141 @@
|
||||
"""Telegram 模块 Mock 工具。
|
||||
|
||||
提供统一的 Telegram 相关模块 mock 设置,避免在测试文件中重复定义。
|
||||
"""
|
||||
|
||||
import sys
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def create_mock_telegram_modules():
|
||||
"""创建 Telegram 相关的 mock 模块。
|
||||
|
||||
Returns:
|
||||
dict: 包含 telegram 和相关模块的 mock 对象
|
||||
"""
|
||||
mock_telegram = MagicMock()
|
||||
mock_telegram.BotCommand = MagicMock
|
||||
mock_telegram.Update = MagicMock
|
||||
mock_telegram.constants = MagicMock()
|
||||
mock_telegram.constants.ChatType = MagicMock()
|
||||
mock_telegram.constants.ChatType.PRIVATE = "private"
|
||||
mock_telegram.constants.ChatAction = MagicMock()
|
||||
mock_telegram.constants.ChatAction.TYPING = "typing"
|
||||
mock_telegram.constants.ChatAction.UPLOAD_VOICE = "upload_voice"
|
||||
mock_telegram.constants.ChatAction.UPLOAD_DOCUMENT = "upload_document"
|
||||
mock_telegram.constants.ChatAction.UPLOAD_PHOTO = "upload_photo"
|
||||
mock_telegram.error = MagicMock()
|
||||
mock_telegram.error.BadRequest = Exception
|
||||
mock_telegram.ReactionTypeCustomEmoji = MagicMock
|
||||
mock_telegram.ReactionTypeEmoji = MagicMock
|
||||
|
||||
mock_telegram_ext = MagicMock()
|
||||
mock_telegram_ext.ApplicationBuilder = MagicMock
|
||||
mock_telegram_ext.ContextTypes = MagicMock
|
||||
mock_telegram_ext.ExtBot = MagicMock
|
||||
mock_telegram_ext.filters = MagicMock()
|
||||
mock_telegram_ext.filters.ALL = MagicMock()
|
||||
mock_telegram_ext.MessageHandler = MagicMock
|
||||
|
||||
# Mock telegramify_markdown
|
||||
mock_telegramify = MagicMock()
|
||||
mock_telegramify.markdownify = lambda text, **kwargs: text
|
||||
|
||||
# Mock apscheduler
|
||||
mock_apscheduler = MagicMock()
|
||||
mock_apscheduler.schedulers = MagicMock()
|
||||
mock_apscheduler.schedulers.asyncio = MagicMock()
|
||||
mock_apscheduler.schedulers.asyncio.AsyncIOScheduler = MagicMock
|
||||
mock_apscheduler.schedulers.background = MagicMock()
|
||||
mock_apscheduler.schedulers.background.BackgroundScheduler = MagicMock
|
||||
|
||||
return {
|
||||
"telegram": mock_telegram,
|
||||
"telegram.ext": mock_telegram_ext,
|
||||
"telegramify_markdown": mock_telegramify,
|
||||
"apscheduler": mock_apscheduler,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", autouse=True)
|
||||
def mock_telegram_modules():
|
||||
"""Mock Telegram 相关模块的 fixture。
|
||||
|
||||
自动应用于使用此 fixture 的测试模块。
|
||||
"""
|
||||
mocks = create_mock_telegram_modules()
|
||||
monkeypatch = pytest.MonkeyPatch()
|
||||
|
||||
monkeypatch.setitem(sys.modules, "telegram", mocks["telegram"])
|
||||
monkeypatch.setitem(sys.modules, "telegram.constants", mocks["telegram"].constants)
|
||||
monkeypatch.setitem(sys.modules, "telegram.error", mocks["telegram"].error)
|
||||
monkeypatch.setitem(sys.modules, "telegram.ext", mocks["telegram.ext"])
|
||||
monkeypatch.setitem(sys.modules, "telegramify_markdown", mocks["telegramify_markdown"])
|
||||
monkeypatch.setitem(sys.modules, "apscheduler", mocks["apscheduler"])
|
||||
monkeypatch.setitem(
|
||||
sys.modules, "apscheduler.schedulers", mocks["apscheduler"].schedulers
|
||||
)
|
||||
monkeypatch.setitem(
|
||||
sys.modules,
|
||||
"apscheduler.schedulers.asyncio",
|
||||
mocks["apscheduler"].schedulers.asyncio,
|
||||
)
|
||||
monkeypatch.setitem(
|
||||
sys.modules,
|
||||
"apscheduler.schedulers.background",
|
||||
mocks["apscheduler"].schedulers.background,
|
||||
)
|
||||
yield
|
||||
monkeypatch.undo()
|
||||
|
||||
|
||||
class MockTelegramBuilder:
|
||||
"""构建 Telegram 测试 mock 对象的工具类。"""
|
||||
|
||||
@staticmethod
|
||||
def create_bot():
|
||||
"""创建 mock Telegram bot 实例。"""
|
||||
bot = MagicMock()
|
||||
bot.username = "test_bot"
|
||||
bot.id = 12345678
|
||||
bot.base_url = "https://api.telegram.org/bottest_token_123/"
|
||||
bot.send_message = AsyncMock()
|
||||
bot.send_photo = AsyncMock()
|
||||
bot.send_document = AsyncMock()
|
||||
bot.send_voice = AsyncMock()
|
||||
bot.send_chat_action = AsyncMock()
|
||||
bot.delete_my_commands = AsyncMock()
|
||||
bot.set_my_commands = AsyncMock()
|
||||
bot.set_message_reaction = AsyncMock()
|
||||
bot.edit_message_text = AsyncMock()
|
||||
return bot
|
||||
|
||||
@staticmethod
|
||||
def create_application():
|
||||
"""创建 mock Telegram Application 实例。"""
|
||||
from tests.fixtures.helpers import NoopAwaitable
|
||||
|
||||
app = MagicMock()
|
||||
app.bot = MagicMock()
|
||||
app.bot.username = "test_bot"
|
||||
app.bot.base_url = "https://api.telegram.org/bottest_token_123/"
|
||||
app.initialize = AsyncMock()
|
||||
app.start = AsyncMock()
|
||||
app.stop = AsyncMock()
|
||||
app.add_handler = MagicMock()
|
||||
app.updater = MagicMock()
|
||||
app.updater.start_polling = MagicMock(return_value=NoopAwaitable())
|
||||
app.updater.stop = AsyncMock()
|
||||
return app
|
||||
|
||||
@staticmethod
|
||||
def create_scheduler():
|
||||
"""创建 mock APScheduler 实例。"""
|
||||
scheduler = MagicMock()
|
||||
scheduler.add_job = MagicMock()
|
||||
scheduler.start = MagicMock()
|
||||
scheduler.running = True
|
||||
scheduler.shutdown = MagicMock()
|
||||
return scheduler
|
||||
+40
@@ -0,0 +1,40 @@
|
||||
"""
|
||||
测试插件 - 用于插件系统测试
|
||||
|
||||
这是一个最小化的测试插件,用于验证插件系统的功能。
|
||||
"""
|
||||
|
||||
from astrbot.api import llm_tool, star
|
||||
from astrbot.api.event import AstrMessageEvent, MessageEventResult, filter
|
||||
|
||||
|
||||
@star.register("test_plugin", "AstrBot Team", "测试插件 - 用于插件系统测试", "1.0.0")
|
||||
class TestPlugin(star.Star):
|
||||
"""测试插件类"""
|
||||
|
||||
def __init__(self, context: star.Context) -> None:
|
||||
super().__init__(context)
|
||||
self.initialized = True
|
||||
|
||||
async def terminate(self) -> None:
|
||||
"""插件终止"""
|
||||
self.initialized = False
|
||||
|
||||
@filter.command("test_cmd")
|
||||
async def test_command(self, event: AstrMessageEvent) -> None:
|
||||
"""测试命令处理器。"""
|
||||
event.set_result(MessageEventResult().message("测试命令执行成功"))
|
||||
|
||||
@llm_tool("test_tool")
|
||||
async def test_llm_tool(self, query: str) -> str:
|
||||
"""测试 LLM 工具。
|
||||
|
||||
Args:
|
||||
query(string): 查询内容。
|
||||
"""
|
||||
return f"测试工具执行成功: {query}"
|
||||
|
||||
@filter.regex(r"^test_regex_(.+)$")
|
||||
async def test_regex_handler(self, event: AstrMessageEvent) -> None:
|
||||
"""测试正则处理器。"""
|
||||
event.set_result(MessageEventResult().message("正则匹配成功"))
|
||||
Vendored
+5
@@ -0,0 +1,5 @@
|
||||
name: test_plugin
|
||||
description: 测试插件 - 用于插件系统测试
|
||||
version: 1.0.0
|
||||
author: AstrBot Team
|
||||
repo: https://github.com/test/test_plugin
|
||||
@@ -0,0 +1,115 @@
|
||||
"""Smoke tests for critical startup and import paths."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from astrbot.core.pipeline.bootstrap import ensure_builtin_stages_registered
|
||||
from astrbot.core.pipeline.process_stage.method.agent_sub_stages.internal import (
|
||||
InternalAgentSubStage,
|
||||
)
|
||||
from astrbot.core.pipeline.process_stage.method.agent_sub_stages.third_party import (
|
||||
ThirdPartyAgentSubStage,
|
||||
)
|
||||
from astrbot.core.pipeline.stage import Stage, registered_stages
|
||||
from astrbot.core.pipeline.stage_order import STAGES_ORDER
|
||||
|
||||
REPO_ROOT = Path(__file__).resolve().parents[1]
|
||||
|
||||
|
||||
def _run_code_in_fresh_interpreter(code: str, failure_message: str) -> None:
|
||||
proc = subprocess.run(
|
||||
[sys.executable, "-c", code],
|
||||
cwd=REPO_ROOT,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=False,
|
||||
)
|
||||
assert proc.returncode == 0, (
|
||||
f"{failure_message}\nstdout:\n{proc.stdout}\nstderr:\n{proc.stderr}\n"
|
||||
)
|
||||
|
||||
|
||||
def test_smoke_critical_imports_in_fresh_interpreter() -> None:
|
||||
code = (
|
||||
"import importlib;"
|
||||
"mods=["
|
||||
"'astrbot.core.core_lifecycle',"
|
||||
"'astrbot.core.astr_main_agent',"
|
||||
"'astrbot.core.pipeline.scheduler',"
|
||||
"'astrbot.core.pipeline.process_stage.method.agent_sub_stages.internal',"
|
||||
"'astrbot.core.pipeline.process_stage.method.agent_sub_stages.third_party'"
|
||||
"];"
|
||||
"[importlib.import_module(m) for m in mods]"
|
||||
)
|
||||
_run_code_in_fresh_interpreter(code, "Smoke import check failed.")
|
||||
|
||||
|
||||
def test_smoke_pipeline_stage_registration_matches_order() -> None:
|
||||
ensure_builtin_stages_registered()
|
||||
stage_names = {cls.__name__ for cls in registered_stages}
|
||||
|
||||
assert set(STAGES_ORDER).issubset(stage_names)
|
||||
assert len(stage_names) == len(registered_stages)
|
||||
|
||||
|
||||
def test_smoke_agent_sub_stages_are_stage_subclasses() -> None:
|
||||
assert issubclass(InternalAgentSubStage, Stage)
|
||||
assert issubclass(ThirdPartyAgentSubStage, Stage)
|
||||
|
||||
|
||||
def test_pipeline_package_exports_remain_compatible() -> None:
|
||||
import astrbot.core.pipeline as pipeline
|
||||
|
||||
assert pipeline.ProcessStage is not None
|
||||
assert pipeline.RespondStage is not None
|
||||
assert isinstance(pipeline.STAGES_ORDER, list)
|
||||
assert "ProcessStage" in pipeline.STAGES_ORDER
|
||||
|
||||
|
||||
def test_builtin_stage_bootstrap_is_idempotent() -> None:
|
||||
ensure_builtin_stages_registered()
|
||||
before_count = len(registered_stages)
|
||||
stage_names = {cls.__name__ for cls in registered_stages}
|
||||
|
||||
expected_stage_names = {
|
||||
"WakingCheckStage",
|
||||
"WhitelistCheckStage",
|
||||
"SessionStatusCheckStage",
|
||||
"RateLimitStage",
|
||||
"ContentSafetyCheckStage",
|
||||
"PreProcessStage",
|
||||
"ProcessStage",
|
||||
"ResultDecorateStage",
|
||||
"RespondStage",
|
||||
}
|
||||
|
||||
assert expected_stage_names.issubset(stage_names)
|
||||
|
||||
ensure_builtin_stages_registered()
|
||||
assert len(registered_stages) == before_count
|
||||
|
||||
|
||||
def test_pipeline_import_is_stable_with_mocked_apscheduler() -> None:
|
||||
"""Regression: importing pipeline should not require cron/apscheduler modules."""
|
||||
code = (
|
||||
"import sys;"
|
||||
"from unittest.mock import MagicMock;"
|
||||
"mock_apscheduler = MagicMock();"
|
||||
"mock_apscheduler.schedulers = MagicMock();"
|
||||
"mock_apscheduler.schedulers.asyncio = MagicMock();"
|
||||
"mock_apscheduler.schedulers.background = MagicMock();"
|
||||
"sys.modules['apscheduler'] = mock_apscheduler;"
|
||||
"sys.modules['apscheduler.schedulers'] = mock_apscheduler.schedulers;"
|
||||
"sys.modules['apscheduler.schedulers.asyncio'] = mock_apscheduler.schedulers.asyncio;"
|
||||
"sys.modules['apscheduler.schedulers.background'] = mock_apscheduler.schedulers.background;"
|
||||
"import astrbot.core.pipeline as pipeline;"
|
||||
"assert pipeline.ProcessStage is not None;"
|
||||
"assert pipeline.RespondStage is not None"
|
||||
)
|
||||
_run_code_in_fresh_interpreter(
|
||||
code,
|
||||
"Pipeline import should not depend on real apscheduler package.",
|
||||
)
|
||||
@@ -0,0 +1,781 @@
|
||||
"""Tests for AstrMessageEvent class."""
|
||||
|
||||
import re
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from astrbot.core.message.components import (
|
||||
At,
|
||||
AtAll,
|
||||
Face,
|
||||
Forward,
|
||||
Image,
|
||||
Plain,
|
||||
Reply,
|
||||
)
|
||||
from astrbot.core.message.message_event_result import MessageEventResult
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.platform.astrbot_message import AstrBotMessage, MessageMember
|
||||
from astrbot.core.platform.message_type import MessageType
|
||||
from astrbot.core.platform.platform_metadata import PlatformMetadata
|
||||
|
||||
|
||||
class ConcreteAstrMessageEvent(AstrMessageEvent):
|
||||
"""Concrete implementation of AstrMessageEvent for testing purposes."""
|
||||
|
||||
async def send(self, message):
|
||||
"""Send message implementation."""
|
||||
await super().send(message)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def platform_meta():
|
||||
"""Create platform metadata for testing."""
|
||||
return PlatformMetadata(
|
||||
name="test_platform",
|
||||
description="Test platform",
|
||||
id="test_platform_id",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def message_member():
|
||||
"""Create a message member for testing."""
|
||||
return MessageMember(user_id="user123", nickname="TestUser")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def astrbot_message(message_member):
|
||||
"""Create an AstrBotMessage for testing."""
|
||||
message = AstrBotMessage()
|
||||
message.type = MessageType.FRIEND_MESSAGE
|
||||
message.self_id = "bot123"
|
||||
message.session_id = "session123"
|
||||
message.message_id = "msg123"
|
||||
message.sender = message_member
|
||||
message.message = [Plain(text="Hello world")]
|
||||
message.message_str = "Hello world"
|
||||
message.raw_message = None
|
||||
return message
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def astr_message_event(platform_meta, astrbot_message):
|
||||
"""Create an AstrMessageEvent instance for testing."""
|
||||
return ConcreteAstrMessageEvent(
|
||||
message_str="Hello world",
|
||||
message_obj=astrbot_message,
|
||||
platform_meta=platform_meta,
|
||||
session_id="session123",
|
||||
)
|
||||
|
||||
|
||||
class TestAstrMessageEventInit:
|
||||
"""Tests for AstrMessageEvent initialization."""
|
||||
|
||||
def test_init_basic(self, astr_message_event):
|
||||
"""Test basic AstrMessageEvent initialization."""
|
||||
assert astr_message_event.message_str == "Hello world"
|
||||
assert astr_message_event.role == "member"
|
||||
assert astr_message_event.is_wake is False
|
||||
assert astr_message_event.is_at_or_wake_command is False
|
||||
assert astr_message_event._extras == {}
|
||||
assert astr_message_event._result is None
|
||||
assert astr_message_event.call_llm is False
|
||||
|
||||
def test_init_session(self, astr_message_event):
|
||||
"""Test session initialization."""
|
||||
assert astr_message_event.session_id == "session123"
|
||||
assert astr_message_event.session.platform_name == "test_platform_id"
|
||||
|
||||
def test_init_platform_reference(self, astr_message_event, platform_meta):
|
||||
"""Test platform reference initialization."""
|
||||
assert astr_message_event.platform_meta == platform_meta
|
||||
assert astr_message_event.platform == platform_meta # back compatibility
|
||||
|
||||
def test_init_created_at(self, astr_message_event):
|
||||
"""Test created_at timestamp is set."""
|
||||
assert astr_message_event.created_at is not None
|
||||
assert isinstance(astr_message_event.created_at, float)
|
||||
|
||||
def test_init_trace(self, astr_message_event):
|
||||
"""Test trace/span initialization."""
|
||||
assert astr_message_event.trace is not None
|
||||
assert astr_message_event.span is not None
|
||||
assert astr_message_event.trace == astr_message_event.span
|
||||
|
||||
|
||||
class TestUnifiedMsgOrigin:
|
||||
"""Tests for unified_msg_origin property."""
|
||||
|
||||
def test_unified_msg_origin_getter(self, astr_message_event):
|
||||
"""Test unified_msg_origin getter."""
|
||||
expected = "test_platform_id:FriendMessage:session123"
|
||||
assert astr_message_event.unified_msg_origin == expected
|
||||
|
||||
def test_unified_msg_origin_setter(self, astr_message_event):
|
||||
"""Test unified_msg_origin setter."""
|
||||
astr_message_event.unified_msg_origin = "new_platform:GroupMessage:new_session"
|
||||
|
||||
assert astr_message_event.session.platform_name == "new_platform"
|
||||
assert astr_message_event.session.session_id == "new_session"
|
||||
|
||||
|
||||
class TestSessionId:
|
||||
"""Tests for session_id property."""
|
||||
|
||||
def test_session_id_getter(self, astr_message_event):
|
||||
"""Test session_id getter."""
|
||||
assert astr_message_event.session_id == "session123"
|
||||
|
||||
def test_session_id_setter(self, astr_message_event):
|
||||
"""Test session_id setter."""
|
||||
astr_message_event.session_id = "new_session_id"
|
||||
|
||||
assert astr_message_event.session_id == "new_session_id"
|
||||
|
||||
|
||||
class TestGetPlatformInfo:
|
||||
"""Tests for platform info methods."""
|
||||
|
||||
def test_get_platform_name(self, astr_message_event):
|
||||
"""Test get_platform_name method."""
|
||||
assert astr_message_event.get_platform_name() == "test_platform"
|
||||
|
||||
def test_get_platform_id(self, astr_message_event):
|
||||
"""Test get_platform_id method."""
|
||||
assert astr_message_event.get_platform_id() == "test_platform_id"
|
||||
|
||||
|
||||
class TestGetMessageInfo:
|
||||
"""Tests for message info methods."""
|
||||
|
||||
def test_get_message_str(self, astr_message_event):
|
||||
"""Test get_message_str method."""
|
||||
assert astr_message_event.get_message_str() == "Hello world"
|
||||
|
||||
def test_get_message_str_none(self, platform_meta, astrbot_message):
|
||||
"""Test get_message_str keeps None when source message_str is None."""
|
||||
astrbot_message.message_str = None
|
||||
event = ConcreteAstrMessageEvent(
|
||||
message_str=None,
|
||||
message_obj=astrbot_message,
|
||||
platform_meta=platform_meta,
|
||||
session_id="session123",
|
||||
)
|
||||
assert event.get_message_str() is None
|
||||
|
||||
def test_get_messages(self, astr_message_event):
|
||||
"""Test get_messages method."""
|
||||
messages = astr_message_event.get_messages()
|
||||
assert len(messages) == 1
|
||||
assert isinstance(messages[0], Plain)
|
||||
assert messages[0].text == "Hello world"
|
||||
|
||||
def test_get_message_type(self, astr_message_event):
|
||||
"""Test get_message_type method."""
|
||||
assert astr_message_event.get_message_type() == MessageType.FRIEND_MESSAGE
|
||||
|
||||
def test_get_session_id(self, astr_message_event):
|
||||
"""Test get_session_id method."""
|
||||
assert astr_message_event.get_session_id() == "session123"
|
||||
|
||||
def test_get_group_id_empty_for_private(self, astr_message_event):
|
||||
"""Test get_group_id returns empty for private messages."""
|
||||
assert astr_message_event.get_group_id() == ""
|
||||
|
||||
def test_get_self_id(self, astr_message_event):
|
||||
"""Test get_self_id method."""
|
||||
assert astr_message_event.get_self_id() == "bot123"
|
||||
|
||||
def test_get_sender_id(self, astr_message_event):
|
||||
"""Test get_sender_id method."""
|
||||
assert astr_message_event.get_sender_id() == "user123"
|
||||
|
||||
def test_get_sender_name(self, astr_message_event):
|
||||
"""Test get_sender_name method."""
|
||||
assert astr_message_event.get_sender_name() == "TestUser"
|
||||
|
||||
def test_get_sender_name_empty_when_none(self, platform_meta, astrbot_message):
|
||||
"""Test get_sender_name returns empty string when nickname is None."""
|
||||
astrbot_message.sender = MessageMember(user_id="user123", nickname=None)
|
||||
event = ConcreteAstrMessageEvent(
|
||||
message_str="test",
|
||||
message_obj=astrbot_message,
|
||||
platform_meta=platform_meta,
|
||||
session_id="session123",
|
||||
)
|
||||
assert event.get_sender_name() == ""
|
||||
|
||||
def test_get_sender_name_coerces_non_string(self, platform_meta, astrbot_message):
|
||||
"""Test get_sender_name stringifies non-string nickname values."""
|
||||
astrbot_message.sender = MessageMember(user_id="user123", nickname=None)
|
||||
astrbot_message.sender.nickname = 12345
|
||||
event = ConcreteAstrMessageEvent(
|
||||
message_str="test",
|
||||
message_obj=astrbot_message,
|
||||
platform_meta=platform_meta,
|
||||
session_id="session123",
|
||||
)
|
||||
assert event.get_sender_name() == "12345"
|
||||
|
||||
|
||||
class TestGetMessageOutline:
|
||||
"""Tests for get_message_outline method."""
|
||||
|
||||
def test_outline_plain_text(self, astr_message_event):
|
||||
"""Test outline with plain text message."""
|
||||
outline = astr_message_event.get_message_outline()
|
||||
assert "Hello world" in outline
|
||||
|
||||
def test_outline_with_image(self, platform_meta, astrbot_message):
|
||||
"""Test outline with image component."""
|
||||
astrbot_message.message = [
|
||||
Plain(text="Look at this"),
|
||||
Image(file="http://example.com/img.jpg"),
|
||||
]
|
||||
event = ConcreteAstrMessageEvent(
|
||||
message_str="Look at this",
|
||||
message_obj=astrbot_message,
|
||||
platform_meta=platform_meta,
|
||||
session_id="session123",
|
||||
)
|
||||
outline = event.get_message_outline()
|
||||
assert "Look at this" in outline
|
||||
assert "[图片]" in outline
|
||||
|
||||
def test_outline_with_at(self, platform_meta, astrbot_message):
|
||||
"""Test outline with At component."""
|
||||
astrbot_message.message = [At(qq="12345"), Plain(text=" hello")]
|
||||
event = ConcreteAstrMessageEvent(
|
||||
message_str=" hello",
|
||||
message_obj=astrbot_message,
|
||||
platform_meta=platform_meta,
|
||||
session_id="session123",
|
||||
)
|
||||
outline = event.get_message_outline()
|
||||
assert "[At:12345]" in outline
|
||||
|
||||
def test_outline_with_at_all(self, platform_meta, astrbot_message):
|
||||
"""Test outline with AtAll component."""
|
||||
astrbot_message.message = [AtAll()]
|
||||
event = ConcreteAstrMessageEvent(
|
||||
message_str="",
|
||||
message_obj=astrbot_message,
|
||||
platform_meta=platform_meta,
|
||||
session_id="session123",
|
||||
)
|
||||
outline = event.get_message_outline()
|
||||
# AtAll format is "[At:all]" in the actual implementation
|
||||
assert "[At:" in outline and "all" in outline.lower()
|
||||
|
||||
def test_outline_with_face(self, platform_meta, astrbot_message):
|
||||
"""Test outline with Face component."""
|
||||
astrbot_message.message = [Face(id="123")]
|
||||
event = ConcreteAstrMessageEvent(
|
||||
message_str="",
|
||||
message_obj=astrbot_message,
|
||||
platform_meta=platform_meta,
|
||||
session_id="session123",
|
||||
)
|
||||
outline = event.get_message_outline()
|
||||
assert "[表情:123]" in outline
|
||||
|
||||
def test_outline_with_forward(self, platform_meta, astrbot_message):
|
||||
"""Test outline with Forward component."""
|
||||
# Forward requires an id parameter
|
||||
astrbot_message.message = [Forward(id="test_forward_id")]
|
||||
event = ConcreteAstrMessageEvent(
|
||||
message_str="",
|
||||
message_obj=astrbot_message,
|
||||
platform_meta=platform_meta,
|
||||
session_id="session123",
|
||||
)
|
||||
outline = event.get_message_outline()
|
||||
assert "[转发消息]" in outline
|
||||
|
||||
def test_outline_with_reply(self, platform_meta, astrbot_message):
|
||||
"""Test outline with Reply component."""
|
||||
# Reply requires an id parameter
|
||||
reply = Reply(id="test_reply_id")
|
||||
reply.message_str = "Original message"
|
||||
reply.sender_nickname = "Sender"
|
||||
astrbot_message.message = [reply, Plain(text=" reply")]
|
||||
event = ConcreteAstrMessageEvent(
|
||||
message_str=" reply",
|
||||
message_obj=astrbot_message,
|
||||
platform_meta=platform_meta,
|
||||
session_id="session123",
|
||||
)
|
||||
outline = event.get_message_outline()
|
||||
assert "[引用消息(Sender: Original message)]" in outline
|
||||
|
||||
def test_outline_with_reply_no_message(self, platform_meta, astrbot_message):
|
||||
"""Test outline with Reply component without message_str."""
|
||||
# Reply requires an id parameter
|
||||
reply = Reply(id="test_reply_id")
|
||||
reply.message_str = None
|
||||
astrbot_message.message = [reply]
|
||||
event = ConcreteAstrMessageEvent(
|
||||
message_str="",
|
||||
message_obj=astrbot_message,
|
||||
platform_meta=platform_meta,
|
||||
session_id="session123",
|
||||
)
|
||||
outline = event.get_message_outline()
|
||||
assert "[引用消息]" in outline
|
||||
|
||||
def test_outline_empty_chain(self, platform_meta, astrbot_message):
|
||||
"""Test outline with empty message chain."""
|
||||
astrbot_message.message = []
|
||||
event = ConcreteAstrMessageEvent(
|
||||
message_str="",
|
||||
message_obj=astrbot_message,
|
||||
platform_meta=platform_meta,
|
||||
session_id="session123",
|
||||
)
|
||||
outline = event.get_message_outline()
|
||||
assert outline == ""
|
||||
|
||||
def test_outline_very_long_plain_text(self, platform_meta, astrbot_message):
|
||||
"""Test outline generation for very long plain text content."""
|
||||
long_text = "A" * 20000
|
||||
astrbot_message.message = [Plain(text=long_text)]
|
||||
event = ConcreteAstrMessageEvent(
|
||||
message_str=long_text,
|
||||
message_obj=astrbot_message,
|
||||
platform_meta=platform_meta,
|
||||
session_id="session123",
|
||||
)
|
||||
outline = event.get_message_outline()
|
||||
assert outline.startswith("A")
|
||||
assert len(outline) >= 20000
|
||||
|
||||
|
||||
class TestExtras:
|
||||
"""Tests for extra information methods."""
|
||||
|
||||
def test_set_extra(self, astr_message_event):
|
||||
"""Test set_extra method."""
|
||||
astr_message_event.set_extra("key1", "value1")
|
||||
assert astr_message_event._extras["key1"] == "value1"
|
||||
|
||||
def test_get_extra_with_key(self, astr_message_event):
|
||||
"""Test get_extra with specific key."""
|
||||
astr_message_event.set_extra("key1", "value1")
|
||||
assert astr_message_event.get_extra("key1") == "value1"
|
||||
|
||||
def test_get_extra_with_default(self, astr_message_event):
|
||||
"""Test get_extra with default value."""
|
||||
result = astr_message_event.get_extra("nonexistent", "default_value")
|
||||
assert result == "default_value"
|
||||
|
||||
def test_get_extra_all(self, astr_message_event):
|
||||
"""Test get_extra without key returns all extras."""
|
||||
astr_message_event.set_extra("key1", "value1")
|
||||
astr_message_event.set_extra("key2", "value2")
|
||||
all_extras = astr_message_event.get_extra()
|
||||
assert all_extras == {"key1": "value1", "key2": "value2"}
|
||||
|
||||
def test_clear_extra(self, astr_message_event):
|
||||
"""Test clear_extra method."""
|
||||
astr_message_event.set_extra("key1", "value1")
|
||||
astr_message_event.clear_extra()
|
||||
assert astr_message_event._extras == {}
|
||||
|
||||
|
||||
class TestSetResult:
|
||||
"""Tests for set_result method."""
|
||||
|
||||
def test_set_result_with_message_event_result(self, astr_message_event):
|
||||
"""Test set_result with MessageEventResult object."""
|
||||
result = MessageEventResult().message("Test message")
|
||||
astr_message_event.set_result(result)
|
||||
|
||||
assert astr_message_event._result == result
|
||||
|
||||
def test_set_result_with_string(self, astr_message_event):
|
||||
"""Test set_result with string creates MessageEventResult."""
|
||||
astr_message_event.set_result("Test message")
|
||||
|
||||
assert astr_message_event._result is not None
|
||||
assert len(astr_message_event._result.chain) == 1
|
||||
assert isinstance(astr_message_event._result.chain[0], Plain)
|
||||
|
||||
def test_set_result_with_empty_chain(self, astr_message_event):
|
||||
"""Test set_result handles empty chain correctly."""
|
||||
result = MessageEventResult()
|
||||
# chain is already an empty list by default
|
||||
astr_message_event.set_result(result)
|
||||
|
||||
assert astr_message_event._result.chain == []
|
||||
|
||||
|
||||
class TestStopContinueEvent:
|
||||
"""Tests for stop_event and continue_event methods."""
|
||||
|
||||
def test_stop_event_creates_result_if_none(self, astr_message_event):
|
||||
"""Test stop_event creates result if none exists."""
|
||||
astr_message_event.stop_event()
|
||||
|
||||
assert astr_message_event._result is not None
|
||||
assert astr_message_event.is_stopped() is True
|
||||
|
||||
def test_stop_event_with_existing_result(self, astr_message_event):
|
||||
"""Test stop_event with existing result."""
|
||||
astr_message_event.set_result(MessageEventResult().message("Test"))
|
||||
astr_message_event.stop_event()
|
||||
|
||||
assert astr_message_event.is_stopped() is True
|
||||
|
||||
def test_continue_event_creates_result_if_none(self, astr_message_event):
|
||||
"""Test continue_event creates result if none exists."""
|
||||
astr_message_event.continue_event()
|
||||
|
||||
assert astr_message_event._result is not None
|
||||
assert astr_message_event.is_stopped() is False
|
||||
|
||||
def test_continue_event_with_existing_result(self, astr_message_event):
|
||||
"""Test continue_event with existing result."""
|
||||
astr_message_event.set_result(MessageEventResult().message("Test"))
|
||||
astr_message_event.stop_event()
|
||||
astr_message_event.continue_event()
|
||||
|
||||
assert astr_message_event.is_stopped() is False
|
||||
|
||||
def test_is_stopped_default_false(self, astr_message_event):
|
||||
"""Test is_stopped returns False by default."""
|
||||
assert astr_message_event.is_stopped() is False
|
||||
|
||||
|
||||
class TestIsPrivateChat:
|
||||
"""Tests for is_private_chat method."""
|
||||
|
||||
def test_is_private_chat_true(self, astr_message_event):
|
||||
"""Test is_private_chat returns True for friend message."""
|
||||
assert astr_message_event.is_private_chat() is True
|
||||
|
||||
def test_is_private_chat_false(self, platform_meta, astrbot_message):
|
||||
"""Test is_private_chat returns False for group message."""
|
||||
astrbot_message.type = MessageType.GROUP_MESSAGE
|
||||
event = ConcreteAstrMessageEvent(
|
||||
message_str="test",
|
||||
message_obj=astrbot_message,
|
||||
platform_meta=platform_meta,
|
||||
session_id="session123",
|
||||
)
|
||||
assert event.is_private_chat() is False
|
||||
|
||||
|
||||
class TestIsWakeUp:
|
||||
"""Tests for is_wake_up method."""
|
||||
|
||||
def test_is_wake_up_default_false(self, astr_message_event):
|
||||
"""Test is_wake_up returns False by default."""
|
||||
assert astr_message_event.is_wake_up() is False
|
||||
|
||||
def test_is_wake_up_when_set(self, astr_message_event):
|
||||
"""Test is_wake_up returns True when is_wake is set."""
|
||||
astr_message_event.is_wake = True
|
||||
assert astr_message_event.is_wake_up() is True
|
||||
|
||||
|
||||
class TestIsAdmin:
|
||||
"""Tests for is_admin method."""
|
||||
|
||||
def test_is_admin_default_false(self, astr_message_event):
|
||||
"""Test is_admin returns False by default."""
|
||||
assert astr_message_event.is_admin() is False
|
||||
|
||||
def test_is_admin_when_admin(self, astr_message_event):
|
||||
"""Test is_admin returns True when role is admin."""
|
||||
astr_message_event.role = "admin"
|
||||
assert astr_message_event.is_admin() is True
|
||||
|
||||
|
||||
class TestProcessBuffer:
|
||||
"""Tests for process_buffer method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_buffer_splits_by_pattern(self, astr_message_event):
|
||||
"""Test process_buffer splits buffer by pattern."""
|
||||
buffer = "Line 1\nLine 2\nLine 3\nRemaining"
|
||||
pattern = re.compile(r".*\n")
|
||||
|
||||
with patch.object(
|
||||
astr_message_event, "send", new_callable=AsyncMock
|
||||
) as mock_send:
|
||||
result = await astr_message_event.process_buffer(buffer, pattern)
|
||||
|
||||
# Should have sent 3 lines and remaining should be "Remaining"
|
||||
assert mock_send.call_count == 3
|
||||
assert result == "Remaining"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_buffer_no_match(self, astr_message_event):
|
||||
"""Test process_buffer returns original when no match."""
|
||||
buffer = "No newlines here"
|
||||
pattern = re.compile(r"\n")
|
||||
|
||||
result = await astr_message_event.process_buffer(buffer, pattern)
|
||||
|
||||
assert result == "No newlines here"
|
||||
|
||||
|
||||
class TestResultHelpers:
|
||||
"""Tests for result helper methods."""
|
||||
|
||||
def test_make_result(self, astr_message_event):
|
||||
"""Test make_result creates empty MessageEventResult."""
|
||||
result = astr_message_event.make_result()
|
||||
assert isinstance(result, MessageEventResult)
|
||||
|
||||
def test_plain_result(self, astr_message_event):
|
||||
"""Test plain_result creates result with text."""
|
||||
result = astr_message_event.plain_result("Hello")
|
||||
|
||||
assert isinstance(result, MessageEventResult)
|
||||
assert len(result.chain) == 1
|
||||
assert isinstance(result.chain[0], Plain)
|
||||
assert result.chain[0].text == "Hello"
|
||||
|
||||
def test_image_result_url(self, astr_message_event):
|
||||
"""Test image_result with URL."""
|
||||
result = astr_message_event.image_result("http://example.com/image.jpg")
|
||||
|
||||
assert isinstance(result, MessageEventResult)
|
||||
assert len(result.chain) == 1
|
||||
assert isinstance(result.chain[0], Image)
|
||||
|
||||
def test_image_result_path(self, astr_message_event):
|
||||
"""Test image_result with file path."""
|
||||
result = astr_message_event.image_result("/path/to/image.jpg")
|
||||
|
||||
assert isinstance(result, MessageEventResult)
|
||||
assert len(result.chain) == 1
|
||||
assert isinstance(result.chain[0], Image)
|
||||
|
||||
|
||||
class TestGetResult:
|
||||
"""Tests for get_result and clear_result methods."""
|
||||
|
||||
def test_get_result_returns_none_by_default(self, astr_message_event):
|
||||
"""Test get_result returns None by default."""
|
||||
assert astr_message_event.get_result() is None
|
||||
|
||||
def test_get_result_returns_set_result(self, astr_message_event):
|
||||
"""Test get_result returns set result."""
|
||||
result = MessageEventResult().message("Test")
|
||||
astr_message_event.set_result(result)
|
||||
|
||||
assert astr_message_event.get_result() == result
|
||||
|
||||
def test_clear_result(self, astr_message_event):
|
||||
"""Test clear_result clears the result."""
|
||||
astr_message_event.set_result(MessageEventResult().message("Test"))
|
||||
astr_message_event.clear_result()
|
||||
|
||||
assert astr_message_event.get_result() is None
|
||||
|
||||
|
||||
class TestShouldCallLlm:
|
||||
"""Tests for should_call_llm method."""
|
||||
|
||||
def test_should_call_llm_default(self, astr_message_event):
|
||||
"""Test call_llm default is False."""
|
||||
assert astr_message_event.call_llm is False
|
||||
|
||||
def test_should_call_llm_when_set(self, astr_message_event):
|
||||
"""Test should_call_llm sets call_llm."""
|
||||
astr_message_event.should_call_llm(True)
|
||||
assert astr_message_event.call_llm is True
|
||||
|
||||
|
||||
class TestRequestLlm:
|
||||
"""Tests for request_llm method."""
|
||||
|
||||
def test_request_llm_basic(self, astr_message_event):
|
||||
"""Test request_llm creates ProviderRequest."""
|
||||
request = astr_message_event.request_llm(prompt="Hello")
|
||||
|
||||
assert request.prompt == "Hello"
|
||||
assert request.session_id == ""
|
||||
assert request.image_urls == []
|
||||
assert request.contexts == []
|
||||
|
||||
def test_request_llm_with_all_params(self, astr_message_event):
|
||||
"""Test request_llm with all parameters."""
|
||||
request = astr_message_event.request_llm(
|
||||
prompt="Hello",
|
||||
session_id="session123",
|
||||
image_urls=["http://example.com/img.jpg"],
|
||||
contexts=[{"role": "user", "content": "Hi"}],
|
||||
system_prompt="You are helpful",
|
||||
)
|
||||
|
||||
assert request.prompt == "Hello"
|
||||
assert request.session_id == "session123"
|
||||
assert request.image_urls == ["http://example.com/img.jpg"]
|
||||
assert request.contexts == [{"role": "user", "content": "Hi"}]
|
||||
assert request.system_prompt == "You are helpful"
|
||||
|
||||
|
||||
class TestSendStreaming:
|
||||
"""Tests for send_streaming method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_streaming_sets_has_send_oper(self, astr_message_event):
|
||||
"""Test send_streaming sets _has_send_oper flag."""
|
||||
assert astr_message_event._has_send_oper is False
|
||||
|
||||
async def generator():
|
||||
yield MessageEventResult().message("Test")
|
||||
|
||||
with patch(
|
||||
"astrbot.core.platform.astr_message_event.Metric.upload",
|
||||
new_callable=AsyncMock,
|
||||
):
|
||||
await astr_message_event.send_streaming(generator())
|
||||
|
||||
assert astr_message_event._has_send_oper is True
|
||||
|
||||
|
||||
class TestSendTyping:
|
||||
"""Tests for send_typing method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_typing_default_empty(self, astr_message_event):
|
||||
"""Test send_typing default implementation is empty."""
|
||||
# Should not raise any exception
|
||||
await astr_message_event.send_typing()
|
||||
|
||||
|
||||
class TestReact:
|
||||
"""Tests for react method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_react_sends_emoji(self, astr_message_event):
|
||||
"""Test react sends emoji as message."""
|
||||
with patch.object(
|
||||
astr_message_event, "send", new_callable=AsyncMock
|
||||
) as mock_send:
|
||||
await astr_message_event.react("👍")
|
||||
|
||||
mock_send.assert_called_once()
|
||||
call_arg = mock_send.call_args[0][0]
|
||||
# MessageChain is a dataclass with chain attribute
|
||||
assert len(call_arg.chain) == 1
|
||||
assert isinstance(call_arg.chain[0], Plain)
|
||||
assert call_arg.chain[0].text == "👍"
|
||||
|
||||
|
||||
class TestGetGroup:
|
||||
"""Tests for get_group method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_group_returns_none_for_private(self, astr_message_event):
|
||||
"""Test get_group returns None for private chat."""
|
||||
result = await astr_message_event.get_group()
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_group_with_group_id_param(self, astr_message_event):
|
||||
"""Test get_group with group_id parameter."""
|
||||
# Default implementation returns None
|
||||
result = await astr_message_event.get_group(group_id="group123")
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestMessageTypeHandling:
|
||||
"""Tests for message type handling edge cases."""
|
||||
|
||||
def test_message_type_from_valid_string(self, platform_meta):
|
||||
"""Valid MessageType string should be converted correctly."""
|
||||
message = AstrBotMessage()
|
||||
message.type = "FRIEND_MESSAGE"
|
||||
message.message = []
|
||||
event = ConcreteAstrMessageEvent(
|
||||
message_str="test",
|
||||
message_obj=message,
|
||||
platform_meta=platform_meta,
|
||||
session_id="session123",
|
||||
)
|
||||
assert event.session.message_type == MessageType.FRIEND_MESSAGE
|
||||
assert event.get_message_type() == MessageType.FRIEND_MESSAGE
|
||||
|
||||
def test_message_type_from_invalid_string_defaults_to_friend(self, platform_meta):
|
||||
"""Invalid message type should default to FRIEND_MESSAGE."""
|
||||
message = AstrBotMessage()
|
||||
message.type = "InvalidMessageType"
|
||||
message.message = []
|
||||
event = ConcreteAstrMessageEvent(
|
||||
message_str="test",
|
||||
message_obj=message,
|
||||
platform_meta=platform_meta,
|
||||
session_id="session123",
|
||||
)
|
||||
assert event.session.message_type == MessageType.FRIEND_MESSAGE
|
||||
assert event.get_message_type() == MessageType.FRIEND_MESSAGE
|
||||
|
||||
def test_message_type_from_none_defaults_to_friend(self, platform_meta):
|
||||
"""None message type should default to FRIEND_MESSAGE."""
|
||||
message = AstrBotMessage()
|
||||
message.type = None
|
||||
message.message = []
|
||||
event = ConcreteAstrMessageEvent(
|
||||
message_str="test",
|
||||
message_obj=message,
|
||||
platform_meta=platform_meta,
|
||||
session_id="session123",
|
||||
)
|
||||
assert event.session.message_type == MessageType.FRIEND_MESSAGE
|
||||
assert event.get_message_type() == MessageType.FRIEND_MESSAGE
|
||||
|
||||
def test_message_type_from_integer_defaults_to_friend(self, platform_meta):
|
||||
"""Integer message type should default to FRIEND_MESSAGE."""
|
||||
message = AstrBotMessage()
|
||||
message.type = 123
|
||||
message.message = []
|
||||
event = ConcreteAstrMessageEvent(
|
||||
message_str="test",
|
||||
message_obj=message,
|
||||
platform_meta=platform_meta,
|
||||
session_id="session123",
|
||||
)
|
||||
assert event.session.message_type == MessageType.FRIEND_MESSAGE
|
||||
assert event.get_message_type() == MessageType.FRIEND_MESSAGE
|
||||
|
||||
|
||||
class TestDefensiveGetattr:
|
||||
"""Tests for defensive getattr behavior in AstrMessageEvent."""
|
||||
|
||||
def test_get_messages_without_message_attr(self, astr_message_event):
|
||||
"""get_messages should handle message_obj without 'message' attribute."""
|
||||
astr_message_event.message_obj = type("DummyMessage", (), {})()
|
||||
messages = astr_message_event.get_messages()
|
||||
assert isinstance(messages, list)
|
||||
|
||||
def test_get_message_type_without_type_attr(self, astr_message_event):
|
||||
"""get_message_type should handle message_obj without 'type' attribute."""
|
||||
astr_message_event.message_obj = type("DummyMessage", (), {})()
|
||||
message_type = astr_message_event.get_message_type()
|
||||
assert isinstance(message_type, MessageType)
|
||||
|
||||
def test_get_sender_fields_without_sender_attr(self, astr_message_event):
|
||||
"""get_sender_id and get_sender_name should handle missing 'sender'."""
|
||||
astr_message_event.message_obj = type("DummyMessage", (), {})()
|
||||
sender_id = astr_message_event.get_sender_id()
|
||||
sender_name = astr_message_event.get_sender_name()
|
||||
assert isinstance(sender_id, str)
|
||||
assert isinstance(sender_name, str)
|
||||
|
||||
def test_get_message_type_with_non_enum_type(self, astr_message_event):
|
||||
"""get_message_type should handle message_obj.type that is not a MessageType."""
|
||||
class DummyMessage:
|
||||
def __init__(self):
|
||||
self.type = "not_an_enum"
|
||||
self.message = []
|
||||
astr_message_event.message_obj = DummyMessage()
|
||||
message_type = astr_message_event.get_message_type()
|
||||
assert isinstance(message_type, MessageType)
|
||||
@@ -0,0 +1,268 @@
|
||||
"""Tests for AstrBotMessage and MessageMember classes."""
|
||||
|
||||
import time
|
||||
from unittest.mock import patch
|
||||
|
||||
from astrbot.core.message.components import Image, Plain
|
||||
from astrbot.core.platform.astrbot_message import AstrBotMessage, Group, MessageMember
|
||||
from astrbot.core.platform.message_type import MessageType
|
||||
|
||||
|
||||
class TestMessageMember:
|
||||
"""Tests for MessageMember dataclass."""
|
||||
|
||||
def test_message_member_creation_basic(self):
|
||||
"""Test creating a MessageMember with required fields."""
|
||||
member = MessageMember(user_id="user123")
|
||||
|
||||
assert member.user_id == "user123"
|
||||
assert member.nickname is None
|
||||
|
||||
def test_message_member_creation_with_nickname(self):
|
||||
"""Test creating a MessageMember with nickname."""
|
||||
member = MessageMember(user_id="user123", nickname="TestUser")
|
||||
|
||||
assert member.user_id == "user123"
|
||||
assert member.nickname == "TestUser"
|
||||
|
||||
def test_message_member_str_with_nickname(self):
|
||||
"""Test __str__ method with nickname."""
|
||||
member = MessageMember(user_id="user123", nickname="TestUser")
|
||||
result = str(member)
|
||||
|
||||
assert "User ID: user123" in result
|
||||
assert "Nickname: TestUser" in result
|
||||
|
||||
def test_message_member_str_without_nickname(self):
|
||||
"""Test __str__ method without nickname."""
|
||||
member = MessageMember(user_id="user123")
|
||||
result = str(member)
|
||||
|
||||
assert "User ID: user123" in result
|
||||
assert "Nickname: N/A" in result
|
||||
|
||||
|
||||
class TestGroup:
|
||||
"""Tests for Group dataclass."""
|
||||
|
||||
def test_group_creation_basic(self):
|
||||
"""Test creating a Group with required fields."""
|
||||
group = Group(group_id="group123")
|
||||
|
||||
assert group.group_id == "group123"
|
||||
assert group.group_name is None
|
||||
assert group.group_avatar is None
|
||||
assert group.group_owner is None
|
||||
assert group.group_admins is None
|
||||
assert group.members is None
|
||||
|
||||
def test_group_creation_with_all_fields(self):
|
||||
"""Test creating a Group with all fields."""
|
||||
members = [MessageMember(user_id="user1"), MessageMember(user_id="user2")]
|
||||
group = Group(
|
||||
group_id="group123",
|
||||
group_name="Test Group",
|
||||
group_avatar="http://example.com/avatar.jpg",
|
||||
group_owner="owner123",
|
||||
group_admins=["admin1", "admin2"],
|
||||
members=members,
|
||||
)
|
||||
|
||||
assert group.group_id == "group123"
|
||||
assert group.group_name == "Test Group"
|
||||
assert group.group_avatar == "http://example.com/avatar.jpg"
|
||||
assert group.group_owner == "owner123"
|
||||
assert group.group_admins == ["admin1", "admin2"]
|
||||
assert group.members == members
|
||||
|
||||
def test_group_str_with_all_fields(self):
|
||||
"""Test __str__ method with all fields."""
|
||||
members = [MessageMember(user_id="user1", nickname="User One")]
|
||||
group = Group(
|
||||
group_id="group123",
|
||||
group_name="Test Group",
|
||||
group_avatar="http://example.com/avatar.jpg",
|
||||
group_owner="owner123",
|
||||
group_admins=["admin1"],
|
||||
members=members,
|
||||
)
|
||||
result = str(group)
|
||||
|
||||
assert "Group ID: group123" in result
|
||||
assert "Name: Test Group" in result
|
||||
assert "Avatar: http://example.com/avatar.jpg" in result
|
||||
assert "Owner ID: owner123" in result
|
||||
assert "Admin IDs: ['admin1']" in result
|
||||
assert "Members Len: 1" in result
|
||||
|
||||
def test_group_str_with_minimal_fields(self):
|
||||
"""Test __str__ method with minimal fields."""
|
||||
group = Group(group_id="group123")
|
||||
result = str(group)
|
||||
|
||||
assert "Group ID: group123" in result
|
||||
assert "Name: N/A" in result
|
||||
assert "Avatar: N/A" in result
|
||||
assert "Owner ID: N/A" in result
|
||||
assert "Admin IDs: N/A" in result
|
||||
assert "Members Len: 0" in result
|
||||
assert "First Member: N/A" in result
|
||||
|
||||
|
||||
class TestAstrBotMessage:
|
||||
"""Tests for AstrBotMessage class."""
|
||||
|
||||
def test_astrbot_message_creation(self):
|
||||
"""Test creating an AstrBotMessage."""
|
||||
message = AstrBotMessage()
|
||||
|
||||
assert message.group is None
|
||||
assert message.timestamp is not None
|
||||
assert isinstance(message.timestamp, int)
|
||||
|
||||
def test_astrbot_message_timestamp(self):
|
||||
"""Test timestamp is set on creation."""
|
||||
with patch.object(time, "time", return_value=1234567890):
|
||||
message = AstrBotMessage()
|
||||
assert message.timestamp == 1234567890
|
||||
|
||||
def test_astrbot_message_all_attributes(self):
|
||||
"""Test setting all attributes on AstrBotMessage."""
|
||||
message = AstrBotMessage()
|
||||
message.type = MessageType.FRIEND_MESSAGE
|
||||
message.self_id = "bot123"
|
||||
message.session_id = "session123"
|
||||
message.message_id = "msg123"
|
||||
message.sender = MessageMember(user_id="user123", nickname="TestUser")
|
||||
message.message = [Plain(text="Hello")]
|
||||
message.message_str = "Hello"
|
||||
message.raw_message = {"raw": "data"}
|
||||
|
||||
assert message.type == MessageType.FRIEND_MESSAGE
|
||||
assert message.self_id == "bot123"
|
||||
assert message.session_id == "session123"
|
||||
assert message.message_id == "msg123"
|
||||
assert message.sender.user_id == "user123"
|
||||
assert len(message.message) == 1
|
||||
assert message.message_str == "Hello"
|
||||
assert message.raw_message == {"raw": "data"}
|
||||
|
||||
def test_astrbot_message_str(self):
|
||||
"""Test __str__ method."""
|
||||
message = AstrBotMessage()
|
||||
message.type = MessageType.FRIEND_MESSAGE
|
||||
message.self_id = "bot123"
|
||||
|
||||
result = str(message)
|
||||
assert "'type'" in result
|
||||
assert "'self_id'" in result
|
||||
|
||||
|
||||
class TestAstrBotMessageGroupId:
|
||||
"""Tests for AstrBotMessage group_id property."""
|
||||
|
||||
def test_group_id_returns_empty_when_no_group(self):
|
||||
"""Test group_id returns empty string when group is None."""
|
||||
message = AstrBotMessage()
|
||||
assert message.group_id == ""
|
||||
|
||||
def test_group_id_returns_group_id_when_group_exists(self):
|
||||
"""Test group_id returns the group's id when group exists."""
|
||||
message = AstrBotMessage()
|
||||
message.group = Group(group_id="group123")
|
||||
|
||||
assert message.group_id == "group123"
|
||||
|
||||
def test_group_id_setter_creates_new_group(self):
|
||||
"""Test group_id setter creates a new group if none exists."""
|
||||
message = AstrBotMessage()
|
||||
message.group_id = "new_group123"
|
||||
|
||||
assert message.group is not None
|
||||
assert message.group.group_id == "new_group123"
|
||||
|
||||
def test_group_id_setter_updates_existing_group(self):
|
||||
"""Test group_id setter updates existing group's id."""
|
||||
message = AstrBotMessage()
|
||||
message.group = Group(group_id="old_group")
|
||||
message.group_id = "new_group"
|
||||
|
||||
assert message.group.group_id == "new_group"
|
||||
|
||||
def test_group_id_setter_with_none_removes_group(self):
|
||||
"""Test group_id setter with None removes the group."""
|
||||
message = AstrBotMessage()
|
||||
message.group = Group(group_id="group123")
|
||||
message.group_id = None
|
||||
|
||||
assert message.group is None
|
||||
|
||||
def test_group_id_setter_with_empty_string_removes_group(self):
|
||||
"""Test group_id setter with empty string removes the group."""
|
||||
message = AstrBotMessage()
|
||||
message.group = Group(group_id="group123")
|
||||
message.group_id = ""
|
||||
|
||||
assert message.group is None
|
||||
|
||||
|
||||
class TestAstrBotMessageTypes:
|
||||
"""Tests for AstrBotMessage with different message types."""
|
||||
|
||||
def test_friend_message_type(self):
|
||||
"""Test AstrBotMessage with FRIEND_MESSAGE type."""
|
||||
message = AstrBotMessage()
|
||||
message.type = MessageType.FRIEND_MESSAGE
|
||||
|
||||
assert message.type == MessageType.FRIEND_MESSAGE
|
||||
assert message.type.value == "FriendMessage"
|
||||
|
||||
def test_group_message_type(self):
|
||||
"""Test AstrBotMessage with GROUP_MESSAGE type."""
|
||||
message = AstrBotMessage()
|
||||
message.type = MessageType.GROUP_MESSAGE
|
||||
|
||||
assert message.type == MessageType.GROUP_MESSAGE
|
||||
assert message.type.value == "GroupMessage"
|
||||
|
||||
def test_other_message_type(self):
|
||||
"""Test AstrBotMessage with OTHER_MESSAGE type."""
|
||||
message = AstrBotMessage()
|
||||
message.type = MessageType.OTHER_MESSAGE
|
||||
|
||||
assert message.type == MessageType.OTHER_MESSAGE
|
||||
assert message.type.value == "OtherMessage"
|
||||
|
||||
|
||||
class TestAstrBotMessageChain:
|
||||
"""Tests for AstrBotMessage message chain."""
|
||||
|
||||
def test_message_chain_with_plain_text(self):
|
||||
"""Test message chain with plain text."""
|
||||
message = AstrBotMessage()
|
||||
message.message = [Plain(text="Hello world")]
|
||||
|
||||
assert len(message.message) == 1
|
||||
assert isinstance(message.message[0], Plain)
|
||||
assert message.message[0].text == "Hello world"
|
||||
|
||||
def test_message_chain_with_multiple_components(self):
|
||||
"""Test message chain with multiple components."""
|
||||
message = AstrBotMessage()
|
||||
message.message = [
|
||||
Plain(text="Hello "),
|
||||
Plain(text="world"),
|
||||
Image(file="http://example.com/img.jpg"),
|
||||
]
|
||||
|
||||
assert len(message.message) == 3
|
||||
assert isinstance(message.message[0], Plain)
|
||||
assert isinstance(message.message[1], Plain)
|
||||
assert isinstance(message.message[2], Image)
|
||||
|
||||
def test_message_chain_empty(self):
|
||||
"""Test empty message chain."""
|
||||
message = AstrBotMessage()
|
||||
message.message = []
|
||||
|
||||
assert len(message.message) == 0
|
||||
Reference in New Issue
Block a user