双方向 LSTM (BiLSTM) 関数の作成
この例では、カスタム深層学習関数用の双方向長短期記憶 (BiLSTM) 関数を作成する方法を示します。
深層学習モデルでは、双方向 LSTM (BiLSTM) 演算によって、時系列データまたはシーケンス データのタイム ステップ間の双方向の長期的な依存関係を学習します。これらの依存関係は、各タイム ステップで時系列全体からネットワークに学習させる場合に役立ちます。
ほとんどのタスクでは、bilstmLayer
オブジェクトを含むネットワークに学習させることができます。関数内で BiLSTM 演算を使用するには、この例を指針として使用して BiLSTM 関数を作成します。
BiLSTM は、最初のタイム ステップから最後のタイム ステップまで演算を行う "順方向 LSTM" と、最後のタイム ステップから最初のタイム ステップまで演算を行う "逆方向 LSTM" の 2 つの LSTM コンポーネントで構成されます。この演算は、データを 2 つの LSTM コンポーネントに渡した後、チャネル次元に沿って出力を連結します。
BiLSTM 関数の作成
この例の最後にリストされている bilstm
関数を作成します。この関数は、初期の隠れ状態、初期のセル状態、入力重み、再帰重み、およびバイアスを使用して、入力に BiLSTM 演算を適用します。
BiLSTM パラメーターの初期化
入力のサイズ (入力層の埋め込み次元など) と隠れユニットの数を指定します。
inputSize = 256; numHiddenUnits = 50;
BiLSTM パラメーターを初期化します。BiLSTM 演算では、演算の順方向と逆方向の両方の部分について、入力重み、再帰重み、およびバイアスのセットが必要です。これらのパラメーターには、順方向と逆方向のコンポーネントを連結したものを指定します。この場合、入力重みのサイズは [8*numHiddenUnits inputSize]
、再帰重みのサイズは [8*numHiddenUnits numHiddenUnits]
、バイアスのサイズは [8*numHiddenUnits 1]
です。
それぞれ initializeGlorot
関数、initializeOrthogonal
関数、および initializeUnitForgetGate
関数を使用して、入力重み、再帰重み、バイアスを初期化します。これらの関数は、この例にサポート ファイルとして添付されています。これらの関数にアクセスするには、例をライブ スクリプトとして開きます。
initializeGlorot
関数を使用して入力重みを初期化します。
numOut = 8*numHiddenUnits; numIn = inputSize; sz = [numOut numIn]; inputWeights = initializeGlorot(sz,numOut,numIn);
initializeOrthogonal
関数を使用して再帰重みを初期化します。
sz = [8*numHiddenUnits numHiddenUnits]; recurrentWeights = initializeOrthogonal(sz);
initializeUnitForgetGate
関数を使用して入力重みを初期化します。
bias = initializeUnitForgetGate(2*numHiddenUnits);
この例にサポート ファイルとして添付されている initializeZeros
関数を使用して、BiLSTM の隠れ状態とセル状態をゼロで初期化します。この関数にアクセスするには、例をライブ スクリプトとして開きます。パラメーターと同様に、順方向と逆方向のコンポーネントを連結したものを指定します。この場合、隠れ状態とセル状態のサイズはそれぞれ [2*numHiddenUnits 1]
です。
sz = [2*numHiddenUnits 1]; H0 = initializeZeros(sz); C0 = initializeZeros(sz);
BiLSTM 演算の適用
ミニバッチ サイズが 128、シーケンス長が 75 であるランダム データの配列を指定します。入力の最初の次元 (チャネル次元) は、BiLSTM 演算の入力サイズと一致しなければなりません。
miniBatchSize = 128; sequenceLength = 75; X = rand([inputSize miniBatchSize sequenceLength],"single"); X = dlarray(X,"CBT");
BiLSTM 演算を適用し、出力のサイズを表示します。
Y = bilstm(X,H0,C0,inputWeights,recurrentWeights,bias); size(Y)
ans = 1×3
100 128 75
シーケンスの最後のタイム ステップのみを必要とするモデルの場合は、順方向 LSTM コンポーネントと逆方向 LSTM コンポーネントの最後の出力に対応するベクトルを抽出します。
YLastForward = Y(1:numHiddenUnits,:,end); YLastBackward = Y(numHiddenUnits+1:end,:,1); YLast = cat(1, YLastForward, YLastBackward); size(YLast)
ans = 1×3
100 128 1
BiLSTM 関数
bilstm
関数は、初期隠れ状態 H0
、初期セル状態 C0
、ならびにパラメーター weights
、recurrentWeights
、および bias
を使用して、形式を整えた dlarray
入力 X
に BiLSTM 演算を適用します。入力重みのサイズは [8*numHiddenUnits inputSize]
、再帰重みのサイズは [8*numHiddenUnits numHiddenUnits]
、バイアスのサイズは [8*numHiddenUnits 1]
です。隠れ状態とセル状態のサイズはそれぞれ [2*numHiddenUnits 1]
です。
function [Y,hiddenState,cellState] = bilstm(X,H0,C0,inputWeights,recurrentWeights,bias) % Determine forward and backward parameter indices numHiddenUnits = numel(bias)/8; idxForward = 1:4*numHiddenUnits; idxBackward = 4*numHiddenUnits+1:8*numHiddenUnits; % Forward and backward states H0Forward = H0(1:numHiddenUnits); H0Backward = H0(numHiddenUnits+1:end); C0Forward = C0(1:numHiddenUnits); C0Backward = C0(numHiddenUnits+1:end); % Forward and backward parameters inputWeightsForward = inputWeights(idxForward,:); inputWeightsBackward = inputWeights(idxBackward,:); recurrentWeightsForward = recurrentWeights(idxForward,:); recurrentWeightsBackward = recurrentWeights(idxBackward,:); biasForward = bias(idxForward); biasBackward = bias(idxBackward); % Forward LSTM [YForward,hiddenStateForward,cellStateForward] = lstm(X,H0Forward,C0Forward,inputWeightsForward, ... recurrentWeightsForward,biasForward); % Backward LSTM XBackward = X; idx = finddim(X,"T"); if ~isempty(idx) XBackward = flip(XBackward,idx); end [YBackward,hiddenStateBackward,cellStateBackward] = lstm(XBackward,H0Backward,C0Backward,inputWeightsBackward, ... recurrentWeightsBackward,biasBackward); if ~isempty(idx) YBackward = flip(YBackward,idx); end % Output Y = cat(1,YForward,YBackward); hiddenState = cat(1,hiddenStateForward,hiddenStateBackward); cellState = cat(1,cellStateForward,cellStateBackward); end
参考
sequenceInputLayer
| lstmLayer
| bilstmLayer
| dlarray
| lstm