Fine-tuning resnet18 (transfer learning) - layer connections error

12 ビュー (過去 30 日間)
Nour Mohamed
Nour Mohamed 2021 年 1 月 25 日
コメント済み: Mahesh Taparia 2021 年 2 月 4 日
I want to fine-tune the pretrained resnet18 model, but when I follow the steps on the website: https://www.mathworks.com/help/deeplearning/ug/transfer-learning-using-alexnet.html
I get an error about the layer connections in the netwrok and I can not retrain it. How can I preserve all the layer connections and not just the layers from resnet18?
the error:
Error using trainNetwork (line 170)
Invalid network.
Error in fine_tune (line 67)
netTransfer = trainNetwork(imdsTrain,lgraph,options);
Caused by:
Layer 'res2a': Unconnected input. Each layer input must be connected to the output of another layer.
Detected unconnected inputs:
input 'in2'
Layer 'res2b': Unconnected input. Each layer input must be connected to the output of another layer.
Detected unconnected inputs:
input 'in2'
Layer 'res3a': Unconnected input. Each layer input must be connected to the output of another layer.
Detected unconnected inputs:
input 'in2'
Layer 'res3b': Unconnected input. Each layer input must be connected to the output of another layer.
Detected unconnected inputs:
input 'in2'
Layer 'res4a': Unconnected input. Each layer input must be connected to the output of another layer.
Detected unconnected inputs:
input 'in2'
Layer 'res4b': Unconnected input. Each layer input must be connected to the output of another layer.
Detected unconnected inputs:
input 'in2'
Layer 'res5a': Unconnected input. Each layer input must be connected to the output of another layer.
Detected unconnected inputs:
input 'in2'
Layer 'res5b': Unconnected input. Each layer input must be connected to the output of another layer.
Detected unconnected inputs:
input 'in2'

回答 (1 件)

Mahesh Taparia
Mahesh Taparia 2021 年 1 月 28 日
Hi
It seems that the network input layer is not connected to the next layer or the lgraph is having some unconnected layers. Use analyzeNetwork to check the network connection, it will show the unconnected layers. Use connectLayers to connect two layers.
If it does n't work then, can you share your code? It will give more clarity.
Hope it will help!
  2 件のコメント
Nour Mohamed
Nour Mohamed 2021 年 2 月 2 日
Thank you for your answer!
The unconnected layers are the addition layers, I tried to use connectLayers and connect the missing connections by layer names, but the whole network got messy. Here is the code:
clear
net=resnet18;
layersTransfer = net.Layers(1:end-3);
layers = [
layersTransfer
fullyConnectedLayer(100,'WeightLearnRateFactor',20,'BiasLearnRateFactor',20)
softmaxLayer
classificationLayer];
layers(69, 1).Name = 'fc100';
layers(70, 1).Name = 'softmax';
layers(71, 1).Name = 'class100';
options = trainingOptions('sgdm', ...
'MiniBatchSize',32, ...
'MaxEpochs',10, ...
'InitialLearnRate',0.01, ...
'Shuffle','every-epoch', ...
'ValidationFrequency',3, ...
'Verbose',false, ...
'Plots','training-progress');
digitDatasetPath = 'C:\Users\user\Desktop\cifar-100-matlab\cifar-100-matlab\CIFAR-100\TEST';
imds = imageDatastore(digitDatasetPath, ...
'IncludeSubfolders',true,'LabelSource','foldernames');
lgraph=layerGraph(layers);
imdsh = transform(imds,@(x) resize(x, [224 224]));
[imdsTrain,imdsValidation] = splitEachLabel(imdsh.UnderlyingDatastore,0.7,'randomize');
netTransfer = trainNetwork(imdsTrain,lgraph,options);
[YPred,scores] = classify(netTransfer,augimdsValidation);
YValidation = imdsValidation.Labels;
accuracy = mean(YPred == YValidation)
Mahesh Taparia
Mahesh Taparia 2021 年 2 月 4 日
Hi
It seems you want to remove last few layers of the dlnetwork. Taking only the layer part of the network will lose the layerGraph property. For example, in your case you want to remove last 3 layers and want to add some layer. You can remove these layers by using removeLayers as given below:
net=resnet18;
lgraph=layerGraph(net);
lgraph = removeLayers(lgraph,{'ClassificationLayer_predictions','prob','fc1000'});
analyzeNetwork(lgraph)
After this you use connectLayers to connect the other layers to the layergraph. You can also check the documentation of replaceLayer here. Hope it will help!

サインインしてコメントする。

カテゴリ

Help Center および File ExchangeImage Data Workflows についてさらに検索

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by