LSTM - Set special loss function

9 ビュー (過去 30 日間)
Oliver Köhn
Oliver Köhn 2018 年 4 月 26 日
編集済み: Stuart Whipp 2018 年 12 月 12 日
Based on this great MatLab-example I would like to build a neural network classifying each timestep of a timeseries (x_i,y_i) (i=1:N) as 1 or 2. For training purpose I created 500 different timeseries and the corresponding target-vectors. In reality, about 85 % of a timeseries is in state 1 and the rest (15 %) in state 2. The training-process lookes like this:
It stagnates at about 85%, because having the Mean Squared Error as a loss-function, a policy classifying every timestep as 1 results in a "good" accuracy of 85 %. So nearly every timestep is classified as 1. I am quite sure this can be avoided by using another loss function, but unfortunately I do not know how I can create an arbitrary loss-function.
I would like to adapt the loss function in a way, that if it falsely classifies a true state 2 as 1, then the loss is weighted by a factor f.
How can this be done?
This is how the training is set up:
featureDimension = 2;
numHiddenUnits = 100;
numClasses = 2;
layers = [ ...
sequenceInputLayer(featureDimension)
lstmLayer(numHiddenUnits,'OutputMode','sequence')
fullyConnectedLayer(numClasses)
softmaxLayer
classificationLayer];
options = trainingOptions('adam', ...
'GradientThreshold',1, ...
'InitialLearnRate',0.01, ...
'LearnRateSchedule','piecewise', ...
'LearnRateDropPeriod',20, ...
'Verbose',0, ...
'Plots','training-progress');
net = trainNetwork(XTrain,YTrain,layers,options);

回答 (2 件)

sii
sii 2018 年 5 月 18 日
If it is still relevant, check following documentation. As far as I know, it is only available in the R2018a release.

Stuart Whipp
Stuart Whipp 2018 年 12 月 10 日
I think what is needed is a weighted classification output so you can account for the imbalance in your classes. A custom layer tutorial exists for this, but it only works on image classification problems seemingly.
  1 件のコメント
Stuart Whipp
Stuart Whipp 2018 年 12 月 12 日
編集済み: Stuart Whipp 2018 年 12 月 12 日
Conor Daly (staff) kindly answered my question this morning with a custom output layer that I can confirm has worked for my Use Case. Please take a look at this link as I believe it also answers your question.
Regards Stuart

サインインしてコメントする。

カテゴリ

Help Center および File ExchangeSequence and Numeric Feature Data Workflows についてさらに検索

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by