Main Content

カスタム学習ループのモデル損失関数の定義

カスタム学習ループを使用して深層学習モデルに学習させる場合、ソフトウェアは、学習可能なパラメーターについての損失を最小化します。損失を最小化するために、ソフトウェアは、学習可能なパラメーターについての損失の勾配を使用します。自動微分を使用してこれらの勾配を計算するには、モデル勾配関数を定義しなければなりません。

dlnetwork オブジェクトを使用して深層学習モデルに学習させる方法を示す例については、カスタム学習ループを使用したネットワークの学習を参照してください。関数として定義される深層学習モデルに学習させる方法を示す例については、モデル関数を使用したネットワークの学習を参照してください。

dlnetwork オブジェクトとして定義されるモデルのモデル損失関数の作成

dlnetwork オブジェクトとして定義される深層学習モデルがある場合、入力として dlnetwork オブジェクトを受け取るモデル損失関数を作成します。

dlnetwork オブジェクトとして指定されるモデルでは、[loss,gradients] = modelLoss(net,X,T) という形式の関数を作成します。ここで、net はネットワーク、X はネットワークの入力で、T にはターゲットが格納され、lossgradients にはそれぞれ損失と勾配が返されます。オプションで、損失関数が必要とする追加情報などを追加引数として勾配関数に渡すことや、更新されたネットワークの状態などを追加引数として返すことができます。

たとえば、この関数は、指定された dlnetwork オブジェクト net に含まれる学習可能なパラメーター、与えられた入力データ X、およびターゲット T についてのクロスエントロピー損失および損失の勾配を返します。

function [loss,gradients] = modelLoss(net,X,T)

    % Forward data through the dlnetwork object.
    Y = forward(net,X);

    % Compute loss.
    loss = crossentropy(Y,T);

    % Compute gradients.
    gradients = dlgradient(loss,net.Learnables);

end

関数として定義されるモデルのモデル損失関数の作成

関数として定義される深層学習モデルがある場合、入力としてモデルの学習可能なパラメーターを受け取るモデル損失関数を作成します。

関数として指定されるモデルでは、[loss,gradients] = modelLoss(parameters,X,T) という形式の関数を作成します。ここで、parameters には学習可能なパラメーターが格納され、X はモデルの入力で、T にはターゲットが格納され、lossgradients にはそれぞれ損失と勾配が返されます。オプションで、損失関数が必要とする追加情報などを追加引数として勾配関数に渡すことや、更新されたモデルの状態などを追加引数として返すことができます。

たとえば、この関数は、学習可能なパラメーター parameters、与えられた入力データ X、およびターゲット T についてのクロスエントロピー損失および損失の勾配を返します。

function [loss,gradients] = modelLoss(parameters,X,T)

    % Forward data through the model function.
    Y = model(parameters,X);

    % Compute loss.
    loss = crossentropy(Y,T);

    % Compute gradients.
    gradients = dlgradient(loss,parameters);

end

モデル損失関数の評価

自動微分を使用してモデル損失関数を評価するには、自動微分を有効にして関数を評価する関数 dlfeval を使用します。dlfeval の最初の入力では、関数ハンドルとして指定されるモデル損失関数を渡します。続く入力では、モデル損失関数に必要な変数を渡します。関数 dlfeval の出力では、モデル損失関数と同じ出力を指定します。

たとえば、dlnetwork オブジェクト net、入力データ X、およびターゲット T でモデル損失関数 modelLoss を評価し、モデルの損失と勾配を返します。

[loss,gradients] = dlfeval(@modelLoss,net,X,T);

同様に、構造体 parameters で指定された学習可能なパラメーター、入力データ X、およびターゲット T をもつモデル関数を使用して、モデル損失関数 modelLoss を評価し、モデルの損失と勾配を返します。

[loss,gradients] = dlfeval(@modelLoss,parameters,X,T);

勾配を使用した学習可能なパラメーターの更新

学習可能なパラメーターを更新するには、次の関数を使用できます。

関数説明
adamupdate適応モーメント推定 (Adam) を使用してパラメーターを更新する
rmspropupdate平方根平均二乗伝播 (RMSProp) を使用してパラメーターを更新する
sgdmupdateモーメンタム項付き確率的勾配降下法 (SGDM) を使用してパラメーターを更新する
lbfgsupdate記憶制限 BFGS (L-BFGS) を使用してパラメーターを更新する
dlupdateカスタム関数を使用してパラメーターを更新する

たとえば、関数 adamupdate を使用して、dlnetwork オブジェクト net の学習可能なパラメーターを更新します。

[net,trailingAvg,trailingAvgSq] = adamupdate(net,gradients, ...
    trailingAvg,trailingAverageSq,iteration);
ここで、gradients は学習可能なパラメーターについての損失の勾配、trailingAvgtrailingAvgSq、および iteration は関数 adamupdate が必要とするハイパーパラメーターです。

同様に、関数 adamupdate を使用してモデル関数 parameters の学習可能パラメーターを更新します。

[parameters,trailingAvg,trailingAvgSq] = adamupdate(parameters,gradients, ...
    trailingAvg,trailingAverageSq,iteration);
ここで、gradients は学習可能なパラメーターについての損失の勾配、trailingAvgtrailingAvgSq、および iteration は関数 adamupdate が必要とするハイパーパラメーターです。

カスタム学習ループでのモデル損失関数の使用

カスタム学習ループを使用して深層学習モデルに学習させる場合、モデルの損失と勾配を評価し、各ミニバッチの学習可能なパラメーターを更新します。

このコード スニペットは、カスタム学習ループで関数 dlfeval および adamupdate を使用する例を示しています。

iteration = 0;

% Loop over epochs.
for epoch = 1:numEpochs
    
    % Loop over mini-batches.
    for i = 1:numIterationsPerEpoch
        iteration = iteration + 1;

        % Prepare mini-batch.
        % ...

        % Evaluate model loss and gradients.
        [loss,gradients] = dlfeval(@modelLoss,net,X,T);

        % Update learnable parameters.
        [parameters,trailingAvg,trailingAvgSq] = adamupdate(parameters,gradients, ...
            trailingAvg,trailingAverageSq,iteration);

    end
end

dlnetwork オブジェクトを使用して深層学習モデルに学習させる方法を示す例については、カスタム学習ループを使用したネットワークの学習を参照してください。関数として定義される深層学習モデルに学習させる方法を示す例については、モデル関数を使用したネットワークの学習を参照してください。

モデル損失関数のデバッグ

モデル損失関数の実装に問題がある場合、dlfeval を呼び出したときにエラーがスローされる可能性があります。関数 dlfeval を使用する際、どのコード行がエラーをスローしているのかはっきりしない場合があります。エラーの位置を特定するために、以下を試すことができます。

モデル損失関数の直接呼び出し

必要とされるサイズの生成済み入力を使用して、モデル損失関数を直接 (すなわち、関数 dlfeval を使用せずに) 呼び出してみます。コード行のいずれかがエラーをスローした場合、エラー メッセージに追加の詳細が表示されます。関数 dlfeval を使用しない場合、関数 dlgradient へのあらゆる呼び出しでエラーがスローされることに注意してください。

% Generate image input data.
X = rand([28 28 1 100],'single');
X = dlarray(X);

% Generate one-hot encoded target data.
T = repmat(eye(10,'single'),[1 10]);

[loss,gradients] = modelLoss(net,X,T);

モデル損失コードの手動での実行

必要とされるサイズの生成済み入力を使用して、モデル損失関数内で手動でコードを実行し、出力およびスローされるエラー メッセージを検査します。

たとえば、次のモデル損失関数を考えます。

function [loss,gradients] = modelLoss(net,X,T)

    % Forward data through the dlnetwork object.
    Y = forward(net,X);

    % Compute loss.
    loss = crossentropy(Y,T);

    % Compute gradients.
    gradients = dlgradient(loss,net.Learnables);

end

次のコードを実行して、モデル損失関数をチェックします。

% Generate image input data.
X = rand([28 28 1 100],'single');
X = dlarray(X);

% Generate one-hot encoded target data.
T = repmat(eye(10,'single'),[1 10]);

% Check forward pass.
Y = forward(net,X);

% Check loss calculation.
loss = crossentropy(Y,T)

関連するトピック