メインコンテンツ

表形式データを使用したニューラル ネットワークの学習

R2023b 以降

この例では、表形式のデータを使用してニューラル ネットワークに学習させる方法を示します。

数値特徴量およびカテゴリカル特徴量のデータ セット (空間次元や時間次元のない表形式データなど) がある場合、特徴入力層を使用して深層ニューラル ネットワークに学習させることができます。この例では、数値形式およびカテゴリカル形式のセンサー読み取り値から成るテーブルに基づいて、ギアの歯の状態を予測するニューラル ネットワークに学習させます。

学習データの読み込み

CSV ファイル "transmissionCasingData.csv" からトランスミッション ケーシング データを読み取ります。

filename = "transmissionCasingData.csv";
tbl = readtable(filename,TextType="String");

関数 convertvars を使用して、予測のラベルを categorical に変換します。

labelName = "GearToothCondition";
tbl = convertvars(tbl,labelName,"categorical");

データ セットのクラス名を表示します。

classNames = categories(tbl.(labelName))
classNames = 2×1 cell
    {'No Tooth Fault'}
    {'Tooth Fault'   }

カテゴリカル特徴量を使用してネットワークに学習させるには、最初にカテゴリカル特徴量を categorical データ型に変換しなければなりません。convertvars 関数を使用して、すべてのカテゴリカル入力変数の名前を格納した string 配列を指定することにより、カテゴリカル予測子を categorical に変換します。このデータ セットには、"SensorCondition""ShaftCondition" という名前の 2 つのカテゴリカル特徴量があります。

categoricalPredictorNames = ["SensorCondition" "ShaftCondition"];
tbl = convertvars(tbl,categoricalPredictorNames,"categorical");

テスト用のデータを確保します。データの 80% を含む学習セット、データの 10% を含む検証セット、およびデータの残りの 10% を含むテスト セットにデータを分割します。データを分割するには、この例にサポート ファイルとして添付されている関数 trainingPartitions を使用します。このファイルにアクセスするには、例をライブ スクリプトとして開きます。

numObservations = size(tbl,1);
[idxTrain,idxValidation,idxTest] = trainingPartitions(numObservations,[0.80 0.1 0.1]);

tblTrain = tbl(idxTrain,:);
tblValidation = tbl(idxValidation,:);
tblTest = tbl(idxTest,:);

ニューラル ネットワーク アーキテクチャの定義

この例では、カテゴリカル入力特徴量を one-hot 符号化して、その特徴量をニューラル ネットワークに学習させます。ニューラル ネットワークの入力サイズを指定するには、one-hot 符号化された categorical データを含む入力特徴量の数を計算します。特徴量の数は、学習データの数値列の数に、カテゴリカル予測子のカテゴリの合計数を足した数です。

numCategoricalPredictors = numel(categoricalPredictorNames);
numFeatures = size(tblTrain,2) - numCategoricalPredictors - 1;

for name = categoricalPredictorNames
    numCategories = numel(categories(tblTrain.(name)));
    numFeatures = numFeatures + numCategories;
end

ニューラル ネットワーク アーキテクチャを定義します。

  • 特徴入力用に、特徴量の数と一致する入力サイズをもつ特徴入力層を指定します。

  • サイズが 16 の全結合層を指定し、その後に正規化層と ReLU 層を指定します。

  • 分類出力用に、クラス数と一致するサイズをもつ全結合層を指定し、その後にソフトマックス層を指定します。

hiddenSize = 16;

numClasses = numel(classNames);

layers = [
    featureInputLayer(numFeatures)
    fullyConnectedLayer(hiddenSize)
    layerNormalizationLayer
    reluLayer
    fullyConnectedLayer(numClasses)
    softmaxLayer];

学習オプションの指定

学習オプションを指定します。

  • L-BFGS ソルバーを使用して学習させます。このソルバーは、ネットワークが小さくデータがメモリに収まるタスクに適しています。

  • CPU を使用して学習させます。ネットワークとデータが小さいため、CPU の方がより適しています。

  • カテゴリカル入力を one-hot 符号化します。

  • 検証データを使用して、5 回の反復ごとにネットワークを検証します。

  • 検証損失が最小のネットワークを返します。

  • 学習の進行状況をプロットで表示し、精度メトリクスを監視します。

  • 詳細出力を非表示にします。

options = trainingOptions("lbfgs", ...
    ExecutionEnvironment="cpu", ...
    CategoricalInputEncoding="one-hot", ...
    ValidationData=tblValidation, ...
    ValidationFrequency=5, ...
    OutputNetwork="best-validation", ...
    Plots="training-progress", ...
    Metrics="accuracy", ...
    Verbose=false);

ニューラル ネットワークの学習

関数 trainnet を使用してニューラル ネットワークに学習させます。分類には、クロスエントロピー損失を使用します。

[net,info] = trainnet(tblTrain,layers,"crossentropy",options);

このプロットは、学習と検証の精度と損失を示しています。学習が完了すると、プロットに停止の理由が表示されます。L-BFGS ソルバーを使用する場合、直線探索が失敗して適切な学習率を見つけられなかったことが停止の理由として示される場合があります。このシナリオは、ソルバーが最小損失値に素早く到達した場合、またはステップや勾配ノルムがゼロに近い場合に発生する可能性があります。

ニューラル ネットワークのテスト

学習済みネットワークを使用してテスト データのラベルを予測します。学習済みネットワークを使用して分類スコアを予測し、関数 onehotdecode を使用して予測結果をラベルに変換します。

testnet関数を使用してニューラル ネットワークをテストします。

  • 単一ラベルの分類では、精度を評価します。精度は、正しい予測の割合です。

  • カテゴリカル入力を one-hot 符号化します。

  • CPU を使用してニューラル ネットワークをテストします。

accuracy = testnet(net,tblTest,"accuracy", ...
    CategoricalInputEncoding="one-hot", ...
    ExecutionEnvironment="cpu")
accuracy = 
86.3636

ターゲットをテスト データから分離し、予測を行って、スコアをラベルに変換することで、混同チャートで予測を可視化します。

ターゲットをデータから分離します。

TTest = tblTest.(labelName);

minibatchpredict関数を使用して予測を行い、scores2label関数を使用して分類スコアをラベルに変換します。

  • カテゴリカル入力を one-hot 符号化します。

  • CPU を使用して予測を行います。

scoresTest = minibatchpredict(net,tblTest, ...
    CategoricalInputEncoding="one-hot", ...
    ExecutionEnvironment="cpu");

YTest = scores2label(scoresTest,classNames);

混同チャートで予測を可視化します。

confusionchart(TTest,YTest)

Figure contains an object of type ConfusionMatrixChart.

参考

| | | |

トピック