Main Content

このページの内容は最新ではありません。最新版の英語を参照するには、ここをクリックします。

複数の出力をもつネットワークの学習

この例では、手書きの数字のラベルと回転角度の両方を予測する、複数の出力をもつ深層学習ネットワークに学習させる方法を説明します。

複数の出力があるネットワークに学習させるには、カスタム学習ループを使用してネットワークに学習させなければなりません。

学習データの読み込み

関数 digitTrain4DArrayData はイメージとその数字ラベル、および垂直方向からの回転角度を読み込みます。イメージ、ラベル、角度について arrayDatastore オブジェクトを作成してから、関数 combine を使用して、すべての学習データを含む単一のデータストアを作成します。クラス名と、離散でない応答の数を抽出します。

[XTrain,T1Train,T2Train] = digitTrain4DArrayData;

dsXTrain = arrayDatastore(XTrain,IterationDimension=4);
dsT1Train = arrayDatastore(T1Train);
dsT2Train = arrayDatastore(T2Train);

dsTrain = combine(dsXTrain,dsT1Train,dsT2Train);

classNames = categories(T1Train);
numClasses = numel(classNames);
numObservations = numel(T1Train);

学習データからの一部のイメージを表示します。

idx = randperm(numObservations,64);
I = imtile(XTrain(:,:,:,idx));
figure
imshow(I)

深層学習モデルの定義

ラベルと回転角度の両方を予測する次のネットワークを定義します。

  • 16 個の 5 x 5 フィルターをもつ convolution-batchnorm-ReLU ブロック

  • それぞれ 32 個の 3 x 3 フィルターをもつ 2 つの convolution-batchnorm-ReLU ブロック

  • 32 個の 1 x 1 の畳み込みをもつ convolution-batchnorm-ReLU ブロックを含む、前述の 2 つのブロックのスキップ接続

  • 加算を使用するスキップ接続のマージ

  • 分類出力用に、サイズが 10 (クラス数) の全結合演算とソフトマックス演算をもつ分岐

  • 回帰出力用に、サイズが 1 (応答数) の全結合演算をもつ分岐

層のメイン ブロックを層グラフとして定義します。

layers = [
    imageInputLayer([28 28 1],Normalization="none")

    convolution2dLayer(5,16,Padding="same")
    batchNormalizationLayer
    reluLayer(Name="relu_1")

    convolution2dLayer(3,32,Padding="same",Stride=2)
    batchNormalizationLayer
    reluLayer
    convolution2dLayer(3,32,Padding="same")
    batchNormalizationLayer
    reluLayer

    additionLayer(2,Name="add")

    fullyConnectedLayer(numClasses)
    softmaxLayer(Name="softmax")];

lgraph = layerGraph(layers);

スキップ接続を追加します。

layers = [
    convolution2dLayer(1,32,Stride=2,Name="conv_skip")
    batchNormalizationLayer
    reluLayer(Name="relu_skip")];

lgraph = addLayers(lgraph,layers);
lgraph = connectLayers(lgraph,"relu_1","conv_skip");
lgraph = connectLayers(lgraph,"relu_skip","add/in2");

回帰用に全結合層を追加します。

layers = fullyConnectedLayer(1,Name="fc_2");
lgraph = addLayers(lgraph,layers);
lgraph = connectLayers(lgraph,"add","fc_2");

層グラフをプロットで表示します。

figure
plot(lgraph)

層グラフから dlnetwork オブジェクトを作成します。

net = dlnetwork(lgraph)
net = 
  dlnetwork with properties:

         Layers: [17×1 nnet.cnn.layer.Layer]
    Connections: [17×2 table]
     Learnables: [20×3 table]
          State: [8×3 table]
     InputNames: {'imageinput'}
    OutputNames: {'softmax'  'fc_2'}
    Initialized: 1

  View summary with summary.

モデル損失関数の定義

例の最後にリストされている関数 modelLoss を作成します。この関数は、dlnetwork オブジェクト、ならびに入力データのミニバッチとそれに対応するターゲット (ラベルと角度を含む) を入力として受け取り、学習可能なパラメーターについての損失と損失の勾配、および更新されたネットワークの状態を返します。

学習オプションの指定

学習オプションを指定します。ミニバッチ サイズを 128 として、学習を 30 エポック行います。

numEpochs = 30;
miniBatchSize = 128;

モデルの学習

minibatchqueueを使用して、イメージのミニバッチを処理および管理します。各ミニバッチで次を行います。

  • カスタム ミニバッチ前処理関数 preprocessMiniBatch (この例の最後に定義) を使用して、クラス ラベルを one-hot 符号化します。

  • イメージ データを次元ラベル "SSCB" (spatial、spatial、channel、batch) で書式設定します。既定では、minibatchqueue オブジェクトは、基となる型が singledlarray オブジェクトにデータを変換します。書式をクラス ラベルまたは角度に追加しないでください。

  • GPU が利用できる場合、GPU で学習を行います。既定では、minibatchqueue オブジェクトは、GPU が利用可能な場合、各出力を gpuArray に変換します。GPU を使用するには、Parallel Computing Toolbox™ とサポートされている GPU デバイスが必要です。サポートされているデバイスについては、GPU 計算の要件 (Parallel Computing Toolbox)を参照してください。

mbq = minibatchqueue(dsTrain,...
    MiniBatchSize=miniBatchSize,...
    MiniBatchFcn=@preprocessData,...
    MiniBatchFormat=["SSCB" "" ""]);

カスタム学習ループを使用してモデルに学習させます。各エポックについて、データをシャッフルしてデータのミニバッチをループで回します。反復が終了するたびに、学習の進行状況を表示します。各ミニバッチで次を行います。

  • 関数 dlfeval および modelLoss を使用してモデルの損失と勾配を評価します。

  • 関数 adamupdate を使用してネットワーク パラメーターを更新します。

Adam 用にパラメーターを初期化します。

trailingAvg = [];
trailingAvgSq = [];

学習の進行状況モニター用に合計反復回数を計算します。

numIterationsPerEpoch = ceil(numObservations / miniBatchSize);
numIterations = numEpochs * numIterationsPerEpoch;

TrainingProgressMonitor オブジェクトを初期化します。monitor オブジェクトを作成するとタイマーが開始されるため、学習ループに近いところでオブジェクトを作成するようにしてください。

monitor = trainingProgressMonitor( ...
    Metrics="Loss", ...
    Info="Epoch", ...
    XLabel="Iteration");

モデルに学習させます。

epoch = 0;
iteration = 0;

while epoch < numEpochs && ~monitor.Stop
    epoch = epoch + 1;
   
    % Shuffle data.
    shuffle(mbq)

    % Loop over mini-batches.
    while hasdata(mbq) && ~monitor.Stop
        iteration = iteration + 1;

        [X,T1,T2] = next(mbq);

        % Evaluate the model loss, gradients, and state using 
        % the dlfeval and modelLoss functions.
        [loss,gradients,state] = dlfeval(@modelLoss,net,X,T1,T2);
        net.State = state;

        % Update the network parameters using the Adam optimizer.
        [net,trailingAvg,trailingAvgSq] = adamupdate(net,gradients, ...
            trailingAvg,trailingAvgSq,iteration);

        % Update the training progress monitor.
        recordMetrics(monitor,iteration,Loss=loss);
        updateInfo(monitor,Epoch=epoch + " of " + numEpochs);
        monitor.Progress = 100*iteration/numIterations;
    end
end

モデルのテスト

真のラベルと角度をもつテスト セットで予測を比較して、モデルの分類精度をテストします。学習データと同じ設定の minibatchqueue オブジェクトを使用して、テスト データ セットを管理します。

[XTest,T1Test,T2Test] = digitTest4DArrayData;

dsXTest = arrayDatastore(XTest,IterationDimension=4);
dsT1Test = arrayDatastore(T1Test);
dsT2Test = arrayDatastore(T2Test);

dsTest = combine(dsXTest,dsT1Test,dsT2Test);

mbqTest = minibatchqueue(dsTest,...
    MiniBatchSize=miniBatchSize,...
    MiniBatchFcn=@preprocessData,...
    MiniBatchFormat=["SSCB" "" ""]);

検証データのラベルと角度を予測するために、ミニバッチをループ処理し、関数 predict を使用します。予測されたクラスと角度を保存します。予測されたクラスおよび角度を真のクラスおよび角度と比較し、その結果を保存します。

classesPredictions = [];
anglesPredictions = [];
classCorr = [];
angleDiff = [];

% Loop over mini-batches.
while hasdata(mbqTest)

    % Read mini-batch of data.
    [X,T1,T2] = next(mbqTest);

    % Make predictions using the predict function.
    [Y1,Y2] = predict(net,X,Outputs=["softmax" "fc_2"]);

    % Determine predicted classes.
    Y1 = onehotdecode(Y1,classNames,1);
    classesPredictions = [classesPredictions Y1];

    % Dermine predicted angles
    Y2 = extractdata(Y2);
    anglesPredictions = [anglesPredictions Y2];

    % Compare predicted and true classes
    T1 = onehotdecode(T1,classNames,1);
    classCorr = [classCorr Y1 == T1];

    % Compare predicted and true angles
    angleDiffBatch = Y2 - T2;
    angleDiffBatch = extractdata(gather(angleDiffBatch));
    angleDiff = [angleDiff angleDiffBatch];
end

分類精度を評価します。

accuracy = mean(classCorr)
accuracy = 0.9882

回帰精度を評価します。

angleRMSE = sqrt(mean(angleDiff.^2))
angleRMSE = single
    6.3569

一部のイメージと、その予測を表示します。予測角度を赤、正解ラベルを緑で表示します。

idx = randperm(size(XTest,4),9);
figure
for i = 1:9
    subplot(3,3,i)
    I = XTest(:,:,:,idx(i));
    imshow(I)
    hold on

    sz = size(I,1);
    offset = sz/2;

    thetaPred = anglesPredictions(idx(i));
    plot(offset*[1-tand(thetaPred) 1+tand(thetaPred)],[sz 0],"r--")

    thetaValidation = T2Test(idx(i));
    plot(offset*[1-tand(thetaValidation) 1+tand(thetaValidation)],[sz 0],"g--")

    hold off
    label = string(classesPredictions(idx(i)));
    title("Label: " + label)
end

モデル損失関数

関数 modelLoss は、dlnetwork オブジェクト net、入力データのミニバッチ X と対応するターゲット T1 および T2 (それぞれラベルと角度を含む) を入力として受け取り、学習可能なパラメーターについての損失と損失の勾配、および更新されたネットワークの状態を返します。

function [loss,gradients,state] = modelLoss(net,X,T1,T2)

[Y1,Y2,state] = forward(net,X,Outputs=["softmax" "fc_2"]);

lossLabels = crossentropy(Y1,T1);
lossAngles = mse(Y2,T2);

loss = lossLabels + 0.1*lossAngles;
gradients = dlgradient(loss,net.Learnables);

end

ミニバッチ前処理関数

関数 preprocessMiniBatch は、次の手順でデータを前処理します。

  1. 入力 cell 配列からイメージ データを抽出して数値配列に連結します。4 番目の次元でイメージ データを連結することにより、3 番目の次元が各イメージに追加されます。この次元は、シングルトン チャネル次元として使用されます。

  2. 入力 cell 配列からラベルと角度データを抽出して、それを 2 番目の次元と共に、categorical 配列および数値配列にそれぞれ連結します。

  3. カテゴリカル ラベルを数値配列に one-hot 符号化します。最初の次元への符号化は、ネットワーク出力の形状と一致する符号化された配列を生成します。

function [X,T1,T2] = preprocessData(dataX,dataT1,dataT2)

% Extract image data from cell and concatenate
X = cat(4,dataX{:});

% Extract label data from cell and concatenate
T1 = cat(2,dataT1{:});

% Extract angle data from cell and concatenate
T2 = cat(2,dataT2{:});

% One-hot encode labels
T1 = onehotencode(T1,1);

end

参考

| | | | | | | | | | |

関連するトピック