Main Content

このページの翻訳は最新ではありません。ここをクリックして、英語の最新版を参照してください。

複数の出力をもつネットワークの学習

この例では、手書きの数字のラベルと回転角度の両方を予測する、複数の出力をもつ深層学習ネットワークに学習させる方法を説明します。

複数の出力があるネットワークに学習させるには、カスタム学習ループを使用してネットワークに学習させなければなりません。

学習データの読み込み

関数 digitTrain4DArrayData はイメージとその数字ラベル、および垂直方向からの回転角度を読み込みます。イメージ、ラベル、角度について arrayDatastore オブジェクトを作成してから、関数 combine を使用して、すべての学習データを含む単一のデータストアを作成します。クラス名と、離散でない応答の数を抽出します。

[XTrain,YTrain,anglesTrain] = digitTrain4DArrayData;

dsXTrain = arrayDatastore(XTrain,'IterationDimension',4);
dsYTrain = arrayDatastore(YTrain);
dsAnglesTrain = arrayDatastore(anglesTrain);

dsTrain = combine(dsXTrain,dsYTrain,dsAnglesTrain);

classNames = categories(YTrain);
numClasses = numel(classNames);
numObservations = numel(YTrain);

学習データからの一部のイメージを表示します。

idx = randperm(numObservations,64);
I = imtile(XTrain(:,:,:,idx));
figure
imshow(I)

深層学習モデルの定義

ラベルと回転角度の両方を予測する次のネットワークを定義します。

  • 16 個の 5 x 5 フィルターをもつ convolution-batchnorm-ReLU ブロック

  • それぞれ 32 個の 3 x 3 フィルターをもつ 2 つの convolution-batchnorm-ReLU ブロック

  • 32 個の 1 x 1 の畳み込みをもつ convolution-batchnorm-ReLU ブロックを含む、前述の 2 つのブロックのスキップ接続

  • 加算を使用するスキップ接続のマージ

  • 分類出力用に、サイズが 10 (クラス数) の全結合演算とソフトマックス演算をもつ分岐

  • 回帰出力用に、サイズが 1 (応答数) の全結合演算をもつ分岐

層のメイン ブロックを層グラフとして定義します。

layers = [
    imageInputLayer([28 28 1],'Normalization','none','Name','in')
    
    convolution2dLayer(5,16,'Padding','same','Name','conv1')
    batchNormalizationLayer('Name','bn1')
    reluLayer('Name','relu1')
    
    convolution2dLayer(3,32,'Padding','same','Stride',2,'Name','conv2')
    batchNormalizationLayer('Name','bn2')
    reluLayer('Name','relu2')
    convolution2dLayer(3,32,'Padding','same','Name','conv3')
    batchNormalizationLayer('Name','bn3')
    reluLayer('Name','relu4')

    additionLayer(2,'Name','addition')
    
    fullyConnectedLayer(numClasses,'Name','fc1')
    softmaxLayer('Name','softmax')];

lgraph = layerGraph(layers);

スキップ接続を追加します。

layers = [
    convolution2dLayer(1,32,'Stride',2,'Name','convSkip')
    batchNormalizationLayer('Name','bnSkip')
    reluLayer('Name','reluSkip')];

lgraph = addLayers(lgraph,layers);
lgraph = connectLayers(lgraph,'relu1','convSkip');
lgraph = connectLayers(lgraph,'reluSkip','addition/in2');

回帰用に全結合層を追加します。

layers = fullyConnectedLayer(1,'Name','fc2');
lgraph = addLayers(lgraph,layers);
lgraph = connectLayers(lgraph,'addition','fc2');

層グラフをプロットで表示します。

figure
plot(lgraph)

層グラフから dlnetwork オブジェクトを作成します。

dlnet = dlnetwork(lgraph)
dlnet = 
  dlnetwork with properties:

         Layers: [17×1 nnet.cnn.layer.Layer]
    Connections: [17×2 table]
     Learnables: [20×3 table]
          State: [8×3 table]
     InputNames: {'in'}
    OutputNames: {'softmax'  'fc2'}

モデル勾配関数の定義

例の最後にリストされている関数 modelGradients を作成します。この関数は、dlnetwork オブジェクト dlnet、入力データ dlX のミニバッチと対応するターゲット T1 および T2 (それぞれラベルと角度を含む) を入力として受け取り、学習可能なパラメーターについての損失の勾配、更新されたネットワーク状態、および対応する損失を返します。

学習オプションの指定

学習オプションを指定します。ミニバッチ サイズを 128 として、学習を 30 エポック行います。

numEpochs = 30;
miniBatchSize = 128;

学習の進行状況をプロットに可視化します。

plots = "training-progress";

モデルの学習

minibatchqueueを使用して、イメージのミニバッチを処理および管理します。各ミニバッチで次を行います。

  • カスタム ミニバッチ前処理関数 preprocessMiniBatch (この例の最後に定義) を使用して、クラス ラベルを one-hot 符号化します。

  • イメージ データを次元ラベル 'SSCB' (spatial、spatial、channel、batch) で書式設定します。既定では、minibatchqueue オブジェクトは、基となる型が singledlarray オブジェクトにデータを変換します。書式をクラス ラベルまたは角度に追加しないでください。

  • GPU が利用できる場合、GPU で学習を行います。既定では、minibatchqueue オブジェクトは、GPU が利用可能な場合、各出力を gpuArray に変換します。GPU を使用するには、Parallel Computing Toolbox™、および Compute Capability 3.0 以上の CUDA® 対応 NVIDIA® GPU が必要です。

mbq = minibatchqueue(dsTrain,...
    'MiniBatchSize',miniBatchSize,...
    'MiniBatchFcn', @preprocessData,...
    'MiniBatchFormat',{'SSCB','',''});

カスタム学習ループを使用してモデルに学習させます。各エポックについて、データをシャッフルしてデータのミニバッチをループで回します。反復が終了するたびに、学習の進行状況を表示します。各ミニバッチで次を行います。

  • 関数 dlfeval および modelGradients を使用してモデルの勾配と損失を評価。

  • 関数 adamupdate を使用してネットワーク パラメーターを更新。

学習の進行状況プロットを初期化します。

if plots == "training-progress"
    figure
    lineLossTrain = animatedline('Color',[0.85 0.325 0.098]);
    ylim([0 inf])
    xlabel("Iteration")
    ylabel("Loss")
    grid on
end

Adam 用にパラメーターを初期化します。

trailingAvg = [];
trailingAvgSq = [];

モデルに学習させます。

iteration = 0;
start = tic;

% Loop over epochs.
for epoch = 1:numEpochs
    
    % Shuffle data.
    shuffle(mbq)
    
    % Loop over mini-batches
    while hasdata(mbq)
        
        iteration = iteration + 1;
        
        [dlX,dlY1,dlY2] = next(mbq);
                       
        % Evaluate the model gradients, state, and loss using dlfeval and the
        % modelGradients function.
        [gradients,state,loss] = dlfeval(@modelGradients, dlnet, dlX, dlY1, dlY2);
        dlnet.State = state;
        
        % Update the network parameters using the Adam optimizer.
        [dlnet,trailingAvg,trailingAvgSq] = adamupdate(dlnet,gradients, ...
            trailingAvg,trailingAvgSq,iteration);
        
        % Display the training progress.
        if plots == "training-progress"
            D = duration(0,0,toc(start),'Format','hh:mm:ss');
            addpoints(lineLossTrain,iteration,double(gather(extractdata(loss))))
            title("Epoch: " + epoch + ", Elapsed: " + string(D))
            drawnow
        end
    end
end

モデルのテスト

真のラベルと角度をもつテスト セットで予測を比較して、モデルの分類精度をテストします。学習データと同じ設定の minibatchqueue オブジェクトを使用して、テスト データ セットを管理します。

[XTest,Y1Test,anglesTest] = digitTest4DArrayData;

dsXTest = arrayDatastore(XTest,'IterationDimension',4);
dsYTest = arrayDatastore(Y1Test);
dsAnglesTest = arrayDatastore(anglesTest);

dsTest = combine(dsXTest,dsYTest,dsAnglesTest);

mbqTest = minibatchqueue(dsTest,...
    'MiniBatchSize',miniBatchSize,...
    'MiniBatchFcn', @preprocessData,...
    'MiniBatchFormat',{'SSCB','',''});

検証データのラベルと角度を予測するために、ミニバッチをループ処理し、関数 predict を使用します。予測されたクラスと角度を保存します。予測されたクラスおよび角度を真のクラスおよび角度と比較し、その結果を保存します。

classesPredictions = [];
anglesPredictions = [];
classCorr = [];
angleDiff = [];

% Loop over mini-batches.
while hasdata(mbqTest)
    
    % Read mini-batch of data.
    [dlXTest,dlY1Test,dlY2Test] = next(mbqTest);
    
    % Make predictions using the predict function.
    [dlY1Pred,dlY2Pred] = predict(dlnet,dlXTest,'Outputs',["softmax" "fc2"]);
    
    % Determine predicted classes.
    Y1PredBatch = onehotdecode(dlY1Pred,classNames,1);
    classesPredictions = [classesPredictions Y1PredBatch];
    
    % Dermine predicted angles
    Y2PredBatch = extractdata(dlY2Pred);
    anglesPredictions = [anglesPredictions Y2PredBatch];
    
    % Compare predicted and true classes
    Y1Test = onehotdecode(dlY1Test,classNames,1);
    classCorr = [classCorr Y1PredBatch == Y1Test];
    
    % Compare predicted and true angles
    angleDiffBatch = Y2PredBatch - dlY2Test;
    angleDiff = [angleDiff extractdata(gather(angleDiffBatch))];
    
end

分類精度を評価します。

accuracy = mean(classCorr)
accuracy = 0.9814

回帰精度を評価します。

angleRMSE = sqrt(mean(angleDiff.^2))
angleRMSE = single
    7.7431

一部のイメージと、その予測を表示します。予測角度を赤、正解ラベルを緑で表示します。

idx = randperm(size(XTest,4),9);
figure
for i = 1:9
    subplot(3,3,i)
    I = XTest(:,:,:,idx(i));
    imshow(I)
    hold on
    
    sz = size(I,1);
    offset = sz/2;
    
    thetaPred = anglesPredictions(idx(i));
    plot(offset*[1-tand(thetaPred) 1+tand(thetaPred)],[sz 0],'r--')
    
    thetaValidation = anglesTest(idx(i));
    plot(offset*[1-tand(thetaValidation) 1+tand(thetaValidation)],[sz 0],'g--')
    
    hold off
    label = string(classesPredictions(idx(i)));
    title("Label: " + label)
end

モデル勾配関数

関数 modelGradients は、dlnetwork オブジェクト dlnet、入力データ dlX のミニバッチと対応するターゲット T1 および T2 (それぞれラベルと角度を含む) を入力として受け取り、学習可能なパラメーターについての損失の勾配、更新されたネットワークの状態、および対応する損失を返します。

function [gradients,state,loss] = modelGradients(dlnet,dlX,T1,T2)

[dlY1,dlY2,state] = forward(dlnet,dlX,'Outputs',["softmax" "fc2"]);

lossLabels = crossentropy(dlY1,T1);
lossAngles = mse(dlY2,T2);

loss = lossLabels + 0.1*lossAngles;
gradients = dlgradient(loss,dlnet.Learnables);

end

ミニバッチ前処理関数

関数 preprocessMiniBatch は、次の手順でデータを前処理します。

  1. 入力 cell 配列からイメージ データを抽出して数値配列に連結します。4 番目の次元でイメージ データを連結することにより、3 番目の次元が各イメージに追加されます。この次元は、シングルトン チャネル次元として使用されます。

  2. 入力 cell 配列からラベルと角度データを抽出して、それを 2 番目の次元と共に、categorical 配列および数値配列にそれぞれ連結します。

  3. カテゴリカル ラベルを数値配列に one-hot 符号化します。最初の次元への符号化は、ネットワーク出力の形状と一致する符号化された配列を生成します。

function [X,Y,angle] = preprocessData(XCell,YCell,angleCell)
    
    % Extract image data from cell and concatenate
    X = cat(4,XCell{:});
    % Extract label data from cell and concatenate
    Y = cat(2,YCell{:});
    % Extract angle data from cell and concatenate
    angle = cat(2,angleCell{:});
        
    % One-hot encode labels
    Y = onehotencode(Y,1);
    
end

参考

| | | | | | | | | | |

関連するトピック