Having high loss function with the custom training loop.
4 ビュー (過去 30 日間)
古いコメントを表示
Good day everyone,
I'm currently working on a custom training loop for cardiac segmentation. However, I'm encountering extremely high loss values during training when using the crossentropy function.
function [loss,Y] = modelLoss_test(net,X,T)
Y = forward(net,X);
loss = crossentropy(Y,T);
end
I've checked that X (size: 256 x 208 x 1 x 2) and T (size: 256 x 208 x 2 x 2) are both in 4-D dlarray. Both Y and T have max value of 1 and min value of 0. However, when directly calculated with the function “loss = crossentropy(Y,T)” the loss value given was extremely high (e.g. 4.123 x 10^5). On the other hand, when I compute the loss manually using the following code, I get a more reasonable value (e.g., 15.356):
yy = gather(extractdata(Y(:)));
tt = gather(extractdata(T(:)));
loss = crossentropy(yy,tt);
For context, I'm using a U-Net with the Adam optimizer. I replaced the final convolution layer of the U-Net with a layer that has 2 output channels:
lgraph = replaceLayer(lgraph, 'Final-ConvolutionLayer', convolution2dLayer(3, 2, 'padding','same','Name', 'Segmentation-Layer'));
I also tried incorporating class weights into the loss function (which resulted in an insignificant reduction in the loss value):
weights = [0.95, 0.05];
loss = crossentropy(Y,T,weights,WeightsFormat="BC");
Could someone explain why there is such a large difference in loss values when using MATLAB's built-in crossentropy function versus my manual calculation? I would greatly appreciate any advice or solutions to this problem. Thank you in advance!
4 件のコメント
回答 (1 件)
Jayanti
2024 年 10 月 16 日
Hi Hui,
Let’s start by analysing the difference between the MATLAB in-built and custom loss function.
Generally, in image segmentation and classification task true labels are provided in one-hot encoding format. The built-in function is interpreting “T” (True label) as one-hot encoded vector while calculating the loss.
Whereas in custom loss function you have extracted data from “dlarray” and stored it in “tt” variable. Now since data has been extracted from “dlarray”, while passing it in cross entropy function it will not be treated as one hot encoded vector. Hence, both losses will result in two different values.
If you want to calculate loss on extracted values, then you can calculate it using below code. This will give you the same loss value as the built-in cross entropy function.
yy = gather(extractdata(Y(:)));
tt = gather(extractdata(T(:)));
loss_array = -sum(tt .* log(yy));
I tried running the code using the above custom loss function and it is giving the same results as that of built-in cross entropy function.
0 件のコメント
参考
カテゴリ
Help Center および File Exchange で Image Data Workflows についてさらに検索
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!