Files
AstrBot/astrbot/core/provider/func_tool_manager.py
T
2025-03-15 20:26:17 +08:00

335 lines
11 KiB
Python

import json
import textwrap
import os
from typing import Dict, List, Awaitable
from dataclasses import dataclass
from typing import Optional
from contextlib import AsyncExitStack
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client
from astrbot import logger
from anthropic import Anthropic
@dataclass
class FuncTool:
"""
用于描述一个函数调用工具。
"""
name: str
parameters: Dict
description: str
handler: Awaitable
handler_module_path: str = None # 必须要保留这个,handler 在初始化会被 functools.partial 包装,导致 handler 的 __module__ 为 functools
active: bool = True
"""是否激活"""
def __repr__(self):
return f"FuncTool(name={self.name}, parameters={self.parameters}, description={self.description}), active={self.active})"
SUPPORTED_TYPES = [
"string",
"number",
"object",
"array",
"boolean",
] # json schema 支持的数据类型
class MCPClient:
def __init__(self):
# Initialize session and client objects
self.session: Optional[ClientSession] = None
self.exit_stack = AsyncExitStack()
self.anthropic = Anthropic()
self.name = None
self.active: bool = True
async def connect_to_server(self, server_script_path: str):
"""Connect to an MCP server
Args:
server_script_path: Path to the server script (.py or .js)
"""
is_python = server_script_path.endswith(".py")
is_js = server_script_path.endswith(".js")
if not (is_python or is_js):
raise ValueError("Server script must be a .py or .js file")
command = "python" if is_python else "node"
server_params = StdioServerParameters(
command=command, args=[server_script_path], env=None
)
stdio_transport = await self.exit_stack.enter_async_context(
stdio_client(server_params)
)
self.stdio, self.write = stdio_transport
self.session = await self.exit_stack.enter_async_context(
ClientSession(self.stdio, self.write)
)
await self.session.initialize()
class FuncCall:
def __init__(self) -> None:
self.func_list: List[FuncTool] = []
self.mcp_client_dict: Dict[str:MCPClient] = {}
def empty(self) -> bool:
return len(self.func_list) == 0
def add_func(
self,
name: str,
func_args: list,
desc: str,
handler: Awaitable,
) -> None:
"""添加函数调用工具
@param name: 函数名
@param func_args: 函数参数列表,格式为 [{"type": "string", "name": "arg_name", "description": "arg_description"}, ...]
@param desc: 函数描述
@param func_obj: 处理函数
"""
# check if the tool has been added before
self.remove_func(name)
params = {
"type": "object", # hard-coded here
"properties": {},
}
for param in func_args:
params["properties"][param["name"]] = {
"type": param["type"],
"description": param["description"],
}
_func = FuncTool(
name=name,
parameters=params,
description=desc,
handler=handler,
)
self.func_list.append(_func)
logger.info(f"添加函数调用工具: {name}")
def remove_func(self, name: str) -> None:
"""
删除一个函数调用工具。
"""
for i, f in enumerate(self.func_list):
if f.name == name:
self.func_list.pop(i)
break
def get_func(self, name) -> FuncTool:
for f in self.func_list:
if f.name == name:
return f
return None
async def init_mcp_client_list(self) -> None:
"""
从项目根目录读取mcp_server.json。提供mcp_server.json.example作为参考。
内容格式为
{
"mcpServers": {
"example_cmp_server": {
"script_path": "path/to/cmp/server/script.py"
}
},
...
}
"""
current_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.abspath(os.path.join(current_dir, "../../.."))
mcp_json_file = os.path.join(project_root, "mcp_server.json")
if not os.path.exists(mcp_json_file):
# 配置文件不存在错误处理
logger.warning(
f"mcp server config file {mcp_json_file} not found. skip init mcp client list."
)
return
mcp_server_json_obj = json.load(open(mcp_json_file, "r", encoding="utf-8"))
for mcp_server_name, mcp_server_script_path in mcp_server_json_obj[
"mcpServers"
].items():
if not os.path.exists(mcp_server_script_path["script_path"]):
logger.error(
f"mcp server import err: Server script {mcp_server_script_path['script_path']} not found."
)
continue
mcp_client = MCPClient()
mcp_client.name = mcp_server_name
await mcp_client.connect_to_server(mcp_server_script_path["script_path"])
self.mcp_client_dict[mcp_server_name] = mcp_client
logger.info(f"添加mcp服务 {mcp_server_name}.")
if len(self.mcp_client_dict) == 0:
logger.info("未启用任何mcp服务.")
async def get_func_desc_openai_style(self) -> list:
"""
获得 OpenAI API 风格的**已经激活**的工具描述
"""
_l = []
for f in self.func_list:
if not f.active:
continue
_l.append(
{
"type": "function",
"function": {
"name": f.name,
"parameters": f.parameters,
"description": f.description,
},
}
)
for name, client in self.mcp_client_dict.items():
responses = await client.session.list_tools()
for tool in responses.tools:
_l.append(
{
"type": "function",
"function": {
"name": f"mcp:{name}:{tool.name}",
"parameters": tool.inputSchema,
"description": tool.description,
},
}
)
return _l
def get_func_desc_anthropic_style(self) -> list:
"""
获得 Anthropic API 风格的**已经激活**的工具描述
"""
tools = []
for f in self.func_list:
if not f.active:
continue
# Convert internal format to Anthropic style
tool = {
"name": f.name,
"description": f.description,
"input_schema": {
"type": "object",
"properties": f.parameters.get("properties", {}),
# Keep the required field from the original parameters if it exists
"required": f.parameters.get("required", []),
},
}
tools.append(tool)
return tools
def get_func_desc_google_genai_style(self) -> Dict:
declarations = {}
tools = []
for f in self.func_list:
if not f.active:
continue
func_declaration = {"name": f.name, "description": f.description}
# 检查并添加非空的properties参数
params = f.parameters if isinstance(f.parameters, dict) else {}
if params.get("properties", {}):
func_declaration["parameters"] = params
tools.append(func_declaration)
if tools:
declarations["function_declarations"] = tools
return declarations
async def func_call(self, question: str, session_id: str, provider) -> tuple:
_l = []
for f in self.func_list:
if not f.active:
continue
_l.append(
{
"name": f["name"],
"parameters": f["parameters"],
"description": f["description"],
}
)
func_definition = json.dumps(_l, ensure_ascii=False)
prompt = textwrap.dedent(f"""
ROLE:
你是一个 Function calling AI Agent, 你的任务是将用户的提问转化为函数调用。
TOOLS:
可用的函数列表:
{func_definition}
LIMIT:
1. 你返回的内容应当能够被 Python 的 json 模块解析的 Json 格式字符串。
2. 你的 Json 返回的格式如下:`[{{"name": "<func_name>", "args": <arg_dict>}}, ...]`。参数根据上面提供的函数列表中的参数来填写。
3. 允许必要时返回多个函数调用,但需保证这些函数调用的顺序正确。
4. 如果用户的提问中不需要用到给定的函数,请直接返回 `{{"res": False}}`。
EXAMPLE:
1. `用户提问`:请问一下天气怎么样? `函数调用`:[{{"name": "get_weather", "args": {{"city": "北京"}}}}]
用户的提问是:{question}
""")
_c = 0
while _c < 3:
try:
res = await provider.text_chat(prompt, session_id)
if res.find("```") != -1:
res = res[res.find("```json") + 7 : res.rfind("```")]
res = json.loads(res)
break
except Exception as e:
_c += 1
if _c == 3:
raise e
if "The message you submitted was too long" in str(e):
raise e
if "res" in res and not res["res"]:
return "", False
tool_call_result = []
for tool in res:
# 说明有函数调用
func_name = tool["name"]
args = tool["args"]
# 调用函数
tool_callable = None
for func in self.func_list:
if func.name == func_name:
tool_callable = func.star_handler_metadata.handler
break
if not tool_callable:
raise Exception(f"Request function {func_name} not found.")
ret = await tool_callable(**args)
if ret:
tool_call_result.append(str(ret))
return tool_call_result, True
def __str__(self):
return str(self.func_list)
def __repr__(self):
return str(self.func_list)