Main Content

分類ネットワークの回帰ネットワークへの変換

この例では、学習済み分類ネットワークを回帰ネットワークに変換する方法を説明します。

事前学習済みのイメージ分類ネットワークは、100 万個を超えるイメージで学習しており、イメージを 1000 個のオブジェクト カテゴリ (キーボード、マグ カップ、鉛筆、多くの動物など) に分類できます。このネットワークは広範囲にわたるイメージについての豊富な特徴表現を学習しています。このネットワークは入力としてイメージを取り、イメージ内のオブジェクトのラベルを各オブジェクト カテゴリの確率と共に出力します。

転移学習は、深層学習アプリケーションでよく使用されています。事前学習済みのネットワークを取得して、新しいタスクの学習の開始点として使用できます。この例では、事前学習済みの分類ネットワークを利用して、回帰タスク用に再学習させる方法を説明します。

この例では、分類用の事前学習済みの畳み込みニューラル ネットワーク アーキテクチャを読み込み、分類用の層を置き換え、ネットワークに再学習させて手書きの数字の回転角度を予測します。

事前学習済みのネットワークの読み込み

サポート ファイル digitsClassificationConvolutionNet.mat から事前学習済みのネットワークを読み込みます。このファイルには、手書きの数字を分類する分類ネットワークが含まれています。

load digitsClassificationConvolutionNet
layers = net.Layers
layers = 
  13x1 Layer array with layers:

     1   'imageinput'    Image Input                  28x28x1 images
     2   'conv_1'        2-D Convolution              10 3x3x1 convolutions with stride [2  2] and padding [0  0  0  0]
     3   'batchnorm_1'   Batch Normalization          Batch normalization with 10 channels
     4   'relu_1'        ReLU                         ReLU
     5   'conv_2'        2-D Convolution              20 3x3x10 convolutions with stride [2  2] and padding [0  0  0  0]
     6   'batchnorm_2'   Batch Normalization          Batch normalization with 20 channels
     7   'relu_2'        ReLU                         ReLU
     8   'conv_3'        2-D Convolution              40 3x3x20 convolutions with stride [2  2] and padding [0  0  0  0]
     9   'batchnorm_3'   Batch Normalization          Batch normalization with 40 channels
    10   'relu_3'        ReLU                         ReLU
    11   'gap'           2-D Global Average Pooling   2-D global average pooling
    12   'fc'            Fully Connected              10 fully connected layer
    13   'softmax'       Softmax                      softmax

データの読み込み

データ セットには、手書きの数字の合成イメージと各イメージに対応する回転角度 (度単位) が含まれています。

サポート ファイル DigitsDataTrain.matDigitsDataTest.mat から学習イメージとテスト イメージを 4 次元配列として読み込みます。変数 anglesTrain および anglesTest は回転角度 (度単位) です。学習データ セットとテスト データ セットにはそれぞれ、5000 個のイメージが含まれています。

load DigitsDataTrain
load DigitsDataTest

imshow を使用して、ランダムに選ばれた 20 個の学習イメージを表示します。

numTrainImages = numel(anglesTrain);
figure
idx = randperm(numTrainImages,20);
for i = 1:numel(idx)
    subplot(4,5,i)    
    imshow(XTrain(:,:,:,idx(i)))
end

最後の層の置き換え

ネットワークの畳み込み層は、最後の学習可能な層が入力イメージの分類に使用したイメージの特徴を抽出します。層 'fc' は、ネットワークによって抽出された特徴を組み合わせてクラス確率にまとめる方法に関する情報を含んでいます。回帰用に事前学習済みのネットワークに再学習させるには、この層およびそれに続くソフトマックス層を、このタスクに適応させた新しい層に置き換えます。

最後の全結合層を、サイズ 1 (応答の数) の全結合層に置き換えます。

numResponses = 1;
layer = fullyConnectedLayer(numResponses,Name="fc");

net = replaceLayer(net,"fc",layer)
net = 
  dlnetwork with properties:

         Layers: [13x1 nnet.cnn.layer.Layer]
    Connections: [12x2 table]
     Learnables: [14x3 table]
          State: [6x3 table]
     InputNames: {'imageinput'}
    OutputNames: {'softmax'}
    Initialized: 0

  View summary with summary.

ソフトマックス層を削除します。

net = removeLayers(net,"softmax");

層の学習率係数の調整

これで、新しいデータでネットワークに再学習させる準備が整いました。学習オプションの指定時、新しい全結合層の学習率を大きくし、グローバル学習率を小さくすることで、必要に応じて、ネットワークの初期の層の重みの学習を遅くすることができます。

係数を使用して全結合層のパラメーターの学習率を大きくするには、関数 setLearnRateFactor を使用します。

net = setLearnRateFactor(net,"fc","Weights",10);
net = setLearnRateFactor(net,"fc","Bias",10);

学習オプションの指定

学習オプションを指定します。オプションの中から選択するには、経験的解析が必要です。実験を実行してさまざまな学習オプションの構成を調べるには、Experiment Managerアプリを使用できます。

  • 小さくした学習率 0.0001 を指定します。

  • 学習の進行状況をプロットに表示します。

  • 詳細出力を無効にします。

options = trainingOptions("sgdm",...
    InitialLearnRate=0.001, ...
    Plots="training-progress",...
    Verbose=false);

ニューラル ネットワークの学習

関数trainnetを使用してニューラル ネットワークに学習させます。回帰の場合は、平均二乗誤差損失を使用します。既定では、関数 trainnet は利用可能な GPU がある場合にそれを使用します。GPU を使用するには、Parallel Computing Toolbox™ ライセンスとサポートされている GPU デバイスが必要です。サポートされているデバイスについては、GPU 計算の要件 (Parallel Computing Toolbox)を参照してください。そうでない場合、関数は CPU を使用します。実行環境を指定するには、ExecutionEnvironment 学習オプションを使用します。

net = trainnet(XTrain,anglesTrain,net,"mse",options);

ネットワークのテスト

テスト データに対する精度を評価することによって、ネットワーク性能をテストします。

predict を使用して、検証イメージの回転角度を予測します。

YTest = predict(net,XTest);

散布図で予測を可視化します。真の値に対する予測値をプロットします。

figure
scatter(YTest,anglesTest,"+")
xlabel("Predicted Value")
ylabel("True Value")

hold on
plot([-60 60], [-60 60],"r--")

参考

| |

関連するトピック