feat:新增MCP服务支持并优化工具调用逻辑

引入MCP客户端支持,增加mcp_server.json配置样例,完善工具描述生成及调用逻辑以支持MCP服务工具功能。同时调整相关逻辑以区分本地工具与MCP工具的调用方式,提升扩展性和灵活性。
This commit is contained in:
Alero
2025-03-15 19:47:06 +08:00
parent 7d6975fd31
commit d4d9a1df4c
6 changed files with 136 additions and 20 deletions
@@ -175,21 +175,35 @@ class LLMRequestSubStage(Stage):
for func_tool_name, func_tool_args in zip(
llm_response.tools_call_name, llm_response.tools_call_args
):
func_tool = req.func_tool.get_func(func_tool_name)
logger.info(
f"调用工具函数:{func_tool_name},参数:{func_tool_args}"
)
try:
# 尝试调用工具函数
wrapper = self._call_handler(
self.ctx, event, func_tool.handler, **func_tool_args
)
async for resp in wrapper:
if resp is not None: # 有 return 返回
function_calling_result[func_tool_name] = resp
else:
yield # 有生成器返回
event.clear_result() # 清除上一个 handler 的结果
if func_tool_name.startswith('mcp:'):
_, mcp_server_name, mcp_func_name = func_tool_name.split(':')
logger.info(
f"从mcp服务 {mcp_server_name} 调用工具函数:{mcp_func_name},参数:{func_tool_args}")
client = req.func_tool.mcp_client_dict[mcp_server_name]
res = await client.session.call_tool(mcp_func_name, func_tool_args)
if res:
# TODO content的类型可能包括list[TextContent | ImageContent | EmbeddedResource],这里只处理了TextContent。
res_event = event.plain_result(res.content[0].text)
event.set_result(res_event)
yield
else:
func_tool = req.func_tool.get_func(func_tool_name)
logger.info(
f"调用工具函数:{func_tool_name},参数:{func_tool_args}"
)
# 尝试调用工具函数
wrapper = self._call_handler(
self.ctx, event, func_tool.handler, **func_tool_args
)
async for resp in wrapper:
if resp is not None: # 有 return 返回
function_calling_result[func_tool_name] = resp
else:
yield # 有生成器返回
event.clear_result() # 清除上一个 handler 的结果
except BaseException as e:
logger.warning(traceback.format_exc())
function_calling_result[func_tool_name] = (
+96 -2
View File
@@ -1,9 +1,19 @@
import json
import textwrap
import os
import asyncio
from typing import Dict, List, Awaitable
from dataclasses import dataclass
from typing import Optional
from contextlib import AsyncExitStack
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client
from astrbot import logger
from anthropic import Anthropic
from ... import logger
@dataclass
class FuncTool:
@@ -33,9 +43,44 @@ SUPPORTED_TYPES = [
] # json schema 支持的数据类型
class MCPClient:
def __init__(self):
# Initialize session and client objects
self.session: Optional[ClientSession] = None
self.exit_stack = AsyncExitStack()
self.anthropic = Anthropic()
self.name = None
self.active: bool = True
async def connect_to_server(self, server_script_path: str):
"""Connect to an MCP server
Args:
server_script_path: Path to the server script (.py or .js)
"""
is_python = server_script_path.endswith('.py')
is_js = server_script_path.endswith('.js')
if not (is_python or is_js):
raise ValueError("Server script must be a .py or .js file")
command = "python" if is_python else "node"
server_params = StdioServerParameters(
command=command,
args=[server_script_path],
env=None
)
stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params))
self.stdio, self.write = stdio_transport
self.session = await self.exit_stack.enter_async_context(ClientSession(self.stdio, self.write))
await self.session.initialize()
class FuncCall:
def __init__(self) -> None:
self.func_list: List[FuncTool] = []
self.mcp_client_dict: Dict[str: MCPClient] = dict()
def empty(self) -> bool:
return len(self.func_list) == 0
@@ -90,7 +135,43 @@ class FuncCall:
return f
return None
def get_func_desc_openai_style(self) -> list:
async def init_mcp_client_list(self) -> None:
"""
从项目根目录读取mcp_server.json。提供mcp_server.json.example作为参考。
内容格式为
{
"mcpServers": {
"example_cmp_server": {
"script_path": "path/to/cmp/server/script.py"
}
},
...
}
"""
current_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.abspath(os.path.join(current_dir, "../../.."))
mcp_json_file = os.path.join(project_root, "mcp_server.json")
if not os.path.exists(mcp_json_file):
# 配置文件不存在错误处理
logger.warning(f"mcp server config file {mcp_json_file} not found. skip init mcp client list.")
return
mcp_server_json_obj = json.load(open(mcp_json_file, "r", encoding="utf-8"))
for mcp_server_name, mcp_server_script_path in mcp_server_json_obj["mcpServers"].items():
if not os.path.exists(mcp_server_script_path["script_path"]):
logger.error(f"mcp server import err: Server script {mcp_server_script_path["script_path"]} not found.")
continue
mcp_client = MCPClient()
mcp_client.name = mcp_server_name
await mcp_client.connect_to_server(mcp_server_script_path["script_path"])
self.mcp_client_dict[mcp_server_name] = mcp_client
logger.info(f"添加mcp服务 {mcp_server_name}.")
if len(self.mcp_client_dict) == 0:
logger.info("未启用任何mcp服务.")
async def get_func_desc_openai_style(self) -> list:
"""
获得 OpenAI API 风格的**已经激活**的工具描述
"""
@@ -105,9 +186,22 @@ class FuncCall:
"name": f.name,
"parameters": f.parameters,
"description": f.description,
},
}
}
)
loop = asyncio.get_event_loop()
for name, client in self.mcp_client_dict.items():
responses = await client.session.list_tools()
for tool in responses.tools:
_l.append({
"type": "function",
"function": {
"name": f"mcp:{name}:{tool.name}",
"parameters": tool.inputSchema,
"description": tool.description,
}
})
return _l
def get_func_desc_anthropic_style(self) -> list:
+3
View File
@@ -127,6 +127,9 @@ class ProviderManager:
if self.tts_enabled and not self.curr_tts_provider_inst:
logger.warning("未启用任何用于 文本转语音 的提供商适配器。")
# 初始化mcpclient连接
await self.llm_tools.init_mcp_client_list()
async def load_provider(self, provider_config: dict):
if not provider_config["enable"]:
return
@@ -85,7 +85,7 @@ class LLMTunerModelLoader(Provider):
"system": system_prompt,
}
if func_tool:
tool_list = func_tool.get_func_desc_openai_style()
tool_list = await func_tool.get_func_desc_openai_style()
if tool_list:
conf["tools"] = tool_list
@@ -80,7 +80,7 @@ class ProviderOpenAIOfficial(Provider):
async def _query(self, payloads: dict, tools: FuncCall) -> LLMResponse:
if tools:
tool_list = tools.get_func_desc_openai_style()
tool_list = await tools.get_func_desc_openai_style()
if tool_list:
payloads["tools"] = tool_list
@@ -124,8 +124,10 @@ class ProviderOpenAIOfficial(Provider):
for tool in tools.func_list:
if tool.name == tool_call.function.name:
args = json.loads(tool_call.function.arguments)
args_ls.append(args)
func_name_ls.append(tool_call.function.name)
if tool_call.function.name.startswith("mcp:") and tool_call.function.name.split(':')[1] in tools.mcp_client_dict:
args = json.loads(tool_call.function.arguments)
args_ls.append(args)
func_name_ls.append(tool_call.function.name)
llm_response.role = "tool"
llm_response.tools_call_args = args_ls
llm_response.tools_call_name = func_name_ls
+3
View File
@@ -21,6 +21,9 @@ psutil>=5.8.0
lark-oapi
ormsgpack
cryptography
mcp~=1.4.1
anthropic~=0.49.0
dashscope
python-telegram-bot
wechatpy