フィルターのクリア

re training neural network from previous state using trainNetwork

21 ビュー (過去 30 日間)
Avi Sulimarski
Avi Sulimarski 2021 年 4 月 29 日
コメント済み: Xie Shipley 2023 年 11 月 16 日
Hi
I am training a deep neural network , using the following matlab function:
net = trainNetwork(XTrain,YTrain,layers,options);
could I use the trainNetwork command to retrain the network (not from scratch), using the last network state from previous training?
I am sharing some of the code:
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
layers = [
imageInputLayer([1 Nin 5],"Name","imageinput")
convolution2dLayer([1 3],32,"Name","conv","Padding","same")%
%batchNormalizationLayer('Name','batchDown')
tanhLayer("Name","tanh1")
convolution2dLayer([1 3],32,"Name","conv","Padding","same","DilationFactor",[1 3])%
%batchNormalizationLayer('Name','batchDown1')
tanhLayer("Name","tanh2")
convolution2dLayer([1 3],2,"Name","conv","Padding","same","DilationFactor",[1 9])%
regressionLayer];
options = trainingOptions('adam', ...
'InitialLearnRate',0.001, ...
'MaxEpochs',50000, ...
'ExecutionEnvironment','parallel',...
'Verbose',false, ...
'Plots','training-progress');
Net1 = trainNetwork(XTrain1,YTrain1,layers,options);
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
now I would like to train again getting Net2 using new data, and starting the training from Net1 stage.
for example:
Net2 = trainNetwork(XTrain2,YTrain2,layers,options);
however its not clear to me how to start the train process for Net2 using Net1 stage.

採用された回答

Yomna Genina
Yomna Genina 2021 年 5 月 5 日
You can do the following if Net1 is a SeriesNetwork:
Net2 = trainNetwork(XTrain2, YTrain2, Net1.Layers, options);
or if Net1 is a DAGNetwork:
Net2 = trainNetwork(XTrain2, YTrain2, layerGraph(Net1), options);
This will train using Net1 as the initial network.
If you would also like to prevent weights in certain layers from changing, you could use freezeWeights (see how to access it below). This function could be used to set the learning rates in those layers to zero. During training, trainNetwork does not update the parameters of the "frozen" layers.
edit(fullfile(matlabroot,'examples','nnet','main','freezeWeights.m'))
You might find the following documentation pages useful:
Hope this helps!
  4 件のコメント
NASRIN AKTER
NASRIN AKTER 2022 年 6 月 4 日
Hi Yomna
Is it necessary to save the checkpoint networks as shown in the 1st link?
Also can you suggest any example of incremental learning for image classification using deep learning?
Thanks
Xie Shipley
Xie Shipley 2023 年 11 月 16 日
If using "population" in `trainingOption`,will the mean and variance of the Net2(it's mean and variance is same as Net1 when stop training Net1) changing while training ? @Yomna Genina

サインインしてコメントする。

その他の回答 (0 件)

カテゴリ

Help Center および File ExchangeImage Data Workflows についてさらに検索

製品


リリース

R2020b

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by