Main Content

resetState

ニューラル ネットワークの状態パラメーターのリセット

説明

netUpdated = resetState(net) は、ニューラル ネットワークの状態パラメーターをリセットします。この関数を使用して、LSTM ネットワークなどの再帰型ニューラル ネットワークの状態をリセットします。

すべて折りたたむ

シーケンスの予測間のネットワークの状態をリセットします。

[1] および [2] で説明されているように Japanese Vowels データ セットで学習させた事前学習済みの長短期記憶 (LSTM) ネットワーク dlnetJapaneseVowels を読み込みます。このネットワークは、ミニバッチのサイズ 27 を使用して、シーケンス長で並べ替えられたシーケンスで学習させています。

load dlnetJapaneseVowels

ネットワーク アーキテクチャを表示します。

net.Layers
ans = 
  4x1 Layer array with layers:

     1   'sequenceinput'   Sequence Input    Sequence input with 12 dimensions
     2   'lstm'            LSTM              LSTM with 100 hidden units
     3   'fc'              Fully Connected   9 fully connected layer
     4   'softmax'         Softmax           softmax

テスト データを読み込みます。

load JapaneseVowelsTestData

シーケンスを分類し、ネットワークの状態を更新します。

X = XTest{94};
[scores,state] = predict(net,X,InputDataFormats="CT");
net.State = state;
label = scores2label(scores,classNames)
label = categorical
     3 

更新されたネットワークを使用して別のシーケンスを分類します。

X = XTest{1};
scores = predict(net,X,InputDataFormats="CT");
label = scores2label(scores,classNames)
label = categorical
     7 

最終予測を真のラベルと比較します。

trueLabel = TTest(1)
trueLabel = categorical
     1 

ネットワークの更新後の状態は、分類に悪影響を与える場合があります。ネットワークの状態をリセットし、シーケンスについて再度予測を行います。

net = resetState(net);
scores = predict(net,X,InputDataFormats="CT");
label = scores2label(scores,classNames)
label = categorical
     1 

入力引数

すべて折りたたむ

ニューラル ネットワーク。dlnetwork オブジェクトとして指定します。

関数 resetState は、net に状態パラメーターがある場合 (たとえば、LSTM 層などの少なくとも 1 つの再帰層をもつネットワーク) にのみ効果があります。入力ネットワークに状態パラメーターがない場合、この関数は無効となり、入力ネットワークを返します。

出力引数

すべて折りたたむ

更新されたネットワーク。dlnetwork オブジェクトとして返されます。

関数 resetState は、net に状態パラメーターがある場合 (たとえば、LSTM 層などの少なくとも 1 つの再帰層をもつネットワーク) にのみ効果があります。入力ネットワークに状態パラメーターがない場合、この関数は無効となり、入力ネットワークを返します。

参照

[1] M. Kudo, J. Toyama, and M. Shimbo. "Multidimensional Curve Classification Using Passing-Through Regions." Pattern Recognition Letters. Vol. 20, No. 11–13, pages 1103–1111.

[2] UCI Machine Learning Repository: Japanese Vowels Dataset. https://archive.ics.uci.edu/ml/datasets/Japanese+Vowels

拡張機能

バージョン履歴

R2017b で導入

すべて展開する