このページは前リリースの情報です。該当の英語のページはこのリリースで削除されています。

分類ネットワークの回帰ネットワークへの変換

この例では、学習済み分類ネットワークを回帰ネットワークに変換する方法を説明します。

事前学習済みのイメージ分類ネットワークは、100 万枚を超えるイメージで学習しており、イメージを 1000 個のオブジェクト カテゴリ (キーボード、マグ カップ、鉛筆、多くの動物など) に分類できます。このネットワークは広範囲にわたるイメージについての豊富な特徴表現を学習しています。このネットワークは入力としてイメージを取り、イメージ内のオブジェクトのラベルを各オブジェクト カテゴリの確率と共に出力します。

転移学習は、深層学習アプリケーションでよく使用されています。事前学習済みのネットワークを取得して、新しいタスクの学習の開始点として使用できます。この例では、事前学習済みの分類ネットワークを利用して、回帰タスク用に再学習させる方法を説明します。

この例では、分類用の事前学習済みの畳み込みニューラル ネットワーク アーキテクチャを読み込み、分類用の層を置き換え、ネットワークに再学習させて手書きの数字の回転角度を予測します。オプションで、imrotate (Image Processing Toolbox™) を使用して、予測値に基づいてイメージの回転を補正できます。

事前学習済みのネットワークの読み込み

サポート ファイル digitsNet.mat から事前学習済みのネットワークを読み込みます。このファイルには、手書きの数字を分類する分類ネットワークが含まれています。

load digitsNet
layers = net.Layers
layers = 
  15x1 Layer array with layers:

     1   'imageinput'    Image Input             28x28x1 images with 'zerocenter' normalization
     2   'conv_1'        Convolution             8 3x3x1 convolutions with stride [1  1] and padding 'same'
     3   'batchnorm_1'   Batch Normalization     Batch normalization with 8 channels
     4   'relu_1'        ReLU                    ReLU
     5   'maxpool_1'     Max Pooling             2x2 max pooling with stride [2  2] and padding [0  0  0  0]
     6   'conv_2'        Convolution             16 3x3x8 convolutions with stride [1  1] and padding 'same'
     7   'batchnorm_2'   Batch Normalization     Batch normalization with 16 channels
     8   'relu_2'        ReLU                    ReLU
     9   'maxpool_2'     Max Pooling             2x2 max pooling with stride [2  2] and padding [0  0  0  0]
    10   'conv_3'        Convolution             32 3x3x16 convolutions with stride [1  1] and padding 'same'
    11   'batchnorm_3'   Batch Normalization     Batch normalization with 32 channels
    12   'relu_3'        ReLU                    ReLU
    13   'fc'            Fully Connected         10 fully connected layer
    14   'softmax'       Softmax                 softmax
    15   'classoutput'   Classification Output   crossentropyex with '0' and 9 other classes

データの読み込み

データセットには、手書きの数字の合成イメージと各イメージに対応する回転角度 (度単位) が含まれています。

digitTrain4DArrayDatadigitTest4DArrayData を使用して学習イメージと検証イメージを 4 次元配列として読み込みます。出力 YTrain および YValidation は回転角度 (度単位) です。学習データセットと検証データセットにはそれぞれ、5000 枚のイメージが含まれています。

[XTrain,~,YTrain] = digitTrain4DArrayData;
[XValidation,~,YValidation] = digitTest4DArrayData;

imshow を使用して、ランダムに選ばれた 20 枚の学習イメージを表示します。

numTrainImages = numel(YTrain);
figure
idx = randperm(numTrainImages,20);
for i = 1:numel(idx)
    subplot(4,5,i)    
    imshow(XTrain(:,:,:,idx(i)))
    drawnow
end

最後の層の置き換え

ネットワークの畳み込み層は、入力イメージを分類するために、最後の学習可能な層と最終分類層が使用するイメージの特徴を抽出します。digitsNet のこれらの 2 つの層 'fc' および 'classoutput' は、ネットワークによって抽出された特徴を組み合わせてクラス確率、損失値、および予測ラベルにまとめる方法に関する情報を含んでいます。回帰用に事前学習済みのネットワークに再学習させるには、これら 2 つの層をこのタスクに適応させた新しい層に置き換えます。

最後の全結合層、ソフトマックス層、および分類出力層をサイズ 1 (応答の数) の全結合層と回帰層に置き換えます。

numResponses = 1;
layers = [
    layers(1:12)
    fullyConnectedLayer(numResponses)
    regressionLayer];

初期の層の凍結

これで、新しいデータでネットワークに再学習させる準備が整いました。オプションで、ネットワークの初期の層について学習率を 0 に設定すると、それらの層の重みを "凍結" できます。学習中に trainNetwork は凍結された層のパラメーターを更新しません。凍結された層の勾配は計算する必要がないため、多数の初期の層について重みを凍結すると、ネットワーク学習を大幅に高速化できます。新しいデータセットが小さい場合、初期のネットワーク層を凍結すると、新しいデータセットに対するこれらの層の過適合を防止することもできます。

サポート関数 freezeWeights を使用して、最初の 12 個の層について学習率を 0 に設定します。

layers(1:12) = freezeWeights(layers(1:12));

ネットワークの学習

ネットワーク学習オプションを作成します。初期学習率を 0.001 に設定します。検証データを指定して、学習中にネットワークの精度を監視します。学習の進行状況プロットをオンにして、コマンド ウィンドウの出力をオフにします。

options = trainingOptions('sgdm',...
    'InitialLearnRate',0.001, ...
    'ValidationData',{XValidation,YValidation},...
    'Plots','training-progress',...
    'Verbose',false);

trainNetwork を使用してネットワークを作成します。このコマンドでは、互換性のある GPU が利用できる場合は、その GPU が使用されます。そうでない場合、trainNetwork では CPU が使用されます。GPU で学習を行うには、Compute Capability 3.0 以上の CUDA® 対応 NVIDIA® GPU が必要です。

net = trainNetwork(XTrain,YTrain,layers,options);

ネットワークのテスト

検証データに対する精度を評価することによって、ネットワーク性能をテストします。

predict を使用して、検証イメージの回転角度を予測します。

YPred = predict(net,XValidation);

次の計算を行って、モデルの性能を評価します。

  1. 許容誤差限界内にある予測の比率

  2. 回転角度の予測値と実際の値の平方根平均二乗誤差 (RMSE)

回転角度の予測値と実際の値との間の予測誤差を計算します。

predictionError = YValidation - YPred;

真の角度から許容誤差限界内にある予測の数を計算します。しきい値を 10 度に設定します。このしきい値の範囲内にある予測の比率を計算します。

thr = 10;
numCorrect = sum(abs(predictionError) < thr);
numImagesValidation = numel(YValidation);

accuracy = numCorrect/numImagesValidation
accuracy = 0.7532

平方根平均二乗誤差 (RMSE) を使用して、回転角度の予測値と実際の値の差を測定します。

 rmse = sqrt(mean(predictionError.^2))
rmse = single
    9.0270

数字の回転補正

Image Processing Toolbox の関数を使用して、数字をまっすぐにし、まとめて表示できます。imrotate (Image Processing Toolbox) を使用して、49 個の数字標本をそれぞれの予測回転角度に応じて回転させます。

idx = randperm(numImagesValidation,49);
for i = 1:numel(idx)
    I = XValidation(:,:,:,idx(i));
    Y = YPred(idx(i));  
    XValidationCorrected(:,:,:,i) = imrotate(I,Y,'bicubic','crop');
end

元の数字と回転補正後の数字を表示します。montage (Image Processing Toolbox) を使用して、数字を 1 つのイメージにまとめます。

figure
subplot(1,2,1)
montage(XValidation(:,:,:,idx))
title('Original')

subplot(1,2,2)
montage(XValidationCorrected)
title('Corrected')

参考

|

関連するトピック