Main Content

Train Network with Multiple Outputs

This example shows how to train a deep learning network with multiple outputs that predict both labels and angles of rotations of handwritten digits.

Load Training Data

Load the digits data. The data contains images of digits as well as the digit labels, and their angles of rotation from the vertical.

load DigitsDataTrain

Create an arrayDatastore object for the images, labels, and the angles, and then use the combine function to make a single datastore that contains all of the training data.

dsXTrain = arrayDatastore(XTrain,IterationDimension=4);
dsT1Train = arrayDatastore(labelsTrain);
dsT2Train = arrayDatastore(anglesTrain);

dsTrain = combine(dsXTrain,dsT1Train,dsT2Train);

classNames = categories(labelsTrain);
numClasses = numel(classNames);
numObservations = numel(labelsTrain);

View some images from the training data.

idx = randperm(numObservations,64);
I = imtile(XTrain(:,:,:,idx));
figure
imshow(I)

Define Deep Learning Model

Define the following network that predicts both labels and angles of rotation.

  • A convolution-batchnorm-ReLU block with 16 5-by-5 filters.

  • Two convolution-batchnorm-ReLU blocks each with 32 3-by-3 filters.

  • A skip connection around the previous two blocks containing a convolution-batchnorm-ReLU block with 32 1-by-1 convolutions.

  • Merge the skip connection using addition.

  • For classification output, a branch with a fully connected operation of size 10 (the number of classes) and a softmax operation.

  • For the regression output, a branch with a fully connected operation of size 1 (the number of responses).

Define the main block of layers.

net = dlnetwork;

layers = [
    imageInputLayer([28 28 1],Normalization="none")

    convolution2dLayer(5,16,Padding="same")
    batchNormalizationLayer
    reluLayer(Name="relu_1")

    convolution2dLayer(3,32,Padding="same",Stride=2)
    batchNormalizationLayer
    reluLayer
    convolution2dLayer(3,32,Padding="same")
    batchNormalizationLayer
    reluLayer

    additionLayer(2,Name="add")

    fullyConnectedLayer(numClasses)
    softmaxLayer(Name="softmax")];

net = addLayers(net,layers);

Add the skip connection.

layers = [
    convolution2dLayer(1,32,Stride=2,Name="conv_skip")
    batchNormalizationLayer
    reluLayer(Name="relu_skip")];

net = addLayers(net,layers);
net = connectLayers(net,"relu_1","conv_skip");
net = connectLayers(net,"relu_skip","add/in2");

Add the fully connected layer for regression.

layers = fullyConnectedLayer(1,Name="fc_2");
net = addLayers(net,layers);
net = connectLayers(net,"add","fc_2");

View the layer graph in a plot.

figure
plot(net)

Specify Training Options

Specify the training options. Choosing among the options requires empirical analysis. To explore different training option configurations by running experiments, you can use the Experiment Manager app.

options = trainingOptions("adam", ...
    Plots="training-progress", ...
    Verbose=false);

Train Neural Network

Train the neural network using the trainnet function. For classification, use a custom loss function that is the cross-entropy loss of the predicted and target labels plus 0.1 times the mean squared error loss of the predicted and target angles. By default, the trainnet function uses a GPU if one is available. Using a GPU requires a Parallel Computing Toolbox™ license and a supported GPU device. For information on supported devices, see GPU Computing Requirements (Parallel Computing Toolbox). Otherwise, the function uses the CPU. To specify the execution environment, use the ExecutionEnvironment training option.

Define the custom loss function as a function handle. Define a loss that corresponds to the cross-entropy loss of the predicted and target labels plus the mean squared error of the predicted and target angles, scaled by a factor of 0.1.

lossFcn = @(Y1,Y2,T1,T2) crossentropy(Y1,T1) + 0.1*mse(Y2,T2);

Train the neural network.

net = trainnet(dsTrain,net,lossFcn,options);

Test Model

Load the digits data. The data contains images of digits as well as the digit labels, and their angles of rotation from the vertical.

load DigitsDataTest
dsXTest = arrayDatastore(XTest,IterationDimension=4);
dsT1Test = arrayDatastore(labelsTest);
dsT2Test = arrayDatastore(anglesTest);

dsTest = combine(dsXTest,dsT1Test,dsT2Test);

Make predictions using the minibatchpredict function. By default, the minibatchpredict function uses a GPU if one is available. Using a GPU requires a Parallel Computing Toolbox™ license and a supported GPU device. For information on supported devices, see GPU Computing Requirements (Parallel Computing Toolbox). Otherwise, the function uses the CPU. To specify the execution environment, use the ExecutionEnvironment option.

[scores,Y2] = minibatchpredict(net,dsTest);
Y1 = scores2label(scores,classNames);

Calculate the classification accuracy of the labels.

accuracy = mean(Y1 == labelsTest)
accuracy = 0.9838

Calculate the root mean square error between the predicted and target angles.

err = rmse(Y2,anglesTest)
err = single
    7.5994

View some of the images with their predictions. Display the predicted angles in red and the correct labels in green.

idx = randperm(size(XTest,4),9);
figure
for i = 1:9
    subplot(3,3,i)
    I = XTest(:,:,:,idx(i));
    imshow(I)
    hold on

    sz = size(I,1);
    offset = sz/2;

    theta = Y2(idx(i));
    plot(offset*[1-tand(theta) 1+tand(theta)],[sz 0],"r--")

    thetaTest = T2Test(idx(i));
    plot(offset*[1-tand(thetaTest) 1+tand(thetaTest)],[sz 0],"g--")

    hold off
    label = Y1(idx(i));
    title("Label: " + string(label))
end

See Also

| | | | | | | | | | |

Related Topics