This is machine translation

Translated by Microsoft
Mouseover text to see original. Click the button below to return to the English version of the page.

Note: This page has been translated by MathWorks. Click here to see
To view all translated materials including this page, select Country from the country navigator on the bottom of this page.

Train Network in the Cloud Using Built-in Parallel Support

This example shows how to train a convolutional neural network on CIFAR-10 using MATLAB's built-in support for parallel training. Deep Learning training often takes hours or days. You can use parallel computing to speed up your training using multiple GPUs locally or in a cluster in the cloud. If you have access to a machine with multiple GPUs, then you can run this script on a local copy of the data set after setting the ExecutionEnvironment value to multi-gpu in your training options. If you want to use more resources, then you can scale up deep learning training to the cloud. This example guides you through the steps to train a deep neural network in a cluster in the cloud using MATLAB's built-in parallel support.

Before you can run this example, you need to configure a cluster and upload your data to the cloud. To get started with the cloud, set up Cloud Center, link it to an Amazon Web Services (AWS) account, and create a cluster. For instructions, see Getting Started with Cloud Center. After that, upload your data to an Amazon S3 bucket and use it directly from MATLAB. For instructions, see Upload Deep Learning Data to the Cloud.

Set Up a Parallel Pool

Start a parallel pool in the cluster and set the number of workers to the number of GPUs on your cluster. If you specify more workers than GPUs, then the remaining workers will be idle. This assumes that the cluster you want to use is set as default in cluster profiles.

numberOfWorkers = 8;
parpool(numberOfWorkers);
Starting parallel pool (parpool) using the 'MyClusterInTheCloudAWS' profile ...
connected to 8 workers.

Load the Data Set from the Cloud

Load the training and test data set from the cloud using imageDatastore. This example shows how to use a copy of CIFAR-10 data that is already stored in Amazon S3. To ensure that the workers have access to the datastore in the cloud, make sure that the environment variables for the AWS credentials has been set correctly. For instructions, see Upload Deep Learning Data to the Cloud.

imdsTrain = imageDatastore('s3://cifar10cloud/cifar10/train', ...
 'IncludeSubfolders',true, ...
 'LabelSource','foldernames');

imdsTest = imageDatastore('s3://cifar10cloud/cifar10/test', ...
 'IncludeSubfolders',true, ...
 'LabelSource','foldernames');

To train the network with augmented image data, create an augmentedImageDatastore object. Use random translations and horizontal reflections. Data augmentation helps prevent the network from overfitting and memorizing the exact details of the training images.

imageSize = [32 32 3];
pixelRange = [-4 4];
imageAugmenter = imageDataAugmenter( ...
    'RandXReflection',true, ...
    'RandXTranslation',pixelRange, ...
    'RandYTranslation',pixelRange);
augmentedImdsTrain = augmentedImageDatastore(imageSize,imdsTrain, ...
    'DataAugmentation',imageAugmenter, ...
    'OutputSizeMode','randcrop');

Define Network Architecture and Training Options

Define a network architecture for CIFAR-10. To simplify the code, use convolutional blocks containing several convolutional layers that convolve the input. The pooling layers downsample the spatial dimensions.

netDepth = 4; % netDepth controls the depth of the convolutional blocks
netWidth = 32; % netWidth controls the number of filters in a convolutional block

layers = [
    imageInputLayer(imageSize)

    convolutionalBlock(netWidth,netDepth)
    maxPooling2dLayer(2,'Stride',2)
    convolutionalBlock(2*netWidth,netDepth)
    maxPooling2dLayer(2,'Stride',2)
    convolutionalBlock(4*netWidth,netDepth)
    averagePooling2dLayer(8)

    fullyConnectedLayer(10)
    softmaxLayer
    classificationLayer
];

Define the training options. To train the network in parallel using the current cluster, set the execution environment to parallel. Scale the learning rate according to the mini-batch size. Use a learning rate schedule to drop the learning rate as the training progresses. Turn on the training progress plot to obtain visual feedback during training.

miniBatchSize = 64 * numberOfWorkers;
initialLearnRate = 1e-1 * miniBatchSize/256;

options = trainingOptions('sgdm', ...
    'ExecutionEnvironment','parallel', ... % Turn on built-in parallel support.
    'InitialLearnRate',initialLearnRate, ... % Set the initial learning rate.
    'MiniBatchSize',miniBatchSize, ... % Set the MiniBatchSize.
    'Verbose',false, ... % Do not send command line output.
    'Plots','training-progress', ... % Turn on the training progress plot.
    'L2Regularization',1e-10, ...
    'MaxEpochs',30, ...
    'Shuffle','every-epoch', ...
    'ValidationData',imdsTest, ...
    'ValidationFrequency',floor(numel(imdsTrain.Files)/miniBatchSize), ...
    'ValidationPatience',Inf, ...
    'LearnRateSchedule','piecewise', ...
    'LearnRateDropFactor',0.1, ...
    'LearnRateDropPeriod',25);

Train the Network and Use It for Classification

Train the network in the cluster. During training, the plot displays the progress.

net = trainNetwork(augmentedImdsTrain,layers,options)
net = 

  SeriesNetwork with properties:

    Layers: [43×1 nnet.cnn.layer.Layer]

To obtain the accuracy of this network, use the trained network to classify the test images on the local machine and compare this to the actual labels.

YPredicted = classify(net,imdsTest);
accuracy = sum(YPredicted == imdsTest.Labels)/numel(imdsTest.Labels)
accuracy =

    0.9036

Define Helper Functions

Define a function to make easier the creation of convolutional blocks in the network architecture.

function layers = convolutionalBlock(numFilters,numConvLayers)
    layers = [
        convolution2dLayer(3,numFilters,'Padding','same')
        batchNormalizationLayer
        reluLayer
    ];

    layers = repmat(layers,numConvLayers,1);
end

See Also

| |

Related Topics