How to plot confusion matrix?

11 ビュー (過去 30 日間)
Adrian Kleffler
Adrian Kleffler 2023 年 5 月 22 日
編集済み: Venkat Siddarth 2023 年 5 月 29 日
Hello guys, I want to plot confusion matrix after training an object detector ... Here is my code ... How to plot confusion matrix ?
data = load("letisko_labels_new.mat");
LabelData = data.gTruth.LabelData;
LabelData.imageFilename = fullfile(LabelData.imageFilename);
rng("default");
shuffledIndices = randperm(height(LabelData));
idx = floor(0.6 * length(shuffledIndices) );
trainingIdx = 1:idx;
trainingDataTbl = LabelData(shuffledIndices(trainingIdx),:);
validationIdx = idx+1 : idx + 1 + floor(0.1 * length(shuffledIndices) );
validationDataTbl = LabelData(shuffledIndices(validationIdx),:);
testIdx = validationIdx(end)+1 : length(shuffledIndices);
testDataTbl = LabelData(shuffledIndices(testIdx),:);
imdsTrain = imageDatastore(trainingDataTbl{:,"imageFilename"});
bldsTrain = boxLabelDatastore(trainingDataTbl(:,2:6));
imdsValidation = imageDatastore(validationDataTbl{:,"imageFilename"});
bldsValidation = boxLabelDatastore(validationDataTbl(:,2:6));
imdsTest = imageDatastore(testDataTbl{:,"imageFilename"});
bldsTest = boxLabelDatastore(testDataTbl(:,2:6));
trainingData = combine(imdsTrain,bldsTrain);
validationData = combine(imdsValidation,bldsValidation);
testData = combine(imdsTest,bldsTest);
validateInputData(trainingData);
validateInputData(validationData);
validateInputData(testData);
inputSize = [256 256 3];
className = ["kamera","lietadlo","satelit","stlp","veza"];
rng("default")
trainingDataForEstimation = transform(trainingData,@(data)preprocessData(data,inputSize));
numAnchors = 9;
[anchors,meanIoU] = estimateAnchorBoxes(trainingDataForEstimation,numAnchors);
area = anchors(:, 1).*anchors(:,2);
[~,idx] = sort(area,"descend");
anchors = anchors(idx,:);
anchorBoxes = {anchors(1:3,:)
anchors(4:6,:)
anchors(7:9,:)
};
detector = yolov4ObjectDetector("csp-darknet53-coco",className,anchorBoxes,InputSize=inputSize);
augmentedTrainingData = transform(trainingData,@augmentData);
options = trainingOptions("adam",...
GradientDecayFactor=0.9,...
SquaredGradientDecayFactor=0.999,...
InitialLearnRate=0.001,...
LearnRateSchedule="none",...
MiniBatchSize=4,...
L2Regularization=0.0005,...
MaxEpochs=50,...
BatchNormalizationStatistics="moving",...
DispatchInBackground=true,...
ResetInputNormalization=false,...
Shuffle="every-epoch",...
VerboseFrequency=20,...
ValidationFrequency=1000,...
Plots="training-progress",...
CheckpointPath='C:\BAKALARKA\checkpointYOLO',...
ValidationData=validationData);
doTraining = true;
if doTraining
% Train the YOLO v4 detector.
[detector,info] = trainYOLOv4ObjectDetector(augmentedTrainingData,detector,options);
else
% Load pretrained detector for the example.
detector = downloadPretrainedYOLOv4Detector();
end
detectionResults = detect(detector,testData,'MiniBatchSize',4);
[ap,recall,precision] = evaluateDetectionPrecision(detectionResults,testData);
recallv = cell2mat(recall);
precisionv = cell2mat(precision);
[r,index] = sort(recallv);
p = precisionv(index);
figure
plot(r,p)
xlabel("Recall")
ylabel("Precision")
grid on
title(sprintf("Average Precision = %.2f",mean(ap)))

回答 (1 件)

Venkat Siddarth
Venkat Siddarth 2023 年 5 月 29 日
編集済み: Venkat Siddarth 2023 年 5 月 29 日
I understand that you are looking to plot confusion matrix for the model. Here I am assuming that you want to plot the confusion matrix for the labels column in detectionResults,which can be achieved by using a function called confusionmat. This function takes two vectors as inputs, the true labels and the predicted labels and produces a confusion matrix.
y_true=[1 0 1 1 1 1 0 0];
y_pred=[1 1 1 1 0 0 1 1];
C=confusionmat(y_true,y_pred)
C = 2×2
0 3 2 3
After generating the confusion matrix you can plot the confusion matrix using the function confusionchart
confusionchart(C)
To know more about these functions, check out the following documentation
I hope this resolves the issue,
Regards
Venkat Siddarth V.

カテゴリ

Help Center および File ExchangeImage Processing and Computer Vision についてさらに検索

タグ

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by