Main Content

数値特徴量を使用したネットワークの学習

この例では、深層学習による特徴データ分類用のシンプルなニューラル ネットワークを作成し、学習を行う方法を説明します。

数値特徴量のデータ セット (空間次元や時間次元のない数値データの集合など) がある場合、特徴入力層を使用して深層学習ネットワークに学習させることができます。イメージ分類用のネットワークに学習させる方法を示す例については、分類用のシンプルな深層学習ネットワークの作成を参照してください。

この例では、数値センサーの読み取り値、統計量、カテゴリカル ラベルの混合を所与として、トランスミッション システムの歯車の状態を分類するよう、ネットワークに学習させる方法について説明します。

データの読み込み

学習用のトランスミッション ケーシング データセットを読み込みます。このデータ セットには、以下のような 18 種類の読み取り値と 3 種類のカテゴリカル ラベルで構成される、トランスミッション システムの 208 個の合成読み取り値が含まれています。

  1. SigMean — 振動信号の平均

  2. SigMedian — 振動信号の中央値

  3. SigRMS — 振動信号の RMS

  4. SigVar — 振動信号の分散

  5. SigPeak — 振動信号のピーク

  6. SigPeak2Peak — 振動信号のピーク ツー ピーク

  7. SigSkewness — 振動信号の歪度

  8. SigKurtosis — 振動信号の尖度

  9. SigCrestFactor — 振動信号のクレスト ファクター

  10. SigMAD — 振動信号の MAD

  11. SigRangeCumSum — 振動信号の範囲累積和

  12. SigCorrDimension — 振動信号の相関次元

  13. SigApproxEntropy — 振動信号の近似エントロピー

  14. SigLyapExponent — 振動信号のリアプノフ指数

  15. PeakFreq — ピーク周波数。

  16. HighFreqPower — 高周波のべき乗

  17. EnvPower — 環境のべき乗

  18. PeakSpecKurtosis — スペクトル尖度のピーク周波数

  19. SensorCondition — センサーの状態。"Sensor Drift" または "No Sensor Drift" として指定します。

  20. ShaftCondition — シャフトの状態。"Shaft Wear" または "No Shaft Wear" として指定します。

  21. GearToothCondition — 歯車の状態。"Tooth Fault" または "No Tooth Fault" として指定します。

CSV ファイル "transmissionCasingData.csv" からトランスミッション ケーシング データを読み取ります。

filename = "transmissionCasingData.csv";
tbl = readtable(filename,'TextType','String');

関数 convertvars を使用して、予測のラベルを categorical に変換します。

labelName = "GearToothCondition";
tbl = convertvars(tbl,labelName,'categorical');

table の最初の数行を表示します。

head(tbl)
    SigMean     SigMedian    SigRMS    SigVar     SigPeak    SigPeak2Peak    SigSkewness    SigKurtosis    SigCrestFactor    SigMAD     SigRangeCumSum    SigCorrDimension    SigApproxEntropy    SigLyapExponent    PeakFreq    HighFreqPower    EnvPower    PeakSpecKurtosis    SensorCondition    ShaftCondition     GearToothCondition
    ________    _________    ______    _______    _______    ____________    ___________    ___________    ______________    _______    ______________    ________________    ________________    _______________    ________    _____________    ________    ________________    _______________    _______________    __________________

    -0.94876     -0.9722     1.3726    0.98387    0.81571       3.6314        -0.041525       2.2666           2.0514         0.8081        28562              1.1429             0.031581            79.931            0          6.75e-06       3.23e-07         162.13         "Sensor Drift"     "No Shaft Wear"      No Tooth Fault  
    -0.97537    -0.98958     1.3937    0.99105    0.81571       3.6314        -0.023777       2.2598           2.0203        0.81017        29418              1.1362             0.037835            70.325            0          5.08e-08       9.16e-08         226.12         "Sensor Drift"     "No Shaft Wear"      No Tooth Fault  
      1.0502      1.0267     1.4449    0.98491     2.8157       3.6314         -0.04162       2.2658           1.9487        0.80853        31710              1.1479             0.031565            125.19            0          6.74e-06       2.85e-07         162.13         "Sensor Drift"     "Shaft Wear"         No Tooth Fault  
      1.0227      1.0045     1.4288    0.99553     2.8157       3.6314        -0.016356       2.2483           1.9707        0.81324        30984              1.1472             0.032088             112.5            0          4.99e-06        2.4e-07         162.13         "Sensor Drift"     "Shaft Wear"         No Tooth Fault  
      1.0123      1.0024     1.4202    0.99233     2.8157       3.6314        -0.014701       2.2542           1.9826        0.81156        30661              1.1469              0.03287            108.86            0          3.62e-06       2.28e-07         230.39         "Sensor Drift"     "Shaft Wear"         No Tooth Fault  
      1.0275      1.0102     1.4338     1.0001     2.8157       3.6314         -0.02659       2.2439           1.9638        0.81589        31102              1.0985             0.033427            64.576            0          2.55e-06       1.65e-07         230.39         "Sensor Drift"     "Shaft Wear"         No Tooth Fault  
      1.0464      1.0275     1.4477     1.0011     2.8157       3.6314        -0.042849       2.2455           1.9449        0.81595        31665              1.1417             0.034159            98.838            0          1.73e-06       1.55e-07         230.39         "Sensor Drift"     "Shaft Wear"         No Tooth Fault  
      1.0459      1.0257     1.4402    0.98047     2.8157       3.6314        -0.035405       2.2757            1.955        0.80583        31554              1.1345               0.0353            44.223            0          1.11e-06       1.39e-07         230.39         "Sensor Drift"     "Shaft Wear"         No Tooth Fault  

カテゴリカル特徴量を使用してネットワークに学習させるには、最初にカテゴリカル特徴量を数値に変換しなければなりません。まず、関数 convertvars を使用して、すべてのカテゴリカル入力変数の名前を格納した string 配列を指定することにより、カテゴリカル予測子を categorical に変換します。このデータ セットには、"SensorCondition""ShaftCondition" という名前の 2 つのカテゴリカル特徴量があります。

categoricalInputNames = ["SensorCondition" "ShaftCondition"];
tbl = convertvars(tbl,categoricalInputNames,'categorical');

カテゴリカル入力変数をループ処理します。各変数について次を行います。

  • 関数 onehotencode を使用して、categorical 値を one-hot 符号化ベクトルに変換する。

  • 関数 addvars を使用して、one-hot ベクトルを table に追加する。対応する categorical データが含まれる列の後にベクトルを挿入するように指定する。

  • categorical データが含まれる対応する列を削除する。

for i = 1:numel(categoricalInputNames)
    name = categoricalInputNames(i);
    oh = onehotencode(tbl(:,name));
    tbl = addvars(tbl,oh,'After',name);
    tbl(:,name) = [];
end

関数 splitvars を使用して、ベクトルを別々の列に分割します。

tbl = splitvars(tbl);

table の最初の数行を表示します。カテゴリカル予測子が、categorical 値を変数名として複数の列に分割されていることに注意してください。

head(tbl)
    SigMean     SigMedian    SigRMS    SigVar     SigPeak    SigPeak2Peak    SigSkewness    SigKurtosis    SigCrestFactor    SigMAD     SigRangeCumSum    SigCorrDimension    SigApproxEntropy    SigLyapExponent    PeakFreq    HighFreqPower    EnvPower    PeakSpecKurtosis    No Sensor Drift    Sensor Drift    No Shaft Wear    Shaft Wear    GearToothCondition
    ________    _________    ______    _______    _______    ____________    ___________    ___________    ______________    _______    ______________    ________________    ________________    _______________    ________    _____________    ________    ________________    _______________    ____________    _____________    __________    __________________

    -0.94876     -0.9722     1.3726    0.98387    0.81571       3.6314        -0.041525       2.2666           2.0514         0.8081        28562              1.1429             0.031581            79.931            0          6.75e-06       3.23e-07         162.13                0                1                1              0           No Tooth Fault  
    -0.97537    -0.98958     1.3937    0.99105    0.81571       3.6314        -0.023777       2.2598           2.0203        0.81017        29418              1.1362             0.037835            70.325            0          5.08e-08       9.16e-08         226.12                0                1                1              0           No Tooth Fault  
      1.0502      1.0267     1.4449    0.98491     2.8157       3.6314         -0.04162       2.2658           1.9487        0.80853        31710              1.1479             0.031565            125.19            0          6.74e-06       2.85e-07         162.13                0                1                0              1           No Tooth Fault  
      1.0227      1.0045     1.4288    0.99553     2.8157       3.6314        -0.016356       2.2483           1.9707        0.81324        30984              1.1472             0.032088             112.5            0          4.99e-06        2.4e-07         162.13                0                1                0              1           No Tooth Fault  
      1.0123      1.0024     1.4202    0.99233     2.8157       3.6314        -0.014701       2.2542           1.9826        0.81156        30661              1.1469              0.03287            108.86            0          3.62e-06       2.28e-07         230.39                0                1                0              1           No Tooth Fault  
      1.0275      1.0102     1.4338     1.0001     2.8157       3.6314         -0.02659       2.2439           1.9638        0.81589        31102              1.0985             0.033427            64.576            0          2.55e-06       1.65e-07         230.39                0                1                0              1           No Tooth Fault  
      1.0464      1.0275     1.4477     1.0011     2.8157       3.6314        -0.042849       2.2455           1.9449        0.81595        31665              1.1417             0.034159            98.838            0          1.73e-06       1.55e-07         230.39                0                1                0              1           No Tooth Fault  
      1.0459      1.0257     1.4402    0.98047     2.8157       3.6314        -0.035405       2.2757            1.955        0.80583        31554              1.1345               0.0353            44.223            0          1.11e-06       1.39e-07         230.39                0                1                0              1           No Tooth Fault  

データ セットのクラス名を表示します。

classNames = categories(tbl{:,labelName})
classNames = 2x1 cell
    {'No Tooth Fault'}
    {'Tooth Fault'   }

学習セットと検証セットへのデータ セットの分割

データ セットを学習用、検証用、およびテスト用の区画に分割します。データの 15% を検証用に、15% をテスト用に残しておきます。

データセットに含まれている観測値の数を表示します。

numObservations = size(tbl,1)
numObservations = 208

各区画の観測数を求めます。

numObservationsTrain = floor(0.7*numObservations)
numObservationsTrain = 145
numObservationsValidation = floor(0.15*numObservations)
numObservationsValidation = 31
numObservationsTest = numObservations - numObservationsTrain - numObservationsValidation
numObservationsTest = 32

観測値に対応するランダムなインデックスの配列を作成し、区画サイズを使用して分割します。

idx = randperm(numObservations);
idxTrain = idx(1:numObservationsTrain);
idxValidation = idx(numObservationsTrain+1:numObservationsTrain+numObservationsValidation);
idxTest = idx(numObservationsTrain+numObservationsValidation+1:end);

インデックスを使用して、データの table を学習用、検証用、およびテスト用の区画に分割します。

tblTrain = tbl(idxTrain,:);
tblValidation = tbl(idxValidation,:);
tblTest = tbl(idxTest,:);

ネットワーク アーキテクチャの定義

分類のためにネットワークを定義します。

特徴入力層を使用してネットワークを定義し、特徴の数を指定します。また、z スコア正規化を使用してデータを正規化するように入力層を構成します。次に、出力サイズが 50 の全結合層を含め、その後にバッチ正規化層と ReLU 層を配置します。分類に向けて、出力サイズがクラスの数に対応する別の全結合層を指定し、その後にソフトマックス層と分類層を配置します。

numFeatures = size(tbl,2) - 1;
numClasses = numel(classNames);
 
layers = [
    featureInputLayer(numFeatures,'Normalization', 'zscore')
    fullyConnectedLayer(50)
    batchNormalizationLayer
    reluLayer
    fullyConnectedLayer(numClasses)
    softmaxLayer
    classificationLayer];

学習オプションの指定

学習オプションを指定します。

  • Adam を使用してネットワークに学習させます。

  • サイズ 16 のミニバッチを使用して学習を行います。

  • すべてのエポックでデータをシャッフルします。

  • 検証データを指定して、学習中にネットワークの精度を監視します。

  • プロットに学習の進行状況を表示し、詳細なコマンド ウィンドウ出力を表示しないようにします。

学習データでネットワークに学習させ、学習中に一定の間隔で検証データに対してその精度を計算します。検証データは、ネットワークの重みの更新には使用されません。

miniBatchSize = 16;

options = trainingOptions('adam', ...
    'MiniBatchSize',miniBatchSize, ...
    'Shuffle','every-epoch', ...
    'ValidationData',tblValidation, ...
    'Plots','training-progress', ...
    'Verbose',false);

ネットワークの学習

layers、学習データ、および学習オプションによって定義されたアーキテクチャを使用して、ネットワークに学習させます。既定で、trainNetwork は、使用可能な GPU があれば GPU を使用し、なければ CPU を使用します。GPU で学習を行うには、Parallel Computing Toolbox™ とサポートされている GPU デバイスが必要です。サポートされているデバイスについては、GPU 計算の要件 (Parallel Computing Toolbox)を参照してください。trainingOptions の名前と値のペアの引数 'ExecutionEnvironment' を使用して、実行環境を指定することもできます。

学習の進行状況プロットには、ミニバッチの損失と精度、および検証の損失と精度が表示されます。学習の進行状況プロットの詳細は、深層学習における学習の進行状況の監視を参照してください。

net = trainNetwork(tblTrain,labelName,layers,options);

{"String":"Figure Training Progress (12-Aug-2022 18:34:28) contains 2 axes objects and another object of type uigridlayout. Axes object 1 contains 10 objects of type patch, text, line. Axes object 2 contains 10 objects of type patch, text, line.","Tex":[],"LaTex":[]}

ネットワークのテスト

学習済みネットワークを使用してテスト データのラベルを予測し、精度を計算します。学習に使用されるサイズと同じミニバッチ サイズを指定します。

YPred = classify(net,tblTest(:,1:end-1),'MiniBatchSize',miniBatchSize);

分類精度を計算します。精度は、ネットワークが正しく予測するラベルの比率です。

YTest = tblTest{:,labelName};
accuracy = sum(YPred == YTest)/numel(YTest)
accuracy = 0.9688

結果を混同行列で表示します。

figure
confusionchart(YTest,YPred)

Figure contains an object of type ConfusionMatrixChart.

参考

| | | |

関連する例

詳細