入れ子層をもつ深層学習ネットワークの学習
この例では、入れ子層をもつネットワークに学習させる方法を説明します。
それ自体が層グラフを定義するカスタム層を作成するには、学習可能パラメーターとして dlnetwork
オブジェクトを指定できます。この手法は "ネットワーク構成" と呼ばれます。以下の場合にネットワーク構成を使用できます。
学習可能な層のブロックを表す単一のカスタム層 (残差ブロックなど) の作成。
コントロール フローをもつネットワークの作成。たとえば、入力データに応じて動的に変更できるセクションをもつネットワーク。
ループをもつネットワークの作成。たとえば、自分自身に出力をフィードバックするセクションをもつネットワーク。
詳細については、深層学習のネットワーク構成を参照してください。
この例では、複数の畳み込み層、バッチ正規化層、ReLU 層、および 1 つのスキップ接続で構成される残差ブロックを表すカスタム層を使用し、ネットワークに学習させる方法を説明します。このユース ケースの場合、通常、入れ子なしの層グラフを使用するほうが簡単です。カスタム層を使用 "せずに" 残差ネットワークを作成する方法を示す例については、イメージ分類用の残差ネットワークの学習を参照してください。
残差結合は畳み込みニューラル ネットワーク アーキテクチャでよく使用される要素です。残差ネットワークは、メイン ネットワーク層をバイパスする残差 (またはショートカット) 結合のあるネットワークの一種です。残差結合を使用すると、ネットワークを通じた勾配フローが改善し、より深いネットワークの学習が可能になります。このようにネットワークを深くすることで、より難しいタスクで高い精度を得ることができます。
この例では、畳み込み層、バッチ正規化層、ReLU 層、加算層から成る学習可能な層ブロックで構成され、スキップ接続も含むカスタム層 residualBlockLayer
を使用しています。スキップ接続には、オプションとして畳み込み層とバッチ正規化層が含まれることがあります。次の図に、残差ブロックの構造を示します。
カスタム層 residualBlockLayer
を作成する方法を示す例については、入れ子になった深層学習層の定義を参照してください。
データの準備
Flowers のデータセット [1] をダウンロードし、解凍します。
url = "http://download.tensorflow.org/example_images/flower_photos.tgz"; downloadFolder = tempdir; filename = fullfile(downloadFolder,"flower_dataset.tgz"); imageFolder = fullfile(downloadFolder,"flower_photos"); if ~datasetExists(imageFolder) disp("Downloading Flowers data set (218 MB)...") websave(filename,url); untar(filename,downloadFolder) end
写真のイメージ データストアを作成します。
datasetFolder = fullfile(imageFolder); imds = imageDatastore(datasetFolder, ... IncludeSubfolders=true, ... LabelSource="foldernames");
データを学習データ セットと検証データ セットに分割します。イメージの 70% を学習に使用し、30% を検証に使用します。
[imdsTrain,imdsValidation] = splitEachLabel(imds,0.7,"randomized");
データ セットのクラスの数を表示します。
classes = categories(imds.Labels); numClasses = numel(classes)
numClasses = 5
データ拡張は、ネットワークで過適合が発生したり、学習イメージの正確な詳細が記憶されたりすることを防止するのに役立ちます。imageDataAugmenter
オブジェクトを使用し、学習イメージのサイズを変更して拡張します。
縦軸方向にイメージをランダムに反転します。
垂直方向および水平方向に最大 30 ピクセルまでイメージをランダムに平行移動します。
時計回りおよび反時計回りに最大 45 度までイメージをランダムに回転します。
垂直方向および水平方向に最大 10% までイメージをランダムにスケーリングします。
pixelRange = [-30 30]; scaleRange = [0.9 1.1]; imageAugmenter = imageDataAugmenter( ... RandXReflection=true, ... RandXTranslation=pixelRange, ... RandYTranslation=pixelRange, ... RandRotation=[-45 45], ... RandXScale=scaleRange, ... RandYScale=scaleRange);
イメージ データ オーグメンターを使用して、学習データが格納された拡張イメージ データストアを作成します。ネットワークの入力サイズに合わせてイメージのサイズを自動的に変更するには、ネットワークの入力サイズの高さと幅を指定します。この例では、入力サイズが [224 224 3]
のネットワークを使用します。
inputSize = [224 224 3]; augimdsTrain = augmentedImageDatastore(inputSize(1:2),imdsTrain,DataAugmentation=imageAugmenter);
他のデータ拡張を実行せずに検証イメージのサイズを自動的に変更するには、追加の前処理演算を指定せずに拡張イメージ データストアを使用します。
augimdsValidation = augmentedImageDatastore([224 224],imdsValidation);
ネットワーク アーキテクチャの定義
カスタム層 residualBlockLayer
を使用して、6 つの残差ブロックをもつ残差ネットワークを定義します。この層にアクセスするには、この例をライブ スクリプトとして開きます。このカスタム層を作成する方法を示す例については、入れ子になった深層学習層の定義を参照してください。
dlnetwork
オブジェクトの入力層の入力サイズを指定しなければならないため、この層を作成するときに入力サイズを指定しなければなりません。層の入力サイズを判断するために、関数 analyzeNetwork
を使用して前の層の活性化サイズをチェックできます。
numFilters = 32;
layers = [
imageInputLayer(inputSize)
convolution2dLayer(7,numFilters,Stride=2,Padding="same")
batchNormalizationLayer
reluLayer
maxPooling2dLayer(3,Stride=2)
residualBlockLayer(numFilters)
residualBlockLayer(numFilters)
residualBlockLayer(2*numFilters,Stride=2,IncludeSkipConvolution=true)
residualBlockLayer(2*numFilters)
residualBlockLayer(4*numFilters,Stride=2,IncludeSkipConvolution=true)
residualBlockLayer(4*numFilters)
globalAveragePooling2dLayer
fullyConnectedLayer(numClasses)
softmaxLayer
classificationLayer]
layers = 15×1 Layer array with layers: 1 '' Image Input 224×224×3 images with 'zerocenter' normalization 2 '' 2-D Convolution 32 7×7 convolutions with stride [2 2] and padding 'same' 3 '' Batch Normalization Batch normalization 4 '' ReLU ReLU 5 '' 2-D Max Pooling 3×3 max pooling with stride [2 2] and padding [0 0 0 0] 6 '' Residual Block Residual block with 32 filters, stride 1 7 '' Residual Block Residual block with 32 filters, stride 1 8 '' Residual Block Residual block with 64 filters, stride 2, and skip convolution 9 '' Residual Block Residual block with 64 filters, stride 1 10 '' Residual Block Residual block with 128 filters, stride 2, and skip convolution 11 '' Residual Block Residual block with 128 filters, stride 1 12 '' 2-D Global Average Pooling 2-D global average pooling 13 '' Fully Connected 5 fully connected layer 14 '' Softmax softmax 15 '' Classification Output crossentropyex
ネットワークの学習
学習オプションを次のように指定します。
ミニバッチ サイズを 128 としてネットワークに学習させます。
すべてのエポックでデータをシャッフルします。
検証データを使用してエポックごとに 1 回ネットワークを検証します。
検証損失が最も少ないネットワークを出力します。
プロットに学習の進行状況を表示し、詳細出力を無効にします。
miniBatchSize = 128; numIterationsPerEpoch = floor(augimdsTrain.NumObservations/miniBatchSize); options = trainingOptions("adam", ... MiniBatchSize=miniBatchSize, ... Shuffle="every-epoch", ... ValidationData=augimdsValidation, ... ValidationFrequency=numIterationsPerEpoch, ... OutputNetwork="best-validation-loss", ... Plots="training-progress", ... Verbose=false);
関数 trainNetwork
を使用してネットワークに学習させます。既定で、trainNetwork
は、使用可能な GPU があれば GPU を使用し、なければ CPU を使用します。GPU で学習を行うには、Parallel Computing Toolbox™ とサポートされている GPU デバイスが必要です。サポートされているデバイスについては、GPU 計算の要件 (Parallel Computing Toolbox)を参照してください。trainingOptions
の ExecutionEnvironment
オプションを使用して、実行環境を指定することもできます。
net = trainNetwork(augimdsTrain,layers,options);
学習済みネットワークの評価
学習セット (データ拡張なし) と検証セットに対するネットワークの最終精度を計算します。精度は、ネットワークによって正しく分類されるイメージの比率です。
YPred = classify(net,augimdsValidation); YValidation = imdsValidation.Labels; accuracy = mean(YPred == YValidation)
accuracy = 0.7175
混同行列として分類精度を可視化します。列と行の要約を使用して、各クラスの適合率と再現率を表示します。
figure confusionchart(YValidation,YPred, ... RowSummary="row-normalized", ... ColumnSummary="column-normalized");
次のコードを使用すると、予測ラベル、およびイメージがそれらのラベルをもつ予測確率と共に、4 個のサンプル検証イメージを表示できます。
idx = randperm(numel(imdsValidation.Files),4); figure for i = 1:4 subplot(2,2,i) I = readimage(imdsValidation,idx(i)); imshow(I) label = YPred(idx(i)); title("Predicted class: " + string(label)); end
参考文献
The TensorFlow Team.Flowers http://download.tensorflow.org/example_images/flower_photos.tgz
参考
checkLayer
| trainNetwork
| trainingOptions
| analyzeNetwork
| dlnetwork