Main Content

このページの内容は最新ではありません。最新版の英語を参照するには、ここをクリックします。

深層学習を使用した sequence-to-one 回帰

この例では、長短期記憶 (LSTM) ニューラル ネットワークを使用して波形の周波数を予測する方法を説明します。

LSTM ニューラル ネットワークを使用すると、シーケンスとターゲット値から成る学習セットを使用して、シーケンスの数値応答を予測できます。LSTM ネットワークは、タイム ステップでループ処理してネットワークの状態を更新することにより入力データを処理する再帰型ニューラル ネットワーク (RNN) です。ネットワークの状態には、前のタイム ステップで記憶された情報が含まれています。シーケンスの数値応答の例には、次のものがあります。

  • 周波数、最大値、平均など、シーケンスのプロパティ。

  • シーケンスの過去または未来のタイム ステップの値。

この例では、波形データ セットを使用して sequence-to-one 回帰 LSTM ネットワークに学習させます。このデータ セットには、3 つのチャネルの異なる波長の合成生成波形が 1,000 個含まれます。従来の手法で波形の周波数を判定する方法については、fftを参照してください。

シーケンス データの読み込み

サンプル データを WaveformData.mat から読み込みます。データは、numObservations 行 1 列のシーケンスの cell 配列です。ここで、numObservations はシーケンスの数です。各シーケンスは numChannelsnumTimeSteps 列の数値配列です。ここで、numChannels はシーケンスのチャネル数、numTimeSteps はシーケンスに含まれるタイム ステップ数です。対応するターゲットは、波形の周波数から成る numObservationsnumResponses 列の数値配列です。ここで、numResponses はターゲットのチャネル数です。

load WaveformData

観測値の数を表示します。

numObservations = numel(data)
numObservations = 1000

最初のいくつかのシーケンスのサイズ、および対応する周波数を表示します。

data(1:4)
ans=4×1 cell array
    {3×103 double}
    {3×136 double}
    {3×140 double}
    {3×124 double}

freq(1:4,:)
ans = 4×1

    5.8922
    2.2557
    4.5250
    4.4418

シーケンスのチャネル数を表示します。ネットワークに学習させるには、各シーケンスに同じ数のチャネルが含まれていなければなりません。

numChannels = size(data{1},1)
numChannels = 3

応答の数 (ターゲットのチャネル数) を表示します。

numResponses = size(freq,2)
numResponses = 1

最初のいくつかのシーケンスをプロットに可視化します。

figure
tiledlayout(2,2)
for i = 1:4
    nexttile
    stackedplot(data{i}', DisplayLabels="Channel " + (1:numChannels))

    xlabel("Time Step")
    title("Frequency: " + freq(i))
end

学習用データの準備

検証用とテスト用のデータを残しておきます。データの 80% を含む学習セット、データの 10% を含む検証セット、およびデータの残りの 10% を含むテスト セットにデータを分割します。

[idxTrain,idxValidation,idxTest] = trainingPartitions(numObservations, [0.8 0.1 0.1]);

XTrain = data(idxTrain);
XValidation = data(idxValidation);
XTest = data(idxTest);

TTrain = freq(idxTrain);
TValidation = freq(idxValidation);
TTest = freq(idxTest);

LSTM ネットワーク アーキテクチャの定義

LSTM 回帰ネットワークを作成します。

  • 入力データのチャネル数と一致する入力サイズのシーケンス入力層を使用します。

  • 良好な適合を実現し、学習の発散を防ぐには、シーケンス入力層の Normalization オプションを "zscore" に設定します。これにより、ゼロ平均と単位分散をもつようにシーケンス データが正規化されます。

  • 100 個の隠れユニットをもつ LSTM 層を使用します。隠れユニットの数によって、層に学習させる情報量が決まります。より大きな値を使用するほど正確な結果が得られますが、学習データに過適合しやすくなる可能性があります。

  • 各シーケンスについて 1 つのタイム ステップを出力するには、LSTM 層の OutputMode オプションを "last" に設定します。

  • 予測する値の数を指定するには、予測子の数に一致するサイズの全結合層と、それに続く回帰層を含めます。

numHiddenUnits = 100;

layers = [ ...
    sequenceInputLayer(numChannels, Normalization="zscore")
    lstmLayer(numHiddenUnits, OutputMode="last")
    fullyConnectedLayer(numResponses)
    regressionLayer]
layers = 
  4×1 Layer array with layers:

     1   ''   Sequence Input      Sequence input with 3 dimensions
     2   ''   LSTM                LSTM with 100 hidden units
     3   ''   Fully Connected     1 fully connected layer
     4   ''   Regression Output   mean-squared-error

学習オプションの指定

学習オプションを指定します。

  • Adam オプティマイザーを使用して学習させる。

  • 学習を 250 エポック行う。より大きなデータ セットでは、良好な適合を実現させるために多くのエポックを学習させる必要がない場合があります。

  • 検証に使用するシーケンスと応答を指定する。

  • 検証損失が最良の (最も少ない) ネットワークを出力する。

  • 学習率を 0.005 に設定する。

  • 各ミニバッチで、最短のシーケンスと同じ長さになるようにシーケンスを切り捨てる。シーケンスを切り捨てると、パディングが必ず追加されないようにすることができますが、データが破棄されます。シーケンスに含まれるすべてのタイム ステップに重要な情報が含まれている可能性があるシーケンスの場合、切り捨てによってネットワークの良好な適合が妨げられる可能性があります。

  • 学習プロセスをプロットに表示する。

  • 詳細出力を無効にします。

options = trainingOptions("adam", ...
    MaxEpochs=250, ...
    ValidationData={XValidation TValidation}, ...
    OutputNetwork="best-validation-loss", ...
    InitialLearnRate=0.005, ...
    SequenceLength="shortest", ...
    Plots="training-progress", ...
    Verbose= false);

LSTM ネットワークの学習

関数 trainNetwork を使用し、指定した学習オプションで LSTM ネットワークに学習させます。

net = trainNetwork(XTrain, TTrain, layers, options);

LSTM ネットワークのテスト

テスト データを使用して、予測を実行します。

YTest = predict(net,XTest, SequenceLength="shortest");

最初のいくつかの予測をプロットに可視化します。

figure
tiledlayout(2,2)
for i = 1:4
    nexttile
    stackedplot(XTest{i}',DisplayLabels="Channel " + (1:numChannels))

    xlabel("Time Step")
    title("Predicted Frequency: " + string(YTest(i)))
end

ヒストグラムで平均二乗誤差を可視化します。

figure
histogram(mean((TTest - YTest).^2,2))
xlabel("Error")
ylabel("Frequency")

全体の平方根平均二乗誤差を計算します。

rmse = sqrt(mean((YTest-TTest).^2))
rmse = single
    0.6865

真の周波数に対する予測された周波数をプロットします。

figure
scatter(YTest,TTest, "b+");
xlabel("Predicted Frequency")
ylabel("Actual Frequency")
hold on

m = min(freq);
M=max(freq);
xlim([m M])
ylim([m M])
plot([m M], [m M], "r--")

参考

| | | |

関連するトピック