このページの内容は最新ではありません。最新版の英語を参照するには、ここをクリックします。
predict
構文
説明
例
GAM のテスト標本観測値のラベル付け
学習標本を使用して一般化加法モデルに学習させてから、テスト標本にラベルを付けます。
fisheriris
データ セットを読み込みます。versicolor と virginica のアヤメについてのがく片と花弁の測定値が含まれる数値行列 X
を作成します。対応するアヤメの種類が含まれる文字ベクトルの cell 配列 Y
を作成します。
load fisheriris inds = strcmp(species,'versicolor') | strcmp(species,'virginica'); X = meas(inds,:); Y = species(inds,:);
Y
のクラス情報を使用して、観測値を階層的に学習セットとテスト セットに無作為に分割します。テスト用の 30% のホールドアウト標本を指定します。
rng('default') % For reproducibility cv = cvpartition(Y,'HoldOut',0.30);
学習インデックスとテスト インデックスを抽出します。
trainInds = training(cv); testInds = test(cv);
学習データ セットとテスト データ セットを指定します。
XTrain = X(trainInds,:); YTrain = Y(trainInds); XTest = X(testInds,:); YTest = Y(testInds);
予測子 XTrain
とクラス ラベル YTrain
を使用して、一般化加法モデルに学習させます。クラス名を指定することが推奨されます。
Mdl = fitcgam(XTrain,YTrain,'ClassNames',{'versicolor','virginica'})
Mdl = ClassificationGAM ResponseName: 'Y' CategoricalPredictors: [] ClassNames: {'versicolor' 'virginica'} ScoreTransform: 'logit' Intercept: -1.1090 NumObservations: 70
Mdl
は ClassificationGAM
モデル オブジェクトです。
テスト標本のラベルを予測します。
label = predict(Mdl,XTest);
真のラベルと予測ラベルを格納する table を作成します。10 件の観測値の無作為なセットについて table を表示します。
t = table(YTest,label,'VariableNames',{'True Label','Predicted Label'}); idx = randsample(sum(testInds),10); t(idx,:)
ans=10×2 table
True Label Predicted Label
______________ _______________
{'virginica' } {'virginica' }
{'virginica' } {'virginica' }
{'versicolor'} {'virginica' }
{'virginica' } {'virginica' }
{'virginica' } {'virginica' }
{'versicolor'} {'versicolor'}
{'versicolor'} {'versicolor'}
{'versicolor'} {'versicolor'}
{'versicolor'} {'versicolor'}
{'virginica' } {'virginica' }
真のラベル YTest
と予測ラベル label
から混同チャートを作成します。
cm = confusionchart(YTest,label);
事後確率のロジットの比較
予測子の線形項と交互作用項の両方が格納されている分類 GAM を使用して、新しい観測値の事後確率のロジットを推定します。メモリ効率の高いモデル オブジェクトを使用して新しい観測値を分類します。新しい観測値を分類する際に交互作用項を含めるかどうかを指定します。
ionosphere
データ セットを読み込みます。このデータ セットには、レーダー反射についての 34 個の予測子と、不良 ('b'
) または良好 ('g'
) という 351 個の二項反応が含まれています。
load ionosphere
データ セットを 2 つのセットに分割します。1 つは学習データを含め、もう 1 つは新しい未観測のテスト データを含めます。新しいテスト データ セットの 10 件の観測値を保持します。
rng('default') % For reproducibility n = size(X,1); newInds = randsample(n,10); inds = ~ismember(1:n,newInds); XNew = X(newInds,:); YNew = Y(newInds);
予測子 X
とクラス ラベル Y
を使用して、GAM に学習させます。クラス名を指定することが推奨されます。上位 10 個の最も重要な交互作用項を含めるように指定します。
Mdl = fitcgam(X(inds,:),Y(inds),'ClassNames',{'b','g'},'Interactions',10);
Mdl
は ClassificationGAM
モデル オブジェクトです。
学習させたモデルのサイズを減らし、メモリの消費量を抑えます。
CMdl = compact(Mdl); whos('Mdl','CMdl')
Name Size Bytes Class Attributes CMdl 1x1 1081260 classreg.learning.classif.CompactClassificationGAM Mdl 1x1 1282819 ClassificationGAM
CMdl
は CompactClassificationGAM
モデル オブジェクトです。
線形項と交互作用項の両方を使用してラベルを予測してから、線形項のみを使用してラベルを予測します。交互作用項を除外するには、'IncludeInteractions',false
を指定します。ScoreTransform
プロパティとして 'none'
を指定して、事後確率のロジットを推定します。
CMdl.ScoreTransform = 'none'; [labels,scores] = predict(CMdl,XNew); [labels_nointeraction,scores_nointeraction] = predict(CMdl,XNew,'IncludeInteractions',false); t = table(YNew,labels,scores,labels_nointeraction,scores_nointeraction, ... 'VariableNames',{'True Labels','Predicted Labels','Scores' ... 'Predicted Labels Without Interactions','Scores Without Interactions'})
t=10×5 table
True Labels Predicted Labels Scores Predicted Labels Without Interactions Scores Without Interactions
___________ ________________ __________________ _____________________________________ ___________________________
{'g'} {'g'} -40.23 40.23 {'g'} -37.484 37.484
{'g'} {'g'} -41.215 41.215 {'g'} -38.737 38.737
{'g'} {'g'} -44.413 44.413 {'g'} -42.186 42.186
{'g'} {'b'} 3.0658 -3.0658 {'b'} 1.4338 -1.4338
{'g'} {'g'} -84.637 84.637 {'g'} -81.269 81.269
{'g'} {'g'} -27.44 27.44 {'g'} -24.831 24.831
{'g'} {'g'} -62.989 62.989 {'g'} -60.4 60.4
{'g'} {'g'} -77.109 77.109 {'g'} -75.937 75.937
{'g'} {'g'} -48.519 48.519 {'g'} -47.067 47.067
{'g'} {'g'} -56.256 56.256 {'g'} -53.373 53.373
テスト データ Xnew
の予測ラベルは交互作用項を含めても変化しませんが、推定スコア値は異なります。
事後確率領域のプロット
一般化加法モデルに学習させてから、1 番目のクラスの確率値を使用して事後確率領域をプロットします。
fisheriris
データ セットを読み込みます。versicolor と virginica のアヤメについての 2 つの花弁の測定値が含まれる数値行列 X
を作成します。対応するアヤメの種類が含まれる文字ベクトルの cell 配列 Y
を作成します。
load fisheriris inds = strcmp(species,'versicolor') | strcmp(species,'virginica'); X = meas(inds,3:4); Y = species(inds,:);
予測子 X
とクラス ラベル Y
を使用して、一般化加法モデルに学習させます。クラス名を指定することが推奨されます。
Mdl = fitcgam(X,Y,'ClassNames',{'versicolor','virginica'});
Mdl
は ClassificationGAM
モデル オブジェクトです。
観測された予測子領域の値のグリッドを定義します。
xMax = max(X); xMin = min(X); x1 = linspace(xMin(1),xMax(1),250); x2 = linspace(xMin(2),xMax(2),250); [x1Grid,x2Grid] = meshgrid(x1,x2);
グリッド内の各インスタンスの事後確率を予測します。
[~,PosteriorRegion] = predict(Mdl,[x1Grid(:),x2Grid(:)]);
1 番目のクラス 'versicolor'
の確率値を使用して事後確率領域をプロットします。
h = scatter(x1Grid(:),x2Grid(:),1,PosteriorRegion(:,1)); h.MarkerEdgeAlpha = 0.3;
学習データをプロットします。
hold on gh = gscatter(X(:,1),X(:,2),Y,'k','dx'); title('Iris Petal Measurements and Posterior Probabilities') xlabel('Petal length (cm)') ylabel('Petal width (cm)') legend(gh,'Location','Best') colorbar hold off
入力引数
Mdl
— 一般化加法モデル
ClassificationGAM
モデル オブジェクト | CompactClassificationGAM
モデル オブジェクト
一般化加法モデル。ClassificationGAM
または CompactClassificationGAM
モデル オブジェクトとして指定します。
X
— 予測子データ
数値行列 | テーブル
予測子データ。数値行列またはテーブルとして指定します。
X
の各行は 1 つの観測値に対応し、各列は 1 つの変数に対応します。
数値行列の場合
X
の列を構成する変数の順序は、Mdl
に学習させた予測子変数の順序と同じでなければなりません。table を使用して
Mdl
に学習をさせた場合、table に含まれている予測子変数がすべて数値変数であれば、X
を数値行列にすることができます。
テーブルの場合
table (たとえば
Tbl
) を使用してMdl
に学習をさせた場合、X
内のすべての予測子変数は変数名およびデータ型がTbl
内の変数と同じでなければなりません。ただし、X
の列の順序がTbl
の列の順序に対応する必要はありません。数値行列を使用して
Mdl
に学習をさせた場合、Mdl.PredictorNames
内の予測子名とX
内の対応する予測子変数名が同じでなければなりません。学習時に予測子の名前を指定するには、名前と値の引数'PredictorNames'
を使用します。X
内の予測子変数はすべて数値ベクトルでなければなりません。X
に追加の変数 (応答変数や観測値の重みなど) を含めることができますが、predict
はこれらを無視します。predict
は、文字ベクトルの cell 配列ではない cell 配列や複数列の変数をサポートしません。
データ型: table
| double
| single
includeInteractions
— 交互作用項を含むというフラグ
true
| false
モデルの交互作用項を含むというフラグ。true
または false
として指定します。
Mdl
に交互作用項が含まれる場合、includeInteractions
の既定値は true
です。モデルに交互作用項が含まれない場合、値は false
でなければなりません。
データ型: logical
出力引数
label
— 予測クラス ラベル
categorical 配列 | 文字配列 | logical ベクトル | 数値ベクトル | 文字ベクトルの cell 配列
score
— 予測事後確率またはクラス スコア
2 列の数値行列
予測事後確率またはクラス スコア。X
と同じ行数の 2 列の数値行列として返されます。score
の 1 列目と 2 列目に、対応する観測値の 1 番目のクラス (陰性クラス、Mdl.ClassNames(1)
) と 2 番目のクラス (陽性クラス、Mdl.ClassNames(2)
) のスコア値がそれぞれ格納されます。
Mdl.ScoreTransform
が 'logit'
(既定) の場合、スコア値は事後確率です。Mdl.ScoreTransform
が 'none'
の場合、スコア値は事後確率のロジットです。ソフトウェアに組み込みのスコア変換関数がいくつか用意されています。詳細については、Mdl
の ScoreTransform
プロパティを参照してください。
スコア変換を変更するには、学習時に fitcgam
の引数 'ScoreTransform'
を指定するか、学習後に ScoreTransform
プロパティを変更します。
詳細
予測クラス ラベル
predict
は、予測される誤分類コストを最小化することにより分類します。
ここで、
は、予測された分類です。
K は、クラスの数です。
は、観測値 x のクラス j の事後確率です。
は、真のクラスが j の場合に観測値を y として分類するコストです。
予測誤分類コスト
観測値ごとの予測誤分類コストは、観測をそれぞれのクラスに分類する平均コストです。
学習済みの分類器を使用して Nobs
個の観測値を分類するとします。また、K
個のクラスがあるとします。1 行に 1 観測ずつ、観測値を行列 X
に置きます。
予測コスト行列 CE
のサイズは、Nobs
行 K
列です。CE
の各行には、観測をそれぞれのクラス K
に分類する予測 (平均) コストが含まれます。CE(n,k)
は次のとおりです。
ここで、
K は、クラスの数です。
は、観測値 X(n) のクラス i の事後確率です。
は、真のクラスが i である観測値を k に分類する真の誤分類コストです。
真の誤分類コスト
真の誤分類コストは、観測値を誤ったクラスに分類するコストです。
分類器の作成時に、名前と値の引数 Cost
を使用してクラスごとの真の誤分類コストを設定できます。Cost(i,j)
は、真のクラスが i
の場合に観測値をクラス j
に分類するコストです。既定では、Cost(i,j)=1
(i~=j
の場合) および Cost(i,j)=0
(i=j
の場合) です。つまり、正しい分類のコストは 0
、誤った分類のコストは 1
です。
バージョン履歴
R2021a で導入
MATLAB コマンド
次の MATLAB コマンドに対応するリンクがクリックされました。
コマンドを MATLAB コマンド ウィンドウに入力して実行してください。Web ブラウザーは MATLAB コマンドをサポートしていません。
Select a Web Site
Choose a web site to get translated content where available and see local events and offers. Based on your location, we recommend that you select: .
You can also select a web site from the following list:
How to Get Best Site Performance
Select the China site (in Chinese or English) for best site performance. Other MathWorks country sites are not optimized for visits from your location.
Americas
- América Latina (Español)
- Canada (English)
- United States (English)
Europe
- Belgium (English)
- Denmark (English)
- Deutschland (Deutsch)
- España (Español)
- Finland (English)
- France (Français)
- Ireland (English)
- Italia (Italiano)
- Luxembourg (English)
- Netherlands (English)
- Norway (English)
- Österreich (Deutsch)
- Portugal (English)
- Sweden (English)
- Switzerland
- United Kingdom (English)