🎈 auto fixes by pre-commit hooks
This commit is contained in:
@@ -176,13 +176,18 @@ class LLMRequestSubStage(Stage):
|
||||
llm_response.tools_call_name, llm_response.tools_call_args
|
||||
):
|
||||
try:
|
||||
if func_tool_name.startswith('mcp:'):
|
||||
_, mcp_server_name, mcp_func_name = func_tool_name.split(':')
|
||||
if func_tool_name.startswith("mcp:"):
|
||||
_, mcp_server_name, mcp_func_name = func_tool_name.split(
|
||||
":"
|
||||
)
|
||||
logger.info(
|
||||
f"从mcp服务 {mcp_server_name} 调用工具函数:{mcp_func_name},参数:{func_tool_args}")
|
||||
f"从mcp服务 {mcp_server_name} 调用工具函数:{mcp_func_name},参数:{func_tool_args}"
|
||||
)
|
||||
|
||||
client = req.func_tool.mcp_client_dict[mcp_server_name]
|
||||
res = await client.session.call_tool(mcp_func_name, func_tool_args)
|
||||
res = await client.session.call_tool(
|
||||
mcp_func_name, func_tool_args
|
||||
)
|
||||
if res:
|
||||
# TODO content的类型可能包括list[TextContent | ImageContent | EmbeddedResource],这里只处理了TextContent。
|
||||
res_event = event.plain_result(res.content[0].text)
|
||||
|
||||
@@ -15,6 +15,7 @@ from anthropic import Anthropic
|
||||
|
||||
from ... import logger
|
||||
|
||||
|
||||
@dataclass
|
||||
class FuncTool:
|
||||
"""
|
||||
@@ -59,28 +60,31 @@ class MCPClient:
|
||||
Args:
|
||||
server_script_path: Path to the server script (.py or .js)
|
||||
"""
|
||||
is_python = server_script_path.endswith('.py')
|
||||
is_js = server_script_path.endswith('.js')
|
||||
is_python = server_script_path.endswith(".py")
|
||||
is_js = server_script_path.endswith(".js")
|
||||
if not (is_python or is_js):
|
||||
raise ValueError("Server script must be a .py or .js file")
|
||||
|
||||
command = "python" if is_python else "node"
|
||||
server_params = StdioServerParameters(
|
||||
command=command,
|
||||
args=[server_script_path],
|
||||
env=None
|
||||
command=command, args=[server_script_path], env=None
|
||||
)
|
||||
|
||||
stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params))
|
||||
stdio_transport = await self.exit_stack.enter_async_context(
|
||||
stdio_client(server_params)
|
||||
)
|
||||
self.stdio, self.write = stdio_transport
|
||||
self.session = await self.exit_stack.enter_async_context(ClientSession(self.stdio, self.write))
|
||||
self.session = await self.exit_stack.enter_async_context(
|
||||
ClientSession(self.stdio, self.write)
|
||||
)
|
||||
|
||||
await self.session.initialize()
|
||||
|
||||
|
||||
class FuncCall:
|
||||
def __init__(self) -> None:
|
||||
self.func_list: List[FuncTool] = []
|
||||
self.mcp_client_dict: Dict[str: MCPClient] = dict()
|
||||
self.mcp_client_dict: Dict[str:MCPClient] = dict()
|
||||
|
||||
def empty(self) -> bool:
|
||||
return len(self.func_list) == 0
|
||||
@@ -154,14 +158,20 @@ class FuncCall:
|
||||
mcp_json_file = os.path.join(project_root, "mcp_server.json")
|
||||
if not os.path.exists(mcp_json_file):
|
||||
# 配置文件不存在错误处理
|
||||
logger.warning(f"mcp server config file {mcp_json_file} not found. skip init mcp client list.")
|
||||
logger.warning(
|
||||
f"mcp server config file {mcp_json_file} not found. skip init mcp client list."
|
||||
)
|
||||
return
|
||||
|
||||
mcp_server_json_obj = json.load(open(mcp_json_file, "r", encoding="utf-8"))
|
||||
|
||||
for mcp_server_name, mcp_server_script_path in mcp_server_json_obj["mcpServers"].items():
|
||||
for mcp_server_name, mcp_server_script_path in mcp_server_json_obj[
|
||||
"mcpServers"
|
||||
].items():
|
||||
if not os.path.exists(mcp_server_script_path["script_path"]):
|
||||
logger.error(f"mcp server import err: Server script {mcp_server_script_path["script_path"]} not found.")
|
||||
logger.error(
|
||||
f"mcp server import err: Server script {mcp_server_script_path['script_path']} not found."
|
||||
)
|
||||
continue
|
||||
mcp_client = MCPClient()
|
||||
mcp_client.name = mcp_server_name
|
||||
@@ -186,7 +196,7 @@ class FuncCall:
|
||||
"name": f.name,
|
||||
"parameters": f.parameters,
|
||||
"description": f.description,
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
@@ -194,14 +204,16 @@ class FuncCall:
|
||||
for name, client in self.mcp_client_dict.items():
|
||||
responses = await client.session.list_tools()
|
||||
for tool in responses.tools:
|
||||
_l.append({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": f"mcp:{name}:{tool.name}",
|
||||
"parameters": tool.inputSchema,
|
||||
"description": tool.description,
|
||||
_l.append(
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": f"mcp:{name}:{tool.name}",
|
||||
"parameters": tool.inputSchema,
|
||||
"description": tool.description,
|
||||
},
|
||||
}
|
||||
})
|
||||
)
|
||||
return _l
|
||||
|
||||
def get_func_desc_anthropic_style(self) -> list:
|
||||
|
||||
@@ -124,7 +124,10 @@ class ProviderOpenAIOfficial(Provider):
|
||||
for tool in tools.func_list:
|
||||
if tool.name == tool_call.function.name:
|
||||
args = json.loads(tool_call.function.arguments)
|
||||
if tool_call.function.name.startswith("mcp:") and tool_call.function.name.split(':')[1] in tools.mcp_client_dict:
|
||||
if (
|
||||
tool_call.function.name.startswith("mcp:")
|
||||
and tool_call.function.name.split(":")[1] in tools.mcp_client_dict
|
||||
):
|
||||
args = json.loads(tool_call.function.arguments)
|
||||
args_ls.append(args)
|
||||
func_name_ls.append(tool_call.function.name)
|
||||
|
||||
Reference in New Issue
Block a user