メインコンテンツ

lbfgsupdate

メモリ制限 BFGS (L-BFGS) を使用してパラメーターを更新する

R2023a 以降

    説明

    メモリ制限 BFGS (L-BFGS) アルゴリズムを使用して、カスタム学習ループでネットワークの学習可能なパラメーターを更新します。

    L-BFGS アルゴリズム[1]は、Broyden-Fletcher-Goldfarb-Shanno (BFGS) アルゴリズムを近似する準ニュートン法です。L-BFGS アルゴリズムは、単一のバッチで処理できる小規模なネットワークやデータ セットに使用します。

    メモ

    この関数は、L-BFGS 最適化アルゴリズムを適用して、カスタム学習ループでネットワークのパラメーターを更新します。trainnet 関数を使用して L-BFGS ソルバーでニューラル ネットワークに学習させるには、trainingOptions 関数を使用してソルバーを "lbfgs" に設定します。

    [netUpdated,solverStateUpdated] = lbfgsupdate(net,lossFcn,solverState) は、指定された損失関数とソルバーの状態を L-BFGS アルゴリズムで使用して、ネットワーク net の学習可能なパラメーターを更新します。dlnetwork オブジェクトとして定義されたネットワークを反復的に更新するには、学習ループでこの構文を使用します。

    [parametersUpdated,solverStateUpdated] = lbfgsupdate(parameters,lossFcn,solverState) は、指定された損失関数とソルバーの状態を L-BFGS アルゴリズムで使用して、parameters に含まれる学習可能なパラメーターを更新します。関数として定義されたネットワークの学習可能なパラメーターを反復的に更新するには、学習ループでこの構文を使用します。

    ___ = lbfgsupdate(___,Name=Value) は、1 つ以上の名前と値の引数を使用して追加オプションを指定します。

    すべて折りたたむ

    CSV ファイル "transmissionCasingData.csv" からトランスミッション ケーシング データを読み取ります。

    filename = "transmissionCasingData.csv";
    tbl = readtable(filename,TextType="String");

    関数 convertvars を使用して、予測のラベルを categorical に変換します。

    labelName = "GearToothCondition";
    tbl = convertvars(tbl,labelName,"categorical");

    カテゴリカル特徴量を使用してネットワークに学習させるには、convertvars 関数を使用して、すべてのカテゴリカル入力変数の名前を格納した string 配列を指定することにより、カテゴリカル予測子を categorical に変換します。

    categoricalPredictorNames = ["SensorCondition" "ShaftCondition"];
    tbl = convertvars(tbl,categoricalPredictorNames,"categorical");

    カテゴリカル入力変数をループ処理します。各変数について、関数 onehotencode を使用して categorical 値を one-hot 符号化ベクトルに変換します。

    for i = 1:numel(categoricalPredictorNames)
        name = categoricalPredictorNames(i);
        tbl.(name) = onehotencode(tbl.(name),2);
    end

    table の最初の数行を表示します。

    head(tbl)
        SigMean     SigMedian    SigRMS    SigVar     SigPeak    SigPeak2Peak    SigSkewness    SigKurtosis    SigCrestFactor    SigMAD     SigRangeCumSum    SigCorrDimension    SigApproxEntropy    SigLyapExponent    PeakFreq    HighFreqPower    EnvPower    PeakSpecKurtosis    SensorCondition    ShaftCondition    GearToothCondition
        ________    _________    ______    _______    _______    ____________    ___________    ___________    ______________    _______    ______________    ________________    ________________    _______________    ________    _____________    ________    ________________    _______________    ______________    __________________
    
        -0.94876     -0.9722     1.3726    0.98387    0.81571       3.6314        -0.041525       2.2666           2.0514         0.8081        28562              1.1429             0.031581            79.931            0          6.75e-06       3.23e-07         162.13             0    1             1    0          No Tooth Fault  
        -0.97537    -0.98958     1.3937    0.99105    0.81571       3.6314        -0.023777       2.2598           2.0203        0.81017        29418              1.1362             0.037835            70.325            0          5.08e-08       9.16e-08         226.12             0    1             1    0          No Tooth Fault  
          1.0502      1.0267     1.4449    0.98491     2.8157       3.6314         -0.04162       2.2658           1.9487        0.80853        31710              1.1479             0.031565            125.19            0          6.74e-06       2.85e-07         162.13             0    1             0    1          No Tooth Fault  
          1.0227      1.0045     1.4288    0.99553     2.8157       3.6314        -0.016356       2.2483           1.9707        0.81324        30984              1.1472             0.032088             112.5            0          4.99e-06        2.4e-07         162.13             0    1             0    1          No Tooth Fault  
          1.0123      1.0024     1.4202    0.99233     2.8157       3.6314        -0.014701       2.2542           1.9826        0.81156        30661              1.1469              0.03287            108.86            0          3.62e-06       2.28e-07         230.39             0    1             0    1          No Tooth Fault  
          1.0275      1.0102     1.4338     1.0001     2.8157       3.6314         -0.02659       2.2439           1.9638        0.81589        31102              1.0985             0.033427            64.576            0          2.55e-06       1.65e-07         230.39             0    1             0    1          No Tooth Fault  
          1.0464      1.0275     1.4477     1.0011     2.8157       3.6314        -0.042849       2.2455           1.9449        0.81595        31665              1.1417             0.034159            98.838            0          1.73e-06       1.55e-07         230.39             0    1             0    1          No Tooth Fault  
          1.0459      1.0257     1.4402    0.98047     2.8157       3.6314        -0.035405       2.2757            1.955        0.80583        31554              1.1345               0.0353            44.223            0          1.11e-06       1.39e-07         230.39             0    1             0    1          No Tooth Fault  
    

    学習データを抽出します。

    predictorNames = ["SigMean" "SigMedian" "SigRMS" "SigVar" "SigPeak" "SigPeak2Peak" ...
        "SigSkewness" "SigKurtosis" "SigCrestFactor" "SigMAD" "SigRangeCumSum" ...
        "SigCorrDimension" "SigApproxEntropy" "SigLyapExponent" "PeakFreq" ...
        "HighFreqPower" "EnvPower" "PeakSpecKurtosis" "SensorCondition" "ShaftCondition"];
    XTrain = table2array(tbl(:,predictorNames));
    numInputFeatures = size(XTrain,2);

    ターゲットを抽出し、one-hot 符号化されたベクトルに変換します。

    TTrain = tbl.(labelName);
    TTrain = onehotencode(TTrain,2);
    numClasses = size(TTrain,2);

    予測子とターゲットを、形式 "BC" (バッチ、チャネル) の dlarray オブジェクトに変換します。

    XTrain = dlarray(XTrain,"BC");
    TTrain = dlarray(TTrain,"BC");

    ネットワーク アーキテクチャを定義します。

    numHiddenUnits = 32;
    
    layers = [
        featureInputLayer(numInputFeatures)
        fullyConnectedLayer(16)
        layerNormalizationLayer
        reluLayer
        fullyConnectedLayer(numClasses)
        softmaxLayer];
    
    net = dlnetwork(layers);

    この例のモデル損失関数のセクションにリストされている modelLoss 関数を定義します。この関数は、ニューラル ネットワーク、入力データ、およびターゲットを入力として受け取ります。この関数は、ネットワークの学習可能なパラメーターに関する損失および損失の勾配を返します。

    lbfgsupdate 関数には、構文 [loss,gradients] = f(net) をもつ損失関数が必要です。評価された modelLoss 関数をパラメーター化して、単一の入力引数を受け取る変数を作成します。

    lossFcn = @(net) dlfeval(@modelLoss,net,XTrain,TTrain);

    最大履歴サイズを 3、初期逆ヘッセ近似係数を 1.1 として、L-BFGS のソルバー状態オブジェクトを初期化します。

    solverState = lbfgsState( ...
        HistorySize=3, ...
        InitialInverseHessianFactor=1.1);

    最大 200 回反復してネットワークに学習させます。勾配のノルムまたはステップのノルムが 0.00001 より小さくなった時点で、学習を早期に停止します。10 回の反復ごとに学習損失を出力します。

    maxIterations = 200;
    gradientTolerance = 1e-5;
    stepTolerance = 1e-5;
    
    iteration = 0;
    
    while iteration < maxIterations
        iteration = iteration + 1;
        [net, solverState] = lbfgsupdate(net,lossFcn,solverState);
    
        if iteration==1 || mod(iteration,10)==0
            fprintf("Iteration %d: Loss: %d\n",iteration,solverState.Loss);
        end
    
        if solverState.GradientsNorm < gradientTolerance || ...
                solverState.StepNorm < stepTolerance || ...
                solverState.LineSearchStatus == "failed"
            break
        end
    end
    Iteration 1: Loss: 9.343236e-01
    Iteration 10: Loss: 4.721475e-01
    Iteration 20: Loss: 4.678575e-01
    Iteration 30: Loss: 4.666964e-01
    Iteration 40: Loss: 4.665921e-01
    Iteration 50: Loss: 4.663871e-01
    Iteration 60: Loss: 4.662519e-01
    Iteration 70: Loss: 4.660451e-01
    Iteration 80: Loss: 4.645303e-01
    Iteration 90: Loss: 4.591753e-01
    Iteration 100: Loss: 4.562556e-01
    Iteration 110: Loss: 4.531167e-01
    Iteration 120: Loss: 4.489444e-01
    Iteration 130: Loss: 4.392228e-01
    Iteration 140: Loss: 4.347853e-01
    Iteration 150: Loss: 4.341757e-01
    Iteration 160: Loss: 4.325102e-01
    Iteration 170: Loss: 4.321948e-01
    Iteration 180: Loss: 4.318990e-01
    Iteration 190: Loss: 4.313784e-01
    Iteration 200: Loss: 4.311314e-01
    

    モデル損失関数

    modelLoss 関数は、ニューラル ネットワーク net、入力データ X、およびターゲット T を入力として受け取ります。この関数は、ネットワークの学習可能なパラメーターに関する損失および損失の勾配を返します。

    function [loss, gradients] = modelLoss(net, X, T)
    
    Y = forward(net,X);
    loss = crossentropy(Y,T);
    gradients = dlgradient(loss,net.Learnables);
    
    end

    入力引数

    すべて折りたたむ

    ニューラル ネットワーク。dlnetwork オブジェクトとして指定します。

    この関数は、dlnetwork オブジェクトの Learnables プロパティを更新します。net.Learnables は、3 つの変数をもつ table です。

    • Layer — 層の名前。string スカラーとして指定します。

    • Parameter — パラメーター名。string スカラーとして指定します。

    • Value — パラメーターの値。dlarray オブジェクトを含む cell 配列として指定します。

    学習可能なパラメーター。dlarray オブジェクト、数値配列、cell 配列、構造体、または table として指定します。

    parameters を table として指定する場合、次の変数を table に含めなければなりません。

    • Layer — 層の名前。string スカラーとして指定します。

    • Parameter — パラメーター名。string スカラーとして指定します。

    • Value — パラメーターの値。dlarray オブジェクトを含む cell 配列として指定します。

    cell 配列、構造体、table、入れ子になった cell 配列、または入れ子になった構造体を使用し、ネットワークの学習可能なパラメーターのコンテナーとして parameters を指定できます。cell 配列、構造体、または table に含まれる学習可能なパラメーターは、データ型が double または single である dlarray オブジェクトまたは数値でなければなりません。

    parameters が数値配列の場合、lossFcndlgradient 関数を使用してはなりません。

    損失関数。関数ハンドルまたは構文 [loss,gradients] = f(net) をもつ AcceleratedFunction オブジェクトとして指定します。ここで、lossgradients は、それぞれ学習可能なパラメーターに対する損失と損失の勾配に対応します。

    dlgradient 関数を呼び出すモデル損失関数をパラメーター化するには、損失関数を @(net) dlfeval(@modelLoss,net,arg1,...,argN) として指定します。ここで、modelLoss は、[loss,gradients] = modelLoss(net,arg1,...,argN) という構文をもつ関数で、引数 arg1,...,argN が与えられた net 内の学習可能なパラメーターに関する損失と損失の勾配を返します。

    parameters が数値配列の場合、損失関数は dlgradient 関数または dlfeval 関数を使用してはなりません。

    損失関数の出力が 2 つより多い場合は、NumLossFunctionOutputs 引数も指定します。

    データ型: function_handle

    ソルバーの状態。lbfgsState オブジェクトまたは [] として指定します。

    名前と値の引数

    すべて折りたたむ

    オプションの引数のペアを Name1=Value1,...,NameN=ValueN として指定します。ここで、Name は引数名で、Value は対応する値です。名前と値の引数は他の引数の後に指定しなければなりませんが、ペアの順序は重要ではありません。

    例: lbfgsupdate(net,lossFcn,solverState,LineSearchMethod="strong-wolfe") は、net 内の学習可能なパラメーターを更新し、強 Wolfe 条件を満たす学習率を検索します。

    適切な学習率を検出する方法。次の値のいずれかとして指定します。

    • "weak-wolfe" — 弱 Wolfe 条件を満たす学習率を検索します。この方法は、逆ヘッセ行列の正定値近似を維持します。

    • "strong-wolfe" — 強 Wolfe 条件を満たす学習率を検索します。この方法は、逆ヘッセ行列の正定値近似を維持します。

    • "backtracking" — 十分な減少条件を満たす学習率を検索します。この方法は、逆ヘッセ行列の正定値近似を維持しません。

    学習率を決定するための直線探索の反復の最大数。正の整数として指定します。

    データ型: single | double | int8 | int16 | int32 | int64 | uint8 | uint16 | uint32 | uint64

    損失関数の出力の数。2 以上の整数として指定します。lossFcn の出力引数が 2 つより多い場合は、このオプションを設定します。

    データ型: single | double | int8 | int16 | int32 | int64 | uint8 | uint16 | uint32 | uint64

    出力引数

    すべて折りたたむ

    更新されたネットワーク。dlnetwork オブジェクトとして返されます。

    この関数は、dlnetwork オブジェクトの Learnables プロパティを更新します。

    更新された学習可能なパラメーター。parameters と同じタイプのオブジェクトとして返されます。

    更新されたソルバーの状態。lbfgsState 状態オブジェクトとして返されます。

    アルゴリズム

    すべて折りたたむ

    参照

    [1] Liu, Dong C., and Jorge Nocedal. "On the limited memory BFGS method for large scale optimization." Mathematical programming 45, no. 1 (August 1989): 503-528. https://doi.org/10.1007/BF01589116.

    拡張機能

    すべて展開する

    バージョン履歴

    R2023a で導入