最新のリリースでは、このページがまだ翻訳されていません。 このページの最新版は英語でご覧になれます。

シンプルなシーケンス分類ネットワークの作成

この例では、シンプルな長短期記憶 (LSTM) 分類ネットワークを作成する方法を説明します。

シーケンス データを分類するよう深層ニューラル ネットワークに学習させるために、LSTM ネットワークを使用できます。LSTM ネットワークは、再帰型ニューラル ネットワーク (RNN) の一種で、シーケンス データのタイム ステップ間の長期的な依存関係を学習します。

この例では、以下を実行する方法を示します。

  • シーケンス データの読み込み。

  • ネットワーク アーキテクチャの定義。

  • 学習オプションの指定。

  • ネットワークの学習。

  • 新しいデータのラベルの予測と分類精度の計算。

データの読み込み

[1] および [2] に記載のある Japanese Vowels データセットを読み込みます。予測子は、特徴次元 12 の可変長のシーケンスが含まれる cell 配列です。ラベルは、ラベル 1、2、...、9 の categorical ベクトルです。

[XTrain,YTrain] = japaneseVowelsTrainData;
[XValidation,YValidation] = japaneseVowelsTestData;

最初のいくつかの学習シーケンスのサイズを表示します。シーケンスは行列で、行数が 12 (特徴ごとに 1 行) で、列数が可変 (タイム ステップごとに 1 列) です。

XTrain(1:5)
ans=5×1 cell
    {12x20 double}
    {12x26 double}
    {12x22 double}
    {12x20 double}
    {12x21 double}

ネットワーク アーキテクチャの定義

LSTM ネットワーク アーキテクチャを定義します。入力層の特徴の数と全結合層のクラス数を指定します。

numFeatures = 12;
numHiddenUnits = 100;
numClasses = 9;

layers = [ ...
    sequenceInputLayer(numFeatures)
    lstmLayer(numHiddenUnits,'OutputMode','last')
    fullyConnectedLayer(numClasses)
    softmaxLayer
    classificationLayer];

ネットワークの学習

学習オプションを指定し、ネットワークに学習させます。

ミニバッチが小さく、シーケンスが短いため、学習には CPU が適しています。'ExecutionEnvironment''cpu' に設定します。GPU が利用できる場合、GPU で学習を行うには、'ExecutionEnvironment''auto' (既定値) に設定します。

miniBatchSize = 27;

options = trainingOptions('adam', ...
    'ExecutionEnvironment','cpu', ...
    'MaxEpochs',100, ...
    'MiniBatchSize',miniBatchSize, ...
    'ValidationData',{XValidation,YValidation}, ...
    'GradientThreshold',2, ...
    'Shuffle','every-epoch', ...
    'Verbose',false, ...
    'Plots','training-progress');

net = trainNetwork(XTrain,YTrain,layers,options);

学習オプションの指定の詳細は、パラメーターの設定と畳み込みニューラル ネットワークの学習を参照してください。

ネットワークのテスト

テスト データを分類し、分類精度を計算します。学習に使用されるサイズと同じミニバッチ サイズを指定します。

YPred = classify(net,XValidation,'MiniBatchSize',miniBatchSize);
acc = mean(YPred == YValidation)
acc = 0.9405

次のステップとして、双方向 LSTM (BiLSTM) 層を使用するか、より深いネットワークを作成して、精度の改善を試みることができます。詳細については、長短期記憶ネットワークを参照してください。

畳み込みネットワークを使用してシーケンス データを分類する方法を示す例については、深層学習を使用した音声コマンド認識を参照してください。

参考文献

  1. M. Kudo, J. Toyama, and M. Shimbo. "Multidimensional Curve Classification Using Passing-Through Regions."Pattern Recognition Letters. Vol. 20, No. 11–13, pages 1103–1111.

  2. UCI Machine Learning Repository: Japanese Vowels Dataset. https://archive.ics.uci.edu/ml/datasets/Japanese+Vowels

参考

| |

関連するトピック