LSTM ネットワークの活性化の可視化
この例では、活性化を抽出し、LSTM ネットワークによって学習された特徴を調査して可視化する方法を説明します。
事前学習済みのネットワークを読み込みます。JapaneseVowelsNet
は、[1] および [2] で説明されているように Japanese Vowels データセットで学習させた事前学習済みの LSTM ネットワークです。これは、ミニバッチのサイズ 27 を使用して、シーケンス長で並べ替えられたシーケンスで学習させています。
load JapaneseVowelsNet
ネットワーク アーキテクチャを表示します。
net.Layers
ans = 4x1 Layer array with layers: 1 'sequenceinput' Sequence Input Sequence input with 12 dimensions 2 'lstm' LSTM LSTM with 100 hidden units 3 'fc' Fully Connected 9 fully connected layer 4 'softmax' Softmax softmax
テスト データを読み込みます。
load JapaneseVowelsTestData
最初の時系列をプロットで可視化します。各ラインは特徴に対応しています。
X = XTest{1}; figure plot(XTest{1}') xlabel("Time Step") title("Test Observation 1") numFeatures = size(XTest{1},1); legend("Feature " + string(1:numFeatures),'Location',"northeastoutside")
シーケンスの各タイム ステップについて、LSTM 層 (層 2) がそのタイム ステップ用に出力した活性化を取得して、ネットワークの状態を更新します。
sequenceLength = size(X,2); idxLayer = 2; outputSize = net.Layers(idxLayer).NumHiddenUnits; for i = 1:sequenceLength [features(i,:),state] = predict(net,X(:,1)',Outputs="lstm"); net.State = state; end features = features';
ヒートマップを使用して、最初の 10 個の隠れユニットを可視化します。
figure heatmap(features(1:10,:)); xlabel("Time Step") ylabel("Hidden Unit") title("LSTM Activations")
ヒートマップは、各隠れユニットがどれだけ強く活性化したかを示しており、時間の経過に伴い活性化がどのように変化したかを強調表示します。
参照
[1] M. Kudo, J. Toyama, and M. Shimbo. "Multidimensional Curve Classification Using Passing-Through Regions." Pattern Recognition Letters. Vol. 20, No. 11–13, pages 1103–1111.
[2] UCI Machine Learning Repository: Japanese Vowels Dataset. https://archive.ics.uci.edu/ml/datasets/Japanese+Vowels
参考
trainnet
| trainingOptions
| dlnetwork
| predict
| forward
| lstmLayer
| bilstmLayer
| sequenceInputLayer