メインコンテンツ

双方向 LSTM (BiLSTM) 関数の作成

R2023b 以降

この例では、カスタム深層学習関数用の双方向長短期記憶 (BiLSTM) 関数を作成する方法を示します。

深層学習モデルでは、双方向 LSTM (BiLSTM) 演算によって、時系列データまたはシーケンス データのタイム ステップ間の双方向の長期的な依存関係を学習します。これらの依存関係は、各タイム ステップで時系列全体からネットワークに学習させる場合に役立ちます。

ほとんどのタスクでは、bilstmLayerオブジェクトを含むネットワークに学習させることができます。関数内で BiLSTM 演算を使用するには、この例を指針として使用して BiLSTM 関数を作成します。

BiLSTM は、最初のタイム ステップから最後のタイム ステップまで演算を行う "順方向 LSTM" と、最後のタイム ステップから最初のタイム ステップまで演算を行う "逆方向 LSTM" の 2 つの LSTM コンポーネントで構成されます。この演算は、データを 2 つの LSTM コンポーネントに渡した後、チャネル次元に沿って出力を連結します。

BiLSTM 関数の作成

この例の最後にリストされている bilstm 関数を作成します。この関数は、初期の隠れ状態、初期のセル状態、入力重み、再帰重み、およびバイアスを使用して、入力に BiLSTM 演算を適用します。

BiLSTM パラメーターの初期化

入力のサイズ (入力層の埋め込み次元など) と隠れユニットの数を指定します。

inputSize = 256;
numHiddenUnits = 50;

BiLSTM パラメーターを初期化します。BiLSTM 演算では、演算の順方向と逆方向の両方の部分について、入力重み、再帰重み、およびバイアスのセットが必要です。これらのパラメーターには、順方向と逆方向のコンポーネントを連結したものを指定します。この場合、入力重みのサイズは [8*numHiddenUnits inputSize]、再帰重みのサイズは [8*numHiddenUnits numHiddenUnits]、バイアスのサイズは [8*numHiddenUnits 1] です。

それぞれ initializeGlorot 関数、initializeOrthogonal 関数、および initializeUnitForgetGate 関数を使用して、入力重み、再帰重み、バイアスを初期化します。これらの関数は、この例にサポート ファイルとして添付されています。これらの関数にアクセスするには、例をライブ スクリプトとして開きます。

initializeGlorot 関数を使用して入力重みを初期化します。

numOut = 8*numHiddenUnits;
numIn = inputSize;
sz = [numOut numIn];
inputWeights = initializeGlorot(sz,numOut,numIn);

initializeOrthogonal 関数を使用して再帰重みを初期化します。

sz = [8*numHiddenUnits numHiddenUnits];
recurrentWeights = initializeOrthogonal(sz);

initializeUnitForgetGate 関数を使用して入力重みを初期化します。

bias = initializeUnitForgetGate(2*numHiddenUnits);

この例にサポート ファイルとして添付されている initializeZeros 関数を使用して、BiLSTM の隠れ状態とセル状態をゼロで初期化します。この関数にアクセスするには、例をライブ スクリプトとして開きます。パラメーターと同様に、順方向と逆方向のコンポーネントを連結したものを指定します。この場合、隠れ状態とセル状態のサイズはそれぞれ [2*numHiddenUnits 1] です。

sz = [2*numHiddenUnits 1];
H0 = initializeZeros(sz);
C0 = initializeZeros(sz);

BiLSTM 演算の適用

ミニバッチ サイズが 128、シーケンス長が 75 であるランダム データの配列を指定します。入力の最初の次元 (チャネル次元) は、BiLSTM 演算の入力サイズと一致しなければなりません。

miniBatchSize = 128;
sequenceLength = 75;
X = rand([inputSize miniBatchSize sequenceLength],"single");
X = dlarray(X,"CBT");

BiLSTM 演算を適用し、出力のサイズを表示します。

Y = bilstm(X,H0,C0,inputWeights,recurrentWeights,bias);
size(Y)
ans = 1×3

   100   128    75

シーケンスの最後のタイム ステップのみを必要とするモデルの場合は、順方向 LSTM コンポーネントと逆方向 LSTM コンポーネントの最後の出力に対応するベクトルを抽出します。

YLastForward = Y(1:numHiddenUnits,:,end);
YLastBackward = Y(numHiddenUnits+1:end,:,1);

YLast = cat(1, YLastForward, YLastBackward);
size(YLast)
ans = 1×3

   100   128     1

BiLSTM 関数

bilstm 関数は、初期隠れ状態 H0、初期セル状態 C0、ならびにパラメーター weightsrecurrentWeights、および bias を使用して、形式を整えた dlarray 入力 X に BiLSTM 演算を適用します。入力重みのサイズは [8*numHiddenUnits inputSize]、再帰重みのサイズは [8*numHiddenUnits numHiddenUnits]、バイアスのサイズは [8*numHiddenUnits 1] です。隠れ状態とセル状態のサイズはそれぞれ [2*numHiddenUnits 1] です。

function [Y,hiddenState,cellState] = bilstm(X,H0,C0,inputWeights,recurrentWeights,bias)

% Determine forward and backward parameter indices
numHiddenUnits = numel(bias)/8;
idxForward = 1:4*numHiddenUnits;
idxBackward = 4*numHiddenUnits+1:8*numHiddenUnits;

% Forward and backward states
H0Forward = H0(1:numHiddenUnits);
H0Backward = H0(numHiddenUnits+1:end);
C0Forward = C0(1:numHiddenUnits);
C0Backward = C0(numHiddenUnits+1:end);

% Forward and backward parameters
inputWeightsForward = inputWeights(idxForward,:);
inputWeightsBackward = inputWeights(idxBackward,:);
recurrentWeightsForward = recurrentWeights(idxForward,:);
recurrentWeightsBackward = recurrentWeights(idxBackward,:);
biasForward = bias(idxForward);
biasBackward = bias(idxBackward);

% Forward LSTM
[YForward,hiddenStateForward,cellStateForward] = lstm(X,H0Forward,C0Forward,inputWeightsForward, ...
    recurrentWeightsForward,biasForward);

% Backward LSTM
XBackward = X;
idx = finddim(X,"T");
if ~isempty(idx)
    XBackward = flip(XBackward,idx);
end

[YBackward,hiddenStateBackward,cellStateBackward] = lstm(XBackward,H0Backward,C0Backward,inputWeightsBackward, ...
    recurrentWeightsBackward,biasBackward);

if ~isempty(idx)
    YBackward = flip(YBackward,idx);
end

% Output
Y = cat(1,YForward,YBackward);
hiddenState = cat(1,hiddenStateForward,hiddenStateBackward);
cellState = cat(1,cellStateForward,cellStateBackward);

end

参考

| | | |

トピック