Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| bb1b6a3814 |
@@ -17,6 +17,8 @@ class ContextWrapper(Generic[TContext]):
|
|||||||
messages: list[Message] = Field(default_factory=list)
|
messages: list[Message] = Field(default_factory=list)
|
||||||
"""This field stores the llm message context for the agent run, agent runners will maintain this field automatically."""
|
"""This field stores the llm message context for the agent run, agent runners will maintain this field automatically."""
|
||||||
tool_call_timeout: int = 60 # Default tool call timeout in seconds
|
tool_call_timeout: int = 60 # Default tool call timeout in seconds
|
||||||
|
tool_call_approval: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
"""Tool call approval runtime configuration."""
|
||||||
|
|
||||||
|
|
||||||
NoContext = ContextWrapper[None]
|
NoContext = ContextWrapper[None]
|
||||||
|
|||||||
@@ -37,6 +37,10 @@ from ..hooks import BaseAgentRunHooks
|
|||||||
from ..message import AssistantMessageSegment, Message, ToolCallMessageSegment
|
from ..message import AssistantMessageSegment, Message, ToolCallMessageSegment
|
||||||
from ..response import AgentResponseData, AgentStats
|
from ..response import AgentResponseData, AgentStats
|
||||||
from ..run_context import ContextWrapper, TContext
|
from ..run_context import ContextWrapper, TContext
|
||||||
|
from ..tool_call_approval import (
|
||||||
|
ToolCallApprovalContext,
|
||||||
|
request_tool_call_approval,
|
||||||
|
)
|
||||||
from ..tool_executor import BaseFunctionToolExecutor
|
from ..tool_executor import BaseFunctionToolExecutor
|
||||||
from .base import AgentResponse, AgentState, BaseAgentRunner
|
from .base import AgentResponse, AgentState, BaseAgentRunner
|
||||||
|
|
||||||
@@ -659,6 +663,41 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
|||||||
# 如果没有 handler(如 MCP 工具),使用所有参数
|
# 如果没有 handler(如 MCP 工具),使用所有参数
|
||||||
valid_params = func_tool_args
|
valid_params = func_tool_args
|
||||||
|
|
||||||
|
approval_cfg = self.run_context.tool_call_approval
|
||||||
|
if approval_cfg.get("enable", False):
|
||||||
|
event = getattr(self.run_context.context, "event", None)
|
||||||
|
if event is None:
|
||||||
|
tool_call_result_blocks.append(
|
||||||
|
ToolCallMessageSegment(
|
||||||
|
role="tool",
|
||||||
|
tool_call_id=func_tool_id,
|
||||||
|
content=(
|
||||||
|
f"error: tool call approval is enabled, but event context is unavailable for `{func_tool_name}`."
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
approval_result = await request_tool_call_approval(
|
||||||
|
config=approval_cfg,
|
||||||
|
ctx=ToolCallApprovalContext(
|
||||||
|
event=event,
|
||||||
|
tool_name=func_tool_name,
|
||||||
|
tool_args=valid_params,
|
||||||
|
tool_call_id=func_tool_id,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
if not approval_result.approved:
|
||||||
|
tool_call_result_blocks.append(
|
||||||
|
ToolCallMessageSegment(
|
||||||
|
role="tool",
|
||||||
|
tool_call_id=func_tool_id,
|
||||||
|
content=approval_result.to_tool_result_text(
|
||||||
|
func_tool_name
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await self.agent_hooks.on_tool_start(
|
await self.agent_hooks.on_tool_start(
|
||||||
self.run_context,
|
self.run_context,
|
||||||
|
|||||||
@@ -0,0 +1,248 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import secrets
|
||||||
|
import string
|
||||||
|
import typing as T
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from astrbot import logger
|
||||||
|
from astrbot.core.message.message_event_result import MessageChain
|
||||||
|
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||||
|
from astrbot.core.utils.session_waiter import (
|
||||||
|
FILTERS,
|
||||||
|
DefaultSessionFilter,
|
||||||
|
SessionController,
|
||||||
|
SessionWaiter,
|
||||||
|
)
|
||||||
|
|
||||||
|
ApprovalReason = T.Literal[
|
||||||
|
"approved",
|
||||||
|
"rejected",
|
||||||
|
"timeout",
|
||||||
|
"unsupported_strategy",
|
||||||
|
"error",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class ToolCallApprovalContext:
|
||||||
|
event: AstrMessageEvent
|
||||||
|
tool_name: str
|
||||||
|
tool_args: dict[str, T.Any]
|
||||||
|
tool_call_id: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class ToolCallApprovalResult:
|
||||||
|
approved: bool
|
||||||
|
reason: ApprovalReason
|
||||||
|
detail: str = ""
|
||||||
|
|
||||||
|
def to_tool_result_text(self, tool_name: str) -> str:
|
||||||
|
if self.approved:
|
||||||
|
return f"tool call approval passed: {tool_name}"
|
||||||
|
if self.reason == "timeout":
|
||||||
|
return (
|
||||||
|
f"error: tool call approval timed out for `{tool_name}`. "
|
||||||
|
"The tool call was cancelled."
|
||||||
|
)
|
||||||
|
if self.reason == "unsupported_strategy":
|
||||||
|
return (
|
||||||
|
f"error: tool call approval strategy is unsupported for `{tool_name}`. "
|
||||||
|
"The tool call was cancelled."
|
||||||
|
)
|
||||||
|
if self.reason == "error":
|
||||||
|
return (
|
||||||
|
f"error: tool call approval failed for `{tool_name}` ({self.detail}). "
|
||||||
|
"The tool call was cancelled."
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
f"error: user rejected tool call approval for `{tool_name}`. "
|
||||||
|
"The tool call was cancelled."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseToolCallApprovalStrategy(ABC):
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def name(self) -> str: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def request(
|
||||||
|
self,
|
||||||
|
ctx: ToolCallApprovalContext,
|
||||||
|
config: dict[str, T.Any],
|
||||||
|
) -> ToolCallApprovalResult: ...
|
||||||
|
|
||||||
|
|
||||||
|
class DynamicCodeApprovalStrategy(BaseToolCallApprovalStrategy):
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return "dynamic_code"
|
||||||
|
|
||||||
|
async def request(
|
||||||
|
self,
|
||||||
|
ctx: ToolCallApprovalContext,
|
||||||
|
config: dict[str, T.Any],
|
||||||
|
) -> ToolCallApprovalResult:
|
||||||
|
timeout_seconds = _safe_int(config.get("timeout", 60), default=60, minimum=1)
|
||||||
|
dynamic_cfg = config.get("dynamic_code", {})
|
||||||
|
if not isinstance(dynamic_cfg, dict):
|
||||||
|
dynamic_cfg = {}
|
||||||
|
code_length = _safe_int(dynamic_cfg.get("code_length", 6), default=6, minimum=4)
|
||||||
|
case_sensitive = bool(dynamic_cfg.get("case_sensitive", False))
|
||||||
|
|
||||||
|
code = "".join(secrets.choice(string.digits) for _ in range(code_length))
|
||||||
|
|
||||||
|
await ctx.event.send(
|
||||||
|
MessageChain().message(
|
||||||
|
"Tool call needs your approval before execution.\n"
|
||||||
|
f"Tool: `{ctx.tool_name}`\n"
|
||||||
|
f"Approval code: `{code}`\n"
|
||||||
|
"Please send this code to continue. "
|
||||||
|
"Any other message will cancel this tool call."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await _wait_for_code_input(
|
||||||
|
event=ctx.event,
|
||||||
|
expected_code=code,
|
||||||
|
timeout=timeout_seconds,
|
||||||
|
case_sensitive=case_sensitive,
|
||||||
|
)
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
logger.error(
|
||||||
|
"Tool call approval failed unexpectedly for %s: %s",
|
||||||
|
ctx.tool_name,
|
||||||
|
exc,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
return ToolCallApprovalResult(
|
||||||
|
approved=False,
|
||||||
|
reason="error",
|
||||||
|
detail=str(exc),
|
||||||
|
)
|
||||||
|
|
||||||
|
if not result.approved:
|
||||||
|
if result.reason == "timeout":
|
||||||
|
await ctx.event.send(
|
||||||
|
MessageChain().message(
|
||||||
|
f"Tool call `{ctx.tool_name}` approval timed out. This call was cancelled."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
await ctx.event.send(
|
||||||
|
MessageChain().message(
|
||||||
|
f"Tool call `{ctx.tool_name}` was cancelled."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
_STRATEGY_REGISTRY: dict[str, BaseToolCallApprovalStrategy] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def register_tool_call_approval_strategy(
|
||||||
|
strategy: BaseToolCallApprovalStrategy,
|
||||||
|
) -> None:
|
||||||
|
_STRATEGY_REGISTRY[strategy.name] = strategy
|
||||||
|
|
||||||
|
|
||||||
|
def _register_builtin_strategies() -> None:
|
||||||
|
register_tool_call_approval_strategy(DynamicCodeApprovalStrategy())
|
||||||
|
|
||||||
|
|
||||||
|
_register_builtin_strategies()
|
||||||
|
|
||||||
|
|
||||||
|
async def request_tool_call_approval(
|
||||||
|
*,
|
||||||
|
config: dict[str, T.Any] | None,
|
||||||
|
ctx: ToolCallApprovalContext,
|
||||||
|
) -> ToolCallApprovalResult:
|
||||||
|
if not config or not bool(config.get("enable", False)):
|
||||||
|
return ToolCallApprovalResult(approved=True, reason="approved")
|
||||||
|
|
||||||
|
strategy_name = (
|
||||||
|
str(config.get("strategy", "dynamic_code")).strip() or "dynamic_code"
|
||||||
|
)
|
||||||
|
strategy = _STRATEGY_REGISTRY.get(strategy_name)
|
||||||
|
if not strategy:
|
||||||
|
logger.warning("Unsupported tool call approval strategy: %s", strategy_name)
|
||||||
|
return ToolCallApprovalResult(
|
||||||
|
approved=False,
|
||||||
|
reason="unsupported_strategy",
|
||||||
|
detail=strategy_name,
|
||||||
|
)
|
||||||
|
return await strategy.request(ctx, config)
|
||||||
|
|
||||||
|
|
||||||
|
async def _wait_for_code_input(
|
||||||
|
*,
|
||||||
|
event: AstrMessageEvent,
|
||||||
|
expected_code: str,
|
||||||
|
timeout: int,
|
||||||
|
case_sensitive: bool,
|
||||||
|
) -> ToolCallApprovalResult:
|
||||||
|
session_filter = DefaultSessionFilter()
|
||||||
|
FILTERS.append(session_filter)
|
||||||
|
waiter = SessionWaiter(
|
||||||
|
session_filter=session_filter,
|
||||||
|
session_id=event.unified_msg_origin,
|
||||||
|
record_history_chains=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _handler(
|
||||||
|
controller: SessionController, incoming: AstrMessageEvent
|
||||||
|
) -> None:
|
||||||
|
raw_input = (incoming.message_str or "").strip()
|
||||||
|
if _is_code_match(
|
||||||
|
expected=expected_code,
|
||||||
|
actual=raw_input,
|
||||||
|
case_sensitive=case_sensitive,
|
||||||
|
):
|
||||||
|
if not controller.future.done():
|
||||||
|
controller.future.set_result(
|
||||||
|
ToolCallApprovalResult(approved=True, reason="approved"),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if not controller.future.done():
|
||||||
|
controller.future.set_result(
|
||||||
|
ToolCallApprovalResult(
|
||||||
|
approved=False,
|
||||||
|
reason="rejected",
|
||||||
|
detail=raw_input,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
controller.stop()
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await waiter.register_wait(handler=_handler, timeout=timeout)
|
||||||
|
except TimeoutError:
|
||||||
|
return ToolCallApprovalResult(approved=False, reason="timeout")
|
||||||
|
|
||||||
|
if isinstance(result, ToolCallApprovalResult):
|
||||||
|
return result
|
||||||
|
return ToolCallApprovalResult(
|
||||||
|
approved=False,
|
||||||
|
reason="error",
|
||||||
|
detail=f"Invalid approval result type: {type(result).__name__}",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_code_match(*, expected: str, actual: str, case_sensitive: bool) -> bool:
|
||||||
|
if case_sensitive:
|
||||||
|
return actual == expected
|
||||||
|
return actual.casefold() == expected.casefold()
|
||||||
|
|
||||||
|
|
||||||
|
def _safe_int(value: T.Any, *, default: int, minimum: int) -> int:
|
||||||
|
try:
|
||||||
|
parsed = int(value)
|
||||||
|
if parsed < minimum:
|
||||||
|
return minimum
|
||||||
|
return parsed
|
||||||
|
except Exception: # noqa: BLE001
|
||||||
|
return default
|
||||||
@@ -121,6 +121,8 @@ class MainAgentBuildConfig:
|
|||||||
timezone: str | None = None
|
timezone: str | None = None
|
||||||
max_quoted_fallback_images: int = 20
|
max_quoted_fallback_images: int = 20
|
||||||
"""Maximum number of images injected from quoted-message fallback extraction."""
|
"""Maximum number of images injected from quoted-message fallback extraction."""
|
||||||
|
tool_call_approval: dict = field(default_factory=dict)
|
||||||
|
"""Tool call approval configuration."""
|
||||||
|
|
||||||
|
|
||||||
@dataclass(slots=True)
|
@dataclass(slots=True)
|
||||||
@@ -1118,6 +1120,7 @@ async def build_main_agent(
|
|||||||
run_context=AgentContextWrapper(
|
run_context=AgentContextWrapper(
|
||||||
context=astr_agent_ctx,
|
context=astr_agent_ctx,
|
||||||
tool_call_timeout=config.tool_call_timeout,
|
tool_call_timeout=config.tool_call_timeout,
|
||||||
|
tool_call_approval=config.tool_call_approval,
|
||||||
),
|
),
|
||||||
tool_executor=FunctionToolExecutor(),
|
tool_executor=FunctionToolExecutor(),
|
||||||
agent_hooks=MAIN_AGENT_HOOKS,
|
agent_hooks=MAIN_AGENT_HOOKS,
|
||||||
|
|||||||
@@ -117,6 +117,15 @@ DEFAULT_CONFIG = {
|
|||||||
"max_agent_step": 30,
|
"max_agent_step": 30,
|
||||||
"tool_call_timeout": 60,
|
"tool_call_timeout": 60,
|
||||||
"tool_schema_mode": "full",
|
"tool_schema_mode": "full",
|
||||||
|
"tool_call_approval": {
|
||||||
|
"enable": False,
|
||||||
|
"strategy": "dynamic_code",
|
||||||
|
"timeout": 60,
|
||||||
|
"dynamic_code": {
|
||||||
|
"code_length": 6,
|
||||||
|
"case_sensitive": False,
|
||||||
|
},
|
||||||
|
},
|
||||||
"llm_safety_mode": True,
|
"llm_safety_mode": True,
|
||||||
"safety_mode_strategy": "system_prompt", # TODO: llm judge
|
"safety_mode_strategy": "system_prompt", # TODO: llm judge
|
||||||
"file_extract": {
|
"file_extract": {
|
||||||
@@ -2330,6 +2339,31 @@ CONFIG_METADATA_2 = {
|
|||||||
"tool_schema_mode": {
|
"tool_schema_mode": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
},
|
},
|
||||||
|
"tool_call_approval": {
|
||||||
|
"type": "object",
|
||||||
|
"items": {
|
||||||
|
"enable": {
|
||||||
|
"type": "bool",
|
||||||
|
},
|
||||||
|
"strategy": {
|
||||||
|
"type": "string",
|
||||||
|
},
|
||||||
|
"timeout": {
|
||||||
|
"type": "int",
|
||||||
|
},
|
||||||
|
"dynamic_code": {
|
||||||
|
"type": "object",
|
||||||
|
"items": {
|
||||||
|
"code_length": {
|
||||||
|
"type": "int",
|
||||||
|
},
|
||||||
|
"case_sensitive": {
|
||||||
|
"type": "bool",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
"file_extract": {
|
"file_extract": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"items": {
|
"items": {
|
||||||
@@ -3066,6 +3100,50 @@ CONFIG_METADATA_3 = {
|
|||||||
"provider_settings.agent_runner_type": "local",
|
"provider_settings.agent_runner_type": "local",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
"provider_settings.tool_call_approval.enable": {
|
||||||
|
"description": "启用工具调用确认",
|
||||||
|
"type": "bool",
|
||||||
|
"hint": "开启后,工具调用需要用户确认后才会执行。",
|
||||||
|
"condition": {
|
||||||
|
"provider_settings.agent_runner_type": "local",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"provider_settings.tool_call_approval.strategy": {
|
||||||
|
"description": "工具调用确认策略",
|
||||||
|
"type": "string",
|
||||||
|
"options": ["dynamic_code"],
|
||||||
|
"labels": ["Dynamic Code(动态码)"],
|
||||||
|
"condition": {
|
||||||
|
"provider_settings.agent_runner_type": "local",
|
||||||
|
"provider_settings.tool_call_approval.enable": True,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"provider_settings.tool_call_approval.timeout": {
|
||||||
|
"description": "工具调用确认超时(秒)",
|
||||||
|
"type": "int",
|
||||||
|
"condition": {
|
||||||
|
"provider_settings.agent_runner_type": "local",
|
||||||
|
"provider_settings.tool_call_approval.enable": True,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"provider_settings.tool_call_approval.dynamic_code.code_length": {
|
||||||
|
"description": "动态确认码长度",
|
||||||
|
"type": "int",
|
||||||
|
"condition": {
|
||||||
|
"provider_settings.agent_runner_type": "local",
|
||||||
|
"provider_settings.tool_call_approval.enable": True,
|
||||||
|
"provider_settings.tool_call_approval.strategy": "dynamic_code",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"provider_settings.tool_call_approval.dynamic_code.case_sensitive": {
|
||||||
|
"description": "动态确认码区分大小写",
|
||||||
|
"type": "bool",
|
||||||
|
"condition": {
|
||||||
|
"provider_settings.agent_runner_type": "local",
|
||||||
|
"provider_settings.tool_call_approval.enable": True,
|
||||||
|
"provider_settings.tool_call_approval.strategy": "dynamic_code",
|
||||||
|
},
|
||||||
|
},
|
||||||
"provider_settings.wake_prefix": {
|
"provider_settings.wake_prefix": {
|
||||||
"description": "LLM 聊天额外唤醒前缀 ",
|
"description": "LLM 聊天额外唤醒前缀 ",
|
||||||
"type": "string",
|
"type": "string",
|
||||||
|
|||||||
@@ -44,6 +44,7 @@ class InternalAgentSubStage(Stage):
|
|||||||
]
|
]
|
||||||
self.max_step: int = settings.get("max_agent_step", 30)
|
self.max_step: int = settings.get("max_agent_step", 30)
|
||||||
self.tool_call_timeout: int = settings.get("tool_call_timeout", 60)
|
self.tool_call_timeout: int = settings.get("tool_call_timeout", 60)
|
||||||
|
self.tool_call_approval: dict = settings.get("tool_call_approval", {})
|
||||||
self.tool_schema_mode: str = settings.get("tool_schema_mode", "full")
|
self.tool_schema_mode: str = settings.get("tool_schema_mode", "full")
|
||||||
if self.tool_schema_mode not in ("skills_like", "full"):
|
if self.tool_schema_mode not in ("skills_like", "full"):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@@ -124,6 +125,7 @@ class InternalAgentSubStage(Stage):
|
|||||||
subagent_orchestrator=conf.get("subagent_orchestrator", {}),
|
subagent_orchestrator=conf.get("subagent_orchestrator", {}),
|
||||||
timezone=self.ctx.plugin_manager.context.get_config().get("timezone"),
|
timezone=self.ctx.plugin_manager.context.get_config().get("timezone"),
|
||||||
max_quoted_fallback_images=settings.get("max_quoted_fallback_images", 20),
|
max_quoted_fallback_images=settings.get("max_quoted_fallback_images", 20),
|
||||||
|
tool_call_approval=self.tool_call_approval,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def process(
|
async def process(
|
||||||
|
|||||||
@@ -0,0 +1,110 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||||
|
|
||||||
|
from astrbot.core.agent.tool_call_approval import (
|
||||||
|
ToolCallApprovalContext,
|
||||||
|
ToolCallApprovalResult,
|
||||||
|
request_tool_call_approval,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DummyEvent:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.unified_msg_origin = "test:friend:test_user"
|
||||||
|
self.sent_messages = []
|
||||||
|
self.message_str = ""
|
||||||
|
|
||||||
|
async def send(self, message) -> None:
|
||||||
|
self.sent_messages.append(message)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_request_tool_call_approval_disabled_returns_approved():
|
||||||
|
event = DummyEvent()
|
||||||
|
result = await request_tool_call_approval(
|
||||||
|
config={"enable": False},
|
||||||
|
ctx=ToolCallApprovalContext(
|
||||||
|
event=event,
|
||||||
|
tool_name="test_tool",
|
||||||
|
tool_args={},
|
||||||
|
tool_call_id="call_1",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
assert result.approved is True
|
||||||
|
assert result.reason == "approved"
|
||||||
|
assert len(event.sent_messages) == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_dynamic_code_approval_passed(monkeypatch):
|
||||||
|
async def _mock_wait(*args, **kwargs):
|
||||||
|
return ToolCallApprovalResult(approved=True, reason="approved")
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"astrbot.core.agent.tool_call_approval._wait_for_code_input",
|
||||||
|
_mock_wait,
|
||||||
|
)
|
||||||
|
|
||||||
|
event = DummyEvent()
|
||||||
|
result = await request_tool_call_approval(
|
||||||
|
config={"enable": True, "strategy": "dynamic_code"},
|
||||||
|
ctx=ToolCallApprovalContext(
|
||||||
|
event=event,
|
||||||
|
tool_name="test_tool",
|
||||||
|
tool_args={"query": "hello"},
|
||||||
|
tool_call_id="call_2",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
assert result.approved is True
|
||||||
|
assert result.reason == "approved"
|
||||||
|
assert len(event.sent_messages) == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_dynamic_code_approval_rejected(monkeypatch):
|
||||||
|
async def _mock_wait(*args, **kwargs):
|
||||||
|
return ToolCallApprovalResult(
|
||||||
|
approved=False,
|
||||||
|
reason="rejected",
|
||||||
|
detail="not_code",
|
||||||
|
)
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"astrbot.core.agent.tool_call_approval._wait_for_code_input",
|
||||||
|
_mock_wait,
|
||||||
|
)
|
||||||
|
|
||||||
|
event = DummyEvent()
|
||||||
|
result = await request_tool_call_approval(
|
||||||
|
config={"enable": True, "strategy": "dynamic_code"},
|
||||||
|
ctx=ToolCallApprovalContext(
|
||||||
|
event=event,
|
||||||
|
tool_name="test_tool",
|
||||||
|
tool_args={},
|
||||||
|
tool_call_id="call_3",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
assert result.approved is False
|
||||||
|
assert result.reason == "rejected"
|
||||||
|
assert len(event.sent_messages) == 2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_request_tool_call_approval_unsupported_strategy():
|
||||||
|
event = DummyEvent()
|
||||||
|
result = await request_tool_call_approval(
|
||||||
|
config={"enable": True, "strategy": "unknown_strategy"},
|
||||||
|
ctx=ToolCallApprovalContext(
|
||||||
|
event=event,
|
||||||
|
tool_name="test_tool",
|
||||||
|
tool_args={},
|
||||||
|
tool_call_id="call_4",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
assert result.approved is False
|
||||||
|
assert result.reason == "unsupported_strategy"
|
||||||
|
assert len(event.sent_messages) == 0
|
||||||
Reference in New Issue
Block a user