Main Content

このページの内容は最新ではありません。最新版の英語を参照するには、ここをクリックします。

カスタム学習ループでの学習オプションの指定

ほとんどのタスクでは、関数 trainingOptionstrainnet、および trainNetwork を使用して学習アルゴリズムの詳細を制御できます。タスク (たとえば、カスタム学習率スケジュール) に必要なオプションが関数 trainingOptions に用意されていない場合、dlnetwork オブジェクトを使用して自分でカスタム学習ループを定義できます。dlnetwork オブジェクトを使用すると、自動微分を使用して、層グラフとして指定したネットワークに学習させることができます。

trainingOptions と同じオプションを指定するには、次の例を指針として使用します。

ソルバー オプション

ソルバーを指定するには、学習ループの更新ステップに関数 adamupdatermspropupdate、および sgdmupdate を使用します。独自のカスタム ソルバーを実装するには、関数 dlupdate を使用して学習可能なパラメーターを更新します。

適応モーメント推定 (Adam)

Adam を使用してネットワーク パラメーターを更新するには、関数 adamupdate を使用します。対応する入力引数を使用して、勾配の減衰率と勾配の二乗の減衰係数を指定します。

平方根平均二乗伝播 (RMSProp)

RMSProp を使用してネットワーク パラメーターを更新するには、関数 rmspropupdate を使用します。対応する入力引数を使用して、分母のオフセット (イプシロン) の値を指定します。

モーメンタム項付き確率的勾配降下法 (SGDM)

SGDM を使用してネットワーク パラメーターを更新するには、関数 sgdmupdate を使用します。対応する入力引数を使用して、モーメンタム項を指定します。

学習率

学習率を指定するには、関数 adamupdatermspropupdate、および sgdmupdate の学習率入力引数を使用します。

学習率の調整、またはカスタム学習率スケジュールの使用を簡単に行うには、カスタム学習ループの前に初期学習率を設定します。

learnRate = 0.01;

区分的な学習率スケジュール

区分的な学習率スケジュールを使用して学習時に学習率を自動的に下げるには、指定された間隔の後に学習率を特定の低下係数で乗算します。

区分的な学習率スケジュールを簡単に指定するには、変数 learnRatelearnRateSchedulelearnRateDropFactor、および learnRateDropPeriod を作成します。ここで、learnRate は初期学習率であり、learnRateSchedule には "piecewise" または "none" が含まれ、learnRateDropFactor は学習率の低下の係数を指定する範囲 [0, 1] のスカラーであり、learnRateDropPeriod は学習率を低下させる間隔のエポック数を指定する正の整数です。

learnRate = 0.01;
learnRateSchedule = "piecewise"
learnRateDropPeriod = 10;
learnRateDropFactor = 0.1;

learnRateSchedule オプションが "piecewise" で現在のエポック数が learnRateDropPeriod の倍数である場合、学習ループ内で、各エポックの最後に学習率を下げます。新しい学習率を、学習率と学習率の低下係数の積に設定します。

if learnRateSchedule == "piecewise" && mod(epoch,learnRateDropPeriod) == 0
    learnRate = learnRate * learnRateDropFactor;
end

プロット

学習時に学習損失と学習精度をプロットするには、モデル損失関数で、ミニバッチの損失、および精度と平方根平均二乗誤差 (RMSE) のいずれかを計算し、TrainingProgressMonitor オブジェクトを使用してこれらをプロットします。

TrainingProgressMonitor オブジェクトの Visible プロパティを設定すると、プロットをオンにするかオフにするかを簡単に指定できます。既定では、Visibletrue に設定されています。Visiblefalse に設定すると、学習メトリクスと学習情報が記録されますが、[学習の進行状況] ウィンドウは表示されません。Visible プロパティを変更することで、学習後に [学習の進行状況] ウィンドウを表示させることができます。検証メトリクスもプロットするには、検証の説明と同じオプション validationData および validationFrequency を使用します。

validationData = {XValidation, YValidation};
validationFrequency = 50;

学習を行う前に、TrainingProgressMonitor オブジェクトを初期化します。モニターによって、オブジェクト作成後の経過時間が自動的に追跡されます。この経過時間を学習時間の代わりとして使用するには、学習ループの先頭に近いところで TrainingProgressMonitor オブジェクトを作成するようにしてください。

分類タスクの場合、学習データと検証データの損失と精度を追跡するためのプロットを作成します。また、エポック数および学習の進行率を追跡します。

monitor = trainingProgressMonitor;

monitor.Metrics = ["TrainingAccuracy","ValidationAccuracy","TrainingLoss","ValidationLoss"];

groupSubPlot(monitor,"Accuracy",["TrainingAccuracy","ValidationAccuracy"]);
groupSubPlot(monitor,"Loss",["TrainingLoss","ValidationLoss"]);

monitor.Info = "Epoch";

monitor.XLabel = "Iteration";
monitor.Progress = 0;

回帰タスクの場合、変数名と変数ラベルを変更することによってコードを調整し、学習精度と検証精度ではなく学習 RMSE と検証 RMSE のプロットが初期化されるようにします。

学習ループ内で、反復の最後に関数 recordMetrics および updateInfo を使用して、学習ループに関する適切なメトリクスと情報が含まれるようにします。分類タスクの場合、ミニバッチの精度とミニバッチの損失に対応する点を追加します。現在の反復が 1 または検証頻度オプションの倍数の場合、検証データの点も追加します。

recordMetrics(monitor,iteration, ...
    TrainingLoss=lossTrain, ...
    TrainingAccuracy=accuracyTrain);

updateInfo(monitor,Epoch=string(epoch) + " of " + string(numEpochs));

if iteration == 1 || mod(iteration,validationFrequency) == 0
    recordMetrics(monitor,iteration, ...
        ValidationLoss=lossValidation, ...
        ValidationAccuracy=accuracyValidation);
end
monitor.Progress = 100*iteration/numIterations;
ここで、accuracyTrain および lossTrain は、モデル損失関数で計算されたミニバッチの精度とミニバッチの損失に対応します。回帰タスクの場合、ミニバッチの精度ではなくミニバッチの RMSE 損失を使用します。

学習を停止するには、[学習の進行状況] ウィンドウの [停止] ボタンを使用します。[停止] ボタンをクリックすると、モニターの Stop プロパティが 1 (true) に変わります。Stop プロパティが 1 のとき、学習ループが終了すると学習が停止します。

while numEpochs < maxEpochs && ~monitor.Stop    
% Custom training loop code.   
end

学習時におけるメトリクスのプロットと記録の詳細については、Monitor Custom Training Loop Progress During Training を参照してください。

検証メトリクスを計算する方法については、検証を参照してください。

詳細出力

学習時に学習損失と学習精度を詳細テーブルに表示するには、モデル損失関数で、ミニバッチの損失と精度 (分類タスクの場合) または RMSE (回帰タスクの場合) のいずれかを計算し、関数 disp を使用して表示します。

詳細テーブルをオンにするかオフにするかを簡単に指定するには、変数 verbose および verboseFrequency を作成します。ここで、verbosetrue または false になり、verbosefrequency は詳細出力間の反復回数を指定します。検証メトリクスを表示するには、検証の説明と同じオプション validationData および validationFrequency を使用します。

verbose = true
verboseFrequency = 50;

validationData = {XValidation, YValidation};
validationFrequency = 50;

学習前に、詳細出力テーブルの見出しを表示し、関数 tic を使用してタイマーを初期化します。

disp("|======================================================================================================================|")
disp("|  Epoch  |  Iteration  |  Time Elapsed  |  Mini-batch  |  Validation  |  Mini-batch  |  Validation  |  Base Learning  |")
disp("|         |             |   (hh:mm:ss)   |   Accuracy   |   Accuracy   |     Loss     |     Loss     |      Rate       |")
disp("|======================================================================================================================|")

start = tic;

回帰タスクの場合、コードを調整し、学習精度と検証精度ではなく学習 RMSE と検証 RMSE が表示されるようにします。

verbose オプションが true で、それが最初の反復であるか、または反復回数が verboseFrequency の倍数である場合、学習ループ内で、反復の最後に詳細出力を出力します。

if verbose && (iteration == 1 || mod(iteration,verboseFrequency) == 0
    D = duration(0,0,toc(start),'Format','hh:mm:ss');

    if isempty(validationData) || mod(iteration,validationFrequency) ~= 0 
        accuracyValidation = "";
        lossValidation = "";
    end

    disp("| " + ...
        pad(epoch,7,'left') + " | " + ...
        pad(iteration,11,'left') + " | " + ...
        pad(D,14,'left') + " | " + ...
        pad(accuracyTrain,12,'left') + " | " + ...
        pad(accuracyValidation,12,'left') + " | " + ...
        pad(lossTrain,12,'left') + " | " + ...
        pad(lossValidation,12,'left') + " | " + ...
        pad(learnRate,15,'left') + " |")
end

回帰タスクの場合、コードを調整し、学習精度と検証精度ではなく学習 RMSE と検証 RMSE が表示されるようにします。

学習の終了時に、詳細テーブルの最後の罫線を出力します。

disp("|======================================================================================================================|")

検証メトリクスを計算する方法については、検証を参照してください。

ミニバッチのサイズ

ミニバッチのサイズの設定は、データの形式や使用するデータストアのタイプによって異なります。

ミニバッチのサイズを簡単に指定するには、変数 miniBatchSize を作成します。

miniBatchSize = 128;

イメージ データストア内のデータの場合、データストアの ReadSize プロパティをミニバッチのサイズに設定します。

imds.ReadSize = miniBatchSize;

拡張イメージ データストア内のデータの場合、学習前に、データストアの MiniBatchSize プロパティをミニバッチのサイズに設定します。

augimds.MiniBatchSize = miniBatchSize;

メモリ内のデータの場合、学習中の各反復の開始時に、配列から観測値を直接読み取ります。

idx = ((iteration - 1)*miniBatchSize + 1):(iteration*miniBatchSize);
X = XTrain(:,:,:,idx);

エポック数

学習ループの外側のループで学習を行う最大エポック数を指定します。

最大エポック数を簡単に指定するには、最大エポック数を含む変数 maxEpochs を作成します。

maxEpochs = 30;

学習ループの外側のループで、範囲 1、2、…、maxEpochs をループするように指定します。

epoch = 0;
while numEpochs < maxEpochs
    epoch = epoch + 1;
    ...
end

検証

学習時にネットワークを検証するには、ホールドアウトされた検証セットを残しておき、そのデータに対するネットワークのパフォーマンスを評価します。

検証オプションを簡単に指定するには、変数 validationData および validationFrequency を作成します。ここで、validationData には検証データが含まれるか空であり、validationFrequency はネットワークの検証間の反復回数を指定します。

validationData = {XValidation,TValidation};
validationFrequency = 50;

学習ループ中、ネットワーク パラメーターを更新した後に、関数 predict を使用してホールドアウトされた検証セットでネットワークのパフォーマンスをテストします。検証データが指定されており、それが最初の反復であるか、または現在の反復が validationFrequency オプションの倍数である場合、ネットワークを検証します。

if iteration == 1 || mod(iteration,validationFrequency) == 0
    YValidation = predict(net,XValidation);
    lossValidation = crossentropy(YValidation,TValidation);

    [~,idx] = max(YValidation);
    labelsPredValidation = classNames(idx);

    accuracyValidation = mean(labelsPredValidation == labelsValidation);
end
ここで、TValidationclassNames のラベルの one-hot 符号化配列です。精度を計算するには、TValidation をラベルの配列に変換します。

回帰タスクの場合、コードを調整し、検証精度ではなく検証 RMSE が計算されるようにします。

学習中に検証メトリクスを計算してプロットする方法を示す例については、Monitor Custom Training Loop Progress During Trainingを参照してください。

早期停止

ホールドアウトされた検証の損失が減少しなくなったときに学習を早期に停止するには、学習ループを脱出するフラグを使用します。

検証の許容回数 (ネットワークの学習が停止するまでに、検証損失が以前の最小損失以上になることが許容される回数) を簡単に指定するには、変数 validationPatience を作成します。

validationPatience = 5;

学習前に、変数 earlyStop および validationLosses を初期化します。ここで、earlyStop は学習を早期に停止するフラグで、validationLosses には比較する損失が含まれます。false で早期停止フラグを初期化し、inf で検証損失の配列を初期化します。

earlyStop = false;
if isfinite(validationPatience)
    validationLosses = inf(1,validationPatience);
end

学習ループ内のミニバッチに対するループで、ループ条件に earlyStop フラグを追加します。

while hasdata(ds) && ~earlyStop
    ...
end

検証ステップ中に、配列 validationLosses に新しい検証を追加します。配列の最初の要素が最小である場合、earlyStop フラグを true に設定します。そうでない場合、最初の要素を削除します。

if isfinite(validationPatience)
    validationLosses = [validationLosses validationLoss];
    if min(validationLosses) == validationLosses(1)
        earlyStop = true;
    else
        validationLosses(1) = [];
    end
end

L2 正則化

重みに L2 正則化を適用するには、関数 dlupdate を使用します。

L2 正則化係数を簡単に指定するには、L2 正則化係数が含まれる変数 l2Regularization を作成します。

l2Regularization = 0.0001;

学習中、モデルの損失と勾配を計算した後に、各重みパラメーターについて、関数 dlupdate を使用して L2 正則化係数と重みの積を勾配の計算値に追加します。重みパラメーターのみを更新するには、"Weights" という名前のパラメーターを抽出します。

idx = net.Learnables.Parameter == "Weights";
gradients(idx,:) = dlupdate(@(g,w) g + l2Regularization*w, gradients(idx,:), net.Learnables(idx,:));

L2 正則化パラメーターを勾配に追加した後、ネットワーク パラメーターを更新します。

勾配クリップ

勾配をクリップするには、関数 dlupdate を使用します。

勾配クリップ オプションを簡単に指定するには、変数 gradientThresholdMethod および gradientThreshold を作成します。ここで、gradientThresholdMethod には "global-l2norm""l2norm"、または "absolute-value" が含まれ、gradientThreshold はしきい値または inf が含まれる正のスカラーです。

gradientThresholdMethod = "global-l2norm";
gradientThreshold = 2;

thresholdGlobalL2NormthresholdL2Norm、および thresholdAbsoluteValue という名前の関数を作成します。これらはそれぞれ "global-l2norm""l2norm"、および "absolute-value" しきい値法を適用します。

"global-l2norm" オプションの場合、この関数はモデルのすべての勾配に対して作用します。

function gradients = thresholdGlobalL2Norm(gradients,gradientThreshold)

globalL2Norm = 0;
for i = 1:numel(gradients)
    globalL2Norm = globalL2Norm + sum(gradients{i}(:).^2);
end
globalL2Norm = sqrt(globalL2Norm);

if globalL2Norm > gradientThreshold
    normScale = gradientThreshold / globalL2Norm;
    for i = 1:numel(gradients)
        gradients{i} = gradients{i} * normScale;
    end
end

end

"l2norm" および "absolute-value" オプションの場合、この関数は各勾配に対して個別に作用します。

function gradients = thresholdL2Norm(gradients,gradientThreshold)

gradientNorm = sqrt(sum(gradients(:).^2));
if gradientNorm > gradientThreshold
    gradients = gradients * (gradientThreshold / gradientNorm);
end

end
function gradients = thresholdAbsoluteValue(gradients,gradientThreshold)

gradients(gradients > gradientThreshold) = gradientThreshold;
gradients(gradients < -gradientThreshold) = -gradientThreshold;

end

学習中、モデルの損失と勾配を計算した後に、関数 dlupdate を使用して適切な勾配クリップ法を勾配に適用します。"global-l2norm" オプションにはすべての勾配値が必要であるため、関数 thresholdGlobalL2Norm を勾配に直接適用します。"l2norm" および "absolute-value" オプションの場合、関数 dlupdate を使用して勾配を個別に更新します。

switch gradientThresholdMethod
    case "global-l2norm"
        gradients = thresholdGlobalL2Norm(gradients, gradientThreshold);
    case "l2norm"
        gradients = dlupdate(@(g) thresholdL2Norm(g, gradientThreshold),gradients);
    case "absolute-value"
        gradients = dlupdate(@(g) thresholdAbsoluteValue(g, gradientThreshold),gradients);
end

勾配しきい値演算を適用した後、ネットワーク パラメーターを更新します。

単一の CPU または GPU での学習

既定では、ソフトウェアは CPU のみを使用して計算を実行します。単一の GPU で学習を行うには、データを gpuArray オブジェクトに変換します。GPU を使用するには Parallel Computing Toolbox™ ライセンスとサポートされている GPU デバイスが必要です。サポートされているデバイスについては、GPU 計算の要件 (Parallel Computing Toolbox)を参照してください。

実行環境を簡単に指定するには、"cpu""gpu"、または "auto" のいずれかが含まれる変数 executionEnvironment を作成します。

executionEnvironment = "auto"

学習中、ミニバッチを読み取った後に実行環境オプションを確認し、必要に応じてデータを gpuArray に変換します。関数 canUseGPU は使用可能な GPU を確認します。

if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
    X = gpuArray(X);
end

チェックポイント

学習中にチェックポイント ネットワークを保存するには、関数 save を使用してネットワークを保存します。

チェックポイントをオンにするかどうかを簡単に指定するには、チェックポイント ネットワークのフォルダーが含まれる、または空の変数 checkpointPath を作成します。

checkpointPath = fullfile(tempdir,"checkpoints");

チェックポイント フォルダーが存在していない場合、学習前に、チェックポイント フォルダーを作成します。

if ~exist(checkpointPath,"dir")
    mkdir(checkpointPath)
end

学習中、各エポックの最後にネットワークを MAT ファイルに保存できます。。現在の反復回数、日付、および時刻が含まれるファイル名を指定します。

if ~isempty(checkpointPath)
    D = string(datetime("now",Format="yyyy_MM_dd__HH_mm_ss"));
    filename = "net_checkpoint__" + iteration + "__" + D + ".mat";
    save(filename,"net")
end
ここで、net は保存される dlnetwork オブジェクトです。

参考

| | | | | | |

関連するトピック