ドキュメンテーション

最新のリリースでは、このページがまだ翻訳されていません。 このページの最新版は英語でご覧になれます。

回帰用の畳み込みニューラル ネットワークの学習

この例では、畳み込みニューラル ネットワークを使用して回帰モデルにあてはめ、手書きの数字の回転角度を予測する方法を示します。

畳み込みニューラル ネットワーク (CNN または ConvNet) は深層学習に不可欠なツールであり、特にイメージ データの解析に適しています。たとえば、CNN を使用してイメージを分類できます。角度や距離などの連続データを予測するために、ネットワークの最後に回帰層を含めることができます。

この例では、畳み込みニューラル ネットワーク アーキテクチャを構築し、ネットワークの学習を行い、学習済みネットワークを使用して手書きの数字の回転角度を予測します。このような予測は、光学式文字認識に役立ちます。

オプションで、imrotate (Image Processing Toolbox™) を使用してイメージを回転させ、boxplot (Statistics and Machine Learning Toolbox™) を使用して残差の箱ひげ図を作成できます。

データの読み込み

データセットには、手書きの数字の合成イメージと各イメージに対応する回転角度 (度単位) が含まれています。

digitTrain4DArrayDatadigitTest4DArrayData を使用して学習イメージと検証イメージを 4 次元配列として読み込みます。出力 YTrain および YValidation は回転角度 (度単位) です。学習データセットと検証データセットにはそれぞれ、5000 枚のイメージが含まれています。

[XTrain,~,YTrain] = digitTrain4DArrayData;
[XValidation,~,YValidation] = digitTest4DArrayData;

imshow を使用して、ランダムに選ばれた 20 枚の学習イメージを表示します。

numTrainImages = numel(YTrain);
figure
idx = randperm(numTrainImages,20);
for i = 1:numel(idx)
    subplot(4,5,i)    
    imshow(XTrain(:,:,:,idx(i)))
    drawnow
end

データ正規化の確認

ニューラル ネットワークに学習させるときは、ネットワークのすべての段階でデータが正規化されていることを確認すると、多くの場合役に立ちます。正規化は、勾配降下を使用したネットワークの学習の安定化と高速化に有効です。データが適切にスケーリングされていない場合、学習中に損失が NaN になり、ネットワーク パラメーターが発散する可能性があります。データを正規化する一般的な方法として、データの範囲が [0,1] になるように、または平均が 0 で標準偏差が 1 になるように、データを再スケーリングする方法があります。次のデータを正規化できます。

  • 入力データ。ネットワークに入力する前に予測子を正規化します。この例では、入力イメージは範囲 [0,1] に既に正規化されています。

  • 層出力。バッチ正規化層を使用すると、畳み込み層と全結合層のそれぞれについて出力を正規化できます。

  • 応答。バッチ正規化層を使用してネットワークの最後で層出力を正規化する場合、学習の開始時にネットワークの予測が正規化されます。応答のスケールがこれらの予測と非常に異なる場合、ネットワークの学習が収束しないことがあります。応答が適切にスケーリングされていない場合は、正規化を試して、ネットワーク学習が改善されるか確認してください。学習の前に応答を正規化する場合、学習済みネットワークの予測を変換して、元の応答の予測を求めなければなりません。

応答の分布をプロットします。応答 (度単位の回転角度) は、-45 と 45 の間でほぼ一様分布しており、正規化しなくても問題ありません。分類問題では、出力はクラス確率であり、常に正規化されています。

figure
histogram(YTrain)
axis tight
ylabel('Counts')
xlabel('Rotation Angle')

一般的に、データを厳密に正規化する必要はありません。ただし、この例で YTrain ではなく 100*YTrain または YTrain+500 を予測するためにネットワークに学習させる場合、学習を開始すると、損失が NaN になり、ネットワーク パラメーターが発散します。aY + b を予測するネットワークと Y を予測するネットワークとが、最後の全結合層の重みとバイアスの単純な再スケーリングしか違わなくても、このような結果になります。

入力または応答の分布が非常に不均一な場合や偏っている場合は、ネットワークの学習前にデータを非線形変換 (対数を取るなど) することもできます。

ネットワーク層の作成

回帰問題を解くには、ネットワークの層を作成し、ネットワークの最後に回帰層を含めます。

最初の層は、入力データのサイズとタイプを定義します。入力イメージは 28 x 28 x 1 です。学習イメージと同じサイズのイメージ入力層を作成します。

ネットワークの中間層は、計算と学習の大部分が行われる、ネットワークの中核を成すアーキテクチャを定義します。

最後の層は、出力データのサイズとタイプを定義します。回帰問題を解くには、ネットワークの最後の回帰層の前に全結合層を配置しなければなりません。サイズ 1 の全結合出力層、および回帰層を作成します。

配列 Layer ですべての層をまとめます。

layers = [
    imageInputLayer([28 28 1])

    convolution2dLayer(3,8,'Padding','same')
    batchNormalizationLayer
    reluLayer
    
    averagePooling2dLayer(2,'Stride',2)

    convolution2dLayer(3,16,'Padding','same')
    batchNormalizationLayer
    reluLayer
    
    averagePooling2dLayer(2,'Stride',2)
  
    convolution2dLayer(3,32,'Padding','same')
    batchNormalizationLayer
    reluLayer
    
    convolution2dLayer(3,32,'Padding','same')
    batchNormalizationLayer
    reluLayer
    
    dropoutLayer(0.2)
    fullyConnectedLayer(1)
    regressionLayer];

ネットワークの学習

ネットワーク学習オプションを作成します。学習を 30 エポック行います。初期学習率を 0.001 に設定し、20 エポック後に学習率を下げます。検証データと検証頻度を指定して、学習中にネットワークの精度を監視します。学習データでネットワークに学習させ、学習中に一定の間隔で検証データに対してその精度を計算します。検証データは、ネットワークの重みの更新には使用されません。学習の進行状況プロットをオンにして、コマンド ウィンドウの出力をオフにします。

miniBatchSize  = 128;
validationFrequency = floor(numel(YTrain)/miniBatchSize);
options = trainingOptions('sgdm', ...
    'MiniBatchSize',miniBatchSize, ...
    'MaxEpochs',30, ...
    'InitialLearnRate',1e-3, ...
    'LearnRateSchedule','piecewise', ...
    'LearnRateDropFactor',0.1, ...
    'LearnRateDropPeriod',20, ...
    'Shuffle','every-epoch', ...
    'ValidationData',{XValidation,YValidation}, ...
    'ValidationFrequency',validationFrequency, ...
    'Plots','training-progress', ...
    'Verbose',false);

trainNetwork を使用してネットワークを作成します。このコマンドでは、互換性のある GPU が利用できる場合は、その GPU が使用されます。そうでない場合、trainNetwork では CPU が使用されます。GPU で学習を行うには、Compute Capability 3.0 以上の CUDA® 対応 NVIDIA® GPU が必要です。

net = trainNetwork(XTrain,YTrain,layers,options);

netLayers プロパティに含まれるネットワーク アーキテクチャの詳細を確認します。

net.Layers
ans = 
  18x1 Layer array with layers:

     1   'imageinput'         Image Input           28x28x1 images with 'zerocenter' normalization
     2   'conv_1'             Convolution           8 3x3x1 convolutions with stride [1  1] and padding 'same'
     3   'batchnorm_1'        Batch Normalization   Batch normalization with 8 channels
     4   'relu_1'             ReLU                  ReLU
     5   'avgpool2d_1'        Average Pooling       2x2 average pooling with stride [2  2] and padding [0  0  0  0]
     6   'conv_2'             Convolution           16 3x3x8 convolutions with stride [1  1] and padding 'same'
     7   'batchnorm_2'        Batch Normalization   Batch normalization with 16 channels
     8   'relu_2'             ReLU                  ReLU
     9   'avgpool2d_2'        Average Pooling       2x2 average pooling with stride [2  2] and padding [0  0  0  0]
    10   'conv_3'             Convolution           32 3x3x16 convolutions with stride [1  1] and padding 'same'
    11   'batchnorm_3'        Batch Normalization   Batch normalization with 32 channels
    12   'relu_3'             ReLU                  ReLU
    13   'conv_4'             Convolution           32 3x3x32 convolutions with stride [1  1] and padding 'same'
    14   'batchnorm_4'        Batch Normalization   Batch normalization with 32 channels
    15   'relu_4'             ReLU                  ReLU
    16   'dropout'            Dropout               20% dropout
    17   'fc'                 Fully Connected       1 fully connected layer
    18   'regressionoutput'   Regression Output     mean-squared-error with response 'Response'

ネットワークのテスト

検証データに対する精度を評価することによって、ネットワーク性能をテストします。

predict を使用して、検証イメージの回転角度を予測します。

YPredicted = predict(net,XValidation);

性能の評価

次の計算を行って、モデルの性能を評価します。

  1. 許容誤差限界内にある予測の比率

  2. 回転角度の予測値と実際の値の平方根平均二乗誤差 (RMSE)

回転角度の予測値と実際の値との間の予測誤差を計算します。

predictionError = YValidation - YPredicted;

真の角度から許容誤差限界内にある予測の数を計算します。しきい値を 10 度に設定します。このしきい値の範囲内にある予測の比率を計算します。

thr = 10;
numCorrect = sum(abs(predictionError) < thr);
numValidationImages = numel(YValidation);

accuracy = numCorrect/numValidationImages
accuracy = 0.9698

平方根平均二乗誤差 (RMSE) を使用して、回転角度の予測値と実際の値の差を測定します。

squares = predictionError.^2;
rmse = sqrt(mean(squares))
rmse = single
    4.5137

各数字クラスの残差の箱ひげ図の表示

関数 boxplot には、各列が各数字クラスの残差に対応している行列が必要です。

検証データは、イメージを数字クラス 0 ~ 9 でグループ化したもので、各クラスに 500 個の例があります。reshape を使用して、数字クラス別に残差をグループ化します。

residualMatrix = reshape(predictionError,500,10);

residualMatrix の各列は、各数字の残差に対応します。boxplot (Statistics and Machine Learning Toolbox) を使用して各数字の残差の箱ひげ図を作成します。

figure
boxplot(residualMatrix,...
    'Labels',{'0','1','2','3','4','5','6','7','8','9'})
xlabel('Digit Class')
ylabel('Degrees Error')
title('Residuals')

精度が最も高い数字クラスは、平均値が 0 に近くなり、分散が小さくなります。

数字の回転補正

Image Processing Toolbox の関数を使用して、数字をまっすぐにし、まとめて表示できます。imrotate (Image Processing Toolbox) を使用して、49 個の数字標本をそれぞれの予測回転角度に応じて回転させます。

idx = randperm(numValidationImages,49);
for i = 1:numel(idx)
    image = XValidation(:,:,:,idx(i));
    predictedAngle = YPredicted(idx(i));  
    imagesRotated(:,:,:,i) = imrotate(image,predictedAngle,'bicubic','crop');
end

元の数字と回転補正後の数字を表示します。montage (Image Processing Toolbox) を使用して、数字を 1 つのイメージにまとめて表示できます。

figure
subplot(1,2,1)
montage(XValidation(:,:,:,idx))
title('Original')

subplot(1,2,2)
montage(imagesRotated)
title('Corrected')

参考

|

関連するトピック