Main Content

カスタム学習ループでのバッチ正規化統計量の更新

この例では、カスタム学習ループでネットワークの状態を更新する方法を示します。

バッチ正規化層は、ミニバッチ全体で各入力チャネルを正規化します。畳み込みニューラル ネットワークの学習速度を上げ、ネットワークの初期化に対する感度を下げるには、畳み込み層の間にあるバッチ正規化層と、ReLU 層などの非線形性を使用します。

学習中、バッチ正規化層は、まず、ミニバッチの平均を減算し、ミニバッチの標準偏差で除算することにより、各チャネルの活性化を正規化します。その後、この層は、学習可能なオフセット β だけ入力をシフトし、それを学習可能なスケール係数 γ だけスケーリングします。

ネットワークの学習が終了したら、バッチ正規化層は学習セット全体の平均と分散を計算し、その値を TrainedMean プロパティおよび TrainedVariance プロパティに格納します。学習済みネットワークを使用して新しいイメージについて予測を実行する場合、バッチ正規化層はミニバッチの平均と分散ではなく、学習済みの平均と分散を使用して活性化を正規化します。

データセットの統計量を計算するために、バッチ正規化層は継続的に更新される状態を使用してミニバッチの統計量を追跡します。カスタム学習ループを実装している場合、ミニバッチ間でネットワークの状態を更新しなければなりません。

学習データの読み込み

関数 digitTrain4DArrayData は、手書き数字のイメージとその数字ラベルを読み込みます。イメージと角度について arrayDatastore オブジェクトを作成してから、関数 combine を使用してすべての学習データを含む単一のデータストアを作成します。クラス名を抽出します。

[XTrain,TTrain] = digitTrain4DArrayData;

dsXTrain = arrayDatastore(XTrain,IterationDimension=4);
dsTTrain = arrayDatastore(TTrain);

dsTrain = combine(dsXTrain,dsTTrain);

classNames = categories(TTrain);
numClasses = numel(classNames);

ネットワークの定義

ネットワークを定義し、イメージ入力層で Mean オプションを使用して平均イメージを指定します。

layers = [
    imageInputLayer([28 28 1],Mean=mean(XTrain,4))
    convolution2dLayer(5,20)
    batchNormalizationLayer
    reluLayer
    convolution2dLayer(3,20,Padding=1)
    batchNormalizationLayer
    reluLayer
    convolution2dLayer(3,20,Padding=1)
    batchNormalizationLayer
    reluLayer
    fullyConnectedLayer(numClasses)
    softmaxLayer];

層配列から dlnetwork オブジェクトを作成します。

net = dlnetwork(layers)
net = 
  dlnetwork with properties:

         Layers: [12×1 nnet.cnn.layer.Layer]
    Connections: [11×2 table]
     Learnables: [14×3 table]
          State: [6×3 table]
     InputNames: {'imageinput'}
    OutputNames: {'softmax'}
    Initialized: 1

ネットワークの状態を表示します。各バッチ正規化層は、データセットの平均と分散をぞれぞれ含む、TrainedMean パラメーターと TrainedVariance パラメーターをもちます。

net.State
ans=6×3 table
        Layer            Parameter             Value      
    _____________    _________________    ________________

    "batchnorm_1"    "TrainedMean"        {1×1×20 dlarray}
    "batchnorm_1"    "TrainedVariance"    {1×1×20 dlarray}
    "batchnorm_2"    "TrainedMean"        {1×1×20 dlarray}
    "batchnorm_2"    "TrainedVariance"    {1×1×20 dlarray}
    "batchnorm_3"    "TrainedMean"        {1×1×20 dlarray}
    "batchnorm_3"    "TrainedVariance"    {1×1×20 dlarray}

モデル損失関数の定義

この例の最後にリストされている関数 modelLoss を作成します。この関数は dlnetwork オブジェクト、入力データのミニバッチとそれに対応するラベルを入力として受け取り、損失、学習可能なパラメーターについての損失の勾配、および更新されたネットワークの状態を返します。

学習オプションの指定

ミニバッチ サイズを 128 として、学習を 5 エポック行います。SGDM 最適化では、学習率に 0.01、モーメンタムに 0.9 を指定します。

numEpochs = 5;
miniBatchSize = 128;

learnRate = 0.01;
momentum = 0.9;

モデルの学習

minibatchqueueを使用して、イメージのミニバッチを処理および管理します。各ミニバッチで次を行います。

  • カスタム ミニバッチ前処理関数 preprocessMiniBatch (この例の最後に定義) を使用して、クラス ラベルを one-hot 符号化します。

  • イメージ データを次元ラベル "SSCB" (spatial、spatial、channel、batch) で書式設定します。既定では、minibatchqueue オブジェクトは、基となる型が singledlarray オブジェクトにデータを変換します。書式をクラス ラベルに追加しないでください。

  • GPU が利用できる場合、GPU で学習を行います。既定では、minibatchqueue オブジェクトは、GPU が利用可能な場合、各出力を gpuArray に変換します。GPU を使用するには、Parallel Computing Toolbox™ とサポートされている GPU デバイスが必要です。サポートされているデバイスについては、リリース別の GPU サポート (Parallel Computing Toolbox)を参照してください。

mbq = minibatchqueue(dsTrain,...
    MiniBatchSize=miniBatchSize,...
    MiniBatchFcn=@preprocessMiniBatch,...
    MiniBatchFormat=["SSCB" ""]);

学習の進行状況プロットを初期化します。

figure
C = colororder;
lineLossTrain = animatedline(Color=C(2,:));
ylim([0 inf])
xlabel("Iteration")
ylabel("Loss")
grid on

SGDM ソルバーの速度パラメーターを初期化します。

velocity = [];

カスタム学習ループを使用してモデルに学習させます。各エポックについて、データをシャッフルしてデータのミニバッチをループで回します。反復が終了するたびに、学習の進行状況を表示します。各ミニバッチで次を行います。

  • 関数 dlfeval と関数 modelLoss を使用してモデルの損失、勾配、および状態を評価し、ネットワークの状態を更新。

  • 関数 sgdmupdate を使用してネットワーク パラメーターを更新。

iteration = 0;
start = tic;

% Loop over epochs.
for epoch = 1:numEpochs
    % Shuffle data.
    shuffle(mbq)

    % Loop over mini-batches.
    while hasdata(mbq)

        iteration = iteration + 1;

        % Read mini-batch of data and convert the labels to dummy
        % variables.
        [X,T] = next(mbq);

        % Evaluate the model loss, gradients, and state using dlfeval and the
        % modelLoss function and update the network state.
        [loss,gradients,state] = dlfeval(@modelLoss,net,X,T);
        net.State = state;

        % Update the network parameters using the SGDM optimizer.
        [net, velocity] = sgdmupdate(net, gradients, velocity, learnRate, momentum);

        % Display the training progress.
        D = duration(0,0,toc(start),Format="hh:mm:ss");
        loss = double(loss);
        addpoints(lineLossTrain,iteration,loss)
        title("Epoch: " + epoch + ", Elapsed: " + string(D))
        drawnow
    end
end

モデルのテスト

真のラベルをもつテスト セットで予測を比較して、モデルの分類精度をテストします。真のラベルと角度をもつテスト セットで予測を比較して、モデルの分類精度をテストします。

テスト データを読み込み、イメージと特徴を含む結合されたデータストアを作成します。

[XTest,TTest] = digitTest4DArrayData;
dsTest = arrayDatastore(XTest,IterationDimension=4);

テスト中にイメージのミニバッチを処理および管理する minibatchqueue オブジェクトを作成します。各ミニバッチで次を行います。

  • カスタム ミニバッチ前処理関数 preprocessMiniBatchPredictors (この例の最後に定義) を使用します。

  • 既定では、minibatchqueue オブジェクトは、基となる型が singledlarray オブジェクトにデータを変換します。イメージを次元ラベル "SSCB" (spatial、spatial、channel、batch) で書式設定します。

mbqTest = minibatchqueue(dsTest,...
    MiniBatchSize=miniBatchSize,...
    MiniBatchFcn=@preprocessMiniBatchPredictors,...
    MiniBatchFormat="SSCB");

例の最後にリストされている関数 modelPredictions を使用してイメージを分類します。

predictions = modelPredictions(net,mbqTest,classNames);

分類精度を評価します。

accuracy = mean(predictions == TTest)
accuracy = 0.9934

モデル損失関数

関数 modelLoss は、dlnetwork オブジェクト net、入力データのミニバッチ X とそれに対応するラベル T を入力として受け取り、損失、net 内の学習可能パラメーターについての損失の勾配、およびネットワークの状態を返します。勾配を自動的に計算するには、関数 dlgradient を使用します。

function [loss,gradients,state] = modelLoss(net,X,T)

[Y,state] = forward(net,X);

loss = crossentropy(Y,T);
gradients = dlgradient(loss,net.Learnables);

end

モデル予測関数

関数 modelPredictions は、dlnetwork オブジェクト net、入力データの minibatchqueue オブジェクト mbq を入力として受け取り、minibatchqueue のすべてのデータを反復処理することでモデル予測を計算します。この関数は、関数 onehotdecode を使用して、スコアが最も高い予測されたクラスを見つけます。

function predictions = modelPredictions(net,mbq,classes)

predictions = [];

while hasdata(mbq)
    X = next(mbq);

    % Make predictions using the model function.
    Y = predict(net,X);

    % Determine predicted classes.
    YPredBatch = onehotdecode(Y,classes,1);
    predictions = [predictions; YPredBatch'];
end

end

ミニ バッチ前処理関数

関数 preprocessMiniBatch は、次の手順でデータを前処理します。

  1. 関数 preprocessMiniBatchPredictors を使用してイメージと特徴を前処理します。

  2. 入力 cell 配列からラベル データを抽出し、2 番目の次元に沿って連結します。

  3. カテゴリカル ラベルを数値配列に one-hot 符号化します。最初の次元への符号化は、ネットワーク出力の形状と一致する符号化された配列を生成します。

function [X,T] = preprocessMiniBatch(dataX,dataY)

% Extract image data from cell and concatenate
X = cat(4,dataX{:});

% Extract label data from cell and concatenate
T = cat(2,dataY{:});

% One-hot encode labels
T = onehotencode(T,1);

end

ミニバッチ予測子前処理関数

関数 preprocessMiniBatchPredictors は、入力 cell 配列からイメージ データを抽出して数値配列に連結することにより、予測子を前処理します。4 番目の次元でイメージ データを連結することにより、3 番目の次元が各イメージに追加されます。この次元は、シングルトン チャネル次元として使用されます。

function X = preprocessMiniBatchPredictors(dataX)

% Extract image data from cell and concatenate
X = cat(4,dataX{:});

end

参考

| | | | | | | | |

関連するトピック