このページの内容は最新ではありません。最新版の英語を参照するには、ここをクリックします。
ベイズ ニューラル ネットワークの学習
この例では、Bayes by Backpropagation [1] を使用してイメージ回帰用にベイズ ニューラル ネットワーク (BNN) に学習させる方法を示します。BNN を使用すると、手書きの数字の回転を予測し、それらの予測の不確実性をモデル化することができます。
ベイズ ニューラル ネットワーク (BNN) は、ベイズ法を使用して深層学習ネットワークの予測の不確実性を定量化する、深層学習ネットワークの一種です。この例では、Bayes by Backpropagation (Bayes by backprop とも呼ばれます) を使用して、ニューラル ネットワークの重みの分布を推定します。単一の重みセットではなく重みの分布を使用することで、ネットワーク予測の不確実性を推定することができます。
次の図は、予測される回転角度と推定された重み分布の不確実性領域の例を示しています。
データの読み込み
数字のデータ セットを読み込みます。このデータ セットには、手書きの数字の合成イメージと各イメージに対応する回転角度 (度単位) が含まれています。
digitTrain4DArrayData
と digitTest4DArrayData
を使用して学習イメージとテスト イメージを 4 次元配列として読み込みます。出力 TTrain
および TTest
は回転角度 (度単位) です。学習データ セットとテスト データ セットにはそれぞれ、5000 個のイメージが含まれています。
[XTrain,~,TTrain] = digitTrain4DArrayData; [XTest,~,TTest] = digitTest4DArrayData;
学習予測子と応答を含む単一のデータストアを作成します。数値配列をデータストアに変換するには、arrayDatastore
を使用します。次に、関数combine
を使用し、それらのデータストアを単一のデータストアに結合します。
dsXTrain = arrayDatastore(XTrain,IterationDimension=4); dsTTrain = arrayDatastore(TTrain); dsTrain = combine(dsXTrain,dsTTrain);
応答のサイズと観測値の数を抽出します。
numResponses = size(TTrain,2)
numResponses = 1
numObservations = numel(TTrain)
numObservations = 5000
ランダムに選ばれた 64 個の学習イメージを表示します。
idx = randperm(numObservations,64); I = imtile(XTrain(:,:,:,idx)); figure imshow(I)
ネットワーク アーキテクチャの定義
単一の確定的なセットではなく分布を使用して重みとバイアスをモデル化するには、重みの確率分布を定義しなければなりません。ベイズの定理を使用して次のように分布を定義できます。
ここで、 は尤度、 は事前分布です。この例では、ガウス分布 (二乗損失に相当) に従うように重みとバイアスを設定します。学習中に、ネットワークは重みとバイアスの分布を決定するガウス分布の平均と分散を学習します。
それぞれが平均 0、分散 sigma1
および sigma2
をもつ 2 つのコンポーネントを使用して、混合ガウス モデル [1] に事前分布を設定します。学習前に分散を修正することも、学習時に分散を学習させることもできます。混合モデルの両方のコンポーネントの混合比率は 0.5 です。
イメージ回帰用にベイズ ニューラル ネットワークを定義します。
イメージ入力用に、学習データと一致する入力サイズのイメージ入力層を指定します。
イメージ入力は正規化しません。入力層の
Normalization
オプションを"none"
に設定します。ReLU 活性化層を間に挟んだ 3 つのベイズ全結合層を指定します。
ベイズ全結合層は、平均重みと予想される重み分布のバイアスを格納する全結合層の一種です。層の活性化を計算するとき、ソフトウェアはランダムなガウス ノイズによって平均の重みとバイアスをシフトし、シフトされた重みとバイアスを使用して層の出力を計算します。
ベイズ全結合層を作成するには、この例にサポート ファイルとして添付されている、bayesFullyConnectedLayer.m
カスタム層を使用します。ベイズ全結合層は、出力サイズと重み分布の事前確率のパラメーター sigma1
と sigma2
を入力として受け取ります。
ネットワークを定義します。
inputSize = [28 28 1];
outputSize = 784;
sigma1 = 1;
sigma2 = 0.5;
layers = [
imageInputLayer(inputSize,Normalization="none")
bayesFullyConnectedLayer(outputSize,Sigma1=sigma1,Sigma2=sigma2)
reluLayer
bayesFullyConnectedLayer(outputSize/2,Sigma1=sigma1,Sigma2=sigma2)
reluLayer
bayesFullyConnectedLayer(1,Sigma1=sigma1,Sigma2=sigma2)];
層配列からdlnetwork
オブジェクトを作成します。
net = dlnetwork(layers);
analyzeNetwork
を使用してネットワークを可視化します。ベイズ全結合層の学習可能なパラメーターに、重みとバイアスの平均と分散が含まれていることがわかります。
analyzeNetwork(net)
学習可能パラメーターの定義
学習可能なパラメーターには、ネットワーク (層) 学習可能パラメーターとグローバル学習可能パラメーターが含まれます。学習の際、アルゴリズムは次の学習可能パラメーターを更新します。
層の重みとバイアスの平均と分散 (層ごと)
重み分布の事前確率 (層ごと)
サンプリング ノイズ (グローバル)
サンプリング ノイズの初期化
サンプリング ノイズを使用して、ニューラル ネットワークの予測におけるノイズを表現します。ネットワークの重みとバイアスを使用してサンプリング ノイズを学習します。
サンプリング ノイズを初期化します。
samplingNoise = dlarray(1);
事前確率の初期化
事前分散パラメーターを修正することも、他の学習可能パラメーターのように学習時に学習させることもできます。学習時には低い学習率で事前パラメーターを学習させ、その値が初期値から離れないようにします。初期学習率を 0.25 に設定します。
doLearnPrior = true; priorLearnRate = 0.25; numLearnables = size(net.Learnables,1); for i=1:numLearnables layerName = net.Learnables.Layer(i); parameterName = net.Learnables.Parameter(i); if parameterName == "Sigma1" || parameterName == "Sigma2" if doLearnPrior net = setLearnRateFactor(net,layerName,parameterName,priorLearnRate); else net = setLearnRateFactor(net,layerName,parameterName,0); end end end
モデル損失関数の定義
モデルの損失と、学習可能パラメーターについての損失の勾配を返す関数を定義します。この例では、証拠下限損失のセクションで定義された証拠下限 (ELBO) 損失を最小限に抑えます。
モデル損失関数のセクションにリストされている関数 modelLoss
を作成します。関数は、dlnetwork
オブジェクトと、対応するターゲットを含む入力データのミニバッチを入力として受け取ります。関数は次の値を返します。
ELBO 損失
平方根平均二乗誤差 (RMSE)
学習可能パラメーターに対する損失の勾配
サンプリング ノイズに対する損失の勾配
ネットワークの状態
学習オプションの指定
ミニバッチ サイズを 128 として 50 エポック学習させます。
numEpochs = 50; miniBatchSize = 128;
学習時の ELBO 損失を追跡します。50 回の反復ごとに損失をプロットし、学習可能パラメーターの 5 つのサンプルにわたる損失を平均します。
numSamplesForavgELBO = 5; averageLossComputationFrequency = 50;
モデルの学習
イメージのミニバッチを処理および管理するためのminibatchqueue
オブジェクトを作成します。各ミニバッチで次を行います。
カスタム ミニバッチ前処理関数
preprocessMiniBatch
(この例の最後に定義) を使用して、クラス ラベルを one-hot 符号化します。イメージ データを次元ラベル
"SSCB"
(spatial、spatial、channel、batch) で書式設定します。既定では、minibatchqueue
オブジェクトは、基となる型が single であるdlarray
オブジェクトにデータを変換します。GPU が利用できる場合、GPU で学習を行います。既定では、
minibatchqueue
オブジェクトは、GPU が利用可能な場合、各出力をgpuArray
オブジェクトに変換します。GPU を使用するには、Parallel Computing Toolbox™ ライセンスとサポートされている GPU デバイスが必要です。サポートされているデバイスについては、GPU 計算の要件 (Parallel Computing Toolbox)を参照してください。
mbq = minibatchqueue(dsTrain, ... MiniBatchSize=miniBatchSize, ... MiniBatchFcn=@preprocessMiniBatch, ... MiniBatchFormat=["SSCB" "CB"]);
Adam 最適化のパラメーターを初期化します。
trailingAvg = []; trailingAvgSq = []; trailingAvgNoise = []; trailingAvgNoiseSq = [];
学習の進行状況モニター用に合計反復回数を計算します。
numIterationsPerEpoch = ceil(numObservations/miniBatchSize); numIterations = numEpochs*numIterationsPerEpoch;
学習の進行状況モニターを初期化します。
monitor = trainingProgressMonitor( ... Metrics=["RMSE","AverageELBOLoss"], ... Info="Epoch", ... XLabel="Iteration");
カスタム学習ループを使用してモデルに学習させます。各エポックについて、データをシャッフルしてデータのミニバッチをループで回します。反復が終了するたびに、学習の進行状況を表示します。各ミニバッチで次を行います。
関数
dlfeval
およびmodelLoss
を使用してモデルの損失と勾配を評価します。関数
adamupdate
を使用してネットワーク パラメーターを更新します。関数
adamupdate
を使用して、サンプリング ノイズ パラメーター (グローバル パラメーター) を更新します。RMSE と平均 ELBO 損失を記録します。
iteration = 0; epoch = 0; % Loop over epochs. while epoch < numEpochs && ~monitor.Stop epoch = epoch + 1; miniBatchIdx = 0; % Shuffle data. shuffle(mbq); while hasdata(mbq) && ~monitor.Stop iteration = iteration + 1; miniBatchIdx = miniBatchIdx + 1; [X,T] = next(mbq); [elboLoss,rmsError,gradientsNet,gradientsNoise] = dlfeval(@modelLoss, ... net,X,T,samplingNoise,miniBatchIdx,numIterationsPerEpoch); % Update the network parameters using the Adam optimizer. [net,trailingAvg,trailingAvgSq] = adamupdate(net,gradientsNet, ... trailingAvg,trailingAvgSq,iteration); % Update the sampling noise. [samplingNoise,trailingAvgNoise,trailingAvgNoiseSq] = adamupdate(samplingNoise, ... gradientsNoise,trailingAvgNoise,trailingAvgNoiseSq,iteration); % Record the RMSE. recordMetrics(monitor,iteration,RMSE=double(rmsError)) % Record the average ELBO loss. if mod(iteration,averageLossComputationFrequency) == 0 avgELBOLoss = averageNegativeELBO(net,X,T,samplingNoise,miniBatchIdx, ... numIterationsPerEpoch,numSamplesForavgELBO); recordMetrics(monitor,iteration,AverageELBOLoss=double(avgELBOLoss)) end % Update the epoch and progress values in the monitor. updateInfo(monitor,Epoch=string(epoch) + " of " + string(numEpochs)) monitor.Progress = 100*(iteration/numIterations); end end
ネットワークのテスト
BNN は、畳み込みニューラル ネットワークのように単一セットの重みを最適化するのではなく、重みの確率分布を学習します。したがって、学習可能パラメーターの学習済み確率分布から各ネットワークをサンプリングするネットワークのアンサンブルとして BNN を見ることができます。
BNN の精度をテストするには、重みとバイアスの 個のサンプルを生成し、 個のサンプルにわたる平均予測と真の値を比較します。 個の予測間の標準偏差がモデルの不確実性です。モデル予測関数のセクションにリストされている関数 modelPosteriorSample
を使用して、入力セットの予測を生成します。関数は、重みとバイアスの事後分布から 回サンプリングします。 個のサンプルのそれぞれについて、関数は入力イメージの予測を生成します。BNN からの予測では、重みとバイアスのサンプルを使用します。したがって、予測には多少の変動ノイズが含まれます。
テスト データを dlarray
オブジェクトに変換します。
XTest = dlarray(XTest,"SSCB"); if canUseGPU XTest = gpuArray(XTest); end
単一イメージのテスト
modelPosteriorSample
を使用して、最初のテスト イメージについて 10 個のサンプルを生成します。関数は、回転角度に対する 10 個の予測を返します。最終的なモデル予測は、10 個の予測の平均値です。
idx = 1; numSamples = 10; img = XTest(:,:,:,idx); predictions = modelPosteriorSample(net,img,samplingNoise,numSamples); YTestImg = mean(predictions,1);
真の角度、予測された角度、および予測の平均をプロットします。
figure lineWidth = 1.5; uncertaintyColor = "#EDB120"; I = extractdata(img); imshow(I,InitialMagnification=800) hold on inputSize = size(img,1); offset = inputSize/2; thetaActual = TTest(idx); plot(offset*[1 - tand(thetaActual),1 + tand(thetaActual)],[inputSize 0], ... LineWidth=lineWidth) thetaPredAvg = YTestImg; plot(offset*[1 - tand(thetaPredAvg),1 + tand(thetaPredAvg)],[inputSize 0], ... LineWidth=lineWidth) for i=1:numSamples thetaPred = predictions(i); plot(offset*[1 - tand(thetaPred),1 + tand(thetaPred)],[inputSize 0],"--", ... Color=uncertaintyColor) end hold off title("Pred: " + round(thetaPredAvg,2)+" (Mean)" + ", True: " + round(thetaActual,2)) legend(["True","Mean Prediction","Prediction"],Location="southeast")
サンプル数を 500 に増やし、テスト イメージの予測回転角の分布をプロットします。
numSamples = 500; predictions = modelPosteriorSample(net,img,samplingNoise,numSamples); YTestImg = mean(predictions,1); uncertaintyImg = std(predictions,1); figure histogram(predictions) trueColor = "#0072BD"; predColor = "#D95319"; hold on xline(TTest(idx),Color=trueColor,LineWidth=lineWidth) xline(YTestImg,Color=predColor,LineWidth=lineWidth) xline(YTestImg - 2*uncertaintyImg,"--",Color=uncertaintyColor,LineWidth=lineWidth) xline(YTestImg + 2*uncertaintyImg,"--",Color=uncertaintyColor,LineWidth=lineWidth) hold off xlabel("Angle of Rotation") ylabel("Frequency") title("Distribution of Predictions (Number of Samples = " + numSamples + ")") legend("","True","Mean Prediction","+-" + "2\sigma (Standard Deviation)")
すべてのイメージのテスト
学習可能パラメーターの 100 個のサンプルを使用して、各テスト イメージの回転角度を予測します。
numSamples = 100; predictions = modelPosteriorSample(net,XTest,samplingNoise,numSamples); YTest = mean(predictions,1); uncertainty = std(predictions,1);
真の回転角度と予測された回転角度の間の予測誤差を計算します。
predictionError = TTest - YTest';
RMSE を使用して、真の回転角度と予測された回転角度の差を測定します。
squares = predictionError.^2; rmse = sqrt(mean(squares))
rmse = 15.0308
予測角度と不確実性の可視化
いくつかのイメージを予測角度と真の角度で表示します。モデル予測の標準偏差を使用して、予測の不確実性を示します。
numTestImages = numel(TTest); numObservationToShow = 9; idxTestSubset = randperm(numTestImages,numObservationToShow); sdToPlot = 2; tiledlayout("flow",TileSpacing="tight"); for i = 1:numObservationToShow idx = idxTestSubset(i); nexttile I = extractdata(XTest(:,:,:,idx)); imshow(I) hold on thetaActual = TTest(idx); plot(offset*[1 - tand(thetaActual),1 + tand(thetaActual)],[inputSize 0],LineWidth=lineWidth) thetaPred = YTest(idx); plot(offset*[1 - tand(thetaPred),1 + tand(thetaPred)],[inputSize 0],LineWidth=lineWidth) thetaUncertainty = [thetaPred - sdToPlot*uncertainty(idx),thetaPred + sdToPlot*uncertainty(idx)]; % Plot upper and lower bounds. lowerBound = [1 - tand(thetaUncertainty(1)),1 + tand(thetaUncertainty(1))]; upperBound = [1 - tand(thetaUncertainty(2)),1 + tand(thetaUncertainty(2))]; plot(offset*lowerBound,[inputSize 0],"--",Color=uncertaintyColor,LineWidth=lineWidth) plot(offset*upperBound,[inputSize 0],"--",Color=uncertaintyColor,LineWidth=lineWidth) hold off title({"True = " + round(thetaActual,2),"Pred: " + round(thetaPred,2)}) if i == 2 legend(["True","Mean Prediction","+-" + sdToPlot + "\sigma (Standard Deviation)"], ... Location="northoutside", ... NumColumns=3) end end
サポート関数
ミニバッチ前処理関数
関数 preprocessMiniBatch
は、次の手順を使用してデータを前処理します。
入力 cell 配列
dataX
からイメージ データを抽出して、それを数値配列に連結します。4 番目の次元でイメージ データを連結することにより、ネットワークがシングルトン チャネル次元として使用できるように、各イメージに 3 番目の次元が追加されます。入力 cell 配列
dataAng
から角度データを抽出し、それを 2 番目の次元に沿って数値配列に連結します。
function [X,A] = preprocessMiniBatch(dataX,dataAng) X = cat(4,dataX{:}); A = cat(2,dataAng{:}); end
モデル予測関数
関数 modelPosteriorSample
は、dlnetwork
オブジェクト net
、入力イメージ X
、サンプリング ノイズ samplingNoise
、およびサンプル数を入力として受け取り、numSamples
を生成します。関数は、入力イメージに対して numSample
個の予測を返します。
function predictions = modelPosteriorSample(net,X,samplingNoise,numSamples) predictions = zeros(numSamples,size(X,4)); for i=1:numSamples Y = predict(net,X,Acceleration="none"); sigmaY = exp(samplingNoise); predictions(i,:) = Y + sigmaY.*randn(size(Y)); end end
最尤推定関数
関数 logLikelihood
は、真の値とサンプリング ノイズを考慮して、ネットワーク予測の尤度を推定します。関数は、予測 Y
、真の値 T
、およびサンプリング ノイズ samplingNoise
を入力として受け取り、対数尤度 l
を返します。
function l = logLikelihood(Y,T,samplingNoise) sigmaY = exp(samplingNoise); l = sum(logProbabilityNormal(T,Y,sigmaY),"all"); end
モデル損失関数
関数 modelLoss
は、dlnetwork
オブジェクト net
、対応するターゲット T
を含む入力データ X
のミニバッチ、サンプリング ノイズ samplingNoise
、ミニバッチ インデックス miniBatchIdx
、およびバッチ数 numBatches
を入力として受け取ります。関数は、ELBO 損失、RMSE 損失、学習可能なパラメーターに関する損失の勾配、およびサンプリング ノイズに関する損失の勾配を返します。
function [elboLoss,meanError,gradientsNet,gradientsNoise] = modelLoss(net,X,T,samplingNoise,miniBatchIdx,numBatches) [elboLoss,Y] = negativeELBO(net,X,T,samplingNoise,miniBatchIdx,numBatches); [gradientsNet,gradientsNoise] = dlgradient(elboLoss,net.Learnables,samplingNoise); meanError = double(sqrt(mse(Y,T))); end
証拠下限 (ELBO) 損失関数
関数 negativeELBO
は、与えられたミニバッチの ELBO 損失を計算します。
ELBO 損失は、以下の目的を兼ねています。
ネットワーク予測の尤度を最大化する。
変分分布 と事後分布の間のカルバック・ライブラー (KL) ダイバージェンスを最小化する。変分分布 は真の事後分布に近似し、学習時の計算量を軽減します。
関数 negativeELBO
は、dlnetwork
オブジェクト net
、対応するターゲット T
を含む入力データ X
のミニバッチ、サンプリング ノイズ samplingNoise
、ミニバッチ インデックス miniBatchIdx
、およびバッチ数 numBatches
を入力として受け取ります。関数は、ELBO 損失 ELBO
とフォワード パスの結果 (ネットワーク予測) Y
を返します。
function [ELBO,Y] = negativeELBO(net,X,T,samplingNoise,miniBatchIdx,numBatches) [Y,state] = forward(net,X,Acceleration="auto"); beta = KLWeight(miniBatchIdx,numBatches); logPosterior = state.Value(state.Parameter == "LogPosterior"); logPosterior = sum([logPosterior{:}]); logPrior = state.Value(state.Parameter == "LogPrior"); logPrior = sum([logPrior{:}]); l = logLikelihood(Y,T,samplingNoise) ; ELBO = (-1*l) + ((logPosterior - logPrior)*beta); end
平均 ELBO 損失
関数 averageNegativeELBO
は、dlnetwork
オブジェクト net
、対応するターゲット T
を含む入力データ X
のミニバッチ、サンプリング ノイズ samplingNoise
、ミニバッチ インデックス miniBatchIdx
、バッチ数 numBatches
、およびサンプル数 numSamples
を入力として受け取ります。関数は、ELBO 損失の numSamples
個のサンプル全体で平均化した ELBO 損失を返します。
function avgELBO = averageNegativeELBO(net,X,T,samplingNoise,miniBatchIdx,numBatches,numSamples) avgELBO = 0; for i=1: numSamples ELBO = negativeELBO(net,X,T,samplingNoise,miniBatchIdx,numBatches); avgELBO = avgELBO + ELBO; end avgELBO = avgELBO/numSamples; end
ミニバッチと KL 再重み付け
関数 KLWeight
は、現在のバッチ インデックス i
とバッチの総数 m
を入力として受け取ります。関数は、現在のバッチの KL sum をスケールするために使用できる、範囲 [0, 1] のスカラー値 beta
を返します。
次の再重み付け戦略を使用して、各ミニバッチのコストを最小化します。
,
ここで、 です。
は、重みの事後分布の推定値に対するスケーリング係数です [1]。
function beta = KLWeight(i,m) beta = 2^(m - i)/(2^m - 1); end
参考文献
[1] Blundell, Charles, Julien Cornebise, Koray Kavukcuoglu, and Daan Wierstra, “Weight Uncertainty in Neural Networks”. arXiv preprint arXiv:1505.05424 (May 2015)., https://arxiv.org/abs/1505.05424.
参考
dlnetwork
| dlarray
| minibatchqueue
| dlfeval
| adamupdate