最新のリリースでは、このページがまだ翻訳されていません。 このページの最新版は英語でご覧になれます。

カスタム学習ループでの学習オプションの指定

ほとんどのタスクでは、関数 trainingOptions および trainNetwork を使用して学習アルゴリズムの詳細を制御できます。タスク (たとえば、カスタム学習率スケジュール) に必要なオプションが関数 trainingOptions に用意されていない場合、自動微分を使用して独自のカスタム学習ループを定義できます。

trainingOptions と同じオプションを指定するには、次の例を指針として使用します。

ソルバー オプション

ソルバーを指定するには、学習ループの更新ステップに関数 adamupdatermspropupdate、および sgdmupdate を使用します。独自のカスタム ソルバーを実装するには、関数 dlupdate を使用して学習可能なパラメーターを更新します。

適応モーメント推定 (Adam)

Adam を使用してネットワーク パラメーターを更新するには、関数 adamupdate を使用します。対応する入力引数を使用して、勾配の減衰率と勾配の二乗の減衰係数を指定します。

平方根平均二乗伝播 (RMSProp)

RMSProp を使用してネットワーク パラメーターを更新するには、関数 rmspropupdate を使用します。対応する入力引数を使用して、分母のオフセット (イプシロン) の値を指定します。

モーメンタム項付き確率的勾配降下法 (SGDM)

SGDM を使用してネットワーク パラメーターを更新するには、関数 sgdmupdate を使用します。対応する入力引数を使用して、モーメンタム項を指定します。

学習率

学習率を指定するには、関数 adamupdatermspropupdate、および sgdmupdate の学習率入力引数を使用します。

学習率の調整、またはカスタム学習率スケジュールの使用を簡単に行うには、カスタム学習ループの前に初期学習率を設定します。

learnRate = 0.01;

区分的な学習率スケジュール

区分的な学習率スケジュールを使用して学習時に学習率を自動的に下げるには、指定された間隔の後に学習率を特定の低下係数で乗算します。

区分的な学習率スケジュールを簡単に指定するには、変数 learnRatelearnRateSchedulelearnRateDropFactor、および learnRateDropPeriod を作成します。ここで、learnRate は初期学習率であり、learnRateScedule には "piecewise" または "none" が含まれ、learnRateDropFactor は学習率の低下の係数を指定する範囲 [0, 1] のスカラーであり、learnRateDropPeriod は学習率を低下させる間隔のエポック数を指定する正の整数です。

learnRate = 0.01;
learnRateSchedule = "piecewise"
learnRateDropPeriod = 10;
learnRateDropFactor = 0.1;

learnRateSchedule オプションが "piecewise" で現在のエポック数が learnRateDropPeriod の倍数である場合、学習ループ内で、各エポックの最後に学習率を下げます。新しい学習率を、学習率と学習率の低下係数の積に設定します。

if learnRateSchedule == "piecewise" && mod(epoch,learnRateDropPeriod) == 0
    learnRate = learnRate * learnRateDropFactor;
end

プロット

学習時に学習損失と学習精度をプロットするには、モデル勾配関数で、ミニバッチの損失と精度または平方根平均二乗誤差 (RMSE) のいずれかを計算し、アニメーションの線を使用してこれらをプロットします。

プロットをオンにするかオフにするかを簡単に指定するには、"training-progress" または "none" のいずれかを含む変数 plots を作成します。検証メトリクスもプロットするには、検証の説明と同じオプション validationData および validationFrequency を使用します。

plots = "training-progress";

validationData = {XValidation, YValidation};
validationFrequency = 50;

学習前に、関数 animatedline を使用してアニメーションの線を初期化します。分類タスクの場合、学習精度と学習損失のプロットを作成します。また、検証データが指定されている場合には検証メトリクスのアニメーションの線を初期化します。

if plots == "training-progress"
    figure
    subplot(2,1,1)
    lineAccuracyTrain = animatedline;
    ylabel("Accuracy")
	
    subplot(2,1,2)
    lineLossTrain = animatedline;
    xlabel("Iteration")
    ylabel("Loss")

    if ~isempty(validationData)
        subplot(1,2,1)
        lineAccuracyValidation = animatedline;

        subplot(1,2,2)
        lineLossValidation = animatedline;
    end
end

回帰タスクの場合、変数名と変数ラベルを変更することによってコードを調整し、学習精度と検証精度ではなく学習 RMSE と検証 RMSE のプロットが初期化されるようにします。

学習ループ内で、反復の最後にプロットを更新してネットワークに適したメトリクスが含まれるようにします。分類タスクの場合、ミニバッチの精度とミニバッチの損失に対応する点を追加します。検証データが空ではなく、現在の反復が 1、または検証頻度オプションの倍数のいずれかの場合、検証データの点も追加します。

if plots == "training-progress"
    addpoints(lineAccuracyTrain,iteration,accuracyTrain)
    addpoints(lineLossTrain,iteration,lossTrain)

    if ~isempty(validationData) && (iteration == 1 || mod(iteration,validationFrequency) == 0)
        addpoints(lineAccuracyValidation,iteration,accuracyValidation)
        addpoints(lineLossValidation,iteration,lossValidation)
    end
end
ここで、accuracyTrain および lossTrain は、モデル勾配関数で計算されたミニバッチの精度とミニバッチの損失に対応します。回帰タスクの場合、ミニバッチの精度ではなくミニバッチの RMSE 損失を使用します。

ヒント

関数 addpoints では、データ点の型が double であることが必要です。dlarray オブジェクトから数値データを抽出するには、関数 extractdata を使用します。GPU からデータを収集するには、関数 gather を使用します。

検証メトリクスを計算する方法については、検証を参照してください。

詳細出力

学習時に学習損失と学習精度を詳細テーブルに表示するには、モデル勾配関数で、ミニバッチの損失と精度 (分類タスクの場合) または RMSE (回帰タスクの場合) のいずれかを計算し、関数 disp を使用して表示します。

詳細テーブルをオンにするかオフにするかを簡単に指定するには、変数 verbose および verboseFrequency を作成します。ここで、verbosetrue または false になり、verbosefrequency は詳細出力間の反復回数を指定します。検証メトリクスを表示するには、検証の説明と同じオプション validationData および validationFrequency を使用します。

verbose = true
verboseFrequency = 50;

validationData = {XValidation, YValidation};
validationFrequency = 50;

学習前に、詳細出力テーブルの見出しを表示し、関数 tic を使用してタイマーを初期化します。

disp("|======================================================================================================================|")
disp("|  Epoch  |  Iteration  |  Time Elapsed  |  Mini-batch  |  Validation  |  Mini-batch  |  Validation  |  Base Learning  |")
disp("|         |             |   (hh:mm:ss)   |   Accuracy   |   Accuracy   |     Loss     |     Loss     |      Rate       |")
disp("|======================================================================================================================|")

start = tic;

回帰タスクの場合、コードを調整し、学習精度と検証精度ではなく学習 RMSE と検証 RMSE が表示されるようにします。

verbose オプションが true で、それが最初の反復であるか、または反復回数が verboseFrequency の倍数である場合、学習ループ内で、反復の最後に詳細出力を出力します。

if verbose && (iteration == 1 || mod(iteration,verboseFrequency) == 0
    D = duration(0,0,toc(start),'Format','hh:mm:ss');

    if isempty(validationData) || mod(iteration,validationFrequency) ~= 0 
        accuracyValidation = "";
        lossValidation = "";
    end

    disp("| " + ...
        pad(epoch,7,'left') + " | " + ...
        pad(iteration,11,'left') + " | " + ...
        pad(D,14,'left') + " | " + ...
        pad(accuracyTrain,12,'left') + " | " + ...
        pad(accuracyValidation,12,'left') + " | " + ...
        pad(lossTrain,12,'left') + " | " + ...
        pad(lossValidation,12,'left') + " | " + ...
        pad(learnRate,15,'left') + " |")
end

回帰タスクの場合、コードを調整し、学習精度と検証精度ではなく学習 RMSE と検証 RMSE が表示されるようにします。

学習の終了時に、詳細テーブルの最後の罫線を出力します。

disp("|======================================================================================================================|")

検証メトリクスを計算する方法については、検証を参照してください。

ミニバッチのサイズ

ミニバッチのサイズの設定は、データの形式や使用するデータストアのタイプによって異なります。

ミニバッチのサイズを簡単に指定するには、変数 miniBatchSize を作成します。

miniBatchSize = 128;

イメージ データストア内のデータの場合、データストアの ReadSize プロパティをミニバッチのサイズに設定します。

imds.ReadSize = miniBatchSize;

拡張イメージ データストア内のデータの場合、学習前に、データストアの MiniBatchSize プロパティをミニバッチのサイズに設定します。

augimds.MiniBatchSize = miniBatchSize;

メモリ内のデータの場合、学習中の各反復の開始時に、配列から観測値を直接読み取ります。

idx = ((iteration - 1)*miniBatchSize + 1):(iteration*miniBatchSize);
X = XTrain(:,:,:,idx);

エポック数

学習ループの外側の for ループで学習を行う最大エポック数を指定します。

最大エポック数を簡単に指定するには、最大エポック数を含む変数 maxEpochs を作成します。

maxEpochs = 30;

学習ループの外側の for ループで、範囲 1、2、…、maxEpochs をループするように指定します。

for epoch = 1:maxEpochs
    ...
end

検証

学習時にネットワークを検証するには、ホールドアウトされた検証セットを残しておき、そのデータに対するネットワークのパフォーマンスを評価します。

検証オプションを簡単に指定するには、変数 validationData および validationFrequency を作成します。ここで、validationData には検証データが含まれるか空であり、validationFrequency はネットワークの検証間の反復回数を指定します。

validationData = {XValidation,YValidation};
validationFrequency = 50;

学習ループ中、ネットワーク パラメーターを更新した後に、関数 predict を使用してホールドアウトされた検証セットでネットワークのパフォーマンスをテストします。検証データが指定されており、それが最初の反復であるか、または現在の反復が verboseFrequency オプションの倍数である場合、ネットワークを検証します。

if iteration == 1 || mod(iteration,verboseFrequency) == 0
    dlYPredValidation = predict(dlnet,dlXValidation);
    lossValidation = crossentropy(softmax(dlYPredValidation), YValidation);

    [~,idx] = max(dlYPredValidation);
    labelsPredValidation = classNames(idx);

    accuracyValidation = mean(labelsPredValidation == labelsValidation);
end
ここで、YValidationclassNames のラベルに対応するダミー変数です。精度を計算するには、YValidation をラベルの配列に変換します。

回帰タスクの場合、コードを調整し、検証精度ではなく検証 RMSE が計算されるようにします。

早期停止

ホールドアウトされた検証の損失が減少しなくなったときに学習を早期に停止するには、学習ループを脱出するフラグを設定します。

検証の許容回数 (ネットワークの学習が停止するまでに、検証損失が以前の最小損失以上になることが許容される回数) を簡単に指定するには、変数 validationPatience を作成します。

validationPatience = 5;

学習前に、変数 earlyStop および validationLosses を初期化します。ここで、earlyStop は学習を早期に停止するフラグで、validationLosses には比較する損失が含まれます。false で早期停止フラグを初期化し、inf で検証損失の配列を初期化します。

earlyStop = false;
if isfinite(validationPatience)
    validationLosses = inf(1,validationPatience);
end

学習ループ内のミニバッチに対するループで、ループ条件に earlyStop フラグを追加します。

while hasdata(ds) && ~earlyStop
    ...
end

検証ステップ中に、配列 validationLosses に新しい検証を追加します。配列の最初の要素が最小である場合、earlyStop フラグを true に設定します。そうでない場合、最初の要素を削除します。

if isfinite(validationPatience)
    validationLosses = [validationLosses validationLoss];
    if min(validationLosses) == validationLosses(1)
        earlyStop = true;
    else
        validationLosses(1) = [];
    end
end

L2 正則化

重みに L2 正則化を適用するには、関数 dlupdate を使用します。

L2 正則化係数を簡単に指定するには、L2 正則化係数が含まれる変数 l2Regularization を作成します。

l2Regularization = 0.0001;

学習中、モデル勾配を計算した後に、各重みパラメーターについて、関数 dlupdate を使用して L2 正則化係数と重みの積を勾配の計算値に追加します。重みパラメーターのみを更新するには、"Weights" という名前のパラメーターを抽出します。

idx = dlnet.Learnables.Parameter == "Weights";
gradients(idx,:) = dlupdate(@(g,w) g + l2Regularization*w, gradients(idx,:), dlnet.Learnables(idx,:));

L2 正則化パラメーターを勾配に追加した後、ネットワーク パラメーターを更新します。

勾配クリップ

勾配をクリップするには、関数 dlupdate を使用します。

勾配クリップ オプションを簡単に指定するには、変数 gradientThresholdMethod および gradientThreshold を作成します。ここで、gradientThresholdMethod には "global-l2norm""l2norm"、または "absolute-value" が含まれ、gradientThreshold はしきい値または inf が含まれる正のスカラーです。

gradientThresholdMethod = "global-l2norm";
gradientThreshold = 2;

thresholdGlobalL2NormthresholdL2Norm、および thresholdAbsoluteValue という名前の関数を作成します。これらはそれぞれ "global-l2norm""l2norm"、および "absolute-value" しきい値法を適用します。

"global-l2norm" オプションの場合、この関数はモデルのすべての勾配に対して作用します。

function gradients = thresholdGlobalL2Norm(gradients,gradientThreshold)

globalL2Norm = 0;
for i = 1:numel(gradients)
    globalL2Norm = globalL2Norm + sum(gradients{i}(:).^2);
end
globalL2Norm = sqrt(globalL2Norm);

if globalL2Norm > gradientThreshold
    normScale = gradientThreshold / globalL2Norm;
    for i = 1:numel(gradients)
        gradients{i} = gradients{i} * normScale;
    end
end

end

"l2norm" および "absolute-value" オプションの場合、この関数は各勾配に対して個別に作用します。

function gradients = thresholdL2Norm(gradients,gradientThreshold)

gradientNorm = sqrt(sum(gradients(:).^2));
if gradientNorm > gradientThreshold
    gradients = gradients * (gradientThreshold / gradientNorm);
end

end
function gradients = thresholdAbsoluteValue(gradients,gradientThreshold)

gradients(gradients > gradientThreshold) = gradientThreshold;
gradients(gradients < -gradientThreshold) = -gradientThreshold;

end

学習中、モデル勾配を計算した後に、関数 dlupdate を使用して適切な勾配クリップ法を勾配に適用します。"global-l2norm" オプションにはすべてのモデル勾配が必要であるため、関数 thresholdGlobalL2Norm を勾配に直接適用します。"l2norm" および "absolute-value" オプションの場合、関数 dlupdate を使用して勾配を個別に更新します。

switch gradientThresholdMethod
    case "global-l2norm"
        gradients = thresholdGlobalL2Norm(gradients, gradientThreshold);
    case "l2norm"
        gradients = dlupdate(@(g) thresholdL2Norm(g, gradientThreshold),gradients);
    case "absolute-value"
        gradients = dlupdate(@(g) thresholdAbsoluteValue(g, gradientThreshold),gradients);
end

勾配しきい値演算を適用した後、ネットワーク パラメーターを更新します。

単一の CPU または GPU での学習

既定では、1 つの CPU を使用して計算が実行されます。単一の GPU で学習させるには、データを gpuArray オブジェクトに変換します。GPU を使用するには、Parallel Computing Toolbox™ および Compute Capability 3.0 以上の CUDA® 対応 NVIDIA® GPU が必要です。

実行環境を簡単に指定するには、"cpu""gpu"、または "auto" のいずれかが含まれる変数 executionEnvironment を作成します。

executionEnvironment = "auto"

学習中、ミニバッチを読み取った後に実行環境オプションを確認し、必要に応じてデータを gpuArray に変換します。関数 canUseGPU は使用可能な GPU を確認します。

if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
    dlX = gpuArray(dlX);
end

チェックポイント

学習中にチェックポイント ネットワークを保存するには、関数 save を使用してネットワークを保存します。

チェックポイントをオンにするかどうかを簡単に指定するには、チェックポイント ネットワークのフォルダーが含まれる、または空の変数 checkpointPath を作成します。

checkpointPath = fullfile(tempdir,"checkpoints");

チェックポイント フォルダーが存在していない場合、学習前に、チェックポイント フォルダーを作成します。

if ~exist(checkpointPath,"dir")
    mkdir(checkpointPath)
end

学習中、各エポックの最後にネットワークを MAT ファイルに保存できます。。現在の反復回数、日付、および時刻が含まれるファイル名を指定します。

if ~isempty(checkpointPath)
    D = datestr(now,'yyyy_mm_dd__HH_MM_SS');
    filename = "dlnet_checkpoint__" + iteration + "__" + D + ".mat";
    save(filename,"dlnet")
end
ここで、dlnet は保存される dlnetwork オブジェクトです。

参考

| | | | | | |

関連するトピック