Main Content

回帰ニューラル ネットワークの性能評価

fitrnet を使用して全結合層をもつフィードフォワード回帰ニューラル ネットワーク モデルを作成します。モデルの過適合を防止するために、検証データを使用して学習プロセスを早期に停止します。その後、モデルのオブジェクト関数を使用して検定データで性能を評価します。

標本データの読み込み

carbig データセットを読み込みます。このデータセットには、1970 年代と 1980 年代初期に製造された自動車の測定値が格納されています。

load carbig

変数 Origin をカテゴリカル変数に変換します。その後、AccelerationDisplacement などの予測子変数と応答変数 MPG を格納する table を作成します。各行に 1 台の自動車の測定値を格納します。欠損値がある table の行を table から削除します。

Origin = categorical(cellstr(Origin));
Tbl = table(Acceleration,Displacement,Horsepower, ...
    Model_Year,Origin,Weight,MPG);
Tbl = rmmissing(Tbl);

データの分割

データを学習セット、検証セット、および検定セットに分割します。まず、観測値の約 3 分の 1 を検定セット用に予約します。その後、残りのデータを半分に分割して学習セットと検証セットを作成します。

rng("default") % For reproducibility of the data partitions
cvp1 = cvpartition(size(Tbl,1),"Holdout",1/3);
testTbl = Tbl(test(cvp1),:);
remainingTbl = Tbl(training(cvp1),:);

cvp2 = cvpartition(size(remainingTbl,1),"Holdout",1/2);
validationTbl = remainingTbl(test(cvp2),:);
trainTbl = remainingTbl(training(cvp2),:);

ニューラル ネットワークの学習

学習セットを使用して回帰ニューラル ネットワーク モデルに学習させます。tblTrain の列 MPG を応答変数として指定し、数値予測子を標準化します。検証セットを使用して各反復でモデルを評価します。名前と値の引数 Verbose を使用して、各反復で学習データを表示するように指定します。既定では、検証損失が 6 回連続でそれまでに計算された検証損失の最小値以上になると、その時点で学習プロセスが早期に終了します。検証損失が最小値以上になる許容回数を変更するには、名前と値の引数 ValidationPatience を指定します。

Mdl = fitrnet(trainTbl,"MPG","Standardize",true, ...
    "ValidationData",validationTbl, ...
    "Verbose",1);
|==========================================================================================|
| Iteration  | Train Loss | Gradient   | Step       | Iteration  | Validation | Validation |
|            |            |            |            | Time (sec) | Loss       | Checks     |
|==========================================================================================|
|           1|  102.962345|   46.853164|    6.700877|    0.025382|  115.730384|           0|
|           2|   55.403995|   22.171181|    1.811805|    0.008526|   53.086379|           0|
|           3|   37.588848|   11.135231|    0.782861|    0.002037|   38.580002|           0|
|           4|   29.713458|    8.379231|    0.392009|    0.000646|   31.021379|           0|
|           5|   17.523851|    9.958164|    2.137584|    0.001807|   17.594863|           0|
|           6|   12.700624|    2.957771|    0.744551|    0.000633|   14.209019|           0|
|           7|   11.841152|    1.907378|    0.201770|    0.000737|   13.159899|           0|
|           8|   10.162988|    2.542555|    0.576907|    0.000708|   11.352490|           0|
|           9|    8.889095|    2.779980|    0.615716|    0.000655|   10.446334|           0|
|          10|    7.670335|    2.400272|    0.648711|    0.000596|   10.424337|           0|
|==========================================================================================|
| Iteration  | Train Loss | Gradient   | Step       | Iteration  | Validation | Validation |
|            |            |            |            | Time (sec) | Loss       | Checks     |
|==========================================================================================|
|          11|    7.416274|    0.505111|    0.214707|    0.005149|   10.522517|           1|
|          12|    7.338923|    0.880655|    0.119085|    0.015136|   10.648031|           2|
|          13|    7.149407|    1.784821|    0.277908|    0.003024|   10.800952|           3|
|          14|    6.866385|    1.904480|    0.472190|    0.002526|   10.839202|           4|
|          15|    6.815575|    3.339285|    0.943063|    0.001086|   10.031692|           0|
|          16|    6.428137|    0.684771|    0.133729|    0.000924|    9.867819|           0|
|          17|    6.363299|    0.456606|    0.125363|    0.000946|    9.720076|           0|
|          18|    6.289887|    0.742923|    0.152290|    0.000713|    9.576588|           0|
|          19|    6.215407|    0.964684|    0.183503|    0.000685|    9.422910|           0|
|          20|    6.078333|    2.124971|    0.566948|    0.000708|    9.599573|           1|
|==========================================================================================|
| Iteration  | Train Loss | Gradient   | Step       | Iteration  | Validation | Validation |
|            |            |            |            | Time (sec) | Loss       | Checks     |
|==========================================================================================|
|          21|    5.947923|    1.217291|    0.583867|    0.000696|    9.618400|           2|
|          22|    5.855505|    0.671774|    0.285123|    0.000807|    9.734680|           3|
|          23|    5.831802|    1.882061|    0.657368|    0.000697|   10.365968|           4|
|          24|    5.713261|    1.004072|    0.134719|    0.000744|   10.314258|           5|
|          25|    5.520766|    0.967032|    0.290156|    0.000721|   10.177322|           6|
|==========================================================================================|

オブジェクト MdlTrainingHistory プロパティ内の情報を使用して、検証の平均二乗誤差 (MSE) が最小になる対応する反復を確認します。最終的に返されるモデル Mdl は、この反復で学習させたモデルになります。

iteration = Mdl.TrainingHistory.Iteration;
valLosses = Mdl.TrainingHistory.ValidationLoss;
[~,minIdx] = min(valLosses);
iteration(minIdx)
ans = 19

検定セットのパフォーマンスの評価

オブジェクト関数 loss および predict を使用して、学習させたモデル Mdl の性能を検定セット testTbl で評価します。

検定セットの平均二乗誤差 (MSE) を計算します。MSE の値が小さいほど、パフォーマンスが優れていることを示します。

mse = loss(Mdl,testTbl,"MPG")
mse = 7.4101

検定セットの予測応答値と実際の応答値を比較します。予測されるガロンあたりの走行マイル数 (MPG) を縦軸に、実際の MPG を横軸にしてプロットします。基準線上にある点は予測が正しいことを示します。優れたモデルでは、生成された予測が線の近くに分布します。

predictedY = predict(Mdl,testTbl);

plot(testTbl.MPG,predictedY,".")
hold on
plot(testTbl.MPG,testTbl.MPG)
hold off
xlabel("True Miles Per Gallon (MPG)")
ylabel("Predicted Miles Per Gallon (MPG)")

Figure contains an axes object. The axes object contains 2 objects of type line.

箱ひげ図を使用して、生産国別に MPG の予測値と実際の値の分布を比較します。関数 boxchart を使用して箱ひげ図を作成します。各箱ひげ図には、中央値、第 1 四分位数と第 3 四分位数、外れ値 (四分位数間範囲を使用して計算)、および外れ値ではない最小値と最大値を表示します。特に、各ボックスの内側の線は標本の中央値であり、円形のマーカーは外れ値を示します。

各生産国について、赤色の箱ひげ図 (MPG の予測値の分布を示す) と青色の箱ひげ図 (MPG の実際の値の分布を示す) を比較します。MPG の予測値と実際の値の分布が似ていれば、予測が優れていることを示します。

boxchart(testTbl.Origin,testTbl.MPG)
hold on
boxchart(testTbl.Origin,predictedY)
hold off
legend(["True MPG","Predicted MPG"])
xlabel("Country of Origin")
ylabel("Miles Per Gallon (MPG)")

Figure contains an axes object. The axes object contains 2 objects of type boxchart. These objects represent True MPG, Predicted MPG.

ほとんどの国については、MPG の予測値と実際の値が同じような分布になっています。一部に相違があるのは、学習セットと検定セットで自動車の数が少ないことが原因と考えられます。

学習セットと検定セットで、自動車についての MPG の値の範囲を比較します。

trainSummary = grpstats(trainTbl(:,["MPG","Origin"]),"Origin", ...
    "range")
trainSummary=6×3 table
               Origin     GroupCount    range_MPG
               _______    __________    _________

    France     France          2           1.2   
    Germany    Germany        12          23.4   
    Italy      Italy           1             0   
    Japan      Japan          26          26.6   
    Sweden     Sweden          4             8   
    USA        USA            86            27   

testSummary = grpstats(testTbl(:,["MPG","Origin"]),"Origin", ...
    "range")
testSummary=6×3 table
               Origin     GroupCount    range_MPG
               _______    __________    _________

    France     France          4          19.8   
    Germany    Germany        13          20.3   
    Italy      Italy           4          11.3   
    Japan      Japan          26          25.6   
    Sweden     Sweden          1             0   
    USA        USA            82            29   

フランス、イタリア、スウェーデンなど、学習セットと検定セットの自動車が少ない国において、MPG の値の範囲が両方のセットで有意に異なっています。

検定セットの残差をプロットします。通常、優れたモデルでは残差が 0 の近くでほぼ対称的に配置されます。残差に明確なパターンがある場合、モデルを改善できる可能性があります。

residuals = testTbl.MPG - predictedY;
plot(testTbl.MPG,residuals,".")
hold on
yline(0)
hold off
xlabel("True Miles Per Gallon (MPG)")
ylabel("MPG Residuals")

Figure contains an axes object. The axes object contains 2 objects of type line, constantline.

プロットから、残差が適切に分散していることがわかります。

絶対値を基準に、残差が大きい観測値についての詳細を取得できます。

[~,residualIdx] = sort(residuals,"descend", ...
    "ComparisonMethod","abs");
residuals(residualIdx)
ans = 130×1

   -8.8469
    8.4427
    8.0493
    7.8996
   -6.2220
    5.8589
    5.7007
   -5.6733
   -5.4545
    5.1899
      ⋮

残差が大きい、つまり振幅が 8 を超えている 3 つの観測値を表示します。

testTbl(residualIdx(1:3),:)
ans=3×7 table
    Acceleration    Displacement    Horsepower    Model_Year    Origin    Weight    MPG 
    ____________    ____________    __________    __________    ______    ______    ____

        17.6             91             68            82        Japan      1970       31
        11.4            168            132            80        Japan      2910     32.7
        13.8             91             67            80        Japan      1850     44.6

参考

| | | |