Classify Gender Using LSTM Networks

This example shows how to classify the gender of a speaker using deep learning. In particular, the example uses a Bidirectional Long Short-Term Memory (BiLSTM) network and Gammatone Cepstral Coefficients (gtcc), pitch, harmonic ratio, and several spectral shape descriptors.

Introduction

Gender classification based on speech signals is an essential component of many audio systems, such as automatic speech recognition, speaker recognition, and content-based multimedia indexing.

This example uses long short-term memory (LSTM) networks, a type of recurrent neural network (RNN) well-suited to study sequence and time-series data. An LSTM network can learn long-term dependencies between time steps of a sequence. An LSTM layer (lstmLayer) can look at the time sequence in the forward direction, while a bidirectional LSTM layer (bilstmLayer) can look at the time sequence in both forward and backward directions. This example uses a bidirectional LSTM layer.

This example trains the LSTM network with sequences of Gammatone Cepstrum Coefficients (gtcc), pitch estimates, harmonic ratio, and several spectral shape descriptors.

The example goes through the following steps:

  1. Create an audioDatastore that points to the audio speech files used to train the LSTM network.

  2. Remove silence and non-speech segments from the speech files using a simple thresholding technique.

  3. Extract feature sequences consisting of GTCC coefficients, pitch, harmonic ratio, and several spectral shape descriptors from the speech signals.

  4. Train the LSTM network using the feature sequences.

  5. Measure and visualize the accuracy of the classifier on the training data.

  6. Create an audioDatastore of speech files used to test the trained network.

  7. Remove non-speech segments from these files, generate feature sequences, pass them through the network, and test its accuracy by comparing the predicted and actual gender of the speakers.

To accelerate the training process, run this example on a machine with a GPU. If your machine has a GPU and Parallel Computing Toolbox™, then MATLAB® automatically uses the GPU for training; otherwise, it uses the CPU.

Examine the Dataset

This example uses the Mozilla Common Voice dataset [1]. The dataset contains 48 kHz recordings of subjects speaking short sentences. Download the dataset and untar the downloaded file. Set datafolder to the location of the data.

datafolder = PathToDatabase;

Use audioDatastore to create a datastore for all files in the dataset.

ads0 = audioDatastore(fullfile(datafolder,"clips"));

Since only a fraction of dataset files are annotated with gender information, you will use both the training and validation sets to train the network. You will use the test set to validate the network. Use readtable to read the metadata associated with the audio files from the training and dev set. The metadata is contained in the train.tsv file. Inspect the first few rows of the metadata.

metadataTrain = readtable(fullfile(datafolder,"train.tsv"),"FileType","text");
metadataDev = readtable(fullfile(datafolder,"dev.tsv"),"FileType","text");
metadata = [metadataTrain;metadataDev];
head(metadata)
ans =

  8×8 table

                                                                 client_id                                                                                                                                  path                                                                                                               sentence                                               up_votes    down_votes        age           gender        accent  
    ____________________________________________________________________________________________________________________________________    ____________________________________________________________________________________________________________________________________    ______________________________________________________________________________________________    ________    __________    ____________    __________    __________

    {'55451a804635a88160a09b9b8122e3dddba46c2e6df2d6d9ec9d3445c38180fd18516d76acc9035978f27ee1f798f480dcb55dcbd31a142374c3af566c9be3c4'}    {'f480b8a93bf84b7f74c141284a71c39ff47d264a75dc905dc918286fb67f0333595206ff953a27b8049c7ec09ea895aa66d1cd4f7547535167d3d7901d12feab'}    {'Unfortunately, nobody can warrant the sanctions that will have an effect on the community.'}       3            0         {'twenties'}    {'female'}    {'canada'}
    {'55451a804635a88160a09b9b8122e3dddba46c2e6df2d6d9ec9d3445c38180fd18516d76acc9035978f27ee1f798f480dcb55dcbd31a142374c3af566c9be3c4'}    {'7647873ce81cd81c90b9e0fe3cb6c85cc03df7c0c4fdf2a04c356d75063af4b9de296a24e3bef0ba7ef0b0105d166abf35597e9c9a4b3857fd09c57b79f65a99'}    {'Came down and picked it out himself.'                                                      }       2            0         {'twenties'}    {'female'}    {'canada'}
    {'55451a804635a88160a09b9b8122e3dddba46c2e6df2d6d9ec9d3445c38180fd18516d76acc9035978f27ee1f798f480dcb55dcbd31a142374c3af566c9be3c4'}    {'81a3dd920de6251cc878a940aff258e859ef13efb9a6446610ab907e08832fafdc463eda334ee74a24cc02e3652a09f5573c133e6f46886cb0ba463efc7a6b43'}    {'She crossed the finish line just in time.'                                                 }       2            0         {'twenties'}    {'female'}    {'canada'}
    {'5b8c0f566c1201a94e684a334cf8a2cbced8a009a5a346fc24f1d51446c6b8610fc7bd78f69e559b29d138ab92652a45408ef87c3ec0e426d2fc5f1b2b44935b'}    {'5e6fc96a7bc91ec2261a51e7713bb0ed8a9f4fa9e20a38060dc4544fb0c2600c192d6e849915acaf8ea0766a9e1d481557d674363e780dbb064586352e560f2c'}    {'Please find me the Home at Last trailer.'                                                  }       2            0         {0×0 char  }    {0×0 char}    {0×0 char}
    {'5b8c0f566c1201a94e684a334cf8a2cbced8a009a5a346fc24f1d51446c6b8610fc7bd78f69e559b29d138ab92652a45408ef87c3ec0e426d2fc5f1b2b44935b'}    {'3a0929094a9aac80b961d479a3ee54311cc0d60d51fe8f97071edc2999e7747444b261d2c0c2345f86fb8161f3b73a14dc19da911d19ca8d9db39574c6199a34'}    {'Play something by Louisiana Blues'                                                         }       2            0         {0×0 char  }    {0×0 char}    {0×0 char}
    {'5b8c0f566c1201a94e684a334cf8a2cbced8a009a5a346fc24f1d51446c6b8610fc7bd78f69e559b29d138ab92652a45408ef87c3ec0e426d2fc5f1b2b44935b'}    {'82b8edf3f1420295069b5bb4543b8f349faaca28f45a3279b0cd64c39d16afb590a4cc70ed805020161f8c1f94bc63d3b69756fbc5a0462ce12d2e17c4ebaeeb'}    {'When is The Devil with Hitler playing in Bow Tie Cinemas'                                  }       2            0         {0×0 char  }    {0×0 char}    {0×0 char}
    {'60013a707ac8cdd2b44427418064915f7810b2d58d52d8f81ad3c6406b8922d61c134259747f3c73d2e64c885fc6141761d29f7e6ada7d6007c48577123e4af0'}    {'8606bac841a08bcbf5ddb83c768103c467ffd1bf38b16052414210dc3ce3267561cb0368d227b6eb420dc147387cc1807032102b6248a13a40f83e5ac06d7122'}    {'Give me the list of animated movies playing at the closest movie house'                    }       3            0         {0×0 char  }    {0×0 char}    {0×0 char}
    {'60013a707ac8cdd2b44427418064915f7810b2d58d52d8f81ad3c6406b8922d61c134259747f3c73d2e64c885fc6141761d29f7e6ada7d6007c48577123e4af0'}    {'0f7e63d320cfbf6ea5d1cda674007131d804a06f5866d38f81da7def33a4ce8ee4f2cb7e47b45eee97903cad3160b3a10f715862227e8ecdc3fb3bafc6b4279d'}    {'The better part of valor is discretion'                                                    }       3            0         {0×0 char  }    {0×0 char}    {0×0 char}

Create Training Dataset

Find the files in the datastore corresponding to the training set.

csvFiles = metadata.path;
adsFiles = ads0.Files;
adsFiles = cellfun(@HelperGetFilePart,adsFiles,'UniformOutput',false);
[~,indA,indB] = intersect(adsFiles,csvFiles);

Create a subset training set from the large dataset.

adsTrain = subset(ads0,indA);

You will use data corresponding to adult speakers only. Read the gender and age variables from the metadata.

gender = metadata.gender;
gender = gender(indB);
age = metadata.age;
age = age(indB);

Assign gender to the Labels property of the datastore.

adsTrain.Labels = gender;

Not all files in the dataset are annotated with gender and age information. Create a subset of the datastore that only contains files where gender information is available and age is greater than 19.

maleOrfemale = categorical(adsTrain.Labels) == "male" | categorical(adsTrain.Labels) == "female";
isAdult = categorical(age) ~= "" & categorical(age) ~= "teens";
adsTrain = subset(adsTrain,maleOrfemale & isAdult);

You will train the deep learning network on a subset of the files. Create a datastore subset containing an equal number of male and female speakers.

ismale = find(categorical(adsTrain.Labels) == "male");
isfemale = find(categorical(adsTrain.Labels) == "female");
numFilesPerGender = numel(isfemale);
adsTrain = subset(adsTrain,[ismale(1:numFilesPerGender) isfemale(1:numFilesPerGender)]);

Use shuffle to randomize the order of the files in the datastore.

adsTrain = shuffle(adsTrain);

Use countEachLabel to inspect the gender breakdown of the training set.

countEachLabel(adsTrain)
ans =

  2×2 table

    Label     Count
    ______    _____

    female     925 
    male       925 

Isolate Speech Segments

Read the contents of an audio file using read.

[audio,adsInfo] = read(adsTrain);
Fs = adsInfo.SampleRate;

Plot the audio signal and then listen to it using the sound command.

timeVector = (1/Fs) * (0:numel(audio)-1);
figure
plot(timeVector,audio)
ylabel("Amplitude")
xlabel("Time (s)")
title("Sample Audio")
grid on

sound(audio,Fs)

The speech signal has silence segments that do not contain useful information pertaining to the gender of the speaker. This example removes silence using a simplified version of the thresholding approach described in [2]. The steps of the silence removal algorithm are outlined below.

First, compute two features over non-overlapping frames of the audio data: The signal energy and the spectral centroid. The spectral centroid is a measure of the "center of gravity" of a signal spectrum.

Break the audio into 50-millisecond non-overlapping frames.

audio = audio ./ max(abs(audio)); % Normalize amplitude
windowLength = 50e-3 * Fs;
segments = buffer(audio,windowLength);

Compute the energy and spectral centroid for each frame.

win = hann(windowLength,'periodic');
signalEnergy = sum(segments.^2,1)/windowLength;
centroid = spectralCentroid(segments,Fs,'Window',win,'OverlapLength',0);

Next, set thresholds for each feature. Regions where the feature values fall below or above their respective thresholds are disregarded. In this example, the energy threshold is set to half the mean energy and the spectral centroid threshold is set to 5000 Hz.

T_E = mean(signalEnergy)/2;
T_C = 5000;
isSpeechRegion = (signalEnergy>=T_E) & (centroid<=T_C);

Visualize the computed energy and spectral centroid over time.

% Hold the signal energy, spectral centroid, and speech decision values for
% plotting purposes.
CC = repmat(centroid,windowLength,1);
CC = CC(:);
EE = repmat(signalEnergy,windowLength,1);
EE = EE(:);
flags2 = repmat(isSpeechRegion,windowLength,1);
flags2 = flags2(:);

figure

subplot(3,1,1)
plot(timeVector, CC(1:numel(audio)), ...
     timeVector, repmat(T_C,1,numel(timeVector)), "LineWidth",2)
xlabel("Time (s)")
ylabel("Normalized Centroid")
legend("Centroid","Threshold")
title("Spectral Centroid")
grid on

subplot(3,1,2)
plot(timeVector, EE(1:numel(audio)), ...
     timeVector, repmat(T_E,1,numel(timeVector)),"LineWidth",2)
ylabel("Normalized Energy")
legend("Energy","Threshold")
title("Window Energy")
grid on

subplot(3,1,3)
plot(timeVector, audio, ...
     timeVector,flags2(1:numel(audio)),"LineWidth",2)
ylabel("Audio")
legend("Audio","Speech Region")
title("Audio")
grid on
ylim([-1 1.1])

Extract the segments of speech from the audio. Assume speech is present for samples where energy is above its threshold and the spectral centroid is below its threshold.

% Get indices of frames where a speech-to-silence or silence-to-speech
% transition occurs.
regionStartPos = find(diff([isSpeechRegion(1)-1, isSpeechRegion]));

% Get the length of the all-silence or all-speech regions.
RegionLengths = diff([regionStartPos, numel(isSpeechRegion)+1]);

% Get speech-only regions.
isSpeechRegion = isSpeechRegion(regionStartPos) == 1;
regionStartPos = regionStartPos(isSpeechRegion);
RegionLengths = RegionLengths(isSpeechRegion);

% Get start and end indices for each speech region. Extend the region by 5
% windows on each side.
startIndices = zeros(1,numel(RegionLengths));
endIndices = zeros(1,numel(RegionLengths));
for index = 1:numel(RegionLengths)
   startIndices(index) = max(1,(regionStartPos(index) - 5) * windowLength + 1);
   endIndices(index) = min(numel(audio),(regionStartPos(index) + RegionLengths(index) + 5) * windowLength);
end

Finally, merge intersecting speech segments.

activeSegment = 1;
isSegmentsActive = zeros(1,numel(startIndices));
isSegmentsActive(1) = 1;
for index = 2:numel(startIndices)
    if startIndices(index) <= endIndices(activeSegment)
        % Current segment intersects with previous segment
        if endIndices(index) > endIndices(activeSegment)
           endIndices(activeSegment) =  endIndices(index);
        end
    else
        % New speech segment detected
        activeSegment = index;
        isSegmentsActive(index) = 1;
    end
end
numSegments = sum(isSegmentsActive);
segments = cell(1,numSegments);
limits = zeros(2,numSegments);
speechSegmentsIndices  = find(isSegmentsActive);
for index = 1:length(speechSegmentsIndices)
    segments{index} = audio(startIndices(speechSegmentsIndices(index)): ...
                            endIndices(speechSegmentsIndices(index)));
    limits(:,index) = [startIndices(speechSegmentsIndices(index)); ...
                       endIndices(speechSegmentsIndices(index))];
end

Plot the original audio along with the detected speech segments.

figure

plot(timeVector,audio)
hold on
myLegend = cell(1,numel(segments) + 1);
myLegend{1} = "Original Audio";
for index = 1:numel(segments)
    plot(timeVector(limits(1,index):limits(2,index)),segments{index});
    myLegend{index+1} = sprintf("Output Audio Segment %d",index);
end
xlabel("Time (s)")
ylabel("Audio")
grid on
legend(myLegend)

Audio Features

A speech signal is dynamic in nature and changes over time. It is assumed that speech signals are stationary on short time scales and their processing is often done in windows of 20-40 ms. For each speech segment, this example extracts audio features for 30 ms windows with 75% overlap.

win = hamming(0.03*Fs,"periodic");
overlapLength = round(0.75*numel(win));
featureParams = struct("SampleRate",Fs, ...
                 "Window",win, ...
                 "OverlapLength",overlapLength);
extractor = audioFeatureExtractor('Window',win, ...
    'OverlapLength',overlapLength, ...
    'SampleRate',Fs, ...
    'SpectralDescriptorInput','melSpectrum', ...
    ...
    'gtcc',true, ...
    'gtccDelta',true, ...
    'gtccDeltaDelta',true, ...
    'spectralSlope',true, ...
    'spectralFlux',true, ...
    'spectralCentroid',true, ...
    'spectralEntropy',true, ...
    'pitch',true, ...
    'harmonicRatio',true);

The figure provides an overview of the feature extraction used in this example.

Extract Features Using Tall Arrays

To speed up processing, extract feature sequences from the speech segments of all audio files in the datastore using tall arrays. Unlike in-memory arrays, tall arrays typically remain unevaluated until you request that the calculations be performed using the gather function. This deferred evaluation allows you to work quickly with large data sets. When you eventually request the output using gather, MATLAB® combines the queued calculations where possible and takes the minimum number of passes through the data. If you have Parallel Computing Toolbox™, you can use tall arrays in your local MATLAB® session, or on a local parallel pool. You can also run tall array calculations on a cluster if you have MATLAB® Parallel Server™ installed.

First, convert the datastore to a tall array:

T = tall(adsTrain)
T =

  M×1 tall cell array

    {280560×1 double}
    {156144×1 double}
    {167664×1 double}
    {190704×1 double}
    {401520×1 double}
    {120432×1 double}
    {354288×1 double}
    {318576×1 double}
        :        :
        :        :

The display indicates that the number of rows (corresponding to the number of files in the datastore), M, is not yet known. M is a placeholder until the calculation completes.

Extract the speech segments from the tall table. This action creates a new tall array variable to use in subsequent calculations. The function HelperSegmentSpeech performs the steps already highlighted in the Isolate Speech Segments section. The cellfun command applies HelperSegmentSpeech to the contents of each audio file in the datastore. Also determine the number of segments per file.

segmentsTall = cellfun(@(x)HelperSegmentSpeech(x,Fs),T,"UniformOutput",false);
segmentsPerFileTall = cellfun(@numel,segmentsTall);

Extract feature sequences from the speech segments using HelperGetFeatureVectors. The helper function applies feature extraction on the segments using audioFeatureExtractor and then reorients the features so that time is along rows to be compatible with sequenceInputLayer.

featureVectorsTall = cellfun(@(x)HelperGetFeatureVectors(x,extractor),segmentsTall,"UniformOutput",false);

Use gather to evaluate featureVectorsTall and segmentsPerFileTall. featureVectors is returned as a NumFiles-by-1 cell array, where each element of the cell array is a 1-by-NumSegmentsPerFile cell array. Unnest the cell array.

[featureVectors,segmentsPerFile] = gather(featureVectorsTall,segmentsPerFileTall);

featureVectors = cat(2,featureVectors{:});
Evaluating tall expression using the Parallel Pool 'local':
- Pass 1 of 1: Completed in 2 min 15 sec
Evaluation completed in 2 min 19 sec

Replicate the labels so there is one label per segment.

myLabels = adsTrain.Labels;
myLabels = repelem(myLabels,segmentsPerFile);

In classification applications, it is good practice to normalize all features to have zero mean and unity standard deviation.

Compute the mean and standard deviation for each coefficient, and use them to normalize the data.

allFeatures = cat(2,featureVectors{:});
allFeatures(isinf(allFeatures)) = nan;
M = mean(allFeatures,2,'omitnan');
S = std(allFeatures,0,2,'omitnan');
featureVectors = cellfun(@(x)(x-M)./S,featureVectors,'UniformOutput',false);
for ii = 1:numel(featureVectors)
    idx = find(isnan(featureVectors{ii}));
    if ~isempty(idx)
        featureVectors{ii}(idx) = 0;
    end
end

User HelperFeatureVector2Sequence to buffer the feature vectors into sequences of 20 feature vectors with 10 overlap.

featureVectorsPerSequence = 20;
featureVectorOverlap = 10;
[featuresTrain,sequencePerSegment] = HelperFeatureVector2Sequence(featureVectors,featureVectorsPerSequence,featureVectorOverlap);

Create a cell array, genderTrain, for the expected gender associated with each training sequence.

genderTrain = repelem(myLabels,[sequencePerSegment{:}]);

Create Validation Dataset

You will create a validation dataset using the same approach you used for the training dataset. Use the readtable function to read the metadata associated with validation files.

metadata = readtable(fullfile(datafolder,"test.tsv"),"FileType","text");

Locate the validation files in the datastore.

csvFiles = metadata.path;
adsFiles = ads0.Files;
adsFiles = cellfun(@HelperGetFilePart,adsFiles,'UniformOutput',false);
[~,indA,indB] = intersect(adsFiles,csvFiles);

Create a validation datastore from the large datastore.

adsVal = subset(ads0,indA);

Similar to the training set, you will use data corresponding to adult speakers only. Read the gender and age variables from the metadata.

gender = metadata.gender;
gender = gender(indB);
age = metadata.age;
age = age(indB);

Assign gender to the Labels property of the datastore.

adsVal.Labels = gender;

Not all files in the dataset are annotated with gender and age information. Create a subset of the datastore that only contains files where gender information is available and age is greater than 19.

maleOrfemale =  categorical(adsVal.Labels) == "female" | categorical(adsVal.Labels) == "male";
isAdult = categorical(age) ~= "" & categorical(age) ~= "teens";
adsVal = subset(adsVal,maleOrfemale & isAdult);

Use countEachLabel to inspect the gender breakdown of the files.

countEachLabel(adsVal)
ans =

  2×2 table

    Label     Count
    ______    _____

    female      83 
    male       532 

Remove silence and extract features from the validation data.

T = tall(adsVal);
segments = cellfun(@(x)HelperSegmentSpeech(x,Fs),T,"UniformOutput",false);
segmentsPerFileTall = cellfun(@numel,segments);

featureVectorsTall = cellfun(@(x)HelperGetFeatureVectors(x,extractor),segments,"UniformOutput",false);

[featureVectors,valSegmentsPerFile] = gather(featureVectorsTall,segmentsPerFileTall);

featureVectors = cat(2,featureVectors{:});

valSegmentLabels = repelem(adsVal.Labels,valSegmentsPerFile);
Evaluating tall expression using the Parallel Pool 'local':
- Pass 1 of 1: Completed in 45 sec
Evaluation completed in 46 sec

Normalize the feature sequence by the mean and standard deviations computed during the training stage.

featureVectors = cellfun(@(x)(x-M)./S,featureVectors,'UniformOutput',false);
for ii = 1:numel(featureVectors)
    idx = find(isnan(featureVectors{ii}));
    if ~isempty(idx)
        featureVectors{ii}(idx) = 0;
    end
end

Create a cell array containing the sequence predictors.

[featuresValidation,valSequencePerSegment] = HelperFeatureVector2Sequence(featureVectors,featureVectorsPerSequence,featureVectorOverlap);

Create a cell array, gender, for the expected gender associated with each training sequence.

genderValidation = repelem(valSegmentLabels,[valSequencePerSegment{:}]);

Define the LSTM Network Architecture

LSTM networks can learn long-term dependencies between time steps of sequence data. This example uses the bidirectional LSTM layer bilstmLayer to look at the sequence in both forward and backward directions.

Specify the input size to be sequences of size NumFeatures. Specify a hidden bidirectional LSTM layer with an output size of 50 and output a sequence. Then, specify a bidirectional LSTM layer with an output size of 50 and output the last element of the sequence. This command instructs the bidirectional LSTM layer to map its input into 50 features and then prepares the output for the fully connected layer. Finally, specify two classes by including a fully connected layer of size 2, followed by a softmax layer and a classification layer.

layers = [ ...
    sequenceInputLayer(size(featuresTrain{1},1))
    bilstmLayer(50,"OutputMode","sequence")
    dropoutLayer(0.1)
    bilstmLayer(50,"OutputMode","last")
    fullyConnectedLayer(2)
    softmaxLayer
    classificationLayer];

Next, specify the training options for the classifier. Set MaxEpochs to 4 so that the network makes 4 passes through the training data. Set MiniBatchSize of 128 so that the network looks at 128 training signals at a time. Specify Plots as "training-progress" to generate plots that show the training progress as the number of iterations increases. Set Verbose to false to disable printing the table output that corresponds to the data shown in the plot. Specify Shuffle as "every-epoch" to shuffle the training sequence at the beginning of each epoch. Specify LearnRateSchedule to "piecewise" to decrease the learning rate by a specified factor (0.1) every time a certain number of epochs (2) has passed.

This example uses the adaptive moment estimation (ADAM) solver. ADAM performs better with recurrent neural networks (RNNs) like LSTMs than the default stochastic gradient descent with momentum (SGDM) solver.

miniBatchSize = 128;
validationFrequency = floor(numel(genderTrain)/miniBatchSize);
options = trainingOptions("adam", ...
    "MaxEpochs",4, ...
    "MiniBatchSize",miniBatchSize, ...
    "Plots","training-progress", ...
    "Verbose",false, ...
    "Shuffle","every-epoch", ...
    "LearnRateSchedule","piecewise", ...
    "LearnRateDropFactor",0.1, ...
    "LearnRateDropPeriod",2,...
    'ValidationData',{featuresValidation,categorical(genderValidation)}, ...
    'ValidationFrequency',validationFrequency);

Train the LSTM Network

Train the LSTM network with the specified training options and layer architecture using trainNetwork. Because the training set is large, the training process can take several minutes.

net = trainNetwork(featuresTrain,categorical(genderTrain),layers,options);

The top subplot of the training-progress plot represents the training accuracy, which is the classification accuracy on each mini-batch. When training progresses successfully, this value typically increases towards 100%. The bottom subplot displays the training loss, which is the cross-entropy loss on each mini-batch. When training progresses successfully, this value typically decreases towards zero.

If the training is not converging, the plots might oscillate between values without trending in a certain upward or downward direction. This oscillation means that the training accuracy is not improving and the training loss is not decreasing. This situation can occur at the start of training, or after some preliminary improvement in training accuracy. In many cases, changing the training options can help the network achieve convergence. Decreasing MiniBatchSize or decreasing InitialLearnRate might result in a longer training time, but it can help the network learn better.

Visualize the Training Accuracy

Calculate the training accuracy, which represents the accuracy of the classifier on the signals on which it was trained. First, classify the training data.

trainPred = classify(net,featuresTrain);

Plot the confusion matrix. Display the precision and recall for the two classes by using column and row summaries.

figure
cm = confusionchart(categorical(genderTrain),trainPred,'title','Training Accuracy');
cm.ColumnSummary = 'column-normalized';
cm.RowSummary = 'row-normalized';

Visualize the Validation Accuracy

Calculate the validation accuracy. First, classify the training data.

[valPred,valScores] = classify(net,featuresValidation);

Plot the confusion matrix. Display the precision and recall for the two classes by using column and row summaries.

figure
cm = confusionchart(categorical(genderValidation),valPred,'title','Validation Set Accuracy');
cm.ColumnSummary = 'column-normalized';
cm.RowSummary = 'row-normalized';

The example generated multiple sequences from each training speech file. Higher accuracy can be achieved by considering the output class of all sequences corresponding to the same file, and applying a "max-rule" decision, where the class with the segment with the highest confidence score is selected.

Determine the number of sequences generated per file in the validation set.

sequencePerFile = zeros(size(valSegmentsPerFile));
valSequencePerSegmentMat = cell2mat(valSequencePerSegment);
idx = 1;
for ii = 1:numel(valSegmentsPerFile)
    sequencePerFile(ii) = sum(valSequencePerSegmentMat(idx:idx+valSegmentsPerFile(ii)-1));
    idx = idx + valSegmentsPerFile(ii);
end

Predict the gender from each training file by considering the output classes of all sequences generated from the same file.

numFiles = numel(adsVal.Files);
actualGender = categorical(adsVal.Labels);
predictedGender = actualGender;
scores = cell(1,numFiles);
counter = 1;
cats = unique(actualGender);
for index = 1:numFiles
    scores{index}      = valScores(counter: counter + sequencePerFile(index) - 1,:);
    m = max(mean(scores{index},1),[],1);
    if m(1) >= m(2)
        predictedGender(index) = cats(1);
    else
        predictedGender(index) = cats(2);
    end
    counter = counter + sequencePerFile(index);
end

Visualize the confusion matrix on the majority-rule predictions.

figure
cm = confusionchart(actualGender,predictedGender,'title','Validation Set Accuracy - Max Rule');
cm.ColumnSummary = 'column-normalized';
cm.RowSummary = 'row-normalized';

References

[1] https://voice.mozilla.org/

[2] Introduction to Audio Analysis: A MATLAB Approach, Giannakopoulos and Pikrakis, Academic Press.

See Also

| |

Related Topics