Training shallow neural network (no hidden layer) for MNIST classification. Getting low accuracy
6 ビュー (過去 30 日間)
古いコメントを表示
Achint Kumar
2019 年 11 月 25 日
回答済み: Srivardhan Gadila
2020 年 1 月 13 日
I am trying to implement classification of MNIST dataset without any hidden layers. So the data is input as 784x1 vector and output is 10x1 vector (after one-hot encoding).
The problem is that once I train the network, I am getting very low accuracy (~1%). The reason for that is unclear to me. My guess is that my update rules are incorrect. After training, the output vector on test data (a_test below) is not one-hot anymore but has multiple 1's. I am not able to figure out where I am going wrong.
alp = 0.0001; % learning rate
epoch = 50; % number of iterations
dim_train = 784; % dimension of input data vector
for itr = 1: epoch
z = W*image_train+repmat(b',n_train,1)'; % image_train is a 784*60000 data matrix and W is 10*784 weight matrix and b is bias
a = 1./(1+exp(-z));
a = a+0.001; % to avoid zero inside log when calculating cross entropy loss
a_flg = sum(a);
for i = 1:n_train
a(:,i) = a(:,i)/a_flg(i); % normalizing output
end
L = -sum( Y_train.*log(a), 'all' ); % calculation loss
dLdW = 1/dim_train*(a-Y_train)*image_train'; % calculating dL/dW
dLdb = 1/dim_train*(a-Y_train)*ones(n_train,1);% calculating dL/db
W = W - alp*dLdW; % updating weights (gradient descent)
b = b - alp*dLdb; % updating bias (gradient descent)
loss(itr) = 1/n_train*L;
end
%% Testing
a_test = 1./(1+exp(-(W*image_test+repmat(b',n_test,1)')));
for i = 1:n_test
a_test(~(a_test(:,i)==max(a_test(:,i))),i)=0;
end
0 件のコメント
採用された回答
Srivardhan Gadila
2020 年 1 月 13 日
The network architecture defined without any hidden layer may not be able to learn to classify the digits beloging to 10 classes. It is reccommended to use few hidden layers.
You can also refer to Define Custom Deep Learning Layers & Create Simple Deep Learning Network for Classification
0 件のコメント
その他の回答 (0 件)
参考
カテゴリ
Help Center および File Exchange で Deep Learning Toolbox についてさらに検索
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!