Poor performance of trainNetwork() function as compared to train()

24 ビュー (過去 30 日間)
Ivan Rodionov
Ivan Rodionov 2025 年 2 月 8 日 20:36
編集済み: Matt J 2025 年 2 月 9 日 0:22
Hello, I am having issues with training a neural network using trainNetwork() as compared to train() and I am stumped. I tried to set up an identical network architecture, the first using fitnet() for train() function and the second using the toolbox for trainNetwork. The train() function converges rapidly and gets to a good solution, while I am currently unable to get the trainNetwork() to converge decently. What am I doing wrong?
Code 1: Working with train()
% Select if new or old network
if useExisting == false
fprintf("Option A: Train New Model\n");
net = fitnet(hidden_layer_size); % Define network
else
fprintf("Option B: Load and Fine-tune Existing Model\n");
net = loadedData.net; % Load the existing neural network
end
% Set the common network training parameters
net.trainFcn = 'trainscg';
net.trainParam.epochs = 250E3;
net.trainParam.goal = 0;
net.trainParam.max_fail = 6;
% Divide data for training and testing (80:20 split)
net.divideParam.trainRatio = 0.8;
net.divideParam.testRatio = 0.1;
net.divideParam.valRatio = 0.1; % Disable validation as we're focusing on training and testing
% Train the neural network
net = train(net, input, target, 'usegpu','yes'); % Training with shuffled input and target
Code 2: Not working
options = trainingOptions('rmsprop', ...
'MaxEpochs', 250E3, ... % Keep epochs reasonable
'MiniBatchSize', 32, ... % Large batch for stable updates
'InitialLearnRate', 1E-3, ... % Lower LR since RMSprop adapts per-parameter
'SquaredGradientDecayFactor', 0.85, ... % Default is 0.99; try 0.9 for faster adaptation
'Shuffle', 'every-epoch', ... % Keep it since data has overlap
'ValidationData', {input(:, 1:round(0.2 * end)), target(:, 1:round(0.2 * end))}, ...
'ValidationFrequency', 50, ... % Check validation every 50 mini-batches
'Verbose', true, ...
'Plots', 'training-progress', ...
'ValidationPatience', 12, ... % More patience for slow convergence
'ExecutionEnvironment', 'gpu'); % Use GPU for speed
% Adjust layers for 1D data
layers = [
sequenceInputLayer(chunk_size, 'Name', 'input') % Adjusted for 1D data
fullyConnectedLayer(hidden_layer_size, 'Name', 'fc1')
tanhLayer
fullyConnectedLayer(chunk_size, 'Name', 'output') % Adjust output size if needed
regressionLayer('Name', 'regression')];
% Select if new or old network
if useExisting
fprintf("Option B: Load and Fine-tune Existing Model\n");
net = trainNetwork(input, target, net.Layers, options);
else
fprintf("Option A: Train New Model\n");
net = trainNetwork(input, target, layers, options);
end
I have tried playing with the training parameters in the trainNetwork and this is the best I was able to set. Unfortunately, the performance is dismal compared to the train().
  3 件のコメント
Ivan Rodionov
Ivan Rodionov 2025 年 2 月 8 日 21:33
編集済み: Ivan Rodionov 2025 年 2 月 8 日 21:33
@Matt J Hello Matt and thank you for your reply. I guess in some sense you are right, however if I am understanding the code correctly, it should not be the case because both codes are fundamentally doing the same thing? A single hidden layer slightly wide neural network with identical tanh activation function. Be it made via fitnet() or layers, it should fundamentally not work with one and not at all with the other unless something is broken and or I am misunderstanding what is going on behind the scenes?
EDIT:
If you are curious, I can gladly provide the datasets, it is a distorted signal.
Matt J
Matt J 2025 年 2 月 9 日 0:07
編集済み: Matt J 2025 年 2 月 9 日 0:08
But the algorithm used by trainscg() is different, and has fewer tuning parameters than rmsprop. We don't know how performance might improve if you changed the InitialLearnRate, MiniBatchSize, and other rmsprop parameters. You might try Adam instead of RmsProp. I've heard it is more robust.

サインインしてコメントする。

回答 (1 件)

Matt J
Matt J 2025 年 2 月 9 日 0:19
編集済み: Matt J 2025 年 2 月 9 日 0:22
I have tried playing with the training parameters in the trainNetwork and this is the best I was able to set
You can try using the Experiment Manager to explore the hyperparameter space more systematically,

カテゴリ

Help Center および File ExchangeSequence and Numeric Feature Data Workflows についてさらに検索

製品


リリース

R2023b

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by