Simulink でのネットワークの状態の分類と更新
この例では、Stateful Classify
ブロックを使用して、Simulink® で学習済みの再帰型ニューラル ネットワークのデータを分類する方法を説明します。この例では、事前学習済みの長短期記憶 (LSTM) ネットワークを使用します。
事前学習済みのネットワークの読み込み
[1] および [2] で説明されているように Japanese Vowels データセットで学習させた事前学習済みの長短期記憶 (LSTM) ネットワーク JapaneseVowelsNet
を読み込みます。このネットワークは、ミニバッチのサイズ 27 を使用して、シーケンス長で並べ替えられたシーケンスで学習させています。
load JapaneseVowelsNet
ネットワーク アーキテクチャを表示します。
analyzeNetwork(net);
テスト データの読み込み
Japanese Vowels テスト データを読み込みます。XTest
は、次元 12 の可変長の 370 個のシーケンスが含まれる cell 配列です。TTest
は、9 人の話者に対応するラベル "1"、"2"、...、"9" から成る categorical ベクトルです。
タイムスタンプ付きの行と X
の反復コピーから成る timetable 配列 simin
を作成します。
[XTest,TTest] = japaneseVowelsTestData;
X = XTest{94};
numTimeSteps = size(X,2);
simin = timetable(repmat(X,1,4)','TimeStep',seconds(0.2));
データ分類用の Simulink モデル
データを分類するための Simulink モデルには、ラベルを予測するための Stateful Classify
ブロックと、タイム ステップでの入力データ シーケンスを読み込むための From Workspace
ブロックが含まれています。
シミュレーション中に再帰型ニューラル ネットワークの状態を初期状態にリセットするには、Stateful Classify
ブロックを Resettable Subsystem
内に配置し、トリガーとして制御信号 Reset
を使用します。
open_system('StatefulClassifyExample');
シミュレーション用モデルの構成
Stateful Classify
ブロックのモデル コンフィギュレーション パラメーターを設定します。
set_param('StatefulClassifyExample/Stateful Classify','NetworkFilePath','JapaneseVowelsNet.mat'); set_param('StatefulClassifyExample','SimulationMode','Normal');
シミュレーションの実行
JapaneseVowelsNet
ネットワークの応答を計算するには、シミュレーションを実行します。予測ラベルは MATLAB® ワークスペースに保存されます。
out = sim('StatefulClassifyExample');
予測されたラベルを階段状プロットにプロットします。このプロットには、タイムス ステップ間の予測の変化が示されます。
labels = squeeze(out.YPred.Data(1:numTimeSteps,1)); figure stairs(labels, '-o') xlim([1 numTimeSteps]) xlabel("Time Step") ylabel("Predicted Class") title("Classification Over Time Steps")
予測と真のラベルを比較します。観測値の真のラベルを示す水平のラインをプロットします。
trueLabel = double(TTest(94)); hold on line([1 numTimeSteps],[trueLabel trueLabel], ... 'Color','red', ... 'LineStyle','--') legend(["Prediction" "True Label"]) axis([1 numTimeSteps+1 0 9]);
参照
[1] M. Kudo, J. Toyama, and M. Shimbo. "Multidimensional Curve Classification Using Passing-Through Regions." Pattern Recognition Letters. Vol. 20, No. 11–13, pages 1103–1111.
[2] UCI Machine Learning Repository: Japanese Vowels Dataset. https://archive.ics.uci.edu/ml/datasets/Japanese+Vowels
参考
Stateful Predict | Stateful Classify | Predict | Image Classifier