Compare commits

...

1 Commits

Author SHA1 Message Date
Soulter bb1b6a3814 feat: implement tool call approval mechanism with dynamic code strategy 2026-02-24 09:41:12 +08:00
7 changed files with 482 additions and 0 deletions
+2
View File
@@ -17,6 +17,8 @@ class ContextWrapper(Generic[TContext]):
messages: list[Message] = Field(default_factory=list) messages: list[Message] = Field(default_factory=list)
"""This field stores the llm message context for the agent run, agent runners will maintain this field automatically.""" """This field stores the llm message context for the agent run, agent runners will maintain this field automatically."""
tool_call_timeout: int = 60 # Default tool call timeout in seconds tool_call_timeout: int = 60 # Default tool call timeout in seconds
tool_call_approval: dict[str, Any] = Field(default_factory=dict)
"""Tool call approval runtime configuration."""
NoContext = ContextWrapper[None] NoContext = ContextWrapper[None]
@@ -37,6 +37,10 @@ from ..hooks import BaseAgentRunHooks
from ..message import AssistantMessageSegment, Message, ToolCallMessageSegment from ..message import AssistantMessageSegment, Message, ToolCallMessageSegment
from ..response import AgentResponseData, AgentStats from ..response import AgentResponseData, AgentStats
from ..run_context import ContextWrapper, TContext from ..run_context import ContextWrapper, TContext
from ..tool_call_approval import (
ToolCallApprovalContext,
request_tool_call_approval,
)
from ..tool_executor import BaseFunctionToolExecutor from ..tool_executor import BaseFunctionToolExecutor
from .base import AgentResponse, AgentState, BaseAgentRunner from .base import AgentResponse, AgentState, BaseAgentRunner
@@ -659,6 +663,41 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
# 如果没有 handler(如 MCP 工具),使用所有参数 # 如果没有 handler(如 MCP 工具),使用所有参数
valid_params = func_tool_args valid_params = func_tool_args
approval_cfg = self.run_context.tool_call_approval
if approval_cfg.get("enable", False):
event = getattr(self.run_context.context, "event", None)
if event is None:
tool_call_result_blocks.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content=(
f"error: tool call approval is enabled, but event context is unavailable for `{func_tool_name}`."
),
),
)
continue
approval_result = await request_tool_call_approval(
config=approval_cfg,
ctx=ToolCallApprovalContext(
event=event,
tool_name=func_tool_name,
tool_args=valid_params,
tool_call_id=func_tool_id,
),
)
if not approval_result.approved:
tool_call_result_blocks.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content=approval_result.to_tool_result_text(
func_tool_name
),
),
)
continue
try: try:
await self.agent_hooks.on_tool_start( await self.agent_hooks.on_tool_start(
self.run_context, self.run_context,
+248
View File
@@ -0,0 +1,248 @@
from __future__ import annotations
import secrets
import string
import typing as T
from abc import ABC, abstractmethod
from dataclasses import dataclass
from astrbot import logger
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.utils.session_waiter import (
FILTERS,
DefaultSessionFilter,
SessionController,
SessionWaiter,
)
ApprovalReason = T.Literal[
"approved",
"rejected",
"timeout",
"unsupported_strategy",
"error",
]
@dataclass(slots=True)
class ToolCallApprovalContext:
event: AstrMessageEvent
tool_name: str
tool_args: dict[str, T.Any]
tool_call_id: str
@dataclass(slots=True)
class ToolCallApprovalResult:
approved: bool
reason: ApprovalReason
detail: str = ""
def to_tool_result_text(self, tool_name: str) -> str:
if self.approved:
return f"tool call approval passed: {tool_name}"
if self.reason == "timeout":
return (
f"error: tool call approval timed out for `{tool_name}`. "
"The tool call was cancelled."
)
if self.reason == "unsupported_strategy":
return (
f"error: tool call approval strategy is unsupported for `{tool_name}`. "
"The tool call was cancelled."
)
if self.reason == "error":
return (
f"error: tool call approval failed for `{tool_name}` ({self.detail}). "
"The tool call was cancelled."
)
return (
f"error: user rejected tool call approval for `{tool_name}`. "
"The tool call was cancelled."
)
class BaseToolCallApprovalStrategy(ABC):
@property
@abstractmethod
def name(self) -> str: ...
@abstractmethod
async def request(
self,
ctx: ToolCallApprovalContext,
config: dict[str, T.Any],
) -> ToolCallApprovalResult: ...
class DynamicCodeApprovalStrategy(BaseToolCallApprovalStrategy):
@property
def name(self) -> str:
return "dynamic_code"
async def request(
self,
ctx: ToolCallApprovalContext,
config: dict[str, T.Any],
) -> ToolCallApprovalResult:
timeout_seconds = _safe_int(config.get("timeout", 60), default=60, minimum=1)
dynamic_cfg = config.get("dynamic_code", {})
if not isinstance(dynamic_cfg, dict):
dynamic_cfg = {}
code_length = _safe_int(dynamic_cfg.get("code_length", 6), default=6, minimum=4)
case_sensitive = bool(dynamic_cfg.get("case_sensitive", False))
code = "".join(secrets.choice(string.digits) for _ in range(code_length))
await ctx.event.send(
MessageChain().message(
"Tool call needs your approval before execution.\n"
f"Tool: `{ctx.tool_name}`\n"
f"Approval code: `{code}`\n"
"Please send this code to continue. "
"Any other message will cancel this tool call."
)
)
try:
result = await _wait_for_code_input(
event=ctx.event,
expected_code=code,
timeout=timeout_seconds,
case_sensitive=case_sensitive,
)
except Exception as exc: # noqa: BLE001
logger.error(
"Tool call approval failed unexpectedly for %s: %s",
ctx.tool_name,
exc,
exc_info=True,
)
return ToolCallApprovalResult(
approved=False,
reason="error",
detail=str(exc),
)
if not result.approved:
if result.reason == "timeout":
await ctx.event.send(
MessageChain().message(
f"Tool call `{ctx.tool_name}` approval timed out. This call was cancelled."
)
)
else:
await ctx.event.send(
MessageChain().message(
f"Tool call `{ctx.tool_name}` was cancelled."
)
)
return result
_STRATEGY_REGISTRY: dict[str, BaseToolCallApprovalStrategy] = {}
def register_tool_call_approval_strategy(
strategy: BaseToolCallApprovalStrategy,
) -> None:
_STRATEGY_REGISTRY[strategy.name] = strategy
def _register_builtin_strategies() -> None:
register_tool_call_approval_strategy(DynamicCodeApprovalStrategy())
_register_builtin_strategies()
async def request_tool_call_approval(
*,
config: dict[str, T.Any] | None,
ctx: ToolCallApprovalContext,
) -> ToolCallApprovalResult:
if not config or not bool(config.get("enable", False)):
return ToolCallApprovalResult(approved=True, reason="approved")
strategy_name = (
str(config.get("strategy", "dynamic_code")).strip() or "dynamic_code"
)
strategy = _STRATEGY_REGISTRY.get(strategy_name)
if not strategy:
logger.warning("Unsupported tool call approval strategy: %s", strategy_name)
return ToolCallApprovalResult(
approved=False,
reason="unsupported_strategy",
detail=strategy_name,
)
return await strategy.request(ctx, config)
async def _wait_for_code_input(
*,
event: AstrMessageEvent,
expected_code: str,
timeout: int,
case_sensitive: bool,
) -> ToolCallApprovalResult:
session_filter = DefaultSessionFilter()
FILTERS.append(session_filter)
waiter = SessionWaiter(
session_filter=session_filter,
session_id=event.unified_msg_origin,
record_history_chains=False,
)
async def _handler(
controller: SessionController, incoming: AstrMessageEvent
) -> None:
raw_input = (incoming.message_str or "").strip()
if _is_code_match(
expected=expected_code,
actual=raw_input,
case_sensitive=case_sensitive,
):
if not controller.future.done():
controller.future.set_result(
ToolCallApprovalResult(approved=True, reason="approved"),
)
else:
if not controller.future.done():
controller.future.set_result(
ToolCallApprovalResult(
approved=False,
reason="rejected",
detail=raw_input,
)
)
controller.stop()
try:
result = await waiter.register_wait(handler=_handler, timeout=timeout)
except TimeoutError:
return ToolCallApprovalResult(approved=False, reason="timeout")
if isinstance(result, ToolCallApprovalResult):
return result
return ToolCallApprovalResult(
approved=False,
reason="error",
detail=f"Invalid approval result type: {type(result).__name__}",
)
def _is_code_match(*, expected: str, actual: str, case_sensitive: bool) -> bool:
if case_sensitive:
return actual == expected
return actual.casefold() == expected.casefold()
def _safe_int(value: T.Any, *, default: int, minimum: int) -> int:
try:
parsed = int(value)
if parsed < minimum:
return minimum
return parsed
except Exception: # noqa: BLE001
return default
+3
View File
@@ -121,6 +121,8 @@ class MainAgentBuildConfig:
timezone: str | None = None timezone: str | None = None
max_quoted_fallback_images: int = 20 max_quoted_fallback_images: int = 20
"""Maximum number of images injected from quoted-message fallback extraction.""" """Maximum number of images injected from quoted-message fallback extraction."""
tool_call_approval: dict = field(default_factory=dict)
"""Tool call approval configuration."""
@dataclass(slots=True) @dataclass(slots=True)
@@ -1118,6 +1120,7 @@ async def build_main_agent(
run_context=AgentContextWrapper( run_context=AgentContextWrapper(
context=astr_agent_ctx, context=astr_agent_ctx,
tool_call_timeout=config.tool_call_timeout, tool_call_timeout=config.tool_call_timeout,
tool_call_approval=config.tool_call_approval,
), ),
tool_executor=FunctionToolExecutor(), tool_executor=FunctionToolExecutor(),
agent_hooks=MAIN_AGENT_HOOKS, agent_hooks=MAIN_AGENT_HOOKS,
+78
View File
@@ -117,6 +117,15 @@ DEFAULT_CONFIG = {
"max_agent_step": 30, "max_agent_step": 30,
"tool_call_timeout": 60, "tool_call_timeout": 60,
"tool_schema_mode": "full", "tool_schema_mode": "full",
"tool_call_approval": {
"enable": False,
"strategy": "dynamic_code",
"timeout": 60,
"dynamic_code": {
"code_length": 6,
"case_sensitive": False,
},
},
"llm_safety_mode": True, "llm_safety_mode": True,
"safety_mode_strategy": "system_prompt", # TODO: llm judge "safety_mode_strategy": "system_prompt", # TODO: llm judge
"file_extract": { "file_extract": {
@@ -2330,6 +2339,31 @@ CONFIG_METADATA_2 = {
"tool_schema_mode": { "tool_schema_mode": {
"type": "string", "type": "string",
}, },
"tool_call_approval": {
"type": "object",
"items": {
"enable": {
"type": "bool",
},
"strategy": {
"type": "string",
},
"timeout": {
"type": "int",
},
"dynamic_code": {
"type": "object",
"items": {
"code_length": {
"type": "int",
},
"case_sensitive": {
"type": "bool",
},
},
},
},
},
"file_extract": { "file_extract": {
"type": "object", "type": "object",
"items": { "items": {
@@ -3066,6 +3100,50 @@ CONFIG_METADATA_3 = {
"provider_settings.agent_runner_type": "local", "provider_settings.agent_runner_type": "local",
}, },
}, },
"provider_settings.tool_call_approval.enable": {
"description": "启用工具调用确认",
"type": "bool",
"hint": "开启后,工具调用需要用户确认后才会执行。",
"condition": {
"provider_settings.agent_runner_type": "local",
},
},
"provider_settings.tool_call_approval.strategy": {
"description": "工具调用确认策略",
"type": "string",
"options": ["dynamic_code"],
"labels": ["Dynamic Code(动态码)"],
"condition": {
"provider_settings.agent_runner_type": "local",
"provider_settings.tool_call_approval.enable": True,
},
},
"provider_settings.tool_call_approval.timeout": {
"description": "工具调用确认超时(秒)",
"type": "int",
"condition": {
"provider_settings.agent_runner_type": "local",
"provider_settings.tool_call_approval.enable": True,
},
},
"provider_settings.tool_call_approval.dynamic_code.code_length": {
"description": "动态确认码长度",
"type": "int",
"condition": {
"provider_settings.agent_runner_type": "local",
"provider_settings.tool_call_approval.enable": True,
"provider_settings.tool_call_approval.strategy": "dynamic_code",
},
},
"provider_settings.tool_call_approval.dynamic_code.case_sensitive": {
"description": "动态确认码区分大小写",
"type": "bool",
"condition": {
"provider_settings.agent_runner_type": "local",
"provider_settings.tool_call_approval.enable": True,
"provider_settings.tool_call_approval.strategy": "dynamic_code",
},
},
"provider_settings.wake_prefix": { "provider_settings.wake_prefix": {
"description": "LLM 聊天额外唤醒前缀 ", "description": "LLM 聊天额外唤醒前缀 ",
"type": "string", "type": "string",
@@ -44,6 +44,7 @@ class InternalAgentSubStage(Stage):
] ]
self.max_step: int = settings.get("max_agent_step", 30) self.max_step: int = settings.get("max_agent_step", 30)
self.tool_call_timeout: int = settings.get("tool_call_timeout", 60) self.tool_call_timeout: int = settings.get("tool_call_timeout", 60)
self.tool_call_approval: dict = settings.get("tool_call_approval", {})
self.tool_schema_mode: str = settings.get("tool_schema_mode", "full") self.tool_schema_mode: str = settings.get("tool_schema_mode", "full")
if self.tool_schema_mode not in ("skills_like", "full"): if self.tool_schema_mode not in ("skills_like", "full"):
logger.warning( logger.warning(
@@ -124,6 +125,7 @@ class InternalAgentSubStage(Stage):
subagent_orchestrator=conf.get("subagent_orchestrator", {}), subagent_orchestrator=conf.get("subagent_orchestrator", {}),
timezone=self.ctx.plugin_manager.context.get_config().get("timezone"), timezone=self.ctx.plugin_manager.context.get_config().get("timezone"),
max_quoted_fallback_images=settings.get("max_quoted_fallback_images", 20), max_quoted_fallback_images=settings.get("max_quoted_fallback_images", 20),
tool_call_approval=self.tool_call_approval,
) )
async def process( async def process(
+110
View File
@@ -0,0 +1,110 @@
import os
import sys
import pytest
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from astrbot.core.agent.tool_call_approval import (
ToolCallApprovalContext,
ToolCallApprovalResult,
request_tool_call_approval,
)
class DummyEvent:
def __init__(self) -> None:
self.unified_msg_origin = "test:friend:test_user"
self.sent_messages = []
self.message_str = ""
async def send(self, message) -> None:
self.sent_messages.append(message)
@pytest.mark.asyncio
async def test_request_tool_call_approval_disabled_returns_approved():
event = DummyEvent()
result = await request_tool_call_approval(
config={"enable": False},
ctx=ToolCallApprovalContext(
event=event,
tool_name="test_tool",
tool_args={},
tool_call_id="call_1",
),
)
assert result.approved is True
assert result.reason == "approved"
assert len(event.sent_messages) == 0
@pytest.mark.asyncio
async def test_dynamic_code_approval_passed(monkeypatch):
async def _mock_wait(*args, **kwargs):
return ToolCallApprovalResult(approved=True, reason="approved")
monkeypatch.setattr(
"astrbot.core.agent.tool_call_approval._wait_for_code_input",
_mock_wait,
)
event = DummyEvent()
result = await request_tool_call_approval(
config={"enable": True, "strategy": "dynamic_code"},
ctx=ToolCallApprovalContext(
event=event,
tool_name="test_tool",
tool_args={"query": "hello"},
tool_call_id="call_2",
),
)
assert result.approved is True
assert result.reason == "approved"
assert len(event.sent_messages) == 1
@pytest.mark.asyncio
async def test_dynamic_code_approval_rejected(monkeypatch):
async def _mock_wait(*args, **kwargs):
return ToolCallApprovalResult(
approved=False,
reason="rejected",
detail="not_code",
)
monkeypatch.setattr(
"astrbot.core.agent.tool_call_approval._wait_for_code_input",
_mock_wait,
)
event = DummyEvent()
result = await request_tool_call_approval(
config={"enable": True, "strategy": "dynamic_code"},
ctx=ToolCallApprovalContext(
event=event,
tool_name="test_tool",
tool_args={},
tool_call_id="call_3",
),
)
assert result.approved is False
assert result.reason == "rejected"
assert len(event.sent_messages) == 2
@pytest.mark.asyncio
async def test_request_tool_call_approval_unsupported_strategy():
event = DummyEvent()
result = await request_tool_call_approval(
config={"enable": True, "strategy": "unknown_strategy"},
ctx=ToolCallApprovalContext(
event=event,
tool_name="test_tool",
tool_args={},
tool_call_id="call_4",
),
)
assert result.approved is False
assert result.reason == "unsupported_strategy"
assert len(event.sent_messages) == 0