feat:新增MCP服务支持并优化工具调用逻辑
引入MCP客户端支持,增加mcp_server.json配置样例,完善工具描述生成及调用逻辑以支持MCP服务工具功能。同时调整相关逻辑以区分本地工具与MCP工具的调用方式,提升扩展性和灵活性。
This commit is contained in:
@@ -175,21 +175,35 @@ class LLMRequestSubStage(Stage):
|
||||
for func_tool_name, func_tool_args in zip(
|
||||
llm_response.tools_call_name, llm_response.tools_call_args
|
||||
):
|
||||
func_tool = req.func_tool.get_func(func_tool_name)
|
||||
logger.info(
|
||||
f"调用工具函数:{func_tool_name},参数:{func_tool_args}"
|
||||
)
|
||||
try:
|
||||
# 尝试调用工具函数
|
||||
wrapper = self._call_handler(
|
||||
self.ctx, event, func_tool.handler, **func_tool_args
|
||||
)
|
||||
async for resp in wrapper:
|
||||
if resp is not None: # 有 return 返回
|
||||
function_calling_result[func_tool_name] = resp
|
||||
else:
|
||||
yield # 有生成器返回
|
||||
event.clear_result() # 清除上一个 handler 的结果
|
||||
if func_tool_name.startswith('mcp:'):
|
||||
_, mcp_server_name, mcp_func_name = func_tool_name.split(':')
|
||||
logger.info(
|
||||
f"从mcp服务 {mcp_server_name} 调用工具函数:{mcp_func_name},参数:{func_tool_args}")
|
||||
|
||||
client = req.func_tool.mcp_client_dict[mcp_server_name]
|
||||
res = await client.session.call_tool(mcp_func_name, func_tool_args)
|
||||
if res:
|
||||
# TODO content的类型可能包括list[TextContent | ImageContent | EmbeddedResource],这里只处理了TextContent。
|
||||
res_event = event.plain_result(res.content[0].text)
|
||||
event.set_result(res_event)
|
||||
yield
|
||||
else:
|
||||
func_tool = req.func_tool.get_func(func_tool_name)
|
||||
logger.info(
|
||||
f"调用工具函数:{func_tool_name},参数:{func_tool_args}"
|
||||
)
|
||||
|
||||
# 尝试调用工具函数
|
||||
wrapper = self._call_handler(
|
||||
self.ctx, event, func_tool.handler, **func_tool_args
|
||||
)
|
||||
async for resp in wrapper:
|
||||
if resp is not None: # 有 return 返回
|
||||
function_calling_result[func_tool_name] = resp
|
||||
else:
|
||||
yield # 有生成器返回
|
||||
event.clear_result() # 清除上一个 handler 的结果
|
||||
except BaseException as e:
|
||||
logger.warning(traceback.format_exc())
|
||||
function_calling_result[func_tool_name] = (
|
||||
|
||||
@@ -1,9 +1,19 @@
|
||||
import json
|
||||
import textwrap
|
||||
import os
|
||||
import asyncio
|
||||
from typing import Dict, List, Awaitable
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
from contextlib import AsyncExitStack
|
||||
|
||||
from mcp import ClientSession, StdioServerParameters
|
||||
from mcp.client.stdio import stdio_client
|
||||
from astrbot import logger
|
||||
|
||||
from anthropic import Anthropic
|
||||
|
||||
from ... import logger
|
||||
|
||||
@dataclass
|
||||
class FuncTool:
|
||||
@@ -33,9 +43,44 @@ SUPPORTED_TYPES = [
|
||||
] # json schema 支持的数据类型
|
||||
|
||||
|
||||
class MCPClient:
|
||||
def __init__(self):
|
||||
# Initialize session and client objects
|
||||
self.session: Optional[ClientSession] = None
|
||||
self.exit_stack = AsyncExitStack()
|
||||
self.anthropic = Anthropic()
|
||||
|
||||
self.name = None
|
||||
self.active: bool = True
|
||||
|
||||
async def connect_to_server(self, server_script_path: str):
|
||||
"""Connect to an MCP server
|
||||
|
||||
Args:
|
||||
server_script_path: Path to the server script (.py or .js)
|
||||
"""
|
||||
is_python = server_script_path.endswith('.py')
|
||||
is_js = server_script_path.endswith('.js')
|
||||
if not (is_python or is_js):
|
||||
raise ValueError("Server script must be a .py or .js file")
|
||||
|
||||
command = "python" if is_python else "node"
|
||||
server_params = StdioServerParameters(
|
||||
command=command,
|
||||
args=[server_script_path],
|
||||
env=None
|
||||
)
|
||||
|
||||
stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params))
|
||||
self.stdio, self.write = stdio_transport
|
||||
self.session = await self.exit_stack.enter_async_context(ClientSession(self.stdio, self.write))
|
||||
|
||||
await self.session.initialize()
|
||||
|
||||
class FuncCall:
|
||||
def __init__(self) -> None:
|
||||
self.func_list: List[FuncTool] = []
|
||||
self.mcp_client_dict: Dict[str: MCPClient] = dict()
|
||||
|
||||
def empty(self) -> bool:
|
||||
return len(self.func_list) == 0
|
||||
@@ -90,7 +135,43 @@ class FuncCall:
|
||||
return f
|
||||
return None
|
||||
|
||||
def get_func_desc_openai_style(self) -> list:
|
||||
async def init_mcp_client_list(self) -> None:
|
||||
"""
|
||||
从项目根目录读取mcp_server.json。提供mcp_server.json.example作为参考。
|
||||
内容格式为
|
||||
{
|
||||
"mcpServers": {
|
||||
"example_cmp_server": {
|
||||
"script_path": "path/to/cmp/server/script.py"
|
||||
}
|
||||
},
|
||||
...
|
||||
}
|
||||
"""
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
project_root = os.path.abspath(os.path.join(current_dir, "../../.."))
|
||||
|
||||
mcp_json_file = os.path.join(project_root, "mcp_server.json")
|
||||
if not os.path.exists(mcp_json_file):
|
||||
# 配置文件不存在错误处理
|
||||
logger.warning(f"mcp server config file {mcp_json_file} not found. skip init mcp client list.")
|
||||
return
|
||||
|
||||
mcp_server_json_obj = json.load(open(mcp_json_file, "r", encoding="utf-8"))
|
||||
|
||||
for mcp_server_name, mcp_server_script_path in mcp_server_json_obj["mcpServers"].items():
|
||||
if not os.path.exists(mcp_server_script_path["script_path"]):
|
||||
logger.error(f"mcp server import err: Server script {mcp_server_script_path["script_path"]} not found.")
|
||||
continue
|
||||
mcp_client = MCPClient()
|
||||
mcp_client.name = mcp_server_name
|
||||
await mcp_client.connect_to_server(mcp_server_script_path["script_path"])
|
||||
self.mcp_client_dict[mcp_server_name] = mcp_client
|
||||
logger.info(f"添加mcp服务 {mcp_server_name}.")
|
||||
if len(self.mcp_client_dict) == 0:
|
||||
logger.info("未启用任何mcp服务.")
|
||||
|
||||
async def get_func_desc_openai_style(self) -> list:
|
||||
"""
|
||||
获得 OpenAI API 风格的**已经激活**的工具描述
|
||||
"""
|
||||
@@ -105,9 +186,22 @@ class FuncCall:
|
||||
"name": f.name,
|
||||
"parameters": f.parameters,
|
||||
"description": f.description,
|
||||
},
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
for name, client in self.mcp_client_dict.items():
|
||||
responses = await client.session.list_tools()
|
||||
for tool in responses.tools:
|
||||
_l.append({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": f"mcp:{name}:{tool.name}",
|
||||
"parameters": tool.inputSchema,
|
||||
"description": tool.description,
|
||||
}
|
||||
})
|
||||
return _l
|
||||
|
||||
def get_func_desc_anthropic_style(self) -> list:
|
||||
|
||||
@@ -127,6 +127,9 @@ class ProviderManager:
|
||||
if self.tts_enabled and not self.curr_tts_provider_inst:
|
||||
logger.warning("未启用任何用于 文本转语音 的提供商适配器。")
|
||||
|
||||
# 初始化mcpclient连接
|
||||
await self.llm_tools.init_mcp_client_list()
|
||||
|
||||
async def load_provider(self, provider_config: dict):
|
||||
if not provider_config["enable"]:
|
||||
return
|
||||
|
||||
@@ -85,7 +85,7 @@ class LLMTunerModelLoader(Provider):
|
||||
"system": system_prompt,
|
||||
}
|
||||
if func_tool:
|
||||
tool_list = func_tool.get_func_desc_openai_style()
|
||||
tool_list = await func_tool.get_func_desc_openai_style()
|
||||
if tool_list:
|
||||
conf["tools"] = tool_list
|
||||
|
||||
|
||||
@@ -80,7 +80,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
|
||||
async def _query(self, payloads: dict, tools: FuncCall) -> LLMResponse:
|
||||
if tools:
|
||||
tool_list = tools.get_func_desc_openai_style()
|
||||
tool_list = await tools.get_func_desc_openai_style()
|
||||
if tool_list:
|
||||
payloads["tools"] = tool_list
|
||||
|
||||
@@ -124,8 +124,10 @@ class ProviderOpenAIOfficial(Provider):
|
||||
for tool in tools.func_list:
|
||||
if tool.name == tool_call.function.name:
|
||||
args = json.loads(tool_call.function.arguments)
|
||||
args_ls.append(args)
|
||||
func_name_ls.append(tool_call.function.name)
|
||||
if tool_call.function.name.startswith("mcp:") and tool_call.function.name.split(':')[1] in tools.mcp_client_dict:
|
||||
args = json.loads(tool_call.function.arguments)
|
||||
args_ls.append(args)
|
||||
func_name_ls.append(tool_call.function.name)
|
||||
llm_response.role = "tool"
|
||||
llm_response.tools_call_args = args_ls
|
||||
llm_response.tools_call_name = func_name_ls
|
||||
|
||||
@@ -21,6 +21,9 @@ psutil>=5.8.0
|
||||
lark-oapi
|
||||
ormsgpack
|
||||
cryptography
|
||||
|
||||
mcp~=1.4.1
|
||||
anthropic~=0.49.0
|
||||
dashscope
|
||||
python-telegram-bot
|
||||
wechatpy
|
||||
|
||||
Reference in New Issue
Block a user