回帰ニューラル ネットワークの性能評価
fitrnet
を使用して全結合層をもつフィードフォワード回帰ニューラル ネットワーク モデルを作成します。モデルの過適合を防止するために、検証データを使用して学習プロセスを早期に停止します。その後、モデルのオブジェクト関数を使用してテスト データで性能を評価します。
標本データの読み込み
carbig
データセットを読み込みます。このデータセットには、1970 年代と 1980 年代初期に製造された自動車の測定値が格納されています。
load carbig
変数 Origin
をカテゴリカル変数に変換します。その後、Acceleration
、Displacement
などの予測子変数と応答変数 MPG
を格納する table を作成します。各行に 1 台の自動車の測定値を格納します。欠損値がある table の行を table から削除します。
Origin = categorical(cellstr(Origin));
Tbl = table(Acceleration,Displacement,Horsepower, ...
Model_Year,Origin,Weight,MPG);
Tbl = rmmissing(Tbl);
データの分割
データを学習セット、検証セット、およびテスト セットに分割します。まず、観測値の約 3 分の 1 をテスト セット用に予約します。その後、残りのデータを半分に分割して学習セットと検証セットを作成します。
rng("default") % For reproducibility of the data partitions cvp1 = cvpartition(size(Tbl,1),"Holdout",1/3); testTbl = Tbl(test(cvp1),:); remainingTbl = Tbl(training(cvp1),:); cvp2 = cvpartition(size(remainingTbl,1),"Holdout",1/2); validationTbl = remainingTbl(test(cvp2),:); trainTbl = remainingTbl(training(cvp2),:);
ニューラル ネットワークの学習
学習セットを使用して回帰ニューラル ネットワーク モデルに学習させます。tblTrain
の列 MPG
を応答変数として指定し、数値予測子を標準化します。検証セットを使用して各反復でモデルを評価します。名前と値の引数 Verbose
を使用して、各反復で学習データを表示するように指定します。既定では、検証損失が 6 回連続でそれまでに計算された検証損失の最小値以上になると、その時点で学習プロセスが早期に終了します。検証損失が最小値以上になる許容回数を変更するには、名前と値の引数 ValidationPatience
を指定します。
Mdl = fitrnet(trainTbl,"MPG","Standardize",true, ... "ValidationData",validationTbl, ... "Verbose",1);
|==========================================================================================| | Iteration | Train Loss | Gradient | Step | Iteration | Validation | Validation | | | | | | Time (sec) | Loss | Checks | |==========================================================================================| | 1| 102.962345| 46.853164| 6.700877| 0.024993| 115.730384| 0| | 2| 55.403995| 22.171181| 1.811805| 0.016630| 53.086379| 0| | 3| 37.588848| 11.135231| 0.782861| 0.006920| 38.580002| 0| | 4| 29.713458| 8.379231| 0.392009| 0.014157| 31.021379| 0| | 5| 17.523851| 9.958164| 2.137584| 0.001966| 17.594863| 0| | 6| 12.700624| 2.957771| 0.744551| 0.004859| 14.209019| 0| | 7| 11.841152| 1.907378| 0.201770| 0.006246| 13.159899| 0| | 8| 10.162988| 2.542555| 0.576907| 0.005628| 11.352490| 0| | 9| 8.889095| 2.779980| 0.615716| 0.003847| 10.446334| 0| | 10| 7.670335| 2.400272| 0.648711| 0.003937| 10.424337| 0| |==========================================================================================| | Iteration | Train Loss | Gradient | Step | Iteration | Validation | Validation | | | | | | Time (sec) | Loss | Checks | |==========================================================================================| | 11| 7.416274| 0.505111| 0.214707| 0.009862| 10.522517| 1| | 12| 7.338923| 0.880655| 0.119085| 0.012907| 10.648031| 2| | 13| 7.149407| 1.784821| 0.277908| 0.002171| 10.800952| 3| | 14| 6.866385| 1.904480| 0.472190| 0.008427| 10.839202| 4| | 15| 6.815575| 3.339285| 0.943063| 0.006701| 10.031692| 0| | 16| 6.428137| 0.684771| 0.133729| 0.004101| 9.867819| 0| | 17| 6.363299| 0.456606| 0.125363| 0.009270| 9.720076| 0| | 18| 6.289887| 0.742923| 0.152290| 0.001644| 9.576588| 0| | 19| 6.215407| 0.964684| 0.183503| 0.005140| 9.422910| 0| | 20| 6.078333| 2.124971| 0.566948| 0.001952| 9.599573| 1| |==========================================================================================| | Iteration | Train Loss | Gradient | Step | Iteration | Validation | Validation | | | | | | Time (sec) | Loss | Checks | |==========================================================================================| | 21| 5.947923| 1.217291| 0.583867| 0.005246| 9.618400| 2| | 22| 5.855505| 0.671774| 0.285123| 0.003429| 9.734680| 3| | 23| 5.831802| 1.882061| 0.657368| 0.003987| 10.365968| 4| | 24| 5.713261| 1.004072| 0.134719| 0.002786| 10.314258| 5| | 25| 5.520766| 0.967032| 0.290156| 0.002881| 10.177322| 6| |==========================================================================================|
オブジェクト Mdl
の TrainingHistory
プロパティ内の情報を使用して、検証の平均二乗誤差 (MSE) が最小になる対応する反復を確認します。最終的に返されるモデル Mdl
は、この反復で学習させたモデルになります。
iteration = Mdl.TrainingHistory.Iteration; valLosses = Mdl.TrainingHistory.ValidationLoss; [~,minIdx] = min(valLosses); iteration(minIdx)
ans = 19
テスト セットのパフォーマンスの評価
オブジェクト関数 loss
および predict
を使用して、学習させたモデル Mdl
の性能をテスト セット testTbl
で評価します。
テスト セットの平均二乗誤差 (MSE) を計算します。MSE の値が小さいほど、パフォーマンスが優れていることを示します。
mse = loss(Mdl,testTbl,"MPG")
mse = 7.4101
テスト セットの予測応答値と実際の応答値を比較します。予測されるガロンあたりの走行マイル数 (MPG) を縦軸に、実際の MPG を横軸にしてプロットします。基準線上にある点は予測が正しいことを示します。優れたモデルでは、生成された予測が線の近くに分布します。
predictedY = predict(Mdl,testTbl); plot(testTbl.MPG,predictedY,".") hold on plot(testTbl.MPG,testTbl.MPG) hold off xlabel("True Miles Per Gallon (MPG)") ylabel("Predicted Miles Per Gallon (MPG)")
箱ひげ図を使用して、生産国別に MPG の予測値と実際の値の分布を比較します。関数 boxchart
を使用して箱ひげ図を作成します。各箱ひげ図には、中央値、第 1 四分位数と第 3 四分位数、外れ値 (四分位数間範囲を使用して計算)、および外れ値ではない最小値と最大値を表示します。特に、各ボックスの内側の線は標本の中央値であり、円形のマーカーは外れ値を示します。
各生産国について、赤色の箱ひげ図 (MPG の予測値の分布を示す) と青色の箱ひげ図 (MPG の実際の値の分布を示す) を比較します。MPG の予測値と実際の値の分布が似ていれば、予測が優れていることを示します。
boxchart(testTbl.Origin,testTbl.MPG) hold on boxchart(testTbl.Origin,predictedY) hold off legend(["True MPG","Predicted MPG"]) xlabel("Country of Origin") ylabel("Miles Per Gallon (MPG)")
ほとんどの国については、MPG の予測値と実際の値が同じような分布になっています。一部に相違があるのは、学習セットとテスト セットで自動車の数が少ないことが原因と考えられます。
学習セットとテスト セットで、自動車についての MPG の値の範囲を比較します。
trainSummary = grpstats(trainTbl(:,["MPG","Origin"]),"Origin", ... "range")
trainSummary=6×3 table
Origin GroupCount range_MPG
_______ __________ _________
France France 2 1.2
Germany Germany 12 23.4
Italy Italy 1 0
Japan Japan 26 26.6
Sweden Sweden 4 8
USA USA 86 27
testSummary = grpstats(testTbl(:,["MPG","Origin"]),"Origin", ... "range")
testSummary=6×3 table
Origin GroupCount range_MPG
_______ __________ _________
France France 4 19.8
Germany Germany 13 20.3
Italy Italy 4 11.3
Japan Japan 26 25.6
Sweden Sweden 1 0
USA USA 82 29
フランス、イタリア、スウェーデンなど、学習セットとテスト セットの自動車が少ない国において、MPG の値の範囲が両方のセットで有意に異なっています。
テスト セットの残差をプロットします。通常、優れたモデルでは残差が 0 の近くでほぼ対称的に配置されます。残差に明確なパターンがある場合、モデルを改善できる可能性があります。
residuals = testTbl.MPG - predictedY; plot(testTbl.MPG,residuals,".") hold on yline(0) hold off xlabel("True Miles Per Gallon (MPG)") ylabel("MPG Residuals")
プロットから、残差が適切に分散していることがわかります。
絶対値を基準に、残差が大きい観測値についての詳細を取得できます。
[~,residualIdx] = sort(residuals,"descend", ... "ComparisonMethod","abs"); residuals(residualIdx)
ans = 130×1
-8.8469
8.4427
8.0493
7.8996
-6.2220
5.8589
5.7007
-5.6733
-5.4545
5.1899
⋮
残差が大きい、つまり振幅が 8 を超えている 3 つの観測値を表示します。
testTbl(residualIdx(1:3),:)
ans=3×7 table
Acceleration Displacement Horsepower Model_Year Origin Weight MPG
____________ ____________ __________ __________ ______ ______ ____
17.6 91 68 82 Japan 1970 31
11.4 168 132 80 Japan 2910 32.7
13.8 91 67 80 Japan 1850 44.6
参考
fitrnet
| loss
| predict
| RegressionNeuralNetwork
| boxchart