Main Content

このページの内容は最新ではありません。最新版の英語を参照するには、ここをクリックします。

深層学習を使用したシーケンスの分類

この例では、長短期記憶 (LSTM) ネットワークを使用してシーケンス データを分類する方法を説明します。

シーケンス データを分類するよう深層ニューラル ネットワークに学習させるために、LSTM ネットワークを使用できます。LSTM ネットワークでは、シーケンス データをネットワークに入力し、シーケンス データの個々のタイム ステップに基づいて予測を行うことができます。

この例では、波形データ セットを使用します。この例では、与えられた時系列データについて波形のタイプを認識するように LSTM ネットワークに学習させます。学習データには 4 種類の波形の時系列データが含まれています。各シーケンスには 3 つのチャネルがあり、長さはさまざまです。

シーケンス データの読み込み

サンプル データを WaveformData から読み込みます。シーケンス データは、シーケンスの numObservations 行 1 列の cell 配列です。ここで、numObservations はシーケンスの数です。各シーケンスは numChannels-numTimeSteps 列の数値配列です。ここで、numChannels はシーケンスのチャネル数、numTimeSteps はシーケンスのタイム ステップ数です。ラベル データは、numObservations1 列の categorical ベクトルです。

load WaveformData 

シーケンスの一部をプロットで可視化します。

numChannels = size(data{1},1);

idx = [3 4 5 12];
figure
tiledlayout(2,2)
for i = 1:4
    nexttile
    stackedplot(data{idx(i)}',DisplayLabels="Channel "+string(1:numChannels))
    
    xlabel("Time Step")
    title("Class: " + string(labels(idx(i))))
end

テスト用のデータを確保します。データの 90% から成る学習セットとデータの残りの 10% から成るテスト セットにデータを分割します。データを分割するには、この例にサポート ファイルとして添付されている関数 trainingPartitions を使用します。このファイルにアクセスするには、例をライブ スクリプトとして開きます。

numObservations = numel(data);
[idxTrain,idxTest] = trainingPartitions(numObservations, [0.9 0.1]);
XTrain = data(idxTrain);
TTrain = labels(idxTrain);

XTest = data(idxTest);
TTest = labels(idxTest);

パディング用のデータの準備

既定では、学習中に、学習データはミニバッチに分割され、パディングによってシーケンスの長さが揃えられます。過度なパディングは、ネットワーク性能に悪影響を与える可能性があります。

学習プロセスでの過度のパディングを防ぐため、シーケンス長で学習データを並べ替えて、ミニバッチ内のシーケンスが似たような長さになるようにミニバッチのサイズを選択できます。次の図は、データを並べ替える前と後におけるシーケンスのパディングの効果を示しています。

各観測値のシーケンス長を取得します。

numObservations = numel(XTrain);
for i=1:numObservations
    sequence = XTrain{i};
    sequenceLengths(i) = size(sequence,2);
end

シーケンス長でデータを並べ替えます。

[sequenceLengths,idx] = sort(sequenceLengths);
XTrain = XTrain(idx);
TTrain = TTrain(idx);

並べ替えられたシーケンス長を棒グラフで表示します。

figure
bar(sequenceLengths)
xlabel("Sequence")
ylabel("Length")
title("Sorted Data")

LSTM ネットワーク アーキテクチャの定義

LSTM ネットワーク アーキテクチャを定義します。入力データのチャネル数になるように入力サイズを指定します。隠れユニットが 120 個の双方向 LSTM 層を指定して、シーケンスの最後の要素を出力します。最後に、クラス数と一致する出力サイズをもつ全結合層を含め、その後にソフトマックス層と分類層を含めます。

予測時にシーケンス全体にアクセスする場合は、ネットワークで双方向 LSTM 層を使用できます。双方向 LSTM 層は、各タイム ステップでシーケンス全体から学習します。予測時にシーケンス全体にアクセスしない場合、たとえば、値を予想していたり一度に 1 タイム ステップを予測していたりする場合は、代わりに LSTM 層を使用します。

numHiddenUnits = 120;
numClasses = 4;

layers = [ ...
    sequenceInputLayer(numChannels)
    bilstmLayer(numHiddenUnits,OutputMode="last")
    fullyConnectedLayer(numClasses)
    softmaxLayer
    classificationLayer]
layers = 
  5×1 Layer array with layers:

     1   ''   Sequence Input          Sequence input with 3 dimensions
     2   ''   BiLSTM                  BiLSTM with 120 hidden units
     3   ''   Fully Connected         4 fully connected layer
     4   ''   Softmax                 softmax
     5   ''   Classification Output   crossentropyex

学習オプションを指定します。学習率 0.002、勾配しきい値 1 で Adam ソルバーを使用して学習させます。エポックの最大数を 150 に設定し、シャッフルを無効にします。既定では、ソフトウェアは GPU が利用できる場合に GPU で学習を行います。GPU を使用するには、Parallel Computing Toolbox とサポートされている GPU デバイスが必要です。サポートされているデバイスについては、GPU 計算の要件 (Parallel Computing Toolbox)を参照してください。

options = trainingOptions("adam", ...
    InitialLearnRate=0.002,...
    MaxEpochs=150, ...
    Shuffle="never", ...
    GradientThreshold=1, ...
    Verbose=false, ...
    Plots="training-progress");

LSTM ネットワークの学習

trainNetwork を使用し、指定した学習オプションで LSTM ネットワークに学習させます。

net = trainNetwork(XTrain,TTrain,layers,options);

LSTM ネットワークのテスト

テスト データを分類し、予測の分類精度を計算します。

XTest(1:3)
ans=3×1 cell array
    {3×127 double}
    {3×180 double}
    {3×193 double}

LSTM ネットワーク net は、類似した長さのシーケンスのミニバッチを使用して学習しました。テスト データも同じ方法で構成されるようにします。シーケンス長でテスト データを並べ替えます。

numObservationsTest = numel(XTest);
for i=1:numObservationsTest
    sequence = XTest{i};
    sequenceLengthsTest(i) = size(sequence,2);
end

[sequenceLengthsTest,idx] = sort(sequenceLengthsTest);
XTest = XTest(idx);
TTest = TTest(idx);

テスト データを分類し、予測の分類精度を計算します。

YTest = classify(net,XTest);
acc = mean(YTest == TTest)
acc = 0.8400

分類結果を混同チャートで表示します。

figure
confusionchart(TTest,YTest)

参考

| | | |

関連するトピック