implementation of mini-batch stochastic gradient descent
10 ビュー (過去 30 日間)
古いコメントを表示
I implemented a mini-batch stochastic gradien descent but counldn't find the bug in my code.
I used this implement to do a classification problem but all my final predictions are 0.
W2 = -1+2*rand(5,2); W3 = -1+2*rand(5,5);
W4 = -1+2*rand(5,5); W5 = -1+2*rand(1,5);
b2 = -1+2*rand(5,1); b3 = -1+2*rand(5,1);
b4 = -1+2*rand(5,1); b5 = -1+2*rand(1,1);
eta = 5e-3; % learning rate
iter = 1000; % number of iterations
num_data = length(label);
loss_vec = zeros(1,iter);
tloss_vec = zeros(1,iter);
for it = 1:iter
% mini-batch method
batch_size = 50;
rand_idx = randperm(num_data);
rand_idx = reshape(rand_idx,[],num_data/batch_size);
for idx = rand_idx
% forward pass
a2 = activate([x1(:,idx);x2(:,idx)], W2, b2);
a3 = activate(a2,W3,b3);
a4 = activate(a3,W4,b4);
a5 = activate(a4,W5,b5);
% backward pass (gradient)
delta5 = a5.*(1-a5).*(a5-label(idx));
delta4 = a4.*(1-a4).*(W5'*delta5);
delta3 = a3.*(1-a3).*(W4'*delta4);
delta2 = a2.*(1-a2).*(W3'*delta3);
% update weights and bias
W2 = W2 - 1/length(idx)*eta*delta2*[x1(:,idx);x2(:,idx)]';
W3 = W3 - 1/length(idx)*eta*delta3*a2';
W4 = W4 - 1/length(idx)*eta*delta4*a3';
W5 = W5 - 1/length(idx)*eta*delta5*a4';
b2 = b2 - 1/length(idx)*eta*sum(delta2,2);
b3 = b3 - 1/length(idx)*eta*sum(delta3,2);
b4 = b4 - 1/length(idx)*eta*sum(delta4,2);
b5 = b5 - 1/length(idx)*eta*sum(delta5,2);
% compute train loss and test loss
loss_vec(it) = 1/(2*num_data)*LossFunc(W2,W3,W4,W5,b2,b3,b4,b5,[x1;x2],label);
tloss_vec(it) = 1/(2*200)*LossFunc(W2,W3,W4,W5,b2,b3,b4,b5,[tx1;tx2],tlabel);
end
end
%% cost function
function loss = LossFunc(W2,W3,W4,W5,b2,b3,b4,b5,x,y)
a2 = activate(x, W2, b2);
a3 = activate(a2, W3, b3);
a4 = activate(a3, W4, b4);
a5 = activate(a4, W5, b5);
loss = norm(a5-y,2)^2;
end
%% prediction
function pred = predict(W2,W3,W4,W5,b2,b3,b4,b5,x)
a2 = activate(x, W2, b2);
a3 = activate(a2, W3, b3);
a4 = activate(a3, W4, b4);
a5 = activate(a4, W5, b5);
pred = round(a5);
end
%% activation function
function y = activate(x,W,b)
y = 1./(1+exp(-(W*x+b)));
end
0 件のコメント
回答 (2 件)
Mahesh Taparia
2021 年 4 月 2 日
Hi
You mentioned that you are implementing a classification network. In your code, you are using square of L2 norm to calculate the loss and loss derivative is also not correct while doing back propagation. Moreover, since it is a classification network, use the classification loss like cross entropy loss, focalcrossentropy, etc instead of norm. May be this is the reason you are getting 0 everytime.
Also, you can use MATLAB inbuilt function to perform back propagation. For this, you can refer the link given below:
Hope it will help!
Mohamed Salem
2022 年 12 月 22 日
Write a MATLAB code, that implement Dalta learning rule with mini-batch.
Compare (with graph) your mini-batch algorithm with SGD, Batch algorithm in terms of mean square error.
0 件のコメント
参考
カテゴリ
Help Center および File Exchange で Deep Learning Toolbox についてさらに検索
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!