Main Content

1 次元畳み込みを使用したシーケンス分類

この例では、1 次元畳み込みニューラル ネットワークを使用してシーケンス データを分類する方法を説明します。

シーケンス データを分類するよう深層ニューラル ネットワークに学習させるために、1 次元畳み込みニューラル ネットワークを使用できます。1 次元畳み込み層は、1 次元入力にスライディング畳み込みフィルターを適用することにより、特徴を学習します。1 次元畳み込み層を使用すると、畳み込み層が 1 回の操作で入力を処理できるため、再帰層を使用するよりも高速になります。一方、再帰層では入力のタイム ステップを反復処理しなければなりません。ただし、ネットワーク アーキテクチャやフィルター サイズによっては、1 次元畳み込み層が、タイム ステップ間の長期的な依存関係を学習できる再帰層ほどには機能しない可能性があります。

この例では、[1] および [2] に記載のある Japanese Vowels データ セットを使用します。この例では、続けて発音された 2 つの日本語の母音を表す時系列データにおいて、その話者を認識するように、1 次元畳み込みニューラル ネットワークに学習させます。学習データには、9 人の話者の時系列データが含まれています。各シーケンスには 12 個の特徴があり、長さはさまざまです。データセットには 270 個の学習観測値と 370 個のテスト観測値が含まれています。

シーケンス データの読み込み

Japanese Vowels 学習データを読み込みます。予測子データは、12 の特徴をもつ可変長のシーケンスが含まれる cell 配列です。ターゲット データは、9 人の話者に対応するラベル "1"、"2"、...、"9" から成る categorical ベクトルです。予測子シーケンスは行列で、行数が 12 (特徴ごとに 1 行) で、列数が可変 (タイム ステップごとに 1 列) です。

[XTrain,TTrain] = japaneseVowelsTrainData;
[XValidation,TValidation] = japaneseVowelsTestData;

最初のいくつかの学習シーケンスを表示します。

XTrain(1:5)
ans=5×1 cell array
    {12x20 double}
    {12x26 double}
    {12x22 double}
    {12x20 double}
    {12x21 double}

最初の時系列をプロットで可視化します。各ラインは特徴に対応しています。

figure
plot(XTrain{1}')
xlabel("Time Step")
title("Training Observation 1")
numFeatures = size(XTrain{1},1);
legend("Feature " + string(1:numFeatures),Location="northeastoutside")

Figure contains an axes object. The axes object with title Training Observation 1, xlabel Time Step contains 12 objects of type line. These objects represent Feature 1, Feature 2, Feature 3, Feature 4, Feature 5, Feature 6, Feature 7, Feature 8, Feature 9, Feature 10, Feature 11, Feature 12.

学習データ内のクラスの数を表示します。

classes = categories(TTrain);
numClasses = numel(classes)
numClasses = 9

1 次元畳み込みネットワーク アーキテクチャの定義

1 次元畳み込みニューラル ネットワーク アーキテクチャを定義します。

  • 入力サイズを入力データの特徴の数として指定する。

  • 畳み込み層のフィルター サイズが 3 である 1 次元畳み込み層、ReLU 層、およびレイヤー正規化層から成るブロックを 2 つ指定する。32 個のフィルターと 64 個のフィルターを最初と 2 番目の畳み込み層にそれぞれ指定する。どちらの畳み込み層に対しても、出力の長さが同じになるように入力を左パディングする (因果的パディング)。

  • 畳み込み層の出力を単一のベクトルに減らすために、1 次元グローバル平均プーリング層を使用する。

  • 出力を確率のベクトルにマッピングするために、出力サイズとクラスの数が一致する全結合層を指定し、その後にソフトマックス層と分類層を配置する。

filterSize = 3;
numFilters = 32;

layers = [ ...
    sequenceInputLayer(numFeatures)
    convolution1dLayer(filterSize,numFilters,Padding="causal")
    reluLayer
    layerNormalizationLayer
    convolution1dLayer(filterSize,2*numFilters,Padding="causal")
    reluLayer
    layerNormalizationLayer
    globalAveragePooling1dLayer
    fullyConnectedLayer(numClasses)
    softmaxLayer
    classificationLayer];

学習オプションの指定

学習オプションを指定します。

  • Adam オプティマイザーを使用して学習させる。

  • ミニバッチ サイズを 27 として、学習を 15 エポック行う。

  • シーケンスを左パディングする。

  • 検証データを使用してネットワークを検証します。

  • プロットに表示される学習の進行状況を監視し、詳細出力を非表示にする。

miniBatchSize = 27;

options = trainingOptions("adam", ...
    MiniBatchSize=miniBatchSize, ...
    MaxEpochs=15, ...
    SequencePaddingDirection="left", ...
    ValidationData={XValidation,TValidation}, ...
    Plots="training-progress", ...
    Verbose=0);

ネットワークの学習

関数 trainNetwork を使用し、指定した学習オプションでネットワークに学習させます。

net = trainNetwork(XTrain,TTrain,layers,options);

Figure Training Progress (03-Mar-2023 09:03:23) contains 2 axes objects and another object of type uigridlayout. Axes object 1 with xlabel Iteration, ylabel Loss contains 8 objects of type patch, text, line. Axes object 2 with xlabel Iteration, ylabel Accuracy (%) contains 8 objects of type patch, text, line.

ネットワークのテスト

学習に使用したものと同じミニバッチ サイズ、シーケンスのパディングのオプションを使用して、検証データを分類します。

YPred = classify(net,XValidation, ...
    MiniBatchSize=miniBatchSize, ...
    SequencePaddingDirection="left");

予測の分類精度を計算します。

acc = mean(YPred == TValidation)
acc = 0.9514

混同行列で予測を可視化します。

confusionchart(TValidation,YPred)

Figure contains an object of type ConfusionMatrixChart.

参考文献

[1] Kudo, Mineichi, Jun Toyama, and Masaru Shimbo.“Multidimensional Curve Classification Using Passing-through Regions.” Pattern Recognition Letters 20, no. 11–13 (November 1999): 1103–11. https://doi.org/10.1016/S0167-8655(99)00077-X

[2] Kudo, Mineichi, Jun Toyama, and Masaru Shimbo."Japanese Vowels Data Set." Distributed by UCI Machine Learning Repository. https://archive.ics.uci.edu/ml/datasets/Japanese+Vowels

参考

| | | | | | |

関連するトピック