Main Content

このページの内容は最新ではありません。最新版の英語を参照するには、ここをクリックします。

層の重み初期化子の比較

この例では、さまざまな重み初期化子を使用して深層学習ネットワークに学習させる方法を示します。

深層学習ネットワークに学習させる際、層の重みとバイアスの初期化は、ネットワークの学習成果に大きな影響を与える可能性があります。バッチ正規化層をもたないネットワークの場合、初期化子の選択はさらに大きな影響を与えます。

層の種類に応じて、WeightsInitializerInputWeightsInitializerRecurrentWeightsInitializer、および BiasInitializer のオプションを使用して、重みとバイアスの初期化を変更できます。

この例では、LSTM ネットワークの学習に次の 3 つの異なる重み初期化子を使用した場合の効果を示します。

  1. Glorot 初期化子 – Glorot 初期化子を使用して入力の重みを初期化します。[1]

  2. He 初期化子 – He 初期化子を使用して入力の重みを初期化します。[2]

  3. 狭い正規初期化子 – ゼロ平均、標準偏差 0.01 の正規分布から個別にサンプリングを行い、入力の重みを初期化します。

データの読み込み

Japanese Vowels データ セットを読み込みます。これには、特徴次元が 12 で、ラベル 1、2、...、9 の categorical ベクトルをもつ可変長のシーケンスが含まれています。シーケンスは行列で、行数が 12 (特徴ごとに 1 行) で、列数が可変 (タイム ステップごとに 1 列) です。

load JapaneseVowelsTrainData
load JapaneseVowelsTestData

ネットワーク アーキテクチャの指定

ネットワーク アーキテクチャを指定します。各初期化子について、同じネットワーク アーキテクチャを使用します。

入力サイズを 12 (入力データの特徴の数) に指定します。100 個の隠れユニットを持つ LSTM 層を指定して、シーケンスの最後の要素を出力します。最後に、サイズが 9 の全結合層を含めることによって 9 個のクラスを指定し、その後にソフトマックス層と分類層を配置します。

numFeatures = 12;
numHiddenUnits = 100;
numClasses = 9;

layers = [ ...
    sequenceInputLayer(numFeatures)
    lstmLayer(numHiddenUnits,OutputMode="last")
    fullyConnectedLayer(numClasses)
    softmaxLayer
    classificationLayer]
layers = 
  5x1 Layer array with layers:

     1   ''   Sequence Input          Sequence input with 12 dimensions
     2   ''   LSTM                    LSTM with 100 hidden units
     3   ''   Fully Connected         9 fully connected layer
     4   ''   Softmax                 softmax
     5   ''   Classification Output   crossentropyex

学習オプション

学習オプションを指定します。各初期化子について、同じ学習オプションを使用してネットワークに学習させます。

maxEpochs = 30;
miniBatchSize = 27;
numObservations = numel(XTrain);
numIterationsPerEpoch = floor(numObservations / miniBatchSize);

options = trainingOptions("adam", ...
    ExecutionEnvironment="cpu", ...
    MaxEpochs=maxEpochs, ...
    MiniBatchSize=miniBatchSize, ...
    GradientThreshold=2, ...
    ValidationData={XTest,TTest}, ...
    ValidationFrequency=numIterationsPerEpoch, ...
    Verbose=false, ...
    Plots="training-progress");

Glorot 初期化子

この例で前述したネットワーク アーキテクチャを指定し、LSTM 層の入力重み初期化子と全結合層の重み初期化子を "glorot" に設定します。

layers = [ ...
    sequenceInputLayer(numFeatures)
    lstmLayer(numHiddenUnits,OutputMode="last",InputWeightsInitializer="glorot")
    fullyConnectedLayer(numClasses,WeightsInitializer="glorot")
    softmaxLayer
    classificationLayer];

Glorot 重み初期化子をもつ層を使用して、ネットワークに学習させます。

[netGlorot,infoGlorot] = trainNetwork(XTrain,TTrain,layers,options);

Figure Training Progress (29-Aug-2023 21:20:25) contains 2 axes objects and another object of type uigridlayout. Axes object 1 with xlabel Iteration, ylabel Loss contains 10 objects of type patch, text, line. Axes object 2 with xlabel Iteration, ylabel Accuracy (%) contains 10 objects of type patch, text, line.

He 初期化子

この例で前述したネットワーク アーキテクチャを指定し、LSTM 層の入力重み初期化子と全結合層の重み初期化子を "he" に設定します。

layers = [ ...
    sequenceInputLayer(numFeatures)
    lstmLayer(numHiddenUnits,OutputMode="last",InputWeightsInitializer="he")
    fullyConnectedLayer(numClasses,WeightsInitializer="he")
    softmaxLayer
    classificationLayer];

He 重み初期化子をもつ層を使用して、ネットワークに学習させます。

[netHe,infoHe] = trainNetwork(XTrain,TTrain,layers,options);

Figure Training Progress (29-Aug-2023 21:21:11) contains 2 axes objects and another object of type uigridlayout. Axes object 1 with xlabel Iteration, ylabel Loss contains 10 objects of type patch, text, line. Axes object 2 with xlabel Iteration, ylabel Accuracy (%) contains 10 objects of type patch, text, line.

narrow-normal 初期化子

この例で前述したネットワーク アーキテクチャを指定し、LSTM 層の入力重み初期化子と全結合層の重み初期化子を "narrow-normal" に設定します。

layers = [ ...
    sequenceInputLayer(numFeatures)
    lstmLayer(numHiddenUnits,OutputMode="last",InputWeightsInitializer="narrow-normal")
    fullyConnectedLayer(numClasses,WeightsInitializer="narrow-normal")
    softmaxLayer
    classificationLayer];

狭い正規重み初期化子をもつ層を使用して、ネットワークに学習させます。

[netNarrowNormal,infoNarrowNormal] = trainNetwork(XTrain,TTrain,layers,options);

Figure Training Progress (29-Aug-2023 21:21:43) contains 2 axes objects and another object of type uigridlayout. Axes object 1 with xlabel Iteration, ylabel Loss contains 10 objects of type patch, text, line. Axes object 2 with xlabel Iteration, ylabel Accuracy (%) contains 10 objects of type patch, text, line.

結果のプロット

関数 trainNetwork で返される情報構造体出力から検証精度を抽出します。

validationAccuracy = [
    infoGlorot.ValidationAccuracy;
    infoHe.ValidationAccuracy;
    infoNarrowNormal.ValidationAccuracy];

検証精度のベクトルには、その検証精度が計算されなかった反復については NaN が含まれています。NaN 値を削除します。

idx = all(isnan(validationAccuracy));
validationAccuracy(:,idx) = [];

各初期化子について、検証精度に対するエポック数をプロットします。

figure
epochs = 0:maxEpochs;
plot(epochs,validationAccuracy)
ylim([0 100])
title("Validation Accuracy")
xlabel("Epoch")
ylabel("Validation Accuracy")
legend(["Glorot" "He" "Narrow-Normal"],Location="southeast")

Figure contains an axes object. The axes object with title Validation Accuracy, xlabel Epoch, ylabel Validation Accuracy contains 3 objects of type line. These objects represent Glorot, He, Narrow-Normal.

このプロットでは、さまざまな初期化子の全体的な効果、および各初期化子における学習の収束速度を示しています。

参考文献

  1. Glorot, Xavier, and Yoshua Bengio. "Understanding the difficulty of training deep feedforward neural networks." In Proceedings of the thirteenth international conference on artificial intelligence and statistics, pp. 249-256. 2010.

  2. He, Kaiming, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. "Delving deep into rectifiers: Surpassing human-level performance on imagenet classification." In Proceedings of the IEEE international conference on computer vision, pp. 1026-1034. 2015.

参考

|

関連するトピック