メインコンテンツ

カスタム学習ループの進行状況の監視

深層学習のためにネットワークに学習させる場合、学習の進行状況を監視すると役立つことがよくあります。学習中にさまざまなメトリクスをプロットすることで、学習がどのように進んでいるかを知ることができます。たとえば、ネットワークの精度が改善されているかどうか、その改善の速度、さらにネットワークで学習データへの過適合が始まっているかどうかを判定できます。

カスタム学習ループの学習の進行状況を監視およびプロットするには、TrainingProgressMonitor オブジェクトを使用します。TrainingProgressMonitor を使用すると、次のことができます。

  • アニメーション化されたカスタム メトリクス プロットを作成し、学習時にカスタム メトリクスを記録する。

  • 学習時に学習情報を表示および記録する。

  • 早期に学習を停止する。

  • 進行状況バーで学習の進行状況を追跡する。

  • 経過時間を追跡する。

trainnet 関数を使用し、Plots 学習オプションを "training-progress" に設定してネットワークに学習させる場合、ソフトウェアは学習中にメトリクスを自動的にプロットします。詳細については、深層学習における学習の進行状況の監視を参照してください。

学習の進行状況モニターの作成

trainingProgressMonitor 関数を使用して、カスタム学習進行状況モニターを作成します。

monitor = trainingProgressMonitor;

TrainingProgressMonitor オブジェクトは経過時間を自動的に追跡します。TrainingProgressMonitor オブジェクトを作成するとタイマーが開始されます。経過時間が学習時間を正確に反映するようにするには、学習ループの開始直前にモニター オブジェクトを作成します。

[学習の進行状況] ウィンドウ

TrainingProgressMonitor オブジェクトのプロパティを使用して、学習の進行状況ウィンドウの表示を制御します。学習前、学習中、または学習後にプロパティを設定することができます。モニターを使用して学習の進行状況を追跡する方法の例については、学習中のカスタム学習ループの進行状況の監視を参照してください。

[学習の進行状況] ウィンドウの例キープロパティと設定コード例

Training Progress window showing two plots with numbers highlighting parts of the window. The first plot shows the training and validation loss and the second plot shows the training and validation accuracy.

この Figure を生成する方法の例については、学習中のカスタム学習ループの進行状況の監視を参照してください。

1

Metrics プロパティを使用して、プロットするメトリクスを指定します。

学習前にメトリクス名を追加します。

monitor.Metrics = ["TrainingLoss","ValidationLoss"];

プロットに新しい点を追加し、recordMetrics を使用して MetricValues プロパティに値を保存します。recordMetrics 関数には、メトリクス値と、反復やエポックなどの学習ループ ステップが必要です。学習プロットにおいて、メトリクス値は y 座標に対応し、学習ループ ステップは x 座標に対応します。

学習中にメトリクス値を更新します。

recordMetrics(monitor,iteration, ...
TrainingLoss=lossTrain, ...
ValidationLoss=lossValidation);

2

XLabel プロパティを使用して x 軸ラベルを設定します。

x 軸ラベルを Iteration に設定します。

monitor.XLabel = "Iteration";

3

groupSubPlot 関数を使用して、メトリクスを単一の学習サブプロットにグループ化します。

学習精度のプロットと検証精度のプロットをグループ化します。

groupSubPlot(monitor, ...
    "Accuracy",["TrainingAccuracy","ValidationAccuracy"]);

4

Progress プロパティを使用して学習の進行状況を追跡します。進捗値は範囲 [0,100] 内の数値でなければなりません。この値は、[学習の進行状況] ウィンドウ内の進行状況バーとして表示されます。

学習の進行率を設定します。

monitor.Progress = 100*(currentIteration/maxIterations);

5

Status プロパティを設定して学習のステータスを追跡します。

現在のステータスを "Running" に設定します。

monitor.Status = "Running";

6

早期停止を有効にします。[停止] ボタンをクリックすると、Stop プロパティが 1 (true) に変わります。Stop プロパティが 1 のとき、学習ループが終了すると学習が停止します。

早期停止を有効にするには、カスタム学習ループに次のコードを含めます。

while numEpochs < maxEpochs && ~monitor.Stop    
% Custom training loop code.   
end

7開始時間と経過時間を追跡します。TrainingProgressMonitor オブジェクトを作成すると時間の計測が開始されます。

モニターを作成し、タイマーを開始します。

monitor = trainingProgressMonitor;

8

Info プロパティを使用して、[学習の進行状況] ウィンドウに情報を表示します。情報値は [学習の進行状況] ウィンドウに表示されますが、プロットには表示されません。[学習の進行状況] ウィンドウに表示したいテキストと数値の情報を使用します。

学習前に情報名を追加します。

monitor.Info = ["Epoch","LearningRate"];

updateInfo 関数を使用して、[学習の進行状況] ウィンドウの情報値を更新し、InfoData プロパティに値を保存します。

学習中に情報値を更新します。

updateInfo(monitor, ...
Epoch=currentEpoch, ...
LearningRate=learnRate);

9yscale 関数を使用して y 軸のスケールを指定します。座標軸ツールバーの対数スケール ボタンをクリックしてスケールを変更することもできます。

損失プロットのスケールを対数に設定します。

yscale(monitor,"Loss","log")

学習中のカスタム学習ループの進行状況の監視

この例では、深層学習のカスタム学習ループの進行状況を監視する方法を示します。

学習データの読み込み

数字のサンプル データを解凍し、イメージ データストアを作成します。関数 imageDatastore は、フォルダー名に基づいてイメージに自動的にラベルを付けます。

unzip("DigitsData.zip")

imds = imageDatastore("DigitsData", ...
    IncludeSubfolders=true, ...
    LabelSource="foldernames");

データを学習セットと検証セットに分割します。

[imdsTrain,imdsValidation] = splitEachLabel(imds,0.6,0.2,"randomize");

この例のネットワークでは、サイズが 28×28×1 の入力イメージが必要です。学習イメージのサイズを自動的に変更するには、拡張イメージ データストアを使用します。学習イメージに対して実行する追加の拡張演算を指定します。

inputSize = [28 28 1];
pixelRange = [-5 5];

imageAugmenter = imageDataAugmenter( ...
    RandXTranslation=pixelRange, ...
    RandYTranslation=pixelRange);

augimdsTrain = augmentedImageDatastore(inputSize(1:2),imdsTrain, ...
    DataAugmentation=imageAugmenter);
augimdsValidation = augmentedImageDatastore(inputSize(1:2),imdsValidation);

学習データ内のクラスの数を決定します。

classes = categories(imdsTrain.Labels);
numClasses = numel(classes);

ネットワークの定義

イメージ分類用のネットワークを定義します。dlnetwork オブジェクトを作成します。

net = dlnetwork;

分類分岐の層を指定し、それをネットワークに追加します。

layers = [
    imageInputLayer(inputSize,Normalization="none")
    convolution2dLayer(5,20,Padding="same")
    batchNormalizationLayer
    reluLayer
    convolution2dLayer(3,20,Padding="same")
    batchNormalizationLayer
    reluLayer
    fullyConnectedLayer(numClasses)
    softmaxLayer];

net = addLayers(net,layers);
net = initialize(net);

モデル損失関数の定義

modelLoss 関数は、dlnetwork オブジェクト net、入力データ X のミニバッチ、および対応するターゲット T を入力として受け取ります。この関数は、net 内の損失、学習可能なパラメーターに関する損失の勾配、およびネットワークの状態を返します。勾配を自動的に計算するには、関数 dlgradient を使用します。

function [loss,gradients,state] = modelLoss(net,X,T)

% Forward data through network.
[Y,state] = forward(net,X);

% Calculate cross-entropy loss.
loss = crossentropy(Y,T);

% Calculate gradients of loss with respect to learnable parameters.
gradients = dlgradient(loss,net.Learnables);

end

学習オプションの指定

ミニバッチ サイズを 128 として、ネットワークに 10 エポック学習させます。

numEpochs = 10;
miniBatchSize = 128;

モーメンタム項付き確率的勾配降下法 (SGDM) 最適化のオプションを指定します。初期学習率を 0.01、減衰を 0.01、モーメンタムを 0.9 に指定します。

initialLearnRate = 0.01;
decay = 0.01;
momentum = 0.9;

モデルの学習

学習中にイメージのミニバッチを処理および管理するminibatchqueueオブジェクトを作成します。各ミニバッチで次を行います。

  • カスタム ミニバッチ前処理関数 preprocessMiniBatch (この例の最後に定義) を使用して、ラベルを one-hot 符号化変数に変換します。

  • イメージ データを次元ラベル "SSCB" (spatial、spatial、channel、batch) で形式を整えます。既定では、minibatchqueue オブジェクトは、基となるデータ型が singledlarray オブジェクトにデータを変換します。クラス ラベルの形式は整えません。

  • GPU が利用できる場合、GPU で学習を行います。既定では、minibatchqueue オブジェクトは、GPU が利用可能な場合、各出力を gpuArray オブジェクトに変換します。GPU を使用するには、Parallel Computing Toolbox™ ライセンスとサポートされている GPU デバイスが必要です。サポートされているデバイスの詳細については、GPU 計算の要件 (Parallel Computing Toolbox)を参照してください。

学習データを準備します。

mbq = minibatchqueue(augimdsTrain,...
    MiniBatchSize=miniBatchSize,...
    MiniBatchFcn=@preprocessMiniBatch,...
    MiniBatchFormat=["SSCB" ""]);

SGDM ソルバーの速度パラメーターを初期化します。

velocity = [];

エポックごとの反復回数を計算します。

numObservationsTrain = numel(imdsTrain.Files);
numIterationsPerEpoch = ceil(numObservationsTrain/miniBatchSize);
numIterations = numEpochs*numIterationsPerEpoch;

学習の進行状況モニターの準備

学習の進行状況を追跡するには、TrainingProgressMonitor オブジェクトを作成します。学習中の学習損失と学習精度、および検証損失と検証精度を記録します。学習の進行状況モニターによって、オブジェクト作成後の経過時間が自動的に追跡されます。この経過時間を学習時間の代わりとして使用するには、学習ループの先頭に近いところで TrainingProgressMonitor オブジェクトを作成するようにしてください。

monitor = trainingProgressMonitor( ...
    Metrics=["TrainingLoss","ValidationLoss","TrainingAccuracy","ValidationAccuracy"]);

groupSubPlotを使用して、学習メトリクスと検証メトリクスを同じサブプロットにプロットします。

groupSubPlot(monitor,"Loss",["TrainingLoss","ValidationLoss"]);
groupSubPlot(monitor,"Accuracy",["TrainingAccuracy","ValidationAccuracy"]);

損失の対数 y 軸スケールを指定します。学習中にスケールを切り替えるには、座標軸ツールバーの対数スケール ボタンをクリックします。

yscale(monitor,"Loss","log")

学習率、エポック、反復、および実行環境の情報値を追跡します。

monitor.Info = ["LearningRate","Epoch","Iteration","ExecutionEnvironment"];

x 軸ラベルを Iteration に設定し、現在のステータスを Configuring に設定します。学習がまだ開始されていないことを示すには、Progress プロパティを 0 に設定します。

monitor.XLabel = "Iteration";
monitor.Status = "Configuring";
monitor.Progress = 0;

実行環境を選択し、updateInfoを使用して、この情報を学習の進行状況モニターに記録します。

executionEnvironment = "auto";

if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
    updateInfo(monitor,ExecutionEnvironment="GPU");
else
    updateInfo(monitor,ExecutionEnvironment="CPU");
end

カスタム学習ループの開始

カスタム学習ループを使用してネットワークに学習させます。各エポックについて、データをシャッフルしてデータのミニバッチをループで回します。各ミニバッチで次を行います。

  • dlfeval関数と modelLoss 関数を使用して、モデルの損失、勾配、および状態を評価します。ネットワークの状態を更新します。

  • 時間ベースの減衰学習率スケジュールの学習率を決定します。

  • sgdmupdate関数を使用してネットワーク パラメーターを更新します。

  • recordMetricsを使用し、学習損失と精度を記録してプロットします。

  • updateInfoを使用し、学習率、エポック、反復数を更新して表示します。

  • 進行率を更新します。

各エポックの最後に、testnet関数を使用して検証の精度と損失を計算します。recordMetricsを使用して、これらのメトリクスを記録し、プロットします。

学習セットと検証セットの両方の精度と損失をプロットすることは、学習の進行状況を監視し、ネットワークが過適合しているかどうかを確認するのに適した方法です。ただし、これらのメトリクスを計算してプロットすることで学習時間は長くなります。

epoch = 0;
iteration = 0;

monitor.Status = "Running";

% Loop over epochs.
while epoch < numEpochs && ~monitor.Stop
    epoch = epoch + 1;

    % Shuffle data.
    shuffle(mbq);

    % Loop over mini-batches.
    while hasdata(mbq) && ~monitor.Stop
        iteration = iteration + 1;

        % Read mini-batch of data.
        [X,T] = next(mbq);

        % Evaluate the model gradients, state, and loss using the dlfeval and
        % modelLoss functions. Update the network state.
        [loss,gradients,state] = dlfeval(@modelLoss,net,X,T);
        net.State = state;

        % Determine learning rate for time-based decay learning rate schedule.
        learnRate = initialLearnRate/(1 + decay*iteration);

        % Update the network parameters using the SGDM optimizer.
        [net,velocity] = sgdmupdate(net,gradients,velocity,learnRate,momentum);

        % Record training loss and accuracy.
        Tdecode = onehotdecode(T,classes,1);
        scores = predict(net,X);
        Y = onehotdecode(scores,classes,1);
        accuracyTrain = 100*mean(Tdecode == Y);

        recordMetrics(monitor,iteration, ...
            TrainingLoss=loss, ...
            TrainingAccuracy=accuracyTrain);

        % Update learning rate, epoch, and iteration information values.
        updateInfo(monitor, ...
            LearningRate=learnRate, ...
            Epoch=string(epoch) + " of " + string(numEpochs), ...
            Iteration=string(iteration) + " of " + string(numIterations));

        % Record validation loss and accuracy.
        if iteration == 1 || ~hasdata(mbq)
            metrics = testnet(net,augimdsValidation,["crossentropy" "accuracy"], ...
                MiniBatchSize=miniBatchSize);

            recordMetrics(monitor,iteration, ...
                ValidationLoss=metrics(1), ...
                ValidationAccuracy=metrics(2));
        end

        % Update progress percentage.
        monitor.Progress = 100*iteration/numIterations;
    end
end

学習ステータスを更新します。

if monitor.Stop == 1
    monitor.Status = "Training stopped";
else
    monitor.Status = "Training complete";
end

TrainingProgressMonitor オブジェクトは、experiments.Monitorオブジェクトと同じプロパティとメソッドをもちます。そのため、実験マネージャーのセットアップ スクリプトで使用できるように、プロット コードを簡単に適応させることができます。詳細については、カスタム学習実験用のプロット コードの準備を参照してください。

サポート関数

ミニ バッチ前処理関数

preprocessMiniBatch 関数は、次の手順を使用して予測子とラベルのミニバッチを前処理します。

  1. 関数 preprocessMiniBatchPredictors を使用してイメージを前処理します。

  2. 入力 cell 配列からラベル データを抽出し、エントリを 2 番目の次元に沿って categorical 配列に連結します。

  3. カテゴリカル ラベルを数値配列に one-hot 符号化します。最初の次元への符号化は、ネットワーク出力の形状と一致する符号化された配列を生成します。

function [X,T] = preprocessMiniBatch(dataX,dataT)

% Preprocess predictors.
X = preprocessMiniBatchPredictors(dataX);

% Extract label data from cell and concatenate.
T = cat(2,dataT{1:end});

% One-hot encode labels.
T = onehotencode(T,1);

end

ミニバッチ予測子前処理関数

関数 preprocessMiniBatchPredictors は、入力 cell 配列からイメージ データを抽出することで予測子のミニバッチを前処理し、数値配列に連結します。グレースケール入力では、4 番目の次元で連結することにより、3 番目の次元が各イメージに追加されます。この次元は、シングルトン チャネル次元として使用されることになります。

function X = preprocessMiniBatchPredictors(dataX)

% Concatenate.
X = cat(4,dataX{1:end});

end

参考

| | | |

トピック