イメージ分類用の残差ネットワークの学習
この例では、残差結合のある深層学習ニューラル ネットワークを作成し、CIFAR-10 データで学習を行う方法を説明します。残差結合は畳み込みニューラル ネットワーク アーキテクチャでよく使用される要素です。残差結合を使用すると、ネットワークを通じた勾配フローが改善し、より深いネットワークの学習が可能になります。
多くの用途では、層のシンプルなシーケンスで構成されるネットワークを使用するだけで十分です。ただし、用途によっては、各層に複数の層からの入力と複数の層への出力がある、より複雑なグラフ構造のネットワークが必要です。多くの場合、これらのタイプのネットワークは有向非循環グラフ (DAG) ネットワークと呼ばれます。残差ネットワーク (ResNet) は、メイン ネットワーク層をバイパスする残差 (またはショートカット) 結合のある DAG ネットワークの一種です。MATLAB では、DAG ネットワークは dlnetwork
オブジェクトで表されます。残差結合では、パラメーターの勾配がネットワークの出力層からより初期の層へとよりスムーズに伝播するため、更に深いネットワークに学習させることができます。このようにネットワークが深くなると、より難しいタスクで高い精度を実現できます。
ResNet アーキテクチャは、初期層、それに続く "残差ブロック" を含む "スタック"、および最終層で構成されています。次の 3 種類の残差ブロックがあります。
初期残差ブロック — このブロックは、最初のスタックの開始点に出現します。この例では、ボトルネック コンポーネントを使用しています。したがって、このブロックにはダウンサンプリング ブロックと同じ層が含まれますが、最初の畳み込み層のストライドは
[1,1]
のみです。詳細については、resnetNetwork
を参照してください。標準残差ブロック — このブロックは、各スタック内の最初のダウンサンプリング残差ブロックの後に出現します。このブロックは各スタックに複数回出現し、活性化サイズを保持します。
ダウンサンプリング残差ブロック — このブロックは、各スタック (最初のスタックを除く) の開始点に出現し、各スタックで 1 回だけ出現します。ダウンサンプリング ブロックの最初の畳み込みユニットは、係数 2 で空間次元をダウンサンプリングします。
各スタックの深さは異なる可能性があります。この例では、徐々に浅くなる 3 つのスタックを使用して残差ネットワークに学習させます。最初のスタックの深さは 4、2 番目のスタックの深さは 3、最後のスタックの深さは 2 です。
各残差ブロックには深層学習層が含まれています。各ブロックの層の詳細については、resnetNetwork
を参照してください。
イメージ分類に適した残差ネットワークを作成して学習を行うには、次の手順に従います。
関数
resnetNetwork
を使用して残差ネットワークを作成します。関数
trainnet
を使用してネットワークに学習させます。学習済みネットワークはdlnetwork
オブジェクトになります。新しいデータで分類と予測を実行します。
イメージ分類用の事前学習済み残差ネットワークを読み込むこともできます。詳細については、事前学習済みの深層ニューラル ネットワークを参照してください。
データの準備
CIFAR-10 データ セット [1] をダウンロードします。このデータ セットには 60,000 個のイメージが格納されています。各イメージのサイズは 32×32 ピクセルで 3 つのカラー チャネル (RGB) があります。データ セットのサイズは 175 MB です。インターネット接続の速度によっては、ダウンロード プロセスに時間がかかることがあります。
datadir = tempdir; downloadCIFARData(datadir);
Downloading CIFAR-10 dataset (175 MB). This can take a while...done.
CIFAR-10 学習イメージとテスト イメージを 4 次元配列として読み込みます。学習セットには 50,000 個のイメージが格納されていて、テスト セットには 10,000 個のイメージが格納されています。CIFAR-10 テスト イメージをネットワークの検証用に使用します。
[XTrain,TTrain,XValidation,TValidation] = loadCIFARData(datadir);
次のコードを使用して、ランダムにサンプリングされた学習イメージを表示できます。
figure; idx = randperm(size(XTrain,4),20); im = imtile(XTrain(:,:,:,idx),ThumbnailSize=[96,96]); imshow(im)
ネットワーク学習に使用する augmentedImageDatastore
オブジェクトを作成します。学習中に、データストアは縦軸に沿って学習イメージをランダムに反転させ、水平方向および垂直方向に最大 4 ピクセルだけランダムに平行移動させます。データ拡張は、ネットワークで過適合が発生したり、学習イメージの正確な詳細が記憶されたりすることを防止するのに役立ちます。
imageSize = [32 32 3]; pixelRange = [-4 4]; imageAugmenter = imageDataAugmenter( ... RandXReflection=true, ... RandXTranslation=pixelRange, ... RandYTranslation=pixelRange); augimdsTrain = augmentedImageDatastore(imageSize,XTrain,TTrain, ... DataAugmentation=imageAugmenter, ... OutputSizeMode="randcrop");
ネットワーク アーキテクチャの定義
関数 resnetNetwork
を使用して、このデータ セットに適した残差ネットワークを作成します。
CIFAR-10 イメージは 32×32 ピクセルであるため、初期ストライドを 1 にして、サイズ 3 の小規模な初期フィルターを使用します。初期フィルターの数を 16 に設定します。
ネットワークの最初のスタックは、初期残差ブロックで始まります。後続の各スタックは、ダウンサンプリング残差ブロックで始まります。ダウンサンプリング ブロックの最初の畳み込みユニットは、係数 2 で空間次元をダウンサンプリングします。ネットワーク全体で各畳み込み層に必要な計算量をほぼ同じに保つには、空間のダウンサンプリングを実行するたびに、フィルターの数を 2 倍ずつ増加させます。スタックの深さを
[4 3 2]
に、フィルターの数を[16 32 64]
に設定します。
initialFilterSize = 3; numInitialFilters = 16; initialStride = 1; numFilters = [16 32 64]; stackDepth = [4 3 2];
2 次元残差ネットワークを作成します。
net = resnetNetwork(imageSize,10, ... InitialFilterSize=initialFilterSize, ... InitialNumFilters=numInitialFilters, ... InitialStride=initialStride, ... InitialPoolingLayer="none", ... StackDepth=[4 3 2], ... NumFilters=[16 32 64]);
ネットワークを可視化します。
plot(net);
学習オプション
学習オプションを指定します。ネットワークの学習を 80 エポック行います。ミニバッチ サイズに比例する学習率を選択し、60 エポック後に学習率を 10 分の 1 に下げます。検証データを使用してエポックごとに 1 回ネットワークを検証します。
miniBatchSize = 128; learnRate = 0.1*miniBatchSize/128; valFrequency = floor(size(XTrain,4)/miniBatchSize); options = trainingOptions("sgdm", ... InitialLearnRate=learnRate, ... MaxEpochs=80, ... MiniBatchSize=miniBatchSize, ... VerboseFrequency=valFrequency, ... Shuffle="every-epoch", ... Plots="training-progress", ... Verbose=false, ... ValidationData={XValidation,TValidation}, ... ValidationFrequency=valFrequency, ... LearnRateSchedule="piecewise", ... LearnRateDropFactor=0.1, ... LearnRateDropPeriod=60);
ネットワークの学習
trainnet
を使用してネットワークに学習させるには、doTraining
フラグを true
に設定します。分類には、クロスエントロピー損失を使用します。既定では、関数 trainnet
は利用可能な GPU がある場合にそれを使用します。GPU での学習には、Parallel Computing Toolbox™ ライセンスとサポートされている GPU デバイスが必要です。サポートされているデバイスについては、GPU 計算の要件 (Parallel Computing Toolbox)を参照してください。そうでない場合、関数 trainnet
は CPU を使用します。実行環境を指定するには、ExecutionEnvironment
学習オプションを使用します。
そうでない場合は、事前学習済みのネットワークを読み込みます。
doTraining = false; if doTraining net = trainnet(augimdsTrain,net,'crossentropy',options); else load("trainedResidualNetwork.mat","net"); end
学習済みネットワークの評価
学習セット (データ拡張なし) と検証セットに対するネットワークの最終精度を計算します。複数の観測値を使用して予測を行うには、関数 minibatchpredict
を使用します。予測スコアをラベルに変換するには、関数 scores2label
を使用します。関数 minibatchpredict
は利用可能な GPU がある場合に自動的にそれを使用します。GPU を使用するには、Parallel Computing Toolbox™ ライセンスとサポートされている GPU デバイスが必要です。サポートされているデバイスについては、GPU 計算の要件 (Parallel Computing Toolbox)を参照してください。そうでない場合、関数は CPU を使用します。
scores = minibatchpredict(net,XValidation); [YValPred,probs] = scores2label(scores,categories(TValidation)); validationError = mean(YValPred ~= TValidation); scores = minibatchpredict(net,XTrain); YTrainPred = scores2label(scores,categories(TTrain)); trainError = mean(YTrainPred ~= TTrain); disp("Training error: " + trainError*100 + "%")
Training error: 4.168%
disp("Validation error: " + validationError*100 + "%")
Validation error: 9.13%
混同行列をプロットします。列と行の要約を使用して、各クラスの適合率と再現率を表示します。このネットワークは、猫と犬を混同することがよくあります。
figure(Units="normalized",Position=[0.2 0.2 0.4 0.4]); cm = confusionchart(TValidation,YValPred); cm.Title = "Confusion Matrix for Validation Data"; cm.ColumnSummary = "column-normalized"; cm.RowSummary = "row-normalized";
次のコードを使用して、ランダムにサンプリングされた 9 つのテスト イメージを、予測されたクラスとそのクラスである確率と共に表示できます。
figure idx = randperm(size(XValidation,4),9); for i = 1:numel(idx) subplot(3,3,i) imshow(XValidation(:,:,:,idx(i))); prob = num2str(100*max(probs(idx(i),:)),3); predClass = char(YValPred(idx(i))); title([predClass + ", " + prob + "%"]) end
参照
[1] Krizhevsky, Alex. "Learning multiple layers of features from tiny images." (2009). https://www.cs.toronto.edu/~kriz/learning-features-2009-TR.pdf
[2] He, Kaiming, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. "Deep residual learning for image recognition." In Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 770-778. 2016.
参考
trainnet
| trainingOptions
| dlnetwork
| resnetNetwork
| resnet3dNetwork
| analyzeNetwork