Main Content

混合ガウス モデルの調整

この例では、成分数と成分の共分散行列の構造を調整することにより最適な混合ガウス モデル (GMM) 近似を決定する方法を示します。

フィッシャーのアヤメのデータセットを読み込みます。花弁の測定値を予測子と考えます。

load fisheriris
X = meas(:,3:4);
[n,p] = size(X);
rng(1) % For reproducibility

figure
plot(X(:,1),X(:,2),'.','MarkerSize',15)
title('Fisher''s Iris Data Set')
xlabel('Petal length (cm)')
ylabel('Petal width (cm)')

Figure contains an axes object. The axes object with title Fisher's Iris Data Set contains an object of type line.

必要な成分またはクラスターの数を k、すべての成分の共分散の構造を Σ とします。次の手順に従って GMM を調整します。

  1. "k"Σ のペアを選択してから、選択したパラメーター指定とデータセット全体を使用して GMM を当てはめます。

  2. AIC と BIC を推定します。

  3. 関心があるすべての "k"Σ のペアを網羅するまで、手順 1 および 2 を繰り返します。

  4. 低い AIC と単純さのバランスがとれている、近似させた GMM を選択します。

この例では、k の値として 2 と 3 および周囲の数を選択します。共分散の構造について可能なすべての選択肢を指定します。k がデータセットに対して大きすぎる場合、推定した成分の共分散は悪条件になる可能性があります。悪条件の共分散行列を回避するため、正則化を使用するよう指定します。EM アルゴリズムの反復数を 10000 に増やします。

k = 1:5;
nK = numel(k);
Sigma = {'diagonal','full'};
nSigma = numel(Sigma);
SharedCovariance = {true,false};
SCtext = {'true','false'};
nSC = numel(SharedCovariance);
RegularizationValue = 0.01;
options = statset('MaxIter',10000);

すべてのパラメーターの組み合わせを使用して GMM を近似させます。各近似について AIC と BIC を計算します。近似ごとに末端の収束状況を追跡します。

% Preallocation
gm = cell(nK,nSigma,nSC);         
aic = zeros(nK,nSigma,nSC);
bic = zeros(nK,nSigma,nSC);
converged = false(nK,nSigma,nSC);

% Fit all models
for m = 1:nSC
    for j = 1:nSigma
        for i = 1:nK
            gm{i,j,m} = fitgmdist(X,k(i),...
                'CovarianceType',Sigma{j},...
                'SharedCovariance',SharedCovariance{m},...
                'RegularizationValue',RegularizationValue,...
                'Options',options);
            aic(i,j,m) = gm{i,j,m}.AIC;
            bic(i,j,m) = gm{i,j,m}.BIC;
            converged(i,j,m) = gm{i,j,m}.Converged;
        end
    end
end

allConverge = (sum(converged(:)) == nK*nSigma*nSC)
allConverge = logical
   1

gm は、近似させたすべての gmdistribution モデル オブジェクトが含まれている cell 配列です。すべての近似インスタンスが収束します。

すべての近似で AIC と BIC を比較するため、別々の棒グラフをプロットします。これらの棒を k でグループ化します。

figure
bar(reshape(aic,nK,nSigma*nSC))
title('AIC For Various $k$ and $\Sigma$ Choices','Interpreter','latex')
xlabel('$k$','Interpreter','Latex')
ylabel('AIC')
legend({'Diagonal-shared','Full-shared','Diagonal-unshared',...
    'Full-unshared'})

Figure contains an axes object. The axes object with title AIC For Various k and Sigma Choices contains 4 objects of type bar. These objects represent Diagonal-shared, Full-shared, Diagonal-unshared, Full-unshared.

figure
bar(reshape(bic,nK,nSigma*nSC))
title('BIC For Various $k$ and $\Sigma$ Choices','Interpreter','latex')
xlabel('$c$','Interpreter','Latex')
ylabel('BIC')
legend({'Diagonal-shared','Full-shared','Diagonal-unshared',...
    'Full-unshared'})

Figure contains an axes object. The axes object with title BIC For Various k and Sigma Choices contains 4 objects of type bar. These objects represent Diagonal-shared, Full-shared, Diagonal-unshared, Full-unshared.

AIC と BIC の値によると、最適なモデルには 3 つの成分があり、共分散行列の構造は非スパースおよび非共有です。

最適な近似モデルを使用して、学習データをクラスタリングします。クラスタリングしたデータと成分の楕円をプロットします。

gmBest = gm{3,2,2};
clusterX = cluster(gmBest,X);
kGMM = gmBest.NumComponents;
d = 500;
x1 = linspace(min(X(:,1)) - 2,max(X(:,1)) + 2,d);
x2 = linspace(min(X(:,2)) - 2,max(X(:,2)) + 2,d);
[x1grid,x2grid] = meshgrid(x1,x2);
X0 = [x1grid(:) x2grid(:)];
mahalDist = mahal(gmBest,X0);
threshold = sqrt(chi2inv(0.99,2));

figure
h1 = gscatter(X(:,1),X(:,2),clusterX);
hold on
for j = 1:kGMM
    idx = mahalDist(:,j)<=threshold;
    Color = h1(j).Color*0.75 + -0.5*(h1(j).Color - 1);
    h2 = plot(X0(idx,1),X0(idx,2),'.','Color',Color,'MarkerSize',1);
    uistack(h2,'bottom')
end
plot(gmBest.mu(:,1),gmBest.mu(:,2),'kx','LineWidth',2,'MarkerSize',10)
title('Clustered Data and Component Structures')
xlabel('Petal length (cm)')
ylabel('Petal width (cm)')
legend(h1,'Cluster 1','Cluster 2','Cluster 3','Location','NorthWest')
hold off

Figure contains an axes object. The axes object with title Clustered Data and Component Structures contains 7 objects of type line. These objects represent Cluster 1, Cluster 2, Cluster 3.

このデータセットにはラベルが含まれています。各予測値を真のラベルと比較して、gmBest によるデータのクラスタリングがどの程度適切であるかを調べます。

species = categorical(species);
Y = zeros(n,1);
Y(species == 'versicolor') = 1;
Y(species == 'virginica') = 2;
Y(species == 'setosa') = 3;

miscluster = Y ~= clusterX;
clusterError = sum(miscluster)/n
clusterError = 0.0800

最適な近似を行う GMM は、観測値の 8% を誤ったクラスターにグループ化します。

cluster ではクラスターの順序が必ず保持されるとは限りません。つまり、複数の近似させた gmdistribution モデルをクラスタリングした場合、cluster は同じ成分に対して異なるクラスター ラベルを割り当てる可能性があります。

参考

| |

関連するトピック