How to force overfiting of Deep Learning Network for Classification

1 回表示 (過去 30 日間)
Wojciech Czop
Wojciech Czop 2020 年 1 月 13 日
コメント済み: Greg Heath 2020 年 1 月 18 日
How to force overfiting neural network proposed at documentation https://www.mathworks.com/help/deeplearning/examples/create-simple-deep-learning-network-for-classification.html trained on MNIST dataset ?
  1 件のコメント
Greg Heath
Greg Heath 2020 年 1 月 18 日
Spelling: overfitting has 2 "t"s
HTH
Greg

サインインしてコメントする。

回答 (2 件)

Srivardhan Gadila
Srivardhan Gadila 2020 年 1 月 17 日
As your question is specific to overfitting the proposed network in the example "Create Simple Deep Learning Network for Classification" , I can suggest you the following:
First one:
While splitting the dataset for training & validation, do not split them randomly. Instead do it normally as follows:
[imdsTrain,imdsValidation] = splitEachLabel(imds,numTrainFiles); %remove the input argument 'randomize'
Then train the network until the training loss/accuracy saturates
Second one:
Take only 10% of the original dataset provided in the example. Train on 75% of the new dataset & validate on the other 25%. You can then see that the network will overfit as the network is too big for the new dataset and over the epochs it will overfit.
The following code can help you for getting 10% of the original dataset
digitDatasetPath = fullfile(matlabroot,'toolbox','nnet','nndemos', ...
'nndatasets','DigitDataset');
imds = imageDatastore(digitDatasetPath, ...
'IncludeSubfolders',true,'LabelSource','foldernames');
%display count of number of samples per each label
imds.countEachLabel
numFiles = 100;
%Taking random 10% of the original dataset with equal samples for each cateogry
imds = splitEachLabel(imds,numFiles,'randomize');
%display count of number of samples per each label after taking 10% of samples from original set
imds.countEachLabel
numTrainFiles = 75;
[imdsTrain,imdsValidation] = splitEachLabel(imds,numTrainFiles,'randomize');
You can also change the default trainingOptions too like 'momentum', 'L2Regularization' etc.
You can also refer to Improve Shallow Neural Network Generalization and Avoid Overfitting and Questions related to overfitting in MATLAB Answers Community.

Greg Heath
Greg Heath 2020 年 1 月 18 日
OVERFITTING = More training unknowns (e.g., weights) than training vectors.
OVERTRAINING1 = Training an overfit network to or past convergence (DANGEROUS)
OVERTRAINING2 = Training any network past convergence (STUPID BUT NOT NECESSARILY DANGEROUS)
HOPE THIS HELPS
THANK YOU FOR FORMALLY ACCEPTING MY ANSWER
GREG

カテゴリ

Help Center および File ExchangeSequence and Numeric Feature Data Workflows についてさらに検索

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by