分類ネットワークの回帰ネットワークへの変換
この例では、学習済み分類ネットワークを回帰ネットワークに変換する方法を説明します。
事前学習済みのイメージ分類ネットワークは、100 万個を超えるイメージで学習しており、イメージを 1000 個のオブジェクト カテゴリ (キーボード、マグ カップ、鉛筆、多くの動物など) に分類できます。このネットワークは広範囲にわたるイメージについての豊富な特徴表現を学習しています。このネットワークは入力としてイメージを取り、イメージ内のオブジェクトのラベルを各オブジェクト カテゴリの確率と共に出力します。
転移学習は、深層学習アプリケーションでよく使用されています。事前学習済みのネットワークを取得して、新しいタスクの学習の開始点として使用できます。この例では、事前学習済みの分類ネットワークを利用して、回帰タスク用に再学習させる方法を説明します。
この例では、分類用の事前学習済みの畳み込みニューラル ネットワーク アーキテクチャを読み込み、分類用の層を置き換え、ネットワークに再学習させて手書きの数字の回転角度を予測します。オプションで、imrotate
(Image Processing Toolbox™) を使用して、予測値に基づいてイメージの回転を補正できます。
事前学習済みのネットワークの読み込み
サポート ファイル digitsNet.mat
から事前学習済みのネットワークを読み込みます。このファイルには、手書きの数字を分類する分類ネットワークが含まれています。
load digitsNet
layers = net.Layers
layers = 15x1 Layer array with layers: 1 'imageinput' Image Input 28x28x1 images with 'zerocenter' normalization 2 'conv_1' 2-D Convolution 8 3x3x1 convolutions with stride [1 1] and padding 'same' 3 'batchnorm_1' Batch Normalization Batch normalization with 8 channels 4 'relu_1' ReLU ReLU 5 'maxpool_1' 2-D Max Pooling 2x2 max pooling with stride [2 2] and padding [0 0 0 0] 6 'conv_2' 2-D Convolution 16 3x3x8 convolutions with stride [1 1] and padding 'same' 7 'batchnorm_2' Batch Normalization Batch normalization with 16 channels 8 'relu_2' ReLU ReLU 9 'maxpool_2' 2-D Max Pooling 2x2 max pooling with stride [2 2] and padding [0 0 0 0] 10 'conv_3' 2-D Convolution 32 3x3x16 convolutions with stride [1 1] and padding 'same' 11 'batchnorm_3' Batch Normalization Batch normalization with 32 channels 12 'relu_3' ReLU ReLU 13 'fc' Fully Connected 10 fully connected layer 14 'softmax' Softmax softmax 15 'classoutput' Classification Output crossentropyex with '0' and 9 other classes
データの読み込み
データセットには、手書きの数字の合成イメージと各イメージに対応する回転角度 (度単位) が含まれています。
digitTrain4DArrayData
と digitTest4DArrayData
を使用して学習イメージと検証イメージを 4 次元配列として読み込みます。出力 YTrain
および YValidation
は回転角度 (度単位) です。学習データセットと検証データセットにはそれぞれ、5000 個のイメージが含まれています。
[XTrain,~,YTrain] = digitTrain4DArrayData; [XValidation,~,YValidation] = digitTest4DArrayData;
imshow
を使用して、ランダムに選ばれた 20 個の学習イメージを表示します。
numTrainImages = numel(YTrain); figure idx = randperm(numTrainImages,20); for i = 1:numel(idx) subplot(4,5,i) imshow(XTrain(:,:,:,idx(i))) end
最後の層の置き換え
ネットワークの畳み込み層は、入力イメージを分類するために、最後の学習可能な層と最終分類層が使用するイメージの特徴を抽出します。digitsNet
のこれらの 2 つの層 'fc'
および 'classoutput'
は、ネットワークによって抽出された特徴を組み合わせてクラス確率、損失値、および予測ラベルにまとめる方法に関する情報を含んでいます。回帰用に事前学習済みのネットワークに再学習させるには、これら 2 つの層をこのタスクに適応させた新しい層に置き換えます。
最後の全結合層、ソフトマックス層、および分類出力層をサイズ 1 (応答の数) の全結合層と回帰層に置き換えます。
numResponses = 1; layers = [ layers(1:12) fullyConnectedLayer(numResponses) regressionLayer];
初期の層の凍結
これで、新しいデータでネットワークに再学習させる準備が整いました。オプションで、ネットワークの初期の層について学習率を 0 に設定すると、それらの層の重みを "凍結" できます。学習中に trainNetwork
は凍結された層のパラメーターを更新しません。凍結された層の勾配は計算する必要がないため、多数の初期の層について重みを凍結すると、ネットワーク学習を大幅に高速化できます。新しいデータセットが小さい場合、初期のネットワーク層を凍結すると、新しいデータセットに対するこれらの層の過適合を防止することもできます。
サポート関数 freezeWeights
を使用して、最初の 12 個の層について学習率を 0 に設定します。
layers(1:12) = freezeWeights(layers(1:12));
ネットワークの学習
ネットワーク学習オプションを作成します。初期学習率を 0.001 に設定します。検証データを指定して、学習中にネットワークの精度を監視します。学習の進行状況プロットをオンにして、コマンド ウィンドウの出力をオフにします。
options = trainingOptions('sgdm',... 'InitialLearnRate',0.001, ... 'ValidationData',{XValidation,YValidation},... 'Plots','training-progress',... 'Verbose',false);
trainNetwork
を使用してネットワークを作成します。このコマンドでは、互換性のある GPU が利用できる場合は、その GPU が使用されます。GPU を使用するには、Parallel Computing Toolbox™ とサポートされている GPU デバイスが必要です。サポートされているデバイスについては、GPU 計算の要件 (Parallel Computing Toolbox)を参照してください。そうでない場合、trainNetwork
では CPU が使用されます。
net = trainNetwork(XTrain,YTrain,layers,options);
ネットワークのテスト
検証データに対する精度を評価することによって、ネットワーク性能をテストします。
predict
を使用して、検証イメージの回転角度を予測します。
YPred = predict(net,XValidation);
次の計算を行って、モデルの性能を評価します。
許容誤差限界内にある予測の比率
回転角度の予測値と実際の値の平方根平均二乗誤差 (RMSE)
回転角度の予測値と実際の値との間の予測誤差を計算します。
predictionError = YValidation - YPred;
真の角度から許容誤差限界内にある予測の数を計算します。しきい値を 10 度に設定します。このしきい値の範囲内にある予測の比率を計算します。
thr = 10; numCorrect = sum(abs(predictionError) < thr); numImagesValidation = numel(YValidation); accuracy = numCorrect/numImagesValidation
accuracy = 0.7532
平方根平均二乗誤差 (RMSE) を使用して、回転角度の予測値と実際の値の差を測定します。
rmse = sqrt(mean(predictionError.^2))
rmse = single
9.0270
数字の回転補正
Image Processing Toolbox の関数を使用して、数字をまっすぐにし、まとめて表示できます。imrotate
(Image Processing Toolbox) を使用して、49 個の数字標本をそれぞれの予測回転角度に応じて回転させます。
idx = randperm(numImagesValidation,49); for i = 1:numel(idx) I = XValidation(:,:,:,idx(i)); Y = YPred(idx(i)); XValidationCorrected(:,:,:,i) = imrotate(I,Y,'bicubic','crop'); end
元の数字と回転補正後の数字を表示します。montage
(Image Processing Toolbox) を使用して、数字を 1 つのイメージにまとめます。
figure subplot(1,2,1) montage(XValidation(:,:,:,idx)) title('Original') subplot(1,2,2) montage(XValidationCorrected) title('Corrected')
参考
regressionLayer
| classificationLayer