Main Content

RobustBoost の調整

RobustBoost アルゴリズムは、学習データにノイズが存在する場合でも、優れた分類予測を実行できます。ただし、RobustBoost パラメーターを既定の設定で使用すると、予測精度が十分ではないアンサンブルが生成される可能性があります。この例では、予測精度を高めるためにパラメーターを調整する方法の 1 つを示します。

ラベル ノイズのあるデータの生成。この例では、観測値あたり 20 の一様乱数を使用しており、最初の 5 つの数字の合計が 2.5 を超えた場合 (つまり、平均より大きい場合) には、観測値を 1 に分類し、それ以外の場合には、0 に分類します。

rng(0,'twister') % for reproducibility
Xtrain = rand(2000,20);
Ytrain = sum(Xtrain(:,1:5),2) > 2.5;

ノイズを追加するために、分類の 10% をランダムに入れ替えます。

idx = randsample(2000,200);
Ytrain(idx) = ~Ytrain(idx);

比較の対象とするために、AdaBoostM1 を使用してアンサンブルを作成します。

ada = fitcensemble(Xtrain,Ytrain,'Method','AdaBoostM1', ...
    'NumLearningCycles',300,'Learners','Tree','LearnRate',0.1);

RobustBoost でアンサンブルを作成します。データの 10% は間違って分類されているため、この場合、15% という誤差の目標値は適切です。

rb1 = fitcensemble(Xtrain,Ytrain,'Method','RobustBoost', ...
    'NumLearningCycles',300,'Learners','Tree','RobustErrorGoal',0.15, ...
    'RobustMaxMargin',1);

誤差の目標値がかなり高い値に設定されている場合、エラーが返されることに注意してください。

非常に楽観的な誤差の目標値 0.01 をもつアンサンブルを作成します。

rb2 = fitcensemble(Xtrain,Ytrain,'Method','RobustBoost', ...
    'NumLearningCycles',300,'Learners','Tree','RobustErrorGoal',0.01);

3 つのアンサンブルの再代入誤差を比較します。

figure
plot(resubLoss(rb1,'Mode','Cumulative'));
hold on
plot(resubLoss(rb2,'Mode','Cumulative'),'r--');
plot(resubLoss(ada,'Mode','Cumulative'),'g.');
hold off;
xlabel('Number of trees');
ylabel('Resubstitution error');
legend('ErrorGoal=0.15','ErrorGoal=0.01',...
    'AdaBoostM1','Location','NE');

Figure contains an axes object. The axes object with xlabel Number of trees, ylabel Resubstitution error contains 3 objects of type line. One or more of the lines displays its values using only markers These objects represent ErrorGoal=0.15, ErrorGoal=0.01, AdaBoostM1.

RobustBoost の曲線はすべて、AdaBoostM1 の曲線より低い再代入誤差を示しています。誤差の目標値が 0.01 の曲線は、ほとんどの範囲で最も低い再代入誤差を示しています。

Xtest = rand(2000,20);
Ytest = sum(Xtest(:,1:5),2) > 2.5;
idx = randsample(2000,200);
Ytest(idx) = ~Ytest(idx);
figure;
plot(loss(rb1,Xtest,Ytest,'Mode','Cumulative'));
hold on
plot(loss(rb2,Xtest,Ytest,'Mode','Cumulative'),'r--');
plot(loss(ada,Xtest,Ytest,'Mode','Cumulative'),'g.');
hold off;
xlabel('Number of trees');
ylabel('Test error');
legend('ErrorGoal=0.15','ErrorGoal=0.01',...
    'AdaBoostM1','Location','NE');

Figure contains an axes object. The axes object with xlabel Number of trees, ylabel Test error contains 3 objects of type line. One or more of the lines displays its values using only markers These objects represent ErrorGoal=0.15, ErrorGoal=0.01, AdaBoostM1.

誤差の目標値が 0.15 の誤差曲線は、プロットの範囲では最も低い値 (最も精度が高い) であることを示しています。AdaBoostM1 は、誤差の目標値 0.15 の曲線より誤差が大きくなっています。過度に楽観的な誤差の目標値 0.01 の曲線は、プロットされたほとんどの範囲において、他の曲線よりも顕著に高い誤差 (低い精度) を示しています。

参考

| |

関連するトピック