predict
アンサンブル分類モデルを使用したラベルの予測
説明
例
フィッシャーのアヤメのデータ セットを読み込みます。標本サイズを調べます。
load fisheriris
N = size(meas,1);
データを学習セットとテスト セットに分割します。データの 10% をテスト用にホールドアウトします。
rng(1); % For reproducibility cvp = cvpartition(N,'Holdout',0.1); idxTrn = training(cvp); % Training set indices idxTest = test(cvp); % Test set indices
学習データを table に格納します。
tblTrn = array2table(meas(idxTrn,:)); tblTrn.Y = species(idxTrn);
AdaBoostM2 と学習セットを使用してアンサンブル分類に学習させます。弱学習器として木の切り株を指定します。
t = templateTree('MaxNumSplits',1); Mdl = fitcensemble(tblTrn,'Y','Method','AdaBoostM2','Learners',t);
テスト セットについてラベルを予測します。モデルの学習にはデータの table を使用しましたが、ラベルの予測には行列を使用できます。
labels = predict(Mdl,meas(idxTest,:));
テスト セットの混同行列を作成します。
confusionchart(species(idxTest),labels)
Mdl
は、テスト セット内の 1 つの versicolor 種のアヤメを virginica として誤分類します。
ブースティング木のアンサンブルを作成し、各予測子の重要度を検査します。テスト データを使用して、アンサンブルの分類精度を評価します。
不整脈データ セットを読み込みます。データのクラス表現を判別します。
load arrhythmia
Y = categorical(Y);
tabulate(Y)
Value Count Percent 1 245 54.20% 2 44 9.73% 3 15 3.32% 4 15 3.32% 5 13 2.88% 6 25 5.53% 7 3 0.66% 8 2 0.44% 9 9 1.99% 10 50 11.06% 14 4 0.88% 15 5 1.11% 16 22 4.87%
データ セットには 16 個のクラスが含まれていますが、すべてのクラスは表現されていません (たとえば、クラス 13)。ほとんどの観測値は不整脈がないものとして分類されています (クラス 1)。このデータ セットは非常に離散的であり、クラスが不均衡です。
不整脈があるすべての観測値 (クラス 2 ~ 15) を 1 つのクラスに結合します。不整脈の状況が不明である観測値 (クラス 16) をデータ セットから削除します。
idx = (Y ~= "16"); Y = Y(idx); X = X(idx,:); Y(Y ~= "1") = "WithArrhythmia"; Y(Y == "1") = "NoArrhythmia"; Y = removecats(Y);
データを学習セットとテスト セットに均等に分割します。
rng("default") % For reproducibility cvp = cvpartition(Y,"Holdout",0.5); idxTrain = training(cvp); idxTest = test(cvp);
cvp
は、学習セットとテスト セットを指定する交差検証分割オブジェクトです。
AdaBoostM1
を使用して 100 本のブースティング分類木のアンサンブルに学習をさせます。弱学習器として木の切り株を使用するように指定します。また、欠損値がデータ セットに含まれているので、代理分岐を使用するように指定します。
t = templateTree("MaxNumSplits",1,"Surrogate","on"); numTrees = 100; mdl = fitcensemble(X(idxTrain,:),Y(idxTrain),"Method","AdaBoostM1", ... "NumLearningCycles",numTrees,"Learners",t);
mdl
は学習させた ClassificationEnsemble
モデルです。
各予測子について重要度を調べます。
predImportance = predictorImportance(mdl); bar(predImportance) title("Predictor Importance") xlabel("Predictor") ylabel("Importance Measure")
重要度が上位 10 番目までの予測子を識別します。
[~,idxSort] = sort(predImportance,"descend");
idx10 = idxSort(1:10)
idx10 = 1×10
228 233 238 93 15 224 91 177 260 277
テスト セットの観測値を分類します。混同行列を使用して結果を表示します。青色の値は正しい分類を示し、赤色の値は誤分類された観測値を示します。
predictedValues = predict(mdl,X(idxTest,:)); confusionchart(Y(idxTest),predictedValues)
テスト セットでモデルの精度を計算します。
error = loss(mdl,X(idxTest,:),Y(idxTest), ... "LossFun","classiferror"); accuracy = 1 - error
accuracy = 0.7731
accuracy
で、正しく分類された観測値の比率が推定されます。
入力引数
アンサンブル分類モデル。fitcensemble
で学習させた ClassificationEnsemble
または ClassificationBaggedEnsemble
モデル オブジェクト、または compact
で作成した CompactClassificationEnsemble
モデル オブジェクトとして指定します。
分類対象の予測子データ。数値行列または table として指定します。
X
の各行は 1 つの観測値に対応し、各列は 1 つの変数に対応します。
数値行列の場合
X
の列を構成する変数の順序は、ens
の学習に使用した予測子変数の順序と同じでなければなりません。table (たとえば
tbl
) を使用してens
に学習させた場合、tbl
に含まれている予測子変数が数値変数のみであれば、X
を数値行列にすることができます。学習時にtbl
内の数値予測子をカテゴリカルとして扱うには、fitcensemble
の名前と値の引数CategoricalPredictors
を使用してカテゴリカル予測子を指定します。tbl
に種類の異なる予測子変数 (数値と categorical データ型など) が混在し、X
が数値行列である場合、predict
でエラーが発行されます。
table の場合
predict
は、文字ベクトルの cell 配列ではない cell 配列や複数列の変数をサポートしません。table (たとえば
tbl
) を使用してens
に学習させた場合、X
内のすべての予測子変数の変数名およびデータ型がens
の学習に使用された (ens.PredictorNames
に格納されている) 変数と同じでなければなりません。ただし、X
の列の順序がtbl
の列の順序に対応する必要はありません。tbl
とX
に追加の変数 (応答変数や観測値の重みなど) を含めることができますが、predict
はこれらを無視します。数値行列を使用して
ens
に学習させた場合、ens.PredictorNames
内の予測子名とX
内の対応する予測子変数名が同じでなければなりません。学習時に予測子の名前を指定するには、fitcensemble
の名前と値の引数PredictorNames
を使用します。X
内の予測子変数はすべて数値ベクトルでなければなりません。X
に追加の変数 (応答変数や観測値の重みなど) を含めることができますが、predict
はこれらを無視します。
名前と値の引数
オプションの引数のペアを Name1=Value1,...,NameN=ValueN
として指定します。ここで、Name
は引数名で、Value
は対応する値です。名前と値の引数は他の引数の後に指定しなければなりませんが、ペアの順序は重要ではありません。
R2021a より前では、名前と値をそれぞれコンマを使って区切り、Name
を引用符で囲みます。
例: predict(ens,X,Learners=[1 2 3 5],UseParallel=true)
は、アンサンブル ens
内の 1 番目、2 番目、3 番目、および 5 番目の学習器を使用し、計算を並列に実行するように指定します。
predict
で使用するアンサンブル内の弱学習器のインデックス。範囲 [1:ens.NumTrained
] の正の整数のベクトルとして指定します。既定では、この関数はすべての学習器を使用します。
例: Learners=[1 2 4]
データ型: single
| double
並列実行のフラグ。数値または logical の 1
(true
) または 0
(false
) として指定します。UseParallel=true
を指定した場合、関数 predict
は parfor
を使用して for
ループの反復を実行します。Parallel Computing Toolbox™ がある場合、ループが並列に実行されます。
例: UseParallel=true
データ型: logical
出力引数
予測クラス ラベル。categorical 配列、文字配列、logical 配列、数値配列、または文字ベクトルの cell 配列として返されます。labels
のデータ型は ens
の学習に使用されたラベルと同じです。(string 配列は文字ベクトルの cell 配列として扱われます)。
関数 predict
は、スコアが最高になるクラスに観測値を分類します。観測値のスコアが NaN
の場合、関数はこの観測値を、学習ラベルの最大比率を占める多数クラスに分類します。
クラス スコア。観測値ごとに 1 つの行とクラスごとに 1 つの列をもつ数値行列として返されます。スコアは、各観測値および各クラスについて、その観測値がそのクラスからの派生である信頼度を表します。スコアが高いほど、信頼度が高いことを示します。詳細については、スコア (アンサンブル)を参照してください。
詳細
アンサンブルの場合、分類スコアは、観測値が特定のクラスからの派生である信頼度を表します。スコアが高いほど、信頼度も高くなります。
アンサンブル アルゴリズムが異なれば、スコアの定義も違ってきます。さらに、スコアの範囲はアンサンブル タイプによって異なります。以下に例を示します。
Bag
スコアの範囲は0
~1
です。これらのスコアは、アンサンブル内のすべての木で平均した確率と解釈できます。AdaBoostM1
、GentleBoost
、およびLogitBoost
のスコアの範囲は –∞ ~ ∞ です。これらのスコアは、ens
をpredict
に渡す前に、ens
のScoreTransform
プロパティを"doublelogit"
に設定することで確率に変換できます。あるいは、ens.ScoreTransform = "doublelogit"; [labels,scores] = predict(ens,X);
ens
を作成するときにfitcensemble
の呼び出しでScoreTransform="doublelogit"
を指定できます。
各種のアンサンブル アルゴリズムとそれらによるスコアの計算方法の詳細については、アンサンブル アルゴリズムを参照してください。
代替機能
Simulink ブロック
Simulink® にアンサンブルの予測を統合するには、Statistics and Machine Learning Toolbox™ ライブラリにある ClassificationEnsemble Predict ブロックを使用するか、MATLAB® Function ブロックを関数 predict
と共に使用します。例については、ClassificationEnsemble Predict ブロックの使用によるクラス ラベルの予測とMATLAB Function ブロックの使用によるクラス ラベルの予測を参照してください。
使用するアプローチを判断する際は、以下を考慮してください。
Statistics and Machine Learning Toolbox ライブラリ ブロックを使用する場合、固定小数点ツール (Fixed-Point Designer)を使用して浮動小数点モデルを固定小数点に変換できます。
MATLAB Function ブロックを関数
predict
と共に使用する場合は、可変サイズの配列に対するサポートを有効にしなければなりません。MATLAB Function ブロックを使用する場合、予測の前処理や後処理のために、同じ MATLAB Function ブロック内で MATLAB 関数を使用することができます。
拡張機能
使用上の注意および制限:
saveLearnerForCoder
、loadLearnerForCoder
およびcodegen
(MATLAB Coder) を使用して、関数predict
のコードを生成します。saveLearnerForCoder
を使用して、学習済みモデルを保存します。loadLearnerForCoder
を使用して保存済みモデルを読み込んで関数predict
を呼び出す、エントリポイント関数を定義します。次に、codegen
を使用して、エントリポイント関数のコードを生成します。predict
の単精度の C/C++ コードを生成するには、loadLearnerForCoder
関数を呼び出すときにDataType="single"
を指定します。predict
に対する固定小数点の C/C++ コードを生成することもできます。固定小数点コードの生成には、予測に必要な変数の固定小数点データ型を定義する追加の手順が必要です。generateLearnerDataTypeFcn
によって生成されるデータ型関数を使用して固定小数点データ型構造体を作成し、その構造体をエントリポイント関数でloadLearnerForCoder
の入力引数として使用します。固定小数点の C/C++ コードを生成するには、MATLAB Coder™ および Fixed-Point Designer™ が必要です。predict
の固定小数点コードの生成には個々の学習器についてのデータ型の伝播が含まれるため、時間がかかることがあります。次の表は、
predict
の引数に関する注意です。この表に含まれていない引数は、完全にサポートされています。引数 注意と制限 ens
モデル オブジェクトの使用上の注意および制限については、
CompactClassificationEnsemble
オブジェクトのCode Generationを参照してください。X
一般的なコード生成の場合、
X
は、単精度または倍精度の行列か、数値変数、カテゴリカル変数、またはその両方を含む table でなければなりません。固定小数点コードの生成の場合、
X
は固定小数点の行列でなければなりません。X
の行数、または観測値の数は可変サイズにすることができますが、X
の列数は固定でなければなりません。X
を table として指定する場合、モデルは table を使用して学習させたものでなければならず、かつ予測のためのエントリポイント関数で次を行う必要があります。データを配列として受け入れる。
データ入力の引数から table を作成し、その table 内で変数名を指定する。
table を
predict
に渡す。
この table のワークフローの例については、table のデータを分類するためのコードの生成を参照してください。コード生成における table の使用の詳細については、table のコード生成 (MATLAB Coder)およびコード生成における table の制限事項 (MATLAB Coder)を参照してください。
名前と値の引数 名前と値の引数に含まれる名前はコンパイル時の定数でなければなりません。たとえば、生成コードで最大 5 つの弱学習器に対応するユーザー定義インデックスを指定できるようにするには、
{coder.Constant('Learners'),coder.typeof(0,[1,5],[0,1])}
をcodegen
(MATLAB Coder) の-args
の値に含めます。"Learners"
固定小数点コードの生成では、
"Learners"
の値は整数データ型でなければなりません。
詳細は、コード生成の紹介を参照してください。
並列実行するには、この関数を呼び出すときに名前と値の引数 UseParallel
を true
に設定します。
並列計算の全般的な情報については、自動並列サポートを使用した MATLAB 関数の実行 (Parallel Computing Toolbox)を参照してください。
UseParallel
は tall 配列、GPU 配列、コード生成では使用できません。
使用上の注意および制限:
関数
predict
では代理分岐をもつ決定木学習器を使用して学習させたアンサンブルはサポートしていません。
詳細は、GPU での MATLAB 関数の実行 (Parallel Computing Toolbox)を参照してください。
バージョン履歴
R2011a で導入
MATLAB Command
You clicked a link that corresponds to this MATLAB command:
Run the command by entering it in the MATLAB Command Window. Web browsers do not support MATLAB commands.
Web サイトの選択
Web サイトを選択すると、翻訳されたコンテンツにアクセスし、地域のイベントやサービスを確認できます。現在の位置情報に基づき、次のサイトの選択を推奨します:
また、以下のリストから Web サイトを選択することもできます。
最適なサイトパフォーマンスの取得方法
中国のサイト (中国語または英語) を選択することで、最適なサイトパフォーマンスが得られます。その他の国の MathWorks のサイトは、お客様の地域からのアクセスが最適化されていません。
南北アメリカ
- América Latina (Español)
- Canada (English)
- United States (English)
ヨーロッパ
- Belgium (English)
- Denmark (English)
- Deutschland (Deutsch)
- España (Español)
- Finland (English)
- France (Français)
- Ireland (English)
- Italia (Italiano)
- Luxembourg (English)
- Netherlands (English)
- Norway (English)
- Österreich (Deutsch)
- Portugal (English)
- Sweden (English)
- Switzerland
- United Kingdom (English)