Precision-Recall curve for Multiclass classification
40 ビュー (過去 30 日間)
古いコメントを表示
I have been trying hard to find any document or example related to ploting precision-recall curve for multiclass classification. But it seems like there is no way to do that. How would I make a precision-recall curve for my model.
Following is the code I use to get confusion matrix
fpath = 'E:\Research Data\Four Classes';
testData = fullfile(fpath, 'Test');
% %
testDatastore = imageDatastore(testData, 'IncludeSubfolders', true, 'LabelSource', 'foldernames');
allclass = [];
for i = 1:length(testDatastore.Labels)
I = readimage(testDatastore, i);
class = classify(incep, I);
allclass = [allclass class];
end
predicted = allclass';
% figure
% plotconfusion(testDatastore.Labels, predicted)
3 件のコメント
回答 (1 件)
Drew
2023 年 8 月 23 日
編集済み: Drew
2023 年 8 月 24 日
Given a multiclass classification problem, you can create a Precision-Recall curve for each class by considering the one-vs-all binary classification problem for each class. The Precision-Recall curves can be built with:
- The rocmetrics function (introduced in R2022a). See https://www.mathworks.com/help/stats/rocmetrics.html and the specific Precision-Recall curve example: openExample('stats/SpecifyScoresAsVectorExample')
- The perfcurve function (introduced in R2009a). See https://www.mathworks.com/help/stats/perfcurve.html . You can see example code at https://www.mathworks.com/matlabcentral/answers/217604-how-to-calculate-precision-and-recall-in-matlab
The documentation for the plot method of the rocmetrics object https://www.mathworks.com/help/stats/rocmetrics.plot.html has another Precision-Recall curve example: openExample('stats/PlotOtherPerformanceCurveExample')
This comment applies to ROC curves, but not Precision-Recall curves: For multiclass problems, the "plot" method of the rocmetrics object also has the ability to create ROC curves from averaged metrics using the "AverageROCType" Name-Value Argument, and the "average" method of the rocmetrics object can be used to calculate these average metrics https://www.mathworks.com/help/stats/rocmetrics.average.html . An example of an average ROC curve is here: openExample('stats/PlotAverageROCCurveExample'). The averaging options include micro, macro, and weighted-macro.
In order to build the rocmetrics object with the rocmetrics function, or to use the perfcurve function, you will need the scores from the classify function.
Here is a precision-recall curve example for a tree model built with fisheriris data.
t=readtable("fisheriris.csv");
response="Species";
% Create set of models for 5-fold cross-validation
cvmdl=fitctree(t,response,KFold=5);
% Get cross-validation predictions and scores
[yfit,scores]=kfoldPredict(cvmdl);
% View confusion matrix
% The per-class Precision can be seen in the blue boxes in the
% column-summary along the bottom.
% The per-class Recall can be seen in the blue boxes in the row summary
% along the right side.
cm=confusionchart(t{:,response},yfit);
cm.ColumnSummary='column-normalized';
cm.RowSummary='row-normalized';
% Calculate precision, recall, and F1 per-class from the raw confusion
% matrix counts
% Precision = TP/(TP+FP); Recall = TP/(TP+FN);
% F1score is the harmonic mean of Precision and Recall.
% The cm.Normalization needs to be set to 'absolute', so that the values
% are raw counts.
counts = cm.NormalizedValues;
precisionPerClass= diag(counts)./ (sum(counts))';
recallPerClass = diag(counts)./ (sum(counts,2));
F1PerClass=2.*diag(counts) ./ ((sum(counts,2)) + (sum(counts))');
% Create rocmetrics object
% Add the metric "PositivePredictiveValue", which is Precision.
% The metric "TruePositiveRate", which is Recall, is in the Metrics by default.
rocObj=rocmetrics(t{:,response}, scores, cvmdl.ClassNames, ...
AdditionalMetrics="PositivePredictiveValue");
% For illustration, focus on metrics for one class, virginica
classindex=3;
% Plot the precision-recall curve for one class, given by classindex
% (By default, the rocmetrics plot function will plot the one-vs-all PR curves
% for all of the classes at once.)
r=plot(rocObj, YAxisMetric="PositivePredictiveValue", ...
XAxisMetric="TruePositiveRate", ...
ClassNames=rocObj.ClassNames(classindex));
hold on;
xlabel("Recall");ylabel("Precision"); title('Precision-Recall Curve');
% Place the operating point on the figure
scatter(recallPerClass(classindex),precisionPerClass(classindex),[],r.Color,"filled");
% Update legend
legend(strcat(string(rocObj.ClassNames(classindex))," one-vs-all P-R curve"), ...
strcat(string(rocObj.ClassNames(classindex)), ...
sprintf(' one-vs-all operating point\nP=%4.1f%%, R=%4.1f%%, F1=%4.1f%%', ...
100*precisionPerClass(classindex),100*recallPerClass(classindex), ...
100*F1PerClass(classindex))));
hold off;
0 件のコメント
参考
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!