チェックポイント ネットワークからの学習の再開
この例では、深層学習ネットワークの学習時にチェックポイント ネットワークを保存する方法と以前に保存したネットワークから学習を再開する方法を説明します。
サンプル データの読み込み
サンプル データを 4 次元配列として読み込みます。digitTrain4DArrayData は数字の学習セットを 4 次元配列データとして読み込みます。XTrain は 28×28×1×5000 の配列で、28 はイメージの高さ、28 は幅です。1 はチャネルの数で、5000 は手書きの数字の合成イメージの数です。YTrain は各観測値のラベルを含む categorical ベクトルです。
[XTrain,YTrain] = digitTrain4DArrayData; size(XTrain)
ans = 1×4
28 28 1 5000
XTrain のイメージをいくつか表示します。
figure; perm = randperm(size(XTrain,4),20); for i = 1:20 subplot(4,5,i); imshow(XTrain(:,:,:,perm(i))); end

ネットワーク アーキテクチャの定義
ニューラル ネットワーク アーキテクチャを定義します。
layers = [
imageInputLayer([28 28 1])
convolution2dLayer(3,8,'Padding','same')
batchNormalizationLayer
reluLayer
maxPooling2dLayer(2,'Stride',2)
convolution2dLayer(3,16,'Padding','same')
batchNormalizationLayer
reluLayer
maxPooling2dLayer(2,'Stride',2)
convolution2dLayer(3,32,'Padding','same')
batchNormalizationLayer
reluLayer
averagePooling2dLayer(7)
fullyConnectedLayer(10)
softmaxLayer
classificationLayer];学習オプションの指定とネットワークの学習
モーメンタム項付き確率的勾配降下法 (SGDM) の学習オプションを指定し、チェックポイント ネットワークを保存するためのパスを指定します。
checkpointPath = pwd; options = trainingOptions('sgdm', ... 'InitialLearnRate',0.1, ... 'MaxEpochs',20, ... 'Verbose',false, ... 'Plots','training-progress', ... 'Shuffle','every-epoch', ... 'CheckpointPath',checkpointPath);
ネットワークに学習させます。trainNetwork は、利用可能なものがある場合は GPU を使用します。利用可能な GPU がない場合、CPU を使用します。trainNetwork は、エポックごとに 1 つのチェックポイント ネットワークを保存し、チェックポイント ファイルに一意の名前を自動的に割り当てます。
net1 = trainNetwork(XTrain,YTrain,layers,options);

チェックポイント ネットワークの読み込みと学習の再開
学習が中断され、完了しなかったとします。学習を最初からやり直すのではなく、最後のチェックポイント ネットワークを読み込んで、その時点から学習を再開できます。trainNetwork は、net_checkpoint__195__2018_07_13__11_59_10.mat の形式のファイル名でチェックポイント ファイルを保存します。ここで、195 は反復回数、2018_07_13 は trainNetwork がネットワークを保存した日付、11_59_10 はその時刻です。チェックポイント ネットワークの変数名は net です。
チェックポイント ネットワークをワークスペースに読み込みます。
load('net_checkpoint__195__2018_07_13__11_59_10.mat','net')
学習オプションを指定して、最大エポック数を減らします。初期学習率などのその他の学習オプションも調整できます。
options = trainingOptions('sgdm', ... 'InitialLearnRate',0.1, ... 'MaxEpochs',15, ... 'Verbose',false, ... 'Plots','training-progress', ... 'Shuffle','every-epoch', ... 'CheckpointPath',checkpointPath);
新しい学習オプションで読み込んだチェックポイント ネットワークの層を使用して、学習を再開します。チェックポイント ネットワークが DAG ネットワークの場合、net.Layers の代わりに layerGraph(net) を引数として使用します。
net2 = trainNetwork(XTrain,YTrain,net.Layers,options);

参考
trainnet | trainingOptions | dlnetwork