分類木および回帰木の改善
fitctree
と fitrtree
に名前と値のペアを設定することによって、ツリーを調整できます。この節の残りの部分では、木の特性の判定方法、設定する名前と値のペアの決定方法、および木のサイズの制御方法について説明します。
再代入誤差の検査
"再代入誤差" とは、応答学習データと、ツリーが入力学習データに基づいて作成した応答と予測の間の差異を意味します。再代入誤差が大きい場合には、ツリーの予測結果が良好であることを期待することはできません。ただし、再代入誤差が小さい場合であっても、新しいデータに対する予測が必ずしも優れているとは限りません。再代入誤差では、新しいデータに対する予測誤差について過度に楽観的な推定が行われる場合が少なくありません。
分類木の再代入誤差
この例では、分類木の再代入誤差を調べる方法を示します。
フィッシャーのアヤメのデータを読み込みます。
load fisheriris
データセット全体を使用して、既定の分類木に学習をさせます。
Mdl = fitctree(meas,species);
再代入誤差を調べます。
resuberror = resubLoss(Mdl)
resuberror = 0.0200
ツリーでは、フィッシャーのアヤメのデータはほとんど正確に分類されています。
交差検証
新しいデータに対するツリーの予測精度を正確に把握するには、ツリーに交差検証を実施します。既定の設定では、交差検証は学習データを無作為に 10 個の部分に分割します。10 本の新しいツリーのそれぞれにおいて、データの 9 個の部分を学習させます。その後に、それぞれの新しいツリーにおいて、そのツリーの学習に含まれないデータに関する予測精度を検査します。この方式では、新しいデータに対して新しいツリーをテストすることになるため、作成するツリーの予測精度を正確に推定できます。
回帰木の交差検証
この例では、carsmall
のデータに基づいて燃費効率を予測するための回帰木について再代入と交差検証の精度を検査する方法を示します。
carsmall
データセットを読み込みます。加速度、排気量、馬力および重量が MPG の予測子であると考えます。
load carsmall
X = [Acceleration Displacement Horsepower Weight];
すべての観測値を使用して回帰木を成長させます。
rtree = fitrtree(X,MPG);
標本内誤差を計算します。
resuberror = resubLoss(rtree)
resuberror = 4.7188
回帰木の再代入損失は平均二乗誤差です。結果の値によれば、ツリーの標準的な予測誤差は 4.7 の平方根、つまり 2 より少し大きな値になります。
交差検証の MSE を推定します。
rng 'default';
cvrtree = crossval(rtree);
cvloss = kfoldLoss(cvrtree)
cvloss = 23.5706
交差検証損失は約 25 なので、新しいデータに対するツリーの標準的な予測誤差は約 5 になります。これは、交差検証損失が通常は単純な再代入損失より高くなることを示しています。
分割予測子選択手法の選択
標準 CART アルゴリズムには、レベル数の多い連続予測子を選択する傾向があります。このような選択は見せかけだけの場合があり、カテゴリカル予測子などのような、より少ないレベルのより重要な予測子を隠してしまう可能性があります。つまり、各ノードにおける予測子選択プロセスが偏ります。また、標準 CART には予測子のペアと応答の間の重要な交互作用を見落とす傾向があります。
選択の偏りを軽減し、より多くの重要な交互作用を検出するため、名前と値のペアの引数 'PredictorSelection'
を使用して曲率検定または交互作用検定の使用を指定できます。曲率検定または交互作用検定を使用すると、標準 CART より予測子重要度推定の精度が向上するというメリットも得られます。
次の表は、サポートされる予測子選択手法をまとめています。
手法 | 'PredictorSelection' の値 | 説明 | 学習速度 | どのような場合に指定するか |
---|---|---|---|---|
標準 CART[1] | 既定の設定 | すべての予測子の可能なすべての分割について分割基準ゲインを最大化する分割予測子を選択。 | 比較のベースライン | 次の条件のいずれかが満たされる場合に指定。
|
曲率検定[2][3] | 'curvature' | 各予測子とその応答との間の独立性についてのカイ二乗検定 p 値を最小化する分割予測子を選択。 | 標準 CART 相当 | 次の条件のいずれかが満たされる場合に指定。
|
交互作用検定[3] | 'interaction-curvature' | 各予測子とその応答との間の独立性についてのカイ二乗検定 p 値の最小化 (つまり、曲率テストの実行)、および各予測子ペアと応答との間の独立性についてのカイ二乗検定 p 値を最小化する分割予測子を選択。 | 多数の予測子変数がデータセットに含まれている場合は特に、標準 CART より低速。 | 次の条件のいずれかが満たされる場合に指定。
|
予測子選択手法の詳細については、以下を参照してください。
分類木については
PredictorSelection
とノード分割規則回帰木については
PredictorSelection
とノード分割規則
深さまたは "Leafiness" の制御
決定木を成長させるときは、単純さと予測力を考慮してください。リーフ数の多い深いツリーでは、通常、学習データについての精度が高くなります。ただし、独立したテスト セットで同等の精度を達成できるとは限りません。リーフ数の多いツリーは過学習 (過適合) になる傾向があるので、多くの場合、そのテスト精度は学習 (再代入) 精度より非常に小さくなります。反対に、浅いツリーの場合には、高い学習精度は達成されません。しかし、浅いツリーは安定性に優れており、学習の場合でも代表的なテスト セットの場合でも、それほど変わらない精度を実現できます。また、ツリーが浅ければ解釈も容易です。学習およびテストに必要なデータが十分でない場合は、交差検証を使用してツリーの精度を推定します。
fitctree
と fitrtree
には、生成される決定木の深さを制御する名前と値のペアの引数が 3 つあります。
MaxNumSplits
― 枝ノード分割の最大数は、ツリーあたりMaxNumSplits
です。MaxNumSplits
に大きい値を設定するとツリーが深くなります。既定の設定はsize(X,1) – 1
です。MinLeafSize
— 各リーフには少なくともMinLeafSize
の観測値があります。MinLeafSize
に小さい値を設定するとツリーが深くなります。既定の設定は1
です。MinParentSize
— ツリーの各枝ノードには少なくともMinParentSize
の観測値があります。MinParentSize
に小さい値を設定するとツリーが深くなります。既定の設定は10
です。
MinParentSize
と MinLeafSize
の両方を指定した場合、リーフが大きくなるツリー (浅いツリー) を生成する設定が学習器で使用されます。
MinParent = max(MinParentSize,2*MinLeafSize)
MaxNumSplits
を設定した場合、3 つの分割基準のいずれかが満たされるまでツリーが分割されます。
ツリーの深さを制御するその他の手法については、枝刈りを参照してください。
適切なツリーの深さの選択
この例では、決定木の深さを制御する方法、および適切な深さを選択する方法について示します。
ionosphere
データを読み込みます。
load ionosphere
10
から 100
までの指数関数的な間隔の値セットを生成します。これは、葉ノードごとの観測値の最小数を表します。
leafs = logspace(1,2,10);
ionosphere
のデータについて交差検証分類木を作成します。leafs
に格納されている葉の最小サイズを使用して各木を成長させるように指定します。
rng('default') N = numel(leafs); err = zeros(N,1); for n=1:N t = fitctree(X,Y,'CrossVal','On',... 'MinLeafSize',leafs(n)); err(n) = kfoldLoss(t); end plot(leafs,err); xlabel('Min Leaf Size'); ylabel('cross-validated error');
最適なリーフ サイズは、リーフあたりの観測が 20
から 50
までということがわかります。
リーフあたりの観測数が少なくとも 40
というほぼ最適に近いツリーと、観測数が親ノードあたり 10
、リーフあたり 1
である既定のツリーを比較します。
DefaultTree = fitctree(X,Y); view(DefaultTree,'Mode','Graph')
OptimalTree = fitctree(X,Y,'MinLeafSize',40); view(OptimalTree,'mode','graph')
resubOpt = resubLoss(OptimalTree); lossOpt = kfoldLoss(crossval(OptimalTree)); resubDefault = resubLoss(DefaultTree); lossDefault = kfoldLoss(crossval(DefaultTree)); resubOpt,resubDefault,lossOpt,lossDefault
resubOpt = 0.0883
resubDefault = 0.0114
lossOpt = 0.1054
lossDefault = 0.1054
最適に近いツリーはサイズがかなり小さくなり、再代入誤差は大きくなります。ただし、精度は交差検証データと同様です。
枝刈り
枝刈りでは、同じ木の枝にある葉をマージすることによって、木の深さ (葉の多さ) が最適化されます。深さまたは "Leafiness" の制御では、最適な木の深さを選択する方法の 1 つについて説明しています。その節とは異なり、すべてのノード サイズごとに新しいツリーを成長させる必要はありません。ここでは、深いツリーを成長させ、そのツリーを選択したレベルになるまで枝刈りします。
ツリーを枝刈りするには、コマンド ラインで prune
メソッド (分類木の場合)、または prune
メソッド (回帰木の場合) を使用します。または、木ビューアーで対話的に木を刈り込みます。
view(tree,'mode','graph')
ツリーを枝刈りするには、ツリーに枝刈り順序が含まれていなければなりません。既定の設定では、fitctree
と fitrtree
のどちらも、ツリーの構築中に枝刈りの順序を計算します。'Prune'
名前と値のペアが 'off'
に設定されたツリーを構築する場合、またはツリーをより小さいレベルまで枝刈りされる場合、ツリーには全体の枝刈りの順序は含まれません。全体の枝刈りの順序を生成するには、コマンド ラインで prune
メソッド (分類木の場合)、または prune
メソッド (回帰木の場合) を使用します。
分類木の枝刈り
この例では、ionosphere
データの分類木を作成し、適切なレベルに枝刈りします。
ionosphere
データを読み込みます。
load ionosphere
データの既定の分類木を構築します。
tree = fitctree(X,Y);
ツリーを対話的なビューアーで表示します。
view(tree,'Mode','Graph')
交差検証損失を最小化することによって、最適な枝刈りレベルを求めます。
[~,~,~,bestlevel] = cvLoss(tree,... 'SubTrees','All','TreeSize','min')
bestlevel = 6
ツリーをレベル 6
まで枝刈りします。
view(tree,'Mode','Graph','Prune',6)
または、対話型ウィンドウを使用してツリーを枝刈りします。
枝刈りされたツリーは、「適切なツリーの深さの選択」の例にある最適に近いツリーと同じになります。
'TreeSize'
を 'SE'
(既定の設定) に設定して、最適なレベルに標準偏差を加えたレベルをツリーの誤差が超過しない範囲で、最大の枝刈りレベルを検出します。
[~,~,~,bestlevel] = cvLoss(tree,'SubTrees','All')
bestlevel = 6
この場合は、'TreeSize'
の設定がどちらでもレベルは同じになります。
ツリーを枝刈りして、別の目的に利用します。
tree = prune(tree,'Level',6); view(tree,'Mode','Graph')
参照
[1] Breiman, L., J. H. Friedman, R. A. Olshen, and C. J. Stone. Classification and Regression Trees. Boca Raton, FL: Chapman & Hall, 1984.
[2] Loh, W.Y. and Y.S. Shih. “Split Selection Methods for Classification Trees.” Statistica Sinica, Vol. 7, 1997, pp. 815–840.
[3] Loh, W.Y. “Regression Trees with Unbiased Variable Selection and Interaction Detection.” Statistica Sinica, Vol. 12, 2002, pp. 361–386.
参考
fitctree
| fitrtree
| ClassificationTree
| RegressionTree
| predict (CompactRegressionTree)
| predict (CompactClassificationTree)
| prune (ClassificationTree)
| prune (RegressionTree)