U-net looses connections, becomes linear rather than U-shaped (unetLayers)

4 ビュー (過去 30 日間)
Allison
Allison 2023 年 6 月 13 日
コメント済み: Allison 2023 年 6 月 15 日
I'm trying to follow the example in Semantic Segmentation of Multispectral Images Using Deep Learning, with the goal of using the pretrained network for tranfer learning to my own semantic segmentation network to use on other terrain images. Unfortunately, my U-net becomes linear, and I cannot figure out why.
When training, the input image size is [256, 256, 6] and there are 18 classes. To keep things simple, I train my network with the data linked in the example and use the function unetLayers (rather than the helper function referenced in the example).
inputTileSize = [256,256,6];
lgraph = unetLayers(inputTileSize, 18, 'EncoderDepth', 4);
plot(lgraph)
I train the network using:
%the mat files and randomPatchExtractionDatastore function come from the
%linked matlab example page above
imds = imageDatastore("train_data.mat",FileExtensions=".mat",ReadFcn=@matRead6Channels);
pxds = pixelLabelDatastore("train_labels.png",classNames,pixelLabelIds);
dsTrain = randomPatchExtractionDatastore(imds,pxds,[256,256],PatchesPerImage=1000);
initialLearningRate = 0.05;
maxEpochs = 5; %low b/c proof of concept, not meant for actual use
minibatchSize = 8;
l2reg = 0.0001;
options = trainingOptions("sgdm",...
InitialLearnRate=initialLearningRate, ...
Momentum=0.9,...
L2Regularization=l2reg,...
MaxEpochs=maxEpochs,...
MiniBatchSize=minibatchSize,...
LearnRateSchedule="piecewise",...
Shuffle="every-epoch",...
GradientThresholdMethod="l2norm",...
GradientThreshold=0.05, ...
Plots="training-progress", ...
VerboseFrequency=20);
net = trainNetwork(dsTrain,lgraph,options);
save("my_multispectralUnet_2.mat", "net");
After training, I load and plot the network. It is linear, rather than U-shaped.
data = load("C:\\Work\\CMFD\\my_multispectralUnet_2.mat");
net = data.net;
plot(layerGraph(net.Layers))
No errors occur while running the above code. What happened to the connections between the encoder and decoder sections?
  2 件のコメント
mohd akmal masud
mohd akmal masud 2023 年 6 月 14 日
you lost your connection between encoder and decorder network.
you can edit in network deisgner apps
Allison
Allison 2023 年 6 月 14 日
How can I make sure that the connections between the encoder and decoder parts do not get lost? Because the connections are there before I train the network, but not after.

サインインしてコメントする。

採用された回答

Richard
Richard 2023 年 6 月 15 日
This is not being caused by the training, or saving and loading: the network is likely correct at that point. The loss of connection data is just caused by the form of your second call to the plot() function:
plot(layerGraph(net.Layers))
This line is first extracting just the layers as a linear list when it calls net.Layers, then constructing a new LayerGraph which has none of the original connections from net. If you just call:
plot(net)
then you will see the correct network.
  1 件のコメント
Allison
Allison 2023 年 6 月 15 日
Thanks! This helped me understand what was going on.
For other folks who check here in the future, in order to get the network to have a U-shape for transfer learning (my next step for this project), I had to include:
layersTransfer = net.Layers(1:end-3);
layers = [
layersTransfer
fullyConnectedLayer(numClasses,'WeightLearnRateFactor',20,'BiasLearnRateFactor',20, 'Name', 'fcl_a')
softmaxLayer('Name','sl_b')
classificationLayer('Name', 'cl_c')];
% Create the layer graph and create connections in the graph
lgraph = layerGraph(layers);
% Connect concatenation layers
lgraph = connectLayers(lgraph, 'Encoder-Stage-1-ReLU-2','Decoder-Stage-4-DepthConcatenation/in2');
lgraph = connectLayers(lgraph, 'Encoder-Stage-2-ReLU-2','Decoder-Stage-3-DepthConcatenation/in2');
lgraph = connectLayers(lgraph, 'Encoder-Stage-3-ReLU-2','Decoder-Stage-2-DepthConcatenation/in2');
lgraph = connectLayers(lgraph, 'Encoder-Stage-4-DropOut','Decoder-Stage-1-DepthConcatenation/in2');
analyzeNetwork(lgraph) %has desired shape now
%specify your options and whatnot
netTransfer = trainNetwork(dsTrain2,lgraph,options);
I've got other error messages now, but thats just coding for ya

サインインしてコメントする。

その他の回答 (0 件)

カテゴリ

Help Center および File ExchangeImage Data Workflows についてさらに検索

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by