メインコンテンツ

predict

アンサンブル分類モデルを使用したラベルの予測

説明

labels = predict(ens,X) は、学習済みのアンサンブル分類モデル (完全またはコンパクト) ens に基づいて、table または行列 X 内の予測子データに対する予測クラス ラベルのベクトルを返します。

labels = predict(ens,X,Name=Value) は、1 つ以上の名前と値の引数を使用して追加オプションを指定します。たとえば、予測に使用する弱学習器を指定したり、計算を並列に実行したりできます。

[labels,scores] = predict(___) は、前の構文におけるいずれかの入力引数の組み合わせを使用して、ラベルが特定のクラスから派生する尤度を示す分類スコアの行列も返します。X 内の各観測値について、予測クラス ラベルは、すべてのクラスの中で最大のスコアに対応します。

すべて折りたたむ

フィッシャーのアヤメのデータ セットを読み込みます。標本サイズを調べます。

load fisheriris
N = size(meas,1);

データを学習セットとテスト セットに分割します。データの 10% をテスト用にホールドアウトします。

rng(1); % For reproducibility
cvp = cvpartition(N,'Holdout',0.1);
idxTrn = training(cvp); % Training set indices
idxTest = test(cvp);    % Test set indices

学習データを table に格納します。

tblTrn = array2table(meas(idxTrn,:));
tblTrn.Y = species(idxTrn);

AdaBoostM2 と学習セットを使用してアンサンブル分類に学習させます。弱学習器として木の切り株を指定します。

t = templateTree('MaxNumSplits',1);
Mdl = fitcensemble(tblTrn,'Y','Method','AdaBoostM2','Learners',t);

テスト セットについてラベルを予測します。モデルの学習にはデータの table を使用しましたが、ラベルの予測には行列を使用できます。

labels = predict(Mdl,meas(idxTest,:));

テスト セットの混同行列を作成します。

confusionchart(species(idxTest),labels)

Figure contains an object of type ConfusionMatrixChart.

Mdl は、テスト セット内の 1 つの versicolor 種のアヤメを virginica として誤分類します。

ブースティング木のアンサンブルを作成し、各予測子の重要度を検査します。テスト データを使用して、アンサンブルの分類精度を評価します。

不整脈データ セットを読み込みます。データのクラス表現を判別します。

load arrhythmia
Y = categorical(Y);
tabulate(Y)
  Value    Count   Percent
      1      245     54.20%
      2       44      9.73%
      3       15      3.32%
      4       15      3.32%
      5       13      2.88%
      6       25      5.53%
      7        3      0.66%
      8        2      0.44%
      9        9      1.99%
     10       50     11.06%
     14        4      0.88%
     15        5      1.11%
     16       22      4.87%

データ セットには 16 個のクラスが含まれていますが、すべてのクラスは表現されていません (たとえば、クラス 13)。ほとんどの観測値は不整脈がないものとして分類されています (クラス 1)。このデータ セットは非常に離散的であり、クラスが不均衡です。

不整脈があるすべての観測値 (クラス 2 ~ 15) を 1 つのクラスに結合します。不整脈の状況が不明である観測値 (クラス 16) をデータ セットから削除します。

idx = (Y ~= "16");
Y = Y(idx);
X = X(idx,:);
Y(Y ~= "1") = "WithArrhythmia";
Y(Y == "1") = "NoArrhythmia";
Y = removecats(Y);

データを学習セットとテスト セットに均等に分割します。

rng("default") % For reproducibility
cvp = cvpartition(Y,"Holdout",0.5);
idxTrain = training(cvp);
idxTest = test(cvp);

cvp は、学習セットとテスト セットを指定する交差検証分割オブジェクトです。

AdaBoostM1 を使用して 100 本のブースティング分類木のアンサンブルに学習をさせます。弱学習器として木の切り株を使用するように指定します。また、欠損値がデータ セットに含まれているので、代理分岐を使用するように指定します。

t = templateTree("MaxNumSplits",1,"Surrogate","on");
numTrees = 100;
mdl = fitcensemble(X(idxTrain,:),Y(idxTrain),"Method","AdaBoostM1", ...
    "NumLearningCycles",numTrees,"Learners",t);

mdl は学習させた ClassificationEnsemble モデルです。

各予測子について重要度を調べます。

predImportance = predictorImportance(mdl);
bar(predImportance)
title("Predictor Importance")
xlabel("Predictor")
ylabel("Importance Measure")

Figure contains an axes object. The axes object with title Predictor Importance, xlabel Predictor, ylabel Importance Measure contains an object of type bar.

重要度が上位 10 番目までの予測子を識別します。

[~,idxSort] = sort(predImportance,"descend");
idx10 = idxSort(1:10)
idx10 = 1×10

   228   233   238    93    15   224    91   177   260   277

テスト セットの観測値を分類します。混同行列を使用して結果を表示します。青色の値は正しい分類を示し、赤色の値は誤分類された観測値を示します。

predictedValues = predict(mdl,X(idxTest,:));
confusionchart(Y(idxTest),predictedValues)

Figure contains an object of type ConfusionMatrixChart.

テスト セットでモデルの精度を計算します。

error = loss(mdl,X(idxTest,:),Y(idxTest), ...
    "LossFun","classiferror");
accuracy = 1 - error
accuracy = 
0.7731

accuracy で、正しく分類された観測値の比率が推定されます。

入力引数

すべて折りたたむ

アンサンブル分類モデル。fitcensemble で学習させた ClassificationEnsemble または ClassificationBaggedEnsemble モデル オブジェクト、または compact で作成した CompactClassificationEnsemble モデル オブジェクトとして指定します。

分類対象の予測子データ。数値行列または table として指定します。

X の各行は 1 つの観測値に対応し、各列は 1 つの変数に対応します。

数値行列の場合

  • X の列を構成する変数の順序は、ens の学習に使用した予測子変数の順序と同じでなければなりません。

  • table (たとえば tbl) を使用して ens に学習させた場合、tbl に含まれている予測子変数が数値変数のみであれば、X を数値行列にすることができます。学習時に tbl 内の数値予測子をカテゴリカルとして扱うには、fitcensemble の名前と値の引数 CategoricalPredictors を使用してカテゴリカル予測子を指定します。tbl に種類の異なる予測子変数 (数値と categorical データ型など) が混在し、X が数値行列である場合、predict でエラーが発行されます。

table の場合

  • predict は、文字ベクトルの cell 配列ではない cell 配列や複数列の変数をサポートしません。

  • table (たとえば tbl) を使用して ens に学習させた場合、X 内のすべての予測子変数の変数名およびデータ型が ens の学習に使用された (ens.PredictorNames に格納されている) 変数と同じでなければなりません。ただし、X の列の順序が tbl の列の順序に対応する必要はありません。tblX に追加の変数 (応答変数や観測値の重みなど) を含めることができますが、predict はこれらを無視します。

  • 数値行列を使用して ens に学習させた場合、ens.PredictorNames 内の予測子名と X 内の対応する予測子変数名が同じでなければなりません。学習時に予測子の名前を指定するには、fitcensemble の名前と値の引数 PredictorNames を使用します。X 内の予測子変数はすべて数値ベクトルでなければなりません。X に追加の変数 (応答変数や観測値の重みなど) を含めることができますが、predict はこれらを無視します。

名前と値の引数

すべて折りたたむ

オプションの引数のペアを Name1=Value1,...,NameN=ValueN として指定します。ここで、Name は引数名で、Value は対応する値です。名前と値の引数は他の引数の後に指定しなければなりませんが、ペアの順序は重要ではありません。

R2021a より前では、名前と値をそれぞれコンマを使って区切り、Name を引用符で囲みます。

例: predict(ens,X,Learners=[1 2 3 5],UseParallel=true) は、アンサンブル ens 内の 1 番目、2 番目、3 番目、および 5 番目の学習器を使用し、計算を並列に実行するように指定します。

predict で使用するアンサンブル内の弱学習器のインデックス。範囲 [1:ens.NumTrained] の正の整数のベクトルとして指定します。既定では、この関数はすべての学習器を使用します。

例: Learners=[1 2 4]

データ型: single | double

学習器の観測値を使用するオプション。NT 列のサイズの logical 行列として指定します。

  • NX の行の数です。

  • T は、ens に存在する弱学習器の数です。

UseObsForLearner(i,j)true (既定) の場合、学習器 jX の行 i のクラスの予測に使用されます。

例: UseObsForLearner=logical([1 1; 0 1; 1 0])

データ型: logical matrix

並列実行のフラグ。数値または logical の 1 (true) または 0 (false) として指定します。UseParallel=true を指定した場合、関数 predictparfor を使用して for ループの反復を実行します。Parallel Computing Toolbox™ がある場合、ループが並列に実行されます。

例: UseParallel=true

データ型: logical

出力引数

すべて折りたたむ

予測クラス ラベル。categorical 配列、文字配列、logical 配列、数値配列、または文字ベクトルの cell 配列として返されます。labels のデータ型は ens の学習に使用されたラベルと同じです。(string 配列は文字ベクトルの cell 配列として扱われます)。

関数 predict は、スコアが最高になるクラスに観測値を分類します。観測値のスコアが NaN の場合、関数はこの観測値を、学習ラベルの最大比率を占める多数クラスに分類します。

クラス スコア。観測値ごとに 1 つの行とクラスごとに 1 つの列をもつ数値行列として返されます。スコアは、各観測値および各クラスについて、その観測値がそのクラスからの派生である信頼度を表します。スコアが高いほど、信頼度が高いことを示します。詳細については、スコア (アンサンブル)を参照してください。

詳細

すべて折りたたむ

代替機能

Simulink ブロック

Simulink® にアンサンブルの予測を統合するには、Statistics and Machine Learning Toolbox™ ライブラリにある ClassificationEnsemble Predict ブロックを使用するか、MATLAB® Function ブロックを関数 predict と共に使用します。例については、ClassificationEnsemble Predict ブロックの使用によるクラス ラベルの予測MATLAB Function ブロックの使用によるクラス ラベルの予測を参照してください。

使用するアプローチを判断する際は、以下を考慮してください。

  • Statistics and Machine Learning Toolbox ライブラリ ブロックを使用する場合、固定小数点ツール (Fixed-Point Designer)を使用して浮動小数点モデルを固定小数点に変換できます。

  • MATLAB Function ブロックを関数 predict と共に使用する場合は、可変サイズの配列に対するサポートを有効にしなければなりません。

  • MATLAB Function ブロックを使用する場合、予測の前処理や後処理のために、同じ MATLAB Function ブロック内で MATLAB 関数を使用することができます。

拡張機能

すべて展開する

バージョン履歴

R2011a で導入