Main Content

モデル関数を使用したバッチ正規化統計量の更新

この例では、関数として定義されたネットワークにおいて、ネットワークの状態を更新する方法を示します。

バッチ正規化演算は、ミニバッチ全体で各入力チャネルを正規化します。畳み込みニューラル ネットワークの学習速度を上げ、ネットワークの初期化に対する感度を下げるには、畳み込みの間にあるバッチ正規化演算と、ReLU 層などの非線形性を使用します。

学習中、バッチ正規化演算は、まず、ミニバッチの平均を減算し、ミニバッチの標準偏差で除算することにより、各チャネルの活性化を正規化します。その後、この演算は、学習可能なオフセット β だけ入力をシフトし、それを学習可能なスケール係数 γ だけスケーリングします。

学習済みネットワークを使用して新しいデータについて予測を実行する場合、バッチ正規化演算はミニバッチの平均と分散ではなく、学習済みのデータ セットの平均と分散を使用して活性化を正規化します。

データ セット統計を計算するため、継続的に更新される状態を使用して、ミニバッチ統計を追跡しなければなりません。

モデル関数においてバッチ正規化演算を使用する場合、学習と予測の両方の動作を定義しなければなりません。たとえば、boolean オプション doTraining を指定して、モデルが学習のためにミニバッチ統計を使用するか、予測のためにデータセット統計を使用するかを制御することができます。

次はモデル関数からのコード例で、バッチ正規化演算を適用し、学習中にデータセット統計のみを更新する方法を示します。

if doTraining
    [Y,trainedMean,trainedVariance] = batchnorm(Y,offset,scale,trainedMean,trainedVariance);
    
    % Update state
    state.batchnorm1.TrainedMean = trainedMean;
    state.batchnorm1.TrainedVariance = trainedVariance;
else
    Y = batchnorm(Y,offset,scale,trainedMean,trainedVariance);
end

学習データの読み込み

関数 digitTrain4DArrayData はイメージとその数字ラベル、および垂直方向からの回転角度を読み込みます。イメージ、ラベル、角度について arrayDatastore オブジェクトを作成してから、関数 combine を使用して、すべての学習データを含む単一のデータストアを作成します。クラス名と、離散でない応答の数を抽出します。

[XTrain,TTrain,anglesTrain] = digitTrain4DArrayData;

dsXTrain = arrayDatastore(XTrain,IterationDimension=4);
dsTTrain = arrayDatastore(TTrain);
dsAnglesTrain = arrayDatastore(anglesTrain);

dsTrain = combine(dsXTrain,dsTTrain,dsAnglesTrain);

classNames = categories(TTrain);
numClasses = numel(classNames);
numResponses = size(anglesTrain,2);
numObservations = numel(TTrain);

学習データからの一部のイメージを表示します。

idx = randperm(numObservations,64);
I = imtile(XTrain(:,:,:,idx));
figure
imshow(I)

深層学習モデルの定義

ラベルと回転角度の両方を予測する次のネットワークを定義します。

  • 16 個の 5 x 5 フィルターをもつ convolution-batchnorm-ReLU ブロック

  • 各ブロックに 32 個の 3 x 3 フィルターがあり、間に ReLU 演算をもつ、2 個の convolution-batchnorm ブロックの分岐

  • 32 個の 1 x 1 の畳み込みをもつ convolution-batchnorm ブロックのあるスキップ接続

  • 加算とそれに続く ReLU 演算を使用した両方の分岐の組み合わせ

  • 回帰出力用に、サイズが 1 (応答数) の全結合演算をもつ分岐

  • 分類出力用に、サイズが 10 (クラス数) の全結合演算とソフトマックス演算をもつ分岐

モデルのパラメーターと状態の定義および初期化

各演算のパラメーターを定義して構造体に含めます。parameters.OperationName.ParameterName の形式を使用します。ここで、parameters は構造体、OperationName は演算名 ("conv1" など)、ParameterName はパラメーター名 ("Weights" など) です。

モデル パラメーターを含む構造体 parameters を作成します。サンプル関数 initializeGlorot および initializeZeros を使用して、学習可能な層の重みとバイアスをそれぞれ初期化します。サンプル関数 initializeZeros および initializeOnes を使用して、バッチ正規化オフセットとスケール パラメーターをそれぞれ初期化します。

バッチ正規化層を使用して学習や推論を実行するには、ネットワークの状態も管理しなければなりません。予測の前に、学習データから派生するデータセットの平均と分散を指定しなければなりません。状態パラメーターを含む構造体 state を作成します。バッチ正規化の統計値は、dlarray オブジェクトにしないでください。関数 zeros および ones を使用して、バッチ正規化の学習済み平均と学習済み分散の状態をそれぞれ初期化します。

この初期化サンプル関数は、この例にサポート ファイルとして添付されています。

最初の畳み込み層のパラメーターを初期化します。

filterSize = [5 5];
numChannels = 1;
numFilters = 16;

sz = [filterSize numChannels numFilters];
numOut = prod(filterSize) * numFilters;
numIn = prod(filterSize) * numFilters;

parameters.conv1.Weights = initializeGlorot(sz,numOut,numIn);
parameters.conv1.Bias = initializeZeros([numFilters 1]);

最初のバッチ正規化層のパラメーターと状態を初期化します。

parameters.batchnorm1.Offset = initializeZeros([numFilters 1]);
parameters.batchnorm1.Scale = initializeOnes([numFilters 1]);
state.batchnorm1.TrainedMean = initializeZeros([numFilters 1]);
state.batchnorm1.TrainedVariance = initializeOnes([numFilters 1]);

2 番目の畳み込み層のパラメーターを初期化します。

filterSize = [3 3];
numChannels = 16;
numFilters = 32;

sz = [filterSize numChannels numFilters];
numOut = prod(filterSize) * numFilters;
numIn = prod(filterSize) * numFilters;

parameters.conv2.Weights = initializeGlorot(sz,numOut,numIn);
parameters.conv2.Bias = initializeZeros([numFilters 1]);

2 番目のバッチ正規化層のパラメーターと状態を初期化します。

parameters.batchnorm2.Offset = initializeZeros([numFilters 1]);
parameters.batchnorm2.Scale = initializeOnes([numFilters 1]);
state.batchnorm2.TrainedMean = initializeZeros([numFilters 1]);
state.batchnorm2.TrainedVariance = initializeOnes([numFilters 1]);

3 番目の畳み込み層のパラメーターを初期化します。

filterSize = [3 3];
numChannels = 32;
numFilters = 32;

sz = [filterSize numChannels numFilters];
numOut = prod(filterSize) * numFilters;
numIn = prod(filterSize) * numFilters;

parameters.conv3.Weights = initializeGlorot(sz,numOut,numIn);
parameters.conv3.Bias = initializeZeros([numFilters 1]);

3 番目のバッチ正規化層のパラメーターと状態を初期化します。

parameters.batchnorm3.Offset = initializeZeros([numFilters 1]);
parameters.batchnorm3.Scale = initializeOnes([numFilters 1]);
state.batchnorm3.TrainedMean = initializeZeros([numFilters 1]);
state.batchnorm3.TrainedVariance = initializeOnes([numFilters 1]);

スキップ接続における畳み込み層のパラメーターを初期化します。

filterSize = [1 1];
numChannels = 16;
numFilters = 32;

sz = [filterSize numChannels numFilters];
numOut = prod(filterSize) * numFilters;
numIn = prod(filterSize) * numFilters;

parameters.convSkip.Weights = initializeGlorot(sz,numOut,numIn);
parameters.convSkip.Bias = initializeZeros([numFilters 1]);

スキップ接続におけるバッチ正規化層のパラメーターと状態を初期化します。

parameters.batchnormSkip.Offset = initializeZeros([numFilters 1]);
parameters.batchnormSkip.Scale = initializeOnes([numFilters 1]);
state.batchnormSkip.TrainedMean = initializeZeros([numFilters 1]);
state.batchnormSkip.TrainedVariance = initializeOnes([numFilters 1]);

分類出力に対応する全結合層のパラメーターを初期化します。

sz = [numClasses 6272];
numOut = numClasses;
numIn = 6272;
parameters.fc1.Weights = initializeGlorot(sz,numOut,numIn);
parameters.fc1.Bias = initializeZeros([numClasses 1]);

回帰出力に対応する全結合層のパラメーターを初期化します。

sz = [numResponses 6272];
numOut = numResponses;
numIn = 6272;
parameters.fc2.Weights = initializeGlorot(sz,numOut,numIn);
parameters.fc2.Bias = initializeZeros([numResponses 1]);

状態の構造体を表示します。

state
state = struct with fields:
       batchnorm1: [1×1 struct]
       batchnorm2: [1×1 struct]
       batchnorm3: [1×1 struct]
    batchnormSkip: [1×1 struct]

演算 batchnorm1 の状態パラメーターを表示します。

state.batchnorm1
ans = struct with fields:
        TrainedMean: [16×1 dlarray]
    TrainedVariance: [16×1 dlarray]

モデルの関数の定義

この例の最後にリストされている関数 model を作成します。この関数は前に説明した深層学習モデルの出力を計算します。

関数 model は、モデル パラメーター parameters、入力データ、モデルが学習と予測のどちらの出力を返すかを指定するフラグ doTraining、およびネットワークの状態 state を入力として受け取ります。ネットワークはラベルの予測、角度の予測、および更新されたネットワークの状態を出力します。

モデル損失関数の定義

この例の最後にリストされている関数 modelLoss を作成します。この関数は入力データのミニバッチとそれに対応するターゲット T1 および T2 (それぞれラベルと角度を含む) を入力として受け取り、損失、学習可能なパラメーターについての損失の勾配、および更新されたネットワークの状態を返します。

学習オプションの指定

学習オプションを指定します。

numEpochs = 20;
miniBatchSize = 128;

モデルの学習

カスタム学習ループを使用してモデルに学習させます。minibatchqueueを使用して、イメージのミニバッチを処理および管理します。各ミニバッチで次を行います。

  • カスタム ミニバッチ前処理関数 preprocessMiniBatch (この例の最後に定義) を使用して、クラス ラベルを one-hot 符号化します。

  • イメージ データを次元ラベル "SSCB" (spatial、spatial、channel、batch) で書式設定します。既定では、minibatchqueue オブジェクトは、基となる型が singledlarray オブジェクトにデータを変換します。書式をクラス ラベルまたは角度に追加しないでください。

  • GPU が利用できる場合、GPU で学習を行います。既定では、minibatchqueue オブジェクトは、GPU が利用可能な場合、各出力を gpuArray に変換します。GPU を使用するには、Parallel Computing Toolbox™ とサポートされている GPU デバイスが必要です。サポートされているデバイスについては、リリース別の GPU サポート (Parallel Computing Toolbox)を参照してください。

mbq = minibatchqueue(dsTrain,...
    MiniBatchSize=miniBatchSize,...
    MiniBatchFcn=@preprocessMiniBatch,...
    MiniBatchFormat=["SSCB" "" ""]);

各エポックについて、データをシャッフルしてデータのミニバッチをループで回します。各エポックの最後に、学習の進行状況を表示します。各ミニバッチで次を行います。

  • 関数 dlfeval および modelLoss を使用してモデルの損失と勾配を評価。

  • 関数 adamupdate を使用してネットワーク パラメーターを更新。

Adam ソルバーのパラメーターを初期化します。

trailingAvg = [];
trailingAvgSq = [];

学習の進行状況プロットを初期化します。

figure
C = colororder;
lineLossTrain = animatedline(Color=C(2,:));
ylim([0 inf])
xlabel("Iteration")
ylabel("Loss")
grid on

モデルに学習させます。

iteration = 0;
start = tic;

% Loop over epochs.
for epoch = 1:numEpochs

    % Shuffle data.
    shuffle(mbq)

    % Loop over mini-batches
    while hasdata(mbq)
        iteration = iteration + 1;

        [X,T1,T2] = next(mbq);

        % Evaluate the model loss, gradients, and state using dlfeval and the
        % modelLoss function.
        [loss,gradients,state] = dlfeval(@modelLoss,parameters,X,T1,T2,state);

        % Update the network parameters using the Adam optimizer.
        [parameters,trailingAvg,trailingAvgSq] = adamupdate(parameters,gradients, ...
            trailingAvg,trailingAvgSq,iteration);

        % Display the training progress.
        D = duration(0,0,toc(start),Format="hh:mm:ss");
        loss = double(loss);
        addpoints(lineLossTrain,iteration,loss)
        title("Epoch: " + epoch + ", Elapsed: " + string(D))
        drawnow
    end
end

モデルのテスト

真のラベルと角度をもつテスト セットで予測を比較して、モデルの分類精度をテストします。学習データと同じ設定の minibatchqueue オブジェクトを使用して、テスト データ セットを管理します。

[XTest,T1Test,anglesTest] = digitTest4DArrayData;

dsXTest = arrayDatastore(XTest,IterationDimension=4);
dsTTest = arrayDatastore(T1Test);
dsAnglesTest = arrayDatastore(anglesTest);

dsTest = combine(dsXTest,dsTTest,dsAnglesTest);

mbqTest = minibatchqueue(dsTest,...
    MiniBatchSize=miniBatchSize,...
    MiniBatchFcn=@preprocessMiniBatch,...
    MiniBatchFormat=["SSCB" "" ""]);

検証データのラベルと角度を予測するために、この例の最後にリストされている関数 modelPredictions を使用します。この関数は、予測されたクラスと角度、および真の値との比較を返します。

[classesPredictions,anglesPredictions,classCorr,angleDiff] = modelPredictions(parameters,state,mbqTest,classNames);

分類精度を評価します。

accuracy = mean(classCorr)
accuracy = 0.9824

回帰精度を評価します。

angleRMSE = sqrt(mean(angleDiff.^2))
angleRMSE = single
    7.9194

一部のイメージと、その予測を表示します。予測された角度を赤、正しい角度を緑で表示します。

idx = randperm(size(XTest,4),9);
figure
for i = 1:9
    subplot(3,3,i)
    I = XTest(:,:,:,idx(i));
    imshow(I)
    hold on

    sz = size(I,1);
    offset = sz/2;

    thetaPred = anglesPredictions(idx(i));
    plot(offset*[1-tand(thetaPred) 1+tand(thetaPred)],[sz 0],"r--")

    thetaValidation = anglesTest(idx(i));
    plot(offset*[1-tand(thetaValidation) 1+tand(thetaValidation)],[sz 0],"g--")

    hold off
    label = string(classesPredictions(idx(i)));
    title("Label: " + label)
end

モデル関数

関数 model は、モデル パラメーター parameters、入力データ X、モデルが学習と予測のどちらの出力を返すかを指定するフラグ doTraining、およびネットワークの状態 state を入力として受け取ります。この関数は、ラベルと角度の予測、および更新されたネットワークの状態を返します。

function [Y1,Y2,state] = model(parameters,X,doTraining,state)

% Convolution
weights = parameters.conv1.Weights;
bias = parameters.conv1.Bias;
Y = dlconv(X,weights,bias,Padding=2);

% Batch normalization, ReLU
offset = parameters.batchnorm1.Offset;
scale = parameters.batchnorm1.Scale;
trainedMean = state.batchnorm1.TrainedMean;
trainedVariance = state.batchnorm1.TrainedVariance;

if doTraining
    [Y,trainedMean,trainedVariance] = batchnorm(Y,offset,scale,trainedMean,trainedVariance);

    % Update state
    state.batchnorm1.TrainedMean = trainedMean;
    state.batchnorm1.TrainedVariance = trainedVariance;
else
    Y = batchnorm(Y,offset,scale,trainedMean,trainedVariance);
end

Y = relu(Y);

% Convolution, batch normalization (skip connection)
weights = parameters.convSkip.Weights;
bias = parameters.convSkip.Bias;
YSkip = dlconv(Y,weights,bias,Stride=2);

offset = parameters.batchnormSkip.Offset;
scale = parameters.batchnormSkip.Scale;
trainedMean = state.batchnormSkip.TrainedMean;
trainedVariance = state.batchnormSkip.TrainedVariance;

if doTraining
    [YSkip,trainedMean,trainedVariance] = batchnorm(YSkip,offset,scale,trainedMean,trainedVariance);

    % Update state
    state.batchnormSkip.TrainedMean = trainedMean;
    state.batchnormSkip.TrainedVariance = trainedVariance;
else
    YSkip = batchnorm(YSkip,offset,scale,trainedMean,trainedVariance);
end

% Convolution
weights = parameters.conv2.Weights;
bias = parameters.conv2.Bias;
Y = dlconv(Y,weights,bias,Padding=1,Stride=2);

% Batch normalization, ReLU
offset = parameters.batchnorm2.Offset;
scale = parameters.batchnorm2.Scale;
trainedMean = state.batchnorm2.TrainedMean;
trainedVariance = state.batchnorm2.TrainedVariance;

if doTraining
    [Y,trainedMean,trainedVariance] = batchnorm(Y,offset,scale,trainedMean,trainedVariance);

    % Update state
    state.batchnorm2.TrainedMean = trainedMean;
    state.batchnorm2.TrainedVariance = trainedVariance;
else
    Y = batchnorm(Y,offset,scale,trainedMean,trainedVariance);
end

Y = relu(Y);

% Convolution
weights = parameters.conv3.Weights;
bias = parameters.conv3.Bias;
Y = dlconv(Y,weights,bias,Padding=1);

% Batch normalization
offset = parameters.batchnorm3.Offset;
scale = parameters.batchnorm3.Scale;
trainedMean = state.batchnorm3.TrainedMean;
trainedVariance = state.batchnorm3.TrainedVariance;

if doTraining
    [Y,trainedMean,trainedVariance] = batchnorm(Y,offset,scale,trainedMean,trainedVariance);

    % Update state
    state.batchnorm3.TrainedMean = trainedMean;
    state.batchnorm3.TrainedVariance = trainedVariance;
else
    Y = batchnorm(Y,offset,scale,trainedMean,trainedVariance);
end

% Addition, ReLU
Y = YSkip + Y;
Y = relu(Y);

% Fully connect, softmax (labels)
weights = parameters.fc1.Weights;
bias = parameters.fc1.Bias;
Y1 = fullyconnect(Y,weights,bias);
Y1 = softmax(Y1);

% Fully connect (angles)
weights = parameters.fc2.Weights;
bias = parameters.fc2.Bias;
Y2 = fullyconnect(Y,weights,bias);

end

モデル損失関数

関数 modelLoss は、モデル パラメーター、入力データのミニバッチ X とそれに対応するターゲット T1 および T2 (それぞれラベルと角度を含む) を入力として受け取り、損失、学習可能なパラメーターについての損失の勾配、および更新されたネットワークの状態を返します。

function [loss,gradients,state] = modelLoss(parameters,X,T1,T2,state)

doTraining = true;
[Y1,Y2,state] = model(parameters,X,doTraining,state);

lossLabels = crossentropy(Y1,T1);
lossAngles = mse(Y2,T2);

loss = lossLabels + 0.1*lossAngles;
gradients = dlgradient(loss,parameters);

end

モデル予測関数

関数 modelPredictions は、モデル パラメーター、ネットワークの状態、入力データの minibatchqueue オブジェクト mbq、およびネットワーク クラスを受け取り、doTraining オプションを false に設定した関数 model を使用して、minibatchqueue のすべてのデータを反復処理することにより、モデル予測を返します。この関数は、予測されたクラスと角度、および真の値との比較を返します。クラスについては、比較結果を、予測の正誤を表す 0 と 1 のベクトルで示します。角度については、比較結果を、予測された角度と真の値の差で示します。

function [classesPredictions,anglesPredictions,classCorr,angleDiff] = modelPredictions(parameters,state,mbq,classes)

doTraining = false;

classesPredictions = [];
anglesPredictions = [];
classCorr = [];
angleDiff = [];

while hasdata(mbq)
    [X,T1,T2] = next(mbq);

    % Make predictions using the model function.
    [Y1,Y2] = model(parameters,X,doTraining,state);

    % Determine predicted classes.
    Y1PredBatch = onehotdecode(Y1,classes,1);
    classesPredictions = [classesPredictions Y1PredBatch];

    % Dermine predicted angles
    Y2PredBatch = extractdata(Y2);
    anglesPredictions = [anglesPredictions Y2PredBatch];

    % Compare predicted and true classes
    Y1 = onehotdecode(T1,classes,1);
    classCorr = [classCorr Y1PredBatch == Y1];

    % Compare predicted and true angles
    angleDiffBatch = Y2PredBatch - T2;
    angleDiff = [angleDiff extractdata(gather(angleDiffBatch))];

end

end

ミニバッチ前処理関数

関数 preprocessMiniBatch は、次の手順でデータを前処理します。

  1. 入力 cell 配列からイメージ データを抽出して数値配列に連結します。4 番目の次元でイメージ データを連結することにより、3 番目の次元が各イメージに追加されます。この次元は、シングルトン チャネル次元として使用されます。

  2. 入力 cell 配列からラベルと角度データを抽出して、categorical 配列と数値配列にそれぞれ連結します。

  3. カテゴリカル ラベルを数値配列に one-hot 符号化します。最初の次元への符号化は、ネットワーク出力の形状と一致する符号化された配列を生成します。

function [X,T,angle] = preprocessMiniBatch(dataX,dataT,dataAngle)

% Extract image data from cell and concatenate
X = cat(4,dataX{:});

% Extract label data from cell and concatenate
T = cat(2,dataT{:});

% Extract angle data from cell and concatenate
angle = cat(2,dataAngle{:});

% One-hot encode labels
T = onehotencode(T,1);

end

参考

| | | | | | | | | | | |

関連するトピック