Main Content

最新のリリースでは、このページがまだ翻訳されていません。 このページの最新版は英語でご覧になれます。

resetState

再帰型ニューラル ネットワークの状態のリセット

説明

updatedNet = resetState(recNet) は、再帰型ニューラル ネットワーク (LSTM ネットワークなど) の状態を初期状態にリセットします。

すべて折りたたむ

シーケンスの予測間のネットワークの状態をリセットします。

[1] および [2] で説明されているように Japanese Vowels データセットで学習させた事前学習済みの長短期記憶 (LSTM) ネットワーク JapaneseVowelsNet を読み込みます。このネットワークは、ミニバッチのサイズ 27 を使用して、シーケンス長で並べ替えられたシーケンスで学習させています。

load JapaneseVowelsNet

ネットワーク アーキテクチャを表示します。

net.Layers
ans = 
  5x1 Layer array with layers:

     1   'sequenceinput'   Sequence Input          Sequence input with 12 dimensions
     2   'lstm'            LSTM                    LSTM with 100 hidden units
     3   'fc'              Fully Connected         9 fully connected layer
     4   'softmax'         Softmax                 softmax
     5   'classoutput'     Classification Output   crossentropyex with '1' and 8 other classes

テスト データを読み込みます。

[XTest,YTest] = japaneseVowelsTestData;

シーケンスを分類し、ネットワークの状態を更新します。再現性を得るために、rng'shuffle' に設定します。

rng('shuffle')
X = XTest{94};
[net,label] = classifyAndUpdateState(net,X);
label
label = categorical
     3 

更新されたネットワークを使用して別のシーケンスを分類します。

X = XTest{1};
label = classify(net,X)
label = categorical
     7 

最終予測を真のラベルと比較します。

trueLabel = YTest(1)
trueLabel = categorical
     1 

ネットワークの更新後の状態は、分類に悪影響を与える場合があります。ネットワークの状態をリセットし、シーケンスについて再度予測を行います。

net = resetState(net);
label = classify(net,XTest{1})
label = categorical
     1 

入力引数

すべて折りたたむ

学習済み再帰型ニューラル ネットワーク。SeriesNetwork または DAGNetwork オブジェクトとして指定します。事前学習済みのネットワークをインポートするか、関数 trainNetwork を使用して独自のネットワークに学習させることによって、学習済みネットワークを取得できます。

recNet は再帰型ニューラル ネットワークです。これには少なくとも 1 つの再帰層 (LSTM ネットワークなど) を含めなければなりません。入力ネットワークが再帰型ネットワークでない場合、この関数は無効となり、入力ネットワークを返します。

出力引数

すべて折りたたむ

更新されたネットワーク。updatedNet は入力ネットワークと同じタイプのネットワークです。

入力ネットワークが再帰型ネットワークでない場合、この関数は無効となり、入力ネットワークを返します。

参照

[1] M. Kudo, J. Toyama, and M. Shimbo. "Multidimensional Curve Classification Using Passing-Through Regions." Pattern Recognition Letters. Vol. 20, No. 11–13, pages 1103–1111.

[2] UCI Machine Learning Repository: Japanese Vowels Dataset. https://archive.ics.uci.edu/ml/datasets/Japanese+Vowels

拡張機能

R2017b で導入