Main Content

このページの内容は最新ではありません。最新版の英語を参照するには、ここをクリックします。

ベイズ ニューラル ネットワークの学習

この例では、Bayes by Backpropagation [1] を使用してイメージ回帰用にベイズ ニューラル ネットワーク (BNN) に学習させる方法を示します。BNN を使用すると、手書きの数字の回転を予測し、それらの予測の不確実性をモデル化することができます。

ベイズ ニューラル ネットワーク (BNN) は、ベイズ法を使用して深層学習ネットワークの予測の不確実性を定量化する、深層学習ネットワークの一種です。この例では、Bayes by Backpropagation (Bayes by backprop とも呼ばれます) を使用して、ニューラル ネットワークの重みの分布を推定します。単一の重みセットではなく重みの分布を使用することで、ネットワーク予測の不確実性を推定することができます。

次の図は、予測される回転角度と推定された重み分布の不確実性領域の例を示しています。

データの読み込み

数字のデータ セットを読み込みます。このデータ セットには、手書きの数字の合成イメージと各イメージに対応する回転角度 (度単位) が含まれています。

digitTrain4DArrayDatadigitTest4DArrayData を使用して学習イメージとテスト イメージを 4 次元配列として読み込みます。出力 TTrain および TTest は回転角度 (度単位) です。学習データ セットとテスト データ セットにはそれぞれ、5000 個のイメージが含まれています。

[XTrain,~,TTrain] = digitTrain4DArrayData;
[XTest,~,TTest] = digitTest4DArrayData;

学習予測子と応答を含む単一のデータストアを作成します。数値配列をデータストアに変換するには、arrayDatastoreを使用します。次に、関数combineを使用し、それらのデータストアを単一のデータストアに結合します。

dsXTrain = arrayDatastore(XTrain,IterationDimension=4);
dsTTrain = arrayDatastore(TTrain);
dsTrain = combine(dsXTrain,dsTTrain);

応答のサイズと観測値の数を抽出します。

numResponses = size(TTrain,2)
numResponses = 1
numObservations = numel(TTrain)
numObservations = 5000

ランダムに選ばれた 64 個の学習イメージを表示します。

idx = randperm(numObservations,64);
I = imtile(XTrain(:,:,:,idx));
figure
imshow(I)

ネットワーク アーキテクチャの定義

単一の確定的なセットではなく分布を使用して重みとバイアスをモデル化するには、重みの確率分布を定義しなければなりません。ベイズの定理を使用して次のように分布を定義できます。

P(parameters|data)=P(data|parameters)×P(parameters)P(data)likelihood×prior

ここで、P(data|parameters)=L(parameters|data) は尤度、P(parameters) は事前分布です。この例では、ガウス分布 (二乗損失に相当) に従うように重みとバイアスを設定します。学習中に、ネットワークは重みとバイアスの分布を決定するガウス分布の平均と分散を学習します。

それぞれが平均 0、分散 sigma1 および sigma2 をもつ 2 つのコンポーネントを使用して、混合ガウス モデル [1] に事前分布を設定します。学習前に分散を修正することも、学習時に分散を学習させることもできます。混合モデルの両方のコンポーネントの混合比率は 0.5 です。

イメージ回帰用にベイズ ニューラル ネットワークを定義します。

  • イメージ入力用に、学習データと一致する入力サイズのイメージ入力層を指定します。

  • イメージ入力は正規化しません。入力層の Normalization オプションを "none" に設定します。

  • ReLU 活性化層を間に挟んだ 3 つのベイズ全結合層を指定します。

ベイズ全結合層は、平均重みと予想される重み分布のバイアスを格納する全結合層の一種です。層の活性化を計算するとき、ソフトウェアはランダムなガウス ノイズによって平均の重みとバイアスをシフトし、シフトされた重みとバイアスを使用して層の出力を計算します。

ベイズ全結合層を作成するには、この例にサポート ファイルとして添付されている、bayesFullyConnectedLayer.m カスタム層を使用します。ベイズ全結合層は、出力サイズと重み分布の事前確率のパラメーター sigma1sigma2 を入力として受け取ります。

ネットワークを定義します。

inputSize = [28 28 1];
outputSize = 784;

sigma1 = 1;
sigma2 = 0.5;

layers = [
    imageInputLayer(inputSize,Normalization="none")
    bayesFullyConnectedLayer(outputSize,Sigma1=sigma1,Sigma2=sigma2)
    reluLayer
    bayesFullyConnectedLayer(outputSize/2,Sigma1=sigma1,Sigma2=sigma2)
    reluLayer
    bayesFullyConnectedLayer(1,Sigma1=sigma1,Sigma2=sigma2)];

層配列からdlnetworkオブジェクトを作成します。

net = dlnetwork(layers);

analyzeNetworkを使用してネットワークを可視化します。ベイズ全結合層の学習可能なパラメーターに、重みとバイアスの平均と分散が含まれていることがわかります。

analyzeNetwork(net)

学習可能パラメーターの定義

学習可能なパラメーターには、ネットワーク (層) 学習可能パラメーターとグローバル学習可能パラメーターが含まれます。学習の際、アルゴリズムは次の学習可能パラメーターを更新します。

  • 層の重みとバイアスの平均と分散 (層ごと)

  • 重み分布の事前確率 (層ごと)

  • サンプリング ノイズ (グローバル)

サンプリング ノイズの初期化

サンプリング ノイズを使用して、ニューラル ネットワークの予測におけるノイズを表現します。ネットワークの重みとバイアスを使用してサンプリング ノイズを学習します。

サンプリング ノイズを初期化します。

samplingNoise = dlarray(1);

事前確率の初期化

事前分散パラメーターを修正することも、他の学習可能パラメーターのように学習時に学習させることもできます。学習時には低い学習率で事前パラメーターを学習させ、その値が初期値から離れないようにします。初期学習率を 0.25 に設定します。

doLearnPrior = true;
priorLearnRate = 0.25;

numLearnables = size(net.Learnables,1);

for i=1:numLearnables
    layerName = net.Learnables.Layer(i);
    parameterName = net.Learnables.Parameter(i);

    if parameterName == "Sigma1" || parameterName == "Sigma2"
        if doLearnPrior
            net = setLearnRateFactor(net,layerName,parameterName,priorLearnRate);
        else
            net = setLearnRateFactor(net,layerName,parameterName,0);
        end
    end
end

モデル損失関数の定義

モデルの損失と、学習可能パラメーターについての損失の勾配を返す関数を定義します。この例では、証拠下限損失のセクションで定義された証拠下限 (ELBO) 損失を最小限に抑えます。

モデル損失関数のセクションにリストされている関数 modelLoss を作成します。関数は、dlnetwork オブジェクトと、対応するターゲットを含む入力データのミニバッチを入力として受け取ります。関数は次の値を返します。

  • ELBO 損失

  • 平方根平均二乗誤差 (RMSE)

  • 学習可能パラメーターに対する損失の勾配

  • サンプリング ノイズに対する損失の勾配

  • ネットワークの状態

学習オプションの指定

ミニバッチ サイズを 128 として 50 エポック学習させます。

numEpochs = 50;
miniBatchSize = 128;

学習時の ELBO 損失を追跡します。50 回の反復ごとに損失をプロットし、学習可能パラメーターの 5 つのサンプルにわたる損失を平均します。

numSamplesForavgELBO = 5;
averageLossComputationFrequency = 50;

モデルの学習

イメージのミニバッチを処理および管理するためのminibatchqueueオブジェクトを作成します。各ミニバッチで次を行います。

  • カスタム ミニバッチ前処理関数 preprocessMiniBatch (この例の最後に定義) を使用して、クラス ラベルを one-hot 符号化します。

  • イメージ データを次元ラベル "SSCB" (spatial、spatial、channel、batch) で書式設定します。既定では、minibatchqueue オブジェクトは、基となる型が single である dlarray オブジェクトにデータを変換します。

  • GPU が利用できる場合、GPU で学習を行います。既定では、minibatchqueue オブジェクトは、GPU が利用可能な場合、各出力を gpuArray オブジェクトに変換します。GPU を使用するには、Parallel Computing Toolbox™ ライセンスとサポートされている GPU デバイスが必要です。サポートされているデバイスについては、GPU 計算の要件 (Parallel Computing Toolbox)を参照してください。

mbq = minibatchqueue(dsTrain, ...
    MiniBatchSize=miniBatchSize, ...
    MiniBatchFcn=@preprocessMiniBatch, ...
    MiniBatchFormat=["SSCB" "CB"]);

Adam 最適化のパラメーターを初期化します。

trailingAvg = [];
trailingAvgSq = [];
trailingAvgNoise = [];
trailingAvgNoiseSq = [];

学習の進行状況モニター用に合計反復回数を計算します。

numIterationsPerEpoch = ceil(numObservations/miniBatchSize);
numIterations = numEpochs*numIterationsPerEpoch;

学習の進行状況モニターを初期化します。

monitor = trainingProgressMonitor( ...
    Metrics=["RMSE","AverageELBOLoss"], ...
    Info="Epoch", ...
    XLabel="Iteration");

カスタム学習ループを使用してモデルに学習させます。各エポックについて、データをシャッフルしてデータのミニバッチをループで回します。反復が終了するたびに、学習の進行状況を表示します。各ミニバッチで次を行います。

  • 関数dlfevalおよび modelLoss を使用してモデルの損失と勾配を評価します。

  • 関数adamupdateを使用してネットワーク パラメーターを更新します。

  • 関数 adamupdate を使用して、サンプリング ノイズ パラメーター (グローバル パラメーター) を更新します。

  • RMSE と平均 ELBO 損失を記録します。

iteration = 0;
epoch = 0;

% Loop over epochs.
while epoch < numEpochs && ~monitor.Stop
    epoch = epoch + 1;
    miniBatchIdx = 0;

    % Shuffle data.
    shuffle(mbq);

    while hasdata(mbq) && ~monitor.Stop
        iteration = iteration + 1;
        miniBatchIdx = miniBatchIdx + 1;

        [X,T] = next(mbq);

        [elboLoss,rmsError,gradientsNet,gradientsNoise] = dlfeval(@modelLoss, ...
            net,X,T,samplingNoise,miniBatchIdx,numIterationsPerEpoch);

        % Update the network parameters using the Adam optimizer.
        [net,trailingAvg,trailingAvgSq] = adamupdate(net,gradientsNet, ...
            trailingAvg,trailingAvgSq,iteration);

        % Update the sampling noise.
        [samplingNoise,trailingAvgNoise,trailingAvgNoiseSq] = adamupdate(samplingNoise, ...
            gradientsNoise,trailingAvgNoise,trailingAvgNoiseSq,iteration);

        % Record the RMSE.
        recordMetrics(monitor,iteration,RMSE=double(rmsError))

        % Record the average ELBO loss.
        if mod(iteration,averageLossComputationFrequency) == 0
            avgELBOLoss = averageNegativeELBO(net,X,T,samplingNoise,miniBatchIdx, ...
                numIterationsPerEpoch,numSamplesForavgELBO);

            recordMetrics(monitor,iteration,AverageELBOLoss=double(avgELBOLoss))
        end

        % Update the epoch and progress values in the monitor.
        updateInfo(monitor,Epoch=string(epoch) + " of " + string(numEpochs))
        monitor.Progress = 100*(iteration/numIterations);
    end
end

ネットワークのテスト

BNN は、畳み込みニューラル ネットワークのように単一セットの重みを最適化するのではなく、重みの確率分布を学習します。したがって、学習可能パラメーターの学習済み確率分布から各ネットワークをサンプリングするネットワークのアンサンブルとして BNN を見ることができます。

BNN の精度をテストするには、重みとバイアスの N 個のサンプルを生成し、N 個のサンプルにわたる平均予測と真の値を比較します。N 個の予測間の標準偏差がモデルの不確実性です。モデル予測関数のセクションにリストされている関数 modelPosteriorSample を使用して、入力セットの予測を生成します。関数は、重みとバイアスの事後分布から N 回サンプリングします。N 個のサンプルのそれぞれについて、関数は入力イメージの予測を生成します。BNN からの予測では、重みとバイアスのサンプルを使用します。したがって、予測には多少の変動ノイズが含まれます。

テスト データを dlarray オブジェクトに変換します。

XTest = dlarray(XTest,"SSCB");
if canUseGPU
    XTest = gpuArray(XTest);
end

単一イメージのテスト

modelPosteriorSample を使用して、最初のテスト イメージについて 10 個のサンプルを生成します。関数は、回転角度に対する 10 個の予測を返します。最終的なモデル予測は、10 個の予測の平均値です。

idx = 1;
numSamples = 10;
img = XTest(:,:,:,idx);

predictions = modelPosteriorSample(net,img,samplingNoise,numSamples);
YTestImg = mean(predictions,1);

真の角度、予測された角度、および予測の平均をプロットします。

figure
lineWidth = 1.5;
uncertaintyColor = "#EDB120";

I = extractdata(img);
imshow(I,InitialMagnification=800)
hold on

inputSize = size(img,1);
offset = inputSize/2;

thetaActual = TTest(idx);
plot(offset*[1 - tand(thetaActual),1 + tand(thetaActual)],[inputSize 0], ...
    LineWidth=lineWidth)

thetaPredAvg = YTestImg;
plot(offset*[1 - tand(thetaPredAvg),1 + tand(thetaPredAvg)],[inputSize 0], ...
    LineWidth=lineWidth)

for i=1:numSamples
    thetaPred = predictions(i);
    plot(offset*[1 - tand(thetaPred),1 + tand(thetaPred)],[inputSize 0],"--", ...
        Color=uncertaintyColor)
end

hold off
title("Pred: " + round(thetaPredAvg,2)+" (Mean)" + ", True: " + round(thetaActual,2))
legend(["True","Mean Prediction","Prediction"],Location="southeast")

サンプル数を 500 に増やし、テスト イメージの予測回転角の分布をプロットします。

numSamples = 500;

predictions = modelPosteriorSample(net,img,samplingNoise,numSamples);
YTestImg = mean(predictions,1);
uncertaintyImg = std(predictions,1);

figure
histogram(predictions)

trueColor = "#0072BD";
predColor = "#D95319";

hold on
xline(TTest(idx),Color=trueColor,LineWidth=lineWidth)
xline(YTestImg,Color=predColor,LineWidth=lineWidth)
xline(YTestImg - 2*uncertaintyImg,"--",Color=uncertaintyColor,LineWidth=lineWidth)
xline(YTestImg + 2*uncertaintyImg,"--",Color=uncertaintyColor,LineWidth=lineWidth)
hold off

xlabel("Angle of Rotation")
ylabel("Frequency")
title("Distribution of Predictions (Number of Samples = " + numSamples + ")")
legend("","True","Mean Prediction","+-" + "2\sigma (Standard Deviation)")

すべてのイメージのテスト

学習可能パラメーターの 100 個のサンプルを使用して、各テスト イメージの回転角度を予測します。

numSamples = 100;
predictions = modelPosteriorSample(net,XTest,samplingNoise,numSamples);
YTest = mean(predictions,1);
uncertainty = std(predictions,1);

真の回転角度と予測された回転角度の間の予測誤差を計算します。

predictionError = TTest - YTest';

RMSE を使用して、真の回転角度と予測された回転角度の差を測定します。

squares = predictionError.^2;
rmse = sqrt(mean(squares))
rmse = 15.0308

予測角度と不確実性の可視化

いくつかのイメージを予測角度と真の角度で表示します。モデル予測の標準偏差を使用して、予測の不確実性を示します。

numTestImages = numel(TTest);
numObservationToShow = 9;
idxTestSubset = randperm(numTestImages,numObservationToShow);

sdToPlot = 2;

tiledlayout("flow",TileSpacing="tight");

for i = 1:numObservationToShow
    idx = idxTestSubset(i);

    nexttile
    I = extractdata(XTest(:,:,:,idx));
    imshow(I)
    hold on

    thetaActual = TTest(idx);
    plot(offset*[1 - tand(thetaActual),1 + tand(thetaActual)],[inputSize 0],LineWidth=lineWidth)

    thetaPred = YTest(idx);
    plot(offset*[1 - tand(thetaPred),1 + tand(thetaPred)],[inputSize 0],LineWidth=lineWidth)

    thetaUncertainty = [thetaPred - sdToPlot*uncertainty(idx),thetaPred + sdToPlot*uncertainty(idx)];

    % Plot upper and lower bounds.
    lowerBound = [1 - tand(thetaUncertainty(1)),1 + tand(thetaUncertainty(1))];
    upperBound = [1 - tand(thetaUncertainty(2)),1 + tand(thetaUncertainty(2))];
    plot(offset*lowerBound,[inputSize 0],"--",Color=uncertaintyColor,LineWidth=lineWidth)
    plot(offset*upperBound,[inputSize 0],"--",Color=uncertaintyColor,LineWidth=lineWidth)

    hold off
    title({"True = " + round(thetaActual,2),"Pred: " + round(thetaPred,2)})
    if i == 2
        legend(["True","Mean Prediction","+-" + sdToPlot + "\sigma (Standard Deviation)"], ...
            Location="northoutside", ...
            NumColumns=3)
    end
end

サポート関数

ミニバッチ前処理関数

関数 preprocessMiniBatch は、次の手順を使用してデータを前処理します。

  1. 入力 cell 配列 dataX からイメージ データを抽出して、それを数値配列に連結します。4 番目の次元でイメージ データを連結することにより、ネットワークがシングルトン チャネル次元として使用できるように、各イメージに 3 番目の次元が追加されます。

  2. 入力 cell 配列 dataAng から角度データを抽出し、それを 2 番目の次元に沿って数値配列に連結します。

function [X,A] = preprocessMiniBatch(dataX,dataAng)

X = cat(4,dataX{:});
A = cat(2,dataAng{:});

end

モデル予測関数

関数 modelPosteriorSample は、dlnetwork オブジェクト net、入力イメージ X、サンプリング ノイズ samplingNoise、およびサンプル数を入力として受け取り、numSamples を生成します。関数は、入力イメージに対して numSample 個の予測を返します。

function predictions = modelPosteriorSample(net,X,samplingNoise,numSamples)

predictions = zeros(numSamples,size(X,4));

for i=1:numSamples
    Y = predict(net,X,Acceleration="none");
    sigmaY = exp(samplingNoise);
    predictions(i,:) = Y + sigmaY.*randn(size(Y));
end

end

最尤推定関数

関数 logLikelihood は、真の値とサンプリング ノイズを考慮して、ネットワーク予測の尤度を推定します。関数は、予測 Y、真の値 T、およびサンプリング ノイズ samplingNoise を入力として受け取り、対数尤度 l を返します。

function l = logLikelihood(Y,T,samplingNoise)

sigmaY = exp(samplingNoise);
l = sum(logProbabilityNormal(T,Y,sigmaY),"all");

end

モデル損失関数

関数 modelLoss は、dlnetwork オブジェクト net、対応するターゲット T を含む入力データ X のミニバッチ、サンプリング ノイズ samplingNoise、ミニバッチ インデックス miniBatchIdx、およびバッチ数 numBatches を入力として受け取ります。関数は、ELBO 損失、RMSE 損失、学習可能なパラメーターに関する損失の勾配、およびサンプリング ノイズに関する損失の勾配を返します。

function [elboLoss,meanError,gradientsNet,gradientsNoise] = modelLoss(net,X,T,samplingNoise,miniBatchIdx,numBatches)

[elboLoss,Y] = negativeELBO(net,X,T,samplingNoise,miniBatchIdx,numBatches);

[gradientsNet,gradientsNoise] = dlgradient(elboLoss,net.Learnables,samplingNoise);

meanError = double(sqrt(mse(Y,T)));

end

証拠下限 (ELBO) 損失関数

関数 negativeELBO は、与えられたミニバッチの ELBO 損失を計算します。

ELBO 損失は、以下の目的を兼ねています。

  • ネットワーク予測の尤度を最大化する。

  • 変分分布 q(w|θ) と事後分布の間のカルバック・ライブラー (KL) ダイバージェンスを最小化する。変分分布 q(w|θ) は真の事後分布に近似し、学習時の計算量を軽減します。

関数 negativeELBO は、dlnetwork オブジェクト net、対応するターゲット T を含む入力データ X のミニバッチ、サンプリング ノイズ samplingNoise、ミニバッチ インデックス miniBatchIdx、およびバッチ数 numBatches を入力として受け取ります。関数は、ELBO 損失 ELBO とフォワード パスの結果 (ネットワーク予測) Y を返します。

function [ELBO,Y] = negativeELBO(net,X,T,samplingNoise,miniBatchIdx,numBatches)

[Y,state] = forward(net,X,Acceleration="auto");

beta = KLWeight(miniBatchIdx,numBatches);

logPosterior = state.Value(state.Parameter == "LogPosterior");
logPosterior = sum([logPosterior{:}]);
logPrior = state.Value(state.Parameter == "LogPrior");
logPrior = sum([logPrior{:}]);

l = logLikelihood(Y,T,samplingNoise) ;

ELBO = (-1*l) + ((logPosterior - logPrior)*beta);

end

平均 ELBO 損失

関数 averageNegativeELBO は、dlnetwork オブジェクト net、対応するターゲット T を含む入力データ X のミニバッチ、サンプリング ノイズ samplingNoise、ミニバッチ インデックス miniBatchIdx、バッチ数 numBatches、およびサンプル数 numSamples を入力として受け取ります。関数は、ELBO 損失の numSamples 個のサンプル全体で平均化した ELBO 損失を返します。

function avgELBO = averageNegativeELBO(net,X,T,samplingNoise,miniBatchIdx,numBatches,numSamples)

avgELBO = 0;

for i=1: numSamples
    ELBO = negativeELBO(net,X,T,samplingNoise,miniBatchIdx,numBatches);
    avgELBO = avgELBO + ELBO;
end

avgELBO = avgELBO/numSamples;

end

ミニバッチと KL 再重み付け

関数 KLWeight は、現在のバッチ インデックス i とバッチの総数 m を入力として受け取ります。関数は、現在のバッチの KL sum をスケールするために使用できる、範囲 [0, 1] のスカラー値 beta を返します。

次の再重み付け戦略を使用して、各ミニバッチのコストを最小化します。

β[0,1]M and i=1Mβi=1,

ここで、βi=2M-i2M-1 です。

β は、重みの事後分布の推定値に対するスケーリング係数です [1]

function beta = KLWeight(i,m)

beta = 2^(m - i)/(2^m - 1);

end

参考文献

[1] Blundell, Charles, Julien Cornebise, Koray Kavukcuoglu, and Daan Wierstra, “Weight Uncertainty in Neural Networks”. arXiv preprint arXiv:1505.05424 (May 2015)., https://arxiv.org/abs/1505.05424.

参考

| | | |

関連するトピック