Main Content

層の重み初期化子の比較

この例では、さまざまな重み初期化子を使用して深層学習ネットワークに学習させる方法を示します。

深層学習ネットワークに学習させる際、層の重みとバイアスの初期化は、ネットワークの学習成果に大きな影響を与える可能性があります。バッチ正規化層をもたないネットワークの場合、初期化子の選択はさらに大きな影響を与えます。

層の種類に応じて、WeightsInitializerInputWeightsInitializerRecurrentWeightsInitializer、および BiasInitializer のオプションを使用して、重みとバイアスの初期化を変更できます。

この例では、LSTM ネットワークの学習に次の 3 つの異なる重み初期化子を使用した場合の効果を示します。

  1. Glorot 初期化子 – Glorot 初期化子を使用して入力の重みを初期化します。[1]

  2. He 初期化子 – He 初期化子を使用して入力の重みを初期化します。[2]

  3. 狭い正規初期化子 – ゼロ平均、標準偏差 0.01 の正規分布から個別にサンプリングを行い、入力の重みを初期化します。

データの読み込み

Japanese Vowels データ セットを読み込みます。これには、特徴次元が 12 で、ラベル 1、2、...、9 の categorical ベクトルをもつ可変長のシーケンスが含まれています。シーケンスは行列で、行数が 12 (特徴ごとに 1 行) で、列数が可変 (タイム ステップごとに 1 列) です。

load JapaneseVowelsTrainData
load JapaneseVowelsTestData

ネットワーク アーキテクチャの指定

ネットワーク アーキテクチャを指定します。各初期化子について、同じネットワーク アーキテクチャを使用します。

入力サイズを 12 (入力データの特徴の数) に指定します。100 個の隠れユニットを持つ LSTM 層を指定して、シーケンスの最後の要素を出力します。最後に、サイズが 9 の全結合層を含めることによって 9 個のクラスを指定し、その後にソフトマックス層を配置します。

numFeatures = 12;
numHiddenUnits = 100;
numClasses = 9;

layers = [ ...
    sequenceInputLayer(numFeatures)
    lstmLayer(numHiddenUnits,OutputMode="last")
    fullyConnectedLayer(numClasses)
    softmaxLayer]
layers = 
  4x1 Layer array with layers:

     1   ''   Sequence Input    Sequence input with 12 dimensions
     2   ''   LSTM              LSTM with 100 hidden units
     3   ''   Fully Connected   9 fully connected layer
     4   ''   Softmax           softmax

学習オプション

学習オプションを指定します。各初期化子について、同じ学習オプションを使用してネットワークに学習させます。

maxEpochs = 30;
miniBatchSize = 27;
numObservations = numel(XTrain);
numIterationsPerEpoch = floor(numObservations / miniBatchSize);

options = trainingOptions("adam", ...
    ExecutionEnvironment="cpu", ...
    MaxEpochs=maxEpochs, ...
    InputDataFormats="CTB", ...
    Metrics="accuracy", ...
    MiniBatchSize=miniBatchSize, ...
    GradientThreshold=2, ...
    ValidationData={XTest,TTest}, ...
    ValidationFrequency=numIterationsPerEpoch, ...
    Verbose=false, ...
    Plots="training-progress");

Glorot 初期化子

この例で前述したネットワーク アーキテクチャを指定し、LSTM 層の入力重み初期化子と全結合層の重み初期化子を "glorot" に設定します。

layers = [ ...
    sequenceInputLayer(numFeatures)
    lstmLayer(numHiddenUnits,OutputMode="last",InputWeightsInitializer="glorot")
    fullyConnectedLayer(numClasses,WeightsInitializer="glorot")
    softmaxLayer];

Glorot 重み初期化子をもつ関数trainnetを使用して、ネットワークに学習させます。

[netGlorot,infoGlorot] = trainnet(XTrain,TTrain,layers,"crossentropy",options);

He 初期化子

この例で前述したネットワーク アーキテクチャを指定し、LSTM 層の入力重み初期化子と全結合層の重み初期化子を "he" に設定します。

layers = [ ...
    sequenceInputLayer(numFeatures)
    lstmLayer(numHiddenUnits,OutputMode="last",InputWeightsInitializer="he")
    fullyConnectedLayer(numClasses,WeightsInitializer="he")
    softmaxLayer];

He 重み初期化子をもつ層を使用して、ネットワークに学習させます。

[netHe,infoHe] = trainnet(XTrain,TTrain,layers,"crossentropy",options);

narrow-normal 初期化子

この例で前述したネットワーク アーキテクチャを指定し、LSTM 層の入力重み初期化子と全結合層の重み初期化子を "narrow-normal" に設定します。

layers = [ ...
    sequenceInputLayer(numFeatures)
    lstmLayer(numHiddenUnits,OutputMode="last",InputWeightsInitializer="narrow-normal")
    fullyConnectedLayer(numClasses,WeightsInitializer="narrow-normal")
    softmaxLayer];

狭い正規重み初期化子をもつ層を使用して、ネットワークに学習させます。

[netNarrowNormal,infoNarrowNormal] = trainnet(XTrain,TTrain,layers,"crossentropy",options);

結果のプロット

関数 trainNetwork で返される情報構造体出力から検証精度を抽出します。

validationAccuracy = [
    infoGlorot.ValidationHistory.Accuracy,...
    infoHe.ValidationHistory.Accuracy,...
    infoNarrowNormal.ValidationHistory.Accuracy];

検証精度のベクトルには、その検証精度が計算されなかった反復については NaN が含まれています。NaN 値を削除します。

idx = all(isnan(validationAccuracy));
validationAccuracy(:,idx) = [];

各初期化子について、検証精度に対するエポック数をプロットします。

figure
epochs = 0:maxEpochs;
plot(epochs,validationAccuracy)
ylim([0 100])
title("Validation Accuracy")
xlabel("Epoch")
ylabel("Validation Accuracy")
legend(["Glorot" "He" "Narrow-Normal"],Location="southeast")

Figure contains an axes object. The axes object with title Validation Accuracy, xlabel Epoch, ylabel Validation Accuracy contains 3 objects of type line. These objects represent Glorot, He, Narrow-Normal.

このプロットでは、さまざまな初期化子の全体的な効果、および各初期化子における学習の収束速度を示しています。

参考文献

  1. Glorot, Xavier, and Yoshua Bengio. "Understanding the difficulty of training deep feedforward neural networks." In Proceedings of the thirteenth international conference on artificial intelligence and statistics, pp. 249-256. 2010.

  2. He, Kaiming, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. "Delving deep into rectifiers: Surpassing human-level performance on imagenet classification." In Proceedings of the IEEE international conference on computer vision, pp. 1026-1034. 2015.

参考

| |

関連するトピック