Main Content

このページの内容は最新ではありません。最新版の英語を参照するには、ここをクリックします。

イメージ分類用の残差ネットワークの学習

この例では、残差結合のある深層学習ニューラル ネットワークを作成し、CIFAR-10 データで学習を行う方法を説明します。残差結合は畳み込みニューラル ネットワーク アーキテクチャでよく使用される要素です。残差結合を使用すると、ネットワークを通じた勾配フローが改善し、より深いネットワークの学習が可能になります。

多くの用途では、層のシンプルなシーケンスで構成されるネットワークを使用するだけで十分です。ただし、用途によっては、各層に複数の層からの入力と複数の層への出力がある、より複雑なグラフ構造のネットワークが必要です。多くの場合、これらのタイプのネットワークは有向非循環グラフ (DAG) ネットワークと呼ばれます。残差ネットワーク (ResNet) は、メイン ネットワーク層をバイパスする残差 (またはショートカット) 結合のある DAG ネットワークの一種です。残差結合では、パラメーターの勾配がネットワークの出力層からより初期の層へとよりスムーズに伝播するため、更に深いネットワークに学習させることができます。このようにネットワークが深くなると、より難しいタスクで高い精度を実現できます。

ResNet アーキテクチャは、初期層、それに続く "残差ブロック" を含む "スタック"、および最終層で構成されています。次の 3 種類の残差ブロックがあります。

  • 初期残差ブロック — このブロックは、最初のスタックの開始点に出現します。この例では、ボトルネック コンポーネントを使用しています。したがって、このブロックにはダウンサンプリング ブロックと同じ層が含まれますが、最初の畳み込み層のストライドは [1,1] のみです。詳細については、resnetLayersを参照してください。

  • 標準残差ブロック — このブロックは、各スタック内の最初のダウンサンプリング残差ブロックの後に出現します。このブロックは各スタックに複数回出現し、活性化サイズを保持します。

  • ダウンサンプリング残差ブロック — このブロックは、各スタック (最初のスタックを除く) の開始点に出現し、各スタックで 1 回だけ出現します。ダウンサンプリング ブロックの最初の畳み込みユニットは、係数 2 で空間次元をダウンサンプリングします。

各スタックの深さは異なる可能性があります。この例では、徐々に浅くなる 3 つのスタックを使用して残差ネットワークに学習させます。最初のスタックの深さは 4、2 番目のスタックの深さは 3、最後のスタックの深さは 2 です。

各残差ブロックには深層学習層が含まれています。各ブロックの層の詳細については、resnetLayersを参照してください。

イメージ分類に適した残差ネットワークを作成して学習を行うには、次の手順に従います。

  • 関数 resnetLayers を使用して残差ネットワークを作成します。

  • 関数 trainNetwork を使用してネットワークに学習させます。学習済みネットワークは DAGNetwork オブジェクトになります。

  • 関数 classify と関数 predict を使用して、新しいデータで分類と予測を実行します。

イメージ分類用の事前学習済み残差ネットワークを読み込むこともできます。詳細については、事前学習済みの深層ニューラル ネットワークを参照してください。

データの準備

CIFAR-10 データセット [1] をダウンロードします。このデータセットには 60,000 個のイメージが格納されています。各イメージのサイズは 32×32 ピクセルで 3 つのカラー チャネル (RGB) があります。データセットのサイズは 175 MB です。インターネット接続の速度によっては、ダウンロード プロセスに時間がかかることがあります。

datadir = tempdir; 
downloadCIFARData(datadir);

CIFAR-10 学習イメージとテスト イメージを 4 次元配列として読み込みます。学習セットには 50,000 個のイメージが格納されていて、テスト セットには 10,000 個のイメージが格納されています。CIFAR-10 テスト イメージをネットワークの検証用に使用します。

[XTrain,TTrain,XValidation,TValidation] = loadCIFARData(datadir);

次のコードを使用して、ランダムにサンプリングされた学習イメージを表示できます。

figure;
idx = randperm(size(XTrain,4),20);
im = imtile(XTrain(:,:,:,idx),ThumbnailSize=[96,96]);
imshow(im)

ネットワーク学習に使用する augmentedImageDatastore オブジェクトを作成します。学習中に、データストアは縦軸に沿って学習イメージをランダムに反転させ、水平方向および垂直方向に最大 4 ピクセルだけランダムに平行移動させます。データ拡張は、ネットワークで過適合が発生したり、学習イメージの正確な詳細が記憶されたりすることを防止するのに役立ちます。

imageSize = [32 32 3];
pixelRange = [-4 4];
imageAugmenter = imageDataAugmenter( ...
    RandXReflection=true, ...
    RandXTranslation=pixelRange, ...
    RandYTranslation=pixelRange);
augimdsTrain = augmentedImageDatastore(imageSize,XTrain,TTrain, ...
    DataAugmentation=imageAugmenter, ...
    OutputSizeMode="randcrop");

ネットワーク アーキテクチャの定義

関数 resnetLayers を使用して、このデータ セットに適した残差ネットワークを作成します。

  • CIFAR-10 イメージは 32×32 ピクセルであるため、初期ストライドを 1 にして、サイズ 3 の小規模な初期フィルターを使用します。初期フィルターの数を 16 に設定します。

  • ネットワークの最初のスタックは、初期残差ブロックで始まります。後続の各スタックは、ダウンサンプリング残差ブロックで始まります。ダウンサンプリング ブロックの最初の畳み込みユニットは、係数 2 で空間次元をダウンサンプリングします。ネットワーク全体で各畳み込み層に必要な計算量をほぼ同じに保つには、空間のダウンサンプリングを実行するたびに、フィルターの数を 2 倍ずつ増加させます。スタックの深さを [4 3 2] に、フィルターの数を [16 32 64] に設定します。

initialFilterSize = 3;
numInitialFilters = 16;
initialStride = 1;

numFilters = [16 32 64];
stackDepth = [4 3 2];

lgraph = resnetLayers(imageSize,10, ...
    InitialFilterSize=initialFilterSize, ...
    InitialNumFilters=numInitialFilters, ...
    InitialStride=initialStride, ...
    InitialPoolingLayer="none", ...
    StackDepth=[4 3 2], ... 
    NumFilters=[16 32 64]);

ネットワークを可視化します。

plot(lgraph);

学習オプション

学習オプションの指定。ネットワークの学習を 80 エポック行います。ミニバッチ サイズに比例する学習率を選択し、60 エポック後に学習率を 10 分の 1 に下げます。検証データを使用してエポックごとに 1 回ネットワークを検証します。

miniBatchSize = 128;
learnRate = 0.1*miniBatchSize/128;
valFrequency = floor(size(XTrain,4)/miniBatchSize);
options = trainingOptions("sgdm", ...
    InitialLearnRate=learnRate, ...
    MaxEpochs=80, ...
    MiniBatchSize=miniBatchSize, ...
    VerboseFrequency=valFrequency, ...
    Shuffle="every-epoch", ...
    Plots="training-progress", ...
    Verbose=false, ...
    ValidationData={XValidation,TValidation}, ...
    ValidationFrequency=valFrequency, ...
    LearnRateSchedule="piecewise", ...
    LearnRateDropFactor=0.1, ...
    LearnRateDropPeriod=60);

ネットワークの学習

trainNetwork を使用してネットワークに学習させるには、doTraining フラグを true に設定します。そうでない場合は、事前学習済みのネットワークを読み込みます。このネットワークの学習を高性能な GPU で行った場合、2 時間以上かかります。GPU がない場合、学習に長い時間がかかります。

doTraining = false;
if doTraining
    net = trainNetwork(augimdsTrain,lgraph,options);
else
    load("trainedResidualNetwork.mat","net");
end

学習済みネットワークの評価

学習セット (データ拡張なし) と検証セットに対するネットワークの最終精度を計算します。

[YValPred,probs] = classify(net,XValidation);
validationError = mean(YValPred ~= TValidation);
YTrainPred = classify(net,XTrain);
trainError = mean(YTrainPred ~= TTrain);
disp("Training error: " + trainError*100 + "%")
Training error: 3.462%
disp("Validation error: " + validationError*100 + "%")
Validation error: 9.27%

混同行列をプロットします。列と行の要約を使用して、各クラスの適合率と再現率を表示します。このネットワークは、猫と犬を混同することがよくあります。

figure(Units="normalized",Position=[0.2 0.2 0.4 0.4]);
cm = confusionchart(TValidation,YValPred);
cm.Title = "Confusion Matrix for Validation Data";
cm.ColumnSummary = "column-normalized";
cm.RowSummary = "row-normalized";

次のコードを使用して、ランダムにサンプリングされた 9 つのテスト イメージを、予測されたクラスとそのクラスである確率と共に表示できます。

figure
idx = randperm(size(XValidation,4),9);
for i = 1:numel(idx)
    subplot(3,3,i)
    imshow(XValidation(:,:,:,idx(i)));
    prob = num2str(100*max(probs(idx(i),:)),3);
    predClass = char(YValPred(idx(i)));
    title([predClass + ", " + prob + "%"])
end

参照

[1] Krizhevsky, Alex. "Learning multiple layers of features from tiny images." (2009). https://www.cs.toronto.edu/~kriz/learning-features-2009-TR.pdf

[2] He, Kaiming, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. "Deep residual learning for image recognition." In Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 770-778. 2016.

参考

| | | | |

関連するトピック