Main Content

このページの翻訳は最新ではありません。ここをクリックして、英語の最新版を参照してください。

回帰ニューラル ネットワークの性能評価

fitrnet を使用して全結合層をもつフィードフォワード回帰ニューラル ネットワーク モデルを作成します。モデルの過適合を防止するために、検証データを使用して学習プロセスを早期に停止します。その後、モデルのオブジェクト関数を使用して検定データで性能を評価します。

標本データの読み込み

carbig データセットを読み込みます。このデータセットには、1970 年代と 1980 年代初期に製造された自動車の測定値が格納されています。

load carbig

変数 Origin をカテゴリカル変数に変換します。その後、AccelerationDisplacement などの予測子変数と応答変数 MPG を格納する table を作成します。各行に 1 台の自動車の測定値を格納します。

Origin = categorical(cellstr(Origin));
Tbl = table(Acceleration,Displacement,Horsepower, ...
    Model_Year,Origin,Weight,MPG);

データの分割

データを学習セット、検証セット、および検定セットに分割します。まず、観測値の約 3 分の 1 を検定セット用に予約します。その後、残りのデータを半分に分割して学習セットと検証セットを作成します。

rng("default") % For reproducibility of the data partitions
cvp1 = cvpartition(size(Tbl,1),"Holdout",1/3);
testTbl = Tbl(test(cvp1),:);
remainingTbl = Tbl(training(cvp1),:);

cvp2 = cvpartition(size(remainingTbl,1),"Holdout",1/2);
validationTbl = remainingTbl(test(cvp2),:);
trainTbl = remainingTbl(training(cvp2),:);

ニューラル ネットワークの学習

学習セットを使用して回帰ニューラル ネットワーク モデルに学習させます。tblTrain の列 MPG を応答変数として指定し、数値予測子を標準化します。検証セットを使用して各反復でモデルを評価します。名前と値の引数 Verbose を使用して、各反復で学習データを表示するように指定します。既定では、検証損失が 6 回連続でそれまでに計算された検証損失の最小値以上になると、その時点で学習プロセスが早期に終了します。検証損失が最小値以上になる許容回数を変更するには、名前と値の引数 ValidationPatience を指定します。

Mdl = fitrnet(trainTbl,"MPG","Standardize",true, ...
    "ValidationData",validationTbl, ...
    "Verbose",1);
|==========================================================================================|
| Iteration  | Train Loss | Gradient   | Step       | Iteration  | Validation | Validation |
|            |            |            |            | Time (sec) | Loss       | Checks     |
|==========================================================================================|
|           1|   71.063537|   22.623354|    6.466959|    0.001272|   72.648960|           0|
|           2|   48.608700|   22.384995|    1.022929|    0.001808|   43.435698|           0|
|           3|   30.584887|   13.433471|    0.537190|    0.000903|   29.134447|           0|
|           4|   17.781636|   11.159801|    1.401355|    0.000461|   16.542207|           0|
|           5|   13.075804|    4.605991|    0.419875|    0.000387|   12.946670|           0|
|           6|   11.697936|    3.197944|    0.226945|    0.000543|   12.025502|           0|
|           7|    9.494801|    2.269831|    0.751711|    0.000452|   12.596499|           1|
|           8|    8.390979|    1.970589|    0.337301|    0.000398|   11.490990|           0|
|           9|    6.853097|    1.029078|    0.866974|    0.000378|    9.449945|           0|
|          10|    6.531678|    0.924820|    0.306913|    0.000429|    9.350721|           0|
|==========================================================================================|
| Iteration  | Train Loss | Gradient   | Step       | Iteration  | Validation | Validation |
|            |            |            |            | Time (sec) | Loss       | Checks     |
|==========================================================================================|
|          11|    6.152995|    1.872684|    0.457744|    0.000403|    9.223829|           0|
|          12|    5.924852|    0.718386|    0.447879|    0.000402|    9.656166|           1|
|          13|    5.792836|    0.500170|    0.216351|    0.000387|    9.733226|           2|
|          14|    5.613473|    1.151197|    0.316828|    0.000531|    9.788646|           3|
|          15|    5.415889|    1.513493|    0.327937|    0.000485|    9.607953|           4|
|          16|    5.008195|    1.398069|    1.085660|    0.000430|    9.251971|           5|
|          17|    5.004176|    2.070041|    0.890201|    0.000383|    8.719334|           0|
|          18|    4.738386|    0.483667|    0.338897|    0.000374|    8.523728|           0|
|          19|    4.680213|    0.437918|    0.107667|    0.000371|    8.369271|           0|
|          20|    4.587350|    0.510639|    0.146276|    0.000385|    8.100236|           0|
|==========================================================================================|
| Iteration  | Train Loss | Gradient   | Step       | Iteration  | Validation | Validation |
|            |            |            |            | Time (sec) | Loss       | Checks     |
|==========================================================================================|
|          21|    4.479929|    0.565635|    0.228198|    0.000381|    8.062927|           0|
|          22|    4.380618|    0.892717|    0.377776|    0.000554|    7.843234|           0|
|          23|    4.189344|    0.403227|    0.362307|    0.000434|    7.834582|           0|
|          24|    4.182775|    1.150234|    1.908768|    0.000408|    9.436226|           1|
|          25|    3.985939|    0.908479|    0.518217|    0.000570|    8.973756|           2|
|          26|    3.873835|    0.826655|    0.477740|    0.000505|    8.863599|           3|
|          27|    3.830830|    0.331936|    0.220000|    0.000539|    8.574682|           4|
|          28|    3.796605|    0.232756|    0.075643|    0.000492|    8.591758|           5|
|          29|    3.706326|    0.470116|    0.249292|    0.000396|    8.517317|           6|
|==========================================================================================|

オブジェクト MdlTrainingHistory プロパティ内の情報を使用して、検証の平均二乗誤差 (MSE) が最小になる対応する反復を確認します。最終的に返されるモデル Mdl は、この反復で学習させたモデルになります。

iteration = Mdl.TrainingHistory.Iteration;
valLosses = Mdl.TrainingHistory.ValidationLoss;
[~,minIdx] = min(valLosses);
iteration(minIdx)
ans = 23

検定セットのパフォーマンスの評価

オブジェクト関数 loss および predict を使用して、学習させたモデル Mdl の性能を検定セット testTbl で評価します。

検定セットの平均二乗誤差 (MSE) を計算します。MSE の値が小さいほど、パフォーマンスが優れていることを示します。

mse = loss(Mdl,testTbl,"MPG")
mse = 25.4145

検定セットの予測応答値と実際の応答値を比較します。予測されるガロンあたりの走行マイル数 (MPG) を縦軸に、実際の MPG を横軸にしてプロットします。基準線上にある点は予測が正しいことを示します。優れたモデルでは、生成された予測が線の近くに分布します。

predictedY = predict(Mdl,testTbl);

plot(testTbl.MPG,predictedY,".")
hold on
plot(testTbl.MPG,testTbl.MPG)
hold off
xlabel("True Miles Per Gallon (MPG)")
ylabel("Predicted Miles Per Gallon (MPG)")

箱ひげ図を使用して、生産国別に MPG の予測値と実際の値の分布を比較します。関数 boxchart を使用して箱ひげ図を作成します。各箱ひげ図には、中央値、第 1 四分位数と第 3 四分位数、外れ値 (四分位数間範囲を使用して計算)、および外れ値ではない最小値と最大値を表示します。特に、各ボックスの内側の線は標本の中央値であり、円形のマーカーは外れ値を示します。

各生産国について、赤色の箱ひげ図 (MPG の予測値の分布を示す) と青色の箱ひげ図 (MPG の実際の値の分布を示す) を比較します。MPG の予測値と実際の値の分布が似ていれば、予測が優れていることを示します。

boxchart(testTbl.Origin,testTbl.MPG)
hold on
boxchart(testTbl.Origin,predictedY)
hold off
legend(["True MPG","Predicted MPG"])
xlabel("Country of Origin")
ylabel("Miles Per Gallon (MPG)")

ほとんどの国については、MPG の予測値と実際の値が同じような分布になっています。ただし、フランス製の自動車については、ニューラル ネットワーク モデルで MPG の値が低く推定されている傾向が見られます。この相違は、学習セットと検定セットでフランス製の自動車の数が少ないことが原因と考えられます。

学習セットと検定セットで、フランス製の自動車についての MPG の値の範囲を比較します。

trainSummary = grpstats(trainTbl(:,["MPG","Origin"]),"Origin", ...
    ["min","max"])
trainSummary=6×4 table
               Origin     GroupCount    min_MPG    max_MPG
               _______    __________    _______    _______

    France     France          3         16.2         27  
    Germany    Germany        11           20       44.3  
    Italy      Italy           1         37.3       37.3  
    Japan      Japan          24           20       40.8  
    Sweden     Sweden          3           19       21.6  
    USA        USA            94            9         39  

testSummary = grpstats(testTbl(:,["MPG","Origin"]),"Origin", ...
    ["min","max"])
testSummary=6×4 table
               Origin     GroupCount    min_MPG    max_MPG
               _______    __________    _______    _______

    France     France          3         28.1       40.9  
    Germany    Germany        12         21.5         44  
    Italy      Italy           3           28         30  
    Japan      Japan          32           18       46.6  
    Sweden     Sweden          3           17         24  
    USA        USA            82           10       36.1  

学習セットでは、フランス製の自動車についての MPG の値の範囲は 16.2 ~ 27 です。一方、検定セットでは、フランス製の自動車についての MPG の値の範囲は 28.1 ~ 40.9 です。

検定セットの残差をプロットします。通常、優れたモデルでは残差が 0 の近くでほぼ対称的に配置されます。残差に明確なパターンがある場合、モデルを改善できる可能性があります。

residuals = testTbl.MPG - predictedY;
plot(testTbl.MPG,residuals,".")
hold on
yline(0)
hold off
xlabel("True Miles Per Gallon (MPG)")
ylabel("MPG Residuals")

プロットから、残差の値にいくつか外れ値があることがわかります。残差が最大の観測値について詳細を調べます。

[outlierResidual,outlierIdx] = max(residuals)
outlierResidual = 37.8727
outlierIdx = 113
testTbl(outlierIdx,:)
ans=1×7 table
    Acceleration    Displacement    Horsepower    Model_Year    Origin    Weight    MPG 
    ____________    ____________    __________    __________    ______    ______    ____

        17.3             85            NaN            80        France     1835     40.9

この観測値は、Horsepower の値がない自動車に対応しており、生産国はフランス (観測数が少ないカテゴリ) になっています。

参考

| | | |