このページの翻訳は最新ではありません。ここをクリックして、英語の最新版を参照してください。
回帰用の畳み込みニューラル ネットワークの学習
この例では、畳み込みニューラル ネットワークを使用して回帰モデルを当てはめ、手書きの数字の回転角度を予測する方法を示します。
畳み込みニューラル ネットワーク (CNN または ConvNet) は深層学習に不可欠なツールであり、特にイメージ データの解析に適しています。たとえば、CNN を使用してイメージを分類できます。角度や距離などの連続データを予測するために、ネットワークの最後に回帰層を含めることができます。
この例では、畳み込みニューラル ネットワーク アーキテクチャを構築し、ネットワークの学習を行い、学習済みネットワークを使用して手書きの数字の回転角度を予測します。このような予測は、光学式文字認識に役立ちます。
オプションで、imrotate
(Image Processing Toolbox™) を使用してイメージを回転させ、boxplot
(Statistics and Machine Learning Toolbox™) を使用して残差の箱ひげ図を作成できます。
データの読み込み
データセットには、手書きの数字の合成イメージと各イメージに対応する回転角度 (度単位) が含まれています。
digitTrain4DArrayData
と digitTest4DArrayData
を使用して学習イメージと検証イメージを 4 次元配列として読み込みます。出力 YTrain
および YValidation
は回転角度 (度単位) です。学習データセットと検証データセットにはそれぞれ、5000 個のイメージが含まれています。
[XTrain,~,YTrain] = digitTrain4DArrayData; [XValidation,~,YValidation] = digitTest4DArrayData;
imshow
を使用して、ランダムに選ばれた 20 個の学習イメージを表示します。
numTrainImages = numel(YTrain); figure idx = randperm(numTrainImages,20); for i = 1:numel(idx) subplot(4,5,i) imshow(XTrain(:,:,:,idx(i))) end
データ正規化の確認
ニューラル ネットワークに学習させるときは、ネットワークのすべての段階でデータが正規化されていることを確認すると、多くの場合役に立ちます。正規化は、勾配降下を使用したネットワークの学習の安定化と高速化に有効です。データが適切にスケーリングされていない場合、学習中に損失が NaN
になり、ネットワーク パラメーターが発散する可能性があります。データを正規化する一般的な方法として、データの範囲が [0,1] になるように、または平均が 0 で標準偏差が 1 になるように、データを再スケーリングする方法があります。次のデータを正規化できます。
入力データ。ネットワークに入力する前に予測子を正規化します。この例では、入力イメージは範囲 [0,1] に既に正規化されています。
層出力。バッチ正規化層を使用すると、畳み込み層と全結合層のそれぞれについて出力を正規化できます。
応答。バッチ正規化層を使用してネットワークの最後で層出力を正規化する場合、学習の開始時にネットワークの予測が正規化されます。応答のスケールがこれらの予測と非常に異なる場合、ネットワークの学習が収束しないことがあります。応答が適切にスケーリングされていない場合は、正規化を試して、ネットワーク学習が改善されるか確認してください。学習の前に応答を正規化する場合、学習済みネットワークの予測を変換して、元の応答の予測を求めなければなりません。
応答の分布をプロットします。応答 (度単位の回転角度) は、-45 と 45 の間でほぼ一様分布しており、正規化しなくても問題ありません。分類問題では、出力はクラス確率であり、常に正規化されています。
figure histogram(YTrain) axis tight ylabel('Counts') xlabel('Rotation Angle')
一般的に、データを厳密に正規化する必要はありません。ただし、この例で YTrain
ではなく 100*YTrain
または YTrain+500
を予測するためにネットワークに学習させる場合、学習を開始すると、損失が NaN
になり、ネットワーク パラメーターが発散します。aY + b を予測するネットワークと Y を予測するネットワークとが、最後の全結合層の重みとバイアスの単純な再スケーリングしか違わなくても、このような結果になります。
入力または応答の分布が非常に不均一な場合や偏っている場合は、ネットワークの学習前にデータを非線形変換 (対数を取るなど) することもできます。
ネットワーク層の作成
回帰問題を解くには、ネットワークの層を作成し、ネットワークの最後に回帰層を含めます。
最初の層は、入力データのサイズとタイプを定義します。入力イメージは 28 x 28 x 1 です。学習イメージと同じサイズのイメージ入力層を作成します。
ネットワークの中間層は、計算と学習の大部分が行われる、ネットワークの中核を成すアーキテクチャを定義します。
最後の層は、出力データのサイズとタイプを定義します。回帰問題を解くには、ネットワークの最後の回帰層の前に全結合層を配置しなければなりません。サイズ 1 の全結合出力層、および回帰層を作成します。
配列 Layer
ですべての層をまとめます。
layers = [ imageInputLayer([28 28 1]) convolution2dLayer(3,8,'Padding','same') batchNormalizationLayer reluLayer averagePooling2dLayer(2,'Stride',2) convolution2dLayer(3,16,'Padding','same') batchNormalizationLayer reluLayer averagePooling2dLayer(2,'Stride',2) convolution2dLayer(3,32,'Padding','same') batchNormalizationLayer reluLayer convolution2dLayer(3,32,'Padding','same') batchNormalizationLayer reluLayer dropoutLayer(0.2) fullyConnectedLayer(1) regressionLayer];
ネットワークの学習
ネットワーク学習オプションを作成します。学習を 30 エポック行います。初期学習率を 0.001 に設定し、20 エポック後に学習率を下げます。検証データと検証頻度を指定して、学習中にネットワークの精度を監視します。学習データでネットワークに学習させ、学習中に一定の間隔で検証データに対してその精度を計算します。検証データは、ネットワークの重みの更新には使用されません。学習の進行状況プロットをオンにして、コマンド ウィンドウの出力をオフにします。
miniBatchSize = 128; validationFrequency = floor(numel(YTrain)/miniBatchSize); options = trainingOptions('sgdm', ... 'MiniBatchSize',miniBatchSize, ... 'MaxEpochs',30, ... 'InitialLearnRate',1e-3, ... 'LearnRateSchedule','piecewise', ... 'LearnRateDropFactor',0.1, ... 'LearnRateDropPeriod',20, ... 'Shuffle','every-epoch', ... 'ValidationData',{XValidation,YValidation}, ... 'ValidationFrequency',validationFrequency, ... 'Plots','training-progress', ... 'Verbose',false);
trainNetwork
を使用してネットワークを作成します。このコマンドでは、互換性のある GPU が利用できる場合は、その GPU が使用されます。GPU を使用するには、Parallel Computing Toolbox™ とサポートされている GPU デバイスが必要です。サポートされているデバイスについては、GPU 計算の要件 (Parallel Computing Toolbox)を参照してください。そうでない場合、trainNetwork
では CPU が使用されます。
net = trainNetwork(XTrain,YTrain,layers,options);
net
の Layers
プロパティに含まれるネットワーク アーキテクチャの詳細を確認します。
net.Layers
ans = 18x1 Layer array with layers: 1 'imageinput' Image Input 28x28x1 images with 'zerocenter' normalization 2 'conv_1' 2-D Convolution 8 3x3x1 convolutions with stride [1 1] and padding 'same' 3 'batchnorm_1' Batch Normalization Batch normalization with 8 channels 4 'relu_1' ReLU ReLU 5 'avgpool2d_1' 2-D Average Pooling 2x2 average pooling with stride [2 2] and padding [0 0 0 0] 6 'conv_2' 2-D Convolution 16 3x3x8 convolutions with stride [1 1] and padding 'same' 7 'batchnorm_2' Batch Normalization Batch normalization with 16 channels 8 'relu_2' ReLU ReLU 9 'avgpool2d_2' 2-D Average Pooling 2x2 average pooling with stride [2 2] and padding [0 0 0 0] 10 'conv_3' 2-D Convolution 32 3x3x16 convolutions with stride [1 1] and padding 'same' 11 'batchnorm_3' Batch Normalization Batch normalization with 32 channels 12 'relu_3' ReLU ReLU 13 'conv_4' 2-D Convolution 32 3x3x32 convolutions with stride [1 1] and padding 'same' 14 'batchnorm_4' Batch Normalization Batch normalization with 32 channels 15 'relu_4' ReLU ReLU 16 'dropout' Dropout 20% dropout 17 'fc' Fully Connected 1 fully connected layer 18 'regressionoutput' Regression Output mean-squared-error with response 'Response'
ネットワークのテスト
検証データに対する精度を評価することによって、ネットワーク性能をテストします。
predict
を使用して、検証イメージの回転角度を予測します。
YPredicted = predict(net,XValidation);
性能の評価
次の計算を行って、モデルの性能を評価します。
許容誤差限界内にある予測の比率
回転角度の予測値と実際の値の平方根平均二乗誤差 (RMSE)
回転角度の予測値と実際の値との間の予測誤差を計算します。
predictionError = YValidation - YPredicted;
真の角度から許容誤差限界内にある予測の数を計算します。しきい値を 10 度に設定します。このしきい値の範囲内にある予測の比率を計算します。
thr = 10; numCorrect = sum(abs(predictionError) < thr); numValidationImages = numel(YValidation); accuracy = numCorrect/numValidationImages
accuracy = 0.9662
平方根平均二乗誤差 (RMSE) を使用して、回転角度の予測値と実際の値の差を測定します。
squares = predictionError.^2; rmse = sqrt(mean(squares))
rmse = single
4.6092
予測の可視化
散布図で予測を可視化します。真の値に対する予測値をプロットします。
figure scatter(YPredicted,YValidation,'+') xlabel("Predicted Value") ylabel("True Value") hold on plot([-60 60], [-60 60],'r--')
数字の回転補正
Image Processing Toolbox の関数を使用して、数字をまっすぐにし、まとめて表示できます。imrotate
(Image Processing Toolbox) を使用して、49 個の数字標本をそれぞれの予測回転角度に応じて回転させます。
idx = randperm(numValidationImages,49); for i = 1:numel(idx) image = XValidation(:,:,:,idx(i)); predictedAngle = YPredicted(idx(i)); imagesRotated(:,:,:,i) = imrotate(image,predictedAngle,'bicubic','crop'); end
元の数字と回転補正後の数字を表示します。montage
(Image Processing Toolbox) を使用して、数字を 1 つのイメージにまとめて表示できます。
figure subplot(1,2,1) montage(XValidation(:,:,:,idx)) title('Original') subplot(1,2,2) montage(imagesRotated) title('Corrected')
参考
regressionLayer
| classificationLayer