メインコンテンツ

チェックポイント ネットワークからの学習の再開

この例では、深層学習ネットワークの学習時にチェックポイント ネットワークを保存する方法と以前に保存したネットワークから学習を再開する方法を説明します。

サンプル データの読み込み

サンプル データを 4 次元配列として読み込みます。digitTrain4DArrayData は数字の学習セットを 4 次元配列データとして読み込みます。XTrain は 28×28×1×5000 の配列で、28 はイメージの高さ、28 は幅です。1 はチャネルの数で、5000 は手書きの数字の合成イメージの数です。YTrain は各観測値のラベルを含む categorical ベクトルです。

[XTrain,YTrain] = digitTrain4DArrayData;
size(XTrain)
ans = 1×4

          28          28           1        5000

XTrain のイメージをいくつか表示します。

figure
idx = randperm(size(XTrain,4),49);
I = imtile(XTrain(:,:,:,idx));
imshow(I)

Figure contains an axes object. The hidden axes object contains an object of type image.

ネットワーク アーキテクチャの定義

ニューラル ネットワーク アーキテクチャを定義します。

layers = [
    imageInputLayer([28 28 1])
    
    convolution2dLayer(3,8,Padding="same")
    batchNormalizationLayer
    reluLayer    
    maxPooling2dLayer(2,Stride=2) 
    
    convolution2dLayer(3,16,Padding="same")
    batchNormalizationLayer
    reluLayer    
    maxPooling2dLayer(2,Stride=2)
    
    convolution2dLayer(3,32,Padding="same")
    batchNormalizationLayer
    reluLayer   
    averagePooling2dLayer(7)  
    
    fullyConnectedLayer(10)
    softmaxLayer];

学習オプションの指定とネットワークの学習

モーメンタム項付き確率的勾配降下法 (SGDM) の学習オプションを指定し、チェックポイント ネットワークを保存するためのパスを指定します。

checkpointPath = pwd;
options = trainingOptions("sgdm", ...
    InitialLearnRate=0.1, ...
    MaxEpochs=20, ...
    Verbose=false, ...
    Plots="training-progress", ...
    Metrics="accuracy", ...
    Shuffle="every-epoch", ...
    CheckpointPath=checkpointPath);

関数trainnetを使用してニューラル ネットワークに学習させます。trainnet 関数は、エポックごとに 1 つのチェックポイント ネットワークを保存し、チェックポイント ファイルに一意の名前を自動的に割り当てます。分類には、クロスエントロピー損失を使用します。既定では、関数 trainnet は利用可能な GPU がある場合にそれを使用します。GPU での学習には、Parallel Computing Toolbox™ ライセンスとサポートされている GPU デバイスが必要です。サポートされているデバイスの詳細については、GPU 計算の要件 (Parallel Computing Toolbox)を参照してください。そうでない場合、関数 trainnet は CPU を使用します。実行環境を手動で選択するには、ExecutionEnvironment 学習オプションを使用します。

net1 = trainnet(XTrain,YTrain,layers,"crossentropy",options);

チェックポイント ネットワークの読み込みと学習の再開

学習が中断され、完了しなかったとします。学習を最初から再開するのではなく、最後のチェックポイント ネットワークを読み込んで、その時点から学習を再開できます。trainnet 関数は、net_checkpoint__195__2025_05_01__10_24_32.mat の形式のファイル名でチェックポイント ファイルを保存します。ここで、195 は反復回数、2025_05_01trainnet がネットワークを保存した日付、10_24_32 はその時刻です。チェックポイント ネットワークの変数名は net です。

チェックポイント ネットワークをワークスペースに読み込みます。

load("net_checkpoint__195__2025_05_01__10_24_32.mat","net")

学習オプションを指定して、最大エポック数を減らします。初期学習率などのその他の学習オプションも調整できます。

options = trainingOptions("sgdm", ...
    InitialLearnRate=0.1, ...
    MaxEpochs=15, ...
    Verbose=false, ...
    Plots="training-progress", ...
    Metrics="accuracy", ...
    Shuffle="every-epoch", ...
    CheckpointPath=checkpointPath);

新しい学習オプションで読み込んだチェックポイント ネットワークの層を使用して、学習を再開します。

net2 = trainnet(XTrain,YTrain,net,"crossentropy",options);

参考

| |

トピック