Main Content

このページの翻訳は最新ではありません。ここをクリックして、英語の最新版を参照してください。

深層学習ネットワークの学習時の出力のカスタマイズ

この例では、深層学習ニューラル ネットワークの学習中に各反復で実行される出力関数を定義する方法を説明します。trainingOptions の名前と値のペアの引数 'OutputFcn' を使用して出力関数を指定する場合、trainNetwork は、学習の開始前に 1 回、学習の各反復後、学習の終了後に 1 回、これらの関数を呼び出します。出力関数が呼び出されるたびに、trainNetwork は、現在の反復回数、損失、精度などの情報を含む構造体を渡します。出力関数を使用して、進行状況を表示またはプロットするか、学習を停止できます。学習を早期に停止するには、出力関数が true を返すようにします。いずれかの出力関数から true が返されると、学習が終了し、 trainNetwork から最新のネットワークが返されます。

検証セットに対する損失が減少しなくなった場合に学習を停止するには、trainingOptions の名前と値のペアの引数 'ValidationData' および 'ValidationPatience' を使用して、検証データと検証の許容回数をそれぞれ指定します。検証の許容回数は、ネットワークの学習が停止するまでに、検証セットに対する損失が前の最小損失以上になることが許容される回数です。出力関数を使用して他の停止条件を追加できます。この例では、検証データに対する分類精度が改善されなくなった場合に学習を停止する出力関数を作成する方法を示します。出力関数の定義は、このスクリプトの終わりで行います。

5000 個の数字のイメージが格納されている学習データを読み込みます。ネットワークの検証用に 1000 個のイメージを残しておきます。

[XTrain,YTrain] = digitTrain4DArrayData;

idx = randperm(size(XTrain,4),1000);
XValidation = XTrain(:,:,:,idx);
XTrain(:,:,:,idx) = [];
YValidation = YTrain(idx);
YTrain(idx) = [];

数字のイメージ データを分類するネットワークを構築します。

layers = [
    imageInputLayer([28 28 1])
    
    convolution2dLayer(3,8,'Padding','same')
    batchNormalizationLayer
    reluLayer   
    
    maxPooling2dLayer(2,'Stride',2)
    
    convolution2dLayer(3,16,'Padding','same')
    batchNormalizationLayer
    reluLayer   
    
    maxPooling2dLayer(2,'Stride',2)
    
    convolution2dLayer(3,32,'Padding','same')
    batchNormalizationLayer
    reluLayer   
    
    fullyConnectedLayer(10)
    softmaxLayer
    classificationLayer];

ネットワーク学習のオプションを指定します。学習中に一定の間隔でネットワークを検証するための検証データを指定します。エポックごとに 1 回ネットワークが検証されるように、'ValidationFrequency' の値を選択します。

検証セットに対する分類精度が改善されなくなった場合に学習を停止するために、出力関数として stopIfAccuracyNotImproving を指定します。stopIfAccuracyNotImproving の 2 番目の入力引数は、ネットワークの学習が停止するまでに、検証セットに対する精度が前の最大精度以下になることが許容される回数です。学習を行うエポックの最大回数に、任意の大きな値を選択します。学習は自動的に停止するため、学習は最終エポックに達しません。

miniBatchSize = 128;
validationFrequency = floor(numel(YTrain)/miniBatchSize);
options = trainingOptions('sgdm', ...
    'InitialLearnRate',0.01, ...
    'MaxEpochs',100, ...
    'MiniBatchSize',miniBatchSize, ...
    'VerboseFrequency',validationFrequency, ...
    'ValidationData',{XValidation,YValidation}, ...
    'ValidationFrequency',validationFrequency, ...
    'Plots','training-progress', ...
    'OutputFcn',@(info)stopIfAccuracyNotImproving(info,3));

ネットワークに学習をさせます。検証精度が向上しなくなると、学習は停止します。

net = trainNetwork(XTrain,YTrain,layers,options);
Training on single CPU.
Initializing input data normalization.
|======================================================================================================================|
|  Epoch  |  Iteration  |  Time Elapsed  |  Mini-batch  |  Validation  |  Mini-batch  |  Validation  |  Base Learning  |
|         |             |   (hh:mm:ss)   |   Accuracy   |   Accuracy   |     Loss     |     Loss     |      Rate       |
|======================================================================================================================|
|       1 |           1 |       00:00:03 |        7.81% |       12.70% |       2.7155 |       2.5169 |          0.0100 |
|       1 |          31 |       00:00:06 |       71.09% |       74.70% |       0.8805 |       0.8120 |          0.0100 |
|       2 |          62 |       00:00:08 |       87.50% |       87.90% |       0.3866 |       0.4448 |          0.0100 |
|       3 |          93 |       00:00:11 |       94.53% |       94.30% |       0.2178 |       0.2529 |          0.0100 |
|       4 |         124 |       00:00:13 |       96.09% |       96.60% |       0.1433 |       0.1759 |          0.0100 |
|       5 |         155 |       00:00:15 |      100.00% |       97.40% |       0.0994 |       0.1306 |          0.0100 |
|       6 |         186 |       00:00:18 |       99.22% |       97.90% |       0.0786 |       0.1126 |          0.0100 |
|       7 |         217 |       00:00:20 |       99.22% |       98.20% |       0.0552 |       0.0938 |          0.0100 |
|       8 |         248 |       00:00:23 |      100.00% |       97.60% |       0.0429 |       0.0871 |          0.0100 |
|       9 |         279 |       00:00:26 |      100.00% |       98.00% |       0.0338 |       0.0777 |          0.0100 |
|      10 |         310 |       00:00:28 |      100.00% |       98.50% |       0.0271 |       0.0681 |          0.0100 |
|      11 |         341 |       00:00:31 |      100.00% |       98.20% |       0.0237 |       0.0623 |          0.0100 |
|      12 |         372 |       00:00:33 |      100.00% |       98.60% |       0.0212 |       0.0570 |          0.0100 |
|      13 |         403 |       00:00:36 |      100.00% |       98.70% |       0.0186 |       0.0533 |          0.0100 |
|      14 |         434 |       00:00:38 |      100.00% |       98.70% |       0.0163 |       0.0507 |          0.0100 |
|      15 |         465 |       00:00:41 |      100.00% |       98.80% |       0.0143 |       0.0483 |          0.0100 |
|      16 |         496 |       00:00:43 |      100.00% |       99.00% |       0.0127 |       0.0457 |          0.0100 |
|      17 |         527 |       00:00:46 |      100.00% |       98.90% |       0.0113 |       0.0435 |          0.0100 |
|      18 |         558 |       00:00:48 |      100.00% |       99.00% |       0.0102 |       0.0416 |          0.0100 |
|      19 |         589 |       00:00:51 |      100.00% |       99.10% |       0.0093 |       0.0400 |          0.0100 |
|      20 |         620 |       00:00:53 |      100.00% |       99.10% |       0.0086 |       0.0387 |          0.0100 |
|      21 |         651 |       00:00:56 |      100.00% |       99.20% |       0.0081 |       0.0375 |          0.0100 |
|      22 |         682 |       00:00:58 |      100.00% |       99.10% |       0.0076 |       0.0364 |          0.0100 |
|      23 |         713 |       00:01:01 |      100.00% |       99.10% |       0.0073 |       0.0355 |          0.0100 |
|      24 |         744 |       00:01:03 |      100.00% |       99.10% |       0.0069 |       0.0346 |          0.0100 |
|======================================================================================================================|

出力関数の定義

連続 N 回のネットワーク検証で検証データに対する最良分類精度が改善されなかった場合、ネットワークの学習を停止する出力関数 stopIfAccuracyNotImproving(info,N) を定義します。この条件は、検証損失を使用する組み込み停止条件と似ていますが、ここでは損失ではなく分類精度に適用されます。

function stop = stopIfAccuracyNotImproving(info,N)

stop = false;

% Keep track of the best validation accuracy and the number of validations for which
% there has not been an improvement of the accuracy.
persistent bestValAccuracy
persistent valLag

% Clear the variables when training starts.
if info.State == "start"
    bestValAccuracy = 0;
    valLag = 0;
    
elseif ~isempty(info.ValidationLoss)
    
    % Compare the current validation accuracy to the best accuracy so far,
    % and either set the best accuracy to the current accuracy, or increase
    % the number of validations for which there has not been an improvement.
    if info.ValidationAccuracy > bestValAccuracy
        valLag = 0;
        bestValAccuracy = info.ValidationAccuracy;
    else
        valLag = valLag + 1;
    end
    
    % If the validation lag is at least N, that is, the validation accuracy
    % has not improved for at least N validations, then return true and
    % stop training.
    if valLag >= N
        stop = true;
    end
    
end

end

参考

|

関連するトピック