フィルターのクリア

How to add Distance transformation Map in loss function at classification layer.

2 ビュー (過去 30 日間)
Raza Ali
Raza Ali 2020 年 8 月 27 日
コメント済み: Raza Ali 2020 年 9 月 10 日
Hi Everyone, I am trying to insert distance map information in loss fucntion. i am doing this in classification layer of CNN.
but when I calculate the disance map using "bwdist(Y)" commnad, during training process the MATLAB produce error
"Error using 'forwardLoss' in Layer ClassificationLayer. The function threw an error and could not be executed".
"Expected input image to be a 2-D real-valued, non-sparse gpuArray with underlying class uint8, uint16, uint32, int8, int16, int32, logical, single or double".
How can I add Distance transformation Map in loss fucntion. or how to resolve this issue?
  3 件のコメント
Raza Ali
Raza Ali 2020 年 9 月 10 日
The classification layer code:
%%%%%%%%%%%%%%%%%%
classdef CEDLossLayer < nnet.layer.ClassificationLayer
properties
% Row vector of weights corresponding to the classes in the
% training data.
Beta=0.7;
end
methods
function layer = CEDLossLayer(name)
% layer = CEDLossLayer(name) creates a
% Set layer name.
if nargin == 2
layer.Name = name;
end
% Set layer description
layer.Description = 'cross entropy';
end
function loss = forwardLoss(layer, Y, T)
% loss = forwardLoss(layer, Y, T) returns the cross entropy loss between the predictions Y and the training
% targets T.
N = size(Y, 4) * size(Y, 1) * size(Y, 2);
Y = squeeze(Y);
T = squeeze(T);
addpath ('D:\Ditsnave Transform')
W=WeightMap(T(:,:,1));
W = gather(W);
loss_i = ((layer.Beta).*W.*T .* log(nnet.internal.cnn.util.boundAwayFromZero(Y)))+((1-layer.Beta).*(1-W).*(1-T) .* log(1-(nnet.internal.cnn.util.boundAwayFromZero(Y))));
loss = -sum( sum( sum( sum(loss_i, 3).*(1./N), 1), 2));
end
function dLdY = backwardLoss(layer, Y, T)
% dLdX = backwardLoss(layer, Y, T) returns the derivatives of
% cross entropy loss with respect to the
% predictions Y.
N = size(Y, 4) * size(Y, 1) * size(Y, 2);
Y = squeeze(Y);
T = squeeze(T);
addpath ('D:\Ditsnave Transform')
W=WeightMap(T(:,:,1));
W = gather(W);
dLdY= (-(W.*T./nnet.internal.cnn.util.boundAwayFromZero(Y))).*(1./N);
% dLdY= -(1./N).*((((layer.Beta).*T)./nnet.internal.cnn.util.boundAwayFromZero(Y))-((1-layer.Beta).*(1-T))./(1-nnet.internal.cnn.util.boundAwayFromZero(Y)));
end
end
end
Raza Ali
Raza Ali 2020 年 9 月 10 日
%% Weight Map Function
function [weight]=WeightMap(gt);
% class balance weights w_c(x)
uvals=unique(gt);
wmp=zeros(1,length(uvals));
for uv=1:length(uvals)
wmp(uv)=1/sum(gt(:)==uvals(uv));
end
% this normalization is important!
%background pixels must have weight 1
wmp=wmp/max(wmp);
% wc=double(gt);
% wc=uint8(gt);
wc=zeros(size(gt));
for uv=1:length(uvals)
wc(gt==uvals(uv))=wmp(uv);
end
% cells instances for distance computation
cells=bwlabel(gt==1, 4);
% cells distance map
bwgt=zeros(size(gt));
maps=zeros(size(gt,1),size(gt,2),max(cells(:)));
if max(cells(:))>=2
for ci=1:max(cells(:))
maps(:,:,ci)=bwdist(cells==ci);
end
maps=sort(maps,3);
d1=maps(:,:,1);
d2=maps(:,:,2);
bwgt=10*exp(-((d1+d2).^2)./(2*25) ).*(cells==0)
end
% unet weights
weight=wc + bwgt;
end

サインインしてコメントする。

回答 (0 件)

カテゴリ

Help Center および File ExchangeGet Started with Statistics and Machine Learning Toolbox についてさらに検索

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by