Trying to make a RNN for 2D signal data classification to 2D classified matrix output. Error using trainNetwork (line 165) Invalid training data. Responses must be a vector of categorical responses, or a cell array of categorical response sequences.

5 ビュー (過去 30 日間)
Hello Everyone,
I'm trying to implement a Neural Network classification algorithm for signal data with a shape (800,500) , with 800 being the number of time steps and 500 being the number of observations. I want to train a NN to identify for every timestep if it belongs to class 0, 1, or -1 . So, my responses should have the same shape with the XTrain data, (800,500).
After trying to use Xtrain and YTrain in the form of simple 2 dimensional arrays I understood that this is not possible, so I made changed their form to cell arrays as the bibliography requires.
My XTrain data have the form of a cell array of double sequences and so do the YTrain data. Out of 500 observations I used 70%, which are 350 observations for my Train data.
As the Matlab trainNetwork bibliography suggest:
But, I'm still getting the same error:
The Layers and Options for my RNN so far are the following, please make sugestions :)
Please help

回答 (1 件)

Srivardhan Gadila
Srivardhan Gadila 2021 年 3 月 29 日
According to the documentation of the Input Arguments: sequences & responses of the trainNetwork function for the syntax net = trainNetwork(sequences,responses,layers,options) the input data should be of N-by-1 cell array of numeric arrays, where N is the number of observations and each observation must be a c-by-s matrix, where c is the number of features of the sequences and s is the sequence length in case of Vector sequences. Whereas the responses should be N-by-1 cell array of categorical sequences of labels, where N is the number of observations with each observation as a 1-by-s sequence of categorical labels, where s is the sequence length of the corresponding predictor sequence.
The following code may help you:
%% Create network.
inputSize = 800;
numClasses = 3;
numHiddenUnits = 100;
layers = [ ...
sequenceInputLayer(1,'Name','Sequence Input')
lstmLayer(numHiddenUnits,'Name','LSTM Layer')
fullyConnectedLayer(numClasses,'Name','FC')
softmaxLayer('Name','Softmax')
classificationLayer('Name','Classification Layer')];
lgraph = layerGraph(layers);
analyzeNetwork(lgraph)
%% Create Random Training data.
numTrainSamples = 50;
trainData = arrayfun(@(x)rand([1 inputSize]),1:numTrainSamples,'UniformOutput',false)';
trainLabels = arrayfun(@(x)categorical(randi([-1 1], 1,inputSize)),1:numTrainSamples,'UniformOutput',false)';
size(trainData)
size(trainLabels)
%% Train the network.
options = trainingOptions('adam', ...
'InitialLearnRate',0.005, ...
'LearnRateSchedule','piecewise',...
'Verbose',1, ...
'Plots','training-progress');
net = trainNetwork(trainData,trainLabels,lgraph,options);
  1 件のコメント
Orestis Marantos
Orestis Marantos 2021 年 4 月 22 日
Hello Srivardhan, thank you so much for your thorough explanations. Could you also recommend me if the way to overpass the high skewness of my problem? I created with simulations more samples.
Specifically, I have 2000 samples each with 800 time steps. In each sample there are some points of interest (signal goes up (on), signal goes down (off)). I would like to detect those points so I thought that MIMO multi-label classification with lstm could solve my problem. So, I want to classify each time step of each sample as 0 if it is not a point of interest, or 1 if it is (changed the -1 also to 1). Eventually I'm interested only in ones so I'm not sure if it is the best method to follow.
Because more than 98% of the timesteps are zeroes the lstm keeps giving my only zeroes after the training. I have implemented a custom weighted binary cross entropy function and I give 0.97 weight to the ones and 0.03 to zeros, but the algorithm keeps giving me only zeros as a result for each prediciton.
Do you have any suggestions on how to solve this highly skewed problem? I read about over-sampling and under-sampling but this won't help in my problem as it is natural to have this very little number of active points (ones).
Thanks a lot in advance

サインインしてコメントする。

カテゴリ

Help Center および File ExchangeClassification Learner App についてさらに検索

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by