Main Content

ニューラル ネットワーク学習時のチェックポイントの自動保存

コンピューターで障害が発生したときや学習プロセスを強制終了した場合に復元できるように、ニューラル ネットワークの学習時に中間結果を定期的に MAT ファイルに保存できます。これにより、長時間にわたる学習の実行の成果を保護することができ、中断された場合に最初からやり直す必要がなくなります。この機能は、計算リソースの障害で中断される可能性が高い、長時間の並列学習セッションの場合に特に便利です。

チェックポイントの保存は、オプションの学習引数 'CheckpointFile' に続けてチェックポイント ファイル名またはパスを指定することで有効になります。ファイル名のみを指定した場合、既定ではそのファイルが作業ディレクトリに保存されます。ファイルには、ファイル拡張子 .mat を付けなければなりませんが、指定しなかった場合は自動的に追加されます。この例では、チェックポイントが現在の作業ディレクトリの MyCheckpoint.mat というファイルに保存されます。

[x,t] = bodyfat_dataset;
net = feedforwardnet(10);
net2 = train(net,x,t,'CheckpointFile','MyCheckpoint.mat');
22-Mar-2013 04:49:05 First Checkpoint #1: /WorkingDir/MyCheckpoint.mat
22-Mar-2013 04:49:06 Final Checkpoint #2: /WorkingDir/MyCheckpoint.mat

既定では、チェックポイントの保存は最高で 60 秒に 1 回行われます。前述の短い学習の例では、学習の開始時と終了時に 1 つずつ、2 つのチェックポイントのみが保存されます。

オプションの学習引数 'CheckpointDelay' を使用して、保存の頻度を変更できます。たとえば、ここでは浮上磁石をモデル化するためにニューラル ネットワークに学習させる時系列問題に対し、最小のチェックポイント遅延を 10 秒に設定しています。

[x,t] = maglev_dataset;
net = narxnet(1:2,1:2,10);
[X,Xi,Ai,T] = preparets(net,x,{},t);
net2 = train(net,X,T,Xi,Ai,'CheckpointFile','MyCheckpoint.mat','CheckpointDelay',10);
22-Mar-2013 04:59:28 First Checkpoint #1: /WorkingDir/MyCheckpoint.mat
22-Mar-2013 04:59:38 Write Checkpoint #2: /WorkingDir/MyCheckpoint.mat
22-Mar-2013 04:59:48 Write Checkpoint #3: /WorkingDir/MyCheckpoint.mat
22-Mar-2013 04:59:58 Write Checkpoint #4: /WorkingDir/MyCheckpoint.mat
22-Mar-2013 05:00:08 Write Checkpoint #5: /WorkingDir/MyCheckpoint.mat
22-Mar-2013 05:00:09 Final Checkpoint #6: /WorkingDir/MyCheckpoint.mat

コンピューターの障害や学習の中断が発生した後で、中断前に取得した最適なニューラル ネットワークを含むチェックポイント構造体と学習記録を再読み込みできます。この場合、stage フィールドの値は 'Final' になっており、学習が正常に終了したので最終エポックで最後の保存が行われたことが示されています。最初のエポックのチェックポイントは 'First' で示され、中間チェックポイントは 'Write' で示されます。

load('MyCheckpoint.mat')
checkpoint = 

      file: '/WorkdingDir/MyCheckpoint.mat'
      time: [2013 3 22 5 0 9.0712]
    number: 6
     stage: 'Final'
       net: [1x1 network]
        tr: [1x1 struct]

データセットを再読み込みし (必要な場合)、復元したネットワークを使用して学習を呼び出すと、最後のチェックポイントから学習を再開できます。

net = checkpoint.net;
[x,t] = maglev_dataset;
load('MyCheckpoint.mat');
[X,Xi,Ai,T] = preparets(net,x,{},t);
net2 = train(net,X,T,Xi,Ai,'CheckpointFile','MyCheckpoint.mat','CheckpointDelay',10);