Main Content

複数の出力をもつネットワークの学習

この例では、手書きの数字のラベルと回転角度の両方を予測する、複数の出力をもつ深層学習ネットワークに学習させる方法を説明します。

学習データの読み込み

数字のデータを読み込みます。データには、数字のイメージ、数字のラベル、垂直方向からの回転角度が含まれています。

load DigitsDataTrain

イメージ、ラベル、角度について arrayDatastore オブジェクトを作成してから、関数 combine を使用して、すべての学習データを含む単一のデータストアを作成します。

dsXTrain = arrayDatastore(XTrain,IterationDimension=4);
dsT1Train = arrayDatastore(labelsTrain);
dsT2Train = arrayDatastore(anglesTrain);

dsTrain = combine(dsXTrain,dsT1Train,dsT2Train);

classNames = categories(labelsTrain);
numClasses = numel(classNames);
numObservations = numel(labelsTrain);

学習データからの一部のイメージを表示します。

idx = randperm(numObservations,64);
I = imtile(XTrain(:,:,:,idx));
figure
imshow(I)

深層学習モデルの定義

ラベルと回転角度の両方を予測する次のネットワークを定義します。

  • 16 個の 5 x 5 フィルターをもつ convolution-batchnorm-ReLU ブロック

  • それぞれ 32 個の 3 x 3 フィルターをもつ 2 つの convolution-batchnorm-ReLU ブロック

  • 32 個の 1 x 1 の畳み込みをもつ convolution-batchnorm-ReLU ブロックを含む、前述の 2 つのブロックのスキップ接続

  • 加算を使用するスキップ接続のマージ

  • 分類出力用に、サイズが 10 (クラス数) の全結合演算とソフトマックス演算をもつ分岐

  • 回帰出力用に、サイズが 1 (応答数) の全結合演算をもつ分岐

層のメイン ブロックを定義します。

net = dlnetwork;

layers = [
    imageInputLayer([28 28 1],Normalization="none")

    convolution2dLayer(5,16,Padding="same")
    batchNormalizationLayer
    reluLayer(Name="relu_1")

    convolution2dLayer(3,32,Padding="same",Stride=2)
    batchNormalizationLayer
    reluLayer
    convolution2dLayer(3,32,Padding="same")
    batchNormalizationLayer
    reluLayer

    additionLayer(2,Name="add")

    fullyConnectedLayer(numClasses)
    softmaxLayer(Name="softmax")];

net = addLayers(net,layers);

スキップ接続を追加します。

layers = [
    convolution2dLayer(1,32,Stride=2,Name="conv_skip")
    batchNormalizationLayer
    reluLayer(Name="relu_skip")];

net = addLayers(net,layers);
net = connectLayers(net,"relu_1","conv_skip");
net = connectLayers(net,"relu_skip","add/in2");

回帰用に全結合層を追加します。

layers = fullyConnectedLayer(1,Name="fc_2");
net = addLayers(net,layers);
net = connectLayers(net,"add","fc_2");

層グラフをプロットで表示します。

figure
plot(net)

学習オプションの指定

学習オプションを指定します。オプションの中から選択するには、経験的解析が必要です。実験を実行してさまざまな学習オプションの構成を調べるには、Experiment Managerアプリを使用できます。

options = trainingOptions("adam", ...
    Plots="training-progress", ...
    Verbose=false);

ニューラル ネットワークの学習

関数trainnetを使用してニューラル ネットワークに学習させます。分類には、予測ラベルとターゲット ラベルのクロスエントロピー損失に予測角度とターゲット角度の平均二乗誤差損失の 0.1 倍を加えるカスタム損失関数を使用します。既定では、関数 trainnet は利用可能な GPU がある場合にそれを使用します。GPU を使用するには、Parallel Computing Toolbox™ ライセンスとサポートされている GPU デバイスが必要です。サポートされているデバイスについては、GPU 計算の要件 (Parallel Computing Toolbox)を参照してください。そうでない場合、関数は CPU を使用します。実行環境を指定するには、ExecutionEnvironment 学習オプションを使用します。

カスタム損失関数を関数ハンドルとして定義します。損失を、予測ラベルとターゲット ラベルのクロスエントロピー損失に予測角度とターゲット角度の平均二乗誤差の 0.1 倍を加えた値として定義します。

lossFcn = @(Y1,Y2,T1,T2) crossentropy(Y1,T1) + 0.1*mse(Y2,T2);

ニューラル ネットワークを学習させます。

net = trainnet(dsTrain,net,lossFcn,options);

モデルのテスト

数字のデータを読み込みます。データには、数字のイメージ、数字のラベル、垂直方向からの回転角度が含まれています。

load DigitsDataTest
dsXTest = arrayDatastore(XTest,IterationDimension=4);
dsT1Test = arrayDatastore(labelsTest);
dsT2Test = arrayDatastore(anglesTest);

dsTest = combine(dsXTest,dsT1Test,dsT2Test);

関数minibatchpredictを使用して予測を行います。既定では、関数 minibatchpredict は利用可能な GPU がある場合にそれを使用します。GPU を使用するには、Parallel Computing Toolbox™ ライセンスとサポートされている GPU デバイスが必要です。サポートされているデバイスについては、GPU 計算の要件 (Parallel Computing Toolbox)を参照してください。そうでない場合、関数は CPU を使用します。実行環境を指定するには、ExecutionEnvironment オプションを使用します。

[scores,Y2] = minibatchpredict(net,dsTest);
Y1 = scores2label(scores,classNames);

ラベルの分類精度を計算します。

accuracy = mean(Y1 == labelsTest)
accuracy = 0.9838

予測角度とターゲット角度の間の平方根平均二乗誤差を計算します。

err = rmse(Y2,anglesTest)
err = single
    7.5994

一部のイメージと、その予測を表示します。予測角度を赤、正解ラベルを緑で表示します。

idx = randperm(size(XTest,4),9);
figure
for i = 1:9
    subplot(3,3,i)
    I = XTest(:,:,:,idx(i));
    imshow(I)
    hold on

    sz = size(I,1);
    offset = sz/2;

    theta = Y2(idx(i));
    plot(offset*[1-tand(theta) 1+tand(theta)],[sz 0],"r--")

    thetaTest = T2Test(idx(i));
    plot(offset*[1-tand(thetaTest) 1+tand(thetaTest)],[sz 0],"g--")

    hold off
    label = Y1(idx(i));
    title("Label: " + string(label))
end

参考

| | | | | | | | | | |

関連するトピック