Main Content

このページの翻訳は最新ではありません。ここをクリックして、英語の最新版を参照してください。

predict

分類モデルのアンサンブルを使用して観測値を分類

説明

labels = predict(Mdl,X) は、学習済みの完全またはコンパクトなアンサンブル分類 Mdl に基づいて、テーブルまたは行列 X 内の予測子データに対する予測クラス ラベルのベクトルを返します。

labels = predict(Mdl,X,Name,Value) は、1 つ以上の Name,Value 引数のペアによって指定された追加オプションを使用します。

[labels,scores] = predict(___) は、前の構文の入力引数のいずれかを使用して、ラベルが特定のクラスから派生する尤度を示す分類スコアの行列 (scores) も返します。X 内の各観測値について、予測クラス ラベルは、すべてのクラスの中で最大のスコアに対応します。

入力引数

Mdl

fitcensemble で作成されたアンサンブル分類、または compact で作成されたコンパクトなアンサンブル分類。

X

分類対象の予測子データ。数値行列またはテーブルを指定します。

X の各行は 1 つの観測値に対応し、各列は 1 つの変数に対応します。

  • 数値行列の場合

    • X の列を構成する変数の順序は、Mdl に学習させた予測子変数の順序と同じでなければなりません。

    • テーブル (たとえば Tbl) を使用して Mdl に学習をさせた場合、Tbl に含まれている予測子変数がすべて数値変数であれば、X を数値行列にすることができます。学習時に Tbl 内の数値予測子をカテゴリカルとして扱うには、fitcensemble の名前と値のペアの引数 CategoricalPredictors を使用してカテゴリカル予測子を同定します。Tbl に種類の異なる予測子変数 (数値および categorical データ型など) が混在し、X が数値行列である場合、predict でエラーがスローされます。

  • テーブルの場合

    • predict は、文字ベクトルの cell 配列ではない cell 配列や複数列の変数をサポートしません。

    • テーブル (たとえば Tbl) を使用して Mdl に学習をさせた場合、X 内のすべての予測子変数は変数名およびデータ型が、(Mdl.PredictorNames に格納されている) Mdl に学習させた変数と同じでなければなりません。ただし、X の列の順序が Tbl の列の順序に対応する必要はありません。TblX に追加の変数 (応答変数や観測値の重みなど) を含めることができますが、predict はこれらを無視します。

    • 数値行列を使用して Mdl に学習をさせた場合、Mdl.PredictorNames 内の予測子名と X 内の対応する予測子変数名が同じでなければなりません。学習時に予測子の名前を指定する方法については、fitcensemble の名前と値のペアの引数 PredictorNames を参照してください。X 内の予測子変数はすべて数値ベクトルでなければなりません。X に追加の変数 (応答変数や観測値の重みなど) を含めることができますが、predict はこれらを無視します。

名前と値の引数

オプションの引数のペアを Name1=Value1,...,NameN=ValueN として指定します。ここで Name は引数名、Value は対応する値です。名前と値の引数は他の引数の後ろにする必要がありますが、ペアの順序は関係ありません。

R2021a より前では、名前と値をそれぞれコンマを使って区切り、Name を引用符で囲みます。

Learners

弱学習器 predict のインデックスは、応答の計算のために数値ベクトルを使用します。

既定値: 1:T。ここで T は、Mdl 内の弱学習器の数を示します。

UseObsForLearner

NT 列のサイズの logical 行列です。

  • NX の行の数です。

  • T は、Mdl に存在する弱学習器の数です。

UseObsForLearner(i,j)true のとき、学習器 jX の行 i のクラスの予測に使用されます。

既定値: true(N,T)

UseParallel

推定を並列で実行するための指定。false (逐次計算) または true (並列計算) として指定します。並列計算には Parallel Computing Toolbox™ が必要です。特に大規模なデータセットでは、並列推定の方が逐次推定よりも高速になる可能性があります。並列計算は木学習器でのみサポートされます。

既定値: false

出力引数

labels

分類ラベルのベクトル。labels のデータ型は、学習 Mdl で使用されるラベルと同じです。(string 配列は文字ベクトルの cell 配列として扱われます)。

関数 predict は、スコアが最高になるクラスに観測値を分類します。観測値のスコアが NaN の場合、関数はこの観測値を、学習ラベルの最大比率を占める多数クラスに分類します。

scores

1 観測値あたり 1 行、1 クラスあたり 1 列の行列。スコアは、各観測値および各クラスについて、そのクラスから観測値が派生する信頼度を表します。スコアが高いほど、信頼度も高くなります。詳細については、スコア (アンサンブル)を参照してください。

すべて展開する

フィッシャーのアヤメのデータセットを読み込みます。標本サイズを調べます。

load fisheriris
N = size(meas,1);

データを学習セットとテスト セットに分割します。データの 10% を検定用にホールドアウトします。

rng(1); % For reproducibility
cvp = cvpartition(N,'Holdout',0.1);
idxTrn = training(cvp); % Training set indices
idxTest = test(cvp);    % Test set indices

学習データをテーブルに格納します。

tblTrn = array2table(meas(idxTrn,:));
tblTrn.Y = species(idxTrn);

AdaBoostM2 と学習セットを使用して、アンサンブル分類に学習をさせます。弱学習器として木の切り株を指定します。

t = templateTree('MaxNumSplits',1);
Mdl = fitcensemble(tblTrn,'Y','Method','AdaBoostM2','Learners',t);

テスト セットについてラベルを予測します。モデルの学習にはテーブル データを使用しましたが、ラベルの予測には行列を使用できます。

labels = predict(Mdl,meas(idxTest,:));

テスト セットの混同行列を作成します。

confusionchart(species(idxTest),labels)

Figure contains an object of type ConfusionMatrixChart.

Mdl は、テスト セット内の 1 つの versicolor 種のアヤメを virginica として誤分類します。

ブースティング木のアンサンブルを作成し、各予測子の重要度を検査します。テスト データを使用して、アンサンブルの分類精度を評価します。

不整脈データセットを読み込みます。データのクラス表現を判別します。

load arrhythmia
Y = categorical(Y);
tabulate(Y)
  Value    Count   Percent
      1      245     54.20%
      2       44      9.73%
      3       15      3.32%
      4       15      3.32%
      5       13      2.88%
      6       25      5.53%
      7        3      0.66%
      8        2      0.44%
      9        9      1.99%
     10       50     11.06%
     14        4      0.88%
     15        5      1.11%
     16       22      4.87%

データ セットには 16 個のクラスが含まれていますが、すべてのクラスは表現されていません (たとえば、クラス 13)。ほとんどの観測値は不整脈がないものとして分類されています (クラス 1)。このデータセットは非常に離散的であり、クラスが不均衡です。

不整脈があるすべての観測値 (クラス 2 ~ 15) を 1 つのクラスに結合します。不整脈の状況が不明である観測値 (クラス 16) をデータ セットから削除します。

idx = (Y ~= "16");
Y = Y(idx);
X = X(idx,:);
Y(Y ~= "1") = "WithArrhythmia";
Y(Y == "1") = "NoArrhythmia";
Y = removecats(Y);

データを学習セットとテスト セットに均等に分割します。

rng("default") % For reproducibility
cvp = cvpartition(Y,"Holdout",0.5);
idxTrain = training(cvp);
idxTest = test(cvp);

cvp は、学習セットとテスト セットを指定する交差検証分割オブジェクトです。

AdaBoostM1 を使用して 100 本のブースティング分類木のアンサンブルに学習をさせます。弱学習器として木の切り株を使用するように指定します。また、欠損値がデータ セットに含まれているので、代理分岐を使用するように指定します。

t = templateTree("MaxNumSplits",1,"Surrogate","on");
numTrees = 100;
mdl = fitcensemble(X(idxTrain,:),Y(idxTrain),"Method","AdaBoostM1", ...
    "NumLearningCycles",numTrees,"Learners",t);

mdl は学習させた ClassificationEnsemble モデルです。

各予測子について重要度を調べます。

predImportance = predictorImportance(mdl);
bar(predImportance)
title("Predictor Importance")
xlabel("Predictor")
ylabel("Importance Measure")

Figure contains an axes object. The axes object with title Predictor Importance, xlabel Predictor, ylabel Importance Measure contains an object of type bar.

重要度が上位 10 番目までの予測子を識別します。

[~,idxSort] = sort(predImportance,"descend");
idx10 = idxSort(1:10)
idx10 = 1×10

   228   233   238    93    15   224    91   177   260   277

テスト セットの観測値を分類します。混同行列を使用して結果を表示します。青色の値は正しい分類を示し、赤色の値は誤分類された観測値を示します。

predictedValues = predict(mdl,X(idxTest,:));
confusionchart(Y(idxTest),predictedValues)

Figure contains an object of type ConfusionMatrixChart.

テスト セットでモデルの精度を計算します。

error = loss(mdl,X(idxTest,:),Y(idxTest), ...
    "LossFun","classiferror");
accuracy = 1 - error
accuracy = 0.7731

accuracy で、正しく分類された観測値の比率が推定されます。

詳細

すべて展開する

代替機能

Simulink ブロック

Simulink® にアンサンブルの予測を統合するには、Statistics and Machine Learning Toolbox™ ライブラリにある ClassificationEnsemble Predict ブロックを使用するか、MATLAB® Function ブロックを関数 predict と共に使用します。例については、ClassificationEnsemble Predict ブロックの使用によるクラス ラベルの予測MATLAB Function ブロックの使用によるクラス ラベルの予測を参照してください。

使用するアプローチを判断する際は、以下を考慮してください。

  • Statistics and Machine Learning Toolbox ライブラリ ブロックを使用する場合、固定小数点ツール (Fixed-Point Designer)を使用して浮動小数点モデルを固定小数点に変換できます。

  • MATLAB Function ブロックを関数 predict と共に使用する場合は、可変サイズの配列に対するサポートを有効にしなければなりません。

  • MATLAB Function ブロックを使用する場合、予測の前処理や後処理のために、同じ MATLAB Function ブロック内で MATLAB 関数を使用することができます。

拡張機能

バージョン履歴

R2011a で導入