Main Content

モデル関数の学習可能パラメーターの初期化

層、層グラフ、または dlnetwork オブジェクトを使用してネットワークに学習させる場合、ソフトウェアは、層の初期化プロパティに従って、学習可能パラメーターを自動的に初期化します。関数として深層学習モデルを定義する場合、学習可能パラメーターを手動で初期化しなければなりません。

学習可能パラメーター (重みやバイアスなど) の初期化方法は、深層学習モデルの収束速度に大きな影響を与える可能性があります。

ヒント

このトピックでは、カスタム学習ループにおいて、関数を定義した深層学習モデルの学習可能パラメーターを初期化する方法を説明します。深層学習層の学習可能パラメーターの初期化を指定する方法を学ぶために、対応する層のプロパティを使用します。たとえば、convolution2dLayer オブジェクトの重み初期化子を設定するには、WeightsInitializer プロパティを使用します。

層の既定の初期化

この表は、各層の学習可能パラメーターに関する既定の初期化を示しています。また、同じ初期化を使用して、モデル関数の学習可能パラメーターを初期化する方法を示すリンクが記載されています。

学習可能なパラメーター既定の初期化
convolution2dLayer重みGlorot の初期化
Biasゼロでの初期化
convolution3dLayer重みGlorot の初期化
Biasゼロでの初期化
groupedConvolution2dLayer重みGlorot の初期化
Biasゼロでの初期化
transposedConv2dLayer重みGlorot の初期化
Biasゼロでの初期化
transposedConv3dLayer重みGlorot の初期化
Biasゼロでの初期化
fullyConnectedLayer重みGlorot の初期化
Biasゼロでの初期化
batchNormalizationLayerオフセットゼロでの初期化
スケール1 での初期化
lstmLayer入力重みGlorot の初期化
再帰重み直交初期化
Biasユニット忘却ゲートによる初期化
gruLayer入力重みGlorot の初期化
再帰重み直交初期化
Biasゼロでの初期化
wordEmbeddingLayer重みガウスによる初期化、平均 0、標準偏差 0.01 とする

学習可能なパラメーターのサイズ

モデル関数の学習可能なパラメーターを初期化する場合、正しいサイズのパラメーターを指定しなければなりません。学習可能なパラメーターのサイズは、深層学習演算のタイプによって異なります。

操作学習可能なパラメーターサイズ
batchnormオフセット

[numChannels 1]numChannels は入力チャネル数

スケール

[numChannels 1]numChannels は入力チャネル数

dlconv重み

[filterSize numChannels numFilters]filterSize はフィルター サイズを指定する 1 行 K 列のベクトル、numChannels は入力チャネル数、numFilters はフィルター数、K は空間次元数

Bias

次のいずれかを選択。

  • [numFilters 1]numFilters はフィルター数

  • [1 1]

dlconv (グループ化)重み

[filterSize numChannelsPerGroup numFiltersPerGroup numGroups]filterSize はフィルター サイズを指定する 1 行 K 列のベクトル、numChannelsPerGroup は各グループの入力チャネル数、numFiltersPerGroup は各グループのフィルター数、numGroups はグループ数、K は空間次元数

Bias

次のいずれかを選択。

  • [numFiltersPerGroup 1]numFiltersPerGroup は各グループのフィルター数。

  • [1 1]

dltranspconv重み

[filterSize numFilters numChannels]filterSize はフィルター サイズを指定する 1 行 K 列のベクトル、numChannels は入力チャネル数、numFilters はフィルター数、K は空間次元数

Bias

次のいずれかを選択。

  • [numFilters 1]numFilters は各グループのフィルター数。

  • [1 1]

dltranspconv (グループ化)重み

[filterSize numFiltersPerGroup numChannelsPerGroup numGroups]filterSize はフィルター サイズを指定する 1 行 K 列のベクトル、numChannelsPerGroup は各グループの入力チャネル数、numFiltersPerGroup は各グループのフィルター数、numGroups はグループ数、K は空間次元数

Bias

次のいずれかを選択。

  • [numFiltersPerGroup 1]numFiltersPerGroup は各グループのフィルター数。

  • [1 1]

fullyconnect重み

[outputSize inputSize]outputSizeinputSize はそれぞれ出力チャネル数と入力チャネル数

Bias

[outputSize 1]outputSize は出力チャネル数

gru入力重み

[3*numHiddenUnits inputSize]numHiddenUnits は演算の隠れユニット数、inputSize は入力チャネル数

再帰重み

[3*numHiddenUnits numHiddenUnits]numHiddenUnits は演算の隠れユニット数

Bias

[3*numHiddenUnits 1]numHiddenUnits は演算の隠れユニット数

lstm入力重み

[4*numHiddenUnits inputSize]numHiddenUnits は演算の隠れユニット数、inputSize は入力チャネル数

再帰重み

[4*numHiddenUnits numHiddenUnits]numHiddenUnits は演算の隠れユニット数

Bias

[4*numHiddenUnits 1]numHiddenUnits は演算の隠れユニット数

Glorot の初期化

Glorot (Xavier とも呼ばれる) 初期化子[1]は、範囲指定付きの一様分布 [6No+Ni,6No+Ni] から重みをサンプリングします。ここで、No と Ni の値は、深層学習演算のタイプによって異なります。

操作学習可能なパラメーターNoNi
dlconv重み

prod(filterSize)*numFiltersfilterSize はフィルター サイズを含む 1 行 K 列のベクトル、numFilters はフィルター数、K は空間次元数

prod(filterSize)*numChannelsfilterSize はフィルター サイズを含む 1 行 K 列のベクトル、numChannels は入力チャネル数、K は空間次元数

dlconv (グループ化)重み

prod(filterSize)*numFiltersPerGroupfilterSize はフィルター サイズを含む 1 行 K 列のベクトル、numFiltersPerGroup は各グループのフィルター数、K は空間次元数

prod(filterSize)*numChannelsPerGroupfilterSize はフィルター サイズを含む 1 行 K 列のベクトル、numChannelsPerGroup は各グループの入力チャネル数、K は空間次元数

dltranspconv重み

prod(filterSize)*numFiltersfilterSize はフィルター サイズを含む 1 行 K 列のベクトル、numFilters はフィルター数、K は空間次元数

prod(filterSize)*numChannelsfilterSize はフィルター サイズを含む 1 行 K 列のベクトル、numChannels は入力チャネル数、K は空間次元数

dltranspconv (グループ化)重み

prod(filterSize)*numFiltersPerGroupfilterSize はフィルター サイズを含む 1 行 K 列のベクトル、numFiltersPerGroup は各グループのフィルター数、K は空間次元数

prod(filterSize)*numChannelsPerGroupfilterSize はフィルター サイズを含む 1 行 K 列のベクトル、numChannelsPerGroup は各グループの入力チャネル数、K は空間次元数

fullyconnect重み演算の出力チャネル数演算の入力チャネル数
gru入力重み3*numHiddenUnitsnumHiddenUnits は演算の隠れユニット数演算の入力チャネル数
再帰重み3*numHiddenUnitsnumHiddenUnits は演算の隠れユニット数演算の隠れユニット数
lstm入力重み4*numHiddenUnitsnumHiddenUnits は演算の隠れユニット数演算の入力チャネル数
再帰重み4*numHiddenUnitsnumHiddenUnits は演算の隠れユニット数演算の隠れユニット数

Glorot 初期化子を使用して学習可能なパラメーターを簡単に初期化するため、カスタム関数を定義できます。関数 initializeGlorot は、学習可能パラメーター sz のサイズ、ならびに値 No および Ni (それぞれ numOutnumIn) を入力として受け取り、基となる型が 'single'dlarray オブジェクトとして、サンプリングされた重みを返します。

function weights = initializeGlorot(sz,numOut,numIn)

Z = 2*rand(sz,'single') - 1;
bound = sqrt(6 / (numIn + numOut));

weights = bound * Z;
weights = dlarray(weights);

end

サイズ 5 x 5 のフィルター 128 個、入力チャネル 3 個で、畳み込み演算の重みを初期化します。

filterSize = [5 5];
numChannels = 3;
numFilters = 128;

sz = [filterSize numChannels numFilters];
numOut = prod(filterSize) * numFilters;
numIn = prod(filterSize) * numFilters;

parameters.conv.Weights = initializeGlorot(sz,numOut,numIn);

He の初期化

He 初期化子[2]は、平均 0、分散 2Ni の正規分布から重みをサンプリングします。ここで、値 Ni は、深層学習演算のタイプによって決まります。

操作学習可能なパラメーターNi
dlconv重み

prod(filterSize)*numChannelsPerGroupfilterSize はフィルター サイズを含む 1 行 K 列のベクトル、numChannelsPerGroup は各グループの入力チャネル数、K は空間次元数

dltranspconv重み

prod(filterSize)*numChannelsPerGroupfilterSize はフィルター サイズを含む 1 行 K 列のベクトル、numChannelsPerGroup は各グループの入力チャネル数、K は空間次元数

fullyconnect重み演算の入力チャネル数
gru入力重み演算の入力チャネル数
再帰重み演算の隠れユニット数。
lstm入力重み演算の入力チャネル数
再帰重み演算の隠れユニット数。

He 初期化子を使用して学習可能なパラメーターを簡単に初期化するために、カスタム関数を定義できます。関数 initializeHe は、学習可能パラメーター sz のサイズと値 Ni を入力として受け取り、基となる型が 'single'dlarray オブジェクトとして、サンプリングされた重みを返します。

function weights = initializeHe(sz,numIn)

weights = randn(sz,'single') * sqrt(2/numIn);
weights = dlarray(weights);

end

サイズ 5 x 5 のフィルター 128 個、入力チャネル 3 個で、畳み込み演算の重みを初期化します。

filterSize = [5 5];
numChannels = 3;
numFilters = 128;

sz = [filterSize numChannels numFilters];
numIn = prod(filterSize) * numFilters;

parameters.conv.Weights = initializeHe(sz,numIn);

ガウスによる初期化

ガウス初期化子は、正規分布から重みをサンプリングします。

ガウス初期化子を使用して学習可能なパラメーターを簡単に初期化するために、カスタム関数を定義できます。関数 initializeGaussian は、学習可能パラメーター sz のサイズ、分散平均 mu、および分散標準偏差 sigma を入力として受け取り、基となる型が 'single'dlarray オブジェクトとして、サンプリングされた重みを返します。

function weights = initializeGaussian(sz,mu,sigma)

weights = randn(sz,'single')*sigma + mu;
weights = dlarray(weights);

end

平均 0 および標準偏差 0.01 のガウス初期化子を使用して、次元 300、語彙サイズ 5000 で埋め込み演算の重みを初期化します。

embeddingDimension = 300;
vocabularySize = 5000;
mu = 0;
sigma = 0.01;

sz = [embeddingDimension vocabularySize];

parameters.emb.Weights = initializeGaussian(sz,mu,sigma);

一様分布による初期化

一様分布初期化子は、一様分布から重みをサンプリングします。

一様分布初期化子を使用して学習可能なパラメーターを簡単に初期化するために、カスタム関数を定義できます。関数 initializeUniform は、学習可能パラメーター sz のサイズと分布境界 bound を入力として受け取り、基となる型が 'single'dlarray オブジェクトとして、サンプリングされた重みを返します。

function parameter = initializeUniform(sz,bound)

Z = 2*rand(sz,'single') - 1;
parameter = bound * Z;
parameter = dlarray(parameter);

end

一様分布初期化子を使用して、サイズ 100 x 100、境界 0.1 でアテンション メカニズムの重みを初期化します。

sz = [100 100];
bound = 0.1;

parameters.attentionn.Weights = initializeUniform(sz,bound);

直交初期化

直交初期化子は、Z = QR の QR 分解によって与えられる直交行列 Q を返します。ここで、Z は単位正規分布からサンプリングされ、Z のサイズは学習可能なパラメーターのサイズと一致します。

直交初期化子を使用して学習可能なパラメーターを簡単に初期化するために、カスタム関数を定義できます。関数 initializeOrthogonal は、学習可能パラメーター sz のサイズを入力として受け取り、基となる型が 'single'dlarray オブジェクトとして、直交行列を返します。

function parameter = initializeOrthogonal(sz)

Z = randn(sz,'single');
[Q,R] = qr(Z,0);

D = diag(R);
Q = Q * diag(D ./ abs(D));

parameter = dlarray(Q);

end

直交初期化子を使用して、100 個の隠れユニットで LSTM 演算の反復重みを初期化します。

numHiddenUnits = 100;

sz = [4*numHiddenUnits numHiddenUnits];

parameters.lstm.RecurrentWeights = initializeOrthogonal(sz);

ユニット忘却ゲートによる初期化

ユニット忘却ゲート初期化子は、バイアスの忘却ゲート成分が 1、残りのエントリが 0 になるように、LSTM 演算のバイアスを初期化します。

直交初期化子を使用して学習可能なパラメーターを簡単に初期化するために、カスタム関数を定義できます。関数 initializeUnitForgetGate は、LSTM 演算の隠れユニットの数を入力として受け取り、基となる型が 'single'dlarray オブジェクトとしてバイアスを返します。

function bias = initializeUnitForgetGate(numHiddenUnits)

bias = zeros(4*numHiddenUnits,1,'single');

idx = numHiddenUnits+1:2*numHiddenUnits;
bias(idx) = 1;

bias = dlarray(bias);

end

ユニット忘却ゲート初期化子を使用して、100 個の隠れユニットで LSTM 演算のバイアスを初期化します。

numHiddenUnits = 100;

parameters.lstm.Bias = initializeUnitForgetGate(numHiddenUnits,'single');

1 での初期化

学習可能なパラメーターを 1 で簡単に初期化するために、カスタム関数を定義できます。関数 initializeOnes は、学習可能パラメーター sz のサイズを入力として受け取り、基となる型が 'single'dlarray オブジェクトとして、パラメーターを返します。

function parameter = initializeOnes(sz)

parameter = ones(sz,'single');
parameter = dlarray(parameter);

end

128 個の入力チャネルで 1 を使用してバッチ正規化演算のスケールを初期化します。

numChannels = 128;

sz = [numChannels 1];

parameters.bn.Scale = initializeOnes(sz);

ゼロでの初期化

学習可能なパラメーターをゼロで簡単に初期化するために、カスタム関数を定義できます。関数 initializeZeros は、学習可能パラメーター sz のサイズを入力として受け取り、基となる型が 'single'dlarray オブジェクトとして、パラメーターを返します。

function parameter = initializeZeros(sz)

parameter = zeros(sz,'single');
parameter = dlarray(parameter);

end

128 個の入力チャネルでゼロを使用してバッチ正規化演算のオフセットを初期化します。

numChannels = 128;

sz = [numChannels 1];

parameters.bn.Offset = initializeZeros(sz);

学習可能パラメーターの保存

特定のモデル関数の学習可能パラメーターを、構造体、table、cell 配列などの単一のオブジェクトに保存することを推奨します。学習可能パラメーターを構造体として初期化する方法を示す例については、モデル関数を使用したネットワークの学習を参照してください。

GPU 上のパラメーターの保存

GPU を使用してモデルに学習させる場合、モデル関数の学習可能パラメーターが gpuArray オブジェクトに変換されて GPU に保存されます。

学習可能パラメーターを GPU のないマシンに簡単に読み込めるように、それらを保存する前にすべてのパラメーターをローカル ワークスペースに収集しておくことを推奨します。関数 dlupdate と関数 gather を使用し、構造体、table、または dlarray オブジェクトの cell 配列として保存されている学習可能パラメーターを収集します。たとえば、ネットワークの学習可能パラメーターが、構造体、table、または cell 配列 parameters で GPU に保存されている場合、次のコードを使用してパラメーターをローカル ワークスペースに転送できます。

parameters = dlupdate(@gather,parameters);

GPU にない学習可能パラメーターを読み込む場合、関数 dlupdate と関数 gpuArray を使用することで、パラメーターを GPU に移動できます。そうすることで、入力データが保存されている場所に関係なく、ネットワークが必ず GPU 上で学習と推論を実行するようになります。たとえば、構造体、table、または cell 配列 parameters で保存されている学習可能パラメーターを移動する場合、次のコードを使用してパラメーターを GPU に転送できます。

parameters = dlupdate(@gpuArray,parameters);

参照

[1] Glorot, Xavier, and Yoshua Bengio. "Understanding the Difficulty of Training Deep Feedforward Neural Networks." In Proceedings of the Thirteenth International Conference on Artificial Intelligence and Statistics, 249–356. Sardinia, Italy: AISTATS, 2010.

[2] He, Kaiming, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. "Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification." In Proceedings of the 2015 IEEE International Conference on Computer Vision, 1026–1034. Washington, DC: IEEE Computer Vision Society, 2015.

参考

| | |

関連するトピック