Main Content

このページの翻訳は最新ではありません。ここをクリックして、英語の最新版を参照してください。

バイナリ分類用の一般化加法モデルの学習

この例では、最適なパラメーターで一般化加法モデル (GAM) に学習させる方法と、学習済みモデルの予測性能を評価する方法を示します。この例では、最初に一変量の GAM に最適なパラメーター値 (線形項のパラメーター) を特定し、次に二変量の GAM の値 (交互作用項のパラメーター) を特定します。また、この例では、特定の予測に対する項の局所的効果を調べて、予測子に対する予測の部分依存を計算することで、学習済みモデルを解釈する方法についても説明します。

標本データの読み込み

census1994.mat に保存されている 1994 年の国勢調査データを読み込みます。このデータセットは、個人の年収が $50,000 を超えるかどうかを予測するための、米国勢調査局の人口統計データから構成されます。この分類タスクでは、年齢、労働階級、教育レベル、婚姻区分、人種などが与えられた人の給与カテゴリを予測するモデルを近似します。

load census1994

census1994 には学習データセット adultdata および検定データセット adulttest が含まれています。この例では、実行時間を短縮するために、関数datasampleを使用して 500 の学習観測値と 500 の検定観測値をサブサンプリングします。

rng('default')
NumSamples = 5e2;
adultdata = datasample(adultdata,NumSamples,'Replace',false);
adulttest = datasample(adulttest,NumSamples,'Replace',false);

一変量の GAM に最適なパラメーターの特定

関数bayesoptを使用して、一変量の GAM のパラメーターを交差検証に関して最適化します。

optimizableVariableオブジェクトを一変量の GAM の名前と値の引数 MaxNumSplitsPerPredictorNumTreesPerPredictor および InitialLearnRateForPredictors 用に準備します。

maxNumSplitsPerPredictor = optimizableVariable('maxNumSplitsPerPredictor',[1,10],'Type','integer');
numTreesPerPredictor = optimizableVariable('numTreesPerPredictor',[1,500],'Type','integer');
initialLearnRateForPredictors = optimizableVariable('initialLearnRateForPredictors',[1e-3,1],'Type','real');

入力として z = [maxNumSplitsPerPredictor,numTreesPerPredictor,initialLearnRateForPredictors] を受け入れ z のパラメーターにおける交差検証損失値を返す目的関数を作成します。

minfun1 = @(z)kfoldLoss(fitcgam(adultdata,'salary','Weights','fnlwgt', ...
    'CrossVal','on', ...
    'InitialLearnRateForPredictors',z.initialLearnRateForPredictors, ...
    'MaxNumSplitsPerPredictor',z.maxNumSplitsPerPredictor, ...
    'NumTreesPerPredictor',z.numTreesPerPredictor));

交差検証オプション 'CrossVal','on' を指定した場合、関数 fitcgam は交差検証済みモデル オブジェクトClassificationPartitionedGAMを返します。関数kfoldLossは、交差検証済みモデルで取得した分類損失を返します。そのため、関数ハンドル minfun は、z のパラメーターで交差検証損失を計算します。

bayesopt を使用して最良のパラメーターを求めます。再現性を得るために、'expected-improvement-plus' の獲得関数を選択します。既定の獲得関数は実行時に決定されるので、結果が異なる場合があります。

results1 = bayesopt(minfun1, ...
    [initialLearnRateForPredictors,maxNumSplitsPerPredictor,numTreesPerPredictor], ...
    'IsObjectiveDeterministic',true, ...
    'AcquisitionFunctionName','expected-improvement-plus');
|====================================================================================================================|
| Iter | Eval   | Objective   | Objective   | BestSoFar   | BestSoFar   | initialLearn-| maxNumSplits-| numTreesPerP-|
|      | result |             | runtime     | (observed)  | (estim.)    | RateForPredi | PerPredictor | redictor     |
|====================================================================================================================|
|    1 | Best   |     0.18549 |      5.6957 |     0.18549 |     0.18549 |      0.73503 |            7 |           99 |
|    2 | Accept |     0.19145 |      20.383 |     0.18549 |     0.18549 |      0.72917 |           10 |          399 |
|    3 | Best   |     0.17703 |      13.412 |     0.17703 |     0.17703 |     0.079299 |            8 |          267 |
|    4 | Best   |     0.14955 |       0.402 |     0.14955 |     0.14955 |      0.24236 |            4 |            3 |
|    5 | Accept |     0.15999 |      12.363 |     0.14955 |     0.14955 |      0.25509 |            1 |          377 |
|    6 | Accept |     0.15158 |      1.5035 |     0.14955 |     0.14955 |      0.23051 |            7 |           29 |
|    7 | Accept |     0.16181 |     0.18204 |     0.14955 |     0.14955 |      0.34396 |            4 |            1 |
|    8 | Accept |     0.15079 |     0.38418 |     0.14955 |     0.14955 |      0.26669 |           10 |            5 |
|    9 | Accept |     0.16102 |     0.55525 |     0.14955 |     0.14955 |      0.26065 |            2 |           10 |
|   10 | Accept |     0.19259 |      8.6487 |     0.14955 |     0.14955 |      0.24894 |           10 |          182 |
|   11 | Accept |     0.18628 |     0.20681 |     0.14955 |     0.14955 |      0.13389 |            6 |            2 |
|   12 | Accept |     0.15653 |     0.24643 |     0.14955 |     0.14955 |      0.24172 |           10 |            2 |
|   13 | Best   |     0.14699 |     0.82743 |     0.14699 |     0.14699 |      0.26745 |            7 |           12 |
|   14 | Best   |     0.14634 |     0.47528 |     0.14634 |     0.14634 |      0.25025 |            6 |            6 |
|   15 | Best   |     0.14312 |     0.34493 |     0.14312 |     0.14312 |      0.30452 |            9 |            3 |
|   16 | Accept |     0.14334 |     0.51583 |     0.14312 |     0.14312 |      0.33507 |           10 |            7 |
|   17 | Best   |     0.13791 |     0.32248 |     0.13791 |     0.13791 |      0.33179 |            9 |            4 |
|   18 | Accept |     0.14875 |      0.3551 |     0.13791 |     0.13791 |      0.36806 |            8 |            5 |
|   19 | Accept |      0.1651 |      1.3731 |     0.13791 |     0.13791 |      0.32691 |            8 |           27 |
|   20 | Accept |     0.15895 |     0.37324 |     0.13791 |     0.13791 |      0.32985 |            7 |            5 |
|====================================================================================================================|
| Iter | Eval   | Objective   | Objective   | BestSoFar   | BestSoFar   | initialLearn-| maxNumSplits-| numTreesPerP-|
|      | result |             | runtime     | (observed)  | (estim.)    | RateForPredi | PerPredictor | redictor     |
|====================================================================================================================|
|   21 | Accept |     0.13946 |     0.26793 |     0.13791 |     0.13791 |      0.36721 |            9 |            3 |
|   22 | Accept |     0.16719 |      1.1276 |     0.13791 |     0.13791 |      0.25385 |            5 |           23 |
|   23 | Accept |     0.17017 |        1.35 |     0.13791 |     0.13791 |      0.23809 |            9 |           26 |
|   24 | Accept |     0.15519 |     0.46246 |     0.13791 |     0.13791 |      0.34831 |            9 |            7 |
|   25 | Accept |     0.15312 |     0.26445 |     0.13791 |     0.13791 |      0.33416 |           10 |            3 |
|   26 | Accept |     0.15852 |     0.31045 |     0.13791 |     0.13791 |       0.6142 |            9 |            4 |
|   27 | Accept |     0.16691 |     0.50559 |     0.13791 |     0.13791 |      0.31446 |            5 |            7 |
|   28 | Accept |     0.14384 |     0.35136 |     0.13791 |     0.13791 |      0.40215 |            9 |            4 |
|   29 | Accept |     0.14773 |     0.33296 |     0.13791 |     0.13791 |      0.34255 |            9 |            4 |
|   30 | Accept |     0.17604 |     0.85847 |     0.13791 |     0.13791 |      0.36565 |            6 |           15 |

__________________________________________________________
Optimization completed.
MaxObjectiveEvaluations of 30 reached.
Total function evaluations: 30
Total elapsed time: 97.6656 seconds
Total objective function evaluation time: 74.4022

Best observed feasible point:
    initialLearnRateForPredictors    maxNumSplitsPerPredictor    numTreesPerPredictor
    _____________________________    ________________________    ____________________

               0.33179                          9                         4          

Observed objective function value = 0.13791
Estimated objective function value = 0.13791
Function evaluation time = 0.32248

Best estimated feasible point (according to models):
    initialLearnRateForPredictors    maxNumSplitsPerPredictor    numTreesPerPredictor
    _____________________________    ________________________    ____________________

               0.33179                          9                         4          

Estimated objective function value = 0.13791
Estimated function evaluation time = 0.33084

results1 から最良の点を取得します。

zbest1 = bestPoint(results1)
zbest1=1×3 table
    initialLearnRateForPredictors    maxNumSplitsPerPredictor    numTreesPerPredictor
    _____________________________    ________________________    ____________________

               0.33179                          9                         4          

最適なパラメーターでの一変量の GAM の学習

zbest1 の値を使用して最適化された GAM に学習させます。クラス名を指定することが推奨されます。

Mdl1 = fitcgam(adultdata,'salary','Weights','fnlwgt', ...
    'ClassNames',categorical({'<=50K','>50K'}), ...
    'InitialLearnRateForPredictors',zbest1.initialLearnRateForPredictors, ...
    'MaxNumSplitsPerPredictor',zbest1.maxNumSplitsPerPredictor, ...
    'NumTreesPerPredictor',zbest1.numTreesPerPredictor) 
Mdl1 = 
  ClassificationGAM
           PredictorNames: {'age'  'workClass'  'education'  'education_num'  'marital_status'  'occupation'  'relationship'  'race'  'sex'  'capital_gain'  'capital_loss'  'hours_per_week'  'native_country'}
             ResponseName: 'salary'
    CategoricalPredictors: [2 3 5 6 7 8 9 13]
               ClassNames: [<=50K    >50K]
           ScoreTransform: 'logit'
                Intercept: -1.7383
          NumObservations: 500


  Properties, Methods

Mdl1ClassificationGAM モデル オブジェクトです。モデル表示には、モデルのプロパティの一部のみが表示されます。モデルのプロパティの完全な一覧を表示するには、ワークスペースで変数名 Mdl1 をダブルクリックします。Mdl1 の変数エディターが開きます。あるいは、コマンド ウィンドウでドット表記を使用してプロパティを表示できます。たとえば、ReasonForTermination プロパティを表示します。

Mdl1.ReasonForTermination
ans = struct with fields:
      PredictorTrees: 'Terminated after training the requested number of trees.'
    InteractionTrees: ''

プロパティ値の PredictorTrees フィールドは、Mdl1 に指定した数の木が含まれていることを示します。fitcgamNumTreesPerPredictor で予測子あたりの木の最大数を指定すると、要求された数の木に学習させる前にこの関数を停止できます。ReasonForTermination プロパティを使用して、学習済みモデルに指定した数の木が含まれているかどうかを確認できます。

交互作用項を含めて fitcgam でそれらについて木に学習させるように指定した場合、InteractionTrees フィールドには空でない値が含まれます。

二変量の GAM に最適なパラメーターの特定

関数 bayesopt を使用して、二変量の GAM の交互作用項のパラメーターを特定します。

optimizableVariable オブジェクトを交互作用項の名前と値の引数 InitialLearnRateForInteractionsMaxNumSplitsPerInteractionNumTreesPerInteraction および InitialLearnRateForInteractions 用に準備します。

initialLearnRateForInteractions = optimizableVariable('initialLearnRateForInteractions',[1e-3,1],'Type','real');
maxNumSplitsPerInteraction = optimizableVariable('maxNumSplitsPerInteraction',[1,10],'Type','integer');
numTreesPerInteraction = optimizableVariable('numTreesPerInteraction',[1,500],'Type','integer');
numInteractions = optimizableVariable('numInteractions',[1,28],'Type','integer');

最適化のための目的関数を作成します。zbest1 の最適なパラメーター値を使用して、交互作用項に最適なパラメーター値が zbest1 の値に基づいて特定されるようにします。

minfun2 = @(z)kfoldLoss(fitcgam(adultdata,'salary','Weights','fnlwgt', ...
    'CrossVal','on', ...
    'InitialLearnRateForPredictors',zbest1.initialLearnRateForPredictors, ...
    'MaxNumSplitsPerPredictor',zbest1.maxNumSplitsPerPredictor, ...
    'NumTreesPerPredictor',zbest1.numTreesPerPredictor, ...
    'InitialLearnRateForInteractions',z.initialLearnRateForInteractions, ...
    'MaxNumSplitsPerInteraction',z.maxNumSplitsPerInteraction, ...
    'NumTreesPerInteraction',z.numTreesPerInteraction, ...
    'Interactions',z.numInteractions));

bayesopt を使用して最良のパラメーターを求めます。最適化プロセスによって複数のモデルの学習が行われ、モデルに交互作用項が含まれていない場合は警告メッセージが表示されます。bayesopt を呼び出す前にすべての警告を無効にして、bayesopt の実行後に警告状態を復元します。警告状態を変更しないままにして警告メッセージを確認できます。

orig_state = warning('query'); 
warning('off')
results2 = bayesopt(minfun2, ...
    [initialLearnRateForInteractions,maxNumSplitsPerInteraction,numTreesPerInteraction,numInteractions], ...
    'IsObjectiveDeterministic',true, ...
    'AcquisitionFunctionName','expected-improvement-plus');
|===================================================================================================================================|
| Iter | Eval   | Objective   | Objective   | BestSoFar   | BestSoFar   | initialLearn-| maxNumSplits-| numTreesPerI-| numInteracti-|
|      | result |             | runtime     | (observed)  | (estim.)    | RateForInter | PerInteracti | nteraction   | ons          |
|===================================================================================================================================|
|    1 | Best   |     0.19671 |      10.999 |     0.19671 |     0.19671 |      0.96444 |            8 |          109 |           22 |
|    2 | Best   |       0.189 |       30.57 |       0.189 |       0.189 |      0.98548 |            6 |          457 |           17 |
|    3 | Best   |     0.16538 |      18.643 |     0.16538 |     0.16538 |      0.28678 |            4 |          383 |           13 |
|    4 | Best   |     0.15243 |      0.4285 |     0.15243 |     0.15243 |      0.28044 |            1 |           45 |            3 |
|    5 | Accept |     0.16065 |     0.69005 |     0.15243 |     0.15243 |      0.20151 |            7 |           60 |            1 |
|    6 | Best   |     0.14831 |     0.36629 |     0.14831 |     0.14831 |     0.032423 |            1 |          151 |            1 |
|    7 | Accept |     0.14887 |     0.36443 |     0.14831 |     0.14831 |     0.021093 |            1 |           15 |            1 |
|    8 | Accept |     0.15039 |     0.42139 |     0.14831 |     0.14831 |     0.012128 |            2 |          482 |            1 |
|    9 | Best   |     0.14787 |     0.42482 |     0.14787 |     0.14787 |      0.10119 |            1 |          121 |            6 |
|   10 | Best   |     0.13902 |     0.38822 |     0.13902 |     0.13902 |       0.1233 |            1 |          281 |            3 |
|   11 | Accept |     0.14721 |     0.39532 |     0.13902 |     0.13902 |     0.065618 |            1 |          291 |            3 |
|   12 | Accept |     0.14586 |     0.39205 |     0.13902 |     0.13902 |      0.18711 |            1 |          117 |            1 |
|   13 | Accept |     0.15073 |       0.383 |     0.13902 |     0.13902 |      0.15072 |            1 |           15 |            3 |
|   14 | Accept |     0.14966 |     0.42744 |     0.13902 |     0.13902 |      0.17155 |            1 |          497 |            4 |
|   15 | Best   |     0.13716 |     0.37599 |     0.13716 |     0.13716 |      0.12601 |            1 |          281 |            1 |
|   16 | Accept |     0.15094 |     0.38197 |     0.13716 |     0.13716 |      0.13962 |            2 |          284 |            1 |
|   17 | Accept |     0.13972 |      4.5994 |     0.13716 |     0.13716 |    0.0028545 |            5 |          481 |            2 |
|   18 | Accept |     0.14788 |      31.639 |     0.13716 |     0.13716 |    0.0024433 |            6 |          489 |           15 |
|   19 | Accept |     0.14565 |       1.276 |     0.13716 |     0.13716 |     0.013118 |            5 |          257 |            1 |
|   20 | Accept |     0.16502 |      28.315 |     0.13716 |     0.13716 |    0.0063353 |            4 |          457 |           16 |
|===================================================================================================================================|
| Iter | Eval   | Objective   | Objective   | BestSoFar   | BestSoFar   | initialLearn-| maxNumSplits-| numTreesPerI-| numInteracti-|
|      | result |             | runtime     | (observed)  | (estim.)    | RateForInter | PerInteracti | nteraction   | ons          |
|===================================================================================================================================|
|   21 | Accept |     0.15693 |      4.9653 |     0.13716 |     0.13716 |     0.016486 |            6 |          466 |            2 |
|   22 | Accept |     0.16312 |      29.942 |     0.13716 |     0.13716 |     0.019904 |            5 |          488 |           15 |
|   23 | Accept |     0.15719 |      4.7423 |     0.13716 |     0.13716 |     0.020155 |            4 |          456 |            3 |
|   24 | Best   |       0.129 |      6.4419 |       0.129 |       0.129 |     0.090858 |            5 |          478 |            3 |
|   25 | Accept |     0.15118 |      6.6757 |       0.129 |       0.129 |      0.15943 |            5 |          494 |            3 |
|   26 | Accept |     0.15343 |      2.2035 |       0.129 |       0.129 |     0.070349 |            5 |          489 |            1 |
|   27 | Best   |     0.12879 |      6.8017 |     0.12879 |     0.12879 |     0.091985 |            5 |          387 |            4 |
|   28 | Accept |     0.19093 |      5.9262 |     0.12879 |     0.12879 |     0.067405 |            5 |          331 |            4 |
|   29 | Accept |     0.16767 |      6.3779 |     0.12879 |     0.12879 |      0.31419 |            5 |          472 |            3 |
|   30 | Accept |     0.17636 |      11.026 |     0.12879 |     0.12879 |     0.054697 |            5 |          383 |            7 |

__________________________________________________________
Optimization completed.
MaxObjectiveEvaluations of 30 reached.
Total function evaluations: 30
Total elapsed time: 239.1035 seconds
Total objective function evaluation time: 216.5833

Best observed feasible point:
    initialLearnRateForInteractions    maxNumSplitsPerInteraction    numTreesPerInteraction    numInteractions
    _______________________________    __________________________    ______________________    _______________

               0.091985                            5                          387                     4       

Observed objective function value = 0.12879
Estimated objective function value = 0.12879
Function evaluation time = 6.8017

Best estimated feasible point (according to models):
    initialLearnRateForInteractions    maxNumSplitsPerInteraction    numTreesPerInteraction    numInteractions
    _______________________________    __________________________    ______________________    _______________

               0.091985                            5                          387                     4       

Estimated objective function value = 0.12879
Estimated function evaluation time = 6.7245
warning(orig_state)

results2 から最良の点を取得します。

zbest2 = bestPoint(results2)
zbest2=1×4 table
    initialLearnRateForInteractions    maxNumSplitsPerInteraction    numTreesPerInteraction    numInteractions
    _______________________________    __________________________    ______________________    _______________

               0.091985                            5                          387                     4       

最適なパラメーターでの二変量の GAM の学習

zbest1zbest2 の値を使用して最適化された GAM に学習させます。

Mdl = fitcgam(adultdata,'salary','Weights','fnlwgt', ...
    'ClassNames',categorical({'<=50K','>50K'}), ...
    'InitialLearnRateForPredictors',zbest1.initialLearnRateForPredictors, ...
    'MaxNumSplitsPerPredictor',zbest1.maxNumSplitsPerPredictor, ...
    'NumTreesPerPredictor',zbest1.numTreesPerPredictor, ...
    'InitialLearnRateForInteractions',zbest2.initialLearnRateForInteractions, ...
    'MaxNumSplitsPerInteraction',zbest2.maxNumSplitsPerInteraction, ...
    'NumTreesPerInteraction',zbest2.numTreesPerInteraction, ...   
    'Interactions',zbest2.numInteractions) 
Mdl = 
  ClassificationGAM
           PredictorNames: {'age'  'workClass'  'education'  'education_num'  'marital_status'  'occupation'  'relationship'  'race'  'sex'  'capital_gain'  'capital_loss'  'hours_per_week'  'native_country'}
             ResponseName: 'salary'
    CategoricalPredictors: [2 3 5 6 7 8 9 13]
               ClassNames: [<=50K    >50K]
           ScoreTransform: 'logit'
                Intercept: -1.7755
             Interactions: [4×2 double]
          NumObservations: 500


  Properties, Methods

または、関数 addInteractions を使用して一変量の GAM に交互作用項を追加できます。

Mdl2 = addInteractions(Mdl1,zbest2.numInteractions, ...
    'InitialLearnRateForInteractions',zbest2.initialLearnRateForInteractions, ...
    'MaxNumSplitsPerInteraction',zbest2.maxNumSplitsPerInteraction, ...
    'NumTreesPerInteraction',zbest2.numTreesPerInteraction); 

2 番目の入力引数では交互作用項の最大数を指定し、名前と値の引数 NumTreesPerInteraction では交互作用項あたりの木の最大数を指定します。関数 addInteractions に含める交互作用項の数を減らすと、要求された数の木に学習させる前にこの関数を停止できます。Interactions プロパティと ReasonForTermination プロパティを使用して、学習済みモデルにある実際の交互作用項の数と木の数を確認できます。

Mdl の交互作用項を表示します。

Mdl.Interactions
ans = 4×2

     7    10
     4     7
     7     9
     5    10

Interactions の各行は 1 つの交互作用項を表し、交互作用項の予測子変数の列インデックスを格納します。Interactions プロパティを使用して、モデル内の交互作用項とそれらが fitcgam でモデルに追加された順序を確認できます。

予測子名を使用して Mdl の交互作用項を表示します。

Mdl.PredictorNames(Mdl.Interactions)
ans = 4×2 cell
    {'relationship'  }    {'capital_gain'}
    {'education_num' }    {'relationship'}
    {'relationship'  }    {'sex'         }
    {'marital_status'}    {'capital_gain'}

終了の理由を表示して、指定した数の木がモデルに含まれているかどうかを線形項と交互作用項のそれぞれについて確認します。

Mdl.ReasonForTermination
ans = struct with fields:
      PredictorTrees: 'Terminated after training the requested number of trees.'
    InteractionTrees: 'Terminated after training the requested number of trees.'

新しい観測値での予測性能の評価

検定標本 adulttest とオブジェクト関数 predictlossedge および margin を使用して、学習済みモデルの性能を評価します。これらの関数をもつ完全またはコンパクトなモデルを使用できます。

  • predict— 観測値の分類

  • loss— 分類損失の計算 (既定では 10 進数の誤分類率)

  • margin— 分類マージンの計算

  • edge— 分類エッジ (分類マージンの平均) の計算

学習データ セットの性能を評価するには、再代入オブジェクト関数resubPredictresubLossresubMarginおよびresubEdgeを使用します。これらの関数を使用するには、学習データを含む完全なモデルを使用しなければなりません。

コンパクトなモデルを作成して、学習済みモデルのサイズを縮小します。

CMdl = compact(Mdl);
whos('Mdl','CMdl')
  Name      Size              Bytes  Class                                                 Attributes

  CMdl      1x1             3272176  classreg.learning.classif.CompactClassificationGAM              
  Mdl       1x1             3389515  ClassificationGAM                                               

検定データ セット (adulttest) のラベルとスコアを予測し、検定データ セットを使用してモデルの統計量 (損失、マージンおよびエッジ) を計算します。

[labels,scores] = predict(CMdl,adulttest);
L = loss(CMdl,adulttest,'Weights',adulttest.fnlwgt);
M = margin(CMdl,adulttest);
E = edge(CMdl,adulttest,'Weights',adulttest.fnlwgt);

学習済みモデルに交互作用項を含めずにラベルとスコアを予測し、統計量を計算します。

[labels_nointeraction,scores_nointeraction] = predict(CMdl,adulttest,'IncludeInteractions',false);
L_nointeractions = loss(CMdl,adulttest,'Weights',adulttest.fnlwgt,'IncludeInteractions',false);
M_nointeractions = margin(CMdl,adulttest,'IncludeInteractions',false);
E_nointeractions = edge(CMdl,adulttest,'Weights',adulttest.fnlwgt,'IncludeInteractions',false);

線形項と交互作用項の両方を含めることにより得た結果と線形項のみを含めることにより得た結果を比較します。

観測されたラベル、予測ラベルおよびスコアが格納されている table を作成します。table の最初の 8 行を表示します。

t = table(adulttest.salary,labels,scores,labels_nointeraction,scores_nointeraction, ...
    'VariableNames',{'True Labels','Predicted Labels','Scores' ...
    'Predicted Labels without interactions','Scores without interactions'});
head(t)
ans=8×5 table
    True Labels    Predicted Labels           Scores            Predicted Labels without interactions    Scores without interactions
    ___________    ________________    _____________________    _____________________________________    ___________________________

       <=50K            <=50K          0.97921      0.020787                    <=50K                       0.98005     0.019951    
       <=50K            <=50K                1     8.258e-17                    <=50K                        0.9713     0.028696    
       <=50K            <=50K                1    1.8297e-19                    <=50K                       0.99449    0.0055054    
       <=50K            <=50K          0.87422       0.12578                    <=50K                       0.87729      0.12271    
       <=50K            <=50K                1    3.5643e-07                    <=50K                       0.99882    0.0011769    
       <=50K            <=50K          0.60371       0.39629                    <=50K                       0.77861      0.22139    
       <=50K            >50K           0.49917       0.50083                    >50K                        0.46877      0.53123    
       >50K             >50K            0.3109        0.6891                    <=50K                       0.53571      0.46429    

真のラベル adulttest.salary と予測ラベルから混同チャートを作成します。

tiledlayout(1,2);
nexttile
confusionchart(adulttest.salary,labels)
title('Linear and Interaction Terms')
nexttile
confusionchart(adulttest.salary,labels_nointeraction)
title('Linear Terms Only')

計算された損失とエッジの値を表示します。

table([L; E], [L_nointeractions; E_nointeractions], ...
    'VariableNames',{'Linear and Interaction Terms','Only Linear Terms'}, ...
    'RowNames',{'Loss','Edge'})
ans=2×2 table
            Linear and Interaction Terms    Only Linear Terms
            ____________________________    _________________

    Loss              0.14868                    0.13852     
    Edge              0.63926                    0.58405     

線形項のみが含まれている場合の方がモデルの損失が小さくなっていますが、線形項と交互作用項の両方が含まれている場合の方がエッジの値が大きくなっています。

箱ひげ図を使用してマージンの分布を表示します。

figure
boxplot([M M_nointeractions],'Labels',{'Linear and Interaction Terms','Linear Terms Only'})
title('Box Plots of Test Sample Margins')

予測の解釈

関数plotLocalEffectsを使用して、最初の検定観測値についての予測を解釈します。また、関数plotPartialDependenceを使用して、モデル内のいくつかの重要な項の部分依存プロットを作成します。

検定データの最初の観測値を分類し、予測に対する CMdl 内の項の局所的効果をプロットします。予測子名に含まれるアンダースコアを表示するには、座標軸の TickLabelInterpreter 値を 'none' に変更します。

label = predict(CMdl,adulttest(1,:))
label = categorical
     <=50K 

f1 = figure;
plotLocalEffects(CMdl,adulttest(1,:))
f1.CurrentAxes.TickLabelInterpreter = 'none';

関数 predict で、最初の観測値 adulttest(1,:)'<=50K' として分類します。関数 plotLocalEffects で、予測に対する上位 10 個の重要な項の局所的効果を示す横棒グラフを作成します。局所的効果の各値は、'<=50K' の分類スコアへの各項の寄与を示します。これは、観測値の分類が '<=50K' となる事後確率のロジットです。

age の部分依存プロットを作成します。学習データ セットと検定データ セットの両方を指定して、両方のセットの部分依存の値を計算します。

figure
plotPartialDependence(CMdl,'age',label,[adultdata; adulttest])

プロットされたラインは、学習済みモデルにおける予測子 age とクラスのスコア <=50K の間の平均化された部分関係性を表します。x 軸の小目盛りは予測子 age の一意の値を表します。

education_num および relationship の部分依存プロットを作成します。

f2 = figure;
plotPartialDependence(CMdl,["education_num","relationship"],label,[adultdata; adulttest])
f2.CurrentAxes.TickLabelInterpreter = 'none';

プロットには、education_num に対する部分依存が示され、relationship の値に応じてトレンドが異なっていることがわかります。

参考

| | | | | |

関連するトピック