How to plot the loss function on the overall dataset in Training Progress

16 ビュー (過去 30 日間)
Maria Grazia
Maria Grazia 2023 年 5 月 25 日
編集済み: Aneela 2024 年 9 月 22 日
I am writing a Convolutional Neural Network for regression in MATLAB R2021b. I'm using the trainNetwork function in Deep Learning Toolbox and in options I have 'Plots','training-progress'.
I have understood that in the Training Progress, for each iteration, I have the MSE value computed on the mini-batch.
I wuold like to know whether I can plot in the training progress the MSE value on the overall Training Set and Validation Set respectively and not on the mini-batch.
Thanks in advance.

回答 (1 件)

Aneela
Aneela 2024 年 9 月 13 日
編集済み: Aneela 2024 年 9 月 22 日
Hi Maria,
The trainNetwork function's default training progress plot displays the mini-batch loss (such as MSE) and accuracy for each iteration.
  • However, it does not directly provide options to display metrics computed over the entire training or validation set during training.
  • To achieve this, the training loop should be customised using a custom training loop approach.
Here’s a possible workaround:
  • Set hyperparameters like learning rate, number of epochs, and mini-batch size.
  • Iterate over number of epochs.
  • Within each epoch, iterate over mini-batches and compute predictions for mini-batch.
  • Compute the gradients of the loss with respect to model parameters.
  • Improve the model's performance by minimizing the loss using an optimization algorithm.
  • Compute the MSE over the training and validation datasets after each mini-batch update. Here’s a sample code snippet:
%net -Network, (XTrain,YTrain)-Training data,
% (XValidation, YValidation)-Validation data
YPredTrain = predict(net, XTrain);
trainMSE = mean((YPredTrain - YTrain).^2);
YPredValidation = predict(net, XValidation);
validationMSE = mean((YPredValidation - YValidation).^2);
  • Plot the MSE for both training and validation datasets throughout the training process using “addpoints and “drawnow”.
trainingPlot = animatedline('Color','r');
validationPlot = animatedline('Color','b');
addpoints(trainingPlot, iteration, trainMSE);
addpoints(validationPlot, iteration, validationMSE);
drawnow;
Refer to the following MathWorks documentation for more information on:

カテゴリ

Help Center および File ExchangeImage Data Workflows についてさらに検索

製品


リリース

R2021b

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by