Main Content

このページの内容は最新ではありません。最新版の英語を参照するには、ここをクリックします。

分類ネットワークの回帰ネットワークへの変換

この例では、学習済み分類ネットワークを回帰ネットワークに変換する方法を説明します。

事前学習済みのイメージ分類ネットワークは、100 万個を超えるイメージで学習しており、イメージを 1000 個のオブジェクト カテゴリ (キーボード、マグ カップ、鉛筆、多くの動物など) に分類できます。このネットワークは広範囲にわたるイメージについての豊富な特徴表現を学習しています。このネットワークは入力としてイメージを取り、イメージ内のオブジェクトのラベルを各オブジェクト カテゴリの確率と共に出力します。

転移学習は、深層学習アプリケーションでよく使用されています。事前学習済みのネットワークを取得して、新しいタスクの学習の開始点として使用できます。この例では、事前学習済みの分類ネットワークを利用して、回帰タスク用に再学習させる方法を説明します。

この例では、分類用の事前学習済みの畳み込みニューラル ネットワーク アーキテクチャを読み込み、分類用の層を置き換え、ネットワークに再学習させて手書きの数字の回転角度を予測します。オプションで、imrotate (Image Processing Toolbox™) を使用して、予測値に基づいてイメージの回転を補正できます。

事前学習済みのネットワークの読み込み

サポート ファイル digitsNet.mat から事前学習済みのネットワークを読み込みます。このファイルには、手書きの数字を分類する分類ネットワークが含まれています。

load digitsNet
layers = net.Layers
layers = 
  15x1 Layer array with layers:

     1   'imageinput'    Image Input             28x28x1 images with 'zerocenter' normalization
     2   'conv_1'        2-D Convolution         8 3x3x1 convolutions with stride [1  1] and padding 'same'
     3   'batchnorm_1'   Batch Normalization     Batch normalization with 8 channels
     4   'relu_1'        ReLU                    ReLU
     5   'maxpool_1'     2-D Max Pooling         2x2 max pooling with stride [2  2] and padding [0  0  0  0]
     6   'conv_2'        2-D Convolution         16 3x3x8 convolutions with stride [1  1] and padding 'same'
     7   'batchnorm_2'   Batch Normalization     Batch normalization with 16 channels
     8   'relu_2'        ReLU                    ReLU
     9   'maxpool_2'     2-D Max Pooling         2x2 max pooling with stride [2  2] and padding [0  0  0  0]
    10   'conv_3'        2-D Convolution         32 3x3x16 convolutions with stride [1  1] and padding 'same'
    11   'batchnorm_3'   Batch Normalization     Batch normalization with 32 channels
    12   'relu_3'        ReLU                    ReLU
    13   'fc'            Fully Connected         10 fully connected layer
    14   'softmax'       Softmax                 softmax
    15   'classoutput'   Classification Output   crossentropyex with '0' and 9 other classes

データの読み込み

データセットには、手書きの数字の合成イメージと各イメージに対応する回転角度 (度単位) が含まれています。

digitTrain4DArrayDatadigitTest4DArrayData を使用して学習イメージと検証イメージを 4 次元配列として読み込みます。出力 YTrain および YValidation は回転角度 (度単位) です。学習データセットと検証データセットにはそれぞれ、5000 個のイメージが含まれています。

[XTrain,~,YTrain] = digitTrain4DArrayData;
[XValidation,~,YValidation] = digitTest4DArrayData;

imshow を使用して、ランダムに選ばれた 20 個の学習イメージを表示します。

numTrainImages = numel(YTrain);
figure
idx = randperm(numTrainImages,20);
for i = 1:numel(idx)
    subplot(4,5,i)    
    imshow(XTrain(:,:,:,idx(i)))
end

Figure contains 20 axes objects. Axes object 1 contains an object of type image. Axes object 2 contains an object of type image. Axes object 3 contains an object of type image. Axes object 4 contains an object of type image. Axes object 5 contains an object of type image. Axes object 6 contains an object of type image. Axes object 7 contains an object of type image. Axes object 8 contains an object of type image. Axes object 9 contains an object of type image. Axes object 10 contains an object of type image. Axes object 11 contains an object of type image. Axes object 12 contains an object of type image. Axes object 13 contains an object of type image. Axes object 14 contains an object of type image. Axes object 15 contains an object of type image. Axes object 16 contains an object of type image. Axes object 17 contains an object of type image. Axes object 18 contains an object of type image. Axes object 19 contains an object of type image. Axes object 20 contains an object of type image.

最後の層の置き換え

ネットワークの畳み込み層は、入力イメージを分類するために、最後の学習可能な層と最終分類層が使用するイメージの特徴を抽出します。digitsNet のこれらの 2 つの層 'fc' および 'classoutput' は、ネットワークによって抽出された特徴を組み合わせてクラス確率、損失値、および予測ラベルにまとめる方法に関する情報を含んでいます。回帰用に事前学習済みのネットワークに再学習させるには、これら 2 つの層をこのタスクに適応させた新しい層に置き換えます。

最後の全結合層、ソフトマックス層、および分類出力層をサイズ 1 (応答の数) の全結合層と回帰層に置き換えます。

numResponses = 1;
layers = [
    layers(1:12)
    fullyConnectedLayer(numResponses)
    regressionLayer];

初期の層の凍結

これで、新しいデータでネットワークに再学習させる準備が整いました。オプションで、ネットワークの初期の層について学習率を 0 に設定すると、それらの層の重みを "凍結" できます。学習中に trainNetwork は凍結された層のパラメーターを更新しません。凍結された層の勾配は計算する必要がないため、多数の初期の層について重みを凍結すると、ネットワーク学習を大幅に高速化できます。新しいデータセットが小さい場合、初期のネットワーク層を凍結すると、新しいデータセットに対するこれらの層の過適合を防止することもできます。

サポート関数 freezeWeights を使用して、最初の 12 個の層について学習率を 0 に設定します。

layers(1:12) = freezeWeights(layers(1:12));

ネットワークの学習

ネットワーク学習オプションを作成します。初期学習率を 0.001 に設定します。検証データを指定して、学習中にネットワークの精度を監視します。学習の進行状況プロットをオンにして、コマンド ウィンドウの出力をオフにします。

options = trainingOptions('sgdm',...
    'InitialLearnRate',0.001, ...
    'ValidationData',{XValidation,YValidation},...
    'Plots','training-progress',...
    'Verbose',false);

trainNetwork を使用してネットワークを作成します。このコマンドでは、互換性のある GPU が利用できる場合は、その GPU が使用されます。GPU を使用するには、Parallel Computing Toolbox™ とサポートされている GPU デバイスが必要です。サポートされているデバイスについては、GPU 計算の要件 (Parallel Computing Toolbox)を参照してください。そうでない場合、trainNetwork では CPU が使用されます。

net = trainNetwork(XTrain,YTrain,layers,options);

Figure Training Progress (27-Jul-2023 15:34:10) contains 2 axes objects and another object of type uigridlayout. Axes object 1 with xlabel Iteration, ylabel Loss contains 10 objects of type patch, text, line. Axes object 2 with xlabel Iteration, ylabel RMSE contains 10 objects of type patch, text, line.

ネットワークのテスト

検証データに対する精度を評価することによって、ネットワーク性能をテストします。

predict を使用して、検証イメージの回転角度を予測します。

YPred = predict(net,XValidation);

次の計算を行って、モデルの性能を評価します。

  1. 許容誤差限界内にある予測の比率

  2. 回転角度の予測値と実際の値の平方根平均二乗誤差 (RMSE)

回転角度の予測値と実際の値との間の予測誤差を計算します。

predictionError = YValidation - YPred;

真の角度から許容誤差限界内にある予測の数を計算します。しきい値を 10 度に設定します。このしきい値の範囲内にある予測の比率を計算します。

thr = 10;
numCorrect = sum(abs(predictionError) < thr);
numImagesValidation = numel(YValidation);

accuracy = numCorrect/numImagesValidation
accuracy = 0.7532

平方根平均二乗誤差 (RMSE) を使用して、回転角度の予測値と実際の値の差を測定します。

 rmse = sqrt(mean(predictionError.^2))
rmse = single
    9.0270

数字の回転補正

Image Processing Toolbox の関数を使用して、数字をまっすぐにし、まとめて表示できます。imrotate (Image Processing Toolbox) を使用して、49 個の数字標本をそれぞれの予測回転角度に応じて回転させます。

idx = randperm(numImagesValidation,49);
for i = 1:numel(idx)
    I = XValidation(:,:,:,idx(i));
    Y = YPred(idx(i));  
    XValidationCorrected(:,:,:,i) = imrotate(I,Y,'bicubic','crop');
end

元の数字と回転補正後の数字を表示します。montage (Image Processing Toolbox) を使用して、数字を 1 つのイメージにまとめます。

figure
subplot(1,2,1)
montage(XValidation(:,:,:,idx))
title('Original')

subplot(1,2,2)
montage(XValidationCorrected)
title('Corrected')

Figure contains 2 axes objects. Axes object 1 with title Original contains an object of type image. Axes object 2 with title Corrected contains an object of type image.

参考

|

関連するトピック