Main Content

resetState

ニューラル ネットワークの状態パラメーターのリセット

説明

netUpdated = resetState(net) は、ニューラル ネットワークの状態パラメーターをリセットします。この関数を使用して、LSTM ネットワークなどの再帰型ニューラル ネットワークの状態をリセットします。

すべて折りたたむ

シーケンスの予測間のネットワークの状態をリセットします。

[1] および [2] で説明されているように Japanese Vowels データセットで学習させた事前学習済みの長短期記憶 (LSTM) ネットワーク JapaneseVowelsNet を読み込みます。このネットワークは、ミニバッチのサイズ 27 を使用して、シーケンス長で並べ替えられたシーケンスで学習させています。

load JapaneseVowelsNet

ネットワーク アーキテクチャを表示します。

net.Layers
ans = 
  5x1 Layer array with layers:

     1   'sequenceinput'   Sequence Input          Sequence input with 12 dimensions
     2   'lstm'            LSTM                    LSTM with 100 hidden units
     3   'fc'              Fully Connected         9 fully connected layer
     4   'softmax'         Softmax                 softmax
     5   'classoutput'     Classification Output   crossentropyex with '1' and 8 other classes

テスト データを読み込みます。

[XTest,YTest] = japaneseVowelsTestData;

シーケンスを分類し、ネットワークの状態を更新します。再現性を得るために、rng'shuffle' に設定します。

rng('shuffle')
X = XTest{94};
[net,label] = classifyAndUpdateState(net,X);
label
label = categorical
     3 

更新されたネットワークを使用して別のシーケンスを分類します。

X = XTest{1};
label = classify(net,X)
label = categorical
     7 

最終予測を真のラベルと比較します。

trueLabel = YTest(1)
trueLabel = categorical
     1 

ネットワークの更新後の状態は、分類に悪影響を与える場合があります。ネットワークの状態をリセットし、シーケンスについて再度予測を行います。

net = resetState(net);
label = classify(net,XTest{1})
label = categorical
     1 

入力引数

すべて折りたたむ

ニューラル ネットワーク。SeriesNetwork オブジェクト、DAGNetwork オブジェクト、または dlnetwork オブジェクトとして指定します。

関数 resetState は、net に状態パラメーターがある場合 (たとえば、LSTM 層などの少なくとも 1 つの再帰層をもつネットワーク) にのみ効果があります。入力ネットワークに状態パラメーターがない場合、この関数は無効となり、入力ネットワークを返します。

出力引数

すべて折りたたむ

更新されたネットワーク。入力ネットワークと同じタイプのネットワークとして返されます。

関数 resetState は、net に状態パラメーターがある場合 (たとえば、LSTM 層などの少なくとも 1 つの再帰層をもつネットワーク) にのみ効果があります。入力ネットワークに状態パラメーターがない場合、この関数は無効となり、入力ネットワークを返します。

参照

[1] M. Kudo, J. Toyama, and M. Shimbo. "Multidimensional Curve Classification Using Passing-Through Regions." Pattern Recognition Letters. Vol. 20, No. 11–13, pages 1103–1111.

[2] UCI Machine Learning Repository: Japanese Vowels Dataset. https://archive.ics.uci.edu/ml/datasets/Japanese+Vowels

拡張機能

バージョン履歴

R2017b で導入