Main Content

スタッキング アンサンブルへの異種混合モデルの結合

この例では、特定の学習データ セットに対して複数の機械学習モデルをビルドし、"スタッキング" と呼ばれる手法を使用してモデルを結合し、個々のモデルの精度に対するテスト データ セットの精度を向上させる方法を示します。

スタッキングは、複数の異種混合モデルを組み合わせるために使用される手法で、しばしば "スタッキング アンサンブル モデル" または "スタッキング学習器" とも呼ばれます。元の (ベース) モデルの k 分割交差検証された予測 (分類モデルの分類スコアと回帰モデルの予測応答) で、追加のモデルに学習させることによってそれを行います。スタッキングの背後にある考え方は、特定のモデルがテスト観測値を正しく分類する一方で、他のモデルはそうしない場合があるというものです。アルゴリズムはこの多様な予測から学習を行い、モデルを組み合わせてベース モデルの予測精度を向上させようとします。

この例では、1 つのデータ セットで、いくつかの異種混合分類モデルに学習させ、スタッキングを使用してモデルを結合します。

標本データの読み込み

この例では census1994.mat に保存されている 1994 年の国勢調査データを使用します。このデータ セットは、個人の年収が $50,000 を超えるかどうかを予測するための、米国勢調査局の人口統計データから構成されます。この分類タスクでは、年齢、労働階級、教育レベル、婚姻区分、人種などが与えられた人の給与カテゴリを予測するモデルを当てはめます。

標本データ census1994 を読み込み、データ セットの変数を表示します。

load census1994
whos
  Name                 Size              Bytes  Class    Attributes

  Description         20x74               2960  char               
  adultdata        32561x15            1872567  table              
  adulttest        16281x15             944467  table              

census1994 には学習データ セット adultdata およびテスト データ セット adulttest が含まれています。この例では、実行時間を短縮するために、関数 datasample を使用して元の table adultdata および adulttest からそれぞれ 5000 の学習観測値とテスト観測値をサブサンプリングします (完全なデータ セットを使用する場合は、この手順を省略できます)。

NumSamples = 5e3;
s = RandStream('mlfg6331_64','seed',0); % For reproducibility
adultdata = datasample(s,adultdata,NumSamples,'Replace',false);
adulttest = datasample(s,adulttest,NumSamples,'Replace',false);

サポート ベクター マシン (SVM) などの一部のモデルは欠損値を含む観測値を削除しますが、決定木などの他のモデルはそのような観測値を削除しません。モデル間の整合性を維持するには、モデルを当てはめる前に欠損値を含む行を削除します。

adultdata = rmmissing(adultdata);
adulttest = rmmissing(adulttest);

学習データ セットの最初の数行をプレビューします。

head(adultdata)
ans=8×15 table
    age     workClass       fnlwgt       education      education_num      marital_status         occupation         relationship     race      sex      capital_gain    capital_loss    hours_per_week    native_country    salary
    ___    ___________    __________    ____________    _____________    __________________    _________________    ______________    _____    ______    ____________    ____________    ______________    ______________    ______

    39     Private          4.91e+05    Bachelors            13          Never-married         Exec-managerial      Other-relative    Black    Male           0               0                45          United-States     <=50K 
    25     Private        2.2022e+05    11th                  7          Never-married         Handlers-cleaners    Own-child         White    Male           0               0                45          United-States     <=50K 
    24     Private        2.2761e+05    10th                  6          Divorced              Handlers-cleaners    Unmarried         White    Female         0               0                58          United-States     <=50K 
    51     Private        1.7329e+05    HS-grad               9          Divorced              Other-service        Not-in-family     White    Female         0               0                40          United-States     <=50K 
    54     Private        2.8029e+05    Some-college         10          Married-civ-spouse    Sales                Husband           White    Male           0               0                32          United-States     <=50K 
    53     Federal-gov         39643    HS-grad               9          Widowed               Exec-managerial      Not-in-family     White    Female         0               0                58          United-States     <=50K 
    52     Private             81859    HS-grad               9          Married-civ-spouse    Machine-op-inspct    Husband           White    Male           0               0                48          United-States     >50K  
    37     Private        1.2429e+05    Some-college         10          Married-civ-spouse    Adm-clerical         Husband           White    Male           0               0                50          United-States     <=50K 

各行は、年齢、教育、職業など、成人 1 人の属性を表します。最後の列 salary は個人の年収が $50,000 以下か、$50,000 を超えるかどうかを示します。

データの理解および分類モデルの選択

Statistics and Machine Learning Toolbox™ には、分類木、判別分析、単純ベイズ、最近傍、SVM、アンサンブル分類を含む、分類用の複数のオプションが用意されています。アルゴリズムの完全なリストについては、分類を参照してください。

問題に使用するアルゴリズムを選択する前に、データ セットを検査します。国勢調査データには注目すべき複数の特性があります。

  • データは表形式であり、数値変数とカテゴリカル変数が両方含まれています。

  • データには欠損値が含まれています。

  • 応答変数 (salary) には、2 つのクラス (バイナリ分類) があります。

何かを仮定したり、データで十分に機能することが予測されるアルゴリズムの事前知識を使用しないで、単純に表形式のデータとバイナリ分類をサポートするすべてのアルゴリズムに学習させます。誤り訂正出力符号 (ECOC) モデルは 3 つ以上のクラスがあるデータで使用されます。判別分析アルゴリズムおよび最近傍アルゴリズムは数値変数とカテゴリカル変数の両方が含まれるデータを解析しません。したがって、この例に適したアルゴリズムは、SVM、決定木、決定木のアンサンブル、および単純ベイズ モデルです。

ベース モデルの構築

2 つの SVM モデルを、一方はガウス カーネル、もう一方は多項式カーネルを使用して当てはめます。さらに、決定木、単純ベイズ モデル、および決定木のアンサンブルを当てはめます。

% SVM with Gaussian kernel
rng('default') % For reproducibility
mdls{1} = fitcsvm(adultdata,'salary','KernelFunction','gaussian', ...
    'Standardize',true,'KernelScale','auto');

% SVM with polynomial kernel
rng('default')
mdls{2} = fitcsvm(adultdata,'salary','KernelFunction','polynomial', ...
    'Standardize',true,'KernelScale','auto');

% Decision tree
rng('default')
mdls{3} = fitctree(adultdata,'salary');

% Naive Bayes
rng('default')
mdls{4} = fitcnb(adultdata,'salary');

% Ensemble of decision trees
rng('default')
mdls{5} = fitcensemble(adultdata,'salary');

スタッキングを使用したモデルの結合

学習データでベース モデルの予測スコアのみを使用すると、スタッキング アンサンブルが過適合となる可能性があります。過適合を減らすには、代わりに k 分割交差検証されたスコアを使用します。確実に同じ k 分割のデータ分割を使用して各モデルに学習させるには、cvpartitionオブジェクトを作成し、そのオブジェクトを各ベース モデルの関数 crossval に渡します。この例はバイナリ分類問題なので、考慮する必要があるのは陽性または陰性いずれかのクラスのスコアのみです。

k 分割交差検証スコアを取得します。

rng('default') % For reproducibility
N = numel(mdls);
Scores = zeros(size(adultdata,1),N);
cv = cvpartition(adultdata.salary,"KFold",5);
for ii = 1:N
    m = crossval(mdls{ii},'cvpartition',cv);
    [~,s] = kfoldPredict(m);
    Scores(:,ii) = s(:,m.ClassNames=='<=50K');
end

次のオプションで交差検証された分類スコア Scores で学習を行って、スタッキング アンサンブルを作成します。

  • スタッキング アンサンブルに対して最良の結果を得るには、そのハイパーパラメーターを最適化します。近似関数を呼び出し、その名前と値のペアの引数 'OptimizeHyperparameters''auto' に設定することで、学習データ セットを近似し、パラメーターを簡単に調整できます。

  • 'Verbose' を 0 に指定して、メッセージ表示を無効にします。

  • 再現性を得るために、乱数シードを設定し、'expected-improvement-plus' の獲得関数を使用します。また、ランダム フォレスト アルゴリズムの再現性を得るため、木学習器の名前と値のペアの引数 'Reproducible'true に指定します。

rng('default') % For reproducibility
t = templateTree('Reproducible',true);
stckdMdl = fitcensemble(Scores,adultdata.salary, ...
    'OptimizeHyperparameters','auto', ...
    'Learners',t, ...
    'HyperparameterOptimizationOptions',struct('Verbose',0,'AcquisitionFunctionName','expected-improvement-plus'));

予測精度の比較

混同行列およびマクネマーの仮説検定を使用して、テスト データ セットで分類器の性能をチェックします。

テスト データのラベルとスコアの予測

ベース モデルおよびスタッキング アンサンブルに対するテスト データ セットの予測ラベル、スコア、損失値を求めます。

まず、ベース モデルに対して反復処理を行い、予測ラベル、スコア、および損失値を計算します。

label = [];
score = zeros(size(adulttest,1),N);
mdlLoss = zeros(1,numel(mdls));
for i = 1:N
    [lbl,s] = predict(mdls{i},adulttest);
    label = [label,lbl];
    score(:,i) = s(:,m.ClassNames=='<=50K');
    mdlLoss(i) = mdls{i}.loss(adulttest);
end

スタッキング アンサンブルからの予測を label および mdlLoss に追加します。

[lbl,s] = predict(stckdMdl,score);
label = [label,lbl];
mdlLoss(end+1) = stckdMdl.loss(score,adulttest.salary);

スタッキング アンサンブルのスコアをベース モデルのスコアに連結します。

score = [score,s(:,1)];

損失値を表示します。

names = {'SVM-Gaussian','SVM-Polynomial','Decision Tree','Naive Bayes', ...
    'Ensemble of Decision Trees','Stacked Ensemble'};
array2table(mdlLoss,'VariableNames',names)
ans=1×6 table
    SVM-Gaussian    SVM-Polynomial    Decision Tree    Naive Bayes    Ensemble of Decision Trees    Stacked Ensemble
    ____________    ______________    _____________    ___________    __________________________    ________________

      0.15668          0.17473           0.1975          0.16764               0.15833                  0.14519     

スタッキング アンサンブルの損失値は、ベース モデルの損失値よりも低くなっています。

混同行列

関数confusionchartを使用して、テスト データ セットの予測したクラスおよび既知の (true) クラスをもつ混同行列を計算します。

figure
c = cell(N+1,1);
for i = 1:numel(c)
    subplot(2,3,i)
    c{i} = confusionchart(adulttest.salary,label(:,i));
    title(names{i})
end

対角要素は、特定のクラスの正しく分類されたインスタンスの数を示しています。非対角要素は誤分類した観測値のインスタンスです。

マクネマーの仮説検定

予測の改善が有意であるかどうかをテストするには、マクネマーの仮説検定を実行する関数testcholdoutを使用します。スタッキング アンサンブルと単純ベイズ モデルを比較します。

 [hNB,pNB] = testcholdout(label(:,6),label(:,4),adulttest.salary)
hNB = logical
   1

pNB = 9.7646e-07

スタッキング アンサンブルと決定木のアンサンブルを比較します。

 [hE,pE] = testcholdout(label(:,6),label(:,5),adulttest.salary)
hE = logical
   1

pE = 1.9357e-04

いずれの場合も、スタッキング アンサンブルの p 値が低いことは、その予測が他のモデルの予測よりも統計的に優れていることを確証しています。