lbfgsupdate
構文
説明
メモリ制限 BFGS (L-BFGS) アルゴリズムを使用して、カスタム学習ループでネットワークの学習可能なパラメーターを更新します。
L-BFGS アルゴリズム[1]は、Broyden-Fletcher-Goldfarb-Shanno (BFGS) アルゴリズムを近似する準ニュートン法です。L-BFGS アルゴリズムは、単一のバッチで処理できる小規模なネットワークやデータ セットに使用します。
メモ
この関数は、L-BFGS 最適化アルゴリズムを適用して、カスタム学習ループでネットワークのパラメーターを更新します。trainnet 関数を使用して L-BFGS ソルバーでニューラル ネットワークに学習させるには、trainingOptions 関数を使用してソルバーを "lbfgs" に設定します。
[ は、指定された損失関数とソルバーの状態を L-BFGS アルゴリズムで使用して、ネットワーク netUpdated,solverStateUpdated] = lbfgsupdate(net,lossFcn,solverState)net の学習可能なパラメーターを更新します。dlnetwork オブジェクトとして定義されたネットワークを反復的に更新するには、学習ループでこの構文を使用します。
[ は、指定された損失関数とソルバーの状態を L-BFGS アルゴリズムで使用して、parametersUpdated,solverStateUpdated] = lbfgsupdate(parameters,lossFcn,solverState)parameters に含まれる学習可能なパラメーターを更新します。関数として定義されたネットワークの学習可能なパラメーターを反復的に更新するには、学習ループでこの構文を使用します。
___ = lbfgsupdate(___, は、1 つ以上の名前と値の引数を使用して追加オプションを指定します。Name=Value)
例
CSV ファイル "transmissionCasingData.csv" からトランスミッション ケーシング データを読み取ります。
filename = "transmissionCasingData.csv"; tbl = readtable(filename,TextType="String");
関数 convertvars を使用して、予測のラベルを categorical に変換します。
labelName = "GearToothCondition"; tbl = convertvars(tbl,labelName,"categorical");
カテゴリカル特徴量を使用してネットワークに学習させるには、convertvars 関数を使用して、すべてのカテゴリカル入力変数の名前を格納した string 配列を指定することにより、カテゴリカル予測子を categorical に変換します。
categoricalPredictorNames = ["SensorCondition" "ShaftCondition"]; tbl = convertvars(tbl,categoricalPredictorNames,"categorical");
カテゴリカル入力変数をループ処理します。各変数について、関数 onehotencode を使用して categorical 値を one-hot 符号化ベクトルに変換します。
for i = 1:numel(categoricalPredictorNames) name = categoricalPredictorNames(i); tbl.(name) = onehotencode(tbl.(name),2); end
table の最初の数行を表示します。
head(tbl)
SigMean SigMedian SigRMS SigVar SigPeak SigPeak2Peak SigSkewness SigKurtosis SigCrestFactor SigMAD SigRangeCumSum SigCorrDimension SigApproxEntropy SigLyapExponent PeakFreq HighFreqPower EnvPower PeakSpecKurtosis SensorCondition ShaftCondition GearToothCondition
________ _________ ______ _______ _______ ____________ ___________ ___________ ______________ _______ ______________ ________________ ________________ _______________ ________ _____________ ________ ________________ _______________ ______________ __________________
-0.94876 -0.9722 1.3726 0.98387 0.81571 3.6314 -0.041525 2.2666 2.0514 0.8081 28562 1.1429 0.031581 79.931 0 6.75e-06 3.23e-07 162.13 0 1 1 0 No Tooth Fault
-0.97537 -0.98958 1.3937 0.99105 0.81571 3.6314 -0.023777 2.2598 2.0203 0.81017 29418 1.1362 0.037835 70.325 0 5.08e-08 9.16e-08 226.12 0 1 1 0 No Tooth Fault
1.0502 1.0267 1.4449 0.98491 2.8157 3.6314 -0.04162 2.2658 1.9487 0.80853 31710 1.1479 0.031565 125.19 0 6.74e-06 2.85e-07 162.13 0 1 0 1 No Tooth Fault
1.0227 1.0045 1.4288 0.99553 2.8157 3.6314 -0.016356 2.2483 1.9707 0.81324 30984 1.1472 0.032088 112.5 0 4.99e-06 2.4e-07 162.13 0 1 0 1 No Tooth Fault
1.0123 1.0024 1.4202 0.99233 2.8157 3.6314 -0.014701 2.2542 1.9826 0.81156 30661 1.1469 0.03287 108.86 0 3.62e-06 2.28e-07 230.39 0 1 0 1 No Tooth Fault
1.0275 1.0102 1.4338 1.0001 2.8157 3.6314 -0.02659 2.2439 1.9638 0.81589 31102 1.0985 0.033427 64.576 0 2.55e-06 1.65e-07 230.39 0 1 0 1 No Tooth Fault
1.0464 1.0275 1.4477 1.0011 2.8157 3.6314 -0.042849 2.2455 1.9449 0.81595 31665 1.1417 0.034159 98.838 0 1.73e-06 1.55e-07 230.39 0 1 0 1 No Tooth Fault
1.0459 1.0257 1.4402 0.98047 2.8157 3.6314 -0.035405 2.2757 1.955 0.80583 31554 1.1345 0.0353 44.223 0 1.11e-06 1.39e-07 230.39 0 1 0 1 No Tooth Fault
学習データを抽出します。
predictorNames = ["SigMean" "SigMedian" "SigRMS" "SigVar" "SigPeak" "SigPeak2Peak" ... "SigSkewness" "SigKurtosis" "SigCrestFactor" "SigMAD" "SigRangeCumSum" ... "SigCorrDimension" "SigApproxEntropy" "SigLyapExponent" "PeakFreq" ... "HighFreqPower" "EnvPower" "PeakSpecKurtosis" "SensorCondition" "ShaftCondition"]; XTrain = table2array(tbl(:,predictorNames)); numInputFeatures = size(XTrain,2);
ターゲットを抽出し、one-hot 符号化されたベクトルに変換します。
TTrain = tbl.(labelName); TTrain = onehotencode(TTrain,2); numClasses = size(TTrain,2);
予測子とターゲットを、形式 "BC" (バッチ、チャネル) の dlarray オブジェクトに変換します。
XTrain = dlarray(XTrain,"BC"); TTrain = dlarray(TTrain,"BC");
ネットワーク アーキテクチャを定義します。
numHiddenUnits = 32;
layers = [
featureInputLayer(numInputFeatures)
fullyConnectedLayer(16)
layerNormalizationLayer
reluLayer
fullyConnectedLayer(numClasses)
softmaxLayer];
net = dlnetwork(layers);この例のモデル損失関数のセクションにリストされている modelLoss 関数を定義します。この関数は、ニューラル ネットワーク、入力データ、およびターゲットを入力として受け取ります。この関数は、ネットワークの学習可能なパラメーターに関する損失および損失の勾配を返します。
lbfgsupdate 関数には、構文 [loss,gradients] = f(net) をもつ損失関数が必要です。評価された modelLoss 関数をパラメーター化して、単一の入力引数を受け取る変数を作成します。
lossFcn = @(net) dlfeval(@modelLoss,net,XTrain,TTrain);
最大履歴サイズを 3、初期逆ヘッセ近似係数を 1.1 として、L-BFGS のソルバー状態オブジェクトを初期化します。
solverState = lbfgsState( ... HistorySize=3, ... InitialInverseHessianFactor=1.1);
最大 200 回反復してネットワークに学習させます。勾配のノルムまたはステップのノルムが 0.00001 より小さくなった時点で、学習を早期に停止します。10 回の反復ごとに学習損失を出力します。
maxIterations = 200; gradientTolerance = 1e-5; stepTolerance = 1e-5; iteration = 0; while iteration < maxIterations iteration = iteration + 1; [net, solverState] = lbfgsupdate(net,lossFcn,solverState); if iteration==1 || mod(iteration,10)==0 fprintf("Iteration %d: Loss: %d\n",iteration,solverState.Loss); end if solverState.GradientsNorm < gradientTolerance || ... solverState.StepNorm < stepTolerance || ... solverState.LineSearchStatus == "failed" break end end
Iteration 1: Loss: 9.343236e-01 Iteration 10: Loss: 4.721475e-01 Iteration 20: Loss: 4.678575e-01 Iteration 30: Loss: 4.666964e-01 Iteration 40: Loss: 4.665921e-01 Iteration 50: Loss: 4.663871e-01 Iteration 60: Loss: 4.662519e-01 Iteration 70: Loss: 4.660451e-01 Iteration 80: Loss: 4.645303e-01 Iteration 90: Loss: 4.591753e-01 Iteration 100: Loss: 4.562556e-01 Iteration 110: Loss: 4.531167e-01 Iteration 120: Loss: 4.489444e-01 Iteration 130: Loss: 4.392228e-01 Iteration 140: Loss: 4.347853e-01 Iteration 150: Loss: 4.341757e-01 Iteration 160: Loss: 4.325102e-01 Iteration 170: Loss: 4.321948e-01 Iteration 180: Loss: 4.318990e-01 Iteration 190: Loss: 4.313784e-01 Iteration 200: Loss: 4.311314e-01
モデル損失関数
modelLoss 関数は、ニューラル ネットワーク net、入力データ X、およびターゲット T を入力として受け取ります。この関数は、ネットワークの学習可能なパラメーターに関する損失および損失の勾配を返します。
function [loss, gradients] = modelLoss(net, X, T) Y = forward(net,X); loss = crossentropy(Y,T); gradients = dlgradient(loss,net.Learnables); end
入力引数
学習可能なパラメーター。dlarray オブジェクト、数値配列、cell 配列、構造体、または table として指定します。
parameters を table として指定する場合、次の変数を table に含めなければなりません。
Layer— 層の名前。string スカラーとして指定します。Parameter— パラメーター名。string スカラーとして指定します。Value— パラメーターの値。dlarrayオブジェクトを含む cell 配列として指定します。
cell 配列、構造体、table、入れ子になった cell 配列、または入れ子になった構造体を使用し、ネットワークの学習可能なパラメーターのコンテナーとして parameters を指定できます。cell 配列、構造体、または table に含まれる学習可能なパラメーターは、データ型が double または single である dlarray オブジェクトまたは数値でなければなりません。
parameters が数値配列の場合、lossFcn は dlgradient 関数を使用してはなりません。
損失関数。関数ハンドルまたは構文 [loss,gradients] = f(net) をもつ AcceleratedFunction オブジェクトとして指定します。ここで、loss と gradients は、それぞれ学習可能なパラメーターに対する損失と損失の勾配に対応します。
dlgradient 関数を呼び出すモデル損失関数をパラメーター化するには、損失関数を @(net) dlfeval(@modelLoss,net,arg1,...,argN) として指定します。ここで、modelLoss は、[loss,gradients] = modelLoss(net,arg1,...,argN) という構文をもつ関数で、引数 arg1,...,argN が与えられた net 内の学習可能なパラメーターに関する損失と損失の勾配を返します。
parameters が数値配列の場合、損失関数は dlgradient 関数または dlfeval 関数を使用してはなりません。
損失関数の出力が 2 つより多い場合は、NumLossFunctionOutputs 引数も指定します。
データ型: function_handle
ソルバーの状態。lbfgsState オブジェクトまたは [] として指定します。
名前と値の引数
オプションの引数のペアを Name1=Value1,...,NameN=ValueN として指定します。ここで、Name は引数名で、Value は対応する値です。名前と値の引数は他の引数の後に指定しなければなりませんが、ペアの順序は重要ではありません。
例: lbfgsupdate(net,lossFcn,solverState,LineSearchMethod="strong-wolfe") は、net 内の学習可能なパラメーターを更新し、強 Wolfe 条件を満たす学習率を検索します。
適切な学習率を検出する方法。次の値のいずれかとして指定します。
"weak-wolfe"— 弱 Wolfe 条件を満たす学習率を検索します。この方法は、逆ヘッセ行列の正定値近似を維持します。"strong-wolfe"— 強 Wolfe 条件を満たす学習率を検索します。この方法は、逆ヘッセ行列の正定値近似を維持します。"backtracking"— 十分な減少条件を満たす学習率を検索します。この方法は、逆ヘッセ行列の正定値近似を維持しません。
学習率を決定するための直線探索の反復の最大数。正の整数として指定します。
データ型: single | double | int8 | int16 | int32 | int64 | uint8 | uint16 | uint32 | uint64
損失関数の出力の数。2 以上の整数として指定します。lossFcn の出力引数が 2 つより多い場合は、このオプションを設定します。
データ型: single | double | int8 | int16 | int32 | int64 | uint8 | uint16 | uint32 | uint64
出力引数
更新されたネットワーク。dlnetwork オブジェクトとして返されます。
この関数は、dlnetwork オブジェクトの Learnables プロパティを更新します。
更新された学習可能なパラメーター。parameters と同じタイプのオブジェクトとして返されます。
更新されたソルバーの状態。lbfgsState 状態オブジェクトとして返されます。
アルゴリズム
L-BFGS アルゴリズム[1]は、Broyden-Fletcher-Goldfarb-Shanno (BFGS) アルゴリズムを近似する準ニュートン法です。L-BFGS アルゴリズムは、単一のバッチで処理できる小規模なネットワークやデータ セットに使用します。
このアルゴリズムは、次で指定される更新ステップを使用し、反復 k+1 で学習可能なパラメーター W を更新します。
ここで、Wk は反復 k における重みを表し、 は反復 k における学習率です。Bk は反復 k におけるヘッセ行列の近似であり、 は反復 k における学習可能なパラメーターに関する損失の勾配を表します。
L-BFGS アルゴリズムは、行列とベクトルの積 を直接計算します。このアルゴリズムでは、Bk の逆行列を計算する必要がありません。
メモリを節約するため、L-BFGS アルゴリズムでは密なヘッセ行列 B の保存や反転は行われません。代わりに、アルゴリズムは近似 を使用します。ここで、m は履歴サイズであり、逆ヘッセ因子 はスカラーです。また、I は単位行列です。このアルゴリズムは、スカラーの逆ヘッセ因子のみを格納します。アルゴリズムは各ステップで逆ヘッセ因子を更新します。
行列とベクトルの積 を直接計算するために、L-BFGS アルゴリズムは次の再帰的アルゴリズムを使用します。
を設定します。ここで、m は履歴サイズです。
について、以下のようにします。
とします。ここで、 および は、それぞれ反復 に対するステップおよび勾配の差分です。
を設定します。ここで、 は、、、および損失関数に対する損失の勾配から導出されます。詳細については、[1]を参照してください。
を返します。
参照
[1] Liu, Dong C., and Jorge Nocedal. "On the limited memory BFGS method for large scale optimization." Mathematical programming 45, no. 1 (August 1989): 503-528. https://doi.org/10.1007/BF01589116.
拡張機能
lbfgsupdate 関数は GPU 配列入力をサポートしますが、次の使用上の注意および制限があります。
lossFcnが、gpuArray型または基となるデータがgpuArray型であるdlarray型のデータを出力する場合、この関数は GPU で実行されます。
詳細については、GPU での MATLAB 関数の実行 (Parallel Computing Toolbox)を参照してください。
バージョン履歴
R2023a で導入
MATLAB Command
You clicked a link that corresponds to this MATLAB command:
Run the command by entering it in the MATLAB Command Window. Web browsers do not support MATLAB commands.
Web サイトの選択
Web サイトを選択すると、翻訳されたコンテンツにアクセスし、地域のイベントやサービスを確認できます。現在の位置情報に基づき、次のサイトの選択を推奨します:
また、以下のリストから Web サイトを選択することもできます。
最適なサイトパフォーマンスの取得方法
中国のサイト (中国語または英語) を選択することで、最適なサイトパフォーマンスが得られます。その他の国の MathWorks のサイトは、お客様の地域からのアクセスが最適化されていません。
南北アメリカ
- América Latina (Español)
- Canada (English)
- United States (English)
ヨーロッパ
- Belgium (English)
- Denmark (English)
- Deutschland (Deutsch)
- España (Español)
- Finland (English)
- France (Français)
- Ireland (English)
- Italia (Italiano)
- Luxembourg (English)
- Netherlands (English)
- Norway (English)
- Österreich (Deutsch)
- Portugal (English)
- Sweden (English)
- Switzerland
- United Kingdom (English)