分類
この例では、判別分析、単純ベイズ分類器、決定木による分類を実行する方法を示します。さまざまな変数 (「予測子変数」と呼びます) に関する測定値と既知のクラス ラベルがある観測で構成されるデータ セットがあるとします。新しい観測の予測値を入手した場合、その観測がおそらくどのクラスに属するのかを判定できるでしょうか。これは分類の問題です。
フィッシャーのアヤメのデータ
フィッシャーのアヤメのデータは、アヤメの標本 150 個のがく片の長さと幅、花弁の長さと幅に関する測定値で構成されます。3 種それぞれについて 50 個の標本があります。データを読み込んで、がく片の測定値が種間でどのように異なるのかを調べてみましょう。がく片の測定値を格納した 2 つの列を使用することができます。
load fisheriris f = figure; gscatter(meas(:,1), meas(:,2), species,'rgb','osd'); xlabel('Sepal length'); ylabel('Sepal width');
N = size(meas,1);
1 本のアヤメのがく片と花弁を測定し、その測定値に基づいて種を判定する必要があるとします。この問題を解く 1 つのアプローチは、「判別分析」と呼ばれます。
線形判別分析と 2 次判別分析
関数 fitcdiscr
は、さまざまな判別分析を使用して分類を行うことができます。最初に、既定の線形判別分析 (LDA) を使用してデータを分類します。
lda = fitcdiscr(meas(:,1:2),species); ldaClass = resubPredict(lda);
既知のクラス ラベルがある観測は通常、「学習データ」と呼ばれます。ここで、再代入誤差を計算します。これは、学習セットに関する誤分類誤差 (誤分類された観測の比率) です。
ldaResubErr = resubLoss(lda)
ldaResubErr = 0.2000
学習セットに関する混同行列を計算することもできます。混同行列には、既知のクラス ラベルと予測されたクラス ラベルについての情報が格納されます。一般に、混同行列の (i,j) 要素は、既知のクラス ラベルがクラス i、予測されるクラスが j である標本の数を表します。対角要素は、正しく分類された観測値を表します。
figure ldaResubCM = confusionchart(species,ldaClass);
150 個の学習観測値のうち、20% つまり 30 個の観測値が線形判別関数によって誤分類されています。どの観測値が誤分類されたのかを具体的に確認するには、誤分類された点を通る X を描きます。
figure(f) bad = ~strcmp(ldaClass,species); hold on; plot(meas(bad,1), meas(bad,2), 'kx'); hold off;
この関数により、平面が直線で複数の領域に分割され、種ごとに別の領域に割り当てられました。この領域を可視化する 1 つの方法は、(x,y)
値のグリッドを作成し、そのグリッドに分類関数を適用することです。
[x,y] = meshgrid(4:.1:8,2:.1:4.5); x = x(:); y = y(:); j = classify([x y],meas(:,1:2),species); gscatter(x,y,j,'grb','sod')
データ セットのなかには、さまざまなクラスの領域が直線ではっきりと分割されないものもあります。その場合には、線形判別分析は適切ではありません。むしろ、ここで紹介するデータには 2 次判別分析 (QDA) を試すことができます。
2 次判別分析の再代入誤差を計算します。
qda = fitcdiscr(meas(:,1:2),species,'DiscrimType','quadratic'); qdaResubErr = resubLoss(qda)
qdaResubErr = 0.2000
再代入誤差を計算できました。人は一般に、テスト誤差 (「汎化誤差」とも呼ばれます) の方に関心をもつものです。この誤差は、独立したセットに関して見込まれる予測誤差です。実際、再代入誤差ではテスト誤差が過小評価されがちです。
この場合、ラベル付けされた別のデータ セットはありませんが、交差検証を行うことによりシミュレートすることができます。分類アルゴリズムでのテスト誤差を推定するには、階層化された 10 分割交差検証がよく使用されます。この検定では、学習セットが 10 個の互いに素のサブセットに無作為に分割されます。各サブセットは、サイズがほぼ等しく、学習セット内でのクラスの比率とほぼ同じクラスの比率をもっています。1 つのサブセットを削除し、他の 9 個のサブセットで分類モデルに学習させ、削除されたサブセットを学習済みモデルを使用して分類します。これを一度に 1 つのサブセットを 10 個のサブセットから削除して繰り返します。
交差検証によってデータが無作為に分割されるので、結果は初期乱数シードで決まります。この例でまったく同じ結果を再現するには、次のコマンドを実行します。
rng(0,'twister');
最初に、関数 cvpartition
を使用して 10 個の互いに素の階層化されたサブセットを生成します。
cp = cvpartition(species,'KFold',10)
cp = K-fold cross validation partition NumObservations: 150 NumTestSets: 10 TrainSize: 135 135 135 135 135 135 135 135 135 135 TestSize: 15 15 15 15 15 15 15 15 15 15 IsCustom: 0
crossval
メソッドと kfoldLoss
メソッドでは、指定したデータ分割 cp
を使用して LDA と QDA の両方について誤分類誤差を推定できます。
階層化された 10 分割交差検証を使用して、LDA の真のテスト誤差を推定します。
cvlda = crossval(lda,'CVPartition',cp);
ldaCVErr = kfoldLoss(cvlda)
ldaCVErr = 0.2000
このデータに関する LDA の交差検証誤差は、LDA の再代入誤差と同じ値です。
階層化された 10 分割交差検証を使用して、QDA の真のテスト誤差を推定します。
cvqda = crossval(qda,'CVPartition',cp);
qdaCVErr = kfoldLoss(cvqda)
qdaCVErr = 0.2200
QDA の交差検証誤差の値は、LDA の場合より少し大きいです。これは、モデルが簡単であればあるほど類似度が高くなるか、または複雑なモデルより成績が良くなることを示しています。
単純ベイズ分類器
関数 fitcdiscr
には、他にも 'DiagLinear'
および 'DiagQuadratic'
という 2 つの種類があります。これらは 'linear'
および 'quadratic'
と似ていますが、対角の共分散行列の推定値がある点が異なります。これらの対角性の選択肢は、単純ベイズ分類器の具体例です。クラス ラベルが与えられた場合、変数が条件的に独立しているものと仮定されるからです。単純ベイズ分類器は、最も一般的な分類器の一種です。クラスの条件付きの下で変数が互いに独立であるという仮定は、一般には成り立ちませんが、多くのデータ セットで単純ベイズ分類器が実際のところうまくいくことが確認されています。
関数 fitcnb
を使用すると、より一般的な種類の単純ベイズ分類器を作成できます。
最初に、ガウス分布を使用して各クラスの各変数をモデル化します。再代入誤差と交差検証誤差を計算することができます。
nbGau = fitcnb(meas(:,1:2), species); nbGauResubErr = resubLoss(nbGau)
nbGauResubErr = 0.2200
nbGauCV = crossval(nbGau, 'CVPartition',cp);
nbGauCVErr = kfoldLoss(nbGauCV)
nbGauCVErr = 0.2200
labels = predict(nbGau, [x y]); gscatter(x,y,labels,'grb','sod')
これまでは、各クラスの変数に多変量正規分布があると仮定してきました。たいていの場合、これは理にかなった仮定です。しかし、このように仮定したくないかまたはこの仮定が明らかに無効であるとわかる場合もあります。そこで、各クラスの変数をカーネル密度推定を使用してモデル化してみましょう。これは、より柔軟性に富むノンパラメトリックな手法です。ここで、カーネルを box
に設定します。
nbKD = fitcnb(meas(:,1:2), species, 'DistributionNames','kernel', 'Kernel','box'); nbKDResubErr = resubLoss(nbKD)
nbKDResubErr = 0.2067
nbKDCV = crossval(nbKD, 'CVPartition',cp);
nbKDCVErr = kfoldLoss(nbKDCV)
nbKDCVErr = 0.2133
labels = predict(nbKD, [x y]); gscatter(x,y,labels,'rgb','osd')
このデータ セットの場合、単純ベイズ分類器にカーネル密度推定を適用すると、ガウス分布を適用した場合より再代入誤差と交差検証誤差が小さくなります。
決定木
別の分類アルゴリズムは、決定木に基づきます。決定木は、単純な規則のセットです。たとえば、"がく片の長さが 5.45 未満なら、その標本を setosa (セトサ) に分類する。" です。決定木もノンパラメトリックです。各クラスの変数の分布について仮定がまったく不要だからです。
関数 fitctree
は決定木を作成します。アヤメのデータの決定木を作成して、アヤメが種にどのように分類されるのかを調べます。
t = fitctree(meas(:,1:2), species,'PredictorNames',{'SL' 'SW' });
決定木法で平面が分割される様子を見るのは興味深いものです。上と同じ手法で、各種に割り当てられた領域を可視化します。
[grpname,node] = predict(t,[x y]); gscatter(x,y,grpname,'grb','sod')
決定木を可視化する別の方法は、決定規則とクラス割り当ての図を描くことです。
view(t,'Mode','graph');
この乱れたように見える木では、"SL < 5.45" という形式の一連の規則を使用して、各標本を 19 個の終端ノードのいずれかに分類します。ある観測の種割り当てを判定するため、最上位ノードから開始して規則を適用していきます。点が規則を満たすなら左に進み、そうでないなら右に進みます。最終的には、観測値を 3 つの種のいずれかに代入する終端ノードに到達します。
決定木の再代入誤差と交差検証誤差を計算します。
dtResubErr = resubLoss(t)
dtResubErr = 0.1333
cvt = crossval(t,'CVPartition',cp);
dtCVErr = kfoldLoss(cvt)
dtCVErr = 0.3000
決定木アルゴリズムの場合、交差検証誤差の推定値は再代入誤差より有意に大きくなります。これは、生成された木が学習セットを過適合することを示しています。言い換えると、この木は元の学習セットをうまく分類していますが、木の構造はこの特定の学習セットの影響を受けやすいので、別の新しいデータの性能が低下する可能性があります。別の新しいデータの場合、複雑な木よりも単純な木の方が性能が優れていることがよくあります。
枝刈りを行います。最初に、元の木のさまざまなサブセットについて再代入誤差を計算します。次に、これらのサブツリーについて交差検証誤差を計算します。グラフは、再代入誤差が楽観的すぎることを示しています。これはツリーのサイズが大きくなるにつれて一貫して減っていきますが、ある点を越えると、ツリーのサイズが増えると交差検証誤差率が上昇するようになります。
resubcost = resubLoss(t,'Subtrees','all'); [cost,secost,ntermnodes,bestlevel] = cvloss(t,'Subtrees','all'); plot(ntermnodes,cost,'b-', ntermnodes,resubcost,'r--') figure(gcf); xlabel('Number of terminal nodes'); ylabel('Cost (misclassification error)') legend('Cross-validation','Resubstitution')
どの木を選択すべきでしょうか。簡単な規則は、交差検証誤差が最小の木を選択することです。これで十分かもしれませんが、単純な木を使用しても複雑な木で得られるのとほぼ同じ結果になるのであれば、単純な木を選択することもできるでしょう。この例では、最小値の 1 標準誤差内にある最も単純な木を採用します。これは、ClassificationTree
の cvloss
メソッドで使用される既定の規則です。
このことは、カットオフ値を計算することによってグラフに示すことができます。このカットオフ値は、最小コストに 1 標準誤差を加えた値です。cvloss
法によって計算される "最高" のレベルは、このカットオフ値を下回る最小木です (bestlevel
=0 が何も刈り込まれていない木であることに注意してください。したがって、cvloss
のベクトル出力のインデックスとして使用するには、1 を加える必要があります)。
[mincost,minloc] = min(cost); cutoff = mincost + secost(minloc); hold on plot([0 20], [cutoff cutoff], 'k:') plot(ntermnodes(bestlevel+1), cost(bestlevel+1), 'mo') legend('Cross-validation','Resubstitution','Min + 1 std. err.','Best choice') hold off
最後に、刈り込まれた木を調べて、推定誤分類誤差を計算することができます。
pt = prune(t,'Level',bestlevel); view(pt,'Mode','graph')
cost(bestlevel+1)
ans = 0.2467
まとめ
この例では、Statistics and Machine Learning Toolbox™ の関数を使用して MATLAB® で分類を行う方法を説明します。
この例の目的は、フィッシャーのアヤメのデータの理想的な分析を示すことではありません。実際、がく片ではなく花弁の測定値を使用したり、がく片と花弁の測定値をあわせて使用すると、分類の精度が向上することがあります。また、この例の目的は、さまざまな分類アルゴリズムの強みと弱みを比較することでもありません。データ セットをさまざまに変えて分析を行い、さまざまなアルゴリズムを比較すると、得るところが大きいでしょう。Statistics and Machine Learning Toolbox には、他の分類アルゴリズムを実装した関数もあります。たとえば、TreeBagger の使用による分類木の bootstrap aggregation (バギング)の例で説明されているように、関数 TreeBagger
を使用して決定木のアンサンブルについてバギングを実行できます。