969 lines
34 KiB
Python
969 lines
34 KiB
Python
from __future__ import annotations
|
||
|
||
import asyncio
|
||
import copy
|
||
import json
|
||
import os
|
||
import threading
|
||
import urllib.parse
|
||
from collections.abc import AsyncGenerator, Awaitable, Callable, Mapping
|
||
from dataclasses import dataclass
|
||
from types import MappingProxyType
|
||
from typing import Any
|
||
|
||
import aiohttp
|
||
|
||
from astrbot import logger
|
||
from astrbot.core import sp
|
||
from astrbot.core.agent.mcp_client import MCPClient, MCPTool
|
||
from astrbot.core.agent.tool import FunctionTool, ToolSet
|
||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||
|
||
DEFAULT_MCP_CONFIG = {"mcpServers": {}}
|
||
|
||
DEFAULT_MCP_INIT_TIMEOUT_SECONDS = 180.0
|
||
DEFAULT_ENABLE_MCP_TIMEOUT_SECONDS = 180.0
|
||
MCP_INIT_TIMEOUT_ENV = "ASTRBOT_MCP_INIT_TIMEOUT"
|
||
ENABLE_MCP_TIMEOUT_ENV = "ASTRBOT_MCP_ENABLE_TIMEOUT"
|
||
MAX_MCP_TIMEOUT_SECONDS = 300.0
|
||
|
||
|
||
class MCPInitError(Exception):
|
||
"""Base exception for MCP initialization failures."""
|
||
|
||
|
||
class MCPInitTimeoutError(asyncio.TimeoutError, MCPInitError):
|
||
"""Raised when MCP client initialization exceeds the configured timeout."""
|
||
|
||
|
||
class MCPAllServicesFailedError(MCPInitError):
|
||
"""Raised when all configured MCP services fail to initialize."""
|
||
|
||
|
||
class MCPShutdownTimeoutError(asyncio.TimeoutError):
|
||
"""Raised when MCP shutdown exceeds the configured timeout."""
|
||
|
||
def __init__(self, names: list[str], timeout: float) -> None:
|
||
self.names = names
|
||
self.timeout = timeout
|
||
message = f"MCP 服务关闭超时({timeout:g} 秒):{', '.join(names)}"
|
||
super().__init__(message)
|
||
|
||
|
||
@dataclass
|
||
class MCPInitSummary:
|
||
total: int
|
||
success: int
|
||
failed: list[str]
|
||
|
||
|
||
@dataclass
|
||
class _MCPServerRuntime:
|
||
name: str
|
||
client: MCPClient
|
||
shutdown_event: asyncio.Event
|
||
lifecycle_task: asyncio.Task[None]
|
||
|
||
|
||
class _MCPClientDictView(Mapping[str, MCPClient]):
|
||
"""Read-only view of MCP clients derived from runtime state."""
|
||
|
||
def __init__(self, runtime: dict[str, _MCPServerRuntime]) -> None:
|
||
self._runtime = runtime
|
||
|
||
def __getitem__(self, key: str) -> MCPClient:
|
||
return self._runtime[key].client
|
||
|
||
def __iter__(self):
|
||
return iter(self._runtime)
|
||
|
||
def __len__(self) -> int:
|
||
return len(self._runtime)
|
||
|
||
|
||
def _resolve_timeout(
|
||
timeout: float | int | str | None = None,
|
||
*,
|
||
env_name: str = MCP_INIT_TIMEOUT_ENV,
|
||
default: float = DEFAULT_MCP_INIT_TIMEOUT_SECONDS,
|
||
) -> float:
|
||
"""Resolve timeout with precedence: explicit argument > env value > default."""
|
||
source = f"环境变量 {env_name}"
|
||
if timeout is None:
|
||
timeout = os.getenv(env_name, str(default))
|
||
else:
|
||
source = "显式参数 timeout"
|
||
|
||
try:
|
||
timeout_value = float(timeout)
|
||
except (TypeError, ValueError):
|
||
logger.warning(
|
||
f"超时配置({source})={timeout!r} 无效,使用默认值 {default:g} 秒。"
|
||
)
|
||
return default
|
||
|
||
if timeout_value <= 0:
|
||
logger.warning(
|
||
f"超时配置({source})={timeout_value:g} 必须大于 0,使用默认值 {default:g} 秒。"
|
||
)
|
||
return default
|
||
|
||
if timeout_value > MAX_MCP_TIMEOUT_SECONDS:
|
||
logger.warning(
|
||
f"超时配置({source})={timeout_value:g} 过大,已限制为最大值 "
|
||
f"{MAX_MCP_TIMEOUT_SECONDS:g} 秒,以避免长时间等待。"
|
||
)
|
||
return MAX_MCP_TIMEOUT_SECONDS
|
||
|
||
return timeout_value
|
||
|
||
|
||
SUPPORTED_TYPES = [
|
||
"string",
|
||
"number",
|
||
"object",
|
||
"array",
|
||
"boolean",
|
||
] # json schema 支持的数据类型
|
||
|
||
PY_TO_JSON_TYPE = {
|
||
"int": "number",
|
||
"float": "number",
|
||
"bool": "boolean",
|
||
"str": "string",
|
||
"dict": "object",
|
||
"list": "array",
|
||
"tuple": "array",
|
||
"set": "array",
|
||
}
|
||
# alias
|
||
FuncTool = FunctionTool
|
||
|
||
|
||
def _prepare_config(config: dict) -> dict:
|
||
"""准备配置,处理嵌套格式"""
|
||
if config.get("mcpServers"):
|
||
first_key = next(iter(config["mcpServers"]))
|
||
config = config["mcpServers"][first_key]
|
||
config.pop("active", None)
|
||
return config
|
||
|
||
|
||
async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
|
||
"""快速测试 MCP 服务器可达性"""
|
||
import aiohttp
|
||
|
||
cfg = _prepare_config(config.copy())
|
||
|
||
url = cfg["url"]
|
||
headers = cfg.get("headers", {})
|
||
timeout = cfg.get("timeout", 10)
|
||
|
||
try:
|
||
async with aiohttp.ClientSession() as session:
|
||
if cfg.get("transport") == "streamable_http":
|
||
test_payload = {
|
||
"jsonrpc": "2.0",
|
||
"method": "initialize",
|
||
"id": 0,
|
||
"params": {
|
||
"protocolVersion": "2024-11-05",
|
||
"capabilities": {},
|
||
"clientInfo": {"name": "test-client", "version": "1.2.3"},
|
||
},
|
||
}
|
||
async with session.post(
|
||
url,
|
||
headers={
|
||
**headers,
|
||
"Content-Type": "application/json",
|
||
"Accept": "application/json, text/event-stream",
|
||
},
|
||
json=test_payload,
|
||
timeout=aiohttp.ClientTimeout(total=timeout),
|
||
) as response:
|
||
if response.status == 200:
|
||
return True, ""
|
||
return False, f"HTTP {response.status}: {response.reason}"
|
||
else:
|
||
async with session.get(
|
||
url,
|
||
headers={
|
||
**headers,
|
||
"Accept": "application/json, text/event-stream",
|
||
},
|
||
timeout=aiohttp.ClientTimeout(total=timeout),
|
||
) as response:
|
||
if response.status == 200:
|
||
return True, ""
|
||
return False, f"HTTP {response.status}: {response.reason}"
|
||
|
||
except asyncio.TimeoutError:
|
||
return False, f"连接超时: {timeout}秒"
|
||
except Exception as e:
|
||
return False, f"{e!s}"
|
||
|
||
|
||
class FunctionToolManager:
|
||
def __init__(self) -> None:
|
||
self.func_list: list[FuncTool] = []
|
||
self._mcp_server_runtime: dict[str, _MCPServerRuntime] = {}
|
||
"""MCP 服务运行时状态(唯一事实来源)"""
|
||
self._mcp_server_runtime_view = MappingProxyType(self._mcp_server_runtime)
|
||
self._mcp_client_dict_view = _MCPClientDictView(self._mcp_server_runtime)
|
||
self._timeout_mismatch_warned = False
|
||
self._timeout_warn_lock = threading.Lock()
|
||
self._runtime_lock = asyncio.Lock()
|
||
self._mcp_starting: set[str] = set()
|
||
self._init_timeout_default = _resolve_timeout(
|
||
timeout=None,
|
||
env_name=MCP_INIT_TIMEOUT_ENV,
|
||
default=DEFAULT_MCP_INIT_TIMEOUT_SECONDS,
|
||
)
|
||
self._enable_timeout_default = _resolve_timeout(
|
||
timeout=None,
|
||
env_name=ENABLE_MCP_TIMEOUT_ENV,
|
||
default=DEFAULT_ENABLE_MCP_TIMEOUT_SECONDS,
|
||
)
|
||
self._warn_on_timeout_mismatch(
|
||
self._init_timeout_default,
|
||
self._enable_timeout_default,
|
||
)
|
||
|
||
@property
|
||
def mcp_client_dict(self) -> Mapping[str, MCPClient]:
|
||
"""Read-only compatibility view for external callers that still read mcp_client_dict.
|
||
|
||
Note: Mutating this mapping is unsupported and will raise TypeError.
|
||
"""
|
||
return self._mcp_client_dict_view
|
||
|
||
@property
|
||
def mcp_server_runtime_view(self) -> Mapping[str, _MCPServerRuntime]:
|
||
"""Read-only view of MCP runtime metadata for external callers."""
|
||
return self._mcp_server_runtime_view
|
||
|
||
@property
|
||
def mcp_server_runtime(self) -> Mapping[str, _MCPServerRuntime]:
|
||
"""Backward-compatible read-only view (deprecated). Do not mutate.
|
||
|
||
Note: Mutations are not supported and will raise TypeError.
|
||
"""
|
||
return self._mcp_server_runtime_view
|
||
|
||
def empty(self) -> bool:
|
||
return len(self.func_list) == 0
|
||
|
||
def spec_to_func(
|
||
self,
|
||
name: str,
|
||
func_args: list[dict],
|
||
desc: str,
|
||
handler: Callable[..., Awaitable[Any] | AsyncGenerator[Any]],
|
||
) -> FuncTool:
|
||
params = {
|
||
"type": "object", # hard-coded here
|
||
"properties": {},
|
||
}
|
||
for param in func_args:
|
||
p = copy.deepcopy(param)
|
||
p.pop("name", None)
|
||
params["properties"][param["name"]] = p
|
||
return FuncTool(
|
||
name=name,
|
||
parameters=params,
|
||
description=desc,
|
||
handler=handler,
|
||
)
|
||
|
||
def add_func(
|
||
self,
|
||
name: str,
|
||
func_args: list,
|
||
desc: str,
|
||
handler: Callable[..., Awaitable[Any] | AsyncGenerator[Any]],
|
||
) -> None:
|
||
"""添加函数调用工具
|
||
|
||
@param name: 函数名
|
||
@param func_args: 函数参数列表,格式为 [{"type": "string", "name": "arg_name", "description": "arg_description"}, ...]
|
||
@param desc: 函数描述
|
||
@param func_obj: 处理函数
|
||
"""
|
||
# check if the tool has been added before
|
||
self.remove_func(name)
|
||
|
||
self.func_list.append(
|
||
self.spec_to_func(
|
||
name=name,
|
||
func_args=func_args,
|
||
desc=desc,
|
||
handler=handler,
|
||
),
|
||
)
|
||
logger.info(f"添加函数调用工具: {name}")
|
||
|
||
def remove_func(self, name: str) -> None:
|
||
"""删除一个函数调用工具。"""
|
||
for i, f in enumerate(self.func_list):
|
||
if f.name == name:
|
||
self.func_list.pop(i)
|
||
break
|
||
|
||
def get_func(self, name) -> FuncTool | None:
|
||
for f in self.func_list:
|
||
if f.name == name:
|
||
return f
|
||
|
||
def get_full_tool_set(self) -> ToolSet:
|
||
"""获取完整工具集"""
|
||
tool_set = ToolSet(self.func_list.copy())
|
||
return tool_set
|
||
|
||
@staticmethod
|
||
def _log_safe_mcp_debug_config(cfg: dict) -> None:
|
||
# 仅记录脱敏后的摘要,避免泄露 command/args/url 中的敏感信息
|
||
if "command" in cfg:
|
||
cmd = cfg["command"]
|
||
executable = str(cmd[0] if isinstance(cmd, (list, tuple)) and cmd else cmd)
|
||
args_val = cfg.get("args", [])
|
||
args_count = (
|
||
len(args_val)
|
||
if isinstance(args_val, (list, tuple))
|
||
else (0 if args_val is None else 1)
|
||
)
|
||
logger.debug(f" 命令可执行文件: {executable}, 参数数量: {args_count}")
|
||
return
|
||
|
||
if "url" in cfg:
|
||
parsed = urllib.parse.urlparse(str(cfg["url"]))
|
||
host = parsed.hostname or ""
|
||
scheme = parsed.scheme or "unknown"
|
||
try:
|
||
port = f":{parsed.port}" if parsed.port else ""
|
||
except ValueError:
|
||
port = ""
|
||
logger.debug(f" 主机: {scheme}://{host}{port}")
|
||
|
||
async def init_mcp_clients(
|
||
self, raise_on_all_failed: bool = False
|
||
) -> MCPInitSummary:
|
||
"""从项目根目录读取 mcp_server.json 文件,初始化 MCP 服务列表。文件格式如下:
|
||
```
|
||
{
|
||
"mcpServers": {
|
||
"weather": {
|
||
"command": "uv",
|
||
"args": [
|
||
"--directory",
|
||
"/ABSOLUTE/PATH/TO/PARENT/FOLDER/weather",
|
||
"run",
|
||
"weather.py"
|
||
]
|
||
}
|
||
}
|
||
...
|
||
}
|
||
```
|
||
|
||
Timeout behavior:
|
||
- 初始化超时使用环境变量 ASTRBOT_MCP_INIT_TIMEOUT 或默认值。
|
||
- 动态启用超时使用 ASTRBOT_MCP_ENABLE_TIMEOUT(独立于初始化超时)。
|
||
"""
|
||
data_dir = get_astrbot_data_path()
|
||
|
||
mcp_json_file = os.path.join(data_dir, "mcp_server.json")
|
||
if not os.path.exists(mcp_json_file):
|
||
# 配置文件不存在错误处理
|
||
with open(mcp_json_file, "w", encoding="utf-8") as f:
|
||
json.dump(DEFAULT_MCP_CONFIG, f, ensure_ascii=False, indent=4)
|
||
logger.info(f"未找到 MCP 服务配置文件,已创建默认配置文件 {mcp_json_file}")
|
||
return MCPInitSummary(total=0, success=0, failed=[])
|
||
|
||
with open(mcp_json_file, encoding="utf-8") as f:
|
||
mcp_server_json_obj: dict[str, dict] = json.load(f)["mcpServers"]
|
||
|
||
init_timeout = self._init_timeout_default
|
||
timeout_display = f"{init_timeout:g}"
|
||
|
||
active_configs: list[tuple[str, dict, asyncio.Event]] = []
|
||
for name, cfg in mcp_server_json_obj.items():
|
||
if cfg.get("active", True):
|
||
shutdown_event = asyncio.Event()
|
||
active_configs.append((name, cfg, shutdown_event))
|
||
|
||
if not active_configs:
|
||
return MCPInitSummary(total=0, success=0, failed=[])
|
||
|
||
logger.info(f"等待 {len(active_configs)} 个 MCP 服务初始化...")
|
||
|
||
init_tasks = [
|
||
asyncio.create_task(
|
||
self._start_mcp_server(
|
||
name=name,
|
||
cfg=cfg,
|
||
shutdown_event=shutdown_event,
|
||
timeout=init_timeout,
|
||
),
|
||
name=f"mcp-init:{name}",
|
||
)
|
||
for (name, cfg, shutdown_event) in active_configs
|
||
]
|
||
results = await asyncio.gather(*init_tasks, return_exceptions=True)
|
||
|
||
success_count = 0
|
||
failed_services: list[str] = []
|
||
|
||
for (name, cfg, _), result in zip(active_configs, results, strict=False):
|
||
if isinstance(result, Exception):
|
||
if isinstance(result, MCPInitTimeoutError):
|
||
logger.error(
|
||
f"Connected to MCP server {name} timeout ({timeout_display} seconds)"
|
||
)
|
||
else:
|
||
logger.error(f"Failed to initialize MCP server {name}: {result}")
|
||
self._log_safe_mcp_debug_config(cfg)
|
||
failed_services.append(name)
|
||
async with self._runtime_lock:
|
||
self._mcp_server_runtime.pop(name, None)
|
||
continue
|
||
|
||
success_count += 1
|
||
|
||
if failed_services:
|
||
logger.warning(
|
||
f"The following MCP services failed to initialize: {', '.join(failed_services)}. "
|
||
f"Please check the mcp_server.json file and server availability."
|
||
)
|
||
|
||
summary = MCPInitSummary(
|
||
total=len(active_configs), success=success_count, failed=failed_services
|
||
)
|
||
logger.info(
|
||
f"MCP services initialization completed: {summary.success}/{summary.total} successful, {len(summary.failed)} failed."
|
||
)
|
||
if summary.total > 0 and summary.success == 0:
|
||
msg = "All MCP services failed to initialize, please check the mcp_server.json and server availability."
|
||
if raise_on_all_failed:
|
||
raise MCPAllServicesFailedError(msg)
|
||
logger.error(msg)
|
||
return summary
|
||
|
||
async def _start_mcp_server(
|
||
self,
|
||
name: str,
|
||
cfg: dict,
|
||
*,
|
||
shutdown_event: asyncio.Event | None = None,
|
||
timeout: float,
|
||
) -> None:
|
||
"""Initialize MCP server with timeout and register task/event together.
|
||
|
||
This method is idempotent. If the server is already running, the existing
|
||
runtime is kept and the new config is ignored.
|
||
"""
|
||
async with self._runtime_lock:
|
||
if name in self._mcp_server_runtime or name in self._mcp_starting:
|
||
logger.warning(
|
||
f"Connected to MCP server {name}, ignoring this startup request (timeout={timeout:g})."
|
||
)
|
||
self._log_safe_mcp_debug_config(cfg)
|
||
return
|
||
self._mcp_starting.add(name)
|
||
|
||
if shutdown_event is None:
|
||
shutdown_event = asyncio.Event()
|
||
|
||
mcp_client: MCPClient | None = None
|
||
try:
|
||
mcp_client = await asyncio.wait_for(
|
||
self._init_mcp_client(name, cfg),
|
||
timeout=timeout,
|
||
)
|
||
except asyncio.TimeoutError as exc:
|
||
raise MCPInitTimeoutError(
|
||
f"Connected to MCP server {name} timeout ({timeout:g} seconds)"
|
||
) from exc
|
||
except Exception:
|
||
logger.error(f"Failed to initialize MCP client {name}", exc_info=True)
|
||
raise
|
||
finally:
|
||
if mcp_client is None:
|
||
async with self._runtime_lock:
|
||
self._mcp_starting.discard(name)
|
||
|
||
async def lifecycle() -> None:
|
||
try:
|
||
await shutdown_event.wait()
|
||
logger.info(f"Received shutdown signal for MCP client {name}")
|
||
except asyncio.CancelledError:
|
||
logger.debug(f"MCP client {name} task was cancelled")
|
||
raise
|
||
finally:
|
||
await self._terminate_mcp_client(name)
|
||
|
||
lifecycle_task = asyncio.create_task(lifecycle(), name=f"mcp-client:{name}")
|
||
async with self._runtime_lock:
|
||
self._mcp_server_runtime[name] = _MCPServerRuntime(
|
||
name=name,
|
||
client=mcp_client,
|
||
shutdown_event=shutdown_event,
|
||
lifecycle_task=lifecycle_task,
|
||
)
|
||
self._mcp_starting.discard(name)
|
||
|
||
async def _shutdown_runtimes(
|
||
self,
|
||
runtimes: list[_MCPServerRuntime],
|
||
timeout: float,
|
||
*,
|
||
strict: bool = True,
|
||
) -> list[str]:
|
||
"""Shutdown runtimes and wait for lifecycle tasks to complete."""
|
||
lifecycle_tasks = [
|
||
runtime.lifecycle_task
|
||
for runtime in runtimes
|
||
if not runtime.lifecycle_task.done()
|
||
]
|
||
if not lifecycle_tasks:
|
||
return []
|
||
|
||
for runtime in runtimes:
|
||
runtime.shutdown_event.set()
|
||
|
||
try:
|
||
results = await asyncio.wait_for(
|
||
asyncio.gather(*lifecycle_tasks, return_exceptions=True),
|
||
timeout=timeout,
|
||
)
|
||
except asyncio.TimeoutError:
|
||
pending_names = [
|
||
runtime.name
|
||
for runtime in runtimes
|
||
if not runtime.lifecycle_task.done()
|
||
]
|
||
for task in lifecycle_tasks:
|
||
if not task.done():
|
||
task.cancel()
|
||
await asyncio.gather(*lifecycle_tasks, return_exceptions=True)
|
||
if strict:
|
||
raise MCPShutdownTimeoutError(pending_names, timeout)
|
||
logger.warning(
|
||
"MCP server shutdown timeout (%s seconds), the following servers were not fully closed: %s",
|
||
f"{timeout:g}",
|
||
", ".join(pending_names),
|
||
)
|
||
return pending_names
|
||
else:
|
||
for result in results:
|
||
if isinstance(result, asyncio.CancelledError):
|
||
logger.debug("MCP lifecycle task was cancelled during shutdown.")
|
||
elif isinstance(result, Exception):
|
||
logger.error(
|
||
"MCP lifecycle task failed during shutdown.",
|
||
exc_info=(type(result), result, result.__traceback__),
|
||
)
|
||
return []
|
||
|
||
async def _cleanup_mcp_client_safely(
|
||
self, mcp_client: MCPClient, name: str
|
||
) -> None:
|
||
"""安全清理单个 MCP 客户端,避免清理异常中断主流程。"""
|
||
try:
|
||
await mcp_client.cleanup()
|
||
except Exception as cleanup_exc: # noqa: BLE001 - only log here
|
||
logger.error(
|
||
f"Failed to cleanup MCP client resources {name}: {cleanup_exc}"
|
||
)
|
||
|
||
async def _init_mcp_client(self, name: str, config: dict) -> MCPClient:
|
||
"""初始化单个MCP客户端"""
|
||
mcp_client = MCPClient()
|
||
mcp_client.name = name
|
||
try:
|
||
await mcp_client.connect_to_server(config, name)
|
||
tools_res = await mcp_client.list_tools_and_save()
|
||
except asyncio.CancelledError:
|
||
await self._cleanup_mcp_client_safely(mcp_client, name)
|
||
raise
|
||
except Exception:
|
||
await self._cleanup_mcp_client_safely(mcp_client, name)
|
||
raise
|
||
logger.debug(f"MCP server {name} list tools response: {tools_res}")
|
||
tool_names = [tool.name for tool in tools_res.tools]
|
||
|
||
# 移除该MCP服务之前的工具(如有)
|
||
self.func_list = [
|
||
f
|
||
for f in self.func_list
|
||
if not (isinstance(f, MCPTool) and f.mcp_server_name == name)
|
||
]
|
||
|
||
# 将 MCP 工具转换为 FuncTool 并添加到 func_list
|
||
for tool in mcp_client.tools:
|
||
func_tool = MCPTool(
|
||
mcp_tool=tool,
|
||
mcp_client=mcp_client,
|
||
mcp_server_name=name,
|
||
)
|
||
self.func_list.append(func_tool)
|
||
|
||
logger.info(f"Connected to MCP server {name}, Tools: {tool_names}")
|
||
return mcp_client
|
||
|
||
async def _terminate_mcp_client(self, name: str) -> None:
|
||
"""关闭并清理MCP客户端"""
|
||
async with self._runtime_lock:
|
||
runtime = self._mcp_server_runtime.get(name)
|
||
if runtime:
|
||
client = runtime.client
|
||
# 关闭MCP连接
|
||
await self._cleanup_mcp_client_safely(client, name)
|
||
# 移除关联的FuncTool
|
||
self.func_list = [
|
||
f
|
||
for f in self.func_list
|
||
if not (isinstance(f, MCPTool) and f.mcp_server_name == name)
|
||
]
|
||
async with self._runtime_lock:
|
||
self._mcp_server_runtime.pop(name, None)
|
||
self._mcp_starting.discard(name)
|
||
logger.info(f"Disconnected from MCP server {name}")
|
||
return
|
||
|
||
# Runtime missing but stale tools may still exist after failed flows.
|
||
self.func_list = [
|
||
f
|
||
for f in self.func_list
|
||
if not (isinstance(f, MCPTool) and f.mcp_server_name == name)
|
||
]
|
||
async with self._runtime_lock:
|
||
self._mcp_starting.discard(name)
|
||
|
||
@staticmethod
|
||
async def test_mcp_server_connection(config: dict) -> list[str]:
|
||
if "url" in config:
|
||
success, error_msg = await _quick_test_mcp_connection(config)
|
||
if not success:
|
||
raise Exception(error_msg)
|
||
|
||
mcp_client = MCPClient()
|
||
try:
|
||
logger.debug(f"testing MCP server connection with config: {config}")
|
||
await mcp_client.connect_to_server(config, "test")
|
||
tools_res = await mcp_client.list_tools_and_save()
|
||
tool_names = [tool.name for tool in tools_res.tools]
|
||
finally:
|
||
logger.debug("Cleaning up MCP client after testing connection.")
|
||
await mcp_client.cleanup()
|
||
return tool_names
|
||
|
||
async def enable_mcp_server(
|
||
self,
|
||
name: str,
|
||
config: dict,
|
||
shutdown_event: asyncio.Event | None = None,
|
||
timeout: float | int | str | None = None,
|
||
) -> None:
|
||
"""Enable a new MCP server and initialize it.
|
||
|
||
Args:
|
||
name: The name of the MCP server.
|
||
config: Configuration for the MCP server.
|
||
shutdown_event: Event to signal when the MCP client should shut down.
|
||
timeout: Timeout in seconds for initialization.
|
||
Uses ASTRBOT_MCP_ENABLE_TIMEOUT by default (separate from init timeout).
|
||
|
||
Raises:
|
||
MCPInitTimeoutError: If initialization does not complete within timeout.
|
||
Exception: If there is an error during initialization.
|
||
"""
|
||
if timeout is None:
|
||
timeout_value = self._enable_timeout_default
|
||
else:
|
||
timeout_value = _resolve_timeout(
|
||
timeout=timeout,
|
||
env_name=ENABLE_MCP_TIMEOUT_ENV,
|
||
default=self._enable_timeout_default,
|
||
)
|
||
await self._start_mcp_server(
|
||
name=name,
|
||
cfg=config,
|
||
shutdown_event=shutdown_event,
|
||
timeout=timeout_value,
|
||
)
|
||
|
||
async def disable_mcp_server(
|
||
self,
|
||
name: str | None = None,
|
||
timeout: float = 10,
|
||
) -> None:
|
||
"""Disable an MCP server by its name.
|
||
|
||
Args:
|
||
name (str): The name of the MCP server to disable. If None, ALL MCP servers will be disabled.
|
||
timeout (int): Timeout.
|
||
|
||
Raises:
|
||
MCPShutdownTimeoutError: If shutdown does not complete within timeout.
|
||
Only raised when disabling a specific server (name is not None).
|
||
|
||
"""
|
||
if name:
|
||
async with self._runtime_lock:
|
||
runtime = self._mcp_server_runtime.get(name)
|
||
if runtime is None:
|
||
return
|
||
|
||
await self._shutdown_runtimes([runtime], timeout, strict=True)
|
||
else:
|
||
async with self._runtime_lock:
|
||
runtimes = list(self._mcp_server_runtime.values())
|
||
await self._shutdown_runtimes(runtimes, timeout, strict=False)
|
||
|
||
def _warn_on_timeout_mismatch(
|
||
self,
|
||
init_timeout: float,
|
||
enable_timeout: float,
|
||
) -> None:
|
||
if init_timeout == enable_timeout:
|
||
return
|
||
with self._timeout_warn_lock:
|
||
if self._timeout_mismatch_warned:
|
||
return
|
||
logger.info(
|
||
"检测到 MCP 初始化超时与动态启用超时配置不同:"
|
||
"初始化使用 %s 秒,动态启用使用 %s 秒。如需一致,请设置相同值。",
|
||
f"{init_timeout:g}",
|
||
f"{enable_timeout:g}",
|
||
)
|
||
self._timeout_mismatch_warned = True
|
||
|
||
def get_func_desc_openai_style(self, omit_empty_parameter_field=False) -> list:
|
||
"""获得 OpenAI API 风格的**已经激活**的工具描述"""
|
||
tools = [f for f in self.func_list if f.active]
|
||
toolset = ToolSet(tools)
|
||
return toolset.openai_schema(
|
||
omit_empty_parameter_field=omit_empty_parameter_field,
|
||
)
|
||
|
||
def get_func_desc_anthropic_style(self) -> list:
|
||
"""获得 Anthropic API 风格的**已经激活**的工具描述"""
|
||
tools = [f for f in self.func_list if f.active]
|
||
toolset = ToolSet(tools)
|
||
return toolset.anthropic_schema()
|
||
|
||
def get_func_desc_google_genai_style(self) -> dict:
|
||
"""获得 Google GenAI API 风格的**已经激活**的工具描述"""
|
||
tools = [f for f in self.func_list if f.active]
|
||
toolset = ToolSet(tools)
|
||
return toolset.google_schema()
|
||
|
||
def deactivate_llm_tool(self, name: str) -> bool:
|
||
"""停用一个已经注册的函数调用工具。
|
||
|
||
Returns:
|
||
如果没找到,会返回 False
|
||
|
||
"""
|
||
func_tool = self.get_func(name)
|
||
if func_tool is not None:
|
||
func_tool.active = False
|
||
|
||
inactivated_llm_tools: list = sp.get(
|
||
"inactivated_llm_tools",
|
||
[],
|
||
scope="global",
|
||
scope_id="global",
|
||
)
|
||
if name not in inactivated_llm_tools:
|
||
inactivated_llm_tools.append(name)
|
||
sp.put(
|
||
"inactivated_llm_tools",
|
||
inactivated_llm_tools,
|
||
scope="global",
|
||
scope_id="global",
|
||
)
|
||
|
||
return True
|
||
return False
|
||
|
||
# 因为不想解决循环引用,所以这里直接传入 star_map 先了...
|
||
def activate_llm_tool(self, name: str, star_map: dict) -> bool:
|
||
func_tool = self.get_func(name)
|
||
if func_tool is not None:
|
||
if func_tool.handler_module_path in star_map:
|
||
if not star_map[func_tool.handler_module_path].activated:
|
||
raise ValueError(
|
||
f"此函数调用工具所属的插件 {star_map[func_tool.handler_module_path].name} 已被禁用,请先在管理面板启用再激活此工具。",
|
||
)
|
||
|
||
func_tool.active = True
|
||
|
||
inactivated_llm_tools: list = sp.get(
|
||
"inactivated_llm_tools",
|
||
[],
|
||
scope="global",
|
||
scope_id="global",
|
||
)
|
||
if name in inactivated_llm_tools:
|
||
inactivated_llm_tools.remove(name)
|
||
sp.put(
|
||
"inactivated_llm_tools",
|
||
inactivated_llm_tools,
|
||
scope="global",
|
||
scope_id="global",
|
||
)
|
||
|
||
return True
|
||
return False
|
||
|
||
@property
|
||
def mcp_config_path(self):
|
||
data_dir = get_astrbot_data_path()
|
||
return os.path.join(data_dir, "mcp_server.json")
|
||
|
||
def load_mcp_config(self):
|
||
if not os.path.exists(self.mcp_config_path):
|
||
# 配置文件不存在,创建默认配置
|
||
os.makedirs(os.path.dirname(self.mcp_config_path), exist_ok=True)
|
||
with open(self.mcp_config_path, "w", encoding="utf-8") as f:
|
||
json.dump(DEFAULT_MCP_CONFIG, f, ensure_ascii=False, indent=4)
|
||
return DEFAULT_MCP_CONFIG
|
||
|
||
try:
|
||
with open(self.mcp_config_path, encoding="utf-8") as f:
|
||
return json.load(f)
|
||
except Exception as e:
|
||
logger.error(f"加载 MCP 配置失败: {e}")
|
||
return DEFAULT_MCP_CONFIG
|
||
|
||
def save_mcp_config(self, config: dict) -> bool:
|
||
try:
|
||
with open(self.mcp_config_path, "w", encoding="utf-8") as f:
|
||
json.dump(config, f, ensure_ascii=False, indent=4)
|
||
return True
|
||
except Exception as e:
|
||
logger.error(f"保存 MCP 配置失败: {e}")
|
||
return False
|
||
|
||
async def sync_modelscope_mcp_servers(self, access_token: str) -> None:
|
||
"""从 ModelScope 平台同步 MCP 服务器配置"""
|
||
base_url = "https://www.modelscope.cn/openapi/v1"
|
||
url = f"{base_url}/mcp/servers/operational"
|
||
headers = {
|
||
"Authorization": f"Bearer {access_token.strip()}",
|
||
"Content-Type": "application/json",
|
||
}
|
||
|
||
try:
|
||
async with aiohttp.ClientSession() as session:
|
||
async with session.get(url, headers=headers) as response:
|
||
if response.status == 200:
|
||
data = await response.json()
|
||
mcp_server_list = data.get("data", {}).get(
|
||
"mcp_server_list",
|
||
[],
|
||
)
|
||
local_mcp_config = self.load_mcp_config()
|
||
|
||
synced_count = 0
|
||
for server in mcp_server_list:
|
||
server_name = server["name"]
|
||
operational_urls = server.get("operational_urls", [])
|
||
if not operational_urls:
|
||
continue
|
||
url_info = operational_urls[0]
|
||
server_url = url_info.get("url")
|
||
if not server_url:
|
||
continue
|
||
# 添加到配置中(同名会覆盖)
|
||
local_mcp_config["mcpServers"][server_name] = {
|
||
"url": server_url,
|
||
"transport": "sse",
|
||
"active": True,
|
||
"provider": "modelscope",
|
||
}
|
||
synced_count += 1
|
||
|
||
if synced_count > 0:
|
||
self.save_mcp_config(local_mcp_config)
|
||
tasks = []
|
||
for server in mcp_server_list:
|
||
name = server["name"]
|
||
tasks.append(
|
||
self.enable_mcp_server(
|
||
name=name,
|
||
config=local_mcp_config["mcpServers"][name],
|
||
),
|
||
)
|
||
await asyncio.gather(*tasks)
|
||
logger.info(
|
||
f"从 ModelScope 同步了 {synced_count} 个 MCP 服务器",
|
||
)
|
||
else:
|
||
logger.warning("没有找到可用的 ModelScope MCP 服务器")
|
||
else:
|
||
raise Exception(
|
||
f"ModelScope API 请求失败: HTTP {response.status}",
|
||
)
|
||
|
||
except aiohttp.ClientError as e:
|
||
raise Exception(f"网络连接错误: {e!s}")
|
||
except Exception as e:
|
||
raise Exception(f"同步 ModelScope MCP 服务器时发生错误: {e!s}")
|
||
|
||
# Module paths whose ``get_all_tools()`` function returns internal tools.
|
||
# To add a new internal-tool provider, simply append its module path here.
|
||
_INTERNAL_TOOL_PROVIDERS: list[str] = [
|
||
"astrbot.core.tools.cron_tools",
|
||
"astrbot.core.tools.kb_query",
|
||
"astrbot.core.tools.send_message",
|
||
"astrbot.core.computer.computer_tool_provider",
|
||
]
|
||
|
||
def register_internal_tools(self) -> None:
|
||
"""Register AstrBot built-in tools from all internal providers.
|
||
|
||
Each provider module is expected to expose a ``get_all_tools()``
|
||
function that returns a list of ``FunctionTool`` instances.
|
||
|
||
Tools are marked with ``source='internal'`` so the WebUI can
|
||
distinguish them from plugin and MCP tools, and subagent
|
||
orchestrators can resolve them by name.
|
||
|
||
Duplicate registration is idempotent (skips if name already present).
|
||
"""
|
||
import importlib
|
||
|
||
existing_names = {t.name for t in self.func_list}
|
||
|
||
for module_path in self._INTERNAL_TOOL_PROVIDERS:
|
||
try:
|
||
mod = importlib.import_module(module_path)
|
||
provider_tools = mod.get_all_tools()
|
||
except Exception as e:
|
||
logger.warning(
|
||
"Failed to load internal tool provider %s: %s",
|
||
module_path,
|
||
e,
|
||
)
|
||
continue
|
||
|
||
for tool in provider_tools:
|
||
tool.source = "internal"
|
||
if tool.name not in existing_names:
|
||
self.func_list.append(tool)
|
||
existing_names.add(tool.name)
|
||
logger.info("Registered internal tool: %s", tool.name)
|
||
|
||
def __str__(self) -> str:
|
||
return str(self.func_list)
|
||
|
||
def __repr__(self) -> str:
|
||
return str(self.func_list)
|
||
|
||
|
||
# alias
|
||
FuncCall = FunctionToolManager
|