I want to train a network and save the weights. Then use a new data set and resume training with the weights from the previous data set.

30 ビュー (過去 30 日間)
It seems train always randomizes the weights when called. I can't seem to maintain the weights from the first training set.

採用された回答

Sonam Gupta
Sonam Gupta 2018 年 3 月 28 日
You can continue the training from weights obtained by previous data set by extracting the layers from the network's "Layers" property, and then passing it to "trainNetwork", as follows:
if true
% Train a network
net = trainNetwork(XTrain, YTrain, layers, options);
% Extract layers from the trained network
newLayers = net.Layers;
% Retrain the network, but start from where we left off
newNet = trainNetwork(XTrain, YTrain, newLayers, options);
'trainNetwork' will always use the weights that are stored in the layers which you pass in for training.
  3 件のコメント
Moh. Saadat
Moh. Saadat 2022 年 8 月 29 日
There is a small caveat to this: check that whether your output 'net' is a LayerGraph or a DAGNetwork. If it is not a LayerGraph, use layerGraph(net) instead of net.Layers.

サインインしてコメントする。

その他の回答 (0 件)

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by