Main Content

このページの翻訳は最新ではありません。ここをクリックして、英語の最新版を参照してください。

深層学習を使用したシーケンスの分類

この例では、長短期記憶 (LSTM) ネットワークを使用してシーケンス データを分類する方法を説明します。

シーケンス データを分類するよう深層ニューラル ネットワークに学習させるために、LSTM ネットワークを使用できます。LSTM ネットワークでは、シーケンス データをネットワークに入力し、シーケンス データの個々のタイム ステップに基づいて予測を行うことができます。

この例では、[1] および [2] に記載のある Japanese Vowels データセットを使用します。この例では、続けて発音された 2 つの日本語の母音を表す時系列データにおいて、その話者を認識するように、LSTM ネットワークに学習させます。学習データには、9 人の話者の時系列データが含まれています。各シーケンスには 12 個の特徴があり、長さはさまざまです。データセットには 270 個の学習観測値と 370 個のテスト観測値が含まれています。

シーケンス データの読み込み

Japanese Vowels 学習データを読み込みます。XTrain は、次元 12 の可変長の 270 個のシーケンスが含まれる cell 配列です。Y は、9 人の話者に対応するラベル "1"、"2"、...、"9" から成る categorical ベクトルです。XTrain のエントリは行列で、行数が 12 (特徴ごとに 1 行) で、列数が可変 (タイム ステップごとに 1 列) です。

[XTrain,YTrain] = japaneseVowelsTrainData;
XTrain(1:5)
ans=5×1 cell array
    {12x20 double}
    {12x26 double}
    {12x22 double}
    {12x20 double}
    {12x21 double}

最初の時系列をプロットで可視化します。各ラインは特徴に対応しています。

figure
plot(XTrain{1}')
xlabel("Time Step")
title("Training Observation 1")
numFeatures = size(XTrain{1},1);
legend("Feature " + string(1:numFeatures),'Location','northeastoutside')

Figure contains an axes object. The axes object with title Training Observation 1 contains 12 objects of type line. These objects represent Feature 1, Feature 2, Feature 3, Feature 4, Feature 5, Feature 6, Feature 7, Feature 8, Feature 9, Feature 10, Feature 11, Feature 12.

パディング用のデータの準備

既定では、学習中に、学習データはミニバッチに分割され、パディングによってシーケンスの長さが揃えられます。過度なパディングは、ネットワーク性能に悪影響を与える可能性があります。

学習プロセスでの過度のパディングを防ぐため、シーケンス長で学習データを並べ替えて、ミニバッチ内のシーケンスが似たような長さになるようにミニバッチのサイズを選択できます。次の図は、データを並べ替える前と後におけるシーケンスのパディングの効果を示しています。

各観測値のシーケンス長を取得します。

numObservations = numel(XTrain);
for i=1:numObservations
    sequence = XTrain{i};
    sequenceLengths(i) = size(sequence,2);
end

シーケンス長でデータを並べ替えます。

[sequenceLengths,idx] = sort(sequenceLengths);
XTrain = XTrain(idx);
YTrain = YTrain(idx);

並べ替えられたシーケンス長を棒グラフで表示します。

figure
bar(sequenceLengths)
ylim([0 30])
xlabel("Sequence")
ylabel("Length")
title("Sorted Data")

Figure contains an axes object. The axes object with title Sorted Data contains an object of type bar.

学習データを等分し、ミニバッチ内のパディングの量を減らすため、ミニバッチのサイズに 27 を選択します。次の図は、シーケンスに追加されたパディングを示しています。

miniBatchSize = 27;

LSTM ネットワーク アーキテクチャの定義

LSTM ネットワーク アーキテクチャを定義します。サイズ 12 (入力データの次元) のシーケンスになるように入力サイズを指定します。隠れユニットが 100 個の双方向 LSTM 層を指定して、シーケンスの最後の要素を出力します。最後に、サイズが 9 の全結合層を含めることによって 9 個のクラスを指定し、その後にソフトマックス層と分類層を配置します。

予測時にシーケンス全体にアクセスする場合は、ネットワークで双方向 LSTM 層を使用できます。双方向 LSTM 層は、各タイム ステップでシーケンス全体から学習します。予測時にシーケンス全体にアクセスしない場合、たとえば、値を予想していたり一度に 1 タイム ステップを予測していたりする場合は、代わりに LSTM 層を使用します。

inputSize = 12;
numHiddenUnits = 100;
numClasses = 9;

layers = [ ...
    sequenceInputLayer(inputSize)
    bilstmLayer(numHiddenUnits,'OutputMode','last')
    fullyConnectedLayer(numClasses)
    softmaxLayer
    classificationLayer]
layers = 
  5x1 Layer array with layers:

     1   ''   Sequence Input          Sequence input with 12 dimensions
     2   ''   BiLSTM                  BiLSTM with 100 hidden units
     3   ''   Fully Connected         9 fully connected layer
     4   ''   Softmax                 softmax
     5   ''   Classification Output   crossentropyex

ここで、学習オプションを指定します。ソルバーに 'adam'、勾配のしきい値に 1、エポックの最大数に 100 を指定します。ミニバッチのパディングの量を減らすために、ミニバッチ サイズとして 27 を選択します。データが最長のシーケンスと同じ長さになるようにパディングするために、シーケンスの長さに 'longest' を指定します。データをシーケンス長で並べ替えたままにするために、データをシャッフルしないように指定します。

ミニバッチが小さく、シーケンスが短いため、学習には CPU が適しています。'ExecutionEnvironment''cpu' に指定します。GPU が利用できる場合、GPU で学習を行うには、'ExecutionEnvironment''auto' に設定します (これが既定値です)。

maxEpochs = 100;
miniBatchSize = 27;

options = trainingOptions('adam', ...
    'ExecutionEnvironment','cpu', ...
    'GradientThreshold',1, ...
    'MaxEpochs',maxEpochs, ...
    'MiniBatchSize',miniBatchSize, ...
    'SequenceLength','longest', ...
    'Shuffle','never', ...
    'Verbose',0, ...
    'Plots','training-progress');

LSTM ネットワークの学習

trainNetwork を使用し、指定した学習オプションで LSTM ネットワークに学習させます。

net = trainNetwork(XTrain,YTrain,layers,options);

Figure Training Progress (12-Apr-2022 00:55:03) contains 2 axes objects and another object of type uigridlayout. Axes object 1 contains 14 objects of type patch, text, line. Axes object 2 contains 14 objects of type patch, text, line.

LSTM ネットワークのテスト

テスト セットを読み込み、シーケンスを話者別に分類します。

Japanese Vowels テスト データを読み込みます。XTest は、次元 12 の可変長の 370 個のシーケンスが含まれる cell 配列です。YTest は、9 人の話者に対応するラベル "1"、"2"、...、"9" から成る categorical ベクトルです。

[XTest,YTest] = japaneseVowelsTestData;
XTest(1:3)
ans=3×1 cell array
    {12x19 double}
    {12x17 double}
    {12x19 double}

LSTM ネットワーク net は、類似した長さのシーケンスのミニバッチを使用して学習しました。テスト データも同じ方法で構成されるようにします。シーケンス長でテスト データを並べ替えます。

numObservationsTest = numel(XTest);
for i=1:numObservationsTest
    sequence = XTest{i};
    sequenceLengthsTest(i) = size(sequence,2);
end
[sequenceLengthsTest,idx] = sort(sequenceLengthsTest);
XTest = XTest(idx);
YTest = YTest(idx);

テスト データを分類します。分類処理で導入されたパディングの量を減らすために、ミニバッチ サイズを 27 に設定します。学習データと同じパディングを適用するために、シーケンスの長さに 'longest' を指定します。

miniBatchSize = 27;
YPred = classify(net,XTest, ...
    'MiniBatchSize',miniBatchSize, ...
    'SequenceLength','longest');

予測の分類精度を計算します。

acc = sum(YPred == YTest)./numel(YTest)
acc = 0.9757

参照

[1] M. Kudo, J. Toyama, and M. Shimbo. "Multidimensional Curve Classification Using Passing-Through Regions." Pattern Recognition Letters. Vol. 20, No. 11–13, pages 1103–1111.

[2] UCI Machine Learning Repository: Japanese Vowels Dataset. https://archive.ics.uci.edu/ml/datasets/Japanese+Vowels

参考

| | | |

関連するトピック