Main Content

このページの翻訳は最新ではありません。ここをクリックして、英語の最新版を参照してください。

回帰用の一般化加法モデルの学習

この例では、最適なパラメーターで一般化加法モデル (GAM) に学習させる方法と、学習済みモデルの予測性能を評価する方法を示します。この例では、最初に一変量の GAM に最適なパラメーター値 (線形項のパラメーター) を特定し、次に二変量の GAM の値 (交互作用項のパラメーター) を特定します。また、この例では、特定の予測に対する項の局所的効果を調べて、予測子に対する予測の部分依存を計算することで、学習済みモデルを解釈する方法についても説明します。

標本データの読み込み

標本データセット NYCHousing2015 を読み込みます。

load NYCHousing2015

データ セットには、2015 年のニューヨーク市における不動産の売上に関する情報を持つ 10 の変数が含まれます。この例では、これらの変数を使用して売価 (SALEPRICE) を解析します。

データ セットを前処理します。1000 ドル以下の SALEPRICE は、現金対価なしの所有権移転を示すものと仮定します。このような SALEPRICE をもつ標本を削除します。さらに、関数 isoutlier で識別される外れ値も削除します。その後、datetime 配列 (SALEDATE) を月番号に変換して、応答変数 (SALEPRICE) を最後の列に移動します。LANDSQUAREFEETGROSSSQUAREFEET、および YEARBUILT に含まれるゼロを NaN に変更します。

idx1 = NYCHousing2015.SALEPRICE <= 1000;
idx2 = isoutlier(NYCHousing2015.SALEPRICE);
NYCHousing2015(idx1|idx2,:) = [];
NYCHousing2015.SALEDATE = month(NYCHousing2015.SALEDATE);
NYCHousing2015 = movevars(NYCHousing2015,'SALEPRICE','After','SALEDATE');
NYCHousing2015.LANDSQUAREFEET(NYCHousing2015.LANDSQUAREFEET == 0) = NaN; 
NYCHousing2015.GROSSSQUAREFEET(NYCHousing2015.GROSSSQUAREFEET == 0) = NaN; 
NYCHousing2015.YEARBUILT(NYCHousing2015.YEARBUILT == 0) = NaN; 

テーブルの最初の 3 行を表示します。

head(NYCHousing2015,3)
ans=3×10 table
    BOROUGH    NEIGHBORHOOD       BUILDINGCLASSCATEGORY        RESIDENTIALUNITS    COMMERCIALUNITS    LANDSQUAREFEET    GROSSSQUAREFEET    YEARBUILT    SALEDATE    SALEPRICE
    _______    ____________    ____________________________    ________________    _______________    ______________    _______________    _________    ________    _________

       2       {'BATHGATE'}    {'01  ONE FAMILY DWELLINGS'}           1                   0                1103              1290            1910          2           3e+05 
       2       {'BATHGATE'}    {'01  ONE FAMILY DWELLINGS'}           1                   1                2500              2452            1910          7           4e+05 
       2       {'BATHGATE'}    {'01  ONE FAMILY DWELLINGS'}           1                   2                1911              4080            1931          1         5.1e+05 

関数datasampleを使用して 1000 個の標本を無作為に選択し、関数cvpartitionを使用して観測値を学習セットと検定セットに分割します。検定用に 10% のホールドアウト標本を指定します。

rng('default') % For reproducibility
NumSamples = 1e3;
NYCHousing2015 = datasample(NYCHousing2015,NumSamples,'Replace',false);
cv = cvpartition(size(NYCHousing2015,1),'HoldOut',0.10);

学習インデックスと検定インデックスを抽出し、学習データ セット用と検定データ セット用の table を作成します。

tbl_training = NYCHousing2015(training(cv),:);
tbl_test = NYCHousing2015(test(cv),:);

一変量の GAM に最適なパラメーターの特定

関数bayesoptを使用して、一変量の GAM のパラメーターを交差検証に関して最適化します。

optimizableVariableオブジェクトを一変量の GAM の名前と値の引数 MaxNumSplitsPerPredictorNumTreesPerPredictor および InitialLearnRateForPredictors 用に準備します。

maxNumSplitsPerPredictor = optimizableVariable('maxNumSplitsPerPredictor',[1,10],'Type','integer');
numTreesPerPredictor = optimizableVariable('numTreesPerPredictor',[1,500],'Type','integer');
initialLearnRateForPredictors = optimizableVariable('initialLearnRateForPredictors',[1e-3,1],'Type','real');

入力として z = [maxNumSplitsPerPredictor,numTreesPerPredictor,initialLearnRateForPredictors] を受け入れ z のパラメーターにおける交差検証損失値を返す目的関数を作成します。

minfun1 = @(z)kfoldLoss(fitrgam(tbl_training,'SALEPRICE', ...
    'CrossVal','on', ...
    'InitialLearnRateForPredictors',z.initialLearnRateForPredictors, ...
    'MaxNumSplitsPerPredictor',z.maxNumSplitsPerPredictor, ...
    'NumTreesPerPredictor',z.numTreesPerPredictor));

交差検証オプション 'CrossVal','on' を指定した場合、関数 fitrgam は交差検証済みモデル オブジェクトRegressionPartitionedGAMを返します。関数kfoldLossは、交差検証済みモデルで取得した回帰損失 (平均二乗誤差) を返します。そのため、関数ハンドル minfun1 は、z のパラメーターで交差検証損失を計算します。

bayesopt を使用して最良のパラメーターを求めます。再現性を得るために、'expected-improvement-plus' の獲得関数を選択します。既定の獲得関数は実行時に決定されるので、結果が異なる場合があります。

results1 = bayesopt(minfun1, ...
    [initialLearnRateForPredictors,maxNumSplitsPerPredictor,numTreesPerPredictor], ...
    'IsObjectiveDeterministic',true, ...
    'AcquisitionFunctionName','expected-improvement-plus');
|====================================================================================================================|
| Iter | Eval   | Objective   | Objective   | BestSoFar   | BestSoFar   | initialLearn-| maxNumSplits-| numTreesPerP-|
|      | result |             | runtime     | (observed)  | (estim.)    | RateForPredi | PerPredictor | redictor     |
|====================================================================================================================|
|    1 | Best   |  8.4558e+10 |      1.5106 |  8.4558e+10 |  8.4558e+10 |      0.36695 |            2 |           30 |
|    2 | Accept |  8.6891e+10 |       12.01 |  8.4558e+10 |  8.4558e+10 |     0.008213 |            5 |          271 |
|    3 | Accept |  9.6521e+10 |      1.9121 |  8.4558e+10 |  8.4558e+10 |      0.22984 |            9 |           37 |
|    4 | Accept |  1.3402e+11 |      14.388 |  8.4558e+10 |  8.4558e+10 |      0.99932 |            3 |          344 |
|    5 | Accept |  8.7852e+10 |      13.595 |  8.4558e+10 |  8.4558e+10 |      0.16575 |            1 |          456 |
|    6 | Accept |  9.3041e+10 |      11.002 |  8.4558e+10 |  8.4558e+10 |      0.49477 |            1 |          360 |
|    7 | Accept |  1.0558e+11 |      7.7647 |  8.4558e+10 |  8.4558e+10 |      0.24562 |            4 |          175 |
|    8 | Accept |  8.8841e+10 |      1.5763 |  8.4558e+10 |  8.4558e+10 |      0.39298 |            2 |           41 |
|    9 | Accept |  9.9227e+10 |      14.377 |  8.4558e+10 |  8.4558e+10 |     0.091879 |            3 |          358 |
|   10 | Accept |  9.8611e+10 |     0.14914 |  8.4558e+10 |  8.4558e+10 |      0.22487 |            2 |            2 |
|   11 | Accept |  1.2998e+11 |      23.962 |  8.4558e+10 |  8.4558e+10 |      0.25341 |            5 |          500 |
|   12 | Accept |  8.8968e+10 |      5.0028 |  8.4558e+10 |  8.4558e+10 |      0.33109 |            1 |          175 |
|   13 | Accept |  1.2018e+11 |      1.8004 |  8.4558e+10 |  8.4558e+10 |    0.0030413 |            6 |           40 |
|   14 | Accept |  8.7503e+10 |     0.79283 |  8.4558e+10 |  8.4558e+10 |      0.33877 |            1 |           25 |
|   15 | Accept |  9.3798e+10 |      2.9578 |  8.4558e+10 |  8.4558e+10 |      0.32926 |            2 |           80 |
|   16 | Accept |  9.5165e+10 |      8.0635 |  8.4558e+10 |  8.4558e+10 |      0.33878 |            1 |          282 |
|   17 | Best   |  8.3549e+10 |     0.24446 |  8.3549e+10 |  8.3549e+10 |       0.3552 |            2 |            5 |
|   18 | Best   |  8.3104e+10 |      1.4534 |  8.3104e+10 |  8.3104e+10 |       0.2526 |            1 |           49 |
|   19 | Accept |  8.6938e+10 |      3.3234 |  8.3104e+10 |  8.3104e+10 |      0.18293 |            1 |          110 |
|   20 | Accept |  8.7531e+10 |      2.8096 |  8.3104e+10 |  8.3104e+10 |       0.2781 |            1 |           93 |
|====================================================================================================================|
| Iter | Eval   | Objective   | Objective   | BestSoFar   | BestSoFar   | initialLearn-| maxNumSplits-| numTreesPerP-|
|      | result |             | runtime     | (observed)  | (estim.)    | RateForPredi | PerPredictor | redictor     |
|====================================================================================================================|
|   21 | Accept |  9.1613e+10 |      13.347 |  8.3104e+10 |  8.3104e+10 |      0.31722 |            1 |          464 |
|   22 | Accept |   8.678e+10 |      10.358 |  8.3104e+10 |  8.3104e+10 |      0.11269 |            1 |          358 |
|   23 | Accept |  8.3614e+10 |     0.47001 |  8.3104e+10 |  8.3104e+10 |      0.22278 |            1 |           14 |
|   24 | Accept |  1.3203e+11 |       1.069 |  8.3104e+10 |  8.3104e+10 |    0.0021552 |            5 |           23 |
|   25 | Accept |    8.66e+10 |       7.233 |  8.3104e+10 |  8.3104e+10 |      0.11469 |            1 |          236 |
|   26 | Accept |  8.4535e+10 |      8.7657 |  8.3104e+10 |  8.3104e+10 |    0.0090628 |            1 |          292 |
|   27 | Accept |  1.0315e+11 |      12.297 |  8.3104e+10 |  8.3104e+10 |    0.0014094 |            1 |          413 |
|   28 | Accept |  9.6736e+10 |      5.8323 |  8.3104e+10 |  8.3104e+10 |    0.0040429 |            1 |          202 |
|   29 | Accept |  8.3651e+10 |      8.4999 |  8.3104e+10 |  8.3104e+10 |      0.09375 |            1 |          295 |
|   30 | Accept |  8.7977e+10 |      13.521 |  8.3104e+10 |  8.3104e+10 |     0.016448 |            6 |          292 |

__________________________________________________________
Optimization completed.
MaxObjectiveEvaluations of 30 reached.
Total function evaluations: 30
Total elapsed time: 245.1541 seconds
Total objective function evaluation time: 210.0881

Best observed feasible point:
    initialLearnRateForPredictors    maxNumSplitsPerPredictor    numTreesPerPredictor
    _____________________________    ________________________    ____________________

               0.2526                           1                         49         

Observed objective function value = 83103839919.908
Estimated objective function value = 83103840296.3186
Function evaluation time = 1.4534

Best estimated feasible point (according to models):
    initialLearnRateForPredictors    maxNumSplitsPerPredictor    numTreesPerPredictor
    _____________________________    ________________________    ____________________

               0.2526                           1                         49         

Estimated objective function value = 83103840296.3186
Estimated function evaluation time = 1.803

results1 から最良の点を取得します。

zbest1 = bestPoint(results1)
zbest1=1×3 table
    initialLearnRateForPredictors    maxNumSplitsPerPredictor    numTreesPerPredictor
    _____________________________    ________________________    ____________________

               0.2526                           1                         49         

最適なパラメーターでの一変量の GAM の学習

zbest1 の値を使用して最適化された GAM に学習させます。

Mdl1 = fitrgam(tbl_training,'SALEPRICE', ...
    'InitialLearnRateForPredictors',zbest1.initialLearnRateForPredictors, ...
    'MaxNumSplitsPerPredictor',zbest1.maxNumSplitsPerPredictor, ...
    'NumTreesPerPredictor',zbest1.numTreesPerPredictor) 
Mdl1 = 
  RegressionGAM
           PredictorNames: {'BOROUGH'  'NEIGHBORHOOD'  'BUILDINGCLASSCATEGORY'  'RESIDENTIALUNITS'  'COMMERCIALUNITS'  'LANDSQUAREFEET'  'GROSSSQUAREFEET'  'YEARBUILT'  'SALEDATE'}
             ResponseName: 'SALEPRICE'
    CategoricalPredictors: [2 3]
        ResponseTransform: 'none'
                Intercept: 4.9806e+05
          NumObservations: 900


  Properties, Methods

Mdl1RegressionGAM モデル オブジェクトです。モデル表示には、モデルのプロパティの一部のみが表示されます。プロパティの完全な一覧を表示するには、ワークスペースで変数名 Mdl1 をダブルクリックします。Mdl1 の変数エディターが開きます。あるいは、コマンド ウィンドウでドット表記を使用してプロパティを表示できます。たとえば、ReasonForTermination プロパティを表示します。

Mdl1.ReasonForTermination
ans = struct with fields:
      PredictorTrees: 'Terminated after training the requested number of trees.'
    InteractionTrees: ''

プロパティ値の PredictorTrees フィールドは、Mdl1 に指定した数の木が含まれていることを示します。fitrgamNumTreesPerPredictor で予測子あたりの木の最大数を指定すると、要求された数の木に学習させる前にこの関数を停止できます。ReasonForTermination プロパティを使用して、学習済みモデルに指定した数の木が含まれているかどうかを確認できます。

交互作用項を含めて fitrgam でそれらについて木に学習させるように指定した場合、InteractionTrees フィールドには空でない値が含まれます。

二変量の GAM に最適なパラメーターの特定

関数 bayesopt を使用して、二変量の GAM の交互作用項のパラメーターを特定します。

optimizableVariable を交互作用項の名前と値の引数 InitialLearnRateForInteractionsMaxNumSplitsPerInteractionNumTreesPerInteraction、および InitialLearnRateForInteractions 用に準備します。

initialLearnRateForInteractions = optimizableVariable('initialLearnRateForInteractions',[1e-3,1],'Type','real');
maxNumSplitsPerInteraction = optimizableVariable('maxNumSplitsPerInteraction',[1,10],'Type','integer');
numTreesPerInteraction = optimizableVariable('numTreesPerInteraction',[1,500],'Type','integer');
numInteractions = optimizableVariable('numInteractions',[1,28],'Type','integer');

最適化のための目的関数を作成します。zbest1 の最適なパラメーター値を使用して、交互作用項に最適なパラメーター値が zbest1 の値に基づいて特定されるようにします。

minfun2 = @(z)kfoldLoss(fitrgam(tbl_training,'SALEPRICE', ...
    'CrossVal','on', ...
    'InitialLearnRateForPredictors',zbest1.initialLearnRateForPredictors, ...
    'MaxNumSplitsPerPredictor',zbest1.maxNumSplitsPerPredictor, ...
    'NumTreesPerPredictor',zbest1.numTreesPerPredictor, ...
    'InitialLearnRateForInteractions',z.initialLearnRateForInteractions, ...
    'MaxNumSplitsPerInteraction',z.maxNumSplitsPerInteraction, ...
    'NumTreesPerInteraction',z.numTreesPerInteraction, ...
    'Interactions',z.numInteractions));

bayesopt を使用して最良のパラメーターを求めます。最適化プロセスによって複数のモデルの学習が行われ、モデルに交互作用項が含まれていない場合は警告メッセージが表示されます。bayesopt を呼び出す前にすべての警告を無効にして、bayesopt の実行後に警告状態を復元します。警告状態を変更しないままにして警告メッセージを確認できます。

orig_state = warning('query'); 
warning('off')
results2 = bayesopt(minfun2, ...
    [initialLearnRateForInteractions,maxNumSplitsPerInteraction,numTreesPerInteraction,numInteractions], ...
    'IsObjectiveDeterministic',true, ...
    'AcquisitionFunctionName','expected-improvement-plus');
|===================================================================================================================================|
| Iter | Eval   | Objective   | Objective   | BestSoFar   | BestSoFar   | initialLearn-| maxNumSplits-| numTreesPerI-| numInteracti-|
|      | result |             | runtime     | (observed)  | (estim.)    | RateForInter | PerInteracti | nteraction   | ons          |
|===================================================================================================================================|
|    1 | Best   |  8.4721e+10 |      1.6996 |  8.4721e+10 |  8.4721e+10 |      0.41774 |            1 |          346 |           28 |
|    2 | Accept |  9.1765e+10 |      8.3313 |  8.4721e+10 |  8.4721e+10 |       0.9565 |            3 |          231 |           14 |
|    3 | Accept |  9.2116e+10 |      2.8341 |  8.4721e+10 |  8.4721e+10 |      0.33578 |            9 |           45 |            5 |
|    4 | Accept |   1.784e+11 |      76.237 |  8.4721e+10 |  8.4721e+10 |      0.91186 |           10 |          479 |           27 |
|    5 | Accept |  8.4906e+10 |      1.8275 |  8.4721e+10 |  8.4721e+10 |        0.296 |            4 |            1 |           27 |
|    6 | Best   |  8.4172e+10 |        1.73 |  8.4172e+10 |  8.4172e+10 |      0.68133 |            1 |           86 |            1 |
|    7 | Best   |   8.234e+10 |      1.7164 |   8.234e+10 |   8.234e+10 |      0.13943 |            1 |          228 |           26 |
|    8 | Accept |  8.3488e+10 |      1.6382 |   8.234e+10 |   8.234e+10 |      0.46764 |            1 |            1 |            5 |
|    9 | Accept |  8.7977e+10 |      1.5655 |   8.234e+10 |   8.234e+10 |       0.8385 |           10 |            1 |            5 |
|   10 | Accept |  8.4431e+10 |      1.5744 |   8.234e+10 |   8.234e+10 |      0.95535 |            1 |          261 |            4 |
|   11 | Accept |  8.5784e+10 |      1.7478 |   8.234e+10 |   8.234e+10 |     0.023058 |            7 |            1 |           14 |
|   12 | Accept |  8.6068e+10 |      1.7304 |   8.234e+10 |   8.234e+10 |      0.77118 |            1 |            5 |           28 |
|   13 | Accept |  8.7004e+10 |      1.5903 |   8.234e+10 |   8.234e+10 |     0.016991 |            1 |          263 |            2 |
|   14 | Accept |  8.3325e+10 |      1.5895 |   8.234e+10 |   8.234e+10 |       0.9468 |            4 |            7 |            1 |
|   15 | Accept |  8.4097e+10 |      1.6357 |   8.234e+10 |   8.234e+10 |      0.97988 |            1 |          250 |           28 |
|   16 | Accept |  8.3106e+10 |      1.6081 |   8.234e+10 |   8.234e+10 |     0.024052 |            1 |          121 |           28 |
|   17 | Accept |   8.469e+10 |      1.6235 |   8.234e+10 |   8.234e+10 |     0.047902 |            3 |            3 |           12 |
|   18 | Best   |  8.1641e+10 |      1.5833 |  8.1641e+10 |  8.1641e+10 |      0.99848 |            6 |            1 |            3 |
|   19 | Accept |  8.5957e+10 |      1.6305 |  8.1641e+10 |  8.1641e+10 |      0.99826 |            6 |            1 |           13 |
|   20 | Accept |  8.2486e+10 |      1.6515 |  8.1641e+10 |  8.1641e+10 |      0.36059 |            7 |            2 |            1 |
|===================================================================================================================================|
| Iter | Eval   | Objective   | Objective   | BestSoFar   | BestSoFar   | initialLearn-| maxNumSplits-| numTreesPerI-| numInteracti-|
|      | result |             | runtime     | (observed)  | (estim.)    | RateForInter | PerInteracti | nteraction   | ons          |
|===================================================================================================================================|
|   21 | Accept |  8.6534e+10 |       1.647 |  8.1641e+10 |  8.1641e+10 |    0.0089186 |            1 |          192 |           18 |
|   22 | Accept |  8.5425e+10 |      1.5316 |  8.1641e+10 |  8.1641e+10 |      0.99842 |            1 |          497 |            1 |
|   23 | Accept |   8.515e+10 |      1.5728 |  8.1641e+10 |  8.1641e+10 |      0.99934 |            1 |            3 |            2 |
|   24 | Accept |   8.593e+10 |      1.6086 |  8.1641e+10 |  8.1641e+10 |    0.0099052 |            1 |            2 |           28 |
|   25 | Accept |  8.7394e+10 |       1.577 |  8.1641e+10 |  8.1641e+10 |      0.96502 |            7 |            5 |            2 |
|   26 | Accept |   8.618e+10 |      1.5714 |  8.1641e+10 |  8.1641e+10 |     0.097871 |            5 |            3 |            1 |
|   27 | Accept |  8.5704e+10 |       1.665 |  8.1641e+10 |  8.1641e+10 |     0.056356 |           10 |            6 |            3 |
|   28 | Accept |  9.5451e+10 |      2.8821 |  8.1641e+10 |  8.1641e+10 |      0.91844 |            3 |           12 |           28 |
|   29 | Accept |  8.4013e+10 |      1.5633 |  8.1641e+10 |  8.1641e+10 |      0.68016 |            6 |            1 |            1 |
|   30 | Accept |  8.3928e+10 |      1.7715 |  8.1641e+10 |  8.1641e+10 |      0.07259 |            5 |            5 |           14 |

__________________________________________________________
Optimization completed.
MaxObjectiveEvaluations of 30 reached.
Total function evaluations: 30
Total elapsed time: 155.1459 seconds
Total objective function evaluation time: 132.9347

Best observed feasible point:
    initialLearnRateForInteractions    maxNumSplitsPerInteraction    numTreesPerInteraction    numInteractions
    _______________________________    __________________________    ______________________    _______________

                0.99848                            6                           1                      3       

Observed objective function value = 81640836929.8637
Estimated objective function value = 81640841484.6238
Function evaluation time = 1.5833

Best estimated feasible point (according to models):
    initialLearnRateForInteractions    maxNumSplitsPerInteraction    numTreesPerInteraction    numInteractions
    _______________________________    __________________________    ______________________    _______________

                0.99848                            6                           1                      3       

Estimated objective function value = 81640841484.6238
Estimated function evaluation time = 1.5784
warning(orig_state)

results2 から最良の点を取得します。

zbest2 = bestPoint(results2)
zbest2=1×4 table
    initialLearnRateForInteractions    maxNumSplitsPerInteraction    numTreesPerInteraction    numInteractions
    _______________________________    __________________________    ______________________    _______________

                0.99848                            6                           1                      3       

最適なパラメーターでの二変量の GAM の学習

zbest1zbest2 の値を使用して最適化された GAM に学習させます。

Mdl = fitrgam(tbl_training,'SALEPRICE', ...
    'InitialLearnRateForPredictors',zbest1.initialLearnRateForPredictors, ...
    'MaxNumSplitsPerPredictor',zbest1.maxNumSplitsPerPredictor, ...
    'NumTreesPerPredictor',zbest1.numTreesPerPredictor, ...
    'InitialLearnRateForInteractions',zbest2.initialLearnRateForInteractions, ...
    'MaxNumSplitsPerInteraction',zbest2.maxNumSplitsPerInteraction, ...
    'NumTreesPerInteraction',zbest2.numTreesPerInteraction, ...   
    'Interactions',zbest2.numInteractions) 
Mdl = 
  RegressionGAM
           PredictorNames: {'BOROUGH'  'NEIGHBORHOOD'  'BUILDINGCLASSCATEGORY'  'RESIDENTIALUNITS'  'COMMERCIALUNITS'  'LANDSQUAREFEET'  'GROSSSQUAREFEET'  'YEARBUILT'  'SALEDATE'}
             ResponseName: 'SALEPRICE'
    CategoricalPredictors: [2 3]
        ResponseTransform: 'none'
                Intercept: 4.9741e+05
             Interactions: [3×2 double]
          NumObservations: 900


  Properties, Methods

または、関数 addInteractions を使用して一変量の GAM に交互作用項を追加できます。

Mdl2 = addInteractions(Mdl1,zbest2.numInteractions, ...
    'InitialLearnRateForInteractions',zbest2.initialLearnRateForInteractions, ...
    'MaxNumSplitsPerInteraction',zbest2.maxNumSplitsPerInteraction, ...
    'NumTreesPerInteraction',zbest2.numTreesPerInteraction); 

2 番目の入力引数では交互作用項の最大数を指定し、名前と値の引数 NumTreesPerInteraction では交互作用項あたりの木の最大数を指定します。関数 addInteractions に含める交互作用項の数を減らすと、要求された数の木に学習させる前にこの関数を停止できます。Interactions プロパティと ReasonForTermination プロパティを使用して、学習済みモデルにある実際の交互作用項の数と木の数を確認できます。

Mdl の交互作用項を表示します。

Mdl.Interactions
ans = 3×2

     3     6
     4     6
     6     8

Interactions の各行は 1 つの交互作用項を表し、交互作用項の予測子変数の列インデックスを格納します。Interactions プロパティを使用して、モデル内の交互作用項とそれらが fitrgam でモデルに追加された順序を確認できます。

予測子名を使用して Mdl の交互作用項を表示します。

Mdl.PredictorNames(Mdl.Interactions)
ans = 3×2 cell
    {'BUILDINGCLASSCATEGORY'}    {'LANDSQUAREFEET'}
    {'RESIDENTIALUNITS'     }    {'LANDSQUAREFEET'}
    {'LANDSQUAREFEET'       }    {'YEARBUILT'     }

終了の理由を表示して、指定した数の木がモデルに含まれているかどうかを線形項と交互作用項のそれぞれについて確認します。

Mdl.ReasonForTermination
ans = struct with fields:
      PredictorTrees: 'Terminated after training the requested number of trees.'
    InteractionTrees: 'Terminated after training the requested number of trees.'

新しい観測値での予測性能の評価

検定標本 tbl_test とオブジェクト関数 predict および loss を使用して、学習済みモデルの性能を評価します。これらの関数をもつ完全またはコンパクトなモデルを使用できます。

  • predict — 応答を予測

  • loss — 回帰損失 (既定では平均二乗誤差) を計算

学習データ セットの性能を評価するには、再代入オブジェクト関数resubPredictおよびresubLossを使用します。これらの関数を使用するには、学習データを含む完全なモデルを使用しなければなりません。

コンパクトなモデルを作成して、学習済みモデルのサイズを縮小します。

CMdl = compact(Mdl);
whos('Mdl','CMdl')
  Name      Size             Bytes  Class                                          Attributes

  CMdl      1x1             370211  classreg.learning.regr.CompactRegressionGAM              
  Mdl       1x1             528102  RegressionGAM                                            

検定データ セット tbl_test について、応答を予測し、平均二乗誤差を計算します。

yFit = predict(CMdl,tbl_test);
L = loss(CMdl,tbl_test)
L = 1.2855e+11

学習済みモデルに交互作用項を含めずに予測された応答と誤差を調べます。

yFit_nointeraction = predict(CMdl,tbl_test,'IncludeInteractions',false);
L_nointeractions = loss(CMdl,tbl_test,'IncludeInteractions',false)
L_nointeractions = 1.3031e+11

線形項と交互作用項の両方が含まれている場合の方が検定データ セットの誤差が小さくなっています。

線形項と交互作用項の両方を含めることにより得た結果と線形項のみを含めることにより得た結果を比較します。観測された応答と予測された応答を格納する table を作成します。table の最初の 8 行を表示します。

t = table(tbl_test.SALEPRICE,yFit,yFit_nointeraction, ...
    'VariableNames',{'Observed Value','Predicted Response','Predicted Response Without Interactions'});
head(t)
ans=8×3 table
    Observed Value    Predicted Response    Predicted Response Without Interactions
    ______________    __________________    _______________________________________

         3.6e+05          4.9812e+05                      5.2712e+05               
         1.8e+05          2.7349e+05                      2.7415e+05               
         1.9e+05          3.3682e+05                      3.3748e+05               
        4.26e+05            6.15e+05                      5.6542e+05               
        3.91e+05          3.1262e+05                      3.1328e+05               
         2.3e+05          1.0606e+05                      1.0672e+05               
      4.7333e+05          1.0773e+06                      1.1399e+06               
           2e+05          2.9506e+05                       3.305e+05               

予測の解釈

関数plotLocalEffectsを使用して、最初の検定観測値についての予測を解釈します。また、関数plotPartialDependenceを使用して、モデル内のいくつかの重要な項の部分依存プロットを作成します。

検定データの最初の観測値に対する応答値を予測し、予測に対する CMdl 内の項の局所的効果をプロットします。プロットに切片項を含めるには、'IncludeIntercept',true を指定します。

yFit = predict(CMdl,tbl_test(1,:))
yFit = 4.9812e+05
plotLocalEffects(CMdl,tbl_test(1,:),'IncludeIntercept',true)

関数 predict で、最初の観測値 tbl_test(1,:) の売価を返します。関数 plotLocalEffects で、予測に対する CMdl 内の項の局所的効果を示す横棒グラフを作成します。局所的効果の各値は、tbl_test(1,:) の予測売価への各項の寄与を示します。

BUILDINGCLASSCATEGORY の部分従属の値を計算し、並べ替えられた値をプロットします。学習データ セットと検定データ セットの両方を指定して、両方のセットの部分依存の値を計算します。

[pd,x,y] = partialDependence(CMdl,'BUILDINGCLASSCATEGORY',[tbl_training; tbl_test]);
[pd_sorted,I] = sort(pd);
x_sorted = x(I);
x_sorted = reordercats(x_sorted,I);
figure
plot(x_sorted,pd_sorted,'o:')
xlabel('BUILDINGCLASSCATEGORY')
ylabel('SALEPRICE')
title('Patial Dependence Plot')

プロットされたラインは、学習済みモデルにおける予測子 BUILDINGCLASSCATEGORY と応答 SALEPRICE の間の平均化された部分関係性を表します。

RESIDENTIALUNITSLANDSQUAREFEET の部分依存プロットを作成します。

figure
plotPartialDependence(CMdl,["RESIDENTIALUNITS","LANDSQUAREFEET"],[tbl_training; tbl_test])

x 軸 (RESIDENTIALUNITS) と y 軸 (LANDSQUAREFEET) 小目盛りは、指定されたデータ内の一意の予測子値を表します。予測子値にはいくつかの外れ値があり、RESIDENTIALUNITSLANDSQUAREFEET の値のほとんどはそれぞれ 10 未満と 50,000 未満です。プロットは、RESIDENTIALUNITSLANDSQUAREFEET の値が 10 と 50,000 を超えると、SALEPRICE の値が大きくは変わらないことを示しています。

参考

| | | | | |

関連するトピック