メインコンテンツ

深層学習における学習用のカスタム停止条件

R2023b 以降

この例では、trainnetを使用し、カスタム停止条件に基づいて深層学習ニューラル ネットワークの学習を停止する方法を示します。

trainingOptionsを使用してニューラル ネットワークの学習オプションを指定できます。検証データを使用して、検証損失の減少が止まったときに学習を自動的に停止させることができます。自動検証停止をオンにするには、ValidationPatience 学習オプションを使用します。

カスタム条件が満たされたときに学習を早期に停止するには、trainingOptions の名前と値のペアの引数 "OutputFcn" にカスタム関数ハンドルを渡します。trainnet は、学習の開始前、各学習の反復後、および学習の完了後にこの関数を 1 回ずつ呼び出します。出力関数が呼び出されるたびに、trainnet は現在の反復回数、損失、精度などの情報を含む構造体を渡します。カスタム出力関数が true を返すと、学習は停止します。

この例では、学習済みのネットワークは、数値センサーの読み取り値、統計、およびカテゴリカル ラベルの組み合わせに基づいて、トランスミッション システムの歯車の状態を "Tooth Fault""No Tooth Fault" の 2 つのカテゴリに分類します。詳細については、表形式データを使用したニューラル ネットワークの学習を参照してください。

この例で定義されたカスタム出力関数は、学習損失が目的の損失しきい値よりも低くなると、学習を早期に停止します。

学習データの読み込みと前処理

CSV ファイル "transmissionCasingData.csv" からトランスミッション ケーシング データを読み取ります。

filename = "transmissionCasingData.csv";
tbl = readtable(filename,TextType="String");

convertvars 関数を使用して、予測のラベルとカテゴリカル予測子を categorical に変換します。このデータ セットには、"SensorCondition""ShaftCondition" という 2 つのカテゴリカル特徴量があります。

labelName = "GearToothCondition";
categoricalPredictorNames = ["SensorCondition" "ShaftCondition"];
tbl = convertvars(tbl,[labelName categoricalPredictorNames],"categorical");

カテゴリカル特徴量を使用してネットワークに学習させるには、カテゴリカル特徴量を数値に変換しなければなりません。これは、onehotencode関数を使用して実行できます。

for i = 1:numel(categoricalPredictorNames)
    name = categoricalPredictorNames(i);
    tbl.(name) = onehotencode(tbl.(name),2);
end

テスト用のデータを確保します。データの 80% を含む学習セット、データの 10% を含む検証セット、およびデータの残りの 10% を含むテスト セットにデータを分割します。データを分割するには、この例にサポート ファイルとして添付されている関数 trainingPartitions を使用します。このファイルにアクセスするには、例をライブ スクリプトとして開きます。

numObservations = size(tbl,1);
[idxTrain,idxValidation,idxTest] = trainingPartitions(numObservations,[0.80 0.1 0.1]);

tblTrain = tbl(idxTrain,:);
tblValidation = tbl(idxValidation,:);
tblTest = tbl(idxTest,:);

関数 trainnet がサポートする形式にデータを変換します。table2array関数を使用して、予測変数とターゲットをそれぞれ数値配列と categorical 配列に変換します。

predictorNames = ["SigMean" "SigMedian" "SigRMS" "SigVar" "SigPeak" "SigPeak2Peak" ...
    "SigSkewness" "SigKurtosis" "SigCrestFactor" "SigMAD" "SigRangeCumSum" ...
    "SigCorrDimension" "SigApproxEntropy" "SigLyapExponent" "PeakFreq" ...
    "HighFreqPower" "EnvPower" "PeakSpecKurtosis" "SensorCondition" "ShaftCondition"];

XTrain = table2array(tblTrain(:,predictorNames));
TTrain = tblTrain.(labelName);

XValidation = table2array(tblValidation(:,predictorNames));
TValidation = tblValidation.(labelName);

XTest = table2array(tblTest(:,predictorNames));
TTest = tblTest.(labelName);

ネットワーク アーキテクチャ

ニューラル ネットワーク アーキテクチャを定義します。

  • 特徴入力用に、特徴量の数に応じた特徴入力層を指定します。Z スコア正規化を使用して入力を正規化します。

  • サイズが 16 の全結合層を指定し、その後に正規化層と ReLU 層を指定します。

  • 分類出力用に、クラス数と一致するサイズをもつ全結合層を指定し、その後にソフトマックス層を指定します。

numFeatures = size(XTrain,2);
hiddenSize = 16;
classNames = categories(tbl{:,labelName});
numClasses = numel(classNames);
 
layers = [
    featureInputLayer(numFeatures,Normalization="zscore")
    fullyConnectedLayer(hiddenSize)
    layerNormalizationLayer
    reluLayer
    fullyConnectedLayer(numClasses)
    softmaxLayer];

学習オプションの定義

このページの下部で定義されている関数 stopTraining を使用して、学習損失が目的の損失しきい値よりも小さい場合に学習を早期に停止します。この関数を trainnet に渡すには、trainingOptions の名前と値のペアの引数 "OutputFcn" を使用します。

学習オプションを指定します。

  • Adam ソルバーを使用して学習させます。

  • CPU を使用して学習させます。ネットワークとデータが小さいため、CPU の方がより適しています。

  • 検証データを使用して、5 回の反復ごとにネットワークを検証します。

  • エポックの最大数を 200 に設定します。

  • 学習の進行状況をプロットに表示します。

  • 詳細出力を非表示にします。

  • カスタム出力関数 stopTraining を含めます。

損失しきい値を定義します。

lossThreshold = 0.3;
options = trainingOptions("adam", ...
    ExecutionEnvironment="cpu", ...
    ValidationData={XValidation,TValidation}, ...
    ValidationFrequency=5, ...
    MaxEpochs=200, ...
    Plots="training-progress", ...
    Verbose=false, ...
    OutputFcn=@(info)stopTraining(info,lossThreshold));

ニューラル ネットワークの学習

ネットワークに学習をさせます。

[net,info] = trainnet(XTrain,TTrain,layers,"crossentropy",options);

ネットワークのテスト

学習済みネットワークを使用してテスト データのラベルを予測します。学習済みネットワークを使用して分類スコアを予測し、onehotdecode関数を使用して予測結果をラベルに変換します。

scoresTest = predict(net,XTest);
YTest = onehotdecode(scoresTest,classNames,2);
accuracy = mean(YTest==TTest)
accuracy = 
0.8636

カスタム出力関数

学習損失が損失しきい値より小さい場合に学習を停止する出力関数 stopTraining(info,lossThreshold) を定義します。出力関数が true を返すと、学習は停止します。

function stop = stopTraining(info,lossThreshold)
trainingLoss = info.TrainingLoss;
stop = trainingLoss < lossThreshold;
end

参考

| | | |

トピック