Simulink でのネットワークの状態の予測と更新
この例では、Stateful Predict
ブロックを使用して、学習済みの再帰型ニューラル ネットワークの応答を Simulink® で予測する方法を説明します。この例では、事前学習済みの長短期記憶 (LSTM) ネットワークを使用します。
事前学習済みのネットワークの読み込み
[1] および [2] で説明されているように Japanese Vowels データ セットで学習させた事前学習済みの長短期記憶 (LSTM) ネットワーク JapaneseVowelsNet
を読み込みます。このネットワークは、ミニバッチのサイズ 27 を使用して、シーケンス長で並べ替えられたシーケンスで学習させています。
load JapaneseVowelsNet
ネットワーク アーキテクチャを表示します。
analyzeNetwork(net);
テスト データの読み込み
Japanese Vowels テスト データを読み込みます。XTest
は、次元 12 の可変長の 370 個のシーケンスが含まれる cell 配列です。TTest
は、9 人の話者に対応するラベル "1"、"2"、...、"9" から成る categorical ベクトルです。
タイムスタンプ付きの行と X
の反復コピーから成る timetable 配列 simin
を作成します。
load JapaneseVowelsTestData X = XTest{94}; numTimeSteps = size(X,2); simin = timetable(repmat(X,1,4)','TimeStep',seconds(0.2));
応答を予測するための Simulink モデル
応答を予測するための Simulink モデルには、スコアを予測するための Stateful Predict
ブロックと、タイム ステップでの入力データ シーケンスを読み込むための From Workspace
ブロックが含まれています。
シミュレーション中に再帰型ニューラル ネットワークの状態を初期状態にリセットするには、Stateful Predict
ブロックを Resettable Subsystem
内に配置し、トリガーとして制御信号 Reset
を使用します。
open_system('StatefulPredictExample');
シミュレーション用モデルの構成
Stateful Predict
ブロックのモデル コンフィギュレーション パラメーターを設定します。
set_param('StatefulPredictExample/Stateful Predict','NetworkFilePath','JapaneseVowelsNet.mat'); set_param('StatefulPredictExample', 'SimulationMode', 'Normal');
シミュレーションの実行
JapaneseVowelsNet
ネットワークの応答を計算するには、シミュレーションを実行します。予測スコアは MATLAB® ワークスペースに保存されます。
out = sim('StatefulPredictExample');
予測スコアをプロットします。プロットには、タイム ステップ間での予測スコアの変化が表示されます。
scores = squeeze(out.yPred.Data(:,:,1:numTimeSteps)); classNames = string(net.Layers(end).Classes); figure lines = plot(scores'); xlim([1 numTimeSteps]) legend("Class " + classNames,'Location','northwest') xlabel("Time Step") ylabel("Score") title("Prediction Scores Over Time Steps")
正しいクラスについて、タイム ステップでの予測スコアを強調表示します。
trueLabel = TTest(94); lines(trueLabel).LineWidth = 3;
最後のタイム ステップの予測を棒グラフで表示します。
figure bar(scores(:,end)) title("Final Prediction Scores") xlabel("Class") ylabel("Score")
参照
[1] M. Kudo, J. Toyama, and M. Shimbo. "Multidimensional Curve Classification Using Passing-Through Regions." Pattern Recognition Letters. Vol. 20, No. 11–13, pages 1103–1111.
[2] UCI Machine Learning Repository: Japanese Vowels Dataset. https://archive.ics.uci.edu/ml/datasets/Japanese+Vowels
参考
Stateful Predict | Stateful Classify | Predict | Image Classifier