Main Content

このページの翻訳は最新ではありません。ここをクリックして、英語の最新版を参照してください。

回帰用の畳み込みニューラル ネットワークの学習

この例では、畳み込みニューラル ネットワークを使用して回帰モデルを当てはめ、手書きの数字の回転角度を予測する方法を示します。

畳み込みニューラル ネットワーク (CNN または ConvNet) は深層学習に不可欠なツールであり、特にイメージ データの解析に適しています。たとえば、CNN を使用してイメージを分類できます。角度や距離などの連続データを予測するために、ネットワークの最後に回帰層を含めることができます。

この例では、畳み込みニューラル ネットワーク アーキテクチャを構築し、ネットワークの学習を行い、学習済みネットワークを使用して手書きの数字の回転角度を予測します。このような予測は、光学式文字認識に役立ちます。

オプションで、imrotate (Image Processing Toolbox™) を使用してイメージを回転させ、boxplot (Statistics and Machine Learning Toolbox™) を使用して残差の箱ひげ図を作成できます。

データの読み込み

データセットには、手書きの数字の合成イメージと各イメージに対応する回転角度 (度単位) が含まれています。

digitTrain4DArrayDatadigitTest4DArrayData を使用して学習イメージと検証イメージを 4 次元配列として読み込みます。出力 YTrain および YValidation は回転角度 (度単位) です。学習データセットと検証データセットにはそれぞれ、5000 枚のイメージが含まれています。

[XTrain,~,YTrain] = digitTrain4DArrayData;
[XValidation,~,YValidation] = digitTest4DArrayData;

imshow を使用して、ランダムに選ばれた 20 枚の学習イメージを表示します。

numTrainImages = numel(YTrain);
figure
idx = randperm(numTrainImages,20);
for i = 1:numel(idx)
    subplot(4,5,i)    
    imshow(XTrain(:,:,:,idx(i)))
end

Figure contains 20 axes objects. Axes object 1 contains an object of type image. Axes object 2 contains an object of type image. Axes object 3 contains an object of type image. Axes object 4 contains an object of type image. Axes object 5 contains an object of type image. Axes object 6 contains an object of type image. Axes object 7 contains an object of type image. Axes object 8 contains an object of type image. Axes object 9 contains an object of type image. Axes object 10 contains an object of type image. Axes object 11 contains an object of type image. Axes object 12 contains an object of type image. Axes object 13 contains an object of type image. Axes object 14 contains an object of type image. Axes object 15 contains an object of type image. Axes object 16 contains an object of type image. Axes object 17 contains an object of type image. Axes object 18 contains an object of type image. Axes object 19 contains an object of type image. Axes object 20 contains an object of type image.

データ正規化の確認

ニューラル ネットワークに学習させるときは、ネットワークのすべての段階でデータが正規化されていることを確認すると、多くの場合役に立ちます。正規化は、勾配降下を使用したネットワークの学習の安定化と高速化に有効です。データが適切にスケーリングされていない場合、学習中に損失が NaN になり、ネットワーク パラメーターが発散する可能性があります。データを正規化する一般的な方法として、データの範囲が [0,1] になるように、または平均が 0 で標準偏差が 1 になるように、データを再スケーリングする方法があります。次のデータを正規化できます。

  • 入力データ。ネットワークに入力する前に予測子を正規化します。この例では、入力イメージは範囲 [0,1] に既に正規化されています。

  • 層出力。バッチ正規化層を使用すると、畳み込み層と全結合層のそれぞれについて出力を正規化できます。

  • 応答。バッチ正規化層を使用してネットワークの最後で層出力を正規化する場合、学習の開始時にネットワークの予測が正規化されます。応答のスケールがこれらの予測と非常に異なる場合、ネットワークの学習が収束しないことがあります。応答が適切にスケーリングされていない場合は、正規化を試して、ネットワーク学習が改善されるか確認してください。学習の前に応答を正規化する場合、学習済みネットワークの予測を変換して、元の応答の予測を求めなければなりません。

応答の分布をプロットします。応答 (度単位の回転角度) は、-45 と 45 の間でほぼ一様分布しており、正規化しなくても問題ありません。分類問題では、出力はクラス確率であり、常に正規化されています。

figure
histogram(YTrain)
axis tight
ylabel('Counts')
xlabel('Rotation Angle')

Figure contains an axes object. The axes object contains an object of type histogram.

一般的に、データを厳密に正規化する必要はありません。ただし、この例で YTrain ではなく 100*YTrain または YTrain+500 を予測するためにネットワークに学習させる場合、学習を開始すると、損失が NaN になり、ネットワーク パラメーターが発散します。aY + b を予測するネットワークと Y を予測するネットワークとが、最後の全結合層の重みとバイアスの単純な再スケーリングしか違わなくても、このような結果になります。

入力または応答の分布が非常に不均一な場合や偏っている場合は、ネットワークの学習前にデータを非線形変換 (対数を取るなど) することもできます。

ネットワーク層の作成

回帰問題を解くには、ネットワークの層を作成し、ネットワークの最後に回帰層を含めます。

最初の層は、入力データのサイズとタイプを定義します。入力イメージは 28 x 28 x 1 です。学習イメージと同じサイズのイメージ入力層を作成します。

ネットワークの中間層は、計算と学習の大部分が行われる、ネットワークの中核を成すアーキテクチャを定義します。

最後の層は、出力データのサイズとタイプを定義します。回帰問題を解くには、ネットワークの最後の回帰層の前に全結合層を配置しなければなりません。サイズ 1 の全結合出力層、および回帰層を作成します。

配列 Layer ですべての層をまとめます。

layers = [
    imageInputLayer([28 28 1])
    convolution2dLayer(3,8,'Padding','same')
    batchNormalizationLayer
    reluLayer
    averagePooling2dLayer(2,'Stride',2)
    convolution2dLayer(3,16,'Padding','same')
    batchNormalizationLayer
    reluLayer
    averagePooling2dLayer(2,'Stride',2)
    convolution2dLayer(3,32,'Padding','same')
    batchNormalizationLayer
    reluLayer
    convolution2dLayer(3,32,'Padding','same')
    batchNormalizationLayer
    reluLayer
    dropoutLayer(0.2)
    fullyConnectedLayer(1)
    regressionLayer];

ネットワークの学習

ネットワーク学習オプションを作成します。学習を 30 エポック行います。初期学習率を 0.001 に設定し、20 エポック後に学習率を下げます。検証データと検証頻度を指定して、学習中にネットワークの精度を監視します。学習データでネットワークに学習させ、学習中に一定の間隔で検証データに対してその精度を計算します。検証データは、ネットワークの重みの更新には使用されません。学習の進行状況プロットをオンにして、コマンド ウィンドウの出力をオフにします。

miniBatchSize  = 128;
validationFrequency = floor(numel(YTrain)/miniBatchSize);
options = trainingOptions('sgdm', ...
    'MiniBatchSize',miniBatchSize, ...
    'MaxEpochs',30, ...
    'InitialLearnRate',1e-3, ...
    'LearnRateSchedule','piecewise', ...
    'LearnRateDropFactor',0.1, ...
    'LearnRateDropPeriod',20, ...
    'Shuffle','every-epoch', ...
    'ValidationData',{XValidation,YValidation}, ...
    'ValidationFrequency',validationFrequency, ...
    'Plots','training-progress', ...
    'Verbose',false);

trainNetwork を使用してネットワークを作成します。このコマンドでは、互換性のある GPU が利用できる場合は、その GPU が使用されます。GPU を使用するには、Parallel Computing Toolbox™ とサポートされている GPU デバイスが必要です。サポートされているデバイスについては、リリース別の GPU サポート (Parallel Computing Toolbox)を参照してください。そうでない場合、trainNetwork では CPU が使用されます。

net = trainNetwork(XTrain,YTrain,layers,options);

Figure Training Progress (12-Apr-2022 00:54:13) contains 2 axes objects and another object of type uigridlayout. Axes object 1 contains 10 objects of type patch, text, line. Axes object 2 contains 10 objects of type patch, text, line.

netLayers プロパティに含まれるネットワーク アーキテクチャの詳細を確認します。

net.Layers
ans = 
  18x1 Layer array with layers:

     1   'imageinput'         Image Input           28x28x1 images with 'zerocenter' normalization
     2   'conv_1'             Convolution           8 3x3x1 convolutions with stride [1  1] and padding 'same'
     3   'batchnorm_1'        Batch Normalization   Batch normalization with 8 channels
     4   'relu_1'             ReLU                  ReLU
     5   'avgpool2d_1'        Average Pooling       2x2 average pooling with stride [2  2] and padding [0  0  0  0]
     6   'conv_2'             Convolution           16 3x3x8 convolutions with stride [1  1] and padding 'same'
     7   'batchnorm_2'        Batch Normalization   Batch normalization with 16 channels
     8   'relu_2'             ReLU                  ReLU
     9   'avgpool2d_2'        Average Pooling       2x2 average pooling with stride [2  2] and padding [0  0  0  0]
    10   'conv_3'             Convolution           32 3x3x16 convolutions with stride [1  1] and padding 'same'
    11   'batchnorm_3'        Batch Normalization   Batch normalization with 32 channels
    12   'relu_3'             ReLU                  ReLU
    13   'conv_4'             Convolution           32 3x3x32 convolutions with stride [1  1] and padding 'same'
    14   'batchnorm_4'        Batch Normalization   Batch normalization with 32 channels
    15   'relu_4'             ReLU                  ReLU
    16   'dropout'            Dropout               20% dropout
    17   'fc'                 Fully Connected       1 fully connected layer
    18   'regressionoutput'   Regression Output     mean-squared-error with response 'Response'

ネットワークのテスト

検証データに対する精度を評価することによって、ネットワーク性能をテストします。

predict を使用して、検証イメージの回転角度を予測します。

YPredicted = predict(net,XValidation);

性能の評価

次の計算を行って、モデルの性能を評価します。

  1. 許容誤差限界内にある予測の比率

  2. 回転角度の予測値と実際の値の平方根平均二乗誤差 (RMSE)

回転角度の予測値と実際の値との間の予測誤差を計算します。

predictionError = YValidation - YPredicted;

真の角度から許容誤差限界内にある予測の数を計算します。しきい値を 10 度に設定します。このしきい値の範囲内にある予測の比率を計算します。

thr = 10;
numCorrect = sum(abs(predictionError) < thr);
numValidationImages = numel(YValidation);

accuracy = numCorrect/numValidationImages
accuracy = 0.9716

平方根平均二乗誤差 (RMSE) を使用して、回転角度の予測値と実際の値の差を測定します。

squares = predictionError.^2;
rmse = sqrt(mean(squares))
rmse = single
    4.5505

予測の可視化

散布図で予測を可視化します。真の値に対する予測値をプロットします。

figure
scatter(YPredicted,YValidation,'+')
xlabel("Predicted Value")
ylabel("True Value")

hold on
plot([-60 60], [-60 60],'r--')

Figure contains an axes object. The axes object contains 2 objects of type scatter, line.

数字の回転補正

Image Processing Toolbox の関数を使用して、数字をまっすぐにし、まとめて表示できます。imrotate (Image Processing Toolbox) を使用して、49 個の数字標本をそれぞれの予測回転角度に応じて回転させます。

idx = randperm(numValidationImages,49);
for i = 1:numel(idx)
    image = XValidation(:,:,:,idx(i));
    predictedAngle = YPredicted(idx(i));  
    imagesRotated(:,:,:,i) = imrotate(image,predictedAngle,'bicubic','crop');
end

元の数字と回転補正後の数字を表示します。montage (Image Processing Toolbox) を使用して、数字を 1 つのイメージにまとめて表示できます。

figure
subplot(1,2,1)
montage(XValidation(:,:,:,idx))
title('Original')

subplot(1,2,2)
montage(imagesRotated)
title('Corrected')

Figure contains 2 axes objects. Axes object 1 with title Original contains an object of type image. Axes object 2 with title Corrected contains an object of type image.

参考

|

関連するトピック