Main Content

このページの内容は最新ではありません。最新版の英語を参照するには、ここをクリックします。

predict

一般化加法モデル (GAM) を使用した観測値の分類

R2021a 以降

    説明

    label = predict(Mdl,X) は、バイナリ分類用の一般化加法モデル Mdl に基づいて、table または行列 X 内の予測子データに対する予測クラス ラベルのベクトルを返します。学習済みのモデルは、完全でもコンパクトでもかまいません。

    X 内の各観測値について、予測クラス ラベルは最小の予測誤分類コストに対応します。

    label = predict(Mdl,X,'IncludeInteractions',includeInteractions) は、計算に交互作用項を含めるかどうかを指定します。

    [label,score] = predict(___) は、前の構文におけるいずれかの入力引数の組み合わせを使用して、分類スコアも返します。

    すべて折りたたむ

    学習標本を使用して一般化加法モデルに学習させてから、テスト標本にラベルを付けます。

    fisheriris データ セットを読み込みます。versicolor と virginica のアヤメについてのがく片と花弁の測定値が含まれる数値行列 X を作成します。対応するアヤメの種類が含まれる文字ベクトルの cell 配列 Y を作成します。

    load fisheriris
    inds = strcmp(species,'versicolor') | strcmp(species,'virginica');
    X = meas(inds,:);
    Y = species(inds,:);

    Y のクラス情報を使用して、観測値を階層的に学習セットとテスト セットに無作為に分割します。テスト用の 30% のホールドアウト標本を指定します。

    rng('default') % For reproducibility
    cv = cvpartition(Y,'HoldOut',0.30);

    学習インデックスとテスト インデックスを抽出します。

    trainInds = training(cv);
    testInds = test(cv);

    学習データ セットとテスト データ セットを指定します。

    XTrain = X(trainInds,:);
    YTrain = Y(trainInds);
    XTest = X(testInds,:);
    YTest = Y(testInds);

    予測子 XTrain とクラス ラベル YTrain を使用して、一般化加法モデルに学習させます。クラス名を指定することが推奨されます。

    Mdl = fitcgam(XTrain,YTrain,'ClassNames',{'versicolor','virginica'})
    Mdl = 
      ClassificationGAM
                 ResponseName: 'Y'
        CategoricalPredictors: []
                   ClassNames: {'versicolor'  'virginica'}
               ScoreTransform: 'logit'
                    Intercept: -1.1090
              NumObservations: 70
    
    
    

    MdlClassificationGAM モデル オブジェクトです。

    テスト標本のラベルを予測します。

    label = predict(Mdl,XTest);

    真のラベルと予測ラベルを格納する table を作成します。10 件の観測値の無作為なセットについて table を表示します。

    t = table(YTest,label,'VariableNames',{'True Label','Predicted Label'});
    idx = randsample(sum(testInds),10);
    t(idx,:)
    ans=10×2 table
          True Label      Predicted Label
        ______________    _______________
    
        {'virginica' }    {'virginica' } 
        {'virginica' }    {'virginica' } 
        {'versicolor'}    {'virginica' } 
        {'virginica' }    {'virginica' } 
        {'virginica' }    {'virginica' } 
        {'versicolor'}    {'versicolor'} 
        {'versicolor'}    {'versicolor'} 
        {'versicolor'}    {'versicolor'} 
        {'versicolor'}    {'versicolor'} 
        {'virginica' }    {'virginica' } 
    
    

    真のラベル YTest と予測ラベル label から混同チャートを作成します。

    cm = confusionchart(YTest,label);

    Figure contains an object of type ConfusionMatrixChart.

    予測子の線形項と交互作用項の両方が格納されている分類 GAM を使用して、新しい観測値の事後確率のロジットを推定します。メモリ効率の高いモデル オブジェクトを使用して新しい観測値を分類します。新しい観測値を分類する際に交互作用項を含めるかどうかを指定します。

    ionosphere データ セットを読み込みます。このデータ セットには、レーダー反射についての 34 個の予測子と、不良 ('b') または良好 ('g') という 351 個の二項反応が含まれています。

    load ionosphere

    データ セットを 2 つのセットに分割します。1 つは学習データを含め、もう 1 つは新しい未観測のテスト データを含めます。新しいテスト データ セットの 10 件の観測値を保持します。

    rng('default') % For reproducibility
    n = size(X,1);
    newInds = randsample(n,10);
    inds = ~ismember(1:n,newInds);
    XNew = X(newInds,:);
    YNew = Y(newInds);

    予測子 X とクラス ラベル Y を使用して、GAM に学習させます。クラス名を指定することが推奨されます。上位 10 個の最も重要な交互作用項を含めるように指定します。

    Mdl = fitcgam(X(inds,:),Y(inds),'ClassNames',{'b','g'},'Interactions',10);

    MdlClassificationGAM モデル オブジェクトです。

    学習させたモデルのサイズを減らし、メモリの消費量を抑えます。

    CMdl = compact(Mdl);
    whos('Mdl','CMdl')
      Name      Size              Bytes  Class                                                 Attributes
    
      CMdl      1x1             1081260  classreg.learning.classif.CompactClassificationGAM              
      Mdl       1x1             1282819  ClassificationGAM                                               
    

    CMdlCompactClassificationGAM モデル オブジェクトです。

    線形項と交互作用項の両方を使用してラベルを予測してから、線形項のみを使用してラベルを予測します。交互作用項を除外するには、'IncludeInteractions',false を指定します。ScoreTransform プロパティとして 'none' を指定して、事後確率のロジットを推定します。

    CMdl.ScoreTransform = 'none';
    [labels,scores] = predict(CMdl,XNew);
    [labels_nointeraction,scores_nointeraction] = predict(CMdl,XNew,'IncludeInteractions',false);
    t = table(YNew,labels,scores,labels_nointeraction,scores_nointeraction, ...
        'VariableNames',{'True Labels','Predicted Labels','Scores' ...
        'Predicted Labels Without Interactions','Scores Without Interactions'})
    t=10×5 table
        True Labels    Predicted Labels          Scores          Predicted Labels Without Interactions    Scores Without Interactions
        ___________    ________________    __________________    _____________________________________    ___________________________
    
           {'g'}            {'g'}           -40.23      40.23                    {'g'}                        -37.484     37.484     
           {'g'}            {'g'}          -41.215     41.215                    {'g'}                        -38.737     38.737     
           {'g'}            {'g'}          -44.413     44.413                    {'g'}                        -42.186     42.186     
           {'g'}            {'b'}           3.0658    -3.0658                    {'b'}                         1.4338    -1.4338     
           {'g'}            {'g'}          -84.637     84.637                    {'g'}                        -81.269     81.269     
           {'g'}            {'g'}           -27.44      27.44                    {'g'}                        -24.831     24.831     
           {'g'}            {'g'}          -62.989     62.989                    {'g'}                          -60.4       60.4     
           {'g'}            {'g'}          -77.109     77.109                    {'g'}                        -75.937     75.937     
           {'g'}            {'g'}          -48.519     48.519                    {'g'}                        -47.067     47.067     
           {'g'}            {'g'}          -56.256     56.256                    {'g'}                        -53.373     53.373     
    
    

    テスト データ Xnew の予測ラベルは交互作用項を含めても変化しませんが、推定スコア値は異なります。

    一般化加法モデルに学習させてから、1 番目のクラスの確率値を使用して事後確率領域をプロットします。

    fisheriris データ セットを読み込みます。versicolor と virginica のアヤメについての 2 つの花弁の測定値が含まれる数値行列 X を作成します。対応するアヤメの種類が含まれる文字ベクトルの cell 配列 Y を作成します。

    load fisheriris
    inds = strcmp(species,'versicolor') | strcmp(species,'virginica');
    X = meas(inds,3:4);
    Y = species(inds,:);

    予測子 X とクラス ラベル Y を使用して、一般化加法モデルに学習させます。クラス名を指定することが推奨されます。

    Mdl = fitcgam(X,Y,'ClassNames',{'versicolor','virginica'});

    MdlClassificationGAM モデル オブジェクトです。

    観測された予測子領域の値のグリッドを定義します。

    xMax = max(X);
    xMin = min(X);
    x1 = linspace(xMin(1),xMax(1),250);
    x2 = linspace(xMin(2),xMax(2),250);
    [x1Grid,x2Grid] = meshgrid(x1,x2);

    グリッド内の各インスタンスの事後確率を予測します。

    [~,PosteriorRegion] = predict(Mdl,[x1Grid(:),x2Grid(:)]);

    1 番目のクラス 'versicolor' の確率値を使用して事後確率領域をプロットします。

    h = scatter(x1Grid(:),x2Grid(:),1,PosteriorRegion(:,1));
    h.MarkerEdgeAlpha = 0.3;

    学習データをプロットします。

    hold on
    gh = gscatter(X(:,1),X(:,2),Y,'k','dx');
    title('Iris Petal Measurements and Posterior Probabilities')
    xlabel('Petal length (cm)')
    ylabel('Petal width (cm)')
    legend(gh,'Location','Best')
    colorbar
    hold off

    Figure contains an axes object. The axes object with title Iris Petal Measurements and Posterior Probabilities, xlabel Petal length (cm), ylabel Petal width (cm) contains 3 objects of type scatter, line. One or more of the lines displays its values using only markers These objects represent versicolor, virginica.

    入力引数

    すべて折りたたむ

    一般化加法モデル。ClassificationGAM または CompactClassificationGAM モデル オブジェクトとして指定します。

    予測子データ。数値行列またはテーブルとして指定します。

    X の各行は 1 つの観測値に対応し、各列は 1 つの変数に対応します。

    • 数値行列の場合

      • X の列を構成する変数の順序は、Mdl に学習させた予測子変数の順序と同じでなければなりません。

      • table を使用して Mdl に学習をさせた場合、table に含まれている予測子変数がすべて数値変数であれば、X を数値行列にすることができます。

    • テーブルの場合

      • table (たとえば Tbl) を使用して Mdl に学習をさせた場合、X 内のすべての予測子変数は変数名およびデータ型が Tbl 内の変数と同じでなければなりません。ただし、X の列の順序が Tbl の列の順序に対応する必要はありません。

      • 数値行列を使用して Mdl に学習をさせた場合、Mdl.PredictorNames 内の予測子名と X 内の対応する予測子変数名が同じでなければなりません。学習時に予測子の名前を指定するには、名前と値の引数 'PredictorNames' を使用します。X 内の予測子変数はすべて数値ベクトルでなければなりません。

      • X に追加の変数 (応答変数や観測値の重みなど) を含めることができますが、predict はこれらを無視します。

      • predict は、文字ベクトルの cell 配列ではない cell 配列や複数列の変数をサポートしません。

    データ型: table | double | single

    モデルの交互作用項を含むというフラグ。true または false として指定します。

    Mdl に交互作用項が含まれる場合、includeInteractions の既定値は true です。モデルに交互作用項が含まれない場合、値は false でなければなりません。

    データ型: logical

    出力引数

    すべて折りたたむ

    予測クラス ラベル。categorical 配列、文字配列、logical ベクトル、数値ベクトル、または文字ベクトルの cell 配列として返されます。

    Mdl.ScoreTransform'logit' (既定) の場合、label の各エントリは、X の対応する行の予測誤分類コストが最小のクラスに対応します。それ以外の場合は、各エントリはスコアが最大のクラスに対応します。

    label は、Mdl に学習させた観測済みクラス ラベルと同じデータ型になり、X の行数と同じ長さになります。(string 配列は文字ベクトルの cell 配列として扱われます)。

    予測事後確率またはクラス スコア。X と同じ行数の 2 列の数値行列として返されます。score の 1 列目と 2 列目に、対応する観測値の 1 番目のクラス (陰性クラス、Mdl.ClassNames(1)) と 2 番目のクラス (陽性クラス、Mdl.ClassNames(2)) のスコア値がそれぞれ格納されます。

    Mdl.ScoreTransform'logit' (既定) の場合、スコア値は事後確率です。Mdl.ScoreTransform'none' の場合、スコア値は事後確率のロジットです。ソフトウェアに組み込みのスコア変換関数がいくつか用意されています。詳細については、MdlScoreTransform プロパティを参照してください。

    スコア変換を変更するには、学習時に fitcgam の引数 'ScoreTransform' を指定するか、学習後に ScoreTransform プロパティを変更します。

    詳細

    すべて折りたたむ

    予測クラス ラベル

    predict は、予測される誤分類コストを最小化することにより分類します。

    y^=argminy=1,...,Kj=1KP^(j|x)C(y|j),

    ここで、

    • y^ は、予測された分類です。

    • K は、クラスの数です。

    • P^(j|x) は、観測値 x のクラス j の事後確率です。

    • C(y|j) は、真のクラスが j の場合に観測値を y として分類するコストです。

    予測誤分類コスト

    観測値ごとの予測誤分類コストは、観測をそれぞれのクラスに分類する平均コストです。

    学習済みの分類器を使用して Nobs 個の観測値を分類するとします。また、K 個のクラスがあるとします。1 行に 1 観測ずつ、観測値を行列 X に置きます。

    予測コスト行列 CE のサイズは、NobsK列です。CE の各行には、観測をそれぞれのクラス K に分類する予測 (平均) コストが含まれます。CE(n,k) は次のとおりです。

    i=1KP^(i|X(n))C(k|i),

    ここで、

    • K は、クラスの数です。

    • P^(i|X(n)) は、観測値 X(n) のクラス i の事後確率です。

    • C(k|i) は、真のクラスが i である観測値を k に分類する真の誤分類コストです。

    真の誤分類コスト

    真の誤分類コストは、観測値を誤ったクラスに分類するコストです。

    分類器の作成時に、名前と値の引数 Cost を使用してクラスごとの真の誤分類コストを設定できます。Cost(i,j) は、真のクラスが i の場合に観測値をクラス j に分類するコストです。既定では、Cost(i,j)=1 (i~=j の場合) および Cost(i,j)=0 (i=j の場合) です。つまり、正しい分類のコストは 0、誤った分類のコストは 1 です。

    バージョン履歴

    R2021a で導入