層の重み初期化子の比較
この例では、さまざまな重み初期化子を使用して深層学習ネットワークに学習させる方法を示します。
深層学習ネットワークに学習させる際、層の重みとバイアスの初期化は、ネットワークの学習成果に大きな影響を与える可能性があります。バッチ正規化層をもたないネットワークの場合、初期化子の選択はさらに大きな影響を与えます。
層の種類に応じて、WeightsInitializer
、InputWeightsInitializer
、RecurrentWeightsInitializer
、および BiasInitializer
のオプションを使用して、重みとバイアスの初期化を変更できます。
この例では、LSTM ネットワークの学習に次の 3 つの異なる重み初期化子を使用した場合の効果を示します。
Glorot 初期化子 – Glorot 初期化子を使用して入力の重みを初期化します。[1]
He 初期化子 – He 初期化子を使用して入力の重みを初期化します。[2]
狭い正規初期化子 – ゼロ平均、標準偏差 0.01 の正規分布から個別にサンプリングを行い、入力の重みを初期化します。
データの読み込み
Japanese Vowels データ セットを読み込みます。これには、特徴次元が 12 で、ラベル 1、2、...、9 の categorical ベクトルをもつ可変長のシーケンスが含まれています。シーケンスは行列で、行数が 12 (特徴ごとに 1 行) で、列数が可変 (タイム ステップごとに 1 列) です。
load JapaneseVowelsTrainData load JapaneseVowelsTestData
ネットワーク アーキテクチャの指定
ネットワーク アーキテクチャを指定します。各初期化子について、同じネットワーク アーキテクチャを使用します。
入力サイズを 12 (入力データの特徴の数) に指定します。100 個の隠れユニットを持つ LSTM 層を指定して、シーケンスの最後の要素を出力します。最後に、サイズが 9 の全結合層を含めることによって 9 個のクラスを指定し、その後にソフトマックス層を配置します。
numFeatures = 12; numHiddenUnits = 100; numClasses = 9; layers = [ ... sequenceInputLayer(numFeatures) lstmLayer(numHiddenUnits,OutputMode="last") fullyConnectedLayer(numClasses) softmaxLayer]
layers = 4x1 Layer array with layers: 1 '' Sequence Input Sequence input with 12 dimensions 2 '' LSTM LSTM with 100 hidden units 3 '' Fully Connected 9 fully connected layer 4 '' Softmax softmax
学習オプション
学習オプションを指定します。各初期化子について、同じ学習オプションを使用してネットワークに学習させます。
maxEpochs = 30; miniBatchSize = 27; numObservations = numel(XTrain); numIterationsPerEpoch = floor(numObservations / miniBatchSize); options = trainingOptions("adam", ... ExecutionEnvironment="cpu", ... MaxEpochs=maxEpochs, ... InputDataFormats="CTB", ... Metrics="accuracy", ... MiniBatchSize=miniBatchSize, ... GradientThreshold=2, ... ValidationData={XTest,TTest}, ... ValidationFrequency=numIterationsPerEpoch, ... Verbose=false, ... Plots="training-progress");
Glorot 初期化子
この例で前述したネットワーク アーキテクチャを指定し、LSTM 層の入力重み初期化子と全結合層の重み初期化子を "glorot"
に設定します。
layers = [ ... sequenceInputLayer(numFeatures) lstmLayer(numHiddenUnits,OutputMode="last",InputWeightsInitializer="glorot") fullyConnectedLayer(numClasses,WeightsInitializer="glorot") softmaxLayer];
Glorot 重み初期化子をもつ関数trainnet
を使用して、ネットワークに学習させます。
[netGlorot,infoGlorot] = trainnet(XTrain,TTrain,layers,"crossentropy",options);
He 初期化子
この例で前述したネットワーク アーキテクチャを指定し、LSTM 層の入力重み初期化子と全結合層の重み初期化子を "he"
に設定します。
layers = [ ... sequenceInputLayer(numFeatures) lstmLayer(numHiddenUnits,OutputMode="last",InputWeightsInitializer="he") fullyConnectedLayer(numClasses,WeightsInitializer="he") softmaxLayer];
He 重み初期化子をもつ層を使用して、ネットワークに学習させます。
[netHe,infoHe] = trainnet(XTrain,TTrain,layers,"crossentropy",options);
narrow-normal 初期化子
この例で前述したネットワーク アーキテクチャを指定し、LSTM 層の入力重み初期化子と全結合層の重み初期化子を "narrow-normal"
に設定します。
layers = [ ... sequenceInputLayer(numFeatures) lstmLayer(numHiddenUnits,OutputMode="last",InputWeightsInitializer="narrow-normal") fullyConnectedLayer(numClasses,WeightsInitializer="narrow-normal") softmaxLayer];
狭い正規重み初期化子をもつ層を使用して、ネットワークに学習させます。
[netNarrowNormal,infoNarrowNormal] = trainnet(XTrain,TTrain,layers,"crossentropy",options);
結果のプロット
関数 trainNetwork
で返される情報構造体出力から検証精度を抽出します。
validationAccuracy = [ infoGlorot.ValidationHistory.Accuracy,... infoHe.ValidationHistory.Accuracy,... infoNarrowNormal.ValidationHistory.Accuracy];
検証精度のベクトルには、その検証精度が計算されなかった反復については NaN
が含まれています。NaN
値を削除します。
idx = all(isnan(validationAccuracy)); validationAccuracy(:,idx) = [];
各初期化子について、検証精度に対するエポック数をプロットします。
figure epochs = 0:maxEpochs; plot(epochs,validationAccuracy) ylim([0 100]) title("Validation Accuracy") xlabel("Epoch") ylabel("Validation Accuracy") legend(["Glorot" "He" "Narrow-Normal"],Location="southeast")
このプロットでは、さまざまな初期化子の全体的な効果、および各初期化子における学習の収束速度を示しています。
参考文献
Glorot, Xavier, and Yoshua Bengio. "Understanding the difficulty of training deep feedforward neural networks." In Proceedings of the thirteenth international conference on artificial intelligence and statistics, pp. 249-256. 2010.
He, Kaiming, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. "Delving deep into rectifiers: Surpassing human-level performance on imagenet classification." In Proceedings of the IEEE international conference on computer vision, pp. 1026-1034. 2015.
参考
trainnet
| trainingOptions
| dlnetwork