I don't know how to interpret this ANN classification code.

1 回表示 (過去 30 日間)
JUN BEOM JEON
JUN BEOM JEON 2020 年 7 月 29 日
回答済み: Anshika Chaurasia 2020 年 8 月 5 日
I constructed following code by utilizing other's code and this code is considerably operated well. I got the 5 confusion matrix after running the code. but I can't interpret this result and code. Do five confusion matrix mean the individual result of K-Fold cross validation? I don't know why 5 confusion matrix were shown. and above all, Do my code is correctly wrotten? I use 6 inputs and 2 outputs for 32 samples.
rng(3); % initialize the RNG to the same state before training to obtain reproducibility
% conversion of table to matrix after loading .csv file.
input = table2array(inputs);
target = table2array(targets);
% hidden layers
hiddenLayerSize = 10;
net = patternnet(hiddenLayerSize);
% Choose the Divide Function
net.divideFcn = 'divideind';
% Choose the Performance Function
net.trainFcn = 'trainscg'
% K-Fold cross validation
k = 5;
cvFolds = crossvalind('Kfold', size(target,2), k);
for i = 1:k
net = configure(net, input, target);
testIdx = (cvFolds == i);
trainIdx = ~testIdx;
trInd = find(trainIdx)
tstInd = find(testIdx)
net.trainParam.epochs = 100;
net.divideParam.trainInd = trInd
net.divideParam.testInd = tstInd
% Performance Function
net.performFcn = 'mse'; % Mean squared error
% Train the network
[net, tr] = train(net,input,target);
% test using test instances
output = net(input);
errors = gsubtract(target, output);
performance = perform(net,target,output)
% Recalculate Training, Validation and Test Performance
trainTargets = target .* tr.trainMask{1};
testTargets = target .* tr.testMask{1};
trainPerformance = perform(net,trainTargets,output)
testPerformance = perform(net,testTargets,output)
test(k)=testPerformance;
% Plot confusion
save net
figure, plotconfusion(target,output)
end
accuracy=mean(test);
% View the Network
view(net)
ples.

回答 (1 件)

Anshika Chaurasia
Anshika Chaurasia 2020 年 8 月 5 日
It is my understanding that you are seeing 5 confusion matrices in the output after running the code. The reason for this is the 5-fold cross validation which is being performed. In your code, plotconfusion function is within the for loop with k=5:
for i = 1:k
....
figure, plotconfusion(target,output)
end
This is why the output has 5 confusion matrices. Each confusion matrix corresponds to one of the 5-fold cross validations.
Please refer to the Multilayer Shallow Neural Networks and Backpropagation Training documentation for better understanding of code.

カテゴリ

Help Center および File ExchangeDeep Learning Toolbox についてさらに検索

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by