I don't know how to interpret this ANN classification code.
1 回表示 (過去 30 日間)
古いコメントを表示
I constructed following code by utilizing other's code and this code is considerably operated well. I got the 5 confusion matrix after running the code. but I can't interpret this result and code. Do five confusion matrix mean the individual result of K-Fold cross validation? I don't know why 5 confusion matrix were shown. and above all, Do my code is correctly wrotten? I use 6 inputs and 2 outputs for 32 samples.
rng(3); % initialize the RNG to the same state before training to obtain reproducibility
% conversion of table to matrix after loading .csv file.
input = table2array(inputs);
target = table2array(targets);
% hidden layers
hiddenLayerSize = 10;
net = patternnet(hiddenLayerSize);
% Choose the Divide Function
net.divideFcn = 'divideind';
% Choose the Performance Function
net.trainFcn = 'trainscg'
% K-Fold cross validation
k = 5;
cvFolds = crossvalind('Kfold', size(target,2), k);
for i = 1:k
net = configure(net, input, target);
testIdx = (cvFolds == i);
trainIdx = ~testIdx;
trInd = find(trainIdx)
tstInd = find(testIdx)
net.trainParam.epochs = 100;
net.divideParam.trainInd = trInd
net.divideParam.testInd = tstInd
% Performance Function
net.performFcn = 'mse'; % Mean squared error
% Train the network
[net, tr] = train(net,input,target);
% test using test instances
output = net(input);
errors = gsubtract(target, output);
performance = perform(net,target,output)
% Recalculate Training, Validation and Test Performance
trainTargets = target .* tr.trainMask{1};
testTargets = target .* tr.testMask{1};
trainPerformance = perform(net,trainTargets,output)
testPerformance = perform(net,testTargets,output)
test(k)=testPerformance;
% Plot confusion
save net
figure, plotconfusion(target,output)
end
accuracy=mean(test);
% View the Network
view(net)
ples.
0 件のコメント
回答 (1 件)
Anshika Chaurasia
2020 年 8 月 5 日
It is my understanding that you are seeing 5 confusion matrices in the output after running the code. The reason for this is the 5-fold cross validation which is being performed. In your code, plotconfusion function is within the for loop with k=5:
for i = 1:k
....
figure, plotconfusion(target,output)
end
This is why the output has 5 confusion matrices. Each confusion matrix corresponds to one of the 5-fold cross validations.
Please refer to the Multilayer Shallow Neural Networks and Backpropagation Training documentation for better understanding of code.
0 件のコメント
参考
カテゴリ
Help Center および File Exchange で Deep Learning Toolbox についてさらに検索
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!