Main Content

Simulink でのネットワークの状態の予測と更新

この例では、Stateful Predict ブロックを使用して、学習済みの再帰型ニューラル ネットワークの応答を Simulink® で予測する方法を説明します。この例では、事前学習済みの長短期記憶 (LSTM) ネットワークを使用します。

事前学習済みのネットワークの読み込み

[1] および [2] で説明されているように Japanese Vowels データセットで学習させた事前学習済みの長短期記憶 (LSTM) ネットワーク JapaneseVowelsNet を読み込みます。このネットワークは、ミニバッチのサイズ 27 を使用して、シーケンス長で並べ替えられたシーケンスで学習させています。

load JapaneseVowelsNet

ネットワーク アーキテクチャを表示します。

analyzeNetwork(net);

テスト データの読み込み

Japanese Vowels テスト データを読み込みます。XTest は、次元 12 の可変長の 370 個のシーケンスが含まれる cell 配列です。TTest は、9 人の話者に対応するラベル "1"、"2"、...、"9" から成る categorical ベクトルです。

タイムスタンプ付きの行と X の反復コピーから成る timetable 配列 simin を作成します。

[XTest,TTest] = japaneseVowelsTestData;
X = XTest{94};
numTimeSteps = size(X,2);
simin = timetable(repmat(X,1,4)','TimeStep',seconds(0.2));

応答を予測するための Simulink モデル

応答を予測するための Simulink モデルには、スコアを予測するための Stateful Predict ブロックと、タイム ステップでの入力データ シーケンスを読み込むための From Workspace ブロックが含まれています。

シミュレーション中に再帰型ニューラル ネットワークの状態を初期状態にリセットするには、Stateful Predict ブロックを Resettable Subsystem 内に配置し、トリガーとして制御信号 Reset を使用します。

open_system('StatefulPredictExample');

シミュレーション用モデルの構成

Stateful Predict ブロックのモデル コンフィギュレーション パラメーターを設定します。

set_param('StatefulPredictExample/Stateful Predict','NetworkFilePath','JapaneseVowelsNet.mat');
set_param('StatefulPredictExample', 'SimulationMode', 'Normal');

シミュレーションの実行

JapaneseVowelsNet ネットワークの応答を計算するには、シミュレーションを実行します。予測スコアは MATLAB® ワークスペースに保存されます。

out = sim('StatefulPredictExample');

予測スコアをプロットします。プロットには、タイム ステップ間での予測スコアの変化が表示されます。

scores = squeeze(out.yPred.Data(:,:,1:numTimeSteps));

classNames = string(net.Layers(end).Classes);
figure
lines = plot(scores');
xlim([1 numTimeSteps])
legend("Class " + classNames,'Location','northwest')
xlabel("Time Step")
ylabel("Score")
title("Prediction Scores Over Time Steps")

正しいクラスについて、タイム ステップでの予測スコアを強調表示します。

trueLabel = TTest(94);
lines(trueLabel).LineWidth = 3;

最後のタイム ステップの予測を棒グラフで表示します。

figure
bar(scores(:,end))
title("Final Prediction Scores")
xlabel("Class")
ylabel("Score")

参照

[1] M. Kudo, J. Toyama, and M. Shimbo. "Multidimensional Curve Classification Using Passing-Through Regions." Pattern Recognition Letters. Vol. 20, No. 11–13, pages 1103–1111.

[2] UCI Machine Learning Repository: Japanese Vowels Dataset. https://archive.ics.uci.edu/ml/datasets/Japanese+Vowels

参考

| | |

関連するトピック