feat: implement tool call approval mechanism with dynamic code strategy

This commit is contained in:
Soulter
2026-02-24 09:41:12 +08:00
parent 28bfb3b8b2
commit bb1b6a3814
7 changed files with 482 additions and 0 deletions
+2
View File
@@ -17,6 +17,8 @@ class ContextWrapper(Generic[TContext]):
messages: list[Message] = Field(default_factory=list)
"""This field stores the llm message context for the agent run, agent runners will maintain this field automatically."""
tool_call_timeout: int = 60 # Default tool call timeout in seconds
tool_call_approval: dict[str, Any] = Field(default_factory=dict)
"""Tool call approval runtime configuration."""
NoContext = ContextWrapper[None]
@@ -37,6 +37,10 @@ from ..hooks import BaseAgentRunHooks
from ..message import AssistantMessageSegment, Message, ToolCallMessageSegment
from ..response import AgentResponseData, AgentStats
from ..run_context import ContextWrapper, TContext
from ..tool_call_approval import (
ToolCallApprovalContext,
request_tool_call_approval,
)
from ..tool_executor import BaseFunctionToolExecutor
from .base import AgentResponse, AgentState, BaseAgentRunner
@@ -659,6 +663,41 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
# 如果没有 handler(如 MCP 工具),使用所有参数
valid_params = func_tool_args
approval_cfg = self.run_context.tool_call_approval
if approval_cfg.get("enable", False):
event = getattr(self.run_context.context, "event", None)
if event is None:
tool_call_result_blocks.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content=(
f"error: tool call approval is enabled, but event context is unavailable for `{func_tool_name}`."
),
),
)
continue
approval_result = await request_tool_call_approval(
config=approval_cfg,
ctx=ToolCallApprovalContext(
event=event,
tool_name=func_tool_name,
tool_args=valid_params,
tool_call_id=func_tool_id,
),
)
if not approval_result.approved:
tool_call_result_blocks.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content=approval_result.to_tool_result_text(
func_tool_name
),
),
)
continue
try:
await self.agent_hooks.on_tool_start(
self.run_context,
+248
View File
@@ -0,0 +1,248 @@
from __future__ import annotations
import secrets
import string
import typing as T
from abc import ABC, abstractmethod
from dataclasses import dataclass
from astrbot import logger
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.utils.session_waiter import (
FILTERS,
DefaultSessionFilter,
SessionController,
SessionWaiter,
)
ApprovalReason = T.Literal[
"approved",
"rejected",
"timeout",
"unsupported_strategy",
"error",
]
@dataclass(slots=True)
class ToolCallApprovalContext:
event: AstrMessageEvent
tool_name: str
tool_args: dict[str, T.Any]
tool_call_id: str
@dataclass(slots=True)
class ToolCallApprovalResult:
approved: bool
reason: ApprovalReason
detail: str = ""
def to_tool_result_text(self, tool_name: str) -> str:
if self.approved:
return f"tool call approval passed: {tool_name}"
if self.reason == "timeout":
return (
f"error: tool call approval timed out for `{tool_name}`. "
"The tool call was cancelled."
)
if self.reason == "unsupported_strategy":
return (
f"error: tool call approval strategy is unsupported for `{tool_name}`. "
"The tool call was cancelled."
)
if self.reason == "error":
return (
f"error: tool call approval failed for `{tool_name}` ({self.detail}). "
"The tool call was cancelled."
)
return (
f"error: user rejected tool call approval for `{tool_name}`. "
"The tool call was cancelled."
)
class BaseToolCallApprovalStrategy(ABC):
@property
@abstractmethod
def name(self) -> str: ...
@abstractmethod
async def request(
self,
ctx: ToolCallApprovalContext,
config: dict[str, T.Any],
) -> ToolCallApprovalResult: ...
class DynamicCodeApprovalStrategy(BaseToolCallApprovalStrategy):
@property
def name(self) -> str:
return "dynamic_code"
async def request(
self,
ctx: ToolCallApprovalContext,
config: dict[str, T.Any],
) -> ToolCallApprovalResult:
timeout_seconds = _safe_int(config.get("timeout", 60), default=60, minimum=1)
dynamic_cfg = config.get("dynamic_code", {})
if not isinstance(dynamic_cfg, dict):
dynamic_cfg = {}
code_length = _safe_int(dynamic_cfg.get("code_length", 6), default=6, minimum=4)
case_sensitive = bool(dynamic_cfg.get("case_sensitive", False))
code = "".join(secrets.choice(string.digits) for _ in range(code_length))
await ctx.event.send(
MessageChain().message(
"Tool call needs your approval before execution.\n"
f"Tool: `{ctx.tool_name}`\n"
f"Approval code: `{code}`\n"
"Please send this code to continue. "
"Any other message will cancel this tool call."
)
)
try:
result = await _wait_for_code_input(
event=ctx.event,
expected_code=code,
timeout=timeout_seconds,
case_sensitive=case_sensitive,
)
except Exception as exc: # noqa: BLE001
logger.error(
"Tool call approval failed unexpectedly for %s: %s",
ctx.tool_name,
exc,
exc_info=True,
)
return ToolCallApprovalResult(
approved=False,
reason="error",
detail=str(exc),
)
if not result.approved:
if result.reason == "timeout":
await ctx.event.send(
MessageChain().message(
f"Tool call `{ctx.tool_name}` approval timed out. This call was cancelled."
)
)
else:
await ctx.event.send(
MessageChain().message(
f"Tool call `{ctx.tool_name}` was cancelled."
)
)
return result
_STRATEGY_REGISTRY: dict[str, BaseToolCallApprovalStrategy] = {}
def register_tool_call_approval_strategy(
strategy: BaseToolCallApprovalStrategy,
) -> None:
_STRATEGY_REGISTRY[strategy.name] = strategy
def _register_builtin_strategies() -> None:
register_tool_call_approval_strategy(DynamicCodeApprovalStrategy())
_register_builtin_strategies()
async def request_tool_call_approval(
*,
config: dict[str, T.Any] | None,
ctx: ToolCallApprovalContext,
) -> ToolCallApprovalResult:
if not config or not bool(config.get("enable", False)):
return ToolCallApprovalResult(approved=True, reason="approved")
strategy_name = (
str(config.get("strategy", "dynamic_code")).strip() or "dynamic_code"
)
strategy = _STRATEGY_REGISTRY.get(strategy_name)
if not strategy:
logger.warning("Unsupported tool call approval strategy: %s", strategy_name)
return ToolCallApprovalResult(
approved=False,
reason="unsupported_strategy",
detail=strategy_name,
)
return await strategy.request(ctx, config)
async def _wait_for_code_input(
*,
event: AstrMessageEvent,
expected_code: str,
timeout: int,
case_sensitive: bool,
) -> ToolCallApprovalResult:
session_filter = DefaultSessionFilter()
FILTERS.append(session_filter)
waiter = SessionWaiter(
session_filter=session_filter,
session_id=event.unified_msg_origin,
record_history_chains=False,
)
async def _handler(
controller: SessionController, incoming: AstrMessageEvent
) -> None:
raw_input = (incoming.message_str or "").strip()
if _is_code_match(
expected=expected_code,
actual=raw_input,
case_sensitive=case_sensitive,
):
if not controller.future.done():
controller.future.set_result(
ToolCallApprovalResult(approved=True, reason="approved"),
)
else:
if not controller.future.done():
controller.future.set_result(
ToolCallApprovalResult(
approved=False,
reason="rejected",
detail=raw_input,
)
)
controller.stop()
try:
result = await waiter.register_wait(handler=_handler, timeout=timeout)
except TimeoutError:
return ToolCallApprovalResult(approved=False, reason="timeout")
if isinstance(result, ToolCallApprovalResult):
return result
return ToolCallApprovalResult(
approved=False,
reason="error",
detail=f"Invalid approval result type: {type(result).__name__}",
)
def _is_code_match(*, expected: str, actual: str, case_sensitive: bool) -> bool:
if case_sensitive:
return actual == expected
return actual.casefold() == expected.casefold()
def _safe_int(value: T.Any, *, default: int, minimum: int) -> int:
try:
parsed = int(value)
if parsed < minimum:
return minimum
return parsed
except Exception: # noqa: BLE001
return default
+3
View File
@@ -121,6 +121,8 @@ class MainAgentBuildConfig:
timezone: str | None = None
max_quoted_fallback_images: int = 20
"""Maximum number of images injected from quoted-message fallback extraction."""
tool_call_approval: dict = field(default_factory=dict)
"""Tool call approval configuration."""
@dataclass(slots=True)
@@ -1118,6 +1120,7 @@ async def build_main_agent(
run_context=AgentContextWrapper(
context=astr_agent_ctx,
tool_call_timeout=config.tool_call_timeout,
tool_call_approval=config.tool_call_approval,
),
tool_executor=FunctionToolExecutor(),
agent_hooks=MAIN_AGENT_HOOKS,
+78
View File
@@ -117,6 +117,15 @@ DEFAULT_CONFIG = {
"max_agent_step": 30,
"tool_call_timeout": 60,
"tool_schema_mode": "full",
"tool_call_approval": {
"enable": False,
"strategy": "dynamic_code",
"timeout": 60,
"dynamic_code": {
"code_length": 6,
"case_sensitive": False,
},
},
"llm_safety_mode": True,
"safety_mode_strategy": "system_prompt", # TODO: llm judge
"file_extract": {
@@ -2330,6 +2339,31 @@ CONFIG_METADATA_2 = {
"tool_schema_mode": {
"type": "string",
},
"tool_call_approval": {
"type": "object",
"items": {
"enable": {
"type": "bool",
},
"strategy": {
"type": "string",
},
"timeout": {
"type": "int",
},
"dynamic_code": {
"type": "object",
"items": {
"code_length": {
"type": "int",
},
"case_sensitive": {
"type": "bool",
},
},
},
},
},
"file_extract": {
"type": "object",
"items": {
@@ -3066,6 +3100,50 @@ CONFIG_METADATA_3 = {
"provider_settings.agent_runner_type": "local",
},
},
"provider_settings.tool_call_approval.enable": {
"description": "启用工具调用确认",
"type": "bool",
"hint": "开启后,工具调用需要用户确认后才会执行。",
"condition": {
"provider_settings.agent_runner_type": "local",
},
},
"provider_settings.tool_call_approval.strategy": {
"description": "工具调用确认策略",
"type": "string",
"options": ["dynamic_code"],
"labels": ["Dynamic Code(动态码)"],
"condition": {
"provider_settings.agent_runner_type": "local",
"provider_settings.tool_call_approval.enable": True,
},
},
"provider_settings.tool_call_approval.timeout": {
"description": "工具调用确认超时(秒)",
"type": "int",
"condition": {
"provider_settings.agent_runner_type": "local",
"provider_settings.tool_call_approval.enable": True,
},
},
"provider_settings.tool_call_approval.dynamic_code.code_length": {
"description": "动态确认码长度",
"type": "int",
"condition": {
"provider_settings.agent_runner_type": "local",
"provider_settings.tool_call_approval.enable": True,
"provider_settings.tool_call_approval.strategy": "dynamic_code",
},
},
"provider_settings.tool_call_approval.dynamic_code.case_sensitive": {
"description": "动态确认码区分大小写",
"type": "bool",
"condition": {
"provider_settings.agent_runner_type": "local",
"provider_settings.tool_call_approval.enable": True,
"provider_settings.tool_call_approval.strategy": "dynamic_code",
},
},
"provider_settings.wake_prefix": {
"description": "LLM 聊天额外唤醒前缀 ",
"type": "string",
@@ -44,6 +44,7 @@ class InternalAgentSubStage(Stage):
]
self.max_step: int = settings.get("max_agent_step", 30)
self.tool_call_timeout: int = settings.get("tool_call_timeout", 60)
self.tool_call_approval: dict = settings.get("tool_call_approval", {})
self.tool_schema_mode: str = settings.get("tool_schema_mode", "full")
if self.tool_schema_mode not in ("skills_like", "full"):
logger.warning(
@@ -124,6 +125,7 @@ class InternalAgentSubStage(Stage):
subagent_orchestrator=conf.get("subagent_orchestrator", {}),
timezone=self.ctx.plugin_manager.context.get_config().get("timezone"),
max_quoted_fallback_images=settings.get("max_quoted_fallback_images", 20),
tool_call_approval=self.tool_call_approval,
)
async def process(
+110
View File
@@ -0,0 +1,110 @@
import os
import sys
import pytest
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from astrbot.core.agent.tool_call_approval import (
ToolCallApprovalContext,
ToolCallApprovalResult,
request_tool_call_approval,
)
class DummyEvent:
def __init__(self) -> None:
self.unified_msg_origin = "test:friend:test_user"
self.sent_messages = []
self.message_str = ""
async def send(self, message) -> None:
self.sent_messages.append(message)
@pytest.mark.asyncio
async def test_request_tool_call_approval_disabled_returns_approved():
event = DummyEvent()
result = await request_tool_call_approval(
config={"enable": False},
ctx=ToolCallApprovalContext(
event=event,
tool_name="test_tool",
tool_args={},
tool_call_id="call_1",
),
)
assert result.approved is True
assert result.reason == "approved"
assert len(event.sent_messages) == 0
@pytest.mark.asyncio
async def test_dynamic_code_approval_passed(monkeypatch):
async def _mock_wait(*args, **kwargs):
return ToolCallApprovalResult(approved=True, reason="approved")
monkeypatch.setattr(
"astrbot.core.agent.tool_call_approval._wait_for_code_input",
_mock_wait,
)
event = DummyEvent()
result = await request_tool_call_approval(
config={"enable": True, "strategy": "dynamic_code"},
ctx=ToolCallApprovalContext(
event=event,
tool_name="test_tool",
tool_args={"query": "hello"},
tool_call_id="call_2",
),
)
assert result.approved is True
assert result.reason == "approved"
assert len(event.sent_messages) == 1
@pytest.mark.asyncio
async def test_dynamic_code_approval_rejected(monkeypatch):
async def _mock_wait(*args, **kwargs):
return ToolCallApprovalResult(
approved=False,
reason="rejected",
detail="not_code",
)
monkeypatch.setattr(
"astrbot.core.agent.tool_call_approval._wait_for_code_input",
_mock_wait,
)
event = DummyEvent()
result = await request_tool_call_approval(
config={"enable": True, "strategy": "dynamic_code"},
ctx=ToolCallApprovalContext(
event=event,
tool_name="test_tool",
tool_args={},
tool_call_id="call_3",
),
)
assert result.approved is False
assert result.reason == "rejected"
assert len(event.sent_messages) == 2
@pytest.mark.asyncio
async def test_request_tool_call_approval_unsupported_strategy():
event = DummyEvent()
result = await request_tool_call_approval(
config={"enable": True, "strategy": "unknown_strategy"},
ctx=ToolCallApprovalContext(
event=event,
tool_name="test_tool",
tool_args={},
tool_call_id="call_4",
),
)
assert result.approved is False
assert result.reason == "unsupported_strategy"
assert len(event.sent_messages) == 0