Customerized loss function taking X as inputs in CNN
1 回表示 (過去 30 日間)
古いコメントを表示
Hello,
I am following this example ( https://www.mathworks.com/help/deeplearning/ug/train-network-using-custom-training-loop.html ) to define a customerized training loop in a regression problem, because I want to pass my X data in each mini batch to the my customerized loss function.
As in the example in the link provided above, I wrote the modelGradients() inside which the loss function myLoss() I defined was called, instead of the crossentropy(). However, I recived the error:
Error using dlfeval (line 43)
Value to differentiate must be a traced dlarray scalar.
when tring to use the dlgradient() function.
Below is my code:
function [gradients,average_loss] = modelGradients(dlnet,dlX,Y)
dlYPred = forward(dlnet,dlX);
loss = myLoss(dlX,dlYPred,Y);
gradients = dlgradient(loss,dlnet.Learnables);
end
function average_loss= myLoss(dlX,dlYPred,Y)
Ypred = extractdata(dlYPred);
X = extractdata(dlX);
sum_loss = 0;
% loop thru all the data in the mini batch and calculate the average
% loss
for i = 1:size(dlX,4)
% f_calculate_loss() is a self-defined loss function takes X,
% Ytarget and Ypred as input
sum_loss = sum_loss + f_calculate_loss(X(:,:,:,i),Y(i,:),Ypred(:,i));
end
% calculate the average loss and need to convert the type into dlarray
average_loss = sum_loss/size(dlX,4); % 1*1 dlarray
average_loss=dlarray(average_loss);
end
In my code, dlX is a 4-D 3(S)*72(S)*1(C)*64(B) single dlarray, dlYPred is 3*64 dlarray, Y is a 64*3 double array, where 64 is the miniBatchSize.The calculated loss (none zeor) is a 1*1 dlarray.
I've been stuck at this issue for weeks and still cound't understand what's going on. I really appreciate if anyone could help explaining what's going on here and how should I fix this. Thank you so much in advanced!
0 件のコメント
回答 (1 件)
Hrishikesh Borate
2021 年 7 月 16 日
Hi,
The problem arises due to the use of the extractdata before computing the gradient, as it breaks the derivative trace. Use the dlarray supported functions to compute the loss.
0 件のコメント
参考
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!