チェックポイント ネットワークからの学習の再開
この例では、深層学習ネットワークの学習時にチェックポイント ネットワークを保存する方法と以前に保存したネットワークから学習を再開する方法を説明します。
標本データの読み込み
標本データを 4 次元配列として読み込みます。digitTrain4DArrayData
は数字の学習セットを 4 次元配列データとして読み込みます。XTrain
は 28 x 28 x 1 x 5000 の配列で、28 はイメージの高さ、28 は幅です。1 はチャネルの数で、5000 は手書きの数字の合成イメージの数です。YTrain
は各観測値のラベルを含む categorical ベクトルです。
[XTrain,YTrain] = digitTrain4DArrayData; size(XTrain)
ans = 1×4
28 28 1 5000
XTrain
のイメージをいくつか表示します。
figure; perm = randperm(size(XTrain,4),20); for i = 1:20 subplot(4,5,i); imshow(XTrain(:,:,:,perm(i))); end
ネットワーク アーキテクチャの定義
ニューラル ネットワーク アーキテクチャを定義します。
layers = [ imageInputLayer([28 28 1]) convolution2dLayer(3,8,'Padding','same') batchNormalizationLayer reluLayer maxPooling2dLayer(2,'Stride',2) convolution2dLayer(3,16,'Padding','same') batchNormalizationLayer reluLayer maxPooling2dLayer(2,'Stride',2) convolution2dLayer(3,32,'Padding','same') batchNormalizationLayer reluLayer averagePooling2dLayer(7) fullyConnectedLayer(10) softmaxLayer classificationLayer];
学習オプションの指定とネットワークの学習
モーメンタム項付き確率的勾配降下法 (SGDM) の学習オプションを指定し、チェックポイント ネットワークを保存するためのパスを指定します。
checkpointPath = pwd; options = trainingOptions('sgdm', ... 'InitialLearnRate',0.1, ... 'MaxEpochs',20, ... 'Verbose',false, ... 'Plots','training-progress', ... 'Shuffle','every-epoch', ... 'CheckpointPath',checkpointPath);
ネットワークに学習させます。trainNetwork
は、利用可能なものがある場合は GPU を使用します。利用可能な GPU がない場合、CPU を使用します。trainNetwork
は、エポックごとに 1 つのチェックポイント ネットワークを保存し、チェックポイント ファイルに一意の名前を自動的に割り当てます。
net1 = trainNetwork(XTrain,YTrain,layers,options);
チェックポイント ネットワークの読み込みと学習の再開
学習が中断され、完了しなかったとします。学習を最初からやり直すのではなく、最後のチェックポイント ネットワークを読み込んで、その時点から学習を再開できます。trainNetwork
は、net_checkpoint__195__2018_07_13__11_59_10.mat
の形式のファイル名でチェックポイント ファイルを保存します。ここで、195 は反復回数、2018_07_13
は trainNetwork
がネットワークを保存した日付、11_59_10
はその時刻です。チェックポイント ネットワークの変数名は net
です。
チェックポイント ネットワークをワークスペースに読み込みます。
load('net_checkpoint__195__2018_07_13__11_59_10.mat','net')
学習オプションを指定して、エポックの最大回数を減らします。初期学習率などのその他の学習オプションも調整できます。
options = trainingOptions('sgdm', ... 'InitialLearnRate',0.1, ... 'MaxEpochs',15, ... 'Verbose',false, ... 'Plots','training-progress', ... 'Shuffle','every-epoch', ... 'CheckpointPath',checkpointPath);
新しい学習オプションで読み込んだチェックポイント ネットワークの層を使用して、学習を再開します。チェックポイント ネットワークが DAG ネットワークの場合、net.Layers
の代わりに layerGraph(net)
を引数として使用します。
net2 = trainNetwork(XTrain,YTrain,net.Layers,options);
参考
trainingOptions
| trainNetwork