モデル関数を使用したバッチ正規化統計量の更新
この例では、関数として定義されたネットワークにおいて、ネットワークの状態を更新する方法を示します。
バッチ正規化演算は、ミニバッチ全体で各入力チャネルを正規化します。畳み込みニューラル ネットワークの学習速度を上げ、ネットワークの初期化に対する感度を下げるには、畳み込みの間にあるバッチ正規化演算と、ReLU 層などの非線形性を使用します。
学習中、バッチ正規化演算は、まず、ミニバッチの平均を減算し、ミニバッチの標準偏差で除算することにより、各チャネルの活性化を正規化します。その後、この演算は、学習可能なオフセット β だけ入力をシフトし、それを学習可能なスケール係数 γ だけスケーリングします。
学習済みネットワークを使用して新しいデータについて予測を実行する場合、バッチ正規化演算はミニバッチの平均と分散ではなく、学習済みのデータ セットの平均と分散を使用して活性化を正規化します。
データ セット統計を計算するため、継続的に更新される状態を使用して、ミニバッチ統計を追跡しなければなりません。
モデル関数においてバッチ正規化演算を使用する場合、学習と予測の両方の動作を定義しなければなりません。たとえば、boolean オプション doTraining
を指定して、モデルが学習のためにミニバッチ統計を使用するか、予測のためにデータセット統計を使用するかを制御することができます。
次はモデル関数からのコード例で、バッチ正規化演算を適用し、学習中にデータセット統計のみを更新する方法を示します。
if doTraining [Y,trainedMean,trainedVariance] = batchnorm(Y,offset,scale,trainedMean,trainedVariance); % Update state state.batchnorm1.TrainedMean = trainedMean; state.batchnorm1.TrainedVariance = trainedVariance; else Y = batchnorm(Y,offset,scale,trainedMean,trainedVariance); end
学習データの読み込み
関数 digitTrain4DArrayData
はイメージとその数字ラベル、および垂直方向からの回転角度を読み込みます。イメージ、ラベル、角度について arrayDatastore
オブジェクトを作成してから、関数 combine
を使用して、すべての学習データを含む単一のデータストアを作成します。クラス名と、離散でない応答の数を抽出します。
[XTrain,TTrain,anglesTrain] = digitTrain4DArrayData; dsXTrain = arrayDatastore(XTrain,IterationDimension=4); dsTTrain = arrayDatastore(TTrain); dsAnglesTrain = arrayDatastore(anglesTrain); dsTrain = combine(dsXTrain,dsTTrain,dsAnglesTrain); classNames = categories(TTrain); numClasses = numel(classNames); numResponses = size(anglesTrain,2); numObservations = numel(TTrain);
学習データからの一部のイメージを表示します。
idx = randperm(numObservations,64); I = imtile(XTrain(:,:,:,idx)); figure imshow(I)
深層学習モデルの定義
ラベルと回転角度の両方を予測する次のネットワークを定義します。
16 個の 5 x 5 フィルターをもつ convolution-batchnorm-ReLU ブロック
各ブロックに 32 個の 3 x 3 フィルターがあり、間に ReLU 演算をもつ、2 個の convolution-batchnorm ブロックの分岐
32 個の 1 x 1 の畳み込みをもつ convolution-batchnorm ブロックのあるスキップ接続
加算とそれに続く ReLU 演算を使用した両方の分岐の組み合わせ
回帰出力用に、サイズが 1 (応答数) の全結合演算をもつ分岐
分類出力用に、サイズが 10 (クラス数) の全結合演算とソフトマックス演算をもつ分岐
モデルのパラメーターと状態の定義および初期化
各演算のパラメーターを定義して struct に含めます。parameters.OperationName.ParameterName
の形式を使用します。ここで、parameters
は struct、OperationName
は演算名 ("conv1" など)、ParameterName
はパラメーター名 ("Weights" など) です。
モデル パラメーターを含む struct parameters
を作成します。サンプル関数 initializeGlorot
および initializeZeros
を使用して、学習可能な層の重みとバイアスをそれぞれ初期化します。サンプル関数 initializeZeros
および initializeOnes
を使用して、バッチ正規化オフセットとスケール パラメーターをそれぞれ初期化します。
バッチ正規化層を使用して学習や推論を実行するには、ネットワークの状態も管理しなければなりません。予測の前に、学習データから派生するデータセットの平均と分散を指定しなければなりません。状態パラメーターを含む struct state
を作成します。バッチ正規化の統計値は、dlarray
オブジェクトにしないでください。関数 zeros
および ones
を使用して、バッチ正規化の学習済み平均と学習済み分散の状態をそれぞれ初期化します。
この初期化サンプル関数は、この例にサポート ファイルとして添付されています。
最初の畳み込み層のパラメーターを初期化します。
filterSize = [5 5]; numChannels = 1; numFilters = 16; sz = [filterSize numChannels numFilters]; numOut = prod(filterSize) * numFilters; numIn = prod(filterSize) * numFilters; parameters.conv1.Weights = initializeGlorot(sz,numOut,numIn); parameters.conv1.Bias = initializeZeros([numFilters 1]);
最初のバッチ正規化層のパラメーターと状態を初期化します。
parameters.batchnorm1.Offset = initializeZeros([numFilters 1]); parameters.batchnorm1.Scale = initializeOnes([numFilters 1]); state.batchnorm1.TrainedMean = initializeZeros([numFilters 1]); state.batchnorm1.TrainedVariance = initializeOnes([numFilters 1]);
2 番目の畳み込み層のパラメーターを初期化します。
filterSize = [3 3]; numChannels = 16; numFilters = 32; sz = [filterSize numChannels numFilters]; numOut = prod(filterSize) * numFilters; numIn = prod(filterSize) * numFilters; parameters.conv2.Weights = initializeGlorot(sz,numOut,numIn); parameters.conv2.Bias = initializeZeros([numFilters 1]);
2 番目のバッチ正規化層のパラメーターと状態を初期化します。
parameters.batchnorm2.Offset = initializeZeros([numFilters 1]); parameters.batchnorm2.Scale = initializeOnes([numFilters 1]); state.batchnorm2.TrainedMean = initializeZeros([numFilters 1]); state.batchnorm2.TrainedVariance = initializeOnes([numFilters 1]);
3 番目の畳み込み層のパラメーターを初期化します。
filterSize = [3 3]; numChannels = 32; numFilters = 32; sz = [filterSize numChannels numFilters]; numOut = prod(filterSize) * numFilters; numIn = prod(filterSize) * numFilters; parameters.conv3.Weights = initializeGlorot(sz,numOut,numIn); parameters.conv3.Bias = initializeZeros([numFilters 1]);
3 番目のバッチ正規化層のパラメーターと状態を初期化します。
parameters.batchnorm3.Offset = initializeZeros([numFilters 1]); parameters.batchnorm3.Scale = initializeOnes([numFilters 1]); state.batchnorm3.TrainedMean = initializeZeros([numFilters 1]); state.batchnorm3.TrainedVariance = initializeOnes([numFilters 1]);
スキップ接続における畳み込み層のパラメーターを初期化します。
filterSize = [1 1]; numChannels = 16; numFilters = 32; sz = [filterSize numChannels numFilters]; numOut = prod(filterSize) * numFilters; numIn = prod(filterSize) * numFilters; parameters.convSkip.Weights = initializeGlorot(sz,numOut,numIn); parameters.convSkip.Bias = initializeZeros([numFilters 1]);
スキップ接続におけるバッチ正規化層のパラメーターと状態を初期化します。
parameters.batchnormSkip.Offset = initializeZeros([numFilters 1]); parameters.batchnormSkip.Scale = initializeOnes([numFilters 1]); state.batchnormSkip.TrainedMean = initializeZeros([numFilters 1]); state.batchnormSkip.TrainedVariance = initializeOnes([numFilters 1]);
分類出力に対応する全結合層のパラメーターを初期化します。
sz = [numClasses 6272]; numOut = numClasses; numIn = 6272; parameters.fc1.Weights = initializeGlorot(sz,numOut,numIn); parameters.fc1.Bias = initializeZeros([numClasses 1]);
回帰出力に対応する全結合層のパラメーターを初期化します。
sz = [numResponses 6272]; numOut = numResponses; numIn = 6272; parameters.fc2.Weights = initializeGlorot(sz,numOut,numIn); parameters.fc2.Bias = initializeZeros([numResponses 1]);
状態の struct を表示します。
state
state = struct with fields:
batchnorm1: [1×1 struct]
batchnorm2: [1×1 struct]
batchnorm3: [1×1 struct]
batchnormSkip: [1×1 struct]
演算 batchnorm1
の状態パラメーターを表示します。
state.batchnorm1
ans = struct with fields:
TrainedMean: [16×1 dlarray]
TrainedVariance: [16×1 dlarray]
モデルの関数の定義
この例の最後にリストされている関数 model
を作成します。この関数は前に説明した深層学習モデルの出力を計算します。
関数 model
は、モデル パラメーター parameters
、入力データ、モデルが学習と予測のどちらの出力を返すかを指定するフラグ doTraining
、およびネットワークの状態 state
を入力として受け取ります。ネットワークはラベルの予測、角度の予測、および更新されたネットワークの状態を出力します。
モデル損失関数の定義
この例の最後にリストされている関数 modelLoss
を作成します。この関数は入力データのミニバッチとそれに対応するターゲット T1
および T2
(それぞれラベルと角度を含む) を入力として受け取り、損失、学習可能なパラメーターについての損失の勾配、および更新されたネットワークの状態を返します。
学習オプションの指定
学習オプションを指定します。
numEpochs = 20; miniBatchSize = 128;
モデルの学習
カスタム学習ループを使用してモデルに学習させます。minibatchqueue
を使用して、イメージのミニバッチを処理および管理します。各ミニバッチで次を行います。
カスタム ミニバッチ前処理関数
preprocessMiniBatch
(この例の最後に定義) を使用して、クラス ラベルを one-hot 符号化します。イメージ データを次元ラベル
"SSCB"
(spatial、spatial、channel、batch) で形式を整えます。既定では、minibatchqueue
オブジェクトは、基となる型がsingle
のdlarray
オブジェクトにデータを変換します。形式をクラス ラベルまたは角度に追加しないでください。GPU が利用できる場合、GPU で学習を行います。既定では、
minibatchqueue
オブジェクトは、GPU が利用可能な場合、各出力をgpuArray
に変換します。GPU を使用するには、Parallel Computing Toolbox™ とサポートされている GPU デバイスが必要です。サポートされているデバイスについては、GPU 計算の要件 (Parallel Computing Toolbox)を参照してください。
mbq = minibatchqueue(dsTrain,... MiniBatchSize=miniBatchSize,... MiniBatchFcn=@preprocessMiniBatch,... MiniBatchFormat=["SSCB" "" ""]);
各エポックについて、データをシャッフルしてデータのミニバッチをループで回します。各エポックの最後に、学習の進行状況を表示します。各ミニバッチで次を行います。
関数
dlfeval
およびmodelLoss
を使用してモデルの損失と勾配を評価します。関数
adamupdate
を使用してネットワーク パラメーターを更新します。
Adam ソルバーのパラメーターを初期化します。
trailingAvg = []; trailingAvgSq = [];
学習の進行状況モニター用に合計反復回数を計算します。
numIterationsPerEpoch = ceil(numObservations / miniBatchSize); numIterations = numEpochs * numIterationsPerEpoch;
TrainingProgressMonitor
オブジェクトを初期化します。monitor オブジェクトを作成するとタイマーが開始されるため、学習ループに近いところでオブジェクトを作成するようにしてください。
monitor = trainingProgressMonitor(Metrics="Loss",Info=["Epoch","Iteration"],XLabel="Iteration");
モデルに学習させます。
iteration = 0; epoch = 0; start = tic; % Loop over epochs. while epoch < numEpochs && ~monitor.Stop epoch = epoch + 1; % Shuffle data. shuffle(mbq) % Loop over mini-batches while hasdata(mbq) && ~monitor.Stop iteration = iteration + 1; [X,T1,T2] = next(mbq); % Evaluate the model loss, gradients, and state using dlfeval and the % modelLoss function. [loss,gradients,state] = dlfeval(@modelLoss,parameters,X,T1,T2,state); % Update the network parameters using the Adam optimizer. [parameters,trailingAvg,trailingAvgSq] = adamupdate(parameters,gradients, ... trailingAvg,trailingAvgSq,iteration); recordMetrics(monitor,iteration,Loss=loss); updateInfo(monitor,Epoch=epoch,Iteration=iteration); monitor.Progress = 100*iteration/numIterations; end end
モデルのテスト
真のラベルと角度をもつテスト セットで予測を比較して、モデルの分類精度をテストします。学習データと同じ設定の minibatchqueue
オブジェクトを使用して、テスト データ セットを管理します。
[XTest,T1Test,anglesTest] = digitTest4DArrayData; dsXTest = arrayDatastore(XTest,IterationDimension=4); dsTTest = arrayDatastore(T1Test); dsAnglesTest = arrayDatastore(anglesTest); dsTest = combine(dsXTest,dsTTest,dsAnglesTest); mbqTest = minibatchqueue(dsTest,... MiniBatchSize=miniBatchSize,... MiniBatchFcn=@preprocessMiniBatch,... MiniBatchFormat=["SSCB" "" ""]);
検証データのラベルと角度を予測するために、この例の最後にリストされている関数 modelPredictions
を使用します。この関数は、予測されたクラスと角度、および真の値との比較を返します。
[classesPredictions,anglesPredictions,classCorr,angleDiff] = modelPredictions(parameters,state,mbqTest,classNames);
分類精度を評価します。
accuracy = mean(classCorr)
accuracy = 0.9858
回帰精度を評価します。
angleRMSE = sqrt(mean(angleDiff.^2))
angleRMSE = single
7.1762
一部のイメージと、その予測を表示します。予測された角度を赤、正しい角度を緑で表示します。
idx = randperm(size(XTest,4),9); figure for i = 1:9 subplot(3,3,i) I = XTest(:,:,:,idx(i)); imshow(I) hold on sz = size(I,1); offset = sz/2; thetaPred = anglesPredictions(idx(i)); plot(offset*[1-tand(thetaPred) 1+tand(thetaPred)],[sz 0],"r--") thetaValidation = anglesTest(idx(i)); plot(offset*[1-tand(thetaValidation) 1+tand(thetaValidation)],[sz 0],"g--") hold off label = string(classesPredictions(idx(i))); title("Label: " + label) end
モデル関数
関数 model
は、モデル パラメーター parameters
、入力データ X
、モデルが学習と予測のどちらの出力を返すかを指定するフラグ doTraining
、およびネットワークの状態 state
を入力として受け取ります。この関数は、ラベルと角度の予測、および更新されたネットワークの状態を返します。
function [Y1,Y2,state] = model(parameters,X,doTraining,state) % Convolution weights = parameters.conv1.Weights; bias = parameters.conv1.Bias; Y = dlconv(X,weights,bias,Padding=2); % Batch normalization, ReLU offset = parameters.batchnorm1.Offset; scale = parameters.batchnorm1.Scale; trainedMean = state.batchnorm1.TrainedMean; trainedVariance = state.batchnorm1.TrainedVariance; if doTraining [Y,trainedMean,trainedVariance] = batchnorm(Y,offset,scale,trainedMean,trainedVariance); % Update state state.batchnorm1.TrainedMean = trainedMean; state.batchnorm1.TrainedVariance = trainedVariance; else Y = batchnorm(Y,offset,scale,trainedMean,trainedVariance); end Y = relu(Y); % Convolution, batch normalization (skip connection) weights = parameters.convSkip.Weights; bias = parameters.convSkip.Bias; YSkip = dlconv(Y,weights,bias,Stride=2); offset = parameters.batchnormSkip.Offset; scale = parameters.batchnormSkip.Scale; trainedMean = state.batchnormSkip.TrainedMean; trainedVariance = state.batchnormSkip.TrainedVariance; if doTraining [YSkip,trainedMean,trainedVariance] = batchnorm(YSkip,offset,scale,trainedMean,trainedVariance); % Update state state.batchnormSkip.TrainedMean = trainedMean; state.batchnormSkip.TrainedVariance = trainedVariance; else YSkip = batchnorm(YSkip,offset,scale,trainedMean,trainedVariance); end % Convolution weights = parameters.conv2.Weights; bias = parameters.conv2.Bias; Y = dlconv(Y,weights,bias,Padding=1,Stride=2); % Batch normalization, ReLU offset = parameters.batchnorm2.Offset; scale = parameters.batchnorm2.Scale; trainedMean = state.batchnorm2.TrainedMean; trainedVariance = state.batchnorm2.TrainedVariance; if doTraining [Y,trainedMean,trainedVariance] = batchnorm(Y,offset,scale,trainedMean,trainedVariance); % Update state state.batchnorm2.TrainedMean = trainedMean; state.batchnorm2.TrainedVariance = trainedVariance; else Y = batchnorm(Y,offset,scale,trainedMean,trainedVariance); end Y = relu(Y); % Convolution weights = parameters.conv3.Weights; bias = parameters.conv3.Bias; Y = dlconv(Y,weights,bias,Padding=1); % Batch normalization offset = parameters.batchnorm3.Offset; scale = parameters.batchnorm3.Scale; trainedMean = state.batchnorm3.TrainedMean; trainedVariance = state.batchnorm3.TrainedVariance; if doTraining [Y,trainedMean,trainedVariance] = batchnorm(Y,offset,scale,trainedMean,trainedVariance); % Update state state.batchnorm3.TrainedMean = trainedMean; state.batchnorm3.TrainedVariance = trainedVariance; else Y = batchnorm(Y,offset,scale,trainedMean,trainedVariance); end % Addition, ReLU Y = YSkip + Y; Y = relu(Y); % Fully connect, softmax (labels) weights = parameters.fc1.Weights; bias = parameters.fc1.Bias; Y1 = fullyconnect(Y,weights,bias); Y1 = softmax(Y1); % Fully connect (angles) weights = parameters.fc2.Weights; bias = parameters.fc2.Bias; Y2 = fullyconnect(Y,weights,bias); end
モデル損失関数
関数 modelLoss
は、モデル パラメーター、入力データのミニバッチ X
とそれに対応するターゲット T1
および T2
(それぞれラベルと角度を含む) を入力として受け取り、損失、学習可能なパラメーターについての損失の勾配、および更新されたネットワークの状態を返します。
function [loss,gradients,state] = modelLoss(parameters,X,T1,T2,state) doTraining = true; [Y1,Y2,state] = model(parameters,X,doTraining,state); lossLabels = crossentropy(Y1,T1); lossAngles = mse(Y2,T2); loss = lossLabels + 0.1*lossAngles; gradients = dlgradient(loss,parameters); end
モデル予測関数
関数 modelPredictions
は、モデル パラメーター、ネットワークの状態、入力データの minibatchqueue
オブジェクト mbq
、およびネットワーク クラスを受け取り、doTraining
オプションを false
に設定した関数 model
を使用して、minibatchqueue
のすべてのデータを反復処理することにより、モデル予測を返します。この関数は、予測されたクラスと角度、および真の値との比較を返します。クラスについては、比較結果を、予測の正誤を表す 0 と 1 のベクトルで示します。角度については、比較結果を、予測された角度と真の値の差で示します。
function [classesPredictions,anglesPredictions,classCorr,angleDiff] = modelPredictions(parameters,state,mbq,classes) doTraining = false; classesPredictions = []; anglesPredictions = []; classCorr = []; angleDiff = []; while hasdata(mbq) [X,T1,T2] = next(mbq); % Make predictions using the model function. [Y1,Y2] = model(parameters,X,doTraining,state); % Determine predicted classes. Y1PredBatch = onehotdecode(Y1,classes,1); classesPredictions = [classesPredictions Y1PredBatch]; % Dermine predicted angles Y2PredBatch = extractdata(Y2); anglesPredictions = [anglesPredictions Y2PredBatch]; % Compare predicted and true classes Y1 = onehotdecode(T1,classes,1); classCorr = [classCorr Y1PredBatch == Y1]; % Compare predicted and true angles angleDiffBatch = Y2PredBatch - T2; angleDiff = [angleDiff extractdata(gather(angleDiffBatch))]; end end
ミニバッチ前処理関数
関数 preprocessMiniBatch
は、次の手順でデータを前処理します。
入力 cell 配列からイメージ データを抽出して数値配列に連結します。4 番目の次元でイメージ データを連結することにより、3 番目の次元が各イメージに追加されます。この次元は、シングルトン チャネル次元として使用されます。
入力 cell 配列からラベルと角度データを抽出して、categorical 配列と数値配列にそれぞれ連結します。
カテゴリカル ラベルを数値配列に one-hot 符号化します。最初の次元への符号化は、ネットワーク出力の形状と一致する符号化された配列を生成します。
function [X,T,angle] = preprocessMiniBatch(dataX,dataT,dataAngle) % Extract image data from cell and concatenate X = cat(4,dataX{:}); % Extract label data from cell and concatenate T = cat(2,dataT{:}); % Extract angle data from cell and concatenate angle = cat(2,dataAngle{:}); % One-hot encode labels T = onehotencode(T,1); end
Copyright 2019–2023 The MathWorks, Inc.
参考
dlarray
| sgdmupdate
| dlfeval
| dlgradient
| fullyconnect
| dlconv
| softmax
| relu
| batchnorm
| crossentropy
| minibatchqueue
| onehotencode
| onehotdecode