双方向 LSTM (BiLSTM) 関数の作成
この例では、カスタム深層学習関数用の双方向長短期記憶 (BiLSTM) 関数を作成する方法を示します。
深層学習モデルでは、双方向 LSTM (BiLSTM) 演算によって、時系列データまたはシーケンス データのタイム ステップ間の双方向の長期的な依存関係を学習します。これらの依存関係は、各タイム ステップで時系列全体からネットワークに学習させる場合に役立ちます。
ほとんどのタスクでは、bilstmLayerオブジェクトを含むネットワークに学習させることができます。関数内で BiLSTM 演算を使用するには、この例を指針として使用して BiLSTM 関数を作成します。
BiLSTM は、最初のタイム ステップから最後のタイム ステップまで演算を行う "順方向 LSTM" と、最後のタイム ステップから最初のタイム ステップまで演算を行う "逆方向 LSTM" の 2 つの LSTM コンポーネントで構成されます。この演算は、データを 2 つの LSTM コンポーネントに渡した後、チャネル次元に沿って出力を連結します。

BiLSTM 関数の作成
この例の最後にリストされている bilstm 関数を作成します。この関数は、初期の隠れ状態、初期のセル状態、入力重み、再帰重み、およびバイアスを使用して、入力に BiLSTM 演算を適用します。
BiLSTM パラメーターの初期化
入力のサイズ (入力層の埋め込み次元など) と隠れユニットの数を指定します。
inputSize = 256; numHiddenUnits = 50;
BiLSTM パラメーターを初期化します。BiLSTM 演算では、演算の順方向と逆方向の両方の部分について、入力重み、再帰重み、およびバイアスのセットが必要です。これらのパラメーターには、順方向と逆方向のコンポーネントを連結したものを指定します。この場合、入力重みのサイズは [8*numHiddenUnits inputSize]、再帰重みのサイズは [8*numHiddenUnits numHiddenUnits]、バイアスのサイズは [8*numHiddenUnits 1] です。
それぞれ initializeGlorot 関数、initializeOrthogonal 関数、および initializeUnitForgetGate 関数を使用して、入力重み、再帰重み、バイアスを初期化します。これらの関数は、この例にサポート ファイルとして添付されています。これらの関数にアクセスするには、例をライブ スクリプトとして開きます。
initializeGlorot 関数を使用して入力重みを初期化します。
numOut = 8*numHiddenUnits; numIn = inputSize; sz = [numOut numIn]; inputWeights = initializeGlorot(sz,numOut,numIn);
initializeOrthogonal 関数を使用して再帰重みを初期化します。
sz = [8*numHiddenUnits numHiddenUnits]; recurrentWeights = initializeOrthogonal(sz);
initializeUnitForgetGate 関数を使用して入力重みを初期化します。
bias = initializeUnitForgetGate(2*numHiddenUnits);
この例にサポート ファイルとして添付されている initializeZeros 関数を使用して、BiLSTM の隠れ状態とセル状態をゼロで初期化します。この関数にアクセスするには、例をライブ スクリプトとして開きます。パラメーターと同様に、順方向と逆方向のコンポーネントを連結したものを指定します。この場合、隠れ状態とセル状態のサイズはそれぞれ [2*numHiddenUnits 1] です。
sz = [2*numHiddenUnits 1]; H0 = initializeZeros(sz); C0 = initializeZeros(sz);
BiLSTM 演算の適用
ミニバッチ サイズが 128、シーケンス長が 75 であるランダム データの配列を指定します。入力の最初の次元 (チャネル次元) は、BiLSTM 演算の入力サイズと一致しなければなりません。
miniBatchSize = 128; sequenceLength = 75; X = rand([inputSize miniBatchSize sequenceLength],"single"); X = dlarray(X,"CBT");
BiLSTM 演算を適用し、出力のサイズを表示します。
Y = bilstm(X,H0,C0,inputWeights,recurrentWeights,bias); size(Y)
ans = 1×3
100 128 75
シーケンスの最後のタイム ステップのみを必要とするモデルの場合は、順方向 LSTM コンポーネントと逆方向 LSTM コンポーネントの最後の出力に対応するベクトルを抽出します。
YLastForward = Y(1:numHiddenUnits,:,end); YLastBackward = Y(numHiddenUnits+1:end,:,1); YLast = cat(1, YLastForward, YLastBackward); size(YLast)
ans = 1×3
100 128 1
BiLSTM 関数
bilstm 関数は、初期隠れ状態 H0、初期セル状態 C0、ならびにパラメーター weights、recurrentWeights、および bias を使用して、形式を整えた dlarray 入力 X に BiLSTM 演算を適用します。入力重みのサイズは [8*numHiddenUnits inputSize]、再帰重みのサイズは [8*numHiddenUnits numHiddenUnits]、バイアスのサイズは [8*numHiddenUnits 1] です。隠れ状態とセル状態のサイズはそれぞれ [2*numHiddenUnits 1] です。
function [Y,hiddenState,cellState] = bilstm(X,H0,C0,inputWeights,recurrentWeights,bias) % Determine forward and backward parameter indices numHiddenUnits = numel(bias)/8; idxForward = 1:4*numHiddenUnits; idxBackward = 4*numHiddenUnits+1:8*numHiddenUnits; % Forward and backward states H0Forward = H0(1:numHiddenUnits); H0Backward = H0(numHiddenUnits+1:end); C0Forward = C0(1:numHiddenUnits); C0Backward = C0(numHiddenUnits+1:end); % Forward and backward parameters inputWeightsForward = inputWeights(idxForward,:); inputWeightsBackward = inputWeights(idxBackward,:); recurrentWeightsForward = recurrentWeights(idxForward,:); recurrentWeightsBackward = recurrentWeights(idxBackward,:); biasForward = bias(idxForward); biasBackward = bias(idxBackward); % Forward LSTM [YForward,hiddenStateForward,cellStateForward] = lstm(X,H0Forward,C0Forward,inputWeightsForward, ... recurrentWeightsForward,biasForward); % Backward LSTM XBackward = X; idx = finddim(X,"T"); if ~isempty(idx) XBackward = flip(XBackward,idx); end [YBackward,hiddenStateBackward,cellStateBackward] = lstm(XBackward,H0Backward,C0Backward,inputWeightsBackward, ... recurrentWeightsBackward,biasBackward); if ~isempty(idx) YBackward = flip(YBackward,idx); end % Output Y = cat(1,YForward,YBackward); hiddenState = cat(1,hiddenStateForward,hiddenStateBackward); cellState = cat(1,cellStateForward,cellStateBackward); end
参考
sequenceInputLayer | lstmLayer | bilstmLayer | dlarray | lstm