Confusion matrix of SVM classifier with k-fold cross-validation
12 ビュー (過去 30 日間)
古いコメントを表示
I am using fitcsvm to train a SVM model using k-fold cross-validation.
I would like to have access to the observations in predictions which caused FN and FP.
Therefore, I created some code to get the indexes of these observations.
However, I found out that the sum of FN, FP, TN and TP from the confusion matrices related to each
kSVMModel.Trained{k} is not equal to the confusion matrix based on "predictions".
Weren't they supposed to be the same?
c = cvpartition(fullDataY, 'KFold', 10); % create stratified folds
kSVMModel = fitcsvm(fullDataX, fullDataY, 'Standardize', true, 'CVPartition', c);
scorekSVMModel = fitSVMPosterior(kSVMModel);
[predictions, post_scores] = kfoldPredict(scorekSVMModel);
for jj = 1:kSVMModel.KFold % debug
indTrainFold{jj} = find(training(c,jj)==1);
indTestFold{jj} = find(test(c,jj)==1);
[predFold{jj}] = predict(kSVMModel.Trained{jj}, fullDataX(indTestFold{jj},:));
cmFold = confusionchart(fullDataY(indTestFold{jj},:), predFold{jj});
TN(jj) = cmFold.NormalizedValues(1,1);
TP(jj) = cmFold.NormalizedValues(2,2);
FP(jj) = cmFold.NormalizedValues(1,2);
FN(jj) = cmFold.NormalizedValues(2,1);
close all;
end
cm = confusionchart(fullDataY, predictions);
sum(TN) == cm.NormalizedValues(1,1);
sum(TP) == cm.NormalizedValues(2,2);
sum(FP) == cm.NormalizedValues(1,2);
sum(FN) == cm.NormalizedValues(2,1);
0 件のコメント
回答 (1 件)
Aditya Patil
2020 年 12 月 22 日
You can use confusionmat for getting the confusion matrix. This way, the results are correct. Check the following sample code,
%Generate data
X = rand(100, 1);
Y = [X(:,1) > 0.5];
% Fit svm model
cvp = cvpartition(Y, 'KFold', 4);
mdl = fitcsvm(X,Y, 'CVPartition', cvp);
prediction = kfoldPredict(mdl);
confusionmat(prediction, Y)
% compare with individual results
FoldPredictions = zeros(mdl.KFold, 2, 2);
for counter = 1: mdl.KFold
index = test(cvp, counter);
predictFolds = predict(mdl.Trained{counter}, X(index));
FoldPredictions(counter,:,:) = confusionmat(predictFolds, Y(index));
end
sum(FoldPredictions, 1)
1 件のコメント
参考
カテゴリ
Help Center および File Exchange で Classification についてさらに検索
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!