Main Content

このページの翻訳は最新ではありません。ここをクリックして、英語の最新版を参照してください。

深層学習における学習の進行状況の監視

この例では、深層学習ネットワークの学習プロセスを監視する方法を示します。

深層学習のネットワークに学習させる場合、学習の進行状況を監視すると役に立つことがよくあります。学習中のさまざまなメトリクスのプロットにより、学習の進行状況を知ることができます。たとえば、ネットワークの精度が改善されているかどうか、その改善の速度、さらにネットワークで学習データへの過適合が始まっているかどうかを判定できます。

trainingOptionsPlots 学習オプションを "training-progress" に設定してネットワークの学習を開始すると、trainNetwork によって Figure が作成され、反復ごとに学習メトリクスが表示されます。各反復は、勾配の推定と、ネットワーク パラメーターの更新で構成されます。trainingOptions に検証データを指定すると、trainNetwork によってネットワークが検証されるたびに Figure に検証メトリクスが表示されます。Figure には次がプロットされます。

  • 学習精度 — 個々のミニバッチの分類精度。

  • 平滑化後の学習精度 — 学習精度に平滑化アルゴリズムを適用することによって求められる、平滑化された学習精度。精度を平滑化していない場合よりノイズが少なく、トレンドを見つけやすくなります。

  • 検証精度 — (trainingOptions を使用して指定された) 検証セット全体に対する分類精度。

  • 学習損失平滑化後の学習損失検証損失 それぞれ、各ミニバッチの損失、その平滑化バージョン、検証セットの損失。ネットワークの最後の層が classificationLayer である場合、損失関数は交差エントロピー損失です。分類問題と回帰問題の損失関数の詳細は、出力層を参照してください。

回帰ネットワークの場合、Figure には精度ではなく平方根平均二乗誤差 (RMSE) がプロットされます。

Figure では、影付きの背景を使用して各学習エポックがマークされます。1 エポックは、データセット全体を一巡することです。

学習中、右上隅の停止ボタンをクリックして学習を停止し、ネットワークの現在の状態を返すことができます。たとえば、ネットワークの精度が横ばい状態に達し、これ以上改善されないことが明確な場合に、学習の停止が必要になることがあります。停止ボタンのクリックの後、学習が完了するまでしばらくかかることがあります。学習が完了すると、trainNetwork が学習済みネットワークを返します。

学習が終了すると、[結果] に最終検証精度と学習の終了理由が表示されます。OutputNetwork 学習オプションが "last-iteration" (既定値) である場合、最終メトリクスは最後の学習反復に対応します。OutputNetwork 学習オプションが "best-validation-loss" である場合、最終メトリクスは検証損失が最小の反復に対応します。最終検証メトリクスが計算される反復には、プロットで "Final" とラベル付けされます。

ネットワークにバッチ正規化層が含まれている場合、最終検証メトリクスは学習中に評価される検証メトリクスと異なる可能性があります。これは、バッチ正規化に使用される平均と分散の値が、学習完了後に変わる可能性があるためです。たとえば、BatchNormalizationStatisics 学習オプションが "population" である場合、学習後に学習データが再度渡され、その結果得られる平均と分散を使用して最終的なバッチ正規化の統計量が決定されます。BatchNormalizationStatisics 学習オプションが "moving" である場合、学習中に実行時推定を使用して統計量が近似され、最新の統計値が使用されます。

右側には、学習の時間と設定に関する情報が表示されます。学習オプションの詳細は、パラメーターの設定と畳み込みニューラル ネットワークの学習を参照してください。

学習の進行状況のプロットを保存するには、学習ウィンドウの [学習プロットのエクスポート] をクリックします。プロットは PNG、JPEG、TIFF、または PDF ファイルとして保存できます。座標軸ツール バーを使用して、損失、精度、および平方根平均二乗誤差の個々のプロットを保存することもできます。

学習時の進行状況のプロット

ネットワークに学習させ、学習中にその進行状況をプロットします。

5000 個の数字のイメージが格納されている学習データを読み込みます。ネットワークの検証用に 1000 個のイメージを残しておきます。

[XTrain,YTrain] = digitTrain4DArrayData;

idx = randperm(size(XTrain,4),1000);
XValidation = XTrain(:,:,:,idx);
XTrain(:,:,:,idx) = [];
YValidation = YTrain(idx);
YTrain(idx) = [];

数字のイメージ データを分類するネットワークを構築します。

layers = [
    imageInputLayer([28 28 1])
    convolution2dLayer(3,8,Padding="same")
    batchNormalizationLayer
    reluLayer   
    maxPooling2dLayer(2,Stride=2)
    convolution2dLayer(3,16,Padding="same")
    batchNormalizationLayer
    reluLayer
    maxPooling2dLayer(2,Stride=2)
    convolution2dLayer(3,32,Padding="same")
    batchNormalizationLayer
    reluLayer
    fullyConnectedLayer(10)
    softmaxLayer
    classificationLayer];

ネットワーク学習のオプションを指定します。学習中に一定の間隔でネットワークを検証するための検証データを指定します。エポックごとに約 1 回ネットワークが検証されるように、ValidationFrequency の値を選択します。学習中に学習の進行状況をプロットするには、Plots 学習オプションを "training-progress" に設定します。

options = trainingOptions("sgdm", ...
    MaxEpochs=8, ...
    ValidationData={XValidation,YValidation}, ...
    ValidationFrequency=30, ...
    Verbose=false, ...
    Plots="training-progress");

ネットワークに学習をさせます。

net = trainNetwork(XTrain,YTrain,layers,options);

Figure Training Progress (12-Apr-2022 00:56:24) contains 2 axes objects and another object of type uigridlayout. Axes object 1 contains 15 objects of type patch, text, line. Axes object 2 contains 15 objects of type patch, text, line.

参考

|

関連するトピック