ドキュメンテーション

最新のリリースでは、このページがまだ翻訳されていません。 このページの最新版は英語でご覧になれます。

チェックポイント ネットワークからの学習の再開

この例では、深層学習ネットワークの学習時にチェックポイント ネットワークを保存する方法と以前に保存したネットワークから学習を再開する方法を説明します。

標本データの読み込み

標本データを 4 次元配列として読み込みます。digitTrain4DArrayData は数字の学習セットを 4 次元配列データとして読み込みます。XTrain は 28 x 28 x 1 x 5000 の配列で、28 はイメージの高さ、28 は幅です。1 はチャネルの数で、5000 は手書きの数字の合成イメージの数です。YTrain は各観測値のラベルを含む categorical ベクトルです。

[XTrain,YTrain] = digitTrain4DArrayData;
size(XTrain)
ans = 1×4

          28          28           1        5000

XTrain のイメージをいくつか表示します。

figure;
perm = randperm(size(XTrain,4),20);
for i = 1:20
    subplot(4,5,i);
    imshow(XTrain(:,:,:,perm(i)));
end

ネットワーク アーキテクチャの定義

ニューラル ネットワーク アーキテクチャを定義します。

layers = [
    imageInputLayer([28 28 1])
    
    convolution2dLayer(3,8,'Padding','same')
    batchNormalizationLayer
    reluLayer    
    maxPooling2dLayer(2,'Stride',2) 
    
    convolution2dLayer(3,16,'Padding','same')
    batchNormalizationLayer
    reluLayer    
    maxPooling2dLayer(2,'Stride',2)
    
    convolution2dLayer(3,32,'Padding','same')
    batchNormalizationLayer
    reluLayer   
    averagePooling2dLayer(7)  
    
    fullyConnectedLayer(10)
    softmaxLayer
    classificationLayer];

学習オプションの指定とネットワークの学習

モーメンタム項付き確率的勾配降下法 (SGDM) の学習オプションを指定し、チェックポイント ネットワークを保存するためのパスを指定します。

checkpointPath = pwd;
options = trainingOptions('sgdm', ...
    'InitialLearnRate',0.1, ...
    'MaxEpochs',20, ...
    'Verbose',false, ...
    'Plots','training-progress', ...
    'Shuffle','every-epoch', ...
    'CheckpointPath',checkpointPath);

ネットワークに学習させます。trainNetwork は、利用可能なものがある場合は GPU を使用します。利用可能な GPU がない場合、CPU を使用します。trainNetwork は、エポックごとに 1 つのチェックポイント ネットワークを保存し、チェックポイント ファイルに一意の名前を自動的に割り当てます。

net1 = trainNetwork(XTrain,YTrain,layers,options);

チェックポイント ネットワークの読み込みと学習の再開

学習が中断され、完了しなかったとします。学習を最初からやり直すのではなく、最後のチェックポイント ネットワークを読み込んで、その時点から学習を再開できます。trainNetwork は、net_checkpoint__195__2018_07_13__11_59_10.mat の形式のファイル名でチェックポイント ファイルを保存します。ここで、195 は反復回数、2018_07_13trainNetwork がネットワークを保存した日付、11_59_10 はその時刻です。チェックポイント ネットワークの変数名は net です。

チェックポイント ネットワークをワークスペースに読み込みます。

load('net_checkpoint__195__2018_07_13__11_59_10.mat','net')

学習オプションを指定して、エポックの最大回数を減らします。初期学習率などのその他の学習オプションも調整できます。

options = trainingOptions('sgdm', ...
    'InitialLearnRate',0.1, ...
    'MaxEpochs',15, ...
    'Verbose',false, ...
    'Plots','training-progress', ...
    'Shuffle','every-epoch', ...
    'CheckpointPath',checkpointPath);

新しい学習オプションで読み込んだチェックポイント ネットワークの層を使用して、学習を再開します。チェックポイント ネットワークが DAG ネットワークの場合、net.Layers の代わりに layerGraph(net) を引数として使用します。

net2 = trainNetwork(XTrain,YTrain,net.Layers,options);

参考

|

関連する例

詳細