Main Content

このページの翻訳は最新ではありません。ここをクリックして、英語の最新版を参照してください。

カスタム学習ループのモデル勾配関数の定義

カスタム学習ループを使用して深層学習モデルに学習させる場合、ソフトウェアは、学習可能なパラメーターについての損失を最小化します。損失を最小化するために、ソフトウェアは、学習可能なパラメーターについての損失の勾配を使用します。自動微分を使用してこれらの勾配を計算するには、モデル勾配関数を定義しなければなりません。

dlnetwork オブジェクトを使用して深層学習モデルに学習させる方法を示す例については、カスタム学習ループを使用したネットワークの学習を参照してください。関数として定義される深層学習モデルに学習させる方法を示す例については、モデル関数を使用したネットワークの学習を参照してください。

dlnetwork オブジェクトとして定義されるモデルのモデル勾配関数の作成

dlnetwork オブジェクトとして定義される深層学習モデルがある場合、入力として dlnetwork オブジェクトを受け取るモデル勾配関数を作成します。

dlnetwork オブジェクトとして指定されるモデルでは、gradients = modelGradients(dlnet,dlX,T) という形式の関数を作成します。ここで、dlnet はネットワークであり、dlX には入力予測子、T にはターゲット、gradients には返される勾配が含まれます。オプションで、(損失関数に追加情報が必要な場合などに) 勾配関数に追加引数を渡したり、(学習の進行状況をプロットするためのメトリクスなどの) 追加引数を返すことができます。

たとえば、この関数は、指定された dlnetwork オブジェクト dlnet、入力データ dlX、およびターゲット T の勾配と交差エントロピー損失を返します。

function [gradients, loss] = modelGradients(dlnet, dlX, T)

    % Forward data through the dlnetwork object.
    dlY = forward(dlnet,dlX);

    % Compute loss.
    loss = crossentropy(dlX,T);

    % Compute gradients.
    gradients = dlgradient(loss,dlnet);

end

関数として定義されるモデルのモデル勾配関数の作成

dlY = model(parameters,dlX) という形式の関数として定義される深層学習モデルがある場合、gradients = modelGradients(parameters,dlX,T) という形式の関数を作成します。ここで、parameters は学習可能なパラメーターを含む構造体、dlX は入力予測子、T はターゲット、gradients は返される勾配です。オプションで、(損失関数に追加情報が必要な場合などに) 勾配関数に追加引数を渡したり、(学習の進行状況をプロットするためのメトリクスなどの) 追加引数を返すことができます。モデルが関数として定義されている場合、ネットワークを入力引数として渡す必要はありません。

たとえば、この関数は、深層学習モデル関数 model、指定された学習可能なパラメーター parameters、入力データ dlX、およびターゲット T の勾配と交差エントロピー損失を返します。

function [gradients, loss] = modelGradients(parameters, dlX, T)

    % Forward data through the model function.
    dlY = model(parameters,dlX);

    % Compute loss.
    loss = crossentropy(dlX,T);

    % Compute gradients.
    gradients = dlgradient(loss,parameters);

end

モデル勾配関数の評価

自動微分を使用してモデル勾配を評価するには、自動微分を有効にして関数を評価する関数 dlfeval を使用します。dlfeval の最初の入力では、関数ハンドルとして指定されるモデル勾配関数を渡し、次の入力では、モデル勾配関数に必要な変数を渡します。関数 dlfeval の出力では、モデル勾配関数と同じ出力を指定します。

たとえば、dlnetwork オブジェクト dlnet、入力データ dlX、および T でモデル勾配関数 modelGradients を評価し、モデルの勾配と損失を返すには、次のコマンドを使用します。

[gradients, loss] = dlfeval(@modelGradients,dlnet,dlX,T);

同様に、構造体 parameters、入力データ dlX、および T で指定される学習可能なパラメーターをもつモデル関数を使用して、モデル勾配関数 modelGradients を評価し、モデルの勾配と損失を返すには、次のコマンドを使用します。

[gradients, loss] = dlfeval(@modelGradients,parameters,dlX,T);

勾配を使用した学習可能なパラメーターの更新

勾配を使用して学習可能なパラメーターを更新するには、次の関数を使用できます。

関数説明
adamupdate適応モーメント推定 (Adam) を使用してパラメーターを更新する
rmspropupdate平方根平均二乗伝播 (RMSProp) を使用してパラメーターを更新する
sgdmupdateモーメンタム項付き確率的勾配降下法 (SGDM) を使用してパラメーターを更新する
dlupdateカスタム関数を使用してパラメーターを更新する

たとえば、関数 adamupdate を使用して、dlnetwork オブジェクト dlnet の学習可能なパラメーターを更新するには、次のコマンドを使用します。

[dlnet,trailingAvg,trailingAvgSq] = adamupdate(dlnet,gradients, ...
    trailingAvg,trailingAverageSq,iteration);
ここで、gradients はモデル勾配関数の出力、trailingAvgtrailingAvgSqiteration は、関数 adamupdate で必要なハイパーパラメーターです。

同様に、関数 adamupdate を使用してモデル関数 parameters の学習可能なパラメーターを更新するには、次のコマンドを使用します。

[parameters,trailingAvg,trailingAvgSq] = adamupdate(parameters,gradients, ...
    trailingAvg,trailingAverageSq,iteration);
ここで、gradients はモデル勾配関数の出力、trailingAvgtrailingAvgSqiteration は、関数 adamupdate で必要なハイパーパラメーターです。

カスタム学習ループでのモデル勾配関数の使用

カスタム学習ループを使用して深層学習モデルに学習させる場合、モデル勾配を評価し、各ミニバッチの学習可能なパラメーターを更新します。

このコード スニペットは、カスタム学習ループで関数 dlfeval および adamupdate を使用する例を示しています。

iteration = 0;

% Loop over epochs.
for epoch = 1:numEpochs
    
    % Loop over mini-batches.
    for i = 1:numIterationsPerEpoch
        iteration = iteration + 1;

        % Prepare mini-batch.
        % ...

        % Evaluate model gradients.
        [gradients, loss] = dlfeval(@modelGradients,dlnet,dlX,T);

        % Update learnable parameters.
        [parameters,trailingAvg,trailingAvgSq] = adamupdate(parameters,gradients, ...
            trailingAvg,trailingAverageSq,iteration);

    end
end

dlnetwork オブジェクトを使用して深層学習モデルに学習させる方法を示す例については、カスタム学習ループを使用したネットワークの学習を参照してください。関数として定義される深層学習モデルに学習させる方法を示す例については、モデル関数を使用したネットワークの学習を参照してください。

モデル勾配関数のデバッグ

モデル勾配関数の実装に問題がある場合、dlfeval を呼び出すと、エラーがスローされる可能性があります。関数 dlfeval を使用する際、どのコード行がエラーをスローしているのかはっきりしない場合があります。エラーの位置を特定するために、以下を試すことができます。

モデル勾配関数の直接呼び出し

想定されるサイズの生成済み入力を使用して、モデル勾配関数を直接 (関数 dlfeval を使用せずに) 呼び出してみます。いずれかのコード行がエラーをスローする場合、どの行がスローしたのかがはっきりするはずです。関数 dlfeval を使用しない場合、関数 dlgradient へのあらゆる呼び出しでエラーの発生が予想されることに注意してください。

% Generate image input data.
X = rand([28 28 1 100],'single');
dlX = dlarray(dlX);

% Generate one-hot encoded target data.
T = repmat(eye(10,'single'),[1 10]);

[gradients, loss] = modelGradients(dlnet,dlX,T);

モデル勾配コードの手動での実行

想定されるサイズの生成済み入力を使用して、モデル勾配関数内で手動でコードを実行し、出力とスローされるエラー メッセージを検査します。

たとえば、次のように定義されるモデル勾配関数をチェックするとします。

function [gradients, loss] = modelGradients(dlnet, dlX, T)

    % Forward data through the dlnetwork object.
    dlY = forward(dlnet,dlX);

    % Compute loss.
    loss = crossentropy(dlX,T);

    % Compute gradients.
    gradients = dlgradient(loss,dlnet);

end

この場合、次のコードを実行します。

% Generate image input data.
X = rand([28 28 1 100],'single');
dlX = dlarray(dlX);

% Generate one-hot encoded target data.
T = repmat(eye(10,'single'),[1 10]);

% Check forward pass.
dlY = forward(dlnet,dlX);

% Check loss calculation.
loss = crossentropy(dlX,T)

関連するトピック